1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/transforms/annotated_region_set.h
22 * \brief Define data structures to extract and manipulate regions from
23 * a relay function. Regions are denoted by region_begin and region_end
24 * annotations that exist on all the input and output edges of the region.
25 */
26
27#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
28#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
29
30#include <tvm/relay/analysis.h>
31#include <tvm/relay/attrs/annotation.h>
32#include <tvm/relay/error.h>
33#include <tvm/relay/expr.h>
34#include <tvm/relay/expr_functor.h>
35#include <tvm/relay/transform.h>
36
37#include <list>
38#include <string>
39#include <unordered_set>
40#include <utility>
41#include <vector>
42
43namespace tvm {
44namespace relay {
45
46class AnnotatedRegion;
47class AnnotatedRegionSet;
48
49class AnnotatedRegionNode : public Object {
50 public:
51 void VisitAttrs(AttrVisitor* v) {
52 v->Visit("id", &id_);
53 v->Visit("target", &target_);
54 Array<Expr> nodes_array(nodes_.begin(), nodes_.end());
55 v->Visit("nodes", &nodes_array);
56 Array<Expr> args_array(ins_.begin(), ins_.end());
57 v->Visit("args", &args_array);
58 Array<Expr> rets_array(outs_.begin(), outs_.end());
59 v->Visit("rets", &rets_array);
60 }
61
62 /*! \brief Get the region ID. */
63 int GetID() const { return id_; }
64
65 /*! \brief Get the region name. */
66 std::string GetName() const { return func_name_; }
67
68 /*! \brief Get the region target. */
69 std::string GetTarget() const { return target_; }
70
71 /*! \brief Get the region's inputs. */
72 std::list<Expr> GetInputs() const { return ins_; }
73
74 /*! \brief Get the region's outputs. */
75 std::list<Expr> GetOutputs() const { return outs_; }
76
77 /*! \brief Get the region's nodes. */
78 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> GetNodes() const { return nodes_; }
79
80 static constexpr const char* _type_key = "relay.AnnotatedRegion";
81 TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object);
82
83 protected:
84 /*! \brief The region ID. */
85 int id_{-1};
86 /*! \brief The func name. */
87 std::string func_name_ = "default";
88 /*! \brief The target for this region. */
89 std::string target_ = "default";
90 /*! \brief The inputs to this region. */
91 std::list<Expr> ins_;
92 /*! \brief The outputs of this region */
93 std::list<Expr> outs_;
94 /*! \brief Nodes in this region. */
95 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> nodes_;
96
97 friend class AnnotatedRegionSet;
98 friend class AnnotatedRegionSetNode;
99};
100
101/*!
102 * \brief An object to hold the properties of a region as used by the
103 * AnnotatedRegionSet class. This should be considered read-only.
104 */
105class AnnotatedRegion : public ObjectRef {
106 public:
107 AnnotatedRegion() {
108 auto n = make_object<AnnotatedRegionNode>();
109 data_ = std::move(n);
110 }
111
112 /*!
113 * \brief Construct from an object pointer.
114 * \param n The object pointer.
115 */
116 explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {}
117
118 /*! \return Mutable pointers to the node. */
119 AnnotatedRegionNode* operator->() const {
120 auto* ptr = get_mutable();
121 ICHECK(ptr != nullptr);
122 return static_cast<AnnotatedRegionNode*>(ptr);
123 }
124};
125
126class AnnotatedRegionSetNode : public Object {
127 using UnorderedRegionSet = std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>;
128 // Create iterator alias for a RegionSet object.
129 using iterator = UnorderedRegionSet::iterator;
130 using const_iterator = UnorderedRegionSet::const_iterator;
131
132 public:
133 /*! \brief Default constructor. */
134 AnnotatedRegionSetNode() = default;
135
136 /*! \return The begin iterator */
137 iterator begin() { return regions_.begin(); }
138 /*! \return The end iterator */
139 iterator end() { return regions_.end(); }
140 /*! \return The const begin iterator */
141 const_iterator begin() const { return regions_.begin(); }
142 /*! \return The const end iterator */
143 const_iterator end() const { return regions_.end(); }
144
145 /*!
146 * \brief Get the region that an expression belongs to.
147 *
148 * \param expr Which expr to get the region for.
149 *
150 * \return A pointer to the region, nullptr if the expression
151 * doesn't belong to a region.
152 */
153 AnnotatedRegion GetRegion(const Expr& expr) const;
154
155 /*!
156 * \brief Merge src region into dest region.
157 *
158 * \param src The region to merge - will be erased.
159 * \param dest The region into which src will be merged.
160 */
161 void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest);
162
163 void VisitAttrs(AttrVisitor* v) {
164 Array<AnnotatedRegion> regions_array(regions_.begin(), regions_.end());
165 v->Visit("regions", &regions_array);
166 }
167
168 static constexpr const char* _type_key = "relay.AnnotatedRegionSet";
169 TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionSetNode, Object);
170
171 private:
172 /*!
173 * \brief Add an expression to a region.
174 *
175 * \param dest The region to add the expression to.
176 * \param expr The expression.
177 */
178 void AddToRegion(AnnotatedRegion dest, const Expr& expr);
179
180 /*!
181 * \brief Make a new region for a target.
182 *
183 * \return The new region.
184 */
185 AnnotatedRegion MakeRegion(const std::string& func_name, const std::string& target);
186
187 std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual> regions_;
188 /*! \brief The next region ID to assign. */
189 int region_id_{0};
190
191 friend class AnnotatedRegionSet;
192};
193
194/*!
195 * \brief A class to hold a set of regions produced from a relay expression
196 * that contains 'region_begin' and 'region_end' style annotations. The
197 * regions should be disjoint. The class provides both a method to construct
198 * the region set of a given relay expression as well as additional methods
199 * to update and query regions.
200 */
201class AnnotatedRegionSet : public ObjectRef {
202 using UnorderedRegionSet = std::unordered_set<AnnotatedRegion, ObjectPtrHash, ObjectPtrEqual>;
203 // Create iterator alias for a RegionSet object.
204 using iterator = UnorderedRegionSet::iterator;
205 using const_iterator = UnorderedRegionSet::const_iterator;
206
207 public:
208 AnnotatedRegionSet() {
209 auto n = make_object<AnnotatedRegionSetNode>();
210 data_ = std::move(n);
211 }
212
213 /*!
214 * \brief Construct from an object pointer.
215 *
216 * \param n The object pointer.
217 */
218 explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {}
219
220 /*! \return The begin iterator. */
221 iterator begin() {
222 auto* n = operator->();
223 ICHECK(n);
224 return n->begin();
225 }
226 /*! \return The end iterator. */
227 iterator end() {
228 auto* n = operator->();
229 ICHECK(n);
230 return n->end();
231 }
232 /*! \return The begin iterator. */
233 const_iterator begin() const {
234 const auto* n = operator->();
235 ICHECK(n);
236 return n->begin();
237 }
238 /*! \return The end iterator. */
239 const_iterator end() const {
240 const auto* n = operator->();
241 ICHECK(n);
242 return n->end();
243 }
244
245 /*! \return mutable pointers to the node. */
246 AnnotatedRegionSetNode* operator->() const {
247 auto* ptr = get_mutable();
248 ICHECK(ptr != nullptr);
249 return static_cast<AnnotatedRegionSetNode*>(ptr);
250 }
251
252 /*! \return The region an expression belongs to. */
253 AnnotatedRegion operator[](const Expr& expr) {
254 const auto* n = operator->();
255 ICHECK(n);
256 return n->GetRegion(expr);
257 }
258
259 /*! \brief Create a RegionSet from a relay expression.
260 *
261 * \param expr The relay expr from which to construct the set.
262 * \param begin Region begin annotation operator.
263 * \param end Region end annotation operator.
264 * \param func_name function name
265 *
266 * \return The created RegionSet for the expression.
267 */
268 static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end,
269 const std::string& func_name = "default");
270
271 private:
272 /*! \brief Helper class to construct a RegionSet from an expr.*/
273 class Creator;
274};
275
276} // namespace relay
277} // namespace tvm
278
279#endif // TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
280