1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/node/reflection.h |
21 | * \brief Reflection and serialization of compiler IR/AST nodes. |
22 | */ |
23 | #ifndef TVM_NODE_REFLECTION_H_ |
24 | #define TVM_NODE_REFLECTION_H_ |
25 | |
26 | #include <tvm/node/structural_equal.h> |
27 | #include <tvm/node/structural_hash.h> |
28 | #include <tvm/runtime/c_runtime_api.h> |
29 | #include <tvm/runtime/data_type.h> |
30 | #include <tvm/runtime/memory.h> |
31 | #include <tvm/runtime/ndarray.h> |
32 | #include <tvm/runtime/object.h> |
33 | #include <tvm/runtime/packed_func.h> |
34 | |
35 | #include <string> |
36 | #include <type_traits> |
37 | #include <vector> |
38 | |
39 | namespace tvm { |
40 | |
41 | using runtime::Object; |
42 | using runtime::ObjectPtr; |
43 | using runtime::ObjectRef; |
44 | |
45 | /*! |
46 | * \brief Visitor class to get the attributes of an AST/IR node. |
47 | * The content is going to be called for each field. |
48 | * |
49 | * Each objects that wants reflection will need to implement |
50 | * a VisitAttrs function and call visitor->Visit on each of its field. |
51 | */ |
52 | class AttrVisitor { |
53 | public: |
54 | //! \cond Doxygen_Suppress |
55 | TVM_DLL virtual ~AttrVisitor() = default; |
56 | TVM_DLL virtual void Visit(const char* key, double* value) = 0; |
57 | TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0; |
58 | TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0; |
59 | TVM_DLL virtual void Visit(const char* key, int* value) = 0; |
60 | TVM_DLL virtual void Visit(const char* key, bool* value) = 0; |
61 | TVM_DLL virtual void Visit(const char* key, std::string* value) = 0; |
62 | TVM_DLL virtual void Visit(const char* key, void** value) = 0; |
63 | TVM_DLL virtual void Visit(const char* key, DataType* value) = 0; |
64 | TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0; |
65 | TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; |
66 | template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type> |
67 | void Visit(const char* key, ENum* ptr) { |
68 | static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value, |
69 | "declare enum to be enum int to use visitor" ); |
70 | this->Visit(key, reinterpret_cast<int*>(ptr)); |
71 | } |
72 | //! \endcond |
73 | }; |
74 | |
75 | /*! |
76 | * \brief Virtual function table to support IR/AST node reflection. |
77 | * |
78 | * Functions are stored in columnar manner. |
79 | * Each column is a vector indexed by Object's type_index. |
80 | */ |
81 | class ReflectionVTable { |
82 | public: |
83 | /*! |
84 | * \brief Visitor function. |
85 | * \note We use function pointer, instead of std::function |
86 | * to reduce the dispatch overhead as field visit |
87 | * does not need as much customization. |
88 | */ |
89 | typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor); |
90 | /*! |
91 | * \brief Equality comparison function. |
92 | */ |
93 | typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal); |
94 | /*! |
95 | * \brief Structural hash reduction function. |
96 | */ |
97 | typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce); |
98 | /*! |
99 | * \brief creator function. |
100 | * \param repr_bytes Repr bytes to create the object. |
101 | * If this is not empty then FReprBytes must be defined for the object. |
102 | * \return The created function. |
103 | */ |
104 | typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes); |
105 | /*! |
106 | * \brief Function to get a byte representation that can be used to recover the object. |
107 | * \param node The node pointer. |
108 | * \return bytes The bytes that can be used to recover the object. |
109 | */ |
110 | typedef std::string (*FReprBytes)(const Object* self); |
111 | /*! |
112 | * \brief Dispatch the VisitAttrs function. |
113 | * \param self The pointer to the object. |
114 | * \param visitor The attribute visitor. |
115 | */ |
116 | inline void VisitAttrs(Object* self, AttrVisitor* visitor) const; |
117 | /*! |
118 | * \brief Get repr bytes if any. |
119 | * \param self The pointer to the object. |
120 | * \param repr_bytes The output repr bytes, can be null, in which case the function |
121 | * simply queries if the ReprBytes function exists for the type. |
122 | * \return Whether repr bytes exists |
123 | */ |
124 | inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const; |
125 | /*! |
126 | * \brief Dispatch the SEqualReduce function. |
127 | * \param self The pointer to the object. |
128 | * \param other The pointer to another object to be compared. |
129 | * \param equal The equality comparator. |
130 | * \return the result. |
131 | */ |
132 | bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const; |
133 | /*! |
134 | * \brief Dispatch the SHashReduce function. |
135 | * \param self The pointer to the object. |
136 | * \param hash_reduce The hash reducer. |
137 | * \return the result. |
138 | */ |
139 | void SHashReduce(const Object* self, SHashReducer hash_reduce) const; |
140 | /*! |
141 | * \brief Create an initial object using default constructor |
142 | * by type_key and global key. |
143 | * |
144 | * \param type_key The type key of the object. |
145 | * \param repr_bytes Bytes representation of the object if any. |
146 | */ |
147 | TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key, |
148 | const std::string& repr_bytes = "" ) const; |
149 | /*! |
150 | * \brief Create an object by giving kwargs about its fields. |
151 | * |
152 | * \param type_key The type key. |
153 | * \param kwargs the arguments in format key1, value1, ..., key_n, value_n. |
154 | * \return The created object. |
155 | */ |
156 | TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs); |
157 | /*! |
158 | * \brief Create an object by giving kwargs about its fields. |
159 | * |
160 | * \param type_key The type key. |
161 | * \param kwargs The field arguments. |
162 | * \return The created object. |
163 | */ |
164 | TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map<String, ObjectRef>& kwargs); |
165 | /*! |
166 | * \brief Get an field object by the attr name. |
167 | * \param self The pointer to the object. |
168 | * \param attr_name The name of the field. |
169 | * \return The corresponding attribute value. |
170 | * \note This function will throw an exception if the object does not contain the field. |
171 | */ |
172 | TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const; |
173 | |
174 | /*! |
175 | * \brief List all the fields in the object. |
176 | * \return All the fields. |
177 | */ |
178 | TVM_DLL std::vector<std::string> ListAttrNames(Object* self) const; |
179 | |
180 | /*! \return The global singleton. */ |
181 | TVM_DLL static ReflectionVTable* Global(); |
182 | |
183 | class Registry; |
184 | template <typename T, typename TraitName> |
185 | inline Registry Register(); |
186 | |
187 | private: |
188 | /*! \brief Attribute visitor. */ |
189 | std::vector<FVisitAttrs> fvisit_attrs_; |
190 | /*! \brief Structural equal function. */ |
191 | std::vector<FSEqualReduce> fsequal_reduce_; |
192 | /*! \brief Structural hash function. */ |
193 | std::vector<FSHashReduce> fshash_reduce_; |
194 | /*! \brief Creation function. */ |
195 | std::vector<FCreate> fcreate_; |
196 | /*! \brief ReprBytes function. */ |
197 | std::vector<FReprBytes> frepr_bytes_; |
198 | }; |
199 | |
200 | /*! \brief Registry of a reflection table. */ |
201 | class ReflectionVTable::Registry { |
202 | public: |
203 | Registry(ReflectionVTable* parent, uint32_t type_index) |
204 | : parent_(parent), type_index_(type_index) {} |
205 | /*! |
206 | * \brief Set fcreate function. |
207 | * \param f The creator function. |
208 | * \return Reference to self. |
209 | */ |
210 | Registry& set_creator(FCreate f) { // NOLINT(*) |
211 | ICHECK_LT(type_index_, parent_->fcreate_.size()); |
212 | parent_->fcreate_[type_index_] = f; |
213 | return *this; |
214 | } |
215 | /*! |
216 | * \brief Set bytes repr function. |
217 | * \param f The ReprBytes function. |
218 | * \return Reference to self. |
219 | */ |
220 | Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*) |
221 | ICHECK_LT(type_index_, parent_->frepr_bytes_.size()); |
222 | parent_->frepr_bytes_[type_index_] = f; |
223 | return *this; |
224 | } |
225 | |
226 | private: |
227 | ReflectionVTable* parent_; |
228 | uint32_t type_index_; |
229 | }; |
230 | |
231 | #define TVM_REFLECTION_REG_VAR_DEF \ |
232 | static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflection |
233 | |
234 | /*! |
235 | * \brief Directly register reflection VTable. |
236 | * \param TypeName The name of the type. |
237 | * \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce. |
238 | * |
239 | * \code |
240 | * |
241 | * // Example SEQualReduce traits for runtime StringObj. |
242 | * |
243 | * struct StringObjTrait { |
244 | * static constexpr const std::nullptr_t VisitAttrs = nullptr; |
245 | * |
246 | * static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { |
247 | * hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); |
248 | * } |
249 | * |
250 | * static bool SEqualReduce(const runtime::StringObj* lhs, |
251 | * const runtime::StringObj* rhs, |
252 | * SEqualReducer equal) { |
253 | * if (lhs == rhs) return true; |
254 | * if (lhs->size != rhs->size) return false; |
255 | * if (lhs->data != rhs->data) return true; |
256 | * return std::memcmp(lhs->data, rhs->data, lhs->size) != 0; |
257 | * } |
258 | * }; |
259 | * |
260 | * TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); |
261 | * |
262 | * \endcode |
263 | * |
264 | * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE. |
265 | * And can be used to register the related reflection functions for runtime objects. |
266 | */ |
267 | #define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ |
268 | TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ |
269 | ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() |
270 | |
271 | /*! |
272 | * \brief Register a node type to object registry and reflection registry. |
273 | * \param TypeName The name of the type. |
274 | * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well. |
275 | */ |
276 | #define TVM_REGISTER_NODE_TYPE(TypeName) \ |
277 | TVM_REGISTER_OBJECT_TYPE(TypeName); \ |
278 | TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \ |
279 | .set_creator([](const std::string&) -> ObjectPtr<Object> { \ |
280 | return ::tvm::runtime::make_object<TypeName>(); \ |
281 | }) |
282 | |
283 | // Implementation details |
284 | namespace detail { |
285 | |
286 | template <typename T, bool = T::_type_has_method_visit_attrs> |
287 | struct ImplVisitAttrs { |
288 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
289 | }; |
290 | |
291 | template <typename T> |
292 | struct ImplVisitAttrs<T, true> { |
293 | static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); } |
294 | }; |
295 | |
296 | template <typename T, bool = T::_type_has_method_sequal_reduce> |
297 | struct ImplSEqualReduce { |
298 | static constexpr const std::nullptr_t SEqualReduce = nullptr; |
299 | }; |
300 | |
301 | template <typename T> |
302 | struct ImplSEqualReduce<T, true> { |
303 | static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) { |
304 | return self->SEqualReduce(other, equal); |
305 | } |
306 | }; |
307 | |
308 | template <typename T, bool = T::_type_has_method_shash_reduce> |
309 | struct ImplSHashReduce { |
310 | static constexpr const std::nullptr_t SHashReduce = nullptr; |
311 | }; |
312 | |
313 | template <typename T> |
314 | struct ImplSHashReduce<T, true> { |
315 | static void SHashReduce(const T* self, SHashReducer hash_reduce) { |
316 | self->SHashReduce(hash_reduce); |
317 | } |
318 | }; |
319 | |
320 | template <typename T> |
321 | struct ReflectionTrait : public ImplVisitAttrs<T>, |
322 | public ImplSEqualReduce<T>, |
323 | public ImplSHashReduce<T> {}; |
324 | |
325 | template <typename T, typename TraitName, |
326 | bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value> |
327 | struct SelectVisitAttrs { |
328 | static constexpr const std::nullptr_t VisitAttrs = nullptr; |
329 | }; |
330 | |
331 | template <typename T, typename TraitName> |
332 | struct SelectVisitAttrs<T, TraitName, false> { |
333 | static void VisitAttrs(Object* self, AttrVisitor* v) { |
334 | TraitName::VisitAttrs(static_cast<T*>(self), v); |
335 | } |
336 | }; |
337 | |
338 | template <typename T, typename TraitName, |
339 | bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value> |
340 | struct SelectSEqualReduce { |
341 | static constexpr const std::nullptr_t SEqualReduce = nullptr; |
342 | }; |
343 | |
344 | template <typename T, typename TraitName> |
345 | struct SelectSEqualReduce<T, TraitName, false> { |
346 | static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) { |
347 | return TraitName::SEqualReduce(static_cast<const T*>(self), static_cast<const T*>(other), |
348 | equal); |
349 | } |
350 | }; |
351 | |
352 | template <typename T, typename TraitName, |
353 | bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value> |
354 | struct SelectSHashReduce { |
355 | static constexpr const std::nullptr_t SHashReduce = nullptr; |
356 | }; |
357 | |
358 | template <typename T, typename TraitName> |
359 | struct SelectSHashReduce<T, TraitName, false> { |
360 | static void SHashReduce(const Object* self, SHashReducer hash_reduce) { |
361 | return TraitName::SHashReduce(static_cast<const T*>(self), hash_reduce); |
362 | } |
363 | }; |
364 | |
365 | } // namespace detail |
366 | |
367 | template <typename T, typename TraitName> |
368 | inline ReflectionVTable::Registry ReflectionVTable::Register() { |
369 | uint32_t tindex = T::RuntimeTypeIndex(); |
370 | if (tindex >= fvisit_attrs_.size()) { |
371 | fvisit_attrs_.resize(tindex + 1, nullptr); |
372 | fcreate_.resize(tindex + 1, nullptr); |
373 | frepr_bytes_.resize(tindex + 1, nullptr); |
374 | fsequal_reduce_.resize(tindex + 1, nullptr); |
375 | fshash_reduce_.resize(tindex + 1, nullptr); |
376 | } |
377 | // functor that implements the redirection. |
378 | fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs; |
379 | |
380 | fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce; |
381 | |
382 | fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce<T, TraitName>::SHashReduce; |
383 | |
384 | return Registry(this, tindex); |
385 | } |
386 | |
387 | inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const { |
388 | uint32_t tindex = self->type_index(); |
389 | if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { |
390 | return; |
391 | } |
392 | fvisit_attrs_[tindex](self, visitor); |
393 | } |
394 | |
395 | inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const { |
396 | uint32_t tindex = self->type_index(); |
397 | if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) { |
398 | if (repr_bytes != nullptr) { |
399 | *repr_bytes = frepr_bytes_[tindex](self); |
400 | } |
401 | return true; |
402 | } else { |
403 | return false; |
404 | } |
405 | } |
406 | |
407 | /*! |
408 | * \brief Given an object and an address of its attribute, return the key of the attribute. |
409 | * \return nullptr if no attribute with the given address exists. |
410 | */ |
411 | Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address); |
412 | |
413 | } // namespace tvm |
414 | #endif // TVM_NODE_REFLECTION_H_ |
415 | |