| 1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import math
- import torch
- import torch.nn.functional as F
- import pytest
- from einops import rearrange
- from src.ops.vandermonde import log_vandermonde, log_vandermonde_fast
- @pytest.mark.parametrize('L', [3, 17, 489, 2**10, 1047, 2**11, 2**12])
- @pytest.mark.parametrize('N', [4, 8, 16, 32, 64, 128, 256])
- # @pytest.mark.parametrize('L', [2048])
- # @pytest.mark.parametrize('N', [64])
- def test_vand_mult_symmetric(N, L):
- assert log_vandermonde_fast is not None, 'cauchy extension is not installed'
- rtol, atol = (1e-4, 1e-4) if N <= 64 and L <= 1024 else(1e-3, 1e-3)
- device = 'cuda'
- batch_size = 4
- torch.random.manual_seed(2357)
- v = torch.randn(batch_size, N // 2, dtype=torch.cfloat, device=device, requires_grad=True)
- x = (0.001 * torch.rand(batch_size, N // 2, device=device)
- + 1j * N * torch.rand(batch_size, N // 2, device=device))
- x.requires_grad_()
- v_keops = v.detach().clone().requires_grad_()
- x_keops = x.detach().clone().requires_grad_()
- out_keops = log_vandermonde(v_keops, x_keops, L)
- out = log_vandermonde_fast(v, x, L)
- err_out = (out - out_keops).abs()
- dout = torch.randn_like(out)
- dv_keops, dx_keops = torch.autograd.grad(out_keops, (v_keops, x_keops), dout, retain_graph=True)
- dv, dx = torch.autograd.grad(out, (v, x), dout, retain_graph=True)
- err_dv = (dv - dv_keops).abs()
- err_dx = (dx - dx_keops).abs()
- print(f'out error: max {err_out.amax().item():.6f}, mean {err_out.mean().item():.6f}')
- print(f'dv error: max {err_dv.amax().item():.6f}, mean {err_dv.mean().item():.6f}')
- print(f'dx relative error: max {err_dx.amax().item():.6f}, mean {err_dx.mean().item():.6f}')
- assert torch.allclose(out, out_keops, rtol=rtol, atol=atol)
- assert torch.allclose(dv, dv_keops, rtol=rtol, atol=atol)
- assert torch.allclose(dx, dx_keops, rtol=rtol, atol=atol)
|