| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691 |
- import torch
- import warnings
- WITH_TRITON = True
- # WITH_TRITON = False
- try:
- import triton
- import triton.language as tl
- except:
- WITH_TRITON = False
- warnings.warn("Triton not installed, fall back to pytorch implements.")
- # to make sure cached_property can be loaded for triton
- if WITH_TRITON:
- try:
- from functools import cached_property
- except:
- warnings.warn("if you are using py37, add this line to functools.py: "
- "cached_property = lambda func: property(lru_cache()(func))")
- # torch implementation ========================================
- def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
- if in_channel_first:
- B, C, H, W = x.shape
- if scans == 0:
- y = x.new_empty((B, 4, C, H * W))
- y[:, 0, :, :] = x.flatten(2, 3)
- y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
- y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
- elif scans == 1:
- y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
- elif scans == 2:
- y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
- y = torch.cat([y, y.flip(dims=[-1])], dim=1)
- else:
- B, H, W, C = x.shape
- if scans == 0:
- y = x.new_empty((B, H * W, 4, C))
- y[:, :, 0, :] = x.flatten(1, 2)
- y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
- y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
- elif scans == 1:
- y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
- elif scans == 2:
- y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
- y = torch.cat([y, y.flip(dims=[1])], dim=2)
- if in_channel_first and (not out_channel_first):
- y = y.permute(0, 3, 1, 2).contiguous()
- elif (not in_channel_first) and out_channel_first:
- y = y.permute(0, 2, 3, 1).contiguous()
- return y
- def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
- if out_channel_first:
- B, K, D, H, W = y.shape
- y = y.view(B, K, D, -1)
- if scans == 0:
- y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
- y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
- elif scans == 1:
- y = y.sum(1)
- elif scans == 2:
- y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
- y = y.sum(1)
- else:
- B, H, W, K, D = y.shape
- y = y.view(B, -1, K, D)
- if scans == 0:
- y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
- y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
- elif scans == 1:
- y = y.sum(2)
- elif scans == 2:
- y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
- y = y.sum(2)
- if in_channel_first and (not out_channel_first):
- y = y.permute(0, 2, 1).contiguous()
- elif (not in_channel_first) and out_channel_first:
- y = y.permute(0, 2, 1).contiguous()
-
- return y
- def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
- if in_channel_first:
- B, _, C, H, W = x.shape
- if scans == 0:
- y = torch.stack([
- x[:, 0].flatten(2, 3),
- x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
- torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
- torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
- ], dim=1)
- elif scans == 1:
- y = x.flatten(2, 3)
- elif scans == 2:
- y = torch.stack([
- x[:, 0].flatten(2, 3),
- x[:, 1].flatten(2, 3),
- torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
- torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
- ], dim=1)
- else:
- B, H, W, _, C = x.shape
- if scans == 0:
- y = torch.stack([
- x[:, :, :, 0].flatten(1, 2),
- x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
- torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
- torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
- ], dim=2)
- elif scans == 1:
- y = x.flatten(1, 2)
- elif scans == 2:
- y = torch.stack([
- x[:, 0].flatten(1, 2),
- x[:, 1].flatten(1, 2),
- torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
- torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
- ], dim=2)
- if in_channel_first and (not out_channel_first):
- y = y.permute(0, 3, 1, 2).contiguous()
- elif (not in_channel_first) and out_channel_first:
- y = y.permute(0, 2, 3, 1).contiguous()
- return y
- def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
- if out_channel_first:
- B, K, D, H, W = y.shape
- y = y.view(B, K, D, -1)
- if scans == 0:
- y = torch.stack([
- y[:, 0],
- y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
- torch.flip(y[:, 2], dims=[-1]),
- torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
- ], dim=1)
- elif scans == 1:
- y = y
- elif scans == 2:
- y = torch.stack([
- y[:, 0],
- y[:, 1],
- torch.flip(y[:, 2], dims=[-1]),
- torch.flip(y[:, 3], dims=[-1]),
- ], dim=1)
- else:
- B, H, W, K, D = y.shape
- y = y.view(B, -1, K, D)
- if scans == 0:
- y = torch.stack([
- y[:, :, 0],
- y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
- torch.flip(y[:, :, 2], dims=[1]),
- torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
- ], dim=2)
- elif scans == 1:
- y = y
- elif scans == 2:
- y = torch.stack([
- y[:, :, 0],
- y[:, :, 1],
- torch.flip(y[:, :, 2], dims=[1]),
- torch.flip(y[:, :, 3], dims=[1]),
- ], dim=2)
- if out_channel_first and (not in_channel_first):
- y = y.permute(0, 3, 1, 2).contiguous()
- elif (not out_channel_first) and in_channel_first:
- y = y.permute(0, 2, 3, 1).contiguous()
- return y
- class CrossScanF(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
- # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
- # y: (B, 4, C, H * W) | (B, H * W, 4, C)
- ctx.in_channel_first = in_channel_first
- ctx.out_channel_first = out_channel_first
- ctx.one_by_one = one_by_one
- ctx.scans = scans
- if one_by_one:
- B, K, C, H, W = x.shape
- if not in_channel_first:
- B, H, W, K, C = x.shape
- else:
- B, C, H, W = x.shape
- if not in_channel_first:
- B, H, W, C = x.shape
- ctx.shape = (B, C, H, W)
- _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
- y = _fn(x, in_channel_first, out_channel_first, scans)
- return y
-
- @staticmethod
- def backward(ctx, ys: torch.Tensor):
- # out: (b, k, d, l)
- in_channel_first = ctx.in_channel_first
- out_channel_first = ctx.out_channel_first
- one_by_one = ctx.one_by_one
- scans = ctx.scans
- B, C, H, W = ctx.shape
- ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
- _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
- y = _fn(ys, in_channel_first, out_channel_first, scans)
-
- if one_by_one:
- y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
- else:
- y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
- return y, None, None, None, None
- class CrossMergeF(torch.autograd.Function):
- @staticmethod
- def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
- # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
- # y: (B, 4, C, H * W) | (B, H * W, 4, C)
- ctx.in_channel_first = in_channel_first
- ctx.out_channel_first = out_channel_first
- ctx.one_by_one = one_by_one
- ctx.scans = scans
- B, K, C, H, W = ys.shape
- if not out_channel_first:
- B, H, W, K, C = ys.shape
- ctx.shape = (B, C, H, W)
-
- _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
- y = _fn(ys, in_channel_first, out_channel_first, scans)
- return y
-
- @staticmethod
- def backward(ctx, x: torch.Tensor):
- # B, D, L = x.shape
- # out: (b, k, d, h, w)
- in_channel_first = ctx.in_channel_first
- out_channel_first = ctx.out_channel_first
- one_by_one = ctx.one_by_one
- scans = ctx.scans
- B, C, H, W = ctx.shape
-
- if not one_by_one:
- if in_channel_first:
- x = x.view(B, C, H, W)
- else:
- x = x.view(B, H, W, C)
- else:
- if in_channel_first:
- x = x.view(B, 4, C, H, W)
- else:
- x = x.view(B, H, W, 4, C)
-
- _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
- x = _fn(x, in_channel_first, out_channel_first, scans)
- x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
-
- return x, None, None, None, None
- # triton implements ========================================
- @triton.jit
- def triton_cross_scan_flex(
- x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
- y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
- x_layout: tl.constexpr,
- y_layout: tl.constexpr,
- operation: tl.constexpr,
- onebyone: tl.constexpr,
- scans: tl.constexpr,
- BC: tl.constexpr,
- BH: tl.constexpr,
- BW: tl.constexpr,
- DC: tl.constexpr,
- DH: tl.constexpr,
- DW: tl.constexpr,
- NH: tl.constexpr,
- NW: tl.constexpr,
- ):
- # x_layout = 0
- # y_layout = 1 # 0 BCHW, 1 BHWC
- # operation = 0 # 0 scan, 1 merge
- # onebyone = 0 # 0 false, 1 true
- # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
- i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
- i_h, i_w = (i_hw // NW), (i_hw % NW)
- _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
- _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
- _mask_hw = _mask_h[:, None] & _mask_w[None, :]
- _for_C = min(DC - i_c * BC, BC)
- pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
- pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
- neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
- neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
- if scans == 0:
- # none; trans; flip; trans + flip;
- HWRoute0 = pos_h * DW + pos_w
- HWRoute1 = pos_w * DH + pos_h # trans
- HWRoute2 = neg_h * DW + neg_w # flip
- HWRoute3 = neg_w * DH + neg_h # trans + flip
- elif scans == 1:
- # none; none; none; none;
- HWRoute0 = pos_h * DW + pos_w
- HWRoute1 = HWRoute0
- HWRoute2 = HWRoute0
- HWRoute3 = HWRoute0
- elif scans == 2:
- # none; none; flip; flip;
- HWRoute0 = pos_h * DW + pos_w
- HWRoute1 = HWRoute0
- HWRoute2 = neg_h * DW + neg_w # flip
- HWRoute3 = HWRoute2
- _tmp1 = DC * DH * DW
- y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
- if y_layout == 0:
- p_y1 = y_ptr_base + HWRoute0
- p_y2 = y_ptr_base + _tmp1 + HWRoute1
- p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
- p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
- else:
- p_y1 = y_ptr_base + HWRoute0 * 4 * DC
- p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
- p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
- p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
-
- if onebyone == 0:
- x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
- if x_layout == 0:
- p_x = x_ptr_base + HWRoute0
- else:
- p_x = x_ptr_base + HWRoute0 * DC
- if operation == 0:
- for idxc in range(_for_C):
- _idx_x = idxc * DH * DW if x_layout == 0 else idxc
- _idx_y = idxc * DH * DW if y_layout == 0 else idxc
- _x = tl.load(p_x + _idx_x, mask=_mask_hw)
- tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
- tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
- tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
- tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
- elif operation == 1:
- for idxc in range(_for_C):
- _idx_x = idxc * DH * DW if x_layout == 0 else idxc
- _idx_y = idxc * DH * DW if y_layout == 0 else idxc
- _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
- _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
- _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
- _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
- tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
- else:
- x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
- if x_layout == 0:
- p_x1 = x_ptr_base + HWRoute0
- p_x2 = p_x1 + _tmp1
- p_x3 = p_x2 + _tmp1
- p_x4 = p_x3 + _tmp1
- else:
- p_x1 = x_ptr_base + HWRoute0 * 4 * DC
- p_x2 = p_x1 + DC
- p_x3 = p_x2 + DC
- p_x4 = p_x3 + DC
-
- if operation == 0:
- for idxc in range(_for_C):
- _idx_x = idxc * DH * DW if x_layout == 0 else idxc
- _idx_y = idxc * DH * DW if y_layout == 0 else idxc
- tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
- tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
- tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
- tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
- else:
- for idxc in range(_for_C):
- _idx_x = idxc * DH * DW if x_layout == 0 else idxc
- _idx_y = idxc * DH * DW if y_layout == 0 else idxc
- tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
- tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
- tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
- tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
- class CrossScanTritonF(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
- if one_by_one:
- if in_channel_first:
- B, _, C, H, W = x.shape
- else:
- B, H, W, _, C = x.shape
- else:
- if in_channel_first:
- B, C, H, W = x.shape
- else:
- B, H, W, C = x.shape
- B, C, H, W = int(B), int(C), int(H), int(W)
- BC, BH, BW = 1, 32, 32
- NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
-
- ctx.in_channel_first = in_channel_first
- ctx.out_channel_first = out_channel_first
- ctx.one_by_one = one_by_one
- ctx.scans = scans
- ctx.shape = (B, C, H, W)
- ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
- y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
- triton_cross_scan_flex[(NH * NW, NC, B)](
- x.contiguous(), y,
- (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
- BC, BH, BW, C, H, W, NH, NW
- )
- return y
-
- @staticmethod
- def backward(ctx, y: torch.Tensor):
- in_channel_first = ctx.in_channel_first
- out_channel_first = ctx.out_channel_first
- one_by_one = ctx.one_by_one
- scans = ctx.scans
- B, C, H, W = ctx.shape
- BC, BH, BW, NC, NH, NW = ctx.triton_shape
- if one_by_one:
- x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
- else:
- x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
-
- triton_cross_scan_flex[(NH * NW, NC, B)](
- x, y.contiguous(),
- (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
- BC, BH, BW, C, H, W, NH, NW
- )
- return x, None, None, None, None
- class CrossMergeTritonF(torch.autograd.Function):
- @staticmethod
- def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
- if out_channel_first:
- B, _, C, H, W = y.shape
- else:
- B, H, W, _, C = y.shape
- B, C, H, W = int(B), int(C), int(H), int(W)
- BC, BH, BW = 1, 32, 32
- NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
- ctx.in_channel_first = in_channel_first
- ctx.out_channel_first = out_channel_first
- ctx.one_by_one = one_by_one
- ctx.scans = scans
- ctx.shape = (B, C, H, W)
- ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
- if one_by_one:
- x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
- else:
- x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
- triton_cross_scan_flex[(NH * NW, NC, B)](
- x, y.contiguous(),
- (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
- BC, BH, BW, C, H, W, NH, NW
- )
- return x
-
- @staticmethod
- def backward(ctx, x: torch.Tensor):
- in_channel_first = ctx.in_channel_first
- out_channel_first = ctx.out_channel_first
- one_by_one = ctx.one_by_one
- scans = ctx.scans
- B, C, H, W = ctx.shape
- BC, BH, BW, NC, NH, NW = ctx.triton_shape
- y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
- triton_cross_scan_flex[(NH * NW, NC, B)](
- x.contiguous(), y,
- (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
- BC, BH, BW, C, H, W, NH, NW
- )
- return y, None, None, None, None, None
- # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
- def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
- # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
- # y: (B, 4, C, L) | (B, L, 4, C)
- # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
- CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
- with torch.cuda.device(x.device):
- return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
- # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
- def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
- # y: (B, 4, C, L) | (B, L, 4, C)
- # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
- # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
- CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
- with torch.cuda.device(y.device):
- return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
- # checks =================================================================
- class CHECK:
- def check_csm_triton():
- B, C, H, W = 256, 192, 56, 57
- dtype=torch.float16
- dtype=torch.float32
- x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
- y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
- x1 = x.clone().detach().requires_grad_(True)
- y1 = y.clone().detach().requires_grad_(True)
- def cross_scan(x: torch.Tensor):
- B, C, H, W = x.shape
- L = H * W
- xs = torch.stack([
- x.view(B, C, L),
- torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
- torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
- torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
- ], dim=1).view(B, 4, C, L)
- return xs
-
- def cross_merge(out_y: torch.Tensor):
- B, K, D, H, W = out_y.shape
- L = H * W
- out_y = out_y.view(B, K, D, L)
- inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
- wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
- return y
- def cross_scan_1b1(x: torch.Tensor):
- B, K, C, H, W = x.shape
- L = H * W
- xs = torch.stack([
- x[:, 0].view(B, C, L),
- torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
- torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
- torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
- ], dim=1).view(B, 4, C, L)
- return xs
-
- def unidi_scan(x):
- B, C, H, W = x.shape
- x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
- return x
-
- def unidi_merge(ys):
- B, K, C, H, W = ys.shape
- return ys.view(B, 4, -1, H * W).sum(1)
- def bidi_scan(x):
- B, C, H, W = x.shape
- x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
- x = torch.cat([x, x.flip(dims=[-1])], dim=1)
- return x
-
- def bidi_merge(ys):
- B, K, D, H, W = ys.shape
- ys = ys.view(B, K, D, -1)
- ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
- return ys.contiguous().sum(1)
- if True:
- res0 = triton.testing.do_bench(lambda :cross_scan(x))
- res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False))
- # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
- res3 = triton.testing.do_bench(lambda :cross_merge(y))
- res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False))
- # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
- # print(res0, res1, res2, res3, res4, res5)
- print(res0, res1, res3, res4)
- res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
- res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward())
- # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
- res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
- res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward())
- # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
- # print(res0, res1, res2, res3, res4, res5)
- print(res0, res1, res3, res4)
- print("test cross scan")
- for (cs0, cm0, cs1, cm1) in [
- # channel_first -> channel_first
- (cross_scan, cross_merge, cross_scan_fn, cross_merge_fn),
- (unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)),
- (bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)),
-
- # flex: BLC->BCL; BCL->BLC; BLC->BLC;
- (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)),
- (cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
- (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
-
- # previous
- # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
- # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
- # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
- ]:
- x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
- o0 = cs0(x)
- o1 = cs1(x1)
- o0.backward(y.view(B, 4, C, H * W))
- o1.backward(y.view(B, 4, C, H * W))
- print((o0 - o1).abs().max())
- print((x.grad - x1.grad).abs().max())
- o0 = cm0(y)
- o1 = cm1(y1)
- o0.backward(x.view(B, C, H * W))
- o1.backward(x.view(B, C, H * W))
- print((o0 - o1).abs().max())
- print((y.grad - y1.grad).abs().max())
- x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
- print("===============", flush=True)
- print("test cross scan one by one")
- for (cs0, cs1) in [
- (cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)),
- # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
- ]:
- o0 = cs0(y)
- o1 = cs1(y1)
- o0.backward(y.view(B, 4, C, H * W))
- o1.backward(y.view(B, 4, C, H * W))
- print((o0 - o1).abs().max())
- print((y.grad - y1.grad).abs().max())
- x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
- print("===============", flush=True)
- def check_csm_scan3():
- if False:
- x = torch.arange(0, 16).view(1, 1, 4, 4).cuda()
- out1 = cross_scan_fn(x, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
- out2 = cross_merge_fn(out1, scans=3, force_torch=True).view(1, 1, 4, 4)
- out4 = cross_merge_fn(out1, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
- out3 = cross_scan_fn(out4, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
- out5 = cross_scan_fn(x.view(1, 4, 4, 1), in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
- out6 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 1)
- out8 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
- out7 = cross_scan_fn(out8, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
- print(out1.view(4, -1))
- print(out2.view(-1))
- print(out3.view(4, -1))
- print(out4.view(4, -1))
- print(out5.view(-1, 4).t())
- print(out6.view(-1))
- print(out7.view(-1, 4).t())
- print(out8.view(-1, 4).t())
- B, C, H, W = 27, 253, 57, 58
- x = torch.randn((B, C, H, W)).cuda()
- for scans in [0, 1, 2, 3]:
- o1 = cross_scan_fn(x, scans=scans, force_torch=True).view(B, 4, C, H, W)
- print((cross_scan_fn(x, scans=scans) == cross_scan_fn(x, scans=scans, force_torch=True)).all())
- print((cross_merge_fn(o1, scans=scans) == cross_merge_fn(o1, scans=scans, force_torch=True)).all())
- kwargs = dict(in_channel_first=False, out_channel_first=False)
- x2 = x.permute(0, 2, 3, 1).contiguous()
- o2 = o1.permute(0, 3, 4, 1, 2).contiguous()
- print((cross_scan_fn(x, scans=scans, **kwargs) == cross_scan_fn(x, scans=scans, force_torch=True, **kwargs)).all())
- print((cross_merge_fn(o2, scans=scans, **kwargs) == cross_merge_fn(o2, scans=scans, force_torch=True, **kwargs)).all())
- breakpoint()
- if __name__ == "__main__":
- CHECK.check_csm_scan3()
- CHECK.check_csm_triton()
|