layers_2d.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """
  2. 通用基础层(2D)。
  3. """
  4. from collections import OrderedDict
  5. import torch
  6. import torch.nn as nn
  7. from timm.layers.drop import DropPath
  8. from timm.layers.mlp import Mlp
  9. from timm.layers.squeeze_excite import SqueezeExcite
  10. from timm.layers.weight_init import trunc_normal_
  11. class Scale(nn.Module):
  12. def __init__(self, dims, init_scale=1.0):
  13. super().__init__()
  14. self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
  15. def forward(self, x):
  16. return self.weight * x
  17. class Residual(nn.Module):
  18. def __init__(self, module, drop=0.0):
  19. super().__init__()
  20. self.module = module
  21. self.drop = drop
  22. def forward(self, x):
  23. if self.training and self.drop > 0.0:
  24. keep = torch.rand(x.size(0), 1, 1, 1, device=x.device)
  25. keep = keep.ge_(self.drop).div(1.0 - self.drop).detach()
  26. return x + self.module(x) * keep
  27. return x + self.module(x)
  28. class FFN2d(nn.Module):
  29. def __init__(self, embed_dim, hidden_dim):
  30. super().__init__()
  31. self.mlp = Mlp(
  32. in_features=embed_dim,
  33. hidden_features=hidden_dim,
  34. out_features=embed_dim,
  35. act_layer=nn.ReLU,
  36. use_conv=True,
  37. bias=False,
  38. )
  39. for m in self.mlp.modules():
  40. if isinstance(m, nn.BatchNorm2d) and m.num_features == embed_dim:
  41. nn.init.constant_(m.weight, 0.0)
  42. nn.init.constant_(m.bias, 0.0)
  43. def forward(self, x):
  44. return self.mlp(x)
  45. class BNLinear1d(nn.Sequential):
  46. def __init__(self, in_features, out_features, bias=True, std=0.02):
  47. bn = nn.BatchNorm1d(in_features)
  48. linear = nn.Linear(in_features, out_features, bias=bias)
  49. trunc_normal_(linear.weight, std=std)
  50. if bias:
  51. nn.init.constant_(linear.bias, 0)
  52. super().__init__(OrderedDict([("bn", bn), ("linear", linear)]))
  53. class Conv2dBN(nn.Sequential):
  54. def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0,
  55. dilation=1, groups=1, bn_weight_init=1.0):
  56. conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
  57. bn = nn.BatchNorm2d(out_channels)
  58. nn.init.constant_(bn.weight, bn_weight_init)
  59. nn.init.constant_(bn.bias, 0)
  60. super().__init__(OrderedDict([("conv", conv), ("bn", bn)]))
  61. class DWConv2dBNReLU(nn.Sequential):
  62. def __init__(self, in_channels, out_channels, kernel_size=3, bn_weight_init=1.0):
  63. super().__init__(OrderedDict([
  64. ("dwconv3x3",
  65. nn.Conv2d(in_channels, in_channels, kernel_size, 1, kernel_size // 2, groups=in_channels, bias=False)),
  66. ("bn1", nn.BatchNorm2d(in_channels)),
  67. ("relu", nn.ReLU(inplace=True)),
  68. ("dwconv1x1", nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=in_channels, bias=False)),
  69. ("bn2", nn.BatchNorm2d(out_channels)),
  70. ]))
  71. for bn_name in ["bn1", "bn2"]:
  72. bn = getattr(self, bn_name)
  73. nn.init.constant_(bn.weight, bn_weight_init)
  74. nn.init.constant_(bn.bias, 0)
  75. class PatchMerging2d(nn.Module):
  76. def __init__(self, dim, out_dim):
  77. super().__init__()
  78. hidden_dim = int(dim * 4)
  79. self.conv1 = Conv2dBN(dim, hidden_dim, 1, 1, 0)
  80. self.act = nn.ReLU(inplace=True)
  81. self.conv2 = Conv2dBN(hidden_dim, hidden_dim, 3, 2, 1, groups=hidden_dim)
  82. self.se = SqueezeExcite(hidden_dim, rd_ratio=0.25)
  83. self.conv3 = Conv2dBN(hidden_dim, out_dim, 1, 1, 0)
  84. def forward(self, x):
  85. return self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))