| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- from __future__ import annotations
- import argparse
- import csv
- from pathlib import Path
- from typing import Any
- import torch
- def _infer_mode(path: Path) -> str:
- parts = set(path.parts)
- if "supervised" in parts:
- return "supervised"
- return "unknown"
- def _infer_dataset(ckpt: dict[str, Any], path: Path) -> str:
- cfg = ckpt.get("cfg", {})
- dataset_cfg = cfg.get("dataset", {})
- dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
- if dataset_name:
- return str(dataset_name)
- parts = path.parts
- if "supervised" in parts:
- idx = parts.index("supervised")
- if idx + 1 < len(parts):
- return parts[idx + 1]
- return "unknown"
- def _infer_ratio(ckpt: dict[str, Any], path: Path) -> str:
- return "-"
- def _infer_ablation_case(ckpt: dict[str, Any], path: Path) -> str:
- return "-"
- def _extract_metric(metrics: dict[str, Any], *names: str) -> float | None:
- for name in names:
- value = metrics.get(name)
- if value is not None:
- return float(value)
- return None
- def collect_rows(outputs_dir: Path) -> list[dict[str, Any]]:
- rows: list[dict[str, Any]] = []
- for best_path in sorted(outputs_dir.rglob("best.pth")):
- ckpt = torch.load(best_path, map_location="cpu")
- metrics = ckpt.get("metrics", {}) or {}
- row = {
- "dataset": _infer_dataset(ckpt, best_path),
- "mode": _infer_mode(best_path),
- "ablation_case": _infer_ablation_case(ckpt, best_path),
- "ratio": _infer_ratio(ckpt, best_path),
- "epoch": ckpt.get("epoch"),
- "best_metric": ckpt.get("best_metric"),
- "dice": _extract_metric(metrics, "val_dice", "dice"),
- "iou": _extract_metric(metrics, "val_iou", "val_miou", "iou", "miou"),
- "checkpoint": str(best_path),
- }
- rows.append(row)
- return rows
- def write_csv(rows: list[dict[str, Any]], path: Path) -> None:
- path.parent.mkdir(parents=True, exist_ok=True)
- fieldnames = ["dataset", "mode", "ablation_case", "ratio", "epoch", "best_metric", "dice", "iou", "checkpoint"]
- with path.open("w", encoding="utf-8", newline="") as handle:
- writer = csv.DictWriter(handle, fieldnames=fieldnames)
- writer.writeheader()
- writer.writerows(rows)
- def write_markdown(rows: list[dict[str, Any]], path: Path) -> None:
- path.parent.mkdir(parents=True, exist_ok=True)
- lines = [
- "# 实验结果汇总",
- "",
- "| dataset | mode | ablation_case | ratio | epoch | best_metric | dice | iou | checkpoint |",
- "| --- | --- | --- | --- | --- | --- | --- | --- | --- |",
- ]
- for row in rows:
- lines.append(
- f"| {row['dataset']} | {row['mode']} | {row['ablation_case']} | {row['ratio']} | {row['epoch']} | "
- f"{row['best_metric']} | {row['dice']} | {row['iou']} | {row['checkpoint']} |"
- )
- if not rows:
- lines.append("| - | - | - | - | - | - | - | - | - |")
- path.write_text("\n".join(lines) + "\n", encoding="utf-8")
- def main() -> None:
- parser = argparse.ArgumentParser(description="Summarize best experiment results from best.pth files.")
- parser.add_argument("--outputs-dir", default="outputs", help="Root output directory")
- parser.add_argument("--results-dir", default="results", help="Directory to write summary tables")
- args = parser.parse_args()
- outputs_dir = Path(args.outputs_dir)
- results_dir = Path(args.results_dir)
- rows = collect_rows(outputs_dir)
- csv_path = results_dir / "experiment_summary.csv"
- md_path = results_dir / "experiment_summary.md"
- write_csv(rows, csv_path)
- write_markdown(rows, md_path)
- print(
- {
- "num_results": len(rows),
- "csv": str(csv_path),
- "markdown": str(md_path),
- }
- )
- if __name__ == "__main__":
- main()
|