# 优化配置训练流程说明 ## 1. 文档目的 本文说明当前优化版全监督训练从命令行启动到模型更新、验证、日志上传和 checkpoint 保存的完整流程。目标是让后续运行实验时清楚知道入口脚本、YAML、Trainer、DataLoader、模型、loss、日志和 checkpoint 之间如何衔接。 当前优化训练主入口是: ```bash DATASET=BUSI bash tools/run_optimized_supervised.sh ``` 批量运行全部优化配置: ```bash RUN_ALL=1 bash tools/run_optimized_supervised.sh ``` ## 2. 总体流程 ```mermaid flowchart TD accTitle: Optimized Training Flow accDescr: This diagram shows how optimized YAML configs are selected, loaded, and used by the supervised XNet2d training loop. command["命令行
DATASET=BUSI bash tools/run_optimized_supervised.sh"] script["tools/run_optimized_supervised.sh
映射数据集、生成 split、选择 YAML"] split{"需要项目级 split?"} split_script["scripts/generate_project_split.py
生成或复用 train.txt / val.txt"] train_entry["tools/train.py
读取 YAML 并应用 EXTRA_SET_ARGS"] trainer_builder["lib/trainers/builder.py
构建 supervised_segmentation trainer"] trainer["lib/trainers/supervised.py
build() + train()"] data["lib/data/*
索引、划分、Dataset、DataLoader"] model["lib/modules/xnet_2d.py
XNet2d forward"] loss["lib/tools/loss.py
DiceCELoss"] optim["lib/tools/optim.py
AdamW + cosine scheduler"] log["SwanLab + terminal
epoch metrics, lr, validation metrics"] ckpt["outputs/experiments/optimized/*
best / last checkpoint"] command --> script script --> split split -->|"BUSI/BUS-UCLM/BUS-BRA/BUS_UC/CCAUI/DDTI"| split_script split -->|"TN3K/TG3K/OTU_2d"| train_entry split_script --> train_entry train_entry --> trainer_builder trainer_builder --> trainer trainer --> data trainer --> model trainer --> loss trainer --> optim data --> trainer model --> loss loss --> optim trainer --> log trainer --> ckpt ``` ## 3. 优化配置文件 优化配置统一放在: ```text configs/segmentation/optimized/ ``` 当前包含 9 个数据集配置: | 数据集 | 配置文件 | 输入尺寸 | batch | lr | epochs | |---|---|---:|---:|---:|---:| | `BUSI` | `sup_busi_opt.yaml` | `384x384` | 48 | `8.0e-5` | 300 | | `BUS-UCLM` | `sup_bus_uclm_opt.yaml` | `384x384` | 48 | `8.0e-5` | 320 | | `BUS-BRA` | `sup_bus_bra_opt.yaml` | `384x384` | 48 | `1.2e-4` | 240 | | `BUS_UC` | `sup_bus_uc_opt.yaml` | `384x384` | 48 | `1.1e-4` | 260 | | `CCAUI` | `sup_ccaui_opt.yaml` | `384x384` | 48 | `1.1e-4` | 260 | | `DDTI` | `sup_ddti_opt.yaml` | `384x384` | 48 | `8.0e-5` | 320 | | `TN3K` | `sup_tn3k_opt.yaml` | `384x384` | 48 | `1.2e-4` | 220 | | `TG3K` | `sup_tg3k_opt.yaml` | `384x384` | 48 | `1.2e-4` | 220 | | `OTU_2d` | `sup_otu_2d_opt.yaml` | `384x384` | 48 | `1.2e-4` | 220 | 共同设置: 1. `train.amp: true`,默认使用混合精度训练。 2. `train.device: cuda`,正式训练走 GPU。 3. `train.batch_size: 48` 和 `train.val_batch_size: 48`,用于更充分利用 16GB GPU。 4. `dataset.image_size: [384, 384]`,保证各数据集实验公平。 5. `logging.use_swanlab: true` 和 `logging.swanlab_mode: cloud`,默认上传 SwanLab。 6. `checkpoint.save: true`,默认保存最优和最后 checkpoint。 ## 4. 运行脚本如何选择 YAML 优化脚本位置: ```text tools/run_optimized_supervised.sh ``` 脚本固定使用当前项目环境: ```bash PYTHON="/opt/miniforge3/envs/xnet_mamba/bin/python" ``` 脚本通过 `config_path()` 将数据集映射到配置文件,例如: ```bash "BUSI") echo "configs/segmentation/optimized/sup_busi_opt.yaml" ;; "TN3K") echo "configs/segmentation/optimized/sup_tn3k_opt.yaml" ;; ``` 随后调用: ```bash "$PYTHON" tools/train.py --config "$config" --set ${EXTRA_SET_ARGS} ``` 如果不传 `EXTRA_SET_ARGS`,训练完全按 YAML 执行。只有显式传入覆盖项时,才会修改 YAML 中的字段。 ## 5. split 生成流程 以下数据集会先生成或复用项目级划分: ```text BUSI, BUS-UCLM, BUS-BRA, BUS_UC, CCAUI, DDTI ``` 脚本执行: ```bash /opt/miniforge3/envs/xnet_mamba/bin/python scripts/generate_project_split.py \ --dataset BUSI \ --root data/BUSI \ --seed 42 ``` 输出位置: ```text data//splits/project/train.txt data//splits/project/val.txt ``` `TN3K`、`TG3K`、`OTU_2d` 走已有 split 逻辑或目录内划分,不由优化脚本生成项目级 split。 ## 6. `tools/train.py` 入口 训练入口文件: ```text tools/train.py ``` 核心逻辑: ```python cfg = load_yaml_config(cfg_path) cfg = apply_dotlist_overrides(cfg, args.set) trainer = build_trainer(cfg, args=args) trainer.train() ``` 含义: 1. `load_yaml_config()` 从 YAML 读取完整配置字典。 2. `apply_dotlist_overrides()` 只在 `--set key=value` 存在时覆盖配置。 3. `build_trainer()` 根据 `trainer.name` 构建训练器。 4. 当前优化配置的 `trainer.name` 都是 `supervised_segmentation`。 ## 7. Trainer 构建阶段 训练器实现: ```text lib/trainers/supervised.py ``` `SupervisedSegmentationTrainer.build()` 会完成: 1. 从 `cfg["model"]` 和 `cfg["dataset"]` 构建 `XNet2d`。 2. 使用 `lib/tools/optim.py` 构建 `AdamW`。 3. 使用 `lib/tools/optim.py` 构建 warmup + cosine scheduler。 4. 使用 `lib/tools/loss.py` 构建 MONAI `DiceCELoss`。 5. 构建训练 DataLoader。 6. 构建验证 DataLoader。 7. 如配置了 resume,则恢复 checkpoint。 8. 如 `logging.use_swanlab=true`,初始化 SwanLab run。 训练器基类 `lib/trainers/base.py` 负责设备、AMP、DataLoader、checkpoint、early stopping、SwanLab 日志和性能指标快照。 ## 8. 数据加载流程 数据加载模块位于: ```text lib/data/ ``` | 文件 | 职责 | |---|---| | `records.py` | 定义统一样本记录 `SegSampleRecord` | | `builder.py` | 按数据集名称选择索引逻辑 | | `indexers.py` | 根据目录结构匹配 image 和 mask | | `project_splits.py` | 生成和读取项目级 train/val split | | `datasets.py` | 真正读取图像和 mask | | `augment.py` | 执行 resize、flip、rotate、brightness/contrast、noise | | `loaders.py` | 组装 `Dataset` 和 `DataLoader` | 训练 batch 的核心字段: ```python { "image": Tensor[B, 3, H, W], "mask": Tensor[B, 1, H, W], "dataset_name": ..., "sample_id": ..., "class_name": ..., "meta": ..., } ``` 在优化配置中,`H=W=384`,`B=48`。 ## 9. 模型 forward 流程 主模型文件: ```text lib/modules/xnet_2d.py ``` ```mermaid flowchart TD accTitle: XNet2d Forward Path accDescr: This diagram summarizes how an input image moves through the XNet2d encoder, decoder, refinement, and output head. input["输入图像
B x 3 x 384 x 384"] stem["XNetStem2d
下采样并提升通道"] encoder["XNetEncoder2d
local + wavelet + VMamba global branches"] bottleneck["bottleneck blocks"] decoder["XNetDecoder2d
多尺度上采样融合"] refine{"use_frequency_refine?"} freq["XFrequencyRefine2d
FFT 频率精炼"] head["segmentation head"] output["输出
seg_logits / logits"] input --> stem stem --> encoder encoder --> bottleneck bottleneck --> decoder decoder --> refine refine -->|"true"| freq refine -->|"false"| head freq --> head head --> output ``` 当前优化配置开启: 1. `use_wavelet_branch: true` 2. `use_frequency_refine: true` 3. `use_global_branch_stage1: false` 4. `ssm_backend: auto` 注意:AMP 下 `ptwt.wavedec2` 不支持 `float16`,因此小波分解和重建内部会临时用 fp32 计算,再转回原 dtype。整体训练仍然保持 AMP。 ## 10. 单步训练流程 ```mermaid sequenceDiagram autonumber participant loader as DataLoader participant trainer as Trainer participant model as XNet2d participant loss as DiceCELoss participant optim as AdamW participant log as Logger loader->>trainer: batch(image, mask) trainer->>trainer: image/mask.to(cuda) trainer->>model: autocast forward(image) model-->>trainer: outputs["seg_logits"] trainer->>loss: loss(seg_logits, mask) loss-->>trainer: total_loss trainer->>trainer: GradScaler.scale(loss).backward() trainer->>optim: unscale + grad clip + step trainer->>trainer: GradScaler.update() trainer->>log: print step snapshot to terminal only ``` 训练 step 日志只打印到终端,不上传 SwanLab。它包含: 1. `epoch` 2. `step` 3. `num_steps` 4. `data_time` 5. `iter_time` 6. `gpu_memory_mb` 7. `lr` 8. `train_total` 9. `train_seg` ## 11. 验证、SwanLab 和 checkpoint 每个 epoch 结束后,如果满足: ```yaml validation: enabled: true interval: 1 ``` 训练器会执行完整验证集 forward,并计算 `val/total`、`val/seg`、`val/dice`、`val/iou`。 SwanLab 初始化在 `BaseTrainer._init_swanlab()`。当前优化配置默认: ```yaml logging: use_swanlab: true project: X_SSL_Net swanlab_mode: cloud swanlab_logdir: swanlog ``` `swanlab_mode: cloud` 会同步云端,`swanlab_logdir: swanlog` 会在本地保留运行日志目录。训练过程只在每个 epoch 结束后上传 SwanLab,不上传 step 级曲线。上传字段包括: | 阶段 | 指标 | |---|---| | epoch train | `total`, `seg`, `grad_norm` 如果本轮有记录 | | epoch val | `val_total`, `val_seg`, `val_dice`, `val_iou` | | epoch state | `epoch`, `best_metric`, `no_improve_epochs`, `lr`, `early_stop`, `improved` | SwanLab 的本地数据主要保存在: ```text swanlog/run-*/backup.swanlab ``` 这是 SwanLab 的二进制备份文件,不是普通文本日志。需要读取时使用: ```bash /opt/miniforge3/envs/xnet_mamba/bin/python tools/export_swanlab_backup.py swanlog/run-YYYYMMDD_HHMMSS- --exclude-system ``` 默认会导出: ```text swanlog/run-*/exported/scalars.csv swanlog/run-*/exported/scalars.jsonl swanlog/run-*/exported/logs.csv swanlog/run-*/exported/logs.jsonl swanlog/run-*/exported/epoch_metrics.csv swanlog/run-*/exported/epoch_metrics.jsonl ``` checkpoint 目录统一在: ```text outputs/experiments/optimized/ ``` 当前以验证 Dice 作为最优模型判断依据: ```yaml checkpoint: monitor: dice monitor_mode: max ``` ## 12. 常用运行方式 单数据集正式训练: ```bash DATASET=BUSI bash tools/run_optimized_supervised.sh ``` 批量跑全部优化配置: ```bash RUN_ALL=1 bash tools/run_optimized_supervised.sh ``` 临时覆盖 batch: ```bash DATASET=BUSI \ EXTRA_SET_ARGS="train.batch_size=32 train.val_batch_size=32" \ bash tools/run_optimized_supervised.sh ``` 短跑检查配置链路: ```bash DATASET=BUSI \ EXTRA_SET_ARGS="train.epochs=1 checkpoint.save=false logging.use_swanlab=false" \ bash tools/run_optimized_supervised.sh ``` ## 13. 当前推荐训练顺序 建议先跑代表性主实验: 1. `BUSI` 2. `TN3K` 3. `TG3K` 4. `BUS-BRA` 再补充扩展实验: 1. `BUS-UCLM` 2. `BUS_UC` 3. `DDTI` 4. `OTU_2d` 5. `CCAUI` ## 14. 出问题时先检查什么 | 现象 | 优先检查 | |---|---| | 找不到数据 | `dataset.root` 是否存在,`DATASET` 是否拼写正确 | | 找不到 split | 对项目级 split 数据集先运行脚本,脚本会自动生成 | | CUDA 不可用 | `torch.cuda.is_available()` 和当前运行权限 | | SwanLab 不上传 | `swanlab` 是否登录、网络是否可用、`logging.use_swanlab` 是否为 true | | 显存不足 | 用 `EXTRA_SET_ARGS` 降低 `train.batch_size` 和 `train.val_batch_size` | | loss 为 NaN | 先降低 lr,再检查 mask 是否为空或异常 | | 验证 Dice 长期为 0 | 检查 mask 读取、阈值、数据 split 和输出通道 | ## 15. 相关文件清单 | 文件 | 作用 | |---|---| | `tools/run_optimized_supervised.sh` | 优化训练一键入口 | | `tools/train.py` | 统一 Python 训练入口 | | `configs/segmentation/optimized/*.yaml` | 各数据集优化训练配置 | | `scripts/generate_project_split.py` | 项目级 train/val split 生成 | | `lib/trainers/supervised.py` | 全监督分割训练循环 | | `lib/trainers/base.py` | 日志、checkpoint、设备、DataLoader 通用能力 | | `lib/data/*` | 数据索引、读取、增强和 DataLoader | | `lib/modules/xnet_2d.py` | XNet2d 主模型 | | `lib/tools/loss.py` | loss 构建 | | `lib/tools/optim.py` | optimizer 和 scheduler 构建 | | `scripts/probe_xnet_memory.py` | 合成 batch 显存探测脚本 |