test_cauchy.py 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import math
  2. import torch
  3. import pytest
  4. from einops import rearrange
  5. from cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult
  6. def generate_data(batch_size, N, L, symmetric=True, device='cuda'):
  7. if not symmetric:
  8. v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
  9. w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
  10. z = torch.randn(L, dtype=torch.complex64, device=device)
  11. else:
  12. assert N % 2 == 0
  13. v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
  14. v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
  15. w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
  16. w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
  17. z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
  18. return v, z, w
  19. def grad_to_half_grad(dx):
  20. dx_half, dx_half_conj = dx.chunk(2, dim=-1)
  21. return dx_half + dx_half_conj.conj()
  22. @pytest.mark.parametrize('L', [3, 17, 489, 2**10, 1047, 2**11, 2**12, 2**13, 2**14, 2**18])
  23. @pytest.mark.parametrize('N', [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048])
  24. def test_cauchy_mult_symmetric(N, L):
  25. # rtol, atol = (1e-4, 1e-4) if N <= 64 and L <= 1024 else(1e-3, 1e-3)
  26. atol = 1e-4
  27. tol_factor = 2.0 # Our error shouldn't be this much higher than Keops' error
  28. device = 'cuda'
  29. batch_size = 4
  30. torch.random.manual_seed(2357)
  31. v, z, w = generate_data(batch_size, N, L, symmetric=True, device=device)
  32. v_half = v[:, :N // 2].clone().detach().requires_grad_(True)
  33. w_half = w[:, :N // 2].clone().detach().requires_grad_(True)
  34. # out_torch = cauchy_mult_torch(v, z, w, symmetric=True)
  35. out_torch = cauchy_mult_torch(v.cdouble(), z.cdouble(), w.cdouble(), symmetric=True).cfloat()
  36. out_keops = cauchy_mult_keops(v, z, w)
  37. out = cauchy_mult(v_half, z, w_half)
  38. relerr_out_keops = (out_keops - out_torch).abs() / out_torch.abs()
  39. relerr_out = (out - out_torch).abs() / out_torch.abs()
  40. dout = torch.randn_like(out)
  41. dv_torch, dw_torch = torch.autograd.grad(out_torch, (v, w), dout, retain_graph=True)
  42. dv_torch, dw_torch = dv_torch[:, :N // 2], dw_torch[:, :N // 2]
  43. dv_keops, dw_keops = torch.autograd.grad(out_keops, (v, w), dout, retain_graph=True)
  44. dv_keops, dw_keops = grad_to_half_grad(dv_keops), grad_to_half_grad(dw_keops)
  45. dv, dw = torch.autograd.grad(out, (v_half, w_half), dout, retain_graph=True)
  46. relerr_dv_keops = (dv_keops - dv_torch).abs() / dv_torch.abs()
  47. relerr_dv = (dv - dv_torch).abs() / dv_torch.abs()
  48. relerr_dw_keops = (dw_keops - dw_torch).abs() / dw_torch.abs()
  49. relerr_dw = (dw - dw_torch).abs() / dw_torch.abs()
  50. print(f'Keops out relative error: max {relerr_out_keops.amax().item():.6f}, mean {relerr_out_keops.mean().item():6f}')
  51. print(f'out relative error: max {relerr_out.amax().item():.6f}, mean {relerr_out.mean().item():.6f}')
  52. print(f'Keops dv relative error: max {relerr_dv_keops.amax().item():.6f}, mean {relerr_dv_keops.mean().item():6f}')
  53. print(f'dv relative error: max {relerr_dv.amax().item():.6f}, mean {relerr_dv.mean().item():.6f}')
  54. print(f'Keops dw relative error: max {relerr_dw_keops.amax().item():.6f}, mean {relerr_dw_keops.mean().item():6f}')
  55. print(f'dw relative error: max {relerr_dw.amax().item():.6f}, mean {relerr_dw.mean().item():.6f}')
  56. assert (relerr_out.amax() <= relerr_out_keops.amax() * tol_factor + atol)
  57. assert (relerr_out.mean() <= relerr_out_keops.mean() * tol_factor + atol)
  58. # assert torch.allclose(out, out_torch, rtol=rtol, atol=atol)
  59. # assert torch.allclose(out, out_keops, rtol=rtol, atol=atol)
  60. assert (relerr_dv.amax() <= relerr_dv_keops.amax() * tol_factor + atol)
  61. assert (relerr_dv.mean() <= relerr_dv_keops.mean() * tol_factor + atol)
  62. assert (relerr_dw.amax() <= relerr_dw_keops.amax() * tol_factor + atol)
  63. assert (relerr_dw.mean() <= relerr_dw_keops.mean() * tol_factor + atol)
  64. # assert torch.allclose(dv, dv_torch, rtol=1e-4, atol=1e-4)
  65. # assert torch.allclose(dv, dv_keops, rtol=1e-4, atol=1e-4)
  66. # assert torch.allclose(dw, dw_torch, rtol=1e-4, atol=1e-4)
  67. # assert torch.allclose(dw, dw_keops, rtol=1e-4, atol=1e-4)