""" 通用基础层(2D)。 """ from collections import OrderedDict import torch import torch.nn as nn from timm.layers.drop import DropPath from timm.layers.mlp import Mlp from timm.layers.squeeze_excite import SqueezeExcite from timm.layers.weight_init import trunc_normal_ class Scale(nn.Module): def __init__(self, dims, init_scale=1.0): super().__init__() self.weight = nn.Parameter(torch.ones(*dims) * init_scale) def forward(self, x): return self.weight * x class Residual(nn.Module): def __init__(self, module, drop=0.0): super().__init__() self.module = module self.drop = drop def forward(self, x): if self.training and self.drop > 0.0: keep = torch.rand(x.size(0), 1, 1, 1, device=x.device) keep = keep.ge_(self.drop).div(1.0 - self.drop).detach() return x + self.module(x) * keep return x + self.module(x) class FFN2d(nn.Module): def __init__(self, embed_dim, hidden_dim): super().__init__() self.mlp = Mlp( in_features=embed_dim, hidden_features=hidden_dim, out_features=embed_dim, act_layer=nn.ReLU, use_conv=True, bias=False, ) for m in self.mlp.modules(): if isinstance(m, nn.BatchNorm2d) and m.num_features == embed_dim: nn.init.constant_(m.weight, 0.0) nn.init.constant_(m.bias, 0.0) def forward(self, x): return self.mlp(x) class BNLinear1d(nn.Sequential): def __init__(self, in_features, out_features, bias=True, std=0.02): bn = nn.BatchNorm1d(in_features) linear = nn.Linear(in_features, out_features, bias=bias) trunc_normal_(linear.weight, std=std) if bias: nn.init.constant_(linear.bias, 0) super().__init__(OrderedDict([("bn", bn), ("linear", linear)])) class Conv2dBN(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1.0): conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) bn = nn.BatchNorm2d(out_channels) nn.init.constant_(bn.weight, bn_weight_init) nn.init.constant_(bn.bias, 0) super().__init__(OrderedDict([("conv", conv), ("bn", bn)])) class DWConv2dBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, bn_weight_init=1.0): super().__init__(OrderedDict([ ("dwconv3x3", nn.Conv2d(in_channels, in_channels, kernel_size, 1, kernel_size // 2, groups=in_channels, bias=False)), ("bn1", nn.BatchNorm2d(in_channels)), ("relu", nn.ReLU(inplace=True)), ("dwconv1x1", nn.Conv2d(in_channels, out_channels, 1, 1, 0, groups=in_channels, bias=False)), ("bn2", nn.BatchNorm2d(out_channels)), ])) for bn_name in ["bn1", "bn2"]: bn = getattr(self, bn_name) nn.init.constant_(bn.weight, bn_weight_init) nn.init.constant_(bn.bias, 0) class PatchMerging2d(nn.Module): def __init__(self, dim, out_dim): super().__init__() hidden_dim = int(dim * 4) self.conv1 = Conv2dBN(dim, hidden_dim, 1, 1, 0) self.act = nn.ReLU(inplace=True) self.conv2 = Conv2dBN(hidden_dim, hidden_dim, 3, 2, 1, groups=hidden_dim) self.se = SqueezeExcite(hidden_dim, rd_ratio=0.25) self.conv3 = Conv2dBN(hidden_dim, out_dim, 1, 1, 0) def forward(self, x): return self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))