1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file transform.cc
22 * \brief Transform operators.
23 */
24#include "transform.h"
25
26#include <tvm/relay/attrs/transform.h>
27#include <tvm/relay/error.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/op.h>
30#include <tvm/runtime/packed_func.h>
31#include <tvm/tir/data_layout.h>
32#include <tvm/tir/expr.h>
33#include <tvm/tir/op.h>
34#include <tvm/topi/broadcast.h>
35#include <tvm/topi/detail/constant_utils.h>
36#include <tvm/topi/elemwise.h>
37#include <tvm/topi/nn.h>
38#include <tvm/topi/reduction.h>
39#include <tvm/topi/transform.h>
40
41#include <sstream>
42#include <vector>
43
44#include "../../transforms/infer_layout_utils.h"
45#include "../../transforms/pass_utils.h"
46#include "../../transforms/pattern_utils.h"
47#include "../make_op.h"
48#include "../op_common.h"
49#include "../type_relations.h"
50
51namespace tvm {
52namespace relay {
53using tir::IntImmNode;
54
55TVM_REGISTER_NODE_TYPE(SlidingWindowAttrs);
56
57bool SlidingWindowRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
58 const TypeReporter& reporter) {
59 // `types` contains: [data, result]
60 ICHECK_EQ(types.size(), 2);
61 const auto* data = types[0].as<TensorTypeNode>();
62 if (data == nullptr) {
63 reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
64 << "SlidingWindow operator expects input to be of TensorType "
65 << "but got " << PrettyPrint(types[0]));
66 return false;
67 }
68 const auto* param = attrs.as<SlidingWindowAttrs>();
69 const int axis = param->axis;
70
71 std::vector<IndexExpr> oshape;
72
73 // Dimensions up until `axis` remain the same.
74 for (int i = 0; i < axis; ++i) {
75 oshape.emplace_back(data->shape[i]);
76 }
77
78 // New dimensions which result from sliding the window in each dimension. One new dimension per
79 // window dimension.
80 for (size_t i = 0; i < param->window_shape.size(); ++i) {
81 // Length of the shape along this dimension.
82 auto dim_len = data->shape[axis + i];
83 // Length of the window along this dimension.
84 auto window_len = param->window_shape[i];
85 // Strides along this dimension.
86 auto stride = param->strides[i];
87
88 oshape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride));
89 }
90
91 // Dimensions comprising the window.
92 for (size_t i = 0; i < param->window_shape.size(); ++i) {
93 oshape.push_back(param->window_shape[i]);
94 }
95
96 reporter->Assign(types[1], TensorType(oshape, data->dtype));
97 return true;
98}
99
100Array<te::Tensor> SlidingWindowCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
101 const Type& out_type) {
102 const SlidingWindowAttrs* param = attrs.as<SlidingWindowAttrs>();
103 ICHECK(param != nullptr);
104 return {topi::sliding_window(inputs[0], param->axis, param->window_shape, param->strides)};
105}
106
107Expr MakeSlidingWindow(Expr data, int axis, Array<Integer> window_shape, Array<Integer> strides) {
108 auto attrs = make_object<SlidingWindowAttrs>();
109 attrs->axis = axis;
110 attrs->window_shape = window_shape;
111 attrs->strides = strides;
112 static const Op& op = Op::Get("sliding_window");
113 return Call(op, {data}, Attrs(attrs), {});
114}
115
116TVM_REGISTER_GLOBAL("relay.ir.sliding_window").set_body_typed(MakeSlidingWindow);
117
118RELAY_REGISTER_OP("sliding_window")
119 .describe(R"code(Slide window over a tensor.)code" TVM_ADD_FILELINE)
120 .set_num_inputs(1)
121 .set_attrs_type<SlidingWindowAttrs>()
122 .add_argument("data", "Tensor", "The input tensor.")
123 .add_type_rel("SlidingWindow", SlidingWindowRel)
124 .set_attr<TOpPattern>("TOpPattern", kOpaque);
125
126// relay.cast
127TVM_REGISTER_NODE_TYPE(CastAttrs);
128
129bool CastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
130 const TypeReporter& reporter) {
131 ICHECK_EQ(types.size(), 2);
132 const auto* data = types[0].as<TensorTypeNode>();
133 if (data == nullptr) {
134 ICHECK(types[0].as<IncompleteTypeNode>())
135 << "cast: expect input type to be TensorType but get " << types[0];
136 return false;
137 }
138 const auto* param = attrs.as<CastAttrs>();
139 reporter->Assign(types[1], TensorType(data->shape, param->dtype));
140 return true;
141}
142
143Array<te::Tensor> CastCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
144 const Type& out_type) {
145 const CastAttrs* param = attrs.as<CastAttrs>();
146 ICHECK(param != nullptr);
147 DataType dtype = param->dtype;
148 return {topi::cast(inputs[0], dtype)};
149}
150
151Expr MakeCast(Expr data, DataType dtype) {
152 auto attrs = make_object<CastAttrs>();
153 attrs->dtype = dtype;
154 static const Op& op = Op::Get("cast");
155 return Call(op, {data}, Attrs(attrs), {});
156}
157
158TVM_REGISTER_GLOBAL("relay.ir.cast").set_body_typed(MakeCast);
159
160RELAY_REGISTER_OP("cast")
161 .describe(R"code(Cast the data into a new data type.
162
163)code" TVM_ADD_FILELINE)
164 .set_num_inputs(1)
165 .set_attrs_type<CastAttrs>()
166 .add_argument("data", "Tensor", "The input tensor.")
167 .set_support_level(3)
168 .add_type_rel("Cast", CastRel)
169 .set_attr<FTVMCompute>("FTVMCompute", CastCompute)
170 .set_attr<TOpPattern>("TOpPattern", kElemWise)
171 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
172
173// relay.cast_like
174bool CastLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
175 const TypeReporter& reporter) {
176 ICHECK_EQ(types.size(), 3);
177 const auto* data = types[0].as<TensorTypeNode>();
178 if (data == nullptr) {
179 ICHECK(types[0].as<IncompleteTypeNode>())
180 << "cast: expect input type to be TensorType but get " << types[0];
181 return false;
182 }
183 const auto* dtype_like = types[1].as<TensorTypeNode>();
184 if (dtype_like == nullptr) {
185 ICHECK(types[1].as<IncompleteTypeNode>())
186 << "cast: expect input type to be TensorType but get " << types[1];
187 return false;
188 }
189 reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype));
190 return true;
191}
192
193Array<te::Tensor> CastLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
194 const Type& out_type) {
195 return {topi::cast(inputs[0], inputs[1]->dtype)};
196}
197
198Expr MakeCastLike(Expr data, Expr dtype_like) {
199 static const Op& op = Op::Get("cast_like");
200 return Call(op, {data, dtype_like}, Attrs(), {});
201}
202
203TVM_REGISTER_GLOBAL("relay.ir.cast_like").set_body_typed(MakeCastLike);
204
205RELAY_REGISTER_OP("cast_like")
206 .describe(R"code(Cast the data into the type of another tensor.
207)code" TVM_ADD_FILELINE)
208 .set_num_inputs(2)
209 .add_argument("data", "Tensor", "The input tensor.")
210 .add_argument("dtype_like", "Tensor", "The tensor to cast to.")
211 .set_support_level(3)
212 .add_type_rel("CastLike", CastLikeRel)
213 .set_attr<FTVMCompute>("FTVMCompute", CastLikeCompute)
214 .set_attr<TOpPattern>("TOpPattern", kElemWise)
215 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
216
217Array<te::Tensor> ReinterpretCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
218 const Type& out_type) {
219 const CastAttrs* param = attrs.as<CastAttrs>();
220 ICHECK(param != nullptr);
221 DataType dtype = param->dtype;
222 return {topi::reinterpret(inputs[0], dtype)};
223}
224
225Expr MakeReinterpret(Expr data, DataType dtype) {
226 auto attrs = make_object<CastAttrs>();
227 attrs->dtype = dtype;
228 static const Op& op = Op::Get("reinterpret");
229 return Call(op, {data}, Attrs(attrs), {});
230}
231
232TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body_typed(MakeReinterpret);
233
234RELAY_REGISTER_OP("reinterpret")
235 .describe(R"code(Reinterpret the data into a new data type.
236)code" TVM_ADD_FILELINE)
237 .set_num_inputs(1)
238 .set_attrs_type<CastAttrs>()
239 .add_argument("data", "Tensor", "The input tensor.")
240 .set_support_level(3)
241 .add_type_rel("Reinterpret", CastRel)
242 .set_attr<FTVMCompute>("FTVMCompute", ReinterpretCompute)
243 .set_attr<TOpPattern>("TOpPattern", kElemWise)
244 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
245
246// relay.expand_dims
247TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs);
248
249bool ExpandDimsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
250 const TypeReporter& reporter) {
251 // `types` contains: [data, result]
252 ICHECK_EQ(types.size(), 2);
253 const auto* data = types[0].as<TensorTypeNode>();
254 if (data == nullptr) {
255 ICHECK(types[0].as<IncompleteTypeNode>())
256 << "expand_dims: expect input type to be TensorType but get " << types[0];
257 return false;
258 }
259 const auto* param = attrs.as<ExpandDimsAttrs>();
260 const int ndim = static_cast<int>(data->shape.size());
261 const int axis = param->axis;
262 const int num_newaxis = param->num_newaxis;
263 ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`"
264 << ", but got num_newaxis = " << num_newaxis;
265 ICHECK(-ndim - 1 <= axis && axis <= ndim)
266 << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]"
267 << ", but got axis = " << axis << ", and data.ndim = " << ndim;
268 const int pivot = axis < 0 ? ndim + axis + 1 : axis;
269 std::vector<IndexExpr> oshape;
270 oshape.reserve(ndim + num_newaxis);
271 for (int i = 0; i < pivot; ++i) {
272 oshape.emplace_back(data->shape[i]);
273 }
274 for (int i = 0; i < num_newaxis; ++i) {
275 oshape.emplace_back(1);
276 }
277 for (int i = pivot; i < ndim; ++i) {
278 oshape.emplace_back(data->shape[i]);
279 }
280 reporter->Assign(types[1], TensorType(oshape, data->dtype));
281 return true;
282}
283
284Array<te::Tensor> ExpandDimsCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
285 const Type& out_type) {
286 const ExpandDimsAttrs* param = attrs.as<ExpandDimsAttrs>();
287 ICHECK(param != nullptr);
288 return {topi::expand_dims(inputs[0], param->axis, param->num_newaxis)};
289}
290
291Expr MakeExpandDims(Expr data, int axis, int num_newaxis) {
292 auto attrs = make_object<ExpandDimsAttrs>();
293 attrs->axis = axis;
294 attrs->num_newaxis = num_newaxis;
295 static const Op& op = Op::Get("expand_dims");
296 return Call(op, {data}, Attrs(attrs), {});
297}
298
299TVM_REGISTER_GLOBAL("relay.op._make.expand_dims").set_body_typed(MakeExpandDims);
300
301RELAY_REGISTER_OP("expand_dims")
302 .describe(R"code(Insert `num_newaxis` axes at the position given by `axis`
303
304- **data**: The input data to the operator.
305
306)code" TVM_ADD_FILELINE)
307 .set_num_inputs(1)
308 .set_attrs_type<ExpandDimsAttrs>()
309 .add_argument("data", "Tensor", "The input tensor.")
310 .set_support_level(1)
311 .add_type_rel("ExpandDims", ExpandDimsRel)
312 .set_attr<FTVMCompute>("FTVMCompute", ExpandDimsCompute)
313 .set_attr<TOpPattern>("TOpPattern", kBroadcast)
314 .set_attr<TReshapeOp>("TReshapeOp", true);
315
316// relay.concatenate
317TVM_REGISTER_NODE_TYPE(ConcatenateAttrs);
318
319Array<te::Tensor> ConcatenateCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
320 const Type& out_type) {
321 const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>();
322 ICHECK(param != nullptr);
323 return {topi::concatenate(inputs, param->axis)};
324}
325
326Expr MakeConcatenate(Expr data, int axis) {
327 auto attrs = make_object<ConcatenateAttrs>();
328 attrs->axis = axis;
329 static const Op& op = Op::Get("concatenate");
330 return Call(op, {data}, Attrs(attrs), {});
331}
332
333TVM_REGISTER_GLOBAL("relay.op._make.concatenate").set_body_typed(MakeConcatenate);
334
335RELAY_REGISTER_OP("concatenate")
336 .describe(R"code(Concatenate the input tensors along the given axis.
337
338- **data** : A list of tensors.
339
340- **axis** : The axis along which the tensors are concatenated.
341
342)code" TVM_ADD_FILELINE)
343 .set_attrs_type<ConcatenateAttrs>()
344 .set_num_inputs(1)
345 .add_argument("data", "Tensor", "The input list of tensors.")
346 .set_support_level(1)
347 .add_type_rel("Concatenate", ConcatenateRel<ConcatenateAttrs>)
348 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConcatenateLayout)
349 .set_attr<TOpPattern>("TOpPattern", kInjective);
350
351TVM_REGISTER_NODE_TYPE(StackAttrs);
352
353bool StackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
354 const TypeReporter& reporter) {
355 // types: [data, result]
356 ICHECK_EQ(types.size(), 2);
357 const auto* tensor_tuple = types[0].as<TupleTypeNode>();
358 if (tensor_tuple == nullptr) {
359 ICHECK(types[0].as<IncompleteTypeNode>())
360 << "cast: expect input type to be TupleType but get " << types[0];
361 return false;
362 }
363 for (auto field : tensor_tuple->fields) {
364 if (field.as<IncompleteTypeNode>()) {
365 return false;
366 }
367 }
368 const auto* param = attrs.as<StackAttrs>();
369 const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
370 const int ndim = static_cast<int>(first->shape.size());
371
372 // Sanity check: axis
373 int axis = param->axis.IntValue();
374 ICHECK(-(ndim + 1) <= axis && axis < ndim + 1)
375 << "stack only accepts `axis` in [-(ndim+1), ndim+1)"
376 << ", but got axis = " << axis << ", and ndim = " << ndim;
377 axis = axis < 0 ? ndim + axis + 1 : axis;
378
379 // Sanity check: ndim and dtype.
380 const DataType dtype = first->dtype;
381 for (const Type& ele : tensor_tuple->fields) {
382 const auto& e = Downcast<TensorType>(ele);
383 int e_ndim = static_cast<int>(e->shape.size());
384 const DataType& e_dtype = e->dtype;
385 ICHECK_EQ(e_ndim, ndim) << "relay.stack requires all tensors have the same ndim";
386 ICHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype";
387 for (size_t j = 0; j < first->shape.size(); ++j) {
388 if (j == static_cast<size_t>(axis)) continue;
389 if (first->shape[j].as<AnyNode>() || e->shape[j].as<AnyNode>() ||
390 reporter->AssertEQ(first->shape[j], e->shape[j]))
391 continue;
392 throw CompileError(
393 "relay.stack requires all tensors have the same shape "
394 "on non-stacking axes");
395 }
396 }
397
398 // Calculate shape
399 std::vector<IndexExpr> oshape;
400 oshape.reserve(ndim + 1);
401 const int stack_dim = static_cast<int>(tensor_tuple->fields.size());
402 for (int i = 0; i < axis; ++i) {
403 oshape.emplace_back(first->shape[i]);
404 }
405 oshape.emplace_back(stack_dim);
406 for (int i = axis; i < ndim; ++i) {
407 oshape.emplace_back(first->shape[i]);
408 }
409 reporter->Assign(types[1], TensorType(oshape, dtype));
410 return true;
411}
412
413Array<te::Tensor> StackCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
414 const Type& out_type) {
415 const StackAttrs* param = attrs.as<StackAttrs>();
416 ICHECK(param != nullptr);
417 return {topi::stack(inputs, param->axis.IntValue())};
418}
419
420Expr MakeStack(Expr data, int axis) {
421 auto attrs = make_object<StackAttrs>();
422 attrs->axis = axis;
423 static const Op& op = Op::Get("stack");
424 return Call(op, {data}, Attrs(attrs), {});
425}
426
427TVM_REGISTER_GLOBAL("relay.op._make.stack").set_body_typed(MakeStack);
428
429RELAY_REGISTER_OP("stack")
430 .describe(R"code(Stack the input tensors along the given axis.
431
432- **data** : A list of tensors.
433
434- **axis** : The axis along which the tensors are stacked.
435
436)code" TVM_ADD_FILELINE)
437 .set_attrs_type<StackAttrs>()
438 .set_num_inputs(1)
439 .add_argument("data", "Tensor", "The input list of tensors.")
440 .set_support_level(3)
441 .add_type_rel("Stack", StackRel)
442 .set_attr<FTVMCompute>("FTVMCompute", StackCompute)
443 .set_attr<TOpPattern>("TOpPattern", kInjective);
444
445/* relay.transpose */
446TVM_REGISTER_NODE_TYPE(TransposeAttrs);
447
448bool TransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
449 const TypeReporter& reporter) {
450 // types: [data, result]
451 ICHECK_EQ(types.size(), 2);
452 const auto* data = types[0].as<TensorTypeNode>();
453 if (data == nullptr) {
454 ICHECK(types[0].as<IncompleteTypeNode>())
455 << "transpose: expect input type to be TensorType but get " << types[0];
456 return false;
457 }
458 const auto* param = attrs.as<TransposeAttrs>();
459 const int ndim = data->shape.size();
460 const Array<Integer>& axes = param->axes;
461 // check dimension match
462 ICHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim)
463 << "Dimension mismatch: axes has " << axes.size() << " elements"
464 << ", but data.ndim = " << ndim;
465 // construct int_axes
466 std::vector<int> int_axes;
467 int_axes.reserve(ndim);
468 // used not defined to check if it is None.
469 if (!axes.defined()) {
470 for (int i = ndim - 1; i >= 0; --i) {
471 int_axes.push_back(i);
472 }
473 } else {
474 std::vector<int> axis_used(ndim, 0);
475 for (const Integer& e : axes) {
476 int64_t axis = e.IntValue();
477 // sanity check for axis and ndim
478 ICHECK(-ndim <= axis && axis < ndim)
479 << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
480 << ", but got axis = " << axis << ", and data.ndim = " << ndim;
481 axis = axis < 0 ? axis + ndim : axis;
482 // sanity check for duplication
483 ICHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
484 axis_used[axis] = 1;
485 int_axes.push_back(static_cast<int>(axis));
486 }
487 }
488 std::vector<IndexExpr> oshape;
489 oshape.reserve(ndim);
490 for (int axis : int_axes) {
491 oshape.push_back(data->shape[axis]);
492 }
493 reporter->Assign(types[1], TensorType(oshape, data->dtype));
494 return true;
495}
496
497InferCorrectLayoutOutput TransposeInferCorrectLayout(const Attrs& attrs,
498 const Array<Layout>& new_in_layouts,
499 const Array<Layout>& old_in_layouts,
500 const Array<tvm::relay::Type>& old_in_types) {
501 const auto* attrs_ptr = attrs.as<TransposeAttrs>();
502 ICHECK(attrs_ptr);
503 ObjectPtr<TransposeAttrs> params = make_object<TransposeAttrs>(*attrs_ptr);
504
505 std::string in_layout_str = "";
506 std::string out_layout_str = "";
507
508 // Infer the input layout string and update the axes.
509 if (old_in_layouts.defined() && old_in_layouts[0].defined()) {
510 ICHECK_EQ(old_in_layouts.size(), 1);
511 auto old_layout = old_in_layouts[0];
512 Array<Integer> old_axes = params->axes;
513
514 // Deal with default axes and negative axes.
515 if (!old_axes.defined() || old_axes.size() == 0) {
516 for (int i = old_layout.ndim() - 1; i >= 0; --i) {
517 old_axes.push_back(i);
518 }
519 }
520 for (size_t i = 0; i < old_axes.size(); ++i) {
521 int axis = static_cast<int>(old_axes[i]->value);
522 if (axis < 0) {
523 int pos_axis = static_cast<int>(old_layout.ndim()) + axis;
524 old_axes.Set(i, pos_axis);
525 }
526 }
527
528 if (new_in_layouts.defined() && new_in_layouts[0].defined()) {
529 ICHECK_EQ(new_in_layouts.size(), 1);
530 auto new_layout = new_in_layouts[0];
531
532 // Update the axes based on the new layout.
533 Array<Integer> new_axes = Array<Integer>();
534 for (auto axis : old_axes) {
535 auto new_axis = new_layout.IndexOf(old_layout[axis->value]);
536 if (new_axis == -1) { // Cannot find the target axis in the new layout.
537 new_axes.clear();
538 break;
539 }
540 new_axes.push_back(new_axis);
541 }
542 if (new_axes.defined() && new_axes.size() == new_layout.ndim()) {
543 params->axes = std::move(new_axes);
544 in_layout_str = new_layout.name();
545 }
546 }
547
548 // If the input layout string cannot be determined, propagate the old layout.
549 if (in_layout_str == "") {
550 params->axes = std::move(old_axes);
551 in_layout_str = old_layout.name();
552 }
553 }
554
555 // Infer the output layout string based on the input layout and the axes.
556 Attrs new_attrs(params);
557 if (in_layout_str != "") {
558 for (auto axis : params->axes) {
559 ICHECK_LT(axis->value, in_layout_str.length());
560 out_layout_str += in_layout_str[axis->value];
561 }
562 try {
563 return InferCorrectLayoutOutput({Layout(in_layout_str)}, {Layout(out_layout_str)}, new_attrs);
564 } catch (const tvm::Error& e) {
565 // If the layout string is invalid for any reason, give up.
566 return InferCorrectLayoutOutput({Layout::Undef()}, {Layout::Undef()}, attrs);
567 }
568 }
569 return InferCorrectLayoutOutput({Layout::Undef()}, {Layout::Undef()}, attrs);
570}
571
572Array<te::Tensor> TransposeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
573 const Type& out_type) {
574 const auto* param = attrs.as<TransposeAttrs>();
575 ICHECK(param != nullptr);
576 return Array<te::Tensor>{topi::transpose(inputs[0], param->axes)};
577}
578
579Expr MakeTranspose(Expr data, Array<Integer> axes) {
580 auto attrs = make_object<TransposeAttrs>();
581 attrs->axes = std::move(axes);
582 static const Op& op = Op::Get("transpose");
583 return Call(op, {data}, Attrs(attrs), {});
584}
585
586TVM_REGISTER_GLOBAL("relay.op._make.transpose").set_body_typed(MakeTranspose);
587
588RELAY_REGISTER_OP("transpose")
589 .describe(R"code(Permutes the dimensions of an array.
590
591- **data**: The input data to the operator.
592
593- **axes**: The target axes order, reverse order if not specified.
594
595)code" TVM_ADD_FILELINE)
596 .set_num_inputs(1)
597 .set_attrs_type<TransposeAttrs>()
598 .add_argument("data", "Tensor", "The input tensor.")
599 .set_support_level(3)
600 .add_type_rel("Transpose", TransposeRel)
601 .set_attr<FTVMCompute>("FTVMCompute", TransposeCompute)
602 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", TransposeInferCorrectLayout)
603 .set_attr<TOpPattern>("TOpPattern", kInjective);
604
605/* relay.reshape */
606TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
607TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);
608
609Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs,
610 bool reverse) {
611 const auto* param = attrs.as<ReshapeAttrs>();
612 Array<IndexExpr> oshape;
613 Array<IndexExpr> ishape;
614 Array<Integer> newshape;
615
616 if (reverse) {
617 ishape.Assign(data_shape.rbegin(), data_shape.rend());
618 newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
619 } else {
620 ishape = data_shape;
621 newshape = param->newshape;
622 }
623
624 bool allowzero = param->allowzero;
625
626 std::unordered_set<size_t> used_input_dims;
627 std::unordered_set<size_t> used_output_dims;
628 size_t src_idx = 0;
629 int infer_idx = -1;
630
631 for (size_t i = 0; i < newshape.size(); ++i) {
632 int svalue = newshape[i]->value;
633 // special flag handling for shape inference.
634 if (svalue > 0) {
635 oshape.push_back(newshape[i]);
636 ++src_idx;
637 } else if (svalue == 0) {
638 if (allowzero) {
639 // 0 means empty tensor, thus default behavior
640 oshape.push_back(newshape[i]);
641 ++src_idx;
642 } else {
643 // 0 means to copy at equivilant position in data tensor
644 ICHECK_LT(src_idx, ishape.size());
645 used_input_dims.insert(src_idx);
646 used_output_dims.insert(oshape.size());
647 oshape.push_back(ishape[src_idx++]);
648 }
649 } else if (svalue == -1) {
650 // inference based on rest
651 ICHECK_LT(infer_idx, 0) << "One and only one dim can be inferred";
652 infer_idx = i;
653 oshape.push_back(1);
654 ++src_idx;
655 } else if (svalue == -2) {
656 // copy all remaining dims from source
657 while (src_idx < ishape.size()) {
658 used_input_dims.insert(src_idx);
659 used_output_dims.insert(oshape.size());
660 oshape.push_back(ishape[src_idx++]);
661 }
662 } else if (svalue == -3) {
663 // merge two dims from source
664 ICHECK_LT(src_idx + 1, ishape.size());
665 used_input_dims.insert(src_idx);
666 IndexExpr d1 = ishape[src_idx++];
667 used_input_dims.insert(src_idx);
668 IndexExpr d2 = ishape[src_idx++];
669 used_output_dims.insert(oshape.size());
670 if (d1.as<AnyNode>() || d2.as<AnyNode>()) {
671 oshape.push_back(Any());
672 } else {
673 oshape.push_back(d1 * d2);
674 }
675 } else if (svalue == -4) {
676 // split the source dim s into two dims
677 // read the left dim and then the right dim (either can be -1)
678 ICHECK_LT(i + 2, newshape.size());
679 ICHECK_LT(src_idx, ishape.size());
680 used_input_dims.insert(src_idx);
681 IndexExpr d0 = ishape[src_idx++];
682 Integer d1 = newshape[++i];
683 Integer d2 = newshape[++i];
684 if (d1->value == -1) {
685 ICHECK_NE(d2->value, -1) << "Split dims cannot both be -1.";
686 used_output_dims.insert(oshape.size());
687 if (d0.as<AnyNode>()) {
688 oshape.push_back(Any());
689 } else {
690 oshape.push_back(indexdiv(d0, d2));
691 }
692 used_output_dims.insert(oshape.size());
693 oshape.push_back(d2);
694 } else {
695 used_output_dims.insert(oshape.size());
696 oshape.push_back(d1);
697 used_output_dims.insert(oshape.size());
698 if (d2->value == -1) {
699 if (d0.as<AnyNode>()) {
700 oshape.push_back(Any());
701 } else {
702 oshape.push_back(indexdiv(d0, d1));
703 }
704 } else {
705 oshape.push_back(d2);
706 }
707 }
708 } else {
709 LOG(FATAL) << "Unsupported special value: " << svalue;
710 }
711 }
712
713 if (infer_idx >= 0) {
714 IndexExpr infer_dim = 1;
715 for (size_t i = 0; i < ishape.size(); ++i) {
716 if (used_input_dims.count(i) != 0) {
717 continue;
718 }
719 if (ishape[i].as<AnyNode>()) {
720 infer_dim = Any();
721 break;
722 }
723 infer_dim *= ishape[i];
724 }
725 if (!infer_dim.as<AnyNode>()) {
726 for (size_t i = 0; i < oshape.size(); ++i) {
727 if (used_output_dims.count(i) != 0) {
728 continue;
729 }
730 if (oshape[i].as<AnyNode>()) {
731 infer_dim = Any();
732 break;
733 }
734 infer_dim = indexdiv(infer_dim, oshape[i]);
735 }
736 }
737 arith::Analyzer ana;
738 infer_dim = ana.Simplify(infer_dim);
739 oshape.Set(infer_idx, infer_dim);
740 }
741
742 return oshape;
743}
744
745bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
746 const TypeReporter& reporter) {
747 // types: [data, result]
748 ICHECK_EQ(types.size(), 2);
749 const auto* data = types[0].as<TensorTypeNode>();
750 if (data == nullptr) {
751 ICHECK(types[0].as<IncompleteTypeNode>())
752 << "reshape: expect input type to be TensorType but get " << types[0];
753 return false;
754 }
755
756 const auto& oshape = InferNewShape(data->shape, attrs, false);
757
758 // Verify that the sum of dimensions in the output shape is the sum of
759 // dimensions in the input shape
760 Array<IndexExpr> data_shape;
761 data_shape = data->shape;
762
763 bool found_dynamic = false;
764 int64_t oshape_sum = 1;
765 for (auto& x : oshape) {
766 // Check if we have a dynamic shape. If we do, we can't verify if the
767 // reshape is valid. Dynamic shapes are marker by using Any, but can also
768 // occur from SizeVar's. In the case of SizeVar, the shape expression can
769 // be an AST. We can't easily check if we have an AST because of a ShapeVar
770 // or some other reason, so our check for dynamic shape is just if we can
771 // convert the shape to in integer or not.
772 if (!x->IsInstance<tvm::Integer::ContainerType>()) {
773 found_dynamic = true;
774 break;
775 }
776 oshape_sum *= Downcast<tvm::Integer>(x)->value;
777 }
778 int64_t data_shape_sum = 1;
779 for (auto& x : data_shape) {
780 if (!x->IsInstance<tvm::Integer::ContainerType>()) {
781 found_dynamic = true;
782 break;
783 }
784 data_shape_sum *= Downcast<tvm::Integer>(x)->value;
785 }
786 if (!found_dynamic && oshape_sum != data_shape_sum) {
787 std::ostringstream oshape_str, data_shape_str;
788 for (auto iter = oshape.begin(); iter != oshape.end(); iter++) {
789 oshape_str << (iter != oshape.begin() ? "," : "") << *iter;
790 }
791 for (auto iter = data_shape.begin(); iter != data_shape.end(); iter++) {
792 data_shape_str << (iter != data_shape.begin() ? "," : "") << *iter;
793 }
794 ICHECK_EQ(oshape_sum, data_shape_sum)
795 << "Input tensor shape(" << data_shape_str.str() << ") and reshaped shape("
796 << oshape_str.str() << ") are not compatible!";
797 }
798
799 reporter->Assign(types[1], TensorType(oshape, data->dtype));
800 return true;
801}
802
803bool ReverseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
804 const TypeReporter& reporter) {
805 // types: [data, result]
806 ICHECK_EQ(types.size(), 2);
807 const auto* data = types[0].as<TensorTypeNode>();
808 if (data == nullptr) {
809 ICHECK(types[0].as<IncompleteTypeNode>())
810 << "reshape: expect input type to be TensorType but get " << types[0];
811 return false;
812 }
813
814 const auto& oshape = InferNewShape(data->shape, attrs, true);
815
816 // Verify that the sum of dimensions in the output shape is the sum of
817 // dimensions in the input shape
818 Array<IndexExpr> data_shape;
819 data_shape.Assign(data->shape.rbegin(), data->shape.rend());
820
821 bool found_dynamic = false;
822 int64_t oshape_sum = 1;
823 for (auto& x : oshape) {
824 // Check if we have a dynamic shape. If we do, we can't verify if the
825 // reshape is valid. Dynamic shapes are marker by using Any, but can also
826 // occur from SizeVar's. In the case of SizeVar, the shape expression can
827 // be an AST. We can't easily check if we have an AST because of a ShapeVar
828 // or some other reason, so our check for dynamic shape is just if we can
829 // convert the shape to in integer or not.
830 if (!x->IsInstance<tvm::Integer::ContainerType>()) {
831 found_dynamic = true;
832 break;
833 }
834 oshape_sum *= Downcast<tvm::Integer>(x)->value;
835 }
836 int64_t data_shape_sum = 1;
837 for (auto& x : data_shape) {
838 if (!x->IsInstance<tvm::Integer::ContainerType>()) {
839 found_dynamic = true;
840 break;
841 }
842 data_shape_sum *= Downcast<tvm::Integer>(x)->value;
843 }
844 if (!found_dynamic) {
845 ICHECK_EQ(oshape_sum, data_shape_sum)
846 << "Input tensor shape and reshaped shape are not compatible";
847 }
848
849 reporter->Assign(types[1],
850 TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
851 return true;
852}
853
854Array<PrimExpr> infer_reshape_like(const Array<PrimExpr>& lhs_shape,
855 const Array<PrimExpr>& rhs_shape, const Attrs& attrs) {
856 const auto* like_attrs = attrs.as<ReshapeLikeAttrs>();
857 CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as<IntImmNode>())
858 << "lhs_end must be a concrete integer or None";
859 CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as<IntImmNode>())
860 << "rhs_end must be a concrete integer or None";
861
862 int64_t lhs_shape_size = static_cast<int64_t>(lhs_shape.size());
863 int64_t rhs_shape_size = static_cast<int64_t>(rhs_shape.size());
864 int64_t lhs_begin = static_cast<int64_t>(like_attrs->lhs_begin);
865 int64_t lhs_end =
866 like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as<IntImmNode>()->value : lhs_shape_size;
867 int64_t rhs_begin = static_cast<int64_t>(like_attrs->rhs_begin);
868 int64_t rhs_end =
869 like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as<IntImmNode>()->value : rhs_shape_size;
870
871 // handle negative axes
872 lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin;
873 lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end;
874 rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin;
875 rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end;
876
877 Array<PrimExpr> shape_like;
878 for (auto i = 0; i < lhs_begin; i++) {
879 shape_like.push_back(lhs_shape[i]);
880 }
881 for (auto i = rhs_begin; i < rhs_end; i++) {
882 shape_like.push_back(rhs_shape[i]);
883 }
884 for (auto i = lhs_end; i < lhs_shape_size; i++) {
885 shape_like.push_back(lhs_shape[i]);
886 }
887 return shape_like;
888}
889
890Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
891 const Type& out_type) {
892 // Quick path for reshape_like
893 if (!attrs.as<ReshapeAttrs>()) {
894 ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
895 auto shape_like = infer_reshape_like(inputs[0]->shape, inputs[1]->shape, attrs);
896 return {topi::reshape(inputs[0], shape_like)};
897 }
898
899 const auto* out_ttype = out_type.as<TensorTypeNode>();
900 ICHECK(out_ttype != nullptr);
901 Array<IndexExpr> newshape;
902 bool newshape_has_any = false;
903 for (auto val : out_ttype->shape) {
904 if (val->IsInstance<tir::AnyNode>() || val->IsInstance<tir::VarNode>()) {
905 newshape_has_any = true;
906 break;
907 } else {
908 newshape.push_back(val);
909 }
910 }
911
912 if (newshape_has_any) {
913 newshape = InferNewShape(inputs[0]->shape, attrs, false);
914 }
915 return {topi::reshape(inputs[0], newshape)};
916}
917
918Expr MakeReshape(Expr data, Array<Integer> newshape, bool allowzero) {
919 auto attrs = make_object<ReshapeAttrs>();
920 attrs->newshape = std::move(newshape);
921 attrs->allowzero = allowzero;
922 static const Op& op = Op::Get("reshape");
923 return Call(op, {data}, Attrs(attrs), {});
924}
925
926TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape);
927
928RELAY_REGISTER_OP("reshape")
929 .describe(R"code(Reshapes the input array.
930
931Example::
932
933To give user more convenience in without doing manual shape inference,
934some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}.
935The significance of each is explained below:
936
937- ``0`` copy this dimension from the input to the output shape.
938
939Example::
940
941- data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2)
942- data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4)
943
944- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions
945keeping the size of the new array same as that of the input array.
946At most one dimension of shape can be -1.
947
948Example::
949
950- data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4)
951- data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8)
952- data.shape = (2,3,4), newshape = (-1,), result.shape = (24,)
953
954- ``-2`` copy all/remainder of the input dimensions to the output shape.
955
956Example::
957
958- data.shape = (2,3,4), newshape = (-2,), result.shape = (2,3,4)
959- data.shape = (2,3,4), newshape = (2,-2), result.shape = (2,3,4)
960- data.shape = (2,3,4), newshape = (-2,1,1), result.shape = (2,3,4,1,1)
961
962- ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension.
963
964Example::
965
966- data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4)
967- data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20)
968- data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12)
969- data.shape = (2,3,4), newshape = (-3,-2), result.shape = (6,4)
970
971- ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1).
972
973Example::
974
975- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4)
976- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4)
977
978)code" TVM_ADD_FILELINE)
979 .set_num_inputs(1)
980 .set_attrs_type<ReshapeAttrs>()
981 .add_argument("data", "Tensor", "The input tensor.")
982 .set_support_level(3)
983 .add_type_rel("Reshape", ReshapeRel)
984 .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
985 .set_attr<TOpPattern>("TOpPattern", kInjective)
986 .set_attr<TReshapeOp>("TReshapeOp", true);
987
988/*!
989 * \brief ReshapeLikeRel User defined type constraint function.
990 * \param num_inputs Number of input types in the args.
991 * \param attrs The additional attributes of the operator.
992 * \param reporter The reporter to report solution to.
993 * \return False if the relation has not been resolved, it might be resolved later.
994 * True if this relation has been resolved.
995 */
996bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
997 const TypeReporter& reporter) {
998 ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
999 ICHECK_EQ(types.size(), 3);
1000 const auto* data = types[0].as<TensorTypeNode>();
1001 if (data == nullptr) {
1002 return false;
1003 }
1004 const auto* reshape_like = types[1].as<TensorTypeNode>();
1005 if (reshape_like == nullptr) {
1006 return false;
1007 }
1008 auto shape_like = infer_reshape_like(data->shape, reshape_like->shape, attrs);
1009 // Only check When input data has static shape.
1010 bool is_static_shape = true;
1011 for (size_t i = 0; i < data->shape.size(); ++i) {
1012 if (!data->shape[i].as<IntImmNode>()) {
1013 is_static_shape = false;
1014 break;
1015 }
1016 }
1017 auto output_type = TensorType(shape_like, data->dtype);
1018 if (is_static_shape) {
1019 ICHECK(reporter->AssertEQ(data->Size(), output_type->Size()))
1020 << "Reshape inputs size should be compatible, "
1021 << "but found data_shape " << data->shape << " not same as output_shape "
1022 << output_type->shape;
1023 }
1024 reporter->Assign(types[2], output_type);
1025 return true;
1026}
1027
1028Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
1029 Integer rhs_end) {
1030 auto attrs = make_object<ReshapeLikeAttrs>();
1031 attrs->lhs_begin = std::move(lhs_begin);
1032 attrs->lhs_end = std::move(lhs_end);
1033 attrs->rhs_begin = std::move(rhs_begin);
1034 attrs->rhs_end = std::move(rhs_end);
1035 static const Op& op = Op::Get("reshape_like");
1036 return Call(op, {lhs, rhs}, Attrs(attrs), {});
1037}
1038
1039TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike);
1040
1041RELAY_REGISTER_OP("reshape_like")
1042 .describe(R"code(Reshapes the input array by the size of another array.
1043For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
1044the input array into an output array with the same shape as the second input array.
1045.. note::
1046 Sizes for both array should be compatible.
1047Example::
1048
1049 data.shape == (1, 2, 3, 4)
1050 shape_like.shape == (6, 2, 2, 3)
1051
1052 ret = reshape_like(data, shape_like, lhs_begin=1, rhs_end=3)
1053 ret.shape == (1, 6, 2, 2)
1054)code" TVM_ADD_FILELINE)
1055 .set_attrs_type<ReshapeLikeAttrs>()
1056 .set_num_inputs(2)
1057 .add_argument("data", "Tensor", "The input tensor.")
1058 .add_argument("shape_like", "Tensor", "Shape tensor.")
1059 .set_support_level(3)
1060 .add_type_rel("ReshapeLike", ReshapeLikeRel)
1061 .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
1062 .set_attr<TOpPattern>("TOpPattern", kInjective);
1063
1064// ArgWhere
1065bool ArgWhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1066 const TypeReporter& reporter) {
1067 ICHECK_EQ(num_inputs, 1);
1068 auto tt = types[0].as<TensorTypeNode>();
1069
1070 if (tt == nullptr) {
1071 return false;
1072 }
1073
1074 const auto& input_shape = tt->shape;
1075 const auto& input_rank = input_shape.size();
1076 std::vector<IndexExpr> result_shape;
1077 result_shape.push_back(Any());
1078 result_shape.push_back(IntImm(DataType::Int(32), input_rank));
1079 reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32)));
1080 return true;
1081}
1082
1083TVM_REGISTER_GLOBAL("relay.op._make.argwhere").set_body_typed([](Expr data) {
1084 static const Op& op = Op::Get("argwhere");
1085 return Call(op, {data}, Attrs(), {});
1086});
1087
1088RELAY_REGISTER_OP("argwhere")
1089 .describe(R"doc(Find the indices of elements of a tensor that are
1090non-zero)doc" TVM_ADD_FILELINE)
1091 .set_num_inputs(1)
1092 .add_argument("condition", "Tensor", "The input condition tensor.")
1093 .add_type_rel("ArgWhere", ArgWhereRel)
1094 .set_attr<TOpIsStateful>("TOpIsStateful", false)
1095 .set_attr<TOpPattern>("TOpPattern", kOpaque)
1096 .set_support_level(10);
1097
1098// Scatter
1099TVM_REGISTER_NODE_TYPE(ScatterAttrs);
1100
1101// Scatter
1102bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1103 const TypeReporter& reporter) {
1104 ICHECK_EQ(num_inputs, 3);
1105 ICHECK_EQ(types.size(), 4);
1106 auto data = types[0].as<TensorTypeNode>();
1107 if (data == nullptr) {
1108 return false;
1109 }
1110 auto indices = types[1].as<TensorTypeNode>();
1111 if (indices == nullptr) {
1112 return false;
1113 }
1114 auto updates = types[2].as<TensorTypeNode>();
1115 if (updates == nullptr) {
1116 return false;
1117 }
1118 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
1119 << "indices of scatter must be tensor of integer";
1120 const auto param = attrs.as<ScatterAttrs>();
1121 ICHECK(param != nullptr);
1122 reporter->Assign(types[3], TensorType(data->shape, data->dtype));
1123 return true;
1124}
1125
1126TVM_REGISTER_GLOBAL("relay.op._make.scatter")
1127 .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
1128 auto attrs = make_object<ScatterAttrs>();
1129 attrs->axis = std::move(axis);
1130 static const Op& op = Op::Get("scatter");
1131 return Call(op, {data, indices, updates}, Attrs(attrs), {});
1132 });
1133
1134RELAY_REGISTER_OP("scatter")
1135 .describe(
1136 R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE)
1137 .set_num_inputs(3)
1138 .add_argument("data", "Tensor", "The input data tensor.")
1139 .add_argument("indices", "Tensor", "The indices location tensor.")
1140 .add_argument("updates", "Tensor", "The values to update the input with.")
1141 .add_type_rel("Scatter", ScatterRel)
1142 .set_attr<TOpIsStateful>("TOpIsStateful", false)
1143 .set_attr<TOpPattern>("TOpPattern", kOpaque)
1144 .set_support_level(10);
1145
1146// Scatter_add
1147TVM_REGISTER_NODE_TYPE(ScatterAddAttrs);
1148
1149// Scatter Add
1150bool ScatterAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1151 const TypeReporter& reporter) {
1152 ICHECK_EQ(num_inputs, 3);
1153 ICHECK_EQ(types.size(), 4);
1154 auto data = types[0].as<TensorTypeNode>();
1155 if (data == nullptr) {
1156 return false;
1157 }
1158 auto indices = types[1].as<TensorTypeNode>();
1159 if (indices == nullptr) {
1160 return false;
1161 }
1162 auto updates = types[2].as<TensorTypeNode>();
1163 if (updates == nullptr) {
1164 return false;
1165 }
1166 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
1167 << "indices of scatter_add must be tensor of integer";
1168 const auto param = attrs.as<ScatterAddAttrs>();
1169 ICHECK(param != nullptr);
1170 reporter->Assign(types[3], TensorType(data->shape, data->dtype));
1171 return true;
1172}
1173
1174TVM_REGISTER_GLOBAL("relay.op._make.scatter_add")
1175 .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) {
1176 auto attrs = make_object<ScatterAddAttrs>();
1177 attrs->axis = std::move(axis);
1178 static const Op& op = Op::Get("scatter_add");
1179 return Call(op, {data, indices, updates}, Attrs(attrs), {});
1180 });
1181
1182RELAY_REGISTER_OP("scatter_add")
1183 .describe(
1184 R"doc(Update data by adding values in updates at positions defined by indices)doc" TVM_ADD_FILELINE)
1185 .set_num_inputs(3)
1186 .add_argument("data", "Tensor", "The input data tensor.")
1187 .add_argument("indices", "Tensor", "The indices location tensor.")
1188 .add_argument("updates", "Tensor", "The values to update the input with.")
1189 .add_type_rel("ScatterAdd", ScatterAddRel)
1190 .set_attr<TOpIsStateful>("TOpIsStateful", false)
1191 .set_attr<TOpPattern>("TOpPattern", kOpaque)
1192 .set_support_level(10);
1193
1194// scatter_nd operator
1195TVM_REGISTER_NODE_TYPE(ScatterNDAttrs);
1196
1197bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1198 const TypeReporter& reporter) {
1199 // `types` contains: [data, indices, updates, result]
1200 ICHECK_EQ(types.size(), 4);
1201 const auto* data = types[0].as<TensorTypeNode>();
1202 const auto* indices = types[1].as<TensorTypeNode>();
1203 const auto* updates = types[2].as<TensorTypeNode>();
1204 if (data == nullptr) {
1205 ICHECK(types[0].as<IncompleteTypeNode>())
1206 << "ScatterND: expect input data type to be TensorType but got " << types[0];
1207 return false;
1208 }
1209 if (indices == nullptr) {
1210 ICHECK(types[1].as<IncompleteTypeNode>())
1211 << "ScatterND: expect indices type to be TensorType but got " << types[1];
1212 return false;
1213 }
1214 if (updates == nullptr) {
1215 ICHECK(types[2].as<IncompleteTypeNode>())
1216 << "ScatterND: expect updates type to be TensorType but got " << types[2];
1217 return false;
1218 }
1219 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
1220 << "ScatterND: indices must be a tensor of integers.";
1221
1222 const auto out_shape = data->shape;
1223 const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
1224 ICHECK(mdim) << "ScatterND needs a static shape for the first axis of indices, got "
1225 << indices->shape;
1226 const size_t kdim = indices->shape.size() - 1;
1227 const size_t ndim = out_shape.size();
1228 ICHECK_LE(size_t(mdim->value), ndim)
1229 << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices "
1230 "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N.";
1231 // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
1232 for (size_t i = 0; i < kdim; i++) {
1233 reporter->AssertEQ(indices->shape[i + 1], updates->shape[i]);
1234 }
1235
1236 std::vector<IndexExpr> oshape;
1237 for (auto& x : out_shape) {
1238 oshape.push_back(x);
1239 }
1240
1241 // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1}
1242 for (size_t i = mdim->value; i < ndim; i++) {
1243 reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
1244 }
1245
1246 reporter->Assign(types[3], TensorType(data->shape, data->dtype));
1247 return true;
1248}
1249
1250Expr MakeScatterND(Expr data, Expr indices, Expr updates, String mode) {
1251 auto attrs = make_object<ScatterNDAttrs>();
1252 attrs->mode = std::move(mode);
1253 static const Op& op = Op::Get("scatter_nd");
1254 return Call(op, {data, indices, updates}, Attrs(attrs), {});
1255}
1256
1257TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd").set_body_typed(MakeScatterND);
1258
1259// scatter_nd operator has extern schedules for CPU and GPU devices.
1260// Fusing extern schedules with Injective schedules leads to errors.
1261// So, converting the scatter_nd to Opaque to prevent compilation failures
1262RELAY_REGISTER_OP("scatter_nd")
1263 .describe(R"code(Scatter elements or slices from data and store to a tensor
1264whose shape is defined by indices.
1265
1266Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape
1267(M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}).
1268)code" TVM_ADD_FILELINE)
1269 .set_num_inputs(3)
1270 .add_argument("data", "Tensor", "The input tensor.")
1271 .add_argument("indices", "Tensor", "The indices tensor.")
1272 .add_argument("updates", "Tensor", "The input tensor.")
1273 .set_support_level(3)
1274 .add_type_rel("ScatterND", ScatterNDRel)
1275 .set_attr<TOpPattern>("TOpPattern", kOpaque);
1276
1277// Take
1278TVM_REGISTER_NODE_TYPE(TakeAttrs);
1279
1280bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1281 const TypeReporter& reporter) {
1282 // `types` contains: [data, indices, result]
1283 ICHECK_EQ(types.size(), 3);
1284 const auto* data = types[0].as<TensorTypeNode>();
1285 if (data == nullptr) {
1286 return false;
1287 }
1288 const auto* indices = types[1].as<TensorTypeNode>();
1289 if (indices == nullptr) {
1290 return false;
1291 }
1292 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
1293 << "indices of take must be tensor of integer";
1294 const auto param = attrs.as<TakeAttrs>();
1295 ICHECK(param != nullptr);
1296
1297 if (!param->axis.defined()) {
1298 std::vector<IndexExpr> oshape(indices->shape.begin(), indices->shape.end());
1299 reporter->Assign(types[2], TensorType(oshape, data->dtype));
1300 return true;
1301 }
1302
1303 std::vector<IndexExpr> oshape;
1304 const auto ndim_data = static_cast<int>(data->shape.size());
1305 const auto ndim_indices = static_cast<int>(indices->shape.size());
1306 int axis = static_cast<int>(param->axis->value);
1307 int batch_dims = static_cast<int>(param->batch_dims->value);
1308 if (axis < 0) axis += ndim_data;
1309 if (batch_dims < 0) axis += ndim_indices;
1310 ICHECK_LE(axis, ndim_data) << "axis should be with in data shape"
1311 << ", but got = " << axis;
1312 ICHECK_LE(batch_dims, ndim_indices) << "batch_dims should be with in indices shape"
1313 << ", but got = " << batch_dims;
1314 ICHECK_LE(batch_dims, axis) << "batch_dims should be less than or equal to axis"
1315 << ", but got = " << batch_dims;
1316
1317 oshape.reserve(ndim_data - 1 + ndim_indices - batch_dims);
1318 for (int i = 0; i < batch_dims; ++i) {
1319 oshape.emplace_back(data->shape[i]);
1320 }
1321 for (int i = batch_dims; i < axis; ++i) {
1322 oshape.emplace_back(data->shape[i]);
1323 }
1324 for (int i = batch_dims; i < ndim_indices; ++i) {
1325 oshape.emplace_back(indices->shape[i]);
1326 }
1327 for (int i = axis + 1; i < ndim_data; ++i) {
1328 oshape.emplace_back(data->shape[i]);
1329 }
1330
1331 reporter->Assign(types[2], TensorType(oshape, data->dtype));
1332 return true;
1333}
1334
1335Array<te::Tensor> TakeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1336 const Type& out_type) {
1337 const auto* param = attrs.as<TakeAttrs>();
1338 ICHECK(param != nullptr);
1339 if (!param->axis.defined()) {
1340 return Array<te::Tensor>{
1341 topi::take(inputs[0], inputs[1], param->batch_dims.IntValue(), param->mode)};
1342 } else {
1343 return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->batch_dims.IntValue(),
1344 param->axis.IntValue(), param->mode)};
1345 }
1346}
1347
1348Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode) {
1349 auto attrs = make_object<TakeAttrs>();
1350 attrs->batch_dims = std::move(batch_dims);
1351 attrs->axis = std::move(axis);
1352 attrs->mode = std::move(mode);
1353 static const Op& op = Op::Get("take");
1354 return Call(op, {data, indices}, Attrs(attrs), {});
1355}
1356
1357TVM_REGISTER_GLOBAL("relay.op._make.take").set_body_typed(MakeTake);
1358
1359RELAY_REGISTER_OP("take")
1360 .describe(R"code(Take elements from an array along an axis.
1361
1362When axis is not None, this function does the same thing as 'fancy' indexing
1363(indexing arrays using arrays); however, it can be easier to use if you need
1364elements along a given axis.
1365
1366**Note** that when axis is none the flattened input array is used.
1367
1368Examples::
1369
1370 a = [[ 1, 2],
1371 [ 3, 4]]
1372 indices = [3, 0, 2]
1373 take(a, indices) = [ 4, 1, 3]
1374
1375 a = [[ 1., 2.],
1376 [ 3., 4.]]
1377 indices = [1, 0]
1378 take(a, indices, axis=1) = [[ 2., 1.],
1379 [ 4., 3.]]
1380
1381)code" TVM_ADD_FILELINE)
1382 .set_attrs_type<TakeAttrs>()
1383 .set_num_inputs(2)
1384 .add_argument("data", "Tensor", "The input tensor.")
1385 .add_argument("indices", "Tensor", "The indices tensor.")
1386 .set_support_level(3)
1387 .add_type_rel("Take", TakeRel)
1388 .set_attr<FTVMCompute>("FTVMCompute", TakeCompute)
1389 .set_attr<TOpPattern>("TOpPattern", kInjective);
1390
1391// Init ops
1392TVM_REGISTER_NODE_TYPE(InitOpAttrs);
1393
1394bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1395 const TypeReporter& reporter) {
1396 ICHECK_EQ(types.size(), 2);
1397 const InitOpAttrs* param = attrs.as<InitOpAttrs>();
1398 const auto* fill_value = types[0].as<TensorTypeNode>();
1399 if (fill_value == nullptr) {
1400 return false;
1401 }
1402
1403 DataType out_dtype = param->dtype;
1404 if (out_dtype.bits() == 0) {
1405 out_dtype = fill_value->dtype;
1406 }
1407
1408 ICHECK_EQ(fill_value->shape.size(), 0)
1409 << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << ".";
1410
1411 std::vector<IndexExpr> oshape;
1412 const Array<Integer>& cshape_array = param->shape.value();
1413 for (size_t i = 0; i < cshape_array.size(); ++i) {
1414 oshape.push_back(cshape_array[i]);
1415 }
1416 reporter->Assign(types[1], TensorType(oshape, out_dtype));
1417 return true;
1418}
1419
1420Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype) {
1421 auto attrs = make_object<InitOpAttrs>();
1422 attrs->dtype = std::move(dtype);
1423 attrs->shape = std::move(shape);
1424 static const Op& op = Op::Get("full");
1425 return Call(op, {fill_value}, Attrs(attrs), {});
1426}
1427
1428Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1429 const Type& out_type) {
1430 const auto* out_ttype = out_type.as<TensorTypeNode>();
1431 return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())};
1432}
1433
1434TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull);
1435
1436RELAY_REGISTER_OP("full")
1437 .describe(R"code(Fill array with scalar value.
1438
1439)code" TVM_ADD_FILELINE)
1440 .set_attrs_type<InitOpAttrs>()
1441 .set_num_inputs(1)
1442 .add_argument("fill_value", "double", "The value to fill.")
1443 .set_support_level(3)
1444 .add_type_rel("Full", FullRel)
1445 .set_attr<FTVMCompute>("FTVMCompute", FullCompute)
1446 .set_attr<TOpPattern>("TOpPattern", kElemWise);
1447
1448bool InitOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1449 const TypeReporter& reporter) {
1450 // types = [ret_type]
1451 ICHECK_EQ(types.size(), 1);
1452
1453 const InitOpAttrs* param = attrs.as<InitOpAttrs>();
1454 ICHECK(param);
1455
1456 DataType out_dtype = param->dtype;
1457 std::vector<IndexExpr> oshape;
1458
1459 const Array<Integer>& cshape_array = param->shape.value();
1460 for (size_t i = 0; i < cshape_array.size(); ++i) {
1461 oshape.push_back(cshape_array[i]);
1462 }
1463 reporter->Assign(types[0], TensorType(oshape, out_dtype));
1464 return true;
1465}
1466
1467Expr MakeZeros(Array<Integer> shape, DataType dtype) {
1468 auto attrs = make_object<InitOpAttrs>();
1469 attrs->shape = std::move(shape);
1470 attrs->dtype = std::move(dtype);
1471 static const Op& op = Op::Get("zeros");
1472 return Call(op, {}, Attrs(attrs), {});
1473}
1474
1475TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros);
1476
1477RELAY_REGISTER_OP("zeros")
1478 .describe(R"code(Fill array with zeros.
1479
1480)code" TVM_ADD_FILELINE)
1481 .set_attrs_type<InitOpAttrs>()
1482 .set_num_inputs(0)
1483 .set_support_level(3)
1484 .add_type_rel("InitOp", InitOpRel);
1485
1486Expr MakeOnes(Array<Integer> shape, DataType dtype) {
1487 auto attrs = make_object<InitOpAttrs>();
1488 attrs->shape = std::move(shape);
1489 attrs->dtype = std::move(dtype);
1490 static const Op& op = Op::Get("ones");
1491 return Call(op, {}, Attrs(attrs), {});
1492}
1493
1494TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes);
1495
1496RELAY_REGISTER_OP("ones")
1497 .describe(R"code(Fill array with ones.
1498
1499)code" TVM_ADD_FILELINE)
1500 .set_attrs_type<InitOpAttrs>()
1501 .set_num_inputs(0)
1502 .set_support_level(3)
1503 .add_type_rel("InitOp", InitOpRel);
1504
1505bool FullLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1506 const TypeReporter& reporter) {
1507 ICHECK_EQ(types.size(), 3);
1508 const auto* data = types[0].as<TensorTypeNode>();
1509 if (data == nullptr) {
1510 return false;
1511 }
1512 const auto* fill_value = types[1].as<TensorTypeNode>();
1513 if (fill_value == nullptr) {
1514 return false;
1515 }
1516
1517 ICHECK_EQ(fill_value->shape.size(), 0)
1518 << "The fill value should be a scalar but here it has dimension " << fill_value->shape.size()
1519 << ".";
1520
1521 reporter->Assign(types[2], TensorType(data->shape, data->dtype));
1522 return true;
1523}
1524
1525Array<te::Tensor> FullLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1526 const Type& out_type) {
1527 return {topi::full_like(inputs[0], inputs[1]())};
1528}
1529
1530Expr MakeFullLike(Expr data, Expr fill_value) {
1531 static const Op& op = Op::Get("full_like");
1532 return Call(op, {data, fill_value}, Attrs(), {});
1533}
1534
1535TVM_REGISTER_GLOBAL("relay.op._make.full_like").set_body_typed(MakeFullLike);
1536
1537RELAY_REGISTER_OP("full_like")
1538 .describe(R"code(Return an scalar value array with the same shape
1539and type as the input array.
1540
1541)code" TVM_ADD_FILELINE)
1542 .set_num_inputs(2)
1543 .add_argument("data", "Tensor", "The input tensor.")
1544 .add_argument("fill_value", "double", "Scalar value to fill.")
1545 .set_support_level(3)
1546 .add_type_rel("FullLike", FullLikeRel)
1547 .set_attr<FTVMCompute>("FTVMCompute", FullLikeCompute)
1548 .set_attr<TOpPattern>("TOpPattern", kElemWise);
1549
1550// arange operator
1551TVM_REGISTER_NODE_TYPE(ArangeAttrs);
1552
1553bool ArangeRel(const Array<Type>& types, int num_inputs, const Attrs& raw_attrs,
1554 const TypeReporter& reporter) {
1555 ICHECK_EQ(types.size(), 4);
1556 const ArangeAttrs* attrs = raw_attrs.as<ArangeAttrs>();
1557 const ConstantNode *cstart, *cstop, *cstep;
1558
1559 reporter->Assign(types[0], types[1]);
1560 reporter->Assign(types[1], types[2]);
1561 reporter->Assign(types[2], TensorType({}, attrs->dtype));
1562
1563 if ((cstart = attrs->start.as<ConstantNode>()) && (cstop = attrs->stop.as<ConstantNode>()) &&
1564 (cstep = attrs->step.as<ConstantNode>())) {
1565 double start = ToScalar(cstart->data);
1566 double stop = ToScalar(cstop->data);
1567 double step = ToScalar(cstep->data);
1568 int32_t num_elem = static_cast<int32_t>(std::ceil((stop - start) / step));
1569 ICHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start
1570 << ", " << attrs->stop << ", " << attrs->step;
1571 reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype));
1572 return true;
1573 } else {
1574 reporter->Assign(types[3], TensorType({Any()}, attrs->dtype));
1575 return true;
1576 }
1577}
1578
1579inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop,
1580 const te::Tensor& step, tvm::DataType dtype,
1581 std::string name = "T_arange_dynamic",
1582 std::string tag = topi::kInjective) {
1583 ICHECK_EQ(start.ndim(), 0);
1584 ICHECK_EQ(stop.ndim(), 0);
1585 ICHECK_EQ(step.ndim(), 0);
1586 tvm::PrimExpr num_elem = tvm::tir::Var("num_elem");
1587 return te::compute(
1588 {num_elem},
1589 [&](const Array<tvm::tir::Var>& indices) {
1590 Array<PrimExpr> empty_indices;
1591 return tvm::cast(dtype, start(empty_indices) + step(empty_indices) * indices[0]);
1592 },
1593 name, tag);
1594}
1595
1596Array<te::Tensor> ArangeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1597 const Type& out_type) {
1598 const ArangeAttrs* param = attrs.as<ArangeAttrs>();
1599 ICHECK(param != nullptr);
1600 te::Tensor start = inputs[0];
1601 te::Tensor stop = inputs[1];
1602 te::Tensor step = inputs[2];
1603 return {DynamicArange(start, stop, step, param->dtype)};
1604}
1605
1606Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) {
1607 auto attrs = make_object<ArangeAttrs>();
1608 attrs->start = start;
1609 attrs->stop = stop;
1610 attrs->step = step;
1611 attrs->dtype = dtype;
1612 static const Op& op = Op::Get("arange");
1613 return Call(op, {start, stop, step}, Attrs(attrs), {});
1614}
1615
1616TVM_REGISTER_GLOBAL("relay.op._make.arange").set_body_typed(MakeArange);
1617
1618// An issue with the existing design is that we require dependency
1619// to type the operator precisely.
1620//
1621// Supporting this in general is challenging so we duplicate the
1622// secondary arguments as args and attributes.
1623//
1624// In this way reify the arguments at both the value and type level.
1625//
1626// In the case our arguments are constant we can immediately recover
1627// the type of arange.
1628//
1629// In general I think we should avoid this pattern, and introduce
1630// a secondary shape analysis to recover more precise information.
1631RELAY_REGISTER_OP("arange")
1632 .describe(R"code(Returns evenly spaced values within a given interval.
1633
1634)code" TVM_ADD_FILELINE)
1635 .set_attrs_type<ArangeAttrs>()
1636 .set_num_inputs(3)
1637 .add_argument("start", "Expr", "Start of interval. The interval includes this value.")
1638 .add_argument("end", "Expr", "Stop of interval. The interval does not include this value.")
1639 .add_argument("step", "Expr", "Spacing between values.")
1640 .set_support_level(3)
1641 .add_type_rel("Arange", ArangeRel)
1642 .set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
1643 // TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
1644 .set_attr<TOpPattern>("TOpPattern", kOpaque)
1645 .set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
1646
1647// repeat operator
1648TVM_REGISTER_NODE_TYPE(RepeatAttrs);
1649
1650bool RepeatRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1651 const TypeReporter& reporter) {
1652 // `types` contains: [data, result]
1653 ICHECK_EQ(types.size(), 2);
1654 const auto* data = types[0].as<TensorTypeNode>();
1655 if (data == nullptr) {
1656 ICHECK(types[0].as<IncompleteTypeNode>())
1657 << "repeat: expect input type to be TensorType but get " << types[0];
1658 return false;
1659 }
1660 const auto* param = attrs.as<RepeatAttrs>();
1661 const int ndim = static_cast<int>(data->shape.size());
1662 const int repeats = param->repeats.IntValue();
1663 const int axis = param->axis.IntValue();
1664 ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`"
1665 << ", but got repeats = " << repeats;
1666 ICHECK(-ndim - 1 <= axis && axis <= ndim)
1667 << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1668 << ", but got axis = " << axis << ", and data.ndim = " << ndim;
1669 const int pivot = axis < 0 ? ndim + axis : axis;
1670 std::vector<IndexExpr> oshape;
1671 oshape.reserve(ndim + repeats);
1672 for (int i = 0; i < pivot; ++i) {
1673 oshape.emplace_back(data->shape[i]);
1674 }
1675 if (data->shape[pivot].as<AnyNode>()) {
1676 oshape.emplace_back(Any());
1677 } else {
1678 oshape.emplace_back(data->shape[pivot] * repeats);
1679 }
1680 for (int i = pivot + 1; i < ndim; ++i) {
1681 oshape.emplace_back(data->shape[i]);
1682 }
1683 reporter->Assign(types[1], TensorType(oshape, data->dtype));
1684 return true;
1685}
1686
1687Array<te::Tensor> RepeatCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1688 const Type& out_type) {
1689 const RepeatAttrs* param = attrs.as<RepeatAttrs>();
1690 ICHECK(param != nullptr);
1691 return {topi::repeat(inputs[0], param->repeats.IntValue(), param->axis.IntValue())};
1692}
1693
1694Expr MakeRepeat(Expr data, int repeats, int axis) {
1695 auto attrs = make_object<RepeatAttrs>();
1696 attrs->repeats = repeats;
1697 attrs->axis = axis;
1698 static const Op& op = Op::Get("repeat");
1699 return Call(op, {data}, Attrs(attrs), {});
1700}
1701
1702TVM_REGISTER_GLOBAL("relay.op._make.repeat").set_body_typed(MakeRepeat);
1703
1704RELAY_REGISTER_OP("repeat")
1705 .describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
1706
1707- **data**: The input data to the operator.
1708
1709)code" TVM_ADD_FILELINE)
1710 .set_num_inputs(1)
1711 .set_attrs_type<RepeatAttrs>()
1712 .add_argument("data", "Tensor", "The input tensor.")
1713 .set_support_level(3)
1714 .add_type_rel("Repeat", RepeatRel)
1715 .set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
1716 .set_attr<TOpPattern>("TOpPattern", kBroadcast);
1717
1718bool SparseFillEmptyRowsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1719 const TypeReporter& reporter) {
1720 // types: [sparse_indices, sparse_values, dense_shape, default_value, result]
1721 ICHECK_EQ(types.size(), 5) << "SparseFillEmptyRowsRel expects 5 inputs but " << types.size()
1722 << "provided";
1723 std::vector<Type> fields;
1724 auto sparse_indices = types[0].as<TensorTypeNode>();
1725 auto ndims = sparse_indices->shape[1];
1726 fields.push_back(TensorType(Array<PrimExpr>{Any(), ndims}, tvm::DataType::Int(64)));
1727 fields.push_back(TensorType(Array<PrimExpr>{Any()}, tvm::DataType::Int(64)));
1728 fields.push_back(TensorType(Array<PrimExpr>{Any()}, tvm::DataType::Int(64)));
1729 reporter->Assign(types[types.size() - 1], TupleType(Array<Type>(fields)));
1730 return true;
1731}
1732
1733Expr MakeSparseFillEmptyRows(Expr sparse_indices, Expr sparse_values, Expr dense_shape,
1734 Expr default_value) {
1735 static const Op& op = Op::Get("sparse_fill_empty_rows");
1736 return Call(op, {sparse_indices, sparse_values, dense_shape, default_value}, Attrs(), {});
1737}
1738
1739TVM_REGISTER_GLOBAL("relay.op._make.sparse_fill_empty_rows")
1740 .set_body_typed(MakeSparseFillEmptyRows);
1741
1742RELAY_REGISTER_OP("sparse_fill_empty_rows")
1743 .describe(
1744 R"code(Fill empty rows of a sparse tensor with a default value.)code" TVM_ADD_FILELINE)
1745 .set_num_inputs(4)
1746 .add_argument("sparse_indices", "Tensor",
1747 "A 2-D int64 tensor of shape [N, ndims], which specifies the indices of the"
1748 "elements in the sparse tensor that contain nonzero values. COO Format")
1749 .add_argument(
1750 "sparse_values", "Tensor",
1751 "A 1-D tensor[N] which supplies the values for each element in indices. COO Format")
1752 .add_argument("dense_shape", "Tensor",
1753 "A 1-D int64 tensor of shape [ndims], which specifies the dense_shape of the"
1754 "sparse tensor. Takes a list indicating the number of elements in each "
1755 "dimension")
1756 .add_argument("default_value", "Tensor",
1757 "The value to fill for empty rows, with the same type as sparse_values")
1758 .add_type_rel("sparse_fill_empty_rows", SparseFillEmptyRowsRel)
1759 .set_support_level(3)
1760 .set_attr<TOpPattern>("TOpPattern", kOpaque);
1761
1762bool SparseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1763 const TypeReporter& reporter) {
1764 // types: [sparse_indices, prev_shape, new_shape, result]
1765 ICHECK_EQ(types.size(), 4) << "SparseReshapeRel expects 4 types but " << types.size()
1766 << " provided";
1767 ICHECK_EQ(num_inputs, 3) << "SparseReshapeRel expects 4 inputs but " << num_inputs << " provided";
1768 auto sparse_indices = types[0].as<TensorTypeNode>();
1769 auto prev_shape = types[1].as<TensorTypeNode>();
1770 auto new_shape = types[2].as<TensorTypeNode>();
1771 if (sparse_indices == nullptr || prev_shape == nullptr || new_shape == nullptr) {
1772 return false;
1773 }
1774 CHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers";
1775 CHECK(prev_shape->dtype.is_int()) << "prev_shape must be tensor of integers";
1776 CHECK(new_shape->dtype.is_int()) << "new_shape must be tensor of integers";
1777 ICHECK_EQ(sparse_indices->shape.size(), 2) << "sparse_indices must be 2-D tensor";
1778 ICHECK_EQ(prev_shape->shape.size(), 1) << "prev_shape must be 1-D tensor";
1779 ICHECK_EQ(new_shape->shape.size(), 1) << "new_shape must be 1-D tensor";
1780 std::vector<Type> fields;
1781 Array<PrimExpr> new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]};
1782 fields.push_back(TensorType(new_sparse_indices_shape, sparse_indices->dtype));
1783 fields.push_back(TensorType(new_shape->shape, new_shape->dtype));
1784 reporter->Assign(types[3], TupleType(Array<Type>(fields)));
1785 return true;
1786}
1787
1788Expr MakeSparseReshape(Expr sparse_indices, Expr prev_shape, Expr new_shape) {
1789 static const Op& op = Op::Get("sparse_reshape");
1790 return Call(op, {sparse_indices, prev_shape, new_shape}, Attrs(), {});
1791}
1792
1793TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape").set_body_typed(MakeSparseReshape);
1794
1795RELAY_REGISTER_OP("sparse_reshape")
1796 .describe(R"code(Return new sparse indices of the reshaped tensor
1797)code" TVM_ADD_FILELINE)
1798 .set_num_inputs(3)
1799 .add_argument("sparse_indices", "Tensor",
1800 "A 2-D tensor of shape [N, ndims], which specifies the indices of the"
1801 "elements in the sparse tensor that contain nonzero values. COO Format")
1802 .add_argument("prev_shape", "Tensor",
1803 "A 1-D tensor of shape [ndims], which specifies the previous dense shape of the"
1804 "sparse tensor")
1805 .add_argument("new_shape", "Tensor",
1806 "A 1-D tensor of shape [ndims], which specifies the desired dense shape of the"
1807 "sparse tensor")
1808 .add_type_rel("sparse_reshape", SparseReshapeRel)
1809 .set_attr<TOpPattern>("TOpPattern", kInjective)
1810 .set_support_level(3);
1811
1812TVM_REGISTER_NODE_TYPE(StftAttrs);
1813
1814bool STFTRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1815 const TypeReporter& reporter) {
1816 // types: [data, window, result]
1817 ICHECK_EQ(types.size(), 3) << "STFTRel expects 3 types but " << types.size() << "provided";
1818 ICHECK_EQ(num_inputs, 2) << "Unique: expect 2 inputs but " << num_inputs << " provided";
1819 auto data = types[0].as<TensorTypeNode>();
1820 if (data == nullptr) {
1821 ICHECK(types[0].as<IncompleteTypeNode>())
1822 << "Unique: expect input type to be TensorType but get " << types[0];
1823 return false;
1824 }
1825 const auto* param = attrs.as<StftAttrs>();
1826 const int ndim = static_cast<int>(data->shape.size());
1827 std::vector<IndexExpr> oshape;
1828 int dim = 0;
1829 if (ndim == 2) {
1830 oshape.push_back(data->shape[0]); // batch dimension
1831 dim += 1;
1832 }
1833 oshape.push_back(param->onesided ? param->n_fft / 2 + 1 : param->n_fft);
1834 if (data->shape[dim].as<AnyNode>())
1835 oshape.push_back(Any());
1836 else
1837 oshape.push_back(indexdiv((data->shape[dim] - param->n_fft), param->hop_length) +
1838 1); // n_frames
1839 oshape.push_back(2);
1840 reporter->Assign(types[2], TensorType(oshape, data->dtype));
1841 return true;
1842}
1843
1844Expr MakeSTFT(Expr data, int n_fft, int hop_length, int win_length, Expr window, bool normalized,
1845 bool onesided) {
1846 auto attrs = make_object<StftAttrs>();
1847 attrs->n_fft = n_fft;
1848 attrs->hop_length = hop_length;
1849 attrs->win_length = win_length;
1850 attrs->normalized = normalized;
1851 attrs->onesided = onesided;
1852 static const Op& op = Op::Get("stft");
1853 return Call(op, {data, window}, Attrs(attrs), {});
1854}
1855
1856TVM_REGISTER_GLOBAL("relay.op._make.stft").set_body_typed(MakeSTFT);
1857
1858RELAY_REGISTER_OP("stft")
1859 .describe(
1860 R"code(The STFT computes the Fourier transform of short overlapping windows of the input.
1861)code" TVM_ADD_FILELINE)
1862 .set_num_inputs(2)
1863 .add_argument("data", "Tensor", "the input tensor")
1864 .add_argument("window", "Tensor", "the optional window function")
1865 .add_type_rel("stft", STFTRel)
1866 .set_support_level(3)
1867 .set_attr<TOpPattern>("TOpPattern", kOpaque);
1868
1869// meshgrid operator
1870TVM_REGISTER_NODE_TYPE(MeshgridAttrs);
1871
1872bool MeshgridRel(const Array<Type>& types, int num_inputs, const Attrs& raw_attrs,
1873 const TypeReporter& reporter) {
1874 // types: [data, result]
1875 ICHECK_EQ(types.size(), 2);
1876 const MeshgridAttrs* attrs = raw_attrs.as<MeshgridAttrs>();
1877 const auto* tensor_tuple = types[0].as<TupleTypeNode>();
1878 if (tensor_tuple == nullptr) {
1879 throw CompileError(ErrorBuilder()
1880 << "meshgrid requires a tuple of tensors as the first argument, found "
1881 << PrettyPrint(types[0]));
1882 } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
1883 return false;
1884 }
1885 const int data_length = static_cast<int>(tensor_tuple->fields.size());
1886
1887 // Get first dtype.
1888 const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
1889 const DataType dtype = first->dtype;
1890
1891 // Get size of output grid.
1892 std::vector<IndexExpr> grid_shape;
1893 grid_shape.reserve(data_length);
1894 for (const Type& ele : tensor_tuple->fields) {
1895 if (ele.as<IncompleteTypeNode>()) {
1896 return false;
1897 }
1898 const auto& e = Downcast<TensorType>(ele);
1899 int e_ndim = static_cast<int>(e->shape.size());
1900 const DataType& e_dtype = e->dtype;
1901 if (e_dtype != dtype) {
1902 throw CompileError("relay.meshgrid requires all tensors have the same dtype");
1903 }
1904 if (e_ndim == 0) {
1905 grid_shape.emplace_back(1);
1906 } else if (e_ndim == 1) {
1907 grid_shape.emplace_back(e->shape[0]);
1908 } else {
1909 throw CompileError("relay.meshgrid requires all tensors be either scalars or 1-D vectors.");
1910 }
1911 }
1912
1913 // "xy" mode swaps first two dimensions
1914 if (attrs->indexing == "xy" && grid_shape.size() >= 2) {
1915 std::swap(grid_shape[0], grid_shape[1]);
1916 }
1917
1918 // There is one output grid for each input, all with same shape.
1919 std::vector<Type> grids;
1920 grids.reserve(data_length);
1921 for (int i = 0; i < data_length; i++) {
1922 grids.emplace_back(TensorType(grid_shape, dtype));
1923 }
1924 reporter->Assign(types[1], TupleType(Array<Type>(grids)));
1925 return true;
1926}
1927
1928Array<te::Tensor> MeshgridCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
1929 const Type& out_type) {
1930 const MeshgridAttrs* param = attrs.as<MeshgridAttrs>();
1931 ICHECK(param != nullptr);
1932 return {topi::meshgrid(inputs, param->indexing)};
1933}
1934
1935Expr MakeMeshgrid(Expr data, String indexing) {
1936 auto attrs = make_object<MeshgridAttrs>();
1937 attrs->indexing = std::move(indexing);
1938 static const Op& op = Op::Get("meshgrid");
1939 return Call(op, {data}, Attrs(attrs), {});
1940}
1941
1942TVM_REGISTER_GLOBAL("relay.op._make.meshgrid").set_body_typed(MakeMeshgrid);
1943
1944RELAY_REGISTER_OP("meshgrid")
1945 .describe(R"code(Create coordinate matrices from coordinate vectors.
1946
1947)code" TVM_ADD_FILELINE)
1948 .set_attrs_type<MeshgridAttrs>()
1949 .set_num_inputs(1)
1950 .add_argument("data", "Tensor", "The input list of tensors.")
1951 .set_support_level(3)
1952 .add_type_rel("Meshgrid", MeshgridRel)
1953 .set_attr<FTVMCompute>("FTVMCompute", MeshgridCompute)
1954 .set_attr<TOpPattern>("TOpPattern", kInjective);
1955
1956// tile operator
1957TVM_REGISTER_NODE_TYPE(TileAttrs);
1958
1959bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
1960 const TypeReporter& reporter) {
1961 // `types` contains: [data, result]
1962 ICHECK_EQ(types.size(), 2);
1963 const auto* data = types[0].as<TensorTypeNode>();
1964 if (data == nullptr) {
1965 ICHECK(types[0].as<IncompleteTypeNode>())
1966 << "tile: expect input type to be TensorType but get " << types[0];
1967 return false;
1968 }
1969 const auto* param = attrs.as<TileAttrs>();
1970 const size_t ndim = data->shape.size();
1971 const Array<Integer>& reps = param->reps;
1972 // check dimension match
1973 ICHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim;
1974 const size_t rndim = reps.size();
1975 for (size_t i = 0; i < rndim; ++i) {
1976 if (const tvm::tir::IntImmNode* val = reps[i].as<tvm::tir::IntImmNode>()) {
1977 ICHECK_GT(val->value, 0) << "Tile reps value should always be larger than 0, but get: "
1978 << val->value;
1979 }
1980 }
1981 size_t tndim = (ndim > rndim) ? ndim : rndim;
1982 // re-construct data shape or reps shape
1983 std::vector<IndexExpr> data_shape;
1984 std::vector<IndexExpr> reps_shape;
1985 data_shape.reserve(tndim);
1986 reps_shape.reserve(tndim);
1987 if (ndim == rndim) {
1988 for (size_t i = 0; i < tndim; ++i) {
1989 data_shape.emplace_back(data->shape[i]);
1990 reps_shape.emplace_back(reps[i]);
1991 }
1992 } else if (ndim > rndim) {
1993 for (size_t i = 0; i < ndim; ++i) {
1994 data_shape.emplace_back(data->shape[i]);
1995 }
1996 for (size_t i = 0; i < (ndim - rndim); ++i) {
1997 reps_shape.emplace_back(1);
1998 }
1999 for (size_t i = 0; i < rndim; ++i) {
2000 reps_shape.emplace_back(reps[i]);
2001 }
2002 } else {
2003 for (size_t i = 0; i < rndim; ++i) {
2004 reps_shape.emplace_back(reps[i]);
2005 }
2006 for (size_t i = 0; i < (rndim - ndim); ++i) {
2007 data_shape.emplace_back(1);
2008 }
2009 for (size_t i = 0; i < ndim; ++i) {
2010 data_shape.emplace_back(data->shape[i]);
2011 }
2012 }
2013 std::vector<IndexExpr> oshape;
2014 oshape.reserve(tndim);
2015 for (size_t i = 0; i < tndim; ++i) {
2016 // Save Any if it is dynamic shape
2017 if (!data_shape[i].as<IntImmNode>()) {
2018 oshape.emplace_back(Any());
2019 } else {
2020 oshape.emplace_back(data_shape[i] * reps_shape[i]);
2021 }
2022 }
2023 reporter->Assign(types[1], TensorType(oshape, data->dtype));
2024 return true;
2025}
2026
2027Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2028 const Type& out_type) {
2029 const TileAttrs* param = attrs.as<TileAttrs>();
2030 ICHECK(param != nullptr);
2031 return {topi::tile(inputs[0], param->reps)};
2032}
2033
2034Expr MakeTile(Expr data, Array<Integer> reps) {
2035 auto attrs = make_object<TileAttrs>();
2036 attrs->reps = reps;
2037 static const Op& op = Op::Get("tile");
2038 return Call(op, {data}, Attrs(attrs), {});
2039}
2040
2041TVM_REGISTER_GLOBAL("relay.op._make.tile").set_body_typed(MakeTile);
2042
2043RELAY_REGISTER_OP("tile")
2044 .describe(R"code(Repeat the whole array multiple times.
2045
2046- **data**: The input data to the operator.
2047
2048)code" TVM_ADD_FILELINE)
2049 .set_num_inputs(1)
2050 .set_attrs_type<TileAttrs>()
2051 .add_argument("data", "Tensor", "The input tensor.")
2052 .set_support_level(3)
2053 .add_type_rel("Tile", TileRel)
2054 .set_attr<FTVMCompute>("FTVMCompute", TileCompute)
2055 .set_attr<TOpPattern>("TOpPattern", kBroadcast);
2056
2057// reverse operator
2058TVM_REGISTER_NODE_TYPE(ReverseAttrs);
2059
2060bool ReverseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2061 const TypeReporter& reporter) {
2062 // `types` contains: [data, result]
2063 ICHECK_EQ(types.size(), 2);
2064 const auto* data = types[0].as<TensorTypeNode>();
2065 if (data == nullptr) {
2066 ICHECK(types[0].as<IncompleteTypeNode>())
2067 << "reverse: expect input type to be TensorType but get " << types[0];
2068 return false;
2069 }
2070 const auto* param = attrs.as<ReverseAttrs>();
2071 const int ndim = static_cast<int>(data->shape.size());
2072 const int axis = param->axis.IntValue();
2073 ICHECK(-ndim <= axis && axis < ndim)
2074 << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]"
2075 << ", but got axis = " << axis << ", and data.ndim = " << ndim;
2076 reporter->Assign(types[1], types[0]);
2077 return true;
2078}
2079
2080Array<te::Tensor> ReverseCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2081 const Type& out_type) {
2082 const ReverseAttrs* param = attrs.as<ReverseAttrs>();
2083 ICHECK(param != nullptr);
2084 // pass empty seq_length tensor to reverse_sequence
2085 return {topi::reverse_sequence(inputs[0], te::Tensor(), param->axis.IntValue())};
2086}
2087
2088Expr MakeReverse(Expr data, int axis) {
2089 auto attrs = make_object<ReverseAttrs>();
2090 attrs->axis = axis;
2091 static const Op& op = Op::Get("reverse");
2092 return Call(op, {data}, Attrs(attrs), {});
2093}
2094
2095TVM_REGISTER_GLOBAL("relay.op._make.reverse").set_body_typed(MakeReverse);
2096
2097RELAY_REGISTER_OP("reverse")
2098 .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
2099
2100- **data**: The input data to the operator.
2101
2102)code" TVM_ADD_FILELINE)
2103 .set_num_inputs(1)
2104 .set_attrs_type<ReverseAttrs>()
2105 .add_argument("data", "Tensor", "The input tensor.")
2106 .set_support_level(3)
2107 .add_type_rel("Reverse", ReverseRel)
2108 .set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
2109 .set_attr<TOpPattern>("TOpPattern", kInjective);
2110
2111// reverse sequence operator
2112TVM_REGISTER_NODE_TYPE(ReverseSequenceAttrs);
2113
2114bool ReverseSequenceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2115 const TypeReporter& reporter) {
2116 // `types` contains: [data, seq_lengths, result]
2117 ICHECK_EQ(types.size(), 3);
2118 const auto* data = types[0].as<TensorTypeNode>();
2119
2120 if (data == nullptr) {
2121 ICHECK(types[0].as<IncompleteTypeNode>())
2122 << "reverse_sequence: expect input type to be TensorType but get " << types[0];
2123 return false;
2124 }
2125
2126 const auto* seq_lengths = types[1].as<TensorTypeNode>();
2127 if (seq_lengths == nullptr) {
2128 ICHECK(types[1].as<IncompleteTypeNode>())
2129 << "reverse_sequence: expect input type to be TensorType but get " << types[1];
2130 return false;
2131 }
2132
2133 const int seq_lengths_dim = static_cast<int>(seq_lengths->shape.size());
2134 ICHECK(seq_lengths_dim == 1) << "For reverse_sequnece, seq_lengths must be a 1D vector";
2135 ICHECK(seq_lengths->dtype.is_int())
2136 << "For reverse_sequnece, seq_lengths must be tensor of integer";
2137
2138 const auto* param = attrs.as<ReverseSequenceAttrs>();
2139 const int ndim = static_cast<int>(data->shape.size());
2140 int batch_axis = param->batch_axis.IntValue();
2141 ICHECK(-ndim <= batch_axis && batch_axis < ndim)
2142 << "reverse_sequence only accepts `batch_axis` in [-data.ndim, data.ndim - 1]"
2143 << ", but got batch_axis = " << batch_axis << ", and data.ndim = " << ndim;
2144
2145 if (batch_axis < 0) {
2146 batch_axis = static_cast<int>(data->shape.size()) + batch_axis;
2147 }
2148 ICHECK(reporter->Assert(seq_lengths->shape[0] == data->shape[batch_axis]))
2149 << "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
2150 << ", but got dimension of batch_axis = " << data->shape[batch_axis]
2151 << ", and seq_length size = " << seq_lengths->shape[0];
2152
2153 const int seq_axis = param->seq_axis.IntValue();
2154 ICHECK(-ndim <= seq_axis && seq_axis < ndim)
2155 << "reverse_sequnece only accepts `seq_axis` in [-data.ndim, data.ndim - 1]"
2156 << ", but got seq_axis = " << seq_axis << ", and data.ndim = " << ndim;
2157
2158 reporter->Assign(types[2], types[0]);
2159 return true;
2160}
2161
2162Array<te::Tensor> ReverseSequenceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2163 const Type& out_type) {
2164 const ReverseSequenceAttrs* param = attrs.as<ReverseSequenceAttrs>();
2165 ICHECK(param != nullptr);
2166 return {topi::reverse_sequence(inputs[0], inputs[1], param->seq_axis.IntValue(),
2167 param->batch_axis.IntValue())};
2168}
2169
2170Expr MakeReverseSequence(Expr data, Expr seq_lengths, int seq_axis, int batch_axis) {
2171 auto attrs = make_object<ReverseSequenceAttrs>();
2172 attrs->seq_axis = seq_axis;
2173 attrs->batch_axis = batch_axis;
2174 static const Op& op = Op::Get("reverse_sequence");
2175 return Call(op, {data, seq_lengths}, Attrs(attrs), {});
2176}
2177
2178TVM_REGISTER_GLOBAL("relay.op._make.reverse_sequence").set_body_typed(MakeReverseSequence);
2179
2180RELAY_REGISTER_OP("reverse_sequence")
2181 .describe(R"code(Reverses the tensor for variable length slices.
2182Input is first sliced along batch axis and then elements are reversed along seq axis.
2183
2184- **data**: The input data to the operator.
2185
2186- **seq_lengths**: A 1D Tensor with length data.dims[batch_axis].
2187
2188- **seq_axis**: The axis along which the elements will be reversed. Default is 1.
2189
2190- **batch_axis**: The axis along which the tensor will be sliced. Default is 0.
2191
2192)code" TVM_ADD_FILELINE)
2193 .set_num_inputs(2)
2194 .set_attrs_type<ReverseSequenceAttrs>()
2195 .add_argument("data", "Tensor", "The input tensor.")
2196 .add_argument("seq_lengths", "Tensor", "A 1D Tensor with length data.dims[batch_axis]")
2197 .set_support_level(3)
2198 .add_type_rel("ReverseSequence", ReverseSequenceRel)
2199 .set_attr<FTVMCompute>("FTVMCompute", ReverseSequenceCompute)
2200 .set_attr<TOpPattern>("TOpPattern", kInjective);
2201
2202// where operator
2203bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2204 const TypeReporter& reporter) {
2205 ICHECK_EQ(types.size(), 4U);
2206 const auto* condition = types[0].as<TensorTypeNode>();
2207 const auto* x = types[1].as<TensorTypeNode>();
2208 const auto* y = types[2].as<TensorTypeNode>();
2209
2210 if (condition == nullptr || x == nullptr || y == nullptr) {
2211 return false;
2212 }
2213
2214 ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
2215 << y->dtype;
2216
2217 auto tensor_ty_condition = GetRef<TensorType>(condition);
2218 auto tensor_ty_x = GetRef<TensorType>(x);
2219 auto tensor_ty_y = GetRef<TensorType>(y);
2220
2221 auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype);
2222 auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype);
2223
2224 reporter->Assign(types[3], ret_ty);
2225 return true;
2226}
2227
2228// Positional relay function to create where operator.
2229Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) {
2230 static const Op& op = Op::Get("where");
2231 return Call(op, {condition, x, y});
2232}
2233
2234Array<te::Tensor> WhereCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2235 const Type& out_type) {
2236 return {topi::where(inputs[0], inputs[1], inputs[2])};
2237}
2238
2239TVM_REGISTER_GLOBAL("relay.op._make.where").set_body_typed(MakeWhere);
2240
2241RELAY_REGISTER_OP("where")
2242 .describe(R"code(
2243Return the elements, either from x or y, depending on the condition.
2244
2245Given three ndarrays, condition, x, and y, return an ndarray with the elements
2246from x or y, depending on the elements from condition are true or false.
2247
2248Shapes of condition, x, and y must be broadcastable to a common shape, which
2249is the output shape of this op. Semantics follow numpy where function.
2250https://numpy.org/doc/stable/reference/generated/numpy.where.html
2251
2252Note that all non-zero values are interpreted as True in condition.
2253
2254Examples::
2255
2256 x = [[1, 2], [3, 4]]
2257 y = [[5, 6], [7, 8]]
2258 cond = [[0, 1], [-1, 0]]
2259 where(cond, x, y) = [[5, 2], [3, 8]]
2260
2261
2262 cond = [[1], [0]]
2263 where(cond, x, y) = [[1, 2], [7, 8]]
2264
2265 cond = [0, 1]
2266 where(cond, 1, -1) = [-1, 1]
2267
2268)code" TVM_ADD_FILELINE)
2269 .add_argument("condition", "Tensor", "Condition array")
2270 .add_argument("x", "Tensor", "First array to be selected")
2271 .add_argument("y", "Tensor", "Second array to be selected")
2272 .set_num_inputs(3)
2273 .set_support_level(4)
2274 .add_type_rel("Where", WhereRel)
2275 .set_attr<FTVMCompute>("FTVMCompute", WhereCompute)
2276 .set_attr<TOpPattern>("TOpPattern", kBroadcast);
2277
2278// Squeeze
2279TVM_REGISTER_NODE_TYPE(SqueezeAttrs);
2280
2281Expr MakeSqueeze(Expr data, Array<Integer> axis) {
2282 auto attrs = make_object<SqueezeAttrs>();
2283 attrs->axis = std::move(axis);
2284 static const Op& op = Op::Get("squeeze");
2285 return Call(op, {data}, Attrs(attrs), {});
2286}
2287
2288TVM_REGISTER_GLOBAL("relay.op._make.squeeze").set_body_typed(MakeSqueeze);
2289
2290bool SqueezeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2291 const TypeReporter& reporter) {
2292 ICHECK_EQ(types.size(), 2);
2293 const auto* data = types[0].as<TensorTypeNode>();
2294 if (data == nullptr) {
2295 return false;
2296 }
2297 const auto* param = attrs.as<SqueezeAttrs>();
2298 ICHECK(param != nullptr);
2299 std::vector<IndexExpr> result_shape;
2300 // if axes is None, squeeze all axes of dimension 1
2301 if (!param->axis.defined()) {
2302 for (const auto& e : data->shape) {
2303 if (!e.as<IntImmNode>()) {
2304 LOG(FATAL) << "axis needs to be defined for dynamic input.";
2305 }
2306 const int64_t* axis_ptr = tir::as_const_int(e);
2307 ICHECK(axis_ptr != nullptr) << "the axes attribute must be concrete";
2308 if (*axis_ptr != 1) {
2309 result_shape.push_back(e);
2310 }
2311 }
2312 } else {
2313 // pair up original shape with a boolean which control whether it will be in the final shape.
2314 std::vector<std::pair<IndexExpr, bool>> original_shape;
2315 for (const auto& e : data->shape) {
2316 original_shape.push_back(std::pair<IndexExpr, bool>(e, true));
2317 }
2318 for (const auto& e : param->axis) {
2319 int64_t axis_val = e->value;
2320 if (axis_val < 0) {
2321 axis_val += static_cast<int64_t>(original_shape.size());
2322 }
2323 ICHECK_GE(axis_val, 0);
2324 ICHECK_LT(axis_val, original_shape.size());
2325 original_shape.at(axis_val).second = false;
2326 }
2327 for (const auto& p : original_shape) {
2328 if (p.second) {
2329 result_shape.push_back(p.first);
2330 } else {
2331 if (const int64_t* axis_ptr = tir::as_const_int(p.first)) {
2332 ICHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1";
2333 }
2334 }
2335 }
2336 }
2337 reporter->Assign(types[1], TensorType(result_shape, data->dtype));
2338 return true;
2339}
2340
2341Array<te::Tensor> SqueezeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2342 const Type& out_type) {
2343 const SqueezeAttrs* param = attrs.as<SqueezeAttrs>();
2344 ICHECK(param != nullptr);
2345 return {topi::squeeze(inputs[0], param->axis)};
2346}
2347
2348InferCorrectLayoutOutput SqueezeInferCorrectLayout(const Attrs& attrs,
2349 const Array<Layout>& new_in_layouts,
2350 const Array<Layout>& old_in_layouts,
2351 const Array<tvm::relay::Type>& old_in_types) {
2352 const auto* attrs_ptr = attrs.as<SqueezeAttrs>();
2353 ICHECK(attrs_ptr);
2354 ObjectPtr<SqueezeAttrs> params = make_object<SqueezeAttrs>(*attrs_ptr);
2355
2356 Layout inferred_input = new_in_layouts.defined() ? new_in_layouts[0] : old_in_layouts[0];
2357 Layout inferred_output = inferred_input;
2358
2359 ICHECK(old_in_types[0].as<TensorTypeNode>());
2360 const auto& shape = old_in_types[0].as<TensorTypeNode>()->shape;
2361
2362 // axis to squeeze
2363 Array<Integer> axis;
2364 if (params->axis.defined()) {
2365 axis = params->axis;
2366 } else {
2367 // if axes is None, squeeze all axes of dimension 1
2368 for (size_t i = 0; i < shape.size(); i++) {
2369 if (topi::detail::GetConstInt(shape[i]) == 1) {
2370 axis.push_back(i);
2371 }
2372 }
2373 }
2374
2375 // If new_in_layouts are defined, this code tries to modify the layout
2376 if (new_in_layouts.defined() && old_in_layouts.defined()) {
2377 Array<Integer> new_axis;
2378 for (const auto& e : axis) {
2379 const auto& dim = old_in_layouts[0][e.IntValue()];
2380 new_axis.push_back((new_in_layouts[0]).IndexOf(dim));
2381 }
2382 params->axis = new_axis;
2383 axis = new_axis;
2384 }
2385
2386 // Infer output layout
2387 Array<tir::IterVar> kept_axes;
2388 for (size_t i = 0; i < inferred_input.ndim(); i++) {
2389 bool is_dim_kept = true;
2390
2391 // Check whether the dim should be kept
2392 for (const auto& e : axis) {
2393 int64_t axis_val = e->value;
2394 if (axis_val < 0) {
2395 axis_val += inferred_input.ndim();
2396 }
2397 if (static_cast<int64_t>(i) == axis_val) {
2398 is_dim_kept = false;
2399 break;
2400 }
2401 }
2402
2403 if (is_dim_kept) {
2404 kept_axes.push_back(inferred_input->axes[i]);
2405 }
2406 }
2407 inferred_output = Layout(kept_axes);
2408
2409 return InferCorrectLayoutOutput({inferred_input}, {inferred_output}, Attrs(params));
2410}
2411
2412RELAY_REGISTER_OP("squeeze")
2413 .describe(R"code(Squeeze the input tensor at the dimensions given by axes
2414
2415- **data**: The input data to the operator.
2416
2417)code" TVM_ADD_FILELINE)
2418 .set_num_inputs(1)
2419 .set_attrs_type<SqueezeAttrs>()
2420 .add_argument("data", "Tensor", "The input tensor.")
2421 .set_support_level(3)
2422 .add_type_rel("Squeeze", SqueezeRel)
2423 .set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
2424 .set_attr<TOpPattern>("TOpPattern", kInjective)
2425 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", SqueezeInferCorrectLayout)
2426 .set_attr<TReshapeOp>("TReshapeOp", true);
2427
2428// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
2429bool CollapseSumLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2430 const TypeReporter& reporter) {
2431 ICHECK_EQ(types.size(), 3);
2432 reporter->Assign(types[2], types[1]);
2433 return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter);
2434}
2435
2436Expr MakeCollapseSumLike(Expr data, Expr collapse_type) {
2437 static const Op& op = Op::Get("collapse_sum_like");
2438 return Call(op, {data, collapse_type}, Attrs(), {});
2439}
2440
2441Array<te::Tensor> CollapseSumLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2442 const Type& out_type) {
2443 const auto* out_ttype = out_type.as<TensorTypeNode>();
2444 ICHECK(out_ttype != nullptr);
2445 return {topi::collapse_sum(inputs[0], out_ttype->shape)};
2446}
2447
2448TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like").set_body_typed(MakeCollapseSumLike);
2449
2450RELAY_REGISTER_OP("collapse_sum_like")
2451 .describe(R"code(Collapse the first input to match the shape of the second input.
2452)code" TVM_ADD_FILELINE)
2453 .set_num_inputs(2)
2454 .add_argument("data", "Tensor", "The input tensor.")
2455 .add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
2456 .set_support_level(10)
2457 .add_type_rel("CollapseSumLike", CollapseSumLikeRel)
2458 .set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
2459 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
2460
2461// CollapseSumTo: <A, B> -> B where Broadcast(A, B) = A
2462bool CollapseSumToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2463 const TypeReporter& reporter) {
2464 ICHECK_EQ(types.size(), 3);
2465 const InitOpAttrs* param = attrs.as<InitOpAttrs>();
2466
2467 const auto* target_shape = types[1].as<TensorTypeNode>();
2468 DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
2469
2470 const IntImmNode* rank = target_shape->shape[0].as<IntImmNode>();
2471 ICHECK(rank) << "Parameter must have static rank";
2472
2473 std::vector<IndexExpr> oshape;
2474 if (param->shape) {
2475 const Array<Integer>& cshape_array = param->shape.value();
2476 for (size_t i = 0; i < cshape_array.size(); i++) {
2477 oshape.push_back(cshape_array[i]);
2478 }
2479 } else {
2480 for (int i = 0; i < rank->value; i++) {
2481 oshape.push_back(Any());
2482 }
2483 }
2484 reporter->Assign(types[2], TensorType(oshape, out_dtype));
2485 return BroadcastRel({types[0], types[2], types[0]}, 2, Attrs(), reporter);
2486}
2487
2488Expr MakeCollapseSumTo(Expr data, Expr shape) {
2489 static const Op& op = Op::Get("collapse_sum_to");
2490 auto attrs = make_object<InitOpAttrs>();
2491 if (const auto* cshape = shape.as<ConstantNode>()) {
2492 attrs->shape = ToVector(cshape->data);
2493 }
2494 return Call(op, {data, shape}, Attrs(attrs), {});
2495}
2496
2497TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_to").set_body_typed(MakeCollapseSumTo);
2498
2499RELAY_REGISTER_OP("collapse_sum_to")
2500 .describe(R"code(Broadcast the first input to match the shape argument.
2501)code" TVM_ADD_FILELINE)
2502 .set_num_inputs(2)
2503 .add_argument("data", "Tensor", "The input tensor.")
2504 .add_argument("shape", "Tensor", "Target shape.")
2505 .set_support_level(4)
2506 .add_type_rel("CollapseSumTo", CollapseSumToRel)
2507 .set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
2508 .set_attr<TOpPattern>("TOpPattern", kCommReduce);
2509
2510bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2511 const TypeReporter& reporter) {
2512 // types = [data_type, ret_type], broadcast_to_type is in attrs bc static
2513 ICHECK_EQ(types.size(), 2);
2514
2515 const InitOpAttrs* param = attrs.as<InitOpAttrs>();
2516 ICHECK(param);
2517
2518 DataType out_dtype;
2519 if (auto ttype = types[0].as<TensorTypeNode>()) {
2520 out_dtype = ttype->dtype;
2521 } else {
2522 ICHECK(types[0].as<IncompleteTypeNode>())
2523 << "Broadcast: expect to be TensorType but get " << types[0];
2524 return false;
2525 }
2526
2527 std::vector<IndexExpr> oshape;
2528
2529 const Array<Integer>& cshape_array = param->shape.value();
2530 for (size_t i = 0; i < cshape_array.size(); ++i) {
2531 oshape.push_back(cshape_array[i]);
2532 }
2533 reporter->Assign(types[1], TensorType(oshape, out_dtype));
2534 return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
2535}
2536
2537Expr MakeBroadCastTo(Expr data, Array<Integer> shape) {
2538 static const Op& op = Op::Get("broadcast_to");
2539 auto attrs = make_object<InitOpAttrs>();
2540
2541 attrs->shape = std::move(shape);
2542 return Call(op, {data}, Attrs(attrs), {});
2543}
2544
2545Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2546 const Type& out_type) {
2547 const auto* out_ttype = out_type.as<TensorTypeNode>();
2548 return {topi::broadcast_to(inputs[0], out_ttype->shape)};
2549}
2550
2551TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo);
2552
2553RELAY_REGISTER_OP("broadcast_to")
2554 .describe(R"code(Broadcast the first input to match the shape argument.
2555)code" TVM_ADD_FILELINE)
2556 .set_num_inputs(1)
2557 .add_argument("data", "Tensor", "The input tensor.")
2558 .set_support_level(4)
2559 .add_type_rel("BroadCastTo", BroadCastToRel)
2560 .set_attrs_type<InitOpAttrs>()
2561 .set_attr<FTVMCompute>("FTVMCompute", BroadCastToCompute)
2562 .set_attr<TOpPattern>("TOpPattern", kBroadcast);
2563
2564// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
2565bool BroadCastToLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2566 const TypeReporter& reporter) {
2567 ICHECK_EQ(types.size(), 3);
2568 reporter->Assign(types[2], types[1]);
2569 return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
2570}
2571
2572Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) {
2573 static const Op& op = Op::Get("broadcast_to_like");
2574 return Call(op, {data, broadcast_type}, Attrs(), {});
2575}
2576
2577Array<te::Tensor> BroadCastToLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2578 const Type& out_type) {
2579 const auto* out_ttype = out_type.as<TensorTypeNode>();
2580 ICHECK(out_ttype != nullptr);
2581 return {topi::broadcast_to(inputs[0], out_ttype->shape)};
2582}
2583
2584TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like").set_body_typed(MakeBroadCastToLike);
2585
2586RELAY_REGISTER_OP("broadcast_to_like")
2587 .describe(R"code(Broadcast the first input to match the shape of the second input.
2588)code" TVM_ADD_FILELINE)
2589 .set_num_inputs(2)
2590 .add_argument("data", "Tensor", "The input tensor.")
2591 .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
2592 .set_support_level(10)
2593 .add_type_rel("BroadCastToLike", BroadCastToLikeRel)
2594 .set_attr<FTVMCompute>("FTVMCompute", BroadCastToLikeCompute)
2595 .set_attr<TOpPattern>("TOpPattern", kBroadcast);
2596
2597// Adapter function to make int array.
2598Array<Integer> GetIntArray(Array<IndexExpr> arr) {
2599 for (size_t i = 0; i < arr.size(); ++i) {
2600 ICHECK(!arr[i].defined() || arr[i].as<IntImmNode>()) << "Expect an int array";
2601 }
2602 return Downcast<Array<Integer>>(arr);
2603}
2604
2605// strided_slice
2606TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
2607
2608bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2609 const TypeReporter& reporter) {
2610 ICHECK_EQ(types.size(), 2);
2611 const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
2612 if (param == nullptr) {
2613 return false;
2614 }
2615 const auto* data = types[0].as<TensorTypeNode>();
2616
2617 if (data == nullptr) {
2618 return false;
2619 }
2620
2621 ICHECK(param->begin) << "strided_slice received invalid begin " << param->begin;
2622 ICHECK(param->end) << "strided_slice received invalid end " << param->end;
2623 ICHECK(param->strides) << "strided_slice received invalid strides " << param->strides;
2624
2625 auto begin = param->begin.value();
2626 auto end = param->end.value();
2627 auto strides = param->strides.value();
2628
2629 const size_t src_tensor_dim = static_cast<size_t>(data->shape.size());
2630 Array<Integer> axes;
2631 if (param->axes) {
2632 axes = param->axes.value();
2633 ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
2634 axes.size() == strides.size())
2635 << "axes, begin, end, and strides must have the same length";
2636 } else {
2637 for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);
2638
2639 const IntImm one = IntImm(DataType::Int(64), 1);
2640 const IntImm zero = IntImm(DataType::Int(64), 0);
2641 const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());
2642
2643 for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
2644 strides.push_back(one);
2645 }
2646 for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
2647 begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range);
2648 }
2649 for (size_t i = end.size(); i < src_tensor_dim; ++i) {
2650 end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range);
2651 }
2652 }
2653 auto oshape =
2654 topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode);
2655 reporter->Assign(types[1], TensorType(oshape, data->dtype));
2656 return true;
2657}
2658
2659InferCorrectLayoutOutput StridedSliceInferCorrectLayout(
2660 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
2661 const Array<tvm::relay::Type>& old_in_types) {
2662 Array<Array<IndexExpr>> old_in_shapes;
2663 for (auto old_in_t : old_in_types) {
2664 ICHECK(old_in_t.as<TensorTypeNode>());
2665 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
2666 }
2667
2668 ICHECK(old_in_layouts.defined());
2669 ICHECK_GE(old_in_layouts.size(), 1);
2670 ICHECK(old_in_shapes.defined());
2671 ICHECK_GE(old_in_shapes.size(), 1);
2672
2673 auto layout = old_in_layouts[0];
2674 InferCorrectLayoutOutput out_default{{Layout::Undef()}, {Layout::Undef()}, attrs};
2675
2676 if (layout.defined() && new_in_layouts.defined()) {
2677 ICHECK_GE(new_in_layouts.size(), 1);
2678 auto new_layout = new_in_layouts[0];
2679 auto shape = old_in_shapes[0];
2680
2681 const auto* attrs_ptr = attrs.as<StridedSliceAttrs>();
2682 ICHECK(attrs_ptr);
2683 ObjectPtr<StridedSliceAttrs> params = make_object<StridedSliceAttrs>(*attrs_ptr);
2684
2685 Array<Integer> begin, end, strides;
2686 if (params->begin && params->end && params->strides) {
2687 for (Integer i : params->strides.value()) {
2688 ICHECK(i.defined());
2689 auto slice_val = Integer(IntImm(i->dtype, i->value));
2690 strides.push_back(params->slice_mode == "size" ? Integer(IntImm(i->dtype, 1)) : slice_val);
2691 }
2692
2693 for (Integer i : params->begin.value()) {
2694 ICHECK(i.defined());
2695 begin.push_back(IntImm(i->dtype, i->value));
2696 }
2697 for (Integer i : params->end.value()) {
2698 ICHECK(i.defined());
2699 end.push_back(IntImm(i->dtype, i->value));
2700 }
2701 }
2702
2703 Array<Integer> new_begin, new_end, new_strides;
2704
2705 // Handles layout conversion like NHWC -> NCHW
2706 auto old_layout_name = layout.name();
2707 auto new_layout_name = new_layout.name();
2708
2709 if (old_layout_name.rfind(new_layout_name, 0) != 0 &&
2710 new_layout_name.rfind(old_layout_name, 0) != 0) {
2711 if (old_layout_name.size() != new_layout_name.size()) {
2712 // Not support NHW4c -> NCHW
2713 return out_default;
2714 } else {
2715 if (params->axes) {
2716 auto axes = params->axes.value();
2717 Array<Integer> new_axes;
2718
2719 for (size_t i = 0; i < axes.size(); ++i) {
2720 auto old_idx = axes[i].IntValue();
2721 auto new_idx = new_layout.IndexOf(layout[old_idx]);
2722 new_begin.push_back(begin[i]);
2723 new_end.push_back(end[i]);
2724 new_strides.push_back(strides[i]);
2725 new_axes.push_back(new_idx);
2726 }
2727 params->axes = new_axes;
2728
2729 } else {
2730 for (size_t i = 0; i < new_layout_name.size(); ++i) {
2731 auto index = layout.IndexOf(new_layout[i]);
2732 if (index == -1) {
2733 return out_default;
2734 }
2735
2736 size_t new_index = static_cast<size_t>(index);
2737 int64_t bg, ed, st;
2738 if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) {
2739 st = strides[new_index]->value;
2740 } else {
2741 st = 1;
2742 }
2743 if (new_index < begin.size() && begin[new_index].defined()) {
2744 bg = begin[new_index]->value;
2745 } else {
2746 bg = 0;
2747 }
2748 if (new_index < end.size() && end[new_index].defined()) {
2749 ed = end[new_index]->value;
2750 } else {
2751 ed = shape[new_index].as<IntImmNode>()->value;
2752 }
2753
2754 new_begin.push_back(IntImm(begin[0]->dtype, bg));
2755 new_end.push_back(IntImm(end[0]->dtype, ed));
2756 new_strides.push_back(IntImm(strides[0]->dtype, st));
2757 }
2758 }
2759
2760 params->begin = new_begin;
2761 params->end = new_end;
2762 params->strides = new_strides;
2763 layout = new_layout;
2764 }
2765 } else if (old_layout_name.size() <
2766 new_layout_name.size()) { // prohibit transforms such as NCHW4c -> NCHW
2767 if (params->axes) {
2768 auto axes = params->axes.value();
2769 Array<Integer> new_axes;
2770 for (size_t i = 0; i < axes.size(); ++i) {
2771 auto old_idx = axes[i].IntValue();
2772 auto new_idx = new_layout.IndexOf(layout[old_idx]);
2773 new_axes.push_back(new_idx);
2774
2775 const LayoutAxis& axis = layout[old_idx];
2776 ICHECK(axis.IsPrimal());
2777 auto factor = new_layout.FactorOf(axis);
2778 if (factor == -1) {
2779 new_begin.push_back(begin[i]);
2780 new_end.push_back(end[i]);
2781 } else {
2782 if (strides.defined() && i < strides.size()) {
2783 auto stride = strides[i];
2784 // arbitrary stride is not supported
2785 if (stride.defined() && stride->value != 1) {
2786 return out_default;
2787 }
2788 }
2789 int64_t bg = begin[i].IntValue();
2790 int64_t ed = end[i].IntValue();
2791 if (bg % factor || ed % factor) {
2792 // transform to original layout
2793 return out_default;
2794 }
2795 new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor)));
2796 new_end.push_back(IntImm(end[0]->dtype, (ed / factor)));
2797 }
2798 }
2799 params->axes = new_axes;
2800
2801 } else {
2802 for (size_t i = 0; i < begin.size(); i++) {
2803 const LayoutAxis& axis = layout[i];
2804 ICHECK(axis.IsPrimal());
2805 auto factor = new_layout.FactorOf(axis);
2806 if (factor == -1) {
2807 new_begin.push_back(IntImm(begin[i]->dtype, begin[i].IntValue()));
2808 new_end.push_back(IntImm(end[i]->dtype, end[i].IntValue()));
2809 } else {
2810 if (strides.defined() && i < strides.size()) {
2811 auto stride = strides[i];
2812 // arbitrary stride is not supported
2813 if (stride.defined() && stride->value != 1) {
2814 return out_default;
2815 }
2816 }
2817 int64_t bg = begin[i].defined() ? begin[i]->value : 0;
2818 int64_t ed;
2819 if (!end[i].defined()) {
2820 ed = shape[i].as<IntImmNode>()->value;
2821 } else if (params->slice_mode == "size") {
2822 if (end[i]->value < 0) {
2823 ed = shape[i].as<IntImmNode>()->value;
2824 } else {
2825 ed = bg + end[i]->value;
2826 }
2827 } else {
2828 ed = end[i]->value;
2829 }
2830
2831 if (bg % factor || ed % factor) {
2832 // transform to original layout
2833 return out_default;
2834 }
2835 new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor)));
2836 new_end.push_back(IntImm(end[0]->dtype, (ed / factor)));
2837 }
2838 }
2839 }
2840
2841 layout = new_layout;
2842 params->begin = new_begin;
2843 params->end = new_end;
2844 }
2845 return InferCorrectLayoutOutput({layout}, {layout}, Attrs(params));
2846 }
2847 return InferCorrectLayoutOutput({layout}, {layout}, attrs);
2848}
2849
2850Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
2851 const Type& out_type) {
2852 const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
2853 ICHECK(param != nullptr);
2854 ICHECK(param->begin && param->end && param->strides);
2855 Array<Integer> begin = param->begin.value();
2856 Array<Integer> end = param->end.value();
2857 Array<Integer> strides = param->strides.value();
2858 if (param->axes) {
2859 auto axes = param->axes.value();
2860 return Array<te::Tensor>{
2861 topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)};
2862 }
2863 return Array<te::Tensor>{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)};
2864}
2865
2866// Positional relay function to create StridedSlice operator used by frontend FFI.
2867Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides,
2868 String slice_mode, Optional<Array<Integer>> axes) {
2869 auto attrs = make_object<StridedSliceAttrs>();
2870 attrs->begin = std::move(begin);
2871 attrs->end = std::move(end);
2872 attrs->strides = std::move(strides);
2873 attrs->slice_mode = slice_mode;
2874 attrs->axes = std::move(axes);
2875 static const Op& op = Op::Get("strided_slice");
2876 return Call(op, {data}, Attrs(attrs), {});
2877}
2878
2879TVM_REGISTER_GLOBAL("relay.op._make.strided_slice").set_body_typed(MakeStridedSlice);
2880
2881RELAY_REGISTER_OP("strided_slice")
2882 .describe(R"code(Strided slice of an array.
2883
2884Examples::
2885
2886 x = [[ 1., 4., 7., 10.],
2887 [ 2., 5., 8., 11.],
2888 [ 3., 6., 9., 12.]]
2889
2890 strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.],
2891 [ 5., 8., 11.]]
2892
2893 x = [[[ 1., 2.],
2894 [ 3., 4.]],
2895
2896 [[ 5., 6.],
2897 [ 7., 8.]]]
2898
2899 strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.],
2900 [ 3., 4.]],
2901
2902 [[ 5., 6.],
2903 [ 7., 8.]]]
2904)code" TVM_ADD_FILELINE)
2905 .set_num_inputs(1)
2906 .add_argument("data", "Tensor", "The input tensor.")
2907 .set_support_level(4)
2908 .set_attrs_type<StridedSliceAttrs>()
2909 .add_type_rel("StridedSlice", StridedSliceRel)
2910 .set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
2911 .set_attr<TOpPattern>("TOpPattern", kInjective)
2912 .set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions)
2913 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", StridedSliceInferCorrectLayout);
2914
2915// strided_set
2916bool StridedSetRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
2917 const TypeReporter& reporter) {
2918 ICHECK_EQ(types.size(), 6);
2919 reporter->Assign(types[5], types[0]);
2920 return true;
2921}
2922
2923Expr MakeStridedSet(Expr data, Expr v, Expr begin, Expr end, Expr strides) {
2924 static const Op& op = Op::Get("strided_set");
2925 return Call(op, {data, v, begin, end, strides}, {});
2926}
2927
2928TVM_REGISTER_GLOBAL("relay.op._make.strided_set").set_body_typed(MakeStridedSet);
2929
2930RELAY_REGISTER_OP("strided_set")
2931 .describe(R"code(Strided set of an array.
2932Example::
2933
2934 x = [[ 1., 4., 7., 10.],
2935 [ 2., 5., 8., 11.],
2936 [ 3., 6., 9., 12.]]
2937
2938 v = [[ 11., 22., 33.]
2939 [ 44., 55., 66.]]
2940
2941 strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \
2942 [[ 1., 11., 22., 33.],
2943 [ 2., 44., 55., 66.],
2944 [ 3., 6., 9., 12.]]
2945)code" TVM_ADD_FILELINE)
2946 .set_num_inputs(5)
2947 .add_argument("data", "Tensor", "The input tensor.")
2948 .add_argument("v", "Tensor", "The data to set.")
2949 .add_argument("begin", "Tensor", "Indices for the start of the slice.")
2950 .add_argument("end", "Tensor", "Indices indicating the end of the slice.")
2951 .add_argument("strides", "Tensor", "The strides values.")
2952 .set_support_level(4)
2953 .set_attr<TOpPattern>("TOpPattern", kInjective)
2954 .add_type_rel("StridedSet", StridedSetRel);
2955
2956// relay.split
2957TVM_REGISTER_NODE_TYPE(SplitAttrs);
2958
2959InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs,
2960 const Array<Layout>& new_in_layouts,
2961 const Array<Layout>& old_in_layouts,
2962 const Array<tvm::relay::Type>& old_in_types) {
2963 const auto* attrs_ptr = attrs.as<SplitAttrs>();
2964 ICHECK(attrs_ptr);
2965 ObjectPtr<SplitAttrs> param = make_object<SplitAttrs>(*attrs_ptr);
2966
2967 Array<Array<IndexExpr>> old_in_shapes;
2968 for (auto old_in_t : old_in_types) {
2969 ICHECK(old_in_t.as<TensorTypeNode>());
2970 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
2971 }
2972
2973 size_t axis =
2974 param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis);
2975
2976 Layout ret = Layout::Undef();
2977 size_t size = 0;
2978 if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
2979 size = sections->value;
2980 } else {
2981 size = Downcast<Array<Integer>>(param->indices_or_sections).size() + 1;
2982 }
2983
2984 // If new_in_layouts are defined, this code tries to modify the layout.
2985 if (new_in_layouts.defined() && old_in_layouts.defined()) {
2986 bool divisible = true;
2987 const auto& sp_dim = old_in_layouts[0][axis];
2988 auto new_index = new_in_layouts[0].IndexOf(sp_dim);
2989 param->axis = new_index;
2990 int factor = new_in_layouts[0].FactorOf(sp_dim);
2991 if (factor > 1) {
2992 if (!param->indices_or_sections.as<IntImmNode>()) {
2993 auto ios = Downcast<Array<Integer>>(param->indices_or_sections);
2994 Array<Integer> new_ios;
2995 for (const auto& v : ios) {
2996 const IntImmNode* vint = v.as<IntImmNode>();
2997 new_ios.push_back(vint->value / factor);
2998 if (vint->value % factor) {
2999 divisible = false;
3000 }
3001 }
3002 if (divisible) {
3003 param->indices_or_sections = new_ios;
3004 }
3005 }
3006 }
3007 if (divisible) {
3008 ret = new_in_layouts[0];
3009 } else {
3010 ret = old_in_layouts[0];
3011 }
3012 } else if (old_in_layouts.defined()) {
3013 ret = old_in_layouts[0];
3014 }
3015
3016 return InferCorrectLayoutOutput({ret}, {Array<Layout>(size, ret)}, Attrs(param));
3017}
3018
3019bool SplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3020 const TypeReporter& reporter) {
3021 // `types` contains: [data, result]
3022 ICHECK_EQ(types.size(), 2);
3023 const auto* data = types[0].as<TensorTypeNode>();
3024 if (data == nullptr) return false;
3025 ICHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
3026 const auto param = attrs.as<SplitAttrs>();
3027 ICHECK(param != nullptr);
3028 auto axis = param->axis;
3029 if (axis < 0) {
3030 axis += data->shape.size();
3031 }
3032 ICHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range.";
3033 ICHECK_GE(axis, 0) << "axis should be within the input dimension range.";
3034
3035 if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
3036 if (!data->shape[axis].as<AnyNode>()) {
3037 ICHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) ==
3038 tir::make_zero(DataType::Int(64))))
3039 << "indices_or_sections need to be able to divide input.shape[axis]";
3040 }
3041 std::vector<Type> fields;
3042 for (int i = 0; i < sections->value; ++i) {
3043 std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
3044 if (data->shape[axis].as<AnyNode>()) {
3045 oshape[axis] = Any();
3046 } else {
3047 oshape[axis] = indexdiv(oshape[axis], sections->value);
3048 }
3049 auto vec_type = TensorType(oshape, data->dtype);
3050 fields.push_back(vec_type);
3051 }
3052 reporter->Assign(types[1], TupleType(Array<Type>(fields)));
3053 } else {
3054 Array<IndexExpr> indices;
3055 for (auto i : Downcast<Array<Integer>>(param->indices_or_sections)) {
3056 indices.push_back(IntImm(DataType::Int(32), i.as<IntImmNode>()->value));
3057 }
3058 auto begin = IndexExpr(tir::make_zero(DataType::Int(32)));
3059 std::vector<Type> fields;
3060 for (unsigned int i = 0; i < indices.size(); ++i) {
3061 ICHECK(reporter->Assert(indices[i] > begin))
3062 << "indices_or_sections need to be a sorted ascending list";
3063 std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
3064 oshape[axis] = indices[i] - begin;
3065 begin = indices[i];
3066 auto vec_type = TensorType(oshape, data->dtype);
3067 fields.push_back(vec_type);
3068 }
3069 if (!data->shape[axis].as<AnyNode>()) {
3070 ICHECK(reporter->Assert(begin < data->shape[axis]))
3071 << "The sum of sections must match the input.shape[axis]";
3072 }
3073 std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end());
3074 if (data->shape[axis].as<AnyNode>()) {
3075 oshape[axis] = Any();
3076 } else {
3077 oshape[axis] = data->shape[axis] - begin;
3078 }
3079 auto vec_type = TensorType(oshape, data->dtype);
3080 fields.push_back(vec_type);
3081 reporter->Assign(types[1], TupleType(Array<Type>(fields)));
3082 }
3083 return true;
3084}
3085
3086Array<te::Tensor> SplitCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3087 const Type& out_type) {
3088 const auto param = attrs.as<SplitAttrs>();
3089 ICHECK(param != nullptr);
3090
3091 if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) {
3092 int64_t num_sections = sections->value;
3093 return Array<te::Tensor>{topi::split_sections(inputs[0], num_sections, param->axis)};
3094 } else {
3095 Array<PrimExpr> indices;
3096 for (auto i : Downcast<Array<Integer>>(param->indices_or_sections)) {
3097 indices.push_back(IntImm(DataType::Int(32), i.as<IntImmNode>()->value));
3098 }
3099 return Array<te::Tensor>{topi::split(inputs[0], indices, param->axis)};
3100 }
3101}
3102
3103Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) {
3104 auto attrs = make_object<SplitAttrs>();
3105 attrs->axis = axis;
3106 attrs->indices_or_sections = std::move(indices_or_sections);
3107 static const Op& op = Op::Get("split");
3108 return Call(op, {data}, Attrs(attrs), {});
3109}
3110
3111TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) {
3112 if (args.type_codes[1] == kDLInt) {
3113 // Note: we change it from Int(64) to Int(32) for now as
3114 // combine_parallel_dense will transform the graph with Int(32).
3115 // More invetigation is needs to check which one we should use.
3116 *rv =
3117 MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast<int>(args[1])), args[2]);
3118 } else {
3119 *rv = MakeSplit(args[0], args[1], args[2]);
3120 }
3121});
3122
3123RELAY_REGISTER_OP("split")
3124 .describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
3125
3126Indices or sections to split into. Accepts an int or a tuple
3127If indices_or_sections is an integer, the input will be divided equally
3128along given axis. If such a split is not possible, an error is raised.
3129
3130If indices_or_sections is a tuple of sorted integers,
3131the entries indicate where along axis the array is split.
3132
3133)code" TVM_ADD_FILELINE)
3134 .set_attrs_type<SplitAttrs>()
3135 .set_num_inputs(1)
3136 .add_argument("data", "Tensor", "The input tensor.")
3137 .set_support_level(3)
3138 .add_type_rel("Split", SplitRel)
3139 .set_attr<FTVMCompute>("FTVMCompute", SplitCompute)
3140 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", SplitInferCorrectLayout)
3141 .set_attr<TOpPattern>("TOpPattern", kInjective);
3142
3143// relay.slice_like
3144TVM_REGISTER_NODE_TYPE(SliceLikeAttrs);
3145
3146/*!
3147 * \brief SliceLikeRel User defined type constraint function.
3148 * \param num_inputs Number of input types in the args.
3149 * \param attrs The additional attributes of the operator.
3150 * \param reporter The reporter to report solution to.
3151 * \return False if the relation has not been resolved, it might be resolved later.
3152 * True if this relation has been resolved.
3153 */
3154bool SliceLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3155 const TypeReporter& reporter) {
3156 ICHECK_EQ(types.size(), 3);
3157 const auto* data = types[0].as<TensorTypeNode>();
3158 if (data == nullptr) {
3159 return false;
3160 }
3161
3162 const auto* target = types[1].as<TensorTypeNode>();
3163 if (target == nullptr) {
3164 return false;
3165 }
3166
3167 const auto param = attrs.as<SliceLikeAttrs>();
3168 ICHECK(param != nullptr);
3169
3170 const Array<IndexExpr>& dshape = data->shape;
3171 const Array<IndexExpr>& target_shape = target->shape;
3172 std::vector<IndexExpr> oshape(dshape.begin(), dshape.end());
3173
3174 if (!param->axes.defined()) {
3175 for (size_t i = 0; i < dshape.size(); ++i) {
3176 if (i < target_shape.size()) {
3177 oshape[i] = target_shape[i];
3178 ICHECK(reporter->Assert(oshape[i] <= dshape[i]))
3179 << "End index of axis " << i << " exceeds input shape: " << oshape[i] << " vs "
3180 << dshape[i];
3181 }
3182 }
3183 } else {
3184 ICHECK(param->axes.size() != 0) << "Axes cannot be empty.";
3185 for (Integer val : param->axes) {
3186 int axis = val->value;
3187 if (axis < 0) {
3188 axis += dshape.size();
3189 }
3190 ICHECK(axis < static_cast<int>(target_shape.size()))
3191 << "Axis " << axis << " exceeds dimension " << target_shape.size() << " of target_shape.";
3192 oshape[axis] = target_shape[axis];
3193 ICHECK(reporter->Assert(oshape[axis] <= dshape[axis]))
3194 << "End index of axis " << axis << " exceeds input shape: " << oshape[axis] << " vs "
3195 << dshape[axis];
3196 }
3197 }
3198
3199 reporter->Assign(types[2], TensorType(oshape, data->dtype));
3200 return true;
3201}
3202
3203Expr MakeSliceLike(Expr data, Expr shape_like, Array<Integer> axes) {
3204 auto attrs = make_object<SliceLikeAttrs>();
3205 attrs->axes = std::move(axes);
3206 static const Op& op = Op::Get("slice_like");
3207 return Call(op, {data, shape_like}, Attrs(attrs), {});
3208}
3209
3210InferCorrectLayoutOutput SliceLikeInferCorrectLayout(const Attrs& attrs,
3211 const Array<Layout>& new_in_layouts,
3212 const Array<Layout>& old_in_layouts,
3213 const Array<tvm::relay::Type>& old_in_types) {
3214 Array<Integer> new_axes;
3215 if (old_in_layouts.defined() && new_in_layouts.defined()) {
3216 ICHECK_EQ(new_in_layouts.size(), 2);
3217 ICHECK_EQ(new_in_layouts[0]->name, new_in_layouts[1]->name);
3218 ICHECK_EQ(old_in_layouts.size(), 2);
3219 ICHECK_EQ(old_in_layouts[0]->name, old_in_layouts[1]->name);
3220
3221 auto old_layout = old_in_layouts[0];
3222 auto new_layout = new_in_layouts[0];
3223
3224 const auto* attrs_ptr = attrs.as<SliceLikeAttrs>();
3225 ICHECK(attrs_ptr);
3226 ObjectPtr<SliceLikeAttrs> params = make_object<SliceLikeAttrs>(*attrs_ptr);
3227
3228 for (auto axis : params->axes) {
3229 auto new_axis = new_layout.IndexOf(old_layout[axis->value]);
3230 // Cannot find the target axis in the new layout.
3231 if (new_axis == -1) {
3232 new_axes.clear();
3233 break;
3234 }
3235 new_axes.push_back(new_axis);
3236 }
3237 if (!new_axes.empty()) {
3238 params->axes = std::move(new_axes);
3239 return InferCorrectLayoutOutput({new_layout, new_layout}, {new_layout}, Attrs(params));
3240 }
3241 }
3242
3243 if (old_in_layouts.defined()) {
3244 ICHECK_EQ(old_in_layouts.size(), 2);
3245 return InferCorrectLayoutOutput({old_in_layouts[0], old_in_layouts[1]}, {old_in_layouts[1]},
3246 attrs);
3247 }
3248 return InferCorrectLayoutOutput({Layout::Undef(), Layout::Undef()}, {Layout::Undef()}, attrs);
3249}
3250
3251Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3252 const Type& out_type) {
3253 const auto* param = attrs.as<SliceLikeAttrs>();
3254 ICHECK(param != nullptr);
3255 Array<IndexExpr> src_shape = inputs[0]->shape;
3256 Array<IndexExpr> target_shape = inputs[1]->shape;
3257 Array<Integer> begin_idx, end_idx, strides;
3258 for (size_t i = 0; i < src_shape.size(); ++i) {
3259 begin_idx.push_back(0);
3260 strides.push_back(1);
3261 }
3262 for (auto s : src_shape) {
3263 ICHECK(s->IsInstance<tvm::IntImmNode>()) << "slice_like does not support dynamic input shape";
3264 end_idx.push_back(topi::GetConstInt(s));
3265 }
3266 if (!param->axes.defined()) {
3267 for (size_t i = 0; i < src_shape.size(); ++i) {
3268 if (i < target_shape.size()) {
3269 ICHECK(target_shape[i]->IsInstance<tvm::IntImmNode>())
3270 << "slice_like does not support dynamic output shape";
3271 end_idx.Set(i, topi::GetConstInt(target_shape[i]));
3272 ICHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i]))
3273 << "End index of axis " << i
3274 << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs "
3275 << topi::GetConstInt(src_shape[i]);
3276 }
3277 }
3278 } else {
3279 for (Integer axis : param->axes) {
3280 int a = axis.IntValue();
3281 if (a < 0) {
3282 a = static_cast<int>(src_shape.size()) + a;
3283 }
3284 ICHECK(target_shape[a]->IsInstance<tvm::IntImmNode>())
3285 << "slice_like does not support dynamic output shape";
3286 end_idx.Set(a, topi::GetConstInt(target_shape[a]));
3287 ICHECK_LE(topi::GetConstInt(end_idx[a]), topi::GetConstInt(src_shape[a]))
3288 << "End index of axis " << a << " exceeds input shape: " << topi::GetConstInt(end_idx[a])
3289 << " vs " << topi::GetConstInt(src_shape[a]);
3290 }
3291 }
3292 return Array<te::Tensor>{topi::strided_slice(inputs[0], begin_idx, end_idx, strides, "end")};
3293}
3294
3295TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike);
3296
3297RELAY_REGISTER_OP("slice_like")
3298 .describe(R"code(Slice the first input respect to the second input.
3299)code" TVM_ADD_FILELINE)
3300 .set_attrs_type<SliceLikeAttrs>()
3301 .set_num_inputs(2)
3302 .add_argument("data", "Tensor", "The input tensor.")
3303 .add_argument("shape_like", "Tensor", "Shape tensor.")
3304 .set_support_level(10)
3305 .add_type_rel("SliceLike", SliceLikeRel)
3306 .set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
3307 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", SliceLikeInferCorrectLayout)
3308 .set_attr<TOpPattern>("TOpPattern", kInjective);
3309
3310// relay.layout_transform
3311TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
3312
3313Array<te::Tensor> LayoutTransformCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3314 const Type& out_type) {
3315 const auto* param = attrs.as<LayoutTransformAttrs>();
3316 ICHECK(param != nullptr);
3317 return Array<te::Tensor>{topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)};
3318}
3319
3320bool LayoutTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3321 const TypeReporter& reporter) {
3322 const auto* data = types[0].as<TensorTypeNode>();
3323 if (data == nullptr) {
3324 ICHECK(types[0].as<IncompleteTypeNode>())
3325 << "LayoutTransform: expect input data type to be TensorType but get " << types[0];
3326 return false;
3327 }
3328 const LayoutTransformAttrs* params = attrs.as<LayoutTransformAttrs>();
3329
3330 Layout src_layout(params->src_layout);
3331 Layout dst_layout(params->dst_layout);
3332
3333 ICHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout";
3334 auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout);
3335 ICHECK(layout_converter.defined())
3336 << "cannot convert from " << params->src_layout << " to " << params->dst_layout;
3337
3338 const auto& out_shape = layout_converter.ForwardShape(data->shape);
3339 reporter->Assign(types[1], TensorType(out_shape, data->dtype));
3340 return true;
3341}
3342
3343Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout) {
3344 auto attrs = make_object<LayoutTransformAttrs>();
3345 attrs->src_layout = std::move(src_layout);
3346 attrs->dst_layout = std::move(dst_layout);
3347 static const Op& op = Op::Get("layout_transform");
3348 return Call(op, {data}, Attrs(attrs), {});
3349}
3350
3351TVM_REGISTER_GLOBAL("relay.op._make.layout_transform").set_body_typed(MakeLayoutTransform);
3352
3353RELAY_REGISTER_OP("layout_transform")
3354 .describe(R"code(Transform the input data layout.
3355
3356For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes
3357the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
3358
3359)code" TVM_ADD_FILELINE)
3360 .set_attrs_type<LayoutTransformAttrs>()
3361 .set_num_inputs(1)
3362 .add_argument("data", "Tensor", "The input tensor.")
3363 .add_type_rel("layout_transform", LayoutTransformRel)
3364 .set_support_level(5)
3365 .set_attr<FTVMCompute>("FTVMCompute", LayoutTransformCompute);
3366
3367// relay.auto_scheduler_layout_transform
3368TVM_REGISTER_NODE_TYPE(AutoSchedulerLayoutTransformAttrs);
3369
3370Array<te::Tensor> AutoSchedulerLayoutTransformCompute(const Attrs& attrs,
3371 const Array<te::Tensor>& inputs,
3372 const Type& out_type) {
3373 const auto* param = attrs.as<AutoSchedulerLayoutTransformAttrs>();
3374 CHECK(param != nullptr);
3375 return Array<te::Tensor>{
3376 topi::auto_scheduler_layout_transform(inputs[0], param->src_layout, param->dst_layout)};
3377}
3378
3379bool AutoSchedulerLayoutTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3380 const TypeReporter& reporter) {
3381 const auto* data = types[0].as<TensorTypeNode>();
3382 CHECK(data != nullptr);
3383 const AutoSchedulerLayoutTransformAttrs* params = attrs.as<AutoSchedulerLayoutTransformAttrs>();
3384
3385 Array<IndexExpr> dst_shape;
3386 std::vector<std::string> dst_axes;
3387
3388 topi::parse_auto_scheduler_layout(params->dst_layout, &dst_shape, &dst_axes);
3389
3390 reporter->Assign(types[1], TensorType(dst_shape, data->dtype));
3391 return true;
3392}
3393
3394Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_layout) {
3395 auto attrs = make_object<AutoSchedulerLayoutTransformAttrs>();
3396 attrs->src_layout = std::move(src_layout);
3397 attrs->dst_layout = std::move(dst_layout);
3398 static const Op& op = Op::Get("auto_scheduler_layout_transform");
3399 return Call(op, {data}, Attrs(attrs), {});
3400}
3401
3402TVM_REGISTER_GLOBAL("relay.op._make.auto_scheduler_layout_transform")
3403 .set_body_typed(MakeAutoSchedulerLayoutTransform);
3404
3405RELAY_REGISTER_OP("auto_scheduler_layout_transform")
3406 .describe(R"code(Transform the input kernel layout.
3407)code" TVM_ADD_FILELINE)
3408 .set_attrs_type<AutoSchedulerLayoutTransformAttrs>()
3409 .set_num_inputs(1)
3410 .add_argument("data", "Tensor", "The input tensor.")
3411 .add_type_rel("auto_scheduler_layout_transform", AutoSchedulerLayoutTransformRel)
3412 .set_support_level(5)
3413 .set_attr<FTVMCompute>("FTVMCompute", AutoSchedulerLayoutTransformCompute);
3414
3415// relay.meta_schedule_layout_transform
3416TVM_REGISTER_NODE_TYPE(MetaScheduleLayoutTransformAttrs);
3417
3418Array<te::Tensor> MetaScheduleLayoutTransformCompute(const Attrs& attrs,
3419 const Array<te::Tensor>& inputs,
3420 const Type& out_type) {
3421 const auto* param = attrs.as<MetaScheduleLayoutTransformAttrs>();
3422 CHECK(param != nullptr);
3423 return Array<te::Tensor>{topi::meta_schedule_layout_transform(inputs[0], param->index_map)};
3424}
3425
3426bool MetaScheduleLayoutTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3427 const TypeReporter& reporter) {
3428 TensorType data_type = Downcast<TensorType>(types[0]);
3429 const MetaScheduleLayoutTransformAttrs* params = attrs.as<MetaScheduleLayoutTransformAttrs>();
3430 ICHECK(params);
3431 Array<PrimExpr> new_shape = params->index_map->MapShape(data_type->shape);
3432 reporter->Assign(types[1], TensorType(new_shape, data_type->dtype));
3433 return true;
3434}
3435
3436Expr MakeMetaScheduleLayoutTransform(Expr data, tir::IndexMap index_map) {
3437 static const Op& op = Op::Get("meta_schedule_layout_transform");
3438 auto attrs = make_object<MetaScheduleLayoutTransformAttrs>();
3439 attrs->index_map = index_map;
3440 return Call(op, {data}, Attrs(attrs), {});
3441}
3442
3443TVM_REGISTER_GLOBAL("relay.op._make.meta_schedule_layout_transform")
3444 .set_body_typed(MakeMetaScheduleLayoutTransform);
3445
3446RELAY_REGISTER_OP("meta_schedule_layout_transform")
3447 .describe(R"code(Transform the input kernel layout.
3448)code" TVM_ADD_FILELINE)
3449 .set_attrs_type<MetaScheduleLayoutTransformAttrs>()
3450 .set_num_inputs(1)
3451 .add_argument("data", "Tensor", "The input tensor.")
3452 .add_type_rel("meta_schedule_layout_transform", MetaScheduleLayoutTransformRel)
3453 .set_support_level(5)
3454 .set_attr<FTVMCompute>("FTVMCompute", MetaScheduleLayoutTransformCompute);
3455
3456// relay._contrib_reverse_reshape
3457Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
3458 auto attrs = make_object<ReshapeAttrs>();
3459 attrs->newshape = std::move(newshape);
3460 static const Op& op = Op::Get("contrib_reverse_reshape");
3461 return Call(op, {data}, Attrs(attrs), {});
3462}
3463
3464TVM_REGISTER_GLOBAL("relay.op._make.contrib_reverse_reshape").set_body_typed(MakeReverseReshape);
3465
3466RELAY_REGISTER_OP("contrib_reverse_reshape")
3467 .describe(R"code(Reshapes the input array where the special values are inferred from
3468right to left.
3469
3470Example::
3471
3472The special values have the same semantics as reshape. The difference is that
3473special values are inferred from right to left. It can be explained in the
3474example below::
3475
3476- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5)
3477- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5)
3478
3479)code" TVM_ADD_FILELINE)
3480 .set_num_inputs(1)
3481 .set_attrs_type<ReshapeAttrs>()
3482 .add_argument("data", "Tensor", "The input tensor.")
3483 .set_support_level(10)
3484 .add_type_rel("ReverseReshape", ReverseReshapeRel)
3485 .set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
3486 .set_attr<TOpPattern>("TOpPattern", kInjective)
3487 .set_attr<TReshapeOp>("TReshapeOp", true);
3488
3489// gather operator
3490TVM_REGISTER_NODE_TYPE(GatherAttrs);
3491
3492bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3493 const TypeReporter& reporter) {
3494 // `types` contains: [data, indices, result]
3495 ICHECK_EQ(types.size(), 3);
3496 const auto* data = types[0].as<TensorTypeNode>();
3497 const auto* indices = types[1].as<TensorTypeNode>();
3498 if (data == nullptr) {
3499 ICHECK(types[0].as<IncompleteTypeNode>())
3500 << "Gather: expect input data type to be TensorType but get " << types[0];
3501 return false;
3502 }
3503 if (indices == nullptr) {
3504 ICHECK(types[1].as<IncompleteTypeNode>())
3505 << "Gather: expect indices type to be TensorType but get " << types[1];
3506 return false;
3507 }
3508 ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer";
3509 const auto param = attrs.as<GatherAttrs>();
3510 ICHECK(param != nullptr);
3511 ICHECK(param->axis.defined());
3512
3513 const auto ndim_data = data->shape.size();
3514 const auto ndim_indices = indices->shape.size();
3515 int axis = param->axis->value;
3516 ICHECK_EQ(ndim_data, ndim_indices);
3517 if (axis < 0) {
3518 axis += ndim_data;
3519 }
3520 ICHECK_GE(axis, 0);
3521 ICHECK_LT(axis, ndim_data);
3522
3523 std::vector<IndexExpr> oshape;
3524 oshape.reserve(ndim_data);
3525 for (size_t i = 0; i < ndim_data; ++i) {
3526 if (i == static_cast<size_t>(axis)) {
3527 if (indices->shape[i].as<IntImmNode>()) {
3528 const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]);
3529 ICHECK_GE(*indice_shape_i, 1);
3530 }
3531 } else {
3532 ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i]));
3533 }
3534 oshape.emplace_back(indices->shape[i]);
3535 }
3536 reporter->Assign(types[2], TensorType(oshape, data->dtype));
3537 return true;
3538}
3539
3540Array<te::Tensor> GatherCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3541 const Type& out_type) {
3542 const auto* param = attrs.as<GatherAttrs>();
3543 return {topi::gather(inputs[0], param->axis.IntValue(), inputs[1])};
3544}
3545
3546Expr MakeGather(Expr data, Integer axis, Expr indices) {
3547 auto attrs = make_object<GatherAttrs>();
3548 attrs->axis = std::move(axis);
3549 static const Op& op = Op::Get("gather");
3550 return Call(op, {data, indices}, Attrs(attrs), {});
3551}
3552
3553TVM_REGISTER_GLOBAL("relay.op._make.gather").set_body_typed(MakeGather);
3554
3555RELAY_REGISTER_OP("gather")
3556 .describe(R"code(Gather values along given axis from given indices.
3557
3558E.g. for a 3D tensor, output is computed as:
3559
3560 out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0
3561 out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1
3562 out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2
3563
3564``indices`` must have same shape as ``data``, except at dimension ``axis``
3565which must just be not null. Output will have same shape as ``indices``.
3566)code" TVM_ADD_FILELINE)
3567 .set_attrs_type<GatherAttrs>()
3568 .set_num_inputs(2)
3569 .add_argument("data", "Tensor", "The input data to the operator.")
3570 .add_argument("indices", "Tensor", "The indices of values to gather.")
3571 .set_support_level(3)
3572 .add_type_rel("Gather", GatherRel)
3573 .set_attr<FTVMCompute>("FTVMCompute", GatherCompute)
3574 .set_attr<TOpPattern>("TOpPattern", kInjective);
3575
3576TVM_REGISTER_NODE_TYPE(GatherNDAttrs);
3577
3578// gather_nd operator
3579bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3580 const TypeReporter& reporter) {
3581 // `types` contains: [data, indices, result]
3582 ICHECK_EQ(types.size(), 3);
3583 const auto* data = types[0].as<TensorTypeNode>();
3584 const auto* indices = types[1].as<TensorTypeNode>();
3585 if (data == nullptr) {
3586 ICHECK(types[0].as<IncompleteTypeNode>())
3587 << "GatherND: expect input data type to be TensorType but get " << types[0];
3588 return false;
3589 }
3590 if (indices == nullptr) {
3591 ICHECK(types[1].as<IncompleteTypeNode>())
3592 << "GatherND: expect indices type to be TensorType but get " << types[1];
3593 return false;
3594 }
3595 const size_t ndim = data->shape.size();
3596 const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
3597 ICHECK(mdim) << "GatherND needs a static shape for the first axis of indices, got "
3598 << indices->shape;
3599 const size_t kdim = indices->shape.size() - 1;
3600 ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy.";
3601
3602 const auto param = attrs.as<GatherNDAttrs>();
3603 ICHECK(param != nullptr);
3604
3605 for (int i = 0; i < param->batch_dims->value; ++i) {
3606 ICHECK(reporter->AssertEQ(
3607 data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple
3608 }
3609
3610 Array<IndexExpr> oshape;
3611 for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]);
3612 for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i)
3613 oshape.push_back(data->shape[i]);
3614 reporter->Assign(types[2], TensorType(oshape, data->dtype));
3615 return true;
3616}
3617
3618Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3619 const Type& out_type) {
3620 const auto* param = attrs.as<GatherNDAttrs>();
3621 ICHECK(param);
3622 return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims.IntValue())};
3623}
3624
3625Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0,
3626 Optional<Integer> index_rank = NullValue<Integer>()) {
3627 static const Op& op = Op::Get("gather_nd");
3628 auto attrs = make_object<GatherNDAttrs>();
3629 attrs->batch_dims = batch_dims;
3630 attrs->index_rank = index_rank;
3631 return Call(op, {data, indices}, Attrs(attrs));
3632}
3633
3634TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND);
3635
3636RELAY_REGISTER_OP("gather_nd")
3637 .describe(R"code(Gather elements or slices from data and store to
3638 a tensor whose shape is defined by indices.
3639
3640Optionally, batch_dims, the number of batch dimensions, can be given, whose
3641default value is 0.
3642
3643Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}),
3644(M, Y_0, ..., Y_{K-1}) respectively.
3645
3646When B > 0, indexing will start from the B-th axis, and it must be the case that
3647X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. The output will have a shape
3648(X_0, ..., X_{B-1}, Y_B, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N.
3649
3650When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}).
3651
3652In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}).
3653)code" TVM_ADD_FILELINE)
3654 .set_num_inputs(2)
3655 .set_attrs_type<GatherNDAttrs>()
3656 .add_argument("data", "Tensor", "The input tensor.")
3657 .add_argument("indices", "Tensor", "The indices of values to gather.")
3658 .set_support_level(3)
3659 .add_type_rel("GatherND", GatherNDRel)
3660 .set_attr<FTVMCompute>("FTVMCompute", GatherNDCompute)
3661 .set_attr<TOpPattern>("TOpPattern", kInjective);
3662
3663// relay.sequence_mask
3664TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs);
3665
3666bool SequenceMaskRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3667 const TypeReporter& reporter) {
3668 // `types` contains: [data, valid_length, result]
3669 ICHECK_EQ(types.size(), 3);
3670 const auto* data = types[0].as<TensorTypeNode>();
3671 const auto* valid_length = types[1].as<TensorTypeNode>();
3672 ICHECK(data);
3673 ICHECK(valid_length);
3674 const auto param = attrs.as<SequenceMaskAttrs>();
3675 Array<IndexExpr> valid_length_shape;
3676 ICHECK(param->axis == 0 || param->axis == 1);
3677 valid_length_shape.push_back(data->shape[1 - param->axis]);
3678 reporter->Assign(types[1], TensorType(valid_length_shape, valid_length->dtype));
3679 reporter->Assign(types[2], types[0]);
3680 return true;
3681}
3682
3683Array<te::Tensor> SequenceMaskCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3684 const Type& out_type) {
3685 const auto* param = attrs.as<SequenceMaskAttrs>();
3686 ICHECK(param != nullptr);
3687 return Array<te::Tensor>{
3688 topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis)};
3689}
3690
3691Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) {
3692 auto attrs = make_object<SequenceMaskAttrs>();
3693 attrs->mask_value = std::move(mask_value);
3694 attrs->axis = std::move(axis);
3695 static const Op& op = Op::Get("sequence_mask");
3696 return Call(op, {data, valid_length}, Attrs(attrs), {});
3697}
3698
3699TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask").set_body_typed(MakeSequenceMask);
3700
3701RELAY_REGISTER_OP("sequence_mask")
3702 .describe(
3703 R"code(Sets all elements outside the expected length of the sequence to a constant value.
3704
3705This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or
3706[batch_size, MAX_LENGTH, ...] and returns an array of the same shape.
3707
3708`axis` means the axis of the length dimension and can only be 0 or 1. If axis is 0,
3709the data must have shape [MAX_LENGTH, batch_size, ...]. Otherwise (axis=1), the data must have
3710shape [batch_size, MAX_LENGTH, ...].
3711
3712`valid_length` gives the length of each sequence. `valid_length` should be
3713a 1D int array with positive ints and has dimension [batch_size,].
3714
3715Examples::
3716
3717 x = [[[ 1., 2., 3.],
3718 [ 4., 5., 6.]],
3719
3720 [[ 7., 8., 9.],
3721 [ 10., 11., 12.]],
3722
3723 [[ 13., 14., 15.],
3724 [ 16., 17., 18.]]]
3725
3726 // valid_length [1, 1] means only the first block of each batch will be kept
3727 // and other blocks are masked with default mask value = 0
3728 sequence_mask(x, valid_length=[1, 1]) =
3729 [[[ 1., 2., 3.],
3730 [ 4., 5., 6.]],
3731
3732 [[ 0., 0., 0.],
3733 [ 0., 0., 0.]],
3734
3735 [[ 0., 0., 0.],
3736 [ 0., 0., 0.]]]
3737
3738 // valid_length [2, 3] means the first 2 blocks of the 1st batch will be kept
3739 // and the first 3 blocks of the 2nd batch will be kept
3740 // the masked values are set to be the specified mask value = 0.1
3741 sequence_mask(x, valid_length=[2, 3], mask_value=0.1) =
3742 [[[ 1., 2., 3.],
3743 [ 4., 5., 6.]],
3744
3745 [[ 7., 8., 9.],
3746 [ 10., 11., 12.]],
3747
3748 [[ 0.1, 0.1, 0.1],
3749 [ 16., 17., 18.]]]
3750)code" TVM_ADD_FILELINE)
3751 .set_attrs_type<SequenceMaskAttrs>()
3752 .set_num_inputs(2)
3753 .add_argument("data", "Tensor", "The input tensor.")
3754 .add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.")
3755 .set_support_level(10)
3756 .add_type_rel("SequenceMask", SequenceMaskRel)
3757 .set_attr<FTVMCompute>("FTVMCompute", SequenceMaskCompute)
3758 .set_attr<TOpPattern>("TOpPattern", kInjective);
3759
3760// relay.one_hot
3761TVM_REGISTER_NODE_TYPE(OneHotAttrs);
3762
3763bool OneHotRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3764 const TypeReporter& reporter) {
3765 // `types` contains: [indices, on_value, off_value, result]
3766 ICHECK_EQ(types.size(), 4);
3767 const auto* indices = types[0].as<TensorTypeNode>();
3768 ICHECK(indices);
3769
3770 const auto param = attrs.as<OneHotAttrs>();
3771 ICHECK_GT(param->depth, 0);
3772
3773 Array<IndexExpr> oshape;
3774 int ndim = indices->shape.size() + 1;
3775 int indices_index = 0;
3776 int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis;
3777 for (int i = 0; i < ndim; i++) {
3778 if (i == true_axis) {
3779 oshape.push_back(Integer(param->depth));
3780 } else {
3781 oshape.push_back(indices->shape[indices_index++]);
3782 }
3783 }
3784
3785 reporter->Assign(types[3], TensorType(oshape, param->dtype));
3786 return true;
3787}
3788
3789Array<te::Tensor> OneHotCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3790 const Type& out_type) {
3791 const auto* param = attrs.as<OneHotAttrs>();
3792 ICHECK(param != nullptr);
3793 return Array<te::Tensor>{
3794 topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype)};
3795}
3796
3797Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype) {
3798 auto attrs = make_object<OneHotAttrs>();
3799 attrs->depth = std::move(depth);
3800 attrs->axis = axis;
3801 attrs->dtype = dtype;
3802 static const Op& op = Op::Get("one_hot");
3803 return Call(op, {indices, on_value, off_value}, Attrs(attrs), {});
3804}
3805
3806TVM_REGISTER_GLOBAL("relay.op._make.one_hot").set_body_typed(MakeOneHot);
3807
3808RELAY_REGISTER_OP("one_hot")
3809 .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1,
3810 other locations take value 0. Final dimension is <indices dimensions> x depth.
3811
3812 **indices** Locations to set to 1.
3813
3814 **on_value** Value to fill at indices.
3815
3816 **off_value** Value to fill at all other positions besides indices.
3817
3818 **depth** Depth of the one-hot dimension.
3819
3820 **axis** Axis to fill.
3821
3822 **dtype**)code" TVM_ADD_FILELINE)
3823 .set_attrs_type<OneHotAttrs>()
3824 .set_num_inputs(3)
3825 .add_argument("indices", "Tensor", "Locations to set to on_value.")
3826 .add_argument("on_value", "Expr", "Value to fill at indices.")
3827 .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.")
3828 .set_support_level(10)
3829 .add_type_rel("OneHot", OneHotRel)
3830 .set_attr<FTVMCompute>("FTVMCompute", OneHotCompute)
3831 .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
3832
3833/* relay.unravel_index */
3834bool UnRavelIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3835 const TypeReporter& reporter) {
3836 ICHECK_EQ(types.size(), 3);
3837
3838 const auto* indices = types[0].as<TensorTypeNode>();
3839 if (indices == nullptr) {
3840 ICHECK(types[0].as<IncompleteTypeNode>())
3841 << "unravel_index: expect input type to be TensorType but get " << types[0];
3842 return false;
3843 }
3844 ICHECK(indices->dtype.is_int() || indices->dtype.is_uint())
3845 << "indices of unravel_index must be tensor of integer";
3846
3847 const auto* shape = types[1].as<TensorTypeNode>();
3848 if (shape == nullptr) {
3849 ICHECK(types[1].as<IncompleteTypeNode>())
3850 << "unravel_index: expect input type to be TensorType but get " << types[1];
3851 return false;
3852 }
3853 ICHECK(shape->dtype.is_int() || shape->dtype.is_uint())
3854 << "shape of unravel_index must be tensor of integer";
3855
3856 Array<IndexExpr> indices_shape;
3857 Array<IndexExpr> shape_shape;
3858 indices_shape = indices->shape;
3859 shape_shape = shape->shape;
3860
3861 Array<IndexExpr> oshape;
3862 oshape.push_back(shape_shape[0]);
3863 if (indices_shape.size() != 0) {
3864 oshape.push_back(indices_shape[0]);
3865 }
3866 reporter->Assign(types[2], TensorType(oshape, indices->dtype));
3867 return true;
3868}
3869
3870Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3871 const Type& out_type) {
3872 return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])};
3873}
3874
3875Expr MakeUnRavelIndex(Expr data, Expr shape) {
3876 static const Op& op = Op::Get("unravel_index");
3877 return Call(op, {data, shape}, Attrs(), {});
3878}
3879
3880TVM_REGISTER_GLOBAL("relay.op._make.unravel_index").set_body_typed(MakeUnRavelIndex);
3881
3882RELAY_REGISTER_OP("unravel_index")
3883 .describe(
3884 R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays.
3885
3886Example::
3887 - unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]]
3888)code" TVM_ADD_FILELINE)
3889 .set_num_inputs(2)
3890 .add_argument("data", "Tensor", "The input tensor.")
3891 .add_argument("shape", "Tensor", "The shape tensor.")
3892 .set_support_level(3)
3893 .add_type_rel("UnRavelIndexRel", UnRavelIndexRel)
3894 .set_attr<FTVMCompute>("FTVMCompute", UnRavelIndexCompute)
3895 .set_attr<TOpPattern>("TOpPattern", kInjective);
3896
3897// sparse_to_dense
3898TVM_REGISTER_NODE_TYPE(SparseToDenseAttrs);
3899
3900bool SparseToDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3901 const TypeReporter& reporter) {
3902 ICHECK_EQ(num_inputs, 3);
3903 auto sparse_indices = types[0].as<TensorTypeNode>();
3904 auto sparse_values = types[1].as<TensorTypeNode>();
3905 auto default_value = types[2].as<TensorTypeNode>();
3906
3907 if (sparse_indices == nullptr || sparse_values == nullptr || default_value == nullptr) {
3908 return false;
3909 }
3910
3911 ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers";
3912
3913 ICHECK_LE(sparse_indices->shape.size(), 3)
3914 << "sparse_indices must be a tensor of either 0D, 1D or 2D";
3915
3916 ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values must be a tensor of either 0D, 1D";
3917
3918 ICHECK_EQ(default_value->shape.size(), 0) << "default_value should be a scalar";
3919
3920 const auto* param = attrs.as<SparseToDenseAttrs>();
3921 ICHECK(param != nullptr);
3922
3923 Array<IndexExpr> oshape;
3924 for (auto i : param->output_shape) {
3925 oshape.push_back(i);
3926 }
3927 reporter->Assign(types[3], TensorType(oshape, sparse_values->dtype));
3928 return true;
3929}
3930
3931Array<te::Tensor> SparseToDenseCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
3932 const Type& out_type) {
3933 ICHECK_EQ(inputs.size(), 3);
3934 const auto* param = attrs.as<SparseToDenseAttrs>();
3935 ICHECK(param != nullptr);
3936 Array<IndexExpr> output_shape;
3937 for (auto val : param->output_shape) {
3938 output_shape.push_back(val);
3939 }
3940 return {topi::sparse_to_dense(inputs[0], output_shape, inputs[1], inputs[2]())};
3941}
3942
3943Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value) {
3944 auto attrs = make_object<SparseToDenseAttrs>();
3945 attrs->output_shape = std::move(output_shape);
3946 static const Op& op = Op::Get("sparse_to_dense");
3947 return Call(op, {indices, values, default_value}, Attrs(attrs));
3948}
3949
3950TVM_REGISTER_GLOBAL("relay.op._make.sparse_to_dense").set_body_typed(MakeSparseToDense);
3951
3952RELAY_REGISTER_OP("sparse_to_dense")
3953 .describe(R"code(A dense tensor from a sparse representation.
3954
3955 - **sparse_indices**: A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values
3956
3957 - **output_shape**: A list of integers. Shape of the dense output tensor.
3958
3959 - **sparse_values**: A 0-D or 1-D tensor containing the sparse values for the sparse indices.
3960
3961 - **default_value**: A 0-D tensor containing the default value for the remaining locations. Defaults to 0.
3962
3963 Example::
3964 - sparse_to_dense([0, 0], [1, 2]], [3, 4], [1, 2], 0) = [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]
3965
3966 )code" TVM_ADD_FILELINE)
3967 .set_num_inputs(3)
3968 .set_support_level(3)
3969 .set_attrs_type<SparseToDenseAttrs>()
3970 .add_argument("sparse_indices", "Tensor", "Contains sparse indices.")
3971 .add_argument("sparse_values", "Tensor", "Contains values for sparse indices.")
3972 .add_argument("default_value", "Tensor", "Value to set for non-sparse indices. Defaults to 0.")
3973 .add_type_rel("SparseToDense", SparseToDenseRel)
3974 .set_attr<TOpIsStateful>("TOpIsStateful", false)
3975 .set_attr<TOpPattern>("TOpPattern", kOpaque)
3976 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
3977 .set_attr<FTVMCompute>("FTVMCompute", SparseToDenseCompute);
3978
3979// relay.matrix_set_diag
3980TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs);
3981
3982bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
3983 const TypeReporter& reporter) {
3984 // `types` contains: [input, diagonal, result]
3985 ICHECK_EQ(types.size(), 3);
3986
3987 const auto* input = types[0].as<TensorTypeNode>();
3988 ICHECK(input);
3989
3990 const auto* diagonal = types[1].as<TensorTypeNode>();
3991 ICHECK(diagonal);
3992
3993 const auto param = attrs.as<MatrixSetDiagAttrs>();
3994 ICHECK_GE(param->k2, param->k1);
3995
3996 int d_ndims = diagonal->shape.size();
3997 int i_ndims = input->shape.size();
3998
3999 reporter->Assert(input->shape[i_ndims - 2] > -param->k1);
4000 reporter->Assert(input->shape[i_ndims - 1] > param->k2);
4001
4002 for (int i = 0; i < d_ndims - 2; i++) {
4003 reporter->AssertEQ(input->shape[i], diagonal->shape[i]);
4004 }
4005 if (param->k1 != param->k2) {
4006 reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1);
4007 } else if (d_ndims >= 2) {
4008 reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]);
4009 }
4010 auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <=
4011 input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0),
4012 input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0),
4013 input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0));
4014 reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len);
4015
4016 reporter->Assign(types[2], TensorType(input->shape, input->dtype));
4017 return true;
4018}
4019
4020Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
4021 const Type& out_type) {
4022 const auto* param = attrs.as<MatrixSetDiagAttrs>();
4023 ICHECK(param != nullptr);
4024 return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2,
4025 param->super_diag_right_align,
4026 param->sub_diag_right_align)};
4027}
4028
4029Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align,
4030 bool sub_diag_right_align) {
4031 auto attrs = make_object<MatrixSetDiagAttrs>();
4032 attrs->k1 = k1;
4033 attrs->k2 = k2;
4034 attrs->super_diag_right_align = super_diag_right_align;
4035 attrs->sub_diag_right_align = sub_diag_right_align;
4036 static const Op& op = Op::Get("matrix_set_diag");
4037 return Call(op, {input, diagonal}, Attrs(attrs), {});
4038}
4039
4040TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag);
4041
4042RELAY_REGISTER_OP("matrix_set_diag")
4043 .describe(
4044 R"code(Returns a tensor with the diagonals of input tensor replaced with the provided diagonal values.
4045 **input** Input tensor.
4046 **diagonal** Values to be filled in the diagonal.
4047 **k1** Lower limit (included) of the range of diagonals.
4048 **k2** Upper limit (included) of the range of diagonals.
4049 **super_diag_right_align** Bool, true iff super-diagonal is right aligned (left-padded).
4050 **sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded).
4051 )code" TVM_ADD_FILELINE)
4052 .set_attrs_type<MatrixSetDiagAttrs>()
4053 .set_num_inputs(2)
4054 .add_argument("input", "Tensor", "Input Tensor.")
4055 .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.")
4056 .set_support_level(10)
4057 .add_type_rel("MatrixSetDiag", MatrixSetDiagRel)
4058 .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
4059 .set_attr<TOpPattern>("TOpPattern", kInjective);
4060
4061// adv_index
4062bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4063 const TypeReporter& reporter) {
4064 ICHECK_EQ(num_inputs, 1);
4065 auto inputs = types[0].as<TupleTypeNode>();
4066 auto data = inputs->fields[0].as<TensorTypeNode>();
4067
4068 if (inputs == nullptr || data == nullptr) {
4069 return false;
4070 }
4071 ICHECK_LE(inputs->fields.size() - 1, data->shape.size()) << "too many indices for data!";
4072
4073 Array<IndexExpr> oshape;
4074 TensorType broadcast_type = Downcast<TensorType>(inputs->fields[1]);
4075 for (size_t i = 2; i < inputs->fields.size(); ++i) {
4076 broadcast_type =
4077 ConcreteBroadcast(broadcast_type, Downcast<TensorType>(inputs->fields[i]), data->dtype);
4078 }
4079
4080 for (const auto& dim : broadcast_type->shape) {
4081 oshape.push_back(dim);
4082 }
4083 for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) {
4084 oshape.push_back(data->shape[i]);
4085 }
4086 reporter->Assign(types[1], TensorType(oshape, data->dtype));
4087 return true;
4088}
4089
4090Array<te::Tensor> AdvIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
4091 const Type& out_type) {
4092 Array<te::Tensor> indices;
4093 for (size_t i = 1; i < inputs.size(); ++i) {
4094 indices.push_back(inputs[i]);
4095 }
4096 return {topi::adv_index(inputs[0], indices)};
4097}
4098
4099Expr MakeAdvIndex(Expr inputs) {
4100 static const Op& op = Op::Get("adv_index");
4101 return Call(op, {inputs}, Attrs(), {});
4102}
4103
4104TVM_REGISTER_GLOBAL("relay.op._make.adv_index").set_body_typed(MakeAdvIndex);
4105
4106RELAY_REGISTER_OP("adv_index")
4107 .describe(R"code(Numpy style advanced indexing. Index with a list of tensors.
4108 )code" TVM_ADD_FILELINE)
4109 .set_num_inputs(1)
4110 .set_support_level(3)
4111 .add_argument("inputs", "Tuple of Tensors", "Input tensor and indices.")
4112 .add_type_rel("AdvIndex", AdvIndexRel)
4113 .set_attr<TOpIsStateful>("TOpIsStateful", false)
4114 .set_attr<TOpPattern>("TOpPattern", kInjective)
4115 .set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);
4116
4117TVM_REGISTER_NODE_TYPE(ScanopAttrs);
4118
4119bool ScanopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4120 const TypeReporter& reporter) {
4121 // types: [data, output]
4122 ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output";
4123 const auto* data = types[0].as<TensorTypeNode>();
4124 if (data == nullptr) {
4125 ICHECK(types[0].as<IncompleteTypeNode>())
4126 << "Scanop: expect input type to be TensorType but get " << types[0];
4127 return false;
4128 }
4129
4130 const auto* param = attrs.as<ScanopAttrs>();
4131
4132 auto dtype = param->dtype;
4133 if (dtype.is_void()) {
4134 dtype = data->dtype;
4135 }
4136
4137 if (param->axis.defined()) {
4138 reporter->Assign(types[1], TensorType(data->shape, dtype));
4139 } else {
4140 auto prod = data->shape[0];
4141 for (size_t i = 1; i < data->shape.size(); ++i) {
4142 prod = prod * data->shape[i];
4143 }
4144 reporter->Assign(types[1], TensorType({prod}, dtype));
4145 }
4146
4147 return true;
4148}
4149
4150Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) {
4151 auto attrs = make_object<ScanopAttrs>();
4152 attrs->dtype = dtype;
4153 attrs->axis = axis;
4154 attrs->exclusive = exclusive;
4155 static const Op& op = Op::Get("cumsum");
4156 return Call(op, {data}, Attrs(attrs), {});
4157}
4158
4159TVM_REGISTER_GLOBAL("relay.op._make.cumsum").set_body_typed(MakeCumsum);
4160
4161RELAY_REGISTER_OP("cumsum")
4162 .describe(
4163 R"doc(Return the cumulative sum of the elements along a given axis.)doc" TVM_ADD_FILELINE)
4164 .set_num_inputs(1)
4165 .add_argument("data", "Tensor", "The input tensor.")
4166 .set_support_level(3)
4167 .add_type_rel("Cumsum", ScanopRel)
4168 .set_attr<TOpPattern>("TOpPattern", kOpaque);
4169
4170Expr MakeCumprod(Expr data, Integer axis, DataType dtype, Bool exclusive) {
4171 auto attrs = make_object<ScanopAttrs>();
4172 attrs->dtype = dtype;
4173 attrs->axis = axis;
4174 attrs->exclusive = exclusive;
4175 static const Op& op = Op::Get("cumprod");
4176 return Call(op, {data}, Attrs(attrs), {});
4177}
4178
4179TVM_REGISTER_GLOBAL("relay.op._make.cumprod").set_body_typed(MakeCumprod);
4180
4181RELAY_REGISTER_OP("cumprod")
4182 .describe(
4183 R"doc(Return the cumulative product of the elements along a given axis.)doc" TVM_ADD_FILELINE)
4184 .set_num_inputs(1)
4185 .add_argument("data", "Tensor", "The input tensor.")
4186 .set_support_level(3)
4187 .add_type_rel("Cumprod", ScanopRel)
4188 .set_attr<TOpPattern>("TOpPattern", kOpaque);
4189
4190TVM_REGISTER_NODE_TYPE(UniqueAttrs);
4191
4192bool UniqueRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4193 const TypeReporter& reporter) {
4194 // types: [data, result]
4195 ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided";
4196 ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided";
4197 auto data = types[0].as<TensorTypeNode>();
4198 if (data == nullptr) {
4199 ICHECK(types[0].as<IncompleteTypeNode>())
4200 << "Unique: expect input type to be TensorType but get " << types[0];
4201 return false;
4202 }
4203 const int ndim = static_cast<int>(data->shape.size());
4204 ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor";
4205
4206 std::vector<Type> fields;
4207 fields.push_back(TensorType(data->shape, data->dtype)); // unique
4208 fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices
4209 fields.push_back(TensorType(data->shape, DataType::Int(32))); // inverse_indices
4210 fields.push_back(TensorType(Array<PrimExpr>{1}, DataType::Int(32))); // num_unique
4211 const auto* param = attrs.as<UniqueAttrs>();
4212 if (param->return_counts) {
4213 fields.push_back(TensorType(data->shape, DataType::Int(32))); // counts
4214 }
4215 reporter->Assign(types[1], TupleType(Array<Type>(fields)));
4216 return true;
4217}
4218
4219Expr MakeUnique(Expr data, bool sorted, bool return_counts) {
4220 auto attrs = make_object<UniqueAttrs>();
4221 attrs->sorted = sorted;
4222 attrs->return_counts = return_counts;
4223 static const Op& op = Op::Get("unique");
4224 return Call(op, {data}, Attrs(attrs), {});
4225}
4226
4227TVM_REGISTER_GLOBAL("relay.op._make.unique").set_body_typed(MakeUnique);
4228
4229RELAY_REGISTER_OP("unique")
4230 .describe(
4231 R"code(This operation returns the unique elements and the new index of each item in a given 1-D array.
4232 )code" TVM_ADD_FILELINE)
4233 .set_num_inputs(1)
4234 .add_argument("data", "Tensor", "The input tensor")
4235 .add_type_rel("unique", UniqueRel)
4236 .set_support_level(3)
4237 .set_attr<TOpPattern>("TOpPattern", kOpaque);
4238
4239// invert_permutation
4240Expr MakeInvertPermutation(Expr data) {
4241 static const Op& op = Op::Get("invert_permutation");
4242 return Call(op, {data}, Attrs(), {});
4243}
4244
4245TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation").set_body_typed(MakeInvertPermutation);
4246
4247RELAY_REGISTER_OP("invert_permutation")
4248 .describe(R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE)
4249 .set_num_inputs(1)
4250 .add_argument("data", "Tensor", "The input tensor.")
4251 .add_type_rel("Identity", IdentityRel)
4252 .set_support_level(1)
4253 .set_attr<TOpPattern>("TOpPattern", kInjective)
4254 .set_attr<TOpIsStateful>("TOpIsStateful", false);
4255
4256// Trilu
4257
4258TVM_REGISTER_NODE_TYPE(TriluAttrs);
4259
4260bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4261 const TypeReporter& reporter) {
4262 // types: [data, k, result]
4263 ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided";
4264 ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided";
4265 auto data = types[0].as<TensorTypeNode>();
4266 if (data == nullptr) {
4267 ICHECK(types[0].as<IncompleteTypeNode>())
4268 << "Trilu: expect input type to be TensorType but get " << types[0];
4269 return false;
4270 }
4271
4272 auto k = types[1].as<TensorTypeNode>();
4273 if (k == nullptr) {
4274 ICHECK(types[1].as<IncompleteTypeNode>())
4275 << "Trilu: expect k type to be TensorType but get " << types[1];
4276 return false;
4277 }
4278
4279 ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k;
4280
4281 // Output shape is the same as input shape.
4282 reporter->Assign(types[2], TensorType(data->shape, data->dtype));
4283 return true;
4284}
4285
4286Expr MakeTrilu(Expr data, Expr k, bool upper) {
4287 auto attrs = make_object<TriluAttrs>();
4288 attrs->upper = upper;
4289 static const Op& op = Op::Get("trilu");
4290 return Call(op, {data, k}, Attrs(attrs), {});
4291}
4292
4293TVM_REGISTER_GLOBAL("relay.op._make.trilu").set_body_typed(MakeTrilu);
4294
4295RELAY_REGISTER_OP("trilu")
4296 .describe(
4297 R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal.
4298 )code" TVM_ADD_FILELINE)
4299 .set_num_inputs(2)
4300 .add_argument("data", "Tensor", "The input tensor")
4301 .add_argument("k", "Tensor", "The number of diagonals above or below the main to exclude.")
4302 .add_type_rel("trilu", TriluRel)
4303 .set_support_level(3)
4304 .set_attr<TOpPattern>("TOpPattern", kElemWise);
4305
4306// FixedPointMultiplyPerAxis
4307
4308TVM_REGISTER_NODE_TYPE(FixedPointMultiplyPerAxisAttrs);
4309
4310bool FixedPointMultiplyPerAxisRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4311 const TypeReporter& reporter) {
4312 ICHECK_EQ(types.size(), 5) << "FixedPointMultiplyPerAxis: expect 5 types but " << types.size()
4313 << " provided";
4314 ICHECK_EQ(num_inputs, 4) << "FixedPointMultiplyPerAxis: expect 4 inputs but " << num_inputs
4315 << " provided";
4316
4317 for (int i = 0; i < num_inputs; i++) {
4318 auto data = types[i].as<TensorTypeNode>();
4319 if (data == nullptr) {
4320 ICHECK(types[i].as<IncompleteTypeNode>())
4321 << "FixedPointMultiplyPerAxis: expect input type to be TensorType but get " << types[i];
4322 return false;
4323 }
4324 }
4325
4326 return IdentityRel({types[0], types[4]}, 1, attrs, reporter);
4327}
4328
4329InferCorrectLayoutOutput FixedPointMultiplyPerAxisInferCorrectLayout(
4330 const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
4331 const Array<tvm::relay::Type>& old_in_types) {
4332 const auto* attrs_ptr = attrs.as<FixedPointMultiplyPerAxisAttrs>();
4333 ICHECK(attrs_ptr);
4334 ObjectPtr<FixedPointMultiplyPerAxisAttrs> param =
4335 make_object<FixedPointMultiplyPerAxisAttrs>(*attrs_ptr);
4336
4337 Array<Array<IndexExpr>> old_in_shapes;
4338 for (auto old_in_t : old_in_types) {
4339 ICHECK(old_in_t.as<TensorTypeNode>());
4340 old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
4341 }
4342
4343 Array<Layout> input_layouts, output_layouts;
4344
4345 if (new_in_layouts.defined()) {
4346 const Layout& new_layout = new_in_layouts[0];
4347 const Layout& old_layout = old_in_layouts[0];
4348
4349 std::unordered_set<std::string> old_dims;
4350 for (auto axis : param->axes) {
4351 ICHECK_GE(axis->value, 0) << "Axis out of bounds in FixedPointMultiplyPerAxis operator.";
4352 ICHECK_LT(axis->value, old_in_shapes[0].size())
4353 << "Axis out of bounds in FixedPointMultiplyPerAxis operator.";
4354 old_dims.emplace(old_layout[axis->value].name());
4355 }
4356
4357 Array<tvm::Integer> new_axes;
4358 std::string new_layout_string = "";
4359 for (size_t axis_index = 0; axis_index < new_layout->axes.size(); ++axis_index) {
4360 const auto& layout_axis = LayoutAxis::Get(new_layout->axes[axis_index]);
4361 const std::string& layout_dim = layout_axis.name();
4362 if (layout_axis.IsPrimal()) {
4363 if (old_dims.count(layout_dim)) {
4364 new_axes.push_back(tvm::Integer(axis_index));
4365 new_layout_string += layout_dim;
4366 }
4367 } else {
4368 auto primal_dim = layout_axis.ToPrimal().name();
4369 if (old_dims.count(primal_dim)) {
4370 new_axes.push_back(tvm::Integer(axis_index));
4371 new_layout_string += std::to_string(new_layout.FactorOf(layout_axis)) + layout_dim;
4372 }
4373 }
4374 }
4375
4376 Layout channel_layout = Layout(new_layout_string);
4377
4378 input_layouts = {new_layout, channel_layout, channel_layout, channel_layout};
4379 output_layouts = {new_layout};
4380 param->axes = std::move(new_axes);
4381 } else if (old_in_layouts.defined()) {
4382 ICHECK_EQ(old_in_layouts.size(), 4);
4383 ICHECK_EQ(param->axes.size(), 1); // Not tested other cases
4384 const Layout& old_layout = old_in_layouts[0];
4385 if (old_layout.defined()) {
4386 std::string layout_string = old_layout[param->axes[0]->value].name();
4387 Layout channel_layout = Layout(layout_string);
4388
4389 input_layouts = {old_layout, channel_layout, channel_layout, channel_layout};
4390 output_layouts = {old_layout};
4391 } else {
4392 // Set the layouts to undef.
4393 Layout undef = Layout::Undef();
4394 input_layouts = Array<Layout>(4, undef);
4395 output_layouts = {undef};
4396 }
4397 } else {
4398 // Set the layouts to undef.
4399 Layout undef = Layout::Undef();
4400 input_layouts = Array<Layout>(4, undef);
4401 output_layouts = {undef};
4402 }
4403
4404 return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param));
4405}
4406
4407Expr MakeFixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift,
4408 bool is_lshift_required, bool is_rshift_required,
4409 Array<Integer> axes) {
4410 auto attrs = make_object<FixedPointMultiplyPerAxisAttrs>();
4411 attrs->is_lshift_required = is_lshift_required;
4412 attrs->is_rshift_required = is_rshift_required;
4413 attrs->axes = std::move(axes);
4414 static const Op& op = Op::Get("fixed_point_multiply_per_axis");
4415 return Call(op, {x, m, lshift, rshift}, Attrs(attrs), {});
4416}
4417
4418TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply_per_axis")
4419 .set_body_typed(MakeFixedPointMultiplyPerAxis);
4420
4421RELAY_REGISTER_OP("fixed_point_multiply_per_axis")
4422 .describe(R"code(per channel fixed point multiplication)code" TVM_ADD_FILELINE)
4423 .set_num_inputs(4)
4424 .add_argument("data", "Tensor", "The input tensor.")
4425 .add_argument("fp_multiplier", "Tensor", "The multipliers tensor.")
4426 .add_argument("left_shift", "Tensor", "The left shifts tensor.")
4427 .add_argument("right_shift", "Tensor", "The right shifts tensor.")
4428 .add_type_rel("FixedPointMultiplyPerAxis", FixedPointMultiplyPerAxisRel)
4429 .set_attr<TOpPattern>("TOpPattern", kBroadcast)
4430 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
4431 FixedPointMultiplyPerAxisInferCorrectLayout)
4432 .set_attrs_type<FixedPointMultiplyPerAxisAttrs>()
4433 .set_support_level(10);
4434
4435} // namespace relay
4436} // namespace tvm
4437