1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/node/object_path.h
22 * ObjectPath class that represents a path from a root object to one of its descendants
23 * via attribute access, array indexing etc.
24 */
25
26#ifndef TVM_NODE_OBJECT_PATH_H_
27#define TVM_NODE_OBJECT_PATH_H_
28
29#include <tvm/runtime/container/optional.h>
30#include <tvm/runtime/container/string.h>
31#include <tvm/runtime/object.h>
32
33#include <string>
34
35namespace tvm {
36
37using runtime::Object;
38using runtime::ObjectPtr;
39using runtime::ObjectRef;
40
41class ObjectPath;
42
43/*!
44 * \brief Path to an object from some root object.
45 *
46 * Motivation:
47 *
48 * Same IR node object can be referenced in several different contexts inside a larger IR object.
49 * For example, a variable could be referenced in several statements within a block.
50 *
51 * This makes it impossible to use an object pointer to uniquely identify a "location" within
52 * the larger IR object for error reporting purposes. The ObjectPath class addresses this problem
53 * by serving as a unique "locator".
54 */
55class ObjectPathNode : public Object {
56 public:
57 /*! \brief Get the parent path */
58 Optional<ObjectPath> GetParent() const;
59 /*!
60 * \brief Get the length of the path.
61 *
62 * For example, the path returned by `ObjectPath::Root()` has length 1.
63 */
64 int32_t Length() const;
65
66 /*!
67 * \brief Get a path prefix of the given length.
68 *
69 * Provided `length` must not exceed the `Length()` of this path.
70 */
71 ObjectPath GetPrefix(int32_t length) const;
72
73 /*!
74 * \brief Check if this path is a prefix of another path.
75 *
76 * The prefix is not strict, i.e. a path is considered a prefix of itself.
77 */
78 bool IsPrefixOf(const ObjectPath& other) const;
79
80 /*! \brief Check if two paths are equal. */
81 bool PathsEqual(const ObjectPath& other) const;
82
83 /*! \brief Extend this path with access to an object attribute. */
84 ObjectPath Attr(const char* attr_key) const;
85
86 /*! \brief Extend this path with access to an object attribute. */
87 ObjectPath Attr(Optional<String> attr_key) const;
88
89 /*! \brief Extend this path with access to an array element. */
90 ObjectPath ArrayIndex(int32_t index) const;
91
92 /*! \brief Extend this path with access to a missing array element. */
93 ObjectPath MissingArrayElement(int32_t index) const;
94
95 /*! \brief Extend this path with access to a map value. */
96 ObjectPath MapValue(ObjectRef key) const;
97
98 /*! \brief Extend this path with access to a missing map entry. */
99 ObjectPath MissingMapEntry() const;
100
101 static constexpr const char* _type_key = "ObjectPath";
102 TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object);
103
104 protected:
105 explicit ObjectPathNode(const ObjectPathNode* parent);
106
107 friend class ObjectPath;
108 friend std::string GetObjectPathRepr(const ObjectPathNode* node);
109
110 const ObjectPathNode* ParentNode() const;
111
112 /*! Compares just the last node of the path, without comparing the whole path. */
113 virtual bool LastNodeEqual(const ObjectPathNode* other) const = 0;
114
115 virtual std::string LastNodeString() const = 0;
116
117 private:
118 Optional<ObjectRef> parent_;
119 int32_t length_;
120};
121
122class ObjectPath : public ObjectRef {
123 public:
124 /*! \brief Create a path that represents the root object itself. */
125 static ObjectPath Root();
126
127 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode);
128};
129
130//-------------------------------------------------------------------------
131//----- Concrete object path nodes ------------------------------------
132//-------------------------------------------------------------------------
133
134// ----- Root -----
135
136class RootPathNode final : public ObjectPathNode {
137 public:
138 explicit RootPathNode();
139
140 static constexpr const char* _type_key = "RootPath";
141 TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);
142
143 protected:
144 bool LastNodeEqual(const ObjectPathNode* other) const final;
145 std::string LastNodeString() const final;
146};
147
148class RootPath : public ObjectPath {
149 public:
150 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootPath, ObjectPath, RootPathNode);
151};
152
153// ----- Attribute access -----
154
155class AttributeAccessPathNode final : public ObjectPathNode {
156 public:
157 /*! \brief Name of the attribute being accessed. Must be a static string. */
158 String attr_key;
159
160 explicit AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key);
161
162 static constexpr const char* _type_key = "AttributeAccessPath";
163 TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode);
164
165 protected:
166 bool LastNodeEqual(const ObjectPathNode* other) const final;
167 std::string LastNodeString() const final;
168};
169
170class AttributeAccessPath : public ObjectPath {
171 public:
172 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttributeAccessPath, ObjectPath,
173 AttributeAccessPathNode);
174};
175
176// ----- Unknown attribute access -----
177
178class UnknownAttributeAccessPathNode final : public ObjectPathNode {
179 public:
180 explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent);
181
182 static constexpr const char* _type_key = "UnknownAttributeAccessPath";
183 TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode);
184
185 protected:
186 bool LastNodeEqual(const ObjectPathNode* other) const final;
187 std::string LastNodeString() const final;
188};
189
190class UnknownAttributeAccessPath : public ObjectPath {
191 public:
192 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnknownAttributeAccessPath, ObjectPath,
193 UnknownAttributeAccessPathNode);
194};
195
196// ----- Array element access by index -----
197
198class ArrayIndexPathNode : public ObjectPathNode {
199 public:
200 /*! \brief Index of the array element that is being accessed. */
201 int32_t index;
202
203 explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index);
204
205 static constexpr const char* _type_key = "ArrayIndexPath";
206 TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode);
207
208 protected:
209 bool LastNodeEqual(const ObjectPathNode* other) const final;
210 std::string LastNodeString() const final;
211};
212
213class ArrayIndexPath : public ObjectPath {
214 public:
215 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ArrayIndexPath, ObjectPath, ArrayIndexPathNode);
216};
217
218// ----- Missing array element -----
219
220class MissingArrayElementPathNode : public ObjectPathNode {
221 public:
222 /*! \brief Index of the array element that is missing. */
223 int32_t index;
224
225 explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index);
226
227 static constexpr const char* _type_key = "MissingArrayElementPath";
228 TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode);
229
230 protected:
231 bool LastNodeEqual(const ObjectPathNode* other) const final;
232 std::string LastNodeString() const final;
233};
234
235class MissingArrayElementPath : public ObjectPath {
236 public:
237 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingArrayElementPath, ObjectPath,
238 MissingArrayElementPathNode);
239};
240
241// ----- Map value -----
242
243class MapValuePathNode : public ObjectPathNode {
244 public:
245 /*! \brief Key of the map entry that is being accessed */
246 ObjectRef key;
247
248 explicit MapValuePathNode(const ObjectPathNode* parent, ObjectRef key);
249
250 static constexpr const char* _type_key = "MapValuePath";
251 TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode);
252
253 protected:
254 bool LastNodeEqual(const ObjectPathNode* other) const final;
255 std::string LastNodeString() const final;
256};
257
258class MapValuePath : public ObjectPath {
259 public:
260 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MapValuePath, ObjectPath, MapValuePathNode);
261};
262
263// ----- Missing map entry -----
264
265class MissingMapEntryPathNode : public ObjectPathNode {
266 public:
267 explicit MissingMapEntryPathNode(const ObjectPathNode* parent);
268
269 static constexpr const char* _type_key = "MissingMapEntryPath";
270 TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode);
271
272 protected:
273 bool LastNodeEqual(const ObjectPathNode* other) const final;
274 std::string LastNodeString() const final;
275};
276
277class MissingMapEntryPath : public ObjectPath {
278 public:
279 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(MissingMapEntryPath, ObjectPath,
280 MissingMapEntryPathNode);
281};
282
283} // namespace tvm
284
285#endif // TVM_NODE_OBJECT_PATH_H_
286