Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

rwkv5.cu 2.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
  1. #include <stdio.h>
  2. #include <assert.h>
  3. #include "ATen/ATen.h"
  4. typedef at::BFloat16 bf16;
  5. typedef at::Half fp16;
  6. typedef float fp32;
  7. template <typename F>
  8. __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
  9. const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
  10. F *__restrict__ const _y)
  11. {
  12. const int b = blockIdx.x / H;
  13. const int h = blockIdx.x % H;
  14. const int i = threadIdx.x;
  15. _w += h*_N_;
  16. _u += h*_N_;
  17. _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
  18. __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
  19. float state[_N_];
  20. #pragma unroll
  21. for (int j = 0; j < _N_; j++)
  22. state[j] = _state[j];
  23. __syncthreads();
  24. u[i] = float(_u[i]);
  25. w[i] = _w[i];
  26. __syncthreads();
  27. for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
  28. {
  29. __syncthreads();
  30. r[i] = float(_r[t]);
  31. k[i] = float(_k[t]);
  32. __syncthreads();
  33. const float v = float(_v[t]);
  34. float y = 0;
  35. #pragma unroll
  36. for (int j = 0; j < _N_; j+=4)
  37. {
  38. const float4& r_ = (float4&)(r[j]);
  39. const float4& k_ = (float4&)(k[j]);
  40. const float4& w_ = (float4&)(w[j]);
  41. const float4& u_ = (float4&)(u[j]);
  42. float4& s = (float4&)(state[j]);
  43. float4 x;
  44. x.x = k_.x * v;
  45. x.y = k_.y * v;
  46. x.z = k_.z * v;
  47. x.w = k_.w * v;
  48. y += r_.x * (u_.x * x.x + s.x);
  49. y += r_.y * (u_.y * x.y + s.y);
  50. y += r_.z * (u_.z * x.z + s.z);
  51. y += r_.w * (u_.w * x.w + s.w);
  52. s.x = s.x * w_.x + x.x;
  53. s.y = s.y * w_.y + x.y;
  54. s.z = s.z * w_.z + x.z;
  55. s.w = s.w * w_.w + x.w;
  56. }
  57. _y[t] = F(y);
  58. }
  59. #pragma unroll
  60. for (int j = 0; j < _N_; j++)
  61. _state[j] = state[j];
  62. }
  63. void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
  64. {
  65. assert(H*_N_ == C);
  66. kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
  67. }
  68. void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
  69. {
  70. assert(H*_N_ == C);
  71. kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
  72. }
  73. void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
  74. {
  75. assert(H*_N_ == C);
  76. kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, k, v, w, u, y);
  77. }
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...