Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ps_roi_align.cpp 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  1. #include "ps_roi_align.h"
  2. #include <ATen/core/dispatch/Dispatcher.h>
  3. #include <torch/library.h>
  4. #include <torch/types.h>
  5. namespace vision {
  6. namespace ops {
  7. std::tuple<at::Tensor, at::Tensor> ps_roi_align(
  8. const at::Tensor& input,
  9. const at::Tensor& rois,
  10. double spatial_scale,
  11. int64_t pooled_height,
  12. int64_t pooled_width,
  13. int64_t sampling_ratio) {
  14. C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align");
  15. static auto op = c10::Dispatcher::singleton()
  16. .findSchemaOrThrow("torchvision::ps_roi_align", "")
  17. .typed<decltype(ps_roi_align)>();
  18. return op.call(
  19. input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
  20. }
  21. std::tuple<at::Tensor, at::Tensor> ps_roi_align_symint(
  22. const at::Tensor& input,
  23. const at::Tensor& rois,
  24. double spatial_scale,
  25. c10::SymInt pooled_height,
  26. c10::SymInt pooled_width,
  27. int64_t sampling_ratio) {
  28. C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align");
  29. static auto op = c10::Dispatcher::singleton()
  30. .findSchemaOrThrow("torchvision::ps_roi_align", "")
  31. .typed<decltype(ps_roi_align_symint)>();
  32. return op.call(
  33. input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
  34. }
  35. namespace detail {
  36. at::Tensor _ps_roi_align_backward(
  37. const at::Tensor& grad,
  38. const at::Tensor& rois,
  39. const at::Tensor& channel_mapping,
  40. double spatial_scale,
  41. int64_t pooled_height,
  42. int64_t pooled_width,
  43. int64_t sampling_ratio,
  44. int64_t batch_size,
  45. int64_t channels,
  46. int64_t height,
  47. int64_t width) {
  48. static auto op =
  49. c10::Dispatcher::singleton()
  50. .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
  51. .typed<decltype(_ps_roi_align_backward)>();
  52. return op.call(
  53. grad,
  54. rois,
  55. channel_mapping,
  56. spatial_scale,
  57. pooled_height,
  58. pooled_width,
  59. sampling_ratio,
  60. batch_size,
  61. channels,
  62. height,
  63. width);
  64. }
  65. at::Tensor _ps_roi_align_backward_symint(
  66. const at::Tensor& grad,
  67. const at::Tensor& rois,
  68. const at::Tensor& channel_mapping,
  69. double spatial_scale,
  70. c10::SymInt pooled_height,
  71. c10::SymInt pooled_width,
  72. int64_t sampling_ratio,
  73. c10::SymInt batch_size,
  74. c10::SymInt channels,
  75. c10::SymInt height,
  76. c10::SymInt width) {
  77. static auto op =
  78. c10::Dispatcher::singleton()
  79. .findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
  80. .typed<decltype(_ps_roi_align_backward_symint)>();
  81. return op.call(
  82. grad,
  83. rois,
  84. channel_mapping,
  85. spatial_scale,
  86. pooled_height,
  87. pooled_width,
  88. sampling_ratio,
  89. batch_size,
  90. channels,
  91. height,
  92. width);
  93. }
  94. } // namespace detail
  95. TORCH_LIBRARY_FRAGMENT(torchvision, m) {
  96. m.def(TORCH_SELECTIVE_SCHEMA(
  97. "torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
  98. m.def(TORCH_SELECTIVE_SCHEMA(
  99. "torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"));
  100. }
  101. } // namespace ops
  102. } // namespace vision
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...