1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file eliminate_common_subexpr.cc
23 * \brief Combine common subexpressions.
24 *
25 * This is an optimization pass that eliminates common subexpressions. During the pass, it tries
26 * to replace an expression with a previously appeared expression with the same input and
27 * attributes. The fskip callback argument allows us to skip specific expressions.
28 */
29#include <tvm/relay/analysis.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/relay/transform.h>
32
33#include <unordered_map>
34
35#include "pattern_utils.h"
36
37namespace tvm {
38namespace relay {
39
40class CommonSubexprEliminator : public MixedModeMutator {
41 public:
42 explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip) : fskip_(fskip) {}
43
44 Expr Rewrite_(const CallNode* call, const Expr& post) final {
45 static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
46 Expr new_expr = post;
47 const CallNode* new_call = new_expr.as<CallNode>();
48 ICHECK(new_call);
49 const OpNode* op = new_call->op.as<OpNode>();
50 StructuralEqual attrs_equal;
51
52 if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
53 return new_expr;
54 }
55 if (fskip_ != nullptr && fskip_(new_expr)) {
56 return new_expr;
57 }
58
59 auto it = expr_map_.find(new_call->op);
60 if (it != expr_map_.end()) {
61 for (const Expr& candidate_expr : it->second) {
62 if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
63 bool is_equivalent = true;
64 if (!attrs_equal(new_call->attrs, candidate->attrs)) {
65 continue;
66 }
67 for (size_t i = 0; i < new_call->args.size(); i++) {
68 if (!new_call->args[i].same_as(candidate->args[i]) &&
69 !IsEqualScalar(new_call->args[i], candidate->args[i])) {
70 is_equivalent = false;
71 break;
72 }
73 }
74 if (!is_equivalent) continue;
75 return GetRef<Call>(candidate);
76 }
77 }
78 }
79 expr_map_[new_call->op].push_back(new_expr);
80 return new_expr;
81 }
82
83 Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
84 Expr new_expr = post;
85 const TupleGetItemNode* new_tuple_item = new_expr.as<TupleGetItemNode>();
86 ICHECK(new_tuple_item);
87
88 if (fskip_ != nullptr && fskip_(new_expr)) {
89 return new_expr;
90 }
91
92 auto it = expr_map_.find(new_tuple_item->tuple);
93 if (it != expr_map_.end()) {
94 for (const Expr& candidate_expr : it->second) {
95 if (const TupleGetItemNode* candidate = candidate_expr.as<TupleGetItemNode>()) {
96 if (new_tuple_item->index == candidate->index) {
97 return GetRef<Expr>(candidate);
98 }
99 }
100 }
101 }
102 expr_map_[new_tuple_item->tuple].push_back(new_expr);
103 return new_expr;
104 }
105
106 std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
107 runtime::TypedPackedFunc<bool(Expr)> fskip_;
108};
109
110Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
111 return CommonSubexprEliminator(callback)(expr);
112}
113
114namespace transform {
115
116Pass EliminateCommonSubexpr(PackedFunc fskip) {
117 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
118 [=](Function f, IRModule m, PassContext pc) {
119 return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
120 };
121 return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
122}
123
124TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
125 .set_body_typed(EliminateCommonSubexpr);
126
127} // namespace transform
128
129} // namespace relay
130} // namespace tvm
131