1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file forward_rewrite.cc
23 * \brief Apply rewriting rules in a forward fashion.
24 */
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/expr_functor.h>
27#include <tvm/relay/op_attr_types.h>
28#include <tvm/relay/transform.h>
29
30#include "pass_utils.h"
31
32namespace tvm {
33namespace relay {
34
35// Realizer class that realizes the expression
36// Note that we can take benefit of its internal memo
37// so that calling realize repeatively won't hurt perf.
38class TempRealizer : private MixedModeMutator {
39 public:
40 Expr Realize(Expr expr) { return Mutate(expr); }
41
42 private:
43 Expr DispatchVisitExpr(const Expr& expr) final {
44 Expr res;
45 if (const auto* temp = expr.as<TempExprNode>()) {
46 res = temp->Realize();
47 } else {
48 res = MixedModeMutator::DispatchVisitExpr(expr);
49 }
50 return res;
51 }
52};
53
54class ForwardRewriter : private MixedModeMutator {
55 public:
56 ForwardRewriter(const OpAttrMap<FForwardRewrite>* rewrite_map,
57 std::function<ObjectRef(const Call&)> fcontext,
58 std::function<Expr(const Expr&)> fmulti_ref_trigger)
59 : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
60
61 ForwardRewriter(const FForwardRewrite* rewrite_func,
62 std::function<ObjectRef(const Call&)> fcontext,
63 std::function<Expr(const Expr&)> fmulti_ref_trigger)
64 : rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {}
65
66 // Transform expression.
67 Expr Rewrite(const Expr& expr) {
68 if (fmulti_ref_trigger_ != nullptr) {
69 ref_counter_ = GetExprRefCount(expr);
70 }
71 return realizer_.Realize(this->VisitExpr(expr));
72 }
73
74 private:
75 // The rewrite rule.
76 const OpAttrMap<FForwardRewrite>* rewrite_map_{nullptr};
77 const FForwardRewrite* rewrite_func_{nullptr};
78 // The context.const
79 std::function<ObjectRef(const Call&)> fcontext_{nullptr};
80 // The multiple reference trigger
81 std::function<Expr(const Expr&)> fmulti_ref_trigger_{nullptr};
82 // Internal ref counter
83 std::unordered_map<const Object*, size_t> ref_counter_;
84 // internal realizer
85 TempRealizer realizer_;
86
87 // Visit and allow non-realized version.
88 Expr GetTempExpr(const Expr& expr, const Expr& post) {
89 if (fmulti_ref_trigger_ != nullptr) {
90 Expr ret = post;
91 auto it = ref_counter_.find(expr.get());
92 ICHECK(it != ref_counter_.end());
93 if (it->second > 1) {
94 ret = fmulti_ref_trigger_(ret);
95 }
96 return ret;
97 } else {
98 return post;
99 }
100 }
101
102 // Automatic fold TupleGetItem.
103 Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
104 Expr tuple = this->GetTempExpr(op->tuple, post.as<TupleGetItemNode>()->tuple);
105 if (const auto* ptuple = tuple.as<TupleNode>()) {
106 return ptuple->fields[op->index];
107 } else {
108 if (tuple.same_as(op->tuple)) {
109 return GetRef<Expr>(op);
110 } else {
111 return TupleGetItem(tuple, op->index);
112 }
113 }
114 }
115
116 Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
117 tvm::Array<Expr> fields;
118 fields.reserve(tuple_node->fields.size());
119
120 const auto* post_tuple_node = post.as<TupleNode>();
121 for (size_t i = 0; i < tuple_node->fields.size(); ++i) {
122 fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i]));
123 }
124
125 return WithFields(GetRef<Tuple>(tuple_node), fields);
126 }
127
128 Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
129 const Call& ref_call = GetRef<Call>(call_node);
130 PackedFunc frewrite;
131 if (rewrite_func_) {
132 frewrite = *rewrite_func_;
133 } else {
134 ICHECK(rewrite_map_);
135 frewrite = rewrite_map_->get(call_node->op, nullptr);
136 }
137 const auto* post_node = post.as<CallNode>();
138 auto new_op = post_node->op;
139 if (new_op->IsInstance<FunctionNode>()) {
140 new_op = realizer_.Realize(new_op);
141 }
142 bool unchanged = call_node->op.same_as(new_op);
143
144 Array<Expr> call_args;
145 for (size_t i = 0; i < call_node->args.size(); ++i) {
146 Expr new_arg = this->GetTempExpr(call_node->args[i], post_node->args[i]);
147 if (frewrite == nullptr) {
148 new_arg = realizer_.Realize(new_arg);
149 }
150 unchanged &= new_arg.same_as(call_node->args[i]);
151 call_args.push_back(new_arg);
152 }
153 // try to rewrite.
154 if (frewrite != nullptr) {
155 Expr res = frewrite(ref_call, call_args,
156 fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr));
157 if (res.defined()) return res;
158 // abort, use old rule
159 for (size_t i = 0; i < call_args.size(); ++i) {
160 Expr arg = call_args[i];
161 Expr new_arg = realizer_.Realize(arg);
162 if (!arg.same_as(new_arg)) {
163 call_args.Set(i, new_arg);
164 unchanged = false;
165 }
166 }
167 }
168 if (unchanged) return ref_call;
169 return Call(new_op, call_args, call_node->attrs, call_node->type_args, call_node->span);
170 }
171};
172
173Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_name,
174 std::function<ObjectRef(const Call&)> fcontext,
175 std::function<Expr(const Expr&)> fmulti_ref_trigger) {
176 auto rewrite_map = Op::GetAttrMap<FForwardRewrite>(rewrite_map_name);
177 return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr);
178}
179
180Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func,
181 std::function<ObjectRef(const Call&)> fcontext,
182 std::function<Expr(const Expr&)> fmulti_ref_trigger) {
183 return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr);
184}
185
186} // namespace relay
187} // namespace tvm
188