from __future__ import annotations import csv import json from pathlib import Path import sys import torch ROOT = Path(__file__).resolve().parents[3] WEIGHT_DIR = ROOT / "weights" / "swinv2" OUTPUT_JSON = ROOT / "tmp" / "swinv2_model_analysis.json" OUTPUT_CSV = ROOT / "tmp" / "swinv2_model_analysis.csv" if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from lib.modules import build_swinv2_auto def count_parameters(model: torch.nn.Module) -> tuple[int, int]: total = sum(param.numel() for param in model.parameters()) trainable = sum(param.numel() for param in model.parameters() if param.requires_grad) return total, trainable def format_shape(shape) -> str: return "x".join(str(dim) for dim in shape) def write_reports(results: list[dict]): OUTPUT_JSON.write_text(json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8") fieldnames = [ "model_name", "status", "config_path", "config_source", "weight_path", "weight_source", "input_shape", "total_params", "trainable_params", "feature_0_shape", "feature_1_shape", "feature_2_shape", "feature_3_shape", "feature_4_shape", "error_type", "error_message", ] with OUTPUT_CSV.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter(handle, fieldnames=fieldnames) writer.writeheader() for result in results: row = { "model_name": result.get("model_name"), "status": result.get("status"), "config_path": result.get("config_path"), "config_source": result.get("config_source"), "weight_path": result.get("weight_path"), "weight_source": result.get("weight_source"), "input_shape": result.get("input_shape"), "total_params": result.get("total_params"), "trainable_params": result.get("trainable_params"), "error_type": result.get("error_type"), "error_message": result.get("error_message"), } feature_shapes = result.get("feature_shapes", []) for idx in range(5): row[f"feature_{idx}_shape"] = feature_shapes[idx] if idx < len(feature_shapes) else None writer.writerow(row) def analyze_model(model_name: str): print(f"Model: {model_name}") try: built = build_swinv2_auto( model_name=model_name, return_config=True, return_resolution=True, verbose=False, ) if not isinstance(built, tuple) or len(built) != 3: raise RuntimeError("build_swinv2_auto(return_config=True, return_resolution=True) must return 3 values") model, config, resolution = built model.eval() img_size = int(config.DATA.IMG_SIZE) in_chans = int(config.MODEL.SWINV2.IN_CHANS) input_shape = (1, in_chans, img_size, img_size) x = torch.randn(*input_shape) with torch.no_grad(): features = model.forward_multiscale_features(x) total_params, trainable_params = count_parameters(model) feature_shapes = [format_shape(tuple(feature.shape)) for feature in features] print(f" Resolved config: {resolution['config_path']} ({resolution['config_source']})") print(f" Resolved weight: {resolution['weight_path']} ({resolution['weight_source']})") print(f" Input shape: {format_shape(input_shape)}") print(f" Parameters: total={total_params:,}, trainable={trainable_params:,}") print(" forward_multiscale_features:") for idx, feature_shape in enumerate(feature_shapes): print(f" [{idx}] {feature_shape}") result = { "model_name": model_name, "status": "ok", "config_path": resolution["config_path"], "config_source": resolution["config_source"], "weight_path": resolution["weight_path"], "weight_source": resolution["weight_source"], "input_shape": format_shape(input_shape), "total_params": total_params, "trainable_params": trainable_params, "feature_shapes": feature_shapes, "error_type": None, "error_message": None, } except Exception as exc: print(f" ERROR: {type(exc).__name__}: {exc}") result = { "model_name": model_name, "status": "error", "config_path": None, "config_source": None, "weight_path": None, "weight_source": None, "input_shape": None, "total_params": None, "trainable_params": None, "feature_shapes": [], "error_type": type(exc).__name__, "error_message": str(exc), } print() return result def main(): model_names = sorted(path.stem for path in WEIGHT_DIR.glob("*.pth")) if not model_names: raise RuntimeError(f"No SwinV2 weights found under {WEIGHT_DIR}") print(f"Found {len(model_names)} SwinV2 weight files in {WEIGHT_DIR}") print() results = [] for model_name in model_names: results.append(analyze_model(model_name)) write_reports(results) print(f"Saved JSON report to {OUTPUT_JSON}") print(f"Saved CSV report to {OUTPUT_CSV}") if __name__ == "__main__": main()