| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828 |
- # Copyright (c) 2024, Tri Dao, Albert Gu.
- """We want triton==2.1.0 or 2.2.0 for this
- """
- import math
- from packaging import version
- import torch
- import torch.nn.functional as F
- import triton
- import triton.language as tl
- from einops import rearrange, repeat
- try:
- from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
- except:
- from ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
- TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
- def init_to_zero(names):
- return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
- ],
- key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
- )
- @triton.jit
- def _chunk_scan_fwd_kernel(
- # Pointers to matrices
- cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,
- # Matrix dimensions
- chunk_size, hdim, dstate,
- batch, seqlen, nheads_ngroups_ratio,
- # Strides
- stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
- stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_seq_idx_batch, stride_seq_idx_seqlen,
- stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
- stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
- stride_D_head,
- # Meta-parameters
- IS_CAUSAL: tl.constexpr,
- HAS_D: tl.constexpr,
- D_HAS_HDIM: tl.constexpr,
- HAS_Z: tl.constexpr,
- HAS_SEQ_IDX: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- BLOCK_SIZE_DSTATE: tl.constexpr,
- IS_TRITON_22: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
- C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
- prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
- if HAS_SEQ_IDX:
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- if HAS_SEQ_IDX:
- seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- # Without the if (pid_c > -1), with Triton 2.1.0, I get
- # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
- # With Triton 2.2.0, this works
- if IS_TRITON_22 or pid_c > -1:
- # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
- offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
- C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
- prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
- if not HAS_SEQ_IDX:
- scale_m = tl.exp(dA_cs_m)
- else:
- scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
- if BLOCK_SIZE_DSTATE <= 128:
- C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
- prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
- prev_states = prev_states.to(C_ptr.dtype.element_ty)
- acc = tl.dot(C, prev_states) * scale_m[:, None]
- else:
- for k in range(0, dstate, BLOCK_SIZE_K):
- C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0)
- # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
- prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
- prev_states = prev_states.to(C_ptr.dtype.element_ty)
- acc += tl.dot(C, prev_states)
- C_ptrs += BLOCK_SIZE_K
- prev_states_ptrs += BLOCK_SIZE_K
- acc *= scale_m[:, None]
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
- x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
- dt_ptrs = dt_ptr + offs_k * stride_dt_csize
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
- K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
- for k in range(0, K_MAX, BLOCK_SIZE_K):
- cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32)
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
- # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
- # So we don't need masking wrt seq_idx here.
- cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
- dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
- cb *= dt_k
- if IS_CAUSAL:
- mask = offs_m[:, None] >= k + offs_k[None, :]
- cb = tl.where(mask, cb, 0.0)
- cb = cb.to(x_ptr.dtype.element_ty)
- x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0)
- acc += tl.dot(cb, x)
- cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
- x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
- dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
- offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- if HAS_D:
- if D_HAS_HDIM:
- D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
- else:
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
- x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
- mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- acc += x_residual * D
- if HAS_Z:
- out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
- out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
- tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
- z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
- z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
- z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
- acc *= z * tl.sigmoid(z)
- out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
- out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim)
- tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
- @triton.autotune(
- configs=[
- # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4),
- # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8),
- ],
- key=['chunk_size', 'hdim', 'dstate'],
- )
- @triton.jit
- def _chunk_scan_fwd_kernel_wip(
- # Pointers to matrices
- cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr,
- # Matrix dimensions
- chunk_size, hdim, dstate,
- batch, seqlen, nheads_ngroups_ratio,
- # Strides
- stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
- stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_seq_idx_batch, stride_seq_idx_seqlen,
- stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
- stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate,
- stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
- stride_D_head,
- # Meta-parameters
- HAS_D: tl.constexpr,
- D_HAS_HDIM: tl.constexpr,
- HAS_Z: tl.constexpr,
- HAS_SEQ_IDX: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
- BLOCK_SIZE_DSTATE: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- pid_n = tl.program_id(axis=0)
- cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
- C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
- B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head
- prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
- if HAS_SEQ_IDX:
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
- out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
- offs_m = tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE)
- C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
- B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate)
- prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
- cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k)
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
- out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
- prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
- # if pid_c == 0:
- # if pid_b == 0:
- # if pid_h == 0:
- # tl.device_print("", prev_states)
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- # scale_m = tl.exp(dA_cs_m)
- # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
- # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
- # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32)
- # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
- # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- # cb *= dt_m
- # mask = offs_m[:, None] >= offs_m[None, :]
- # cb = tl.where(mask, cb, 0.0)
- # cb = cb.to(x_ptr.dtype.element_ty)
- # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0)
- # acc += tl.dot(cb, x)
- # if HAS_D:
- # if D_HAS_HDIM:
- # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
- # else:
- # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
- # acc += x.to(tl.float32) * D
- # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
- for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M):
- start_m = tl.multiple_of(start_m, BLOCK_SIZE_M)
- dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
- if HAS_SEQ_IDX:
- seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
- seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1)
- if not HAS_SEQ_IDX:
- scale_m = tl.exp(dA_cs_m)
- else:
- scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
- C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0)
- acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
- # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32)
- # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
- # cb *= dt_m
- # mask = offs_m[:, None] >= offs_m[None, :]
- # cb = tl.where(mask, cb, 0.0)
- # cb = cb.to(x_ptr.dtype.element_ty)
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0)
- # acc += tl.dot(cb, x)
- if HAS_D:
- if D_HAS_HDIM:
- D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
- else:
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
- acc += x.to(tl.float32) * D
- # if HAS_Z:
- # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
- # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
- # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
- # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
- # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
- # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
- # acc *= z * tl.sigmoid(z)
- tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim))
- # TODO: this is not correct, and quite a bit slower
- if start_m + BLOCK_SIZE_M < chunk_size_limit:
- # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32)
- B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0)
- dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32)
- # TODO: seq_idx
- scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m
- # B *= scale
- B = B.to(x_ptr.dtype.element_ty)
- tmp = tl.dot(B, x)
- prev_states += tmp.to(prev_states.dtype)
- C_ptrs += BLOCK_SIZE_M * stride_C_seqlen
- B_ptrs += BLOCK_SIZE_M * stride_B_seqlen
- cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k
- x_ptrs += BLOCK_SIZE_M * stride_x_seqlen
- dt_ptrs += BLOCK_SIZE_M * stride_dt_csize
- out_ptrs += BLOCK_SIZE_M * stride_out_seqlen
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 32}),
- triton.Config({'BLOCK_SIZE_M': 64}),
- triton.Config({'BLOCK_SIZE_M': 128}),
- triton.Config({'BLOCK_SIZE_M': 256}),
- ],
- key=["chunk_size", "hdim"],
- )
- @triton.jit
- def _chunk_scan_bwd_dz_kernel(
- # Pointers to matrices
- dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr,
- # Matrix dimensions
- chunk_size, hdim,
- batch, seqlen,
- # Strides
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
- stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_D_head,
- stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim,
- stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim,
- stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim,
- stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
- # Meta-parameters
- HAS_D: tl.constexpr,
- D_HAS_HDIM: tl.constexpr,
- HAS_DDACS: tl.constexpr,
- RECOMPUTE_OUTPUT: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- pid_m = tl.program_id(axis=0)
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
- dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head
- out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
- z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
- dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head
- if RECOMPUTE_OUTPUT:
- outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head
- if HAS_DDACS:
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
- if HAS_D:
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
- dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = tl.arange(0, BLOCK_SIZE_N)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
- dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim)
- out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
- z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim)
- dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim)
- if RECOMPUTE_OUTPUT:
- outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim)
- if HAS_D:
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
- if D_HAS_HDIM:
- dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- z_sigmoid = tl.sigmoid(z)
- if RECOMPUTE_OUTPUT:
- outz = out * z * z_sigmoid
- tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
- dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid))
- tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
- dout *= z * z_sigmoid
- tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
- if HAS_D:
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- if D_HAS_HDIM:
- dD = tl.sum(dout * x, axis=0)
- tl.store(dD_ptrs, dD, mask=offs_n < hdim)
- D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
- else:
- dD = tl.sum(dout * x)
- tl.store(dD_ptr, dD)
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
- out -= x * D
- if HAS_DDACS:
- ddA_cs = tl.sum(dout * out, axis=1)
- tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
- ],
- key=['hdim', 'dstate', 'chunk_size'],
- )
- @triton.jit
- def _chunk_scan_bwd_dstates_kernel(
- # Pointers to matrices
- dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr,
- # Matrix dimensions
- hdim, dstate, chunk_size,
- batch, seqlen, nchunks, nheads_ngroups_ratio,
- # Strides
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate,
- stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_seq_idx_batch, stride_seq_idx_seqlen,
- # Meta-parameters
- HAS_SEQ_IDX: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
- if HAS_SEQ_IDX:
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen)
- c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen)
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
- if HAS_SEQ_IDX:
- seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- if HAS_SEQ_IDX:
- seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
- for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32)
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
- if not HAS_SEQ_IDX:
- scale_k = tl.exp(dA_cs_k)
- else:
- seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
- scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0)
- dout = (dout * scale_k).to(dout_ptr.dtype.element_ty)
- c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0)
- acc += tl.dot(dout, c)
- dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
- c_ptrs += BLOCK_SIZE_K * stride_c_seqlen
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
- if HAS_SEQ_IDX:
- seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
- out = acc.to(dprev_states_ptr.dtype.element_ty)
- dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate)
- tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate))
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- ],
- key=['chunk_size', 'dstate', 'hdim'],
- )
- @triton.jit
- def _chunk_scan_bwd_dc_kernel(
- # Pointers to matrices
- dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
- dc_ptr, ddA_cumsum_ptr,
- # Matrix dimensions
- chunk_size, dstate, hdim,
- batch, seqlen, nheads, nheads_per_program, ngroups,
- # Strides
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
- stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_seq_idx_batch, stride_seq_idx_seqlen,
- stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate,
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
- # Meta-parameters
- HAS_DDA_CS: tl.constexpr,
- HAS_SEQ_IDX: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_sg = tl.program_id(axis=2)
- pid_s = pid_sg // ngroups
- pid_g = pid_sg - pid_s * ngroups
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
- dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split
- prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
- if HAS_DDA_CS:
- C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
- if HAS_SEQ_IDX:
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
- prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
- if HAS_DDA_CS:
- C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- if HAS_DDA_CS:
- c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
- if HAS_SEQ_IDX:
- seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
- nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
- for h in range(nheads_iter):
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
- prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
- prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
- dc = tl.dot(dout, prev_states)
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
- if not HAS_SEQ_IDX:
- scale = tl.exp(dA_cs_m)
- else:
- scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
- dc *= scale[:, None]
- if HAS_DDA_CS:
- ddA_cs = tl.sum(dc * c, axis=1)
- tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
- acc += dc
- dout_ptrs += stride_dout_head
- prev_states_ptrs += stride_prev_states_head
- dA_cumsum_ptrs += stride_dA_cs_head
- if HAS_DDA_CS:
- ddA_cumsum_ptrs += stride_ddA_cs_head
- # if HAS_SEQ_IDX:
- # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
- # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
- # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0)
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate)
- tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
- ],
- key=['chunk_size', 'hdim'],
- )
- @triton.jit
- def _chunk_scan_bwd_dx_kernel(
- # Pointers to matrices
- x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr,
- dx_ptr, ddt_ptr, # dD_ptr,
- # Matrix dimensions
- chunk_size, hdim,
- batch, seqlen, nheads_ngroups_ratio,
- # Strides
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_D_head,
- stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
- stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
- # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize,
- # Meta-parameters
- HAS_D: tl.constexpr,
- D_HAS_HDIM: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
- cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
- # if HAS_D:
- # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
- dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- # Idk why limiting K_MAX gives wrong results, is it a Triton bug?
- # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
- K_MAX = chunk_size_limit
- for k in range(0, K_MAX, BLOCK_SIZE_K):
- # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
- cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
- dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
- dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
- cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
- mask = k + offs_k[None, :] >= offs_m[:, None]
- cb = tl.where(mask, cb, 0.0)
- cb = cb.to(dout_ptr.dtype.element_ty)
- acc += tl.dot(cb, dout)
- cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
- dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
- dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- dt_ptrs = dt_ptr + offs_m * stride_dt_csize
- dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
- dx = acc * dt_m[:, None]
- dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
- dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
- if HAS_D:
- dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
- dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- if D_HAS_HDIM:
- D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
- else:
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
- dx += dout_res * D
- tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- ddt = tl.sum(acc * x, axis=1)
- ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
- tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
- # if HAS_D:
- # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim)
- # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32)
- # dD = tl.sum(x * dout, axis=0)
- # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N)
- # Disabling HAS_DDA_CS for now since it's much slower
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
- # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
- # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
- # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
- ],
- key=['chunk_size', 'hdim'],
- )
- # @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)})
- # @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32})
- @triton.jit
- def _chunk_scan_bwd_dcb_kernel(
- # Pointers to matrices
- x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
- dcb_ptr, ddA_cumsum_ptr,
- # Matrix dimensions
- chunk_size, hdim,
- batch, seqlen, nheads, nheads_per_program, ngroups,
- # Strides
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_seq_idx_batch, stride_seq_idx_seqlen,
- stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n,
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
- # Meta-parameters
- HAS_DDA_CS: tl.constexpr,
- HAS_SEQ_IDX: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_sg = tl.program_id(axis=2)
- pid_s = pid_sg // ngroups
- pid_g = pid_sg - pid_s * ngroups
- num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
- if HAS_DDA_CS:
- cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
- if HAS_SEQ_IDX:
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
- x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
- dt_ptrs = dt_ptr + offs_n * stride_dt_csize
- if HAS_DDA_CS:
- cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
- if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
- dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
- dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
- tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
- return
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- if HAS_DDA_CS:
- cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
- nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
- for h in range(nheads_iter):
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
- x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
- dcb = tl.dot(dout, x)
- dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
- dcb *= dt_n
- dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
- dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
- dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- if HAS_DDA_CS:
- tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
- ddA_cs = dcb * cb
- mask = offs_m[:, None] >= offs_n[None, :] + 1
- ddA_cs = tl.where(mask, ddA_cs, 0.0)
- ddA_cs = tl.cumsum(ddA_cs, axis=1)
- ddA_cs = tl.where(mask, ddA_cs, 0.0)
- ddA_cs = tl.sum(ddA_cs, axis=0)
- tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
- tl.store(ddA_cumsum_ptr, 0.0)
- acc += dcb
- dout_ptrs += stride_dout_head
- x_ptrs += stride_x_head
- dt_ptrs += stride_dt_head
- dA_cumsum_ptr += stride_dA_cs_head
- if HAS_DDA_CS:
- ddA_cumsum_ptr += stride_ddA_cs_head
- ddA_cumsum_ptrs += stride_ddA_cs_head
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- if HAS_SEQ_IDX:
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
- seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
- acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
- mask = offs_m[:, None] >= offs_n[None, :]
- acc = tl.where(mask, acc, 0.0)
- dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
- dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
- tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
- # Not numerically stable and should not be used. Leaving here for reference.
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 32}),
- triton.Config({'BLOCK_SIZE_M': 64}),
- triton.Config({'BLOCK_SIZE_M': 128}),
- triton.Config({'BLOCK_SIZE_M': 256}),
- ],
- key=["chunk_size", "hdim"],
- )
- @triton.jit
- def _chunk_scan_bwd_ddAcs_unstable_kernel(
- # Pointers to matrices
- dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr,
- ddA_cumsum_ptr, dD_ptr,
- # Matrix dimensions
- chunk_size, hdim,
- batch, seqlen,
- # Strides
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
- stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_D_head,
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
- stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
- # Meta-parameters
- HAS_D: tl.constexpr,
- D_HAS_HDIM: tl.constexpr,
- SUBTRACT_DDTDT: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- pid_m = tl.program_id(axis=0)
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
- out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
- ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
- if HAS_D:
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
- dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = tl.arange(0, BLOCK_SIZE_N)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
- out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
- if HAS_D:
- x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
- if D_HAS_HDIM:
- dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- if HAS_D:
- x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
- if D_HAS_HDIM:
- dD = tl.sum(dout * x, axis=0)
- tl.store(dD_ptrs, dD, mask=offs_n < hdim)
- D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
- else:
- dD = tl.sum(dout * x)
- tl.store(dD_ptr, dD)
- D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
- out -= x * D
- ddA_cs = tl.sum(dout * out, axis=1)
- if SUBTRACT_DDTDT:
- dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- ddA_cs -= dt * ddt
- tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
- @triton.autotune(
- configs=[
- # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
- # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
- # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
- # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
- # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
- triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
- ],
- key=['chunk_size', 'hdim'],
- )
- @triton.jit
- def _chunk_scan_bwd_ddAcs_stable_kernel_old(
- # Pointers to matrices
- x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
- ddAcs_ptr,
- # Matrix dimensions
- chunk_size, hdim,
- batch, seqlen, nheads_ngroups_ratio,
- # Strides
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
- stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n,
- # Meta-parameters
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
- cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
- x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
- dt_ptrs = dt_ptr + offs_n * stride_dt_csize
- cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
- # Doing a matmul loop with cumsum later on will cause Triton to crash
- # Instead we do just one big matmul
- # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- # for k in range(0, hdim, BLOCK_SIZE_K):
- # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
- # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
- # acc += tl.dot(dout, x)
- # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
- # x_ptrs += BLOCK_SIZE_K * stride_x_hdim
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
- x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
- acc = tl.dot(dout, x)
- cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
- acc *= cb
- dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
- acc *= dt_n
- dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
- acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- mask = offs_m[:, None] >= offs_n[None, :] + 1
- acc = tl.where(mask, acc, 0.0)
- acc = tl.cumsum(acc, axis=1)
- acc = tl.where(mask, acc, 0.0)
- ddA_cs = tl.sum(acc, axis=0)
- ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
- tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
- tl.store(ddAcs_ptr, 0.0)
- # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64)
- # offs_k = tl.arange(0, BLOCK_SIZE_K)
- # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
- # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
- # dt_ptrs = dt_ptr + offs_n * stride_dt_csize
- # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
- # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
- # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
- # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
- # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
- # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
- # for n in range(0, chunk_size_limit_n, 64):
- # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0)
- # acc = tl.dot(dout, x)
- # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32)
- # acc *= cb
- # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
- # acc *= dt_n
- # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
- # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n
- # acc = tl.where(mask, acc, 0.0)
- # acc = tl.cumsum(acc, axis=1)
- # acc = tl.where(mask, acc, 0.0)
- # ddA_cs = tl.sum(acc, axis=0)
- # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n)
- # # tl.store(ddAcs_ptr, 0.0)
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
- ],
- key=['chunk_size', 'hdim'],
- )
- @triton.jit
- def _chunk_scan_bwd_ddAcs_stable_kernel(
- # Pointers to matrices
- x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
- ddA_cumsum_ptr,
- # Matrix dimensions
- chunk_size, hdim,
- batch, seqlen, nheads_ngroups_ratio,
- # Strides
- stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
- # Meta-parameters
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- pid_m = tl.program_id(axis=0)
- x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
- dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
- cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
- x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
- dt_ptrs = dt_ptr + offs_n * stride_dt_csize
- cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
- ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
- tl.store(ddA_cumsum_ptr, 0.0)
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
- dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
- # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower
- lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M
- # lo, hi = 0, chunk_size
- for start_n in range(lo, hi, BLOCK_SIZE_N):
- start_n = tl.multiple_of(start_n, BLOCK_SIZE_N)
- # Doing a matmul loop with cumsum later on will cause Triton to crash
- # Instead we do just one big matmul
- # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- # for k in range(0, hdim, BLOCK_SIZE_K):
- # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
- # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
- # acc += tl.dot(dout, x)
- # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
- # x_ptrs += BLOCK_SIZE_K * stride_x_hdim
- # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
- x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0)
- acc = tl.dot(dout, x)
- dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
- acc *= dt_n
- # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j]
- cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
- acc *= cb
- dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
- acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
- mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
- acc = tl.where(mask, acc, 0.0)
- rowsum_new = rowsum + tl.sum(acc, axis=1)
- acc = rowsum[:, None] + tl.cumsum(acc, axis=1)
- rowsum = rowsum_new
- acc = tl.where(mask, acc, 0.0)
- ddA_cs = tl.sum(acc, axis=0)
- tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1)
- x_ptrs += BLOCK_SIZE_N * stride_x_seqlen
- dt_ptrs += BLOCK_SIZE_N * stride_dt_csize
- cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n
- ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
- # Need to zero out the rest, since we'll be summing the rows together
- for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
- tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)
- ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
- @triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
- ],
- key=['chunk_size', 'dstate', 'hdim'],
- )
- @triton.jit
- def _chunk_scan_bwd_ddAcs_prev_kernel(
- # Pointers to matrices
- dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
- ddA_cumsum_ptr,
- # Matrix dimensions
- chunk_size, dstate, hdim,
- batch, seqlen, nchunks, nheads_ngroups_ratio,
- # Strides
- stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
- stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
- stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
- stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
- stride_seq_idx_batch, stride_seq_idx_seqlen,
- stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
- # Meta-parameters
- HAS_SEQ_IDX: tl.constexpr,
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
- ):
- pid_bc = tl.program_id(axis=1)
- pid_c = pid_bc // batch
- pid_b = pid_bc - pid_c * batch
- pid_h = tl.program_id(axis=2)
- num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
- pid_m = tl.program_id(axis=0) // num_pid_n
- pid_n = tl.program_id(axis=0) % num_pid_n
- dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
- prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head
- C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
- ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
- dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
- if HAS_SEQ_IDX:
- seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
- prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
- C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
- dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
- chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
- dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
- prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
- prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
- acc = tl.dot(dout, prev_states)
- c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
- ddA_cs = tl.sum(acc * c, axis=1)
- dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
- if not HAS_SEQ_IDX:
- scale = tl.exp(dA_cs_m)
- if HAS_SEQ_IDX:
- seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
- seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
- scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
- ddA_cs *= scale
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
- tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
- def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- _, _, ngroups, dstate = C.shape
- assert nheads % ngroups == 0
- assert C.shape == (batch, seqlen, ngroups, dstate)
- assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
- if z is not None:
- assert z.shape == x.shape
- if D is not None:
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
- assert states.shape == (batch, nchunks, nheads, headdim, dstate)
- if seq_idx is not None:
- assert seq_idx.shape == (batch, seqlen)
- # Allocates output.
- out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
- if z is not None:
- out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
- assert out_x.stride() == out.stride()
- else:
- out_x = None
- grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
- batch * nchunks, nheads)
- z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
- if z is not None else (0, 0, 0, 0))
- _chunk_scan_fwd_kernel[grid](
- cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,
- int(chunk_size), int(headdim), int(dstate),
- int(batch), int(seqlen), int(nheads // ngroups),
- cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- z_strides[0], z_strides[1], z_strides[2], z_strides[3],
- out.stride(0), out.stride(1), out.stride(2), out.stride(3),
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
- C.stride(0), C.stride(1), C.stride(2), C.stride(3),
- states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
- D.stride(0) if D is not None else 0,
- True,
- D is not None,
- D.dim() == 2 if D is not None else True,
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(int(dstate)), 16),
- HAS_Z=z is not None,
- HAS_SEQ_IDX=seq_idx is not None,
- IS_TRITON_22=TRITON_22,
- )
- return out, out_x
- def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None):
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- _, _, ngroups, dstate = C.shape
- assert nheads % ngroups == 0
- assert C.shape == (batch, seqlen, ngroups, dstate)
- assert B.shape == C.shape
- assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
- if z is not None:
- assert z.shape == x.shape
- if D is not None:
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
- assert states.shape == (batch, nchunks, nheads, headdim, dstate)
- if seq_idx is not None:
- assert seq_idx.shape == (batch, seqlen)
- # Allocates output.
- out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
- if z is not None:
- out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
- assert out_x.stride() == out.stride()
- else:
- out_x = None
- grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
- z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
- if z is not None else (0, 0, 0, 0))
- _chunk_scan_fwd_kernel_wip[grid](
- cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D,
- int(chunk_size), int(headdim), int(dstate),
- int(batch), int(seqlen), int(nheads // ngroups),
- cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- z_strides[0], z_strides[1], z_strides[2], z_strides[3],
- out.stride(0), out.stride(1), out.stride(2), out.stride(3),
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
- C.stride(0), C.stride(1), C.stride(2), C.stride(3),
- B.stride(0), B.stride(1), B.stride(2), B.stride(3),
- states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
- D.stride(0) if D is not None else 0,
- D is not None,
- D.dim() == 2 if D is not None else True,
- BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(int(dstate)), 16),
- BLOCK_SIZE_M=128,
- HAS_Z=z is not None,
- HAS_SEQ_IDX=seq_idx is not None,
- )
- return out, out_x
- def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False):
- batch, seqlen, nheads, headdim = x.shape
- assert z.shape == x.shape
- assert out.shape == x.shape
- assert dout.shape == out.shape
- nchunks = math.ceil(seqlen / chunk_size)
- if D is not None:
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
- assert D.stride(-1) == 1
- if has_ddAcs:
- ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
- if D is not None:
- BLOCK_SIZE_min = 32
- dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
- headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
- else:
- dD = None
- if dz is not None:
- assert dz.shape == z.shape
- else:
- dz = torch.empty_like(z)
- if recompute_output:
- outz = torch.empty_like(x)
- dout_x = torch.empty_like(dout)
- dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
- if D is not None else (0, 0, 0, 0, 0))
- grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
- with torch.cuda.device(x.device.index):
- _chunk_scan_bwd_dz_kernel[grid_dz](
- dout, out, z, x, D, outz if recompute_output else None,
- dz, dout_x, dD, ddA_cumsum if has_ddAcs else None,
- int(chunk_size), int(headdim),
- int(batch), int(seqlen),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- out.stride(0), out.stride(1), out.stride(2), out.stride(3),
- z.stride(0), z.stride(1), z.stride(2), z.stride(3),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- D.stride(0) if D is not None else 0,
- *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)),
- dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3),
- dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3),
- dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
- *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
- if has_ddAcs else (0, 0, 0, 0)),
- D is not None,
- D.dim() == 2 if D is not None else True,
- has_ddAcs,
- BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
- RECOMPUTE_OUTPUT=recompute_output,
- )
- if D is not None:
- BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"]
- n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
- dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
- if D.dim() == 1:
- dD = rearrange(dD, "h 1 -> h")
- return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD)
- return return_vals if not recompute_output else (*return_vals, outz)
- def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None):
- batch, seqlen, nheads, headdim = dout.shape
- _, _, nchunks, chunk_size = dA_cumsum.shape
- _, _, ngroups, dstate = C.shape
- assert nheads % ngroups == 0
- assert C.shape == (batch, seqlen, ngroups, dstate)
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
- if seq_idx is not None:
- assert seq_idx.shape == (batch, seqlen)
- dtype = C.dtype if dtype is None else dtype
- dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype)
- grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
- batch * nchunks, nheads)
- with torch.cuda.device(C.device.index):
- _chunk_scan_bwd_dstates_kernel[grid_dstates](
- dout, C, dprev_states, dA_cumsum, seq_idx,
- int(headdim), int(dstate), int(chunk_size),
- int(batch), int(seqlen), int(nchunks), int(nheads // ngroups),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- C.stride(0), C.stride(1), C.stride(2), C.stride(3),
- dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
- HAS_SEQ_IDX=seq_idx is not None,
- )
- return dprev_states
- def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1):
- batch, nchunks, nheads, headdim, dstate = prev_states.shape
- _, seqlen, _, _ = dout.shape
- _, _, _, chunk_size = dA_cumsum.shape
- assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
- assert dout.shape == (batch, seqlen, nheads, headdim)
- if seq_idx is not None:
- assert seq_idx.shape == (batch, seqlen)
- if C is not None:
- assert C.shape == (batch, seqlen, ngroups, dstate)
- C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3))
- ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
- ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3))
- else:
- C_strides = (0, 0, 0, 0)
- ddA_cumsum_prev = None
- ddA_cumsum_prev_strides = (0, 0, 0, 0)
- nheads_ngroups_ratio = nheads // ngroups
- sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count
- nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
- nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
- dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32)
- grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
- batch * nchunks, nsplits * ngroups)
- with torch.cuda.device(dout.device.index):
- _chunk_scan_bwd_dc_kernel[grid_dc](
- dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev,
- int(chunk_size), int(dstate), int(headdim),
- int(batch), int(seqlen), int(nheads), int(nheads_per_program), int(ngroups),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
- *C_strides,
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
- dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4),
- *ddA_cumsum_prev_strides,
- HAS_DDA_CS=ddA_cumsum_prev is not None,
- HAS_SEQ_IDX=seq_idx is not None,
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
- )
- dC = dC.sum(2)
- return dC if C is None else (dC, ddA_cumsum_prev)
- def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1):
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert dA_cumsum.shape == dt.shape
- assert dout.shape == x.shape
- if seq_idx is not None:
- assert seq_idx.shape == (batch, seqlen)
- if CB is not None:
- assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
- CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4))
- BLOCK_SIZE_M_min = 16
- ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
- chunk_size, device=x.device, dtype=torch.float32)
- ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4))
- else:
- CB_strides = (0, 0, 0, 0, 0)
- ddA_cumsum = None
- ddA_cumsum_strides = (0, 0, 0, 0, 0)
- nheads_ngroups_ratio = nheads // ngroups
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
- nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
- nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
- dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32)
- grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
- batch * nchunks, nsplits * ngroups)
- with torch.cuda.device(x.device.index):
- _chunk_scan_bwd_dcb_kernel[grid_dcb](
- x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum,
- int(chunk_size), int(headdim),
- int(batch), int(seqlen), int(nheads), int(nheads_per_program), int(ngroups),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- *CB_strides,
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
- dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5),
- *ddA_cumsum_strides,
- HAS_DDA_CS=ddA_cumsum is not None,
- HAS_SEQ_IDX=seq_idx is not None,
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
- )
- dcb = dcb.sum(2)
- if ddA_cumsum is not None:
- BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"]
- n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
- ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
- return dcb if CB is None else (dcb, ddA_cumsum)
- def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None):
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- ngroups = cb.shape[2]
- assert nheads % ngroups == 0
- assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert dA_cumsum.shape == dt.shape
- assert dout.shape == x.shape
- # if D is not None:
- # BLOCK_SIZE_M_min = 32
- # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32)
- # else:
- # dD = None
- dx = torch.empty_like(x)
- ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
- grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
- batch * nchunks, nheads)
- with torch.cuda.device(x.device.index):
- _chunk_scan_bwd_dx_kernel[grid_dx](
- x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD,
- int(chunk_size), int(headdim),
- int(batch), int(seqlen), int(nheads // ngroups),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- D.stride(0) if D is not None else 0,
- dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
- ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
- # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0,
- D is not None,
- D.dim() == 2 if D is not None else True,
- )
- # if D is not None:
- # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
- # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
- # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
- return dx, ddt.to(dtype=dt.dtype)
- def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True):
- """Not numerically stable and should not be used. Leaving here for reference.
- """
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert ddt.shape == dt.shape
- assert out.shape == x.shape
- assert dout.shape == x.shape
- if D is not None:
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
- ddA_cumsum = torch.empty_like(dt)
- grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
- if D is not None: # Triton gives wrong results if we write to the same location
- BLOCK_SIZE_min = 32
- dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
- headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
- else:
- dD = None
- dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
- if D is not None else (0, 0, 0, 0, 0))
- with torch.cuda.device(x.device.index):
- _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs](
- dout, out, dt, ddt, x, D, ddA_cumsum, dD,
- int(chunk_size), int(headdim),
- int(batch), int(seqlen),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- out.stride(0), out.stride(1), out.stride(2), out.stride(3),
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
- ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- D.stride(0) if D is not None else 0,
- ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
- dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
- D is not None,
- D.dim() == 2 if D is not None else True,
- subtract_ddtdt,
- BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
- )
- if D is not None:
- BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
- n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
- dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
- if D.dim() == 1:
- dD = rearrange(dD, "h 1 -> h")
- return ddA_cumsum, dD
- def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb):
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert dout.shape == x.shape
- assert dA_cumsum.shape == dt.shape
- ngroups = cb.shape[2]
- assert nheads % ngroups == 0
- assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
- BLOCK_SIZE_M_min = 16
- ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
- chunk_size, device=x.device, dtype=torch.float32)
- grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
- with torch.cuda.device(x.device.index):
- _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs](
- x, dout, dt, dA_cumsum, cb, ddA_cumsum,
- int(chunk_size), int(headdim),
- int(batch), int(seqlen), int(nheads // ngroups),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
- ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
- BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16),
- )
- BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"]
- n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
- ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
- return ddA_cumsum
- def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb):
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert dout.shape == x.shape
- assert dA_cumsum.shape == dt.shape
- ngroups = cb.shape[2]
- assert nheads % ngroups == 0
- assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
- BLOCK_SIZE_M_min = 32
- ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
- chunk_size, device=x.device, dtype=torch.float32)
- grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
- with torch.cuda.device(x.device.index):
- _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs](
- x, dout, dt, dA_cumsum, cb, ddA_cumsum,
- int(chunk_size), int(headdim),
- int(batch), int(seqlen), int(nheads // ngroups),
- x.stride(0), x.stride(1), x.stride(2), x.stride(3),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
- ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
- )
- BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
- n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
- ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
- return ddA_cumsum
- def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None):
- batch, nchunks, nheads, headdim, dstate = prev_states.shape
- _, seqlen, _, _ = dout.shape
- _, _, _, chunk_size = dA_cumsum.shape
- assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
- assert dout.shape == (batch, seqlen, nheads, headdim)
- ngroups = C.shape[2]
- assert nheads % ngroups == 0
- assert C.shape == (batch, seqlen, ngroups, dstate)
- if seq_idx is not None:
- assert seq_idx.shape == (batch, seqlen)
- ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
- grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
- batch * nchunks, nheads)
- with torch.cuda.device(dout.device.index):
- _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs](
- dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev,
- int(chunk_size), int(dstate), int(headdim),
- int(batch), int(seqlen), int(nchunks), int(nheads // ngroups),
- dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
- prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
- C.stride(0), C.stride(1), C.stride(2), C.stride(3),
- dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
- *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
- ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3),
- HAS_SEQ_IDX=seq_idx is not None,
- BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
- )
- return ddA_cumsum_prev
- class ChunkScanFn(torch.autograd.Function):
- @staticmethod
- def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
- # Check constraints.
- batch, seqlen, nheads, headdim = x.shape
- _, _, ngroups, dstate = B.shape
- assert B.shape == (batch, seqlen, ngroups, dstate)
- _, _, nchunks, chunk_size = dt.shape
- assert seqlen == nchunks * chunk_size
- assert C.shape == B.shape
- if z is not None:
- assert z.shape == x.shape
- if D is not None:
- assert D.shape == (nheads, headdim) or D.shape == (nheads,)
- assert dt.shape == (batch, nheads, nchunks, chunk_size)
- assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
- assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
- if B.stride(-1) != 1:
- B = B.contiguous()
- if C.stride(-1) != 1:
- C = C.contiguous()
- if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
- x = x.contiguous()
- if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
- z = z.contiguous()
- if D is not None and D.stride(-1) != 1:
- D = D.contiguous()
- CB = _bmm_chunk_fwd(C, B, chunk_size)
- out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z)
- ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z)
- return out
- @staticmethod
- def backward(ctx, dout):
- if dout.stride(-1) != 1:
- dout = dout.contiguous()
- out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors
- batch, seqlen, nheads, headdim = x.shape
- _, _, nchunks, chunk_size = dt.shape
- _, _, ngroups, dstate = B.shape
- assert dout.shape == (batch, seqlen, nheads, headdim)
- if z is not None:
- dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D)
- else:
- dz = None
- dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype)
- dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups)
- dC = dC.to(C.dtype)
- dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups)
- dCB = dCB.to(CB.dtype)
- dB = _bmm_chunk_bwd(C, dCB)
- dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC)
- dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D)
- # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
- # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
- if z is not None:
- ddA_cumsum -= ddt * dt
- else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz
- ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D)
- ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype)
- return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz
- def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
- """
- prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1.
- Argument:
- B: (batch, seqlen, ngroups, dstate)
- C: (batch, seqlen, ngroups, dstate)
- x: (batch, seqlen, nheads, headdim)
- dt: (batch, nheads, nchunks, chunk_size)
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
- prev_states: (batch, nchunks, nheads, headdim, dstate)
- D: (nheads, headdim) or (nheads,)
- z: (batch, seqlen, nheads, headdim)
- Return:
- out: (batch, seqlen, nheads, headdim)
- """
- return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z)
- def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
- """
- Argument:
- B: (batch, seqlen, ngroups, dstate)
- C: (batch, seqlen, ngroups, dstate)
- x: (batch, seqlen, nheads, headdim)
- dt: (batch, nheads, nchunks, chunk_size)
- dA_cumsum: (batch, nheads, nchunks, chunk_size)
- prev_states: (batch, nchunks, nheads, headdim, dstate)
- D: (nheads, headdim) or (nheads,)
- z: (batch, seqlen, nheads, headdim)
- Return:
- out: (batch, seqlen, nheads, headdim)
- """
- batch, seqlen, nheads, headdim = x.shape
- _, _, ngroups, dstate = B.shape
- assert B.shape == (batch, seqlen, ngroups, dstate)
- _, _, nchunks, chunk_size = dt.shape
- assert seqlen == nchunks * chunk_size
- assert C.shape == B.shape
- B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
- C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
- CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
- rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
- # (batch, nheads, nchunks, chunksize, chunksize)
- dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
- decay = torch.exp(dt_segment_sum)
- scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s")
- causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
- scores_decay = scores_decay.masked_fill(~causal_mask, 0)
- out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
- rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
- state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
- out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
- prev_states.to(C.dtype)) * state_decay_out
- out = out + out_prev
- out = rearrange(out, "b c l h p -> b (c l) h p")
- if D is not None:
- if D.dim() == 1:
- D = rearrange(D, "h -> h 1")
- out = out + x * D
- return out if z is None else out * F.silu(z)
|