1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Support/TensorPool.h"
17#include "glow/Graph/Graph.h"
18#include "glow/Graph/PlaceholderBindings.h"
19#include "gtest/gtest.h"
20
21#include "llvm/ADT/STLExtras.h"
22
23#include <future>
24#include <vector>
25
26using namespace glow;
27
28/// Can get Tensor from the pool without allocation.
29TEST(TensorPool, BasicTest) {
30 TensorPool pool;
31 Type ty(ElemKind::FloatTy, {1, 2, 3});
32 pool.reserve(&ty, 1);
33
34 Tensor T = std::move(pool.get(&ty).getValue());
35 EXPECT_TRUE(T.getType().isEqual(ty));
36 EXPECT_EQ(T.dims(), ty.dims());
37
38 const auto &stats = pool.getStats();
39 EXPECT_EQ(stats.totalTypes, 1);
40 EXPECT_EQ(stats.currentBuffers, 0);
41 EXPECT_EQ(stats.totalAllocs, 1);
42 EXPECT_EQ(stats.inlineAllocs, 0);
43 EXPECT_EQ(stats.totalGets, 1);
44 EXPECT_EQ(stats.totalReclaims, 0);
45
46 pool.reclaim(std::move(T));
47}
48
49/// Can get a tensor, return it and get it again without allocation.
50TEST(TensorPool, ReclaimAndGet) {
51 TensorPool pool;
52 Type ty(ElemKind::FloatTy, {1, 2, 3});
53 pool.reserve(&ty, 1);
54
55 Tensor T = std::move(pool.get(&ty).getValue());
56 auto *backingPtr = T.getUnsafePtr();
57
58 pool.reclaim(std::move(T));
59
60 Tensor T2 = std::move(pool.get(&ty).getValue());
61 // They are the same buffer.
62 EXPECT_EQ(T2.getUnsafePtr(), backingPtr);
63
64 const auto &stats = pool.getStats();
65 EXPECT_EQ(stats.totalTypes, 1);
66 EXPECT_EQ(stats.currentBuffers, 0);
67 EXPECT_EQ(stats.totalAllocs, 1);
68 EXPECT_EQ(stats.inlineAllocs, 0);
69 EXPECT_EQ(stats.totalGets, 2);
70 EXPECT_EQ(stats.totalReclaims, 1);
71
72 pool.reclaim(std::move(T2));
73}
74
75/// The pool auto resizes when it's empty.
76TEST(TensorPool, Extends) {
77 TensorPool pool;
78 Type ty(ElemKind::FloatTy, {1, 2, 3});
79 pool.reserve(&ty, 1);
80
81 Tensor T = std::move(pool.get(&ty).getValue());
82 Tensor T2 = std::move(pool.get(&ty).getValue());
83 EXPECT_TRUE(T.getType().isEqual(T2.getType()));
84 EXPECT_TRUE(T.getType().isEqual(ty));
85 EXPECT_TRUE(T2.getType().isEqual(ty));
86
87 // They are not the same buffer.
88 EXPECT_NE(T.getUnsafePtr(), T2.getUnsafePtr());
89
90 const auto &stats = pool.getStats();
91 EXPECT_EQ(stats.totalTypes, 1);
92 EXPECT_EQ(stats.currentBuffers, 0);
93 EXPECT_EQ(stats.totalAllocs, 2);
94 EXPECT_EQ(stats.inlineAllocs, 1);
95 EXPECT_EQ(stats.totalGets, 2);
96 EXPECT_EQ(stats.totalReclaims, 0);
97
98 pool.reclaim(std::move(T));
99 pool.reclaim(std::move(T2));
100}
101
102/// The pool doesn't resize when you tell it not to.
103TEST(TensorPool, DoesntExtend) {
104 TensorPool pool(true);
105 Type ty(ElemKind::FloatTy, {1, 2, 3});
106 pool.reserve(&ty, 1);
107
108 Tensor T = std::move(pool.get(&ty).getValue());
109 Type Tt = T.getType();
110
111 auto T2opt = pool.get(&ty);
112 EXPECT_FALSE(T2opt.hasValue());
113
114 pool.reclaim(std::move(T));
115
116 T = std::move(pool.get(&ty).getValue());
117 EXPECT_EQ(Tt, T.getType());
118
119 const auto &stats = pool.getStats();
120 EXPECT_EQ(stats.totalTypes, 1);
121 EXPECT_EQ(stats.currentBuffers, 0);
122 EXPECT_EQ(stats.totalAllocs, 1);
123 EXPECT_EQ(stats.inlineAllocs, 0);
124 EXPECT_EQ(stats.totalGets, 3);
125 EXPECT_EQ(stats.totalReclaims, 1);
126
127 pool.reclaim(std::move(T));
128}
129
130/// Still works if you don't reserve it.
131TEST(TensorPool, Noreserve) {
132 TensorPool pool;
133 Type ty(ElemKind::FloatTy, {1, 2, 3});
134
135 Tensor T = std::move(pool.get(&ty).getValue());
136 Tensor T2 = std::move(pool.get(&ty).getValue());
137
138 EXPECT_TRUE(T.getType().isEqual(T2.getType()));
139
140 const auto &stats = pool.getStats();
141 EXPECT_EQ(stats.totalTypes, 1);
142 EXPECT_EQ(stats.currentBuffers, 0);
143 EXPECT_EQ(stats.totalAllocs, 2);
144 EXPECT_EQ(stats.inlineAllocs, 2);
145 EXPECT_EQ(stats.totalGets, 2);
146 EXPECT_EQ(stats.totalReclaims, 0);
147
148 pool.reclaim(std::move(T));
149 pool.reclaim(std::move(T2));
150}
151
152/// Can handle multiple types of Tensors.
153TEST(TensorPool, MultipleTypes) {
154 TensorPool pool;
155 Type ty(ElemKind::FloatTy, {1, 2, 3});
156 Type ty2(ElemKind::Int8QTy, {3, 2, 1}, 1.0, 4);
157
158 // Six total buffers.
159 pool.reserve(&ty, 1);
160 pool.reserve(&ty2, 5);
161
162 std::vector<Tensor> tensors;
163 // Ten total allocs.
164 for (int i = 0; i < 5; ++i) {
165 Tensor T = std::move(pool.get(&ty).getValue());
166 Tensor T2 = std::move(pool.get(&ty2).getValue());
167 EXPECT_FALSE(T.getType().isEqual(T2.getType()));
168 EXPECT_TRUE(T.getType().isEqual(ty));
169 EXPECT_TRUE(T2.getType().isEqual(ty2));
170 EXPECT_NE(T.dims(), T2.dims());
171 EXPECT_NE(T.getUnsafePtr(), T2.getUnsafePtr());
172
173 tensors.emplace_back(std::move(T));
174 tensors.emplace_back(std::move(T2));
175 }
176
177 const auto &stats = pool.getStats();
178 EXPECT_EQ(stats.totalTypes, 2);
179 EXPECT_EQ(stats.currentBuffers, 0);
180 EXPECT_EQ(stats.totalAllocs, 10);
181 EXPECT_EQ(stats.inlineAllocs, 4); // Four allocs inline.
182 EXPECT_EQ(stats.totalGets, 10);
183 EXPECT_EQ(stats.totalReclaims, 0);
184
185 for (auto &t : tensors) {
186 pool.reclaim(std::move(t));
187 }
188
189 const auto &stats2 = pool.getStats();
190 EXPECT_EQ(stats2.totalTypes, 2);
191 EXPECT_EQ(stats2.currentBuffers, 10);
192 EXPECT_EQ(stats.totalReclaims, 10);
193}
194
195/// Reclaims still work with multiple types of Tensors.
196TEST(TensorPool, MultipleTypesReclaim) {
197 TensorPool pool;
198 Type ty(ElemKind::FloatTy, {1, 2, 3});
199 Type ty2(ElemKind::Int8QTy, {3, 2, 1}, 1.0, 4);
200 pool.reserve(&ty, 1);
201 pool.reserve(&ty2, 1);
202
203 Tensor T = std::move(pool.get(&ty).getValue());
204 Tensor T2 = std::move(pool.get(&ty2).getValue());
205
206 pool.reclaim(std::move(T));
207 pool.reclaim(std::move(T2));
208
209 T = std::move(pool.get(&ty).getValue());
210 T2 = std::move(pool.get(&ty2).getValue());
211
212 pool.reclaim(std::move(T));
213 pool.reclaim(std::move(T2));
214
215 const auto &stats = pool.getStats();
216 EXPECT_EQ(stats.totalTypes, 2);
217 EXPECT_EQ(stats.currentBuffers, 2);
218 EXPECT_EQ(stats.totalAllocs, 2);
219 EXPECT_EQ(stats.inlineAllocs, 0);
220 EXPECT_EQ(stats.totalGets, 4);
221 EXPECT_EQ(stats.totalReclaims, 4);
222}
223
224/// Inserting a managed Tensor into the PlaceholderBindings does reclaim when
225/// the bindings are cleared or destroyed.
226TEST(TensorPool, PlaceholderBindingsReclaim) {
227 TensorPool pool;
228 Type ty(ElemKind::FloatTy, {1, 2, 3});
229
230 PlaceholderBindings bindings;
231 Module mod;
232
233 auto *PH = mod.createPlaceholder(&ty, "test", false);
234 bindings.insert(PH, std::move(pool.get(&ty).getValue()));
235
236 /// Insert a non managed tensor.
237 auto *PH2 = mod.createPlaceholder(&ty, "test2", false);
238 Tensor T2(ty);
239 bindings.insert(PH2, std::move(T2));
240
241 bindings.clear();
242
243 /// Bindings had two Tensors but only the first was reclaimed.
244 const auto &stats = pool.getStats();
245 EXPECT_EQ(stats.totalTypes, 1);
246 EXPECT_EQ(stats.currentBuffers, 1);
247 EXPECT_EQ(stats.totalAllocs, 1);
248 EXPECT_EQ(stats.inlineAllocs, 1);
249 EXPECT_EQ(stats.totalGets, 1);
250 EXPECT_EQ(stats.totalReclaims, 1);
251
252 bindings.insert(PH, std::move(pool.get(&ty).getValue()));
253
254 bindings.erase(PH);
255 const auto &stats2 = pool.getStats();
256 EXPECT_EQ(stats.currentBuffers, 1);
257 EXPECT_EQ(stats.totalGets, 2);
258 EXPECT_EQ(stats2.totalReclaims, 2);
259}
260
261/// Clearing the Tensor pool removes contents but the pool still works.
262TEST(TensorPool, Clear) {
263 TensorPool pool;
264 Type ty(ElemKind::FloatTy, {1, 2, 3});
265
266 Tensor T = std::move(pool.get(&ty).getValue());
267 pool.reclaim(std::move(T));
268
269 const auto &stats = pool.getStats();
270 EXPECT_EQ(stats.totalTypes, 1);
271 EXPECT_EQ(stats.currentBuffers, 1);
272 EXPECT_EQ(stats.totalAllocs, 1);
273 EXPECT_EQ(stats.inlineAllocs, 1);
274 EXPECT_EQ(stats.totalGets, 1);
275 EXPECT_EQ(stats.totalReclaims, 1);
276 EXPECT_EQ(stats.totalFrees, 0);
277
278 pool.clear();
279
280 T = std::move(pool.get(&ty).getValue());
281 pool.reclaim(std::move(T));
282
283 const auto &stats2 = pool.getStats();
284 EXPECT_EQ(stats2.totalTypes, 1);
285 EXPECT_EQ(stats2.currentBuffers, 1);
286 EXPECT_EQ(stats2.totalAllocs, 2);
287 EXPECT_EQ(stats2.inlineAllocs, 2);
288 EXPECT_EQ(stats2.totalGets, 2);
289 EXPECT_EQ(stats2.totalReclaims, 2);
290 EXPECT_EQ(stats2.totalFrees, 1);
291}
292