1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file convolution.cc |
22 | * \brief Convolution operators |
23 | */ |
24 | #include "convolution.h" |
25 | |
26 | #include <tvm/relay/attrs/nn.h> |
27 | #include <tvm/relay/op.h> |
28 | #include <tvm/tir/data_layout.h> |
29 | |
30 | #include <vector> |
31 | |
32 | #include "../../transforms/infer_layout_utils.h" |
33 | #include "../op_common.h" |
34 | #include "convolution_make.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | |
39 | Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) { |
40 | auto attrs = make_object<ConvWinogradWeightTransformAttrs>(); |
41 | attrs->tile_size = tile_size; |
42 | const Op& op = Op::Get(op_name); |
43 | return Call(op, {weight}, Attrs(attrs), {}); |
44 | } |
45 | |
46 | Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) { |
47 | auto attrs = make_object<ConvGemmWeightTransformAttrs>(); |
48 | attrs->tile_rows = tile_rows; |
49 | attrs->tile_cols = tile_cols; |
50 | const Op& op = Op::Get(op_name); |
51 | return Call(op, {weight}, Attrs(attrs), {}); |
52 | } |
53 | |
54 | // relay.nn.conv1d |
55 | TVM_REGISTER_NODE_TYPE(Conv1DAttrs); |
56 | |
57 | bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
58 | const TypeReporter& reporter) { |
59 | ICHECK_EQ(types.size(), 3); |
60 | const auto* data = types[0].as<TensorTypeNode>(); |
61 | const auto* weight = types[1].as<TensorTypeNode>(); |
62 | if (data == nullptr) return false; |
63 | static const Layout kNCW("NCW" ); |
64 | static const Layout kOIW("OIW" ); |
65 | |
66 | const auto* param = attrs.as<Conv1DAttrs>(); |
67 | ICHECK(param != nullptr); |
68 | const Layout in_layout(param->data_layout); |
69 | const Layout kernel_layout(param->kernel_layout); |
70 | |
71 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); |
72 | ICHECK(trans_in_layout.defined()) |
73 | << "Conv only support input layouts that are convertible from NCW." |
74 | << " But got " << in_layout; |
75 | |
76 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); |
77 | ICHECK(trans_kernel_layout.defined()) |
78 | << "Conv only support kernel layouts that are convertible from OIW." |
79 | << " But got " << kernel_layout; |
80 | |
81 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
82 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); |
83 | ICHECK(trans_out_layout.defined()) |
84 | << "Conv only support output layouts that are convertible from NCW." |
85 | << " But got " << out_layout; |
86 | |
87 | Array<IndexExpr> dshape_ncw = trans_in_layout.ForwardShape(data->shape); |
88 | |
89 | IndexExpr channels, dilated_ksize; |
90 | // infer weight if the kernel_size and channels are defined |
91 | if (param->kernel_size.defined() && param->channels.defined()) { |
92 | Array<IndexExpr> wshape; |
93 | |
94 | wshape = {{param->channels, indexdiv(dshape_ncw[1], param->groups), param->kernel_size[0]}}; |
95 | |
96 | wshape = trans_kernel_layout.BackwardShape(wshape); |
97 | channels = param->channels; |
98 | dilated_ksize = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
99 | DataType weight_dtype = data->dtype; |
100 | if (weight != nullptr) { |
101 | weight_dtype = weight->dtype; |
102 | } |
103 | // assign result to reporter |
104 | reporter->Assign(types[1], TensorType(wshape, weight_dtype)); |
105 | } else { |
106 | // use weight to infer the conv shape. |
107 | if (weight == nullptr) return false; |
108 | auto wshape = trans_kernel_layout.ForwardShape(weight->shape); |
109 | if (param->kernel_size.defined()) { |
110 | // check the size |
111 | ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) |
112 | << "Conv1D: shape of weight is inconsistent with kernel_size, " |
113 | << " kernel_size=" << param->kernel_size << " wshape=" << wshape; |
114 | } |
115 | if (param->channels.defined()) { |
116 | ICHECK(reporter->AssertEQ(param->channels, wshape[0])) |
117 | << "Conv1D: shape of weight is inconsistent with channels, " |
118 | << " channels=" << param->channels << " wshape=" << wshape; |
119 | } |
120 | if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) { |
121 | ICHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1])); |
122 | } |
123 | channels = wshape[0]; |
124 | dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0]; |
125 | } |
126 | // dilation |
127 | Array<IndexExpr> oshape({dshape_ncw[0], channels, 0}); |
128 | |
129 | if (!dshape_ncw[2].as<tir::AnyNode>()) { |
130 | oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, |
131 | param->strides[0]) + |
132 | 1); |
133 | } else { |
134 | oshape.Set(2, dshape_ncw[2]); |
135 | } |
136 | |
137 | DataType out_dtype = param->out_dtype; |
138 | if (out_dtype.bits() == 0) { |
139 | out_dtype = data->dtype; |
140 | } |
141 | oshape = trans_out_layout.BackwardShape(oshape); |
142 | // assign output type |
143 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
144 | return true; |
145 | } |
146 | |
147 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d" ) |
148 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
149 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
150 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
151 | String out_layout, DataType out_dtype) { |
152 | return MakeConv<Conv1DAttrs>(data, weight, strides, padding, dilation, groups, channels, |
153 | kernel_size, data_layout, kernel_layout, out_layout, out_dtype, |
154 | "nn.conv1d" ); |
155 | }); |
156 | |
157 | RELAY_REGISTER_OP("nn.conv1d" ) |
158 | .describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). |
159 | |
160 | This layer creates a convolution kernel that is convolved |
161 | with the layer input to produce a tensor of outputs. |
162 | |
163 | - **data**: This depends on the `layout` parameter. Input is 3D array of shape |
164 | (batch_size, in_channels, width) if `layout` is `NCW`. |
165 | - **weight**: (channels, in_channels, kernel_size) |
166 | - **out**: This depends on the `layout` parameter. Output is 3D array of shape |
167 | (batch_size, channels, out_width) if `layout` is `NCW`. |
168 | |
169 | )code" TVM_ADD_FILELINE) |
170 | .set_attrs_type<Conv1DAttrs>() |
171 | .set_num_inputs(2) |
172 | .add_argument("data" , "Tensor" , "The input tensor." ) |
173 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
174 | .set_support_level(2) |
175 | .add_type_rel("Conv1D" , Conv1DRel) |
176 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConvInferCorrectLayout<Conv1DAttrs>) |
177 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
178 | |
179 | // relay.nn.conv2d |
180 | TVM_REGISTER_NODE_TYPE(Conv2DAttrs); |
181 | |
182 | bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
183 | const TypeReporter& reporter) { |
184 | ICHECK_EQ(types.size(), 3); |
185 | const auto* data = types[0].as<TensorTypeNode>(); |
186 | const auto* weight = types[1].as<TensorTypeNode>(); |
187 | if (data == nullptr) return false; |
188 | static const Layout kNCHW("NCHW" ); |
189 | Layout kOIHW("OIHW" ); |
190 | |
191 | const auto* param = attrs.as<Conv2DAttrs>(); |
192 | DataType out_dtype = param->out_dtype; |
193 | if (out_dtype.bits() == 0) { |
194 | out_dtype = data->dtype; |
195 | if (out_dtype.bits() == 0 && weight != nullptr) { |
196 | out_dtype = weight->dtype; |
197 | } |
198 | } |
199 | TensorType meta_schedule_weight{nullptr}; |
200 | if (param->meta_schedule_original_shape.size() != 0) { |
201 | meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype); |
202 | weight = meta_schedule_weight.get(); |
203 | } |
204 | ICHECK(param != nullptr); |
205 | const Layout in_layout(param->data_layout); |
206 | const Layout kernel_layout(param->kernel_layout); |
207 | |
208 | bool is_dnnl_group_conv = false; |
209 | if (param->groups > 1 && kernel_layout.name().find("G" ) != std::string::npos) { |
210 | kOIHW = Layout("GOIHW" ); |
211 | is_dnnl_group_conv = true; |
212 | } |
213 | |
214 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); |
215 | if (!trans_in_layout.defined()) { |
216 | reporter->GetDiagCtx().Emit( |
217 | Diagnostic::Error(reporter->GetSpan()) |
218 | << "conv2d only support input layouts that are convertible from NCHW." |
219 | << " The provided layout is: " << in_layout); |
220 | return false; |
221 | } |
222 | |
223 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); |
224 | if (!trans_kernel_layout.defined()) { |
225 | reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) |
226 | << "conv2d only support kernel layouts that are convertible from " |
227 | << kOIHW << "." |
228 | << " The provided layout is: " << kernel_layout); |
229 | return false; |
230 | } |
231 | |
232 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
233 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); |
234 | if (!trans_out_layout.defined()) { |
235 | reporter->GetDiagCtx().Emit( |
236 | Diagnostic::Error(reporter->GetSpan()) |
237 | << "conv2d only support output layouts that are convertible from NCHW." |
238 | << "The provided layout is: " << out_layout); |
239 | return false; |
240 | } |
241 | |
242 | Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
243 | bool is_depthwise = false; |
244 | if (param->groups > 1) { |
245 | if (!(weight && weight->shape.defined())) { |
246 | reporter->GetDiagCtx().Emit( |
247 | Diagnostic::Error(reporter->GetSpan()) |
248 | << "Weight shape must be specified when groups is greater than 1." ); |
249 | return false; |
250 | } |
251 | |
252 | Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); |
253 | if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && |
254 | tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { |
255 | is_depthwise = true; |
256 | } |
257 | } |
258 | |
259 | IndexExpr channels, dilated_ksize_y, dilated_ksize_x; |
260 | // infer weight if the kernel_size and channels are defined |
261 | if (param->kernel_size.defined() && param->channels.defined()) { |
262 | ICHECK_EQ(param->kernel_size.size(), 2); |
263 | ICHECK_EQ(param->dilation.size(), 2); |
264 | Array<IndexExpr> wshape; |
265 | |
266 | if (is_dnnl_group_conv) { |
267 | // infer weight's shape for group convolution |
268 | wshape = {{param->groups, indexdiv(param->channels, param->groups), |
269 | indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0], |
270 | param->kernel_size[1]}}; |
271 | } else if (is_depthwise) { |
272 | // infer weight's shape for depthwise convolution |
273 | wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0], |
274 | param->kernel_size[1]}}; |
275 | } else { |
276 | wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0], |
277 | param->kernel_size[1]}}; |
278 | } |
279 | |
280 | wshape = trans_kernel_layout.BackwardShape(wshape); |
281 | channels = param->channels; |
282 | dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
283 | dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
284 | DataType weight_dtype = data->dtype; |
285 | if (weight != nullptr) { |
286 | weight_dtype = weight->dtype; |
287 | } |
288 | |
289 | if (param->auto_scheduler_rewritten_layout.size() != 0) { |
290 | // If the layout is rewritten by auto-scheduler, |
291 | // we just forcly apply the layout provided by auto-scheduler and |
292 | // skip the normal inference logic. |
293 | {} // do nothing |
294 | } else if (param->meta_schedule_original_shape.size() == 0) { |
295 | // Normal case: assign result to reporter |
296 | reporter->Assign(types[1], TensorType(wshape, weight_dtype)); |
297 | } |
298 | } else { |
299 | // use weight to infer the conv shape. |
300 | if (weight == nullptr) return false; |
301 | |
302 | Array<PrimExpr> wshape; |
303 | if (param->auto_scheduler_rewritten_layout.size() != 0) { |
304 | // works for the default kernel layout "HWIO" |
305 | ICHECK_EQ(param->kernel_layout, "HWIO" ); |
306 | wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, |
307 | {"ry" , "rx" , "rc" , "ff" }); |
308 | } else { |
309 | wshape = weight->shape; |
310 | } |
311 | |
312 | wshape = trans_kernel_layout.ForwardShape(wshape); |
313 | if (param->kernel_size.defined()) { |
314 | ICHECK_EQ(param->kernel_size.size(), 2); |
315 | |
316 | if (!reporter->AssertEQ(param->kernel_size[0], wshape[2])) { |
317 | reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) |
318 | << "Conv2D: shape of weight is inconsistent with kernel_size," |
319 | << " kernel_size=" << param->kernel_size |
320 | << " wshape=" << wshape); |
321 | } |
322 | |
323 | if (!reporter->AssertEQ(param->kernel_size[1], wshape[3])) { |
324 | reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) |
325 | << "Conv2D: shape of weight is inconsistent with kernel_size," |
326 | << " kernel_size=" << param->kernel_size |
327 | << " wshape=" << wshape); |
328 | return false; |
329 | } |
330 | } |
331 | |
332 | if (param->channels.defined() && !reporter->AssertEQ(param->channels, wshape[0])) { |
333 | reporter->GetDiagCtx().Emit( |
334 | Diagnostic::Error(reporter->GetSpan()) |
335 | << "conv2D: the first dimensions of the weight tensor (" << wshape << ")" |
336 | << "does not match the number of channels (" << param->channels << ")." ); |
337 | return false; |
338 | } |
339 | |
340 | if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) { |
341 | if (!reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])) { |
342 | reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) |
343 | << "conv2d: requires that `" |
344 | << indexdiv(dshape_nchw[1], param->groups) << "`," |
345 | << " the input channels (" << dshape_nchw[1] << ")" |
346 | << " divided by groups (" << param->groups << ")" |
347 | << ",\n must match the input channels" |
348 | << " of the weight `" << wshape[1] |
349 | << "`, where the weight shape is (" << wshape << ")." ); |
350 | return false; |
351 | } |
352 | } |
353 | channels = wshape[0]; |
354 | dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; |
355 | dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; |
356 | } |
357 | // dilation |
358 | Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); |
359 | |
360 | IndexExpr pad_h, pad_w; |
361 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
362 | if (!dshape_nchw[2].as<tir::AnyNode>()) { |
363 | oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); |
364 | } else { |
365 | oshape.Set(2, dshape_nchw[2]); |
366 | } |
367 | |
368 | if (!dshape_nchw[3].as<tir::AnyNode>()) { |
369 | oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); |
370 | } else { |
371 | oshape.Set(3, dshape_nchw[3]); |
372 | } |
373 | oshape = trans_out_layout.BackwardShape(oshape); |
374 | // assign output type |
375 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
376 | return true; |
377 | } |
378 | |
379 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d" ) |
380 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
381 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
382 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
383 | String out_layout, DataType out_dtype) { |
384 | return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels, |
385 | kernel_size, data_layout, kernel_layout, out_layout, out_dtype, |
386 | "nn.conv2d" ); |
387 | }); |
388 | |
389 | RELAY_REGISTER_OP("nn.conv2d" ) |
390 | .describe(R"code(2D convolution layer (e.g. spatial convolution over images). |
391 | |
392 | This layer creates a convolution kernel that is convolved |
393 | with the layer input to produce a tensor of outputs. |
394 | |
395 | - **data**: This depends on the `layout` parameter. Input is 4D array of shape |
396 | (batch_size, in_channels, height, width) if `layout` is `NCHW`. |
397 | - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) |
398 | - **out**: This depends on the `layout` parameter. Output is 4D array of shape |
399 | (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. |
400 | |
401 | )code" TVM_ADD_FILELINE) |
402 | .set_attrs_type<Conv2DAttrs>() |
403 | .set_num_inputs(2) |
404 | .add_argument("data" , "Tensor" , "The input tensor." ) |
405 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
406 | .set_support_level(2) |
407 | .add_type_rel("Conv2D" , Conv2DRel) |
408 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConvInferCorrectLayout<Conv2DAttrs>) |
409 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
410 | |
411 | // relay.nn.conv3d |
412 | TVM_REGISTER_NODE_TYPE(Conv3DAttrs); |
413 | |
414 | bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
415 | const TypeReporter& reporter) { |
416 | ICHECK_EQ(types.size(), 3); |
417 | const auto* data = types[0].as<TensorTypeNode>(); |
418 | const auto* weight = types[1].as<TensorTypeNode>(); |
419 | if (data == nullptr) return false; |
420 | static const Layout kNCDHW("NCDHW" ); |
421 | static const Layout kOIDHW("OIDHW" ); |
422 | |
423 | const auto* param = attrs.as<Conv3DAttrs>(); |
424 | ICHECK(param != nullptr); |
425 | DataType out_dtype = param->out_dtype; |
426 | if (out_dtype.bits() == 0) { |
427 | out_dtype = data->dtype; |
428 | if (out_dtype.bits() == 0 && weight != nullptr) { |
429 | out_dtype = weight->dtype; |
430 | } |
431 | } |
432 | TensorType meta_schedule_weight{nullptr}; |
433 | if (param->meta_schedule_original_shape.size() != 0) { |
434 | meta_schedule_weight = TensorType(param->meta_schedule_original_shape, out_dtype); |
435 | weight = meta_schedule_weight.get(); |
436 | } |
437 | const Layout in_layout(param->data_layout); |
438 | const Layout kernel_layout(param->kernel_layout); |
439 | |
440 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); |
441 | ICHECK(trans_in_layout.defined()) |
442 | << "Conv only support input layouts that are convertible from NCDHW." |
443 | << " But got " << in_layout; |
444 | |
445 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); |
446 | ICHECK(trans_kernel_layout.defined()) |
447 | << "Conv only support kernel layouts that are convertible from OIDHW." |
448 | << " But got " << kernel_layout; |
449 | |
450 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
451 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); |
452 | ICHECK(trans_out_layout.defined()) |
453 | << "Conv only support output layouts that are convertible from NCDHW." |
454 | << " But got " << out_layout; |
455 | |
456 | Array<IndexExpr> dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); |
457 | |
458 | IndexExpr channels, dilated_ksize_z, dilated_ksize_y, dilated_ksize_x; |
459 | // infer weight if the kernel_size and channels are defined |
460 | if (param->kernel_size.defined() && param->channels.defined()) { |
461 | ICHECK_EQ(param->kernel_size.size(), 3); |
462 | ICHECK_EQ(param->dilation.size(), 3); |
463 | Array<IndexExpr> wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups), |
464 | param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); |
465 | wshape = trans_kernel_layout.BackwardShape(wshape); |
466 | channels = param->channels; |
467 | dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
468 | dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
469 | dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; |
470 | DataType weight_dtype = data->dtype; |
471 | if (weight != nullptr) { |
472 | weight_dtype = weight->dtype; |
473 | } |
474 | |
475 | if (param->auto_scheduler_rewritten_layout.size() != 0) { |
476 | // If the layout is rewritten by auto-scheduler, |
477 | // we just forcly apply the layout provided by auto-scheduler and |
478 | // skip the normal inference logic. |
479 | {} // do nothing |
480 | } else if (param->meta_schedule_original_shape.size() == 0) { |
481 | // Normal case: assign result to reporter |
482 | reporter->Assign(types[1], TensorType(wshape, weight_dtype)); |
483 | } |
484 | |
485 | } else { |
486 | // use weight to infer the conv shape. |
487 | if (weight == nullptr) return false; |
488 | |
489 | Array<PrimExpr> wshape; |
490 | if (param->auto_scheduler_rewritten_layout.size() != 0) { |
491 | // works for the default kernel layout "DHWIO" |
492 | ICHECK_EQ(param->kernel_layout, "DHWIO" ); |
493 | wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout, |
494 | {"rd" , "rh" , "rw" , "rc" , "cc" }); |
495 | } else { |
496 | wshape = weight->shape; |
497 | } |
498 | |
499 | wshape = trans_kernel_layout.ForwardShape(wshape); |
500 | if (param->kernel_size.defined()) { |
501 | ICHECK_EQ(param->kernel_size.size(), 3); |
502 | // check the size |
503 | ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && |
504 | reporter->AssertEQ(param->kernel_size[1], wshape[3]) && |
505 | reporter->AssertEQ(param->kernel_size[2], wshape[4])) |
506 | << "Conv3D: shape of weight is inconsistent with kernel_size, " |
507 | << " kernel_size=" << param->kernel_size << " wshape=" << wshape; |
508 | } |
509 | |
510 | if (param->channels.defined()) { |
511 | ICHECK(reporter->AssertEQ(param->channels, wshape[0])) |
512 | << "Conv3D: shape of weight is inconsistent with channels, " |
513 | << " channels=" << param->channels << " wshape=" << wshape; |
514 | } |
515 | |
516 | if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) { |
517 | ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); |
518 | } |
519 | channels = wshape[0]; |
520 | dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0]; |
521 | dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1]; |
522 | dilated_ksize_x = 1 + (wshape[4] - 1) * param->dilation[2]; |
523 | } |
524 | // dilation |
525 | Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0}); |
526 | |
527 | IndexExpr pad_d, pad_h, pad_w; |
528 | GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); |
529 | if (!dshape_ncdhw[2].as<tir::AnyNode>()) { |
530 | oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); |
531 | } else { |
532 | oshape.Set(2, dshape_ncdhw[2]); |
533 | } |
534 | |
535 | if (!dshape_ncdhw[3].as<tir::AnyNode>()) { |
536 | oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); |
537 | } else { |
538 | oshape.Set(3, dshape_ncdhw[3]); |
539 | } |
540 | |
541 | if (!dshape_ncdhw[4].as<tir::AnyNode>()) { |
542 | oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); |
543 | } else { |
544 | oshape.Set(4, dshape_ncdhw[4]); |
545 | } |
546 | oshape = trans_out_layout.BackwardShape(oshape); |
547 | // assign output type |
548 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
549 | return true; |
550 | } |
551 | |
552 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d" ) |
553 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
554 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
555 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
556 | String out_layout, DataType out_dtype) { |
557 | return MakeConv<Conv3DAttrs>(data, weight, strides, padding, dilation, groups, channels, |
558 | kernel_size, data_layout, kernel_layout, out_layout, out_dtype, |
559 | "nn.conv3d" ); |
560 | }); |
561 | |
562 | RELAY_REGISTER_OP("nn.conv3d" ) |
563 | .describe(R"code(3D convolution layer (e.g. convolution over 3D image data, |
564 | like Magnetic Resonance Imaging (MRI) data in medicine). |
565 | |
566 | This layer creates a convolution kernel that is convolved |
567 | with the layer input to produce a tensor of outputs. |
568 | |
569 | - **data**: This depends on the `layout` parameter. Input is 5D array of shape |
570 | (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`. |
571 | - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) |
572 | - **out**: This depends on the `layout` parameter. Output is 5D array of shape |
573 | (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. |
574 | |
575 | )code" TVM_ADD_FILELINE) |
576 | .set_attrs_type<Conv3DAttrs>() |
577 | .set_num_inputs(2) |
578 | .add_argument("data" , "Tensor" , "The input tensor." ) |
579 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
580 | .set_support_level(2) |
581 | .add_type_rel("Conv3D" , Conv3DRel) |
582 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConvInferCorrectLayout<Conv3DAttrs>) |
583 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
584 | |
585 | // relay.nn.conv3d_transpose |
586 | TVM_REGISTER_NODE_TYPE(Conv3DTransposeAttrs); |
587 | |
588 | template <typename AttrType> |
589 | bool Conv3DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
590 | const TypeReporter& reporter) { |
591 | ICHECK_EQ(types.size(), 3); |
592 | const auto* data = types[0].as<TensorTypeNode>(); |
593 | const auto* weight = types[1].as<TensorTypeNode>(); |
594 | if (data == nullptr) return false; |
595 | |
596 | static const Layout kNCDHW("NCDHW" ); |
597 | static const Layout kOIDHW("OIDHW" ); |
598 | |
599 | const Conv3DTransposeAttrs* param = attrs.as<Conv3DTransposeAttrs>(); |
600 | ICHECK(param != nullptr); |
601 | const Layout in_layout(param->data_layout); |
602 | const Layout kernel_layout(param->kernel_layout); |
603 | |
604 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); |
605 | ICHECK(trans_in_layout.defined()) |
606 | << "Conv3d_transpose only support input layouts that are convertible from NCDHW." |
607 | << " But got " << in_layout; |
608 | |
609 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); |
610 | ICHECK(trans_kernel_layout.defined()) |
611 | << "Conv3d_transpose only support kernel layouts that are convertible from OIDHW." |
612 | << " But got " << kernel_layout; |
613 | |
614 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
615 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); |
616 | ICHECK(trans_out_layout.defined()) |
617 | << "Conv3d_transpose only support output layouts that are convertible from NCDHW." |
618 | << " But got " << out_layout; |
619 | |
620 | IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; |
621 | |
622 | auto dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); |
623 | |
624 | // infer weight if the kernel_size and channels are defined |
625 | if (param->kernel_size.defined() && param->channels.defined()) { |
626 | ICHECK_EQ(param->kernel_size.size(), 3); |
627 | ICHECK_EQ(param->dilation.size(), 3); |
628 | |
629 | Array<IndexExpr> wshape({dshape_ncdhw[1], indexdiv(param->channels, param->groups), |
630 | param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); |
631 | |
632 | wshape = trans_kernel_layout.BackwardShape(wshape); |
633 | dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
634 | dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
635 | dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; |
636 | channels = param->channels; |
637 | |
638 | DataType weight_dtype = data->dtype; |
639 | if (weight != nullptr) { |
640 | weight_dtype = weight->dtype; |
641 | } |
642 | // assign result to reporter |
643 | reporter->Assign(types[1], TensorType(wshape, weight_dtype)); |
644 | } else { |
645 | // use weight to infer the conv shape. |
646 | if (weight == nullptr) return false; |
647 | auto wshape = trans_kernel_layout.ForwardShape(weight->shape); |
648 | if (param->kernel_size.defined()) { |
649 | ICHECK_EQ(param->kernel_size.size(), 3); |
650 | // check the size |
651 | ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && |
652 | reporter->AssertEQ(param->kernel_size[1], wshape[3]) && |
653 | reporter->AssertEQ(param->kernel_size[2], wshape[4])) |
654 | << "Conv3D: shape of weight is inconsistent with kernel_size, " |
655 | << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape); |
656 | } |
657 | if (param->channels.defined()) { |
658 | ICHECK(reporter->AssertEQ(param->channels, wshape[1])) |
659 | << "Conv3D: shape of weight is inconsistent with channels, " |
660 | << " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape); |
661 | } |
662 | if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) { |
663 | ICHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); |
664 | } |
665 | channels = wshape[1]; |
666 | dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0]; |
667 | dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; |
668 | dilated_ksize_y = 1 + (wshape[4] - 1) * param->dilation[2]; |
669 | } |
670 | |
671 | // dilation |
672 | Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0}); |
673 | IndexExpr pad_d, pad_h, pad_w; |
674 | GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); |
675 | |
676 | if (!dshape_ncdhw[2].as<tir::AnyNode>()) { |
677 | oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + |
678 | param->output_padding[0])); |
679 | } else { |
680 | oshape.Set(2, dshape_ncdhw[2]); |
681 | } |
682 | if (!dshape_ncdhw[3].as<tir::AnyNode>()) { |
683 | oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + |
684 | param->output_padding[1])); |
685 | } else { |
686 | oshape.Set(3, dshape_ncdhw[3]); |
687 | } |
688 | if (!dshape_ncdhw[4].as<tir::AnyNode>()) { |
689 | oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + |
690 | param->output_padding[2])); |
691 | } else { |
692 | oshape.Set(4, dshape_ncdhw[4]); |
693 | } |
694 | |
695 | DataType out_dtype = param->out_dtype; |
696 | if (out_dtype.bits() == 0) { |
697 | out_dtype = data->dtype; |
698 | } |
699 | oshape = trans_out_layout.BackwardShape(oshape); |
700 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
701 | return true; |
702 | } |
703 | |
704 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d_transpose" ) |
705 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
706 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
707 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
708 | String out_layout, Array<IndexExpr> output_padding, DataType out_dtype) { |
709 | return MakeConvTranspose<Conv3DTransposeAttrs>( |
710 | data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, |
711 | kernel_layout, out_layout, output_padding, out_dtype, "nn.conv3d_transpose" ); |
712 | }); |
713 | |
714 | RELAY_REGISTER_OP("nn.conv3d_transpose" ) |
715 | .describe(R"code(Transposed 3D convolution layer (sometimes called Deconvolution 3D). |
716 | |
717 | The need for transposed convolutions generally arises |
718 | from the desire to use a transformation going in the opposite direction |
719 | of a normal convolution, i.e., from something that has the shape of the |
720 | output of some convolution to something that has the shape of its input |
721 | while maintaining a connectivity pattern that is compatible with |
722 | said convolution. |
723 | |
724 | - **data**: This depends on the `layout` parameter. Input is 5D array of shape |
725 | (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`. |
726 | - **weight**: (in_channels, channels, kernel_size[0], kernel_size[1], kernel_size[2]) |
727 | - **bias**: (channels,) |
728 | - **out**: This depends on the `layout` parameter. Output is 5D array of shape |
729 | (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. |
730 | |
731 | out_depth and out_height and out_width are calculated as:: |
732 | out_depth = (depth-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] |
733 | out_height = (height-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] |
734 | out_width = (width-1)*strides[2]-2*padding[2]+kernel_size[2]+output_padding[2] |
735 | |
736 | )code" TVM_ADD_FILELINE) |
737 | .set_attrs_type<Conv3DTransposeAttrs>() |
738 | .set_num_inputs(2) |
739 | .add_argument("data" , "Tensor" , "The input tensor." ) |
740 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
741 | .set_support_level(2) |
742 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
743 | ConvInferCorrectLayout<Conv3DTransposeAttrs>) |
744 | .add_type_rel("Conv3DTranspose" , Conv3DTransposeRel<Conv3DTransposeAttrs>) |
745 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
746 | |
747 | // relay.nn.conv2d_transpose |
748 | TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); |
749 | |
750 | bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
751 | const TypeReporter& reporter) { |
752 | ICHECK_EQ(types.size(), 3); |
753 | const auto* data = types[0].as<TensorTypeNode>(); |
754 | const auto* weight = types[1].as<TensorTypeNode>(); |
755 | if (data == nullptr) return false; |
756 | |
757 | static const Layout kNCHW("NCHW" ); |
758 | Layout kIOHW("IOHW" ); |
759 | |
760 | const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>(); |
761 | ICHECK(param != nullptr); |
762 | const Layout in_layout(param->data_layout); |
763 | const Layout kernel_layout(param->kernel_layout); |
764 | |
765 | bool is_dnnl_group_conv = false; |
766 | if (param->groups > 1 && kernel_layout.name().find("G" ) != std::string::npos) { |
767 | kIOHW = Layout("GIOHW" ); |
768 | is_dnnl_group_conv = true; |
769 | } |
770 | |
771 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); |
772 | ICHECK(trans_in_layout.defined()) |
773 | << "Conv2DTransposed only support input layouts that are convertible from NCHW." |
774 | << " But got " << in_layout; |
775 | |
776 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW); |
777 | ICHECK(trans_kernel_layout.defined()) |
778 | << "Conv2DTransposed only support kernel layouts that are convertible from " << kIOHW << "." |
779 | << " But got " << kernel_layout << " " << kIOHW; |
780 | |
781 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
782 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); |
783 | ICHECK(trans_out_layout.defined()) |
784 | << "Conv2DTransposed only support output layouts that are convertible from NCHW." |
785 | << " But got " << out_layout; |
786 | |
787 | IndexExpr channels, dilated_ksize_y, dilated_ksize_x; |
788 | |
789 | auto dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
790 | |
791 | // infer weight if the kernel_size and channels are defined |
792 | if (param->kernel_size.defined() && param->channels.defined()) { |
793 | ICHECK_EQ(param->kernel_size.size(), 2); |
794 | ICHECK_EQ(param->dilation.size(), 2); |
795 | |
796 | Array<IndexExpr> wshape; |
797 | if (is_dnnl_group_conv) { |
798 | // infer weight's shape for group convolution |
799 | wshape = {{param->groups, indexdiv(dshape_nchw[1], param->groups), |
800 | indexdiv(param->channels, param->groups), param->kernel_size[0], |
801 | param->kernel_size[1]}}; |
802 | } else { |
803 | // infer weight's shape for depthwise convolution |
804 | wshape = {{dshape_nchw[1], indexdiv(param->channels, param->groups), param->kernel_size[0], |
805 | param->kernel_size[1]}}; |
806 | } |
807 | |
808 | wshape = trans_kernel_layout.BackwardShape(wshape); |
809 | dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
810 | dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
811 | channels = param->channels; |
812 | |
813 | DataType weight_dtype = data->dtype; |
814 | if (weight != nullptr) { |
815 | weight_dtype = weight->dtype; |
816 | } |
817 | // assign result to reporter |
818 | reporter->Assign(types[1], TensorType(wshape, weight_dtype)); |
819 | } else { |
820 | // use weight to infer the conv shape. |
821 | if (weight == nullptr) return false; |
822 | auto wshape = trans_kernel_layout.ForwardShape(weight->shape); |
823 | if (param->kernel_size.defined()) { |
824 | ICHECK_EQ(param->kernel_size.size(), 2); |
825 | // check the size |
826 | ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && |
827 | reporter->AssertEQ(param->kernel_size[1], wshape[3])) |
828 | << "Conv2DTransposed: shape of weight is inconsistent with kernel_size, " |
829 | << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape); |
830 | } |
831 | if (param->channels.defined()) { |
832 | ICHECK(reporter->AssertEQ(indexdiv(param->channels, param->groups), wshape[1])) |
833 | << "Conv2DTransposed: shape of weight is inconsistent with out_channels, " |
834 | << " out_channels // groups != weight.shape[1] " |
835 | << " out_channels=" << param->channels << " groups=" << param->groups |
836 | << " weight.shape=" << Array<IndexExpr>(wshape); |
837 | } |
838 | if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) { |
839 | ICHECK(reporter->AssertEQ(dshape_nchw[1], wshape[0])) |
840 | << "Conv2DTransposed: shape of weight is inconsistent with in_channels." |
841 | << " data.shape= " << Array<IndexExpr>(dshape_nchw) << " groups= " << param->groups |
842 | << " weight.shape= " << Array<IndexExpr>(wshape); |
843 | } |
844 | channels = wshape[1]; |
845 | dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; |
846 | dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; |
847 | } |
848 | // dilation |
849 | Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); |
850 | IndexExpr pad_h, pad_w; |
851 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
852 | if (!dshape_nchw[2].as<tir::AnyNode>()) { |
853 | oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + |
854 | param->output_padding[0])); |
855 | } else { |
856 | oshape.Set(2, dshape_nchw[2]); |
857 | } |
858 | if (!dshape_nchw[3].as<tir::AnyNode>()) { |
859 | oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + |
860 | param->output_padding[1])); |
861 | } else { |
862 | oshape.Set(3, dshape_nchw[3]); |
863 | } |
864 | |
865 | DataType out_dtype = param->out_dtype; |
866 | if (out_dtype.bits() == 0) { |
867 | out_dtype = data->dtype; |
868 | } |
869 | oshape = trans_out_layout.BackwardShape(oshape); |
870 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
871 | return true; |
872 | } |
873 | |
874 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose" ) |
875 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
876 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
877 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
878 | String out_layout, Array<IndexExpr> output_padding, DataType out_dtype) { |
879 | return MakeConvTranspose<Conv2DTransposeAttrs>( |
880 | data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, |
881 | kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose" ); |
882 | }); |
883 | |
884 | RELAY_REGISTER_OP("nn.conv2d_transpose" ) |
885 | .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). |
886 | |
887 | The need for transposed convolutions generally arises |
888 | from the desire to use a transformation going in the opposite direction |
889 | of a normal convolution, i.e., from something that has the shape of the |
890 | output of some convolution to something that has the shape of its input |
891 | while maintaining a connectivity pattern that is compatible with |
892 | said convolution. |
893 | |
894 | - **data**: This depends on the `layout` parameter. Input is 4D array of shape |
895 | (batch_size, in_channels, height, width) if `layout` is `NCHW`. |
896 | - **weight**: (in_channels, channels, kernel_size[0], kernel_size[1]) |
897 | - **bias**: (channels,) |
898 | - **out**: This depends on the `layout` parameter. Output is 4D array of shape |
899 | v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. |
900 | |
901 | out_height and out_width are calculated as:: |
902 | out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] |
903 | out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] |
904 | |
905 | )code" TVM_ADD_FILELINE) |
906 | .set_attrs_type<Conv2DTransposeAttrs>() |
907 | .set_num_inputs(2) |
908 | .add_argument("data" , "Tensor" , "The input tensor." ) |
909 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
910 | .set_support_level(2) |
911 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
912 | ConvInferCorrectLayout<Conv2DTransposeAttrs>) |
913 | .add_type_rel("Conv2DTranspose" , Conv2DTransposeRel) |
914 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
915 | |
916 | // relay.nn.conv1d_transpose |
917 | TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs); |
918 | |
919 | bool Conv1DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
920 | const TypeReporter& reporter) { |
921 | ICHECK_EQ(types.size(), 3); |
922 | const auto* data = types[0].as<TensorTypeNode>(); |
923 | const auto* weight = types[1].as<TensorTypeNode>(); |
924 | if (data == nullptr) return false; |
925 | |
926 | static const Layout kNCW("NCW" ); |
927 | static const Layout kOIW("OIW" ); |
928 | |
929 | const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>(); |
930 | ICHECK(param != nullptr); |
931 | const Layout in_layout(param->data_layout); |
932 | const Layout kernel_layout(param->kernel_layout); |
933 | |
934 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); |
935 | ICHECK(trans_in_layout.defined()) |
936 | << "Conv only support input layouts that are convertible from NCW." |
937 | << " But got " << in_layout; |
938 | |
939 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); |
940 | ICHECK(trans_kernel_layout.defined()) |
941 | << "Conv only support kernel layouts that are convertible from OIW." |
942 | << " But got " << kernel_layout; |
943 | |
944 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
945 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); |
946 | ICHECK(trans_out_layout.defined()) |
947 | << "Conv only support output layouts that are convertible from NCW." |
948 | << " But got " << out_layout; |
949 | |
950 | IndexExpr channels, dilated_ksize_y, dilated_ksize_x; |
951 | |
952 | auto dshape_ncw = trans_in_layout.ForwardShape(data->shape); |
953 | |
954 | // infer weight if the kernel_size and channels are defined |
955 | if (param->kernel_size.defined() && param->channels.defined()) { |
956 | ICHECK_EQ(param->kernel_size.size(), 1); |
957 | ICHECK_EQ(param->dilation.size(), 1); |
958 | |
959 | Array<IndexExpr> wshape( |
960 | {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]}); |
961 | |
962 | wshape = trans_kernel_layout.BackwardShape(wshape); |
963 | dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
964 | channels = param->channels; |
965 | |
966 | DataType weight_dtype = data->dtype; |
967 | if (weight != nullptr) { |
968 | weight_dtype = weight->dtype; |
969 | } |
970 | // assign result to reporter |
971 | reporter->Assign(types[1], TensorType(wshape, weight_dtype)); |
972 | } else { |
973 | // use weight to infer the conv shape. |
974 | if (weight == nullptr) return false; |
975 | auto wshape = trans_kernel_layout.ForwardShape(weight->shape); |
976 | if (param->kernel_size.defined()) { |
977 | ICHECK_EQ(param->kernel_size.size(), 1); |
978 | // check the size |
979 | ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) |
980 | << "Conv1D: shape of weight is inconsistent with kernel_size, " |
981 | << " kernel_size=" << param->kernel_size << " wshape=" << Array<IndexExpr>(wshape); |
982 | } |
983 | if (param->channels.defined()) { |
984 | ICHECK(reporter->AssertEQ(param->channels, wshape[1])) |
985 | << "Conv1D: shape of weight is inconsistent with channels, " |
986 | << " channels=" << param->channels << " wshape=" << Array<IndexExpr>(wshape); |
987 | } |
988 | if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) { |
989 | ICHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); |
990 | } |
991 | channels = wshape[1]; |
992 | dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0]; |
993 | } |
994 | // dilation |
995 | IndexExpr pad_w; |
996 | GetPaddingWidth(param->padding, &pad_w); |
997 | Array<IndexExpr> oshape({dshape_ncw[0], channels, 0}); |
998 | if (!dshape_ncw[2].as<tir::AnyNode>()) { |
999 | oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + |
1000 | param->output_padding[0])); |
1001 | } else { |
1002 | oshape.Set(2, dshape_ncw[2]); |
1003 | } |
1004 | |
1005 | DataType out_dtype = param->out_dtype; |
1006 | if (out_dtype.bits() == 0) { |
1007 | out_dtype = data->dtype; |
1008 | } |
1009 | oshape = trans_out_layout.BackwardShape(oshape); |
1010 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
1011 | return true; |
1012 | } |
1013 | |
1014 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose" ) |
1015 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
1016 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
1017 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
1018 | String out_layout, Array<IndexExpr> output_padding, DataType out_dtype) { |
1019 | return MakeConvTranspose<Conv1DTransposeAttrs>( |
1020 | data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, |
1021 | kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose" ); |
1022 | }); |
1023 | |
1024 | RELAY_REGISTER_OP("nn.conv1d_transpose" ) |
1025 | .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). |
1026 | |
1027 | The need for transposed convolutions generally arises |
1028 | from the desire to use a transformation going in the opposite direction |
1029 | of a normal convolution, i.e., from something that has the shape of the |
1030 | output of some convolution to something that has the shape of its input |
1031 | while maintaining a connectivity pattern that is compatible with |
1032 | said convolution. |
1033 | |
1034 | - **data**: This depends on the `layout` parameter. Input is 3D array of shape |
1035 | (batch_size, in_channels, width) if `layout` is `NCW`. |
1036 | - **weight**: (in_channels, channels, kernel_size[0]) |
1037 | - **bias**: (channels,) |
1038 | - **out**: This depends on the `layout` parameter. Output is 3D array of shape |
1039 | (batch_size, channels, out_width) if `layout` is `NCW`. |
1040 | |
1041 | out_width is calculated as:: |
1042 | out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] |
1043 | |
1044 | )code" TVM_ADD_FILELINE) |
1045 | .set_attrs_type<Conv1DTransposeAttrs>() |
1046 | .set_num_inputs(2) |
1047 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1048 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1049 | .set_support_level(2) |
1050 | .add_type_rel("Conv1DTranspose" , Conv1DTransposeRel) |
1051 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1052 | |
1053 | // relay.nn.contrib_conv2d_winograd_without_weight_transform |
1054 | TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); |
1055 | |
1056 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform" ) |
1057 | .set_body_typed([](Expr data, Expr weight, int tile_size, Array<IndexExpr> strides, |
1058 | Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups, |
1059 | IndexExpr channels, Array<IndexExpr> kernel_size, String data_layout, |
1060 | String kernel_layout, String out_layout, DataType out_dtype) { |
1061 | return MakeConvWinograd<Conv2DWinogradAttrs>( |
1062 | data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, |
1063 | data_layout, kernel_layout, out_layout, out_dtype, |
1064 | "nn.contrib_conv2d_winograd_without_weight_transform" ); |
1065 | }); |
1066 | |
1067 | RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform" ) |
1068 | .describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. |
1069 | This operator assumes the weight tensor is already pre-transformed by |
1070 | nn.contrib_conv2d_winograd_weight_transform. |
1071 | |
1072 | - **data**: Input is 4D array of shape (batch_size, in_channels, height, width) |
1073 | - **weight**: Any shape |
1074 | We do not check the shape for this input tensor. Since different backend |
1075 | has different layout strategy. |
1076 | |
1077 | - **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) |
1078 | )code" TVM_ADD_FILELINE) |
1079 | .set_attrs_type<Conv2DWinogradAttrs>() |
1080 | .set_num_inputs(2) |
1081 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1082 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1083 | .set_support_level(10) |
1084 | .add_type_rel("Conv2DWinograd" , Conv2DWinogradRel<Conv2DWinogradAttrs>) |
1085 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
1086 | ConvInferCorrectLayout<Conv2DWinogradAttrs>) |
1087 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1088 | |
1089 | // relay.nn.contrib_conv2d_winograd_weight_transform |
1090 | TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs); |
1091 | |
1092 | bool Conv2DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1093 | const TypeReporter& reporter) { |
1094 | ICHECK_EQ(types.size(), 2); |
1095 | const auto* data = types[0].as<TensorTypeNode>(); |
1096 | if (data == nullptr) return false; |
1097 | |
1098 | const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>(); |
1099 | ICHECK(param != nullptr); |
1100 | |
1101 | ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout" ; |
1102 | |
1103 | std::vector<IndexExpr> oshape{ |
1104 | param->tile_size + data->shape[2] - 1, |
1105 | param->tile_size + data->shape[3] - 1, |
1106 | data->shape[0], |
1107 | data->shape[1], |
1108 | }; |
1109 | |
1110 | reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), data->dtype)); |
1111 | return true; |
1112 | } |
1113 | |
1114 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform" ) |
1115 | .set_body_typed([](Expr weight, int tile_size) { |
1116 | return MakeConvWinogradWeightTransform(weight, tile_size, |
1117 | "nn.contrib_conv2d_winograd_weight_transform" ); |
1118 | }); |
1119 | |
1120 | RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform" ) |
1121 | .describe(R"code(Weight transformation of winograd fast convolution algorithm. |
1122 | |
1123 | Separate this into another operator in order to enable Precompute Pass to compute the |
1124 | weight transformation in advance. |
1125 | |
1126 | - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) |
1127 | )code" TVM_ADD_FILELINE) |
1128 | .set_attrs_type<ConvWinogradWeightTransformAttrs>() |
1129 | .set_num_inputs(1) |
1130 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1131 | .set_support_level(10) |
1132 | .add_type_rel("Conv2DWinogradWeightTransform" , Conv2DWinogradWeightTransformRel) |
1133 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1134 | |
1135 | // relay.nn.contrib_conv3d_winograd_without_weight_transform |
1136 | TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs); |
1137 | |
1138 | bool Conv3DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1139 | const TypeReporter& reporter) { |
1140 | ICHECK_EQ(types.size(), 3); |
1141 | const auto* data = types[0].as<TensorTypeNode>(); |
1142 | if (data == nullptr) return false; |
1143 | static const Layout kNCDHW("NCDHW" ); |
1144 | static const Layout kOIDHW("OIDHW" ); |
1145 | |
1146 | const auto* param = attrs.as<Conv3DWinogradAttrs>(); |
1147 | ICHECK(param != nullptr); |
1148 | const Layout in_layout(param->data_layout); |
1149 | const Layout kernel_layout(param->kernel_layout); |
1150 | |
1151 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); |
1152 | ICHECK(trans_in_layout.defined()) |
1153 | << "Conv only support input layouts that are convertible from NCDHW." |
1154 | << " But got " << in_layout; |
1155 | |
1156 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); |
1157 | ICHECK(trans_kernel_layout.defined()) |
1158 | << "Conv only support kernel layouts that are convertible from OIDHW." |
1159 | << " But got " << kernel_layout; |
1160 | |
1161 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
1162 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); |
1163 | ICHECK(trans_out_layout.defined()) |
1164 | << "Conv only support output layouts that are convertible from NCDHW." |
1165 | << " But got " << out_layout; |
1166 | |
1167 | Array<IndexExpr> dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); |
1168 | |
1169 | IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; |
1170 | |
1171 | ICHECK(param->kernel_size.defined() && param->channels.defined()) |
1172 | << "The kernel size and channels of a Conv must be set or inferred by previous pass" ; |
1173 | |
1174 | ICHECK_EQ(param->kernel_size.size(), 3); |
1175 | ICHECK_EQ(param->dilation.size(), 3); |
1176 | |
1177 | channels = param->channels; |
1178 | dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
1179 | dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
1180 | dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; |
1181 | |
1182 | // NOTE: Do not check weight shape here! |
1183 | // Different backend requires different layout to compute |
1184 | // the batch gemm stage in winograd efficiently, but we want to |
1185 | // make this op work for all backends. |
1186 | // So we accept all weight shapes, and assume the TOPI developers |
1187 | // can handle this correctly in alter_op_layout. |
1188 | |
1189 | // dilation |
1190 | Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0}); |
1191 | |
1192 | IndexExpr pad_d, pad_h, pad_w; |
1193 | GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); |
1194 | if (!dshape_ncdhw[2].as<tir::AnyNode>()) { |
1195 | oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1); |
1196 | } else { |
1197 | oshape.Set(2, dshape_ncdhw[2]); |
1198 | } |
1199 | if (!dshape_ncdhw[2].as<tir::AnyNode>()) { |
1200 | oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1); |
1201 | } else { |
1202 | oshape.Set(3, dshape_ncdhw[3]); |
1203 | } |
1204 | if (!dshape_ncdhw[4].as<tir::AnyNode>()) { |
1205 | oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1); |
1206 | } else { |
1207 | oshape.Set(4, dshape_ncdhw[4]); |
1208 | } |
1209 | |
1210 | DataType out_dtype = param->out_dtype; |
1211 | if (out_dtype.bits() == 0) { |
1212 | out_dtype = data->dtype; |
1213 | } |
1214 | oshape = trans_out_layout.BackwardShape(oshape); |
1215 | // assign output type |
1216 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
1217 | return true; |
1218 | } |
1219 | |
1220 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform" ) |
1221 | .set_body_typed([](Expr data, Expr weight, int tile_size, Array<IndexExpr> strides, |
1222 | Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups, |
1223 | IndexExpr channels, Array<IndexExpr> kernel_size, String data_layout, |
1224 | String kernel_layout, String out_layout, DataType out_dtype) { |
1225 | return MakeConvWinograd<Conv3DWinogradAttrs>( |
1226 | data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, |
1227 | data_layout, kernel_layout, out_layout, out_dtype, |
1228 | "nn.contrib_conv3d_winograd_without_weight_transform" ); |
1229 | }); |
1230 | |
1231 | RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform" ) |
1232 | .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. |
1233 | This operator assumes the weight tensor is already pre-transformed by |
1234 | nn.contrib_conv3d_winograd_weight_transform. |
1235 | |
1236 | - **data**: Input is 5D array of shape (batch_size, in_channels, depth, height, width) |
1237 | - **weight**: Any shape |
1238 | We do not check the shape for this input tensor. Since different backend |
1239 | has different layout strategy. |
1240 | |
1241 | - **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width) |
1242 | )code" TVM_ADD_FILELINE) |
1243 | .set_attrs_type<Conv3DWinogradAttrs>() |
1244 | .set_num_inputs(2) |
1245 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1246 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1247 | .set_support_level(10) |
1248 | .add_type_rel("Conv3DWinograd" , Conv3DWinogradRel) |
1249 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
1250 | ConvInferCorrectLayout<Conv3DWinogradAttrs>) |
1251 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1252 | |
1253 | // relay.nn.contrib_conv3d_winograd_weight_transform |
1254 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform" ) |
1255 | .set_body_typed([](Expr weight, int tile_size) { |
1256 | return MakeConvWinogradWeightTransform(weight, tile_size, |
1257 | "nn.contrib_conv3d_winograd_weight_transform" ); |
1258 | }); |
1259 | |
1260 | bool Conv3DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1261 | const TypeReporter& reporter) { |
1262 | ICHECK_EQ(types.size(), 2); |
1263 | const auto* data = types[0].as<TensorTypeNode>(); |
1264 | if (data == nullptr) return false; |
1265 | |
1266 | const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>(); |
1267 | ICHECK(param != nullptr); |
1268 | |
1269 | ICHECK_EQ(data->shape.size(), 5) << "Only support NCDHW normal kernel layout" ; |
1270 | |
1271 | // Shape of packed weights depends on whether depth is being transformed or not. |
1272 | Array<IndexExpr> oshape({0, 0, 0, data->shape[0], data->shape[1]}); |
1273 | auto* depth_imm = data->shape[2].as<IntImmNode>(); |
1274 | bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8); |
1275 | if (transform_depth) { |
1276 | oshape.Set(0, param->tile_size + data->shape[2] - 1); |
1277 | oshape.Set(1, param->tile_size + data->shape[3] - 1); |
1278 | oshape.Set(2, param->tile_size + data->shape[4] - 1); |
1279 | } else { |
1280 | oshape.Set(0, param->tile_size + data->shape[3] - 1); |
1281 | oshape.Set(1, param->tile_size + data->shape[4] - 1); |
1282 | oshape.Set(2, data->shape[2]); |
1283 | } |
1284 | |
1285 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
1286 | return true; |
1287 | } |
1288 | |
1289 | RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform" ) |
1290 | .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm. |
1291 | |
1292 | Separate this into another operator in order to enable Precompute Pass to compute the |
1293 | weight transformation in advance. |
1294 | |
1295 | - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) |
1296 | )code" TVM_ADD_FILELINE) |
1297 | .set_attrs_type<ConvWinogradWeightTransformAttrs>() |
1298 | .set_num_inputs(1) |
1299 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1300 | .set_support_level(10) |
1301 | .add_type_rel("Conv3DWinogradWeightTransform" , Conv3DWinogradWeightTransformRel) |
1302 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1303 | |
1304 | // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform |
1305 | TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); |
1306 | |
1307 | bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types, int num_inputs, |
1308 | const Attrs& attrs, const TypeReporter& reporter) { |
1309 | ICHECK_EQ(types.size(), 2); |
1310 | const auto* data = types[0].as<TensorTypeNode>(); |
1311 | if (data == nullptr) { |
1312 | return false; |
1313 | } |
1314 | |
1315 | const Conv2DWinogradNNPACKWeightTransformAttrs* param = |
1316 | attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>(); |
1317 | ICHECK(param != nullptr); |
1318 | |
1319 | ICHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout" ; |
1320 | |
1321 | std::vector<IndexExpr> oshape{ |
1322 | data->shape[0], |
1323 | data->shape[1], |
1324 | 8, |
1325 | 8, |
1326 | }; |
1327 | |
1328 | DataType out_dtype = param->out_dtype; |
1329 | if (out_dtype.bits() == 0) { |
1330 | out_dtype = data->dtype; |
1331 | } |
1332 | reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype)); |
1333 | return true; |
1334 | } |
1335 | |
1336 | Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, |
1337 | DataType out_dtype) { |
1338 | auto attrs = make_object<Conv2DWinogradNNPACKWeightTransformAttrs>(); |
1339 | attrs->convolution_algorithm = convolution_algorithm; |
1340 | attrs->out_dtype = std::move(out_dtype); |
1341 | static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform" ); |
1342 | return Call(op, {weight}, Attrs(attrs), {}); |
1343 | } |
1344 | |
1345 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform" ) |
1346 | .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); |
1347 | |
1348 | RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform" ) |
1349 | .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. |
1350 | Separate this into another symbol in order to enable Precompute Pass to compute the |
1351 | weight transformation in advance. |
1352 | |
1353 | - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) |
1354 | |
1355 | )code" TVM_ADD_FILELINE) |
1356 | .set_attrs_type<Conv2DWinogradNNPACKWeightTransformAttrs>() |
1357 | .set_num_inputs(1) |
1358 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1359 | .set_support_level(10) |
1360 | .add_type_rel("Conv2DWinogradNNPACKWeightTransform" , Conv2DWinogradNNPACKWeightTransformRel) |
1361 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
1362 | |
1363 | // relay.nn.contrib_conv2d_gemm_without_weight_transform |
1364 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform" ) |
1365 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
1366 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
1367 | Array<IndexExpr> kernel_size, tvm::String data_layout, |
1368 | tvm::String kernel_layout, tvm::String out_layout, DataType out_dtype) { |
1369 | return MakeConvGemm<Conv2DAttrs>( |
1370 | data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, |
1371 | kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform" ); |
1372 | }); |
1373 | |
1374 | bool Conv2DGemmRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1375 | const TypeReporter& reporter) { |
1376 | ICHECK_EQ(types.size(), 3); |
1377 | const auto* data = types[0].as<TensorTypeNode>(); |
1378 | if (data == nullptr) return false; |
1379 | static const Layout kNHWC("NHWC" ); |
1380 | static const Layout kHWIO("HWIO" ); |
1381 | |
1382 | const auto* param = attrs.as<Conv2DAttrs>(); |
1383 | ICHECK(param != nullptr); |
1384 | const Layout in_layout(param->data_layout); |
1385 | const Layout kernel_layout(param->kernel_layout); |
1386 | |
1387 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNHWC); |
1388 | ICHECK(trans_in_layout.defined()) |
1389 | << "Conv only support input layouts that are convertible from NHWC." |
1390 | << " But got " << in_layout; |
1391 | |
1392 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kHWIO); |
1393 | ICHECK(trans_kernel_layout.defined()) |
1394 | << "Conv only support kernel layouts that are convertible from HWIO." |
1395 | << " But got " << kernel_layout; |
1396 | |
1397 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
1398 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNHWC); |
1399 | ICHECK(trans_out_layout.defined()) |
1400 | << "Conv only support output layouts that are convertible from NHWC." |
1401 | << " But got " << out_layout; |
1402 | |
1403 | Array<IndexExpr> dshape_nhwc = trans_in_layout.ForwardShape(data->shape); |
1404 | |
1405 | IndexExpr channels, dilated_ksize_y, dilated_ksize_x; |
1406 | |
1407 | ICHECK(param->kernel_size.defined() && param->channels.defined()) |
1408 | << "The kernel size and channels of a Conv must be set or inferred by previous pass" ; |
1409 | |
1410 | ICHECK_EQ(param->kernel_size.size(), 2); |
1411 | ICHECK_EQ(param->dilation.size(), 2); |
1412 | |
1413 | channels = param->channels; |
1414 | dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
1415 | dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
1416 | |
1417 | // NOTE: Do not check weight shape here! |
1418 | |
1419 | // dilation |
1420 | Array<IndexExpr> oshape({dshape_nhwc[0], 0, 0, channels}); |
1421 | |
1422 | IndexExpr pad_h, pad_w; |
1423 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
1424 | if (!dshape_nhwc[2].as<tir::AnyNode>()) { |
1425 | oshape.Set(1, (dshape_nhwc[1] + pad_h - dilated_ksize_y) / param->strides[0] + 1); |
1426 | } else { |
1427 | oshape.Set(1, dshape_nhwc[1]); |
1428 | } |
1429 | if (!dshape_nhwc[3].as<tir::AnyNode>()) { |
1430 | oshape.Set(2, (dshape_nhwc[2] + pad_w - dilated_ksize_x) / param->strides[1] + 1); |
1431 | } else { |
1432 | oshape.Set(2, dshape_nhwc[2]); |
1433 | } |
1434 | |
1435 | DataType out_dtype = param->out_dtype; |
1436 | if (out_dtype.bits() == 0) { |
1437 | out_dtype = data->dtype; |
1438 | } |
1439 | oshape = trans_out_layout.BackwardShape(oshape); |
1440 | // assign output type |
1441 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
1442 | return true; |
1443 | } |
1444 | |
1445 | RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform" ) |
1446 | .describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout. |
1447 | This operator assumes the weight tensor is already pre-transformed by |
1448 | nn.contrib_conv2d_gemm_weight_transform. |
1449 | |
1450 | - **data**: Input is 4D array of shape (batch_size, height, width, in_channels) |
1451 | - **weight**: Any shape |
1452 | We do not check the shape for this input tensor. Since different backend |
1453 | has different layout strategy. |
1454 | |
1455 | - **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) |
1456 | )code" TVM_ADD_FILELINE) |
1457 | .set_attrs_type<Conv2DAttrs>() |
1458 | .set_num_inputs(2) |
1459 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1460 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1461 | .set_support_level(10) |
1462 | .add_type_rel("Conv2DGemm" , Conv2DGemmRel) |
1463 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConvInferCorrectLayout<Conv2DAttrs>) |
1464 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1465 | |
1466 | // relay.nn.contrib_conv2d_gemm_weight_transform |
1467 | |
1468 | TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs); |
1469 | |
1470 | // Gemm convolution shape relations |
1471 | // In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W. |
1472 | // The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and |
1473 | // interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols] |
1474 | // matrix that we call W_interleaved_t. |
1475 | // |
1476 | // In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed |
1477 | // for tile_rows = 4 and tile_cols = 16 |
1478 | // |
1479 | // W[0,0,:,:] W_interleaved_t[0,0,:,:] |
1480 | // +-------------------------------+ +----------------------------------- + |
1481 | // |W[0,0] W[0,1] W[0,2] W[0,3] | |W[0,0] W[1,0] W[2,0] ... W[15,0]| |
1482 | // |W[1,0] W[1,1] W[1,2] W[1,3] | --\ |W[0,1] W[1,1] W[2,1] ... W[15,1]| |
1483 | // |W[2,0] W[2,1] W[2,2] W[2,3] | --/ |W[0,2] W[1,2] W[2,2] ... W[15,2]| |
1484 | // | ... ... ... ... | |W[0,3] W[1,3] W[2,3] ... W[15,3]| |
1485 | // | ... ... ... ... | +------------------------------------+ |
1486 | // |W[15,0] W[15,1] W[15,2] W[15,3]| |
1487 | // +-------------------------------+ |
1488 | // |
1489 | // Tile columns is usually the direction of the reduction. So, if our target can reduce k elements |
1490 | // at the time, we should set tile_cols = k. |
1491 | // Tile rows is connected with the number of registers available for the given target. |
1492 | // |
1493 | bool Conv2DGemmWeightTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1494 | const TypeReporter& reporter) { |
1495 | ICHECK_EQ(types.size(), 2); |
1496 | const auto* weight = types[0].as<TensorTypeNode>(); |
1497 | if (weight == nullptr) return false; |
1498 | |
1499 | const ConvGemmWeightTransformAttrs* param = attrs.as<ConvGemmWeightTransformAttrs>(); |
1500 | ICHECK(param != nullptr); |
1501 | int n = param->tile_rows; |
1502 | int k = param->tile_cols; |
1503 | |
1504 | ICHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout" ; |
1505 | |
1506 | const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2]; |
1507 | const auto N = weight->shape[3]; |
1508 | |
1509 | auto K_mod_k = indexmod(K, k); |
1510 | auto N_mod_n = indexmod(N, n); |
1511 | |
1512 | auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32))); |
1513 | auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32))); |
1514 | |
1515 | const auto N_padded = N + pad_N; |
1516 | const auto K_padded = K + pad_K; |
1517 | |
1518 | Array<IndexExpr> oshape{ |
1519 | indexdiv(N_padded, n), |
1520 | indexdiv(K_padded, k), |
1521 | n, |
1522 | k, |
1523 | }; |
1524 | |
1525 | reporter->Assign(types[1], TensorType(oshape, weight->dtype)); |
1526 | return true; |
1527 | } |
1528 | |
1529 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform" ) |
1530 | .set_body_typed([](Expr weights, int tile_rows, int tile_cols) { |
1531 | return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols, |
1532 | "nn.contrib_conv2d_gemm_weight_transform" ); |
1533 | }); |
1534 | |
1535 | RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_weight_transform" ) |
1536 | .describe(R"code(Weight transformation of GEMM convolution algorithm. |
1537 | |
1538 | Separate this into another operator in order to enable Precompute Pass to compute the |
1539 | weight transformation in advance. |
1540 | |
1541 | )code" TVM_ADD_FILELINE) |
1542 | .set_attrs_type<ConvGemmWeightTransformAttrs>() |
1543 | .set_num_inputs(1) |
1544 | .add_argument("weights" , "Tensor" , "The weights tensor." ) |
1545 | .set_support_level(10) |
1546 | .add_type_rel("Conv2DGemmWeightTransform" , Conv2DGemmWeightTransformRel) |
1547 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1548 | |
1549 | // Positional relay function to create conv2d NCHWc operator |
1550 | // used by frontend FFI. |
1551 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc" ) |
1552 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
1553 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
1554 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
1555 | String out_layout, DataType out_dtype) { |
1556 | return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels, |
1557 | kernel_size, data_layout, kernel_layout, out_layout, out_dtype, |
1558 | "nn.contrib_conv2d_NCHWc" ); |
1559 | }); |
1560 | |
1561 | RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc" ) |
1562 | .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. |
1563 | - **data**: Input is 5D packed tensor. |
1564 | - **weight**: 6D packed tensor. |
1565 | |
1566 | - **out**: Output is 5D packed tensor |
1567 | )code" TVM_ADD_FILELINE) |
1568 | .set_attrs_type<Conv2DAttrs>() |
1569 | .set_num_inputs(2) |
1570 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1571 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1572 | .set_support_level(10) |
1573 | .add_type_rel("Conv2DNCHWc" , Conv2DWinogradRel<Conv2DAttrs>) |
1574 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConvInferCorrectLayout<Conv2DAttrs>) |
1575 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1576 | |
1577 | // Positional relay function to create depthwise conv2d NCHWc operator |
1578 | // used by frontend FFI. |
1579 | TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc" ) |
1580 | .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding, |
1581 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
1582 | Array<IndexExpr> kernel_size, String data_layout, String kernel_layout, |
1583 | String out_layout, DataType out_dtype) { |
1584 | return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels, |
1585 | kernel_size, data_layout, kernel_layout, out_layout, out_dtype, |
1586 | "nn.contrib_depthwise_conv2d_NCHWc" ); |
1587 | }); |
1588 | |
1589 | RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc" ) |
1590 | .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. |
1591 | - **data**: Input is 5D packed tensor. |
1592 | - **weight**: 6D packed tensor. |
1593 | |
1594 | - **out**: Output is 5D packed tensor |
1595 | )code" TVM_ADD_FILELINE) |
1596 | .set_attrs_type<Conv2DAttrs>() |
1597 | .set_num_inputs(2) |
1598 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1599 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1600 | .set_support_level(10) |
1601 | .add_type_rel("Conv2D" , Conv2DRel) |
1602 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConvInferCorrectLayout<Conv2DAttrs>) |
1603 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1604 | |
1605 | TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs); |
1606 | |
1607 | // Deformable Convolution shape relations. |
1608 | bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1609 | const TypeReporter& reporter) { |
1610 | ICHECK_EQ(types.size(), 4); |
1611 | const auto* data = types[0].as<TensorTypeNode>(); |
1612 | const auto* weight = types[2].as<TensorTypeNode>(); |
1613 | |
1614 | ICHECK(data); |
1615 | static const Layout kNCHW("NCHW" ); |
1616 | static const Layout kOIHW("OIHW" ); |
1617 | |
1618 | auto* param = attrs.as<DeformableConv2DAttrs>(); |
1619 | ICHECK(param != nullptr); |
1620 | const Layout in_layout(param->data_layout); |
1621 | const Layout kernel_layout(param->kernel_layout); |
1622 | |
1623 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); |
1624 | if (!trans_in_layout.defined()) { |
1625 | reporter->GetDiagCtx().Emit( |
1626 | Diagnostic::Error(reporter->GetSpan()) |
1627 | << "deformable_conv2d only support input layouts that are convertible from NCHW." |
1628 | << " The provided layout is: " << in_layout); |
1629 | return false; |
1630 | } |
1631 | |
1632 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); |
1633 | if (!trans_kernel_layout.defined()) { |
1634 | reporter->GetDiagCtx().Emit( |
1635 | Diagnostic::Error(reporter->GetSpan()) |
1636 | << "deformable_conv2d only support kernel layouts that are convertible from OIHW." |
1637 | << " The provided layout is: " << kernel_layout); |
1638 | return false; |
1639 | } |
1640 | |
1641 | Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); |
1642 | const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); |
1643 | if (!trans_out_layout.defined()) { |
1644 | reporter->GetDiagCtx().Emit( |
1645 | Diagnostic::Error(reporter->GetSpan()) |
1646 | << "deformable_conv2d only support output layouts that are convertible from NCHW." |
1647 | << "The provided layout is: " << out_layout); |
1648 | return false; |
1649 | } |
1650 | |
1651 | Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
1652 | |
1653 | IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x; |
1654 | |
1655 | // infer weight shape if kernel_size and channels are defiend |
1656 | if (param->kernel_size.defined() && param->channels.defined()) { |
1657 | ICHECK_EQ(param->kernel_size.size(), 2); |
1658 | ICHECK_EQ(param->dilation.size(), 2); |
1659 | Array<IndexExpr> wshape({param->channels, indexdiv(dshape_nchw[1], param->groups), |
1660 | param->kernel_size[0], param->kernel_size[1]}); |
1661 | |
1662 | wshape = trans_kernel_layout.BackwardShape(wshape); |
1663 | channels = param->channels; |
1664 | ksize_y = param->kernel_size[0]; |
1665 | ksize_x = param->kernel_size[1]; |
1666 | dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; |
1667 | dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; |
1668 | // assign result to reporter |
1669 | reporter->Assign(types[2], TensorType(wshape, data->dtype)); |
1670 | } else { |
1671 | // use weight to infer the conv shape. |
1672 | if (weight == nullptr) return false; |
1673 | auto wshape = trans_kernel_layout.ForwardShape(weight->shape); |
1674 | |
1675 | if (param->kernel_size.defined()) { |
1676 | ICHECK_EQ(param->kernel_size.size(), 2); |
1677 | // check the size |
1678 | ICHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && |
1679 | reporter->AssertEQ(param->kernel_size[1], wshape[3])) |
1680 | << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " |
1681 | << " kernel_size=" << param->kernel_size << " wshape=" << wshape; |
1682 | } |
1683 | if (param->channels.defined()) { |
1684 | ICHECK(reporter->AssertEQ(param->channels, wshape[0])) |
1685 | << "DeformableConv2D: shape of weight is inconsistent with channels, " |
1686 | << " channels=" << param->channels << " wshape=" << wshape; |
1687 | } |
1688 | if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) { |
1689 | ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); |
1690 | } |
1691 | channels = wshape[0]; |
1692 | ksize_y = wshape[2]; |
1693 | ksize_x = wshape[3]; |
1694 | dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; |
1695 | dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; |
1696 | } |
1697 | // dilation |
1698 | Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); |
1699 | |
1700 | IndexExpr pad_h, pad_w; |
1701 | GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); |
1702 | oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); |
1703 | oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); |
1704 | DataType out_dtype = param->out_dtype; |
1705 | |
1706 | // infer offset shape |
1707 | Array<IndexExpr> offset_shape( |
1708 | {dshape_nchw[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); |
1709 | offset_shape = trans_in_layout.BackwardShape(offset_shape); |
1710 | reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); |
1711 | if (out_dtype.bits() == 0) { |
1712 | out_dtype = data->dtype; |
1713 | } |
1714 | |
1715 | oshape = trans_out_layout.BackwardShape(oshape); |
1716 | reporter->Assign(types[3], TensorType(oshape, out_dtype)); |
1717 | return true; |
1718 | } |
1719 | |
1720 | InferCorrectLayoutOutput DeformableConvInferCorrectLayout( |
1721 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
1722 | const Array<tvm::relay::Type>& old_in_types) { |
1723 | const auto* params = attrs.as<DeformableConv2DAttrs>(); |
1724 | return InferCorrectLayoutOutput( |
1725 | {params->data_layout, params->data_layout, params->kernel_layout}, |
1726 | {params->out_layout == "" ? params->data_layout : params->out_layout}, attrs); |
1727 | } |
1728 | |
1729 | RELAY_REGISTER_OP("nn.deformable_conv2d" ) |
1730 | .describe(R"code(Compute 2-D deformable convolution on 4-D input. |
1731 | The deformable convolution operation is described in https://arxiv.org/abs/1703.06211 |
1732 | |
1733 | For 2-D deformable convolution, the shapes are |
1734 | - **data**: (batch_size, channel, height, width) |
1735 | - **offset**: (batch_size, deformable_groups * kernel[0] * kernel[1] * 2, out_height, out_width) |
1736 | - **weight**: (num_filter, channel, kernel[0], kernel[1]) |
1737 | - **out**: (batch_size, num_filter, out_height, out_width). |
1738 | |
1739 | If `deformable_groups` is larger than 1, denoted by *dg*, then split the |
1740 | input `offset` evenly into *dg* parts along the channel axis, and also evenly split `out` |
1741 | evenly into *dg* parts along the channel axis. Next compute the deformable convolution, apply the |
1742 | *i*-th part of the offset part on the *i*-th out. |
1743 | |
1744 | If `groups` is larger than 1, denoted by *g*, then split the input `data` evenly into *g* parts |
1745 | along the channel axis, and also evenly split `weight` along the first dimension. Next compute |
1746 | the convolution on the *i*-th part of the data with the *i*-th weight part. The output is obtained |
1747 | by concating all the *g* results. |
1748 | )code" TVM_ADD_FILELINE) |
1749 | .set_attrs_type<DeformableConv2DAttrs>() |
1750 | .set_num_inputs(3) |
1751 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1752 | .add_argument("offset" , "Tensor" , "The offset tensor." ) |
1753 | .add_argument("weight" , "Tensor" , "The weight tensor." ) |
1754 | .set_support_level(5) |
1755 | .add_type_rel("DeformableConv2D" , DeformableConv2DRel) |
1756 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , DeformableConvInferCorrectLayout) |
1757 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1758 | |
1759 | // Positional relay function to create deformable_conv2d operator |
1760 | // used by frontend FFI. |
1761 | TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d" ) |
1762 | .set_body_typed([](Expr data, Expr offset, Expr weight, Array<IndexExpr> strides, |
1763 | Array<IndexExpr> padding, Array<IndexExpr> dilation, int deformable_groups, |
1764 | int groups, int channels, Array<IndexExpr> kernel_size, String data_layout, |
1765 | String kernel_layout, String out_layout, DataType out_dtype) { |
1766 | return MakeDeformableConv<DeformableConv2DAttrs>( |
1767 | data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels, |
1768 | kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d" ); |
1769 | }); |
1770 | |
1771 | inline Expr MakeConv2dBackwardWeight(Expr grad, Expr data, Array<IndexExpr> strides, |
1772 | Array<IndexExpr> padding, Array<IndexExpr> dilation, |
1773 | int groups, IndexExpr channels, Array<IndexExpr> kernel_size, |
1774 | std::string grad_layout, std::string data_layout, |
1775 | std::string kernel_layout, DataType out_dtype) { |
1776 | auto attrs = make_object<Conv2DAttrs>(); |
1777 | attrs->strides = std::move(strides); |
1778 | attrs->padding = std::move(padding); |
1779 | attrs->dilation = std::move(dilation); |
1780 | attrs->groups = groups; |
1781 | attrs->channels = std::move(channels); |
1782 | attrs->kernel_size = std::move(kernel_size); |
1783 | attrs->out_dtype = std::move(out_dtype); |
1784 | attrs->data_layout = std::move(grad_layout); |
1785 | attrs->kernel_layout = std::move(data_layout); |
1786 | attrs->out_layout = std::move(kernel_layout); |
1787 | const Op& op = Op::Get("nn.conv2d_backward_weight" ); |
1788 | return Call(op, {grad, data}, Attrs(attrs), {}); |
1789 | } |
1790 | |
1791 | TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_backward_weight" ) |
1792 | .set_body_typed([](Expr grad, Expr data, Array<IndexExpr> strides, Array<IndexExpr> padding, |
1793 | Array<IndexExpr> dilation, int groups, IndexExpr channels, |
1794 | Array<IndexExpr> kernel_size, String grad_layout, String data_layout, |
1795 | String kernel_layout, DataType out_dtype) { |
1796 | return MakeConv2dBackwardWeight(grad, data, strides, padding, dilation, groups, channels, |
1797 | kernel_size, grad_layout, data_layout, kernel_layout, |
1798 | out_dtype); |
1799 | }); |
1800 | |
1801 | bool Conv2DBackwardWeightRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1802 | const TypeReporter& reporter) { |
1803 | ICHECK_EQ(types.size(), 3); |
1804 | const auto* grad = types[0].as<TensorTypeNode>(); |
1805 | const auto* data = types[1].as<TensorTypeNode>(); |
1806 | if (data == nullptr) return false; |
1807 | |
1808 | static const Layout kNCHW("NCHW" ); |
1809 | static const Layout kOIHW("OIHW" ); |
1810 | |
1811 | const auto* param = attrs.as<Conv2DAttrs>(); |
1812 | ICHECK(param != nullptr); |
1813 | // Require kernel_size to be passed, to simplify the output shape determination. |
1814 | ICHECK(param->kernel_size.defined()) << "kernel_size attribute needs to be specified" ; |
1815 | |
1816 | // We repurpose Conv2dAttrs for Conv2DBackwardWeight, note the meanings of layouts. |
1817 | const Layout grad_layout(param->data_layout); |
1818 | const Layout in_layout(param->kernel_layout); |
1819 | const Layout kernel_layout(param->out_layout); |
1820 | |
1821 | const auto trans_grad_layout = tir::BijectiveLayout(grad_layout, kNCHW); |
1822 | const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); |
1823 | const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); |
1824 | |
1825 | Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape); |
1826 | Array<IndexExpr> grad_shape_nchw = trans_grad_layout.ForwardShape(grad->shape); |
1827 | |
1828 | auto in_channels = dshape_nchw[1]; |
1829 | auto out_channels = grad_shape_nchw[1]; |
1830 | |
1831 | auto in_channels_intimm = in_channels.as<IntImmNode>(); |
1832 | auto out_channels_intimm = out_channels.as<IntImmNode>(); |
1833 | ICHECK(in_channels_intimm); |
1834 | ICHECK(out_channels_intimm); |
1835 | |
1836 | IndexExpr weight_dim_i; |
1837 | if (in_channels_intimm->value == out_channels_intimm->value && |
1838 | in_channels_intimm->value == param->groups) { |
1839 | // depthwise |
1840 | ICHECK(param->channels.defined()) |
1841 | << "out_channels attribute not specified for depth wise conv2d." ; |
1842 | weight_dim_i = indexdiv(param->channels, param->groups); |
1843 | } else { |
1844 | weight_dim_i = indexdiv(in_channels, param->groups); |
1845 | } |
1846 | |
1847 | Array<IndexExpr> wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0], |
1848 | param->kernel_size[1]}; |
1849 | auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw); |
1850 | |
1851 | const auto dw_dtype = (param->out_dtype == DataType() || param->out_dtype.is_void()) |
1852 | ? grad->dtype |
1853 | : param->out_dtype; |
1854 | |
1855 | reporter->Assign(types[2], TensorType(wshape, dw_dtype)); |
1856 | return true; |
1857 | } |
1858 | |
1859 | RELAY_REGISTER_OP("nn.conv2d_backward_weight" ) |
1860 | .describe(R"code(The gradient of the 2D convolution layer with respect to the weight. |
1861 | |
1862 | This layer computes the gradient of the conv2d op with respect to weight, |
1863 | given the original input data and the output gradient. |
1864 | |
1865 | - **grad**: (batch, channels, out_height, out_width) if `layout` is `NCHW`. |
1866 | - **data**: This depends on the `layout` parameter. Input is 4D array of shape |
1867 | (batch_size, in_channels, height, width) if `layout` is `NCHW`. |
1868 | - **out**: This depends on the `layout` parameter. Output is 4D array of shape |
1869 | (channels, in_channels, kernel_size[0], kernel_size[1]) if `layout` is `NCHW`. |
1870 | )code" TVM_ADD_FILELINE) |
1871 | .set_attrs_type<Conv2DAttrs>() |
1872 | .set_num_inputs(2) |
1873 | .add_argument("grad" , "Tensor" , "The gradient tensor." ) |
1874 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1875 | .set_support_level(2) |
1876 | .add_type_rel("Conv2DBackwardWeight" , Conv2DBackwardWeightRel) |
1877 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConvInferCorrectLayout<Conv2DAttrs>) |
1878 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
1879 | |
1880 | } // namespace relay |
1881 | } // namespace tvm |
1882 | |