swin_unetr.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125
  1. # Copyright (c) MONAI Consortium
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. # Unless required by applicable law or agreed to in writing, software
  7. # distributed under the License is distributed on an "AS IS" BASIS,
  8. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # See the License for the specific language governing permissions and
  10. # limitations under the License.
  11. from __future__ import annotations
  12. import itertools
  13. from collections.abc import Sequence
  14. import numpy as np
  15. import torch
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. import torch.utils.checkpoint as checkpoint
  19. from torch.nn import LayerNorm
  20. from monai.networks.blocks import MLPBlock as Mlp
  21. from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
  22. from monai.networks.layers import DropPath, trunc_normal_
  23. from monai.utils import ensure_tuple_rep, look_up_option, optional_import
  24. rearrange, _ = optional_import("einops", name="rearrange")
  25. __all__ = [
  26. "SwinUNETR",
  27. "window_partition",
  28. "window_reverse",
  29. "WindowAttention",
  30. "SwinTransformerBlock",
  31. "PatchMerging",
  32. "PatchMergingV2",
  33. "MERGING_MODE",
  34. "BasicLayer",
  35. "SwinTransformer",
  36. ]
  37. class SwinUNETR(nn.Module):
  38. """
  39. Swin UNETR based on: "Hatamizadeh et al.,
  40. Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
  41. <https://arxiv.org/abs/2201.01266>"
  42. """
  43. def __init__(
  44. self,
  45. in_channels: int,
  46. out_channels: int,
  47. patch_size: int = 2,
  48. depths: Sequence[int] = (2, 2, 2, 2),
  49. num_heads: Sequence[int] = (3, 6, 12, 24),
  50. window_size: Sequence[int] | int = 7,
  51. qkv_bias: bool = True,
  52. mlp_ratio: float = 4.0,
  53. feature_size: int = 24,
  54. norm_name: tuple | str = "instance",
  55. drop_rate: float = 0.0,
  56. attn_drop_rate: float = 0.0,
  57. dropout_path_rate: float = 0.0,
  58. normalize: bool = True,
  59. norm_layer: type[LayerNorm] = nn.LayerNorm,
  60. patch_norm: bool = False,
  61. use_checkpoint: bool = False,
  62. spatial_dims: int = 3,
  63. downsample: str | nn.Module = "merging",
  64. use_v2: bool = False,
  65. ) -> None:
  66. """
  67. Args:
  68. in_channels: dimension of input channels.
  69. out_channels: dimension of output channels.
  70. patch_size: size of the patch token.
  71. feature_size: dimension of network feature size.
  72. depths: number of layers in each stage.
  73. num_heads: number of attention heads.
  74. window_size: local window size.
  75. qkv_bias: add a learnable bias to query, key, value.
  76. mlp_ratio: ratio of mlp hidden dim to embedding dim.
  77. norm_name: feature normalization type and arguments.
  78. drop_rate: dropout rate.
  79. attn_drop_rate: attention dropout rate.
  80. dropout_path_rate: drop path rate.
  81. normalize: normalize output intermediate features in each stage.
  82. norm_layer: normalization layer.
  83. patch_norm: whether to apply normalization to the patch embedding. Default is False.
  84. use_checkpoint: use gradient checkpointing for reduced memory usage.
  85. spatial_dims: number of spatial dims.
  86. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
  87. user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
  88. The default is currently `"merging"` (the original version defined in v0.9.0).
  89. use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
  90. Examples::
  91. # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
  92. >>> net = SwinUNETR(in_channels=1, out_channels=4, feature_size=48)
  93. # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
  94. >>> net = SwinUNETR(in_channels=4, out_channels=3, depths=(2,4,2,2))
  95. # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
  96. >>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
  97. """
  98. super().__init__()
  99. if spatial_dims not in (2, 3):
  100. raise ValueError("spatial dimension should be 2 or 3.")
  101. self.patch_size = patch_size
  102. patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
  103. window_size = ensure_tuple_rep(window_size, spatial_dims)
  104. if not (0 <= drop_rate <= 1):
  105. raise ValueError("dropout rate should be between 0 and 1.")
  106. if not (0 <= attn_drop_rate <= 1):
  107. raise ValueError("attention dropout rate should be between 0 and 1.")
  108. if not (0 <= dropout_path_rate <= 1):
  109. raise ValueError("drop path rate should be between 0 and 1.")
  110. if feature_size % 12 != 0:
  111. raise ValueError("feature_size should be divisible by 12.")
  112. self.normalize = normalize
  113. self.swinViT = SwinTransformer(
  114. in_chans=in_channels,
  115. embed_dim=feature_size,
  116. window_size=window_size,
  117. patch_size=patch_sizes,
  118. depths=depths,
  119. num_heads=num_heads,
  120. mlp_ratio=mlp_ratio,
  121. qkv_bias=qkv_bias,
  122. drop_rate=drop_rate,
  123. attn_drop_rate=attn_drop_rate,
  124. drop_path_rate=dropout_path_rate,
  125. norm_layer=norm_layer,
  126. patch_norm=patch_norm,
  127. use_checkpoint=use_checkpoint,
  128. spatial_dims=spatial_dims,
  129. downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
  130. use_v2=use_v2,
  131. )
  132. self.encoder1 = UnetrBasicBlock(
  133. spatial_dims=spatial_dims,
  134. in_channels=in_channels,
  135. out_channels=feature_size,
  136. kernel_size=3,
  137. stride=1,
  138. norm_name=norm_name,
  139. res_block=True,
  140. )
  141. self.encoder2 = UnetrBasicBlock(
  142. spatial_dims=spatial_dims,
  143. in_channels=feature_size,
  144. out_channels=feature_size,
  145. kernel_size=3,
  146. stride=1,
  147. norm_name=norm_name,
  148. res_block=True,
  149. )
  150. self.encoder3 = UnetrBasicBlock(
  151. spatial_dims=spatial_dims,
  152. in_channels=2 * feature_size,
  153. out_channels=2 * feature_size,
  154. kernel_size=3,
  155. stride=1,
  156. norm_name=norm_name,
  157. res_block=True,
  158. )
  159. self.encoder4 = UnetrBasicBlock(
  160. spatial_dims=spatial_dims,
  161. in_channels=4 * feature_size,
  162. out_channels=4 * feature_size,
  163. kernel_size=3,
  164. stride=1,
  165. norm_name=norm_name,
  166. res_block=True,
  167. )
  168. self.encoder10 = UnetrBasicBlock(
  169. spatial_dims=spatial_dims,
  170. in_channels=16 * feature_size,
  171. out_channels=16 * feature_size,
  172. kernel_size=3,
  173. stride=1,
  174. norm_name=norm_name,
  175. res_block=True,
  176. )
  177. self.decoder5 = UnetrUpBlock(
  178. spatial_dims=spatial_dims,
  179. in_channels=16 * feature_size,
  180. out_channels=8 * feature_size,
  181. kernel_size=3,
  182. upsample_kernel_size=2,
  183. norm_name=norm_name,
  184. res_block=True,
  185. )
  186. self.decoder4 = UnetrUpBlock(
  187. spatial_dims=spatial_dims,
  188. in_channels=feature_size * 8,
  189. out_channels=feature_size * 4,
  190. kernel_size=3,
  191. upsample_kernel_size=2,
  192. norm_name=norm_name,
  193. res_block=True,
  194. )
  195. self.decoder3 = UnetrUpBlock(
  196. spatial_dims=spatial_dims,
  197. in_channels=feature_size * 4,
  198. out_channels=feature_size * 2,
  199. kernel_size=3,
  200. upsample_kernel_size=2,
  201. norm_name=norm_name,
  202. res_block=True,
  203. )
  204. self.decoder2 = UnetrUpBlock(
  205. spatial_dims=spatial_dims,
  206. in_channels=feature_size * 2,
  207. out_channels=feature_size,
  208. kernel_size=3,
  209. upsample_kernel_size=2,
  210. norm_name=norm_name,
  211. res_block=True,
  212. )
  213. self.decoder1 = UnetrUpBlock(
  214. spatial_dims=spatial_dims,
  215. in_channels=feature_size,
  216. out_channels=feature_size,
  217. kernel_size=3,
  218. upsample_kernel_size=2,
  219. norm_name=norm_name,
  220. res_block=True,
  221. )
  222. self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
  223. def load_from(self, weights):
  224. layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment]
  225. layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment]
  226. layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment]
  227. layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment]
  228. wstate = weights["state_dict"]
  229. with torch.no_grad():
  230. self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"])
  231. self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"])
  232. for bname, block in layers1_0.blocks.named_children():
  233. block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator]
  234. if layers1_0.downsample is not None:
  235. d = layers1_0.downsample
  236. d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore
  237. d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore
  238. d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore
  239. for bname, block in layers2_0.blocks.named_children():
  240. block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator]
  241. if layers2_0.downsample is not None:
  242. d = layers2_0.downsample
  243. d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore
  244. d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore
  245. d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore
  246. for bname, block in layers3_0.blocks.named_children():
  247. block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator]
  248. if layers3_0.downsample is not None:
  249. d = layers3_0.downsample
  250. d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore
  251. d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore
  252. d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore
  253. for bname, block in layers4_0.blocks.named_children():
  254. block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator]
  255. if layers4_0.downsample is not None:
  256. d = layers4_0.downsample
  257. d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore
  258. d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore
  259. d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore
  260. @torch.jit.unused
  261. def _check_input_size(self, spatial_shape):
  262. img_size = np.array(spatial_shape)
  263. remainder = (img_size % np.power(self.patch_size, 5)) > 0
  264. if remainder.any():
  265. wrong_dims = (np.where(remainder)[0] + 2).tolist()
  266. raise ValueError(
  267. f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
  268. f" must be divisible by {self.patch_size}**5."
  269. )
  270. def forward(self, x_in):
  271. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  272. self._check_input_size(x_in.shape[2:])
  273. hidden_states_out = self.swinViT(x_in, self.normalize)
  274. enc0 = self.encoder1(x_in)
  275. enc1 = self.encoder2(hidden_states_out[0])
  276. enc2 = self.encoder3(hidden_states_out[1])
  277. enc3 = self.encoder4(hidden_states_out[2])
  278. dec4 = self.encoder10(hidden_states_out[4])
  279. dec3 = self.decoder5(dec4, hidden_states_out[3])
  280. dec2 = self.decoder4(dec3, enc3)
  281. dec1 = self.decoder3(dec2, enc2)
  282. dec0 = self.decoder2(dec1, enc1)
  283. out = self.decoder1(dec0, enc0)
  284. logits = self.out(out)
  285. return logits
  286. def window_partition(x, window_size):
  287. """window partition operation based on: "Liu et al.,
  288. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  289. <https://arxiv.org/abs/2103.14030>"
  290. https://github.com/microsoft/Swin-Transformer
  291. Args:
  292. x: input tensor.
  293. window_size: local window size.
  294. """
  295. x_shape = x.size() # length 4 or 5 only
  296. if len(x_shape) == 5:
  297. b, d, h, w, c = x_shape
  298. x = x.view(
  299. b,
  300. d // window_size[0],
  301. window_size[0],
  302. h // window_size[1],
  303. window_size[1],
  304. w // window_size[2],
  305. window_size[2],
  306. c,
  307. )
  308. windows = (
  309. x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
  310. )
  311. else: # if len(x_shape) == 4:
  312. b, h, w, c = x.shape
  313. x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
  314. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
  315. return windows
  316. def window_reverse(windows, window_size, dims):
  317. """window reverse operation based on: "Liu et al.,
  318. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  319. <https://arxiv.org/abs/2103.14030>"
  320. https://github.com/microsoft/Swin-Transformer
  321. Args:
  322. windows: windows tensor.
  323. window_size: local window size.
  324. dims: dimension values.
  325. """
  326. if len(dims) == 4:
  327. b, d, h, w = dims
  328. x = windows.view(
  329. b,
  330. d // window_size[0],
  331. h // window_size[1],
  332. w // window_size[2],
  333. window_size[0],
  334. window_size[1],
  335. window_size[2],
  336. -1,
  337. )
  338. x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
  339. elif len(dims) == 3:
  340. b, h, w = dims
  341. x = windows.view(b, h // window_size[0], w // window_size[1], window_size[0], window_size[1], -1)
  342. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
  343. return x
  344. def get_window_size(x_size, window_size, shift_size=None):
  345. """Computing window size based on: "Liu et al.,
  346. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  347. <https://arxiv.org/abs/2103.14030>"
  348. https://github.com/microsoft/Swin-Transformer
  349. Args:
  350. x_size: input size.
  351. window_size: local window size.
  352. shift_size: window shifting size.
  353. """
  354. use_window_size = list(window_size)
  355. if shift_size is not None:
  356. use_shift_size = list(shift_size)
  357. for i in range(len(x_size)):
  358. if x_size[i] <= window_size[i]:
  359. use_window_size[i] = x_size[i]
  360. if shift_size is not None:
  361. use_shift_size[i] = 0
  362. if shift_size is None:
  363. return tuple(use_window_size)
  364. else:
  365. return tuple(use_window_size), tuple(use_shift_size)
  366. class WindowAttention(nn.Module):
  367. """
  368. Window based multi-head self attention module with relative position bias based on: "Liu et al.,
  369. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  370. <https://arxiv.org/abs/2103.14030>"
  371. https://github.com/microsoft/Swin-Transformer
  372. """
  373. def __init__(
  374. self,
  375. dim: int,
  376. num_heads: int,
  377. window_size: Sequence[int],
  378. qkv_bias: bool = False,
  379. attn_drop: float = 0.0,
  380. proj_drop: float = 0.0,
  381. ) -> None:
  382. """
  383. Args:
  384. dim: number of feature channels.
  385. num_heads: number of attention heads.
  386. window_size: local window size.
  387. qkv_bias: add a learnable bias to query, key, value.
  388. attn_drop: attention dropout rate.
  389. proj_drop: dropout rate of output.
  390. """
  391. super().__init__()
  392. self.dim = dim
  393. self.window_size = window_size
  394. self.num_heads = num_heads
  395. head_dim = dim // num_heads
  396. self.scale = head_dim ** -0.5
  397. mesh_args = torch.meshgrid.__kwdefaults__
  398. if len(self.window_size) == 3:
  399. self.relative_position_bias_table = nn.Parameter(
  400. torch.zeros(
  401. (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
  402. num_heads,
  403. )
  404. )
  405. coords_d = torch.arange(self.window_size[0])
  406. coords_h = torch.arange(self.window_size[1])
  407. coords_w = torch.arange(self.window_size[2])
  408. if mesh_args is not None:
  409. coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
  410. else:
  411. coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
  412. coords_flatten = torch.flatten(coords, 1)
  413. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  414. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  415. relative_coords[:, :, 0] += self.window_size[0] - 1
  416. relative_coords[:, :, 1] += self.window_size[1] - 1
  417. relative_coords[:, :, 2] += self.window_size[2] - 1
  418. relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
  419. relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
  420. elif len(self.window_size) == 2:
  421. self.relative_position_bias_table = nn.Parameter(
  422. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
  423. )
  424. coords_h = torch.arange(self.window_size[0])
  425. coords_w = torch.arange(self.window_size[1])
  426. if mesh_args is not None:
  427. coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
  428. else:
  429. coords = torch.stack(torch.meshgrid(coords_h, coords_w))
  430. coords_flatten = torch.flatten(coords, 1)
  431. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  432. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  433. relative_coords[:, :, 0] += self.window_size[0] - 1
  434. relative_coords[:, :, 1] += self.window_size[1] - 1
  435. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  436. relative_position_index = relative_coords.sum(-1)
  437. self.register_buffer("relative_position_index", relative_position_index)
  438. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  439. self.attn_drop = nn.Dropout(attn_drop)
  440. self.proj = nn.Linear(dim, dim)
  441. self.proj_drop = nn.Dropout(proj_drop)
  442. trunc_normal_(self.relative_position_bias_table, std=0.02)
  443. self.softmax = nn.Softmax(dim=-1)
  444. def forward(self, x, mask):
  445. b, n, c = x.shape
  446. qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
  447. q, k, v = qkv[0], qkv[1], qkv[2]
  448. q = q * self.scale
  449. attn = q @ k.transpose(-2, -1)
  450. relative_position_bias = self.relative_position_bias_table[
  451. self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator]
  452. ].reshape(n, n, -1)
  453. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  454. attn = attn + relative_position_bias.unsqueeze(0)
  455. if mask is not None:
  456. nw = mask.shape[0]
  457. attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
  458. attn = attn.view(-1, self.num_heads, n, n)
  459. attn = self.softmax(attn)
  460. else:
  461. attn = self.softmax(attn)
  462. attn = self.attn_drop(attn).to(v.dtype)
  463. x = (attn @ v).transpose(1, 2).reshape(b, n, c)
  464. x = self.proj(x)
  465. x = self.proj_drop(x)
  466. return x
  467. class SwinTransformerBlock(nn.Module):
  468. """
  469. Swin Transformer block based on: "Liu et al.,
  470. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  471. <https://arxiv.org/abs/2103.14030>"
  472. https://github.com/microsoft/Swin-Transformer
  473. """
  474. def __init__(
  475. self,
  476. dim: int,
  477. num_heads: int,
  478. window_size: Sequence[int],
  479. shift_size: Sequence[int],
  480. mlp_ratio: float = 4.0,
  481. qkv_bias: bool = True,
  482. drop: float = 0.0,
  483. attn_drop: float = 0.0,
  484. drop_path: float = 0.0,
  485. act_layer: str = "GELU",
  486. norm_layer: type[LayerNorm] = nn.LayerNorm,
  487. use_checkpoint: bool = False,
  488. ) -> None:
  489. """
  490. Args:
  491. dim: number of feature channels.
  492. num_heads: number of attention heads.
  493. window_size: local window size.
  494. shift_size: window shift size.
  495. mlp_ratio: ratio of mlp hidden dim to embedding dim.
  496. qkv_bias: add a learnable bias to query, key, value.
  497. drop: dropout rate.
  498. attn_drop: attention dropout rate.
  499. drop_path: stochastic depth rate.
  500. act_layer: activation layer.
  501. norm_layer: normalization layer.
  502. use_checkpoint: use gradient checkpointing for reduced memory usage.
  503. """
  504. super().__init__()
  505. self.dim = dim
  506. self.num_heads = num_heads
  507. self.window_size = window_size
  508. self.shift_size = shift_size
  509. self.mlp_ratio = mlp_ratio
  510. self.use_checkpoint = use_checkpoint
  511. self.norm1 = norm_layer(dim)
  512. self.attn = WindowAttention(
  513. dim,
  514. window_size=self.window_size,
  515. num_heads=num_heads,
  516. qkv_bias=qkv_bias,
  517. attn_drop=attn_drop,
  518. proj_drop=drop,
  519. )
  520. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  521. self.norm2 = norm_layer(dim)
  522. mlp_hidden_dim = int(dim * mlp_ratio)
  523. self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin")
  524. def forward_part1(self, x, mask_matrix):
  525. x_shape = x.size()
  526. x = self.norm1(x)
  527. if len(x_shape) == 5:
  528. b, d, h, w, c = x.shape
  529. window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
  530. pad_l = pad_t = pad_d0 = 0
  531. pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0]
  532. pad_b = (window_size[1] - h % window_size[1]) % window_size[1]
  533. pad_r = (window_size[2] - w % window_size[2]) % window_size[2]
  534. x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
  535. _, dp, hp, wp, _ = x.shape
  536. dims = [b, dp, hp, wp]
  537. else: # elif len(x_shape) == 4
  538. b, h, w, c = x.shape
  539. window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
  540. pad_l = pad_t = 0
  541. pad_b = (window_size[0] - h % window_size[0]) % window_size[0]
  542. pad_r = (window_size[1] - w % window_size[1]) % window_size[1]
  543. x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
  544. _, hp, wp, _ = x.shape
  545. dims = [b, hp, wp]
  546. if any(i > 0 for i in shift_size):
  547. if len(x_shape) == 5:
  548. shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
  549. elif len(x_shape) == 4:
  550. shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
  551. attn_mask = mask_matrix
  552. else:
  553. shifted_x = x
  554. attn_mask = None
  555. x_windows = window_partition(shifted_x, window_size)
  556. attn_windows = self.attn(x_windows, mask=attn_mask)
  557. attn_windows = attn_windows.view(-1, *(window_size + (c,)))
  558. shifted_x = window_reverse(attn_windows, window_size, dims)
  559. if any(i > 0 for i in shift_size):
  560. if len(x_shape) == 5:
  561. x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
  562. elif len(x_shape) == 4:
  563. x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
  564. else:
  565. x = shifted_x
  566. if len(x_shape) == 5:
  567. if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
  568. x = x[:, :d, :h, :w, :].contiguous()
  569. elif len(x_shape) == 4:
  570. if pad_r > 0 or pad_b > 0:
  571. x = x[:, :h, :w, :].contiguous()
  572. return x
  573. def forward_part2(self, x):
  574. return self.drop_path(self.mlp(self.norm2(x)))
  575. def load_from(self, weights, n_block, layer):
  576. root = f"module.{layer}.0.blocks.{n_block}."
  577. block_names = [
  578. "norm1.weight",
  579. "norm1.bias",
  580. "attn.relative_position_bias_table",
  581. "attn.relative_position_index",
  582. "attn.qkv.weight",
  583. "attn.qkv.bias",
  584. "attn.proj.weight",
  585. "attn.proj.bias",
  586. "norm2.weight",
  587. "norm2.bias",
  588. "mlp.fc1.weight",
  589. "mlp.fc1.bias",
  590. "mlp.fc2.weight",
  591. "mlp.fc2.bias",
  592. ]
  593. with torch.no_grad():
  594. self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]])
  595. self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]])
  596. self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]])
  597. self.attn.relative_position_index.copy_(
  598. weights["state_dict"][root + block_names[3]]) # type: ignore[operator]
  599. self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]])
  600. self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]])
  601. self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]])
  602. self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]])
  603. self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]])
  604. self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]])
  605. self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]])
  606. self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]])
  607. self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]])
  608. self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]])
  609. def forward(self, x, mask_matrix):
  610. shortcut = x
  611. if self.use_checkpoint:
  612. x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False)
  613. else:
  614. x = self.forward_part1(x, mask_matrix)
  615. x = shortcut + self.drop_path(x)
  616. if self.use_checkpoint:
  617. x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False)
  618. else:
  619. x = x + self.forward_part2(x)
  620. return x
  621. class PatchMergingV2(nn.Module):
  622. """
  623. Patch merging layer based on: "Liu et al.,
  624. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  625. <https://arxiv.org/abs/2103.14030>"
  626. https://github.com/microsoft/Swin-Transformer
  627. """
  628. def __init__(self, dim: int, norm_layer: type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None:
  629. """
  630. Args:
  631. dim: number of feature channels.
  632. norm_layer: normalization layer.
  633. spatial_dims: number of spatial dims.
  634. """
  635. super().__init__()
  636. self.dim = dim
  637. if spatial_dims == 3:
  638. self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
  639. self.norm = norm_layer(8 * dim)
  640. elif spatial_dims == 2:
  641. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  642. self.norm = norm_layer(4 * dim)
  643. def forward(self, x):
  644. x_shape = x.size()
  645. if len(x_shape) == 5:
  646. b, d, h, w, c = x_shape
  647. pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
  648. if pad_input:
  649. x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
  650. x = torch.cat(
  651. [x[:, i::2, j::2, k::2, :] for i, j, k in itertools.product(range(2), range(2), range(2))], -1
  652. )
  653. elif len(x_shape) == 4:
  654. b, h, w, c = x_shape
  655. pad_input = (h % 2 == 1) or (w % 2 == 1)
  656. if pad_input:
  657. x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
  658. x = torch.cat([x[:, j::2, i::2, :] for i, j in itertools.product(range(2), range(2))], -1)
  659. x = self.norm(x)
  660. x = self.reduction(x)
  661. return x
  662. class PatchMerging(PatchMergingV2):
  663. """The `PatchMerging` module previously defined in v0.9.0."""
  664. def forward(self, x):
  665. x_shape = x.size()
  666. if len(x_shape) == 4:
  667. return super().forward(x)
  668. if len(x_shape) != 5:
  669. raise ValueError(f"expecting 5D x, got {x.shape}.")
  670. b, d, h, w, c = x_shape
  671. pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
  672. if pad_input:
  673. x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
  674. x0 = x[:, 0::2, 0::2, 0::2, :]
  675. x1 = x[:, 1::2, 0::2, 0::2, :]
  676. x2 = x[:, 0::2, 1::2, 0::2, :]
  677. x3 = x[:, 0::2, 0::2, 1::2, :]
  678. x4 = x[:, 1::2, 1::2, 0::2, :]
  679. x5 = x[:, 1::2, 0::2, 1::2, :]
  680. x6 = x[:, 0::2, 1::2, 1::2, :]
  681. x7 = x[:, 1::2, 1::2, 1::2, :]
  682. x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
  683. x = self.norm(x)
  684. x = self.reduction(x)
  685. return x
  686. MERGING_MODE = {"merging": PatchMerging, "mergingv2": PatchMergingV2}
  687. def compute_mask(dims, window_size, shift_size, device):
  688. """Computing region masks based on: "Liu et al.,
  689. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  690. <https://arxiv.org/abs/2103.14030>"
  691. https://github.com/microsoft/Swin-Transformer
  692. Args:
  693. dims: dimension values.
  694. window_size: local window size.
  695. shift_size: shift size.
  696. device: device.
  697. """
  698. cnt = 0
  699. if len(dims) == 3:
  700. d, h, w = dims
  701. img_mask = torch.zeros((1, d, h, w, 1), device=device)
  702. for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
  703. for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
  704. for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
  705. img_mask[:, d, h, w, :] = cnt
  706. cnt += 1
  707. elif len(dims) == 2:
  708. h, w = dims
  709. img_mask = torch.zeros((1, h, w, 1), device=device)
  710. for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
  711. for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
  712. img_mask[:, h, w, :] = cnt
  713. cnt += 1
  714. mask_windows = window_partition(img_mask, window_size)
  715. mask_windows = mask_windows.squeeze(-1)
  716. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  717. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  718. return attn_mask
  719. class BasicLayer(nn.Module):
  720. """
  721. Basic Swin Transformer layer in one stage based on: "Liu et al.,
  722. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  723. <https://arxiv.org/abs/2103.14030>"
  724. https://github.com/microsoft/Swin-Transformer
  725. """
  726. def __init__(
  727. self,
  728. dim: int,
  729. depth: int,
  730. num_heads: int,
  731. window_size: Sequence[int],
  732. drop_path: list,
  733. mlp_ratio: float = 4.0,
  734. qkv_bias: bool = False,
  735. drop: float = 0.0,
  736. attn_drop: float = 0.0,
  737. norm_layer: type[LayerNorm] = nn.LayerNorm,
  738. downsample: nn.Module | None = None,
  739. use_checkpoint: bool = False,
  740. ) -> None:
  741. """
  742. Args:
  743. dim: number of feature channels.
  744. depth: number of layers in each stage.
  745. num_heads: number of attention heads.
  746. window_size: local window size.
  747. drop_path: stochastic depth rate.
  748. mlp_ratio: ratio of mlp hidden dim to embedding dim.
  749. qkv_bias: add a learnable bias to query, key, value.
  750. drop: dropout rate.
  751. attn_drop: attention dropout rate.
  752. norm_layer: normalization layer.
  753. downsample: an optional downsampling layer at the end of the layer.
  754. use_checkpoint: use gradient checkpointing for reduced memory usage.
  755. """
  756. super().__init__()
  757. self.window_size = window_size
  758. self.shift_size = tuple(i // 2 for i in window_size)
  759. self.no_shift = tuple(0 for i in window_size)
  760. self.depth = depth
  761. self.use_checkpoint = use_checkpoint
  762. self.blocks = nn.ModuleList(
  763. [
  764. SwinTransformerBlock(
  765. dim=dim,
  766. num_heads=num_heads,
  767. window_size=self.window_size,
  768. shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
  769. mlp_ratio=mlp_ratio,
  770. qkv_bias=qkv_bias,
  771. drop=drop,
  772. attn_drop=attn_drop,
  773. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  774. norm_layer=norm_layer,
  775. use_checkpoint=use_checkpoint,
  776. )
  777. for i in range(depth)
  778. ]
  779. )
  780. self.downsample = downsample
  781. if callable(self.downsample):
  782. self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
  783. def forward(self, x):
  784. x_shape = x.size()
  785. if len(x_shape) == 5:
  786. b, c, d, h, w = x_shape
  787. window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
  788. x = rearrange(x, "b c d h w -> b d h w c")
  789. dp = int(np.ceil(d / window_size[0])) * window_size[0]
  790. hp = int(np.ceil(h / window_size[1])) * window_size[1]
  791. wp = int(np.ceil(w / window_size[2])) * window_size[2]
  792. attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
  793. for blk in self.blocks:
  794. x = blk(x, attn_mask)
  795. x = x.view(b, d, h, w, -1)
  796. if self.downsample is not None:
  797. x = self.downsample(x)
  798. x = rearrange(x, "b d h w c -> b c d h w")
  799. elif len(x_shape) == 4:
  800. b, c, h, w = x_shape
  801. window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
  802. x = rearrange(x, "b c h w -> b h w c")
  803. hp = int(np.ceil(h / window_size[0])) * window_size[0]
  804. wp = int(np.ceil(w / window_size[1])) * window_size[1]
  805. attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
  806. for blk in self.blocks:
  807. x = blk(x, attn_mask)
  808. x = x.view(b, h, w, -1)
  809. if self.downsample is not None:
  810. x = self.downsample(x)
  811. x = rearrange(x, "b h w c -> b c h w")
  812. return x
  813. class SwinTransformer(nn.Module):
  814. """
  815. Swin Transformer based on: "Liu et al.,
  816. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
  817. <https://arxiv.org/abs/2103.14030>"
  818. https://github.com/microsoft/Swin-Transformer
  819. """
  820. def __init__(
  821. self,
  822. in_chans: int,
  823. embed_dim: int,
  824. window_size: Sequence[int],
  825. patch_size: Sequence[int],
  826. depths: Sequence[int],
  827. num_heads: Sequence[int],
  828. mlp_ratio: float = 4.0,
  829. qkv_bias: bool = True,
  830. drop_rate: float = 0.0,
  831. attn_drop_rate: float = 0.0,
  832. drop_path_rate: float = 0.0,
  833. norm_layer: type[LayerNorm] = nn.LayerNorm,
  834. patch_norm: bool = False,
  835. use_checkpoint: bool = False,
  836. spatial_dims: int = 3,
  837. downsample="merging",
  838. use_v2=False,
  839. ) -> None:
  840. """
  841. Args:
  842. in_chans: dimension of input channels.
  843. embed_dim: number of linear projection output channels.
  844. window_size: local window size.
  845. patch_size: patch size.
  846. depths: number of layers in each stage.
  847. num_heads: number of attention heads.
  848. mlp_ratio: ratio of mlp hidden dim to embedding dim.
  849. qkv_bias: add a learnable bias to query, key, value.
  850. drop_rate: dropout rate.
  851. attn_drop_rate: attention dropout rate.
  852. drop_path_rate: stochastic depth rate.
  853. norm_layer: normalization layer.
  854. patch_norm: add normalization after patch embedding.
  855. use_checkpoint: use gradient checkpointing for reduced memory usage.
  856. spatial_dims: spatial dimension.
  857. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
  858. user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
  859. The default is currently `"merging"` (the original version defined in v0.9.0).
  860. use_v2: using swinunetr_v2, which adds a residual convolution block at the beginning of each swin stage.
  861. """
  862. super().__init__()
  863. self.num_layers = len(depths)
  864. self.embed_dim = embed_dim
  865. self.patch_norm = patch_norm
  866. self.window_size = window_size
  867. self.patch_size = patch_size
  868. self.patch_embed = PatchEmbed(
  869. patch_size=self.patch_size,
  870. in_chans=in_chans,
  871. embed_dim=embed_dim,
  872. norm_layer=norm_layer if self.patch_norm else None, # type: ignore
  873. spatial_dims=spatial_dims,
  874. )
  875. self.pos_drop = nn.Dropout(p=drop_rate)
  876. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  877. self.use_v2 = use_v2
  878. self.layers1 = nn.ModuleList()
  879. self.layers2 = nn.ModuleList()
  880. self.layers3 = nn.ModuleList()
  881. self.layers4 = nn.ModuleList()
  882. if self.use_v2:
  883. self.layers1c = nn.ModuleList()
  884. self.layers2c = nn.ModuleList()
  885. self.layers3c = nn.ModuleList()
  886. self.layers4c = nn.ModuleList()
  887. down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
  888. for i_layer in range(self.num_layers):
  889. layer = BasicLayer(
  890. dim=int(embed_dim * 2 ** i_layer),
  891. depth=depths[i_layer],
  892. num_heads=num_heads[i_layer],
  893. window_size=self.window_size,
  894. drop_path=dpr[sum(depths[:i_layer]): sum(depths[: i_layer + 1])],
  895. mlp_ratio=mlp_ratio,
  896. qkv_bias=qkv_bias,
  897. drop=drop_rate,
  898. attn_drop=attn_drop_rate,
  899. norm_layer=norm_layer,
  900. downsample=down_sample_mod,
  901. use_checkpoint=use_checkpoint,
  902. )
  903. if i_layer == 0:
  904. self.layers1.append(layer)
  905. elif i_layer == 1:
  906. self.layers2.append(layer)
  907. elif i_layer == 2:
  908. self.layers3.append(layer)
  909. elif i_layer == 3:
  910. self.layers4.append(layer)
  911. if self.use_v2:
  912. layerc = UnetrBasicBlock(
  913. spatial_dims=spatial_dims,
  914. in_channels=embed_dim * 2 ** i_layer,
  915. out_channels=embed_dim * 2 ** i_layer,
  916. kernel_size=3,
  917. stride=1,
  918. norm_name="instance",
  919. res_block=True,
  920. )
  921. if i_layer == 0:
  922. self.layers1c.append(layerc)
  923. elif i_layer == 1:
  924. self.layers2c.append(layerc)
  925. elif i_layer == 2:
  926. self.layers3c.append(layerc)
  927. elif i_layer == 3:
  928. self.layers4c.append(layerc)
  929. self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
  930. def proj_out(self, x, normalize=False):
  931. if normalize:
  932. x_shape = x.shape
  933. # Force trace() to generate a constant by casting to int
  934. ch = int(x_shape[1])
  935. if len(x_shape) == 5:
  936. x = rearrange(x, "n c d h w -> n d h w c")
  937. x = F.layer_norm(x, [ch])
  938. x = rearrange(x, "n d h w c -> n c d h w")
  939. elif len(x_shape) == 4:
  940. x = rearrange(x, "n c h w -> n h w c")
  941. x = F.layer_norm(x, [ch])
  942. x = rearrange(x, "n h w c -> n c h w")
  943. return x
  944. def forward(self, x, normalize=True):
  945. x0 = self.patch_embed(x)
  946. x0 = self.pos_drop(x0)
  947. x0_out = self.proj_out(x0, normalize)
  948. if self.use_v2:
  949. x0 = self.layers1c[0](x0.contiguous())
  950. x1 = self.layers1[0](x0.contiguous())
  951. x1_out = self.proj_out(x1, normalize)
  952. if self.use_v2:
  953. x1 = self.layers2c[0](x1.contiguous())
  954. x2 = self.layers2[0](x1.contiguous())
  955. x2_out = self.proj_out(x2, normalize)
  956. if self.use_v2:
  957. x2 = self.layers3c[0](x2.contiguous())
  958. x3 = self.layers3[0](x2.contiguous())
  959. x3_out = self.proj_out(x3, normalize)
  960. if self.use_v2:
  961. x3 = self.layers4c[0](x3.contiguous())
  962. x4 = self.layers4[0](x3.contiguous())
  963. x4_out = self.proj_out(x4, normalize)
  964. return [x0_out, x1_out, x2_out, x3_out, x4_out]
  965. def filter_swinunetr(key, value):
  966. """
  967. A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
  968. This function is typically used with `monai.networks.copy_model_state`
  969. [1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
  970. <https://arxiv.org/abs/2307.16896>"
  971. Args:
  972. key: the key in the source state dict used for the update.
  973. value: the value in the source state dict used for the update.
  974. Examples::
  975. import torch
  976. from monai.apps import download_url
  977. from monai.networks.utils import copy_model_state
  978. from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr
  979. model = SwinUNETR(in_channels=1, out_channels=3, feature_size=48)
  980. resource = (
  981. "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
  982. )
  983. ssl_weights_path = "./ssl_pretrained_weights.pth"
  984. download_url(resource, ssl_weights_path)
  985. ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"]
  986. dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
  987. """
  988. if key in [
  989. "encoder.mask_token",
  990. "encoder.norm.weight",
  991. "encoder.norm.bias",
  992. "out.conv.conv.weight",
  993. "out.conv.conv.bias",
  994. ]:
  995. return None
  996. if key[:8] == "encoder.":
  997. if key[8:19] == "patch_embed":
  998. new_key = "swinViT." + key[8:]
  999. else:
  1000. new_key = "swinViT." + key[8:18] + key[20:]
  1001. return new_key, value
  1002. else:
  1003. return None