summarize_results.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from __future__ import annotations
  2. import argparse
  3. import csv
  4. from pathlib import Path
  5. from typing import Any
  6. import torch
  7. def _infer_mode(path: Path) -> str:
  8. parts = set(path.parts)
  9. if "supervised" in parts:
  10. return "supervised"
  11. return "unknown"
  12. def _infer_dataset(ckpt: dict[str, Any], path: Path) -> str:
  13. cfg = ckpt.get("cfg", {})
  14. dataset_cfg = cfg.get("dataset", {})
  15. dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
  16. if dataset_name:
  17. return str(dataset_name)
  18. parts = path.parts
  19. if "supervised" in parts:
  20. idx = parts.index("supervised")
  21. if idx + 1 < len(parts):
  22. return parts[idx + 1]
  23. return "unknown"
  24. def _infer_ratio(ckpt: dict[str, Any], path: Path) -> str:
  25. return "-"
  26. def _extract_metric(metrics: dict[str, Any], *names: str) -> float | None:
  27. for name in names:
  28. value = metrics.get(name)
  29. if value is not None:
  30. return float(value)
  31. return None
  32. def collect_rows(outputs_dir: Path) -> list[dict[str, Any]]:
  33. rows: list[dict[str, Any]] = []
  34. for best_path in sorted(outputs_dir.rglob("best.pth")):
  35. ckpt = torch.load(best_path, map_location="cpu")
  36. metrics = ckpt.get("metrics", {}) or {}
  37. row = {
  38. "dataset": _infer_dataset(ckpt, best_path),
  39. "mode": _infer_mode(best_path),
  40. "ratio": _infer_ratio(ckpt, best_path),
  41. "epoch": ckpt.get("epoch"),
  42. "best_metric": ckpt.get("best_metric"),
  43. "dice": _extract_metric(metrics, "val_dice", "dice"),
  44. "iou": _extract_metric(metrics, "val_iou", "val_miou", "iou", "miou"),
  45. "checkpoint": str(best_path),
  46. }
  47. rows.append(row)
  48. return rows
  49. def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
  50. path.parent.mkdir(parents=True, exist_ok=True)
  51. fieldnames = ["dataset", "mode", "ratio", "epoch", "best_metric", "dice", "iou", "checkpoint"]
  52. with path.open("w", encoding="utf-8", newline="") as handle:
  53. writer = csv.DictWriter(handle, fieldnames=fieldnames)
  54. writer.writeheader()
  55. writer.writerows(rows)
  56. def write_markdown(rows: list[dict[str, Any]], path: Path) -> None:
  57. path.parent.mkdir(parents=True, exist_ok=True)
  58. lines = [
  59. "# 实验结果汇总",
  60. "",
  61. "| dataset | mode | ratio | epoch | best_metric | dice | iou | checkpoint |",
  62. "| --- | --- | --- | --- | --- | --- | --- | --- |",
  63. ]
  64. for row in rows:
  65. lines.append(
  66. f"| {row['dataset']} | {row['mode']} | {row['ratio']} | {row['epoch']} | "
  67. f"{row['best_metric']} | {row['dice']} | {row['iou']} | {row['checkpoint']} |"
  68. )
  69. if not rows:
  70. lines.append("| - | - | - | - | - | - | - | - |")
  71. path.write_text("\n".join(lines) + "\n", encoding="utf-8")
  72. def main() -> None:
  73. parser = argparse.ArgumentParser(description="Summarize best experiment results from best.pth files.")
  74. parser.add_argument("--outputs-dir", default="outputs", help="Root output directory")
  75. parser.add_argument("--results-dir", default="results", help="Directory to write summary tables")
  76. args = parser.parse_args()
  77. outputs_dir = Path(args.outputs_dir)
  78. results_dir = Path(args.results_dir)
  79. rows = collect_rows(outputs_dir)
  80. csv_path = results_dir / "experiment_summary.csv"
  81. md_path = results_dir / "experiment_summary.md"
  82. write_csv(rows, csv_path)
  83. write_markdown(rows, md_path)
  84. print({"num_results": len(rows), "csv": str(csv_path), "markdown": str(md_path)})
  85. if __name__ == "__main__":
  86. main()