1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file replace_selected_expr.h |
22 | * \brief Interface of the pass that replaces in a statement |
23 | or expression all the subexpressions that are selected |
24 | with a predicate by another expression. |
25 | */ |
26 | |
27 | #ifndef TVM_TIR_TRANSFORMS_REPLACE_SELECTED_EXPR_H_ |
28 | #define TVM_TIR_TRANSFORMS_REPLACE_SELECTED_EXPR_H_ |
29 | |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/expr_functor.h> |
32 | #include <tvm/tir/stmt.h> |
33 | #include <tvm/tir/stmt_functor.h> // For the class StmtExprMutator |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | /*! |
39 | * \brief Mutator for replacing the expressions selected by a predicate in a statement and/or |
40 | in an expression, which only replace inside of nodes in which it is allowed to perform |
41 | replacecements (given by a second predicate) |
42 | */ |
43 | class ReplaceSelectedExpr : public StmtExprMutator { |
44 | public: |
45 | // Toplevel (static) functions |
46 | static PrimExpr ReplaceSelectedExprInExpr( |
47 | const PrimExpr& expr, std::function<bool(const PrimExpr&)> predicate_selector, |
48 | const PrimExpr& new_expr, std::function<bool(const PrimExpr&)> can_replace_inside); |
49 | static Stmt ReplaceSelectedExprInStmt(const Stmt& stmt, |
50 | std::function<bool(const PrimExpr&)> predicate_selector, |
51 | const PrimExpr& new_expr, |
52 | std::function<bool(const PrimExpr&)> can_replace_inside); |
53 | |
54 | protected: |
55 | // Constructor |
56 | ReplaceSelectedExpr(std::function<bool(const PrimExpr&)> predicate_selector, |
57 | const PrimExpr& new_expr, |
58 | std::function<bool(const PrimExpr&)> can_replace_inside); |
59 | |
60 | PrimExpr VisitExpr(const PrimExpr& expr) override; |
61 | |
62 | private: |
63 | // The predicate used for selecting what will be replaced |
64 | std::function<bool(const PrimExpr&)> predicate_selector_; |
65 | // The expression used for replacing |
66 | const PrimExpr& new_expr_; |
67 | // The predicate used for knowning inside which nodes we can do rewriting |
68 | // (i.e. in which nodes it can recurse) |
69 | std::function<bool(const PrimExpr&)> can_replace_inside_; |
70 | }; |
71 | |
72 | } // namespace tir |
73 | } // namespace tvm |
74 | |
75 | #endif // TVM_TIR_TRANSFORMS_REPLACE_SELECTED_EXPR_H_ |
76 | |