1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | #include <ATen/core/Tensor.h> |
3 | #include <ATen/Config.h> |
4 | #include <ATen/Parallel.h> |
5 | #include <ATen/TensorOperators.h> |
6 | #include <ATen/native/ConvolutionMM3d.h> |
7 | #include <ATen/native/ConvUtils.h> |
8 | #include <ATen/native/Pool.h> |
9 | #include <ATen/native/cpu/DepthwiseConvKernel.h> |
10 | #include <ATen/native/utils/ParamUtils.h> |
11 | #include <ATen/native/xnnpack/Engine.h> |
12 | #include <c10/util/accumulate.h> |
13 | #include <c10/util/irange.h> |
14 | #include <c10/macros/Macros.h> |
15 | #include <limits> |
16 | #include <utility> |
17 | |
18 | #ifndef AT_PER_OPERATOR_HEADERS |
19 | #include <ATen/Functions.h> |
20 | #else |
21 | #include <ATen/ops/permute.h> |
22 | #endif |
23 | |
24 | #if AT_NNPACK_ENABLED() |
25 | #include <nnpack.h> |
26 | #endif |
27 | |
28 | #if AT_MKLDNN_ENABLED() |
29 | #include <ATen/native/mkldnn/Utils.h> |
30 | #endif |
31 | |
32 | #ifndef AT_PER_OPERATOR_HEADERS |
33 | #include <ATen/Functions.h> |
34 | #include <ATen/NativeFunctions.h> |
35 | #else |
36 | #include <ATen/ops/_conv_depthwise2d.h> |
37 | #include <ATen/ops/_convolution.h> |
38 | #include <ATen/ops/_convolution_double_backward_native.h> |
39 | #include <ATen/ops/_convolution_mode.h> |
40 | #include <ATen/ops/_convolution_mode_native.h> |
41 | #include <ATen/ops/_convolution_native.h> |
42 | #include <ATen/ops/_mps_convolution.h> |
43 | #include <ATen/ops/_mps_convolution_transpose.h> |
44 | #include <ATen/ops/_nnpack_available.h> |
45 | #include <ATen/ops/_nnpack_spatial_convolution.h> |
46 | #include <ATen/ops/_slow_conv2d_backward.h> |
47 | #include <ATen/ops/_unsafe_view.h> |
48 | #include <ATen/ops/cat.h> |
49 | #include <ATen/ops/constant_pad_nd.h> |
50 | #include <ATen/ops/conv1d_native.h> |
51 | #include <ATen/ops/conv2d_native.h> |
52 | #include <ATen/ops/conv3d_native.h> |
53 | #include <ATen/ops/conv_depthwise3d.h> |
54 | #include <ATen/ops/conv_transpose1d_native.h> |
55 | #include <ATen/ops/conv_transpose2d_native.h> |
56 | #include <ATen/ops/conv_transpose3d_native.h> |
57 | #include <ATen/ops/convolution.h> |
58 | #include <ATen/ops/convolution_backward_native.h> |
59 | #include <ATen/ops/convolution_backward_overrideable.h> |
60 | #include <ATen/ops/convolution_backward_overrideable_native.h> |
61 | #include <ATen/ops/convolution_native.h> |
62 | #include <ATen/ops/convolution_overrideable.h> |
63 | #include <ATen/ops/convolution_overrideable_native.h> |
64 | #include <ATen/ops/cudnn_convolution.h> |
65 | #include <ATen/ops/cudnn_convolution_transpose.h> |
66 | #include <ATen/ops/empty.h> |
67 | #include <ATen/ops/empty_like.h> |
68 | #include <ATen/ops/empty_native.h> |
69 | #include <ATen/ops/miopen_convolution.h> |
70 | #include <ATen/ops/miopen_convolution_transpose.h> |
71 | #include <ATen/ops/miopen_depthwise_convolution.h> |
72 | #include <ATen/ops/mkldnn_convolution.h> |
73 | #include <ATen/ops/mps_convolution_backward.h> |
74 | #include <ATen/ops/mps_convolution_transpose_backward.h> |
75 | #include <ATen/ops/slow_conv3d.h> |
76 | #include <ATen/ops/slow_conv_dilated2d.h> |
77 | #include <ATen/ops/slow_conv_dilated3d.h> |
78 | #include <ATen/ops/slow_conv_transpose2d.h> |
79 | #include <ATen/ops/slow_conv_transpose3d.h> |
80 | #include <ATen/ops/thnn_conv2d.h> |
81 | #include <ATen/ops/view_as_real.h> |
82 | #include <ATen/ops/zeros.h> |
83 | #include <ATen/ops/zeros_like.h> |
84 | #endif |
85 | |
86 | constexpr int MIOPEN_DIM_MAX = 5; |
87 | |
88 | namespace at { namespace native { |
89 | |
90 | // Check workload to activate fast depthwise FP16 cudnn conv kernels |
91 | template <typename T> |
92 | bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { |
93 | auto w = at::symint::size<T>(input, 3); // same as h |
94 | auto ch = at::symint::size<T>(input, 1); |
95 | auto bs = at::symint::size<T>(input, 0); |
96 | if (stride==1) { |
97 | if (w >= 7) { |
98 | // All batch sizes and nb_channels |
99 | if (w >= 112) { |
100 | return true; |
101 | } |
102 | |
103 | // large nb_channels |
104 | if (ch >= 1024) { |
105 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
106 | if (w >= 56) { |
107 | return true; |
108 | } else if (bs >= 32) { |
109 | return true; |
110 | } |
111 | } |
112 | |
113 | // batch_size specific |
114 | if (bs >= 128) { |
115 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
116 | if (ch >= 512) { |
117 | return true; |
118 | } else if (ch >= 64) { |
119 | if (w >= 14) { |
120 | return true; |
121 | } |
122 | } else if ((ch >= 32) && (w >=28)) { |
123 | return true; |
124 | } |
125 | } else if (bs >= 64) { |
126 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
127 | if ((ch >= 256) && (w >= 14)) { |
128 | return true; |
129 | } else if ((ch >= 32) && (w >= 28)) { |
130 | return true; |
131 | } |
132 | } else if (bs >= 32) { |
133 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
134 | if ((ch >= 256) && (w >= 14)) { |
135 | return true; |
136 | } else if ((ch >= 128) && (w >= 28)) { |
137 | return true; |
138 | } else if ((ch >= 32) && (w >= 56)) { |
139 | return true; |
140 | } |
141 | } else if (bs >= 16) { |
142 | if ((ch >= 1024) && (w >= 14)) { |
143 | return true; |
144 | } |
145 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
146 | if ((ch >= 256) && (w >= 28)) { |
147 | return true; |
148 | } else if ((ch >= 32) && (w >= 56)) { |
149 | return true; |
150 | } |
151 | } else if (bs >= 8) { |
152 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
153 | if ((ch >= 512) && (w >= 28)) { |
154 | return true; |
155 | } else if ((ch >= 64) && (w >= 56)) { |
156 | return true; |
157 | } |
158 | } |
159 | } |
160 | } else if (stride==2) { |
161 | if (ch < 256) { |
162 | return false; |
163 | } |
164 | |
165 | if (w >= 7) { |
166 | if (bs >= 128) { |
167 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
168 | if (ch >= 1024) { |
169 | return true; |
170 | } else if ((ch >= 512) && (w >= 14)) { |
171 | return true; |
172 | } else if (w >= 28) { |
173 | return true; |
174 | } |
175 | } else if (bs >= 64) { |
176 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
177 | if ((ch >= 512) && (w >= 14)) { |
178 | return true; |
179 | } else if (w >= 28) { |
180 | return true; |
181 | } |
182 | } else if (bs >= 32) { |
183 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
184 | if ((ch >= 1024) && (w >= 14)) { |
185 | return true; |
186 | } else if (w >= 28) { |
187 | return true; |
188 | } |
189 | } else if (bs >= 16) { |
190 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
191 | if ((ch >= 512) && (w >= 28)) { |
192 | return true; |
193 | } else if (w >= 56) { |
194 | return true; |
195 | } |
196 | } else if (bs >= 8) { |
197 | // NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers) |
198 | if ((ch >= 1024) && (w >= 28)) { |
199 | return true; |
200 | } else if (w >= 56) { |
201 | return true; |
202 | } |
203 | } else if (bs >= 1) { |
204 | if ((ch >= 512) && (w >=112)) { |
205 | return true; |
206 | } |
207 | } |
208 | } |
209 | } |
210 | return false; |
211 | } |
212 | |
213 | // simplified version for cudnn 8.2 and above |
214 | template <typename T> |
215 | bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int stride, const at::Tensor& weight) { |
216 | // 1D conv |
217 | if(at::symint::size<T>(input, 2) == 1 && stride == 1){ |
218 | return true; |
219 | } |
220 | |
221 | // 2d conv |
222 | // only square filters |
223 | if (at::symint::size<T>(weight, 2) != at::symint::size<T>(weight, 3)) return false; |
224 | auto filter = at::symint::size<T>(weight, 3); |
225 | // only 1/3/5 filter |
226 | if (filter != 1 && filter != 3 && filter != 5) return false; |
227 | // we don't enforce square input but only check width to reduce heuristic space |
228 | if (at::symint::size<T>(input, 3) < 7) return false; // min width 7 |
229 | auto w = at::symint::size<T>(input, 3); |
230 | // only 1/2 stride, use cudnn for all stride 1 |
231 | if (stride == 1) return true; |
232 | if (stride != 2) return false; |
233 | |
234 | auto ch = at::symint::size<T>(input, 1); |
235 | auto bs = at::symint::size<T>(input, 0); |
236 | // special case since bs1 show good perf in lots of cases |
237 | if (bs == 1) { |
238 | if (filter == 1 && w <= 28) return true; |
239 | if (filter == 3 || filter == 5) return true; |
240 | } else { |
241 | if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true; |
242 | if (filter == 3 || filter == 5) { |
243 | if ((ch >= 512) || (ch >= 256 && w >= 28)) return true; |
244 | } |
245 | } |
246 | return false; |
247 | } |
248 | |
249 | |
250 | bool xnnpack_use_convolution2d( |
251 | const Tensor& input, |
252 | const Tensor& weight, |
253 | const at::OptionalIntArrayRef bias_sizes_opt, |
254 | const IntArrayRef padding, |
255 | const IntArrayRef stride, |
256 | const IntArrayRef dilation, |
257 | const int64_t groups, |
258 | const bool transposed) { |
259 | return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed); |
260 | } |
261 | |
262 | bool xnnpack_use_convolution2d( |
263 | const Tensor& input, |
264 | const Tensor& weight, |
265 | const at::OptionalSymIntArrayRef bias_sizes_opt, |
266 | const SymIntArrayRef padding, |
267 | const IntArrayRef stride, |
268 | const IntArrayRef dilation, |
269 | const int64_t groups, |
270 | const bool transposed) { |
271 | // Never use xnnpack for symbolic tracing |
272 | return false; |
273 | } |
274 | |
275 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
276 | // This struct is templated so that we can run backend selection in a dynamic |
277 | // shapes context; all of the real kernel selection in eager mode runs with |
278 | // int64_t |
279 | template <typename T> |
280 | struct ConvParams { |
281 | std::vector<int64_t> stride; |
282 | std::vector<T> padding; |
283 | std::vector<int64_t> dilation; |
284 | bool transposed; |
285 | std::vector<T> output_padding; |
286 | int groups; |
287 | bool benchmark; |
288 | bool deterministic; |
289 | bool cudnn_enabled; |
290 | bool allow_tf32; |
291 | |
292 | bool is_strided() const { |
293 | bool is_strided = false; |
294 | for (auto s : stride) { |
295 | is_strided |= (s != 1); |
296 | } |
297 | return is_strided; |
298 | } |
299 | |
300 | bool is_dilated() const { |
301 | bool is_dilated = false; |
302 | for (auto d : dilation) { |
303 | is_dilated |= (d != 1); |
304 | } |
305 | return is_dilated; |
306 | } |
307 | |
308 | bool is_padded() const { |
309 | bool is_padded = false; |
310 | for (auto p : padding) { |
311 | is_padded |= (p != 0); |
312 | } |
313 | return is_padded; |
314 | } |
315 | |
316 | bool is_output_padding_neg() const { |
317 | bool is_non_neg = false; |
318 | for (const auto& p : output_padding) { |
319 | is_non_neg |= (p < 0); |
320 | } |
321 | return is_non_neg; |
322 | } |
323 | |
324 | bool is_output_padding_big() const { |
325 | bool is_big = false; |
326 | for (auto i: c10::irange(output_padding.size())) { |
327 | is_big |= (output_padding[i] >= stride[i]); |
328 | } |
329 | return is_big; |
330 | } |
331 | |
332 | bool is_padding_neg() const { |
333 | bool is_non_neg = false; |
334 | for (const auto& p : padding) { |
335 | is_non_neg |= (p < 0); |
336 | } |
337 | return is_non_neg; |
338 | } |
339 | |
340 | bool is_stride_nonpos() const { |
341 | bool is_nonpos = false; |
342 | for (auto s : stride) { |
343 | is_nonpos |= (s <= 0); |
344 | } |
345 | return is_nonpos; |
346 | } |
347 | |
348 | void view1d_as_2d() { |
349 | if (stride.size() == 1) { |
350 | stride.insert(stride.begin(), 1); |
351 | padding.insert(padding.begin(), 0); |
352 | dilation.insert(dilation.begin(), 1); |
353 | output_padding.insert(output_padding.begin(), 0); |
354 | } |
355 | } |
356 | |
357 | bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias) const { |
358 | #if defined(__ARM_NEON__) |
359 | // Currently only 3x3 depthwise convolutions on tensors of float are supported. |
360 | return (input.ndimension() == 4) && |
361 | (at::symint::size<T>(input, 1) == groups) && |
362 | (weight.ndimension() == 4 ) && |
363 | (at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0) && |
364 | (at::symint::size<T>(weight, 1) == 1) && |
365 | (at::symint::size<T>(weight, 2) == 3) && |
366 | (at::symint::size<T>(weight, 3) == 3) && |
367 | (input.device().is_cpu()) && |
368 | (input.scalar_type() == at::kFloat) && |
369 | input.is_contiguous() && |
370 | (weight.device().is_cpu()) && |
371 | (weight.scalar_type() == at::kFloat) && |
372 | weight.is_contiguous() && |
373 | (!bias.has_value() || bias->is_contiguous()) && |
374 | !is_strided() && |
375 | !is_dilated() && |
376 | !transposed; |
377 | #else |
378 | return false; |
379 | #endif |
380 | } |
381 | |
382 | bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const { |
383 | constexpr int64_t int_max = std::numeric_limits<int>::max(); |
384 | auto numel_input = at::symint::numel<T>(input); |
385 | // empty input |
386 | if (numel_input == 0) { |
387 | return false; |
388 | } |
389 | // input size can not be reduced to the range of int by splitting the batch dim |
390 | auto n = at::symint::size<T>(input, 0); |
391 | if (numel_input / n > int_max) { |
392 | return true; |
393 | } |
394 | // output size can not be reduced to the range of int by splitting the batch dim |
395 | T outsize = 1; |
396 | if (transposed) { |
397 | auto o = conv_input_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, output_padding, stride, dilation, groups); |
398 | outsize = c10::multiply_integers(o.begin() + 1, o.end()); |
399 | } else { |
400 | auto o = conv_output_size(at::symint::sizes<T>(input), at::symint::sizes<T>(weight), padding, stride, dilation); |
401 | outsize = c10::multiply_integers(o.begin() + 1, o.end()); |
402 | } |
403 | return outsize > int_max; |
404 | } |
405 | |
406 | bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const { |
407 | // Note [Mobile check segfaults] |
408 | // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest |
409 | // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) |
410 | #if !defined(C10_MOBILE) |
411 | if (needs_64bit_indexing_no_split(input, weight)) { |
412 | return false; |
413 | } |
414 | if (!detail::getCUDAHooks().compiledWithCuDNN()) { |
415 | return false; |
416 | } |
417 | if (!input.is_cuda() || !cudnn_enabled) { |
418 | return false; |
419 | } |
420 | if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { |
421 | if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) { |
422 | return false; |
423 | } |
424 | } |
425 | if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) { |
426 | // bypass dilation checks for channels_last convolution |
427 | if (deterministic && is_dilated()) { |
428 | // cudnn doesn't support deterministic dilated convolution fully yet |
429 | return false; |
430 | } |
431 | if (is_dilated()) { |
432 | return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big(); |
433 | } |
434 | } |
435 | return !is_output_padding_big(); |
436 | #else |
437 | return false; |
438 | #endif |
439 | } |
440 | |
441 | // Use cudnn for FP16 depthwise convolutions |
442 | bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const { |
443 | if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) { |
444 | // always use cudnn_depthwise for channels_last format |
445 | return true; |
446 | } |
447 | if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { |
448 | long cudnn_version = detail::getCUDAHooks().versionCuDNN(); |
449 | if (cudnn_version >= 8200) { |
450 | bool kernel_cond = (use_cudnn(input, weight) && |
451 | input.scalar_type() == kHalf && // only for FP16 |
452 | weight.scalar_type() == kHalf && |
453 | is_depthwise(input, weight) && |
454 | input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks |
455 | !is_dilated() && // no dilation supported |
456 | (stride[0] == stride[1] || at::symint::size<T>(input, 2) == 1) && // square or 1d |
457 | at::symint::size<T>(input, 1) >= 32); // min 32 channels supported) |
458 | if (kernel_cond) { |
459 | return check_cudnn_depthwise_workload_with_filter<T>(input, stride[1], weight); |
460 | } |
461 | } |
462 | // keep (7600 <= cudnn < 8200) code unchanged |
463 | bool kernel_cond = (cudnn_version >= 7600 && |
464 | use_cudnn(input, weight) && |
465 | input.scalar_type() == kHalf && // only for FP16 |
466 | weight.scalar_type() == kHalf && |
467 | is_depthwise(input, weight) && |
468 | input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks |
469 | at::symint::size<T>(weight, 2) == at::symint::size<T>(weight, 3) && // only square kernels |
470 | at::symint::size<T>(input, 2) >= 7 && // min width/height 7 |
471 | !is_dilated() && // no dilation supported |
472 | stride[0] == stride[1] && // equal strides |
473 | ((at::symint::size<T>(weight, 3) == 3) || (at::symint::size<T>(weight, 3) == 1)) && |
474 | at::symint::size<T>(input, 1) >= 32); // min 32 channels supported) |
475 | if (kernel_cond) { |
476 | return check_cudnn_depthwise_workload<T>(input, stride[0]); |
477 | } else { |
478 | return false; |
479 | } |
480 | } else { |
481 | return false; |
482 | } |
483 | } |
484 | |
485 | bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const { |
486 | if (needs_64bit_indexing_no_split(input, weight)) { |
487 | return false; |
488 | } |
489 | return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16)) |
490 | && detail::getCUDAHooks().compiledWithMIOpen() |
491 | && input.is_cuda() |
492 | && input.dim() <= MIOPEN_DIM_MAX |
493 | && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 |
494 | && !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16 |
495 | && cudnn_enabled |
496 | ; |
497 | } |
498 | bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const { |
499 | #if AT_MKLDNN_ENABLED() |
500 | if (!at::globalContext().userEnabledMkldnn()) { |
501 | return false; |
502 | } |
503 | if (transposed && is_output_padding_big()) { |
504 | return false; |
505 | } |
506 | if (transposed && groups > 1 && at::symint::size<T>(input, 1) == groups) { |
507 | return false; |
508 | } |
509 | if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) { |
510 | return true; |
511 | } |
512 | return (input.is_mkldnn()) || // input is mkldnn Tensor |
513 | (input.device().is_cpu() && |
514 | input.scalar_type() == kFloat && // only on CPU Float Tensors |
515 | // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, |
516 | // but THNN is faster when single-threaded. |
517 | (is_strided() || is_dilated() || at::symint::size<T>(input, 0) >= 16 || |
518 | at::symint::size<T>(weight, -1) != 1 || at::symint::size<T>(weight, -2) != 1 || at::get_num_threads() > 1) && |
519 | (groups > 1 |
520 | || (at::symint::size<T>(weight, -1) > 3 && at::symint::size<T>(weight, -2) > 3) |
521 | || at::symint::size<T>(input, 0) > 1 |
522 | || at::symint::size<T>(input, 0)*at::symint::size<T>(input, 1)*at::symint::size<T>(input, 2)*at::symint::size<T>(input, 3) > 20480) // for some case, native is faster |
523 | ); |
524 | |
525 | #endif |
526 | return false; |
527 | } |
528 | bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const { |
529 | #if AT_NNPACK_ENABLED() |
530 | return at::_nnpack_available() && |
531 | input.device().is_cpu() && |
532 | input.scalar_type() == kFloat && // only on CPU Float Tensors |
533 | !is_dilated() && // or dilation |
534 | !transposed && // or transposed tensors |
535 | input.ndimension() == 4 && // must be in NCHW format |
536 | weight.ndimension() == 4 && |
537 | (at::symint::size<T>(weight, 2) < 17) && (at::symint::size<T>(weight, 3) < 17) // NNPACK only supports kernels up to 16x16 |
538 | #if !defined(C10_MOBILE) |
539 | && at::symint::size<T>(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable |
540 | #endif |
541 | ; |
542 | #endif |
543 | return false; |
544 | } |
545 | bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, |
546 | const at::OptionalArrayRef<T> bias_sizes_opt) const { |
547 | #if defined(C10_MOBILE) |
548 | if (!transposed) { |
549 | // NB: for the call here, it MATTERS that we are templated. If you |
550 | // untemplate this to always use SymInt, the function |
551 | // xnnpack_use_convolution2d will always return false |
552 | return (at::symint::size<T>(input, 1) == groups) && |
553 | xnnpack_use_convolution2d( |
554 | input, |
555 | weight, |
556 | bias_sizes_opt, |
557 | padding, |
558 | stride, |
559 | dilation, |
560 | groups, |
561 | transposed); |
562 | } |
563 | #endif |
564 | return false; |
565 | } |
566 | |
567 | bool use_mps(const at::Tensor& input, const at::Tensor& weight) const { |
568 | // These checks need to be expanded. Currently we have very limited set of |
569 | // checks for MPS. |
570 | #ifdef USE_MPS |
571 | if (needs_64bit_indexing_no_split(input, weight)) { |
572 | return false; |
573 | } |
574 | if (!input.is_mps()) { |
575 | return false; |
576 | } |
577 | return true; |
578 | #else |
579 | return false; |
580 | #endif |
581 | } |
582 | |
583 | // We currently only have depthwise support for the case where groups == |
584 | // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of |
585 | // a depthwise multiplier) |
586 | bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const { |
587 | return input.is_cuda() && |
588 | !transposed && |
589 | (input.ndimension() == 4 || input.ndimension() == 5) && |
590 | at::symint::size<T>(input, 1) == groups && |
591 | groups > 1 && // no point if there is only a single group |
592 | at::symint::size<T>(weight, 0) % at::symint::size<T>(input, 1) == 0; // output channels must be a multiple of input channels |
593 | } |
594 | }; |
595 | |
596 | DEFINE_DISPATCH(conv_depthwise2d_backward_stub); |
597 | DEFINE_DISPATCH(conv_depthwise3d_backward_stub); |
598 | DEFINE_DISPATCH(cudnn_convolution_backward_stub); |
599 | DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub); |
600 | DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub); |
601 | DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub); |
602 | DEFINE_DISPATCH(miopen_convolution_backward_stub); |
603 | DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub); |
604 | DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub); |
605 | DEFINE_DISPATCH(mkldnn_convolution_backward_stub); |
606 | DEFINE_DISPATCH(mkldnn_convolution_transpose_stub); |
607 | DEFINE_DISPATCH(mkldnn_convolution_transpose_backward_stub); |
608 | DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub); |
609 | DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub); |
610 | DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub); |
611 | REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub); |
612 | REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub); |
613 | REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub); |
614 | REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub); |
615 | REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); |
616 | REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); |
617 | REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); |
618 | |
619 | template <typename T> |
620 | std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params) { |
621 | out << "ConvParams {" |
622 | << " stride = " << IntArrayRef{params.stride} |
623 | << " padding = " << ArrayRef<T>{params.padding} |
624 | << " dilation = " << IntArrayRef{params.dilation} |
625 | << " transposed = " << params.transposed |
626 | << " output_padding = " << ArrayRef<T>{params.output_padding} |
627 | << " groups = " << params.groups |
628 | << " benchmark = " << params.benchmark |
629 | << " deterministic = " << params.deterministic |
630 | << " cudnn_enabled = " << params.cudnn_enabled |
631 | << " allow_tf32 = " << params.allow_tf32 |
632 | << "}" ; |
633 | return out; |
634 | } |
635 | |
636 | template <typename T> |
637 | static void check_shape_forward(const at::Tensor& input, |
638 | const c10::ArrayRef<T>& weight_sizes, const at::Tensor& bias, |
639 | const ConvParams<T>& params) { |
640 | int64_t k = input.ndimension(); |
641 | int64_t weight_dim = weight_sizes.size(); |
642 | int64_t groups = params.groups; |
643 | const auto& padding = params.padding; |
644 | const auto& dilation = params.dilation; |
645 | bool transposed = params.transposed; |
646 | |
647 | TORCH_CHECK(!params.is_padding_neg(), "negative padding is not supported" ); |
648 | TORCH_CHECK(!params.is_output_padding_neg(), "negative output_padding is not supported" ); |
649 | TORCH_CHECK(!params.is_stride_nonpos(), "non-positive stride is not supported" ); |
650 | |
651 | TORCH_CHECK(weight_dim == k, |
652 | "Expected " , weight_dim, "-dimensional input for " , weight_dim, |
653 | "-dimensional weight " , weight_sizes, ", but got " , k, "-dimensional input of size " , |
654 | at::symint::sizes<T>(input), " instead" ); |
655 | TORCH_CHECK(weight_sizes[0] >= groups, |
656 | "Given groups=" , groups, ", expected weight to be at least " , groups, |
657 | " at dimension 0, but got weight of size " , weight_sizes, " instead" ); |
658 | TORCH_CHECK(weight_sizes[0] % groups == 0, |
659 | "Given groups=" , groups, ", expected weight to be divisible by " , |
660 | groups, " at dimension 0, but got weight of size [" , weight_sizes, |
661 | "] instead" ); |
662 | |
663 | if (!transposed) { |
664 | std::vector<T> input_shape; |
665 | std::vector<T> kernel_shape; |
666 | bool kernel_size_correct = true; |
667 | |
668 | TORCH_CHECK(at::symint::size<T>(input, 1) == (weight_sizes[1] * groups), |
669 | "Given groups=" , groups, ", weight of size " , weight_sizes, |
670 | ", expected input" , at::symint::sizes<T>(input), " to have " , |
671 | (weight_sizes[1] * groups), " channels, but got " , at::symint::size<T>(input, 1), |
672 | " channels instead" ); |
673 | |
674 | TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[0]), |
675 | "Given weight of size " , weight_sizes, |
676 | ", expected bias to be 1-dimensional with " , weight_sizes[0], " elements" , |
677 | ", but got bias of size " , at::symint::sizes<T>(bias), " instead" ); |
678 | |
679 | for (const auto i : c10::irange(2, k)) { |
680 | input_shape.push_back(at::symint::size<T>(input, i) + 2 * padding[i-2]); |
681 | // log new kernel size considering dilation |
682 | kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1); |
683 | if (input_shape.back() < kernel_shape.back()) { |
684 | kernel_size_correct = false; |
685 | } |
686 | } |
687 | |
688 | TORCH_CHECK(input_shape.size() == kernel_shape.size(), "Inconsistent shape between Input and Kernel" ); |
689 | |
690 | if (!kernel_size_correct) { |
691 | // If kernel size is incorrect |
692 | std::ostringstream input_ss; |
693 | std::ostringstream kernel_ss; |
694 | std::string separator = "" ; |
695 | |
696 | for (int i = 0, len = input_shape.size(); i < len; ++i) { |
697 | input_ss << separator << input_shape[i]; |
698 | kernel_ss << separator << kernel_shape[i]; |
699 | separator = " x " ; |
700 | } |
701 | |
702 | AT_ERROR("Calculated padded input size per channel: (" , input_ss.str(), "). " |
703 | "Kernel size: (" , kernel_ss.str(), "). Kernel size can't be greater than actual input size" ); |
704 | } |
705 | } else { // transposed |
706 | TORCH_CHECK(at::symint::size<T>(input, 1) == weight_sizes[0], |
707 | "Given transposed=" , transposed, ", weight of size " , weight_sizes, |
708 | ", expected input" , at::symint::sizes<T>(input), " to have " , weight_sizes[0], |
709 | " channels, but got " , at::symint::size<T>(input, 1), " channels instead" ); |
710 | TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size<T>(bias, 0) == weight_sizes[1] * groups), |
711 | "Given transposed=" , transposed, ", weight of size " , weight_sizes, |
712 | ", expected bias to be 1-dimensional with " , weight_sizes[1] * groups, " elements" , |
713 | ", but got bias of size " , at::symint::sizes<T>(bias), " instead" ); |
714 | } |
715 | } |
716 | |
717 | template <typename T> |
718 | static void check_shape_backward( |
719 | const at::Tensor& input, |
720 | const c10::ArrayRef<T>& weight_sizes, |
721 | const ConvParams<T>& params) { |
722 | check_shape_forward<T>(input, weight_sizes, /*bias=*/ Tensor(), params); |
723 | } |
724 | |
725 | // Given an input tensor and an expected number of spatial dimensions, checks that the |
726 | // input is a valid shape and returns the batched form of the input. |
727 | // |
728 | // Args: |
729 | // input (Tensor): Input tensor |
730 | // num_spatial_dims (int): Number of spatial dimensions expected for the input |
731 | // func_name (string): Function name to produce a nice error message for invalid input |
732 | // |
733 | // Returns a std::tuple containing: |
734 | // batched_input (Tensor): Input with a batch dimension |
735 | // is_batched (bool): Indicates whether the original input was already batched |
736 | static std::tuple<Tensor, bool> batchify( |
737 | const Tensor& input, |
738 | const int64_t num_spatial_dims, |
739 | const std::string& func_name) { |
740 | const auto dim_count_no_batch = num_spatial_dims + 1; |
741 | const auto dim_count_batch = dim_count_no_batch + 1; |
742 | const auto is_batched = (input.dim() == dim_count_batch); |
743 | TORCH_CHECK(input.dim() == dim_count_no_batch || is_batched, |
744 | "Expected " , dim_count_no_batch, "D (unbatched) or " , dim_count_batch, |
745 | "D (batched) input to " , func_name, ", but got input of size: " , input.sizes()); |
746 | return std::make_tuple(is_batched ? input : input.unsqueeze(0), is_batched); |
747 | } |
748 | |
749 | static void check_input_same_type_as_parameters( |
750 | const Tensor& input, |
751 | const Tensor& weight, |
752 | const Tensor& bias) { |
753 | TORCH_CHECK(input.options().type_equal(weight.options()), |
754 | "Input type (" , input.toString(), ") and weight type (" , weight.toString(), |
755 | ") should be the same" ); |
756 | TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())), |
757 | "Input type (" , input.toString(), ") and bias type (" , bias.toString(), |
758 | ") should be the same" ); |
759 | } |
760 | |
761 | static void check_input_same_type_as_parameters( |
762 | const Tensor& input, |
763 | const Tensor& weight) { |
764 | check_input_same_type_as_parameters(input, weight, /*bias=*/ Tensor()); |
765 | } |
766 | |
767 | static void check_input_same_type_as_parameters( |
768 | const Tensor& input, |
769 | const Tensor& weight, |
770 | const Tensor& bias, |
771 | const ConvBackend backend) { |
772 | if (backend == ConvBackend::Mkldnn || backend == ConvBackend::MkldnnTranspose) { |
773 | TORCH_CHECK(input.options().type_equal(weight.options()) |
774 | || (input.is_mkldnn() && weight.device().is_cpu() && weight.scalar_type() == kFloat), |
775 | "Input type (" , input.toString(), ") and weight type (" , weight.toString(), |
776 | ") should be the same or input should be a MKLDNN tensor and weight is a dense tensor" ); |
777 | TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())) |
778 | || (input.is_mkldnn() && bias.device().is_cpu() && bias.scalar_type() == kFloat), |
779 | "Input type (" , input.toString(), ") and bias type (" , bias.toString(), |
780 | ") should be the same or input should be a MKLDNN tensor and bias is a dense tensor" ); |
781 | } else { |
782 | check_input_same_type_as_parameters(input, weight, bias); |
783 | } |
784 | } |
785 | |
786 | static auto view4d(const at::Tensor& tensor) -> at::Tensor { |
787 | TORCH_CHECK(tensor.ndimension() == 3, |
788 | "expected 3D tensor, got tensor with " , tensor.ndimension(), |
789 | " dimensions instead" ); |
790 | return tensor.unsqueeze(2); |
791 | } |
792 | |
793 | static auto view3d(const at::Tensor& tensor) -> at::Tensor { |
794 | TORCH_CHECK(tensor.ndimension() == 4, |
795 | "expected 4D tensor, got tensor with " , tensor.ndimension(), |
796 | " dimensions instead" ); |
797 | return tensor.squeeze(2); |
798 | } |
799 | |
800 | static at::Tensor subtensor(at::Tensor& tensor, int dim, int groups, int g) { |
801 | if (!tensor.defined()) { |
802 | return at::Tensor(); |
803 | } |
804 | const auto memory_format = tensor.suggest_memory_format(); |
805 | int64_t n = tensor.sizes()[dim] / groups; |
806 | return tensor.narrow(dim, n * g, n).contiguous(memory_format); |
807 | } |
808 | |
809 | namespace { |
810 | |
811 | std::pair<Tensor, Tensor> complex_to_real(const Tensor& inp) { |
812 | auto inp_view_as_complex = at::view_as_real(inp); |
813 | auto dim_i = inp_view_as_complex.dim() - 1; |
814 | auto i_r = inp_view_as_complex.select(dim_i, 0); |
815 | auto i_i = inp_view_as_complex.select(dim_i, 1); |
816 | return std::make_pair(i_r, i_i); |
817 | } |
818 | |
819 | at::Tensor complex_convolution( |
820 | const Tensor& input, |
821 | const Tensor& weight, |
822 | const Tensor& bias, |
823 | IntArrayRef stride, |
824 | IntArrayRef padding, |
825 | IntArrayRef dilation, |
826 | bool transposed, |
827 | IntArrayRef output_padding, |
828 | int64_t groups) { |
829 | check_input_same_type_as_parameters(input, weight, bias); |
830 | Tensor i_r, i_i, w_r, w_i; |
831 | std::tie(i_r, i_i) = complex_to_real(input.resolve_conj()); |
832 | std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj()); |
833 | |
834 | // [NOTE] Complex Convolution |
835 | // conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) |
836 | // where W, x and b are all complex inputs. |
837 | // With Gauss Trick: |
838 | // a = conv(Wr, xr, br), |
839 | // b = conv(Wi, xi, 0), |
840 | // c = conv(Wr + Wi, xr + xi, bi + br) |
841 | // conv(W, x, b) = a - b + i(c - a - b) |
842 | Tensor a, b, c; |
843 | if (!bias.defined()) { |
844 | a = at::convolution(i_r, w_r, bias, stride, padding, dilation, transposed, output_padding, groups); |
845 | b = at::convolution(i_i, w_i, bias, stride, padding, dilation, transposed, output_padding, groups); |
846 | c = at::convolution(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, transposed, output_padding, groups); |
847 | } else { |
848 | Tensor b_r, b_i; |
849 | std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj()); |
850 | a = at::convolution(i_r, w_r, b_r, stride, padding, dilation, transposed, output_padding, groups); |
851 | b = at::convolution(i_i, w_i, Tensor(), stride, padding, dilation, transposed, output_padding, groups); |
852 | c = at::convolution(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, transposed, output_padding, groups); |
853 | } |
854 | |
855 | auto i = c10::Scalar(c10::complex<double>(0, 1)); |
856 | return a - b + i * (c - a - b); |
857 | } |
858 | |
859 | at::Tensor complex_convolution_mode( |
860 | const at::Tensor& input, |
861 | const at::Tensor& weight, |
862 | const c10::optional<at::Tensor>& bias_opt, |
863 | at::IntArrayRef stride, |
864 | c10::string_view padding, |
865 | at::IntArrayRef dilation, |
866 | int64_t groups) { |
867 | auto bias = bias_opt.value_or(Tensor()); |
868 | check_input_same_type_as_parameters(input, weight, bias); |
869 | Tensor i_r, i_i, w_r, w_i; |
870 | std::tie(i_r, i_i) = complex_to_real(input.resolve_conj()); |
871 | std::tie(w_r, w_i) = complex_to_real(weight.resolve_conj()); |
872 | |
873 | // See [NOTE] Complex Convolution |
874 | Tensor a, b, c; |
875 | if (!bias.defined()) { |
876 | a = at::_convolution_mode(i_r, w_r, bias, stride, padding, dilation, groups); |
877 | b = at::_convolution_mode(i_i, w_i, bias, stride, padding, dilation, groups); |
878 | c = at::_convolution_mode(i_r + i_i, w_r + w_i, bias, stride, padding, dilation, groups); |
879 | } else { |
880 | Tensor b_r, b_i; |
881 | std::tie(b_r, b_i) = complex_to_real(bias.resolve_conj()); |
882 | a = at::_convolution_mode(i_r, w_r, b_r, stride, padding, dilation, groups); |
883 | b = at::_convolution_mode(i_i, w_i, Tensor(), stride, padding, dilation, groups); |
884 | c = at::_convolution_mode(i_r + i_i, w_r + w_i, b_r + b_i, stride, padding, dilation, groups); |
885 | } |
886 | |
887 | auto i = c10::Scalar(c10::complex<double>(0, 1)); |
888 | return a - b + i * (c - a - b); |
889 | } |
890 | |
891 | } // namespace |
892 | |
893 | at::Tensor conv1d( |
894 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
895 | IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { |
896 | // See [Note: hacky wrapper removal for optional tensor] |
897 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
898 | const Tensor& bias = *bias_maybe_owned; |
899 | |
900 | Tensor input; |
901 | bool is_batched; |
902 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d" ); |
903 | Tensor output; |
904 | if (at::isComplexType(input_.scalar_type())) { |
905 | output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups); |
906 | } else { |
907 | output = at::convolution(input, weight, bias, stride, padding, dilation, false, {0}, groups); |
908 | } |
909 | return is_batched ? std::move(output) : output.squeeze(0); |
910 | } |
911 | |
912 | at::Tensor conv2d( |
913 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
914 | IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { |
915 | // See [Note: hacky wrapper removal for optional tensor] |
916 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
917 | const Tensor& bias = *bias_maybe_owned; |
918 | |
919 | TORCH_CHECK( |
920 | !bias.defined() || bias.dtype() == input_.dtype(), |
921 | "Input type (" , |
922 | input_.dtype().name(), |
923 | ") and bias type (" , |
924 | bias.dtype().name(), |
925 | ") should be the same" ); |
926 | |
927 | Tensor input; |
928 | bool is_batched; |
929 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d" ); |
930 | Tensor output; |
931 | if (at::isComplexType(input_.scalar_type())) { |
932 | output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups); |
933 | } else { |
934 | output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0}}, groups); |
935 | } |
936 | return is_batched ? std::move(output) : output.squeeze(0); |
937 | } |
938 | |
939 | at::Tensor conv3d( |
940 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
941 | IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { |
942 | // See [Note: hacky wrapper removal for optional tensor] |
943 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
944 | const Tensor& bias = *bias_maybe_owned; |
945 | |
946 | Tensor input; |
947 | bool is_batched; |
948 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d" ); |
949 | Tensor output; |
950 | if (at::isComplexType(input_.scalar_type())) { |
951 | output = complex_convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups); |
952 | } else { |
953 | output = at::convolution(input, weight, bias, stride, padding, dilation, false, {{0, 0, 0}}, groups); |
954 | } |
955 | return is_batched ? std::move(output) : output.squeeze(0); |
956 | } |
957 | |
958 | |
959 | static Tensor convolution_same( |
960 | const Tensor &input, const Tensor &weight, const Tensor &bias, |
961 | IntArrayRef stride, IntArrayRef dilation, int64_t groups) { |
962 | |
963 | auto k = weight.dim(); |
964 | TORCH_CHECK(k > 2, "weight should have at least three dimensions" ); |
965 | auto dim = static_cast<size_t>(k - 2); |
966 | auto weight_sizes = weight.sym_sizes(); |
967 | auto input_sizes = input.sym_sizes(); |
968 | TORCH_CHECK(k == input.dim(), |
969 | "Expected " , k, "-dimensional input for " , |
970 | k, "-dimensional weight" , weight_sizes, ", but got " , |
971 | input.dim(), "-dimensional input of size " , |
972 | input.sizes(), " instead" ); |
973 | TORCH_CHECK(stride.size() == dim || stride.size() == 1U, |
974 | "stride cannot broadcast to " , dim, " dimensions" ); |
975 | TORCH_CHECK(dilation.size() == dim || dilation.size() == 1U, |
976 | "dilation cannot broadcast to " , dim, " dimensions" ); |
977 | for (auto i: c10::irange(stride.size())) { |
978 | TORCH_CHECK(stride[i] == 1, "padding='same' is not supported for strided convolutions" ); |
979 | } |
980 | |
981 | // Calculate the correct padding |
982 | SymDimVector padding_l, padding_r; |
983 | bool symmetric_padding = true; |
984 | for (auto i: c10::irange(dim)) { |
985 | auto s = stride.size() == 1 ? stride[0] : stride[i]; |
986 | auto d = dilation.size() == 1 ? dilation[0] : dilation[i]; |
987 | auto pad = pooling_same_mode_padding_lr( |
988 | input_sizes[i + 2], weight_sizes[i + 2], s, d); |
989 | padding_l.push_back(pad.first); |
990 | padding_r.push_back(pad.second); |
991 | if (pad.first != pad.second) { |
992 | symmetric_padding = false; |
993 | } |
994 | } |
995 | |
996 | if (symmetric_padding) { |
997 | // All backends handle symmetric padding natively |
998 | SymDimVector output_padding(static_cast<size_t>(dim)); |
999 | return at::convolution_symint(input, weight, bias, stride, padding_l, dilation, |
1000 | false, output_padding, groups); |
1001 | } |
1002 | |
1003 | TORCH_WARN_ONCE("Using padding='same' with even kernel lengths and odd dilation may" |
1004 | " require a zero-padded copy of the input be created" ); |
1005 | SmallVector<c10::SymInt, kDimVectorStaticSize * 2> pad_nd(static_cast<size_t>(2 * dim)); |
1006 | for (auto i: c10::irange(dim)) { |
1007 | // Apply padding by the difference, leaving only a symmetric padding |
1008 | auto delta_pad = padding_r[i] - padding_l[i]; |
1009 | auto pad_idx = 2 * (dim - 1 - i); // F.pad goes from last dim to first |
1010 | if (delta_pad > 0) { |
1011 | pad_nd[pad_idx + 1] = delta_pad; |
1012 | } else { |
1013 | pad_nd[pad_idx] = delta_pad; |
1014 | padding_l[i] = padding_r[i]; |
1015 | } |
1016 | } |
1017 | auto padded_input = at::constant_pad_nd_symint(input, pad_nd, 0); |
1018 | SymDimVector output_padding(static_cast<size_t>(dim)); |
1019 | return at::convolution_symint(padded_input, weight, bias, stride, padding_l, |
1020 | dilation, false, output_padding, groups); |
1021 | } |
1022 | |
1023 | Tensor _convolution_mode( |
1024 | const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
1025 | IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, |
1026 | int64_t groups) { |
1027 | // See [Note: hacky wrapper removal for optional tensor] |
1028 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
1029 | const Tensor& bias = *bias_maybe_owned; |
1030 | |
1031 | if (padding == "same" ) { |
1032 | return at::native::convolution_same( |
1033 | input, weight, bias, stride, dilation, groups); |
1034 | } else if (padding == "valid" ) { |
1035 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) |
1036 | const int64_t padding_[] = {0}; |
1037 | return at::convolution( |
1038 | input, weight, bias, stride, padding_, dilation, false, padding_, groups); |
1039 | } |
1040 | TORCH_CHECK(false, "Invalid padding string: '" , padding, "'" ); |
1041 | } |
1042 | |
1043 | at::Tensor conv1d( |
1044 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias, |
1045 | IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, |
1046 | int64_t groups) { |
1047 | Tensor input; |
1048 | bool is_batched; |
1049 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv1d" ); |
1050 | Tensor output; |
1051 | if (at::isComplexType(input_.scalar_type())) { |
1052 | output = complex_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups); |
1053 | } else { |
1054 | output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups); |
1055 | } |
1056 | return is_batched ? std::move(output) : output.squeeze(0); |
1057 | } |
1058 | |
1059 | at::Tensor conv2d( |
1060 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias, |
1061 | IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, |
1062 | int64_t groups) { |
1063 | Tensor input; |
1064 | bool is_batched; |
1065 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv2d" ); |
1066 | Tensor output; |
1067 | if (at::isComplexType(input_.scalar_type())) { |
1068 | output = complex_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups); |
1069 | } else { |
1070 | output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups); |
1071 | } |
1072 | return is_batched ? std::move(output) : output.squeeze(0); |
1073 | } |
1074 | |
1075 | at::Tensor conv3d( |
1076 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias, |
1077 | IntArrayRef stride, c10::string_view padding, IntArrayRef dilation, |
1078 | int64_t groups) { |
1079 | Tensor input; |
1080 | bool is_batched; |
1081 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv3d" ); |
1082 | Tensor output; |
1083 | if (at::isComplexType(input_.scalar_type())) { |
1084 | output = complex_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups); |
1085 | } else { |
1086 | output = at::_convolution_mode(input, weight, bias, stride, std::move(padding), dilation, groups); |
1087 | } |
1088 | return is_batched ? std::move(output) : output.squeeze(0); |
1089 | } |
1090 | |
1091 | at::Tensor conv_transpose1d( |
1092 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
1093 | IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) { |
1094 | // See [Note: hacky wrapper removal for optional tensor] |
1095 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
1096 | const Tensor& bias = *bias_maybe_owned; |
1097 | |
1098 | Tensor input; |
1099 | bool is_batched; |
1100 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 1, "conv_transpose1d" ); |
1101 | Tensor output; |
1102 | if (at::isComplexType(input_.scalar_type())) { |
1103 | output = complex_convolution( |
1104 | input, weight, bias, stride, padding, dilation, true, output_padding, groups); |
1105 | } else { |
1106 | output = at::convolution( |
1107 | input, weight, bias, stride, padding, dilation, true, output_padding, groups); |
1108 | } |
1109 | return is_batched ? std::move(output) : output.squeeze(0); |
1110 | } |
1111 | |
1112 | at::Tensor conv_transpose2d( |
1113 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
1114 | IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) { |
1115 | // See [Note: hacky wrapper removal for optional tensor] |
1116 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
1117 | const Tensor& bias = *bias_maybe_owned; |
1118 | |
1119 | Tensor input; |
1120 | bool is_batched; |
1121 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 2, "conv_transpose2d" ); |
1122 | Tensor output; |
1123 | if (at::isComplexType(input_.scalar_type())) { |
1124 | output = complex_convolution( |
1125 | input, weight, bias, stride, padding, dilation, true, output_padding, groups); |
1126 | } else { |
1127 | output = at::convolution( |
1128 | input, weight, bias, stride, padding, dilation, true, output_padding, groups); |
1129 | } |
1130 | return is_batched ? std::move(output) : output.squeeze(0); |
1131 | } |
1132 | |
1133 | at::Tensor conv_transpose3d( |
1134 | const Tensor& input_, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
1135 | IntArrayRef stride, IntArrayRef padding, IntArrayRef output_padding, int64_t groups, IntArrayRef dilation) { |
1136 | // See [Note: hacky wrapper removal for optional tensor] |
1137 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
1138 | const Tensor& bias = *bias_maybe_owned; |
1139 | |
1140 | Tensor input; |
1141 | bool is_batched; |
1142 | std::tie(input, is_batched) = batchify(input_, /*num_spatial_dims=*/ 3, "conv_transpose3d" ); |
1143 | Tensor output; |
1144 | if (at::isComplexType(input_.scalar_type())) { |
1145 | output = complex_convolution( |
1146 | input, weight, bias, stride, padding, dilation, true, output_padding, groups); |
1147 | } else { |
1148 | output = at::convolution( |
1149 | input, weight, bias, stride, padding, dilation, true, output_padding, groups); |
1150 | } |
1151 | return is_batched ? std::move(output) : output.squeeze(0); |
1152 | } |
1153 | |
1154 | at::Tensor convolution( |
1155 | const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
1156 | IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, |
1157 | bool transposed, IntArrayRef output_padding, int64_t groups) { |
1158 | // See [Note: hacky wrapper removal for optional tensor] |
1159 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
1160 | const Tensor& bias = *bias_maybe_owned; |
1161 | |
1162 | auto& ctx = at::globalContext(); |
1163 | // See Note [Enabling Deterministic Operations] |
1164 | bool deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); |
1165 | return at::_convolution(input, weight, bias, stride, padding, dilation, |
1166 | transposed, output_padding, groups, |
1167 | ctx.benchmarkCuDNN(), deterministic, ctx.userEnabledCuDNN(), ctx.allowTF32CuDNN()); |
1168 | } |
1169 | |
1170 | at::Tensor convolution_overrideable( |
1171 | const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt, |
1172 | IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, |
1173 | bool transposed, IntArrayRef output_padding, int64_t groups) { |
1174 | TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function " ); |
1175 | } |
1176 | |
1177 | // Function to select the convolution backend based on the inputs and params. |
1178 | // This overload is used within the convolution internals but not exposed to python. |
1179 | // NB: The forward pass provides a bias tensor while the backward pass provides |
1180 | // a bool indicating whether the bias is defined. This is done to save memory by |
1181 | // avoiding saving the full bias tensor for backward. |
1182 | template <typename T> |
1183 | ConvBackend _select_conv_backend( |
1184 | const Tensor& input, |
1185 | const Tensor& weight, |
1186 | const c10::optional<Tensor>& bias, |
1187 | const at::OptionalArrayRef<T> bias_sizes_opt, |
1188 | const bool need_backward, |
1189 | const ConvParams<T>& params) { |
1190 | |
1191 | // don't send empty inputs through backends |
1192 | if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) { |
1193 | return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty; |
1194 | } else if (at::symint::numel<T>(input) == 0) { |
1195 | TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: " , at::symint::sizes<T>(input)); |
1196 | } |
1197 | |
1198 | if (params.is_depthwise(input, weight)) { |
1199 | if (params.use_cudnn_depthwise(input, weight)) { |
1200 | return ConvBackend::Cudnn; |
1201 | } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) { |
1202 | return ConvBackend::MiopenDepthwise; |
1203 | } else { |
1204 | if (input.ndimension() == 4) { |
1205 | return ConvBackend::CudaDepthwise2d; |
1206 | } else if (input.ndimension() == 5) { |
1207 | return ConvBackend::CudaDepthwise3d; |
1208 | } else { |
1209 | // unsupported |
1210 | } |
1211 | } |
1212 | } else if (params.use_cudnn(input, weight)) { |
1213 | if (params.transposed) { |
1214 | return ConvBackend::CudnnTranspose; |
1215 | } else { |
1216 | return ConvBackend::Cudnn; |
1217 | } |
1218 | } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) { |
1219 | if (params.transposed) { |
1220 | return ConvBackend::MiopenTranspose; |
1221 | } else { |
1222 | return ConvBackend::Miopen; |
1223 | } |
1224 | } else if (params.use_mkldnn(input, weight)) { |
1225 | if (params.transposed) { |
1226 | return ConvBackend::MkldnnTranspose; |
1227 | } else { |
1228 | return ConvBackend::Mkldnn; |
1229 | } |
1230 | } else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) { |
1231 | // Using prepacked conv is preferred, but XNNPACK is still the fastest |
1232 | // option for NHWC. |
1233 | return ConvBackend::Xnnpack2d; |
1234 | // 3x3 depthwith convolutions implementation is inference only |
1235 | } else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) { |
1236 | return ConvBackend::Winograd3x3Depthwise; |
1237 | } else if ( |
1238 | !params.transposed && (input.ndimension() == 5) && |
1239 | (input.device().is_cpu()) && |
1240 | !params.is_dilated()) { |
1241 | // fast path for grouped conv3d |
1242 | return ConvBackend::Slow3d; |
1243 | } else if (input.device().is_cpu() || input.is_cuda()) { |
1244 | // backends without support for groups |
1245 | if (params.transposed) { |
1246 | if (input.ndimension() == 4) { |
1247 | return ConvBackend::SlowTranspose2d; |
1248 | } else if (input.ndimension() == 5) { |
1249 | return ConvBackend::SlowTranspose3d; |
1250 | } else { |
1251 | // unsupported |
1252 | } |
1253 | } else { /* Not transposed */ |
1254 | if (input.ndimension() == 4) { |
1255 | if (params.is_dilated()) { |
1256 | return ConvBackend::SlowDilated2d; |
1257 | } else { /* dim == 4, non-dilated */ |
1258 | if (params.use_nnpack(input, weight)) { |
1259 | return ConvBackend::NnpackSpatial; |
1260 | } else { |
1261 | /* CPU implementation has specialized MM kernels |
1262 | for non-dilated case here */ |
1263 | return ConvBackend::Slow2d; |
1264 | } |
1265 | } |
1266 | } else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) { |
1267 | return ConvBackend::SlowDilated3d; |
1268 | } else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */ |
1269 | /* CPU implementation has specialized MM kernels |
1270 | for non-dilated case here */ |
1271 | return ConvBackend::Slow3d; |
1272 | } else { |
1273 | // unsupported |
1274 | } |
1275 | } |
1276 | } else if (params.use_mps(input, weight)) { |
1277 | if (params.transposed) { |
1278 | return ConvBackend::MpsTranspose; |
1279 | } else { |
1280 | return ConvBackend::Mps; |
1281 | } |
1282 | } else { |
1283 | // Only reach here when input is backend with out-of-source implementation. |
1284 | return ConvBackend::Overrideable; |
1285 | } |
1286 | |
1287 | // Error out if no suitable backend was found. |
1288 | AT_ERROR("unsupported ConvNd parameters" ); |
1289 | } |
1290 | |
1291 | // Selects a backend for convolution based on the inputs and params. |
1292 | ConvBackend select_conv_backend( |
1293 | const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_opt, |
1294 | IntArrayRef stride_, SymIntArrayRef padding_, IntArrayRef dilation_, |
1295 | bool transposed_, SymIntArrayRef output_padding_, int64_t groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) { |
1296 | c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); |
1297 | const Tensor& bias = *bias_maybe_owned; |
1298 | |
1299 | auto& ctx = at::globalContext(); |
1300 | auto k = weight_r.ndimension(); |
1301 | int64_t dim = k - 2; |
1302 | ConvParams<c10::SymInt> params; |
1303 | params.stride = expand_param_if_needed(stride_, "stride" , dim); |
1304 | params.padding = expand_param_if_needed(padding_, "padding" , dim); |
1305 | params.dilation = expand_param_if_needed(dilation_, "dilation" , dim); |
1306 | params.transposed = transposed_; |
1307 | params.output_padding = expand_param_if_needed(output_padding_, "output_padding" , dim); |
1308 | params.groups = groups_; |
1309 | params.benchmark = ctx.benchmarkCuDNN(); |
1310 | params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); |
1311 | params.cudnn_enabled = ctx.userEnabledCuDNN(); |
1312 | params.allow_tf32 = ctx.allowTF32CuDNN(); |
1313 | |
1314 | auto input = input_r; |
1315 | auto weight = weight_r; |
1316 | check_shape_forward(input, weight.sym_sizes(), bias, params); |
1317 | |
1318 | // Expand 1d -> 2d. |
1319 | // This is only done for backends that don't natively support 1d spatial input. |
1320 | if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { |
1321 | // avoid accidentally going through NHWC for permuted 3d input. |
1322 | input = input.contiguous(); |
1323 | params.view1d_as_2d(); |
1324 | input = view4d(input); |
1325 | weight = view4d(weight); |
1326 | } |
1327 | |
1328 | auto bias_sizes = bias.defined() ? c10::optional<SymIntArrayRef>(bias.sym_sizes()) : bias_sizes_opt; |
1329 | bool need_backward = GradMode::is_enabled() && |
1330 | (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); |
1331 | return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); |
1332 | } |
1333 | |
1334 | // For BC reasons, have a copy that does not require bias_opt |
1335 | ConvBackend select_conv_backend( |
1336 | const Tensor& input, |
1337 | const Tensor& weight, |
1338 | const at::OptionalIntArrayRef bias_sizes_opt, |
1339 | const bool need_backward, |
1340 | const ConvParams<int64_t>& params) { |
1341 | return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); |
1342 | } |
1343 | |
1344 | at::Tensor _convolution_nogroup_backend( |
1345 | const Tensor& input, |
1346 | const Tensor& weight, |
1347 | const Tensor& bias, |
1348 | const ConvBackend backend, |
1349 | const ConvParams<int64_t>& params) { |
1350 | auto kernel_size = weight.sizes().slice(2); |
1351 | switch(backend) { |
1352 | case ConvBackend::NnpackSpatial: |
1353 | #if AT_NNPACK_ENABLED() |
1354 | return at::_nnpack_spatial_convolution(input, weight, bias, params.padding, params.stride); |
1355 | #else |
1356 | TORCH_INTERNAL_ASSERT(false, "NnpackSpatial backend was selected in PyTorch compiled without nnpack support" ); |
1357 | #endif |
1358 | case ConvBackend::Slow2d: |
1359 | return at::thnn_conv2d(input, weight, kernel_size, bias, params.stride, params.padding); |
1360 | case ConvBackend::SlowDilated2d: |
1361 | return at::slow_conv_dilated2d( |
1362 | input, weight, kernel_size, bias, params.stride, params.padding, params.dilation); |
1363 | case ConvBackend::SlowDilated3d: |
1364 | return at::slow_conv_dilated3d( |
1365 | input, weight, kernel_size, bias, params.stride, params.padding, params.dilation); |
1366 | case ConvBackend::SlowTranspose2d: |
1367 | return at::slow_conv_transpose2d( |
1368 | input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation); |
1369 | case ConvBackend::SlowTranspose3d: |
1370 | return at::slow_conv_transpose3d( |
1371 | input, weight, kernel_size, bias, params.stride, params.padding, params.output_padding, params.dilation); |
1372 | default: |
1373 | TORCH_CHECK(false, "Unsupported conv nogroup backend encountered" ); |
1374 | } |
1375 | } |
1376 | |
1377 | static inline std::vector<int64_t> calc_output_size( |
1378 | const Tensor& input, |
1379 | const Tensor& weight, |
1380 | const ConvParams<int64_t>& params) { |
1381 | std::vector<int64_t> output_size = params.transposed ? |
1382 | conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding, |
1383 | params.stride, params.dilation, params.groups) : |
1384 | conv_output_size(input.sizes(), weight.sizes(), params.padding, params.stride, params.dilation); |
1385 | |
1386 | // Handle empty # of channels. |
1387 | if (input.size(1) == 0) { |
1388 | output_size[input_channels_dim] = 0; |
1389 | } |
1390 | return output_size; |
1391 | } |
1392 | |
1393 | static inline at::MemoryFormat determine_backend_memory_format( |
1394 | const Tensor& input, |
1395 | const Tensor& weight, |
1396 | const ConvBackend backend) { |
1397 | at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous; |
1398 | auto k = weight.ndimension(); |
1399 | #if !defined(C10_MOBILE) |
1400 | // See Note [Mobile check segfaults] |
1401 | switch(backend) { |
1402 | case ConvBackend::Cudnn: |
1403 | case ConvBackend::CudnnTranspose: |
1404 | if (detail::getCUDAHooks().compiledWithCuDNN()) { |
1405 | backend_memory_format = cudnn_conv_suggest_memory_format(input, weight); |
1406 | } |
1407 | break; |
1408 | case ConvBackend::Miopen: |
1409 | case ConvBackend::MiopenDepthwise: |
1410 | case ConvBackend::MiopenTranspose: |
1411 | if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) { |
1412 | TORCH_INTERNAL_ASSERT((k == 4 || k == 5), |
1413 | "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()" ); |
1414 | backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast; |
1415 | } |
1416 | break; |
1417 | case ConvBackend::Mkldnn: |
1418 | case ConvBackend::MkldnnTranspose: |
1419 | if (mkldnn_conv_use_channels_last(input, weight)) { |
1420 | backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast; |
1421 | } |
1422 | break; |
1423 | case ConvBackend::Slow2d: |
1424 | case ConvBackend::SlowDilated2d: |
1425 | case ConvBackend::SlowTranspose2d: |
1426 | if (thnn_conv_use_channels_last(input, weight)) { |
1427 | backend_memory_format = at::MemoryFormat::ChannelsLast; |
1428 | } |
1429 | break; |
1430 | default: |
1431 | backend_memory_format = at::MemoryFormat::Contiguous; |
1432 | } |
1433 | #endif |
1434 | return backend_memory_format; |
1435 | } |
1436 | |
1437 | at::MemoryFormat _determine_backend_memory_format( |
1438 | const Tensor& input, |
1439 | const Tensor& weight, |
1440 | const ConvBackend backend) { |
1441 | return determine_backend_memory_format(input, weight, backend); |
1442 | } |
1443 | |
1444 | at::Tensor _convolution( |
1445 | const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt, |
1446 | IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, |
1447 | bool transposed_, IntArrayRef output_padding_, int64_t groups_, |
1448 | bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { |
1449 | // See [Note: hacky wrapper removal for optional tensor] |
1450 | c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt); |
1451 | const Tensor& bias_r = *bias_r_maybe_owned; |
1452 | |
1453 | auto input = input_r; |
1454 | auto weight = weight_r; |
1455 | auto bias = bias_r; |
1456 | auto k = weight.ndimension(); |
1457 | c10::IntArrayRef weight_sizes = weight.sizes(); |
1458 | int64_t dim = k - 2; |
1459 | |
1460 | TORCH_CHECK(dim > 0, "weight should have at least three dimensions" ); |
1461 | TORCH_CHECK(groups_ > 0, "non-positive groups is not supported" ); |
1462 | |
1463 | ConvParams<int64_t> params; |
1464 | params.stride = expand_param_if_needed(stride_, "stride" , dim); |
1465 | params.padding = expand_param_if_needed(padding_, "padding" , dim); |
1466 | params.dilation = expand_param_if_needed(dilation_, "dilation" , dim); |
1467 | params.transposed = transposed_; |
1468 | params.output_padding = expand_param_if_needed(output_padding_, "output_padding" , dim); |
1469 | params.groups = groups_; |
1470 | params.benchmark = benchmark; |
1471 | params.deterministic = deterministic; |
1472 | params.cudnn_enabled = cudnn_enabled; |
1473 | params.allow_tf32 = allow_tf32; |
1474 | |
1475 | check_shape_forward(input, weight_sizes, bias, params); |
1476 | |
1477 | // Expand 1d -> 2d. |
1478 | // This is only done for backends that don't natively support 1d spatial input. |
1479 | if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { |
1480 | // avoid accidentally going through NHWC for permuted 3d input. |
1481 | input = input.contiguous(); |
1482 | params.view1d_as_2d(); |
1483 | input = view4d(input); |
1484 | weight = view4d(weight); |
1485 | } |
1486 | |
1487 | // Select appropriate backend to use. |
1488 | auto bias_sizes_opt = bias.defined() ? c10::optional<IntArrayRef>(bias.sizes()) : c10::nullopt; |
1489 | bool need_backward = GradMode::is_enabled() && |
1490 | (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); |
1491 | ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params); |
1492 | at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend); |
1493 | |
1494 | // Call the backend. |
1495 | Tensor output; |
1496 | auto kernel_size = weight.sizes().slice(2); |
1497 | switch (backend) { |
1498 | case ConvBackend::CudaDepthwise2d: |
1499 | output = at::_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, |
1500 | params.stride, params.padding, params.dilation); |
1501 | break; |
1502 | case ConvBackend::CudaDepthwise3d: |
1503 | output = at::conv_depthwise3d(input.contiguous(), weight, kernel_size, bias, |
1504 | params.stride, params.padding, params.dilation); |
1505 | break; |
1506 | case ConvBackend::Cudnn: |
1507 | check_input_same_type_as_parameters(input, weight, bias); |
1508 | output = at::cudnn_convolution( |
1509 | input.contiguous(backend_memory_format), weight, params.padding, params.stride, |
1510 | params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32); |
1511 | if (bias.defined()) { |
1512 | output.add_(reshape_bias(input.dim(), bias)); |
1513 | } |
1514 | break; |
1515 | case ConvBackend::CudnnTranspose: |
1516 | check_input_same_type_as_parameters(input, weight, bias); |
1517 | output = at::cudnn_convolution_transpose( |
1518 | input.contiguous(backend_memory_format), weight, params.padding, params.output_padding, |
1519 | params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32); |
1520 | if (bias.defined()) { |
1521 | output.add_(reshape_bias(input.dim(), bias)); |
1522 | } |
1523 | break; |
1524 | case ConvBackend::Empty: |
1525 | { |
1526 | Tensor weight_view; |
1527 | // Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where |
1528 | // view size is not compatible with input tensor's size and stride. |
1529 | if(weight.is_contiguous()) { |
1530 | weight_view = at::_unsafe_view(weight, -1); |
1531 | } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) { |
1532 | weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1); |
1533 | } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) { |
1534 | weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1); |
1535 | } else { |
1536 | weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1); |
1537 | } |
1538 | |
1539 | output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]); |
1540 | if (bias.defined()) { |
1541 | output.add_(bias[0]); |
1542 | } |
1543 | output = output.view(calc_output_size(input, weight, params)); |
1544 | break; |
1545 | } |
1546 | case ConvBackend::Miopen: |
1547 | check_input_same_type_as_parameters(input, weight, bias); |
1548 | output = at::miopen_convolution( |
1549 | input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride, |
1550 | params.dilation, params.groups, params.benchmark, params.deterministic); |
1551 | break; |
1552 | case ConvBackend::MiopenDepthwise: |
1553 | output = at::miopen_depthwise_convolution( |
1554 | input.contiguous(backend_memory_format), weight, bias, params.padding, params.stride, |
1555 | params.dilation, params.groups, params.benchmark, params.deterministic); |
1556 | break; |
1557 | case ConvBackend::MiopenTranspose: |
1558 | check_input_same_type_as_parameters(input, weight, bias); |
1559 | output = at::miopen_convolution_transpose( |
1560 | input.contiguous(backend_memory_format), weight, bias, params.padding, params.output_padding, |
1561 | params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); |
1562 | break; |
1563 | case ConvBackend::Mkldnn: |
1564 | #if AT_MKLDNN_ENABLED() |
1565 | check_input_same_type_as_parameters(input, weight, bias, backend); |
1566 | if (!input.is_mkldnn()) { |
1567 | // need to ensure contiguous for non-mkldnn tensors |
1568 | input = input.contiguous(backend_memory_format); |
1569 | weight = weight.contiguous(backend_memory_format); |
1570 | bias = bias.defined() ? bias.contiguous() : bias; |
1571 | } |
1572 | output = at::mkldnn_convolution( |
1573 | input, weight, bias, params.padding, params.stride, params.dilation, params.groups); |
1574 | #else |
1575 | TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support" ); |
1576 | #endif |
1577 | break; |
1578 | case ConvBackend::MkldnnTranspose: |
1579 | #if AT_MKLDNN_ENABLED() |
1580 | check_input_same_type_as_parameters(input, weight, bias, backend); |
1581 | if (!input.is_mkldnn()) { |
1582 | // need to ensure contiguous for non-mkldnn tensors |
1583 | input = input.contiguous(backend_memory_format); |
1584 | weight = weight.contiguous(backend_memory_format); |
1585 | bias = bias.defined() ? bias.contiguous() : bias; |
1586 | } |
1587 | output = mkldnn_convolution_transpose_stub(input.device().type(), |
1588 | input, weight, bias, params.padding, params.output_padding, params.stride, params.dilation, params.groups); |
1589 | #else |
1590 | TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support" ); |
1591 | #endif |
1592 | break; |
1593 | case ConvBackend::MkldnnEmpty: |
1594 | #if AT_MKLDNN_ENABLED() |
1595 | output = empty_mkldnn( |
1596 | calc_output_size(input, weight, params), optTypeMetaToScalarType(input.options().dtype_opt()), |
1597 | input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt()); |
1598 | #else |
1599 | TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support" ); |
1600 | #endif |
1601 | break; |
1602 | case ConvBackend::Overrideable: |
1603 | output = at::convolution_overrideable( |
1604 | input, weight, bias, params.stride, params.padding, params.dilation, params.transposed, |
1605 | params.output_padding, params.groups); |
1606 | break; |
1607 | case ConvBackend::Slow3d: |
1608 | output = at::slow_conv3d(input, weight, kernel_size, bias, params.stride, params.padding); |
1609 | break; |
1610 | case ConvBackend::Winograd3x3Depthwise: |
1611 | output = convolution_depthwise3x3_winograd_stub( |
1612 | input.device().type(), input, weight, bias, params.stride, params.padding, params.groups); |
1613 | break; |
1614 | case ConvBackend::Xnnpack2d: |
1615 | output = xnnpack::convolution2d( |
1616 | input, weight, bias, params.padding, params.stride, params.dilation, params.groups); |
1617 | break; |
1618 | // Handle backends that don't natively support groups > 1. |
1619 | case ConvBackend::NnpackSpatial: |
1620 | case ConvBackend::Slow2d: |
1621 | case ConvBackend::SlowDilated2d: |
1622 | case ConvBackend::SlowDilated3d: |
1623 | case ConvBackend::SlowTranspose2d: |
1624 | case ConvBackend::SlowTranspose3d: |
1625 | input = input.contiguous(backend_memory_format); |
1626 | weight = weight.contiguous(backend_memory_format); |
1627 | if (params.groups == 1) { |
1628 | output = _convolution_nogroup_backend(input, weight, bias, backend, params); |
1629 | } else { |
1630 | std::vector<Tensor> outputs(params.groups); |
1631 | for (const auto g : c10::irange(params.groups)) { |
1632 | auto input_g = subtensor(input, 1, params.groups, g); |
1633 | auto weight_g = subtensor(weight, 0, params.groups, g); |
1634 | auto bias_g = subtensor(bias, 0, params.groups, g); |
1635 | outputs[g] = _convolution_nogroup_backend(input_g, weight_g, bias_g, backend, params); |
1636 | } |
1637 | output = at::cat(outputs, 1); |
1638 | } |
1639 | break; |
1640 | case ConvBackend::Mps: |
1641 | #ifdef USE_MPS |
1642 | TORCH_CHECK(input.options().type_equal(weight.options()), |
1643 | "Input type (" , input.toString(), ") and weight type (" , weight.toString(), |
1644 | ") should be the same" ); |
1645 | TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())), |
1646 | "Input type (" , input.toString(), ") and bias type (" , bias.toString(), |
1647 | ") should be the same" ); |
1648 | |
1649 | output = at::_mps_convolution(input.contiguous(), weight, bias.defined() ? bias.contiguous() : bias, |
1650 | params.padding, params.stride, params.dilation, |
1651 | params.groups); |
1652 | #else |
1653 | TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support" ); |
1654 | #endif |
1655 | break; |
1656 | case ConvBackend::MpsTranspose: |
1657 | #ifdef USE_MPS |
1658 | TORCH_CHECK(input.options().type_equal(weight.options()), |
1659 | "Input type (" , input.toString(), ") and weight type (" , weight.toString(), |
1660 | ") should be the same" ); |
1661 | TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())), |
1662 | "Input type (" , input.toString(), ") and bias type (" , bias.toString(), |
1663 | ") should be the same" ); |
1664 | output = at::_mps_convolution_transpose( |
1665 | input.contiguous(backend_memory_format), weight, |
1666 | params.padding, params.output_padding, |
1667 | params.stride, params.dilation, params.groups); |
1668 | if (bias.defined()) { |
1669 | output.add_(reshape_bias(input.dim(), bias)); |
1670 | } |
1671 | #else |
1672 | TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support" ); |
1673 | #endif |
1674 | break; |
1675 | } |
1676 | |
1677 | if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { |
1678 | output = view3d(output); |
1679 | } |
1680 | |
1681 | return output; |
1682 | } |
1683 | |
1684 | at::Tensor _convolution( |
1685 | const Tensor& input_r, const Tensor& weight_r, const c10::optional<Tensor>& bias_r_opt, |
1686 | IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, |
1687 | bool transposed_, IntArrayRef output_padding_, int64_t groups_, |
1688 | bool benchmark, bool deterministic, bool cudnn_enabled) |
1689 | { |
1690 | // See [Note: hacky wrapper removal for optional tensor] |
1691 | c10::MaybeOwned<Tensor> bias_r_maybe_owned = at::borrow_from_optional_tensor(bias_r_opt); |
1692 | const Tensor& bias_r = *bias_r_maybe_owned; |
1693 | |
1694 | return at::_convolution(input_r, weight_r, bias_r, stride_, padding_, dilation_, transposed_, output_padding_, groups_, benchmark, deterministic, cudnn_enabled, at::globalContext().allowTF32CuDNN()); |
1695 | } |
1696 | |
1697 | std::tuple<Tensor, Tensor, Tensor> convolution_backward_overrideable( |
1698 | const Tensor& grad_output, const Tensor& input, const Tensor& weight, |
1699 | IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, |
1700 | bool transposed, IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) { |
1701 | TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function " ); |
1702 | return std::tuple<Tensor, Tensor, Tensor>( |
1703 | at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT), |
1704 | at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT), |
1705 | at::empty({})); |
1706 | } |
1707 | |
1708 | static Tensor subvariable(const Tensor& var, int dim, int groups, int g) { |
1709 | int64_t n = var.sizes()[dim] / groups; |
1710 | auto result = var.narrow(dim, n * g, n); |
1711 | return result; |
1712 | } |
1713 | |
1714 | std::tuple<Tensor,Tensor,Tensor> _convolution_double_backward( const c10::optional<Tensor>& ggI_opt, const c10::optional<Tensor>& ggW_r_opt, const c10::optional<Tensor>& ggb_opt, |
1715 | const Tensor& gO_r, const Tensor& weight_r, const Tensor& input, |
1716 | IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, |
1717 | bool transposed_, IntArrayRef output_padding_, int64_t groups_, |
1718 | std::array<bool, 3> output_mask) { |
1719 | // See [Note: hacky wrapper removal for optional tensor] |
1720 | c10::MaybeOwned<Tensor> ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt); |
1721 | const Tensor& ggI = *ggI_maybe_owned; |
1722 | const Tensor& ggW_r = c10::value_or_else(ggW_r_opt, [] {return Tensor();}); |
1723 | const Tensor& ggb = c10::value_or_else(ggb_opt, [] {return Tensor();}); |
1724 | |
1725 | |
1726 | auto ggW = ggW_r; |
1727 | auto gO = gO_r; |
1728 | auto weight = weight_r; |
1729 | |
1730 | int64_t dim = weight.ndimension() - 2; |
1731 | ConvParams<int64_t> params; |
1732 | params.stride = expand_param_if_needed(stride_, "stride" , dim); |
1733 | params.padding = expand_param_if_needed(padding_, "padding" , dim); |
1734 | params.dilation = expand_param_if_needed(dilation_, "dilation" , dim); |
1735 | params.transposed = transposed_; |
1736 | params.output_padding = expand_param_if_needed(output_padding_, "output_padding" , dim); |
1737 | // TODO: hacky way of inferring the groups number for grouped Conv3D |
1738 | // See: https://github.com/pytorch/pytorch/pull/36355 |
1739 | if (!params.transposed && input.dim() > 4) { |
1740 | // Avoid undefined behavior when num channels == 0; params are unused for that case. |
1741 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
1742 | params.groups = (weight.size(1) > 0) ? input.size(1) / weight.size(1) : -1; |
1743 | } else { |
1744 | params.groups = groups_; |
1745 | } |
1746 | |
1747 | // Compute ggO = conv(ggI, w) + conv(i, ggW) + ggb |
1748 | Tensor ggO; |
1749 | if (input.numel() != 0) { |
1750 | if (ggI.defined()) { |
1751 | if (weight.is_cuda()) { |
1752 | weight = weight.contiguous(); |
1753 | } |
1754 | ggO = at::convolution(ggI, weight, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups); |
1755 | } |
1756 | |
1757 | if (ggW.defined()) { |
1758 | if (ggW.is_cuda()) { |
1759 | ggW = ggW.contiguous(); |
1760 | } |
1761 | auto ggW_term = at::convolution(input, ggW, Tensor(), params.stride, params.padding, params.dilation, params.transposed, params.output_padding, params.groups); |
1762 | if (ggO.defined()) { |
1763 | ggO = ggO + ggW_term; |
1764 | } else { |
1765 | ggO = ggW_term; |
1766 | } |
1767 | } |
1768 | } |
1769 | |
1770 | if (ggb.defined()) { |
1771 | // View as (1, ggb.size(0), 1, 1...) |
1772 | |
1773 | // Expand |
1774 | std::vector<int64_t> new_size(gO.ndimension(), 1); |
1775 | new_size[1] = ggb.sizes()[0]; |
1776 | auto ggb_contiguous = ggb.contiguous(); |
1777 | auto ggb_view = ggb_contiguous.view(new_size); |
1778 | |
1779 | // Expand |
1780 | auto ggb_expanded = ggb_view.expand(gO.sizes()); |
1781 | |
1782 | if (ggO.defined()) { |
1783 | ggO = ggO + ggb_expanded; |
1784 | } else { |
1785 | ggO = ggb_expanded; |
1786 | } |
1787 | } |
1788 | |
1789 | // Compute gW = conv(ggI, gO) |
1790 | Tensor gW; |
1791 | if (ggI.defined()) { |
1792 | |
1793 | // Modified params with correct padding |
1794 | ConvParams<int64_t> gw_conv_params(params); |
1795 | |
1796 | // Disable groups as they are handled separately |
1797 | auto groups = gw_conv_params.groups; |
1798 | gw_conv_params.groups = 1; |
1799 | std::swap(gw_conv_params.dilation, gw_conv_params.stride); |
1800 | |
1801 | // Transpose gO and ggI to accumulate over batch |
1802 | auto gOt = gO.transpose(0, 1); |
1803 | auto ggIt = ggI.transpose(0, 1); |
1804 | |
1805 | Tensor gWt; |
1806 | // Compute conv |
1807 | if (input.numel() != 0) { |
1808 | if (groups == 1) { |
1809 | |
1810 | if (gOt.is_cuda()) { |
1811 | gOt = gOt.contiguous(); |
1812 | } |
1813 | // Compute conv |
1814 | if (params.transposed) { |
1815 | gw_conv_params.transposed = false; |
1816 | gWt = at::convolution(gOt, ggIt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups); |
1817 | } else { |
1818 | gWt = at::convolution(ggIt, gOt, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups); |
1819 | } |
1820 | } else { |
1821 | std::vector<Tensor> gWt_list(groups); |
1822 | for (const auto g : c10::irange(groups)) { |
1823 | auto ggIt_g = subvariable(ggIt, 0, groups, g); |
1824 | auto gOt_g = subvariable(gOt, 0, groups, g); |
1825 | if (gOt_g.is_cuda()) { |
1826 | gOt_g = gOt_g.contiguous(); |
1827 | } |
1828 | |
1829 | // Compute conv |
1830 | if (params.transposed) { |
1831 | gw_conv_params.transposed = false; |
1832 | gWt_list[g] = at::convolution(gOt_g, ggIt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups); |
1833 | } else { |
1834 | gWt_list[g] = at::convolution(ggIt_g, gOt_g, Tensor(), gw_conv_params.stride, gw_conv_params.padding, gw_conv_params.dilation, gw_conv_params.transposed, gw_conv_params.output_padding, gw_conv_params.groups); |
1835 | } |
1836 | } |
1837 | |
1838 | gWt = at::cat(gWt_list, 1); |
1839 | } |
1840 | |
1841 | // Transpose gW to match chan_in and chan_out |
1842 | gW = gWt.transpose(0, 1); |
1843 | |
1844 | // narrow gW to only relevant portion |
1845 | // we do it this way instead of narrowing the input itself because |
1846 | // the ConvForward kernels don't support asymmetric padding. |
1847 | auto gW_size = gW.sizes(); |
1848 | auto w_size = weight.sizes(); |
1849 | for (const auto i : c10::irange(2, gW_size.size())) { |
1850 | if (gW_size[i] > w_size[i]) { |
1851 | gW = gW.narrow(i, 0, w_size[i]); |
1852 | gW_size = gW.sizes(); |
1853 | } |
1854 | } |
1855 | } |
1856 | } |
1857 | |
1858 | // Compute gI = convT(gO, ggW) if !transposed |
1859 | // gI = conv(gO, ggw) if transposed |
1860 | Tensor gI; |
1861 | if (input.numel() != 0) { |
1862 | if (ggW.defined()) { |
1863 | ConvParams<int64_t> gi_conv_params(params); |
1864 | gi_conv_params.transposed = !params.transposed; |
1865 | |
1866 | if (params.transposed) { |
1867 | if (gO.is_cuda()) { |
1868 | gO = gO.contiguous(); |
1869 | } |
1870 | gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups); |
1871 | |
1872 | // narrow gI to only relevant portion |
1873 | // we do it this way because negative output_padding is not supported |
1874 | // TODO: figure out if we can narrow gO and save some compute, |
1875 | // rather than narrowing the computed gI |
1876 | auto gI_size = gI.sizes(); |
1877 | auto i_size = input.sizes(); |
1878 | for (const auto i : c10::irange(2, gI_size.size())) { |
1879 | if (gI_size[i] > i_size[i]) { |
1880 | gI = gI.narrow(i, 0, i_size[i]); |
1881 | gI_size = gI.sizes(); |
1882 | } |
1883 | } |
1884 | } else { |
1885 | // calculate output_padding |
1886 | // TODO: figure out why this needs to be computed... |
1887 | auto kernel_size = weight.sizes().slice(2); |
1888 | auto input_shape = input.sizes().slice(2); |
1889 | auto grad_output_shape = gO.sizes().slice(2); |
1890 | |
1891 | for (const auto i : c10::irange(kernel_size.size())) { |
1892 | // Check if whole input has been used or not |
1893 | auto expected_input_shape = (kernel_size[i] - 1) * gi_conv_params.dilation[i] |
1894 | - 2 * gi_conv_params.padding[i] |
1895 | + (gi_conv_params.stride[i] * (grad_output_shape[i] - 1) + 1); |
1896 | if (expected_input_shape != input_shape[i]) { |
1897 | gi_conv_params.output_padding[i] = input_shape[i] - expected_input_shape; |
1898 | } |
1899 | } |
1900 | |
1901 | if (gO.is_cuda()) { |
1902 | gO = gO.contiguous(); |
1903 | } |
1904 | |
1905 | gI = at::convolution(gO, ggW, Tensor(), gi_conv_params.stride, gi_conv_params.padding, gi_conv_params.dilation, gi_conv_params.transposed, gi_conv_params.output_padding, gi_conv_params.groups); |
1906 | } |
1907 | } |
1908 | } |
1909 | |
1910 | return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW}; |
1911 | } |
1912 | |
1913 | std::tuple<at::Tensor, at::Tensor, at::Tensor> _convolution_backward_nogroup_backend( |
1914 | const Tensor& grad_output, |
1915 | const Tensor& input, |
1916 | const Tensor& weight, |
1917 | const std::array<bool, 3> output_mask, |
1918 | const ConvBackend backend, |
1919 | const ConvParams<int64_t>& params) { |
1920 | auto kernel_size = weight.sizes().slice(2); |
1921 | switch(backend) { |
1922 | case ConvBackend::Slow2d: |
1923 | return at::_slow_conv2d_backward( |
1924 | grad_output, input, weight, kernel_size, params.stride, params.padding, output_mask); |
1925 | // NB: nnpack backward does not support strided convolutions; use slow impl instead |
1926 | case ConvBackend::NnpackSpatial: |
1927 | case ConvBackend::SlowDilated2d: |
1928 | return slow_conv_dilated2d_backward_stub( |
1929 | input.device().type(), |
1930 | grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask); |
1931 | case ConvBackend::SlowDilated3d: |
1932 | return slow_conv_dilated3d_backward_stub( |
1933 | input.device().type(), |
1934 | grad_output, input, weight, kernel_size, params.stride, params.padding, params.dilation, output_mask); |
1935 | case ConvBackend::SlowTranspose2d: |
1936 | return slow_conv_transpose2d_backward_stub( |
1937 | input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding, |
1938 | params.output_padding, params.dilation, output_mask); |
1939 | case ConvBackend::SlowTranspose3d: |
1940 | return slow_conv_transpose3d_backward_stub( |
1941 | input.device().type(), grad_output, input, weight, kernel_size, params.stride, params.padding, |
1942 | params.output_padding, params.dilation, output_mask); |
1943 | default: |
1944 | TORCH_CHECK(false, "Unsupported conv nogroup backend encountered" ); |
1945 | } |
1946 | } |
1947 | |
1948 | // Backward pass for convolution. Computes gradients for input, weight, and bias depending on the |
1949 | // output_mask setting. This function supports 1D, 2D, or 3D spatial convolution and currently requires |
1950 | // a single batch dimension to be present. |
1951 | // |
1952 | // Args: |
1953 | // grad_output_: tensor of shape (N, C_out, L_out), (N, C_out, H_out, W_out), or (N, C_out, D_out, H_out, W_out) |
1954 | // input_: tensor of shape (N, C_in, L_in), (N, C_in, H_in, W_in), or (N, C_in, D_in, H_in, W_in) |
1955 | // weight_: tensor of shape (C_out, C_in // groups, *kernel_size); dimension of kernel_size must match the number |
1956 | // of input spatial dimensions |
1957 | // bias_sizes_opt: if specified, indicates that a bias was used in the forward pass and contains the shape |
1958 | // of the bias. While the bias shape can be computed from other inputs, it is provided to this function for |
1959 | // ease of use. The bias shape is (weight.shape[0]) for normal convolution and (weight.shape[1] * groups) |
1960 | // for transposed convolution. |
1961 | // stride: single value or an array with dimension matching the number of input spatial dimensions |
1962 | // padding: single value or an array with dimension matching the number of input spatial dimensions |
1963 | // dilation: single value or an array with dimension matching the number of input spatial dimensions |
1964 | // transposed: boolean indicating whether the convolution is transposed |
1965 | // output_padding: single value or dimension == number of input spatial dimensions; only supported when |
1966 | // transposed is true |
1967 | // groups: number of groups for grouped convolution |
1968 | // output_mask: 3-dim boolean array specifying which gradients to compute in input, weight, bias order |
1969 | std::tuple<Tensor, Tensor, Tensor> convolution_backward( |
1970 | const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_, |
1971 | const at::OptionalIntArrayRef bias_sizes_opt, |
1972 | IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, |
1973 | int64_t groups, std::array<bool, 3> output_mask) { |
1974 | auto grad_output = grad_output_; |
1975 | auto input = input_; |
1976 | auto weight = weight_; |
1977 | |
1978 | auto k = weight.ndimension(); |
1979 | int64_t dim = k - 2; |
1980 | |
1981 | TORCH_CHECK(dim > 0, "weight should have at least three dimensions" ); |
1982 | |
1983 | auto& ctx = at::globalContext(); |
1984 | ConvParams<int64_t> params; |
1985 | params.stride = expand_param_if_needed(stride, "stride" , dim); |
1986 | params.padding = expand_param_if_needed(padding, "padding" , dim); |
1987 | params.dilation = expand_param_if_needed(dilation, "dilation" , dim); |
1988 | params.transposed = transposed; |
1989 | params.output_padding = expand_param_if_needed(output_padding, "output_padding" , dim); |
1990 | params.groups = groups; |
1991 | params.benchmark = ctx.benchmarkCuDNN(); |
1992 | params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); |
1993 | params.cudnn_enabled = ctx.userEnabledCuDNN(); |
1994 | params.allow_tf32 = ctx.allowTF32CuDNN(); |
1995 | |
1996 | // Validate inputs. |
1997 | check_shape_backward(input, weight.sizes(), params); |
1998 | TORCH_CHECK(input.dim() == grad_output.dim(), |
1999 | "Expected input and grad_output to have the same number of dimensions, but got: " , |
2000 | input.dim(), " and " , grad_output.dim()); |
2001 | |
2002 | // output_padding is only supported for transposed convolutions |
2003 | if (!params.transposed) { |
2004 | for (auto pad : params.output_padding) { |
2005 | TORCH_CHECK(pad == 0, "output_padding is not supported for non-transposed convolutions; got: " , |
2006 | params.output_padding); |
2007 | } |
2008 | } |
2009 | |
2010 | // Expand 1d -> 2d. |
2011 | // This is only done for backends that don't natively support 1d spatial input. |
2012 | if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { |
2013 | // avoid accidentally going through NHWC for permuted 3d input. |
2014 | input = input.contiguous(); |
2015 | params.view1d_as_2d(); |
2016 | grad_output = view4d(grad_output); |
2017 | input = view4d(input); |
2018 | weight = view4d(weight); |
2019 | } |
2020 | |
2021 | // Select appropriate backend to use. |
2022 | ConvBackend backend = select_conv_backend(input, weight, bias_sizes_opt, /*need_backward=*/ true, params); |
2023 | at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend); |
2024 | |
2025 | // Call the backend. |
2026 | Tensor backend_grad_input, backend_grad_weight, backend_grad_bias; |
2027 | auto kernel_size = weight.sizes().slice(2); |
2028 | switch(backend) { |
2029 | case ConvBackend::CudaDepthwise2d: |
2030 | { |
2031 | std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]}; |
2032 | std::tie(backend_grad_input, backend_grad_weight) = |
2033 | conv_depthwise2d_backward_stub(input.device().type(), grad_output, input, |
2034 | weight, kernel_size, params.stride, params.padding, params.dilation, input_weight_output_mask); |
2035 | break; |
2036 | } |
2037 | case ConvBackend::CudaDepthwise3d: |
2038 | TORCH_CHECK(input.ndimension() == 5); |
2039 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2040 | conv_depthwise3d_backward_stub( |
2041 | input.device().type(), grad_output, input, weight, kernel_size, params.stride, |
2042 | params.padding, params.dilation, output_mask); |
2043 | break; |
2044 | case ConvBackend::Cudnn: |
2045 | { |
2046 | check_input_same_type_as_parameters(input, weight); |
2047 | std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]}; |
2048 | std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_backward_stub( |
2049 | input.device().type(), |
2050 | // Only make input contiguous when it is necessary for the backwards computation |
2051 | output_mask[1] ? input.contiguous(backend_memory_format) : input, |
2052 | grad_output, weight, params.padding, params.stride, |
2053 | params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32, |
2054 | input_weight_output_mask); |
2055 | break; |
2056 | } |
2057 | case ConvBackend::Mps: |
2058 | { |
2059 | #ifdef USE_MPS |
2060 | check_input_same_type_as_parameters(input, weight); |
2061 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2062 | at::mps_convolution_backward(input, grad_output, weight, params.padding, |
2063 | params.stride, params.dilation, params.groups, output_mask); |
2064 | #else |
2065 | TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support" ); |
2066 | #endif |
2067 | break; |
2068 | } |
2069 | case ConvBackend::MpsTranspose: |
2070 | { |
2071 | #ifdef USE_MPS |
2072 | check_input_same_type_as_parameters(input, weight); |
2073 | std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]}; |
2074 | std::tie(backend_grad_input, backend_grad_weight) = at::mps_convolution_transpose_backward( |
2075 | // Only make input contiguous when it is necessary for the backwards computation |
2076 | output_mask[1] ? input.contiguous(backend_memory_format) : input, |
2077 | grad_output, weight, params.padding, params.output_padding, |
2078 | params.stride, params.dilation, params.groups, input_weight_output_mask); |
2079 | #else |
2080 | TORCH_INTERNAL_ASSERT(false, "MPS backend was selected in PyTorch without support" ); |
2081 | #endif |
2082 | break; |
2083 | } |
2084 | case ConvBackend::CudnnTranspose: |
2085 | { |
2086 | check_input_same_type_as_parameters(input, weight); |
2087 | std::array<bool, 2> input_weight_output_mask = {output_mask[0], output_mask[1]}; |
2088 | std::tie(backend_grad_input, backend_grad_weight) = cudnn_convolution_transpose_backward_stub( |
2089 | input.device().type(), |
2090 | // Only make input contiguous when it is necessary for the backwards computation |
2091 | output_mask[1] ? input.contiguous(backend_memory_format) : input, |
2092 | grad_output, weight, params.padding, params.output_padding, |
2093 | params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32, |
2094 | input_weight_output_mask); |
2095 | break; |
2096 | } |
2097 | case ConvBackend::Empty: |
2098 | if (output_mask[0]) { |
2099 | backend_grad_input = at::zeros_like(input); |
2100 | } |
2101 | if (output_mask[1]) { |
2102 | backend_grad_weight = at::zeros_like(weight); |
2103 | } |
2104 | if (output_mask[2]) { |
2105 | backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options()); |
2106 | } |
2107 | break; |
2108 | case ConvBackend::MkldnnEmpty: |
2109 | #if AT_MKLDNN_ENABLED() |
2110 | if (output_mask[0]) { |
2111 | if (input.is_mkldnn()) { |
2112 | backend_grad_input = empty_mkldnn(input.sizes(), optTypeMetaToScalarType(input.options().dtype_opt()), |
2113 | input.options().layout_opt(), input.options().device_opt(), input.options().pinned_memory_opt()); |
2114 | backend_grad_input.zero_(); |
2115 | } else { |
2116 | backend_grad_input = at::zeros_like(input); |
2117 | } |
2118 | } |
2119 | if (output_mask[1]) { |
2120 | // mkldnn weight is not supported during training by the mkldnn backend |
2121 | backend_grad_weight = at::zeros_like(weight); |
2122 | } |
2123 | if (output_mask[2]) { |
2124 | // mkldnn bias is not supported during training by the mkldnn backend |
2125 | backend_grad_bias = at::zeros(*bias_sizes_opt, weight.options()); |
2126 | } |
2127 | #else |
2128 | TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support" ); |
2129 | #endif |
2130 | break; |
2131 | case ConvBackend::Miopen: |
2132 | check_input_same_type_as_parameters(input, weight); |
2133 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2134 | miopen_convolution_backward_stub( |
2135 | input.device().type(), |
2136 | input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride, |
2137 | params.dilation, params.groups, params.benchmark, params.deterministic, output_mask); |
2138 | break; |
2139 | case ConvBackend::MiopenDepthwise: |
2140 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2141 | miopen_depthwise_convolution_backward_stub( |
2142 | input.device().type(), |
2143 | input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride, |
2144 | params.dilation, params.groups, params.benchmark, params.deterministic, output_mask); |
2145 | break; |
2146 | case ConvBackend::MiopenTranspose: |
2147 | check_input_same_type_as_parameters(input, weight); |
2148 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2149 | miopen_convolution_transpose_backward_stub( |
2150 | input.device().type(), |
2151 | input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding, |
2152 | params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, output_mask); |
2153 | break; |
2154 | case ConvBackend::Mkldnn: |
2155 | TORCH_CHECK(!weight.is_mkldnn(), |
2156 | "The MKLDNN backend does not support weight as an MKLDNN tensor during training" ); |
2157 | if (!input.is_mkldnn()) { |
2158 | input = input.contiguous(backend_memory_format); |
2159 | weight = weight.contiguous(backend_memory_format); |
2160 | } |
2161 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2162 | mkldnn_convolution_backward_stub(input.device().type(), input, grad_output, weight, params.padding, |
2163 | params.stride, params.dilation, params.groups, output_mask); |
2164 | break; |
2165 | case ConvBackend::MkldnnTranspose: |
2166 | TORCH_CHECK(!weight.is_mkldnn(), |
2167 | "The MKLDNN backend does not support weight as an MKLDNN tensor during training" ); |
2168 | if (!input.is_mkldnn()) { |
2169 | input = input.contiguous(backend_memory_format); |
2170 | weight = weight.contiguous(backend_memory_format); |
2171 | } |
2172 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2173 | mkldnn_convolution_transpose_backward_stub(input.device().type(), input, grad_output, weight, params.padding, |
2174 | params.output_padding, params.stride, params.dilation, params.groups, output_mask); |
2175 | break; |
2176 | case ConvBackend::Overrideable: |
2177 | // Only reach here when input is backend with out-of-source implementation. |
2178 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2179 | at::convolution_backward_overrideable(grad_output, input, weight, params.stride, params.padding, |
2180 | params.dilation, params.transposed, params.output_padding, params.groups, output_mask); |
2181 | break; |
2182 | case ConvBackend::Slow3d: |
2183 | // Note that no CUDA implementation of this kernel exists currently. |
2184 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2185 | slow_conv3d_backward_cpu( |
2186 | grad_output, input, weight, kernel_size, |
2187 | params.stride, params.padding, output_mask); |
2188 | break; |
2189 | // Handle backends that don't natively support groups > 1. |
2190 | case ConvBackend::NnpackSpatial: |
2191 | case ConvBackend::Slow2d: |
2192 | case ConvBackend::SlowDilated2d: |
2193 | case ConvBackend::SlowDilated3d: |
2194 | case ConvBackend::SlowTranspose2d: |
2195 | case ConvBackend::SlowTranspose3d: |
2196 | { |
2197 | input = input.contiguous(backend_memory_format); |
2198 | weight = weight.contiguous(backend_memory_format); |
2199 | if (params.groups == 1) { |
2200 | std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = |
2201 | _convolution_backward_nogroup_backend( |
2202 | grad_output, input, weight, output_mask, backend, params); |
2203 | } else { |
2204 | std::vector<Tensor> backend_grad_inputs(params.groups); |
2205 | std::vector<Tensor> backend_grad_weights(params.groups); |
2206 | std::vector<Tensor> backend_grad_biases(params.groups); |
2207 | for (int g = 0; g < params.groups; ++g) { |
2208 | auto grad_output_g = subtensor(grad_output, 1, params.groups, g); |
2209 | auto input_g = subtensor(input, 1, params.groups, g); |
2210 | auto weight_g = subtensor(weight, 0, params.groups, g); |
2211 | std::tie(backend_grad_inputs[g], backend_grad_weights[g], backend_grad_biases[g]) = |
2212 | _convolution_backward_nogroup_backend( |
2213 | grad_output_g, input_g, weight_g, output_mask, backend, params); |
2214 | } |
2215 | if (output_mask[0]) { |
2216 | backend_grad_input = at::cat(backend_grad_inputs, 1); |
2217 | } |
2218 | if (output_mask[1]) { |
2219 | backend_grad_weight = at::cat(backend_grad_weights, 0); |
2220 | } |
2221 | if (output_mask[2]) { |
2222 | backend_grad_bias = at::cat(backend_grad_biases, 0); |
2223 | } |
2224 | } |
2225 | break; |
2226 | } |
2227 | // Backward is not supported for these backends. |
2228 | case ConvBackend::Winograd3x3Depthwise: |
2229 | TORCH_CHECK(false, "Backward is not supported for depthwise 3x3 winograd" ); |
2230 | break; |
2231 | case ConvBackend::Xnnpack2d: |
2232 | TORCH_CHECK(false, "Backward is not supported for xnnpack" ); |
2233 | break; |
2234 | } |
2235 | |
2236 | // Convert 2D inputs back to 1D for backends that don't natively support 1D |
2237 | // spatial inputs. |
2238 | if (output_mask[0]) { |
2239 | if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { |
2240 | backend_grad_input = view3d(backend_grad_input); |
2241 | } |
2242 | } |
2243 | if (output_mask[1]) { |
2244 | if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { |
2245 | backend_grad_weight = view3d(backend_grad_weight); |
2246 | } |
2247 | } |
2248 | if (output_mask[2]) { |
2249 | if (!backend_grad_bias.defined()) { |
2250 | // Calculate bias gradients outside of the backend for those that don't support it. |
2251 | backend_grad_bias = grad_output.sum((dim == 3) ? IntArrayRef{0, 2, 3, 4} : IntArrayRef{0, 2, 3}); |
2252 | } |
2253 | } |
2254 | |
2255 | return std::make_tuple(backend_grad_input, backend_grad_weight, backend_grad_bias); |
2256 | } |
2257 | |
2258 | }} // at::native |
2259 | |