1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file runtime/pooled_allocator.h
22 */
23#ifndef TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_
24#define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_
25
26#include <tvm/runtime/device_api.h>
27#include <tvm/runtime/vm/memory_manager.h>
28
29#include <atomic>
30#include <mutex>
31#include <unordered_map>
32#include <vector>
33
34namespace tvm {
35namespace runtime {
36namespace vm {
37
38class PooledAllocator final : public Allocator {
39 public:
40 static constexpr size_t kDefaultPageSize = 4096;
41
42 explicit PooledAllocator(Device dev, size_t page_size = kDefaultPageSize)
43 : Allocator(kPooled), page_size_(page_size), used_memory_(0), device_(dev) {}
44
45 ~PooledAllocator() { ReleaseAll(); }
46
47 Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override {
48 std::lock_guard<std::recursive_mutex> lock(mu_);
49 size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_;
50 auto&& it = memory_pool_.find(size);
51 if (it != memory_pool_.end() && !it->second.empty()) {
52 auto&& pool = it->second;
53 auto ret = pool.back();
54 pool.pop_back();
55 return ret;
56 }
57 Buffer buf;
58 buf.device = device_;
59 buf.size = size;
60 try {
61 buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint);
62 } catch (InternalError& err) {
63 LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message();
64 LOG(WARNING) << "Trying to release all unused memory and reallocate...";
65 ReleaseAll();
66 buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint);
67 }
68
69 used_memory_.fetch_add(size, std::memory_order_relaxed);
70 VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B";
71 return buf;
72 }
73
74 void Free(const Buffer& buffer) override {
75 std::lock_guard<std::recursive_mutex> lock(mu_);
76 if (memory_pool_.find(buffer.size) == memory_pool_.end()) {
77 memory_pool_.emplace(buffer.size, std::vector<Buffer>{});
78 }
79 memory_pool_.at(buffer.size).push_back(buffer);
80 VLOG(1) << "reclaim buffer " << buffer.size;
81 }
82
83 size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); }
84
85 private:
86 void ReleaseAll() {
87 std::lock_guard<std::recursive_mutex> lock(mu_);
88 for (auto const& it : memory_pool_) {
89 auto const& pool = it.second;
90 for (auto const& buf : pool) {
91 DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data);
92 }
93 }
94 memory_pool_.clear();
95 used_memory_ = 0;
96 VLOG(1) << "release all buffers";
97 }
98
99 private:
100 size_t page_size_;
101 std::atomic<size_t> used_memory_;
102 std::unordered_map<size_t, std::vector<Buffer>> memory_pool_;
103 std::recursive_mutex mu_;
104 Device device_;
105};
106
107} // namespace vm
108} // namespace runtime
109} // namespace tvm
110
111#endif // TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_
112