| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- """
- 通用基础层(2D)。
- """
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- from timm.layers.drop import DropPath
- from timm.layers.mlp import Mlp
- from timm.layers.squeeze_excite import SqueezeExcite
- from timm.layers.weight_init import trunc_normal_
- class Scale(nn.Module):
- def __init__(self, dims, init_scale=1.0):
- super().__init__()
- self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
- def forward(self, x):
- return self.weight * x
- class Residual(nn.Module):
- def __init__(self, module, drop=0.0):
- super().__init__()
- self.module = module
- self.drop = drop
- def forward(self, x):
- if self.training and self.drop > 0.0:
- keep = torch.rand(x.size(0), 1, 1, 1, device=x.device)
- keep = keep.ge_(self.drop).div(1.0 - self.drop).detach()
- return x + self.module(x) * keep
- return x + self.module(x)
- class FFN2d(nn.Module):
- def __init__(self, embed_dim, hidden_dim):
- super().__init__()
- self.mlp = Mlp(
- in_features=embed_dim,
- hidden_features=hidden_dim,
- out_features=embed_dim,
- act_layer=nn.ReLU,
- use_conv=True,
- bias=False,
- )
- for m in self.mlp.modules():
- if isinstance(m, nn.BatchNorm2d) and m.num_features == embed_dim:
- nn.init.constant_(m.weight, 0.0)
- nn.init.constant_(m.bias, 0.0)
- def forward(self, x):
- return self.mlp(x)
- class BNLinear1d(nn.Sequential):
- def __init__(self, in_features, out_features, bias=True, std=0.02):
- bn = nn.BatchNorm1d(in_features)
- linear = nn.Linear(in_features, out_features, bias=bias)
- trunc_normal_(linear.weight, std=std)
- if bias:
- nn.init.constant_(linear.bias, 0)
- super().__init__(OrderedDict([("bn", bn), ("linear", linear)]))
- class Conv2dBN(nn.Sequential):
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0,
- dilation=1, groups=1, bn_weight_init=1.0):
- conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
- bn = nn.BatchNorm2d(out_channels)
- nn.init.constant_(bn.weight, bn_weight_init)
- nn.init.constant_(bn.bias, 0)
- super().__init__(OrderedDict([("conv", conv), ("bn", bn)]))
- class DWConv2dBNReLU(nn.Sequential):
- def __init__(self, in_channels, out_channels, kernel_size=3, bn_weight_init=1.0):
- super().__init__(OrderedDict([
- ("dwconv3x3",
- nn.Conv2d(in_channels, in_channels, kernel_size, 1, kernel_size // 2, groups=in_channels, bias=False)),
- ("bn1", nn.BatchNorm2d(in_channels)),
- ("relu", nn.ReLU(inplace=True)),
- ("dwconv1x1", nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=in_channels, bias=False)),
- ("bn2", nn.BatchNorm2d(out_channels)),
- ]))
- for bn_name in ["bn1", "bn2"]:
- bn = getattr(self, bn_name)
- nn.init.constant_(bn.weight, bn_weight_init)
- nn.init.constant_(bn.bias, 0)
- class PatchMerging2d(nn.Module):
- def __init__(self, dim, out_dim):
- super().__init__()
- hidden_dim = int(dim * 4)
- self.conv1 = Conv2dBN(dim, hidden_dim, 1, 1, 0)
- self.act = nn.ReLU(inplace=True)
- self.conv2 = Conv2dBN(hidden_dim, hidden_dim, 3, 2, 1, groups=hidden_dim)
- self.se = SqueezeExcite(hidden_dim, rd_ratio=0.25)
- self.conv3 = Conv2dBN(hidden_dim, out_dim, 1, 1, 0)
- def forward(self, x):
- return self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
|