X_SSL_Net 当前 active 主线是一个面向 2D 超声图像分割的全监督训练工程。
当前真实训练链路:
shell script
-> tools/train.py
-> SupervisedSegmentationTrainer
-> SegmentationRecordDataset / DataLoader
-> XNet2d
-> seg_logits
-> DiceCE loss / BCE fallback
-> Dice / IoU validation
-> best.pth / last.pth
当前真实模型主线:
XNet2d = CNN-Wavelet-VMamba encoder + plain U-Net skip decoder segmentation network
当前主训练只使用一个分割头:
outputs["seg_logits"]
当前主线不调用:
lib/sam2lib/SwinTransformerlib/sam2 与 lib/SwinTransformer 目前作为外部代码资产保留,不进入当前训练路径。
当前项目可以概括为:
用 XNet2d 在 BUSI / DDTI / TN3K / TG3K 等 2D 超声数据集上做全监督分割训练。
XNet2d 的 encoder 用 local + wavelet + VMamba-style SS2D 三分支建模,
decoder 用普通 U-Net 同尺度 skip + 频率细化恢复 mask。
最常用入口:
DATASET=BUSI bash tools/run_optimized_supervised.sh
短跑调试入口:
DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \
bash tools/run_optimized_supervised.sh
文件:
tools/run_optimized_supervised.sh
它做四件事:
DATASETtrain/valtools/train.py支持的数据集名称:
BUS-UCLM
BUSI
BUS-BRA
BUS_UC
CCAUI
DDTI
OTU_2d
TN3K
TG3K
数据集根目录映射:
BUSI -> data/BUSI
DDTI -> data/DDTI
TN3K -> data/TN3K
TG3K -> data/TG3K
BUS_UC -> data/BUS_UC
...
项目级 split 数据集:
BUS-UCLM, BUSI, BUS-BRA, BUS_UC, CCAUI, DDTI
官方 split 数据集:
OTU_2d, TN3K, TG3K
User command
|
| DATASET=BUSI EXTRA_SET_ARGS="..." bash tools/run_optimized_supervised.sh
v
+----------------------------------------------------------------------------------+
| tools/run_optimized_supervised.sh |
+----------------------------------------------------------------------------------+
| 1. read DATASET / SEED / EXTRA_SET_ARGS |
| 2. dataset_root(DATASET) |
| 3. if DATASET needs project split: |
| python scripts/generate_project_split.py --dataset DATASET --root ROOT |
| 4. python tools/train.py |
| --config configs/segmentation/optimized/*.yaml |
| --set dataset.dataset_name=DATASET dataset.root=ROOT ... EXTRA_SET_ARGS |
+----------------------------------------------------------------------------------+
|
v
+----------------------------------------------------------------------------------+
| tools/train.py |
+----------------------------------------------------------------------------------+
| 1. parse --config / --trainer / --set |
| 2. load yaml config |
| 3. apply dotlist overrides |
| 4. optional override trainer.name |
| 5. build_trainer(cfg) |
| 6. trainer.train() |
+----------------------------------------------------------------------------------+
当前主配置:
configs/segmentation/optimized/*.yaml
当前保留的 segmentation 配置:
configs/segmentation/optimized/*.yaml
configs/segmentation/us_exp_sup_busi.yaml
configs/segmentation/us_exp_sup_busi_ablation.yaml
tools/train.py 支持:
--set key=value key=value ...
例如:
--set train.epochs=1 train.batch_size=8 model.use_frequency_refine=false
覆盖逻辑:
load_yaml_config(path)
|
v
apply_dotlist_overrides(cfg, args.set)
|
v
nested dict update
训练:
train:
epochs: 200
batch_size: 4
val_batch_size: 4
amp: true
num_workers: 4
pin_memory: true
persistent_workers: true
prefetch_factor: 2
device: cuda
grad_clip:
enabled: true
max_norm: 1.0
数据:
dataset:
dataset_name: BUSI
root: data/BUSI
split: train
val_split: val
image_size: [256, 256]
in_channels: 3
num_classes: 1
模型:
model:
in_channels: 3
encoder_channels: [32, 64, 128, 192]
encoder_depths: [2, 2, 2, 2]
decoder_channels: [128, 64, 32]
stem_channels: 24
bottleneck_depth: 1
global_ratio: 2.0
wavelet_type: haar
wavelet_level: 1
use_wavelet_branch: true
use_global_branch_stage1: false
ssm_d_state: 16
ssm_forward_type: v3
ssm_backend: auto
use_frequency_refine: true
guide_mode: affine # compatibility only; current decoder no longer uses guide path
out_channels: null
优化:
optimizer:
name: adamw
lr: 1.0e-4
weight_decay: 0.05
scheduler:
name: cosine
warmup:
name: linear
params:
start_factor: 0.1
total_iters: 10
params:
T_max: 190
eta_min: 1.0e-6
loss 与 metric:
loss:
name: dicece
task_mode: binary
params:
include_background: true
lambda_dice: 0.7
lambda_ce: 0.3
validation:
threshold: 0.5
metrics:
task_mode: binary
metrics:
- name: dice
- name: iou
入口:
lib/trainers/builder.py::build_trainer
当前 trainer:
lib/trainers/supervised.py::SupervisedSegmentationTrainer
构建流程:
build_trainer(cfg)
|
v
read cfg.trainer.name
|
v
TRAINER_REGISTRY["supervised_segmentation"]
|
v
trainer = SupervisedSegmentationTrainer(cfg, args)
|
v
trainer.build()
|
v
return trainer
SupervisedSegmentationTrainer.build() 做:
1. dataset_cfg = cfg["dataset"]
2. model_cfg = cfg["model"]
3. train_cfg = cfg["train"]
4. build XNet2d from model_cfg
5. move model to device
6. build optimizer
7. build scheduler
8. build loss if cfg.loss is not null
9. build train dataloader
10. build validation dataloader
11. maybe resume checkpoint
12. maybe init SwanLab
文件:
lib/trainers/base.py
公共职责:
BaseTrainer
├─ random seed
├─ device selection
├─ output directory
├─ AMP GradScaler
├─ batch size resolution
├─ dataloader construction helper
├─ validation metric construction
├─ checkpoint save / resume
├─ early stopping
├─ SwanLab logging
├─ training setup summary
├─ step performance logging
└─ epoch finalization
设备选择:
cfg.train.device == "cuda" and torch.cuda.is_available()
-> cuda
else
-> cpu
AMP 开关:
cfg.train.amp == true and device == cuda
-> enabled
else
-> disabled
当前已验证目标环境:
conda env: xnet_mamba
torch: 2.10.0+cu126
GPU: NVIDIA GeForce RTX 4070 Ti SUPER
selective_scan_cuda_oflex: available
入口:
lib/data/builder.py::build_dataset_index
核心 registry:
BUS-UCLM -> paired images/masks
BUSI -> Dataset_BUSI_with_GT/{benign,malignant,normal}
BUS-BRA -> prefixed image/mask matching
BUS_UC -> All / Benign / Malignant folders
CCAUI -> US images / Expert mask images
DDTI -> XML annotation records
OTU_2d -> images / annotations
TN3K -> trainval/test image/mask folders
TG3K -> thyroid-image / thyroid-mask
入口:
lib/data/loaders.py::apply_official_split
流程:
build_dataset_index(dataset_name, root)
|
v
if split is requested:
|
+-- OTU_2d: read train.txt / val.txt
|
+-- TN3K: read tn3k-trainval.json or use test folder
|
+-- TG3K: read tg3k-trainval.json
|
+-- project split dataset:
read data/<dataset>/splits/project/train.txt or val.txt
项目级 split 生成:
scripts/generate_project_split.py
|
v
generate_project_splits()
|
v
write:
data/<dataset>/splits/project/train.txt
data/<dataset>/splits/project/val.txt
文件:
lib/data/datasets.py::SegmentationRecordDataset
单样本读取:
record
|
+-- image_path -> PIL RGB -> float32 [3,H,W] in [0,1]
|
+-- mask_path -> PIL L -> binary float32 [1,H,W]
|
+-- DDTI special:
annotation_path XML -> build_ddti_mask() -> binary [1,H,W]
|
+-- joint augmentation
|
+-- resize image to dataset.image_size
|
+-- resize mask to dataset.image_size
|
v
{
"image": image,
"mask": mask,
"dataset_name": ...,
"sample_id": ...,
"split": ...,
"class_name": ...,
"meta": ...
}
文件:
lib/data/augment.py::SegmentationAugmentation
当前支持:
spatial:
random horizontal flip
random vertical flip
random rotate 90
intensity:
random brightness / contrast
random gaussian noise
clamp to [0,1]
文件:
lib/data/collate.py::record_collate_fn
逻辑:
if all tensor shapes same:
torch.stack(values, dim=0)
else:
keep list
strings / dict / metadata:
keep list
最终 batch:
image: [B,3,256,256]
mask : [B,1,256,256]
SupervisedSegmentationTrainer.build()
|
v
_build_segmentation_loader(split="train")
|
v
build_dataloader()
|
v
build_record_dataset()
|
v
build_dataset_index()
|
v
apply_official_split()
|
v
SegmentationRecordDataset(records, transforms)
|
v
DataLoader(
batch_size,
shuffle,
num_workers,
pin_memory,
persistent_workers,
prefetch_factor,
collate_fn=record_collate_fn
)
注意:DataLoader worker 的真实启动通常发生在第一次迭代时,也就是 ======== END TRAINING SETUP ======== 之后。若 num_workers > 0,第一批数据可能出现一次性等待。
文件:
lib/modules/xnet_2d.py
当前默认参数量:
total parameters: 9,432,129
trainable parameters: 9,432,129
顶层结构:
XNet2d
├─ XNetEncoder2d
│ ├─ XNetStem2d
│ ├─ Encoder Stage 1: XTEB2d x 2
│ ├─ Downsample 1
│ ├─ Encoder Stage 2: XTEB2d x 2
│ ├─ Downsample 2
│ ├─ Encoder Stage 3: XTEB2d x 2
│ ├─ Downsample 3
│ └─ Encoder Stage 4: XTEB2d x 2
│
├─ Bottleneck: XTEB2d x 1
│
├─ XNetDecoder2d
│ ├─ dec4: XCRB2d(E4, E3)
│ ├─ dec3: XCRB2d(D4, E2)
│ ├─ dec2: XCRB2d(D3, E1)
│ └─ head_refine
│
└─ XNetSegHead2d
以输入 [B,3,256,256] 为例,默认通道为 [32,64,128,192]:
Input
[B, 3, 256, 256]
|
v
XNetStem2d
Conv3x3 s2: [B, 24, 128, 128]
DWConv3x3: [B, 24, 128, 128]
PWConv1x1: [B, 32, 128, 128]
Conv3x3 s2: [B, 32, 64, 64]
|
v
E1 = Encoder Stage 1, XTEB x2
[B, 32, 64, 64]
|
v
Down1
[B, 64, 32, 32]
|
v
E2 = Encoder Stage 2, XTEB x2
[B, 64, 32, 32]
|
v
Down2
[B, 128, 16, 16]
|
v
E3 = Encoder Stage 3, XTEB x2
[B, 128, 16, 16]
|
v
Down3
[B, 192, 8, 8]
|
v
E4 = Encoder Stage 4, XTEB x2
[B, 192, 8, 8]
|
v
Bottleneck XTEB x1
[B, 192, 8, 8]
Decoder:
E4 [B,192,8,8]
|
v
dec4 input:
decoder input: E4 [B,192,8,8]
same-scale skip: E3 [B,128,16,16]
output D4 [B,128,16,16]
D4 [B,128,16,16]
|
v
dec3 input:
decoder input: D4 [B,128,16,16]
same-scale skip: E2 [B,64,32,32]
output D3 [B,64,32,32]
D3 [B,64,32,32]
|
v
dec2 input:
decoder input: D3 [B,64,32,32]
same-scale skip: E1 [B,32,64,64]
output D2 [B,32,64,64]
D2 [B,32,64,64]
|
v
HeadRefine
[B,32,64,64]
|
v
SegHead + upsample to input size
[B,1,256,256]
XTEB2d 是 encoder 的基本 block。
名字含义:
XTEB = XNet Tri-branch Encoding Block
输入输出:
input : X [B,C,H,W]
output: Y [B,C,H,W]
内部结构:
X
│
├─ pre_norm: 1x1 Conv2dBN
│
├─ Local branch
│ ├─ DWConv3x3 + PWConv1x1
│ └─ DWConv5x5 + PWConv1x1
│
├─ Wavelet branch
│ ├─ Haar DWT
│ │ ├─ LL
│ │ └─ LH/HL/HH high bands
│ ├─ LL projection
│ ├─ high-band projection
│ └─ inverse Haar transform
│
├─ Global branch
│ ├─ 1x1 pre projection
│ ├─ VMamba-style SS2D
│ └─ 1x1 post projection
│
├─ concat(local, wavelet, global)
├─ 1x1 fusion
├─ channel gate from GAP + MLP + sigmoid
├─ residual add
└─ lightweight FFN + residual add
公式化:
X0 = PreNorm(X)
L = Local(X0)
W = Wavelet(X0)
G = GlobalSS2D(X0)
F = Fuse([L,W,G])
Y = X + Post(F)
Z = Y + FFN(Y)
职责:
局部纹理、边界、短程结构
结构:
DWConv3x3 -> ReLU -> PWConv1x1
DWConv5x5 -> ReLU -> PWConv1x1
sum
职责:
低频轮廓 + 高频边界/纹理
结构:
Haar DWT:
LL -> low-frequency structure
LH/HL/HH -> high-frequency directional details
LL -> Conv projection
High bands -> depthwise conv + pointwise conv
IDWT -> output projection
当前限制:
wavelet_type = haar
wavelet_level = 1
职责:
高效长程依赖建模、全局结构一致性
当前实现:
lib/modules/lib_mamba/vmamba.py::SS2D
来源:
VMamba-style SS2D operator
后端选择:
ssm_backend = auto
|
+-- if x.is_cuda:
selective_scan_backend = oflex
scan_force_torch = false
|
+-- else:
selective_scan_backend = torch
scan_force_torch = true
ssm_backend = oflex
-> force oflex
ssm_backend = torch
-> force torch fallback
当前默认:
ssm_forward_type = v3
ssm_backend = auto
在 xnet_mamba + RTX 4070 Ti SUPER 环境中已验证:
selective_scan_cuda_oflex import OK
WITH_SELECTIVESCAN_OFLEX = True
XCRB2d 是 decoder 的基本 block。当前实现已经去掉旧版 X-shaped 斜向 guide,恢复为普通 U-Net 同尺度 skip 连接。
名字含义:
XCRB = XNet Reconstruction Block
输入:
decoder input: deeper decoder or bottleneck feature
same-scale skip: encoder feature at target scale
内部结构:
decoder input
|
v
bilinear upsample to skip size
|
v
1x1 projection
|
+-----------------------------+
|
same-scale skip |
| |
v |
1x1 projection |
| |
+----------- concat ----------+
|
v
3x3 fusion
|
v
optional frequency refine
|
v
residual spatial refine
当前 decoder 使用普通 U-Net 横向 skip:
same-scale path:
E3 -> D4
E2 -> D3
E1 -> D2
纯文本示意:
Encoder: E1 ---------------------------> D2
Encoder: E2 ---------------------------> D3
Encoder: E3 ---------------------------> D4
Encoder: E4 -> Bottleneck -> decoder up path
配置里仍可能出现 guide_mode=affine,但当前 decoder 不再构建 XGuideProjector2d 或 XGuideModulation2d,该参数只用于兼容旧 YAML,不参与前向计算。
默认 use_frequency_refine=true。
流程:
feature F
|
v
cast to float32 if needed
|
v
rfft2
|
+-- low frequency mask
|
+-- high frequency residual
|
v
low/high learnable gates
|
v
irfft2
|
v
cast back to input dtype
|
v
depthwise conv refine
这里显式将 FFT 计算放在 float32 中,避免 AMP 下触发 ComplexHalf support is experimental warning。
XNet2d.forward(x) 返回:
{
"logits": logits,
"seg_logits": logits,
"encoder_features": encoder_features,
"decoder_features": decoder_features,
"guides": [],
}
训练只使用:
outputs["seg_logits"]
其余输出用于:
debug
visualization
future auxiliary analysis
当前没有边界辅助输出。
当前没有斜向 guide 输出,guides 保持为空列表用于兼容旧调试接口。
入口:
SupervisedSegmentationTrainer.train()
流程:
train()
|
v
print training setup
|
v
for epoch in range(start_epoch, epochs):
|
+-- model.train()
+-- optimizer.zero_grad()
+-- for step, batch in train_loader:
|
+-- measure data_time
|
+-- image = batch["image"].to(device)
+-- mask = batch["mask"].to(device)
|
+-- with autocast(enabled=amp):
outputs = model(image)
seg_logits = outputs["seg_logits"]
seg_loss = loss(seg_logits, mask)
total_loss = seg_loss
|
+-- scaled_total_loss = total_loss / accum_steps
+-- grad_scaler.scale(scaled_total_loss).backward()
|
+-- if should optimizer step:
unscale gradients if grad clipping enabled
clip grad norm
grad_scaler.step(optimizer)
grad_scaler.update()
optimizer.zero_grad()
|
+-- log step every logging.log_interval
|
+-- scheduler.step()
|
+-- validate if enabled and interval matches
|
+-- finalize epoch
|
+-- merge train / val metrics
+-- update best metric
+-- save best.pth if improved
+-- save last.pth if enabled
+-- early stopping check
当前配置使用:
MONAI DiceCELoss
构建路径:
cfg.loss
|
v
lib/tools/loss.py::build_loss
|
v
DiceCELoss(sigmoid=True, include_background=True, lambda_dice=0.7, lambda_ce=0.3)
如果 loss: null:
torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
该 fallback 适合环境临时缺 MONAI 时做 smoke test,不建议作为正式论文训练默认。
验证函数:
SupervisedSegmentationTrainer._validate()
流程:
model.eval()
build validation metrics
for batch in val_loader:
image -> device
mask -> device
outputs, losses = _compute_losses(image, mask)
update loss sums
update metrics with outputs["seg_logits"]
average val loss
compute Dice / IoU
reset metric states
return val_metrics
metric 输入处理:
binary mode:
pred = sigmoid(logits) >= threshold
target = target > 0
multiclass mode:
pred = argmax(logits)
target = one-hot or class index
当前默认:
threshold = 0.5
metrics = Dice, IoU
checkpoint 目录:
cfg.checkpoint.dir
默认脚本会覆盖为:
outputs/experiments/supervised/<DATASET>
保存文件:
best.pth
last.pth
checkpoint 内容:
epoch
cfg
metrics
model state_dict
optimizer state_dict
scheduler state_dict
grad_scaler state_dict
best_metric
no_improve_epochs
best 判断:
monitor = dice
monitor_mode = max
即:
val_dice 越大越好
每隔 logging.log_interval step 打印:
epoch
step
num_steps
data_time
iter_time
gpu_memory_mb
lr
train_total
train_seg
train_grad_norm
含义:
data_time:
从上一步结束到当前 batch 可用的时间。
num_workers > 0 时,第一批 worker 启动开销发生在 END TRAINING SETUP 之后。
iter_time:
当前 step 的训练计算时间,包括 forward、loss、backward、optimizer step。
gpu_memory_mb:
torch.cuda.max_memory_allocated。
当前实测参考:
batch_size = 8
image_size = 256
ssm_backend = auto -> oflex
iter_time ≈ 0.09 - 0.11 s / step
GPU memory ≈ 850 MB
Batch from DataLoader
|
+-- image [B,3,256,256]
+-- mask [B,1,256,256]
|
v
image.to(cuda), mask.to(cuda)
|
v
autocast(enabled=True)
|
v
XNet2d(image)
|
+-- encoder_features = [E1,E2,E3,E4]
|
+-- bottleneck(E4)
|
+-- decoder_out, decoder_features, guides=[]
|
+-- segmentation_head(decoder_out)
|
v
seg_logits [B,1,256,256]
|
v
DiceCELoss(seg_logits, mask)
|
v
total_loss
|
v
GradScaler.scale(total_loss).backward()
|
v
clip gradients
|
v
optimizer.step()
GPU 环境检查:
python -c "import sys, torch; print(sys.executable); print(torch.__version__); print(torch.cuda.is_available()); print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'no cuda')"
oflex 检查:
python -c "import torch; import selective_scan_cuda_oflex; print('oflex import OK')"
python -c "import torch; from lib.modules.lib_mamba import csms6s; print(csms6s.WITH_SELECTIVESCAN_OFLEX)"
前向检查:
python - <<'PY'
import torch
from lib.modules import XNet2d
model = XNet2d(in_channels=3, num_classes=1, ssm_backend="auto", ssm_forward_type="v3").cuda().eval()
x = torch.randn(1, 3, 128, 128, device="cuda")
with torch.no_grad():
y = model(x)
print(sorted(y.keys()))
print(tuple(y["seg_logits"].shape))
PY
短训:
DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \
bash tools/run_optimized_supervised.sh
关闭 frequency refine 消融:
DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 model.use_frequency_refine=false logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8_no_freq" \
bash tools/run_optimized_supervised.sh
汇总结果:
bash tools/summarize_results.sh
sed -n '1,40p' results/experiment_summary.md
第一阶段:训练链路稳定性
BUSI smoke
BUSI batch size 8
BUSI no frequency refine
第二阶段:甲状腺主线
DDTI
TN3K
TG3K
DDTI -> TN3K / TN3K -> DDTI 跨数据集泛化
第三阶段:乳腺扩展
BUSI
BUS_UC
BUS-BRA
BUS-UCLM
第四阶段:核心消融
use_wavelet_branch=false
use_frequency_refine=false
ssm_backend=torch
use_global_branch_stage1=true
encoder_depths=[2,2,3,2]
seg_logits。lib/sam2 保留但不参与训练。lib/SwinTransformer 保留但不参与训练。ssm_backend=auto 在 CUDA 上应走 oflex,这是当前速度优化后的默认路径。XFrequencyRefine2d 的 FFT 计算使用 float32,避免 AMP 下 ComplexHalf warning。num_workers > 0 时,第一次进入 dataloader 迭代可能在 END TRAINING SETUP 后产生一次性等待。