import os current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) if parent_dir not in os.sys.path: os.sys.path.insert(0, parent_dir) import torch, time from lib.modules import XNet2d model = ( XNet2d( in_channels=3, num_classes=1, ssm_backend="auto", ssm_forward_type="v3", ) .cuda() .eval() ) x = torch.randn(1, 3, 128, 128, device="cuda") for _ in range(5): with torch.no_grad(): _ = model(x) torch.cuda.synchronize() t = time.perf_counter() for _ in range(20): with torch.no_grad(): y = model(x) torch.cuda.synchronize() print("seg_logits:", tuple(y["seg_logits"].shape)) print("avg_forward_sec:", (time.perf_counter() - t) / 20)