1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file transform.cc |
22 | * \brief Transform operators. |
23 | */ |
24 | #include "transform.h" |
25 | |
26 | #include <tvm/relay/attrs/transform.h> |
27 | #include <tvm/relay/error.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/op.h> |
30 | #include <tvm/runtime/packed_func.h> |
31 | #include <tvm/tir/data_layout.h> |
32 | #include <tvm/tir/expr.h> |
33 | #include <tvm/tir/op.h> |
34 | #include <tvm/topi/broadcast.h> |
35 | #include <tvm/topi/detail/constant_utils.h> |
36 | #include <tvm/topi/elemwise.h> |
37 | #include <tvm/topi/nn.h> |
38 | #include <tvm/topi/reduction.h> |
39 | #include <tvm/topi/transform.h> |
40 | |
41 | #include <sstream> |
42 | #include <vector> |
43 | |
44 | #include "../../transforms/infer_layout_utils.h" |
45 | #include "../../transforms/pass_utils.h" |
46 | #include "../../transforms/pattern_utils.h" |
47 | #include "../make_op.h" |
48 | #include "../op_common.h" |
49 | #include "../type_relations.h" |
50 | |
51 | namespace tvm { |
52 | namespace relay { |
53 | using tir::IntImmNode; |
54 | |
55 | TVM_REGISTER_NODE_TYPE(SlidingWindowAttrs); |
56 | |
57 | bool SlidingWindowRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
58 | const TypeReporter& reporter) { |
59 | // `types` contains: [data, result] |
60 | ICHECK_EQ(types.size(), 2); |
61 | const auto* data = types[0].as<TensorTypeNode>(); |
62 | if (data == nullptr) { |
63 | reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) |
64 | << "SlidingWindow operator expects input to be of TensorType " |
65 | << "but got " << PrettyPrint(types[0])); |
66 | return false; |
67 | } |
68 | const auto* param = attrs.as<SlidingWindowAttrs>(); |
69 | const int axis = param->axis; |
70 | |
71 | std::vector<IndexExpr> oshape; |
72 | |
73 | // Dimensions up until `axis` remain the same. |
74 | for (int i = 0; i < axis; ++i) { |
75 | oshape.emplace_back(data->shape[i]); |
76 | } |
77 | |
78 | // New dimensions which result from sliding the window in each dimension. One new dimension per |
79 | // window dimension. |
80 | for (size_t i = 0; i < param->window_shape.size(); ++i) { |
81 | // Length of the shape along this dimension. |
82 | auto dim_len = data->shape[axis + i]; |
83 | // Length of the window along this dimension. |
84 | auto window_len = param->window_shape[i]; |
85 | // Strides along this dimension. |
86 | auto stride = param->strides[i]; |
87 | |
88 | oshape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride)); |
89 | } |
90 | |
91 | // Dimensions comprising the window. |
92 | for (size_t i = 0; i < param->window_shape.size(); ++i) { |
93 | oshape.push_back(param->window_shape[i]); |
94 | } |
95 | |
96 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
97 | return true; |
98 | } |
99 | |
100 | Array<te::Tensor> SlidingWindowCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
101 | const Type& out_type) { |
102 | const SlidingWindowAttrs* param = attrs.as<SlidingWindowAttrs>(); |
103 | ICHECK(param != nullptr); |
104 | return {topi::sliding_window(inputs[0], param->axis, param->window_shape, param->strides)}; |
105 | } |
106 | |
107 | Expr MakeSlidingWindow(Expr data, int axis, Array<Integer> window_shape, Array<Integer> strides) { |
108 | auto attrs = make_object<SlidingWindowAttrs>(); |
109 | attrs->axis = axis; |
110 | attrs->window_shape = window_shape; |
111 | attrs->strides = strides; |
112 | static const Op& op = Op::Get("sliding_window" ); |
113 | return Call(op, {data}, Attrs(attrs), {}); |
114 | } |
115 | |
116 | TVM_REGISTER_GLOBAL("relay.ir.sliding_window" ).set_body_typed(MakeSlidingWindow); |
117 | |
118 | RELAY_REGISTER_OP("sliding_window" ) |
119 | .describe(R"code(Slide window over a tensor.)code" TVM_ADD_FILELINE) |
120 | .set_num_inputs(1) |
121 | .set_attrs_type<SlidingWindowAttrs>() |
122 | .add_argument("data" , "Tensor" , "The input tensor." ) |
123 | .add_type_rel("SlidingWindow" , SlidingWindowRel) |
124 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
125 | |
126 | // relay.cast |
127 | TVM_REGISTER_NODE_TYPE(CastAttrs); |
128 | |
129 | bool CastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
130 | const TypeReporter& reporter) { |
131 | ICHECK_EQ(types.size(), 2); |
132 | const auto* data = types[0].as<TensorTypeNode>(); |
133 | if (data == nullptr) { |
134 | ICHECK(types[0].as<IncompleteTypeNode>()) |
135 | << "cast: expect input type to be TensorType but get " << types[0]; |
136 | return false; |
137 | } |
138 | const auto* param = attrs.as<CastAttrs>(); |
139 | reporter->Assign(types[1], TensorType(data->shape, param->dtype)); |
140 | return true; |
141 | } |
142 | |
143 | Array<te::Tensor> CastCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
144 | const Type& out_type) { |
145 | const CastAttrs* param = attrs.as<CastAttrs>(); |
146 | ICHECK(param != nullptr); |
147 | DataType dtype = param->dtype; |
148 | return {topi::cast(inputs[0], dtype)}; |
149 | } |
150 | |
151 | Expr MakeCast(Expr data, DataType dtype) { |
152 | auto attrs = make_object<CastAttrs>(); |
153 | attrs->dtype = dtype; |
154 | static const Op& op = Op::Get("cast" ); |
155 | return Call(op, {data}, Attrs(attrs), {}); |
156 | } |
157 | |
158 | TVM_REGISTER_GLOBAL("relay.ir.cast" ).set_body_typed(MakeCast); |
159 | |
160 | RELAY_REGISTER_OP("cast" ) |
161 | .describe(R"code(Cast the data into a new data type. |
162 | |
163 | )code" TVM_ADD_FILELINE) |
164 | .set_num_inputs(1) |
165 | .set_attrs_type<CastAttrs>() |
166 | .add_argument("data" , "Tensor" , "The input tensor." ) |
167 | .set_support_level(3) |
168 | .add_type_rel("Cast" , CastRel) |
169 | .set_attr<FTVMCompute>("FTVMCompute" , CastCompute) |
170 | .set_attr<TOpPattern>("TOpPattern" , kElemWise) |
171 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout); |
172 | |
173 | // relay.cast_like |
174 | bool CastLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
175 | const TypeReporter& reporter) { |
176 | ICHECK_EQ(types.size(), 3); |
177 | const auto* data = types[0].as<TensorTypeNode>(); |
178 | if (data == nullptr) { |
179 | ICHECK(types[0].as<IncompleteTypeNode>()) |
180 | << "cast: expect input type to be TensorType but get " << types[0]; |
181 | return false; |
182 | } |
183 | const auto* dtype_like = types[1].as<TensorTypeNode>(); |
184 | if (dtype_like == nullptr) { |
185 | ICHECK(types[1].as<IncompleteTypeNode>()) |
186 | << "cast: expect input type to be TensorType but get " << types[1]; |
187 | return false; |
188 | } |
189 | reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype)); |
190 | return true; |
191 | } |
192 | |
193 | Array<te::Tensor> CastLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
194 | const Type& out_type) { |
195 | return {topi::cast(inputs[0], inputs[1]->dtype)}; |
196 | } |
197 | |
198 | Expr MakeCastLike(Expr data, Expr dtype_like) { |
199 | static const Op& op = Op::Get("cast_like" ); |
200 | return Call(op, {data, dtype_like}, Attrs(), {}); |
201 | } |
202 | |
203 | TVM_REGISTER_GLOBAL("relay.ir.cast_like" ).set_body_typed(MakeCastLike); |
204 | |
205 | RELAY_REGISTER_OP("cast_like" ) |
206 | .describe(R"code(Cast the data into the type of another tensor. |
207 | )code" TVM_ADD_FILELINE) |
208 | .set_num_inputs(2) |
209 | .add_argument("data" , "Tensor" , "The input tensor." ) |
210 | .add_argument("dtype_like" , "Tensor" , "The tensor to cast to." ) |
211 | .set_support_level(3) |
212 | .add_type_rel("CastLike" , CastLikeRel) |
213 | .set_attr<FTVMCompute>("FTVMCompute" , CastLikeCompute) |
214 | .set_attr<TOpPattern>("TOpPattern" , kElemWise) |
215 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout); |
216 | |
217 | Array<te::Tensor> ReinterpretCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
218 | const Type& out_type) { |
219 | const CastAttrs* param = attrs.as<CastAttrs>(); |
220 | ICHECK(param != nullptr); |
221 | DataType dtype = param->dtype; |
222 | return {topi::reinterpret(inputs[0], dtype)}; |
223 | } |
224 | |
225 | Expr MakeReinterpret(Expr data, DataType dtype) { |
226 | auto attrs = make_object<CastAttrs>(); |
227 | attrs->dtype = dtype; |
228 | static const Op& op = Op::Get("reinterpret" ); |
229 | return Call(op, {data}, Attrs(attrs), {}); |
230 | } |
231 | |
232 | TVM_REGISTER_GLOBAL("relay._make.reinterpret" ).set_body_typed(MakeReinterpret); |
233 | |
234 | RELAY_REGISTER_OP("reinterpret" ) |
235 | .describe(R"code(Reinterpret the data into a new data type. |
236 | )code" TVM_ADD_FILELINE) |
237 | .set_num_inputs(1) |
238 | .set_attrs_type<CastAttrs>() |
239 | .add_argument("data" , "Tensor" , "The input tensor." ) |
240 | .set_support_level(3) |
241 | .add_type_rel("Reinterpret" , CastRel) |
242 | .set_attr<FTVMCompute>("FTVMCompute" , ReinterpretCompute) |
243 | .set_attr<TOpPattern>("TOpPattern" , kElemWise) |
244 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout); |
245 | |
246 | // relay.expand_dims |
247 | TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); |
248 | |
249 | bool ExpandDimsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
250 | const TypeReporter& reporter) { |
251 | // `types` contains: [data, result] |
252 | ICHECK_EQ(types.size(), 2); |
253 | const auto* data = types[0].as<TensorTypeNode>(); |
254 | if (data == nullptr) { |
255 | ICHECK(types[0].as<IncompleteTypeNode>()) |
256 | << "expand_dims: expect input type to be TensorType but get " << types[0]; |
257 | return false; |
258 | } |
259 | const auto* param = attrs.as<ExpandDimsAttrs>(); |
260 | const int ndim = static_cast<int>(data->shape.size()); |
261 | const int axis = param->axis; |
262 | const int num_newaxis = param->num_newaxis; |
263 | ICHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" |
264 | << ", but got num_newaxis = " << num_newaxis; |
265 | ICHECK(-ndim - 1 <= axis && axis <= ndim) |
266 | << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" |
267 | << ", but got axis = " << axis << ", and data.ndim = " << ndim; |
268 | const int pivot = axis < 0 ? ndim + axis + 1 : axis; |
269 | std::vector<IndexExpr> oshape; |
270 | oshape.reserve(ndim + num_newaxis); |
271 | for (int i = 0; i < pivot; ++i) { |
272 | oshape.emplace_back(data->shape[i]); |
273 | } |
274 | for (int i = 0; i < num_newaxis; ++i) { |
275 | oshape.emplace_back(1); |
276 | } |
277 | for (int i = pivot; i < ndim; ++i) { |
278 | oshape.emplace_back(data->shape[i]); |
279 | } |
280 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
281 | return true; |
282 | } |
283 | |
284 | Array<te::Tensor> ExpandDimsCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
285 | const Type& out_type) { |
286 | const ExpandDimsAttrs* param = attrs.as<ExpandDimsAttrs>(); |
287 | ICHECK(param != nullptr); |
288 | return {topi::expand_dims(inputs[0], param->axis, param->num_newaxis)}; |
289 | } |
290 | |
291 | Expr MakeExpandDims(Expr data, int axis, int num_newaxis) { |
292 | auto attrs = make_object<ExpandDimsAttrs>(); |
293 | attrs->axis = axis; |
294 | attrs->num_newaxis = num_newaxis; |
295 | static const Op& op = Op::Get("expand_dims" ); |
296 | return Call(op, {data}, Attrs(attrs), {}); |
297 | } |
298 | |
299 | TVM_REGISTER_GLOBAL("relay.op._make.expand_dims" ).set_body_typed(MakeExpandDims); |
300 | |
301 | RELAY_REGISTER_OP("expand_dims" ) |
302 | .describe(R"code(Insert `num_newaxis` axes at the position given by `axis` |
303 | |
304 | - **data**: The input data to the operator. |
305 | |
306 | )code" TVM_ADD_FILELINE) |
307 | .set_num_inputs(1) |
308 | .set_attrs_type<ExpandDimsAttrs>() |
309 | .add_argument("data" , "Tensor" , "The input tensor." ) |
310 | .set_support_level(1) |
311 | .add_type_rel("ExpandDims" , ExpandDimsRel) |
312 | .set_attr<FTVMCompute>("FTVMCompute" , ExpandDimsCompute) |
313 | .set_attr<TOpPattern>("TOpPattern" , kBroadcast) |
314 | .set_attr<TReshapeOp>("TReshapeOp" , true); |
315 | |
316 | // relay.concatenate |
317 | TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); |
318 | |
319 | Array<te::Tensor> ConcatenateCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
320 | const Type& out_type) { |
321 | const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>(); |
322 | ICHECK(param != nullptr); |
323 | return {topi::concatenate(inputs, param->axis)}; |
324 | } |
325 | |
326 | Expr MakeConcatenate(Expr data, int axis) { |
327 | auto attrs = make_object<ConcatenateAttrs>(); |
328 | attrs->axis = axis; |
329 | static const Op& op = Op::Get("concatenate" ); |
330 | return Call(op, {data}, Attrs(attrs), {}); |
331 | } |
332 | |
333 | TVM_REGISTER_GLOBAL("relay.op._make.concatenate" ).set_body_typed(MakeConcatenate); |
334 | |
335 | RELAY_REGISTER_OP("concatenate" ) |
336 | .describe(R"code(Concatenate the input tensors along the given axis. |
337 | |
338 | - **data** : A list of tensors. |
339 | |
340 | - **axis** : The axis along which the tensors are concatenated. |
341 | |
342 | )code" TVM_ADD_FILELINE) |
343 | .set_attrs_type<ConcatenateAttrs>() |
344 | .set_num_inputs(1) |
345 | .add_argument("data" , "Tensor" , "The input list of tensors." ) |
346 | .set_support_level(1) |
347 | .add_type_rel("Concatenate" , ConcatenateRel<ConcatenateAttrs>) |
348 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ConcatenateLayout) |
349 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
350 | |
351 | TVM_REGISTER_NODE_TYPE(StackAttrs); |
352 | |
353 | bool StackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
354 | const TypeReporter& reporter) { |
355 | // types: [data, result] |
356 | ICHECK_EQ(types.size(), 2); |
357 | const auto* tensor_tuple = types[0].as<TupleTypeNode>(); |
358 | if (tensor_tuple == nullptr) { |
359 | ICHECK(types[0].as<IncompleteTypeNode>()) |
360 | << "cast: expect input type to be TupleType but get " << types[0]; |
361 | return false; |
362 | } |
363 | for (auto field : tensor_tuple->fields) { |
364 | if (field.as<IncompleteTypeNode>()) { |
365 | return false; |
366 | } |
367 | } |
368 | const auto* param = attrs.as<StackAttrs>(); |
369 | const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); |
370 | const int ndim = static_cast<int>(first->shape.size()); |
371 | |
372 | // Sanity check: axis |
373 | int axis = param->axis.IntValue(); |
374 | ICHECK(-(ndim + 1) <= axis && axis < ndim + 1) |
375 | << "stack only accepts `axis` in [-(ndim+1), ndim+1)" |
376 | << ", but got axis = " << axis << ", and ndim = " << ndim; |
377 | axis = axis < 0 ? ndim + axis + 1 : axis; |
378 | |
379 | // Sanity check: ndim and dtype. |
380 | const DataType dtype = first->dtype; |
381 | for (const Type& ele : tensor_tuple->fields) { |
382 | const auto& e = Downcast<TensorType>(ele); |
383 | int e_ndim = static_cast<int>(e->shape.size()); |
384 | const DataType& e_dtype = e->dtype; |
385 | ICHECK_EQ(e_ndim, ndim) << "relay.stack requires all tensors have the same ndim" ; |
386 | ICHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype" ; |
387 | for (size_t j = 0; j < first->shape.size(); ++j) { |
388 | if (j == static_cast<size_t>(axis)) continue; |
389 | if (first->shape[j].as<AnyNode>() || e->shape[j].as<AnyNode>() || |
390 | reporter->AssertEQ(first->shape[j], e->shape[j])) |
391 | continue; |
392 | throw CompileError( |
393 | "relay.stack requires all tensors have the same shape " |
394 | "on non-stacking axes" ); |
395 | } |
396 | } |
397 | |
398 | // Calculate shape |
399 | std::vector<IndexExpr> oshape; |
400 | oshape.reserve(ndim + 1); |
401 | const int stack_dim = static_cast<int>(tensor_tuple->fields.size()); |
402 | for (int i = 0; i < axis; ++i) { |
403 | oshape.emplace_back(first->shape[i]); |
404 | } |
405 | oshape.emplace_back(stack_dim); |
406 | for (int i = axis; i < ndim; ++i) { |
407 | oshape.emplace_back(first->shape[i]); |
408 | } |
409 | reporter->Assign(types[1], TensorType(oshape, dtype)); |
410 | return true; |
411 | } |
412 | |
413 | Array<te::Tensor> StackCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
414 | const Type& out_type) { |
415 | const StackAttrs* param = attrs.as<StackAttrs>(); |
416 | ICHECK(param != nullptr); |
417 | return {topi::stack(inputs, param->axis.IntValue())}; |
418 | } |
419 | |
420 | Expr MakeStack(Expr data, int axis) { |
421 | auto attrs = make_object<StackAttrs>(); |
422 | attrs->axis = axis; |
423 | static const Op& op = Op::Get("stack" ); |
424 | return Call(op, {data}, Attrs(attrs), {}); |
425 | } |
426 | |
427 | TVM_REGISTER_GLOBAL("relay.op._make.stack" ).set_body_typed(MakeStack); |
428 | |
429 | RELAY_REGISTER_OP("stack" ) |
430 | .describe(R"code(Stack the input tensors along the given axis. |
431 | |
432 | - **data** : A list of tensors. |
433 | |
434 | - **axis** : The axis along which the tensors are stacked. |
435 | |
436 | )code" TVM_ADD_FILELINE) |
437 | .set_attrs_type<StackAttrs>() |
438 | .set_num_inputs(1) |
439 | .add_argument("data" , "Tensor" , "The input list of tensors." ) |
440 | .set_support_level(3) |
441 | .add_type_rel("Stack" , StackRel) |
442 | .set_attr<FTVMCompute>("FTVMCompute" , StackCompute) |
443 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
444 | |
445 | /* relay.transpose */ |
446 | TVM_REGISTER_NODE_TYPE(TransposeAttrs); |
447 | |
448 | bool TransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
449 | const TypeReporter& reporter) { |
450 | // types: [data, result] |
451 | ICHECK_EQ(types.size(), 2); |
452 | const auto* data = types[0].as<TensorTypeNode>(); |
453 | if (data == nullptr) { |
454 | ICHECK(types[0].as<IncompleteTypeNode>()) |
455 | << "transpose: expect input type to be TensorType but get " << types[0]; |
456 | return false; |
457 | } |
458 | const auto* param = attrs.as<TransposeAttrs>(); |
459 | const int ndim = data->shape.size(); |
460 | const Array<Integer>& axes = param->axes; |
461 | // check dimension match |
462 | ICHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim) |
463 | << "Dimension mismatch: axes has " << axes.size() << " elements" |
464 | << ", but data.ndim = " << ndim; |
465 | // construct int_axes |
466 | std::vector<int> int_axes; |
467 | int_axes.reserve(ndim); |
468 | // used not defined to check if it is None. |
469 | if (!axes.defined()) { |
470 | for (int i = ndim - 1; i >= 0; --i) { |
471 | int_axes.push_back(i); |
472 | } |
473 | } else { |
474 | std::vector<int> axis_used(ndim, 0); |
475 | for (const Integer& e : axes) { |
476 | int64_t axis = e.IntValue(); |
477 | // sanity check for axis and ndim |
478 | ICHECK(-ndim <= axis && axis < ndim) |
479 | << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" |
480 | << ", but got axis = " << axis << ", and data.ndim = " << ndim; |
481 | axis = axis < 0 ? axis + ndim : axis; |
482 | // sanity check for duplication |
483 | ICHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; |
484 | axis_used[axis] = 1; |
485 | int_axes.push_back(static_cast<int>(axis)); |
486 | } |
487 | } |
488 | std::vector<IndexExpr> oshape; |
489 | oshape.reserve(ndim); |
490 | for (int axis : int_axes) { |
491 | oshape.push_back(data->shape[axis]); |
492 | } |
493 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
494 | return true; |
495 | } |
496 | |
497 | InferCorrectLayoutOutput TransposeInferCorrectLayout(const Attrs& attrs, |
498 | const Array<Layout>& new_in_layouts, |
499 | const Array<Layout>& old_in_layouts, |
500 | const Array<tvm::relay::Type>& old_in_types) { |
501 | const auto* attrs_ptr = attrs.as<TransposeAttrs>(); |
502 | ICHECK(attrs_ptr); |
503 | ObjectPtr<TransposeAttrs> params = make_object<TransposeAttrs>(*attrs_ptr); |
504 | |
505 | std::string in_layout_str = "" ; |
506 | std::string out_layout_str = "" ; |
507 | |
508 | // Infer the input layout string and update the axes. |
509 | if (old_in_layouts.defined() && old_in_layouts[0].defined()) { |
510 | ICHECK_EQ(old_in_layouts.size(), 1); |
511 | auto old_layout = old_in_layouts[0]; |
512 | Array<Integer> old_axes = params->axes; |
513 | |
514 | // Deal with default axes and negative axes. |
515 | if (!old_axes.defined() || old_axes.size() == 0) { |
516 | for (int i = old_layout.ndim() - 1; i >= 0; --i) { |
517 | old_axes.push_back(i); |
518 | } |
519 | } |
520 | for (size_t i = 0; i < old_axes.size(); ++i) { |
521 | int axis = static_cast<int>(old_axes[i]->value); |
522 | if (axis < 0) { |
523 | int pos_axis = static_cast<int>(old_layout.ndim()) + axis; |
524 | old_axes.Set(i, pos_axis); |
525 | } |
526 | } |
527 | |
528 | if (new_in_layouts.defined() && new_in_layouts[0].defined()) { |
529 | ICHECK_EQ(new_in_layouts.size(), 1); |
530 | auto new_layout = new_in_layouts[0]; |
531 | |
532 | // Update the axes based on the new layout. |
533 | Array<Integer> new_axes = Array<Integer>(); |
534 | for (auto axis : old_axes) { |
535 | auto new_axis = new_layout.IndexOf(old_layout[axis->value]); |
536 | if (new_axis == -1) { // Cannot find the target axis in the new layout. |
537 | new_axes.clear(); |
538 | break; |
539 | } |
540 | new_axes.push_back(new_axis); |
541 | } |
542 | if (new_axes.defined() && new_axes.size() == new_layout.ndim()) { |
543 | params->axes = std::move(new_axes); |
544 | in_layout_str = new_layout.name(); |
545 | } |
546 | } |
547 | |
548 | // If the input layout string cannot be determined, propagate the old layout. |
549 | if (in_layout_str == "" ) { |
550 | params->axes = std::move(old_axes); |
551 | in_layout_str = old_layout.name(); |
552 | } |
553 | } |
554 | |
555 | // Infer the output layout string based on the input layout and the axes. |
556 | Attrs new_attrs(params); |
557 | if (in_layout_str != "" ) { |
558 | for (auto axis : params->axes) { |
559 | ICHECK_LT(axis->value, in_layout_str.length()); |
560 | out_layout_str += in_layout_str[axis->value]; |
561 | } |
562 | try { |
563 | return InferCorrectLayoutOutput({Layout(in_layout_str)}, {Layout(out_layout_str)}, new_attrs); |
564 | } catch (const tvm::Error& e) { |
565 | // If the layout string is invalid for any reason, give up. |
566 | return InferCorrectLayoutOutput({Layout::Undef()}, {Layout::Undef()}, attrs); |
567 | } |
568 | } |
569 | return InferCorrectLayoutOutput({Layout::Undef()}, {Layout::Undef()}, attrs); |
570 | } |
571 | |
572 | Array<te::Tensor> TransposeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
573 | const Type& out_type) { |
574 | const auto* param = attrs.as<TransposeAttrs>(); |
575 | ICHECK(param != nullptr); |
576 | return Array<te::Tensor>{topi::transpose(inputs[0], param->axes)}; |
577 | } |
578 | |
579 | Expr MakeTranspose(Expr data, Array<Integer> axes) { |
580 | auto attrs = make_object<TransposeAttrs>(); |
581 | attrs->axes = std::move(axes); |
582 | static const Op& op = Op::Get("transpose" ); |
583 | return Call(op, {data}, Attrs(attrs), {}); |
584 | } |
585 | |
586 | TVM_REGISTER_GLOBAL("relay.op._make.transpose" ).set_body_typed(MakeTranspose); |
587 | |
588 | RELAY_REGISTER_OP("transpose" ) |
589 | .describe(R"code(Permutes the dimensions of an array. |
590 | |
591 | - **data**: The input data to the operator. |
592 | |
593 | - **axes**: The target axes order, reverse order if not specified. |
594 | |
595 | )code" TVM_ADD_FILELINE) |
596 | .set_num_inputs(1) |
597 | .set_attrs_type<TransposeAttrs>() |
598 | .add_argument("data" , "Tensor" , "The input tensor." ) |
599 | .set_support_level(3) |
600 | .add_type_rel("Transpose" , TransposeRel) |
601 | .set_attr<FTVMCompute>("FTVMCompute" , TransposeCompute) |
602 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , TransposeInferCorrectLayout) |
603 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
604 | |
605 | /* relay.reshape */ |
606 | TVM_REGISTER_NODE_TYPE(ReshapeAttrs); |
607 | TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs); |
608 | |
609 | Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs, |
610 | bool reverse) { |
611 | const auto* param = attrs.as<ReshapeAttrs>(); |
612 | Array<IndexExpr> oshape; |
613 | Array<IndexExpr> ishape; |
614 | Array<Integer> newshape; |
615 | |
616 | if (reverse) { |
617 | ishape.Assign(data_shape.rbegin(), data_shape.rend()); |
618 | newshape.Assign(param->newshape.rbegin(), param->newshape.rend()); |
619 | } else { |
620 | ishape = data_shape; |
621 | newshape = param->newshape; |
622 | } |
623 | |
624 | bool allowzero = param->allowzero; |
625 | |
626 | std::unordered_set<size_t> used_input_dims; |
627 | std::unordered_set<size_t> used_output_dims; |
628 | size_t src_idx = 0; |
629 | int infer_idx = -1; |
630 | |
631 | for (size_t i = 0; i < newshape.size(); ++i) { |
632 | int svalue = newshape[i]->value; |
633 | // special flag handling for shape inference. |
634 | if (svalue > 0) { |
635 | oshape.push_back(newshape[i]); |
636 | ++src_idx; |
637 | } else if (svalue == 0) { |
638 | if (allowzero) { |
639 | // 0 means empty tensor, thus default behavior |
640 | oshape.push_back(newshape[i]); |
641 | ++src_idx; |
642 | } else { |
643 | // 0 means to copy at equivilant position in data tensor |
644 | ICHECK_LT(src_idx, ishape.size()); |
645 | used_input_dims.insert(src_idx); |
646 | used_output_dims.insert(oshape.size()); |
647 | oshape.push_back(ishape[src_idx++]); |
648 | } |
649 | } else if (svalue == -1) { |
650 | // inference based on rest |
651 | ICHECK_LT(infer_idx, 0) << "One and only one dim can be inferred" ; |
652 | infer_idx = i; |
653 | oshape.push_back(1); |
654 | ++src_idx; |
655 | } else if (svalue == -2) { |
656 | // copy all remaining dims from source |
657 | while (src_idx < ishape.size()) { |
658 | used_input_dims.insert(src_idx); |
659 | used_output_dims.insert(oshape.size()); |
660 | oshape.push_back(ishape[src_idx++]); |
661 | } |
662 | } else if (svalue == -3) { |
663 | // merge two dims from source |
664 | ICHECK_LT(src_idx + 1, ishape.size()); |
665 | used_input_dims.insert(src_idx); |
666 | IndexExpr d1 = ishape[src_idx++]; |
667 | used_input_dims.insert(src_idx); |
668 | IndexExpr d2 = ishape[src_idx++]; |
669 | used_output_dims.insert(oshape.size()); |
670 | if (d1.as<AnyNode>() || d2.as<AnyNode>()) { |
671 | oshape.push_back(Any()); |
672 | } else { |
673 | oshape.push_back(d1 * d2); |
674 | } |
675 | } else if (svalue == -4) { |
676 | // split the source dim s into two dims |
677 | // read the left dim and then the right dim (either can be -1) |
678 | ICHECK_LT(i + 2, newshape.size()); |
679 | ICHECK_LT(src_idx, ishape.size()); |
680 | used_input_dims.insert(src_idx); |
681 | IndexExpr d0 = ishape[src_idx++]; |
682 | Integer d1 = newshape[++i]; |
683 | Integer d2 = newshape[++i]; |
684 | if (d1->value == -1) { |
685 | ICHECK_NE(d2->value, -1) << "Split dims cannot both be -1." ; |
686 | used_output_dims.insert(oshape.size()); |
687 | if (d0.as<AnyNode>()) { |
688 | oshape.push_back(Any()); |
689 | } else { |
690 | oshape.push_back(indexdiv(d0, d2)); |
691 | } |
692 | used_output_dims.insert(oshape.size()); |
693 | oshape.push_back(d2); |
694 | } else { |
695 | used_output_dims.insert(oshape.size()); |
696 | oshape.push_back(d1); |
697 | used_output_dims.insert(oshape.size()); |
698 | if (d2->value == -1) { |
699 | if (d0.as<AnyNode>()) { |
700 | oshape.push_back(Any()); |
701 | } else { |
702 | oshape.push_back(indexdiv(d0, d1)); |
703 | } |
704 | } else { |
705 | oshape.push_back(d2); |
706 | } |
707 | } |
708 | } else { |
709 | LOG(FATAL) << "Unsupported special value: " << svalue; |
710 | } |
711 | } |
712 | |
713 | if (infer_idx >= 0) { |
714 | IndexExpr infer_dim = 1; |
715 | for (size_t i = 0; i < ishape.size(); ++i) { |
716 | if (used_input_dims.count(i) != 0) { |
717 | continue; |
718 | } |
719 | if (ishape[i].as<AnyNode>()) { |
720 | infer_dim = Any(); |
721 | break; |
722 | } |
723 | infer_dim *= ishape[i]; |
724 | } |
725 | if (!infer_dim.as<AnyNode>()) { |
726 | for (size_t i = 0; i < oshape.size(); ++i) { |
727 | if (used_output_dims.count(i) != 0) { |
728 | continue; |
729 | } |
730 | if (oshape[i].as<AnyNode>()) { |
731 | infer_dim = Any(); |
732 | break; |
733 | } |
734 | infer_dim = indexdiv(infer_dim, oshape[i]); |
735 | } |
736 | } |
737 | arith::Analyzer ana; |
738 | infer_dim = ana.Simplify(infer_dim); |
739 | oshape.Set(infer_idx, infer_dim); |
740 | } |
741 | |
742 | return oshape; |
743 | } |
744 | |
745 | bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
746 | const TypeReporter& reporter) { |
747 | // types: [data, result] |
748 | ICHECK_EQ(types.size(), 2); |
749 | const auto* data = types[0].as<TensorTypeNode>(); |
750 | if (data == nullptr) { |
751 | ICHECK(types[0].as<IncompleteTypeNode>()) |
752 | << "reshape: expect input type to be TensorType but get " << types[0]; |
753 | return false; |
754 | } |
755 | |
756 | const auto& oshape = InferNewShape(data->shape, attrs, false); |
757 | |
758 | // Verify that the sum of dimensions in the output shape is the sum of |
759 | // dimensions in the input shape |
760 | Array<IndexExpr> data_shape; |
761 | data_shape = data->shape; |
762 | |
763 | bool found_dynamic = false; |
764 | int64_t oshape_sum = 1; |
765 | for (auto& x : oshape) { |
766 | // Check if we have a dynamic shape. If we do, we can't verify if the |
767 | // reshape is valid. Dynamic shapes are marker by using Any, but can also |
768 | // occur from SizeVar's. In the case of SizeVar, the shape expression can |
769 | // be an AST. We can't easily check if we have an AST because of a ShapeVar |
770 | // or some other reason, so our check for dynamic shape is just if we can |
771 | // convert the shape to in integer or not. |
772 | if (!x->IsInstance<tvm::Integer::ContainerType>()) { |
773 | found_dynamic = true; |
774 | break; |
775 | } |
776 | oshape_sum *= Downcast<tvm::Integer>(x)->value; |
777 | } |
778 | int64_t data_shape_sum = 1; |
779 | for (auto& x : data_shape) { |
780 | if (!x->IsInstance<tvm::Integer::ContainerType>()) { |
781 | found_dynamic = true; |
782 | break; |
783 | } |
784 | data_shape_sum *= Downcast<tvm::Integer>(x)->value; |
785 | } |
786 | if (!found_dynamic && oshape_sum != data_shape_sum) { |
787 | std::ostringstream oshape_str, data_shape_str; |
788 | for (auto iter = oshape.begin(); iter != oshape.end(); iter++) { |
789 | oshape_str << (iter != oshape.begin() ? "," : "" ) << *iter; |
790 | } |
791 | for (auto iter = data_shape.begin(); iter != data_shape.end(); iter++) { |
792 | data_shape_str << (iter != data_shape.begin() ? "," : "" ) << *iter; |
793 | } |
794 | ICHECK_EQ(oshape_sum, data_shape_sum) |
795 | << "Input tensor shape(" << data_shape_str.str() << ") and reshaped shape(" |
796 | << oshape_str.str() << ") are not compatible!" ; |
797 | } |
798 | |
799 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
800 | return true; |
801 | } |
802 | |
803 | bool ReverseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
804 | const TypeReporter& reporter) { |
805 | // types: [data, result] |
806 | ICHECK_EQ(types.size(), 2); |
807 | const auto* data = types[0].as<TensorTypeNode>(); |
808 | if (data == nullptr) { |
809 | ICHECK(types[0].as<IncompleteTypeNode>()) |
810 | << "reshape: expect input type to be TensorType but get " << types[0]; |
811 | return false; |
812 | } |
813 | |
814 | const auto& oshape = InferNewShape(data->shape, attrs, true); |
815 | |
816 | // Verify that the sum of dimensions in the output shape is the sum of |
817 | // dimensions in the input shape |
818 | Array<IndexExpr> data_shape; |
819 | data_shape.Assign(data->shape.rbegin(), data->shape.rend()); |
820 | |
821 | bool found_dynamic = false; |
822 | int64_t oshape_sum = 1; |
823 | for (auto& x : oshape) { |
824 | // Check if we have a dynamic shape. If we do, we can't verify if the |
825 | // reshape is valid. Dynamic shapes are marker by using Any, but can also |
826 | // occur from SizeVar's. In the case of SizeVar, the shape expression can |
827 | // be an AST. We can't easily check if we have an AST because of a ShapeVar |
828 | // or some other reason, so our check for dynamic shape is just if we can |
829 | // convert the shape to in integer or not. |
830 | if (!x->IsInstance<tvm::Integer::ContainerType>()) { |
831 | found_dynamic = true; |
832 | break; |
833 | } |
834 | oshape_sum *= Downcast<tvm::Integer>(x)->value; |
835 | } |
836 | int64_t data_shape_sum = 1; |
837 | for (auto& x : data_shape) { |
838 | if (!x->IsInstance<tvm::Integer::ContainerType>()) { |
839 | found_dynamic = true; |
840 | break; |
841 | } |
842 | data_shape_sum *= Downcast<tvm::Integer>(x)->value; |
843 | } |
844 | if (!found_dynamic) { |
845 | ICHECK_EQ(oshape_sum, data_shape_sum) |
846 | << "Input tensor shape and reshaped shape are not compatible" ; |
847 | } |
848 | |
849 | reporter->Assign(types[1], |
850 | TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype)); |
851 | return true; |
852 | } |
853 | |
854 | Array<PrimExpr> infer_reshape_like(const Array<PrimExpr>& lhs_shape, |
855 | const Array<PrimExpr>& rhs_shape, const Attrs& attrs) { |
856 | const auto* like_attrs = attrs.as<ReshapeLikeAttrs>(); |
857 | CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as<IntImmNode>()) |
858 | << "lhs_end must be a concrete integer or None" ; |
859 | CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as<IntImmNode>()) |
860 | << "rhs_end must be a concrete integer or None" ; |
861 | |
862 | int64_t lhs_shape_size = static_cast<int64_t>(lhs_shape.size()); |
863 | int64_t rhs_shape_size = static_cast<int64_t>(rhs_shape.size()); |
864 | int64_t lhs_begin = static_cast<int64_t>(like_attrs->lhs_begin); |
865 | int64_t lhs_end = |
866 | like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as<IntImmNode>()->value : lhs_shape_size; |
867 | int64_t rhs_begin = static_cast<int64_t>(like_attrs->rhs_begin); |
868 | int64_t rhs_end = |
869 | like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as<IntImmNode>()->value : rhs_shape_size; |
870 | |
871 | // handle negative axes |
872 | lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin; |
873 | lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end; |
874 | rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin; |
875 | rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end; |
876 | |
877 | Array<PrimExpr> shape_like; |
878 | for (auto i = 0; i < lhs_begin; i++) { |
879 | shape_like.push_back(lhs_shape[i]); |
880 | } |
881 | for (auto i = rhs_begin; i < rhs_end; i++) { |
882 | shape_like.push_back(rhs_shape[i]); |
883 | } |
884 | for (auto i = lhs_end; i < lhs_shape_size; i++) { |
885 | shape_like.push_back(lhs_shape[i]); |
886 | } |
887 | return shape_like; |
888 | } |
889 | |
890 | Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
891 | const Type& out_type) { |
892 | // Quick path for reshape_like |
893 | if (!attrs.as<ReshapeAttrs>()) { |
894 | ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr); |
895 | auto shape_like = infer_reshape_like(inputs[0]->shape, inputs[1]->shape, attrs); |
896 | return {topi::reshape(inputs[0], shape_like)}; |
897 | } |
898 | |
899 | const auto* out_ttype = out_type.as<TensorTypeNode>(); |
900 | ICHECK(out_ttype != nullptr); |
901 | Array<IndexExpr> newshape; |
902 | bool newshape_has_any = false; |
903 | for (auto val : out_ttype->shape) { |
904 | if (val->IsInstance<tir::AnyNode>() || val->IsInstance<tir::VarNode>()) { |
905 | newshape_has_any = true; |
906 | break; |
907 | } else { |
908 | newshape.push_back(val); |
909 | } |
910 | } |
911 | |
912 | if (newshape_has_any) { |
913 | newshape = InferNewShape(inputs[0]->shape, attrs, false); |
914 | } |
915 | return {topi::reshape(inputs[0], newshape)}; |
916 | } |
917 | |
918 | Expr MakeReshape(Expr data, Array<Integer> newshape, bool allowzero) { |
919 | auto attrs = make_object<ReshapeAttrs>(); |
920 | attrs->newshape = std::move(newshape); |
921 | attrs->allowzero = allowzero; |
922 | static const Op& op = Op::Get("reshape" ); |
923 | return Call(op, {data}, Attrs(attrs), {}); |
924 | } |
925 | |
926 | TVM_REGISTER_GLOBAL("relay.op._make.reshape" ).set_body_typed(MakeReshape); |
927 | |
928 | RELAY_REGISTER_OP("reshape" ) |
929 | .describe(R"code(Reshapes the input array. |
930 | |
931 | Example:: |
932 | |
933 | To give user more convenience in without doing manual shape inference, |
934 | some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. |
935 | The significance of each is explained below: |
936 | |
937 | - ``0`` copy this dimension from the input to the output shape. |
938 | |
939 | Example:: |
940 | |
941 | - data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2) |
942 | - data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4) |
943 | |
944 | - ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions |
945 | keeping the size of the new array same as that of the input array. |
946 | At most one dimension of shape can be -1. |
947 | |
948 | Example:: |
949 | |
950 | - data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4) |
951 | - data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8) |
952 | - data.shape = (2,3,4), newshape = (-1,), result.shape = (24,) |
953 | |
954 | - ``-2`` copy all/remainder of the input dimensions to the output shape. |
955 | |
956 | Example:: |
957 | |
958 | - data.shape = (2,3,4), newshape = (-2,), result.shape = (2,3,4) |
959 | - data.shape = (2,3,4), newshape = (2,-2), result.shape = (2,3,4) |
960 | - data.shape = (2,3,4), newshape = (-2,1,1), result.shape = (2,3,4,1,1) |
961 | |
962 | - ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension. |
963 | |
964 | Example:: |
965 | |
966 | - data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4) |
967 | - data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20) |
968 | - data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12) |
969 | - data.shape = (2,3,4), newshape = (-3,-2), result.shape = (6,4) |
970 | |
971 | - ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1). |
972 | |
973 | Example:: |
974 | |
975 | - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) |
976 | - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) |
977 | |
978 | )code" TVM_ADD_FILELINE) |
979 | .set_num_inputs(1) |
980 | .set_attrs_type<ReshapeAttrs>() |
981 | .add_argument("data" , "Tensor" , "The input tensor." ) |
982 | .set_support_level(3) |
983 | .add_type_rel("Reshape" , ReshapeRel) |
984 | .set_attr<FTVMCompute>("FTVMCompute" , ReshapeCompute) |
985 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
986 | .set_attr<TReshapeOp>("TReshapeOp" , true); |
987 | |
988 | /*! |
989 | * \brief ReshapeLikeRel User defined type constraint function. |
990 | * \param num_inputs Number of input types in the args. |
991 | * \param attrs The additional attributes of the operator. |
992 | * \param reporter The reporter to report solution to. |
993 | * \return False if the relation has not been resolved, it might be resolved later. |
994 | * True if this relation has been resolved. |
995 | */ |
996 | bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
997 | const TypeReporter& reporter) { |
998 | ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr); |
999 | ICHECK_EQ(types.size(), 3); |
1000 | const auto* data = types[0].as<TensorTypeNode>(); |
1001 | if (data == nullptr) { |
1002 | return false; |
1003 | } |
1004 | const auto* reshape_like = types[1].as<TensorTypeNode>(); |
1005 | if (reshape_like == nullptr) { |
1006 | return false; |
1007 | } |
1008 | auto shape_like = infer_reshape_like(data->shape, reshape_like->shape, attrs); |
1009 | // Only check When input data has static shape. |
1010 | bool is_static_shape = true; |
1011 | for (size_t i = 0; i < data->shape.size(); ++i) { |
1012 | if (!data->shape[i].as<IntImmNode>()) { |
1013 | is_static_shape = false; |
1014 | break; |
1015 | } |
1016 | } |
1017 | auto output_type = TensorType(shape_like, data->dtype); |
1018 | if (is_static_shape) { |
1019 | ICHECK(reporter->AssertEQ(data->Size(), output_type->Size())) |
1020 | << "Reshape inputs size should be compatible, " |
1021 | << "but found data_shape " << data->shape << " not same as output_shape " |
1022 | << output_type->shape; |
1023 | } |
1024 | reporter->Assign(types[2], output_type); |
1025 | return true; |
1026 | } |
1027 | |
1028 | Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, |
1029 | Integer rhs_end) { |
1030 | auto attrs = make_object<ReshapeLikeAttrs>(); |
1031 | attrs->lhs_begin = std::move(lhs_begin); |
1032 | attrs->lhs_end = std::move(lhs_end); |
1033 | attrs->rhs_begin = std::move(rhs_begin); |
1034 | attrs->rhs_end = std::move(rhs_end); |
1035 | static const Op& op = Op::Get("reshape_like" ); |
1036 | return Call(op, {lhs, rhs}, Attrs(attrs), {}); |
1037 | } |
1038 | |
1039 | TVM_REGISTER_GLOBAL("relay.op._make.reshape_like" ).set_body_typed(MakeReshapeLike); |
1040 | |
1041 | RELAY_REGISTER_OP("reshape_like" ) |
1042 | .describe(R"code(Reshapes the input array by the size of another array. |
1043 | For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes |
1044 | the input array into an output array with the same shape as the second input array. |
1045 | .. note:: |
1046 | Sizes for both array should be compatible. |
1047 | Example:: |
1048 | |
1049 | data.shape == (1, 2, 3, 4) |
1050 | shape_like.shape == (6, 2, 2, 3) |
1051 | |
1052 | ret = reshape_like(data, shape_like, lhs_begin=1, rhs_end=3) |
1053 | ret.shape == (1, 6, 2, 2) |
1054 | )code" TVM_ADD_FILELINE) |
1055 | .set_attrs_type<ReshapeLikeAttrs>() |
1056 | .set_num_inputs(2) |
1057 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1058 | .add_argument("shape_like" , "Tensor" , "Shape tensor." ) |
1059 | .set_support_level(3) |
1060 | .add_type_rel("ReshapeLike" , ReshapeLikeRel) |
1061 | .set_attr<FTVMCompute>("FTVMCompute" , ReshapeCompute) |
1062 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
1063 | |
1064 | // ArgWhere |
1065 | bool ArgWhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1066 | const TypeReporter& reporter) { |
1067 | ICHECK_EQ(num_inputs, 1); |
1068 | auto tt = types[0].as<TensorTypeNode>(); |
1069 | |
1070 | if (tt == nullptr) { |
1071 | return false; |
1072 | } |
1073 | |
1074 | const auto& input_shape = tt->shape; |
1075 | const auto& input_rank = input_shape.size(); |
1076 | std::vector<IndexExpr> result_shape; |
1077 | result_shape.push_back(Any()); |
1078 | result_shape.push_back(IntImm(DataType::Int(32), input_rank)); |
1079 | reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32))); |
1080 | return true; |
1081 | } |
1082 | |
1083 | TVM_REGISTER_GLOBAL("relay.op._make.argwhere" ).set_body_typed([](Expr data) { |
1084 | static const Op& op = Op::Get("argwhere" ); |
1085 | return Call(op, {data}, Attrs(), {}); |
1086 | }); |
1087 | |
1088 | RELAY_REGISTER_OP("argwhere" ) |
1089 | .describe(R"doc(Find the indices of elements of a tensor that are |
1090 | non-zero)doc" TVM_ADD_FILELINE) |
1091 | .set_num_inputs(1) |
1092 | .add_argument("condition" , "Tensor" , "The input condition tensor." ) |
1093 | .add_type_rel("ArgWhere" , ArgWhereRel) |
1094 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
1095 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
1096 | .set_support_level(10); |
1097 | |
1098 | // Scatter |
1099 | TVM_REGISTER_NODE_TYPE(ScatterAttrs); |
1100 | |
1101 | // Scatter |
1102 | bool ScatterRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1103 | const TypeReporter& reporter) { |
1104 | ICHECK_EQ(num_inputs, 3); |
1105 | ICHECK_EQ(types.size(), 4); |
1106 | auto data = types[0].as<TensorTypeNode>(); |
1107 | if (data == nullptr) { |
1108 | return false; |
1109 | } |
1110 | auto indices = types[1].as<TensorTypeNode>(); |
1111 | if (indices == nullptr) { |
1112 | return false; |
1113 | } |
1114 | auto updates = types[2].as<TensorTypeNode>(); |
1115 | if (updates == nullptr) { |
1116 | return false; |
1117 | } |
1118 | ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) |
1119 | << "indices of scatter must be tensor of integer" ; |
1120 | const auto param = attrs.as<ScatterAttrs>(); |
1121 | ICHECK(param != nullptr); |
1122 | reporter->Assign(types[3], TensorType(data->shape, data->dtype)); |
1123 | return true; |
1124 | } |
1125 | |
1126 | TVM_REGISTER_GLOBAL("relay.op._make.scatter" ) |
1127 | .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { |
1128 | auto attrs = make_object<ScatterAttrs>(); |
1129 | attrs->axis = std::move(axis); |
1130 | static const Op& op = Op::Get("scatter" ); |
1131 | return Call(op, {data, indices, updates}, Attrs(attrs), {}); |
1132 | }); |
1133 | |
1134 | RELAY_REGISTER_OP("scatter" ) |
1135 | .describe( |
1136 | R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE) |
1137 | .set_num_inputs(3) |
1138 | .add_argument("data" , "Tensor" , "The input data tensor." ) |
1139 | .add_argument("indices" , "Tensor" , "The indices location tensor." ) |
1140 | .add_argument("updates" , "Tensor" , "The values to update the input with." ) |
1141 | .add_type_rel("Scatter" , ScatterRel) |
1142 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
1143 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
1144 | .set_support_level(10); |
1145 | |
1146 | // Scatter_add |
1147 | TVM_REGISTER_NODE_TYPE(ScatterAddAttrs); |
1148 | |
1149 | // Scatter Add |
1150 | bool ScatterAddRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1151 | const TypeReporter& reporter) { |
1152 | ICHECK_EQ(num_inputs, 3); |
1153 | ICHECK_EQ(types.size(), 4); |
1154 | auto data = types[0].as<TensorTypeNode>(); |
1155 | if (data == nullptr) { |
1156 | return false; |
1157 | } |
1158 | auto indices = types[1].as<TensorTypeNode>(); |
1159 | if (indices == nullptr) { |
1160 | return false; |
1161 | } |
1162 | auto updates = types[2].as<TensorTypeNode>(); |
1163 | if (updates == nullptr) { |
1164 | return false; |
1165 | } |
1166 | ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) |
1167 | << "indices of scatter_add must be tensor of integer" ; |
1168 | const auto param = attrs.as<ScatterAddAttrs>(); |
1169 | ICHECK(param != nullptr); |
1170 | reporter->Assign(types[3], TensorType(data->shape, data->dtype)); |
1171 | return true; |
1172 | } |
1173 | |
1174 | TVM_REGISTER_GLOBAL("relay.op._make.scatter_add" ) |
1175 | .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { |
1176 | auto attrs = make_object<ScatterAddAttrs>(); |
1177 | attrs->axis = std::move(axis); |
1178 | static const Op& op = Op::Get("scatter_add" ); |
1179 | return Call(op, {data, indices, updates}, Attrs(attrs), {}); |
1180 | }); |
1181 | |
1182 | RELAY_REGISTER_OP("scatter_add" ) |
1183 | .describe( |
1184 | R"doc(Update data by adding values in updates at positions defined by indices)doc" TVM_ADD_FILELINE) |
1185 | .set_num_inputs(3) |
1186 | .add_argument("data" , "Tensor" , "The input data tensor." ) |
1187 | .add_argument("indices" , "Tensor" , "The indices location tensor." ) |
1188 | .add_argument("updates" , "Tensor" , "The values to update the input with." ) |
1189 | .add_type_rel("ScatterAdd" , ScatterAddRel) |
1190 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
1191 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
1192 | .set_support_level(10); |
1193 | |
1194 | // scatter_nd operator |
1195 | TVM_REGISTER_NODE_TYPE(ScatterNDAttrs); |
1196 | |
1197 | bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1198 | const TypeReporter& reporter) { |
1199 | // `types` contains: [data, indices, updates, result] |
1200 | ICHECK_EQ(types.size(), 4); |
1201 | const auto* data = types[0].as<TensorTypeNode>(); |
1202 | const auto* indices = types[1].as<TensorTypeNode>(); |
1203 | const auto* updates = types[2].as<TensorTypeNode>(); |
1204 | if (data == nullptr) { |
1205 | ICHECK(types[0].as<IncompleteTypeNode>()) |
1206 | << "ScatterND: expect input data type to be TensorType but got " << types[0]; |
1207 | return false; |
1208 | } |
1209 | if (indices == nullptr) { |
1210 | ICHECK(types[1].as<IncompleteTypeNode>()) |
1211 | << "ScatterND: expect indices type to be TensorType but got " << types[1]; |
1212 | return false; |
1213 | } |
1214 | if (updates == nullptr) { |
1215 | ICHECK(types[2].as<IncompleteTypeNode>()) |
1216 | << "ScatterND: expect updates type to be TensorType but got " << types[2]; |
1217 | return false; |
1218 | } |
1219 | ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) |
1220 | << "ScatterND: indices must be a tensor of integers." ; |
1221 | |
1222 | const auto out_shape = data->shape; |
1223 | const IntImmNode* mdim = indices->shape[0].as<IntImmNode>(); |
1224 | ICHECK(mdim) << "ScatterND needs a static shape for the first axis of indices, got " |
1225 | << indices->shape; |
1226 | const size_t kdim = indices->shape.size() - 1; |
1227 | const size_t ndim = out_shape.size(); |
1228 | ICHECK_LE(size_t(mdim->value), ndim) |
1229 | << "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices " |
1230 | "with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N." ; |
1231 | // Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's. |
1232 | for (size_t i = 0; i < kdim; i++) { |
1233 | reporter->AssertEQ(indices->shape[i + 1], updates->shape[i]); |
1234 | } |
1235 | |
1236 | std::vector<IndexExpr> oshape; |
1237 | for (auto& x : out_shape) { |
1238 | oshape.push_back(x); |
1239 | } |
1240 | |
1241 | // data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1} |
1242 | for (size_t i = mdim->value; i < ndim; i++) { |
1243 | reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]); |
1244 | } |
1245 | |
1246 | reporter->Assign(types[3], TensorType(data->shape, data->dtype)); |
1247 | return true; |
1248 | } |
1249 | |
1250 | Expr MakeScatterND(Expr data, Expr indices, Expr updates, String mode) { |
1251 | auto attrs = make_object<ScatterNDAttrs>(); |
1252 | attrs->mode = std::move(mode); |
1253 | static const Op& op = Op::Get("scatter_nd" ); |
1254 | return Call(op, {data, indices, updates}, Attrs(attrs), {}); |
1255 | } |
1256 | |
1257 | TVM_REGISTER_GLOBAL("relay.op._make.scatter_nd" ).set_body_typed(MakeScatterND); |
1258 | |
1259 | // scatter_nd operator has extern schedules for CPU and GPU devices. |
1260 | // Fusing extern schedules with Injective schedules leads to errors. |
1261 | // So, converting the scatter_nd to Opaque to prevent compilation failures |
1262 | RELAY_REGISTER_OP("scatter_nd" ) |
1263 | .describe(R"code(Scatter elements or slices from data and store to a tensor |
1264 | whose shape is defined by indices. |
1265 | |
1266 | Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}) and indices with shape |
1267 | (M, Y_0, ..., Y_{K-1}), the output will have shape (X_0, X_1, ..., X_{N-1}). |
1268 | )code" TVM_ADD_FILELINE) |
1269 | .set_num_inputs(3) |
1270 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1271 | .add_argument("indices" , "Tensor" , "The indices tensor." ) |
1272 | .add_argument("updates" , "Tensor" , "The input tensor." ) |
1273 | .set_support_level(3) |
1274 | .add_type_rel("ScatterND" , ScatterNDRel) |
1275 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
1276 | |
1277 | // Take |
1278 | TVM_REGISTER_NODE_TYPE(TakeAttrs); |
1279 | |
1280 | bool TakeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1281 | const TypeReporter& reporter) { |
1282 | // `types` contains: [data, indices, result] |
1283 | ICHECK_EQ(types.size(), 3); |
1284 | const auto* data = types[0].as<TensorTypeNode>(); |
1285 | if (data == nullptr) { |
1286 | return false; |
1287 | } |
1288 | const auto* indices = types[1].as<TensorTypeNode>(); |
1289 | if (indices == nullptr) { |
1290 | return false; |
1291 | } |
1292 | ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) |
1293 | << "indices of take must be tensor of integer" ; |
1294 | const auto param = attrs.as<TakeAttrs>(); |
1295 | ICHECK(param != nullptr); |
1296 | |
1297 | if (!param->axis.defined()) { |
1298 | std::vector<IndexExpr> oshape(indices->shape.begin(), indices->shape.end()); |
1299 | reporter->Assign(types[2], TensorType(oshape, data->dtype)); |
1300 | return true; |
1301 | } |
1302 | |
1303 | std::vector<IndexExpr> oshape; |
1304 | const auto ndim_data = static_cast<int>(data->shape.size()); |
1305 | const auto ndim_indices = static_cast<int>(indices->shape.size()); |
1306 | int axis = static_cast<int>(param->axis->value); |
1307 | int batch_dims = static_cast<int>(param->batch_dims->value); |
1308 | if (axis < 0) axis += ndim_data; |
1309 | if (batch_dims < 0) axis += ndim_indices; |
1310 | ICHECK_LE(axis, ndim_data) << "axis should be with in data shape" |
1311 | << ", but got = " << axis; |
1312 | ICHECK_LE(batch_dims, ndim_indices) << "batch_dims should be with in indices shape" |
1313 | << ", but got = " << batch_dims; |
1314 | ICHECK_LE(batch_dims, axis) << "batch_dims should be less than or equal to axis" |
1315 | << ", but got = " << batch_dims; |
1316 | |
1317 | oshape.reserve(ndim_data - 1 + ndim_indices - batch_dims); |
1318 | for (int i = 0; i < batch_dims; ++i) { |
1319 | oshape.emplace_back(data->shape[i]); |
1320 | } |
1321 | for (int i = batch_dims; i < axis; ++i) { |
1322 | oshape.emplace_back(data->shape[i]); |
1323 | } |
1324 | for (int i = batch_dims; i < ndim_indices; ++i) { |
1325 | oshape.emplace_back(indices->shape[i]); |
1326 | } |
1327 | for (int i = axis + 1; i < ndim_data; ++i) { |
1328 | oshape.emplace_back(data->shape[i]); |
1329 | } |
1330 | |
1331 | reporter->Assign(types[2], TensorType(oshape, data->dtype)); |
1332 | return true; |
1333 | } |
1334 | |
1335 | Array<te::Tensor> TakeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
1336 | const Type& out_type) { |
1337 | const auto* param = attrs.as<TakeAttrs>(); |
1338 | ICHECK(param != nullptr); |
1339 | if (!param->axis.defined()) { |
1340 | return Array<te::Tensor>{ |
1341 | topi::take(inputs[0], inputs[1], param->batch_dims.IntValue(), param->mode)}; |
1342 | } else { |
1343 | return Array<te::Tensor>{topi::take(inputs[0], inputs[1], param->batch_dims.IntValue(), |
1344 | param->axis.IntValue(), param->mode)}; |
1345 | } |
1346 | } |
1347 | |
1348 | Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode) { |
1349 | auto attrs = make_object<TakeAttrs>(); |
1350 | attrs->batch_dims = std::move(batch_dims); |
1351 | attrs->axis = std::move(axis); |
1352 | attrs->mode = std::move(mode); |
1353 | static const Op& op = Op::Get("take" ); |
1354 | return Call(op, {data, indices}, Attrs(attrs), {}); |
1355 | } |
1356 | |
1357 | TVM_REGISTER_GLOBAL("relay.op._make.take" ).set_body_typed(MakeTake); |
1358 | |
1359 | RELAY_REGISTER_OP("take" ) |
1360 | .describe(R"code(Take elements from an array along an axis. |
1361 | |
1362 | When axis is not None, this function does the same thing as 'fancy' indexing |
1363 | (indexing arrays using arrays); however, it can be easier to use if you need |
1364 | elements along a given axis. |
1365 | |
1366 | **Note** that when axis is none the flattened input array is used. |
1367 | |
1368 | Examples:: |
1369 | |
1370 | a = [[ 1, 2], |
1371 | [ 3, 4]] |
1372 | indices = [3, 0, 2] |
1373 | take(a, indices) = [ 4, 1, 3] |
1374 | |
1375 | a = [[ 1., 2.], |
1376 | [ 3., 4.]] |
1377 | indices = [1, 0] |
1378 | take(a, indices, axis=1) = [[ 2., 1.], |
1379 | [ 4., 3.]] |
1380 | |
1381 | )code" TVM_ADD_FILELINE) |
1382 | .set_attrs_type<TakeAttrs>() |
1383 | .set_num_inputs(2) |
1384 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1385 | .add_argument("indices" , "Tensor" , "The indices tensor." ) |
1386 | .set_support_level(3) |
1387 | .add_type_rel("Take" , TakeRel) |
1388 | .set_attr<FTVMCompute>("FTVMCompute" , TakeCompute) |
1389 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
1390 | |
1391 | // Init ops |
1392 | TVM_REGISTER_NODE_TYPE(InitOpAttrs); |
1393 | |
1394 | bool FullRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1395 | const TypeReporter& reporter) { |
1396 | ICHECK_EQ(types.size(), 2); |
1397 | const InitOpAttrs* param = attrs.as<InitOpAttrs>(); |
1398 | const auto* fill_value = types[0].as<TensorTypeNode>(); |
1399 | if (fill_value == nullptr) { |
1400 | return false; |
1401 | } |
1402 | |
1403 | DataType out_dtype = param->dtype; |
1404 | if (out_dtype.bits() == 0) { |
1405 | out_dtype = fill_value->dtype; |
1406 | } |
1407 | |
1408 | ICHECK_EQ(fill_value->shape.size(), 0) |
1409 | << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "." ; |
1410 | |
1411 | std::vector<IndexExpr> oshape; |
1412 | const Array<Integer>& cshape_array = param->shape.value(); |
1413 | for (size_t i = 0; i < cshape_array.size(); ++i) { |
1414 | oshape.push_back(cshape_array[i]); |
1415 | } |
1416 | reporter->Assign(types[1], TensorType(oshape, out_dtype)); |
1417 | return true; |
1418 | } |
1419 | |
1420 | Expr MakeFull(Expr fill_value, Array<Integer> shape, DataType dtype) { |
1421 | auto attrs = make_object<InitOpAttrs>(); |
1422 | attrs->dtype = std::move(dtype); |
1423 | attrs->shape = std::move(shape); |
1424 | static const Op& op = Op::Get("full" ); |
1425 | return Call(op, {fill_value}, Attrs(attrs), {}); |
1426 | } |
1427 | |
1428 | Array<te::Tensor> FullCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
1429 | const Type& out_type) { |
1430 | const auto* out_ttype = out_type.as<TensorTypeNode>(); |
1431 | return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())}; |
1432 | } |
1433 | |
1434 | TVM_REGISTER_GLOBAL("relay.op._make.full" ).set_body_typed(MakeFull); |
1435 | |
1436 | RELAY_REGISTER_OP("full" ) |
1437 | .describe(R"code(Fill array with scalar value. |
1438 | |
1439 | )code" TVM_ADD_FILELINE) |
1440 | .set_attrs_type<InitOpAttrs>() |
1441 | .set_num_inputs(1) |
1442 | .add_argument("fill_value" , "double" , "The value to fill." ) |
1443 | .set_support_level(3) |
1444 | .add_type_rel("Full" , FullRel) |
1445 | .set_attr<FTVMCompute>("FTVMCompute" , FullCompute) |
1446 | .set_attr<TOpPattern>("TOpPattern" , kElemWise); |
1447 | |
1448 | bool InitOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1449 | const TypeReporter& reporter) { |
1450 | // types = [ret_type] |
1451 | ICHECK_EQ(types.size(), 1); |
1452 | |
1453 | const InitOpAttrs* param = attrs.as<InitOpAttrs>(); |
1454 | ICHECK(param); |
1455 | |
1456 | DataType out_dtype = param->dtype; |
1457 | std::vector<IndexExpr> oshape; |
1458 | |
1459 | const Array<Integer>& cshape_array = param->shape.value(); |
1460 | for (size_t i = 0; i < cshape_array.size(); ++i) { |
1461 | oshape.push_back(cshape_array[i]); |
1462 | } |
1463 | reporter->Assign(types[0], TensorType(oshape, out_dtype)); |
1464 | return true; |
1465 | } |
1466 | |
1467 | Expr MakeZeros(Array<Integer> shape, DataType dtype) { |
1468 | auto attrs = make_object<InitOpAttrs>(); |
1469 | attrs->shape = std::move(shape); |
1470 | attrs->dtype = std::move(dtype); |
1471 | static const Op& op = Op::Get("zeros" ); |
1472 | return Call(op, {}, Attrs(attrs), {}); |
1473 | } |
1474 | |
1475 | TVM_REGISTER_GLOBAL("relay.op._make.zeros" ).set_body_typed(MakeZeros); |
1476 | |
1477 | RELAY_REGISTER_OP("zeros" ) |
1478 | .describe(R"code(Fill array with zeros. |
1479 | |
1480 | )code" TVM_ADD_FILELINE) |
1481 | .set_attrs_type<InitOpAttrs>() |
1482 | .set_num_inputs(0) |
1483 | .set_support_level(3) |
1484 | .add_type_rel("InitOp" , InitOpRel); |
1485 | |
1486 | Expr MakeOnes(Array<Integer> shape, DataType dtype) { |
1487 | auto attrs = make_object<InitOpAttrs>(); |
1488 | attrs->shape = std::move(shape); |
1489 | attrs->dtype = std::move(dtype); |
1490 | static const Op& op = Op::Get("ones" ); |
1491 | return Call(op, {}, Attrs(attrs), {}); |
1492 | } |
1493 | |
1494 | TVM_REGISTER_GLOBAL("relay.op._make.ones" ).set_body_typed(MakeOnes); |
1495 | |
1496 | RELAY_REGISTER_OP("ones" ) |
1497 | .describe(R"code(Fill array with ones. |
1498 | |
1499 | )code" TVM_ADD_FILELINE) |
1500 | .set_attrs_type<InitOpAttrs>() |
1501 | .set_num_inputs(0) |
1502 | .set_support_level(3) |
1503 | .add_type_rel("InitOp" , InitOpRel); |
1504 | |
1505 | bool FullLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1506 | const TypeReporter& reporter) { |
1507 | ICHECK_EQ(types.size(), 3); |
1508 | const auto* data = types[0].as<TensorTypeNode>(); |
1509 | if (data == nullptr) { |
1510 | return false; |
1511 | } |
1512 | const auto* fill_value = types[1].as<TensorTypeNode>(); |
1513 | if (fill_value == nullptr) { |
1514 | return false; |
1515 | } |
1516 | |
1517 | ICHECK_EQ(fill_value->shape.size(), 0) |
1518 | << "The fill value should be a scalar but here it has dimension " << fill_value->shape.size() |
1519 | << "." ; |
1520 | |
1521 | reporter->Assign(types[2], TensorType(data->shape, data->dtype)); |
1522 | return true; |
1523 | } |
1524 | |
1525 | Array<te::Tensor> FullLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
1526 | const Type& out_type) { |
1527 | return {topi::full_like(inputs[0], inputs[1]())}; |
1528 | } |
1529 | |
1530 | Expr MakeFullLike(Expr data, Expr fill_value) { |
1531 | static const Op& op = Op::Get("full_like" ); |
1532 | return Call(op, {data, fill_value}, Attrs(), {}); |
1533 | } |
1534 | |
1535 | TVM_REGISTER_GLOBAL("relay.op._make.full_like" ).set_body_typed(MakeFullLike); |
1536 | |
1537 | RELAY_REGISTER_OP("full_like" ) |
1538 | .describe(R"code(Return an scalar value array with the same shape |
1539 | and type as the input array. |
1540 | |
1541 | )code" TVM_ADD_FILELINE) |
1542 | .set_num_inputs(2) |
1543 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1544 | .add_argument("fill_value" , "double" , "Scalar value to fill." ) |
1545 | .set_support_level(3) |
1546 | .add_type_rel("FullLike" , FullLikeRel) |
1547 | .set_attr<FTVMCompute>("FTVMCompute" , FullLikeCompute) |
1548 | .set_attr<TOpPattern>("TOpPattern" , kElemWise); |
1549 | |
1550 | // arange operator |
1551 | TVM_REGISTER_NODE_TYPE(ArangeAttrs); |
1552 | |
1553 | bool ArangeRel(const Array<Type>& types, int num_inputs, const Attrs& raw_attrs, |
1554 | const TypeReporter& reporter) { |
1555 | ICHECK_EQ(types.size(), 4); |
1556 | const ArangeAttrs* attrs = raw_attrs.as<ArangeAttrs>(); |
1557 | const ConstantNode *cstart, *cstop, *cstep; |
1558 | |
1559 | reporter->Assign(types[0], types[1]); |
1560 | reporter->Assign(types[1], types[2]); |
1561 | reporter->Assign(types[2], TensorType({}, attrs->dtype)); |
1562 | |
1563 | if ((cstart = attrs->start.as<ConstantNode>()) && (cstop = attrs->stop.as<ConstantNode>()) && |
1564 | (cstep = attrs->step.as<ConstantNode>())) { |
1565 | double start = ToScalar(cstart->data); |
1566 | double stop = ToScalar(cstop->data); |
1567 | double step = ToScalar(cstep->data); |
1568 | int32_t num_elem = static_cast<int32_t>(std::ceil((stop - start) / step)); |
1569 | ICHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start |
1570 | << ", " << attrs->stop << ", " << attrs->step; |
1571 | reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); |
1572 | return true; |
1573 | } else { |
1574 | reporter->Assign(types[3], TensorType({Any()}, attrs->dtype)); |
1575 | return true; |
1576 | } |
1577 | } |
1578 | |
1579 | inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, |
1580 | const te::Tensor& step, tvm::DataType dtype, |
1581 | std::string name = "T_arange_dynamic" , |
1582 | std::string tag = topi::kInjective) { |
1583 | ICHECK_EQ(start.ndim(), 0); |
1584 | ICHECK_EQ(stop.ndim(), 0); |
1585 | ICHECK_EQ(step.ndim(), 0); |
1586 | tvm::PrimExpr num_elem = tvm::tir::Var("num_elem" ); |
1587 | return te::compute( |
1588 | {num_elem}, |
1589 | [&](const Array<tvm::tir::Var>& indices) { |
1590 | Array<PrimExpr> empty_indices; |
1591 | return tvm::cast(dtype, start(empty_indices) + step(empty_indices) * indices[0]); |
1592 | }, |
1593 | name, tag); |
1594 | } |
1595 | |
1596 | Array<te::Tensor> ArangeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
1597 | const Type& out_type) { |
1598 | const ArangeAttrs* param = attrs.as<ArangeAttrs>(); |
1599 | ICHECK(param != nullptr); |
1600 | te::Tensor start = inputs[0]; |
1601 | te::Tensor stop = inputs[1]; |
1602 | te::Tensor step = inputs[2]; |
1603 | return {DynamicArange(start, stop, step, param->dtype)}; |
1604 | } |
1605 | |
1606 | Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) { |
1607 | auto attrs = make_object<ArangeAttrs>(); |
1608 | attrs->start = start; |
1609 | attrs->stop = stop; |
1610 | attrs->step = step; |
1611 | attrs->dtype = dtype; |
1612 | static const Op& op = Op::Get("arange" ); |
1613 | return Call(op, {start, stop, step}, Attrs(attrs), {}); |
1614 | } |
1615 | |
1616 | TVM_REGISTER_GLOBAL("relay.op._make.arange" ).set_body_typed(MakeArange); |
1617 | |
1618 | // An issue with the existing design is that we require dependency |
1619 | // to type the operator precisely. |
1620 | // |
1621 | // Supporting this in general is challenging so we duplicate the |
1622 | // secondary arguments as args and attributes. |
1623 | // |
1624 | // In this way reify the arguments at both the value and type level. |
1625 | // |
1626 | // In the case our arguments are constant we can immediately recover |
1627 | // the type of arange. |
1628 | // |
1629 | // In general I think we should avoid this pattern, and introduce |
1630 | // a secondary shape analysis to recover more precise information. |
1631 | RELAY_REGISTER_OP("arange" ) |
1632 | .describe(R"code(Returns evenly spaced values within a given interval. |
1633 | |
1634 | )code" TVM_ADD_FILELINE) |
1635 | .set_attrs_type<ArangeAttrs>() |
1636 | .set_num_inputs(3) |
1637 | .add_argument("start" , "Expr" , "Start of interval. The interval includes this value." ) |
1638 | .add_argument("end" , "Expr" , "Stop of interval. The interval does not include this value." ) |
1639 | .add_argument("step" , "Expr" , "Spacing between values." ) |
1640 | .set_support_level(3) |
1641 | .add_type_rel("Arange" , ArangeRel) |
1642 | .set_attr<FTVMCompute>("FTVMCompute" , ArangeCompute) |
1643 | // TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape |
1644 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
1645 | .set_attr<AnyCodegenStrategy>("AnyCodegenStrategy" , kVariableDimensions); |
1646 | |
1647 | // repeat operator |
1648 | TVM_REGISTER_NODE_TYPE(RepeatAttrs); |
1649 | |
1650 | bool RepeatRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1651 | const TypeReporter& reporter) { |
1652 | // `types` contains: [data, result] |
1653 | ICHECK_EQ(types.size(), 2); |
1654 | const auto* data = types[0].as<TensorTypeNode>(); |
1655 | if (data == nullptr) { |
1656 | ICHECK(types[0].as<IncompleteTypeNode>()) |
1657 | << "repeat: expect input type to be TensorType but get " << types[0]; |
1658 | return false; |
1659 | } |
1660 | const auto* param = attrs.as<RepeatAttrs>(); |
1661 | const int ndim = static_cast<int>(data->shape.size()); |
1662 | const int repeats = param->repeats.IntValue(); |
1663 | const int axis = param->axis.IntValue(); |
1664 | ICHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" |
1665 | << ", but got repeats = " << repeats; |
1666 | ICHECK(-ndim - 1 <= axis && axis <= ndim) |
1667 | << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" |
1668 | << ", but got axis = " << axis << ", and data.ndim = " << ndim; |
1669 | const int pivot = axis < 0 ? ndim + axis : axis; |
1670 | std::vector<IndexExpr> oshape; |
1671 | oshape.reserve(ndim + repeats); |
1672 | for (int i = 0; i < pivot; ++i) { |
1673 | oshape.emplace_back(data->shape[i]); |
1674 | } |
1675 | if (data->shape[pivot].as<AnyNode>()) { |
1676 | oshape.emplace_back(Any()); |
1677 | } else { |
1678 | oshape.emplace_back(data->shape[pivot] * repeats); |
1679 | } |
1680 | for (int i = pivot + 1; i < ndim; ++i) { |
1681 | oshape.emplace_back(data->shape[i]); |
1682 | } |
1683 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
1684 | return true; |
1685 | } |
1686 | |
1687 | Array<te::Tensor> RepeatCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
1688 | const Type& out_type) { |
1689 | const RepeatAttrs* param = attrs.as<RepeatAttrs>(); |
1690 | ICHECK(param != nullptr); |
1691 | return {topi::repeat(inputs[0], param->repeats.IntValue(), param->axis.IntValue())}; |
1692 | } |
1693 | |
1694 | Expr MakeRepeat(Expr data, int repeats, int axis) { |
1695 | auto attrs = make_object<RepeatAttrs>(); |
1696 | attrs->repeats = repeats; |
1697 | attrs->axis = axis; |
1698 | static const Op& op = Op::Get("repeat" ); |
1699 | return Call(op, {data}, Attrs(attrs), {}); |
1700 | } |
1701 | |
1702 | TVM_REGISTER_GLOBAL("relay.op._make.repeat" ).set_body_typed(MakeRepeat); |
1703 | |
1704 | RELAY_REGISTER_OP("repeat" ) |
1705 | .describe(R"code(Repeat elements of an array `repeats` times along axis `axis` |
1706 | |
1707 | - **data**: The input data to the operator. |
1708 | |
1709 | )code" TVM_ADD_FILELINE) |
1710 | .set_num_inputs(1) |
1711 | .set_attrs_type<RepeatAttrs>() |
1712 | .add_argument("data" , "Tensor" , "The input tensor." ) |
1713 | .set_support_level(3) |
1714 | .add_type_rel("Repeat" , RepeatRel) |
1715 | .set_attr<FTVMCompute>("FTVMCompute" , RepeatCompute) |
1716 | .set_attr<TOpPattern>("TOpPattern" , kBroadcast); |
1717 | |
1718 | bool SparseFillEmptyRowsRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1719 | const TypeReporter& reporter) { |
1720 | // types: [sparse_indices, sparse_values, dense_shape, default_value, result] |
1721 | ICHECK_EQ(types.size(), 5) << "SparseFillEmptyRowsRel expects 5 inputs but " << types.size() |
1722 | << "provided" ; |
1723 | std::vector<Type> fields; |
1724 | auto sparse_indices = types[0].as<TensorTypeNode>(); |
1725 | auto ndims = sparse_indices->shape[1]; |
1726 | fields.push_back(TensorType(Array<PrimExpr>{Any(), ndims}, tvm::DataType::Int(64))); |
1727 | fields.push_back(TensorType(Array<PrimExpr>{Any()}, tvm::DataType::Int(64))); |
1728 | fields.push_back(TensorType(Array<PrimExpr>{Any()}, tvm::DataType::Int(64))); |
1729 | reporter->Assign(types[types.size() - 1], TupleType(Array<Type>(fields))); |
1730 | return true; |
1731 | } |
1732 | |
1733 | Expr MakeSparseFillEmptyRows(Expr sparse_indices, Expr sparse_values, Expr dense_shape, |
1734 | Expr default_value) { |
1735 | static const Op& op = Op::Get("sparse_fill_empty_rows" ); |
1736 | return Call(op, {sparse_indices, sparse_values, dense_shape, default_value}, Attrs(), {}); |
1737 | } |
1738 | |
1739 | TVM_REGISTER_GLOBAL("relay.op._make.sparse_fill_empty_rows" ) |
1740 | .set_body_typed(MakeSparseFillEmptyRows); |
1741 | |
1742 | RELAY_REGISTER_OP("sparse_fill_empty_rows" ) |
1743 | .describe( |
1744 | R"code(Fill empty rows of a sparse tensor with a default value.)code" TVM_ADD_FILELINE) |
1745 | .set_num_inputs(4) |
1746 | .add_argument("sparse_indices" , "Tensor" , |
1747 | "A 2-D int64 tensor of shape [N, ndims], which specifies the indices of the" |
1748 | "elements in the sparse tensor that contain nonzero values. COO Format" ) |
1749 | .add_argument( |
1750 | "sparse_values" , "Tensor" , |
1751 | "A 1-D tensor[N] which supplies the values for each element in indices. COO Format" ) |
1752 | .add_argument("dense_shape" , "Tensor" , |
1753 | "A 1-D int64 tensor of shape [ndims], which specifies the dense_shape of the" |
1754 | "sparse tensor. Takes a list indicating the number of elements in each " |
1755 | "dimension" ) |
1756 | .add_argument("default_value" , "Tensor" , |
1757 | "The value to fill for empty rows, with the same type as sparse_values" ) |
1758 | .add_type_rel("sparse_fill_empty_rows" , SparseFillEmptyRowsRel) |
1759 | .set_support_level(3) |
1760 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
1761 | |
1762 | bool SparseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1763 | const TypeReporter& reporter) { |
1764 | // types: [sparse_indices, prev_shape, new_shape, result] |
1765 | ICHECK_EQ(types.size(), 4) << "SparseReshapeRel expects 4 types but " << types.size() |
1766 | << " provided" ; |
1767 | ICHECK_EQ(num_inputs, 3) << "SparseReshapeRel expects 4 inputs but " << num_inputs << " provided" ; |
1768 | auto sparse_indices = types[0].as<TensorTypeNode>(); |
1769 | auto prev_shape = types[1].as<TensorTypeNode>(); |
1770 | auto new_shape = types[2].as<TensorTypeNode>(); |
1771 | if (sparse_indices == nullptr || prev_shape == nullptr || new_shape == nullptr) { |
1772 | return false; |
1773 | } |
1774 | CHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers" ; |
1775 | CHECK(prev_shape->dtype.is_int()) << "prev_shape must be tensor of integers" ; |
1776 | CHECK(new_shape->dtype.is_int()) << "new_shape must be tensor of integers" ; |
1777 | ICHECK_EQ(sparse_indices->shape.size(), 2) << "sparse_indices must be 2-D tensor" ; |
1778 | ICHECK_EQ(prev_shape->shape.size(), 1) << "prev_shape must be 1-D tensor" ; |
1779 | ICHECK_EQ(new_shape->shape.size(), 1) << "new_shape must be 1-D tensor" ; |
1780 | std::vector<Type> fields; |
1781 | Array<PrimExpr> new_sparse_indices_shape{sparse_indices->shape[0], new_shape->shape[0]}; |
1782 | fields.push_back(TensorType(new_sparse_indices_shape, sparse_indices->dtype)); |
1783 | fields.push_back(TensorType(new_shape->shape, new_shape->dtype)); |
1784 | reporter->Assign(types[3], TupleType(Array<Type>(fields))); |
1785 | return true; |
1786 | } |
1787 | |
1788 | Expr MakeSparseReshape(Expr sparse_indices, Expr prev_shape, Expr new_shape) { |
1789 | static const Op& op = Op::Get("sparse_reshape" ); |
1790 | return Call(op, {sparse_indices, prev_shape, new_shape}, Attrs(), {}); |
1791 | } |
1792 | |
1793 | TVM_REGISTER_GLOBAL("relay.op._make.sparse_reshape" ).set_body_typed(MakeSparseReshape); |
1794 | |
1795 | RELAY_REGISTER_OP("sparse_reshape" ) |
1796 | .describe(R"code(Return new sparse indices of the reshaped tensor |
1797 | )code" TVM_ADD_FILELINE) |
1798 | .set_num_inputs(3) |
1799 | .add_argument("sparse_indices" , "Tensor" , |
1800 | "A 2-D tensor of shape [N, ndims], which specifies the indices of the" |
1801 | "elements in the sparse tensor that contain nonzero values. COO Format" ) |
1802 | .add_argument("prev_shape" , "Tensor" , |
1803 | "A 1-D tensor of shape [ndims], which specifies the previous dense shape of the" |
1804 | "sparse tensor" ) |
1805 | .add_argument("new_shape" , "Tensor" , |
1806 | "A 1-D tensor of shape [ndims], which specifies the desired dense shape of the" |
1807 | "sparse tensor" ) |
1808 | .add_type_rel("sparse_reshape" , SparseReshapeRel) |
1809 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
1810 | .set_support_level(3); |
1811 | |
1812 | TVM_REGISTER_NODE_TYPE(StftAttrs); |
1813 | |
1814 | bool STFTRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1815 | const TypeReporter& reporter) { |
1816 | // types: [data, window, result] |
1817 | ICHECK_EQ(types.size(), 3) << "STFTRel expects 3 types but " << types.size() << "provided" ; |
1818 | ICHECK_EQ(num_inputs, 2) << "Unique: expect 2 inputs but " << num_inputs << " provided" ; |
1819 | auto data = types[0].as<TensorTypeNode>(); |
1820 | if (data == nullptr) { |
1821 | ICHECK(types[0].as<IncompleteTypeNode>()) |
1822 | << "Unique: expect input type to be TensorType but get " << types[0]; |
1823 | return false; |
1824 | } |
1825 | const auto* param = attrs.as<StftAttrs>(); |
1826 | const int ndim = static_cast<int>(data->shape.size()); |
1827 | std::vector<IndexExpr> oshape; |
1828 | int dim = 0; |
1829 | if (ndim == 2) { |
1830 | oshape.push_back(data->shape[0]); // batch dimension |
1831 | dim += 1; |
1832 | } |
1833 | oshape.push_back(param->onesided ? param->n_fft / 2 + 1 : param->n_fft); |
1834 | if (data->shape[dim].as<AnyNode>()) |
1835 | oshape.push_back(Any()); |
1836 | else |
1837 | oshape.push_back(indexdiv((data->shape[dim] - param->n_fft), param->hop_length) + |
1838 | 1); // n_frames |
1839 | oshape.push_back(2); |
1840 | reporter->Assign(types[2], TensorType(oshape, data->dtype)); |
1841 | return true; |
1842 | } |
1843 | |
1844 | Expr MakeSTFT(Expr data, int n_fft, int hop_length, int win_length, Expr window, bool normalized, |
1845 | bool onesided) { |
1846 | auto attrs = make_object<StftAttrs>(); |
1847 | attrs->n_fft = n_fft; |
1848 | attrs->hop_length = hop_length; |
1849 | attrs->win_length = win_length; |
1850 | attrs->normalized = normalized; |
1851 | attrs->onesided = onesided; |
1852 | static const Op& op = Op::Get("stft" ); |
1853 | return Call(op, {data, window}, Attrs(attrs), {}); |
1854 | } |
1855 | |
1856 | TVM_REGISTER_GLOBAL("relay.op._make.stft" ).set_body_typed(MakeSTFT); |
1857 | |
1858 | RELAY_REGISTER_OP("stft" ) |
1859 | .describe( |
1860 | R"code(The STFT computes the Fourier transform of short overlapping windows of the input. |
1861 | )code" TVM_ADD_FILELINE) |
1862 | .set_num_inputs(2) |
1863 | .add_argument("data" , "Tensor" , "the input tensor" ) |
1864 | .add_argument("window" , "Tensor" , "the optional window function" ) |
1865 | .add_type_rel("stft" , STFTRel) |
1866 | .set_support_level(3) |
1867 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
1868 | |
1869 | // meshgrid operator |
1870 | TVM_REGISTER_NODE_TYPE(MeshgridAttrs); |
1871 | |
1872 | bool MeshgridRel(const Array<Type>& types, int num_inputs, const Attrs& raw_attrs, |
1873 | const TypeReporter& reporter) { |
1874 | // types: [data, result] |
1875 | ICHECK_EQ(types.size(), 2); |
1876 | const MeshgridAttrs* attrs = raw_attrs.as<MeshgridAttrs>(); |
1877 | const auto* tensor_tuple = types[0].as<TupleTypeNode>(); |
1878 | if (tensor_tuple == nullptr) { |
1879 | throw CompileError(ErrorBuilder() |
1880 | << "meshgrid requires a tuple of tensors as the first argument, found " |
1881 | << PrettyPrint(types[0])); |
1882 | } else if (types[0].as<IncompleteTypeNode>() != nullptr) { |
1883 | return false; |
1884 | } |
1885 | const int data_length = static_cast<int>(tensor_tuple->fields.size()); |
1886 | |
1887 | // Get first dtype. |
1888 | const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]); |
1889 | const DataType dtype = first->dtype; |
1890 | |
1891 | // Get size of output grid. |
1892 | std::vector<IndexExpr> grid_shape; |
1893 | grid_shape.reserve(data_length); |
1894 | for (const Type& ele : tensor_tuple->fields) { |
1895 | if (ele.as<IncompleteTypeNode>()) { |
1896 | return false; |
1897 | } |
1898 | const auto& e = Downcast<TensorType>(ele); |
1899 | int e_ndim = static_cast<int>(e->shape.size()); |
1900 | const DataType& e_dtype = e->dtype; |
1901 | if (e_dtype != dtype) { |
1902 | throw CompileError("relay.meshgrid requires all tensors have the same dtype" ); |
1903 | } |
1904 | if (e_ndim == 0) { |
1905 | grid_shape.emplace_back(1); |
1906 | } else if (e_ndim == 1) { |
1907 | grid_shape.emplace_back(e->shape[0]); |
1908 | } else { |
1909 | throw CompileError("relay.meshgrid requires all tensors be either scalars or 1-D vectors." ); |
1910 | } |
1911 | } |
1912 | |
1913 | // "xy" mode swaps first two dimensions |
1914 | if (attrs->indexing == "xy" && grid_shape.size() >= 2) { |
1915 | std::swap(grid_shape[0], grid_shape[1]); |
1916 | } |
1917 | |
1918 | // There is one output grid for each input, all with same shape. |
1919 | std::vector<Type> grids; |
1920 | grids.reserve(data_length); |
1921 | for (int i = 0; i < data_length; i++) { |
1922 | grids.emplace_back(TensorType(grid_shape, dtype)); |
1923 | } |
1924 | reporter->Assign(types[1], TupleType(Array<Type>(grids))); |
1925 | return true; |
1926 | } |
1927 | |
1928 | Array<te::Tensor> MeshgridCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
1929 | const Type& out_type) { |
1930 | const MeshgridAttrs* param = attrs.as<MeshgridAttrs>(); |
1931 | ICHECK(param != nullptr); |
1932 | return {topi::meshgrid(inputs, param->indexing)}; |
1933 | } |
1934 | |
1935 | Expr MakeMeshgrid(Expr data, String indexing) { |
1936 | auto attrs = make_object<MeshgridAttrs>(); |
1937 | attrs->indexing = std::move(indexing); |
1938 | static const Op& op = Op::Get("meshgrid" ); |
1939 | return Call(op, {data}, Attrs(attrs), {}); |
1940 | } |
1941 | |
1942 | TVM_REGISTER_GLOBAL("relay.op._make.meshgrid" ).set_body_typed(MakeMeshgrid); |
1943 | |
1944 | RELAY_REGISTER_OP("meshgrid" ) |
1945 | .describe(R"code(Create coordinate matrices from coordinate vectors. |
1946 | |
1947 | )code" TVM_ADD_FILELINE) |
1948 | .set_attrs_type<MeshgridAttrs>() |
1949 | .set_num_inputs(1) |
1950 | .add_argument("data" , "Tensor" , "The input list of tensors." ) |
1951 | .set_support_level(3) |
1952 | .add_type_rel("Meshgrid" , MeshgridRel) |
1953 | .set_attr<FTVMCompute>("FTVMCompute" , MeshgridCompute) |
1954 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
1955 | |
1956 | // tile operator |
1957 | TVM_REGISTER_NODE_TYPE(TileAttrs); |
1958 | |
1959 | bool TileRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
1960 | const TypeReporter& reporter) { |
1961 | // `types` contains: [data, result] |
1962 | ICHECK_EQ(types.size(), 2); |
1963 | const auto* data = types[0].as<TensorTypeNode>(); |
1964 | if (data == nullptr) { |
1965 | ICHECK(types[0].as<IncompleteTypeNode>()) |
1966 | << "tile: expect input type to be TensorType but get " << types[0]; |
1967 | return false; |
1968 | } |
1969 | const auto* param = attrs.as<TileAttrs>(); |
1970 | const size_t ndim = data->shape.size(); |
1971 | const Array<Integer>& reps = param->reps; |
1972 | // check dimension match |
1973 | ICHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; |
1974 | const size_t rndim = reps.size(); |
1975 | for (size_t i = 0; i < rndim; ++i) { |
1976 | if (const tvm::tir::IntImmNode* val = reps[i].as<tvm::tir::IntImmNode>()) { |
1977 | ICHECK_GT(val->value, 0) << "Tile reps value should always be larger than 0, but get: " |
1978 | << val->value; |
1979 | } |
1980 | } |
1981 | size_t tndim = (ndim > rndim) ? ndim : rndim; |
1982 | // re-construct data shape or reps shape |
1983 | std::vector<IndexExpr> data_shape; |
1984 | std::vector<IndexExpr> reps_shape; |
1985 | data_shape.reserve(tndim); |
1986 | reps_shape.reserve(tndim); |
1987 | if (ndim == rndim) { |
1988 | for (size_t i = 0; i < tndim; ++i) { |
1989 | data_shape.emplace_back(data->shape[i]); |
1990 | reps_shape.emplace_back(reps[i]); |
1991 | } |
1992 | } else if (ndim > rndim) { |
1993 | for (size_t i = 0; i < ndim; ++i) { |
1994 | data_shape.emplace_back(data->shape[i]); |
1995 | } |
1996 | for (size_t i = 0; i < (ndim - rndim); ++i) { |
1997 | reps_shape.emplace_back(1); |
1998 | } |
1999 | for (size_t i = 0; i < rndim; ++i) { |
2000 | reps_shape.emplace_back(reps[i]); |
2001 | } |
2002 | } else { |
2003 | for (size_t i = 0; i < rndim; ++i) { |
2004 | reps_shape.emplace_back(reps[i]); |
2005 | } |
2006 | for (size_t i = 0; i < (rndim - ndim); ++i) { |
2007 | data_shape.emplace_back(1); |
2008 | } |
2009 | for (size_t i = 0; i < ndim; ++i) { |
2010 | data_shape.emplace_back(data->shape[i]); |
2011 | } |
2012 | } |
2013 | std::vector<IndexExpr> oshape; |
2014 | oshape.reserve(tndim); |
2015 | for (size_t i = 0; i < tndim; ++i) { |
2016 | // Save Any if it is dynamic shape |
2017 | if (!data_shape[i].as<IntImmNode>()) { |
2018 | oshape.emplace_back(Any()); |
2019 | } else { |
2020 | oshape.emplace_back(data_shape[i] * reps_shape[i]); |
2021 | } |
2022 | } |
2023 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
2024 | return true; |
2025 | } |
2026 | |
2027 | Array<te::Tensor> TileCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2028 | const Type& out_type) { |
2029 | const TileAttrs* param = attrs.as<TileAttrs>(); |
2030 | ICHECK(param != nullptr); |
2031 | return {topi::tile(inputs[0], param->reps)}; |
2032 | } |
2033 | |
2034 | Expr MakeTile(Expr data, Array<Integer> reps) { |
2035 | auto attrs = make_object<TileAttrs>(); |
2036 | attrs->reps = reps; |
2037 | static const Op& op = Op::Get("tile" ); |
2038 | return Call(op, {data}, Attrs(attrs), {}); |
2039 | } |
2040 | |
2041 | TVM_REGISTER_GLOBAL("relay.op._make.tile" ).set_body_typed(MakeTile); |
2042 | |
2043 | RELAY_REGISTER_OP("tile" ) |
2044 | .describe(R"code(Repeat the whole array multiple times. |
2045 | |
2046 | - **data**: The input data to the operator. |
2047 | |
2048 | )code" TVM_ADD_FILELINE) |
2049 | .set_num_inputs(1) |
2050 | .set_attrs_type<TileAttrs>() |
2051 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2052 | .set_support_level(3) |
2053 | .add_type_rel("Tile" , TileRel) |
2054 | .set_attr<FTVMCompute>("FTVMCompute" , TileCompute) |
2055 | .set_attr<TOpPattern>("TOpPattern" , kBroadcast); |
2056 | |
2057 | // reverse operator |
2058 | TVM_REGISTER_NODE_TYPE(ReverseAttrs); |
2059 | |
2060 | bool ReverseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2061 | const TypeReporter& reporter) { |
2062 | // `types` contains: [data, result] |
2063 | ICHECK_EQ(types.size(), 2); |
2064 | const auto* data = types[0].as<TensorTypeNode>(); |
2065 | if (data == nullptr) { |
2066 | ICHECK(types[0].as<IncompleteTypeNode>()) |
2067 | << "reverse: expect input type to be TensorType but get " << types[0]; |
2068 | return false; |
2069 | } |
2070 | const auto* param = attrs.as<ReverseAttrs>(); |
2071 | const int ndim = static_cast<int>(data->shape.size()); |
2072 | const int axis = param->axis.IntValue(); |
2073 | ICHECK(-ndim <= axis && axis < ndim) |
2074 | << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" |
2075 | << ", but got axis = " << axis << ", and data.ndim = " << ndim; |
2076 | reporter->Assign(types[1], types[0]); |
2077 | return true; |
2078 | } |
2079 | |
2080 | Array<te::Tensor> ReverseCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2081 | const Type& out_type) { |
2082 | const ReverseAttrs* param = attrs.as<ReverseAttrs>(); |
2083 | ICHECK(param != nullptr); |
2084 | // pass empty seq_length tensor to reverse_sequence |
2085 | return {topi::reverse_sequence(inputs[0], te::Tensor(), param->axis.IntValue())}; |
2086 | } |
2087 | |
2088 | Expr MakeReverse(Expr data, int axis) { |
2089 | auto attrs = make_object<ReverseAttrs>(); |
2090 | attrs->axis = axis; |
2091 | static const Op& op = Op::Get("reverse" ); |
2092 | return Call(op, {data}, Attrs(attrs), {}); |
2093 | } |
2094 | |
2095 | TVM_REGISTER_GLOBAL("relay.op._make.reverse" ).set_body_typed(MakeReverse); |
2096 | |
2097 | RELAY_REGISTER_OP("reverse" ) |
2098 | .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. |
2099 | |
2100 | - **data**: The input data to the operator. |
2101 | |
2102 | )code" TVM_ADD_FILELINE) |
2103 | .set_num_inputs(1) |
2104 | .set_attrs_type<ReverseAttrs>() |
2105 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2106 | .set_support_level(3) |
2107 | .add_type_rel("Reverse" , ReverseRel) |
2108 | .set_attr<FTVMCompute>("FTVMCompute" , ReverseCompute) |
2109 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
2110 | |
2111 | // reverse sequence operator |
2112 | TVM_REGISTER_NODE_TYPE(ReverseSequenceAttrs); |
2113 | |
2114 | bool ReverseSequenceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2115 | const TypeReporter& reporter) { |
2116 | // `types` contains: [data, seq_lengths, result] |
2117 | ICHECK_EQ(types.size(), 3); |
2118 | const auto* data = types[0].as<TensorTypeNode>(); |
2119 | |
2120 | if (data == nullptr) { |
2121 | ICHECK(types[0].as<IncompleteTypeNode>()) |
2122 | << "reverse_sequence: expect input type to be TensorType but get " << types[0]; |
2123 | return false; |
2124 | } |
2125 | |
2126 | const auto* seq_lengths = types[1].as<TensorTypeNode>(); |
2127 | if (seq_lengths == nullptr) { |
2128 | ICHECK(types[1].as<IncompleteTypeNode>()) |
2129 | << "reverse_sequence: expect input type to be TensorType but get " << types[1]; |
2130 | return false; |
2131 | } |
2132 | |
2133 | const int seq_lengths_dim = static_cast<int>(seq_lengths->shape.size()); |
2134 | ICHECK(seq_lengths_dim == 1) << "For reverse_sequnece, seq_lengths must be a 1D vector" ; |
2135 | ICHECK(seq_lengths->dtype.is_int()) |
2136 | << "For reverse_sequnece, seq_lengths must be tensor of integer" ; |
2137 | |
2138 | const auto* param = attrs.as<ReverseSequenceAttrs>(); |
2139 | const int ndim = static_cast<int>(data->shape.size()); |
2140 | int batch_axis = param->batch_axis.IntValue(); |
2141 | ICHECK(-ndim <= batch_axis && batch_axis < ndim) |
2142 | << "reverse_sequence only accepts `batch_axis` in [-data.ndim, data.ndim - 1]" |
2143 | << ", but got batch_axis = " << batch_axis << ", and data.ndim = " << ndim; |
2144 | |
2145 | if (batch_axis < 0) { |
2146 | batch_axis = static_cast<int>(data->shape.size()) + batch_axis; |
2147 | } |
2148 | ICHECK(reporter->Assert(seq_lengths->shape[0] == data->shape[batch_axis])) |
2149 | << "For reverse_sequnece seq_lengths size should match with dimension of batch axis" |
2150 | << ", but got dimension of batch_axis = " << data->shape[batch_axis] |
2151 | << ", and seq_length size = " << seq_lengths->shape[0]; |
2152 | |
2153 | const int seq_axis = param->seq_axis.IntValue(); |
2154 | ICHECK(-ndim <= seq_axis && seq_axis < ndim) |
2155 | << "reverse_sequnece only accepts `seq_axis` in [-data.ndim, data.ndim - 1]" |
2156 | << ", but got seq_axis = " << seq_axis << ", and data.ndim = " << ndim; |
2157 | |
2158 | reporter->Assign(types[2], types[0]); |
2159 | return true; |
2160 | } |
2161 | |
2162 | Array<te::Tensor> ReverseSequenceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2163 | const Type& out_type) { |
2164 | const ReverseSequenceAttrs* param = attrs.as<ReverseSequenceAttrs>(); |
2165 | ICHECK(param != nullptr); |
2166 | return {topi::reverse_sequence(inputs[0], inputs[1], param->seq_axis.IntValue(), |
2167 | param->batch_axis.IntValue())}; |
2168 | } |
2169 | |
2170 | Expr MakeReverseSequence(Expr data, Expr seq_lengths, int seq_axis, int batch_axis) { |
2171 | auto attrs = make_object<ReverseSequenceAttrs>(); |
2172 | attrs->seq_axis = seq_axis; |
2173 | attrs->batch_axis = batch_axis; |
2174 | static const Op& op = Op::Get("reverse_sequence" ); |
2175 | return Call(op, {data, seq_lengths}, Attrs(attrs), {}); |
2176 | } |
2177 | |
2178 | TVM_REGISTER_GLOBAL("relay.op._make.reverse_sequence" ).set_body_typed(MakeReverseSequence); |
2179 | |
2180 | RELAY_REGISTER_OP("reverse_sequence" ) |
2181 | .describe(R"code(Reverses the tensor for variable length slices. |
2182 | Input is first sliced along batch axis and then elements are reversed along seq axis. |
2183 | |
2184 | - **data**: The input data to the operator. |
2185 | |
2186 | - **seq_lengths**: A 1D Tensor with length data.dims[batch_axis]. |
2187 | |
2188 | - **seq_axis**: The axis along which the elements will be reversed. Default is 1. |
2189 | |
2190 | - **batch_axis**: The axis along which the tensor will be sliced. Default is 0. |
2191 | |
2192 | )code" TVM_ADD_FILELINE) |
2193 | .set_num_inputs(2) |
2194 | .set_attrs_type<ReverseSequenceAttrs>() |
2195 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2196 | .add_argument("seq_lengths" , "Tensor" , "A 1D Tensor with length data.dims[batch_axis]" ) |
2197 | .set_support_level(3) |
2198 | .add_type_rel("ReverseSequence" , ReverseSequenceRel) |
2199 | .set_attr<FTVMCompute>("FTVMCompute" , ReverseSequenceCompute) |
2200 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
2201 | |
2202 | // where operator |
2203 | bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2204 | const TypeReporter& reporter) { |
2205 | ICHECK_EQ(types.size(), 4U); |
2206 | const auto* condition = types[0].as<TensorTypeNode>(); |
2207 | const auto* x = types[1].as<TensorTypeNode>(); |
2208 | const auto* y = types[2].as<TensorTypeNode>(); |
2209 | |
2210 | if (condition == nullptr || x == nullptr || y == nullptr) { |
2211 | return false; |
2212 | } |
2213 | |
2214 | ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " |
2215 | << y->dtype; |
2216 | |
2217 | auto tensor_ty_condition = GetRef<TensorType>(condition); |
2218 | auto tensor_ty_x = GetRef<TensorType>(x); |
2219 | auto tensor_ty_y = GetRef<TensorType>(y); |
2220 | |
2221 | auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype); |
2222 | auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype); |
2223 | |
2224 | reporter->Assign(types[3], ret_ty); |
2225 | return true; |
2226 | } |
2227 | |
2228 | // Positional relay function to create where operator. |
2229 | Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) { |
2230 | static const Op& op = Op::Get("where" ); |
2231 | return Call(op, {condition, x, y}); |
2232 | } |
2233 | |
2234 | Array<te::Tensor> WhereCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2235 | const Type& out_type) { |
2236 | return {topi::where(inputs[0], inputs[1], inputs[2])}; |
2237 | } |
2238 | |
2239 | TVM_REGISTER_GLOBAL("relay.op._make.where" ).set_body_typed(MakeWhere); |
2240 | |
2241 | RELAY_REGISTER_OP("where" ) |
2242 | .describe(R"code( |
2243 | Return the elements, either from x or y, depending on the condition. |
2244 | |
2245 | Given three ndarrays, condition, x, and y, return an ndarray with the elements |
2246 | from x or y, depending on the elements from condition are true or false. |
2247 | |
2248 | Shapes of condition, x, and y must be broadcastable to a common shape, which |
2249 | is the output shape of this op. Semantics follow numpy where function. |
2250 | https://numpy.org/doc/stable/reference/generated/numpy.where.html |
2251 | |
2252 | Note that all non-zero values are interpreted as True in condition. |
2253 | |
2254 | Examples:: |
2255 | |
2256 | x = [[1, 2], [3, 4]] |
2257 | y = [[5, 6], [7, 8]] |
2258 | cond = [[0, 1], [-1, 0]] |
2259 | where(cond, x, y) = [[5, 2], [3, 8]] |
2260 | |
2261 | |
2262 | cond = [[1], [0]] |
2263 | where(cond, x, y) = [[1, 2], [7, 8]] |
2264 | |
2265 | cond = [0, 1] |
2266 | where(cond, 1, -1) = [-1, 1] |
2267 | |
2268 | )code" TVM_ADD_FILELINE) |
2269 | .add_argument("condition" , "Tensor" , "Condition array" ) |
2270 | .add_argument("x" , "Tensor" , "First array to be selected" ) |
2271 | .add_argument("y" , "Tensor" , "Second array to be selected" ) |
2272 | .set_num_inputs(3) |
2273 | .set_support_level(4) |
2274 | .add_type_rel("Where" , WhereRel) |
2275 | .set_attr<FTVMCompute>("FTVMCompute" , WhereCompute) |
2276 | .set_attr<TOpPattern>("TOpPattern" , kBroadcast); |
2277 | |
2278 | // Squeeze |
2279 | TVM_REGISTER_NODE_TYPE(SqueezeAttrs); |
2280 | |
2281 | Expr MakeSqueeze(Expr data, Array<Integer> axis) { |
2282 | auto attrs = make_object<SqueezeAttrs>(); |
2283 | attrs->axis = std::move(axis); |
2284 | static const Op& op = Op::Get("squeeze" ); |
2285 | return Call(op, {data}, Attrs(attrs), {}); |
2286 | } |
2287 | |
2288 | TVM_REGISTER_GLOBAL("relay.op._make.squeeze" ).set_body_typed(MakeSqueeze); |
2289 | |
2290 | bool SqueezeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2291 | const TypeReporter& reporter) { |
2292 | ICHECK_EQ(types.size(), 2); |
2293 | const auto* data = types[0].as<TensorTypeNode>(); |
2294 | if (data == nullptr) { |
2295 | return false; |
2296 | } |
2297 | const auto* param = attrs.as<SqueezeAttrs>(); |
2298 | ICHECK(param != nullptr); |
2299 | std::vector<IndexExpr> result_shape; |
2300 | // if axes is None, squeeze all axes of dimension 1 |
2301 | if (!param->axis.defined()) { |
2302 | for (const auto& e : data->shape) { |
2303 | if (!e.as<IntImmNode>()) { |
2304 | LOG(FATAL) << "axis needs to be defined for dynamic input." ; |
2305 | } |
2306 | const int64_t* axis_ptr = tir::as_const_int(e); |
2307 | ICHECK(axis_ptr != nullptr) << "the axes attribute must be concrete" ; |
2308 | if (*axis_ptr != 1) { |
2309 | result_shape.push_back(e); |
2310 | } |
2311 | } |
2312 | } else { |
2313 | // pair up original shape with a boolean which control whether it will be in the final shape. |
2314 | std::vector<std::pair<IndexExpr, bool>> original_shape; |
2315 | for (const auto& e : data->shape) { |
2316 | original_shape.push_back(std::pair<IndexExpr, bool>(e, true)); |
2317 | } |
2318 | for (const auto& e : param->axis) { |
2319 | int64_t axis_val = e->value; |
2320 | if (axis_val < 0) { |
2321 | axis_val += static_cast<int64_t>(original_shape.size()); |
2322 | } |
2323 | ICHECK_GE(axis_val, 0); |
2324 | ICHECK_LT(axis_val, original_shape.size()); |
2325 | original_shape.at(axis_val).second = false; |
2326 | } |
2327 | for (const auto& p : original_shape) { |
2328 | if (p.second) { |
2329 | result_shape.push_back(p.first); |
2330 | } else { |
2331 | if (const int64_t* axis_ptr = tir::as_const_int(p.first)) { |
2332 | ICHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not equal to 1" ; |
2333 | } |
2334 | } |
2335 | } |
2336 | } |
2337 | reporter->Assign(types[1], TensorType(result_shape, data->dtype)); |
2338 | return true; |
2339 | } |
2340 | |
2341 | Array<te::Tensor> SqueezeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2342 | const Type& out_type) { |
2343 | const SqueezeAttrs* param = attrs.as<SqueezeAttrs>(); |
2344 | ICHECK(param != nullptr); |
2345 | return {topi::squeeze(inputs[0], param->axis)}; |
2346 | } |
2347 | |
2348 | InferCorrectLayoutOutput SqueezeInferCorrectLayout(const Attrs& attrs, |
2349 | const Array<Layout>& new_in_layouts, |
2350 | const Array<Layout>& old_in_layouts, |
2351 | const Array<tvm::relay::Type>& old_in_types) { |
2352 | const auto* attrs_ptr = attrs.as<SqueezeAttrs>(); |
2353 | ICHECK(attrs_ptr); |
2354 | ObjectPtr<SqueezeAttrs> params = make_object<SqueezeAttrs>(*attrs_ptr); |
2355 | |
2356 | Layout inferred_input = new_in_layouts.defined() ? new_in_layouts[0] : old_in_layouts[0]; |
2357 | Layout inferred_output = inferred_input; |
2358 | |
2359 | ICHECK(old_in_types[0].as<TensorTypeNode>()); |
2360 | const auto& shape = old_in_types[0].as<TensorTypeNode>()->shape; |
2361 | |
2362 | // axis to squeeze |
2363 | Array<Integer> axis; |
2364 | if (params->axis.defined()) { |
2365 | axis = params->axis; |
2366 | } else { |
2367 | // if axes is None, squeeze all axes of dimension 1 |
2368 | for (size_t i = 0; i < shape.size(); i++) { |
2369 | if (topi::detail::GetConstInt(shape[i]) == 1) { |
2370 | axis.push_back(i); |
2371 | } |
2372 | } |
2373 | } |
2374 | |
2375 | // If new_in_layouts are defined, this code tries to modify the layout |
2376 | if (new_in_layouts.defined() && old_in_layouts.defined()) { |
2377 | Array<Integer> new_axis; |
2378 | for (const auto& e : axis) { |
2379 | const auto& dim = old_in_layouts[0][e.IntValue()]; |
2380 | new_axis.push_back((new_in_layouts[0]).IndexOf(dim)); |
2381 | } |
2382 | params->axis = new_axis; |
2383 | axis = new_axis; |
2384 | } |
2385 | |
2386 | // Infer output layout |
2387 | Array<tir::IterVar> kept_axes; |
2388 | for (size_t i = 0; i < inferred_input.ndim(); i++) { |
2389 | bool is_dim_kept = true; |
2390 | |
2391 | // Check whether the dim should be kept |
2392 | for (const auto& e : axis) { |
2393 | int64_t axis_val = e->value; |
2394 | if (axis_val < 0) { |
2395 | axis_val += inferred_input.ndim(); |
2396 | } |
2397 | if (static_cast<int64_t>(i) == axis_val) { |
2398 | is_dim_kept = false; |
2399 | break; |
2400 | } |
2401 | } |
2402 | |
2403 | if (is_dim_kept) { |
2404 | kept_axes.push_back(inferred_input->axes[i]); |
2405 | } |
2406 | } |
2407 | inferred_output = Layout(kept_axes); |
2408 | |
2409 | return InferCorrectLayoutOutput({inferred_input}, {inferred_output}, Attrs(params)); |
2410 | } |
2411 | |
2412 | RELAY_REGISTER_OP("squeeze" ) |
2413 | .describe(R"code(Squeeze the input tensor at the dimensions given by axes |
2414 | |
2415 | - **data**: The input data to the operator. |
2416 | |
2417 | )code" TVM_ADD_FILELINE) |
2418 | .set_num_inputs(1) |
2419 | .set_attrs_type<SqueezeAttrs>() |
2420 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2421 | .set_support_level(3) |
2422 | .add_type_rel("Squeeze" , SqueezeRel) |
2423 | .set_attr<FTVMCompute>("FTVMCompute" , SqueezeCompute) |
2424 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
2425 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , SqueezeInferCorrectLayout) |
2426 | .set_attr<TReshapeOp>("TReshapeOp" , true); |
2427 | |
2428 | // CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A |
2429 | bool CollapseSumLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2430 | const TypeReporter& reporter) { |
2431 | ICHECK_EQ(types.size(), 3); |
2432 | reporter->Assign(types[2], types[1]); |
2433 | return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter); |
2434 | } |
2435 | |
2436 | Expr MakeCollapseSumLike(Expr data, Expr collapse_type) { |
2437 | static const Op& op = Op::Get("collapse_sum_like" ); |
2438 | return Call(op, {data, collapse_type}, Attrs(), {}); |
2439 | } |
2440 | |
2441 | Array<te::Tensor> CollapseSumLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2442 | const Type& out_type) { |
2443 | const auto* out_ttype = out_type.as<TensorTypeNode>(); |
2444 | ICHECK(out_ttype != nullptr); |
2445 | return {topi::collapse_sum(inputs[0], out_ttype->shape)}; |
2446 | } |
2447 | |
2448 | TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like" ).set_body_typed(MakeCollapseSumLike); |
2449 | |
2450 | RELAY_REGISTER_OP("collapse_sum_like" ) |
2451 | .describe(R"code(Collapse the first input to match the shape of the second input. |
2452 | )code" TVM_ADD_FILELINE) |
2453 | .set_num_inputs(2) |
2454 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2455 | .add_argument("collapse_type" , "Tensor" , "Provide the type to collapse to." ) |
2456 | .set_support_level(10) |
2457 | .add_type_rel("CollapseSumLike" , CollapseSumLikeRel) |
2458 | .set_attr<FTVMCompute>("FTVMCompute" , CollapseSumLikeCompute) |
2459 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
2460 | |
2461 | // CollapseSumTo: <A, B> -> B where Broadcast(A, B) = A |
2462 | bool CollapseSumToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2463 | const TypeReporter& reporter) { |
2464 | ICHECK_EQ(types.size(), 3); |
2465 | const InitOpAttrs* param = attrs.as<InitOpAttrs>(); |
2466 | |
2467 | const auto* target_shape = types[1].as<TensorTypeNode>(); |
2468 | DataType out_dtype = types[0].as<TensorTypeNode>()->dtype; |
2469 | |
2470 | const IntImmNode* rank = target_shape->shape[0].as<IntImmNode>(); |
2471 | ICHECK(rank) << "Parameter must have static rank" ; |
2472 | |
2473 | std::vector<IndexExpr> oshape; |
2474 | if (param->shape) { |
2475 | const Array<Integer>& cshape_array = param->shape.value(); |
2476 | for (size_t i = 0; i < cshape_array.size(); i++) { |
2477 | oshape.push_back(cshape_array[i]); |
2478 | } |
2479 | } else { |
2480 | for (int i = 0; i < rank->value; i++) { |
2481 | oshape.push_back(Any()); |
2482 | } |
2483 | } |
2484 | reporter->Assign(types[2], TensorType(oshape, out_dtype)); |
2485 | return BroadcastRel({types[0], types[2], types[0]}, 2, Attrs(), reporter); |
2486 | } |
2487 | |
2488 | Expr MakeCollapseSumTo(Expr data, Expr shape) { |
2489 | static const Op& op = Op::Get("collapse_sum_to" ); |
2490 | auto attrs = make_object<InitOpAttrs>(); |
2491 | if (const auto* cshape = shape.as<ConstantNode>()) { |
2492 | attrs->shape = ToVector(cshape->data); |
2493 | } |
2494 | return Call(op, {data, shape}, Attrs(attrs), {}); |
2495 | } |
2496 | |
2497 | TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_to" ).set_body_typed(MakeCollapseSumTo); |
2498 | |
2499 | RELAY_REGISTER_OP("collapse_sum_to" ) |
2500 | .describe(R"code(Broadcast the first input to match the shape argument. |
2501 | )code" TVM_ADD_FILELINE) |
2502 | .set_num_inputs(2) |
2503 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2504 | .add_argument("shape" , "Tensor" , "Target shape." ) |
2505 | .set_support_level(4) |
2506 | .add_type_rel("CollapseSumTo" , CollapseSumToRel) |
2507 | .set_attr<FTVMCompute>("FTVMCompute" , CollapseSumLikeCompute) |
2508 | .set_attr<TOpPattern>("TOpPattern" , kCommReduce); |
2509 | |
2510 | bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2511 | const TypeReporter& reporter) { |
2512 | // types = [data_type, ret_type], broadcast_to_type is in attrs bc static |
2513 | ICHECK_EQ(types.size(), 2); |
2514 | |
2515 | const InitOpAttrs* param = attrs.as<InitOpAttrs>(); |
2516 | ICHECK(param); |
2517 | |
2518 | DataType out_dtype; |
2519 | if (auto ttype = types[0].as<TensorTypeNode>()) { |
2520 | out_dtype = ttype->dtype; |
2521 | } else { |
2522 | ICHECK(types[0].as<IncompleteTypeNode>()) |
2523 | << "Broadcast: expect to be TensorType but get " << types[0]; |
2524 | return false; |
2525 | } |
2526 | |
2527 | std::vector<IndexExpr> oshape; |
2528 | |
2529 | const Array<Integer>& cshape_array = param->shape.value(); |
2530 | for (size_t i = 0; i < cshape_array.size(); ++i) { |
2531 | oshape.push_back(cshape_array[i]); |
2532 | } |
2533 | reporter->Assign(types[1], TensorType(oshape, out_dtype)); |
2534 | return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); |
2535 | } |
2536 | |
2537 | Expr MakeBroadCastTo(Expr data, Array<Integer> shape) { |
2538 | static const Op& op = Op::Get("broadcast_to" ); |
2539 | auto attrs = make_object<InitOpAttrs>(); |
2540 | |
2541 | attrs->shape = std::move(shape); |
2542 | return Call(op, {data}, Attrs(attrs), {}); |
2543 | } |
2544 | |
2545 | Array<te::Tensor> BroadCastToCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2546 | const Type& out_type) { |
2547 | const auto* out_ttype = out_type.as<TensorTypeNode>(); |
2548 | return {topi::broadcast_to(inputs[0], out_ttype->shape)}; |
2549 | } |
2550 | |
2551 | TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to" ).set_body_typed(MakeBroadCastTo); |
2552 | |
2553 | RELAY_REGISTER_OP("broadcast_to" ) |
2554 | .describe(R"code(Broadcast the first input to match the shape argument. |
2555 | )code" TVM_ADD_FILELINE) |
2556 | .set_num_inputs(1) |
2557 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2558 | .set_support_level(4) |
2559 | .add_type_rel("BroadCastTo" , BroadCastToRel) |
2560 | .set_attrs_type<InitOpAttrs>() |
2561 | .set_attr<FTVMCompute>("FTVMCompute" , BroadCastToCompute) |
2562 | .set_attr<TOpPattern>("TOpPattern" , kBroadcast); |
2563 | |
2564 | // BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B |
2565 | bool BroadCastToLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2566 | const TypeReporter& reporter) { |
2567 | ICHECK_EQ(types.size(), 3); |
2568 | reporter->Assign(types[2], types[1]); |
2569 | return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); |
2570 | } |
2571 | |
2572 | Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) { |
2573 | static const Op& op = Op::Get("broadcast_to_like" ); |
2574 | return Call(op, {data, broadcast_type}, Attrs(), {}); |
2575 | } |
2576 | |
2577 | Array<te::Tensor> BroadCastToLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2578 | const Type& out_type) { |
2579 | const auto* out_ttype = out_type.as<TensorTypeNode>(); |
2580 | ICHECK(out_ttype != nullptr); |
2581 | return {topi::broadcast_to(inputs[0], out_ttype->shape)}; |
2582 | } |
2583 | |
2584 | TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like" ).set_body_typed(MakeBroadCastToLike); |
2585 | |
2586 | RELAY_REGISTER_OP("broadcast_to_like" ) |
2587 | .describe(R"code(Broadcast the first input to match the shape of the second input. |
2588 | )code" TVM_ADD_FILELINE) |
2589 | .set_num_inputs(2) |
2590 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2591 | .add_argument("broadcast_type" , "Tensor" , "Provide the type to broadcast to." ) |
2592 | .set_support_level(10) |
2593 | .add_type_rel("BroadCastToLike" , BroadCastToLikeRel) |
2594 | .set_attr<FTVMCompute>("FTVMCompute" , BroadCastToLikeCompute) |
2595 | .set_attr<TOpPattern>("TOpPattern" , kBroadcast); |
2596 | |
2597 | // Adapter function to make int array. |
2598 | Array<Integer> GetIntArray(Array<IndexExpr> arr) { |
2599 | for (size_t i = 0; i < arr.size(); ++i) { |
2600 | ICHECK(!arr[i].defined() || arr[i].as<IntImmNode>()) << "Expect an int array" ; |
2601 | } |
2602 | return Downcast<Array<Integer>>(arr); |
2603 | } |
2604 | |
2605 | // strided_slice |
2606 | TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); |
2607 | |
2608 | bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2609 | const TypeReporter& reporter) { |
2610 | ICHECK_EQ(types.size(), 2); |
2611 | const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>(); |
2612 | if (param == nullptr) { |
2613 | return false; |
2614 | } |
2615 | const auto* data = types[0].as<TensorTypeNode>(); |
2616 | |
2617 | if (data == nullptr) { |
2618 | return false; |
2619 | } |
2620 | |
2621 | ICHECK(param->begin) << "strided_slice received invalid begin " << param->begin; |
2622 | ICHECK(param->end) << "strided_slice received invalid end " << param->end; |
2623 | ICHECK(param->strides) << "strided_slice received invalid strides " << param->strides; |
2624 | |
2625 | auto begin = param->begin.value(); |
2626 | auto end = param->end.value(); |
2627 | auto strides = param->strides.value(); |
2628 | |
2629 | const size_t src_tensor_dim = static_cast<size_t>(data->shape.size()); |
2630 | Array<Integer> axes; |
2631 | if (param->axes) { |
2632 | axes = param->axes.value(); |
2633 | ICHECK(axes.size() == begin.size() && axes.size() == end.size() && |
2634 | axes.size() == strides.size()) |
2635 | << "axes, begin, end, and strides must have the same length" ; |
2636 | } else { |
2637 | for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); |
2638 | |
2639 | const IntImm one = IntImm(DataType::Int(64), 1); |
2640 | const IntImm zero = IntImm(DataType::Int(64), 0); |
2641 | const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max()); |
2642 | |
2643 | for (size_t i = strides.size(); i < src_tensor_dim; ++i) { |
2644 | strides.push_back(one); |
2645 | } |
2646 | for (size_t i = begin.size(); i < src_tensor_dim; ++i) { |
2647 | begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range); |
2648 | } |
2649 | for (size_t i = end.size(); i < src_tensor_dim; ++i) { |
2650 | end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range); |
2651 | } |
2652 | } |
2653 | auto oshape = |
2654 | topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode); |
2655 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
2656 | return true; |
2657 | } |
2658 | |
2659 | InferCorrectLayoutOutput StridedSliceInferCorrectLayout( |
2660 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
2661 | const Array<tvm::relay::Type>& old_in_types) { |
2662 | Array<Array<IndexExpr>> old_in_shapes; |
2663 | for (auto old_in_t : old_in_types) { |
2664 | ICHECK(old_in_t.as<TensorTypeNode>()); |
2665 | old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); |
2666 | } |
2667 | |
2668 | ICHECK(old_in_layouts.defined()); |
2669 | ICHECK_GE(old_in_layouts.size(), 1); |
2670 | ICHECK(old_in_shapes.defined()); |
2671 | ICHECK_GE(old_in_shapes.size(), 1); |
2672 | |
2673 | auto layout = old_in_layouts[0]; |
2674 | InferCorrectLayoutOutput out_default{{Layout::Undef()}, {Layout::Undef()}, attrs}; |
2675 | |
2676 | if (layout.defined() && new_in_layouts.defined()) { |
2677 | ICHECK_GE(new_in_layouts.size(), 1); |
2678 | auto new_layout = new_in_layouts[0]; |
2679 | auto shape = old_in_shapes[0]; |
2680 | |
2681 | const auto* attrs_ptr = attrs.as<StridedSliceAttrs>(); |
2682 | ICHECK(attrs_ptr); |
2683 | ObjectPtr<StridedSliceAttrs> params = make_object<StridedSliceAttrs>(*attrs_ptr); |
2684 | |
2685 | Array<Integer> begin, end, strides; |
2686 | if (params->begin && params->end && params->strides) { |
2687 | for (Integer i : params->strides.value()) { |
2688 | ICHECK(i.defined()); |
2689 | auto slice_val = Integer(IntImm(i->dtype, i->value)); |
2690 | strides.push_back(params->slice_mode == "size" ? Integer(IntImm(i->dtype, 1)) : slice_val); |
2691 | } |
2692 | |
2693 | for (Integer i : params->begin.value()) { |
2694 | ICHECK(i.defined()); |
2695 | begin.push_back(IntImm(i->dtype, i->value)); |
2696 | } |
2697 | for (Integer i : params->end.value()) { |
2698 | ICHECK(i.defined()); |
2699 | end.push_back(IntImm(i->dtype, i->value)); |
2700 | } |
2701 | } |
2702 | |
2703 | Array<Integer> new_begin, new_end, new_strides; |
2704 | |
2705 | // Handles layout conversion like NHWC -> NCHW |
2706 | auto old_layout_name = layout.name(); |
2707 | auto new_layout_name = new_layout.name(); |
2708 | |
2709 | if (old_layout_name.rfind(new_layout_name, 0) != 0 && |
2710 | new_layout_name.rfind(old_layout_name, 0) != 0) { |
2711 | if (old_layout_name.size() != new_layout_name.size()) { |
2712 | // Not support NHW4c -> NCHW |
2713 | return out_default; |
2714 | } else { |
2715 | if (params->axes) { |
2716 | auto axes = params->axes.value(); |
2717 | Array<Integer> new_axes; |
2718 | |
2719 | for (size_t i = 0; i < axes.size(); ++i) { |
2720 | auto old_idx = axes[i].IntValue(); |
2721 | auto new_idx = new_layout.IndexOf(layout[old_idx]); |
2722 | new_begin.push_back(begin[i]); |
2723 | new_end.push_back(end[i]); |
2724 | new_strides.push_back(strides[i]); |
2725 | new_axes.push_back(new_idx); |
2726 | } |
2727 | params->axes = new_axes; |
2728 | |
2729 | } else { |
2730 | for (size_t i = 0; i < new_layout_name.size(); ++i) { |
2731 | auto index = layout.IndexOf(new_layout[i]); |
2732 | if (index == -1) { |
2733 | return out_default; |
2734 | } |
2735 | |
2736 | size_t new_index = static_cast<size_t>(index); |
2737 | int64_t bg, ed, st; |
2738 | if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { |
2739 | st = strides[new_index]->value; |
2740 | } else { |
2741 | st = 1; |
2742 | } |
2743 | if (new_index < begin.size() && begin[new_index].defined()) { |
2744 | bg = begin[new_index]->value; |
2745 | } else { |
2746 | bg = 0; |
2747 | } |
2748 | if (new_index < end.size() && end[new_index].defined()) { |
2749 | ed = end[new_index]->value; |
2750 | } else { |
2751 | ed = shape[new_index].as<IntImmNode>()->value; |
2752 | } |
2753 | |
2754 | new_begin.push_back(IntImm(begin[0]->dtype, bg)); |
2755 | new_end.push_back(IntImm(end[0]->dtype, ed)); |
2756 | new_strides.push_back(IntImm(strides[0]->dtype, st)); |
2757 | } |
2758 | } |
2759 | |
2760 | params->begin = new_begin; |
2761 | params->end = new_end; |
2762 | params->strides = new_strides; |
2763 | layout = new_layout; |
2764 | } |
2765 | } else if (old_layout_name.size() < |
2766 | new_layout_name.size()) { // prohibit transforms such as NCHW4c -> NCHW |
2767 | if (params->axes) { |
2768 | auto axes = params->axes.value(); |
2769 | Array<Integer> new_axes; |
2770 | for (size_t i = 0; i < axes.size(); ++i) { |
2771 | auto old_idx = axes[i].IntValue(); |
2772 | auto new_idx = new_layout.IndexOf(layout[old_idx]); |
2773 | new_axes.push_back(new_idx); |
2774 | |
2775 | const LayoutAxis& axis = layout[old_idx]; |
2776 | ICHECK(axis.IsPrimal()); |
2777 | auto factor = new_layout.FactorOf(axis); |
2778 | if (factor == -1) { |
2779 | new_begin.push_back(begin[i]); |
2780 | new_end.push_back(end[i]); |
2781 | } else { |
2782 | if (strides.defined() && i < strides.size()) { |
2783 | auto stride = strides[i]; |
2784 | // arbitrary stride is not supported |
2785 | if (stride.defined() && stride->value != 1) { |
2786 | return out_default; |
2787 | } |
2788 | } |
2789 | int64_t bg = begin[i].IntValue(); |
2790 | int64_t ed = end[i].IntValue(); |
2791 | if (bg % factor || ed % factor) { |
2792 | // transform to original layout |
2793 | return out_default; |
2794 | } |
2795 | new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); |
2796 | new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); |
2797 | } |
2798 | } |
2799 | params->axes = new_axes; |
2800 | |
2801 | } else { |
2802 | for (size_t i = 0; i < begin.size(); i++) { |
2803 | const LayoutAxis& axis = layout[i]; |
2804 | ICHECK(axis.IsPrimal()); |
2805 | auto factor = new_layout.FactorOf(axis); |
2806 | if (factor == -1) { |
2807 | new_begin.push_back(IntImm(begin[i]->dtype, begin[i].IntValue())); |
2808 | new_end.push_back(IntImm(end[i]->dtype, end[i].IntValue())); |
2809 | } else { |
2810 | if (strides.defined() && i < strides.size()) { |
2811 | auto stride = strides[i]; |
2812 | // arbitrary stride is not supported |
2813 | if (stride.defined() && stride->value != 1) { |
2814 | return out_default; |
2815 | } |
2816 | } |
2817 | int64_t bg = begin[i].defined() ? begin[i]->value : 0; |
2818 | int64_t ed; |
2819 | if (!end[i].defined()) { |
2820 | ed = shape[i].as<IntImmNode>()->value; |
2821 | } else if (params->slice_mode == "size" ) { |
2822 | if (end[i]->value < 0) { |
2823 | ed = shape[i].as<IntImmNode>()->value; |
2824 | } else { |
2825 | ed = bg + end[i]->value; |
2826 | } |
2827 | } else { |
2828 | ed = end[i]->value; |
2829 | } |
2830 | |
2831 | if (bg % factor || ed % factor) { |
2832 | // transform to original layout |
2833 | return out_default; |
2834 | } |
2835 | new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); |
2836 | new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); |
2837 | } |
2838 | } |
2839 | } |
2840 | |
2841 | layout = new_layout; |
2842 | params->begin = new_begin; |
2843 | params->end = new_end; |
2844 | } |
2845 | return InferCorrectLayoutOutput({layout}, {layout}, Attrs(params)); |
2846 | } |
2847 | return InferCorrectLayoutOutput({layout}, {layout}, attrs); |
2848 | } |
2849 | |
2850 | Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
2851 | const Type& out_type) { |
2852 | const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>(); |
2853 | ICHECK(param != nullptr); |
2854 | ICHECK(param->begin && param->end && param->strides); |
2855 | Array<Integer> begin = param->begin.value(); |
2856 | Array<Integer> end = param->end.value(); |
2857 | Array<Integer> strides = param->strides.value(); |
2858 | if (param->axes) { |
2859 | auto axes = param->axes.value(); |
2860 | return Array<te::Tensor>{ |
2861 | topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)}; |
2862 | } |
2863 | return Array<te::Tensor>{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; |
2864 | } |
2865 | |
2866 | // Positional relay function to create StridedSlice operator used by frontend FFI. |
2867 | Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides, |
2868 | String slice_mode, Optional<Array<Integer>> axes) { |
2869 | auto attrs = make_object<StridedSliceAttrs>(); |
2870 | attrs->begin = std::move(begin); |
2871 | attrs->end = std::move(end); |
2872 | attrs->strides = std::move(strides); |
2873 | attrs->slice_mode = slice_mode; |
2874 | attrs->axes = std::move(axes); |
2875 | static const Op& op = Op::Get("strided_slice" ); |
2876 | return Call(op, {data}, Attrs(attrs), {}); |
2877 | } |
2878 | |
2879 | TVM_REGISTER_GLOBAL("relay.op._make.strided_slice" ).set_body_typed(MakeStridedSlice); |
2880 | |
2881 | RELAY_REGISTER_OP("strided_slice" ) |
2882 | .describe(R"code(Strided slice of an array. |
2883 | |
2884 | Examples:: |
2885 | |
2886 | x = [[ 1., 4., 7., 10.], |
2887 | [ 2., 5., 8., 11.], |
2888 | [ 3., 6., 9., 12.]] |
2889 | |
2890 | strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.], |
2891 | [ 5., 8., 11.]] |
2892 | |
2893 | x = [[[ 1., 2.], |
2894 | [ 3., 4.]], |
2895 | |
2896 | [[ 5., 6.], |
2897 | [ 7., 8.]]] |
2898 | |
2899 | strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.], |
2900 | [ 3., 4.]], |
2901 | |
2902 | [[ 5., 6.], |
2903 | [ 7., 8.]]] |
2904 | )code" TVM_ADD_FILELINE) |
2905 | .set_num_inputs(1) |
2906 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2907 | .set_support_level(4) |
2908 | .set_attrs_type<StridedSliceAttrs>() |
2909 | .add_type_rel("StridedSlice" , StridedSliceRel) |
2910 | .set_attr<FTVMCompute>("FTVMCompute" , StridedSliceCompute) |
2911 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
2912 | .set_attr<AnyCodegenStrategy>("AnyCodegenStrategy" , kVariableDimensions) |
2913 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , StridedSliceInferCorrectLayout); |
2914 | |
2915 | // strided_set |
2916 | bool StridedSetRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
2917 | const TypeReporter& reporter) { |
2918 | ICHECK_EQ(types.size(), 6); |
2919 | reporter->Assign(types[5], types[0]); |
2920 | return true; |
2921 | } |
2922 | |
2923 | Expr MakeStridedSet(Expr data, Expr v, Expr begin, Expr end, Expr strides) { |
2924 | static const Op& op = Op::Get("strided_set" ); |
2925 | return Call(op, {data, v, begin, end, strides}, {}); |
2926 | } |
2927 | |
2928 | TVM_REGISTER_GLOBAL("relay.op._make.strided_set" ).set_body_typed(MakeStridedSet); |
2929 | |
2930 | RELAY_REGISTER_OP("strided_set" ) |
2931 | .describe(R"code(Strided set of an array. |
2932 | Example:: |
2933 | |
2934 | x = [[ 1., 4., 7., 10.], |
2935 | [ 2., 5., 8., 11.], |
2936 | [ 3., 6., 9., 12.]] |
2937 | |
2938 | v = [[ 11., 22., 33.] |
2939 | [ 44., 55., 66.]] |
2940 | |
2941 | strided_set(x, v, begin=[0, 1], end=[2, 4], stride=[1, 1]) = \ |
2942 | [[ 1., 11., 22., 33.], |
2943 | [ 2., 44., 55., 66.], |
2944 | [ 3., 6., 9., 12.]] |
2945 | )code" TVM_ADD_FILELINE) |
2946 | .set_num_inputs(5) |
2947 | .add_argument("data" , "Tensor" , "The input tensor." ) |
2948 | .add_argument("v" , "Tensor" , "The data to set." ) |
2949 | .add_argument("begin" , "Tensor" , "Indices for the start of the slice." ) |
2950 | .add_argument("end" , "Tensor" , "Indices indicating the end of the slice." ) |
2951 | .add_argument("strides" , "Tensor" , "The strides values." ) |
2952 | .set_support_level(4) |
2953 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
2954 | .add_type_rel("StridedSet" , StridedSetRel); |
2955 | |
2956 | // relay.split |
2957 | TVM_REGISTER_NODE_TYPE(SplitAttrs); |
2958 | |
2959 | InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, |
2960 | const Array<Layout>& new_in_layouts, |
2961 | const Array<Layout>& old_in_layouts, |
2962 | const Array<tvm::relay::Type>& old_in_types) { |
2963 | const auto* attrs_ptr = attrs.as<SplitAttrs>(); |
2964 | ICHECK(attrs_ptr); |
2965 | ObjectPtr<SplitAttrs> param = make_object<SplitAttrs>(*attrs_ptr); |
2966 | |
2967 | Array<Array<IndexExpr>> old_in_shapes; |
2968 | for (auto old_in_t : old_in_types) { |
2969 | ICHECK(old_in_t.as<TensorTypeNode>()); |
2970 | old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); |
2971 | } |
2972 | |
2973 | size_t axis = |
2974 | param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast<size_t>(param->axis); |
2975 | |
2976 | Layout ret = Layout::Undef(); |
2977 | size_t size = 0; |
2978 | if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) { |
2979 | size = sections->value; |
2980 | } else { |
2981 | size = Downcast<Array<Integer>>(param->indices_or_sections).size() + 1; |
2982 | } |
2983 | |
2984 | // If new_in_layouts are defined, this code tries to modify the layout. |
2985 | if (new_in_layouts.defined() && old_in_layouts.defined()) { |
2986 | bool divisible = true; |
2987 | const auto& sp_dim = old_in_layouts[0][axis]; |
2988 | auto new_index = new_in_layouts[0].IndexOf(sp_dim); |
2989 | param->axis = new_index; |
2990 | int factor = new_in_layouts[0].FactorOf(sp_dim); |
2991 | if (factor > 1) { |
2992 | if (!param->indices_or_sections.as<IntImmNode>()) { |
2993 | auto ios = Downcast<Array<Integer>>(param->indices_or_sections); |
2994 | Array<Integer> new_ios; |
2995 | for (const auto& v : ios) { |
2996 | const IntImmNode* vint = v.as<IntImmNode>(); |
2997 | new_ios.push_back(vint->value / factor); |
2998 | if (vint->value % factor) { |
2999 | divisible = false; |
3000 | } |
3001 | } |
3002 | if (divisible) { |
3003 | param->indices_or_sections = new_ios; |
3004 | } |
3005 | } |
3006 | } |
3007 | if (divisible) { |
3008 | ret = new_in_layouts[0]; |
3009 | } else { |
3010 | ret = old_in_layouts[0]; |
3011 | } |
3012 | } else if (old_in_layouts.defined()) { |
3013 | ret = old_in_layouts[0]; |
3014 | } |
3015 | |
3016 | return InferCorrectLayoutOutput({ret}, {Array<Layout>(size, ret)}, Attrs(param)); |
3017 | } |
3018 | |
3019 | bool SplitRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3020 | const TypeReporter& reporter) { |
3021 | // `types` contains: [data, result] |
3022 | ICHECK_EQ(types.size(), 2); |
3023 | const auto* data = types[0].as<TensorTypeNode>(); |
3024 | if (data == nullptr) return false; |
3025 | ICHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty" ; |
3026 | const auto param = attrs.as<SplitAttrs>(); |
3027 | ICHECK(param != nullptr); |
3028 | auto axis = param->axis; |
3029 | if (axis < 0) { |
3030 | axis += data->shape.size(); |
3031 | } |
3032 | ICHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range." ; |
3033 | ICHECK_GE(axis, 0) << "axis should be within the input dimension range." ; |
3034 | |
3035 | if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) { |
3036 | if (!data->shape[axis].as<AnyNode>()) { |
3037 | ICHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == |
3038 | tir::make_zero(DataType::Int(64)))) |
3039 | << "indices_or_sections need to be able to divide input.shape[axis]" ; |
3040 | } |
3041 | std::vector<Type> fields; |
3042 | for (int i = 0; i < sections->value; ++i) { |
3043 | std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); |
3044 | if (data->shape[axis].as<AnyNode>()) { |
3045 | oshape[axis] = Any(); |
3046 | } else { |
3047 | oshape[axis] = indexdiv(oshape[axis], sections->value); |
3048 | } |
3049 | auto vec_type = TensorType(oshape, data->dtype); |
3050 | fields.push_back(vec_type); |
3051 | } |
3052 | reporter->Assign(types[1], TupleType(Array<Type>(fields))); |
3053 | } else { |
3054 | Array<IndexExpr> indices; |
3055 | for (auto i : Downcast<Array<Integer>>(param->indices_or_sections)) { |
3056 | indices.push_back(IntImm(DataType::Int(32), i.as<IntImmNode>()->value)); |
3057 | } |
3058 | auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); |
3059 | std::vector<Type> fields; |
3060 | for (unsigned int i = 0; i < indices.size(); ++i) { |
3061 | ICHECK(reporter->Assert(indices[i] > begin)) |
3062 | << "indices_or_sections need to be a sorted ascending list" ; |
3063 | std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); |
3064 | oshape[axis] = indices[i] - begin; |
3065 | begin = indices[i]; |
3066 | auto vec_type = TensorType(oshape, data->dtype); |
3067 | fields.push_back(vec_type); |
3068 | } |
3069 | if (!data->shape[axis].as<AnyNode>()) { |
3070 | ICHECK(reporter->Assert(begin < data->shape[axis])) |
3071 | << "The sum of sections must match the input.shape[axis]" ; |
3072 | } |
3073 | std::vector<IndexExpr> oshape(data->shape.begin(), data->shape.end()); |
3074 | if (data->shape[axis].as<AnyNode>()) { |
3075 | oshape[axis] = Any(); |
3076 | } else { |
3077 | oshape[axis] = data->shape[axis] - begin; |
3078 | } |
3079 | auto vec_type = TensorType(oshape, data->dtype); |
3080 | fields.push_back(vec_type); |
3081 | reporter->Assign(types[1], TupleType(Array<Type>(fields))); |
3082 | } |
3083 | return true; |
3084 | } |
3085 | |
3086 | Array<te::Tensor> SplitCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3087 | const Type& out_type) { |
3088 | const auto param = attrs.as<SplitAttrs>(); |
3089 | ICHECK(param != nullptr); |
3090 | |
3091 | if (const IntImmNode* sections = param->indices_or_sections.as<IntImmNode>()) { |
3092 | int64_t num_sections = sections->value; |
3093 | return Array<te::Tensor>{topi::split_sections(inputs[0], num_sections, param->axis)}; |
3094 | } else { |
3095 | Array<PrimExpr> indices; |
3096 | for (auto i : Downcast<Array<Integer>>(param->indices_or_sections)) { |
3097 | indices.push_back(IntImm(DataType::Int(32), i.as<IntImmNode>()->value)); |
3098 | } |
3099 | return Array<te::Tensor>{topi::split(inputs[0], indices, param->axis)}; |
3100 | } |
3101 | } |
3102 | |
3103 | Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { |
3104 | auto attrs = make_object<SplitAttrs>(); |
3105 | attrs->axis = axis; |
3106 | attrs->indices_or_sections = std::move(indices_or_sections); |
3107 | static const Op& op = Op::Get("split" ); |
3108 | return Call(op, {data}, Attrs(attrs), {}); |
3109 | } |
3110 | |
3111 | TVM_REGISTER_GLOBAL("relay.op._make.split" ).set_body([](const TVMArgs& args, TVMRetValue* rv) { |
3112 | if (args.type_codes[1] == kDLInt) { |
3113 | // Note: we change it from Int(64) to Int(32) for now as |
3114 | // combine_parallel_dense will transform the graph with Int(32). |
3115 | // More invetigation is needs to check which one we should use. |
3116 | *rv = |
3117 | MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast<int>(args[1])), args[2]); |
3118 | } else { |
3119 | *rv = MakeSplit(args[0], args[1], args[2]); |
3120 | } |
3121 | }); |
3122 | |
3123 | RELAY_REGISTER_OP("split" ) |
3124 | .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. |
3125 | |
3126 | Indices or sections to split into. Accepts an int or a tuple |
3127 | If indices_or_sections is an integer, the input will be divided equally |
3128 | along given axis. If such a split is not possible, an error is raised. |
3129 | |
3130 | If indices_or_sections is a tuple of sorted integers, |
3131 | the entries indicate where along axis the array is split. |
3132 | |
3133 | )code" TVM_ADD_FILELINE) |
3134 | .set_attrs_type<SplitAttrs>() |
3135 | .set_num_inputs(1) |
3136 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3137 | .set_support_level(3) |
3138 | .add_type_rel("Split" , SplitRel) |
3139 | .set_attr<FTVMCompute>("FTVMCompute" , SplitCompute) |
3140 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , SplitInferCorrectLayout) |
3141 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
3142 | |
3143 | // relay.slice_like |
3144 | TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); |
3145 | |
3146 | /*! |
3147 | * \brief SliceLikeRel User defined type constraint function. |
3148 | * \param num_inputs Number of input types in the args. |
3149 | * \param attrs The additional attributes of the operator. |
3150 | * \param reporter The reporter to report solution to. |
3151 | * \return False if the relation has not been resolved, it might be resolved later. |
3152 | * True if this relation has been resolved. |
3153 | */ |
3154 | bool SliceLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3155 | const TypeReporter& reporter) { |
3156 | ICHECK_EQ(types.size(), 3); |
3157 | const auto* data = types[0].as<TensorTypeNode>(); |
3158 | if (data == nullptr) { |
3159 | return false; |
3160 | } |
3161 | |
3162 | const auto* target = types[1].as<TensorTypeNode>(); |
3163 | if (target == nullptr) { |
3164 | return false; |
3165 | } |
3166 | |
3167 | const auto param = attrs.as<SliceLikeAttrs>(); |
3168 | ICHECK(param != nullptr); |
3169 | |
3170 | const Array<IndexExpr>& dshape = data->shape; |
3171 | const Array<IndexExpr>& target_shape = target->shape; |
3172 | std::vector<IndexExpr> oshape(dshape.begin(), dshape.end()); |
3173 | |
3174 | if (!param->axes.defined()) { |
3175 | for (size_t i = 0; i < dshape.size(); ++i) { |
3176 | if (i < target_shape.size()) { |
3177 | oshape[i] = target_shape[i]; |
3178 | ICHECK(reporter->Assert(oshape[i] <= dshape[i])) |
3179 | << "End index of axis " << i << " exceeds input shape: " << oshape[i] << " vs " |
3180 | << dshape[i]; |
3181 | } |
3182 | } |
3183 | } else { |
3184 | ICHECK(param->axes.size() != 0) << "Axes cannot be empty." ; |
3185 | for (Integer val : param->axes) { |
3186 | int axis = val->value; |
3187 | if (axis < 0) { |
3188 | axis += dshape.size(); |
3189 | } |
3190 | ICHECK(axis < static_cast<int>(target_shape.size())) |
3191 | << "Axis " << axis << " exceeds dimension " << target_shape.size() << " of target_shape." ; |
3192 | oshape[axis] = target_shape[axis]; |
3193 | ICHECK(reporter->Assert(oshape[axis] <= dshape[axis])) |
3194 | << "End index of axis " << axis << " exceeds input shape: " << oshape[axis] << " vs " |
3195 | << dshape[axis]; |
3196 | } |
3197 | } |
3198 | |
3199 | reporter->Assign(types[2], TensorType(oshape, data->dtype)); |
3200 | return true; |
3201 | } |
3202 | |
3203 | Expr MakeSliceLike(Expr data, Expr shape_like, Array<Integer> axes) { |
3204 | auto attrs = make_object<SliceLikeAttrs>(); |
3205 | attrs->axes = std::move(axes); |
3206 | static const Op& op = Op::Get("slice_like" ); |
3207 | return Call(op, {data, shape_like}, Attrs(attrs), {}); |
3208 | } |
3209 | |
3210 | InferCorrectLayoutOutput SliceLikeInferCorrectLayout(const Attrs& attrs, |
3211 | const Array<Layout>& new_in_layouts, |
3212 | const Array<Layout>& old_in_layouts, |
3213 | const Array<tvm::relay::Type>& old_in_types) { |
3214 | Array<Integer> new_axes; |
3215 | if (old_in_layouts.defined() && new_in_layouts.defined()) { |
3216 | ICHECK_EQ(new_in_layouts.size(), 2); |
3217 | ICHECK_EQ(new_in_layouts[0]->name, new_in_layouts[1]->name); |
3218 | ICHECK_EQ(old_in_layouts.size(), 2); |
3219 | ICHECK_EQ(old_in_layouts[0]->name, old_in_layouts[1]->name); |
3220 | |
3221 | auto old_layout = old_in_layouts[0]; |
3222 | auto new_layout = new_in_layouts[0]; |
3223 | |
3224 | const auto* attrs_ptr = attrs.as<SliceLikeAttrs>(); |
3225 | ICHECK(attrs_ptr); |
3226 | ObjectPtr<SliceLikeAttrs> params = make_object<SliceLikeAttrs>(*attrs_ptr); |
3227 | |
3228 | for (auto axis : params->axes) { |
3229 | auto new_axis = new_layout.IndexOf(old_layout[axis->value]); |
3230 | // Cannot find the target axis in the new layout. |
3231 | if (new_axis == -1) { |
3232 | new_axes.clear(); |
3233 | break; |
3234 | } |
3235 | new_axes.push_back(new_axis); |
3236 | } |
3237 | if (!new_axes.empty()) { |
3238 | params->axes = std::move(new_axes); |
3239 | return InferCorrectLayoutOutput({new_layout, new_layout}, {new_layout}, Attrs(params)); |
3240 | } |
3241 | } |
3242 | |
3243 | if (old_in_layouts.defined()) { |
3244 | ICHECK_EQ(old_in_layouts.size(), 2); |
3245 | return InferCorrectLayoutOutput({old_in_layouts[0], old_in_layouts[1]}, {old_in_layouts[1]}, |
3246 | attrs); |
3247 | } |
3248 | return InferCorrectLayoutOutput({Layout::Undef(), Layout::Undef()}, {Layout::Undef()}, attrs); |
3249 | } |
3250 | |
3251 | Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3252 | const Type& out_type) { |
3253 | const auto* param = attrs.as<SliceLikeAttrs>(); |
3254 | ICHECK(param != nullptr); |
3255 | Array<IndexExpr> src_shape = inputs[0]->shape; |
3256 | Array<IndexExpr> target_shape = inputs[1]->shape; |
3257 | Array<Integer> begin_idx, end_idx, strides; |
3258 | for (size_t i = 0; i < src_shape.size(); ++i) { |
3259 | begin_idx.push_back(0); |
3260 | strides.push_back(1); |
3261 | } |
3262 | for (auto s : src_shape) { |
3263 | ICHECK(s->IsInstance<tvm::IntImmNode>()) << "slice_like does not support dynamic input shape" ; |
3264 | end_idx.push_back(topi::GetConstInt(s)); |
3265 | } |
3266 | if (!param->axes.defined()) { |
3267 | for (size_t i = 0; i < src_shape.size(); ++i) { |
3268 | if (i < target_shape.size()) { |
3269 | ICHECK(target_shape[i]->IsInstance<tvm::IntImmNode>()) |
3270 | << "slice_like does not support dynamic output shape" ; |
3271 | end_idx.Set(i, topi::GetConstInt(target_shape[i])); |
3272 | ICHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) |
3273 | << "End index of axis " << i |
3274 | << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " |
3275 | << topi::GetConstInt(src_shape[i]); |
3276 | } |
3277 | } |
3278 | } else { |
3279 | for (Integer axis : param->axes) { |
3280 | int a = axis.IntValue(); |
3281 | if (a < 0) { |
3282 | a = static_cast<int>(src_shape.size()) + a; |
3283 | } |
3284 | ICHECK(target_shape[a]->IsInstance<tvm::IntImmNode>()) |
3285 | << "slice_like does not support dynamic output shape" ; |
3286 | end_idx.Set(a, topi::GetConstInt(target_shape[a])); |
3287 | ICHECK_LE(topi::GetConstInt(end_idx[a]), topi::GetConstInt(src_shape[a])) |
3288 | << "End index of axis " << a << " exceeds input shape: " << topi::GetConstInt(end_idx[a]) |
3289 | << " vs " << topi::GetConstInt(src_shape[a]); |
3290 | } |
3291 | } |
3292 | return Array<te::Tensor>{topi::strided_slice(inputs[0], begin_idx, end_idx, strides, "end" )}; |
3293 | } |
3294 | |
3295 | TVM_REGISTER_GLOBAL("relay.op._make.slice_like" ).set_body_typed(MakeSliceLike); |
3296 | |
3297 | RELAY_REGISTER_OP("slice_like" ) |
3298 | .describe(R"code(Slice the first input respect to the second input. |
3299 | )code" TVM_ADD_FILELINE) |
3300 | .set_attrs_type<SliceLikeAttrs>() |
3301 | .set_num_inputs(2) |
3302 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3303 | .add_argument("shape_like" , "Tensor" , "Shape tensor." ) |
3304 | .set_support_level(10) |
3305 | .add_type_rel("SliceLike" , SliceLikeRel) |
3306 | .set_attr<FTVMCompute>("FTVMCompute" , SliceLikeCompute) |
3307 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , SliceLikeInferCorrectLayout) |
3308 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
3309 | |
3310 | // relay.layout_transform |
3311 | TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); |
3312 | |
3313 | Array<te::Tensor> LayoutTransformCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3314 | const Type& out_type) { |
3315 | const auto* param = attrs.as<LayoutTransformAttrs>(); |
3316 | ICHECK(param != nullptr); |
3317 | return Array<te::Tensor>{topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)}; |
3318 | } |
3319 | |
3320 | bool LayoutTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3321 | const TypeReporter& reporter) { |
3322 | const auto* data = types[0].as<TensorTypeNode>(); |
3323 | if (data == nullptr) { |
3324 | ICHECK(types[0].as<IncompleteTypeNode>()) |
3325 | << "LayoutTransform: expect input data type to be TensorType but get " << types[0]; |
3326 | return false; |
3327 | } |
3328 | const LayoutTransformAttrs* params = attrs.as<LayoutTransformAttrs>(); |
3329 | |
3330 | Layout src_layout(params->src_layout); |
3331 | Layout dst_layout(params->dst_layout); |
3332 | |
3333 | ICHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout" ; |
3334 | auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout); |
3335 | ICHECK(layout_converter.defined()) |
3336 | << "cannot convert from " << params->src_layout << " to " << params->dst_layout; |
3337 | |
3338 | const auto& out_shape = layout_converter.ForwardShape(data->shape); |
3339 | reporter->Assign(types[1], TensorType(out_shape, data->dtype)); |
3340 | return true; |
3341 | } |
3342 | |
3343 | Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout) { |
3344 | auto attrs = make_object<LayoutTransformAttrs>(); |
3345 | attrs->src_layout = std::move(src_layout); |
3346 | attrs->dst_layout = std::move(dst_layout); |
3347 | static const Op& op = Op::Get("layout_transform" ); |
3348 | return Call(op, {data}, Attrs(attrs), {}); |
3349 | } |
3350 | |
3351 | TVM_REGISTER_GLOBAL("relay.op._make.layout_transform" ).set_body_typed(MakeLayoutTransform); |
3352 | |
3353 | RELAY_REGISTER_OP("layout_transform" ) |
3354 | .describe(R"code(Transform the input data layout. |
3355 | |
3356 | For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes |
3357 | the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] |
3358 | |
3359 | )code" TVM_ADD_FILELINE) |
3360 | .set_attrs_type<LayoutTransformAttrs>() |
3361 | .set_num_inputs(1) |
3362 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3363 | .add_type_rel("layout_transform" , LayoutTransformRel) |
3364 | .set_support_level(5) |
3365 | .set_attr<FTVMCompute>("FTVMCompute" , LayoutTransformCompute); |
3366 | |
3367 | // relay.auto_scheduler_layout_transform |
3368 | TVM_REGISTER_NODE_TYPE(AutoSchedulerLayoutTransformAttrs); |
3369 | |
3370 | Array<te::Tensor> AutoSchedulerLayoutTransformCompute(const Attrs& attrs, |
3371 | const Array<te::Tensor>& inputs, |
3372 | const Type& out_type) { |
3373 | const auto* param = attrs.as<AutoSchedulerLayoutTransformAttrs>(); |
3374 | CHECK(param != nullptr); |
3375 | return Array<te::Tensor>{ |
3376 | topi::auto_scheduler_layout_transform(inputs[0], param->src_layout, param->dst_layout)}; |
3377 | } |
3378 | |
3379 | bool AutoSchedulerLayoutTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3380 | const TypeReporter& reporter) { |
3381 | const auto* data = types[0].as<TensorTypeNode>(); |
3382 | CHECK(data != nullptr); |
3383 | const AutoSchedulerLayoutTransformAttrs* params = attrs.as<AutoSchedulerLayoutTransformAttrs>(); |
3384 | |
3385 | Array<IndexExpr> dst_shape; |
3386 | std::vector<std::string> dst_axes; |
3387 | |
3388 | topi::parse_auto_scheduler_layout(params->dst_layout, &dst_shape, &dst_axes); |
3389 | |
3390 | reporter->Assign(types[1], TensorType(dst_shape, data->dtype)); |
3391 | return true; |
3392 | } |
3393 | |
3394 | Expr MakeAutoSchedulerLayoutTransform(Expr data, String src_layout, String dst_layout) { |
3395 | auto attrs = make_object<AutoSchedulerLayoutTransformAttrs>(); |
3396 | attrs->src_layout = std::move(src_layout); |
3397 | attrs->dst_layout = std::move(dst_layout); |
3398 | static const Op& op = Op::Get("auto_scheduler_layout_transform" ); |
3399 | return Call(op, {data}, Attrs(attrs), {}); |
3400 | } |
3401 | |
3402 | TVM_REGISTER_GLOBAL("relay.op._make.auto_scheduler_layout_transform" ) |
3403 | .set_body_typed(MakeAutoSchedulerLayoutTransform); |
3404 | |
3405 | RELAY_REGISTER_OP("auto_scheduler_layout_transform" ) |
3406 | .describe(R"code(Transform the input kernel layout. |
3407 | )code" TVM_ADD_FILELINE) |
3408 | .set_attrs_type<AutoSchedulerLayoutTransformAttrs>() |
3409 | .set_num_inputs(1) |
3410 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3411 | .add_type_rel("auto_scheduler_layout_transform" , AutoSchedulerLayoutTransformRel) |
3412 | .set_support_level(5) |
3413 | .set_attr<FTVMCompute>("FTVMCompute" , AutoSchedulerLayoutTransformCompute); |
3414 | |
3415 | // relay.meta_schedule_layout_transform |
3416 | TVM_REGISTER_NODE_TYPE(MetaScheduleLayoutTransformAttrs); |
3417 | |
3418 | Array<te::Tensor> MetaScheduleLayoutTransformCompute(const Attrs& attrs, |
3419 | const Array<te::Tensor>& inputs, |
3420 | const Type& out_type) { |
3421 | const auto* param = attrs.as<MetaScheduleLayoutTransformAttrs>(); |
3422 | CHECK(param != nullptr); |
3423 | return Array<te::Tensor>{topi::meta_schedule_layout_transform(inputs[0], param->index_map)}; |
3424 | } |
3425 | |
3426 | bool MetaScheduleLayoutTransformRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3427 | const TypeReporter& reporter) { |
3428 | TensorType data_type = Downcast<TensorType>(types[0]); |
3429 | const MetaScheduleLayoutTransformAttrs* params = attrs.as<MetaScheduleLayoutTransformAttrs>(); |
3430 | ICHECK(params); |
3431 | Array<PrimExpr> new_shape = params->index_map->MapShape(data_type->shape); |
3432 | reporter->Assign(types[1], TensorType(new_shape, data_type->dtype)); |
3433 | return true; |
3434 | } |
3435 | |
3436 | Expr MakeMetaScheduleLayoutTransform(Expr data, tir::IndexMap index_map) { |
3437 | static const Op& op = Op::Get("meta_schedule_layout_transform" ); |
3438 | auto attrs = make_object<MetaScheduleLayoutTransformAttrs>(); |
3439 | attrs->index_map = index_map; |
3440 | return Call(op, {data}, Attrs(attrs), {}); |
3441 | } |
3442 | |
3443 | TVM_REGISTER_GLOBAL("relay.op._make.meta_schedule_layout_transform" ) |
3444 | .set_body_typed(MakeMetaScheduleLayoutTransform); |
3445 | |
3446 | RELAY_REGISTER_OP("meta_schedule_layout_transform" ) |
3447 | .describe(R"code(Transform the input kernel layout. |
3448 | )code" TVM_ADD_FILELINE) |
3449 | .set_attrs_type<MetaScheduleLayoutTransformAttrs>() |
3450 | .set_num_inputs(1) |
3451 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3452 | .add_type_rel("meta_schedule_layout_transform" , MetaScheduleLayoutTransformRel) |
3453 | .set_support_level(5) |
3454 | .set_attr<FTVMCompute>("FTVMCompute" , MetaScheduleLayoutTransformCompute); |
3455 | |
3456 | // relay._contrib_reverse_reshape |
3457 | Expr MakeReverseReshape(Expr data, Array<Integer> newshape) { |
3458 | auto attrs = make_object<ReshapeAttrs>(); |
3459 | attrs->newshape = std::move(newshape); |
3460 | static const Op& op = Op::Get("contrib_reverse_reshape" ); |
3461 | return Call(op, {data}, Attrs(attrs), {}); |
3462 | } |
3463 | |
3464 | TVM_REGISTER_GLOBAL("relay.op._make.contrib_reverse_reshape" ).set_body_typed(MakeReverseReshape); |
3465 | |
3466 | RELAY_REGISTER_OP("contrib_reverse_reshape" ) |
3467 | .describe(R"code(Reshapes the input array where the special values are inferred from |
3468 | right to left. |
3469 | |
3470 | Example:: |
3471 | |
3472 | The special values have the same semantics as reshape. The difference is that |
3473 | special values are inferred from right to left. It can be explained in the |
3474 | example below:: |
3475 | |
3476 | - data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5) |
3477 | - data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) |
3478 | |
3479 | )code" TVM_ADD_FILELINE) |
3480 | .set_num_inputs(1) |
3481 | .set_attrs_type<ReshapeAttrs>() |
3482 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3483 | .set_support_level(10) |
3484 | .add_type_rel("ReverseReshape" , ReverseReshapeRel) |
3485 | .set_attr<FTVMCompute>("FTVMCompute" , ReshapeCompute) |
3486 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
3487 | .set_attr<TReshapeOp>("TReshapeOp" , true); |
3488 | |
3489 | // gather operator |
3490 | TVM_REGISTER_NODE_TYPE(GatherAttrs); |
3491 | |
3492 | bool GatherRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3493 | const TypeReporter& reporter) { |
3494 | // `types` contains: [data, indices, result] |
3495 | ICHECK_EQ(types.size(), 3); |
3496 | const auto* data = types[0].as<TensorTypeNode>(); |
3497 | const auto* indices = types[1].as<TensorTypeNode>(); |
3498 | if (data == nullptr) { |
3499 | ICHECK(types[0].as<IncompleteTypeNode>()) |
3500 | << "Gather: expect input data type to be TensorType but get " << types[0]; |
3501 | return false; |
3502 | } |
3503 | if (indices == nullptr) { |
3504 | ICHECK(types[1].as<IncompleteTypeNode>()) |
3505 | << "Gather: expect indices type to be TensorType but get " << types[1]; |
3506 | return false; |
3507 | } |
3508 | ICHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer" ; |
3509 | const auto param = attrs.as<GatherAttrs>(); |
3510 | ICHECK(param != nullptr); |
3511 | ICHECK(param->axis.defined()); |
3512 | |
3513 | const auto ndim_data = data->shape.size(); |
3514 | const auto ndim_indices = indices->shape.size(); |
3515 | int axis = param->axis->value; |
3516 | ICHECK_EQ(ndim_data, ndim_indices); |
3517 | if (axis < 0) { |
3518 | axis += ndim_data; |
3519 | } |
3520 | ICHECK_GE(axis, 0); |
3521 | ICHECK_LT(axis, ndim_data); |
3522 | |
3523 | std::vector<IndexExpr> oshape; |
3524 | oshape.reserve(ndim_data); |
3525 | for (size_t i = 0; i < ndim_data; ++i) { |
3526 | if (i == static_cast<size_t>(axis)) { |
3527 | if (indices->shape[i].as<IntImmNode>()) { |
3528 | const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); |
3529 | ICHECK_GE(*indice_shape_i, 1); |
3530 | } |
3531 | } else { |
3532 | ICHECK(reporter->AssertEQ(indices->shape[i], data->shape[i])); |
3533 | } |
3534 | oshape.emplace_back(indices->shape[i]); |
3535 | } |
3536 | reporter->Assign(types[2], TensorType(oshape, data->dtype)); |
3537 | return true; |
3538 | } |
3539 | |
3540 | Array<te::Tensor> GatherCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3541 | const Type& out_type) { |
3542 | const auto* param = attrs.as<GatherAttrs>(); |
3543 | return {topi::gather(inputs[0], param->axis.IntValue(), inputs[1])}; |
3544 | } |
3545 | |
3546 | Expr MakeGather(Expr data, Integer axis, Expr indices) { |
3547 | auto attrs = make_object<GatherAttrs>(); |
3548 | attrs->axis = std::move(axis); |
3549 | static const Op& op = Op::Get("gather" ); |
3550 | return Call(op, {data, indices}, Attrs(attrs), {}); |
3551 | } |
3552 | |
3553 | TVM_REGISTER_GLOBAL("relay.op._make.gather" ).set_body_typed(MakeGather); |
3554 | |
3555 | RELAY_REGISTER_OP("gather" ) |
3556 | .describe(R"code(Gather values along given axis from given indices. |
3557 | |
3558 | E.g. for a 3D tensor, output is computed as: |
3559 | |
3560 | out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0 |
3561 | out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1 |
3562 | out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2 |
3563 | |
3564 | ``indices`` must have same shape as ``data``, except at dimension ``axis`` |
3565 | which must just be not null. Output will have same shape as ``indices``. |
3566 | )code" TVM_ADD_FILELINE) |
3567 | .set_attrs_type<GatherAttrs>() |
3568 | .set_num_inputs(2) |
3569 | .add_argument("data" , "Tensor" , "The input data to the operator." ) |
3570 | .add_argument("indices" , "Tensor" , "The indices of values to gather." ) |
3571 | .set_support_level(3) |
3572 | .add_type_rel("Gather" , GatherRel) |
3573 | .set_attr<FTVMCompute>("FTVMCompute" , GatherCompute) |
3574 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
3575 | |
3576 | TVM_REGISTER_NODE_TYPE(GatherNDAttrs); |
3577 | |
3578 | // gather_nd operator |
3579 | bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3580 | const TypeReporter& reporter) { |
3581 | // `types` contains: [data, indices, result] |
3582 | ICHECK_EQ(types.size(), 3); |
3583 | const auto* data = types[0].as<TensorTypeNode>(); |
3584 | const auto* indices = types[1].as<TensorTypeNode>(); |
3585 | if (data == nullptr) { |
3586 | ICHECK(types[0].as<IncompleteTypeNode>()) |
3587 | << "GatherND: expect input data type to be TensorType but get " << types[0]; |
3588 | return false; |
3589 | } |
3590 | if (indices == nullptr) { |
3591 | ICHECK(types[1].as<IncompleteTypeNode>()) |
3592 | << "GatherND: expect indices type to be TensorType but get " << types[1]; |
3593 | return false; |
3594 | } |
3595 | const size_t ndim = data->shape.size(); |
3596 | const IntImmNode* mdim = indices->shape[0].as<IntImmNode>(); |
3597 | ICHECK(mdim) << "GatherND needs a static shape for the first axis of indices, got " |
3598 | << indices->shape; |
3599 | const size_t kdim = indices->shape.size() - 1; |
3600 | ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy." ; |
3601 | |
3602 | const auto param = attrs.as<GatherNDAttrs>(); |
3603 | ICHECK(param != nullptr); |
3604 | |
3605 | for (int i = 0; i < param->batch_dims->value; ++i) { |
3606 | ICHECK(reporter->AssertEQ( |
3607 | data->shape[i], indices->shape[i + 1])); // +1 since the first axis is the index tuple |
3608 | } |
3609 | |
3610 | Array<IndexExpr> oshape; |
3611 | for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); |
3612 | for (size_t i = mdim->value + param->batch_dims->value; i < ndim; ++i) |
3613 | oshape.push_back(data->shape[i]); |
3614 | reporter->Assign(types[2], TensorType(oshape, data->dtype)); |
3615 | return true; |
3616 | } |
3617 | |
3618 | Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3619 | const Type& out_type) { |
3620 | const auto* param = attrs.as<GatherNDAttrs>(); |
3621 | ICHECK(param); |
3622 | return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims.IntValue())}; |
3623 | } |
3624 | |
3625 | Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, |
3626 | Optional<Integer> index_rank = NullValue<Integer>()) { |
3627 | static const Op& op = Op::Get("gather_nd" ); |
3628 | auto attrs = make_object<GatherNDAttrs>(); |
3629 | attrs->batch_dims = batch_dims; |
3630 | attrs->index_rank = index_rank; |
3631 | return Call(op, {data, indices}, Attrs(attrs)); |
3632 | } |
3633 | |
3634 | TVM_REGISTER_GLOBAL("relay.op._make.gather_nd" ).set_body_typed(MakeGatherND); |
3635 | |
3636 | RELAY_REGISTER_OP("gather_nd" ) |
3637 | .describe(R"code(Gather elements or slices from data and store to |
3638 | a tensor whose shape is defined by indices. |
3639 | |
3640 | Optionally, batch_dims, the number of batch dimensions, can be given, whose |
3641 | default value is 0. |
3642 | |
3643 | Let B denote batch_dims, and data, indices shape be (X_0, X_1, ..., X_{N-1}), |
3644 | (M, Y_0, ..., Y_{K-1}) respectively. |
3645 | |
3646 | When B > 0, indexing will start from the B-th axis, and it must be the case that |
3647 | X_0, ... X_{B-1} == Y_0, ... Y_{B-1}. The output will have a shape |
3648 | (X_0, ..., X_{B-1}, Y_B, ..., Y_{K-1}, X_{M+B}, ..., X_{N-1}), where M + B <= N. |
3649 | |
3650 | When B == 0 (the default case), the output shape will be (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}). |
3651 | |
3652 | In both cases, if M + B == N, the output shape will simply be (Y_0, ..., Y_{K-1}). |
3653 | )code" TVM_ADD_FILELINE) |
3654 | .set_num_inputs(2) |
3655 | .set_attrs_type<GatherNDAttrs>() |
3656 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3657 | .add_argument("indices" , "Tensor" , "The indices of values to gather." ) |
3658 | .set_support_level(3) |
3659 | .add_type_rel("GatherND" , GatherNDRel) |
3660 | .set_attr<FTVMCompute>("FTVMCompute" , GatherNDCompute) |
3661 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
3662 | |
3663 | // relay.sequence_mask |
3664 | TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs); |
3665 | |
3666 | bool SequenceMaskRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3667 | const TypeReporter& reporter) { |
3668 | // `types` contains: [data, valid_length, result] |
3669 | ICHECK_EQ(types.size(), 3); |
3670 | const auto* data = types[0].as<TensorTypeNode>(); |
3671 | const auto* valid_length = types[1].as<TensorTypeNode>(); |
3672 | ICHECK(data); |
3673 | ICHECK(valid_length); |
3674 | const auto param = attrs.as<SequenceMaskAttrs>(); |
3675 | Array<IndexExpr> valid_length_shape; |
3676 | ICHECK(param->axis == 0 || param->axis == 1); |
3677 | valid_length_shape.push_back(data->shape[1 - param->axis]); |
3678 | reporter->Assign(types[1], TensorType(valid_length_shape, valid_length->dtype)); |
3679 | reporter->Assign(types[2], types[0]); |
3680 | return true; |
3681 | } |
3682 | |
3683 | Array<te::Tensor> SequenceMaskCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3684 | const Type& out_type) { |
3685 | const auto* param = attrs.as<SequenceMaskAttrs>(); |
3686 | ICHECK(param != nullptr); |
3687 | return Array<te::Tensor>{ |
3688 | topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis)}; |
3689 | } |
3690 | |
3691 | Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) { |
3692 | auto attrs = make_object<SequenceMaskAttrs>(); |
3693 | attrs->mask_value = std::move(mask_value); |
3694 | attrs->axis = std::move(axis); |
3695 | static const Op& op = Op::Get("sequence_mask" ); |
3696 | return Call(op, {data, valid_length}, Attrs(attrs), {}); |
3697 | } |
3698 | |
3699 | TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask" ).set_body_typed(MakeSequenceMask); |
3700 | |
3701 | RELAY_REGISTER_OP("sequence_mask" ) |
3702 | .describe( |
3703 | R"code(Sets all elements outside the expected length of the sequence to a constant value. |
3704 | |
3705 | This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or |
3706 | [batch_size, MAX_LENGTH, ...] and returns an array of the same shape. |
3707 | |
3708 | `axis` means the axis of the length dimension and can only be 0 or 1. If axis is 0, |
3709 | the data must have shape [MAX_LENGTH, batch_size, ...]. Otherwise (axis=1), the data must have |
3710 | shape [batch_size, MAX_LENGTH, ...]. |
3711 | |
3712 | `valid_length` gives the length of each sequence. `valid_length` should be |
3713 | a 1D int array with positive ints and has dimension [batch_size,]. |
3714 | |
3715 | Examples:: |
3716 | |
3717 | x = [[[ 1., 2., 3.], |
3718 | [ 4., 5., 6.]], |
3719 | |
3720 | [[ 7., 8., 9.], |
3721 | [ 10., 11., 12.]], |
3722 | |
3723 | [[ 13., 14., 15.], |
3724 | [ 16., 17., 18.]]] |
3725 | |
3726 | // valid_length [1, 1] means only the first block of each batch will be kept |
3727 | // and other blocks are masked with default mask value = 0 |
3728 | sequence_mask(x, valid_length=[1, 1]) = |
3729 | [[[ 1., 2., 3.], |
3730 | [ 4., 5., 6.]], |
3731 | |
3732 | [[ 0., 0., 0.], |
3733 | [ 0., 0., 0.]], |
3734 | |
3735 | [[ 0., 0., 0.], |
3736 | [ 0., 0., 0.]]] |
3737 | |
3738 | // valid_length [2, 3] means the first 2 blocks of the 1st batch will be kept |
3739 | // and the first 3 blocks of the 2nd batch will be kept |
3740 | // the masked values are set to be the specified mask value = 0.1 |
3741 | sequence_mask(x, valid_length=[2, 3], mask_value=0.1) = |
3742 | [[[ 1., 2., 3.], |
3743 | [ 4., 5., 6.]], |
3744 | |
3745 | [[ 7., 8., 9.], |
3746 | [ 10., 11., 12.]], |
3747 | |
3748 | [[ 0.1, 0.1, 0.1], |
3749 | [ 16., 17., 18.]]] |
3750 | )code" TVM_ADD_FILELINE) |
3751 | .set_attrs_type<SequenceMaskAttrs>() |
3752 | .set_num_inputs(2) |
3753 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3754 | .add_argument("valid_length" , "Tensor" , "The real (valid) length of each sequence." ) |
3755 | .set_support_level(10) |
3756 | .add_type_rel("SequenceMask" , SequenceMaskRel) |
3757 | .set_attr<FTVMCompute>("FTVMCompute" , SequenceMaskCompute) |
3758 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
3759 | |
3760 | // relay.one_hot |
3761 | TVM_REGISTER_NODE_TYPE(OneHotAttrs); |
3762 | |
3763 | bool OneHotRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3764 | const TypeReporter& reporter) { |
3765 | // `types` contains: [indices, on_value, off_value, result] |
3766 | ICHECK_EQ(types.size(), 4); |
3767 | const auto* indices = types[0].as<TensorTypeNode>(); |
3768 | ICHECK(indices); |
3769 | |
3770 | const auto param = attrs.as<OneHotAttrs>(); |
3771 | ICHECK_GT(param->depth, 0); |
3772 | |
3773 | Array<IndexExpr> oshape; |
3774 | int ndim = indices->shape.size() + 1; |
3775 | int indices_index = 0; |
3776 | int true_axis = (param->axis == -1) ? indices->shape.size() : param->axis; |
3777 | for (int i = 0; i < ndim; i++) { |
3778 | if (i == true_axis) { |
3779 | oshape.push_back(Integer(param->depth)); |
3780 | } else { |
3781 | oshape.push_back(indices->shape[indices_index++]); |
3782 | } |
3783 | } |
3784 | |
3785 | reporter->Assign(types[3], TensorType(oshape, param->dtype)); |
3786 | return true; |
3787 | } |
3788 | |
3789 | Array<te::Tensor> OneHotCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3790 | const Type& out_type) { |
3791 | const auto* param = attrs.as<OneHotAttrs>(); |
3792 | ICHECK(param != nullptr); |
3793 | return Array<te::Tensor>{ |
3794 | topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype)}; |
3795 | } |
3796 | |
3797 | Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype) { |
3798 | auto attrs = make_object<OneHotAttrs>(); |
3799 | attrs->depth = std::move(depth); |
3800 | attrs->axis = axis; |
3801 | attrs->dtype = dtype; |
3802 | static const Op& op = Op::Get("one_hot" ); |
3803 | return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); |
3804 | } |
3805 | |
3806 | TVM_REGISTER_GLOBAL("relay.op._make.one_hot" ).set_body_typed(MakeOneHot); |
3807 | |
3808 | RELAY_REGISTER_OP("one_hot" ) |
3809 | .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, |
3810 | other locations take value 0. Final dimension is <indices dimensions> x depth. |
3811 | |
3812 | **indices** Locations to set to 1. |
3813 | |
3814 | **on_value** Value to fill at indices. |
3815 | |
3816 | **off_value** Value to fill at all other positions besides indices. |
3817 | |
3818 | **depth** Depth of the one-hot dimension. |
3819 | |
3820 | **axis** Axis to fill. |
3821 | |
3822 | **dtype**)code" TVM_ADD_FILELINE) |
3823 | .set_attrs_type<OneHotAttrs>() |
3824 | .set_num_inputs(3) |
3825 | .add_argument("indices" , "Tensor" , "Locations to set to on_value." ) |
3826 | .add_argument("on_value" , "Expr" , "Value to fill at indices." ) |
3827 | .add_argument("off_value" , "Expr" , "Value to fill at all other positions besides indices." ) |
3828 | .set_support_level(10) |
3829 | .add_type_rel("OneHot" , OneHotRel) |
3830 | .set_attr<FTVMCompute>("FTVMCompute" , OneHotCompute) |
3831 | .set_attr<TOpPattern>("TOpPattern" , kOutEWiseFusable); |
3832 | |
3833 | /* relay.unravel_index */ |
3834 | bool UnRavelIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3835 | const TypeReporter& reporter) { |
3836 | ICHECK_EQ(types.size(), 3); |
3837 | |
3838 | const auto* indices = types[0].as<TensorTypeNode>(); |
3839 | if (indices == nullptr) { |
3840 | ICHECK(types[0].as<IncompleteTypeNode>()) |
3841 | << "unravel_index: expect input type to be TensorType but get " << types[0]; |
3842 | return false; |
3843 | } |
3844 | ICHECK(indices->dtype.is_int() || indices->dtype.is_uint()) |
3845 | << "indices of unravel_index must be tensor of integer" ; |
3846 | |
3847 | const auto* shape = types[1].as<TensorTypeNode>(); |
3848 | if (shape == nullptr) { |
3849 | ICHECK(types[1].as<IncompleteTypeNode>()) |
3850 | << "unravel_index: expect input type to be TensorType but get " << types[1]; |
3851 | return false; |
3852 | } |
3853 | ICHECK(shape->dtype.is_int() || shape->dtype.is_uint()) |
3854 | << "shape of unravel_index must be tensor of integer" ; |
3855 | |
3856 | Array<IndexExpr> indices_shape; |
3857 | Array<IndexExpr> shape_shape; |
3858 | indices_shape = indices->shape; |
3859 | shape_shape = shape->shape; |
3860 | |
3861 | Array<IndexExpr> oshape; |
3862 | oshape.push_back(shape_shape[0]); |
3863 | if (indices_shape.size() != 0) { |
3864 | oshape.push_back(indices_shape[0]); |
3865 | } |
3866 | reporter->Assign(types[2], TensorType(oshape, indices->dtype)); |
3867 | return true; |
3868 | } |
3869 | |
3870 | Array<te::Tensor> UnRavelIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3871 | const Type& out_type) { |
3872 | return Array<te::Tensor>{topi::unravel_index(inputs[0], inputs[1])}; |
3873 | } |
3874 | |
3875 | Expr MakeUnRavelIndex(Expr data, Expr shape) { |
3876 | static const Op& op = Op::Get("unravel_index" ); |
3877 | return Call(op, {data, shape}, Attrs(), {}); |
3878 | } |
3879 | |
3880 | TVM_REGISTER_GLOBAL("relay.op._make.unravel_index" ).set_body_typed(MakeUnRavelIndex); |
3881 | |
3882 | RELAY_REGISTER_OP("unravel_index" ) |
3883 | .describe( |
3884 | R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. |
3885 | |
3886 | Example:: |
3887 | - unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]] |
3888 | )code" TVM_ADD_FILELINE) |
3889 | .set_num_inputs(2) |
3890 | .add_argument("data" , "Tensor" , "The input tensor." ) |
3891 | .add_argument("shape" , "Tensor" , "The shape tensor." ) |
3892 | .set_support_level(3) |
3893 | .add_type_rel("UnRavelIndexRel" , UnRavelIndexRel) |
3894 | .set_attr<FTVMCompute>("FTVMCompute" , UnRavelIndexCompute) |
3895 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
3896 | |
3897 | // sparse_to_dense |
3898 | TVM_REGISTER_NODE_TYPE(SparseToDenseAttrs); |
3899 | |
3900 | bool SparseToDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3901 | const TypeReporter& reporter) { |
3902 | ICHECK_EQ(num_inputs, 3); |
3903 | auto sparse_indices = types[0].as<TensorTypeNode>(); |
3904 | auto sparse_values = types[1].as<TensorTypeNode>(); |
3905 | auto default_value = types[2].as<TensorTypeNode>(); |
3906 | |
3907 | if (sparse_indices == nullptr || sparse_values == nullptr || default_value == nullptr) { |
3908 | return false; |
3909 | } |
3910 | |
3911 | ICHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers" ; |
3912 | |
3913 | ICHECK_LE(sparse_indices->shape.size(), 3) |
3914 | << "sparse_indices must be a tensor of either 0D, 1D or 2D" ; |
3915 | |
3916 | ICHECK_LE(sparse_values->shape.size(), 2) << "sparse_values must be a tensor of either 0D, 1D" ; |
3917 | |
3918 | ICHECK_EQ(default_value->shape.size(), 0) << "default_value should be a scalar" ; |
3919 | |
3920 | const auto* param = attrs.as<SparseToDenseAttrs>(); |
3921 | ICHECK(param != nullptr); |
3922 | |
3923 | Array<IndexExpr> oshape; |
3924 | for (auto i : param->output_shape) { |
3925 | oshape.push_back(i); |
3926 | } |
3927 | reporter->Assign(types[3], TensorType(oshape, sparse_values->dtype)); |
3928 | return true; |
3929 | } |
3930 | |
3931 | Array<te::Tensor> SparseToDenseCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
3932 | const Type& out_type) { |
3933 | ICHECK_EQ(inputs.size(), 3); |
3934 | const auto* param = attrs.as<SparseToDenseAttrs>(); |
3935 | ICHECK(param != nullptr); |
3936 | Array<IndexExpr> output_shape; |
3937 | for (auto val : param->output_shape) { |
3938 | output_shape.push_back(val); |
3939 | } |
3940 | return {topi::sparse_to_dense(inputs[0], output_shape, inputs[1], inputs[2]())}; |
3941 | } |
3942 | |
3943 | Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value) { |
3944 | auto attrs = make_object<SparseToDenseAttrs>(); |
3945 | attrs->output_shape = std::move(output_shape); |
3946 | static const Op& op = Op::Get("sparse_to_dense" ); |
3947 | return Call(op, {indices, values, default_value}, Attrs(attrs)); |
3948 | } |
3949 | |
3950 | TVM_REGISTER_GLOBAL("relay.op._make.sparse_to_dense" ).set_body_typed(MakeSparseToDense); |
3951 | |
3952 | RELAY_REGISTER_OP("sparse_to_dense" ) |
3953 | .describe(R"code(A dense tensor from a sparse representation. |
3954 | |
3955 | - **sparse_indices**: A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values |
3956 | |
3957 | - **output_shape**: A list of integers. Shape of the dense output tensor. |
3958 | |
3959 | - **sparse_values**: A 0-D or 1-D tensor containing the sparse values for the sparse indices. |
3960 | |
3961 | - **default_value**: A 0-D tensor containing the default value for the remaining locations. Defaults to 0. |
3962 | |
3963 | Example:: |
3964 | - sparse_to_dense([0, 0], [1, 2]], [3, 4], [1, 2], 0) = [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]] |
3965 | |
3966 | )code" TVM_ADD_FILELINE) |
3967 | .set_num_inputs(3) |
3968 | .set_support_level(3) |
3969 | .set_attrs_type<SparseToDenseAttrs>() |
3970 | .add_argument("sparse_indices" , "Tensor" , "Contains sparse indices." ) |
3971 | .add_argument("sparse_values" , "Tensor" , "Contains values for sparse indices." ) |
3972 | .add_argument("default_value" , "Tensor" , "Value to set for non-sparse indices. Defaults to 0." ) |
3973 | .add_type_rel("SparseToDense" , SparseToDenseRel) |
3974 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
3975 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
3976 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
3977 | .set_attr<FTVMCompute>("FTVMCompute" , SparseToDenseCompute); |
3978 | |
3979 | // relay.matrix_set_diag |
3980 | TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs); |
3981 | |
3982 | bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
3983 | const TypeReporter& reporter) { |
3984 | // `types` contains: [input, diagonal, result] |
3985 | ICHECK_EQ(types.size(), 3); |
3986 | |
3987 | const auto* input = types[0].as<TensorTypeNode>(); |
3988 | ICHECK(input); |
3989 | |
3990 | const auto* diagonal = types[1].as<TensorTypeNode>(); |
3991 | ICHECK(diagonal); |
3992 | |
3993 | const auto param = attrs.as<MatrixSetDiagAttrs>(); |
3994 | ICHECK_GE(param->k2, param->k1); |
3995 | |
3996 | int d_ndims = diagonal->shape.size(); |
3997 | int i_ndims = input->shape.size(); |
3998 | |
3999 | reporter->Assert(input->shape[i_ndims - 2] > -param->k1); |
4000 | reporter->Assert(input->shape[i_ndims - 1] > param->k2); |
4001 | |
4002 | for (int i = 0; i < d_ndims - 2; i++) { |
4003 | reporter->AssertEQ(input->shape[i], diagonal->shape[i]); |
4004 | } |
4005 | if (param->k1 != param->k2) { |
4006 | reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1); |
4007 | } else if (d_ndims >= 2) { |
4008 | reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]); |
4009 | } |
4010 | auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <= |
4011 | input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0), |
4012 | input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0), |
4013 | input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0)); |
4014 | reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len); |
4015 | |
4016 | reporter->Assign(types[2], TensorType(input->shape, input->dtype)); |
4017 | return true; |
4018 | } |
4019 | |
4020 | Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
4021 | const Type& out_type) { |
4022 | const auto* param = attrs.as<MatrixSetDiagAttrs>(); |
4023 | ICHECK(param != nullptr); |
4024 | return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2, |
4025 | param->super_diag_right_align, |
4026 | param->sub_diag_right_align)}; |
4027 | } |
4028 | |
4029 | Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align, |
4030 | bool sub_diag_right_align) { |
4031 | auto attrs = make_object<MatrixSetDiagAttrs>(); |
4032 | attrs->k1 = k1; |
4033 | attrs->k2 = k2; |
4034 | attrs->super_diag_right_align = super_diag_right_align; |
4035 | attrs->sub_diag_right_align = sub_diag_right_align; |
4036 | static const Op& op = Op::Get("matrix_set_diag" ); |
4037 | return Call(op, {input, diagonal}, Attrs(attrs), {}); |
4038 | } |
4039 | |
4040 | TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag" ).set_body_typed(MakeMatrixSetDiag); |
4041 | |
4042 | RELAY_REGISTER_OP("matrix_set_diag" ) |
4043 | .describe( |
4044 | R"code(Returns a tensor with the diagonals of input tensor replaced with the provided diagonal values. |
4045 | **input** Input tensor. |
4046 | **diagonal** Values to be filled in the diagonal. |
4047 | **k1** Lower limit (included) of the range of diagonals. |
4048 | **k2** Upper limit (included) of the range of diagonals. |
4049 | **super_diag_right_align** Bool, true iff super-diagonal is right aligned (left-padded). |
4050 | **sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded). |
4051 | )code" TVM_ADD_FILELINE) |
4052 | .set_attrs_type<MatrixSetDiagAttrs>() |
4053 | .set_num_inputs(2) |
4054 | .add_argument("input" , "Tensor" , "Input Tensor." ) |
4055 | .add_argument("diagonal" , "Tensor" , "Values to be filled in the diagonal." ) |
4056 | .set_support_level(10) |
4057 | .add_type_rel("MatrixSetDiag" , MatrixSetDiagRel) |
4058 | .set_attr<FTVMCompute>("FTVMCompute" , MatrixSetDiagCompute) |
4059 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
4060 | |
4061 | // adv_index |
4062 | bool AdvIndexRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
4063 | const TypeReporter& reporter) { |
4064 | ICHECK_EQ(num_inputs, 1); |
4065 | auto inputs = types[0].as<TupleTypeNode>(); |
4066 | auto data = inputs->fields[0].as<TensorTypeNode>(); |
4067 | |
4068 | if (inputs == nullptr || data == nullptr) { |
4069 | return false; |
4070 | } |
4071 | ICHECK_LE(inputs->fields.size() - 1, data->shape.size()) << "too many indices for data!" ; |
4072 | |
4073 | Array<IndexExpr> oshape; |
4074 | TensorType broadcast_type = Downcast<TensorType>(inputs->fields[1]); |
4075 | for (size_t i = 2; i < inputs->fields.size(); ++i) { |
4076 | broadcast_type = |
4077 | ConcreteBroadcast(broadcast_type, Downcast<TensorType>(inputs->fields[i]), data->dtype); |
4078 | } |
4079 | |
4080 | for (const auto& dim : broadcast_type->shape) { |
4081 | oshape.push_back(dim); |
4082 | } |
4083 | for (size_t i = inputs->fields.size() - 1; i < data->shape.size(); ++i) { |
4084 | oshape.push_back(data->shape[i]); |
4085 | } |
4086 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
4087 | return true; |
4088 | } |
4089 | |
4090 | Array<te::Tensor> AdvIndexCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, |
4091 | const Type& out_type) { |
4092 | Array<te::Tensor> indices; |
4093 | for (size_t i = 1; i < inputs.size(); ++i) { |
4094 | indices.push_back(inputs[i]); |
4095 | } |
4096 | return {topi::adv_index(inputs[0], indices)}; |
4097 | } |
4098 | |
4099 | Expr MakeAdvIndex(Expr inputs) { |
4100 | static const Op& op = Op::Get("adv_index" ); |
4101 | return Call(op, {inputs}, Attrs(), {}); |
4102 | } |
4103 | |
4104 | TVM_REGISTER_GLOBAL("relay.op._make.adv_index" ).set_body_typed(MakeAdvIndex); |
4105 | |
4106 | RELAY_REGISTER_OP("adv_index" ) |
4107 | .describe(R"code(Numpy style advanced indexing. Index with a list of tensors. |
4108 | )code" TVM_ADD_FILELINE) |
4109 | .set_num_inputs(1) |
4110 | .set_support_level(3) |
4111 | .add_argument("inputs" , "Tuple of Tensors" , "Input tensor and indices." ) |
4112 | .add_type_rel("AdvIndex" , AdvIndexRel) |
4113 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
4114 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
4115 | .set_attr<FTVMCompute>("FTVMCompute" , AdvIndexCompute); |
4116 | |
4117 | TVM_REGISTER_NODE_TYPE(ScanopAttrs); |
4118 | |
4119 | bool ScanopRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
4120 | const TypeReporter& reporter) { |
4121 | // types: [data, output] |
4122 | ICHECK_EQ(types.size(), 2) << "Expects two types, one for the input and another for the output" ; |
4123 | const auto* data = types[0].as<TensorTypeNode>(); |
4124 | if (data == nullptr) { |
4125 | ICHECK(types[0].as<IncompleteTypeNode>()) |
4126 | << "Scanop: expect input type to be TensorType but get " << types[0]; |
4127 | return false; |
4128 | } |
4129 | |
4130 | const auto* param = attrs.as<ScanopAttrs>(); |
4131 | |
4132 | auto dtype = param->dtype; |
4133 | if (dtype.is_void()) { |
4134 | dtype = data->dtype; |
4135 | } |
4136 | |
4137 | if (param->axis.defined()) { |
4138 | reporter->Assign(types[1], TensorType(data->shape, dtype)); |
4139 | } else { |
4140 | auto prod = data->shape[0]; |
4141 | for (size_t i = 1; i < data->shape.size(); ++i) { |
4142 | prod = prod * data->shape[i]; |
4143 | } |
4144 | reporter->Assign(types[1], TensorType({prod}, dtype)); |
4145 | } |
4146 | |
4147 | return true; |
4148 | } |
4149 | |
4150 | Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { |
4151 | auto attrs = make_object<ScanopAttrs>(); |
4152 | attrs->dtype = dtype; |
4153 | attrs->axis = axis; |
4154 | attrs->exclusive = exclusive; |
4155 | static const Op& op = Op::Get("cumsum" ); |
4156 | return Call(op, {data}, Attrs(attrs), {}); |
4157 | } |
4158 | |
4159 | TVM_REGISTER_GLOBAL("relay.op._make.cumsum" ).set_body_typed(MakeCumsum); |
4160 | |
4161 | RELAY_REGISTER_OP("cumsum" ) |
4162 | .describe( |
4163 | R"doc(Return the cumulative sum of the elements along a given axis.)doc" TVM_ADD_FILELINE) |
4164 | .set_num_inputs(1) |
4165 | .add_argument("data" , "Tensor" , "The input tensor." ) |
4166 | .set_support_level(3) |
4167 | .add_type_rel("Cumsum" , ScanopRel) |
4168 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
4169 | |
4170 | Expr MakeCumprod(Expr data, Integer axis, DataType dtype, Bool exclusive) { |
4171 | auto attrs = make_object<ScanopAttrs>(); |
4172 | attrs->dtype = dtype; |
4173 | attrs->axis = axis; |
4174 | attrs->exclusive = exclusive; |
4175 | static const Op& op = Op::Get("cumprod" ); |
4176 | return Call(op, {data}, Attrs(attrs), {}); |
4177 | } |
4178 | |
4179 | TVM_REGISTER_GLOBAL("relay.op._make.cumprod" ).set_body_typed(MakeCumprod); |
4180 | |
4181 | RELAY_REGISTER_OP("cumprod" ) |
4182 | .describe( |
4183 | R"doc(Return the cumulative product of the elements along a given axis.)doc" TVM_ADD_FILELINE) |
4184 | .set_num_inputs(1) |
4185 | .add_argument("data" , "Tensor" , "The input tensor." ) |
4186 | .set_support_level(3) |
4187 | .add_type_rel("Cumprod" , ScanopRel) |
4188 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
4189 | |
4190 | TVM_REGISTER_NODE_TYPE(UniqueAttrs); |
4191 | |
4192 | bool UniqueRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
4193 | const TypeReporter& reporter) { |
4194 | // types: [data, result] |
4195 | ICHECK_EQ(types.size(), 2) << "Unique: expect 2 types but " << types.size() << " provided" ; |
4196 | ICHECK_EQ(num_inputs, 1) << "Unique: expect 1 inputs but " << num_inputs << " provided" ; |
4197 | auto data = types[0].as<TensorTypeNode>(); |
4198 | if (data == nullptr) { |
4199 | ICHECK(types[0].as<IncompleteTypeNode>()) |
4200 | << "Unique: expect input type to be TensorType but get " << types[0]; |
4201 | return false; |
4202 | } |
4203 | const int ndim = static_cast<int>(data->shape.size()); |
4204 | ICHECK_EQ(ndim, 1) << "Unique: input must be 1-D tensor" ; |
4205 | |
4206 | std::vector<Type> fields; |
4207 | fields.push_back(TensorType(data->shape, data->dtype)); // unique |
4208 | fields.push_back(TensorType(data->shape, DataType::Int(32))); // indices |
4209 | fields.push_back(TensorType(data->shape, DataType::Int(32))); // inverse_indices |
4210 | fields.push_back(TensorType(Array<PrimExpr>{1}, DataType::Int(32))); // num_unique |
4211 | const auto* param = attrs.as<UniqueAttrs>(); |
4212 | if (param->return_counts) { |
4213 | fields.push_back(TensorType(data->shape, DataType::Int(32))); // counts |
4214 | } |
4215 | reporter->Assign(types[1], TupleType(Array<Type>(fields))); |
4216 | return true; |
4217 | } |
4218 | |
4219 | Expr MakeUnique(Expr data, bool sorted, bool return_counts) { |
4220 | auto attrs = make_object<UniqueAttrs>(); |
4221 | attrs->sorted = sorted; |
4222 | attrs->return_counts = return_counts; |
4223 | static const Op& op = Op::Get("unique" ); |
4224 | return Call(op, {data}, Attrs(attrs), {}); |
4225 | } |
4226 | |
4227 | TVM_REGISTER_GLOBAL("relay.op._make.unique" ).set_body_typed(MakeUnique); |
4228 | |
4229 | RELAY_REGISTER_OP("unique" ) |
4230 | .describe( |
4231 | R"code(This operation returns the unique elements and the new index of each item in a given 1-D array. |
4232 | )code" TVM_ADD_FILELINE) |
4233 | .set_num_inputs(1) |
4234 | .add_argument("data" , "Tensor" , "The input tensor" ) |
4235 | .add_type_rel("unique" , UniqueRel) |
4236 | .set_support_level(3) |
4237 | .set_attr<TOpPattern>("TOpPattern" , kOpaque); |
4238 | |
4239 | // invert_permutation |
4240 | Expr MakeInvertPermutation(Expr data) { |
4241 | static const Op& op = Op::Get("invert_permutation" ); |
4242 | return Call(op, {data}, Attrs(), {}); |
4243 | } |
4244 | |
4245 | TVM_REGISTER_GLOBAL("relay.op._make.invert_permutation" ).set_body_typed(MakeInvertPermutation); |
4246 | |
4247 | RELAY_REGISTER_OP("invert_permutation" ) |
4248 | .describe(R"doc(Computes the inverse permutation of a tensor.)doc" TVM_ADD_FILELINE) |
4249 | .set_num_inputs(1) |
4250 | .add_argument("data" , "Tensor" , "The input tensor." ) |
4251 | .add_type_rel("Identity" , IdentityRel) |
4252 | .set_support_level(1) |
4253 | .set_attr<TOpPattern>("TOpPattern" , kInjective) |
4254 | .set_attr<TOpIsStateful>("TOpIsStateful" , false); |
4255 | |
4256 | // Trilu |
4257 | |
4258 | TVM_REGISTER_NODE_TYPE(TriluAttrs); |
4259 | |
4260 | bool TriluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
4261 | const TypeReporter& reporter) { |
4262 | // types: [data, k, result] |
4263 | ICHECK_EQ(types.size(), 3) << "Trilu: expect 3 types but " << types.size() << " provided" ; |
4264 | ICHECK_EQ(num_inputs, 2) << "Trilu: expect 2 inputs but " << num_inputs << " provided" ; |
4265 | auto data = types[0].as<TensorTypeNode>(); |
4266 | if (data == nullptr) { |
4267 | ICHECK(types[0].as<IncompleteTypeNode>()) |
4268 | << "Trilu: expect input type to be TensorType but get " << types[0]; |
4269 | return false; |
4270 | } |
4271 | |
4272 | auto k = types[1].as<TensorTypeNode>(); |
4273 | if (k == nullptr) { |
4274 | ICHECK(types[1].as<IncompleteTypeNode>()) |
4275 | << "Trilu: expect k type to be TensorType but get " << types[1]; |
4276 | return false; |
4277 | } |
4278 | |
4279 | ICHECK(k->shape.size() == 0) << "Trilu: k must be a 0-D tensor but get " << k; |
4280 | |
4281 | // Output shape is the same as input shape. |
4282 | reporter->Assign(types[2], TensorType(data->shape, data->dtype)); |
4283 | return true; |
4284 | } |
4285 | |
4286 | Expr MakeTrilu(Expr data, Expr k, bool upper) { |
4287 | auto attrs = make_object<TriluAttrs>(); |
4288 | attrs->upper = upper; |
4289 | static const Op& op = Op::Get("trilu" ); |
4290 | return Call(op, {data, k}, Attrs(attrs), {}); |
4291 | } |
4292 | |
4293 | TVM_REGISTER_GLOBAL("relay.op._make.trilu" ).set_body_typed(MakeTrilu); |
4294 | |
4295 | RELAY_REGISTER_OP("trilu" ) |
4296 | .describe( |
4297 | R"code(Filters out the upper or lower portion of an input tensor on one side of a diagonal. |
4298 | )code" TVM_ADD_FILELINE) |
4299 | .set_num_inputs(2) |
4300 | .add_argument("data" , "Tensor" , "The input tensor" ) |
4301 | .add_argument("k" , "Tensor" , "The number of diagonals above or below the main to exclude." ) |
4302 | .add_type_rel("trilu" , TriluRel) |
4303 | .set_support_level(3) |
4304 | .set_attr<TOpPattern>("TOpPattern" , kElemWise); |
4305 | |
4306 | // FixedPointMultiplyPerAxis |
4307 | |
4308 | TVM_REGISTER_NODE_TYPE(FixedPointMultiplyPerAxisAttrs); |
4309 | |
4310 | bool FixedPointMultiplyPerAxisRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
4311 | const TypeReporter& reporter) { |
4312 | ICHECK_EQ(types.size(), 5) << "FixedPointMultiplyPerAxis: expect 5 types but " << types.size() |
4313 | << " provided" ; |
4314 | ICHECK_EQ(num_inputs, 4) << "FixedPointMultiplyPerAxis: expect 4 inputs but " << num_inputs |
4315 | << " provided" ; |
4316 | |
4317 | for (int i = 0; i < num_inputs; i++) { |
4318 | auto data = types[i].as<TensorTypeNode>(); |
4319 | if (data == nullptr) { |
4320 | ICHECK(types[i].as<IncompleteTypeNode>()) |
4321 | << "FixedPointMultiplyPerAxis: expect input type to be TensorType but get " << types[i]; |
4322 | return false; |
4323 | } |
4324 | } |
4325 | |
4326 | return IdentityRel({types[0], types[4]}, 1, attrs, reporter); |
4327 | } |
4328 | |
4329 | InferCorrectLayoutOutput FixedPointMultiplyPerAxisInferCorrectLayout( |
4330 | const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, |
4331 | const Array<tvm::relay::Type>& old_in_types) { |
4332 | const auto* attrs_ptr = attrs.as<FixedPointMultiplyPerAxisAttrs>(); |
4333 | ICHECK(attrs_ptr); |
4334 | ObjectPtr<FixedPointMultiplyPerAxisAttrs> param = |
4335 | make_object<FixedPointMultiplyPerAxisAttrs>(*attrs_ptr); |
4336 | |
4337 | Array<Array<IndexExpr>> old_in_shapes; |
4338 | for (auto old_in_t : old_in_types) { |
4339 | ICHECK(old_in_t.as<TensorTypeNode>()); |
4340 | old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape); |
4341 | } |
4342 | |
4343 | Array<Layout> input_layouts, output_layouts; |
4344 | |
4345 | if (new_in_layouts.defined()) { |
4346 | const Layout& new_layout = new_in_layouts[0]; |
4347 | const Layout& old_layout = old_in_layouts[0]; |
4348 | |
4349 | std::unordered_set<std::string> old_dims; |
4350 | for (auto axis : param->axes) { |
4351 | ICHECK_GE(axis->value, 0) << "Axis out of bounds in FixedPointMultiplyPerAxis operator." ; |
4352 | ICHECK_LT(axis->value, old_in_shapes[0].size()) |
4353 | << "Axis out of bounds in FixedPointMultiplyPerAxis operator." ; |
4354 | old_dims.emplace(old_layout[axis->value].name()); |
4355 | } |
4356 | |
4357 | Array<tvm::Integer> new_axes; |
4358 | std::string new_layout_string = "" ; |
4359 | for (size_t axis_index = 0; axis_index < new_layout->axes.size(); ++axis_index) { |
4360 | const auto& layout_axis = LayoutAxis::Get(new_layout->axes[axis_index]); |
4361 | const std::string& layout_dim = layout_axis.name(); |
4362 | if (layout_axis.IsPrimal()) { |
4363 | if (old_dims.count(layout_dim)) { |
4364 | new_axes.push_back(tvm::Integer(axis_index)); |
4365 | new_layout_string += layout_dim; |
4366 | } |
4367 | } else { |
4368 | auto primal_dim = layout_axis.ToPrimal().name(); |
4369 | if (old_dims.count(primal_dim)) { |
4370 | new_axes.push_back(tvm::Integer(axis_index)); |
4371 | new_layout_string += std::to_string(new_layout.FactorOf(layout_axis)) + layout_dim; |
4372 | } |
4373 | } |
4374 | } |
4375 | |
4376 | Layout channel_layout = Layout(new_layout_string); |
4377 | |
4378 | input_layouts = {new_layout, channel_layout, channel_layout, channel_layout}; |
4379 | output_layouts = {new_layout}; |
4380 | param->axes = std::move(new_axes); |
4381 | } else if (old_in_layouts.defined()) { |
4382 | ICHECK_EQ(old_in_layouts.size(), 4); |
4383 | ICHECK_EQ(param->axes.size(), 1); // Not tested other cases |
4384 | const Layout& old_layout = old_in_layouts[0]; |
4385 | if (old_layout.defined()) { |
4386 | std::string layout_string = old_layout[param->axes[0]->value].name(); |
4387 | Layout channel_layout = Layout(layout_string); |
4388 | |
4389 | input_layouts = {old_layout, channel_layout, channel_layout, channel_layout}; |
4390 | output_layouts = {old_layout}; |
4391 | } else { |
4392 | // Set the layouts to undef. |
4393 | Layout undef = Layout::Undef(); |
4394 | input_layouts = Array<Layout>(4, undef); |
4395 | output_layouts = {undef}; |
4396 | } |
4397 | } else { |
4398 | // Set the layouts to undef. |
4399 | Layout undef = Layout::Undef(); |
4400 | input_layouts = Array<Layout>(4, undef); |
4401 | output_layouts = {undef}; |
4402 | } |
4403 | |
4404 | return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param)); |
4405 | } |
4406 | |
4407 | Expr MakeFixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, |
4408 | bool is_lshift_required, bool is_rshift_required, |
4409 | Array<Integer> axes) { |
4410 | auto attrs = make_object<FixedPointMultiplyPerAxisAttrs>(); |
4411 | attrs->is_lshift_required = is_lshift_required; |
4412 | attrs->is_rshift_required = is_rshift_required; |
4413 | attrs->axes = std::move(axes); |
4414 | static const Op& op = Op::Get("fixed_point_multiply_per_axis" ); |
4415 | return Call(op, {x, m, lshift, rshift}, Attrs(attrs), {}); |
4416 | } |
4417 | |
4418 | TVM_REGISTER_GLOBAL("relay.op._make.fixed_point_multiply_per_axis" ) |
4419 | .set_body_typed(MakeFixedPointMultiplyPerAxis); |
4420 | |
4421 | RELAY_REGISTER_OP("fixed_point_multiply_per_axis" ) |
4422 | .describe(R"code(per channel fixed point multiplication)code" TVM_ADD_FILELINE) |
4423 | .set_num_inputs(4) |
4424 | .add_argument("data" , "Tensor" , "The input tensor." ) |
4425 | .add_argument("fp_multiplier" , "Tensor" , "The multipliers tensor." ) |
4426 | .add_argument("left_shift" , "Tensor" , "The left shifts tensor." ) |
4427 | .add_argument("right_shift" , "Tensor" , "The right shifts tensor." ) |
4428 | .add_type_rel("FixedPointMultiplyPerAxis" , FixedPointMultiplyPerAxisRel) |
4429 | .set_attr<TOpPattern>("TOpPattern" , kBroadcast) |
4430 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
4431 | FixedPointMultiplyPerAxisInferCorrectLayout) |
4432 | .set_attrs_type<FixedPointMultiplyPerAxisAttrs>() |
4433 | .set_support_level(10); |
4434 | |
4435 | } // namespace relay |
4436 | } // namespace tvm |
4437 | |