probe_xnet_memory.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. from __future__ import annotations
  2. import argparse
  3. import gc
  4. import sys
  5. import time
  6. from pathlib import Path
  7. from typing import Any
  8. import torch
  9. ROOT_DIR = Path(__file__).resolve().parents[1]
  10. if str(ROOT_DIR) not in sys.path:
  11. sys.path.insert(0, str(ROOT_DIR))
  12. from lib.modules import XNet2d
  13. from lib.tools import build_loss, build_optimizer
  14. from lib.utils.config import apply_dotlist_overrides, load_yaml_config
  15. def parse_args() -> argparse.Namespace:
  16. parser = argparse.ArgumentParser(
  17. description="Probe XNet2d CUDA memory with synthetic segmentation batches."
  18. )
  19. parser.add_argument(
  20. "--config",
  21. default="configs/segmentation/train_sup_us_template.yaml",
  22. help="YAML config path.",
  23. )
  24. parser.add_argument(
  25. "--batch-sizes",
  26. nargs="+",
  27. type=int,
  28. default=[4, 6, 8],
  29. help="Batch sizes to probe.",
  30. )
  31. parser.add_argument(
  32. "--image-size",
  33. nargs=2,
  34. type=int,
  35. default=None,
  36. metavar=("H", "W"),
  37. help="Override dataset.image_size.",
  38. )
  39. parser.add_argument(
  40. "--amp",
  41. action=argparse.BooleanOptionalAction,
  42. default=None,
  43. help="Override train.amp.",
  44. )
  45. parser.add_argument(
  46. "--device",
  47. default="cuda",
  48. help="Device to probe. CUDA is required for memory numbers.",
  49. )
  50. parser.add_argument(
  51. "--warmup",
  52. action="store_true",
  53. help="Run one unmeasured warmup step before measuring each batch size.",
  54. )
  55. parser.add_argument(
  56. "--set",
  57. nargs="*",
  58. default=None,
  59. help="Override config values with key=value pairs.",
  60. )
  61. return parser.parse_args()
  62. def build_model(cfg: dict[str, Any], device: torch.device) -> XNet2d:
  63. dataset_cfg = cfg["dataset"]
  64. model_cfg = cfg["model"]
  65. return XNet2d(
  66. in_channels=int(
  67. model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3))
  68. ),
  69. num_classes=int(dataset_cfg["num_classes"]),
  70. encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))),
  71. encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))),
  72. decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))),
  73. stem_channels=int(model_cfg.get("stem_channels", 24)),
  74. bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)),
  75. global_ratio=float(model_cfg.get("global_ratio", 2.0)),
  76. wavelet_type=str(model_cfg.get("wavelet_type", "haar")),
  77. wavelet_level=int(model_cfg.get("wavelet_level", 1)),
  78. use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)),
  79. use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)),
  80. ssm_d_state=int(model_cfg.get("ssm_d_state", 16)),
  81. ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
  82. ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
  83. use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
  84. low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)),
  85. low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)),
  86. learnable_low_freq_radius=bool(
  87. model_cfg.get("learnable_low_freq_radius", True)
  88. ),
  89. guide_mode=str(model_cfg.get("guide_mode", "affine")),
  90. out_channels=model_cfg.get("out_channels"),
  91. ).to(device)
  92. def release_cuda() -> None:
  93. gc.collect()
  94. if torch.cuda.is_available():
  95. torch.cuda.empty_cache()
  96. torch.cuda.reset_peak_memory_stats()
  97. def make_batch(
  98. *,
  99. batch_size: int,
  100. in_channels: int,
  101. num_classes: int,
  102. image_size: tuple[int, int],
  103. device: torch.device,
  104. ) -> tuple[torch.Tensor, torch.Tensor]:
  105. height, width = image_size
  106. image = torch.randn(batch_size, in_channels, height, width, device=device)
  107. if num_classes == 1:
  108. mask = torch.randint(
  109. 0, 2, (batch_size, 1, height, width), device=device
  110. ).float()
  111. else:
  112. mask = torch.randint(
  113. 0, num_classes, (batch_size, 1, height, width), device=device
  114. )
  115. return image, mask
  116. def run_step(
  117. *,
  118. cfg: dict[str, Any],
  119. batch_size: int,
  120. image_size: tuple[int, int],
  121. device: torch.device,
  122. amp_enabled: bool,
  123. ) -> dict[str, float]:
  124. release_cuda()
  125. model = build_model(cfg, device)
  126. model.train()
  127. optimizer = build_optimizer(model, cfg["optimizer"])
  128. loss_fn = build_loss(cfg["loss"])
  129. dataset_cfg = cfg["dataset"]
  130. in_channels = int(
  131. dataset_cfg.get("in_channels", cfg["model"].get("in_channels", 3))
  132. )
  133. num_classes = int(dataset_cfg["num_classes"])
  134. image, mask = make_batch(
  135. batch_size=batch_size,
  136. in_channels=in_channels,
  137. num_classes=num_classes,
  138. image_size=image_size,
  139. device=device,
  140. )
  141. torch.cuda.synchronize(device)
  142. torch.cuda.reset_peak_memory_stats(device)
  143. start = time.perf_counter()
  144. optimizer.zero_grad(set_to_none=True)
  145. with torch.autocast(device_type=device.type, enabled=amp_enabled):
  146. outputs = model(image)
  147. loss = loss_fn(outputs["seg_logits"], mask)
  148. loss.backward()
  149. optimizer.step()
  150. torch.cuda.synchronize(device)
  151. elapsed = time.perf_counter() - start
  152. result = {
  153. "loss": float(loss.detach().cpu()),
  154. "seconds": elapsed,
  155. "allocated_mb": torch.cuda.max_memory_allocated(device) / (1024**2),
  156. "reserved_mb": torch.cuda.max_memory_reserved(device) / (1024**2),
  157. }
  158. del model, optimizer, loss_fn, image, mask, outputs, loss
  159. release_cuda()
  160. return result
  161. def print_header(
  162. cfg: dict[str, Any],
  163. image_size: tuple[int, int],
  164. device: torch.device,
  165. amp_enabled: bool,
  166. ) -> None:
  167. model_cfg = cfg["model"]
  168. print("XNet2d memory probe")
  169. print(
  170. f"device: {torch.cuda.get_device_name(device) if device.type == 'cuda' else device}"
  171. )
  172. print(f"image_size: {list(image_size)}")
  173. print(f"amp: {amp_enabled}")
  174. print(f"encoder_channels: {model_cfg.get('encoder_channels')}")
  175. print(f"encoder_depths: {model_cfg.get('encoder_depths')}")
  176. print(f"global_ratio: {model_cfg.get('global_ratio')}")
  177. print()
  178. print(
  179. f"{'batch':>5} {'status':>8} {'allocated':>12} {'reserved':>12} "
  180. f"{'seconds':>8} {'loss/error'}"
  181. )
  182. print("-" * 78)
  183. def main() -> None:
  184. args = parse_args()
  185. if args.device == "cuda" and not torch.cuda.is_available():
  186. raise RuntimeError("CUDA is not available.")
  187. cfg_path = (
  188. ROOT_DIR / args.config
  189. if not Path(args.config).is_absolute()
  190. else Path(args.config)
  191. )
  192. cfg = apply_dotlist_overrides(load_yaml_config(cfg_path), args.set)
  193. device = torch.device(args.device)
  194. image_size = tuple(args.image_size or cfg["dataset"]["image_size"])
  195. amp_enabled = bool(
  196. cfg.get("train", {}).get("amp", False) if args.amp is None else args.amp
  197. )
  198. print_header(cfg, image_size, device, amp_enabled)
  199. for batch_size in args.batch_sizes:
  200. try:
  201. if args.warmup:
  202. run_step(
  203. cfg=cfg,
  204. batch_size=batch_size,
  205. image_size=image_size,
  206. device=device,
  207. amp_enabled=amp_enabled,
  208. )
  209. result = run_step(
  210. cfg=cfg,
  211. batch_size=batch_size,
  212. image_size=image_size,
  213. device=device,
  214. amp_enabled=amp_enabled,
  215. )
  216. print(
  217. f"{batch_size:>5} {'ok':>8} "
  218. f"{result['allocated_mb']:>9.1f} MB "
  219. f"{result['reserved_mb']:>9.1f} MB "
  220. f"{result['seconds']:>8.2f} "
  221. f"loss={result['loss']:.6f}"
  222. )
  223. except torch.cuda.OutOfMemoryError as exc:
  224. release_cuda()
  225. print(
  226. f"{batch_size:>5} {'OOM':>8} "
  227. f"{'-':>12} {'-':>12} {'-':>8} {str(exc).splitlines()[0]}"
  228. )
  229. except Exception as exc:
  230. release_cuda()
  231. print(
  232. f"{batch_size:>5} {'ERROR':>8} "
  233. f"{'-':>12} {'-':>12} {'-':>8} {type(exc).__name__}: {exc}"
  234. )
  235. if __name__ == "__main__":
  236. main()