1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file constant_utils.h
22 * \brief Utility functions for handling constants in TVM expressions
23 */
24#ifndef TVM_TOPI_DETAIL_CONSTANT_UTILS_H_
25#define TVM_TOPI_DETAIL_CONSTANT_UTILS_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/te/operation.h>
29#include <tvm/tir/analysis.h>
30#include <tvm/tir/expr.h>
31
32#include <string>
33#include <vector>
34
35namespace tvm {
36namespace topi {
37namespace detail {
38
39using namespace tvm::te;
40
41/*!
42 * \brief Test whether the given Expr is a constant integer
43 *
44 * \param expr the Expr to query
45 *
46 * \return true if the given expr is a constant int or uint, false otherwise.
47 */
48inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance<tvm::tir::IntImmNode>(); }
49
50/*!
51 * \brief Test whether the given Array has every element as constant integer.
52 * Undefined elements are also treat as constants.
53 *
54 * \param array the array to query
55 *
56 * \return true if every element in array is constant int or uint, false otherwise.
57 */
58inline bool IsConstIntArray(Array<PrimExpr> array) {
59 bool is_const_int = true;
60 for (auto const& elem : array) {
61 is_const_int &= !elem.defined() || elem->IsInstance<tvm::tir::IntImmNode>();
62 }
63 return is_const_int;
64}
65
66/*!
67 * \brief Get the value of the given constant integer expression. An error
68 * is logged if the given expression is not a constant integer.
69 *
70 * \param expr The expression to get the value of
71 *
72 * \return The integer value.
73 */
74inline int64_t GetConstInt(PrimExpr expr) {
75 if (expr->IsInstance<tvm::IntImmNode>()) {
76 return expr.as<tvm::IntImmNode>()->value;
77 }
78 LOG(ERROR) << "expr must be a constant integer";
79 return -1;
80}
81
82/*!
83 * \brief Get the value of all the constant integer expressions in the given array
84 *
85 * \param exprs The array of expressions to get the values of
86 * \param var_name The name to be used when logging an error in the event that any
87 * of the expressions are not constant integers.
88 *
89 * \return A vector of the integer values
90 */
91inline std::vector<int> GetConstIntValues(Array<PrimExpr> exprs, const std::string& var_name) {
92 std::vector<int> result;
93 if (!exprs.defined()) return result;
94 for (auto expr : exprs) {
95 ICHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
96 result.push_back(GetConstInt(expr));
97 }
98 return result;
99}
100
101/*!
102 * \brief Get the value of all the constant integer expressions in the given array
103 *
104 * \param exprs The array of expressions to get the values of
105 * \param var_name The name to be used when logging an error in the event that any
106 * of the expressions are not constant integers.
107 *
108 * \return A vector of the int64_t values
109 */
110inline std::vector<int64_t> GetConstInt64Values(Array<PrimExpr> exprs,
111 const std::string& var_name) {
112 std::vector<int64_t> result;
113 if (!exprs.defined()) return result;
114 for (auto expr : exprs) {
115 ICHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
116 result.push_back(GetConstInt(expr));
117 }
118 return result;
119}
120
121/*!
122 * \brief Check whether the two expressions are equal or not, if not simplify the expressions and
123 * check again
124 * \note This is stronger equality check than tvm::tir::Equal
125 * \param lhs First expression
126 * \param rhs Second expression
127 * \return result True if both expressions are equal, else false
128 */
129inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) {
130 tvm::tir::ExprDeepEqual expr_equal;
131 bool result = expr_equal(lhs, rhs);
132 if (!result) {
133 PrimExpr t = tvm::arith::Analyzer().Simplify(lhs - rhs);
134 if (const IntImmNode* i = t.as<IntImmNode>()) {
135 result = i->value == 0;
136 }
137 }
138 return result;
139}
140
141} // namespace detail
142} // namespace topi
143} // namespace tvm
144#endif // TVM_TOPI_DETAIL_CONSTANT_UTILS_H_
145