1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/runtime/vm/memory_manager.cc
22 * \brief Allocate and manage memory for the runtime.
23 */
24#include <tvm/runtime/vm/memory_manager.h>
25
26#include <memory>
27#include <utility>
28
29#include "naive_allocator.h"
30#include "pooled_allocator.h"
31
32namespace tvm {
33namespace runtime {
34namespace vm {
35
36static void BufferDeleter(Object* obj) {
37 auto* ptr = static_cast<NDArray::Container*>(obj);
38 ICHECK(ptr->manager_ctx != nullptr);
39 Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
40 MemoryManager::GetAllocator(buffer->device)->Free(*(buffer));
41 delete buffer;
42 delete ptr;
43}
44
45void StorageObj::Deleter(Object* obj) {
46 auto* ptr = static_cast<NDArray::Container*>(obj);
47 // When invoking AllocNDArray we don't own the underlying allocation
48 // and should not delete the buffer, but instead let it be reclaimed
49 // by the storage object's destructor.
50 //
51 // We did bump the reference count by 1 to keep alive the StorageObj
52 // allocation in case this NDArray is the sole owner.
53 //
54 // We decrement the object allowing for the buffer to release our
55 // reference count from allocation.
56 StorageObj* storage = reinterpret_cast<StorageObj*>(ptr->manager_ctx);
57 storage->DecRef();
58 delete ptr;
59}
60
61inline void VerifyDataType(DLDataType dtype) {
62 ICHECK_GE(dtype.lanes, 1);
63 if (dtype.code == kDLFloat) {
64 ICHECK_EQ(dtype.bits % 8, 0);
65 } else {
66 // allow uint1 as a special flag for bool.
67 if (dtype.bits == 1 && dtype.code == kDLUInt) return;
68 ICHECK_EQ(dtype.bits % 8, 0);
69 }
70 ICHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
71}
72
73inline size_t GetDataAlignment(const DLTensor& arr) {
74 size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
75 if (align < kAllocAlignment) return kAllocAlignment;
76 return align;
77}
78
79NDArray StorageObj::AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDataType dtype) {
80 VerifyDataType(dtype);
81
82 // crtical zone: allocate header, cannot throw
83 NDArray::Container* container =
84 new NDArray::Container(this->buffer.data, shape, dtype, this->buffer.device);
85 container->dl_tensor.byte_offset = offset;
86
87 container->SetDeleter(StorageObj::Deleter);
88 size_t needed_size = GetDataSize(container->dl_tensor);
89 this->IncRef();
90 // The manager context pointer must continue to point to the storage object
91 // which owns the backing memory, and keeps track of the reference count.
92 //
93 // When we free a container we extract the storage object, decrement its
94 // reference count, then destroy the container, but leave the underlying
95 // buffer intact.
96 container->manager_ctx = reinterpret_cast<void*>(this);
97
98 NDArray ret(GetObjectPtr<Object>(container));
99 // RAII in effect, now run the check.
100
101 ICHECK(offset + needed_size <= this->buffer.size)
102 << "storage allocation failure, attempted to allocate " << needed_size << " at offset "
103 << offset << " in region that is " << this->buffer.size << "bytes";
104
105 return ret;
106}
107
108MemoryManager* MemoryManager::Global() {
109 // NOTE: explicitly use new to avoid exit-time destruction of global state
110 // Global state will be recycled by OS as the process exits.
111 static auto* inst = new MemoryManager();
112 return inst;
113}
114
115Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
116 MemoryManager* m = MemoryManager::Global();
117 std::lock_guard<std::mutex> lock(m->mu_);
118 if (m->allocators_.find(dev) == m->allocators_.end()) {
119 std::unique_ptr<Allocator> alloc;
120 switch (type) {
121 case kNaive: {
122 VLOG(1) << "New naive allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
123 << ")";
124 alloc.reset(new NaiveAllocator(dev));
125 break;
126 }
127 case kPooled: {
128 VLOG(1) << "New pooled allocator for " << DeviceName(dev.device_type) << "("
129 << dev.device_id << ")";
130 alloc.reset(new PooledAllocator(dev));
131 break;
132 }
133 default:
134 LOG(FATAL) << "Unknown allocator type: " << type;
135 }
136 auto ret = alloc.get();
137 m->allocators_.emplace(dev, std::move(alloc));
138 return ret;
139 }
140 auto alloc = m->allocators_.at(dev).get();
141 if (alloc->type() != type) {
142 LOG(WARNING) << "The type of existing allocator for " << DeviceName(dev.device_type) << "("
143 << dev.device_id << ") is different from the request type (" << alloc->type()
144 << " vs " << type << ")";
145 }
146 return alloc;
147}
148
149Allocator* MemoryManager::GetAllocator(Device dev) {
150 MemoryManager* m = MemoryManager::Global();
151 std::lock_guard<std::mutex> lock(m->mu_);
152 auto it = m->allocators_.find(dev);
153 if (it == m->allocators_.end()) {
154 LOG(FATAL) << "Allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
155 << ") has not been created yet.";
156 }
157 return it->second.get();
158}
159
160NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLDevice dev) {
161 VerifyDataType(dtype);
162 NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, dev);
163 container->SetDeleter(BufferDeleter);
164 size_t size = GetDataSize(container->dl_tensor);
165 size_t alignment = GetDataAlignment(container->dl_tensor);
166 Buffer* buffer = new Buffer;
167 *buffer = this->Alloc(size, alignment, dtype);
168 container->manager_ctx = reinterpret_cast<void*>(buffer);
169 container->dl_tensor.data = buffer->data;
170 return NDArray(GetObjectPtr<Object>(container));
171}
172
173} // namespace vm
174} // namespace runtime
175} // namespace tvm
176