swin_transformer_v2.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. # --------------------------------------------------------
  2. # Swin Transformer V2
  3. # Copyright (c) 2022 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torch.utils.checkpoint as checkpoint
  11. from timm.layers import DropPath, to_2tuple, trunc_normal_
  12. import numpy as np
  13. class Mlp(nn.Module):
  14. def __init__(
  15. self,
  16. in_features,
  17. hidden_features=None,
  18. out_features=None,
  19. act_layer=nn.GELU,
  20. drop=0.0,
  21. ):
  22. super().__init__()
  23. out_features = out_features or in_features
  24. hidden_features = hidden_features or in_features
  25. self.fc1 = nn.Linear(in_features, hidden_features)
  26. self.act = act_layer()
  27. self.fc2 = nn.Linear(hidden_features, out_features)
  28. self.drop = nn.Dropout(drop)
  29. def forward(self, x):
  30. x = self.fc1(x)
  31. x = self.act(x)
  32. x = self.drop(x)
  33. x = self.fc2(x)
  34. x = self.drop(x)
  35. return x
  36. def window_partition(x, window_size):
  37. """
  38. Args:
  39. x: (B, H, W, C)
  40. window_size (int): window size
  41. Returns:
  42. windows: (num_windows*B, window_size, window_size, C)
  43. """
  44. B, H, W, C = x.shape
  45. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  46. windows = (
  47. x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  48. )
  49. return windows
  50. def window_reverse(windows, window_size, H, W):
  51. """
  52. Args:
  53. windows: (num_windows*B, window_size, window_size, C)
  54. window_size (int): Window size
  55. H (int): Height of image
  56. W (int): Width of image
  57. Returns:
  58. x: (B, H, W, C)
  59. """
  60. B = int(windows.shape[0] / (H * W / window_size / window_size))
  61. x = windows.view(
  62. B, H // window_size, W // window_size, window_size, window_size, -1
  63. )
  64. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  65. return x
  66. class WindowAttention(nn.Module):
  67. r"""Window based multi-head self attention (W-MSA) module with relative position bias.
  68. It supports both of shifted and non-shifted window.
  69. Args:
  70. dim (int): Number of input channels.
  71. window_size (tuple[int]): The height and width of the window.
  72. num_heads (int): Number of attention heads.
  73. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  74. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  75. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  76. pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
  77. """
  78. def __init__(
  79. self,
  80. dim,
  81. window_size,
  82. num_heads,
  83. qkv_bias=True,
  84. attn_drop=0.0,
  85. proj_drop=0.0,
  86. pretrained_window_size=(0, 0),
  87. ):
  88. super().__init__()
  89. self.dim = dim
  90. self.window_size = window_size # Wh, Ww
  91. self.pretrained_window_size = pretrained_window_size
  92. self.num_heads = num_heads
  93. self.logit_scale = nn.Parameter(
  94. torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
  95. )
  96. # mlp to generate continuous relative position bias
  97. self.cpb_mlp = nn.Sequential(
  98. nn.Linear(2, 512, bias=True),
  99. nn.ReLU(inplace=True),
  100. nn.Linear(512, num_heads, bias=False),
  101. )
  102. # get relative_coords_table
  103. relative_coords_h = torch.arange(
  104. -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
  105. )
  106. relative_coords_w = torch.arange(
  107. -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
  108. )
  109. relative_coords_table = (
  110. torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
  111. .permute(1, 2, 0)
  112. .contiguous()
  113. .unsqueeze(0)
  114. ) # 1, 2*Wh-1, 2*Ww-1, 2
  115. if pretrained_window_size[0] > 0:
  116. relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
  117. relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
  118. else:
  119. relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
  120. relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
  121. relative_coords_table *= 8 # normalize to -8, 8
  122. relative_coords_table = (
  123. torch.sign(relative_coords_table)
  124. * torch.log2(torch.abs(relative_coords_table) + 1.0)
  125. / np.log2(8)
  126. )
  127. self.register_buffer("relative_coords_table", relative_coords_table)
  128. # get pair-wise relative position index for each token inside the window
  129. coords_h = torch.arange(self.window_size[0])
  130. coords_w = torch.arange(self.window_size[1])
  131. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  132. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  133. relative_coords = (
  134. coords_flatten[:, :, None] - coords_flatten[:, None, :]
  135. ) # 2, Wh*Ww, Wh*Ww
  136. relative_coords = relative_coords.permute(
  137. 1, 2, 0
  138. ).contiguous() # Wh*Ww, Wh*Ww, 2
  139. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  140. relative_coords[:, :, 1] += self.window_size[1] - 1
  141. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  142. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  143. self.register_buffer("relative_position_index", relative_position_index)
  144. self.qkv = nn.Linear(dim, dim * 3, bias=False)
  145. if qkv_bias:
  146. self.q_bias = nn.Parameter(torch.zeros(dim))
  147. self.v_bias = nn.Parameter(torch.zeros(dim))
  148. else:
  149. self.q_bias = None
  150. self.v_bias = None
  151. self.attn_drop = nn.Dropout(attn_drop)
  152. self.proj = nn.Linear(dim, dim)
  153. self.proj_drop = nn.Dropout(proj_drop)
  154. self.softmax = nn.Softmax(dim=-1)
  155. def forward(self, x, mask=None):
  156. """
  157. Args:
  158. x: input features with shape of (num_windows*B, N, C)
  159. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  160. """
  161. B_, N, C = x.shape
  162. qkv_bias = None
  163. if self.q_bias is not None:
  164. qkv_bias = torch.cat(
  165. (
  166. self.q_bias,
  167. torch.zeros_like(self.v_bias, requires_grad=False),
  168. self.v_bias,
  169. )
  170. )
  171. qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
  172. qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  173. q, k, v = (
  174. qkv[0],
  175. qkv[1],
  176. qkv[2],
  177. ) # make torchscript happy (cannot use tensor as tuple)
  178. # cosine attention
  179. attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
  180. logit_scale = torch.clamp(
  181. self.logit_scale,
  182. max=torch.log(torch.tensor(1.0 / 0.01, device=self.logit_scale.device)),
  183. ).exp()
  184. attn = attn * logit_scale
  185. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
  186. -1, self.num_heads
  187. )
  188. relative_position_bias = relative_position_bias_table[
  189. self.relative_position_index.view(-1)
  190. ].view(
  191. self.window_size[0] * self.window_size[1],
  192. self.window_size[0] * self.window_size[1],
  193. -1,
  194. ) # Wh*Ww,Wh*Ww,nH
  195. relative_position_bias = relative_position_bias.permute(
  196. 2, 0, 1
  197. ).contiguous() # nH, Wh*Ww, Wh*Ww
  198. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  199. attn = attn + relative_position_bias.unsqueeze(0)
  200. if mask is not None:
  201. nW = mask.shape[0]
  202. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
  203. 1
  204. ).unsqueeze(0)
  205. attn = attn.view(-1, self.num_heads, N, N)
  206. attn = self.softmax(attn)
  207. else:
  208. attn = self.softmax(attn)
  209. attn = self.attn_drop(attn)
  210. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  211. x = self.proj(x)
  212. x = self.proj_drop(x)
  213. return x
  214. def extra_repr(self) -> str:
  215. return (
  216. f"dim={self.dim}, window_size={self.window_size}, "
  217. f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
  218. )
  219. def flops(self, N):
  220. # calculate flops for 1 window with token length of N
  221. flops = 0
  222. # qkv = self.qkv(x)
  223. flops += N * self.dim * 3 * self.dim
  224. # attn = (q @ k.transpose(-2, -1))
  225. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  226. # x = (attn @ v)
  227. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  228. # x = self.proj(x)
  229. flops += N * self.dim * self.dim
  230. return flops
  231. class SwinTransformerBlock(nn.Module):
  232. r"""Swin Transformer Block.
  233. Args:
  234. dim (int): Number of input channels.
  235. input_resolution (tuple[int]): Input resulotion.
  236. num_heads (int): Number of attention heads.
  237. window_size (int): Window size.
  238. shift_size (int): Shift size for SW-MSA.
  239. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  240. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  241. drop (float, optional): Dropout rate. Default: 0.0
  242. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  243. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  244. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  245. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  246. pretrained_window_size (int): Window size in pre-training.
  247. """
  248. def __init__(
  249. self,
  250. dim,
  251. input_resolution,
  252. num_heads,
  253. window_size=7,
  254. shift_size=0,
  255. mlp_ratio=4.0,
  256. qkv_bias=True,
  257. drop=0.0,
  258. attn_drop=0.0,
  259. drop_path=0.0,
  260. act_layer=nn.GELU,
  261. norm_layer=nn.LayerNorm,
  262. pretrained_window_size=0,
  263. ):
  264. super().__init__()
  265. self.dim = dim
  266. self.input_resolution = input_resolution
  267. self.num_heads = num_heads
  268. self.window_size = window_size
  269. self.shift_size = shift_size
  270. self.mlp_ratio = mlp_ratio
  271. if min(self.input_resolution) <= self.window_size:
  272. # if window size is larger than input resolution, we don't partition windows
  273. self.shift_size = 0
  274. self.window_size = min(self.input_resolution)
  275. assert (
  276. 0 <= self.shift_size < self.window_size
  277. ), "shift_size must in 0-window_size"
  278. self.norm1 = norm_layer(dim)
  279. self.attn = WindowAttention(
  280. dim,
  281. window_size=to_2tuple(self.window_size),
  282. num_heads=num_heads,
  283. qkv_bias=qkv_bias,
  284. attn_drop=attn_drop,
  285. proj_drop=drop,
  286. pretrained_window_size=to_2tuple(pretrained_window_size),
  287. )
  288. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  289. self.norm2 = norm_layer(dim)
  290. mlp_hidden_dim = int(dim * mlp_ratio)
  291. self.mlp = Mlp(
  292. in_features=dim,
  293. hidden_features=mlp_hidden_dim,
  294. act_layer=act_layer,
  295. drop=drop,
  296. )
  297. if self.shift_size > 0:
  298. # calculate attention mask for SW-MSA
  299. H, W = self.input_resolution
  300. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  301. h_slices = (
  302. slice(0, -self.window_size),
  303. slice(-self.window_size, -self.shift_size),
  304. slice(-self.shift_size, None),
  305. )
  306. w_slices = (
  307. slice(0, -self.window_size),
  308. slice(-self.window_size, -self.shift_size),
  309. slice(-self.shift_size, None),
  310. )
  311. cnt = 0
  312. for h in h_slices:
  313. for w in w_slices:
  314. img_mask[:, h, w, :] = cnt
  315. cnt += 1
  316. mask_windows = window_partition(
  317. img_mask, self.window_size
  318. ) # nW, window_size, window_size, 1
  319. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  320. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  321. attn_mask = attn_mask.masked_fill(
  322. attn_mask != 0, float(-100.0)
  323. ).masked_fill(attn_mask == 0, float(0.0))
  324. else:
  325. attn_mask = None
  326. self.register_buffer("attn_mask", attn_mask)
  327. def forward(self, x):
  328. H, W = self.input_resolution
  329. B, L, C = x.shape
  330. assert L == H * W, "input feature has wrong size"
  331. shortcut = x
  332. x = x.view(B, H, W, C)
  333. # cyclic shift
  334. if self.shift_size > 0:
  335. shifted_x = torch.roll(
  336. x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
  337. )
  338. else:
  339. shifted_x = x
  340. # partition windows
  341. x_windows = window_partition(
  342. shifted_x, self.window_size
  343. ) # nW*B, window_size, window_size, C
  344. x_windows = x_windows.view(
  345. -1, self.window_size * self.window_size, C
  346. ) # nW*B, window_size*window_size, C
  347. # W-MSA/SW-MSA
  348. attn_windows = self.attn(
  349. x_windows, mask=self.attn_mask
  350. ) # nW*B, window_size*window_size, C
  351. # merge windows
  352. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  353. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  354. # reverse cyclic shift
  355. if self.shift_size > 0:
  356. x = torch.roll(
  357. shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
  358. )
  359. else:
  360. x = shifted_x
  361. x = x.view(B, H * W, C)
  362. x = shortcut + self.drop_path(self.norm1(x))
  363. # FFN
  364. x = x + self.drop_path(self.norm2(self.mlp(x)))
  365. return x
  366. def extra_repr(self) -> str:
  367. return (
  368. f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
  369. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
  370. )
  371. def flops(self):
  372. flops = 0
  373. H, W = self.input_resolution
  374. # norm1
  375. flops += self.dim * H * W
  376. # W-MSA/SW-MSA
  377. nW = H * W / self.window_size / self.window_size
  378. flops += nW * self.attn.flops(self.window_size * self.window_size)
  379. # mlp
  380. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  381. # norm2
  382. flops += self.dim * H * W
  383. return flops
  384. class PatchMerging(nn.Module):
  385. r"""Patch Merging Layer.
  386. Args:
  387. input_resolution (tuple[int]): Resolution of input feature.
  388. dim (int): Number of input channels.
  389. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  390. """
  391. def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
  392. super().__init__()
  393. self.input_resolution = input_resolution
  394. self.dim = dim
  395. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  396. self.norm = norm_layer(2 * dim)
  397. def forward(self, x):
  398. """
  399. x: B, H*W, C
  400. """
  401. H, W = self.input_resolution
  402. B, L, C = x.shape
  403. assert L == H * W, "input feature has wrong size"
  404. assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
  405. x = x.view(B, H, W, C)
  406. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  407. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  408. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  409. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  410. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  411. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  412. x = self.reduction(x)
  413. x = self.norm(x)
  414. return x
  415. def extra_repr(self) -> str:
  416. return f"input_resolution={self.input_resolution}, dim={self.dim}"
  417. def flops(self):
  418. H, W = self.input_resolution
  419. flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
  420. flops += H * W * self.dim // 2
  421. return flops
  422. class BasicLayer(nn.Module):
  423. """A basic Swin Transformer layer for one stage.
  424. Args:
  425. dim (int): Number of input channels.
  426. input_resolution (tuple[int]): Input resolution.
  427. depth (int): Number of blocks.
  428. num_heads (int): Number of attention heads.
  429. window_size (int): Local window size.
  430. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  431. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  432. drop (float, optional): Dropout rate. Default: 0.0
  433. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  434. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  435. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  436. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  437. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  438. pretrained_window_size (int): Local window size in pre-training.
  439. """
  440. def __init__(
  441. self,
  442. dim,
  443. input_resolution,
  444. depth,
  445. num_heads,
  446. window_size,
  447. mlp_ratio=4.0,
  448. qkv_bias=True,
  449. drop=0.0,
  450. attn_drop=0.0,
  451. drop_path=0.0,
  452. norm_layer=nn.LayerNorm,
  453. downsample=None,
  454. use_checkpoint=False,
  455. pretrained_window_size=0,
  456. ):
  457. super().__init__()
  458. self.dim = dim
  459. self.input_resolution = input_resolution
  460. self.depth = depth
  461. self.use_checkpoint = use_checkpoint
  462. # build blocks
  463. self.blocks = nn.ModuleList(
  464. [
  465. SwinTransformerBlock(
  466. dim=dim,
  467. input_resolution=input_resolution,
  468. num_heads=num_heads,
  469. window_size=window_size,
  470. shift_size=0 if (i % 2 == 0) else window_size // 2,
  471. mlp_ratio=mlp_ratio,
  472. qkv_bias=qkv_bias,
  473. drop=drop,
  474. attn_drop=attn_drop,
  475. drop_path=(
  476. drop_path[i] if isinstance(drop_path, list) else drop_path
  477. ),
  478. norm_layer=norm_layer,
  479. pretrained_window_size=pretrained_window_size,
  480. )
  481. for i in range(depth)
  482. ]
  483. )
  484. # patch merging layer
  485. if downsample is not None:
  486. self.downsample = downsample(
  487. input_resolution, dim=dim, norm_layer=norm_layer
  488. )
  489. else:
  490. self.downsample = None
  491. def forward(self, x):
  492. for blk in self.blocks:
  493. if self.use_checkpoint:
  494. x = checkpoint.checkpoint(blk, x)
  495. else:
  496. x = blk(x)
  497. if self.downsample is not None:
  498. x = self.downsample(x)
  499. return x
  500. def extra_repr(self) -> str:
  501. return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
  502. def flops(self):
  503. flops = 0
  504. for blk in self.blocks:
  505. flops += blk.flops()
  506. if self.downsample is not None:
  507. flops += self.downsample.flops()
  508. return flops
  509. def _init_respostnorm(self):
  510. for blk in self.blocks:
  511. nn.init.constant_(blk.norm1.bias, 0)
  512. nn.init.constant_(blk.norm1.weight, 0)
  513. nn.init.constant_(blk.norm2.bias, 0)
  514. nn.init.constant_(blk.norm2.weight, 0)
  515. class PatchEmbed(nn.Module):
  516. r"""Image to Patch Embedding
  517. Args:
  518. img_size (int): Image size. Default: 224.
  519. patch_size (int): Patch token size. Default: 4.
  520. in_chans (int): Number of input image channels. Default: 3.
  521. embed_dim (int): Number of linear projection output channels. Default: 96.
  522. norm_layer (nn.Module, optional): Normalization layer. Default: None
  523. """
  524. def __init__(
  525. self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
  526. ):
  527. super().__init__()
  528. img_size = to_2tuple(img_size)
  529. patch_size = to_2tuple(patch_size)
  530. patches_resolution = [
  531. img_size[0] // patch_size[0],
  532. img_size[1] // patch_size[1],
  533. ]
  534. self.img_size = img_size
  535. self.patch_size = patch_size
  536. self.patches_resolution = patches_resolution
  537. self.num_patches = patches_resolution[0] * patches_resolution[1]
  538. self.in_chans = in_chans
  539. self.embed_dim = embed_dim
  540. self.proj = nn.Conv2d(
  541. in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
  542. )
  543. if norm_layer is not None:
  544. self.norm = norm_layer(embed_dim)
  545. else:
  546. self.norm = None
  547. def forward(self, x):
  548. B, C, H, W = x.shape
  549. # FIXME look at relaxing size constraints
  550. assert (
  551. H == self.img_size[0] and W == self.img_size[1]
  552. ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  553. x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
  554. if self.norm is not None:
  555. x = self.norm(x)
  556. return x
  557. def flops(self):
  558. Ho, Wo = self.patches_resolution
  559. flops = (
  560. Ho
  561. * Wo
  562. * self.embed_dim
  563. * self.in_chans
  564. * (self.patch_size[0] * self.patch_size[1])
  565. )
  566. if self.norm is not None:
  567. flops += Ho * Wo * self.embed_dim
  568. return flops
  569. class SwinTransformerV2(nn.Module):
  570. r"""Swin Transformer
  571. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  572. https://arxiv.org/pdf/2103.14030
  573. Args:
  574. img_size (int | tuple(int)): Input image size. Default 224
  575. patch_size (int | tuple(int)): Patch size. Default: 4
  576. in_chans (int): Number of input image channels. Default: 3
  577. num_classes (int): Number of classes for classification head. Default: 1000
  578. embed_dim (int): Patch embedding dimension. Default: 96
  579. depths (tuple(int)): Depth of each Swin Transformer layer.
  580. num_heads (tuple(int)): Number of attention heads in different layers.
  581. window_size (int): Window size. Default: 7
  582. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  583. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  584. drop_rate (float): Dropout rate. Default: 0
  585. attn_drop_rate (float): Attention dropout rate. Default: 0
  586. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  587. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  588. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  589. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  590. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  591. pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
  592. """
  593. def __init__(
  594. self,
  595. img_size=224,
  596. patch_size=4,
  597. in_chans=3,
  598. num_classes=1000,
  599. embed_dim=96,
  600. depths=(2, 2, 6, 2),
  601. num_heads=(3, 6, 12, 24),
  602. window_size=7,
  603. mlp_ratio=4.0,
  604. qkv_bias=True,
  605. drop_rate=0.0,
  606. attn_drop_rate=0.0,
  607. drop_path_rate=0.1,
  608. norm_layer=nn.LayerNorm,
  609. ape=False,
  610. patch_norm=True,
  611. use_checkpoint=False,
  612. pretrained_window_sizes=(0, 0, 0, 0),
  613. **kwargs,
  614. ):
  615. super().__init__()
  616. self.num_classes = num_classes
  617. self.num_layers = len(depths)
  618. self.embed_dim = embed_dim
  619. self.ape = ape
  620. self.patch_norm = patch_norm
  621. self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
  622. self.mlp_ratio = mlp_ratio
  623. # split image into non-overlapping patches
  624. self.patch_embed = PatchEmbed(
  625. img_size=img_size,
  626. patch_size=patch_size,
  627. in_chans=in_chans,
  628. embed_dim=embed_dim,
  629. norm_layer=norm_layer if self.patch_norm else None,
  630. )
  631. num_patches = self.patch_embed.num_patches
  632. patches_resolution = self.patch_embed.patches_resolution
  633. self.patches_resolution = patches_resolution
  634. # absolute position embedding
  635. if self.ape:
  636. # noinspection PyTypeChecker
  637. self.absolute_pos_embed = nn.Parameter(
  638. torch.zeros(1, num_patches, embed_dim)
  639. )
  640. trunc_normal_(self.absolute_pos_embed, std=0.02)
  641. self.pos_drop = nn.Dropout(p=drop_rate)
  642. # stochastic depth
  643. dpr = [
  644. x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
  645. ] # stochastic depth decay rule
  646. # build layers
  647. self.layers = nn.ModuleList()
  648. for i_layer in range(self.num_layers):
  649. layer = BasicLayer(
  650. dim=int(embed_dim * 2**i_layer),
  651. input_resolution=(
  652. patches_resolution[0] // (2**i_layer),
  653. patches_resolution[1] // (2**i_layer),
  654. ),
  655. depth=depths[i_layer],
  656. num_heads=num_heads[i_layer],
  657. window_size=window_size,
  658. mlp_ratio=self.mlp_ratio,
  659. qkv_bias=qkv_bias,
  660. drop=drop_rate,
  661. attn_drop=attn_drop_rate,
  662. drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
  663. norm_layer=norm_layer,
  664. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  665. use_checkpoint=use_checkpoint,
  666. pretrained_window_size=pretrained_window_sizes[i_layer],
  667. )
  668. self.layers.append(layer)
  669. self.norm = norm_layer(self.num_features)
  670. self.avgpool = nn.AdaptiveAvgPool1d(1)
  671. self.head = (
  672. nn.Linear(self.num_features, num_classes)
  673. if num_classes > 0
  674. else nn.Identity()
  675. )
  676. self.apply(self._init_weights)
  677. for bly in self.layers:
  678. # noinspection PyProtectedMember
  679. bly._init_respostnorm()
  680. @staticmethod
  681. def _tokens_to_feature_map(x, resolution):
  682. """Convert flattened tokens ``B, L, C`` to feature maps ``B, C, H, W``."""
  683. B, L, C = x.shape
  684. H, W = resolution
  685. assert L == H * W, f"token length {L} does not match resolution {H}x{W}"
  686. return x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
  687. @staticmethod
  688. def proj_out(x, normalize=False):
  689. if normalize:
  690. x = (
  691. F.layer_norm(x.permute(0, 2, 3, 1), [x.shape[1]])
  692. .permute(0, 3, 1, 2)
  693. .contiguous()
  694. )
  695. return x
  696. def forward_multiscale_features(self, x, normalize=True, include_patch_embed=True):
  697. """Return MONAI-style multiscale feature maps.
  698. Args:
  699. x: input tensor with shape ``B, C, H, W``.
  700. normalize: if True, apply per-location layer norm after reshaping to ``B, C, H, W``.
  701. include_patch_embed: if True, include the patch embedding output as the first feature.
  702. Returns:
  703. A list of feature maps. When ``include_patch_embed`` is True, the outputs follow
  704. ``[patch_embed, layer1, layer2, layer3, layer4]`` and match MONAI's Swin encoder
  705. convention. Otherwise returns the 4 standard stage outputs.
  706. """
  707. x = self.patch_embed(x)
  708. if self.ape:
  709. x = x + self.absolute_pos_embed
  710. x = self.pos_drop(x)
  711. features = []
  712. resolution = tuple(self.patches_resolution)
  713. if include_patch_embed:
  714. features.append(
  715. self.proj_out(self._tokens_to_feature_map(x, resolution), normalize)
  716. )
  717. for layer in self.layers:
  718. x = layer(x)
  719. if layer.downsample is not None:
  720. resolution = (resolution[0] // 2, resolution[1] // 2)
  721. features.append(
  722. self.proj_out(self._tokens_to_feature_map(x, resolution), normalize)
  723. )
  724. return features
  725. def forward_stage_features(self, x, normalize=True):
  726. """Return standard stage outputs before the next stage downsamples them.
  727. This matches the common timm/mmseg backbone convention and returns 4 feature maps
  728. with strides ``4, 8, 16, 32`` relative to the input image.
  729. """
  730. x = self.patch_embed(x)
  731. if self.ape:
  732. x = x + self.absolute_pos_embed
  733. x = self.pos_drop(x)
  734. features = []
  735. resolution = tuple(self.patches_resolution)
  736. for layer in self.layers:
  737. for blk in layer.blocks:
  738. if layer.use_checkpoint:
  739. x = checkpoint.checkpoint(blk, x)
  740. else:
  741. x = blk(x)
  742. features.append(
  743. self.proj_out(self._tokens_to_feature_map(x, resolution), normalize)
  744. )
  745. if layer.downsample is not None:
  746. x = layer.downsample(x)
  747. resolution = (resolution[0] // 2, resolution[1] // 2)
  748. return features
  749. @staticmethod
  750. def _init_weights(m):
  751. if isinstance(m, nn.Linear):
  752. trunc_normal_(m.weight, std=0.02)
  753. if isinstance(m, nn.Linear) and m.bias is not None:
  754. nn.init.constant_(m.bias, 0)
  755. elif isinstance(m, nn.LayerNorm):
  756. nn.init.constant_(m.bias, 0)
  757. nn.init.constant_(m.weight, 1.0)
  758. @torch.jit.ignore
  759. def no_weight_decay(self):
  760. return {"absolute_pos_embed"}
  761. @torch.jit.ignore
  762. def no_weight_decay_keywords(self):
  763. return {"cpb_mlp", "logit_scale", "relative_position_bias_table"}
  764. def forward_features(self, x):
  765. x = self.patch_embed(x)
  766. if self.ape:
  767. x = x + self.absolute_pos_embed
  768. x = self.pos_drop(x)
  769. for layer in self.layers:
  770. x = layer(x)
  771. x = self.norm(x) # B L C
  772. x = self.avgpool(x.transpose(1, 2)) # B C 1
  773. x = torch.flatten(x, 1)
  774. return x
  775. def forward(self, x):
  776. x = self.forward_features(x)
  777. x = self.head(x)
  778. return x
  779. def flops(self):
  780. flops = 0
  781. flops += self.patch_embed.flops()
  782. for i, layer in enumerate(self.layers):
  783. flops += layer.flops()
  784. flops += (
  785. self.num_features
  786. * self.patches_resolution[0]
  787. * self.patches_resolution[1]
  788. // (2**self.num_layers)
  789. )
  790. flops += self.num_features * self.num_classes
  791. return flops