cauchy.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from pathlib import Path
  2. import torch
  3. from einops import rearrange
  4. from structured_kernels import cauchy_mult_sym_fwd, cauchy_mult_sym_bwd
  5. # try:
  6. # from cauchy_mult import cauchy_mult_sym_fwd, cauchy_mult_sym_bwd
  7. # except ImportError:
  8. # from torch.utils.cpp_extension import load
  9. # current_dir = Path(__file__).parent.absolute()
  10. # cauchy_mult_extension = load(
  11. # name='cauchy_mult',
  12. # sources=[str(current_dir / 'cauchy.cpp'), str(current_dir / 'cauchy_cuda.cu')],
  13. # extra_cflags=['-g', '-march=native', '-funroll-loops'],
  14. # extra_cuda_cflags=['-O3', '-lineinfo', '--use_fast_math'],
  15. # extra_include_paths=str(current_dir),
  16. # build_directory=str(current_dir),
  17. # verbose=True
  18. # )
  19. # cauchy_mult_sym_fwd = cauchy_mult_extension.cauchy_mult_sym_fwd
  20. # cauchy_mult_sym_bwd = cauchy_mult_extension.cauchy_mult_sym_bwd
  21. def cauchy_mult_torch(v: torch.Tensor, z: torch.Tensor, w: torch.Tensor,
  22. symmetric=True) -> torch.Tensor:
  23. """
  24. v: (B, N)
  25. z: (L)
  26. w: (B, N)
  27. symmetric: whether to assume that v and w contain complex conjugate pairs, of the form
  28. [v_half, v_half.conj()] and [w_half, w_half.conj()]
  29. """
  30. if not symmetric:
  31. return (rearrange(v, 'b n -> b 1 n') / (rearrange(z, 'l -> l 1') - rearrange(w, 'b n -> b 1 n'))).sum(dim=-1)
  32. else:
  33. N = v.shape[-1]
  34. assert N % 2 == 0
  35. vv = rearrange(v[:, :N // 2], 'b n -> b 1 n')
  36. zz = rearrange(z, 'l -> l 1')
  37. ww = rearrange(w[:, :N // 2], 'b n -> b 1 n')
  38. # return 2 * ((zz * vv.real - vv.real * ww.real - vv.imag * ww.imag)
  39. # / (zz * zz - 2 * zz * ww.real + ww.abs().square())).sum(dim=-1)
  40. return (vv / (zz - ww) + vv.conj() / (zz - ww.conj())).sum(dim=-1)
  41. def cauchy_mult_keops(v, z, w):
  42. from pykeops.torch import LazyTensor
  43. v_l = LazyTensor(rearrange(v, 'b N -> b 1 N 1'))
  44. z_l = LazyTensor(rearrange(z, 'L -> 1 L 1 1'))
  45. w_l = LazyTensor(rearrange(w, 'b N -> b 1 N 1'))
  46. sub = z_l - w_l # (b N L 1), for some reason it doesn't display the last dimension
  47. div = v_l / sub
  48. s = div.sum(dim=2, backend='GPU')
  49. return s.squeeze(-1)
  50. def _cauchy_mult(v, z, w):
  51. return CauchyMultiplySymmetric.apply(v, z, w)
  52. def cauchy_mult(v, z, w):
  53. """ Wrap the cuda method to deal with shapes """
  54. v, w = torch.broadcast_tensors(v, w)
  55. shape = v.shape
  56. # z_shape = z.shape
  57. # z = z.squeeze()
  58. assert len(z.shape) == 1
  59. v = v.contiguous()
  60. w = w.contiguous()
  61. z = z.contiguous()
  62. N = v.size(-1)
  63. assert w.size(-1) == N
  64. y = _cauchy_mult(v.view(-1, N), z, w.view(-1, N))
  65. y = y.view(*shape[:-1], z.size(-1))
  66. return y
  67. class CauchyMultiplySymmetric(torch.autograd.Function):
  68. @staticmethod
  69. def forward(ctx, v, z, w):
  70. batch, N = v.shape
  71. supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
  72. L = z.shape[-1]
  73. if not N in supported_N_values:
  74. raise NotImplementedError(f'Only support N values in {supported_N_values}')
  75. max_L_value = 32 * 1024 * 64 * 1024
  76. if L > max_L_value:
  77. raise NotImplementedError(f'Only support L values <= {max_L_value}')
  78. if not (v.is_cuda and z.is_cuda and w.is_cuda):
  79. raise NotImplementedError(f'Only support CUDA tensors')
  80. ctx.save_for_backward(v, z, w)
  81. return cauchy_mult_sym_fwd(v, z, w)
  82. @staticmethod
  83. def backward(ctx, dout):
  84. v, z, w = ctx.saved_tensors
  85. dv, dw = cauchy_mult_sym_bwd(v, z, w, dout)
  86. return dv, None, dw