1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file common_subexpr_elim.h
22 * \brief Interface of the Common Subexpressions Elimination (CSE) pass which rewrites statements
23 and expressions in order to eliminate redundant computations. In order to achieve that,
24 common (sub-)expressions are introduced into variables with let-in bindings, and the
25 places where the expression was used are replaced with the freshly introduced variable.
26 */
27
28#ifndef TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_
29#define TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_
30
31#include <tvm/tir/expr.h>
32#include <tvm/tir/expr_functor.h>
33#include <tvm/tir/stmt.h>
34#include <tvm/tir/stmt_functor.h> // For the class StmtExprMutator
35#include <tvm/tir/var.h>
36
37#include <utility> // For std::pair
38#include <vector>
39
40#include "common_subexpr_elim_tools.h" // For the class MaybeValue
41
42namespace tvm {
43namespace tir {
44
45/*!
46 * \brief A context is a vector of pairs that associates Var to MaybeValue
47 (which are either an expression or nothing)
48 */
49using Context = std::vector<std::pair<Var, MaybeValue>>;
50
51/*!
52 * \brief Mutator that performs Common Subexpression Elimination (CSE) for the body of a
53 PrimFunc, mutating both its expressions and statements.
54 */
55class CommonSubexpressionEliminator : public StmtExprMutator {
56 public:
57 // Toplevel (static) function
58 static Stmt PerformCSE(const Stmt& stmt, const Context& context_init, bool identify_equiv_terms);
59
60 PrimExpr VisitExpr(const PrimExpr& expr) override;
61 Stmt VisitStmt(const Stmt& stmt) override;
62
63 int GetNbVarGenerated();
64
65 protected:
66 // Constructor
67 CommonSubexpressionEliminator(const Stmt& stmt, const Context& context_init,
68 bool identify_equiv_terms);
69
70 PrimExpr VisitExpr_(const LetNode* op) override;
71
72 Stmt VisitStmt_(const LetStmtNode* op) override;
73 Stmt VisitStmt_(const ForNode* op) override;
74
75 private:
76 Stmt initial_body_; // Kept for checking if names of new variables already exist
77 Context context_; // Context associating variables to (maybe) definitions
78 int num_last_try_ = 0; // Number of the last variable tried
79 int nb_var_ = 0; // Number of variables introduced by the CSE pass
80
81 bool identify_equiv_terms_ = false;
82
83 static bool ForbiddenComputation(const PrimExpr& expr);
84 static bool IsEligibleComputation(const PrimExpr& expr);
85 static bool CanContainEligibleComputations(const PrimExpr& expr);
86 static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b);
87 Var GenerateNewVar(DataType type_annotation);
88};
89
90} // namespace tir
91} // namespace tvm
92
93#endif // TVM_TIR_TRANSFORMS_COMMON_SUBEXPR_ELIM_H_
94