from __future__ import annotations import argparse import csv from pathlib import Path from typing import Any import torch def _infer_mode(path: Path) -> str: parts = set(path.parts) if "supervised" in parts: return "supervised" return "unknown" def _infer_dataset(ckpt: dict[str, Any], path: Path) -> str: cfg = ckpt.get("cfg", {}) dataset_cfg = cfg.get("dataset", {}) dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name") if dataset_name: return str(dataset_name) parts = path.parts if "supervised" in parts: idx = parts.index("supervised") if idx + 1 < len(parts): return parts[idx + 1] return "unknown" def _infer_ratio(ckpt: dict[str, Any], path: Path) -> str: return "-" def _infer_ablation_case(ckpt: dict[str, Any], path: Path) -> str: return "-" def _extract_metric(metrics: dict[str, Any], *names: str) -> float | None: for name in names: value = metrics.get(name) if value is not None: return float(value) return None def collect_rows(outputs_dir: Path) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for best_path in sorted(outputs_dir.rglob("best.pth")): ckpt = torch.load(best_path, map_location="cpu") metrics = ckpt.get("metrics", {}) or {} row = { "dataset": _infer_dataset(ckpt, best_path), "mode": _infer_mode(best_path), "ablation_case": _infer_ablation_case(ckpt, best_path), "ratio": _infer_ratio(ckpt, best_path), "epoch": ckpt.get("epoch"), "best_metric": ckpt.get("best_metric"), "dice": _extract_metric(metrics, "val_dice", "dice"), "iou": _extract_metric(metrics, "val_iou", "val_miou", "iou", "miou"), "checkpoint": str(best_path), } rows.append(row) return rows def write_csv(rows: list[dict[str, Any]], path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) fieldnames = ["dataset", "mode", "ablation_case", "ratio", "epoch", "best_metric", "dice", "iou", "checkpoint"] with path.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) def write_markdown(rows: list[dict[str, Any]], path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) lines = [ "# 实验结果汇总", "", "| dataset | mode | ablation_case | ratio | epoch | best_metric | dice | iou | checkpoint |", "| --- | --- | --- | --- | --- | --- | --- | --- | --- |", ] for row in rows: lines.append( f"| {row['dataset']} | {row['mode']} | {row['ablation_case']} | {row['ratio']} | {row['epoch']} | " f"{row['best_metric']} | {row['dice']} | {row['iou']} | {row['checkpoint']} |" ) if not rows: lines.append("| - | - | - | - | - | - | - | - | - |") path.write_text("\n".join(lines) + "\n", encoding="utf-8") def main() -> None: parser = argparse.ArgumentParser(description="Summarize best experiment results from best.pth files.") parser.add_argument("--outputs-dir", default="outputs", help="Root output directory") parser.add_argument("--results-dir", default="results", help="Directory to write summary tables") args = parser.parse_args() outputs_dir = Path(args.outputs_dir) results_dir = Path(args.results_dir) rows = collect_rows(outputs_dir) csv_path = results_dir / "experiment_summary.csv" md_path = results_dir / "experiment_summary.md" write_csv(rows, csv_path) write_markdown(rows, md_path) print( { "num_results": len(rows), "csv": str(csv_path), "markdown": str(md_path), } ) if __name__ == "__main__": main()