smoke_test.py 775 B

1234567891011121314151617181920212223242526
  1. import torch
  2. from .attentions_2d import WaveletAttentionGlobalBranch2d
  3. from .nets_2d import wavelet_fft_t2
  4. def run_smoke_test():
  5. with torch.no_grad():
  6. global_op = WaveletAttentionGlobalBranch2d(32, kernel_size=5, wt_levels=2)
  7. for shape in ((2, 32, 32, 32), (1, 32, 31, 29)):
  8. x = torch.randn(*shape)
  9. y = global_op(x)
  10. assert y.shape == x.shape, f"global_op shape mismatch: {shape} -> {tuple(y.shape)}"
  11. model = wavelet_fft_t2(num_classes=10)
  12. model.eval()
  13. x = torch.randn(2, 3, 193, 193)
  14. y = model(x)
  15. assert y.shape == (2, 10), f"model output shape mismatch: {tuple(y.shape)}"
  16. return "wavelet_fft smoke test passed"
  17. if __name__ == "__main__":
  18. print(run_smoke_test())