1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/node/object_path.h |
22 | * ObjectPath class that represents a path from a root object to one of its descendants |
23 | * via attribute access, array indexing etc. |
24 | */ |
25 | |
26 | #ifndef TVM_NODE_OBJECT_PATH_H_ |
27 | #define TVM_NODE_OBJECT_PATH_H_ |
28 | |
29 | #include <tvm/runtime/container/optional.h> |
30 | #include <tvm/runtime/container/string.h> |
31 | #include <tvm/runtime/object.h> |
32 | |
33 | #include <string> |
34 | |
35 | namespace tvm { |
36 | |
37 | using runtime::Object; |
38 | using runtime::ObjectPtr; |
39 | using runtime::ObjectRef; |
40 | |
41 | class ObjectPath; |
42 | |
43 | /*! |
44 | * \brief Path to an object from some root object. |
45 | * |
46 | * Motivation: |
47 | * |
48 | * Same IR node object can be referenced in several different contexts inside a larger IR object. |
49 | * For example, a variable could be referenced in several statements within a block. |
50 | * |
51 | * This makes it impossible to use an object pointer to uniquely identify a "location" within |
52 | * the larger IR object for error reporting purposes. The ObjectPath class addresses this problem |
53 | * by serving as a unique "locator". |
54 | */ |
55 | class ObjectPathNode : public Object { |
56 | public: |
57 | /*! \brief Get the parent path */ |
58 | Optional<ObjectPath> GetParent() const; |
59 | /*! |
60 | * \brief Get the length of the path. |
61 | * |
62 | * For example, the path returned by `ObjectPath::Root()` has length 1. |
63 | */ |
64 | int32_t Length() const; |
65 | |
66 | /*! |
67 | * \brief Get a path prefix of the given length. |
68 | * |
69 | * Provided `length` must not exceed the `Length()` of this path. |
70 | */ |
71 | ObjectPath GetPrefix(int32_t length) const; |
72 | |
73 | /*! |
74 | * \brief Check if this path is a prefix of another path. |
75 | * |
76 | * The prefix is not strict, i.e. a path is considered a prefix of itself. |
77 | */ |
78 | bool IsPrefixOf(const ObjectPath& other) const; |
79 | |
80 | /*! \brief Check if two paths are equal. */ |
81 | bool PathsEqual(const ObjectPath& other) const; |
82 | |
83 | /*! \brief Extend this path with access to an object attribute. */ |
84 | ObjectPath Attr(const char* attr_key) const; |
85 | |
86 | /*! \brief Extend this path with access to an object attribute. */ |
87 | ObjectPath Attr(Optional<String> attr_key) const; |
88 | |
89 | /*! \brief Extend this path with access to an array element. */ |
90 | ObjectPath ArrayIndex(int32_t index) const; |
91 | |
92 | /*! \brief Extend this path with access to a missing array element. */ |
93 | ObjectPath MissingArrayElement(int32_t index) const; |
94 | |
95 | /*! \brief Extend this path with access to a map value. */ |
96 | ObjectPath MapValue(ObjectRef key) const; |
97 | |
98 | /*! \brief Extend this path with access to a missing map entry. */ |
99 | ObjectPath MissingMapEntry() const; |
100 | |
101 | static constexpr const char* _type_key = "ObjectPath" ; |
102 | TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object); |
103 | |
104 | protected: |
105 | explicit ObjectPathNode(const ObjectPathNode* parent); |
106 | |
107 | friend class ObjectPath; |
108 | friend std::string GetObjectPathRepr(const ObjectPathNode* node); |
109 | |
110 | const ObjectPathNode* ParentNode() const; |
111 | |
112 | /*! Compares just the last node of the path, without comparing the whole path. */ |
113 | virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0; |
114 | |
115 | virtual std::string LastNodeString() const = 0; |
116 | |
117 | private: |
118 | Optional<ObjectRef> parent_; |
119 | int32_t length_; |
120 | }; |
121 | |
122 | class ObjectPath : public ObjectRef { |
123 | public: |
124 | /*! \brief Create a path that represents the root object itself. */ |
125 | static ObjectPath Root(); |
126 | |
127 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode); |
128 | }; |
129 | |
130 | //------------------------------------------------------------------------- |
131 | //----- Concrete object path nodes ------------------------------------ |
132 | //------------------------------------------------------------------------- |
133 | |
134 | // ----- Root ----- |
135 | |
136 | class RootPathNode final : public ObjectPathNode { |
137 | public: |
138 | explicit RootPathNode(); |
139 | |
140 | static constexpr const char* _type_key = "RootPath" ; |
141 | TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); |
142 | |
143 | protected: |
144 | bool LastNodeEqual(const ObjectPathNode* other) const final; |
145 | std::string LastNodeString() const final; |
146 | }; |
147 | |
148 | class RootPath : public ObjectPath { |
149 | public: |
150 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode); |
151 | }; |
152 | |
153 | // ----- Attribute access ----- |
154 | |
155 | class AttributeAccessPathNode final : public ObjectPathNode { |
156 | public: |
157 | /*! \brief Name of the attribute being accessed. Must be a static string. */ |
158 | String attr_key; |
159 | |
160 | explicit AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key); |
161 | |
162 | static constexpr const char* _type_key = "AttributeAccessPath" ; |
163 | TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode); |
164 | |
165 | protected: |
166 | bool LastNodeEqual(const ObjectPathNode* other) const final; |
167 | std::string LastNodeString() const final; |
168 | }; |
169 | |
170 | class AttributeAccessPath : public ObjectPath { |
171 | public: |
172 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttributeAccessPath, ObjectPath, |
173 | AttributeAccessPathNode); |
174 | }; |
175 | |
176 | // ----- Unknown attribute access ----- |
177 | |
178 | class UnknownAttributeAccessPathNode final : public ObjectPathNode { |
179 | public: |
180 | explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent); |
181 | |
182 | static constexpr const char* _type_key = "UnknownAttributeAccessPath" ; |
183 | TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode); |
184 | |
185 | protected: |
186 | bool LastNodeEqual(const ObjectPathNode* other) const final; |
187 | std::string LastNodeString() const final; |
188 | }; |
189 | |
190 | class UnknownAttributeAccessPath : public ObjectPath { |
191 | public: |
192 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnknownAttributeAccessPath, ObjectPath, |
193 | UnknownAttributeAccessPathNode); |
194 | }; |
195 | |
196 | // ----- Array element access by index ----- |
197 | |
198 | class ArrayIndexPathNode : public ObjectPathNode { |
199 | public: |
200 | /*! \brief Index of the array element that is being accessed. */ |
201 | int32_t index; |
202 | |
203 | explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index); |
204 | |
205 | static constexpr const char* _type_key = "ArrayIndexPath" ; |
206 | TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode); |
207 | |
208 | protected: |
209 | bool LastNodeEqual(const ObjectPathNode* other) const final; |
210 | std::string LastNodeString() const final; |
211 | }; |
212 | |
213 | class ArrayIndexPath : public ObjectPath { |
214 | public: |
215 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode); |
216 | }; |
217 | |
218 | // ----- Missing array element ----- |
219 | |
220 | class MissingArrayElementPathNode : public ObjectPathNode { |
221 | public: |
222 | /*! \brief Index of the array element that is missing. */ |
223 | int32_t index; |
224 | |
225 | explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index); |
226 | |
227 | static constexpr const char* _type_key = "MissingArrayElementPath" ; |
228 | TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode); |
229 | |
230 | protected: |
231 | bool LastNodeEqual(const ObjectPathNode* other) const final; |
232 | std::string LastNodeString() const final; |
233 | }; |
234 | |
235 | class MissingArrayElementPath : public ObjectPath { |
236 | public: |
237 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingArrayElementPath, ObjectPath, |
238 | MissingArrayElementPathNode); |
239 | }; |
240 | |
241 | // ----- Map value ----- |
242 | |
243 | class MapValuePathNode : public ObjectPathNode { |
244 | public: |
245 | /*! \brief Key of the map entry that is being accessed */ |
246 | ObjectRef key; |
247 | |
248 | explicit MapValuePathNode(const ObjectPathNode* parent, ObjectRef key); |
249 | |
250 | static constexpr const char* _type_key = "MapValuePath" ; |
251 | TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); |
252 | |
253 | protected: |
254 | bool LastNodeEqual(const ObjectPathNode* other) const final; |
255 | std::string LastNodeString() const final; |
256 | }; |
257 | |
258 | class MapValuePath : public ObjectPath { |
259 | public: |
260 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode); |
261 | }; |
262 | |
263 | // ----- Missing map entry ----- |
264 | |
265 | class MissingMapEntryPathNode : public ObjectPathNode { |
266 | public: |
267 | explicit MissingMapEntryPathNode(const ObjectPathNode* parent); |
268 | |
269 | static constexpr const char* _type_key = "MissingMapEntryPath" ; |
270 | TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode); |
271 | |
272 | protected: |
273 | bool LastNodeEqual(const ObjectPathNode* other) const final; |
274 | std::string LastNodeString() const final; |
275 | }; |
276 | |
277 | class MissingMapEntryPath : public ObjectPath { |
278 | public: |
279 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingMapEntryPath, ObjectPath, |
280 | MissingMapEntryPathNode); |
281 | }; |
282 | |
283 | } // namespace tvm |
284 | |
285 | #endif // TVM_NODE_OBJECT_PATH_H_ |
286 | |