test_selective_scan.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. # Modified by $@#Anonymous#@$ #20240123
  2. # Copyright (C) 2023, Tri Dao, Albert Gu.
  3. import math
  4. import torch
  5. import torch.nn.functional as F
  6. import pytest
  7. import torch
  8. import torch.nn.functional as F
  9. from torch.cuda.amp import custom_bwd, custom_fwd
  10. from einops import rearrange, repeat
  11. import time
  12. from functools import partial
  13. SSOFLEX_FLOAT = True
  14. def build_selective_scan_fn(selective_scan_cuda: object = None, mode="mamba_ssm", tag=None):
  15. MODE = mode
  16. class SelectiveScanFn(torch.autograd.Function):
  17. @staticmethod
  18. def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1):
  19. if u.stride(-1) != 1:
  20. u = u.contiguous()
  21. if delta.stride(-1) != 1:
  22. delta = delta.contiguous()
  23. if D is not None:
  24. D = D.contiguous()
  25. if B.stride(-1) != 1:
  26. B = B.contiguous()
  27. if C.stride(-1) != 1:
  28. C = C.contiguous()
  29. if z is not None and z.stride(-1) != 1:
  30. z = z.contiguous()
  31. if B.dim() == 3:
  32. B = rearrange(B, "b dstate l -> b 1 dstate l")
  33. ctx.squeeze_B = True
  34. if C.dim() == 3:
  35. C = rearrange(C, "b dstate l -> b 1 dstate l")
  36. ctx.squeeze_C = True
  37. if D is not None and (D.dtype != torch.float):
  38. ctx._d_dtype = D.dtype
  39. D = D.float()
  40. if delta_bias is not None and (delta_bias.dtype != torch.float):
  41. ctx._delta_bias_dtype = delta_bias.dtype
  42. delta_bias = delta_bias.float()
  43. assert u.shape[1] % (B.shape[1] * nrows) == 0
  44. assert nrows in [1, 2, 3, 4] # 8+ is too slow to compile
  45. if backnrows > 0:
  46. assert u.shape[1] % (B.shape[1] * backnrows) == 0
  47. assert backnrows in [1, 2, 3, 4] # 8+ is too slow to compile
  48. else:
  49. backnrows = nrows
  50. ctx.backnrows = backnrows
  51. if MODE in ["mamba_ssm"]:
  52. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
  53. elif MODE in ["ssoflex"]:
  54. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, SSOFLEX_FLOAT)
  55. elif MODE in ["sscore"]:
  56. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows)
  57. elif MODE in ["sstest"]:
  58. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, nrows)
  59. elif MODE in ["sscorendstate"]:
  60. assert A.shape[-1] == 1 and B.shape[2] == 1 and C.shape[2] == 1
  61. A = A.view(-1)
  62. B = B.squeeze(2)
  63. C = C.squeeze(2)
  64. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
  65. else:
  66. raise NotImplementedError
  67. ctx.delta_softplus = delta_softplus
  68. ctx.has_z = z is not None
  69. last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
  70. if not ctx.has_z:
  71. ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
  72. return out if not return_last_state else (out, last_state)
  73. else:
  74. ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
  75. if MODE in ["mamba_ssm", "sstest"]:
  76. out_z = rest[0]
  77. return out_z if not return_last_state else (out_z, last_state)
  78. elif MODE in ["sscore", "ssoflex"]:
  79. return out if not return_last_state else (out, last_state)
  80. @staticmethod
  81. def backward(ctx, dout, *args):
  82. if not ctx.has_z:
  83. u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
  84. z = None
  85. out = None
  86. else:
  87. u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
  88. if dout.stride(-1) != 1:
  89. dout = dout.contiguous()
  90. # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
  91. # backward of selective_scan_cuda with the backward of chunk).
  92. # Here we just pass in None and dz will be allocated in the C++ code.
  93. if MODE in ["mamba_ssm"]:
  94. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
  95. u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
  96. False # option to recompute out_z, not used here
  97. )
  98. elif MODE in ["sstest"]:
  99. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
  100. u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
  101. False, ctx.backnrows # option to recompute out_z, not used here
  102. )
  103. elif MODE in ["sscore", "ssoflex"]:
  104. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
  105. u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, ctx.backnrows
  106. )
  107. elif MODE in ["sscorendstate"]:
  108. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
  109. u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
  110. )
  111. dA = dA.unsqueeze(1)
  112. dB = dB.unsqueeze(2)
  113. dC = dC.unsqueeze(2)
  114. else:
  115. raise NotImplementedError
  116. dz = rest[0] if ctx.has_z else None
  117. dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
  118. dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
  119. _dD = None
  120. if D is not None:
  121. if dD.dtype != getattr(ctx, "_d_dtype", dD.dtype):
  122. _dD = dD.to(ctx._d_dtype)
  123. else:
  124. _dD = dD
  125. _ddelta_bias = None
  126. if delta_bias is not None:
  127. if ddelta_bias.dtype != getattr(ctx, "_delta_bias_dtype", ddelta_bias.dtype):
  128. _ddelta_bias = ddelta_bias.to(ctx._delta_bias_dtype)
  129. else:
  130. _ddelta_bias = ddelta_bias
  131. return (du, ddelta, dA, dB, dC,
  132. dD if D is not None else None,
  133. dz,
  134. ddelta_bias if delta_bias is not None else None,
  135. None, None, None, None)
  136. def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, nrows=1, backnrows=-1):
  137. """if return_last_state is True, returns (out, last_state)
  138. last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
  139. not considered in the backward pass.
  140. """
  141. outs = SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state, nrows, backnrows)
  142. if mode in ["ssoflex"]:
  143. return outs.to(u.dtype) if not return_last_state else (outs[0].to(u.dtype), outs[1])
  144. else:
  145. return outs
  146. selective_scan_fn.__repr__ = lambda *_ :f"selective_scan_fn | {mode} | {tag}"
  147. return selective_scan_fn
  148. def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  149. return_last_state=False):
  150. """
  151. u: r(B D L)
  152. delta: r(B D L)
  153. A: c(D N) or r(D N)
  154. B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  155. C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  156. D: r(D)
  157. z: r(B D L)
  158. delta_bias: r(D), fp32
  159. out: r(B D L)
  160. last_state (optional): r(B D dstate) or c(B D dstate)
  161. """
  162. dtype_in = u.dtype
  163. u = u.float()
  164. delta = delta.float()
  165. if delta_bias is not None:
  166. delta = delta + delta_bias[..., None].float()
  167. if delta_softplus:
  168. delta = F.softplus(delta)
  169. batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
  170. is_variable_B = B.dim() >= 3
  171. is_variable_C = C.dim() >= 3
  172. if A.is_complex():
  173. if is_variable_B:
  174. B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
  175. if is_variable_C:
  176. C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
  177. else:
  178. B = B.float()
  179. C = C.float()
  180. x = A.new_zeros((batch, dim, dstate))
  181. ys = []
  182. deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
  183. if not is_variable_B:
  184. deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
  185. else:
  186. if B.dim() == 3:
  187. deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
  188. else:
  189. B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
  190. deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
  191. if is_variable_C and C.dim() == 4:
  192. C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
  193. last_state = None
  194. for i in range(u.shape[2]):
  195. x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
  196. if not is_variable_C:
  197. y = torch.einsum('bdn,dn->bd', x, C)
  198. else:
  199. if C.dim() == 3:
  200. y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
  201. else:
  202. y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
  203. if i == u.shape[2] - 1:
  204. last_state = x
  205. if y.is_complex():
  206. y = y.real * 2
  207. ys.append(y)
  208. y = torch.stack(ys, dim=2) # (batch dim L)
  209. out = y if D is None else y + u * rearrange(D, "d -> d 1")
  210. if z is not None:
  211. out = out * F.silu(z)
  212. out = out.to(dtype=dtype_in)
  213. return out if not return_last_state else (out, last_state)
  214. def selective_scan_ref_v2(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  215. return_last_state=False):
  216. """
  217. u: r(B D L)
  218. delta: r(B D L)
  219. A: c(D N) or r(D N)
  220. B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  221. C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  222. D: r(D)
  223. z: r(B D L)
  224. delta_bias: r(D), fp32
  225. out: r(B D L)
  226. last_state (optional): r(B D dstate) or c(B D dstate)
  227. """
  228. dtype_in = u.dtype
  229. A = A.to(dtype_in)
  230. B = B.to(dtype_in)
  231. C = C.to(dtype_in)
  232. D = D.to(dtype_in) if D is not None else None
  233. z = z.to(dtype_in) if z is not None else None
  234. delta = delta.to(dtype_in) if delta is not None else None
  235. delta_bias = delta_bias.to(dtype_in) if delta_bias is not None else None
  236. if delta_bias is not None:
  237. delta = delta + delta_bias[..., None]
  238. if delta_softplus:
  239. delta = F.softplus(delta)
  240. batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
  241. is_variable_B = B.dim() >= 3
  242. is_variable_C = C.dim() >= 3
  243. if A.is_complex():
  244. if is_variable_B:
  245. B = torch.view_as_complex(rearrange(B, "... (L two) -> ... L two", two=2))
  246. if is_variable_C:
  247. C = torch.view_as_complex(rearrange(C, "... (L two) -> ... L two", two=2))
  248. x = A.new_zeros((batch, dim, dstate))
  249. ys = []
  250. deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
  251. if not is_variable_B:
  252. deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
  253. else:
  254. if B.dim() == 3:
  255. deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
  256. else:
  257. B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
  258. deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
  259. if is_variable_C and C.dim() == 4:
  260. C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
  261. last_state = None
  262. for i in range(u.shape[2]):
  263. x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
  264. if not is_variable_C:
  265. y = torch.einsum('bdn,dn->bd', x, C)
  266. else:
  267. if C.dim() == 3:
  268. y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
  269. else:
  270. y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
  271. if i == u.shape[2] - 1:
  272. last_state = x
  273. if y.is_complex():
  274. y = y.real * 2
  275. ys.append(y)
  276. y = torch.stack(ys, dim=2) # (batch dim L)
  277. out = y if D is None else y + u * rearrange(D, "d -> d 1")
  278. if z is not None:
  279. out = out * F.silu(z)
  280. out = out.to(dtype=dtype_in)
  281. return out if not return_last_state else (out, last_state.float())
  282. def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False, *args, **kwargs):
  283. return selective_scan_ref_v2(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
  284. # MODE = None
  285. # MODE = "mamba_ssm"
  286. # MODE = "sscore"
  287. # MODE = "ssoflex"
  288. # MODE = "sstest"
  289. # MODE = "mamba_ssm_sscore" # 1344 items pass
  290. # MODE = "mamba_ssm_sscorendstate" # 1344 items pass
  291. MODE = "mamba_ssm_ssoflex" # 1344 items pass
  292. if MODE in ["mamba_ssm"]:
  293. import selective_scan_cuda
  294. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda, mode=MODE)
  295. selective_scan_ref = selective_scan_ref
  296. elif MODE in ["ssoflex"]:
  297. import selective_scan_cuda_oflex
  298. selective_scan_cuda = selective_scan_cuda_oflex
  299. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_oflex, mode=MODE)
  300. selective_scan_ref = selective_scan_ref
  301. elif MODE in ["sscore"]:
  302. import selective_scan_cuda_core
  303. selective_scan_cuda = selective_scan_cuda_core
  304. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode=MODE)
  305. selective_scan_ref = selective_scan_ref
  306. elif MODE in ["sstest"]:
  307. import selective_scan_cuda_test
  308. selective_scan_cuda = selective_scan_cuda_test
  309. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode=MODE)
  310. selective_scan_ref = selective_scan_ref
  311. elif MODE in ["mamba_ssm_sscore"]:
  312. import selective_scan_cuda_core
  313. import selective_scan_cuda
  314. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode="sscore")
  315. selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm")
  316. elif MODE in ["mamba_ssm_sstest"]:
  317. import selective_scan_cuda_test
  318. import selective_scan_cuda
  319. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_test, mode="sstest")
  320. selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm")
  321. elif MODE in ["mamba_ssm_sscorendstate"]:
  322. import selective_scan_cuda_core
  323. import selective_scan_cuda
  324. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_core, mode="sscorendstate")
  325. selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm")
  326. elif MODE in ["mamba_ssm_ssoflex"]:
  327. import selective_scan_cuda_oflex
  328. import selective_scan_cuda
  329. selective_scan_fn = build_selective_scan_fn(selective_scan_cuda_oflex, mode="ssoflex")
  330. selective_scan_ref = build_selective_scan_fn(selective_scan_cuda, mode="mamba_ssm")
  331. else:
  332. selective_scan_cuda = None
  333. print("use MODE:", MODE)
  334. DSTATE = [1]
  335. DIM = [768]
  336. DIM1 = [768]
  337. DIM1 = [24]
  338. BATCHSIZE = [2]
  339. # DSTATE = [1] if MODE in ["mamba_ssm_sscorendstate", "sscorendstate"] else [8]
  340. NROWS = [1,2,3,4]
  341. IDTYPE = MODE in [None]
  342. # @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
  343. @pytest.mark.parametrize('wtype', [torch.float32])
  344. @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
  345. @pytest.mark.parametrize('seqlen', [64, 128, 256, 512, 1024, 2048, 4096])
  346. @pytest.mark.parametrize("return_last_state", [True])
  347. @pytest.mark.parametrize('has_delta_bias', [False, True])
  348. @pytest.mark.parametrize('delta_softplus', [False, True])
  349. # @pytest.mark.parametrize('has_z', [False, True])
  350. @pytest.mark.parametrize('has_z', [False])
  351. @pytest.mark.parametrize('has_D', [False, True])
  352. @pytest.mark.parametrize("varBC_groups", [1, 2])
  353. # @pytest.mark.parametrize("is_variable_C", [False, True])
  354. @pytest.mark.parametrize("is_variable_C", [True])
  355. # @pytest.mark.parametrize("is_variable_B", [False, True])
  356. @pytest.mark.parametrize("is_variable_B", [True])
  357. @pytest.mark.parametrize("nrows", NROWS)
  358. @pytest.mark.parametrize("batch_size", BATCHSIZE)
  359. @pytest.mark.parametrize("dim", DIM)
  360. @pytest.mark.parametrize("dim1", DIM1)
  361. @pytest.mark.parametrize("dstate", DSTATE)
  362. def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias,
  363. delta_softplus, return_last_state, seqlen, itype, wtype, nrows, batch_size, dim, dim1, dstate):
  364. wtype = itype if IDTYPE else wtype
  365. print(f'method: {selective_scan_cuda}')
  366. if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
  367. pytest.skip() # This config is not applicable
  368. device = 'cuda'
  369. rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
  370. if itype == torch.bfloat16:
  371. rtol, atol = 3e-2, 5e-2
  372. rtolw, atolw = (1e-3, 1e-3)
  373. if has_z: # If we have z, the errors on the weights seem higher
  374. rtolw = max(rtolw, rtol)
  375. atolw = max(atolw, atol)
  376. # set seed
  377. torch.random.manual_seed(0)
  378. # batch_size = 2
  379. # dim = 24
  380. # dstate = 8
  381. is_complex = wtype == torch.complex64
  382. A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
  383. if not is_variable_B:
  384. B_shape = (dim, dstate)
  385. elif varBC_groups == 1:
  386. B_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
  387. else:
  388. B_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
  389. B = torch.randn(*B_shape, device=device, dtype=wtype if not is_variable_B else itype,
  390. requires_grad=True)
  391. if not is_variable_C:
  392. C_shape = (dim, dstate)
  393. elif varBC_groups == 1:
  394. C_shape = (batch_size, dstate, seqlen if not is_complex else seqlen * 2)
  395. else:
  396. C_shape = (batch_size, varBC_groups, dstate, seqlen if not is_complex else seqlen * 2)
  397. C = torch.randn(*C_shape, device=device, dtype=wtype if not is_variable_C else itype,
  398. requires_grad=True)
  399. if has_D:
  400. D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
  401. else:
  402. D = None
  403. if has_z:
  404. z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
  405. else:
  406. z = None
  407. if has_delta_bias:
  408. delta_bias = (0.5 * torch.rand(dim1, device=device, dtype=torch.float32)).requires_grad_()
  409. else:
  410. delta_bias = None
  411. u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype, requires_grad=True)
  412. delta = (0.5 * torch.rand(batch_size, dim1, seqlen, device=device, dtype=itype)).requires_grad_()
  413. A_ref = A.detach().clone().requires_grad_()
  414. B_ref = B.detach().clone().requires_grad_()
  415. C_ref = C.detach().clone().requires_grad_()
  416. D_ref = D.detach().clone().requires_grad_() if D is not None else None
  417. z_ref = z.detach().clone().requires_grad_() if z is not None else None
  418. u_ref = u.detach().clone().requires_grad_()
  419. delta_ref = delta.detach().clone().requires_grad_()
  420. delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
  421. if dim1 != dim:
  422. assert dim % dim1 == 0
  423. delta_ref = delta.unsqueeze(2).repeat(1, 1, dim // dim1, 1).contiguous().flatten(1, 2)
  424. delta_ref = delta_ref.detach().clone().requires_grad_()
  425. delta_bias_ref = delta_bias.unsqueeze(1).repeat(1, dim // dim1).view(-1).detach().clone().requires_grad_() if delta_bias is not None else None
  426. out, *rest = selective_scan_fn(
  427. u, delta, A, B, C, D, z=z,
  428. delta_bias=delta_bias, delta_softplus=delta_softplus,
  429. return_last_state=return_last_state, nrows=nrows
  430. )
  431. if return_last_state:
  432. state = rest[0]
  433. out_ref, *rest = selective_scan_ref(
  434. u_ref, delta_ref, A_ref, B_ref, C_ref, D_ref, z=z_ref,
  435. delta_bias=delta_bias_ref, delta_softplus=delta_softplus,
  436. return_last_state=return_last_state
  437. )
  438. if return_last_state:
  439. state_ref = rest[0]
  440. # dA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
  441. # dt_u = delta * u
  442. print(f'Output max diff: {(out - out_ref).abs().max().item()}')
  443. print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
  444. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  445. if return_last_state:
  446. print(f'State max diff: {(state - state_ref).abs().max().item()}')
  447. assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
  448. g = torch.randn_like(out)
  449. out_ref.backward(g)
  450. out.backward(g)
  451. print(f'du max diff: {(u.grad - u_ref.grad).abs().max().item()}')
  452. print(f'dA max diff: {(A.grad - A_ref.grad).abs().max().item()}')
  453. print(f'dB max diff: {(B.grad - B_ref.grad).abs().max().item()}')
  454. print(f'dC max diff: {(C.grad - C_ref.grad).abs().max().item()}')
  455. if has_D:
  456. print(f'dD max diff: {(D.grad - D_ref.grad).abs().max().item()}')
  457. assert torch.allclose(D.grad, D_ref.grad, rtol=rtolw, atol=atolw)
  458. if has_z:
  459. print(f'dz max diff: {(z.grad - z_ref.grad).abs().max().item()}')
  460. assert torch.allclose(z.grad, z_ref.grad, rtol=rtolw, atol=atolw)
  461. assert torch.allclose(u.grad, u_ref.grad.to(dtype=itype), rtol=rtol * 2, atol=atol * 2)
  462. assert torch.allclose(A.grad, A_ref.grad, rtol=rtolw, atol=atolw * 5)
  463. assert torch.allclose(B.grad, B_ref.grad, rtol=rtolw if not is_variable_B else rtol,
  464. atol=atolw if not is_variable_B else atol)
  465. assert torch.allclose(C.grad, C_ref.grad, rtol=rtolw if not is_variable_C else rtol,
  466. atol=atolw if not is_variable_C else atol)
  467. if dim == dim1:
  468. print(f'ddelta max diff: {(delta.grad - delta_ref.grad).abs().max().item()}')
  469. assert torch.allclose(delta.grad, delta_ref.grad.to(dtype=itype), rtol=rtol * 5, atol=atol * 10)
  470. if has_delta_bias:
  471. print(f'ddelta_bias max diff: {(delta_bias.grad - delta_bias_ref.grad).abs().max().item()}')
  472. assert torch.allclose(delta_bias.grad, delta_bias_ref.grad, rtol=rtolw, atol=atolw)
  473. else:
  474. dgr = delta_ref.grad.view(delta_ref.grad.shape[0], -1, dim // dim1, delta_ref.grad.shape[-1]).sum(2)
  475. print(f'ddelta max diff: {(delta.grad - dgr).abs().max().item()}')
  476. assert torch.allclose(delta.grad, dgr.to(dtype=itype), rtol=rtol * 5, atol=atol * 10), breakpoint()
  477. if has_delta_bias:
  478. dbr = delta_bias_ref.grad.view(-1, dim // dim1).sum(-1)
  479. print(f'ddelta_bias max diff: {(delta_bias.grad - dbr).abs().max().item()}')
  480. assert torch.allclose(delta_bias.grad, dbr, rtol=rtolw, atol=atolw)
  481. # test_selective_scan(True, True, 2, True, False, True, True, True, 64, torch.float32, torch.float32, 1, 2, 24, 24, 1)
  482. # test_selective_scan(True, True, 2, True, False, True, True, True, 64, torch.float32, torch.float32, 1, 2, 24, 12, 1)