| 1234567891011121314151617181920212223242526272829303132 |
- import os
- # 加入当前目录的父目录到系统路径,以便导入lib模块
- current_dir = os.path.dirname(os.path.abspath(__file__))
- parent_dir = os.path.dirname(current_dir)
- if parent_dir not in os.sys.path:
- os.sys.path.insert(0, parent_dir)
- import torch
- from lib.modules import XNet2d
- assert torch.cuda.is_available(), "CUDA is not available"
- model = (
- XNet2d(
- in_channels=3,
- num_classes=1,
- encoder_channels=(32, 64, 128, 192),
- encoder_depths=(2, 2, 2, 2),
- decoder_channels=(128, 64, 32),
- ssm_backend="auto",
- )
- .cuda()
- .eval()
- )
- x = torch.randn(1, 3, 128, 128, device="cuda")
- with torch.no_grad():
- y = model(x)
- print("keys:", sorted(y.keys()))
- print("seg_logits:", tuple(y["seg_logits"].shape))
|