优化配置训练流程说明.md 12 KB

优化配置训练流程说明

1. 文档目的

本文说明当前优化版全监督训练从命令行启动到模型更新、验证、日志上传和 checkpoint 保存的完整流程。目标是让后续运行实验时清楚知道入口脚本、YAML、Trainer、DataLoader、模型、loss、日志和 checkpoint 之间如何衔接。

当前优化训练主入口是:

DATASET=BUSI bash tools/run_optimized_supervised.sh

批量运行全部优化配置:

RUN_ALL=1 bash tools/run_optimized_supervised.sh

2. 总体流程

flowchart TD
    accTitle: Optimized Training Flow
    accDescr: This diagram shows how optimized YAML configs are selected, loaded, and used by the supervised XNet2d training loop.

    command["命令行<br/>DATASET=BUSI bash tools/run_optimized_supervised.sh"]
    script["tools/run_optimized_supervised.sh<br/>映射数据集、生成 split、选择 YAML"]
    split{"需要项目级 split?"}
    split_script["scripts/generate_project_split.py<br/>生成或复用 train.txt / val.txt"]
    train_entry["tools/train.py<br/>读取 YAML 并应用 EXTRA_SET_ARGS"]
    trainer_builder["lib/trainers/builder.py<br/>构建 supervised_segmentation trainer"]
    trainer["lib/trainers/supervised.py<br/>build() + train()"]
    data["lib/data/*<br/>索引、划分、Dataset、DataLoader"]
    model["lib/modules/xnet_2d.py<br/>XNet2d forward"]
    loss["lib/tools/loss.py<br/>DiceCELoss"]
    optim["lib/tools/optim.py<br/>AdamW + cosine scheduler"]
    log["SwanLab + terminal<br/>epoch metrics, lr, validation metrics"]
    ckpt["outputs/experiments/optimized/*<br/>best / last checkpoint"]

    command --> script
    script --> split
    split -->|"BUSI/BUS-UCLM/BUS-BRA/BUS_UC/CCAUI/DDTI"| split_script
    split -->|"TN3K/TG3K/OTU_2d"| train_entry
    split_script --> train_entry
    train_entry --> trainer_builder
    trainer_builder --> trainer
    trainer --> data
    trainer --> model
    trainer --> loss
    trainer --> optim
    data --> trainer
    model --> loss
    loss --> optim
    trainer --> log
    trainer --> ckpt

3. 优化配置文件

优化配置统一放在:

configs/segmentation/optimized/

当前包含 9 个数据集配置:

数据集 配置文件 输入尺寸 batch lr epochs
BUSI sup_busi_opt.yaml 384x384 48 8.0e-5 300
BUS-UCLM sup_bus_uclm_opt.yaml 384x384 48 8.0e-5 320
BUS-BRA sup_bus_bra_opt.yaml 384x384 48 1.2e-4 240
BUS_UC sup_bus_uc_opt.yaml 384x384 48 1.1e-4 260
CCAUI sup_ccaui_opt.yaml 384x384 48 1.1e-4 260
DDTI sup_ddti_opt.yaml 384x384 48 8.0e-5 320
TN3K sup_tn3k_opt.yaml 384x384 48 1.2e-4 220
TG3K sup_tg3k_opt.yaml 384x384 48 1.2e-4 220
OTU_2d sup_otu_2d_opt.yaml 384x384 48 1.2e-4 220

共同设置:

  1. train.amp: true,默认使用混合精度训练。
  2. train.device: cuda,正式训练走 GPU。
  3. train.batch_size: 48train.val_batch_size: 48,用于更充分利用 16GB GPU。
  4. dataset.image_size: [384, 384],保证各数据集实验公平。
  5. logging.use_swanlab: truelogging.swanlab_mode: cloud,默认上传 SwanLab。
  6. checkpoint.save: true,默认保存最优和最后 checkpoint。

4. 运行脚本如何选择 YAML

优化脚本位置:

tools/run_optimized_supervised.sh

脚本固定使用当前项目环境:

PYTHON="/opt/miniforge3/envs/xnet_mamba/bin/python"

脚本通过 config_path() 将数据集映射到配置文件,例如:

"BUSI") echo "configs/segmentation/optimized/sup_busi_opt.yaml" ;;
"TN3K") echo "configs/segmentation/optimized/sup_tn3k_opt.yaml" ;;

随后调用:

"$PYTHON" tools/train.py --config "$config" --set ${EXTRA_SET_ARGS}

如果不传 EXTRA_SET_ARGS,训练完全按 YAML 执行。只有显式传入覆盖项时,才会修改 YAML 中的字段。

5. split 生成流程

以下数据集会先生成或复用项目级划分:

BUSI, BUS-UCLM, BUS-BRA, BUS_UC, CCAUI, DDTI

脚本执行:

/opt/miniforge3/envs/xnet_mamba/bin/python scripts/generate_project_split.py \
  --dataset BUSI \
  --root data/BUSI \
  --seed 42

输出位置:

data/<DATASET>/splits/project/train.txt
data/<DATASET>/splits/project/val.txt

TN3KTG3KOTU_2d 走已有 split 逻辑或目录内划分,不由优化脚本生成项目级 split。

6. tools/train.py 入口

训练入口文件:

tools/train.py

核心逻辑:

cfg = load_yaml_config(cfg_path)
cfg = apply_dotlist_overrides(cfg, args.set)
trainer = build_trainer(cfg, args=args)
trainer.train()

含义:

  1. load_yaml_config() 从 YAML 读取完整配置字典。
  2. apply_dotlist_overrides() 只在 --set key=value 存在时覆盖配置。
  3. build_trainer() 根据 trainer.name 构建训练器。
  4. 当前优化配置的 trainer.name 都是 supervised_segmentation

7. Trainer 构建阶段

训练器实现:

lib/trainers/supervised.py

SupervisedSegmentationTrainer.build() 会完成:

  1. cfg["model"]cfg["dataset"] 构建 XNet2d
  2. 使用 lib/tools/optim.py 构建 AdamW
  3. 使用 lib/tools/optim.py 构建 warmup + cosine scheduler。
  4. 使用 lib/tools/loss.py 构建 MONAI DiceCELoss
  5. 构建训练 DataLoader。
  6. 构建验证 DataLoader。
  7. 如配置了 resume,则恢复 checkpoint。
  8. logging.use_swanlab=true,初始化 SwanLab run。

训练器基类 lib/trainers/base.py 负责设备、AMP、DataLoader、checkpoint、early stopping、SwanLab 日志和性能指标快照。

8. 数据加载流程

数据加载模块位于:

lib/data/
文件 职责
records.py 定义统一样本记录 SegSampleRecord
builder.py 按数据集名称选择索引逻辑
indexers.py 根据目录结构匹配 image 和 mask
project_splits.py 生成和读取项目级 train/val split
datasets.py 真正读取图像和 mask
augment.py 执行 resize、flip、rotate、brightness/contrast、noise
loaders.py 组装 DatasetDataLoader

训练 batch 的核心字段:

{
    "image": Tensor[B, 3, H, W],
    "mask": Tensor[B, 1, H, W],
    "dataset_name": ...,
    "sample_id": ...,
    "class_name": ...,
    "meta": ...,
}

在优化配置中,H=W=384B=48

9. 模型 forward 流程

主模型文件:

lib/modules/xnet_2d.py
flowchart TD
    accTitle: XNet2d Forward Path
    accDescr: This diagram summarizes how an input image moves through the XNet2d encoder, decoder, refinement, and output head.

    input["输入图像<br/>B x 3 x 384 x 384"]
    stem["XNetStem2d<br/>下采样并提升通道"]
    encoder["XNetEncoder2d<br/>local + wavelet + VMamba global branches"]
    bottleneck["bottleneck blocks"]
    decoder["XNetDecoder2d<br/>多尺度上采样融合"]
    refine{"use_frequency_refine?"}
    freq["XFrequencyRefine2d<br/>FFT 频率精炼"]
    head["segmentation head"]
    output["输出<br/>seg_logits / logits"]

    input --> stem
    stem --> encoder
    encoder --> bottleneck
    bottleneck --> decoder
    decoder --> refine
    refine -->|"true"| freq
    refine -->|"false"| head
    freq --> head
    head --> output

当前优化配置开启:

  1. use_wavelet_branch: true
  2. use_frequency_refine: true
  3. use_global_branch_stage1: false
  4. ssm_backend: auto

注意:AMP 下 ptwt.wavedec2 不支持 float16,因此小波分解和重建内部会临时用 fp32 计算,再转回原 dtype。整体训练仍然保持 AMP。

10. 单步训练流程

sequenceDiagram
    autonumber
    participant loader as DataLoader
    participant trainer as Trainer
    participant model as XNet2d
    participant loss as DiceCELoss
    participant optim as AdamW
    participant log as Logger

    loader->>trainer: batch(image, mask)
    trainer->>trainer: image/mask.to(cuda)
    trainer->>model: autocast forward(image)
    model-->>trainer: outputs["seg_logits"]
    trainer->>loss: loss(seg_logits, mask)
    loss-->>trainer: total_loss
    trainer->>trainer: GradScaler.scale(loss).backward()
    trainer->>optim: unscale + grad clip + step
    trainer->>trainer: GradScaler.update()
    trainer->>log: print step snapshot to terminal only

训练 step 日志只打印到终端,不上传 SwanLab。它包含:

  1. epoch
  2. step
  3. num_steps
  4. data_time
  5. iter_time
  6. gpu_memory_mb
  7. lr
  8. train_total
  9. train_seg

11. 验证、SwanLab 和 checkpoint

每个 epoch 结束后,如果满足:

validation:
  enabled: true
  interval: 1

训练器会执行完整验证集 forward,并计算 val/totalval/segval/diceval/iou

SwanLab 初始化在 BaseTrainer._init_swanlab()。当前优化配置默认:

logging:
  use_swanlab: true
  project: X_SSL_Net
  swanlab_mode: cloud
  swanlab_logdir: swanlog

swanlab_mode: cloud 会同步云端,swanlab_logdir: swanlog 会在本地保留运行日志目录。训练过程只在每个 epoch 结束后上传 SwanLab,不上传 step 级曲线。上传字段包括:

阶段 指标
epoch train total, seg, grad_norm 如果本轮有记录
epoch val val_total, val_seg, val_dice, val_iou
epoch state epoch, best_metric, no_improve_epochs, lr, early_stop, improved

SwanLab 的本地数据主要保存在:

swanlog/run-*/backup.swanlab

这是 SwanLab 的二进制备份文件,不是普通文本日志。需要读取时使用:

/opt/miniforge3/envs/xnet_mamba/bin/python tools/export_swanlab_backup.py swanlog/run-YYYYMMDD_HHMMSS-<run_id> --exclude-system

默认会导出:

swanlog/run-*/exported/scalars.csv
swanlog/run-*/exported/scalars.jsonl
swanlog/run-*/exported/logs.csv
swanlog/run-*/exported/logs.jsonl
swanlog/run-*/exported/epoch_metrics.csv
swanlog/run-*/exported/epoch_metrics.jsonl

checkpoint 目录统一在:

outputs/experiments/optimized/<DATASET>

当前以验证 Dice 作为最优模型判断依据:

checkpoint:
  monitor: dice
  monitor_mode: max

12. 常用运行方式

单数据集正式训练:

DATASET=BUSI bash tools/run_optimized_supervised.sh

批量跑全部优化配置:

RUN_ALL=1 bash tools/run_optimized_supervised.sh

临时覆盖 batch:

DATASET=BUSI \
EXTRA_SET_ARGS="train.batch_size=32 train.val_batch_size=32" \
bash tools/run_optimized_supervised.sh

短跑检查配置链路:

DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 checkpoint.save=false logging.use_swanlab=false" \
bash tools/run_optimized_supervised.sh

13. 当前推荐训练顺序

建议先跑代表性主实验:

  1. BUSI
  2. TN3K
  3. TG3K
  4. BUS-BRA

再补充扩展实验:

  1. BUS-UCLM
  2. BUS_UC
  3. DDTI
  4. OTU_2d
  5. CCAUI

14. 出问题时先检查什么

现象 优先检查
找不到数据 dataset.root 是否存在,DATASET 是否拼写正确
找不到 split 对项目级 split 数据集先运行脚本,脚本会自动生成
CUDA 不可用 torch.cuda.is_available() 和当前运行权限
SwanLab 不上传 swanlab 是否登录、网络是否可用、logging.use_swanlab 是否为 true
显存不足 EXTRA_SET_ARGS 降低 train.batch_sizetrain.val_batch_size
loss 为 NaN 先降低 lr,再检查 mask 是否为空或异常
验证 Dice 长期为 0 检查 mask 读取、阈值、数据 split 和输出通道

15. 相关文件清单

文件 作用
tools/run_optimized_supervised.sh 优化训练一键入口
tools/train.py 统一 Python 训练入口
configs/segmentation/optimized/*.yaml 各数据集优化训练配置
scripts/generate_project_split.py 项目级 train/val split 生成
lib/trainers/supervised.py 全监督分割训练循环
lib/trainers/base.py 日志、checkpoint、设备、DataLoader 通用能力
lib/data/* 数据索引、读取、增强和 DataLoader
lib/modules/xnet_2d.py XNet2d 主模型
lib/tools/loss.py loss 构建
lib/tools/optim.py optimizer 和 scheduler 构建
scripts/probe_xnet_memory.py 合成 batch 显存探测脚本