import torch from .attentions_2d import WaveletAttentionGlobalBranch2d from .nets_2d import wavelet_fft_t2 def run_smoke_test(): with torch.no_grad(): global_op = WaveletAttentionGlobalBranch2d(32, kernel_size=5, wt_levels=2) for shape in ((2, 32, 32, 32), (1, 32, 31, 29)): x = torch.randn(*shape) y = global_op(x) assert y.shape == x.shape, f"global_op shape mismatch: {shape} -> {tuple(y.shape)}" model = wavelet_fft_t2(num_classes=10) model.eval() x = torch.randn(2, 3, 193, 193) y = model(x) assert y.shape == (2, 10), f"model output shape mismatch: {tuple(y.shape)}" return "wavelet_fft smoke test passed" if __name__ == "__main__": print(run_smoke_test())