image_encoder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from typing import Optional, Tuple, Type
  9. from .common import LayerNorm2d, MLPBlock
  10. # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
  11. class ImageEncoderViT(nn.Module):
  12. def __init__(
  13. self,
  14. img_size: int = 1024,
  15. patch_size: int = 16,
  16. in_chans: int = 3,
  17. embed_dim: int = 768,
  18. depth: int = 12,
  19. num_heads: int = 12,
  20. mlp_ratio: float = 4.0,
  21. out_chans: int = 256,
  22. qkv_bias: bool = True,
  23. norm_layer: Type[nn.Module] = nn.LayerNorm,
  24. act_layer: Type[nn.Module] = nn.GELU,
  25. use_abs_pos: bool = True,
  26. use_rel_pos: bool = False,
  27. rel_pos_zero_init: bool = True,
  28. window_size: int = 0,
  29. global_attn_indexes: Tuple[int, ...] = (),
  30. ) -> None:
  31. """
  32. Args:
  33. img_size (int): Input image size.
  34. patch_size (int): Patch size.
  35. in_chans (int): Number of input image channels.
  36. embed_dim (int): Patch embedding dimension.
  37. depth (int): Depth of ViT.
  38. num_heads (int): Number of attention heads in each ViT block.
  39. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  40. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  41. norm_layer (nn.Module): Normalization layer.
  42. act_layer (nn.Module): Activation layer.
  43. use_abs_pos (bool): If True, use absolute positional embeddings.
  44. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  45. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  46. window_size (int): Window size for window attention blocks.
  47. global_attn_indexes (list): Indexes for blocks using global attention.
  48. """
  49. super().__init__()
  50. self.img_size = img_size
  51. self.patch_embed = PatchEmbed(
  52. kernel_size=(patch_size, patch_size),
  53. stride=(patch_size, patch_size),
  54. in_chans=in_chans,
  55. embed_dim=embed_dim,
  56. )
  57. self.pos_embed: Optional[nn.Parameter] = None
  58. if use_abs_pos:
  59. # Initialize absolute positional embedding with pretrain image size.
  60. self.pos_embed = nn.Parameter(
  61. torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
  62. )
  63. self.blocks = nn.ModuleList()
  64. for i in range(depth):
  65. block = Block(
  66. dim=embed_dim,
  67. num_heads=num_heads,
  68. mlp_ratio=mlp_ratio,
  69. qkv_bias=qkv_bias,
  70. norm_layer=norm_layer,
  71. act_layer=act_layer,
  72. use_rel_pos=use_rel_pos,
  73. rel_pos_zero_init=rel_pos_zero_init,
  74. window_size=window_size if i not in global_attn_indexes else 0,
  75. input_size=(img_size // patch_size, img_size // patch_size),
  76. )
  77. self.blocks.append(block)
  78. self.neck = nn.Sequential(
  79. nn.Conv2d(
  80. embed_dim,
  81. out_chans,
  82. kernel_size=1,
  83. bias=False,
  84. ),
  85. LayerNorm2d(out_chans),
  86. nn.Conv2d(
  87. out_chans,
  88. out_chans,
  89. kernel_size=3,
  90. padding=1,
  91. bias=False,
  92. ),
  93. LayerNorm2d(out_chans),
  94. )
  95. def forward(self, x: torch.Tensor) -> torch.Tensor:
  96. x = self.patch_embed(x)
  97. if self.pos_embed is not None:
  98. x = x + self.pos_embed
  99. for blk in self.blocks:
  100. x = blk(x)
  101. x = self.neck(x.permute(0, 3, 1, 2))
  102. return x
  103. class Block(nn.Module):
  104. """Transformer blocks with support of window attention and residual propagation blocks"""
  105. def __init__(
  106. self,
  107. dim: int,
  108. num_heads: int,
  109. mlp_ratio: float = 4.0,
  110. qkv_bias: bool = True,
  111. norm_layer: Type[nn.Module] = nn.LayerNorm,
  112. act_layer: Type[nn.Module] = nn.GELU,
  113. use_rel_pos: bool = False,
  114. rel_pos_zero_init: bool = True,
  115. window_size: int = 0,
  116. input_size: Optional[Tuple[int, int]] = None,
  117. ) -> None:
  118. """
  119. Args:
  120. dim (int): Number of input channels.
  121. num_heads (int): Number of attention heads in each ViT block.
  122. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  123. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  124. norm_layer (nn.Module): Normalization layer.
  125. act_layer (nn.Module): Activation layer.
  126. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  127. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  128. window_size (int): Window size for window attention blocks. If it equals 0, then
  129. use global attention.
  130. input_size (tuple(int, int) or None): Input resolution for calculating the relative
  131. positional parameter size.
  132. """
  133. super().__init__()
  134. self.norm1 = norm_layer(dim)
  135. self.attn = Attention(
  136. dim,
  137. num_heads=num_heads,
  138. qkv_bias=qkv_bias,
  139. use_rel_pos=use_rel_pos,
  140. rel_pos_zero_init=rel_pos_zero_init,
  141. input_size=input_size if window_size == 0 else (window_size, window_size),
  142. )
  143. self.norm2 = norm_layer(dim)
  144. self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
  145. self.window_size = window_size
  146. def forward(self, x: torch.Tensor) -> torch.Tensor:
  147. shortcut = x
  148. x = self.norm1(x)
  149. # Window partition
  150. if self.window_size > 0:
  151. H, W = x.shape[1], x.shape[2]
  152. x, pad_hw = window_partition(x, self.window_size)
  153. x = self.attn(x)
  154. # Reverse window partition
  155. if self.window_size > 0:
  156. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  157. x = shortcut + x
  158. x = x + self.mlp(self.norm2(x))
  159. return x
  160. class Attention(nn.Module):
  161. """Multi-head Attention block with relative position embeddings."""
  162. def __init__(
  163. self,
  164. dim: int,
  165. num_heads: int = 8,
  166. qkv_bias: bool = True,
  167. use_rel_pos: bool = False,
  168. rel_pos_zero_init: bool = True,
  169. input_size: Optional[Tuple[int, int]] = None,
  170. ) -> None:
  171. """
  172. Args:
  173. dim (int): Number of input channels.
  174. num_heads (int): Number of attention heads.
  175. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  176. rel_pos (bool): If True, add relative positional embeddings to the attention map.
  177. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  178. input_size (tuple(int, int) or None): Input resolution for calculating the relative
  179. positional parameter size.
  180. """
  181. super().__init__()
  182. self.num_heads = num_heads
  183. head_dim = dim // num_heads
  184. self.scale = head_dim**-0.5
  185. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  186. self.proj = nn.Linear(dim, dim)
  187. self.use_rel_pos = use_rel_pos
  188. if self.use_rel_pos:
  189. assert (
  190. input_size is not None
  191. ), "Input size must be provided if using relative positional encoding."
  192. # initialize relative positional embeddings
  193. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  194. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  195. def forward(self, x: torch.Tensor) -> torch.Tensor:
  196. B, H, W, _ = x.shape
  197. # qkv with shape (3, B, nHead, H * W, C)
  198. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  199. # q, k, v with shape (B * nHead, H * W, C)
  200. q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
  201. attn = (q * self.scale) @ k.transpose(-2, -1)
  202. if self.use_rel_pos:
  203. attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
  204. attn = attn.softmax(dim=-1)
  205. x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
  206. x = self.proj(x)
  207. return x
  208. def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
  209. """
  210. Partition into non-overlapping windows with padding if needed.
  211. Args:
  212. x (tensor): input tokens with [B, H, W, C].
  213. window_size (int): window size.
  214. Returns:
  215. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  216. (Hp, Wp): padded height and width before partition
  217. """
  218. B, H, W, C = x.shape
  219. pad_h = (window_size - H % window_size) % window_size
  220. pad_w = (window_size - W % window_size) % window_size
  221. if pad_h > 0 or pad_w > 0:
  222. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  223. Hp, Wp = H + pad_h, W + pad_w
  224. x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
  225. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  226. return windows, (Hp, Wp)
  227. def window_unpartition(
  228. windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
  229. ) -> torch.Tensor:
  230. """
  231. Window unpartition into original sequences and removing padding.
  232. Args:
  233. windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  234. window_size (int): window size.
  235. pad_hw (Tuple): padded height and width (Hp, Wp).
  236. hw (Tuple): original height and width (H, W) before padding.
  237. Returns:
  238. x: unpartitioned sequences with [B, H, W, C].
  239. """
  240. Hp, Wp = pad_hw
  241. H, W = hw
  242. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  243. x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
  244. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
  245. if Hp > H or Wp > W:
  246. x = x[:, :H, :W, :].contiguous()
  247. return x
  248. def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  249. """
  250. Get relative positional embeddings according to the relative positions of
  251. query and key sizes.
  252. Args:
  253. q_size (int): size of query q.
  254. k_size (int): size of key k.
  255. rel_pos (Tensor): relative position embeddings (L, C).
  256. Returns:
  257. Extracted positional embeddings according to relative positions.
  258. """
  259. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  260. # Interpolate rel pos if needed.
  261. if rel_pos.shape[0] != max_rel_dist:
  262. # Interpolate rel pos.
  263. rel_pos_resized = F.interpolate(
  264. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  265. size=max_rel_dist,
  266. mode="linear",
  267. )
  268. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  269. else:
  270. rel_pos_resized = rel_pos
  271. # Scale the coords with short length if shapes for q and k are different.
  272. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  273. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  274. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  275. return rel_pos_resized[relative_coords.long()]
  276. def add_decomposed_rel_pos(
  277. attn: torch.Tensor,
  278. q: torch.Tensor,
  279. rel_pos_h: torch.Tensor,
  280. rel_pos_w: torch.Tensor,
  281. q_size: Tuple[int, int],
  282. k_size: Tuple[int, int],
  283. ) -> torch.Tensor:
  284. """
  285. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  286. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
  287. Args:
  288. attn (Tensor): attention map.
  289. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
  290. rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
  291. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
  292. q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
  293. k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
  294. Returns:
  295. attn (Tensor): attention map with added relative positional embeddings.
  296. """
  297. q_h, q_w = q_size
  298. k_h, k_w = k_size
  299. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  300. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  301. B, _, dim = q.shape
  302. r_q = q.reshape(B, q_h, q_w, dim)
  303. rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
  304. rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
  305. attn = (
  306. attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
  307. ).view(B, q_h * q_w, k_h * k_w)
  308. return attn
  309. class PatchEmbed(nn.Module):
  310. """
  311. Image to Patch Embedding.
  312. """
  313. def __init__(
  314. self,
  315. kernel_size: Tuple[int, int] = (16, 16),
  316. stride: Tuple[int, int] = (16, 16),
  317. padding: Tuple[int, int] = (0, 0),
  318. in_chans: int = 3,
  319. embed_dim: int = 768,
  320. ) -> None:
  321. """
  322. Args:
  323. kernel_size (Tuple): kernel size of the projection layer.
  324. stride (Tuple): stride of the projection layer.
  325. padding (Tuple): padding size of the projection layer.
  326. in_chans (int): Number of input image channels.
  327. embed_dim (int): Patch embedding dimension.
  328. """
  329. super().__init__()
  330. self.proj = nn.Conv2d(
  331. in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
  332. )
  333. def forward(self, x: torch.Tensor) -> torch.Tensor:
  334. x = self.proj(x)
  335. # B C H W -> B H W C
  336. x = x.permute(0, 2, 3, 1)
  337. return x