selective_state_update.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
  3. """
  4. import math
  5. import torch
  6. import torch.nn.functional as F
  7. import triton
  8. import triton.language as tl
  9. from einops import rearrange, repeat
  10. @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
  11. @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
  12. @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
  13. @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
  14. @triton.jit
  15. def _selective_scan_update_kernel(
  16. # Pointers to matrices
  17. state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
  18. # Matrix dimensions
  19. batch, nheads, dim, dstate, nheads_ngroups_ratio,
  20. # Strides
  21. stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
  22. stride_x_batch, stride_x_head, stride_x_dim,
  23. stride_dt_batch, stride_dt_head, stride_dt_dim,
  24. stride_dt_bias_head, stride_dt_bias_dim,
  25. stride_A_head, stride_A_dim, stride_A_dstate,
  26. stride_B_batch, stride_B_group, stride_B_dstate,
  27. stride_C_batch, stride_C_group, stride_C_dstate,
  28. stride_D_head, stride_D_dim,
  29. stride_z_batch, stride_z_head, stride_z_dim,
  30. stride_out_batch, stride_out_head, stride_out_dim,
  31. # Meta-parameters
  32. DT_SOFTPLUS: tl.constexpr,
  33. TIE_HDIM: tl.constexpr,
  34. BLOCK_SIZE_M: tl.constexpr,
  35. HAS_DT_BIAS: tl.constexpr,
  36. HAS_D: tl.constexpr,
  37. HAS_Z: tl.constexpr,
  38. BLOCK_SIZE_DSTATE: tl.constexpr,
  39. ):
  40. pid_m = tl.program_id(axis=0)
  41. pid_b = tl.program_id(axis=1)
  42. pid_h = tl.program_id(axis=2)
  43. state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
  44. x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
  45. dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
  46. if HAS_DT_BIAS:
  47. dt_bias_ptr += pid_h * stride_dt_bias_head
  48. A_ptr += pid_h * stride_A_head
  49. B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
  50. C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
  51. if HAS_Z:
  52. z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
  53. out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
  54. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  55. offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
  56. state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
  57. x_ptrs = x_ptr + offs_m * stride_x_dim
  58. dt_ptrs = dt_ptr + offs_m * stride_dt_dim
  59. if HAS_DT_BIAS:
  60. dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
  61. if HAS_D:
  62. D_ptr += pid_h * stride_D_head
  63. A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
  64. B_ptrs = B_ptr + offs_n * stride_B_dstate
  65. C_ptrs = C_ptr + offs_n * stride_C_dstate
  66. if HAS_D:
  67. D_ptrs = D_ptr + offs_m * stride_D_dim
  68. if HAS_Z:
  69. z_ptrs = z_ptr + offs_m * stride_z_dim
  70. out_ptrs = out_ptr + offs_m * stride_out_dim
  71. state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
  72. x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  73. if not TIE_HDIM:
  74. dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  75. if HAS_DT_BIAS:
  76. dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  77. if DT_SOFTPLUS:
  78. dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
  79. A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
  80. dA = tl.exp(A * dt[:, None])
  81. else:
  82. dt = tl.load(dt_ptr).to(tl.float32)
  83. if HAS_DT_BIAS:
  84. dt += tl.load(dt_bias_ptr).to(tl.float32)
  85. if DT_SOFTPLUS:
  86. dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
  87. A = tl.load(A_ptr).to(tl.float32)
  88. dA = tl.exp(A * dt) # scalar, not a matrix
  89. B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
  90. C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
  91. if HAS_D:
  92. D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  93. if HAS_Z:
  94. z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  95. if not TIE_HDIM:
  96. dB = B[None, :] * dt[:, None]
  97. else:
  98. dB = B * dt # vector of size (dstate,)
  99. state = state * dA + dB * x[:, None]
  100. tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
  101. out = tl.sum(state * C[None, :], axis=1)
  102. if HAS_D:
  103. out += x * D
  104. if HAS_Z:
  105. out *= z * tl.sigmoid(z)
  106. tl.store(out_ptrs, out, mask=offs_m < dim)
  107. def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
  108. """
  109. Argument:
  110. state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
  111. x: (batch, dim) or (batch, nheads, dim)
  112. dt: (batch, dim) or (batch, nheads, dim)
  113. A: (dim, dstate) or (nheads, dim, dstate)
  114. B: (batch, dstate) or (batch, ngroups, dstate)
  115. C: (batch, dstate) or (batch, ngroups, dstate)
  116. D: (dim,) or (nheads, dim)
  117. z: (batch, dim) or (batch, nheads, dim)
  118. dt_bias: (dim,) or (nheads, dim)
  119. Return:
  120. out: (batch, dim) or (batch, nheads, dim)
  121. """
  122. has_heads = state.dim() > 3
  123. if state.dim() == 3:
  124. state = state.unsqueeze(1)
  125. if x.dim() == 2:
  126. x = x.unsqueeze(1)
  127. if dt.dim() == 2:
  128. dt = dt.unsqueeze(1)
  129. if A.dim() == 2:
  130. A = A.unsqueeze(0)
  131. if B.dim() == 2:
  132. B = B.unsqueeze(1)
  133. if C.dim() == 2:
  134. C = C.unsqueeze(1)
  135. if D is not None and D.dim() == 1:
  136. D = D.unsqueeze(0)
  137. if z is not None and z.dim() == 2:
  138. z = z.unsqueeze(1)
  139. if dt_bias is not None and dt_bias.dim() == 1:
  140. dt_bias = dt_bias.unsqueeze(0)
  141. batch, nheads, dim, dstate = state.shape
  142. assert x.shape == (batch, nheads, dim)
  143. assert dt.shape == x.shape
  144. assert A.shape == (nheads, dim, dstate)
  145. ngroups = B.shape[1]
  146. assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
  147. assert B.shape == (batch, ngroups, dstate)
  148. assert C.shape == B.shape
  149. if D is not None:
  150. assert D.shape == (nheads, dim)
  151. if z is not None:
  152. assert z.shape == x.shape
  153. if dt_bias is not None:
  154. assert dt_bias.shape == (nheads, dim)
  155. out = torch.empty_like(x)
  156. grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
  157. z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
  158. # We don't want autotune since it will overwrite the state
  159. # We instead tune by hand.
  160. BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
  161. else ((16, 4) if dstate <= 32 else
  162. ((8, 4) if dstate <= 64 else
  163. ((4, 4) if dstate <= 128 else
  164. ((4, 8))))))
  165. tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
  166. with torch.cuda.device(x.device.index):
  167. _selective_scan_update_kernel[grid](
  168. state, x, dt, dt_bias, A, B, C, D, z, out,
  169. batch, nheads, dim, dstate, nheads // ngroups,
  170. state.stride(0), state.stride(1), state.stride(2), state.stride(3),
  171. x.stride(0), x.stride(1), x.stride(2),
  172. dt.stride(0), dt.stride(1), dt.stride(2),
  173. *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
  174. A.stride(0), A.stride(1), A.stride(2),
  175. B.stride(0), B.stride(1), B.stride(2),
  176. C.stride(0), C.stride(1), C.stride(2),
  177. *(D.stride(0), D.stride(1)) if D is not None else 0,
  178. z_strides[0], z_strides[1], z_strides[2],
  179. out.stride(0), out.stride(1), out.stride(2),
  180. dt_softplus,
  181. tie_hdim,
  182. BLOCK_SIZE_M,
  183. num_warps=num_warps,
  184. )
  185. if not has_heads:
  186. out = out.squeeze(1)
  187. return out
  188. def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
  189. """
  190. Argument:
  191. state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
  192. x: (batch, dim) or (batch, nheads, dim)
  193. dt: (batch, dim) or (batch, nheads, dim)
  194. A: (dim, dstate) or (nheads, dim, dstate)
  195. B: (batch, dstate) or (batch, ngroups, dstate)
  196. C: (batch, dstate) or (batch, ngroups, dstate)
  197. D: (dim,) or (nheads, dim)
  198. z: (batch, dim) or (batch, nheads, dim)
  199. dt_bias: (dim,) or (nheads, dim)
  200. Return:
  201. out: (batch, dim) or (batch, nheads, dim)
  202. """
  203. has_heads = state.dim() > 3
  204. if state.dim() == 3:
  205. state = state.unsqueeze(1)
  206. if x.dim() == 2:
  207. x = x.unsqueeze(1)
  208. if dt.dim() == 2:
  209. dt = dt.unsqueeze(1)
  210. if A.dim() == 2:
  211. A = A.unsqueeze(0)
  212. if B.dim() == 2:
  213. B = B.unsqueeze(1)
  214. if C.dim() == 2:
  215. C = C.unsqueeze(1)
  216. if D is not None and D.dim() == 1:
  217. D = D.unsqueeze(0)
  218. if z is not None and z.dim() == 2:
  219. z = z.unsqueeze(1)
  220. if dt_bias is not None and dt_bias.dim() == 1:
  221. dt_bias = dt_bias.unsqueeze(0)
  222. batch, nheads, dim, dstate = state.shape
  223. assert x.shape == (batch, nheads, dim)
  224. assert dt.shape == x.shape
  225. assert A.shape == (nheads, dim, dstate)
  226. ngroups = B.shape[1]
  227. assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
  228. assert B.shape == (batch, ngroups, dstate)
  229. assert C.shape == B.shape
  230. if D is not None:
  231. assert D.shape == (nheads, dim)
  232. if z is not None:
  233. assert z.shape == x.shape
  234. if dt_bias is not None:
  235. assert dt_bias.shape == (nheads, dim)
  236. dt = dt + dt_bias
  237. dt = F.softplus(dt) if dt_softplus else dt
  238. dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
  239. B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
  240. C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
  241. dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
  242. state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
  243. out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
  244. if D is not None:
  245. out += (x * D).to(out.dtype)
  246. out = (out if z is None else out * F.silu(z)).to(x.dtype)
  247. if not has_heads:
  248. out = out.squeeze(1)
  249. return out