| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- # Copyright (c) 2024, Tri Dao, Albert Gu.
- """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
- """
- import math
- import torch
- import torch.nn.functional as F
- import triton
- import triton.language as tl
- from einops import rearrange, repeat
- @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
- @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
- @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
- @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
- @triton.jit
- def _selective_scan_update_kernel(
- # Pointers to matrices
- state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
- # Matrix dimensions
- batch, nheads, dim, dstate, nheads_ngroups_ratio,
- # Strides
- stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
- stride_x_batch, stride_x_head, stride_x_dim,
- stride_dt_batch, stride_dt_head, stride_dt_dim,
- stride_dt_bias_head, stride_dt_bias_dim,
- stride_A_head, stride_A_dim, stride_A_dstate,
- stride_B_batch, stride_B_group, stride_B_dstate,
- stride_C_batch, stride_C_group, stride_C_dstate,
- stride_D_head, stride_D_dim,
- stride_z_batch, stride_z_head, stride_z_dim,
- stride_out_batch, stride_out_head, stride_out_dim,
- # Meta-parameters
- DT_SOFTPLUS: tl.constexpr,
- TIE_HDIM: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr,
- HAS_DT_BIAS: tl.constexpr,
- HAS_D: tl.constexpr,
- HAS_Z: tl.constexpr,
- BLOCK_SIZE_DSTATE: tl.constexpr,
- ):
- pid_m = tl.program_id(axis=0)
- pid_b = tl.program_id(axis=1)
- pid_h = tl.program_id(axis=2)
- state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
- x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
- dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
- if HAS_DT_BIAS:
- dt_bias_ptr += pid_h * stride_dt_bias_head
- A_ptr += pid_h * stride_A_head
- B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
- C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
- if HAS_Z:
- z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
- out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
- state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
- x_ptrs = x_ptr + offs_m * stride_x_dim
- dt_ptrs = dt_ptr + offs_m * stride_dt_dim
- if HAS_DT_BIAS:
- dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
- if HAS_D:
- D_ptr += pid_h * stride_D_head
- A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
- B_ptrs = B_ptr + offs_n * stride_B_dstate
- C_ptrs = C_ptr + offs_n * stride_C_dstate
- if HAS_D:
- D_ptrs = D_ptr + offs_m * stride_D_dim
- if HAS_Z:
- z_ptrs = z_ptr + offs_m * stride_z_dim
- out_ptrs = out_ptr + offs_m * stride_out_dim
- state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
- x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
- if not TIE_HDIM:
- dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
- if HAS_DT_BIAS:
- dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
- if DT_SOFTPLUS:
- dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
- A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
- dA = tl.exp(A * dt[:, None])
- else:
- dt = tl.load(dt_ptr).to(tl.float32)
- if HAS_DT_BIAS:
- dt += tl.load(dt_bias_ptr).to(tl.float32)
- if DT_SOFTPLUS:
- dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
- A = tl.load(A_ptr).to(tl.float32)
- dA = tl.exp(A * dt) # scalar, not a matrix
- B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
- C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
- if HAS_D:
- D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
- if HAS_Z:
- z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
- if not TIE_HDIM:
- dB = B[None, :] * dt[:, None]
- else:
- dB = B * dt # vector of size (dstate,)
- state = state * dA + dB * x[:, None]
- tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
- out = tl.sum(state * C[None, :], axis=1)
- if HAS_D:
- out += x * D
- if HAS_Z:
- out *= z * tl.sigmoid(z)
- tl.store(out_ptrs, out, mask=offs_m < dim)
- def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
- """
- Argument:
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
- x: (batch, dim) or (batch, nheads, dim)
- dt: (batch, dim) or (batch, nheads, dim)
- A: (dim, dstate) or (nheads, dim, dstate)
- B: (batch, dstate) or (batch, ngroups, dstate)
- C: (batch, dstate) or (batch, ngroups, dstate)
- D: (dim,) or (nheads, dim)
- z: (batch, dim) or (batch, nheads, dim)
- dt_bias: (dim,) or (nheads, dim)
- Return:
- out: (batch, dim) or (batch, nheads, dim)
- """
- has_heads = state.dim() > 3
- if state.dim() == 3:
- state = state.unsqueeze(1)
- if x.dim() == 2:
- x = x.unsqueeze(1)
- if dt.dim() == 2:
- dt = dt.unsqueeze(1)
- if A.dim() == 2:
- A = A.unsqueeze(0)
- if B.dim() == 2:
- B = B.unsqueeze(1)
- if C.dim() == 2:
- C = C.unsqueeze(1)
- if D is not None and D.dim() == 1:
- D = D.unsqueeze(0)
- if z is not None and z.dim() == 2:
- z = z.unsqueeze(1)
- if dt_bias is not None and dt_bias.dim() == 1:
- dt_bias = dt_bias.unsqueeze(0)
- batch, nheads, dim, dstate = state.shape
- assert x.shape == (batch, nheads, dim)
- assert dt.shape == x.shape
- assert A.shape == (nheads, dim, dstate)
- ngroups = B.shape[1]
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
- assert B.shape == (batch, ngroups, dstate)
- assert C.shape == B.shape
- if D is not None:
- assert D.shape == (nheads, dim)
- if z is not None:
- assert z.shape == x.shape
- if dt_bias is not None:
- assert dt_bias.shape == (nheads, dim)
- out = torch.empty_like(x)
- grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
- z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
- # We don't want autotune since it will overwrite the state
- # We instead tune by hand.
- BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
- else ((16, 4) if dstate <= 32 else
- ((8, 4) if dstate <= 64 else
- ((4, 4) if dstate <= 128 else
- ((4, 8))))))
- tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
- with torch.cuda.device(x.device.index):
- _selective_scan_update_kernel[grid](
- state, x, dt, dt_bias, A, B, C, D, z, out,
- batch, nheads, dim, dstate, nheads // ngroups,
- state.stride(0), state.stride(1), state.stride(2), state.stride(3),
- x.stride(0), x.stride(1), x.stride(2),
- dt.stride(0), dt.stride(1), dt.stride(2),
- *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
- A.stride(0), A.stride(1), A.stride(2),
- B.stride(0), B.stride(1), B.stride(2),
- C.stride(0), C.stride(1), C.stride(2),
- *(D.stride(0), D.stride(1)) if D is not None else 0,
- z_strides[0], z_strides[1], z_strides[2],
- out.stride(0), out.stride(1), out.stride(2),
- dt_softplus,
- tie_hdim,
- BLOCK_SIZE_M,
- num_warps=num_warps,
- )
- if not has_heads:
- out = out.squeeze(1)
- return out
- def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
- """
- Argument:
- state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
- x: (batch, dim) or (batch, nheads, dim)
- dt: (batch, dim) or (batch, nheads, dim)
- A: (dim, dstate) or (nheads, dim, dstate)
- B: (batch, dstate) or (batch, ngroups, dstate)
- C: (batch, dstate) or (batch, ngroups, dstate)
- D: (dim,) or (nheads, dim)
- z: (batch, dim) or (batch, nheads, dim)
- dt_bias: (dim,) or (nheads, dim)
- Return:
- out: (batch, dim) or (batch, nheads, dim)
- """
- has_heads = state.dim() > 3
- if state.dim() == 3:
- state = state.unsqueeze(1)
- if x.dim() == 2:
- x = x.unsqueeze(1)
- if dt.dim() == 2:
- dt = dt.unsqueeze(1)
- if A.dim() == 2:
- A = A.unsqueeze(0)
- if B.dim() == 2:
- B = B.unsqueeze(1)
- if C.dim() == 2:
- C = C.unsqueeze(1)
- if D is not None and D.dim() == 1:
- D = D.unsqueeze(0)
- if z is not None and z.dim() == 2:
- z = z.unsqueeze(1)
- if dt_bias is not None and dt_bias.dim() == 1:
- dt_bias = dt_bias.unsqueeze(0)
- batch, nheads, dim, dstate = state.shape
- assert x.shape == (batch, nheads, dim)
- assert dt.shape == x.shape
- assert A.shape == (nheads, dim, dstate)
- ngroups = B.shape[1]
- assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
- assert B.shape == (batch, ngroups, dstate)
- assert C.shape == B.shape
- if D is not None:
- assert D.shape == (nheads, dim)
- if z is not None:
- assert z.shape == x.shape
- if dt_bias is not None:
- assert dt_bias.shape == (nheads, dim)
- dt = dt + dt_bias
- dt = F.softplus(dt) if dt_softplus else dt
- dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
- B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
- C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
- dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
- state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
- out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
- if D is not None:
- out += (x * D).to(out.dtype)
- out = (out if z is None else out * F.silu(z)).to(x.dtype)
- if not has_heads:
- out = out.squeeze(1)
- return out
|