ssd_combined.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. """We want triton==2.1.0 or 2.2.0 for this
  3. """
  4. from typing import Optional
  5. import math
  6. from packaging import version
  7. import torch
  8. import torch.nn.functional as F
  9. from torch import Tensor
  10. from torch.cuda.amp import custom_bwd, custom_fwd
  11. import triton
  12. import triton.language as tl
  13. from einops import rearrange, repeat
  14. try:
  15. from causal_conv1d import causal_conv1d_fn
  16. import causal_conv1d_cuda
  17. except ImportError:
  18. causal_conv1d_fn, causal_conv1d_cuda = None, None
  19. try:
  20. from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
  21. from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
  22. from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
  23. from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
  24. from .ssd_chunk_state import chunk_state, chunk_state_ref
  25. from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
  26. from .ssd_state_passing import state_passing, state_passing_ref
  27. from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
  28. from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
  29. from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
  30. from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
  31. from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
  32. from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
  33. from .k_activations import _swiglu_fwd, _swiglu_bwd
  34. except:
  35. from ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
  36. from ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
  37. from ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
  38. from ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
  39. from ssd_chunk_state import chunk_state, chunk_state_ref
  40. from ssd_state_passing import _state_passing_fwd, _state_passing_bwd
  41. from ssd_state_passing import state_passing, state_passing_ref
  42. from ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
  43. from ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
  44. from ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
  45. from ssd_chunk_scan import chunk_scan, chunk_scan_ref
  46. from ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
  47. from layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
  48. from k_activations import _swiglu_fwd, _swiglu_bwd
  49. TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
  50. if True:
  51. class tTensor(torch.Tensor):
  52. @property
  53. def shape(self):
  54. shape = super().shape
  55. return tuple([int(s) for s in shape])
  56. to_ttensor = lambda *args: tuple([tTensor(x) for x in args]) if len(args) > 1 else tTensor(args[0])
  57. def init_to_zero(names):
  58. return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
  59. @triton.autotune(
  60. configs=[
  61. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
  62. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  63. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  64. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  65. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  66. triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  67. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  68. triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  69. triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
  70. ],
  71. key=['chunk_size', 'hdim', 'dstate'],
  72. )
  73. @triton.jit
  74. def _chunk_scan_chunk_state_bwd_dx_kernel(
  75. # Pointers to matrices
  76. x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,
  77. b_ptr, dstates_ptr,
  78. dx_ptr, ddt_ptr, dD_ptr,
  79. # Matrix dimensions
  80. chunk_size, hdim, dstate,
  81. batch, seqlen, nheads_ngroups_ratio,
  82. # Strides
  83. stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
  84. stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
  85. stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
  86. stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
  87. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
  88. stride_seq_idx_batch, stride_seq_idx_seqlen,
  89. stride_D_head,
  90. stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
  91. stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,
  92. stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
  93. stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
  94. stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
  95. # Meta-parameters
  96. HAS_D: tl.constexpr,
  97. D_HAS_HDIM: tl.constexpr,
  98. HAS_SEQ_IDX: tl.constexpr,
  99. BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
  100. BLOCK_SIZE_DSTATE: tl.constexpr,
  101. IS_TRITON_22: tl.constexpr,
  102. ):
  103. pid_bc = tl.program_id(axis=1)
  104. pid_c = pid_bc // batch
  105. pid_b = pid_bc - pid_c * batch
  106. pid_h = tl.program_id(axis=2)
  107. num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
  108. pid_m = tl.program_id(axis=0) // num_pid_n
  109. pid_n = tl.program_id(axis=0) % num_pid_n
  110. x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
  111. cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
  112. dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
  113. dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
  114. ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
  115. dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
  116. b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
  117. dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head
  118. if HAS_SEQ_IDX:
  119. seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
  120. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  121. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  122. chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
  123. acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  124. dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  125. dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
  126. if not HAS_SEQ_IDX:
  127. scale = tl.exp(dA_cs_last - dA_cs_m)
  128. else:
  129. seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
  130. seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
  131. scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
  132. # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
  133. # However, we're getting error with the Triton compiler 2.1.0 for that code path:
  134. # Unexpected mma -> mma layout conversion
  135. # Triton 2.2.0 fixes this
  136. offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
  137. b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)
  138. dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)
  139. if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
  140. b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)
  141. dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
  142. dstates = dstates.to(b_ptr.dtype.element_ty)
  143. acc = tl.dot(b, dstates) * scale[:, None]
  144. else:
  145. for k in range(0, dstate, BLOCK_SIZE_K):
  146. b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)
  147. dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
  148. dstates = dstates.to(b_ptr.dtype.element_ty)
  149. acc += tl.dot(b, dstates)
  150. b_ptrs += BLOCK_SIZE_K * stride_b_dstate
  151. dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
  152. acc *= scale[:, None]
  153. # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  154. # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  155. # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
  156. # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  157. # ddt = tl.sum(acc * x, axis=1) * dt_m
  158. # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
  159. # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
  160. offs_k = tl.arange(0, BLOCK_SIZE_K)
  161. cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
  162. dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
  163. dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
  164. K_MAX = chunk_size_limit
  165. K_MIN = pid_m * BLOCK_SIZE_M
  166. cb_ptrs += K_MIN * stride_cb_csize_k
  167. dout_ptrs += K_MIN * stride_dout_seqlen
  168. dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
  169. for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
  170. k = tl.multiple_of(k, BLOCK_SIZE_K)
  171. # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
  172. cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
  173. dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
  174. dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
  175. cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
  176. mask = k + offs_k[None, :] >= offs_m[:, None]
  177. cb = tl.where(mask, cb, 0.0)
  178. cb = cb.to(dout_ptr.dtype.element_ty)
  179. acc += tl.dot(cb, dout)
  180. cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
  181. dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
  182. dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
  183. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  184. offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  185. dt_ptrs = dt_ptr + offs_m * stride_dt_csize
  186. dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
  187. dx = acc * dt_m[:, None]
  188. dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
  189. dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
  190. if HAS_D:
  191. dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
  192. dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  193. if D_HAS_HDIM:
  194. D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
  195. else:
  196. D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
  197. dx += dout_res * D
  198. tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
  199. x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
  200. x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
  201. if HAS_D:
  202. dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
  203. if D_HAS_HDIM:
  204. dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
  205. dD = tl.sum(dout_res * x, axis=0)
  206. tl.store(dD_ptrs, dD, mask=offs_n < hdim)
  207. else:
  208. dD = tl.sum(dout_res * x)
  209. tl.store(dD_ptr, dD)
  210. ddt = tl.sum(acc * x, axis=1)
  211. ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
  212. tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
  213. def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):
  214. batch, seqlen, nheads, headdim = x.shape
  215. _, _, nchunks, chunk_size = dt.shape
  216. _, _, ngroups, dstate = B.shape
  217. assert nheads % ngroups == 0
  218. assert B.shape == (batch, seqlen, ngroups, dstate)
  219. assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
  220. assert dt.shape == (batch, nheads, nchunks, chunk_size)
  221. assert dA_cumsum.shape == dt.shape
  222. assert dout.shape == x.shape
  223. assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
  224. if seq_idx is not None:
  225. assert seq_idx.shape == (batch, seqlen)
  226. if D is not None:
  227. assert D.shape == (nheads, headdim) or D.shape == (nheads,)
  228. assert D.stride(-1) == 1
  229. BLOCK_SIZE_min = 32
  230. dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
  231. headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
  232. else:
  233. dD = None
  234. dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
  235. if D is not None else (0, 0, 0, 0, 0))
  236. if dx is None:
  237. dx = torch.empty_like(x)
  238. else:
  239. assert dx.shape == x.shape
  240. ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
  241. grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
  242. batch * nchunks, nheads)
  243. with torch.cuda.device(x.device.index):
  244. _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
  245. x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,
  246. int(chunk_size), int(headdim), int(dstate),
  247. int(batch), int(seqlen), int(nheads // ngroups),
  248. x.stride(0), x.stride(1), x.stride(2), x.stride(3),
  249. CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),
  250. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  251. dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
  252. dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
  253. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  254. D.stride(0) if D is not None else 0,
  255. B.stride(0), B.stride(1), B.stride(2), B.stride(3),
  256. dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
  257. dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
  258. ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
  259. dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
  260. D is not None,
  261. D.dim() == 2 if D is not None else True,
  262. HAS_SEQ_IDX=seq_idx is not None,
  263. BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
  264. IS_TRITON_22=TRITON_22
  265. )
  266. if D is not None:
  267. BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
  268. n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
  269. dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
  270. if D.dim() == 1:
  271. dD = rearrange(dD, "h 1 -> h")
  272. return dx, ddt.to(dtype=dt.dtype), dD
  273. # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  274. # @torch.compile(fullgraph=True)
  275. def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
  276. batch, seqlen, nheads, headdim = x.shape
  277. _, _, ngroups, dstate = B.shape
  278. assert nheads % ngroups == 0
  279. assert B.shape == (batch, seqlen, ngroups, dstate)
  280. assert x.shape == (batch, seqlen, nheads, headdim)
  281. assert dt.shape == (batch, seqlen, nheads)
  282. assert A.shape == (nheads,)
  283. assert C.shape == B.shape
  284. if z is not None:
  285. assert z.shape == x.shape
  286. if D is not None:
  287. assert D.shape == (nheads, headdim) or D.shape == (nheads,)
  288. if seq_idx is not None:
  289. assert seq_idx.shape == (batch, seqlen)
  290. if B.stride(-1) != 1:
  291. B = B.contiguous()
  292. if C.stride(-1) != 1:
  293. C = C.contiguous()
  294. if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
  295. x = x.contiguous()
  296. if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
  297. z = z.contiguous()
  298. if D is not None and D.stride(-1) != 1:
  299. D = D.contiguous()
  300. if initial_states is not None:
  301. assert initial_states.shape == (batch, nheads, headdim, dstate)
  302. # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
  303. # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
  304. # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
  305. # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
  306. dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
  307. states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
  308. # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
  309. # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
  310. # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
  311. states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
  312. initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
  313. seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)
  314. states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]]
  315. # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
  316. # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
  317. CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
  318. out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
  319. return out, out_x, dt, dA_cumsum, states, final_states
  320. # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  321. # @torch.compile(fullgraph=True)
  322. def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None,
  323. dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False,
  324. dt_limit=(0.0, float("inf")),
  325. dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False):
  326. if dout.stride(-1) != 1:
  327. dout = dout.contiguous()
  328. batch, seqlen, nheads, headdim = x.shape
  329. nchunks = math.ceil(seqlen / chunk_size)
  330. _, _, ngroups, dstate = B.shape
  331. assert dout.shape == (batch, seqlen, nheads, headdim)
  332. assert dt.shape == (batch, seqlen, nheads)
  333. assert A.shape == (nheads,)
  334. assert nheads % ngroups == 0
  335. assert B.shape == (batch, seqlen, ngroups, dstate)
  336. assert C.shape == B.shape
  337. assert out.shape == x.shape
  338. if initial_states is not None:
  339. assert initial_states.shape == (batch, nheads, headdim, dstate)
  340. if seq_idx is not None:
  341. assert seq_idx.shape == (batch, seqlen)
  342. if dx is not None:
  343. assert dx.shape == x.shape
  344. if dB is not None:
  345. assert dB.shape == B.shape
  346. dB_given = dB
  347. else:
  348. dB_given = torch.empty_like(B)
  349. if dC is not None:
  350. assert dC.shape == C.shape
  351. dC_given = dC
  352. else:
  353. dC_given = torch.empty_like(C)
  354. if dz is not None:
  355. assert z is not None
  356. assert dz.shape == z.shape
  357. if ddt is not None:
  358. assert ddt.shape == dt.shape
  359. ddt_given = ddt
  360. else:
  361. ddt_given = torch.empty_like(dt)
  362. # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
  363. # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
  364. dt_in = dt.clone()
  365. dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus,
  366. dt_limit=dt_limit)
  367. CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
  368. states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
  369. states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
  370. initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
  371. seq_idx=seq_idx, chunk_size=chunk_size)
  372. states = rearrange(states, "... (p n) -> ... p n", n=dstate)
  373. if z is not None:
  374. dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output)
  375. outz = rest[0] if recompute_output else out
  376. else:
  377. dz = None
  378. outz = out
  379. dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype)
  380. # dstates has length nchunks, containing the gradient to initial states at index 0 and
  381. # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
  382. # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
  383. # will be used in matmul in the next kernels.
  384. dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
  385. rearrange(states, "... p n -> ... (p n)"),
  386. dA_cumsum[:, :, :, -1],
  387. rearrange(dstates, "... p n -> ... (p n)"),
  388. dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None,
  389. seq_idx=seq_idx,
  390. has_initial_states=initial_states is not None,
  391. dstates_dtype=x.dtype,
  392. states_dtype=x.dtype,
  393. chunk_size=chunk_size,
  394. )
  395. # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
  396. # gradient to the final states at index (nchunks - 1)
  397. # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
  398. # The final states is not stored.
  399. states = rearrange(states, "... (p n) -> ... p n", n=dstate)
  400. dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
  401. dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None
  402. dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
  403. # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
  404. dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups)
  405. # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
  406. dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups)
  407. # Computing ddA with the dcb kernel is much slower, so we're not using it for now
  408. dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
  409. # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
  410. dCB = dCB.to(CB.dtype)
  411. _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
  412. _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
  413. # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
  414. # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
  415. if z is None:
  416. dD = dD_from_x
  417. # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
  418. # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
  419. # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
  420. # be a lot of underflow.
  421. # This is already done as part of bwd_dC kernel
  422. # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
  423. ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
  424. ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
  425. # This is already done as part of bwd_dB kernel
  426. # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
  427. # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
  428. ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
  429. ddA += ddA_next + ddA_prev
  430. ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given)
  431. # These 2 lines are just to test ddt and dA being computed by old code
  432. # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
  433. # ddt_given.copy_(ddt)
  434. return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states)
  435. return return_vals if not recompute_output else (*return_vals, outz)
  436. def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
  437. """
  438. Argument:
  439. dout: (batch, seqlen, nheads, headdim)
  440. x: (batch, seqlen, nheads, headdim)
  441. dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
  442. A: (nheads) or (dim, dstate)
  443. B: (batch, seqlen, ngroups, dstate)
  444. C: (batch, seqlen, ngroups, dstate)
  445. D: (nheads, headdim) or (nheads,)
  446. z: (batch, seqlen, nheads, headdim)
  447. Return:
  448. out: (batch, seqlen, nheads, headdim)
  449. """
  450. import selective_scan
  451. batch, seqlen, nheads, headdim = x.shape
  452. chunk_size = dt.shape[-1]
  453. _, _, ngroups, dstate = B.shape
  454. assert nheads % ngroups == 0
  455. x = rearrange(x, "b l h p -> b (h p) l")
  456. squeeze_dt = dt.dim() == 4
  457. if dt.dim() == 4:
  458. dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
  459. dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
  460. squeeze_A = A.dim() == 1
  461. if A.dim() == 1:
  462. A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
  463. else:
  464. A = A.to(dtype=torch.float32)
  465. B = rearrange(B, "b l g n -> b g n l")
  466. C = rearrange(C, "b l g n -> b g n l")
  467. if D is not None:
  468. if D.dim() == 2:
  469. D = rearrange(D, "h p -> (h p)")
  470. else:
  471. D = repeat(D, "h -> (h p)", p=headdim)
  472. if z is not None:
  473. z = rearrange(z, "b l h p -> b (h p) l")
  474. if x.stride(-1) != 1:
  475. x = x.contiguous()
  476. if dt.stride(-1) != 1:
  477. dt = dt.contiguous()
  478. if D is not None:
  479. D = D.contiguous()
  480. if B.stride(-1) != 1:
  481. B = B.contiguous()
  482. if C.stride(-1) != 1:
  483. C = C.contiguous()
  484. if z is not None and z.stride(-1) != 1:
  485. z = z.contiguous()
  486. _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False)
  487. if z is not None:
  488. out = rest[0]
  489. else:
  490. out = None
  491. dout = rearrange(dout, "b l h p -> b (h p) l")
  492. if dout.stride(-1) != 1:
  493. dout = dout.contiguous()
  494. # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
  495. # backward of selective_scan with the backward of chunk).
  496. # Here we just pass in None and dz will be allocated in the C++ code.
  497. _, ddt, dA, *rest = selective_scan.bwd(
  498. x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False,
  499. False # option to recompute out_z, not used here
  500. )
  501. ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
  502. if squeeze_dt:
  503. ddt = ddt.float().sum(dim=2)
  504. if squeeze_A:
  505. dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
  506. return ddt, dA
  507. class MambaChunkScanCombinedFn(torch.autograd.Function):
  508. @staticmethod
  509. def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False):
  510. ctx.dt_dtype = dt.dtype
  511. out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
  512. ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx)
  513. ctx.dt_softplus = dt_softplus
  514. ctx.chunk_size = chunk_size
  515. ctx.dt_limit = dt_limit
  516. ctx.return_final_states = return_final_states
  517. return out if not return_final_states else (out, final_states)
  518. @staticmethod
  519. def backward(ctx, dout, *args):
  520. out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
  521. dfinal_states = args[0] if ctx.return_final_states else None
  522. dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
  523. return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None
  524. # @torch.jit.ignore
  525. def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False):
  526. """
  527. Argument:
  528. x: (batch, seqlen, nheads, headdim)
  529. dt: (batch, seqlen, nheads)
  530. A: (nheads)
  531. B: (batch, seqlen, ngroups, dstate)
  532. C: (batch, seqlen, ngroups, dstate)
  533. chunk_size: int
  534. D: (nheads, headdim) or (nheads,)
  535. z: (batch, seqlen, nheads, headdim)
  536. dt_bias: (nheads,)
  537. initial_states: (batch, nheads, headdim, dstate)
  538. seq_idx: (batch, seqlen)
  539. dt_softplus: Whether to apply softplus to dt
  540. Return:
  541. out: (batch, seqlen, nheads, headdim)
  542. """
  543. return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
  544. def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
  545. """
  546. Argument:
  547. x: (batch, seqlen, nheads, headdim)
  548. dt: (batch, seqlen, nheads)
  549. A: (nheads)
  550. B: (batch, seqlen, ngroups, dstate)
  551. C: (batch, seqlen, ngroups, dstate)
  552. D: (nheads, headdim) or (nheads,)
  553. z: (batch, seqlen, nheads, headdim)
  554. dt_bias: (nheads,)
  555. Return:
  556. out: (batch, seqlen, nheads, headdim)
  557. """
  558. batch, seqlen, nheads, headdim = x.shape
  559. dstate = B.shape[-1]
  560. if seqlen % chunk_size != 0:
  561. dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
  562. dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
  563. dt = dt.float() # We want high precision for this before cumsum
  564. if dt_bias is not None:
  565. dt = dt + rearrange(dt_bias, "h -> h 1 1")
  566. if dt_softplus:
  567. dt = F.softplus(dt)
  568. dA = dt * rearrange(A, "h -> h 1 1")
  569. dA = dt * rearrange(A, "h -> h 1 1")
  570. dA_cumsum = torch.cumsum(dA, dim=-1)
  571. # 1. Compute the state for each chunk
  572. states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
  573. # 2. Pass the state to all the chunks by weighted cumsum.
  574. states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
  575. "... (p n) -> ... p n", n=dstate)
  576. # 3. Compute the output for each chunk
  577. out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
  578. return out
  579. def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
  580. """
  581. Argument:
  582. x: (batch, seqlen, nheads, headdim)
  583. dt: (batch, seqlen, nheads)
  584. A: (nheads)
  585. B: (batch, seqlen, ngroups, dstate)
  586. C: (batch, seqlen, ngroups, dstate)
  587. D: (nheads, headdim) or (nheads,)
  588. z: (batch, seqlen, nheads, headdim)
  589. dt_bias: (nheads,)
  590. Return:
  591. out: (batch, seqlen, nheads, headdim)
  592. """
  593. batch, seqlen, nheads, headdim = x.shape
  594. dstate = B.shape[-1]
  595. if seqlen % chunk_size != 0:
  596. dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
  597. dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
  598. dt = dt.float() # We want high precision for this before cumsum
  599. if dt_bias is not None:
  600. dt = dt + rearrange(dt_bias, "h -> h 1 1")
  601. if dt_softplus:
  602. dt = F.softplus(dt)
  603. dA = dt * rearrange(A, "h -> h 1 1")
  604. dA_cumsum = torch.cumsum(dA, dim=-1)
  605. # 1. Compute the state for each chunk
  606. states = chunk_state_ref(B, x, dt, dA_cumsum)
  607. states_dtype = states.dtype
  608. if states.dtype not in [torch.float32, torch.float64]:
  609. states = states.to(torch.float32)
  610. # 2. Pass the state to all the chunks by weighted cumsum.
  611. # state_passing_ref is much less numerically stable
  612. states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
  613. "... (p n) -> ... p n", n=dstate)
  614. states = states.to(states_dtype)
  615. # 3. Compute the output for each chunk
  616. out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
  617. return out
  618. def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
  619. """
  620. Argument:
  621. x: (batch, seqlen, nheads, headdim)
  622. dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
  623. A: (nheads) or (dim, dstate)
  624. B: (batch, seqlen, ngroups, dstate)
  625. C: (batch, seqlen, ngroups, dstate)
  626. D: (nheads, headdim) or (nheads,)
  627. z: (batch, seqlen, nheads, headdim)
  628. dt_bias: (nheads,) or (nheads, headdim)
  629. Return:
  630. out: (batch, seqlen, nheads, headdim)
  631. """
  632. from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
  633. batch, seqlen, nheads, headdim = x.shape
  634. _, _, ngroups, dstate = B.shape
  635. x = rearrange(x, "b l h p -> b (h p) l")
  636. if dt.dim() == 3:
  637. dt = repeat(dt, "b l h -> b l h p", p=headdim)
  638. dt = rearrange(dt, "b l h p -> b (h p) l")
  639. if A.dim() == 1:
  640. A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
  641. else:
  642. A = A.to(dtype=torch.float32)
  643. B = rearrange(B, "b l g n -> b g n l")
  644. C = rearrange(C, "b l g n -> b g n l")
  645. if D is not None:
  646. if D.dim() == 2:
  647. D = rearrange(D, "h p -> (h p)")
  648. else:
  649. D = repeat(D, "h -> (h p)", p=headdim)
  650. if z is not None:
  651. z = rearrange(z, "b l h p -> b (h p) l")
  652. if dt_bias is not None:
  653. if dt_bias.dim() == 1:
  654. dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
  655. dt_bias = rearrange(dt_bias, "h p -> (h p)")
  656. if dt_limit != (0.0, float("inf")):
  657. if dt_bias is not None:
  658. dt = dt + rearrange(dt_bias, "d -> d 1")
  659. if dt_softplus:
  660. dt = F.softplus(dt)
  661. dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
  662. dt_bias = None
  663. dt_softplus = None
  664. out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus)
  665. return rearrange(out, "b (h p) l -> b l h p", p=headdim)
  666. def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None,
  667. dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")),
  668. activation="silu", headdim=None, ngroups=1):
  669. """
  670. Argument:
  671. xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
  672. conv1d_weight: (dim + 2 * ngroups * dstate, width)
  673. conv1d_bias: (dim + 2 * ngroups * dstate,)
  674. dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
  675. A: (nheads)
  676. D: (nheads, headdim) or (nheads,)
  677. z: (batch, seqlen, dim)
  678. dt_bias: (nheads) or (nheads, headdim)
  679. headdim: if D is 1D and z is None, headdim must be passed in
  680. Return:
  681. out: (batch, seqlen, dim)
  682. """
  683. batch, seqlen, nheads = dt.shape[:3]
  684. assert nheads % ngroups == 0
  685. if z is not None:
  686. dim = z.shape[-1]
  687. assert dim % nheads == 0
  688. headdim = dim // nheads
  689. else:
  690. if D.dim() == 1:
  691. assert headdim is not None
  692. else:
  693. headdim = D.shape[1]
  694. dim = nheads * headdim
  695. xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
  696. "b d s -> b s d")
  697. dstate = (xBC.shape[-1] - dim) // ngroups // 2
  698. x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
  699. x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
  700. B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
  701. C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
  702. z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
  703. out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
  704. return rearrange(out, "b s h p -> b s (h p)")
  705. class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
  706. @staticmethod
  707. @custom_fwd
  708. def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  709. rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
  710. ngroups=1, norm_before_gate=True):
  711. assert activation in [None, "silu", "swish"]
  712. if D.dim() == 1:
  713. assert headdim is not None
  714. nheads, = D.shape
  715. else:
  716. nheads, headdim = D.shape
  717. batch, seqlen, _ = zxbcdt.shape
  718. dim = nheads * headdim
  719. assert nheads % ngroups == 0
  720. dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
  721. d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
  722. assert d_nonssm >= 0
  723. assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)
  724. assert dt_bias.shape == (nheads,)
  725. assert A.shape == (nheads,)
  726. zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
  727. seq_idx = seq_idx.contiguous() if seq_idx is not None else None
  728. xBC_conv = rearrange(
  729. causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
  730. conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
  731. "b d s -> b s d"
  732. )
  733. x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
  734. x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
  735. B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
  736. C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
  737. z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
  738. if rmsnorm_weight is None:
  739. out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
  740. out = rearrange(out, "b s h p -> b s (h p)")
  741. rstd = None
  742. if d_nonssm > 0:
  743. out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
  744. else:
  745. out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
  746. # reshape input data into 2D tensor
  747. x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
  748. z_rms = rearrange(z, "b s h p -> (b s) (h p)")
  749. rmsnorm_weight = rmsnorm_weight.contiguous()
  750. if d_nonssm == 0:
  751. out = None
  752. else:
  753. out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)
  754. out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
  755. _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
  756. out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,
  757. group_size=dim // ngroups,
  758. norm_before_gate=norm_before_gate, is_rms_norm=True)
  759. if d_nonssm == 0:
  760. out = rearrange(out, "(b s) d -> b s d", b=batch)
  761. else:
  762. out = out01
  763. ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None
  764. if outproj_weight is not None:
  765. if torch.is_autocast_enabled():
  766. dtype = torch.get_autocast_gpu_dtype()
  767. out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
  768. outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None
  769. out = F.linear(out, outproj_weight, outproj_bias)
  770. else:
  771. assert outproj_bias is None
  772. ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,
  773. out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)
  774. ctx.dt_limit = dt_limit
  775. ctx.return_final_states = return_final_states
  776. ctx.activation = activation
  777. ctx.rmsnorm_eps = rmsnorm_eps
  778. ctx.norm_before_gate = norm_before_gate
  779. ctx.chunk_size = chunk_size
  780. ctx.headdim = headdim
  781. ctx.ngroups = ngroups
  782. return out if not return_final_states else (out, final_states)
  783. @staticmethod
  784. @custom_bwd
  785. def backward(ctx, dout, *args):
  786. zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
  787. dfinal_states = args[0] if ctx.return_final_states else None
  788. headdim = ctx.headdim
  789. nheads = D.shape[0]
  790. dim = nheads * headdim
  791. assert nheads % ctx.ngroups == 0
  792. dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
  793. d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
  794. assert d_nonssm >= 0
  795. recompute_output = outproj_weight is not None
  796. if recompute_output:
  797. out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)
  798. out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)
  799. zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
  800. # Recompute x, B, C
  801. xBC_conv = rearrange(
  802. causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"),
  803. conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
  804. "b d s -> b s d"
  805. )
  806. x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
  807. x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
  808. B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
  809. C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
  810. dzxbcdt = torch.empty_like(zxbcdt)
  811. dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
  812. dxBC = torch.empty_like(xBC)
  813. dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
  814. z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
  815. dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
  816. dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
  817. dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
  818. if outproj_weight is not None:
  819. dout_og = dout
  820. dout = F.linear(dout, outproj_weight.t())
  821. if d_nonssm > 0:
  822. dout0, dout = dout.split([d_nonssm, dim], dim=-1)
  823. _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
  824. dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
  825. if rmsnorm_weight is None:
  826. dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
  827. dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(
  828. dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output
  829. )
  830. out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
  831. drmsnorm_weight = None
  832. else:
  833. batch = dout.shape[0]
  834. dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
  835. dz = rearrange(dz, "b l d -> (b l) d")
  836. x_rms = rearrange(out, "b s h p -> (b s) (h p)")
  837. z_rms = rearrange(z, "b s h p -> (b s) (h p)")
  838. out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
  839. dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
  840. out_for_linear = out_recompute if recompute_output else None
  841. dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
  842. dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
  843. dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC
  844. )
  845. if outproj_weight is not None:
  846. doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
  847. doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
  848. else:
  849. doutproj_weight, doutproj_bias = None, None
  850. dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
  851. dxBC_given, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
  852. rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
  853. rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, dxBC_given, False, ctx.activation in ["silu", "swish"]
  854. )
  855. dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
  856. return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
  857. def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
  858. """
  859. Argument:
  860. zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
  861. conv1d_weight: (dim + 2 * ngroups * dstate, width)
  862. conv1d_bias: (dim + 2 * ngroups * dstate,)
  863. dt_bias: (nheads,)
  864. A: (nheads)
  865. D: (nheads, headdim) or (nheads,)
  866. initial_states: (batch, nheads, headdim, dstate)
  867. seq_idx: (batch, seqlen), int32
  868. rmsnorm_weight: (dim,)
  869. outproj_weight: (out_dim, dim)
  870. outproj_bias: (out_dim,)
  871. headdim: if D is 1D, headdim must be passed in
  872. norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
  873. Return:
  874. out: (batch, seqlen, dim)
  875. """
  876. return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
  877. def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
  878. """
  879. Argument:
  880. zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
  881. conv1d_weight: (dim + 2 * ngroups * dstate, width)
  882. conv1d_bias: (dim + 2 * ngroups * dstate,)
  883. dt_bias: (nheads,)
  884. A: (nheads)
  885. D: (nheads, headdim) or (nheads,)
  886. rmsnorm_weight: (dim,)
  887. outproj_weight: (out_dim, dim)
  888. outproj_bias: (out_dim,)
  889. headdim: if D is 1D, headdim must be passed in
  890. norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
  891. Return:
  892. out: (batch, seqlen, dim)
  893. """
  894. if D.dim() == 1:
  895. assert headdim is not None
  896. nheads, = D.shape
  897. else:
  898. nheads, headdim = D.shape
  899. assert nheads % ngroups == 0
  900. batch, seqlen, _ = zxbcdt.shape
  901. dim = nheads * headdim
  902. dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
  903. assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
  904. assert dt_bias.shape == (nheads,)
  905. assert A.shape == (nheads,)
  906. if rmsnorm_weight is not None:
  907. assert rmsnorm_weight.shape == (dim,)
  908. z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
  909. xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
  910. "b d s -> b s d")
  911. x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
  912. x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
  913. B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
  914. C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
  915. z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
  916. out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(),
  917. z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit)
  918. out = rearrange(out, "b s h p -> b s (h p)")
  919. if rmsnorm_weight is not None:
  920. out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps,
  921. norm_before_gate=norm_before_gate)
  922. if outproj_weight is not None:
  923. out = F.linear(out, outproj_weight, outproj_bias)
  924. return out