|
|
@@ -0,0 +1,261 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import argparse
|
|
|
+import gc
|
|
|
+import sys
|
|
|
+import time
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+import torch
|
|
|
+
|
|
|
+ROOT_DIR = Path(__file__).resolve().parents[1]
|
|
|
+if str(ROOT_DIR) not in sys.path:
|
|
|
+ sys.path.insert(0, str(ROOT_DIR))
|
|
|
+
|
|
|
+from lib.modules import XNet2d
|
|
|
+from lib.tools import build_loss, build_optimizer
|
|
|
+from lib.utils.config import apply_dotlist_overrides, load_yaml_config
|
|
|
+
|
|
|
+
|
|
|
+def parse_args() -> argparse.Namespace:
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="Probe XNet2d CUDA memory with synthetic segmentation batches."
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--config",
|
|
|
+ default="configs/segmentation/train_sup_us_template.yaml",
|
|
|
+ help="YAML config path.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--batch-sizes",
|
|
|
+ nargs="+",
|
|
|
+ type=int,
|
|
|
+ default=[4, 6, 8],
|
|
|
+ help="Batch sizes to probe.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--image-size",
|
|
|
+ nargs=2,
|
|
|
+ type=int,
|
|
|
+ default=None,
|
|
|
+ metavar=("H", "W"),
|
|
|
+ help="Override dataset.image_size.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--amp",
|
|
|
+ action=argparse.BooleanOptionalAction,
|
|
|
+ default=None,
|
|
|
+ help="Override train.amp.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--device",
|
|
|
+ default="cuda",
|
|
|
+ help="Device to probe. CUDA is required for memory numbers.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--warmup",
|
|
|
+ action="store_true",
|
|
|
+ help="Run one unmeasured warmup step before measuring each batch size.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--set",
|
|
|
+ nargs="*",
|
|
|
+ default=None,
|
|
|
+ help="Override config values with key=value pairs.",
|
|
|
+ )
|
|
|
+ return parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
+def build_model(cfg: dict[str, Any], device: torch.device) -> XNet2d:
|
|
|
+ dataset_cfg = cfg["dataset"]
|
|
|
+ model_cfg = cfg["model"]
|
|
|
+ return XNet2d(
|
|
|
+ in_channels=int(
|
|
|
+ model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3))
|
|
|
+ ),
|
|
|
+ num_classes=int(dataset_cfg["num_classes"]),
|
|
|
+ encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))),
|
|
|
+ encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))),
|
|
|
+ decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))),
|
|
|
+ stem_channels=int(model_cfg.get("stem_channels", 24)),
|
|
|
+ bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)),
|
|
|
+ global_ratio=float(model_cfg.get("global_ratio", 2.0)),
|
|
|
+ wavelet_type=str(model_cfg.get("wavelet_type", "haar")),
|
|
|
+ wavelet_level=int(model_cfg.get("wavelet_level", 1)),
|
|
|
+ use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)),
|
|
|
+ use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)),
|
|
|
+ ssm_d_state=int(model_cfg.get("ssm_d_state", 16)),
|
|
|
+ ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
|
|
|
+ ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
|
|
|
+ use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
|
|
|
+ low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)),
|
|
|
+ low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)),
|
|
|
+ learnable_low_freq_radius=bool(
|
|
|
+ model_cfg.get("learnable_low_freq_radius", True)
|
|
|
+ ),
|
|
|
+ guide_mode=str(model_cfg.get("guide_mode", "affine")),
|
|
|
+ out_channels=model_cfg.get("out_channels"),
|
|
|
+ ).to(device)
|
|
|
+
|
|
|
+
|
|
|
+def release_cuda() -> None:
|
|
|
+ gc.collect()
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ torch.cuda.reset_peak_memory_stats()
|
|
|
+
|
|
|
+
|
|
|
+def make_batch(
|
|
|
+ *,
|
|
|
+ batch_size: int,
|
|
|
+ in_channels: int,
|
|
|
+ num_classes: int,
|
|
|
+ image_size: tuple[int, int],
|
|
|
+ device: torch.device,
|
|
|
+) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ height, width = image_size
|
|
|
+ image = torch.randn(batch_size, in_channels, height, width, device=device)
|
|
|
+ if num_classes == 1:
|
|
|
+ mask = torch.randint(
|
|
|
+ 0, 2, (batch_size, 1, height, width), device=device
|
|
|
+ ).float()
|
|
|
+ else:
|
|
|
+ mask = torch.randint(
|
|
|
+ 0, num_classes, (batch_size, 1, height, width), device=device
|
|
|
+ )
|
|
|
+ return image, mask
|
|
|
+
|
|
|
+
|
|
|
+def run_step(
|
|
|
+ *,
|
|
|
+ cfg: dict[str, Any],
|
|
|
+ batch_size: int,
|
|
|
+ image_size: tuple[int, int],
|
|
|
+ device: torch.device,
|
|
|
+ amp_enabled: bool,
|
|
|
+) -> dict[str, float]:
|
|
|
+ release_cuda()
|
|
|
+ model = build_model(cfg, device)
|
|
|
+ model.train()
|
|
|
+ optimizer = build_optimizer(model, cfg["optimizer"])
|
|
|
+ loss_fn = build_loss(cfg["loss"])
|
|
|
+
|
|
|
+ dataset_cfg = cfg["dataset"]
|
|
|
+ in_channels = int(
|
|
|
+ dataset_cfg.get("in_channels", cfg["model"].get("in_channels", 3))
|
|
|
+ )
|
|
|
+ num_classes = int(dataset_cfg["num_classes"])
|
|
|
+ image, mask = make_batch(
|
|
|
+ batch_size=batch_size,
|
|
|
+ in_channels=in_channels,
|
|
|
+ num_classes=num_classes,
|
|
|
+ image_size=image_size,
|
|
|
+ device=device,
|
|
|
+ )
|
|
|
+
|
|
|
+ torch.cuda.synchronize(device)
|
|
|
+ torch.cuda.reset_peak_memory_stats(device)
|
|
|
+ start = time.perf_counter()
|
|
|
+ optimizer.zero_grad(set_to_none=True)
|
|
|
+ with torch.autocast(device_type=device.type, enabled=amp_enabled):
|
|
|
+ outputs = model(image)
|
|
|
+ loss = loss_fn(outputs["seg_logits"], mask)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ torch.cuda.synchronize(device)
|
|
|
+ elapsed = time.perf_counter() - start
|
|
|
+
|
|
|
+ result = {
|
|
|
+ "loss": float(loss.detach().cpu()),
|
|
|
+ "seconds": elapsed,
|
|
|
+ "allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
|
|
|
+ "reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
|
|
|
+ }
|
|
|
+ del model, optimizer, loss_fn, image, mask, outputs, loss
|
|
|
+ release_cuda()
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+def print_header(
|
|
|
+ cfg: dict[str, Any],
|
|
|
+ image_size: tuple[int, int],
|
|
|
+ device: torch.device,
|
|
|
+ amp_enabled: bool,
|
|
|
+) -> None:
|
|
|
+ model_cfg = cfg["model"]
|
|
|
+ print("XNet2d memory probe")
|
|
|
+ print(
|
|
|
+ f"device: {torch.cuda.get_device_name(device) if device.type == 'cuda' else device}"
|
|
|
+ )
|
|
|
+ print(f"image_size: {list(image_size)}")
|
|
|
+ print(f"amp: {amp_enabled}")
|
|
|
+ print(f"encoder_channels: {model_cfg.get('encoder_channels')}")
|
|
|
+ print(f"encoder_depths: {model_cfg.get('encoder_depths')}")
|
|
|
+ print(f"global_ratio: {model_cfg.get('global_ratio')}")
|
|
|
+ print()
|
|
|
+ print(
|
|
|
+ f"{'batch':>5} {'status':>8} {'allocated':>12} {'reserved':>12} "
|
|
|
+ f"{'seconds':>8} {'loss/error'}"
|
|
|
+ )
|
|
|
+ print("-" * 78)
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ args = parse_args()
|
|
|
+ if args.device == "cuda" and not torch.cuda.is_available():
|
|
|
+ raise RuntimeError("CUDA is not available.")
|
|
|
+
|
|
|
+ cfg_path = (
|
|
|
+ ROOT_DIR / args.config
|
|
|
+ if not Path(args.config).is_absolute()
|
|
|
+ else Path(args.config)
|
|
|
+ )
|
|
|
+ cfg = apply_dotlist_overrides(load_yaml_config(cfg_path), args.set)
|
|
|
+ device = torch.device(args.device)
|
|
|
+ image_size = tuple(args.image_size or cfg["dataset"]["image_size"])
|
|
|
+ amp_enabled = bool(
|
|
|
+ cfg.get("train", {}).get("amp", False) if args.amp is None else args.amp
|
|
|
+ )
|
|
|
+
|
|
|
+ print_header(cfg, image_size, device, amp_enabled)
|
|
|
+ for batch_size in args.batch_sizes:
|
|
|
+ try:
|
|
|
+ if args.warmup:
|
|
|
+ run_step(
|
|
|
+ cfg=cfg,
|
|
|
+ batch_size=batch_size,
|
|
|
+ image_size=image_size,
|
|
|
+ device=device,
|
|
|
+ amp_enabled=amp_enabled,
|
|
|
+ )
|
|
|
+ result = run_step(
|
|
|
+ cfg=cfg,
|
|
|
+ batch_size=batch_size,
|
|
|
+ image_size=image_size,
|
|
|
+ device=device,
|
|
|
+ amp_enabled=amp_enabled,
|
|
|
+ )
|
|
|
+ print(
|
|
|
+ f"{batch_size:>5} {'ok':>8} "
|
|
|
+ f"{result['allocated_mb']:>9.1f} MB "
|
|
|
+ f"{result['reserved_mb']:>9.1f} MB "
|
|
|
+ f"{result['seconds']:>8.2f} "
|
|
|
+ f"loss={result['loss']:.6f}"
|
|
|
+ )
|
|
|
+ except torch.cuda.OutOfMemoryError as exc:
|
|
|
+ release_cuda()
|
|
|
+ print(
|
|
|
+ f"{batch_size:>5} {'OOM':>8} "
|
|
|
+ f"{'-':>12} {'-':>12} {'-':>8} {str(exc).splitlines()[0]}"
|
|
|
+ )
|
|
|
+ except Exception as exc:
|
|
|
+ release_cuda()
|
|
|
+ print(
|
|
|
+ f"{batch_size:>5} {'ERROR':>8} "
|
|
|
+ f"{'-':>12} {'-':>12} {'-':>8} {type(exc).__name__}: {exc}"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|