summarize_results.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from __future__ import annotations
  2. import argparse
  3. import csv
  4. from pathlib import Path
  5. from typing import Any
  6. import torch
  7. def _infer_mode(path: Path) -> str:
  8. parts = set(path.parts)
  9. if "supervised" in parts:
  10. return "supervised"
  11. return "unknown"
  12. def _infer_dataset(ckpt: dict[str, Any], path: Path) -> str:
  13. cfg = ckpt.get("cfg", {})
  14. dataset_cfg = cfg.get("dataset", {})
  15. dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
  16. if dataset_name:
  17. return str(dataset_name)
  18. parts = path.parts
  19. if "supervised" in parts:
  20. idx = parts.index("supervised")
  21. if idx + 1 < len(parts):
  22. return parts[idx + 1]
  23. return "unknown"
  24. def _infer_ratio(ckpt: dict[str, Any], path: Path) -> str:
  25. return "-"
  26. def _infer_ablation_case(ckpt: dict[str, Any], path: Path) -> str:
  27. return "-"
  28. def _extract_metric(metrics: dict[str, Any], *names: str) -> float | None:
  29. for name in names:
  30. value = metrics.get(name)
  31. if value is not None:
  32. return float(value)
  33. return None
  34. def collect_rows(outputs_dir: Path) -> list[dict[str, Any]]:
  35. rows: list[dict[str, Any]] = []
  36. for best_path in sorted(outputs_dir.rglob("best.pth")):
  37. ckpt = torch.load(best_path, map_location="cpu")
  38. metrics = ckpt.get("metrics", {}) or {}
  39. row = {
  40. "dataset": _infer_dataset(ckpt, best_path),
  41. "mode": _infer_mode(best_path),
  42. "ablation_case": _infer_ablation_case(ckpt, best_path),
  43. "ratio": _infer_ratio(ckpt, best_path),
  44. "epoch": ckpt.get("epoch"),
  45. "best_metric": ckpt.get("best_metric"),
  46. "dice": _extract_metric(metrics, "val_dice", "dice"),
  47. "iou": _extract_metric(metrics, "val_iou", "val_miou", "iou", "miou"),
  48. "checkpoint": str(best_path),
  49. }
  50. rows.append(row)
  51. return rows
  52. def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
  53. path.parent.mkdir(parents=True, exist_ok=True)
  54. fieldnames = ["dataset", "mode", "ablation_case", "ratio", "epoch", "best_metric", "dice", "iou", "checkpoint"]
  55. with path.open("w", encoding="utf-8", newline="") as handle:
  56. writer = csv.DictWriter(handle, fieldnames=fieldnames)
  57. writer.writeheader()
  58. writer.writerows(rows)
  59. def write_markdown(rows: list[dict[str, Any]], path: Path) -> None:
  60. path.parent.mkdir(parents=True, exist_ok=True)
  61. lines = [
  62. "# 实验结果汇总",
  63. "",
  64. "| dataset | mode | ablation_case | ratio | epoch | best_metric | dice | iou | checkpoint |",
  65. "| --- | --- | --- | --- | --- | --- | --- | --- | --- |",
  66. ]
  67. for row in rows:
  68. lines.append(
  69. f"| {row['dataset']} | {row['mode']} | {row['ablation_case']} | {row['ratio']} | {row['epoch']} | "
  70. f"{row['best_metric']} | {row['dice']} | {row['iou']} | {row['checkpoint']} |"
  71. )
  72. if not rows:
  73. lines.append("| - | - | - | - | - | - | - | - | - |")
  74. path.write_text("\n".join(lines) + "\n", encoding="utf-8")
  75. def main() -> None:
  76. parser = argparse.ArgumentParser(description="Summarize best experiment results from best.pth files.")
  77. parser.add_argument("--outputs-dir", default="outputs", help="Root output directory")
  78. parser.add_argument("--results-dir", default="results", help="Directory to write summary tables")
  79. args = parser.parse_args()
  80. outputs_dir = Path(args.outputs_dir)
  81. results_dir = Path(args.results_dir)
  82. rows = collect_rows(outputs_dir)
  83. csv_path = results_dir / "experiment_summary.csv"
  84. md_path = results_dir / "experiment_summary.md"
  85. write_csv(rows, csv_path)
  86. write_markdown(rows, md_path)
  87. print(
  88. {
  89. "num_results": len(rows),
  90. "csv": str(csv_path),
  91. "markdown": str(md_path),
  92. }
  93. )
  94. if __name__ == "__main__":
  95. main()