1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Reflection utilities. |
22 | * \file node/reflection.cc |
23 | */ |
24 | #include <tvm/ir/attrs.h> |
25 | #include <tvm/node/node.h> |
26 | #include <tvm/node/reflection.h> |
27 | #include <tvm/runtime/registry.h> |
28 | |
29 | namespace tvm { |
30 | |
31 | using runtime::PackedFunc; |
32 | using runtime::TVMArgs; |
33 | using runtime::TVMRetValue; |
34 | |
35 | // Attr getter. |
36 | class AttrGetter : public AttrVisitor { |
37 | public: |
38 | const String& skey; |
39 | TVMRetValue* ret; |
40 | |
41 | AttrGetter(const String& skey, TVMRetValue* ret) : skey(skey), ret(ret) {} |
42 | |
43 | bool found_ref_object{false}; |
44 | |
45 | void Visit(const char* key, double* value) final { |
46 | if (skey == key) *ret = value[0]; |
47 | } |
48 | void Visit(const char* key, int64_t* value) final { |
49 | if (skey == key) *ret = value[0]; |
50 | } |
51 | void Visit(const char* key, uint64_t* value) final { |
52 | ICHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) |
53 | << "cannot return too big constant" ; |
54 | if (skey == key) *ret = static_cast<int64_t>(value[0]); |
55 | } |
56 | void Visit(const char* key, int* value) final { |
57 | if (skey == key) *ret = static_cast<int64_t>(value[0]); |
58 | } |
59 | void Visit(const char* key, bool* value) final { |
60 | if (skey == key) *ret = static_cast<int64_t>(value[0]); |
61 | } |
62 | void Visit(const char* key, void** value) final { |
63 | if (skey == key) *ret = static_cast<void*>(value[0]); |
64 | } |
65 | void Visit(const char* key, DataType* value) final { |
66 | if (skey == key) *ret = value[0]; |
67 | } |
68 | void Visit(const char* key, std::string* value) final { |
69 | if (skey == key) *ret = value[0]; |
70 | } |
71 | |
72 | void Visit(const char* key, runtime::NDArray* value) final { |
73 | if (skey == key) { |
74 | *ret = value[0]; |
75 | found_ref_object = true; |
76 | } |
77 | } |
78 | void Visit(const char* key, runtime::ObjectRef* value) final { |
79 | if (skey == key) { |
80 | *ret = value[0]; |
81 | found_ref_object = true; |
82 | } |
83 | } |
84 | }; |
85 | |
86 | runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const String& field_name) const { |
87 | runtime::TVMRetValue ret; |
88 | AttrGetter getter(field_name, &ret); |
89 | |
90 | bool success; |
91 | if (getter.skey == "type_key" ) { |
92 | ret = self->GetTypeKey(); |
93 | success = true; |
94 | } else if (!self->IsInstance<DictAttrsNode>()) { |
95 | VisitAttrs(self, &getter); |
96 | success = getter.found_ref_object || ret.type_code() != kTVMNullptr; |
97 | } else { |
98 | // specially handle dict attr |
99 | DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self); |
100 | auto it = dnode->dict.find(getter.skey); |
101 | if (it != dnode->dict.end()) { |
102 | success = true; |
103 | ret = (*it).second; |
104 | } else { |
105 | success = false; |
106 | } |
107 | } |
108 | if (!success) { |
109 | LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed " |
110 | << getter.skey; |
111 | } |
112 | return ret; |
113 | } |
114 | |
115 | // List names; |
116 | class AttrDir : public AttrVisitor { |
117 | public: |
118 | std::vector<std::string>* names; |
119 | |
120 | void Visit(const char* key, double* value) final { names->push_back(key); } |
121 | void Visit(const char* key, int64_t* value) final { names->push_back(key); } |
122 | void Visit(const char* key, uint64_t* value) final { names->push_back(key); } |
123 | void Visit(const char* key, bool* value) final { names->push_back(key); } |
124 | void Visit(const char* key, int* value) final { names->push_back(key); } |
125 | void Visit(const char* key, void** value) final { names->push_back(key); } |
126 | void Visit(const char* key, DataType* value) final { names->push_back(key); } |
127 | void Visit(const char* key, std::string* value) final { names->push_back(key); } |
128 | void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } |
129 | void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); } |
130 | }; |
131 | |
132 | std::vector<std::string> ReflectionVTable::ListAttrNames(Object* self) const { |
133 | std::vector<std::string> names; |
134 | AttrDir dir; |
135 | dir.names = &names; |
136 | |
137 | if (!self->IsInstance<DictAttrsNode>()) { |
138 | VisitAttrs(self, &dir); |
139 | } else { |
140 | // specially handle dict attr |
141 | DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self); |
142 | for (const auto& kv : dnode->dict) { |
143 | names.push_back(kv.first); |
144 | } |
145 | } |
146 | return names; |
147 | } |
148 | |
149 | ReflectionVTable* ReflectionVTable::Global() { |
150 | static ReflectionVTable inst; |
151 | return &inst; |
152 | } |
153 | |
154 | ObjectPtr<Object> ReflectionVTable::CreateInitObject(const std::string& type_key, |
155 | const std::string& repr_bytes) const { |
156 | uint32_t tindex = Object::TypeKey2Index(type_key); |
157 | if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { |
158 | LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE" ; |
159 | } |
160 | return fcreate_[tindex](repr_bytes); |
161 | } |
162 | |
163 | class NodeAttrSetter : public AttrVisitor { |
164 | public: |
165 | std::string type_key; |
166 | std::unordered_map<std::string, runtime::TVMArgValue> attrs; |
167 | |
168 | void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); } |
169 | void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); } |
170 | void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); } |
171 | void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); } |
172 | void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); } |
173 | void Visit(const char* key, std::string* value) final { |
174 | *value = GetAttr(key).operator std::string(); |
175 | } |
176 | void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); } |
177 | void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); } |
178 | void Visit(const char* key, runtime::NDArray* value) final { |
179 | *value = GetAttr(key).operator runtime::NDArray(); |
180 | } |
181 | void Visit(const char* key, ObjectRef* value) final { |
182 | *value = GetAttr(key).operator ObjectRef(); |
183 | } |
184 | |
185 | private: |
186 | runtime::TVMArgValue GetAttr(const char* key) { |
187 | auto it = attrs.find(key); |
188 | if (it == attrs.end()) { |
189 | LOG(FATAL) << type_key << ": require field " << key; |
190 | } |
191 | runtime::TVMArgValue v = it->second; |
192 | attrs.erase(it); |
193 | return v; |
194 | } |
195 | }; |
196 | |
197 | void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) { |
198 | NodeAttrSetter setter; |
199 | setter.type_key = n->GetTypeKey(); |
200 | ICHECK_EQ(args.size() % 2, 0); |
201 | for (int i = 0; i < args.size(); i += 2) { |
202 | setter.attrs.emplace(args[i].operator std::string(), args[i + 1]); |
203 | } |
204 | reflection->VisitAttrs(n, &setter); |
205 | |
206 | if (setter.attrs.size() != 0) { |
207 | std::ostringstream os; |
208 | os << setter.type_key << " does not contain field " ; |
209 | for (const auto& kv : setter.attrs) { |
210 | os << " " << kv.first; |
211 | } |
212 | LOG(FATAL) << os.str(); |
213 | } |
214 | } |
215 | |
216 | ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) { |
217 | ObjectPtr<Object> n = this->CreateInitObject(type_key); |
218 | if (n->IsInstance<BaseAttrsNode>()) { |
219 | static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs); |
220 | } else { |
221 | InitNodeByPackedArgs(this, n.get(), kwargs); |
222 | } |
223 | return ObjectRef(n); |
224 | } |
225 | |
226 | ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, |
227 | const Map<String, ObjectRef>& kwargs) { |
228 | // Redirect to the TVMArgs version |
229 | // It is not the most efficient way, but CreateObject is not meant to be used |
230 | // in a fast code-path and is mainly reserved as a flexible API for frontends. |
231 | std::vector<TVMValue> values(kwargs.size() * 2); |
232 | std::vector<int32_t> tcodes(kwargs.size() * 2); |
233 | runtime::TVMArgsSetter setter(values.data(), tcodes.data()); |
234 | int index = 0; |
235 | |
236 | for (const auto& kv : *static_cast<const MapNode*>(kwargs.get())) { |
237 | setter(index, Downcast<String>(kv.first).c_str()); |
238 | setter(index + 1, kv.second); |
239 | index += 2; |
240 | } |
241 | |
242 | return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2)); |
243 | } |
244 | |
245 | // Expose to FFI APIs. |
246 | void NodeGetAttr(TVMArgs args, TVMRetValue* ret) { |
247 | ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); |
248 | Object* self = static_cast<Object*>(args[0].value().v_handle); |
249 | *ret = ReflectionVTable::Global()->GetAttr(self, args[1]); |
250 | } |
251 | |
252 | void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { |
253 | ICHECK_EQ(args[0].type_code(), kTVMObjectHandle); |
254 | Object* self = static_cast<Object*>(args[0].value().v_handle); |
255 | |
256 | auto names = |
257 | std::make_shared<std::vector<std::string>>(ReflectionVTable::Global()->ListAttrNames(self)); |
258 | |
259 | *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) { |
260 | int64_t i = args[0]; |
261 | if (i == -1) { |
262 | *rv = static_cast<int64_t>(names->size()); |
263 | } else { |
264 | *rv = (*names)[i]; |
265 | } |
266 | }); |
267 | } |
268 | |
269 | // API function to make node. |
270 | // args format: |
271 | // key1, value1, ..., key_n, value_n |
272 | void MakeNode(const TVMArgs& args, TVMRetValue* rv) { |
273 | std::string type_key = args[0]; |
274 | std::string empty_str; |
275 | TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); |
276 | *rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs); |
277 | } |
278 | |
279 | TVM_REGISTER_GLOBAL("node.NodeGetAttr" ).set_body(NodeGetAttr); |
280 | |
281 | TVM_REGISTER_GLOBAL("node.NodeListAttrNames" ).set_body(NodeListAttrNames); |
282 | |
283 | TVM_REGISTER_GLOBAL("node.MakeNode" ).set_body(MakeNode); |
284 | |
285 | namespace { |
286 | // Attribute visitor class for finding the attribute key by its address |
287 | class GetAttrKeyByAddressVisitor : public AttrVisitor { |
288 | public: |
289 | explicit GetAttrKeyByAddressVisitor(const void* attr_address) |
290 | : attr_address_(attr_address), key_(nullptr) {} |
291 | |
292 | void Visit(const char* key, double* value) final { DoVisit(key, value); } |
293 | void Visit(const char* key, int64_t* value) final { DoVisit(key, value); } |
294 | void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); } |
295 | void Visit(const char* key, int* value) final { DoVisit(key, value); } |
296 | void Visit(const char* key, bool* value) final { DoVisit(key, value); } |
297 | void Visit(const char* key, std::string* value) final { DoVisit(key, value); } |
298 | void Visit(const char* key, void** value) final { DoVisit(key, value); } |
299 | void Visit(const char* key, DataType* value) final { DoVisit(key, value); } |
300 | void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); } |
301 | void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); } |
302 | |
303 | const char* GetKey() const { return key_; } |
304 | |
305 | private: |
306 | const void* attr_address_; |
307 | const char* key_; |
308 | |
309 | void DoVisit(const char* key, const void* candidate) { |
310 | if (attr_address_ == candidate) { |
311 | key_ = key; |
312 | } |
313 | } |
314 | }; |
315 | } // anonymous namespace |
316 | |
317 | Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address) { |
318 | GetAttrKeyByAddressVisitor visitor(attr_address); |
319 | ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object), &visitor); |
320 | const char* key = visitor.GetKey(); |
321 | if (key == nullptr) { |
322 | return NullOpt; |
323 | } else { |
324 | return String(key); |
325 | } |
326 | } |
327 | |
328 | } // namespace tvm |
329 | |