# Copyright (c) 2024, Albert Gu and Tri Dao. """Minimal implementation of SSD. This is the same as Listing 1 from the paper. """ import torch import torch.nn.functional as F from einops import rearrange, repeat def segsum_unstable(x): """Naive segment sum calculation.""" T = x.size(-1) x_cumsum = torch.cumsum(x, dim=-1) x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum def segsum(x): """More stable segment sum calculation.""" T = x.size(-1) x = repeat(x, "... d -> ... d e", e=T) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) x = x.masked_fill(~mask, 0) x_segsum = torch.cumsum(x, dim=-2) mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) x_segsum = x_segsum.masked_fill(~mask, -torch.inf) return x_segsum def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): """ Arguments: X: (batch, length, n_heads, d_head) A: (batch, length, n_heads) B: (batch, length, n_heads, d_state) C: (batch, length, n_heads, d_state) Return: Y: (batch, length, n_heads, d_head) """ assert X.dtype == A.dtype == B.dtype == C.dtype assert X.shape[1] % block_len == 0 # Rearrange into blocks/chunks X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)] A = rearrange(A, "b c l h -> b h c l") A_cumsum = torch.cumsum(A, dim=-1) # 1. Compute the output for each intra-chunk (diagonal blocks) L = torch.exp(segsum(A)) Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) if initial_states is None: initial_states = torch.zeros_like(states[:, :1]) states = torch.cat([initial_states, states], dim=1) decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) states, final_state = new_states[:, :-1], new_states[:, -1] # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p") return Y, final_state # ===================================== # add below 2 lines in `_mamba_chunk_scan_combined_fwd`...: # tuple(...) def mamba_chunk_scan_combined_torch(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False): """ Argument: x: (batch, seqlen, nheads, headdim) dt: (batch, seqlen, nheads) A: (nheads) B: (batch, seqlen, ngroups, dstate) C: (batch, seqlen, ngroups, dstate) chunk_size: int D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) seq_idx: (batch, seqlen) dt_softplus: Whether to apply softplus to dt Return: out: (batch, seqlen, nheads, headdim) """ batch, seqlen, ngroups, dstate = B.shape nheads, headdim = x.shape[2:] while seqlen % chunk_size != 0: chunk_size = chunk_size >> 1 if nheads != ngroups: assert nheads % ngroups == 0 B = B.view(batch, seqlen, ngroups, 1, dstate).repeat(1, 1, 1, nheads // ngroups, 1).view(batch, seqlen, nheads, dstate) C = C.view(batch, seqlen, ngroups, 1, dstate).repeat(1, 1, 1, nheads // ngroups, 1).view(batch, seqlen, nheads, dstate) if dt_bias is not None: dt = dt + dt_bias if dt_softplus: dt = F.softplus(dt) u = x * dt.unsqueeze(-1) w = A * dt y, state = ssd_minimal_discrete(u, w, B, C, block_len=chunk_size, initial_states=initial_states) if D is not None: y = y + D.view(y.shape[-2], -1) * x if z is not None: y = y * (z * torch.sigmoid(z)) return (y, state) if return_final_states else y WITH_TRITON = True # WITH_TRITON = False try: import triton except ImportError: WITH_TRITON = False if WITH_TRITON: try: from .ssd_combined import mamba_chunk_scan_combined except ImportError: from ssd_combined import mamba_chunk_scan_combined def selective_scan_chunk_fn(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, backend=None): fn = mamba_chunk_scan_combined_torch if backend == "torch" or (not WITH_TRITON) else mamba_chunk_scan_combined return fn(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states) # Simple test def test_correctness(): torch.manual_seed(42) ## Dimensions # Denoted (B, T, Q, D, P) in the paper batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64 nheads = dim // headdim # (H) in the paper ngroups = 1 # (G) in the paper ngroups = nheads # (G) in the paper dstate = 64 # (N) in the paper dtype = torch.float32 device = "cuda" x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device) dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_() A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_() B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) D = torch.randn(nheads, dtype=dtype, device=device) yto = selective_scan_chunk_fn(x, dt, A, B, C, chunk_size=64, D=D, backend="torch") ytr = selective_scan_chunk_fn(x, dt, A, B, C, chunk_size=64, D=D, backend="triton") print((yto - ytr).abs().max()) breakpoint() ... if __name__ == "__main__": test_correctness() breakpoint()