1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <tvm/node/object_path.h> |
21 | #include <tvm/node/repr_printer.h> |
22 | #include <tvm/runtime/memory.h> |
23 | #include <tvm/runtime/registry.h> |
24 | |
25 | #include <algorithm> |
26 | #include <cstring> |
27 | |
28 | using namespace tvm::runtime; |
29 | |
30 | namespace tvm { |
31 | |
32 | // ============== ObjectPathNode ============== |
33 | |
34 | ObjectPathNode::ObjectPathNode(const ObjectPathNode* parent) |
35 | : parent_(GetRef<ObjectRef>(parent)), length_(parent == nullptr ? 1 : parent->length_ + 1) {} |
36 | |
37 | // --- GetParent --- |
38 | |
39 | Optional<ObjectPath> ObjectPathNode::GetParent() const { |
40 | if (parent_ == nullptr) { |
41 | return NullOpt; |
42 | } else { |
43 | return Downcast<ObjectPath>(parent_); |
44 | } |
45 | } |
46 | |
47 | TVM_REGISTER_GLOBAL("node.ObjectPathGetParent" ) |
48 | .set_body_method<ObjectPath>(&ObjectPathNode::GetParent); |
49 | |
50 | // --- Length --- |
51 | |
52 | int32_t ObjectPathNode::Length() const { return length_; } |
53 | |
54 | TVM_REGISTER_GLOBAL("node.ObjectPathLength" ).set_body_method<ObjectPath>(&ObjectPathNode::Length); |
55 | |
56 | // --- GetPrefix --- |
57 | |
58 | ObjectPath ObjectPathNode::GetPrefix(int32_t length) const { |
59 | CHECK_GE(length, 1) << "IndexError: Prefix length must be at least 1" ; |
60 | CHECK_LE(length, Length()) << "IndexError: Attempted to get a prefix longer than the path itself" ; |
61 | |
62 | const ObjectPathNode* node = this; |
63 | int32_t suffix_len = Length() - length; |
64 | for (int32_t i = 0; i < suffix_len; ++i) { |
65 | node = node->ParentNode(); |
66 | } |
67 | |
68 | return GetRef<ObjectPath>(node); |
69 | } |
70 | |
71 | TVM_REGISTER_GLOBAL("node.ObjectPathGetPrefix" ) |
72 | .set_body_method<ObjectPath>(&ObjectPathNode::GetPrefix); |
73 | |
74 | // --- IsPrefixOf --- |
75 | |
76 | bool ObjectPathNode::IsPrefixOf(const ObjectPath& other) const { |
77 | int32_t this_len = Length(); |
78 | if (this_len > other->Length()) { |
79 | return false; |
80 | } |
81 | return this->PathsEqual(other->GetPrefix(this_len)); |
82 | } |
83 | |
84 | TVM_REGISTER_GLOBAL("node.ObjectPathIsPrefixOf" ) |
85 | .set_body_method<ObjectPath>(&ObjectPathNode::IsPrefixOf); |
86 | |
87 | // --- Attr --- |
88 | |
89 | ObjectPath ObjectPathNode::Attr(const char* attr_key) const { |
90 | if (attr_key != nullptr) { |
91 | return ObjectPath(make_object<AttributeAccessPathNode>(this, attr_key)); |
92 | } else { |
93 | return ObjectPath(make_object<UnknownAttributeAccessPathNode>(this)); |
94 | } |
95 | } |
96 | |
97 | ObjectPath ObjectPathNode::Attr(Optional<String> attr_key) const { |
98 | if (attr_key.defined()) { |
99 | return ObjectPath(make_object<AttributeAccessPathNode>(this, attr_key.value())); |
100 | } else { |
101 | return ObjectPath(make_object<UnknownAttributeAccessPathNode>(this)); |
102 | } |
103 | } |
104 | |
105 | TVM_REGISTER_GLOBAL("node.ObjectPathAttr" ) |
106 | .set_body_typed([](const ObjectPath& object_path, Optional<String> attr_key) { |
107 | return object_path->Attr(attr_key); |
108 | }); |
109 | |
110 | // --- ArrayIndex --- |
111 | |
112 | ObjectPath ObjectPathNode::ArrayIndex(int32_t index) const { |
113 | return ObjectPath(make_object<ArrayIndexPathNode>(this, index)); |
114 | } |
115 | |
116 | TVM_REGISTER_GLOBAL("node.ObjectPathArrayIndex" ) |
117 | .set_body_method<ObjectPath>(&ObjectPathNode::ArrayIndex); |
118 | |
119 | // --- MissingArrayElement --- |
120 | |
121 | ObjectPath ObjectPathNode::MissingArrayElement(int32_t index) const { |
122 | return ObjectPath(make_object<MissingArrayElementPathNode>(this, index)); |
123 | } |
124 | |
125 | TVM_REGISTER_GLOBAL("node.ObjectPathMissingArrayElement" ) |
126 | .set_body_method<ObjectPath>(&ObjectPathNode::MissingArrayElement); |
127 | |
128 | // --- MapValue --- |
129 | |
130 | ObjectPath ObjectPathNode::MapValue(ObjectRef key) const { |
131 | return ObjectPath(make_object<MapValuePathNode>(this, std::move(key))); |
132 | } |
133 | |
134 | TVM_REGISTER_GLOBAL("node.ObjectPathMapValue" ) |
135 | .set_body_method<ObjectPath>(&ObjectPathNode::MapValue); |
136 | |
137 | // --- MissingMapEntry --- |
138 | |
139 | ObjectPath ObjectPathNode::MissingMapEntry() const { |
140 | return ObjectPath(make_object<MissingMapEntryPathNode>(this)); |
141 | } |
142 | |
143 | TVM_REGISTER_GLOBAL("node.ObjectPathMissingMapEntry" ) |
144 | .set_body_method<ObjectPath>(&ObjectPathNode::MissingMapEntry); |
145 | |
146 | // --- PathsEqual ---- |
147 | |
148 | bool ObjectPathNode::PathsEqual(const ObjectPath& other) const { |
149 | if (!other.defined() || Length() != other->Length()) { |
150 | return false; |
151 | } |
152 | |
153 | const ObjectPathNode* lhs = this; |
154 | const ObjectPathNode* rhs = static_cast<const ObjectPathNode*>(other.get()); |
155 | |
156 | while (lhs != nullptr && rhs != nullptr) { |
157 | if (lhs->type_index() != rhs->type_index()) { |
158 | return false; |
159 | } |
160 | if (!lhs->LastNodeEqual(rhs)) { |
161 | return false; |
162 | } |
163 | lhs = lhs->ParentNode(); |
164 | rhs = rhs->ParentNode(); |
165 | } |
166 | |
167 | return lhs == nullptr && rhs == nullptr; |
168 | } |
169 | |
170 | TVM_REGISTER_GLOBAL("node.ObjectPathEqual" ) |
171 | .set_body_method<ObjectPath>(&ObjectPathNode::PathsEqual); |
172 | |
173 | // --- Repr --- |
174 | |
175 | std::string GetObjectPathRepr(const ObjectPathNode* node) { |
176 | std::string ret; |
177 | while (node != nullptr) { |
178 | std::string node_str = node->LastNodeString(); |
179 | ret.append(node_str.rbegin(), node_str.rend()); |
180 | node = static_cast<const ObjectPathNode*>(node->GetParent().get()); |
181 | } |
182 | std::reverse(ret.begin(), ret.end()); |
183 | return ret; |
184 | } |
185 | |
186 | static void PrintObjectPathRepr(const ObjectRef& node, ReprPrinter* p) { |
187 | p->stream << GetObjectPathRepr(static_cast<const ObjectPathNode*>(node.get())); |
188 | } |
189 | |
190 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<ObjectPathNode>(PrintObjectPathRepr); |
191 | |
192 | // --- Private/protected methods --- |
193 | |
194 | const ObjectPathNode* ObjectPathNode::ParentNode() const { |
195 | return static_cast<const ObjectPathNode*>(parent_.get()); |
196 | } |
197 | |
198 | // ============== ObjectPath ============== |
199 | |
200 | /* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object<RootPathNode>()); } |
201 | |
202 | TVM_REGISTER_GLOBAL("node.ObjectPathRoot" ).set_body_typed(ObjectPath::Root); |
203 | |
204 | // ============== Individual path classes ============== |
205 | |
206 | // ----- Root ----- |
207 | |
208 | RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {} |
209 | |
210 | bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } |
211 | |
212 | std::string RootPathNode::LastNodeString() const { return "<root>" ; } |
213 | |
214 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<RootPathNode>(PrintObjectPathRepr); |
215 | |
216 | // ----- AttributeAccess ----- |
217 | |
218 | AttributeAccessPathNode::AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key) |
219 | : ObjectPathNode(parent), attr_key(std::move(attr_key)) {} |
220 | |
221 | bool AttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { |
222 | const auto* otherAttrAccess = static_cast<const AttributeAccessPathNode*>(other); |
223 | return attr_key == otherAttrAccess->attr_key; |
224 | } |
225 | |
226 | std::string AttributeAccessPathNode::LastNodeString() const { return "." + attr_key; } |
227 | |
228 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
229 | .set_dispatch<AttributeAccessPathNode>(PrintObjectPathRepr); |
230 | |
231 | // ----- UnknownAttributeAccess ----- |
232 | |
233 | UnknownAttributeAccessPathNode::UnknownAttributeAccessPathNode(const ObjectPathNode* parent) |
234 | : ObjectPathNode(parent) {} |
235 | |
236 | bool UnknownAttributeAccessPathNode::LastNodeEqual(const ObjectPathNode* other) const { |
237 | // Consider any two unknown attribute accesses unequal |
238 | return false; |
239 | } |
240 | |
241 | std::string UnknownAttributeAccessPathNode::LastNodeString() const { |
242 | return ".<unknown attribute>" ; |
243 | } |
244 | |
245 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
246 | .set_dispatch<UnknownAttributeAccessPathNode>(PrintObjectPathRepr); |
247 | |
248 | // ----- ArrayIndexPath ----- |
249 | |
250 | ArrayIndexPathNode::ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index) |
251 | : ObjectPathNode(parent), index(index) {} |
252 | |
253 | bool ArrayIndexPathNode::LastNodeEqual(const ObjectPathNode* other) const { |
254 | const auto* otherArrayIndex = static_cast<const ArrayIndexPathNode*>(other); |
255 | return index == otherArrayIndex->index; |
256 | } |
257 | |
258 | std::string ArrayIndexPathNode::LastNodeString() const { return "[" + std::to_string(index) + "]" ; } |
259 | |
260 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<ArrayIndexPathNode>(PrintObjectPathRepr); |
261 | |
262 | // ----- MissingArrayElement ----- |
263 | |
264 | MissingArrayElementPathNode::MissingArrayElementPathNode(const ObjectPathNode* parent, |
265 | int32_t index) |
266 | : ObjectPathNode(parent), index(index) {} |
267 | |
268 | bool MissingArrayElementPathNode::LastNodeEqual(const ObjectPathNode* other) const { |
269 | const auto* otherMissingElement = static_cast<const MissingArrayElementPathNode*>(other); |
270 | return index == otherMissingElement->index; |
271 | } |
272 | |
273 | std::string MissingArrayElementPathNode::LastNodeString() const { |
274 | return "[<missing element #" + std::to_string(index) + ">]" ; |
275 | } |
276 | |
277 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
278 | .set_dispatch<MissingArrayElementPathNode>(PrintObjectPathRepr); |
279 | |
280 | // ----- MapValue ----- |
281 | |
282 | MapValuePathNode::MapValuePathNode(const ObjectPathNode* parent, ObjectRef key) |
283 | : ObjectPathNode(parent), key(std::move(key)) {} |
284 | |
285 | bool MapValuePathNode::LastNodeEqual(const ObjectPathNode* other) const { |
286 | const auto* otherMapValue = static_cast<const MapValuePathNode*>(other); |
287 | return ObjectEqual()(key, otherMapValue->key); |
288 | } |
289 | |
290 | std::string MapValuePathNode::LastNodeString() const { |
291 | std::ostringstream s; |
292 | s << "[" << key << "]" ; |
293 | return s.str(); |
294 | } |
295 | |
296 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<MapValuePathNode>(PrintObjectPathRepr); |
297 | |
298 | // ----- MissingMapEntry ----- |
299 | |
300 | MissingMapEntryPathNode::MissingMapEntryPathNode(const ObjectPathNode* parent) |
301 | : ObjectPathNode(parent) {} |
302 | |
303 | bool MissingMapEntryPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; } |
304 | |
305 | std::string MissingMapEntryPathNode::LastNodeString() const { return "[<missing entry>]" ; } |
306 | |
307 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
308 | .set_dispatch<MissingMapEntryPathNode>(PrintObjectPathRepr); |
309 | |
310 | } // namespace tvm |
311 | |