# 当前项目详解与纯文本架构流程图 ## 1. 当前项目定位 `X_SSL_Net` 当前 active 主线是一个面向 2D 超声图像分割的全监督训练工程。 当前真实训练链路: ```text shell script -> tools/train.py -> SupervisedSegmentationTrainer -> SegmentationRecordDataset / DataLoader -> XNet2d -> seg_logits -> DiceCE loss / BCE fallback -> Dice / IoU validation -> best.pth / last.pth ``` 当前真实模型主线: ```text XNet2d = X-shaped CNN-Wavelet-VMamba hybrid segmentation network ``` 当前主训练只使用一个分割头: ```text outputs["seg_logits"] ``` 当前主线不调用: 1. `lib/sam2` 2. `lib/SwinTransformer` 3. SwinV2 segmentation config 4. boundary auxiliary head 5. semi-supervised trainer `lib/sam2` 与 `lib/SwinTransformer` 目前作为外部代码资产保留,不进入当前训练路径。 ## 2. 一句话总览 当前项目可以概括为: ```text 用 XNet2d 在 BUSI / DDTI / TN3K / TG3K 等 2D 超声数据集上做全监督分割训练。 XNet2d 的 encoder 用 local + wavelet + VMamba-style SS2D 三分支建模, decoder 用同尺度 skip + 斜向 guide + 频率细化恢复 mask。 ``` ## 3. 启动入口 ### 3.1 推荐 shell 入口 最常用入口: ```bash DATASET=BUSI bash tools/run_us_experiments.sh ``` 短跑调试入口: ```bash DATASET=BUSI \ EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \ bash tools/run_us_experiments.sh ``` ### 3.2 shell 脚本职责 文件: ```text tools/run_us_experiments.sh ``` 它做四件事: 1. 解析 `DATASET` 2. 映射数据集根目录 3. 对需要项目级划分的数据集生成或加载 `train/val` 4. 调用 `tools/train.py` 支持的数据集名称: ```text BUS-UCLM BUSI BUS-BRA BUS_UC CCAUI DDTI OTU_2d TN3K TG3K ``` 数据集根目录映射: ```text BUSI -> data/BUSI DDTI -> data/DDTI TN3K -> data/TN3K TG3K -> data/TG3K BUS_UC -> data/BUS_UC ... ``` 项目级 split 数据集: ```text BUS-UCLM, BUSI, BUS-BRA, BUS_UC, CCAUI, DDTI ``` 官方 split 数据集: ```text OTU_2d, TN3K, TG3K ``` ## 4. 从 shell 到 Python 的总流程图 ```text User command | | DATASET=BUSI EXTRA_SET_ARGS="..." bash tools/run_us_experiments.sh v +----------------------------------------------------------------------------------+ | tools/run_us_experiments.sh | +----------------------------------------------------------------------------------+ | 1. read DATASET / SEED / EXTRA_SET_ARGS | | 2. dataset_root(DATASET) | | 3. if DATASET needs project split: | | python scripts/generate_project_split.py --dataset DATASET --root ROOT | | 4. python tools/train.py | | --config configs/segmentation/train_sup_us_template.yaml | | --set dataset.dataset_name=DATASET dataset.root=ROOT ... EXTRA_SET_ARGS | +----------------------------------------------------------------------------------+ | v +----------------------------------------------------------------------------------+ | tools/train.py | +----------------------------------------------------------------------------------+ | 1. parse --config / --trainer / --set | | 2. load yaml config | | 3. apply dotlist overrides | | 4. optional override trainer.name | | 5. build_trainer(cfg) | | 6. trainer.train() | +----------------------------------------------------------------------------------+ ``` ## 5. 配置系统 当前主配置: ```text configs/segmentation/train_sup_us_template.yaml ``` 当前保留的 segmentation 配置: ```text configs/segmentation/train_sup_us_template.yaml configs/segmentation/us_exp_sup_busi.yaml configs/segmentation/us_exp_sup_busi_ablation.yaml ``` ### 5.1 配置覆盖方式 `tools/train.py` 支持: ```text --set key=value key=value ... ``` 例如: ```bash --set train.epochs=1 train.batch_size=8 model.use_frequency_refine=false ``` 覆盖逻辑: ```text load_yaml_config(path) | v apply_dotlist_overrides(cfg, args.set) | v nested dict update ``` ### 5.2 当前关键配置 训练: ```yaml train: epochs: 200 batch_size: 4 val_batch_size: 4 amp: true num_workers: 4 pin_memory: true persistent_workers: true prefetch_factor: 2 device: cuda grad_clip: enabled: true max_norm: 1.0 ``` 数据: ```yaml dataset: dataset_name: BUSI root: data/BUSI split: train val_split: val image_size: [256, 256] in_channels: 3 num_classes: 1 ``` 模型: ```yaml model: in_channels: 3 encoder_channels: [32, 64, 128, 192] encoder_depths: [2, 2, 2, 2] decoder_channels: [128, 64, 32] stem_channels: 24 bottleneck_depth: 1 global_ratio: 2.0 wavelet_type: haar wavelet_level: 1 use_wavelet_branch: true use_global_branch_stage1: false ssm_d_state: 16 ssm_forward_type: v3 ssm_backend: auto use_frequency_refine: true guide_mode: affine out_channels: null ``` 优化: ```yaml optimizer: name: adamw lr: 1.0e-4 weight_decay: 0.05 scheduler: name: cosine warmup: name: linear params: start_factor: 0.1 total_iters: 10 params: T_max: 190 eta_min: 1.0e-6 ``` loss 与 metric: ```yaml loss: name: dicece task_mode: binary params: include_background: true lambda_dice: 0.7 lambda_ce: 0.3 validation: threshold: 0.5 metrics: task_mode: binary metrics: - name: dice - name: iou ``` ## 6. Trainer 构建流程 入口: ```text lib/trainers/builder.py::build_trainer ``` 当前 trainer: ```text lib/trainers/supervised.py::SupervisedSegmentationTrainer ``` 构建流程: ```text build_trainer(cfg) | v read cfg.trainer.name | v TRAINER_REGISTRY["supervised_segmentation"] | v trainer = SupervisedSegmentationTrainer(cfg, args) | v trainer.build() | v return trainer ``` `SupervisedSegmentationTrainer.build()` 做: ```text 1. dataset_cfg = cfg["dataset"] 2. model_cfg = cfg["model"] 3. train_cfg = cfg["train"] 4. build XNet2d from model_cfg 5. move model to device 6. build optimizer 7. build scheduler 8. build loss if cfg.loss is not null 9. build train dataloader 10. build validation dataloader 11. maybe resume checkpoint 12. maybe init SwanLab ``` ## 7. BaseTrainer 公共职责 文件: ```text lib/trainers/base.py ``` 公共职责: ```text BaseTrainer ├─ random seed ├─ device selection ├─ output directory ├─ AMP GradScaler ├─ batch size resolution ├─ dataloader construction helper ├─ validation metric construction ├─ checkpoint save / resume ├─ early stopping ├─ SwanLab logging ├─ training setup summary ├─ step performance logging └─ epoch finalization ``` 设备选择: ```text cfg.train.device == "cuda" and torch.cuda.is_available() -> cuda else -> cpu ``` AMP 开关: ```text cfg.train.amp == true and device == cuda -> enabled else -> disabled ``` 当前已验证目标环境: ```text conda env: xnet_mamba torch: 2.10.0+cu126 GPU: NVIDIA GeForce RTX 4070 Ti SUPER selective_scan_cuda_oflex: available ``` ## 8. 数据链路 ### 8.1 数据 index 构建 入口: ```text lib/data/builder.py::build_dataset_index ``` 核心 registry: ```text BUS-UCLM -> paired images/masks BUSI -> Dataset_BUSI_with_GT/{benign,malignant,normal} BUS-BRA -> prefixed image/mask matching BUS_UC -> All / Benign / Malignant folders CCAUI -> US images / Expert mask images DDTI -> XML annotation records OTU_2d -> images / annotations TN3K -> trainval/test image/mask folders TG3K -> thyroid-image / thyroid-mask ``` ### 8.2 split 应用 入口: ```text lib/data/loaders.py::apply_official_split ``` 流程: ```text build_dataset_index(dataset_name, root) | v if split is requested: | +-- OTU_2d: read train.txt / val.txt | +-- TN3K: read tn3k-trainval.json or use test folder | +-- TG3K: read tg3k-trainval.json | +-- project split dataset: read data//splits/project/train.txt or val.txt ``` 项目级 split 生成: ```text scripts/generate_project_split.py | v generate_project_splits() | v write: data//splits/project/train.txt data//splits/project/val.txt ``` ### 8.3 Dataset 读取 文件: ```text lib/data/datasets.py::SegmentationRecordDataset ``` 单样本读取: ```text record | +-- image_path -> PIL RGB -> float32 [3,H,W] in [0,1] | +-- mask_path -> PIL L -> binary float32 [1,H,W] | +-- DDTI special: annotation_path XML -> build_ddti_mask() -> binary [1,H,W] | +-- joint augmentation | +-- resize image to dataset.image_size | +-- resize mask to dataset.image_size | v { "image": image, "mask": mask, "dataset_name": ..., "sample_id": ..., "split": ..., "class_name": ..., "meta": ... } ``` ### 8.4 augmentation 文件: ```text lib/data/augment.py::SegmentationAugmentation ``` 当前支持: ```text spatial: random horizontal flip random vertical flip random rotate 90 intensity: random brightness / contrast random gaussian noise clamp to [0,1] ``` ### 8.5 collate 文件: ```text lib/data/collate.py::record_collate_fn ``` 逻辑: ```text if all tensor shapes same: torch.stack(values, dim=0) else: keep list strings / dict / metadata: keep list ``` 最终 batch: ```text image: [B,3,256,256] mask : [B,1,256,256] ``` ## 9. Dataloader 流程图 ```text SupervisedSegmentationTrainer.build() | v _build_segmentation_loader(split="train") | v build_dataloader() | v build_record_dataset() | v build_dataset_index() | v apply_official_split() | v SegmentationRecordDataset(records, transforms) | v DataLoader( batch_size, shuffle, num_workers, pin_memory, persistent_workers, prefetch_factor, collate_fn=record_collate_fn ) ``` 注意:`DataLoader` worker 的真实启动通常发生在第一次迭代时,也就是 `======== END TRAINING SETUP ========` 之后。若 `num_workers > 0`,第一批数据可能出现一次性等待。 ## 10. XNet2d 总体结构 文件: ```text lib/modules/xnet_2d.py ``` 当前默认参数量: ```text total parameters: 9,432,129 trainable parameters: 9,432,129 ``` 顶层结构: ```text XNet2d ├─ XNetEncoder2d │ ├─ XNetStem2d │ ├─ Encoder Stage 1: XTEB2d x 2 │ ├─ Downsample 1 │ ├─ Encoder Stage 2: XTEB2d x 2 │ ├─ Downsample 2 │ ├─ Encoder Stage 3: XTEB2d x 2 │ ├─ Downsample 3 │ └─ Encoder Stage 4: XTEB2d x 2 │ ├─ Bottleneck: XTEB2d x 1 │ ├─ XNetDecoder2d │ ├─ guide4: E4 -> D4 affine guide │ ├─ dec4: XCRB2d(E4, E3, guide4) │ ├─ guide3: E3 -> D3 affine guide │ ├─ dec3: XCRB2d(D4, E2, guide3) │ ├─ guide2: E2 -> D2 affine guide │ ├─ dec2: XCRB2d(D3, E1, guide2) │ └─ head_refine │ └─ XNetSegHead2d ``` ## 11. XNet2d 纯文本架构图 以输入 `[B,3,256,256]` 为例,默认通道为 `[32,64,128,192]`: ```text Input [B, 3, 256, 256] | v XNetStem2d Conv3x3 s2: [B, 24, 128, 128] DWConv3x3: [B, 24, 128, 128] PWConv1x1: [B, 32, 128, 128] Conv3x3 s2: [B, 32, 64, 64] | v E1 = Encoder Stage 1, XTEB x2 [B, 32, 64, 64] | v Down1 [B, 64, 32, 32] | v E2 = Encoder Stage 2, XTEB x2 [B, 64, 32, 32] | v Down2 [B, 128, 16, 16] | v E3 = Encoder Stage 3, XTEB x2 [B, 128, 16, 16] | v Down3 [B, 192, 8, 8] | v E4 = Encoder Stage 4, XTEB x2 [B, 192, 8, 8] | v Bottleneck XTEB x1 [B, 192, 8, 8] ``` Decoder: ```text E4 [B,192,8,8] | +-- guide4 = Phi(E4) -> resize to E3 size -> affine gamma/beta for d4 | v dec4 input: decoder input: E4 [B,192,8,8] same-scale skip: E3 [B,128,16,16] guide: g4 output D4 [B,128,16,16] D4 [B,128,16,16] | +-- guide3 = Phi(E3) -> resize to E2 size -> affine gamma/beta for d3 | v dec3 input: decoder input: D4 [B,128,16,16] same-scale skip: E2 [B,64,32,32] guide: g3 output D3 [B,64,32,32] D3 [B,64,32,32] | +-- guide2 = Phi(E2) -> resize to E1 size -> affine gamma/beta for d2 | v dec2 input: decoder input: D3 [B,64,32,32] same-scale skip: E1 [B,32,64,64] guide: g2 output D2 [B,32,64,64] D2 [B,32,64,64] | v HeadRefine [B,32,64,64] | v SegHead + upsample to input size [B,1,256,256] ``` ## 12. XTEB2d 详解 `XTEB2d` 是 encoder 的基本 block。 名字含义: ```text XTEB = XNet Tri-branch Encoding Block ``` 输入输出: ```text input : X [B,C,H,W] output: Y [B,C,H,W] ``` 内部结构: ```text X │ ├─ pre_norm: 1x1 Conv2dBN │ ├─ Local branch │ ├─ DWConv3x3 + PWConv1x1 │ └─ DWConv5x5 + PWConv1x1 │ ├─ Wavelet branch │ ├─ Haar DWT │ │ ├─ LL │ │ └─ LH/HL/HH high bands │ ├─ LL projection │ ├─ high-band projection │ └─ inverse Haar transform │ ├─ Global branch │ ├─ 1x1 pre projection │ ├─ VMamba-style SS2D │ └─ 1x1 post projection │ ├─ concat(local, wavelet, global) ├─ 1x1 fusion ├─ channel gate from GAP + MLP + sigmoid ├─ residual add └─ lightweight FFN + residual add ``` 公式化: ```text X0 = PreNorm(X) L = Local(X0) W = Wavelet(X0) G = GlobalSS2D(X0) F = Fuse([L,W,G]) Y = X + Post(F) Z = Y + FFN(Y) ``` ### 12.1 Local branch 职责: ```text 局部纹理、边界、短程结构 ``` 结构: ```text DWConv3x3 -> ReLU -> PWConv1x1 DWConv5x5 -> ReLU -> PWConv1x1 sum ``` ### 12.2 Wavelet branch 职责: ```text 低频轮廓 + 高频边界/纹理 ``` 结构: ```text Haar DWT: LL -> low-frequency structure LH/HL/HH -> high-frequency directional details LL -> Conv projection High bands -> depthwise conv + pointwise conv IDWT -> output projection ``` 当前限制: ```text wavelet_type = haar wavelet_level = 1 ``` ### 12.3 Global SS2D branch 职责: ```text 高效长程依赖建模、全局结构一致性 ``` 当前实现: ```text lib/modules/lib_mamba/vmamba.py::SS2D ``` 来源: ```text VMamba-style SS2D operator ``` 后端选择: ```text ssm_backend = auto | +-- if x.is_cuda: selective_scan_backend = oflex scan_force_torch = false | +-- else: selective_scan_backend = torch scan_force_torch = true ssm_backend = oflex -> force oflex ssm_backend = torch -> force torch fallback ``` 当前默认: ```text ssm_forward_type = v3 ssm_backend = auto ``` 在 `xnet_mamba` + RTX 4070 Ti SUPER 环境中已验证: ```text selective_scan_cuda_oflex import OK WITH_SELECTIVESCAN_OFLEX = True ``` ## 13. XCRB2d 详解 `XCRB2d` 是 decoder 的基本 block。 名字含义: ```text XCRB = XNet Cross-guided Reconstruction Block ``` 输入: ```text decoder input: deeper decoder or bottleneck feature same-scale skip: encoder feature at target scale diagonal guide: deeper encoder semantic guide ``` 内部结构: ```text decoder input | v bilinear upsample to skip size | v 1x1 projection | +-----------------------------+ | same-scale skip | | | v | 1x1 projection | | | +----------- concat ----------+ | v 3x3 fusion | v guide affine modulation | v optional frequency refine | v residual spatial refine ``` ### 13.1 X-shaped 信息流 当前 decoder 不只是普通 U-Net 横向 skip。 它同时使用: ```text same-scale path: E3 -> D4 E2 -> D3 E1 -> D2 diagonal guide path: E4 -> D4 E3 -> D3 E2 -> D2 ``` 纯文本示意: ```text Encoder: E1 ---------------------------> D2 \ / \ / Encoder: E2 -------------------> D3 \ guide to D2 / \ / Encoder: E3 -------------> D4 \ guide to D3 / \ / Encoder: E4 --------/ guide to D4 ``` ### 13.2 Guide modulation 默认 `guide_mode=affine`。 流程: ```text guide feature | v resize to target decoder scale | v projection -> [gamma, beta] | v gamma = sigmoid(gamma) + 0.5 | v F' = gamma * F + beta ``` ### 13.3 Frequency refine 默认 `use_frequency_refine=true`。 流程: ```text feature F | v cast to float32 if needed | v rfft2 | +-- low frequency mask | +-- high frequency residual | v low/high learnable gates | v irfft2 | v cast back to input dtype | v depthwise conv refine ``` 这里显式将 FFT 计算放在 `float32` 中,避免 AMP 下触发 `ComplexHalf support is experimental` warning。 ## 14. XNet2d forward 输出 `XNet2d.forward(x)` 返回: ```python { "logits": logits, "seg_logits": logits, "encoder_features": encoder_features, "decoder_features": decoder_features, "guides": guides, } ``` 训练只使用: ```text outputs["seg_logits"] ``` 其余输出用于: ```text debug visualization future auxiliary analysis ``` 当前没有边界辅助输出。 ## 15. 训练循环详解 入口: ```text SupervisedSegmentationTrainer.train() ``` 流程: ```text train() | v print training setup | v for epoch in range(start_epoch, epochs): | +-- model.train() +-- optimizer.zero_grad() +-- for step, batch in train_loader: | +-- measure data_time | +-- image = batch["image"].to(device) +-- mask = batch["mask"].to(device) | +-- with autocast(enabled=amp): outputs = model(image) seg_logits = outputs["seg_logits"] seg_loss = loss(seg_logits, mask) total_loss = seg_loss | +-- scaled_total_loss = total_loss / accum_steps +-- grad_scaler.scale(scaled_total_loss).backward() | +-- if should optimizer step: unscale gradients if grad clipping enabled clip grad norm grad_scaler.step(optimizer) grad_scaler.update() optimizer.zero_grad() | +-- log step every logging.log_interval | +-- scheduler.step() | +-- validate if enabled and interval matches | +-- finalize epoch | +-- merge train / val metrics +-- update best metric +-- save best.pth if improved +-- save last.pth if enabled +-- early stopping check ``` ## 16. Loss 路径 当前配置使用: ```text MONAI DiceCELoss ``` 构建路径: ```text cfg.loss | v lib/tools/loss.py::build_loss | v DiceCELoss(sigmoid=True, include_background=True, lambda_dice=0.7, lambda_ce=0.3) ``` 如果 `loss: null`: ```text torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask) ``` 该 fallback 适合环境临时缺 MONAI 时做 smoke test,不建议作为正式论文训练默认。 ## 17. Validation 路径 验证函数: ```text SupervisedSegmentationTrainer._validate() ``` 流程: ```text model.eval() build validation metrics for batch in val_loader: image -> device mask -> device outputs, losses = _compute_losses(image, mask) update loss sums update metrics with outputs["seg_logits"] average val loss compute Dice / IoU reset metric states return val_metrics ``` metric 输入处理: ```text binary mode: pred = sigmoid(logits) >= threshold target = target > 0 multiclass mode: pred = argmax(logits) target = one-hot or class index ``` 当前默认: ```text threshold = 0.5 metrics = Dice, IoU ``` ## 18. Checkpoint 路径 checkpoint 目录: ```text cfg.checkpoint.dir ``` 默认脚本会覆盖为: ```text outputs/experiments/supervised/ ``` 保存文件: ```text best.pth last.pth ``` checkpoint 内容: ```text epoch cfg metrics model state_dict optimizer state_dict scheduler state_dict grad_scaler state_dict best_metric no_improve_epochs ``` best 判断: ```text monitor = dice monitor_mode = max ``` 即: ```text val_dice 越大越好 ``` ## 19. 日志与性能字段 每隔 `logging.log_interval` step 打印: ```text epoch step num_steps data_time iter_time gpu_memory_mb lr train_total train_seg train_grad_norm ``` 含义: ```text data_time: 从上一步结束到当前 batch 可用的时间。 num_workers > 0 时,第一批 worker 启动开销发生在 END TRAINING SETUP 之后。 iter_time: 当前 step 的训练计算时间,包括 forward、loss、backward、optimizer step。 gpu_memory_mb: torch.cuda.max_memory_allocated。 ``` 当前实测参考: ```text batch_size = 8 image_size = 256 ssm_backend = auto -> oflex iter_time ≈ 0.09 - 0.11 s / step GPU memory ≈ 850 MB ``` ## 20. 从输入到 loss 的端到端流程图 ```text Batch from DataLoader | +-- image [B,3,256,256] +-- mask [B,1,256,256] | v image.to(cuda), mask.to(cuda) | v autocast(enabled=True) | v XNet2d(image) | +-- encoder_features = [E1,E2,E3,E4] | +-- bottleneck(E4) | +-- decoder_out, decoder_features, guides | +-- segmentation_head(decoder_out) | v seg_logits [B,1,256,256] | v DiceCELoss(seg_logits, mask) | v total_loss | v GradScaler.scale(total_loss).backward() | v clip gradients | v optimizer.step() ``` ## 21. 关键运行命令 GPU 环境检查: ```bash python -c "import sys, torch; print(sys.executable); print(torch.__version__); print(torch.cuda.is_available()); print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'no cuda')" ``` oflex 检查: ```bash python -c "import torch; import selective_scan_cuda_oflex; print('oflex import OK')" python -c "import torch; from lib.modules.lib_mamba import csms6s; print(csms6s.WITH_SELECTIVESCAN_OFLEX)" ``` 前向检查: ```bash python - <<'PY' import torch from lib.modules import XNet2d model = XNet2d(in_channels=3, num_classes=1, ssm_backend="auto", ssm_forward_type="v3").cuda().eval() x = torch.randn(1, 3, 128, 128, device="cuda") with torch.no_grad(): y = model(x) print(sorted(y.keys())) print(tuple(y["seg_logits"].shape)) PY ``` 短训: ```bash DATASET=BUSI \ EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8" \ bash tools/run_us_experiments.sh ``` 关闭 frequency refine 消融: ```bash DATASET=BUSI \ EXTRA_SET_ARGS="train.epochs=1 train.batch_size=8 train.val_batch_size=8 model.use_frequency_refine=false logging.use_swanlab=false checkpoint.dir=outputs/validation/xnet_oflex_b8_no_freq" \ bash tools/run_us_experiments.sh ``` 汇总结果: ```bash bash tools/summarize_results.sh sed -n '1,40p' results/experiment_summary.md ``` ## 22. 推荐实验主线 第一阶段:训练链路稳定性 ```text BUSI smoke BUSI batch size 8 BUSI no frequency refine ``` 第二阶段:甲状腺主线 ```text DDTI TN3K TG3K DDTI -> TN3K / TN3K -> DDTI 跨数据集泛化 ``` 第三阶段:乳腺扩展 ```text BUSI BUS_UC BUS-BRA BUS-UCLM ``` 第四阶段:核心消融 ```text use_wavelet_branch=false use_frequency_refine=false ssm_backend=torch use_global_branch_stage1=true encoder_depths=[2,2,3,2] ``` ## 23. 当前边界与注意事项 1. 当前文档描述的是 active XNet2d 全监督主链。 2. 当前训练主链只优化 `seg_logits`。 3. `lib/sam2` 保留但不参与训练。 4. `lib/SwinTransformer` 保留但不参与训练。 5. `ssm_backend=auto` 在 CUDA 上应走 `oflex`,这是当前速度优化后的默认路径。 6. `XFrequencyRefine2d` 的 FFT 计算使用 float32,避免 AMP 下 ComplexHalf warning。 7. `num_workers > 0` 时,第一次进入 dataloader 迭代可能在 `END TRAINING SETUP` 后产生一次性等待。