1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file tvm/relay/transforms/pattern_utils.h
23 * \brief Header of internal operator functions
24 * These can be used for writing passes.
25 */
26#ifndef TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
27#define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
28
29#include <builtin_fp16.h>
30#include <tvm/node/structural_equal.h>
31#include <tvm/relay/analysis.h>
32#include <tvm/relay/attrs/nn.h>
33#include <tvm/relay/attrs/reduce.h>
34#include <tvm/relay/attrs/transform.h>
35#include <tvm/relay/expr.h>
36#include <tvm/relay/op.h>
37#include <tvm/relay/op_attr_types.h>
38#include <tvm/runtime/registry.h>
39#include <tvm/tir/data_layout.h>
40
41#include <limits>
42#include <optional>
43#include <string>
44#include <utility>
45#include <vector>
46
47#include "../backend/utils.h"
48#include "../op/make_op.h"
49
50namespace tvm {
51namespace relay {
52
53/*!
54 * \brief Dispatch DataType to the C++ data type
55 * during runtime.
56 */
57#define TVM_DTYPE_DISPATCH(type, DType, ...) \
58 if (type == DataType::Float(64)) { \
59 typedef double DType; \
60 { __VA_ARGS__ } \
61 } else if (type == DataType::Float(32)) { \
62 typedef float DType; \
63 { __VA_ARGS__ } \
64 } else if (type == DataType::Float(16)) { \
65 typedef uint16_t DType; \
66 { __VA_ARGS__ } \
67 } else if (type == DataType::BFloat(16)) { \
68 typedef uint16_t DType; \
69 { __VA_ARGS__ } \
70 } else if (type == DataType::Int(64)) { \
71 typedef int64_t DType; \
72 { __VA_ARGS__ } \
73 } else if (type == DataType::Int(32)) { \
74 typedef int32_t DType; \
75 { __VA_ARGS__ } \
76 } else if (type == DataType::Int(16)) { \
77 typedef int16_t DType; \
78 { __VA_ARGS__ } \
79 } else if (type == DataType::Int(8)) { \
80 typedef int8_t DType; \
81 { __VA_ARGS__ } \
82 } else if (type == DataType::UInt(64)) { \
83 typedef uint64_t DType; \
84 { __VA_ARGS__ } \
85 } else if (type == DataType::UInt(32)) { \
86 typedef uint32_t DType; \
87 { __VA_ARGS__ } \
88 } else if (type == DataType::UInt(16)) { \
89 typedef uint16_t DType; \
90 { __VA_ARGS__ } \
91 } else if (type == DataType::UInt(8)) { \
92 typedef uint8_t DType; \
93 { __VA_ARGS__ } \
94 } else if (type == DataType::Bool()) { \
95 typedef bool DType; \
96 { __VA_ARGS__ } \
97 } else if ((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \
98 static_cast<uint8_t>(type.code()))) { \
99 typedef double DType; \
100 { __VA_ARGS__ } \
101 } else { \
102 LOG(FATAL) << "unknown data type " << type; \
103 }
104
105/*!
106 * \brief Try to do the type inference over expr:
107 *
108 * Do the infer_type over each node in expr
109 *
110 * \param expr The IR expression
111 * \return infered expr if succeed.
112 */
113inline Expr InferType(const Expr& expr) {
114 auto mod = IRModule::FromExpr(expr);
115 mod = transform::InferType()(mod);
116 if (expr.as<FunctionNode>()) {
117 return mod->Lookup("main");
118 } else {
119 return mod->Lookup("main").as<FunctionNode>()->body;
120 }
121}
122
123/*!
124 * \brief Try to match lhs and rhs via broadcasting rule, such that:
125 *
126 * rhs matches the dimension of lhs specified by lhs_axes
127 * rhs's value equals 1 on rest of dimensions.
128 *
129 * \param tlhs The type of left operand (data)
130 * \param trhs The type right operand (bias)
131 * \param lhs_axes The axes on lhs to match.
132 * \param rhs_value A squeezed version of rhs which only contains matched dimension.
133 * \return Whether match is successful.
134 */
135inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTypeNode* trhs,
136 const Array<Integer>& lhs_axes, Expr* rhs_value = nullptr) {
137 if (tlhs->shape.size() < trhs->shape.size()) return false;
138 StructuralEqual equal;
139 size_t base = tlhs->shape.size() - trhs->shape.size();
140 size_t j = 0;
141
142 // handle case trhs is simple constant
143 if (trhs->shape.size() == 0 && rhs_value != nullptr && lhs_axes.size() > 0) {
144 *rhs_value = MakeExpandDims(*rhs_value, 0, lhs_axes.size());
145 for (size_t i = 0; i < lhs_axes.size(); i++) {
146 int repeat_value =
147 tlhs->shape[static_cast<size_t>(lhs_axes[j]->value)].as<IntImmNode>()->value;
148 *rhs_value = MakeRepeat(*rhs_value, repeat_value, i);
149 }
150 return true;
151 }
152
153 ObjectPtr<SqueezeAttrs> squeeze_attrs;
154 if (rhs_value != nullptr) {
155 squeeze_attrs = make_object<SqueezeAttrs>();
156 }
157
158 for (size_t i = 0; i < tlhs->shape.size(); ++i) {
159 if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
160 if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
161 return false;
162 }
163 ++j;
164 } else if (i >= base) {
165 if (!tir::is_const_int(trhs->shape[i - base], 1)) {
166 return false;
167 }
168 if (rhs_value != nullptr) {
169 squeeze_attrs->axis.push_back(static_cast<int>(i - base));
170 }
171 }
172 }
173 if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
174 static const Op& squeeze_op = Op::Get("squeeze");
175 *rhs_value = Call(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
176 }
177 return true;
178}
179
180/*!
181 * \brief Expand 1D Tensor to match axis.
182 *
183 * The result bias can be used to add or multiply to
184 * the target Tensor on the specified axis via broadcasting rule.
185 *
186 * \param bias The bias.
187 * \param target_ndim Target dimension.
188 * \param axes The axis on the output we want to match on.
189 */
190inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Integer>& axes) {
191 static const Op& expand_dims = Op::Get("expand_dims");
192 for (size_t i = axes.size(); i != 0; --i) {
193 if (i == axes.size()) {
194 int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
195 if (num_pad_axis > 0) {
196 auto attrs = make_object<ExpandDimsAttrs>();
197 attrs->axis = i;
198 attrs->num_newaxis = static_cast<int>(num_pad_axis);
199 bias = Call(expand_dims, {bias}, Attrs(attrs), {});
200 }
201 } else {
202 int64_t diff = axes[i]->value - axes[i - 1]->value;
203 ICHECK_GE(diff, 0L);
204 if (diff > 0) {
205 auto attrs = make_object<ExpandDimsAttrs>();
206 attrs->axis = i;
207 attrs->num_newaxis = static_cast<int>(diff);
208 bias = Call(expand_dims, {bias}, Attrs(attrs), {});
209 }
210 }
211 }
212 return bias;
213}
214
215/*!
216 * \brief Check if the call is depthwise conv3d.
217 *
218 * \param call The conv call.
219 * \param param The conv attributes.
220 * \return Whether it is depthwise_conv3d.
221 */
222template <typename ATTRS>
223inline bool IsDepthwiseConv(const Call& call, ATTRS param, const Layout& kernel_layout) {
224 static const Layout kOIXX =
225 backend::IsOp(call.as<CallNode>(), "nn.conv2d") ? Layout("OIHW") : Layout("OIDHW");
226 const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIXX);
227 auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
228 return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1);
229}
230
231/*!
232 * \brief Get super-dimension of output channels of conv2d
233 * \param call The conv2d call.
234 * \return Super-dimension size of output channels of conv2d.
235 */
236inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
237 auto param = call->attrs.as<Conv2DAttrs>();
238 auto tweight = call->args[1]->type_as<TensorTypeNode>();
239 auto index = param->kernel_layout.operator std::string().find('O');
240 ICHECK_NE(index, std::string::npos);
241 auto channels = tir::as_const_int(tweight->shape[index]);
242 return *channels;
243}
244
245/*!
246 * \brief Is single value tensor (scalar).
247 * \param expr The expr.
248 * \return True if single value tensor.
249 */
250inline bool IsScalar(const Expr& expr) {
251 if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) {
252 for (auto dim_index_expr : tensor_type->shape) {
253 if (auto dim_index = dim_index_expr.as<IntImmNode>()) {
254 if (dim_index->value != 1) {
255 return false;
256 }
257 } else {
258 return false;
259 }
260 }
261 } else {
262 return false;
263 }
264 return true;
265}
266
267/*!
268 * \brief Check if expr is a const scalar.
269 * \param expr The expr.
270 * \return True if const scalar.
271 */
272inline bool IsConstScalar(const Expr& expr) {
273 const auto* const_expr = expr.as<ConstantNode>();
274 if (const_expr) {
275 return const_expr->is_scalar();
276 }
277 return false;
278}
279
280/*!
281 * \brief Create a Constant with a scalar
282 *
283 * \param dtype The data type.
284 * \param value The value of the scalar.
285 * \return A Constant.
286 */
287template <typename T>
288inline Constant MakeConstantScalar(DataType dtype, T value) {
289 runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0});
290 TVM_DTYPE_DISPATCH(dtype, DType, {
291 if (dtype == DataType::Float(16)) {
292 // convert to float16
293 // storage is uint16_t
294 *static_cast<DType*>(arr->data) =
295 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
296 } else if (dtype == DataType::BFloat(16)) {
297 // convert to bfloat16
298 // storage is uint16_t
299 *static_cast<DType*>(arr->data) =
300 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(value));
301 } else {
302 *static_cast<DType*>(arr->data) = value;
303 }
304 })
305 return Constant(arr);
306}
307
308/*!
309 * \brief Create a Constant with a tensor.
310 *
311 * \param dtype The data type.
312 * \param value The vector of the tensor values.
313 * \return A Constant.
314 */
315template <typename T>
316static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape,
317 std::vector<T> value) {
318 runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
319 TVM_DTYPE_DISPATCH(dtype, DType, {
320 for (size_t i = 0; i < value.size(); i++) {
321 if (dtype == DataType::Float(16)) {
322 // convert to float16
323 // storage is uint16_t
324 // Similar handling as that in MakeConstantScalar
325 *(static_cast<DType*>(arr->data) + i) =
326 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
327 static_cast<float>(value[i]));
328 } else if (dtype == DataType::BFloat(16)) {
329 // convert to bfloat16
330 // storage is uint16_t
331 *(static_cast<DType*>(arr->data) + i) =
332 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
333 static_cast<float>(value[i]));
334 } else {
335 *(static_cast<DType*>(arr->data) + i) = value[i];
336 }
337 }
338 })
339 return Constant(arr);
340}
341
342/*!
343 * \brief Create a Constant with a tensor.
344 *
345 * \param dtype The data type.
346 * \param value The array of the tensor values.
347 * \return A Constant.
348 */
349template <typename T>
350static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape,
351 Array<T> value) {
352 runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
353 TVM_DTYPE_DISPATCH(dtype, DType, {
354 for (size_t i = 0; i < value.size(); i++) {
355 if (dtype == DataType::Float(16)) {
356 // convert to float16
357 // storage is uint16_t
358 // Similar handling as that in MakeConstantScalar
359 *(static_cast<DType*>(arr->data) + i) =
360 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
361 static_cast<float>(value[i]));
362 } else if (dtype == DataType::BFloat(16)) {
363 // convert to bfloat16
364 // storage is uint16_t
365 *(static_cast<DType*>(arr->data) + i) =
366 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(
367 static_cast<float>(value[i]));
368 } else {
369 *(static_cast<DType*>(arr->data) + i) = value[i];
370 }
371 }
372 })
373 return Constant(arr);
374}
375
376/*!
377 * \brief Create a Constant tensor of zeros.
378 *
379 * \param dtype The data type.
380 * \param shape The shape of the output constant tensor.
381 * \return A Constant.
382 */
383static inline Constant MakeConstantZeros(DataType dtype, std::vector<int64_t> shape) {
384 runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
385 int64_t data_size = 1;
386 for (int64_t dim : shape) {
387 data_size *= dim;
388 }
389 TVM_DTYPE_DISPATCH(dtype, DType, {
390 for (int64_t i = 0; i < data_size; i++) {
391 if (dtype == DataType::Float(16)) {
392 // convert to float16
393 // storage is uint16_t
394 // Similar handling as that in MakeConstantScalar
395 *(static_cast<DType*>(arr->data) + i) =
396 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(0));
397 } else if (dtype == DataType::BFloat(16)) {
398 // convert to bfloat16
399 // storage is uint16_t
400 *(static_cast<DType*>(arr->data) + i) =
401 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(0));
402 } else {
403 *(static_cast<DType*>(arr->data) + i) = 0;
404 }
405 }
406 })
407 return Constant(arr);
408}
409
410/*!
411 * \brief Check whether a shape is static and create corresponding Constant.
412 Eventually this will be removed and replaced with CheckConstantShapeArrayInteger
413 *
414 * \param shape The Array of the shape values.
415 * \return A Constant.
416 */
417static inline Constant CheckConstantShape(const Array<IndexExpr>& shape) {
418 auto shape_array =
419 runtime::NDArray::Empty({int64_t(shape.size())}, DataType::Int(64), {kDLCPU, 0});
420 auto* shape_data = static_cast<int64_t*>(shape_array->data);
421 for (size_t i = 0; i < shape.size(); ++i) {
422 const auto& dim_val = shape[i].as<IntImmNode>();
423 ICHECK(dim_val) << "Do not support symbolic shape for "
424 "Array format. Pass shape as Expr instead.";
425 shape_data[i] = dim_val->value;
426 }
427 return Constant(shape_array);
428}
429
430/*!
431 * \brief Check whether a shape is static and create corresponding Array<Integer>. Will replace
432 * CheckConstantShape after dynamic refactorization is complete
433 *
434 * \param shape The Array of the shape values.
435 * \return A Constant.
436 */
437static inline Array<Integer> CheckConstantShapeArrayInteger(const Array<IndexExpr>& shape) {
438 Array<Integer> constShape;
439
440 for (size_t i = 0; i < shape.size(); ++i) {
441 const auto& dim_val = shape[i].as<IntImmNode>();
442 ICHECK(dim_val) << "Do not support symbolic shape for "
443 "Array format. Pass shape as Expr instead.";
444
445 constShape.push_back(dim_val->value);
446 }
447 return constShape;
448}
449
450/*!
451 * \brief Check if two expressions are equal scalars.
452 * \param a The expression to be checked.
453 * \param b The expression to be checked
454 * \return Whether two expressions are equal scalars.
455 */
456inline bool IsEqualScalar(const Expr& a, const Expr& b) {
457 const auto* constant_a = a.as<ConstantNode>();
458 const auto* constant_b = b.as<ConstantNode>();
459 if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
460 return false;
461 }
462 return tvm::StructuralEqual()(a, b);
463}
464
465/*!
466 * \brief Convert an element of a NDArray with type int or float to scalar.
467 * \param array Input NDArray
468 * \param i element index
469 * \return Converted scalar value, or None if conversion failed
470 */
471static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
472 if (array->dtype.code == kDLInt) {
473 if (array->dtype.bits == 8) {
474 return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
475 } else if (array->dtype.bits == 16) {
476 return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
477 } else if (array->dtype.bits == 32) {
478 return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
479 } else if (array->dtype.bits == 64) {
480 return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
481 }
482 } else if (array->dtype.code == kDLUInt) {
483 if (array->dtype.bits == 1) { // bool
484 return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
485 } else if (array->dtype.bits == 8) {
486 return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
487 } else if (array->dtype.bits == 16) {
488 return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
489 } else if (array->dtype.bits == 32) {
490 return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
491 } else if (array->dtype.bits == 64) {
492 return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
493 }
494 } else if (array->dtype.code == kDLFloat) {
495 if (array->dtype.bits == 16) {
496 return std::optional<long double>(
497 __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
498 reinterpret_cast<uint16_t*>(array->data)[i]));
499 }
500 if (array->dtype.bits == 32) {
501 return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
502 } else if (array->dtype.bits == 64) {
503 return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
504 }
505 } else if (array->dtype.code == kDLBfloat) {
506 if (array->dtype.bits == 16) {
507 return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
508 reinterpret_cast<uint16_t*>(array->data)[i]));
509 }
510 }
511 return std::nullopt;
512}
513
514/*!
515 * \brief Convert an element of a NDArray with type int or float to scalar.
516 * \param array Input NDArray
517 * \param i element index
518 * \return Converted scalar value
519 */
520static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
521 auto try_value = TryToScalar(array, i);
522 ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
523 return try_value.value();
524}
525
526/*!
527 * \brief Convert a NDArray with type int or float to Array<Integer>.
528 * \param array Input NDArray
529 * \return Converted Array.
530 */
531static inline Array<Integer> ToVector(const runtime::NDArray& array) {
532 size_t ndim = array.Shape().size();
533 ICHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays";
534 size_t len = array.Shape().front();
535 Array<Integer> out;
536 for (size_t i = 0; i < len; ++i) {
537 long double elem_val = ToScalar(array, i);
538 out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val))));
539 }
540 return out;
541}
542
543/*!
544 * \brief Convert a NDArray with type int or float to Array<FloatImm>.
545 * \param array Input NDArray
546 * \return Converted Array.
547 */
548static inline Array<FloatImm> ToFloatVector(const runtime::NDArray& array) {
549 size_t ndim = array.Shape().size();
550 ICHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays";
551 size_t len = array.Shape().front();
552 Array<FloatImm> out;
553 for (size_t i = 0; i < len; ++i) {
554 long double elem_val = ToScalar(array, i);
555 out.push_back(FloatImm(DataType::Float(32), static_cast<float>(elem_val)));
556 }
557 return out;
558}
559
560/*!
561 * \brief Convert a NDArray with type int or float to Array<Array<Integer>>.
562 * \param array Input NDArray
563 * \return Converted Array.
564 */
565static inline Array<Array<Integer>> ToMatrix(const runtime::NDArray& array) {
566 size_t ndim = array.Shape().size();
567 ICHECK_EQ(ndim, 2) << "This function should only used for 2D NDArrays";
568 size_t dim1 = array.Shape().at(0);
569 size_t dim2 = array.Shape().at(1);
570
571 Array<Array<Integer>> out;
572
573 for (size_t i = 0; i < dim1; ++i) {
574 Array<Integer> inner_out;
575 for (size_t j = 0; j < dim2; ++j) {
576 double elem_val = ToScalar(array, i * dim2 + j);
577 inner_out.push_back(Integer(static_cast<int>(elem_val)));
578 }
579 out.push_back(inner_out);
580 }
581 return out;
582}
583
584inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); }
585
586inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); }
587
588inline Expr Exp(Expr e) {
589 static const Op& op = Op::Get("exp");
590 return Call(op, {e});
591}
592
593inline Expr Erf(Expr e) {
594 static const Op& op = Op::Get("erf");
595 return Call(op, {e});
596}
597
598inline Expr FastExp(Expr e) {
599 static const Op& op = Op::Get("fast_exp");
600 return Call(op, {e});
601}
602
603inline Expr FastErf(Expr e) {
604 static const Op& op = Op::Get("fast_erf");
605 return Call(op, {e});
606}
607
608inline Expr FastTanh(Expr e) {
609 static const Op& op = Op::Get("fast_tanh");
610 return Call(op, {e});
611}
612
613inline Expr FastSoftmax(Expr e, tvm::Attrs attr) {
614 static const Op& op = Op::Get("nn.fast_softmax");
615 return Call(op, {e}, attr);
616}
617
618inline Expr Log(Expr e) {
619 static const Op& op = Op::Get("log");
620 return Call(op, {e});
621}
622
623inline Expr Tanh(Expr e) {
624 static const Op& op = Op::Get("tanh");
625 return Call(op, {e});
626}
627
628inline Expr Abs(Expr e) {
629 static const Op& op = Op::Get("abs");
630 return Call(op, {e});
631}
632/*!
633 * \brief Get an immediate scalar from a Constant expr.
634 *
635 * \param expr The Constant expr.
636 * \return A scalar with type T.
637 */
638template <typename T>
639T GetScalarFromConstant(Expr expr) {
640 const auto* n = expr.as<ConstantNode>();
641 ICHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
642 ICHECK(n->is_scalar());
643 return static_cast<T*>(n->data->data)[0];
644}
645
646inline Expr Cast(Expr x, DataType dtype) { return MakeCast(x, dtype); }
647
648inline Expr Negative(Expr x) {
649 static const Op& op = Op::Get("negative");
650 return Call(op, {x}, Attrs(), {});
651}
652
653inline Expr Sqrt(Expr x) {
654 static const Op& op = Op::Get("sqrt");
655 return Call(op, {x}, Attrs(), {});
656}
657
658inline Expr Sigmoid(Expr x) {
659 static const Op& op = Op::Get("sigmoid");
660 return Call(op, {x}, Attrs(), {});
661}
662
663inline Expr Rsqrt(Expr x) {
664 static const Op& op = Op::Get("rsqrt");
665 return Call(op, {x}, Attrs(), {});
666}
667
668inline Expr Relu(Expr x) {
669 static const Op& op = Op::Get("nn.relu");
670 return Call(op, {x}, Attrs(), {});
671}
672
673inline Expr Round(Expr x) {
674 static const Op& op = Op::Get("round");
675 return Call(op, {x}, Attrs(), {});
676}
677
678inline Expr Floor(Expr x) {
679 static const Op& op = Op::Get("floor");
680 return Call(op, {x}, Attrs(), {});
681}
682
683inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); }
684
685inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) {
686 static const Op& op = Op::Get("fixed_point_multiply");
687 auto attrs = make_object<FixedPointMultiplyAttrs>();
688 attrs->multiplier = multiplier;
689 attrs->shift = shift;
690 return Call(op, {x}, Attrs(attrs), {});
691}
692
693inline Expr FixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift,
694 bool is_lshift_required, bool is_rshift_required,
695 Array<Integer> axes) {
696 return MakeFixedPointMultiplyPerAxis(x, m, lshift, rshift, is_lshift_required, is_rshift_required,
697 axes);
698}
699
700inline Expr Add(Expr lhs, Expr rhs) {
701 static const Op& op = Op::Get("add");
702 return Call(op, {lhs, rhs}, Attrs(), {});
703}
704
705inline Expr Subtract(Expr lhs, Expr rhs) {
706 static const Op& op = Op::Get("subtract");
707 return Call(op, {lhs, rhs}, Attrs(), {});
708}
709
710inline Expr Multiply(Expr lhs, Expr rhs) {
711 static const Op& op = Op::Get("multiply");
712 return Call(op, {lhs, rhs}, Attrs(), {});
713}
714
715inline Expr Divide(Expr lhs, Expr rhs) {
716 static const Op& op = Op::Get("divide");
717 return Call(op, {lhs, rhs}, Attrs(), {});
718}
719
720inline Expr Maximum(Expr lhs, Expr rhs) {
721 static const Op& op = Op::Get("maximum");
722 return Call(op, {lhs, rhs}, Attrs(), {});
723}
724
725inline Expr ZerosLike(Expr e) {
726 static const Op& op = Op::Get("zeros_like");
727 return Call(op, {e});
728}
729
730inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
731 return MakeZeros(CheckConstantShapeArrayInteger(shape), dtype);
732}
733
734inline Expr OnesLike(Expr e) {
735 static const Op& op = Op::Get("ones_like");
736 return Call(op, {e});
737}
738
739inline Expr Ones(Array<IndexExpr> shape, DataType dtype) {
740 return MakeOnes(CheckConstantShapeArrayInteger(shape), dtype);
741}
742
743inline Expr CollapseSumLike(Expr e) {
744 static const Op& op = Op::Get("collapse_sum_like");
745 return Call(op, {e});
746}
747
748inline Expr Power(Expr lhs, Expr rhs) {
749 static const Op& op = Op::Get("power");
750 return Call(op, {lhs, rhs}, Attrs(), {});
751}
752
753inline Expr RightShift(Expr x, Expr nbit) {
754 static const Op& op = Op::Get("right_shift");
755 return Call(op, {x, nbit}, Attrs(), {});
756}
757
758inline Expr LeftShift(Expr x, Expr nbit) {
759 static const Op& op = Op::Get("left_shift");
760 return Call(op, {x, nbit}, Attrs(), {});
761}
762
763inline Expr ReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
764 Integer rhs_end) {
765 return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end);
766}
767
768inline Expr Copy(Expr data) {
769 static const Op& op = Op::Get("copy");
770 return Call(op, {data}, Attrs(), {});
771}
772
773inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
774 return MakeReduce(data, axis, keepdims, exclude, "mean");
775}
776
777inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude,
778 bool unbiased = false) {
779 return MakeVariance(data, mean, axis, keepdims, exclude, unbiased);
780}
781
782static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
783 static const Op& op = Op::Get("where");
784 return Call(op, {condition, x, y});
785}
786
787static inline Expr LogicalOr(const Expr& lhs, const Expr& rhs) {
788 static const Op& op = Op::Get("logical_or");
789 return Call(op, {lhs, rhs}, Attrs(), {});
790}
791
792static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
793 static const Op& op = Op::Get("greater_equal");
794 return Call(op, {lhs, rhs}, Attrs(), {});
795}
796
797static inline Expr Equal(const Expr& lhs, const Expr& rhs) {
798 static const Op& op = Op::Get("equal");
799 return Call(op, {lhs, rhs}, Attrs(), {});
800}
801
802static inline Expr Less(const Expr& lhs, const Expr& rhs) {
803 static const Op& op = Op::Get("less");
804 return Call(op, {lhs, rhs}, Attrs(), {});
805}
806
807static inline Expr IsFinite(const Expr x) {
808 static const Op& op = Op::Get("isfinite");
809 return Call(op, {x}, Attrs(), {});
810}
811
812static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType dtype) {
813 return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype);
814}
815
816static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
817 Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
818 IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
819 std::string kernel_layout, std::string out_layout, DataType out_dtype) {
820 return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels,
821 kernel_size, data_layout, kernel_layout, out_layout, out_dtype,
822 "nn.conv2d");
823}
824
825static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
826 return MakeDense(data, weight, units, out_dtype);
827}
828
829static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
830 return MakeReduce(data, axis, keepdims, exclude, "sum");
831}
832
833static inline Expr Prod(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
834 return MakeReduce(data, axis, keepdims, exclude, "prod");
835}
836
837static inline Expr Reshape(Expr data, Array<Integer> newshape) {
838 return MakeReshape(data, newshape);
839}
840
841static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
842 Array<IndexExpr> dilation, Array<IndexExpr> padding,
843 std::string layout, std::string out_layout, bool ceil_mode,
844 bool count_include_pad) {
845 return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, dilation, padding, layout,
846 out_layout, ceil_mode, count_include_pad, "nn.avg_pool2d");
847}
848
849static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, Expr pad_value,
850 std::string pad_mode) {
851 Array<Array<Integer>> pad_width_int;
852 for (size_t i = 0; i < pad_width.size(); ++i) {
853 pad_width_int.push_back(CheckConstantShapeArrayInteger(pad_width[i]));
854 }
855 return MakePad(data, pad_width_int, pad_value, pad_mode);
856}
857
858static inline Expr Tile(Expr data, Array<Integer> reps) { return MakeTile(data, reps); }
859
860static inline Expr BroadCastTo(Expr data, Array<IndexExpr> shape) {
861 return MakeBroadCastTo(data, CheckConstantShapeArrayInteger(shape));
862}
863
864inline Expr Hardswish(Expr x) {
865 auto three = MakeConstantScalar(DataType::Float(32), 3.0);
866 auto six = MakeConstantScalar(DataType::Float(32), 6.0);
867 auto x2 = Add(x, three);
868 x2 = Clip(x2, 0.0, 6.0);
869 x2 = Multiply(x, x2);
870 x2 = Divide(x2, six);
871 return x2;
872}
873
874} // namespace relay
875} // namespace tvm
876#endif // TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
877