1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file convolution.cc
22 * \brief Convolution operators
23 */
24#include "convolution.h"
25
26#include <tvm/relay/attrs/nn.h>
27#include <tvm/relay/op.h>
28#include <tvm/tir/data_layout.h>
29
30#include <vector>
31
32#include "../../transforms/infer_layout_utils.h"
33#include "../op_common.h"
34#include "convolution_make.h"
35
36namespace tvm {
37namespace relay {
38
39Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) {
40 auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
41 attrs->tile_size = tile_size;
42 const Op& op = Op::Get(op_name);
43 return Call(op, {weight}, Attrs(attrs), {});
44}
45
46Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) {
47 auto attrs = make_object<ConvGemmWeightTransformAttrs>();
48 attrs->tile_rows = tile_rows;
49 attrs->tile_cols = tile_cols;
50 const Op& op = Op::Get(op_name);
51 return Call(op, {weight}, Attrs(attrs), {});
52}
53
54// relay.nn.conv1d
55TVM_REGISTER_NODE_TYPE(Conv1DAttrs);
56
57bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
58 const TypeReporter& reporter) {
59 ICHECK_EQ(types.size(), 3);
60 const auto* data = types[0].as<TensorTypeNode>();
61 const auto* weight = types[1].as<TensorTypeNode>();
62 if (data == nullptr) return false;
63 static const Layout kNCW("NCW");
64 static const Layout kOIW("OIW");
65
66 const auto* param = attrs.as<Conv1DAttrs>();
67 ICHECK(param != nullptr);
68 const Layout in_layout(param->data_layout);
69 const Layout kernel_layout(param->kernel_layout);
70
71 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
72 ICHECK(trans_in_layout.defined())
73 << "Conv only support input layouts that are convertible from NCW."
74 << " But got " << in_layout;
75
76 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
77 ICHECK(trans_kernel_layout.defined())
78 << "Conv only support kernel layouts that are convertible from OIW."
79 << " But got " << kernel_layout;
80
81 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
82 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
83 ICHECK(trans_out_layout.defined())
84 << "Conv only support output layouts that are convertible from NCW."
85 << " But got " << out_layout;
86
87 Array<IndexExpr> dshape_ncw = trans_in_layout.ForwardShape(data->shape);
88
89 IndexExpr channels, dilated_ksize;
90 // infer weight if the kernel_size and channels are defined
91 if (param->kernel_size.defined() && param->channels.defined()) {
92 Array<IndexExpr> wshape;
93
94 wshape = {{param->channels, indexdiv(dshape_ncw[1], param->groups), param->kernel_size[0]}};
95
96 wshape = trans_kernel_layout.BackwardShape(wshape);
97 channels = param->channels;
98 dilated_ksize = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
99 DataType weight_dtype = data->dtype;
100 if (weight != nullptr) {
101 weight_dtype = weight->dtype;
102 }
103 // assign result to reporter
104 reporter->Assign(types[1], TensorType(wshape, weight_dtype));
105 } else {
106 // use weight to infer the conv shape.
107 if (weight == nullptr) return false;
108 auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
109 if (param->kernel_size.defined()) {
110 // check the size
111 ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
112 << "Conv1D: shape of weight is inconsistent with kernel_size, "
113 << " kernel_size=" << param->kernel_size << " wshape=" << wshape;
114 }
115 if (param->channels.defined()) {
116 ICHECK(reporter->AssertEQ(param->channels, wshape[0]))
117 << "Conv1D: shape of weight is inconsistent with channels, "
118 << " channels=" << param->channels << " wshape=" << wshape;
119 }
120 if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
121 ICHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1]));
122 }
123 channels = wshape[0];
124 dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0];
125 }
126 // dilation
127 Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
128
129 if (!dshape_ncw[2].as<tir::AnyNode>()) {
130 oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize,
131 param->strides[0]) +
132 1);
133 } else {
134 oshape.Set(2, dshape_ncw[2]);
135 }
136
137 DataType out_dtype = param->out_dtype;
138 if (out_dtype.bits() == 0) {
139 out_dtype = data->dtype;
140 }
141 oshape = trans_out_layout.BackwardShape(oshape);
142 // assign output type
143 reporter->Assign(types[2], TensorType(oshape, out_dtype));
144 return true;
145}
146
147TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d")
148 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
149 Array<IndexExpr> dilation, int groups, IndexExpr channels,
150 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
151 String out_layout, DataType out_dtype) {
152 return MakeConv<Conv1DAttrs>(data, weight, strides, padding, dilation, groups, channels,
153 kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
154 "nn.conv1d");
155 });
156
157RELAY_REGISTER_OP("nn.conv1d")
158 .describe(R"code(1D convolution layer (e.g. spatial convolution over sequences).
159
160This layer creates a convolution kernel that is convolved
161with the layer input to produce a tensor of outputs.
162
163- **data**: This depends on the `layout` parameter. Input is 3D array of shape
164 (batch_size, in_channels, width) if `layout` is `NCW`.
165- **weight**: (channels, in_channels, kernel_size)
166- **out**: This depends on the `layout` parameter. Output is 3D array of shape
167 (batch_size, channels, out_width) if `layout` is `NCW`.
168
169)code" TVM_ADD_FILELINE)
170 .set_attrs_type<Conv1DAttrs>()
171 .set_num_inputs(2)
172 .add_argument("data", "Tensor", "The input tensor.")
173 .add_argument("weight", "Tensor", "The weight tensor.")
174 .set_support_level(2)
175 .add_type_rel("Conv1D", Conv1DRel)
176 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv1DAttrs>)
177 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
178
179// relay.nn.conv2d
180TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
181
182bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
183 const TypeReporter& reporter) {
184 ICHECK_EQ(types.size(), 3);
185 const auto* data = types[0].as<TensorTypeNode>();
186 const auto* weight = types[1].as<TensorTypeNode>();
187 if (data == nullptr) return false;
188 static const Layout kNCHW("NCHW");
189 Layout kOIHW("OIHW");
190
191 const auto* param = attrs.as<Conv2DAttrs>();
192 DataType out_dtype = param->out_dtype;
193 if (out_dtype.bits() == 0) {
194 out_dtype = data->dtype;
195 if (out_dtype.bits() == 0 && weight != nullptr) {
196 out_dtype = weight->dtype;
197 }
198 }
199 TensorType meta_schedule_weight{nullptr};
200 if (param->meta_schedule_original_shape.size() != 0) {
201 meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype);
202 weight = meta_schedule_weight.get();
203 }
204 ICHECK(param != nullptr);
205 const Layout in_layout(param->data_layout);
206 const Layout kernel_layout(param->kernel_layout);
207
208 bool is_dnnl_group_conv = false;
209 if (param->groups > 1 && kernel_layout.name().find("G") != std::string::npos) {
210 kOIHW = Layout("GOIHW");
211 is_dnnl_group_conv = true;
212 }
213
214 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
215 if (!trans_in_layout.defined()) {
216 reporter->GetDiagCtx().Emit(
217 Diagnostic::Error(reporter->GetSpan())
218 << "conv2d only support input layouts that are convertible from NCHW."
219 << " The provided layout is: " << in_layout);
220 return false;
221 }
222
223 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
224 if (!trans_kernel_layout.defined()) {
225 reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan())
226 << "conv2d only support kernel layouts that are convertible from "
227 << kOIHW << "."
228 << " The provided layout is: " << kernel_layout);
229 return false;
230 }
231
232 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
233 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
234 if (!trans_out_layout.defined()) {
235 reporter->GetDiagCtx().Emit(
236 Diagnostic::Error(reporter->GetSpan())
237 << "conv2d only support output layouts that are convertible from NCHW."
238 << "The provided layout is: " << out_layout);
239 return false;
240 }
241
242 Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
243 bool is_depthwise = false;
244 if (param->groups > 1) {
245 if (!(weight && weight->shape.defined())) {
246 reporter->GetDiagCtx().Emit(
247 Diagnostic::Error(reporter->GetSpan())
248 << "Weight shape must be specified when groups is greater than 1.");
249 return false;
250 }
251
252 Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape);
253 if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) &&
254 tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) {
255 is_depthwise = true;
256 }
257 }
258
259 IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
260 // infer weight if the kernel_size and channels are defined
261 if (param->kernel_size.defined() && param->channels.defined()) {
262 ICHECK_EQ(param->kernel_size.size(), 2);
263 ICHECK_EQ(param->dilation.size(), 2);
264 Array<IndexExpr> wshape;
265
266 if (is_dnnl_group_conv) {
267 // infer weight's shape for group convolution
268 wshape = {{param->groups, indexdiv(param->channels, param->groups),
269 indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
270 param->kernel_size[1]}};
271 } else if (is_depthwise) {
272 // infer weight's shape for depthwise convolution
273 wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0],
274 param->kernel_size[1]}};
275 } else {
276 wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
277 param->kernel_size[1]}};
278 }
279
280 wshape = trans_kernel_layout.BackwardShape(wshape);
281 channels = param->channels;
282 dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
283 dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
284 DataType weight_dtype = data->dtype;
285 if (weight != nullptr) {
286 weight_dtype = weight->dtype;
287 }
288
289 if (param->auto_scheduler_rewritten_layout.size() != 0) {
290 // If the layout is rewritten by auto-scheduler,
291 // we just forcly apply the layout provided by auto-scheduler and
292 // skip the normal inference logic.
293 {} // do nothing
294 } else if (param->meta_schedule_original_shape.size() == 0) {
295 // Normal case: assign result to reporter
296 reporter->Assign(types[1], TensorType(wshape, weight_dtype));
297 }
298 } else {
299 // use weight to infer the conv shape.
300 if (weight == nullptr) return false;
301
302 Array<PrimExpr> wshape;
303 if (param->auto_scheduler_rewritten_layout.size() != 0) {
304 // works for the default kernel layout "HWIO"
305 ICHECK_EQ(param->kernel_layout, "HWIO");
306 wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
307 {"ry", "rx", "rc", "ff"});
308 } else {
309 wshape = weight->shape;
310 }
311
312 wshape = trans_kernel_layout.ForwardShape(wshape);
313 if (param->kernel_size.defined()) {
314 ICHECK_EQ(param->kernel_size.size(), 2);
315
316 if (!reporter->AssertEQ(param->kernel_size[0], wshape[2])) {
317 reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan())
318 << "Conv2D: shape of weight is inconsistent with kernel_size,"
319 << " kernel_size=" << param->kernel_size
320 << " wshape=" << wshape);
321 }
322
323 if (!reporter->AssertEQ(param->kernel_size[1], wshape[3])) {
324 reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan())
325 << "Conv2D: shape of weight is inconsistent with kernel_size,"
326 << " kernel_size=" << param->kernel_size
327 << " wshape=" << wshape);
328 return false;
329 }
330 }
331
332 if (param->channels.defined() && !reporter->AssertEQ(param->channels, wshape[0])) {
333 reporter->GetDiagCtx().Emit(
334 Diagnostic::Error(reporter->GetSpan())
335 << "conv2D: the first dimensions of the weight tensor (" << wshape << ")"
336 << "does not match the number of channels (" << param->channels << ").");
337 return false;
338 }
339
340 if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
341 if (!reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])) {
342 reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan())
343 << "conv2d: requires that `"
344 << indexdiv(dshape_nchw[1], param->groups) << "`,"
345 << " the input channels (" << dshape_nchw[1] << ")"
346 << " divided by groups (" << param->groups << ")"
347 << ",\n must match the input channels"
348 << " of the weight `" << wshape[1]
349 << "`, where the weight shape is (" << wshape << ").");
350 return false;
351 }
352 }
353 channels = wshape[0];
354 dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
355 dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
356 }
357 // dilation
358 Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
359
360 IndexExpr pad_h, pad_w;
361 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
362 if (!dshape_nchw[2].as<tir::AnyNode>()) {
363 oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
364 } else {
365 oshape.Set(2, dshape_nchw[2]);
366 }
367
368 if (!dshape_nchw[3].as<tir::AnyNode>()) {
369 oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
370 } else {
371 oshape.Set(3, dshape_nchw[3]);
372 }
373 oshape = trans_out_layout.BackwardShape(oshape);
374 // assign output type
375 reporter->Assign(types[2], TensorType(oshape, out_dtype));
376 return true;
377}
378
379TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d")
380 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
381 Array<IndexExpr> dilation, int groups, IndexExpr channels,
382 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
383 String out_layout, DataType out_dtype) {
384 return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
385 kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
386 "nn.conv2d");
387 });
388
389RELAY_REGISTER_OP("nn.conv2d")
390 .describe(R"code(2D convolution layer (e.g. spatial convolution over images).
391
392This layer creates a convolution kernel that is convolved
393with the layer input to produce a tensor of outputs.
394
395- **data**: This depends on the `layout` parameter. Input is 4D array of shape
396 (batch_size, in_channels, height, width) if `layout` is `NCHW`.
397- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
398- **out**: This depends on the `layout` parameter. Output is 4D array of shape
399 (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
400
401)code" TVM_ADD_FILELINE)
402 .set_attrs_type<Conv2DAttrs>()
403 .set_num_inputs(2)
404 .add_argument("data", "Tensor", "The input tensor.")
405 .add_argument("weight", "Tensor", "The weight tensor.")
406 .set_support_level(2)
407 .add_type_rel("Conv2D", Conv2DRel)
408 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>)
409 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
410
411// relay.nn.conv3d
412TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
413
414bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
415 const TypeReporter& reporter) {
416 ICHECK_EQ(types.size(), 3);
417 const auto* data = types[0].as<TensorTypeNode>();
418 const auto* weight = types[1].as<TensorTypeNode>();
419 if (data == nullptr) return false;
420 static const Layout kNCDHW("NCDHW");
421 static const Layout kOIDHW("OIDHW");
422
423 const auto* param = attrs.as<Conv3DAttrs>();
424 ICHECK(param != nullptr);
425 DataType out_dtype = param->out_dtype;
426 if (out_dtype.bits() == 0) {
427 out_dtype = data->dtype;
428 if (out_dtype.bits() == 0 && weight != nullptr) {
429 out_dtype = weight->dtype;
430 }
431 }
432 TensorType meta_schedule_weight{nullptr};
433 if (param->meta_schedule_original_shape.size() != 0) {
434 meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype);
435 weight = meta_schedule_weight.get();
436 }
437 const Layout in_layout(param->data_layout);
438 const Layout kernel_layout(param->kernel_layout);
439
440 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW);
441 ICHECK(trans_in_layout.defined())
442 << "Conv only support input layouts that are convertible from NCDHW."
443 << " But got " << in_layout;
444
445 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
446 ICHECK(trans_kernel_layout.defined())
447 << "Conv only support kernel layouts that are convertible from OIDHW."
448 << " But got " << kernel_layout;
449
450 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
451 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW);
452 ICHECK(trans_out_layout.defined())
453 << "Conv only support output layouts that are convertible from NCDHW."
454 << " But got " << out_layout;
455
456 Array<IndexExpr> dshape_ncdhw = trans_in_layout.ForwardShape(data->shape);
457
458 IndexExpr channels, dilated_ksize_z, dilated_ksize_y, dilated_ksize_x;
459 // infer weight if the kernel_size and channels are defined
460 if (param->kernel_size.defined() && param->channels.defined()) {
461 ICHECK_EQ(param->kernel_size.size(), 3);
462 ICHECK_EQ(param->dilation.size(), 3);
463 Array<IndexExpr> wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups),
464 param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]});
465 wshape = trans_kernel_layout.BackwardShape(wshape);
466 channels = param->channels;
467 dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
468 dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
469 dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2];
470 DataType weight_dtype = data->dtype;
471 if (weight != nullptr) {
472 weight_dtype = weight->dtype;
473 }
474
475 if (param->auto_scheduler_rewritten_layout.size() != 0) {
476 // If the layout is rewritten by auto-scheduler,
477 // we just forcly apply the layout provided by auto-scheduler and
478 // skip the normal inference logic.
479 {} // do nothing
480 } else if (param->meta_schedule_original_shape.size() == 0) {
481 // Normal case: assign result to reporter
482 reporter->Assign(types[1], TensorType(wshape, weight_dtype));
483 }
484
485 } else {
486 // use weight to infer the conv shape.
487 if (weight == nullptr) return false;
488
489 Array<PrimExpr> wshape;
490 if (param->auto_scheduler_rewritten_layout.size() != 0) {
491 // works for the default kernel layout "DHWIO"
492 ICHECK_EQ(param->kernel_layout, "DHWIO");
493 wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
494 {"rd", "rh", "rw", "rc", "cc"});
495 } else {
496 wshape = weight->shape;
497 }
498
499 wshape = trans_kernel_layout.ForwardShape(wshape);
500 if (param->kernel_size.defined()) {
501 ICHECK_EQ(param->kernel_size.size(), 3);
502 // check the size
503 ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
504 reporter->AssertEQ(param->kernel_size[1], wshape[3]) &&
505 reporter->AssertEQ(param->kernel_size[2], wshape[4]))
506 << "Conv3D: shape of weight is inconsistent with kernel_size, "
507 << " kernel_size=" << param->kernel_size << " wshape=" << wshape;
508 }
509
510 if (param->channels.defined()) {
511 ICHECK(reporter->AssertEQ(param->channels, wshape[0]))
512 << "Conv3D: shape of weight is inconsistent with channels, "
513 << " channels=" << param->channels << " wshape=" << wshape;
514 }
515
516 if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
517 ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1]));
518 }
519 channels = wshape[0];
520 dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0];
521 dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1];
522 dilated_ksize_x = 1 + (wshape[4] - 1) * param->dilation[2];
523 }
524 // dilation
525 Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
526
527 IndexExpr pad_d, pad_h, pad_w;
528 GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
529 if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
530 oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1);
531 } else {
532 oshape.Set(2, dshape_ncdhw[2]);
533 }
534
535 if (!dshape_ncdhw[3].as<tir::AnyNode>()) {
536 oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1);
537 } else {
538 oshape.Set(3, dshape_ncdhw[3]);
539 }
540
541 if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
542 oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1);
543 } else {
544 oshape.Set(4, dshape_ncdhw[4]);
545 }
546 oshape = trans_out_layout.BackwardShape(oshape);
547 // assign output type
548 reporter->Assign(types[2], TensorType(oshape, out_dtype));
549 return true;
550}
551
552TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d")
553 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
554 Array<IndexExpr> dilation, int groups, IndexExpr channels,
555 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
556 String out_layout, DataType out_dtype) {
557 return MakeConv<Conv3DAttrs>(data, weight, strides, padding, dilation, groups, channels,
558 kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
559 "nn.conv3d");
560 });
561
562RELAY_REGISTER_OP("nn.conv3d")
563 .describe(R"code(3D convolution layer (e.g. convolution over 3D image data,
564like Magnetic Resonance Imaging (MRI) data in medicine).
565
566This layer creates a convolution kernel that is convolved
567with the layer input to produce a tensor of outputs.
568
569- **data**: This depends on the `layout` parameter. Input is 5D array of shape
570 (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`.
571- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
572- **out**: This depends on the `layout` parameter. Output is 5D array of shape
573 (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`.
574
575)code" TVM_ADD_FILELINE)
576 .set_attrs_type<Conv3DAttrs>()
577 .set_num_inputs(2)
578 .add_argument("data", "Tensor", "The input tensor.")
579 .add_argument("weight", "Tensor", "The weight tensor.")
580 .set_support_level(2)
581 .add_type_rel("Conv3D", Conv3DRel)
582 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>)
583 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
584
585// relay.nn.conv3d_transpose
586TVM_REGISTER_NODE_TYPE(Conv3DTransposeAttrs);
587
588template <typename AttrType>
589bool Conv3DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
590 const TypeReporter& reporter) {
591 ICHECK_EQ(types.size(), 3);
592 const auto* data = types[0].as<TensorTypeNode>();
593 const auto* weight = types[1].as<TensorTypeNode>();
594 if (data == nullptr) return false;
595
596 static const Layout kNCDHW("NCDHW");
597 static const Layout kOIDHW("OIDHW");
598
599 const Conv3DTransposeAttrs* param = attrs.as<Conv3DTransposeAttrs>();
600 ICHECK(param != nullptr);
601 const Layout in_layout(param->data_layout);
602 const Layout kernel_layout(param->kernel_layout);
603
604 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW);
605 ICHECK(trans_in_layout.defined())
606 << "Conv3d_transpose only support input layouts that are convertible from NCDHW."
607 << " But got " << in_layout;
608
609 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
610 ICHECK(trans_kernel_layout.defined())
611 << "Conv3d_transpose only support kernel layouts that are convertible from OIDHW."
612 << " But got " << kernel_layout;
613
614 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
615 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW);
616 ICHECK(trans_out_layout.defined())
617 << "Conv3d_transpose only support output layouts that are convertible from NCDHW."
618 << " But got " << out_layout;
619
620 IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x;
621
622 auto dshape_ncdhw = trans_in_layout.ForwardShape(data->shape);
623
624 // infer weight if the kernel_size and channels are defined
625 if (param->kernel_size.defined() && param->channels.defined()) {
626 ICHECK_EQ(param->kernel_size.size(), 3);
627 ICHECK_EQ(param->dilation.size(), 3);
628
629 Array<IndexExpr> wshape({dshape_ncdhw[1], indexdiv(param->channels, param->groups),
630 param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]});
631
632 wshape = trans_kernel_layout.BackwardShape(wshape);
633 dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
634 dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
635 dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2];
636 channels = param->channels;
637
638 DataType weight_dtype = data->dtype;
639 if (weight != nullptr) {
640 weight_dtype = weight->dtype;
641 }
642 // assign result to reporter
643 reporter->Assign(types[1], TensorType(wshape, weight_dtype));
644 } else {
645 // use weight to infer the conv shape.
646 if (weight == nullptr) return false;
647 auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
648 if (param->kernel_size.defined()) {
649 ICHECK_EQ(param->kernel_size.size(), 3);
650 // check the size
651 ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
652 reporter->AssertEQ(param->kernel_size[1], wshape[3]) &&
653 reporter->AssertEQ(param->kernel_size[2], wshape[4]))
654 << "Conv3D: shape of weight is inconsistent with kernel_size, "
655 << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
656 }
657 if (param->channels.defined()) {
658 ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
659 << "Conv3D: shape of weight is inconsistent with channels, "
660 << " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
661 }
662 if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
663 ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0]));
664 }
665 channels = wshape[1];
666 dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0];
667 dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
668 dilated_ksize_y = 1 + (wshape[4] - 1) * param->dilation[2];
669 }
670
671 // dilation
672 Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
673 IndexExpr pad_d, pad_h, pad_w;
674 GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
675
676 if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
677 oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d +
678 param->output_padding[0]));
679 } else {
680 oshape.Set(2, dshape_ncdhw[2]);
681 }
682 if (!dshape_ncdhw[3].as<tir::AnyNode>()) {
683 oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h +
684 param->output_padding[1]));
685 } else {
686 oshape.Set(3, dshape_ncdhw[3]);
687 }
688 if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
689 oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w +
690 param->output_padding[2]));
691 } else {
692 oshape.Set(4, dshape_ncdhw[4]);
693 }
694
695 DataType out_dtype = param->out_dtype;
696 if (out_dtype.bits() == 0) {
697 out_dtype = data->dtype;
698 }
699 oshape = trans_out_layout.BackwardShape(oshape);
700 reporter->Assign(types[2], TensorType(oshape, out_dtype));
701 return true;
702}
703
704TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d_transpose")
705 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
706 Array<IndexExpr> dilation, int groups, IndexExpr channels,
707 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
708 String out_layout, Array<IndexExpr> output_padding, DataType out_dtype) {
709 return MakeConvTranspose<Conv3DTransposeAttrs>(
710 data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
711 kernel_layout, out_layout, output_padding, out_dtype, "nn.conv3d_transpose");
712 });
713
714RELAY_REGISTER_OP("nn.conv3d_transpose")
715 .describe(R"code(Transposed 3D convolution layer (sometimes called Deconvolution 3D).
716
717The need for transposed convolutions generally arises
718from the desire to use a transformation going in the opposite direction
719of a normal convolution, i.e., from something that has the shape of the
720output of some convolution to something that has the shape of its input
721while maintaining a connectivity pattern that is compatible with
722said convolution.
723
724- **data**: This depends on the `layout` parameter. Input is 5D array of shape
725 (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`.
726- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1], kernel_size[2])
727- **bias**: (channels,)
728- **out**: This depends on the `layout` parameter. Output is 5D array of shape
729 (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`.
730
731 out_depth and out_height and out_width are calculated as::
732 out_depth = (depth-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
733 out_height = (height-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
734 out_width = (width-1)*strides[2]-2*padding[2]+kernel_size[2]+output_padding[2]
735
736)code" TVM_ADD_FILELINE)
737 .set_attrs_type<Conv3DTransposeAttrs>()
738 .set_num_inputs(2)
739 .add_argument("data", "Tensor", "The input tensor.")
740 .add_argument("weight", "Tensor", "The weight tensor.")
741 .set_support_level(2)
742 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
743 ConvInferCorrectLayout<Conv3DTransposeAttrs>)
744 .add_type_rel("Conv3DTranspose", Conv3DTransposeRel<Conv3DTransposeAttrs>)
745 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
746
747// relay.nn.conv2d_transpose
748TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
749
750bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
751 const TypeReporter& reporter) {
752 ICHECK_EQ(types.size(), 3);
753 const auto* data = types[0].as<TensorTypeNode>();
754 const auto* weight = types[1].as<TensorTypeNode>();
755 if (data == nullptr) return false;
756
757 static const Layout kNCHW("NCHW");
758 Layout kIOHW("IOHW");
759
760 const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
761 ICHECK(param != nullptr);
762 const Layout in_layout(param->data_layout);
763 const Layout kernel_layout(param->kernel_layout);
764
765 bool is_dnnl_group_conv = false;
766 if (param->groups > 1 && kernel_layout.name().find("G") != std::string::npos) {
767 kIOHW = Layout("GIOHW");
768 is_dnnl_group_conv = true;
769 }
770
771 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
772 ICHECK(trans_in_layout.defined())
773 << "Conv2DTransposed only support input layouts that are convertible from NCHW."
774 << " But got " << in_layout;
775
776 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW);
777 ICHECK(trans_kernel_layout.defined())
778 << "Conv2DTransposed only support kernel layouts that are convertible from " << kIOHW << "."
779 << " But got " << kernel_layout << " " << kIOHW;
780
781 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
782 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
783 ICHECK(trans_out_layout.defined())
784 << "Conv2DTransposed only support output layouts that are convertible from NCHW."
785 << " But got " << out_layout;
786
787 IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
788
789 auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
790
791 // infer weight if the kernel_size and channels are defined
792 if (param->kernel_size.defined() && param->channels.defined()) {
793 ICHECK_EQ(param->kernel_size.size(), 2);
794 ICHECK_EQ(param->dilation.size(), 2);
795
796 Array<IndexExpr> wshape;
797 if (is_dnnl_group_conv) {
798 // infer weight's shape for group convolution
799 wshape = {{param->groups, indexdiv(dshape_nchw[1], param->groups),
800 indexdiv(param->channels, param->groups), param->kernel_size[0],
801 param->kernel_size[1]}};
802 } else {
803 // infer weight's shape for depthwise convolution
804 wshape = {{dshape_nchw[1], indexdiv(param->channels, param->groups), param->kernel_size[0],
805 param->kernel_size[1]}};
806 }
807
808 wshape = trans_kernel_layout.BackwardShape(wshape);
809 dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
810 dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
811 channels = param->channels;
812
813 DataType weight_dtype = data->dtype;
814 if (weight != nullptr) {
815 weight_dtype = weight->dtype;
816 }
817 // assign result to reporter
818 reporter->Assign(types[1], TensorType(wshape, weight_dtype));
819 } else {
820 // use weight to infer the conv shape.
821 if (weight == nullptr) return false;
822 auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
823 if (param->kernel_size.defined()) {
824 ICHECK_EQ(param->kernel_size.size(), 2);
825 // check the size
826 ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
827 reporter->AssertEQ(param->kernel_size[1], wshape[3]))
828 << "Conv2DTransposed: shape of weight is inconsistent with kernel_size, "
829 << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
830 }
831 if (param->channels.defined()) {
832 ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1]))
833 << "Conv2DTransposed: shape of weight is inconsistent with out_channels, "
834 << " out_channels // groups != weight.shape[1] "
835 << " out_channels=" << param->channels << " groups=" << param->groups
836 << " weight.shape=" << Array<IndexExpr>(wshape);
837 }
838 if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
839 ICHECK(reporter->AssertEQ(dshape_nchw[1], wshape[0]))
840 << "Conv2DTransposed: shape of weight is inconsistent with in_channels."
841 << " data.shape= " << Array<IndexExpr>(dshape_nchw) << " groups= " << param->groups
842 << " weight.shape= " << Array<IndexExpr>(wshape);
843 }
844 channels = wshape[1];
845 dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
846 dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
847 }
848 // dilation
849 Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
850 IndexExpr pad_h, pad_w;
851 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
852 if (!dshape_nchw[2].as<tir::AnyNode>()) {
853 oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h +
854 param->output_padding[0]));
855 } else {
856 oshape.Set(2, dshape_nchw[2]);
857 }
858 if (!dshape_nchw[3].as<tir::AnyNode>()) {
859 oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w +
860 param->output_padding[1]));
861 } else {
862 oshape.Set(3, dshape_nchw[3]);
863 }
864
865 DataType out_dtype = param->out_dtype;
866 if (out_dtype.bits() == 0) {
867 out_dtype = data->dtype;
868 }
869 oshape = trans_out_layout.BackwardShape(oshape);
870 reporter->Assign(types[2], TensorType(oshape, out_dtype));
871 return true;
872}
873
874TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose")
875 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
876 Array<IndexExpr> dilation, int groups, IndexExpr channels,
877 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
878 String out_layout, Array<IndexExpr> output_padding, DataType out_dtype) {
879 return MakeConvTranspose<Conv2DTransposeAttrs>(
880 data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
881 kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose");
882 });
883
884RELAY_REGISTER_OP("nn.conv2d_transpose")
885 .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
886
887The need for transposed convolutions generally arises
888from the desire to use a transformation going in the opposite direction
889of a normal convolution, i.e., from something that has the shape of the
890output of some convolution to something that has the shape of its input
891while maintaining a connectivity pattern that is compatible with
892said convolution.
893
894- **data**: This depends on the `layout` parameter. Input is 4D array of shape
895 (batch_size, in_channels, height, width) if `layout` is `NCHW`.
896- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1])
897- **bias**: (channels,)
898- **out**: This depends on the `layout` parameter. Output is 4D array of shape
899v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
900
901 out_height and out_width are calculated as::
902 out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
903 out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
904
905)code" TVM_ADD_FILELINE)
906 .set_attrs_type<Conv2DTransposeAttrs>()
907 .set_num_inputs(2)
908 .add_argument("data", "Tensor", "The input tensor.")
909 .add_argument("weight", "Tensor", "The weight tensor.")
910 .set_support_level(2)
911 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
912 ConvInferCorrectLayout<Conv2DTransposeAttrs>)
913 .add_type_rel("Conv2DTranspose", Conv2DTransposeRel)
914 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
915
916// relay.nn.conv1d_transpose
917TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
918
919bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
920 const TypeReporter& reporter) {
921 ICHECK_EQ(types.size(), 3);
922 const auto* data = types[0].as<TensorTypeNode>();
923 const auto* weight = types[1].as<TensorTypeNode>();
924 if (data == nullptr) return false;
925
926 static const Layout kNCW("NCW");
927 static const Layout kOIW("OIW");
928
929 const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
930 ICHECK(param != nullptr);
931 const Layout in_layout(param->data_layout);
932 const Layout kernel_layout(param->kernel_layout);
933
934 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
935 ICHECK(trans_in_layout.defined())
936 << "Conv only support input layouts that are convertible from NCW."
937 << " But got " << in_layout;
938
939 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
940 ICHECK(trans_kernel_layout.defined())
941 << "Conv only support kernel layouts that are convertible from OIW."
942 << " But got " << kernel_layout;
943
944 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
945 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
946 ICHECK(trans_out_layout.defined())
947 << "Conv only support output layouts that are convertible from NCW."
948 << " But got " << out_layout;
949
950 IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
951
952 auto dshape_ncw = trans_in_layout.ForwardShape(data->shape);
953
954 // infer weight if the kernel_size and channels are defined
955 if (param->kernel_size.defined() && param->channels.defined()) {
956 ICHECK_EQ(param->kernel_size.size(), 1);
957 ICHECK_EQ(param->dilation.size(), 1);
958
959 Array<IndexExpr> wshape(
960 {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]});
961
962 wshape = trans_kernel_layout.BackwardShape(wshape);
963 dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
964 channels = param->channels;
965
966 DataType weight_dtype = data->dtype;
967 if (weight != nullptr) {
968 weight_dtype = weight->dtype;
969 }
970 // assign result to reporter
971 reporter->Assign(types[1], TensorType(wshape, weight_dtype));
972 } else {
973 // use weight to infer the conv shape.
974 if (weight == nullptr) return false;
975 auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
976 if (param->kernel_size.defined()) {
977 ICHECK_EQ(param->kernel_size.size(), 1);
978 // check the size
979 ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
980 << "Conv1D: shape of weight is inconsistent with kernel_size, "
981 << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape);
982 }
983 if (param->channels.defined()) {
984 ICHECK(reporter->AssertEQ(param->channels, wshape[1]))
985 << "Conv1D: shape of weight is inconsistent with channels, "
986 << " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape);
987 }
988 if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
989 ICHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
990 }
991 channels = wshape[1];
992 dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
993 }
994 // dilation
995 IndexExpr pad_w;
996 GetPaddingWidth(param->padding, &pad_w);
997 Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
998 if (!dshape_ncw[2].as<tir::AnyNode>()) {
999 oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w +
1000 param->output_padding[0]));
1001 } else {
1002 oshape.Set(2, dshape_ncw[2]);
1003 }
1004
1005 DataType out_dtype = param->out_dtype;
1006 if (out_dtype.bits() == 0) {
1007 out_dtype = data->dtype;
1008 }
1009 oshape = trans_out_layout.BackwardShape(oshape);
1010 reporter->Assign(types[2], TensorType(oshape, out_dtype));
1011 return true;
1012}
1013
1014TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose")
1015 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
1016 Array<IndexExpr> dilation, int groups, IndexExpr channels,
1017 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
1018 String out_layout, Array<IndexExpr> output_padding, DataType out_dtype) {
1019 return MakeConvTranspose<Conv1DTransposeAttrs>(
1020 data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
1021 kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose");
1022 });
1023
1024RELAY_REGISTER_OP("nn.conv1d_transpose")
1025 .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution).
1026
1027The need for transposed convolutions generally arises
1028from the desire to use a transformation going in the opposite direction
1029of a normal convolution, i.e., from something that has the shape of the
1030output of some convolution to something that has the shape of its input
1031while maintaining a connectivity pattern that is compatible with
1032said convolution.
1033
1034- **data**: This depends on the `layout` parameter. Input is 3D array of shape
1035 (batch_size, in_channels, width) if `layout` is `NCW`.
1036- **weight**: (in_channels, channels, kernel_size[0])
1037- **bias**: (channels,)
1038- **out**: This depends on the `layout` parameter. Output is 3D array of shape
1039 (batch_size, channels, out_width) if `layout` is `NCW`.
1040
1041 out_width is calculated as::
1042 out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
1043
1044)code" TVM_ADD_FILELINE)
1045 .set_attrs_type<Conv1DTransposeAttrs>()
1046 .set_num_inputs(2)
1047 .add_argument("data", "Tensor", "The input tensor.")
1048 .add_argument("weight", "Tensor", "The weight tensor.")
1049 .set_support_level(2)
1050 .add_type_rel("Conv1DTranspose", Conv1DTransposeRel)
1051 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1052
1053// relay.nn.contrib_conv2d_winograd_without_weight_transform
1054TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);
1055
1056TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
1057 .set_body_typed([](Expr data, Expr weight, int tile_size, Array<IndexExpr> strides,
1058 Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
1059 IndexExpr channels, Array<IndexExpr> kernel_size, String data_layout,
1060 String kernel_layout, String out_layout, DataType out_dtype) {
1061 return MakeConvWinograd<Conv2DWinogradAttrs>(
1062 data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size,
1063 data_layout, kernel_layout, out_layout, out_dtype,
1064 "nn.contrib_conv2d_winograd_without_weight_transform");
1065 });
1066
1067RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
1068 .describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout.
1069 This operator assumes the weight tensor is already pre-transformed by
1070 nn.contrib_conv2d_winograd_weight_transform.
1071
1072- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
1073- **weight**: Any shape
1074 We do not check the shape for this input tensor. Since different backend
1075 has different layout strategy.
1076
1077- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
1078)code" TVM_ADD_FILELINE)
1079 .set_attrs_type<Conv2DWinogradAttrs>()
1080 .set_num_inputs(2)
1081 .add_argument("data", "Tensor", "The input tensor.")
1082 .add_argument("weight", "Tensor", "The weight tensor.")
1083 .set_support_level(10)
1084 .add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
1085 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
1086 ConvInferCorrectLayout<Conv2DWinogradAttrs>)
1087 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1088
1089// relay.nn.contrib_conv2d_winograd_weight_transform
1090TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
1091
1092bool Conv2DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1093 const TypeReporter& reporter) {
1094 ICHECK_EQ(types.size(), 2);
1095 const auto* data = types[0].as<TensorTypeNode>();
1096 if (data == nullptr) return false;
1097
1098 const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>();
1099 ICHECK(param != nullptr);
1100
1101 ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
1102
1103 std::vector<IndexExpr> oshape{
1104 param->tile_size + data->shape[2] - 1,
1105 param->tile_size + data->shape[3] - 1,
1106 data->shape[0],
1107 data->shape[1],
1108 };
1109
1110 reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype));
1111 return true;
1112}
1113
1114TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
1115 .set_body_typed([](Expr weight, int tile_size) {
1116 return MakeConvWinogradWeightTransform(weight, tile_size,
1117 "nn.contrib_conv2d_winograd_weight_transform");
1118 });
1119
1120RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
1121 .describe(R"code(Weight transformation of winograd fast convolution algorithm.
1122
1123Separate this into another operator in order to enable Precompute Pass to compute the
1124weight transformation in advance.
1125
1126- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
1127)code" TVM_ADD_FILELINE)
1128 .set_attrs_type<ConvWinogradWeightTransformAttrs>()
1129 .set_num_inputs(1)
1130 .add_argument("weight", "Tensor", "The weight tensor.")
1131 .set_support_level(10)
1132 .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel)
1133 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1134
1135// relay.nn.contrib_conv3d_winograd_without_weight_transform
1136TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs);
1137
1138bool Conv3DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1139 const TypeReporter& reporter) {
1140 ICHECK_EQ(types.size(), 3);
1141 const auto* data = types[0].as<TensorTypeNode>();
1142 if (data == nullptr) return false;
1143 static const Layout kNCDHW("NCDHW");
1144 static const Layout kOIDHW("OIDHW");
1145
1146 const auto* param = attrs.as<Conv3DWinogradAttrs>();
1147 ICHECK(param != nullptr);
1148 const Layout in_layout(param->data_layout);
1149 const Layout kernel_layout(param->kernel_layout);
1150
1151 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW);
1152 ICHECK(trans_in_layout.defined())
1153 << "Conv only support input layouts that are convertible from NCDHW."
1154 << " But got " << in_layout;
1155
1156 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
1157 ICHECK(trans_kernel_layout.defined())
1158 << "Conv only support kernel layouts that are convertible from OIDHW."
1159 << " But got " << kernel_layout;
1160
1161 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
1162 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW);
1163 ICHECK(trans_out_layout.defined())
1164 << "Conv only support output layouts that are convertible from NCDHW."
1165 << " But got " << out_layout;
1166
1167 Array<IndexExpr> dshape_ncdhw = trans_in_layout.ForwardShape(data->shape);
1168
1169 IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x;
1170
1171 ICHECK(param->kernel_size.defined() && param->channels.defined())
1172 << "The kernel size and channels of a Conv must be set or inferred by previous pass";
1173
1174 ICHECK_EQ(param->kernel_size.size(), 3);
1175 ICHECK_EQ(param->dilation.size(), 3);
1176
1177 channels = param->channels;
1178 dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
1179 dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
1180 dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2];
1181
1182 // NOTE: Do not check weight shape here!
1183 // Different backend requires different layout to compute
1184 // the batch gemm stage in winograd efficiently, but we want to
1185 // make this op work for all backends.
1186 // So we accept all weight shapes, and assume the TOPI developers
1187 // can handle this correctly in alter_op_layout.
1188
1189 // dilation
1190 Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
1191
1192 IndexExpr pad_d, pad_h, pad_w;
1193 GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
1194 if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
1195 oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1);
1196 } else {
1197 oshape.Set(2, dshape_ncdhw[2]);
1198 }
1199 if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
1200 oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1);
1201 } else {
1202 oshape.Set(3, dshape_ncdhw[3]);
1203 }
1204 if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
1205 oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1);
1206 } else {
1207 oshape.Set(4, dshape_ncdhw[4]);
1208 }
1209
1210 DataType out_dtype = param->out_dtype;
1211 if (out_dtype.bits() == 0) {
1212 out_dtype = data->dtype;
1213 }
1214 oshape = trans_out_layout.BackwardShape(oshape);
1215 // assign output type
1216 reporter->Assign(types[2], TensorType(oshape, out_dtype));
1217 return true;
1218}
1219
1220TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform")
1221 .set_body_typed([](Expr data, Expr weight, int tile_size, Array<IndexExpr> strides,
1222 Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
1223 IndexExpr channels, Array<IndexExpr> kernel_size, String data_layout,
1224 String kernel_layout, String out_layout, DataType out_dtype) {
1225 return MakeConvWinograd<Conv3DWinogradAttrs>(
1226 data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size,
1227 data_layout, kernel_layout, out_layout, out_dtype,
1228 "nn.contrib_conv3d_winograd_without_weight_transform");
1229 });
1230
1231RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform")
1232 .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout.
1233 This operator assumes the weight tensor is already pre-transformed by
1234 nn.contrib_conv3d_winograd_weight_transform.
1235
1236- **data**: Input is 5D array of shape (batch_size, in_channels, depth, height, width)
1237- **weight**: Any shape
1238 We do not check the shape for this input tensor. Since different backend
1239 has different layout strategy.
1240
1241- **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width)
1242)code" TVM_ADD_FILELINE)
1243 .set_attrs_type<Conv3DWinogradAttrs>()
1244 .set_num_inputs(2)
1245 .add_argument("data", "Tensor", "The input tensor.")
1246 .add_argument("weight", "Tensor", "The weight tensor.")
1247 .set_support_level(10)
1248 .add_type_rel("Conv3DWinograd", Conv3DWinogradRel)
1249 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
1250 ConvInferCorrectLayout<Conv3DWinogradAttrs>)
1251 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1252
1253// relay.nn.contrib_conv3d_winograd_weight_transform
1254TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform")
1255 .set_body_typed([](Expr weight, int tile_size) {
1256 return MakeConvWinogradWeightTransform(weight, tile_size,
1257 "nn.contrib_conv3d_winograd_weight_transform");
1258 });
1259
1260bool Conv3DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1261 const TypeReporter& reporter) {
1262 ICHECK_EQ(types.size(), 2);
1263 const auto* data = types[0].as<TensorTypeNode>();
1264 if (data == nullptr) return false;
1265
1266 const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>();
1267 ICHECK(param != nullptr);
1268
1269 ICHECK_EQ(data->shape.size(), 5) << "Only support NCDHW normal kernel layout";
1270
1271 // Shape of packed weights depends on whether depth is being transformed or not.
1272 Array<IndexExpr> oshape({0, 0, 0, data->shape[0], data->shape[1]});
1273 auto* depth_imm = data->shape[2].as<IntImmNode>();
1274 bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8);
1275 if (transform_depth) {
1276 oshape.Set(0, param->tile_size + data->shape[2] - 1);
1277 oshape.Set(1, param->tile_size + data->shape[3] - 1);
1278 oshape.Set(2, param->tile_size + data->shape[4] - 1);
1279 } else {
1280 oshape.Set(0, param->tile_size + data->shape[3] - 1);
1281 oshape.Set(1, param->tile_size + data->shape[4] - 1);
1282 oshape.Set(2, data->shape[2]);
1283 }
1284
1285 reporter->Assign(types[1], TensorType(oshape, data->dtype));
1286 return true;
1287}
1288
1289RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform")
1290 .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm.
1291
1292Separate this into another operator in order to enable Precompute Pass to compute the
1293weight transformation in advance.
1294
1295- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
1296)code" TVM_ADD_FILELINE)
1297 .set_attrs_type<ConvWinogradWeightTransformAttrs>()
1298 .set_num_inputs(1)
1299 .add_argument("weight", "Tensor", "The weight tensor.")
1300 .set_support_level(10)
1301 .add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel)
1302 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1303
1304// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
1305TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
1306
1307bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types, int num_inputs,
1308 const Attrs& attrs, const TypeReporter& reporter) {
1309 ICHECK_EQ(types.size(), 2);
1310 const auto* data = types[0].as<TensorTypeNode>();
1311 if (data == nullptr) {
1312 return false;
1313 }
1314
1315 const Conv2DWinogradNNPACKWeightTransformAttrs* param =
1316 attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
1317 ICHECK(param != nullptr);
1318
1319 ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
1320
1321 std::vector<IndexExpr> oshape{
1322 data->shape[0],
1323 data->shape[1],
1324 8,
1325 8,
1326 };
1327
1328 DataType out_dtype = param->out_dtype;
1329 if (out_dtype.bits() == 0) {
1330 out_dtype = data->dtype;
1331 }
1332 reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
1333 return true;
1334}
1335
1336Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm,
1337 DataType out_dtype) {
1338 auto attrs = make_object<Conv2DWinogradNNPACKWeightTransformAttrs>();
1339 attrs->convolution_algorithm = convolution_algorithm;
1340 attrs->out_dtype = std::move(out_dtype);
1341 static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform");
1342 return Call(op, {weight}, Attrs(attrs), {});
1343}
1344
1345TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
1346 .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
1347
1348RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
1349 .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
1350Separate this into another symbol in order to enable Precompute Pass to compute the
1351weight transformation in advance.
1352
1353- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
1354
1355)code" TVM_ADD_FILELINE)
1356 .set_attrs_type<Conv2DWinogradNNPACKWeightTransformAttrs>()
1357 .set_num_inputs(1)
1358 .add_argument("weight", "Tensor", "The weight tensor.")
1359 .set_support_level(10)
1360 .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel)
1361 .set_attr<TOpPattern>("TOpPattern", kOpaque);
1362
1363// relay.nn.contrib_conv2d_gemm_without_weight_transform
1364TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform")
1365 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
1366 Array<IndexExpr> dilation, int groups, IndexExpr channels,
1367 Array<IndexExpr> kernel_size, tvm::String data_layout,
1368 tvm::String kernel_layout, tvm::String out_layout, DataType out_dtype) {
1369 return MakeConvGemm<Conv2DAttrs>(
1370 data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
1371 kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform");
1372 });
1373
1374bool Conv2DGemmRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1375 const TypeReporter& reporter) {
1376 ICHECK_EQ(types.size(), 3);
1377 const auto* data = types[0].as<TensorTypeNode>();
1378 if (data == nullptr) return false;
1379 static const Layout kNHWC("NHWC");
1380 static const Layout kHWIO("HWIO");
1381
1382 const auto* param = attrs.as<Conv2DAttrs>();
1383 ICHECK(param != nullptr);
1384 const Layout in_layout(param->data_layout);
1385 const Layout kernel_layout(param->kernel_layout);
1386
1387 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNHWC);
1388 ICHECK(trans_in_layout.defined())
1389 << "Conv only support input layouts that are convertible from NHWC."
1390 << " But got " << in_layout;
1391
1392 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kHWIO);
1393 ICHECK(trans_kernel_layout.defined())
1394 << "Conv only support kernel layouts that are convertible from HWIO."
1395 << " But got " << kernel_layout;
1396
1397 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
1398 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNHWC);
1399 ICHECK(trans_out_layout.defined())
1400 << "Conv only support output layouts that are convertible from NHWC."
1401 << " But got " << out_layout;
1402
1403 Array<IndexExpr> dshape_nhwc = trans_in_layout.ForwardShape(data->shape);
1404
1405 IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
1406
1407 ICHECK(param->kernel_size.defined() && param->channels.defined())
1408 << "The kernel size and channels of a Conv must be set or inferred by previous pass";
1409
1410 ICHECK_EQ(param->kernel_size.size(), 2);
1411 ICHECK_EQ(param->dilation.size(), 2);
1412
1413 channels = param->channels;
1414 dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
1415 dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
1416
1417 // NOTE: Do not check weight shape here!
1418
1419 // dilation
1420 Array<IndexExpr> oshape({dshape_nhwc[0], 0, 0, channels});
1421
1422 IndexExpr pad_h, pad_w;
1423 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
1424 if (!dshape_nhwc[2].as<tir::AnyNode>()) {
1425 oshape.Set(1, (dshape_nhwc[1] + pad_h - dilated_ksize_y) / param->strides[0] + 1);
1426 } else {
1427 oshape.Set(1, dshape_nhwc[1]);
1428 }
1429 if (!dshape_nhwc[3].as<tir::AnyNode>()) {
1430 oshape.Set(2, (dshape_nhwc[2] + pad_w - dilated_ksize_x) / param->strides[1] + 1);
1431 } else {
1432 oshape.Set(2, dshape_nhwc[2]);
1433 }
1434
1435 DataType out_dtype = param->out_dtype;
1436 if (out_dtype.bits() == 0) {
1437 out_dtype = data->dtype;
1438 }
1439 oshape = trans_out_layout.BackwardShape(oshape);
1440 // assign output type
1441 reporter->Assign(types[2], TensorType(oshape, out_dtype));
1442 return true;
1443}
1444
1445RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform")
1446 .describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout.
1447 This operator assumes the weight tensor is already pre-transformed by
1448 nn.contrib_conv2d_gemm_weight_transform.
1449
1450- **data**: Input is 4D array of shape (batch_size, height, width, in_channels)
1451- **weight**: Any shape
1452 We do not check the shape for this input tensor. Since different backend
1453 has different layout strategy.
1454
1455- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
1456)code" TVM_ADD_FILELINE)
1457 .set_attrs_type<Conv2DAttrs>()
1458 .set_num_inputs(2)
1459 .add_argument("data", "Tensor", "The input tensor.")
1460 .add_argument("weight", "Tensor", "The weight tensor.")
1461 .set_support_level(10)
1462 .add_type_rel("Conv2DGemm", Conv2DGemmRel)
1463 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>)
1464 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1465
1466// relay.nn.contrib_conv2d_gemm_weight_transform
1467
1468TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs);
1469
1470// Gemm convolution shape relations
1471// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W.
1472// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and
1473// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols]
1474// matrix that we call W_interleaved_t.
1475//
1476// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed
1477// for tile_rows = 4 and tile_cols = 16
1478//
1479// W[0,0,:,:] W_interleaved_t[0,0,:,:]
1480// +-------------------------------+ +----------------------------------- +
1481// |W[0,0] W[0,1] W[0,2] W[0,3] | |W[0,0] W[1,0] W[2,0] ... W[15,0]|
1482// |W[1,0] W[1,1] W[1,2] W[1,3] | --\ |W[0,1] W[1,1] W[2,1] ... W[15,1]|
1483// |W[2,0] W[2,1] W[2,2] W[2,3] | --/ |W[0,2] W[1,2] W[2,2] ... W[15,2]|
1484// | ... ... ... ... | |W[0,3] W[1,3] W[2,3] ... W[15,3]|
1485// | ... ... ... ... | +------------------------------------+
1486// |W[15,0] W[15,1] W[15,2] W[15,3]|
1487// +-------------------------------+
1488//
1489// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements
1490// at the time, we should set tile_cols = k.
1491// Tile rows is connected with the number of registers available for the given target.
1492//
1493bool Conv2DGemmWeightTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1494 const TypeReporter& reporter) {
1495 ICHECK_EQ(types.size(), 2);
1496 const auto* weight = types[0].as<TensorTypeNode>();
1497 if (weight == nullptr) return false;
1498
1499 const ConvGemmWeightTransformAttrs* param = attrs.as<ConvGemmWeightTransformAttrs>();
1500 ICHECK(param != nullptr);
1501 int n = param->tile_rows;
1502 int k = param->tile_cols;
1503
1504 ICHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout";
1505
1506 const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2];
1507 const auto N = weight->shape[3];
1508
1509 auto K_mod_k = indexmod(K, k);
1510 auto N_mod_n = indexmod(N, n);
1511
1512 auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32)));
1513 auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32)));
1514
1515 const auto N_padded = N + pad_N;
1516 const auto K_padded = K + pad_K;
1517
1518 Array<IndexExpr> oshape{
1519 indexdiv(N_padded, n),
1520 indexdiv(K_padded, k),
1521 n,
1522 k,
1523 };
1524
1525 reporter->Assign(types[1], TensorType(oshape, weight->dtype));
1526 return true;
1527}
1528
1529TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform")
1530 .set_body_typed([](Expr weights, int tile_rows, int tile_cols) {
1531 return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols,
1532 "nn.contrib_conv2d_gemm_weight_transform");
1533 });
1534
1535RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_weight_transform")
1536 .describe(R"code(Weight transformation of GEMM convolution algorithm.
1537
1538Separate this into another operator in order to enable Precompute Pass to compute the
1539weight transformation in advance.
1540
1541)code" TVM_ADD_FILELINE)
1542 .set_attrs_type<ConvGemmWeightTransformAttrs>()
1543 .set_num_inputs(1)
1544 .add_argument("weights", "Tensor", "The weights tensor.")
1545 .set_support_level(10)
1546 .add_type_rel("Conv2DGemmWeightTransform", Conv2DGemmWeightTransformRel)
1547 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1548
1549// Positional relay function to create conv2d NCHWc operator
1550// used by frontend FFI.
1551TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
1552 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
1553 Array<IndexExpr> dilation, int groups, IndexExpr channels,
1554 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
1555 String out_layout, DataType out_dtype) {
1556 return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
1557 kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
1558 "nn.contrib_conv2d_NCHWc");
1559 });
1560
1561RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
1562 .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
1563- **data**: Input is 5D packed tensor.
1564- **weight**: 6D packed tensor.
1565
1566- **out**: Output is 5D packed tensor
1567)code" TVM_ADD_FILELINE)
1568 .set_attrs_type<Conv2DAttrs>()
1569 .set_num_inputs(2)
1570 .add_argument("data", "Tensor", "The input tensor.")
1571 .add_argument("weight", "Tensor", "The weight tensor.")
1572 .set_support_level(10)
1573 .add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
1574 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>)
1575 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1576
1577// Positional relay function to create depthwise conv2d NCHWc operator
1578// used by frontend FFI.
1579TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
1580 .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
1581 Array<IndexExpr> dilation, int groups, IndexExpr channels,
1582 Array<IndexExpr> kernel_size, String data_layout, String kernel_layout,
1583 String out_layout, DataType out_dtype) {
1584 return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
1585 kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
1586 "nn.contrib_depthwise_conv2d_NCHWc");
1587 });
1588
1589RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
1590 .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
1591- **data**: Input is 5D packed tensor.
1592- **weight**: 6D packed tensor.
1593
1594- **out**: Output is 5D packed tensor
1595)code" TVM_ADD_FILELINE)
1596 .set_attrs_type<Conv2DAttrs>()
1597 .set_num_inputs(2)
1598 .add_argument("data", "Tensor", "The input tensor.")
1599 .add_argument("weight", "Tensor", "The weight tensor.")
1600 .set_support_level(10)
1601 .add_type_rel("Conv2D", Conv2DRel)
1602 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>)
1603 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1604
1605TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs);
1606
1607// Deformable Convolution shape relations.
1608bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1609 const TypeReporter& reporter) {
1610 ICHECK_EQ(types.size(), 4);
1611 const auto* data = types[0].as<TensorTypeNode>();
1612 const auto* weight = types[2].as<TensorTypeNode>();
1613
1614 ICHECK(data);
1615 static const Layout kNCHW("NCHW");
1616 static const Layout kOIHW("OIHW");
1617
1618 auto* param = attrs.as<DeformableConv2DAttrs>();
1619 ICHECK(param != nullptr);
1620 const Layout in_layout(param->data_layout);
1621 const Layout kernel_layout(param->kernel_layout);
1622
1623 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
1624 if (!trans_in_layout.defined()) {
1625 reporter->GetDiagCtx().Emit(
1626 Diagnostic::Error(reporter->GetSpan())
1627 << "deformable_conv2d only support input layouts that are convertible from NCHW."
1628 << " The provided layout is: " << in_layout);
1629 return false;
1630 }
1631
1632 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
1633 if (!trans_kernel_layout.defined()) {
1634 reporter->GetDiagCtx().Emit(
1635 Diagnostic::Error(reporter->GetSpan())
1636 << "deformable_conv2d only support kernel layouts that are convertible from OIHW."
1637 << " The provided layout is: " << kernel_layout);
1638 return false;
1639 }
1640
1641 Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
1642 const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
1643 if (!trans_out_layout.defined()) {
1644 reporter->GetDiagCtx().Emit(
1645 Diagnostic::Error(reporter->GetSpan())
1646 << "deformable_conv2d only support output layouts that are convertible from NCHW."
1647 << "The provided layout is: " << out_layout);
1648 return false;
1649 }
1650
1651 Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
1652
1653 IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;
1654
1655 // infer weight shape if kernel_size and channels are defiend
1656 if (param->kernel_size.defined() && param->channels.defined()) {
1657 ICHECK_EQ(param->kernel_size.size(), 2);
1658 ICHECK_EQ(param->dilation.size(), 2);
1659 Array<IndexExpr> wshape({param->channels, indexdiv(dshape_nchw[1], param->groups),
1660 param->kernel_size[0], param->kernel_size[1]});
1661
1662 wshape = trans_kernel_layout.BackwardShape(wshape);
1663 channels = param->channels;
1664 ksize_y = param->kernel_size[0];
1665 ksize_x = param->kernel_size[1];
1666 dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
1667 dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
1668 // assign result to reporter
1669 reporter->Assign(types[2], TensorType(wshape, data->dtype));
1670 } else {
1671 // use weight to infer the conv shape.
1672 if (weight == nullptr) return false;
1673 auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
1674
1675 if (param->kernel_size.defined()) {
1676 ICHECK_EQ(param->kernel_size.size(), 2);
1677 // check the size
1678 ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
1679 reporter->AssertEQ(param->kernel_size[1], wshape[3]))
1680 << "DeformableConv2D: shape of weight is inconsistent with kernel_size, "
1681 << " kernel_size=" << param->kernel_size << " wshape=" << wshape;
1682 }
1683 if (param->channels.defined()) {
1684 ICHECK(reporter->AssertEQ(param->channels, wshape[0]))
1685 << "DeformableConv2D: shape of weight is inconsistent with channels, "
1686 << " channels=" << param->channels << " wshape=" << wshape;
1687 }
1688 if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
1689 ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1]));
1690 }
1691 channels = wshape[0];
1692 ksize_y = wshape[2];
1693 ksize_x = wshape[3];
1694 dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
1695 dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
1696 }
1697 // dilation
1698 Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
1699
1700 IndexExpr pad_h, pad_w;
1701 GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
1702 oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1);
1703 oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1);
1704 DataType out_dtype = param->out_dtype;
1705
1706 // infer offset shape
1707 Array<IndexExpr> offset_shape(
1708 {dshape_nchw[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]});
1709 offset_shape = trans_in_layout.BackwardShape(offset_shape);
1710 reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
1711 if (out_dtype.bits() == 0) {
1712 out_dtype = data->dtype;
1713 }
1714
1715 oshape = trans_out_layout.BackwardShape(oshape);
1716 reporter->Assign(types[3], TensorType(oshape, out_dtype));
1717 return true;
1718}
1719
1720InferCorrectLayoutOutput DeformableConvInferCorrectLayout(
1721 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
1722 const Array<tvm::relay::Type>& old_in_types) {
1723 const auto* params = attrs.as<DeformableConv2DAttrs>();
1724 return InferCorrectLayoutOutput(
1725 {params->data_layout, params->data_layout, params->kernel_layout},
1726 {params->out_layout == "" ? params->data_layout : params->out_layout}, attrs);
1727}
1728
1729RELAY_REGISTER_OP("nn.deformable_conv2d")
1730 .describe(R"code(Compute 2-D deformable convolution on 4-D input.
1731The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
1732
1733For 2-D deformable convolution, the shapes are
1734- **data**: (batch_size, channel, height, width)
1735- **offset**: (batch_size, deformable_groups * kernel[0] * kernel[1] * 2, out_height, out_width)
1736- **weight**: (num_filter, channel, kernel[0], kernel[1])
1737- **out**: (batch_size, num_filter, out_height, out_width).
1738
1739If `deformable_groups` is larger than 1, denoted by *dg*, then split the
1740input `offset` evenly into *dg* parts along the channel axis, and also evenly split `out`
1741evenly into *dg* parts along the channel axis. Next compute the deformable convolution, apply the
1742*i*-th part of the offset part on the *i*-th out.
1743
1744If `groups` is larger than 1, denoted by *g*, then split the input `data` evenly into *g* parts
1745along the channel axis, and also evenly split `weight` along the first dimension. Next compute
1746the convolution on the *i*-th part of the data with the *i*-th weight part. The output is obtained
1747by concating all the *g* results.
1748)code" TVM_ADD_FILELINE)
1749 .set_attrs_type<DeformableConv2DAttrs>()
1750 .set_num_inputs(3)
1751 .add_argument("data", "Tensor", "The input tensor.")
1752 .add_argument("offset", "Tensor", "The offset tensor.")
1753 .add_argument("weight", "Tensor", "The weight tensor.")
1754 .set_support_level(5)
1755 .add_type_rel("DeformableConv2D", DeformableConv2DRel)
1756 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", DeformableConvInferCorrectLayout)
1757 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1758
1759// Positional relay function to create deformable_conv2d operator
1760// used by frontend FFI.
1761TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
1762 .set_body_typed([](Expr data, Expr offset, Expr weight, Array<IndexExpr> strides,
1763 Array<IndexExpr> padding, Array<IndexExpr> dilation, int deformable_groups,
1764 int groups, int channels, Array<IndexExpr> kernel_size, String data_layout,
1765 String kernel_layout, String out_layout, DataType out_dtype) {
1766 return MakeDeformableConv<DeformableConv2DAttrs>(
1767 data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels,
1768 kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d");
1769 });
1770
1771inline Expr MakeConv2dBackwardWeight(Expr grad, Expr data, Array<IndexExpr> strides,
1772 Array<IndexExpr> padding, Array<IndexExpr> dilation,
1773 int groups, IndexExpr channels, Array<IndexExpr> kernel_size,
1774 std::string grad_layout, std::string data_layout,
1775 std::string kernel_layout, DataType out_dtype) {
1776 auto attrs = make_object<Conv2DAttrs>();
1777 attrs->strides = std::move(strides);
1778 attrs->padding = std::move(padding);
1779 attrs->dilation = std::move(dilation);
1780 attrs->groups = groups;
1781 attrs->channels = std::move(channels);
1782 attrs->kernel_size = std::move(kernel_size);
1783 attrs->out_dtype = std::move(out_dtype);
1784 attrs->data_layout = std::move(grad_layout);
1785 attrs->kernel_layout = std::move(data_layout);
1786 attrs->out_layout = std::move(kernel_layout);
1787 const Op& op = Op::Get("nn.conv2d_backward_weight");
1788 return Call(op, {grad, data}, Attrs(attrs), {});
1789}
1790
1791TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_backward_weight")
1792 .set_body_typed([](Expr grad, Expr data, Array<IndexExpr> strides, Array<IndexExpr> padding,
1793 Array<IndexExpr> dilation, int groups, IndexExpr channels,
1794 Array<IndexExpr> kernel_size, String grad_layout, String data_layout,
1795 String kernel_layout, DataType out_dtype) {
1796 return MakeConv2dBackwardWeight(grad, data, strides, padding, dilation, groups, channels,
1797 kernel_size, grad_layout, data_layout, kernel_layout,
1798 out_dtype);
1799 });
1800
1801bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1802 const TypeReporter& reporter) {
1803 ICHECK_EQ(types.size(), 3);
1804 const auto* grad = types[0].as<TensorTypeNode>();
1805 const auto* data = types[1].as<TensorTypeNode>();
1806 if (data == nullptr) return false;
1807
1808 static const Layout kNCHW("NCHW");
1809 static const Layout kOIHW("OIHW");
1810
1811 const auto* param = attrs.as<Conv2DAttrs>();
1812 ICHECK(param != nullptr);
1813 // Require kernel_size to be passed, to simplify the output shape determination.
1814 ICHECK(param->kernel_size.defined()) << "kernel_size attribute needs to be specified";
1815
1816 // We repurpose Conv2dAttrs for Conv2DBackwardWeight, note the meanings of layouts.
1817 const Layout grad_layout(param->data_layout);
1818 const Layout in_layout(param->kernel_layout);
1819 const Layout kernel_layout(param->out_layout);
1820
1821 const auto trans_grad_layout = tir::BijectiveLayout(grad_layout, kNCHW);
1822 const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
1823 const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
1824
1825 Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
1826 Array<IndexExpr> grad_shape_nchw = trans_grad_layout.ForwardShape(grad->shape);
1827
1828 auto in_channels = dshape_nchw[1];
1829 auto out_channels = grad_shape_nchw[1];
1830
1831 auto in_channels_intimm = in_channels.as<IntImmNode>();
1832 auto out_channels_intimm = out_channels.as<IntImmNode>();
1833 ICHECK(in_channels_intimm);
1834 ICHECK(out_channels_intimm);
1835
1836 IndexExpr weight_dim_i;
1837 if (in_channels_intimm->value == out_channels_intimm->value &&
1838 in_channels_intimm->value == param->groups) {
1839 // depthwise
1840 ICHECK(param->channels.defined())
1841 << "out_channels attribute not specified for depth wise conv2d.";
1842 weight_dim_i = indexdiv(param->channels, param->groups);
1843 } else {
1844 weight_dim_i = indexdiv(in_channels, param->groups);
1845 }
1846
1847 Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0],
1848 param->kernel_size[1]};
1849 auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw);
1850
1851 const auto dw_dtype = (param->out_dtype == DataType() || param->out_dtype.is_void())
1852 ? grad->dtype
1853 : param->out_dtype;
1854
1855 reporter->Assign(types[2], TensorType(wshape, dw_dtype));
1856 return true;
1857}
1858
1859RELAY_REGISTER_OP("nn.conv2d_backward_weight")
1860 .describe(R"code(The gradient of the 2D convolution layer with respect to the weight.
1861
1862This layer computes the gradient of the conv2d op with respect to weight,
1863given the original input data and the output gradient.
1864
1865- **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`.
1866- **data**: This depends on the `layout` parameter. Input is 4D array of shape
1867 (batch_size, in_channels, height, width) if `layout` is `NCHW`.
1868- **out**: This depends on the `layout` parameter. Output is 4D array of shape
1869 (channels, in_channels, kernel_size[0], kernel_size[1]) if `layout` is `NCHW`.
1870)code" TVM_ADD_FILELINE)
1871 .set_attrs_type<Conv2DAttrs>()
1872 .set_num_inputs(2)
1873 .add_argument("grad", "Tensor", "The gradient tensor.")
1874 .add_argument("data", "Tensor", "The input tensor.")
1875 .set_support_level(2)
1876 .add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel)
1877 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>)
1878 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1879
1880} // namespace relay
1881} // namespace tvm
1882