from __future__ import annotations import argparse import gc import sys import time from pathlib import Path from typing import Any import torch ROOT_DIR = Path(__file__).resolve().parents[1] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) from lib.modules import XNet2d from lib.tools import build_loss, build_optimizer from lib.utils.config import apply_dotlist_overrides, load_yaml_config def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Probe XNet2d CUDA memory with synthetic segmentation batches." ) parser.add_argument( "--config", default="configs/segmentation/train_sup_us_template.yaml", help="YAML config path.", ) parser.add_argument( "--batch-sizes", nargs="+", type=int, default=[4, 6, 8], help="Batch sizes to probe.", ) parser.add_argument( "--image-size", nargs=2, type=int, default=None, metavar=("H", "W"), help="Override dataset.image_size.", ) parser.add_argument( "--amp", action=argparse.BooleanOptionalAction, default=None, help="Override train.amp.", ) parser.add_argument( "--device", default="cuda", help="Device to probe. CUDA is required for memory numbers.", ) parser.add_argument( "--warmup", action="store_true", help="Run one unmeasured warmup step before measuring each batch size.", ) parser.add_argument( "--set", nargs="*", default=None, help="Override config values with key=value pairs.", ) return parser.parse_args() def build_model(cfg: dict[str, Any], device: torch.device) -> XNet2d: dataset_cfg = cfg["dataset"] model_cfg = cfg["model"] return XNet2d( in_channels=int( model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3)) ), num_classes=int(dataset_cfg["num_classes"]), encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))), encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))), decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))), stem_channels=int(model_cfg.get("stem_channels", 24)), bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)), global_ratio=float(model_cfg.get("global_ratio", 2.0)), wavelet_type=str(model_cfg.get("wavelet_type", "haar")), wavelet_level=int(model_cfg.get("wavelet_level", 1)), use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)), use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)), ssm_d_state=int(model_cfg.get("ssm_d_state", 16)), ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")), ssm_backend=str(model_cfg.get("ssm_backend", "auto")), use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)), low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)), low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)), learnable_low_freq_radius=bool( model_cfg.get("learnable_low_freq_radius", True) ), guide_mode=str(model_cfg.get("guide_mode", "affine")), out_channels=model_cfg.get("out_channels"), ).to(device) def release_cuda() -> None: gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() def make_batch( *, batch_size: int, in_channels: int, num_classes: int, image_size: tuple[int, int], device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: height, width = image_size image = torch.randn(batch_size, in_channels, height, width, device=device) if num_classes == 1: mask = torch.randint( 0, 2, (batch_size, 1, height, width), device=device ).float() else: mask = torch.randint( 0, num_classes, (batch_size, 1, height, width), device=device ) return image, mask def run_step( *, cfg: dict[str, Any], batch_size: int, image_size: tuple[int, int], device: torch.device, amp_enabled: bool, ) -> dict[str, float]: release_cuda() model = build_model(cfg, device) model.train() optimizer = build_optimizer(model, cfg["optimizer"]) loss_fn = build_loss(cfg["loss"]) dataset_cfg = cfg["dataset"] in_channels = int( dataset_cfg.get("in_channels", cfg["model"].get("in_channels", 3)) ) num_classes = int(dataset_cfg["num_classes"]) image, mask = make_batch( batch_size=batch_size, in_channels=in_channels, num_classes=num_classes, image_size=image_size, device=device, ) torch.cuda.synchronize(device) torch.cuda.reset_peak_memory_stats(device) start = time.perf_counter() optimizer.zero_grad(set_to_none=True) with torch.autocast(device_type=device.type, enabled=amp_enabled): outputs = model(image) loss = loss_fn(outputs["seg_logits"], mask) loss.backward() optimizer.step() torch.cuda.synchronize(device) elapsed = time.perf_counter() - start result = { "loss": float(loss.detach().cpu()), "seconds": elapsed, "allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2), "reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2), } del model, optimizer, loss_fn, image, mask, outputs, loss release_cuda() return result def print_header( cfg: dict[str, Any], image_size: tuple[int, int], device: torch.device, amp_enabled: bool, ) -> None: model_cfg = cfg["model"] print("XNet2d memory probe") print( f"device: {torch.cuda.get_device_name(device) if device.type == 'cuda' else device}" ) print(f"image_size: {list(image_size)}") print(f"amp: {amp_enabled}") print(f"encoder_channels: {model_cfg.get('encoder_channels')}") print(f"encoder_depths: {model_cfg.get('encoder_depths')}") print(f"global_ratio: {model_cfg.get('global_ratio')}") print() print( f"{'batch':>5} {'status':>8} {'allocated':>12} {'reserved':>12} " f"{'seconds':>8} {'loss/error'}" ) print("-" * 78) def main() -> None: args = parse_args() if args.device == "cuda" and not torch.cuda.is_available(): raise RuntimeError("CUDA is not available.") cfg_path = ( ROOT_DIR / args.config if not Path(args.config).is_absolute() else Path(args.config) ) cfg = apply_dotlist_overrides(load_yaml_config(cfg_path), args.set) device = torch.device(args.device) image_size = tuple(args.image_size or cfg["dataset"]["image_size"]) amp_enabled = bool( cfg.get("train", {}).get("amp", False) if args.amp is None else args.amp ) print_header(cfg, image_size, device, amp_enabled) for batch_size in args.batch_sizes: try: if args.warmup: run_step( cfg=cfg, batch_size=batch_size, image_size=image_size, device=device, amp_enabled=amp_enabled, ) result = run_step( cfg=cfg, batch_size=batch_size, image_size=image_size, device=device, amp_enabled=amp_enabled, ) print( f"{batch_size:>5} {'ok':>8} " f"{result['allocated_mb']:>9.1f} MB " f"{result['reserved_mb']:>9.1f} MB " f"{result['seconds']:>8.2f} " f"loss={result['loss']:.6f}" ) except torch.cuda.OutOfMemoryError as exc: release_cuda() print( f"{batch_size:>5} {'OOM':>8} " f"{'-':>12} {'-':>12} {'-':>8} {str(exc).splitlines()[0]}" ) except Exception as exc: release_cuda() print( f"{batch_size:>5} {'ERROR':>8} " f"{'-':>12} {'-':>12} {'-':>8} {type(exc).__name__}: {exc}" ) if __name__ == "__main__": main()