benchmark_cauchy.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import math
  2. from functools import partial
  3. import torch
  4. from einops import rearrange
  5. from .cauchy import cauchy_mult_torch, cauchy_mult_keops, cauchy_mult
  6. from benchmark.utils import benchmark_all, benchmark_combined, benchmark_forward, benchmark_backward
  7. def generate_data(batch_size, N, L, symmetric=True, device='cuda'):
  8. if not symmetric:
  9. v = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
  10. w = torch.randn(batch_size, N, dtype=torch.complex64, device=device, requires_grad=True)
  11. z = torch.randn(L, dtype=torch.complex64, device=device)
  12. else:
  13. assert N % 2 == 0
  14. v_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
  15. v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
  16. w_half = torch.randn(batch_size, N // 2, dtype=torch.complex64, device=device)
  17. w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
  18. z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
  19. return v, z, w
  20. if __name__ == '__main__':
  21. device = 'cuda'
  22. bs = 1024
  23. N = 64
  24. L = 16384
  25. v, z, w = generate_data(bs, N, L, symmetric=True)
  26. v_half = v[:, :N // 2].clone().detach().requires_grad_(True)
  27. w_half = w[:, :N // 2].clone().detach().requires_grad_(True)
  28. repeat = 30
  29. benchmark_all(repeat, cauchy_mult_keops, v, z, w, desc='Cauchy mult keops')
  30. fn = partial(cauchy_mult, symmetric=False)
  31. benchmark_all(repeat, fn, v, z, w, desc='Cauchy mult')
  32. fn = partial(cauchy_mult, symmetric=True)
  33. benchmark_all(repeat, fn, v_half, z, w_half, desc='Cauchy mult symmetric')