本文说明当前优化版全监督训练从命令行启动到模型更新、验证、日志上传和 checkpoint 保存的完整流程。目标是让后续运行实验时清楚知道入口脚本、YAML、Trainer、DataLoader、模型、loss、日志和 checkpoint 之间如何衔接。
当前优化训练主入口是:
DATASET=BUSI bash tools/run_optimized_supervised.sh
批量运行全部优化配置:
RUN_ALL=1 bash tools/run_optimized_supervised.sh
flowchart TD
accTitle: Optimized Training Flow
accDescr: This diagram shows how optimized YAML configs are selected, loaded, and used by the supervised XNet2d training loop.
command["命令行<br/>DATASET=BUSI bash tools/run_optimized_supervised.sh"]
script["tools/run_optimized_supervised.sh<br/>映射数据集、生成 split、选择 YAML"]
split{"需要项目级 split?"}
split_script["scripts/generate_project_split.py<br/>生成或复用 train.txt / val.txt"]
train_entry["tools/train.py<br/>读取 YAML 并应用 EXTRA_SET_ARGS"]
trainer_builder["lib/trainers/builder.py<br/>构建 supervised_segmentation trainer"]
trainer["lib/trainers/supervised.py<br/>build() + train()"]
data["lib/data/*<br/>索引、划分、Dataset、DataLoader"]
model["lib/modules/xnet_2d.py<br/>XNet2d forward"]
loss["lib/tools/loss.py<br/>DiceCELoss"]
optim["lib/tools/optim.py<br/>AdamW + cosine scheduler"]
log["SwanLab + terminal<br/>epoch metrics, lr, validation metrics"]
ckpt["outputs/experiments/optimized/*<br/>best / last checkpoint"]
command --> script
script --> split
split -->|"BUSI/BUS-UCLM/BUS-BRA/BUS_UC/CCAUI/DDTI"| split_script
split -->|"TN3K/TG3K/OTU_2d"| train_entry
split_script --> train_entry
train_entry --> trainer_builder
trainer_builder --> trainer
trainer --> data
trainer --> model
trainer --> loss
trainer --> optim
data --> trainer
model --> loss
loss --> optim
trainer --> log
trainer --> ckpt
优化配置统一放在:
configs/segmentation/optimized/
当前包含 9 个数据集配置:
| 数据集 | 配置文件 | 输入尺寸 | batch | lr | epochs |
|---|---|---|---|---|---|
BUSI |
sup_busi_opt.yaml |
384x384 |
48 | 8.0e-5 |
300 |
BUS-UCLM |
sup_bus_uclm_opt.yaml |
384x384 |
48 | 8.0e-5 |
320 |
BUS-BRA |
sup_bus_bra_opt.yaml |
384x384 |
48 | 1.2e-4 |
240 |
BUS_UC |
sup_bus_uc_opt.yaml |
384x384 |
48 | 1.1e-4 |
260 |
CCAUI |
sup_ccaui_opt.yaml |
384x384 |
48 | 1.1e-4 |
260 |
DDTI |
sup_ddti_opt.yaml |
384x384 |
48 | 8.0e-5 |
320 |
TN3K |
sup_tn3k_opt.yaml |
384x384 |
48 | 1.2e-4 |
220 |
TG3K |
sup_tg3k_opt.yaml |
384x384 |
48 | 1.2e-4 |
220 |
OTU_2d |
sup_otu_2d_opt.yaml |
384x384 |
48 | 1.2e-4 |
220 |
共同设置:
train.amp: true,默认使用混合精度训练。train.device: cuda,正式训练走 GPU。train.batch_size: 48 和 train.val_batch_size: 48,用于更充分利用 16GB GPU。dataset.image_size: [384, 384],保证各数据集实验公平。logging.use_swanlab: true 和 logging.swanlab_mode: cloud,默认上传 SwanLab。checkpoint.save: true,默认保存最优和最后 checkpoint。优化脚本位置:
tools/run_optimized_supervised.sh
脚本固定使用当前项目环境:
PYTHON="/opt/miniforge3/envs/xnet_mamba/bin/python"
脚本通过 config_path() 将数据集映射到配置文件,例如:
"BUSI") echo "configs/segmentation/optimized/sup_busi_opt.yaml" ;;
"TN3K") echo "configs/segmentation/optimized/sup_tn3k_opt.yaml" ;;
随后调用:
"$PYTHON" tools/train.py --config "$config" --set ${EXTRA_SET_ARGS}
如果不传 EXTRA_SET_ARGS,训练完全按 YAML 执行。只有显式传入覆盖项时,才会修改 YAML 中的字段。
以下数据集会先生成或复用项目级划分:
BUSI, BUS-UCLM, BUS-BRA, BUS_UC, CCAUI, DDTI
脚本执行:
/opt/miniforge3/envs/xnet_mamba/bin/python scripts/generate_project_split.py \
--dataset BUSI \
--root data/BUSI \
--seed 42
输出位置:
data/<DATASET>/splits/project/train.txt
data/<DATASET>/splits/project/val.txt
TN3K、TG3K、OTU_2d 走已有 split 逻辑或目录内划分,不由优化脚本生成项目级 split。
tools/train.py 入口训练入口文件:
tools/train.py
核心逻辑:
cfg = load_yaml_config(cfg_path)
cfg = apply_dotlist_overrides(cfg, args.set)
trainer = build_trainer(cfg, args=args)
trainer.train()
含义:
load_yaml_config() 从 YAML 读取完整配置字典。apply_dotlist_overrides() 只在 --set key=value 存在时覆盖配置。build_trainer() 根据 trainer.name 构建训练器。trainer.name 都是 supervised_segmentation。训练器实现:
lib/trainers/supervised.py
SupervisedSegmentationTrainer.build() 会完成:
cfg["model"] 和 cfg["dataset"] 构建 XNet2d。lib/tools/optim.py 构建 AdamW。lib/tools/optim.py 构建 warmup + cosine scheduler。lib/tools/loss.py 构建 MONAI DiceCELoss。logging.use_swanlab=true,初始化 SwanLab run。训练器基类 lib/trainers/base.py 负责设备、AMP、DataLoader、checkpoint、early stopping、SwanLab 日志和性能指标快照。
数据加载模块位于:
lib/data/
| 文件 | 职责 |
|---|---|
records.py |
定义统一样本记录 SegSampleRecord |
builder.py |
按数据集名称选择索引逻辑 |
indexers.py |
根据目录结构匹配 image 和 mask |
project_splits.py |
生成和读取项目级 train/val split |
datasets.py |
真正读取图像和 mask |
augment.py |
执行 resize、flip、rotate、brightness/contrast、noise |
loaders.py |
组装 Dataset 和 DataLoader |
训练 batch 的核心字段:
{
"image": Tensor[B, 3, H, W],
"mask": Tensor[B, 1, H, W],
"dataset_name": ...,
"sample_id": ...,
"class_name": ...,
"meta": ...,
}
在优化配置中,H=W=384,B=48。
主模型文件:
lib/modules/xnet_2d.py
flowchart TD
accTitle: XNet2d Forward Path
accDescr: This diagram summarizes how an input image moves through the XNet2d encoder, decoder, refinement, and output head.
input["输入图像<br/>B x 3 x 384 x 384"]
stem["XNetStem2d<br/>下采样并提升通道"]
encoder["XNetEncoder2d<br/>local + wavelet + VMamba global branches"]
bottleneck["bottleneck blocks"]
decoder["XNetDecoder2d<br/>多尺度上采样融合"]
refine{"use_frequency_refine?"}
freq["XFrequencyRefine2d<br/>FFT 频率精炼"]
head["segmentation head"]
output["输出<br/>seg_logits / logits"]
input --> stem
stem --> encoder
encoder --> bottleneck
bottleneck --> decoder
decoder --> refine
refine -->|"true"| freq
refine -->|"false"| head
freq --> head
head --> output
当前优化配置开启:
use_wavelet_branch: trueuse_frequency_refine: trueuse_global_branch_stage1: falsessm_backend: auto注意:AMP 下 ptwt.wavedec2 不支持 float16,因此小波分解和重建内部会临时用 fp32 计算,再转回原 dtype。整体训练仍然保持 AMP。
sequenceDiagram
autonumber
participant loader as DataLoader
participant trainer as Trainer
participant model as XNet2d
participant loss as DiceCELoss
participant optim as AdamW
participant log as Logger
loader->>trainer: batch(image, mask)
trainer->>trainer: image/mask.to(cuda)
trainer->>model: autocast forward(image)
model-->>trainer: outputs["seg_logits"]
trainer->>loss: loss(seg_logits, mask)
loss-->>trainer: total_loss
trainer->>trainer: GradScaler.scale(loss).backward()
trainer->>optim: unscale + grad clip + step
trainer->>trainer: GradScaler.update()
trainer->>log: print step snapshot to terminal only
训练 step 日志只打印到终端,不上传 SwanLab。它包含:
epochstepnum_stepsdata_timeiter_timegpu_memory_mblrtrain_totaltrain_seg每个 epoch 结束后,如果满足:
validation:
enabled: true
interval: 1
训练器会执行完整验证集 forward,并计算 val/total、val/seg、val/dice、val/iou。
SwanLab 初始化在 BaseTrainer._init_swanlab()。当前优化配置默认:
logging:
use_swanlab: true
project: X_SSL_Net
swanlab_mode: cloud
swanlab_logdir: swanlog
swanlab_mode: cloud 会同步云端,swanlab_logdir: swanlog 会在本地保留运行日志目录。训练过程只在每个 epoch 结束后上传 SwanLab,不上传 step 级曲线。上传字段包括:
| 阶段 | 指标 |
|---|---|
| epoch train | total, seg, grad_norm 如果本轮有记录 |
| epoch val | val_total, val_seg, val_dice, val_iou |
| epoch state | epoch, best_metric, no_improve_epochs, lr, early_stop, improved |
SwanLab 的本地数据主要保存在:
swanlog/run-*/backup.swanlab
这是 SwanLab 的二进制备份文件,不是普通文本日志。需要读取时使用:
/opt/miniforge3/envs/xnet_mamba/bin/python tools/export_swanlab_backup.py swanlog/run-YYYYMMDD_HHMMSS-<run_id> --exclude-system
默认会导出:
swanlog/run-*/exported/scalars.csv
swanlog/run-*/exported/scalars.jsonl
swanlog/run-*/exported/logs.csv
swanlog/run-*/exported/logs.jsonl
swanlog/run-*/exported/epoch_metrics.csv
swanlog/run-*/exported/epoch_metrics.jsonl
checkpoint 目录统一在:
outputs/experiments/optimized/<DATASET>
当前以验证 Dice 作为最优模型判断依据:
checkpoint:
monitor: dice
monitor_mode: max
单数据集正式训练:
DATASET=BUSI bash tools/run_optimized_supervised.sh
批量跑全部优化配置:
RUN_ALL=1 bash tools/run_optimized_supervised.sh
临时覆盖 batch:
DATASET=BUSI \
EXTRA_SET_ARGS="train.batch_size=32 train.val_batch_size=32" \
bash tools/run_optimized_supervised.sh
短跑检查配置链路:
DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 checkpoint.save=false logging.use_swanlab=false" \
bash tools/run_optimized_supervised.sh
建议先跑代表性主实验:
BUSITN3KTG3KBUS-BRA再补充扩展实验:
BUS-UCLMBUS_UCDDTIOTU_2dCCAUI| 现象 | 优先检查 |
|---|---|
| 找不到数据 | dataset.root 是否存在,DATASET 是否拼写正确 |
| 找不到 split | 对项目级 split 数据集先运行脚本,脚本会自动生成 |
| CUDA 不可用 | torch.cuda.is_available() 和当前运行权限 |
| SwanLab 不上传 | swanlab 是否登录、网络是否可用、logging.use_swanlab 是否为 true |
| 显存不足 | 用 EXTRA_SET_ARGS 降低 train.batch_size 和 train.val_batch_size |
| loss 为 NaN | 先降低 lr,再检查 mask 是否为空或异常 |
| 验证 Dice 长期为 0 | 检查 mask 读取、阈值、数据 split 和输出通道 |
| 文件 | 作用 |
|---|---|
tools/run_optimized_supervised.sh |
优化训练一键入口 |
tools/train.py |
统一 Python 训练入口 |
configs/segmentation/optimized/*.yaml |
各数据集优化训练配置 |
scripts/generate_project_split.py |
项目级 train/val split 生成 |
lib/trainers/supervised.py |
全监督分割训练循环 |
lib/trainers/base.py |
日志、checkpoint、设备、DataLoader 通用能力 |
lib/data/* |
数据索引、读取、增强和 DataLoader |
lib/modules/xnet_2d.py |
XNet2d 主模型 |
lib/tools/loss.py |
loss 构建 |
lib/tools/optim.py |
optimizer 和 scheduler 构建 |
scripts/probe_xnet_memory.py |
合成 batch 显存探测脚本 |