| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 |
- from __future__ import annotations
- from argparse import Namespace
- from pathlib import Path
- from types import SimpleNamespace
- from typing import Any
- import torch
- import yaml
- from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
- ROOT_DIR = Path(__file__).resolve().parents[2]
- SWINV2_CONFIG_DIR = ROOT_DIR / "configs" / "swinv2"
- SWINV2_WEIGHT_DIR = ROOT_DIR / "weights" / "swinv2"
- MAP22KTO1K_PATH = ROOT_DIR / "lib" / "SwinTransformer" / "data" / "map22kto1k.txt"
- DEFAULTS: dict[str, Any] = {
- "DATA": {
- "IMG_SIZE": 224,
- },
- "MODEL": {
- "TYPE": "swinv2",
- "NAME": "swinv2_tiny_patch4_window8_256",
- "NUM_CLASSES": 1000,
- "DROP_RATE": 0.0,
- "DROP_PATH_RATE": 0.1,
- "PRETRAINED": "",
- "SWINV2": {
- "PATCH_SIZE": 4,
- "IN_CHANS": 3,
- "EMBED_DIM": 96,
- "DEPTHS": [2, 2, 6, 2],
- "NUM_HEADS": [3, 6, 12, 24],
- "WINDOW_SIZE": 7,
- "MLP_RATIO": 4.0,
- "QKV_BIAS": True,
- "APE": False,
- "PATCH_NORM": True,
- "PRETRAINED_WINDOW_SIZES": [0, 0, 0, 0],
- },
- },
- "TRAIN": {
- "USE_CHECKPOINT": False,
- },
- }
- def _deep_copy_dict(value: dict[str, Any]) -> dict[str, Any]:
- copied: dict[str, Any] = {}
- for key, item in value.items():
- if isinstance(item, dict):
- copied[key] = _deep_copy_dict(item)
- elif isinstance(item, list):
- copied[key] = list(item)
- else:
- copied[key] = item
- return copied
- def _merge_dict(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
- for key, value in src.items():
- if isinstance(value, dict) and isinstance(dst.get(key), dict):
- _merge_dict(dst[key], value)
- else:
- dst[key] = value
- return dst
- def _dict_to_namespace(value: Any) -> Any:
- if isinstance(value, dict):
- return SimpleNamespace(**{key: _dict_to_namespace(item) for key, item in value.items()})
- if isinstance(value, list):
- return [_dict_to_namespace(item) for item in value]
- return value
- def _get_arg(args: Namespace | None, *names: str) -> Any:
- if args is None:
- return None
- for name in names:
- if hasattr(args, name):
- value = getattr(args, name)
- if value is not None:
- return value
- return None
- def _to_path(value: str | Path | None) -> Path | None:
- if value is None:
- return None
- return Path(value).expanduser().resolve()
- def _resolve_model_name(
- model_name: str | None,
- config_path: Path | None,
- weight_path: Path | None,
- args: Namespace | None,
- ) -> str | None:
- return (
- _get_arg(args, "model_name", "model")
- or model_name
- or (weight_path.stem if weight_path is not None else None)
- or (config_path.stem if config_path is not None else None)
- )
- def _resolve_config_path(model_name: str | None, config_path: str | Path | None, args: Namespace | None) -> Path | None:
- cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
- if cli_cfg is not None:
- return cli_cfg
- explicit_cfg = _to_path(config_path)
- if explicit_cfg is not None:
- return explicit_cfg
- if model_name is None:
- return None
- candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
- return candidate if candidate.exists() else None
- def _resolve_weight_path(model_name: str | None, weight_path: str | Path | None, args: Namespace | None) -> Path | None:
- cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
- if cli_weight is not None:
- return cli_weight
- explicit_weight = _to_path(weight_path)
- if explicit_weight is not None:
- return explicit_weight
- if model_name is None:
- return None
- candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
- return candidate if candidate.exists() else None
- def _resolve_config_with_source(
- model_name: str | None,
- config_path: str | Path | None,
- args: Namespace | None,
- ) -> tuple[Path | None, str]:
- cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
- if cli_cfg is not None:
- return cli_cfg, "args.cfg"
- explicit_cfg = _to_path(config_path)
- if explicit_cfg is not None:
- return explicit_cfg, "function config_path"
- if model_name is None:
- return None, "defaults only"
- candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
- if candidate.exists():
- return candidate, "auto by MODEL.NAME"
- return None, "defaults only"
- def _resolve_weight_with_source(
- model_name: str | None,
- weight_path: str | Path | None,
- args: Namespace | None,
- ) -> tuple[Path | None, str]:
- cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
- if cli_weight is not None:
- return cli_weight, "args.pretrained"
- explicit_weight = _to_path(weight_path)
- if explicit_weight is not None:
- return explicit_weight, "function weight_path"
- if model_name is None:
- return None, "not resolved"
- candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
- if candidate.exists():
- return candidate, "auto by MODEL.NAME"
- return None, "not resolved"
- def _load_yaml_config(config_path: Path | None) -> dict[str, Any]:
- config = _deep_copy_dict(DEFAULTS)
- if config_path is None:
- return config
- if not config_path.exists():
- raise FileNotFoundError(f"SwinV2 config not found: {config_path}")
- with config_path.open("r", encoding="utf-8") as handle:
- yaml_config = yaml.safe_load(handle) or {}
- return _merge_dict(config, yaml_config)
- def _set_nested(config: dict[str, Any], path: tuple[str, ...], value: Any):
- current = config
- for key in path[:-1]:
- current = current.setdefault(key, {})
- current[path[-1]] = value
- def _collect_function_overrides(
- model_name: str | None,
- weight_path: Path | None,
- num_classes: int | None,
- img_size: int | None,
- in_chans: int | None,
- use_checkpoint: bool | None,
- model_kwargs: dict[str, Any],
- ) -> list[tuple[tuple[str, ...], Any]]:
- overrides: list[tuple[tuple[str, ...], Any]] = []
- if model_name is not None:
- overrides.append((("MODEL", "NAME"), model_name))
- if weight_path is not None:
- overrides.append((("MODEL", "PRETRAINED"), str(weight_path)))
- if num_classes is not None:
- overrides.append((("MODEL", "NUM_CLASSES"), num_classes))
- if img_size is not None:
- overrides.append((("DATA", "IMG_SIZE"), img_size))
- if in_chans is not None:
- overrides.append((("MODEL", "SWINV2", "IN_CHANS"), in_chans))
- if use_checkpoint is not None:
- overrides.append((("TRAIN", "USE_CHECKPOINT"), use_checkpoint))
- model_key_map = {
- "patch_size": ("MODEL", "SWINV2", "PATCH_SIZE"),
- "embed_dim": ("MODEL", "SWINV2", "EMBED_DIM"),
- "depths": ("MODEL", "SWINV2", "DEPTHS"),
- "num_heads": ("MODEL", "SWINV2", "NUM_HEADS"),
- "window_size": ("MODEL", "SWINV2", "WINDOW_SIZE"),
- "mlp_ratio": ("MODEL", "SWINV2", "MLP_RATIO"),
- "qkv_bias": ("MODEL", "SWINV2", "QKV_BIAS"),
- "ape": ("MODEL", "SWINV2", "APE"),
- "patch_norm": ("MODEL", "SWINV2", "PATCH_NORM"),
- "pretrained_window_sizes": ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
- "drop_rate": ("MODEL", "DROP_RATE"),
- "drop_path_rate": ("MODEL", "DROP_PATH_RATE"),
- }
- for key, path in model_key_map.items():
- if key in model_kwargs and model_kwargs[key] is not None:
- overrides.append((path, model_kwargs[key]))
- return overrides
- def _collect_arg_overrides(args: Namespace | None) -> list[tuple[tuple[str, ...], Any]]:
- if args is None:
- return []
- key_map = {
- ("model_name", "model"): ("MODEL", "NAME"),
- ("pretrained", "weights", "weight_path", "ckpt", "checkpoint"): ("MODEL", "PRETRAINED"),
- ("num_classes",): ("MODEL", "NUM_CLASSES"),
- ("img_size", "image_size", "input_size"): ("DATA", "IMG_SIZE"),
- ("in_chans", "in_channels"): ("MODEL", "SWINV2", "IN_CHANS"),
- ("patch_size",): ("MODEL", "SWINV2", "PATCH_SIZE"),
- ("embed_dim",): ("MODEL", "SWINV2", "EMBED_DIM"),
- ("depths",): ("MODEL", "SWINV2", "DEPTHS"),
- ("num_heads",): ("MODEL", "SWINV2", "NUM_HEADS"),
- ("window_size",): ("MODEL", "SWINV2", "WINDOW_SIZE"),
- ("mlp_ratio",): ("MODEL", "SWINV2", "MLP_RATIO"),
- ("qkv_bias",): ("MODEL", "SWINV2", "QKV_BIAS"),
- ("ape",): ("MODEL", "SWINV2", "APE"),
- ("patch_norm",): ("MODEL", "SWINV2", "PATCH_NORM"),
- ("pretrained_window_sizes",): ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
- ("drop_rate",): ("MODEL", "DROP_RATE"),
- ("drop_path_rate",): ("MODEL", "DROP_PATH_RATE"),
- ("use_checkpoint",): ("TRAIN", "USE_CHECKPOINT"),
- }
- overrides: list[tuple[tuple[str, ...], Any]] = []
- for names, path in key_map.items():
- value = _get_arg(args, *names)
- if value is not None:
- overrides.append((path, value))
- return overrides
- def _apply_overrides(config: dict[str, Any], overrides: list[tuple[tuple[str, ...], Any]]) -> dict[str, Any]:
- for path, value in overrides:
- _set_nested(config, path, value)
- return config
- def _extract_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
- if isinstance(checkpoint, dict):
- for key in ("model", "state_dict"):
- if key in checkpoint and isinstance(checkpoint[key], dict):
- return checkpoint[key]
- return checkpoint
- raise TypeError(f"Unsupported checkpoint format: {type(checkpoint)!r}")
- def _remap_head_if_needed(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]):
- if "head.bias" not in state_dict or "head.weight" not in state_dict:
- return
- ckpt_classes = state_dict["head.bias"].shape[0]
- head_bias = getattr(model.head, "bias", None)
- model_classes = head_bias.shape[0] if head_bias is not None else 0
- if ckpt_classes == model_classes:
- return
- if ckpt_classes == 21841 and model_classes == 1000 and MAP22KTO1K_PATH.exists():
- with MAP22KTO1K_PATH.open("r", encoding="utf-8") as handle:
- indices = [int(line.strip()) for line in handle if line.strip()]
- state_dict["head.weight"] = state_dict["head.weight"][indices, :]
- state_dict["head.bias"] = state_dict["head.bias"][indices]
- return
- state_dict.pop("head.weight", None)
- state_dict.pop("head.bias", None)
- def _sanitize_state_dict(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
- if any(key.startswith("encoder.") for key in state_dict):
- state_dict = {
- key.replace("encoder.", "", 1): value
- for key, value in state_dict.items()
- if key.startswith("encoder.")
- }
- remapped: dict[str, torch.Tensor] = {}
- for key, value in state_dict.items():
- if "relative_position_index" in key or "relative_coords_table" in key or "attn_mask" in key:
- continue
- remapped[key.replace("rpe_mlp", "cpb_mlp")] = value
- _remap_head_if_needed(model, remapped)
- model_state = model.state_dict()
- filtered: dict[str, torch.Tensor] = {}
- for key, value in remapped.items():
- if key in model_state and model_state[key].shape == value.shape:
- filtered[key] = value
- return filtered
- def _load_checkpoint(weight_path: Path) -> dict[str, Any]:
- try:
- return torch.load(weight_path, map_location="cpu")
- except Exception as exc:
- raise RuntimeError(f"Failed to load SwinV2 checkpoint: {weight_path}") from exc
- def build_swinv2(
- model_name: str | None = None,
- config_path: str | Path | None = None,
- weight_path: str | Path | None = None,
- args: Namespace | None = None,
- *,
- num_classes: int | None = None,
- img_size: int | None = None,
- in_chans: int | None = None,
- use_checkpoint: bool | None = None,
- strict: bool = False,
- load_weights: bool = True,
- return_config: bool = False,
- **model_kwargs,
- ):
- """Build a SwinTransformerV2 with loaded weights.
- Precedence order:
- 1. internal defaults
- 2. YAML config under ``configs/swinv2``
- 3. explicit function inputs outside config files
- 4. command line style ``args`` overrides
- """
- explicit_weight_path = _to_path(weight_path)
- initial_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
- resolved_config_path = _resolve_config_path(initial_model_name, config_path, args)
- config_dict = _load_yaml_config(resolved_config_path)
- resolved_model_name = _resolve_model_name(
- initial_model_name or config_dict["MODEL"]["NAME"],
- resolved_config_path,
- explicit_weight_path,
- args,
- )
- resolved_weight_path = _resolve_weight_path(resolved_model_name, weight_path, args)
- function_overrides = _collect_function_overrides(
- model_name=resolved_model_name,
- weight_path=resolved_weight_path,
- num_classes=num_classes,
- img_size=img_size,
- in_chans=in_chans,
- use_checkpoint=use_checkpoint,
- model_kwargs=model_kwargs,
- )
- arg_overrides = _collect_arg_overrides(args)
- config_dict = _apply_overrides(config_dict, function_overrides)
- config_dict = _apply_overrides(config_dict, arg_overrides)
- model_cfg = config_dict["MODEL"]
- swinv2_cfg = model_cfg["SWINV2"]
- model = SwinTransformerV2(
- img_size=config_dict["DATA"]["IMG_SIZE"],
- patch_size=swinv2_cfg["PATCH_SIZE"],
- in_chans=swinv2_cfg["IN_CHANS"],
- num_classes=model_cfg["NUM_CLASSES"],
- embed_dim=swinv2_cfg["EMBED_DIM"],
- depths=tuple(swinv2_cfg["DEPTHS"]),
- num_heads=tuple(swinv2_cfg["NUM_HEADS"]),
- window_size=swinv2_cfg["WINDOW_SIZE"],
- mlp_ratio=swinv2_cfg["MLP_RATIO"],
- qkv_bias=swinv2_cfg["QKV_BIAS"],
- drop_rate=model_cfg["DROP_RATE"],
- drop_path_rate=model_cfg["DROP_PATH_RATE"],
- ape=swinv2_cfg["APE"],
- patch_norm=swinv2_cfg["PATCH_NORM"],
- use_checkpoint=config_dict["TRAIN"]["USE_CHECKPOINT"],
- pretrained_window_sizes=tuple(swinv2_cfg["PRETRAINED_WINDOW_SIZES"]),
- )
- if load_weights:
- if resolved_weight_path is None:
- raise FileNotFoundError(
- f"No SwinV2 weight file resolved for model '{model_cfg['NAME']}'. "
- f"Expected one under {SWINV2_WEIGHT_DIR}."
- )
- if not resolved_weight_path.exists():
- raise FileNotFoundError(f"SwinV2 weight file not found: {resolved_weight_path}")
- checkpoint = _load_checkpoint(resolved_weight_path)
- state_dict = _sanitize_state_dict(model, _extract_state_dict(checkpoint))
- model.load_state_dict(state_dict, strict=strict)
- config = _dict_to_namespace(config_dict)
- if return_config:
- return model, config
- return model
- def build_swinv2_auto(
- model_name: str | None = None,
- config_path: str | Path | None = None,
- weight_path: str | Path | None = None,
- args: Namespace | None = None,
- *,
- verbose: bool = True,
- return_config: bool = False,
- return_resolution: bool = False,
- **kwargs,
- ):
- """Auto-resolve SwinV2 config and weights by ``MODEL.NAME`` and print sources.
- This wrapper keeps the same precedence rules as ``build_swinv2`` while making
- the config/weight resolution explicit for callers.
- """
- explicit_weight_path = _to_path(weight_path)
- candidate_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
- resolved_config_path, config_source = _resolve_config_with_source(candidate_model_name, config_path, args)
- temp_config = _load_yaml_config(resolved_config_path)
- final_model_name = _resolve_model_name(
- candidate_model_name or temp_config["MODEL"]["NAME"],
- resolved_config_path,
- explicit_weight_path,
- args,
- ) or temp_config["MODEL"]["NAME"]
- resolved_weight_path, weight_source = _resolve_weight_with_source(final_model_name, weight_path, args)
- built = build_swinv2(
- model_name=final_model_name,
- config_path=resolved_config_path,
- weight_path=resolved_weight_path,
- args=args,
- return_config=True,
- **kwargs,
- )
- if not isinstance(built, tuple) or len(built) != 2:
- raise RuntimeError("build_swinv2(return_config=True) must return (model, config)")
- model, config = built
- resolution = {
- "model_name": config.MODEL.NAME,
- "config_path": str(resolved_config_path) if resolved_config_path is not None else None,
- "config_source": config_source,
- "weight_path": str(resolved_weight_path) if resolved_weight_path is not None else None,
- "weight_source": weight_source,
- }
- if verbose:
- print(
- "[build_swinv2_auto] "
- f"MODEL.NAME={resolution['model_name']} | "
- f"config={resolution['config_path']} ({resolution['config_source']}) | "
- f"weight={resolution['weight_path']} ({resolution['weight_source']})"
- )
- if return_config and return_resolution:
- return model, config, resolution
- if return_config:
- return model, config
- if return_resolution:
- return model, resolution
- return model
- __all__ = ["build_swinv2", "build_swinv2_auto"]
|