1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file common_subexpr_elim.cc
22 * \brief Implementation of the Common Subexpressions Elimination (CSE) pass
23 which rewrites statements and expressions in order to eliminate
24 redundant computations. In order to achieve that, common (sub-)
25 expressions are introduced into variables with let-in bindings,
26 and the places where the expression was used are replaced with
27 the freshly introduced variable.
28 */
29
30#include "common_subexpr_elim.h"
31
32#include <tvm/ir/transform.h> // For the class Pass and the class PassContext
33#include <tvm/runtime/container/array.h>
34#include <tvm/runtime/container/string.h>
35#include <tvm/tir/analysis.h> // For the analysis which gives the size of an expr
36#include <tvm/tir/expr.h>
37#include <tvm/tir/expr_functor.h>
38#include <tvm/tir/function.h> // For the class PrimFunc
39#include <tvm/tir/stmt.h>
40#include <tvm/tir/stmt_functor.h>
41#include <tvm/tir/transform.h> // For the decl of the function returning the pass
42
43#include <algorithm> // For the algorithm std::find
44#include <iostream>
45#include <string>
46#include <unordered_map> // For the hashtable datatype
47#include <utility> // For std::pair and std::move
48#include <vector>
49
50#include "../analysis/check_contains.h" // For the visitor CheckContains
51#include "common_subexpr_elim_tools.h" // For the auxiliary analysis (visitors) and tools
52#include "replace_selected_expr.h" // For the mutator ReplaceSelectedExpr
53
54namespace tvm {
55namespace tir {
56
57/*!
58 * \brief Check whether a computation is forbidden for being treated by the CSE pass.
59 The important thing about forbidden computations is that not only we won't want
60 to collect them for the CSE pass, but we also won't even want to collect computations
61 that contain them.
62 The reason is that reusing such computations would change the semantics of the program,
63 and therefore before doing any introduction of var or any reuse of already introduced
64 variables, we will make sure that the computation being considered is not forbidden, and
65 that it does not even contain a forbidden computation.
66 * \param expr The expression to check
67 * \return Whether `expr` is a forbidden computation or not
68 */
69bool CommonSubexpressionEliminator::ForbiddenComputation(const PrimExpr& expr) {
70 // Function calls, loads and buffer loads are absolutely forbidden as introducing them into
71 // variables would change the semantics of the program.
72 return (expr.as<CallNode>() != nullptr || expr.as<LoadNode>() != nullptr ||
73 expr.as<BufferLoadNode>() != nullptr);
74}
75
76/*!
77 * \brief Predicate used for verifying that a computation is eligible for being treated by
78 the CSE pass, i.e. for being introduced into a variable / for being replaced by a
79 variable.
80 Being eligible is a conjunction of a few conditions, like not being an atom (constant
81 or variable), not being a forbidden node, not containing a forbidden node, etc.
82 * \param expr The expression to check
83 * \return Whether `expr` is an eligible computation or not
84 */
85bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) {
86 return (
87 // In order to be eligible, the given expression should not be a constant
88 (expr.as<IntImmNode>() == nullptr) && (expr.as<FloatImmNode>() == nullptr) &&
89 (expr.as<StringImmNode>() == nullptr)
90 // and it should not be a variable
91 && (expr.as<VarNode>() == nullptr)
92 // and it should not be a forbidden computation (function calls and loads)
93 && (!ForbiddenComputation(expr))
94 // and it should not even contain a forbidden computation (function calls and loads)
95 // the reason is that we don't want to register expressions like (x + f(y)) or
96 // (x + Mem[i]) as introducing them into variables could change the semantics
97 && (!CheckContains::ExprContains(expr, ForbiddenComputation))
98 // and it should not be a ramp node or a broadcast node due to some internals TVM
99 // constraints (which check for these node explicitely without performing any
100 // evaluation first, so if they have been put into variables it fails)
101 && (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
102}
103
104/*!
105 * \brief Predicate used (when considering eligible computations) for only diving into
106 expressions that are allowed to contain eligible computations. Customize this predicate
107 if you want to make it forbidden to rewrite inside a specific node, like inside
108 a Load node for instance.
109 * \param expr The expression to check
110 * \return Whether `expr` can contain some eligible computations or not, and therefore
111 if recursing inside `expr` is necessary.
112 */
113bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExpr& expr) {
114 // Uncomment the next line to prevent the collection and the replacement of eligible computations
115 // inside the index of Load nodes. We initially thought that this would be needed in order to
116 // not harm the indexing mode of the CPU, but as we are still far from ASM code, we
117 // finally want to perform such simplifications, which tend to happen fairly frequently.
118
119 // return ( (expr.as<LoadNode>() == nullptr) && (expr.as<BufferLoadNode>() == nullptr) )
120 return true;
121}
122
123/*!
124 * \brief Implements an order on pairs (expression,frequency). First attempts to compare them
125 using the size of the expression. If it is the same, decides something else still
126 deterministic.
127 * \param a The first pair
128 * \param b The second pair
129 * \return A boolean telling if the first pair `a` comes before the second pair `b`
130 * \note We need this order to be deterministic in order to have a fully deterministic pass,
131 * as we will deal with elements that are coming from a hashtable, but the order in which
132 * they appeared in the hashtable was based on some runtime addresses, so it can potentially
133 * change with every execution.
134 */
135bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a,
136 std::pair<PrimExpr, size_t> b) {
137 size_t a_size = CalculateExprComplexity(a.first);
138 size_t b_size = CalculateExprComplexity(b.first);
139
140 // Criteria 1 - Size of the expression comes first
141 // `a` comes before `b` if the size of `a` is bigger
142 if (a_size > b_size) {
143 return true;
144 }
145 // `a` does NOT come before `b` if the size of `b` is bigger
146 if (b_size > a_size) {
147 return false;
148 }
149
150 // Criteria 2 - If they had the same size, use the lexicographic order as a last resort
151 // as we need a deterministic order
152 std::stringstream a_stream;
153 std::stringstream b_stream;
154 a_stream << AsLegacyRepr(a.first);
155 b_stream << AsLegacyRepr(b.first);
156 return (a_stream.str().compare(b_stream.str()) < 0);
157}
158
159/*!
160 * \brief Generates a new fresh variable, whose name will be cse_var_i.
161 * \param type_annotation The type of the new variable to generate
162 * \return A new variable of type `type_annotation` called cse_var_i where i is the first available
163 integer.
164 */
165Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
166 // Increase `num_last_try_` for this new attempt
167 num_last_try_++;
168 // Builds the variable name, which is sce_var_i where i will go up from 1
169 std::string prefix = "cse_var_";
170 std::string name = prefix.append(std::to_string(num_last_try_));
171 // Builds a String using the std::string
172 String string_name(name);
173
174 // Check that the name that we want to use for the new variable isn't already being used
175 // (names don't really have to be unique as they are just hints, and having the same name
176 // doesn't means that it's the same variable, but it's clearer for dumps)
177 if (UsesVarName::StmtUsesVarName(initial_body_, string_name)) {
178 // If the name is already used, call ourselves recursively for trying with the next one
179 return GenerateNewVar(type_annotation);
180 }
181
182 // Increase `nb_var_` for this new generation of variable that we have just done
183 nb_var_++;
184
185 // Return a new Variable using the name built and the given type_annotation
186 return (Var(string_name, type_annotation));
187}
188
189/*!
190 * \brief Gives the number of variables generated by the CSE on the current function
191 (i.e., getter for `nb_var_`).
192 * \return A copy of `nb_var_`
193 */
194int CommonSubexpressionEliminator::GetNbVarGenerated() { return nb_var_; }
195
196/*!
197 * \brief Toplevel (static) method that performs Common Subexpression Elimination on
198 a given statement (which should be the body of a PrimFunc). This method should be
199 called for each PrimFunc definition.
200 * \param stmt The statement of the function being analyzed, on which we want to perform CSE
201 * \param context_init The initial context, which should contain the formal parameters
202 of the function being analyzed
203 * \return A new statement where CSE has been performed
204 */
205Stmt CommonSubexpressionEliminator::PerformCSE(const Stmt& stmt, const Context& context_init,
206 bool identify_equiv_terms) {
207 // As this function is being called for each PrimFunc definition, we create a new instance
208 // for the one we are having now.
209 CommonSubexpressionEliminator common_subexpression_eliminator(stmt, context_init,
210 identify_equiv_terms);
211 return common_subexpression_eliminator.VisitStmt(stmt);
212}
213
214/*!
215 * \brief Protected constructor of CommonSubexpressionEliminator.
216 * \param context_init The context at the beginning of the CSE pass. It should contain the
217 formal parameters of the function that will be analyzed
218 */
219CommonSubexpressionEliminator::CommonSubexpressionEliminator(const Stmt& stmt,
220 const Context& context_init,
221 bool identify_equiv_terms)
222 : initial_body_(stmt), context_(context_init), identify_equiv_terms_(identify_equiv_terms) {}
223
224/*!
225 * \brief The method which overrides the generic dispatcher of StmtExprMutator.
226 Entry point to the common subexpression elimination mutator for expressions.
227 * \param expr The expression to mutate
228 */
229PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) {
230 bool variables_created = false; // Will be needed for knowing if the CSE has created new vars
231 PrimExpr result = expr;
232
233 // Obtain the (syntactic) eligible computations done by the input expression, and keep it as
234 // a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the
235 // number of time this exact syntactic computation is being computed.
236 ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
237 expr, IsEligibleComputation, CanContainEligibleComputations);
238
239 // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
240 // containing *semantic* entities, i.e. where equivalent computations are merged.
241 std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
242 SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_);
243
244 // Sort the vector of semantic entities by decreasing size
245 std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(),
246 OrderOnExprAndFrequency);
247
248 // For each computation done (considering them from biggest to smallest)
249 for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
250 std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
251
252 bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"
253
254 // The predicate later used (when doing replacements) to select expressions that are
255 // equivalent to the current computation (`computation_and_nb.first`)
256 std::function<bool(const PrimExpr&)> predicate_selector =
257 [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
258 // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
259 // that `current_expr` is an eligible computation even if we know that
260 // `computation_and_nb.first` is eligible by construction, in case that one day the
261 // equivalence relation would not preserve the eligibility any more (even though that
262 // would probably be a very weird equivalence).
263 return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
264 IsEligibleComputation(current_expr));
265 };
266
267 // See if there is a pair (`var`, `value`) in the context where `value` is semantically
268 // equivalent to `computation_and_nb.first`
269 auto it_on_var = std::find_if(
270 context_.begin(), context_.end(),
271 [computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
272 // Note : safe to call value() as we check has_value() just before
273 return (var_and_value.second.has_value() &&
274 EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
275 ident_equiv_terms));
276 });
277
278 // Case where we have a perfectly equivalent computation already available in a variable
279 // introduced (i.e, present in context_).
280 // Note that this case is needed when the user has written something like
281 // [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by
282 // an already existing variable holding A, when such a variable happens to exist.
283 if (it_on_var != context_.end()) {
284 // Replace in the current `result` everything that is selected by the selector with
285 // the existing variable, without diving into expressions in which we don't have the
286 // right to dive.
287 result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(
288 result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
289 } else {
290 // The current computation is not equivalent to a computation already done. We will
291 // need to see if we want to introduce it.
292
293 // --- Chunk needed for reusing the UndefinedVars() analysis ---
294 // 1 - Wraps the computation into a statement
295 Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
296 // 2.1 - Transform the context into a vector of variables instead of pairs
297 std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
298 [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
299 std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
300 // 2.2 - Transform the std::vector into an Array
301 Array<Var> array_vars_known = Array<Var>(vector_vars_known);
302 // --- End of chunk needed for reusing the UndefinedVars() analysis ---
303
304 // We use the UndefinedVars() analysis to get the undefined vars of the computation
305 Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
306
307 // Check if we can introduce it : if it contains no undefined variables and if we want
308 // to introduce it according to the predicate
309 if (vars_undefined.empty() &&
310 PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
311 // Create a new variable for this computation
312 Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
313 // Replace in the current `result` everything that is selected by the selector with
314 // the new variable, without diving into expressions in which we don't have the
315 // right to dive.
316 result = ReplaceSelectedExpr::ReplaceSelectedExprInExpr(result, predicate_selector, new_var,
317 CanContainEligibleComputations);
318 // Build a let-in that introduces the new variable in the current `result`
319 result = Let(new_var, computation_and_nb.first, result);
320 // We don't add the variable to the context because the invariant is that the
321 // context is the context in which 'result' makes sense, and we've just updated it.
322 } else {
323 // Here it's not doable to introduce (via a let-in) the computation at this level
324 // as it contains variables that are not yet declared, and/or because the predicate
325 // did not select it.
326 // Either way, we will simply add to the vector of computations the direct subexprs
327 // of the current computation, as these ones might be good candidates
328 // for being introduced into variables.
329 // Note that we don't need to add all of its subexpressions, but only its *direct*
330 // subexpressions as we consider them from biggest to smallest, and if they were
331 // all added at once, then there could be dependencies between them, as commoning
332 // one of them could remove some other possibilities.
333
334 // Computing the direct subexpressions will return a small number of direct
335 // subexpressions (typically 0 to 3)
336 std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
337 computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
338 // The following insertion will maintain `semantic_comp_done_by_expr` sorted (by
339 // decreasing size/complexity), and it will only insert at locations > i as the
340 // direct subexprs are necessarily smaller than the current computation.
341 InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs,
342 identify_equiv_terms_);
343 }
344 }
345 // Note : we do not remove the current element, as we never look back in the local vector
346 } // End of for loop
347
348 // If the CSE pass has created some variables, then we run it again as more commoning could
349 // potentially happen using the new variables introduced
350 if (variables_created) {
351 result = VisitExpr(result);
352 } else {
353 // But if no changes were performed, we recurse inside the children by calling the dispatcher.
354 // Calling the dispatcher to the specific treatments, which will update the context
355 // appropriately before doing the recursive calls on the children nodes
356 result = StmtExprMutator::VisitExpr(result);
357 }
358
359 return result;
360}
361
362/*!
363 * \brief The method which overrides the specific treatment for a LetNode
364 */
365PrimExpr CommonSubexpressionEliminator::VisitExpr_(const LetNode* op) {
366 // At this point, we have already done the generic treatment of introducing (via let-in) what
367 // was doable at the toplevel of the given let-in.
368
369 // Save the context at the entry of the function
370 Context context_at_entry = context_;
371
372 // Recurse on the `value` field for potentially rewriting it
373 PrimExpr value_new = VisitExpr(op->value);
374
375 // Augment the context with the association (`var`, `value`) for preparing the next recursion
376 // on the `body`
377 context_.push_back({op->var, MaybeValue(op->value)});
378
379 // Recurse on the `body` (with this extended context)
380 // The recursive call will have potentially done new simplifications, because in this recursive
381 // call `var` will be a part of the context.
382 // (see in VisitExpr() that no introduction were performed when a computation was using an
383 // undefined variable, as that would lead to ill-formed code)
384 PrimExpr body_new = VisitExpr(op->body);
385
386 // Restaure the context to its content at the entrance to not carry out of scope declarations
387 // as the variable introduced by the let-in is not in scope outside of its body
388 context_ = context_at_entry;
389
390 // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
391 // have been done.
392
393 // If the `value` and the `body` of the let-in have been rewritten to the same thing
394 if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
395 // then return a reference to the same node
396 return GetRef<PrimExpr>(op);
397 } else {
398 // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
399 // have just been obtained
400 return Let(op->var, value_new, body_new, op->span);
401 }
402}
403
404/*!
405 * \brief The method which overrides the generic dispatcher of StmtExprMutator.
406 Entry point to the common subexpression elimination mutator for statements.
407 * \param stmt The statement to mutate.
408 */
409Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) {
410 bool variables_created = false; // Will be needed for knowing if the CSE has created new vars
411 Stmt result = stmt;
412
413 // Obtain the (syntactic) eligible computations done by the input statement, and keep it as
414 // a ComputationTable, which is a mapping from PrimExpr to size_t, where the size_t is the
415 // number of time this exact syntactic computation is being computed.
416 ComputationTable table_syntactic_comp_done_by_stmt = ComputationsDoneBy::GetComputationsDoneBy(
417 stmt, IsEligibleComputation, CanContainEligibleComputations);
418
419 // Transform the hashtable of *syntactic* eligible computations into a vector of pairs
420 // containing *semantic* entities, i.e. where equivalent computations are merged.
421 std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
422 SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_);
423
424 // Sort the vector of semantic entities by decreasing size
425 std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(),
426 OrderOnExprAndFrequency);
427
428 // For each computation done (considering them from biggest to smallest)
429 for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
430 std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];
431
432 bool ident_equiv_terms = identify_equiv_terms_; // To avoid the capture of "this"
433
434 // The predicate later used (when doing replacements) to select expressions that are
435 // equivalent to the current computation (`computation_and_nb.first`)
436 std::function<bool(const PrimExpr&)> predicate_selector =
437 [computation_and_nb, ident_equiv_terms](const PrimExpr& current_expr) {
438 // `current_expr` should be equivalent to `computation_and_nb.first`, but we also check
439 // that `current_expr` is an eligible computation even if we know that
440 // `computation_and_nb.first` is eligible by construction, in case that one day the
441 // equivalence relation would not preserve the eligibility any more (even though that
442 // would probably be a very weird equivalence).
443 return (EquivalentTerms(current_expr, computation_and_nb.first, ident_equiv_terms) &&
444 IsEligibleComputation(current_expr));
445 };
446
447 // See if there is a pair (`var`, `value`) in the context where `value` is semantically
448 // equivalent to `computation_and_nb.first`
449 auto it_on_var = std::find_if(
450 context_.begin(), context_.end(),
451 [computation_and_nb, ident_equiv_terms](const std::pair<Var, MaybeValue>& var_and_value) {
452 // Note : safe to call value() as we check has_value() just before
453 return (var_and_value.second.has_value() &&
454 EquivalentTerms(var_and_value.second.value(), computation_and_nb.first,
455 ident_equiv_terms));
456 });
457
458 // Case where we have a perfectly equivalent computation already available in a variable
459 // introduced (i.e, present in context_).
460 // Note that this case is needed when the user has written something like
461 // [let x = A in ....A...A...] : we need to be able to replace all the occurrences of A by
462 // an already existing variable holding A, when such a variable happens to exist.
463 if (it_on_var != context_.end()) {
464 // Replace in the current `result` everything that is selected by the selector with
465 // the existing variable, without diving into expressions in which we don't have the
466 // right to dive.
467 result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(
468 result, predicate_selector, it_on_var->first, CanContainEligibleComputations);
469 } else {
470 // The current computation is not equivalent to a computation already done. We will
471 // need to see if we want to introduce it.
472
473 // --- Chunk needed for reusing the UndefinedVars() analysis ---
474 // 1 - Wraps the computation into a statement
475 Stmt computation_wrapped_in_stmt = Evaluate(computation_and_nb.first);
476 // 2.1 - Transform the context into a vector of variables instead of pairs
477 std::function<Var(const std::pair<Var, MaybeValue>&)> forget_value =
478 [](const std::pair<Var, MaybeValue>& pair) { return pair.first; };
479 std::vector<Var> vector_vars_known = VectorMap(context_, forget_value);
480 // 2.2 - Transform the std::vector into an Array
481 Array<Var> array_vars_known = Array<Var>(vector_vars_known);
482 // --- End of chunk needed for reusing the UndefinedVars() analysis ---
483
484 // We use the UndefinedVars() analysis to get the undefined vars of the computation
485 Array<Var> vars_undefined = UndefinedVars(computation_wrapped_in_stmt, array_vars_known);
486
487 // Check if we can introduce it : if it contains no undefined variables and if we want
488 // to introduce it according to the predicate
489 if (vars_undefined.empty() &&
490 PredicateIntroVarForComputation(computation_and_nb.first, computation_and_nb.second)) {
491 // Create a new variable for this computation
492 Var new_var = GenerateNewVar(computation_and_nb.first.dtype());
493 variables_created = true;
494 // Replace in the current `result` everything that is selected by the selector with
495 // the new variable, without diving into expressions in which we don't have the
496 // right to dive.
497 result = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(result, predicate_selector, new_var,
498 CanContainEligibleComputations);
499 // Build a let-in that introduces the new variable in the current `result`
500 result = LetStmt(new_var, computation_and_nb.first, result);
501 // We don't add the variable to the context because the invariant is that the
502 // context is the context in which 'result' makes sense, and we've just updated it.
503 } else {
504 // Here it's not doable to introduce (via a let-in) the computation at this level
505 // as it contains variables that are not yet declared, and/or because the predicate
506 // did not select it.
507 // Either way, we will simply add to the vector of computations the direct subexprs
508 // of the current computation, as these ones might be good candidates
509 // for being introduced into variables.
510 // Note that we don't need to add all of its subexpressions, but only its *direct*
511 // subexpressions as we consider them from biggest to smallest, and if they were
512 // all added at once, then there could be dependencies between them, as commoning
513 // one of them could remove some other possibilities.
514
515 // Computing the direct subexpressions will return a small number of direct
516 // subexpressions (typically 0 to 3)
517 std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
518 computation_and_nb.first, IsEligibleComputation, CanContainEligibleComputations);
519 // The following insertion will maintain `semantic_comp_done_by_stmt` sorted (by
520 // decreasing size/complexity), and it will only insert at locations > i as the
521 // direct subexprs are necessarily smaller than the current computation.
522 InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs,
523 identify_equiv_terms_);
524 }
525 }
526 // Note : we do not remove the current element, as we never look back in the local vector
527 } // End of for loop
528
529 // If the CSE pass has created some variables, then we run it again as more commoning could
530 // potentially happen using the new variables introduced
531 if (variables_created) {
532 result = VisitStmt(result);
533 } else {
534 // But if no changes were performed, we recurse inside the children by calling the dispatcher.
535 // Calling the dispatcher to the specific treatments, which will update the context
536 // appropriately before doing the recursive calls on the children nodes
537 result = StmtExprMutator::VisitStmt(result);
538 }
539
540 return result;
541}
542
543/*!
544 * \brief The method which overrides the specific treatment for a LetStmtNode
545 */
546Stmt CommonSubexpressionEliminator::VisitStmt_(const LetStmtNode* op) {
547 // At this point, we have already done the generic treatment of introducing (via let-in) what
548 // was doable at the toplevel of the given let-in.
549
550 // Save the context at the entry of the function
551 Context context_at_entry = context_;
552
553 // Recurse on the `value` field for potentially rewriting it
554 PrimExpr value_new = VisitExpr(op->value);
555
556 // Augment the context with the association (`var`, `value`) for preparing the next recursion
557 // on the `body`
558 context_.push_back({op->var, MaybeValue(op->value)});
559
560 // Recurse on the `body` (with this extended context)
561 // The recursive call will have potentially done new simplifications, because in this recursive
562 // call `var` will be a part of the context.
563 // (see in VisitStmt() that no introduction were performed when a computation was using an
564 // undefined variable, as that would lead to ill-formed code)
565 Stmt body_new = VisitStmt(op->body);
566
567 // Restaure the context to its content at the entrance to not carry out of scope declarations
568 // as the variable introduced by the let-in is not in scope outside of its body
569 context_ = context_at_entry;
570
571 // Rebuild the let-in with a new `value_new` and `body_new` where new simplifications might
572 // have been done.
573
574 // If the `value` and the `body` of the let-in have been rewritten to the same thing
575 if (value_new.same_as(op->value) && body_new.same_as(op->body)) {
576 // Return a reference to the same node
577 return GetRef<Stmt>(op);
578 } else {
579 // Otherwise return a let-in built with the new `value_new` and the new `body_new` that
580 // have just been obtained
581 return LetStmt(op->var, value_new, body_new, op->span);
582 }
583}
584
585/*!
586 * \brief The method which overrides the specific treatment for a ForNode
587 */
588Stmt CommonSubexpressionEliminator::VisitStmt_(const ForNode* op) {
589 // At this point, we have already done the generic treatment of introducing (via let-in) what
590 // was doable at the toplevel of the given for loop.
591
592 // Save the context at the entry of the function
593 Context context_at_entry = context_;
594
595 // Recurse on the `min` field for potentially rewriting it
596 PrimExpr min_new = VisitExpr(op->min);
597
598 // Recurse on the `extent` field for potentially rewriting it
599 PrimExpr extent_new = VisitExpr(op->extent);
600
601 // Augment the context with the association {loop_var, no value} (no value as its value will
602 // change during the execution of the loop) for preparing the next recursion on the `body`
603 context_.push_back({op->loop_var, MaybeValue()});
604
605 // Recurse on the `body` (with this extended context)
606 Stmt body_new = VisitStmt(op->body);
607
608 // Restaure the context to its content at the entrance to not carry out of scope declarations
609 // as the variable introduced by the for loop is not in scope outside of its body
610 context_ = context_at_entry;
611
612 // Rebuild the for loop with (potentially) a new `min_new`, `extent_new` and `body_new`, where
613 // new simplifications might have been done.
614
615 // If the `min`, `extent` and `body` of the for loop have been rewritten to the same thing
616 if (min_new.same_as(op->min) && extent_new.same_as(op->extent) && body_new.same_as(op->body)) {
617 // Return a reference to the same node
618 return GetRef<Stmt>(op);
619 } else {
620 // Otherwise return a for node built with the new `min_new`, `extent_new` and `body_new`
621 // that have just been obtained
622 return For(op->loop_var, min_new, extent_new, op->kind, body_new, op->thread_binding,
623 op->annotations, op->span);
624 }
625}
626
627namespace transform {
628
629/*!
630 * \brief The function which returns the pass for the Common Subexpression Elimination.
631 * \return The pass for performing CSE.
632 */
633Pass CommonSubexprElimTIR(bool enable_cse_tir, bool identify_equiv_terms) {
634 auto pass_func = [enable_cse_tir, identify_equiv_terms](PrimFunc f, IRModule m, PassContext ctx) {
635 if (enable_cse_tir) {
636 auto* n = f.CopyOnWrite();
637 Context context_init;
638 // Add to the initial context all the parameters of the function, as that is needed for
639 // doing commoning on terms that use these parameters (it is only possible to introduce
640 // a term into a new variable at a specific point in the program if all the variables that
641 // it uses have already been declared at this point)
642 for (auto current_param : f->params) {
643 // The parameters of the functions are variables associated with no value
644 context_init.push_back({current_param, MaybeValue()});
645 }
646
647 // Do the Common Subexpression Elimination on the body of the function, with the initial
648 // context that we have prepared
649 n->body = CommonSubexpressionEliminator::PerformCSE(std::move(f->body), context_init,
650 identify_equiv_terms);
651 }
652
653 return f;
654 };
655 return CreatePrimFuncPass(pass_func, 0, "tir.CommonSubexprElimTIR", {});
656}
657
658// The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it
659TVM_REGISTER_GLOBAL("tir.transform.CommonSubexprElimTIR").set_body_typed(CommonSubexprElimTIR);
660
661} // namespace transform
662} // namespace tir
663} // namespace tvm
664