1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Reflection utilities.
22 * \file node/reflection.cc
23 */
24#include <tvm/ir/attrs.h>
25#include <tvm/node/node.h>
26#include <tvm/node/reflection.h>
27#include <tvm/runtime/registry.h>
28
29namespace tvm {
30
31using runtime::PackedFunc;
32using runtime::TVMArgs;
33using runtime::TVMRetValue;
34
35// Attr getter.
36class AttrGetter : public AttrVisitor {
37 public:
38 const String& skey;
39 TVMRetValue* ret;
40
41 AttrGetter(const String& skey, TVMRetValue* ret) : skey(skey), ret(ret) {}
42
43 bool found_ref_object{false};
44
45 void Visit(const char* key, double* value) final {
46 if (skey == key) *ret = value[0];
47 }
48 void Visit(const char* key, int64_t* value) final {
49 if (skey == key) *ret = value[0];
50 }
51 void Visit(const char* key, uint64_t* value) final {
52 ICHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
53 << "cannot return too big constant";
54 if (skey == key) *ret = static_cast<int64_t>(value[0]);
55 }
56 void Visit(const char* key, int* value) final {
57 if (skey == key) *ret = static_cast<int64_t>(value[0]);
58 }
59 void Visit(const char* key, bool* value) final {
60 if (skey == key) *ret = static_cast<int64_t>(value[0]);
61 }
62 void Visit(const char* key, void** value) final {
63 if (skey == key) *ret = static_cast<void*>(value[0]);
64 }
65 void Visit(const char* key, DataType* value) final {
66 if (skey == key) *ret = value[0];
67 }
68 void Visit(const char* key, std::string* value) final {
69 if (skey == key) *ret = value[0];
70 }
71
72 void Visit(const char* key, runtime::NDArray* value) final {
73 if (skey == key) {
74 *ret = value[0];
75 found_ref_object = true;
76 }
77 }
78 void Visit(const char* key, runtime::ObjectRef* value) final {
79 if (skey == key) {
80 *ret = value[0];
81 found_ref_object = true;
82 }
83 }
84};
85
86runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const String& field_name) const {
87 runtime::TVMRetValue ret;
88 AttrGetter getter(field_name, &ret);
89
90 bool success;
91 if (getter.skey == "type_key") {
92 ret = self->GetTypeKey();
93 success = true;
94 } else if (!self->IsInstance<DictAttrsNode>()) {
95 VisitAttrs(self, &getter);
96 success = getter.found_ref_object || ret.type_code() != kTVMNullptr;
97 } else {
98 // specially handle dict attr
99 DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
100 auto it = dnode->dict.find(getter.skey);
101 if (it != dnode->dict.end()) {
102 success = true;
103 ret = (*it).second;
104 } else {
105 success = false;
106 }
107 }
108 if (!success) {
109 LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed "
110 << getter.skey;
111 }
112 return ret;
113}
114
115// List names;
116class AttrDir : public AttrVisitor {
117 public:
118 std::vector<std::string>* names;
119
120 void Visit(const char* key, double* value) final { names->push_back(key); }
121 void Visit(const char* key, int64_t* value) final { names->push_back(key); }
122 void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
123 void Visit(const char* key, bool* value) final { names->push_back(key); }
124 void Visit(const char* key, int* value) final { names->push_back(key); }
125 void Visit(const char* key, void** value) final { names->push_back(key); }
126 void Visit(const char* key, DataType* value) final { names->push_back(key); }
127 void Visit(const char* key, std::string* value) final { names->push_back(key); }
128 void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); }
129 void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); }
130};
131
132std::vector<std::string> ReflectionVTable::ListAttrNames(Object* self) const {
133 std::vector<std::string> names;
134 AttrDir dir;
135 dir.names = &names;
136
137 if (!self->IsInstance<DictAttrsNode>()) {
138 VisitAttrs(self, &dir);
139 } else {
140 // specially handle dict attr
141 DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
142 for (const auto& kv : dnode->dict) {
143 names.push_back(kv.first);
144 }
145 }
146 return names;
147}
148
149ReflectionVTable* ReflectionVTable::Global() {
150 static ReflectionVTable inst;
151 return &inst;
152}
153
154ObjectPtr<Object> ReflectionVTable::CreateInitObject(const std::string& type_key,
155 const std::string& repr_bytes) const {
156 uint32_t tindex = Object::TypeKey2Index(type_key);
157 if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
158 LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE";
159 }
160 return fcreate_[tindex](repr_bytes);
161}
162
163class NodeAttrSetter : public AttrVisitor {
164 public:
165 std::string type_key;
166 std::unordered_map<std::string, runtime::TVMArgValue> attrs;
167
168 void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); }
169 void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); }
170 void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); }
171 void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); }
172 void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); }
173 void Visit(const char* key, std::string* value) final {
174 *value = GetAttr(key).operator std::string();
175 }
176 void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); }
177 void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); }
178 void Visit(const char* key, runtime::NDArray* value) final {
179 *value = GetAttr(key).operator runtime::NDArray();
180 }
181 void Visit(const char* key, ObjectRef* value) final {
182 *value = GetAttr(key).operator ObjectRef();
183 }
184
185 private:
186 runtime::TVMArgValue GetAttr(const char* key) {
187 auto it = attrs.find(key);
188 if (it == attrs.end()) {
189 LOG(FATAL) << type_key << ": require field " << key;
190 }
191 runtime::TVMArgValue v = it->second;
192 attrs.erase(it);
193 return v;
194 }
195};
196
197void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) {
198 NodeAttrSetter setter;
199 setter.type_key = n->GetTypeKey();
200 ICHECK_EQ(args.size() % 2, 0);
201 for (int i = 0; i < args.size(); i += 2) {
202 setter.attrs.emplace(args[i].operator std::string(), args[i + 1]);
203 }
204 reflection->VisitAttrs(n, &setter);
205
206 if (setter.attrs.size() != 0) {
207 std::ostringstream os;
208 os << setter.type_key << " does not contain field ";
209 for (const auto& kv : setter.attrs) {
210 os << " " << kv.first;
211 }
212 LOG(FATAL) << os.str();
213 }
214}
215
216ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) {
217 ObjectPtr<Object> n = this->CreateInitObject(type_key);
218 if (n->IsInstance<BaseAttrsNode>()) {
219 static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
220 } else {
221 InitNodeByPackedArgs(this, n.get(), kwargs);
222 }
223 return ObjectRef(n);
224}
225
226ObjectRef ReflectionVTable::CreateObject(const std::string& type_key,
227 const Map<String, ObjectRef>& kwargs) {
228 // Redirect to the TVMArgs version
229 // It is not the most efficient way, but CreateObject is not meant to be used
230 // in a fast code-path and is mainly reserved as a flexible API for frontends.
231 std::vector<TVMValue> values(kwargs.size() * 2);
232 std::vector<int32_t> tcodes(kwargs.size() * 2);
233 runtime::TVMArgsSetter setter(values.data(), tcodes.data());
234 int index = 0;
235
236 for (const auto& kv : *static_cast<const MapNode*>(kwargs.get())) {
237 setter(index, Downcast<String>(kv.first).c_str());
238 setter(index + 1, kv.second);
239 index += 2;
240 }
241
242 return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2));
243}
244
245// Expose to FFI APIs.
246void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
247 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
248 Object* self = static_cast<Object*>(args[0].value().v_handle);
249 *ret = ReflectionVTable::Global()->GetAttr(self, args[1]);
250}
251
252void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) {
253 ICHECK_EQ(args[0].type_code(), kTVMObjectHandle);
254 Object* self = static_cast<Object*>(args[0].value().v_handle);
255
256 auto names =
257 std::make_shared<std::vector<std::string>>(ReflectionVTable::Global()->ListAttrNames(self));
258
259 *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) {
260 int64_t i = args[0];
261 if (i == -1) {
262 *rv = static_cast<int64_t>(names->size());
263 } else {
264 *rv = (*names)[i];
265 }
266 });
267}
268
269// API function to make node.
270// args format:
271// key1, value1, ..., key_n, value_n
272void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
273 std::string type_key = args[0];
274 std::string empty_str;
275 TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
276 *rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs);
277}
278
279TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
280
281TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);
282
283TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);
284
285namespace {
286// Attribute visitor class for finding the attribute key by its address
287class GetAttrKeyByAddressVisitor : public AttrVisitor {
288 public:
289 explicit GetAttrKeyByAddressVisitor(const void* attr_address)
290 : attr_address_(attr_address), key_(nullptr) {}
291
292 void Visit(const char* key, double* value) final { DoVisit(key, value); }
293 void Visit(const char* key, int64_t* value) final { DoVisit(key, value); }
294 void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); }
295 void Visit(const char* key, int* value) final { DoVisit(key, value); }
296 void Visit(const char* key, bool* value) final { DoVisit(key, value); }
297 void Visit(const char* key, std::string* value) final { DoVisit(key, value); }
298 void Visit(const char* key, void** value) final { DoVisit(key, value); }
299 void Visit(const char* key, DataType* value) final { DoVisit(key, value); }
300 void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); }
301 void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); }
302
303 const char* GetKey() const { return key_; }
304
305 private:
306 const void* attr_address_;
307 const char* key_;
308
309 void DoVisit(const char* key, const void* candidate) {
310 if (attr_address_ == candidate) {
311 key_ = key;
312 }
313 }
314};
315} // anonymous namespace
316
317Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address) {
318 GetAttrKeyByAddressVisitor visitor(attr_address);
319 ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object), &visitor);
320 const char* key = visitor.GetKey();
321 if (key == nullptr) {
322 return NullOpt;
323 } else {
324 return String(key);
325 }
326}
327
328} // namespace tvm
329