cauchy_cuda.cu 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. #include <stdio.h>
  2. // On pytorch 1.10 and CUDA 10.2, I get compilation errors on torch/csrc/api/include/torch/nn/cloneable.h
  3. // So we'll only include torch/python.h instead of torch/extension.h
  4. // Similar to https://github.com/getkeops/keops/blob/3efd428b55c724b12f23982c06de00bc4d02d903/pykeops/torch_headers.h.in#L8
  5. // #include <torch/extension.h>
  6. #include <torch/python.h>
  7. #include <ATen/cuda/CUDAContext.h> // For getCurrentCUDAStream
  8. #include <THC/THCAtomics.cuh> // For atomicAdd on complex
  9. #include <ATen/native/cuda/block_reduce.cuh>
  10. #include <c10/util/complex.h> // For scalar_value_type
  11. #include "map.h" // For the MAP macro, i.e. for_each over the arguments
  12. #ifndef ITEMS_PER_THREAD_SYM_FWD_VALUES
  13. #define ITEMS_PER_THREAD_SYM_FWD_VALUES {2, 4, 8, 16, 32, 32, 32, 64, 64, 64}
  14. #endif
  15. #ifndef MAX_BLOCK_SIZE_VALUE
  16. #define MAX_BLOCK_SIZE_VALUE 256
  17. #endif
  18. #ifndef ITEMS_PER_THREAD_SYM_BWD_VALUE
  19. #define ITEMS_PER_THREAD_SYM_BWD_VALUE 32
  20. #endif
  21. static constexpr int ITEMS_PER_THREAD_SYM_FWD[] = ITEMS_PER_THREAD_SYM_FWD_VALUES;
  22. static constexpr int MAX_BLOCK_SIZE = MAX_BLOCK_SIZE_VALUE;
  23. static constexpr int ITEMS_PER_THREAD_SYM_BWD = ITEMS_PER_THREAD_SYM_BWD_VALUE;
  24. template <typename T, size_t N>
  25. using CudaAcsr = at::GenericPackedTensorAccessor<T, N, at::RestrictPtrTraits, int32_t>;
  26. constexpr __host__ __device__ int div_up_const(int a, int b) { return (a + b - 1) / b; }
  27. __host__ __device__ static inline int div_up(int a, int b) { return (a + b - 1) / b;}
  28. template <typename scalar_t, int log_N,
  29. int items_per_thread=ITEMS_PER_THREAD_SYM_FWD[log_N - 1]>
  30. __global__ void cauchy_mult_sym_fwd_cuda_kernel(CudaAcsr<scalar_t, 2> v,
  31. CudaAcsr<scalar_t, 1> z,
  32. CudaAcsr<scalar_t, 2> w,
  33. CudaAcsr<scalar_t, 2> out,
  34. int L) {
  35. // Get the float type from the complex type
  36. // https://github.com/pytorch/pytorch/blob/bceb1db885cafa87fe8d037d8f22ae9649a1bba0/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp#L213
  37. using float_t = typename at::scalar_value_type<scalar_t>::type;
  38. constexpr int N = 1 << log_N;
  39. constexpr int blockDimx = div_up_const(N, items_per_thread);
  40. constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx;
  41. // We just want a shared array:
  42. // __shared__ scalar_t s_b[16];
  43. // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
  44. // So we declare a char array and cast it.
  45. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
  46. __shared__ char v_smem_char[N * sizeof(scalar_t)];
  47. scalar_t *v_smem = (scalar_t *)&v_smem_char;
  48. __shared__ char w_smem_char[N * sizeof(scalar_t)];
  49. scalar_t *w_smem = (scalar_t *)&w_smem_char;
  50. __shared__ char out_smem_char[blockDimy * sizeof(scalar_t)];
  51. scalar_t *out_smem = (scalar_t *)&out_smem_char;
  52. int batch_idx = blockIdx.x;
  53. int tid = threadIdx.x + threadIdx.y * blockDim.x;
  54. int L_idx = blockIdx.y * blockDim.y + threadIdx.y;
  55. int L_block_start = blockIdx.y * blockDim.y;
  56. scalar_t z_t = L_block_start + threadIdx.y < L ? z[L_block_start + threadIdx.y] : scalar_t(0.f);
  57. for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) {
  58. v_smem[N_idx] = v[batch_idx][N_idx];
  59. w_smem[N_idx] = w[batch_idx][N_idx];
  60. }
  61. __syncthreads();
  62. scalar_t result = 0;
  63. if (L_idx < L) {
  64. // Combining the two terms (a/b + c/d = (ad + bc)/(bd)) seems to increase numerical errors.
  65. // Using nvcc --use_fast_math yields the same speed between the two versions.
  66. // So we don't combine the two terms.
  67. #pragma unroll
  68. for (int item = 0; item < items_per_thread; ++item) {
  69. int N_idx = item * blockDimx + threadIdx.x;
  70. scalar_t v_t = v_smem[N_idx], w_t = w_smem[N_idx];
  71. result += v_t / (z_t - w_t) + std::conj(v_t) / (z_t - std::conj(w_t));
  72. }
  73. }
  74. // TODO: this only works for N a power of 2
  75. #pragma unroll
  76. for (int offset = blockDimx / 2; offset > 0; offset /= 2) {
  77. result += WARP_SHFL_DOWN(result, offset);
  78. }
  79. if ((threadIdx.x == 0) && (L_idx < L)) {
  80. out_smem[threadIdx.y] = result;
  81. }
  82. __syncthreads();
  83. if (tid < blockDim.y && L_block_start + tid < L) {
  84. out[batch_idx][L_block_start + tid] = out_smem[tid];
  85. }
  86. }
  87. torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v,
  88. torch::Tensor z,
  89. torch::Tensor w) {
  90. const int batch_size = v.size(0);
  91. const int N = v.size(1);
  92. const int L = z.size(0);
  93. auto out = torch::empty({batch_size, L}, torch::dtype(v.dtype()).device(v.device()));
  94. auto stream = at::cuda::getCurrentCUDAStream();
  95. using scalar_t = c10::complex<float>;
  96. const auto v_a = v.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
  97. const auto z_a = z.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>();
  98. const auto w_a = w.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
  99. auto out_a = out.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
  100. int log_N = int(log2((double) N));
  101. int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]);
  102. dim3 block(block_x, MAX_BLOCK_SIZE / block_x);
  103. dim3 grid(batch_size, div_up(L, block.y));
  104. switch (log_N) {
  105. #define CASE_LOG_N(log_N_val) case log_N_val: \
  106. cauchy_mult_sym_fwd_cuda_kernel<scalar_t, log_N_val> \
  107. <<<grid, block, 0, stream>>>(v_a, z_a, w_a, out_a, L); break;
  108. MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
  109. }
  110. #undef CASE_LOG_N
  111. C10_CUDA_KERNEL_LAUNCH_CHECK();
  112. return out;
  113. }
  114. template <typename scalar_t, bool check_L_boundary>
  115. __global__ void cauchy_mult_sym_bwd_cuda_kernel(CudaAcsr<scalar_t, 2> v,
  116. CudaAcsr<scalar_t, 1> z,
  117. CudaAcsr<scalar_t, 2> w,
  118. CudaAcsr<scalar_t, 2> dout,
  119. CudaAcsr<scalar_t, 3> dv,
  120. CudaAcsr<scalar_t, 3> dw,
  121. int L,
  122. int L_chunk_size) {
  123. // We just want a shared array:
  124. // __shared__ scalar_t s_b[16];
  125. // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
  126. // So we declare a char array and cast it.
  127. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
  128. __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(scalar_t)];
  129. scalar_t *dv_smem = (scalar_t *)&dv_smem_char;
  130. __shared__ char dw_smem_char[C10_WARP_SIZE * sizeof(scalar_t)];
  131. scalar_t *dw_smem = (scalar_t *)&dw_smem_char;
  132. int batch_idx = blockIdx.x;
  133. int N_idx = blockIdx.y;
  134. int L_chunk_idx = blockIdx.z;
  135. int tid = threadIdx.x;
  136. scalar_t w_conj_t = std::conj(w[batch_idx][N_idx]);
  137. scalar_t dv_t = 0;
  138. scalar_t dw_t = 0;
  139. #pragma unroll
  140. for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) {
  141. int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x;
  142. scalar_t dout_t, z_t;
  143. if (check_L_boundary) {
  144. dout_t = l < L ? dout[batch_idx][l] : 0;
  145. z_t = l < L ? z[l] : 1;
  146. } else { // Not checking boundary can speed it up quite a bit, around 30%.
  147. dout_t = dout[batch_idx][l];
  148. z_t = z[l];
  149. }
  150. scalar_t denom_1 = std::conj(z_t) - w_conj_t;
  151. scalar_t denom_2 = z_t - w_conj_t;
  152. scalar_t term_1 = dout_t / denom_1;
  153. scalar_t term_2 = std::conj(dout_t) / denom_2;
  154. dv_t += term_1 + term_2;
  155. dw_t += term_1 / denom_1 + term_2 / denom_2;
  156. }
  157. dv_t = at::native::cuda_utils::BlockReduceSum<scalar_t>(dv_t, dv_smem);
  158. dw_t = at::native::cuda_utils::BlockReduceSum<scalar_t>(dw_t, dw_smem);
  159. if (tid == 0) {
  160. dw[batch_idx][N_idx][L_chunk_idx] = dw_t * std::conj(v[batch_idx][N_idx]);
  161. dv[batch_idx][N_idx][L_chunk_idx] = dv_t;
  162. }
  163. }
  164. std::tuple<torch::Tensor, torch::Tensor>
  165. cauchy_mult_sym_bwd_cuda(torch::Tensor v,
  166. torch::Tensor z,
  167. torch::Tensor w,
  168. torch::Tensor dout) {
  169. const int batch_size = v.size(0);
  170. const int N = v.size(1);
  171. const int L = z.size(0);
  172. constexpr int MAX_BLOCK_SIZE = 1024;
  173. constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE;
  174. const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE);
  175. auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device()));
  176. auto dw = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(w.dtype()).device(w.device()));
  177. auto stream = at::cuda::getCurrentCUDAStream();
  178. using scalar_t = c10::complex<float>;
  179. const auto v_a = v.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
  180. const auto z_a = z.packed_accessor32<scalar_t, 1, at::RestrictPtrTraits>();
  181. const auto w_a = w.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
  182. const auto dout_a = dout.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>();
  183. auto dv_a = dv.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>();
  184. auto dw_a = dw.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>();
  185. // Each block need to have a multiple of 32 threads, otherwise
  186. // at::native::cuda_utils::BlockReduceSum to produce wrong result.
  187. // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE);
  188. const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE);
  189. int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE;
  190. bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks;
  191. dim3 block(block_x);
  192. dim3 grid(batch_size, N, n_L_chunks);
  193. check_L_boundary
  194. ? cauchy_mult_sym_bwd_cuda_kernel<scalar_t, true>
  195. <<<grid, block, 0, stream>>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size)
  196. : cauchy_mult_sym_bwd_cuda_kernel<scalar_t, false>
  197. <<<grid, block, 0, stream>>>(v_a, z_a, w_a, dout_a, dv_a, dw_a, L, L_chunk_size);
  198. C10_CUDA_KERNEL_LAUNCH_CHECK();
  199. return std::make_tuple(dv.sum(-1), dw.sum(-1));
  200. }
  201. template <int log_N, int items_per_thread=ITEMS_PER_THREAD_SYM_FWD[log_N - 1]>
  202. __global__ void vand_log_mult_sym_fwd_cuda_kernel(CudaAcsr<c10::complex<float>, 2> v,
  203. CudaAcsr<c10::complex<float>, 2> x,
  204. CudaAcsr<float, 2> out,
  205. int L) {
  206. using cfloat_t = typename c10::complex<float>;
  207. constexpr int N = 1 << log_N;
  208. constexpr int blockDimx = div_up_const(N, items_per_thread);
  209. constexpr int blockDimy = MAX_BLOCK_SIZE / blockDimx;
  210. // We just want a shared array:
  211. // __shared__ cfloat_t s_b[16];
  212. // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
  213. // So we declare a char array and cast it.
  214. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
  215. __shared__ char v_smem_char[N * sizeof(cfloat_t)];
  216. cfloat_t *v_smem = (cfloat_t *)&v_smem_char;
  217. __shared__ char x_smem_char[N * sizeof(cfloat_t)];
  218. cfloat_t *x_smem = (cfloat_t *)&x_smem_char;
  219. __shared__ float out_smem[blockDimy];
  220. int batch_idx = blockIdx.x;
  221. int tid = threadIdx.x + threadIdx.y * blockDim.x;
  222. int L_idx = blockIdx.y * blockDim.y + threadIdx.y;
  223. int L_block_start = blockIdx.y * blockDim.y;
  224. for (int N_idx = threadIdx.x + threadIdx.y * blockDim.x; N_idx < N; N_idx += blockDim.x * blockDim.y) {
  225. v_smem[N_idx] = v[batch_idx][N_idx];
  226. x_smem[N_idx] = x[batch_idx][N_idx];
  227. }
  228. __syncthreads();
  229. float result = 0;
  230. if (L_idx < L) {
  231. #pragma unroll
  232. for (int item = 0; item < items_per_thread; ++item) {
  233. int N_idx = item * blockDimx + threadIdx.x;
  234. cfloat_t v_t = v_smem[N_idx], x_t = x_smem[N_idx];
  235. result += (std::exp(x_t * L_idx) * v_t).real_;
  236. }
  237. }
  238. // TODO: this only works for N a power of 2
  239. #pragma unroll
  240. for (int offset = blockDimx / 2; offset > 0; offset /= 2) {
  241. result += WARP_SHFL_DOWN(result, offset);
  242. }
  243. if ((threadIdx.x == 0) && (L_idx < L)) {
  244. out_smem[threadIdx.y] = 2 * result;
  245. }
  246. __syncthreads();
  247. if (tid < blockDim.y && L_block_start + tid < L) {
  248. out[batch_idx][L_block_start + tid] = out_smem[tid];
  249. }
  250. }
  251. torch::Tensor vand_log_mult_sym_fwd_cuda(torch::Tensor v, torch::Tensor x, int L) {
  252. const int batch_size = v.size(0);
  253. const int N = v.size(1);
  254. auto opts = v.options();
  255. auto out = torch::empty({batch_size, L}, opts.dtype(torch::kFloat32));
  256. auto stream = at::cuda::getCurrentCUDAStream();
  257. const auto v_a = v.packed_accessor32<c10::complex<float>, 2, at::RestrictPtrTraits>();
  258. const auto x_a = x.packed_accessor32<c10::complex<float>, 2, at::RestrictPtrTraits>();
  259. auto out_a = out.packed_accessor32<float, 2, at::RestrictPtrTraits>();
  260. int log_N = int(log2((double) N));
  261. int block_x = div_up(N, ITEMS_PER_THREAD_SYM_FWD[log_N - 1]);
  262. dim3 block(block_x, MAX_BLOCK_SIZE / block_x);
  263. dim3 grid(batch_size, div_up(L, block.y));
  264. switch (log_N) {
  265. #define CASE_LOG_N(log_N_val) case log_N_val: \
  266. vand_log_mult_sym_fwd_cuda_kernel<log_N_val> \
  267. <<<grid, block, 0, stream>>>(v_a, x_a, out_a, L); break;
  268. MAP(CASE_LOG_N, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
  269. }
  270. #undef CASE_LOG_N
  271. C10_CUDA_KERNEL_LAUNCH_CHECK();
  272. return out;
  273. }
  274. template <bool check_L_boundary>
  275. __global__ void vand_log_mult_sym_bwd_cuda_kernel(CudaAcsr<c10::complex<float>, 2> v,
  276. CudaAcsr<c10::complex<float>, 2> x,
  277. CudaAcsr<float, 2> dout,
  278. CudaAcsr<c10::complex<float>, 3> dv,
  279. CudaAcsr<c10::complex<float>, 3> dx,
  280. int L,
  281. int L_chunk_size) {
  282. using cfloat_t = typename c10::complex<float>;
  283. // We just want a shared array:
  284. // __shared__ c10::complex<float> s_b[16];
  285. // But it doesn't work for complex: https://github.com/pytorch/pytorch/issues/39270
  286. // So we declare a char array and cast it.
  287. // The casting is subtle: https://stackoverflow.com/questions/12692310/convert-array-to-two-dimensional-array-by-pointer
  288. __shared__ char dv_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)];
  289. cfloat_t *dv_smem = (cfloat_t *)&dv_smem_char;
  290. __shared__ char dx_smem_char[C10_WARP_SIZE * sizeof(cfloat_t)];
  291. cfloat_t *dx_smem = (cfloat_t *)&dx_smem_char;
  292. int batch_idx = blockIdx.x;
  293. int N_idx = blockIdx.y;
  294. int L_chunk_idx = blockIdx.z;
  295. int tid = threadIdx.x;
  296. cfloat_t x_t = x[batch_idx][N_idx];
  297. cfloat_t dv_t = 0;
  298. cfloat_t dx_t = 0;
  299. #pragma unroll
  300. for (int item = 0; item < ITEMS_PER_THREAD_SYM_BWD; ++item) {
  301. int l = L_chunk_idx * L_chunk_size + item * blockDim.x + threadIdx.x;
  302. float dout_t;
  303. if (check_L_boundary) {
  304. dout_t = l < L ? dout[batch_idx][l] : 0;
  305. } else { // Not checking boundary can speed it up quite a bit.
  306. dout_t = dout[batch_idx][l];
  307. }
  308. // Need to conjugate as we're doing complex gradient.
  309. cfloat_t do_exp_x_t = dout_t * std::conj(std::exp(x_t * l));
  310. dv_t += do_exp_x_t;
  311. dx_t += do_exp_x_t * l;
  312. }
  313. dv_t = at::native::cuda_utils::BlockReduceSum<cfloat_t>(dv_t, dv_smem);
  314. dx_t = at::native::cuda_utils::BlockReduceSum<cfloat_t>(dx_t, dx_smem);
  315. if (tid == 0) {
  316. dx[batch_idx][N_idx][L_chunk_idx] = 2 * dx_t * std::conj(v[batch_idx][N_idx]);
  317. dv[batch_idx][N_idx][L_chunk_idx] = 2 * dv_t;
  318. }
  319. }
  320. std::tuple<torch::Tensor, torch::Tensor>
  321. vand_log_mult_sym_bwd_cuda(torch::Tensor v,
  322. torch::Tensor x,
  323. torch::Tensor dout) {
  324. const int batch_size = v.size(0);
  325. const int N = v.size(1);
  326. const int L = dout.size(1);
  327. constexpr int MAX_BLOCK_SIZE = 1024;
  328. constexpr int MAX_L_CHUNK_SIZE = ITEMS_PER_THREAD_SYM_BWD * MAX_BLOCK_SIZE;
  329. const int n_L_chunks = div_up(L, MAX_L_CHUNK_SIZE);
  330. auto dv = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(v.dtype()).device(v.device()));
  331. auto dx = torch::empty({batch_size, N, n_L_chunks}, torch::dtype(x.dtype()).device(x.device()));
  332. auto stream = at::cuda::getCurrentCUDAStream();
  333. using cfloat_t = c10::complex<float>;
  334. const auto v_a = v.packed_accessor32<cfloat_t, 2, at::RestrictPtrTraits>();
  335. const auto x_a = x.packed_accessor32<cfloat_t, 2, at::RestrictPtrTraits>();
  336. const auto dout_a = dout.packed_accessor32<float, 2, at::RestrictPtrTraits>();
  337. auto dv_a = dv.packed_accessor32<cfloat_t, 3, at::RestrictPtrTraits>();
  338. auto dx_a = dx.packed_accessor32<cfloat_t, 3, at::RestrictPtrTraits>();
  339. // Each block need to have a multiple of 32 threads, otherwise
  340. // at::native::cuda_utils::BlockReduceSum to produce wrong result.
  341. // int block_x = max(div_up(L, ITEMS_PER_THREAD_SYM_BWD), C10_WARP_SIZE);
  342. const int L_chunk_size = min(L, MAX_L_CHUNK_SIZE);
  343. int block_x = div_up(L_chunk_size, ITEMS_PER_THREAD_SYM_BWD * C10_WARP_SIZE) * C10_WARP_SIZE;
  344. bool check_L_boundary = L != block_x * ITEMS_PER_THREAD_SYM_BWD * n_L_chunks;
  345. dim3 block(block_x);
  346. dim3 grid(batch_size, N, n_L_chunks);
  347. check_L_boundary
  348. ? vand_log_mult_sym_bwd_cuda_kernel<true>
  349. <<<grid, block, 0, stream>>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size)
  350. : vand_log_mult_sym_bwd_cuda_kernel<false>
  351. <<<grid, block, 0, stream>>>(v_a, x_a, dout_a, dv_a, dx_a, L, L_chunk_size);
  352. C10_CUDA_KERNEL_LAUNCH_CHECK();
  353. return std::make_tuple(dv.sum(-1), dx.sum(-1));
  354. }