1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "annotated_region_set.h" |
21 | |
22 | #include <tvm/relay/error.h> |
23 | #include <tvm/relay/expr.h> |
24 | |
25 | #include <unordered_map> |
26 | #include <vector> |
27 | |
28 | namespace tvm { |
29 | namespace relay { |
30 | |
31 | AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { |
32 | for (auto candidate : regions_) { |
33 | if (candidate->nodes_.find(expr) != candidate->nodes_.end()) { |
34 | return candidate; |
35 | } |
36 | } |
37 | return AnnotatedRegion(nullptr); |
38 | } |
39 | |
40 | void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, AnnotatedRegion dest) { |
41 | if (dest == src) { |
42 | return; |
43 | } |
44 | |
45 | // Merge src to dest and erase src. |
46 | dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end()); |
47 | for (const auto& input : src->ins_) { |
48 | dest->ins_.push_back(input); |
49 | } |
50 | for (const auto& output : src->outs_) { |
51 | dest->outs_.push_back(output); |
52 | } |
53 | // if any of the outputs of src are inputs of dest, they become internal nodes |
54 | // so remove them from outs |
55 | std::vector<Expr> ins_to_remove; |
56 | for (const auto& input : dest->ins_) { |
57 | auto call = Downcast<Call>(input); |
58 | auto it = src->nodes_.find(call->args[0]); |
59 | if (it != src->nodes_.end()) { |
60 | dest->outs_.remove(*it); |
61 | ins_to_remove.push_back(input); |
62 | } |
63 | } |
64 | for (const auto& input : ins_to_remove) { |
65 | dest->ins_.remove(input); |
66 | } |
67 | regions_.erase(src); |
68 | } |
69 | |
70 | void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) { |
71 | auto src = GetRegion(expr); |
72 | if (src.defined()) { |
73 | MergeRegions(src, dest); |
74 | } else { |
75 | dest->nodes_.insert(expr); |
76 | } |
77 | } |
78 | |
79 | AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& func_name, |
80 | const std::string& target) { |
81 | auto ret = regions_.emplace(AnnotatedRegion()); |
82 | (*ret.first)->id_ = region_id_++; |
83 | (*ret.first)->target_ = target; |
84 | (*ret.first)->func_name_ = func_name; |
85 | return *ret.first; |
86 | } |
87 | |
88 | class AnnotatedRegionSet::Creator : protected MixedModeVisitor { |
89 | public: |
90 | Creator(const Op& region_begin_op, const Op& region_end_op, |
91 | const std::string& func_name = "default" ) |
92 | : begin_op_(region_begin_op), end_op_(region_end_op), func_name_(func_name) {} |
93 | |
94 | AnnotatedRegionSet Create(const Expr& expr) { |
95 | VisitExpr(expr); |
96 | return std::move(region_set_); |
97 | } |
98 | |
99 | void AddToArgRegion(Expr expr, Array<Expr> args) { |
100 | // Merge argument regions and add itself to the region. |
101 | |
102 | // Find the first open region. |
103 | AnnotatedRegion region; |
104 | for (auto arg : args) { |
105 | const CallNode* end = arg.as<CallNode>(); |
106 | if (end && end->op == end_op_) { // Ignore closed regions. |
107 | continue; |
108 | } |
109 | |
110 | region = region_set_->GetRegion(arg); |
111 | if (region.defined()) { |
112 | break; |
113 | } |
114 | } |
115 | |
116 | // Try to merge open regions. |
117 | for (auto arg : args) { |
118 | const CallNode* end = arg.as<CallNode>(); |
119 | if (end && end->op == end_op_) { // Ignore closed regions. |
120 | continue; |
121 | } |
122 | |
123 | auto arg_region = region_set_->GetRegion(arg); |
124 | ICHECK_EQ(region.defined(), arg_region.defined()) |
125 | << "Arg regions are inconsistent: " << AsText(expr); |
126 | if (region.defined() && region != arg_region) { |
127 | region_set_->MergeRegions(arg_region, region); |
128 | } |
129 | } |
130 | if (region.defined()) { |
131 | region_set_->AddToRegion(region, expr); |
132 | } |
133 | } |
134 | |
135 | void VisitExpr_(const CallNode* call) { |
136 | auto op_node = call->op.as<OpNode>(); |
137 | |
138 | if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) { |
139 | AddToArgRegion(GetRef<Call>(call), call->args); |
140 | } else if (call->op == begin_op_) { |
141 | // The annotation node is inserted on edge so it must have only one argument. |
142 | ICHECK_EQ(call->args.size(), 1U); |
143 | std::string target = call->attrs.as<CompilerAttrs>()->compiler; |
144 | |
145 | // Check if the argument already belongs to a region |
146 | auto region = region_set_->GetRegion(GetRef<Call>(call)); |
147 | ICHECK(!region.defined()); |
148 | |
149 | // Create a new region. |
150 | region = region_set_->MakeRegion(func_name_, target); |
151 | region->nodes_.insert(GetRef<Call>(call)); |
152 | region->ins_.push_back(GetRef<Call>(call)); |
153 | } else { |
154 | ICHECK_EQ(call->op, end_op_); |
155 | // The annotation node is inserted on edge so it must have only one argument. |
156 | ICHECK_EQ(call->args.size(), 1U); |
157 | std::string target = call->attrs.as<CompilerAttrs>()->compiler; |
158 | |
159 | // Check if the argument already belongs to a region |
160 | auto region = region_set_->GetRegion(call->args[0]); |
161 | if (!region.defined()) { |
162 | throw CompileError(ErrorBuilder() |
163 | << "Cannot find the corresponding region for end annotation:\n" |
164 | << AsText(GetRef<Call>(call), false)); |
165 | } else { |
166 | // If the argument is belonged to a region, it must have the same target. |
167 | // Otherwise we should see a region_begin op. |
168 | ICHECK_EQ(region->GetTarget(), target); |
169 | } |
170 | region->nodes_.insert(GetRef<Call>(call)); |
171 | region->outs_.push_back(GetRef<Call>(call)); |
172 | } |
173 | } |
174 | |
175 | void VisitExpr_(const TupleNode* op) { AddToArgRegion(GetRef<Tuple>(op), op->fields); } |
176 | |
177 | void VisitExpr_(const TupleGetItemNode* g) { |
178 | Array<Expr> args = {g->tuple}; |
179 | AddToArgRegion(GetRef<TupleGetItem>(g), args); |
180 | } |
181 | |
182 | void VisitExpr_(const LetNode* op) { |
183 | Array<Expr> args = {op->var, op->value, op->body}; |
184 | AddToArgRegion(GetRef<Let>(op), args); |
185 | ExprVisitor::VisitExpr_(op); |
186 | } |
187 | |
188 | void VisitExpr_(const IfNode* op) { |
189 | Array<Expr> args = {op->cond, op->true_branch, op->false_branch}; |
190 | AddToArgRegion(GetRef<If>(op), args); |
191 | ExprVisitor::VisitExpr_(op); |
192 | } |
193 | |
194 | void VisitExpr_(const RefCreateNode* op) { |
195 | Array<Expr> args = {op->value}; |
196 | AddToArgRegion(GetRef<RefCreate>(op), args); |
197 | ExprVisitor::VisitExpr_(op); |
198 | } |
199 | |
200 | void VisitExpr_(const RefReadNode* op) { |
201 | Array<Expr> args = {op->ref}; |
202 | AddToArgRegion(GetRef<RefRead>(op), args); |
203 | ExprVisitor::VisitExpr_(op); |
204 | } |
205 | |
206 | void VisitExpr_(const RefWriteNode* op) { |
207 | Array<Expr> args = {op->ref}; |
208 | AddToArgRegion(GetRef<RefWrite>(op), args); |
209 | ExprVisitor::VisitExpr_(op); |
210 | } |
211 | |
212 | private: |
213 | /*! \brief The region set being constructed.*/ |
214 | AnnotatedRegionSet region_set_; |
215 | /*! \brief Region 'begin' annotation operator. */ |
216 | const Op begin_op_; |
217 | /*! \brief Region 'end' annotation operator. */ |
218 | const Op end_op_; |
219 | /*! \brief The unique function name that is used to be the name of this region set. */ |
220 | const std::string func_name_; |
221 | }; |
222 | |
223 | AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end, |
224 | const std::string& func_name) { |
225 | return Creator(begin, end, func_name).Create(expr); |
226 | } |
227 | |
228 | TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); |
229 | TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode); |
230 | |
231 | TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet" ) |
232 | .set_body_typed([](Expr expr, Op begin, Op end) { |
233 | return AnnotatedRegionSet::Create(expr, begin, end); |
234 | }); |
235 | |
236 | TVM_REGISTER_GLOBAL("relay.analysis.GetRegion" ) |
237 | .set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { |
238 | return region_set->GetRegion(expr); |
239 | }); |
240 | |
241 | } // namespace relay |
242 | } // namespace tvm |
243 | |