train_seg_multiclass_template.yaml 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. train:
  2. seed: 42
  3. epochs: 100
  4. batch_size: 8
  5. accum_steps: 1
  6. amp: true
  7. num_workers: 8
  8. device: cuda
  9. dataset:
  10. name: lung_ultrasound_seg_multiclass
  11. root: data/lung_ultrasound
  12. task_name: multiclass
  13. image_size: [256, 256]
  14. in_channels: 3
  15. num_classes: 4
  16. class_names: [background, lun, pe, b]
  17. train_split: train
  18. val_split: val
  19. test_split: test
  20. mask_suffix: .png
  21. image_suffix: .png
  22. model:
  23. name: swin_unet
  24. encoder_name: swinv2_base_patch4_window12_192_22k
  25. in_channels: 3
  26. out_channels: 4
  27. img_size: 256
  28. drop_rate: 0.0
  29. drop_path_rate: 0.2
  30. pretrain:
  31. enabled: true
  32. source: imagenet22k
  33. checkpoint: weights/swinv2_base_patch4_window12_192_22k.pth
  34. strict: false
  35. loss:
  36. name: generalized_dice_focal
  37. task_mode: multiclass
  38. params:
  39. include_background: false
  40. metrics:
  41. task_mode: multiclass
  42. metrics:
  43. - name: dice
  44. - name: miou
  45. optimizer:
  46. name: adamw
  47. lr: 5.0e-5
  48. weight_decay: 0.05
  49. betas: [0.9, 0.999]
  50. scheduler:
  51. name: cosine
  52. warmup:
  53. name: linear
  54. params:
  55. start_factor: 0.1
  56. total_iters: 10
  57. params:
  58. T_max: 100
  59. eta_min: 1.0e-6
  60. augmentation:
  61. train:
  62. random_flip: true
  63. random_rotate_90: true
  64. random_resized_crop: false
  65. random_brightness_contrast: true
  66. random_gaussian_noise: true
  67. val:
  68. center_crop: false
  69. validation:
  70. enabled: true
  71. interval: 1
  72. metrics: [dice, miou]
  73. save_best: true
  74. monitor: dice
  75. mode: max
  76. checkpoint:
  77. dir: outputs/segmentation/train_seg_multiclass
  78. save_last: true
  79. save_best_only: false
  80. keep_top_k: 3
  81. logging:
  82. log_interval: 20
  83. use_tensorboard: true
  84. tensorboard_dir: outputs/tensorboard/train_seg_multiclass