csms6s.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import time
  2. import torch
  3. import warnings
  4. WITH_SELECTIVESCAN_OFLEX = True
  5. WITH_SELECTIVESCAN_CORE = False
  6. WITH_SELECTIVESCAN_MAMBA = True
  7. try:
  8. import selective_scan_cuda_oflex
  9. except ImportError:
  10. WITH_SELECTIVESCAN_OFLEX = False
  11. warnings.warn("Can not import selective_scan_cuda_oflex. This affects speed.")
  12. print("Can not import selective_scan_cuda_oflex. This affects speed.", flush=True)
  13. try:
  14. import selective_scan_cuda_core
  15. except ImportError:
  16. WITH_SELECTIVESCAN_CORE = False
  17. try:
  18. import selective_scan_cuda
  19. except ImportError:
  20. WITH_SELECTIVESCAN_MAMBA = False
  21. def selective_scan_torch(
  22. u: torch.Tensor, # (B, K * C, L)
  23. delta: torch.Tensor, # (B, K * C, L)
  24. A: torch.Tensor, # (K * C, N)
  25. B: torch.Tensor, # (B, K, N, L)
  26. C: torch.Tensor, # (B, K, N, L)
  27. D: torch.Tensor = None, # (K * C)
  28. delta_bias: torch.Tensor = None, # (K * C)
  29. delta_softplus=True,
  30. oflex=True,
  31. *args,
  32. **kwargs
  33. ):
  34. dtype_in = u.dtype
  35. Batch, K, N, L = B.shape
  36. KCdim = u.shape[1]
  37. Cdim = int(KCdim / K)
  38. assert u.shape == (Batch, KCdim, L)
  39. assert delta.shape == (Batch, KCdim, L)
  40. assert A.shape == (KCdim, N)
  41. assert C.shape == B.shape
  42. if delta_bias is not None:
  43. delta = delta + delta_bias[..., None]
  44. if delta_softplus:
  45. delta = torch.nn.functional.softplus(delta)
  46. u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
  47. B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
  48. C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
  49. deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
  50. deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
  51. if True:
  52. x = A.new_zeros((Batch, KCdim, N))
  53. ys = []
  54. for i in range(L):
  55. x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
  56. y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
  57. ys.append(y)
  58. y = torch.stack(ys, dim=2) # (B, C, L)
  59. out = y if D is None else y + u * D.unsqueeze(-1)
  60. return out if oflex else out.to(dtype=dtype_in)
  61. class SelectiveScanCuda(torch.autograd.Function):
  62. @staticmethod
  63. @torch.cuda.amp.custom_fwd
  64. def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
  65. ctx.delta_softplus = delta_softplus
  66. backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
  67. backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
  68. backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
  69. ctx.backend = backend
  70. if backend == "oflex":
  71. out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
  72. elif backend == "core":
  73. out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
  74. elif backend == "mamba":
  75. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
  76. ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
  77. return out
  78. @staticmethod
  79. @torch.cuda.amp.custom_bwd
  80. def backward(ctx, dout, *args):
  81. u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
  82. backend = ctx.backend
  83. if dout.stride(-1) != 1:
  84. dout = dout.contiguous()
  85. if backend == "oflex":
  86. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
  87. u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
  88. )
  89. elif backend == "core":
  90. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
  91. u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
  92. )
  93. elif backend == "mamba":
  94. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
  95. u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
  96. False
  97. )
  98. return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
  99. def selective_scan_fn(
  100. u: torch.Tensor, # (B, K * C, L)
  101. delta: torch.Tensor, # (B, K * C, L)
  102. A: torch.Tensor, # (K * C, N)
  103. B: torch.Tensor, # (B, K, N, L)
  104. C: torch.Tensor, # (B, K, N, L)
  105. D: torch.Tensor = None, # (K * C)
  106. delta_bias: torch.Tensor = None, # (K * C)
  107. delta_softplus=True,
  108. oflex=True,
  109. backend=None,
  110. ):
  111. WITH_CUDA = (WITH_SELECTIVESCAN_OFLEX or WITH_SELECTIVESCAN_CORE or WITH_SELECTIVESCAN_MAMBA)
  112. fn = selective_scan_torch if backend == "torch" or (not WITH_CUDA) else SelectiveScanCuda.apply
  113. return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
  114. # fvcore flops =======================================
  115. def print_jit_input_names(inputs):
  116. print("input params: ", end=" ", flush=True)
  117. try:
  118. for i in range(10):
  119. print(inputs[i].debugName(), end=" ", flush=True)
  120. except Exception as e:
  121. pass
  122. print("", flush=True)
  123. def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
  124. """
  125. u: r(B D L)
  126. delta: r(B D L)
  127. A: r(D N)
  128. B: r(B N L)
  129. C: r(B N L)
  130. D: r(D)
  131. z: r(B D L)
  132. delta_bias: r(D), fp32
  133. ignores:
  134. [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
  135. """
  136. assert not with_complex
  137. # https://github.com/state-spaces/mamba/issues/110
  138. flops = 9 * B * L * D * N
  139. if with_D:
  140. flops += B * D * L
  141. if with_Z:
  142. flops += B * D * L
  143. return flops
  144. # this is only for selective_scan_ref...
  145. def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
  146. """
  147. u: r(B D L)
  148. delta: r(B D L)
  149. A: r(D N)
  150. B: r(B N L)
  151. C: r(B N L)
  152. D: r(D)
  153. z: r(B D L)
  154. delta_bias: r(D), fp32
  155. ignores:
  156. [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
  157. """
  158. import numpy as np
  159. # fvcore.nn.jit_handles
  160. def get_flops_einsum(input_shapes, equation):
  161. np_arrs = [np.zeros(s) for s in input_shapes]
  162. optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
  163. for line in optim.split("\n"):
  164. if "optimized flop" in line.lower():
  165. # divided by 2 because we count MAC (multiply-add counted as one flop)
  166. flop = float(np.floor(float(line.split(":")[-1]) / 2))
  167. return flop
  168. assert not with_complex
  169. flops = 0 # below code flops = 0
  170. flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
  171. if with_Group:
  172. flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
  173. else:
  174. flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
  175. in_for_flops = B * D * N
  176. if with_Group:
  177. in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
  178. else:
  179. in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
  180. flops += L * in_for_flops
  181. if with_D:
  182. flops += B * D * L
  183. if with_Z:
  184. flops += B * D * L
  185. return flops
  186. def selective_scan_flop_jit(inputs, outputs, backend="prefixsum", verbose=True):
  187. if verbose:
  188. print_jit_input_names(inputs)
  189. flops_fn = flops_selective_scan_ref if backend == "naive" else flops_selective_scan_fn
  190. B, D, L = inputs[0].type().sizes()
  191. N = inputs[2].type().sizes()[1]
  192. flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
  193. return flops
  194. if __name__ == "__main__":
  195. def params(B, K, C, N, L, device = torch.device("cuda"), itype = torch.float):
  196. As = (-0.5 * torch.rand(K * C, N, device=device, dtype=torch.float32)).requires_grad_()
  197. Bs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
  198. Cs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
  199. Ds = torch.randn((K * C), device=device, dtype=torch.float32).requires_grad_()
  200. u = torch.randn((B, K * C, L), device=device, dtype=itype).requires_grad_()
  201. delta = (0.5 * torch.rand((B, K * C, L), device=device, dtype=itype)).requires_grad_()
  202. delta_bias = (0.5 * torch.rand((K * C), device=device, dtype=torch.float32)).requires_grad_()
  203. return u, delta, As, Bs, Cs, Ds, delta_bias
  204. def bench(func, xs, Warmup=30, NTimes=20):
  205. import time
  206. torch.cuda.synchronize()
  207. for r in range(Warmup):
  208. for x in xs:
  209. func(x)
  210. torch.cuda.synchronize()
  211. tim0 = time.time()
  212. for r in range(NTimes):
  213. for x in xs:
  214. func(x)
  215. torch.cuda.synchronize()
  216. return (time.time() - tim0) / NTimes
  217. def check():
  218. u, delta, As, Bs, Cs, Ds, delta_bias = params(1, 4, 16, 8, 512, itype=torch.float16)
  219. u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1 = [x.clone().detach().requires_grad_() for x in [u, delta, As, Bs, Cs, Ds, delta_bias]]
  220. # out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="torch")
  221. out = selective_scan_fn(u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1, True, backend="oflex")
  222. out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="mamba")
  223. print((out_ref - out).abs().max())
  224. out.sum().backward()
  225. out_ref.sum().backward()
  226. for x, y in zip([u, As, Bs, Cs, Ds, delta, delta_bias], [u1, As1, Bs1, Cs1, Ds1, delta1, delta_bias1]):
  227. print((x.grad - y.grad).abs().max())
  228. u, delta, As, Bs, Cs, Ds, delta_bias = params(128, 4, 96, 8, 56 * 56)
  229. print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="oflex"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
  230. print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="mamba"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
  231. print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="torch"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
  232. check()