| 123456789101112131415161718192021222324252627282930313233343536 |
- import os
- current_dir = os.path.dirname(os.path.abspath(__file__))
- parent_dir = os.path.dirname(current_dir)
- if parent_dir not in os.sys.path:
- os.sys.path.insert(0, parent_dir)
- import torch, time
- from lib.modules import XNet2d
- model = (
- XNet2d(
- in_channels=3,
- num_classes=1,
- ssm_backend="auto",
- ssm_forward_type="v3",
- )
- .cuda()
- .eval()
- )
- x = torch.randn(1, 3, 128, 128, device="cuda")
- for _ in range(5):
- with torch.no_grad():
- _ = model(x)
- torch.cuda.synchronize()
- t = time.perf_counter()
- for _ in range(20):
- with torch.no_grad():
- y = model(x)
- torch.cuda.synchronize()
- print("seg_logits:", tuple(y["seg_logits"].shape))
- print("avg_forward_sec:", (time.perf_counter() - t) / 20)
|