|
|
@@ -1,504 +1,504 @@
|
|
|
-from __future__ import annotations
|
|
|
-
|
|
|
-from argparse import Namespace
|
|
|
-from pathlib import Path
|
|
|
-from types import SimpleNamespace
|
|
|
-from typing import Any
|
|
|
-
|
|
|
-import torch
|
|
|
-import yaml
|
|
|
-
|
|
|
-from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
|
|
|
-
|
|
|
-ROOT_DIR = Path(__file__).resolve().parents[2]
|
|
|
-SWINV2_CONFIG_DIR = ROOT_DIR / "configs" / "swinv2"
|
|
|
-SWINV2_WEIGHT_DIR = ROOT_DIR / "weights" / "swinv2"
|
|
|
-MAP22KTO1K_PATH = ROOT_DIR / "lib" / "SwinTransformer" / "data" / "map22kto1k.txt"
|
|
|
-
|
|
|
-DEFAULTS: dict[str, Any] = {
|
|
|
- "DATA": {
|
|
|
- "IMG_SIZE": 224,
|
|
|
- },
|
|
|
- "MODEL": {
|
|
|
- "TYPE": "swinv2",
|
|
|
- "NAME": "swinv2_tiny_patch4_window8_256",
|
|
|
- "NUM_CLASSES": 1000,
|
|
|
- "DROP_RATE": 0.0,
|
|
|
- "DROP_PATH_RATE": 0.1,
|
|
|
- "PRETRAINED": "",
|
|
|
- "SWINV2": {
|
|
|
- "PATCH_SIZE": 4,
|
|
|
- "IN_CHANS": 3,
|
|
|
- "EMBED_DIM": 96,
|
|
|
- "DEPTHS": [2, 2, 6, 2],
|
|
|
- "NUM_HEADS": [3, 6, 12, 24],
|
|
|
- "WINDOW_SIZE": 7,
|
|
|
- "MLP_RATIO": 4.0,
|
|
|
- "QKV_BIAS": True,
|
|
|
- "APE": False,
|
|
|
- "PATCH_NORM": True,
|
|
|
- "PRETRAINED_WINDOW_SIZES": [0, 0, 0, 0],
|
|
|
- },
|
|
|
- },
|
|
|
- "TRAIN": {
|
|
|
- "USE_CHECKPOINT": False,
|
|
|
- },
|
|
|
-}
|
|
|
-
|
|
|
-
|
|
|
-def _deep_copy_dict(value: dict[str, Any]) -> dict[str, Any]:
|
|
|
- copied: dict[str, Any] = {}
|
|
|
- for key, item in value.items():
|
|
|
- if isinstance(item, dict):
|
|
|
- copied[key] = _deep_copy_dict(item)
|
|
|
- elif isinstance(item, list):
|
|
|
- copied[key] = list(item)
|
|
|
- else:
|
|
|
- copied[key] = item
|
|
|
- return copied
|
|
|
-
|
|
|
-
|
|
|
-def _merge_dict(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
|
|
|
- for key, value in src.items():
|
|
|
- if isinstance(value, dict) and isinstance(dst.get(key), dict):
|
|
|
- _merge_dict(dst[key], value)
|
|
|
- else:
|
|
|
- dst[key] = value
|
|
|
- return dst
|
|
|
-
|
|
|
-
|
|
|
-def _dict_to_namespace(value: Any) -> Any:
|
|
|
- if isinstance(value, dict):
|
|
|
- return SimpleNamespace(**{key: _dict_to_namespace(item) for key, item in value.items()})
|
|
|
- if isinstance(value, list):
|
|
|
- return [_dict_to_namespace(item) for item in value]
|
|
|
- return value
|
|
|
-
|
|
|
-
|
|
|
-def _get_arg(args: Namespace | None, *names: str) -> Any:
|
|
|
- if args is None:
|
|
|
- return None
|
|
|
- for name in names:
|
|
|
- if hasattr(args, name):
|
|
|
- value = getattr(args, name)
|
|
|
- if value is not None:
|
|
|
- return value
|
|
|
- return None
|
|
|
-
|
|
|
-
|
|
|
-def _to_path(value: str | Path | None) -> Path | None:
|
|
|
- if value is None:
|
|
|
- return None
|
|
|
- return Path(value).expanduser().resolve()
|
|
|
-
|
|
|
-
|
|
|
-def _resolve_model_name(
|
|
|
- model_name: str | None,
|
|
|
- config_path: Path | None,
|
|
|
- weight_path: Path | None,
|
|
|
- args: Namespace | None,
|
|
|
-) -> str | None:
|
|
|
- return (
|
|
|
- _get_arg(args, "model_name", "model")
|
|
|
- or model_name
|
|
|
- or (weight_path.stem if weight_path is not None else None)
|
|
|
- or (config_path.stem if config_path is not None else None)
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def _resolve_config_path(model_name: str | None, config_path: str | Path | None, args: Namespace | None) -> Path | None:
|
|
|
- cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
|
|
|
- if cli_cfg is not None:
|
|
|
- return cli_cfg
|
|
|
-
|
|
|
- explicit_cfg = _to_path(config_path)
|
|
|
- if explicit_cfg is not None:
|
|
|
- return explicit_cfg
|
|
|
-
|
|
|
- if model_name is None:
|
|
|
- return None
|
|
|
-
|
|
|
- candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
|
|
|
- return candidate if candidate.exists() else None
|
|
|
-
|
|
|
-
|
|
|
-def _resolve_weight_path(model_name: str | None, weight_path: str | Path | None, args: Namespace | None) -> Path | None:
|
|
|
- cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
|
|
|
- if cli_weight is not None:
|
|
|
- return cli_weight
|
|
|
-
|
|
|
- explicit_weight = _to_path(weight_path)
|
|
|
- if explicit_weight is not None:
|
|
|
- return explicit_weight
|
|
|
-
|
|
|
- if model_name is None:
|
|
|
- return None
|
|
|
-
|
|
|
- candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
|
|
|
- return candidate if candidate.exists() else None
|
|
|
-
|
|
|
-
|
|
|
-def _resolve_config_with_source(
|
|
|
- model_name: str | None,
|
|
|
- config_path: str | Path | None,
|
|
|
- args: Namespace | None,
|
|
|
-) -> tuple[Path | None, str]:
|
|
|
- cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
|
|
|
- if cli_cfg is not None:
|
|
|
- return cli_cfg, "args.cfg"
|
|
|
-
|
|
|
- explicit_cfg = _to_path(config_path)
|
|
|
- if explicit_cfg is not None:
|
|
|
- return explicit_cfg, "function config_path"
|
|
|
-
|
|
|
- if model_name is None:
|
|
|
- return None, "defaults only"
|
|
|
-
|
|
|
- candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
|
|
|
- if candidate.exists():
|
|
|
- return candidate, "auto by MODEL.NAME"
|
|
|
- return None, "defaults only"
|
|
|
-
|
|
|
-
|
|
|
-def _resolve_weight_with_source(
|
|
|
- model_name: str | None,
|
|
|
- weight_path: str | Path | None,
|
|
|
- args: Namespace | None,
|
|
|
-) -> tuple[Path | None, str]:
|
|
|
- cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
|
|
|
- if cli_weight is not None:
|
|
|
- return cli_weight, "args.pretrained"
|
|
|
-
|
|
|
- explicit_weight = _to_path(weight_path)
|
|
|
- if explicit_weight is not None:
|
|
|
- return explicit_weight, "function weight_path"
|
|
|
-
|
|
|
- if model_name is None:
|
|
|
- return None, "not resolved"
|
|
|
-
|
|
|
- candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
|
|
|
- if candidate.exists():
|
|
|
- return candidate, "auto by MODEL.NAME"
|
|
|
- return None, "not resolved"
|
|
|
-
|
|
|
-
|
|
|
-def _load_yaml_config(config_path: Path | None) -> dict[str, Any]:
|
|
|
- config = _deep_copy_dict(DEFAULTS)
|
|
|
- if config_path is None:
|
|
|
- return config
|
|
|
- if not config_path.exists():
|
|
|
- raise FileNotFoundError(f"SwinV2 config not found: {config_path}")
|
|
|
-
|
|
|
- with config_path.open("r", encoding="utf-8") as handle:
|
|
|
- yaml_config = yaml.safe_load(handle) or {}
|
|
|
- return _merge_dict(config, yaml_config)
|
|
|
-
|
|
|
-
|
|
|
-def _set_nested(config: dict[str, Any], path: tuple[str, ...], value: Any):
|
|
|
- current = config
|
|
|
- for key in path[:-1]:
|
|
|
- current = current.setdefault(key, {})
|
|
|
- current[path[-1]] = value
|
|
|
-
|
|
|
-
|
|
|
-def _collect_function_overrides(
|
|
|
- model_name: str | None,
|
|
|
- weight_path: Path | None,
|
|
|
- num_classes: int | None,
|
|
|
- img_size: int | None,
|
|
|
- in_chans: int | None,
|
|
|
- use_checkpoint: bool | None,
|
|
|
- model_kwargs: dict[str, Any],
|
|
|
-) -> list[tuple[tuple[str, ...], Any]]:
|
|
|
- overrides: list[tuple[tuple[str, ...], Any]] = []
|
|
|
- if model_name is not None:
|
|
|
- overrides.append((("MODEL", "NAME"), model_name))
|
|
|
- if weight_path is not None:
|
|
|
- overrides.append((("MODEL", "PRETRAINED"), str(weight_path)))
|
|
|
- if num_classes is not None:
|
|
|
- overrides.append((("MODEL", "NUM_CLASSES"), num_classes))
|
|
|
- if img_size is not None:
|
|
|
- overrides.append((("DATA", "IMG_SIZE"), img_size))
|
|
|
- if in_chans is not None:
|
|
|
- overrides.append((("MODEL", "SWINV2", "IN_CHANS"), in_chans))
|
|
|
- if use_checkpoint is not None:
|
|
|
- overrides.append((("TRAIN", "USE_CHECKPOINT"), use_checkpoint))
|
|
|
-
|
|
|
- model_key_map = {
|
|
|
- "patch_size": ("MODEL", "SWINV2", "PATCH_SIZE"),
|
|
|
- "embed_dim": ("MODEL", "SWINV2", "EMBED_DIM"),
|
|
|
- "depths": ("MODEL", "SWINV2", "DEPTHS"),
|
|
|
- "num_heads": ("MODEL", "SWINV2", "NUM_HEADS"),
|
|
|
- "window_size": ("MODEL", "SWINV2", "WINDOW_SIZE"),
|
|
|
- "mlp_ratio": ("MODEL", "SWINV2", "MLP_RATIO"),
|
|
|
- "qkv_bias": ("MODEL", "SWINV2", "QKV_BIAS"),
|
|
|
- "ape": ("MODEL", "SWINV2", "APE"),
|
|
|
- "patch_norm": ("MODEL", "SWINV2", "PATCH_NORM"),
|
|
|
- "pretrained_window_sizes": ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
|
|
|
- "drop_rate": ("MODEL", "DROP_RATE"),
|
|
|
- "drop_path_rate": ("MODEL", "DROP_PATH_RATE"),
|
|
|
- }
|
|
|
- for key, path in model_key_map.items():
|
|
|
- if key in model_kwargs and model_kwargs[key] is not None:
|
|
|
- overrides.append((path, model_kwargs[key]))
|
|
|
- return overrides
|
|
|
-
|
|
|
-
|
|
|
-def _collect_arg_overrides(args: Namespace | None) -> list[tuple[tuple[str, ...], Any]]:
|
|
|
- if args is None:
|
|
|
- return []
|
|
|
-
|
|
|
- key_map = {
|
|
|
- ("model_name", "model"): ("MODEL", "NAME"),
|
|
|
- ("pretrained", "weights", "weight_path", "ckpt", "checkpoint"): ("MODEL", "PRETRAINED"),
|
|
|
- ("num_classes",): ("MODEL", "NUM_CLASSES"),
|
|
|
- ("img_size", "image_size", "input_size"): ("DATA", "IMG_SIZE"),
|
|
|
- ("in_chans", "in_channels"): ("MODEL", "SWINV2", "IN_CHANS"),
|
|
|
- ("patch_size",): ("MODEL", "SWINV2", "PATCH_SIZE"),
|
|
|
- ("embed_dim",): ("MODEL", "SWINV2", "EMBED_DIM"),
|
|
|
- ("depths",): ("MODEL", "SWINV2", "DEPTHS"),
|
|
|
- ("num_heads",): ("MODEL", "SWINV2", "NUM_HEADS"),
|
|
|
- ("window_size",): ("MODEL", "SWINV2", "WINDOW_SIZE"),
|
|
|
- ("mlp_ratio",): ("MODEL", "SWINV2", "MLP_RATIO"),
|
|
|
- ("qkv_bias",): ("MODEL", "SWINV2", "QKV_BIAS"),
|
|
|
- ("ape",): ("MODEL", "SWINV2", "APE"),
|
|
|
- ("patch_norm",): ("MODEL", "SWINV2", "PATCH_NORM"),
|
|
|
- ("pretrained_window_sizes",): ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
|
|
|
- ("drop_rate",): ("MODEL", "DROP_RATE"),
|
|
|
- ("drop_path_rate",): ("MODEL", "DROP_PATH_RATE"),
|
|
|
- ("use_checkpoint",): ("TRAIN", "USE_CHECKPOINT"),
|
|
|
- }
|
|
|
-
|
|
|
- overrides: list[tuple[tuple[str, ...], Any]] = []
|
|
|
- for names, path in key_map.items():
|
|
|
- value = _get_arg(args, *names)
|
|
|
- if value is not None:
|
|
|
- overrides.append((path, value))
|
|
|
- return overrides
|
|
|
-
|
|
|
-
|
|
|
-def _apply_overrides(config: dict[str, Any], overrides: list[tuple[tuple[str, ...], Any]]) -> dict[str, Any]:
|
|
|
- for path, value in overrides:
|
|
|
- _set_nested(config, path, value)
|
|
|
- return config
|
|
|
-
|
|
|
-
|
|
|
-def _extract_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
|
|
|
- if isinstance(checkpoint, dict):
|
|
|
- for key in ("model", "state_dict"):
|
|
|
- if key in checkpoint and isinstance(checkpoint[key], dict):
|
|
|
- return checkpoint[key]
|
|
|
- return checkpoint
|
|
|
- raise TypeError(f"Unsupported checkpoint format: {type(checkpoint)!r}")
|
|
|
-
|
|
|
-
|
|
|
-def _remap_head_if_needed(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]):
|
|
|
- if "head.bias" not in state_dict or "head.weight" not in state_dict:
|
|
|
- return
|
|
|
-
|
|
|
- ckpt_classes = state_dict["head.bias"].shape[0]
|
|
|
- head_bias = getattr(model.head, "bias", None)
|
|
|
- model_classes = head_bias.shape[0] if head_bias is not None else 0
|
|
|
- if ckpt_classes == model_classes:
|
|
|
- return
|
|
|
-
|
|
|
- if ckpt_classes == 21841 and model_classes == 1000 and MAP22KTO1K_PATH.exists():
|
|
|
- with MAP22KTO1K_PATH.open("r", encoding="utf-8") as handle:
|
|
|
- indices = [int(line.strip()) for line in handle if line.strip()]
|
|
|
- state_dict["head.weight"] = state_dict["head.weight"][indices, :]
|
|
|
- state_dict["head.bias"] = state_dict["head.bias"][indices]
|
|
|
- return
|
|
|
-
|
|
|
- state_dict.pop("head.weight", None)
|
|
|
- state_dict.pop("head.bias", None)
|
|
|
-
|
|
|
-
|
|
|
-def _sanitize_state_dict(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
|
- if any(key.startswith("encoder.") for key in state_dict):
|
|
|
- state_dict = {
|
|
|
- key.replace("encoder.", "", 1): value
|
|
|
- for key, value in state_dict.items()
|
|
|
- if key.startswith("encoder.")
|
|
|
- }
|
|
|
-
|
|
|
- remapped: dict[str, torch.Tensor] = {}
|
|
|
- for key, value in state_dict.items():
|
|
|
- if "relative_position_index" in key or "relative_coords_table" in key or "attn_mask" in key:
|
|
|
- continue
|
|
|
- remapped[key.replace("rpe_mlp", "cpb_mlp")] = value
|
|
|
-
|
|
|
- _remap_head_if_needed(model, remapped)
|
|
|
-
|
|
|
- model_state = model.state_dict()
|
|
|
- filtered: dict[str, torch.Tensor] = {}
|
|
|
- for key, value in remapped.items():
|
|
|
- if key in model_state and model_state[key].shape == value.shape:
|
|
|
- filtered[key] = value
|
|
|
- return filtered
|
|
|
-
|
|
|
-
|
|
|
-def _load_checkpoint(weight_path: Path) -> dict[str, Any]:
|
|
|
- try:
|
|
|
- return torch.load(weight_path, map_location="cpu")
|
|
|
- except Exception as exc:
|
|
|
- raise RuntimeError(f"Failed to load SwinV2 checkpoint: {weight_path}") from exc
|
|
|
-
|
|
|
-
|
|
|
-def build_swinv2(
|
|
|
- model_name: str | None = None,
|
|
|
- config_path: str | Path | None = None,
|
|
|
- weight_path: str | Path | None = None,
|
|
|
- args: Namespace | None = None,
|
|
|
- *,
|
|
|
- num_classes: int | None = None,
|
|
|
- img_size: int | None = None,
|
|
|
- in_chans: int | None = None,
|
|
|
- use_checkpoint: bool | None = None,
|
|
|
- strict: bool = False,
|
|
|
- load_weights: bool = True,
|
|
|
- return_config: bool = False,
|
|
|
- **model_kwargs,
|
|
|
-):
|
|
|
- """Build a SwinTransformerV2 with loaded weights.
|
|
|
-
|
|
|
- Precedence order:
|
|
|
- 1. internal defaults
|
|
|
- 2. YAML config under ``configs/swinv2``
|
|
|
- 3. explicit function inputs outside config files
|
|
|
- 4. command line style ``args`` overrides
|
|
|
- """
|
|
|
-
|
|
|
- explicit_weight_path = _to_path(weight_path)
|
|
|
- initial_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
|
|
|
- resolved_config_path = _resolve_config_path(initial_model_name, config_path, args)
|
|
|
- config_dict = _load_yaml_config(resolved_config_path)
|
|
|
-
|
|
|
- resolved_model_name = _resolve_model_name(
|
|
|
- initial_model_name or config_dict["MODEL"]["NAME"],
|
|
|
- resolved_config_path,
|
|
|
- explicit_weight_path,
|
|
|
- args,
|
|
|
- )
|
|
|
- resolved_weight_path = _resolve_weight_path(resolved_model_name, weight_path, args)
|
|
|
-
|
|
|
- function_overrides = _collect_function_overrides(
|
|
|
- model_name=resolved_model_name,
|
|
|
- weight_path=resolved_weight_path,
|
|
|
- num_classes=num_classes,
|
|
|
- img_size=img_size,
|
|
|
- in_chans=in_chans,
|
|
|
- use_checkpoint=use_checkpoint,
|
|
|
- model_kwargs=model_kwargs,
|
|
|
- )
|
|
|
- arg_overrides = _collect_arg_overrides(args)
|
|
|
- config_dict = _apply_overrides(config_dict, function_overrides)
|
|
|
- config_dict = _apply_overrides(config_dict, arg_overrides)
|
|
|
-
|
|
|
- model_cfg = config_dict["MODEL"]
|
|
|
- swinv2_cfg = model_cfg["SWINV2"]
|
|
|
- model = SwinTransformerV2(
|
|
|
- img_size=config_dict["DATA"]["IMG_SIZE"],
|
|
|
- patch_size=swinv2_cfg["PATCH_SIZE"],
|
|
|
- in_chans=swinv2_cfg["IN_CHANS"],
|
|
|
- num_classes=model_cfg["NUM_CLASSES"],
|
|
|
- embed_dim=swinv2_cfg["EMBED_DIM"],
|
|
|
- depths=tuple(swinv2_cfg["DEPTHS"]),
|
|
|
- num_heads=tuple(swinv2_cfg["NUM_HEADS"]),
|
|
|
- window_size=swinv2_cfg["WINDOW_SIZE"],
|
|
|
- mlp_ratio=swinv2_cfg["MLP_RATIO"],
|
|
|
- qkv_bias=swinv2_cfg["QKV_BIAS"],
|
|
|
- drop_rate=model_cfg["DROP_RATE"],
|
|
|
- drop_path_rate=model_cfg["DROP_PATH_RATE"],
|
|
|
- ape=swinv2_cfg["APE"],
|
|
|
- patch_norm=swinv2_cfg["PATCH_NORM"],
|
|
|
- use_checkpoint=config_dict["TRAIN"]["USE_CHECKPOINT"],
|
|
|
- pretrained_window_sizes=tuple(swinv2_cfg["PRETRAINED_WINDOW_SIZES"]),
|
|
|
- )
|
|
|
-
|
|
|
- if load_weights:
|
|
|
- if resolved_weight_path is None:
|
|
|
- raise FileNotFoundError(
|
|
|
- f"No SwinV2 weight file resolved for model '{model_cfg['NAME']}'. "
|
|
|
- f"Expected one under {SWINV2_WEIGHT_DIR}."
|
|
|
- )
|
|
|
- if not resolved_weight_path.exists():
|
|
|
- raise FileNotFoundError(f"SwinV2 weight file not found: {resolved_weight_path}")
|
|
|
-
|
|
|
- checkpoint = _load_checkpoint(resolved_weight_path)
|
|
|
- state_dict = _sanitize_state_dict(model, _extract_state_dict(checkpoint))
|
|
|
- model.load_state_dict(state_dict, strict=strict)
|
|
|
-
|
|
|
- config = _dict_to_namespace(config_dict)
|
|
|
- if return_config:
|
|
|
- return model, config
|
|
|
- return model
|
|
|
-
|
|
|
-
|
|
|
-def build_swinv2_auto(
|
|
|
- model_name: str | None = None,
|
|
|
- config_path: str | Path | None = None,
|
|
|
- weight_path: str | Path | None = None,
|
|
|
- args: Namespace | None = None,
|
|
|
- *,
|
|
|
- verbose: bool = True,
|
|
|
- return_config: bool = False,
|
|
|
- return_resolution: bool = False,
|
|
|
- **kwargs,
|
|
|
-):
|
|
|
- """Auto-resolve SwinV2 config and weights by ``MODEL.NAME`` and print sources.
|
|
|
-
|
|
|
- This wrapper keeps the same precedence rules as ``build_swinv2`` while making
|
|
|
- the config/weight resolution explicit for callers.
|
|
|
- """
|
|
|
-
|
|
|
- explicit_weight_path = _to_path(weight_path)
|
|
|
- candidate_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
|
|
|
- resolved_config_path, config_source = _resolve_config_with_source(candidate_model_name, config_path, args)
|
|
|
-
|
|
|
- temp_config = _load_yaml_config(resolved_config_path)
|
|
|
- final_model_name = _resolve_model_name(
|
|
|
- candidate_model_name or temp_config["MODEL"]["NAME"],
|
|
|
- resolved_config_path,
|
|
|
- explicit_weight_path,
|
|
|
- args,
|
|
|
- ) or temp_config["MODEL"]["NAME"]
|
|
|
- resolved_weight_path, weight_source = _resolve_weight_with_source(final_model_name, weight_path, args)
|
|
|
-
|
|
|
- built = build_swinv2(
|
|
|
- model_name=final_model_name,
|
|
|
- config_path=resolved_config_path,
|
|
|
- weight_path=resolved_weight_path,
|
|
|
- args=args,
|
|
|
- return_config=True,
|
|
|
- **kwargs,
|
|
|
- )
|
|
|
- if not isinstance(built, tuple) or len(built) != 2:
|
|
|
- raise RuntimeError("build_swinv2(return_config=True) must return (model, config)")
|
|
|
- model, config = built
|
|
|
-
|
|
|
- resolution = {
|
|
|
- "model_name": config.MODEL.NAME,
|
|
|
- "config_path": str(resolved_config_path) if resolved_config_path is not None else None,
|
|
|
- "config_source": config_source,
|
|
|
- "weight_path": str(resolved_weight_path) if resolved_weight_path is not None else None,
|
|
|
- "weight_source": weight_source,
|
|
|
- }
|
|
|
-
|
|
|
- if verbose:
|
|
|
- print(
|
|
|
- "[build_swinv2_auto] "
|
|
|
- f"MODEL.NAME={resolution['model_name']} | "
|
|
|
- f"config={resolution['config_path']} ({resolution['config_source']}) | "
|
|
|
- f"weight={resolution['weight_path']} ({resolution['weight_source']})"
|
|
|
- )
|
|
|
-
|
|
|
- if return_config and return_resolution:
|
|
|
- return model, config, resolution
|
|
|
- if return_config:
|
|
|
- return model, config
|
|
|
- if return_resolution:
|
|
|
- return model, resolution
|
|
|
- return model
|
|
|
-
|
|
|
-
|
|
|
-__all__ = ["build_swinv2", "build_swinv2_auto"]
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+from argparse import Namespace
|
|
|
+from pathlib import Path
|
|
|
+from types import SimpleNamespace
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+import torch
|
|
|
+import yaml
|
|
|
+
|
|
|
+from lib.SwinTransformer.models.swin_transformer_v2 import SwinTransformerV2
|
|
|
+
|
|
|
+ROOT_DIR = Path(__file__).resolve().parents[2]
|
|
|
+SWINV2_CONFIG_DIR = ROOT_DIR / "configs" / "swinv2"
|
|
|
+SWINV2_WEIGHT_DIR = ROOT_DIR / "weights" / "swinv2"
|
|
|
+MAP22KTO1K_PATH = ROOT_DIR / "lib" / "SwinTransformer" / "data" / "map22kto1k.txt"
|
|
|
+
|
|
|
+DEFAULTS: dict[str, Any] = {
|
|
|
+ "DATA": {
|
|
|
+ "IMG_SIZE": 224,
|
|
|
+ },
|
|
|
+ "MODEL": {
|
|
|
+ "TYPE": "swinv2",
|
|
|
+ "NAME": "swinv2_tiny_patch4_window8_256",
|
|
|
+ "NUM_CLASSES": 1000,
|
|
|
+ "DROP_RATE": 0.0,
|
|
|
+ "DROP_PATH_RATE": 0.1,
|
|
|
+ "PRETRAINED": "",
|
|
|
+ "SWINV2": {
|
|
|
+ "PATCH_SIZE": 4,
|
|
|
+ "IN_CHANS": 3,
|
|
|
+ "EMBED_DIM": 96,
|
|
|
+ "DEPTHS": [2, 2, 6, 2],
|
|
|
+ "NUM_HEADS": [3, 6, 12, 24],
|
|
|
+ "WINDOW_SIZE": 7,
|
|
|
+ "MLP_RATIO": 4.0,
|
|
|
+ "QKV_BIAS": True,
|
|
|
+ "APE": False,
|
|
|
+ "PATCH_NORM": True,
|
|
|
+ "PRETRAINED_WINDOW_SIZES": [0, 0, 0, 0],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ "TRAIN": {
|
|
|
+ "USE_CHECKPOINT": False,
|
|
|
+ },
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+def _deep_copy_dict(value: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ copied: dict[str, Any] = {}
|
|
|
+ for key, item in value.items():
|
|
|
+ if isinstance(item, dict):
|
|
|
+ copied[key] = _deep_copy_dict(item)
|
|
|
+ elif isinstance(item, list):
|
|
|
+ copied[key] = list(item)
|
|
|
+ else:
|
|
|
+ copied[key] = item
|
|
|
+ return copied
|
|
|
+
|
|
|
+
|
|
|
+def _merge_dict(dst: dict[str, Any], src: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ for key, value in src.items():
|
|
|
+ if isinstance(value, dict) and isinstance(dst.get(key), dict):
|
|
|
+ _merge_dict(dst[key], value)
|
|
|
+ else:
|
|
|
+ dst[key] = value
|
|
|
+ return dst
|
|
|
+
|
|
|
+
|
|
|
+def _dict_to_namespace(value: Any) -> Any:
|
|
|
+ if isinstance(value, dict):
|
|
|
+ return SimpleNamespace(**{key: _dict_to_namespace(item) for key, item in value.items()})
|
|
|
+ if isinstance(value, list):
|
|
|
+ return [_dict_to_namespace(item) for item in value]
|
|
|
+ return value
|
|
|
+
|
|
|
+
|
|
|
+def _get_arg(args: Namespace | None, *names: str) -> Any:
|
|
|
+ if args is None:
|
|
|
+ return None
|
|
|
+ for name in names:
|
|
|
+ if hasattr(args, name):
|
|
|
+ value = getattr(args, name)
|
|
|
+ if value is not None:
|
|
|
+ return value
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _to_path(value: str | Path | None) -> Path | None:
|
|
|
+ if value is None:
|
|
|
+ return None
|
|
|
+ return Path(value).expanduser().resolve()
|
|
|
+
|
|
|
+
|
|
|
+def _resolve_model_name(
|
|
|
+ model_name: str | None,
|
|
|
+ config_path: Path | None,
|
|
|
+ weight_path: Path | None,
|
|
|
+ args: Namespace | None,
|
|
|
+) -> str | None:
|
|
|
+ return (
|
|
|
+ _get_arg(args, "model_name", "model")
|
|
|
+ or model_name
|
|
|
+ or (weight_path.stem if weight_path is not None else None)
|
|
|
+ or (config_path.stem if config_path is not None else None)
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _resolve_config_path(model_name: str | None, config_path: str | Path | None, args: Namespace | None) -> Path | None:
|
|
|
+ cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
|
|
|
+ if cli_cfg is not None:
|
|
|
+ return cli_cfg
|
|
|
+
|
|
|
+ explicit_cfg = _to_path(config_path)
|
|
|
+ if explicit_cfg is not None:
|
|
|
+ return explicit_cfg
|
|
|
+
|
|
|
+ if model_name is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
|
|
|
+ return candidate if candidate.exists() else None
|
|
|
+
|
|
|
+
|
|
|
+def _resolve_weight_path(model_name: str | None, weight_path: str | Path | None, args: Namespace | None) -> Path | None:
|
|
|
+ cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
|
|
|
+ if cli_weight is not None:
|
|
|
+ return cli_weight
|
|
|
+
|
|
|
+ explicit_weight = _to_path(weight_path)
|
|
|
+ if explicit_weight is not None:
|
|
|
+ return explicit_weight
|
|
|
+
|
|
|
+ if model_name is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
|
|
|
+ return candidate if candidate.exists() else None
|
|
|
+
|
|
|
+
|
|
|
+def _resolve_config_with_source(
|
|
|
+ model_name: str | None,
|
|
|
+ config_path: str | Path | None,
|
|
|
+ args: Namespace | None,
|
|
|
+) -> tuple[Path | None, str]:
|
|
|
+ cli_cfg = _to_path(_get_arg(args, "cfg", "config", "config_path"))
|
|
|
+ if cli_cfg is not None:
|
|
|
+ return cli_cfg, "args.cfg"
|
|
|
+
|
|
|
+ explicit_cfg = _to_path(config_path)
|
|
|
+ if explicit_cfg is not None:
|
|
|
+ return explicit_cfg, "function config_path"
|
|
|
+
|
|
|
+ if model_name is None:
|
|
|
+ return None, "defaults only"
|
|
|
+
|
|
|
+ candidate = SWINV2_CONFIG_DIR / f"{model_name}.yaml"
|
|
|
+ if candidate.exists():
|
|
|
+ return candidate, "auto by MODEL.NAME"
|
|
|
+ return None, "defaults only"
|
|
|
+
|
|
|
+
|
|
|
+def _resolve_weight_with_source(
|
|
|
+ model_name: str | None,
|
|
|
+ weight_path: str | Path | None,
|
|
|
+ args: Namespace | None,
|
|
|
+) -> tuple[Path | None, str]:
|
|
|
+ cli_weight = _to_path(_get_arg(args, "pretrained", "weights", "weight_path", "ckpt", "checkpoint"))
|
|
|
+ if cli_weight is not None:
|
|
|
+ return cli_weight, "args.pretrained"
|
|
|
+
|
|
|
+ explicit_weight = _to_path(weight_path)
|
|
|
+ if explicit_weight is not None:
|
|
|
+ return explicit_weight, "function weight_path"
|
|
|
+
|
|
|
+ if model_name is None:
|
|
|
+ return None, "not resolved"
|
|
|
+
|
|
|
+ candidate = SWINV2_WEIGHT_DIR / f"{model_name}.pth"
|
|
|
+ if candidate.exists():
|
|
|
+ return candidate, "auto by MODEL.NAME"
|
|
|
+ return None, "not resolved"
|
|
|
+
|
|
|
+
|
|
|
+def _load_yaml_config(config_path: Path | None) -> dict[str, Any]:
|
|
|
+ config = _deep_copy_dict(DEFAULTS)
|
|
|
+ if config_path is None:
|
|
|
+ return config
|
|
|
+ if not config_path.exists():
|
|
|
+ raise FileNotFoundError(f"SwinV2 config not found: {config_path}")
|
|
|
+
|
|
|
+ with config_path.open("r", encoding="utf-8") as handle:
|
|
|
+ yaml_config = yaml.safe_load(handle) or {}
|
|
|
+ return _merge_dict(config, yaml_config)
|
|
|
+
|
|
|
+
|
|
|
+def _set_nested(config: dict[str, Any], path: tuple[str, ...], value: Any):
|
|
|
+ current = config
|
|
|
+ for key in path[:-1]:
|
|
|
+ current = current.setdefault(key, {})
|
|
|
+ current[path[-1]] = value
|
|
|
+
|
|
|
+
|
|
|
+def _collect_function_overrides(
|
|
|
+ model_name: str | None,
|
|
|
+ weight_path: Path | None,
|
|
|
+ num_classes: int | None,
|
|
|
+ img_size: int | None,
|
|
|
+ in_chans: int | None,
|
|
|
+ use_checkpoint: bool | None,
|
|
|
+ model_kwargs: dict[str, Any],
|
|
|
+) -> list[tuple[tuple[str, ...], Any]]:
|
|
|
+ overrides: list[tuple[tuple[str, ...], Any]] = []
|
|
|
+ if model_name is not None:
|
|
|
+ overrides.append((("MODEL", "NAME"), model_name))
|
|
|
+ if weight_path is not None:
|
|
|
+ overrides.append((("MODEL", "PRETRAINED"), str(weight_path)))
|
|
|
+ if num_classes is not None:
|
|
|
+ overrides.append((("MODEL", "NUM_CLASSES"), num_classes))
|
|
|
+ if img_size is not None:
|
|
|
+ overrides.append((("DATA", "IMG_SIZE"), img_size))
|
|
|
+ if in_chans is not None:
|
|
|
+ overrides.append((("MODEL", "SWINV2", "IN_CHANS"), in_chans))
|
|
|
+ if use_checkpoint is not None:
|
|
|
+ overrides.append((("TRAIN", "USE_CHECKPOINT"), use_checkpoint))
|
|
|
+
|
|
|
+ model_key_map = {
|
|
|
+ "patch_size": ("MODEL", "SWINV2", "PATCH_SIZE"),
|
|
|
+ "embed_dim": ("MODEL", "SWINV2", "EMBED_DIM"),
|
|
|
+ "depths": ("MODEL", "SWINV2", "DEPTHS"),
|
|
|
+ "num_heads": ("MODEL", "SWINV2", "NUM_HEADS"),
|
|
|
+ "window_size": ("MODEL", "SWINV2", "WINDOW_SIZE"),
|
|
|
+ "mlp_ratio": ("MODEL", "SWINV2", "MLP_RATIO"),
|
|
|
+ "qkv_bias": ("MODEL", "SWINV2", "QKV_BIAS"),
|
|
|
+ "ape": ("MODEL", "SWINV2", "APE"),
|
|
|
+ "patch_norm": ("MODEL", "SWINV2", "PATCH_NORM"),
|
|
|
+ "pretrained_window_sizes": ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
|
|
|
+ "drop_rate": ("MODEL", "DROP_RATE"),
|
|
|
+ "drop_path_rate": ("MODEL", "DROP_PATH_RATE"),
|
|
|
+ }
|
|
|
+ for key, path in model_key_map.items():
|
|
|
+ if key in model_kwargs and model_kwargs[key] is not None:
|
|
|
+ overrides.append((path, model_kwargs[key]))
|
|
|
+ return overrides
|
|
|
+
|
|
|
+
|
|
|
+def _collect_arg_overrides(args: Namespace | None) -> list[tuple[tuple[str, ...], Any]]:
|
|
|
+ if args is None:
|
|
|
+ return []
|
|
|
+
|
|
|
+ key_map = {
|
|
|
+ ("model_name", "model"): ("MODEL", "NAME"),
|
|
|
+ ("pretrained", "weights", "weight_path", "ckpt", "checkpoint"): ("MODEL", "PRETRAINED"),
|
|
|
+ ("num_classes",): ("MODEL", "NUM_CLASSES"),
|
|
|
+ ("img_size", "image_size", "input_size"): ("DATA", "IMG_SIZE"),
|
|
|
+ ("in_chans", "in_channels"): ("MODEL", "SWINV2", "IN_CHANS"),
|
|
|
+ ("patch_size",): ("MODEL", "SWINV2", "PATCH_SIZE"),
|
|
|
+ ("embed_dim",): ("MODEL", "SWINV2", "EMBED_DIM"),
|
|
|
+ ("depths",): ("MODEL", "SWINV2", "DEPTHS"),
|
|
|
+ ("num_heads",): ("MODEL", "SWINV2", "NUM_HEADS"),
|
|
|
+ ("window_size",): ("MODEL", "SWINV2", "WINDOW_SIZE"),
|
|
|
+ ("mlp_ratio",): ("MODEL", "SWINV2", "MLP_RATIO"),
|
|
|
+ ("qkv_bias",): ("MODEL", "SWINV2", "QKV_BIAS"),
|
|
|
+ ("ape",): ("MODEL", "SWINV2", "APE"),
|
|
|
+ ("patch_norm",): ("MODEL", "SWINV2", "PATCH_NORM"),
|
|
|
+ ("pretrained_window_sizes",): ("MODEL", "SWINV2", "PRETRAINED_WINDOW_SIZES"),
|
|
|
+ ("drop_rate",): ("MODEL", "DROP_RATE"),
|
|
|
+ ("drop_path_rate",): ("MODEL", "DROP_PATH_RATE"),
|
|
|
+ ("use_checkpoint",): ("TRAIN", "USE_CHECKPOINT"),
|
|
|
+ }
|
|
|
+
|
|
|
+ overrides: list[tuple[tuple[str, ...], Any]] = []
|
|
|
+ for names, path in key_map.items():
|
|
|
+ value = _get_arg(args, *names)
|
|
|
+ if value is not None:
|
|
|
+ overrides.append((path, value))
|
|
|
+ return overrides
|
|
|
+
|
|
|
+
|
|
|
+def _apply_overrides(config: dict[str, Any], overrides: list[tuple[tuple[str, ...], Any]]) -> dict[str, Any]:
|
|
|
+ for path, value in overrides:
|
|
|
+ _set_nested(config, path, value)
|
|
|
+ return config
|
|
|
+
|
|
|
+
|
|
|
+def _extract_state_dict(checkpoint: Any) -> dict[str, torch.Tensor]:
|
|
|
+ if isinstance(checkpoint, dict):
|
|
|
+ for key in ("model", "state_dict"):
|
|
|
+ if key in checkpoint and isinstance(checkpoint[key], dict):
|
|
|
+ return checkpoint[key]
|
|
|
+ return checkpoint
|
|
|
+ raise TypeError(f"Unsupported checkpoint format: {type(checkpoint)!r}")
|
|
|
+
|
|
|
+
|
|
|
+def _remap_head_if_needed(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]):
|
|
|
+ if "head.bias" not in state_dict or "head.weight" not in state_dict:
|
|
|
+ return
|
|
|
+
|
|
|
+ ckpt_classes = state_dict["head.bias"].shape[0]
|
|
|
+ head_bias = getattr(model.head, "bias", None)
|
|
|
+ model_classes = head_bias.shape[0] if head_bias is not None else 0
|
|
|
+ if ckpt_classes == model_classes:
|
|
|
+ return
|
|
|
+
|
|
|
+ if ckpt_classes == 21841 and model_classes == 1000 and MAP22KTO1K_PATH.exists():
|
|
|
+ with MAP22KTO1K_PATH.open("r", encoding="utf-8") as handle:
|
|
|
+ indices = [int(line.strip()) for line in handle if line.strip()]
|
|
|
+ state_dict["head.weight"] = state_dict["head.weight"][indices, :]
|
|
|
+ state_dict["head.bias"] = state_dict["head.bias"][indices]
|
|
|
+ return
|
|
|
+
|
|
|
+ state_dict.pop("head.weight", None)
|
|
|
+ state_dict.pop("head.bias", None)
|
|
|
+
|
|
|
+
|
|
|
+def _sanitize_state_dict(model: SwinTransformerV2, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
|
+ if any(key.startswith("encoder.") for key in state_dict):
|
|
|
+ state_dict = {
|
|
|
+ key.replace("encoder.", "", 1): value
|
|
|
+ for key, value in state_dict.items()
|
|
|
+ if key.startswith("encoder.")
|
|
|
+ }
|
|
|
+
|
|
|
+ remapped: dict[str, torch.Tensor] = {}
|
|
|
+ for key, value in state_dict.items():
|
|
|
+ if "relative_position_index" in key or "relative_coords_table" in key or "attn_mask" in key:
|
|
|
+ continue
|
|
|
+ remapped[key.replace("rpe_mlp", "cpb_mlp")] = value
|
|
|
+
|
|
|
+ _remap_head_if_needed(model, remapped)
|
|
|
+
|
|
|
+ model_state = model.state_dict()
|
|
|
+ filtered: dict[str, torch.Tensor] = {}
|
|
|
+ for key, value in remapped.items():
|
|
|
+ if key in model_state and model_state[key].shape == value.shape:
|
|
|
+ filtered[key] = value
|
|
|
+ return filtered
|
|
|
+
|
|
|
+
|
|
|
+def _load_checkpoint(weight_path: Path) -> dict[str, Any]:
|
|
|
+ try:
|
|
|
+ return torch.load(weight_path, map_location="cpu")
|
|
|
+ except Exception as exc:
|
|
|
+ raise RuntimeError(f"Failed to load SwinV2 checkpoint: {weight_path}") from exc
|
|
|
+
|
|
|
+
|
|
|
+def build_swinv2(
|
|
|
+ model_name: str | None = None,
|
|
|
+ config_path: str | Path | None = None,
|
|
|
+ weight_path: str | Path | None = None,
|
|
|
+ args: Namespace | None = None,
|
|
|
+ *,
|
|
|
+ num_classes: int | None = None,
|
|
|
+ img_size: int | None = None,
|
|
|
+ in_chans: int | None = None,
|
|
|
+ use_checkpoint: bool | None = None,
|
|
|
+ strict: bool = False,
|
|
|
+ load_weights: bool = True,
|
|
|
+ return_config: bool = False,
|
|
|
+ **model_kwargs,
|
|
|
+):
|
|
|
+ """Build a SwinTransformerV2 with loaded weights.
|
|
|
+
|
|
|
+ Precedence order:
|
|
|
+ 1. internal defaults
|
|
|
+ 2. YAML config under ``configs/swinv2``
|
|
|
+ 3. explicit function inputs outside config files
|
|
|
+ 4. command line style ``args`` overrides
|
|
|
+ """
|
|
|
+
|
|
|
+ explicit_weight_path = _to_path(weight_path)
|
|
|
+ initial_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
|
|
|
+ resolved_config_path = _resolve_config_path(initial_model_name, config_path, args)
|
|
|
+ config_dict = _load_yaml_config(resolved_config_path)
|
|
|
+
|
|
|
+ resolved_model_name = _resolve_model_name(
|
|
|
+ initial_model_name or config_dict["MODEL"]["NAME"],
|
|
|
+ resolved_config_path,
|
|
|
+ explicit_weight_path,
|
|
|
+ args,
|
|
|
+ )
|
|
|
+ resolved_weight_path = _resolve_weight_path(resolved_model_name, weight_path, args)
|
|
|
+
|
|
|
+ function_overrides = _collect_function_overrides(
|
|
|
+ model_name=resolved_model_name,
|
|
|
+ weight_path=resolved_weight_path,
|
|
|
+ num_classes=num_classes,
|
|
|
+ img_size=img_size,
|
|
|
+ in_chans=in_chans,
|
|
|
+ use_checkpoint=use_checkpoint,
|
|
|
+ model_kwargs=model_kwargs,
|
|
|
+ )
|
|
|
+ arg_overrides = _collect_arg_overrides(args)
|
|
|
+ config_dict = _apply_overrides(config_dict, function_overrides)
|
|
|
+ config_dict = _apply_overrides(config_dict, arg_overrides)
|
|
|
+
|
|
|
+ model_cfg = config_dict["MODEL"]
|
|
|
+ swinv2_cfg = model_cfg["SWINV2"]
|
|
|
+ model = SwinTransformerV2(
|
|
|
+ img_size=config_dict["DATA"]["IMG_SIZE"],
|
|
|
+ patch_size=swinv2_cfg["PATCH_SIZE"],
|
|
|
+ in_chans=swinv2_cfg["IN_CHANS"],
|
|
|
+ num_classes=model_cfg["NUM_CLASSES"],
|
|
|
+ embed_dim=swinv2_cfg["EMBED_DIM"],
|
|
|
+ depths=tuple(swinv2_cfg["DEPTHS"]),
|
|
|
+ num_heads=tuple(swinv2_cfg["NUM_HEADS"]),
|
|
|
+ window_size=swinv2_cfg["WINDOW_SIZE"],
|
|
|
+ mlp_ratio=swinv2_cfg["MLP_RATIO"],
|
|
|
+ qkv_bias=swinv2_cfg["QKV_BIAS"],
|
|
|
+ drop_rate=model_cfg["DROP_RATE"],
|
|
|
+ drop_path_rate=model_cfg["DROP_PATH_RATE"],
|
|
|
+ ape=swinv2_cfg["APE"],
|
|
|
+ patch_norm=swinv2_cfg["PATCH_NORM"],
|
|
|
+ use_checkpoint=config_dict["TRAIN"]["USE_CHECKPOINT"],
|
|
|
+ pretrained_window_sizes=tuple(swinv2_cfg["PRETRAINED_WINDOW_SIZES"]),
|
|
|
+ )
|
|
|
+
|
|
|
+ if load_weights:
|
|
|
+ if resolved_weight_path is None:
|
|
|
+ raise FileNotFoundError(
|
|
|
+ f"No SwinV2 weight file resolved for model '{model_cfg['NAME']}'. "
|
|
|
+ f"Expected one under {SWINV2_WEIGHT_DIR}."
|
|
|
+ )
|
|
|
+ if not resolved_weight_path.exists():
|
|
|
+ raise FileNotFoundError(f"SwinV2 weight file not found: {resolved_weight_path}")
|
|
|
+
|
|
|
+ checkpoint = _load_checkpoint(resolved_weight_path)
|
|
|
+ state_dict = _sanitize_state_dict(model, _extract_state_dict(checkpoint))
|
|
|
+ model.load_state_dict(state_dict, strict=strict)
|
|
|
+
|
|
|
+ config = _dict_to_namespace(config_dict)
|
|
|
+ if return_config:
|
|
|
+ return model, config
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+def build_swinv2_auto(
|
|
|
+ model_name: str | None = None,
|
|
|
+ config_path: str | Path | None = None,
|
|
|
+ weight_path: str | Path | None = None,
|
|
|
+ args: Namespace | None = None,
|
|
|
+ *,
|
|
|
+ verbose: bool = True,
|
|
|
+ return_config: bool = False,
|
|
|
+ return_resolution: bool = False,
|
|
|
+ **kwargs,
|
|
|
+):
|
|
|
+ """Auto-resolve SwinV2 config and weights by ``MODEL.NAME`` and print sources.
|
|
|
+
|
|
|
+ This wrapper keeps the same precedence rules as ``build_swinv2`` while making
|
|
|
+ the config/weight resolution explicit for callers.
|
|
|
+ """
|
|
|
+
|
|
|
+ explicit_weight_path = _to_path(weight_path)
|
|
|
+ candidate_model_name = _resolve_model_name(model_name, _to_path(config_path), explicit_weight_path, args)
|
|
|
+ resolved_config_path, config_source = _resolve_config_with_source(candidate_model_name, config_path, args)
|
|
|
+
|
|
|
+ temp_config = _load_yaml_config(resolved_config_path)
|
|
|
+ final_model_name = _resolve_model_name(
|
|
|
+ candidate_model_name or temp_config["MODEL"]["NAME"],
|
|
|
+ resolved_config_path,
|
|
|
+ explicit_weight_path,
|
|
|
+ args,
|
|
|
+ ) or temp_config["MODEL"]["NAME"]
|
|
|
+ resolved_weight_path, weight_source = _resolve_weight_with_source(final_model_name, weight_path, args)
|
|
|
+
|
|
|
+ built = build_swinv2(
|
|
|
+ model_name=final_model_name,
|
|
|
+ config_path=resolved_config_path,
|
|
|
+ weight_path=resolved_weight_path,
|
|
|
+ args=args,
|
|
|
+ return_config=True,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+ if not isinstance(built, tuple) or len(built) != 2:
|
|
|
+ raise RuntimeError("build_swinv2(return_config=True) must return (model, config)")
|
|
|
+ model, config = built
|
|
|
+
|
|
|
+ resolution = {
|
|
|
+ "model_name": config.MODEL.NAME,
|
|
|
+ "config_path": str(resolved_config_path) if resolved_config_path is not None else None,
|
|
|
+ "config_source": config_source,
|
|
|
+ "weight_path": str(resolved_weight_path) if resolved_weight_path is not None else None,
|
|
|
+ "weight_source": weight_source,
|
|
|
+ }
|
|
|
+
|
|
|
+ if verbose:
|
|
|
+ print(
|
|
|
+ "[build_swinv2_auto] "
|
|
|
+ f"MODEL.NAME={resolution['model_name']} | "
|
|
|
+ f"config={resolution['config_path']} ({resolution['config_source']}) | "
|
|
|
+ f"weight={resolution['weight_path']} ({resolution['weight_source']})"
|
|
|
+ )
|
|
|
+
|
|
|
+ if return_config and return_resolution:
|
|
|
+ return model, config, resolution
|
|
|
+ if return_config:
|
|
|
+ return model, config
|
|
|
+ if return_resolution:
|
|
|
+ return model, resolution
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+__all__ = ["build_swinv2", "build_swinv2_auto"]
|