import os # 加入当前目录的父目录到系统路径,以便导入lib模块 current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) if parent_dir not in os.sys.path: os.sys.path.insert(0, parent_dir) import torch from lib.modules import XNet2d assert torch.cuda.is_available(), "CUDA is not available" model = ( XNet2d( in_channels=3, num_classes=1, encoder_channels=(32, 64, 128, 192), encoder_depths=(2, 2, 2, 2), decoder_channels=(128, 64, 32), ssm_backend="auto", ) .cuda() .eval() ) x = torch.randn(1, 3, 128, 128, device="cuda") with torch.no_grad(): y = model(x) print("keys:", sorted(y.keys())) print("seg_logits:", tuple(y["seg_logits"].shape))