attnmap.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # this is only a script !
  2. import os
  3. import torch
  4. import torch.nn as nn
  5. from functools import partial
  6. from PIL import Image
  7. from utils import visualize, get_dataset, AttnMamba, import_abspy, show_mask_on_image
  8. HOME = os.environ["HOME"].rstrip("/")
  9. def main_vssm(det_model=True, showpath= "show/vssmattnmap"):
  10. raw_attn = True
  11. stage = 2
  12. block_id = 1
  13. img_size = 512
  14. featHW = 32 # stage 2 so 32
  15. if not det_model:
  16. dataset = get_dataset(root='/media/Disk1/Dataset/ImageNet_ILSVRC2012/val', img_size=img_size, crop=False)
  17. idxs_posxs_posys = [
  18. [72, 0.7, 0.3], [72, 0.2, 0.8],
  19. [273, 0.7, 0.3], [282, 0.2, 0.8],
  20. [282, 0.7, 0.3], [282, 0.2, 0.8],
  21. [14602, 0.7, 0.3], [14602, 0.2, 0.8],
  22. [17460, 0.6, 0.3], [17460, 0.2, 0.6],
  23. [19256, 0.7, 0.3], [19256, 0.2, 0.3],
  24. [47512, 0.7, 0.3], [47512, 0.3, 0.6],
  25. ]
  26. # print([i for i, s in enumerate(dataset.samples) if "ILSVRC2012_val_00012107.JPEG" in s[0] ])
  27. else:
  28. # we want multiple objects, so we choose to use det model
  29. dataset = get_dataset(root='/media/Disk1/Dataset/MSCOCO2014/images/', img_size=img_size, ret="val2014", crop=False)
  30. idxs_posxs_posys = [
  31. [0, 0.3, 0.5], [0, 0.8, 0.8],
  32. [149, 0.7, 0.5], [149, 0.2, 0.4],
  33. [162, 0.7, 0.4], [162, 0.4, 0.4],
  34. [204, 0.3, 0.6], [204, 0.7, 0.2],
  35. [273, 0.2, 0.6], [273, 0.9, 0.5],
  36. [309, 0.1, 0.7], [309, 0.9, 0.8],
  37. ]
  38. # dataset = get_dataset(root='/media/Disk1/Dataset/ADEChallengeData2016/images/', img_size=img_size, ret="validation", crop=False)
  39. vmamba = import_abspy("vmamba", os.path.join(os.path.dirname(os.path.abspath(__file__)), "../classification/models"))
  40. vssm: nn.Module = vmamba.vmamba_tiny_s1l8().cuda().eval()
  41. if det_model:
  42. vssm.load_state_dict(AttnMamba.convert_state_dict_from_mmdet(torch.load(open(f"{HOME}/Workspace/PylanceAware/ckpts/private/detection/vssm1/detection/mask_rcnn_vssm_fpn_coco_tiny_ms_3x_s/epoch_36.pth", "rb"), map_location="cpu")["state_dict"]), strict=False)
  43. else:
  44. vssm.load_state_dict(torch.load(open(f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_tiny_0230s/vssm1_tiny_0230s_ckpt_epoch_264.pth", "rb"), map_location="cpu")["model"], strict=False)
  45. if raw_attn:
  46. setattr(vssm.layers[stage].blocks[block_id].op, "__DEBUG__", True)
  47. ss2ds = vssm.layers[stage].blocks[block_id].op
  48. else:
  49. [[ setattr(blk.op, "__DEBUG__", True) for blk in layer.blocks] for layer in vssm.layers ]
  50. ss2ds = [[blk.op for blk in layer.blocks] for layer in vssm.layers ]
  51. for idx, posx, posy in idxs_posxs_posys:
  52. img, label = dataset[idx]
  53. with torch.no_grad():
  54. out = vssm(img[None].cuda())
  55. print(out.argmax().item(), label, img.shape)
  56. os.makedirs(f"{showpath}/{idx}_{posx}_{posy}", exist_ok=True)
  57. deimg = img.cpu() * torch.tensor([0.25, 0.25, 0.25]).view(-1, 1, 1) + torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1)
  58. deimg = deimg.permute(1, 2, 0).cpu()
  59. Image.fromarray((deimg * 255).to(torch.uint8).numpy()).save(f"{showpath}/{idx}_{posx}_{posy}/imori.jpg")
  60. visualize.draw_image_grid(
  61. Image.fromarray((deimg * 255).to(torch.uint8).numpy()),
  62. [(posx * img_size, posy * img_size, img_size / featHW, img_size / featHW,)]
  63. ).save(f"{showpath}/{idx}_{posx}_{posy}/imori_grid.jpg")
  64. # continue
  65. for m0 in ["a0", "a1", "a2", "a3", "all", "nall"]:
  66. # for m0 in ["ao0", "ao1", "ao2", "ao3", "a0", "a1", "a2", "a3", "all", "nall"]:
  67. for m1 in ["CB", "CwBw", "ww"]:
  68. aaa = AttnMamba.get_attnmap_mamba(ss2ds, stage, f"{m0}_norm_{m1}", raw_attn=True, block_id=block_id)
  69. # attention map
  70. # visualize.visualize_attnmap(aaa, f"{showpath}/{idx}_{posx}_{posy}/attn_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
  71. # diag attention map
  72. # visualize.visualize_attnmap(torch.diag(aaa).view(featHW, featHW), f"{showpath}/{idx}_{posx}_{posy}/attn_{m0}_norm_{m1}_diag.jpg", colorbar=False, sticks=False)
  73. # activation map
  74. mask = aaa[int(posy * featHW) * int(featHW) + int(posx * featHW)].view(featHW, featHW)
  75. visualize.visualize_attnmap(mask, f"{showpath}/{idx}_{posx}_{posy}/activation_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
  76. def main_deit(det_model=True, showpath="show/deitdet"):
  77. raw_attn = True
  78. stage = 2
  79. block_id = 1
  80. img_size = 512
  81. featHW = 32 # stage 2 so 32
  82. if not det_model:
  83. dataset = get_dataset(root='/media/Disk1/Dataset/ImageNet_ILSVRC2012/val', img_size=img_size, crop=False)
  84. idxs_posxs_posys = [
  85. [0, 0.7, 0.3], [0, 0.2, 0.8],
  86. [149, 0.7, 0.5], [149, 0.2, 0.4],
  87. [162, 0.7, 0.4], [162, 0.4, 0.4],
  88. [204, 0.3, 0.6], [204, 0.7, 0.2],
  89. [273, 0.2, 0.6], [273, 0.9, 0.5],
  90. [309, 0.1, 0.7], [309, 0.9, 0.8],
  91. ]
  92. else:
  93. # we want multiple objects, so we choose to use det model
  94. dataset = get_dataset(root='/media/Disk1/Dataset/MSCOCO2014/images/', img_size=img_size, ret="val2014", crop=False)
  95. idxs_posxs_posys = [
  96. [0, 0.7, 0.3], [0, 0.2, 0.8],
  97. [149, 0.7, 0.5], [149, 0.2, 0.4],
  98. [162, 0.7, 0.4], [162, 0.4, 0.4],
  99. [204, 0.3, 0.6], [204, 0.7, 0.2],
  100. [273, 0.2, 0.6], [273, 0.9, 0.5],
  101. [309, 0.1, 0.7], [309, 0.9, 0.8],
  102. ]
  103. # dataset = get_dataset(root='/media/Disk1/Dataset/ADEChallengeData2016/images/', img_size=img_size, ret="validation", crop=False)
  104. attns = dict()
  105. deit_small_baseline = None
  106. if det_model:
  107. _deit = import_abspy("vit_adpter_baseline", f"{HOME}/Workspace/PylanceAware/ckpts/ckpts")
  108. # from ckpts.ckpts.vit_adpter_baseline import deit_small_baseline, Attention, WindowedAttention
  109. deit_small_baseline, Attention, WindowedAttention = _deit.deit_small_baseline, _deit.Attention, _deit.WindowedAttention
  110. sd = torch.load(f"{HOME}/Workspace/PylanceAware/ckpts/others/deit_small_patch16_224-cd65a155.pth", map_location=torch.device("cpu"))
  111. deit_small_baseline = deit_small_baseline().cuda()
  112. deit_small_baseline.load_state_dict(sd['model'], strict=False)
  113. def attn_forward(self: Attention, x, H, W):
  114. B, N, C = x.shape
  115. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  116. q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
  117. attn = (q @ k.transpose(-2, -1)) * self.scale
  118. attn = attn.softmax(dim=-1)
  119. setattr(self, "__data__", attn)
  120. attn = self.attn_drop(attn)
  121. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  122. x = self.proj(x)
  123. x = self.proj_drop(x)
  124. return x
  125. for n, m in deit_small_baseline.blocks.named_children():
  126. if isinstance(m.attn, WindowedAttention):
  127. pass
  128. elif isinstance(m.attn, Attention):
  129. m.attn.forward = partial(attn_forward, m.attn)
  130. attns.update({n: m.attn})
  131. else:
  132. assert False
  133. else:
  134. from utils import BuildModels
  135. model = BuildModels.build_deit_mmpretrain(with_ckpt=True, scale="small").cuda().eval()
  136. from mmpretrain.models.utils.attention import MultiheadAttention
  137. from mmpretrain.models.utils.attention import scaled_dot_product_attention_pyimpl
  138. def mattn_forward(self: MultiheadAttention, x):
  139. B, N, _ = x.shape
  140. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
  141. self.head_dims).permute(2, 0, 3, 1, 4)
  142. q, k, v = qkv[0], qkv[1], qkv[2]
  143. attn_drop = self.attn_drop if self.training else 0.
  144. scale = q.size(-1)**0.5
  145. attn_weight = q @ k.transpose(-2, -1) / scale
  146. setattr(self, "__data__", attn_weight)
  147. x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
  148. x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
  149. x = self.proj(x)
  150. x = self.out_drop(self.gamma1(self.proj_drop(x)))
  151. if self.v_shortcut:
  152. x = v.squeeze(1) + x
  153. return x
  154. for n, l in model.backbone.layers.named_children():
  155. l.attn.forward = partial(mattn_forward, l.attn)
  156. attns.update({n: l.attn})
  157. deit_small_baseline = model
  158. print(attns.keys())
  159. for idx, posx, posy in idxs_posxs_posys:
  160. img, label = dataset[idx]
  161. with torch.no_grad():
  162. deit_small_baseline(img[None].cuda())
  163. os.makedirs(f"{showpath}/{idx}_{posx}_{posy}", exist_ok=True)
  164. deimg = img.cpu() * torch.tensor([0.25, 0.25, 0.25]).view(-1, 1, 1) + torch.tensor([0.5, 0.5, 0.5]).view(-1, 1, 1)
  165. deimg = deimg.permute(1, 2, 0).cpu()
  166. visualize.draw_image_grid(
  167. Image.fromarray((deimg * 255).to(torch.uint8).numpy()),
  168. [(posx * img_size, posy * img_size, img_size / featHW, img_size / featHW,)]
  169. ).save(f"{showpath}/{idx}_{posx}_{posy}/imori.jpg")
  170. for m0 in ["attn"]:
  171. for m1 in ["attn"]:
  172. aaa = getattr(attns['8'], "__data__")[0]
  173. aaa = ((aaa - aaa.min()) / (aaa.max() - aaa.min())).mean(dim=0)
  174. # attention map
  175. visualize.visualize_attnmap(aaa, f"{showpath}/{idx}_{posx}_{posy}/attn_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
  176. # activation map
  177. if aaa.shape[0] == featHW * featHW + 1:
  178. aaa = aaa[1:, 1:]
  179. mask = aaa[int(posy * featHW) * int(featHW) + int(posx * featHW)].view(featHW, featHW)
  180. visualize.visualize_attnmap(mask, f"{showpath}/{idx}_{posx}_{posy}/activation_{m0}_norm_{m1}.jpg", colorbar=False, sticks=False)
  181. if __name__ == "__main__":
  182. this_path = os.path.dirname(os.path.abspath(__file__))
  183. # main_deit(det_model=True, showpath="show/deitdet")
  184. # main_deit(det_model=False, showpath="show/deitcls")
  185. main_vssm(det_model=False, showpath="show/vssmcls")
  186. # main_vssm(det_model=True, showpath="show/vssmdet")