1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file common_subexpr_elim.h |
22 | * \brief Interface of the Common Subexpressions Elimination (CSE) pass which rewrites statements |
23 | and expressions in order to eliminate redundant computations. In order to achieve that, |
24 | common (sub-)expressions are introduced into variables with let-in bindings, and the |
25 | places where the expression was used are replaced with the freshly introduced variable. |
26 | */ |
27 | |
28 | #ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_ |
29 | #define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_ |
30 | |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/expr_functor.h> |
33 | #include <tvm/tir/stmt.h> |
34 | #include <tvm/tir/stmt_functor.h> // For the class StmtExprMutator |
35 | #include <tvm/tir/var.h> |
36 | |
37 | #include <utility> // For std::pair |
38 | #include <vector> |
39 | |
40 | #include "common_subexpr_elim_tools.h" // For the class MaybeValue |
41 | |
42 | namespace tvm { |
43 | namespace tir { |
44 | |
45 | /*! |
46 | * \brief A context is a vector of pairs that associates Var to MaybeValue |
47 | (which are either an expression or nothing) |
48 | */ |
49 | using Context = std::vector<std::pair<Var, MaybeValue>>; |
50 | |
51 | /*! |
52 | * \brief Mutator that performs Common Subexpression Elimination (CSE) for the body of a |
53 | PrimFunc, mutating both its expressions and statements. |
54 | */ |
55 | class CommonSubexpressionEliminator : public StmtExprMutator { |
56 | public: |
57 | // Toplevel (static) function |
58 | static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms); |
59 | |
60 | PrimExpr VisitExpr(const PrimExpr& expr) override; |
61 | Stmt VisitStmt(const Stmt& stmt) override; |
62 | |
63 | int GetNbVarGenerated(); |
64 | |
65 | protected: |
66 | // Constructor |
67 | CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init, |
68 | bool identify_equiv_terms); |
69 | |
70 | PrimExpr VisitExpr_(const LetNode* op) override; |
71 | |
72 | Stmt VisitStmt_(const LetStmtNode* op) override; |
73 | Stmt VisitStmt_(const ForNode* op) override; |
74 | |
75 | private: |
76 | Stmt initial_body_; // Kept for checking if names of new variables already exist |
77 | Context context_; // Context associating variables to (maybe) definitions |
78 | int num_last_try_ = 0; // Number of the last variable tried |
79 | int nb_var_ = 0; // Number of variables introduced by the CSE pass |
80 | |
81 | bool identify_equiv_terms_ = false; |
82 | |
83 | static bool ForbiddenComputation(const PrimExpr& expr); |
84 | static bool IsEligibleComputation(const PrimExpr& expr); |
85 | static bool CanContainEligibleComputations(const PrimExpr& expr); |
86 | static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b); |
87 | Var GenerateNewVar(DataType type_annotation); |
88 | }; |
89 | |
90 | } // namespace tir |
91 | } // namespace tvm |
92 | |
93 | #endif // TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_ |
94 | |