1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/runtime/memory.h
21 * \brief Runtime memory management.
22 */
23#ifndef TVM_RUNTIME_MEMORY_H_
24#define TVM_RUNTIME_MEMORY_H_
25
26#include <tvm/runtime/object.h>
27
28#include <cstdlib>
29#include <type_traits>
30#include <utility>
31
32namespace tvm {
33namespace runtime {
34/*!
35 * \brief Allocate an object using default allocator.
36 * \param args arguments to the constructor.
37 * \tparam T the node type.
38 * \return The ObjectPtr to the allocated object.
39 */
40template <typename T, typename... Args>
41inline ObjectPtr<T> make_object(Args&&... args);
42
43// Detail implementations after this
44//
45// The current design allows swapping the
46// allocator pattern when necessary.
47//
48// Possible future allocator optimizations:
49// - Arena allocator that gives ownership of memory to arena (deleter_= nullptr)
50// - Thread-local object pools: one pool per size and alignment requirement.
51// - Can specialize by type of object to give the specific allocator to each object.
52
53/*!
54 * \brief Base class of object allocators that implements make.
55 * Use curiously recurring template pattern.
56 *
57 * \tparam Derived The derived class.
58 */
59template <typename Derived>
60class ObjAllocatorBase {
61 public:
62 /*!
63 * \brief Make a new object using the allocator.
64 * \tparam T The type to be allocated.
65 * \tparam Args The constructor signature.
66 * \param args The arguments.
67 */
68 template <typename T, typename... Args>
69 inline ObjectPtr<T> make_object(Args&&... args) {
70 using Handler = typename Derived::template Handler<T>;
71 static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
72 T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...);
73 ptr->type_index_ = T::RuntimeTypeIndex();
74 ptr->deleter_ = Handler::Deleter();
75 return ObjectPtr<T>(ptr);
76 }
77
78 /*!
79 * \tparam ArrayType The type to be allocated.
80 * \tparam ElemType The type of array element.
81 * \tparam Args The constructor signature.
82 * \param num_elems The number of array elements.
83 * \param args The arguments.
84 */
85 template <typename ArrayType, typename ElemType, typename... Args>
86 inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
87 using Handler = typename Derived::template ArrayHandler<ArrayType, ElemType>;
88 static_assert(std::is_base_of<Object, ArrayType>::value,
89 "make_inplace_array can only be used to create Object");
90 ArrayType* ptr =
91 Handler::New(static_cast<Derived*>(this), num_elems, std::forward<Args>(args)...);
92 ptr->type_index_ = ArrayType::RuntimeTypeIndex();
93 ptr->deleter_ = Handler::Deleter();
94 return ObjectPtr<ArrayType>(ptr);
95 }
96};
97
98// Simple allocator that uses new/delete.
99class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> {
100 public:
101 template <typename T>
102 class Handler {
103 public:
104 using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
105
106 template <typename... Args>
107 static T* New(SimpleObjAllocator*, Args&&... args) {
108 // NOTE: the first argument is not needed for SimpleObjAllocator
109 // It is reserved for special allocators that needs to recycle
110 // the object to itself (e.g. in the case of object pool).
111 //
112 // In the case of an object pool, an allocator needs to create
113 // a special chunk memory that hides reference to the allocator
114 // and call allocator's release function in the deleter.
115
116 // NOTE2: Use inplace new to allocate
117 // This is used to get rid of warning when deleting a virtual
118 // class with non-virtual destructor.
119 // We are fine here as we captured the right deleter during construction.
120 // This is also the right way to get storage type for an object pool.
121 StorageType* data = new StorageType();
122 new (data) T(std::forward<Args>(args)...);
123 return reinterpret_cast<T*>(data);
124 }
125
126 static Object::FDeleter Deleter() { return Deleter_; }
127
128 private:
129 static void Deleter_(Object* objptr) {
130 // NOTE: this is important to cast back to T*
131 // because objptr and tptr may not be the same
132 // depending on how sub-class allocates the space.
133 T* tptr = static_cast<T*>(objptr);
134 // It is important to do tptr->T::~T(),
135 // so that we explicitly call the specific destructor
136 // instead of tptr->~T(), which could mean the intention
137 // call a virtual destructor(which may not be available and is not required).
138 tptr->T::~T();
139 delete reinterpret_cast<StorageType*>(tptr);
140 }
141 };
142
143 // Array handler that uses new/delete.
144 template <typename ArrayType, typename ElemType>
145 class ArrayHandler {
146 public:
147 using StorageType = typename std::aligned_storage<sizeof(ArrayType), alignof(ArrayType)>::type;
148 // for now only support elements that aligns with array header.
149 static_assert(alignof(ArrayType) % alignof(ElemType) == 0 &&
150 sizeof(ArrayType) % alignof(ElemType) == 0,
151 "element alignment constraint");
152
153 template <typename... Args>
154 static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) {
155 // NOTE: the first argument is not needed for ArrayObjAllocator
156 // It is reserved for special allocators that needs to recycle
157 // the object to itself (e.g. in the case of object pool).
158 //
159 // In the case of an object pool, an allocator needs to create
160 // a special chunk memory that hides reference to the allocator
161 // and call allocator's release function in the deleter.
162 // NOTE2: Use inplace new to allocate
163 // This is used to get rid of warning when deleting a virtual
164 // class with non-virtual destructor.
165 // We are fine here as we captured the right deleter during construction.
166 // This is also the right way to get storage type for an object pool.
167 size_t unit = sizeof(StorageType);
168 size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType);
169 size_t num_storage_slots = (requested_size + unit - 1) / unit;
170 StorageType* data = new StorageType[num_storage_slots];
171 new (data) ArrayType(std::forward<Args>(args)...);
172 return reinterpret_cast<ArrayType*>(data);
173 }
174
175 static Object::FDeleter Deleter() { return Deleter_; }
176
177 private:
178 static void Deleter_(Object* objptr) {
179 // NOTE: this is important to cast back to ArrayType*
180 // because objptr and tptr may not be the same
181 // depending on how sub-class allocates the space.
182 ArrayType* tptr = static_cast<ArrayType*>(objptr);
183 // It is important to do tptr->ArrayType::~ArrayType(),
184 // so that we explicitly call the specific destructor
185 // instead of tptr->~ArrayType(), which could mean the intention
186 // call a virtual destructor(which may not be available and is not required).
187 tptr->ArrayType::~ArrayType();
188 StorageType* p = reinterpret_cast<StorageType*>(tptr);
189 delete[] p;
190 }
191 };
192};
193
194template <typename T, typename... Args>
195inline ObjectPtr<T> make_object(Args&&... args) {
196 return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
197}
198
199template <typename ArrayType, typename ElemType, typename... Args>
200inline ObjectPtr<ArrayType> make_inplace_array_object(size_t num_elems, Args&&... args) {
201 return SimpleObjAllocator().make_inplace_array<ArrayType, ElemType>(num_elems,
202 std::forward<Args>(args)...);
203}
204
205} // namespace runtime
206} // namespace tvm
207#endif // TVM_RUNTIME_MEMORY_H_
208