vandermonde.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import torch
  2. from structured_kernels import vand_log_mult_sym_fwd, vand_log_mult_sym_bwd
  3. def log_vandermonde_cuda(v, z, L):
  4. """ Wrap the cuda method to deal with shapes """
  5. v, z = torch.broadcast_tensors(v, z)
  6. shape = v.shape
  7. v = v.contiguous()
  8. z = z.contiguous()
  9. N = v.size(-1)
  10. assert z.size(-1) == N
  11. y = LogVandMultiplySymmetric.apply(v.view(-1, N), z.view(-1, N), L)
  12. y = y.view(*shape[:-1], L)
  13. return y
  14. class LogVandMultiplySymmetric(torch.autograd.Function):
  15. @staticmethod
  16. def forward(ctx, v, x, L):
  17. batch, N = v.shape
  18. supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
  19. if not N in supported_N_values:
  20. raise NotImplementedError(f'Only support N values in {supported_N_values}')
  21. max_L_value = 32 * 1024 * 64 * 1024
  22. if L > max_L_value:
  23. raise NotImplementedError(f'Only support L values <= {max_L_value}')
  24. if not v.is_cuda and x.is_cuda:
  25. raise NotImplementedError(f'Only support CUDA tensors')
  26. ctx.save_for_backward(v, x)
  27. return vand_log_mult_sym_fwd(v, x, L)
  28. @staticmethod
  29. def backward(ctx, dout):
  30. v, x = ctx.saved_tensors
  31. dv, dx = vand_log_mult_sym_bwd(v, x, dout)
  32. return dv, dx, None
  33. if vand_log_mult_sym_fwd and vand_log_mult_sym_bwd is not None:
  34. log_vandermonde_fast = LogVandMultiplySymmetric.apply
  35. else:
  36. log_vandermonde_fast = None