1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file runtime/pooled_allocator.h |
22 | */ |
23 | #ifndef TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_ |
24 | #define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_ |
25 | |
26 | #include <tvm/runtime/device_api.h> |
27 | #include <tvm/runtime/vm/memory_manager.h> |
28 | |
29 | #include <atomic> |
30 | #include <mutex> |
31 | #include <unordered_map> |
32 | #include <vector> |
33 | |
34 | namespace tvm { |
35 | namespace runtime { |
36 | namespace vm { |
37 | |
38 | class PooledAllocator final : public Allocator { |
39 | public: |
40 | static constexpr size_t kDefaultPageSize = 4096; |
41 | |
42 | explicit PooledAllocator(Device dev, size_t page_size = kDefaultPageSize) |
43 | : Allocator(kPooled), page_size_(page_size), used_memory_(0), device_(dev) {} |
44 | |
45 | ~PooledAllocator() { ReleaseAll(); } |
46 | |
47 | Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override { |
48 | std::lock_guard<std::recursive_mutex> lock(mu_); |
49 | size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; |
50 | auto&& it = memory_pool_.find(size); |
51 | if (it != memory_pool_.end() && !it->second.empty()) { |
52 | auto&& pool = it->second; |
53 | auto ret = pool.back(); |
54 | pool.pop_back(); |
55 | return ret; |
56 | } |
57 | Buffer buf; |
58 | buf.device = device_; |
59 | buf.size = size; |
60 | try { |
61 | buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); |
62 | } catch (InternalError& err) { |
63 | LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); |
64 | LOG(WARNING) << "Trying to release all unused memory and reallocate..." ; |
65 | ReleaseAll(); |
66 | buf.data = DeviceAPI::Get(device_)->AllocDataSpace(device_, size, alignment, type_hint); |
67 | } |
68 | |
69 | used_memory_.fetch_add(size, std::memory_order_relaxed); |
70 | VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B" ; |
71 | return buf; |
72 | } |
73 | |
74 | void Free(const Buffer& buffer) override { |
75 | std::lock_guard<std::recursive_mutex> lock(mu_); |
76 | if (memory_pool_.find(buffer.size) == memory_pool_.end()) { |
77 | memory_pool_.emplace(buffer.size, std::vector<Buffer>{}); |
78 | } |
79 | memory_pool_.at(buffer.size).push_back(buffer); |
80 | VLOG(1) << "reclaim buffer " << buffer.size; |
81 | } |
82 | |
83 | size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); } |
84 | |
85 | private: |
86 | void ReleaseAll() { |
87 | std::lock_guard<std::recursive_mutex> lock(mu_); |
88 | for (auto const& it : memory_pool_) { |
89 | auto const& pool = it.second; |
90 | for (auto const& buf : pool) { |
91 | DeviceAPI::Get(buf.device)->FreeDataSpace(buf.device, buf.data); |
92 | } |
93 | } |
94 | memory_pool_.clear(); |
95 | used_memory_ = 0; |
96 | VLOG(1) << "release all buffers" ; |
97 | } |
98 | |
99 | private: |
100 | size_t page_size_; |
101 | std::atomic<size_t> used_memory_; |
102 | std::unordered_map<size_t, std::vector<Buffer>> memory_pool_; |
103 | std::recursive_mutex mu_; |
104 | Device device_; |
105 | }; |
106 | |
107 | } // namespace vm |
108 | } // namespace runtime |
109 | } // namespace tvm |
110 | |
111 | #endif // TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_ |
112 | |