1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file common_subexpr_elim_tools.h |
22 | * \brief Interface of analysis tools and utility functions used |
23 | by the Common Subexpression Elimination (CSE) pass. |
24 | */ |
25 | |
26 | #ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_ |
27 | #define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_ |
28 | |
29 | #include <tvm/runtime/container/string.h> |
30 | #include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/expr_functor.h> |
33 | #include <tvm/tir/stmt.h> |
34 | #include <tvm/tir/stmt_functor.h> // For the class StmtExprVisitor |
35 | |
36 | #include <optional> |
37 | #include <unordered_map> // For the hashtable datatype |
38 | #include <utility> // For pairs datatype |
39 | #include <vector> |
40 | |
41 | namespace tvm { |
42 | namespace tir { |
43 | |
44 | /*! |
45 | * \brief A computation table is a hashtable which associates to each expression being computed |
46 | a number (which is the number of time that it is computed) |
47 | It is important to note that the hash used is a StructuralHash (and not an ObjectPtrHash) |
48 | as we need to hash similarly deeply equal terms. |
49 | The comparison used is ExprDeepEqual, which is stricter than StructuralEqual (as it does |
50 | not do variables remapping), so it is compatible with StructuralHash (intended to be used |
51 | with StructuralEqual). |
52 | */ |
53 | using ComputationTable = std::unordered_map<PrimExpr, size_t, StructuralHash, ExprDeepEqual>; |
54 | |
55 | /*! |
56 | * \brief A cache of computations is made of a pair of two hashtables, which respectively associate |
57 | to each statement or expression of the program its computation table. Its purpose is |
58 | to avoid the CSE pass from recomputing repeatedly the same tables of computations. |
59 | */ |
60 | struct ComputationCache { |
61 | // Part of the cache for statements |
62 | // It maps each known statement to its computation table |
63 | std::unordered_map<Stmt, ComputationTable, ObjectPtrHash, ObjectPtrEqual> |
64 | cache_stmt_table_computations_; |
65 | |
66 | // Part of the cache for expressions |
67 | // It maps each known expression to its computation table |
68 | std::unordered_map<PrimExpr, ComputationTable, ObjectPtrHash, ObjectPtrEqual> |
69 | cache_expr_table_computations_; |
70 | }; |
71 | |
72 | /*! |
73 | * \brief Visitor which returns in a hashtable the (syntatic) computations done by an expression |
74 | or by a statement. |
75 | * \note Computations here are considered syntactically, meaning that semantically equivalent |
76 | computations that are not syntactically the same are not merged together. |
77 | */ |
78 | class ComputationsDoneBy : public StmtExprVisitor { |
79 | public: |
80 | // Toplevel (static) methods |
81 | static ComputationTable GetComputationsDoneBy( |
82 | const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation, |
83 | std::function<bool(const PrimExpr&)> can_contain_computations); |
84 | static ComputationTable GetComputationsDoneBy( |
85 | const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation, |
86 | std::function<bool(const PrimExpr&)> can_contain_computations); |
87 | |
88 | protected: |
89 | // Constructor |
90 | ComputationsDoneBy(std::function<bool(const PrimExpr&)> is_eligible_computation, |
91 | std::function<bool(const PrimExpr&)> can_contain_computations); |
92 | |
93 | void VisitExpr(const PrimExpr& expr) override; |
94 | void VisitStmt(const Stmt& stmt) override; |
95 | |
96 | void VisitStmt_(const IfThenElseNode* op) override; |
97 | void VisitStmt_(const ForNode* op) override; |
98 | void VisitStmt_(const WhileNode* op) override; |
99 | |
100 | private: |
101 | static ComputationTable ComputationsDoneByChildrenOf( |
102 | const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation, |
103 | std::function<bool(const PrimExpr&)> can_contain_computations); |
104 | static ComputationTable ComputationsDoneByChildrenOf( |
105 | const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation, |
106 | std::function<bool(const PrimExpr&)> can_contain_computations); |
107 | |
108 | // The predicate used for knowing which computations are eligible |
109 | std::function<bool(const PrimExpr&)> is_eligible_computation_; |
110 | // The predicate used for knowing in which nodes we can search for eligible computations |
111 | std::function<bool(const PrimExpr&)> can_contain_computations_; |
112 | // The object being constructed and "returned" by the VisitExpr()/VisitStmt() methods |
113 | ComputationTable table_of_computations_; |
114 | // Cache for preventing to compute repeatedly the computations done by the same stmt or expr |
115 | static ComputationCache cache_; |
116 | }; |
117 | |
118 | /*! |
119 | * \brief Visitor that computes the *direct* subexpressions of a given expression. |
120 | * \note Returns only the direct subexpressions of the given expressions, not all the subexprs. |
121 | So for instance, for (A+(B+C)) it will return A and (B+C) if they are eligible, |
122 | but not B and C. |
123 | */ |
124 | class DirectSubexpr : public ExprVisitor { |
125 | public: |
126 | // Toplevel (static) function |
127 | static std::vector<PrimExpr> GetDirectSubexpressions( |
128 | const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation, |
129 | std::function<bool(const PrimExpr&)> can_contain_computations); |
130 | |
131 | protected: |
132 | // Constructor |
133 | DirectSubexpr(std::function<bool(const PrimExpr&)> is_eligible_computation, |
134 | std::function<bool(const PrimExpr&)> can_contain_computations); |
135 | |
136 | void VisitExpr(const PrimExpr& expr) override; |
137 | |
138 | private: |
139 | // The predicate used for knowing which computations are eligible |
140 | std::function<bool(const PrimExpr&)> is_eligible_computation_; |
141 | // The predicate used for knowing in which nodes we can search for eligible subexpressions |
142 | std::function<bool(const PrimExpr&)> can_contain_computations_; |
143 | |
144 | // We haven't entered the VisitExpr() method yet |
145 | bool entered_ = false; |
146 | // The vector of direct subexpressions that we are building |
147 | std::vector<PrimExpr> direct_subexpr_; |
148 | }; |
149 | |
150 | /*! |
151 | * \brief Visitor which tells if a given expression or statement uses a given variable name. |
152 | This is used by the CSE pass to make sure that we do not reuse existing names, |
153 | even though having the same name does not mean that it's the same variable, but it's |
154 | clearer for dumps. |
155 | */ |
156 | class UsesVarName : public StmtExprVisitor { |
157 | public: |
158 | // Toplevel (static) methods |
159 | static bool ExprUsesVarName(const PrimExpr& expr, String var_name); |
160 | static bool StmtUsesVarName(const Stmt& stmt, String var_name); |
161 | |
162 | protected: |
163 | // Constructor |
164 | explicit UsesVarName(String var_name); |
165 | |
166 | void VisitExpr(const PrimExpr& expr) override; |
167 | void VisitStmt(const Stmt& stmt) override; |
168 | |
169 | private: |
170 | String var_name_; |
171 | bool uses_var_name_ = false; |
172 | }; |
173 | |
174 | /*! |
175 | * \brief Various utility functions for the CSE pass |
176 | */ |
177 | void PrintComputationTable(const ComputationTable& table); |
178 | |
179 | using MaybeValue = std::optional<PrimExpr>; |
180 | |
181 | bool EqualTerms(const PrimExpr& a, const PrimExpr& b); |
182 | // Used for deciding the (decidable) equivalence relation |
183 | PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization); |
184 | // The equivalence relation, which is the syntactical equality when `identify_equiv_terms` is false |
185 | bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool identify_equiv_terms); |
186 | std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations( |
187 | const ComputationTable& table, bool identify_equiv_terms); |
188 | bool PredicateIntroVarForComputation(const PrimExpr& computation, size_t nb_times_seen); |
189 | |
190 | // Polymorphic (functional) map on a vector, which builds a news vector with the same number of |
191 | // elements, where each element is the application of a given function on the corresponding element |
192 | // in the input vector. |
193 | template <typename A, typename B> |
194 | std::vector<B> VectorMap(const std::vector<A>& input, std::function<B(const A&)> fun) { |
195 | std::vector<B> result; |
196 | size_t size = input.size(); |
197 | // For efficiency, allocate immediately the size needed as the result will have |
198 | // the same size as the input |
199 | result.reserve(size); |
200 | |
201 | for (size_t i = 0; i < size; i++) { |
202 | result.push_back(fun(input[i])); |
203 | } |
204 | |
205 | return result; |
206 | } |
207 | // Explicitely instanciate the template function for A=std::pair<Var,MaybeValue> and B=Var |
208 | template std::vector<Var> VectorMap(const std::vector<std::pair<Var, MaybeValue>>&, |
209 | std::function<Var(const std::pair<Var, MaybeValue>&)>); |
210 | |
211 | void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec, |
212 | const std::pair<PrimExpr, size_t>& pair); |
213 | |
214 | void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec, |
215 | const std::vector<PrimExpr>& vec_to_add, |
216 | bool identify_equiv_terms, size_t increase_count = 1); |
217 | |
218 | } // namespace tir |
219 | } // namespace tvm |
220 | |
221 | #endif // TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_TOOLS_H_ |
222 | |