| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619 |
- import os
- import time
- import math
- import copy
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import triton
- from functools import partial
- from collections import OrderedDict
- from vmamba import CrossScan, CrossMerge, CrossScan_Ab_1direction, CrossMerge_Ab_1direction, CrossScan_Ab_2direction, CrossMerge_Ab_2direction, getCSM
- from vmamba import CrossScanTriton, CrossMergeTriton, CrossScanTriton1b1
- from vmamba import VSSM, PatchMerging2D, Mlp, gMlp, LayerNorm2d, VSSBlock
- class CHECKS:
- def check_vssm_equals_vmambadp():
- try:
- from _ignore.vmamba.vmamba_bak1 import VMamba2Dp
- from _ignore.vmamba.vmamba_pub import VSSM
- except:
- print("original VSSM and VMamba2Dp not found.", flush=True)
- return
- # test 1 True =================================
- torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time())
- oldvss = VMamba2Dp(depths=[2,2,6,2]).half().cuda()
- newvss = VSSM(depths=[2,2,6,2]).half().cuda()
- newvss.load_state_dict(oldvss.state_dict())
- input = torch.randn((12, 3, 224, 224)).half().cuda()
- torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y1 = oldvss.forward_backbone(input)
- torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y2 = newvss.forward_backbone(input)
- print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
-
- torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y1 = oldvss.forward(input)
- torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y2 = newvss.forward(input)
- print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
-
- # test 2 True ==========================================
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- oldvss = VMamba2Dp(depths=[2,2,6,2]).cuda()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- newvss = VSSM(depths=[2,2,6,2]).cuda()
- miss_align = 0
- for k, v in oldvss.state_dict().items():
- same = (oldvss.state_dict()[k] == newvss.state_dict()[k]).all()
- if not same:
- print(k, same)
- miss_align += 1
- print("init miss align", miss_align) # init miss align 0
- def check_vssm1_equals_vssm(forward_type="v0"):
- try:
- from _ignore.vmamba.vmamba_pub import VSSM as VSSM0
- except:
- print("original VSSM and VMamba2Dp not found.", flush=True)
- return
- class VSSM_(VSSM):
- @staticmethod
- def _make_layer(*args, **kwargs):
- layer = VSSM._make_layer(*args, **kwargs)
- dim = kwargs.get("dim", None)
- norm_layer = kwargs.get("norm_layer", None)
- downsample = kwargs.get("downsample", None)
- blocks = layer.blocks
-
- if True: # is this really applied? Yes, but been overriden later in VSSM!
- def _init_weights(module: nn.Module):
- for name, p in module.named_parameters():
- if name in ["out_proj.weight"]:
- p = p.clone().detach_() # fake init, just to keep the seed ....
- nn.init.kaiming_uniform_(p, a=math.sqrt(5))
- blks = nn.Sequential(*copy.deepcopy(blocks))
- blks.apply(_init_weights)
- downsample = PatchMerging2D(dim, 2*dim, norm_layer=norm_layer) if downsample is None else nn.Identity()
-
- return nn.Sequential(OrderedDict(
- blocks=nn.Sequential(*blocks,),
- downsample=downsample,
- ))
- def forward_backbone(self, x):
- x = self.patch_embed(x)
- for l in self.layers:
- x = l(x)
- return x
- def forward1(self, x: torch.Tensor):
- x = self.patch_embed(x)
- for layer in self.layers:
- x = layer(x)
- x = self.classifier.norm(x)
- # here: whether has contiguous would differ
- x = self.classifier.avgpool(x.permute(0, 3, 1, 2).contiguous()).flatten(1)
- x = self.classifier.head(x)
- return x
- # only has initial difference
- VSSM1 = partial(VSSM, downsample_version="v1", patchembed_version="v1", mlp_ratio=0.0, ssm_ratio=2.0, forward_type=forward_type)
- VSSM.forward_backbone = VSSM_.forward_backbone
- VSSM.forward1 = VSSM_.forward1
- # expected to be all the same
- VSSM1 = partial(VSSM_, downsample_version="none", patchembed_version="v1", mlp_ratio=0.0, ssm_ratio=2.0, forward_type=forward_type)
- # test 1 True =================================
- torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time())
- oldvss = VSSM0(depths=[2,2,6,2]).half().cuda()
- newvss = VSSM1(depths=[2,2,6,2]).half().cuda()
- newvss.load_state_dict(oldvss.state_dict())
- input = torch.randn((12, 3, 224, 224)).half().cuda()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y1 = oldvss.forward_backbone(input)
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y2 = newvss.forward_backbone(input)
- print((y1 -y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
-
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y1 = oldvss.forward(input)
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y2 = newvss.forward1(input)
- print((y1 -y2).abs().sum()) # tensor(2.5988e-05, device='cuda:0', grad_fn=<SumBackward0>)
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y3 = newvss.forward(input)
- print((y1 -y3).abs().sum()) # tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
-
- # test 2 True ==========================================
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- oldvss = VSSM0(depths=[2,2,6,2]).cuda()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- newvss = VSSM1(depths=[2,2,6,2]).cuda()
- miss_align = 0
- oldvss2new = copy.deepcopy(newvss)
- oldvss2new.load_state_dict(oldvss.state_dict())
- for k, v in oldvss2new.state_dict().items():
- same = (oldvss2new.state_dict()[k] == newvss.state_dict()[k]).all()
- if not same:
- print(k, same)
- miss_align += 1
- print("init miss align", miss_align) # init miss align 0
- def check_vssm1_ssoflex_equals_mambassm():
- # only has initial difference
- VSSM0 = partial(VSSM, downsample_version="v3", patchembed_version="v2", mlp_ratio=4.0, ssm_ratio=2.0, forward_type="v2")
- VSSM1 = partial(VSSM, downsample_version="v3", patchembed_version="v2", mlp_ratio=4.0, ssm_ratio=2.0, forward_type="v01")
- # test 1 True =================================
- torch.manual_seed(time.time()); torch.cuda.manual_seed(time.time())
- oldvss = VSSM0(depths=[2,2,6,2]).half().cuda()
- newvss = VSSM1(depths=[2,2,6,2]).half().cuda()
- newvss.load_state_dict(oldvss.state_dict())
- input0 = torch.randn((12, 3, 224, 224)).half().cuda().requires_grad_()
- input1 = input0.detach().clone().requires_grad_()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y1 = oldvss.forward(input0)
- y1.sum().backward()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- with torch.cuda.amp.autocast():
- y2 = newvss.forward(input1)
- y2.sum().backward()
- print((y1 - y2).abs().sum()) # tensor(0., device='cuda:0', dtype=torch.float16, grad_fn=<SumBackward0>)
- print((input0.grad - input1.grad).abs().sum()) # tensor(6.6016, device='cuda:0', dtype=torch.float16)
-
- # test 2 True ==========================================
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- oldvss = VSSM0(depths=[2,2,6,2]).cuda()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- newvss = VSSM1(depths=[2,2,6,2]).cuda()
- miss_align = 0
- oldvss2new = copy.deepcopy(newvss)
- oldvss2new.load_state_dict(oldvss.state_dict())
- for k, v in oldvss2new.state_dict().items():
- same = (oldvss2new.state_dict()[k] == newvss.state_dict()[k]).all()
- if not same:
- print(k, same)
- miss_align += 1
- print("init miss align", miss_align) # init miss align 0
- def check_csm_triton():
- B, C, H, W = 256, 192, 56, 57
- dtype=torch.float16
- dtype=torch.float32
- x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
- y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
- x1 = x.clone().detach().requires_grad_(True)
- y1 = y.clone().detach().requires_grad_(True)
- def cross_scan(x: torch.Tensor):
- B, C, H, W = x.shape
- L = H * W
- xs = torch.stack([
- x.view(B, C, L),
- torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
- torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
- torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
- ], dim=1).view(B, 4, C, L)
- return xs
-
- def cross_merge(out_y: torch.Tensor):
- B, K, D, H, W = out_y.shape
- L = H * W
- out_y = out_y.view(B, K, D, L)
- inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
- wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
- return y
- def cross_scan_1b1(x: torch.Tensor):
- B, K, C, H, W = x.shape
- L = H * W
- xs = torch.stack([
- x[:, 0].view(B, C, L),
- torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
- torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
- torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
- ], dim=1).view(B, 4, C, L)
- return xs
-
- def unidi_scan(x):
- B, C, H, W = x.shape
- x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
- return x
-
- def unidi_merge(ys):
- B, K, C, H, W = ys.shape
- return ys.view(B, 4, -1, H * W).sum(1)
- def bidi_scan(x):
- B, C, H, W = x.shape
- x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
- x = torch.cat([x, x.flip(dims=[-1])], dim=1)
- return x
-
- def bidi_merge(ys):
- B, K, D, H, W = ys.shape
- ys = ys.view(B, K, D, -1)
- ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
- return ys.contiguous().sum(1)
- if True:
- res0 = triton.testing.do_bench(lambda :cross_scan(x))
- res1 = triton.testing.do_bench(lambda :CrossScan.apply(x))
- res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
- res3 = triton.testing.do_bench(lambda :cross_merge(y))
- res4 = triton.testing.do_bench(lambda :CrossMerge.apply(y))
- res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
- print(res0, res1, res2, res3, res4, res5)
- res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
- res1 = triton.testing.do_bench(lambda :CrossScan.apply(x).sum().backward())
- res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
- res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
- res4 = triton.testing.do_bench(lambda :CrossMerge.apply(y).sum().backward())
- res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
- print(res0, res1, res2, res3, res4, res5)
- print("test cross scan")
- for (cs0, cm0, cs1, cm1) in [
- (cross_scan, cross_merge, CrossScanTriton, CrossMergeTriton),
- (unidi_scan, unidi_merge, getCSM(1)[0], getCSM(1)[1]),
- (bidi_scan, bidi_merge, getCSM(2)[0], getCSM(2)[1]),
- ]:
- x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
- o0 = cs0(x)
- o1 = cs1.apply(x1)
- o0.backward(y.view(B, 4, C, H * W))
- o1.backward(y.view(B, 4, C, H * W))
- print((o0 - o1).abs().max())
- print((x.grad - x1.grad).abs().max())
- o0 = cm0(y)
- o1 = cm1.apply(y1)
- o0.backward(x.view(B, C, H * W))
- o1.backward(x.view(B, C, H * W))
- print((o0 - o1).abs().max())
- print((y.grad - y1.grad).abs().max())
- x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
- print("test cross scan one by one")
- if True:
- o0 = cross_scan_1b1(y)
- o1 = CrossScanTriton1b1.apply(y1)
- o0.backward(y.view(B, 4, C, H * W))
- o1.backward(y.view(B, 4, C, H * W))
- print((o0 - o1).abs().max())
- print((y.grad - y1.grad).abs().max())
- x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
- x = torch.arange(0, 16, 1).view(1, 1, 4, 4)
- print(x)
- print(cross_scan(x))
- print(cross_merge(cross_scan(x).view(1, 4, 1, 4, 4)))
- x = torch.arange(0, 16, 1).view(1, 4, 1, 2, 2)
- print(x)
- print(cross_scan_1b1(x))
- def check_einsum():
- B, D, L, R, K = 128, 192, 56 * 56, 12, 4
- o = torch.randn((B, K * D, L)).cuda()
- x = torch.randn((B, K, R, L)).cuda().requires_grad_(True)
- w = torch.randn((K, D, R)).cuda().requires_grad_(True)
- x1 = x.clone().detach().requires_grad_(True)
- w1 = w.clone().detach().requires_grad_(True)
- y1 = torch.einsum("bkrl,kdr->bkdl", x, w).contiguous().view(B, -1, L)
- y2 = F.conv1d(x1.view(B, -1, L), w1.view(K * D, R, 1), None, groups=K).contiguous().view(B, -1, L)
- print((y1 - y2).abs().max())
- y1.backward(o)
- y2.backward(o)
- print((x.grad - x1.grad).abs().max())
- def check_vssblock():
- import triton
- from torchvision.models.vision_transformer import EncoderBlock
- vb = VSSBlock(
- hidden_dim=16,
- drop_path=0.0,
- norm_layer=nn.LayerNorm,
- ssm_d_state=1,
- ssm_ratio=2,
- ssm_dt_rank="auto",
- ssm_act_layer=nn.SiLU,
- ssm_conv=3,
- ssm_conv_bias=False,
- ssm_drop_rate=0.0,
- ssm_init="v0",
- forward_type="v2",
- mlp_ratio=4,
- mlp_act_layer=nn.GELU,
- mlp_drop_rate=0.0,
- use_checkpoint=False,
- ).cuda()
-
- trans = EncoderBlock(
- num_heads=1,
- hidden_dim=16,
- mlp_dim=int(4.0 * 16),
- dropout=0.0,
- attention_dropout=0.0,
- norm_layer=nn.LayerNorm,
- ).cuda()
- inp = torch.randn((16, 128, 128, 16)).cuda().requires_grad_()
- inp2 = inp.detach().cuda().view(16, -1, 16).requires_grad_()
- fn = lambda :vb(inp)
- ms = triton.testing.do_bench(fn, warmup=100)
- print(ms)
- fn = lambda :trans(inp2)
- ms = triton.testing.do_bench(fn, warmup=100)
- print(ms)
- fn = lambda :vb(inp).sum().backward()
- ms = triton.testing.do_bench(fn, warmup=100)
- print(ms)
- fn = lambda :trans(inp2).sum().backward()
- ms = triton.testing.do_bench(fn, warmup=100)
- print(ms)
- import time; time.sleep(10000)
- def check_ln2d():
- import triton
- B, C, H, W = 128, 8192, 7, 7
- inp = torch.randn((B, C, H, W)).cuda().requires_grad_()
- inp2 = inp.detach().permute(0, 2, 3, 1).clone().requires_grad_()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n1 = LayerNorm2d(C).cuda()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n2 = nn.LayerNorm(C).cuda()
- o1 = n1(inp)
- o2 = n2(inp2)
- print((o1.permute(0, 2, 3, 1) - o2).abs().max())
- o1.backward(inp.data)
- o2.backward(inp.data.permute(0, 2, 3, 1))
- print((inp.grad.permute(0, 2, 3, 1) - inp2.grad).abs().max())
- ms1 = triton.testing.do_bench(lambda:n1(inp))
- ms2 = triton.testing.do_bench(lambda:n2(inp2))
- ms3 = triton.testing.do_bench(lambda:n1(inp))
- print(ms1, ms2, ms3)
- def check_linear_2d():
- import triton
- inp = torch.randn((64, 192, 56, 57)).cuda().requires_grad_()
- inp2 = inp.detach().permute(0, 2, 3, 1).clone().requires_grad_()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n1 = Mlp(192, 4*192, 384, channels_first=True).cuda()
- catch_random1 = torch.randn((1,))
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n2 = Mlp(192, 4*192, 384, channels_first=False).cuda()
- catch_random2 = torch.randn((1,))
- print(catch_random1, catch_random2)
- with torch.cuda.amp.autocast():
- o1 = n1(inp)
- o2 = n2(inp2)
- print((o1.permute(0, 2, 3, 1) - o2).abs().max())
- o1.sum().backward()
- o2.sum().backward()
- print((inp.grad.permute(0, 2, 3, 1) - inp2.grad).abs().max())
- i1, i2 = inp.float(), inp2.float()
- ms2 = triton.testing.do_bench(lambda:n2(i2))
- ms1 = triton.testing.do_bench(lambda:n1(i1))
- ms4 = triton.testing.do_bench(lambda:n2(i2).sum().backward())
- ms3 = triton.testing.do_bench(lambda:n1(i1).sum().backward())
- print(ms1, ms2, ms3, ms4)
- def check_gmlp():
- import triton
- inp = torch.randn((64, 192, 56, 57)).cuda().requires_grad_()
- inp2 = inp.detach().permute(0, 2, 3, 1).clone().requires_grad_()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n1 = Mlp(192, 4*192, 384, channels_first=True).cuda()
- catch_random1 = torch.randn((1,))
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n2 = gMlp(192, 2*192, 384, channels_first=False).cuda()
- catch_random2 = torch.randn((1,))
- print(catch_random1, catch_random2)
- with torch.cuda.amp.autocast():
- o1 = n1(inp)
- o2 = n2(inp2)
- print((o1.permute(0, 2, 3, 1) - o2).abs().max())
- o1.sum().backward()
- o2.sum().backward()
- print((inp.grad.permute(0, 2, 3, 1) - inp2.grad).abs().max())
- i1, i2 = inp.float(), inp2.float()
- ms2 = triton.testing.do_bench(lambda:n2(i2))
- ms1 = triton.testing.do_bench(lambda:n1(i1))
- ms4 = triton.testing.do_bench(lambda:n2(i2).sum().backward())
- ms3 = triton.testing.do_bench(lambda:n1(i1).sum().backward())
- print(ms1, ms2, ms3, ms4)
- def check_channel_first():
- import triton
- inp = torch.randn((64, 3, 224, 224)).cuda().half().requires_grad_()
- inp2 = inp.detach().clone().requires_grad_()
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n1 = VSSM(norm_layer="ln").cuda()
- catch_random1 = torch.randn((1,))
- torch.manual_seed(0); torch.cuda.manual_seed(0)
- n2 = VSSM(norm_layer="ln2d").cuda()
- catch_random2 = torch.randn((1,))
- print(catch_random1, catch_random2)
- with torch.cuda.amp.autocast():
- o1 = n1(inp)
- o2 = n2(inp2)
- print((o1 - o2).abs().max())
- o1.sum().backward()
- o2.sum().backward()
- print((inp.grad - inp2.grad).abs().max())
- i1, i2 = inp.float(), inp2.float()
- ms2 = triton.testing.do_bench(lambda:n2(i2))
- ms1 = triton.testing.do_bench(lambda:n1(i1))
- ms4 = triton.testing.do_bench(lambda:n2(i2).sum().backward())
- ms3 = triton.testing.do_bench(lambda:n1(i1).sum().backward())
- print(ms1, ms2, ms3, ms4)
- def check_profile():
- vss = VSSM(depths=[1], dims=1024).half().cuda()
- input = torch.randn((128, 3, 56, 56)).half().cuda()
- torch.cuda.manual_seed(0)
- def trace_handler(prof: torch.profiler.profile):
- print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
- # print(prof.export_chrome_trace("./tracev1.json"))
- with torch.cuda.amp.autocast():
- # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True, with_stack=True) as prof:
- with torch.profiler.profile(
- with_modules=True,
- with_stack=True,
- profile_memory=True,
- activities=[
- torch.profiler.ProfilerActivity.CPU,
- torch.profiler.ProfilerActivity.CUDA,
- ],
- # In this example with wait=1, warmup=1, active=2, repeat=1,
- # profiler will skip the first step/iteration,
- # start warming up on the second, record
- # the third and the forth iterations,
- # after which the trace will become available
- # and on_trace_ready (when set) is called;
- # the cycle repeats starting with the next step
- schedule=torch.profiler.schedule(
- wait=1,
- warmup=1,
- active=2,
- repeat=1),
- on_trace_ready=trace_handler
- # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
- # used when outputting for tensorboard
- ) as prof:
- for iter in range(1000):
- x = input
- # with torch.autograd.profiler.record_function("patch_embed"):
- # x = self.patch_embed(x)
- prof.step()
- def load22kto1k():
- if False:
- # delete relative_position_index since we always re-init it
- relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
- for k in relative_position_index_keys:
- del state_dict[k]
- # delete relative_coords_table since we always re-init it
- relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
- for k in relative_position_index_keys:
- del state_dict[k]
- # delete attn_mask since we always re-init it
- attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
- for k in attn_mask_keys:
- del state_dict[k]
- # bicubic interpolate relative_position_bias_table if not match
- relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
- for k in relative_position_bias_table_keys:
- relative_position_bias_table_pretrained = state_dict[k]
- relative_position_bias_table_current = model.state_dict()[k]
- L1, nH1 = relative_position_bias_table_pretrained.size()
- L2, nH2 = relative_position_bias_table_current.size()
- if nH1 != nH2:
- logger.warning(f"Error in loading {k}, passing......")
- else:
- if L1 != L2:
- # bicubic interpolate relative_position_bias_table if not match
- S1 = int(L1 ** 0.5)
- S2 = int(L2 ** 0.5)
- relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
- relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
- mode='bicubic')
- state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
- # bicubic interpolate absolute_pos_embed if not match
- absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
- for k in absolute_pos_embed_keys:
- # dpe
- absolute_pos_embed_pretrained = state_dict[k]
- absolute_pos_embed_current = model.state_dict()[k]
- _, L1, C1 = absolute_pos_embed_pretrained.size()
- _, L2, C2 = absolute_pos_embed_current.size()
- if C1 != C1:
- logger.warning(f"Error in loading {k}, passing......")
- else:
- if L1 != L2:
- S1 = int(L1 ** 0.5)
- S2 = int(L2 ** 0.5)
- absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
- absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
- absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
- absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
- absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
- absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
- state_dict[k] = absolute_pos_embed_pretrained_resized
- # check classifier, if not match, then re-init classifier to zero
- head_bias_pretrained = state_dict['head.bias']
- Nc1 = head_bias_pretrained.shape[0]
- Nc2 = model.head.bias.shape[0]
- if (Nc1 != Nc2):
- if Nc1 == 21841 and Nc2 == 1000:
- logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
- map22kto1k_path = f'data/map22kto1k.txt'
- with open(map22kto1k_path) as f:
- map22kto1k = f.readlines()
- map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
- state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
- state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
- else:
- torch.nn.init.constant_(model.head.bias, 0.)
- torch.nn.init.constant_(model.head.weight, 0.)
- del state_dict['head.weight']
- del state_dict['head.bias']
- logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
- if __name__ == "__main__":
- # CHECKS.check_vssblock()
- # CHECKS.check_vssm_equals_vmambadp()
- # CHECKS.check_vssm1_equals_vssm(forward_type="v0")
- # CHECKS.check_vssm1_equals_vssm(forward_type="v0_seq")
- # CHECKS.check_vssm1_ssoflex_equals_mambassm()
- CHECKS.check_csm_triton()
- # CHECKS.check_einsum()
- # CHECKS.check_ln2d()
- # CHECKS.check_linear_2d()
- # CHECKS.check_gmlp()
- # CHECKS.check_channel_first()
- # breakpoint()
-
|