1 | /* |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Support/TensorPool.h" |
18 | |
19 | namespace glow { |
20 | |
21 | llvm::Optional<Tensor> TensorPool::get(TypeRef ty) { |
22 | stats_.totalGets++; |
23 | |
24 | std::unique_lock<std::mutex> l(lock_); |
25 | |
26 | auto it = pools_.find(*ty); |
27 | |
28 | if (it == pools_.end()) { |
29 | if (preventInlineAllocs_) { |
30 | return llvm::Optional<Tensor>(); |
31 | } |
32 | |
33 | stats_.totalTypes++; |
34 | it = pools_.emplace(*ty, std::vector<Tensor>()).first; |
35 | } |
36 | |
37 | if (it->second.empty()) { |
38 | if (preventInlineAllocs_) { |
39 | return llvm::Optional<Tensor>(); |
40 | } |
41 | |
42 | // Don't need to alloc under the lock. |
43 | l.unlock(); |
44 | stats_.totalAllocs++; |
45 | stats_.inlineAllocs++; |
46 | // Don't add it to the queue because it's being claimed now. |
47 | return Tensor(ty, this); |
48 | } |
49 | |
50 | auto &queue = it->second; |
51 | Tensor t = std::move(queue.back()); |
52 | queue.pop_back(); |
53 | stats_.currentBuffers--; |
54 | return t; |
55 | } |
56 | |
57 | void TensorPool::reclaim(Tensor &&t) { |
58 | std::lock_guard<std::mutex> l(lock_); |
59 | auto it = pools_.find(t.getType()); |
60 | assert(it != pools_.end() && "Type has not been initialized" ); |
61 | stats_.totalReclaims++; |
62 | stats_.currentBuffers++; |
63 | it->second.emplace_back(std::move(t)); |
64 | } |
65 | |
66 | void TensorPool::reserve(TypeRef ty, size_t count) { |
67 | std::vector<Tensor> temp; |
68 | temp.reserve(count); |
69 | for (unsigned i = 0; i < count; ++i) { |
70 | stats_.totalAllocs++; |
71 | temp.emplace_back(ty, this); |
72 | } |
73 | |
74 | { |
75 | std::lock_guard<std::mutex> l(lock_); |
76 | auto it = pools_.find(*ty); |
77 | if (it == pools_.end()) { |
78 | stats_.totalTypes++; |
79 | } |
80 | |
81 | std::vector<Tensor> &queue = pools_[*ty]; |
82 | std::move(temp.begin(), temp.end(), std::back_inserter(queue)); |
83 | stats_.currentBuffers += count; |
84 | } |
85 | } |
86 | |
87 | void TensorPool::clear() { |
88 | std::lock_guard<std::mutex> l(lock_); |
89 | for (auto &p : pools_) { |
90 | stats_.currentBuffers -= p.second.size(); |
91 | stats_.totalFrees += p.second.size(); |
92 | p.second.clear(); |
93 | } |
94 | assert(stats_.currentBuffers == 0); |
95 | } |
96 | |
97 | } // namespace glow |
98 | |