1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
|
- #include <ATen/ATen.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <torch/library.h>
- #include <ATen/native/cuda/KernelUtils.cuh>
- #include "cuda_helpers.h"
- namespace vision {
- namespace ops {
- namespace {
- template <typename T>
- __device__ T bilinear_interpolate(
- const T* input,
- int height,
- int width,
- T y,
- T x,
- int index /* index for debug only*/) {
- // deal with cases that inverse elements are out of feature map boundary
- if (y < -1.0 || y > height || x < -1.0 || x > width) {
- // empty
- return 0;
- }
- if (y <= 0)
- y = 0;
- if (x <= 0)
- x = 0;
- int y_low = (int)y;
- int x_low = (int)x;
- int y_high;
- int x_high;
- if (y_low >= height - 1) {
- y_high = y_low = height - 1;
- y = (T)y_low;
- } else {
- y_high = y_low + 1;
- }
- if (x_low >= width - 1) {
- x_high = x_low = width - 1;
- x = (T)x_low;
- } else {
- x_high = x_low + 1;
- }
- T ly = y - y_low;
- T lx = x - x_low;
- T hy = 1. - ly, hx = 1. - lx;
- // do bilinear interpolation
- T v1 = input[y_low * width + x_low];
- T v2 = input[y_low * width + x_high];
- T v3 = input[y_high * width + x_low];
- T v4 = input[y_high * width + x_high];
- T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
- T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- return val;
- }
- template <typename T>
- __global__ void ps_roi_align_forward_kernel_impl(
- int nthreads,
- const T* input,
- const T spatial_scale,
- int channels,
- int height,
- int width,
- int pooled_height,
- int pooled_width,
- int sampling_ratio,
- const T* rois,
- int channels_out,
- T* output,
- int* channel_mapping) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- // (n, c_out, ph, pw) is an element in the pooled output
- int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height;
- int c_out = (index / pooled_width / pooled_height) % channels_out;
- int n = index / pooled_width / pooled_height / channels_out;
- // (n, c_in, ph, pw) is the associated element in the input
- int c_in = (c_out * pooled_height + ph) * pooled_width + pw;
- // [start, end) interval for spatial sampling
- const T* offset_rois = rois + n * 5;
- int roi_batch_ind = offset_rois[0];
- // Do not using rounding; this implementation detail is critical
- T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
- T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
- T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
- T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
- T roi_width = roi_end_w - roi_start_w;
- T roi_height = roi_end_h - roi_start_h;
- T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
- T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
- // Do not using floor/ceil; this implementation detail is critical
- T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
- T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
- // We use roi_bin_grid to sample the grid and mimic integral
- int roi_bin_grid_h = (sampling_ratio > 0)
- ? sampling_ratio
- : ceil(roi_height / pooled_height);
- int roi_bin_grid_w =
- (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
- const T count = roi_bin_grid_h * roi_bin_grid_w;
- const T* offset_input =
- input + (roi_batch_ind * channels + c_in) * height * width;
- T out_sum = 0;
- for (int iy = 0; iy < roi_bin_grid_h; iy++) {
- const T y = hstart +
- static_cast<T>(iy + .5f) * bin_size_h /
- static_cast<T>(roi_bin_grid_h);
- for (int ix = 0; ix < roi_bin_grid_w; ix++) {
- const T x = wstart +
- static_cast<T>(ix + .5f) * bin_size_w /
- static_cast<T>(roi_bin_grid_w);
- T val = bilinear_interpolate(offset_input, height, width, y, x, index);
- out_sum += val;
- }
- }
- out_sum /= count;
- output[index] = out_sum;
- channel_mapping[index] = c_in;
- }
- }
- template <typename T>
- __device__ void bilinear_interpolate_gradient(
- int height,
- int width,
- T y,
- T x,
- T& w1,
- T& w2,
- T& w3,
- T& w4,
- int& x_low,
- int& x_high,
- int& y_low,
- int& y_high,
- int index /* index for debug only*/) {
- // deal with cases that inverse elements are out of feature map boundary
- if (y < -1.0 || y > height || x < -1.0 || x > width) {
- // empty
- w1 = w2 = w3 = w4 = 0.;
- x_low = x_high = y_low = y_high = -1;
- return;
- }
- if (y <= 0)
- y = 0;
- if (x <= 0)
- x = 0;
- y_low = (int)y;
- x_low = (int)x;
- if (y_low >= height - 1) {
- y_high = y_low = height - 1;
- y = (T)y_low;
- } else {
- y_high = y_low + 1;
- }
- if (x_low >= width - 1) {
- x_high = x_low = width - 1;
- x = (T)x_low;
- } else {
- x_high = x_low + 1;
- }
- T ly = y - y_low;
- T lx = x - x_low;
- T hy = 1. - ly, hx = 1. - lx;
- // reference in forward
- // T v1 = input[y_low * width + x_low];
- // T v2 = input[y_low * width + x_high];
- // T v3 = input[y_high * width + x_low];
- // T v4 = input[y_high * width + x_high];
- // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
- w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
- }
- template <typename T>
- __global__ void ps_roi_align_backward_kernel_impl(
- int nthreads,
- const T* grad_output,
- const int* channel_mapping,
- const T spatial_scale,
- int channels,
- int height,
- int width,
- int pooled_height,
- int pooled_width,
- int sampling_ratio,
- int channels_out,
- T* grad_input,
- const T* rois,
- const int memory_span) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- // (n, *, ph, pw) is an element in the pooled output
- int pw = index % pooled_width;
- int ph = (index / pooled_width) % pooled_height;
- int n = index / pooled_width / pooled_height / channels_out;
- const T* offset_rois = rois + n * 5;
- int roi_batch_ind = offset_rois[0];
- // Do not using rounding; this implementation detail is critical
- T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
- T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
- T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
- T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
- // Force too small ROIs to be 1x1
- T roi_width = roi_end_w - roi_start_w;
- T roi_height = roi_end_h - roi_start_h;
- T bin_size_h = roi_height / static_cast<T>(pooled_height);
- T bin_size_w = roi_width / static_cast<T>(pooled_width);
- int c_in = channel_mapping[index];
- // Do not using floor/ceil; this implementation detail is critical
- T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
- T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
- const T grad_output_this_bin = grad_output[index];
- // We use roi_bin_grid to sample the grid and mimic integral
- int roi_bin_grid_h = (sampling_ratio > 0)
- ? sampling_ratio
- : ceil(roi_height / pooled_height); // e.g., = 2
- int roi_bin_grid_w =
- (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
- const T count = roi_bin_grid_h * roi_bin_grid_w;
- const int offset = (roi_batch_ind * channels + c_in) * height * width;
- for (int iy = 0; iy < roi_bin_grid_h; iy++) {
- const T y = hstart +
- static_cast<T>(iy + .5f) * bin_size_h /
- static_cast<T>(roi_bin_grid_h);
- for (int ix = 0; ix < roi_bin_grid_w; ix++) {
- const T x = wstart +
- static_cast<T>(ix + .5f) * bin_size_w /
- static_cast<T>(roi_bin_grid_w);
- T w1, w2, w3, w4;
- int x_low, x_high, y_low, y_high;
- bilinear_interpolate_gradient(
- height,
- width,
- y,
- x,
- w1,
- w2,
- w3,
- w4,
- x_low,
- x_high,
- y_low,
- y_high,
- index);
- T g1 = grad_output_this_bin * w1 / count;
- T g2 = grad_output_this_bin * w2 / count;
- T g3 = grad_output_this_bin * w3 / count;
- T g4 = grad_output_this_bin * w4 / count;
- if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
- at::native::fastAtomicAdd(
- grad_input,
- offset + y_low * width + x_low,
- memory_span,
- static_cast<T>(g1),
- true);
- at::native::fastAtomicAdd(
- grad_input,
- offset + y_low * width + x_high,
- memory_span,
- static_cast<T>(g2),
- true);
- at::native::fastAtomicAdd(
- grad_input,
- offset + y_high * width + x_low,
- memory_span,
- static_cast<T>(g3),
- true);
- at::native::fastAtomicAdd(
- grad_input,
- offset + y_high * width + x_high,
- memory_span,
- static_cast<T>(g4),
- true);
- } // if
- } // ix
- } // iy
- }
- }
- std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(
- const at::Tensor& input,
- const at::Tensor& rois,
- double spatial_scale,
- int64_t pooled_height,
- int64_t pooled_width,
- int64_t sampling_ratio) {
- // Check if input tensors are CUDA tensors
- TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
- TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
- TORCH_CHECK(
- rois.size(1) == 5, "Tensor rois should have shape as Tensor[K, 5]");
- at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
- at::CheckedFrom c = "ps_roi_align_forward_kernel";
- at::checkAllSameGPU(c, {input_t, rois_t});
- at::checkAllSameType(c, {input_t, rois_t});
- at::cuda::CUDAGuard device_guard(input.device());
- auto num_rois = rois.size(0);
- auto channels = input.size(1);
- auto height = input.size(2);
- auto width = input.size(3);
- TORCH_CHECK(
- channels % (pooled_height * pooled_width) == 0,
- "input channels must be a multiple of pooling height * pooling width");
- int channels_out = channels / (pooled_height * pooled_width);
- auto output = at::zeros(
- {num_rois, channels_out, pooled_height, pooled_width}, input.options());
- auto channel_mapping =
- at::zeros(output.sizes(), input.options().dtype(at::kInt));
- auto output_size = output.numel();
- if (output_size == 0) {
- AT_CUDA_CHECK(cudaGetLastError());
- return std::make_tuple(output, channel_mapping);
- }
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- dim3 grid(std::min(
- ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)),
- static_cast<int64_t>(4096)));
- dim3 block(512);
- auto input_ = input.contiguous(), rois_ = rois.contiguous();
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
- input.scalar_type(), "ps_roi_align_forward_kernel", [&] {
- ps_roi_align_forward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
- output_size,
- input_.data_ptr<scalar_t>(),
- spatial_scale,
- channels,
- height,
- width,
- pooled_height,
- pooled_width,
- sampling_ratio,
- rois_.data_ptr<scalar_t>(),
- channels_out,
- output.data_ptr<scalar_t>(),
- channel_mapping.data_ptr<int>());
- });
- AT_CUDA_CHECK(cudaGetLastError());
- cudaDeviceSynchronize();
- return std::make_tuple(output, channel_mapping);
- }
- at::Tensor ps_roi_align_backward_kernel(
- const at::Tensor& grad,
- const at::Tensor& rois,
- const at::Tensor& channel_mapping,
- double spatial_scale,
- int64_t pooled_height,
- int64_t pooled_width,
- int64_t sampling_ratio,
- int64_t batch_size,
- int64_t channels,
- int64_t height,
- int64_t width) {
- // Check if input tensors are CUDA tensors
- TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor");
- TORCH_CHECK(rois.is_cuda(), "rois must be a CUDA tensor");
- TORCH_CHECK(
- channel_mapping.is_cuda(), "channel_mapping must be a CUDA tensor");
- at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
- channel_mapping_t{channel_mapping, "channel_mapping", 3};
- at::CheckedFrom c = "ps_roi_align_backward_kernel";
- at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
- at::checkAllSameType(c, {grad_t, rois_t});
- at::cuda::CUDAGuard device_guard(grad.device());
- auto grad_input =
- at::zeros({batch_size, channels, height, width}, grad.options());
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- dim3 grid(std::min(
- ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)),
- static_cast<int64_t>(4096)));
- dim3 block(512);
- // handle possibly empty gradients
- if (grad.numel() == 0) {
- AT_CUDA_CHECK(cudaGetLastError());
- return grad_input;
- }
- int channels_out = channels / (pooled_height * pooled_width);
- at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
- auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
- grad.scalar_type(), "ps_roi_align_backward_kernel", [&] {
- ps_roi_align_backward_kernel_impl<scalar_t><<<grid, block, 0, stream>>>(
- grad.numel(),
- grad_.data_ptr<scalar_t>(),
- channel_mapping.data_ptr<int>(),
- spatial_scale,
- channels,
- height,
- width,
- pooled_height,
- pooled_width,
- sampling_ratio,
- channels_out,
- grad_input.data_ptr<scalar_t>(),
- rois_.data_ptr<scalar_t>(),
- grad_input.numel());
- });
- AT_CUDA_CHECK(cudaGetLastError());
- return grad_input;
- }
- } // namespace
- TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
- m.impl(
- TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"),
- TORCH_FN(ps_roi_align_forward_kernel));
- m.impl(
- TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"),
- TORCH_FN(ps_roi_align_backward_kernel));
- }
- } // namespace ops
- } // namespace vision
|