1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file common_subexpr_elim_tools.h
22 * \brief Interface of analysis tools and utility functions used
23 by the Common Subexpression Elimination (CSE) pass.
24 */
25
26#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
27#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
28
29#include <tvm/runtime/container/string.h>
30#include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis
31#include <tvm/tir/expr.h>
32#include <tvm/tir/expr_functor.h>
33#include <tvm/tir/stmt.h>
34#include <tvm/tir/stmt_functor.h> // For the class StmtExprVisitor
35
36#include <optional>
37#include <unordered_map> // For the hashtable datatype
38#include <utility> // For pairs datatype
39#include <vector>
40
41namespace tvm {
42namespace tir {
43
44/*!
45 * \brief A computation table is a hashtable which associates to each expression being computed
46 a number (which is the number of time that it is computed)
47 It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash)
48 as we need to hash similarly deeply equal terms.
49 The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does
50 not do variables remapping), so it is compatible with StructuralHash (intended to be used
51 with StructuralEqual).
52 */
53using ComputationTable = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>;
54
55/*!
56 * \brief A cache of computations is made of a pair of two hashtables, which respectively associate
57 to each statement or expression of the program its computation table. Its purpose is
58 to avoid the CSE pass from recomputing repeatedly the same tables of computations.
59 */
60struct ComputationCache {
61 // Part of the cache for statements
62 // It maps each known statement to its computation table
63 std::unordered_map<Stmt, ComputationTable, ObjectPtrHash, ObjectPtrEqual>
64 cache_stmt_table_computations_;
65
66 // Part of the cache for expressions
67 // It maps each known expression to its computation table
68 std::unordered_map<PrimExpr, ComputationTable, ObjectPtrHash, ObjectPtrEqual>
69 cache_expr_table_computations_;
70};
71
72/*!
73 * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression
74 or by a statement.
75 * \note Computations here are considered syntactically, meaning that semantically equivalent
76 computations that are not syntactically the same are not merged together.
77 */
78class ComputationsDoneBy : public StmtExprVisitor {
79 public:
80 // Toplevel (static) methods
81 static ComputationTable GetComputationsDoneBy(
82 const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
83 std::function<bool(const PrimExpr&)> can_contain_computations);
84 static ComputationTable GetComputationsDoneBy(
85 const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation,
86 std::function<bool(const PrimExpr&)> can_contain_computations);
87
88 protected:
89 // Constructor
90 ComputationsDoneBy(std::function<bool(const PrimExpr&)> is_eligible_computation,
91 std::function<bool(const PrimExpr&)> can_contain_computations);
92
93 void VisitExpr(const PrimExpr& expr) override;
94 void VisitStmt(const Stmt& stmt) override;
95
96 void VisitStmt_(const IfThenElseNode* op) override;
97 void VisitStmt_(const ForNode* op) override;
98 void VisitStmt_(const WhileNode* op) override;
99
100 private:
101 static ComputationTable ComputationsDoneByChildrenOf(
102 const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
103 std::function<bool(const PrimExpr&)> can_contain_computations);
104 static ComputationTable ComputationsDoneByChildrenOf(
105 const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation,
106 std::function<bool(const PrimExpr&)> can_contain_computations);
107
108 // The predicate used for knowing which computations are eligible
109 std::function<bool(const PrimExpr&)> is_eligible_computation_;
110 // The predicate used for knowing in which nodes we can search for eligible computations
111 std::function<bool(const PrimExpr&)> can_contain_computations_;
112 // The object being constructed and "returned" by the VisitExpr()/VisitStmt() methods
113 ComputationTable table_of_computations_;
114 // Cache for preventing to compute repeatedly the computations done by the same stmt or expr
115 static ComputationCache cache_;
116};
117
118/*!
119 * \brief Visitor that computes the *direct* subexpressions of a given expression.
120 * \note Returns only the direct subexpressions of the given expressions, not all the subexprs.
121 So for instance, for (A+(B+C)) it will return A and (B+C) if they are eligible,
122 but not B and C.
123 */
124class DirectSubexpr : public ExprVisitor {
125 public:
126 // Toplevel (static) function
127 static std::vector<PrimExpr> GetDirectSubexpressions(
128 const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
129 std::function<bool(const PrimExpr&)> can_contain_computations);
130
131 protected:
132 // Constructor
133 DirectSubexpr(std::function<bool(const PrimExpr&)> is_eligible_computation,
134 std::function<bool(const PrimExpr&)> can_contain_computations);
135
136 void VisitExpr(const PrimExpr& expr) override;
137
138 private:
139 // The predicate used for knowing which computations are eligible
140 std::function<bool(const PrimExpr&)> is_eligible_computation_;
141 // The predicate used for knowing in which nodes we can search for eligible subexpressions
142 std::function<bool(const PrimExpr&)> can_contain_computations_;
143
144 // We haven't entered the VisitExpr() method yet
145 bool entered_ = false;
146 // The vector of direct subexpressions that we are building
147 std::vector<PrimExpr> direct_subexpr_;
148};
149
150/*!
151 * \brief Visitor which tells if a given expression or statement uses a given variable name.
152 This is used by the CSE pass to make sure that we do not reuse existing names,
153 even though having the same name does not mean that it's the same variable, but it's
154 clearer for dumps.
155 */
156class UsesVarName : public StmtExprVisitor {
157 public:
158 // Toplevel (static) methods
159 static bool ExprUsesVarName(const PrimExpr& expr, String var_name);
160 static bool StmtUsesVarName(const Stmt& stmt, String var_name);
161
162 protected:
163 // Constructor
164 explicit UsesVarName(String var_name);
165
166 void VisitExpr(const PrimExpr& expr) override;
167 void VisitStmt(const Stmt& stmt) override;
168
169 private:
170 String var_name_;
171 bool uses_var_name_ = false;
172};
173
174/*!
175 * \brief Various utility functions for the CSE pass
176 */
177void PrintComputationTable(const ComputationTable& table);
178
179using MaybeValue = std::optional<PrimExpr>;
180
181bool EqualTerms(const PrimExpr& a, const PrimExpr& b);
182// Used for deciding the (decidable) equivalence relation
183PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization);
184// The equivalence relation, which is the syntactical equality when `identify_equiv_terms` is false
185bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool identify_equiv_terms);
186std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
187 const ComputationTable& table, bool identify_equiv_terms);
188bool PredicateIntroVarForComputation(const PrimExpr& computation, size_t nb_times_seen);
189
190// Polymorphic (functional) map on a vector, which builds a news vector with the same number of
191// elements, where each element is the application of a given function on the corresponding element
192// in the input vector.
193template <typename A, typename B>
194std::vector<B> VectorMap(const std::vector<A>& input, std::function<B(const A&)> fun) {
195 std::vector<B> result;
196 size_t size = input.size();
197 // For efficiency, allocate immediately the size needed as the result will have
198 // the same size as the input
199 result.reserve(size);
200
201 for (size_t i = 0; i < size; i++) {
202 result.push_back(fun(input[i]));
203 }
204
205 return result;
206}
207// Explicitely instanciate the template function for A=std::pair<Var,MaybeValue> and B=Var
208template std::vector<Var> VectorMap(const std::vector<std::pair<Var, MaybeValue>>&,
209 std::function<Var(const std::pair<Var, MaybeValue>&)>);
210
211void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
212 const std::pair<PrimExpr, size_t>& pair);
213
214void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
215 const std::vector<PrimExpr>& vec_to_add,
216 bool identify_equiv_terms, size_t increase_count = 1);
217
218} // namespace tir
219} // namespace tvm
220
221#endif // TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_
222