ssd_chunk_scan.py 103 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. """We want triton==2.1.0 or 2.2.0 for this
  3. """
  4. import math
  5. from packaging import version
  6. import torch
  7. import torch.nn.functional as F
  8. import triton
  9. import triton.language as tl
  10. from einops import rearrange, repeat
  11. try:
  12. from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
  13. except:
  14. from ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
  15. TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
  16. def init_to_zero(names):
  17. return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
  18. @triton.autotune(
  19. configs=[
  20. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
  21. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  22. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  23. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  24. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  25. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
  26. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
  27. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  28. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  29. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  30. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
  31. ],
  32. key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
  33. )
  34. @triton.jit
  35. def _chunk_scan_fwd_kernel(
  36. # Pointers to matrices
  37. cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,
  38. # Matrix dimensions
  39. chunk_size, hdim, dstate,
  40. batch, seqlen, nheads_ngroups_ratio,
  41. # Strides
  42. stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
  43. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  44. stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
  45. stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
  46. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  47. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  48. stride_seq_idx_batch, stride_seq_idx_seqlen,
  49. stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
  50. stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
  51. stride_D_head,
  52. # Meta-parameters
  53. IS_CAUSAL: tl.constexpr,
  54. HAS_D: tl.constexpr,
  55. D_HAS_HDIM: tl.constexpr,
  56. HAS_Z: tl.constexpr,
  57. HAS_SEQ_IDX: tl.constexpr,
  58. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  59. BLOCK_SIZE_DSTATE: tl.constexpr,
  60. IS_TRITON_22: tl.constexpr,
  61. ):
  62. pid_bc = tl.program_id(axis=1)
  63. pid_c = pid_bc // batch
  64. pid_b = pid_bc - pid_c * batch
  65. pid_h = tl.program_id(axis=2)
  66. num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
  67. pid_m = tl.program_id(axis=0) // num_pid_n
  68. pid_n = tl.program_id(axis=0) % num_pid_n
  69. cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
  70. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  71. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  72. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  73. C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
  74. prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
  75. if HAS_SEQ_IDX:
  76. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  77. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  78. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  79. dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  80. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  81. if HAS_SEQ_IDX:
  82. seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
  83. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  84. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  85. # Without the if (pid_c > -1), with Triton 2.1.0, I get
  86. # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
  87. # With Triton 2.2.0, this works
  88. if IS_TRITON_22 or pid_c > -1:
  89. # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
  90. offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
  91. C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
  92. prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
  93. if not HAS_SEQ_IDX:
  94. scale_m = tl.exp(dA_cs_m)
  95. else:
  96. scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
  97. if BLOCK_SIZE_DSTATE <= 128:
  98. C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
  99. prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
  100. prev_states = prev_states.to(C_ptr.dtype.element_ty)
  101. acc = tl.dot(C, prev_states) * scale_m[:, None]
  102. else:
  103. for k in range(0, dstate, BLOCK_SIZE_K):
  104. C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0)
  105. # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
  106. prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
  107. prev_states = prev_states.to(C_ptr.dtype.element_ty)
  108. acc += tl.dot(C, prev_states)
  109. C_ptrs += BLOCK_SIZE_K
  110. prev_states_ptrs += BLOCK_SIZE_K
  111. acc *= scale_m[:, None]
  112. offs_k = tl.arange(0, BLOCK_SIZE_K)
  113. cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
  114. x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  115. dt_ptrs = dt_ptr + offs_k * stride_dt_csize
  116. dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
  117. K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
  118. for k in range(0, K_MAX, BLOCK_SIZE_K):
  119. cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32)
  120. dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
  121. # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
  122. # So we don't need masking wrt seq_idx here.
  123. cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
  124. dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
  125. cb *= dt_k
  126. if IS_CAUSAL:
  127. mask = offs_m[:, None] >= k + offs_k[None, :]
  128. cb = tl.where(mask, cb, 0.0)
  129. cb = cb.to(x_ptr.dtype.element_ty)
  130. x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0)
  131. acc += tl.dot(cb, x)
  132. cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
  133. x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
  134. dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
  135. dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
  136. offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  137. offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  138. if HAS_D:
  139. if D_HAS_HDIM:
  140. D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
  141. else:
  142. D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
  143. x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
  144. mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  145. acc += x_residual * D
  146. if HAS_Z:
  147. out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
  148. out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
  149. tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
  150. z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
  151. z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
  152. z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
  153. acc *= z * tl.sigmoid(z)
  154. out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
  155. out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim)
  156. tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
  157. @triton.autotune(
  158. configs=[
  159. # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4),
  160. # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4),
  161. triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4),
  162. triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4),
  163. triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8),
  164. triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8),
  165. ],
  166. key=['chunk_size', 'hdim', 'dstate'],
  167. )
  168. @triton.jit
  169. def _chunk_scan_fwd_kernel_wip(
  170. # Pointers to matrices
  171. cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr,
  172. # Matrix dimensions
  173. chunk_size, hdim, dstate,
  174. batch, seqlen, nheads_ngroups_ratio,
  175. # Strides
  176. stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
  177. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  178. stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
  179. stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
  180. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  181. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  182. stride_seq_idx_batch, stride_seq_idx_seqlen,
  183. stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
  184. stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate,
  185. stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
  186. stride_D_head,
  187. # Meta-parameters
  188. HAS_D: tl.constexpr,
  189. D_HAS_HDIM: tl.constexpr,
  190. HAS_Z: tl.constexpr,
  191. HAS_SEQ_IDX: tl.constexpr,
  192. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
  193. BLOCK_SIZE_DSTATE: tl.constexpr,
  194. ):
  195. pid_bc = tl.program_id(axis=1)
  196. pid_c = pid_bc // batch
  197. pid_b = pid_bc - pid_c * batch
  198. pid_h = tl.program_id(axis=2)
  199. pid_n = tl.program_id(axis=0)
  200. cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
  201. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  202. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  203. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  204. C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
  205. B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head
  206. prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
  207. if HAS_SEQ_IDX:
  208. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  209. out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
  210. offs_m = tl.arange(0, BLOCK_SIZE_M)
  211. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  212. offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE)
  213. C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
  214. B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate)
  215. prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
  216. num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
  217. cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k)
  218. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  219. dt_ptrs = dt_ptr + offs_m * stride_dt_csize
  220. out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
  221. prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
  222. # if pid_c == 0:
  223. # if pid_b == 0:
  224. # if pid_h == 0:
  225. # tl.device_print("", prev_states)
  226. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  227. # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  228. # scale_m = tl.exp(dA_cs_m)
  229. # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
  230. # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
  231. # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32)
  232. # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
  233. # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  234. # cb *= dt_m
  235. # mask = offs_m[:, None] >= offs_m[None, :]
  236. # cb = tl.where(mask, cb, 0.0)
  237. # cb = cb.to(x_ptr.dtype.element_ty)
  238. # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0)
  239. # acc += tl.dot(cb, x)
  240. # if HAS_D:
  241. # if D_HAS_HDIM:
  242. # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
  243. # else:
  244. # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
  245. # acc += x.to(tl.float32) * D
  246. # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
  247. for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M):
  248. start_m = tl.multiple_of(start_m, BLOCK_SIZE_M)
  249. dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
  250. if HAS_SEQ_IDX:
  251. seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
  252. seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1)
  253. if not HAS_SEQ_IDX:
  254. scale_m = tl.exp(dA_cs_m)
  255. else:
  256. scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
  257. C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0)
  258. acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
  259. # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32)
  260. # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
  261. dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
  262. # cb *= dt_m
  263. # mask = offs_m[:, None] >= offs_m[None, :]
  264. # cb = tl.where(mask, cb, 0.0)
  265. # cb = cb.to(x_ptr.dtype.element_ty)
  266. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0)
  267. # acc += tl.dot(cb, x)
  268. if HAS_D:
  269. if D_HAS_HDIM:
  270. D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
  271. else:
  272. D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
  273. acc += x.to(tl.float32) * D
  274. # if HAS_Z:
  275. # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
  276. # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
  277. # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
  278. # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
  279. # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
  280. # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
  281. # acc *= z * tl.sigmoid(z)
  282. tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim))
  283. # TODO: this is not correct, and quite a bit slower
  284. if start_m + BLOCK_SIZE_M < chunk_size_limit:
  285. # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32)
  286. B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0)
  287. dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32)
  288. # TODO: seq_idx
  289. scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m
  290. # B *= scale
  291. B = B.to(x_ptr.dtype.element_ty)
  292. tmp = tl.dot(B, x)
  293. prev_states += tmp.to(prev_states.dtype)
  294. C_ptrs += BLOCK_SIZE_M * stride_C_seqlen
  295. B_ptrs += BLOCK_SIZE_M * stride_B_seqlen
  296. cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k
  297. x_ptrs += BLOCK_SIZE_M * stride_x_seqlen
  298. dt_ptrs += BLOCK_SIZE_M * stride_dt_csize
  299. out_ptrs += BLOCK_SIZE_M * stride_out_seqlen
  300. @triton.autotune(
  301. configs=[
  302. triton.Config({'BLOCK_SIZE_M': 32}),
  303. triton.Config({'BLOCK_SIZE_M': 64}),
  304. triton.Config({'BLOCK_SIZE_M': 128}),
  305. triton.Config({'BLOCK_SIZE_M': 256}),
  306. ],
  307. key=["chunk_size", "hdim"],
  308. )
  309. @triton.jit
  310. def _chunk_scan_bwd_dz_kernel(
  311. # Pointers to matrices
  312. dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr,
  313. # Matrix dimensions
  314. chunk_size, hdim,
  315. batch, seqlen,
  316. # Strides
  317. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  318. stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
  319. stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
  320. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  321. stride_D_head,
  322. stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim,
  323. stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim,
  324. stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim,
  325. stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
  326. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
  327. # Meta-parameters
  328. HAS_D: tl.constexpr,
  329. D_HAS_HDIM: tl.constexpr,
  330. HAS_DDACS: tl.constexpr,
  331. RECOMPUTE_OUTPUT: tl.constexpr,
  332. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
  333. ):
  334. pid_bc = tl.program_id(axis=1)
  335. pid_c = pid_bc // batch
  336. pid_b = pid_bc - pid_c * batch
  337. pid_h = tl.program_id(axis=2)
  338. pid_m = tl.program_id(axis=0)
  339. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  340. dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head
  341. out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
  342. z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
  343. dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head
  344. if RECOMPUTE_OUTPUT:
  345. outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head
  346. if HAS_DDACS:
  347. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
  348. if HAS_D:
  349. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  350. dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
  351. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  352. offs_n = tl.arange(0, BLOCK_SIZE_N)
  353. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
  354. dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim)
  355. out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
  356. z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim)
  357. dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim)
  358. if RECOMPUTE_OUTPUT:
  359. outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim)
  360. if HAS_D:
  361. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  362. if D_HAS_HDIM:
  363. dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
  364. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  365. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  366. out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  367. z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  368. z_sigmoid = tl.sigmoid(z)
  369. if RECOMPUTE_OUTPUT:
  370. outz = out * z * z_sigmoid
  371. tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
  372. dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid))
  373. tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
  374. dout *= z * z_sigmoid
  375. tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
  376. if HAS_D:
  377. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  378. if D_HAS_HDIM:
  379. dD = tl.sum(dout * x, axis=0)
  380. tl.store(dD_ptrs, dD, mask=offs_n < hdim)
  381. D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
  382. else:
  383. dD = tl.sum(dout * x)
  384. tl.store(dD_ptr, dD)
  385. D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
  386. out -= x * D
  387. if HAS_DDACS:
  388. ddA_cs = tl.sum(dout * out, axis=1)
  389. tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
  390. @triton.autotune(
  391. configs=[
  392. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
  393. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  394. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  395. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  396. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  397. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  398. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  399. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  400. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
  401. ],
  402. key=['hdim', 'dstate', 'chunk_size'],
  403. )
  404. @triton.jit
  405. def _chunk_scan_bwd_dstates_kernel(
  406. # Pointers to matrices
  407. dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr,
  408. # Matrix dimensions
  409. hdim, dstate, chunk_size,
  410. batch, seqlen, nchunks, nheads_ngroups_ratio,
  411. # Strides
  412. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  413. stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate,
  414. stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate,
  415. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  416. stride_seq_idx_batch, stride_seq_idx_seqlen,
  417. # Meta-parameters
  418. HAS_SEQ_IDX: tl.constexpr,
  419. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  420. ):
  421. pid_bc = tl.program_id(axis=1)
  422. pid_c = pid_bc // batch
  423. pid_b = pid_bc - pid_c * batch
  424. pid_h = tl.program_id(axis=2)
  425. num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
  426. pid_m = tl.program_id(axis=0) // num_pid_n
  427. pid_n = tl.program_id(axis=0) % num_pid_n
  428. c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head
  429. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  430. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  431. if HAS_SEQ_IDX:
  432. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  433. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  434. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  435. offs_k = tl.arange(0, BLOCK_SIZE_K)
  436. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen)
  437. c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen)
  438. dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
  439. if HAS_SEQ_IDX:
  440. seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
  441. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  442. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  443. if HAS_SEQ_IDX:
  444. seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
  445. for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
  446. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32)
  447. dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
  448. if not HAS_SEQ_IDX:
  449. scale_k = tl.exp(dA_cs_k)
  450. else:
  451. seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
  452. scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0)
  453. dout = (dout * scale_k).to(dout_ptr.dtype.element_ty)
  454. c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0)
  455. acc += tl.dot(dout, c)
  456. dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
  457. c_ptrs += BLOCK_SIZE_K * stride_c_seqlen
  458. dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
  459. if HAS_SEQ_IDX:
  460. seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
  461. out = acc.to(dprev_states_ptr.dtype.element_ty)
  462. dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head
  463. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  464. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  465. dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate)
  466. tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate))
  467. @triton.autotune(
  468. configs=[
  469. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  470. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  471. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  472. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  473. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  474. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  475. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  476. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  477. ],
  478. key=['chunk_size', 'dstate', 'hdim'],
  479. )
  480. @triton.jit
  481. def _chunk_scan_bwd_dc_kernel(
  482. # Pointers to matrices
  483. dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
  484. dc_ptr, ddA_cumsum_ptr,
  485. # Matrix dimensions
  486. chunk_size, dstate, hdim,
  487. batch, seqlen, nheads, nheads_per_program, ngroups,
  488. # Strides
  489. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  490. stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
  491. stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
  492. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  493. stride_seq_idx_batch, stride_seq_idx_seqlen,
  494. stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate,
  495. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
  496. # Meta-parameters
  497. HAS_DDA_CS: tl.constexpr,
  498. HAS_SEQ_IDX: tl.constexpr,
  499. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  500. ):
  501. pid_bc = tl.program_id(axis=1)
  502. pid_c = pid_bc // batch
  503. pid_b = pid_bc - pid_c * batch
  504. pid_sg = tl.program_id(axis=2)
  505. pid_s = pid_sg // ngroups
  506. pid_g = pid_sg - pid_s * ngroups
  507. num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
  508. pid_m = tl.program_id(axis=0) // num_pid_n
  509. pid_n = tl.program_id(axis=0) % num_pid_n
  510. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
  511. dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split
  512. prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head
  513. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
  514. if HAS_DDA_CS:
  515. C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head
  516. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
  517. if HAS_SEQ_IDX:
  518. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  519. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  520. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  521. offs_k = tl.arange(0, BLOCK_SIZE_K)
  522. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
  523. prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
  524. dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
  525. if HAS_DDA_CS:
  526. C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
  527. ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
  528. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  529. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  530. if HAS_DDA_CS:
  531. c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
  532. if HAS_SEQ_IDX:
  533. seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
  534. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  535. nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
  536. for h in range(nheads_iter):
  537. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
  538. prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
  539. prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
  540. dc = tl.dot(dout, prev_states)
  541. dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  542. if not HAS_SEQ_IDX:
  543. scale = tl.exp(dA_cs_m)
  544. else:
  545. scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
  546. dc *= scale[:, None]
  547. if HAS_DDA_CS:
  548. ddA_cs = tl.sum(dc * c, axis=1)
  549. tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
  550. acc += dc
  551. dout_ptrs += stride_dout_head
  552. prev_states_ptrs += stride_prev_states_head
  553. dA_cumsum_ptrs += stride_dA_cs_head
  554. if HAS_DDA_CS:
  555. ddA_cumsum_ptrs += stride_ddA_cs_head
  556. # if HAS_SEQ_IDX:
  557. # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
  558. # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  559. # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0)
  560. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  561. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  562. dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate)
  563. tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
  564. @triton.autotune(
  565. configs=[
  566. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
  567. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  568. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  569. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  570. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  571. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  572. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  573. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  574. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  575. ],
  576. key=['chunk_size', 'hdim'],
  577. )
  578. @triton.jit
  579. def _chunk_scan_bwd_dx_kernel(
  580. # Pointers to matrices
  581. x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr,
  582. dx_ptr, ddt_ptr, # dD_ptr,
  583. # Matrix dimensions
  584. chunk_size, hdim,
  585. batch, seqlen, nheads_ngroups_ratio,
  586. # Strides
  587. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  588. stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
  589. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  590. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  591. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  592. stride_D_head,
  593. stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
  594. stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
  595. # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize,
  596. # Meta-parameters
  597. HAS_D: tl.constexpr,
  598. D_HAS_HDIM: tl.constexpr,
  599. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  600. ):
  601. pid_bc = tl.program_id(axis=1)
  602. pid_c = pid_bc // batch
  603. pid_b = pid_bc - pid_c * batch
  604. pid_h = tl.program_id(axis=2)
  605. num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
  606. pid_m = tl.program_id(axis=0) // num_pid_n
  607. pid_n = tl.program_id(axis=0) % num_pid_n
  608. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  609. cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
  610. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  611. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  612. ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
  613. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  614. # if HAS_D:
  615. # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
  616. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  617. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  618. offs_k = tl.arange(0, BLOCK_SIZE_K)
  619. cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
  620. dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
  621. dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
  622. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  623. dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  624. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  625. # Idk why limiting K_MAX gives wrong results, is it a Triton bug?
  626. # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
  627. K_MAX = chunk_size_limit
  628. for k in range(0, K_MAX, BLOCK_SIZE_K):
  629. # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
  630. cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
  631. dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
  632. dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
  633. cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
  634. mask = k + offs_k[None, :] >= offs_m[:, None]
  635. cb = tl.where(mask, cb, 0.0)
  636. cb = cb.to(dout_ptr.dtype.element_ty)
  637. acc += tl.dot(cb, dout)
  638. cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
  639. dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
  640. dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
  641. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  642. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  643. dt_ptrs = dt_ptr + offs_m * stride_dt_csize
  644. dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  645. dx = acc * dt_m[:, None]
  646. dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
  647. dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
  648. if HAS_D:
  649. dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
  650. dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  651. if D_HAS_HDIM:
  652. D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
  653. else:
  654. D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
  655. dx += dout_res * D
  656. tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
  657. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  658. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  659. ddt = tl.sum(acc * x, axis=1)
  660. ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
  661. tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
  662. # if HAS_D:
  663. # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim)
  664. # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32)
  665. # dD = tl.sum(x * dout, axis=0)
  666. # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N)
  667. # Disabling HAS_DDA_CS for now since it's much slower
  668. @triton.autotune(
  669. configs=[
  670. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
  671. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
  672. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
  673. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
  674. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
  675. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
  676. # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
  677. # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
  678. # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
  679. # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
  680. # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
  681. # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
  682. # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
  683. # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
  684. ],
  685. key=['chunk_size', 'hdim'],
  686. )
  687. # @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)})
  688. # @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32})
  689. @triton.jit
  690. def _chunk_scan_bwd_dcb_kernel(
  691. # Pointers to matrices
  692. x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
  693. dcb_ptr, ddA_cumsum_ptr,
  694. # Matrix dimensions
  695. chunk_size, hdim,
  696. batch, seqlen, nheads, nheads_per_program, ngroups,
  697. # Strides
  698. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  699. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  700. stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
  701. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  702. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  703. stride_seq_idx_batch, stride_seq_idx_seqlen,
  704. stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n,
  705. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
  706. # Meta-parameters
  707. HAS_DDA_CS: tl.constexpr,
  708. HAS_SEQ_IDX: tl.constexpr,
  709. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  710. ):
  711. pid_bc = tl.program_id(axis=1)
  712. pid_c = pid_bc // batch
  713. pid_b = pid_bc - pid_c * batch
  714. pid_sg = tl.program_id(axis=2)
  715. pid_s = pid_sg // ngroups
  716. pid_g = pid_sg - pid_s * ngroups
  717. num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
  718. pid_m = tl.program_id(axis=0) // num_pid_n
  719. pid_n = tl.program_id(axis=0) % num_pid_n
  720. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
  721. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
  722. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
  723. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
  724. if HAS_DDA_CS:
  725. cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head
  726. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
  727. if HAS_SEQ_IDX:
  728. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  729. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  730. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  731. offs_k = tl.arange(0, BLOCK_SIZE_K)
  732. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
  733. x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
  734. dt_ptrs = dt_ptr + offs_n * stride_dt_csize
  735. if HAS_DDA_CS:
  736. cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
  737. ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
  738. if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
  739. dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
  740. dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
  741. tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
  742. return
  743. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  744. chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
  745. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  746. if HAS_DDA_CS:
  747. cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
  748. nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
  749. for h in range(nheads_iter):
  750. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
  751. x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
  752. dcb = tl.dot(dout, x)
  753. dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
  754. dcb *= dt_n
  755. dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  756. dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
  757. dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
  758. if HAS_DDA_CS:
  759. tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
  760. ddA_cs = dcb * cb
  761. mask = offs_m[:, None] >= offs_n[None, :] + 1
  762. ddA_cs = tl.where(mask, ddA_cs, 0.0)
  763. ddA_cs = tl.cumsum(ddA_cs, axis=1)
  764. ddA_cs = tl.where(mask, ddA_cs, 0.0)
  765. ddA_cs = tl.sum(ddA_cs, axis=0)
  766. tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
  767. tl.store(ddA_cumsum_ptr, 0.0)
  768. acc += dcb
  769. dout_ptrs += stride_dout_head
  770. x_ptrs += stride_x_head
  771. dt_ptrs += stride_dt_head
  772. dA_cumsum_ptr += stride_dA_cs_head
  773. if HAS_DDA_CS:
  774. ddA_cumsum_ptr += stride_ddA_cs_head
  775. ddA_cumsum_ptrs += stride_ddA_cs_head
  776. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  777. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  778. if HAS_SEQ_IDX:
  779. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  780. seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
  781. acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
  782. mask = offs_m[:, None] >= offs_n[None, :]
  783. acc = tl.where(mask, acc, 0.0)
  784. dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
  785. dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
  786. tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
  787. # Not numerically stable and should not be used. Leaving here for reference.
  788. @triton.autotune(
  789. configs=[
  790. triton.Config({'BLOCK_SIZE_M': 32}),
  791. triton.Config({'BLOCK_SIZE_M': 64}),
  792. triton.Config({'BLOCK_SIZE_M': 128}),
  793. triton.Config({'BLOCK_SIZE_M': 256}),
  794. ],
  795. key=["chunk_size", "hdim"],
  796. )
  797. @triton.jit
  798. def _chunk_scan_bwd_ddAcs_unstable_kernel(
  799. # Pointers to matrices
  800. dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr,
  801. ddA_cumsum_ptr, dD_ptr,
  802. # Matrix dimensions
  803. chunk_size, hdim,
  804. batch, seqlen,
  805. # Strides
  806. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  807. stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
  808. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  809. stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
  810. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  811. stride_D_head,
  812. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
  813. stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
  814. # Meta-parameters
  815. HAS_D: tl.constexpr,
  816. D_HAS_HDIM: tl.constexpr,
  817. SUBTRACT_DDTDT: tl.constexpr,
  818. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
  819. ):
  820. pid_bc = tl.program_id(axis=1)
  821. pid_c = pid_bc // batch
  822. pid_b = pid_bc - pid_c * batch
  823. pid_h = tl.program_id(axis=2)
  824. pid_m = tl.program_id(axis=0)
  825. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  826. out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
  827. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  828. ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
  829. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
  830. if HAS_D:
  831. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  832. dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
  833. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  834. offs_n = tl.arange(0, BLOCK_SIZE_N)
  835. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
  836. out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
  837. if HAS_D:
  838. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  839. if D_HAS_HDIM:
  840. dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
  841. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  842. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  843. out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  844. if HAS_D:
  845. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  846. if D_HAS_HDIM:
  847. dD = tl.sum(dout * x, axis=0)
  848. tl.store(dD_ptrs, dD, mask=offs_n < hdim)
  849. D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
  850. else:
  851. dD = tl.sum(dout * x)
  852. tl.store(dD_ptr, dD)
  853. D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
  854. out -= x * D
  855. ddA_cs = tl.sum(dout * out, axis=1)
  856. if SUBTRACT_DDTDT:
  857. dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  858. ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  859. ddA_cs -= dt * ddt
  860. tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
  861. @triton.autotune(
  862. configs=[
  863. # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
  864. # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
  865. # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
  866. # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
  867. # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
  868. # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
  869. # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
  870. # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
  871. triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
  872. triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
  873. triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
  874. triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
  875. triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
  876. triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
  877. triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
  878. triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
  879. ],
  880. key=['chunk_size', 'hdim'],
  881. )
  882. @triton.jit
  883. def _chunk_scan_bwd_ddAcs_stable_kernel_old(
  884. # Pointers to matrices
  885. x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
  886. ddAcs_ptr,
  887. # Matrix dimensions
  888. chunk_size, hdim,
  889. batch, seqlen, nheads_ngroups_ratio,
  890. # Strides
  891. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  892. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  893. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  894. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  895. stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
  896. stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n,
  897. # Meta-parameters
  898. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  899. ):
  900. pid_bc = tl.program_id(axis=1)
  901. pid_c = pid_bc // batch
  902. pid_b = pid_bc - pid_c * batch
  903. pid_h = tl.program_id(axis=2)
  904. num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
  905. pid_m = tl.program_id(axis=0) // num_pid_n
  906. pid_n = tl.program_id(axis=0) % num_pid_n
  907. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  908. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  909. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  910. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  911. cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
  912. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  913. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  914. offs_k = tl.arange(0, BLOCK_SIZE_K)
  915. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
  916. x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
  917. dt_ptrs = dt_ptr + offs_n * stride_dt_csize
  918. cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
  919. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  920. chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
  921. # Doing a matmul loop with cumsum later on will cause Triton to crash
  922. # Instead we do just one big matmul
  923. # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  924. # for k in range(0, hdim, BLOCK_SIZE_K):
  925. # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
  926. # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
  927. # acc += tl.dot(dout, x)
  928. # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
  929. # x_ptrs += BLOCK_SIZE_K * stride_x_hdim
  930. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
  931. x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
  932. acc = tl.dot(dout, x)
  933. cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
  934. acc *= cb
  935. dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
  936. acc *= dt_n
  937. dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  938. dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
  939. acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
  940. mask = offs_m[:, None] >= offs_n[None, :] + 1
  941. acc = tl.where(mask, acc, 0.0)
  942. acc = tl.cumsum(acc, axis=1)
  943. acc = tl.where(mask, acc, 0.0)
  944. ddA_cs = tl.sum(acc, axis=0)
  945. ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
  946. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  947. ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
  948. tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
  949. tl.store(ddAcs_ptr, 0.0)
  950. # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64)
  951. # offs_k = tl.arange(0, BLOCK_SIZE_K)
  952. # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
  953. # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
  954. # dt_ptrs = dt_ptr + offs_n * stride_dt_csize
  955. # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
  956. # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  957. # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
  958. # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
  959. # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
  960. # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  961. # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
  962. # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
  963. # for n in range(0, chunk_size_limit_n, 64):
  964. # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0)
  965. # acc = tl.dot(dout, x)
  966. # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32)
  967. # acc *= cb
  968. # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
  969. # acc *= dt_n
  970. # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
  971. # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
  972. # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n
  973. # acc = tl.where(mask, acc, 0.0)
  974. # acc = tl.cumsum(acc, axis=1)
  975. # acc = tl.where(mask, acc, 0.0)
  976. # ddA_cs = tl.sum(acc, axis=0)
  977. # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n)
  978. # # tl.store(ddAcs_ptr, 0.0)
  979. @triton.autotune(
  980. configs=[
  981. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
  982. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
  983. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
  984. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
  985. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
  986. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
  987. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
  988. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
  989. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
  990. ],
  991. key=['chunk_size', 'hdim'],
  992. )
  993. @triton.jit
  994. def _chunk_scan_bwd_ddAcs_stable_kernel(
  995. # Pointers to matrices
  996. x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
  997. ddA_cumsum_ptr,
  998. # Matrix dimensions
  999. chunk_size, hdim,
  1000. batch, seqlen, nheads_ngroups_ratio,
  1001. # Strides
  1002. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  1003. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  1004. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  1005. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  1006. stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
  1007. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
  1008. # Meta-parameters
  1009. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  1010. ):
  1011. pid_bc = tl.program_id(axis=1)
  1012. pid_c = pid_bc // batch
  1013. pid_b = pid_bc - pid_c * batch
  1014. pid_h = tl.program_id(axis=2)
  1015. pid_m = tl.program_id(axis=0)
  1016. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  1017. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  1018. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  1019. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  1020. cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
  1021. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
  1022. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  1023. offs_n = tl.arange(0, BLOCK_SIZE_N)
  1024. offs_k = tl.arange(0, BLOCK_SIZE_K)
  1025. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
  1026. x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
  1027. dt_ptrs = dt_ptr + offs_n * stride_dt_csize
  1028. cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
  1029. ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
  1030. tl.store(ddA_cumsum_ptr, 0.0)
  1031. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  1032. rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
  1033. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
  1034. dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  1035. # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower
  1036. lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M
  1037. # lo, hi = 0, chunk_size
  1038. for start_n in range(lo, hi, BLOCK_SIZE_N):
  1039. start_n = tl.multiple_of(start_n, BLOCK_SIZE_N)
  1040. # Doing a matmul loop with cumsum later on will cause Triton to crash
  1041. # Instead we do just one big matmul
  1042. # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  1043. # for k in range(0, hdim, BLOCK_SIZE_K):
  1044. # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
  1045. # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
  1046. # acc += tl.dot(dout, x)
  1047. # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
  1048. # x_ptrs += BLOCK_SIZE_K * stride_x_hdim
  1049. # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
  1050. x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0)
  1051. acc = tl.dot(dout, x)
  1052. dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
  1053. acc *= dt_n
  1054. # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j]
  1055. cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
  1056. acc *= cb
  1057. dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
  1058. acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
  1059. mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
  1060. acc = tl.where(mask, acc, 0.0)
  1061. rowsum_new = rowsum + tl.sum(acc, axis=1)
  1062. acc = rowsum[:, None] + tl.cumsum(acc, axis=1)
  1063. rowsum = rowsum_new
  1064. acc = tl.where(mask, acc, 0.0)
  1065. ddA_cs = tl.sum(acc, axis=0)
  1066. tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1)
  1067. x_ptrs += BLOCK_SIZE_N * stride_x_seqlen
  1068. dt_ptrs += BLOCK_SIZE_N * stride_dt_csize
  1069. cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n
  1070. ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
  1071. # Need to zero out the rest, since we'll be summing the rows together
  1072. for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
  1073. tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)
  1074. ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
  1075. @triton.autotune(
  1076. configs=[
  1077. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  1078. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  1079. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  1080. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  1081. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  1082. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  1083. ],
  1084. key=['chunk_size', 'dstate', 'hdim'],
  1085. )
  1086. @triton.jit
  1087. def _chunk_scan_bwd_ddAcs_prev_kernel(
  1088. # Pointers to matrices
  1089. dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
  1090. ddA_cumsum_ptr,
  1091. # Matrix dimensions
  1092. chunk_size, dstate, hdim,
  1093. batch, seqlen, nchunks, nheads_ngroups_ratio,
  1094. # Strides
  1095. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  1096. stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
  1097. stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
  1098. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  1099. stride_seq_idx_batch, stride_seq_idx_seqlen,
  1100. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
  1101. # Meta-parameters
  1102. HAS_SEQ_IDX: tl.constexpr,
  1103. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  1104. ):
  1105. pid_bc = tl.program_id(axis=1)
  1106. pid_c = pid_bc // batch
  1107. pid_b = pid_bc - pid_c * batch
  1108. pid_h = tl.program_id(axis=2)
  1109. num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
  1110. pid_m = tl.program_id(axis=0) // num_pid_n
  1111. pid_n = tl.program_id(axis=0) % num_pid_n
  1112. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  1113. prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head
  1114. C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
  1115. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
  1116. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  1117. if HAS_SEQ_IDX:
  1118. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  1119. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  1120. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  1121. offs_k = tl.arange(0, BLOCK_SIZE_K)
  1122. dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
  1123. prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
  1124. C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
  1125. dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
  1126. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  1127. dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
  1128. prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
  1129. prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
  1130. acc = tl.dot(dout, prev_states)
  1131. c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
  1132. ddA_cs = tl.sum(acc * c, axis=1)
  1133. dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  1134. if not HAS_SEQ_IDX:
  1135. scale = tl.exp(dA_cs_m)
  1136. if HAS_SEQ_IDX:
  1137. seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
  1138. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  1139. scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
  1140. ddA_cs *= scale
  1141. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  1142. ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
  1143. tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
  1144. def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):
  1145. batch, seqlen, nheads, headdim = x.shape
  1146. _, _, nchunks, chunk_size = dt.shape
  1147. _, _, ngroups, dstate = C.shape
  1148. assert nheads % ngroups == 0
  1149. assert C.shape == (batch, seqlen, ngroups, dstate)
  1150. assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
  1151. if z is not None:
  1152. assert z.shape == x.shape
  1153. if D is not None:
  1154. assert D.shape == (nheads, headdim) or D.shape == (nheads,)
  1155. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1156. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  1157. assert states.shape == (batch, nchunks, nheads, headdim, dstate)
  1158. if seq_idx is not None:
  1159. assert seq_idx.shape == (batch, seqlen)
  1160. # Allocates output.
  1161. out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
  1162. if z is not None:
  1163. out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
  1164. assert out_x.stride() == out.stride()
  1165. else:
  1166. out_x = None
  1167. grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
  1168. batch * nchunks, nheads)
  1169. z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
  1170. if z is not None else (0, 0, 0, 0))
  1171. _chunk_scan_fwd_kernel[grid](
  1172. cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,
  1173. int(chunk_size), int(headdim), int(dstate),
  1174. int(batch), int(seqlen), int(nheads // ngroups),
  1175. cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
  1176. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1177. z_strides[0], z_strides[1], z_strides[2], z_strides[3],
  1178. out.stride(0), out.stride(1), out.stride(2), out.stride(3),
  1179. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  1180. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1181. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  1182. C.stride(0), C.stride(1), C.stride(2), C.stride(3),
  1183. states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
  1184. D.stride(0) if D is not None else 0,
  1185. True,
  1186. D is not None,
  1187. D.dim() == 2 if D is not None else True,
  1188. BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(int(dstate)), 16),
  1189. HAS_Z=z is not None,
  1190. HAS_SEQ_IDX=seq_idx is not None,
  1191. IS_TRITON_22=TRITON_22,
  1192. )
  1193. return out, out_x
  1194. def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None):
  1195. batch, seqlen, nheads, headdim = x.shape
  1196. _, _, nchunks, chunk_size = dt.shape
  1197. _, _, ngroups, dstate = C.shape
  1198. assert nheads % ngroups == 0
  1199. assert C.shape == (batch, seqlen, ngroups, dstate)
  1200. assert B.shape == C.shape
  1201. assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
  1202. if z is not None:
  1203. assert z.shape == x.shape
  1204. if D is not None:
  1205. assert D.shape == (nheads, headdim) or D.shape == (nheads,)
  1206. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1207. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  1208. assert states.shape == (batch, nchunks, nheads, headdim, dstate)
  1209. if seq_idx is not None:
  1210. assert seq_idx.shape == (batch, seqlen)
  1211. # Allocates output.
  1212. out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
  1213. if z is not None:
  1214. out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
  1215. assert out_x.stride() == out.stride()
  1216. else:
  1217. out_x = None
  1218. grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
  1219. z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
  1220. if z is not None else (0, 0, 0, 0))
  1221. _chunk_scan_fwd_kernel_wip[grid](
  1222. cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D,
  1223. int(chunk_size), int(headdim), int(dstate),
  1224. int(batch), int(seqlen), int(nheads // ngroups),
  1225. cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
  1226. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1227. z_strides[0], z_strides[1], z_strides[2], z_strides[3],
  1228. out.stride(0), out.stride(1), out.stride(2), out.stride(3),
  1229. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  1230. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1231. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  1232. C.stride(0), C.stride(1), C.stride(2), C.stride(3),
  1233. B.stride(0), B.stride(1), B.stride(2), B.stride(3),
  1234. states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
  1235. D.stride(0) if D is not None else 0,
  1236. D is not None,
  1237. D.dim() == 2 if D is not None else True,
  1238. BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(int(dstate)), 16),
  1239. BLOCK_SIZE_M=128,
  1240. HAS_Z=z is not None,
  1241. HAS_SEQ_IDX=seq_idx is not None,
  1242. )
  1243. return out, out_x
  1244. def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False):
  1245. batch, seqlen, nheads, headdim = x.shape
  1246. assert z.shape == x.shape
  1247. assert out.shape == x.shape
  1248. assert dout.shape == out.shape
  1249. nchunks = math.ceil(seqlen / chunk_size)
  1250. if D is not None:
  1251. assert D.shape == (nheads, headdim) or D.shape == (nheads,)
  1252. assert D.stride(-1) == 1
  1253. if has_ddAcs:
  1254. ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
  1255. if D is not None:
  1256. BLOCK_SIZE_min = 32
  1257. dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
  1258. headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
  1259. else:
  1260. dD = None
  1261. if dz is not None:
  1262. assert dz.shape == z.shape
  1263. else:
  1264. dz = torch.empty_like(z)
  1265. if recompute_output:
  1266. outz = torch.empty_like(x)
  1267. dout_x = torch.empty_like(dout)
  1268. dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
  1269. if D is not None else (0, 0, 0, 0, 0))
  1270. grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
  1271. with torch.cuda.device(x.device.index):
  1272. _chunk_scan_bwd_dz_kernel[grid_dz](
  1273. dout, out, z, x, D, outz if recompute_output else None,
  1274. dz, dout_x, dD, ddA_cumsum if has_ddAcs else None,
  1275. int(chunk_size), int(headdim),
  1276. int(batch), int(seqlen),
  1277. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1278. out.stride(0), out.stride(1), out.stride(2), out.stride(3),
  1279. z.stride(0), z.stride(1), z.stride(2), z.stride(3),
  1280. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1281. D.stride(0) if D is not None else 0,
  1282. *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)),
  1283. dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3),
  1284. dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3),
  1285. dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
  1286. *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
  1287. if has_ddAcs else (0, 0, 0, 0)),
  1288. D is not None,
  1289. D.dim() == 2 if D is not None else True,
  1290. has_ddAcs,
  1291. BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
  1292. RECOMPUTE_OUTPUT=recompute_output,
  1293. )
  1294. if D is not None:
  1295. BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"]
  1296. n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
  1297. dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
  1298. if D.dim() == 1:
  1299. dD = rearrange(dD, "h 1 -> h")
  1300. return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD)
  1301. return return_vals if not recompute_output else (*return_vals, outz)
  1302. def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None):
  1303. batch, seqlen, nheads, headdim = dout.shape
  1304. _, _, nchunks, chunk_size = dA_cumsum.shape
  1305. _, _, ngroups, dstate = C.shape
  1306. assert nheads % ngroups == 0
  1307. assert C.shape == (batch, seqlen, ngroups, dstate)
  1308. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  1309. if seq_idx is not None:
  1310. assert seq_idx.shape == (batch, seqlen)
  1311. dtype = C.dtype if dtype is None else dtype
  1312. dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype)
  1313. grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
  1314. batch * nchunks, nheads)
  1315. with torch.cuda.device(C.device.index):
  1316. _chunk_scan_bwd_dstates_kernel[grid_dstates](
  1317. dout, C, dprev_states, dA_cumsum, seq_idx,
  1318. int(headdim), int(dstate), int(chunk_size),
  1319. int(batch), int(seqlen), int(nchunks), int(nheads // ngroups),
  1320. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1321. C.stride(0), C.stride(1), C.stride(2), C.stride(3),
  1322. dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4),
  1323. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1324. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  1325. HAS_SEQ_IDX=seq_idx is not None,
  1326. )
  1327. return dprev_states
  1328. def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1):
  1329. batch, nchunks, nheads, headdim, dstate = prev_states.shape
  1330. _, seqlen, _, _ = dout.shape
  1331. _, _, _, chunk_size = dA_cumsum.shape
  1332. assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
  1333. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  1334. assert dout.shape == (batch, seqlen, nheads, headdim)
  1335. if seq_idx is not None:
  1336. assert seq_idx.shape == (batch, seqlen)
  1337. if C is not None:
  1338. assert C.shape == (batch, seqlen, ngroups, dstate)
  1339. C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3))
  1340. ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
  1341. ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3))
  1342. else:
  1343. C_strides = (0, 0, 0, 0)
  1344. ddA_cumsum_prev = None
  1345. ddA_cumsum_prev_strides = (0, 0, 0, 0)
  1346. nheads_ngroups_ratio = nheads // ngroups
  1347. sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count
  1348. nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
  1349. nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
  1350. dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32)
  1351. grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
  1352. batch * nchunks, nsplits * ngroups)
  1353. with torch.cuda.device(dout.device.index):
  1354. _chunk_scan_bwd_dc_kernel[grid_dc](
  1355. dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev,
  1356. int(chunk_size), int(dstate), int(headdim),
  1357. int(batch), int(seqlen), int(nheads), int(nheads_per_program), int(ngroups),
  1358. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1359. prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
  1360. *C_strides,
  1361. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1362. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  1363. dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4),
  1364. *ddA_cumsum_prev_strides,
  1365. HAS_DDA_CS=ddA_cumsum_prev is not None,
  1366. HAS_SEQ_IDX=seq_idx is not None,
  1367. BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
  1368. )
  1369. dC = dC.sum(2)
  1370. return dC if C is None else (dC, ddA_cumsum_prev)
  1371. def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1):
  1372. batch, seqlen, nheads, headdim = x.shape
  1373. _, _, nchunks, chunk_size = dt.shape
  1374. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1375. assert dA_cumsum.shape == dt.shape
  1376. assert dout.shape == x.shape
  1377. if seq_idx is not None:
  1378. assert seq_idx.shape == (batch, seqlen)
  1379. if CB is not None:
  1380. assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
  1381. CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4))
  1382. BLOCK_SIZE_M_min = 16
  1383. ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
  1384. chunk_size, device=x.device, dtype=torch.float32)
  1385. ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4))
  1386. else:
  1387. CB_strides = (0, 0, 0, 0, 0)
  1388. ddA_cumsum = None
  1389. ddA_cumsum_strides = (0, 0, 0, 0, 0)
  1390. nheads_ngroups_ratio = nheads // ngroups
  1391. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
  1392. nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
  1393. nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
  1394. dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32)
  1395. grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
  1396. batch * nchunks, nsplits * ngroups)
  1397. with torch.cuda.device(x.device.index):
  1398. _chunk_scan_bwd_dcb_kernel[grid_dcb](
  1399. x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum,
  1400. int(chunk_size), int(headdim),
  1401. int(batch), int(seqlen), int(nheads), int(nheads_per_program), int(ngroups),
  1402. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1403. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1404. *CB_strides,
  1405. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  1406. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1407. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  1408. dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5),
  1409. *ddA_cumsum_strides,
  1410. HAS_DDA_CS=ddA_cumsum is not None,
  1411. HAS_SEQ_IDX=seq_idx is not None,
  1412. BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
  1413. )
  1414. dcb = dcb.sum(2)
  1415. if ddA_cumsum is not None:
  1416. BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"]
  1417. n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
  1418. ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
  1419. return dcb if CB is None else (dcb, ddA_cumsum)
  1420. def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None):
  1421. batch, seqlen, nheads, headdim = x.shape
  1422. _, _, nchunks, chunk_size = dt.shape
  1423. ngroups = cb.shape[2]
  1424. assert nheads % ngroups == 0
  1425. assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
  1426. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1427. assert dA_cumsum.shape == dt.shape
  1428. assert dout.shape == x.shape
  1429. # if D is not None:
  1430. # BLOCK_SIZE_M_min = 32
  1431. # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32)
  1432. # else:
  1433. # dD = None
  1434. dx = torch.empty_like(x)
  1435. ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
  1436. grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
  1437. batch * nchunks, nheads)
  1438. with torch.cuda.device(x.device.index):
  1439. _chunk_scan_bwd_dx_kernel[grid_dx](
  1440. x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD,
  1441. int(chunk_size), int(headdim),
  1442. int(batch), int(seqlen), int(nheads // ngroups),
  1443. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1444. cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2),
  1445. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1446. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  1447. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1448. D.stride(0) if D is not None else 0,
  1449. dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
  1450. ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
  1451. # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0,
  1452. D is not None,
  1453. D.dim() == 2 if D is not None else True,
  1454. )
  1455. # if D is not None:
  1456. # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
  1457. # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
  1458. # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
  1459. return dx, ddt.to(dtype=dt.dtype)
  1460. def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True):
  1461. """Not numerically stable and should not be used. Leaving here for reference.
  1462. """
  1463. batch, seqlen, nheads, headdim = x.shape
  1464. _, _, nchunks, chunk_size = dt.shape
  1465. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1466. assert ddt.shape == dt.shape
  1467. assert out.shape == x.shape
  1468. assert dout.shape == x.shape
  1469. if D is not None:
  1470. assert D.shape == (nheads, headdim) or D.shape == (nheads,)
  1471. ddA_cumsum = torch.empty_like(dt)
  1472. grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
  1473. if D is not None: # Triton gives wrong results if we write to the same location
  1474. BLOCK_SIZE_min = 32
  1475. dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
  1476. headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
  1477. else:
  1478. dD = None
  1479. dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
  1480. if D is not None else (0, 0, 0, 0, 0))
  1481. with torch.cuda.device(x.device.index):
  1482. _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs](
  1483. dout, out, dt, ddt, x, D, ddA_cumsum, dD,
  1484. int(chunk_size), int(headdim),
  1485. int(batch), int(seqlen),
  1486. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1487. out.stride(0), out.stride(1), out.stride(2), out.stride(3),
  1488. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  1489. ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
  1490. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1491. D.stride(0) if D is not None else 0,
  1492. ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
  1493. dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
  1494. D is not None,
  1495. D.dim() == 2 if D is not None else True,
  1496. subtract_ddtdt,
  1497. BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
  1498. )
  1499. if D is not None:
  1500. BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
  1501. n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
  1502. dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
  1503. if D.dim() == 1:
  1504. dD = rearrange(dD, "h 1 -> h")
  1505. return ddA_cumsum, dD
  1506. def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb):
  1507. batch, seqlen, nheads, headdim = x.shape
  1508. _, _, nchunks, chunk_size = dt.shape
  1509. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1510. assert dout.shape == x.shape
  1511. assert dA_cumsum.shape == dt.shape
  1512. ngroups = cb.shape[2]
  1513. assert nheads % ngroups == 0
  1514. assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
  1515. BLOCK_SIZE_M_min = 16
  1516. ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
  1517. chunk_size, device=x.device, dtype=torch.float32)
  1518. grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
  1519. with torch.cuda.device(x.device.index):
  1520. _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs](
  1521. x, dout, dt, dA_cumsum, cb, ddA_cumsum,
  1522. int(chunk_size), int(headdim),
  1523. int(batch), int(seqlen), int(nheads // ngroups),
  1524. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1525. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1526. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  1527. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1528. cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
  1529. ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
  1530. BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
  1531. BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16),
  1532. )
  1533. BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"]
  1534. n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
  1535. ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
  1536. return ddA_cumsum
  1537. def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb):
  1538. batch, seqlen, nheads, headdim = x.shape
  1539. _, _, nchunks, chunk_size = dt.shape
  1540. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1541. assert dout.shape == x.shape
  1542. assert dA_cumsum.shape == dt.shape
  1543. ngroups = cb.shape[2]
  1544. assert nheads % ngroups == 0
  1545. assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
  1546. BLOCK_SIZE_M_min = 32
  1547. ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
  1548. chunk_size, device=x.device, dtype=torch.float32)
  1549. grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
  1550. with torch.cuda.device(x.device.index):
  1551. _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs](
  1552. x, dout, dt, dA_cumsum, cb, ddA_cumsum,
  1553. int(chunk_size), int(headdim),
  1554. int(batch), int(seqlen), int(nheads // ngroups),
  1555. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  1556. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1557. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  1558. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1559. cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
  1560. ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
  1561. BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
  1562. )
  1563. BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
  1564. n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
  1565. ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
  1566. return ddA_cumsum
  1567. def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None):
  1568. batch, nchunks, nheads, headdim, dstate = prev_states.shape
  1569. _, seqlen, _, _ = dout.shape
  1570. _, _, _, chunk_size = dA_cumsum.shape
  1571. assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
  1572. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  1573. assert dout.shape == (batch, seqlen, nheads, headdim)
  1574. ngroups = C.shape[2]
  1575. assert nheads % ngroups == 0
  1576. assert C.shape == (batch, seqlen, ngroups, dstate)
  1577. if seq_idx is not None:
  1578. assert seq_idx.shape == (batch, seqlen)
  1579. ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
  1580. grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
  1581. batch * nchunks, nheads)
  1582. with torch.cuda.device(dout.device.index):
  1583. _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs](
  1584. dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev,
  1585. int(chunk_size), int(dstate), int(headdim),
  1586. int(batch), int(seqlen), int(nchunks), int(nheads // ngroups),
  1587. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  1588. prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
  1589. C.stride(0), C.stride(1), C.stride(2), C.stride(3),
  1590. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  1591. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  1592. ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3),
  1593. HAS_SEQ_IDX=seq_idx is not None,
  1594. BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
  1595. )
  1596. return ddA_cumsum_prev
  1597. class ChunkScanFn(torch.autograd.Function):
  1598. @staticmethod
  1599. def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
  1600. # Check constraints.
  1601. batch, seqlen, nheads, headdim = x.shape
  1602. _, _, ngroups, dstate = B.shape
  1603. assert B.shape == (batch, seqlen, ngroups, dstate)
  1604. _, _, nchunks, chunk_size = dt.shape
  1605. assert seqlen == nchunks * chunk_size
  1606. assert C.shape == B.shape
  1607. if z is not None:
  1608. assert z.shape == x.shape
  1609. if D is not None:
  1610. assert D.shape == (nheads, headdim) or D.shape == (nheads,)
  1611. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  1612. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  1613. assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
  1614. if B.stride(-1) != 1:
  1615. B = B.contiguous()
  1616. if C.stride(-1) != 1:
  1617. C = C.contiguous()
  1618. if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
  1619. x = x.contiguous()
  1620. if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
  1621. z = z.contiguous()
  1622. if D is not None and D.stride(-1) != 1:
  1623. D = D.contiguous()
  1624. CB = _bmm_chunk_fwd(C, B, chunk_size)
  1625. out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z)
  1626. ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z)
  1627. return out
  1628. @staticmethod
  1629. def backward(ctx, dout):
  1630. if dout.stride(-1) != 1:
  1631. dout = dout.contiguous()
  1632. out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors
  1633. batch, seqlen, nheads, headdim = x.shape
  1634. _, _, nchunks, chunk_size = dt.shape
  1635. _, _, ngroups, dstate = B.shape
  1636. assert dout.shape == (batch, seqlen, nheads, headdim)
  1637. if z is not None:
  1638. dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D)
  1639. else:
  1640. dz = None
  1641. dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype)
  1642. dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups)
  1643. dC = dC.to(C.dtype)
  1644. dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups)
  1645. dCB = dCB.to(CB.dtype)
  1646. dB = _bmm_chunk_bwd(C, dCB)
  1647. dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC)
  1648. dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D)
  1649. # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
  1650. # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
  1651. if z is not None:
  1652. ddA_cumsum -= ddt * dt
  1653. else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz
  1654. ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D)
  1655. ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype)
  1656. return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz
  1657. def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
  1658. """
  1659. prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1.
  1660. Argument:
  1661. B: (batch, seqlen, ngroups, dstate)
  1662. C: (batch, seqlen, ngroups, dstate)
  1663. x: (batch, seqlen, nheads, headdim)
  1664. dt: (batch, nheads, nchunks, chunk_size)
  1665. dA_cumsum: (batch, nheads, nchunks, chunk_size)
  1666. prev_states: (batch, nchunks, nheads, headdim, dstate)
  1667. D: (nheads, headdim) or (nheads,)
  1668. z: (batch, seqlen, nheads, headdim)
  1669. Return:
  1670. out: (batch, seqlen, nheads, headdim)
  1671. """
  1672. return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z)
  1673. def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
  1674. """
  1675. Argument:
  1676. B: (batch, seqlen, ngroups, dstate)
  1677. C: (batch, seqlen, ngroups, dstate)
  1678. x: (batch, seqlen, nheads, headdim)
  1679. dt: (batch, nheads, nchunks, chunk_size)
  1680. dA_cumsum: (batch, nheads, nchunks, chunk_size)
  1681. prev_states: (batch, nchunks, nheads, headdim, dstate)
  1682. D: (nheads, headdim) or (nheads,)
  1683. z: (batch, seqlen, nheads, headdim)
  1684. Return:
  1685. out: (batch, seqlen, nheads, headdim)
  1686. """
  1687. batch, seqlen, nheads, headdim = x.shape
  1688. _, _, ngroups, dstate = B.shape
  1689. assert B.shape == (batch, seqlen, ngroups, dstate)
  1690. _, _, nchunks, chunk_size = dt.shape
  1691. assert seqlen == nchunks * chunk_size
  1692. assert C.shape == B.shape
  1693. B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
  1694. C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
  1695. CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
  1696. rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
  1697. # (batch, nheads, nchunks, chunksize, chunksize)
  1698. dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
  1699. decay = torch.exp(dt_segment_sum)
  1700. scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s")
  1701. causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
  1702. scores_decay = scores_decay.masked_fill(~causal_mask, 0)
  1703. out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
  1704. rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
  1705. state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
  1706. out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
  1707. prev_states.to(C.dtype)) * state_decay_out
  1708. out = out + out_prev
  1709. out = rearrange(out, "b c l h p -> b (c l) h p")
  1710. if D is not None:
  1711. if D.dim() == 1:
  1712. D = rearrange(D, "h -> h 1")
  1713. out = out + x * D
  1714. return out if z is None else out * F.silu(z)