1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <tvm/node/object_path.h>
21#include <tvm/node/repr_printer.h>
22#include <tvm/runtime/memory.h>
23#include <tvm/runtime/registry.h>
24
25#include <algorithm>
26#include <cstring>
27
28using namespace tvm::runtime;
29
30namespace tvm {
31
32// ============== ObjectPathNode ==============
33
34ObjectPathNode::ObjectPathNode(const ObjectPathNode* parent)
35 : parent_(GetRef<ObjectRef>(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {}
36
37// --- GetParent ---
38
39Optional<ObjectPath> ObjectPathNode::GetParent() const {
40 if (parent_ == nullptr) {
41 return NullOpt;
42 } else {
43 return Downcast<ObjectPath>(parent_);
44 }
45}
46
47TVM_REGISTER_GLOBAL("node.ObjectPathGetParent")
48 .set_body_method<ObjectPath>(&ObjectPathNode::GetParent);
49
50// --- Length ---
51
52int32_t ObjectPathNode::Length() const { return length_; }
53
54TVM_REGISTER_GLOBAL("node.ObjectPathLength").set_body_method<ObjectPath>(&ObjectPathNode::Length);
55
56// --- GetPrefix ---
57
58ObjectPath ObjectPathNode::GetPrefix(int32_t length) const {
59 CHECK_GE(length, 1) << "IndexError: Prefix length must be at least 1";
60 CHECK_LE(length, Length()) << "IndexError: Attempted to get a prefix longer than the path itself";
61
62 const ObjectPathNode* node = this;
63 int32_t suffix_len = Length() - length;
64 for (int32_t i = 0; i < suffix_len; ++i) {
65 node = node->ParentNode();
66 }
67
68 return GetRef<ObjectPath>(node);
69}
70
71TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix")
72 .set_body_method<ObjectPath>(&ObjectPathNode::GetPrefix);
73
74// --- IsPrefixOf ---
75
76bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const {
77 int32_t this_len = Length();
78 if (this_len > other->Length()) {
79 return false;
80 }
81 return this->PathsEqual(other->GetPrefix(this_len));
82}
83
84TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf")
85 .set_body_method<ObjectPath>(&ObjectPathNode::IsPrefixOf);
86
87// --- Attr ---
88
89ObjectPath ObjectPathNode::Attr(const char* attr_key) const {
90 if (attr_key != nullptr) {
91 return ObjectPath(make_object<AttributeAccessPathNode>(this, attr_key));
92 } else {
93 return ObjectPath(make_object<UnknownAttributeAccessPathNode>(this));
94 }
95}
96
97ObjectPath ObjectPathNode::Attr(Optional<String> attr_key) const {
98 if (attr_key.defined()) {
99 return ObjectPath(make_object<AttributeAccessPathNode>(this, attr_key.value()));
100 } else {
101 return ObjectPath(make_object<UnknownAttributeAccessPathNode>(this));
102 }
103}
104
105TVM_REGISTER_GLOBAL("node.ObjectPathAttr")
106 .set_body_typed([](const ObjectPath& object_path, Optional<String> attr_key) {
107 return object_path->Attr(attr_key);
108 });
109
110// --- ArrayIndex ---
111
112ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const {
113 return ObjectPath(make_object<ArrayIndexPathNode>(this, index));
114}
115
116TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex")
117 .set_body_method<ObjectPath>(&ObjectPathNode::ArrayIndex);
118
119// --- MissingArrayElement ---
120
121ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const {
122 return ObjectPath(make_object<MissingArrayElementPathNode>(this, index));
123}
124
125TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement")
126 .set_body_method<ObjectPath>(&ObjectPathNode::MissingArrayElement);
127
128// --- MapValue ---
129
130ObjectPath ObjectPathNode::MapValue(ObjectRef key) const {
131 return ObjectPath(make_object<MapValuePathNode>(this, std::move(key)));
132}
133
134TVM_REGISTER_GLOBAL("node.ObjectPathMapValue")
135 .set_body_method<ObjectPath>(&ObjectPathNode::MapValue);
136
137// --- MissingMapEntry ---
138
139ObjectPath ObjectPathNode::MissingMapEntry() const {
140 return ObjectPath(make_object<MissingMapEntryPathNode>(this));
141}
142
143TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry")
144 .set_body_method<ObjectPath>(&ObjectPathNode::MissingMapEntry);
145
146// --- PathsEqual ----
147
148bool ObjectPathNode::PathsEqual(const ObjectPath& other) const {
149 if (!other.defined() || Length() != other->Length()) {
150 return false;
151 }
152
153 const ObjectPathNode* lhs = this;
154 const ObjectPathNode* rhs = static_cast<const ObjectPathNode*>(other.get());
155
156 while (lhs != nullptr && rhs != nullptr) {
157 if (lhs->type_index() != rhs->type_index()) {
158 return false;
159 }
160 if (!lhs->LastNodeEqual(rhs)) {
161 return false;
162 }
163 lhs = lhs->ParentNode();
164 rhs = rhs->ParentNode();
165 }
166
167 return lhs == nullptr && rhs == nullptr;
168}
169
170TVM_REGISTER_GLOBAL("node.ObjectPathEqual")
171 .set_body_method<ObjectPath>(&ObjectPathNode::PathsEqual);
172
173// --- Repr ---
174
175std::string GetObjectPathRepr(const ObjectPathNode* node) {
176 std::string ret;
177 while (node != nullptr) {
178 std::string node_str = node->LastNodeString();
179 ret.append(node_str.rbegin(), node_str.rend());
180 node = static_cast<const ObjectPathNode*>(node->GetParent().get());
181 }
182 std::reverse(ret.begin(), ret.end());
183 return ret;
184}
185
186static void PrintObjectPathRepr(const ObjectRef& node, ReprPrinter* p) {
187 p->stream << GetObjectPathRepr(static_cast<const ObjectPathNode*>(node.get()));
188}
189
190TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<ObjectPathNode>(PrintObjectPathRepr);
191
192// --- Private/protected methods ---
193
194const ObjectPathNode* ObjectPathNode::ParentNode() const {
195 return static_cast<const ObjectPathNode*>(parent_.get());
196}
197
198// ============== ObjectPath ==============
199
200/* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object<RootPathNode>()); }
201
202TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);
203
204// ============== Individual path classes ==============
205
206// ----- Root -----
207
208RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {}
209
210bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; }
211
212std::string RootPathNode::LastNodeString() const { return "<root>"; }
213
214TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<RootPathNode>(PrintObjectPathRepr);
215
216// ----- AttributeAccess -----
217
218AttributeAccessPathNode::AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key)
219 : ObjectPathNode(parent), attr_key(std::move(attr_key)) {}
220
221bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const {
222 const auto* otherAttrAccess = static_cast<const AttributeAccessPathNode*>(other);
223 return attr_key == otherAttrAccess->attr_key;
224}
225
226std::string AttributeAccessPathNode::LastNodeString() const { return "." + attr_key; }
227
228TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
229 .set_dispatch<AttributeAccessPathNode>(PrintObjectPathRepr);
230
231// ----- UnknownAttributeAccess -----
232
233UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(const ObjectPathNode* parent)
234 : ObjectPathNode(parent) {}
235
236bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const {
237 // Consider any two unknown attribute accesses unequal
238 return false;
239}
240
241std::string UnknownAttributeAccessPathNode::LastNodeString() const {
242 return ".<unknown attribute>";
243}
244
245TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
246 .set_dispatch<UnknownAttributeAccessPathNode>(PrintObjectPathRepr);
247
248// ----- ArrayIndexPath -----
249
250ArrayIndexPathNode::ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index)
251 : ObjectPathNode(parent), index(index) {}
252
253bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const {
254 const auto* otherArrayIndex = static_cast<const ArrayIndexPathNode*>(other);
255 return index == otherArrayIndex->index;
256}
257
258std::string ArrayIndexPathNode::LastNodeString() const { return "[" + std::to_string(index) + "]"; }
259
260TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<ArrayIndexPathNode>(PrintObjectPathRepr);
261
262// ----- MissingArrayElement -----
263
264MissingArrayElementPathNode::MissingArrayElementPathNode(const ObjectPathNode* parent,
265 int32_t index)
266 : ObjectPathNode(parent), index(index) {}
267
268bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const {
269 const auto* otherMissingElement = static_cast<const MissingArrayElementPathNode*>(other);
270 return index == otherMissingElement->index;
271}
272
273std::string MissingArrayElementPathNode::LastNodeString() const {
274 return "[<missing element #" + std::to_string(index) + ">]";
275}
276
277TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
278 .set_dispatch<MissingArrayElementPathNode>(PrintObjectPathRepr);
279
280// ----- MapValue -----
281
282MapValuePathNode::MapValuePathNode(const ObjectPathNode* parent, ObjectRef key)
283 : ObjectPathNode(parent), key(std::move(key)) {}
284
285bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const {
286 const auto* otherMapValue = static_cast<const MapValuePathNode*>(other);
287 return ObjectEqual()(key, otherMapValue->key);
288}
289
290std::string MapValuePathNode::LastNodeString() const {
291 std::ostringstream s;
292 s << "[" << key << "]";
293 return s.str();
294}
295
296TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<MapValuePathNode>(PrintObjectPathRepr);
297
298// ----- MissingMapEntry -----
299
300MissingMapEntryPathNode::MissingMapEntryPathNode(const ObjectPathNode* parent)
301 : ObjectPathNode(parent) {}
302
303bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; }
304
305std::string MissingMapEntryPathNode::LastNodeString() const { return "[<missing entry>]"; }
306
307TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
308 .set_dispatch<MissingMapEntryPathNode>(PrintObjectPathRepr);
309
310} // namespace tvm
311