1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file eliminate_common_subexpr.cc |
23 | * \brief Combine common subexpressions. |
24 | * |
25 | * This is an optimization pass that eliminates common subexpressions. During the pass, it tries |
26 | * to replace an expression with a previously appeared expression with the same input and |
27 | * attributes. The fskip callback argument allows us to skip specific expressions. |
28 | */ |
29 | #include <tvm/relay/analysis.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/relay/transform.h> |
32 | |
33 | #include <unordered_map> |
34 | |
35 | #include "pattern_utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | class CommonSubexprEliminator : public MixedModeMutator { |
41 | public: |
42 | explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip) : fskip_(fskip) {} |
43 | |
44 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
45 | static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful" ); |
46 | Expr new_expr = post; |
47 | const CallNode* new_call = new_expr.as<CallNode>(); |
48 | ICHECK(new_call); |
49 | const OpNode* op = new_call->op.as<OpNode>(); |
50 | StructuralEqual attrs_equal; |
51 | |
52 | if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) { |
53 | return new_expr; |
54 | } |
55 | if (fskip_ != nullptr && fskip_(new_expr)) { |
56 | return new_expr; |
57 | } |
58 | |
59 | auto it = expr_map_.find(new_call->op); |
60 | if (it != expr_map_.end()) { |
61 | for (const Expr& candidate_expr : it->second) { |
62 | if (const CallNode* candidate = candidate_expr.as<CallNode>()) { |
63 | bool is_equivalent = true; |
64 | if (!attrs_equal(new_call->attrs, candidate->attrs)) { |
65 | continue; |
66 | } |
67 | for (size_t i = 0; i < new_call->args.size(); i++) { |
68 | if (!new_call->args[i].same_as(candidate->args[i]) && |
69 | !IsEqualScalar(new_call->args[i], candidate->args[i])) { |
70 | is_equivalent = false; |
71 | break; |
72 | } |
73 | } |
74 | if (!is_equivalent) continue; |
75 | return GetRef<Call>(candidate); |
76 | } |
77 | } |
78 | } |
79 | expr_map_[new_call->op].push_back(new_expr); |
80 | return new_expr; |
81 | } |
82 | |
83 | Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { |
84 | Expr new_expr = post; |
85 | const TupleGetItemNode* new_tuple_item = new_expr.as<TupleGetItemNode>(); |
86 | ICHECK(new_tuple_item); |
87 | |
88 | if (fskip_ != nullptr && fskip_(new_expr)) { |
89 | return new_expr; |
90 | } |
91 | |
92 | auto it = expr_map_.find(new_tuple_item->tuple); |
93 | if (it != expr_map_.end()) { |
94 | for (const Expr& candidate_expr : it->second) { |
95 | if (const TupleGetItemNode* candidate = candidate_expr.as<TupleGetItemNode>()) { |
96 | if (new_tuple_item->index == candidate->index) { |
97 | return GetRef<Expr>(candidate); |
98 | } |
99 | } |
100 | } |
101 | } |
102 | expr_map_[new_tuple_item->tuple].push_back(new_expr); |
103 | return new_expr; |
104 | } |
105 | |
106 | std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_; |
107 | runtime::TypedPackedFunc<bool(Expr)> fskip_; |
108 | }; |
109 | |
110 | Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) { |
111 | return CommonSubexprEliminator(callback)(expr); |
112 | } |
113 | |
114 | namespace transform { |
115 | |
116 | Pass EliminateCommonSubexpr(PackedFunc fskip) { |
117 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
118 | [=](Function f, IRModule m, PassContext pc) { |
119 | return Downcast<Function>(EliminateCommonSubexpr(f, fskip)); |
120 | }; |
121 | return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr" , {"InferType" }); |
122 | } |
123 | |
124 | TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr" ) |
125 | .set_body_typed(EliminateCommonSubexpr); |
126 | |
127 | } // namespace transform |
128 | |
129 | } // namespace relay |
130 | } // namespace tvm |
131 | |