1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Graph/Graph.h" |
18 | #include "BackendTestUtils.h" |
19 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
20 | #include "glow/Flags/Flags.h" |
21 | #include "glow/Graph/Hook.h" |
22 | #include "glow/Graph/Node.h" |
23 | #include "glow/Graph/Nodes.h" |
24 | #include "glow/Graph/Utils.h" |
25 | #include "glow/IR/IR.h" |
26 | #include "glow/IR/Instrs.h" |
27 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
28 | |
29 | #include "llvm/ADT/SmallPtrSet.h" |
30 | #include "llvm/Support/FileSystem.h" |
31 | |
32 | #include "gtest/gtest.h" |
33 | |
34 | using namespace glow; |
35 | |
36 | // Helper to find a node in the Function by name |
37 | static const Node *nodeByName(const Function *F, const std::string &name) { |
38 | for (auto &n : F->getNodes()) { |
39 | if (n.getName().str() == name) { |
40 | return &n; |
41 | } |
42 | } |
43 | return nullptr; |
44 | } |
45 | |
46 | /// Mock backend that does lower FC nodes. |
47 | class MockBackendNoLowerConv3D : public MockBackend { |
48 | bool shouldLower(const Node *N) const override { |
49 | if (N->getKind() == Kinded::Kind::Convolution3DNodeKind) { |
50 | return false; |
51 | } else { |
52 | return true; |
53 | } |
54 | } |
55 | }; |
56 | |
57 | TEST(Graph, testVariableErasure) { |
58 | Module MD; |
59 | auto &vars = MD.getConstants(); |
60 | EXPECT_EQ(vars.size(), 0); |
61 | EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size()); |
62 | |
63 | Constant *V = MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy" ); |
64 | EXPECT_EQ(vars.size(), 1); |
65 | EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size()); |
66 | |
67 | MD.eraseConstant(V); |
68 | EXPECT_EQ(vars.size(), 0); |
69 | EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size()); |
70 | } |
71 | |
72 | /// Check that the clear method completely reset a module. |
73 | TEST(Graph, clear) { |
74 | Module M; |
75 | |
76 | // Check that the module is initially empty. |
77 | EXPECT_EQ(M.getConstants().size(), 0); |
78 | EXPECT_EQ(M.getPlaceholders().size(), 0); |
79 | EXPECT_EQ(M.getFunctions().size(), 0); |
80 | |
81 | // Create a few things. |
82 | M.createFunction("main" ); |
83 | M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder" , true); |
84 | M.createConstant(ElemKind::FloatTy, {1}, "var" ); |
85 | |
86 | EXPECT_EQ(M.getConstants().size(), 1); |
87 | EXPECT_EQ(M.getPlaceholders().size(), 1); |
88 | EXPECT_EQ(M.getFunctions().size(), 1); |
89 | |
90 | // Check that clearing the module makes it completely free of any kind of |
91 | // objects. |
92 | M.clear(); |
93 | EXPECT_EQ(M.getConstants().size(), 0); |
94 | EXPECT_EQ(M.getPlaceholders().size(), 0); |
95 | EXPECT_EQ(M.getFunctions().size(), 0); |
96 | } |
97 | |
98 | /// Check that the clear method works as expected. |
99 | TEST(Graph, clearFunctions) { |
100 | Module M; |
101 | |
102 | // Check that the module is initially empty. |
103 | EXPECT_EQ(M.getConstants().size(), 0); |
104 | EXPECT_EQ(M.getPlaceholders().size(), 0); |
105 | EXPECT_EQ(M.getFunctions().size(), 0); |
106 | |
107 | // Create a few things. |
108 | Function *F = M.createFunction("main" ); |
109 | auto *PH = M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder" , true); |
110 | auto *C = M.createConstant(ElemKind::FloatTy, {1}, "var" ); |
111 | auto *AN = F->createAdd("add" , PH, C); |
112 | F->createSave("save" , AN); |
113 | |
114 | EXPECT_EQ(M.getConstants().size(), 1); |
115 | EXPECT_EQ(M.getPlaceholders().size(), 2); // Input PH and PH for Save |
116 | EXPECT_EQ(M.getFunctions().size(), 1); |
117 | EXPECT_EQ(F->getNodes().size(), 2); // Add, Save |
118 | |
119 | M.clearFunctions(); |
120 | EXPECT_EQ(M.getConstants().size(), 1); |
121 | EXPECT_EQ(M.getPlaceholders().size(), 2); |
122 | ASSERT_EQ(M.getFunctions().size(), 1); |
123 | // Same Function ptr should exist, just nothing left in them. |
124 | EXPECT_EQ(*M.getFunctions().begin(), F); |
125 | EXPECT_EQ(F->getNodes().size(), 0); |
126 | } |
127 | |
128 | /// Test the graph nodes names and utilities. |
129 | TEST(Graph, testGraphNames) { |
130 | Module MD; |
131 | Function *F = MD.createFunction("F" ); |
132 | |
133 | Node *op1 = MD.createPlaceholder(ElemKind::FloatTy, {1, 10}, "op1" , |
134 | false /*isTrainable*/); |
135 | Node *op2 = MD.createConstant(ElemKind::FloatTy, {1, 10}, "op2" ); |
136 | Node *add = F->createAdd("add" , op1, op2); |
137 | auto *top = F->createTopK("top" , add, 5); |
138 | Node *save = F->createSave("out" , top->getValues()); |
139 | |
140 | EXPECT_TRUE(MD.getPlaceholderByNameSlow("op1" )); |
141 | EXPECT_TRUE(MD.getConstantByName("op2" )); |
142 | EXPECT_TRUE(F->getNodeByName("add" )); |
143 | EXPECT_TRUE(F->getNodeByName("top" )); |
144 | EXPECT_TRUE(F->getNodeByName("out_save" )); |
145 | |
146 | NodeValue op1Res = op1->getNthResult(0); |
147 | NodeValue op2Res = op2->getNthResult(0); |
148 | NodeValue addRes = add->getNthResult(0); |
149 | EXPECT_TRUE(top->getNumResults() == 2); |
150 | NodeValue topValRes = top->getNthResult(0); |
151 | NodeValue topIndRes = top->getNthResult(1); |
152 | |
153 | auto op1ResName = |
154 | op1Res.generateNodeOutputName(false /*stripResNoFor0thInput*/); |
155 | auto op2ResName = |
156 | op2Res.generateNodeOutputName(false /*stripResNoFor0thInput*/); |
157 | auto addResName = |
158 | addRes.generateNodeOutputName(true /*stripResNoFor0thInput*/); |
159 | auto topValResName = |
160 | topValRes.generateNodeOutputName(false /*stripResNoFor0thInput*/); |
161 | auto topIndResName = |
162 | topIndRes.generateNodeOutputName(false /*stripResNoFor0thInput*/); |
163 | |
164 | EXPECT_EQ(op1ResName, "op1:0" ); |
165 | EXPECT_EQ(op2ResName, "op2:0" ); |
166 | EXPECT_EQ(addResName, "add" ); |
167 | EXPECT_EQ(topValResName, "top:0" ); |
168 | EXPECT_EQ(topIndResName, "top:1" ); |
169 | |
170 | EXPECT_EQ(F->getNodeValueByName(op1ResName), op1Res); |
171 | EXPECT_EQ(F->getNodeValueByName(op2ResName), op2Res); |
172 | EXPECT_EQ(F->getNodeValueByName(addResName), addRes); |
173 | EXPECT_EQ(F->getNodeValueByName(topValResName), topValRes); |
174 | EXPECT_EQ(F->getNodeValueByName(topIndResName), topIndRes); |
175 | |
176 | EXPECT_EQ(F->getNodeValueByName("op1" ), op1Res); |
177 | EXPECT_EQ(F->getNodeValueByName("op2" ), op2Res); |
178 | EXPECT_EQ(F->getNodeValueByName("add:0" ), addRes); |
179 | |
180 | // Verify the node value is invalid for the SaveNode which has no outputs. |
181 | EXPECT_EQ(F->getNodeValueByName(save->getName()).getNode(), nullptr); |
182 | } |
183 | |
184 | /// Check node names. |
185 | TEST(Graph, testNodeNames) { |
186 | Module MD; |
187 | Function *F = MD.createFunction("F" ); |
188 | IRFunction M(F); |
189 | PlaceholderBindings bindings; |
190 | Node *K = |
191 | MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
192 | Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select" , true); |
193 | |
194 | K = F->createConv(bindings, "Conv1" , K, 16, 3, 2, 3, 1); |
195 | K = F->createRELU("Relu" , K); |
196 | K = F->createSoftMax("SoftMax" , K, S); |
197 | F->createSave("Save" , K); |
198 | F->dump(); |
199 | auto filePath = F->dumpDAG(); |
200 | auto backend = MockBackend(); |
201 | CompilationContext cctx; |
202 | lower(F, cctx, &backend); |
203 | ::optimize(F, CompilationMode::Train); |
204 | M.generateIR(backend); |
205 | M.dump(); |
206 | EXPECT_GT(M.getInstrs().size(), 0); |
207 | llvm::sys::fs::remove(filePath); |
208 | } |
209 | |
210 | /// Check that a createConv3D can be run. |
211 | TEST(Graph, simpleTestConv3D) { |
212 | Module MD; |
213 | Function *F = MD.createFunction("F" ); |
214 | IRFunction M(F); |
215 | PlaceholderBindings bindings; |
216 | Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3}, |
217 | "input" , true); |
218 | K = F->createConv3D(bindings, /* name */ "Conv3D" , /* input */ K, |
219 | /* outChannels */ 16, /* kernel */ 3, /* stride */ 2, |
220 | /* pad */ 3, /* group */ 1); |
221 | K = F->createRELU("Relu" , K); |
222 | F->createSave("Save" , K); |
223 | F->dump(); |
224 | auto filePath = F->dumpDAG(); |
225 | auto backend = MockBackend(); |
226 | CompilationContext cctx; |
227 | lower(F, cctx, &backend); |
228 | ::optimize(F, CompilationMode::Train); |
229 | M.generateIR(backend); |
230 | M.dump(); |
231 | EXPECT_GT(M.getInstrs().size(), 0); |
232 | llvm::sys::fs::remove(filePath); |
233 | } |
234 | |
235 | /// Tests custom lowering from Node to Instruction IR |
236 | TEST(Graph, simpleTestConvCustomLower) { |
237 | Module MD; |
238 | Function *F = MD.createFunction("F" ); |
239 | IRFunction M(F); |
240 | PlaceholderBindings bindings; |
241 | Node *K = |
242 | MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
243 | Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select" , true); |
244 | |
245 | K = F->createConv(bindings, "Conv1" , K, 16, 3, 2, 3, 1); |
246 | K = F->createRELU("Relu" , K); |
247 | K = F->createSoftMax("SoftMax" , K, S); |
248 | F->createSave("Save" , K); |
249 | F->dump(); |
250 | auto filePath = F->dumpDAG(); |
251 | auto backend = MockBackendCustomIRGen(); |
252 | CompilationContext cctx; |
253 | lower(F, cctx, &backend); |
254 | ::optimize(F, CompilationMode::Train); |
255 | M.generateIR(MockBackendCustomIRGen()); |
256 | M.dump(); |
257 | auto &instrList = M.getInstrs(); |
258 | bool customHappened = false; |
259 | for (auto begin = instrList.begin(); begin != instrList.end(); ++begin) { |
260 | if (begin->getName().equals("CustomConvolutionInstruction" )) { |
261 | customHappened = true; |
262 | break; |
263 | } |
264 | } |
265 | |
266 | EXPECT_EQ(customHappened, true); |
267 | llvm::sys::fs::remove(filePath); |
268 | } |
269 | |
270 | /// Check that we can create convolution with float16. |
271 | TEST(Graph, float16Conv) { |
272 | Module MD; |
273 | Function *F = MD.createFunction("F" ); |
274 | PlaceholderBindings bindings; |
275 | Node *K = MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 3}, "input" ); |
276 | |
277 | auto *conv = F->createConv(bindings, "Conv" , K, 16, 3, 2, 3, 1); |
278 | F->createSave("Save" , conv); |
279 | EXPECT_TRUE(conv->verify()); |
280 | EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty); |
281 | EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty); |
282 | EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty); |
283 | |
284 | auto backend = MockBackend(); |
285 | CompilationContext cctx; |
286 | lower(F, cctx, &backend); |
287 | |
288 | IRFunction M(F); |
289 | |
290 | M.generateIR(backend); |
291 | EXPECT_GT(M.getInstrs().size(), 0); |
292 | auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
293 | [](const Instruction &inst) -> bool { |
294 | return llvm::isa<ConvolutionInst>(inst); |
295 | }); |
296 | ASSERT_TRUE(convIt != M.getInstrs().end()); |
297 | const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt); |
298 | EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty); |
299 | EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty); |
300 | EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty); |
301 | } |
302 | |
303 | /// Check that we can create layernormalization with float16. |
304 | TEST(Graph, float16LayerNorm) { |
305 | const auto origFlagVal = interpreter::flags::LowerLayerNormalization; |
306 | interpreter::flags::LowerLayerNormalization = false; |
307 | |
308 | Module MD; |
309 | Function *F = MD.createFunction("F" ); |
310 | |
311 | PlaceholderBindings bindings; |
312 | auto *input = |
313 | MD.createPlaceholder(ElemKind::Float16Ty, {1, 4, 5, 5}, "in" , false); |
314 | |
315 | Tensor scaleT(ElemKind::Float16Ty, {5, 5}); |
316 | scaleT.getHandle<float16_t>().randomize(0.0f, 1.0f, MD.getPRNG()); |
317 | Constant *scaleC = MD.createConstant("scale" , std::move(scaleT)); |
318 | Tensor biasT(ElemKind::Float16Ty, {5, 5}); |
319 | biasT.getHandle<float16_t>().randomize(0.0f, 1.0f, MD.getPRNG()); |
320 | Constant *biasC = MD.createConstant("bias" , std::move(biasT)); |
321 | |
322 | LayerNormalizationNode *LNN = F->createLayerNormalization( |
323 | "LN" , input->getType(), input, scaleC, biasC, 1e-5); |
324 | F->createSave("Save" , LNN); |
325 | |
326 | std::unique_ptr<const Backend> backend(createBackend("Interpreter" )); |
327 | |
328 | CompilationContext cctx; |
329 | lower(F, cctx, backend.get()); |
330 | |
331 | IRFunction M(F); |
332 | |
333 | M.generateIR(*backend); |
334 | EXPECT_GT(M.getInstrs().size(), 0); |
335 | auto lnIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
336 | [](const Instruction &inst) -> bool { |
337 | return llvm::isa<LayerNormalizationInst>(inst); |
338 | }); |
339 | ASSERT_TRUE(lnIt != M.getInstrs().end()); |
340 | interpreter::flags::LowerLayerNormalization = origFlagVal; |
341 | } |
342 | |
343 | /// Check that we can create batch_matmul with float16. |
344 | TEST(Graph, float16BatchMatMul) { |
345 | const auto origFlagVal = interpreter::flags::LowerBatchMatMul; |
346 | interpreter::flags::LowerBatchMatMul = false; |
347 | |
348 | Module MD; |
349 | Function *F = MD.createFunction("F" ); |
350 | |
351 | PlaceholderBindings bindings; |
352 | auto *LHS = MD.createPlaceholder(ElemKind::Float16Ty, {2, 3, 4}, "A" , false); |
353 | auto *RHS = MD.createPlaceholder(ElemKind::Float16Ty, {2, 4, 5}, "B" , false); |
354 | |
355 | BatchMatMulNode *BMM = F->createBatchMatMul("BMM" , LHS, RHS); |
356 | F->createSave("Save" , BMM); |
357 | |
358 | std::unique_ptr<const Backend> backend(createBackend("Interpreter" )); |
359 | |
360 | CompilationContext cctx; |
361 | lower(F, cctx, backend.get()); |
362 | |
363 | IRFunction M(F); |
364 | |
365 | M.generateIR(*backend); |
366 | EXPECT_GT(M.getInstrs().size(), 0); |
367 | auto bmmIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
368 | [](const Instruction &inst) -> bool { |
369 | return llvm::isa<BatchMatMulInst>(inst); |
370 | }); |
371 | ASSERT_TRUE(bmmIt != M.getInstrs().end()); |
372 | interpreter::flags::LowerBatchMatMul = origFlagVal; |
373 | } |
374 | |
375 | /// Check that we can create batch_matmul with float. |
376 | TEST(Graph, floatBatchMatMul) { |
377 | const auto origFlagVal = interpreter::flags::LowerBatchMatMul; |
378 | interpreter::flags::LowerBatchMatMul = false; |
379 | |
380 | Module MD; |
381 | Function *F = MD.createFunction("F" ); |
382 | |
383 | PlaceholderBindings bindings; |
384 | auto *LHS = MD.createPlaceholder(ElemKind::FloatTy, {2, 3, 4}, "A" , false); |
385 | auto *RHS = MD.createPlaceholder(ElemKind::FloatTy, {2, 4, 5}, "B" , false); |
386 | |
387 | BatchMatMulNode *BMM = F->createBatchMatMul("BMM" , LHS, RHS); |
388 | F->createSave("Save" , BMM); |
389 | |
390 | std::unique_ptr<const Backend> backend(createBackend("Interpreter" )); |
391 | |
392 | CompilationContext cctx; |
393 | lower(F, cctx, backend.get()); |
394 | |
395 | IRFunction M(F); |
396 | |
397 | M.generateIR(*backend); |
398 | EXPECT_GT(M.getInstrs().size(), 0); |
399 | auto bmmIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
400 | [](const Instruction &inst) -> bool { |
401 | return llvm::isa<BatchMatMulInst>(inst); |
402 | }); |
403 | ASSERT_TRUE(bmmIt != M.getInstrs().end()); |
404 | interpreter::flags::LowerBatchMatMul = origFlagVal; |
405 | } |
406 | |
407 | /// Check that we can create conv3D with float16. |
408 | TEST(Graph, float16Conv3DLower) { |
409 | Module MD; |
410 | Function *F = MD.createFunction("F" ); |
411 | PlaceholderBindings bindings; |
412 | Node *K = |
413 | MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input" ); |
414 | |
415 | auto *conv = F->createConv3D(bindings, "Conv3D" , K, 16, 3, 2, 3, 1); |
416 | F->createSave("Save" , conv); |
417 | EXPECT_TRUE(conv->verify()); |
418 | EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty); |
419 | EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty); |
420 | EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty); |
421 | |
422 | auto backend = MockBackend(); |
423 | CompilationContext cctx; |
424 | lower(F, cctx, &backend); |
425 | |
426 | IRFunction M(F); |
427 | |
428 | M.generateIR(backend); |
429 | EXPECT_GT(M.getInstrs().size(), 0); |
430 | auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
431 | [](const Instruction &inst) -> bool { |
432 | return llvm::isa<Convolution3DInst>(inst); |
433 | }); |
434 | ASSERT_TRUE(convIt == M.getInstrs().end()); |
435 | } |
436 | |
437 | /// Check that we can create conv3D with float16. |
438 | TEST(Graph, float16Conv3DNoLower) { |
439 | Module MD; |
440 | Function *F = MD.createFunction("F" ); |
441 | PlaceholderBindings bindings; |
442 | Node *K = |
443 | MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input" ); |
444 | |
445 | auto *conv = F->createConv3D(bindings, "Conv3D" , K, 16, 3, 2, 3, 1); |
446 | F->createSave("Save" , conv); |
447 | EXPECT_TRUE(conv->verify()); |
448 | EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty); |
449 | EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty); |
450 | EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty); |
451 | |
452 | auto backend = MockBackendNoLowerConv3D(); |
453 | CompilationContext cctx; |
454 | lower(F, cctx, &backend); |
455 | |
456 | IRFunction M(F); |
457 | |
458 | M.generateIR(backend); |
459 | EXPECT_GT(M.getInstrs().size(), 0); |
460 | auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
461 | [](const Instruction &inst) -> bool { |
462 | return llvm::isa<Convolution3DInst>(inst); |
463 | }); |
464 | ASSERT_TRUE(convIt != M.getInstrs().end()); |
465 | const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt); |
466 | EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty); |
467 | EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty); |
468 | EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty); |
469 | } |
470 | |
471 | /// Check that we can create batchNorm with float16. |
472 | TEST(Graph, float16BatchNorm) { |
473 | Module MD; |
474 | Function *F = MD.createFunction("F" ); |
475 | PlaceholderBindings bindings; |
476 | auto *input = |
477 | MD.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "input" , false); |
478 | BatchNormalizationNode *BN = |
479 | F->createBatchNormalization(bindings, "batch" , input, 3, 0.0001, 0.9); |
480 | |
481 | EXPECT_TRUE(BN->verify()); |
482 | EXPECT_EQ(BN->getResult().getElementType(), ElemKind::Float16Ty); |
483 | EXPECT_EQ(BN->getScale().getElementType(), ElemKind::Float16Ty); |
484 | EXPECT_EQ(BN->getBias().getElementType(), ElemKind::Float16Ty); |
485 | EXPECT_EQ(BN->getMean().getElementType(), ElemKind::Float16Ty); |
486 | EXPECT_EQ(BN->getVar().getElementType(), ElemKind::Float16Ty); |
487 | |
488 | auto backend = MockBackend(); |
489 | CompilationContext cctx; |
490 | lower(F, cctx, &backend); |
491 | |
492 | EXPECT_TRUE(std::all_of( |
493 | F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool { |
494 | for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) { |
495 | if (node.getType(idx)->getElementType() != ElemKind::Float16Ty) { |
496 | return false; |
497 | } |
498 | } |
499 | return true; |
500 | })); |
501 | } |
502 | |
503 | /// Check that we can create convolution with bfloat16. |
504 | TEST(Graph, bfloat16Conv) { |
505 | Module MD; |
506 | Function *F = MD.createFunction("F" ); |
507 | PlaceholderBindings bindings; |
508 | Node *K = MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 3}, "input" ); |
509 | |
510 | auto *conv = F->createConv(bindings, "Conv" , K, 16, 3, 2, 3, 1); |
511 | F->createSave("Save" , conv); |
512 | EXPECT_TRUE(conv->verify()); |
513 | EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty); |
514 | EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty); |
515 | EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty); |
516 | |
517 | auto backend = MockBackend(); |
518 | CompilationContext cctx; |
519 | lower(F, cctx, &backend); |
520 | |
521 | IRFunction M(F); |
522 | |
523 | M.generateIR(backend); |
524 | EXPECT_GT(M.getInstrs().size(), 0); |
525 | auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
526 | [](const Instruction &inst) -> bool { |
527 | return llvm::isa<ConvolutionInst>(inst); |
528 | }); |
529 | ASSERT_TRUE(convIt != M.getInstrs().end()); |
530 | const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt); |
531 | EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty); |
532 | EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty); |
533 | EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty); |
534 | } |
535 | |
536 | /// Check that we can create conv3D with bfloat16. |
537 | TEST(Graph, bfloat16Conv3DLower) { |
538 | Module MD; |
539 | Function *F = MD.createFunction("F" ); |
540 | PlaceholderBindings bindings; |
541 | Node *K = |
542 | MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input" ); |
543 | |
544 | auto *conv = F->createConv3D(bindings, "Conv3D" , K, 16, 3, 2, 3, 1); |
545 | F->createSave("Save" , conv); |
546 | EXPECT_TRUE(conv->verify()); |
547 | EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty); |
548 | EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty); |
549 | EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty); |
550 | |
551 | auto backend = MockBackend(); |
552 | CompilationContext cctx; |
553 | lower(F, cctx, &backend); |
554 | |
555 | IRFunction M(F); |
556 | |
557 | M.generateIR(backend); |
558 | EXPECT_GT(M.getInstrs().size(), 0); |
559 | auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
560 | [](const Instruction &inst) -> bool { |
561 | return llvm::isa<Convolution3DInst>(inst); |
562 | }); |
563 | ASSERT_TRUE(convIt == M.getInstrs().end()); |
564 | } |
565 | |
566 | /// Check that we can create conv3D with bfloat16. |
567 | TEST(Graph, bfloat16Conv3DNoLower) { |
568 | Module MD; |
569 | Function *F = MD.createFunction("F" ); |
570 | PlaceholderBindings bindings; |
571 | Node *K = |
572 | MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input" ); |
573 | |
574 | auto *conv = F->createConv3D(bindings, "Conv3D" , K, 16, 3, 2, 3, 1); |
575 | F->createSave("Save" , conv); |
576 | EXPECT_TRUE(conv->verify()); |
577 | EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty); |
578 | EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty); |
579 | EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty); |
580 | |
581 | auto backend = MockBackendNoLowerConv3D(); |
582 | CompilationContext cctx; |
583 | lower(F, cctx, &backend); |
584 | |
585 | IRFunction M(F); |
586 | |
587 | M.generateIR(backend); |
588 | EXPECT_GT(M.getInstrs().size(), 0); |
589 | auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(), |
590 | [](const Instruction &inst) -> bool { |
591 | return llvm::isa<Convolution3DInst>(inst); |
592 | }); |
593 | ASSERT_TRUE(convIt != M.getInstrs().end()); |
594 | const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt); |
595 | EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty); |
596 | EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty); |
597 | EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty); |
598 | } |
599 | |
600 | /// Check that we can create batchNorm with float16. |
601 | TEST(Graph, bfloat16BatchNorm) { |
602 | Module MD; |
603 | Function *F = MD.createFunction("F" ); |
604 | PlaceholderBindings bindings; |
605 | auto *input = MD.createPlaceholder(ElemKind::BFloat16Ty, {1, 10, 20, 3}, |
606 | "input" , false); |
607 | BatchNormalizationNode *BN = |
608 | F->createBatchNormalization(bindings, "batch" , input, 3, 0.0001, 0.9); |
609 | |
610 | EXPECT_TRUE(BN->verify()); |
611 | EXPECT_EQ(BN->getResult().getElementType(), ElemKind::BFloat16Ty); |
612 | EXPECT_EQ(BN->getScale().getElementType(), ElemKind::BFloat16Ty); |
613 | EXPECT_EQ(BN->getBias().getElementType(), ElemKind::BFloat16Ty); |
614 | EXPECT_EQ(BN->getMean().getElementType(), ElemKind::BFloat16Ty); |
615 | EXPECT_EQ(BN->getVar().getElementType(), ElemKind::BFloat16Ty); |
616 | |
617 | auto backend = MockBackend(); |
618 | CompilationContext cctx; |
619 | lower(F, cctx, &backend); |
620 | |
621 | EXPECT_TRUE(std::all_of( |
622 | F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool { |
623 | for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) { |
624 | if (node.getType(idx)->getElementType() != ElemKind::BFloat16Ty) { |
625 | return false; |
626 | } |
627 | } |
628 | return true; |
629 | })); |
630 | } |
631 | |
632 | /// Test that our use lists are correctly reflecting the state of the IR |
633 | /// and in particular that it is not polluted by temporary variable. |
634 | TEST(Graph, useList) { |
635 | Module MD; |
636 | Function *F = MD.createFunction("F" ); |
637 | IRFunction M(F); |
638 | PlaceholderBindings bindings; |
639 | auto *K = |
640 | MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
641 | |
642 | EXPECT_EQ(K->getNumUsers(), 0); |
643 | |
644 | ConvolutionNode *conv = F->createConv(bindings, "Conv1" , K, 16, 3, 2, 3, 1); |
645 | |
646 | EXPECT_TRUE(K->hasOneUse()); |
647 | EXPECT_EQ(K->getNumUsers(), 1); |
648 | EXPECT_EQ(conv->getNumUsers(), 0); |
649 | |
650 | // Although the filter of the convolution is only used by the convolution |
651 | // node, calling getFilter creates a temporary NodeValue that messes up |
652 | // with the actual use list. |
653 | // Therefore those checks are currently inverted but should be |
654 | // fixed eventually. |
655 | // Test with implicit temporary NodeValue. |
656 | EXPECT_TRUE(conv->getFilter().getNode()->hasOneUse()); |
657 | EXPECT_EQ(conv->getFilter().getNode()->getNumUsers(), 1); |
658 | |
659 | // Test with explicit temporary NodeValue. |
660 | Node *nodeFilter; |
661 | { |
662 | NodeValue tmp = conv->getFilter(); |
663 | EXPECT_TRUE(tmp.getNode()->hasOneUse()); |
664 | EXPECT_EQ(tmp.getNode()->getNumUsers(), 1); |
665 | nodeFilter = tmp.getNode(); |
666 | // Test with NodeValue still around. |
667 | EXPECT_TRUE(nodeFilter->hasOneUse()); |
668 | EXPECT_EQ(nodeFilter->getNumUsers(), 1); |
669 | } |
670 | |
671 | // Test with NodeValue took out. |
672 | EXPECT_TRUE(nodeFilter->hasOneUse()); |
673 | EXPECT_EQ(nodeFilter->getNumUsers(), 1); |
674 | |
675 | // Same kind of test but with the convolution node itself. |
676 | { |
677 | NodeValue tmpConvRes(conv, 0); |
678 | EXPECT_EQ(conv->getNumUsers(), 0); |
679 | EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 0); |
680 | } |
681 | |
682 | // Add a couple of uses to conv and make sure it reflects on its use list. |
683 | F->createSave("Save" , conv, K); |
684 | |
685 | EXPECT_FALSE(K->hasOneUse()); |
686 | EXPECT_EQ(K->getNumUsers(), 2); |
687 | EXPECT_EQ(conv->getNumUsers(), 1); |
688 | EXPECT_TRUE(conv->hasOneUse()); |
689 | |
690 | { |
691 | NodeValue tmpConvRes(conv, 0); |
692 | EXPECT_TRUE(tmpConvRes.getNode()->hasOneUse()); |
693 | EXPECT_TRUE(conv->hasOneUse()); |
694 | EXPECT_EQ(conv->getNumUsers(), 1); |
695 | EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 1); |
696 | } |
697 | |
698 | F->createSave("Save" , conv, K); |
699 | |
700 | EXPECT_FALSE(K->hasOneUse()); |
701 | EXPECT_EQ(K->getNumUsers(), 3); |
702 | EXPECT_EQ(conv->getNumUsers(), 2); |
703 | EXPECT_FALSE(conv->hasOneUse()); |
704 | |
705 | { |
706 | NodeValue tmpConvRes(conv, 0); |
707 | EXPECT_FALSE(tmpConvRes.getNode()->hasOneUse()); |
708 | EXPECT_FALSE(conv->hasOneUse()); |
709 | EXPECT_EQ(conv->getNumUsers(), 2); |
710 | EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 2); |
711 | } |
712 | } |
713 | |
714 | TEST(Graph, useListIteration) { |
715 | Module MD; |
716 | Function *F = MD.createFunction("F" ); |
717 | IRFunction M(F); |
718 | Node *K = |
719 | MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
720 | |
721 | EXPECT_EQ(K->getNumUsers(), 0); |
722 | |
723 | PlaceholderBindings bindings; |
724 | ConvolutionNode *conv1 = F->createConv(bindings, "Conv1" , K, 16, 3, 2, 3, 1); |
725 | ConvolutionNode *conv2 = F->createConv(bindings, "Conv2" , K, 16, 3, 2, 3, 1); |
726 | // Check the number of users for different nodes. |
727 | EXPECT_EQ(K->getNumUsers(), 2); |
728 | EXPECT_EQ(conv1->getNumUsers(), 0); |
729 | EXPECT_TRUE(conv2->getFilter().getNode()->hasOneUse()); |
730 | EXPECT_EQ(conv1->getFilter().getNode()->getNumUsers(), 1); |
731 | // Check that the first user of K is conv1. |
732 | EXPECT_EQ(K->getUsers().begin()->getUser(), conv1); |
733 | // Check that the second user of K is conv2. |
734 | EXPECT_EQ((++K->getUsers().begin())->getUser(), conv2); |
735 | } |
736 | |
737 | TEST(Graph, simpleTestFC) { |
738 | unsigned numInputs = 10; |
739 | Module MD; |
740 | Function *F = MD.createFunction("F" ); |
741 | IRFunction M(F); |
742 | |
743 | auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A" , true); |
744 | auto *Ex = |
745 | MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 1}, "Ex" , true); |
746 | |
747 | PlaceholderBindings bindings; |
748 | Node *O = F->createFullyConnected(bindings, "FC1" , A, 6); |
749 | O = F->createRELU("RELU1" , O); |
750 | O = F->createFullyConnected(bindings, "FC2" , O, 1); |
751 | O = F->createRELU("RELU2" , O); |
752 | O = F->createRegression("Regression" , O, Ex); |
753 | F->createSave("Save" , O); |
754 | F->dump(); |
755 | auto filePath = F->dumpDAG(); |
756 | auto backend = MockBackend(); |
757 | CompilationContext cctx; |
758 | lower(F, cctx, &backend); |
759 | ::optimize(F, CompilationMode::Train); |
760 | M.generateIR(backend); |
761 | M.dump(); |
762 | EXPECT_GT(M.getInstrs().size(), 0); |
763 | llvm::sys::fs::remove(filePath); |
764 | } |
765 | |
766 | TEST(Graph, QuantizationProfileNodes) { |
767 | unsigned numInputs = 10; |
768 | Module MD; |
769 | Function *F = MD.createFunction("F" ); |
770 | IRFunction M(F); |
771 | |
772 | auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A" , true); |
773 | |
774 | // Add non float operation, which should not be profiled. |
775 | auto *outQTy = F->getParent()->uniqueType(glow::ElemKind::Int8QTy, |
776 | {numInputs, 2}, 1.5, 6); |
777 | auto *quantize = F->createQuantize("quantize" , A, outQTy); |
778 | // Make sure that quantize is not optimized away. |
779 | PlaceholderBindings bindings; |
780 | F->createSave("save" , quantize); |
781 | |
782 | // Multiple nodes read from the same variable. |
783 | // Only one Quantization Profile node should be created for the output |
784 | // from the variable. |
785 | Node *O = F->createFullyConnected(bindings, "FC1" , A, 6); |
786 | Node *C = F->createFullyConnected(bindings, "FC2" , A, 6); |
787 | O = F->createRELU("RELU1" , O); |
788 | F->createSave("save" , O); |
789 | F->createSave("save" , C); |
790 | |
791 | LoweredInfoMap loweredMapForProf; |
792 | CompilationContext cctx{&bindings, &loweredMapForProf}; |
793 | cctx.precisionConfig.quantMode = QuantizationMode::Profile; |
794 | std::unique_ptr<Backend> backend(createBackend("Interpreter" )); |
795 | EXIT_ON_ERR(::optimizeFunction(F, *backend, cctx)); |
796 | |
797 | size_t numberOfProfileNodes = |
798 | std::count_if(F->getNodes().begin(), F->getNodes().end(), [](Node &node) { |
799 | return llvm::isa<QuantizationProfileNode>(&node); |
800 | }); |
801 | |
802 | // 1 from A |
803 | // 8 from two lowered FCs: MM, BA, weight PH, bias PH |
804 | // 2 from RELU (lowered to Max+Splat) |
805 | EXPECT_EQ(11, numberOfProfileNodes); |
806 | } |
807 | |
808 | TEST(Graph, simpleQuant) { |
809 | ExecutionEngine EE; |
810 | auto &MD = EE.getModule(); |
811 | auto *F = MD.createFunction("main" ); |
812 | |
813 | unsigned depth = 16; |
814 | llvm::SmallVector<unsigned_t, 2> kernels = {5, 5}; |
815 | llvm::SmallVector<unsigned_t, 4> pads = {0, 0, 0, 0}; |
816 | llvm::SmallVector<unsigned_t, 2> steps = {1, 1}; |
817 | unsigned width = 224; |
818 | |
819 | auto *input = MD.createPlaceholder(ElemKind::Int8QTy, {1, width, width, 3}, |
820 | 0.4, 2, "Input" , true); |
821 | |
822 | // Calculate the size and allocate the output buffer. |
823 | std::array<dim_t, 4> filterDim = {{depth, kernels[0], kernels[1], 3}}; |
824 | auto *filter = |
825 | MD.createPlaceholder(ElemKind::Int8QTy, filterDim, 3.3, 4, "F" , true); |
826 | auto *bias = |
827 | MD.createPlaceholder(ElemKind::Int32QTy, {depth}, 1.3, 5, "B" , true); |
828 | |
829 | // Calculate the size and allocate the output buffer. |
830 | auto outSz = calculateConvPoolOutputDims(width, width, kernels, steps, pads); |
831 | std::array<dim_t, 4> outDims = {{1, outSz.first, outSz.second, 16}}; |
832 | auto t = F->getParent()->uniqueType(glow::ElemKind::Int8QTy, outDims, 1.5, 6); |
833 | |
834 | auto *conv = |
835 | F->createConv("conv" , input, filter, bias, t, kernels, steps, pads, 1); |
836 | |
837 | auto s = conv->getResult().getType()->size(); |
838 | auto *fcFilter = |
839 | MD.createPlaceholder(ElemKind::Int8QTy, {s, 6}, 0.4, 2, "F" , true); |
840 | auto *fcBias = |
841 | MD.createPlaceholder(ElemKind::Int32QTy, {6}, 0.4, 2, "B" , true); |
842 | Node *O = F->createFullyConnected("fc1" , conv, fcFilter, fcBias); |
843 | PlaceholderBindings bindings; |
844 | F->createSave("ret" , O); |
845 | EE.compile(CompilationMode::Infer); |
846 | } |
847 | |
848 | TEST(Graph, quantizeDequantizeNodes) { |
849 | ExecutionEngine EE; |
850 | auto &MD = EE.getModule(); |
851 | auto F = MD.createFunction("main" ); |
852 | |
853 | auto *input = MD.createPlaceholder(ElemKind::FloatTy, {1, 3}, "Input" , true); |
854 | auto qType = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5); |
855 | |
856 | auto *Q = F->createQuantize("quantize" , input, qType); |
857 | |
858 | auto transform = |
859 | F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 1.4, 3); |
860 | auto *A = F->createRescaleQuantized("rescale" , Q, transform); |
861 | |
862 | auto *D = F->createDequantize("dequantize" , A, ElemKind::FloatTy); |
863 | PlaceholderBindings bindings; |
864 | F->createSave("ret" , D); |
865 | EE.compile(CompilationMode::Infer); |
866 | } |
867 | |
868 | TEST(Graph, quantizeGather) { |
869 | ExecutionEngine EE; |
870 | auto &mod = EE.getModule(); |
871 | auto *F = mod.createFunction("main" ); |
872 | auto *input = |
873 | mod.createPlaceholder(ElemKind::Int8QTy, {2, 2}, 0.4, 2, "input" , true); |
874 | auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {1}, "index" , true); |
875 | auto *gather = F->createGather("gather" , input, indices); |
876 | PlaceholderBindings bindings; |
877 | F->createSave("ret" , gather); |
878 | EE.compile(CompilationMode::Infer); |
879 | } |
880 | |
881 | TEST(Graph, cloneTest) { |
882 | Module M; |
883 | PlaceholderBindings bindings; |
884 | |
885 | Function *F = M.createFunction("main" ); |
886 | Node *K = |
887 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
888 | Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select" , true); |
889 | Node *conv = F->createConv(bindings, "Conv1" , K, 16, 3, 2, 3, 1); |
890 | Node *relu = F->createRELU("Relu" , conv); |
891 | Node *SM = F->createSoftMax("SoftMax" , relu, S); |
892 | F->createSave("Save" , SM); |
893 | |
894 | auto *newConv = F->addNode(conv->clone()); |
895 | auto *newRelu = F->addNode(relu->clone()); |
896 | auto *newSM = F->addNode(SM->clone()); |
897 | |
898 | EXPECT_TRUE(newConv != conv && conv->isEqual(*newConv)); |
899 | EXPECT_TRUE(newRelu != relu && relu->isEqual(*newRelu)); |
900 | EXPECT_TRUE(newSM != SM && SM->isEqual(*newSM)); |
901 | } |
902 | |
903 | TEST(Graph, moduleTest) { |
904 | Module M; |
905 | M.createFunction("one" ); |
906 | M.createFunction("two" ); |
907 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1" , true); |
908 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2" , true); |
909 | EXPECT_TRUE(M.hasFunction("one" )); |
910 | EXPECT_TRUE(M.hasFunction("two" )); |
911 | EXPECT_FALSE(M.hasFunction("four" )); |
912 | M.dumpDAG(); |
913 | } |
914 | |
915 | TEST(Graph, functionDependenciesTest) { |
916 | Module M; |
917 | auto *F1 = M.createFunction("one" ); |
918 | auto *F2 = M.createFunction("two" ); |
919 | auto *V1 = |
920 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1" , true); |
921 | auto *V2 = |
922 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2" , true); |
923 | auto *V3 = |
924 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V3" , true); |
925 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V4" , true); |
926 | |
927 | PlaceholderBindings bindings; |
928 | auto sum = F1->createSub("1_sub_2" , V1, V2); |
929 | F1->createSave("sv" , sum, V1); |
930 | F2->createSave("sv" , V3, V2); |
931 | |
932 | EXPECT_TRUE(M.hasFunction("one" )); |
933 | EXPECT_TRUE(M.hasFunction("two" )); |
934 | EXPECT_FALSE(M.hasFunction("four" )); |
935 | M.dumpDAG(); |
936 | } |
937 | |
938 | TEST(Graph, functionCloneTest) { |
939 | Module M; |
940 | PlaceholderBindings bindings; |
941 | |
942 | auto *F = M.createFunction("main" ); |
943 | Node *K = |
944 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
945 | Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select" , true); |
946 | Node *conv = F->createConv(bindings, "Conv" , K, 16, 3, 2, 3, 1); |
947 | Node *relu = F->createRELU("Relu" , conv); |
948 | Node *concat = F->createConcat("concat" , {relu, relu, relu}, 0); |
949 | Node *SM = F->createSoftMax("SoftMax" , concat, S); |
950 | F->createSave("Save" , SM); |
951 | |
952 | auto *newF = F->clone("new_main" ); |
953 | |
954 | EXPECT_TRUE(newF->verify()); |
955 | |
956 | EXPECT_EQ(newF->getNodes().size(), F->getNodes().size()); |
957 | EXPECT_EQ(newF->getParent(), F->getParent()); |
958 | } |
959 | |
960 | /// Compile the module \p M inside the execution engine \p EE and then run it |
961 | /// using the provided \p bindings. Use the provided \p inputName and \p |
962 | /// outputName. |
963 | static void compileAndRun(ExecutionEngine &EE, PlaceholderBindings &bindings, |
964 | Module &M, llvm::StringRef inputName, |
965 | llvm::StringRef outputName) { |
966 | EE.compile(glow::CompilationMode::Infer); |
967 | // Allocate stprage for placeholders and initialize inputs. |
968 | bindings.allocate(M.getPlaceholderByNameSlow(inputName)) |
969 | ->getHandle() |
970 | .clear(2.0); |
971 | bindings.allocate(M.getPlaceholderByNameSlow(outputName)); |
972 | EE.run(bindings); |
973 | } |
974 | |
975 | /// Check the module cloning functionality. |
976 | TEST(Graph, moduleCloneTest) { |
977 | // State related to the cloned module and its execution. |
978 | ExecutionEngine clonedEE("Interpreter" ); |
979 | Module &clonedM = clonedEE.getModule(); |
980 | PlaceholderBindings clonedBindings; |
981 | Tensor clonedResult; |
982 | // State related to the original module and its execution. |
983 | PlaceholderBindings originalBindings; |
984 | Tensor originalResult; |
985 | // Name of the placeholder holding the results of executions. |
986 | std::string resultName; |
987 | { |
988 | // Define the original execution engine and module. |
989 | ExecutionEngine originalEE("Interpreter" ); |
990 | Module &originalM = originalEE.getModule(); |
991 | |
992 | // Create a function. |
993 | auto *F = originalM.createFunction("main" ); |
994 | auto *input1 = originalM.createPlaceholder(ElemKind::FloatTy, |
995 | {4, 10, 10, 3}, "input" , true); |
996 | |
997 | auto *add = F->createAdd("add" , input1, input1); |
998 | auto *relu = F->createRELU("Relu" , add); |
999 | auto *concat = F->createConcat("concat" , {relu, relu, relu}, 0); |
1000 | auto *C = originalM.createConstant(concat->getResult().getType(), "C" ); |
1001 | C->getPayloadMutable().getHandle().clear(1.0f); |
1002 | auto *SM = F->createAdd("add" , concat, C); |
1003 | auto *SN = F->createSave("Save" , SM); |
1004 | resultName = SN->getPlaceholder()->getName().str(); |
1005 | |
1006 | // Clone the original module into the cloned module. |
1007 | originalM.clone(&clonedM); |
1008 | // The cloned module should have the same numer of types, functions, |
1009 | // constants and placeholders. |
1010 | EXPECT_EQ(originalM.getFunctions().size(), clonedM.getFunctions().size()); |
1011 | EXPECT_EQ(originalM.getPlaceholders().size(), |
1012 | clonedM.getPlaceholders().size()); |
1013 | EXPECT_EQ(originalM.getConstants().size(), clonedM.getConstants().size()); |
1014 | EXPECT_EQ(originalM.getTypes().size(), clonedM.getTypes().size()); |
1015 | // String representations of the original and cloned modules should be the |
1016 | // same. |
1017 | EXPECT_EQ(originalM.toString(), clonedM.toString()); |
1018 | for (auto *originalF : originalM.getFunctions()) { |
1019 | EXPECT_EQ(originalF->toString(), |
1020 | clonedM.getFunction(originalF->getName())->toString()); |
1021 | } |
1022 | |
1023 | // Compile and run the original module. |
1024 | compileAndRun(originalEE, originalBindings, originalM, "input" , resultName); |
1025 | // Store the result of running the original module. |
1026 | originalResult.assign(originalBindings.get( |
1027 | originalBindings.getPlaceholderByNameSlow(resultName))); |
1028 | // The old module should be removed when this scope ends. Thus, if the |
1029 | // cloned module newM refers to any deleted nodes from the original module, |
1030 | // it would result in a dangling reference and most likely in a crash. |
1031 | } |
1032 | // Check that the cloned module is still alive and valid after the original |
1033 | // module was deleted. |
1034 | EXPECT_TRUE(clonedM.verify()); |
1035 | // Compile and run the cloned model. |
1036 | compileAndRun(clonedEE, clonedBindings, clonedM, "input" , resultName); |
1037 | // Store the result of running the cloned module. |
1038 | clonedResult.assign( |
1039 | clonedBindings.get(clonedBindings.getPlaceholderByNameSlow(resultName))); |
1040 | // The results of execution should be exactly the same in both cases. |
1041 | EXPECT_TRUE(originalResult.isEqual(clonedResult, 0)); |
1042 | } |
1043 | |
1044 | TEST(Graph, cloneWithPredicates) { |
1045 | Module M; |
1046 | PlaceholderBindings bindings; |
1047 | |
1048 | auto *F = M.createFunction("main" ); |
1049 | auto *input = |
1050 | M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , false); |
1051 | auto *counters = |
1052 | M.createPlaceholder(ElemKind::FloatTy, {10}, "counters" , false); |
1053 | auto *reluExt = F->createRELU("reluExt" , input); |
1054 | auto *reluInt = F->createRELU("reluInt" , input); |
1055 | auto *externalPredicate = |
1056 | M.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate" , false); |
1057 | auto *C10 = F->createSplat("C10" , counters->getType(), 10.0); |
1058 | auto *internalPredicate = F->createCmpLTE("lte" , C10, counters); |
1059 | |
1060 | reluExt->setPredicate(externalPredicate); |
1061 | reluInt->setPredicate(internalPredicate); |
1062 | |
1063 | auto *newF = F->clone("new_main" ); |
1064 | |
1065 | EXPECT_TRUE(newF->verify()); |
1066 | EXPECT_EQ(newF->getNodes().size(), F->getNodes().size()); |
1067 | EXPECT_EQ(newF->getParent(), F->getParent()); |
1068 | |
1069 | // Original predicates are not changed |
1070 | EXPECT_EQ(reluExt->getPredicate().getNode(), externalPredicate); |
1071 | EXPECT_EQ(reluInt->getPredicate().getNode(), internalPredicate); |
1072 | // Clone of predicate that points to a node outside the graph |
1073 | // points to the same node (predicate is shared) |
1074 | EXPECT_EQ(nodeByName(newF, "reluExt" )->getPredicate().getNode(), |
1075 | externalPredicate); |
1076 | // Clone of predicate that points to a node that belongs to the graph |
1077 | // points to the predicate clone |
1078 | EXPECT_EQ(nodeByName(newF, "reluInt" )->getPredicate().getNode(), |
1079 | nodeByName(newF, "lte" )); |
1080 | } |
1081 | |
1082 | TEST(Graph, NodeValue) { |
1083 | ExecutionEngine EE; |
1084 | auto &mod = EE.getModule(); |
1085 | Function *F = mod.createFunction("main" ); |
1086 | PlaceholderBindings bindings; |
1087 | auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input" , true); |
1088 | bindings.allocate(inputX)->init(Tensor::InitKind::Broadcast, 3.0, |
1089 | mod.getPRNG()); |
1090 | |
1091 | NodeValue a = F->createAdd("x2" , inputX, inputX); |
1092 | a = F->createAdd("x4" , a, a); |
1093 | a = F->createAdd("x8" , a, a); |
1094 | auto *S = F->createSave("Save" , a); |
1095 | auto *res = bindings.allocate(S->getPlaceholder()); |
1096 | |
1097 | EE.compile(CompilationMode::Infer); |
1098 | |
1099 | EE.run(bindings); |
1100 | |
1101 | EXPECT_EQ(res->getHandle().raw(0), 24); |
1102 | } |
1103 | |
1104 | /// Check that by deleting one function, the variables that refernced |
1105 | /// by this function, will reduce its number of uses by one. |
1106 | TEST(Graph, deleteFunction) { |
1107 | ExecutionEngine EE; |
1108 | auto &mod = EE.getModule(); |
1109 | Function *F1 = mod.createFunction("f1" ); |
1110 | auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input" , true); |
1111 | F1->createLog("log1" , inputX); |
1112 | Function *F2 = mod.createFunction("f2" ); |
1113 | F2->createLog("log2" , inputX); |
1114 | // We check the number of user of inputX to be 2 as only F1 and F2 are |
1115 | // using it. |
1116 | EXPECT_EQ(inputX->getNumUsers(), 2); |
1117 | // Erase this function here to see if we can see the number of user of inputX |
1118 | // reduce to 1. |
1119 | mod.eraseFunction(F1); |
1120 | EXPECT_EQ(inputX->getNumUsers(), 1); |
1121 | } |
1122 | |
1123 | TEST(Graph, nodesWithPredicates) { |
1124 | ExecutionEngine EE; |
1125 | |
1126 | Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3}); |
1127 | |
1128 | auto &mod = EE.getModule(); |
1129 | Function *F = mod.createFunction("main" ); |
1130 | F->setName("interpret" ); |
1131 | PlaceholderBindings bindings; |
1132 | auto *input = |
1133 | mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input" , true); |
1134 | auto *ex = mod.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "exp" , true); |
1135 | auto *pred = |
1136 | mod.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate" , false); |
1137 | bindings.allocate(input); |
1138 | bindings.allocate(ex); |
1139 | bindings.allocate(pred); |
1140 | |
1141 | auto *CV0 = F->createConv(bindings, "conv1" , input, 16, 5, 1, 2, 1); |
1142 | auto *RL0 = F->createRELU("relu1" , CV0); |
1143 | auto *MP0 = F->createMaxPool("pool1" , RL0, 2, 2, 0); |
1144 | |
1145 | CV0->setPredicate(pred); |
1146 | RL0->setPredicate(pred); |
1147 | MP0->setPredicate(pred); |
1148 | |
1149 | auto *FCL1 = F->createFullyConnected(bindings, "fc" , MP0->getResult(), 10); |
1150 | auto *RL3 = F->createRELU("relu4" , FCL1); |
1151 | auto *SM = F->createSoftMax("sm" , RL3, ex); |
1152 | auto *save = F->createSave("ret" , SM); |
1153 | bindings.allocate(save->getPlaceholder()); |
1154 | |
1155 | EE.compile(CompilationMode::Infer); |
1156 | |
1157 | updateInputPlaceholders(bindings, {input}, {&inputs}); |
1158 | EE.run(bindings); |
1159 | } |
1160 | |
1161 | // Return the number of ConvolutionNode after lower. |
1162 | unsigned getConvNodeSize(llvm::StringRef kind) { |
1163 | Module mod; |
1164 | Function *F = mod.createFunction("main" ); |
1165 | IRFunction M(F); |
1166 | PlaceholderBindings bindings; |
1167 | auto *input = |
1168 | mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 32}, "input" , true); |
1169 | ConvolutionNode *CN = F->createConv(bindings, "conv" , input, 6, 1, 1, 0, 2); |
1170 | F->createSave("save" , CN); |
1171 | |
1172 | std::unique_ptr<Backend> backend(createBackend(kind)); |
1173 | CompilationContext cctx; |
1174 | lower(F, cctx, backend.get()); |
1175 | |
1176 | unsigned count = 0; |
1177 | for (auto &n : F->getNodes()) { |
1178 | if (n.getKind() == Kinded::Kind::ConvolutionNodeKind) { |
1179 | count++; |
1180 | } |
1181 | } |
1182 | |
1183 | if (kind == "Interpreter" ) { |
1184 | EXPECT_EQ(count, 1); |
1185 | } |
1186 | |
1187 | return count; |
1188 | } |
1189 | |
1190 | // Check the unrolling grouped convolution opt status: |
1191 | // -- disabled for Interpreter, CPU and OpenCL backend, |
1192 | TEST(Graph, disableUnrollingGroupConv) { |
1193 | unsigned numberOfNodesInterpreter = getConvNodeSize("Interpreter" ); |
1194 | (void)numberOfNodesInterpreter; |
1195 | |
1196 | #ifdef GLOW_WITH_CPU |
1197 | unsigned numberOfNodesCPU = getConvNodeSize("CPU" ); |
1198 | EXPECT_EQ(numberOfNodesCPU, numberOfNodesInterpreter); |
1199 | #endif // GLOW_WITH_CPU |
1200 | |
1201 | #ifdef GLOW_WITH_OPENCL |
1202 | unsigned numberOfNodesOpenCL = getConvNodeSize("OpenCL" ); |
1203 | EXPECT_EQ(numberOfNodesOpenCL, numberOfNodesInterpreter); |
1204 | #endif // GLOW_WITH_OPENCL |
1205 | } |
1206 | |
1207 | /// Check that save nodes are properly scheduled. |
1208 | /// That is, they happen after the last use of the related variable. |
1209 | /// In that test, the order of the creation of the nodes give a valid schedule. |
1210 | TEST(Graph, schedulingOfSavesOrderProvided) { |
1211 | ExecutionEngine EE; |
1212 | |
1213 | auto &mod = EE.getModule(); |
1214 | Function *F = mod.createFunction("main" ); |
1215 | auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A" , true); |
1216 | auto *B = mod.createPlaceholder(A->getType(), "B" , true); |
1217 | auto *zero = mod.createPlaceholder(A->getType(), "zero" , true); |
1218 | |
1219 | PlaceholderBindings bindings; |
1220 | bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG()); |
1221 | bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG()); |
1222 | bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0, |
1223 | mod.getPRNG()); |
1224 | |
1225 | auto *addAB = F->createAdd("addAB" , A, B); |
1226 | |
1227 | auto *saveNode = F->createSave("ret" , addAB); |
1228 | auto *savePH = saveNode->getPlaceholder(); |
1229 | bindings.allocate(savePH); |
1230 | F->createSave("resetA" , zero, A); |
1231 | |
1232 | // Copy the value of A. |
1233 | Tensor AOrig = bindings.get(A)->clone(); |
1234 | |
1235 | EE.compile(CompilationMode::Infer); |
1236 | |
1237 | EE.run(bindings); |
1238 | auto *ret = bindings.get(savePH); |
1239 | auto handleAOrig = AOrig.getHandle<>(); |
1240 | auto handleB = bindings.get(B)->getHandle<>(); |
1241 | auto handleRet = ret->getHandle<>(); |
1242 | bool allEqual = true; |
1243 | for (unsigned row = 0; row != 3; ++row) { |
1244 | for (unsigned column = 0; column != 32; ++column) { |
1245 | allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) == |
1246 | handleRet.at({row, column}); |
1247 | } |
1248 | } |
1249 | EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0)); |
1250 | EXPECT_TRUE(allEqual); |
1251 | } |
1252 | |
1253 | /// Same as schedulingOfSavesOrderProvided except the order in which the nodes |
1254 | /// are added to the function don't form a valid schedule. |
1255 | /// In other words, the scheduler won't get away with scheduling |
1256 | /// using only the order of the nodes in the list of nodes. |
1257 | TEST(Graph, schedulingOfSaves) { |
1258 | ExecutionEngine EE; |
1259 | PlaceholderBindings bindings; |
1260 | |
1261 | auto &mod = EE.getModule(); |
1262 | Function *F = mod.createFunction("main" ); |
1263 | auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A" , true); |
1264 | auto *B = mod.createPlaceholder(A->getType(), "B" , true); |
1265 | auto *zero = mod.createPlaceholder(A->getType(), "zero" , true); |
1266 | F->createSave("resetA" , zero, A); |
1267 | |
1268 | bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG()); |
1269 | bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG()); |
1270 | bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0, |
1271 | mod.getPRNG()); |
1272 | |
1273 | auto *addAB = F->createAdd("addAB" , A, B); |
1274 | |
1275 | auto *saveNode = F->createSave("ret" , addAB); |
1276 | bindings.allocate(saveNode->getPlaceholder()); |
1277 | |
1278 | // Copy the value of A. |
1279 | Tensor AOrig = bindings.get(A)->clone(); |
1280 | auto *ret = saveNode->getPlaceholder(); |
1281 | EE.compile(CompilationMode::Infer); |
1282 | |
1283 | EE.run(bindings); |
1284 | |
1285 | auto handleAOrig = AOrig.getHandle<>(); |
1286 | auto handleB = bindings.get(B)->getHandle<>(); |
1287 | auto handleRet = bindings.get(ret)->getHandle<>(); |
1288 | bool allEqual = true; |
1289 | for (unsigned row = 0; row != 3; ++row) { |
1290 | for (unsigned column = 0; column != 32; ++column) { |
1291 | allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) == |
1292 | handleRet.at({row, column}); |
1293 | } |
1294 | } |
1295 | EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0)); |
1296 | EXPECT_TRUE(allEqual); |
1297 | } |
1298 | |
1299 | /// Check that the parent link is properly updated while tweaking |
1300 | /// nodes and their function. |
1301 | TEST(Graph, parentLink) { |
1302 | ExecutionEngine EE; |
1303 | |
1304 | auto &mod = EE.getModule(); |
1305 | Constant *V = |
1306 | new Constant("V" , mod.uniqueType(ElemKind::FloatTy, {3, 32}), ANY_LAYOUT); |
1307 | |
1308 | // Variables don't belong to any function... |
1309 | EXPECT_EQ(V->getParent(), nullptr); |
1310 | // Even when we create them from a module... |
1311 | Constant *V2 = mod.createConstant(V->getType(), "V2" ); |
1312 | EXPECT_EQ(V2->getParent(), nullptr); |
1313 | |
1314 | Function *F = mod.createFunction("main" ); |
1315 | |
1316 | // Nodes created with function helper belong to the related function. |
1317 | auto *addNode = F->createAdd("addnode" , V, V2); |
1318 | EXPECT_EQ(addNode->getParent(), F); |
1319 | |
1320 | // Nodes created directly don't belong to any function. |
1321 | auto *addNode2 = new AddNode("addnode2" , V->getType(), addNode, addNode); |
1322 | EXPECT_EQ(addNode2->getParent(), nullptr); |
1323 | |
1324 | // Nodes added to a function belong to that function. |
1325 | F->addNode(addNode2); |
1326 | EXPECT_EQ(addNode2->getParent(), F); |
1327 | |
1328 | // Cloned nodes don't belong to anything. |
1329 | auto *clonedAddNode = addNode->clone(); |
1330 | EXPECT_EQ(clonedAddNode->getParent(), nullptr); |
1331 | |
1332 | // Check that the setter properly sets things. |
1333 | clonedAddNode->setParent(F); |
1334 | EXPECT_EQ(clonedAddNode->getParent(), F); |
1335 | clonedAddNode->setParent(nullptr); |
1336 | EXPECT_EQ(clonedAddNode->getParent(), nullptr); |
1337 | |
1338 | // Add the cloned node to F so that the memory is properly |
1339 | // cleaned at the end of the test. |
1340 | F->addNode(clonedAddNode); |
1341 | EXPECT_EQ(clonedAddNode->getParent(), F); |
1342 | |
1343 | delete V; |
1344 | } |
1345 | |
1346 | /// Check that verification can detect that Storage nodes are being used by |
1347 | /// Functions in a Module that doesn't own the Storage nodes. |
1348 | TEST(Graph, moduleLink) { |
1349 | ExecutionEngine EEA, EEB; |
1350 | |
1351 | auto &modA = EEA.getModule(); |
1352 | auto &modB = EEB.getModule(); |
1353 | |
1354 | auto *FA = modA.createFunction("FA" ); |
1355 | auto *FB = modB.createFunction("FB" ); |
1356 | |
1357 | auto *C = modA.createConstant(ElemKind::FloatTy, {1}, "C" ); |
1358 | auto *P = modA.createPlaceholder(ElemKind::FloatTy, {1}, "P" , false); |
1359 | |
1360 | auto *AA = FA->createAdd("AA" , C, P); |
1361 | FA->createSave("SA" , AA); |
1362 | |
1363 | // These nodes use Storage nodes that reside in modA |
1364 | auto *AB = FB->createAdd("AB" , C, P); |
1365 | FB->createSave("SB" , AB); |
1366 | |
1367 | EXPECT_TRUE(modA.verify()); |
1368 | EXPECT_FALSE( |
1369 | modB.verify()); // Module::verify calls Function::verify on all functions |
1370 | // within the module, so this should fail |
1371 | } |
1372 | |
1373 | /// Check that Cmp nodes are created with proper output types. |
1374 | TEST(Graph, cmpOutputTypes) { |
1375 | ExecutionEngine EE; |
1376 | |
1377 | auto &mod = EE.getModule(); |
1378 | Function *F = mod.createFunction("main" ); |
1379 | // Define two different quntized types. |
1380 | auto qType1 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5); |
1381 | auto qType2 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.4, 5); |
1382 | // Define two variables of quantized types. |
1383 | auto *qv1 = mod.createPlaceholder(qType1, "V1" , true); |
1384 | auto *qv2 = mod.createPlaceholder(qType2, "V2" , true); |
1385 | // Create cmp nodes using quantized inputs. |
1386 | auto *cmpNode1 = F->createCmpEQ("cmpeq" , qv1, qv2); |
1387 | auto *cmpNode2 = F->createCmpLTE("cmplte" , qv1, qv2); |
1388 | // Check that the output type of cmp nodes is BoolKind. |
1389 | EXPECT_TRUE(cmpNode1->getResult().getElementType() == ElemKind::BoolTy); |
1390 | EXPECT_TRUE(cmpNode2->getResult().getElementType() == ElemKind::BoolTy); |
1391 | |
1392 | // Define a non-quantized type. |
1393 | auto nqType3 = F->getParent()->uniqueType(ElemKind::FloatTy, {1, 3}); |
1394 | // Define two variables of non-quantized types. |
1395 | auto *nqv3 = mod.createPlaceholder(nqType3, "V3" , true); |
1396 | auto *nqv4 = mod.createPlaceholder(nqType3, "V4" , true); |
1397 | // Create cmp nodes using non-quantized inputs. |
1398 | auto *cmpNode3 = F->createCmpEQ("cmpeq" , nqv3, nqv4); |
1399 | auto *cmpNode4 = F->createCmpLTE("cmplte" , nqv3, nqv4); |
1400 | // Check that the output type of cmp nodes is BoolKind. |
1401 | EXPECT_TRUE(cmpNode3->getResult().getElementType() == ElemKind::BoolTy); |
1402 | EXPECT_TRUE(cmpNode4->getResult().getElementType() == ElemKind::BoolTy); |
1403 | } |
1404 | |
1405 | /// Check that the users of value are equal to expectedUsers. |
1406 | static bool |
1407 | hasAllTheseUses(const llvm::SmallPtrSetImpl<const Node *> &expectedUsers, |
1408 | const NodeValue &value) { |
1409 | llvm::SmallPtrSet<const Node *, 4> uses; |
1410 | for (const NodeUse &use : value.getUsers()) { |
1411 | const Node *user = use.getUser(); |
1412 | if (!expectedUsers.count(user)) { |
1413 | // We found a user that wasn't on the list. |
1414 | return false; |
1415 | } |
1416 | uses.insert(user); |
1417 | } |
1418 | return expectedUsers.size() == uses.size(); |
1419 | } |
1420 | |
1421 | /// Check that our uses lists are correct for nodes with multiple results. |
1422 | TEST(Graph, usesListsWithSeveralResult) { |
1423 | ExecutionEngine EE; |
1424 | PlaceholderBindings bindings; |
1425 | |
1426 | auto &mod = EE.getModule(); |
1427 | Function *F = mod.createFunction("main" ); |
1428 | auto *input = |
1429 | mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input" , true); |
1430 | auto *topK = F->createTopK("topK" , input, 12); |
1431 | EXPECT_EQ(topK->getNumUsers(), 0); |
1432 | |
1433 | NodeValue values = topK->getValues(); |
1434 | NodeValue indices = topK->getIndices(); |
1435 | llvm::SmallPtrSet<const Node *, 4> savesOfValues; |
1436 | llvm::SmallPtrSet<const Node *, 4> savesOfIndices; |
1437 | |
1438 | EXPECT_EQ(indices.getNumUsers(), 0); |
1439 | EXPECT_EQ(values.getNumUsers(), 0); |
1440 | |
1441 | EXPECT_FALSE(indices.hasOneUse()); |
1442 | EXPECT_FALSE(values.hasOneUse()); |
1443 | |
1444 | EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices)); |
1445 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1446 | |
1447 | // Now add a user to only one result of the topK node. |
1448 | savesOfValues.insert(F->createSave("saveValues1" , values)); |
1449 | |
1450 | // The whole node should inherit the uses of each of its results. |
1451 | EXPECT_EQ(topK->getNumUsers(), 1); |
1452 | |
1453 | // Each result should have its own use list. |
1454 | EXPECT_EQ(indices.getNumUsers(), 0); |
1455 | EXPECT_EQ(values.getNumUsers(), 1); |
1456 | |
1457 | EXPECT_FALSE(indices.hasOneUse()); |
1458 | EXPECT_TRUE(values.hasOneUse()); |
1459 | |
1460 | EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices)); |
1461 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1462 | |
1463 | // Add a user to the other result of the topK node. |
1464 | savesOfIndices.insert(F->createSave("saveIndices1" , indices)); |
1465 | |
1466 | // The whole node should inherit the uses of each of its results. |
1467 | EXPECT_EQ(topK->getNumUsers(), 2); |
1468 | |
1469 | // Each result should have its own use list. |
1470 | EXPECT_EQ(indices.getNumUsers(), 1); |
1471 | EXPECT_EQ(values.getNumUsers(), 1); |
1472 | |
1473 | EXPECT_TRUE(indices.hasOneUse()); |
1474 | EXPECT_TRUE(values.hasOneUse()); |
1475 | |
1476 | EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices)); |
1477 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1478 | |
1479 | // Add a couple more users of values and indices. |
1480 | // Interleaves the insertions in the uses list for both values and indices. |
1481 | savesOfValues.insert(F->createSave("saveValues2" , values)); |
1482 | savesOfValues.insert(F->createSave("saveValues3" , values)); |
1483 | savesOfIndices.insert(F->createSave("saveIndices2" , indices)); |
1484 | |
1485 | EXPECT_EQ(topK->getNumUsers(), 5); |
1486 | |
1487 | EXPECT_EQ(indices.getNumUsers(), 2); |
1488 | EXPECT_EQ(values.getNumUsers(), 3); |
1489 | |
1490 | EXPECT_FALSE(indices.hasOneUse()); |
1491 | EXPECT_FALSE(values.hasOneUse()); |
1492 | |
1493 | EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices)); |
1494 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1495 | } |
1496 | |
1497 | /// Check that our uses lists are correct when accessed through |
1498 | /// NodeValue. |
1499 | TEST(Graph, usesListsThroughNodeValues) { |
1500 | ExecutionEngine EE; |
1501 | PlaceholderBindings bindings; |
1502 | |
1503 | auto &mod = EE.getModule(); |
1504 | Function *F = mod.createFunction("main" ); |
1505 | auto *input = |
1506 | mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input" , true); |
1507 | auto *reLU = F->createRELU("reLU" , input); |
1508 | EXPECT_EQ(reLU->getNumUsers(), 0); |
1509 | |
1510 | NodeValue values = reLU->getResult(); |
1511 | llvm::SmallPtrSet<const Node *, 4> savesOfValues; |
1512 | |
1513 | EXPECT_EQ(values.getNumUsers(), 0); |
1514 | |
1515 | EXPECT_FALSE(values.hasOneUse()); |
1516 | |
1517 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1518 | |
1519 | // Now add a user to only one result of the reLU node. |
1520 | savesOfValues.insert(F->createSave("saveValues1" , values)); |
1521 | |
1522 | // The whole node should inherit the uses of each of its results. |
1523 | EXPECT_EQ(reLU->getNumUsers(), 1); |
1524 | |
1525 | // The NodeValue should match. |
1526 | EXPECT_EQ(values.getNumUsers(), 1); |
1527 | EXPECT_TRUE(values.hasOneUse()); |
1528 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1529 | |
1530 | // Add one more use. |
1531 | savesOfValues.insert(F->createSave("saveValues2" , values)); |
1532 | |
1533 | // The whole node should inherit the uses of each of its results. |
1534 | EXPECT_EQ(reLU->getNumUsers(), 2); |
1535 | |
1536 | EXPECT_EQ(values.getNumUsers(), 2); |
1537 | EXPECT_FALSE(values.hasOneUse()); |
1538 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1539 | |
1540 | // Add a couple more users. |
1541 | savesOfValues.insert(F->createSave("saveValues3" , values)); |
1542 | savesOfValues.insert(F->createSave("saveValues4" , values)); |
1543 | |
1544 | EXPECT_EQ(reLU->getNumUsers(), 4); |
1545 | |
1546 | EXPECT_EQ(values.getNumUsers(), 4); |
1547 | EXPECT_FALSE(values.hasOneUse()); |
1548 | EXPECT_TRUE(hasAllTheseUses(savesOfValues, values)); |
1549 | } |
1550 | |
1551 | /// Verify that the pre-order visitor works correctly. |
1552 | TEST(Graph, PreOrderTest) { |
1553 | Module M; |
1554 | PlaceholderBindings bindings; |
1555 | auto *F = M.createFunction("main" ); |
1556 | |
1557 | auto *input1 = |
1558 | M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1" , true); |
1559 | auto *input2 = |
1560 | M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2" , true); |
1561 | SplatNode *zero = F->createSplat("zero" , input1->getType(), 0.); |
1562 | MulNode *mul1 = F->createMul("mul1" , zero, input1); |
1563 | MulNode *mul2 = F->createMul("mul2" , zero, input2); |
1564 | MulNode *mul3 = F->createMul("mul3" , mul1, mul2); |
1565 | SaveNode *ret1 = F->createSave("ret1" , mul3); |
1566 | |
1567 | SplatNode *one = F->createSplat("one" , input2->getType(), 1.0); |
1568 | AddNode *add1 = F->createAdd("add1" , input2, one); |
1569 | AddNode *add2 = F->createAdd("add2" , add1, one); |
1570 | AddNode *add3 = F->createAdd("add3" , add2, one); |
1571 | SaveNode *ret2 = F->createSave("ret2" , add2); |
1572 | |
1573 | GraphPreOrderVisitor visitor(*F); |
1574 | auto order = visitor.getPreOrder(); |
1575 | |
1576 | ASSERT_EQ(order.size(), 14); |
1577 | EXPECT_EQ(order[0], ret1); |
1578 | EXPECT_EQ(order[1], mul3); |
1579 | EXPECT_EQ(order[2], mul1); |
1580 | EXPECT_EQ(order[3], zero); |
1581 | EXPECT_EQ(order[4], input1); |
1582 | EXPECT_EQ(order[5], mul2); |
1583 | EXPECT_EQ(order[6], input2); |
1584 | EXPECT_EQ(order[7], ret1->getOutput()); |
1585 | EXPECT_EQ(order[8], add3); |
1586 | EXPECT_EQ(order[9], add2); |
1587 | EXPECT_EQ(order[10], add1); |
1588 | EXPECT_EQ(order[11], one); |
1589 | EXPECT_EQ(order[12], ret2); |
1590 | EXPECT_EQ(order[13], ret2->getOutput()); |
1591 | } |
1592 | |
1593 | /// Verify that the post-order visitor works correctly. |
1594 | TEST(Graph, PostOrderTest) { |
1595 | Module M; |
1596 | PlaceholderBindings bindings; |
1597 | auto *F = M.createFunction("main" ); |
1598 | |
1599 | auto *input1 = |
1600 | M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1" , true); |
1601 | auto *input2 = |
1602 | M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2" , true); |
1603 | SplatNode *zero = F->createSplat("zero" , input1->getType(), 0.); |
1604 | MulNode *mul1 = F->createMul("mul1" , zero, input1); |
1605 | MulNode *mul2 = F->createMul("mul2" , zero, input2); |
1606 | MulNode *mul3 = F->createMul("mul3" , mul1, mul2); |
1607 | SaveNode *ret1 = F->createSave("ret1" , mul3); |
1608 | |
1609 | SplatNode *one = F->createSplat("one" , input2->getType(), 1.0); |
1610 | AddNode *add1 = F->createAdd("add1" , input2, one); |
1611 | AddNode *add2 = F->createAdd("add2" , add1, one); |
1612 | AddNode *add3 = F->createAdd("add3" , add2, one); |
1613 | SaveNode *ret2 = F->createSave("ret2" , add2); |
1614 | |
1615 | GraphPostOrderVisitor visitor(*F); |
1616 | auto order = visitor.getPostOrder(); |
1617 | |
1618 | ASSERT_EQ(order.size(), 14); |
1619 | EXPECT_EQ(order[0], zero); |
1620 | EXPECT_EQ(order[1], input1); |
1621 | EXPECT_EQ(order[2], mul1); |
1622 | EXPECT_EQ(order[3], input2); |
1623 | EXPECT_EQ(order[4], mul2); |
1624 | EXPECT_EQ(order[5], mul3); |
1625 | EXPECT_EQ(order[6], ret1->getOutput()); |
1626 | EXPECT_EQ(order[7], ret1); |
1627 | EXPECT_EQ(order[8], one); |
1628 | EXPECT_EQ(order[9], add1); |
1629 | EXPECT_EQ(order[10], add2); |
1630 | EXPECT_EQ(order[11], add3); |
1631 | EXPECT_EQ(order[12], ret2->getOutput()); |
1632 | EXPECT_EQ(order[13], ret2); |
1633 | } |
1634 | |
1635 | TEST(Graph, placeholder) { |
1636 | Module MD; |
1637 | PlaceholderBindings bindings; |
1638 | Function *F = MD.createFunction("F" ); |
1639 | IRFunction M(F); |
1640 | Node *K = |
1641 | MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , false); |
1642 | Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select" , false); |
1643 | |
1644 | K = F->createFullyConnected(bindings, "FC" , K, 10); |
1645 | K = F->createRELU("Relu" , K); |
1646 | K = F->createSoftMax("SoftMax" , K, S); |
1647 | F->createSave("Save" , K); |
1648 | } |
1649 | |
1650 | /// Check that the setType API allows to change the type of the |
1651 | /// related result and only the related result. |
1652 | TEST(Graph, setType) { |
1653 | Module M; |
1654 | auto *F = M.createFunction("main" ); |
1655 | |
1656 | const dim_t inputDims[] = {4, 10}; |
1657 | const dim_t top5Dims[] = {4, 5}; |
1658 | auto *input = |
1659 | M.createPlaceholder(ElemKind::FloatTy, inputDims, "input" , true); |
1660 | TopKNode *topK = F->createTopK("add" , input, 5); |
1661 | TypeRef origTopKRes0 = M.uniqueType(ElemKind::FloatTy, top5Dims); |
1662 | TypeRef origTopKRes1 = M.uniqueType(ElemKind::Int64ITy, top5Dims); |
1663 | |
1664 | EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0); |
1665 | EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1); |
1666 | |
1667 | // Modify the type of result 0 and make sure type 1 is not |
1668 | // affected. Similarly the input shouldn't be affected. |
1669 | TypeRef inputTy = M.uniqueType(ElemKind::FloatTy, inputDims); |
1670 | TypeRef topKRes0 = M.uniqueType(ElemKind::Float16Ty, top5Dims); |
1671 | topK->setType(TopKNode::ValuesIdx, topKRes0); |
1672 | EXPECT_EQ(input->getType(), inputTy); |
1673 | EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0); |
1674 | EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1); |
1675 | |
1676 | // Make sure the NodeValue API works the same way |
1677 | // as the Node::setType API. |
1678 | NodeValue valRes1 = topK->getNthResult(TopKNode::IndicesIdx); |
1679 | valRes1.setType(topKRes0); |
1680 | EXPECT_EQ(input->getType(), inputTy); |
1681 | EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0); |
1682 | EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), topKRes0); |
1683 | EXPECT_EQ(valRes1.getType(), topKRes0); |
1684 | |
1685 | // Now restore sane types. |
1686 | NodeValue valRes0 = topK->getNthResult(TopKNode::ValuesIdx); |
1687 | valRes0.setType(origTopKRes0); |
1688 | topK->setType(TopKNode::IndicesIdx, origTopKRes1); |
1689 | EXPECT_EQ(input->getType(), inputTy); |
1690 | EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0); |
1691 | EXPECT_EQ(valRes0.getType(), origTopKRes0); |
1692 | EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1); |
1693 | EXPECT_EQ(valRes1.getType(), origTopKRes1); |
1694 | } |
1695 | |
1696 | /// Check that we fixed the bug with Function::eraseNode. This method used to |
1697 | /// erase a node that was equal to the node we wanted to delete, which may be |
1698 | /// two different entities. |
1699 | /// To see this bug in action, we create a bunch of nodes with the same value. |
1700 | /// Then we erase them in reserve order. This reserve ordering was actually |
1701 | /// freeing the node in the original order, thus at some point we try to delete |
1702 | /// a node that has already deleted and an assert (debug mode) or segmentation |
1703 | /// fault (release would occur). |
1704 | /// Note: Which node is actually freed depend on the implementation of |
1705 | /// std::find, thus we cannot really predict when the bug occurs. |
1706 | TEST(Graph, eraseNodeBug) { |
1707 | Module M; |
1708 | auto *F = M.createFunction("main" ); |
1709 | |
1710 | auto *input = M.createPlaceholder(ElemKind::FloatTy, {3, 2}, "input" , true); |
1711 | std::vector<Node *> ReLUs; |
1712 | // Create a bunch of ReLUs. |
1713 | for (unsigned idx = 0; idx != 5; ++idx) { |
1714 | ReLUs.push_back(F->createRELU("relu" , input)); |
1715 | } |
1716 | // Check that we can erase all the nodes. |
1717 | for (int idx = 4; idx != -1; --idx) { |
1718 | F->eraseNode(ReLUs[idx]); |
1719 | } |
1720 | EXPECT_EQ(F->getNodes().size(), 0); |
1721 | } |
1722 | |
1723 | /// Verify that two Nodes with different predicates but the same inputs are not |
1724 | /// considered equal. |
1725 | TEST(Graph, nodeEqualityWithDifferentPredicates) { |
1726 | Module M; |
1727 | auto *F = M.createFunction("main" ); |
1728 | |
1729 | Node *in = M.createPlaceholder(ElemKind::FloatTy, {5}, "in" , false); |
1730 | Node *pred1 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1731 | Node *pred2 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1732 | |
1733 | Node *RN1 = F->createRELU("relu1" , in); |
1734 | RN1->setPredicate(pred1); |
1735 | |
1736 | Node *RN2 = F->createRELU("relu2" , in); |
1737 | RN2->setPredicate(pred2); |
1738 | |
1739 | EXPECT_FALSE(RN1->isEqual(*RN2)); |
1740 | } |
1741 | |
1742 | /// Check that verify doesn't allow for multiple writers to the same node. |
1743 | TEST(Graph, verifyOneWriter) { |
1744 | Module M; |
1745 | auto *F = M.createFunction("main" ); |
1746 | |
1747 | auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input" , false); |
1748 | auto *output = M.createPlaceholder(ElemKind::FloatTy, {5}, "output" , false); |
1749 | F->createSave("Save1" , input, output); |
1750 | F->createSave("Save2" , input, output); |
1751 | |
1752 | EXPECT_FALSE(M.verify()); |
1753 | } |
1754 | |
1755 | /// Check that verify doesn't allow for Constants to be written to. Note that |
1756 | /// createSave() cannot do this as the API only accepts Placeholders to write |
1757 | /// to, however it could happen during graph transformations, e.g. via |
1758 | /// replaceAllUsesOfWith() as shown here. |
1759 | TEST(Graph, verifyConstantNoWriters) { |
1760 | Module M; |
1761 | auto *F = M.createFunction("main" ); |
1762 | |
1763 | auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input" , false); |
1764 | auto *outputPH = M.createPlaceholder(ElemKind::FloatTy, {5}, "outPH" , false); |
1765 | F->createSave("save" , input, outputPH); |
1766 | |
1767 | // Replace the output Placeholder with a Constant. This should fail |
1768 | // verification. |
1769 | auto *outputC = M.createConstant(ElemKind::FloatTy, {5}, "outC" ); |
1770 | NodeValue(outputPH).replaceAllUsesOfWith(outputC); |
1771 | |
1772 | EXPECT_FALSE(M.verify()); |
1773 | } |
1774 | |
1775 | TEST(Graph, typeUnsafeReplaceAllUsesOfWith) { |
1776 | Module M; |
1777 | auto *F = M.createFunction("main" ); |
1778 | |
1779 | auto *LHS = M.createPlaceholder(ElemKind::FloatTy, {3, 4}, "A" , false); |
1780 | auto *RHS = M.createPlaceholder(ElemKind::FloatTy, {4, 5}, "B" , false); |
1781 | auto *FC = F->createMatMul("fc" , LHS, RHS); |
1782 | F->createSave("save" , FC); |
1783 | |
1784 | auto newLHS = M.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A" , false); |
1785 | LHS->getOutput().typeUnsafeReplaceAllUsesOfWith(newLHS); |
1786 | } |
1787 | |
1788 | /// Check that the verifier will complain if a constant and its |
1789 | /// underlying tensor have mismatching types. |
1790 | /// Here the constant is updated but not the tensor. |
1791 | TEST(Graph, verifyConstantTensorTypeMatchesConstantTypeChanged) { |
1792 | Module M; |
1793 | |
1794 | auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input" ); |
1795 | // Fresh constant should verify just fine. |
1796 | EXPECT_TRUE(input->verify()); |
1797 | |
1798 | input->setType(Storage::OutputIdx, M.uniqueType(ElemKind::Float16Ty, {5})); |
1799 | |
1800 | EXPECT_FALSE(input->verify()); |
1801 | } |
1802 | |
1803 | /// Check that the verifier will complain if a constant and its |
1804 | /// underlying tensor have mismatching types. |
1805 | /// Here the tensor is updated but not the constant. |
1806 | TEST(Graph, verifyConstantTensorTypeMatchesTensorTypeChanged) { |
1807 | Module M; |
1808 | |
1809 | auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input" ); |
1810 | // Fresh constant should verify just fine. |
1811 | EXPECT_TRUE(input->verify()); |
1812 | input->getPayloadMutable().convertToType(ElemKind::Float16Ty); |
1813 | |
1814 | EXPECT_FALSE(input->verify()); |
1815 | } |
1816 | |
1817 | /// Check that Constants backed by unowned Tensors are in fact unowned until |
1818 | /// a mutable reference to their payload is obtained at which point the backing |
1819 | /// Tensor is copied and becomes owned. |
1820 | TEST(Graph, verifyConstantWithUnownedTensorCopiesOnWrite) { |
1821 | Module M; |
1822 | |
1823 | Tensor originalT(ElemKind::FloatTy, {3}); |
1824 | Tensor unownedT = originalT.getUnowned({3}); |
1825 | |
1826 | auto originalH = originalT.getHandle(); |
1827 | |
1828 | for (size_t i = 0; i < originalT.size(); i++) { |
1829 | originalH.raw(i) = i; |
1830 | } |
1831 | |
1832 | // Both Tensors should have the same underlying memory because unownedT shares |
1833 | // originalT's memory. |
1834 | EXPECT_EQ(originalT.getUnsafePtr(), unownedT.getUnsafePtr()); |
1835 | |
1836 | Constant *originalC = M.createConstant("original" , std::move(originalT)); |
1837 | Constant *unownedC = M.createConstant("unowned" , std::move(unownedT)); |
1838 | |
1839 | const Tensor &originalCT = originalC->getPayload(); |
1840 | const Tensor &unownedCT = unownedC->getPayload(); |
1841 | |
1842 | const auto originalCTH = originalCT.getHandle(); |
1843 | const auto unownedCTH = unownedCT.getHandle(); |
1844 | |
1845 | ASSERT_EQ(originalCTH.size(), unownedCTH.size()); |
1846 | |
1847 | // Both Constants should have the same values because their Tensors have the |
1848 | // same underlying memory. |
1849 | for (size_t i = 0; i < originalCTH.size(); i++) { |
1850 | EXPECT_EQ(i, originalCTH.raw(i)); |
1851 | EXPECT_EQ(i, unownedCTH.raw(i)); |
1852 | } |
1853 | |
1854 | Tensor &originalCTM = originalC->getPayloadMutable(); |
1855 | auto originalCTMH = originalCTM.getHandle(); |
1856 | |
1857 | // Bump up the value in the original Constant, this change should be |
1858 | // reflected in the unowned Constant as well. |
1859 | for (size_t i = 0; i < originalCTMH.size(); i++) { |
1860 | originalCTMH.raw(i) += 1; |
1861 | } |
1862 | |
1863 | // After changing the values in the original Constant, we should see an update |
1864 | // in the values of the unowned Constant because they share the same |
1865 | // underlying memory. |
1866 | for (size_t i = 0; i < unownedCTH.size(); i++) { |
1867 | EXPECT_EQ(unownedCTH.raw(i), i + 1); |
1868 | } |
1869 | |
1870 | Tensor &unownedCTM = unownedC->getPayloadMutable(); |
1871 | auto unownedCTMH = unownedCTM.getHandle(); |
1872 | |
1873 | ASSERT_EQ(originalCTH.size(), unownedCTMH.size()); |
1874 | |
1875 | // After getting a mutable reference to the unowned Constant's payload, the |
1876 | // underlying memory should have been copied but should still contain the same |
1877 | // values as it did previously at this point. |
1878 | EXPECT_NE(unownedCTM.getUnsafePtr(), originalCT.getUnsafePtr()); |
1879 | for (size_t i = 0; i < unownedCTMH.size(); i++) { |
1880 | EXPECT_EQ(unownedCTMH.raw(i), i + 1); |
1881 | } |
1882 | |
1883 | // Bump up the value in the original Constant again, this change should not be |
1884 | // reflected in the unowned Constant now because at this point, after a |
1885 | // mutable reference to its payload has been obtained, it should have it's own |
1886 | // memory. |
1887 | for (size_t i = 0; i < originalCTMH.size(); i++) { |
1888 | originalCTMH.raw(i) += 1; |
1889 | } |
1890 | |
1891 | // Now that the unowned Constant's payload has been obtained as mutable, it |
1892 | // should have been copied and thus have its own memory and changes to the |
1893 | // original constant should not be reflected in the unowned Constant. |
1894 | for (size_t i = 0; i < unownedCTMH.size(); i++) { |
1895 | EXPECT_EQ(unownedCTMH.raw(i), i + 1); |
1896 | } |
1897 | } |
1898 | |
1899 | /// Check that hooking an intermediate node works. |
1900 | TEST(Graph, hookTest) { |
1901 | Module mod; |
1902 | auto *F = mod.createFunction("main" ); |
1903 | auto *in = mod.createPlaceholder(ElemKind::FloatTy, {1}, "in" , false); |
1904 | auto *relu1 = F->createRELU("relu1" , in); |
1905 | auto *relu2 = F->createRELU("relu2" , relu1); |
1906 | F->createSave("save" , relu2); |
1907 | EXPECT_EQ(F->getNodes().size(), 3); |
1908 | EXPECT_EQ(mod.getPlaceholders().size(), 2); |
1909 | |
1910 | // Hook the first relu and verify that the hooked graph looks right. |
1911 | auto hooked = glow::hookNode(F, relu1); |
1912 | auto const &nodes = hooked.function->getNodes(); |
1913 | ASSERT_EQ(mod.getPlaceholders().size(), 3); |
1914 | ASSERT_EQ(nodes.size(), 2); |
1915 | auto const *hookSave = *hooked.outputSaves.begin(); |
1916 | ASSERT_TRUE(hookSave); |
1917 | auto *inp = llvm::dyn_cast<ReluNode>(hookSave->getInput()); |
1918 | ASSERT_TRUE(inp); |
1919 | auto *ph = llvm::dyn_cast<Placeholder>(inp->getInput()); |
1920 | ASSERT_TRUE(ph); |
1921 | ASSERT_EQ(ph, in); |
1922 | } |
1923 | |
1924 | /// Check that getConstantsSize returns the correct size of constants. |
1925 | TEST(Graph, moduleSize) { |
1926 | Module mod; |
1927 | |
1928 | EXPECT_EQ(mod.getConstantsSize(), 0); |
1929 | |
1930 | auto *cons1 = mod.createConstant(ElemKind::FloatTy, {1}, "var" ); |
1931 | EXPECT_EQ(mod.getConstantsSize(), sizeof(float) * cons1->getPayload().size()); |
1932 | |
1933 | auto *cons2 = mod.createConstant(ElemKind::FloatTy, {1, 32, 32, 16}, "var2" ); |
1934 | EXPECT_EQ(mod.getConstantsSize(), |
1935 | sizeof(float) + sizeof(float) * cons2->getPayload().size()); |
1936 | } |
1937 | |
1938 | /// Check that getDataSize() returns the correct size of backing tensors. |
1939 | TEST(Graph, contextSize) { |
1940 | Module mod; |
1941 | PlaceholderBindings bindings; |
1942 | |
1943 | Placeholder *PH = |
1944 | mod.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
1945 | |
1946 | EXPECT_EQ(bindings.getDataSize(), 0); |
1947 | bindings.allocate(PH); |
1948 | EXPECT_EQ(bindings.get(PH)->size(), 4 * 320 * 200 * 3); |
1949 | EXPECT_EQ(bindings.getDataSize(), sizeof(float) * bindings.get(PH)->size()); |
1950 | } |
1951 | |
1952 | /// Check that clones of the context are distinct and share no references back |
1953 | /// to the original object. |
1954 | TEST(Graph, clonePlaceholderBindings) { |
1955 | Module mod; |
1956 | |
1957 | Placeholder *PH1 = |
1958 | mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH1" , false); |
1959 | |
1960 | PlaceholderBindings bindings1; |
1961 | bindings1.allocate(PH1); |
1962 | |
1963 | PlaceholderBindings bindings2 = bindings1.clone(); |
1964 | |
1965 | Tensor *t1 = bindings1.get(PH1); |
1966 | Tensor *t2 = bindings2.get(PH1); |
1967 | |
1968 | EXPECT_NE(t1, nullptr); |
1969 | EXPECT_NE(t2, nullptr); |
1970 | EXPECT_NE(t1, t2); |
1971 | |
1972 | // The new PlaceholderBindings has no references back, and changing it does |
1973 | // not affect bindings1 |
1974 | Placeholder *PH2 = |
1975 | mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH2" , false); |
1976 | |
1977 | bindings2.allocate(PH2); |
1978 | // now exists in bindings1 but not bindings2 |
1979 | EXPECT_EQ(bindings1.get(PH2), nullptr); |
1980 | EXPECT_NE(bindings2.get(PH2), nullptr); |
1981 | |
1982 | // Likewise changing bindings1 does not affect bindings2 |
1983 | bindings1.clear(); |
1984 | EXPECT_EQ(bindings1.count(PH1), 0); |
1985 | EXPECT_EQ(bindings2.count(PH1), 1); |
1986 | |
1987 | // Adds are distinct |
1988 | Placeholder *PH3 = |
1989 | mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH3" , false); |
1990 | bindings1.allocate(PH3); |
1991 | bindings2.allocate(PH3); |
1992 | EXPECT_NE(bindings1.get(PH3), nullptr); |
1993 | EXPECT_NE(bindings2.get(PH3), nullptr); |
1994 | EXPECT_NE(bindings1.get(PH3), bindings2.get(PH3)); |
1995 | } |
1996 | |
1997 | /// Check that running a function multiple times on cloned PlaceholderBindingss |
1998 | /// have distinct outputs. |
1999 | TEST(Graph, clonePlaceholderBindingsRuns) { |
2000 | ExecutionEngine EE; |
2001 | PseudoRNG PRNG; |
2002 | |
2003 | Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3}); |
2004 | auto &mod = EE.getModule(); |
2005 | Function *F = mod.createFunction("main" ); |
2006 | PlaceholderBindings bindings; |
2007 | auto *input = |
2008 | mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input" , true); |
2009 | |
2010 | bindings.allocate(input); |
2011 | |
2012 | auto *FCL1 = F->createFullyConnected(bindings, "fc" , input, 10); |
2013 | auto *RL3 = F->createRELU("relu4" , FCL1); |
2014 | auto *save = F->createSave("ret" , RL3); |
2015 | auto *savePH = save->getPlaceholder(); |
2016 | |
2017 | bindings.allocate(save->getPlaceholder()); |
2018 | |
2019 | // Compile once. |
2020 | EE.compile(CompilationMode::Infer); |
2021 | |
2022 | // Run with random inputs. |
2023 | inputs.getHandle<>().randomize(-3.0, 3.0, PRNG); |
2024 | updateInputPlaceholders(bindings, {input}, {&inputs}); |
2025 | EE.run(bindings); |
2026 | |
2027 | // Clone the context. |
2028 | PlaceholderBindings bindings2 = bindings.clone(); |
2029 | |
2030 | // PlaceholderBindingss are identical. |
2031 | Tensor *saveBacking1, *saveBacking2; |
2032 | saveBacking1 = bindings.get(savePH); |
2033 | saveBacking2 = bindings2.get(savePH); |
2034 | EXPECT_NE(saveBacking1, saveBacking2); |
2035 | EXPECT_EQ(saveBacking1->size(), saveBacking2->size()); |
2036 | EXPECT_TRUE(saveBacking1->isEqual(*saveBacking2)); |
2037 | |
2038 | // Run again with different random inputs using the cloned context. |
2039 | Tensor inputs2(ElemKind::FloatTy, {1, 32, 32, 3}); |
2040 | inputs2.getHandle<>().randomize(-3.0, 3.0, PRNG); |
2041 | updateInputPlaceholders(bindings2, {input}, {&inputs2}); |
2042 | EE.run(bindings2); |
2043 | |
2044 | // PlaceholderBindingss are no longer identical. |
2045 | EXPECT_EQ(saveBacking1->size(), saveBacking2->size()); |
2046 | EXPECT_FALSE(saveBacking1->isEqual(*saveBacking2)); |
2047 | } |
2048 | |
2049 | /// Check that using the indices enums in nodes works correctly, with |
2050 | /// multi-input, multi-output, and single-input/output nodes. |
2051 | TEST(Graph, TestNodeEnums) { |
2052 | Module MD; |
2053 | Function *F = MD.createFunction("F" ); |
2054 | PlaceholderBindings bindings; |
2055 | Placeholder *I = |
2056 | MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input" , true); |
2057 | Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output" , true); |
2058 | |
2059 | TopKNode *TKN = F->createTopK("topk" , I, 3); |
2060 | GatherNode *GN = |
2061 | F->createGather("gather" , TKN->getValues(), TKN->getIndices()); |
2062 | TanhNode *TN = F->createTanh("tanh" , GN); |
2063 | SaveNode *SN = F->createSave("save" , TN, O); |
2064 | |
2065 | // Check structure of Placeholders. |
2066 | EXPECT_EQ(I->getNthResult(Storage::OutputIdx), I->getOutput()); |
2067 | EXPECT_EQ(O->getNthResult(Storage::OutputIdx), O->getOutput()); |
2068 | |
2069 | // Check structure of TopK. |
2070 | EXPECT_EQ(TKN->getInput(), TKN->getNthInput(TopKNode::InputIdx)); |
2071 | EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx), TKN->getValues()); |
2072 | EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx), TKN->getIndices()); |
2073 | |
2074 | // Check structure of Gather. |
2075 | EXPECT_EQ(GN->getNthInput(GatherNode::DataIdx), GN->getData()); |
2076 | EXPECT_EQ(GN->getNthInput(GatherNode::IndicesIdx), GN->getIndices()); |
2077 | EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx), GN->getResult()); |
2078 | |
2079 | // Check structure of Tanh. |
2080 | EXPECT_EQ(TN->getNthInput(TanhNode::InputIdx), TN->getInput()); |
2081 | EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx), TN->getResult()); |
2082 | |
2083 | // Check structure of Save. |
2084 | EXPECT_EQ(SN->getNthInput(SaveNode::InputIdx), SN->getInput()); |
2085 | EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), SN->getOutput()); |
2086 | |
2087 | // Check connection between Placeholder and TopK. |
2088 | EXPECT_EQ(TKN->getNthInput(TopKNode::InputIdx), I->getOutput()); |
2089 | |
2090 | // Check connections between TopK and Gather. |
2091 | EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx), |
2092 | GN->getNthInput(GatherNode::DataIdx)); |
2093 | EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx), |
2094 | GN->getNthInput(GatherNode::IndicesIdx)); |
2095 | |
2096 | // Check connection between Gather and Tanh. |
2097 | EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx), |
2098 | TN->getNthInput(TanhNode::InputIdx)); |
2099 | |
2100 | // Check connection between Gather and Tanh. |
2101 | EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx), |
2102 | SN->getNthInput(SaveNode::InputIdx)); |
2103 | |
2104 | // Check connection between Gather and Tanh. |
2105 | EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), O->getOutput()); |
2106 | } |
2107 | |
2108 | /// Searched \p F for a single instance of a node of Kind T. If more than one is |
2109 | /// found, \returns nullptr, otherwise returns the single instance. |
2110 | template <class T> static T *findSingleInstanceOfNode(Function *F) { |
2111 | T *found = nullptr; |
2112 | for (auto &n : F->getNodes()) { |
2113 | if (auto *currNode = llvm::dyn_cast<T>(&n)) { |
2114 | if (found != nullptr) { |
2115 | return nullptr; |
2116 | } |
2117 | found = currNode; |
2118 | } |
2119 | } |
2120 | return found; |
2121 | } |
2122 | |
2123 | /// Check that group Conv is not lowered when specified to lower by backend if |
2124 | /// doNotLowerKinds contains Conv. |
2125 | TEST(Graph, GroupTestConvNoLower) { |
2126 | Module MD; |
2127 | Function *F = MD.createFunction("F" ); |
2128 | IRFunction M(F); |
2129 | PlaceholderBindings bindings; |
2130 | Node *K = |
2131 | MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 8}, "input" , true); |
2132 | Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select" , true); |
2133 | |
2134 | K = F->createConv(bindings, "Conv1" , K, 16, 3, 2, 3, /* group */ 8); |
2135 | K = F->createRELU("Relu" , K); |
2136 | K = F->createSoftMax("SoftMax" , K, S); |
2137 | F->createSave("Save" , K); |
2138 | F->dump(); |
2139 | auto filePath = F->dumpDAG(); |
2140 | auto backend = MockBackend(); |
2141 | |
2142 | { |
2143 | // Before we lower, we should have a single Conv node with group = 8. |
2144 | ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F); |
2145 | if (!CN) { |
2146 | llvm::sys::fs::remove(filePath); |
2147 | } |
2148 | ASSERT_TRUE(CN); |
2149 | EXPECT_EQ(CN->getGroup(), 8); |
2150 | } |
2151 | |
2152 | // Now lower, but prevent ConvolutionNodeKinds from being lowered. |
2153 | KindSet doNotLower; |
2154 | doNotLower.insert(Kinded::Kind::ConvolutionNodeKind); |
2155 | CompilationContext cctx; |
2156 | lower(F, cctx, &backend, doNotLower); |
2157 | |
2158 | { |
2159 | // Now have lowered but should still have a single Conv node with group = 8. |
2160 | ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F); |
2161 | if (!CN) { |
2162 | llvm::sys::fs::remove(filePath); |
2163 | } |
2164 | ASSERT_TRUE(CN); |
2165 | EXPECT_EQ(CN->getGroup(), 8); |
2166 | } |
2167 | } |
2168 | |
2169 | /// Check that getOutputSave returns SaveNode object for the correct Placeholder |
2170 | /// and nullptr in other cases. |
2171 | TEST(Graph, GetOutputSaveTest) { |
2172 | Module MD; |
2173 | Function *F = MD.createFunction("F" ); |
2174 | PlaceholderBindings bindings; |
2175 | Placeholder *I = |
2176 | MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input" , true); |
2177 | Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output" , true); |
2178 | TopKNode *TKN = F->createTopK("topk" , I, 3); |
2179 | GatherNode *GN = |
2180 | F->createGather("gather" , TKN->getValues(), TKN->getIndices()); |
2181 | TanhNode *TN = F->createTanh("tanh" , GN); |
2182 | SaveNode *SN = F->createSave("save" , TN, O); |
2183 | |
2184 | // Check the return value of getOutputSave method. |
2185 | // Placeholder parent is null. |
2186 | auto *FoundNode = glow::getOutputSave(F, O); |
2187 | EXPECT_NE(nullptr, FoundNode); |
2188 | EXPECT_EQ(SN, FoundNode); |
2189 | |
2190 | // Placeholder parent is set to the correct value. |
2191 | O->setParent(F); |
2192 | EXPECT_EQ(F, O->getParent()); |
2193 | FoundNode = glow::getOutputSave(F, O); |
2194 | EXPECT_NE(nullptr, FoundNode); |
2195 | EXPECT_EQ(SN, FoundNode); |
2196 | |
2197 | // Invalid placeholder type is provided. |
2198 | EXPECT_EQ(nullptr, glow::getOutputSave(F, I)); |
2199 | |
2200 | // Save belongs to a different function |
2201 | Function *F2 = MD.createFunction("F2" ); |
2202 | TopKNode *TKN2 = F2->createTopK("topk" , I, 3); |
2203 | GatherNode *GN2 = |
2204 | F2->createGather("gather" , TKN2->getValues(), TKN2->getIndices()); |
2205 | TanhNode *TN2 = F2->createTanh("tanh" , GN2); |
2206 | SaveNode *SN2 = F2->createSave("save" , TN2, O); |
2207 | |
2208 | FoundNode = glow::getOutputSave(F, O); |
2209 | EXPECT_NE(nullptr, FoundNode); |
2210 | EXPECT_EQ(SN, FoundNode); |
2211 | |
2212 | O->setParent(F2); |
2213 | FoundNode = glow::getOutputSave(F2, O); |
2214 | EXPECT_NE(nullptr, FoundNode); |
2215 | EXPECT_EQ(SN2, FoundNode); |
2216 | } |
2217 | |
2218 | /// Check if dump functions work for Node, Function and Module. |
2219 | TEST(Graph, testDumpStructure) { |
2220 | Module MD; |
2221 | Function *F = MD.createFunction("F" ); |
2222 | IRFunction M(F); |
2223 | PlaceholderBindings bindings; |
2224 | Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3}, |
2225 | "input" , true); |
2226 | // Test Node |
2227 | std::string storageN1; |
2228 | llvm::raw_string_ostream osN1(storageN1); |
2229 | K->dump(osN1); |
2230 | std::string mesN = K->toString(); |
2231 | std::string expectMes = R"(Placeholder |
2232 | Name : input |
2233 | Layout : * |
2234 | Output : float<4 x 320 x 200 x 100 x 3> |
2235 | Trainable : 1 |
2236 | Static : 0 |
2237 | Users : 0 |
2238 | )" ; |
2239 | EXPECT_EQ(mesN, expectMes); |
2240 | EXPECT_EQ(mesN, osN1.str()); |
2241 | std::string storageN2; |
2242 | llvm::raw_string_ostream osN2(storageN2); |
2243 | osN2 << K; |
2244 | EXPECT_EQ(mesN, osN2.str()); |
2245 | // Test Function |
2246 | Placeholder *I = |
2247 | MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input" , true); |
2248 | I->setStatic(true); |
2249 | Function *F2 = MD.createFunction("F2" ); |
2250 | F2->createTopK("topk" , I, 3); |
2251 | std::string storageF1; |
2252 | llvm::raw_string_ostream osF1(storageF1); |
2253 | F2->dump(osF1); |
2254 | std::string mesF = F2->toString(); |
2255 | std::string expectMesF = R"(Graph structure F2: |
2256 | TopK |
2257 | Name : topk |
2258 | Input : float<10 x 10> |
2259 | K : 3 |
2260 | Users : 0 |
2261 | Values : float<10 x 3> |
2262 | Indices : index64<10 x 3> |
2263 | Placeholder |
2264 | Name : input__1 |
2265 | Layout : * |
2266 | Output : float<10 x 10> |
2267 | Trainable : 1 |
2268 | Static : 1 |
2269 | Users : 1 |
2270 | )" ; |
2271 | EXPECT_EQ(mesF, expectMesF); |
2272 | EXPECT_EQ(mesF, osF1.str()); |
2273 | std::string storageF2; |
2274 | llvm::raw_string_ostream osF2(storageF2); |
2275 | osF2 << F2; |
2276 | EXPECT_EQ(mesF, osF2.str()); |
2277 | storageF1.clear(); |
2278 | F2->dump(osF1, /* skipUsersForStorage */ true); |
2279 | mesF = F2->toString(/* skipUsersForStorage */ true); |
2280 | expectMesF = R"(Graph structure F2: |
2281 | TopK |
2282 | Name : topk |
2283 | Input : float<10 x 10> |
2284 | K : 3 |
2285 | Users : 0 |
2286 | Values : float<10 x 3> |
2287 | Indices : index64<10 x 3> |
2288 | Placeholder |
2289 | Name : input__1 |
2290 | Layout : * |
2291 | Output : float<10 x 10> |
2292 | Trainable : 1 |
2293 | Static : 1 |
2294 | )" ; |
2295 | EXPECT_EQ(mesF, expectMesF); |
2296 | EXPECT_EQ(mesF, osF1.str()); |
2297 | // Test Module |
2298 | MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy" ); |
2299 | std::string storageM1; |
2300 | llvm::raw_string_ostream osM1(storageM1); |
2301 | MD.dump(osM1); |
2302 | std::string mesM = MD.toString(); |
2303 | std::string expectMesM = R"(Module structure: |
2304 | Constant |
2305 | Name : dummy |
2306 | Layout : * |
2307 | Output : float<1 x 1> |
2308 | Users : 0 |
2309 | |
2310 | Placeholder |
2311 | Name : input__1 |
2312 | Layout : * |
2313 | Output : float<10 x 10> |
2314 | Trainable : 1 |
2315 | Static : 1 |
2316 | Users : 1 |
2317 | |
2318 | Placeholder |
2319 | Name : input |
2320 | Layout : * |
2321 | Output : float<4 x 320 x 200 x 100 x 3> |
2322 | Trainable : 1 |
2323 | Static : 0 |
2324 | Users : 0 |
2325 | |
2326 | Function : F2 |
2327 | Function : F |
2328 | )" ; |
2329 | EXPECT_EQ(mesM, expectMesM); |
2330 | EXPECT_EQ(mesM, osM1.str()); |
2331 | std::string storageM2; |
2332 | llvm::raw_string_ostream osM2(storageM2); |
2333 | osM2 << MD; |
2334 | EXPECT_EQ(mesM, osM2.str()); |
2335 | } |
2336 | |
2337 | /// Initialize tensor payload for testing purposes. The value at index i is set |
2338 | /// to i. |
2339 | template <typename ElemTy> static void initTensor(Tensor &T) { |
2340 | Handle<ElemTy> handle = T.getHandle<ElemTy>(); |
2341 | float val = 0; |
2342 | for (auto &elem : handle) { |
2343 | elem = val; |
2344 | val += 1.0; |
2345 | } |
2346 | } |
2347 | |
2348 | // Test that randomizing Constants in a Function works. |
2349 | TEST(Graph, testRandomizeConstants) { |
2350 | Module MD; |
2351 | Function *F = MD.createFunction("F" ); |
2352 | |
2353 | // Create tensors to be used in Constants |
2354 | Tensor floatT(ElemKind::FloatTy, {10}); |
2355 | initTensor<float>(floatT); |
2356 | |
2357 | Tensor halfT(ElemKind::Float16Ty, {10}); |
2358 | initTensor<float16_t>(halfT); |
2359 | |
2360 | Tensor bfloat16T(ElemKind::BFloat16Ty, {10}); |
2361 | initTensor<bfloat16_t>(bfloat16T); |
2362 | |
2363 | Tensor int8QT(ElemKind::Int8QTy, {10}, 1.0, 0); |
2364 | initTensor<int8_t>(int8QT); |
2365 | |
2366 | Tensor uint8QT(ElemKind::UInt8QTy, {10}, 1.0, 0); |
2367 | initTensor<uint8_t>(uint8QT); |
2368 | |
2369 | Tensor int16QT(ElemKind::Int16QTy, {10}, 1.0, 0); |
2370 | initTensor<int16_t>(int16QT); |
2371 | |
2372 | Tensor int32QT(ElemKind::Int32QTy, {10}, 1.0, 0); |
2373 | initTensor<int32_t>(int32QT); |
2374 | |
2375 | Tensor int32IT(ElemKind::Int32ITy, {10}); |
2376 | initTensor<int32_t>(int32IT); |
2377 | |
2378 | Tensor int64IT(ElemKind::Int64ITy, {10}); |
2379 | initTensor<int64_t>(int64IT); |
2380 | |
2381 | Tensor uint8FusedQT(ElemKind::UInt8FusedQTy, {16, 16}, 1.0, 0); |
2382 | initTensor<uint8_t>(uint8FusedQT); |
2383 | |
2384 | Tensor uint8FusedFP16QT(ElemKind::UInt8FusedFP16QTy, {16, 16}, 1.0, 0); |
2385 | initTensor<uint8_t>(uint8FusedFP16QT); |
2386 | |
2387 | Tensor uint4FusedFP16QT(ElemKind::UInt4FusedFP16QTy, {16, 16}, 1.0, 0); |
2388 | initTensor<uint8_t>(uint4FusedFP16QT); |
2389 | |
2390 | Tensor boolT(ElemKind::BoolTy, {10}); |
2391 | initTensor<bool>(boolT); |
2392 | |
2393 | // Create Constants and use them in F |
2394 | auto *floatC = MD.createConstant("floatC" , floatT); |
2395 | F->createAdd("add" , floatC, floatC); |
2396 | |
2397 | auto *halfC = MD.createConstant("halfC" , halfT); |
2398 | F->createAdd("add" , halfC, halfC); |
2399 | |
2400 | auto *bfloat16C = MD.createConstant("bloat16C" , bfloat16T); |
2401 | F->createAdd("add" , bfloat16C, bfloat16C); |
2402 | |
2403 | auto *int8QC = MD.createConstant("int8QC" , int8QT); |
2404 | F->createAdd("add" , int8QC, int8QC); |
2405 | |
2406 | auto *uint8QC = MD.createConstant("uint8QC" , uint8QT); |
2407 | F->createAdd("add" , uint8QC, uint8QC); |
2408 | |
2409 | auto *int16QC = MD.createConstant("int16QC" , int16QT); |
2410 | F->createAdd("add" , int16QC, int16QC); |
2411 | |
2412 | auto *int32QC = MD.createConstant("int32QC" , int32QT); |
2413 | F->createAdd("add" , int32QC, int32QC); |
2414 | |
2415 | auto *int32IC = MD.createConstant("int32IC" , int32IT); |
2416 | F->createAdd("add" , int32IC, int32IC); |
2417 | |
2418 | auto *int64IC = MD.createConstant("int64IC" , int64IT); |
2419 | F->createAdd("add" , int64IC, int64IC); |
2420 | |
2421 | auto *uint8FusedQC = MD.createConstant("uint8FusedQC" , uint8FusedQT); |
2422 | F->createAdd("add" , uint8FusedQC, uint8FusedQC); |
2423 | |
2424 | auto *uint8FusedFP16QC = |
2425 | MD.createConstant("uint8FusedFP16QC" , uint8FusedFP16QT); |
2426 | F->createAdd("add" , uint8FusedFP16QC, uint8FusedFP16QC); |
2427 | |
2428 | auto *uint4FusedFP16QC = |
2429 | MD.createConstant("uint4FusedFP16QC" , uint4FusedFP16QT); |
2430 | F->createAdd("add" , uint4FusedFP16QC, uint4FusedFP16QC); |
2431 | |
2432 | auto *boolC = MD.createConstant("boolC" , boolT); |
2433 | F->createAdd("add" , boolC, boolC); |
2434 | |
2435 | // Randomize Constants in F |
2436 | F->randomizeConstants(); |
2437 | |
2438 | // Check that no Constant is the same as what it started as |
2439 | EXPECT_FALSE(floatT.isEqual(floatC->getPayload())); |
2440 | EXPECT_FALSE(halfT.isEqual(halfC->getPayload())); |
2441 | EXPECT_FALSE(bfloat16T.isEqual(bfloat16C->getPayload())); |
2442 | EXPECT_FALSE(int8QT.isEqual(int8QC->getPayload())); |
2443 | EXPECT_FALSE(uint8QT.isEqual(uint8QC->getPayload())); |
2444 | EXPECT_FALSE(int16QT.isEqual(int16QC->getPayload())); |
2445 | EXPECT_FALSE(int32QT.isEqual(int32QC->getPayload())); |
2446 | EXPECT_FALSE(int32IT.isEqual(int32IC->getPayload())); |
2447 | EXPECT_FALSE(int64IT.isEqual(int64IC->getPayload())); |
2448 | EXPECT_FALSE(uint8FusedQT.isEqual(uint8FusedQC->getPayload())); |
2449 | EXPECT_FALSE(uint8FusedFP16QT.isEqual(uint8FusedFP16QC->getPayload())); |
2450 | EXPECT_FALSE(uint4FusedFP16QT.isEqual(uint4FusedFP16QC->getPayload())); |
2451 | EXPECT_FALSE(boolT.isEqual(boolC->getPayload())); |
2452 | } |
2453 | |
2454 | TEST(Graph, testSoftmaxMultiplier) { |
2455 | glow::ExecutionEngine EE; |
2456 | Module &M = EE.getModule(); |
2457 | Function *F = M.createFunction("F" ); |
2458 | |
2459 | float beta = 2.0; |
2460 | |
2461 | // Create a graph with single softmax. |
2462 | auto *inputPH = |
2463 | M.createPlaceholder(ElemKind::FloatTy, {1, 2}, "input" , false); |
2464 | auto *select = |
2465 | M.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "select" , true); |
2466 | Node *softmaxNode = |
2467 | F->createSoftMax("softmax" , inputPH, select, inputPH->getType(), beta); |
2468 | auto *saveNode = F->createSave("output" , softmaxNode); |
2469 | |
2470 | PlaceholderBindings bindings; |
2471 | auto *inputT = bindings.allocate(inputPH); |
2472 | inputT->getHandle() = {1.0, 2.0}; |
2473 | Tensor expectedT(inputT->getType()); |
2474 | expectedT.getHandle() = {0.11920292, 0.88079703}; |
2475 | auto *outputT = bindings.allocate(saveNode->getPlaceholder()); |
2476 | EE.compile(CompilationMode::Infer); |
2477 | EE.run(bindings); |
2478 | |
2479 | EXPECT_TRUE(outputT->isEqual(expectedT)); |
2480 | } |
2481 | |
2482 | TEST(Graph, testLogSoftmaxMultiplier) { |
2483 | glow::ExecutionEngine EE; |
2484 | Module &M = EE.getModule(); |
2485 | Function *F = M.createFunction("F" ); |
2486 | |
2487 | float beta = 2.0; |
2488 | |
2489 | // Create a graph with single softmax. |
2490 | auto *inputPH = |
2491 | M.createPlaceholder(ElemKind::FloatTy, {1, 2}, "input" , false); |
2492 | auto *select = |
2493 | M.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "select" , true); |
2494 | Node *softmaxNode = |
2495 | F->createLogSoftMax("softmax" , inputPH, select, inputPH->getType(), beta); |
2496 | auto *saveNode = F->createSave("output" , softmaxNode); |
2497 | |
2498 | PlaceholderBindings bindings; |
2499 | auto *inputT = bindings.allocate(inputPH); |
2500 | inputT->getHandle() = {1.0, 2.0}; |
2501 | Tensor expectedT(inputT->getType()); |
2502 | expectedT.getHandle() = {-2.1269, -0.1269}; |
2503 | auto *outputT = bindings.allocate(saveNode->getPlaceholder()); |
2504 | EE.compile(CompilationMode::Infer); |
2505 | EE.run(bindings); |
2506 | |
2507 | EXPECT_TRUE(outputT->isEqual(expectedT)); |
2508 | } |
2509 | |