build_swinv2.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. from __future__ import annotations
  2. from argparse import Namespace
  3. from pathlib import Path
  4. from types import SimpleNamespace
  5. from typing import Any
  6. import torch
  7. import yaml
  8. from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
  9. ROOT_DIR = Path(__file__).resolve().parents[2]
  10. SWINV2_CONFIG_DIR = ROOT_DIR / "configs" / "swinv2"
  11. SWINV2_WEIGHT_DIR = ROOT_DIR / "weights" / "swinv2"
  12. MAP22KTO1K_PATH = ROOT_DIR / "lib" / "SwinTransformer" / "data" / "map22kto1k.txt"
  13. DEFAULTS: dict[str, Any] = {
  14. "DATA": {
  15. "IMG_SIZE": 224,
  16. },
  17. "MODEL": {
  18. "TYPE": "swinv2",
  19. "NAME": "swinv2_tiny_patch4_window8_256",
  20. "NUM_CLASSES": 1000,
  21. "DROP_RATE": 0.0,
  22. "DROP_PATH_RATE": 0.1,
  23. "PRETRAINED": "",
  24. "SWINV2": {
  25. "PATCH_SIZE": 4,
  26. "IN_CHANS": 3,
  27. "EMBED_DIM": 96,
  28. "DEPTHS": [2, 2, 6, 2],
  29. "NUM_HEADS": [3, 6, 12, 24],
  30. "WINDOW_SIZE": 7,
  31. "MLP_RATIO": 4.0,
  32. "QKV_BIAS": True,
  33. "APE": False,
  34. "PATCH_NORM": True,
  35. "PRETRAINED_WINDOW_SIZES": [0, 0, 0, 0],
  36. },
  37. },
  38. "TRAIN": {
  39. "USE_CHECKPOINT": False,
  40. },
  41. }
  42. def _deep_copy_dict(value: dict[str, Any]) -> dict[str, Any]:
  43. copied: dict[str, Any] = {}
  44. for key, item in value.items():
  45. if isinstance(item, dict):
  46. copied[key] = _deep_copy_dict(item)
  47. elif isinstance(item, list):
  48. copied[key] = list(item)
  49. else:
  50. copied[key] = item
  51. return copied
  52. def _merge_dict(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
  53. for key, value in src.items():
  54. if isinstance(value, dict) and isinstance(dst.get(key), dict):
  55. _merge_dict(dst[key], value)
  56. else:
  57. dst[key] = value
  58. return dst
  59. def _dict_to_namespace(value: Any) -> Any:
  60. if isinstance(value, dict):
  61. return SimpleNamespace(**{key: _dict_to_namespace(item) for key, item in value.items()})
  62. if isinstance(value, list):
  63. return [_dict_to_namespace(item) for item in value]
  64. return value
  65. def _get_arg(args: Namespace | None, *names: str) -> Any:
  66. if args is None:
  67. return None
  68. for name in names:
  69. if hasattr(args, name):
  70. value = getattr(args, name)
  71. if value is not None:
  72. return value
  73. return None
  74. def _to_path(value: str | Path | None) -> Path | None:
  75. if value is None:
  76. return None
  77. return Path(value).expanduser().resolve()
  78. def _resolve_model_name(
  79. model_name: str | None,
  80. config_path: Path | None,
  81. weight_path: Path | None,
  82. args: Namespace | None,
  83. ) -> str | None:
  84. return (
  85. _get_arg(args, "model_name", "model")
  86. or model_name
  87. or (weight_path.stem if weight_path is not None else None)
  88. or (config_path.stem if config_path is not None else None)
  89. )
  90. def _resolve_config_path(model_name: str | None, config_path: str | Path | None, args: Namespace | None) -> Path | None:
  91. cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
  92. if cli_cfg is not None:
  93. return cli_cfg
  94. explicit_cfg = _to_path(config_path)
  95. if explicit_cfg is not None:
  96. return explicit_cfg
  97. if model_name is None:
  98. return None
  99. candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
  100. return candidate if candidate.exists() else None
  101. def _resolve_weight_path(model_name: str | None, weight_path: str | Path | None, args: Namespace | None) -> Path | None:
  102. cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
  103. if cli_weight is not None:
  104. return cli_weight
  105. explicit_weight = _to_path(weight_path)
  106. if explicit_weight is not None:
  107. return explicit_weight
  108. if model_name is None:
  109. return None
  110. candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
  111. return candidate if candidate.exists() else None
  112. def _resolve_config_with_source(
  113. model_name: str | None,
  114. config_path: str | Path | None,
  115. args: Namespace | None,
  116. ) -> tuple[Path | None, str]:
  117. cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
  118. if cli_cfg is not None:
  119. return cli_cfg, "args.cfg"
  120. explicit_cfg = _to_path(config_path)
  121. if explicit_cfg is not None:
  122. return explicit_cfg, "function config_path"
  123. if model_name is None:
  124. return None, "defaults only"
  125. candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
  126. if candidate.exists():
  127. return candidate, "auto by MODEL.NAME"
  128. return None, "defaults only"
  129. def _resolve_weight_with_source(
  130. model_name: str | None,
  131. weight_path: str | Path | None,
  132. args: Namespace | None,
  133. ) -> tuple[Path | None, str]:
  134. cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
  135. if cli_weight is not None:
  136. return cli_weight, "args.pretrained"
  137. explicit_weight = _to_path(weight_path)
  138. if explicit_weight is not None:
  139. return explicit_weight, "function weight_path"
  140. if model_name is None:
  141. return None, "not resolved"
  142. candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
  143. if candidate.exists():
  144. return candidate, "auto by MODEL.NAME"
  145. return None, "not resolved"
  146. def _load_yaml_config(config_path: Path | None) -> dict[str, Any]:
  147. config = _deep_copy_dict(DEFAULTS)
  148. if config_path is None:
  149. return config
  150. if not config_path.exists():
  151. raise FileNotFoundError(f"SwinV2 config not found: {config_path}")
  152. with config_path.open("r", encoding="utf-8") as handle:
  153. yaml_config = yaml.safe_load(handle) or {}
  154. return _merge_dict(config, yaml_config)
  155. def _set_nested(config: dict[str, Any], path: tuple[str, ...], value: Any):
  156. current = config
  157. for key in path[:-1]:
  158. current = current.setdefault(key, {})
  159. current[path[-1]] = value
  160. def _collect_function_overrides(
  161. model_name: str | None,
  162. weight_path: Path | None,
  163. num_classes: int | None,
  164. img_size: int | None,
  165. in_chans: int | None,
  166. use_checkpoint: bool | None,
  167. model_kwargs: dict[str, Any],
  168. ) -> list[tuple[tuple[str, ...], Any]]:
  169. overrides: list[tuple[tuple[str, ...], Any]] = []
  170. if model_name is not None:
  171. overrides.append((("MODEL", "NAME"), model_name))
  172. if weight_path is not None:
  173. overrides.append((("MODEL", "PRETRAINED"), str(weight_path)))
  174. if num_classes is not None:
  175. overrides.append((("MODEL", "NUM_CLASSES"), num_classes))
  176. if img_size is not None:
  177. overrides.append((("DATA", "IMG_SIZE"), img_size))
  178. if in_chans is not None:
  179. overrides.append((("MODEL", "SWINV2", "IN_CHANS"), in_chans))
  180. if use_checkpoint is not None:
  181. overrides.append((("TRAIN", "USE_CHECKPOINT"), use_checkpoint))
  182. model_key_map = {
  183. "patch_size": ("MODEL", "SWINV2", "PATCH_SIZE"),
  184. "embed_dim": ("MODEL", "SWINV2", "EMBED_DIM"),
  185. "depths": ("MODEL", "SWINV2", "DEPTHS"),
  186. "num_heads": ("MODEL", "SWINV2", "NUM_HEADS"),
  187. "window_size": ("MODEL", "SWINV2", "WINDOW_SIZE"),
  188. "mlp_ratio": ("MODEL", "SWINV2", "MLP_RATIO"),
  189. "qkv_bias": ("MODEL", "SWINV2", "QKV_BIAS"),
  190. "ape": ("MODEL", "SWINV2", "APE"),
  191. "patch_norm": ("MODEL", "SWINV2", "PATCH_NORM"),
  192. "pretrained_window_sizes": ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
  193. "drop_rate": ("MODEL", "DROP_RATE"),
  194. "drop_path_rate": ("MODEL", "DROP_PATH_RATE"),
  195. }
  196. for key, path in model_key_map.items():
  197. if key in model_kwargs and model_kwargs[key] is not None:
  198. overrides.append((path, model_kwargs[key]))
  199. return overrides
  200. def _collect_arg_overrides(args: Namespace | None) -> list[tuple[tuple[str, ...], Any]]:
  201. if args is None:
  202. return []
  203. key_map = {
  204. ("model_name", "model"): ("MODEL", "NAME"),
  205. ("pretrained", "weights", "weight_path", "ckpt", "checkpoint"): ("MODEL", "PRETRAINED"),
  206. ("num_classes",): ("MODEL", "NUM_CLASSES"),
  207. ("img_size", "image_size", "input_size"): ("DATA", "IMG_SIZE"),
  208. ("in_chans", "in_channels"): ("MODEL", "SWINV2", "IN_CHANS"),
  209. ("patch_size",): ("MODEL", "SWINV2", "PATCH_SIZE"),
  210. ("embed_dim",): ("MODEL", "SWINV2", "EMBED_DIM"),
  211. ("depths",): ("MODEL", "SWINV2", "DEPTHS"),
  212. ("num_heads",): ("MODEL", "SWINV2", "NUM_HEADS"),
  213. ("window_size",): ("MODEL", "SWINV2", "WINDOW_SIZE"),
  214. ("mlp_ratio",): ("MODEL", "SWINV2", "MLP_RATIO"),
  215. ("qkv_bias",): ("MODEL", "SWINV2", "QKV_BIAS"),
  216. ("ape",): ("MODEL", "SWINV2", "APE"),
  217. ("patch_norm",): ("MODEL", "SWINV2", "PATCH_NORM"),
  218. ("pretrained_window_sizes",): ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
  219. ("drop_rate",): ("MODEL", "DROP_RATE"),
  220. ("drop_path_rate",): ("MODEL", "DROP_PATH_RATE"),
  221. ("use_checkpoint",): ("TRAIN", "USE_CHECKPOINT"),
  222. }
  223. overrides: list[tuple[tuple[str, ...], Any]] = []
  224. for names, path in key_map.items():
  225. value = _get_arg(args, *names)
  226. if value is not None:
  227. overrides.append((path, value))
  228. return overrides
  229. def _apply_overrides(config: dict[str, Any], overrides: list[tuple[tuple[str, ...], Any]]) -> dict[str, Any]:
  230. for path, value in overrides:
  231. _set_nested(config, path, value)
  232. return config
  233. def _extract_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
  234. if isinstance(checkpoint, dict):
  235. for key in ("model", "state_dict"):
  236. if key in checkpoint and isinstance(checkpoint[key], dict):
  237. return checkpoint[key]
  238. return checkpoint
  239. raise TypeError(f"Unsupported checkpoint format: {type(checkpoint)!r}")
  240. def _remap_head_if_needed(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]):
  241. if "head.bias" not in state_dict or "head.weight" not in state_dict:
  242. return
  243. ckpt_classes = state_dict["head.bias"].shape[0]
  244. head_bias = getattr(model.head, "bias", None)
  245. model_classes = head_bias.shape[0] if head_bias is not None else 0
  246. if ckpt_classes == model_classes:
  247. return
  248. if ckpt_classes == 21841 and model_classes == 1000 and MAP22KTO1K_PATH.exists():
  249. with MAP22KTO1K_PATH.open("r", encoding="utf-8") as handle:
  250. indices = [int(line.strip()) for line in handle if line.strip()]
  251. state_dict["head.weight"] = state_dict["head.weight"][indices, :]
  252. state_dict["head.bias"] = state_dict["head.bias"][indices]
  253. return
  254. state_dict.pop("head.weight", None)
  255. state_dict.pop("head.bias", None)
  256. def _sanitize_state_dict(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
  257. if any(key.startswith("encoder.") for key in state_dict):
  258. state_dict = {
  259. key.replace("encoder.", "", 1): value
  260. for key, value in state_dict.items()
  261. if key.startswith("encoder.")
  262. }
  263. remapped: dict[str, torch.Tensor] = {}
  264. for key, value in state_dict.items():
  265. if "relative_position_index" in key or "relative_coords_table" in key or "attn_mask" in key:
  266. continue
  267. remapped[key.replace("rpe_mlp", "cpb_mlp")] = value
  268. _remap_head_if_needed(model, remapped)
  269. model_state = model.state_dict()
  270. filtered: dict[str, torch.Tensor] = {}
  271. for key, value in remapped.items():
  272. if key in model_state and model_state[key].shape == value.shape:
  273. filtered[key] = value
  274. return filtered
  275. def _load_checkpoint(weight_path: Path) -> dict[str, Any]:
  276. try:
  277. return torch.load(weight_path, map_location="cpu")
  278. except Exception as exc:
  279. raise RuntimeError(f"Failed to load SwinV2 checkpoint: {weight_path}") from exc
  280. def build_swinv2(
  281. model_name: str | None = None,
  282. config_path: str | Path | None = None,
  283. weight_path: str | Path | None = None,
  284. args: Namespace | None = None,
  285. *,
  286. num_classes: int | None = None,
  287. img_size: int | None = None,
  288. in_chans: int | None = None,
  289. use_checkpoint: bool | None = None,
  290. strict: bool = False,
  291. load_weights: bool = True,
  292. return_config: bool = False,
  293. **model_kwargs,
  294. ):
  295. """Build a SwinTransformerV2 with loaded weights.
  296. Precedence order:
  297. 1. internal defaults
  298. 2. YAML config under ``configs/swinv2``
  299. 3. explicit function inputs outside config files
  300. 4. command line style ``args`` overrides
  301. """
  302. explicit_weight_path = _to_path(weight_path)
  303. initial_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
  304. resolved_config_path = _resolve_config_path(initial_model_name, config_path, args)
  305. config_dict = _load_yaml_config(resolved_config_path)
  306. resolved_model_name = _resolve_model_name(
  307. initial_model_name or config_dict["MODEL"]["NAME"],
  308. resolved_config_path,
  309. explicit_weight_path,
  310. args,
  311. )
  312. resolved_weight_path = _resolve_weight_path(resolved_model_name, weight_path, args)
  313. function_overrides = _collect_function_overrides(
  314. model_name=resolved_model_name,
  315. weight_path=resolved_weight_path,
  316. num_classes=num_classes,
  317. img_size=img_size,
  318. in_chans=in_chans,
  319. use_checkpoint=use_checkpoint,
  320. model_kwargs=model_kwargs,
  321. )
  322. arg_overrides = _collect_arg_overrides(args)
  323. config_dict = _apply_overrides(config_dict, function_overrides)
  324. config_dict = _apply_overrides(config_dict, arg_overrides)
  325. model_cfg = config_dict["MODEL"]
  326. swinv2_cfg = model_cfg["SWINV2"]
  327. model = SwinTransformerV2(
  328. img_size=config_dict["DATA"]["IMG_SIZE"],
  329. patch_size=swinv2_cfg["PATCH_SIZE"],
  330. in_chans=swinv2_cfg["IN_CHANS"],
  331. num_classes=model_cfg["NUM_CLASSES"],
  332. embed_dim=swinv2_cfg["EMBED_DIM"],
  333. depths=tuple(swinv2_cfg["DEPTHS"]),
  334. num_heads=tuple(swinv2_cfg["NUM_HEADS"]),
  335. window_size=swinv2_cfg["WINDOW_SIZE"],
  336. mlp_ratio=swinv2_cfg["MLP_RATIO"],
  337. qkv_bias=swinv2_cfg["QKV_BIAS"],
  338. drop_rate=model_cfg["DROP_RATE"],
  339. drop_path_rate=model_cfg["DROP_PATH_RATE"],
  340. ape=swinv2_cfg["APE"],
  341. patch_norm=swinv2_cfg["PATCH_NORM"],
  342. use_checkpoint=config_dict["TRAIN"]["USE_CHECKPOINT"],
  343. pretrained_window_sizes=tuple(swinv2_cfg["PRETRAINED_WINDOW_SIZES"]),
  344. )
  345. if load_weights:
  346. if resolved_weight_path is None:
  347. raise FileNotFoundError(
  348. f"No SwinV2 weight file resolved for model '{model_cfg['NAME']}'. "
  349. f"Expected one under {SWINV2_WEIGHT_DIR}."
  350. )
  351. if not resolved_weight_path.exists():
  352. raise FileNotFoundError(f"SwinV2 weight file not found: {resolved_weight_path}")
  353. checkpoint = _load_checkpoint(resolved_weight_path)
  354. state_dict = _sanitize_state_dict(model, _extract_state_dict(checkpoint))
  355. model.load_state_dict(state_dict, strict=strict)
  356. config = _dict_to_namespace(config_dict)
  357. if return_config:
  358. return model, config
  359. return model
  360. def build_swinv2_auto(
  361. model_name: str | None = None,
  362. config_path: str | Path | None = None,
  363. weight_path: str | Path | None = None,
  364. args: Namespace | None = None,
  365. *,
  366. verbose: bool = True,
  367. return_config: bool = False,
  368. return_resolution: bool = False,
  369. **kwargs,
  370. ):
  371. """Auto-resolve SwinV2 config and weights by ``MODEL.NAME`` and print sources.
  372. This wrapper keeps the same precedence rules as ``build_swinv2`` while making
  373. the config/weight resolution explicit for callers.
  374. """
  375. explicit_weight_path = _to_path(weight_path)
  376. candidate_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
  377. resolved_config_path, config_source = _resolve_config_with_source(candidate_model_name, config_path, args)
  378. temp_config = _load_yaml_config(resolved_config_path)
  379. final_model_name = _resolve_model_name(
  380. candidate_model_name or temp_config["MODEL"]["NAME"],
  381. resolved_config_path,
  382. explicit_weight_path,
  383. args,
  384. ) or temp_config["MODEL"]["NAME"]
  385. resolved_weight_path, weight_source = _resolve_weight_with_source(final_model_name, weight_path, args)
  386. built = build_swinv2(
  387. model_name=final_model_name,
  388. config_path=resolved_config_path,
  389. weight_path=resolved_weight_path,
  390. args=args,
  391. return_config=True,
  392. **kwargs,
  393. )
  394. if not isinstance(built, tuple) or len(built) != 2:
  395. raise RuntimeError("build_swinv2(return_config=True) must return (model, config)")
  396. model, config = built
  397. resolution = {
  398. "model_name": config.MODEL.NAME,
  399. "config_path": str(resolved_config_path) if resolved_config_path is not None else None,
  400. "config_source": config_source,
  401. "weight_path": str(resolved_weight_path) if resolved_weight_path is not None else None,
  402. "weight_source": weight_source,
  403. }
  404. if verbose:
  405. print(
  406. "[build_swinv2_auto] "
  407. f"MODEL.NAME={resolution['model_name']} | "
  408. f"config={resolution['config_path']} ({resolution['config_source']}) | "
  409. f"weight={resolution['weight_path']} ({resolution['weight_source']})"
  410. )
  411. if return_config and return_resolution:
  412. return model, config, resolution
  413. if return_config:
  414. return model, config
  415. if return_resolution:
  416. return model, resolution
  417. return model
  418. __all__ = ["build_swinv2", "build_swinv2_auto"]