1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/runtime/object.h
21 * \brief A managed object in the TVM runtime.
22 */
23#ifndef TVM_RUNTIME_OBJECT_H_
24#define TVM_RUNTIME_OBJECT_H_
25
26#include <tvm/runtime/c_runtime_api.h>
27#include <tvm/runtime/logging.h>
28
29#include <string>
30#include <type_traits>
31#include <utility>
32
33/*!
34 * \brief Whether or not use atomic reference counter.
35 * If the reference counter is not atomic,
36 * an object cannot be owned by multiple threads.
37 * We can, however, move an object across threads
38 */
39#ifndef TVM_OBJECT_ATOMIC_REF_COUNTER
40#define TVM_OBJECT_ATOMIC_REF_COUNTER 1
41#endif
42
43#if TVM_OBJECT_ATOMIC_REF_COUNTER
44#include <atomic>
45#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
46
47namespace tvm {
48namespace runtime {
49
50/*!
51 * \brief Namespace for the list of type index.
52 * \note Use struct so that we have to use TypeIndex::ENumName to refer to
53 * the constant, but still able to use enum.
54 */
55struct TypeIndex {
56 enum {
57 /*! \brief Root object type. */
58 kRoot = 0,
59 // Standard static index assignments,
60 // Frontends can take benefit of these constants.
61 /*! \brief runtime::Module. */
62 kRuntimeModule = 1,
63 /*! \brief runtime::NDArray. */
64 kRuntimeNDArray = 2,
65 /*! \brief runtime::String. */
66 kRuntimeString = 3,
67 /*! \brief runtime::Array. */
68 kRuntimeArray = 4,
69 /*! \brief runtime::Map. */
70 kRuntimeMap = 5,
71 /*! \brief runtime::ShapeTuple. */
72 kRuntimeShapeTuple = 6,
73 /*! \brief runtime::PackedFunc. */
74 kRuntimePackedFunc = 7,
75 // static assignments that may subject to change.
76 kRuntimeClosure,
77 kRuntimeADT,
78 kStaticIndexEnd,
79 /*! \brief Type index is allocated during runtime. */
80 kDynamic = kStaticIndexEnd
81 };
82}; // namespace TypeIndex
83
84/*!
85 * \brief base class of all object containers.
86 *
87 * Sub-class of objects should declare the following static constexpr fields:
88 *
89 * - _type_index:
90 * Static type index of the object, if assigned to TypeIndex::kDynamic
91 * the type index will be assigned during runtime.
92 * Runtime type index can be accessed by ObjectType::TypeIndex();
93 * - _type_key:
94 * The unique string identifier of the type.
95 * - _type_final:
96 * Whether the type is terminal type(there is no subclass of the type in the object system).
97 * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO
98 * It is still OK to sub-class a terminal object type T and construct it using make_object.
99 * But IsInstance check will only show that the object type is T(instead of the sub-class).
100 *
101 * The following two fields are necessary for base classes that can be sub-classed.
102 *
103 * - _type_child_slots:
104 * Number of reserved type index slots for child classes.
105 * Used for runtime optimization for type checking in IsInstance.
106 * If an object's type_index is within range of [type_index, type_index + _type_child_slots]
107 * Then the object can be quickly decided as sub-class of the current object class.
108 * If not, a fallback mechanism is used to check the global type table.
109 * Recommendation: set to estimate number of children needed.
110 * - _type_child_slots_can_overflow:
111 * Whether we can add additional child classes even if the number of child classes
112 * exceeds the _type_child_slots. A fallback mechanism to check global type table will be
113 * used. Recommendation: set to false for optimal runtime speed if we know exact number of children.
114 *
115 * Two macros are used to declare helper functions in the object:
116 * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed.
117 * - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed.
118 *
119 * New objects can be created using make_object function.
120 * Which will automatically populate the type_index and deleter of the object.
121 *
122 * \sa make_object
123 * \sa ObjectPtr
124 * \sa ObjectRef
125 *
126 * \code
127 *
128 * // Create a base object
129 * class BaseObj : public Object {
130 * public:
131 * // object fields
132 * int field0;
133 *
134 * // object properties
135 * static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
136 * static constexpr const char* _type_key = "test.BaseObj";
137 * TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object);
138 * };
139 *
140 * class LeafObj : public BaseObj {
141 * public:
142 * // fields
143 * int child_field0;
144 * // object properties
145 * static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
146 * static constexpr const char* _type_key = "test.LeafObj";
147 * TVM_DECLARE_BASE_OBJECT_INFO(LeafObj, Object);
148 * };
149 *
150 * // The following code should be put into a cc file.
151 * TVM_REGISTER_OBJECT_TYPE(BaseObj);
152 * TVM_REGISTER_OBJECT_TYPE(LeafObj);
153 *
154 * // Usage example.
155 * void TestObjects() {
156 * // create an object
157 * ObjectRef leaf_ref(make_object<LeafObj>());
158 * // cast to a specific instance
159 * const LeafObj* leaf_ptr = leaf_ref.as<LeafObj>();
160 * ICHECK(leaf_ptr != nullptr);
161 * // can also cast to the base class.
162 * ICHECK(leaf_ref.as<BaseObj>() != nullptr);
163 * }
164 *
165 * \endcode
166 */
167class TVM_DLL Object {
168 public:
169 /*!
170 * \brief Object deleter
171 * \param self pointer to the Object.
172 */
173 typedef void (*FDeleter)(Object* self);
174 /*! \return The internal runtime type index of the object. */
175 uint32_t type_index() const { return type_index_; }
176 /*!
177 * \return the type key of the object.
178 * \note this operation is expensive, can be used for error reporting.
179 */
180 std::string GetTypeKey() const { return TypeIndex2Key(type_index_); }
181 /*!
182 * \return A hash value of the return of GetTypeKey.
183 */
184 size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); }
185 /*!
186 * Check if the object is an instance of TargetType.
187 * \tparam TargetType The target type to be checked.
188 * \return Whether the target type is true.
189 */
190 template <typename TargetType>
191 inline bool IsInstance() const;
192 /*!
193 * \return Whether the cell has only one reference
194 * \note We use stl style naming to be consistent with known API in shared_ptr.
195 */
196 inline bool unique() const;
197 /*!
198 * \brief Get the type key of the corresponding index from runtime.
199 * \param tindex The type index.
200 * \return the result.
201 */
202 static std::string TypeIndex2Key(uint32_t tindex);
203 /*!
204 * \brief Get the type key hash of the corresponding index from runtime.
205 * \param tindex The type index.
206 * \return the related key-hash.
207 */
208 static size_t TypeIndex2KeyHash(uint32_t tindex);
209 /*!
210 * \brief Get the type index of the corresponding key from runtime.
211 * \param key The type key.
212 * \return the result.
213 */
214 static uint32_t TypeKey2Index(const std::string& key);
215
216#if TVM_OBJECT_ATOMIC_REF_COUNTER
217 using RefCounterType = std::atomic<int32_t>;
218#else
219 using RefCounterType = int32_t;
220#endif
221
222 static constexpr const char* _type_key = "runtime.Object";
223
224 static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; }
225 static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; }
226
227 // Default object type properties for sub-classes
228 static constexpr bool _type_final = false;
229 static constexpr uint32_t _type_child_slots = 0;
230 static constexpr bool _type_child_slots_can_overflow = true;
231 // member information
232 static constexpr bool _type_has_method_visit_attrs = true;
233 static constexpr bool _type_has_method_sequal_reduce = false;
234 static constexpr bool _type_has_method_shash_reduce = false;
235 // NOTE: the following field is not type index of Object
236 // but was intended to be used by sub-classes as default value.
237 // The type index of Object is TypeIndex::kRoot
238 static constexpr uint32_t _type_index = TypeIndex::kDynamic;
239
240 // Default constructor and copy constructor
241 Object() {}
242 // Override the copy and assign constructors to do nothing.
243 // This is to make sure only contents, but not deleter and ref_counter
244 // are copied when a child class copies itself.
245 // This will enable us to use make_object<ObjectClass>(*obj_ptr)
246 // to copy an existing object.
247 Object(const Object& other) { // NOLINT(*)
248 }
249 Object(Object&& other) { // NOLINT(*)
250 }
251 Object& operator=(const Object& other) { // NOLINT(*)
252 return *this;
253 }
254 Object& operator=(Object&& other) { // NOLINT(*)
255 return *this;
256 }
257
258 protected:
259 // The fields of the base object cell.
260 /*! \brief Type index(tag) that indicates the type of the object. */
261 uint32_t type_index_{0};
262 /*! \brief The internal reference counter */
263 RefCounterType ref_counter_{0};
264 /*!
265 * \brief deleter of this object to enable customized allocation.
266 * If the deleter is nullptr, no deletion will be performed.
267 * The creator of the object must always set the deleter field properly.
268 */
269 FDeleter deleter_ = nullptr;
270 // Invariant checks.
271 static_assert(sizeof(int32_t) == sizeof(RefCounterType) &&
272 alignof(int32_t) == sizeof(RefCounterType),
273 "RefCounter ABI check.");
274
275 /*!
276 * \brief Get the type index using type key.
277 *
278 * When the function is first time called for a type,
279 * it will register the type to the type table in the runtime.
280 * If the static_tindex is TypeIndex::kDynamic, the function will
281 * allocate a runtime type index.
282 * Otherwise, we will populate the type table and return the static index.
283 *
284 * \param key the type key.
285 * \param static_tindex The current _type_index field.
286 * can be TypeIndex::kDynamic.
287 * \param parent_tindex The index of the parent.
288 * \param type_child_slots Number of slots reserved for its children.
289 * \param type_child_slots_can_overflow Whether to allow child to overflow the slots.
290 * \return The allocated type index.
291 */
292 static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex,
293 uint32_t parent_tindex, uint32_t type_child_slots,
294 bool type_child_slots_can_overflow);
295
296 // reference counter related operations
297 /*! \brief developer function, increases reference counter. */
298 inline void IncRef();
299 /*!
300 * \brief developer function, decrease reference counter.
301 * \note The deleter will be called when ref_counter_ becomes zero.
302 */
303 inline void DecRef();
304
305 private:
306 /*!
307 * \return The usage count of the cell.
308 * \note We use stl style naming to be consistent with known API in shared_ptr.
309 */
310 inline int use_count() const;
311 /*!
312 * \brief Check of this object is derived from the parent.
313 * \param parent_tindex The parent type index.
314 * \return The derivation results.
315 */
316 bool DerivedFrom(uint32_t parent_tindex) const;
317 // friend classes
318 template <typename>
319 friend class ObjAllocatorBase;
320 template <typename>
321 friend class ObjectPtr;
322 friend class TVMRetValue;
323 friend class ObjectInternal;
324};
325
326/*!
327 * \brief Get a reference type from a raw object ptr type
328 *
329 * It is always important to get a reference type
330 * if we want to return a value as reference or keep
331 * the object alive beyond the scope of the function.
332 *
333 * \param ptr The object pointer
334 * \tparam RefType The reference type
335 * \tparam ObjectType The object type
336 * \return The corresponding RefType
337 */
338template <typename RelayRefType, typename ObjectType>
339inline RelayRefType GetRef(const ObjectType* ptr);
340
341/*!
342 * \brief Downcast a base reference type to a more specific type.
343 *
344 * \param ref The input reference
345 * \return The corresponding SubRef.
346 * \tparam SubRef The target specific reference type.
347 * \tparam BaseRef the current reference type.
348 */
349template <typename SubRef, typename BaseRef>
350inline SubRef Downcast(BaseRef ref);
351
352/*!
353 * \brief A custom smart pointer for Object.
354 * \tparam T the content data type.
355 * \sa make_object
356 */
357template <typename T>
358class ObjectPtr {
359 public:
360 /*! \brief default constructor */
361 ObjectPtr() {}
362 /*! \brief default constructor */
363 ObjectPtr(std::nullptr_t) {} // NOLINT(*)
364 /*!
365 * \brief copy constructor
366 * \param other The value to be moved
367 */
368 ObjectPtr(const ObjectPtr<T>& other) // NOLINT(*)
369 : ObjectPtr(other.data_) {}
370 /*!
371 * \brief copy constructor
372 * \param other The value to be moved
373 */
374 template <typename U>
375 ObjectPtr(const ObjectPtr<U>& other) // NOLINT(*)
376 : ObjectPtr(other.data_) {
377 static_assert(std::is_base_of<T, U>::value,
378 "can only assign of child class ObjectPtr to parent");
379 }
380 /*!
381 * \brief move constructor
382 * \param other The value to be moved
383 */
384 ObjectPtr(ObjectPtr<T>&& other) // NOLINT(*)
385 : data_(other.data_) {
386 other.data_ = nullptr;
387 }
388 /*!
389 * \brief move constructor
390 * \param other The value to be moved
391 */
392 template <typename Y>
393 ObjectPtr(ObjectPtr<Y>&& other) // NOLINT(*)
394 : data_(other.data_) {
395 static_assert(std::is_base_of<T, Y>::value,
396 "can only assign of child class ObjectPtr to parent");
397 other.data_ = nullptr;
398 }
399 /*! \brief destructor */
400 ~ObjectPtr() { this->reset(); }
401 /*!
402 * \brief Swap this array with another Object
403 * \param other The other Object
404 */
405 void swap(ObjectPtr<T>& other) { // NOLINT(*)
406 std::swap(data_, other.data_);
407 }
408 /*!
409 * \return Get the content of the pointer
410 */
411 T* get() const { return static_cast<T*>(data_); }
412 /*!
413 * \return The pointer
414 */
415 T* operator->() const { return get(); }
416 /*!
417 * \return The reference
418 */
419 T& operator*() const { // NOLINT(*)
420 return *get();
421 }
422 /*!
423 * \brief copy assignment
424 * \param other The value to be assigned.
425 * \return reference to self.
426 */
427 ObjectPtr<T>& operator=(const ObjectPtr<T>& other) { // NOLINT(*)
428 // takes in plane operator to enable copy elison.
429 // copy-and-swap idiom
430 ObjectPtr(other).swap(*this); // NOLINT(*)
431 return *this;
432 }
433 /*!
434 * \brief move assignment
435 * \param other The value to be assigned.
436 * \return reference to self.
437 */
438 ObjectPtr<T>& operator=(ObjectPtr<T>&& other) { // NOLINT(*)
439 // copy-and-swap idiom
440 ObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
441 return *this;
442 }
443 /*!
444 * \brief nullptr check
445 * \return result of comparison of internal pointer with nullptr.
446 */
447 explicit operator bool() const { return get() != nullptr; }
448 /*! \brief reset the content of ptr to be nullptr */
449 void reset() {
450 if (data_ != nullptr) {
451 data_->DecRef();
452 data_ = nullptr;
453 }
454 }
455 /*! \return The use count of the ptr, for debug purposes */
456 int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
457 /*! \return whether the reference is unique */
458 bool unique() const { return data_ != nullptr && data_->use_count() == 1; }
459 /*! \return Whether two ObjectPtr do not equal each other */
460 bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; }
461 /*! \return Whether two ObjectPtr equals each other */
462 bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; }
463 /*! \return Whether the pointer is nullptr */
464 bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
465 /*! \return Whether the pointer is not nullptr */
466 bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
467
468 private:
469 /*! \brief internal pointer field */
470 Object* data_{nullptr};
471 /*!
472 * \brief constructor from Object
473 * \param data The data pointer
474 */
475 explicit ObjectPtr(Object* data) : data_(data) {
476 if (data != nullptr) {
477 data_->IncRef();
478 }
479 }
480 /*!
481 * \brief Move an ObjectPtr from an RValueRef argument.
482 * \param ref The rvalue reference.
483 * \return the moved result.
484 */
485 static ObjectPtr<T> MoveFromRValueRefArg(Object** ref) {
486 ObjectPtr<T> ptr;
487 ptr.data_ = *ref;
488 *ref = nullptr;
489 return ptr;
490 }
491 // friend classes
492 friend class Object;
493 friend class ObjectRef;
494 friend struct ObjectPtrHash;
495 template <typename>
496 friend class ObjectPtr;
497 template <typename>
498 friend class ObjAllocatorBase;
499 friend class TVMPODValue_;
500 friend class TVMArgsSetter;
501 friend class TVMRetValue;
502 friend class TVMArgValue;
503 friend class TVMMovableArgValue_;
504 template <typename RelayRefType, typename ObjType>
505 friend RelayRefType GetRef(const ObjType* ptr);
506 template <typename BaseType, typename ObjType>
507 friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr);
508};
509
510/*! \brief Base class of all object reference */
511class ObjectRef {
512 public:
513 /*! \brief default constructor */
514 ObjectRef() = default;
515 /*! \brief Constructor from existing object ptr */
516 explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
517 /*!
518 * \brief Comparator
519 * \param other Another object ref.
520 * \return the compare result.
521 */
522 bool same_as(const ObjectRef& other) const { return data_ == other.data_; }
523 /*!
524 * \brief Comparator
525 * \param other Another object ref.
526 * \return the compare result.
527 */
528 bool operator==(const ObjectRef& other) const { return data_ == other.data_; }
529 /*!
530 * \brief Comparator
531 * \param other Another object ref.
532 * \return the compare result.
533 */
534 bool operator!=(const ObjectRef& other) const { return data_ != other.data_; }
535 /*!
536 * \brief Comparator
537 * \param other Another object ref by address.
538 * \return the compare result.
539 */
540 bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); }
541 /*!
542 * \return whether the object is defined(not null).
543 */
544 bool defined() const { return data_ != nullptr; }
545 /*! \return the internal object pointer */
546 const Object* get() const { return data_.get(); }
547 /*! \return the internal object pointer */
548 const Object* operator->() const { return get(); }
549 /*! \return whether the reference is unique */
550 bool unique() const { return data_.unique(); }
551 /*! \return The use count of the ptr, for debug purposes */
552 int use_count() const { return data_.use_count(); }
553 /*!
554 * \brief Try to downcast the internal Object to a
555 * raw pointer of a corresponding type.
556 *
557 * The function will return a nullptr if the cast failed.
558 *
559 * if (const Add *add = node_ref.As<Add>()) {
560 * // This is an add node
561 * }
562 * \tparam ObjectType the target type, must be a subtype of Object/
563 */
564 template <typename ObjectType>
565 inline const ObjectType* as() const;
566
567 /*! \brief type indicate the container type. */
568 using ContainerType = Object;
569 // Default type properties for the reference class.
570 static constexpr bool _type_is_nullable = true;
571
572 protected:
573 /*! \brief Internal pointer that backs the reference. */
574 ObjectPtr<Object> data_;
575 /*! \return return a mutable internal ptr, can be used by sub-classes. */
576 Object* get_mutable() const { return data_.get(); }
577 /*!
578 * \brief Internal helper function downcast a ref without check.
579 * \note Only used for internal dev purposes.
580 * \tparam T The target reference type.
581 * \return The casted result.
582 */
583 template <typename T>
584 static T DowncastNoCheck(ObjectRef ref) {
585 return T(std::move(ref.data_));
586 }
587 /*!
588 * \brief Clear the object ref data field without DecRef
589 * after we successfully moved the field.
590 * \param ref The reference data.
591 */
592 static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; }
593 /*!
594 * \brief Internal helper function get data_ as ObjectPtr of ObjectType.
595 * \note only used for internal dev purpose.
596 * \tparam ObjectType The corresponding object type.
597 * \return the corresponding type.
598 */
599 template <typename ObjectType>
600 static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) {
601 return ObjectPtr<ObjectType>(ref.data_.data_);
602 }
603 // friend classes.
604 friend struct ObjectPtrHash;
605 friend class TVMRetValue;
606 friend class TVMArgsSetter;
607 friend class ObjectInternal;
608 template <typename SubRef, typename BaseRef>
609 friend SubRef Downcast(BaseRef ref);
610};
611
612/*!
613 * \brief Get an object ptr type from a raw object ptr.
614 *
615 * \param ptr The object pointer
616 * \tparam BaseType The reference type
617 * \tparam ObjectType The object type
618 * \return The corresponding RefType
619 */
620template <typename BaseType, typename ObjectType>
621inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
622
623/*! \brief ObjectRef hash functor */
624struct ObjectPtrHash {
625 size_t operator()(const ObjectRef& a) const { return operator()(a.data_); }
626
627 template <typename T>
628 size_t operator()(const ObjectPtr<T>& a) const {
629 return std::hash<Object*>()(a.get());
630 }
631};
632
633/*! \brief ObjectRef equal functor */
634struct ObjectPtrEqual {
635 bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); }
636
637 template <typename T>
638 size_t operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const {
639 return a == b;
640 }
641};
642
643/*!
644 * \brief helper macro to declare a base object type that can be inherited.
645 * \param TypeName The name of the current type.
646 * \param ParentType The name of the ParentType
647 */
648#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
649 static_assert(!ParentType::_type_final, "ParentObj marked as final"); \
650 static uint32_t RuntimeTypeIndex() { \
651 static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
652 TypeName::_type_child_slots < ParentType::_type_child_slots, \
653 "Need to set _type_child_slots when parent specifies it."); \
654 if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
655 return TypeName::_type_index; \
656 } \
657 return _GetOrAllocRuntimeTypeIndex(); \
658 } \
659 static uint32_t _GetOrAllocRuntimeTypeIndex() { \
660 static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \
661 TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
662 TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
663 return tindex; \
664 }
665
666/*!
667 * \brief helper macro to declare type information in a final class.
668 * \param TypeName The name of the current type.
669 * \param ParentType The name of the ParentType
670 */
671#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
672 static const constexpr bool _type_final = true; \
673 static const constexpr int _type_child_slots = 0; \
674 TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
675
676/*! \brief helper macro to suppress unused warning */
677#if defined(__GNUC__)
678#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
679#else
680#define TVM_ATTRIBUTE_UNUSED
681#endif
682
683#define TVM_STR_CONCAT_(__x, __y) __x##__y
684#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
685
686#define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid
687
688/*!
689 * \brief Helper macro to register the object type to runtime.
690 * Makes sure that the runtime type table is correctly populated.
691 *
692 * Use this macro in the cc file for each terminal class.
693 */
694#define TVM_REGISTER_OBJECT_TYPE(TypeName) \
695 TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex()
696
697/*
698 * \brief Define the default copy/move constructor and assign operator
699 * \param TypeName The class typename.
700 */
701#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \
702 TypeName(const TypeName& other) = default; \
703 TypeName(TypeName&& other) = default; \
704 TypeName& operator=(const TypeName& other) = default; \
705 TypeName& operator=(TypeName&& other) = default;
706
707/*
708 * \brief Define object reference methods.
709 * \param TypeName The object type name
710 * \param ParentType The parent type of the objectref
711 * \param ObjectName The type name of the object.
712 */
713#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
714 TypeName() = default; \
715 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
716 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
717 const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
718 const ObjectName* get() const { return operator->(); } \
719 using ContainerType = ObjectName;
720
721/*
722 * \brief Define object reference methods that is not nullable.
723 *
724 * \param TypeName The object type name
725 * \param ParentType The parent type of the objectref
726 * \param ObjectName The type name of the object.
727 */
728#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
729 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
730 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
731 const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
732 const ObjectName* get() const { return operator->(); } \
733 static constexpr bool _type_is_nullable = false; \
734 using ContainerType = ObjectName;
735
736/*
737 * \brief Define object reference methods of whose content is mutable.
738 * \param TypeName The object type name
739 * \param ParentType The parent type of the objectref
740 * \param ObjectName The type name of the object.
741 * \note We recommend making objects immutable when possible.
742 * This macro is only reserved for objects that stores runtime states.
743 */
744#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
745 TypeName() = default; \
746 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
747 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
748 ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
749 using ContainerType = ObjectName;
750
751/*
752 * \brief Define object reference methods that is both not nullable and mutable.
753 *
754 * \param TypeName The object type name
755 * \param ParentType The parent type of the objectref
756 * \param ObjectName The type name of the object.
757 */
758#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
759 explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
760 TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
761 ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \
762 ObjectName* get() const { return operator->(); } \
763 static constexpr bool _type_is_nullable = false; \
764 using ContainerType = ObjectName;
765
766/*!
767 * \brief Define CopyOnWrite function in an ObjectRef.
768 * \param ObjectName The Type of the Node.
769 *
770 * CopyOnWrite will generate a unique copy of the internal node.
771 * The node will be copied if it is referenced by multiple places.
772 * The function returns the raw pointer to the node to allow modification
773 * of the content.
774 *
775 * \code
776 *
777 * MyCOWObjectRef ref, ref2;
778 * ref2 = ref;
779 * ref.CopyOnWrite()->value = new_value;
780 * assert(ref2->value == old_value);
781 * assert(ref->value == new_value);
782 *
783 * \endcode
784 */
785#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \
786 ObjectName* CopyOnWrite() { \
787 ICHECK(data_ != nullptr); \
788 if (!data_.unique()) { \
789 auto n = make_object<ObjectName>(*(operator->())); \
790 ObjectPtr<Object>(std::move(n)).swap(data_); \
791 } \
792 return static_cast<ObjectName*>(data_.get()); \
793 }
794
795// Implementations details below
796// Object reference counting.
797#if TVM_OBJECT_ATOMIC_REF_COUNTER
798
799inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); }
800
801inline void Object::DecRef() {
802 if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
803 std::atomic_thread_fence(std::memory_order_acquire);
804 if (this->deleter_ != nullptr) {
805 (*this->deleter_)(this);
806 }
807 }
808}
809
810inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); }
811
812#else
813
814inline void Object::IncRef() { ++ref_counter_; }
815
816inline void Object::DecRef() {
817 if (--ref_counter_ == 0) {
818 if (this->deleter_ != nullptr) {
819 (*this->deleter_)(this);
820 }
821 }
822}
823
824inline int Object::use_count() const { return ref_counter_; }
825
826#endif // TVM_OBJECT_ATOMIC_REF_COUNTER
827
828template <typename TargetType>
829inline bool Object::IsInstance() const {
830 const Object* self = this;
831 // NOTE: the following code can be optimized by
832 // compiler dead-code elimination for already known constants.
833 if (self != nullptr) {
834 // Everything is a subclass of object.
835 if (std::is_same<TargetType, Object>::value) return true;
836 if (TargetType::_type_final) {
837 // if the target type is a final type
838 // then we only need to check the equivalence.
839 return self->type_index_ == TargetType::RuntimeTypeIndex();
840 } else {
841 // if target type is a non-leaf type
842 // Check if type index falls into the range of reserved slots.
843 uint32_t begin = TargetType::RuntimeTypeIndex();
844 // The condition will be optimized by constant-folding.
845 if (TargetType::_type_child_slots != 0) {
846 uint32_t end = begin + TargetType::_type_child_slots;
847 if (self->type_index_ >= begin && self->type_index_ < end) return true;
848 } else {
849 if (self->type_index_ == begin) return true;
850 }
851 if (!TargetType::_type_child_slots_can_overflow) return false;
852 // Invariance: parent index is always smaller than the child.
853 if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false;
854 // The rare slower-path, check type hierarchy.
855 return self->DerivedFrom(TargetType::RuntimeTypeIndex());
856 }
857 } else {
858 return false;
859 }
860}
861
862inline bool Object::unique() const { return use_count() == 1; }
863
864template <typename ObjectType>
865inline const ObjectType* ObjectRef::as() const {
866 if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
867 return static_cast<ObjectType*>(data_.get());
868 } else {
869 return nullptr;
870 }
871}
872
873template <typename RefType, typename ObjType>
874inline RefType GetRef(const ObjType* ptr) {
875 static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
876 "Can only cast to the ref of same container type");
877 if (!RefType::_type_is_nullable) {
878 ICHECK(ptr != nullptr);
879 }
880 return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
881}
882
883template <typename BaseType, typename ObjType>
884inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
885 static_assert(std::is_base_of<BaseType, ObjType>::value,
886 "Can only cast to the ref of same container type");
887 return ObjectPtr<BaseType>(static_cast<Object*>(ptr));
888}
889
890template <typename SubRef, typename BaseRef>
891inline SubRef Downcast(BaseRef ref) {
892 if (ref.defined()) {
893 ICHECK(ref->template IsInstance<typename SubRef::ContainerType>())
894 << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key
895 << " failed.";
896 } else {
897 ICHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of "
898 << SubRef::ContainerType::_type_key;
899 }
900 return SubRef(std::move(ref.data_));
901}
902
903} // namespace runtime
904} // namespace tvm
905
906#endif // TVM_RUNTIME_OBJECT_H_
907