test_vandermonde.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. import pytest
  5. from einops import rearrange
  6. from src.ops.vandermonde import log_vandermonde, log_vandermonde_fast
  7. @pytest.mark.parametrize('L', [3, 17, 489, 2**10, 1047, 2**11, 2**12])
  8. @pytest.mark.parametrize('N', [4, 8, 16, 32, 64, 128, 256])
  9. # @pytest.mark.parametrize('L', [2048])
  10. # @pytest.mark.parametrize('N', [64])
  11. def test_vand_mult_symmetric(N, L):
  12. assert log_vandermonde_fast is not None, 'cauchy extension is not installed'
  13. rtol, atol = (1e-4, 1e-4) if N <= 64 and L <= 1024 else(1e-3, 1e-3)
  14. device = 'cuda'
  15. batch_size = 4
  16. torch.random.manual_seed(2357)
  17. v = torch.randn(batch_size, N // 2, dtype=torch.cfloat, device=device, requires_grad=True)
  18. x = (0.001 * torch.rand(batch_size, N // 2, device=device)
  19. + 1j * N * torch.rand(batch_size, N // 2, device=device))
  20. x.requires_grad_()
  21. v_keops = v.detach().clone().requires_grad_()
  22. x_keops = x.detach().clone().requires_grad_()
  23. out_keops = log_vandermonde(v_keops, x_keops, L)
  24. out = log_vandermonde_fast(v, x, L)
  25. err_out = (out - out_keops).abs()
  26. dout = torch.randn_like(out)
  27. dv_keops, dx_keops = torch.autograd.grad(out_keops, (v_keops, x_keops), dout, retain_graph=True)
  28. dv, dx = torch.autograd.grad(out, (v, x), dout, retain_graph=True)
  29. err_dv = (dv - dv_keops).abs()
  30. err_dx = (dx - dx_keops).abs()
  31. print(f'out error: max {err_out.amax().item():.6f}, mean {err_out.mean().item():.6f}')
  32. print(f'dv error: max {err_dv.amax().item():.6f}, mean {err_dv.mean().item():.6f}')
  33. print(f'dx relative error: max {err_dx.amax().item():.6f}, mean {err_dx.mean().item():.6f}')
  34. assert torch.allclose(out, out_keops, rtol=rtol, atol=atol)
  35. assert torch.allclose(dv, dv_keops, rtol=rtol, atol=atol)
  36. assert torch.allclose(dx, dx_keops, rtol=rtol, atol=atol)