1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file forward_rewrite.cc |
23 | * \brief Apply rewriting rules in a forward fashion. |
24 | */ |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/expr_functor.h> |
27 | #include <tvm/relay/op_attr_types.h> |
28 | #include <tvm/relay/transform.h> |
29 | |
30 | #include "pass_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | // Realizer class that realizes the expression |
36 | // Note that we can take benefit of its internal memo |
37 | // so that calling realize repeatively won't hurt perf. |
38 | class TempRealizer : private MixedModeMutator { |
39 | public: |
40 | Expr Realize(Expr expr) { return Mutate(expr); } |
41 | |
42 | private: |
43 | Expr DispatchVisitExpr(const Expr& expr) final { |
44 | Expr res; |
45 | if (const auto* temp = expr.as<TempExprNode>()) { |
46 | res = temp->Realize(); |
47 | } else { |
48 | res = MixedModeMutator::DispatchVisitExpr(expr); |
49 | } |
50 | return res; |
51 | } |
52 | }; |
53 | |
54 | class ForwardRewriter : private MixedModeMutator { |
55 | public: |
56 | ForwardRewriter(const OpAttrMap<FForwardRewrite>* rewrite_map, |
57 | std::function<ObjectRef(const Call&)> fcontext, |
58 | std::function<Expr(const Expr&)> fmulti_ref_trigger) |
59 | : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} |
60 | |
61 | ForwardRewriter(const FForwardRewrite* rewrite_func, |
62 | std::function<ObjectRef(const Call&)> fcontext, |
63 | std::function<Expr(const Expr&)> fmulti_ref_trigger) |
64 | : rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} |
65 | |
66 | // Transform expression. |
67 | Expr Rewrite(const Expr& expr) { |
68 | if (fmulti_ref_trigger_ != nullptr) { |
69 | ref_counter_ = GetExprRefCount(expr); |
70 | } |
71 | return realizer_.Realize(this->VisitExpr(expr)); |
72 | } |
73 | |
74 | private: |
75 | // The rewrite rule. |
76 | const OpAttrMap<FForwardRewrite>* rewrite_map_{nullptr}; |
77 | const FForwardRewrite* rewrite_func_{nullptr}; |
78 | // The context.const |
79 | std::function<ObjectRef(const Call&)> fcontext_{nullptr}; |
80 | // The multiple reference trigger |
81 | std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr}; |
82 | // Internal ref counter |
83 | std::unordered_map<const Object*, size_t> ref_counter_; |
84 | // internal realizer |
85 | TempRealizer realizer_; |
86 | |
87 | // Visit and allow non-realized version. |
88 | Expr GetTempExpr(const Expr& expr, const Expr& post) { |
89 | if (fmulti_ref_trigger_ != nullptr) { |
90 | Expr ret = post; |
91 | auto it = ref_counter_.find(expr.get()); |
92 | ICHECK(it != ref_counter_.end()); |
93 | if (it->second > 1) { |
94 | ret = fmulti_ref_trigger_(ret); |
95 | } |
96 | return ret; |
97 | } else { |
98 | return post; |
99 | } |
100 | } |
101 | |
102 | // Automatic fold TupleGetItem. |
103 | Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { |
104 | Expr tuple = this->GetTempExpr(op->tuple, post.as<TupleGetItemNode>()->tuple); |
105 | if (const auto* ptuple = tuple.as<TupleNode>()) { |
106 | return ptuple->fields[op->index]; |
107 | } else { |
108 | if (tuple.same_as(op->tuple)) { |
109 | return GetRef<Expr>(op); |
110 | } else { |
111 | return TupleGetItem(tuple, op->index); |
112 | } |
113 | } |
114 | } |
115 | |
116 | Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final { |
117 | tvm::Array<Expr> fields; |
118 | fields.reserve(tuple_node->fields.size()); |
119 | |
120 | const auto* post_tuple_node = post.as<TupleNode>(); |
121 | for (size_t i = 0; i < tuple_node->fields.size(); ++i) { |
122 | fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i])); |
123 | } |
124 | |
125 | return WithFields(GetRef<Tuple>(tuple_node), fields); |
126 | } |
127 | |
128 | Expr Rewrite_(const CallNode* call_node, const Expr& post) final { |
129 | const Call& ref_call = GetRef<Call>(call_node); |
130 | PackedFunc frewrite; |
131 | if (rewrite_func_) { |
132 | frewrite = *rewrite_func_; |
133 | } else { |
134 | ICHECK(rewrite_map_); |
135 | frewrite = rewrite_map_->get(call_node->op, nullptr); |
136 | } |
137 | const auto* post_node = post.as<CallNode>(); |
138 | auto new_op = post_node->op; |
139 | if (new_op->IsInstance<FunctionNode>()) { |
140 | new_op = realizer_.Realize(new_op); |
141 | } |
142 | bool unchanged = call_node->op.same_as(new_op); |
143 | |
144 | Array<Expr> call_args; |
145 | for (size_t i = 0; i < call_node->args.size(); ++i) { |
146 | Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]); |
147 | if (frewrite == nullptr) { |
148 | new_arg = realizer_.Realize(new_arg); |
149 | } |
150 | unchanged &= new_arg.same_as(call_node->args[i]); |
151 | call_args.push_back(new_arg); |
152 | } |
153 | // try to rewrite. |
154 | if (frewrite != nullptr) { |
155 | Expr res = frewrite(ref_call, call_args, |
156 | fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); |
157 | if (res.defined()) return res; |
158 | // abort, use old rule |
159 | for (size_t i = 0; i < call_args.size(); ++i) { |
160 | Expr arg = call_args[i]; |
161 | Expr new_arg = realizer_.Realize(arg); |
162 | if (!arg.same_as(new_arg)) { |
163 | call_args.Set(i, new_arg); |
164 | unchanged = false; |
165 | } |
166 | } |
167 | } |
168 | if (unchanged) return ref_call; |
169 | return Call(new_op, call_args, call_node->attrs, call_node->type_args, call_node->span); |
170 | } |
171 | }; |
172 | |
173 | Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_name, |
174 | std::function<ObjectRef(const Call&)> fcontext, |
175 | std::function<Expr(const Expr&)> fmulti_ref_trigger) { |
176 | auto rewrite_map = Op::GetAttrMap<FForwardRewrite>(rewrite_map_name); |
177 | return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); |
178 | } |
179 | |
180 | Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, |
181 | std::function<ObjectRef(const Call&)> fcontext, |
182 | std::function<Expr(const Expr&)> fmulti_ref_trigger) { |
183 | return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); |
184 | } |
185 | |
186 | } // namespace relay |
187 | } // namespace tvm |
188 | |