1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "annotated_region_set.h"
21
22#include <tvm/relay/error.h>
23#include <tvm/relay/expr.h>
24
25#include <unordered_map>
26#include <vector>
27
28namespace tvm {
29namespace relay {
30
31AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const {
32 for (auto candidate : regions_) {
33 if (candidate->nodes_.find(expr) != candidate->nodes_.end()) {
34 return candidate;
35 }
36 }
37 return AnnotatedRegion(nullptr);
38}
39
40void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, AnnotatedRegion dest) {
41 if (dest == src) {
42 return;
43 }
44
45 // Merge src to dest and erase src.
46 dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end());
47 for (const auto& input : src->ins_) {
48 dest->ins_.push_back(input);
49 }
50 for (const auto& output : src->outs_) {
51 dest->outs_.push_back(output);
52 }
53 // if any of the outputs of src are inputs of dest, they become internal nodes
54 // so remove them from outs
55 std::vector<Expr> ins_to_remove;
56 for (const auto& input : dest->ins_) {
57 auto call = Downcast<Call>(input);
58 auto it = src->nodes_.find(call->args[0]);
59 if (it != src->nodes_.end()) {
60 dest->outs_.remove(*it);
61 ins_to_remove.push_back(input);
62 }
63 }
64 for (const auto& input : ins_to_remove) {
65 dest->ins_.remove(input);
66 }
67 regions_.erase(src);
68}
69
70void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) {
71 auto src = GetRegion(expr);
72 if (src.defined()) {
73 MergeRegions(src, dest);
74 } else {
75 dest->nodes_.insert(expr);
76 }
77}
78
79AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& func_name,
80 const std::string& target) {
81 auto ret = regions_.emplace(AnnotatedRegion());
82 (*ret.first)->id_ = region_id_++;
83 (*ret.first)->target_ = target;
84 (*ret.first)->func_name_ = func_name;
85 return *ret.first;
86}
87
88class AnnotatedRegionSet::Creator : protected MixedModeVisitor {
89 public:
90 Creator(const Op& region_begin_op, const Op& region_end_op,
91 const std::string& func_name = "default")
92 : begin_op_(region_begin_op), end_op_(region_end_op), func_name_(func_name) {}
93
94 AnnotatedRegionSet Create(const Expr& expr) {
95 VisitExpr(expr);
96 return std::move(region_set_);
97 }
98
99 void AddToArgRegion(Expr expr, Array<Expr> args) {
100 // Merge argument regions and add itself to the region.
101
102 // Find the first open region.
103 AnnotatedRegion region;
104 for (auto arg : args) {
105 const CallNode* end = arg.as<CallNode>();
106 if (end && end->op == end_op_) { // Ignore closed regions.
107 continue;
108 }
109
110 region = region_set_->GetRegion(arg);
111 if (region.defined()) {
112 break;
113 }
114 }
115
116 // Try to merge open regions.
117 for (auto arg : args) {
118 const CallNode* end = arg.as<CallNode>();
119 if (end && end->op == end_op_) { // Ignore closed regions.
120 continue;
121 }
122
123 auto arg_region = region_set_->GetRegion(arg);
124 ICHECK_EQ(region.defined(), arg_region.defined())
125 << "Arg regions are inconsistent: " << AsText(expr);
126 if (region.defined() && region != arg_region) {
127 region_set_->MergeRegions(arg_region, region);
128 }
129 }
130 if (region.defined()) {
131 region_set_->AddToRegion(region, expr);
132 }
133 }
134
135 void VisitExpr_(const CallNode* call) {
136 auto op_node = call->op.as<OpNode>();
137
138 if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
139 AddToArgRegion(GetRef<Call>(call), call->args);
140 } else if (call->op == begin_op_) {
141 // The annotation node is inserted on edge so it must have only one argument.
142 ICHECK_EQ(call->args.size(), 1U);
143 std::string target = call->attrs.as<CompilerAttrs>()->compiler;
144
145 // Check if the argument already belongs to a region
146 auto region = region_set_->GetRegion(GetRef<Call>(call));
147 ICHECK(!region.defined());
148
149 // Create a new region.
150 region = region_set_->MakeRegion(func_name_, target);
151 region->nodes_.insert(GetRef<Call>(call));
152 region->ins_.push_back(GetRef<Call>(call));
153 } else {
154 ICHECK_EQ(call->op, end_op_);
155 // The annotation node is inserted on edge so it must have only one argument.
156 ICHECK_EQ(call->args.size(), 1U);
157 std::string target = call->attrs.as<CompilerAttrs>()->compiler;
158
159 // Check if the argument already belongs to a region
160 auto region = region_set_->GetRegion(call->args[0]);
161 if (!region.defined()) {
162 throw CompileError(ErrorBuilder()
163 << "Cannot find the corresponding region for end annotation:\n"
164 << AsText(GetRef<Call>(call), false));
165 } else {
166 // If the argument is belonged to a region, it must have the same target.
167 // Otherwise we should see a region_begin op.
168 ICHECK_EQ(region->GetTarget(), target);
169 }
170 region->nodes_.insert(GetRef<Call>(call));
171 region->outs_.push_back(GetRef<Call>(call));
172 }
173 }
174
175 void VisitExpr_(const TupleNode* op) { AddToArgRegion(GetRef<Tuple>(op), op->fields); }
176
177 void VisitExpr_(const TupleGetItemNode* g) {
178 Array<Expr> args = {g->tuple};
179 AddToArgRegion(GetRef<TupleGetItem>(g), args);
180 }
181
182 void VisitExpr_(const LetNode* op) {
183 Array<Expr> args = {op->var, op->value, op->body};
184 AddToArgRegion(GetRef<Let>(op), args);
185 ExprVisitor::VisitExpr_(op);
186 }
187
188 void VisitExpr_(const IfNode* op) {
189 Array<Expr> args = {op->cond, op->true_branch, op->false_branch};
190 AddToArgRegion(GetRef<If>(op), args);
191 ExprVisitor::VisitExpr_(op);
192 }
193
194 void VisitExpr_(const RefCreateNode* op) {
195 Array<Expr> args = {op->value};
196 AddToArgRegion(GetRef<RefCreate>(op), args);
197 ExprVisitor::VisitExpr_(op);
198 }
199
200 void VisitExpr_(const RefReadNode* op) {
201 Array<Expr> args = {op->ref};
202 AddToArgRegion(GetRef<RefRead>(op), args);
203 ExprVisitor::VisitExpr_(op);
204 }
205
206 void VisitExpr_(const RefWriteNode* op) {
207 Array<Expr> args = {op->ref};
208 AddToArgRegion(GetRef<RefWrite>(op), args);
209 ExprVisitor::VisitExpr_(op);
210 }
211
212 private:
213 /*! \brief The region set being constructed.*/
214 AnnotatedRegionSet region_set_;
215 /*! \brief Region 'begin' annotation operator. */
216 const Op begin_op_;
217 /*! \brief Region 'end' annotation operator. */
218 const Op end_op_;
219 /*! \brief The unique function name that is used to be the name of this region set. */
220 const std::string func_name_;
221};
222
223AnnotatedRegionSet AnnotatedRegionSet::Create(const Expr& expr, const Op& begin, const Op& end,
224 const std::string& func_name) {
225 return Creator(begin, end, func_name).Create(expr);
226}
227
228TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode);
229TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode);
230
231TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet")
232 .set_body_typed([](Expr expr, Op begin, Op end) {
233 return AnnotatedRegionSet::Create(expr, begin, end);
234 });
235
236TVM_REGISTER_GLOBAL("relay.analysis.GetRegion")
237 .set_body_typed([](AnnotatedRegionSet region_set, Expr expr) {
238 return region_set->GetRegion(expr);
239 });
240
241} // namespace relay
242} // namespace tvm
243