1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/transforms/annotated_region_set.h |
22 | * \brief Define data structures to extract and manipulate regions from |
23 | * a relay function. Regions are denoted by region_begin and region_end |
24 | * annotations that exist on all the input and output edges of the region. |
25 | */ |
26 | |
27 | #ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ |
28 | #define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ |
29 | |
30 | #include <tvm/relay/analysis.h> |
31 | #include <tvm/relay/attrs/annotation.h> |
32 | #include <tvm/relay/error.h> |
33 | #include <tvm/relay/expr.h> |
34 | #include <tvm/relay/expr_functor.h> |
35 | #include <tvm/relay/transform.h> |
36 | |
37 | #include <list> |
38 | #include <string> |
39 | #include <unordered_set> |
40 | #include <utility> |
41 | #include <vector> |
42 | |
43 | namespace tvm { |
44 | namespace relay { |
45 | |
46 | class AnnotatedRegion; |
47 | class AnnotatedRegionSet; |
48 | |
49 | class AnnotatedRegionNode : public Object { |
50 | public: |
51 | void VisitAttrs(AttrVisitor* v) { |
52 | v->Visit("id" , &id_); |
53 | v->Visit("target" , &target_); |
54 | Array<Expr> nodes_array(nodes_.begin(), nodes_.end()); |
55 | v->Visit("nodes" , &nodes_array); |
56 | Array<Expr> args_array(ins_.begin(), ins_.end()); |
57 | v->Visit("args" , &args_array); |
58 | Array<Expr> rets_array(outs_.begin(), outs_.end()); |
59 | v->Visit("rets" , &rets_array); |
60 | } |
61 | |
62 | /*! \brief Get the region ID. */ |
63 | int GetID() const { return id_; } |
64 | |
65 | /*! \brief Get the region name. */ |
66 | std::string GetName() const { return func_name_; } |
67 | |
68 | /*! \brief Get the region target. */ |
69 | std::string GetTarget() const { return target_; } |
70 | |
71 | /*! \brief Get the region's inputs. */ |
72 | std::list<Expr> GetInputs() const { return ins_; } |
73 | |
74 | /*! \brief Get the region's outputs. */ |
75 | std::list<Expr> GetOutputs() const { return outs_; } |
76 | |
77 | /*! \brief Get the region's nodes. */ |
78 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> GetNodes() const { return nodes_; } |
79 | |
80 | static constexpr const char* _type_key = "relay.AnnotatedRegion" ; |
81 | TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object); |
82 | |
83 | protected: |
84 | /*! \brief The region ID. */ |
85 | int id_{-1}; |
86 | /*! \brief The func name. */ |
87 | std::string func_name_ = "default" ; |
88 | /*! \brief The target for this region. */ |
89 | std::string target_ = "default" ; |
90 | /*! \brief The inputs to this region. */ |
91 | std::list<Expr> ins_; |
92 | /*! \brief The outputs of this region */ |
93 | std::list<Expr> outs_; |
94 | /*! \brief Nodes in this region. */ |
95 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> nodes_; |
96 | |
97 | friend class AnnotatedRegionSet; |
98 | friend class AnnotatedRegionSetNode; |
99 | }; |
100 | |
101 | /*! |
102 | * \brief An object to hold the properties of a region as used by the |
103 | * AnnotatedRegionSet class. This should be considered read-only. |
104 | */ |
105 | class AnnotatedRegion : public ObjectRef { |
106 | public: |
107 | AnnotatedRegion() { |
108 | auto n = make_object<AnnotatedRegionNode>(); |
109 | data_ = std::move(n); |
110 | } |
111 | |
112 | /*! |
113 | * \brief Construct from an object pointer. |
114 | * \param n The object pointer. |
115 | */ |
116 | explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {} |
117 | |
118 | /*! \return Mutable pointers to the node. */ |
119 | AnnotatedRegionNode* operator->() const { |
120 | auto* ptr = get_mutable(); |
121 | ICHECK(ptr != nullptr); |
122 | return static_cast<AnnotatedRegionNode*>(ptr); |
123 | } |
124 | }; |
125 | |
126 | class AnnotatedRegionSetNode : public Object { |
127 | using UnorderedRegionSet = std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>; |
128 | // Create iterator alias for a RegionSet object. |
129 | using iterator = UnorderedRegionSet::iterator; |
130 | using const_iterator = UnorderedRegionSet::const_iterator; |
131 | |
132 | public: |
133 | /*! \brief Default constructor. */ |
134 | AnnotatedRegionSetNode() = default; |
135 | |
136 | /*! \return The begin iterator */ |
137 | iterator begin() { return regions_.begin(); } |
138 | /*! \return The end iterator */ |
139 | iterator end() { return regions_.end(); } |
140 | /*! \return The const begin iterator */ |
141 | const_iterator begin() const { return regions_.begin(); } |
142 | /*! \return The const end iterator */ |
143 | const_iterator end() const { return regions_.end(); } |
144 | |
145 | /*! |
146 | * \brief Get the region that an expression belongs to. |
147 | * |
148 | * \param expr Which expr to get the region for. |
149 | * |
150 | * \return A pointer to the region, nullptr if the expression |
151 | * doesn't belong to a region. |
152 | */ |
153 | AnnotatedRegion GetRegion(const Expr& expr) const; |
154 | |
155 | /*! |
156 | * \brief Merge src region into dest region. |
157 | * |
158 | * \param src The region to merge - will be erased. |
159 | * \param dest The region into which src will be merged. |
160 | */ |
161 | void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest); |
162 | |
163 | void VisitAttrs(AttrVisitor* v) { |
164 | Array<AnnotatedRegion> regions_array(regions_.begin(), regions_.end()); |
165 | v->Visit("regions" , ®ions_array); |
166 | } |
167 | |
168 | static constexpr const char* _type_key = "relay.AnnotatedRegionSet" ; |
169 | TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionSetNode, Object); |
170 | |
171 | private: |
172 | /*! |
173 | * \brief Add an expression to a region. |
174 | * |
175 | * \param dest The region to add the expression to. |
176 | * \param expr The expression. |
177 | */ |
178 | void AddToRegion(AnnotatedRegion dest, const Expr& expr); |
179 | |
180 | /*! |
181 | * \brief Make a new region for a target. |
182 | * |
183 | * \return The new region. |
184 | */ |
185 | AnnotatedRegion MakeRegion(const std::string& func_name, const std::string& target); |
186 | |
187 | std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> regions_; |
188 | /*! \brief The next region ID to assign. */ |
189 | int region_id_{0}; |
190 | |
191 | friend class AnnotatedRegionSet; |
192 | }; |
193 | |
194 | /*! |
195 | * \brief A class to hold a set of regions produced from a relay expression |
196 | * that contains 'region_begin' and 'region_end' style annotations. The |
197 | * regions should be disjoint. The class provides both a method to construct |
198 | * the region set of a given relay expression as well as additional methods |
199 | * to update and query regions. |
200 | */ |
201 | class AnnotatedRegionSet : public ObjectRef { |
202 | using UnorderedRegionSet = std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>; |
203 | // Create iterator alias for a RegionSet object. |
204 | using iterator = UnorderedRegionSet::iterator; |
205 | using const_iterator = UnorderedRegionSet::const_iterator; |
206 | |
207 | public: |
208 | AnnotatedRegionSet() { |
209 | auto n = make_object<AnnotatedRegionSetNode>(); |
210 | data_ = std::move(n); |
211 | } |
212 | |
213 | /*! |
214 | * \brief Construct from an object pointer. |
215 | * |
216 | * \param n The object pointer. |
217 | */ |
218 | explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {} |
219 | |
220 | /*! \return The begin iterator. */ |
221 | iterator begin() { |
222 | auto* n = operator->(); |
223 | ICHECK(n); |
224 | return n->begin(); |
225 | } |
226 | /*! \return The end iterator. */ |
227 | iterator end() { |
228 | auto* n = operator->(); |
229 | ICHECK(n); |
230 | return n->end(); |
231 | } |
232 | /*! \return The begin iterator. */ |
233 | const_iterator begin() const { |
234 | const auto* n = operator->(); |
235 | ICHECK(n); |
236 | return n->begin(); |
237 | } |
238 | /*! \return The end iterator. */ |
239 | const_iterator end() const { |
240 | const auto* n = operator->(); |
241 | ICHECK(n); |
242 | return n->end(); |
243 | } |
244 | |
245 | /*! \return mutable pointers to the node. */ |
246 | AnnotatedRegionSetNode* operator->() const { |
247 | auto* ptr = get_mutable(); |
248 | ICHECK(ptr != nullptr); |
249 | return static_cast<AnnotatedRegionSetNode*>(ptr); |
250 | } |
251 | |
252 | /*! \return The region an expression belongs to. */ |
253 | AnnotatedRegion operator[](const Expr& expr) { |
254 | const auto* n = operator->(); |
255 | ICHECK(n); |
256 | return n->GetRegion(expr); |
257 | } |
258 | |
259 | /*! \brief Create a RegionSet from a relay expression. |
260 | * |
261 | * \param expr The relay expr from which to construct the set. |
262 | * \param begin Region begin annotation operator. |
263 | * \param end Region end annotation operator. |
264 | * \param func_name function name |
265 | * |
266 | * \return The created RegionSet for the expression. |
267 | */ |
268 | static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end, |
269 | const std::string& func_name = "default" ); |
270 | |
271 | private: |
272 | /*! \brief Helper class to construct a RegionSet from an expr.*/ |
273 | class Creator; |
274 | }; |
275 | |
276 | } // namespace relay |
277 | } // namespace tvm |
278 | |
279 | #endif // TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ |
280 | |