Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

wrapper.cpp 4.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
  1. #include <torch/extension.h>
  2. #include "ATen/ATen.h"
  3. #include <iostream>
  4. #include <c10/cuda/CUDAGuard.h>
  5. typedef at::Half fp16;
  6. template <typename F>
  7. void cuda_wkv_forward(int B, int T, int C,
  8. float *w, float *u, F *k, F *v, F *y,
  9. float *aa, float *bb, float *pp);
  10. template <typename F>
  11. void cuda_mm8_seq(int B, int N, int M,
  12. F *x, int x_stride,
  13. uint8_t *w, int w_stride,
  14. F *mx, F *rx,
  15. F *my, F *ry,
  16. F *y, int y_stride);
  17. template <typename F>
  18. void cuda_mm8_one(int N, int M,
  19. F *x,
  20. uint8_t *w, int w_stride,
  21. F *mx, F *rx,
  22. F *my, F *ry,
  23. float *y);
  24. void wkv_forward(int64_t B, int64_t T, int64_t C,
  25. torch::Tensor &w, torch::Tensor &u,
  26. torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
  27. torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
  28. const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
  29. switch (k.scalar_type()) {
  30. case c10::ScalarType::Half:
  31. cuda_wkv_forward(B, T, C,
  32. w.data_ptr<float>(), u.data_ptr<float>(),
  33. k.data_ptr<fp16>(), v.data_ptr<fp16>(), y.data_ptr<fp16>(),
  34. aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
  35. break;
  36. case c10::ScalarType::Float:
  37. cuda_wkv_forward(B, T, C,
  38. w.data_ptr<float>(), u.data_ptr<float>(),
  39. k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(),
  40. aa.data_ptr<float>(), bb.data_ptr<float>(), pp.data_ptr<float>());
  41. break;
  42. default:
  43. assert(false && "Only FP16 and FP32 are currently supported");
  44. }
  45. }
  46. void mm8_seq(int64_t B, int64_t N, int64_t M,
  47. torch::Tensor &x, torch::Tensor &w,
  48. torch::Tensor &mx, torch::Tensor &rx,
  49. torch::Tensor &my, torch::Tensor &ry,
  50. torch::Tensor &y) {
  51. assert(x.stride(1) == 1);
  52. assert(w.stride(1) == 1);
  53. assert(mx.stride(0) == 1 && rx.stride(0) == 1);
  54. assert(my.stride(0) == 1 && ry.stride(0) == 1);
  55. assert(y.stride(1) == 1);
  56. const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
  57. switch (x.scalar_type()) {
  58. case c10::ScalarType::Half:
  59. cuda_mm8_seq(
  60. B, N, M,
  61. x.data_ptr<fp16>(), x.stride(0),
  62. w.data_ptr<uint8_t>(), w.stride(0),
  63. mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
  64. my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
  65. y.data_ptr<fp16>(), y.stride(0));
  66. break;
  67. case c10::ScalarType::Float:
  68. cuda_mm8_seq(
  69. B, N, M,
  70. x.data_ptr<float>(), x.stride(0),
  71. w.data_ptr<uint8_t>(), w.stride(0),
  72. mx.data_ptr<float>(), rx.data_ptr<float>(),
  73. my.data_ptr<float>(), ry.data_ptr<float>(),
  74. y.data_ptr<float>(), y.stride(0));
  75. break;
  76. default:
  77. assert(false && "Only FP16 and FP32 are currently supported");
  78. }
  79. }
  80. void mm8_one(int64_t N, int64_t M,
  81. torch::Tensor &x, torch::Tensor &w,
  82. torch::Tensor &mx, torch::Tensor &rx,
  83. torch::Tensor &my, torch::Tensor &ry,
  84. torch::Tensor &y) {
  85. assert(x.stride(0) == 1);
  86. assert(w.stride(1) == 1);
  87. assert(mx.stride(0) == 1 && rx.stride(0) == 1);
  88. assert(my.stride(0) == 1 && ry.stride(0) == 1);
  89. assert(y.stride(0) == 1);
  90. const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
  91. switch (x.scalar_type()) {
  92. case c10::ScalarType::Half:
  93. cuda_mm8_one(
  94. N, M,
  95. x.data_ptr<fp16>(),
  96. w.data_ptr<uint8_t>(), w.stride(0),
  97. mx.data_ptr<fp16>(), rx.data_ptr<fp16>(),
  98. my.data_ptr<fp16>(), ry.data_ptr<fp16>(),
  99. y.data_ptr<float>());
  100. break;
  101. case c10::ScalarType::Float:
  102. cuda_mm8_one(
  103. N, M,
  104. x.data_ptr<float>(),
  105. w.data_ptr<uint8_t>(), w.stride(0),
  106. mx.data_ptr<float>(), rx.data_ptr<float>(),
  107. my.data_ptr<float>(), ry.data_ptr<float>(),
  108. y.data_ptr<float>());
  109. break;
  110. default:
  111. assert(false && "Only FP16 and FP32 are currently supported");
  112. }
  113. }
  114. using torch::Tensor;
  115. #ifndef DISABLE_CUBLAS_GEMM
  116. void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
  117. #endif
  118. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  119. m.def("wkv_forward", &wkv_forward, "wkv forward");
  120. m.def("mm8_seq", &mm8_seq, "mm8 seq");
  121. m.def("mm8_one", &mm8_one, "mm8 one");
  122. #ifndef DISABLE_CUBLAS_GEMM
  123. m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
  124. #endif
  125. }
  126. TORCH_LIBRARY(rwkv, m) {
  127. m.def("wkv_forward", wkv_forward);
  128. m.def("mm8_seq", mm8_seq);
  129. m.def("mm8_one", mm8_one);
  130. #ifndef DISABLE_CUBLAS_GEMM
  131. m.def("gemm_fp16_cublas", gemm_fp16_cublas);
  132. #endif
  133. }
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...