from __future__ import annotations from argparse import Namespace from pathlib import Path from types import SimpleNamespace from typing import Any import torch import yaml from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2 ROOT_DIR = Path(__file__).resolve().parents[2] SWINV2_CONFIG_DIR = ROOT_DIR / "configs" / "swinv2" SWINV2_WEIGHT_DIR = ROOT_DIR / "weights" / "swinv2" MAP22KTO1K_PATH = ROOT_DIR / "lib" / "SwinTransformer" / "data" / "map22kto1k.txt" DEFAULTS: dict[str, Any] = { "DATA": { "IMG_SIZE": 224, }, "MODEL": { "TYPE": "swinv2", "NAME": "swinv2_tiny_patch4_window8_256", "NUM_CLASSES": 1000, "DROP_RATE": 0.0, "DROP_PATH_RATE": 0.1, "PRETRAINED": "", "SWINV2": { "PATCH_SIZE": 4, "IN_CHANS": 3, "EMBED_DIM": 96, "DEPTHS": [2, 2, 6, 2], "NUM_HEADS": [3, 6, 12, 24], "WINDOW_SIZE": 7, "MLP_RATIO": 4.0, "QKV_BIAS": True, "APE": False, "PATCH_NORM": True, "PRETRAINED_WINDOW_SIZES": [0, 0, 0, 0], }, }, "TRAIN": { "USE_CHECKPOINT": False, }, } def _deep_copy_dict(value: dict[str, Any]) -> dict[str, Any]: copied: dict[str, Any] = {} for key, item in value.items(): if isinstance(item, dict): copied[key] = _deep_copy_dict(item) elif isinstance(item, list): copied[key] = list(item) else: copied[key] = item return copied def _merge_dict(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]: for key, value in src.items(): if isinstance(value, dict) and isinstance(dst.get(key), dict): _merge_dict(dst[key], value) else: dst[key] = value return dst def _dict_to_namespace(value: Any) -> Any: if isinstance(value, dict): return SimpleNamespace(**{key: _dict_to_namespace(item) for key, item in value.items()}) if isinstance(value, list): return [_dict_to_namespace(item) for item in value] return value def _get_arg(args: Namespace | None, *names: str) -> Any: if args is None: return None for name in names: if hasattr(args, name): value = getattr(args, name) if value is not None: return value return None def _to_path(value: str | Path | None) -> Path | None: if value is None: return None return Path(value).expanduser().resolve() def _resolve_model_name( model_name: str | None, config_path: Path | None, weight_path: Path | None, args: Namespace | None, ) -> str | None: return ( _get_arg(args, "model_name", "model") or model_name or (weight_path.stem if weight_path is not None else None) or (config_path.stem if config_path is not None else None) ) def _resolve_config_path(model_name: str | None, config_path: str | Path | None, args: Namespace | None) -> Path | None: cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path")) if cli_cfg is not None: return cli_cfg explicit_cfg = _to_path(config_path) if explicit_cfg is not None: return explicit_cfg if model_name is None: return None candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml" return candidate if candidate.exists() else None def _resolve_weight_path(model_name: str | None, weight_path: str | Path | None, args: Namespace | None) -> Path | None: cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint")) if cli_weight is not None: return cli_weight explicit_weight = _to_path(weight_path) if explicit_weight is not None: return explicit_weight if model_name is None: return None candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth" return candidate if candidate.exists() else None def _resolve_config_with_source( model_name: str | None, config_path: str | Path | None, args: Namespace | None, ) -> tuple[Path | None, str]: cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path")) if cli_cfg is not None: return cli_cfg, "args.cfg" explicit_cfg = _to_path(config_path) if explicit_cfg is not None: return explicit_cfg, "function config_path" if model_name is None: return None, "defaults only" candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml" if candidate.exists(): return candidate, "auto by MODEL.NAME" return None, "defaults only" def _resolve_weight_with_source( model_name: str | None, weight_path: str | Path | None, args: Namespace | None, ) -> tuple[Path | None, str]: cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint")) if cli_weight is not None: return cli_weight, "args.pretrained" explicit_weight = _to_path(weight_path) if explicit_weight is not None: return explicit_weight, "function weight_path" if model_name is None: return None, "not resolved" candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth" if candidate.exists(): return candidate, "auto by MODEL.NAME" return None, "not resolved" def _load_yaml_config(config_path: Path | None) -> dict[str, Any]: config = _deep_copy_dict(DEFAULTS) if config_path is None: return config if not config_path.exists(): raise FileNotFoundError(f"SwinV2 config not found: {config_path}") with config_path.open("r", encoding="utf-8") as handle: yaml_config = yaml.safe_load(handle) or {} return _merge_dict(config, yaml_config) def _set_nested(config: dict[str, Any], path: tuple[str, ...], value: Any): current = config for key in path[:-1]: current = current.setdefault(key, {}) current[path[-1]] = value def _collect_function_overrides( model_name: str | None, weight_path: Path | None, num_classes: int | None, img_size: int | None, in_chans: int | None, use_checkpoint: bool | None, model_kwargs: dict[str, Any], ) -> list[tuple[tuple[str, ...], Any]]: overrides: list[tuple[tuple[str, ...], Any]] = [] if model_name is not None: overrides.append((("MODEL", "NAME"), model_name)) if weight_path is not None: overrides.append((("MODEL", "PRETRAINED"), str(weight_path))) if num_classes is not None: overrides.append((("MODEL", "NUM_CLASSES"), num_classes)) if img_size is not None: overrides.append((("DATA", "IMG_SIZE"), img_size)) if in_chans is not None: overrides.append((("MODEL", "SWINV2", "IN_CHANS"), in_chans)) if use_checkpoint is not None: overrides.append((("TRAIN", "USE_CHECKPOINT"), use_checkpoint)) model_key_map = { "patch_size": ("MODEL", "SWINV2", "PATCH_SIZE"), "embed_dim": ("MODEL", "SWINV2", "EMBED_DIM"), "depths": ("MODEL", "SWINV2", "DEPTHS"), "num_heads": ("MODEL", "SWINV2", "NUM_HEADS"), "window_size": ("MODEL", "SWINV2", "WINDOW_SIZE"), "mlp_ratio": ("MODEL", "SWINV2", "MLP_RATIO"), "qkv_bias": ("MODEL", "SWINV2", "QKV_BIAS"), "ape": ("MODEL", "SWINV2", "APE"), "patch_norm": ("MODEL", "SWINV2", "PATCH_NORM"), "pretrained_window_sizes": ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"), "drop_rate": ("MODEL", "DROP_RATE"), "drop_path_rate": ("MODEL", "DROP_PATH_RATE"), } for key, path in model_key_map.items(): if key in model_kwargs and model_kwargs[key] is not None: overrides.append((path, model_kwargs[key])) return overrides def _collect_arg_overrides(args: Namespace | None) -> list[tuple[tuple[str, ...], Any]]: if args is None: return [] key_map = { ("model_name", "model"): ("MODEL", "NAME"), ("pretrained", "weights", "weight_path", "ckpt", "checkpoint"): ("MODEL", "PRETRAINED"), ("num_classes",): ("MODEL", "NUM_CLASSES"), ("img_size", "image_size", "input_size"): ("DATA", "IMG_SIZE"), ("in_chans", "in_channels"): ("MODEL", "SWINV2", "IN_CHANS"), ("patch_size",): ("MODEL", "SWINV2", "PATCH_SIZE"), ("embed_dim",): ("MODEL", "SWINV2", "EMBED_DIM"), ("depths",): ("MODEL", "SWINV2", "DEPTHS"), ("num_heads",): ("MODEL", "SWINV2", "NUM_HEADS"), ("window_size",): ("MODEL", "SWINV2", "WINDOW_SIZE"), ("mlp_ratio",): ("MODEL", "SWINV2", "MLP_RATIO"), ("qkv_bias",): ("MODEL", "SWINV2", "QKV_BIAS"), ("ape",): ("MODEL", "SWINV2", "APE"), ("patch_norm",): ("MODEL", "SWINV2", "PATCH_NORM"), ("pretrained_window_sizes",): ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"), ("drop_rate",): ("MODEL", "DROP_RATE"), ("drop_path_rate",): ("MODEL", "DROP_PATH_RATE"), ("use_checkpoint",): ("TRAIN", "USE_CHECKPOINT"), } overrides: list[tuple[tuple[str, ...], Any]] = [] for names, path in key_map.items(): value = _get_arg(args, *names) if value is not None: overrides.append((path, value)) return overrides def _apply_overrides(config: dict[str, Any], overrides: list[tuple[tuple[str, ...], Any]]) -> dict[str, Any]: for path, value in overrides: _set_nested(config, path, value) return config def _extract_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]: if isinstance(checkpoint, dict): for key in ("model", "state_dict"): if key in checkpoint and isinstance(checkpoint[key], dict): return checkpoint[key] return checkpoint raise TypeError(f"Unsupported checkpoint format: {type(checkpoint)!r}") def _remap_head_if_needed(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]): if "head.bias" not in state_dict or "head.weight" not in state_dict: return ckpt_classes = state_dict["head.bias"].shape[0] head_bias = getattr(model.head, "bias", None) model_classes = head_bias.shape[0] if head_bias is not None else 0 if ckpt_classes == model_classes: return if ckpt_classes == 21841 and model_classes == 1000 and MAP22KTO1K_PATH.exists(): with MAP22KTO1K_PATH.open("r", encoding="utf-8") as handle: indices = [int(line.strip()) for line in handle if line.strip()] state_dict["head.weight"] = state_dict["head.weight"][indices, :] state_dict["head.bias"] = state_dict["head.bias"][indices] return state_dict.pop("head.weight", None) state_dict.pop("head.bias", None) def _sanitize_state_dict(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: if any(key.startswith("encoder.") for key in state_dict): state_dict = { key.replace("encoder.", "", 1): value for key, value in state_dict.items() if key.startswith("encoder.") } remapped: dict[str, torch.Tensor] = {} for key, value in state_dict.items(): if "relative_position_index" in key or "relative_coords_table" in key or "attn_mask" in key: continue remapped[key.replace("rpe_mlp", "cpb_mlp")] = value _remap_head_if_needed(model, remapped) model_state = model.state_dict() filtered: dict[str, torch.Tensor] = {} for key, value in remapped.items(): if key in model_state and model_state[key].shape == value.shape: filtered[key] = value return filtered def _load_checkpoint(weight_path: Path) -> dict[str, Any]: try: return torch.load(weight_path, map_location="cpu") except Exception as exc: raise RuntimeError(f"Failed to load SwinV2 checkpoint: {weight_path}") from exc def build_swinv2( model_name: str | None = None, config_path: str | Path | None = None, weight_path: str | Path | None = None, args: Namespace | None = None, *, num_classes: int | None = None, img_size: int | None = None, in_chans: int | None = None, use_checkpoint: bool | None = None, strict: bool = False, load_weights: bool = True, return_config: bool = False, **model_kwargs, ): """Build a SwinTransformerV2 with loaded weights. Precedence order: 1. internal defaults 2. YAML config under ``configs/swinv2`` 3. explicit function inputs outside config files 4. command line style ``args`` overrides """ explicit_weight_path = _to_path(weight_path) initial_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args) resolved_config_path = _resolve_config_path(initial_model_name, config_path, args) config_dict = _load_yaml_config(resolved_config_path) resolved_model_name = _resolve_model_name( initial_model_name or config_dict["MODEL"]["NAME"], resolved_config_path, explicit_weight_path, args, ) resolved_weight_path = _resolve_weight_path(resolved_model_name, weight_path, args) function_overrides = _collect_function_overrides( model_name=resolved_model_name, weight_path=resolved_weight_path, num_classes=num_classes, img_size=img_size, in_chans=in_chans, use_checkpoint=use_checkpoint, model_kwargs=model_kwargs, ) arg_overrides = _collect_arg_overrides(args) config_dict = _apply_overrides(config_dict, function_overrides) config_dict = _apply_overrides(config_dict, arg_overrides) model_cfg = config_dict["MODEL"] swinv2_cfg = model_cfg["SWINV2"] model = SwinTransformerV2( img_size=config_dict["DATA"]["IMG_SIZE"], patch_size=swinv2_cfg["PATCH_SIZE"], in_chans=swinv2_cfg["IN_CHANS"], num_classes=model_cfg["NUM_CLASSES"], embed_dim=swinv2_cfg["EMBED_DIM"], depths=tuple(swinv2_cfg["DEPTHS"]), num_heads=tuple(swinv2_cfg["NUM_HEADS"]), window_size=swinv2_cfg["WINDOW_SIZE"], mlp_ratio=swinv2_cfg["MLP_RATIO"], qkv_bias=swinv2_cfg["QKV_BIAS"], drop_rate=model_cfg["DROP_RATE"], drop_path_rate=model_cfg["DROP_PATH_RATE"], ape=swinv2_cfg["APE"], patch_norm=swinv2_cfg["PATCH_NORM"], use_checkpoint=config_dict["TRAIN"]["USE_CHECKPOINT"], pretrained_window_sizes=tuple(swinv2_cfg["PRETRAINED_WINDOW_SIZES"]), ) if load_weights: if resolved_weight_path is None: raise FileNotFoundError( f"No SwinV2 weight file resolved for model '{model_cfg['NAME']}'. " f"Expected one under {SWINV2_WEIGHT_DIR}." ) if not resolved_weight_path.exists(): raise FileNotFoundError(f"SwinV2 weight file not found: {resolved_weight_path}") checkpoint = _load_checkpoint(resolved_weight_path) state_dict = _sanitize_state_dict(model, _extract_state_dict(checkpoint)) model.load_state_dict(state_dict, strict=strict) config = _dict_to_namespace(config_dict) if return_config: return model, config return model def build_swinv2_auto( model_name: str | None = None, config_path: str | Path | None = None, weight_path: str | Path | None = None, args: Namespace | None = None, *, verbose: bool = True, return_config: bool = False, return_resolution: bool = False, **kwargs, ): """Auto-resolve SwinV2 config and weights by ``MODEL.NAME`` and print sources. This wrapper keeps the same precedence rules as ``build_swinv2`` while making the config/weight resolution explicit for callers. """ explicit_weight_path = _to_path(weight_path) candidate_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args) resolved_config_path, config_source = _resolve_config_with_source(candidate_model_name, config_path, args) temp_config = _load_yaml_config(resolved_config_path) final_model_name = _resolve_model_name( candidate_model_name or temp_config["MODEL"]["NAME"], resolved_config_path, explicit_weight_path, args, ) or temp_config["MODEL"]["NAME"] resolved_weight_path, weight_source = _resolve_weight_with_source(final_model_name, weight_path, args) built = build_swinv2( model_name=final_model_name, config_path=resolved_config_path, weight_path=resolved_weight_path, args=args, return_config=True, **kwargs, ) if not isinstance(built, tuple) or len(built) != 2: raise RuntimeError("build_swinv2(return_config=True) must return (model, config)") model, config = built resolution = { "model_name": config.MODEL.NAME, "config_path": str(resolved_config_path) if resolved_config_path is not None else None, "config_source": config_source, "weight_path": str(resolved_weight_path) if resolved_weight_path is not None else None, "weight_source": weight_source, } if verbose: print( "[build_swinv2_auto] " f"MODEL.NAME={resolution['model_name']} | " f"config={resolution['config_path']} ({resolution['config_source']}) | " f"weight={resolution['weight_path']} ({resolution['weight_source']})" ) if return_config and return_resolution: return model, config, resolution if return_config: return model, config if return_resolution: return model, resolution return model __all__ = ["build_swinv2", "build_swinv2_auto"]