1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "tools/loader/Loader.h" |
18 | #include "ImporterTestUtils.h" |
19 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
20 | #include "glow/Graph/Graph.h" |
21 | #include "glow/Importer/Caffe2ModelLoader.h" |
22 | #include "glow/Importer/ONNXModelLoader.h" |
23 | #include "gtest/gtest.h" |
24 | |
25 | #include "llvm/ADT/StringMap.h" |
26 | |
27 | #ifndef GLOW_DATA_PATH |
28 | #define GLOW_DATA_PATH |
29 | #endif |
30 | |
31 | class LoaderTest : public ::testing::Test { |
32 | protected: |
33 | // By default constant folding at load time is enabled in general, but we do |
34 | // many tests here loading Constants, so keep it false during these tests by |
35 | // default. |
36 | void SetUp() override { glow::setConstantFoldLoaderOpsFlag(false); } |
37 | void TearDown() override { glow::setConstantFoldLoaderOpsFlag(true); } |
38 | }; |
39 | |
40 | using namespace glow; |
41 | |
42 | namespace { |
43 | const dim_t BATCH_SIZE = 3; |
44 | const dim_t BATCH_SIZE_TFLITE = 1; |
45 | const size_t MINI_BATCH_SIZE = 2; |
46 | } // namespace |
47 | |
48 | // A Loader extension class for testing purpose. |
49 | class testLoaderExtension : public LoaderExtension { |
50 | public: |
51 | static int stage_; |
52 | static size_t index_; |
53 | static Loader *loader_; |
54 | static PlaceholderBindings *bindings_; |
55 | static ProtobufLoader *protobufLoader_; |
56 | static TFLiteModelLoader *tfliteloader_; |
57 | static bool destructed_; |
58 | |
59 | testLoaderExtension() { |
60 | stage_ = 0; |
61 | index_ = 0; |
62 | loader_ = nullptr; |
63 | bindings_ = nullptr; |
64 | protobufLoader_ = nullptr; |
65 | destructed_ = false; |
66 | } |
67 | |
68 | /// Called once after ONNX or Caffe2 model loading. |
69 | virtual void postModelLoad(Loader &loader, PlaceholderBindings &bindings, |
70 | ProtobufLoader &protobufLoader, |
71 | llvm::StringMap<Placeholder *> &outputMap, |
72 | llvm::ArrayRef<TypeRef> inputImageType) { |
73 | |
74 | size_t compilationBatchSize = inputImageType[0]->dims()[0]; |
75 | // To check the method was executed. |
76 | stage_ = 1; |
77 | |
78 | // To check params are correctly set. |
79 | loader_ = &loader; |
80 | bindings_ = &bindings; |
81 | protobufLoader_ = &protobufLoader; |
82 | EXPECT_EQ(BATCH_SIZE, compilationBatchSize); |
83 | } |
84 | /// Called once at the beginning of the mini-batch inference. |
85 | virtual void inferInitMiniBatch(Loader &loader, PlaceholderBindings &bindings, |
86 | size_t minibatchIndex, size_t minibatchSize) { |
87 | // To check the method was executed. |
88 | stage_ = 2; |
89 | |
90 | // To check params are correctly set. |
91 | loader_ = &loader; |
92 | bindings_ = &bindings; |
93 | index_ = minibatchIndex; |
94 | EXPECT_EQ(MINI_BATCH_SIZE, minibatchSize); |
95 | } |
96 | /// Called once after the completion of the mini-batch inference. |
97 | virtual void inferEndMiniBatch(Loader &loader, PlaceholderBindings &bindings, |
98 | size_t minibatchIndex, size_t minibatchSize) { |
99 | // To check the method was executed. |
100 | stage_ = 3; |
101 | |
102 | // To check params are correctly set. |
103 | loader_ = &loader; |
104 | bindings_ = &bindings; |
105 | index_ = minibatchIndex; |
106 | EXPECT_EQ(MINI_BATCH_SIZE, minibatchSize); |
107 | } |
108 | /// Called once after TFLite. |
109 | virtual void postModelLoad(Loader &loader, PlaceholderBindings &bindings, |
110 | TFLiteModelLoader &tfliteLoader, |
111 | llvm::StringMap<Placeholder *> &outputMap, |
112 | llvm::ArrayRef<TypeRef> inputImageType) { |
113 | |
114 | size_t compilationBatchSize = inputImageType[0]->dims()[0]; |
115 | // To check the method was executed. |
116 | stage_ = 4; |
117 | |
118 | // To check params are correctly set. |
119 | loader_ = &loader; |
120 | bindings_ = &bindings; |
121 | tfliteloader_ = &tfliteLoader; |
122 | EXPECT_EQ(BATCH_SIZE_TFLITE, compilationBatchSize); |
123 | } |
124 | virtual ~testLoaderExtension() { destructed_ = true; } |
125 | }; |
126 | |
127 | // A simple Loader second extension class. |
128 | class secondTestLoaderExtension : public LoaderExtension { |
129 | public: |
130 | static int stage_; |
131 | static bool destructed_; |
132 | |
133 | secondTestLoaderExtension() { |
134 | stage_ = 0; |
135 | destructed_ = false; |
136 | } |
137 | |
138 | /// Called once after ONNX or Caffe2 model loading. |
139 | virtual void postModelLoad(Loader &, PlaceholderBindings &, ProtobufLoader &, |
140 | llvm::StringMap<Placeholder *> &, |
141 | llvm::ArrayRef<TypeRef> inputImageType) { |
142 | stage_ = 1; |
143 | } |
144 | /// Called once at the beginning of the mini-batch inference. |
145 | virtual void inferInitMiniBatch(Loader &, PlaceholderBindings &, size_t, |
146 | size_t) { |
147 | stage_ = 2; |
148 | } |
149 | /// Called once after the completion of the mini-batch inference. |
150 | virtual void inferEndMiniBatch(Loader &, PlaceholderBindings &, size_t, |
151 | size_t) { |
152 | stage_ = 3; |
153 | } |
154 | virtual void postModelLoad(Loader &, PlaceholderBindings &, |
155 | TFLiteModelLoader &, |
156 | llvm::StringMap<Placeholder *> &, |
157 | llvm::ArrayRef<TypeRef> inputImageType) { |
158 | stage_ = 4; |
159 | } |
160 | |
161 | virtual ~secondTestLoaderExtension() { destructed_ = true; } |
162 | }; |
163 | |
164 | // Static class members. |
165 | int testLoaderExtension::stage_; |
166 | size_t testLoaderExtension::index_; |
167 | Loader *testLoaderExtension::loader_; |
168 | PlaceholderBindings *testLoaderExtension::bindings_; |
169 | ProtobufLoader *testLoaderExtension::protobufLoader_; |
170 | TFLiteModelLoader *testLoaderExtension::tfliteloader_; |
171 | bool testLoaderExtension::destructed_; |
172 | int secondTestLoaderExtension::stage_; |
173 | bool secondTestLoaderExtension::destructed_; |
174 | |
175 | /// This test simulates what can be a Glow application (like image_classifier). |
176 | TEST_F(LoaderTest, LoaderExtensionCaffe2) { |
177 | { |
178 | std::unique_ptr<ExecutionContext> exContext = |
179 | glow::make_unique<ExecutionContext>(); |
180 | PlaceholderBindings &bindings = *exContext->getPlaceholderBindings(); |
181 | llvm::StringMap<Placeholder *> outputMap; |
182 | |
183 | // Create a loader object. |
184 | Loader loader; |
185 | |
186 | // Register Loader extensions. |
187 | loader.registerExtension( |
188 | std::unique_ptr<LoaderExtension>(new testLoaderExtension())); |
189 | loader.registerExtension( |
190 | std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension())); |
191 | |
192 | // Load a model |
193 | std::string NetDescFilename( |
194 | GLOW_DATA_PATH "tests/models/caffe2Models/sqr_predict_net.pbtxt" ); |
195 | std::string NetWeightFilename( |
196 | GLOW_DATA_PATH "tests/models/caffe2Models/empty_init_net.pbtxt" ); |
197 | |
198 | Placeholder *output; |
199 | Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE, 2}); |
200 | Caffe2ModelLoader caffe2LD(NetDescFilename, NetWeightFilename, {"input" }, |
201 | {&inputData.getType()}, *loader.getFunction()); |
202 | output = EXIT_ON_ERR(caffe2LD.getSingleOutput()); |
203 | |
204 | // Check the model was loaded. |
205 | EXPECT_EQ(loader.getFunction()->getNodes().size(), 3); |
206 | auto *save = getSaveNodeFromDest(output); |
207 | ASSERT_TRUE(save); |
208 | auto *pow = llvm::dyn_cast<PowNode>(save->getInput().getNode()); |
209 | ASSERT_TRUE(pow); |
210 | auto *input = llvm::dyn_cast<Placeholder>(pow->getLHS().getNode()); |
211 | ASSERT_TRUE(input); |
212 | auto *splat = llvm::dyn_cast<SplatNode>(pow->getRHS().getNode()); |
213 | ASSERT_TRUE(splat); |
214 | |
215 | // Get bindings and call post model load extensions. |
216 | ASSERT_EQ(testLoaderExtension::stage_, 0); |
217 | loader.postModelLoad(bindings, caffe2LD, outputMap, &inputData.getType()); |
218 | ASSERT_EQ(testLoaderExtension::stage_, 1); |
219 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
220 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
221 | ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD); |
222 | ASSERT_EQ(secondTestLoaderExtension::stage_, 1); |
223 | |
224 | // Allocate tensors to back all inputs and outputs. |
225 | bindings.allocate(loader.getModule()->getPlaceholders()); |
226 | |
227 | // Compile the model. |
228 | CompilationContext cctx = loader.getCompilationContext(); |
229 | cctx.bindings = &bindings; |
230 | loader.compile(cctx); |
231 | |
232 | // Load data to input placeholders. |
233 | updateInputPlaceholdersByName(bindings, loader.getModule(), {"input" }, |
234 | {&inputData}); |
235 | |
236 | // Run mini-batches. |
237 | for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE; |
238 | miniBatchIndex += MINI_BATCH_SIZE) { |
239 | // Minibatch inference initialization of loader extensions. |
240 | loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE); |
241 | ASSERT_EQ(testLoaderExtension::stage_, 2); |
242 | ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex); |
243 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
244 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
245 | ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD); |
246 | ASSERT_EQ(secondTestLoaderExtension::stage_, 2); |
247 | |
248 | // Perform the inference execution for a minibatch. |
249 | loader.runInference(exContext.get(), BATCH_SIZE); |
250 | |
251 | // Minibatch inference initialization of loader extensions. |
252 | loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE); |
253 | ASSERT_EQ(testLoaderExtension::stage_, 3); |
254 | ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex); |
255 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
256 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
257 | ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD); |
258 | ASSERT_EQ(secondTestLoaderExtension::stage_, 3); |
259 | } |
260 | |
261 | // Extension object not destructed yet. |
262 | ASSERT_EQ(testLoaderExtension::destructed_, false); |
263 | ASSERT_EQ(secondTestLoaderExtension::destructed_, false); |
264 | } // End of the loader scope. |
265 | |
266 | // Check that extensions were properly destructed by the Loader destruction. |
267 | ASSERT_EQ(testLoaderExtension::destructed_, true); |
268 | ASSERT_EQ(secondTestLoaderExtension::destructed_, true); |
269 | } |
270 | |
271 | TEST_F(LoaderTest, LoaderExtensionTFlite) { |
272 | { |
273 | std::unique_ptr<ExecutionContext> exContext = |
274 | glow::make_unique<ExecutionContext>(); |
275 | PlaceholderBindings &bindings = *exContext->getPlaceholderBindings(); |
276 | llvm::StringMap<Placeholder *> outputMap; |
277 | |
278 | // Create a loader object. |
279 | Loader loader; |
280 | |
281 | // Register Loader extensions. |
282 | loader.registerExtension( |
283 | std::unique_ptr<LoaderExtension>(new testLoaderExtension())); |
284 | loader.registerExtension( |
285 | std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension())); |
286 | |
287 | // Load a model |
288 | std::string NetFilename(GLOW_DATA_PATH |
289 | "tests/models/tfliteModels/abs.tflite" ); |
290 | |
291 | Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE_TFLITE, 10}); |
292 | TFLiteModelLoader LD(NetFilename, loader.getFunction()); |
293 | |
294 | // Check the model was loaded. |
295 | EXPECT_EQ(loader.getFunction()->getNodes().size(), 2); |
296 | |
297 | // Get bindings and call post model load extensions. |
298 | ASSERT_EQ(testLoaderExtension::stage_, 0); |
299 | loader.postModelLoad(bindings, LD, outputMap, &inputData.getType()); |
300 | ASSERT_EQ(testLoaderExtension::stage_, 4); |
301 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
302 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
303 | ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD); |
304 | ASSERT_EQ(secondTestLoaderExtension::stage_, 4); |
305 | |
306 | // Allocate tensors to back all inputs and outputs. |
307 | bindings.allocate(loader.getModule()->getPlaceholders()); |
308 | |
309 | // Compile the model. |
310 | CompilationContext cctx = loader.getCompilationContext(); |
311 | cctx.bindings = &bindings; |
312 | loader.compile(cctx); |
313 | |
314 | // Load data to input placeholders. |
315 | updateInputPlaceholdersByName(bindings, loader.getModule(), {"input" }, |
316 | {&inputData}); |
317 | |
318 | // Run mini-batches. |
319 | for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE; |
320 | miniBatchIndex += MINI_BATCH_SIZE) { |
321 | // Minibatch inference initialization of loader extensions. |
322 | loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE); |
323 | ASSERT_EQ(testLoaderExtension::stage_, 2); |
324 | ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex); |
325 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
326 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
327 | ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD); |
328 | ASSERT_EQ(secondTestLoaderExtension::stage_, 2); |
329 | |
330 | // Perform the inference execution for a minibatch. |
331 | loader.runInference(exContext.get(), BATCH_SIZE); |
332 | |
333 | // Minibatch inference initialization of loader extensions. |
334 | loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE); |
335 | ASSERT_EQ(testLoaderExtension::stage_, 3); |
336 | ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex); |
337 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
338 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
339 | ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD); |
340 | ASSERT_EQ(secondTestLoaderExtension::stage_, 3); |
341 | } |
342 | |
343 | // Extension object not destructed yet. |
344 | ASSERT_EQ(testLoaderExtension::destructed_, false); |
345 | ASSERT_EQ(secondTestLoaderExtension::destructed_, false); |
346 | } // End of the loader scope. |
347 | |
348 | // Check that extensions were properly destructed by the Loader destruction. |
349 | ASSERT_EQ(testLoaderExtension::destructed_, true); |
350 | } |
351 | |
352 | TEST_F(LoaderTest, LoaderExtensionONNX) { |
353 | { |
354 | std::unique_ptr<ExecutionContext> exContext = |
355 | glow::make_unique<ExecutionContext>(); |
356 | PlaceholderBindings &bindings = *exContext->getPlaceholderBindings(); |
357 | llvm::StringMap<Placeholder *> outputMap; |
358 | |
359 | // Create a loader object. |
360 | Loader loader; |
361 | |
362 | // Register Loader extensions. |
363 | loader.registerExtension( |
364 | std::unique_ptr<LoaderExtension>(new testLoaderExtension())); |
365 | loader.registerExtension( |
366 | std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension())); |
367 | |
368 | // Load a model |
369 | std::string NetFilename(GLOW_DATA_PATH |
370 | "tests/models/onnxModels/clip.onnxtxt" ); |
371 | |
372 | Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE, 3}); |
373 | ONNXModelLoader LD(NetFilename, {"x" }, {&inputData.getType()}, |
374 | *loader.getFunction()); |
375 | |
376 | // Check the model was loaded. |
377 | EXPECT_EQ(loader.getFunction()->getNodes().size(), 2); |
378 | |
379 | // Get bindings and call post model load extensions. |
380 | ASSERT_EQ(testLoaderExtension::stage_, 0); |
381 | loader.postModelLoad(bindings, LD, outputMap, &inputData.getType()); |
382 | ASSERT_EQ(testLoaderExtension::stage_, 1); |
383 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
384 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
385 | ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD); |
386 | ASSERT_EQ(secondTestLoaderExtension::stage_, 1); |
387 | |
388 | // Allocate tensors to back all inputs and outputs. |
389 | bindings.allocate(loader.getModule()->getPlaceholders()); |
390 | |
391 | // Compile the model. |
392 | CompilationContext cctx = loader.getCompilationContext(); |
393 | cctx.bindings = &bindings; |
394 | loader.compile(cctx); |
395 | |
396 | // Load data to input placeholders. |
397 | updateInputPlaceholdersByName(bindings, loader.getModule(), {"x" }, |
398 | {&inputData}); |
399 | |
400 | // Run mini-batches. |
401 | for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE; |
402 | miniBatchIndex += MINI_BATCH_SIZE) { |
403 | // Minibatch inference initialization of loader extensions. |
404 | loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE); |
405 | ASSERT_EQ(testLoaderExtension::stage_, 2); |
406 | ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex); |
407 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
408 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
409 | ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD); |
410 | ASSERT_EQ(secondTestLoaderExtension::stage_, 2); |
411 | |
412 | // Perform the inference execution for a minibatch. |
413 | loader.runInference(exContext.get(), BATCH_SIZE); |
414 | |
415 | // Minibatch inference initialization of loader extensions. |
416 | loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE); |
417 | ASSERT_EQ(testLoaderExtension::stage_, 3); |
418 | ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex); |
419 | ASSERT_EQ(testLoaderExtension::loader_, &loader); |
420 | ASSERT_EQ(testLoaderExtension::bindings_, &bindings); |
421 | ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD); |
422 | ASSERT_EQ(secondTestLoaderExtension::stage_, 3); |
423 | } |
424 | |
425 | // Extension object not destructed yet. |
426 | ASSERT_EQ(testLoaderExtension::destructed_, false); |
427 | ASSERT_EQ(secondTestLoaderExtension::destructed_, false); |
428 | } // End of the loader scope. |
429 | |
430 | // Check that extensions were properly destructed by the Loader destruction. |
431 | ASSERT_EQ(testLoaderExtension::destructed_, true); |
432 | ASSERT_EQ(secondTestLoaderExtension::destructed_, true); |
433 | } |
434 | |