analyze_swinv2_models.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import annotations
  2. import csv
  3. import json
  4. from pathlib import Path
  5. import sys
  6. import torch
  7. ROOT = Path(__file__).resolve().parents[3]
  8. WEIGHT_DIR = ROOT / "weights" / "swinv2"
  9. OUTPUT_JSON = ROOT / "tmp" / "swinv2_model_analysis.json"
  10. OUTPUT_CSV = ROOT / "tmp" / "swinv2_model_analysis.csv"
  11. if str(ROOT) not in sys.path:
  12. sys.path.insert(0, str(ROOT))
  13. from lib.modules import build_swinv2_auto
  14. def count_parameters(model: torch.nn.Module) -> tuple[int, int]:
  15. total = sum(param.numel() for param in model.parameters())
  16. trainable = sum(param.numel() for param in model.parameters() if param.requires_grad)
  17. return total, trainable
  18. def format_shape(shape) -> str:
  19. return "x".join(str(dim) for dim in shape)
  20. def write_reports(results: list[dict]):
  21. OUTPUT_JSON.write_text(json.dumps(results, indent=2, ensure_ascii=False), encoding="utf-8")
  22. fieldnames = [
  23. "model_name",
  24. "status",
  25. "config_path",
  26. "config_source",
  27. "weight_path",
  28. "weight_source",
  29. "input_shape",
  30. "total_params",
  31. "trainable_params",
  32. "feature_0_shape",
  33. "feature_1_shape",
  34. "feature_2_shape",
  35. "feature_3_shape",
  36. "feature_4_shape",
  37. "error_type",
  38. "error_message",
  39. ]
  40. with OUTPUT_CSV.open("w", encoding="utf-8", newline="") as handle:
  41. writer = csv.DictWriter(handle, fieldnames=fieldnames)
  42. writer.writeheader()
  43. for result in results:
  44. row = {
  45. "model_name": result.get("model_name"),
  46. "status": result.get("status"),
  47. "config_path": result.get("config_path"),
  48. "config_source": result.get("config_source"),
  49. "weight_path": result.get("weight_path"),
  50. "weight_source": result.get("weight_source"),
  51. "input_shape": result.get("input_shape"),
  52. "total_params": result.get("total_params"),
  53. "trainable_params": result.get("trainable_params"),
  54. "error_type": result.get("error_type"),
  55. "error_message": result.get("error_message"),
  56. }
  57. feature_shapes = result.get("feature_shapes", [])
  58. for idx in range(5):
  59. row[f"feature_{idx}_shape"] = feature_shapes[idx] if idx < len(feature_shapes) else None
  60. writer.writerow(row)
  61. def analyze_model(model_name: str):
  62. print(f"Model: {model_name}")
  63. try:
  64. built = build_swinv2_auto(
  65. model_name=model_name,
  66. return_config=True,
  67. return_resolution=True,
  68. verbose=False,
  69. )
  70. if not isinstance(built, tuple) or len(built) != 3:
  71. raise RuntimeError("build_swinv2_auto(return_config=True, return_resolution=True) must return 3 values")
  72. model, config, resolution = built
  73. model.eval()
  74. img_size = int(config.DATA.IMG_SIZE)
  75. in_chans = int(config.MODEL.SWINV2.IN_CHANS)
  76. input_shape = (1, in_chans, img_size, img_size)
  77. x = torch.randn(*input_shape)
  78. with torch.no_grad():
  79. features = model.forward_multiscale_features(x)
  80. total_params, trainable_params = count_parameters(model)
  81. feature_shapes = [format_shape(tuple(feature.shape)) for feature in features]
  82. print(f" Resolved config: {resolution['config_path']} ({resolution['config_source']})")
  83. print(f" Resolved weight: {resolution['weight_path']} ({resolution['weight_source']})")
  84. print(f" Input shape: {format_shape(input_shape)}")
  85. print(f" Parameters: total={total_params:,}, trainable={trainable_params:,}")
  86. print(" forward_multiscale_features:")
  87. for idx, feature_shape in enumerate(feature_shapes):
  88. print(f" [{idx}] {feature_shape}")
  89. result = {
  90. "model_name": model_name,
  91. "status": "ok",
  92. "config_path": resolution["config_path"],
  93. "config_source": resolution["config_source"],
  94. "weight_path": resolution["weight_path"],
  95. "weight_source": resolution["weight_source"],
  96. "input_shape": format_shape(input_shape),
  97. "total_params": total_params,
  98. "trainable_params": trainable_params,
  99. "feature_shapes": feature_shapes,
  100. "error_type": None,
  101. "error_message": None,
  102. }
  103. except Exception as exc:
  104. print(f" ERROR: {type(exc).__name__}: {exc}")
  105. result = {
  106. "model_name": model_name,
  107. "status": "error",
  108. "config_path": None,
  109. "config_source": None,
  110. "weight_path": None,
  111. "weight_source": None,
  112. "input_shape": None,
  113. "total_params": None,
  114. "trainable_params": None,
  115. "feature_shapes": [],
  116. "error_type": type(exc).__name__,
  117. "error_message": str(exc),
  118. }
  119. print()
  120. return result
  121. def main():
  122. model_names = sorted(path.stem for path in WEIGHT_DIR.glob("*.pth"))
  123. if not model_names:
  124. raise RuntimeError(f"No SwinV2 weights found under {WEIGHT_DIR}")
  125. print(f"Found {len(model_names)} SwinV2 weight files in {WEIGHT_DIR}")
  126. print()
  127. results = []
  128. for model_name in model_names:
  129. results.append(analyze_model(model_name))
  130. write_reports(results)
  131. print(f"Saved JSON report to {OUTPUT_JSON}")
  132. print(f"Saved CSV report to {OUTPUT_CSV}")
  133. if __name__ == "__main__":
  134. main()