ssd_state_passing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. """We want triton==2.1.0 or 2.2.0 for this
  3. """
  4. import math
  5. import torch
  6. import torch.nn.functional as F
  7. import triton
  8. import triton.language as tl
  9. from einops import rearrange, repeat
  10. @triton.autotune(
  11. configs=[
  12. triton.Config({'BLOCK_SIZE': 64}),
  13. triton.Config({'BLOCK_SIZE': 128}),
  14. triton.Config({'BLOCK_SIZE': 256}),
  15. triton.Config({'BLOCK_SIZE': 512}),
  16. triton.Config({'BLOCK_SIZE': 1024}),
  17. triton.Config({'BLOCK_SIZE': 2048}),
  18. ],
  19. key=['dim'],
  20. )
  21. @triton.jit
  22. def _state_passing_fwd_kernel(
  23. # Pointers to matrices
  24. states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
  25. # Matrix dimensions
  26. dim, nchunks, seqlen, chunk_size,
  27. # Strides
  28. stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
  29. stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
  30. stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
  31. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
  32. stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
  33. stride_seq_idx_batch, stride_seq_idx_seqlen,
  34. # Meta-parameters
  35. HAS_INITSTATES: tl.constexpr,
  36. HAS_SEQ_IDX: tl.constexpr,
  37. BLOCK_SIZE: tl.constexpr,
  38. ):
  39. pid_b = tl.program_id(axis=1)
  40. pid_h = tl.program_id(axis=2)
  41. pid_m = tl.program_id(axis=0)
  42. states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
  43. dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
  44. out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
  45. final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
  46. if HAS_INITSTATES:
  47. initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
  48. if HAS_SEQ_IDX:
  49. seq_idx_ptr += pid_b * stride_seq_idx_batch
  50. offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  51. states_ptrs = states_ptr + offs_m * stride_states_dim
  52. out_ptrs = out_ptr + offs_m * stride_out_dim
  53. final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
  54. if not HAS_INITSTATES:
  55. states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
  56. else:
  57. initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
  58. states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  59. tl.store(out_ptrs, states, mask=offs_m < dim)
  60. out_ptrs += stride_out_chunk
  61. seq_idx = 0
  62. for c in range(nchunks):
  63. new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  64. dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
  65. scale = tl.exp(dA_cs)
  66. if HAS_SEQ_IDX:
  67. seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
  68. scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
  69. seq_idx = seq_idx_new
  70. states = scale * states + new_states
  71. if c < nchunks - 1:
  72. tl.store(out_ptrs, states, mask=offs_m < dim)
  73. else:
  74. tl.store(final_states_ptrs, states, mask=offs_m < dim)
  75. states_ptrs += stride_states_chunk
  76. dA_cs_ptr += stride_dA_cs_chunk
  77. out_ptrs += stride_out_chunk
  78. @triton.autotune(
  79. configs=[
  80. triton.Config({'BLOCK_SIZE': 64}),
  81. triton.Config({'BLOCK_SIZE': 128}),
  82. triton.Config({'BLOCK_SIZE': 256}),
  83. triton.Config({'BLOCK_SIZE': 512}),
  84. triton.Config({'BLOCK_SIZE': 1024}),
  85. triton.Config({'BLOCK_SIZE': 2048}),
  86. ],
  87. key=['dim'],
  88. )
  89. @triton.jit
  90. def _state_passing_bwd_kernel(
  91. # Pointers to matrices
  92. dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
  93. dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
  94. # Matrix dimensions
  95. dim, nchunks, seqlen, chunk_size,
  96. # Strides
  97. stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
  98. stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
  99. stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
  100. stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
  101. stride_seq_idx_batch, stride_seq_idx_seqlen,
  102. stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
  103. stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
  104. stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
  105. # Meta-parameters
  106. CONVERT_STATES: tl.constexpr,
  107. HAS_DFINAL_STATES: tl.constexpr,
  108. HAS_DINITSTATES: tl.constexpr,
  109. HAS_SEQ_IDX: tl.constexpr,
  110. BLOCK_SIZE: tl.constexpr,
  111. ):
  112. pid_b = tl.program_id(axis=1)
  113. pid_h = tl.program_id(axis=2)
  114. pid_m = tl.program_id(axis=0)
  115. dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
  116. dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
  117. ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
  118. out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
  119. dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
  120. if CONVERT_STATES:
  121. states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
  122. if HAS_DFINAL_STATES:
  123. dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
  124. if HAS_DINITSTATES:
  125. dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
  126. if HAS_SEQ_IDX:
  127. seq_idx_ptr += pid_b * stride_seq_idx_batch
  128. offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  129. dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
  130. out_ptrs = out_ptr + offs_m * stride_out_dim
  131. dout_ptrs = dout_ptr + offs_m * stride_dout_dim
  132. if CONVERT_STATES:
  133. states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
  134. if HAS_DFINAL_STATES:
  135. dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
  136. else:
  137. dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
  138. tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
  139. if HAS_SEQ_IDX:
  140. seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
  141. dstates_ptrs -= stride_dstates_chunk
  142. for c in range(nchunks - 1):
  143. dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
  144. scale = tl.exp(dA_cs)
  145. if HAS_SEQ_IDX:
  146. seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
  147. scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
  148. seq_idx = seq_idx_new
  149. out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  150. if CONVERT_STATES:
  151. tl.store(states_converted_ptrs, out, mask=offs_m < dim)
  152. ddA = tl.sum(out * dstates) * scale
  153. tl.store(ddA_cs_ptr, ddA)
  154. dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  155. dstates = scale * dstates + dout
  156. tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
  157. dout_ptrs -= stride_dout_chunk
  158. dstates_ptrs -= stride_dstates_chunk
  159. dA_cs_ptr -= stride_dA_cs_chunk
  160. ddA_cs_ptr -= stride_ddA_cs_chunk
  161. out_ptrs -= stride_out_chunk
  162. if CONVERT_STATES:
  163. states_converted_ptrs -= stride_out_chunk
  164. if CONVERT_STATES:
  165. out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  166. tl.store(states_converted_ptrs, out, mask=offs_m < dim)
  167. if not HAS_DINITSTATES:
  168. tl.store(ddA_cs_ptr, 0.0)
  169. else:
  170. dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
  171. scale = tl.exp(dA_cs)
  172. if HAS_SEQ_IDX:
  173. scale = tl.where(seq_idx == 0, scale, 0.0)
  174. out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  175. ddA = tl.sum(out * dstates) * scale
  176. tl.store(ddA_cs_ptr, ddA)
  177. dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  178. dstates = scale * dstates + dout
  179. tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
  180. def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
  181. out_dtype=None):
  182. batch, nchunks, nheads, dim = states.shape
  183. assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
  184. if initial_states is not None:
  185. assert initial_states.shape == (batch, nheads, dim)
  186. if seq_idx is not None:
  187. assert chunk_size is not None
  188. seqlen = seq_idx.shape[-1]
  189. assert seq_idx.shape == (batch, seqlen)
  190. out_dtype = states.dtype if out_dtype is None else out_dtype
  191. out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
  192. final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
  193. grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
  194. with torch.cuda.device(states.device.index):
  195. _state_passing_fwd_kernel[grid](
  196. states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
  197. int(dim), int(nchunks), int(seqlen if seq_idx is not None else 0), int(chunk_size if seq_idx is not None else 0),
  198. states.stride(0), states.stride(1), states.stride(2), states.stride(3),
  199. out.stride(0), out.stride(1), out.stride(2), out.stride(3),
  200. final_states.stride(0), final_states.stride(1), final_states.stride(2),
  201. dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
  202. *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
  203. if initial_states is not None else (0, 0, 0)),
  204. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  205. HAS_INITSTATES=initial_states is not None,
  206. HAS_SEQ_IDX=seq_idx is not None,
  207. )
  208. return out, final_states
  209. def _state_passing_bwd(
  210. states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
  211. dstates_dtype=None, states_dtype=None, chunk_size=None
  212. ):
  213. """
  214. states contains the initial_states at index 0. The final states are not included in states.
  215. """
  216. batch, nchunks, nheads, dim = states.shape
  217. assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
  218. assert dout.shape == (batch, nchunks, nheads, dim)
  219. if seq_idx is not None:
  220. assert chunk_size is not None
  221. seqlen = seq_idx.shape[-1]
  222. assert seq_idx.shape == (batch, seqlen)
  223. dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
  224. if states_dtype is not None and states_dtype != states.dtype:
  225. states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
  226. assert states_converted.stride() == states.stride()
  227. else:
  228. states_converted = None
  229. if has_initial_states:
  230. dinitstates = torch.empty_like(dstates[:, 0])
  231. else:
  232. dinitstates = None
  233. if dfinal_states is not None:
  234. assert dfinal_states.shape == (batch, nheads, dim)
  235. BLOCK_SIZE_min = 64
  236. n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
  237. ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
  238. dtype=torch.float32, device=dA_chunk_cumsum.device)
  239. grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
  240. with torch.cuda.device(dout.device.index):
  241. _state_passing_bwd_kernel[grid](
  242. dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
  243. dstates, ddA_chunk_cumsum, dinitstates, states_converted,
  244. int(dim), int(nchunks), int(seqlen if seq_idx is not None else 0), int(chunk_size if seq_idx is not None else 0),
  245. dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
  246. states.stride(0), states.stride(1), states.stride(2), states.stride(3),
  247. dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
  248. *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
  249. if dfinal_states is not None else (0, 0, 0)),
  250. *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
  251. dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
  252. ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
  253. *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
  254. if dinitstates is not None else (0, 0, 0)),
  255. CONVERT_STATES=states_converted is not None,
  256. HAS_DFINAL_STATES=dfinal_states is not None,
  257. HAS_DINITSTATES=dinitstates is not None,
  258. HAS_SEQ_IDX=seq_idx is not None,
  259. )
  260. BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
  261. n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
  262. ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
  263. if states_dtype is not None and states_dtype == states.dtype:
  264. states_converted = states
  265. return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
  266. class StatePassingFn(torch.autograd.Function):
  267. @staticmethod
  268. def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
  269. batch, nchunks, nheads, dim = states.shape
  270. assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
  271. if states.stride(-1) != 1:
  272. states = states.contiguous()
  273. out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
  274. ctx.save_for_backward(out, dA_chunk_cumsum)
  275. ctx.has_initial_states = initial_states is not None
  276. return out, final_states
  277. @staticmethod
  278. def backward(ctx, dout, dfinal_states):
  279. out, dA_chunk_cumsum = ctx.saved_tensors
  280. batch, nchunks, nheads, dim = out.shape
  281. assert dout.shape == (batch, nchunks, nheads, dim)
  282. assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
  283. assert dfinal_states.shape == (batch, nheads, dim)
  284. if dout.stride(-1) != 1:
  285. dout = dout.contiguous()
  286. dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
  287. out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
  288. )
  289. return dstates, ddA_chunk_cumsum, dinitstates
  290. def state_passing(states, dA_chunk_cumsum, initial_states=None):
  291. """
  292. Argument:
  293. states: (batch, nchunks, nheads, dim)
  294. dA_chunk_cumsum: (batch, nheads, nchunks)
  295. initial_states: (batch, nheads, dim)
  296. Return:
  297. out: (batch, nchunks, nheads, dim)
  298. final_states: (batch, nheads, dim)
  299. """
  300. return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
  301. def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
  302. """
  303. Argument:
  304. states: (batch, nchunks, nheads, dim)
  305. dA_chunk_cumsum: (batch, nheads, nchunks)
  306. initial_states: (batch, nheads, dim)
  307. Return:
  308. out: (batch, nchunks, nheads, dim)
  309. final_states: (batch, nheads, dim)
  310. """
  311. if initial_states is None:
  312. initial_states = torch.zeros_like(states[:, 0])
  313. states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
  314. dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
  315. dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
  316. nchunks = dA_chunk_cumsum.shape[-1]
  317. # (batch, nheads, nchunks, nchunks)
  318. dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
  319. # (batch, nheads, nchunks, nchunks)
  320. decay_chunk = torch.exp(dt_chunk_segment_sum)
  321. causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
  322. decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
  323. out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
  324. return out[:, :-1], out[:, -1]