test_v3.py 760 B

123456789101112131415161718192021222324252627282930313233343536
  1. import os
  2. current_dir = os.path.dirname(os.path.abspath(__file__))
  3. parent_dir = os.path.dirname(current_dir)
  4. if parent_dir not in os.sys.path:
  5. os.sys.path.insert(0, parent_dir)
  6. import torch, time
  7. from lib.modules import XNet2d
  8. model = (
  9. XNet2d(
  10. in_channels=3,
  11. num_classes=1,
  12. ssm_backend="auto",
  13. ssm_forward_type="v3",
  14. )
  15. .cuda()
  16. .eval()
  17. )
  18. x = torch.randn(1, 3, 128, 128, device="cuda")
  19. for _ in range(5):
  20. with torch.no_grad():
  21. _ = model(x)
  22. torch.cuda.synchronize()
  23. t = time.perf_counter()
  24. for _ in range(20):
  25. with torch.no_grad():
  26. y = model(x)
  27. torch.cuda.synchronize()
  28. print("seg_logits:", tuple(y["seg_logits"].shape))
  29. print("avg_forward_sec:", (time.perf_counter() - t) / 20)