1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/node/reflection.h
21 * \brief Reflection and serialization of compiler IR/AST nodes.
22 */
23#ifndef TVM_NODE_REFLECTION_H_
24#define TVM_NODE_REFLECTION_H_
25
26#include <tvm/node/structural_equal.h>
27#include <tvm/node/structural_hash.h>
28#include <tvm/runtime/c_runtime_api.h>
29#include <tvm/runtime/data_type.h>
30#include <tvm/runtime/memory.h>
31#include <tvm/runtime/ndarray.h>
32#include <tvm/runtime/object.h>
33#include <tvm/runtime/packed_func.h>
34
35#include <string>
36#include <type_traits>
37#include <vector>
38
39namespace tvm {
40
41using runtime::Object;
42using runtime::ObjectPtr;
43using runtime::ObjectRef;
44
45/*!
46 * \brief Visitor class to get the attributes of an AST/IR node.
47 * The content is going to be called for each field.
48 *
49 * Each objects that wants reflection will need to implement
50 * a VisitAttrs function and call visitor->Visit on each of its field.
51 */
52class AttrVisitor {
53 public:
54 //! \cond Doxygen_Suppress
55 TVM_DLL virtual ~AttrVisitor() = default;
56 TVM_DLL virtual void Visit(const char* key, double* value) = 0;
57 TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
58 TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
59 TVM_DLL virtual void Visit(const char* key, int* value) = 0;
60 TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
61 TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
62 TVM_DLL virtual void Visit(const char* key, void** value) = 0;
63 TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
64 TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
65 TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
66 template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
67 void Visit(const char* key, ENum* ptr) {
68 static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
69 "declare enum to be enum int to use visitor");
70 this->Visit(key, reinterpret_cast<int*>(ptr));
71 }
72 //! \endcond
73};
74
75/*!
76 * \brief Virtual function table to support IR/AST node reflection.
77 *
78 * Functions are stored in columnar manner.
79 * Each column is a vector indexed by Object's type_index.
80 */
81class ReflectionVTable {
82 public:
83 /*!
84 * \brief Visitor function.
85 * \note We use function pointer, instead of std::function
86 * to reduce the dispatch overhead as field visit
87 * does not need as much customization.
88 */
89 typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
90 /*!
91 * \brief Equality comparison function.
92 */
93 typedef bool (*FSEqualReduce)(const Object* self, const Object* other, SEqualReducer equal);
94 /*!
95 * \brief Structural hash reduction function.
96 */
97 typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
98 /*!
99 * \brief creator function.
100 * \param repr_bytes Repr bytes to create the object.
101 * If this is not empty then FReprBytes must be defined for the object.
102 * \return The created function.
103 */
104 typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
105 /*!
106 * \brief Function to get a byte representation that can be used to recover the object.
107 * \param node The node pointer.
108 * \return bytes The bytes that can be used to recover the object.
109 */
110 typedef std::string (*FReprBytes)(const Object* self);
111 /*!
112 * \brief Dispatch the VisitAttrs function.
113 * \param self The pointer to the object.
114 * \param visitor The attribute visitor.
115 */
116 inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
117 /*!
118 * \brief Get repr bytes if any.
119 * \param self The pointer to the object.
120 * \param repr_bytes The output repr bytes, can be null, in which case the function
121 * simply queries if the ReprBytes function exists for the type.
122 * \return Whether repr bytes exists
123 */
124 inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
125 /*!
126 * \brief Dispatch the SEqualReduce function.
127 * \param self The pointer to the object.
128 * \param other The pointer to another object to be compared.
129 * \param equal The equality comparator.
130 * \return the result.
131 */
132 bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const;
133 /*!
134 * \brief Dispatch the SHashReduce function.
135 * \param self The pointer to the object.
136 * \param hash_reduce The hash reducer.
137 * \return the result.
138 */
139 void SHashReduce(const Object* self, SHashReducer hash_reduce) const;
140 /*!
141 * \brief Create an initial object using default constructor
142 * by type_key and global key.
143 *
144 * \param type_key The type key of the object.
145 * \param repr_bytes Bytes representation of the object if any.
146 */
147 TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
148 const std::string& repr_bytes = "") const;
149 /*!
150 * \brief Create an object by giving kwargs about its fields.
151 *
152 * \param type_key The type key.
153 * \param kwargs the arguments in format key1, value1, ..., key_n, value_n.
154 * \return The created object.
155 */
156 TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs);
157 /*!
158 * \brief Create an object by giving kwargs about its fields.
159 *
160 * \param type_key The type key.
161 * \param kwargs The field arguments.
162 * \return The created object.
163 */
164 TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map<String, ObjectRef>& kwargs);
165 /*!
166 * \brief Get an field object by the attr name.
167 * \param self The pointer to the object.
168 * \param attr_name The name of the field.
169 * \return The corresponding attribute value.
170 * \note This function will throw an exception if the object does not contain the field.
171 */
172 TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const;
173
174 /*!
175 * \brief List all the fields in the object.
176 * \return All the fields.
177 */
178 TVM_DLL std::vector<std::string> ListAttrNames(Object* self) const;
179
180 /*! \return The global singleton. */
181 TVM_DLL static ReflectionVTable* Global();
182
183 class Registry;
184 template <typename T, typename TraitName>
185 inline Registry Register();
186
187 private:
188 /*! \brief Attribute visitor. */
189 std::vector<FVisitAttrs> fvisit_attrs_;
190 /*! \brief Structural equal function. */
191 std::vector<FSEqualReduce> fsequal_reduce_;
192 /*! \brief Structural hash function. */
193 std::vector<FSHashReduce> fshash_reduce_;
194 /*! \brief Creation function. */
195 std::vector<FCreate> fcreate_;
196 /*! \brief ReprBytes function. */
197 std::vector<FReprBytes> frepr_bytes_;
198};
199
200/*! \brief Registry of a reflection table. */
201class ReflectionVTable::Registry {
202 public:
203 Registry(ReflectionVTable* parent, uint32_t type_index)
204 : parent_(parent), type_index_(type_index) {}
205 /*!
206 * \brief Set fcreate function.
207 * \param f The creator function.
208 * \return Reference to self.
209 */
210 Registry& set_creator(FCreate f) { // NOLINT(*)
211 ICHECK_LT(type_index_, parent_->fcreate_.size());
212 parent_->fcreate_[type_index_] = f;
213 return *this;
214 }
215 /*!
216 * \brief Set bytes repr function.
217 * \param f The ReprBytes function.
218 * \return Reference to self.
219 */
220 Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*)
221 ICHECK_LT(type_index_, parent_->frepr_bytes_.size());
222 parent_->frepr_bytes_[type_index_] = f;
223 return *this;
224 }
225
226 private:
227 ReflectionVTable* parent_;
228 uint32_t type_index_;
229};
230
231#define TVM_REFLECTION_REG_VAR_DEF \
232 static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflection
233
234/*!
235 * \brief Directly register reflection VTable.
236 * \param TypeName The name of the type.
237 * \param TraitName A trait class that implements functions like VisitAttrs and SEqualReduce.
238 *
239 * \code
240 *
241 * // Example SEQualReduce traits for runtime StringObj.
242 *
243 * struct StringObjTrait {
244 * static constexpr const std::nullptr_t VisitAttrs = nullptr;
245 *
246 * static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) {
247 * hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size));
248 * }
249 *
250 * static bool SEqualReduce(const runtime::StringObj* lhs,
251 * const runtime::StringObj* rhs,
252 * SEqualReducer equal) {
253 * if (lhs == rhs) return true;
254 * if (lhs->size != rhs->size) return false;
255 * if (lhs->data != rhs->data) return true;
256 * return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
257 * }
258 * };
259 *
260 * TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
261 *
262 * \endcode
263 *
264 * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE.
265 * And can be used to register the related reflection functions for runtime objects.
266 */
267#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
268 TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
269 ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>()
270
271/*!
272 * \brief Register a node type to object registry and reflection registry.
273 * \param TypeName The name of the type.
274 * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well.
275 */
276#define TVM_REGISTER_NODE_TYPE(TypeName) \
277 TVM_REGISTER_OBJECT_TYPE(TypeName); \
278 TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
279 .set_creator([](const std::string&) -> ObjectPtr<Object> { \
280 return ::tvm::runtime::make_object<TypeName>(); \
281 })
282
283// Implementation details
284namespace detail {
285
286template <typename T, bool = T::_type_has_method_visit_attrs>
287struct ImplVisitAttrs {
288 static constexpr const std::nullptr_t VisitAttrs = nullptr;
289};
290
291template <typename T>
292struct ImplVisitAttrs<T, true> {
293 static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); }
294};
295
296template <typename T, bool = T::_type_has_method_sequal_reduce>
297struct ImplSEqualReduce {
298 static constexpr const std::nullptr_t SEqualReduce = nullptr;
299};
300
301template <typename T>
302struct ImplSEqualReduce<T, true> {
303 static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) {
304 return self->SEqualReduce(other, equal);
305 }
306};
307
308template <typename T, bool = T::_type_has_method_shash_reduce>
309struct ImplSHashReduce {
310 static constexpr const std::nullptr_t SHashReduce = nullptr;
311};
312
313template <typename T>
314struct ImplSHashReduce<T, true> {
315 static void SHashReduce(const T* self, SHashReducer hash_reduce) {
316 self->SHashReduce(hash_reduce);
317 }
318};
319
320template <typename T>
321struct ReflectionTrait : public ImplVisitAttrs<T>,
322 public ImplSEqualReduce<T>,
323 public ImplSHashReduce<T> {};
324
325template <typename T, typename TraitName,
326 bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
327struct SelectVisitAttrs {
328 static constexpr const std::nullptr_t VisitAttrs = nullptr;
329};
330
331template <typename T, typename TraitName>
332struct SelectVisitAttrs<T, TraitName, false> {
333 static void VisitAttrs(Object* self, AttrVisitor* v) {
334 TraitName::VisitAttrs(static_cast<T*>(self), v);
335 }
336};
337
338template <typename T, typename TraitName,
339 bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
340struct SelectSEqualReduce {
341 static constexpr const std::nullptr_t SEqualReduce = nullptr;
342};
343
344template <typename T, typename TraitName>
345struct SelectSEqualReduce<T, TraitName, false> {
346 static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) {
347 return TraitName::SEqualReduce(static_cast<const T*>(self), static_cast<const T*>(other),
348 equal);
349 }
350};
351
352template <typename T, typename TraitName,
353 bool = std::is_null_pointer<decltype(TraitName::SHashReduce)>::value>
354struct SelectSHashReduce {
355 static constexpr const std::nullptr_t SHashReduce = nullptr;
356};
357
358template <typename T, typename TraitName>
359struct SelectSHashReduce<T, TraitName, false> {
360 static void SHashReduce(const Object* self, SHashReducer hash_reduce) {
361 return TraitName::SHashReduce(static_cast<const T*>(self), hash_reduce);
362 }
363};
364
365} // namespace detail
366
367template <typename T, typename TraitName>
368inline ReflectionVTable::Registry ReflectionVTable::Register() {
369 uint32_t tindex = T::RuntimeTypeIndex();
370 if (tindex >= fvisit_attrs_.size()) {
371 fvisit_attrs_.resize(tindex + 1, nullptr);
372 fcreate_.resize(tindex + 1, nullptr);
373 frepr_bytes_.resize(tindex + 1, nullptr);
374 fsequal_reduce_.resize(tindex + 1, nullptr);
375 fshash_reduce_.resize(tindex + 1, nullptr);
376 }
377 // functor that implements the redirection.
378 fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
379
380 fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
381
382 fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce<T, TraitName>::SHashReduce;
383
384 return Registry(this, tindex);
385}
386
387inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const {
388 uint32_t tindex = self->type_index();
389 if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
390 return;
391 }
392 fvisit_attrs_[tindex](self, visitor);
393}
394
395inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const {
396 uint32_t tindex = self->type_index();
397 if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
398 if (repr_bytes != nullptr) {
399 *repr_bytes = frepr_bytes_[tindex](self);
400 }
401 return true;
402 } else {
403 return false;
404 }
405}
406
407/*!
408 * \brief Given an object and an address of its attribute, return the key of the attribute.
409 * \return nullptr if no attribute with the given address exists.
410 */
411Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);
412
413} // namespace tvm
414#endif // TVM_NODE_REFLECTION_H_
415