Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

rwkv5_op.cpp 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  1. #include <torch/extension.h>
  2. #include "ATen/ATen.h"
  3. #include <c10/cuda/CUDAGuard.h>
  4. typedef at::BFloat16 bf16;
  5. typedef at::Half fp16;
  6. typedef float fp32;
  7. void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
  8. void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
  9. void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
  10. void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
  11. const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
  12. cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
  13. }
  14. void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
  15. const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
  16. cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
  17. }
  18. void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
  19. const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
  20. cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
  21. }
  22. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  23. m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
  24. m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
  25. m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
  26. }
  27. TORCH_LIBRARY(rwkv5, m) {
  28. m.def("forward_bf16", forward_bf16);
  29. m.def("forward_fp16", forward_fp16);
  30. m.def("forward_fp32", forward_fp32);
  31. }
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...