1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21* \file common_subexpr_elim_tools.cc
22* \brief Implementation of analysis tools and utility functions used
23 by the Common Subexpression Elimination (CSE) pass.
24*/
25
26#include "common_subexpr_elim_tools.h"
27
28#include <tvm/arith/analyzer.h> // For the arith::Analyzer::Simplify() method simplifying terms
29#include <tvm/ir/transform.h> // For the class Pass and the class PassContext
30#include <tvm/runtime/container/string.h>
31#include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis
32#include <tvm/tir/expr.h>
33#include <tvm/tir/expr_functor.h>
34#include <tvm/tir/function.h> // For the class PrimFunc
35#include <tvm/tir/stmt.h>
36#include <tvm/tir/stmt_functor.h>
37#include <tvm/tir/transform.h> // For the declaration of the pass
38
39#include <algorithm> // For std::find_if
40#include <unordered_map> // For the hashtable datatype
41#include <utility>
42#include <vector>
43
44#include "../analysis/check_contains.h" // For the CheckContains analysis
45
46namespace tvm {
47namespace tir {
48
49// cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here
50// such static attribute, otherwise it causes a linking error.
51ComputationCache ComputationsDoneBy::cache_;
52
53/* ********************************** Class ComputationsDoneBy **********************************
54*********************************************************************************************** */
55
56/* This utility class of the CSE pass offers a way of knowing the eligible computations done by a
57 statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr.
58 This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which
59 is the number of time that this computation is being seen).
60 This analysis is used by the CSE pass in order to find potential candidates for being introduced
61 into new variables (after having merged semantically equivalent computations).
62
63 This analysis is parametrized by two predicates : `is_eligible_computation` and
64 `can_contain_computations`.
65 The first one helps to select only "eligible" computations, and the second one helps to only
66 select computations that are located at appropriate location (i.e., it tells in which nodes the
67 analysis can recurse). The user of the class must define these notions of "eligible computation"
68 and of "nodes that can contain eligibile computations" for his own use case.
69
70 - On an statement, this analysis often returns the union of all the computations that appear in
71 its child nodes (ie, the union of the results of the recursive calls).
72 For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y)
73 seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates.
74 On some nodes, it will return something more complicated that uses the intersection of the
75 computations done by the children nodes.
76 For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return
77 (x+y) seen twice but it won't report b-x as is it seen only the else branch.
78
79 - On an expression, this analysis returns the expression itself, except if it is not eligible
80 for being introduced by the CSE pass into a variable according to `is_eligible_computation_`
81 (often because it's a load node or a function call node for instance), in which case it will
82 return the union of the recursive calls on its children, as long as the other predicate
83 `can_contain_computations` evaluates to true to let the algorithm recurse deeper.
84 With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression
85 itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node
86 might not be eligible.
87
88 This class uses an internal cache of results, so that if one queries it several times on the
89 same statement or expression, it will just retrieve the result from its internal cache.
90 That avoids some systematic recomputations, which would otherwise happen as the CSE pass first
91 analyses the program at the toplovel (asking for the computations done by the root), and then
92 dives deeper and deeper into the program, asking for the computations done by the children of
93 the root, which were necessarly previously obtained when computing the computations done by the
94 root (as the computations done by the root are by definition the union of the computations done
95 by the children nodes).
96
97 The somehow difficult aspect of the implementation is the interaction between this caching of
98 results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are
99 void methods which can't return anything, and instead need to accumulate a result into a member
100 variable, which is called `table_of_computations_` here.
101
102 In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just
103 call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't
104 want to override each of these specialized methods to change this behaviour, then
105 `table_of_computations_` will necessary be shared by all the children of a given nodes.
106 That requires to be careful when trying to write into the cache.
107*/
108
109/*!
110 * \brief Does the union of two tables of computations.
111 * \param table_main Pointer to one of the two tables. The union will be written into it.
112 * \param table_aux The other table, which won't change.
113 * \note Does it directly in the first argument A for efficiency, as the union of A and B
114 * necessarily gives something which contains A, so we avoid its copy.
115 */
116void UnionOfComputationTables(ComputationTable* table_main, const ComputationTable& table_aux) {
117 if (table_main == nullptr) {
118 return;
119 }
120 // Adds each element of the second table to the first one
121 for (const auto& current : table_aux) {
122 (*table_main)[current.first] += current.second;
123 }
124}
125
126/*!
127 * \brief Does the union of three tables of computations.
128 * \param table1 One of the three tables, which won't change.
129 * \param table2 One of the three tables, which won't change.
130 * \param table3 One of the three tables, which won't change.
131 * \note We don't need (at least yet) to have a function working for N tables, even if this
132 * function for 3 tables seems at first glance redundant with the one for 2 tables defined
133 * just above. The reason is that in order to do the union for N tables, we need to know how
134 * to do it for two. That's because we would compute for N tables using the associativity
135 * of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn
136 * Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
137 * (at least for now) for N=3, there is at the moment no need for such a generic union over
138 * N tables.
139 */
140ComputationTable UnionOfComputationTables(const ComputationTable& table1,
141 const ComputationTable& table2,
142 const ComputationTable& table3) {
143 ComputationTable result = table1; // Copy needed as the union of 2 writes into its first arg
144 UnionOfComputationTables(&result, table2);
145 UnionOfComputationTables(&result, table3);
146
147 return result;
148}
149
150/*!
151 * \brief Does the intersection of two tables of computations.
152 * \param table1 One of the two tables, which won't change.
153 * \param table2 The other table, which also won't change.
154 */
155ComputationTable IntersectComputationTables(const ComputationTable& table1,
156 const ComputationTable& table2) {
157 ComputationTable result;
158 for (const auto& current : table1) {
159 auto it = table2.find(current.first);
160 if (it != table2.end()) {
161 result[current.first] = current.second + it->second;
162 }
163 }
164 return result;
165}
166
167/*!
168 * \brief Does the intersection of three tables of computations.
169 * \param table1 One of the three tables, which won't change.
170 * \param table2 One of the three tables, which won't change.
171 * \param table3 One of the three tables, which won't change.
172 * \note We don't need (at least yet) to have a function working for N tables, even if this
173 * function for 3 tables seems at first glance redundant with the one for 2 tables defined
174 * just above. The reason is that in order to do the intersection for N tables, we need to
175 * know how to do it for two. That's because we would compute for N tables using the
176 * associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn
177 * = ((T1 Inter T2) Inter T3) ... Inter Tn
178 * Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used
179 * (at least for now) for N=3, there is at the moment no need for such a generic intersection
180 * over N tables.
181 */
182ComputationTable IntersectComputationTables(const ComputationTable& table1,
183 const ComputationTable& table2,
184 const ComputationTable& table3) {
185 ComputationTable result = IntersectComputationTables(table1, table2);
186 result = IntersectComputationTables(result, table3);
187 return result;
188}
189
190/*!
191 * \brief Recompute the number of times that each computation in table_main is seen in the tables
192 contained by the vector of tables vecTables. It sets each element to the sum of the times
193 it is seen in each individual table.
194 * \param table_main The main table, for which we recompute the counters.
195 * \param vecTables The vector of tables which won't change.
196 * \note This function is needed because both the intersection (A Inter B) and the union
197 * (A U B U C) adds the individual counters found in A, B and C. So when we treat for
198 * instance an If (which contains a Cond, a Then branch and an Else branch),
199 * it will compute (Then Inter Else) U (Cond Inter Then) U (Cond Inter Else).
200 * In order to get back to the appropriate number (for instance, 3 if seen one time in each
201 * bloc), it is therefore necessary to recompute the counters afterwards, which is what this
202 * function does.
203 */
204void RecomputeNbTimesSeen(ComputationTable* table_main,
205 const std::vector<const ComputationTable*>& vec_tables) {
206 if (table_main == nullptr) {
207 return;
208 }
209 // For each element in the main table
210 for (auto& current_elem : *table_main) {
211 // We will recompute its associated counter.
212 // Set its count to zero as so far it has been seen zero times
213 current_elem.second = 0;
214 // For each table in the vector of tables
215 for (auto current_table : vec_tables) {
216 // Try to find current_elem in the current table
217 auto it = current_table->find(current_elem.first);
218 if (it != current_table->end()) {
219 // If found, increase its counter by the value found in the current table
220 current_elem.second += it->second;
221 }
222 }
223 }
224}
225
226/*!
227 * \brief Builds a table for a node that has three children. A computation will be reported
228 as being computed if it appears in at least two of the children, i.e. if it will aways be
229 computed, regardless of the execution path.
230 * \param table_child1 The table of computations done by the first child.
231 * \param table_child2 The table of computations done by the second child.
232 * \param table_child3 The table of computations done by the third child.
233 * \note This function will be used for obtaining the computations done by If nodes and by For
234 * nodes, which both have three children.
235 */
236ComputationTable BuildTableForThreeChildrenNode(const ComputationTable& table_child1,
237 const ComputationTable& table_child2,
238 const ComputationTable& table_child3) {
239 ComputationTable result;
240 // We look at what the children have in common
241 ComputationTable child2_inter_child3 = IntersectComputationTables(table_child2, table_child3);
242 ComputationTable child1_inter_child2 = IntersectComputationTables(table_child1, table_child2);
243 ComputationTable child1_inter_child3 = IntersectComputationTables(table_child1, table_child3);
244
245 // We do the union of all the things they have in common
246 result = UnionOfComputationTables(child2_inter_child3, child1_inter_child2, child1_inter_child3);
247
248 // Now we need to recompute the numbers associated with each computation, because both the
249 // intersections and the union might have increased the counters, which can now be wrong.
250 std::vector<const ComputationTable*> vec_tables = {&table_child1, &table_child2, &table_child3};
251 RecomputeNbTimesSeen(&result, vec_tables);
252
253 return result;
254}
255
256/*!
257 * \brief Toplevel (static) method for a PrimExpr
258 * \param expr The expr for which we want to know the computations done
259 * \param is_eligible_computation The predicate which decides if an expression is eligible for
260 being introduced in a new variable
261 * \param can_contain_computations The predicate which decides if an expression can contain an
262 eligible computation
263 */
264ComputationTable ComputationsDoneBy::GetComputationsDoneBy(
265 const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
266 std::function<bool(const PrimExpr&)> can_contain_computations) {
267 // Chunk for avoiding the lookup (and writing) in the cache for an atom (constant or variable),
268 // for which the table of computations is empty.
269 // (We don't want to use a "line of cache" of that, as that would cost an empty table of
270 // computations in memory for absolutely no gain)
271 if (expr.as<IntImmNode>() != nullptr || expr.as<FloatImmNode>() != nullptr ||
272 expr.as<StringImmNode>() != nullptr || expr.as<VarNode>() != nullptr) {
273 // Return an empty table
274 return {};
275 }
276
277 // See if we have already computed the (table of) computations done by `expr`
278 auto it_table_expr = cache_.cache_expr_table_computations_.find(expr);
279 if (it_table_expr != cache_.cache_expr_table_computations_.end()) {
280 // then we just return it
281 return it_table_expr->second;
282 }
283
284 // Otherwise we will need to compute it, by using an instance of the class ComputationsDoneBy
285 // (as we are currently in a static method)
286 ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations);
287 // Call the VisitExpr() method on it to start the visit
288 computations_done_by.VisitExpr(expr);
289 // Copy the `table_of_computations_` (that `computations_done_by` has computed) into the cache
290 // for the future queries
291 cache_.cache_expr_table_computations_[expr] = computations_done_by.table_of_computations_;
292
293 return computations_done_by.table_of_computations_;
294}
295
296/*!
297 * \brief Toplevel (static) method for a Stmt
298 * \param stmt The stmt for which we want to know the computations done
299 * \param is_eligible_computation The predicate which decides if an expression is eligible for
300 being introduced in a new variable
301 * \param can_contain_computations The predicate which decides if an expression can contain an
302 eligible computation
303 */
304ComputationTable ComputationsDoneBy::GetComputationsDoneBy(
305 const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation,
306 std::function<bool(const PrimExpr&)> can_contain_computations) {
307 // See if we have already computed the (table of) computations done by `stmt`
308 auto it_table_stmt = cache_.cache_stmt_table_computations_.find(stmt);
309 if (it_table_stmt != cache_.cache_stmt_table_computations_.end()) {
310 // then we just return it
311 return it_table_stmt->second;
312 }
313
314 // Otherwise we will need to compute it, by using an instance of the class ComputationsDoneBy
315 // (as we are currently in a static method)
316 ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations);
317 // Call the VisitStmt() method on it to start the visit
318 computations_done_by.VisitStmt(stmt);
319 // Copy the `table_of_computations_` that `computations_done_by` has computed into the cache
320 // for the future queries
321 cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_;
322
323 return computations_done_by.table_of_computations_;
324}
325
326/*!
327 * \brief Protected constructor of ComputationsDoneBy.
328 * \param is_eligible_computation The predicate which decides if an expression is eligible for
329 being introduced in a new variable
330 * \param can_contain_computations The predicate which decides if an expression can contain an
331 eligible computation
332 */
333ComputationsDoneBy::ComputationsDoneBy(
334 std::function<bool(const PrimExpr&)> is_eligible_computation,
335 std::function<bool(const PrimExpr&)> can_contain_computations)
336 : is_eligible_computation_(is_eligible_computation),
337 can_contain_computations_(can_contain_computations) {}
338
339/*!
340 * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions
341 * \param expr The expression to visit
342 */
343void ComputationsDoneBy::VisitExpr(const PrimExpr& expr) {
344 // Chunk for avoiding the lookup (and writing) in the cache for an atom (constant or variable),
345 // for which the table of computations is empty.
346 // (We don't want to use a "line of cache" of that, as that would cost an empty table of
347 // computations in memory for absolutely no gain)
348 if (expr.as<IntImmNode>() != nullptr || expr.as<FloatImmNode>() != nullptr ||
349 expr.as<StringImmNode>() != nullptr || expr.as<VarNode>() != nullptr) {
350 return;
351 }
352
353 // See if we have already computed the (table of) computations done by `expr`
354 auto it_table_expr = cache_.cache_expr_table_computations_.find(expr);
355 if (it_table_expr != cache_.cache_expr_table_computations_.end()) {
356 // We need to do the union with `table_of_computations_` instead of just writing into it,
357 // because some other childs might have added things into it too. The reason for that is
358 // that `table_of_computations_` is shared between the child nodes of a given expression.
359 UnionOfComputationTables(&table_of_computations_, it_table_expr->second);
360 return;
361 }
362
363 // If we reach this point, it means that we have never computed before the computations done
364 // by 'expr' and will do so now.
365
366 // If the given expression is an eligible computation, we simply "return it" by adding it into
367 // the "result variable" that `table_of_computations_` is.
368 if (is_eligible_computation_(expr)) {
369 // We can add `expr` to the table of computations
370 table_of_computations_[expr]++;
371 return;
372 }
373
374 // If we reach this point, then the given expression is not an eligible computation.
375 // But perhaps we have the right to dive into it to find some smaller eligible computations
376 if (can_contain_computations_(expr)) {
377 ComputationTable temp =
378 ComputationsDoneByChildrenOf(expr, is_eligible_computation_, can_contain_computations_);
379 // We need to do the union with `table_of_computations_` instead of just writing into it,
380 // because some other childs might have added things into it too. The reason for that is
381 // that `table_of_computations_` is shared between the child nodes of a given expression.
382 UnionOfComputationTables(&table_of_computations_, temp);
383 return;
384 }
385
386 // Note that we do not continue by calling the general disptacher
387 // StmtExprVisitor::VisitExpr(expr) as we want the full computations, not their subexpressions.
388}
389
390/*!
391 * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements
392 * \param stmt The statement to visit
393 */
394void ComputationsDoneBy::VisitStmt(const Stmt& stmt) {
395 // See if we have already computed the (table of) computations done by `stmt`
396 auto it_table_stmt = cache_.cache_stmt_table_computations_.find(stmt);
397 if (it_table_stmt != cache_.cache_stmt_table_computations_.end()) {
398 // We need to do the union with `table_of_computations_` instead of just writing into it,
399 // because some other childs might have added things into it too. The reason for that is
400 // that `table_of_computations_` is shared between the child nodes of a given statement.
401 UnionOfComputationTables(&table_of_computations_, it_table_stmt->second);
402 return;
403 }
404
405 // If we reach this point, it means that we have never computed before the computations done
406 // by `stmt` and will do so now.
407
408 // The computations done by a Stmt node are just the ones done by its children
409 ComputationTable temp =
410 ComputationsDoneByChildrenOf(stmt, is_eligible_computation_, can_contain_computations_);
411 // We need to do the union with `table_of_computations_` instead of just writing into it,
412 // because some other childs might have added things into it too. The reason for that is
413 // that `table_of_computations_` is shared between the child nodes of a given expression.
414 UnionOfComputationTables(&table_of_computations_, temp);
415}
416
417/*!
418 * \brief The method which overrides the specific treatment for an IfThenElseNode
419 */
420void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) {
421 // We build the computations done by each of its child, but unlike the overridden method we will
422 // remember each table of computations so that we can at the end compute the needed intersections
423
424 // Calls the VisitExpr() method on the `condition` child
425 VisitExpr(op->condition);
426 ComputationTable computations_done_by_cond = table_of_computations_;
427 // Clear it for not importing the computations of the condition in the computations of the then
428 table_of_computations_.clear();
429
430 // Then calls the VisitStmt() method on the `then_case` child
431 VisitStmt(op->then_case);
432 ComputationTable computations_done_by_then = table_of_computations_;
433 // Clear it for not importing the computations of the then in the computations of the else
434 table_of_computations_.clear();
435
436 ComputationTable computations_done_by_else;
437 if (op->else_case) {
438 // And finally calls the VisitStmt() method on the `else_case` child
439 VisitStmt(op->else_case.value());
440 computations_done_by_else = table_of_computations_;
441 table_of_computations_.clear();
442 }
443
444 // Build a table of computations for this node with three children
445 table_of_computations_ = BuildTableForThreeChildrenNode(
446 computations_done_by_cond, computations_done_by_then, computations_done_by_else);
447
448 // Copy the `table_of_computations_` into the cache
449 // for the future queries
450 Stmt ref_to_op = GetRef<Stmt>(op);
451 cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_;
452}
453
454/*!
455 * \brief The method which overrides the specific treatment for a ForNode
456 */
457void ComputationsDoneBy::VisitStmt_(const ForNode* op) {
458 // We build the computations done by each of its child, but unlike the overridden method we will
459 // remember each table of computations so that we can at the end compute the needed intersections
460
461 // Calls the VisitExpr() method on the `min` child
462 VisitExpr(op->min);
463 ComputationTable computations_done_by_min = table_of_computations_;
464 // Clear it for not importing the computations of the min in the computations of the extent
465 table_of_computations_.clear();
466
467 // Then calls the VisitStmt() method on the `extent` child
468 VisitExpr(op->extent);
469 ComputationTable computations_done_by_extent = table_of_computations_;
470 // Clear it for not importing the computations of the extent in the computations of the body
471 table_of_computations_.clear();
472
473 ComputationTable computations_done_by_body;
474 // And finally calls the VisitStmt() method on the `body` child
475 VisitStmt(op->body);
476 computations_done_by_body = table_of_computations_;
477 table_of_computations_.clear();
478
479 // Build a table of computations for this node with three children
480 table_of_computations_ = BuildTableForThreeChildrenNode(
481 computations_done_by_min, computations_done_by_extent, computations_done_by_body);
482
483 // Copy the `table_of_computations_` into the cache
484 // for the future queries
485 Stmt ref_to_op = GetRef<Stmt>(op);
486 cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_;
487}
488
489/*!
490 * \brief The method which overrides the specific treatment for a WhileNode
491 */
492void ComputationsDoneBy::VisitStmt_(const WhileNode* op) {
493 // We build the computations done by each of its child, but unlike the overridden method we will
494 // remember each table of computations so that we can at the end compute the needed intersection
495
496 // Calls the VisitExpr() method on the `condition` child
497 VisitExpr(op->condition);
498 ComputationTable computations_done_by_condition = table_of_computations_;
499 // Clear it for not importing the computations of the min in the computations of the extent
500 table_of_computations_.clear();
501
502 // Then calls the VisitStmt() method on the `body` child
503 VisitStmt(op->body);
504 ComputationTable computations_done_by_body = table_of_computations_;
505 // Clear it for not importing the computations of the extent in the computations of the body
506 table_of_computations_.clear();
507
508 // Build a table of computations for this node with two children by computing what is
509 // is common between the two child, i.e. computing their intersection
510 table_of_computations_ =
511 IntersectComputationTables(computations_done_by_condition, computations_done_by_body);
512
513 // Copy the `table_of_computations_` into the cache
514 // for the future queries
515 Stmt ref_to_op = GetRef<Stmt>(op);
516 cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_;
517}
518
519/*!
520 * \brief Static method that returns the computations done by the children of an expression.
521 * \param expr The expression to analyze
522 * \param is_eligible_computation The predicate which decides if an expression is eligible for
523 being introduced in a new variable
524 * \param can_contain_computations The predicate which decides if an expression can contain an
525 eligible computation
526 * \return The hashtable containing the (syntactic) computations done by children nodes of `expr`
527 */
528ComputationTable ComputationsDoneBy::ComputationsDoneByChildrenOf(
529 const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
530 std::function<bool(const PrimExpr&)> can_contain_computations) {
531 // We will be using an instance of the class ComputationsDoneBy for the child nodes
532 // (ie, they will share the "result" that `table_of_computations_` is)
533 ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations);
534 // Calls the *dispatcher* (not the overriden method)
535 computations_done_by.StmtExprVisitor::VisitExpr(expr);
536 // Now we can copy `table_of_computations_` into the cache for the future queries
537 // Note : in the table, the computations done by `expr` is set to the computations done by its
538 // children, because otherwise we would not have needed to compute them.
539 cache_.cache_expr_table_computations_[expr] = computations_done_by.table_of_computations_;
540
541 return computations_done_by.table_of_computations_;
542}
543
544/*!
545 * \brief Static method that returns the computations done by the children of a statement.
546 * \param stmt The statement to analyze.
547 * \param is_eligible_computation The predicate which decides if an expression is eligible for
548 being introduced in a new variable
549 * \param can_contain_computations The predicate which decides if an expression can contain an
550 eligible computation
551 * \return The hashtable contaning the (syntactic) computations done by children nodes of `stmt`
552 */
553ComputationTable ComputationsDoneBy::ComputationsDoneByChildrenOf(
554 const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation,
555 std::function<bool(const PrimExpr&)> can_contain_computations) {
556 // We will be using an instance of the class ComputationsDoneBy for the child nodes
557 // (ie, they will share the "result" that `table_of_computations_` is)
558 ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations);
559 // Calls the *dispatcher* (not the overriden method)
560 computations_done_by.StmtExprVisitor::VisitStmt(stmt);
561 // So now we can copy table_of_computations_ into the cache for the future queries
562 // Note : in the table, the computations done by `stmt` is set to the computations done by its
563 // children, because that's exactly what we mean by "the computations of a statement".
564 cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_;
565
566 return computations_done_by.table_of_computations_;
567}
568
569/* *********************************** Class DirectSubexpr **************************************
570*********************************************************************************************** */
571
572/* This utility class of the CSE pass offers a way of obtaining the direct subexpression
573 of a given expression.
574 For instance, for (A+(B+C)) it will return A and (B+C) if they are eligible, but not B and C.
575 If one of the direct subexpression is not eligible, it will consider the direct subexprs of this
576 uneligible expression (and etcetera if one of them is not eligible).
577 But before continuing recursively on an ineligible term, it makes sure that is has the right to
578 do so by checking if `can_contain_computations` evaluates to true.
579
580 This is used by the CSE pass, which will first attempt to introduce large computations into new
581 variables, and only when that's not possible (either because the computation uses some variables
582 not yet within scope, or because it is not computed enough for being a good candidate), it will
583 consider its direct subexpression. That avoids to compute all the subexpression at once, and
584 instead evaluates them lazily, if and when needed.
585*/
586
587/*!
588 * \brief Toplevel (static) function that returns the direct subexpressions of a given expression
589 * \param expr The expression to analyze.
590 * \param is_eligible_computation The predicate which decides if an expression is eligible for
591 being introduced in a new variable
592 * \param can_contain_computations The predicate which decides if an expression can contain an
593 eligible computation
594 * \return A vector of PrimExpr containing the direct subexpressions of `expr`
595 */
596std::vector<PrimExpr> DirectSubexpr::GetDirectSubexpressions(
597 const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation,
598 std::function<bool(const PrimExpr&)> can_contain_computations) {
599 DirectSubexpr direct_subexpr(is_eligible_computation, can_contain_computations);
600 direct_subexpr.VisitExpr(expr);
601
602 return direct_subexpr.direct_subexpr_;
603}
604
605/*!
606 * \brief Protected constructor of DirectSubexpr.
607 */
608DirectSubexpr::DirectSubexpr(std::function<bool(const PrimExpr&)> is_eligible_computation,
609 std::function<bool(const PrimExpr&)> can_contain_computations)
610 : is_eligible_computation_(is_eligible_computation),
611 can_contain_computations_(can_contain_computations) {}
612
613/*!
614 * \brief The method which overrides the generic dispatcher of ExprVisitor
615 * \param expr The expression to visit
616 */
617void DirectSubexpr::VisitExpr(const PrimExpr& expr) {
618 // If we have already entered (meaning that we are not dealing with the original expression)
619 if (entered_) {
620 if (is_eligible_computation_(expr)) {
621 direct_subexpr_.push_back(expr);
622 return;
623 } else {
624 if (can_contain_computations_(expr)) {
625 ExprVisitor::VisitExpr(expr);
626 }
627 return;
628 }
629 }
630
631 // If we reach this point, it means that we haven't visited any child node yet, and will need
632 // to dive into the expression, if it is allowed to contain eligible computations
633 if (can_contain_computations_(expr)) {
634 // Take note that now we have already visited some node
635 entered_ = true;
636 ExprVisitor::VisitExpr(expr);
637 }
638}
639
640/* ************************************ Class UsesVarName *************************************
641*********************************************************************************************** */
642
643/*!
644 * \brief Toplevel (static) function that tells if a given expression uses a given variable name.
645 * \param expr The expression to analyze
646 * \param var_name The variable name to check for
647 * \return A boolean telling if `expr` uses `var_name`
648 */
649bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, String var_name) {
650 UsesVarName uses_var_name(var_name);
651 uses_var_name.VisitExpr(expr);
652
653 return uses_var_name.uses_var_name_;
654}
655
656/*!
657 * \brief Toplevel (static) function that tells if a given statement uses a given variable name.
658 * \param stmt The statement to analyze
659 * \param var_name The variable name to check for
660 * \return A boolean telling if `stmt` uses `var_name`
661 */
662bool UsesVarName::StmtUsesVarName(const Stmt& stmt, String var_name) {
663 UsesVarName uses_var_name(var_name);
664 uses_var_name.VisitStmt(stmt);
665
666 return uses_var_name.uses_var_name_;
667}
668
669/*!
670 * \brief Protected constructor of UsesVarName.
671 * \param var_name The String that we are looking for
672 */
673UsesVarName::UsesVarName(String var_name) : var_name_(var_name) {}
674
675/*!
676 * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions.
677 * \param expr The expression to visit
678 */
679void UsesVarName::VisitExpr(const PrimExpr& expr) {
680 if (auto var_node = expr.as<VarNode>()) {
681 if (var_node->name_hint == var_name_) {
682 uses_var_name_ = true;
683 return;
684 }
685 }
686 StmtExprVisitor::VisitExpr(expr);
687}
688
689/*!
690 * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements.
691 * \param stmt The statement to visit
692 */
693void UsesVarName::VisitStmt(const Stmt& stmt) {
694 // We keep exploring only if `uses_var_name_` is false
695 if (!uses_var_name_) {
696 // and in order to do that we call the general dispatcher
697 StmtExprVisitor::VisitStmt(stmt);
698 }
699 // As otherwise we already have our answer
700}
701
702/* ********************************** Utility functions for CSE *********************************
703*********************************************************************************************** */
704
705/*!
706 * \brief Print a table of computation.
707 */
708void PrintComputationTable(const ComputationTable& table) {
709 std::cout << "{" << std::endl;
710 for (const auto& current : table) {
711 std::cout << "(" << current.first << ", " << current.second << ")" << std::endl;
712 }
713 std::cout << "}" << std::endl;
714}
715
716/*!
717 * \brief Decides if two terms are equal syntactically
718 */
719bool EqualTerms(const PrimExpr& a, const PrimExpr& b) {
720 ExprDeepEqual deep_equal_;
721 return deep_equal_(a, b);
722}
723
724/*!
725 * \brief Normalization function of a term, use to decide the equivalence relation of interest
726 * \param expr The expression to normalize
727 * \param do_normalization Whether we want the function to actually do normalization
728 * \note This function can be customized
729 */
730PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization) {
731 if (do_normalization) {
732 // Customize here!
733 // We could decide to normalize terms in a way that identifies them modulo commutativity
734 // (like x+y and y+x), or modulo associativity (like (x+y)+z and x+(y+z)), etc.
735 // For that, a normalization procedure (or an incomplete "pseudo-normalization" like
736 // arith::Analyzer::Simplify) will be used.
737
738 // One possible customization:
739 // Here is just an attempt to do more commonings by using the pseudo-normalization function
740 // offered by arith::Analyzer::Simplify(). "pseudo" because while it is correct (i.e.
741 // the simplification is indeed equivalent to the original term), it is incomplete (i.e.
742 // the returned term is not guaranteed to be a normal form).
743 arith::Analyzer analyzer;
744 return analyzer.Simplify(expr);
745 } else {
746 // If `do_normalization` is false, the equivalence relation just checks the syntactic equality,
747 // so the normalization is just the identity function.
748 return expr;
749 }
750}
751
752/*!
753 * \brief Decides if two terms are equivalent semantically
754 */
755bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool identify_equiv_terms) {
756 // We restrict the equivalence to be decidable by a normalization procedure that is used to
757 // normalize both sides, and to then compare the normal forms with the strict syntactical
758 // equality
759 return EqualTerms(NormalizeTerm(a, identify_equiv_terms), NormalizeTerm(b, identify_equiv_terms));
760}
761
762/*!
763 * \brief Transforms a hashtable of syntactic computations into a vector or pairs
764 (expression, counter) where equivalent computations are merged and their counters added.
765 This function simply looks for semantically equivalent terms in order to get the real
766 total number of times a computation (and semantically equivalent ones) is seen.
767 * \param table The table to transform
768 \note This function is needed because the advantage of the hashtable was the constant lookup.
769 But in order to have this constant lookup, we could not collapse semantically equivalent
770 computations.
771 Attention, the pairs returned are deterministic and will always be the same (as the same
772 canonical representant will always be chosen for a given class of equivalence), but the
773 order in which these pairs appear in the result is not deterministic, as it is based on
774 the order in which we found items in the "normalized hashtable" `norm_table`). The caller
775 is expected to sort the result anyway.
776 */
777std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
778 const ComputationTable& table, bool identify_equiv_terms) {
779 std::vector<std::pair<PrimExpr, size_t>> result;
780
781 // If we do NOT identify equivalent terms, then we simply need to transform the input hashtable
782 // into a vector, without doing anything else.
783 if (!identify_equiv_terms) {
784 // The result will contain exactly as many elements as the input `table` has
785 result.reserve(table.size());
786 for (const auto& elem : table) {
787 result.push_back(elem);
788 }
789
790 return result;
791 }
792
793 // Otherwise, in order to identify equivalent terms, we will go through a table `norm_table`
794 // where normal forms are the keys., and use it to efficiently merge equivalent terms.
795
796 // In order to produce the result (a vector of semantical entities), the input table will be
797 // normalized. This normalized table will keep the count for each set of equivalent terms
798 // (i.e. each equivalence class), together with a term that did appear in this equivalence class
799 // (in practice, the first term of the equivalence class that was encoutered).
800 std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash, ExprDeepEqual>
801 norm_table;
802
803 // In order to avoid frequent rehashing if the norm_table becomes big, we immediately ask for
804 // enough space to store the amount of elements that the input table has, as it's clearly an
805 // upper bound (in the worst case, each element is its own representant, and there is as many
806 // equivalence classes as there are elements)
807 norm_table.reserve(table.size());
808
809 // Transform the input hashtable to a vector and sort it according to some order, as we will be
810 // iterating through its items soon, and the order of appearance will be used to determine the
811 // individual representant for each class of equivalence, which we want to be deterministic
812 // (otherwise {x+y, y+x} could be both replaced by x+y, and on another run by y+x).
813 std::vector<std::pair<PrimExpr, size_t>> sorted_items_of_table(table.begin(), table.end());
814
815 // We do the ordering by comparing the string repr of each expr to get a determinstic ordering
816 sort(sorted_items_of_table.begin(), sorted_items_of_table.end(),
817 [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
818 std::stringstream a_stream;
819 std::stringstream b_stream;
820 a_stream << AsLegacyRepr(a.first);
821 b_stream << AsLegacyRepr(b.first);
822 return a_stream.str().compare(b_stream.str()) < 0;
823 });
824
825 for (const auto& elem : sorted_items_of_table) {
826 PrimExpr norm_elem = NormalizeTerm(elem.first, identify_equiv_terms);
827 // If the normalized term is not already a key in the normalized table
828 auto it_found = norm_table.find(norm_elem);
829 if (it_found == norm_table.end()) {
830 // Then we add the mapping `norm_elem` -> (`elem`.first, `elem`.second) to the norm table
831 // (i.e. `norm_elem` has been seen `elem`.second many times so far, and the chosen element
832 // to represent the equivalence class will be `elem`.first as it's the first element of the
833 // class that we see)
834 norm_table[norm_elem] = elem;
835 } else {
836 // Otherwise, it's not the first time we see a term in this equivalence class, so we just
837 // increase the count of this equivalence class as we now have `elem`.second additional items
838 // coming to the equivalence class.
839 it_found->second.second += elem.second;
840 }
841 }
842
843 // norm_table.size() is the number of equivalence class that we have built, so it's exactly the
844 // number of items that we will return in the vector of semantical entities
845 result.reserve(norm_table.size());
846
847 // Transform the intermediate hashtable `norm_table` into a vector, forgetting the keys,
848 // (which are the normal forms), as they won't be used as the canonical representants (which are
849 // instead the first element of each class that is effectively seen)
850 // Careful : the pairs will never change (the canonical represantants chosen will always be the
851 // same), but the order in which the pairs are produced can vary as we are iterating through the
852 // hashtable `norm_table`. It is not an issue as the called will be sorting the result anyway.
853 std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash,
854 ExprDeepEqual>::const_iterator it_norm_table;
855 for (it_norm_table = norm_table.begin(); it_norm_table != norm_table.end(); ++it_norm_table) {
856 result.push_back(it_norm_table->second);
857 }
858
859 return result;
860}
861
862/*!
863 * \brief Predicate that decides if a computation, that is seen `nb_times_seen`, should be
864 introduced in a variable or not.
865 */
866bool PredicateIntroVarForComputation(const PrimExpr& computation, size_t nb_times_seen) {
867 // This predicate could later implement something more fine grained that would take in account
868 // the size of the expression too, as for instance a very large computation could be introduced
869 // as soon as two occurrences are seen, but a smaller one could need three or more occurrences
870 // for being introduced in a variable.
871
872 // But for now, we factorize any eligible item that we see at least twice, regardless of its size
873 return nb_times_seen >= 2;
874}
875
876/*!
877 * \brief Inserts a pair (expr,nb) to a sorted vector of such pairs (which is sorted by decreasing
878 size of expressions) and maintain the vector sorted while doing so.
879 */
880void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
881 const std::pair<PrimExpr, size_t>& pair) {
882 if (sorted_vec == nullptr) {
883 return;
884 }
885 // Find the insertion point using std::lower_bound on a comparison that uses
886 // CalculateExprComplexity(), which computes the "size" of an expr.
887 // std::lower_boud returns an iterator pointing to the first element on which the comparison
888 // does not return true with the given value (`pair` here), i.e, an iterator pointing to the
889 // first element that is not greater or equal than `pair`, i.e, the first element that is
890 // strictly smaller than `pair`.
891 auto insertion_point = std::lower_bound(
892 sorted_vec->begin(), sorted_vec->end(), pair,
893 [](const std::pair<PrimExpr, size_t>& left, const std::pair<PrimExpr, size_t>& right) {
894 return (CalculateExprComplexity(left.first) >= CalculateExprComplexity(right.first));
895 });
896 sorted_vec->insert(insertion_point, pair);
897}
898
899/*!
900 * \brief Inserts a vector of expressions into a sorted vector of computations (which is sorted by
901 decreasing size of the expression) and maintain the vector sorted while doing so.
902 */
903void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
904 const std::vector<PrimExpr>& vec_to_add,
905 bool identify_equiv_terms, size_t increase_count) {
906 if (sorted_vec == nullptr) {
907 return;
908 }
909 for (auto elem_to_add : vec_to_add) {
910 // See if the current element to add (or an equivalent one) is already present
911 // in the sorted vector
912 auto it_found =
913 std::find_if(sorted_vec->begin(), sorted_vec->end(),
914 [elem_to_add, identify_equiv_terms](std::pair<PrimExpr, size_t> elem) {
915 return EquivalentTerms(elem.first, elem_to_add, identify_equiv_terms);
916 });
917
918 // If we found `elem_to_add` (or an equivalent expression) already in sorted_vec
919 if (it_found != sorted_vec->end()) {
920 // then we just increase its associated count
921 it_found->second += increase_count;
922 } else {
923 // Otherwise we add the pair (`elem_to_add`,1) at the right place
924 InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, increase_count});
925 }
926 }
927}
928
929} // namespace tir
930} // namespace tvm
931