| 123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- import torch
- from structured_kernels import vand_log_mult_sym_fwd, vand_log_mult_sym_bwd
- def log_vandermonde_cuda(v, z, L):
- """ Wrap the cuda method to deal with shapes """
- v, z = torch.broadcast_tensors(v, z)
- shape = v.shape
- v = v.contiguous()
- z = z.contiguous()
- N = v.size(-1)
- assert z.size(-1) == N
- y = LogVandMultiplySymmetric.apply(v.view(-1, N), z.view(-1, N), L)
- y = y.view(*shape[:-1], L)
- return y
- class LogVandMultiplySymmetric(torch.autograd.Function):
- @staticmethod
- def forward(ctx, v, x, L):
- batch, N = v.shape
- supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]
- if not N in supported_N_values:
- raise NotImplementedError(f'Only support N values in {supported_N_values}')
- max_L_value = 32 * 1024 * 64 * 1024
- if L > max_L_value:
- raise NotImplementedError(f'Only support L values <= {max_L_value}')
- if not v.is_cuda and x.is_cuda:
- raise NotImplementedError(f'Only support CUDA tensors')
- ctx.save_for_backward(v, x)
- return vand_log_mult_sym_fwd(v, x, L)
- @staticmethod
- def backward(ctx, dout):
- v, x = ctx.saved_tensors
- dv, dx = vand_log_mult_sym_bwd(v, x, dout)
- return dv, dx, None
- if vand_log_mult_sym_fwd and vand_log_mult_sym_bwd is not None:
- log_vandermonde_fast = LogVandMultiplySymmetric.apply
- else:
- log_vandermonde_fast = None
|