ssd_minimal.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) 2024, Albert Gu and Tri Dao.
  2. """Minimal implementation of SSD.
  3. This is the same as Listing 1 from the paper.
  4. """
  5. import torch
  6. import torch.nn.functional as F
  7. from einops import rearrange, repeat
  8. def segsum_unstable(x):
  9. """Naive segment sum calculation."""
  10. T = x.size(-1)
  11. x_cumsum = torch.cumsum(x, dim=-1)
  12. x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
  13. mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
  14. x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
  15. return x_segsum
  16. def segsum(x):
  17. """More stable segment sum calculation."""
  18. T = x.size(-1)
  19. x = repeat(x, "... d -> ... d e", e=T)
  20. mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
  21. x = x.masked_fill(~mask, 0)
  22. x_segsum = torch.cumsum(x, dim=-2)
  23. mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
  24. x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
  25. return x_segsum
  26. def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
  27. """
  28. Arguments:
  29. X: (batch, length, n_heads, d_head)
  30. A: (batch, length, n_heads)
  31. B: (batch, length, n_heads, d_state)
  32. C: (batch, length, n_heads, d_state)
  33. Return:
  34. Y: (batch, length, n_heads, d_head)
  35. """
  36. assert X.dtype == A.dtype == B.dtype == C.dtype
  37. assert X.shape[1] % block_len == 0
  38. # Rearrange into blocks/chunks
  39. X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
  40. A = rearrange(A, "b c l h -> b h c l")
  41. A_cumsum = torch.cumsum(A, dim=-1)
  42. # 1. Compute the output for each intra-chunk (diagonal blocks)
  43. L = torch.exp(segsum(A))
  44. Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
  45. # 2. Compute the state for each intra-chunk
  46. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  47. decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
  48. states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
  49. # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
  50. # (middle term of factorization of off-diag blocks; A terms)
  51. if initial_states is None:
  52. initial_states = torch.zeros_like(states[:, :1])
  53. states = torch.cat([initial_states, states], dim=1)
  54. decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
  55. new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
  56. states, final_state = new_states[:, :-1], new_states[:, -1]
  57. # 4. Compute state -> output conversion per chunk
  58. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  59. state_decay_out = torch.exp(A_cumsum)
  60. Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
  61. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  62. Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
  63. return Y, final_state
  64. # =====================================
  65. # add below 2 lines in `_mamba_chunk_scan_combined_fwd`...:
  66. # tuple(...)
  67. def mamba_chunk_scan_combined_torch(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False):
  68. """
  69. Argument:
  70. x: (batch, seqlen, nheads, headdim)
  71. dt: (batch, seqlen, nheads)
  72. A: (nheads)
  73. B: (batch, seqlen, ngroups, dstate)
  74. C: (batch, seqlen, ngroups, dstate)
  75. chunk_size: int
  76. D: (nheads, headdim) or (nheads,)
  77. z: (batch, seqlen, nheads, headdim)
  78. dt_bias: (nheads,)
  79. initial_states: (batch, nheads, headdim, dstate)
  80. seq_idx: (batch, seqlen)
  81. dt_softplus: Whether to apply softplus to dt
  82. Return:
  83. out: (batch, seqlen, nheads, headdim)
  84. """
  85. batch, seqlen, ngroups, dstate = B.shape
  86. nheads, headdim = x.shape[2:]
  87. while seqlen % chunk_size != 0:
  88. chunk_size = chunk_size >> 1
  89. if nheads != ngroups:
  90. assert nheads % ngroups == 0
  91. B = B.view(batch, seqlen, ngroups, 1, dstate).repeat(1, 1, 1, nheads // ngroups, 1).view(batch, seqlen, nheads, dstate)
  92. C = C.view(batch, seqlen, ngroups, 1, dstate).repeat(1, 1, 1, nheads // ngroups, 1).view(batch, seqlen, nheads, dstate)
  93. if dt_bias is not None:
  94. dt = dt + dt_bias
  95. if dt_softplus:
  96. dt = F.softplus(dt)
  97. u = x * dt.unsqueeze(-1)
  98. w = A * dt
  99. y, state = ssd_minimal_discrete(u, w, B, C, block_len=chunk_size, initial_states=initial_states)
  100. if D is not None:
  101. y = y + D.view(y.shape[-2], -1) * x
  102. if z is not None:
  103. y = y * (z * torch.sigmoid(z))
  104. return (y, state) if return_final_states else y
  105. WITH_TRITON = True
  106. # WITH_TRITON = False
  107. try:
  108. import triton
  109. except ImportError:
  110. WITH_TRITON = False
  111. if WITH_TRITON:
  112. try:
  113. from .ssd_combined import mamba_chunk_scan_combined
  114. except ImportError:
  115. from ssd_combined import mamba_chunk_scan_combined
  116. def selective_scan_chunk_fn(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, backend=None):
  117. fn = mamba_chunk_scan_combined_torch if backend == "torch" or (not WITH_TRITON) else mamba_chunk_scan_combined
  118. return fn(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
  119. # Simple test
  120. def test_correctness():
  121. torch.manual_seed(42)
  122. ## Dimensions
  123. # Denoted (B, T, Q, D, P) in the paper
  124. batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
  125. nheads = dim // headdim # (H) in the paper
  126. ngroups = 1 # (G) in the paper
  127. ngroups = nheads # (G) in the paper
  128. dstate = 64 # (N) in the paper
  129. dtype = torch.float32
  130. device = "cuda"
  131. x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
  132. dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_()
  133. A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_()
  134. B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
  135. C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
  136. D = torch.randn(nheads, dtype=dtype, device=device)
  137. yto = selective_scan_chunk_fn(x, dt, A, B, C, chunk_size=64, D=D, backend="torch")
  138. ytr = selective_scan_chunk_fn(x, dt, A, B, C, chunk_size=64, D=D, backend="triton")
  139. print((yto - ytr).abs().max())
  140. breakpoint()
  141. ...
  142. if __name__ == "__main__":
  143. test_correctness()
  144. breakpoint()