1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/runtime/object.h |
21 | * \brief A managed object in the TVM runtime. |
22 | */ |
23 | #ifndef TVM_RUNTIME_OBJECT_H_ |
24 | #define TVM_RUNTIME_OBJECT_H_ |
25 | |
26 | #include <tvm/runtime/c_runtime_api.h> |
27 | #include <tvm/runtime/logging.h> |
28 | |
29 | #include <string> |
30 | #include <type_traits> |
31 | #include <utility> |
32 | |
33 | /*! |
34 | * \brief Whether or not use atomic reference counter. |
35 | * If the reference counter is not atomic, |
36 | * an object cannot be owned by multiple threads. |
37 | * We can, however, move an object across threads |
38 | */ |
39 | #ifndef TVM_OBJECT_ATOMIC_REF_COUNTER |
40 | #define TVM_OBJECT_ATOMIC_REF_COUNTER 1 |
41 | #endif |
42 | |
43 | #if TVM_OBJECT_ATOMIC_REF_COUNTER |
44 | #include <atomic> |
45 | #endif // TVM_OBJECT_ATOMIC_REF_COUNTER |
46 | |
47 | namespace tvm { |
48 | namespace runtime { |
49 | |
50 | /*! |
51 | * \brief Namespace for the list of type index. |
52 | * \note Use struct so that we have to use TypeIndex::ENumName to refer to |
53 | * the constant, but still able to use enum. |
54 | */ |
55 | struct TypeIndex { |
56 | enum { |
57 | /*! \brief Root object type. */ |
58 | kRoot = 0, |
59 | // Standard static index assignments, |
60 | // Frontends can take benefit of these constants. |
61 | /*! \brief runtime::Module. */ |
62 | kRuntimeModule = 1, |
63 | /*! \brief runtime::NDArray. */ |
64 | kRuntimeNDArray = 2, |
65 | /*! \brief runtime::String. */ |
66 | kRuntimeString = 3, |
67 | /*! \brief runtime::Array. */ |
68 | kRuntimeArray = 4, |
69 | /*! \brief runtime::Map. */ |
70 | kRuntimeMap = 5, |
71 | /*! \brief runtime::ShapeTuple. */ |
72 | kRuntimeShapeTuple = 6, |
73 | /*! \brief runtime::PackedFunc. */ |
74 | kRuntimePackedFunc = 7, |
75 | // static assignments that may subject to change. |
76 | kRuntimeClosure, |
77 | kRuntimeADT, |
78 | kStaticIndexEnd, |
79 | /*! \brief Type index is allocated during runtime. */ |
80 | kDynamic = kStaticIndexEnd |
81 | }; |
82 | }; // namespace TypeIndex |
83 | |
84 | /*! |
85 | * \brief base class of all object containers. |
86 | * |
87 | * Sub-class of objects should declare the following static constexpr fields: |
88 | * |
89 | * - _type_index: |
90 | * Static type index of the object, if assigned to TypeIndex::kDynamic |
91 | * the type index will be assigned during runtime. |
92 | * Runtime type index can be accessed by ObjectType::TypeIndex(); |
93 | * - _type_key: |
94 | * The unique string identifier of the type. |
95 | * - _type_final: |
96 | * Whether the type is terminal type(there is no subclass of the type in the object system). |
97 | * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO |
98 | * It is still OK to sub-class a terminal object type T and construct it using make_object. |
99 | * But IsInstance check will only show that the object type is T(instead of the sub-class). |
100 | * |
101 | * The following two fields are necessary for base classes that can be sub-classed. |
102 | * |
103 | * - _type_child_slots: |
104 | * Number of reserved type index slots for child classes. |
105 | * Used for runtime optimization for type checking in IsInstance. |
106 | * If an object's type_index is within range of [type_index, type_index + _type_child_slots] |
107 | * Then the object can be quickly decided as sub-class of the current object class. |
108 | * If not, a fallback mechanism is used to check the global type table. |
109 | * Recommendation: set to estimate number of children needed. |
110 | * - _type_child_slots_can_overflow: |
111 | * Whether we can add additional child classes even if the number of child classes |
112 | * exceeds the _type_child_slots. A fallback mechanism to check global type table will be |
113 | * used. Recommendation: set to false for optimal runtime speed if we know exact number of children. |
114 | * |
115 | * Two macros are used to declare helper functions in the object: |
116 | * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. |
117 | * - Use TVM_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. |
118 | * |
119 | * New objects can be created using make_object function. |
120 | * Which will automatically populate the type_index and deleter of the object. |
121 | * |
122 | * \sa make_object |
123 | * \sa ObjectPtr |
124 | * \sa ObjectRef |
125 | * |
126 | * \code |
127 | * |
128 | * // Create a base object |
129 | * class BaseObj : public Object { |
130 | * public: |
131 | * // object fields |
132 | * int field0; |
133 | * |
134 | * // object properties |
135 | * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; |
136 | * static constexpr const char* _type_key = "test.BaseObj"; |
137 | * TVM_DECLARE_BASE_OBJECT_INFO(BaseObj, Object); |
138 | * }; |
139 | * |
140 | * class LeafObj : public BaseObj { |
141 | * public: |
142 | * // fields |
143 | * int child_field0; |
144 | * // object properties |
145 | * static constexpr const uint32_t _type_index = TypeIndex::kDynamic; |
146 | * static constexpr const char* _type_key = "test.LeafObj"; |
147 | * TVM_DECLARE_BASE_OBJECT_INFO(LeafObj, Object); |
148 | * }; |
149 | * |
150 | * // The following code should be put into a cc file. |
151 | * TVM_REGISTER_OBJECT_TYPE(BaseObj); |
152 | * TVM_REGISTER_OBJECT_TYPE(LeafObj); |
153 | * |
154 | * // Usage example. |
155 | * void TestObjects() { |
156 | * // create an object |
157 | * ObjectRef leaf_ref(make_object<LeafObj>()); |
158 | * // cast to a specific instance |
159 | * const LeafObj* leaf_ptr = leaf_ref.as<LeafObj>(); |
160 | * ICHECK(leaf_ptr != nullptr); |
161 | * // can also cast to the base class. |
162 | * ICHECK(leaf_ref.as<BaseObj>() != nullptr); |
163 | * } |
164 | * |
165 | * \endcode |
166 | */ |
167 | class TVM_DLL Object { |
168 | public: |
169 | /*! |
170 | * \brief Object deleter |
171 | * \param self pointer to the Object. |
172 | */ |
173 | typedef void (*FDeleter)(Object* self); |
174 | /*! \return The internal runtime type index of the object. */ |
175 | uint32_t type_index() const { return type_index_; } |
176 | /*! |
177 | * \return the type key of the object. |
178 | * \note this operation is expensive, can be used for error reporting. |
179 | */ |
180 | std::string GetTypeKey() const { return TypeIndex2Key(type_index_); } |
181 | /*! |
182 | * \return A hash value of the return of GetTypeKey. |
183 | */ |
184 | size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); } |
185 | /*! |
186 | * Check if the object is an instance of TargetType. |
187 | * \tparam TargetType The target type to be checked. |
188 | * \return Whether the target type is true. |
189 | */ |
190 | template <typename TargetType> |
191 | inline bool IsInstance() const; |
192 | /*! |
193 | * \return Whether the cell has only one reference |
194 | * \note We use stl style naming to be consistent with known API in shared_ptr. |
195 | */ |
196 | inline bool unique() const; |
197 | /*! |
198 | * \brief Get the type key of the corresponding index from runtime. |
199 | * \param tindex The type index. |
200 | * \return the result. |
201 | */ |
202 | static std::string TypeIndex2Key(uint32_t tindex); |
203 | /*! |
204 | * \brief Get the type key hash of the corresponding index from runtime. |
205 | * \param tindex The type index. |
206 | * \return the related key-hash. |
207 | */ |
208 | static size_t TypeIndex2KeyHash(uint32_t tindex); |
209 | /*! |
210 | * \brief Get the type index of the corresponding key from runtime. |
211 | * \param key The type key. |
212 | * \return the result. |
213 | */ |
214 | static uint32_t TypeKey2Index(const std::string& key); |
215 | |
216 | #if TVM_OBJECT_ATOMIC_REF_COUNTER |
217 | using RefCounterType = std::atomic<int32_t>; |
218 | #else |
219 | using RefCounterType = int32_t; |
220 | #endif |
221 | |
222 | static constexpr const char* _type_key = "runtime.Object" ; |
223 | |
224 | static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; } |
225 | static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; } |
226 | |
227 | // Default object type properties for sub-classes |
228 | static constexpr bool _type_final = false; |
229 | static constexpr uint32_t _type_child_slots = 0; |
230 | static constexpr bool _type_child_slots_can_overflow = true; |
231 | // member information |
232 | static constexpr bool _type_has_method_visit_attrs = true; |
233 | static constexpr bool _type_has_method_sequal_reduce = false; |
234 | static constexpr bool _type_has_method_shash_reduce = false; |
235 | // NOTE: the following field is not type index of Object |
236 | // but was intended to be used by sub-classes as default value. |
237 | // The type index of Object is TypeIndex::kRoot |
238 | static constexpr uint32_t _type_index = TypeIndex::kDynamic; |
239 | |
240 | // Default constructor and copy constructor |
241 | Object() {} |
242 | // Override the copy and assign constructors to do nothing. |
243 | // This is to make sure only contents, but not deleter and ref_counter |
244 | // are copied when a child class copies itself. |
245 | // This will enable us to use make_object<ObjectClass>(*obj_ptr) |
246 | // to copy an existing object. |
247 | Object(const Object& other) { // NOLINT(*) |
248 | } |
249 | Object(Object&& other) { // NOLINT(*) |
250 | } |
251 | Object& operator=(const Object& other) { // NOLINT(*) |
252 | return *this; |
253 | } |
254 | Object& operator=(Object&& other) { // NOLINT(*) |
255 | return *this; |
256 | } |
257 | |
258 | protected: |
259 | // The fields of the base object cell. |
260 | /*! \brief Type index(tag) that indicates the type of the object. */ |
261 | uint32_t type_index_{0}; |
262 | /*! \brief The internal reference counter */ |
263 | RefCounterType ref_counter_{0}; |
264 | /*! |
265 | * \brief deleter of this object to enable customized allocation. |
266 | * If the deleter is nullptr, no deletion will be performed. |
267 | * The creator of the object must always set the deleter field properly. |
268 | */ |
269 | FDeleter deleter_ = nullptr; |
270 | // Invariant checks. |
271 | static_assert(sizeof(int32_t) == sizeof(RefCounterType) && |
272 | alignof(int32_t) == sizeof(RefCounterType), |
273 | "RefCounter ABI check." ); |
274 | |
275 | /*! |
276 | * \brief Get the type index using type key. |
277 | * |
278 | * When the function is first time called for a type, |
279 | * it will register the type to the type table in the runtime. |
280 | * If the static_tindex is TypeIndex::kDynamic, the function will |
281 | * allocate a runtime type index. |
282 | * Otherwise, we will populate the type table and return the static index. |
283 | * |
284 | * \param key the type key. |
285 | * \param static_tindex The current _type_index field. |
286 | * can be TypeIndex::kDynamic. |
287 | * \param parent_tindex The index of the parent. |
288 | * \param type_child_slots Number of slots reserved for its children. |
289 | * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. |
290 | * \return The allocated type index. |
291 | */ |
292 | static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, |
293 | uint32_t parent_tindex, uint32_t type_child_slots, |
294 | bool type_child_slots_can_overflow); |
295 | |
296 | // reference counter related operations |
297 | /*! \brief developer function, increases reference counter. */ |
298 | inline void IncRef(); |
299 | /*! |
300 | * \brief developer function, decrease reference counter. |
301 | * \note The deleter will be called when ref_counter_ becomes zero. |
302 | */ |
303 | inline void DecRef(); |
304 | |
305 | private: |
306 | /*! |
307 | * \return The usage count of the cell. |
308 | * \note We use stl style naming to be consistent with known API in shared_ptr. |
309 | */ |
310 | inline int use_count() const; |
311 | /*! |
312 | * \brief Check of this object is derived from the parent. |
313 | * \param parent_tindex The parent type index. |
314 | * \return The derivation results. |
315 | */ |
316 | bool DerivedFrom(uint32_t parent_tindex) const; |
317 | // friend classes |
318 | template <typename> |
319 | friend class ObjAllocatorBase; |
320 | template <typename> |
321 | friend class ObjectPtr; |
322 | friend class TVMRetValue; |
323 | friend class ObjectInternal; |
324 | }; |
325 | |
326 | /*! |
327 | * \brief Get a reference type from a raw object ptr type |
328 | * |
329 | * It is always important to get a reference type |
330 | * if we want to return a value as reference or keep |
331 | * the object alive beyond the scope of the function. |
332 | * |
333 | * \param ptr The object pointer |
334 | * \tparam RefType The reference type |
335 | * \tparam ObjectType The object type |
336 | * \return The corresponding RefType |
337 | */ |
338 | template <typename RelayRefType, typename ObjectType> |
339 | inline RelayRefType GetRef(const ObjectType* ptr); |
340 | |
341 | /*! |
342 | * \brief Downcast a base reference type to a more specific type. |
343 | * |
344 | * \param ref The input reference |
345 | * \return The corresponding SubRef. |
346 | * \tparam SubRef The target specific reference type. |
347 | * \tparam BaseRef the current reference type. |
348 | */ |
349 | template <typename SubRef, typename BaseRef> |
350 | inline SubRef Downcast(BaseRef ref); |
351 | |
352 | /*! |
353 | * \brief A custom smart pointer for Object. |
354 | * \tparam T the content data type. |
355 | * \sa make_object |
356 | */ |
357 | template <typename T> |
358 | class ObjectPtr { |
359 | public: |
360 | /*! \brief default constructor */ |
361 | ObjectPtr() {} |
362 | /*! \brief default constructor */ |
363 | ObjectPtr(std::nullptr_t) {} // NOLINT(*) |
364 | /*! |
365 | * \brief copy constructor |
366 | * \param other The value to be moved |
367 | */ |
368 | ObjectPtr(const ObjectPtr<T>& other) // NOLINT(*) |
369 | : ObjectPtr(other.data_) {} |
370 | /*! |
371 | * \brief copy constructor |
372 | * \param other The value to be moved |
373 | */ |
374 | template <typename U> |
375 | ObjectPtr(const ObjectPtr<U>& other) // NOLINT(*) |
376 | : ObjectPtr(other.data_) { |
377 | static_assert(std::is_base_of<T, U>::value, |
378 | "can only assign of child class ObjectPtr to parent" ); |
379 | } |
380 | /*! |
381 | * \brief move constructor |
382 | * \param other The value to be moved |
383 | */ |
384 | ObjectPtr(ObjectPtr<T>&& other) // NOLINT(*) |
385 | : data_(other.data_) { |
386 | other.data_ = nullptr; |
387 | } |
388 | /*! |
389 | * \brief move constructor |
390 | * \param other The value to be moved |
391 | */ |
392 | template <typename Y> |
393 | ObjectPtr(ObjectPtr<Y>&& other) // NOLINT(*) |
394 | : data_(other.data_) { |
395 | static_assert(std::is_base_of<T, Y>::value, |
396 | "can only assign of child class ObjectPtr to parent" ); |
397 | other.data_ = nullptr; |
398 | } |
399 | /*! \brief destructor */ |
400 | ~ObjectPtr() { this->reset(); } |
401 | /*! |
402 | * \brief Swap this array with another Object |
403 | * \param other The other Object |
404 | */ |
405 | void swap(ObjectPtr<T>& other) { // NOLINT(*) |
406 | std::swap(data_, other.data_); |
407 | } |
408 | /*! |
409 | * \return Get the content of the pointer |
410 | */ |
411 | T* get() const { return static_cast<T*>(data_); } |
412 | /*! |
413 | * \return The pointer |
414 | */ |
415 | T* operator->() const { return get(); } |
416 | /*! |
417 | * \return The reference |
418 | */ |
419 | T& operator*() const { // NOLINT(*) |
420 | return *get(); |
421 | } |
422 | /*! |
423 | * \brief copy assignment |
424 | * \param other The value to be assigned. |
425 | * \return reference to self. |
426 | */ |
427 | ObjectPtr<T>& operator=(const ObjectPtr<T>& other) { // NOLINT(*) |
428 | // takes in plane operator to enable copy elison. |
429 | // copy-and-swap idiom |
430 | ObjectPtr(other).swap(*this); // NOLINT(*) |
431 | return *this; |
432 | } |
433 | /*! |
434 | * \brief move assignment |
435 | * \param other The value to be assigned. |
436 | * \return reference to self. |
437 | */ |
438 | ObjectPtr<T>& operator=(ObjectPtr<T>&& other) { // NOLINT(*) |
439 | // copy-and-swap idiom |
440 | ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) |
441 | return *this; |
442 | } |
443 | /*! |
444 | * \brief nullptr check |
445 | * \return result of comparison of internal pointer with nullptr. |
446 | */ |
447 | explicit operator bool() const { return get() != nullptr; } |
448 | /*! \brief reset the content of ptr to be nullptr */ |
449 | void reset() { |
450 | if (data_ != nullptr) { |
451 | data_->DecRef(); |
452 | data_ = nullptr; |
453 | } |
454 | } |
455 | /*! \return The use count of the ptr, for debug purposes */ |
456 | int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } |
457 | /*! \return whether the reference is unique */ |
458 | bool unique() const { return data_ != nullptr && data_->use_count() == 1; } |
459 | /*! \return Whether two ObjectPtr do not equal each other */ |
460 | bool operator==(const ObjectPtr<T>& other) const { return data_ == other.data_; } |
461 | /*! \return Whether two ObjectPtr equals each other */ |
462 | bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; } |
463 | /*! \return Whether the pointer is nullptr */ |
464 | bool operator==(std::nullptr_t null) const { return data_ == nullptr; } |
465 | /*! \return Whether the pointer is not nullptr */ |
466 | bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } |
467 | |
468 | private: |
469 | /*! \brief internal pointer field */ |
470 | Object* data_{nullptr}; |
471 | /*! |
472 | * \brief constructor from Object |
473 | * \param data The data pointer |
474 | */ |
475 | explicit ObjectPtr(Object* data) : data_(data) { |
476 | if (data != nullptr) { |
477 | data_->IncRef(); |
478 | } |
479 | } |
480 | /*! |
481 | * \brief Move an ObjectPtr from an RValueRef argument. |
482 | * \param ref The rvalue reference. |
483 | * \return the moved result. |
484 | */ |
485 | static ObjectPtr<T> MoveFromRValueRefArg(Object** ref) { |
486 | ObjectPtr<T> ptr; |
487 | ptr.data_ = *ref; |
488 | *ref = nullptr; |
489 | return ptr; |
490 | } |
491 | // friend classes |
492 | friend class Object; |
493 | friend class ObjectRef; |
494 | friend struct ObjectPtrHash; |
495 | template <typename> |
496 | friend class ObjectPtr; |
497 | template <typename> |
498 | friend class ObjAllocatorBase; |
499 | friend class TVMPODValue_; |
500 | friend class TVMArgsSetter; |
501 | friend class TVMRetValue; |
502 | friend class TVMArgValue; |
503 | friend class TVMMovableArgValue_; |
504 | template <typename RelayRefType, typename ObjType> |
505 | friend RelayRefType GetRef(const ObjType* ptr); |
506 | template <typename BaseType, typename ObjType> |
507 | friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr); |
508 | }; |
509 | |
510 | /*! \brief Base class of all object reference */ |
511 | class ObjectRef { |
512 | public: |
513 | /*! \brief default constructor */ |
514 | ObjectRef() = default; |
515 | /*! \brief Constructor from existing object ptr */ |
516 | explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {} |
517 | /*! |
518 | * \brief Comparator |
519 | * \param other Another object ref. |
520 | * \return the compare result. |
521 | */ |
522 | bool same_as(const ObjectRef& other) const { return data_ == other.data_; } |
523 | /*! |
524 | * \brief Comparator |
525 | * \param other Another object ref. |
526 | * \return the compare result. |
527 | */ |
528 | bool operator==(const ObjectRef& other) const { return data_ == other.data_; } |
529 | /*! |
530 | * \brief Comparator |
531 | * \param other Another object ref. |
532 | * \return the compare result. |
533 | */ |
534 | bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } |
535 | /*! |
536 | * \brief Comparator |
537 | * \param other Another object ref by address. |
538 | * \return the compare result. |
539 | */ |
540 | bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } |
541 | /*! |
542 | * \return whether the object is defined(not null). |
543 | */ |
544 | bool defined() const { return data_ != nullptr; } |
545 | /*! \return the internal object pointer */ |
546 | const Object* get() const { return data_.get(); } |
547 | /*! \return the internal object pointer */ |
548 | const Object* operator->() const { return get(); } |
549 | /*! \return whether the reference is unique */ |
550 | bool unique() const { return data_.unique(); } |
551 | /*! \return The use count of the ptr, for debug purposes */ |
552 | int use_count() const { return data_.use_count(); } |
553 | /*! |
554 | * \brief Try to downcast the internal Object to a |
555 | * raw pointer of a corresponding type. |
556 | * |
557 | * The function will return a nullptr if the cast failed. |
558 | * |
559 | * if (const Add *add = node_ref.As<Add>()) { |
560 | * // This is an add node |
561 | * } |
562 | * \tparam ObjectType the target type, must be a subtype of Object/ |
563 | */ |
564 | template <typename ObjectType> |
565 | inline const ObjectType* as() const; |
566 | |
567 | /*! \brief type indicate the container type. */ |
568 | using ContainerType = Object; |
569 | // Default type properties for the reference class. |
570 | static constexpr bool _type_is_nullable = true; |
571 | |
572 | protected: |
573 | /*! \brief Internal pointer that backs the reference. */ |
574 | ObjectPtr<Object> data_; |
575 | /*! \return return a mutable internal ptr, can be used by sub-classes. */ |
576 | Object* get_mutable() const { return data_.get(); } |
577 | /*! |
578 | * \brief Internal helper function downcast a ref without check. |
579 | * \note Only used for internal dev purposes. |
580 | * \tparam T The target reference type. |
581 | * \return The casted result. |
582 | */ |
583 | template <typename T> |
584 | static T DowncastNoCheck(ObjectRef ref) { |
585 | return T(std::move(ref.data_)); |
586 | } |
587 | /*! |
588 | * \brief Clear the object ref data field without DecRef |
589 | * after we successfully moved the field. |
590 | * \param ref The reference data. |
591 | */ |
592 | static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; } |
593 | /*! |
594 | * \brief Internal helper function get data_ as ObjectPtr of ObjectType. |
595 | * \note only used for internal dev purpose. |
596 | * \tparam ObjectType The corresponding object type. |
597 | * \return the corresponding type. |
598 | */ |
599 | template <typename ObjectType> |
600 | static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) { |
601 | return ObjectPtr<ObjectType>(ref.data_.data_); |
602 | } |
603 | // friend classes. |
604 | friend struct ObjectPtrHash; |
605 | friend class TVMRetValue; |
606 | friend class TVMArgsSetter; |
607 | friend class ObjectInternal; |
608 | template <typename SubRef, typename BaseRef> |
609 | friend SubRef Downcast(BaseRef ref); |
610 | }; |
611 | |
612 | /*! |
613 | * \brief Get an object ptr type from a raw object ptr. |
614 | * |
615 | * \param ptr The object pointer |
616 | * \tparam BaseType The reference type |
617 | * \tparam ObjectType The object type |
618 | * \return The corresponding RefType |
619 | */ |
620 | template <typename BaseType, typename ObjectType> |
621 | inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); |
622 | |
623 | /*! \brief ObjectRef hash functor */ |
624 | struct ObjectPtrHash { |
625 | size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } |
626 | |
627 | template <typename T> |
628 | size_t operator()(const ObjectPtr<T>& a) const { |
629 | return std::hash<Object*>()(a.get()); |
630 | } |
631 | }; |
632 | |
633 | /*! \brief ObjectRef equal functor */ |
634 | struct ObjectPtrEqual { |
635 | bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } |
636 | |
637 | template <typename T> |
638 | size_t operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const { |
639 | return a == b; |
640 | } |
641 | }; |
642 | |
643 | /*! |
644 | * \brief helper macro to declare a base object type that can be inherited. |
645 | * \param TypeName The name of the current type. |
646 | * \param ParentType The name of the ParentType |
647 | */ |
648 | #define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ |
649 | static_assert(!ParentType::_type_final, "ParentObj marked as final"); \ |
650 | static uint32_t RuntimeTypeIndex() { \ |
651 | static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ |
652 | TypeName::_type_child_slots < ParentType::_type_child_slots, \ |
653 | "Need to set _type_child_slots when parent specifies it."); \ |
654 | if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ |
655 | return TypeName::_type_index; \ |
656 | } \ |
657 | return _GetOrAllocRuntimeTypeIndex(); \ |
658 | } \ |
659 | static uint32_t _GetOrAllocRuntimeTypeIndex() { \ |
660 | static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \ |
661 | TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \ |
662 | TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \ |
663 | return tindex; \ |
664 | } |
665 | |
666 | /*! |
667 | * \brief helper macro to declare type information in a final class. |
668 | * \param TypeName The name of the current type. |
669 | * \param ParentType The name of the ParentType |
670 | */ |
671 | #define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ |
672 | static const constexpr bool _type_final = true; \ |
673 | static const constexpr int _type_child_slots = 0; \ |
674 | TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) |
675 | |
676 | /*! \brief helper macro to suppress unused warning */ |
677 | #if defined(__GNUC__) |
678 | #define TVM_ATTRIBUTE_UNUSED __attribute__((unused)) |
679 | #else |
680 | #define TVM_ATTRIBUTE_UNUSED |
681 | #endif |
682 | |
683 | #define TVM_STR_CONCAT_(__x, __y) __x##__y |
684 | #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) |
685 | |
686 | #define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid |
687 | |
688 | /*! |
689 | * \brief Helper macro to register the object type to runtime. |
690 | * Makes sure that the runtime type table is correctly populated. |
691 | * |
692 | * Use this macro in the cc file for each terminal class. |
693 | */ |
694 | #define TVM_REGISTER_OBJECT_TYPE(TypeName) \ |
695 | TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex() |
696 | |
697 | /* |
698 | * \brief Define the default copy/move constructor and assign operator |
699 | * \param TypeName The class typename. |
700 | */ |
701 | #define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ |
702 | TypeName(const TypeName& other) = default; \ |
703 | TypeName(TypeName&& other) = default; \ |
704 | TypeName& operator=(const TypeName& other) = default; \ |
705 | TypeName& operator=(TypeName&& other) = default; |
706 | |
707 | /* |
708 | * \brief Define object reference methods. |
709 | * \param TypeName The object type name |
710 | * \param ParentType The parent type of the objectref |
711 | * \param ObjectName The type name of the object. |
712 | */ |
713 | #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ |
714 | TypeName() = default; \ |
715 | explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ |
716 | TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ |
717 | const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \ |
718 | const ObjectName* get() const { return operator->(); } \ |
719 | using ContainerType = ObjectName; |
720 | |
721 | /* |
722 | * \brief Define object reference methods that is not nullable. |
723 | * |
724 | * \param TypeName The object type name |
725 | * \param ParentType The parent type of the objectref |
726 | * \param ObjectName The type name of the object. |
727 | */ |
728 | #define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ |
729 | explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ |
730 | TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ |
731 | const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \ |
732 | const ObjectName* get() const { return operator->(); } \ |
733 | static constexpr bool _type_is_nullable = false; \ |
734 | using ContainerType = ObjectName; |
735 | |
736 | /* |
737 | * \brief Define object reference methods of whose content is mutable. |
738 | * \param TypeName The object type name |
739 | * \param ParentType The parent type of the objectref |
740 | * \param ObjectName The type name of the object. |
741 | * \note We recommend making objects immutable when possible. |
742 | * This macro is only reserved for objects that stores runtime states. |
743 | */ |
744 | #define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ |
745 | TypeName() = default; \ |
746 | TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ |
747 | explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ |
748 | ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \ |
749 | using ContainerType = ObjectName; |
750 | |
751 | /* |
752 | * \brief Define object reference methods that is both not nullable and mutable. |
753 | * |
754 | * \param TypeName The object type name |
755 | * \param ParentType The parent type of the objectref |
756 | * \param ObjectName The type name of the object. |
757 | */ |
758 | #define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ |
759 | explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ |
760 | TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ |
761 | ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \ |
762 | ObjectName* get() const { return operator->(); } \ |
763 | static constexpr bool _type_is_nullable = false; \ |
764 | using ContainerType = ObjectName; |
765 | |
766 | /*! |
767 | * \brief Define CopyOnWrite function in an ObjectRef. |
768 | * \param ObjectName The Type of the Node. |
769 | * |
770 | * CopyOnWrite will generate a unique copy of the internal node. |
771 | * The node will be copied if it is referenced by multiple places. |
772 | * The function returns the raw pointer to the node to allow modification |
773 | * of the content. |
774 | * |
775 | * \code |
776 | * |
777 | * MyCOWObjectRef ref, ref2; |
778 | * ref2 = ref; |
779 | * ref.CopyOnWrite()->value = new_value; |
780 | * assert(ref2->value == old_value); |
781 | * assert(ref->value == new_value); |
782 | * |
783 | * \endcode |
784 | */ |
785 | #define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ |
786 | ObjectName* CopyOnWrite() { \ |
787 | ICHECK(data_ != nullptr); \ |
788 | if (!data_.unique()) { \ |
789 | auto n = make_object<ObjectName>(*(operator->())); \ |
790 | ObjectPtr<Object>(std::move(n)).swap(data_); \ |
791 | } \ |
792 | return static_cast<ObjectName*>(data_.get()); \ |
793 | } |
794 | |
795 | // Implementations details below |
796 | // Object reference counting. |
797 | #if TVM_OBJECT_ATOMIC_REF_COUNTER |
798 | |
799 | inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } |
800 | |
801 | inline void Object::DecRef() { |
802 | if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { |
803 | std::atomic_thread_fence(std::memory_order_acquire); |
804 | if (this->deleter_ != nullptr) { |
805 | (*this->deleter_)(this); |
806 | } |
807 | } |
808 | } |
809 | |
810 | inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); } |
811 | |
812 | #else |
813 | |
814 | inline void Object::IncRef() { ++ref_counter_; } |
815 | |
816 | inline void Object::DecRef() { |
817 | if (--ref_counter_ == 0) { |
818 | if (this->deleter_ != nullptr) { |
819 | (*this->deleter_)(this); |
820 | } |
821 | } |
822 | } |
823 | |
824 | inline int Object::use_count() const { return ref_counter_; } |
825 | |
826 | #endif // TVM_OBJECT_ATOMIC_REF_COUNTER |
827 | |
828 | template <typename TargetType> |
829 | inline bool Object::IsInstance() const { |
830 | const Object* self = this; |
831 | // NOTE: the following code can be optimized by |
832 | // compiler dead-code elimination for already known constants. |
833 | if (self != nullptr) { |
834 | // Everything is a subclass of object. |
835 | if (std::is_same<TargetType, Object>::value) return true; |
836 | if (TargetType::_type_final) { |
837 | // if the target type is a final type |
838 | // then we only need to check the equivalence. |
839 | return self->type_index_ == TargetType::RuntimeTypeIndex(); |
840 | } else { |
841 | // if target type is a non-leaf type |
842 | // Check if type index falls into the range of reserved slots. |
843 | uint32_t begin = TargetType::RuntimeTypeIndex(); |
844 | // The condition will be optimized by constant-folding. |
845 | if (TargetType::_type_child_slots != 0) { |
846 | uint32_t end = begin + TargetType::_type_child_slots; |
847 | if (self->type_index_ >= begin && self->type_index_ < end) return true; |
848 | } else { |
849 | if (self->type_index_ == begin) return true; |
850 | } |
851 | if (!TargetType::_type_child_slots_can_overflow) return false; |
852 | // Invariance: parent index is always smaller than the child. |
853 | if (self->type_index_ < TargetType::RuntimeTypeIndex()) return false; |
854 | // The rare slower-path, check type hierarchy. |
855 | return self->DerivedFrom(TargetType::RuntimeTypeIndex()); |
856 | } |
857 | } else { |
858 | return false; |
859 | } |
860 | } |
861 | |
862 | inline bool Object::unique() const { return use_count() == 1; } |
863 | |
864 | template <typename ObjectType> |
865 | inline const ObjectType* ObjectRef::as() const { |
866 | if (data_ != nullptr && data_->IsInstance<ObjectType>()) { |
867 | return static_cast<ObjectType*>(data_.get()); |
868 | } else { |
869 | return nullptr; |
870 | } |
871 | } |
872 | |
873 | template <typename RefType, typename ObjType> |
874 | inline RefType GetRef(const ObjType* ptr) { |
875 | static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value, |
876 | "Can only cast to the ref of same container type" ); |
877 | if (!RefType::_type_is_nullable) { |
878 | ICHECK(ptr != nullptr); |
879 | } |
880 | return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr)))); |
881 | } |
882 | |
883 | template <typename BaseType, typename ObjType> |
884 | inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) { |
885 | static_assert(std::is_base_of<BaseType, ObjType>::value, |
886 | "Can only cast to the ref of same container type" ); |
887 | return ObjectPtr<BaseType>(static_cast<Object*>(ptr)); |
888 | } |
889 | |
890 | template <typename SubRef, typename BaseRef> |
891 | inline SubRef Downcast(BaseRef ref) { |
892 | if (ref.defined()) { |
893 | ICHECK(ref->template IsInstance<typename SubRef::ContainerType>()) |
894 | << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key |
895 | << " failed." ; |
896 | } else { |
897 | ICHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of " |
898 | << SubRef::ContainerType::_type_key; |
899 | } |
900 | return SubRef(std::move(ref.data_)); |
901 | } |
902 | |
903 | } // namespace runtime |
904 | } // namespace tvm |
905 | |
906 | #endif // TVM_RUNTIME_OBJECT_H_ |
907 | |