| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- # Copyright (c) 2024, Albert Gu and Tri Dao.
- """Minimal implementation of SSD.
- This is the same as Listing 1 from the paper.
- """
- import torch
- import torch.nn.functional as F
- from einops import rearrange, repeat
- def segsum_unstable(x):
- """Naive segment sum calculation."""
- T = x.size(-1)
- x_cumsum = torch.cumsum(x, dim=-1)
- x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
- mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
- x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
- return x_segsum
- def segsum(x):
- """More stable segment sum calculation."""
- T = x.size(-1)
- x = repeat(x, "... d -> ... d e", e=T)
- mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
- x = x.masked_fill(~mask, 0)
- x_segsum = torch.cumsum(x, dim=-2)
- mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
- x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
- return x_segsum
- def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
- """
- Arguments:
- X: (batch, length, n_heads, d_head)
- A: (batch, length, n_heads)
- B: (batch, length, n_heads, d_state)
- C: (batch, length, n_heads, d_state)
- Return:
- Y: (batch, length, n_heads, d_head)
- """
- assert X.dtype == A.dtype == B.dtype == C.dtype
- assert X.shape[1] % block_len == 0
- # Rearrange into blocks/chunks
- X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
- A = rearrange(A, "b c l h -> b h c l")
- A_cumsum = torch.cumsum(A, dim=-1)
- # 1. Compute the output for each intra-chunk (diagonal blocks)
- L = torch.exp(segsum(A))
- Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
- # 2. Compute the state for each intra-chunk
- # (right term of low-rank factorization of off-diagonal blocks; B terms)
- decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
- states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
- # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
- # (middle term of factorization of off-diag blocks; A terms)
- if initial_states is None:
- initial_states = torch.zeros_like(states[:, :1])
- states = torch.cat([initial_states, states], dim=1)
- decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
- new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
- states, final_state = new_states[:, :-1], new_states[:, -1]
- # 4. Compute state -> output conversion per chunk
- # (left term of low-rank factorization of off-diagonal blocks; C terms)
- state_decay_out = torch.exp(A_cumsum)
- Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
- # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
- Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
- return Y, final_state
- # =====================================
- # add below 2 lines in `_mamba_chunk_scan_combined_fwd`...:
- # tuple(...)
- def mamba_chunk_scan_combined_torch(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False):
- """
- Argument:
- x: (batch, seqlen, nheads, headdim)
- dt: (batch, seqlen, nheads)
- A: (nheads)
- B: (batch, seqlen, ngroups, dstate)
- C: (batch, seqlen, ngroups, dstate)
- chunk_size: int
- D: (nheads, headdim) or (nheads,)
- z: (batch, seqlen, nheads, headdim)
- dt_bias: (nheads,)
- initial_states: (batch, nheads, headdim, dstate)
- seq_idx: (batch, seqlen)
- dt_softplus: Whether to apply softplus to dt
- Return:
- out: (batch, seqlen, nheads, headdim)
- """
- batch, seqlen, ngroups, dstate = B.shape
- nheads, headdim = x.shape[2:]
-
- while seqlen % chunk_size != 0:
- chunk_size = chunk_size >> 1
-
- if nheads != ngroups:
- assert nheads % ngroups == 0
- B = B.view(batch, seqlen, ngroups, 1, dstate).repeat(1, 1, 1, nheads // ngroups, 1).view(batch, seqlen, nheads, dstate)
- C = C.view(batch, seqlen, ngroups, 1, dstate).repeat(1, 1, 1, nheads // ngroups, 1).view(batch, seqlen, nheads, dstate)
- if dt_bias is not None:
- dt = dt + dt_bias
- if dt_softplus:
- dt = F.softplus(dt)
- u = x * dt.unsqueeze(-1)
- w = A * dt
-
- y, state = ssd_minimal_discrete(u, w, B, C, block_len=chunk_size, initial_states=initial_states)
- if D is not None:
- y = y + D.view(y.shape[-2], -1) * x
- if z is not None:
- y = y * (z * torch.sigmoid(z))
- return (y, state) if return_final_states else y
- WITH_TRITON = True
- # WITH_TRITON = False
- try:
- import triton
- except ImportError:
- WITH_TRITON = False
- if WITH_TRITON:
- try:
- from .ssd_combined import mamba_chunk_scan_combined
- except ImportError:
- from ssd_combined import mamba_chunk_scan_combined
- def selective_scan_chunk_fn(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, backend=None):
- fn = mamba_chunk_scan_combined_torch if backend == "torch" or (not WITH_TRITON) else mamba_chunk_scan_combined
- return fn(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
- # Simple test
- def test_correctness():
- torch.manual_seed(42)
- ## Dimensions
- # Denoted (B, T, Q, D, P) in the paper
- batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
- nheads = dim // headdim # (H) in the paper
- ngroups = 1 # (G) in the paper
- ngroups = nheads # (G) in the paper
- dstate = 64 # (N) in the paper
- dtype = torch.float32
- device = "cuda"
- x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
- dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_()
- A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_()
- B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
- C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
- D = torch.randn(nheads, dtype=dtype, device=device)
- yto = selective_scan_chunk_fn(x, dt, A, B, C, chunk_size=64, D=D, backend="torch")
- ytr = selective_scan_chunk_fn(x, dt, A, B, C, chunk_size=64, D=D, backend="triton")
- print((yto - ytr).abs().max())
- breakpoint()
- ...
- if __name__ == "__main__":
- test_correctness()
- breakpoint()
|