# 优化配置训练流程说明
## 1. 文档目的
本文说明当前优化版全监督训练从命令行启动到模型更新、验证、日志上传和 checkpoint 保存的完整流程。目标是让后续运行实验时清楚知道入口脚本、YAML、Trainer、DataLoader、模型、loss、日志和 checkpoint 之间如何衔接。
当前优化训练主入口是:
```bash
DATASET=BUSI bash tools/run_optimized_supervised.sh
```
批量运行全部优化配置:
```bash
RUN_ALL=1 bash tools/run_optimized_supervised.sh
```
## 2. 总体流程
```mermaid
flowchart TD
accTitle: Optimized Training Flow
accDescr: This diagram shows how optimized YAML configs are selected, loaded, and used by the supervised XNet2d training loop.
command["命令行
DATASET=BUSI bash tools/run_optimized_supervised.sh"]
script["tools/run_optimized_supervised.sh
映射数据集、生成 split、选择 YAML"]
split{"需要项目级 split?"}
split_script["scripts/generate_project_split.py
生成或复用 train.txt / val.txt"]
train_entry["tools/train.py
读取 YAML 并应用 EXTRA_SET_ARGS"]
trainer_builder["lib/trainers/builder.py
构建 supervised_segmentation trainer"]
trainer["lib/trainers/supervised.py
build() + train()"]
data["lib/data/*
索引、划分、Dataset、DataLoader"]
model["lib/modules/xnet_2d.py
XNet2d forward"]
loss["lib/tools/loss.py
DiceCELoss"]
optim["lib/tools/optim.py
AdamW + cosine scheduler"]
log["SwanLab + terminal
epoch metrics, lr, validation metrics"]
ckpt["outputs/experiments/optimized/*
best / last checkpoint"]
command --> script
script --> split
split -->|"BUSI/BUS-UCLM/BUS-BRA/BUS_UC/CCAUI/DDTI"| split_script
split -->|"TN3K/TG3K/OTU_2d"| train_entry
split_script --> train_entry
train_entry --> trainer_builder
trainer_builder --> trainer
trainer --> data
trainer --> model
trainer --> loss
trainer --> optim
data --> trainer
model --> loss
loss --> optim
trainer --> log
trainer --> ckpt
```
## 3. 优化配置文件
优化配置统一放在:
```text
configs/segmentation/optimized/
```
当前包含 9 个数据集配置:
| 数据集 | 配置文件 | 输入尺寸 | batch | lr | epochs |
|---|---|---:|---:|---:|---:|
| `BUSI` | `sup_busi_opt.yaml` | `384x384` | 48 | `8.0e-5` | 300 |
| `BUS-UCLM` | `sup_bus_uclm_opt.yaml` | `384x384` | 48 | `8.0e-5` | 320 |
| `BUS-BRA` | `sup_bus_bra_opt.yaml` | `384x384` | 48 | `1.2e-4` | 240 |
| `BUS_UC` | `sup_bus_uc_opt.yaml` | `384x384` | 48 | `1.1e-4` | 260 |
| `CCAUI` | `sup_ccaui_opt.yaml` | `384x384` | 48 | `1.1e-4` | 260 |
| `DDTI` | `sup_ddti_opt.yaml` | `384x384` | 48 | `8.0e-5` | 320 |
| `TN3K` | `sup_tn3k_opt.yaml` | `384x384` | 48 | `1.2e-4` | 220 |
| `TG3K` | `sup_tg3k_opt.yaml` | `384x384` | 48 | `1.2e-4` | 220 |
| `OTU_2d` | `sup_otu_2d_opt.yaml` | `384x384` | 48 | `1.2e-4` | 220 |
共同设置:
1. `train.amp: true`,默认使用混合精度训练。
2. `train.device: cuda`,正式训练走 GPU。
3. `train.batch_size: 48` 和 `train.val_batch_size: 48`,用于更充分利用 16GB GPU。
4. `dataset.image_size: [384, 384]`,保证各数据集实验公平。
5. `logging.use_swanlab: true` 和 `logging.swanlab_mode: cloud`,默认上传 SwanLab。
6. `checkpoint.save: true`,默认保存最优和最后 checkpoint。
## 4. 运行脚本如何选择 YAML
优化脚本位置:
```text
tools/run_optimized_supervised.sh
```
脚本固定使用当前项目环境:
```bash
PYTHON="/opt/miniforge3/envs/xnet_mamba/bin/python"
```
脚本通过 `config_path()` 将数据集映射到配置文件,例如:
```bash
"BUSI") echo "configs/segmentation/optimized/sup_busi_opt.yaml" ;;
"TN3K") echo "configs/segmentation/optimized/sup_tn3k_opt.yaml" ;;
```
随后调用:
```bash
"$PYTHON" tools/train.py --config "$config" --set ${EXTRA_SET_ARGS}
```
如果不传 `EXTRA_SET_ARGS`,训练完全按 YAML 执行。只有显式传入覆盖项时,才会修改 YAML 中的字段。
## 5. split 生成流程
以下数据集会先生成或复用项目级划分:
```text
BUSI, BUS-UCLM, BUS-BRA, BUS_UC, CCAUI, DDTI
```
脚本执行:
```bash
/opt/miniforge3/envs/xnet_mamba/bin/python scripts/generate_project_split.py \
--dataset BUSI \
--root data/BUSI \
--seed 42
```
输出位置:
```text
data//splits/project/train.txt
data//splits/project/val.txt
```
`TN3K`、`TG3K`、`OTU_2d` 走已有 split 逻辑或目录内划分,不由优化脚本生成项目级 split。
## 6. `tools/train.py` 入口
训练入口文件:
```text
tools/train.py
```
核心逻辑:
```python
cfg = load_yaml_config(cfg_path)
cfg = apply_dotlist_overrides(cfg, args.set)
trainer = build_trainer(cfg, args=args)
trainer.train()
```
含义:
1. `load_yaml_config()` 从 YAML 读取完整配置字典。
2. `apply_dotlist_overrides()` 只在 `--set key=value` 存在时覆盖配置。
3. `build_trainer()` 根据 `trainer.name` 构建训练器。
4. 当前优化配置的 `trainer.name` 都是 `supervised_segmentation`。
## 7. Trainer 构建阶段
训练器实现:
```text
lib/trainers/supervised.py
```
`SupervisedSegmentationTrainer.build()` 会完成:
1. 从 `cfg["model"]` 和 `cfg["dataset"]` 构建 `XNet2d`。
2. 使用 `lib/tools/optim.py` 构建 `AdamW`。
3. 使用 `lib/tools/optim.py` 构建 warmup + cosine scheduler。
4. 使用 `lib/tools/loss.py` 构建 MONAI `DiceCELoss`。
5. 构建训练 DataLoader。
6. 构建验证 DataLoader。
7. 如配置了 resume,则恢复 checkpoint。
8. 如 `logging.use_swanlab=true`,初始化 SwanLab run。
训练器基类 `lib/trainers/base.py` 负责设备、AMP、DataLoader、checkpoint、early stopping、SwanLab 日志和性能指标快照。
## 8. 数据加载流程
数据加载模块位于:
```text
lib/data/
```
| 文件 | 职责 |
|---|---|
| `records.py` | 定义统一样本记录 `SegSampleRecord` |
| `builder.py` | 按数据集名称选择索引逻辑 |
| `indexers.py` | 根据目录结构匹配 image 和 mask |
| `project_splits.py` | 生成和读取项目级 train/val split |
| `datasets.py` | 真正读取图像和 mask |
| `augment.py` | 执行 resize、flip、rotate、brightness/contrast、noise |
| `loaders.py` | 组装 `Dataset` 和 `DataLoader` |
训练 batch 的核心字段:
```python
{
"image": Tensor[B, 3, H, W],
"mask": Tensor[B, 1, H, W],
"dataset_name": ...,
"sample_id": ...,
"class_name": ...,
"meta": ...,
}
```
在优化配置中,`H=W=384`,`B=48`。
## 9. 模型 forward 流程
主模型文件:
```text
lib/modules/xnet_2d.py
```
```mermaid
flowchart TD
accTitle: XNet2d Forward Path
accDescr: This diagram summarizes how an input image moves through the XNet2d encoder, decoder, refinement, and output head.
input["输入图像
B x 3 x 384 x 384"]
stem["XNetStem2d
下采样并提升通道"]
encoder["XNetEncoder2d
local + wavelet + VMamba global branches"]
bottleneck["bottleneck blocks"]
decoder["XNetDecoder2d
多尺度上采样融合"]
refine{"use_frequency_refine?"}
freq["XFrequencyRefine2d
FFT 频率精炼"]
head["segmentation head"]
output["输出
seg_logits / logits"]
input --> stem
stem --> encoder
encoder --> bottleneck
bottleneck --> decoder
decoder --> refine
refine -->|"true"| freq
refine -->|"false"| head
freq --> head
head --> output
```
当前优化配置开启:
1. `use_wavelet_branch: true`
2. `use_frequency_refine: true`
3. `use_global_branch_stage1: false`
4. `ssm_backend: auto`
注意:AMP 下 `ptwt.wavedec2` 不支持 `float16`,因此小波分解和重建内部会临时用 fp32 计算,再转回原 dtype。整体训练仍然保持 AMP。
## 10. 单步训练流程
```mermaid
sequenceDiagram
autonumber
participant loader as DataLoader
participant trainer as Trainer
participant model as XNet2d
participant loss as DiceCELoss
participant optim as AdamW
participant log as Logger
loader->>trainer: batch(image, mask)
trainer->>trainer: image/mask.to(cuda)
trainer->>model: autocast forward(image)
model-->>trainer: outputs["seg_logits"]
trainer->>loss: loss(seg_logits, mask)
loss-->>trainer: total_loss
trainer->>trainer: GradScaler.scale(loss).backward()
trainer->>optim: unscale + grad clip + step
trainer->>trainer: GradScaler.update()
trainer->>log: print step snapshot to terminal only
```
训练 step 日志只打印到终端,不上传 SwanLab。它包含:
1. `epoch`
2. `step`
3. `num_steps`
4. `data_time`
5. `iter_time`
6. `gpu_memory_mb`
7. `lr`
8. `train_total`
9. `train_seg`
## 11. 验证、SwanLab 和 checkpoint
每个 epoch 结束后,如果满足:
```yaml
validation:
enabled: true
interval: 1
```
训练器会执行完整验证集 forward,并计算 `val/total`、`val/seg`、`val/dice`、`val/iou`。
SwanLab 初始化在 `BaseTrainer._init_swanlab()`。当前优化配置默认:
```yaml
logging:
use_swanlab: true
project: X_SSL_Net
swanlab_mode: cloud
swanlab_logdir: swanlog
```
`swanlab_mode: cloud` 会同步云端,`swanlab_logdir: swanlog` 会在本地保留运行日志目录。训练过程只在每个 epoch 结束后上传 SwanLab,不上传 step 级曲线。上传字段包括:
| 阶段 | 指标 |
|---|---|
| epoch train | `total`, `seg`, `grad_norm` 如果本轮有记录 |
| epoch val | `val_total`, `val_seg`, `val_dice`, `val_iou` |
| epoch state | `epoch`, `best_metric`, `no_improve_epochs`, `lr`, `early_stop`, `improved` |
SwanLab 的本地数据主要保存在:
```text
swanlog/run-*/backup.swanlab
```
这是 SwanLab 的二进制备份文件,不是普通文本日志。需要读取时使用:
```bash
/opt/miniforge3/envs/xnet_mamba/bin/python tools/export_swanlab_backup.py swanlog/run-YYYYMMDD_HHMMSS- --exclude-system
```
默认会导出:
```text
swanlog/run-*/exported/scalars.csv
swanlog/run-*/exported/scalars.jsonl
swanlog/run-*/exported/logs.csv
swanlog/run-*/exported/logs.jsonl
swanlog/run-*/exported/epoch_metrics.csv
swanlog/run-*/exported/epoch_metrics.jsonl
```
checkpoint 目录统一在:
```text
outputs/experiments/optimized/
```
当前以验证 Dice 作为最优模型判断依据:
```yaml
checkpoint:
monitor: dice
monitor_mode: max
```
## 12. 常用运行方式
单数据集正式训练:
```bash
DATASET=BUSI bash tools/run_optimized_supervised.sh
```
批量跑全部优化配置:
```bash
RUN_ALL=1 bash tools/run_optimized_supervised.sh
```
临时覆盖 batch:
```bash
DATASET=BUSI \
EXTRA_SET_ARGS="train.batch_size=32 train.val_batch_size=32" \
bash tools/run_optimized_supervised.sh
```
短跑检查配置链路:
```bash
DATASET=BUSI \
EXTRA_SET_ARGS="train.epochs=1 checkpoint.save=false logging.use_swanlab=false" \
bash tools/run_optimized_supervised.sh
```
## 13. 当前推荐训练顺序
建议先跑代表性主实验:
1. `BUSI`
2. `TN3K`
3. `TG3K`
4. `BUS-BRA`
再补充扩展实验:
1. `BUS-UCLM`
2. `BUS_UC`
3. `DDTI`
4. `OTU_2d`
5. `CCAUI`
## 14. 出问题时先检查什么
| 现象 | 优先检查 |
|---|---|
| 找不到数据 | `dataset.root` 是否存在,`DATASET` 是否拼写正确 |
| 找不到 split | 对项目级 split 数据集先运行脚本,脚本会自动生成 |
| CUDA 不可用 | `torch.cuda.is_available()` 和当前运行权限 |
| SwanLab 不上传 | `swanlab` 是否登录、网络是否可用、`logging.use_swanlab` 是否为 true |
| 显存不足 | 用 `EXTRA_SET_ARGS` 降低 `train.batch_size` 和 `train.val_batch_size` |
| loss 为 NaN | 先降低 lr,再检查 mask 是否为空或异常 |
| 验证 Dice 长期为 0 | 检查 mask 读取、阈值、数据 split 和输出通道 |
## 15. 相关文件清单
| 文件 | 作用 |
|---|---|
| `tools/run_optimized_supervised.sh` | 优化训练一键入口 |
| `tools/train.py` | 统一 Python 训练入口 |
| `configs/segmentation/optimized/*.yaml` | 各数据集优化训练配置 |
| `scripts/generate_project_split.py` | 项目级 train/val split 生成 |
| `lib/trainers/supervised.py` | 全监督分割训练循环 |
| `lib/trainers/base.py` | 日志、checkpoint、设备、DataLoader 通用能力 |
| `lib/data/*` | 数据索引、读取、增强和 DataLoader |
| `lib/modules/xnet_2d.py` | XNet2d 主模型 |
| `lib/tools/loss.py` | loss 构建 |
| `lib/tools/optim.py` | optimizer 和 scheduler 构建 |
| `scripts/probe_xnet_memory.py` | 合成 batch 显存探测脚本 |