1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file workspace_pool.h |
22 | * \brief Workspace pool utility. |
23 | */ |
24 | #include "workspace_pool.h" |
25 | |
26 | #include <memory> |
27 | |
28 | namespace tvm { |
29 | namespace runtime { |
30 | |
31 | // page size. |
32 | constexpr size_t kWorkspacePageSize = 4 << 10; |
33 | |
34 | class WorkspacePool::Pool { |
35 | public: |
36 | // constructor |
37 | Pool() { |
38 | // safe guard header on each list. |
39 | Entry e; |
40 | e.data = nullptr; |
41 | e.size = 0; |
42 | free_list_.push_back(e); |
43 | allocated_.push_back(e); |
44 | } |
45 | // allocate from pool |
46 | void* Alloc(Device dev, DeviceAPI* device, size_t nbytes) { |
47 | // Allocate align to page. |
48 | nbytes = (nbytes + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize; |
49 | if (nbytes == 0) nbytes = kWorkspacePageSize; |
50 | Entry e; |
51 | DLDataType type; |
52 | type.code = kDLUInt; |
53 | type.bits = 8; |
54 | type.lanes = 1; |
55 | if (free_list_.size() == 2) { |
56 | e = free_list_.back(); |
57 | free_list_.pop_back(); |
58 | if (e.size < nbytes) { |
59 | // resize the page |
60 | device->FreeDataSpace(dev, e.data); |
61 | e.data = device->AllocDataSpace(dev, nbytes, kTempAllocaAlignment, type); |
62 | e.size = nbytes; |
63 | } |
64 | } else if (free_list_.size() == 1) { |
65 | e.data = device->AllocDataSpace(dev, nbytes, kTempAllocaAlignment, type); |
66 | e.size = nbytes; |
67 | } else { |
68 | if (free_list_.back().size >= nbytes) { |
69 | // find smallest fit |
70 | auto it = free_list_.end() - 2; |
71 | for (; it->size >= nbytes; --it) { |
72 | } |
73 | e = *(it + 1); |
74 | free_list_.erase(it + 1); |
75 | } else { |
76 | // resize the page |
77 | e = free_list_.back(); |
78 | free_list_.pop_back(); |
79 | device->FreeDataSpace(dev, e.data); |
80 | e.data = device->AllocDataSpace(dev, nbytes, kTempAllocaAlignment, type); |
81 | e.size = nbytes; |
82 | } |
83 | } |
84 | allocated_.push_back(e); |
85 | return e.data; |
86 | } |
87 | // free resource back to pool |
88 | void Free(void* data) { |
89 | Entry e; |
90 | if (allocated_.back().data == data) { |
91 | // quick path, last allocated. |
92 | e = allocated_.back(); |
93 | allocated_.pop_back(); |
94 | } else { |
95 | int index = static_cast<int>(allocated_.size()) - 2; |
96 | for (; index > 0 && allocated_[index].data != data; --index) { |
97 | } |
98 | ICHECK_GT(index, 0) << "trying to free things that has not been allocated" ; |
99 | e = allocated_[index]; |
100 | allocated_.erase(allocated_.begin() + index); |
101 | } |
102 | if (free_list_.back().size < e.size) { |
103 | free_list_.push_back(e); |
104 | } else if (free_list_.size() == 2) { |
105 | free_list_.push_back(free_list_.back()); |
106 | free_list_[1] = e; |
107 | } else { |
108 | size_t i = free_list_.size() - 1; |
109 | free_list_.resize(free_list_.size() + 1); |
110 | for (; e.size < free_list_[i].size; --i) { |
111 | free_list_[i + 1] = free_list_[i]; |
112 | } |
113 | free_list_[i + 1] = e; |
114 | } |
115 | } |
116 | // Release all resources |
117 | void Release(Device dev, DeviceAPI* device) { |
118 | for (size_t i = 1; i < free_list_.size(); ++i) { |
119 | device->FreeDataSpace(dev, free_list_[i].data); |
120 | } |
121 | free_list_.clear(); |
122 | } |
123 | |
124 | private: |
125 | /*! \brief a single entry in the pool */ |
126 | struct Entry { |
127 | void* data; |
128 | size_t size; |
129 | }; |
130 | /*! \brief List of free items, sorted from small to big size */ |
131 | std::vector<Entry> free_list_; |
132 | /*! \brief List of allocated items */ |
133 | std::vector<Entry> allocated_; |
134 | }; |
135 | |
136 | WorkspacePool::WorkspacePool(DLDeviceType device_type, DeviceAPI* device) |
137 | : device_type_(device_type), device_(device) {} |
138 | |
139 | WorkspacePool::~WorkspacePool() { |
140 | for (size_t i = 0; i < array_.size(); ++i) { |
141 | if (array_[i] != nullptr) { |
142 | Device dev; |
143 | dev.device_type = device_type_; |
144 | dev.device_id = static_cast<int>(i); |
145 | array_[i]->Release(dev, device_); |
146 | delete array_[i]; |
147 | } |
148 | } |
149 | } |
150 | |
151 | void* WorkspacePool::AllocWorkspace(Device dev, size_t size) { |
152 | if (static_cast<size_t>(dev.device_id) >= array_.size()) { |
153 | array_.resize(dev.device_id + 1, nullptr); |
154 | } |
155 | if (array_[dev.device_id] == nullptr) { |
156 | array_[dev.device_id] = new Pool(); |
157 | } |
158 | return array_[dev.device_id]->Alloc(dev, device_, size); |
159 | } |
160 | |
161 | void WorkspacePool::FreeWorkspace(Device dev, void* ptr) { |
162 | ICHECK(static_cast<size_t>(dev.device_id) < array_.size() && array_[dev.device_id] != nullptr); |
163 | array_[dev.device_id]->Free(ptr); |
164 | } |
165 | |
166 | } // namespace runtime |
167 | } // namespace tvm |
168 | |