当前项目详解与纯文本架构流程图.md 24 KB

当前项目详解与纯文本架构流程图

1. 当前项目定位

X_SSL_Net 当前 active 主线是一个面向 2D 超声图像分割的全监督训练工程。

当前真实训练链路:

shell script
-> tools/train.py
-> SupervisedSegmentationTrainer
-> SegmentationRecordDataset / DataLoader
-> XNet2d
-> seg_logits
-> DiceCE loss / BCE fallback
-> Dice / IoU validation
-> best.pth / last.pth

当前真实模型主线:

XNet2d = CNN-Wavelet-VMamba encoder + plain U-Net skip decoder segmentation network

当前主训练只使用一个分割头:

outputs["seg_logits"]

当前主线不调用:

  1. lib/sam2
  2. lib/SwinTransformer
  3. SwinV2 segmentation config
  4. boundary auxiliary head
  5. semi-supervised trainer

lib/sam2lib/SwinTransformer 目前作为外部代码资产保留,不进入当前训练路径。

2. 一句话总览

当前项目可以概括为:

用 XNet2d 在 BUSI / DDTI / TN3K / TG3K 等 2D 超声数据集上做全监督分割训练。
XNet2d 的 encoder 用 local + wavelet + VMamba-style SS2D 三分支建模,
decoder 用普通 U-Net 同尺度 skip + 频率细化恢复 mask。

3. 启动入口

3.1 推荐 shell 入口

最常用入口:

DATASET=BUSI bash tools/run_optimized_supervised.sh

短跑调试入口:

DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \
bash tools/run_optimized_supervised.sh

3.2 shell 脚本职责

文件:

tools/run_optimized_supervised.sh

它做四件事:

  1. 解析 DATASET
  2. 映射数据集根目录
  3. 对需要项目级划分的数据集生成或加载 train/val
  4. 调用 tools/train.py

支持的数据集名称:

BUS-UCLM
BUSI
BUS-BRA
BUS_UC
CCAUI
DDTI
OTU_2d
TN3K
TG3K

数据集根目录映射:

BUSI      -> data/BUSI
DDTI      -> data/DDTI
TN3K      -> data/TN3K
TG3K      -> data/TG3K
BUS_UC    -> data/BUS_UC
...

项目级 split 数据集:

BUS-UCLM, BUSI, BUS-BRA, BUS_UC, CCAUI, DDTI

官方 split 数据集:

OTU_2d, TN3K, TG3K

4. 从 shell 到 Python 的总流程图

User command
  |
  |  DATASET=BUSI EXTRA_SET_ARGS="..." bash tools/run_optimized_supervised.sh
  v
+----------------------------------------------------------------------------------+
| tools/run_optimized_supervised.sh                                                       |
+----------------------------------------------------------------------------------+
| 1. read DATASET / SEED / EXTRA_SET_ARGS                                           |
| 2. dataset_root(DATASET)                                                          |
| 3. if DATASET needs project split:                                                |
|      python scripts/generate_project_split.py --dataset DATASET --root ROOT       |
| 4. python tools/train.py                                                          |
|      --config configs/segmentation/optimized/*.yaml                     |
|      --set dataset.dataset_name=DATASET dataset.root=ROOT ... EXTRA_SET_ARGS      |
+----------------------------------------------------------------------------------+
  |
  v
+----------------------------------------------------------------------------------+
| tools/train.py                                                                    |
+----------------------------------------------------------------------------------+
| 1. parse --config / --trainer / --set                                             |
| 2. load yaml config                                                               |
| 3. apply dotlist overrides                                                        |
| 4. optional override trainer.name                                                 |
| 5. build_trainer(cfg)                                                             |
| 6. trainer.train()                                                                |
+----------------------------------------------------------------------------------+

5. 配置系统

当前主配置:

configs/segmentation/optimized/*.yaml

当前保留的 segmentation 配置:

configs/segmentation/optimized/*.yaml
configs/segmentation/us_exp_sup_busi.yaml
configs/segmentation/us_exp_sup_busi_ablation.yaml

5.1 配置覆盖方式

tools/train.py 支持:

--set key=value key=value ...

例如:

--set train.epochs=1 train.batch_size=8 model.use_frequency_refine=false

覆盖逻辑:

load_yaml_config(path)
  |
  v
apply_dotlist_overrides(cfg, args.set)
  |
  v
nested dict update

5.2 当前关键配置

训练:

train:
  epochs: 200
  batch_size: 4
  val_batch_size: 4
  amp: true
  num_workers: 4
  pin_memory: true
  persistent_workers: true
  prefetch_factor: 2
  device: cuda
  grad_clip:
    enabled: true
    max_norm: 1.0

数据:

dataset:
  dataset_name: BUSI
  root: data/BUSI
  split: train
  val_split: val
  image_size: [256, 256]
  in_channels: 3
  num_classes: 1

模型:

model:
  in_channels: 3
  encoder_channels: [32, 64, 128, 192]
  encoder_depths: [2, 2, 2, 2]
  decoder_channels: [128, 64, 32]
  stem_channels: 24
  bottleneck_depth: 1
  global_ratio: 2.0
  wavelet_type: haar
  wavelet_level: 1
  use_wavelet_branch: true
  use_global_branch_stage1: false
  ssm_d_state: 16
  ssm_forward_type: v3
  ssm_backend: auto
  use_frequency_refine: true
  guide_mode: affine  # compatibility only; current decoder no longer uses guide path
  out_channels: null

优化:

optimizer:
  name: adamw
  lr: 1.0e-4
  weight_decay: 0.05

scheduler:
  name: cosine
  warmup:
    name: linear
    params:
      start_factor: 0.1
      total_iters: 10
  params:
    T_max: 190
    eta_min: 1.0e-6

loss 与 metric:

loss:
  name: dicece
  task_mode: binary
  params:
    include_background: true
    lambda_dice: 0.7
    lambda_ce: 0.3

validation:
  threshold: 0.5
  metrics:
    task_mode: binary
    metrics:
      - name: dice
      - name: iou

6. Trainer 构建流程

入口:

lib/trainers/builder.py::build_trainer

当前 trainer:

lib/trainers/supervised.py::SupervisedSegmentationTrainer

构建流程:

build_trainer(cfg)
  |
  v
read cfg.trainer.name
  |
  v
TRAINER_REGISTRY["supervised_segmentation"]
  |
  v
trainer = SupervisedSegmentationTrainer(cfg, args)
  |
  v
trainer.build()
  |
  v
return trainer

SupervisedSegmentationTrainer.build() 做:

1. dataset_cfg = cfg["dataset"]
2. model_cfg   = cfg["model"]
3. train_cfg   = cfg["train"]

4. build XNet2d from model_cfg
5. move model to device
6. build optimizer
7. build scheduler
8. build loss if cfg.loss is not null
9. build train dataloader
10. build validation dataloader
11. maybe resume checkpoint
12. maybe init SwanLab

7. BaseTrainer 公共职责

文件:

lib/trainers/base.py

公共职责:

BaseTrainer
├─ random seed
├─ device selection
├─ output directory
├─ AMP GradScaler
├─ batch size resolution
├─ dataloader construction helper
├─ validation metric construction
├─ checkpoint save / resume
├─ early stopping
├─ SwanLab logging
├─ training setup summary
├─ step performance logging
└─ epoch finalization

设备选择:

cfg.train.device == "cuda" and torch.cuda.is_available()
  -> cuda
else
  -> cpu

AMP 开关:

cfg.train.amp == true and device == cuda
  -> enabled
else
  -> disabled

当前已验证目标环境:

conda env: xnet_mamba
torch: 2.10.0+cu126
GPU: NVIDIA GeForce RTX 4070 Ti SUPER
selective_scan_cuda_oflex: available

8. 数据链路

8.1 数据 index 构建

入口:

lib/data/builder.py::build_dataset_index

核心 registry:

BUS-UCLM -> paired images/masks
BUSI     -> Dataset_BUSI_with_GT/{benign,malignant,normal}
BUS-BRA  -> prefixed image/mask matching
BUS_UC   -> All / Benign / Malignant folders
CCAUI    -> US images / Expert mask images
DDTI     -> XML annotation records
OTU_2d   -> images / annotations
TN3K     -> trainval/test image/mask folders
TG3K     -> thyroid-image / thyroid-mask

8.2 split 应用

入口:

lib/data/loaders.py::apply_official_split

流程:

build_dataset_index(dataset_name, root)
  |
  v
if split is requested:
  |
  +-- OTU_2d: read train.txt / val.txt
  |
  +-- TN3K: read tn3k-trainval.json or use test folder
  |
  +-- TG3K: read tg3k-trainval.json
  |
  +-- project split dataset:
        read data/<dataset>/splits/project/train.txt or val.txt

项目级 split 生成:

scripts/generate_project_split.py
  |
  v
generate_project_splits()
  |
  v
write:
  data/<dataset>/splits/project/train.txt
  data/<dataset>/splits/project/val.txt

8.3 Dataset 读取

文件:

lib/data/datasets.py::SegmentationRecordDataset

单样本读取:

record
  |
  +-- image_path -> PIL RGB -> float32 [3,H,W] in [0,1]
  |
  +-- mask_path  -> PIL L -> binary float32 [1,H,W]
  |
  +-- DDTI special:
        annotation_path XML -> build_ddti_mask() -> binary [1,H,W]
  |
  +-- joint augmentation
  |
  +-- resize image to dataset.image_size
  |
  +-- resize mask to dataset.image_size
  |
  v
{
  "image": image,
  "mask": mask,
  "dataset_name": ...,
  "sample_id": ...,
  "split": ...,
  "class_name": ...,
  "meta": ...
}

8.4 augmentation

文件:

lib/data/augment.py::SegmentationAugmentation

当前支持:

spatial:
  random horizontal flip
  random vertical flip
  random rotate 90

intensity:
  random brightness / contrast
  random gaussian noise
  clamp to [0,1]

8.5 collate

文件:

lib/data/collate.py::record_collate_fn

逻辑:

if all tensor shapes same:
  torch.stack(values, dim=0)
else:
  keep list

strings / dict / metadata:
  keep list

最终 batch:

image: [B,3,256,256]
mask : [B,1,256,256]

9. Dataloader 流程图

SupervisedSegmentationTrainer.build()
  |
  v
_build_segmentation_loader(split="train")
  |
  v
build_dataloader()
  |
  v
build_record_dataset()
  |
  v
build_dataset_index()
  |
  v
apply_official_split()
  |
  v
SegmentationRecordDataset(records, transforms)
  |
  v
DataLoader(
  batch_size,
  shuffle,
  num_workers,
  pin_memory,
  persistent_workers,
  prefetch_factor,
  collate_fn=record_collate_fn
)

注意:DataLoader worker 的真实启动通常发生在第一次迭代时,也就是 ======== END TRAINING SETUP ======== 之后。若 num_workers > 0,第一批数据可能出现一次性等待。

10. XNet2d 总体结构

文件:

lib/modules/xnet_2d.py

当前默认参数量:

total parameters:     9,432,129
trainable parameters: 9,432,129

顶层结构:

XNet2d
├─ XNetEncoder2d
│  ├─ XNetStem2d
│  ├─ Encoder Stage 1: XTEB2d x 2
│  ├─ Downsample 1
│  ├─ Encoder Stage 2: XTEB2d x 2
│  ├─ Downsample 2
│  ├─ Encoder Stage 3: XTEB2d x 2
│  ├─ Downsample 3
│  └─ Encoder Stage 4: XTEB2d x 2
│
├─ Bottleneck: XTEB2d x 1
│
├─ XNetDecoder2d
│  ├─ dec4: XCRB2d(E4, E3)
│  ├─ dec3: XCRB2d(D4, E2)
│  ├─ dec2: XCRB2d(D3, E1)
│  └─ head_refine
│
└─ XNetSegHead2d

11. XNet2d 纯文本架构图

以输入 [B,3,256,256] 为例,默认通道为 [32,64,128,192]

Input
[B, 3, 256, 256]
  |
  v
XNetStem2d
  Conv3x3 s2:       [B, 24, 128, 128]
  DWConv3x3:        [B, 24, 128, 128]
  PWConv1x1:        [B, 32, 128, 128]
  Conv3x3 s2:       [B, 32,  64,  64]
  |
  v
E1 = Encoder Stage 1, XTEB x2
[B, 32, 64, 64]
  |
  v
Down1
[B, 64, 32, 32]
  |
  v
E2 = Encoder Stage 2, XTEB x2
[B, 64, 32, 32]
  |
  v
Down2
[B, 128, 16, 16]
  |
  v
E3 = Encoder Stage 3, XTEB x2
[B, 128, 16, 16]
  |
  v
Down3
[B, 192, 8, 8]
  |
  v
E4 = Encoder Stage 4, XTEB x2
[B, 192, 8, 8]
  |
  v
Bottleneck XTEB x1
[B, 192, 8, 8]

Decoder:

E4 [B,192,8,8]
  |
  v
dec4 input:
  decoder input: E4 [B,192,8,8]
  same-scale skip: E3 [B,128,16,16]
  output D4 [B,128,16,16]

D4 [B,128,16,16]
  |
  v
dec3 input:
  decoder input: D4 [B,128,16,16]
  same-scale skip: E2 [B,64,32,32]
  output D3 [B,64,32,32]

D3 [B,64,32,32]
  |
  v
dec2 input:
  decoder input: D3 [B,64,32,32]
  same-scale skip: E1 [B,32,64,64]
  output D2 [B,32,64,64]

D2 [B,32,64,64]
  |
  v
HeadRefine
[B,32,64,64]
  |
  v
SegHead + upsample to input size
[B,1,256,256]

12. XTEB2d 详解

XTEB2d 是 encoder 的基本 block。

名字含义:

XTEB = XNet Tri-branch Encoding Block

输入输出:

input : X [B,C,H,W]
output: Y [B,C,H,W]

内部结构:

X
│
├─ pre_norm: 1x1 Conv2dBN
│
├─ Local branch
│   ├─ DWConv3x3 + PWConv1x1
│   └─ DWConv5x5 + PWConv1x1
│
├─ Wavelet branch
│   ├─ Haar DWT
│   │   ├─ LL
│   │   └─ LH/HL/HH high bands
│   ├─ LL projection
│   ├─ high-band projection
│   └─ inverse Haar transform
│
├─ Global branch
│   ├─ 1x1 pre projection
│   ├─ VMamba-style SS2D
│   └─ 1x1 post projection
│
├─ concat(local, wavelet, global)
├─ 1x1 fusion
├─ channel gate from GAP + MLP + sigmoid
├─ residual add
└─ lightweight FFN + residual add

公式化:

X0 = PreNorm(X)

L = Local(X0)
W = Wavelet(X0)
G = GlobalSS2D(X0)

F = Fuse([L,W,G])
Y = X + Post(F)
Z = Y + FFN(Y)

12.1 Local branch

职责:

局部纹理、边界、短程结构

结构:

DWConv3x3 -> ReLU -> PWConv1x1
DWConv5x5 -> ReLU -> PWConv1x1
sum

12.2 Wavelet branch

职责:

低频轮廓 + 高频边界/纹理

结构:

Haar DWT:
  LL      -> low-frequency structure
  LH/HL/HH -> high-frequency directional details

LL -> Conv projection
High bands -> depthwise conv + pointwise conv
IDWT -> output projection

当前限制:

wavelet_type = haar
wavelet_level = 1

12.3 Global SS2D branch

职责:

高效长程依赖建模、全局结构一致性

当前实现:

lib/modules/lib_mamba/vmamba.py::SS2D

来源:

VMamba-style SS2D operator

后端选择:

ssm_backend = auto
  |
  +-- if x.is_cuda:
        selective_scan_backend = oflex
        scan_force_torch = false
  |
  +-- else:
        selective_scan_backend = torch
        scan_force_torch = true

ssm_backend = oflex
  -> force oflex

ssm_backend = torch
  -> force torch fallback

当前默认:

ssm_forward_type = v3
ssm_backend = auto

xnet_mamba + RTX 4070 Ti SUPER 环境中已验证:

selective_scan_cuda_oflex import OK
WITH_SELECTIVESCAN_OFLEX = True

13. XCRB2d 详解

XCRB2d 是 decoder 的基本 block。当前实现已经去掉旧版 X-shaped 斜向 guide,恢复为普通 U-Net 同尺度 skip 连接。

名字含义:

XCRB = XNet Reconstruction Block

输入:

decoder input: deeper decoder or bottleneck feature
same-scale skip: encoder feature at target scale

内部结构:

decoder input
  |
  v
bilinear upsample to skip size
  |
  v
1x1 projection
  |
  +-----------------------------+
                                |
same-scale skip                 |
  |                             |
  v                             |
1x1 projection                  |
  |                             |
  +----------- concat ----------+
                  |
                  v
             3x3 fusion
                  |
                  v
        optional frequency refine
                  |
                  v
        residual spatial refine

13.1 普通 U-Net skip 信息流

当前 decoder 使用普通 U-Net 横向 skip:

same-scale path:
  E3 -> D4
  E2 -> D3
  E1 -> D2

纯文本示意:

Encoder: E1 ---------------------------> D2

Encoder: E2 ---------------------------> D3

Encoder: E3 ---------------------------> D4

Encoder: E4 -> Bottleneck -> decoder up path

13.2 guide_mode 兼容说明

配置里仍可能出现 guide_mode=affine,但当前 decoder 不再构建 XGuideProjector2dXGuideModulation2d,该参数只用于兼容旧 YAML,不参与前向计算。

13.3 Frequency refine

默认 use_frequency_refine=true

流程:

feature F
  |
  v
cast to float32 if needed
  |
  v
rfft2
  |
  +-- low frequency mask
  |
  +-- high frequency residual
  |
  v
low/high learnable gates
  |
  v
irfft2
  |
  v
cast back to input dtype
  |
  v
depthwise conv refine

这里显式将 FFT 计算放在 float32 中,避免 AMP 下触发 ComplexHalf support is experimental warning。

14. XNet2d forward 输出

XNet2d.forward(x) 返回:

{
    "logits": logits,
    "seg_logits": logits,
    "encoder_features": encoder_features,
    "decoder_features": decoder_features,
    "guides": [],
}

训练只使用:

outputs["seg_logits"]

其余输出用于:

debug
visualization
future auxiliary analysis

当前没有边界辅助输出。 当前没有斜向 guide 输出,guides 保持为空列表用于兼容旧调试接口。

15. 训练循环详解

入口:

SupervisedSegmentationTrainer.train()

流程:

train()
  |
  v
print training setup
  |
  v
for epoch in range(start_epoch, epochs):
  |
  +-- model.train()
  +-- optimizer.zero_grad()
  +-- for step, batch in train_loader:
        |
        +-- measure data_time
        |
        +-- image = batch["image"].to(device)
        +-- mask  = batch["mask"].to(device)
        |
        +-- with autocast(enabled=amp):
              outputs = model(image)
              seg_logits = outputs["seg_logits"]
              seg_loss = loss(seg_logits, mask)
              total_loss = seg_loss
        |
        +-- scaled_total_loss = total_loss / accum_steps
        +-- grad_scaler.scale(scaled_total_loss).backward()
        |
        +-- if should optimizer step:
              unscale gradients if grad clipping enabled
              clip grad norm
              grad_scaler.step(optimizer)
              grad_scaler.update()
              optimizer.zero_grad()
        |
        +-- log step every logging.log_interval
  |
  +-- scheduler.step()
  |
  +-- validate if enabled and interval matches
  |
  +-- finalize epoch
        |
        +-- merge train / val metrics
        +-- update best metric
        +-- save best.pth if improved
        +-- save last.pth if enabled
        +-- early stopping check

16. Loss 路径

当前配置使用:

MONAI DiceCELoss

构建路径:

cfg.loss
  |
  v
lib/tools/loss.py::build_loss
  |
  v
DiceCELoss(sigmoid=True, include_background=True, lambda_dice=0.7, lambda_ce=0.3)

如果 loss: null

torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)

该 fallback 适合环境临时缺 MONAI 时做 smoke test,不建议作为正式论文训练默认。

17. Validation 路径

验证函数:

SupervisedSegmentationTrainer._validate()

流程:

model.eval()
build validation metrics
for batch in val_loader:
  image -> device
  mask -> device
  outputs, losses = _compute_losses(image, mask)
  update loss sums
  update metrics with outputs["seg_logits"]

average val loss
compute Dice / IoU
reset metric states
return val_metrics

metric 输入处理:

binary mode:
  pred = sigmoid(logits) >= threshold
  target = target > 0

multiclass mode:
  pred = argmax(logits)
  target = one-hot or class index

当前默认:

threshold = 0.5
metrics = Dice, IoU

18. Checkpoint 路径

checkpoint 目录:

cfg.checkpoint.dir

默认脚本会覆盖为:

outputs/experiments/supervised/<DATASET>

保存文件:

best.pth
last.pth

checkpoint 内容:

epoch
cfg
metrics
model state_dict
optimizer state_dict
scheduler state_dict
grad_scaler state_dict
best_metric
no_improve_epochs

best 判断:

monitor = dice
monitor_mode = max

即:

val_dice 越大越好

19. 日志与性能字段

每隔 logging.log_interval step 打印:

epoch
step
num_steps
data_time
iter_time
gpu_memory_mb
lr
train_total
train_seg
train_grad_norm

含义:

data_time:
  从上一步结束到当前 batch 可用的时间。
  num_workers > 0 时,第一批 worker 启动开销发生在 END TRAINING SETUP 之后。

iter_time:
  当前 step 的训练计算时间,包括 forward、loss、backward、optimizer step。

gpu_memory_mb:
  torch.cuda.max_memory_allocated。

当前实测参考:

batch_size = 8
image_size = 256
ssm_backend = auto -> oflex
iter_time ≈ 0.09 - 0.11 s / step
GPU memory ≈ 850 MB

20. 从输入到 loss 的端到端流程图

Batch from DataLoader
  |
  +-- image [B,3,256,256]
  +-- mask  [B,1,256,256]
  |
  v
image.to(cuda), mask.to(cuda)
  |
  v
autocast(enabled=True)
  |
  v
XNet2d(image)
  |
  +-- encoder_features = [E1,E2,E3,E4]
  |
  +-- bottleneck(E4)
  |
  +-- decoder_out, decoder_features, guides=[]
  |
  +-- segmentation_head(decoder_out)
  |
  v
seg_logits [B,1,256,256]
  |
  v
DiceCELoss(seg_logits, mask)
  |
  v
total_loss
  |
  v
GradScaler.scale(total_loss).backward()
  |
  v
clip gradients
  |
  v
optimizer.step()

21. 关键运行命令

GPU 环境检查:

python -c "import sys, torch; print(sys.executable); print(torch.__version__); print(torch.cuda.is_available()); print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'no cuda')"

oflex 检查:

python -c "import torch; import selective_scan_cuda_oflex; print('oflex import OK')"
python -c "import torch; from lib.modules.lib_mamba import csms6s; print(csms6s.WITH_SELECTIVESCAN_OFLEX)"

前向检查:

python - <<'PY'
import torch
from lib.modules import XNet2d

model = XNet2d(in_channels=3, num_classes=1, ssm_backend="auto", ssm_forward_type="v3").cuda().eval()
x = torch.randn(1, 3, 128, 128, device="cuda")
with torch.no_grad():
    y = model(x)
print(sorted(y.keys()))
print(tuple(y["seg_logits"].shape))
PY

短训:

DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \
bash tools/run_optimized_supervised.sh

关闭 frequency refine 消融:

DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 model.use_frequency_refine=false logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8_no_freq" \
bash tools/run_optimized_supervised.sh

汇总结果:

bash tools/summarize_results.sh
sed -n '1,40p' results/experiment_summary.md

22. 推荐实验主线

第一阶段:训练链路稳定性

BUSI smoke
BUSI batch size 8
BUSI no frequency refine

第二阶段:甲状腺主线

DDTI
TN3K
TG3K
DDTI -> TN3K / TN3K -> DDTI 跨数据集泛化

第三阶段:乳腺扩展

BUSI
BUS_UC
BUS-BRA
BUS-UCLM

第四阶段:核心消融

use_wavelet_branch=false
use_frequency_refine=false
ssm_backend=torch
use_global_branch_stage1=true
encoder_depths=[2,2,3,2]

23. 当前边界与注意事项

  1. 当前文档描述的是 active XNet2d 全监督主链。
  2. 当前训练主链只优化 seg_logits
  3. lib/sam2 保留但不参与训练。
  4. lib/SwinTransformer 保留但不参与训练。
  5. ssm_backend=auto 在 CUDA 上应走 oflex,这是当前速度优化后的默认路径。
  6. XFrequencyRefine2d 的 FFT 计算使用 float32,避免 AMP 下 ComplexHalf warning。
  7. num_workers > 0 时,第一次进入 dataloader 迭代可能在 END TRAINING SETUP 后产生一次性等待。