1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file tvm/relay/transforms/pattern_utils.h |
23 | * \brief Header of internal operator functions |
24 | * These can be used for writing passes. |
25 | */ |
26 | #ifndef TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ |
27 | #define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ |
28 | |
29 | #include <builtin_fp16.h> |
30 | #include <tvm/node/structural_equal.h> |
31 | #include <tvm/relay/analysis.h> |
32 | #include <tvm/relay/attrs/nn.h> |
33 | #include <tvm/relay/attrs/reduce.h> |
34 | #include <tvm/relay/attrs/transform.h> |
35 | #include <tvm/relay/expr.h> |
36 | #include <tvm/relay/op.h> |
37 | #include <tvm/relay/op_attr_types.h> |
38 | #include <tvm/runtime/registry.h> |
39 | #include <tvm/tir/data_layout.h> |
40 | |
41 | #include <limits> |
42 | #include <optional> |
43 | #include <string> |
44 | #include <utility> |
45 | #include <vector> |
46 | |
47 | #include "../backend/utils.h" |
48 | #include "../op/make_op.h" |
49 | |
50 | namespace tvm { |
51 | namespace relay { |
52 | |
53 | /*! |
54 | * \brief Dispatch DataType to the C++ data type |
55 | * during runtime. |
56 | */ |
57 | #define TVM_DTYPE_DISPATCH(type, DType, ...) \ |
58 | if (type == DataType::Float(64)) { \ |
59 | typedef double DType; \ |
60 | { __VA_ARGS__ } \ |
61 | } else if (type == DataType::Float(32)) { \ |
62 | typedef float DType; \ |
63 | { __VA_ARGS__ } \ |
64 | } else if (type == DataType::Float(16)) { \ |
65 | typedef uint16_t DType; \ |
66 | { __VA_ARGS__ } \ |
67 | } else if (type == DataType::BFloat(16)) { \ |
68 | typedef uint16_t DType; \ |
69 | { __VA_ARGS__ } \ |
70 | } else if (type == DataType::Int(64)) { \ |
71 | typedef int64_t DType; \ |
72 | { __VA_ARGS__ } \ |
73 | } else if (type == DataType::Int(32)) { \ |
74 | typedef int32_t DType; \ |
75 | { __VA_ARGS__ } \ |
76 | } else if (type == DataType::Int(16)) { \ |
77 | typedef int16_t DType; \ |
78 | { __VA_ARGS__ } \ |
79 | } else if (type == DataType::Int(8)) { \ |
80 | typedef int8_t DType; \ |
81 | { __VA_ARGS__ } \ |
82 | } else if (type == DataType::UInt(64)) { \ |
83 | typedef uint64_t DType; \ |
84 | { __VA_ARGS__ } \ |
85 | } else if (type == DataType::UInt(32)) { \ |
86 | typedef uint32_t DType; \ |
87 | { __VA_ARGS__ } \ |
88 | } else if (type == DataType::UInt(16)) { \ |
89 | typedef uint16_t DType; \ |
90 | { __VA_ARGS__ } \ |
91 | } else if (type == DataType::UInt(8)) { \ |
92 | typedef uint8_t DType; \ |
93 | { __VA_ARGS__ } \ |
94 | } else if (type == DataType::Bool()) { \ |
95 | typedef bool DType; \ |
96 | { __VA_ARGS__ } \ |
97 | } else if ((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \ |
98 | static_cast<uint8_t>(type.code()))) { \ |
99 | typedef double DType; \ |
100 | { __VA_ARGS__ } \ |
101 | } else { \ |
102 | LOG(FATAL) << "unknown data type " << type; \ |
103 | } |
104 | |
105 | /*! |
106 | * \brief Try to do the type inference over expr: |
107 | * |
108 | * Do the infer_type over each node in expr |
109 | * |
110 | * \param expr The IR expression |
111 | * \return infered expr if succeed. |
112 | */ |
113 | inline Expr InferType(const Expr& expr) { |
114 | auto mod = IRModule::FromExpr(expr); |
115 | mod = transform::InferType()(mod); |
116 | if (expr.as<FunctionNode>()) { |
117 | return mod->Lookup("main" ); |
118 | } else { |
119 | return mod->Lookup("main" ).as<FunctionNode>()->body; |
120 | } |
121 | } |
122 | |
123 | /*! |
124 | * \brief Try to match lhs and rhs via broadcasting rule, such that: |
125 | * |
126 | * rhs matches the dimension of lhs specified by lhs_axes |
127 | * rhs's value equals 1 on rest of dimensions. |
128 | * |
129 | * \param tlhs The type of left operand (data) |
130 | * \param trhs The type right operand (bias) |
131 | * \param lhs_axes The axes on lhs to match. |
132 | * \param rhs_value A squeezed version of rhs which only contains matched dimension. |
133 | * \return Whether match is successful. |
134 | */ |
135 | inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTypeNode* trhs, |
136 | const Array<Integer>& lhs_axes, Expr* rhs_value = nullptr) { |
137 | if (tlhs->shape.size() < trhs->shape.size()) return false; |
138 | StructuralEqual equal; |
139 | size_t base = tlhs->shape.size() - trhs->shape.size(); |
140 | size_t j = 0; |
141 | |
142 | // handle case trhs is simple constant |
143 | if (trhs->shape.size() == 0 && rhs_value != nullptr && lhs_axes.size() > 0) { |
144 | *rhs_value = MakeExpandDims(*rhs_value, 0, lhs_axes.size()); |
145 | for (size_t i = 0; i < lhs_axes.size(); i++) { |
146 | int repeat_value = |
147 | tlhs->shape[static_cast<size_t>(lhs_axes[j]->value)].as<IntImmNode>()->value; |
148 | *rhs_value = MakeRepeat(*rhs_value, repeat_value, i); |
149 | } |
150 | return true; |
151 | } |
152 | |
153 | ObjectPtr<SqueezeAttrs> squeeze_attrs; |
154 | if (rhs_value != nullptr) { |
155 | squeeze_attrs = make_object<SqueezeAttrs>(); |
156 | } |
157 | |
158 | for (size_t i = 0; i < tlhs->shape.size(); ++i) { |
159 | if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) { |
160 | if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) { |
161 | return false; |
162 | } |
163 | ++j; |
164 | } else if (i >= base) { |
165 | if (!tir::is_const_int(trhs->shape[i - base], 1)) { |
166 | return false; |
167 | } |
168 | if (rhs_value != nullptr) { |
169 | squeeze_attrs->axis.push_back(static_cast<int>(i - base)); |
170 | } |
171 | } |
172 | } |
173 | if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) { |
174 | static const Op& squeeze_op = Op::Get("squeeze" ); |
175 | *rhs_value = Call(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {}); |
176 | } |
177 | return true; |
178 | } |
179 | |
180 | /*! |
181 | * \brief Expand 1D Tensor to match axis. |
182 | * |
183 | * The result bias can be used to add or multiply to |
184 | * the target Tensor on the specified axis via broadcasting rule. |
185 | * |
186 | * \param bias The bias. |
187 | * \param target_ndim Target dimension. |
188 | * \param axes The axis on the output we want to match on. |
189 | */ |
190 | inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array<Integer>& axes) { |
191 | static const Op& expand_dims = Op::Get("expand_dims" ); |
192 | for (size_t i = axes.size(); i != 0; --i) { |
193 | if (i == axes.size()) { |
194 | int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1; |
195 | if (num_pad_axis > 0) { |
196 | auto attrs = make_object<ExpandDimsAttrs>(); |
197 | attrs->axis = i; |
198 | attrs->num_newaxis = static_cast<int>(num_pad_axis); |
199 | bias = Call(expand_dims, {bias}, Attrs(attrs), {}); |
200 | } |
201 | } else { |
202 | int64_t diff = axes[i]->value - axes[i - 1]->value; |
203 | ICHECK_GE(diff, 0L); |
204 | if (diff > 0) { |
205 | auto attrs = make_object<ExpandDimsAttrs>(); |
206 | attrs->axis = i; |
207 | attrs->num_newaxis = static_cast<int>(diff); |
208 | bias = Call(expand_dims, {bias}, Attrs(attrs), {}); |
209 | } |
210 | } |
211 | } |
212 | return bias; |
213 | } |
214 | |
215 | /*! |
216 | * \brief Check if the call is depthwise conv3d. |
217 | * |
218 | * \param call The conv call. |
219 | * \param param The conv attributes. |
220 | * \return Whether it is depthwise_conv3d. |
221 | */ |
222 | template <typename ATTRS> |
223 | inline bool IsDepthwiseConv(const Call& call, ATTRS param, const Layout& kernel_layout) { |
224 | static const Layout kOIXX = |
225 | backend::IsOp(call.as<CallNode>(), "nn.conv2d" ) ? Layout("OIHW" ) : Layout("OIDHW" ); |
226 | const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIXX); |
227 | auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape); |
228 | return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1); |
229 | } |
230 | |
231 | /*! |
232 | * \brief Get super-dimension of output channels of conv2d |
233 | * \param call The conv2d call. |
234 | * \return Super-dimension size of output channels of conv2d. |
235 | */ |
236 | inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { |
237 | auto param = call->attrs.as<Conv2DAttrs>(); |
238 | auto tweight = call->args[1]->type_as<TensorTypeNode>(); |
239 | auto index = param->kernel_layout.operator std::string().find('O'); |
240 | ICHECK_NE(index, std::string::npos); |
241 | auto channels = tir::as_const_int(tweight->shape[index]); |
242 | return *channels; |
243 | } |
244 | |
245 | /*! |
246 | * \brief Is single value tensor (scalar). |
247 | * \param expr The expr. |
248 | * \return True if single value tensor. |
249 | */ |
250 | inline bool IsScalar(const Expr& expr) { |
251 | if (auto tensor_type = expr->checked_type().as<TensorTypeNode>()) { |
252 | for (auto dim_index_expr : tensor_type->shape) { |
253 | if (auto dim_index = dim_index_expr.as<IntImmNode>()) { |
254 | if (dim_index->value != 1) { |
255 | return false; |
256 | } |
257 | } else { |
258 | return false; |
259 | } |
260 | } |
261 | } else { |
262 | return false; |
263 | } |
264 | return true; |
265 | } |
266 | |
267 | /*! |
268 | * \brief Check if expr is a const scalar. |
269 | * \param expr The expr. |
270 | * \return True if const scalar. |
271 | */ |
272 | inline bool IsConstScalar(const Expr& expr) { |
273 | const auto* const_expr = expr.as<ConstantNode>(); |
274 | if (const_expr) { |
275 | return const_expr->is_scalar(); |
276 | } |
277 | return false; |
278 | } |
279 | |
280 | /*! |
281 | * \brief Create a Constant with a scalar |
282 | * |
283 | * \param dtype The data type. |
284 | * \param value The value of the scalar. |
285 | * \return A Constant. |
286 | */ |
287 | template <typename T> |
288 | inline Constant MakeConstantScalar(DataType dtype, T value) { |
289 | runtime::NDArray arr = runtime::NDArray::Empty({}, dtype, {kDLCPU, 0}); |
290 | TVM_DTYPE_DISPATCH(dtype, DType, { |
291 | if (dtype == DataType::Float(16)) { |
292 | // convert to float16 |
293 | // storage is uint16_t |
294 | *static_cast<DType*>(arr->data) = |
295 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value)); |
296 | } else if (dtype == DataType::BFloat(16)) { |
297 | // convert to bfloat16 |
298 | // storage is uint16_t |
299 | *static_cast<DType*>(arr->data) = |
300 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(value)); |
301 | } else { |
302 | *static_cast<DType*>(arr->data) = value; |
303 | } |
304 | }) |
305 | return Constant(arr); |
306 | } |
307 | |
308 | /*! |
309 | * \brief Create a Constant with a tensor. |
310 | * |
311 | * \param dtype The data type. |
312 | * \param value The vector of the tensor values. |
313 | * \return A Constant. |
314 | */ |
315 | template <typename T> |
316 | static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape, |
317 | std::vector<T> value) { |
318 | runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); |
319 | TVM_DTYPE_DISPATCH(dtype, DType, { |
320 | for (size_t i = 0; i < value.size(); i++) { |
321 | if (dtype == DataType::Float(16)) { |
322 | // convert to float16 |
323 | // storage is uint16_t |
324 | // Similar handling as that in MakeConstantScalar |
325 | *(static_cast<DType*>(arr->data) + i) = |
326 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>( |
327 | static_cast<float>(value[i])); |
328 | } else if (dtype == DataType::BFloat(16)) { |
329 | // convert to bfloat16 |
330 | // storage is uint16_t |
331 | *(static_cast<DType*>(arr->data) + i) = |
332 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>( |
333 | static_cast<float>(value[i])); |
334 | } else { |
335 | *(static_cast<DType*>(arr->data) + i) = value[i]; |
336 | } |
337 | } |
338 | }) |
339 | return Constant(arr); |
340 | } |
341 | |
342 | /*! |
343 | * \brief Create a Constant with a tensor. |
344 | * |
345 | * \param dtype The data type. |
346 | * \param value The array of the tensor values. |
347 | * \return A Constant. |
348 | */ |
349 | template <typename T> |
350 | static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> shape, |
351 | Array<T> value) { |
352 | runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); |
353 | TVM_DTYPE_DISPATCH(dtype, DType, { |
354 | for (size_t i = 0; i < value.size(); i++) { |
355 | if (dtype == DataType::Float(16)) { |
356 | // convert to float16 |
357 | // storage is uint16_t |
358 | // Similar handling as that in MakeConstantScalar |
359 | *(static_cast<DType*>(arr->data) + i) = |
360 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>( |
361 | static_cast<float>(value[i])); |
362 | } else if (dtype == DataType::BFloat(16)) { |
363 | // convert to bfloat16 |
364 | // storage is uint16_t |
365 | *(static_cast<DType*>(arr->data) + i) = |
366 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>( |
367 | static_cast<float>(value[i])); |
368 | } else { |
369 | *(static_cast<DType*>(arr->data) + i) = value[i]; |
370 | } |
371 | } |
372 | }) |
373 | return Constant(arr); |
374 | } |
375 | |
376 | /*! |
377 | * \brief Create a Constant tensor of zeros. |
378 | * |
379 | * \param dtype The data type. |
380 | * \param shape The shape of the output constant tensor. |
381 | * \return A Constant. |
382 | */ |
383 | static inline Constant MakeConstantZeros(DataType dtype, std::vector<int64_t> shape) { |
384 | runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); |
385 | int64_t data_size = 1; |
386 | for (int64_t dim : shape) { |
387 | data_size *= dim; |
388 | } |
389 | TVM_DTYPE_DISPATCH(dtype, DType, { |
390 | for (int64_t i = 0; i < data_size; i++) { |
391 | if (dtype == DataType::Float(16)) { |
392 | // convert to float16 |
393 | // storage is uint16_t |
394 | // Similar handling as that in MakeConstantScalar |
395 | *(static_cast<DType*>(arr->data) + i) = |
396 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(0)); |
397 | } else if (dtype == DataType::BFloat(16)) { |
398 | // convert to bfloat16 |
399 | // storage is uint16_t |
400 | *(static_cast<DType*>(arr->data) + i) = |
401 | __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(0)); |
402 | } else { |
403 | *(static_cast<DType*>(arr->data) + i) = 0; |
404 | } |
405 | } |
406 | }) |
407 | return Constant(arr); |
408 | } |
409 | |
410 | /*! |
411 | * \brief Check whether a shape is static and create corresponding Constant. |
412 | Eventually this will be removed and replaced with CheckConstantShapeArrayInteger |
413 | * |
414 | * \param shape The Array of the shape values. |
415 | * \return A Constant. |
416 | */ |
417 | static inline Constant CheckConstantShape(const Array<IndexExpr>& shape) { |
418 | auto shape_array = |
419 | runtime::NDArray::Empty({int64_t(shape.size())}, DataType::Int(64), {kDLCPU, 0}); |
420 | auto* shape_data = static_cast<int64_t*>(shape_array->data); |
421 | for (size_t i = 0; i < shape.size(); ++i) { |
422 | const auto& dim_val = shape[i].as<IntImmNode>(); |
423 | ICHECK(dim_val) << "Do not support symbolic shape for " |
424 | "Array format. Pass shape as Expr instead." ; |
425 | shape_data[i] = dim_val->value; |
426 | } |
427 | return Constant(shape_array); |
428 | } |
429 | |
430 | /*! |
431 | * \brief Check whether a shape is static and create corresponding Array<Integer>. Will replace |
432 | * CheckConstantShape after dynamic refactorization is complete |
433 | * |
434 | * \param shape The Array of the shape values. |
435 | * \return A Constant. |
436 | */ |
437 | static inline Array<Integer> CheckConstantShapeArrayInteger(const Array<IndexExpr>& shape) { |
438 | Array<Integer> constShape; |
439 | |
440 | for (size_t i = 0; i < shape.size(); ++i) { |
441 | const auto& dim_val = shape[i].as<IntImmNode>(); |
442 | ICHECK(dim_val) << "Do not support symbolic shape for " |
443 | "Array format. Pass shape as Expr instead." ; |
444 | |
445 | constShape.push_back(dim_val->value); |
446 | } |
447 | return constShape; |
448 | } |
449 | |
450 | /*! |
451 | * \brief Check if two expressions are equal scalars. |
452 | * \param a The expression to be checked. |
453 | * \param b The expression to be checked |
454 | * \return Whether two expressions are equal scalars. |
455 | */ |
456 | inline bool IsEqualScalar(const Expr& a, const Expr& b) { |
457 | const auto* constant_a = a.as<ConstantNode>(); |
458 | const auto* constant_b = b.as<ConstantNode>(); |
459 | if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) { |
460 | return false; |
461 | } |
462 | return tvm::StructuralEqual()(a, b); |
463 | } |
464 | |
465 | /*! |
466 | * \brief Convert an element of a NDArray with type int or float to scalar. |
467 | * \param array Input NDArray |
468 | * \param i element index |
469 | * \return Converted scalar value, or None if conversion failed |
470 | */ |
471 | static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) { |
472 | if (array->dtype.code == kDLInt) { |
473 | if (array->dtype.bits == 8) { |
474 | return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]); |
475 | } else if (array->dtype.bits == 16) { |
476 | return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]); |
477 | } else if (array->dtype.bits == 32) { |
478 | return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]); |
479 | } else if (array->dtype.bits == 64) { |
480 | return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]); |
481 | } |
482 | } else if (array->dtype.code == kDLUInt) { |
483 | if (array->dtype.bits == 1) { // bool |
484 | return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); |
485 | } else if (array->dtype.bits == 8) { |
486 | return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]); |
487 | } else if (array->dtype.bits == 16) { |
488 | return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]); |
489 | } else if (array->dtype.bits == 32) { |
490 | return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]); |
491 | } else if (array->dtype.bits == 64) { |
492 | return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]); |
493 | } |
494 | } else if (array->dtype.code == kDLFloat) { |
495 | if (array->dtype.bits == 16) { |
496 | return std::optional<long double>( |
497 | __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>( |
498 | reinterpret_cast<uint16_t*>(array->data)[i])); |
499 | } |
500 | if (array->dtype.bits == 32) { |
501 | return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]); |
502 | } else if (array->dtype.bits == 64) { |
503 | return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]); |
504 | } |
505 | } else if (array->dtype.code == kDLBfloat) { |
506 | if (array->dtype.bits == 16) { |
507 | return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>( |
508 | reinterpret_cast<uint16_t*>(array->data)[i])); |
509 | } |
510 | } |
511 | return std::nullopt; |
512 | } |
513 | |
514 | /*! |
515 | * \brief Convert an element of a NDArray with type int or float to scalar. |
516 | * \param array Input NDArray |
517 | * \param i element index |
518 | * \return Converted scalar value |
519 | */ |
520 | static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) { |
521 | auto try_value = TryToScalar(array, i); |
522 | ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); |
523 | return try_value.value(); |
524 | } |
525 | |
526 | /*! |
527 | * \brief Convert a NDArray with type int or float to Array<Integer>. |
528 | * \param array Input NDArray |
529 | * \return Converted Array. |
530 | */ |
531 | static inline Array<Integer> ToVector(const runtime::NDArray& array) { |
532 | size_t ndim = array.Shape().size(); |
533 | ICHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays" ; |
534 | size_t len = array.Shape().front(); |
535 | Array<Integer> out; |
536 | for (size_t i = 0; i < len; ++i) { |
537 | long double elem_val = ToScalar(array, i); |
538 | out.push_back(Integer(IntImm(DataType::Int(32), static_cast<int64_t>(elem_val)))); |
539 | } |
540 | return out; |
541 | } |
542 | |
543 | /*! |
544 | * \brief Convert a NDArray with type int or float to Array<FloatImm>. |
545 | * \param array Input NDArray |
546 | * \return Converted Array. |
547 | */ |
548 | static inline Array<FloatImm> ToFloatVector(const runtime::NDArray& array) { |
549 | size_t ndim = array.Shape().size(); |
550 | ICHECK_EQ(ndim, 1) << "This function should only be used for 1D NDArrays" ; |
551 | size_t len = array.Shape().front(); |
552 | Array<FloatImm> out; |
553 | for (size_t i = 0; i < len; ++i) { |
554 | long double elem_val = ToScalar(array, i); |
555 | out.push_back(FloatImm(DataType::Float(32), static_cast<float>(elem_val))); |
556 | } |
557 | return out; |
558 | } |
559 | |
560 | /*! |
561 | * \brief Convert a NDArray with type int or float to Array<Array<Integer>>. |
562 | * \param array Input NDArray |
563 | * \return Converted Array. |
564 | */ |
565 | static inline Array<Array<Integer>> ToMatrix(const runtime::NDArray& array) { |
566 | size_t ndim = array.Shape().size(); |
567 | ICHECK_EQ(ndim, 2) << "This function should only used for 2D NDArrays" ; |
568 | size_t dim1 = array.Shape().at(0); |
569 | size_t dim2 = array.Shape().at(1); |
570 | |
571 | Array<Array<Integer>> out; |
572 | |
573 | for (size_t i = 0; i < dim1; ++i) { |
574 | Array<Integer> inner_out; |
575 | for (size_t j = 0; j < dim2; ++j) { |
576 | double elem_val = ToScalar(array, i * dim2 + j); |
577 | inner_out.push_back(Integer(static_cast<int>(elem_val))); |
578 | } |
579 | out.push_back(inner_out); |
580 | } |
581 | return out; |
582 | } |
583 | |
584 | inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); } |
585 | |
586 | inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); } |
587 | |
588 | inline Expr Exp(Expr e) { |
589 | static const Op& op = Op::Get("exp" ); |
590 | return Call(op, {e}); |
591 | } |
592 | |
593 | inline Expr Erf(Expr e) { |
594 | static const Op& op = Op::Get("erf" ); |
595 | return Call(op, {e}); |
596 | } |
597 | |
598 | inline Expr FastExp(Expr e) { |
599 | static const Op& op = Op::Get("fast_exp" ); |
600 | return Call(op, {e}); |
601 | } |
602 | |
603 | inline Expr FastErf(Expr e) { |
604 | static const Op& op = Op::Get("fast_erf" ); |
605 | return Call(op, {e}); |
606 | } |
607 | |
608 | inline Expr FastTanh(Expr e) { |
609 | static const Op& op = Op::Get("fast_tanh" ); |
610 | return Call(op, {e}); |
611 | } |
612 | |
613 | inline Expr FastSoftmax(Expr e, tvm::Attrs attr) { |
614 | static const Op& op = Op::Get("nn.fast_softmax" ); |
615 | return Call(op, {e}, attr); |
616 | } |
617 | |
618 | inline Expr Log(Expr e) { |
619 | static const Op& op = Op::Get("log" ); |
620 | return Call(op, {e}); |
621 | } |
622 | |
623 | inline Expr Tanh(Expr e) { |
624 | static const Op& op = Op::Get("tanh" ); |
625 | return Call(op, {e}); |
626 | } |
627 | |
628 | inline Expr Abs(Expr e) { |
629 | static const Op& op = Op::Get("abs" ); |
630 | return Call(op, {e}); |
631 | } |
632 | /*! |
633 | * \brief Get an immediate scalar from a Constant expr. |
634 | * |
635 | * \param expr The Constant expr. |
636 | * \return A scalar with type T. |
637 | */ |
638 | template <typename T> |
639 | T GetScalarFromConstant(Expr expr) { |
640 | const auto* n = expr.as<ConstantNode>(); |
641 | ICHECK(n) << "Expr must be a constant expr - " << AsText(expr, false); |
642 | ICHECK(n->is_scalar()); |
643 | return static_cast<T*>(n->data->data)[0]; |
644 | } |
645 | |
646 | inline Expr Cast(Expr x, DataType dtype) { return MakeCast(x, dtype); } |
647 | |
648 | inline Expr Negative(Expr x) { |
649 | static const Op& op = Op::Get("negative" ); |
650 | return Call(op, {x}, Attrs(), {}); |
651 | } |
652 | |
653 | inline Expr Sqrt(Expr x) { |
654 | static const Op& op = Op::Get("sqrt" ); |
655 | return Call(op, {x}, Attrs(), {}); |
656 | } |
657 | |
658 | inline Expr Sigmoid(Expr x) { |
659 | static const Op& op = Op::Get("sigmoid" ); |
660 | return Call(op, {x}, Attrs(), {}); |
661 | } |
662 | |
663 | inline Expr Rsqrt(Expr x) { |
664 | static const Op& op = Op::Get("rsqrt" ); |
665 | return Call(op, {x}, Attrs(), {}); |
666 | } |
667 | |
668 | inline Expr Relu(Expr x) { |
669 | static const Op& op = Op::Get("nn.relu" ); |
670 | return Call(op, {x}, Attrs(), {}); |
671 | } |
672 | |
673 | inline Expr Round(Expr x) { |
674 | static const Op& op = Op::Get("round" ); |
675 | return Call(op, {x}, Attrs(), {}); |
676 | } |
677 | |
678 | inline Expr Floor(Expr x) { |
679 | static const Op& op = Op::Get("floor" ); |
680 | return Call(op, {x}, Attrs(), {}); |
681 | } |
682 | |
683 | inline Expr Clip(Expr x, double a_min, double a_max) { return MakeClip(x, a_min, a_max); } |
684 | |
685 | inline Expr FixedPointMultiply(Expr x, int32_t multiplier, int32_t shift) { |
686 | static const Op& op = Op::Get("fixed_point_multiply" ); |
687 | auto attrs = make_object<FixedPointMultiplyAttrs>(); |
688 | attrs->multiplier = multiplier; |
689 | attrs->shift = shift; |
690 | return Call(op, {x}, Attrs(attrs), {}); |
691 | } |
692 | |
693 | inline Expr FixedPointMultiplyPerAxis(Expr x, Expr m, Expr lshift, Expr rshift, |
694 | bool is_lshift_required, bool is_rshift_required, |
695 | Array<Integer> axes) { |
696 | return MakeFixedPointMultiplyPerAxis(x, m, lshift, rshift, is_lshift_required, is_rshift_required, |
697 | axes); |
698 | } |
699 | |
700 | inline Expr Add(Expr lhs, Expr rhs) { |
701 | static const Op& op = Op::Get("add" ); |
702 | return Call(op, {lhs, rhs}, Attrs(), {}); |
703 | } |
704 | |
705 | inline Expr Subtract(Expr lhs, Expr rhs) { |
706 | static const Op& op = Op::Get("subtract" ); |
707 | return Call(op, {lhs, rhs}, Attrs(), {}); |
708 | } |
709 | |
710 | inline Expr Multiply(Expr lhs, Expr rhs) { |
711 | static const Op& op = Op::Get("multiply" ); |
712 | return Call(op, {lhs, rhs}, Attrs(), {}); |
713 | } |
714 | |
715 | inline Expr Divide(Expr lhs, Expr rhs) { |
716 | static const Op& op = Op::Get("divide" ); |
717 | return Call(op, {lhs, rhs}, Attrs(), {}); |
718 | } |
719 | |
720 | inline Expr Maximum(Expr lhs, Expr rhs) { |
721 | static const Op& op = Op::Get("maximum" ); |
722 | return Call(op, {lhs, rhs}, Attrs(), {}); |
723 | } |
724 | |
725 | inline Expr ZerosLike(Expr e) { |
726 | static const Op& op = Op::Get("zeros_like" ); |
727 | return Call(op, {e}); |
728 | } |
729 | |
730 | inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) { |
731 | return MakeZeros(CheckConstantShapeArrayInteger(shape), dtype); |
732 | } |
733 | |
734 | inline Expr OnesLike(Expr e) { |
735 | static const Op& op = Op::Get("ones_like" ); |
736 | return Call(op, {e}); |
737 | } |
738 | |
739 | inline Expr Ones(Array<IndexExpr> shape, DataType dtype) { |
740 | return MakeOnes(CheckConstantShapeArrayInteger(shape), dtype); |
741 | } |
742 | |
743 | inline Expr CollapseSumLike(Expr e) { |
744 | static const Op& op = Op::Get("collapse_sum_like" ); |
745 | return Call(op, {e}); |
746 | } |
747 | |
748 | inline Expr Power(Expr lhs, Expr rhs) { |
749 | static const Op& op = Op::Get("power" ); |
750 | return Call(op, {lhs, rhs}, Attrs(), {}); |
751 | } |
752 | |
753 | inline Expr RightShift(Expr x, Expr nbit) { |
754 | static const Op& op = Op::Get("right_shift" ); |
755 | return Call(op, {x, nbit}, Attrs(), {}); |
756 | } |
757 | |
758 | inline Expr LeftShift(Expr x, Expr nbit) { |
759 | static const Op& op = Op::Get("left_shift" ); |
760 | return Call(op, {x, nbit}, Attrs(), {}); |
761 | } |
762 | |
763 | inline Expr ReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, |
764 | Integer rhs_end) { |
765 | return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end); |
766 | } |
767 | |
768 | inline Expr Copy(Expr data) { |
769 | static const Op& op = Op::Get("copy" ); |
770 | return Call(op, {data}, Attrs(), {}); |
771 | } |
772 | |
773 | inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) { |
774 | return MakeReduce(data, axis, keepdims, exclude, "mean" ); |
775 | } |
776 | |
777 | inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude, |
778 | bool unbiased = false) { |
779 | return MakeVariance(data, mean, axis, keepdims, exclude, unbiased); |
780 | } |
781 | |
782 | static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { |
783 | static const Op& op = Op::Get("where" ); |
784 | return Call(op, {condition, x, y}); |
785 | } |
786 | |
787 | static inline Expr LogicalOr(const Expr& lhs, const Expr& rhs) { |
788 | static const Op& op = Op::Get("logical_or" ); |
789 | return Call(op, {lhs, rhs}, Attrs(), {}); |
790 | } |
791 | |
792 | static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { |
793 | static const Op& op = Op::Get("greater_equal" ); |
794 | return Call(op, {lhs, rhs}, Attrs(), {}); |
795 | } |
796 | |
797 | static inline Expr Equal(const Expr& lhs, const Expr& rhs) { |
798 | static const Op& op = Op::Get("equal" ); |
799 | return Call(op, {lhs, rhs}, Attrs(), {}); |
800 | } |
801 | |
802 | static inline Expr Less(const Expr& lhs, const Expr& rhs) { |
803 | static const Op& op = Op::Get("less" ); |
804 | return Call(op, {lhs, rhs}, Attrs(), {}); |
805 | } |
806 | |
807 | static inline Expr IsFinite(const Expr x) { |
808 | static const Op& op = Op::Get("isfinite" ); |
809 | return Call(op, {x}, Attrs(), {}); |
810 | } |
811 | |
812 | static inline Expr Full(Expr fill_value, Array<IndexExpr> shape, DataType dtype) { |
813 | return MakeFull(fill_value, CheckConstantShapeArrayInteger(shape), dtype); |
814 | } |
815 | |
816 | static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides, |
817 | Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups, |
818 | IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout, |
819 | std::string kernel_layout, std::string out_layout, DataType out_dtype) { |
820 | return MakeConv<Conv2DAttrs>(data, weight, strides, padding, dilation, groups, channels, |
821 | kernel_size, data_layout, kernel_layout, out_layout, out_dtype, |
822 | "nn.conv2d" ); |
823 | } |
824 | |
825 | static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { |
826 | return MakeDense(data, weight, units, out_dtype); |
827 | } |
828 | |
829 | static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclude) { |
830 | return MakeReduce(data, axis, keepdims, exclude, "sum" ); |
831 | } |
832 | |
833 | static inline Expr Prod(Expr data, Array<Integer> axis, bool keepdims, bool exclude) { |
834 | return MakeReduce(data, axis, keepdims, exclude, "prod" ); |
835 | } |
836 | |
837 | static inline Expr Reshape(Expr data, Array<Integer> newshape) { |
838 | return MakeReshape(data, newshape); |
839 | } |
840 | |
841 | static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides, |
842 | Array<IndexExpr> dilation, Array<IndexExpr> padding, |
843 | std::string layout, std::string out_layout, bool ceil_mode, |
844 | bool count_include_pad) { |
845 | return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, dilation, padding, layout, |
846 | out_layout, ceil_mode, count_include_pad, "nn.avg_pool2d" ); |
847 | } |
848 | |
849 | static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, Expr pad_value, |
850 | std::string pad_mode) { |
851 | Array<Array<Integer>> pad_width_int; |
852 | for (size_t i = 0; i < pad_width.size(); ++i) { |
853 | pad_width_int.push_back(CheckConstantShapeArrayInteger(pad_width[i])); |
854 | } |
855 | return MakePad(data, pad_width_int, pad_value, pad_mode); |
856 | } |
857 | |
858 | static inline Expr Tile(Expr data, Array<Integer> reps) { return MakeTile(data, reps); } |
859 | |
860 | static inline Expr BroadCastTo(Expr data, Array<IndexExpr> shape) { |
861 | return MakeBroadCastTo(data, CheckConstantShapeArrayInteger(shape)); |
862 | } |
863 | |
864 | inline Expr Hardswish(Expr x) { |
865 | auto three = MakeConstantScalar(DataType::Float(32), 3.0); |
866 | auto six = MakeConstantScalar(DataType::Float(32), 6.0); |
867 | auto x2 = Add(x, three); |
868 | x2 = Clip(x2, 0.0, 6.0); |
869 | x2 = Multiply(x, x2); |
870 | x2 = Divide(x2, six); |
871 | return x2; |
872 | } |
873 | |
874 | } // namespace relay |
875 | } // namespace tvm |
876 | #endif // TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ |
877 | |