1#pragma once
2#include <ATen/core/Tensor.h>
3#include <ATen/TensorUtils.h>
4#include <ATen/detail/CUDAHooksInterface.h>
5#include <ATen/native/DispatchStub.h>
6#include <c10/util/env.h>
7#include <c10/util/irange.h>
8
9namespace at { namespace native {
10
11using conv_depthwise2d_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
12 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
13 at::IntArrayRef, at::IntArrayRef, std::array<bool, 2>);
14DECLARE_DISPATCH(conv_depthwise2d_backward_fn, conv_depthwise2d_backward_stub);
15using conv_depthwise3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
16 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
17 at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
18DECLARE_DISPATCH(conv_depthwise3d_backward_fn, conv_depthwise3d_backward_stub);
19using cudnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
20 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
21 at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
22DECLARE_DISPATCH(cudnn_convolution_backward_fn, cudnn_convolution_backward_stub);
23using mps_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
24 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
25 at::IntArrayRef, int64_t, std::array<bool,3>);
26DECLARE_DISPATCH(mps_convolution_backward_fn, mps_convolution_backward_stub);
27using cudnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor>(*)(
28 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
29 at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, bool, std::array<bool,2>);
30DECLARE_DISPATCH(cudnn_convolution_transpose_backward_fn, cudnn_convolution_transpose_backward_stub);
31using miopen_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
32 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
33 at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
34DECLARE_DISPATCH(miopen_convolution_backward_fn, miopen_convolution_backward_stub);
35using miopen_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
36 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
37 at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
38DECLARE_DISPATCH(miopen_convolution_transpose_backward_fn, miopen_convolution_transpose_backward_stub);
39using miopen_depthwise_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
40 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
41 at::IntArrayRef, int64_t, bool, bool, std::array<bool,3>);
42DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub);
43using mkldnn_convolution_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
44 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
45 at::IntArrayRef, int64_t, std::array<bool,3>);
46DECLARE_DISPATCH(mkldnn_convolution_backward_fn, mkldnn_convolution_backward_stub);
47using mkldnn_convolution_transpose_fn = Tensor(*)(const Tensor&, const Tensor&, const c10::optional<Tensor>&,
48 IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t);
49DECLARE_DISPATCH(mkldnn_convolution_transpose_fn, mkldnn_convolution_transpose_stub);
50using mkldnn_convolution_transpose_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
51 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
52 at::IntArrayRef, at::IntArrayRef, int64_t, std::array<bool,3>);
53DECLARE_DISPATCH(mkldnn_convolution_transpose_backward_fn, mkldnn_convolution_transpose_backward_stub);
54using slow_conv_dilated2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
55 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
56 at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
57DECLARE_DISPATCH(slow_conv_dilated2d_backward_fn, slow_conv_dilated2d_backward_stub);
58using slow_conv_dilated3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
59 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
60 at::IntArrayRef, at::IntArrayRef, std::array<bool, 3>);
61DECLARE_DISPATCH(slow_conv_dilated3d_backward_fn, slow_conv_dilated3d_backward_stub);
62using slow_conv_transpose2d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
63 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
64 at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
65DECLARE_DISPATCH(slow_conv_transpose2d_backward_fn, slow_conv_transpose2d_backward_stub);
66using slow_conv_transpose3d_backward_fn = std::tuple<at::Tensor,at::Tensor,at::Tensor>(*)(
67 const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef,
68 at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, std::array<bool,3>);
69DECLARE_DISPATCH(slow_conv_transpose3d_backward_fn, slow_conv_transpose3d_backward_stub);
70
71namespace {
72 static bool cudnnv8_heuristic_mode_b = c10::utils::check_env("TORCH_CUDNN_USE_HEURISTIC_MODE_B") == true;
73}
74
75static inline bool cudnnv8_enabled_check_debug() {
76 static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
77 static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
78 static uint8_t cudnnv8_debugcount = 0;
79 if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
80 TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", cudnnv8_heuristic_mode_b);
81 cudnnv8_debugcount++;
82 }
83 return cudnnv8_flag == 1;
84}
85
86static inline bool cudnnv8_use_heur_mode_b() {
87 return cudnnv8_heuristic_mode_b;
88}
89
90// Keep in sync with py::enum_ in Module.cpp
91enum class ConvBackend {
92 CudaDepthwise2d,
93 CudaDepthwise3d,
94 Cudnn,
95 CudnnTranspose,
96 Empty,
97 Miopen,
98 MiopenDepthwise,
99 MiopenTranspose,
100 Mkldnn,
101 MkldnnTranspose,
102 MkldnnEmpty,
103 NnpackSpatial,
104 Overrideable,
105 Slow2d,
106 Slow3d,
107 SlowDilated2d,
108 SlowDilated3d,
109 SlowTranspose2d,
110 SlowTranspose3d,
111 Winograd3x3Depthwise,
112 Xnnpack2d,
113 Mps,
114 MpsTranspose,
115};
116
117// Overload for selecting the convolution backend from the full set of convolution inputs.
118// This overload is exposed to python for testing, etc.
119TORCH_API ConvBackend select_conv_backend(
120 const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt,
121 IntArrayRef stride, SymIntArrayRef padding, IntArrayRef dilation,
122 bool transposed, SymIntArrayRef output_padding, int64_t groups, const at::OptionalSymIntArrayRef bias_sizes_opt);
123
124TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input,
125 const Tensor& weight,
126 const ConvBackend backend);
127
128// ---------------------------------------------------------------------
129//
130// Math
131//
132// ---------------------------------------------------------------------
133
134constexpr int input_batch_size_dim = 0; // also grad_input
135constexpr int input_channels_dim = 1;
136constexpr int output_batch_size_dim = 0; // also grad_output
137constexpr int output_channels_dim = 1;
138constexpr int weight_output_channels_dim = 0;
139constexpr int weight_input_channels_dim = 1;
140
141// Often written as 2 + max_dim (extra dims for batch size and channels)
142constexpr int max_dim = 3;
143
144// ---------------------------------------------------------------------
145//
146// Checking
147//
148// ---------------------------------------------------------------------
149
150// Used on pad, stride and dilation
151static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, const char* arg_name)
152{
153 TORCH_CHECK(args.size() <= expected_size,
154 "Too many ", arg_name, " values (", args.size(), ") supplied, expecting ",
155 expected_size, " (while checking arguments for ", c, ")");
156 TORCH_CHECK(args.size() >= expected_size,
157 "Not enough ", arg_name, " values (", args.size(), ") supplied, expecting ",
158 expected_size, " (while checking arguments for ", c, ")");
159
160 auto num_negative_values = std::count_if(args.begin(), args.end(), [](int x){return x < 0;});
161 if (num_negative_values > 0){
162 std::stringstream ss;
163 ss << arg_name << " should be greater than zero but got (";
164 std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
165 ss << args.back() << ")" << " (while checking arguments for " << c << ")";
166 AT_ERROR(ss.str());
167 }
168}
169
170
171// NOTE [ Convolution checks ]
172//
173// NB: For many call sites, it is not strictly necessary to check all of
174// these relationships (for example, for forward convolution, we compute
175// the size of output ourselves, so we don't actually need to check
176// output. However, writing a single function that does everything
177// means we get to reuse it for both forwards and all backwards
178// variants, even when the set of "real" inputs varies. The magic of
179// relational computing!
180//
181// (There is one downside, which is that it is slightly harder to write
182// error messages which are able to distinguish between real inputs
183// (which the user can change) and computed inputs (which the user can
184// only indirectly affect). It would be an interesting exercise to
185// come up with a general framework to handle such situations.)
186static void convolution_shape_check(
187 CheckedFrom c,
188 const TensorGeometryArg& input, const TensorGeometryArg& weight, const TensorGeometryArg& output,
189 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
190{
191 check_args(c, padding, input->dim() - 2, "padding");
192 check_args(c, stride, padding.size(), "stride");
193 check_args(c, dilation, padding.size(), "dilation");
194
195 // Input
196 checkDimRange(c, input, 3, 6 /* exclusive */);
197 checkSize_symint(c, input, input_channels_dim, weight->size(1) * groups);
198
199 // Weight
200 checkSameDim(c, input, weight);
201
202 // TODO: check that output->size() matches output_sizes
203 // TODO: check that weight matches output->sizes()
204 checkSameDim(c, input, output);
205}
206
207// NB: conv_output_size and conv_input_size are not bijections,
208// as conv_output_size loses information; this is why conv_input_size
209// takes an extra output_padding argument to resolve the ambiguity.
210
211template <typename T>
212static inline std::vector<T> _conv_output_size(
213 ArrayRef<T> input_size, ArrayRef<T> weight_size,
214 ArrayRef<T> padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
215) {
216 // ASSERT(input_size.size() > 2)
217 // ASSERT(input_size.size() == weight_size.size())
218 bool has_dilation = !dilation.empty();
219 auto dim = input_size.size();
220 std::vector<T> output_size(dim);
221 output_size[0] = input_size[input_batch_size_dim];
222 output_size[1] = weight_size[weight_output_channels_dim];
223 for (const auto d : c10::irange(2, dim)) {
224 auto dilation_ = has_dilation ? dilation[d - 2] : 1;
225 auto kernel = dilation_ * (weight_size[d] - 1) + 1;
226 output_size[d] = (input_size[d] + (2 * padding[d - 2]) - kernel) / stride[d - 2] + 1;
227 }
228 return output_size;
229}
230
231static inline std::vector<int64_t> conv_output_size(
232 IntArrayRef input_size, IntArrayRef weight_size,
233 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
234) {
235 return _conv_output_size(input_size, weight_size, padding, stride, dilation);
236}
237
238static inline std::vector<c10::SymInt> conv_output_size(
239 SymIntArrayRef input_size, SymIntArrayRef weight_size,
240 SymIntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef()
241) {
242 return _conv_output_size(input_size, weight_size, padding, stride, dilation);
243}
244
245template <typename T>
246std::vector<T> _conv_input_size(
247 ArrayRef<T> output_size, ArrayRef<T> weight_size,
248 ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
249) {
250 // ASSERT(output_size.size() > 2)
251 // ASSERT(output_size.size() == weight_size.size())
252 auto dim = output_size.size();
253 std::vector<T> input_size(dim);
254 input_size[0] = output_size[output_batch_size_dim];
255 input_size[1] = weight_size[weight_input_channels_dim] * groups;
256 for (const auto d : c10::irange(2, dim)) {
257 auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1;
258 input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) +
259 kernel + output_padding[d - 2];
260 }
261 return input_size;
262}
263
264static inline std::vector<c10::SymInt> conv_input_size(
265 SymIntArrayRef output_size, SymIntArrayRef weight_size,
266 SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
267) {
268 return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
269}
270
271static inline std::vector<int64_t> conv_input_size(
272 IntArrayRef output_size, IntArrayRef weight_size,
273 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
274) {
275 return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups);
276}
277
278template <typename T>
279std::vector<T> _conv_weight_size(
280 ArrayRef<T> input_size, ArrayRef<T> output_size,
281 ArrayRef<T> padding, ArrayRef<T> output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
282) {
283 auto dim = input_size.size();
284 std::vector<T> weight_size(dim);
285 weight_size[0] = output_size[1];
286 weight_size[1] = input_size[1] / groups;
287 for (const auto d : c10::irange(2, dim)) {
288 auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
289 + padding[d - 2] * 2 - output_padding[d - 2];
290 weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
291 }
292 return weight_size;
293}
294
295static inline std::vector<c10::SymInt> conv_weight_size(
296 SymIntArrayRef input_size, SymIntArrayRef output_size,
297 SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
298) {
299 return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
300}
301
302static inline std::vector<int64_t> conv_weight_size(
303 IntArrayRef input_size, IntArrayRef output_size,
304 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
305) {
306 return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups);
307}
308
309static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) {
310 std::vector<int64_t> shape(dim, 1);
311 shape[1] = -1;
312 return bias.reshape(shape);
313}
314
315static inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) {
316 // disable NHWC for float64 input.
317 if (!at::detail::getCUDAHooks().compiledWithCuDNN() ||
318 input.scalar_type() == at::kDouble ||
319 weight.scalar_type() == at::kDouble) {
320 return at::MemoryFormat::Contiguous;
321 }
322 long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
323 auto input_memory_format = input.suggest_memory_format();
324 auto weight_memory_format = weight.suggest_memory_format();
325 auto weight_ndim = weight.ndimension();
326
327 bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && (
328 (input_memory_format == at::MemoryFormat::ChannelsLast) ||
329 (weight_memory_format == at::MemoryFormat::ChannelsLast)
330 );
331 if (can_use_cudnn_channels_last_2d) {
332 return at::MemoryFormat::ChannelsLast;
333 }
334
335 bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && (
336 (input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
337 (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
338 );
339 if (can_use_cudnn_channels_last_3d) {
340 return at::MemoryFormat::ChannelsLast3d;
341 }
342
343 return at::MemoryFormat::Contiguous;
344}
345
346static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
347 // disable NHWC for float64 input.
348 if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
349 input.scalar_type() == at::kDouble ||
350 weight.scalar_type() == at::kDouble) {
351 return false;
352 }
353
354 auto input_memory_format = input.suggest_memory_format();
355 auto weight_memory_format = weight.suggest_memory_format();
356
357 bool can_use_miopen_channels_last_2d = (
358 (input_memory_format == at::MemoryFormat::ChannelsLast) ||
359 (weight_memory_format == at::MemoryFormat::ChannelsLast)
360 );
361
362 bool can_use_miopen_channels_last_3d = false;
363
364 return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
365}
366
367static inline bool mkldnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
368
369 // disable NHWC for float64 input.
370 if (input.scalar_type() == at::kDouble ||
371 weight.scalar_type() == at::kDouble) {
372 return false;
373 }
374
375 // disable NHWC for MkldnnCPU tensor.
376 if (input.is_mkldnn() || weight.is_mkldnn()) {
377 return false;
378 }
379
380 auto input_memory_format = input.suggest_memory_format();
381 auto weight_memory_format = weight.suggest_memory_format();
382
383 bool can_use_mkldnn_channels_last_2d =
384 (input_memory_format == at::MemoryFormat::ChannelsLast) ||
385 (weight_memory_format == at::MemoryFormat::ChannelsLast);
386
387 // TODO: add channels last 3d support
388 bool can_use_mkldnn_channels_last_3d = false;
389
390 return can_use_mkldnn_channels_last_2d || can_use_mkldnn_channels_last_3d;
391}
392
393static inline bool thnn_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
394
395 auto input_memory_format = input.suggest_memory_format();
396 auto weight_memory_format = weight.suggest_memory_format();
397
398 bool can_use_thnn_channels_last_2d = input.device().is_cpu() && (
399 (input_memory_format == at::MemoryFormat::ChannelsLast) || (
400 weight_memory_format == at::MemoryFormat::ChannelsLast));
401
402 return can_use_thnn_channels_last_2d;
403}
404
405}} // namespace at::native
406