1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file common_subexpr_elim.cc |
22 | * \brief Implementation of the Common Subexpressions Elimination (CSE) pass |
23 | which rewrites statements and expressions in order to eliminate |
24 | redundant computations. In order to achieve that, common (sub-) |
25 | expressions are introduced into variables with let-in bindings, |
26 | and the places where the expression was used are replaced with |
27 | the freshly introduced variable. |
28 | */ |
29 | |
30 | #include "common_subexpr_elim.h" |
31 | |
32 | #include <tvm/ir/transform.h> // For the class Pass and the class PassContext |
33 | #include <tvm/runtime/container/array.h> |
34 | #include <tvm/runtime/container/string.h> |
35 | #include <tvm/tir/analysis.h> // For the analysis which gives the size of an expr |
36 | #include <tvm/tir/expr.h> |
37 | #include <tvm/tir/expr_functor.h> |
38 | #include <tvm/tir/function.h> // For the class PrimFunc |
39 | #include <tvm/tir/stmt.h> |
40 | #include <tvm/tir/stmt_functor.h> |
41 | #include <tvm/tir/transform.h> // For the decl of the function returning the pass |
42 | |
43 | #include <algorithm> // For the algorithm std::find |
44 | #include <iostream> |
45 | #include <string> |
46 | #include <unordered_map> // For the hashtable datatype |
47 | #include <utility> // For std::pair and std::move |
48 | #include <vector> |
49 | |
50 | #include "../analysis/check_contains.h" // For the visitor CheckContains |
51 | #include "common_subexpr_elim_tools.h" // For the auxiliary analysis (visitors) and tools |
52 | #include "replace_selected_expr.h" // For the mutator ReplaceSelectedExpr |
53 | |
54 | namespace tvm { |
55 | namespace tir { |
56 | |
57 | /*! |
58 | * \brief Check whether a computation is forbidden for being treated by the CSE pass. |
59 | The important thing about forbidden computations is that not only we won't want |
60 | to collect them for the CSE pass, but we also won't even want to collect computations |
61 | that contain them. |
62 | The reason is that reusing such computations would change the semantics of the program, |
63 | and therefore before doing any introduction of var or any reuse of already introduced |
64 | variables, we will make sure that the computation being considered is not forbidden, and |
65 | that it does not even contain a forbidden computation. |
66 | * \param expr The expression to check |
67 | * \return Whether `expr` is a forbidden computation or not |
68 | */ |
69 | bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) { |
70 | // Function calls, loads and buffer loads are absolutely forbidden as introducing them into |
71 | // variables would change the semantics of the program. |
72 | return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr || |
73 | expr.as<BufferLoadNode>() != nullptr); |
74 | } |
75 | |
76 | /*! |
77 | * \brief Predicate used for verifying that a computation is eligible for being treated by |
78 | the CSE pass, i.e. for being introduced into a variable / for being replaced by a |
79 | variable. |
80 | Being eligible is a conjunction of a few conditions, like not being an atom (constant |
81 | or variable), not being a forbidden node, not containing a forbidden node, etc. |
82 | * \param expr The expression to check |
83 | * \return Whether `expr` is an eligible computation or not |
84 | */ |
85 | bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) { |
86 | return ( |
87 | // In order to be eligible, the given expression should not be a constant |
88 | (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) && |
89 | (expr.as<StringImmNode>() == nullptr) |
90 | // and it should not be a variable |
91 | && (expr.as<VarNode>() == nullptr) |
92 | // and it should not be a forbidden computation (function calls and loads) |
93 | && (!ForbiddenComputation(expr)) |
94 | // and it should not even contain a forbidden computation (function calls and loads) |
95 | // the reason is that we don't want to register expressions like (x + f(y)) or |
96 | // (x + Mem[i]) as introducing them into variables could change the semantics |
97 | && (!CheckContains::ExprContains(expr, ForbiddenComputation)) |
98 | // and it should not be a ramp node or a broadcast node due to some internals TVM |
99 | // constraints (which check for these node explicitely without performing any |
100 | // evaluation first, so if they have been put into variables it fails) |
101 | && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr)); |
102 | } |
103 | |
104 | /*! |
105 | * \brief Predicate used (when considering eligible computations) for only diving into |
106 | expressions that are allowed to contain eligible computations. Customize this predicate |
107 | if you want to make it forbidden to rewrite inside a specific node, like inside |
108 | a Load node for instance. |
109 | * \param expr The expression to check |
110 | * \return Whether `expr` can contain some eligible computations or not, and therefore |
111 | if recursing inside `expr` is necessary. |
112 | */ |
113 | bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) { |
114 | // Uncomment the next line to prevent the collection and the replacement of eligible computations |
115 | // inside the index of Load nodes. We initially thought that this would be needed in order to |
116 | // not harm the indexing mode of the CPU, but as we are still far from ASM code, we |
117 | // finally want to perform such simplifications, which tend to happen fairly frequently. |
118 | |
119 | // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) ) |
120 | return true; |
121 | } |
122 | |
123 | /*! |
124 | * \brief Implements an order on pairs (expression,frequency). First attempts to compare them |
125 | using the size of the expression. If it is the same, decides something else still |
126 | deterministic. |
127 | * \param a The first pair |
128 | * \param b The second pair |
129 | * \return A boolean telling if the first pair `a` comes before the second pair `b` |
130 | * \note We need this order to be deterministic in order to have a fully deterministic pass, |
131 | * as we will deal with elements that are coming from a hashtable, but the order in which |
132 | * they appeared in the hashtable was based on some runtime addresses, so it can potentially |
133 | * change with every execution. |
134 | */ |
135 | bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, |
136 | std::pair<PrimExpr, size_t> b) { |
137 | size_t a_size = CalculateExprComplexity(a.first); |
138 | size_t b_size = CalculateExprComplexity(b.first); |
139 | |
140 | // Criteria 1 - Size of the expression comes first |
141 | // `a` comes before `b` if the size of `a` is bigger |
142 | if (a_size > b_size) { |
143 | return true; |
144 | } |
145 | // `a` does NOT come before `b` if the size of `b` is bigger |
146 | if (b_size > a_size) { |
147 | return false; |
148 | } |
149 | |
150 | // Criteria 2 - If they had the same size, use the lexicographic order as a last resort |
151 | // as we need a deterministic order |
152 | std::stringstream a_stream; |
153 | std::stringstream b_stream; |
154 | a_stream << AsLegacyRepr(a.first); |
155 | b_stream << AsLegacyRepr(b.first); |
156 | return (a_stream.str().compare(b_stream.str()) < 0); |
157 | } |
158 | |
159 | /*! |
160 | * \brief Generates a new fresh variable, whose name will be cse_var_i. |
161 | * \param type_annotation The type of the new variable to generate |
162 | * \return A new variable of type `type_annotation` called cse_var_i where i is the first available |
163 | integer. |
164 | */ |
165 | Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) { |
166 | // Increase `num_last_try_` for this new attempt |
167 | num_last_try_++; |
168 | // Builds the variable name, which is sce_var_i where i will go up from 1 |
169 | std::string prefix = "cse_var_" ; |
170 | std::string name = prefix.append(std::to_string(num_last_try_)); |
171 | // Builds a String using the std::string |
172 | String string_name(name); |
173 | |
174 | // Check that the name that we want to use for the new variable isn't already being used |
175 | // (names don't really have to be unique as they are just hints, and having the same name |
176 | // doesn't means that it's the same variable, but it's clearer for dumps) |
177 | if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) { |
178 | // If the name is already used, call ourselves recursively for trying with the next one |
179 | return GenerateNewVar(type_annotation); |
180 | } |
181 | |
182 | // Increase `nb_var_` for this new generation of variable that we have just done |
183 | nb_var_++; |
184 | |
185 | // Return a new Variable using the name built and the given type_annotation |
186 | return (Var(string_name, type_annotation)); |
187 | } |
188 | |
189 | /*! |
190 | * \brief Gives the number of variables generated by the CSE on the current function |
191 | (i.e., getter for `nb_var_`). |
192 | * \return A copy of `nb_var_` |
193 | */ |
194 | int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; } |
195 | |
196 | /*! |
197 | * \brief Toplevel (static) method that performs Common Subexpression Elimination on |
198 | a given statement (which should be the body of a PrimFunc). This method should be |
199 | called for each PrimFunc definition. |
200 | * \param stmt The statement of the function being analyzed, on which we want to perform CSE |
201 | * \param context_init The initial context, which should contain the formal parameters |
202 | of the function being analyzed |
203 | * \return A new statement where CSE has been performed |
204 | */ |
205 | Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init, |
206 | bool identify_equiv_terms) { |
207 | // As this function is being called for each PrimFunc definition, we create a new instance |
208 | // for the one we are having now. |
209 | CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init, |
210 | identify_equiv_terms); |
211 | return common_subexpression_eliminator.VisitStmt(stmt); |
212 | } |
213 | |
214 | /*! |
215 | * \brief Protected constructor of CommonSubexpressionEliminator. |
216 | * \param context_init The context at the beginning of the CSE pass. It should contain the |
217 | formal parameters of the function that will be analyzed |
218 | */ |
219 | CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt, |
220 | const Context& context_init, |
221 | bool identify_equiv_terms) |
222 | : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {} |
223 | |
224 | /*! |
225 | * \brief The method which overrides the generic dispatcher of StmtExprMutator. |
226 | Entry point to the common subexpression elimination mutator for expressions. |
227 | * \param expr The expression to mutate |
228 | */ |
229 | PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { |
230 | bool variables_created = false; // Will be needed for knowing if the CSE has created new vars |
231 | PrimExpr result = expr; |
232 | |
233 | // Obtain the (syntactic) eligible computations done by the input expression, and keep it as |
234 | // a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the |
235 | // number of time this exact syntactic computation is being computed. |
236 | ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy( |
237 | expr, IsEligibleComputation, CanContainEligibleComputations); |
238 | |
239 | // Transform the hashtable of *syntactic* eligible computations into a vector of pairs |
240 | // containing *semantic* entities, i.e. where equivalent computations are merged. |
241 | std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr = |
242 | SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_); |
243 | |
244 | // Sort the vector of semantic entities by decreasing size |
245 | std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(), |
246 | OrderOnExprAndFrequency); |
247 | |
248 | // For each computation done (considering them from biggest to smallest) |
249 | for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) { |
250 | std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i]; |
251 | |
252 | bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this" |
253 | |
254 | // The predicate later used (when doing replacements) to select expressions that are |
255 | // equivalent to the current computation (`computation_and_nb.first`) |
256 | std::function<bool(const PrimExpr&)> predicate_selector = |
257 | [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) { |
258 | // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check |
259 | // that `current_expr` is an eligible computation even if we know that |
260 | // `computation_and_nb.first` is eligible by construction, in case that one day the |
261 | // equivalence relation would not preserve the eligibility any more (even though that |
262 | // would probably be a very weird equivalence). |
263 | return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) && |
264 | IsEligibleComputation(current_expr)); |
265 | }; |
266 | |
267 | // See if there is a pair (`var`, `value`) in the context where `value` is semantically |
268 | // equivalent to `computation_and_nb.first` |
269 | auto it_on_var = std::find_if( |
270 | context_.begin(), context_.end(), |
271 | [computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) { |
272 | // Note : safe to call value() as we check has_value() just before |
273 | return (var_and_value.second.has_value() && |
274 | EquivalentTerms(var_and_value.second.value(), computation_and_nb.first, |
275 | ident_equiv_terms)); |
276 | }); |
277 | |
278 | // Case where we have a perfectly equivalent computation already available in a variable |
279 | // introduced (i.e, present in context_). |
280 | // Note that this case is needed when the user has written something like |
281 | // [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by |
282 | // an already existing variable holding A, when such a variable happens to exist. |
283 | if (it_on_var != context_.end()) { |
284 | // Replace in the current `result` everything that is selected by the selector with |
285 | // the existing variable, without diving into expressions in which we don't have the |
286 | // right to dive. |
287 | result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr( |
288 | result, predicate_selector, it_on_var->first, CanContainEligibleComputations); |
289 | } else { |
290 | // The current computation is not equivalent to a computation already done. We will |
291 | // need to see if we want to introduce it. |
292 | |
293 | // --- Chunk needed for reusing the UndefinedVars() analysis --- |
294 | // 1 - Wraps the computation into a statement |
295 | Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first); |
296 | // 2.1 - Transform the context into a vector of variables instead of pairs |
297 | std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value = |
298 | [](const std::pair<Var, MaybeValue>& pair) { return pair.first; }; |
299 | std::vector<Var> vector_vars_known = VectorMap(context_, forget_value); |
300 | // 2.2 - Transform the std::vector into an Array |
301 | Array<Var> array_vars_known = Array<Var>(vector_vars_known); |
302 | // --- End of chunk needed for reusing the UndefinedVars() analysis --- |
303 | |
304 | // We use the UndefinedVars() analysis to get the undefined vars of the computation |
305 | Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); |
306 | |
307 | // Check if we can introduce it : if it contains no undefined variables and if we want |
308 | // to introduce it according to the predicate |
309 | if (vars_undefined.empty() && |
310 | PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) { |
311 | // Create a new variable for this computation |
312 | Var new_var = GenerateNewVar(computation_and_nb.first.dtype()); |
313 | // Replace in the current `result` everything that is selected by the selector with |
314 | // the new variable, without diving into expressions in which we don't have the |
315 | // right to dive. |
316 | result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(result, predicate_selector, new_var, |
317 | CanContainEligibleComputations); |
318 | // Build a let-in that introduces the new variable in the current `result` |
319 | result = Let(new_var, computation_and_nb.first, result); |
320 | // We don't add the variable to the context because the invariant is that the |
321 | // context is the context in which 'result' makes sense, and we've just updated it. |
322 | } else { |
323 | // Here it's not doable to introduce (via a let-in) the computation at this level |
324 | // as it contains variables that are not yet declared, and/or because the predicate |
325 | // did not select it. |
326 | // Either way, we will simply add to the vector of computations the direct subexprs |
327 | // of the current computation, as these ones might be good candidates |
328 | // for being introduced into variables. |
329 | // Note that we don't need to add all of its subexpressions, but only its *direct* |
330 | // subexpressions as we consider them from biggest to smallest, and if they were |
331 | // all added at once, then there could be dependencies between them, as commoning |
332 | // one of them could remove some other possibilities. |
333 | |
334 | // Computing the direct subexpressions will return a small number of direct |
335 | // subexpressions (typically 0 to 3) |
336 | std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions( |
337 | computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations); |
338 | // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by |
339 | // decreasing size/complexity), and it will only insert at locations > i as the |
340 | // direct subexprs are necessarily smaller than the current computation. |
341 | InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, |
342 | identify_equiv_terms_); |
343 | } |
344 | } |
345 | // Note : we do not remove the current element, as we never look back in the local vector |
346 | } // End of for loop |
347 | |
348 | // If the CSE pass has created some variables, then we run it again as more commoning could |
349 | // potentially happen using the new variables introduced |
350 | if (variables_created) { |
351 | result = VisitExpr(result); |
352 | } else { |
353 | // But if no changes were performed, we recurse inside the children by calling the dispatcher. |
354 | // Calling the dispatcher to the specific treatments, which will update the context |
355 | // appropriately before doing the recursive calls on the children nodes |
356 | result = StmtExprMutator::VisitExpr(result); |
357 | } |
358 | |
359 | return result; |
360 | } |
361 | |
362 | /*! |
363 | * \brief The method which overrides the specific treatment for a LetNode |
364 | */ |
365 | PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) { |
366 | // At this point, we have already done the generic treatment of introducing (via let-in) what |
367 | // was doable at the toplevel of the given let-in. |
368 | |
369 | // Save the context at the entry of the function |
370 | Context context_at_entry = context_; |
371 | |
372 | // Recurse on the `value` field for potentially rewriting it |
373 | PrimExpr value_new = VisitExpr(op->value); |
374 | |
375 | // Augment the context with the association (`var`, `value`) for preparing the next recursion |
376 | // on the `body` |
377 | context_.push_back({op->var, MaybeValue(op->value)}); |
378 | |
379 | // Recurse on the `body` (with this extended context) |
380 | // The recursive call will have potentially done new simplifications, because in this recursive |
381 | // call `var` will be a part of the context. |
382 | // (see in VisitExpr() that no introduction were performed when a computation was using an |
383 | // undefined variable, as that would lead to ill-formed code) |
384 | PrimExpr body_new = VisitExpr(op->body); |
385 | |
386 | // Restaure the context to its content at the entrance to not carry out of scope declarations |
387 | // as the variable introduced by the let-in is not in scope outside of its body |
388 | context_ = context_at_entry; |
389 | |
390 | // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might |
391 | // have been done. |
392 | |
393 | // If the `value` and the `body` of the let-in have been rewritten to the same thing |
394 | if (value_new.same_as(op->value) && body_new.same_as(op->body)) { |
395 | // then return a reference to the same node |
396 | return GetRef<PrimExpr>(op); |
397 | } else { |
398 | // Otherwise return a let-in built with the new `value_new` and the new `body_new` that |
399 | // have just been obtained |
400 | return Let(op->var, value_new, body_new, op->span); |
401 | } |
402 | } |
403 | |
404 | /*! |
405 | * \brief The method which overrides the generic dispatcher of StmtExprMutator. |
406 | Entry point to the common subexpression elimination mutator for statements. |
407 | * \param stmt The statement to mutate. |
408 | */ |
409 | Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { |
410 | bool variables_created = false; // Will be needed for knowing if the CSE has created new vars |
411 | Stmt result = stmt; |
412 | |
413 | // Obtain the (syntactic) eligible computations done by the input statement, and keep it as |
414 | // a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the |
415 | // number of time this exact syntactic computation is being computed. |
416 | ComputationTable table_syntactic_comp_done_by_stmt = ComputationsDoneBy::GetComputationsDoneBy( |
417 | stmt, IsEligibleComputation, CanContainEligibleComputations); |
418 | |
419 | // Transform the hashtable of *syntactic* eligible computations into a vector of pairs |
420 | // containing *semantic* entities, i.e. where equivalent computations are merged. |
421 | std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt = |
422 | SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_); |
423 | |
424 | // Sort the vector of semantic entities by decreasing size |
425 | std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(), |
426 | OrderOnExprAndFrequency); |
427 | |
428 | // For each computation done (considering them from biggest to smallest) |
429 | for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) { |
430 | std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i]; |
431 | |
432 | bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this" |
433 | |
434 | // The predicate later used (when doing replacements) to select expressions that are |
435 | // equivalent to the current computation (`computation_and_nb.first`) |
436 | std::function<bool(const PrimExpr&)> predicate_selector = |
437 | [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) { |
438 | // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check |
439 | // that `current_expr` is an eligible computation even if we know that |
440 | // `computation_and_nb.first` is eligible by construction, in case that one day the |
441 | // equivalence relation would not preserve the eligibility any more (even though that |
442 | // would probably be a very weird equivalence). |
443 | return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) && |
444 | IsEligibleComputation(current_expr)); |
445 | }; |
446 | |
447 | // See if there is a pair (`var`, `value`) in the context where `value` is semantically |
448 | // equivalent to `computation_and_nb.first` |
449 | auto it_on_var = std::find_if( |
450 | context_.begin(), context_.end(), |
451 | [computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) { |
452 | // Note : safe to call value() as we check has_value() just before |
453 | return (var_and_value.second.has_value() && |
454 | EquivalentTerms(var_and_value.second.value(), computation_and_nb.first, |
455 | ident_equiv_terms)); |
456 | }); |
457 | |
458 | // Case where we have a perfectly equivalent computation already available in a variable |
459 | // introduced (i.e, present in context_). |
460 | // Note that this case is needed when the user has written something like |
461 | // [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by |
462 | // an already existing variable holding A, when such a variable happens to exist. |
463 | if (it_on_var != context_.end()) { |
464 | // Replace in the current `result` everything that is selected by the selector with |
465 | // the existing variable, without diving into expressions in which we don't have the |
466 | // right to dive. |
467 | result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt( |
468 | result, predicate_selector, it_on_var->first, CanContainEligibleComputations); |
469 | } else { |
470 | // The current computation is not equivalent to a computation already done. We will |
471 | // need to see if we want to introduce it. |
472 | |
473 | // --- Chunk needed for reusing the UndefinedVars() analysis --- |
474 | // 1 - Wraps the computation into a statement |
475 | Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first); |
476 | // 2.1 - Transform the context into a vector of variables instead of pairs |
477 | std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value = |
478 | [](const std::pair<Var, MaybeValue>& pair) { return pair.first; }; |
479 | std::vector<Var> vector_vars_known = VectorMap(context_, forget_value); |
480 | // 2.2 - Transform the std::vector into an Array |
481 | Array<Var> array_vars_known = Array<Var>(vector_vars_known); |
482 | // --- End of chunk needed for reusing the UndefinedVars() analysis --- |
483 | |
484 | // We use the UndefinedVars() analysis to get the undefined vars of the computation |
485 | Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known); |
486 | |
487 | // Check if we can introduce it : if it contains no undefined variables and if we want |
488 | // to introduce it according to the predicate |
489 | if (vars_undefined.empty() && |
490 | PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) { |
491 | // Create a new variable for this computation |
492 | Var new_var = GenerateNewVar(computation_and_nb.first.dtype()); |
493 | variables_created = true; |
494 | // Replace in the current `result` everything that is selected by the selector with |
495 | // the new variable, without diving into expressions in which we don't have the |
496 | // right to dive. |
497 | result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(result, predicate_selector, new_var, |
498 | CanContainEligibleComputations); |
499 | // Build a let-in that introduces the new variable in the current `result` |
500 | result = LetStmt(new_var, computation_and_nb.first, result); |
501 | // We don't add the variable to the context because the invariant is that the |
502 | // context is the context in which 'result' makes sense, and we've just updated it. |
503 | } else { |
504 | // Here it's not doable to introduce (via a let-in) the computation at this level |
505 | // as it contains variables that are not yet declared, and/or because the predicate |
506 | // did not select it. |
507 | // Either way, we will simply add to the vector of computations the direct subexprs |
508 | // of the current computation, as these ones might be good candidates |
509 | // for being introduced into variables. |
510 | // Note that we don't need to add all of its subexpressions, but only its *direct* |
511 | // subexpressions as we consider them from biggest to smallest, and if they were |
512 | // all added at once, then there could be dependencies between them, as commoning |
513 | // one of them could remove some other possibilities. |
514 | |
515 | // Computing the direct subexpressions will return a small number of direct |
516 | // subexpressions (typically 0 to 3) |
517 | std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions( |
518 | computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations); |
519 | // The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by |
520 | // decreasing size/complexity), and it will only insert at locations > i as the |
521 | // direct subexprs are necessarily smaller than the current computation. |
522 | InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs, |
523 | identify_equiv_terms_); |
524 | } |
525 | } |
526 | // Note : we do not remove the current element, as we never look back in the local vector |
527 | } // End of for loop |
528 | |
529 | // If the CSE pass has created some variables, then we run it again as more commoning could |
530 | // potentially happen using the new variables introduced |
531 | if (variables_created) { |
532 | result = VisitStmt(result); |
533 | } else { |
534 | // But if no changes were performed, we recurse inside the children by calling the dispatcher. |
535 | // Calling the dispatcher to the specific treatments, which will update the context |
536 | // appropriately before doing the recursive calls on the children nodes |
537 | result = StmtExprMutator::VisitStmt(result); |
538 | } |
539 | |
540 | return result; |
541 | } |
542 | |
543 | /*! |
544 | * \brief The method which overrides the specific treatment for a LetStmtNode |
545 | */ |
546 | Stmt CommonSubexpressionEliminator::VisitStmt_(const LetStmtNode* op) { |
547 | // At this point, we have already done the generic treatment of introducing (via let-in) what |
548 | // was doable at the toplevel of the given let-in. |
549 | |
550 | // Save the context at the entry of the function |
551 | Context context_at_entry = context_; |
552 | |
553 | // Recurse on the `value` field for potentially rewriting it |
554 | PrimExpr value_new = VisitExpr(op->value); |
555 | |
556 | // Augment the context with the association (`var`, `value`) for preparing the next recursion |
557 | // on the `body` |
558 | context_.push_back({op->var, MaybeValue(op->value)}); |
559 | |
560 | // Recurse on the `body` (with this extended context) |
561 | // The recursive call will have potentially done new simplifications, because in this recursive |
562 | // call `var` will be a part of the context. |
563 | // (see in VisitStmt() that no introduction were performed when a computation was using an |
564 | // undefined variable, as that would lead to ill-formed code) |
565 | Stmt body_new = VisitStmt(op->body); |
566 | |
567 | // Restaure the context to its content at the entrance to not carry out of scope declarations |
568 | // as the variable introduced by the let-in is not in scope outside of its body |
569 | context_ = context_at_entry; |
570 | |
571 | // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might |
572 | // have been done. |
573 | |
574 | // If the `value` and the `body` of the let-in have been rewritten to the same thing |
575 | if (value_new.same_as(op->value) && body_new.same_as(op->body)) { |
576 | // Return a reference to the same node |
577 | return GetRef<Stmt>(op); |
578 | } else { |
579 | // Otherwise return a let-in built with the new `value_new` and the new `body_new` that |
580 | // have just been obtained |
581 | return LetStmt(op->var, value_new, body_new, op->span); |
582 | } |
583 | } |
584 | |
585 | /*! |
586 | * \brief The method which overrides the specific treatment for a ForNode |
587 | */ |
588 | Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) { |
589 | // At this point, we have already done the generic treatment of introducing (via let-in) what |
590 | // was doable at the toplevel of the given for loop. |
591 | |
592 | // Save the context at the entry of the function |
593 | Context context_at_entry = context_; |
594 | |
595 | // Recurse on the `min` field for potentially rewriting it |
596 | PrimExpr min_new = VisitExpr(op->min); |
597 | |
598 | // Recurse on the `extent` field for potentially rewriting it |
599 | PrimExpr extent_new = VisitExpr(op->extent); |
600 | |
601 | // Augment the context with the association {loop_var, no value} (no value as its value will |
602 | // change during the execution of the loop) for preparing the next recursion on the `body` |
603 | context_.push_back({op->loop_var, MaybeValue()}); |
604 | |
605 | // Recurse on the `body` (with this extended context) |
606 | Stmt body_new = VisitStmt(op->body); |
607 | |
608 | // Restaure the context to its content at the entrance to not carry out of scope declarations |
609 | // as the variable introduced by the for loop is not in scope outside of its body |
610 | context_ = context_at_entry; |
611 | |
612 | // Rebuild the for loop with (potentially) a new `min_new`, `extent_new` and `body_new`, where |
613 | // new simplifications might have been done. |
614 | |
615 | // If the `min`, `extent` and `body` of the for loop have been rewritten to the same thing |
616 | if (min_new.same_as(op->min) && extent_new.same_as(op->extent) && body_new.same_as(op->body)) { |
617 | // Return a reference to the same node |
618 | return GetRef<Stmt>(op); |
619 | } else { |
620 | // Otherwise return a for node built with the new `min_new`, `extent_new` and `body_new` |
621 | // that have just been obtained |
622 | return For(op->loop_var, min_new, extent_new, op->kind, body_new, op->thread_binding, |
623 | op->annotations, op->span); |
624 | } |
625 | } |
626 | |
627 | namespace transform { |
628 | |
629 | /*! |
630 | * \brief The function which returns the pass for the Common Subexpression Elimination. |
631 | * \return The pass for performing CSE. |
632 | */ |
633 | Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) { |
634 | auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule m, PassContext ctx) { |
635 | if (enable_cse_tir) { |
636 | auto* n = f.CopyOnWrite(); |
637 | Context context_init; |
638 | // Add to the initial context all the parameters of the function, as that is needed for |
639 | // doing commoning on terms that use these parameters (it is only possible to introduce |
640 | // a term into a new variable at a specific point in the program if all the variables that |
641 | // it uses have already been declared at this point) |
642 | for (auto current_param : f->params) { |
643 | // The parameters of the functions are variables associated with no value |
644 | context_init.push_back({current_param, MaybeValue()}); |
645 | } |
646 | |
647 | // Do the Common Subexpression Elimination on the body of the function, with the initial |
648 | // context that we have prepared |
649 | n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init, |
650 | identify_equiv_terms); |
651 | } |
652 | |
653 | return f; |
654 | }; |
655 | return CreatePrimFuncPass(pass_func, 0, "tir.CommonSubexprElimTIR" , {}); |
656 | } |
657 | |
658 | // The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it |
659 | TVM_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR" ).set_body_typed(CommonSubexprElimTIR); |
660 | |
661 | } // namespace transform |
662 | } // namespace tir |
663 | } // namespace tvm |
664 | |