| 1234567891011121314151617181920212223242526 |
- import torch
- from .attentions_2d import WaveletAttentionGlobalBranch2d
- from .nets_2d import wavelet_fft_t2
- def run_smoke_test():
- with torch.no_grad():
- global_op = WaveletAttentionGlobalBranch2d(32, kernel_size=5, wt_levels=2)
- for shape in ((2, 32, 32, 32), (1, 32, 31, 29)):
- x = torch.randn(*shape)
- y = global_op(x)
- assert y.shape == x.shape, f"global_op shape mismatch: {shape} -> {tuple(y.shape)}"
- model = wavelet_fft_t2(num_classes=10)
- model.eval()
- x = torch.randn(2, 3, 193, 193)
- y = model(x)
- assert y.shape == (2, 10), f"model output shape mismatch: {tuple(y.shape)}"
- return "wavelet_fft smoke test passed"
- if __name__ == "__main__":
- print(run_smoke_test())
|