Преглед изворни кода

feat(segmentation): 添加多个超声分割数据集的优化配置文件

添加了以下数据集的优化配置文件:
- BUS-BRA
- BUS_UC
- BUS-UCLM
- BUSI
- CCAUI
- DDTI
- OTU_2d
- TG3K
- TN3K

同时更新了模板文件中的swanlab日志配置,并修复了
wavelet变换中的数据类型处理问题以及训练器中的日志记录时机。
kekezack пре 1 недеља
родитељ
комит
66181add62

+ 127 - 0
configs/segmentation/optimized/sup_bus_bra_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 240
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 55
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: BUS-BRA
+  root: data/BUS-BRA
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 1.2e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 230
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.18
+    contrast_limit: 0.18
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.03
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/BUS-BRA
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_bus_bra_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_bus_uc_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 260
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 60
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: BUS_UC
+  root: data/BUS_UC
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 1.1e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 250
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.18
+    contrast_limit: 0.18
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.03
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/BUS_UC
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_bus_uc_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_bus_uclm_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 320
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 80
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: BUS-UCLM
+  root: data/BUS-UCLM
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 8.0e-5
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 310
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.20
+    contrast_limit: 0.20
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.035
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/BUS-UCLM
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_bus_uclm_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_busi_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 300
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 70
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: BUSI
+  root: data/BUSI
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 8.0e-5
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 290
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.20
+    contrast_limit: 0.20
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.035
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/BUSI
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_busi_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_ccaui_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 260
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 55
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: CCAUI
+  root: data/CCAUI
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 1.1e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 250
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.15
+    contrast_limit: 0.15
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.025
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/CCAUI
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_ccaui_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_ddti_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 320
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 80
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: DDTI
+  root: data/DDTI
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 8.0e-5
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 310
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.18
+    contrast_limit: 0.18
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.03
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/DDTI
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_ddti_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_otu_2d_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 220
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 50
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: OTU_2d
+  root: data/OTU_2d
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 1.2e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 210
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.15
+    contrast_limit: 0.15
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.025
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/OTU_2d
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_otu_2d_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_tg3k_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 220
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 50
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: TG3K
+  root: data/TG3K
+  split: train
+  split_file: null
+  val_split: val
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 1.2e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 210
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.15
+    contrast_limit: 0.15
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.025
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/TG3K
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_tg3k_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 127 - 0
configs/segmentation/optimized/sup_tn3k_opt.yaml

@@ -0,0 +1,127 @@
+trainer:
+  name: supervised_segmentation
+
+train:
+  seed: 42
+  deterministic: false
+  epochs: 220
+  batch_size: 48
+  val_batch_size: 48
+  accum_steps: 1
+  amp: true
+  num_workers: 4
+  pin_memory: true
+  persistent_workers: true
+  prefetch_factor: 2
+  device: cuda
+  grad_clip:
+    enabled: true
+    max_norm: 1.0
+    norm_type: 2.0
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+loss:
+  name: dicece
+  task_mode: binary
+  params:
+    include_background: true
+    lambda_dice: 0.7
+    lambda_ce: 0.3
+
+validation:
+  enabled: true
+  interval: 1
+  threshold: 0.5
+  early_stopping: true
+  early_stopping_patience: 50
+  early_stopping_min_delta: 0.0
+  metrics:
+    task_mode: binary
+    metrics:
+      - name: dice
+      - name: iou
+
+dataset:
+  name: ultrasound_sup_seg
+  dataset_name: TN3K
+  root: data/TN3K
+  split: trainval
+  split_file: null
+  val_split: test
+  val_split_file: null
+  image_size: [384, 384]
+  in_channels: 3
+  num_classes: 1
+
+model:
+  in_channels: 3
+  encoder_channels: [32, 64, 128, 192]
+  encoder_depths: [2, 2, 2, 2]
+  decoder_channels: [128, 64, 32]
+  stem_channels: 24
+  bottleneck_depth: 1
+  global_ratio: 2.0
+  wavelet_type: haar
+  wavelet_level: 1
+  use_wavelet_branch: true
+  use_global_branch_stage1: false
+  ssm_d_state: 16
+  ssm_forward_type: v3
+  ssm_backend: auto
+  use_frequency_refine: true
+  low_freq_radius_h: 0.25
+  low_freq_radius_w: 0.25
+  learnable_low_freq_radius: true
+  guide_mode: affine
+  out_channels: null
+
+optimizer:
+  name: adamw
+  lr: 1.2e-4
+  weight_decay: 0.05
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 15
+  params:
+    T_max: 210
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_brightness_contrast: true
+    brightness_limit: 0.15
+    contrast_limit: 0.15
+    random_gaussian_noise: true
+    gaussian_noise_std: 0.025
+  val: {}
+
+checkpoint:
+  dir: outputs/experiments/optimized/TN3K
+  save: true
+  save_last: true
+  monitor: dice
+  monitor_mode: max
+  resume: null
+  resume_strict: true
+  resume_training: true
+
+logging:
+  log_interval: 10
+  print_training_setup: true
+  use_swanlab: true
+  project: X_SSL_Net
+  experiment_name: xnet_sup_tn3k_opt
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 2 - 1
configs/segmentation/train_sup_us_template.yaml

@@ -130,4 +130,5 @@ logging:
   use_swanlab: true
   project: X_SSL_Net
   experiment_name: xnet_sup_busi
-  swanlab_mode: null
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 2 - 1
configs/segmentation/us_exp_sup_busi.yaml

@@ -120,4 +120,5 @@ logging:
   use_swanlab: true
   project: X_SSL_Net
   experiment_name: xnet_sup_busi
-  swanlab_mode: null
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 2 - 1
configs/segmentation/us_exp_sup_busi_ablation.yaml

@@ -120,4 +120,5 @@ logging:
   use_swanlab: true
   project: X_SSL_Net
   experiment_name: xnet_sup_busi_no_wavelet_no_freq
-  swanlab_mode: null
+  swanlab_mode: cloud
+  swanlab_logdir: swanlog

+ 4 - 1
lib/modules/lib_mamba/vmamba.py

@@ -11,7 +11,6 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.utils.checkpoint as checkpoint
 from timm.layers import DropPath, trunc_normal_
-from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
 
 try:
     from .csm_triton import cross_scan_fn, cross_merge_fn
@@ -1503,6 +1502,8 @@ class VSSM(nn.Module):
         return x
 
     def flops(self, shape=(3, 224, 224), verbose=True):
+        from fvcore.nn import flop_count, parameter_count
+
         # shape = self.__input_shape__[1:]
         supported_ops={
             "aten::silu": None, # as relu is in _IGNORED_OPS
@@ -1799,6 +1800,8 @@ def vmamba_base_m2():
 
 
 if __name__ == "__main__":
+    from fvcore.nn import parameter_count
+
     model_ref = vmamba_tiny_s1l8()
 
     model = VSSM(

+ 13 - 6
lib/modules/xnet_2d.py

@@ -75,19 +75,26 @@ class XWaveletTransform2d(nn.Module):
         self.wavelet_level = wavelet_level
 
     def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
-        coeffs = ptwt.wavedec2(x, self.wavelet_type, level=self.wavelet_level)
+        original_dtype = x.dtype
+        with torch.autocast(device_type=x.device.type, enabled=False):
+            coeffs = ptwt.wavedec2(
+                x.float(), self.wavelet_type, level=self.wavelet_level
+            )
         ll = coeffs[0]
         high_parts = coeffs[1]
         high = torch.cat(high_parts, dim=1)
-        return ll, high
+        return ll.to(original_dtype), high.to(original_dtype)
 
     def inverse(
         self, ll: torch.Tensor, high: torch.Tensor, output_size: tuple[int, int]
     ) -> torch.Tensor:
-        lh, hl, hh = torch.chunk(high, 3, dim=1)
-        coeffs = [ll, (lh, hl, hh)]
-        x = ptwt.waverec2(coeffs, self.wavelet_type)
-        return x[:, :, : output_size[0], : output_size[1]]
+        original_dtype = ll.dtype
+        with torch.autocast(device_type=ll.device.type, enabled=False):
+            lh, hl, hh = torch.chunk(high.float(), 3, dim=1)
+            coeffs = [ll.float(), (lh, hl, hh)]
+            x = ptwt.waverec2(coeffs, self.wavelet_type)
+        x = x[:, :, : output_size[0], : output_size[1]]
+        return x.to(original_dtype)
 
 
 class XWaveletBranch2d(nn.Module):

+ 1 - 1
lib/trainers/base.py

@@ -371,6 +371,7 @@ class BaseTrainer(ABC):
             project=logging_cfg.get("project", "X_SSL_Net"),
             name=run_name,
             config=self.cfg,
+            logdir=logging_cfg.get("swanlab_logdir", "swanlog"),
             mode=logging_cfg.get("swanlab_mode"),
         )
 
@@ -563,7 +564,6 @@ class BaseTrainer(ABC):
         )
         if "lr" in snapshot:
             log_metrics[f"{prefix}/lr"] = float(snapshot["lr"])
-        self._log_metrics(log_metrics, step=epoch * max(1, num_steps) + step)
 
     @staticmethod
     def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]:

+ 261 - 0
scripts/probe_xnet_memory.py

@@ -0,0 +1,261 @@
+from __future__ import annotations
+
+import argparse
+import gc
+import sys
+import time
+from pathlib import Path
+from typing import Any
+
+import torch
+
+ROOT_DIR = Path(__file__).resolve().parents[1]
+if str(ROOT_DIR) not in sys.path:
+    sys.path.insert(0, str(ROOT_DIR))
+
+from lib.modules import XNet2d
+from lib.tools import build_loss, build_optimizer
+from lib.utils.config import apply_dotlist_overrides, load_yaml_config
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(
+        description="Probe XNet2d CUDA memory with synthetic segmentation batches."
+    )
+    parser.add_argument(
+        "--config",
+        default="configs/segmentation/train_sup_us_template.yaml",
+        help="YAML config path.",
+    )
+    parser.add_argument(
+        "--batch-sizes",
+        nargs="+",
+        type=int,
+        default=[4, 6, 8],
+        help="Batch sizes to probe.",
+    )
+    parser.add_argument(
+        "--image-size",
+        nargs=2,
+        type=int,
+        default=None,
+        metavar=("H", "W"),
+        help="Override dataset.image_size.",
+    )
+    parser.add_argument(
+        "--amp",
+        action=argparse.BooleanOptionalAction,
+        default=None,
+        help="Override train.amp.",
+    )
+    parser.add_argument(
+        "--device",
+        default="cuda",
+        help="Device to probe. CUDA is required for memory numbers.",
+    )
+    parser.add_argument(
+        "--warmup",
+        action="store_true",
+        help="Run one unmeasured warmup step before measuring each batch size.",
+    )
+    parser.add_argument(
+        "--set",
+        nargs="*",
+        default=None,
+        help="Override config values with key=value pairs.",
+    )
+    return parser.parse_args()
+
+
+def build_model(cfg: dict[str, Any], device: torch.device) -> XNet2d:
+    dataset_cfg = cfg["dataset"]
+    model_cfg = cfg["model"]
+    return XNet2d(
+        in_channels=int(
+            model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3))
+        ),
+        num_classes=int(dataset_cfg["num_classes"]),
+        encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))),
+        encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))),
+        decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))),
+        stem_channels=int(model_cfg.get("stem_channels", 24)),
+        bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)),
+        global_ratio=float(model_cfg.get("global_ratio", 2.0)),
+        wavelet_type=str(model_cfg.get("wavelet_type", "haar")),
+        wavelet_level=int(model_cfg.get("wavelet_level", 1)),
+        use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)),
+        use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)),
+        ssm_d_state=int(model_cfg.get("ssm_d_state", 16)),
+        ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
+        ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
+        use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
+        low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)),
+        low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)),
+        learnable_low_freq_radius=bool(
+            model_cfg.get("learnable_low_freq_radius", True)
+        ),
+        guide_mode=str(model_cfg.get("guide_mode", "affine")),
+        out_channels=model_cfg.get("out_channels"),
+    ).to(device)
+
+
+def release_cuda() -> None:
+    gc.collect()
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+        torch.cuda.reset_peak_memory_stats()
+
+
+def make_batch(
+    *,
+    batch_size: int,
+    in_channels: int,
+    num_classes: int,
+    image_size: tuple[int, int],
+    device: torch.device,
+) -> tuple[torch.Tensor, torch.Tensor]:
+    height, width = image_size
+    image = torch.randn(batch_size, in_channels, height, width, device=device)
+    if num_classes == 1:
+        mask = torch.randint(
+            0, 2, (batch_size, 1, height, width), device=device
+        ).float()
+    else:
+        mask = torch.randint(
+            0, num_classes, (batch_size, 1, height, width), device=device
+        )
+    return image, mask
+
+
+def run_step(
+    *,
+    cfg: dict[str, Any],
+    batch_size: int,
+    image_size: tuple[int, int],
+    device: torch.device,
+    amp_enabled: bool,
+) -> dict[str, float]:
+    release_cuda()
+    model = build_model(cfg, device)
+    model.train()
+    optimizer = build_optimizer(model, cfg["optimizer"])
+    loss_fn = build_loss(cfg["loss"])
+
+    dataset_cfg = cfg["dataset"]
+    in_channels = int(
+        dataset_cfg.get("in_channels", cfg["model"].get("in_channels", 3))
+    )
+    num_classes = int(dataset_cfg["num_classes"])
+    image, mask = make_batch(
+        batch_size=batch_size,
+        in_channels=in_channels,
+        num_classes=num_classes,
+        image_size=image_size,
+        device=device,
+    )
+
+    torch.cuda.synchronize(device)
+    torch.cuda.reset_peak_memory_stats(device)
+    start = time.perf_counter()
+    optimizer.zero_grad(set_to_none=True)
+    with torch.autocast(device_type=device.type, enabled=amp_enabled):
+        outputs = model(image)
+        loss = loss_fn(outputs["seg_logits"], mask)
+    loss.backward()
+    optimizer.step()
+    torch.cuda.synchronize(device)
+    elapsed = time.perf_counter() - start
+
+    result = {
+        "loss": float(loss.detach().cpu()),
+        "seconds": elapsed,
+        "allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
+        "reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
+    }
+    del model, optimizer, loss_fn, image, mask, outputs, loss
+    release_cuda()
+    return result
+
+
+def print_header(
+    cfg: dict[str, Any],
+    image_size: tuple[int, int],
+    device: torch.device,
+    amp_enabled: bool,
+) -> None:
+    model_cfg = cfg["model"]
+    print("XNet2d memory probe")
+    print(
+        f"device: {torch.cuda.get_device_name(device) if device.type == 'cuda' else device}"
+    )
+    print(f"image_size: {list(image_size)}")
+    print(f"amp: {amp_enabled}")
+    print(f"encoder_channels: {model_cfg.get('encoder_channels')}")
+    print(f"encoder_depths: {model_cfg.get('encoder_depths')}")
+    print(f"global_ratio: {model_cfg.get('global_ratio')}")
+    print()
+    print(
+        f"{'batch':>5}  {'status':>8}  {'allocated':>12}  {'reserved':>12}  "
+        f"{'seconds':>8}  {'loss/error'}"
+    )
+    print("-" * 78)
+
+
+def main() -> None:
+    args = parse_args()
+    if args.device == "cuda" and not torch.cuda.is_available():
+        raise RuntimeError("CUDA is not available.")
+
+    cfg_path = (
+        ROOT_DIR / args.config
+        if not Path(args.config).is_absolute()
+        else Path(args.config)
+    )
+    cfg = apply_dotlist_overrides(load_yaml_config(cfg_path), args.set)
+    device = torch.device(args.device)
+    image_size = tuple(args.image_size or cfg["dataset"]["image_size"])
+    amp_enabled = bool(
+        cfg.get("train", {}).get("amp", False) if args.amp is None else args.amp
+    )
+
+    print_header(cfg, image_size, device, amp_enabled)
+    for batch_size in args.batch_sizes:
+        try:
+            if args.warmup:
+                run_step(
+                    cfg=cfg,
+                    batch_size=batch_size,
+                    image_size=image_size,
+                    device=device,
+                    amp_enabled=amp_enabled,
+                )
+            result = run_step(
+                cfg=cfg,
+                batch_size=batch_size,
+                image_size=image_size,
+                device=device,
+                amp_enabled=amp_enabled,
+            )
+            print(
+                f"{batch_size:>5}  {'ok':>8}  "
+                f"{result['allocated_mb']:>9.1f} MB  "
+                f"{result['reserved_mb']:>9.1f} MB  "
+                f"{result['seconds']:>8.2f}  "
+                f"loss={result['loss']:.6f}"
+            )
+        except torch.cuda.OutOfMemoryError as exc:
+            release_cuda()
+            print(
+                f"{batch_size:>5}  {'OOM':>8}  "
+                f"{'-':>12}  {'-':>12}  {'-':>8}  {str(exc).splitlines()[0]}"
+            )
+        except Exception as exc:
+            release_cuda()
+            print(
+                f"{batch_size:>5}  {'ERROR':>8}  "
+                f"{'-':>12}  {'-':>12}  {'-':>8}  {type(exc).__name__}: {exc}"
+            )
+
+
+if __name__ == "__main__":
+    main()

+ 25 - 1
tests/test_xnet_2d.py

@@ -1,12 +1,36 @@
 from __future__ import annotations
 
+import importlib
+import sys
+import warnings
+
 import torch
 from torch import nn
 
-from lib.modules.xnet_2d import XNet2d, XTEB2d
+
+def test_importing_xnet2d_does_not_emit_deprecation_warnings() -> None:
+    modules_to_clear = [
+        name
+        for name in sys.modules
+        if name == "lib.modules.xnet_2d" or name.startswith("lib.modules.lib_mamba")
+    ]
+    for name in modules_to_clear:
+        sys.modules.pop(name, None)
+
+    with warnings.catch_warnings(record=True) as caught:
+        warnings.simplefilter("always", DeprecationWarning)
+        importlib.import_module("lib.modules.xnet_2d")
+
+    assert not [
+        warning
+        for warning in caught
+        if issubclass(warning.category, DeprecationWarning)
+    ]
 
 
 def test_xnet2d_forward_preserves_segmentation_shape() -> None:
+    from lib.modules.xnet_2d import XNet2d, XTEB2d
+
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     model = XNet2d(
         in_channels=3,

+ 136 - 0
tools/export_swanlab_backup.py

@@ -0,0 +1,136 @@
+from __future__ import annotations
+
+import argparse
+import csv
+import json
+import sys
+from collections import defaultdict
+from pathlib import Path
+from typing import Any
+
+from swanlab.data.porter import DataPorter
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(
+        description="Export SwanLab backup.swanlab records to readable CSV/JSONL files."
+    )
+    parser.add_argument(
+        "run_dir",
+        help="SwanLab run directory, e.g. swanlog/run-20260530_115103-...",
+    )
+    parser.add_argument(
+        "--out-dir",
+        default=None,
+        help="Output directory. Defaults to <run_dir>/exported.",
+    )
+    parser.add_argument(
+        "--exclude-system",
+        action="store_true",
+        help="Exclude SwanLab system metrics whose keys start with __swanlab__.",
+    )
+    return parser.parse_args()
+
+
+def scalar_to_row(scalar: Any) -> dict[str, Any]:
+    metric = scalar.metric or {}
+    return {
+        "key": scalar.key,
+        "step": scalar.step,
+        "epoch": scalar.epoch,
+        "index": metric.get("index"),
+        "data": metric.get("data"),
+        "create_time": metric.get("create_time"),
+    }
+
+
+def log_to_row(log: Any) -> dict[str, Any]:
+    return {
+        "level": log.level,
+        "message": log.message,
+        "create_time": log.create_time,
+        "epoch": log.epoch,
+    }
+
+
+def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
+    with path.open("w", encoding="utf-8") as handle:
+        for row in rows:
+            handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
+
+
+def write_csv(path: Path, rows: list[dict[str, Any]], fieldnames: list[str]) -> None:
+    with path.open("w", encoding="utf-8", newline="") as handle:
+        writer = csv.DictWriter(handle, fieldnames=fieldnames)
+        writer.writeheader()
+        writer.writerows(rows)
+
+
+def build_epoch_table(scalar_rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
+    grouped: dict[int, dict[str, Any]] = defaultdict(dict)
+    for row in scalar_rows:
+        epoch = row.get("step")
+        key = row.get("key")
+        if epoch is None or key is None:
+            continue
+        if str(key).startswith("__swanlab__"):
+            continue
+        grouped[int(epoch)]["epoch"] = int(epoch)
+        grouped[int(epoch)][str(key)] = row.get("data")
+    return [grouped[epoch] for epoch in sorted(grouped)]
+
+
+def main() -> None:
+    args = parse_args()
+    run_dir = Path(args.run_dir)
+    if not (run_dir / "backup.swanlab").exists():
+        raise FileNotFoundError(f"backup.swanlab not found under {run_dir}")
+
+    out_dir = Path(args.out_dir) if args.out_dir is not None else run_dir / "exported"
+    out_dir.mkdir(parents=True, exist_ok=True)
+
+    with DataPorter().open_for_sync(str(run_dir), backend="python") as porter:
+        project, experiment = porter.parse()
+        scalar_rows = [scalar_to_row(scalar) for scalar in porter._scalars]
+        log_rows = [log_to_row(log) for log in porter._logs]
+
+    if args.exclude_system:
+        scalar_rows = [
+            row for row in scalar_rows if not str(row["key"]).startswith("__swanlab__")
+        ]
+
+    scalar_fields = ["key", "step", "epoch", "index", "data", "create_time"]
+    log_fields = ["level", "message", "create_time", "epoch"]
+    write_csv(out_dir / "scalars.csv", scalar_rows, scalar_fields)
+    write_jsonl(out_dir / "scalars.jsonl", scalar_rows)
+    write_csv(out_dir / "logs.csv", log_rows, log_fields)
+    write_jsonl(out_dir / "logs.jsonl", log_rows)
+
+    epoch_rows = build_epoch_table(scalar_rows)
+    if epoch_rows:
+        fields = ["epoch"]
+        for row in epoch_rows:
+            for key in row:
+                if key not in fields:
+                    fields.append(key)
+        write_csv(out_dir / "epoch_metrics.csv", epoch_rows, fields)
+        write_jsonl(out_dir / "epoch_metrics.jsonl", epoch_rows)
+
+    summary = {
+        "project": project.name,
+        "experiment_id": experiment.id,
+        "experiment_name": experiment.name,
+        "num_scalars": len(scalar_rows),
+        "num_logs": len(log_rows),
+        "num_epoch_rows": len(epoch_rows),
+        "out_dir": str(out_dir),
+    }
+    print(summary)
+
+
+if __name__ == "__main__":
+    try:
+        main()
+    except Exception as exc:
+        print(f"error: {type(exc).__name__}: {exc}", file=sys.stderr)
+        raise

+ 99 - 0
tools/run_optimized_supervised.sh

@@ -0,0 +1,99 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
+cd "$ROOT_DIR"
+
+PYTHON="/opt/miniforge3/envs/xnet_mamba/bin/python"
+DATASET="${DATASET:-BUSI}"
+RUN_ALL="${RUN_ALL:-0}"
+SEED="${SEED:-42}"
+EXTRA_SET_ARGS="${EXTRA_SET_ARGS:-}"
+
+dataset_root() {
+  case "$1" in
+    "BUSI") echo "data/BUSI" ;;
+    "BUS-UCLM") echo "data/BUS-UCLM" ;;
+    "BUS-BRA") echo "data/BUS-BRA" ;;
+    "BUS_UC") echo "data/BUS_UC" ;;
+    "CCAUI") echo "data/CCAUI" ;;
+    "DDTI") echo "data/DDTI" ;;
+    "TN3K") echo "data/TN3K" ;;
+    "TG3K") echo "data/TG3K" ;;
+    "OTU_2d") echo "data/OTU_2d" ;;
+    *) echo "Unsupported dataset: $1" >&2; exit 1 ;;
+  esac
+}
+
+config_path() {
+  case "$1" in
+    "BUSI") echo "configs/segmentation/optimized/sup_busi_opt.yaml" ;;
+    "BUS-UCLM") echo "configs/segmentation/optimized/sup_bus_uclm_opt.yaml" ;;
+    "BUS-BRA") echo "configs/segmentation/optimized/sup_bus_bra_opt.yaml" ;;
+    "BUS_UC") echo "configs/segmentation/optimized/sup_bus_uc_opt.yaml" ;;
+    "CCAUI") echo "configs/segmentation/optimized/sup_ccaui_opt.yaml" ;;
+    "DDTI") echo "configs/segmentation/optimized/sup_ddti_opt.yaml" ;;
+    "TN3K") echo "configs/segmentation/optimized/sup_tn3k_opt.yaml" ;;
+    "TG3K") echo "configs/segmentation/optimized/sup_tg3k_opt.yaml" ;;
+    "OTU_2d") echo "configs/segmentation/optimized/sup_otu_2d_opt.yaml" ;;
+    *) echo "Unsupported dataset: $1" >&2; exit 1 ;;
+  esac
+}
+
+needs_project_split() {
+  case "$1" in
+    "BUSI"|"BUS-UCLM"|"BUS-BRA"|"BUS_UC"|"CCAUI"|"DDTI") return 0 ;;
+    *) return 1 ;;
+  esac
+}
+
+prepare_split() {
+  local dataset="$1"
+  local root
+  root="$(dataset_root "$dataset")"
+  if needs_project_split "$dataset"; then
+    echo "[split] ${dataset}"
+    "$PYTHON" scripts/generate_project_split.py --dataset "$dataset" --root "$root" --seed "$SEED"
+  fi
+}
+
+run_one() {
+  local dataset="$1"
+  local config
+  config="$(config_path "$dataset")"
+  prepare_split "$dataset"
+  "$PYTHON" - "$config" ${EXTRA_SET_ARGS} <<'PY'
+import sys
+
+from lib.utils.config import apply_dotlist_overrides, load_yaml_config
+
+config = sys.argv[1]
+overrides = sys.argv[2:]
+cfg = apply_dotlist_overrides(load_yaml_config(config), overrides)
+print("[config]", config)
+print(
+    "[effective]",
+    f"dataset={cfg['dataset']['dataset_name']}",
+    f"root={cfg['dataset']['root']}",
+    f"image_size={cfg['dataset']['image_size']}",
+    f"batch_size={cfg['train']['batch_size']}",
+    f"val_batch_size={cfg['train']['val_batch_size']}",
+    f"amp={cfg['train']['amp']}",
+    f"lr={cfg['optimizer']['lr']}",
+    f"swanlab={cfg['logging']['use_swanlab']}",
+    f"swanlab_mode={cfg['logging'].get('swanlab_mode')}",
+    f"swanlab_logdir={cfg['logging'].get('swanlab_logdir', 'swanlog')}",
+    f"experiment={cfg['logging']['experiment_name']}",
+)
+PY
+  echo "[train] ${dataset} using ${config}"
+  "$PYTHON" tools/train.py --config "$config" --set ${EXTRA_SET_ARGS}
+}
+
+if [[ "$RUN_ALL" == "1" ]]; then
+  for dataset in BUSI BUS-UCLM BUS-BRA BUS_UC CCAUI DDTI TN3K TG3K OTU_2d; do
+    run_one "$dataset"
+  done
+else
+  run_one "$DATASET"
+fi

+ 0 - 91
tools/run_us_experiments.sh

@@ -1,91 +0,0 @@
-#!/usr/bin/env bash
-set -euo pipefail
-
-ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
-cd "$ROOT_DIR"
-
-# ===== 可直接改这里 =====
-DATASET="${DATASET:-BUSI}"          # BUS-UCLM | BUSI | BUS-BRA | BUS_UC | CCAUI | DDTI | OTU_2d | TN3K | TG3K
-SEED="${SEED:-42}"
-RUN_ALL_SUP="${RUN_ALL_SUP:-0}"     # 1 表示跑内置所有全监督实验
-PYTHON_BIN="${PYTHON_BIN:-python}"
-EXTRA_SET_ARGS="${EXTRA_SET_ARGS:-}"
-
-# ===== 数据集根目录 =====
-dataset_root() {
-  case "$1" in
-    "BUS-UCLM") echo "data/BUS-UCLM" ;;
-    "BUSI") echo "data/BUSI" ;;
-    "BUS-BRA") echo "data/BUS-BRA" ;;
-    "BUS_UC") echo "data/BUS_UC" ;;
-    "CCAUI") echo "data/CCAUI" ;;
-    "DDTI") echo "data/DDTI" ;;
-    "OTU_2d") echo "data/OTU_2d" ;;
-    "TN3K") echo "data/TN3K" ;;
-    "TG3K") echo "data/TG3K" ;;
-    *) echo "Unsupported dataset: $1" >&2; exit 1 ;;
-  esac
-}
-
-# ===== 是否需要项目级 train/val =====
-needs_project_split() {
-  case "$1" in
-    "BUS-UCLM"|"BUSI"|"BUS-BRA"|"BUS_UC"|"CCAUI"|"DDTI") return 0 ;;
-    *) return 1 ;;
-  esac
-}
-
-prepare_project_splits() {
-  local dataset="$1"
-  local root
-  root="$(dataset_root "$dataset")"
-
-  if needs_project_split "$dataset"; then
-    echo "[split] generate project split for ${dataset}"
-    "$PYTHON_BIN" scripts/generate_project_split.py --dataset "$dataset" --root "$root" --seed "$SEED"
-  fi
-}
-
-run_supervised() {
-  local dataset="$1"
-  local root
-  root="$(dataset_root "$dataset")"
-  prepare_project_splits "$dataset"
-  echo "[train] supervised ${dataset}"
-  "$PYTHON_BIN" tools/train.py \
-    --config configs/segmentation/train_sup_us_template.yaml \
-    --set \
-      dataset.dataset_name="$dataset" \
-      dataset.root="$root" \
-      checkpoint.dir="outputs/experiments/supervised/${dataset}" \
-      logging.experiment_name="sup_${dataset}" \
-      ${EXTRA_SET_ARGS}
-}
-
-run_all_supervised_suite() {
-  local datasets=(
-    "BUS-UCLM"
-    "BUSI"
-    "BUS-BRA"
-    "BUS_UC"
-    "CCAUI"
-    "DDTI"
-    "OTU_2d"
-    "TN3K"
-    "TG3K"
-  )
-  for ds in "${datasets[@]}"; do
-    run_supervised "$ds"
-  done
-}
-
-main() {
-  if [[ "$RUN_ALL_SUP" == "1" ]]; then
-    run_all_supervised_suite
-    exit 0
-  fi
-
-  run_supervised "$DATASET"
-}
-
-main "$@"

+ 3 - 3
tools/summarize_results.sh

@@ -4,9 +4,9 @@ set -euo pipefail
 ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
 cd "$ROOT_DIR"
 
-PYTHON_BIN="${PYTHON_BIN:-python}"
-OUTPUTS_DIR="${OUTPUTS_DIR:-outputs}"
-RESULTS_DIR="${RESULTS_DIR:-results}"
+PYTHON_BIN="${PYTHON_BIN:-/opt/miniforge3/envs/xnet_mamba/bin/python}"
+OUTPUTS_DIR="${OUTPUTS_DIR:-outputs/experiments/optimized}"
+RESULTS_DIR="${RESULTS_DIR:-results/optimized}"
 
 "$PYTHON_BIN" tools/summarize_results.py --outputs-dir "$OUTPUTS_DIR" --results-dir "$RESULTS_DIR"