1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "BackendTestUtils.h"
17
18#include "glow/Backend/BackendUtils.h"
19#include "glow/Backends/Interpreter/Interpreter.h"
20#include "glow/Base/TensorSerialization.h"
21#include "glow/ExecutionEngine/ExecutionEngine.h"
22#include "glow/Graph/Graph.h"
23#include "glow/Graph/PlaceholderBindings.h"
24#include "glow/IR/IRBuilder.h"
25#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
26#include "glow/Optimizer/IROptimizer/IROptimizer.h"
27
28#include "gtest/gtest.h"
29
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/FileSystem.h"
33
34#include <future>
35
36using namespace glow;
37
38/// An enum to indicate what type placholder it is.
39enum class PlaceholderType {
40 InputPlaceholder = 0,
41 InputOutputPlaceholder = 1,
42 OutputPlaceholder = 2,
43 NonePlaceholder = 3
44};
45
46class BackendExecTest : public ::testing::TestWithParam<std::string> {
47public:
48 ExecutionEngine EE_{GetParam()};
49};
50
51class BackendExecStatelessTest : public BackendStatelessTest {
52public:
53 ExecutionEngine EE_{getBackendName()};
54};
55
56TEST(Interpreter, profileQuantizationForANetwork) {
57 ExecutionEngine EE;
58 PlaceholderBindings bindings;
59 auto &mod = EE.getModule();
60 Function *F = mod.createFunction("main");
61 Tensor inputs(ElemKind::FloatTy, {1, 4});
62 inputs.getHandle() = {1, 1.2f, 0.5f, 1.3f};
63
64 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "A", false);
65 auto *Ex = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "E", false);
66 Node *O = F->createFullyConnected(bindings, "fc", A, 4);
67 O = F->createRELU("relu", O);
68 O = F->createRegression("reg", O, Ex);
69 F->createSave("ret", O);
70
71 LoweredInfoMap loweredMap;
72 CompilationContext cctx{&bindings, &loweredMap};
73 cctx.precisionConfig.quantMode = QuantizationMode::Profile;
74
75 bindings.allocate(A);
76 bindings.allocate(Ex);
77 EE.compile(cctx);
78 bindings.allocate(mod.getPlaceholders());
79
80 // TODO: Verify histogram itself, for now just verify min and max.
81 // Run inference first time and capture tensor stats.
82 updateInputPlaceholders(bindings, {A}, {&inputs});
83 EE.run(bindings);
84 // Because we are quantizing the partitioner deleted the original function and
85 // created a new one, get the new function.
86 F = mod.getFunctions().front();
87
88 QuantizationProfileNode *profile{nullptr};
89 // Find QPN for node A.
90 for (auto &node : F->getNodes()) {
91 if (QuantizationProfileNode *QPN =
92 llvm::dyn_cast<QuantizationProfileNode>(&node)) {
93 Node *observedNode = QPN->getInput().getNode();
94 if (observedNode == A) {
95 profile = QPN;
96 break;
97 }
98 }
99 }
100
101 EXPECT_TRUE(profile != nullptr);
102
103 auto CI = bindings.get(profile->getComputationInfoPlaceholder())
104 ->getHandle<float>();
105 float min = CI.raw(0);
106 float max = CI.raw(1);
107 EXPECT_NEAR(0.5, min, 0.00001);
108 EXPECT_NEAR(1.3, max, 0.00001);
109
110 // Run inference for the second time with new min and max.
111 inputs.getHandle() = {0.2f, 1.6f, 0.5f, 1.3f};
112 updateInputPlaceholders(bindings, {A}, {&inputs});
113 EE.run(bindings);
114 min = CI.raw(0);
115 max = CI.raw(1);
116 EXPECT_NEAR(0.2, min, 0.00001);
117 EXPECT_NEAR(1.6, max, 0.00001);
118}
119
120/// Creates an interpreter with a given \p name and custom instruction handler
121/// \p hook. \retruns a newly created custom interpreter.
122static Backend *createCustomInterpreter(llvm::StringRef name,
123 IRInstructionProcessingFn hook) {
124 auto interpreter = new Interpreter();
125 interpreter->setIRInstructionProcessingHandler(hook);
126 interpreter->setName(name);
127 return interpreter;
128}
129
130#ifdef GLOW_WITH_CPU
131
132/// A couple of counters to check that custom processing has happened.
133static unsigned numCustomProcessedSupportedInstructions;
134static unsigned numCustomProcessedUnsupportedInstructions;
135
136/// An interceptor to be invoked when executing the interpreter instructions.
137static IRInstructionProcessingFn customInterpreterHook =
138 [](const Instruction *I, IRInstructionProcessingStage executionStage,
139 void *ctx) -> bool {
140 // Only handle instructions in the processing stage.
141 if (executionStage != IRInstructionProcessingStage::PROCESSING) {
142 return false;
143 }
144 llvm::outs() << "Intercept instruction execution: " << I << "\n";
145 // This is an example of handling an instruction that is normally not
146 // supported by a vanilla interpreter. This way new backends or tests can
147 // extend the functionality of the interpreter and support custom
148 // instructions.
149 if (llvm::isa<CPUMaxSplatInst>(I)) {
150 llvm::outs() << "Apply special processing for an instruction not supported "
151 "by the interpreter: "
152 << I << "\n";
153 numCustomProcessedUnsupportedInstructions++;
154 // Tell the backend to skip standard processing of this instruction.
155 return true;
156 }
157 // This is an example of implementing a custom handling of an instruction that
158 // is supported by a vanilla interpreter. This way new backends or tests can
159 // change the behavior of the interpreter for specific instructions.
160 if (llvm::isa<ElementSubInst>(I)) {
161 llvm::outs() << "Apply special processing instruction: " << I << "\n";
162 numCustomProcessedSupportedInstructions++;
163 // Tell the backend to skip standard processing of this instruction.
164 return true;
165 }
166 return false;
167};
168
169/// Check support for intercepting and customizing the processing of
170/// instructions suppored by Interpreter.
171TEST(Interpreter, customPrePostAroundProcessing) {
172 // Register a custom backend.
173 REGISTER_DYNAMIC_GLOW_BACKEND_FACTORY(
174 CustomInterpreterFactory, Interpreter, "CustomInterpreter",
175 createCustomInterpreter("CustomInterpreter", customInterpreterHook))
176
177 ExecutionEngine EE("CustomInterpreter");
178 auto &mod = EE.getModule();
179 auto *F = mod.createFunction("test");
180 auto *input1 =
181 mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in1", false);
182 auto *input2 =
183 mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in2", false);
184 auto *add = F->createAdd("add", input1, input2);
185 auto *sub = F->createSub("sub", add, input1);
186 F->createSave("save", sub);
187 PlaceholderBindings bindings;
188 bindings.allocate({input1, input2});
189 EE.compile(CompilationMode::Infer);
190 numCustomProcessedSupportedInstructions = 0;
191 numCustomProcessedUnsupportedInstructions = 0;
192 // Process the function by means of the custom backend.
193 EE.run(bindings);
194 // Sub operation should have been processed in a custom way.
195 EXPECT_EQ(numCustomProcessedSupportedInstructions, 1);
196 EXPECT_EQ(numCustomProcessedUnsupportedInstructions, 0);
197}
198
199TEST(Interpreter, customHandleUnsupportedInstruction) {
200 // Register a custom Interpreter-based backend.
201 REGISTER_DYNAMIC_GLOW_BACKEND_FACTORY(
202 CustomInterpreterFactory, Interpreter, "CustomInterpreter",
203 createCustomInterpreter("CustomInterpreter", customInterpreterHook))
204 // Create CPU and custom interpreter backends.
205 ExecutionEngine cpuEE("CPU");
206 ExecutionEngine customInterpreterEE("CustomInterpreter");
207 auto *customInterpreterBackend =
208 &customInterpreterEE.getBackend("CustomInterpreter");
209 auto *cpuBackend = &cpuEE.getBackend("CPU");
210 auto &mod = cpuEE.getModule();
211 auto *F = mod.createFunction("test");
212 auto *input1 =
213 mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in1", false);
214 auto *splatTy = mod.uniqueType(ElemKind::FloatTy, {1, 10, 10, 3});
215 auto *splat = F->createSplat("splat", splatTy, 3);
216 auto *maxsplat = F->createMax("max", input1, splat);
217 auto *save = F->createSave("save", maxsplat);
218 std::unique_ptr<PlaceholderBindings> cpuBindings(new PlaceholderBindings);
219 cpuBindings->allocate({input1, save->getPlaceholder()});
220 std::unique_ptr<PlaceholderBindings> customInterpreterBindings(
221 new PlaceholderBindings);
222 customInterpreterBindings->allocate({input1, save->getPlaceholder()});
223 CompilationContext cctx;
224 cctx.compMode = CompilationMode::Infer;
225 FAIL_TEST_IF_ERR(glow::optimizeFunction(F, *cpuBackend, cctx));
226 // Generate the low-level IR for the CPU backend.
227 std::unique_ptr<IRFunction> cpuIR =
228 glow::generateAndOptimizeIR(F, *cpuBackend, false);
229 // Clone the low-level IR.
230 auto clonedCpuIR = cpuIR->clone("newTest");
231 // Compile the cloned IR for the custom Interpreter backend.
232 std::unique_ptr<IRFunction> customInterpreterIR(clonedCpuIR);
233 auto customInterpreterCompiledF(
234 reinterpret_cast<BackendUsingGlowIR *>(customInterpreterBackend)
235 ->compileIR(std::move(customInterpreterIR)));
236 auto cpuCompiledF(reinterpret_cast<BackendUsingGlowIR *>(cpuBackend)
237 ->compileIR(std::move(cpuIR)));
238 ExecutionContext cpuExecCtx(std::move(cpuBindings));
239 // Execute on the CPU backend.
240 FAIL_TEST_IF_ERR(cpuCompiledF->execute(&cpuExecCtx));
241 ExecutionContext customInterpreterExecCtx(
242 std::move(customInterpreterBindings));
243 numCustomProcessedUnsupportedInstructions = 0;
244 // Execute on the custom Interpreter backend. The usual Interpreter backend
245 // would not be able to handle some of the custom IR instructions defined by
246 // the CPU backend, but the custom interpreter backend can process them.
247 numCustomProcessedSupportedInstructions = 0;
248 numCustomProcessedUnsupportedInstructions = 0;
249 FAIL_TEST_IF_ERR(
250 customInterpreterCompiledF->execute(&customInterpreterExecCtx));
251 EXPECT_EQ(numCustomProcessedUnsupportedInstructions, 1);
252}
253
254#endif
255
256/// An interceptor to be invoked when executing the interpreter instructions.
257/// This is similar in spirit to customInterpreterHook.
258/// Please remember, the user can choose to do what she wants to with funcImpl
259/// -- they can compile and call it (like CUDA), or just invoke it, if it's a
260/// handle to an external function. In this case, based on funcImpl being "PLUS"
261/// or not, we add the inputs and return the value in the output.
262static IRInstructionProcessingFn externFnCallInterpreterHook =
263 [](const Instruction *I, IRInstructionProcessingStage executionStage,
264 void *ctx) -> bool {
265 // Only handle instructions in the processing stage.
266 if (executionStage != IRInstructionProcessingStage::PROCESSING) {
267 return false;
268 }
269
270 if (llvm::isa<ExternalFunctionCallInst>(I)) {
271 auto boundInterpFn = reinterpret_cast<BoundInterpreterFunction *>(ctx);
272 auto EFCI = llvm::dyn_cast<ExternalFunctionCallInst>(I);
273 auto funcImpl = EFCI->getFunctionImpl();
274
275 auto output = EFCI->getDest();
276 auto input1 = EFCI->getOperand(1).first;
277 auto input2 = EFCI->getOperand(2).first;
278 auto out = boundInterpFn->getWeightHandle<float>(output);
279 auto in1 = boundInterpFn->getWeightHandle<float>(input1);
280 auto in2 = boundInterpFn->getWeightHandle<float>(input2);
281
282 // In this simple test, we check the funcImpl of the instruction is PLUS
283 // or MINUS. If so, we return the sum or difference of the inputs. If it's
284 // anything else we zero the output. Note that this test shows a simple use
285 // of the ExternalFunctionCallInst. The user based on their needs can
286 // compile, compile and invoke, or invoke an external function. PLEASE NOTE
287 // HERE WE COULD HAVE INVOKED AN EXTERNAL FUNCTION OR COMPILED AND RAN CODE.
288 if (funcImpl == "PLUS") {
289 for (dim_t i = 0, e = out.size(); i < e; i++) {
290 out.raw(i) = in1.raw(i) + in2.raw(i);
291 }
292 } else if (funcImpl == "MINUS") {
293 for (dim_t i = 0, e = out.size(); i < e; i++) {
294 out.raw(i) = in1.raw(i) - in2.raw(i);
295 }
296 } else {
297 // Only PLUS and MINUS are supported.
298 for (dim_t i = 0, e = out.size(); i < e; i++) {
299 out.raw(i) = 0.0;
300 }
301 }
302 // Tell the backend to skip standard processing of this instruction.
303 return true;
304 }
305 return false;
306};
307
308TEST(Interpreter, ExternalFunctionCallTest) {
309 // Register a custom Interpreter-based backend with a hook for handling the
310 // ExternalFunctionCall instructions.
311 REGISTER_DYNAMIC_GLOW_BACKEND_FACTORY(
312 CustomInterpreterFactory, Interpreter, "CustomInterpreter",
313 createCustomInterpreter("CustomInterpreter", externFnCallInterpreterHook))
314 // Create a custom interpreter backend.
315 ExecutionEngine customInterpreterEE("CustomInterpreter");
316 auto *customInterpreterBackend =
317 &customInterpreterEE.getBackend("CustomInterpreter");
318 auto &mod = customInterpreterEE.getModule();
319 auto *F = mod.createFunction("test");
320
321 Tensor inputs(ElemKind::FloatTy, {10});
322 inputs.zero();
323
324 auto *input1 = mod.createPlaceholder(ElemKind::FloatTy, {10}, "in1", false);
325 auto *input2 = mod.createPlaceholder(ElemKind::FloatTy, {10}, "in2", false);
326
327 // For this test, we send in a toy external function. We call it plus_call.
328 // The functionImpl is just a string "PLUS". Based on this string being equal
329 // to "PLUS", we compute an add operation with the inputs and store it to the
330 // output. PLEASE NOTE: This test is just a toy example. The user can choose
331 // to do what she wants to with funcImpl -- they can compile and call it (like
332 // CUDA), or just invoke it, if it's a handle to an external function, and use
333 // the inputs and outputs as they see fit.
334
335 std::string fnName = "plus_call";
336 // This can be source code like OpenCL, CUDA, or a handle to a function.
337 std::string fnImplPlus = "PLUS";
338 std::string fnImplMinus = "MINUS";
339 std::string fnImplMul = "MUL";
340 std::string fnKind = "CUSTOM_OP";
341
342 auto *extFnCallPlus = F->createExternalFunctionCall(
343 "external_function_call", input1->getType(), {input1, input2}, fnName,
344 fnImplPlus, fnKind);
345 auto *extFnCallMinus = F->createExternalFunctionCall(
346 "external_function_call", input1->getType(), {input1, input2}, fnName,
347 fnImplMinus, fnKind);
348 auto *extFnCallMul = F->createExternalFunctionCall(
349 "external_function_call", input1->getType(), {input1, input2}, fnName,
350 fnImplMul, fnKind);
351 auto *savePlus = F->createSave("save", extFnCallPlus);
352 auto *saveMinus = F->createSave("save", extFnCallMinus);
353 auto *saveMul = F->createSave("save", extFnCallMul);
354
355 std::unique_ptr<PlaceholderBindings> customInterpreterBindings(
356 new PlaceholderBindings);
357 customInterpreterBindings->allocate(
358 {input1, input2, savePlus->getPlaceholder(), saveMinus->getPlaceholder(),
359 saveMul->getPlaceholder()});
360
361 // Now get the tensors and set their values.
362 auto inTensor1 = customInterpreterBindings->get(input1)->getHandle<float>();
363 auto inTensor2 = customInterpreterBindings->get(input2)->getHandle<float>();
364 for (dim_t i = 0, e = inTensor1.size(); i < e; i++) {
365 inTensor1.raw(i) = 5.0;
366 inTensor2.raw(i) = 4.0;
367 }
368
369 // Generate the IR for the custom backend.
370 std::unique_ptr<IRFunction> customInterpreterIR =
371 glow::generateAndOptimizeIR(F, *customInterpreterBackend, false);
372
373 auto customInterpreterCompiledF(
374 reinterpret_cast<BackendUsingGlowIR *>(customInterpreterBackend)
375 ->compileIR(std::move(customInterpreterIR)));
376
377 ExecutionContext customInterpreterExecCtx(
378 std::move(customInterpreterBindings));
379
380 FAIL_TEST_IF_ERR(
381 customInterpreterCompiledF->execute(&customInterpreterExecCtx));
382
383 // Get bindings, then get the input and output tensors.
384 auto *bindings = customInterpreterExecCtx.getPlaceholderBindings();
385 auto in1 = bindings->get(input1)->getHandle<float>();
386 auto in2 = bindings->get(input2)->getHandle<float>();
387 auto outputPlus =
388 bindings->get(savePlus->getPlaceholder())->getHandle<float>();
389 auto outputMinus =
390 bindings->get(saveMinus->getPlaceholder())->getHandle<float>();
391 auto outputMul = bindings->get(saveMul->getPlaceholder())->getHandle<float>();
392
393 // Verify the output tensors. Add and Minus should have been processed in the
394 // hook. Mul is not supported, and this ouptut should be zero'd.
395 for (dim_t i = 0, e = outputPlus.size(); i < e; i++) {
396 EXPECT_TRUE(outputPlus.raw(i) == in1.raw(i) + in2.raw(i));
397 EXPECT_TRUE(outputMinus.raw(i) == in1.raw(i) - in2.raw(i));
398 EXPECT_TRUE(outputMul.raw(i) == 0.0);
399 }
400}
401
402/// Check that new backends and backend factories can be registered dynamically.
403TEST(Interpreter, DynamicBackendFactory) {
404 // Use a static variable here, because the macro invocation below creates a
405 // new class and C++ does not allow for capturing of local variables.
406 static std::string backendName;
407 for (unsigned i = 0; i < 16; ++i) {
408 {
409 backendName = "CustomInterpreter" + std::to_string(i);
410 // Dynamically create a new backend factory and register it.
411 REGISTER_DYNAMIC_GLOW_BACKEND_FACTORY(CustomInterpreterFactory,
412 Interpreter, backendName,
413 []() -> Backend * {
414 // Dynamically create a backend
415 // and give it a name.
416 auto *backend = new Interpreter;
417 backend->setName(backendName);
418 return backend;
419 }())
420 ExecutionEngine EE(backendName);
421 auto *backend = &EE.getBackend(backendName);
422 ASSERT_NE(backend, nullptr);
423 // Check that a new backend is registered.
424 auto backends = getAvailableBackends();
425 EXPECT_NE(std::find(backends.begin(), backends.end(), backendName),
426 backends.end());
427 // The new backend factory will be destroyed at the end of this scope.
428 }
429 // Check that a new backend is not registered anymore after its factory was
430 // destroyed.
431 auto backends = getAvailableBackends();
432 EXPECT_EQ(std::find(backends.begin(), backends.end(), backendName),
433 backends.end());
434 }
435}
436
437/// Test that the symbol category for a symbol is properly set.
438TEST(RuntimeBundle, BundleSymbolInfo) {
439
440 ExecutionEngine EE;
441 auto &mod = EE.getModule();
442 PlaceholderBindings bindings;
443
444 Tensor inputs(ElemKind::FloatTy, {1, 10, 10, 3});
445 inputs.getHandle().randomize(-2, 2, mod.getPRNG());
446
447 // Create a simple graph that has placeholders, constants, activations, and a
448 // tensor_view.
449 Function *F = mod.createFunction("main");
450 auto *input =
451 mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
452
453 auto *ex = mod.createConstant(ElemKind::Int64ITy, {1, 1}, "exp");
454
455 auto *FC = F->createFullyConnected(bindings, "FC", input, 30);
456 auto *RL = F->createRELU("RL2", FC);
457 auto *SM = F->createSoftMax("sm", RL, ex);
458 auto *S = F->createSave("ret", SM);
459 auto *qp = F->createQuantizationProfile(bindings, "qp", input);
460
461 EE.compile(CompilationMode::Infer);
462 runtime::DAG *dag;
463 ASSIGN_VALUE_OR_FAIL_TEST(dag, EE.getDAG("main"));
464 assert(dag->nodes.size() > 0 && "Empty DAG list");
465 auto table = dag->nodes[0]->runtimeBundle->getSymbolTable();
466
467 // Check that placeholders and constants are correctly labelled.
468 EXPECT_EQ(
469 table.find(S->getPlaceholder()->getName().str())->second.symbolCategory,
470 glow::runtime::SymbolCategory::Placeholder);
471 EXPECT_EQ(table.find(ex->getName().str())->second.symbolCategory,
472 glow::runtime::SymbolCategory::Constant);
473 // Check that activations are labelled correctly.
474 EXPECT_EQ(table.find("FC_res")->second.symbolCategory,
475 glow::runtime::SymbolCategory::Activation);
476 // Check that tensor views have the same label as their parent symbol. In this
477 // case same as "input".
478 EXPECT_EQ(table.find("FC_reshape2D_tensorview")->second.symbolCategory,
479 glow::runtime::SymbolCategory::PlaceholderTensorView);
480
481 // Check that placeholders and constants input/output flags are correctly set.
482 EXPECT_EQ(table.find(S->getPlaceholder()->getName().str())->second.input,
483 false);
484 EXPECT_EQ(table.find(S->getPlaceholder()->getName().str())->second.output,
485 true);
486 EXPECT_EQ(table.find(ex->getName().str())->second.input, false);
487 EXPECT_EQ(table.find(ex->getName().str())->second.output, false);
488 EXPECT_EQ(table.find(input->getName().str())->second.input, true);
489 EXPECT_EQ(table.find(input->getName().str())->second.output, false);
490 // HistogramPlaceholder node is not an input node, it is an output node.
491 EXPECT_EQ(
492 table.find(qp->getHistogramPlaceholder()->getName().str())->second.input,
493 false);
494 EXPECT_EQ(
495 table.find(qp->getHistogramPlaceholder()->getName().str())->second.output,
496 true);
497 // Check that activations are labelled correctly.
498 EXPECT_EQ(table.find("FC_res")->second.input, false);
499 EXPECT_EQ(table.find("FC_res")->second.output, false);
500 // Check that tensor views are labelled correctly.
501 EXPECT_EQ(table.find("FC_reshape2D_tensorview")->second.input, false);
502 EXPECT_EQ(table.find("FC_reshape2D_tensorview")->second.output, false);
503}
504
505// Test that using a buffer in a TensorView instruction doesn't get it marked
506// as an input buffer.
507TEST(IR, testInputToTensorView) {
508 Module mod;
509 Function *F = mod.createFunction("main");
510 IRFunction M(F);
511 IRBuilder builder(&M);
512 auto T0 = mod.uniqueType(ElemKind::FloatTy, {1024, 1024});
513 auto T1 = mod.uniqueType(ElemKind::FloatTy, {512, 1024});
514 auto *input0 = builder.createWeightVar(T1, "A");
515 auto *input1 = builder.createWeightVar(T1, "B");
516 auto *output = builder.createWeightVar(T0, "C");
517 auto *tvo0 =
518 builder.createTensorViewInst("outuput_view0", output, T0, {0, 0});
519 auto *tvo1 =
520 builder.createTensorViewInst("outuput_view1", output, T0, {512, 0});
521 builder.createElementAddInst("add0", tvo0, input0, input1);
522 builder.createElementAddInst("add1", tvo1, input0, input1);
523 // output0 is only used as an input to a TensorView instruction, The buffer
524 // should not be marked as an input buffer since that doesn't include any
525 // reads of the buffer.
526 EXPECT_EQ(isInput(output), false);
527}
528
529// Test the correctness of isInput.
530TEST(IR, testIsInput) {
531 Module mod;
532 Function *F = mod.createFunction("main");
533 IRFunction M(F);
534 IRBuilder builder(&M);
535 auto T0 = mod.uniqueType(ElemKind::FloatTy, {1024, 1024});
536 auto T1 = mod.uniqueType(ElemKind::FloatTy, {512, 1024});
537 auto *input0 = builder.createWeightVar(T1, "A");
538 auto *input1 = builder.createWeightVar(T1, "B");
539 auto *output0 = builder.createWeightVar(T0, "C0");
540 auto *output1 = builder.createWeightVar(T0, "C1");
541 auto *tvo0 =
542 builder.createTensorViewInst("output_view0", output0, T0, {0, 0});
543 auto *tvo1 =
544 builder.createTensorViewInst("output_view1", output1, T0, {512, 0});
545 // tv0 is used as src and dest in this instruction. This is a first operation
546 // using output0 and it first reads from it. Thus output0 should be reported
547 // as input.
548 builder.createElementAddInst("add0", tvo0, tvo0, input1);
549 // Write into tvo1. This is the first operation touching output1 and it is a
550 // write.
551 builder.createElementAddInst("add1", tvo1, input0, input1);
552 // Read from tvo1 and then write into it.
553 builder.createElementAddInst("add2", tvo1, tvo1, input1);
554 // output is used as an input to a TensorView instruction tvo0, which doesn't
555 // count. But then tvo0 is used an input and output for the same add
556 // instruction. Thus, it is an input.
557 EXPECT_EQ(isInput(output0), true);
558 // output1 was first written into and then read. Therefore it is not an input.
559 EXPECT_EQ(isInput(output1), false);
560}
561
562// Test if the placeholders are allocated contiguously as
563// Input|InputOutput|Output.
564TEST(RuntimeBundle, ContiguousPlaceholder) {
565 ExecutionEngine EE;
566 PlaceholderBindings bindings;
567 auto &mod = EE.getModule();
568 Function *F = mod.createFunction("main");
569 Tensor inputs(ElemKind::FloatTy, {1, 4});
570 inputs.getHandle() = {1, 1.2f, 0.5f, 1.3f};
571
572 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "A", false);
573 auto *B = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "B", false);
574 auto *Ex = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "E", false);
575 auto *add = F->createAdd("add", A, Ex);
576 auto *sub = F->createSub("sub", B, add);
577 F->createSave("ret", sub);
578
579 LoweredInfoMap loweredMap;
580 CompilationContext cctx{&bindings, &loweredMap};
581 cctx.precisionConfig.quantMode = QuantizationMode::Profile;
582
583 bindings.allocate(A);
584 bindings.allocate(Ex);
585 EE.compile(cctx);
586 runtime::DAG *dag;
587 ASSIGN_VALUE_OR_FAIL_TEST(dag, EE.getDAG("main"));
588 auto &table = dag->nodes[0]->runtimeBundle->getSymbolTable();
589
590 std::vector<glow::runtime::RuntimeSymbolInfo> tableContainer;
591 // Only check placeholders.
592 for (auto v : table) {
593 if (v.second.symbolCategory == glow::runtime::SymbolCategory::Placeholder) {
594 tableContainer.push_back(v.second);
595 }
596 }
597 // Sort the placeholders by offset.
598 sort(tableContainer.begin(), tableContainer.end(),
599 [](const glow::runtime::RuntimeSymbolInfo &a,
600 const glow::runtime::RuntimeSymbolInfo &b) {
601 return (a.offset < b.offset);
602 });
603
604 // Define the order of placeholders.
605 auto order = [](glow::runtime::RuntimeSymbolInfo i) -> PlaceholderType {
606 if (i.input) {
607 if (!i.output) {
608 // input only
609 return PlaceholderType::InputPlaceholder;
610 } else {
611 // input & output
612 return PlaceholderType::InputOutputPlaceholder;
613 }
614 } else {
615 if (i.output) {
616 // output only
617 return PlaceholderType::OutputPlaceholder;
618 } else {
619 // neither
620 return PlaceholderType::NonePlaceholder;
621 }
622 }
623 };
624 // The order function of placeholders should be increasing.
625 PlaceholderType prev = PlaceholderType::InputPlaceholder;
626 bool flag = true;
627 for (auto v : tableContainer) {
628 PlaceholderType tmp = order(v);
629 if (tmp > prev) {
630 prev = tmp;
631 } else if (tmp < prev) {
632 flag = false;
633 break;
634 }
635 }
636
637 EXPECT_EQ(flag, true);
638}
639
640TEST_P(BackendExecTest, simpleInference) {
641 PlaceholderBindings bindings;
642
643 auto &mod = EE_.getModule();
644 Function *F = mod.createFunction("main");
645 auto *input =
646 mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
647 auto *conv = F->createConv(bindings, "conv", input, 10, 5, 1, 0, 1);
648 auto *res = F->createSave("save", conv);
649
650 ::glow::convertPlaceholdersToConstants(F, bindings,
651 {input, res->getPlaceholder()});
652 bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
653 bindings.allocate(res->getPlaceholder());
654
655 EE_.compile(CompilationMode::Infer);
656 EE_.run(bindings);
657}
658
659/// Utility function to create a simple network in which a tensor \p tensor is
660/// dumped using the debug instrumentation mechanism using the given \p format
661/// and filename \p filename. Note that the backend being tested must inherit
662/// from BackendUsingGlowIR and implement the compileIR() function for this test
663/// to work.
664static void runDebugPrint(ExecutionEngine &EE, std::string backendName,
665 Tensor &tensor, std::string format,
666 std::string filename) {
667 auto &mod = EE.getModule();
668 auto ctx = glow::make_unique<ExecutionContext>();
669 Function *F = mod.createFunction("main");
670 auto *IV = mod.createPlaceholder(&tensor.getType(), "tensor", false);
671 auto *IVTensor = ctx->getPlaceholderBindings()->allocate(IV);
672 IVTensor->assign(&tensor);
673 auto *save = F->createSave("save", IV);
674 ctx->getPlaceholderBindings()->allocate(save->getPlaceholder());
675
676 std::unique_ptr<Backend> backend(createBackend(backendName));
677 auto IR = glow::make_unique<IRFunction>(F);
678 IR->generateIR(*backend.get());
679 IRBuilder(IR.get()).createDebugPrintInst("print", *IR->getWeights().begin(),
680 format, filename);
681
682 auto function = reinterpret_cast<BackendUsingGlowIR *>(backend.get())
683 ->compileIR(std::move(IR));
684
685 // Since we are compiling IR by hand we cannot go through the normal EE route.
686 // Create and initialize the device.
687 auto config =
688 glow::make_unique<runtime::DeviceConfig>(backend->getBackendName());
689 std::unique_ptr<runtime::DeviceManager> device(
690 runtime::DeviceManager::createDeviceManager(*config));
691 EXIT_ON_ERR(device->init());
692 // Load the function on the device.
693 std::string name = "main";
694 runtime::FunctionMapTy functionMap;
695 functionMap[name] = function.get();
696
697 std::promise<void> addPromise;
698 auto fut = addPromise.get_future();
699 Error addErr = Error::empty();
700 device->addNetwork(&EE.getModule(), std::move(functionMap),
701 [&addPromise, &addErr](const Module *, Error err) {
702 addErr = std::move(err);
703 addPromise.set_value();
704 });
705 fut.wait();
706 EXIT_ON_ERR(std::move(addErr));
707 // Run the function.
708 std::promise<void> runPromise;
709 fut = runPromise.get_future();
710 Error runErr = Error::empty();
711 device->runFunction(name, std::move(ctx),
712 [&runPromise, &runErr,
713 &ctx](runtime::RunIdentifierTy, Error err,
714 std::unique_ptr<ExecutionContext> contextPtr) {
715 ctx = std::move(contextPtr);
716 runErr = std::move(err);
717 runPromise.set_value();
718 });
719 fut.wait();
720 EXIT_ON_ERR(std::move(runErr));
721}
722
723/// Utility function to test the debug instrumentation mechanism for a tensor
724/// \p tensorRef using the given \p format.
725template <typename type>
726static void testDebugPrint(ExecutionEngine &EE, std::string backendName,
727 Tensor &tensorRef, std::string format) {
728 // Create temporary file.
729 llvm::SmallString<64> path;
730 auto tempFileRes = llvm::sys::fs::createTemporaryFile("tensor", ".dat", path);
731 if (tempFileRes.value() != 0) {
732 FAIL() << "Failed to create temp file to write into.";
733 }
734 // Run debug print.
735 runDebugPrint(EE, backendName, tensorRef, format, path.str().str());
736 // Read tensor back.
737 Tensor tensorTest;
738 if (format == "bin") {
739 TensorSerializationOptions opts;
740 opts.withType = true;
741 glow::loadTensorFromBinaryFile(tensorTest, path.str(), opts);
742 } else if (format == "txt") {
743 TensorSerializationOptions opts;
744 opts.withType = true;
745 glow::loadTensorFromTextFile(tensorTest, path.str(), opts);
746 } else if (format == "rawbin") {
747 TensorSerializationOptions opts;
748 opts.withType = false;
749 tensorTest = Tensor(tensorRef.getType());
750 glow::loadTensorFromBinaryFile(tensorTest, path.str(), opts);
751 } else if (format == "rawtxt") {
752 TensorSerializationOptions opts;
753 opts.withType = false;
754 tensorTest = Tensor(tensorRef.getType());
755 glow::loadTensorFromTextFile(tensorTest, path.str(), opts);
756 } else {
757 FAIL() << "Invalid DebugPrint format!";
758 }
759 // Remove temporary file.
760 llvm::sys::fs::remove(path);
761 // Compare the two tensors.
762 EXPECT_EQ(tensorRef.getType(), tensorTest.getType());
763 auto handleRef = tensorRef.getHandle<type>();
764 auto handleTest = tensorTest.getHandle<type>();
765 EXPECT_EQ(handleRef.size(), handleTest.size());
766 EXPECT_EQ(handleRef.actualSize(), handleTest.actualSize());
767 for (size_t idx = 0; idx < tensorTest.actualSize(); idx++) {
768 EXPECT_EQ(handleTest.raw(idx), handleRef.raw(idx));
769 }
770}
771
772/// Test dumping to console.
773TEST_P(BackendExecStatelessTest, DebugPrint_Console) {
774 ENABLED_BACKENDS("CPU", "Interpreter");
775 Tensor tensorRef(ElemKind::FloatTy, {4});
776 tensorRef.getHandle<float>() = {1, 2, 3, 4};
777 runDebugPrint(EE_, getBackendName(), tensorRef, "console", "");
778}
779
780/// Test dumping to file in binary format.
781TEST_P(BackendExecStatelessTest, DebugPrint_Bin_FloatTy) {
782 ENABLED_BACKENDS("CPU", "Interpreter");
783 Tensor tensorRef(ElemKind::FloatTy, {4});
784 tensorRef.getHandle<float>() = {1, 2, 3, 4};
785 testDebugPrint<float>(EE_, getBackendName(), tensorRef, "bin");
786}
787
788TEST_P(BackendExecStatelessTest, DebugPrint_Bin_Int8QTy) {
789 ENABLED_BACKENDS("CPU", "Interpreter");
790 Tensor tensorRef(ElemKind::Int8QTy, {4}, 1.0, 0);
791 tensorRef.getHandle<int8_t>() = {1, 2, 3, 4};
792 testDebugPrint<int8_t>(EE_, getBackendName(), tensorRef, "bin");
793}
794
795TEST_P(BackendExecStatelessTest, DebugPrint_Bin_Int16QTy) {
796 ENABLED_BACKENDS("CPU", "Interpreter");
797 Tensor tensorRef(ElemKind::Int16QTy, {4}, 1.0, 0);
798 tensorRef.getHandle<int16_t>() = {1, 2, 3, 4};
799 testDebugPrint<int16_t>(EE_, getBackendName(), tensorRef, "bin");
800}
801
802TEST_P(BackendExecStatelessTest, DebugPrint_Bin_Int32QTy) {
803 ENABLED_BACKENDS("CPU", "Interpreter");
804 Tensor tensorRef(ElemKind::Int32QTy, {4}, 1.0, 0);
805 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
806 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "bin");
807}
808
809TEST_P(BackendExecStatelessTest, DebugPrint_Bin_Int32ITy) {
810 ENABLED_BACKENDS("CPU", "Interpreter");
811 Tensor tensorRef(ElemKind::Int32ITy, {4});
812 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
813 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "bin");
814}
815
816TEST_P(BackendExecStatelessTest, DebugPrint_Bin_Int64ITy) {
817 ENABLED_BACKENDS("CPU", "Interpreter");
818 Tensor tensorRef(ElemKind::Int64ITy, {4});
819 tensorRef.getHandle<int64_t>() = {1, 2, 3, 4};
820 testDebugPrint<int64_t>(EE_, getBackendName(), tensorRef, "bin");
821}
822
823TEST_P(BackendExecStatelessTest, DebugPrint_Bin_BoolTy) {
824 ENABLED_BACKENDS("CPU", "Interpreter");
825 Tensor tensorRef(ElemKind::BoolTy, {4});
826 tensorRef.getHandle<bool>() = {0, 1, 0, 1};
827 testDebugPrint<bool>(EE_, getBackendName(), tensorRef, "bin");
828}
829
830/// Test dumping to file in text format.
831TEST_P(BackendExecStatelessTest, DebugPrint_Txt_FloatTy) {
832 ENABLED_BACKENDS("CPU", "Interpreter");
833 Tensor tensorRef(ElemKind::FloatTy, {4});
834 tensorRef.getHandle<float>() = {1, 2, 3, 4};
835 testDebugPrint<float>(EE_, getBackendName(), tensorRef, "txt");
836}
837
838TEST_P(BackendExecStatelessTest, DebugPrint_Txt_Int8QTy) {
839 ENABLED_BACKENDS("CPU", "Interpreter");
840 Tensor tensorRef(ElemKind::Int8QTy, {4}, 1.0, 0);
841 tensorRef.getHandle<int8_t>() = {1, 2, 3, 4};
842 testDebugPrint<int8_t>(EE_, getBackendName(), tensorRef, "txt");
843}
844
845TEST_P(BackendExecStatelessTest, DebugPrint_Txt_Int16QTy) {
846 ENABLED_BACKENDS("CPU", "Interpreter");
847 Tensor tensorRef(ElemKind::Int16QTy, {4}, 1.0, 0);
848 tensorRef.getHandle<int16_t>() = {1, 2, 3, 4};
849 testDebugPrint<int16_t>(EE_, getBackendName(), tensorRef, "txt");
850}
851
852TEST_P(BackendExecStatelessTest, DebugPrint_Txt_Int32QTy) {
853 ENABLED_BACKENDS("CPU", "Interpreter");
854 Tensor tensorRef(ElemKind::Int32QTy, {4}, 1.0, 0);
855 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
856 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "txt");
857}
858
859TEST_P(BackendExecStatelessTest, DebugPrint_Txt_Int32ITy) {
860 ENABLED_BACKENDS("CPU", "Interpreter");
861 Tensor tensorRef(ElemKind::Int32ITy, {4});
862 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
863 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "txt");
864}
865
866TEST_P(BackendExecStatelessTest, DebugPrint_Txt_Int64ITy) {
867 ENABLED_BACKENDS("CPU", "Interpreter");
868 Tensor tensorRef(ElemKind::Int64ITy, {4});
869 tensorRef.getHandle<int64_t>() = {1, 2, 3, 4};
870 testDebugPrint<int64_t>(EE_, getBackendName(), tensorRef, "txt");
871}
872
873TEST_P(BackendExecStatelessTest, DebugPrint_Txt_BoolTy) {
874 ENABLED_BACKENDS("CPU", "Interpreter");
875 Tensor tensorRef(ElemKind::BoolTy, {4});
876 tensorRef.getHandle<bool>() = {0, 1, 0, 1};
877 testDebugPrint<bool>(EE_, getBackendName(), tensorRef, "txt");
878}
879
880/// Test dumping to file in raw binary format.
881TEST_P(BackendExecStatelessTest, DebugPrint_RawBin_FloatTy) {
882 ENABLED_BACKENDS("CPU", "Interpreter");
883 Tensor tensorRef(ElemKind::FloatTy, {4});
884 tensorRef.getHandle<float>() = {1, 2, 3, 4};
885 testDebugPrint<float>(EE_, getBackendName(), tensorRef, "rawbin");
886}
887
888TEST_P(BackendExecStatelessTest, DebugPrint_RawBin_Int8QTy) {
889 ENABLED_BACKENDS("CPU", "Interpreter");
890 Tensor tensorRef(ElemKind::Int8QTy, {4}, 1.0, 0);
891 tensorRef.getHandle<int8_t>() = {1, 2, 3, 4};
892 testDebugPrint<int8_t>(EE_, getBackendName(), tensorRef, "rawbin");
893}
894
895TEST_P(BackendExecStatelessTest, DebugPrint_RawBin_Int16QTy) {
896 ENABLED_BACKENDS("CPU", "Interpreter");
897 Tensor tensorRef(ElemKind::Int16QTy, {4}, 1.0, 0);
898 tensorRef.getHandle<int16_t>() = {1, 2, 3, 4};
899 testDebugPrint<int16_t>(EE_, getBackendName(), tensorRef, "rawbin");
900}
901
902TEST_P(BackendExecStatelessTest, DebugPrint_RawBin_Int32QTy) {
903 ENABLED_BACKENDS("CPU", "Interpreter");
904 Tensor tensorRef(ElemKind::Int32QTy, {4}, 1.0, 0);
905 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
906 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "rawbin");
907}
908
909TEST_P(BackendExecStatelessTest, DebugPrint_RawBin_Int32ITy) {
910 ENABLED_BACKENDS("CPU", "Interpreter");
911 Tensor tensorRef(ElemKind::Int32ITy, {4});
912 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
913 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "rawbin");
914}
915
916TEST_P(BackendExecStatelessTest, DebugPrint_RawBin_Int64ITy) {
917 ENABLED_BACKENDS("CPU", "Interpreter");
918 Tensor tensorRef(ElemKind::Int64ITy, {4});
919 tensorRef.getHandle<int64_t>() = {1, 2, 3, 4};
920 testDebugPrint<int64_t>(EE_, getBackendName(), tensorRef, "rawbin");
921}
922
923TEST_P(BackendExecStatelessTest, DebugPrint_RawBin_BoolTy) {
924 ENABLED_BACKENDS("CPU", "Interpreter");
925 Tensor tensorRef(ElemKind::BoolTy, {4});
926 tensorRef.getHandle<bool>() = {0, 1, 0, 1};
927 testDebugPrint<bool>(EE_, getBackendName(), tensorRef, "rawbin");
928}
929
930/// Test dumping to file in raw text format.
931TEST_P(BackendExecStatelessTest, DebugPrint_RawTxt_FloatTy) {
932 ENABLED_BACKENDS("CPU", "Interpreter");
933 Tensor tensorRef(ElemKind::FloatTy, {4});
934 tensorRef.getHandle<float>() = {1, 2, 3, 4};
935 testDebugPrint<float>(EE_, getBackendName(), tensorRef, "rawtxt");
936}
937
938TEST_P(BackendExecStatelessTest, DebugPrint_RawTxt_Int8QTy) {
939 ENABLED_BACKENDS("CPU", "Interpreter");
940 Tensor tensorRef(ElemKind::Int8QTy, {4}, 1.0, 0);
941 tensorRef.getHandle<int8_t>() = {1, 2, 3, 4};
942 testDebugPrint<int8_t>(EE_, getBackendName(), tensorRef, "rawtxt");
943}
944
945TEST_P(BackendExecStatelessTest, DebugPrint_RawTxt_Int16QTy) {
946 ENABLED_BACKENDS("CPU", "Interpreter");
947 Tensor tensorRef(ElemKind::Int16QTy, {4}, 1.0, 0);
948 tensorRef.getHandle<int16_t>() = {1, 2, 3, 4};
949 testDebugPrint<int16_t>(EE_, getBackendName(), tensorRef, "rawtxt");
950}
951
952TEST_P(BackendExecStatelessTest, DebugPrint_RawTxt_Int32QTy) {
953 ENABLED_BACKENDS("CPU", "Interpreter");
954 Tensor tensorRef(ElemKind::Int32QTy, {4}, 1.0, 0);
955 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
956 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "rawtxt");
957}
958
959TEST_P(BackendExecStatelessTest, DebugPrint_RawTxt_Int32ITy) {
960 ENABLED_BACKENDS("CPU", "Interpreter");
961 Tensor tensorRef(ElemKind::Int32ITy, {4});
962 tensorRef.getHandle<int32_t>() = {1, 2, 3, 4};
963 testDebugPrint<int32_t>(EE_, getBackendName(), tensorRef, "rawtxt");
964}
965
966TEST_P(BackendExecStatelessTest, DebugPrint_RawTxt_Int64ITy) {
967 ENABLED_BACKENDS("CPU", "Interpreter");
968 Tensor tensorRef(ElemKind::Int64ITy, {4});
969 tensorRef.getHandle<int64_t>() = {1, 2, 3, 4};
970 testDebugPrint<int64_t>(EE_, getBackendName(), tensorRef, "rawtxt");
971}
972
973TEST_P(BackendExecStatelessTest, DebugPrint_RawTxt_BoolTy) {
974 ENABLED_BACKENDS("CPU", "Interpreter");
975 Tensor tensorRef(ElemKind::BoolTy, {4});
976 tensorRef.getHandle<bool>() = {0, 1, 0, 1};
977 testDebugPrint<bool>(EE_, getBackendName(), tensorRef, "rawtxt");
978}
979
980/// Test the compile method on the backend completes without error when
981/// collectConstants is false.
982TEST_P(BackendExecTest, CompileWithoutConstants) {
983 Module mod;
984 PlaceholderBindings bindings;
985 Function *F = mod.createFunction("main");
986 auto *X = mod.createPlaceholder(ElemKind::FloatTy, {3}, "X", false);
987 auto *XTensor = bindings.allocate(X);
988 XTensor->getHandle() = {1., 2., 3.};
989 auto *pow = F->createPow("Pow1", X, 2.0);
990 auto *save = F->createSave("save", pow);
991 bindings.allocate(save->getPlaceholder());
992 std::unique_ptr<Backend> backend(createBackend(GetParam()));
993 BackendOptions opts;
994 opts.collectConstants = false;
995 auto function = EXIT_ON_ERR(backend->compile(F, opts));
996}
997
998/// Test that the runtimeBundle includes only symbols from its function and not
999/// the whole module.
1000TEST_P(BackendExecTest, BundleFunctionSymbolsOnly) {
1001 Module mod;
1002 PlaceholderBindings bindings;
1003 Function *F = mod.createFunction("main");
1004 auto *X = mod.createConstant(ElemKind::FloatTy, {3}, "X");
1005 X->getHandle() = {1., 2., 3.};
1006 auto *pow = F->createPow("Pow1", X, 2.0);
1007 auto *save = F->createSave("save", pow);
1008 bindings.allocate(save->getPlaceholder());
1009 PlaceholderBindings bindings2;
1010 Function *F2 = mod.createFunction("main2");
1011 auto *X2 = mod.createConstant(ElemKind::FloatTy, {3}, "X2");
1012 X2->getHandle() = {1., 2., 3.};
1013 auto *pow2 = F2->createPow("Pow2", X2, 2.0);
1014 auto *save2 = F2->createSave("save2", pow2);
1015 bindings2.allocate(save2->getPlaceholder());
1016
1017 std::unique_ptr<Backend> backend(createBackend(GetParam()));
1018 auto function = EXIT_ON_ERR(backend->compile(F));
1019 auto function2 = EXIT_ON_ERR(backend->compile(F2));
1020 auto table1 = function->getRuntimeBundle().getSymbolTable();
1021 auto table2 = function2->getRuntimeBundle().getSymbolTable();
1022 /// Make sure no symbol in table1 is in table2.
1023 for (auto sym : table1) {
1024 auto it = table2.find(sym.first);
1025 EXPECT_TRUE(it == table2.end());
1026 }
1027}
1028
1029/// Test that a shared constant is in the bundle of both functions.
1030TEST_P(BackendExecTest, BundleSharedConstant) {
1031 Module mod;
1032 PlaceholderBindings bindings;
1033 Function *F = mod.createFunction("main");
1034 auto *X = mod.createConstant(ElemKind::FloatTy, {3}, "X");
1035 X->getHandle() = {1., 2., 3.};
1036 auto *pow = F->createPow("Pow1", X, 2.0);
1037 auto *save = F->createSave("save", pow);
1038 bindings.allocate(save->getPlaceholder());
1039 PlaceholderBindings bindings2;
1040 Function *F2 = mod.createFunction("main2");
1041 auto *pow2 = F2->createPow("Pow2", X, 2.0);
1042 auto *save2 = F2->createSave("save2", pow2);
1043 bindings2.allocate(save2->getPlaceholder());
1044
1045 std::unique_ptr<Backend> backend(createBackend(GetParam()));
1046 auto function = EXIT_ON_ERR(backend->compile(F));
1047 auto function2 = EXIT_ON_ERR(backend->compile(F2));
1048 auto table1 = function->getRuntimeBundle().getSymbolTable();
1049 auto table2 = function2->getRuntimeBundle().getSymbolTable();
1050 /// Make sure X is in both tables.
1051 auto it = table1.find(X->getName().str());
1052 auto it2 = table2.find(X->getName().str());
1053 EXPECT_TRUE(it != table1.end());
1054 EXPECT_TRUE(it2 != table2.end());
1055}
1056
1057/// Test compiling a vector of functions completes without error.
1058TEST_P(BackendExecTest, compileVectorOfFunctions) {
1059 Module mod;
1060 std::vector<Function *> functions;
1061 llvm::StringMap<BackendOptions> optsMap;
1062 BackendOptions opts;
1063
1064 for (unsigned int i = 0; i < 3; i++) {
1065 Function *F = mod.createFunction("function" + std::to_string(i));
1066 auto *X = mod.createPlaceholder(ElemKind::FloatTy, {3},
1067 "X" + std::to_string(i), false);
1068 auto *pow = F->createPow("Pow" + std::to_string(i), X, 2.0);
1069 F->createSave("save" + std::to_string(i), pow);
1070 functions.push_back(F);
1071 optsMap.insert({F->getName(), opts});
1072 }
1073 std::unique_ptr<Backend> backend(createBackend(GetParam()));
1074
1075 auto functionOrErr = backend->compileFunctions(functions, optsMap);
1076 ASSERT_TRUE((bool)functionOrErr);
1077}
1078
1079/// This test checks that we can compile a function without depending on the
1080/// graph representation. We compile some function and then delete the function.
1081/// Later we execute the code and check that things work.
1082TEST_P(BackendExecTest, decoupleCodegenFromGraph) {
1083 auto &mod = EE_.getModule();
1084 PlaceholderBindings bindings;
1085
1086 Function *F = mod.createFunction("main");
1087 auto *X = mod.createPlaceholder(ElemKind::FloatTy, {3}, "X", false);
1088 auto *XTensor = bindings.allocate(X);
1089 XTensor->getHandle() = {1., 2., 3.};
1090 auto *pow = F->createPow("Pow1", X, 2.0);
1091 auto *save = F->createSave("save", pow);
1092 auto *saveTensor = bindings.allocate(save->getPlaceholder());
1093 EE_.compile(CompilationMode::Infer);
1094
1095 // Erase all of the functions to ensure that the compiled code does not
1096 // depend on the graph.
1097 mod.eraseFunctions();
1098
1099 // We can run the compiled code without having the graph representation
1100 // around.
1101 EE_.run(bindings);
1102
1103 auto HX = saveTensor->getHandle();
1104 EXPECT_NEAR(HX.at({0}), 1, 1E-5);
1105 EXPECT_NEAR(HX.at({1}), 4, 1E-5);
1106 EXPECT_NEAR(HX.at({2}), 9, 1E-5);
1107}
1108
1109/// Check that we can pass information to the execution engine using Placeholder
1110/// variables and read it back using Save nodes (in variables).
1111TEST_P(BackendExecTest, simplePlaceholderValue) {
1112 Tensor data{99.0, 35.0, 2.0, 3.0};
1113 auto &mod = EE_.getModule();
1114 Function *F = mod.createFunction("main");
1115 auto *input = mod.createPlaceholder(ElemKind::FloatTy, {4}, "input", false);
1116 PlaceholderBindings bindings({input}, {&data});
1117 SaveNode *S = F->createSave("ret", input);
1118 auto *STensor = bindings.allocate(S->getPlaceholder());
1119
1120 EE_.compile(CompilationMode::Infer);
1121 EE_.run(bindings);
1122 EXPECT_TRUE(STensor->isEqual(data));
1123}
1124
1125/// Add and compile a network, then add and compile another so that the first
1126/// CompiledFunction does not know about every Placeholder in the module.
1127TEST_P(BackendExecTest, compileThenAddNetwork) {
1128 PlaceholderBindings bindings1, bindings2;
1129
1130 auto &mod = EE_.getModule();
1131 Tensor inputs(ElemKind::FloatTy, {1, 10, 10, 3});
1132 inputs.getHandle().randomize(-2, 2, mod.getPRNG());
1133
1134 // Create a simple graph that uses some placeholders.
1135 Function *F = mod.createFunction("main");
1136 auto *input =
1137 mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
1138
1139 auto *FC = F->createFullyConnected(bindings1, "FC", input, 30);
1140 auto *RL = F->createRELU("RL2", FC);
1141 auto *S = F->createSave("ret", RL);
1142
1143 Placeholder *FC_weights =
1144 llvm::dyn_cast<Placeholder>(FC->getWeights().getNode());
1145
1146 // Recreate that graph in a different Function.
1147 Function *F2 = mod.createFunction("other");
1148 auto *input2 =
1149 mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
1150
1151 auto *FC2 = F2->createFullyConnected(bindings2, "FC", input2, 30);
1152
1153 // FC2 now has random weights we replace with FC1's weights so the output is
1154 // the same.
1155 Placeholder *FC2_weights =
1156 llvm::dyn_cast<Placeholder>(FC2->getWeights().getNode());
1157 bindings2.get(FC2_weights)->assign(bindings1.get(FC_weights));
1158
1159 auto *RL2 = F2->createRELU("RL2", FC2);
1160 auto *S2 = F2->createSave("ret", RL2);
1161
1162 convertPlaceholdersToConstants(F, bindings1, {input, S->getPlaceholder()});
1163 convertPlaceholdersToConstants(F2, bindings2, {input2, S2->getPlaceholder()});
1164 EE_.compile(CompilationMode::Infer);
1165
1166 // Allocate all placeholders.
1167 bindings1.allocate(mod.getPlaceholders());
1168 bindings2.allocate(mod.getPlaceholders());
1169 updateInputPlaceholders(bindings1, {input}, {&inputs});
1170 updateInputPlaceholders(bindings2, {input2}, {&inputs});
1171
1172 EE_.run(bindings1, "main");
1173 EE_.run(bindings2, "other");
1174
1175 // The graphs were the same so their outputs should be as well.
1176 EXPECT_TRUE(bindings1.get(S->getPlaceholder())
1177 ->isEqual(*bindings2.get(S2->getPlaceholder())));
1178}
1179
1180/// Test the basic functionality of the bindings.
1181TEST(PlaceholderBindings, basicPlaceholderBindingsTest) {
1182 Module mod;
1183 TypeRef ty = mod.uniqueType(ElemKind::FloatTy, {1, 32, 32, 3});
1184
1185 Tensor T1(ty);
1186
1187 // Create a bindings for some threaded execution.
1188 PlaceholderBindings C;
1189
1190 // Create a simple graph, just to have a few placeholders.
1191 Function *F = mod.createFunction("main");
1192 auto *input1 = mod.createPlaceholder(ty, "input1", false);
1193 auto *input2 = mod.createPlaceholder(ty, "input2", false);
1194 auto *input3 = mod.createPlaceholder(ty, "input3", false);
1195 auto *add = F->createAdd("add", input1, input2);
1196 auto *save = F->createSave("ret", add);
1197 auto *savePlaceholder = save->getPlaceholder();
1198 C.allocate(savePlaceholder);
1199
1200 C.insert(input1, std::move(T1));
1201 Tensor *I2 = C.allocate(input2);
1202
1203 // Check that the right placeholders are found.
1204 EXPECT_TRUE(C.count(input1));
1205 EXPECT_TRUE(C.count(input2));
1206 EXPECT_TRUE(C.count(savePlaceholder));
1207 EXPECT_FALSE(C.count(nullptr));
1208
1209 // Try to fetch some placeholders that exist and some that don't.
1210 auto *V1 = C.get(input1);
1211 auto *V2 = C.get(input2);
1212 auto *V3 = C.get(input3);
1213 auto *S = C.get(savePlaceholder);
1214 EXPECT_NE(V1, nullptr);
1215 EXPECT_NE(V2, nullptr);
1216 EXPECT_EQ(V3, nullptr);
1217 EXPECT_NE(S, nullptr);
1218
1219 // The tensor that we got while allocating T2 is the same one that we got
1220 // while searching the bindings.
1221 EXPECT_EQ(I2, V2);
1222
1223 // Check that all of the placeholders are allocated.
1224 C.allocate(input3);
1225 EXPECT_EQ(nullptr, C.getFirstUnallocated(mod.getPlaceholders()));
1226
1227 // Check that some placeholders are unallocated.
1228 C.clear();
1229 EXPECT_NE(nullptr, C.getFirstUnallocated(mod.getPlaceholders()));
1230
1231 // Check that all of the placeholders are allocated.
1232 C.allocate(mod.getPlaceholders());
1233 EXPECT_EQ(nullptr, C.getFirstUnallocated(mod.getPlaceholders()));
1234}
1235
1236/// Check if the dump function works for Type.
1237TEST(BackendExecTest, dumpType) {
1238 Module mod;
1239 TypeRef tyA = mod.uniqueType(ElemKind::FloatTy, {1, 32, 32, 3});
1240 std::string storage;
1241 llvm::raw_string_ostream os(storage);
1242 tyA->dump(os);
1243 std::string mesA = tyA->toString();
1244 std::string expectA = "float<1 x 32 x 32 x 3>";
1245 EXPECT_EQ(mesA, expectA);
1246 EXPECT_EQ(mesA, os.str());
1247}
1248
1249INSTANTIATE_BACKEND_TEST(BackendExecTest);
1250INSTANTIATE_BACKEND_TEST(BackendExecStatelessTest);
1251