| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- # this is only a script !
- import os
- import torch
- import torch.nn as nn
- from functools import partial
- from PIL import Image
- from utils import visualize, get_dataset, AttnMamba, import_abspy, show_mask_on_image
- HOME = os.environ["HOME"].rstrip("/")
- def main_vssm(det_model=True, showpath= "show/vssmattnmap"):
- raw_attn = True
- stage = 2
- block_id = 1
- img_size = 512
- featHW = 32 # stage 2 so 32
- if not det_model:
- dataset = get_dataset(root='/media/Disk1/Dataset/ImageNet_ILSVRC2012/val', img_size=img_size, crop=False)
- idxs_posxs_posys = [
- [72, 0.7, 0.3], [72, 0.2, 0.8],
- [273, 0.7, 0.3], [282, 0.2, 0.8],
- [282, 0.7, 0.3], [282, 0.2, 0.8],
- [14602, 0.7, 0.3], [14602, 0.2, 0.8],
- [17460, 0.6, 0.3], [17460, 0.2, 0.6],
- [19256, 0.7, 0.3], [19256, 0.2, 0.3],
- [47512, 0.7, 0.3], [47512, 0.3, 0.6],
- ]
- # print([i for i, s in enumerate(dataset.samples) if "ILSVRC2012_val_00012107.JPEG" in s[0] ])
- else:
- # we want multiple objects, so we choose to use det model
- dataset = get_dataset(root='/media/Disk1/Dataset/MSCOCO2014/images/', img_size=img_size, ret="val2014", crop=False)
- idxs_posxs_posys = [
- [0, 0.3, 0.5], [0, 0.8, 0.8],
- [149, 0.7, 0.5], [149, 0.2, 0.4],
- [162, 0.7, 0.4], [162, 0.4, 0.4],
- [204, 0.3, 0.6], [204, 0.7, 0.2],
- [273, 0.2, 0.6], [273, 0.9, 0.5],
- [309, 0.1, 0.7], [309, 0.9, 0.8],
- ]
-
- # dataset = get_dataset(root='/media/Disk1/Dataset/ADEChallengeData2016/images/', img_size=img_size, ret="validation", crop=False)
-
- vmamba = import_abspy("vmamba", os.path.join(os.path.dirname(os.path.abspath(__file__)), "../classification/models"))
- vssm: nn.Module = vmamba.vmamba_tiny_s1l8().cuda().eval()
- if det_model:
- vssm.load_state_dict(AttnMamba.convert_state_dict_from_mmdet(torch.load(open(f"{HOME}/Workspace/PylanceAware/ckpts/private/detection/vssm1/detection/mask_rcnn_vssm_fpn_coco_tiny_ms_3x_s/epoch_36.pth", "rb"), map_location="cpu")["state_dict"]), strict=False)
- else:
- vssm.load_state_dict(torch.load(open(f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_tiny_0230s/vssm1_tiny_0230s_ckpt_epoch_264.pth", "rb"), map_location="cpu")["model"], strict=False)
- if raw_attn:
- setattr(vssm.layers[stage].blocks[block_id].op, "__DEBUG__", True)
- ss2ds = vssm.layers[stage].blocks[block_id].op
- else:
- [[ setattr(blk.op, "__DEBUG__", True) for blk in layer.blocks] for layer in vssm.layers ]
- ss2ds = [[blk.op for blk in layer.blocks] for layer in vssm.layers ]
- for idx, posx, posy in idxs_posxs_posys:
- img, label = dataset[idx]
- with torch.no_grad():
- out = vssm(img[None].cuda())
- print(out.argmax().item(), label, img.shape)
- os.makedirs(f"{showpath}/{idx}_{posx}_{posy}", exist_ok=True)
- deimg = img.cpu() * torch.tensor([0.25, 0.25, 0.25]).view(-1, 1, 1) + torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1)
- deimg = deimg.permute(1, 2, 0).cpu()
- Image.fromarray((deimg * 255).to(torch.uint8).numpy()).save(f"{showpath}/{idx}_{posx}_{posy}/imori.jpg")
- visualize.draw_image_grid(
- Image.fromarray((deimg * 255).to(torch.uint8).numpy()),
- [(posx * img_size, posy * img_size, img_size / featHW, img_size / featHW,)]
- ).save(f"{showpath}/{idx}_{posx}_{posy}/imori_grid.jpg")
- # continue
- for m0 in ["a0", "a1", "a2", "a3", "all", "nall"]:
- # for m0 in ["ao0", "ao1", "ao2", "ao3", "a0", "a1", "a2", "a3", "all", "nall"]:
- for m1 in ["CB", "CwBw", "ww"]:
- aaa = AttnMamba.get_attnmap_mamba(ss2ds, stage, f"{m0}_norm_{m1}", raw_attn=True, block_id=block_id)
- # attention map
- # visualize.visualize_attnmap(aaa, f"{showpath}/{idx}_{posx}_{posy}/attn_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
- # diag attention map
- # visualize.visualize_attnmap(torch.diag(aaa).view(featHW, featHW), f"{showpath}/{idx}_{posx}_{posy}/attn_{m0}_norm_{m1}_diag.jpg", colorbar=False, sticks=False)
- # activation map
- mask = aaa[int(posy * featHW) * int(featHW) + int(posx * featHW)].view(featHW, featHW)
- visualize.visualize_attnmap(mask, f"{showpath}/{idx}_{posx}_{posy}/activation_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
- def main_deit(det_model=True, showpath="show/deitdet"):
- raw_attn = True
- stage = 2
- block_id = 1
- img_size = 512
- featHW = 32 # stage 2 so 32
- if not det_model:
- dataset = get_dataset(root='/media/Disk1/Dataset/ImageNet_ILSVRC2012/val', img_size=img_size, crop=False)
- idxs_posxs_posys = [
- [0, 0.7, 0.3], [0, 0.2, 0.8],
- [149, 0.7, 0.5], [149, 0.2, 0.4],
- [162, 0.7, 0.4], [162, 0.4, 0.4],
- [204, 0.3, 0.6], [204, 0.7, 0.2],
- [273, 0.2, 0.6], [273, 0.9, 0.5],
- [309, 0.1, 0.7], [309, 0.9, 0.8],
- ]
- else:
- # we want multiple objects, so we choose to use det model
- dataset = get_dataset(root='/media/Disk1/Dataset/MSCOCO2014/images/', img_size=img_size, ret="val2014", crop=False)
- idxs_posxs_posys = [
- [0, 0.7, 0.3], [0, 0.2, 0.8],
- [149, 0.7, 0.5], [149, 0.2, 0.4],
- [162, 0.7, 0.4], [162, 0.4, 0.4],
- [204, 0.3, 0.6], [204, 0.7, 0.2],
- [273, 0.2, 0.6], [273, 0.9, 0.5],
- [309, 0.1, 0.7], [309, 0.9, 0.8],
- ]
-
- # dataset = get_dataset(root='/media/Disk1/Dataset/ADEChallengeData2016/images/', img_size=img_size, ret="validation", crop=False)
-
- attns = dict()
- deit_small_baseline = None
- if det_model:
- _deit = import_abspy("vit_adpter_baseline", f"{HOME}/Workspace/PylanceAware/ckpts/ckpts")
- # from ckpts.ckpts.vit_adpter_baseline import deit_small_baseline, Attention, WindowedAttention
- deit_small_baseline, Attention, WindowedAttention = _deit.deit_small_baseline, _deit.Attention, _deit.WindowedAttention
- sd = torch.load(f"{HOME}/Workspace/PylanceAware/ckpts/others/deit_small_patch16_224-cd65a155.pth", map_location=torch.device("cpu"))
- deit_small_baseline = deit_small_baseline().cuda()
- deit_small_baseline.load_state_dict(sd['model'], strict=False)
-
- def attn_forward(self: Attention, x, H, W):
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
- attn = (q @ k.transpose(-2, -1)) * self.scale
- attn = attn.softmax(dim=-1)
- setattr(self, "__data__", attn)
- attn = self.attn_drop(attn)
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- for n, m in deit_small_baseline.blocks.named_children():
- if isinstance(m.attn, WindowedAttention):
- pass
- elif isinstance(m.attn, Attention):
- m.attn.forward = partial(attn_forward, m.attn)
- attns.update({n: m.attn})
- else:
- assert False
- else:
- from utils import BuildModels
- model = BuildModels.build_deit_mmpretrain(with_ckpt=True, scale="small").cuda().eval()
- from mmpretrain.models.utils.attention import MultiheadAttention
- from mmpretrain.models.utils.attention import scaled_dot_product_attention_pyimpl
- def mattn_forward(self: MultiheadAttention, x):
- B, N, _ = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
- self.head_dims).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
- attn_drop = self.attn_drop if self.training else 0.
- scale = q.size(-1)**0.5
- attn_weight = q @ k.transpose(-2, -1) / scale
- setattr(self, "__data__", attn_weight)
- x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
- x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
- x = self.proj(x)
- x = self.out_drop(self.gamma1(self.proj_drop(x)))
- if self.v_shortcut:
- x = v.squeeze(1) + x
- return x
-
- for n, l in model.backbone.layers.named_children():
- l.attn.forward = partial(mattn_forward, l.attn)
- attns.update({n: l.attn})
-
- deit_small_baseline = model
- print(attns.keys())
- for idx, posx, posy in idxs_posxs_posys:
- img, label = dataset[idx]
- with torch.no_grad():
- deit_small_baseline(img[None].cuda())
- os.makedirs(f"{showpath}/{idx}_{posx}_{posy}", exist_ok=True)
- deimg = img.cpu() * torch.tensor([0.25, 0.25, 0.25]).view(-1, 1, 1) + torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1)
- deimg = deimg.permute(1, 2, 0).cpu()
- visualize.draw_image_grid(
- Image.fromarray((deimg * 255).to(torch.uint8).numpy()),
- [(posx * img_size, posy * img_size, img_size / featHW, img_size / featHW,)]
- ).save(f"{showpath}/{idx}_{posx}_{posy}/imori.jpg")
- for m0 in ["attn"]:
- for m1 in ["attn"]:
- aaa = getattr(attns['8'], "__data__")[0]
- aaa = ((aaa - aaa.min()) / (aaa.max() - aaa.min())).mean(dim=0)
-
- # attention map
- visualize.visualize_attnmap(aaa, f"{showpath}/{idx}_{posx}_{posy}/attn_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
-
- # activation map
- if aaa.shape[0] == featHW * featHW + 1:
- aaa = aaa[1:, 1:]
- mask = aaa[int(posy * featHW) * int(featHW) + int(posx * featHW)].view(featHW, featHW)
- visualize.visualize_attnmap(mask, f"{showpath}/{idx}_{posx}_{posy}/activation_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
- if __name__ == "__main__":
- this_path = os.path.dirname(os.path.abspath(__file__))
- # main_deit(det_model=True, showpath="show/deitdet")
- # main_deit(det_model=False, showpath="show/deitcls")
- main_vssm(det_model=False, showpath="show/vssmcls")
- # main_vssm(det_model=True, showpath="show/vssmdet")
|