| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- from __future__ import annotations
- import csv
- import json
- from pathlib import Path
- import sys
- import torch
- ROOT = Path(__file__).resolve().parents[3]
- WEIGHT_DIR = ROOT / "weights" / "swinv2"
- OUTPUT_JSON = ROOT / "tmp" / "swinv2_model_analysis.json"
- OUTPUT_CSV = ROOT / "tmp" / "swinv2_model_analysis.csv"
- if str(ROOT) not in sys.path:
- sys.path.insert(0, str(ROOT))
- from lib.modules import build_swinv2_auto
- def count_parameters(model: torch.nn.Module) -> tuple[int, int]:
- total = sum(param.numel() for param in model.parameters())
- trainable = sum(param.numel() for param in model.parameters() if param.requires_grad)
- return total, trainable
- def format_shape(shape) -> str:
- return "x".join(str(dim) for dim in shape)
- def write_reports(results: list[dict]):
- OUTPUT_JSON.write_text(json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8")
- fieldnames = [
- "model_name",
- "status",
- "config_path",
- "config_source",
- "weight_path",
- "weight_source",
- "input_shape",
- "total_params",
- "trainable_params",
- "feature_0_shape",
- "feature_1_shape",
- "feature_2_shape",
- "feature_3_shape",
- "feature_4_shape",
- "error_type",
- "error_message",
- ]
- with OUTPUT_CSV.open("w", encoding="utf-8", newline="") as handle:
- writer = csv.DictWriter(handle, fieldnames=fieldnames)
- writer.writeheader()
- for result in results:
- row = {
- "model_name": result.get("model_name"),
- "status": result.get("status"),
- "config_path": result.get("config_path"),
- "config_source": result.get("config_source"),
- "weight_path": result.get("weight_path"),
- "weight_source": result.get("weight_source"),
- "input_shape": result.get("input_shape"),
- "total_params": result.get("total_params"),
- "trainable_params": result.get("trainable_params"),
- "error_type": result.get("error_type"),
- "error_message": result.get("error_message"),
- }
- feature_shapes = result.get("feature_shapes", [])
- for idx in range(5):
- row[f"feature_{idx}_shape"] = feature_shapes[idx] if idx < len(feature_shapes) else None
- writer.writerow(row)
- def analyze_model(model_name: str):
- print(f"Model: {model_name}")
- try:
- built = build_swinv2_auto(
- model_name=model_name,
- return_config=True,
- return_resolution=True,
- verbose=False,
- )
- if not isinstance(built, tuple) or len(built) != 3:
- raise RuntimeError("build_swinv2_auto(return_config=True, return_resolution=True) must return 3 values")
- model, config, resolution = built
- model.eval()
- img_size = int(config.DATA.IMG_SIZE)
- in_chans = int(config.MODEL.SWINV2.IN_CHANS)
- input_shape = (1, in_chans, img_size, img_size)
- x = torch.randn(*input_shape)
- with torch.no_grad():
- features = model.forward_multiscale_features(x)
- total_params, trainable_params = count_parameters(model)
- feature_shapes = [format_shape(tuple(feature.shape)) for feature in features]
- print(f" Resolved config: {resolution['config_path']} ({resolution['config_source']})")
- print(f" Resolved weight: {resolution['weight_path']} ({resolution['weight_source']})")
- print(f" Input shape: {format_shape(input_shape)}")
- print(f" Parameters: total={total_params:,}, trainable={trainable_params:,}")
- print(" forward_multiscale_features:")
- for idx, feature_shape in enumerate(feature_shapes):
- print(f" [{idx}] {feature_shape}")
- result = {
- "model_name": model_name,
- "status": "ok",
- "config_path": resolution["config_path"],
- "config_source": resolution["config_source"],
- "weight_path": resolution["weight_path"],
- "weight_source": resolution["weight_source"],
- "input_shape": format_shape(input_shape),
- "total_params": total_params,
- "trainable_params": trainable_params,
- "feature_shapes": feature_shapes,
- "error_type": None,
- "error_message": None,
- }
- except Exception as exc:
- print(f" ERROR: {type(exc).__name__}: {exc}")
- result = {
- "model_name": model_name,
- "status": "error",
- "config_path": None,
- "config_source": None,
- "weight_path": None,
- "weight_source": None,
- "input_shape": None,
- "total_params": None,
- "trainable_params": None,
- "feature_shapes": [],
- "error_type": type(exc).__name__,
- "error_message": str(exc),
- }
- print()
- return result
- def main():
- model_names = sorted(path.stem for path in WEIGHT_DIR.glob("*.pth"))
- if not model_names:
- raise RuntimeError(f"No SwinV2 weights found under {WEIGHT_DIR}")
- print(f"Found {len(model_names)} SwinV2 weight files in {WEIGHT_DIR}")
- print()
- results = []
- for model_name in model_names:
- results.append(analyze_model(model_name))
- write_reports(results)
- print(f"Saved JSON report to {OUTPUT_JSON}")
- print(f"Saved CSV report to {OUTPUT_CSV}")
- if __name__ == "__main__":
- main()
|