1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "tools/loader/Loader.h"
18#include "ImporterTestUtils.h"
19#include "glow/ExecutionEngine/ExecutionEngine.h"
20#include "glow/Graph/Graph.h"
21#include "glow/Importer/Caffe2ModelLoader.h"
22#include "glow/Importer/ONNXModelLoader.h"
23#include "gtest/gtest.h"
24
25#include "llvm/ADT/StringMap.h"
26
27#ifndef GLOW_DATA_PATH
28#define GLOW_DATA_PATH
29#endif
30
31class LoaderTest : public ::testing::Test {
32protected:
33 // By default constant folding at load time is enabled in general, but we do
34 // many tests here loading Constants, so keep it false during these tests by
35 // default.
36 void SetUp() override { glow::setConstantFoldLoaderOpsFlag(false); }
37 void TearDown() override { glow::setConstantFoldLoaderOpsFlag(true); }
38};
39
40using namespace glow;
41
42namespace {
43const dim_t BATCH_SIZE = 3;
44const dim_t BATCH_SIZE_TFLITE = 1;
45const size_t MINI_BATCH_SIZE = 2;
46} // namespace
47
48// A Loader extension class for testing purpose.
49class testLoaderExtension : public LoaderExtension {
50public:
51 static int stage_;
52 static size_t index_;
53 static Loader *loader_;
54 static PlaceholderBindings *bindings_;
55 static ProtobufLoader *protobufLoader_;
56 static TFLiteModelLoader *tfliteloader_;
57 static bool destructed_;
58
59 testLoaderExtension() {
60 stage_ = 0;
61 index_ = 0;
62 loader_ = nullptr;
63 bindings_ = nullptr;
64 protobufLoader_ = nullptr;
65 destructed_ = false;
66 }
67
68 /// Called once after ONNX or Caffe2 model loading.
69 virtual void postModelLoad(Loader &loader, PlaceholderBindings &bindings,
70 ProtobufLoader &protobufLoader,
71 llvm::StringMap<Placeholder *> &outputMap,
72 llvm::ArrayRef<TypeRef> inputImageType) {
73
74 size_t compilationBatchSize = inputImageType[0]->dims()[0];
75 // To check the method was executed.
76 stage_ = 1;
77
78 // To check params are correctly set.
79 loader_ = &loader;
80 bindings_ = &bindings;
81 protobufLoader_ = &protobufLoader;
82 EXPECT_EQ(BATCH_SIZE, compilationBatchSize);
83 }
84 /// Called once at the beginning of the mini-batch inference.
85 virtual void inferInitMiniBatch(Loader &loader, PlaceholderBindings &bindings,
86 size_t minibatchIndex, size_t minibatchSize) {
87 // To check the method was executed.
88 stage_ = 2;
89
90 // To check params are correctly set.
91 loader_ = &loader;
92 bindings_ = &bindings;
93 index_ = minibatchIndex;
94 EXPECT_EQ(MINI_BATCH_SIZE, minibatchSize);
95 }
96 /// Called once after the completion of the mini-batch inference.
97 virtual void inferEndMiniBatch(Loader &loader, PlaceholderBindings &bindings,
98 size_t minibatchIndex, size_t minibatchSize) {
99 // To check the method was executed.
100 stage_ = 3;
101
102 // To check params are correctly set.
103 loader_ = &loader;
104 bindings_ = &bindings;
105 index_ = minibatchIndex;
106 EXPECT_EQ(MINI_BATCH_SIZE, minibatchSize);
107 }
108 /// Called once after TFLite.
109 virtual void postModelLoad(Loader &loader, PlaceholderBindings &bindings,
110 TFLiteModelLoader &tfliteLoader,
111 llvm::StringMap<Placeholder *> &outputMap,
112 llvm::ArrayRef<TypeRef> inputImageType) {
113
114 size_t compilationBatchSize = inputImageType[0]->dims()[0];
115 // To check the method was executed.
116 stage_ = 4;
117
118 // To check params are correctly set.
119 loader_ = &loader;
120 bindings_ = &bindings;
121 tfliteloader_ = &tfliteLoader;
122 EXPECT_EQ(BATCH_SIZE_TFLITE, compilationBatchSize);
123 }
124 virtual ~testLoaderExtension() { destructed_ = true; }
125};
126
127// A simple Loader second extension class.
128class secondTestLoaderExtension : public LoaderExtension {
129public:
130 static int stage_;
131 static bool destructed_;
132
133 secondTestLoaderExtension() {
134 stage_ = 0;
135 destructed_ = false;
136 }
137
138 /// Called once after ONNX or Caffe2 model loading.
139 virtual void postModelLoad(Loader &, PlaceholderBindings &, ProtobufLoader &,
140 llvm::StringMap<Placeholder *> &,
141 llvm::ArrayRef<TypeRef> inputImageType) {
142 stage_ = 1;
143 }
144 /// Called once at the beginning of the mini-batch inference.
145 virtual void inferInitMiniBatch(Loader &, PlaceholderBindings &, size_t,
146 size_t) {
147 stage_ = 2;
148 }
149 /// Called once after the completion of the mini-batch inference.
150 virtual void inferEndMiniBatch(Loader &, PlaceholderBindings &, size_t,
151 size_t) {
152 stage_ = 3;
153 }
154 virtual void postModelLoad(Loader &, PlaceholderBindings &,
155 TFLiteModelLoader &,
156 llvm::StringMap<Placeholder *> &,
157 llvm::ArrayRef<TypeRef> inputImageType) {
158 stage_ = 4;
159 }
160
161 virtual ~secondTestLoaderExtension() { destructed_ = true; }
162};
163
164// Static class members.
165int testLoaderExtension::stage_;
166size_t testLoaderExtension::index_;
167Loader *testLoaderExtension::loader_;
168PlaceholderBindings *testLoaderExtension::bindings_;
169ProtobufLoader *testLoaderExtension::protobufLoader_;
170TFLiteModelLoader *testLoaderExtension::tfliteloader_;
171bool testLoaderExtension::destructed_;
172int secondTestLoaderExtension::stage_;
173bool secondTestLoaderExtension::destructed_;
174
175/// This test simulates what can be a Glow application (like image_classifier).
176TEST_F(LoaderTest, LoaderExtensionCaffe2) {
177 {
178 std::unique_ptr<ExecutionContext> exContext =
179 glow::make_unique<ExecutionContext>();
180 PlaceholderBindings &bindings = *exContext->getPlaceholderBindings();
181 llvm::StringMap<Placeholder *> outputMap;
182
183 // Create a loader object.
184 Loader loader;
185
186 // Register Loader extensions.
187 loader.registerExtension(
188 std::unique_ptr<LoaderExtension>(new testLoaderExtension()));
189 loader.registerExtension(
190 std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension()));
191
192 // Load a model
193 std::string NetDescFilename(
194 GLOW_DATA_PATH "tests/models/caffe2Models/sqr_predict_net.pbtxt");
195 std::string NetWeightFilename(
196 GLOW_DATA_PATH "tests/models/caffe2Models/empty_init_net.pbtxt");
197
198 Placeholder *output;
199 Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE, 2});
200 Caffe2ModelLoader caffe2LD(NetDescFilename, NetWeightFilename, {"input"},
201 {&inputData.getType()}, *loader.getFunction());
202 output = EXIT_ON_ERR(caffe2LD.getSingleOutput());
203
204 // Check the model was loaded.
205 EXPECT_EQ(loader.getFunction()->getNodes().size(), 3);
206 auto *save = getSaveNodeFromDest(output);
207 ASSERT_TRUE(save);
208 auto *pow = llvm::dyn_cast<PowNode>(save->getInput().getNode());
209 ASSERT_TRUE(pow);
210 auto *input = llvm::dyn_cast<Placeholder>(pow->getLHS().getNode());
211 ASSERT_TRUE(input);
212 auto *splat = llvm::dyn_cast<SplatNode>(pow->getRHS().getNode());
213 ASSERT_TRUE(splat);
214
215 // Get bindings and call post model load extensions.
216 ASSERT_EQ(testLoaderExtension::stage_, 0);
217 loader.postModelLoad(bindings, caffe2LD, outputMap, &inputData.getType());
218 ASSERT_EQ(testLoaderExtension::stage_, 1);
219 ASSERT_EQ(testLoaderExtension::loader_, &loader);
220 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
221 ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD);
222 ASSERT_EQ(secondTestLoaderExtension::stage_, 1);
223
224 // Allocate tensors to back all inputs and outputs.
225 bindings.allocate(loader.getModule()->getPlaceholders());
226
227 // Compile the model.
228 CompilationContext cctx = loader.getCompilationContext();
229 cctx.bindings = &bindings;
230 loader.compile(cctx);
231
232 // Load data to input placeholders.
233 updateInputPlaceholdersByName(bindings, loader.getModule(), {"input"},
234 {&inputData});
235
236 // Run mini-batches.
237 for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE;
238 miniBatchIndex += MINI_BATCH_SIZE) {
239 // Minibatch inference initialization of loader extensions.
240 loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
241 ASSERT_EQ(testLoaderExtension::stage_, 2);
242 ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
243 ASSERT_EQ(testLoaderExtension::loader_, &loader);
244 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
245 ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD);
246 ASSERT_EQ(secondTestLoaderExtension::stage_, 2);
247
248 // Perform the inference execution for a minibatch.
249 loader.runInference(exContext.get(), BATCH_SIZE);
250
251 // Minibatch inference initialization of loader extensions.
252 loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
253 ASSERT_EQ(testLoaderExtension::stage_, 3);
254 ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
255 ASSERT_EQ(testLoaderExtension::loader_, &loader);
256 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
257 ASSERT_EQ(testLoaderExtension::protobufLoader_, &caffe2LD);
258 ASSERT_EQ(secondTestLoaderExtension::stage_, 3);
259 }
260
261 // Extension object not destructed yet.
262 ASSERT_EQ(testLoaderExtension::destructed_, false);
263 ASSERT_EQ(secondTestLoaderExtension::destructed_, false);
264 } // End of the loader scope.
265
266 // Check that extensions were properly destructed by the Loader destruction.
267 ASSERT_EQ(testLoaderExtension::destructed_, true);
268 ASSERT_EQ(secondTestLoaderExtension::destructed_, true);
269}
270
271TEST_F(LoaderTest, LoaderExtensionTFlite) {
272 {
273 std::unique_ptr<ExecutionContext> exContext =
274 glow::make_unique<ExecutionContext>();
275 PlaceholderBindings &bindings = *exContext->getPlaceholderBindings();
276 llvm::StringMap<Placeholder *> outputMap;
277
278 // Create a loader object.
279 Loader loader;
280
281 // Register Loader extensions.
282 loader.registerExtension(
283 std::unique_ptr<LoaderExtension>(new testLoaderExtension()));
284 loader.registerExtension(
285 std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension()));
286
287 // Load a model
288 std::string NetFilename(GLOW_DATA_PATH
289 "tests/models/tfliteModels/abs.tflite");
290
291 Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE_TFLITE, 10});
292 TFLiteModelLoader LD(NetFilename, loader.getFunction());
293
294 // Check the model was loaded.
295 EXPECT_EQ(loader.getFunction()->getNodes().size(), 2);
296
297 // Get bindings and call post model load extensions.
298 ASSERT_EQ(testLoaderExtension::stage_, 0);
299 loader.postModelLoad(bindings, LD, outputMap, &inputData.getType());
300 ASSERT_EQ(testLoaderExtension::stage_, 4);
301 ASSERT_EQ(testLoaderExtension::loader_, &loader);
302 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
303 ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD);
304 ASSERT_EQ(secondTestLoaderExtension::stage_, 4);
305
306 // Allocate tensors to back all inputs and outputs.
307 bindings.allocate(loader.getModule()->getPlaceholders());
308
309 // Compile the model.
310 CompilationContext cctx = loader.getCompilationContext();
311 cctx.bindings = &bindings;
312 loader.compile(cctx);
313
314 // Load data to input placeholders.
315 updateInputPlaceholdersByName(bindings, loader.getModule(), {"input"},
316 {&inputData});
317
318 // Run mini-batches.
319 for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE;
320 miniBatchIndex += MINI_BATCH_SIZE) {
321 // Minibatch inference initialization of loader extensions.
322 loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
323 ASSERT_EQ(testLoaderExtension::stage_, 2);
324 ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
325 ASSERT_EQ(testLoaderExtension::loader_, &loader);
326 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
327 ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD);
328 ASSERT_EQ(secondTestLoaderExtension::stage_, 2);
329
330 // Perform the inference execution for a minibatch.
331 loader.runInference(exContext.get(), BATCH_SIZE);
332
333 // Minibatch inference initialization of loader extensions.
334 loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
335 ASSERT_EQ(testLoaderExtension::stage_, 3);
336 ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
337 ASSERT_EQ(testLoaderExtension::loader_, &loader);
338 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
339 ASSERT_EQ(testLoaderExtension::tfliteloader_, &LD);
340 ASSERT_EQ(secondTestLoaderExtension::stage_, 3);
341 }
342
343 // Extension object not destructed yet.
344 ASSERT_EQ(testLoaderExtension::destructed_, false);
345 ASSERT_EQ(secondTestLoaderExtension::destructed_, false);
346 } // End of the loader scope.
347
348 // Check that extensions were properly destructed by the Loader destruction.
349 ASSERT_EQ(testLoaderExtension::destructed_, true);
350}
351
352TEST_F(LoaderTest, LoaderExtensionONNX) {
353 {
354 std::unique_ptr<ExecutionContext> exContext =
355 glow::make_unique<ExecutionContext>();
356 PlaceholderBindings &bindings = *exContext->getPlaceholderBindings();
357 llvm::StringMap<Placeholder *> outputMap;
358
359 // Create a loader object.
360 Loader loader;
361
362 // Register Loader extensions.
363 loader.registerExtension(
364 std::unique_ptr<LoaderExtension>(new testLoaderExtension()));
365 loader.registerExtension(
366 std::unique_ptr<LoaderExtension>(new secondTestLoaderExtension()));
367
368 // Load a model
369 std::string NetFilename(GLOW_DATA_PATH
370 "tests/models/onnxModels/clip.onnxtxt");
371
372 Tensor inputData(ElemKind::FloatTy, {BATCH_SIZE, 3});
373 ONNXModelLoader LD(NetFilename, {"x"}, {&inputData.getType()},
374 *loader.getFunction());
375
376 // Check the model was loaded.
377 EXPECT_EQ(loader.getFunction()->getNodes().size(), 2);
378
379 // Get bindings and call post model load extensions.
380 ASSERT_EQ(testLoaderExtension::stage_, 0);
381 loader.postModelLoad(bindings, LD, outputMap, &inputData.getType());
382 ASSERT_EQ(testLoaderExtension::stage_, 1);
383 ASSERT_EQ(testLoaderExtension::loader_, &loader);
384 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
385 ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD);
386 ASSERT_EQ(secondTestLoaderExtension::stage_, 1);
387
388 // Allocate tensors to back all inputs and outputs.
389 bindings.allocate(loader.getModule()->getPlaceholders());
390
391 // Compile the model.
392 CompilationContext cctx = loader.getCompilationContext();
393 cctx.bindings = &bindings;
394 loader.compile(cctx);
395
396 // Load data to input placeholders.
397 updateInputPlaceholdersByName(bindings, loader.getModule(), {"x"},
398 {&inputData});
399
400 // Run mini-batches.
401 for (size_t miniBatchIndex = 0; miniBatchIndex < BATCH_SIZE;
402 miniBatchIndex += MINI_BATCH_SIZE) {
403 // Minibatch inference initialization of loader extensions.
404 loader.inferInitMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
405 ASSERT_EQ(testLoaderExtension::stage_, 2);
406 ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
407 ASSERT_EQ(testLoaderExtension::loader_, &loader);
408 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
409 ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD);
410 ASSERT_EQ(secondTestLoaderExtension::stage_, 2);
411
412 // Perform the inference execution for a minibatch.
413 loader.runInference(exContext.get(), BATCH_SIZE);
414
415 // Minibatch inference initialization of loader extensions.
416 loader.inferEndMiniBatch(bindings, miniBatchIndex, MINI_BATCH_SIZE);
417 ASSERT_EQ(testLoaderExtension::stage_, 3);
418 ASSERT_EQ(testLoaderExtension::index_, miniBatchIndex);
419 ASSERT_EQ(testLoaderExtension::loader_, &loader);
420 ASSERT_EQ(testLoaderExtension::bindings_, &bindings);
421 ASSERT_EQ(testLoaderExtension::protobufLoader_, &LD);
422 ASSERT_EQ(secondTestLoaderExtension::stage_, 3);
423 }
424
425 // Extension object not destructed yet.
426 ASSERT_EQ(testLoaderExtension::destructed_, false);
427 ASSERT_EQ(secondTestLoaderExtension::destructed_, false);
428 } // End of the loader scope.
429
430 // Check that extensions were properly destructed by the Loader destruction.
431 ASSERT_EQ(testLoaderExtension::destructed_, true);
432 ASSERT_EQ(secondTestLoaderExtension::destructed_, true);
433}
434