1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file common_subexpr_elim_tools.cc |
22 | * \brief Implementation of analysis tools and utility functions used |
23 | by the Common Subexpression Elimination (CSE) pass. |
24 | */ |
25 | |
26 | #include "common_subexpr_elim_tools.h" |
27 | |
28 | #include <tvm/arith/analyzer.h> // For the arith::Analyzer::Simplify() method simplifying terms |
29 | #include <tvm/ir/transform.h> // For the class Pass and the class PassContext |
30 | #include <tvm/runtime/container/string.h> |
31 | #include <tvm/tir/analysis.h> // For the ExprDeepEqual analysis |
32 | #include <tvm/tir/expr.h> |
33 | #include <tvm/tir/expr_functor.h> |
34 | #include <tvm/tir/function.h> // For the class PrimFunc |
35 | #include <tvm/tir/stmt.h> |
36 | #include <tvm/tir/stmt_functor.h> |
37 | #include <tvm/tir/transform.h> // For the declaration of the pass |
38 | |
39 | #include <algorithm> // For std::find_if |
40 | #include <unordered_map> // For the hashtable datatype |
41 | #include <utility> |
42 | #include <vector> |
43 | |
44 | #include "../analysis/check_contains.h" // For the CheckContains analysis |
45 | |
46 | namespace tvm { |
47 | namespace tir { |
48 | |
49 | // cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here |
50 | // such static attribute, otherwise it causes a linking error. |
51 | ComputationCache ComputationsDoneBy::cache_; |
52 | |
53 | /* ********************************** Class ComputationsDoneBy ********************************** |
54 | *********************************************************************************************** */ |
55 | |
56 | /* This utility class of the CSE pass offers a way of knowing the eligible computations done by a |
57 | statement or expression. A "computation" here is a syntatical entity, represented by a PrimExpr. |
58 | This analysis returns a hashtable associating PrimExpr (a computation done) to a number (which |
59 | is the number of time that this computation is being seen). |
60 | This analysis is used by the CSE pass in order to find potential candidates for being introduced |
61 | into new variables (after having merged semantically equivalent computations). |
62 | |
63 | This analysis is parametrized by two predicates : `is_eligible_computation` and |
64 | `can_contain_computations`. |
65 | The first one helps to select only "eligible" computations, and the second one helps to only |
66 | select computations that are located at appropriate location (i.e., it tells in which nodes the |
67 | analysis can recurse). The user of the class must define these notions of "eligible computation" |
68 | and of "nodes that can contain eligibile computations" for his own use case. |
69 | |
70 | - On an statement, this analysis often returns the union of all the computations that appear in |
71 | its child nodes (ie, the union of the results of the recursive calls). |
72 | For instance, on the input statement [let a = x+y in Mem[i1+i2] = a+b] it will report (x+y) |
73 | seen once, (i1+i2) seen once, and (a+b) also seen once when used with typical predicates. |
74 | On some nodes, it will return something more complicated that uses the intersection of the |
75 | computations done by the children nodes. |
76 | For instance, on the input statement [if (x+y>z) then a = x+y else a = b-x] it will return |
77 | (x+y) seen twice but it won't report b-x as is it seen only the else branch. |
78 | |
79 | - On an expression, this analysis returns the expression itself, except if it is not eligible |
80 | for being introduced by the CSE pass into a variable according to `is_eligible_computation_` |
81 | (often because it's a load node or a function call node for instance), in which case it will |
82 | return the union of the recursive calls on its children, as long as the other predicate |
83 | `can_contain_computations` evaluates to true to let the algorithm recurse deeper. |
84 | With typical predicates, on the expression ((w+x)+(y+z)) it will return only the expression |
85 | itself. But on the expression Load[i1+i2] it might return only (i1+i2) as the full Load node |
86 | might not be eligible. |
87 | |
88 | This class uses an internal cache of results, so that if one queries it several times on the |
89 | same statement or expression, it will just retrieve the result from its internal cache. |
90 | That avoids some systematic recomputations, which would otherwise happen as the CSE pass first |
91 | analyses the program at the toplovel (asking for the computations done by the root), and then |
92 | dives deeper and deeper into the program, asking for the computations done by the children of |
93 | the root, which were necessarly previously obtained when computing the computations done by the |
94 | root (as the computations done by the root are by definition the union of the computations done |
95 | by the children nodes). |
96 | |
97 | The somehow difficult aspect of the implementation is the interaction between this caching of |
98 | results, and the fact that the VisitStmt()/VisitExpr() of an analyzer (a StmtExprVisitor) are |
99 | void methods which can't return anything, and instead need to accumulate a result into a member |
100 | variable, which is called `table_of_computations_` here. |
101 | |
102 | In particular, as the specialized methods (all the VisitStmt_() and VisitExpr_() methods), just |
103 | call VisitStmt()/VisitExpr() on all the children nodes within the same instance, if we don't |
104 | want to override each of these specialized methods to change this behaviour, then |
105 | `table_of_computations_` will necessary be shared by all the children of a given nodes. |
106 | That requires to be careful when trying to write into the cache. |
107 | */ |
108 | |
109 | /*! |
110 | * \brief Does the union of two tables of computations. |
111 | * \param table_main Pointer to one of the two tables. The union will be written into it. |
112 | * \param table_aux The other table, which won't change. |
113 | * \note Does it directly in the first argument A for efficiency, as the union of A and B |
114 | * necessarily gives something which contains A, so we avoid its copy. |
115 | */ |
116 | void UnionOfComputationTables(ComputationTable* table_main, const ComputationTable& table_aux) { |
117 | if (table_main == nullptr) { |
118 | return; |
119 | } |
120 | // Adds each element of the second table to the first one |
121 | for (const auto& current : table_aux) { |
122 | (*table_main)[current.first] += current.second; |
123 | } |
124 | } |
125 | |
126 | /*! |
127 | * \brief Does the union of three tables of computations. |
128 | * \param table1 One of the three tables, which won't change. |
129 | * \param table2 One of the three tables, which won't change. |
130 | * \param table3 One of the three tables, which won't change. |
131 | * \note We don't need (at least yet) to have a function working for N tables, even if this |
132 | * function for 3 tables seems at first glance redundant with the one for 2 tables defined |
133 | * just above. The reason is that in order to do the union for N tables, we need to know how |
134 | * to do it for two. That's because we would compute for N tables using the associativity |
135 | * of the union : T1 U T2 U T3 ... U Tn = ((T1 U T2) U T3) ... U Tn |
136 | * Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used |
137 | * (at least for now) for N=3, there is at the moment no need for such a generic union over |
138 | * N tables. |
139 | */ |
140 | ComputationTable UnionOfComputationTables(const ComputationTable& table1, |
141 | const ComputationTable& table2, |
142 | const ComputationTable& table3) { |
143 | ComputationTable result = table1; // Copy needed as the union of 2 writes into its first arg |
144 | UnionOfComputationTables(&result, table2); |
145 | UnionOfComputationTables(&result, table3); |
146 | |
147 | return result; |
148 | } |
149 | |
150 | /*! |
151 | * \brief Does the intersection of two tables of computations. |
152 | * \param table1 One of the two tables, which won't change. |
153 | * \param table2 The other table, which also won't change. |
154 | */ |
155 | ComputationTable IntersectComputationTables(const ComputationTable& table1, |
156 | const ComputationTable& table2) { |
157 | ComputationTable result; |
158 | for (const auto& current : table1) { |
159 | auto it = table2.find(current.first); |
160 | if (it != table2.end()) { |
161 | result[current.first] = current.second + it->second; |
162 | } |
163 | } |
164 | return result; |
165 | } |
166 | |
167 | /*! |
168 | * \brief Does the intersection of three tables of computations. |
169 | * \param table1 One of the three tables, which won't change. |
170 | * \param table2 One of the three tables, which won't change. |
171 | * \param table3 One of the three tables, which won't change. |
172 | * \note We don't need (at least yet) to have a function working for N tables, even if this |
173 | * function for 3 tables seems at first glance redundant with the one for 2 tables defined |
174 | * just above. The reason is that in order to do the intersection for N tables, we need to |
175 | * know how to do it for two. That's because we would compute for N tables using the |
176 | * associativity of the intersection : T1 Inter T2 Inter T3 ... Inter Tn |
177 | * = ((T1 Inter T2) Inter T3) ... Inter Tn |
178 | * Therefore, we need one for 2 tables anyway. And as the one for N (N>=3) would only be used |
179 | * (at least for now) for N=3, there is at the moment no need for such a generic intersection |
180 | * over N tables. |
181 | */ |
182 | ComputationTable IntersectComputationTables(const ComputationTable& table1, |
183 | const ComputationTable& table2, |
184 | const ComputationTable& table3) { |
185 | ComputationTable result = IntersectComputationTables(table1, table2); |
186 | result = IntersectComputationTables(result, table3); |
187 | return result; |
188 | } |
189 | |
190 | /*! |
191 | * \brief Recompute the number of times that each computation in table_main is seen in the tables |
192 | contained by the vector of tables vecTables. It sets each element to the sum of the times |
193 | it is seen in each individual table. |
194 | * \param table_main The main table, for which we recompute the counters. |
195 | * \param vecTables The vector of tables which won't change. |
196 | * \note This function is needed because both the intersection (A Inter B) and the union |
197 | * (A U B U C) adds the individual counters found in A, B and C. So when we treat for |
198 | * instance an If (which contains a Cond, a Then branch and an Else branch), |
199 | * it will compute (Then Inter Else) U (Cond Inter Then) U (Cond Inter Else). |
200 | * In order to get back to the appropriate number (for instance, 3 if seen one time in each |
201 | * bloc), it is therefore necessary to recompute the counters afterwards, which is what this |
202 | * function does. |
203 | */ |
204 | void RecomputeNbTimesSeen(ComputationTable* table_main, |
205 | const std::vector<const ComputationTable*>& vec_tables) { |
206 | if (table_main == nullptr) { |
207 | return; |
208 | } |
209 | // For each element in the main table |
210 | for (auto& current_elem : *table_main) { |
211 | // We will recompute its associated counter. |
212 | // Set its count to zero as so far it has been seen zero times |
213 | current_elem.second = 0; |
214 | // For each table in the vector of tables |
215 | for (auto current_table : vec_tables) { |
216 | // Try to find current_elem in the current table |
217 | auto it = current_table->find(current_elem.first); |
218 | if (it != current_table->end()) { |
219 | // If found, increase its counter by the value found in the current table |
220 | current_elem.second += it->second; |
221 | } |
222 | } |
223 | } |
224 | } |
225 | |
226 | /*! |
227 | * \brief Builds a table for a node that has three children. A computation will be reported |
228 | as being computed if it appears in at least two of the children, i.e. if it will aways be |
229 | computed, regardless of the execution path. |
230 | * \param table_child1 The table of computations done by the first child. |
231 | * \param table_child2 The table of computations done by the second child. |
232 | * \param table_child3 The table of computations done by the third child. |
233 | * \note This function will be used for obtaining the computations done by If nodes and by For |
234 | * nodes, which both have three children. |
235 | */ |
236 | ComputationTable BuildTableForThreeChildrenNode(const ComputationTable& table_child1, |
237 | const ComputationTable& table_child2, |
238 | const ComputationTable& table_child3) { |
239 | ComputationTable result; |
240 | // We look at what the children have in common |
241 | ComputationTable child2_inter_child3 = IntersectComputationTables(table_child2, table_child3); |
242 | ComputationTable child1_inter_child2 = IntersectComputationTables(table_child1, table_child2); |
243 | ComputationTable child1_inter_child3 = IntersectComputationTables(table_child1, table_child3); |
244 | |
245 | // We do the union of all the things they have in common |
246 | result = UnionOfComputationTables(child2_inter_child3, child1_inter_child2, child1_inter_child3); |
247 | |
248 | // Now we need to recompute the numbers associated with each computation, because both the |
249 | // intersections and the union might have increased the counters, which can now be wrong. |
250 | std::vector<const ComputationTable*> vec_tables = {&table_child1, &table_child2, &table_child3}; |
251 | RecomputeNbTimesSeen(&result, vec_tables); |
252 | |
253 | return result; |
254 | } |
255 | |
256 | /*! |
257 | * \brief Toplevel (static) method for a PrimExpr |
258 | * \param expr The expr for which we want to know the computations done |
259 | * \param is_eligible_computation The predicate which decides if an expression is eligible for |
260 | being introduced in a new variable |
261 | * \param can_contain_computations The predicate which decides if an expression can contain an |
262 | eligible computation |
263 | */ |
264 | ComputationTable ComputationsDoneBy::GetComputationsDoneBy( |
265 | const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation, |
266 | std::function<bool(const PrimExpr&)> can_contain_computations) { |
267 | // Chunk for avoiding the lookup (and writing) in the cache for an atom (constant or variable), |
268 | // for which the table of computations is empty. |
269 | // (We don't want to use a "line of cache" of that, as that would cost an empty table of |
270 | // computations in memory for absolutely no gain) |
271 | if (expr.as<IntImmNode>() != nullptr || expr.as<FloatImmNode>() != nullptr || |
272 | expr.as<StringImmNode>() != nullptr || expr.as<VarNode>() != nullptr) { |
273 | // Return an empty table |
274 | return {}; |
275 | } |
276 | |
277 | // See if we have already computed the (table of) computations done by `expr` |
278 | auto it_table_expr = cache_.cache_expr_table_computations_.find(expr); |
279 | if (it_table_expr != cache_.cache_expr_table_computations_.end()) { |
280 | // then we just return it |
281 | return it_table_expr->second; |
282 | } |
283 | |
284 | // Otherwise we will need to compute it, by using an instance of the class ComputationsDoneBy |
285 | // (as we are currently in a static method) |
286 | ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations); |
287 | // Call the VisitExpr() method on it to start the visit |
288 | computations_done_by.VisitExpr(expr); |
289 | // Copy the `table_of_computations_` (that `computations_done_by` has computed) into the cache |
290 | // for the future queries |
291 | cache_.cache_expr_table_computations_[expr] = computations_done_by.table_of_computations_; |
292 | |
293 | return computations_done_by.table_of_computations_; |
294 | } |
295 | |
296 | /*! |
297 | * \brief Toplevel (static) method for a Stmt |
298 | * \param stmt The stmt for which we want to know the computations done |
299 | * \param is_eligible_computation The predicate which decides if an expression is eligible for |
300 | being introduced in a new variable |
301 | * \param can_contain_computations The predicate which decides if an expression can contain an |
302 | eligible computation |
303 | */ |
304 | ComputationTable ComputationsDoneBy::GetComputationsDoneBy( |
305 | const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation, |
306 | std::function<bool(const PrimExpr&)> can_contain_computations) { |
307 | // See if we have already computed the (table of) computations done by `stmt` |
308 | auto it_table_stmt = cache_.cache_stmt_table_computations_.find(stmt); |
309 | if (it_table_stmt != cache_.cache_stmt_table_computations_.end()) { |
310 | // then we just return it |
311 | return it_table_stmt->second; |
312 | } |
313 | |
314 | // Otherwise we will need to compute it, by using an instance of the class ComputationsDoneBy |
315 | // (as we are currently in a static method) |
316 | ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations); |
317 | // Call the VisitStmt() method on it to start the visit |
318 | computations_done_by.VisitStmt(stmt); |
319 | // Copy the `table_of_computations_` that `computations_done_by` has computed into the cache |
320 | // for the future queries |
321 | cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_; |
322 | |
323 | return computations_done_by.table_of_computations_; |
324 | } |
325 | |
326 | /*! |
327 | * \brief Protected constructor of ComputationsDoneBy. |
328 | * \param is_eligible_computation The predicate which decides if an expression is eligible for |
329 | being introduced in a new variable |
330 | * \param can_contain_computations The predicate which decides if an expression can contain an |
331 | eligible computation |
332 | */ |
333 | ComputationsDoneBy::ComputationsDoneBy( |
334 | std::function<bool(const PrimExpr&)> is_eligible_computation, |
335 | std::function<bool(const PrimExpr&)> can_contain_computations) |
336 | : is_eligible_computation_(is_eligible_computation), |
337 | can_contain_computations_(can_contain_computations) {} |
338 | |
339 | /*! |
340 | * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions |
341 | * \param expr The expression to visit |
342 | */ |
343 | void ComputationsDoneBy::VisitExpr(const PrimExpr& expr) { |
344 | // Chunk for avoiding the lookup (and writing) in the cache for an atom (constant or variable), |
345 | // for which the table of computations is empty. |
346 | // (We don't want to use a "line of cache" of that, as that would cost an empty table of |
347 | // computations in memory for absolutely no gain) |
348 | if (expr.as<IntImmNode>() != nullptr || expr.as<FloatImmNode>() != nullptr || |
349 | expr.as<StringImmNode>() != nullptr || expr.as<VarNode>() != nullptr) { |
350 | return; |
351 | } |
352 | |
353 | // See if we have already computed the (table of) computations done by `expr` |
354 | auto it_table_expr = cache_.cache_expr_table_computations_.find(expr); |
355 | if (it_table_expr != cache_.cache_expr_table_computations_.end()) { |
356 | // We need to do the union with `table_of_computations_` instead of just writing into it, |
357 | // because some other childs might have added things into it too. The reason for that is |
358 | // that `table_of_computations_` is shared between the child nodes of a given expression. |
359 | UnionOfComputationTables(&table_of_computations_, it_table_expr->second); |
360 | return; |
361 | } |
362 | |
363 | // If we reach this point, it means that we have never computed before the computations done |
364 | // by 'expr' and will do so now. |
365 | |
366 | // If the given expression is an eligible computation, we simply "return it" by adding it into |
367 | // the "result variable" that `table_of_computations_` is. |
368 | if (is_eligible_computation_(expr)) { |
369 | // We can add `expr` to the table of computations |
370 | table_of_computations_[expr]++; |
371 | return; |
372 | } |
373 | |
374 | // If we reach this point, then the given expression is not an eligible computation. |
375 | // But perhaps we have the right to dive into it to find some smaller eligible computations |
376 | if (can_contain_computations_(expr)) { |
377 | ComputationTable temp = |
378 | ComputationsDoneByChildrenOf(expr, is_eligible_computation_, can_contain_computations_); |
379 | // We need to do the union with `table_of_computations_` instead of just writing into it, |
380 | // because some other childs might have added things into it too. The reason for that is |
381 | // that `table_of_computations_` is shared between the child nodes of a given expression. |
382 | UnionOfComputationTables(&table_of_computations_, temp); |
383 | return; |
384 | } |
385 | |
386 | // Note that we do not continue by calling the general disptacher |
387 | // StmtExprVisitor::VisitExpr(expr) as we want the full computations, not their subexpressions. |
388 | } |
389 | |
390 | /*! |
391 | * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements |
392 | * \param stmt The statement to visit |
393 | */ |
394 | void ComputationsDoneBy::VisitStmt(const Stmt& stmt) { |
395 | // See if we have already computed the (table of) computations done by `stmt` |
396 | auto it_table_stmt = cache_.cache_stmt_table_computations_.find(stmt); |
397 | if (it_table_stmt != cache_.cache_stmt_table_computations_.end()) { |
398 | // We need to do the union with `table_of_computations_` instead of just writing into it, |
399 | // because some other childs might have added things into it too. The reason for that is |
400 | // that `table_of_computations_` is shared between the child nodes of a given statement. |
401 | UnionOfComputationTables(&table_of_computations_, it_table_stmt->second); |
402 | return; |
403 | } |
404 | |
405 | // If we reach this point, it means that we have never computed before the computations done |
406 | // by `stmt` and will do so now. |
407 | |
408 | // The computations done by a Stmt node are just the ones done by its children |
409 | ComputationTable temp = |
410 | ComputationsDoneByChildrenOf(stmt, is_eligible_computation_, can_contain_computations_); |
411 | // We need to do the union with `table_of_computations_` instead of just writing into it, |
412 | // because some other childs might have added things into it too. The reason for that is |
413 | // that `table_of_computations_` is shared between the child nodes of a given expression. |
414 | UnionOfComputationTables(&table_of_computations_, temp); |
415 | } |
416 | |
417 | /*! |
418 | * \brief The method which overrides the specific treatment for an IfThenElseNode |
419 | */ |
420 | void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) { |
421 | // We build the computations done by each of its child, but unlike the overridden method we will |
422 | // remember each table of computations so that we can at the end compute the needed intersections |
423 | |
424 | // Calls the VisitExpr() method on the `condition` child |
425 | VisitExpr(op->condition); |
426 | ComputationTable computations_done_by_cond = table_of_computations_; |
427 | // Clear it for not importing the computations of the condition in the computations of the then |
428 | table_of_computations_.clear(); |
429 | |
430 | // Then calls the VisitStmt() method on the `then_case` child |
431 | VisitStmt(op->then_case); |
432 | ComputationTable computations_done_by_then = table_of_computations_; |
433 | // Clear it for not importing the computations of the then in the computations of the else |
434 | table_of_computations_.clear(); |
435 | |
436 | ComputationTable computations_done_by_else; |
437 | if (op->else_case) { |
438 | // And finally calls the VisitStmt() method on the `else_case` child |
439 | VisitStmt(op->else_case.value()); |
440 | computations_done_by_else = table_of_computations_; |
441 | table_of_computations_.clear(); |
442 | } |
443 | |
444 | // Build a table of computations for this node with three children |
445 | table_of_computations_ = BuildTableForThreeChildrenNode( |
446 | computations_done_by_cond, computations_done_by_then, computations_done_by_else); |
447 | |
448 | // Copy the `table_of_computations_` into the cache |
449 | // for the future queries |
450 | Stmt ref_to_op = GetRef<Stmt>(op); |
451 | cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; |
452 | } |
453 | |
454 | /*! |
455 | * \brief The method which overrides the specific treatment for a ForNode |
456 | */ |
457 | void ComputationsDoneBy::VisitStmt_(const ForNode* op) { |
458 | // We build the computations done by each of its child, but unlike the overridden method we will |
459 | // remember each table of computations so that we can at the end compute the needed intersections |
460 | |
461 | // Calls the VisitExpr() method on the `min` child |
462 | VisitExpr(op->min); |
463 | ComputationTable computations_done_by_min = table_of_computations_; |
464 | // Clear it for not importing the computations of the min in the computations of the extent |
465 | table_of_computations_.clear(); |
466 | |
467 | // Then calls the VisitStmt() method on the `extent` child |
468 | VisitExpr(op->extent); |
469 | ComputationTable computations_done_by_extent = table_of_computations_; |
470 | // Clear it for not importing the computations of the extent in the computations of the body |
471 | table_of_computations_.clear(); |
472 | |
473 | ComputationTable computations_done_by_body; |
474 | // And finally calls the VisitStmt() method on the `body` child |
475 | VisitStmt(op->body); |
476 | computations_done_by_body = table_of_computations_; |
477 | table_of_computations_.clear(); |
478 | |
479 | // Build a table of computations for this node with three children |
480 | table_of_computations_ = BuildTableForThreeChildrenNode( |
481 | computations_done_by_min, computations_done_by_extent, computations_done_by_body); |
482 | |
483 | // Copy the `table_of_computations_` into the cache |
484 | // for the future queries |
485 | Stmt ref_to_op = GetRef<Stmt>(op); |
486 | cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; |
487 | } |
488 | |
489 | /*! |
490 | * \brief The method which overrides the specific treatment for a WhileNode |
491 | */ |
492 | void ComputationsDoneBy::VisitStmt_(const WhileNode* op) { |
493 | // We build the computations done by each of its child, but unlike the overridden method we will |
494 | // remember each table of computations so that we can at the end compute the needed intersection |
495 | |
496 | // Calls the VisitExpr() method on the `condition` child |
497 | VisitExpr(op->condition); |
498 | ComputationTable computations_done_by_condition = table_of_computations_; |
499 | // Clear it for not importing the computations of the min in the computations of the extent |
500 | table_of_computations_.clear(); |
501 | |
502 | // Then calls the VisitStmt() method on the `body` child |
503 | VisitStmt(op->body); |
504 | ComputationTable computations_done_by_body = table_of_computations_; |
505 | // Clear it for not importing the computations of the extent in the computations of the body |
506 | table_of_computations_.clear(); |
507 | |
508 | // Build a table of computations for this node with two children by computing what is |
509 | // is common between the two child, i.e. computing their intersection |
510 | table_of_computations_ = |
511 | IntersectComputationTables(computations_done_by_condition, computations_done_by_body); |
512 | |
513 | // Copy the `table_of_computations_` into the cache |
514 | // for the future queries |
515 | Stmt ref_to_op = GetRef<Stmt>(op); |
516 | cache_.cache_stmt_table_computations_[ref_to_op] = table_of_computations_; |
517 | } |
518 | |
519 | /*! |
520 | * \brief Static method that returns the computations done by the children of an expression. |
521 | * \param expr The expression to analyze |
522 | * \param is_eligible_computation The predicate which decides if an expression is eligible for |
523 | being introduced in a new variable |
524 | * \param can_contain_computations The predicate which decides if an expression can contain an |
525 | eligible computation |
526 | * \return The hashtable containing the (syntactic) computations done by children nodes of `expr` |
527 | */ |
528 | ComputationTable ComputationsDoneBy::ComputationsDoneByChildrenOf( |
529 | const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation, |
530 | std::function<bool(const PrimExpr&)> can_contain_computations) { |
531 | // We will be using an instance of the class ComputationsDoneBy for the child nodes |
532 | // (ie, they will share the "result" that `table_of_computations_` is) |
533 | ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations); |
534 | // Calls the *dispatcher* (not the overriden method) |
535 | computations_done_by.StmtExprVisitor::VisitExpr(expr); |
536 | // Now we can copy `table_of_computations_` into the cache for the future queries |
537 | // Note : in the table, the computations done by `expr` is set to the computations done by its |
538 | // children, because otherwise we would not have needed to compute them. |
539 | cache_.cache_expr_table_computations_[expr] = computations_done_by.table_of_computations_; |
540 | |
541 | return computations_done_by.table_of_computations_; |
542 | } |
543 | |
544 | /*! |
545 | * \brief Static method that returns the computations done by the children of a statement. |
546 | * \param stmt The statement to analyze. |
547 | * \param is_eligible_computation The predicate which decides if an expression is eligible for |
548 | being introduced in a new variable |
549 | * \param can_contain_computations The predicate which decides if an expression can contain an |
550 | eligible computation |
551 | * \return The hashtable contaning the (syntactic) computations done by children nodes of `stmt` |
552 | */ |
553 | ComputationTable ComputationsDoneBy::ComputationsDoneByChildrenOf( |
554 | const Stmt& stmt, std::function<bool(const PrimExpr&)> is_eligible_computation, |
555 | std::function<bool(const PrimExpr&)> can_contain_computations) { |
556 | // We will be using an instance of the class ComputationsDoneBy for the child nodes |
557 | // (ie, they will share the "result" that `table_of_computations_` is) |
558 | ComputationsDoneBy computations_done_by(is_eligible_computation, can_contain_computations); |
559 | // Calls the *dispatcher* (not the overriden method) |
560 | computations_done_by.StmtExprVisitor::VisitStmt(stmt); |
561 | // So now we can copy table_of_computations_ into the cache for the future queries |
562 | // Note : in the table, the computations done by `stmt` is set to the computations done by its |
563 | // children, because that's exactly what we mean by "the computations of a statement". |
564 | cache_.cache_stmt_table_computations_[stmt] = computations_done_by.table_of_computations_; |
565 | |
566 | return computations_done_by.table_of_computations_; |
567 | } |
568 | |
569 | /* *********************************** Class DirectSubexpr ************************************** |
570 | *********************************************************************************************** */ |
571 | |
572 | /* This utility class of the CSE pass offers a way of obtaining the direct subexpression |
573 | of a given expression. |
574 | For instance, for (A+(B+C)) it will return A and (B+C) if they are eligible, but not B and C. |
575 | If one of the direct subexpression is not eligible, it will consider the direct subexprs of this |
576 | uneligible expression (and etcetera if one of them is not eligible). |
577 | But before continuing recursively on an ineligible term, it makes sure that is has the right to |
578 | do so by checking if `can_contain_computations` evaluates to true. |
579 | |
580 | This is used by the CSE pass, which will first attempt to introduce large computations into new |
581 | variables, and only when that's not possible (either because the computation uses some variables |
582 | not yet within scope, or because it is not computed enough for being a good candidate), it will |
583 | consider its direct subexpression. That avoids to compute all the subexpression at once, and |
584 | instead evaluates them lazily, if and when needed. |
585 | */ |
586 | |
587 | /*! |
588 | * \brief Toplevel (static) function that returns the direct subexpressions of a given expression |
589 | * \param expr The expression to analyze. |
590 | * \param is_eligible_computation The predicate which decides if an expression is eligible for |
591 | being introduced in a new variable |
592 | * \param can_contain_computations The predicate which decides if an expression can contain an |
593 | eligible computation |
594 | * \return A vector of PrimExpr containing the direct subexpressions of `expr` |
595 | */ |
596 | std::vector<PrimExpr> DirectSubexpr::GetDirectSubexpressions( |
597 | const PrimExpr& expr, std::function<bool(const PrimExpr&)> is_eligible_computation, |
598 | std::function<bool(const PrimExpr&)> can_contain_computations) { |
599 | DirectSubexpr direct_subexpr(is_eligible_computation, can_contain_computations); |
600 | direct_subexpr.VisitExpr(expr); |
601 | |
602 | return direct_subexpr.direct_subexpr_; |
603 | } |
604 | |
605 | /*! |
606 | * \brief Protected constructor of DirectSubexpr. |
607 | */ |
608 | DirectSubexpr::DirectSubexpr(std::function<bool(const PrimExpr&)> is_eligible_computation, |
609 | std::function<bool(const PrimExpr&)> can_contain_computations) |
610 | : is_eligible_computation_(is_eligible_computation), |
611 | can_contain_computations_(can_contain_computations) {} |
612 | |
613 | /*! |
614 | * \brief The method which overrides the generic dispatcher of ExprVisitor |
615 | * \param expr The expression to visit |
616 | */ |
617 | void DirectSubexpr::VisitExpr(const PrimExpr& expr) { |
618 | // If we have already entered (meaning that we are not dealing with the original expression) |
619 | if (entered_) { |
620 | if (is_eligible_computation_(expr)) { |
621 | direct_subexpr_.push_back(expr); |
622 | return; |
623 | } else { |
624 | if (can_contain_computations_(expr)) { |
625 | ExprVisitor::VisitExpr(expr); |
626 | } |
627 | return; |
628 | } |
629 | } |
630 | |
631 | // If we reach this point, it means that we haven't visited any child node yet, and will need |
632 | // to dive into the expression, if it is allowed to contain eligible computations |
633 | if (can_contain_computations_(expr)) { |
634 | // Take note that now we have already visited some node |
635 | entered_ = true; |
636 | ExprVisitor::VisitExpr(expr); |
637 | } |
638 | } |
639 | |
640 | /* ************************************ Class UsesVarName ************************************* |
641 | *********************************************************************************************** */ |
642 | |
643 | /*! |
644 | * \brief Toplevel (static) function that tells if a given expression uses a given variable name. |
645 | * \param expr The expression to analyze |
646 | * \param var_name The variable name to check for |
647 | * \return A boolean telling if `expr` uses `var_name` |
648 | */ |
649 | bool UsesVarName::ExprUsesVarName(const PrimExpr& expr, String var_name) { |
650 | UsesVarName uses_var_name(var_name); |
651 | uses_var_name.VisitExpr(expr); |
652 | |
653 | return uses_var_name.uses_var_name_; |
654 | } |
655 | |
656 | /*! |
657 | * \brief Toplevel (static) function that tells if a given statement uses a given variable name. |
658 | * \param stmt The statement to analyze |
659 | * \param var_name The variable name to check for |
660 | * \return A boolean telling if `stmt` uses `var_name` |
661 | */ |
662 | bool UsesVarName::StmtUsesVarName(const Stmt& stmt, String var_name) { |
663 | UsesVarName uses_var_name(var_name); |
664 | uses_var_name.VisitStmt(stmt); |
665 | |
666 | return uses_var_name.uses_var_name_; |
667 | } |
668 | |
669 | /*! |
670 | * \brief Protected constructor of UsesVarName. |
671 | * \param var_name The String that we are looking for |
672 | */ |
673 | UsesVarName::UsesVarName(String var_name) : var_name_(var_name) {} |
674 | |
675 | /*! |
676 | * \brief The method which overrides the generic dispatcher of StmtExprVisitor for expressions. |
677 | * \param expr The expression to visit |
678 | */ |
679 | void UsesVarName::VisitExpr(const PrimExpr& expr) { |
680 | if (auto var_node = expr.as<VarNode>()) { |
681 | if (var_node->name_hint == var_name_) { |
682 | uses_var_name_ = true; |
683 | return; |
684 | } |
685 | } |
686 | StmtExprVisitor::VisitExpr(expr); |
687 | } |
688 | |
689 | /*! |
690 | * \brief The method which overrides the generic dispatcher of StmtExprVisitor for statements. |
691 | * \param stmt The statement to visit |
692 | */ |
693 | void UsesVarName::VisitStmt(const Stmt& stmt) { |
694 | // We keep exploring only if `uses_var_name_` is false |
695 | if (!uses_var_name_) { |
696 | // and in order to do that we call the general dispatcher |
697 | StmtExprVisitor::VisitStmt(stmt); |
698 | } |
699 | // As otherwise we already have our answer |
700 | } |
701 | |
702 | /* ********************************** Utility functions for CSE ********************************* |
703 | *********************************************************************************************** */ |
704 | |
705 | /*! |
706 | * \brief Print a table of computation. |
707 | */ |
708 | void PrintComputationTable(const ComputationTable& table) { |
709 | std::cout << "{" << std::endl; |
710 | for (const auto& current : table) { |
711 | std::cout << "(" << current.first << ", " << current.second << ")" << std::endl; |
712 | } |
713 | std::cout << "}" << std::endl; |
714 | } |
715 | |
716 | /*! |
717 | * \brief Decides if two terms are equal syntactically |
718 | */ |
719 | bool EqualTerms(const PrimExpr& a, const PrimExpr& b) { |
720 | ExprDeepEqual deep_equal_; |
721 | return deep_equal_(a, b); |
722 | } |
723 | |
724 | /*! |
725 | * \brief Normalization function of a term, use to decide the equivalence relation of interest |
726 | * \param expr The expression to normalize |
727 | * \param do_normalization Whether we want the function to actually do normalization |
728 | * \note This function can be customized |
729 | */ |
730 | PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization) { |
731 | if (do_normalization) { |
732 | // Customize here! |
733 | // We could decide to normalize terms in a way that identifies them modulo commutativity |
734 | // (like x+y and y+x), or modulo associativity (like (x+y)+z and x+(y+z)), etc. |
735 | // For that, a normalization procedure (or an incomplete "pseudo-normalization" like |
736 | // arith::Analyzer::Simplify) will be used. |
737 | |
738 | // One possible customization: |
739 | // Here is just an attempt to do more commonings by using the pseudo-normalization function |
740 | // offered by arith::Analyzer::Simplify(). "pseudo" because while it is correct (i.e. |
741 | // the simplification is indeed equivalent to the original term), it is incomplete (i.e. |
742 | // the returned term is not guaranteed to be a normal form). |
743 | arith::Analyzer analyzer; |
744 | return analyzer.Simplify(expr); |
745 | } else { |
746 | // If `do_normalization` is false, the equivalence relation just checks the syntactic equality, |
747 | // so the normalization is just the identity function. |
748 | return expr; |
749 | } |
750 | } |
751 | |
752 | /*! |
753 | * \brief Decides if two terms are equivalent semantically |
754 | */ |
755 | bool EquivalentTerms(const PrimExpr& a, const PrimExpr& b, bool identify_equiv_terms) { |
756 | // We restrict the equivalence to be decidable by a normalization procedure that is used to |
757 | // normalize both sides, and to then compare the normal forms with the strict syntactical |
758 | // equality |
759 | return EqualTerms(NormalizeTerm(a, identify_equiv_terms), NormalizeTerm(b, identify_equiv_terms)); |
760 | } |
761 | |
762 | /*! |
763 | * \brief Transforms a hashtable of syntactic computations into a vector or pairs |
764 | (expression, counter) where equivalent computations are merged and their counters added. |
765 | This function simply looks for semantically equivalent terms in order to get the real |
766 | total number of times a computation (and semantically equivalent ones) is seen. |
767 | * \param table The table to transform |
768 | \note This function is needed because the advantage of the hashtable was the constant lookup. |
769 | But in order to have this constant lookup, we could not collapse semantically equivalent |
770 | computations. |
771 | Attention, the pairs returned are deterministic and will always be the same (as the same |
772 | canonical representant will always be chosen for a given class of equivalence), but the |
773 | order in which these pairs appear in the result is not deterministic, as it is based on |
774 | the order in which we found items in the "normalized hashtable" `norm_table`). The caller |
775 | is expected to sort the result anyway. |
776 | */ |
777 | std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations( |
778 | const ComputationTable& table, bool identify_equiv_terms) { |
779 | std::vector<std::pair<PrimExpr, size_t>> result; |
780 | |
781 | // If we do NOT identify equivalent terms, then we simply need to transform the input hashtable |
782 | // into a vector, without doing anything else. |
783 | if (!identify_equiv_terms) { |
784 | // The result will contain exactly as many elements as the input `table` has |
785 | result.reserve(table.size()); |
786 | for (const auto& elem : table) { |
787 | result.push_back(elem); |
788 | } |
789 | |
790 | return result; |
791 | } |
792 | |
793 | // Otherwise, in order to identify equivalent terms, we will go through a table `norm_table` |
794 | // where normal forms are the keys., and use it to efficiently merge equivalent terms. |
795 | |
796 | // In order to produce the result (a vector of semantical entities), the input table will be |
797 | // normalized. This normalized table will keep the count for each set of equivalent terms |
798 | // (i.e. each equivalence class), together with a term that did appear in this equivalence class |
799 | // (in practice, the first term of the equivalence class that was encoutered). |
800 | std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash, ExprDeepEqual> |
801 | norm_table; |
802 | |
803 | // In order to avoid frequent rehashing if the norm_table becomes big, we immediately ask for |
804 | // enough space to store the amount of elements that the input table has, as it's clearly an |
805 | // upper bound (in the worst case, each element is its own representant, and there is as many |
806 | // equivalence classes as there are elements) |
807 | norm_table.reserve(table.size()); |
808 | |
809 | // Transform the input hashtable to a vector and sort it according to some order, as we will be |
810 | // iterating through its items soon, and the order of appearance will be used to determine the |
811 | // individual representant for each class of equivalence, which we want to be deterministic |
812 | // (otherwise {x+y, y+x} could be both replaced by x+y, and on another run by y+x). |
813 | std::vector<std::pair<PrimExpr, size_t>> sorted_items_of_table(table.begin(), table.end()); |
814 | |
815 | // We do the ordering by comparing the string repr of each expr to get a determinstic ordering |
816 | sort(sorted_items_of_table.begin(), sorted_items_of_table.end(), |
817 | [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) { |
818 | std::stringstream a_stream; |
819 | std::stringstream b_stream; |
820 | a_stream << AsLegacyRepr(a.first); |
821 | b_stream << AsLegacyRepr(b.first); |
822 | return a_stream.str().compare(b_stream.str()) < 0; |
823 | }); |
824 | |
825 | for (const auto& elem : sorted_items_of_table) { |
826 | PrimExpr norm_elem = NormalizeTerm(elem.first, identify_equiv_terms); |
827 | // If the normalized term is not already a key in the normalized table |
828 | auto it_found = norm_table.find(norm_elem); |
829 | if (it_found == norm_table.end()) { |
830 | // Then we add the mapping `norm_elem` -> (`elem`.first, `elem`.second) to the norm table |
831 | // (i.e. `norm_elem` has been seen `elem`.second many times so far, and the chosen element |
832 | // to represent the equivalence class will be `elem`.first as it's the first element of the |
833 | // class that we see) |
834 | norm_table[norm_elem] = elem; |
835 | } else { |
836 | // Otherwise, it's not the first time we see a term in this equivalence class, so we just |
837 | // increase the count of this equivalence class as we now have `elem`.second additional items |
838 | // coming to the equivalence class. |
839 | it_found->second.second += elem.second; |
840 | } |
841 | } |
842 | |
843 | // norm_table.size() is the number of equivalence class that we have built, so it's exactly the |
844 | // number of items that we will return in the vector of semantical entities |
845 | result.reserve(norm_table.size()); |
846 | |
847 | // Transform the intermediate hashtable `norm_table` into a vector, forgetting the keys, |
848 | // (which are the normal forms), as they won't be used as the canonical representants (which are |
849 | // instead the first element of each class that is effectively seen) |
850 | // Careful : the pairs will never change (the canonical represantants chosen will always be the |
851 | // same), but the order in which the pairs are produced can vary as we are iterating through the |
852 | // hashtable `norm_table`. It is not an issue as the called will be sorting the result anyway. |
853 | std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash, |
854 | ExprDeepEqual>::const_iterator it_norm_table; |
855 | for (it_norm_table = norm_table.begin(); it_norm_table != norm_table.end(); ++it_norm_table) { |
856 | result.push_back(it_norm_table->second); |
857 | } |
858 | |
859 | return result; |
860 | } |
861 | |
862 | /*! |
863 | * \brief Predicate that decides if a computation, that is seen `nb_times_seen`, should be |
864 | introduced in a variable or not. |
865 | */ |
866 | bool PredicateIntroVarForComputation(const PrimExpr& computation, size_t nb_times_seen) { |
867 | // This predicate could later implement something more fine grained that would take in account |
868 | // the size of the expression too, as for instance a very large computation could be introduced |
869 | // as soon as two occurrences are seen, but a smaller one could need three or more occurrences |
870 | // for being introduced in a variable. |
871 | |
872 | // But for now, we factorize any eligible item that we see at least twice, regardless of its size |
873 | return nb_times_seen >= 2; |
874 | } |
875 | |
876 | /*! |
877 | * \brief Inserts a pair (expr,nb) to a sorted vector of such pairs (which is sorted by decreasing |
878 | size of expressions) and maintain the vector sorted while doing so. |
879 | */ |
880 | void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec, |
881 | const std::pair<PrimExpr, size_t>& pair) { |
882 | if (sorted_vec == nullptr) { |
883 | return; |
884 | } |
885 | // Find the insertion point using std::lower_bound on a comparison that uses |
886 | // CalculateExprComplexity(), which computes the "size" of an expr. |
887 | // std::lower_boud returns an iterator pointing to the first element on which the comparison |
888 | // does not return true with the given value (`pair` here), i.e, an iterator pointing to the |
889 | // first element that is not greater or equal than `pair`, i.e, the first element that is |
890 | // strictly smaller than `pair`. |
891 | auto insertion_point = std::lower_bound( |
892 | sorted_vec->begin(), sorted_vec->end(), pair, |
893 | [](const std::pair<PrimExpr, size_t>& left, const std::pair<PrimExpr, size_t>& right) { |
894 | return (CalculateExprComplexity(left.first) >= CalculateExprComplexity(right.first)); |
895 | }); |
896 | sorted_vec->insert(insertion_point, pair); |
897 | } |
898 | |
899 | /*! |
900 | * \brief Inserts a vector of expressions into a sorted vector of computations (which is sorted by |
901 | decreasing size of the expression) and maintain the vector sorted while doing so. |
902 | */ |
903 | void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec, |
904 | const std::vector<PrimExpr>& vec_to_add, |
905 | bool identify_equiv_terms, size_t increase_count) { |
906 | if (sorted_vec == nullptr) { |
907 | return; |
908 | } |
909 | for (auto elem_to_add : vec_to_add) { |
910 | // See if the current element to add (or an equivalent one) is already present |
911 | // in the sorted vector |
912 | auto it_found = |
913 | std::find_if(sorted_vec->begin(), sorted_vec->end(), |
914 | [elem_to_add, identify_equiv_terms](std::pair<PrimExpr, size_t> elem) { |
915 | return EquivalentTerms(elem.first, elem_to_add, identify_equiv_terms); |
916 | }); |
917 | |
918 | // If we found `elem_to_add` (or an equivalent expression) already in sorted_vec |
919 | if (it_found != sorted_vec->end()) { |
920 | // then we just increase its associated count |
921 | it_found->second += increase_count; |
922 | } else { |
923 | // Otherwise we add the pair (`elem_to_add`,1) at the right place |
924 | InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, increase_count}); |
925 | } |
926 | } |
927 | } |
928 | |
929 | } // namespace tir |
930 | } // namespace tvm |
931 | |