| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- #include <stdio.h>
- // On pytorch 1.10 and CUDA 10.2, I get compilation errors on torch/csrc/api/include/torch/nn/cloneable.h
- // So we'll only include torch/python.h instead of torch/extension.h
- // Similar to https://github.com/getkeops/keops/blob/3efd428b55c724b12f23982c06de00bc4d02d903/pykeops/torch_headers.h.in#L8
- // #include <torch/extension.h>
- #include <torch/python.h>
- #include <ATen/cuda/CUDAContext.h> // For getCurrentCUDAStream
- #include <THC/THCAtomics.cuh> // For atomicAdd on complex
- #include <ATen/native/cuda/block_reduce.cuh>
- #include <c10/util/complex.h> // For scalar_value_type
- #include "map.h" // For the MAP macro, i.e. for_each over the arguments
- #ifndef ITEMS_PER_THREAD_SYM_FWD_VALUES
- #define ITEMS_PER_THREAD_SYM_FWD_VALUES {2, 4, 8, 16, 32, 32, 32, 64, 64, 64}
- #endif
- #ifndef MAX_BLOCK_SIZE_VALUE
- #define MAX_BLOCK_SIZE_VALUE 256
- #endif
- #ifndef ITEMS_PER_THREAD_SYM_BWD_VALUE
- #define ITEMS_PER_THREAD_SYM_BWD_VALUE 32
- #endif
- static constexpr int ITEMS_PER_THREAD_SYM_FWD[] = ITEMS_PER_THREAD_SYM_FWD_VALUES;
- static constexpr int MAX_BLOCK_SIZE = MAX_BLOCK_SIZE_VALUE;
- static constexpr int ITEMS_PER_THREAD_SYM_BWD = ITEMS_PER_THREAD_SYM_BWD_VALUE;
- template <typename T, size_t N>
- using CudaAcsr = at::GenericPackedTensorAccessor<T, N, at::RestrictPtrTraits, int32_t>;
- constexpr __host__ __device__ int div_up_const(int a, int b) { return (a + b - 1) / b; }
- __host__ __device__ static inline int div_up(int a, int b) { return (a + b - 1) / b;}
- template <typename scalar_t, int log_N,
- int items_per_thread=ITEMS_PER_THREAD_SYM_FWD[log_N - 1]>
- __global__ void cauchy_mult_sym_fwd_cuda_kernel(CudaAcsr<scalar_t, 2> v,
- CudaAcsr<scalar_t, 1> z,
- CudaAcsr<scalar_t, 2> w,
- CudaAcsr<scalar_t, 2> out,
- int L) {
- // Get the float type from the complex type
- // https://github.com/pytorch/pytorch/blob/bceb1db885cafa87fe8d037d8f22ae9649a1bba0/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp#L213
- using float_t = typename at::scalar_value_type<scalar_t>::type;
- constexpr int N = 1 << log_N;
- constexpr int blockDimx = div_up_const(N, items_per_thread);
- constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx;
- // We just want a shared array:
- // __shared__ scalar_t s_b[16];
- // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
- // So we declare a char array and cast it.
- // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
- __shared__ char v_smem_char[N * sizeof(scalar_t)];
- scalar_t *v_smem = (scalar_t *)&v_smem_char;
- __shared__ char w_smem_char[N * sizeof(scalar_t)];
- scalar_t *w_smem = (scalar_t *)&w_smem_char;
- __shared__ char out_smem_char[blockDimy * sizeof(scalar_t)];
- scalar_t *out_smem = (scalar_t *)&out_smem_char;
- int batch_idx = blockIdx.x;
- int tid = threadIdx.x + threadIdx.y * blockDim.x;
- int L_idx = blockIdx.y * blockDim.y + threadIdx.y;
- int L_block_start = blockIdx.y * blockDim.y;
- scalar_t z_t = L_block_start + threadIdx.y < L ? z[L_block_start + threadIdx.y] : scalar_t(0.f);
- for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) {
- v_smem[N_idx] = v[batch_idx][N_idx];
- w_smem[N_idx] = w[batch_idx][N_idx];
- }
- __syncthreads();
- scalar_t result = 0;
- if (L_idx < L) {
- // Combining the two terms (a/b + c/d = (ad + bc)/(bd)) seems to increase numerical errors.
- // Using nvcc --use_fast_math yields the same speed between the two versions.
- // So we don't combine the two terms.
- #pragma unroll
- for (int item = 0; item < items_per_thread; ++item) {
- int N_idx = item * blockDimx + threadIdx.x;
- scalar_t v_t = v_smem[N_idx], w_t = w_smem[N_idx];
- result += v_t / (z_t - w_t) + std::conj(v_t) / (z_t - std::conj(w_t));
- }
- }
- // TODO: this only works for N a power of 2
- #pragma unroll
- for (int offset = blockDimx / 2; offset > 0; offset /= 2) {
- result += WARP_SHFL_DOWN(result, offset);
- }
- if ((threadIdx.x == 0) && (L_idx < L)) {
- out_smem[threadIdx.y] = result;
- }
- __syncthreads();
- if (tid < blockDim.y && L_block_start + tid < L) {
- out[batch_idx][L_block_start + tid] = out_smem[tid];
- }
- }
- torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v,
- torch::Tensor z,
- torch::Tensor w) {
- const int batch_size = v.size(0);
- const int N = v.size(1);
- const int L = z.size(0);
- auto out = torch::empty({batch_size, L}, torch::dtype(v.dtype()).device(v.device()));
- auto stream = at::cuda::getCurrentCUDAStream();
- using scalar_t = c10::complex<float>;
- const auto v_a = v.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
- const auto z_a = z.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>();
- const auto w_a = w.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
- auto out_a = out.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
- int log_N = int(log2((double) N));
- int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]);
- dim3 block(block_x, MAX_BLOCK_SIZE / block_x);
- dim3 grid(batch_size, div_up(L, block.y));
- switch (log_N) {
- #define CASE_LOG_N(log_N_val) case log_N_val: \
- cauchy_mult_sym_fwd_cuda_kernel<scalar_t, log_N_val> \
- <<<grid, block, 0, stream>>>(v_a, z_a, w_a, out_a, L); break;
- MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
- }
- #undef CASE_LOG_N
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return out;
- }
- template <typename scalar_t, bool check_L_boundary>
- __global__ void cauchy_mult_sym_bwd_cuda_kernel(CudaAcsr<scalar_t, 2> v,
- CudaAcsr<scalar_t, 1> z,
- CudaAcsr<scalar_t, 2> w,
- CudaAcsr<scalar_t, 2> dout,
- CudaAcsr<scalar_t, 3> dv,
- CudaAcsr<scalar_t, 3> dw,
- int L,
- int L_chunk_size) {
- // We just want a shared array:
- // __shared__ scalar_t s_b[16];
- // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
- // So we declare a char array and cast it.
- // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
- __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(scalar_t)];
- scalar_t *dv_smem = (scalar_t *)&dv_smem_char;
- __shared__ char dw_smem_char[C10_WARP_SIZE * sizeof(scalar_t)];
- scalar_t *dw_smem = (scalar_t *)&dw_smem_char;
- int batch_idx = blockIdx.x;
- int N_idx = blockIdx.y;
- int L_chunk_idx = blockIdx.z;
- int tid = threadIdx.x;
- scalar_t w_conj_t = std::conj(w[batch_idx][N_idx]);
- scalar_t dv_t = 0;
- scalar_t dw_t = 0;
- #pragma unroll
- for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) {
- int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x;
- scalar_t dout_t, z_t;
- if (check_L_boundary) {
- dout_t = l < L ? dout[batch_idx][l] : 0;
- z_t = l < L ? z[l] : 1;
- } else { // Not checking boundary can speed it up quite a bit, around 30%.
- dout_t = dout[batch_idx][l];
- z_t = z[l];
- }
- scalar_t denom_1 = std::conj(z_t) - w_conj_t;
- scalar_t denom_2 = z_t - w_conj_t;
- scalar_t term_1 = dout_t / denom_1;
- scalar_t term_2 = std::conj(dout_t) / denom_2;
- dv_t += term_1 + term_2;
- dw_t += term_1 / denom_1 + term_2 / denom_2;
- }
- dv_t = at::native::cuda_utils::BlockReduceSum<scalar_t>(dv_t, dv_smem);
- dw_t = at::native::cuda_utils::BlockReduceSum<scalar_t>(dw_t, dw_smem);
- if (tid == 0) {
- dw[batch_idx][N_idx][L_chunk_idx] = dw_t * std::conj(v[batch_idx][N_idx]);
- dv[batch_idx][N_idx][L_chunk_idx] = dv_t;
- }
- }
- std::tuple<torch::Tensor, torch::Tensor>
- cauchy_mult_sym_bwd_cuda(torch::Tensor v,
- torch::Tensor z,
- torch::Tensor w,
- torch::Tensor dout) {
- const int batch_size = v.size(0);
- const int N = v.size(1);
- const int L = z.size(0);
- constexpr int MAX_BLOCK_SIZE = 1024;
- constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE;
- const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE);
- auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device()));
- auto dw = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(w.dtype()).device(w.device()));
- auto stream = at::cuda::getCurrentCUDAStream();
- using scalar_t = c10::complex<float>;
- const auto v_a = v.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
- const auto z_a = z.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>();
- const auto w_a = w.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
- const auto dout_a = dout.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
- auto dv_a = dv.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>();
- auto dw_a = dw.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>();
- // Each block need to have a multiple of 32 threads, otherwise
- // at::native::cuda_utils::BlockReduceSum to produce wrong result.
- // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE);
- const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE);
- int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE;
- bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks;
- dim3 block(block_x);
- dim3 grid(batch_size, N, n_L_chunks);
- check_L_boundary
- ? cauchy_mult_sym_bwd_cuda_kernel<scalar_t, true>
- <<<grid, block, 0, stream>>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size)
- : cauchy_mult_sym_bwd_cuda_kernel<scalar_t, false>
- <<<grid, block, 0, stream>>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return std::make_tuple(dv.sum(-1), dw.sum(-1));
- }
- template <int log_N, int items_per_thread=ITEMS_PER_THREAD_SYM_FWD[log_N - 1]>
- __global__ void vand_log_mult_sym_fwd_cuda_kernel(CudaAcsr<c10::complex<float>, 2> v,
- CudaAcsr<c10::complex<float>, 2> x,
- CudaAcsr<float, 2> out,
- int L) {
- using cfloat_t = typename c10::complex<float>;
- constexpr int N = 1 << log_N;
- constexpr int blockDimx = div_up_const(N, items_per_thread);
- constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx;
- // We just want a shared array:
- // __shared__ cfloat_t s_b[16];
- // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
- // So we declare a char array and cast it.
- // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
- __shared__ char v_smem_char[N * sizeof(cfloat_t)];
- cfloat_t *v_smem = (cfloat_t *)&v_smem_char;
- __shared__ char x_smem_char[N * sizeof(cfloat_t)];
- cfloat_t *x_smem = (cfloat_t *)&x_smem_char;
- __shared__ float out_smem[blockDimy];
- int batch_idx = blockIdx.x;
- int tid = threadIdx.x + threadIdx.y * blockDim.x;
- int L_idx = blockIdx.y * blockDim.y + threadIdx.y;
- int L_block_start = blockIdx.y * blockDim.y;
- for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) {
- v_smem[N_idx] = v[batch_idx][N_idx];
- x_smem[N_idx] = x[batch_idx][N_idx];
- }
- __syncthreads();
- float result = 0;
- if (L_idx < L) {
- #pragma unroll
- for (int item = 0; item < items_per_thread; ++item) {
- int N_idx = item * blockDimx + threadIdx.x;
- cfloat_t v_t = v_smem[N_idx], x_t = x_smem[N_idx];
- result += (std::exp(x_t * L_idx) * v_t).real_;
- }
- }
- // TODO: this only works for N a power of 2
- #pragma unroll
- for (int offset = blockDimx / 2; offset > 0; offset /= 2) {
- result += WARP_SHFL_DOWN(result, offset);
- }
- if ((threadIdx.x == 0) && (L_idx < L)) {
- out_smem[threadIdx.y] = 2 * result;
- }
- __syncthreads();
- if (tid < blockDim.y && L_block_start + tid < L) {
- out[batch_idx][L_block_start + tid] = out_smem[tid];
- }
- }
- torch::Tensor vand_log_mult_sym_fwd_cuda(torch::Tensor v, torch::Tensor x, int L) {
- const int batch_size = v.size(0);
- const int N = v.size(1);
- auto opts = v.options();
- auto out = torch::empty({batch_size, L}, opts.dtype(torch::kFloat32));
- auto stream = at::cuda::getCurrentCUDAStream();
- const auto v_a = v.packed_accessor32<c10::complex<float>, 2, at::RestrictPtrTraits>();
- const auto x_a = x.packed_accessor32<c10::complex<float>, 2, at::RestrictPtrTraits>();
- auto out_a = out.packed_accessor32<float, 2, at::RestrictPtrTraits>();
- int log_N = int(log2((double) N));
- int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]);
- dim3 block(block_x, MAX_BLOCK_SIZE / block_x);
- dim3 grid(batch_size, div_up(L, block.y));
- switch (log_N) {
- #define CASE_LOG_N(log_N_val) case log_N_val: \
- vand_log_mult_sym_fwd_cuda_kernel<log_N_val> \
- <<<grid, block, 0, stream>>>(v_a, x_a, out_a, L); break;
- MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
- }
- #undef CASE_LOG_N
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return out;
- }
- template <bool check_L_boundary>
- __global__ void vand_log_mult_sym_bwd_cuda_kernel(CudaAcsr<c10::complex<float>, 2> v,
- CudaAcsr<c10::complex<float>, 2> x,
- CudaAcsr<float, 2> dout,
- CudaAcsr<c10::complex<float>, 3> dv,
- CudaAcsr<c10::complex<float>, 3> dx,
- int L,
- int L_chunk_size) {
- using cfloat_t = typename c10::complex<float>;
- // We just want a shared array:
- // __shared__ c10::complex<float> s_b[16];
- // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
- // So we declare a char array and cast it.
- // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
- __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)];
- cfloat_t *dv_smem = (cfloat_t *)&dv_smem_char;
- __shared__ char dx_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)];
- cfloat_t *dx_smem = (cfloat_t *)&dx_smem_char;
- int batch_idx = blockIdx.x;
- int N_idx = blockIdx.y;
- int L_chunk_idx = blockIdx.z;
- int tid = threadIdx.x;
- cfloat_t x_t = x[batch_idx][N_idx];
- cfloat_t dv_t = 0;
- cfloat_t dx_t = 0;
- #pragma unroll
- for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) {
- int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x;
- float dout_t;
- if (check_L_boundary) {
- dout_t = l < L ? dout[batch_idx][l] : 0;
- } else { // Not checking boundary can speed it up quite a bit.
- dout_t = dout[batch_idx][l];
- }
- // Need to conjugate as we're doing complex gradient.
- cfloat_t do_exp_x_t = dout_t * std::conj(std::exp(x_t * l));
- dv_t += do_exp_x_t;
- dx_t += do_exp_x_t * l;
- }
- dv_t = at::native::cuda_utils::BlockReduceSum<cfloat_t>(dv_t, dv_smem);
- dx_t = at::native::cuda_utils::BlockReduceSum<cfloat_t>(dx_t, dx_smem);
- if (tid == 0) {
- dx[batch_idx][N_idx][L_chunk_idx] = 2 * dx_t * std::conj(v[batch_idx][N_idx]);
- dv[batch_idx][N_idx][L_chunk_idx] = 2 * dv_t;
- }
- }
- std::tuple<torch::Tensor, torch::Tensor>
- vand_log_mult_sym_bwd_cuda(torch::Tensor v,
- torch::Tensor x,
- torch::Tensor dout) {
- const int batch_size = v.size(0);
- const int N = v.size(1);
- const int L = dout.size(1);
- constexpr int MAX_BLOCK_SIZE = 1024;
- constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE;
- const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE);
- auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device()));
- auto dx = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(x.dtype()).device(x.device()));
- auto stream = at::cuda::getCurrentCUDAStream();
- using cfloat_t = c10::complex<float>;
- const auto v_a = v.packed_accessor32<cfloat_t, 2, at::RestrictPtrTraits>();
- const auto x_a = x.packed_accessor32<cfloat_t, 2, at::RestrictPtrTraits>();
- const auto dout_a = dout.packed_accessor32<float, 2, at::RestrictPtrTraits>();
- auto dv_a = dv.packed_accessor32<cfloat_t, 3, at::RestrictPtrTraits>();
- auto dx_a = dx.packed_accessor32<cfloat_t, 3, at::RestrictPtrTraits>();
- // Each block need to have a multiple of 32 threads, otherwise
- // at::native::cuda_utils::BlockReduceSum to produce wrong result.
- // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE);
- const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE);
- int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE;
- bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks;
- dim3 block(block_x);
- dim3 grid(batch_size, N, n_L_chunks);
- check_L_boundary
- ? vand_log_mult_sym_bwd_cuda_kernel<true>
- <<<grid, block, 0, stream>>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size)
- : vand_log_mult_sym_bwd_cuda_kernel<false>
- <<<grid, block, 0, stream>>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- return std::make_tuple(dv.sum(-1), dx.sum(-1));
- }
|