1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Graph/Graph.h"
18#include "BackendTestUtils.h"
19#include "glow/ExecutionEngine/ExecutionEngine.h"
20#include "glow/Flags/Flags.h"
21#include "glow/Graph/Hook.h"
22#include "glow/Graph/Node.h"
23#include "glow/Graph/Nodes.h"
24#include "glow/Graph/Utils.h"
25#include "glow/IR/IR.h"
26#include "glow/IR/Instrs.h"
27#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
28
29#include "llvm/ADT/SmallPtrSet.h"
30#include "llvm/Support/FileSystem.h"
31
32#include "gtest/gtest.h"
33
34using namespace glow;
35
36// Helper to find a node in the Function by name
37static const Node *nodeByName(const Function *F, const std::string &name) {
38 for (auto &n : F->getNodes()) {
39 if (n.getName().str() == name) {
40 return &n;
41 }
42 }
43 return nullptr;
44}
45
46/// Mock backend that does lower FC nodes.
47class MockBackendNoLowerConv3D : public MockBackend {
48 bool shouldLower(const Node *N) const override {
49 if (N->getKind() == Kinded::Kind::Convolution3DNodeKind) {
50 return false;
51 } else {
52 return true;
53 }
54 }
55};
56
57TEST(Graph, testVariableErasure) {
58 Module MD;
59 auto &vars = MD.getConstants();
60 EXPECT_EQ(vars.size(), 0);
61 EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
62
63 Constant *V = MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy");
64 EXPECT_EQ(vars.size(), 1);
65 EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
66
67 MD.eraseConstant(V);
68 EXPECT_EQ(vars.size(), 0);
69 EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
70}
71
72/// Check that the clear method completely reset a module.
73TEST(Graph, clear) {
74 Module M;
75
76 // Check that the module is initially empty.
77 EXPECT_EQ(M.getConstants().size(), 0);
78 EXPECT_EQ(M.getPlaceholders().size(), 0);
79 EXPECT_EQ(M.getFunctions().size(), 0);
80
81 // Create a few things.
82 M.createFunction("main");
83 M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder", true);
84 M.createConstant(ElemKind::FloatTy, {1}, "var");
85
86 EXPECT_EQ(M.getConstants().size(), 1);
87 EXPECT_EQ(M.getPlaceholders().size(), 1);
88 EXPECT_EQ(M.getFunctions().size(), 1);
89
90 // Check that clearing the module makes it completely free of any kind of
91 // objects.
92 M.clear();
93 EXPECT_EQ(M.getConstants().size(), 0);
94 EXPECT_EQ(M.getPlaceholders().size(), 0);
95 EXPECT_EQ(M.getFunctions().size(), 0);
96}
97
98/// Check that the clear method works as expected.
99TEST(Graph, clearFunctions) {
100 Module M;
101
102 // Check that the module is initially empty.
103 EXPECT_EQ(M.getConstants().size(), 0);
104 EXPECT_EQ(M.getPlaceholders().size(), 0);
105 EXPECT_EQ(M.getFunctions().size(), 0);
106
107 // Create a few things.
108 Function *F = M.createFunction("main");
109 auto *PH = M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder", true);
110 auto *C = M.createConstant(ElemKind::FloatTy, {1}, "var");
111 auto *AN = F->createAdd("add", PH, C);
112 F->createSave("save", AN);
113
114 EXPECT_EQ(M.getConstants().size(), 1);
115 EXPECT_EQ(M.getPlaceholders().size(), 2); // Input PH and PH for Save
116 EXPECT_EQ(M.getFunctions().size(), 1);
117 EXPECT_EQ(F->getNodes().size(), 2); // Add, Save
118
119 M.clearFunctions();
120 EXPECT_EQ(M.getConstants().size(), 1);
121 EXPECT_EQ(M.getPlaceholders().size(), 2);
122 ASSERT_EQ(M.getFunctions().size(), 1);
123 // Same Function ptr should exist, just nothing left in them.
124 EXPECT_EQ(*M.getFunctions().begin(), F);
125 EXPECT_EQ(F->getNodes().size(), 0);
126}
127
128/// Test the graph nodes names and utilities.
129TEST(Graph, testGraphNames) {
130 Module MD;
131 Function *F = MD.createFunction("F");
132
133 Node *op1 = MD.createPlaceholder(ElemKind::FloatTy, {1, 10}, "op1",
134 false /*isTrainable*/);
135 Node *op2 = MD.createConstant(ElemKind::FloatTy, {1, 10}, "op2");
136 Node *add = F->createAdd("add", op1, op2);
137 auto *top = F->createTopK("top", add, 5);
138 Node *save = F->createSave("out", top->getValues());
139
140 EXPECT_TRUE(MD.getPlaceholderByNameSlow("op1"));
141 EXPECT_TRUE(MD.getConstantByName("op2"));
142 EXPECT_TRUE(F->getNodeByName("add"));
143 EXPECT_TRUE(F->getNodeByName("top"));
144 EXPECT_TRUE(F->getNodeByName("out_save"));
145
146 NodeValue op1Res = op1->getNthResult(0);
147 NodeValue op2Res = op2->getNthResult(0);
148 NodeValue addRes = add->getNthResult(0);
149 EXPECT_TRUE(top->getNumResults() == 2);
150 NodeValue topValRes = top->getNthResult(0);
151 NodeValue topIndRes = top->getNthResult(1);
152
153 auto op1ResName =
154 op1Res.generateNodeOutputName(false /*stripResNoFor0thInput*/);
155 auto op2ResName =
156 op2Res.generateNodeOutputName(false /*stripResNoFor0thInput*/);
157 auto addResName =
158 addRes.generateNodeOutputName(true /*stripResNoFor0thInput*/);
159 auto topValResName =
160 topValRes.generateNodeOutputName(false /*stripResNoFor0thInput*/);
161 auto topIndResName =
162 topIndRes.generateNodeOutputName(false /*stripResNoFor0thInput*/);
163
164 EXPECT_EQ(op1ResName, "op1:0");
165 EXPECT_EQ(op2ResName, "op2:0");
166 EXPECT_EQ(addResName, "add");
167 EXPECT_EQ(topValResName, "top:0");
168 EXPECT_EQ(topIndResName, "top:1");
169
170 EXPECT_EQ(F->getNodeValueByName(op1ResName), op1Res);
171 EXPECT_EQ(F->getNodeValueByName(op2ResName), op2Res);
172 EXPECT_EQ(F->getNodeValueByName(addResName), addRes);
173 EXPECT_EQ(F->getNodeValueByName(topValResName), topValRes);
174 EXPECT_EQ(F->getNodeValueByName(topIndResName), topIndRes);
175
176 EXPECT_EQ(F->getNodeValueByName("op1"), op1Res);
177 EXPECT_EQ(F->getNodeValueByName("op2"), op2Res);
178 EXPECT_EQ(F->getNodeValueByName("add:0"), addRes);
179
180 // Verify the node value is invalid for the SaveNode which has no outputs.
181 EXPECT_EQ(F->getNodeValueByName(save->getName()).getNode(), nullptr);
182}
183
184/// Check node names.
185TEST(Graph, testNodeNames) {
186 Module MD;
187 Function *F = MD.createFunction("F");
188 IRFunction M(F);
189 PlaceholderBindings bindings;
190 Node *K =
191 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
192 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
193
194 K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
195 K = F->createRELU("Relu", K);
196 K = F->createSoftMax("SoftMax", K, S);
197 F->createSave("Save", K);
198 F->dump();
199 auto filePath = F->dumpDAG();
200 auto backend = MockBackend();
201 CompilationContext cctx;
202 lower(F, cctx, &backend);
203 ::optimize(F, CompilationMode::Train);
204 M.generateIR(backend);
205 M.dump();
206 EXPECT_GT(M.getInstrs().size(), 0);
207 llvm::sys::fs::remove(filePath);
208}
209
210/// Check that a createConv3D can be run.
211TEST(Graph, simpleTestConv3D) {
212 Module MD;
213 Function *F = MD.createFunction("F");
214 IRFunction M(F);
215 PlaceholderBindings bindings;
216 Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3},
217 "input", true);
218 K = F->createConv3D(bindings, /* name */ "Conv3D", /* input */ K,
219 /* outChannels */ 16, /* kernel */ 3, /* stride */ 2,
220 /* pad */ 3, /* group */ 1);
221 K = F->createRELU("Relu", K);
222 F->createSave("Save", K);
223 F->dump();
224 auto filePath = F->dumpDAG();
225 auto backend = MockBackend();
226 CompilationContext cctx;
227 lower(F, cctx, &backend);
228 ::optimize(F, CompilationMode::Train);
229 M.generateIR(backend);
230 M.dump();
231 EXPECT_GT(M.getInstrs().size(), 0);
232 llvm::sys::fs::remove(filePath);
233}
234
235/// Tests custom lowering from Node to Instruction IR
236TEST(Graph, simpleTestConvCustomLower) {
237 Module MD;
238 Function *F = MD.createFunction("F");
239 IRFunction M(F);
240 PlaceholderBindings bindings;
241 Node *K =
242 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
243 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
244
245 K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
246 K = F->createRELU("Relu", K);
247 K = F->createSoftMax("SoftMax", K, S);
248 F->createSave("Save", K);
249 F->dump();
250 auto filePath = F->dumpDAG();
251 auto backend = MockBackendCustomIRGen();
252 CompilationContext cctx;
253 lower(F, cctx, &backend);
254 ::optimize(F, CompilationMode::Train);
255 M.generateIR(MockBackendCustomIRGen());
256 M.dump();
257 auto &instrList = M.getInstrs();
258 bool customHappened = false;
259 for (auto begin = instrList.begin(); begin != instrList.end(); ++begin) {
260 if (begin->getName().equals("CustomConvolutionInstruction")) {
261 customHappened = true;
262 break;
263 }
264 }
265
266 EXPECT_EQ(customHappened, true);
267 llvm::sys::fs::remove(filePath);
268}
269
270/// Check that we can create convolution with float16.
271TEST(Graph, float16Conv) {
272 Module MD;
273 Function *F = MD.createFunction("F");
274 PlaceholderBindings bindings;
275 Node *K = MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 3}, "input");
276
277 auto *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
278 F->createSave("Save", conv);
279 EXPECT_TRUE(conv->verify());
280 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
281 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
282 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
283
284 auto backend = MockBackend();
285 CompilationContext cctx;
286 lower(F, cctx, &backend);
287
288 IRFunction M(F);
289
290 M.generateIR(backend);
291 EXPECT_GT(M.getInstrs().size(), 0);
292 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
293 [](const Instruction &inst) -> bool {
294 return llvm::isa<ConvolutionInst>(inst);
295 });
296 ASSERT_TRUE(convIt != M.getInstrs().end());
297 const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt);
298 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty);
299 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty);
300 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty);
301}
302
303/// Check that we can create layernormalization with float16.
304TEST(Graph, float16LayerNorm) {
305 const auto origFlagVal = interpreter::flags::LowerLayerNormalization;
306 interpreter::flags::LowerLayerNormalization = false;
307
308 Module MD;
309 Function *F = MD.createFunction("F");
310
311 PlaceholderBindings bindings;
312 auto *input =
313 MD.createPlaceholder(ElemKind::Float16Ty, {1, 4, 5, 5}, "in", false);
314
315 Tensor scaleT(ElemKind::Float16Ty, {5, 5});
316 scaleT.getHandle<float16_t>().randomize(0.0f, 1.0f, MD.getPRNG());
317 Constant *scaleC = MD.createConstant("scale", std::move(scaleT));
318 Tensor biasT(ElemKind::Float16Ty, {5, 5});
319 biasT.getHandle<float16_t>().randomize(0.0f, 1.0f, MD.getPRNG());
320 Constant *biasC = MD.createConstant("bias", std::move(biasT));
321
322 LayerNormalizationNode *LNN = F->createLayerNormalization(
323 "LN", input->getType(), input, scaleC, biasC, 1e-5);
324 F->createSave("Save", LNN);
325
326 std::unique_ptr<const Backend> backend(createBackend("Interpreter"));
327
328 CompilationContext cctx;
329 lower(F, cctx, backend.get());
330
331 IRFunction M(F);
332
333 M.generateIR(*backend);
334 EXPECT_GT(M.getInstrs().size(), 0);
335 auto lnIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
336 [](const Instruction &inst) -> bool {
337 return llvm::isa<LayerNormalizationInst>(inst);
338 });
339 ASSERT_TRUE(lnIt != M.getInstrs().end());
340 interpreter::flags::LowerLayerNormalization = origFlagVal;
341}
342
343/// Check that we can create batch_matmul with float16.
344TEST(Graph, float16BatchMatMul) {
345 const auto origFlagVal = interpreter::flags::LowerBatchMatMul;
346 interpreter::flags::LowerBatchMatMul = false;
347
348 Module MD;
349 Function *F = MD.createFunction("F");
350
351 PlaceholderBindings bindings;
352 auto *LHS = MD.createPlaceholder(ElemKind::Float16Ty, {2, 3, 4}, "A", false);
353 auto *RHS = MD.createPlaceholder(ElemKind::Float16Ty, {2, 4, 5}, "B", false);
354
355 BatchMatMulNode *BMM = F->createBatchMatMul("BMM", LHS, RHS);
356 F->createSave("Save", BMM);
357
358 std::unique_ptr<const Backend> backend(createBackend("Interpreter"));
359
360 CompilationContext cctx;
361 lower(F, cctx, backend.get());
362
363 IRFunction M(F);
364
365 M.generateIR(*backend);
366 EXPECT_GT(M.getInstrs().size(), 0);
367 auto bmmIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
368 [](const Instruction &inst) -> bool {
369 return llvm::isa<BatchMatMulInst>(inst);
370 });
371 ASSERT_TRUE(bmmIt != M.getInstrs().end());
372 interpreter::flags::LowerBatchMatMul = origFlagVal;
373}
374
375/// Check that we can create batch_matmul with float.
376TEST(Graph, floatBatchMatMul) {
377 const auto origFlagVal = interpreter::flags::LowerBatchMatMul;
378 interpreter::flags::LowerBatchMatMul = false;
379
380 Module MD;
381 Function *F = MD.createFunction("F");
382
383 PlaceholderBindings bindings;
384 auto *LHS = MD.createPlaceholder(ElemKind::FloatTy, {2, 3, 4}, "A", false);
385 auto *RHS = MD.createPlaceholder(ElemKind::FloatTy, {2, 4, 5}, "B", false);
386
387 BatchMatMulNode *BMM = F->createBatchMatMul("BMM", LHS, RHS);
388 F->createSave("Save", BMM);
389
390 std::unique_ptr<const Backend> backend(createBackend("Interpreter"));
391
392 CompilationContext cctx;
393 lower(F, cctx, backend.get());
394
395 IRFunction M(F);
396
397 M.generateIR(*backend);
398 EXPECT_GT(M.getInstrs().size(), 0);
399 auto bmmIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
400 [](const Instruction &inst) -> bool {
401 return llvm::isa<BatchMatMulInst>(inst);
402 });
403 ASSERT_TRUE(bmmIt != M.getInstrs().end());
404 interpreter::flags::LowerBatchMatMul = origFlagVal;
405}
406
407/// Check that we can create conv3D with float16.
408TEST(Graph, float16Conv3DLower) {
409 Module MD;
410 Function *F = MD.createFunction("F");
411 PlaceholderBindings bindings;
412 Node *K =
413 MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input");
414
415 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
416 F->createSave("Save", conv);
417 EXPECT_TRUE(conv->verify());
418 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
419 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
420 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
421
422 auto backend = MockBackend();
423 CompilationContext cctx;
424 lower(F, cctx, &backend);
425
426 IRFunction M(F);
427
428 M.generateIR(backend);
429 EXPECT_GT(M.getInstrs().size(), 0);
430 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
431 [](const Instruction &inst) -> bool {
432 return llvm::isa<Convolution3DInst>(inst);
433 });
434 ASSERT_TRUE(convIt == M.getInstrs().end());
435}
436
437/// Check that we can create conv3D with float16.
438TEST(Graph, float16Conv3DNoLower) {
439 Module MD;
440 Function *F = MD.createFunction("F");
441 PlaceholderBindings bindings;
442 Node *K =
443 MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input");
444
445 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
446 F->createSave("Save", conv);
447 EXPECT_TRUE(conv->verify());
448 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
449 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
450 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
451
452 auto backend = MockBackendNoLowerConv3D();
453 CompilationContext cctx;
454 lower(F, cctx, &backend);
455
456 IRFunction M(F);
457
458 M.generateIR(backend);
459 EXPECT_GT(M.getInstrs().size(), 0);
460 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
461 [](const Instruction &inst) -> bool {
462 return llvm::isa<Convolution3DInst>(inst);
463 });
464 ASSERT_TRUE(convIt != M.getInstrs().end());
465 const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt);
466 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty);
467 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty);
468 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty);
469}
470
471/// Check that we can create batchNorm with float16.
472TEST(Graph, float16BatchNorm) {
473 Module MD;
474 Function *F = MD.createFunction("F");
475 PlaceholderBindings bindings;
476 auto *input =
477 MD.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "input", false);
478 BatchNormalizationNode *BN =
479 F->createBatchNormalization(bindings, "batch", input, 3, 0.0001, 0.9);
480
481 EXPECT_TRUE(BN->verify());
482 EXPECT_EQ(BN->getResult().getElementType(), ElemKind::Float16Ty);
483 EXPECT_EQ(BN->getScale().getElementType(), ElemKind::Float16Ty);
484 EXPECT_EQ(BN->getBias().getElementType(), ElemKind::Float16Ty);
485 EXPECT_EQ(BN->getMean().getElementType(), ElemKind::Float16Ty);
486 EXPECT_EQ(BN->getVar().getElementType(), ElemKind::Float16Ty);
487
488 auto backend = MockBackend();
489 CompilationContext cctx;
490 lower(F, cctx, &backend);
491
492 EXPECT_TRUE(std::all_of(
493 F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool {
494 for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
495 if (node.getType(idx)->getElementType() != ElemKind::Float16Ty) {
496 return false;
497 }
498 }
499 return true;
500 }));
501}
502
503/// Check that we can create convolution with bfloat16.
504TEST(Graph, bfloat16Conv) {
505 Module MD;
506 Function *F = MD.createFunction("F");
507 PlaceholderBindings bindings;
508 Node *K = MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 3}, "input");
509
510 auto *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
511 F->createSave("Save", conv);
512 EXPECT_TRUE(conv->verify());
513 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
514 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
515 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
516
517 auto backend = MockBackend();
518 CompilationContext cctx;
519 lower(F, cctx, &backend);
520
521 IRFunction M(F);
522
523 M.generateIR(backend);
524 EXPECT_GT(M.getInstrs().size(), 0);
525 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
526 [](const Instruction &inst) -> bool {
527 return llvm::isa<ConvolutionInst>(inst);
528 });
529 ASSERT_TRUE(convIt != M.getInstrs().end());
530 const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt);
531 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty);
532 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty);
533 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty);
534}
535
536/// Check that we can create conv3D with bfloat16.
537TEST(Graph, bfloat16Conv3DLower) {
538 Module MD;
539 Function *F = MD.createFunction("F");
540 PlaceholderBindings bindings;
541 Node *K =
542 MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input");
543
544 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
545 F->createSave("Save", conv);
546 EXPECT_TRUE(conv->verify());
547 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
548 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
549 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
550
551 auto backend = MockBackend();
552 CompilationContext cctx;
553 lower(F, cctx, &backend);
554
555 IRFunction M(F);
556
557 M.generateIR(backend);
558 EXPECT_GT(M.getInstrs().size(), 0);
559 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
560 [](const Instruction &inst) -> bool {
561 return llvm::isa<Convolution3DInst>(inst);
562 });
563 ASSERT_TRUE(convIt == M.getInstrs().end());
564}
565
566/// Check that we can create conv3D with bfloat16.
567TEST(Graph, bfloat16Conv3DNoLower) {
568 Module MD;
569 Function *F = MD.createFunction("F");
570 PlaceholderBindings bindings;
571 Node *K =
572 MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input");
573
574 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
575 F->createSave("Save", conv);
576 EXPECT_TRUE(conv->verify());
577 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
578 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
579 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
580
581 auto backend = MockBackendNoLowerConv3D();
582 CompilationContext cctx;
583 lower(F, cctx, &backend);
584
585 IRFunction M(F);
586
587 M.generateIR(backend);
588 EXPECT_GT(M.getInstrs().size(), 0);
589 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
590 [](const Instruction &inst) -> bool {
591 return llvm::isa<Convolution3DInst>(inst);
592 });
593 ASSERT_TRUE(convIt != M.getInstrs().end());
594 const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt);
595 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty);
596 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty);
597 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty);
598}
599
600/// Check that we can create batchNorm with float16.
601TEST(Graph, bfloat16BatchNorm) {
602 Module MD;
603 Function *F = MD.createFunction("F");
604 PlaceholderBindings bindings;
605 auto *input = MD.createPlaceholder(ElemKind::BFloat16Ty, {1, 10, 20, 3},
606 "input", false);
607 BatchNormalizationNode *BN =
608 F->createBatchNormalization(bindings, "batch", input, 3, 0.0001, 0.9);
609
610 EXPECT_TRUE(BN->verify());
611 EXPECT_EQ(BN->getResult().getElementType(), ElemKind::BFloat16Ty);
612 EXPECT_EQ(BN->getScale().getElementType(), ElemKind::BFloat16Ty);
613 EXPECT_EQ(BN->getBias().getElementType(), ElemKind::BFloat16Ty);
614 EXPECT_EQ(BN->getMean().getElementType(), ElemKind::BFloat16Ty);
615 EXPECT_EQ(BN->getVar().getElementType(), ElemKind::BFloat16Ty);
616
617 auto backend = MockBackend();
618 CompilationContext cctx;
619 lower(F, cctx, &backend);
620
621 EXPECT_TRUE(std::all_of(
622 F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool {
623 for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
624 if (node.getType(idx)->getElementType() != ElemKind::BFloat16Ty) {
625 return false;
626 }
627 }
628 return true;
629 }));
630}
631
632/// Test that our use lists are correctly reflecting the state of the IR
633/// and in particular that it is not polluted by temporary variable.
634TEST(Graph, useList) {
635 Module MD;
636 Function *F = MD.createFunction("F");
637 IRFunction M(F);
638 PlaceholderBindings bindings;
639 auto *K =
640 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
641
642 EXPECT_EQ(K->getNumUsers(), 0);
643
644 ConvolutionNode *conv = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
645
646 EXPECT_TRUE(K->hasOneUse());
647 EXPECT_EQ(K->getNumUsers(), 1);
648 EXPECT_EQ(conv->getNumUsers(), 0);
649
650 // Although the filter of the convolution is only used by the convolution
651 // node, calling getFilter creates a temporary NodeValue that messes up
652 // with the actual use list.
653 // Therefore those checks are currently inverted but should be
654 // fixed eventually.
655 // Test with implicit temporary NodeValue.
656 EXPECT_TRUE(conv->getFilter().getNode()->hasOneUse());
657 EXPECT_EQ(conv->getFilter().getNode()->getNumUsers(), 1);
658
659 // Test with explicit temporary NodeValue.
660 Node *nodeFilter;
661 {
662 NodeValue tmp = conv->getFilter();
663 EXPECT_TRUE(tmp.getNode()->hasOneUse());
664 EXPECT_EQ(tmp.getNode()->getNumUsers(), 1);
665 nodeFilter = tmp.getNode();
666 // Test with NodeValue still around.
667 EXPECT_TRUE(nodeFilter->hasOneUse());
668 EXPECT_EQ(nodeFilter->getNumUsers(), 1);
669 }
670
671 // Test with NodeValue took out.
672 EXPECT_TRUE(nodeFilter->hasOneUse());
673 EXPECT_EQ(nodeFilter->getNumUsers(), 1);
674
675 // Same kind of test but with the convolution node itself.
676 {
677 NodeValue tmpConvRes(conv, 0);
678 EXPECT_EQ(conv->getNumUsers(), 0);
679 EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 0);
680 }
681
682 // Add a couple of uses to conv and make sure it reflects on its use list.
683 F->createSave("Save", conv, K);
684
685 EXPECT_FALSE(K->hasOneUse());
686 EXPECT_EQ(K->getNumUsers(), 2);
687 EXPECT_EQ(conv->getNumUsers(), 1);
688 EXPECT_TRUE(conv->hasOneUse());
689
690 {
691 NodeValue tmpConvRes(conv, 0);
692 EXPECT_TRUE(tmpConvRes.getNode()->hasOneUse());
693 EXPECT_TRUE(conv->hasOneUse());
694 EXPECT_EQ(conv->getNumUsers(), 1);
695 EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 1);
696 }
697
698 F->createSave("Save", conv, K);
699
700 EXPECT_FALSE(K->hasOneUse());
701 EXPECT_EQ(K->getNumUsers(), 3);
702 EXPECT_EQ(conv->getNumUsers(), 2);
703 EXPECT_FALSE(conv->hasOneUse());
704
705 {
706 NodeValue tmpConvRes(conv, 0);
707 EXPECT_FALSE(tmpConvRes.getNode()->hasOneUse());
708 EXPECT_FALSE(conv->hasOneUse());
709 EXPECT_EQ(conv->getNumUsers(), 2);
710 EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 2);
711 }
712}
713
714TEST(Graph, useListIteration) {
715 Module MD;
716 Function *F = MD.createFunction("F");
717 IRFunction M(F);
718 Node *K =
719 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
720
721 EXPECT_EQ(K->getNumUsers(), 0);
722
723 PlaceholderBindings bindings;
724 ConvolutionNode *conv1 = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
725 ConvolutionNode *conv2 = F->createConv(bindings, "Conv2", K, 16, 3, 2, 3, 1);
726 // Check the number of users for different nodes.
727 EXPECT_EQ(K->getNumUsers(), 2);
728 EXPECT_EQ(conv1->getNumUsers(), 0);
729 EXPECT_TRUE(conv2->getFilter().getNode()->hasOneUse());
730 EXPECT_EQ(conv1->getFilter().getNode()->getNumUsers(), 1);
731 // Check that the first user of K is conv1.
732 EXPECT_EQ(K->getUsers().begin()->getUser(), conv1);
733 // Check that the second user of K is conv2.
734 EXPECT_EQ((++K->getUsers().begin())->getUser(), conv2);
735}
736
737TEST(Graph, simpleTestFC) {
738 unsigned numInputs = 10;
739 Module MD;
740 Function *F = MD.createFunction("F");
741 IRFunction M(F);
742
743 auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A", true);
744 auto *Ex =
745 MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 1}, "Ex", true);
746
747 PlaceholderBindings bindings;
748 Node *O = F->createFullyConnected(bindings, "FC1", A, 6);
749 O = F->createRELU("RELU1", O);
750 O = F->createFullyConnected(bindings, "FC2", O, 1);
751 O = F->createRELU("RELU2", O);
752 O = F->createRegression("Regression", O, Ex);
753 F->createSave("Save", O);
754 F->dump();
755 auto filePath = F->dumpDAG();
756 auto backend = MockBackend();
757 CompilationContext cctx;
758 lower(F, cctx, &backend);
759 ::optimize(F, CompilationMode::Train);
760 M.generateIR(backend);
761 M.dump();
762 EXPECT_GT(M.getInstrs().size(), 0);
763 llvm::sys::fs::remove(filePath);
764}
765
766TEST(Graph, QuantizationProfileNodes) {
767 unsigned numInputs = 10;
768 Module MD;
769 Function *F = MD.createFunction("F");
770 IRFunction M(F);
771
772 auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A", true);
773
774 // Add non float operation, which should not be profiled.
775 auto *outQTy = F->getParent()->uniqueType(glow::ElemKind::Int8QTy,
776 {numInputs, 2}, 1.5, 6);
777 auto *quantize = F->createQuantize("quantize", A, outQTy);
778 // Make sure that quantize is not optimized away.
779 PlaceholderBindings bindings;
780 F->createSave("save", quantize);
781
782 // Multiple nodes read from the same variable.
783 // Only one Quantization Profile node should be created for the output
784 // from the variable.
785 Node *O = F->createFullyConnected(bindings, "FC1", A, 6);
786 Node *C = F->createFullyConnected(bindings, "FC2", A, 6);
787 O = F->createRELU("RELU1", O);
788 F->createSave("save", O);
789 F->createSave("save", C);
790
791 LoweredInfoMap loweredMapForProf;
792 CompilationContext cctx{&bindings, &loweredMapForProf};
793 cctx.precisionConfig.quantMode = QuantizationMode::Profile;
794 std::unique_ptr<Backend> backend(createBackend("Interpreter"));
795 EXIT_ON_ERR(::optimizeFunction(F, *backend, cctx));
796
797 size_t numberOfProfileNodes =
798 std::count_if(F->getNodes().begin(), F->getNodes().end(), [](Node &node) {
799 return llvm::isa<QuantizationProfileNode>(&node);
800 });
801
802 // 1 from A
803 // 8 from two lowered FCs: MM, BA, weight PH, bias PH
804 // 2 from RELU (lowered to Max+Splat)
805 EXPECT_EQ(11, numberOfProfileNodes);
806}
807
808TEST(Graph, simpleQuant) {
809 ExecutionEngine EE;
810 auto &MD = EE.getModule();
811 auto *F = MD.createFunction("main");
812
813 unsigned depth = 16;
814 llvm::SmallVector<unsigned_t, 2> kernels = {5, 5};
815 llvm::SmallVector<unsigned_t, 4> pads = {0, 0, 0, 0};
816 llvm::SmallVector<unsigned_t, 2> steps = {1, 1};
817 unsigned width = 224;
818
819 auto *input = MD.createPlaceholder(ElemKind::Int8QTy, {1, width, width, 3},
820 0.4, 2, "Input", true);
821
822 // Calculate the size and allocate the output buffer.
823 std::array<dim_t, 4> filterDim = {{depth, kernels[0], kernels[1], 3}};
824 auto *filter =
825 MD.createPlaceholder(ElemKind::Int8QTy, filterDim, 3.3, 4, "F", true);
826 auto *bias =
827 MD.createPlaceholder(ElemKind::Int32QTy, {depth}, 1.3, 5, "B", true);
828
829 // Calculate the size and allocate the output buffer.
830 auto outSz = calculateConvPoolOutputDims(width, width, kernels, steps, pads);
831 std::array<dim_t, 4> outDims = {{1, outSz.first, outSz.second, 16}};
832 auto t = F->getParent()->uniqueType(glow::ElemKind::Int8QTy, outDims, 1.5, 6);
833
834 auto *conv =
835 F->createConv("conv", input, filter, bias, t, kernels, steps, pads, 1);
836
837 auto s = conv->getResult().getType()->size();
838 auto *fcFilter =
839 MD.createPlaceholder(ElemKind::Int8QTy, {s, 6}, 0.4, 2, "F", true);
840 auto *fcBias =
841 MD.createPlaceholder(ElemKind::Int32QTy, {6}, 0.4, 2, "B", true);
842 Node *O = F->createFullyConnected("fc1", conv, fcFilter, fcBias);
843 PlaceholderBindings bindings;
844 F->createSave("ret", O);
845 EE.compile(CompilationMode::Infer);
846}
847
848TEST(Graph, quantizeDequantizeNodes) {
849 ExecutionEngine EE;
850 auto &MD = EE.getModule();
851 auto F = MD.createFunction("main");
852
853 auto *input = MD.createPlaceholder(ElemKind::FloatTy, {1, 3}, "Input", true);
854 auto qType = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5);
855
856 auto *Q = F->createQuantize("quantize", input, qType);
857
858 auto transform =
859 F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 1.4, 3);
860 auto *A = F->createRescaleQuantized("rescale", Q, transform);
861
862 auto *D = F->createDequantize("dequantize", A, ElemKind::FloatTy);
863 PlaceholderBindings bindings;
864 F->createSave("ret", D);
865 EE.compile(CompilationMode::Infer);
866}
867
868TEST(Graph, quantizeGather) {
869 ExecutionEngine EE;
870 auto &mod = EE.getModule();
871 auto *F = mod.createFunction("main");
872 auto *input =
873 mod.createPlaceholder(ElemKind::Int8QTy, {2, 2}, 0.4, 2, "input", true);
874 auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {1}, "index", true);
875 auto *gather = F->createGather("gather", input, indices);
876 PlaceholderBindings bindings;
877 F->createSave("ret", gather);
878 EE.compile(CompilationMode::Infer);
879}
880
881TEST(Graph, cloneTest) {
882 Module M;
883 PlaceholderBindings bindings;
884
885 Function *F = M.createFunction("main");
886 Node *K =
887 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
888 Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
889 Node *conv = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
890 Node *relu = F->createRELU("Relu", conv);
891 Node *SM = F->createSoftMax("SoftMax", relu, S);
892 F->createSave("Save", SM);
893
894 auto *newConv = F->addNode(conv->clone());
895 auto *newRelu = F->addNode(relu->clone());
896 auto *newSM = F->addNode(SM->clone());
897
898 EXPECT_TRUE(newConv != conv && conv->isEqual(*newConv));
899 EXPECT_TRUE(newRelu != relu && relu->isEqual(*newRelu));
900 EXPECT_TRUE(newSM != SM && SM->isEqual(*newSM));
901}
902
903TEST(Graph, moduleTest) {
904 Module M;
905 M.createFunction("one");
906 M.createFunction("two");
907 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1", true);
908 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2", true);
909 EXPECT_TRUE(M.hasFunction("one"));
910 EXPECT_TRUE(M.hasFunction("two"));
911 EXPECT_FALSE(M.hasFunction("four"));
912 M.dumpDAG();
913}
914
915TEST(Graph, functionDependenciesTest) {
916 Module M;
917 auto *F1 = M.createFunction("one");
918 auto *F2 = M.createFunction("two");
919 auto *V1 =
920 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1", true);
921 auto *V2 =
922 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2", true);
923 auto *V3 =
924 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V3", true);
925 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V4", true);
926
927 PlaceholderBindings bindings;
928 auto sum = F1->createSub("1_sub_2", V1, V2);
929 F1->createSave("sv", sum, V1);
930 F2->createSave("sv", V3, V2);
931
932 EXPECT_TRUE(M.hasFunction("one"));
933 EXPECT_TRUE(M.hasFunction("two"));
934 EXPECT_FALSE(M.hasFunction("four"));
935 M.dumpDAG();
936}
937
938TEST(Graph, functionCloneTest) {
939 Module M;
940 PlaceholderBindings bindings;
941
942 auto *F = M.createFunction("main");
943 Node *K =
944 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
945 Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
946 Node *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
947 Node *relu = F->createRELU("Relu", conv);
948 Node *concat = F->createConcat("concat", {relu, relu, relu}, 0);
949 Node *SM = F->createSoftMax("SoftMax", concat, S);
950 F->createSave("Save", SM);
951
952 auto *newF = F->clone("new_main");
953
954 EXPECT_TRUE(newF->verify());
955
956 EXPECT_EQ(newF->getNodes().size(), F->getNodes().size());
957 EXPECT_EQ(newF->getParent(), F->getParent());
958}
959
960/// Compile the module \p M inside the execution engine \p EE and then run it
961/// using the provided \p bindings. Use the provided \p inputName and \p
962/// outputName.
963static void compileAndRun(ExecutionEngine &EE, PlaceholderBindings &bindings,
964 Module &M, llvm::StringRef inputName,
965 llvm::StringRef outputName) {
966 EE.compile(glow::CompilationMode::Infer);
967 // Allocate stprage for placeholders and initialize inputs.
968 bindings.allocate(M.getPlaceholderByNameSlow(inputName))
969 ->getHandle()
970 .clear(2.0);
971 bindings.allocate(M.getPlaceholderByNameSlow(outputName));
972 EE.run(bindings);
973}
974
975/// Check the module cloning functionality.
976TEST(Graph, moduleCloneTest) {
977 // State related to the cloned module and its execution.
978 ExecutionEngine clonedEE("Interpreter");
979 Module &clonedM = clonedEE.getModule();
980 PlaceholderBindings clonedBindings;
981 Tensor clonedResult;
982 // State related to the original module and its execution.
983 PlaceholderBindings originalBindings;
984 Tensor originalResult;
985 // Name of the placeholder holding the results of executions.
986 std::string resultName;
987 {
988 // Define the original execution engine and module.
989 ExecutionEngine originalEE("Interpreter");
990 Module &originalM = originalEE.getModule();
991
992 // Create a function.
993 auto *F = originalM.createFunction("main");
994 auto *input1 = originalM.createPlaceholder(ElemKind::FloatTy,
995 {4, 10, 10, 3}, "input", true);
996
997 auto *add = F->createAdd("add", input1, input1);
998 auto *relu = F->createRELU("Relu", add);
999 auto *concat = F->createConcat("concat", {relu, relu, relu}, 0);
1000 auto *C = originalM.createConstant(concat->getResult().getType(), "C");
1001 C->getPayloadMutable().getHandle().clear(1.0f);
1002 auto *SM = F->createAdd("add", concat, C);
1003 auto *SN = F->createSave("Save", SM);
1004 resultName = SN->getPlaceholder()->getName().str();
1005
1006 // Clone the original module into the cloned module.
1007 originalM.clone(&clonedM);
1008 // The cloned module should have the same numer of types, functions,
1009 // constants and placeholders.
1010 EXPECT_EQ(originalM.getFunctions().size(), clonedM.getFunctions().size());
1011 EXPECT_EQ(originalM.getPlaceholders().size(),
1012 clonedM.getPlaceholders().size());
1013 EXPECT_EQ(originalM.getConstants().size(), clonedM.getConstants().size());
1014 EXPECT_EQ(originalM.getTypes().size(), clonedM.getTypes().size());
1015 // String representations of the original and cloned modules should be the
1016 // same.
1017 EXPECT_EQ(originalM.toString(), clonedM.toString());
1018 for (auto *originalF : originalM.getFunctions()) {
1019 EXPECT_EQ(originalF->toString(),
1020 clonedM.getFunction(originalF->getName())->toString());
1021 }
1022
1023 // Compile and run the original module.
1024 compileAndRun(originalEE, originalBindings, originalM, "input", resultName);
1025 // Store the result of running the original module.
1026 originalResult.assign(originalBindings.get(
1027 originalBindings.getPlaceholderByNameSlow(resultName)));
1028 // The old module should be removed when this scope ends. Thus, if the
1029 // cloned module newM refers to any deleted nodes from the original module,
1030 // it would result in a dangling reference and most likely in a crash.
1031 }
1032 // Check that the cloned module is still alive and valid after the original
1033 // module was deleted.
1034 EXPECT_TRUE(clonedM.verify());
1035 // Compile and run the cloned model.
1036 compileAndRun(clonedEE, clonedBindings, clonedM, "input", resultName);
1037 // Store the result of running the cloned module.
1038 clonedResult.assign(
1039 clonedBindings.get(clonedBindings.getPlaceholderByNameSlow(resultName)));
1040 // The results of execution should be exactly the same in both cases.
1041 EXPECT_TRUE(originalResult.isEqual(clonedResult, 0));
1042}
1043
1044TEST(Graph, cloneWithPredicates) {
1045 Module M;
1046 PlaceholderBindings bindings;
1047
1048 auto *F = M.createFunction("main");
1049 auto *input =
1050 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", false);
1051 auto *counters =
1052 M.createPlaceholder(ElemKind::FloatTy, {10}, "counters", false);
1053 auto *reluExt = F->createRELU("reluExt", input);
1054 auto *reluInt = F->createRELU("reluInt", input);
1055 auto *externalPredicate =
1056 M.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate", false);
1057 auto *C10 = F->createSplat("C10", counters->getType(), 10.0);
1058 auto *internalPredicate = F->createCmpLTE("lte", C10, counters);
1059
1060 reluExt->setPredicate(externalPredicate);
1061 reluInt->setPredicate(internalPredicate);
1062
1063 auto *newF = F->clone("new_main");
1064
1065 EXPECT_TRUE(newF->verify());
1066 EXPECT_EQ(newF->getNodes().size(), F->getNodes().size());
1067 EXPECT_EQ(newF->getParent(), F->getParent());
1068
1069 // Original predicates are not changed
1070 EXPECT_EQ(reluExt->getPredicate().getNode(), externalPredicate);
1071 EXPECT_EQ(reluInt->getPredicate().getNode(), internalPredicate);
1072 // Clone of predicate that points to a node outside the graph
1073 // points to the same node (predicate is shared)
1074 EXPECT_EQ(nodeByName(newF, "reluExt")->getPredicate().getNode(),
1075 externalPredicate);
1076 // Clone of predicate that points to a node that belongs to the graph
1077 // points to the predicate clone
1078 EXPECT_EQ(nodeByName(newF, "reluInt")->getPredicate().getNode(),
1079 nodeByName(newF, "lte"));
1080}
1081
1082TEST(Graph, NodeValue) {
1083 ExecutionEngine EE;
1084 auto &mod = EE.getModule();
1085 Function *F = mod.createFunction("main");
1086 PlaceholderBindings bindings;
1087 auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input", true);
1088 bindings.allocate(inputX)->init(Tensor::InitKind::Broadcast, 3.0,
1089 mod.getPRNG());
1090
1091 NodeValue a = F->createAdd("x2", inputX, inputX);
1092 a = F->createAdd("x4", a, a);
1093 a = F->createAdd("x8", a, a);
1094 auto *S = F->createSave("Save", a);
1095 auto *res = bindings.allocate(S->getPlaceholder());
1096
1097 EE.compile(CompilationMode::Infer);
1098
1099 EE.run(bindings);
1100
1101 EXPECT_EQ(res->getHandle().raw(0), 24);
1102}
1103
1104/// Check that by deleting one function, the variables that refernced
1105/// by this function, will reduce its number of uses by one.
1106TEST(Graph, deleteFunction) {
1107 ExecutionEngine EE;
1108 auto &mod = EE.getModule();
1109 Function *F1 = mod.createFunction("f1");
1110 auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input", true);
1111 F1->createLog("log1", inputX);
1112 Function *F2 = mod.createFunction("f2");
1113 F2->createLog("log2", inputX);
1114 // We check the number of user of inputX to be 2 as only F1 and F2 are
1115 // using it.
1116 EXPECT_EQ(inputX->getNumUsers(), 2);
1117 // Erase this function here to see if we can see the number of user of inputX
1118 // reduce to 1.
1119 mod.eraseFunction(F1);
1120 EXPECT_EQ(inputX->getNumUsers(), 1);
1121}
1122
1123TEST(Graph, nodesWithPredicates) {
1124 ExecutionEngine EE;
1125
1126 Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3});
1127
1128 auto &mod = EE.getModule();
1129 Function *F = mod.createFunction("main");
1130 F->setName("interpret");
1131 PlaceholderBindings bindings;
1132 auto *input =
1133 mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input", true);
1134 auto *ex = mod.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "exp", true);
1135 auto *pred =
1136 mod.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate", false);
1137 bindings.allocate(input);
1138 bindings.allocate(ex);
1139 bindings.allocate(pred);
1140
1141 auto *CV0 = F->createConv(bindings, "conv1", input, 16, 5, 1, 2, 1);
1142 auto *RL0 = F->createRELU("relu1", CV0);
1143 auto *MP0 = F->createMaxPool("pool1", RL0, 2, 2, 0);
1144
1145 CV0->setPredicate(pred);
1146 RL0->setPredicate(pred);
1147 MP0->setPredicate(pred);
1148
1149 auto *FCL1 = F->createFullyConnected(bindings, "fc", MP0->getResult(), 10);
1150 auto *RL3 = F->createRELU("relu4", FCL1);
1151 auto *SM = F->createSoftMax("sm", RL3, ex);
1152 auto *save = F->createSave("ret", SM);
1153 bindings.allocate(save->getPlaceholder());
1154
1155 EE.compile(CompilationMode::Infer);
1156
1157 updateInputPlaceholders(bindings, {input}, {&inputs});
1158 EE.run(bindings);
1159}
1160
1161// Return the number of ConvolutionNode after lower.
1162unsigned getConvNodeSize(llvm::StringRef kind) {
1163 Module mod;
1164 Function *F = mod.createFunction("main");
1165 IRFunction M(F);
1166 PlaceholderBindings bindings;
1167 auto *input =
1168 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 32}, "input", true);
1169 ConvolutionNode *CN = F->createConv(bindings, "conv", input, 6, 1, 1, 0, 2);
1170 F->createSave("save", CN);
1171
1172 std::unique_ptr<Backend> backend(createBackend(kind));
1173 CompilationContext cctx;
1174 lower(F, cctx, backend.get());
1175
1176 unsigned count = 0;
1177 for (auto &n : F->getNodes()) {
1178 if (n.getKind() == Kinded::Kind::ConvolutionNodeKind) {
1179 count++;
1180 }
1181 }
1182
1183 if (kind == "Interpreter") {
1184 EXPECT_EQ(count, 1);
1185 }
1186
1187 return count;
1188}
1189
1190// Check the unrolling grouped convolution opt status:
1191// -- disabled for Interpreter, CPU and OpenCL backend,
1192TEST(Graph, disableUnrollingGroupConv) {
1193 unsigned numberOfNodesInterpreter = getConvNodeSize("Interpreter");
1194 (void)numberOfNodesInterpreter;
1195
1196#ifdef GLOW_WITH_CPU
1197 unsigned numberOfNodesCPU = getConvNodeSize("CPU");
1198 EXPECT_EQ(numberOfNodesCPU, numberOfNodesInterpreter);
1199#endif // GLOW_WITH_CPU
1200
1201#ifdef GLOW_WITH_OPENCL
1202 unsigned numberOfNodesOpenCL = getConvNodeSize("OpenCL");
1203 EXPECT_EQ(numberOfNodesOpenCL, numberOfNodesInterpreter);
1204#endif // GLOW_WITH_OPENCL
1205}
1206
1207/// Check that save nodes are properly scheduled.
1208/// That is, they happen after the last use of the related variable.
1209/// In that test, the order of the creation of the nodes give a valid schedule.
1210TEST(Graph, schedulingOfSavesOrderProvided) {
1211 ExecutionEngine EE;
1212
1213 auto &mod = EE.getModule();
1214 Function *F = mod.createFunction("main");
1215 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A", true);
1216 auto *B = mod.createPlaceholder(A->getType(), "B", true);
1217 auto *zero = mod.createPlaceholder(A->getType(), "zero", true);
1218
1219 PlaceholderBindings bindings;
1220 bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1221 bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1222 bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0,
1223 mod.getPRNG());
1224
1225 auto *addAB = F->createAdd("addAB", A, B);
1226
1227 auto *saveNode = F->createSave("ret", addAB);
1228 auto *savePH = saveNode->getPlaceholder();
1229 bindings.allocate(savePH);
1230 F->createSave("resetA", zero, A);
1231
1232 // Copy the value of A.
1233 Tensor AOrig = bindings.get(A)->clone();
1234
1235 EE.compile(CompilationMode::Infer);
1236
1237 EE.run(bindings);
1238 auto *ret = bindings.get(savePH);
1239 auto handleAOrig = AOrig.getHandle<>();
1240 auto handleB = bindings.get(B)->getHandle<>();
1241 auto handleRet = ret->getHandle<>();
1242 bool allEqual = true;
1243 for (unsigned row = 0; row != 3; ++row) {
1244 for (unsigned column = 0; column != 32; ++column) {
1245 allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) ==
1246 handleRet.at({row, column});
1247 }
1248 }
1249 EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0));
1250 EXPECT_TRUE(allEqual);
1251}
1252
1253/// Same as schedulingOfSavesOrderProvided except the order in which the nodes
1254/// are added to the function don't form a valid schedule.
1255/// In other words, the scheduler won't get away with scheduling
1256/// using only the order of the nodes in the list of nodes.
1257TEST(Graph, schedulingOfSaves) {
1258 ExecutionEngine EE;
1259 PlaceholderBindings bindings;
1260
1261 auto &mod = EE.getModule();
1262 Function *F = mod.createFunction("main");
1263 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A", true);
1264 auto *B = mod.createPlaceholder(A->getType(), "B", true);
1265 auto *zero = mod.createPlaceholder(A->getType(), "zero", true);
1266 F->createSave("resetA", zero, A);
1267
1268 bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1269 bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1270 bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0,
1271 mod.getPRNG());
1272
1273 auto *addAB = F->createAdd("addAB", A, B);
1274
1275 auto *saveNode = F->createSave("ret", addAB);
1276 bindings.allocate(saveNode->getPlaceholder());
1277
1278 // Copy the value of A.
1279 Tensor AOrig = bindings.get(A)->clone();
1280 auto *ret = saveNode->getPlaceholder();
1281 EE.compile(CompilationMode::Infer);
1282
1283 EE.run(bindings);
1284
1285 auto handleAOrig = AOrig.getHandle<>();
1286 auto handleB = bindings.get(B)->getHandle<>();
1287 auto handleRet = bindings.get(ret)->getHandle<>();
1288 bool allEqual = true;
1289 for (unsigned row = 0; row != 3; ++row) {
1290 for (unsigned column = 0; column != 32; ++column) {
1291 allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) ==
1292 handleRet.at({row, column});
1293 }
1294 }
1295 EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0));
1296 EXPECT_TRUE(allEqual);
1297}
1298
1299/// Check that the parent link is properly updated while tweaking
1300/// nodes and their function.
1301TEST(Graph, parentLink) {
1302 ExecutionEngine EE;
1303
1304 auto &mod = EE.getModule();
1305 Constant *V =
1306 new Constant("V", mod.uniqueType(ElemKind::FloatTy, {3, 32}), ANY_LAYOUT);
1307
1308 // Variables don't belong to any function...
1309 EXPECT_EQ(V->getParent(), nullptr);
1310 // Even when we create them from a module...
1311 Constant *V2 = mod.createConstant(V->getType(), "V2");
1312 EXPECT_EQ(V2->getParent(), nullptr);
1313
1314 Function *F = mod.createFunction("main");
1315
1316 // Nodes created with function helper belong to the related function.
1317 auto *addNode = F->createAdd("addnode", V, V2);
1318 EXPECT_EQ(addNode->getParent(), F);
1319
1320 // Nodes created directly don't belong to any function.
1321 auto *addNode2 = new AddNode("addnode2", V->getType(), addNode, addNode);
1322 EXPECT_EQ(addNode2->getParent(), nullptr);
1323
1324 // Nodes added to a function belong to that function.
1325 F->addNode(addNode2);
1326 EXPECT_EQ(addNode2->getParent(), F);
1327
1328 // Cloned nodes don't belong to anything.
1329 auto *clonedAddNode = addNode->clone();
1330 EXPECT_EQ(clonedAddNode->getParent(), nullptr);
1331
1332 // Check that the setter properly sets things.
1333 clonedAddNode->setParent(F);
1334 EXPECT_EQ(clonedAddNode->getParent(), F);
1335 clonedAddNode->setParent(nullptr);
1336 EXPECT_EQ(clonedAddNode->getParent(), nullptr);
1337
1338 // Add the cloned node to F so that the memory is properly
1339 // cleaned at the end of the test.
1340 F->addNode(clonedAddNode);
1341 EXPECT_EQ(clonedAddNode->getParent(), F);
1342
1343 delete V;
1344}
1345
1346/// Check that verification can detect that Storage nodes are being used by
1347/// Functions in a Module that doesn't own the Storage nodes.
1348TEST(Graph, moduleLink) {
1349 ExecutionEngine EEA, EEB;
1350
1351 auto &modA = EEA.getModule();
1352 auto &modB = EEB.getModule();
1353
1354 auto *FA = modA.createFunction("FA");
1355 auto *FB = modB.createFunction("FB");
1356
1357 auto *C = modA.createConstant(ElemKind::FloatTy, {1}, "C");
1358 auto *P = modA.createPlaceholder(ElemKind::FloatTy, {1}, "P", false);
1359
1360 auto *AA = FA->createAdd("AA", C, P);
1361 FA->createSave("SA", AA);
1362
1363 // These nodes use Storage nodes that reside in modA
1364 auto *AB = FB->createAdd("AB", C, P);
1365 FB->createSave("SB", AB);
1366
1367 EXPECT_TRUE(modA.verify());
1368 EXPECT_FALSE(
1369 modB.verify()); // Module::verify calls Function::verify on all functions
1370 // within the module, so this should fail
1371}
1372
1373/// Check that Cmp nodes are created with proper output types.
1374TEST(Graph, cmpOutputTypes) {
1375 ExecutionEngine EE;
1376
1377 auto &mod = EE.getModule();
1378 Function *F = mod.createFunction("main");
1379 // Define two different quntized types.
1380 auto qType1 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5);
1381 auto qType2 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.4, 5);
1382 // Define two variables of quantized types.
1383 auto *qv1 = mod.createPlaceholder(qType1, "V1", true);
1384 auto *qv2 = mod.createPlaceholder(qType2, "V2", true);
1385 // Create cmp nodes using quantized inputs.
1386 auto *cmpNode1 = F->createCmpEQ("cmpeq", qv1, qv2);
1387 auto *cmpNode2 = F->createCmpLTE("cmplte", qv1, qv2);
1388 // Check that the output type of cmp nodes is BoolKind.
1389 EXPECT_TRUE(cmpNode1->getResult().getElementType() == ElemKind::BoolTy);
1390 EXPECT_TRUE(cmpNode2->getResult().getElementType() == ElemKind::BoolTy);
1391
1392 // Define a non-quantized type.
1393 auto nqType3 = F->getParent()->uniqueType(ElemKind::FloatTy, {1, 3});
1394 // Define two variables of non-quantized types.
1395 auto *nqv3 = mod.createPlaceholder(nqType3, "V3", true);
1396 auto *nqv4 = mod.createPlaceholder(nqType3, "V4", true);
1397 // Create cmp nodes using non-quantized inputs.
1398 auto *cmpNode3 = F->createCmpEQ("cmpeq", nqv3, nqv4);
1399 auto *cmpNode4 = F->createCmpLTE("cmplte", nqv3, nqv4);
1400 // Check that the output type of cmp nodes is BoolKind.
1401 EXPECT_TRUE(cmpNode3->getResult().getElementType() == ElemKind::BoolTy);
1402 EXPECT_TRUE(cmpNode4->getResult().getElementType() == ElemKind::BoolTy);
1403}
1404
1405/// Check that the users of value are equal to expectedUsers.
1406static bool
1407hasAllTheseUses(const llvm::SmallPtrSetImpl<const Node *> &expectedUsers,
1408 const NodeValue &value) {
1409 llvm::SmallPtrSet<const Node *, 4> uses;
1410 for (const NodeUse &use : value.getUsers()) {
1411 const Node *user = use.getUser();
1412 if (!expectedUsers.count(user)) {
1413 // We found a user that wasn't on the list.
1414 return false;
1415 }
1416 uses.insert(user);
1417 }
1418 return expectedUsers.size() == uses.size();
1419}
1420
1421/// Check that our uses lists are correct for nodes with multiple results.
1422TEST(Graph, usesListsWithSeveralResult) {
1423 ExecutionEngine EE;
1424 PlaceholderBindings bindings;
1425
1426 auto &mod = EE.getModule();
1427 Function *F = mod.createFunction("main");
1428 auto *input =
1429 mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input", true);
1430 auto *topK = F->createTopK("topK", input, 12);
1431 EXPECT_EQ(topK->getNumUsers(), 0);
1432
1433 NodeValue values = topK->getValues();
1434 NodeValue indices = topK->getIndices();
1435 llvm::SmallPtrSet<const Node *, 4> savesOfValues;
1436 llvm::SmallPtrSet<const Node *, 4> savesOfIndices;
1437
1438 EXPECT_EQ(indices.getNumUsers(), 0);
1439 EXPECT_EQ(values.getNumUsers(), 0);
1440
1441 EXPECT_FALSE(indices.hasOneUse());
1442 EXPECT_FALSE(values.hasOneUse());
1443
1444 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1445 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1446
1447 // Now add a user to only one result of the topK node.
1448 savesOfValues.insert(F->createSave("saveValues1", values));
1449
1450 // The whole node should inherit the uses of each of its results.
1451 EXPECT_EQ(topK->getNumUsers(), 1);
1452
1453 // Each result should have its own use list.
1454 EXPECT_EQ(indices.getNumUsers(), 0);
1455 EXPECT_EQ(values.getNumUsers(), 1);
1456
1457 EXPECT_FALSE(indices.hasOneUse());
1458 EXPECT_TRUE(values.hasOneUse());
1459
1460 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1461 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1462
1463 // Add a user to the other result of the topK node.
1464 savesOfIndices.insert(F->createSave("saveIndices1", indices));
1465
1466 // The whole node should inherit the uses of each of its results.
1467 EXPECT_EQ(topK->getNumUsers(), 2);
1468
1469 // Each result should have its own use list.
1470 EXPECT_EQ(indices.getNumUsers(), 1);
1471 EXPECT_EQ(values.getNumUsers(), 1);
1472
1473 EXPECT_TRUE(indices.hasOneUse());
1474 EXPECT_TRUE(values.hasOneUse());
1475
1476 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1477 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1478
1479 // Add a couple more users of values and indices.
1480 // Interleaves the insertions in the uses list for both values and indices.
1481 savesOfValues.insert(F->createSave("saveValues2", values));
1482 savesOfValues.insert(F->createSave("saveValues3", values));
1483 savesOfIndices.insert(F->createSave("saveIndices2", indices));
1484
1485 EXPECT_EQ(topK->getNumUsers(), 5);
1486
1487 EXPECT_EQ(indices.getNumUsers(), 2);
1488 EXPECT_EQ(values.getNumUsers(), 3);
1489
1490 EXPECT_FALSE(indices.hasOneUse());
1491 EXPECT_FALSE(values.hasOneUse());
1492
1493 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1494 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1495}
1496
1497/// Check that our uses lists are correct when accessed through
1498/// NodeValue.
1499TEST(Graph, usesListsThroughNodeValues) {
1500 ExecutionEngine EE;
1501 PlaceholderBindings bindings;
1502
1503 auto &mod = EE.getModule();
1504 Function *F = mod.createFunction("main");
1505 auto *input =
1506 mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input", true);
1507 auto *reLU = F->createRELU("reLU", input);
1508 EXPECT_EQ(reLU->getNumUsers(), 0);
1509
1510 NodeValue values = reLU->getResult();
1511 llvm::SmallPtrSet<const Node *, 4> savesOfValues;
1512
1513 EXPECT_EQ(values.getNumUsers(), 0);
1514
1515 EXPECT_FALSE(values.hasOneUse());
1516
1517 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1518
1519 // Now add a user to only one result of the reLU node.
1520 savesOfValues.insert(F->createSave("saveValues1", values));
1521
1522 // The whole node should inherit the uses of each of its results.
1523 EXPECT_EQ(reLU->getNumUsers(), 1);
1524
1525 // The NodeValue should match.
1526 EXPECT_EQ(values.getNumUsers(), 1);
1527 EXPECT_TRUE(values.hasOneUse());
1528 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1529
1530 // Add one more use.
1531 savesOfValues.insert(F->createSave("saveValues2", values));
1532
1533 // The whole node should inherit the uses of each of its results.
1534 EXPECT_EQ(reLU->getNumUsers(), 2);
1535
1536 EXPECT_EQ(values.getNumUsers(), 2);
1537 EXPECT_FALSE(values.hasOneUse());
1538 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1539
1540 // Add a couple more users.
1541 savesOfValues.insert(F->createSave("saveValues3", values));
1542 savesOfValues.insert(F->createSave("saveValues4", values));
1543
1544 EXPECT_EQ(reLU->getNumUsers(), 4);
1545
1546 EXPECT_EQ(values.getNumUsers(), 4);
1547 EXPECT_FALSE(values.hasOneUse());
1548 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1549}
1550
1551/// Verify that the pre-order visitor works correctly.
1552TEST(Graph, PreOrderTest) {
1553 Module M;
1554 PlaceholderBindings bindings;
1555 auto *F = M.createFunction("main");
1556
1557 auto *input1 =
1558 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1559 auto *input2 =
1560 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1561 SplatNode *zero = F->createSplat("zero", input1->getType(), 0.);
1562 MulNode *mul1 = F->createMul("mul1", zero, input1);
1563 MulNode *mul2 = F->createMul("mul2", zero, input2);
1564 MulNode *mul3 = F->createMul("mul3", mul1, mul2);
1565 SaveNode *ret1 = F->createSave("ret1", mul3);
1566
1567 SplatNode *one = F->createSplat("one", input2->getType(), 1.0);
1568 AddNode *add1 = F->createAdd("add1", input2, one);
1569 AddNode *add2 = F->createAdd("add2", add1, one);
1570 AddNode *add3 = F->createAdd("add3", add2, one);
1571 SaveNode *ret2 = F->createSave("ret2", add2);
1572
1573 GraphPreOrderVisitor visitor(*F);
1574 auto order = visitor.getPreOrder();
1575
1576 ASSERT_EQ(order.size(), 14);
1577 EXPECT_EQ(order[0], ret1);
1578 EXPECT_EQ(order[1], mul3);
1579 EXPECT_EQ(order[2], mul1);
1580 EXPECT_EQ(order[3], zero);
1581 EXPECT_EQ(order[4], input1);
1582 EXPECT_EQ(order[5], mul2);
1583 EXPECT_EQ(order[6], input2);
1584 EXPECT_EQ(order[7], ret1->getOutput());
1585 EXPECT_EQ(order[8], add3);
1586 EXPECT_EQ(order[9], add2);
1587 EXPECT_EQ(order[10], add1);
1588 EXPECT_EQ(order[11], one);
1589 EXPECT_EQ(order[12], ret2);
1590 EXPECT_EQ(order[13], ret2->getOutput());
1591}
1592
1593/// Verify that the post-order visitor works correctly.
1594TEST(Graph, PostOrderTest) {
1595 Module M;
1596 PlaceholderBindings bindings;
1597 auto *F = M.createFunction("main");
1598
1599 auto *input1 =
1600 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1601 auto *input2 =
1602 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1603 SplatNode *zero = F->createSplat("zero", input1->getType(), 0.);
1604 MulNode *mul1 = F->createMul("mul1", zero, input1);
1605 MulNode *mul2 = F->createMul("mul2", zero, input2);
1606 MulNode *mul3 = F->createMul("mul3", mul1, mul2);
1607 SaveNode *ret1 = F->createSave("ret1", mul3);
1608
1609 SplatNode *one = F->createSplat("one", input2->getType(), 1.0);
1610 AddNode *add1 = F->createAdd("add1", input2, one);
1611 AddNode *add2 = F->createAdd("add2", add1, one);
1612 AddNode *add3 = F->createAdd("add3", add2, one);
1613 SaveNode *ret2 = F->createSave("ret2", add2);
1614
1615 GraphPostOrderVisitor visitor(*F);
1616 auto order = visitor.getPostOrder();
1617
1618 ASSERT_EQ(order.size(), 14);
1619 EXPECT_EQ(order[0], zero);
1620 EXPECT_EQ(order[1], input1);
1621 EXPECT_EQ(order[2], mul1);
1622 EXPECT_EQ(order[3], input2);
1623 EXPECT_EQ(order[4], mul2);
1624 EXPECT_EQ(order[5], mul3);
1625 EXPECT_EQ(order[6], ret1->getOutput());
1626 EXPECT_EQ(order[7], ret1);
1627 EXPECT_EQ(order[8], one);
1628 EXPECT_EQ(order[9], add1);
1629 EXPECT_EQ(order[10], add2);
1630 EXPECT_EQ(order[11], add3);
1631 EXPECT_EQ(order[12], ret2->getOutput());
1632 EXPECT_EQ(order[13], ret2);
1633}
1634
1635TEST(Graph, placeholder) {
1636 Module MD;
1637 PlaceholderBindings bindings;
1638 Function *F = MD.createFunction("F");
1639 IRFunction M(F);
1640 Node *K =
1641 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", false);
1642 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", false);
1643
1644 K = F->createFullyConnected(bindings, "FC", K, 10);
1645 K = F->createRELU("Relu", K);
1646 K = F->createSoftMax("SoftMax", K, S);
1647 F->createSave("Save", K);
1648}
1649
1650/// Check that the setType API allows to change the type of the
1651/// related result and only the related result.
1652TEST(Graph, setType) {
1653 Module M;
1654 auto *F = M.createFunction("main");
1655
1656 const dim_t inputDims[] = {4, 10};
1657 const dim_t top5Dims[] = {4, 5};
1658 auto *input =
1659 M.createPlaceholder(ElemKind::FloatTy, inputDims, "input", true);
1660 TopKNode *topK = F->createTopK("add", input, 5);
1661 TypeRef origTopKRes0 = M.uniqueType(ElemKind::FloatTy, top5Dims);
1662 TypeRef origTopKRes1 = M.uniqueType(ElemKind::Int64ITy, top5Dims);
1663
1664 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0);
1665 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1666
1667 // Modify the type of result 0 and make sure type 1 is not
1668 // affected. Similarly the input shouldn't be affected.
1669 TypeRef inputTy = M.uniqueType(ElemKind::FloatTy, inputDims);
1670 TypeRef topKRes0 = M.uniqueType(ElemKind::Float16Ty, top5Dims);
1671 topK->setType(TopKNode::ValuesIdx, topKRes0);
1672 EXPECT_EQ(input->getType(), inputTy);
1673 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0);
1674 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1675
1676 // Make sure the NodeValue API works the same way
1677 // as the Node::setType API.
1678 NodeValue valRes1 = topK->getNthResult(TopKNode::IndicesIdx);
1679 valRes1.setType(topKRes0);
1680 EXPECT_EQ(input->getType(), inputTy);
1681 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0);
1682 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), topKRes0);
1683 EXPECT_EQ(valRes1.getType(), topKRes0);
1684
1685 // Now restore sane types.
1686 NodeValue valRes0 = topK->getNthResult(TopKNode::ValuesIdx);
1687 valRes0.setType(origTopKRes0);
1688 topK->setType(TopKNode::IndicesIdx, origTopKRes1);
1689 EXPECT_EQ(input->getType(), inputTy);
1690 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0);
1691 EXPECT_EQ(valRes0.getType(), origTopKRes0);
1692 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1693 EXPECT_EQ(valRes1.getType(), origTopKRes1);
1694}
1695
1696/// Check that we fixed the bug with Function::eraseNode. This method used to
1697/// erase a node that was equal to the node we wanted to delete, which may be
1698/// two different entities.
1699/// To see this bug in action, we create a bunch of nodes with the same value.
1700/// Then we erase them in reserve order. This reserve ordering was actually
1701/// freeing the node in the original order, thus at some point we try to delete
1702/// a node that has already deleted and an assert (debug mode) or segmentation
1703/// fault (release would occur).
1704/// Note: Which node is actually freed depend on the implementation of
1705/// std::find, thus we cannot really predict when the bug occurs.
1706TEST(Graph, eraseNodeBug) {
1707 Module M;
1708 auto *F = M.createFunction("main");
1709
1710 auto *input = M.createPlaceholder(ElemKind::FloatTy, {3, 2}, "input", true);
1711 std::vector<Node *> ReLUs;
1712 // Create a bunch of ReLUs.
1713 for (unsigned idx = 0; idx != 5; ++idx) {
1714 ReLUs.push_back(F->createRELU("relu", input));
1715 }
1716 // Check that we can erase all the nodes.
1717 for (int idx = 4; idx != -1; --idx) {
1718 F->eraseNode(ReLUs[idx]);
1719 }
1720 EXPECT_EQ(F->getNodes().size(), 0);
1721}
1722
1723/// Verify that two Nodes with different predicates but the same inputs are not
1724/// considered equal.
1725TEST(Graph, nodeEqualityWithDifferentPredicates) {
1726 Module M;
1727 auto *F = M.createFunction("main");
1728
1729 Node *in = M.createPlaceholder(ElemKind::FloatTy, {5}, "in", false);
1730 Node *pred1 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1731 Node *pred2 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1732
1733 Node *RN1 = F->createRELU("relu1", in);
1734 RN1->setPredicate(pred1);
1735
1736 Node *RN2 = F->createRELU("relu2", in);
1737 RN2->setPredicate(pred2);
1738
1739 EXPECT_FALSE(RN1->isEqual(*RN2));
1740}
1741
1742/// Check that verify doesn't allow for multiple writers to the same node.
1743TEST(Graph, verifyOneWriter) {
1744 Module M;
1745 auto *F = M.createFunction("main");
1746
1747 auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input", false);
1748 auto *output = M.createPlaceholder(ElemKind::FloatTy, {5}, "output", false);
1749 F->createSave("Save1", input, output);
1750 F->createSave("Save2", input, output);
1751
1752 EXPECT_FALSE(M.verify());
1753}
1754
1755/// Check that verify doesn't allow for Constants to be written to. Note that
1756/// createSave() cannot do this as the API only accepts Placeholders to write
1757/// to, however it could happen during graph transformations, e.g. via
1758/// replaceAllUsesOfWith() as shown here.
1759TEST(Graph, verifyConstantNoWriters) {
1760 Module M;
1761 auto *F = M.createFunction("main");
1762
1763 auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input", false);
1764 auto *outputPH = M.createPlaceholder(ElemKind::FloatTy, {5}, "outPH", false);
1765 F->createSave("save", input, outputPH);
1766
1767 // Replace the output Placeholder with a Constant. This should fail
1768 // verification.
1769 auto *outputC = M.createConstant(ElemKind::FloatTy, {5}, "outC");
1770 NodeValue(outputPH).replaceAllUsesOfWith(outputC);
1771
1772 EXPECT_FALSE(M.verify());
1773}
1774
1775TEST(Graph, typeUnsafeReplaceAllUsesOfWith) {
1776 Module M;
1777 auto *F = M.createFunction("main");
1778
1779 auto *LHS = M.createPlaceholder(ElemKind::FloatTy, {3, 4}, "A", false);
1780 auto *RHS = M.createPlaceholder(ElemKind::FloatTy, {4, 5}, "B", false);
1781 auto *FC = F->createMatMul("fc", LHS, RHS);
1782 F->createSave("save", FC);
1783
1784 auto newLHS = M.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A", false);
1785 LHS->getOutput().typeUnsafeReplaceAllUsesOfWith(newLHS);
1786}
1787
1788/// Check that the verifier will complain if a constant and its
1789/// underlying tensor have mismatching types.
1790/// Here the constant is updated but not the tensor.
1791TEST(Graph, verifyConstantTensorTypeMatchesConstantTypeChanged) {
1792 Module M;
1793
1794 auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input");
1795 // Fresh constant should verify just fine.
1796 EXPECT_TRUE(input->verify());
1797
1798 input->setType(Storage::OutputIdx, M.uniqueType(ElemKind::Float16Ty, {5}));
1799
1800 EXPECT_FALSE(input->verify());
1801}
1802
1803/// Check that the verifier will complain if a constant and its
1804/// underlying tensor have mismatching types.
1805/// Here the tensor is updated but not the constant.
1806TEST(Graph, verifyConstantTensorTypeMatchesTensorTypeChanged) {
1807 Module M;
1808
1809 auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input");
1810 // Fresh constant should verify just fine.
1811 EXPECT_TRUE(input->verify());
1812 input->getPayloadMutable().convertToType(ElemKind::Float16Ty);
1813
1814 EXPECT_FALSE(input->verify());
1815}
1816
1817/// Check that Constants backed by unowned Tensors are in fact unowned until
1818/// a mutable reference to their payload is obtained at which point the backing
1819/// Tensor is copied and becomes owned.
1820TEST(Graph, verifyConstantWithUnownedTensorCopiesOnWrite) {
1821 Module M;
1822
1823 Tensor originalT(ElemKind::FloatTy, {3});
1824 Tensor unownedT = originalT.getUnowned({3});
1825
1826 auto originalH = originalT.getHandle();
1827
1828 for (size_t i = 0; i < originalT.size(); i++) {
1829 originalH.raw(i) = i;
1830 }
1831
1832 // Both Tensors should have the same underlying memory because unownedT shares
1833 // originalT's memory.
1834 EXPECT_EQ(originalT.getUnsafePtr(), unownedT.getUnsafePtr());
1835
1836 Constant *originalC = M.createConstant("original", std::move(originalT));
1837 Constant *unownedC = M.createConstant("unowned", std::move(unownedT));
1838
1839 const Tensor &originalCT = originalC->getPayload();
1840 const Tensor &unownedCT = unownedC->getPayload();
1841
1842 const auto originalCTH = originalCT.getHandle();
1843 const auto unownedCTH = unownedCT.getHandle();
1844
1845 ASSERT_EQ(originalCTH.size(), unownedCTH.size());
1846
1847 // Both Constants should have the same values because their Tensors have the
1848 // same underlying memory.
1849 for (size_t i = 0; i < originalCTH.size(); i++) {
1850 EXPECT_EQ(i, originalCTH.raw(i));
1851 EXPECT_EQ(i, unownedCTH.raw(i));
1852 }
1853
1854 Tensor &originalCTM = originalC->getPayloadMutable();
1855 auto originalCTMH = originalCTM.getHandle();
1856
1857 // Bump up the value in the original Constant, this change should be
1858 // reflected in the unowned Constant as well.
1859 for (size_t i = 0; i < originalCTMH.size(); i++) {
1860 originalCTMH.raw(i) += 1;
1861 }
1862
1863 // After changing the values in the original Constant, we should see an update
1864 // in the values of the unowned Constant because they share the same
1865 // underlying memory.
1866 for (size_t i = 0; i < unownedCTH.size(); i++) {
1867 EXPECT_EQ(unownedCTH.raw(i), i + 1);
1868 }
1869
1870 Tensor &unownedCTM = unownedC->getPayloadMutable();
1871 auto unownedCTMH = unownedCTM.getHandle();
1872
1873 ASSERT_EQ(originalCTH.size(), unownedCTMH.size());
1874
1875 // After getting a mutable reference to the unowned Constant's payload, the
1876 // underlying memory should have been copied but should still contain the same
1877 // values as it did previously at this point.
1878 EXPECT_NE(unownedCTM.getUnsafePtr(), originalCT.getUnsafePtr());
1879 for (size_t i = 0; i < unownedCTMH.size(); i++) {
1880 EXPECT_EQ(unownedCTMH.raw(i), i + 1);
1881 }
1882
1883 // Bump up the value in the original Constant again, this change should not be
1884 // reflected in the unowned Constant now because at this point, after a
1885 // mutable reference to its payload has been obtained, it should have it's own
1886 // memory.
1887 for (size_t i = 0; i < originalCTMH.size(); i++) {
1888 originalCTMH.raw(i) += 1;
1889 }
1890
1891 // Now that the unowned Constant's payload has been obtained as mutable, it
1892 // should have been copied and thus have its own memory and changes to the
1893 // original constant should not be reflected in the unowned Constant.
1894 for (size_t i = 0; i < unownedCTMH.size(); i++) {
1895 EXPECT_EQ(unownedCTMH.raw(i), i + 1);
1896 }
1897}
1898
1899/// Check that hooking an intermediate node works.
1900TEST(Graph, hookTest) {
1901 Module mod;
1902 auto *F = mod.createFunction("main");
1903 auto *in = mod.createPlaceholder(ElemKind::FloatTy, {1}, "in", false);
1904 auto *relu1 = F->createRELU("relu1", in);
1905 auto *relu2 = F->createRELU("relu2", relu1);
1906 F->createSave("save", relu2);
1907 EXPECT_EQ(F->getNodes().size(), 3);
1908 EXPECT_EQ(mod.getPlaceholders().size(), 2);
1909
1910 // Hook the first relu and verify that the hooked graph looks right.
1911 auto hooked = glow::hookNode(F, relu1);
1912 auto const &nodes = hooked.function->getNodes();
1913 ASSERT_EQ(mod.getPlaceholders().size(), 3);
1914 ASSERT_EQ(nodes.size(), 2);
1915 auto const *hookSave = *hooked.outputSaves.begin();
1916 ASSERT_TRUE(hookSave);
1917 auto *inp = llvm::dyn_cast<ReluNode>(hookSave->getInput());
1918 ASSERT_TRUE(inp);
1919 auto *ph = llvm::dyn_cast<Placeholder>(inp->getInput());
1920 ASSERT_TRUE(ph);
1921 ASSERT_EQ(ph, in);
1922}
1923
1924/// Check that getConstantsSize returns the correct size of constants.
1925TEST(Graph, moduleSize) {
1926 Module mod;
1927
1928 EXPECT_EQ(mod.getConstantsSize(), 0);
1929
1930 auto *cons1 = mod.createConstant(ElemKind::FloatTy, {1}, "var");
1931 EXPECT_EQ(mod.getConstantsSize(), sizeof(float) * cons1->getPayload().size());
1932
1933 auto *cons2 = mod.createConstant(ElemKind::FloatTy, {1, 32, 32, 16}, "var2");
1934 EXPECT_EQ(mod.getConstantsSize(),
1935 sizeof(float) + sizeof(float) * cons2->getPayload().size());
1936}
1937
1938/// Check that getDataSize() returns the correct size of backing tensors.
1939TEST(Graph, contextSize) {
1940 Module mod;
1941 PlaceholderBindings bindings;
1942
1943 Placeholder *PH =
1944 mod.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
1945
1946 EXPECT_EQ(bindings.getDataSize(), 0);
1947 bindings.allocate(PH);
1948 EXPECT_EQ(bindings.get(PH)->size(), 4 * 320 * 200 * 3);
1949 EXPECT_EQ(bindings.getDataSize(), sizeof(float) * bindings.get(PH)->size());
1950}
1951
1952/// Check that clones of the context are distinct and share no references back
1953/// to the original object.
1954TEST(Graph, clonePlaceholderBindings) {
1955 Module mod;
1956
1957 Placeholder *PH1 =
1958 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH1", false);
1959
1960 PlaceholderBindings bindings1;
1961 bindings1.allocate(PH1);
1962
1963 PlaceholderBindings bindings2 = bindings1.clone();
1964
1965 Tensor *t1 = bindings1.get(PH1);
1966 Tensor *t2 = bindings2.get(PH1);
1967
1968 EXPECT_NE(t1, nullptr);
1969 EXPECT_NE(t2, nullptr);
1970 EXPECT_NE(t1, t2);
1971
1972 // The new PlaceholderBindings has no references back, and changing it does
1973 // not affect bindings1
1974 Placeholder *PH2 =
1975 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH2", false);
1976
1977 bindings2.allocate(PH2);
1978 // now exists in bindings1 but not bindings2
1979 EXPECT_EQ(bindings1.get(PH2), nullptr);
1980 EXPECT_NE(bindings2.get(PH2), nullptr);
1981
1982 // Likewise changing bindings1 does not affect bindings2
1983 bindings1.clear();
1984 EXPECT_EQ(bindings1.count(PH1), 0);
1985 EXPECT_EQ(bindings2.count(PH1), 1);
1986
1987 // Adds are distinct
1988 Placeholder *PH3 =
1989 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH3", false);
1990 bindings1.allocate(PH3);
1991 bindings2.allocate(PH3);
1992 EXPECT_NE(bindings1.get(PH3), nullptr);
1993 EXPECT_NE(bindings2.get(PH3), nullptr);
1994 EXPECT_NE(bindings1.get(PH3), bindings2.get(PH3));
1995}
1996
1997/// Check that running a function multiple times on cloned PlaceholderBindingss
1998/// have distinct outputs.
1999TEST(Graph, clonePlaceholderBindingsRuns) {
2000 ExecutionEngine EE;
2001 PseudoRNG PRNG;
2002
2003 Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3});
2004 auto &mod = EE.getModule();
2005 Function *F = mod.createFunction("main");
2006 PlaceholderBindings bindings;
2007 auto *input =
2008 mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input", true);
2009
2010 bindings.allocate(input);
2011
2012 auto *FCL1 = F->createFullyConnected(bindings, "fc", input, 10);
2013 auto *RL3 = F->createRELU("relu4", FCL1);
2014 auto *save = F->createSave("ret", RL3);
2015 auto *savePH = save->getPlaceholder();
2016
2017 bindings.allocate(save->getPlaceholder());
2018
2019 // Compile once.
2020 EE.compile(CompilationMode::Infer);
2021
2022 // Run with random inputs.
2023 inputs.getHandle<>().randomize(-3.0, 3.0, PRNG);
2024 updateInputPlaceholders(bindings, {input}, {&inputs});
2025 EE.run(bindings);
2026
2027 // Clone the context.
2028 PlaceholderBindings bindings2 = bindings.clone();
2029
2030 // PlaceholderBindingss are identical.
2031 Tensor *saveBacking1, *saveBacking2;
2032 saveBacking1 = bindings.get(savePH);
2033 saveBacking2 = bindings2.get(savePH);
2034 EXPECT_NE(saveBacking1, saveBacking2);
2035 EXPECT_EQ(saveBacking1->size(), saveBacking2->size());
2036 EXPECT_TRUE(saveBacking1->isEqual(*saveBacking2));
2037
2038 // Run again with different random inputs using the cloned context.
2039 Tensor inputs2(ElemKind::FloatTy, {1, 32, 32, 3});
2040 inputs2.getHandle<>().randomize(-3.0, 3.0, PRNG);
2041 updateInputPlaceholders(bindings2, {input}, {&inputs2});
2042 EE.run(bindings2);
2043
2044 // PlaceholderBindingss are no longer identical.
2045 EXPECT_EQ(saveBacking1->size(), saveBacking2->size());
2046 EXPECT_FALSE(saveBacking1->isEqual(*saveBacking2));
2047}
2048
2049/// Check that using the indices enums in nodes works correctly, with
2050/// multi-input, multi-output, and single-input/output nodes.
2051TEST(Graph, TestNodeEnums) {
2052 Module MD;
2053 Function *F = MD.createFunction("F");
2054 PlaceholderBindings bindings;
2055 Placeholder *I =
2056 MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
2057 Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output", true);
2058
2059 TopKNode *TKN = F->createTopK("topk", I, 3);
2060 GatherNode *GN =
2061 F->createGather("gather", TKN->getValues(), TKN->getIndices());
2062 TanhNode *TN = F->createTanh("tanh", GN);
2063 SaveNode *SN = F->createSave("save", TN, O);
2064
2065 // Check structure of Placeholders.
2066 EXPECT_EQ(I->getNthResult(Storage::OutputIdx), I->getOutput());
2067 EXPECT_EQ(O->getNthResult(Storage::OutputIdx), O->getOutput());
2068
2069 // Check structure of TopK.
2070 EXPECT_EQ(TKN->getInput(), TKN->getNthInput(TopKNode::InputIdx));
2071 EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx), TKN->getValues());
2072 EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx), TKN->getIndices());
2073
2074 // Check structure of Gather.
2075 EXPECT_EQ(GN->getNthInput(GatherNode::DataIdx), GN->getData());
2076 EXPECT_EQ(GN->getNthInput(GatherNode::IndicesIdx), GN->getIndices());
2077 EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx), GN->getResult());
2078
2079 // Check structure of Tanh.
2080 EXPECT_EQ(TN->getNthInput(TanhNode::InputIdx), TN->getInput());
2081 EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx), TN->getResult());
2082
2083 // Check structure of Save.
2084 EXPECT_EQ(SN->getNthInput(SaveNode::InputIdx), SN->getInput());
2085 EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), SN->getOutput());
2086
2087 // Check connection between Placeholder and TopK.
2088 EXPECT_EQ(TKN->getNthInput(TopKNode::InputIdx), I->getOutput());
2089
2090 // Check connections between TopK and Gather.
2091 EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx),
2092 GN->getNthInput(GatherNode::DataIdx));
2093 EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx),
2094 GN->getNthInput(GatherNode::IndicesIdx));
2095
2096 // Check connection between Gather and Tanh.
2097 EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx),
2098 TN->getNthInput(TanhNode::InputIdx));
2099
2100 // Check connection between Gather and Tanh.
2101 EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx),
2102 SN->getNthInput(SaveNode::InputIdx));
2103
2104 // Check connection between Gather and Tanh.
2105 EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), O->getOutput());
2106}
2107
2108/// Searched \p F for a single instance of a node of Kind T. If more than one is
2109/// found, \returns nullptr, otherwise returns the single instance.
2110template <class T> static T *findSingleInstanceOfNode(Function *F) {
2111 T *found = nullptr;
2112 for (auto &n : F->getNodes()) {
2113 if (auto *currNode = llvm::dyn_cast<T>(&n)) {
2114 if (found != nullptr) {
2115 return nullptr;
2116 }
2117 found = currNode;
2118 }
2119 }
2120 return found;
2121}
2122
2123/// Check that group Conv is not lowered when specified to lower by backend if
2124/// doNotLowerKinds contains Conv.
2125TEST(Graph, GroupTestConvNoLower) {
2126 Module MD;
2127 Function *F = MD.createFunction("F");
2128 IRFunction M(F);
2129 PlaceholderBindings bindings;
2130 Node *K =
2131 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 8}, "input", true);
2132 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
2133
2134 K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, /* group */ 8);
2135 K = F->createRELU("Relu", K);
2136 K = F->createSoftMax("SoftMax", K, S);
2137 F->createSave("Save", K);
2138 F->dump();
2139 auto filePath = F->dumpDAG();
2140 auto backend = MockBackend();
2141
2142 {
2143 // Before we lower, we should have a single Conv node with group = 8.
2144 ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F);
2145 if (!CN) {
2146 llvm::sys::fs::remove(filePath);
2147 }
2148 ASSERT_TRUE(CN);
2149 EXPECT_EQ(CN->getGroup(), 8);
2150 }
2151
2152 // Now lower, but prevent ConvolutionNodeKinds from being lowered.
2153 KindSet doNotLower;
2154 doNotLower.insert(Kinded::Kind::ConvolutionNodeKind);
2155 CompilationContext cctx;
2156 lower(F, cctx, &backend, doNotLower);
2157
2158 {
2159 // Now have lowered but should still have a single Conv node with group = 8.
2160 ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F);
2161 if (!CN) {
2162 llvm::sys::fs::remove(filePath);
2163 }
2164 ASSERT_TRUE(CN);
2165 EXPECT_EQ(CN->getGroup(), 8);
2166 }
2167}
2168
2169/// Check that getOutputSave returns SaveNode object for the correct Placeholder
2170/// and nullptr in other cases.
2171TEST(Graph, GetOutputSaveTest) {
2172 Module MD;
2173 Function *F = MD.createFunction("F");
2174 PlaceholderBindings bindings;
2175 Placeholder *I =
2176 MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
2177 Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output", true);
2178 TopKNode *TKN = F->createTopK("topk", I, 3);
2179 GatherNode *GN =
2180 F->createGather("gather", TKN->getValues(), TKN->getIndices());
2181 TanhNode *TN = F->createTanh("tanh", GN);
2182 SaveNode *SN = F->createSave("save", TN, O);
2183
2184 // Check the return value of getOutputSave method.
2185 // Placeholder parent is null.
2186 auto *FoundNode = glow::getOutputSave(F, O);
2187 EXPECT_NE(nullptr, FoundNode);
2188 EXPECT_EQ(SN, FoundNode);
2189
2190 // Placeholder parent is set to the correct value.
2191 O->setParent(F);
2192 EXPECT_EQ(F, O->getParent());
2193 FoundNode = glow::getOutputSave(F, O);
2194 EXPECT_NE(nullptr, FoundNode);
2195 EXPECT_EQ(SN, FoundNode);
2196
2197 // Invalid placeholder type is provided.
2198 EXPECT_EQ(nullptr, glow::getOutputSave(F, I));
2199
2200 // Save belongs to a different function
2201 Function *F2 = MD.createFunction("F2");
2202 TopKNode *TKN2 = F2->createTopK("topk", I, 3);
2203 GatherNode *GN2 =
2204 F2->createGather("gather", TKN2->getValues(), TKN2->getIndices());
2205 TanhNode *TN2 = F2->createTanh("tanh", GN2);
2206 SaveNode *SN2 = F2->createSave("save", TN2, O);
2207
2208 FoundNode = glow::getOutputSave(F, O);
2209 EXPECT_NE(nullptr, FoundNode);
2210 EXPECT_EQ(SN, FoundNode);
2211
2212 O->setParent(F2);
2213 FoundNode = glow::getOutputSave(F2, O);
2214 EXPECT_NE(nullptr, FoundNode);
2215 EXPECT_EQ(SN2, FoundNode);
2216}
2217
2218/// Check if dump functions work for Node, Function and Module.
2219TEST(Graph, testDumpStructure) {
2220 Module MD;
2221 Function *F = MD.createFunction("F");
2222 IRFunction M(F);
2223 PlaceholderBindings bindings;
2224 Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3},
2225 "input", true);
2226 // Test Node
2227 std::string storageN1;
2228 llvm::raw_string_ostream osN1(storageN1);
2229 K->dump(osN1);
2230 std::string mesN = K->toString();
2231 std::string expectMes = R"(Placeholder
2232Name : input
2233Layout : *
2234Output : float<4 x 320 x 200 x 100 x 3>
2235Trainable : 1
2236Static : 0
2237Users : 0
2238)";
2239 EXPECT_EQ(mesN, expectMes);
2240 EXPECT_EQ(mesN, osN1.str());
2241 std::string storageN2;
2242 llvm::raw_string_ostream osN2(storageN2);
2243 osN2 << K;
2244 EXPECT_EQ(mesN, osN2.str());
2245 // Test Function
2246 Placeholder *I =
2247 MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
2248 I->setStatic(true);
2249 Function *F2 = MD.createFunction("F2");
2250 F2->createTopK("topk", I, 3);
2251 std::string storageF1;
2252 llvm::raw_string_ostream osF1(storageF1);
2253 F2->dump(osF1);
2254 std::string mesF = F2->toString();
2255 std::string expectMesF = R"(Graph structure F2:
2256TopK
2257Name : topk
2258Input : float<10 x 10>
2259K : 3
2260Users : 0
2261Values : float<10 x 3>
2262Indices : index64<10 x 3>
2263Placeholder
2264Name : input__1
2265Layout : *
2266Output : float<10 x 10>
2267Trainable : 1
2268Static : 1
2269Users : 1
2270)";
2271 EXPECT_EQ(mesF, expectMesF);
2272 EXPECT_EQ(mesF, osF1.str());
2273 std::string storageF2;
2274 llvm::raw_string_ostream osF2(storageF2);
2275 osF2 << F2;
2276 EXPECT_EQ(mesF, osF2.str());
2277 storageF1.clear();
2278 F2->dump(osF1, /* skipUsersForStorage */ true);
2279 mesF = F2->toString(/* skipUsersForStorage */ true);
2280 expectMesF = R"(Graph structure F2:
2281TopK
2282Name : topk
2283Input : float<10 x 10>
2284K : 3
2285Users : 0
2286Values : float<10 x 3>
2287Indices : index64<10 x 3>
2288Placeholder
2289Name : input__1
2290Layout : *
2291Output : float<10 x 10>
2292Trainable : 1
2293Static : 1
2294)";
2295 EXPECT_EQ(mesF, expectMesF);
2296 EXPECT_EQ(mesF, osF1.str());
2297 // Test Module
2298 MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy");
2299 std::string storageM1;
2300 llvm::raw_string_ostream osM1(storageM1);
2301 MD.dump(osM1);
2302 std::string mesM = MD.toString();
2303 std::string expectMesM = R"(Module structure:
2304Constant
2305Name : dummy
2306Layout : *
2307Output : float<1 x 1>
2308Users : 0
2309
2310Placeholder
2311Name : input__1
2312Layout : *
2313Output : float<10 x 10>
2314Trainable : 1
2315Static : 1
2316Users : 1
2317
2318Placeholder
2319Name : input
2320Layout : *
2321Output : float<4 x 320 x 200 x 100 x 3>
2322Trainable : 1
2323Static : 0
2324Users : 0
2325
2326Function : F2
2327Function : F
2328)";
2329 EXPECT_EQ(mesM, expectMesM);
2330 EXPECT_EQ(mesM, osM1.str());
2331 std::string storageM2;
2332 llvm::raw_string_ostream osM2(storageM2);
2333 osM2 << MD;
2334 EXPECT_EQ(mesM, osM2.str());
2335}
2336
2337/// Initialize tensor payload for testing purposes. The value at index i is set
2338/// to i.
2339template <typename ElemTy> static void initTensor(Tensor &T) {
2340 Handle<ElemTy> handle = T.getHandle<ElemTy>();
2341 float val = 0;
2342 for (auto &elem : handle) {
2343 elem = val;
2344 val += 1.0;
2345 }
2346}
2347
2348// Test that randomizing Constants in a Function works.
2349TEST(Graph, testRandomizeConstants) {
2350 Module MD;
2351 Function *F = MD.createFunction("F");
2352
2353 // Create tensors to be used in Constants
2354 Tensor floatT(ElemKind::FloatTy, {10});
2355 initTensor<float>(floatT);
2356
2357 Tensor halfT(ElemKind::Float16Ty, {10});
2358 initTensor<float16_t>(halfT);
2359
2360 Tensor bfloat16T(ElemKind::BFloat16Ty, {10});
2361 initTensor<bfloat16_t>(bfloat16T);
2362
2363 Tensor int8QT(ElemKind::Int8QTy, {10}, 1.0, 0);
2364 initTensor<int8_t>(int8QT);
2365
2366 Tensor uint8QT(ElemKind::UInt8QTy, {10}, 1.0, 0);
2367 initTensor<uint8_t>(uint8QT);
2368
2369 Tensor int16QT(ElemKind::Int16QTy, {10}, 1.0, 0);
2370 initTensor<int16_t>(int16QT);
2371
2372 Tensor int32QT(ElemKind::Int32QTy, {10}, 1.0, 0);
2373 initTensor<int32_t>(int32QT);
2374
2375 Tensor int32IT(ElemKind::Int32ITy, {10});
2376 initTensor<int32_t>(int32IT);
2377
2378 Tensor int64IT(ElemKind::Int64ITy, {10});
2379 initTensor<int64_t>(int64IT);
2380
2381 Tensor uint8FusedQT(ElemKind::UInt8FusedQTy, {16, 16}, 1.0, 0);
2382 initTensor<uint8_t>(uint8FusedQT);
2383
2384 Tensor uint8FusedFP16QT(ElemKind::UInt8FusedFP16QTy, {16, 16}, 1.0, 0);
2385 initTensor<uint8_t>(uint8FusedFP16QT);
2386
2387 Tensor uint4FusedFP16QT(ElemKind::UInt4FusedFP16QTy, {16, 16}, 1.0, 0);
2388 initTensor<uint8_t>(uint4FusedFP16QT);
2389
2390 Tensor boolT(ElemKind::BoolTy, {10});
2391 initTensor<bool>(boolT);
2392
2393 // Create Constants and use them in F
2394 auto *floatC = MD.createConstant("floatC", floatT);
2395 F->createAdd("add", floatC, floatC);
2396
2397 auto *halfC = MD.createConstant("halfC", halfT);
2398 F->createAdd("add", halfC, halfC);
2399
2400 auto *bfloat16C = MD.createConstant("bloat16C", bfloat16T);
2401 F->createAdd("add", bfloat16C, bfloat16C);
2402
2403 auto *int8QC = MD.createConstant("int8QC", int8QT);
2404 F->createAdd("add", int8QC, int8QC);
2405
2406 auto *uint8QC = MD.createConstant("uint8QC", uint8QT);
2407 F->createAdd("add", uint8QC, uint8QC);
2408
2409 auto *int16QC = MD.createConstant("int16QC", int16QT);
2410 F->createAdd("add", int16QC, int16QC);
2411
2412 auto *int32QC = MD.createConstant("int32QC", int32QT);
2413 F->createAdd("add", int32QC, int32QC);
2414
2415 auto *int32IC = MD.createConstant("int32IC", int32IT);
2416 F->createAdd("add", int32IC, int32IC);
2417
2418 auto *int64IC = MD.createConstant("int64IC", int64IT);
2419 F->createAdd("add", int64IC, int64IC);
2420
2421 auto *uint8FusedQC = MD.createConstant("uint8FusedQC", uint8FusedQT);
2422 F->createAdd("add", uint8FusedQC, uint8FusedQC);
2423
2424 auto *uint8FusedFP16QC =
2425 MD.createConstant("uint8FusedFP16QC", uint8FusedFP16QT);
2426 F->createAdd("add", uint8FusedFP16QC, uint8FusedFP16QC);
2427
2428 auto *uint4FusedFP16QC =
2429 MD.createConstant("uint4FusedFP16QC", uint4FusedFP16QT);
2430 F->createAdd("add", uint4FusedFP16QC, uint4FusedFP16QC);
2431
2432 auto *boolC = MD.createConstant("boolC", boolT);
2433 F->createAdd("add", boolC, boolC);
2434
2435 // Randomize Constants in F
2436 F->randomizeConstants();
2437
2438 // Check that no Constant is the same as what it started as
2439 EXPECT_FALSE(floatT.isEqual(floatC->getPayload()));
2440 EXPECT_FALSE(halfT.isEqual(halfC->getPayload()));
2441 EXPECT_FALSE(bfloat16T.isEqual(bfloat16C->getPayload()));
2442 EXPECT_FALSE(int8QT.isEqual(int8QC->getPayload()));
2443 EXPECT_FALSE(uint8QT.isEqual(uint8QC->getPayload()));
2444 EXPECT_FALSE(int16QT.isEqual(int16QC->getPayload()));
2445 EXPECT_FALSE(int32QT.isEqual(int32QC->getPayload()));
2446 EXPECT_FALSE(int32IT.isEqual(int32IC->getPayload()));
2447 EXPECT_FALSE(int64IT.isEqual(int64IC->getPayload()));
2448 EXPECT_FALSE(uint8FusedQT.isEqual(uint8FusedQC->getPayload()));
2449 EXPECT_FALSE(uint8FusedFP16QT.isEqual(uint8FusedFP16QC->getPayload()));
2450 EXPECT_FALSE(uint4FusedFP16QT.isEqual(uint4FusedFP16QC->getPayload()));
2451 EXPECT_FALSE(boolT.isEqual(boolC->getPayload()));
2452}
2453
2454TEST(Graph, testSoftmaxMultiplier) {
2455 glow::ExecutionEngine EE;
2456 Module &M = EE.getModule();
2457 Function *F = M.createFunction("F");
2458
2459 float beta = 2.0;
2460
2461 // Create a graph with single softmax.
2462 auto *inputPH =
2463 M.createPlaceholder(ElemKind::FloatTy, {1, 2}, "input", false);
2464 auto *select =
2465 M.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "select", true);
2466 Node *softmaxNode =
2467 F->createSoftMax("softmax", inputPH, select, inputPH->getType(), beta);
2468 auto *saveNode = F->createSave("output", softmaxNode);
2469
2470 PlaceholderBindings bindings;
2471 auto *inputT = bindings.allocate(inputPH);
2472 inputT->getHandle() = {1.0, 2.0};
2473 Tensor expectedT(inputT->getType());
2474 expectedT.getHandle() = {0.11920292, 0.88079703};
2475 auto *outputT = bindings.allocate(saveNode->getPlaceholder());
2476 EE.compile(CompilationMode::Infer);
2477 EE.run(bindings);
2478
2479 EXPECT_TRUE(outputT->isEqual(expectedT));
2480}
2481
2482TEST(Graph, testLogSoftmaxMultiplier) {
2483 glow::ExecutionEngine EE;
2484 Module &M = EE.getModule();
2485 Function *F = M.createFunction("F");
2486
2487 float beta = 2.0;
2488
2489 // Create a graph with single softmax.
2490 auto *inputPH =
2491 M.createPlaceholder(ElemKind::FloatTy, {1, 2}, "input", false);
2492 auto *select =
2493 M.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "select", true);
2494 Node *softmaxNode =
2495 F->createLogSoftMax("softmax", inputPH, select, inputPH->getType(), beta);
2496 auto *saveNode = F->createSave("output", softmaxNode);
2497
2498 PlaceholderBindings bindings;
2499 auto *inputT = bindings.allocate(inputPH);
2500 inputT->getHandle() = {1.0, 2.0};
2501 Tensor expectedT(inputT->getType());
2502 expectedT.getHandle() = {-2.1269, -0.1269};
2503 auto *outputT = bindings.allocate(saveNode->getPlaceholder());
2504 EE.compile(CompilationMode::Infer);
2505 EE.run(bindings);
2506
2507 EXPECT_TRUE(outputT->isEqual(expectedT));
2508}
2509