1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file workspace_pool.h
22 * \brief Workspace pool utility.
23 */
24#include "workspace_pool.h"
25
26#include <memory>
27
28namespace tvm {
29namespace runtime {
30
31// page size.
32constexpr size_t kWorkspacePageSize = 4 << 10;
33
34class WorkspacePool::Pool {
35 public:
36 // constructor
37 Pool() {
38 // safe guard header on each list.
39 Entry e;
40 e.data = nullptr;
41 e.size = 0;
42 free_list_.push_back(e);
43 allocated_.push_back(e);
44 }
45 // allocate from pool
46 void* Alloc(Device dev, DeviceAPI* device, size_t nbytes) {
47 // Allocate align to page.
48 nbytes = (nbytes + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize;
49 if (nbytes == 0) nbytes = kWorkspacePageSize;
50 Entry e;
51 DLDataType type;
52 type.code = kDLUInt;
53 type.bits = 8;
54 type.lanes = 1;
55 if (free_list_.size() == 2) {
56 e = free_list_.back();
57 free_list_.pop_back();
58 if (e.size < nbytes) {
59 // resize the page
60 device->FreeDataSpace(dev, e.data);
61 e.data = device->AllocDataSpace(dev, nbytes, kTempAllocaAlignment, type);
62 e.size = nbytes;
63 }
64 } else if (free_list_.size() == 1) {
65 e.data = device->AllocDataSpace(dev, nbytes, kTempAllocaAlignment, type);
66 e.size = nbytes;
67 } else {
68 if (free_list_.back().size >= nbytes) {
69 // find smallest fit
70 auto it = free_list_.end() - 2;
71 for (; it->size >= nbytes; --it) {
72 }
73 e = *(it + 1);
74 free_list_.erase(it + 1);
75 } else {
76 // resize the page
77 e = free_list_.back();
78 free_list_.pop_back();
79 device->FreeDataSpace(dev, e.data);
80 e.data = device->AllocDataSpace(dev, nbytes, kTempAllocaAlignment, type);
81 e.size = nbytes;
82 }
83 }
84 allocated_.push_back(e);
85 return e.data;
86 }
87 // free resource back to pool
88 void Free(void* data) {
89 Entry e;
90 if (allocated_.back().data == data) {
91 // quick path, last allocated.
92 e = allocated_.back();
93 allocated_.pop_back();
94 } else {
95 int index = static_cast<int>(allocated_.size()) - 2;
96 for (; index > 0 && allocated_[index].data != data; --index) {
97 }
98 ICHECK_GT(index, 0) << "trying to free things that has not been allocated";
99 e = allocated_[index];
100 allocated_.erase(allocated_.begin() + index);
101 }
102 if (free_list_.back().size < e.size) {
103 free_list_.push_back(e);
104 } else if (free_list_.size() == 2) {
105 free_list_.push_back(free_list_.back());
106 free_list_[1] = e;
107 } else {
108 size_t i = free_list_.size() - 1;
109 free_list_.resize(free_list_.size() + 1);
110 for (; e.size < free_list_[i].size; --i) {
111 free_list_[i + 1] = free_list_[i];
112 }
113 free_list_[i + 1] = e;
114 }
115 }
116 // Release all resources
117 void Release(Device dev, DeviceAPI* device) {
118 for (size_t i = 1; i < free_list_.size(); ++i) {
119 device->FreeDataSpace(dev, free_list_[i].data);
120 }
121 free_list_.clear();
122 }
123
124 private:
125 /*! \brief a single entry in the pool */
126 struct Entry {
127 void* data;
128 size_t size;
129 };
130 /*! \brief List of free items, sorted from small to big size */
131 std::vector<Entry> free_list_;
132 /*! \brief List of allocated items */
133 std::vector<Entry> allocated_;
134};
135
136WorkspacePool::WorkspacePool(DLDeviceType device_type, DeviceAPI* device)
137 : device_type_(device_type), device_(device) {}
138
139WorkspacePool::~WorkspacePool() {
140 for (size_t i = 0; i < array_.size(); ++i) {
141 if (array_[i] != nullptr) {
142 Device dev;
143 dev.device_type = device_type_;
144 dev.device_id = static_cast<int>(i);
145 array_[i]->Release(dev, device_);
146 delete array_[i];
147 }
148 }
149}
150
151void* WorkspacePool::AllocWorkspace(Device dev, size_t size) {
152 if (static_cast<size_t>(dev.device_id) >= array_.size()) {
153 array_.resize(dev.device_id + 1, nullptr);
154 }
155 if (array_[dev.device_id] == nullptr) {
156 array_[dev.device_id] = new Pool();
157 }
158 return array_[dev.device_id]->Alloc(dev, device_, size);
159}
160
161void WorkspacePool::FreeWorkspace(Device dev, void* ptr) {
162 ICHECK(static_cast<size_t>(dev.device_id) < array_.size() && array_[dev.device_id] != nullptr);
163 array_[dev.device_id]->Free(ptr);
164}
165
166} // namespace runtime
167} // namespace tvm
168