1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file constant_utils.h |
22 | * \brief Utility functions for handling constants in TVM expressions |
23 | */ |
24 | #ifndef TVM_TOPI_DETAIL_CONSTANT_UTILS_H_ |
25 | #define TVM_TOPI_DETAIL_CONSTANT_UTILS_H_ |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/te/operation.h> |
29 | #include <tvm/tir/analysis.h> |
30 | #include <tvm/tir/expr.h> |
31 | |
32 | #include <string> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace topi { |
37 | namespace detail { |
38 | |
39 | using namespace tvm::te; |
40 | |
41 | /*! |
42 | * \brief Test whether the given Expr is a constant integer |
43 | * |
44 | * \param expr the Expr to query |
45 | * |
46 | * \return true if the given expr is a constant int or uint, false otherwise. |
47 | */ |
48 | inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImmNode>(); } |
49 | |
50 | /*! |
51 | * \brief Test whether the given Array has every element as constant integer. |
52 | * Undefined elements are also treat as constants. |
53 | * |
54 | * \param array the array to query |
55 | * |
56 | * \return true if every element in array is constant int or uint, false otherwise. |
57 | */ |
58 | inline bool IsConstIntArray(Array<PrimExpr> array) { |
59 | bool is_const_int = true; |
60 | for (auto const& elem : array) { |
61 | is_const_int &= !elem.defined() || elem->IsInstance<tvm::tir::IntImmNode>(); |
62 | } |
63 | return is_const_int; |
64 | } |
65 | |
66 | /*! |
67 | * \brief Get the value of the given constant integer expression. An error |
68 | * is logged if the given expression is not a constant integer. |
69 | * |
70 | * \param expr The expression to get the value of |
71 | * |
72 | * \return The integer value. |
73 | */ |
74 | inline int64_t GetConstInt(PrimExpr expr) { |
75 | if (expr->IsInstance<tvm::IntImmNode>()) { |
76 | return expr.as<tvm::IntImmNode>()->value; |
77 | } |
78 | LOG(ERROR) << "expr must be a constant integer" ; |
79 | return -1; |
80 | } |
81 | |
82 | /*! |
83 | * \brief Get the value of all the constant integer expressions in the given array |
84 | * |
85 | * \param exprs The array of expressions to get the values of |
86 | * \param var_name The name to be used when logging an error in the event that any |
87 | * of the expressions are not constant integers. |
88 | * |
89 | * \return A vector of the integer values |
90 | */ |
91 | inline std::vector<int> GetConstIntValues(Array<PrimExpr> exprs, const std::string& var_name) { |
92 | std::vector<int> result; |
93 | if (!exprs.defined()) return result; |
94 | for (auto expr : exprs) { |
95 | ICHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers" ; |
96 | result.push_back(GetConstInt(expr)); |
97 | } |
98 | return result; |
99 | } |
100 | |
101 | /*! |
102 | * \brief Get the value of all the constant integer expressions in the given array |
103 | * |
104 | * \param exprs The array of expressions to get the values of |
105 | * \param var_name The name to be used when logging an error in the event that any |
106 | * of the expressions are not constant integers. |
107 | * |
108 | * \return A vector of the int64_t values |
109 | */ |
110 | inline std::vector<int64_t> GetConstInt64Values(Array<PrimExpr> exprs, |
111 | const std::string& var_name) { |
112 | std::vector<int64_t> result; |
113 | if (!exprs.defined()) return result; |
114 | for (auto expr : exprs) { |
115 | ICHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers" ; |
116 | result.push_back(GetConstInt(expr)); |
117 | } |
118 | return result; |
119 | } |
120 | |
121 | /*! |
122 | * \brief Check whether the two expressions are equal or not, if not simplify the expressions and |
123 | * check again |
124 | * \note This is stronger equality check than tvm::tir::Equal |
125 | * \param lhs First expression |
126 | * \param rhs Second expression |
127 | * \return result True if both expressions are equal, else false |
128 | */ |
129 | inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { |
130 | tvm::tir::ExprDeepEqual expr_equal; |
131 | bool result = expr_equal(lhs, rhs); |
132 | if (!result) { |
133 | PrimExpr t = tvm::arith::Analyzer().Simplify(lhs - rhs); |
134 | if (const IntImmNode* i = t.as<IntImmNode>()) { |
135 | result = i->value == 0; |
136 | } |
137 | } |
138 | return result; |
139 | } |
140 | |
141 | } // namespace detail |
142 | } // namespace topi |
143 | } // namespace tvm |
144 | #endif // TVM_TOPI_DETAIL_CONSTANT_UTILS_H_ |
145 | |