run_us_experiments.sh 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env bash
  2. set -euo pipefail
  3. ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
  4. cd "$ROOT_DIR"
  5. # ===== 可直接改这里 =====
  6. DATASET="${DATASET:-BUSI}" # BUS-UCLM | BUSI | BUS-BRA | BUS_UC | CCAUI | DDTI | OTU_2d | TN3K | TG3K
  7. SEED="${SEED:-42}"
  8. RUN_ALL_SUP="${RUN_ALL_SUP:-0}" # 1 表示跑内置所有全监督实验
  9. PYTHON_BIN="${PYTHON_BIN:-python}"
  10. EXTRA_SET_ARGS="${EXTRA_SET_ARGS:-}"
  11. # ===== 数据集根目录 =====
  12. dataset_root() {
  13. case "$1" in
  14. "BUS-UCLM") echo "data/BUS-UCLM" ;;
  15. "BUSI") echo "data/BUSI" ;;
  16. "BUS-BRA") echo "data/BUS-BRA" ;;
  17. "BUS_UC") echo "data/BUS_UC" ;;
  18. "CCAUI") echo "data/CCAUI" ;;
  19. "DDTI") echo "data/DDTI" ;;
  20. "OTU_2d") echo "data/OTU_2d" ;;
  21. "TN3K") echo "data/TN3K" ;;
  22. "TG3K") echo "data/TG3K" ;;
  23. *) echo "Unsupported dataset: $1" >&2; exit 1 ;;
  24. esac
  25. }
  26. # ===== 是否需要项目级 train/val =====
  27. needs_project_split() {
  28. case "$1" in
  29. "BUS-UCLM"|"BUSI"|"BUS-BRA"|"BUS_UC"|"CCAUI"|"DDTI") return 0 ;;
  30. *) return 1 ;;
  31. esac
  32. }
  33. prepare_project_splits() {
  34. local dataset="$1"
  35. local root
  36. root="$(dataset_root "$dataset")"
  37. if needs_project_split "$dataset"; then
  38. echo "[split] generate project split for ${dataset}"
  39. "$PYTHON_BIN" scripts/generate_project_split.py --dataset "$dataset" --root "$root" --seed "$SEED"
  40. fi
  41. }
  42. run_supervised() {
  43. local dataset="$1"
  44. local root
  45. root="$(dataset_root "$dataset")"
  46. prepare_project_splits "$dataset"
  47. echo "[train] supervised ${dataset}"
  48. "$PYTHON_BIN" tools/train.py \
  49. --config configs/segmentation/train_sup_us_template.yaml \
  50. --set \
  51. dataset.dataset_name="$dataset" \
  52. dataset.root="$root" \
  53. checkpoint.dir="outputs/experiments/supervised/${dataset}" \
  54. logging.experiment_name="sup_${dataset}" \
  55. ${EXTRA_SET_ARGS}
  56. }
  57. run_all_supervised_suite() {
  58. local datasets=(
  59. "BUS-UCLM"
  60. "BUSI"
  61. "BUS-BRA"
  62. "BUS_UC"
  63. "CCAUI"
  64. "DDTI"
  65. "OTU_2d"
  66. "TN3K"
  67. "TG3K"
  68. )
  69. for ds in "${datasets[@]}"; do
  70. run_supervised "$ds"
  71. done
  72. }
  73. main() {
  74. if [[ "$RUN_ALL_SUP" == "1" ]]; then
  75. run_all_supervised_suite
  76. exit 0
  77. fi
  78. run_supervised "$DATASET"
  79. }
  80. main "$@"