cauchy.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #include <vector>
  2. #include <utility>
  3. #include <cmath>
  4. #include <torch/extension.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
  7. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  8. torch::Tensor cauchy_mult_sym_fwd_cuda(torch::Tensor v,
  9. torch::Tensor z,
  10. torch::Tensor w);
  11. std::tuple<torch::Tensor, torch::Tensor> cauchy_mult_sym_bwd_cuda(torch::Tensor v,
  12. torch::Tensor z,
  13. torch::Tensor w,
  14. torch::Tensor dout);
  15. namespace cauchy {
  16. torch::Tensor cauchy_mult_sym_fwd(torch::Tensor v,
  17. torch::Tensor z,
  18. torch::Tensor w) {
  19. CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w);
  20. const auto batch_size = v.size(0);
  21. const auto N = v.size(1);
  22. const auto L = z.size(0);
  23. CHECK_SHAPE(v, batch_size, N);
  24. CHECK_SHAPE(z, L);
  25. CHECK_SHAPE(w, batch_size, N);
  26. // Otherwise the kernel will be launched from cuda:0 device
  27. // Cast to char to avoid compiler warning about narrowing
  28. at::cuda::CUDAGuard device_guard{(char)v.get_device()};
  29. return cauchy_mult_sym_fwd_cuda(v, z, w);
  30. }
  31. std::tuple<torch::Tensor, torch::Tensor>
  32. cauchy_mult_sym_bwd(torch::Tensor v,
  33. torch::Tensor z,
  34. torch::Tensor w,
  35. torch::Tensor dout) {
  36. CHECK_DEVICE(v); CHECK_DEVICE(z); CHECK_DEVICE(w); CHECK_DEVICE(dout);
  37. const auto batch_size = v.size(0);
  38. const auto N = v.size(1);
  39. const auto L = z.size(0);
  40. CHECK_SHAPE(v, batch_size, N);
  41. CHECK_SHAPE(z, L);
  42. CHECK_SHAPE(w, batch_size, N);
  43. CHECK_SHAPE(dout, batch_size, L);
  44. // Otherwise the kernel will be launched from cuda:0 device
  45. // Cast to char to avoid compiler warning about narrowing
  46. at::cuda::CUDAGuard device_guard{(char)v.get_device()};
  47. return cauchy_mult_sym_bwd_cuda(v, z, w, dout);
  48. }
  49. } // cauchy
  50. torch::Tensor vand_log_mult_sym_fwd_cuda(torch::Tensor v, torch::Tensor x, int L);
  51. std::tuple<torch::Tensor, torch::Tensor>
  52. vand_log_mult_sym_bwd_cuda(torch::Tensor v, torch::Tensor x, torch::Tensor dout);
  53. namespace vand {
  54. torch::Tensor vand_log_mult_sym_fwd(torch::Tensor v, torch::Tensor x, int L) {
  55. CHECK_DEVICE(v); CHECK_DEVICE(x);
  56. const auto batch_size = v.size(0);
  57. const auto N = v.size(1);
  58. CHECK_SHAPE(v, batch_size, N);
  59. CHECK_SHAPE(x, batch_size, N);
  60. // Otherwise the kernel will be launched from cuda:0 device
  61. // Cast to char to avoid compiler warning about narrowing
  62. at::cuda::CUDAGuard device_guard{(char)v.get_device()};
  63. return vand_log_mult_sym_fwd_cuda(v, x, L);
  64. }
  65. std::tuple<torch::Tensor, torch::Tensor>
  66. vand_log_mult_sym_bwd(torch::Tensor v, torch::Tensor x, torch::Tensor dout) {
  67. CHECK_DEVICE(v); CHECK_DEVICE(x); CHECK_DEVICE(dout);
  68. const auto batch_size = v.size(0);
  69. const auto N = v.size(1);
  70. const auto L = dout.size(1);
  71. CHECK_SHAPE(v, batch_size, N);
  72. CHECK_SHAPE(x, batch_size, N);
  73. CHECK_SHAPE(dout, batch_size, L);
  74. // Otherwise the kernel will be launched from cuda:0 device
  75. // Cast to char to avoid compiler warning about narrowing
  76. at::cuda::CUDAGuard device_guard{(char)v.get_device()};
  77. return vand_log_mult_sym_bwd_cuda(v, x, dout);
  78. }
  79. } // vand
  80. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  81. m.def("cauchy_mult_sym_fwd", &cauchy::cauchy_mult_sym_fwd,
  82. "Cauchy multiply symmetric forward");
  83. m.def("cauchy_mult_sym_bwd", &cauchy::cauchy_mult_sym_bwd,
  84. "Cauchy multiply symmetric backward");
  85. m.def("vand_log_mult_sym_fwd", &vand::vand_log_mult_sym_fwd,
  86. "Log Vandermonde multiply symmetric forward");
  87. m.def("vand_log_mult_sym_bwd", &vand::vand_log_mult_sym_bwd,
  88. "Log Vandermonde multiply symmetric backward");
  89. }