ssd_chunk_state.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. """We want triton==2.1.0 or 2.2.0 for this
  3. """
  4. import math
  5. import torch
  6. import torch.nn.functional as F
  7. import triton
  8. import triton.language as tl
  9. from einops import rearrange, repeat
  10. def init_to_zero(names):
  11. return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
  12. @triton.autotune(
  13. configs=[
  14. triton.Config({'BLOCK_SIZE_H': 1}),
  15. triton.Config({'BLOCK_SIZE_H': 2}),
  16. triton.Config({'BLOCK_SIZE_H': 4}),
  17. triton.Config({'BLOCK_SIZE_H': 8}),
  18. triton.Config({'BLOCK_SIZE_H': 16}),
  19. triton.Config({'BLOCK_SIZE_H': 32}),
  20. triton.Config({'BLOCK_SIZE_H': 64}),
  21. ],
  22. key=['chunk_size', 'nheads'],
  23. )
  24. @triton.jit
  25. def _chunk_cumsum_fwd_kernel(
  26. # Pointers to matrices
  27. dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,
  28. # Matrix dimension
  29. batch, seqlen, nheads, chunk_size,
  30. dt_min, dt_max,
  31. # Strides
  32. stride_dt_batch, stride_dt_seqlen, stride_dt_head,
  33. stride_A_head,
  34. stride_dt_bias_head,
  35. stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,
  36. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  37. # Meta-parameters
  38. DT_SOFTPLUS: tl.constexpr,
  39. HAS_DT_BIAS: tl.constexpr,
  40. BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
  41. ):
  42. pid_b = tl.program_id(axis=0)
  43. pid_c = tl.program_id(axis=1)
  44. pid_h = tl.program_id(axis=2)
  45. dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
  46. dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
  47. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
  48. offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
  49. offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
  50. dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
  51. A_ptrs = A_ptr + offs_h * stride_A_head
  52. dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)
  53. dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)
  54. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  55. dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
  56. if HAS_DT_BIAS:
  57. dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
  58. dt += dt_bias[:, None]
  59. if DT_SOFTPLUS:
  60. dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
  61. # As of Triton 2.2.0, tl.clamp is not available yet
  62. # dt = tl.clamp(dt, dt_min, dt_max)
  63. dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
  64. dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
  65. tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
  66. A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
  67. dA = dt * A[:, None]
  68. dA_cs = tl.cumsum(dA, axis=1)
  69. tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
  70. @triton.autotune(
  71. configs=[
  72. triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
  73. triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
  74. triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
  75. triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
  76. triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
  77. triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
  78. triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
  79. ],
  80. key=['chunk_size', 'nheads'],
  81. )
  82. @triton.jit
  83. def _chunk_cumsum_bwd_kernel(
  84. # Pointers to matrices
  85. ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,
  86. ddt_ptr, dA_ptr, ddt_bias_ptr,
  87. # Matrix dimensions
  88. batch, seqlen, nheads, chunk_size,
  89. dt_min, dt_max,
  90. # Strides
  91. stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,
  92. stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,
  93. stride_dt_batch, stride_dt_seqlen, stride_dt_head,
  94. stride_A_head,
  95. stride_dt_bias_head,
  96. stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,
  97. stride_dA_head,
  98. stride_ddt_bias_head,
  99. # Meta-parameters
  100. DT_SOFTPLUS: tl.constexpr,
  101. HAS_DT_BIAS: tl.constexpr,
  102. BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
  103. ):
  104. pid_b = tl.program_id(axis=0)
  105. pid_c = tl.program_id(axis=1)
  106. pid_h = tl.program_id(axis=2)
  107. ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
  108. ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
  109. dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
  110. ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
  111. offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
  112. offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
  113. ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)
  114. ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)
  115. dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
  116. ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)
  117. A_ptrs = A_ptr + offs_h * stride_A_head
  118. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  119. ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
  120. ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
  121. A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
  122. ddt = ddA * A[:, None] + ddt_out
  123. dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
  124. if HAS_DT_BIAS:
  125. dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
  126. dt += dt_bias[:, None]
  127. if DT_SOFTPLUS:
  128. dt_presoftplus = dt
  129. dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), ddt)
  130. clamp_mask = (dt < dt_min) | (dt > dt_max)
  131. # As of Triton 2.2.0, tl.clamp is not available yet
  132. # dt = tl.clamp(dt, dt_min, dt_max)
  133. dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
  134. dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
  135. ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)
  136. ddt = tl.where(clamp_mask, 0.0, ddt)
  137. if DT_SOFTPLUS:
  138. ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
  139. tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))
  140. dA = tl.sum(ddA * dt, axis=1)
  141. tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
  142. if HAS_DT_BIAS:
  143. ddt_bias = tl.sum(ddt, axis=1)
  144. tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)
  145. @triton.autotune(
  146. configs=[
  147. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
  148. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  149. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  150. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  151. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  152. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
  153. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  154. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
  155. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
  156. ],
  157. key=['hdim', 'dstate', 'chunk_size'],
  158. )
  159. @triton.jit
  160. def _chunk_state_fwd_kernel(
  161. # Pointers to matrices
  162. x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
  163. # Matrix dimensions
  164. hdim, dstate, chunk_size,
  165. batch, seqlen, nheads_ngroups_ratio,
  166. # Strides
  167. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  168. stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
  169. stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
  170. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  171. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  172. stride_seq_idx_batch, stride_seq_idx_seqlen,
  173. # Meta-parameters
  174. HAS_SEQ_IDX: tl.constexpr,
  175. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  176. ):
  177. pid_bc = tl.program_id(axis=1)
  178. pid_c = pid_bc // batch
  179. pid_b = pid_bc - pid_c * batch
  180. pid_h = tl.program_id(axis=2)
  181. num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
  182. pid_m = tl.program_id(axis=0) // num_pid_n
  183. pid_n = tl.program_id(axis=0) % num_pid_n
  184. b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
  185. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  186. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  187. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  188. if HAS_SEQ_IDX:
  189. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  190. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  191. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  192. offs_k = tl.arange(0, BLOCK_SIZE_K)
  193. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
  194. b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
  195. dt_ptrs = dt_ptr + offs_k * stride_dt_csize
  196. dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
  197. dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
  198. if HAS_SEQ_IDX:
  199. seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
  200. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  201. if HAS_SEQ_IDX:
  202. seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
  203. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  204. for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
  205. x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
  206. b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
  207. dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
  208. if HAS_SEQ_IDX:
  209. seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
  210. dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
  211. if not HAS_SEQ_IDX:
  212. scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
  213. else:
  214. scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
  215. b *= scale[:, None]
  216. b = b.to(x_ptr.dtype.element_ty)
  217. acc += tl.dot(x, b)
  218. x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
  219. b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
  220. dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
  221. dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
  222. if HAS_SEQ_IDX:
  223. seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
  224. states = acc.to(states_ptr.dtype.element_ty)
  225. states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
  226. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  227. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  228. states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
  229. c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
  230. tl.store(states_ptrs, states, mask=c_mask)
  231. @triton.autotune(
  232. configs=[
  233. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  234. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  235. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  236. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  237. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  238. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  239. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  240. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  241. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
  242. ],
  243. key=['chunk_size', 'hdim', 'dstate'],
  244. )
  245. @triton.jit
  246. def _chunk_state_bwd_dx_kernel(
  247. # Pointers to matrices
  248. x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr,
  249. dx_ptr, ddt_ptr, ddA_cumsum_ptr,
  250. # Matrix dimensions
  251. chunk_size, hdim, dstate,
  252. batch, seqlen, nheads_ngroups_ratio,
  253. # Strides
  254. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  255. stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
  256. stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
  257. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  258. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  259. stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
  260. stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
  261. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
  262. # Meta-parameters
  263. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  264. BLOCK_SIZE_DSTATE: tl.constexpr,
  265. ):
  266. pid_bc = tl.program_id(axis=1)
  267. pid_c = pid_bc // batch
  268. pid_b = pid_bc - pid_c * batch
  269. pid_h = tl.program_id(axis=2)
  270. num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
  271. pid_m = tl.program_id(axis=0) // num_pid_n
  272. pid_n = tl.program_id(axis=0) % num_pid_n
  273. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  274. b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
  275. dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
  276. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  277. ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
  278. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
  279. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  280. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  281. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  282. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  283. # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
  284. offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
  285. b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
  286. dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
  287. if BLOCK_SIZE_DSTATE <= 128:
  288. b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
  289. dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
  290. dstates = dstates.to(b_ptr.dtype.element_ty)
  291. acc = tl.dot(b, dstates)
  292. else:
  293. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  294. for k in range(0, dstate, BLOCK_SIZE_K):
  295. b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
  296. dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
  297. dstates = dstates.to(b_ptr.dtype.element_ty)
  298. acc += tl.dot(b, dstates)
  299. b_ptrs += BLOCK_SIZE_K * stride_b_dstate
  300. dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
  301. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  302. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  303. dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
  304. dt_ptrs = dt_ptr + offs_m * stride_dt_csize
  305. dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
  306. dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  307. dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  308. acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
  309. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  310. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  311. ddt = tl.sum(acc * x, axis=1)
  312. ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
  313. tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
  314. ddA_cs = -(ddt * dt_m)
  315. ddA_cs_last = -tl.sum(ddA_cs)
  316. ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
  317. tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
  318. tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
  319. dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
  320. dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
  321. dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
  322. tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
  323. @triton.autotune(
  324. configs=[
  325. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  326. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  327. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  328. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  329. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  330. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  331. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  332. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  333. ],
  334. key=['chunk_size', 'dstate', 'hdim'],
  335. )
  336. @triton.jit
  337. def _chunk_state_bwd_db_kernel(
  338. # Pointers to matrices
  339. x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
  340. db_ptr, ddA_cumsum_ptr,
  341. # Matrix dimensions
  342. chunk_size, dstate, hdim,
  343. batch, seqlen, nheads, nheads_per_program, ngroups,
  344. # Strides
  345. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  346. stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
  347. stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
  348. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  349. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  350. stride_seq_idx_batch, stride_seq_idx_seqlen,
  351. stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate,
  352. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
  353. # Meta-parameters
  354. HAS_DDA_CS: tl.constexpr,
  355. HAS_SEQ_IDX: tl.constexpr,
  356. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  357. ):
  358. pid_bc = tl.program_id(axis=1)
  359. pid_c = pid_bc // batch
  360. pid_b = pid_bc - pid_c * batch
  361. pid_sg = tl.program_id(axis=2)
  362. pid_s = pid_sg // ngroups
  363. pid_g = pid_sg - pid_s * ngroups
  364. num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
  365. pid_m = tl.program_id(axis=0) // num_pid_n
  366. pid_n = tl.program_id(axis=0) % num_pid_n
  367. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
  368. db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split
  369. dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head
  370. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
  371. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
  372. if HAS_DDA_CS:
  373. b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head
  374. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
  375. if HAS_SEQ_IDX:
  376. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  377. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  378. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  379. offs_k = tl.arange(0, BLOCK_SIZE_K)
  380. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim)
  381. dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim)
  382. dt_ptrs = dt_ptr + offs_m * stride_dt_csize
  383. dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
  384. if HAS_DDA_CS:
  385. b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate)
  386. ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
  387. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  388. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  389. if HAS_DDA_CS:
  390. b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
  391. if HAS_SEQ_IDX:
  392. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  393. seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
  394. nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
  395. for h in range(nheads_iter):
  396. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
  397. dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
  398. dstates = dstates.to(x_ptrs.dtype.element_ty)
  399. db = tl.dot(x, dstates)
  400. dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
  401. dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  402. dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  403. if not HAS_SEQ_IDX:
  404. scale = tl.exp(dA_cs_last - dA_cs_m)
  405. else:
  406. scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
  407. db *= (scale * dt_m)[:, None]
  408. if HAS_DDA_CS:
  409. # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
  410. ddA_cs = tl.sum(db * b, axis=1)
  411. tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
  412. acc += db
  413. x_ptrs += stride_x_head
  414. dstates_ptrs += stride_states_head
  415. dt_ptrs += stride_dt_head
  416. dA_cumsum_ptr += stride_dA_cs_head
  417. dA_cumsum_ptrs += stride_dA_cs_head
  418. if HAS_DDA_CS:
  419. ddA_cumsum_ptrs += stride_ddA_cs_head
  420. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  421. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  422. # if HAS_SEQ_IDX:
  423. # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
  424. # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  425. # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
  426. db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate)
  427. tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
  428. @triton.autotune(
  429. configs=[
  430. # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  431. # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  432. # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  433. # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  434. # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  435. # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  436. # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  437. # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  438. # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  439. triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  440. triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  441. triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  442. triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  443. triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  444. triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  445. triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  446. triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
  447. ],
  448. key=['chunk_size', 'hdim', 'dstate'],
  449. )
  450. @triton.jit
  451. def _chunk_state_bwd_ddAcs_stable_kernel(
  452. # Pointers to matrices
  453. x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
  454. ddA_cumsum_ptr,
  455. # Matrix dimensions
  456. chunk_size, hdim, dstate,
  457. batch, seqlen, nheads_ngroups_ratio,
  458. # Strides
  459. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  460. stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
  461. stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
  462. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  463. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  464. stride_seq_idx_batch, stride_seq_idx_seqlen,
  465. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
  466. # Meta-parameters
  467. HAS_SEQ_IDX: tl.constexpr,
  468. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  469. BLOCK_SIZE_DSTATE: tl.constexpr,
  470. ):
  471. pid_bc = tl.program_id(axis=1)
  472. pid_c = pid_bc // batch
  473. pid_b = pid_bc - pid_c * batch
  474. pid_h = tl.program_id(axis=2)
  475. num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
  476. pid_m = tl.program_id(axis=0) // num_pid_n
  477. pid_n = tl.program_id(axis=0) % num_pid_n
  478. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  479. b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
  480. dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
  481. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  482. ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
  483. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  484. if HAS_SEQ_IDX:
  485. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  486. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  487. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  488. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  489. # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
  490. offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
  491. b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
  492. dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
  493. if BLOCK_SIZE_DSTATE <= 128:
  494. b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
  495. dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
  496. dstates = dstates.to(b_ptr.dtype.element_ty)
  497. acc = tl.dot(b, dstates)
  498. else:
  499. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  500. for k in range(0, dstate, BLOCK_SIZE_K):
  501. b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
  502. dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
  503. dstates = dstates.to(b_ptr.dtype.element_ty)
  504. acc += tl.dot(b, dstates)
  505. b_ptrs += BLOCK_SIZE_K * stride_b_dstate
  506. dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
  507. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  508. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  509. dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  510. dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
  511. if not HAS_SEQ_IDX:
  512. scale = tl.exp(dA_cs_last - dA_cs_m)
  513. else:
  514. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  515. seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
  516. scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
  517. acc *= scale[:, None]
  518. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  519. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  520. dt_ptrs = dt_ptr + offs_m * stride_dt_csize
  521. dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
  522. ddt = tl.sum(acc * x, axis=1)
  523. # ddA_cs = -(ddt * dt_m)
  524. # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
  525. # then call torch.cumsum outside this kernel.
  526. # ddA_cs = tl.cumsum(ddt * dt_m)
  527. ddA_cs = ddt * dt_m
  528. ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
  529. # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
  530. tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
  531. def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
  532. batch, seqlen, nheads = dt.shape
  533. assert A.shape == (nheads,)
  534. if dt_bias is not None:
  535. assert dt_bias.shape == (nheads,)
  536. nchunks = math.ceil(seqlen / chunk_size)
  537. dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
  538. dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
  539. grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
  540. with torch.cuda.device(dt.device.index):
  541. _chunk_cumsum_fwd_kernel[grid_chunk_cs](
  542. dt, A, dt_bias, dt_out, dA_cumsum,
  543. int(batch), int(seqlen), int(nheads), int(chunk_size),
  544. dt_limit[0], dt_limit[1],
  545. dt.stride(0), dt.stride(1), dt.stride(2),
  546. A.stride(0),
  547. dt_bias.stride(0) if dt_bias is not None else 0,
  548. dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
  549. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  550. dt_softplus,
  551. HAS_DT_BIAS=dt_bias is not None,
  552. BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
  553. )
  554. return dA_cumsum, dt_out
  555. def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None):
  556. batch, seqlen, nheads = dt.shape
  557. _, _, nchunks, chunk_size = ddA.shape
  558. assert ddA.shape == (batch, nheads, nchunks, chunk_size)
  559. assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
  560. assert A.shape == (nheads,)
  561. if dt_bias is not None:
  562. assert dt_bias.shape == (nheads,)
  563. ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
  564. else:
  565. ddt_bias = None
  566. if ddt is not None:
  567. assert ddt.shape == dt.shape
  568. else:
  569. ddt = torch.empty_like(dt)
  570. dA = torch.empty_like(A, dtype=torch.float32)
  571. grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
  572. with torch.cuda.device(dt.device.index):
  573. _chunk_cumsum_bwd_kernel[grid_chunk_cs](
  574. ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,
  575. int(batch), int(seqlen), int(nheads), int(chunk_size),
  576. dt_limit[0], dt_limit[1],
  577. ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),
  578. ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),
  579. dt.stride(0), dt.stride(1), dt.stride(2),
  580. A.stride(0),
  581. dt_bias.stride(0) if dt_bias is not None else 0,
  582. ddt.stride(0), ddt.stride(1), ddt.stride(2),
  583. dA.stride(0),
  584. ddt_bias.stride(0) if ddt_bias is not None else 0,
  585. dt_softplus,
  586. HAS_DT_BIAS=dt_bias is not None,
  587. BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
  588. )
  589. return ddt, dA, ddt_bias
  590. def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):
  591. batch, seqlen, nheads, headdim = x.shape
  592. _, _, nchunks, chunk_size = dt.shape
  593. _, _, ngroups, dstate = B.shape
  594. assert nheads % ngroups == 0
  595. assert B.shape == (batch, seqlen, ngroups, dstate)
  596. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  597. assert dA_cumsum.shape == dt.shape
  598. if seq_idx is not None:
  599. assert seq_idx.shape == (batch, seqlen)
  600. if states is not None:
  601. assert states.shape == (batch, nchunks, nheads, headdim, dstate)
  602. else:
  603. states_dtype = torch.float32 if states_in_fp32 else B.dtype
  604. states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)
  605. grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
  606. batch * nchunks, nheads)
  607. with torch.cuda.device(x.device.index):
  608. _chunk_state_fwd_kernel[grid](
  609. x, B, states, dt, dA_cumsum, seq_idx,
  610. int(headdim), int(dstate), int(chunk_size),
  611. int(batch), int(seqlen), int(nheads // ngroups),
  612. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  613. B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
  614. states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
  615. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  616. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  617. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  618. HAS_SEQ_IDX=seq_idx is not None,
  619. )
  620. return states
  621. def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
  622. batch, seqlen, nheads, headdim = x.shape
  623. _, _, nchunks, chunk_size = dt.shape
  624. _, _, ngroups, dstate = B.shape
  625. assert nheads % ngroups == 0
  626. assert B.shape == (batch, seqlen, ngroups, dstate)
  627. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  628. assert dA_cumsum.shape == dt.shape
  629. assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
  630. if dx is not None:
  631. assert dx.shape == x.shape
  632. else:
  633. dx = torch.empty_like(x)
  634. ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
  635. ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32)
  636. grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
  637. batch * nchunks, nheads)
  638. with torch.cuda.device(x.device.index):
  639. _chunk_state_bwd_dx_kernel[grid_dx](
  640. x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum,
  641. int(chunk_size), int(headdim), int(dstate),
  642. int(batch), int(seqlen), int(nheads // ngroups),
  643. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  644. B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
  645. dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
  646. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  647. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  648. dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
  649. ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
  650. ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
  651. BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
  652. )
  653. return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
  654. def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
  655. batch, seqlen, nheads, headdim = x.shape
  656. _, _, nchunks, chunk_size = dt.shape
  657. dstate = dstates.shape[-1]
  658. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  659. assert dA_cumsum.shape == dt.shape
  660. assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
  661. if seq_idx is not None:
  662. assert seq_idx.shape == (batch, seqlen)
  663. if B is not None:
  664. assert B.shape == (batch, seqlen, ngroups, dstate)
  665. B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
  666. # Use torch.empty since the Triton kernel will call init_to_zero
  667. ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
  668. ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
  669. else:
  670. B_strides = (0, 0, 0, 0)
  671. ddA_cumsum = None
  672. ddA_cumsum_strides = (0, 0, 0, 0)
  673. nheads_ngroups_ratio = nheads // ngroups
  674. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
  675. nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
  676. nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
  677. dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32)
  678. grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
  679. batch * nchunks, nsplits * ngroups)
  680. with torch.cuda.device(x.device.index):
  681. _chunk_state_bwd_db_kernel[grid_db](
  682. x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum,
  683. int(chunk_size), int(dstate), int(headdim),
  684. int(batch), int(seqlen), int(nheads), int(nheads_per_program), int(ngroups),
  685. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  686. dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
  687. *B_strides,
  688. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  689. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  690. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  691. dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4),
  692. *ddA_cumsum_strides,
  693. HAS_DDA_CS=ddA_cumsum is not None,
  694. HAS_SEQ_IDX=seq_idx is not None,
  695. BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
  696. )
  697. dB = dB.sum(2)
  698. if ddA_cumsum is not None:
  699. # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
  700. # to the state of the chunk.
  701. # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
  702. # But it's easier to just do the cumsum for all elements, the result will be the same.
  703. torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
  704. return dB if B is None else (dB, ddA_cumsum)
  705. def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
  706. batch, seqlen, nheads, headdim = x.shape
  707. _, _, nchunks, chunk_size = dt.shape
  708. _, _, ngroups, dstate = B.shape
  709. assert nheads % ngroups == 0
  710. assert B.shape == (batch, seqlen, ngroups, dstate)
  711. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  712. assert dA_cumsum.shape == dt.shape
  713. assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
  714. if seq_idx is not None:
  715. assert seq_idx.shape == (batch, seqlen)
  716. # Use torch.empty since the Triton kernel will call init_to_zero
  717. ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
  718. grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
  719. batch * nchunks, nheads)
  720. with torch.cuda.device(x.device.index):
  721. _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
  722. x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum,
  723. int(chunk_size), int(headdim), int(dstate),
  724. int(batch), int(seqlen), int(nheads // ngroups),
  725. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  726. B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
  727. dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
  728. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  729. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  730. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  731. ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
  732. HAS_SEQ_IDX=seq_idx is not None,
  733. BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
  734. BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
  735. )
  736. torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
  737. return ddA_cumsum
  738. class ChunkStateFn(torch.autograd.Function):
  739. @staticmethod
  740. def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
  741. batch, seqlen, nheads, headdim = x.shape
  742. _, _, nchunks, chunk_size = dt.shape
  743. assert seqlen <= nchunks * chunk_size
  744. _, _, ngroups, dstate = B.shape
  745. assert B.shape == (batch, seqlen, ngroups, dstate)
  746. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  747. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  748. if B.stride(-1) != 1:
  749. B = B.contiguous()
  750. if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
  751. x = x.contiguous()
  752. states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
  753. ctx.save_for_backward(B, x, dt, dA_cumsum)
  754. return states
  755. @staticmethod
  756. def backward(ctx, dstates):
  757. B, x, dt, dA_cumsum = ctx.saved_tensors
  758. batch, seqlen, nheads, headdim = x.shape
  759. _, _, nchunks, chunk_size = dt.shape
  760. _, _, ngroups, dstate = B.shape
  761. assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
  762. if dstates.stride(-1) != 1:
  763. dstates = dstates.contiguous()
  764. dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
  765. dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
  766. dB = dB.to(B.dtype)
  767. return dB, dx, ddt, ddA_cumsum, None
  768. def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
  769. """
  770. Argument:
  771. B: (batch, seqlen, ngroups, headdim)
  772. x: (batch, seqlen, nheads, headdim)
  773. dt: (batch, nheads, nchunks, chunk_size)
  774. dA_cumsum: (batch, nheads, nchunks, chunk_size)
  775. Return:
  776. states: (batch, nchunks, nheads, headdim, dstate)
  777. """
  778. return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
  779. def chunk_state_ref(B, x, dt, dA_cumsum):
  780. """
  781. Argument:
  782. B: (batch, seqlen, ngroups, headdim)
  783. x: (batch, seqlen, nheads, headdim)
  784. dt: (batch, nheads, nchunks, chunk_size)
  785. dA_cumsum: (batch, nheads, nchunks, chunk_size)
  786. Return:
  787. states: (batch, nchunks, nheads, headdim, dstate)
  788. """
  789. # Check constraints.
  790. batch, seqlen, nheads, headdim = x.shape
  791. dstate = B.shape[-1]
  792. _, _, nchunks, chunk_size = dt.shape
  793. assert seqlen <= nchunks * chunk_size
  794. assert x.shape == (batch, seqlen, nheads, headdim)
  795. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  796. ngroups = B.shape[2]
  797. assert nheads % ngroups == 0
  798. assert B.shape == (batch, seqlen, ngroups, dstate)
  799. B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
  800. assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
  801. if seqlen < nchunks * chunk_size:
  802. x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
  803. B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
  804. x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
  805. B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
  806. decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
  807. return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)