#include // On pytorch 1.10 and CUDA 10.2, I get compilation errors on torch/csrc/api/include/torch/nn/cloneable.h // So we'll only include torch/python.h instead of torch/extension.h // Similar to https://github.com/getkeops/keops/blob/3efd428b55c724b12f23982c06de00bc4d02d903/pykeops/torch_headers.h.in#L8 // #include #include #include // For getCurrentCUDAStream #include // For atomicAdd on complex #include #include // For scalar_value_type #include "map.h" // For the MAP macro, i.e. for_each over the arguments #ifndef ITEMS_PER_THREAD_SYM_FWD_VALUES #define ITEMS_PER_THREAD_SYM_FWD_VALUES {2, 4, 8, 16, 32, 32, 32, 64, 64, 64} #endif #ifndef MAX_BLOCK_SIZE_VALUE #define MAX_BLOCK_SIZE_VALUE 256 #endif #ifndef ITEMS_PER_THREAD_SYM_BWD_VALUE #define ITEMS_PER_THREAD_SYM_BWD_VALUE 32 #endif static constexpr int ITEMS_PER_THREAD_SYM_FWD[] = ITEMS_PER_THREAD_SYM_FWD_VALUES; static constexpr int MAX_BLOCK_SIZE = MAX_BLOCK_SIZE_VALUE; static constexpr int ITEMS_PER_THREAD_SYM_BWD = ITEMS_PER_THREAD_SYM_BWD_VALUE; template using CudaAcsr = at::GenericPackedTensorAccessor; constexpr __host__ __device__ int div_up_const(int a, int b) { return (a + b - 1) / b; } __host__ __device__ static inline int div_up(int a, int b) { return (a + b - 1) / b;} template __global__ void cauchy_mult_sym_fwd_cuda_kernel(CudaAcsr v, CudaAcsr z, CudaAcsr w, CudaAcsr out, int L) { // Get the float type from the complex type // https://github.com/pytorch/pytorch/blob/bceb1db885cafa87fe8d037d8f22ae9649a1bba0/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp#L213 using float_t = typename at::scalar_value_type::type; constexpr int N = 1 << log_N; constexpr int blockDimx = div_up_const(N, items_per_thread); constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx; // We just want a shared array: // __shared__ scalar_t s_b[16]; // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 // So we declare a char array and cast it. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer __shared__ char v_smem_char[N * sizeof(scalar_t)]; scalar_t *v_smem = (scalar_t *)&v_smem_char; __shared__ char w_smem_char[N * sizeof(scalar_t)]; scalar_t *w_smem = (scalar_t *)&w_smem_char; __shared__ char out_smem_char[blockDimy * sizeof(scalar_t)]; scalar_t *out_smem = (scalar_t *)&out_smem_char; int batch_idx = blockIdx.x; int tid = threadIdx.x + threadIdx.y * blockDim.x; int L_idx = blockIdx.y * blockDim.y + threadIdx.y; int L_block_start = blockIdx.y * blockDim.y; scalar_t z_t = L_block_start + threadIdx.y < L ? z[L_block_start + threadIdx.y] : scalar_t(0.f); for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) { v_smem[N_idx] = v[batch_idx][N_idx]; w_smem[N_idx] = w[batch_idx][N_idx]; } __syncthreads(); scalar_t result = 0; if (L_idx < L) { // Combining the two terms (a/b + c/d = (ad + bc)/(bd)) seems to increase numerical errors. // Using nvcc --use_fast_math yields the same speed between the two versions. // So we don't combine the two terms. #pragma unroll for (int item = 0; item < items_per_thread; ++item) { int N_idx = item * blockDimx + threadIdx.x; scalar_t v_t = v_smem[N_idx], w_t = w_smem[N_idx]; result += v_t / (z_t - w_t) + std::conj(v_t) / (z_t - std::conj(w_t)); } } // TODO: this only works for N a power of 2 #pragma unroll for (int offset = blockDimx / 2; offset > 0; offset /= 2) { result += WARP_SHFL_DOWN(result, offset); } if ((threadIdx.x == 0) && (L_idx < L)) { out_smem[threadIdx.y] = result; } __syncthreads(); if (tid < blockDim.y && L_block_start + tid < L) { out[batch_idx][L_block_start + tid] = out_smem[tid]; } } torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v, torch::Tensor z, torch::Tensor w) { const int batch_size = v.size(0); const int N = v.size(1); const int L = z.size(0); auto out = torch::empty({batch_size, L}, torch::dtype(v.dtype()).device(v.device())); auto stream = at::cuda::getCurrentCUDAStream(); using scalar_t = c10::complex; const auto v_a = v.packed_accessor32(); const auto z_a = z.packed_accessor32(); const auto w_a = w.packed_accessor32(); auto out_a = out.packed_accessor32(); int log_N = int(log2((double) N)); int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]); dim3 block(block_x, MAX_BLOCK_SIZE / block_x); dim3 grid(batch_size, div_up(L, block.y)); switch (log_N) { #define CASE_LOG_N(log_N_val) case log_N_val: \ cauchy_mult_sym_fwd_cuda_kernel \ <<>>(v_a, z_a, w_a, out_a, L); break; MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) } #undef CASE_LOG_N C10_CUDA_KERNEL_LAUNCH_CHECK(); return out; } template __global__ void cauchy_mult_sym_bwd_cuda_kernel(CudaAcsr v, CudaAcsr z, CudaAcsr w, CudaAcsr dout, CudaAcsr dv, CudaAcsr dw, int L, int L_chunk_size) { // We just want a shared array: // __shared__ scalar_t s_b[16]; // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 // So we declare a char array and cast it. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(scalar_t)]; scalar_t *dv_smem = (scalar_t *)&dv_smem_char; __shared__ char dw_smem_char[C10_WARP_SIZE * sizeof(scalar_t)]; scalar_t *dw_smem = (scalar_t *)&dw_smem_char; int batch_idx = blockIdx.x; int N_idx = blockIdx.y; int L_chunk_idx = blockIdx.z; int tid = threadIdx.x; scalar_t w_conj_t = std::conj(w[batch_idx][N_idx]); scalar_t dv_t = 0; scalar_t dw_t = 0; #pragma unroll for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) { int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x; scalar_t dout_t, z_t; if (check_L_boundary) { dout_t = l < L ? dout[batch_idx][l] : 0; z_t = l < L ? z[l] : 1; } else { // Not checking boundary can speed it up quite a bit, around 30%. dout_t = dout[batch_idx][l]; z_t = z[l]; } scalar_t denom_1 = std::conj(z_t) - w_conj_t; scalar_t denom_2 = z_t - w_conj_t; scalar_t term_1 = dout_t / denom_1; scalar_t term_2 = std::conj(dout_t) / denom_2; dv_t += term_1 + term_2; dw_t += term_1 / denom_1 + term_2 / denom_2; } dv_t = at::native::cuda_utils::BlockReduceSum(dv_t, dv_smem); dw_t = at::native::cuda_utils::BlockReduceSum(dw_t, dw_smem); if (tid == 0) { dw[batch_idx][N_idx][L_chunk_idx] = dw_t * std::conj(v[batch_idx][N_idx]); dv[batch_idx][N_idx][L_chunk_idx] = dv_t; } } std::tuple cauchy_mult_sym_bwd_cuda(torch::Tensor v, torch::Tensor z, torch::Tensor w, torch::Tensor dout) { const int batch_size = v.size(0); const int N = v.size(1); const int L = z.size(0); constexpr int MAX_BLOCK_SIZE = 1024; constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE; const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE); auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device())); auto dw = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(w.dtype()).device(w.device())); auto stream = at::cuda::getCurrentCUDAStream(); using scalar_t = c10::complex; const auto v_a = v.packed_accessor32(); const auto z_a = z.packed_accessor32(); const auto w_a = w.packed_accessor32(); const auto dout_a = dout.packed_accessor32(); auto dv_a = dv.packed_accessor32(); auto dw_a = dw.packed_accessor32(); // Each block need to have a multiple of 32 threads, otherwise // at::native::cuda_utils::BlockReduceSum to produce wrong result. // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE); const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE); int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE; bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks; dim3 block(block_x); dim3 grid(batch_size, N, n_L_chunks); check_L_boundary ? cauchy_mult_sym_bwd_cuda_kernel <<>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size) : cauchy_mult_sym_bwd_cuda_kernel <<>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size); C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(dv.sum(-1), dw.sum(-1)); } template __global__ void vand_log_mult_sym_fwd_cuda_kernel(CudaAcsr, 2> v, CudaAcsr, 2> x, CudaAcsr out, int L) { using cfloat_t = typename c10::complex; constexpr int N = 1 << log_N; constexpr int blockDimx = div_up_const(N, items_per_thread); constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx; // We just want a shared array: // __shared__ cfloat_t s_b[16]; // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 // So we declare a char array and cast it. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer __shared__ char v_smem_char[N * sizeof(cfloat_t)]; cfloat_t *v_smem = (cfloat_t *)&v_smem_char; __shared__ char x_smem_char[N * sizeof(cfloat_t)]; cfloat_t *x_smem = (cfloat_t *)&x_smem_char; __shared__ float out_smem[blockDimy]; int batch_idx = blockIdx.x; int tid = threadIdx.x + threadIdx.y * blockDim.x; int L_idx = blockIdx.y * blockDim.y + threadIdx.y; int L_block_start = blockIdx.y * blockDim.y; for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) { v_smem[N_idx] = v[batch_idx][N_idx]; x_smem[N_idx] = x[batch_idx][N_idx]; } __syncthreads(); float result = 0; if (L_idx < L) { #pragma unroll for (int item = 0; item < items_per_thread; ++item) { int N_idx = item * blockDimx + threadIdx.x; cfloat_t v_t = v_smem[N_idx], x_t = x_smem[N_idx]; result += (std::exp(x_t * L_idx) * v_t).real_; } } // TODO: this only works for N a power of 2 #pragma unroll for (int offset = blockDimx / 2; offset > 0; offset /= 2) { result += WARP_SHFL_DOWN(result, offset); } if ((threadIdx.x == 0) && (L_idx < L)) { out_smem[threadIdx.y] = 2 * result; } __syncthreads(); if (tid < blockDim.y && L_block_start + tid < L) { out[batch_idx][L_block_start + tid] = out_smem[tid]; } } torch::Tensor vand_log_mult_sym_fwd_cuda(torch::Tensor v, torch::Tensor x, int L) { const int batch_size = v.size(0); const int N = v.size(1); auto opts = v.options(); auto out = torch::empty({batch_size, L}, opts.dtype(torch::kFloat32)); auto stream = at::cuda::getCurrentCUDAStream(); const auto v_a = v.packed_accessor32, 2, at::RestrictPtrTraits>(); const auto x_a = x.packed_accessor32, 2, at::RestrictPtrTraits>(); auto out_a = out.packed_accessor32(); int log_N = int(log2((double) N)); int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]); dim3 block(block_x, MAX_BLOCK_SIZE / block_x); dim3 grid(batch_size, div_up(L, block.y)); switch (log_N) { #define CASE_LOG_N(log_N_val) case log_N_val: \ vand_log_mult_sym_fwd_cuda_kernel \ <<>>(v_a, x_a, out_a, L); break; MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) } #undef CASE_LOG_N C10_CUDA_KERNEL_LAUNCH_CHECK(); return out; } template __global__ void vand_log_mult_sym_bwd_cuda_kernel(CudaAcsr, 2> v, CudaAcsr, 2> x, CudaAcsr dout, CudaAcsr, 3> dv, CudaAcsr, 3> dx, int L, int L_chunk_size) { using cfloat_t = typename c10::complex; // We just want a shared array: // __shared__ c10::complex s_b[16]; // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270 // So we declare a char array and cast it. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)]; cfloat_t *dv_smem = (cfloat_t *)&dv_smem_char; __shared__ char dx_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)]; cfloat_t *dx_smem = (cfloat_t *)&dx_smem_char; int batch_idx = blockIdx.x; int N_idx = blockIdx.y; int L_chunk_idx = blockIdx.z; int tid = threadIdx.x; cfloat_t x_t = x[batch_idx][N_idx]; cfloat_t dv_t = 0; cfloat_t dx_t = 0; #pragma unroll for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) { int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x; float dout_t; if (check_L_boundary) { dout_t = l < L ? dout[batch_idx][l] : 0; } else { // Not checking boundary can speed it up quite a bit. dout_t = dout[batch_idx][l]; } // Need to conjugate as we're doing complex gradient. cfloat_t do_exp_x_t = dout_t * std::conj(std::exp(x_t * l)); dv_t += do_exp_x_t; dx_t += do_exp_x_t * l; } dv_t = at::native::cuda_utils::BlockReduceSum(dv_t, dv_smem); dx_t = at::native::cuda_utils::BlockReduceSum(dx_t, dx_smem); if (tid == 0) { dx[batch_idx][N_idx][L_chunk_idx] = 2 * dx_t * std::conj(v[batch_idx][N_idx]); dv[batch_idx][N_idx][L_chunk_idx] = 2 * dv_t; } } std::tuple vand_log_mult_sym_bwd_cuda(torch::Tensor v, torch::Tensor x, torch::Tensor dout) { const int batch_size = v.size(0); const int N = v.size(1); const int L = dout.size(1); constexpr int MAX_BLOCK_SIZE = 1024; constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE; const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE); auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device())); auto dx = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(x.dtype()).device(x.device())); auto stream = at::cuda::getCurrentCUDAStream(); using cfloat_t = c10::complex; const auto v_a = v.packed_accessor32(); const auto x_a = x.packed_accessor32(); const auto dout_a = dout.packed_accessor32(); auto dv_a = dv.packed_accessor32(); auto dx_a = dx.packed_accessor32(); // Each block need to have a multiple of 32 threads, otherwise // at::native::cuda_utils::BlockReduceSum to produce wrong result. // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE); const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE); int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE; bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks; dim3 block(block_x); dim3 grid(batch_size, N, n_L_chunks); check_L_boundary ? vand_log_mult_sym_bwd_cuda_kernel <<>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size) : vand_log_mult_sym_bwd_cuda_kernel <<>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size); C10_CUDA_KERNEL_LAUNCH_CHECK(); return std::make_tuple(dv.sum(-1), dx.sum(-1)); }