| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- from __future__ import annotations
- import argparse
- import gc
- import sys
- import time
- from pathlib import Path
- from typing import Any
- import torch
- ROOT_DIR = Path(__file__).resolve().parents[1]
- if str(ROOT_DIR) not in sys.path:
- sys.path.insert(0, str(ROOT_DIR))
- from lib.modules import XNet2d
- from lib.tools import build_loss, build_optimizer
- from lib.utils.config import apply_dotlist_overrides, load_yaml_config
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(
- description="Probe XNet2d CUDA memory with synthetic segmentation batches."
- )
- parser.add_argument(
- "--config",
- default="configs/segmentation/train_sup_us_template.yaml",
- help="YAML config path.",
- )
- parser.add_argument(
- "--batch-sizes",
- nargs="+",
- type=int,
- default=[4, 6, 8],
- help="Batch sizes to probe.",
- )
- parser.add_argument(
- "--image-size",
- nargs=2,
- type=int,
- default=None,
- metavar=("H", "W"),
- help="Override dataset.image_size.",
- )
- parser.add_argument(
- "--amp",
- action=argparse.BooleanOptionalAction,
- default=None,
- help="Override train.amp.",
- )
- parser.add_argument(
- "--device",
- default="cuda",
- help="Device to probe. CUDA is required for memory numbers.",
- )
- parser.add_argument(
- "--warmup",
- action="store_true",
- help="Run one unmeasured warmup step before measuring each batch size.",
- )
- parser.add_argument(
- "--set",
- nargs="*",
- default=None,
- help="Override config values with key=value pairs.",
- )
- return parser.parse_args()
- def build_model(cfg: dict[str, Any], device: torch.device) -> XNet2d:
- dataset_cfg = cfg["dataset"]
- model_cfg = cfg["model"]
- return XNet2d(
- in_channels=int(
- model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3))
- ),
- num_classes=int(dataset_cfg["num_classes"]),
- encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))),
- encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))),
- decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))),
- stem_channels=int(model_cfg.get("stem_channels", 24)),
- bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)),
- global_ratio=float(model_cfg.get("global_ratio", 2.0)),
- wavelet_type=str(model_cfg.get("wavelet_type", "haar")),
- wavelet_level=int(model_cfg.get("wavelet_level", 1)),
- use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)),
- use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)),
- ssm_d_state=int(model_cfg.get("ssm_d_state", 16)),
- ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
- ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
- use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
- low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)),
- low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)),
- learnable_low_freq_radius=bool(
- model_cfg.get("learnable_low_freq_radius", True)
- ),
- guide_mode=str(model_cfg.get("guide_mode", "affine")),
- out_channels=model_cfg.get("out_channels"),
- ).to(device)
- def release_cuda() -> None:
- gc.collect()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
- def make_batch(
- *,
- batch_size: int,
- in_channels: int,
- num_classes: int,
- image_size: tuple[int, int],
- device: torch.device,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- height, width = image_size
- image = torch.randn(batch_size, in_channels, height, width, device=device)
- if num_classes == 1:
- mask = torch.randint(
- 0, 2, (batch_size, 1, height, width), device=device
- ).float()
- else:
- mask = torch.randint(
- 0, num_classes, (batch_size, 1, height, width), device=device
- )
- return image, mask
- def run_step(
- *,
- cfg: dict[str, Any],
- batch_size: int,
- image_size: tuple[int, int],
- device: torch.device,
- amp_enabled: bool,
- ) -> dict[str, float]:
- release_cuda()
- model = build_model(cfg, device)
- model.train()
- optimizer = build_optimizer(model, cfg["optimizer"])
- loss_fn = build_loss(cfg["loss"])
- dataset_cfg = cfg["dataset"]
- in_channels = int(
- dataset_cfg.get("in_channels", cfg["model"].get("in_channels", 3))
- )
- num_classes = int(dataset_cfg["num_classes"])
- image, mask = make_batch(
- batch_size=batch_size,
- in_channels=in_channels,
- num_classes=num_classes,
- image_size=image_size,
- device=device,
- )
- torch.cuda.synchronize(device)
- torch.cuda.reset_peak_memory_stats(device)
- start = time.perf_counter()
- optimizer.zero_grad(set_to_none=True)
- with torch.autocast(device_type=device.type, enabled=amp_enabled):
- outputs = model(image)
- loss = loss_fn(outputs["seg_logits"], mask)
- loss.backward()
- optimizer.step()
- torch.cuda.synchronize(device)
- elapsed = time.perf_counter() - start
- result = {
- "loss": float(loss.detach().cpu()),
- "seconds": elapsed,
- "allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
- "reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
- }
- del model, optimizer, loss_fn, image, mask, outputs, loss
- release_cuda()
- return result
- def print_header(
- cfg: dict[str, Any],
- image_size: tuple[int, int],
- device: torch.device,
- amp_enabled: bool,
- ) -> None:
- model_cfg = cfg["model"]
- print("XNet2d memory probe")
- print(
- f"device: {torch.cuda.get_device_name(device) if device.type == 'cuda' else device}"
- )
- print(f"image_size: {list(image_size)}")
- print(f"amp: {amp_enabled}")
- print(f"encoder_channels: {model_cfg.get('encoder_channels')}")
- print(f"encoder_depths: {model_cfg.get('encoder_depths')}")
- print(f"global_ratio: {model_cfg.get('global_ratio')}")
- print()
- print(
- f"{'batch':>5} {'status':>8} {'allocated':>12} {'reserved':>12} "
- f"{'seconds':>8} {'loss/error'}"
- )
- print("-" * 78)
- def main() -> None:
- args = parse_args()
- if args.device == "cuda" and not torch.cuda.is_available():
- raise RuntimeError("CUDA is not available.")
- cfg_path = (
- ROOT_DIR / args.config
- if not Path(args.config).is_absolute()
- else Path(args.config)
- )
- cfg = apply_dotlist_overrides(load_yaml_config(cfg_path), args.set)
- device = torch.device(args.device)
- image_size = tuple(args.image_size or cfg["dataset"]["image_size"])
- amp_enabled = bool(
- cfg.get("train", {}).get("amp", False) if args.amp is None else args.amp
- )
- print_header(cfg, image_size, device, amp_enabled)
- for batch_size in args.batch_sizes:
- try:
- if args.warmup:
- run_step(
- cfg=cfg,
- batch_size=batch_size,
- image_size=image_size,
- device=device,
- amp_enabled=amp_enabled,
- )
- result = run_step(
- cfg=cfg,
- batch_size=batch_size,
- image_size=image_size,
- device=device,
- amp_enabled=amp_enabled,
- )
- print(
- f"{batch_size:>5} {'ok':>8} "
- f"{result['allocated_mb']:>9.1f} MB "
- f"{result['reserved_mb']:>9.1f} MB "
- f"{result['seconds']:>8.2f} "
- f"loss={result['loss']:.6f}"
- )
- except torch.cuda.OutOfMemoryError as exc:
- release_cuda()
- print(
- f"{batch_size:>5} {'OOM':>8} "
- f"{'-':>12} {'-':>12} {'-':>8} {str(exc).splitlines()[0]}"
- )
- except Exception as exc:
- release_cuda()
- print(
- f"{batch_size:>5} {'ERROR':>8} "
- f"{'-':>12} {'-':>12} {'-':>8} {type(exc).__name__}: {exc}"
- )
- if __name__ == "__main__":
- main()
|