Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

gemm_fp16_cublas.cpp 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  1. #include <cublas_v2.h>
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <cuda_runtime.h>
  5. #include <torch/extension.h>
  6. #include <c10/cuda/CUDAGuard.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #define CUBLAS_CHECK(condition) \
  9. for (cublasStatus_t _cublas_check_status = (condition); \
  10. _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
  11. throw std::runtime_error("cuBLAS error " + \
  12. std::to_string(_cublas_check_status) + " at " + \
  13. std::to_string(__LINE__));
  14. #define CUDA_CHECK(condition) \
  15. for (cudaError_t _cuda_check_status = (condition); \
  16. _cuda_check_status != cudaSuccess;) \
  17. throw std::runtime_error( \
  18. "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
  19. " at " + std::to_string(__LINE__));
  20. /*
  21. NOTE: blas gemm is column-major by default, but we need row-major output.
  22. The data of row-major, transposed matrix is exactly the same as the
  23. column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
  24. */
  25. void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
  26. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  27. const auto cuda_data_type = CUDA_R_16F;
  28. const auto cuda_c_data_type =
  29. c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
  30. const auto compute_type = CUDA_R_32F;
  31. const float sp_alpha = 1.f;
  32. // swap a and b, and use CUBLAS_OP_N. see the notes above
  33. std::swap(a, b);
  34. const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
  35. const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
  36. // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
  37. // negative axis is used because of the existence of batch matmul.
  38. const int m = a.size(-1);
  39. const int k = a.size(-2);
  40. const int n = b.size(-2);
  41. const int cublas_lda = m;
  42. const int cublas_ldb = k;
  43. const int cublas_ldc = m;
  44. cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
  45. #if CUDA_VERSION >= 11000
  46. cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
  47. #else
  48. cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
  49. #endif
  50. const float sp_beta = 0.f;
  51. if (a.sizes().size() == 2 && b.sizes().size() == 2) {
  52. CUBLAS_CHECK(cublasGemmEx(
  53. cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
  54. a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
  55. cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
  56. compute_type, algo));
  57. } else {
  58. // batch matmul
  59. assert(a.sizes().size() == 3 && b.sizes().size() == 3);
  60. const long long int cublas_stride_a = m * k;
  61. const long long int cublas_stride_b = k * n;
  62. const long long int cublas_stride_c = m * n;
  63. CUBLAS_CHECK(cublasGemmStridedBatchedEx(
  64. cublas_handle, cublas_trans_a, cublas_trans_b, m,
  65. n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
  66. cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
  67. &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
  68. a.size(0), compute_type, algo));
  69. }
  70. }
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...