1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file nn.cc
22 * \brief Property def of nn operators.
23 */
24
25#include "nn.h"
26
27#include <tvm/auto_scheduler/compute_dag.h>
28#include <tvm/relay/attrs/image.h>
29#include <tvm/relay/attrs/nn.h>
30#include <tvm/relay/op.h>
31#include <tvm/tir/data_layout.h>
32#include <tvm/topi/nn.h>
33#include <tvm/topi/nn/bias_add.h>
34#include <tvm/topi/nn/flatten.h>
35#include <tvm/topi/nn/softmax.h>
36
37#include <algorithm>
38#include <string>
39#include <vector>
40
41#include "../../transforms/infer_layout_utils.h"
42#include "../make_op.h"
43#include "../op_common.h"
44#include "../type_relations.h"
45
46namespace tvm {
47namespace relay {
48
49// relay.nn.bias_add
50TVM_REGISTER_NODE_TYPE(BiasAddAttrs);
51
52bool BiasAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
53 const TypeReporter& reporter) {
54 ICHECK_EQ(types.size(), 3);
55 const auto* data = types[0].as<TensorTypeNode>();
56 if (data == nullptr) return false;
57
58 const BiasAddAttrs* param = attrs.as<BiasAddAttrs>();
59 ICHECK(param != nullptr);
60 int axis = param->axis;
61 if (axis < 0) {
62 axis = data->shape.size() + axis;
63 }
64 if (axis >= static_cast<int>(data->shape.size()) || axis < 0) {
65 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
66 << "The axis in bias_add must be in range for the shape; "
67 << "attempted to access index " << param->axis << " of "
68 << PrettyPrint(data->shape));
69 return false;
70 }
71
72 // assign output type
73 reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
74 reporter->Assign(types[2], types[0]);
75 return true;
76}
77
78// Positional relay function to create dense operator used by frontend FFI.
79Expr MakeBiasAdd(Expr data, Expr bias, int axis) {
80 auto attrs = make_object<BiasAddAttrs>();
81 attrs->axis = axis;
82 static const Op& op = Op::Get("nn.bias_add");
83 return Call(op, {data, bias}, Attrs(attrs), {});
84}
85
86TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd);
87
88RELAY_REGISTER_OP("nn.bias_add")
89 .describe(R"code(Add bias to an axis of the input.
90
91)code" TVM_ADD_FILELINE)
92 .set_attrs_type<BiasAddAttrs>()
93 .set_num_inputs(2)
94 .add_argument("data", "nD Tensor", "Input data.")
95 .add_argument("bias", "1D Tensor", "Bias.")
96 .set_support_level(1)
97 .add_type_rel("BiasAdd", BiasAddRel)
98 .set_attr<TOpPattern>("TOpPattern", kBroadcast)
99 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
100 const Type& out_type) {
101 const auto* param = attrs.as<BiasAddAttrs>();
102 return tvm::Array<tvm::te::Tensor>{topi::nn::bias_add(inputs[0], inputs[1], param->axis)};
103 });
104
105// relay.nn.fifo_buffer
106TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs);
107
108Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) {
109 auto attrs = make_object<FIFOBufferAttrs>();
110 attrs->axis = axis;
111 static const Op& op = Op::Get("nn.fifo_buffer");
112 return Call(op, {input, buffer}, Attrs(attrs), {});
113}
114
115bool FIFOBufferRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
116 const TypeReporter& reporter) {
117 ICHECK_EQ(types.size(), 3);
118 const auto* input = types[0].as<TensorTypeNode>();
119 const auto* buffer = types[1].as<TensorTypeNode>();
120 const FIFOBufferAttrs* param = attrs.as<FIFOBufferAttrs>();
121 if (input == nullptr || buffer == nullptr) {
122 return false;
123 }
124 ICHECK(param != nullptr);
125 ICHECK_EQ(input->shape.size(), buffer->shape.size());
126
127 const size_t buffer_axis = static_cast<size_t>(
128 param->axis < 0 ? static_cast<int>(buffer->shape.size()) + param->axis : param->axis);
129
130 reporter->Assert(buffer_axis < buffer->shape.size());
131 for (size_t i = 0; i < buffer->shape.size(); ++i) {
132 if (i != buffer_axis) {
133 reporter->AssertEQ(input->shape[i], buffer->shape[i]);
134 }
135 }
136 reporter->Assert(input->shape[buffer_axis] < buffer->shape[buffer_axis]);
137
138 Array<tvm::PrimExpr> oshape = buffer->shape;
139
140 reporter->Assign(types[2], TensorType(oshape, buffer->dtype));
141 return true;
142}
143
144TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer").set_body_typed(MakeFIFOBuffer);
145
146RELAY_REGISTER_OP("nn.fifo_buffer")
147 .describe(R"code(FIFO buffer
148Compute equivalent of
149
150```
151concat(buffer, data, axis=axis) \
152.slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis])
153```
154
155Useful for
156* Encoding explicit re-use of computation in convolution ops operated on a sliding window input
157* Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet.
158)code" TVM_ADD_FILELINE)
159 .set_attrs_type<FIFOBufferAttrs>()
160 .set_num_inputs(2)
161 .add_argument("data", "Tensor", "Latest input")
162 .add_argument("buffer", "Tensor", "Buffer storing latest [length_buffer] inputs")
163 .set_support_level(3)
164 .add_type_rel("FIFOBuffer", FIFOBufferRel)
165 .set_attr<TOpPattern>("TOpPattern", kOpaque);
166
167// ------------------- relay.nn.matmul
168TVM_REGISTER_NODE_TYPE(MatmulAttrs);
169
170Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType out_dtype, bool transpose_a,
171 bool transpose_b) {
172 auto attrs = make_object<MatmulAttrs>();
173 attrs->units = units;
174 attrs->out_dtype = out_dtype;
175 attrs->transpose_a = transpose_a;
176 attrs->transpose_b = transpose_b;
177 static const Op& matmul_op = Op::Get("nn.matmul");
178 return Call(matmul_op, {tensor_a, tensor_b}, Attrs(attrs), {});
179}
180
181TVM_REGISTER_GLOBAL("relay.op.nn._make.matmul").set_body_typed(MakeMatmul);
182
183RELAY_REGISTER_OP("nn.matmul")
184 .describe(R"code(Applies a linear transformation: :math:`C = A * B`. A & B can be transposed.
185
186- **tensor_a**: `(x1, x2, ..., xn, input_dim)` or `(x1, x2, ..., input_dim, xn)`
187- **tensor_b**: `(input_dim, units)` or `(units, input_dim)`
188- **out**: `(x1, x2, ..., xn, units)`.
189
190)code" TVM_ADD_FILELINE)
191 .set_attrs_type<MatmulAttrs>()
192 .set_num_inputs(2)
193 .add_argument("tensor_a", "nD Tensor", "The first input Tensor.")
194 .add_argument("tensor_b", "2D Tensor", "The second input Tensor.")
195 .set_support_level(1)
196 .add_type_rel("Matmul", MatmulRel<MatmulAttrs>)
197 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
198
199// ------------------- relay.nn.matmul
200
201// ------------------- relay.nn.dense
202TVM_REGISTER_NODE_TYPE(DenseAttrs);
203
204// Positional relay function to create dense operator used by frontend FFI.
205Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
206 auto attrs = make_object<DenseAttrs>();
207 attrs->units = units;
208 attrs->out_dtype = out_dtype;
209 static const Op& op = Op::Get("nn.dense");
210 return Call(op, {data, weight}, Attrs(attrs), {});
211}
212
213InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
214 const Array<Layout>& new_in_layouts,
215 const Array<Layout>& old_in_layouts,
216 const Array<tvm::relay::Type>& old_in_types) {
217 return InferCorrectLayoutOutput({"NC", "NC"}, {"NC"}, attrs);
218}
219
220TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense);
221
222RELAY_REGISTER_OP("nn.dense")
223 .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
224
225- **data**: `(x1, x2, ..., xn, input_dim)`
226- **weight**: `(units, input_dim)`
227- **out**: `(x1, x2, ..., xn, units)`.
228
229)code" TVM_ADD_FILELINE)
230 .set_attrs_type<DenseAttrs>()
231 .set_num_inputs(2)
232 .add_argument("data", "nD Tensor", "Input data.")
233 .add_argument("weight", "2D Tensor", "Weight matrix.")
234 .set_support_level(1)
235 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", DenseInferCorrectLayout)
236 .add_type_rel("Dense", MatmulRel<DenseAttrs>)
237 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
238// ------------------- relay.nn.dense
239
240// ------------------- relay.nn.contrib_dense_pack
241TVM_REGISTER_NODE_TYPE(DensePackAttrs);
242
243// Positional relay function to create dense_pack operator used by frontend FFI.
244Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
245 DataType out_dtype) {
246 auto attrs = make_object<DensePackAttrs>();
247 attrs->units = units;
248 attrs->out_dtype = out_dtype;
249 attrs->weight_layout = std::move(weight_layout);
250 static const Op& op = Op::Get("nn.contrib_dense_pack");
251 return Call(op, {data, weight}, Attrs(attrs), {});
252}
253
254TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack);
255
256bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
257 const TypeReporter& reporter) {
258 ICHECK_EQ(types.size(), 3);
259 const auto* data = types[0].as<TensorTypeNode>();
260 const auto* weight = types[1].as<TensorTypeNode>();
261 if (data == nullptr || weight == nullptr) return false;
262
263 const DensePackAttrs* param = attrs.as<DensePackAttrs>();
264 ICHECK(param != nullptr);
265
266 ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
267 ICHECK(weight->shape.size() == 3 || weight->shape.size() == 4) << "Expect weight to be 3D or 4D";
268
269 Array<tvm::PrimExpr> oshape = data->shape;
270 oshape.Set(1, weight->shape[0] * weight->shape[2]);
271
272 DataType out_dtype = param->out_dtype;
273 if (out_dtype.bits() == 0) {
274 out_dtype = data->dtype;
275 }
276 // assign output type
277 reporter->Assign(types[2], TensorType(oshape, out_dtype));
278 return true;
279}
280
281InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
282 const Array<Layout>& new_in_layouts,
283 const Array<Layout>& old_in_layouts,
284 const Array<tvm::relay::Type>& old_in_types) {
285 auto params = attrs.as<DensePackAttrs>();
286 ICHECK(params);
287 return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
288}
289
290RELAY_REGISTER_OP("nn.contrib_dense_pack")
291 .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
292
293- **data**: `(batch, input_dim)`
294- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)`
295- **out**: `(batch, units)`.
296
297)code" TVM_ADD_FILELINE)
298 .set_attrs_type<DenseAttrs>()
299 .set_num_inputs(2)
300 .add_argument("data", "2D Tensor", "Input data.")
301 .add_argument("weight", "3D Tensor", "Packed weight matrix.")
302 .set_support_level(10)
303 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", DensePackInferCorrectLayout)
304 .add_type_rel("DensePack", DensePackRel)
305 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
306
307// ------------------- relay.nn.contrib_dense_pack
308
309// relay.leaky_relu
310TVM_REGISTER_NODE_TYPE(LeakyReluAttrs);
311
312// Positional relay function to create leaky relu operator used by frontend FFI.
313Expr MakeLeakyRelu(Expr data, double alpha) {
314 auto attrs = make_object<LeakyReluAttrs>();
315 attrs->alpha = alpha;
316 static const Op& op = Op::Get("nn.leaky_relu");
317 return Call(op, {data}, Attrs(attrs), {});
318}
319
320TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu").set_body_typed(MakeLeakyRelu);
321
322RELAY_REGISTER_OP("nn.leaky_relu")
323 .describe(R"code(Leaky version of a Rectified Linear Unit.
324
325`y = x > 0 ? x : alpha * x`
326
327)code" TVM_ADD_FILELINE)
328 .set_attrs_type<LeakyReluAttrs>()
329 .set_num_inputs(1)
330 .add_argument("data", "Tensor", "Input data.")
331 .set_support_level(3)
332 .add_type_rel("Identity", IdentityRel)
333 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
334 .set_attr<TOpPattern>("TOpPattern", kElemWise)
335 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
336 const Type& out_type) {
337 const auto* param = attrs.as<LeakyReluAttrs>();
338 return Array<te::Tensor>{topi::leaky_relu(inputs[0], param->alpha)};
339 });
340
341// relay.prelu
342TVM_REGISTER_NODE_TYPE(PReluAttrs);
343
344bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
345 const TypeReporter& reporter) {
346 ICHECK_EQ(types.size(), 3);
347 const auto* data = types[0].as<TensorTypeNode>();
348 if (data == nullptr) return false;
349
350 const PReluAttrs* param = attrs.as<PReluAttrs>();
351 ICHECK(param != nullptr);
352
353 ICHECK(param->axis < static_cast<int>(data->shape.size()))
354 << "Wrong axis (" << param->axis << ")value.";
355
356 // assign alpha type
357 Array<IndexExpr> alpha_shape({data->shape[param->axis]});
358 reporter->Assign(types[1], TensorType(alpha_shape, data->dtype));
359
360 // assign output type
361 reporter->Assign(types[2], TensorType(data->shape, data->dtype));
362 return true;
363}
364
365InferCorrectLayoutOutput PReluInferCorrectLayout(const Attrs& attrs,
366 const Array<Layout>& new_in_layouts,
367 const Array<Layout>& old_in_layouts,
368 const Array<tvm::relay::Type>& old_in_types) {
369 ICHECK_EQ(old_in_layouts.size(), 2U);
370 ICHECK_EQ(old_in_types.size(), 2U);
371 Layout data_layout = old_in_layouts[0];
372 if (new_in_layouts.defined()) {
373 ICHECK_EQ(new_in_layouts.size(), 2U);
374 }
375 return InferCorrectLayoutOutput({data_layout, Layout("C")}, {data_layout}, attrs);
376}
377
378// Positional relay function to create prelu operator used by frontend FFI.
379Expr MakePRelu(Expr data, Expr alpha, int axis) {
380 auto attrs = make_object<PReluAttrs>();
381 attrs->axis = axis;
382 static const Op& op = Op::Get("nn.prelu");
383 return Call(op, {data, alpha}, Attrs(attrs), {});
384}
385
386TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu").set_body_typed(MakePRelu);
387
388RELAY_REGISTER_OP("nn.prelu")
389 .describe(R"code(Parametric version of a Rectified Linear Unit.
390It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
391and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
392where :math:`*` is an channelwise multiplication for each sample in the batch.
393)code" TVM_ADD_FILELINE)
394 .set_attrs_type<PReluAttrs>()
395 .set_num_inputs(2)
396 .add_argument("data", "Tensor", "Input data.")
397 .add_argument("alpha", "Tensor", "Input channelwise alpha.")
398 .set_support_level(3)
399 .add_type_rel("PRelu", PReluRel)
400 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout)
401 .set_attr<TOpPattern>("TOpPattern", kBroadcast)
402 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
403 const Type& out_type) {
404 const auto* param = attrs.as<PReluAttrs>();
405 return Array<te::Tensor>{topi::prelu(inputs[0], inputs[1], param->axis)};
406 });
407
408// relay.softmax
409TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
410
411bool SoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
412 const TypeReporter& reporter) {
413 ICHECK_EQ(types.size(), 2);
414 const auto* data = types[0].as<TensorTypeNode>();
415 if (data == nullptr) return false;
416
417 const SoftmaxAttrs* param = attrs.as<SoftmaxAttrs>();
418 ICHECK(param != nullptr);
419 int axis = param->axis;
420 int ndim = static_cast<int>(data->shape.size());
421 if (axis >= ndim || axis < -ndim) {
422 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
423 << "Wrong axis (" << axis << ") not in expected range: ["
424 << -ndim << ", " << ndim << ")");
425 return false;
426 }
427
428 reporter->Assign(types[1], types[0]);
429 return true;
430}
431
432TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data, int axis) {
433 auto attrs = make_object<SoftmaxAttrs>();
434 attrs->axis = axis;
435 static const Op& op = Op::Get("nn.softmax");
436 return Call(op, {data}, Attrs(attrs), {});
437});
438
439RELAY_REGISTER_OP("nn.softmax")
440 .describe(R"code(Softmax layer.
441
442.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
443
444.. note::
445 This operator can be optimized away for inference.
446
447- **data**: The input data
448)code" TVM_ADD_FILELINE)
449 .set_attrs_type<SoftmaxAttrs>()
450 .set_num_inputs(1)
451 .add_argument("data", "Tensor", "The input tensor.")
452 .set_support_level(1)
453 .add_type_rel("Softmax", SoftmaxRel)
454 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
455
456// relay.fast_softmax
457TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
458
459TVM_REGISTER_GLOBAL("relay.op.nn._make.fast_softmax").set_body_typed([](Expr data, int axis) {
460 auto attrs = make_object<SoftmaxAttrs>();
461 attrs->axis = axis;
462 static const Op& op = Op::Get("nn.fast_softmax");
463 return Call(op, {data}, Attrs(attrs), {});
464});
465
466RELAY_REGISTER_OP("nn.fast_softmax")
467 .describe(R"code(Softmax layer.
468 Use approximation to compute exponent for faster speed.
469
470.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
471
472.. note::
473 This operator can be optimized away for inference.
474
475- **data**: The input data
476)code" TVM_ADD_FILELINE)
477 .set_attrs_type<SoftmaxAttrs>()
478 .set_num_inputs(1)
479 .add_argument("data", "Tensor", "The input tensor.")
480 .set_support_level(1)
481 .add_type_rel("Softmax", SoftmaxRel)
482 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
483
484// relay.nn.log_softmax
485TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) {
486 auto attrs = make_object<SoftmaxAttrs>();
487 attrs->axis = axis;
488 static const Op& op = Op::Get("nn.log_softmax");
489 return Call(op, {data}, Attrs(attrs), {});
490});
491
492RELAY_REGISTER_OP("nn.log_softmax")
493 .describe(R"code(Computes log softmax.
494
495.. math:: \text{log_softmax}(x)_i = \log \frac{exp(x_i)}{\sum_j exp(x_j)}
496
497.. note::
498 This operator can be optimized away for inference.
499
500- **data**: The input data
501)code" TVM_ADD_FILELINE)
502 .set_attrs_type<SoftmaxAttrs>()
503 .set_num_inputs(1)
504 .add_argument("data", "Tensor", "The input tensor.")
505 .set_support_level(1)
506 .add_type_rel("Softmax", SoftmaxRel)
507 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable)
508 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
509 const Type& out_type) {
510 const auto* param = attrs.as<SoftmaxAttrs>();
511 ICHECK(param != nullptr);
512 ICHECK(param->axis == -1 || param->axis == static_cast<int32_t>(inputs[0].ndim()) - 1)
513 << "log_softmax currently only works on last dimension";
514 return Array<te::Tensor>{topi::nn::log_softmax(inputs[0])};
515 });
516
517// relay.nn.batch_flatten
518bool BatchFlattenRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
519 const TypeReporter& reporter) {
520 ICHECK_EQ(types.size(), 2);
521 const auto* data = types[0].as<TensorTypeNode>();
522 if (data == nullptr) return false;
523 if (data->shape.size() == 0) return false;
524
525 auto target_dim = tir::make_const(DataType::Int(32), 1);
526
527 for (uint32_t i = 1; i < data->shape.size(); ++i) {
528 if (!data->shape[i].as<tir::AnyNode>()) {
529 target_dim = target_dim * data->shape[i];
530 } else {
531 target_dim = data->shape[i];
532 break;
533 }
534 }
535
536 std::vector<IndexExpr> oshape({data->shape[0], target_dim});
537
538 // assign output type
539 reporter->Assign(types[1], TensorType(oshape, data->dtype));
540 return true;
541}
542
543Expr MakeBatchFlatten(Expr data) {
544 static const Op& op = Op::Get("nn.batch_flatten");
545 return Call(op, {data}, Attrs(), {});
546}
547
548TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten").set_body_typed(MakeBatchFlatten);
549
550RELAY_REGISTER_OP("nn.batch_flatten")
551 .describe(R"code(Flattens the input into a 2-D array.
552
553For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes
554the input array into an output array of shape ``(d1, d2*...*dk)``.
555
556Example::
557
558 x = [[
559 [1,2,3],
560 [4,5,6],
561 [7,8,9]
562 ],
563 [ [1,2,3],
564 [4,5,6],
565 [7,8,9]
566 ]],
567
568 batch_flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.],
569 [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]]
570
571)code" TVM_ADD_FILELINE)
572 .set_num_inputs(1)
573 .add_argument("data", "Tensor", "The input tensor.")
574 .set_support_level(2)
575 .add_type_rel("BatchFlatten", BatchFlattenRel)
576 .set_attr<TOpPattern>("TOpPattern", kInjective)
577 .set_attr<FTVMCompute>("FTVMCompute",
578 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
579 const Type& out_type) {
580 return Array<te::Tensor>{topi::nn::flatten(inputs[0])};
581 })
582 .set_attr<TReshapeOp>("TReshapeOp", true);
583
584// relu
585TVM_REGISTER_GLOBAL("relay.op.nn._make.relu").set_body_typed([](Expr data) {
586 static const Op& op = Op::Get("nn.relu");
587 return Call(op, {data}, Attrs(), {});
588});
589
590RELAY_REGISTER_OP("nn.relu")
591 .describe(R"code(Returns the relu input array, computed element-wise.
592
593.. math::
594 max(x, 0)
595
596)code" TVM_ADD_FILELINE)
597 .set_num_inputs(1)
598 .add_argument("data", "Tensor", "The input tensor.")
599 .set_support_level(1)
600 .add_type_rel("Identity", IdentityRel)
601 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
602 .set_attr<TOpPattern>("TOpPattern", kElemWise)
603 .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
604 const Type& out_type) {
605 return Array<te::Tensor>{topi::relu(inputs[0], 0.0f)};
606 });
607
608// Positional relay function to create LRN operator used by frontend FFI.
609TVM_REGISTER_NODE_TYPE(LRNAttrs);
610
611Expr MakeLRN(Expr data, int size, int axis, double alpha, double beta, double bias) {
612 auto attrs = make_object<LRNAttrs>();
613 attrs->size = size;
614 attrs->axis = axis;
615 attrs->alpha = alpha;
616 attrs->beta = beta;
617 attrs->bias = bias;
618 static const Op& op = Op::Get("nn.lrn");
619 return Call(op, {data}, Attrs(attrs), {});
620}
621
622TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn").set_body_typed(MakeLRN);
623
624RELAY_REGISTER_OP("nn.lrn")
625 .describe(R"code(LRN layer.
626
627Normalize the input in a local region across or within feature maps.
628Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta,
629where n is the size of each local region, and the sum is taken over the region
630centered at that value (zero padding is added where necessary).
631
632.. math::
633
634 data / (bias + (alpha * sum_data ^2 /size))^beta
635
636- **data**: The input tensor.
637)code" TVM_ADD_FILELINE)
638 .set_attrs_type<LRNAttrs>()
639 .set_num_inputs(1)
640 .add_argument("data", "Tensor", "The input tensor.")
641 .set_support_level(2)
642 .add_type_rel("Identity", IdentityRel)
643 .set_attr<TOpPattern>("TOpPattern", kOpaque);
644
645// Positional relay function to create L2Normalize operator used by frontend FFI.
646TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);
647
648Expr MakeL2Normalize(Expr data, double eps, Array<Integer> axis) {
649 auto attrs = make_object<L2NormalizeAttrs>();
650 attrs->eps = eps;
651 attrs->axis = std::move(axis);
652 static const Op& op = Op::Get("nn.l2_normalize");
653 return Call(op, {data}, Attrs(attrs), {});
654}
655
656InferCorrectLayoutOutput L2NormalizeInferCorrectLayout(
657 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
658 const Array<tvm::relay::Type>& old_in_types) {
659 const auto* attrs_ptr = attrs.as<L2NormalizeAttrs>();
660 ICHECK(attrs_ptr);
661 ObjectPtr<L2NormalizeAttrs> param = make_object<L2NormalizeAttrs>(*attrs_ptr);
662
663 Array<Array<IndexExpr>> old_in_shapes;
664 for (auto old_in_t : old_in_types) {
665 ICHECK(old_in_t.as<TensorTypeNode>());
666 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
667 }
668 std::vector<size_t> axis_list;
669 for (auto i : param->axis) {
670 int64_t axis = i->value;
671 if (axis < 0) {
672 axis = axis + static_cast<size_t>(old_in_shapes[0].size());
673 }
674 axis_list.emplace_back(axis);
675 }
676
677 Layout ret = Layout::Undef();
678 if (new_in_layouts.defined() && old_in_layouts.defined()) {
679 for (size_t i = 0; i < axis_list.size(); ++i) {
680 const auto& axis_dim = old_in_layouts[0][axis_list[i]];
681 auto axis_index = new_in_layouts[0].IndexOf(axis_dim);
682 param->axis.Set(i, axis_index);
683 }
684 ret = new_in_layouts[0];
685 } else if (old_in_layouts.defined()) {
686 ret = old_in_layouts[0];
687 }
688
689 return InferCorrectLayoutOutput({ret}, {ret}, Attrs(param));
690}
691
692TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize").set_body_typed(MakeL2Normalize);
693
694RELAY_REGISTER_OP("nn.l2_normalize")
695 .describe(R"code(L2 Normalization layer.
696
697Normalizes along dimension axis using an L2 norm
698
699.. math::
700 output = x / sqrt(max(sum(x^2), epsilon))
701
702- **data**: The input tensor.
703)code" TVM_ADD_FILELINE)
704 .set_attrs_type<L2NormalizeAttrs>()
705 .set_num_inputs(1)
706 .add_argument("data", "Tensor", "The input tensor.")
707 .set_support_level(2)
708 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", L2NormalizeInferCorrectLayout)
709 .add_type_rel("Identity", IdentityRel);
710
711// Dropout
712TVM_REGISTER_NODE_TYPE(DropoutAttrs);
713
714bool DropoutRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
715 const TypeReporter& reporter) {
716 ICHECK_EQ(types.size(), 2);
717 const auto* data = types[0].as<TensorTypeNode>();
718 if (data == nullptr) return false;
719
720 // dropout returns the original tensor with dropout applied
721 // and a mask tensor (1.0 where element not dropped, 0.0 where dropped)
722 auto ret_type = TensorType(data->shape, data->dtype);
723 reporter->Assign(types[1], TupleType(Array<Type>({ret_type, ret_type})));
724 return true;
725}
726
727Expr MakeDropout(Expr data, double rate) {
728 auto attrs = make_object<DropoutAttrs>();
729 attrs->rate = rate;
730 static const Op& op = Op::Get("nn.dropout");
731 return Call(op, {data}, Attrs(attrs), {});
732}
733
734TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout").set_body_typed(MakeDropout);
735
736RELAY_REGISTER_OP("nn.dropout")
737 .describe(R"code(Applies the dropout operation to the input array.
738
739During training, each element of the input is set to zero with probability ``p``.
740The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
741
742)code" TVM_ADD_FILELINE)
743 .set_attrs_type<DropoutAttrs>()
744 .set_num_inputs(1)
745 .add_argument("data", "Tensor", "Input to which dropout will be applied.")
746 .set_support_level(1)
747 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
748 .set_attr<TOpPattern>("TOpPattern", kOpaque)
749 .add_type_rel("Dropout", DropoutRel)
750 .set_attr<TOpIsStateful>("TOpIsStateful", true);
751
752// batch_norm
753TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
754
755InferCorrectLayoutOutput BatchNormInferCorrectLayout(const Attrs& attrs,
756 const Array<Layout>& new_in_layouts,
757 const Array<Layout>& old_in_layouts,
758 const Array<tvm::relay::Type>& old_in_types) {
759 const auto* attrs_ptr = attrs.as<BatchNormAttrs>();
760 ICHECK(attrs_ptr);
761 ObjectPtr<BatchNormAttrs> param = make_object<BatchNormAttrs>(*attrs_ptr);
762
763 Array<Array<IndexExpr>> old_in_shapes;
764 for (auto old_in_t : old_in_types) {
765 ICHECK(old_in_t.as<TensorTypeNode>());
766 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
767 }
768
769 size_t axis =
770 param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
771
772 Layout ret = Layout::Undef();
773
774 // If new_in_layouts are defined, this code tries to modify the layout.
775 if (new_in_layouts.defined() && old_in_layouts.defined()) {
776 // Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout.
777 const auto& bn_dim = old_in_layouts[0][axis];
778 auto new_index = new_in_layouts[0].IndexOf(bn_dim);
779 param->axis = new_index;
780 ret = new_in_layouts[0];
781 } else if (old_in_layouts.defined()) {
782 ret = old_in_layouts[0];
783 }
784 // BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout.
785 Layout c_layout = Layout("C");
786 return InferCorrectLayoutOutput({ret, c_layout, c_layout, c_layout, c_layout},
787 {ret, c_layout, c_layout}, Attrs(param));
788}
789
790bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
791 const TypeReporter& reporter) {
792 ICHECK_EQ(types.size(), 6);
793 const auto* data = types[0].as<TensorTypeNode>();
794 if (data == nullptr) return false;
795
796 const BatchNormAttrs* param = attrs.as<BatchNormAttrs>();
797
798 // axis of -1 means use the last dimension
799 ICHECK(param->axis >= -1 && param->axis < (int)data->shape.size());
800 int axis = (param->axis != -1) ? param->axis : data->shape.size() - 1;
801 auto axis_size = data->shape[axis];
802
803 // if we are using beta and gamma, they need to be of shape (dim,)
804 reporter->Assign(types[1], TensorType({axis_size}, data->dtype));
805 reporter->Assign(types[2], TensorType({axis_size}, data->dtype));
806 reporter->Assign(types[3], TensorType({axis_size}, data->dtype));
807 reporter->Assign(types[4], TensorType({axis_size}, data->dtype));
808
809 // output is a tuple of the normed data (same shape as input), new running mean,
810 // new running variance, saved mean and saved variance (the latter are all
811 // vectors of length dim)
812 std::vector<Type> fields;
813 auto vec_ty = TensorType(Array<IndexExpr>({data->shape[axis]}), data->dtype);
814 fields.push_back(TensorType(data->shape, data->dtype));
815 fields.push_back(vec_ty);
816 fields.push_back(vec_ty);
817 reporter->Assign(types[5], TupleType(Array<Type>(fields)));
818 return true;
819}
820
821Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis,
822 double epsilon, bool center, bool scale) {
823 auto attrs = make_object<BatchNormAttrs>();
824 attrs->axis = axis;
825 attrs->epsilon = epsilon;
826 attrs->center = center;
827 attrs->scale = scale;
828 static const Op& op = Op::Get("nn.batch_norm");
829 return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {});
830}
831
832TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm").set_body_typed(MakeBatchNorm);
833
834RELAY_REGISTER_OP("nn.batch_norm")
835 .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
836Normalizes the input at each batch, i.e. applies a transformation
837that maintains the mean activation close to 0 and the activation
838standard deviation close to 1.
839
840.. math::
841
842 data\_mean[i] = mean(data[:,i,:,...]) \\
843 data\_var[i] = var(data[:,i,:,...])
844
845Then compute the normalized output, which has the same shape as input, as following:
846
847.. math::
848
849 out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} \
850* gamma[i] + beta[i]
851
852Both *mean* and *var* returns a scalar by treating the input as a vector.
853
854Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` have shape *(k,)*.
855
856Besides the inputs and the outputs, this operator accepts two auxiliary
857states, ``moving_mean`` and ``moving_var``, which are *k*-length
858vectors. They are global statistics for the whole dataset, which are updated
859by::
860
861 moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
862 moving_var = moving_var * momentum + data_var * (1 - momentum)
863
864The parameter ``axis`` specifies which axis of the input shape denotes
865the 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel
866axis to be the last item in the input shape.
867
868.. note::
869 This operator can be optimized away for inference.
870)code" TVM_ADD_FILELINE)
871 .set_attrs_type<BatchNormAttrs>()
872 .set_num_inputs(5)
873 .add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
874 .add_argument("gamma", "Tensor", "The gamma scale factor.")
875 .add_argument("beta", "Tensor", "The beta offset factor.")
876 .add_argument("moving_mean", "Tensor", "Running mean of input.")
877 .add_argument("moving_var", "Tensor", "Running variance of input.")
878 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", BatchNormInferCorrectLayout)
879 .set_support_level(1)
880 .add_type_rel("BatchNorm", BatchNormRel)
881 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
882
883// instance_norm
884TVM_REGISTER_NODE_TYPE(InstanceNormAttrs);
885
886template <typename T>
887InferCorrectLayoutOutput NormalizationInferCorrectLayout(
888 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
889 const Array<tvm::relay::Type>& old_in_types) {
890 const auto* attrs_ptr = attrs.as<T>();
891 ICHECK(attrs_ptr);
892 ObjectPtr<T> param = make_object<T>(*attrs_ptr);
893
894 Array<Array<IndexExpr>> old_in_shapes;
895 for (auto old_in_t : old_in_types) {
896 ICHECK(old_in_t.as<TensorTypeNode>());
897 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
898 }
899
900 size_t axis =
901 param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
902
903 Layout ret = Layout::Undef();
904
905 // If new_in_layouts are defined, this code tries to modify the layout.
906 if (new_in_layouts.defined() && old_in_layouts.defined()) {
907 // Get the new C axis. Extract the dim in old layout. Find the index of that dim in next layout.
908 const auto& ln_dim = old_in_layouts[0][axis];
909 auto new_index = new_in_layouts[0].IndexOf(ln_dim);
910 param->axis = new_index;
911 ret = new_in_layouts[0];
912 } else if (old_in_layouts.defined()) {
913 ret = old_in_layouts[0];
914 }
915
916 // For normalization has 3 inputs, 1 outputs. The last 2 inputs have "C" layout.
917 Layout c_layout = Layout("C");
918 return InferCorrectLayoutOutput({ret, c_layout, c_layout}, {ret}, Attrs(param));
919}
920
921bool InstanceNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
922 const TypeReporter& reporter) {
923 ICHECK_EQ(types.size(), 4);
924 const auto* data = types[0].as<TensorTypeNode>();
925 if (data == nullptr) return false;
926 ICHECK_GT(data->shape.size(), 2);
927 const InstanceNormAttrs* param = attrs.as<InstanceNormAttrs>();
928 int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
929 ICHECK(axis >= 0 && axis < (int)data->shape.size());
930 reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
931 reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
932 reporter->Assign(types[3], TensorType(data->shape, data->dtype));
933
934 return true;
935}
936
937Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center,
938 bool scale) {
939 auto attrs = make_object<InstanceNormAttrs>();
940 attrs->axis = axis;
941 attrs->epsilon = epsilon;
942 attrs->center = center;
943 attrs->scale = scale;
944 static const Op& op = Op::Get("nn.instance_norm");
945 return Call(op, {data, gamma, beta}, Attrs(attrs), {});
946}
947
948TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm").set_body_typed(MakeInstanceNorm);
949
950RELAY_REGISTER_OP("nn.instance_norm")
951 .describe(R"code(Instance Normalization (Ulyanov and et al., 2016)
952Applies instance normalization to the n-dimensional input array.
953
954.. math::
955
956 out = \frac{data - mean(data)}{\sqrt{var(data)+\epsilon}}
957 * gamma + beta
958
959The instance normalization is similar to batch normalization, but unlike
960batch normalization, the mean and var are calculated per-dimension
961separately for each object(instance) in a mini-batch, not over a batch.
962And the same normalization is applied both at test and train time.
963
964Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
965have shape *(k,)*.
966
967The parameter ``axis`` specifies which axis of the input shape denotes
968the 'channel'. The default is 1. Specifying -1 sets the channel axis
969to be the last item in the input shape.
970
971.. note::
972
973 This operator can be optimized away for inference.
974)code" TVM_ADD_FILELINE)
975 .set_attrs_type<InstanceNormAttrs>()
976 .set_num_inputs(3)
977 .add_argument("data", "Tensor", "Input to which instance_norm will be applied.")
978 .add_argument("gamma", "Tensor", "The gamma scale factor.")
979 .add_argument("beta", "Tensor", "The beta offset factor.")
980 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
981 NormalizationInferCorrectLayout<InstanceNormAttrs>)
982 .set_support_level(1)
983 .add_type_rel("InstanceNorm", InstanceNormRel);
984
985// layer_norm
986TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
987
988bool LayerNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
989 const TypeReporter& reporter) {
990 ICHECK_EQ(types.size(), 4);
991 const auto* data = types[0].as<TensorTypeNode>();
992 if (data == nullptr) return false;
993 const LayerNormAttrs* param = attrs.as<LayerNormAttrs>();
994 int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
995 ICHECK(axis >= 0 && axis < (int)data->shape.size());
996 reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
997 reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
998 reporter->Assign(types[3], TensorType(data->shape, data->dtype));
999
1000 return true;
1001}
1002
1003Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center,
1004 bool scale) {
1005 auto attrs = make_object<LayerNormAttrs>();
1006 attrs->axis = axis;
1007 attrs->epsilon = epsilon;
1008 attrs->center = center;
1009 attrs->scale = scale;
1010 static const Op& op = Op::Get("nn.layer_norm");
1011 return Call(op, {data, gamma, beta}, Attrs(attrs), {});
1012}
1013
1014TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm").set_body_typed(MakeLayerNorm);
1015
1016RELAY_REGISTER_OP("nn.layer_norm")
1017 .describe(R"code(
1018)code" TVM_ADD_FILELINE)
1019 .set_attrs_type<LayerNormAttrs>()
1020 .set_num_inputs(3)
1021 .add_argument("data", "Tensor", "Input to which layer_norm will be applied.")
1022 .add_argument("gamma", "Tensor", "The gamma scale factor.")
1023 .add_argument("beta", "Tensor", "The beta offset factor.")
1024 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
1025 NormalizationInferCorrectLayout<LayerNormAttrs>)
1026 .set_support_level(1)
1027 .add_type_rel("LayerNorm", LayerNormRel);
1028
1029// group_norm
1030TVM_REGISTER_NODE_TYPE(GroupNormAttrs);
1031
1032bool GroupNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1033 const TypeReporter& reporter) {
1034 ICHECK_EQ(types.size(), 4);
1035 const auto* data = types[0].as<TensorTypeNode>();
1036 if (data == nullptr) return false;
1037 const GroupNormAttrs* param = attrs.as<GroupNormAttrs>();
1038 int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size();
1039 ICHECK(axis >= 0 && axis < (int)data->shape.size());
1040 reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype));
1041 reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype));
1042 reporter->Assign(types[3], TensorType(data->shape, data->dtype));
1043
1044 return true;
1045}
1046
1047Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, int axis, double epsilon,
1048 bool center, bool scale) {
1049 auto attrs = make_object<GroupNormAttrs>();
1050 attrs->num_groups = num_groups;
1051 attrs->axis = axis;
1052 attrs->epsilon = epsilon;
1053 attrs->center = center;
1054 attrs->scale = scale;
1055 static const Op& op = Op::Get("nn.group_norm");
1056 return Call(op, {data, gamma, beta}, Attrs(attrs), {});
1057}
1058
1059TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm").set_body_typed(MakeGroupNorm);
1060
1061RELAY_REGISTER_OP("nn.group_norm")
1062 .describe(R"code(
1063Group normalization normalizes over group of channels for each training examples.
1064We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put
1065all the channels into a single group, group normalization becomes Layer normalization.
1066And, when we put each channel into different groups it becomes Instance normalization
1067
1068https://arxiv.org/pdf/1803.08494.pdf
1069
1070Applies group normalization to the n-dimensional input array by seperating the input channels
1071into 'num_groups' groups, each containing 'num_channels / num_groups' channels.
1072The mean and standard-deviation are calculated separately over the each group. gamma and
1073beta are learnable per-channel affine transform parameter vectors of size num_channels.
1074
1075.. math::
1076
1077 out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
1078 * gamma + beta
1079
1080Unlike batch normalization, the mean and var are computed along a group of channels.
1081
1082If the input has size k on axis 1, then both gamma and beta have shape (k,).
1083
1084.. note::
1085
1086 This operator can be optimized away for inference.
1087
1088)code" TVM_ADD_FILELINE)
1089 .set_attrs_type<GroupNormAttrs>()
1090 .set_num_inputs(3)
1091 .add_argument("data", "Tensor", "Input to which group_norm will be applied.")
1092 .add_argument("gamma", "Tensor", "The gamma scale factor.")
1093 .add_argument("beta", "Tensor", "The beta offset factor.")
1094 .set_support_level(1)
1095 .add_type_rel("GroupNorm", GroupNormRel);
1096
1097// ------------------- relay.nn.batch_matmul
1098TVM_REGISTER_NODE_TYPE(BatchMatmulAttrs);
1099
1100// Positional relay function to create batch_matmul operator used by frontend FFI.
1101Expr MakeBatchMatmul(Expr tensor_a, Expr tensor_b, DataType out_dtype, bool transpose_a,
1102 bool transpose_b) {
1103 auto attrs = make_object<BatchMatmulAttrs>();
1104 attrs->out_dtype = out_dtype;
1105 attrs->transpose_a = transpose_a;
1106 attrs->transpose_b = transpose_b;
1107 static const Op& op = Op::Get("nn.batch_matmul");
1108 return Call(op, {tensor_a, tensor_b}, Attrs(attrs), {});
1109}
1110
1111TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul);
1112
1113RELAY_REGISTER_OP("nn.batch_matmul")
1114 .describe(R"code(Compute batch matrix multiplication of `tensor_a` and `tensor_b`.
1115
1116Both `tensor_a` and `tensor_b` can be transposed. For legacy reason, we use NT format
1117(transpose_a=False, transpose_b=True) by default.
1118
1119.. math::
1120
1121 batch\_matmul(A, B)[i, :, :] = matmul(A[i, :, :], B[i, :, :]^T)
1122
1123- **tensor_a**: `(b, m, k)` or `(b, k, m)`
1124- **tensor_b**: `(b, k, n)` or `(b, n, k)`
1125- **out**: `(b, m, n)`.
1126
1127)code" TVM_ADD_FILELINE)
1128 .set_attrs_type<BatchMatmulAttrs>()
1129 .set_num_inputs(2)
1130 .add_argument("tensor_a", "3D Tensor", "The first input.")
1131 .add_argument("tensor_b", "3D Tensor", "The second input.")
1132 .set_support_level(10)
1133 .add_type_rel("BatchMatmul", BatchMatmulRel<BatchMatmulAttrs>)
1134 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1135
1136// ------------------- relay.nn.batch_matmul
1137
1138// relay.nn.cross_entropy
1139bool CrossEntropyRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1140 const TypeReporter& reporter) {
1141 ICHECK_EQ(types.size(), 3);
1142 const auto* x = types[0].as<TensorTypeNode>();
1143 const auto* y = types[1].as<TensorTypeNode>();
1144 if (x == nullptr || y == nullptr) return false;
1145 ICHECK(x->shape.size() == 2 && y->shape.size() == 2)
1146 << "CrossEntropy: shapes of x and y is inconsistent, "
1147 << "x shape = " << x->shape << ", "
1148 << "y shape = " << y->shape;
1149 ICHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
1150 << "CrossEntropy: shapes of x and y is inconsistent, "
1151 << "x shape = " << x->shape << ", "
1152 << "y shape = " << y->shape;
1153 ICHECK(reporter->AssertEQ(x->shape[1], y->shape[1]))
1154 << "CrossEntropy: shapes of x and y is inconsistent, "
1155 << "x shape = " << x->shape << ", "
1156 << "y shape = " << y->shape;
1157 // assign output type
1158 reporter->Assign(types[2], TensorType({}, x->dtype));
1159 return true;
1160}
1161
1162// Positional relay function to create cross_entropy operator used by frontend FFI.
1163Expr MakeCrossEntropy(Expr predictions, Expr targets) {
1164 static const Op& op = Op::Get("nn.cross_entropy");
1165 return Call(op, {predictions, targets}, Attrs(), {});
1166}
1167
1168TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy").set_body_typed(MakeCrossEntropy);
1169
1170RELAY_REGISTER_OP("nn.cross_entropy")
1171 .describe(R"code(
1172Computes cross entropy given predictions and targets.
1173Do log on the data - do not accept logits.
1174)code" TVM_ADD_FILELINE)
1175 .set_num_inputs(2)
1176 .add_argument("x", "1D Tensor", "Predictions.")
1177 .add_argument("y", "1D Tensor", "Targets.")
1178 .set_support_level(10)
1179 .add_type_rel("CrossEntropy", CrossEntropyRel)
1180 .set_attr<TOpPattern>("TOpPattern", kOpaque);
1181
1182// relay.nn.dilate
1183TVM_REGISTER_NODE_TYPE(DilateAttrs);
1184
1185bool DilateRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1186 const TypeReporter& reporter) {
1187 ICHECK_EQ(types.size(), 2);
1188 const auto* x = types[0].as<TensorTypeNode>();
1189 const DilateAttrs* param = attrs.as<DilateAttrs>();
1190 if (x == nullptr) return false;
1191 ICHECK_EQ(x->shape.size(), param->strides.size());
1192
1193 std::vector<IndexExpr> oshape;
1194 for (size_t i = 0; i < param->strides.size(); ++i) {
1195 if (!x->shape[i].as<tir::AnyNode>()) {
1196 oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1);
1197 } else {
1198 oshape.push_back(x->shape[i]);
1199 }
1200 }
1201
1202 reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), x->dtype));
1203 return true;
1204}
1205
1206// Positional relay function to create dilate operator used by frontend FFI.
1207Expr MakeDilate(Expr data, Array<IndexExpr> strides, double dilation_value = 0.0) {
1208 auto attrs = make_object<DilateAttrs>();
1209 attrs->strides = std::move(strides);
1210 attrs->dilation_value = std::move(dilation_value);
1211 static const Op& op = Op::Get("nn.dilate");
1212 return Call(op, {data}, Attrs(attrs), {});
1213}
1214
1215TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate);
1216
1217RELAY_REGISTER_OP("nn.dilate")
1218 .describe(R"code(
1219Dilate data with given dilation value (0 by default).
1220)code" TVM_ADD_FILELINE)
1221 .set_num_inputs(1)
1222 .add_argument("x", "1D Tensor", "Data to dilate.")
1223 .set_support_level(10)
1224 .add_type_rel("Dilate", DilateRel)
1225 .set_attr<TOpPattern>("TOpPattern", kInjective);
1226
1227// relay.nn.cross_entropy_with_logits
1228// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
1229Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
1230 static const Op& op = Op::Get("nn.cross_entropy_with_logits");
1231 return Call(op, {predictions, targets}, Attrs(), {});
1232}
1233
1234TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy_with_logits")
1235 .set_body_typed(MakeCrossEntropyWithLogits);
1236
1237RELAY_REGISTER_OP("nn.cross_entropy_with_logits")
1238 .describe(R"code(
1239Computes cross entropy given predictions and targets.
1240Accept logits.
1241)code" TVM_ADD_FILELINE)
1242 .set_num_inputs(2)
1243 .add_argument("x", "1D Tensor", "Predictions.")
1244 .add_argument("y", "1D Tensor", "Targets.")
1245 .set_support_level(10)
1246 .add_type_rel("CrossEntropy", CrossEntropyRel)
1247 .set_attr<TOpPattern>("TOpPattern", kOpaque);
1248
1249// Depth to space and space to depth
1250TVM_REGISTER_NODE_TYPE(SubPixelAttrs);
1251
1252// relay.nn.nll_loss
1253TVM_REGISTER_NODE_TYPE(NLLLossAttrs);
1254
1255bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1256 const TypeReporter& reporter) {
1257 ICHECK_EQ(types.size(), 4) << "NLLLossRel expects 4 types, but " << types.size()
1258 << " were provided.";
1259 const auto* predictions = types[0].as<TensorTypeNode>();
1260 const auto* targets = types[1].as<TensorTypeNode>();
1261 const auto* weights = types[2].as<TensorTypeNode>();
1262 const NLLLossAttrs* param = attrs.as<NLLLossAttrs>();
1263 if (predictions == nullptr || targets == nullptr || weights == nullptr) return false;
1264 if (!(predictions->shape.size() - targets->shape.size() == 1)) {
1265 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
1266 << "NLLLossRel: predictions should be one"
1267 << " dimension larger than targets,"
1268 << "predictions shape = " << predictions->shape
1269 << ", targets shape = " << targets->shape);
1270 return false;
1271 }
1272 if (!(weights->shape.size() == 1)) {
1273 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
1274 << "NLLLossRel: weights should be a one dimension"
1275 << " Tensor with its length the number of classes,"
1276 << " but Tensor of dimension " << weights->shape.size()
1277 << " were provided.");
1278 return false;
1279 }
1280 if (!reporter->AssertEQ(predictions->shape[1], weights->shape[0])) {
1281 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
1282 << "NLLLossRel: the second dimension of predictions"
1283 << " should be the number of classes, "
1284 << "which is the length of weights, "
1285 << "predictions shape = " << predictions->shape
1286 << ", weights shape = " << weights->shape);
1287 return false;
1288 }
1289 if (!(predictions->dtype == weights->dtype &&
1290 (predictions->dtype.is_float() || predictions->dtype.is_bfloat16()))) {
1291 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
1292 << "NLLLossRel: predictions and weights should"
1293 << " be of the same floating type.");
1294 return false;
1295 }
1296 if (!targets->dtype.is_int()) {
1297 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
1298 << "NLLLossRel: targets should be of int type.");
1299 return false;
1300 }
1301 // assign output type
1302 if (param->reduction == "none") {
1303 reporter->Assign(types[3], TensorType(targets->shape, predictions->dtype));
1304 } else {
1305 reporter->Assign(types[3], TensorType({}, predictions->dtype));
1306 }
1307 return true;
1308}
1309
1310// Handler to create a call to the padding op used by front-end FFI
1311Expr MakeNLLLoss(Expr predictions, Expr targets, Expr weights, String reduction, int ignore_index) {
1312 auto attrs = make_object<NLLLossAttrs>();
1313 attrs->reduction = reduction;
1314 attrs->ignore_index = ignore_index;
1315 static const Op& op = Op::Get("nn.nll_loss");
1316 return Call(op, {predictions, targets, weights}, Attrs(attrs), {});
1317}
1318
1319TVM_REGISTER_GLOBAL("relay.op.nn._make.nll_loss").set_body_typed(MakeNLLLoss);
1320
1321RELAY_REGISTER_OP("nn.nll_loss")
1322 .describe(R"code(
1323Negative log likelihood loss for given prediction and target.
1324)code" TVM_ADD_FILELINE)
1325 .set_attrs_type<NLLLossAttrs>()
1326 .set_num_inputs(3)
1327 .add_argument("predictions", "Tensor", "The prediction tensor.")
1328 .add_argument("targets", "Tensor", "The target tensor.")
1329 .add_argument("weights", "Tensor", "The weight of each target values.")
1330 .add_type_rel("NLLLoss", NLLLossRel)
1331 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
1332
1333bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1334 const TypeReporter& reporter) {
1335 ICHECK_EQ(types.size(), 2);
1336 const auto* data = types[0].as<TensorTypeNode>();
1337 if (data == nullptr) return false;
1338
1339 static const Layout kNCHW("NCHW");
1340
1341 const SubPixelAttrs* param = attrs.as<SubPixelAttrs>();
1342 ICHECK(param != nullptr);
1343 const int block_size = param->block_size;
1344 const Layout in_layout(param->layout);
1345 auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
1346 ICHECK(layout_converter.defined())
1347 << "DepthToSpace only support input layouts that are convertible from NCHW."
1348 << " But got " << in_layout;
1349
1350 auto oshape = layout_converter.ForwardShape(data->shape);
1351 if (!oshape[1].as<tir::AnyNode>()) {
1352 oshape.Set(1, indexdiv(oshape[1], (block_size * block_size)));
1353 }
1354 if (!oshape[2].as<tir::AnyNode>()) {
1355 oshape.Set(2, oshape[2] * block_size);
1356 }
1357 if (!oshape[3].as<tir::AnyNode>()) {
1358 oshape.Set(3, oshape[3] * block_size);
1359 }
1360
1361 // Assign output type
1362 reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
1363
1364 return true;
1365}
1366
1367// Positional relay function to create DepthToSpace operator
1368// used by frontend FFI
1369Expr MakeDepthToSpace(Expr data, int block_size, String layout, String mode) {
1370 auto attrs = make_object<SubPixelAttrs>();
1371 attrs->block_size = block_size;
1372 attrs->layout = std::move(layout);
1373 attrs->mode = std::move(mode);
1374 static const Op& op = Op::Get("nn.depth_to_space");
1375 return Call(op, {data}, Attrs(attrs), {});
1376}
1377
1378TVM_REGISTER_GLOBAL("relay.op.nn._make.depth_to_space").set_body_typed(MakeDepthToSpace);
1379
1380RELAY_REGISTER_OP("nn.depth_to_space")
1381 .describe(R"code(Rearrange input channels into spatial pixels.
1382
1383- **data**: data is a 4D array of shape
1384 (batch, in_channels, in_height, in_width) for NCHW
1385
1386- **out**: Output is a 4D array of shape
1387 (batch, in_channels / block_size * block_size, in_height * block_size, in_width * block_size) for NCHW.
1388
1389)code" TVM_ADD_FILELINE)
1390 .set_attrs_type<SubPixelAttrs>()
1391 .set_num_inputs(1)
1392 .add_argument("data", "Tensor", "The input tensor")
1393 .set_support_level(5)
1394 .add_type_rel("DepthToSpace", DepthToSpaceRel)
1395 .set_attr<TOpPattern>("TOpPattern", kInjective);
1396
1397bool SpaceToDepthRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1398 const TypeReporter& reporter) {
1399 ICHECK_EQ(types.size(), 2);
1400 const auto* data = types[0].as<TensorTypeNode>();
1401 if (data == nullptr) return false;
1402
1403 static const Layout kNCHW("NCHW");
1404
1405 const SubPixelAttrs* param = attrs.as<SubPixelAttrs>();
1406 ICHECK(param != nullptr);
1407 const int block_size = param->block_size;
1408 const Layout in_layout(param->layout);
1409 auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
1410 ICHECK(layout_converter.defined())
1411 << "SpaceToDepth only support input layouts that are convertible from NCHW."
1412 << " But got " << in_layout;
1413
1414 auto oshape = layout_converter.ForwardShape(data->shape);
1415 if (!oshape[1].as<tir::AnyNode>()) {
1416 oshape.Set(1, oshape[1] * (block_size * block_size));
1417 }
1418 if (!oshape[2].as<tir::AnyNode>()) {
1419 oshape.Set(2, indexdiv(oshape[2], block_size));
1420 }
1421 if (!oshape[3].as<tir::AnyNode>()) {
1422 oshape.Set(3, indexdiv(oshape[3], block_size));
1423 }
1424
1425 // Assign output type
1426 reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
1427
1428 return true;
1429}
1430
1431// Positional relay function to create SpaceToDepth operator
1432// used by frontend FFI
1433Expr MakeSpaceToDepth(Expr data, int block_size, String layout) {
1434 auto attrs = make_object<SubPixelAttrs>();
1435 attrs->block_size = block_size;
1436 attrs->layout = std::move(layout);
1437 static const Op& op = Op::Get("nn.space_to_depth");
1438 return Call(op, {data}, Attrs(attrs), {});
1439}
1440
1441TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_depth").set_body_typed(MakeSpaceToDepth);
1442
1443RELAY_REGISTER_OP("nn.space_to_depth")
1444 .describe(R"code(Rearrange spatial pixels into new output channels.
1445
1446- **data**: data is a 4D array of shape
1447 (batch, in_channels, in_height, in_width) for NCHW
1448
1449- **out**: Output is a 4D array of shape
1450 (batch, in_channels * block_size * block_size, in_height / block_size, in_width / block_size) for NCHW.
1451
1452)code" TVM_ADD_FILELINE)
1453 .set_attrs_type<SubPixelAttrs>()
1454 .set_num_inputs(1)
1455 .add_argument("data", "Tensor", "The input tensor")
1456 .set_support_level(5)
1457 .add_type_rel("SpaceToDepth", SpaceToDepthRel)
1458 .set_attr<TOpPattern>("TOpPattern", kInjective);
1459
1460// Positional relay function to create SpaceToBatchND operator
1461// used by frontend FFI
1462TVM_REGISTER_NODE_TYPE(SpaceToBatchNDAttrs);
1463
1464Expr MakeSpaceToBatchND(Expr data, Array<Integer> block_shape, Array<Array<IndexExpr>> paddings,
1465 double pad_value) {
1466 auto attrs = make_object<SpaceToBatchNDAttrs>();
1467 attrs->block_shape = std::move(block_shape);
1468 attrs->paddings = std::move(paddings);
1469 attrs->pad_value = pad_value;
1470 static const Op& op = Op::Get("nn.space_to_batch_nd");
1471 return Call(op, {data}, Attrs(attrs), {});
1472}
1473
1474bool SpaceToBatchNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1475 const TypeReporter& reporter) {
1476 CHECK_EQ(types.size(), 2);
1477
1478 auto* input = types[0].as<TensorTypeNode>();
1479 // Input must be a TensorType
1480 if (input == nullptr) {
1481 CHECK(types[0].as<IncompleteTypeNode>())
1482 << "SpaceToBatchND: expect input type to be TensorType but got " << types[0];
1483 return false;
1484 }
1485
1486 if (input->shape.size() <= 1) return false;
1487
1488 const auto* param = attrs.as<SpaceToBatchNDAttrs>();
1489 CHECK(param != nullptr);
1490
1491 auto block_shape = param->block_shape;
1492 auto paddings = param->paddings;
1493 const int bdims = static_cast<int>(block_shape.size());
1494 const int pdims = static_cast<int>(paddings.size());
1495 // Paddings must be provided for each spatial dim.
1496 CHECK(pdims == bdims) << "SpaceToBatchND: Paddings must be provided for each spatial dim";
1497
1498 // Apply paddings to input
1499 auto in_shape = input->shape;
1500 std::vector<IndexExpr> padded_shape(input->shape.begin(), input->shape.end());
1501 for (size_t i = 0; i < paddings.size(); i++) {
1502 CHECK_EQ(paddings[i].size(), 2U);
1503 auto pad_before = tir::as_const_int(param->paddings[i][0]);
1504 auto pad_after = tir::as_const_int(param->paddings[i][1]);
1505 auto padding = tir::make_const(input->shape[i].dtype(), *pad_before + *pad_after);
1506 padded_shape[i + 1] = in_shape[i + 1] + padding;
1507 }
1508
1509 auto block_shape_numele = tir::make_const(DataType::Int(32), 1);
1510 for (size_t i = 0; i < block_shape.size(); i++) {
1511 block_shape_numele *= block_shape[i];
1512 }
1513
1514 // Construct output shape
1515 std::vector<IndexExpr> out_shape(padded_shape);
1516 out_shape[0] = in_shape[0] * block_shape_numele;
1517 for (size_t i = 1; i <= block_shape.size(); i++) {
1518 out_shape[i] = div(padded_shape[i], block_shape[i - 1]);
1519 }
1520
1521 // Assign output shape
1522 reporter->Assign(types[1], TensorType(Array<IndexExpr>(out_shape), input->dtype));
1523 return true;
1524}
1525
1526Array<te::Tensor> SpaceToBatchNDCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1527 const Type& out_type) {
1528 const auto* param = attrs.as<SpaceToBatchNDAttrs>();
1529 CHECK(param != nullptr);
1530
1531 auto b_shape = param->block_shape;
1532 auto paddings = param->paddings;
1533 Array<IndexExpr> pad_before;
1534 Array<IndexExpr> pad_after;
1535
1536 for (size_t i = 0; i < paddings.size(); ++i) {
1537 pad_before.push_back(paddings[i][0]);
1538 }
1539 for (size_t i = 0; i < paddings.size(); ++i) {
1540 pad_after.push_back(paddings[i][1]);
1541 }
1542 const auto* out_ttype = out_type.as<TensorTypeNode>();
1543 return Array<te::Tensor>{
1544 topi::space_to_batch_nd(inputs[0], b_shape, pad_before, pad_after,
1545 tvm::tir::make_const(out_ttype->dtype, param->pad_value))};
1546}
1547
1548TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_batch_nd").set_body_typed(MakeSpaceToBatchND);
1549
1550RELAY_REGISTER_OP("nn.space_to_batch_nd")
1551 .describe(R"code(Divide spatial dimensions of the input into a grid of blocks
1552and interleave them into batch dim.
1553
1554- **data**: data is a ND array of shape
1555 (batch, spatial_shapes, remaining_shapes) for NHWC
1556
1557- **out**: Output is a ND array of shape
1558 (batch * prod(block_shape), padded_data[1] / block_shape[0], ..., padded_data[M] / block_shape[M-1],
1559 remaining_shape) for NHWC, where M is the number of spatial dimensions.
1560
1561Example::
1562
1563 x = [[[[1], [2]], [[3], [4]]]]
1564
1565 space_to_batch_nd(x, block_shape = [2, 2]) =
1566 [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
1567
1568)code" TVM_ADD_FILELINE)
1569 .set_num_inputs(1)
1570 .add_argument("data", "Tensor", "The input tensor.")
1571 .set_attrs_type<SpaceToBatchNDAttrs>()
1572 .set_support_level(5)
1573 .add_type_rel("SpaceToBatchND", SpaceToBatchNDRel)
1574 .set_attr<FTVMCompute>("FTVMCompute", SpaceToBatchNDCompute)
1575 .set_attr<TOpPattern>("TOpPattern", kInjective);
1576
1577/*****************************************************************/
1578
1579// Positional relay function to create BatchToSpaceND operator
1580// used by frontend FFI
1581TVM_REGISTER_NODE_TYPE(BatchToSpaceNDAttrs);
1582
1583Expr MakeBatchToSpaceND(Expr data, Array<Integer> block_shape, Array<Array<IndexExpr>> crops) {
1584 auto attrs = make_object<BatchToSpaceNDAttrs>();
1585 attrs->block_shape = std::move(block_shape);
1586 attrs->crops = std::move(crops);
1587 static const Op& op = Op::Get("nn.batch_to_space_nd");
1588 return Call(op, {data}, Attrs(attrs), {});
1589}
1590
1591bool BatchToSpaceNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1592 const TypeReporter& reporter) {
1593 CHECK_EQ(types.size(), 2);
1594
1595 auto* input = types[0].as<TensorTypeNode>();
1596 // Input must be a TensorType
1597 if (input == nullptr) {
1598 CHECK(types[0].as<IncompleteTypeNode>())
1599 << "BatchToSpaceND: expect input type to be TensorType but got " << types[0];
1600 return false;
1601 }
1602
1603 if (input->shape.size() <= 1) return false;
1604
1605 const auto* param = attrs.as<BatchToSpaceNDAttrs>();
1606 CHECK(param != nullptr);
1607
1608 auto block_shape = param->block_shape;
1609 auto crops = param->crops;
1610 const int bdims = static_cast<int>(block_shape.size());
1611 const int cdims = static_cast<int>(crops.size());
1612 const int indims = static_cast<int>(input->shape.size());
1613 // crops must be provided for each spatial dim.
1614 CHECK(cdims == bdims) << "BatchToSpaceND: crops must be provided for each spatial dim";
1615 CHECK(bdims < indims) << "BatchToSpaceND: block_shape must be less than input shape";
1616
1617 auto block_shape_numele = tir::make_const(DataType::Int(32), 1);
1618 for (size_t i = 0; i < block_shape.size(); i++) {
1619 block_shape_numele *= block_shape[i];
1620 }
1621
1622 auto in_shape = input->shape;
1623
1624 // Construct output shape
1625 // Start with input shape, only batch and spatial dims shapes are modified.
1626 std::vector<IndexExpr> out_shape(input->shape.begin(), input->shape.end());
1627 out_shape[0] = in_shape[0] / block_shape_numele;
1628 for (size_t i = 1; i <= block_shape.size(); i++) {
1629 out_shape[i] = (in_shape[i] * block_shape[i - 1]) - crops[i - 1][0] - crops[i - 1][1];
1630 }
1631 for (int i = bdims + 1; i < indims; i++) {
1632 out_shape[i] = in_shape[i];
1633 }
1634
1635 // Assign output shape
1636 reporter->Assign(types[1], TensorType(Array<IndexExpr>(out_shape), input->dtype));
1637 return true;
1638}
1639
1640Array<te::Tensor> BatchToSpaceNDCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1641 const Type& out_type) {
1642 const auto* param = attrs.as<BatchToSpaceNDAttrs>();
1643 CHECK(param != nullptr);
1644
1645 auto b_shape = param->block_shape;
1646 auto crops = param->crops;
1647 Array<IndexExpr> crop_begin_list, crop_end_list;
1648 for (size_t i = 0; i < crops.size(); ++i) {
1649 crop_begin_list.push_back(crops[i][0]);
1650 crop_end_list.push_back(crops[i][1]);
1651 }
1652
1653 return Array<te::Tensor>{
1654 topi::batch_to_space_nd(inputs[0], b_shape, crop_begin_list, crop_end_list)};
1655}
1656
1657TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_to_space_nd").set_body_typed(MakeBatchToSpaceND);
1658
1659RELAY_REGISTER_OP("nn.batch_to_space_nd")
1660 .describe(R"code(Reshape the batch dimension into spatial dimensions.
1661
1662Example::
1663
1664 x = [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
1665
1666 batch_to_space_nd(x, block_shape = [2, 2]) =
1667 [[[[1], [2]], [[3], [4]]]]
1668
1669)code" TVM_ADD_FILELINE)
1670 .set_num_inputs(1)
1671 .add_argument("data", "Tensor", "The input tensor.")
1672 .set_attrs_type<BatchToSpaceNDAttrs>()
1673 .set_support_level(5)
1674 .add_type_rel("BatchToSpaceND", BatchToSpaceNDRel)
1675 .set_attr<FTVMCompute>("FTVMCompute", BatchToSpaceNDCompute)
1676 .set_attr<TOpPattern>("TOpPattern", kInjective);
1677
1678} // namespace relay
1679} // namespace tvm
1680