1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21* \file replace_expr_selected.cc
22* \brief Implementation of the pass that replaces in a statement
23 or expression all the subexpressions that are selected
24 with a predicate by another expression.
25*/
26
27#include "replace_selected_expr.h"
28
29#include <tvm/ir/transform.h> // For the class Pass and the class PassContext
30#include <tvm/tir/expr.h>
31#include <tvm/tir/expr_functor.h>
32#include <tvm/tir/function.h> // For the class PrimFunc
33#include <tvm/tir/stmt.h>
34#include <tvm/tir/stmt_functor.h>
35#include <tvm/tir/transform.h> // For the declaration of the pass
36
37namespace tvm {
38namespace tir {
39
40/*!
41 * \brief Toplevel (static) function that replace in an expression
42 everything that is selected by a predicate.
43 * \param expr The PrimExpr in which replacements will be performed
44 * \param new_expr The new expression replacing everything that's selected by the predicate
45 * \param predicate_selector The predicate which tells what to replace in `expr`
46 * \param can_replace_inside The predicate which tells in which nodes we are allowed to recurse
47 for pursuing further replacements.
48 * \return A new expression where the replacements have been done
49 */
50PrimExpr ReplaceSelectedExpr::ReplaceSelectedExprInExpr(
51 const PrimExpr& expr, std::function<bool(const PrimExpr&)> predicate_selector,
52 const PrimExpr& new_expr, std::function<bool(const PrimExpr&)> can_replace_inside) {
53 ReplaceSelectedExpr replace_expr_selected(predicate_selector, new_expr, can_replace_inside);
54 return replace_expr_selected.VisitExpr(expr);
55}
56
57/*!
58 * \brief Toplevel (static) function that replace in a statement what is selected by a predicate.
59 * \param stmt The Stmt in which replacements will be performed
60 * \param new_expr The new expression that will replace everything that's selected by the predicate
61 * \param predicate_selector The predicate which tells what to replace in `stmt`
62 * \param can_replace_inside The predicate which tells in which nodes we are allowed to recurse
63 for pursuing further replacements
64 * \return A new statement where the replacements have been done
65 */
66Stmt ReplaceSelectedExpr::ReplaceSelectedExprInStmt(
67 const Stmt& stmt, std::function<bool(const PrimExpr&)> predicate_selector,
68 const PrimExpr& new_expr, std::function<bool(const PrimExpr&)> can_replace_inside) {
69 ReplaceSelectedExpr replace_expr_selected(predicate_selector, new_expr, can_replace_inside);
70 return replace_expr_selected.VisitStmt(stmt);
71}
72
73/*!
74 * \brief Protected constructor of ReplaceSelectedExpr.
75 * \param predicate_selector The predicate which tells what to replace
76 * \param new_expr The new expression that will replace everything that's selected by the predicate
77 * \param can_replace_inside The predicate which tells in which nodes we are allowed to recurse
78 for pursuing further replacements
79 */
80ReplaceSelectedExpr::ReplaceSelectedExpr(std::function<bool(const PrimExpr&)> predicate_selector,
81 const PrimExpr& new_expr,
82 std::function<bool(const PrimExpr&)> can_replace_inside)
83 : predicate_selector_(predicate_selector),
84 new_expr_(new_expr),
85 can_replace_inside_(can_replace_inside) {}
86
87/*!
88 * \brief The method which overrides the generic dispatcher of StmtExprMutator
89 * \param expr The expression to mutate
90 */
91PrimExpr ReplaceSelectedExpr::VisitExpr(const PrimExpr& expr) {
92 // If the current expression is selected by the predicate
93 if (predicate_selector_(expr)) {
94 // Then simply return the new expression
95 return new_expr_;
96 } else {
97 // If replacing inside the current expression is allowed
98 if (can_replace_inside_(expr)) {
99 // then we continue the exploration recursively
100 return StmtExprMutator::VisitExpr(expr);
101 } else {
102 // otherwise we simply return the current expression
103 return expr;
104 }
105 }
106}
107
108} // namespace tir
109} // namespace tvm
110