1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "ImporterTestUtils.h"
17#include "glow/ExecutionEngine/ExecutionEngine.h"
18#include "glow/Graph/Graph.h"
19#include "glow/Graph/Nodes.h"
20#include "glow/Graph/PlaceholderBindings.h"
21#include "glow/Importer/TFLiteModelLoader.h"
22#include "gtest/gtest.h"
23
24#include "llvm/Support/CommandLine.h"
25
26#include <fstream>
27
28namespace {
29
30llvm::cl::OptionCategory tfliteModelTestCat("TFLITE Test Options");
31
32llvm::cl::opt<bool> tflitePrintTestTensorsOpt(
33 "tflite-dump-test-tensors", llvm::cl::init(false), llvm::cl::Optional,
34 llvm::cl::desc(
35 "Print input/expected tensors from test files. Default is false."),
36 llvm::cl::cat(tfliteModelTestCat));
37
38} // namespace
39
40using namespace glow;
41
42class TFLiteImporterTest : public ::testing::Test {};
43
44/// \p returns the full path of the TensorFlowLite model \p name.
45static std::string getModelPath(std::string name) {
46 return "tests/models/tfliteModels/" + name;
47}
48
49/// Utility function to load a binary file from \p fileName into \p tensor.
50/// The binary files have a special format with an extra byte of '0' at the
51/// start of the file followed by the actual tensor binary content. The extra
52/// '0' leading byte was required in order for the GIT system to correctly
53/// recognize the files as being binary files.
54static void loadTensor(Tensor *tensor, const std::string &fileName) {
55 std::ifstream file;
56 file.open(fileName, std::ios::binary);
57 assert(file.is_open() && "Error opening tensor file!");
58 file.seekg(1);
59 file.read(tensor->getUnsafePtr(), tensor->getSizeInBytes());
60 file.close();
61}
62
63/// Utility function to load and run TensorFlowLite model named \p modelName.
64/// The model with the name <name>.tflite is also associated with binary files
65/// used to validate numerically the model. The binary files have the following
66/// naming convention: <name>.inp0, <name>.inp1, etc for the model inputs and
67/// <name>.out0, <name>.out1, etc for the model reference outputs. When testing
68/// the output of the model a maximum error of \p maxError is allowed.
69static void loadAndRunModel(std::string modelName, float maxError = 1e-6) {
70 ExecutionEngine EE;
71 auto &mod = EE.getModule();
72 Function *F = mod.createFunction("main");
73
74 // Load TensorFlowLite model.
75 std::string modelPath = getModelPath(modelName);
76 { TFLiteModelLoader(modelPath, F); }
77
78 // Allocate tensors for all placeholders.
79 PlaceholderBindings bindings;
80 bindings.allocate(mod.getPlaceholders());
81
82 // Get model input/output placeholders.
83 PlaceholderList inputPH;
84 PlaceholderList outputPH;
85 for (const auto &ph : mod.getPlaceholders()) {
86 if (isInput(ph, *F)) {
87 inputPH.push_back(ph);
88 } else {
89 outputPH.push_back(ph);
90 }
91 }
92
93 // Load data into the input placeholders.
94 size_t dotPos = llvm::StringRef(modelPath).find_first_of('.');
95 std::string dataBasename = std::string(modelPath).substr(0, dotPos);
96 size_t inpIdx = 0;
97 for (const auto &inpPH : inputPH) {
98 std::string inpFilename = dataBasename + ".inp" + std::to_string(inpIdx++);
99 Tensor *inpT = bindings.get(inpPH);
100 loadTensor(inpT, inpFilename);
101 if (tflitePrintTestTensorsOpt) {
102 llvm::outs() << "Input Placeholder: " << inpPH->getName() << "\n";
103 inpT->dump();
104 }
105 }
106
107 // Run model.
108 EE.compile(CompilationMode::Infer);
109 EE.run(bindings);
110
111 // Compare output data versus reference.
112 size_t outIdx = 0;
113 for (const auto &outPH : outputPH) {
114 std::string refFilename = dataBasename + ".out" + std::to_string(outIdx++);
115
116 // Get output tensor.
117 Tensor *outT = bindings.get(outPH);
118
119 // Load reference tensor.
120 Tensor refT(outT->getType());
121 loadTensor(&refT, refFilename);
122 if (tflitePrintTestTensorsOpt) {
123 llvm::outs() << "Reference Tensor:\n";
124 refT.dump();
125 llvm::outs() << "Output Placeholder: " << outPH->getName() << "\n";
126 outT->dump();
127 }
128
129 // Compare.
130 ASSERT_TRUE(outT->isEqual(refT, maxError, /* verbose */ true));
131 }
132}
133
134#define TFLITE_UNIT_TEST(name, model) \
135 TEST(TFLiteImporterTest, name) { loadAndRunModel(model); }
136
137TFLITE_UNIT_TEST(Add, "add.tflite")
138
139TFLITE_UNIT_TEST(AvgPool2D_PaddingSame, "avgpool2d_same.tflite")
140TFLITE_UNIT_TEST(AvgPool2D_PaddingValid, "avgpool2d_valid.tflite")
141
142TFLITE_UNIT_TEST(Concat, "concat.tflite")
143TFLITE_UNIT_TEST(ConcatNegAxis, "concat_neg_axis.tflite")
144
145TFLITE_UNIT_TEST(Conv2D_PaddingSame, "conv2d_same.tflite")
146TFLITE_UNIT_TEST(Conv2D_PaddingValid, "conv2d_valid.tflite")
147TFLITE_UNIT_TEST(Conv2D_FusedRelu, "conv2d_relu.tflite")
148
149TFLITE_UNIT_TEST(DepthwiseConv2D_Ch1Mult1, "depthwise_conv2d_c1_m1.tflite")
150TFLITE_UNIT_TEST(DepthwiseConv2D_Ch1Mult2, "depthwise_conv2d_c1_m2.tflite")
151TFLITE_UNIT_TEST(DepthwiseConv2D_Ch2Mult1, "depthwise_conv2d_c2_m1.tflite")
152TFLITE_UNIT_TEST(DepthwiseConv2D_Ch2Mult2, "depthwise_conv2d_c2_m2.tflite")
153
154TFLITE_UNIT_TEST(HardSwish, "hardSwish.tflite")
155
156TFLITE_UNIT_TEST(Floor, "floor.tflite")
157
158TFLITE_UNIT_TEST(FullyConnected, "fully_connected.tflite")
159
160TFLITE_UNIT_TEST(Sigmoid, "sigmoid.tflite")
161
162TFLITE_UNIT_TEST(MaxPool2D_PaddingSame, "maxpool2d_same.tflite")
163TFLITE_UNIT_TEST(MaxPool2D_PaddingValid, "maxpool2d_valid.tflite")
164
165TFLITE_UNIT_TEST(Mul, "mul.tflite")
166
167TFLITE_UNIT_TEST(Relu, "relu.tflite")
168
169TFLITE_UNIT_TEST(ReluN1To1, "relu_n1to1.tflite")
170
171TFLITE_UNIT_TEST(Relu6, "relu6.tflite")
172
173TFLITE_UNIT_TEST(Reshape, "reshape.tflite")
174TFLITE_UNIT_TEST(ReshapeNegShape, "reshape_neg_shape.tflite")
175
176TFLITE_UNIT_TEST(Softmax, "softmax.tflite")
177
178TFLITE_UNIT_TEST(Tanh, "tanh.tflite")
179
180TFLITE_UNIT_TEST(Pad, "pad.tflite")
181
182TFLITE_UNIT_TEST(Transpose, "transpose.tflite")
183
184TFLITE_UNIT_TEST(MeanKeepDims, "mean_keep_dims.tflite")
185TFLITE_UNIT_TEST(MeanNoKeepDims, "mean_no_keep_dims.tflite")
186TFLITE_UNIT_TEST(MeanMultipleAxisKeepDims,
187 "mean_multiple_axis_keep_dims.tflite")
188TFLITE_UNIT_TEST(MeanMultipleAxisNoKeepDims,
189 "mean_multiple_axis_no_keep_dims.tflite")
190
191TFLITE_UNIT_TEST(Sub, "sub.tflite")
192
193TFLITE_UNIT_TEST(Div, "div.tflite")
194
195TFLITE_UNIT_TEST(Exp, "exp.tflite")
196
197TFLITE_UNIT_TEST(Split, "split.tflite")
198
199TFLITE_UNIT_TEST(PRelu, "prelu.tflite")
200
201TFLITE_UNIT_TEST(Maximum, "max.tflite")
202
203TFLITE_UNIT_TEST(ArgMax, "arg_max.tflite")
204
205TFLITE_UNIT_TEST(Minimum, "min.tflite")
206
207TFLITE_UNIT_TEST(Less, "less.tflite")
208
209TFLITE_UNIT_TEST(Neg, "neg.tflite")
210
211TFLITE_UNIT_TEST(Greater, "greater.tflite")
212
213TFLITE_UNIT_TEST(GreaterEqual, "greater_equal.tflite")
214
215TFLITE_UNIT_TEST(LessEqual, "less_equal.tflite")
216
217TFLITE_UNIT_TEST(Slice, "slice.tflite")
218TFLITE_UNIT_TEST(SliceNegSize, "slice_neg_size.tflite")
219
220TFLITE_UNIT_TEST(StridedSliceTest0, "strided_slice_test0.tflite")
221TFLITE_UNIT_TEST(StridedSliceTest1, "strided_slice_test1.tflite")
222TFLITE_UNIT_TEST(StridedSliceTest2, "strided_slice_test2.tflite")
223TFLITE_UNIT_TEST(StridedSliceTest3, "strided_slice_test3.tflite")
224TFLITE_UNIT_TEST(StridedSliceTest4, "strided_slice_test4.tflite")
225TFLITE_UNIT_TEST(StridedSliceTest5, "strided_slice_test5.tflite")
226TFLITE_UNIT_TEST(StridedSliceTest6, "strided_slice_test6.tflite")
227
228TFLITE_UNIT_TEST(Sin, "sin.tflite")
229
230TFLITE_UNIT_TEST(Tile, "tile.tflite")
231
232TFLITE_UNIT_TEST(ResizeBilinear, "resize_bilinear.tflite")
233
234TFLITE_UNIT_TEST(ResizeNearest, "resize_nearest.tflite")
235
236TFLITE_UNIT_TEST(SpaceToDepth, "space_to_depth.tflite")
237
238TFLITE_UNIT_TEST(DepthToSpace, "depth_to_space.tflite")
239
240TFLITE_UNIT_TEST(CastF32ToInt32, "cast_f32_to_int32.tflite")
241
242TFLITE_UNIT_TEST(GatherAxis0, "gather_axis0.tflite")
243TFLITE_UNIT_TEST(GatherAxis1, "gather_axis1.tflite")
244
245TFLITE_UNIT_TEST(GatherND, "gather_nd.tflite")
246
247TFLITE_UNIT_TEST(LogSoftmax, "log_softmax.tflite")
248
249TFLITE_UNIT_TEST(Select, "select.tflite")
250
251TFLITE_UNIT_TEST(SpaceToBatchNd, "spaceToBatchNd.tflite")
252TFLITE_UNIT_TEST(BatchToSpaceNd, "batchToSpaceNd.tflite")
253
254TFLITE_UNIT_TEST(Equal, "equal.tflite")
255
256TFLITE_UNIT_TEST(NotEqual, "not_equal.tflite")
257
258TFLITE_UNIT_TEST(Log, "log.tflite")
259
260TFLITE_UNIT_TEST(Sqrt, "sqrt.tflite")
261
262TFLITE_UNIT_TEST(Rsqrt, "rsqrt.tflite")
263
264TFLITE_UNIT_TEST(Pow, "pow.tflite")
265
266TFLITE_UNIT_TEST(ArgMin, "arg_min.tflite")
267
268TFLITE_UNIT_TEST(Pack, "pack.tflite")
269
270TFLITE_UNIT_TEST(LogicalOr, "logical_or.tflite")
271
272TFLITE_UNIT_TEST(LogicalAnd, "logical_and.tflite")
273
274TFLITE_UNIT_TEST(LogicalNot, "logical_not.tflite")
275
276TFLITE_UNIT_TEST(Unpack, "unpack.tflite")
277
278TFLITE_UNIT_TEST(Square, "square.tflite")
279
280TFLITE_UNIT_TEST(LeakyRelu, "leaky_relu.tflite")
281
282TFLITE_UNIT_TEST(Abs, "abs.tflite")
283
284TFLITE_UNIT_TEST(Ceil, "ceil.tflite")
285
286TFLITE_UNIT_TEST(Cos, "cos.tflite")
287
288TFLITE_UNIT_TEST(Round, "round.tflite")
289
290TFLITE_UNIT_TEST(Add_broadcast, "add_broadcast.tflite")
291TFLITE_UNIT_TEST(Sub_broadcast, "sub_broadcast.tflite")
292TFLITE_UNIT_TEST(Div_broadcast, "div_broadcast.tflite")
293TFLITE_UNIT_TEST(Mul_broadcast, "mul_broadcast.tflite")
294TFLITE_UNIT_TEST(Min_broadcast, "min_broadcast.tflite")
295TFLITE_UNIT_TEST(Max_broadcast, "max_broadcast.tflite")
296
297#undef TFLITE_UNIT_TEST
298
299/// Test Regular TFLiteDetectionPostProcess node.
300TEST(TFLiteImporterTest, TFLiteDetectionPostProcessRegular) {
301 ExecutionEngine EE;
302 auto &mod = EE.getModule();
303 Function *F = mod.createFunction("main");
304
305 // Load TensorFlowLite model.
306 std::string modelPath =
307 getModelPath("tflite_detection_post_processing_regular.tflite");
308 { TFLiteModelLoader(modelPath, F); }
309
310 // Allocate tensors for all placeholders.
311 PlaceholderBindings bindings;
312 bindings.allocate(mod.getPlaceholders());
313
314 // Get model input/output placeholders.
315 std::vector<Placeholder *> inputPH;
316 std::vector<Placeholder *> outputPH;
317 for (const auto &ph : mod.getPlaceholders()) {
318 if (isInput(ph, *F)) {
319 inputPH.push_back(ph);
320 } else {
321 outputPH.push_back(ph);
322 }
323 }
324
325 // Load data into the input placeholders.
326 loadTensor(bindings.get(inputPH[0]),
327 getModelPath("tflite_detection_post_processing_boxes.bin"));
328 loadTensor(bindings.get(inputPH[1]),
329 getModelPath("tflite_detection_post_processing_scores.bin"));
330
331 // Run model.
332 EE.compile(CompilationMode::Infer);
333 EE.run(bindings);
334
335 // Compare output data versus reference.
336 std::vector<float> detectionBoxesRef = {
337 0.270546197891235, 0.036445915699005, 0.625426292419434,
338 0.715417265892029, 0.008843034505844, 0.453001916408539,
339 0.434335529804230, 1.007383584976196, 0.264277368783951,
340 0.225462928414345, 0.431514173746109, 0.499467015266418,
341 0.012970104813576, 0.489649474620819, 0.433307945728302,
342 1.010598421096802, 0.208248645067215, 0.414025753736496,
343 0.256930917501450, 0.457198470830917, 0.259306669235229,
344 0.276896983385086, 0.413792371749878, 0.558155655860901,
345 0.296046763658524, 0.024428725242615, 0.620571494102478,
346 0.726388156414032, 0.100624501705170, 0.478332787752151,
347 0.341053903102875, 0.616274893283844, 0.195692524313927,
348 0.446290910243988, 0.264245152473450, 0.527587413787842,
349 0.232087373733521, 0.244561776518822, 0.373351573944092,
350 0.512895405292511,
351 };
352 std::vector<int32_t> detectionClassesRef = {
353 2, 7, 2, 5, 2, 2, 32, 7, 2, 2,
354 };
355 std::vector<float> detectionScoresRef = {
356 0.709131240844727, 0.694569468498230, 0.563223838806152,
357 0.540955007076263, 0.452089250087738, 0.439201682806015,
358 0.433123916387558, 0.432144701480865, 0.416427463293076,
359 0.408173263072968,
360 };
361 int32_t numDetectionsRef = 10;
362 auto detectionBoxesH = bindings.get(outputPH[0])->getHandle<float>();
363 auto detectionClassesH = bindings.get(outputPH[1])->getHandle<int32_t>();
364 auto detectionScoresH = bindings.get(outputPH[2])->getHandle<float>();
365 auto numDetectionsH = bindings.get(outputPH[3])->getHandle<int32_t>();
366 for (size_t idx = 0; idx < 4 * numDetectionsRef; ++idx) {
367 EXPECT_FLOAT_EQ(detectionBoxesH.raw(idx), detectionBoxesRef[idx]);
368 }
369 for (size_t idx = 0; idx < numDetectionsRef; ++idx) {
370 EXPECT_EQ(detectionClassesH.raw(idx), detectionClassesRef[idx]);
371 EXPECT_EQ(detectionScoresH.raw(idx), detectionScoresRef[idx]);
372 }
373 EXPECT_EQ(numDetectionsH.raw(0), numDetectionsRef);
374}
375
376/// Test Fast TFLiteDetectionPostProcess node.
377TEST(TFLiteImporterTest, TFLiteDetectionPostProcessFast) {
378 ExecutionEngine EE;
379 auto &mod = EE.getModule();
380 Function *F = mod.createFunction("main");
381
382 // Load TensorFlowLite model.
383 std::string modelPath =
384 getModelPath("tflite_detection_post_processing_fast.tflite");
385 { TFLiteModelLoader(modelPath, F); }
386
387 // Allocate tensors for all placeholders.
388 PlaceholderBindings bindings;
389 bindings.allocate(mod.getPlaceholders());
390
391 // Get model input/output placeholders.
392 std::vector<Placeholder *> inputPH;
393 std::vector<Placeholder *> outputPH;
394 for (const auto &ph : mod.getPlaceholders()) {
395 if (isInput(ph, *F)) {
396 inputPH.push_back(ph);
397 } else {
398 outputPH.push_back(ph);
399 }
400 }
401
402 // Load data into the input placeholders.
403 loadTensor(bindings.get(inputPH[0]),
404 getModelPath("tflite_detection_post_processing_boxes.bin"));
405 loadTensor(bindings.get(inputPH[1]),
406 getModelPath("tflite_detection_post_processing_scores.bin"));
407
408 // Run model.
409 EE.compile(CompilationMode::Infer);
410 EE.run(bindings);
411
412 // Compare output data versus reference.
413 std::vector<float> detectionBoxesRef = {
414 0.270546197891235, 0.036445915699005, 0.625426292419434,
415 0.715417265892029, 0.008843034505844, 0.453001916408539,
416 0.434335529804230, 1.007383584976196, 0.264277368783951,
417 0.225462928414345, 0.431514173746109, 0.499467015266418,
418 0.208248645067215, 0.414025753736496, 0.256930917501450,
419 0.457198470830917, 0.259306669235229, 0.276896983385086,
420 0.413792371749878, 0.558155655860901, 0.100624501705170,
421 0.478332787752151, 0.341053903102875, 0.616274893283844,
422 0.195692524313927, 0.446290910243988, 0.264245152473450,
423 0.527587413787842, 0.232087373733521, 0.244561776518822,
424 0.373351573944092, 0.512895405292511, 0.275883287191391,
425 0.037467807531357, 0.595628619194031, 0.463419944047928,
426 0.203831464052200, 0.354441434144974, 0.266103237867355,
427 0.427350491285324,
428 };
429 std::vector<int32_t> detectionClassesRef = {
430 2, 7, 2, 2, 2, 7, 2, 2, 2, 2,
431 };
432 std::vector<float> detectionScoresRef = {
433 0.709131240844727, 0.694569468498230, 0.563223838806152,
434 0.452089250087738, 0.439201682806015, 0.432144701480865,
435 0.416427463293076, 0.408173263072968, 0.405113369226456,
436 0.398936122655869,
437 };
438 int32_t numDetectionsRef = 10;
439 auto detectionBoxesH = bindings.get(outputPH[0])->getHandle<float>();
440 auto detectionClassesH = bindings.get(outputPH[1])->getHandle<int32_t>();
441 auto detectionScoresH = bindings.get(outputPH[2])->getHandle<float>();
442 auto numDetectionsH = bindings.get(outputPH[3])->getHandle<int32_t>();
443
444 for (size_t idx = 0; idx < 4 * numDetectionsRef; ++idx) {
445 EXPECT_FLOAT_EQ(detectionBoxesH.raw(idx), detectionBoxesRef[idx]);
446 }
447 for (size_t idx = 0; idx < numDetectionsRef; ++idx) {
448 EXPECT_EQ(detectionClassesH.raw(idx), detectionClassesRef[idx]);
449 EXPECT_EQ(detectionScoresH.raw(idx), detectionScoresRef[idx]);
450 }
451 EXPECT_EQ(numDetectionsH.raw(0), numDetectionsRef);
452}
453