1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Partitioner/Partitioner.h" |
17 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
18 | #include "glow/Exporter/ONNXModelWriter.h" |
19 | #include "glow/Graph/Graph.h" |
20 | #include "glow/Importer/ONNXModelLoader.h" |
21 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
22 | #include "glow/Partitioner/PartitionerUtils.h" |
23 | |
24 | #include "llvm/Support/FileSystem.h" |
25 | |
26 | #include "gtest/gtest.h" |
27 | |
28 | using namespace glow; |
29 | |
30 | class PartitionerTest : public ::testing::Test { |
31 | public: |
32 | PartitionerTest() : F_(mod_.createFunction("main" )) {} |
33 | |
34 | protected: |
35 | Module mod_; |
36 | Function *F_; |
37 | PlaceholderBindings bindings_; |
38 | }; |
39 | |
40 | /// Execute a graph of functions based on the given DAG. |
41 | static void executeDAG(DAGNode *G, Module &mod, PlaceholderBindings &bindings, |
42 | llvm::ArrayRef<Placeholder *> vars, |
43 | llvm::ArrayRef<Tensor *> inputs, ExecutionEngine *EE) { |
44 | std::unordered_map<std::string, Function *> name2func; |
45 | |
46 | for (auto *F : mod.getFunctions()) { |
47 | name2func[F->getName().str()] = F; |
48 | } |
49 | |
50 | std::vector<DAGNode *> exeList; |
51 | int endPt = 0; |
52 | int curPt = 0; |
53 | // The first node is always the dummy node. |
54 | exeList.push_back(G); |
55 | endPt++; |
56 | while (curPt < endPt) { |
57 | DAGNode *dag = exeList.at(curPt); |
58 | // The root in a G is always a dummy function. |
59 | if (curPt > 0) { |
60 | updateInputPlaceholders(bindings, vars, inputs); |
61 | EE->run(bindings, dag->name); |
62 | } |
63 | for (int i = 0, e = dag->children.size(); i < e; i++) { |
64 | exeList.push_back(dag->children.at(i)); |
65 | endPt++; |
66 | } |
67 | curPt++; |
68 | } |
69 | } |
70 | |
71 | /// \returns true if all the functions have the valid save node format: i.e. no |
72 | /// such pattern Save->Save. |
73 | static bool checkSaveNode(Module &mod) { |
74 | for (auto F : mod.getFunctions()) { |
75 | for (const Node &N : F->getNodes()) { |
76 | if (N.getKind() != Kinded::Kind::SaveNodeKind) { |
77 | continue; |
78 | } |
79 | auto *ph = llvm::dyn_cast<Placeholder>(N.getNthInput(0).getNode()); |
80 | if (!ph) { |
81 | continue; |
82 | } |
83 | // If this SaveNode use the output of another SaveNode, it is an illegal |
84 | // pattern. |
85 | for (auto &user : ph->getUsers()) { |
86 | if (user.getUser() == &N || !llvm::dyn_cast<SaveNode>(user.getUser())) { |
87 | continue; |
88 | } |
89 | return false; |
90 | } |
91 | } |
92 | } |
93 | return true; |
94 | } |
95 | |
96 | /// Serializes \p dagList and re-loads it. Compares the structure of the DAGs |
97 | /// before/after and verify results are still the same given \p devices |
98 | static void verifyDAGSerialization( |
99 | DAGListTy &dagList, Module &origMod, PlaceholderBindings &bindings, |
100 | llvm::ArrayRef<llvm::StringRef> inputNames, llvm::StringRef resultName, |
101 | const std::vector<DeviceInfo> &devices, llvm::ArrayRef<Tensor *> inputs, |
102 | const Tensor &ref, ConstantFoldingRecordMap *constFoldRecord = nullptr) { |
103 | llvm::SmallString<64> path; |
104 | auto tempFileRes = |
105 | llvm::sys::fs::createTemporaryFile("exporter" , "output.onnx" , path); |
106 | (void)tempFileRes; |
107 | assert(tempFileRes.value() == 0); |
108 | |
109 | std::string outputFilename(path.c_str()); |
110 | std::cout << "Writing to file: " << outputFilename << std::endl; |
111 | { |
112 | // Note: do not include Constant data when we write out; we will reuse the |
113 | // Module so we don't need to save it. |
114 | Error err = Error::empty(); |
115 | llvm::StringMap<std::string> ; |
116 | ONNXModelWriter onnxWR(outputFilename, dagList, 7, 9, &err, |
117 | /* textMode */ false, /* zipMode */ false, |
118 | /* includeConstantData */ false, extraMetadataProps, |
119 | constFoldRecord ? *constFoldRecord |
120 | : ConstantFoldingRecordMap()); |
121 | |
122 | if (ERR_TO_BOOL(std::move(err))) { |
123 | llvm::errs() << "ONNXModelWriter failed to write model: " |
124 | << outputFilename << "\n" ; |
125 | llvm::sys::fs::remove(outputFilename); |
126 | FAIL() << "Error exporting DAG." ; |
127 | } |
128 | } |
129 | |
130 | // Create a new EE using the same module. Note that we assume devices are |
131 | // homogenous here. |
132 | ExecutionEngine loadedEE(devices[0].backendName, devices[0].availableMemory, |
133 | /* ignoreUserDeviceConfig */ false, |
134 | /* numDevices */ devices.size()); |
135 | // Clone the original module into the one in the EE; we're going to |
136 | // deserialize the DAG into it as if we're reusing the same Module. |
137 | origMod.clone(&loadedEE.getModule()); |
138 | Module &loadedMod = loadedEE.getModule(); |
139 | CompilationContext loadedCctx; |
140 | runtime::PrePartitionedConfig PPC; |
141 | loadedCctx.prepartitionedConfig = &PPC; |
142 | { |
143 | // Clear out Functions from Nodes. We will reuse the empty Functions. |
144 | loadedMod.clearFunctions(); |
145 | // If we have a constant folding record then delete those Constants too |
146 | // since we're going to recreate them. Also delete the const fold Functions. |
147 | if (constFoldRecord) { |
148 | std::unordered_set<Function *> funsToDelete; |
149 | for (auto &pair : *constFoldRecord) { |
150 | Function *origF = pair.second->getParent(); |
151 | funsToDelete.insert(origF); |
152 | Constant *C = loadedMod.getConstantByName(pair.first->getName()); |
153 | ASSERT_TRUE(C); |
154 | loadedMod.eraseConstant(C); |
155 | } |
156 | for (Function *origF : funsToDelete) { |
157 | Function *loadedConstFoldF = loadedMod.getFunction(origF->getName()); |
158 | ASSERT_TRUE(loadedConstFoldF); |
159 | loadedMod.eraseFunction(loadedConstFoldF); |
160 | origMod.eraseFunction(origF); |
161 | } |
162 | } |
163 | Error err = Error::empty(); |
164 | ONNXModelLoader onnxLD( |
165 | outputFilename, {}, {}, loadedMod, "main" , &PPC, &err, |
166 | /* zipMode */ false, &loadedCctx.backendOpts.backendSpecificNodeInfo, |
167 | /* loadIntoExistingModule */ true); |
168 | if (ERR_TO_BOOL(std::move(err))) { |
169 | llvm::errs() << "ONNXModelLoader failed to load model: " << outputFilename |
170 | << "\n" ; |
171 | llvm::sys::fs::remove(outputFilename); |
172 | FAIL() << "Error importing DAG." ; |
173 | } |
174 | } |
175 | llvm::sys::fs::remove(outputFilename); |
176 | |
177 | // Now verify the DAG is the same, including all static properties of the DAG. |
178 | Partitioner loadedPartitioner(&loadedMod, devices, /* optimized */ true); |
179 | DAGListTy loadedDagList; |
180 | ASSIGN_VALUE_OR_FAIL_TEST(loadedDagList, |
181 | loadedPartitioner.partition(loadedCctx)); |
182 | |
183 | // Verify that two DAGs are the same. |
184 | ASSERT_EQ(dagList.size(), loadedDagList.size()); |
185 | ASSERT_EQ(dagList.size(), 1); |
186 | DAG &origDAG = dagList.front(); |
187 | DAG &loadedDAG = loadedDagList.front(); |
188 | EXPECT_EQ(origDAG.root->name, loadedDAG.root->name); |
189 | |
190 | // Map from orig DAGNodes to loaded DAGNodes. |
191 | std::unordered_map<const DAGNode *, const DAGNode *> origToLoaded; |
192 | for (DAGNodePtr &origN : origDAG.nodes) { |
193 | for (DAGNodePtr &loadedN : loadedDAG.nodes) { |
194 | if (origN->name != loadedN->name) { |
195 | continue; |
196 | } |
197 | origToLoaded[origN.get()] = loadedN.get(); |
198 | break; |
199 | } |
200 | } |
201 | ASSERT_EQ(origDAG.nodes.size(), origToLoaded.size()); |
202 | origToLoaded[origDAG.root.get()] = loadedDAG.root.get(); |
203 | |
204 | for (const auto &nPair : origToLoaded) { |
205 | const DAGNode *origN = nPair.first; |
206 | const DAGNode *loadedN = nPair.second; |
207 | #define CHECK_DAG_EQ(MEM_NAME) EXPECT_EQ(origN->MEM_NAME, loadedN->MEM_NAME); |
208 | CHECK_DAG_EQ(name); |
209 | CHECK_DAG_EQ(size); |
210 | CHECK_DAG_EQ(backendName); |
211 | CHECK_DAG_EQ(backendHints.executionUnits); |
212 | CHECK_DAG_EQ(logicalDevices.size()); |
213 | for (size_t i = 0, e = origN->logicalDevices.size(); i < e; i++) { |
214 | EXPECT_EQ(origN->logicalDevices[i], loadedN->logicalDevices[i]); |
215 | } |
216 | CHECK_DAG_EQ(replicationCount); |
217 | EXPECT_TRUE(std::equal(origN->backendSpecificOpts.begin(), |
218 | origN->backendSpecificOpts.end(), |
219 | loadedN->backendSpecificOpts.begin())); |
220 | #undef CHECK_DAG_EQ |
221 | |
222 | for (const DAGNode *origChild : origN->children) { |
223 | auto it = std::find_if(loadedN->children.begin(), loadedN->children.end(), |
224 | [=](auto *loadedChild) { |
225 | return loadedChild->name == origChild->name; |
226 | }); |
227 | EXPECT_NE(it, std::end(loadedN->children)); |
228 | } |
229 | for (const DAGNode *origParent : origN->parents) { |
230 | auto it = std::find_if(loadedN->parents.begin(), loadedN->parents.end(), |
231 | [=](auto *loadedParent) { |
232 | return loadedParent->name == origParent->name; |
233 | }); |
234 | EXPECT_NE(it, std::end(loadedN->parents)); |
235 | } |
236 | |
237 | // Skip checking root as there's no Function for them. |
238 | if (origN == origDAG.root.get()) { |
239 | continue; |
240 | } |
241 | Function *origF = origMod.getFunction(origN->name); |
242 | Function *loadedF = loadedMod.getFunction(loadedN->name); |
243 | ASSERT_TRUE(origF); |
244 | ASSERT_TRUE(loadedF); |
245 | EXPECT_EQ(origF->toString(), loadedF->toString()); |
246 | } |
247 | EXPECT_EQ(origMod.toString(), loadedMod.toString()); |
248 | |
249 | // Now reset bindings and run, checking results are bitwise equal from before |
250 | // and after serialization. Note that we still use the same PPC -- it will |
251 | // re-partition/setup the same DAG inside compilation. |
252 | loadedEE.compile(loadedCctx); |
253 | bindings.clear(); |
254 | bindings.allocate(loadedMod.getPlaceholders()); |
255 | std::vector<Placeholder *> inPHs; |
256 | for (const llvm::StringRef &inName : inputNames) { |
257 | inPHs.push_back(bindings.getPlaceholderByNameSlow(inName)); |
258 | } |
259 | updateInputPlaceholders(bindings, inPHs, inputs); |
260 | loadedEE.run(bindings); |
261 | Tensor test = |
262 | bindings.get(bindings.getPlaceholderByNameSlow(resultName))->clone(); |
263 | EXPECT_TRUE(ref.isEqual(test, 0.0f)); |
264 | } |
265 | |
266 | /// This one tests the model with this feature: after BFS, the memory |
267 | /// consumption of all the nodes in each level won't exceed the device memory |
268 | /// constraints. |
269 | TEST_F(PartitionerTest, Basic1) { |
270 | ExecutionEngine EER, EEP; |
271 | EEP.setSkipModuleStrip(true); |
272 | constexpr float range = 2.0; |
273 | std::vector<ExecutionEngine *> engines{&EER, &EEP}; |
274 | // Since compiling modifies the module and partitioning modifies the function, |
275 | // setup two EEs with identical functions for validation. |
276 | for (auto EE : engines) { |
277 | auto mod = &EE->getModule(); |
278 | F_ = mod->createFunction("main" ); |
279 | auto *input = |
280 | mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input" , false); |
281 | auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1" ); |
282 | auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1" ); |
283 | bindings_.allocate(input); |
284 | w1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
285 | b1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
286 | |
287 | // Initial FC. |
288 | Node *I = F_->createFullyConnected("initial_fc" , input, w1, b1); |
289 | I = F_->createSigmoid("initial_sigmoid" , I); |
290 | |
291 | // Left branch. |
292 | auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
293 | auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2" ); |
294 | w2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
295 | b2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
296 | Node *L = F_->createFullyConnected("left_fc1" , I, w2, b2); |
297 | L = F_->createSigmoid("left_sigmoid1" , L); |
298 | auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
299 | auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3" ); |
300 | w3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
301 | b3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
302 | L = F_->createFullyConnected("left_fc2" , L, w3, b3); |
303 | L = F_->createSigmoid("left_sigmoid2" , L); |
304 | |
305 | // Right branch. |
306 | auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4" ); |
307 | auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4" ); |
308 | w4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
309 | b4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
310 | Node *R = F_->createFullyConnected("right_fc1" , I, w4, b4); |
311 | R = F_->createSigmoid("right_sigmoid1" , R); |
312 | auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
313 | auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5" ); |
314 | w5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
315 | b5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
316 | R = F_->createFullyConnected("right_fc2" , R, w5, b5); |
317 | R = F_->createSigmoid("right_sigmoid2" , R); |
318 | |
319 | // Join branches. |
320 | auto *mul = F_->createMul("mul" , L, R); |
321 | F_->createSave("ret" , mul); |
322 | } |
323 | |
324 | // Infer using the un-partitioned graph. |
325 | Tensor in(ElemKind::FloatTy, {1, 32}); |
326 | in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG()); |
327 | |
328 | EER.compile(CompilationMode::Infer); |
329 | bindings_.clear(); |
330 | bindings_.allocate(EER.getModule().getPlaceholders()); |
331 | updateInputPlaceholders(bindings_, |
332 | {bindings_.getPlaceholderByNameSlow("input" )}, {&in}); |
333 | EER.run(bindings_); |
334 | Tensor ref = |
335 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
336 | |
337 | std::vector<DeviceInfo> devices = { |
338 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "Interpreter" }}; |
339 | Partitioner myPartitioner(&EEP.getModule(), devices, true); |
340 | CompilationContext cctx; |
341 | auto dagList = myPartitioner.partition(cctx); |
342 | ASSERT_TRUE((bool)dagList); |
343 | EXPECT_EQ(EEP.getModule().getFunctions().size(), 3); |
344 | EXPECT_EQ(dagList->size(), 1); |
345 | EXPECT_TRUE(checkSaveNode(EEP.getModule())); |
346 | |
347 | // Run the paritioned graph and compare the results. |
348 | bindings_.clear(); |
349 | bindings_.allocate(EEP.getModule().getPlaceholders()); |
350 | EEP.compile(cctx); |
351 | executeDAG(dagList->begin()->root.get(), EEP.getModule(), bindings_, |
352 | {bindings_.getPlaceholderByNameSlow("input" )}, {&in}, &EEP); |
353 | Tensor test = |
354 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
355 | EXPECT_TRUE(ref.isEqual(test, 0.0f)); |
356 | verifyDAGSerialization(dagList.get(), EEP.getModule(), bindings_, {"input" }, |
357 | "ret" , devices, {&in}, ref); |
358 | } |
359 | |
360 | /// This one tests the model with this feature: after BFS, there is one level, |
361 | /// the memory consumption of all the nodes in which exceeds the device memory |
362 | /// constraints. |
363 | TEST_F(PartitionerTest, Basic2) { |
364 | |
365 | ExecutionEngine EER, EEP; |
366 | EEP.setSkipModuleStrip(true); |
367 | constexpr float range = 2.0; |
368 | std::vector<ExecutionEngine *> engines{&EER, &EEP}; |
369 | for (auto EE : engines) { |
370 | auto mod = &EE->getModule(); |
371 | F_ = mod->createFunction("main" ); |
372 | auto *input = |
373 | mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input" , false); |
374 | auto *input1 = |
375 | mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input1" , false); |
376 | bindings_.allocate(input); |
377 | bindings_.allocate(input1); |
378 | // Left branch. |
379 | auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
380 | auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2" ); |
381 | w2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
382 | b2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
383 | Node *L = F_->createFullyConnected("left_fc1" , input, w2, b2); |
384 | L = F_->createSigmoid("left_sigmoid1" , L); |
385 | auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
386 | auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3" ); |
387 | w3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
388 | b3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
389 | L = F_->createFullyConnected("left_fc2" , L, w3, b3); |
390 | L = F_->createSigmoid("left_sigmoid2" , L); |
391 | |
392 | // Right branch. |
393 | auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4" ); |
394 | auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4" ); |
395 | w4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
396 | b4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
397 | Node *R = F_->createFullyConnected("right_fc1" , input1, w4, b4); |
398 | R = F_->createSigmoid("right_sigmoid1" , R); |
399 | auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
400 | auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5" ); |
401 | w5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
402 | b5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
403 | R = F_->createFullyConnected("right_fc2" , R, w5, b5); |
404 | R = F_->createSigmoid("right_sigmoid2" , R); |
405 | |
406 | // Join branches. |
407 | auto *mul = F_->createMul("mul" , L, R); |
408 | F_->createSave("ret" , mul); |
409 | } |
410 | |
411 | // Infer using the un-partitioned graph. |
412 | Tensor in(ElemKind::FloatTy, {1, 16}); |
413 | in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG()); |
414 | EER.compile(CompilationMode::Infer); |
415 | bindings_.clear(); |
416 | bindings_.allocate(EER.getModule().getPlaceholders()); |
417 | updateInputPlaceholders(bindings_, |
418 | {bindings_.getPlaceholderByNameSlow("input" ), |
419 | bindings_.getPlaceholderByNameSlow("input1" )}, |
420 | {&in, &in}); |
421 | EER.run(bindings_); |
422 | Tensor ref = |
423 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
424 | |
425 | std::vector<DeviceInfo> devices = {{2048, "Interpreter" }, |
426 | {2048, "Interpreter" }, |
427 | {2048, "Interpreter" }, |
428 | {2048, "Interpreter" }}; |
429 | Partitioner myPartitioner(&EEP.getModule(), devices); |
430 | CompilationContext cctx; |
431 | cctx.saturateHost = true; |
432 | runtime::DAGListTy dagList; |
433 | ASSIGN_VALUE_OR_FAIL_TEST(dagList, myPartitioner.partition(cctx)); |
434 | EXPECT_EQ(EEP.getModule().getFunctions().size(), 2); |
435 | EXPECT_EQ(dagList.size(), 1); |
436 | ASSERT_TRUE(checkSaveNode(EEP.getModule())); |
437 | |
438 | for (auto &dag : dagList) { |
439 | for (auto &node : dag.nodes) { |
440 | // Since saturateHost is set true, in this case, there should be 2 copys |
441 | // of the partitions. |
442 | EXPECT_EQ(node->logicalDevices.size(), 2); |
443 | } |
444 | } |
445 | |
446 | // Run the paritioned graph and compare the results. |
447 | bindings_.clear(); |
448 | bindings_.allocate(EEP.getModule().getPlaceholders()); |
449 | EEP.compile(cctx); |
450 | updateInputPlaceholders(bindings_, |
451 | {bindings_.getPlaceholderByNameSlow("input" ), |
452 | bindings_.getPlaceholderByNameSlow("input1" )}, |
453 | {&in, &in}); |
454 | executeDAG(dagList.begin()->root.get(), EEP.getModule(), bindings_, |
455 | {bindings_.getPlaceholderByNameSlow("input" )}, {&in}, &EEP); |
456 | Tensor test = |
457 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
458 | ASSERT_TRUE(ref.isEqual(test, 0.0f)); |
459 | verifyDAGSerialization(dagList, EEP.getModule(), bindings_, |
460 | {"input" , "input1" }, "ret" , devices, {&in, &in}, ref); |
461 | } |
462 | |
463 | /// This one tests the error msg: if the number of partitions is larger than |
464 | /// given number of devices, report an error. |
465 | TEST_F(PartitionerTest, Error1) { |
466 | ExecutionEngine EER, EEP; |
467 | constexpr float range = 2.0; |
468 | std::vector<ExecutionEngine *> engines{&EER, &EEP}; |
469 | for (auto EE : engines) { |
470 | auto mod = &EE->getModule(); |
471 | F_ = mod->createFunction("main" ); |
472 | auto *input = |
473 | mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input" , false); |
474 | auto *input1 = |
475 | mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input1" , false); |
476 | bindings_.allocate(input); |
477 | bindings_.allocate(input1); |
478 | // Left branch. |
479 | auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
480 | auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2" ); |
481 | w2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
482 | b2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
483 | Node *L = F_->createFullyConnected("left_fc1" , input, w2, b2); |
484 | L = F_->createSigmoid("left_sigmoid1" , L); |
485 | auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
486 | auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3" ); |
487 | w3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
488 | b3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
489 | L = F_->createFullyConnected("left_fc2" , L, w3, b3); |
490 | L = F_->createSigmoid("left_sigmoid2" , L); |
491 | |
492 | // Right branch. |
493 | auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4" ); |
494 | auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4" ); |
495 | w4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
496 | b4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
497 | Node *R = F_->createFullyConnected("right_fc1" , input1, w4, b4); |
498 | R = F_->createSigmoid("right_sigmoid1" , R); |
499 | auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
500 | auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5" ); |
501 | w5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
502 | b5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
503 | R = F_->createFullyConnected("right_fc2" , R, w5, b5); |
504 | R = F_->createSigmoid("right_sigmoid2" , R); |
505 | |
506 | // Join branches. |
507 | auto *mul = F_->createMul("mul" , L, R); |
508 | F_->createSave("ret" , mul); |
509 | } |
510 | |
511 | // Infer using the un-partitioned graph. |
512 | Tensor in(ElemKind::FloatTy, {1, 16}); |
513 | in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG()); |
514 | |
515 | EER.compile(CompilationMode::Infer); |
516 | bindings_.clear(); |
517 | bindings_.allocate(EER.getModule().getPlaceholders()); |
518 | updateInputPlaceholders(bindings_, |
519 | {bindings_.getPlaceholderByNameSlow("input" ), |
520 | bindings_.getPlaceholderByNameSlow("input1" )}, |
521 | {&in, &in}); |
522 | EER.run(bindings_); |
523 | |
524 | std::vector<DeviceInfo> devices = {{2048, "Interpreter" }}; |
525 | Partitioner myPartitioner(&EEP.getModule(), devices); |
526 | CompilationContext cctx; |
527 | auto dagList = myPartitioner.partition(cctx); |
528 | EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError())); |
529 | } |
530 | |
531 | /// This one tests the roofline computed with compute, memory and |
532 | /// communication costs |
533 | TEST_F(PartitionerTest, Basic1Roofline) { |
534 | ExecutionEngine EEP; |
535 | constexpr float range = 2.0; |
536 | |
537 | auto mod = &EEP.getModule(); |
538 | F_ = mod->createFunction("main" ); |
539 | auto *input = |
540 | mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input" , false); |
541 | auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1" ); |
542 | auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1" ); |
543 | bindings_.allocate(input); |
544 | w1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
545 | b1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
546 | |
547 | // Initial FC. |
548 | Node *I = F_->createFullyConnected("initial_fc" , input, w1, b1); |
549 | I = F_->createSigmoid("initial_sigmoid" , I); |
550 | |
551 | // Left branch. |
552 | auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
553 | auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2" ); |
554 | w2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
555 | b2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
556 | Node *L = F_->createFullyConnected("left_fc1" , I, w2, b2); |
557 | L = F_->createSigmoid("left_sigmoid1" , L); |
558 | auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
559 | auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3" ); |
560 | w3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
561 | b3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
562 | L = F_->createFullyConnected("left_fc2" , L, w3, b3); |
563 | L = F_->createSigmoid("left_sigmoid2" , L); |
564 | |
565 | // Right branch. |
566 | auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4" ); |
567 | auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4" ); |
568 | w4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
569 | b4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
570 | Node *R = F_->createFullyConnected("right_fc1" , I, w4, b4); |
571 | R = F_->createSigmoid("right_sigmoid1" , R); |
572 | auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
573 | auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5" ); |
574 | w5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
575 | b5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
576 | R = F_->createFullyConnected("right_fc2" , R, w5, b5); |
577 | R = F_->createSigmoid("right_sigmoid2" , R); |
578 | |
579 | // Join branches. |
580 | auto *mul = F_->createMul("mul" , L, R); |
581 | F_->createSave("ret" , mul); |
582 | |
583 | // Since the partitioner will look at all nodesin the function post |
584 | // optimization and lowering, we need to do so here for the same list of |
585 | // nodes. |
586 | std::unique_ptr<Backend> backend(createBackend(EEP.getBackendName())); |
587 | CompilationContext cctx; |
588 | EXIT_ON_ERR(optimizeFunctionBeforeLowering( |
589 | EEP.getModule().getFunction("main" ), cctx)); |
590 | EXIT_ON_ERR(::glow::optimizeFunction(EEP.getModule().getFunction("main" ), |
591 | *backend, cctx)); |
592 | std::unordered_map<Node *, std::string> nodeNamesMap; |
593 | for (auto &node : EEP.getModule().getFunction("main" )->getNodes()) { |
594 | nodeNamesMap[&node] = node.getName().str(); |
595 | } |
596 | |
597 | // check compute costs |
598 | std::unordered_map<std::string, float> expectedComputeTime{ |
599 | {"initial_sigmoid" , 128}, {"left_sigmoid2" , 64}, |
600 | {"right_sigmoid1" , 128}, {"mul" , 96}, |
601 | {"ret_save" , 0}, {"initial_fc" , 21760}, |
602 | {"left_fc1" , 10240}, {"left_fc2" , 5120}, |
603 | {"left_sigmoid1" , 128}, {"right_fc1" , 10240}, |
604 | {"right_fc2" , 5120}, {"right_sigmoid2" , 64}, |
605 | }; |
606 | |
607 | BackendInfo backendInfo; |
608 | backendInfo.sramCapacity = 100; |
609 | backendInfo.peakCompute = 10; |
610 | backendInfo.peakDramBw = 0.1; |
611 | backendInfo.peakSramBw = 1; |
612 | backendInfo.peakPCIeBw = 0.05; |
613 | for (auto const &p : nodeNamesMap) { |
614 | auto *N = p.first; |
615 | EXPECT_EQ(getNodeComputeTime(N, backendInfo), |
616 | expectedComputeTime[p.second]); |
617 | } |
618 | } |
619 | |
620 | TEST_F(PartitionerTest, SelectRepFunc) { |
621 | auto *inA = mod_.createConstant(ElemKind::FloatTy, {2}, "A" ); |
622 | auto *inB = mod_.createConstant(ElemKind::FloatTy, {2}, "B" ); |
623 | inA->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); |
624 | inB->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG()); |
625 | |
626 | auto *plus = F_->createAdd("AplusB" , inA, inB); |
627 | F_->createSave("save" , plus); |
628 | |
629 | Partitioner myPartitioner(&mod_, {{1000000, "Interpreter" }, |
630 | {1000000, "Interpreter" }, |
631 | {1000000, "Interpreter" }}); |
632 | |
633 | CompilationContext cctx; |
634 | auto dagList = myPartitioner.partition(cctx); |
635 | ASSERT_TRUE((bool)dagList); |
636 | } |
637 | |
638 | /// Create a mock backend and rewrite the isOpSupported function |
639 | /// to un-support the op \p unsupportedOpKind. |
640 | template <glow::Kinded::Kind unsupportedOpKind> |
641 | class MockBackend : public Backend { |
642 | public: |
643 | std::string backendName; |
644 | |
645 | class MockFunction : public CompiledFunction { |
646 | public: |
647 | MockFunction(llvm::StringRef backendName, runtime::RuntimeBundle &&bundle) |
648 | : CompiledFunction(std::move(bundle)), backendName(backendName) {} |
649 | |
650 | Error execute(ExecutionContext *) override { return Error::success(); } |
651 | |
652 | std::string getCompileBackendName() const override { return backendName; } |
653 | |
654 | std::string backendName; |
655 | }; |
656 | |
657 | std::string getBackendName() const override { return backendName; } |
658 | |
659 | Expected<std::unique_ptr<CompiledFunction>> |
660 | compile(Function *F, const BackendOptions &opts) const override { |
661 | return glow::make_unique<MockFunction>(backendName, |
662 | runtime::RuntimeBundle::create(*F)); |
663 | } |
664 | |
665 | bool isOpSupported(const NodeInfo &NI) const override { |
666 | if (NI.getKind() == unsupportedOpKind) { |
667 | return false; |
668 | } |
669 | return true; |
670 | } |
671 | |
672 | bool shouldLower(const Node *N) const override { return false; } |
673 | |
674 | bool generateInst(Node *N, IRGenVisitor &irgen) const override { |
675 | return false; |
676 | } |
677 | |
678 | Expected<double> estimateNodeCost(const Node * /*node */) const override { |
679 | return 2.0; |
680 | } |
681 | |
682 | runtime::DeviceManager * |
683 | createDeviceManager(const runtime::DeviceConfig &deviceConfig) override { |
684 | return nullptr; |
685 | } |
686 | }; |
687 | |
688 | class BackendWithoutSub : public MockBackend<Kinded::Kind::SubNodeKind> { |
689 | public: |
690 | BackendWithoutSub() { backendName = "CPU" ; } |
691 | }; |
692 | class BackendWithoutMul : public MockBackend<Kinded::Kind::MulNodeKind> { |
693 | public: |
694 | BackendWithoutMul() { backendName = "Interpreter" ; } |
695 | }; |
696 | |
697 | static void createSimpleModule(Module &mod) { |
698 | mod.clear(); |
699 | auto *F = mod.createFunction("test" ); |
700 | auto *input1 = |
701 | mod.createPlaceholder(ElemKind::FloatTy, {16}, "input1" , false); |
702 | auto *input2 = |
703 | mod.createPlaceholder(ElemKind::FloatTy, {16}, "input2" , false); |
704 | auto *input3 = |
705 | mod.createPlaceholder(ElemKind::FloatTy, {16}, "input3" , false); |
706 | auto *sub = F->createSub("sub" , input1, input2); |
707 | auto *mul = F->createMul("mul" , input1, input2); |
708 | auto *sum = F->createAdd("add" , sub, mul); |
709 | auto *sub2 = F->createSub("sub1" , sum, input3); |
710 | auto *save = F->createSave("ret" , sub2); |
711 | (void)save; |
712 | } |
713 | |
714 | static void createSimpleSparseNNModule(Module &mod, bool shareSplatInputs, |
715 | bool addClipAndLayerNorm, bool addTile, |
716 | bool addTanh, dim_t numFCLayers) { |
717 | mod.clear(); |
718 | auto *F = mod.createFunction("test" ); |
719 | |
720 | // Create SLS inputs |
721 | std::vector<NodeValue> slsOutputs; |
722 | const dim_t tableWidth = 16; |
723 | const dim_t numIndices = 80; |
724 | const dim_t batchSize = 32; |
725 | const dim_t tableEntries = 10; |
726 | const size_t tableNum = 5; |
727 | // Based on how SLS table width is calculated below, the fcWidth |
728 | // is the sum of all the SLS tables' width. |
729 | const dim_t fcWidth = tableNum * (tableNum + 1) / 2 * tableWidth; |
730 | |
731 | NodeValue weights; |
732 | NodeValue lengths; |
733 | NodeValue scale; |
734 | NodeValue bias; |
735 | if (shareSplatInputs) { |
736 | // Shared by FusedRowwiseQuantizedSparseLengthsWeightedSumNode |
737 | auto ty = |
738 | F->getParent()->uniqueType(ElemKind::FloatTy, {numIndices * batchSize}); |
739 | weights = F->createSplat("ones" , ty, 1.0)->getResult(); |
740 | lengths = |
741 | F->createSplat( |
742 | "lengths" , |
743 | F->getParent()->uniqueType(ElemKind::Int32ITy, {batchSize}), 1) |
744 | ->getResult(); |
745 | // Shared by LayerNormalizationNode |
746 | scale = F->createSplat("LN_scale" , |
747 | F->getParent()->uniqueType(ElemKind::FloatTy, |
748 | {fcWidth / tableNum}), |
749 | 1.0) |
750 | ->getResult(); |
751 | bias = F->createSplat("LN_bias" , |
752 | F->getParent()->uniqueType(ElemKind::FloatTy, |
753 | {fcWidth / tableNum}), |
754 | 1.0) |
755 | ->getResult(); |
756 | } |
757 | |
758 | // Create SLS portion |
759 | for (int table = 0; table < tableNum; table++) { |
760 | dim_t thisTableWidth = |
761 | shareSplatInputs ? fcWidth / tableNum : (table + 1) * tableWidth; |
762 | Tensor data(ElemKind::FloatTy, {tableEntries, thisTableWidth}); |
763 | auto *indices = mod.createPlaceholder( |
764 | ElemKind::Int64ITy, {numIndices * batchSize}, "indices" , false); |
765 | if (!shareSplatInputs) { |
766 | weights = mod.createPlaceholder(ElemKind::FloatTy, |
767 | {numIndices * batchSize}, "w" , false) |
768 | ->getOutput(); |
769 | |
770 | lengths = mod.createPlaceholder(ElemKind::Int32ITy, {batchSize}, |
771 | "lengths" , false) |
772 | ->getOutput(); |
773 | if (addTile && table == 0) { |
774 | lengths = |
775 | mod.createPlaceholder(ElemKind::Int32ITy, {1}, "lengths" , false) |
776 | ->getOutput(); |
777 | } |
778 | } |
779 | float avgLength = (table % 2) ? 12.0f : 10.0f; |
780 | auto *slsOutput = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
781 | "SLS" , data, weights, indices, lengths, ElemKind::UInt8FusedQTy, |
782 | /*useFP16Accumulation*/ false, |
783 | /* lengthsMode */ LengthsMode::Variable, /* avgLength */ avgLength); |
784 | |
785 | if (addClipAndLayerNorm) { |
786 | /* Clip */ |
787 | auto *clipped = F->createClip("SLS_clipped" , slsOutput, 0.0f, 70.0f); |
788 | |
789 | /* Layer Norm*/ |
790 | if (!shareSplatInputs) { |
791 | Tensor scaleT(ElemKind::FloatTy, {thisTableWidth}); |
792 | scaleT.getHandle().randomize(0.0f, 1.0f, mod.getPRNG()); |
793 | scale = mod.createConstant("LN_scale" , std::move(scaleT)); |
794 | |
795 | Tensor biasT(ElemKind::FloatTy, {thisTableWidth}); |
796 | biasT.getHandle().randomize(0.0f, 1.0f, mod.getPRNG()); |
797 | bias = mod.createConstant("LN_bias" , std::move(biasT)); |
798 | } |
799 | auto *layerNormed = F->createLayerNormalization( |
800 | "LN" , clipped->getResult().getType(), clipped, scale, bias, 1e-5); |
801 | |
802 | /* Clip */ |
803 | auto *layerNormedClipped = |
804 | F->createClip("LN_clipped" , layerNormed, 0.0f, 70.0f); |
805 | slsOutputs.emplace_back(layerNormedClipped); |
806 | } else if (addTile && table == 0) { |
807 | /* Tile */ |
808 | auto *tiled = F->createTile("SLS_tiled" , slsOutput, batchSize, 0); |
809 | slsOutputs.emplace_back(tiled); |
810 | } else if (addTanh) { |
811 | /* Tanh */ |
812 | auto *tanh = F->createTanh("SLS_tanh" , slsOutput); |
813 | slsOutputs.emplace_back(tanh); |
814 | } else { |
815 | slsOutputs.emplace_back(slsOutput); |
816 | } |
817 | } |
818 | |
819 | // Create Concat |
820 | auto *concat = F->createConcat("concat" , slsOutputs, 1); |
821 | Node *cur = (Node *)concat; |
822 | |
823 | // Create FC portion |
824 | for (dim_t layer = 0; layer < numFCLayers; layer++) { |
825 | Tensor FCWeights(ElemKind::FloatTy, {fcWidth, fcWidth}); |
826 | FCWeights.getHandle().randomize(-0.5, 0.5, mod.getPRNG()); |
827 | Constant *weights = mod.createConstant("FCWeights" , FCWeights); |
828 | Tensor FCBias(ElemKind::FloatTy, {fcWidth}); |
829 | FCBias.getHandle().randomize(-0.5, 0.5, mod.getPRNG()); |
830 | Constant *bias = mod.createConstant("FCBias" , FCBias); |
831 | |
832 | auto *FC = F->createFullyConnected("FC" , cur, weights, bias); |
833 | cur = (Node *)FC; |
834 | } |
835 | |
836 | auto *save = F->createSave("ret" , cur); |
837 | (void)save; |
838 | } |
839 | |
840 | /// \returns true if there is \p nodeKind kind of nodes in \p func. |
841 | static bool findNodeInFunction(const Function *func, |
842 | const Kinded::Kind nodeKind) { |
843 | for (const Node &N : func->getNodes()) { |
844 | if (N.getKind() == nodeKind) { |
845 | return true; |
846 | } |
847 | } |
848 | return false; |
849 | } |
850 | |
851 | /// To check if the generated DAG is correct for the SparseNN Partiton |
852 | /// unnittests. The network used for check is generated from function static |
853 | /// void createSimpleSparseNNModule(Module &mod). |
854 | static void |
855 | sparseNNPartitionValidation(const DAGListTy &dagList, uint64_t deviceMemory, |
856 | Module &mod, bool shareSplatInputs, |
857 | bool addClipAndLayerNorm, bool pairLNWithSLS, |
858 | bool addTile, bool pairTileWithSLS, bool addTanh, |
859 | std::string sparseNNPartitioningPairSLSWith) { |
860 | int numOfCPUBackends = 0; |
861 | int numOfSLSNodes = 0; |
862 | int numOfFCNodes = 0; |
863 | std::unordered_set<uint64_t> slsPartitionSizes; |
864 | uint64_t nonSlsPartitionSize = 0; |
865 | for (auto &dag : dagList) { |
866 | auto tileAdded = false; |
867 | for (auto &node : dag.nodes) { |
868 | ASSERT_TRUE(node->backendName == "CPU" ); |
869 | numOfCPUBackends++; |
870 | auto *func = mod.getFunction(node->name); |
871 | GraphMemInfo memInfo = getFunctionMemory(func); |
872 | |
873 | if (shareSplatInputs) { |
874 | for (const Node &N : func->getNodes()) { |
875 | if (const auto *SLWS = llvm::dyn_cast< |
876 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(&N)) { |
877 | EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getWeights())); |
878 | EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getLengths())); |
879 | // weight/length node is splat node, partitioner will clone it, |
880 | // thus each user has its own copy. |
881 | EXPECT_EQ(SLWS->getWeights().getNumUsers(), 1); |
882 | EXPECT_EQ(SLWS->getLengths().getNumUsers(), 1); |
883 | } else if (const auto *SLWS = |
884 | llvm::dyn_cast<LayerNormalizationNode>(&N)) { |
885 | EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getScale())); |
886 | EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getBias())); |
887 | // scale/bias node is splat node, partitioner will clone it, thus |
888 | // each user has its own copy. |
889 | EXPECT_EQ(SLWS->getScale().getNumUsers(), 1); |
890 | EXPECT_EQ(SLWS->getBias().getNumUsers(), 1); |
891 | } |
892 | } |
893 | } |
894 | |
895 | if (findNodeInFunction( |
896 | func, |
897 | Kinded::Kind:: |
898 | FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind)) { |
899 | numOfSLSNodes++; |
900 | slsPartitionSizes.insert(memInfo.getTotalMemSize()); |
901 | EXPECT_EQ(node->logicalDevices.size(), 1); |
902 | if (addClipAndLayerNorm && pairLNWithSLS) { |
903 | EXPECT_TRUE(findNodeInFunction( |
904 | func, Kinded::Kind::LayerNormalizationNodeKind)); |
905 | EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::ClipNodeKind)); |
906 | } |
907 | if (addTile && pairTileWithSLS) { |
908 | tileAdded |= findNodeInFunction(func, Kinded::Kind::TileNodeKind); |
909 | } |
910 | if (addTanh && |
911 | sparseNNPartitioningPairSLSWith.find("Tanh" ) != std::string::npos) { |
912 | EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::TanhNodeKind)); |
913 | } |
914 | } else if (findNodeInFunction(func, |
915 | Kinded::Kind::FullyConnectedNodeKind)) { |
916 | nonSlsPartitionSize = memInfo.getTotalMemSize(); |
917 | numOfFCNodes++; |
918 | EXPECT_EQ(node->logicalDevices.size(), 3); |
919 | if (addClipAndLayerNorm && !pairLNWithSLS) { |
920 | EXPECT_TRUE(findNodeInFunction( |
921 | func, Kinded::Kind::LayerNormalizationNodeKind)); |
922 | EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::ClipNodeKind)); |
923 | } |
924 | if (addTanh && |
925 | sparseNNPartitioningPairSLSWith.find("Tanh" ) == std::string::npos) { |
926 | EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::TanhNodeKind)); |
927 | } |
928 | } else { |
929 | FAIL() << "Unexpected partition" ; |
930 | } |
931 | } |
932 | if (addTile && pairTileWithSLS) { |
933 | EXPECT_TRUE(tileAdded); |
934 | } |
935 | } |
936 | |
937 | // 4 partitions (3 SLS + 1 FC) |
938 | EXPECT_EQ(numOfCPUBackends, 4); |
939 | EXPECT_EQ(numOfSLSNodes, 3); |
940 | EXPECT_EQ(numOfFCNodes, 1); |
941 | for (uint64_t slsPartitionSize : slsPartitionSizes) { |
942 | EXPECT_LE(slsPartitionSize + nonSlsPartitionSize, deviceMemory); |
943 | } |
944 | } |
945 | |
946 | static void testSimpleSparseNNPartitioning( |
947 | Module &mod, bool shareSplatInputs, bool concatSLSOutputs, |
948 | bool balancePerfModel, bool addClipAndLayerNorm, bool pairLNWithSLS, |
949 | bool addTile, bool pairTileWithSLS, bool addTanh, |
950 | std::string sparseNNPartitioningPairSLSWith, bool forceFailure = false) { |
951 | createSimpleSparseNNModule(mod, shareSplatInputs, addClipAndLayerNorm, |
952 | addTile, addTanh, forceFailure ? 5 : 4); |
953 | BackendWithoutSub backend1, backend2, backend3; |
954 | std::vector<Backend *> backends; |
955 | backends.emplace_back(&backend1); |
956 | backends.emplace_back(&backend2); |
957 | backends.emplace_back(&backend3); |
958 | const uint64_t deviceMemory = 1250000; |
959 | std::vector<DeviceInfo> devices = { |
960 | {deviceMemory, "CPU" }, {deviceMemory, "CPU" }, {deviceMemory, "CPU" }}; |
961 | Partitioner partitioner(&mod, devices, backends); |
962 | CompilationContext cctx; |
963 | cctx.optimizationOpts.useSparseNNPartitioningScheme = true; |
964 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards = 3; |
965 | cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats = concatSLSOutputs; |
966 | cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel = balancePerfModel; |
967 | cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS = pairLNWithSLS; |
968 | cctx.optimizationOpts.sparseNNPartitioningPairTileWithSLS = pairTileWithSLS; |
969 | cctx.optimizationOpts.sparseNNPartitioningPairSLSWith = |
970 | sparseNNPartitioningPairSLSWith; |
971 | Expected<DAGListTy> dagList = partitioner.partition(cctx); |
972 | bool failed = ERR_TO_BOOL(dagList.takeError()); |
973 | if (forceFailure) { |
974 | EXPECT_TRUE(failed); |
975 | return; |
976 | } |
977 | if (concatSLSOutputs && addTile && !pairTileWithSLS) { |
978 | EXPECT_TRUE(failed); |
979 | return; |
980 | } |
981 | EXPECT_EQ(mod.getFunctions().size(), 4); |
982 | EXPECT_EQ(dagList->size(), 1); |
983 | ASSERT_TRUE(checkSaveNode(mod)); |
984 | sparseNNPartitionValidation(*dagList, deviceMemory, mod, shareSplatInputs, |
985 | addClipAndLayerNorm, pairLNWithSLS, addTile, |
986 | pairTileWithSLS, addTanh, |
987 | sparseNNPartitioningPairSLSWith); |
988 | mod.clear(); |
989 | } |
990 | |
991 | /// Test using user-defined backends for SparseNN partition. |
992 | TEST_F(PartitionerTest, SimpleSparseNNPartitioning) { |
993 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
994 | /*concatSLSOutputs*/ false, |
995 | /*balancePerfModel*/ false, |
996 | /*addClipAndLayerNorm*/ false, |
997 | /*pairLNWithSLS*/ false, |
998 | /*addTile*/ false, |
999 | /*pairTileWithSLS*/ false, |
1000 | /*addTanh*/ false, |
1001 | /*pairSLSWith*/ "" ); |
1002 | } |
1003 | |
1004 | /// Test that this flag is a NOP when LN doesn't exist |
1005 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairLNNOP) { |
1006 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1007 | /*concatSLSOutputs*/ false, |
1008 | /*balancePerfModel*/ false, |
1009 | /*addClipAndLayerNorm*/ false, |
1010 | /*pairLNWithSLS*/ true, |
1011 | /*addTile*/ false, |
1012 | /*addTanh*/ false, |
1013 | /*pairTileWithSLS*/ |
1014 | false, |
1015 | /*pairSLSWith*/ "" ); |
1016 | } |
1017 | |
1018 | /// Test using user-defined backends for SparseNN partition. |
1019 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningClipAndLayerNormInNonSLS) { |
1020 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1021 | /*concatSLSOutputs*/ false, |
1022 | /*balancePerfModel*/ false, |
1023 | /*addClipAndLayerNorm*/ true, |
1024 | /*pairLNWithSLS*/ false, |
1025 | /*addTile*/ false, |
1026 | /*pairTileWithSLS*/ false, |
1027 | /*addTanh*/ false, |
1028 | /*pairSLSWith*/ "" ); |
1029 | } |
1030 | |
1031 | /// Test using user-defined backends for SparseNN partition. |
1032 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningClipAndLayerNormInSLS) { |
1033 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1034 | /*concatSLSOutputs*/ false, |
1035 | /*balancePerfModel*/ false, |
1036 | /*addClipAndLayerNorm*/ true, |
1037 | /*pairLNWithSLS*/ true, |
1038 | /*addTile*/ false, |
1039 | /*pairTileWithSLS*/ false, |
1040 | /*addTanh*/ false, |
1041 | /*pairSLSWith*/ "" ); |
1042 | } |
1043 | |
1044 | TEST_F(PartitionerTest, SimpleSparseNNPartitioning_ConcatSLSOutputs) { |
1045 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1046 | /*concatSLSOutputs*/ true, |
1047 | /*balancePerfModel*/ false, |
1048 | /*addClipAndLayerNorm*/ true, |
1049 | /*pairLNWithSLS*/ false, |
1050 | /*addTile*/ false, |
1051 | /*pairTileWithSLS*/ false, |
1052 | /*addTanh*/ false, |
1053 | /*pairSLSWith*/ "" ); |
1054 | } |
1055 | |
1056 | /// Test using user-defined backends for SparseNN partition when inputs are |
1057 | /// shared Splats by all SLSs. |
1058 | TEST_F(PartitionerTest, SimpleSparseNNPartitioning_SharedSplatInputs) { |
1059 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ true, |
1060 | /*concatSLSOutputs*/ false, |
1061 | /*balancePerfModel*/ false, |
1062 | /*addClipAndLayerNorm*/ true, |
1063 | /*pairLNWithSLS*/ false, |
1064 | /*addTile*/ false, |
1065 | /*pairTileWithSLS*/ false, |
1066 | /*addTanh*/ false, |
1067 | /*pairSLSWith*/ "" ); |
1068 | } |
1069 | |
1070 | /// Test using user-defined backends for SparseNN partition when inputs are |
1071 | /// shared Splats by all SLSs, and LN is included in frontier. |
1072 | TEST_F(PartitionerTest, |
1073 | SimpleSparseNNPartitioning_SharedSplatInputsAndLayerNormInSLS) { |
1074 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ true, |
1075 | /*concatSLSOutputs*/ false, |
1076 | /*balancePerfModel*/ false, |
1077 | /*addClipAndLayerNorm*/ true, |
1078 | /*pairLNWithSLS*/ true, |
1079 | /*addTile*/ false, |
1080 | /*pairTileWithSLS*/ false, |
1081 | /*addTanh*/ false, |
1082 | /*pairSLSWith*/ "" ); |
1083 | } |
1084 | |
1085 | /// Test using user-defined backends for SparseNN partition. |
1086 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningBalancePerfModel) { |
1087 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1088 | /*concatSLSOutputs*/ false, |
1089 | /*balancePerfModel*/ true, |
1090 | /*addClipAndLayerNorm*/ true, |
1091 | /*pairLNWithSLS*/ false, |
1092 | /*addTile*/ false, |
1093 | /*pairTileWithSLS*/ false, |
1094 | /*addTanh*/ false, |
1095 | /*pairSLSWith*/ "" ); |
1096 | } |
1097 | |
1098 | /// Test pairTileWithSLS is NOP when Tile doesn't exist |
1099 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTileNOP) { |
1100 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1101 | /*concatSLSOutputs*/ true, |
1102 | /*balancePerfModel*/ false, |
1103 | /*addClipAndLayerNorm*/ false, |
1104 | /*pairLNWithSLS*/ false, |
1105 | /*addTile*/ false, |
1106 | /*pairTileWithSLS*/ true, |
1107 | /*addTanh*/ false, |
1108 | /*pairSLSWith*/ "" ); |
1109 | } |
1110 | |
1111 | /// Test concatting SLS nodes where one node has first dimension = 1 without |
1112 | /// concatSLSOutputs works |
1113 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningTileAndConcatSLSOutputs) { |
1114 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1115 | /*concatSLSOutputs*/ false, |
1116 | /*balancePerfModel*/ false, |
1117 | /*addClipAndLayerNorm*/ false, |
1118 | /*pairLNWithSLS*/ false, |
1119 | /*addTile*/ true, |
1120 | /*pairTileWithSLS*/ false, |
1121 | /*addTanh*/ false, |
1122 | /*pairSLSWith*/ "" ); |
1123 | } |
1124 | |
1125 | /// Test concatSLSOutputs on SLS nodes w/ first dimension = 1 while other nodes |
1126 | /// have the same first dimension without pairTileWithSLS flag fails |
1127 | TEST_F(PartitionerTest, |
1128 | SimpleSparseNNPartitioningTileAndConcatSlsOutputsFails) { |
1129 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1130 | /*concatSLSOutputs*/ true, |
1131 | /*balancePerfModel*/ false, |
1132 | /*addClipAndLayerNorm*/ false, |
1133 | /*pairLNWithSLS*/ false, |
1134 | /*addTile*/ true, |
1135 | /*pairTileWithSLS*/ false, |
1136 | /*addTanh*/ false, |
1137 | /*pairSLSWith*/ "" ); |
1138 | } |
1139 | |
1140 | /// Test concatSLSOutputs on SLS nodes w/ first dimension = 1 while other nodes |
1141 | /// have the same first dimension with pairTileWithSLS flag works |
1142 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTileAndConcatSlsOutputs) { |
1143 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1144 | /*concatSLSOutputs*/ true, |
1145 | /*balancePerfModel*/ false, |
1146 | /*addClipAndLayerNorm*/ false, |
1147 | /*pairLNWithSLS*/ false, |
1148 | /*addTile*/ true, |
1149 | /*pairTileWithSLS*/ true, |
1150 | /*addTanh*/ false, |
1151 | /*pairSLSWith*/ "" ); |
1152 | } |
1153 | |
1154 | /// Test pairSLSWith is NOP when Tanh doesn't exist |
1155 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTanhNOP) { |
1156 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1157 | /*concatSLSOutputs*/ true, |
1158 | /*balancePerfModel*/ false, |
1159 | /*addClipAndLayerNorm*/ false, |
1160 | /*pairLNWithSLS*/ false, |
1161 | /*addTile*/ false, |
1162 | /*pairTileWithSLS*/ false, |
1163 | /*addTanh*/ false, |
1164 | /*pairSLSWith*/ "Tanh" ); |
1165 | } |
1166 | |
1167 | /// Test using user-defined backends for SparseNN partition. |
1168 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTanhAndSLS) { |
1169 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1170 | /*concatSLSOutputs*/ false, |
1171 | /*balancePerfModel*/ false, |
1172 | /*addClipAndLayerNorm*/ false, |
1173 | /*pairLNWithSLS*/ false, |
1174 | /*addTile*/ false, |
1175 | /*pairTileWithSLS*/ false, |
1176 | /*addTanh*/ true, |
1177 | /*pairSLSWith*/ "Tanh" ); |
1178 | } |
1179 | |
1180 | /// This test checks that we fail partitioning when we have a SLSPartition and |
1181 | /// NonSLSPartition that, when summed together, cannot fit inside a device. |
1182 | TEST_F(PartitionerTest, SimpleSparseNNPartitioningExpectFailure) { |
1183 | testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false, |
1184 | /*concatSLSOutputs*/ false, |
1185 | /*balancePerfModel*/ false, |
1186 | /*addClipAndLayerNorm*/ true, |
1187 | /*pairLNWithSLS*/ true, |
1188 | /*addTile*/ false, |
1189 | /*pairTileWithSLS*/ false, |
1190 | /*addTanh*/ false, |
1191 | /*pairSLSWith*/ "LayerNorm" , |
1192 | /*forceFailure*/ true); |
1193 | } |
1194 | |
1195 | /// To check if the generated DAG is correct for the Heterogeneous Partiton |
1196 | /// unnittests. The network used for check is generated from function static |
1197 | /// void createSimpleModule(Module &mod). |
1198 | static void heterogeneousPartitionValidation(const DAGListTy &dagList, |
1199 | Module &mod) { |
1200 | int numOfInterpreterBackends = 0; |
1201 | int numOfCPUBackends = 0; |
1202 | int numOfMulNodes = 0; |
1203 | int numOfSubNodes = 0; |
1204 | for (auto &dag : dagList) { |
1205 | for (auto &node : dag.nodes) { |
1206 | // Although the saturateHost is set true, no saturating the host in |
1207 | // heterogeneous partiton. |
1208 | EXPECT_EQ(node->logicalDevices.size(), 1); |
1209 | if (node->backendName == "CPU" ) { |
1210 | numOfCPUBackends++; |
1211 | auto func = mod.getFunction(node->name); |
1212 | // Sub Node should not be assigned to CPU backend. |
1213 | EXPECT_FALSE(findNodeInFunction(func, Kinded::Kind::SubNodeKind)); |
1214 | numOfMulNodes += |
1215 | (findNodeInFunction(func, Kinded::Kind::MulNodeKind) == true); |
1216 | } |
1217 | if (node->backendName == "Interpreter" ) { |
1218 | numOfInterpreterBackends++; |
1219 | auto func = mod.getFunction(node->name); |
1220 | // Mul Node should not be assigned to Interpreter backend. |
1221 | EXPECT_FALSE(findNodeInFunction(func, Kinded::Kind::MulNodeKind)); |
1222 | numOfSubNodes += |
1223 | (findNodeInFunction(func, Kinded::Kind::SubNodeKind) == true); |
1224 | } |
1225 | } |
1226 | } |
1227 | EXPECT_EQ(numOfInterpreterBackends, 2); |
1228 | EXPECT_EQ(numOfCPUBackends, 1); |
1229 | EXPECT_EQ(numOfSubNodes, 2); |
1230 | EXPECT_EQ(numOfMulNodes, 1); |
1231 | } |
1232 | |
1233 | /// Test using user-defined backends for heterogeneous partition. |
1234 | TEST_F(PartitionerTest, SimpleHeterogeneousPartitioning) { |
1235 | createSimpleModule(mod_); |
1236 | BackendWithoutSub backendWithoutSub1; |
1237 | BackendWithoutMul backendWithoutMul1, backendWithoutMul2; |
1238 | // Create two backends which support different ops, then do the partition by |
1239 | // assigning the ops to the corresponding abackends. |
1240 | std::vector<Backend *> backends; |
1241 | backends.emplace_back(&backendWithoutMul1); |
1242 | backends.emplace_back(&backendWithoutMul2); |
1243 | backends.emplace_back(&backendWithoutSub1); |
1244 | std::vector<DeviceInfo> devices = { |
1245 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "CPU" }}; |
1246 | Partitioner partitioner(&mod_, devices, backends); |
1247 | CompilationContext cctx; |
1248 | cctx.saturateHost = true; |
1249 | auto dagList = partitioner.partition(cctx); |
1250 | ASSERT_TRUE((bool)dagList); |
1251 | EXPECT_EQ(mod_.getFunctions().size(), 3); |
1252 | EXPECT_EQ(dagList->size(), 1); |
1253 | ASSERT_TRUE(checkSaveNode(mod_)); |
1254 | heterogeneousPartitionValidation(dagList.get(), mod_); |
1255 | |
1256 | mod_.clear(); |
1257 | } |
1258 | |
1259 | /// Test pre-defined non-supported ops used for choosing backend in |
1260 | /// Heterogeneous Partition. In this test, "Mul" is not supported in |
1261 | /// Interpreter backend, and "Sub" is not supported in CPU backend. |
1262 | TEST_F(PartitionerTest, heterogeneousPartitioningWithNonSupportedNodes) { |
1263 | #ifndef GLOW_WITH_CPU |
1264 | return; |
1265 | #endif |
1266 | createSimpleModule(mod_); |
1267 | std::vector<DeviceInfo> devices = {{3072, "Interpreter" , "Mul" }, |
1268 | {3072, "Interpreter" , "Mul" }, |
1269 | {3072, "CPU" , "Sub" }}; |
1270 | Partitioner partitioner(&mod_, devices); |
1271 | CompilationContext cctx; |
1272 | auto dagList = partitioner.partition(cctx); |
1273 | ASSERT_TRUE((bool)dagList); |
1274 | EXPECT_EQ(mod_.getFunctions().size(), 3); |
1275 | EXPECT_EQ(dagList->size(), 1); |
1276 | ASSERT_TRUE(checkSaveNode(mod_)); |
1277 | heterogeneousPartitionValidation(dagList.get(), mod_); |
1278 | |
1279 | mod_.clear(); |
1280 | } |
1281 | |
1282 | /// Test pre-defined supported ops used for choosing backend in Heterogeneous |
1283 | /// Partition. In this test, "Mul" is not supported in Interpreter backend, |
1284 | /// and "Sub" is not supported in CPU backend. "Sub,Add,Save" can be supported |
1285 | /// in Interpreter backend and "Mul,Add,Save" can be supported in CPU backend. |
1286 | TEST_F(PartitionerTest, heterogeneousPartitioningWithSupportedNodes) { |
1287 | #ifndef GLOW_WITH_CPU |
1288 | return; |
1289 | #endif |
1290 | createSimpleModule(mod_); |
1291 | std::vector<DeviceInfo> devices = { |
1292 | // {memory size, backend, non-supported nodes, supported nodes} |
1293 | {3072, "Interpreter" , "" , "Sub,Add,Save" }, |
1294 | {3072, "Interpreter" , "" , "Sub,Add,Save" }, |
1295 | {3072, "CPU" , "" , "Mul,Add,Save" }}; |
1296 | Partitioner partitioner(&mod_, devices); |
1297 | CompilationContext cctx; |
1298 | auto dagList = partitioner.partition(cctx); |
1299 | ASSERT_TRUE((bool)dagList); |
1300 | EXPECT_EQ(mod_.getFunctions().size(), 3); |
1301 | EXPECT_EQ(dagList->size(), 1); |
1302 | ASSERT_TRUE(checkSaveNode(mod_)); |
1303 | heterogeneousPartitionValidation(dagList.get(), mod_); |
1304 | |
1305 | mod_.clear(); |
1306 | } |
1307 | |
1308 | /// Test assigning more than one partitions in to one device for single |
1309 | /// backendName. |
1310 | TEST_F(PartitionerTest, logicalIDTest0) { |
1311 | auto *input1 = mod_.createConstant(ElemKind::FloatTy, {1, 100}, "input1" ); |
1312 | input1->getHandle<>().randomize(-10, 10, mod_.getPRNG()); |
1313 | auto *input2 = |
1314 | mod_.createPlaceholder(ElemKind::FloatTy, {100, 1}, "input2" , false); |
1315 | auto *input3 = mod_.createConstant(ElemKind::FloatTy, {1, 100}, "input5" ); |
1316 | input3->getHandle<>().randomize(-10, 10, mod_.getPRNG()); |
1317 | auto *mul0 = F_->createMatMul("mul0" , input1, input2); |
1318 | auto *mul1 = F_->createMatMul("mul1" , mul0, input3); |
1319 | auto *save = F_->createSave("ret" , mul1); |
1320 | (void)save; |
1321 | std::vector<DeviceInfo> devices = {{1000, "Interpreter" }, |
1322 | {1000, "Interpreter" }}; |
1323 | // Create two backends which support different ops, then do the partition by |
1324 | // assigning the ops to the corresponding abackends. |
1325 | Partitioner partitioner(&mod_, devices); |
1326 | CompilationContext cctx; |
1327 | cctx.saturateHost = true; |
1328 | auto dagList = partitioner.partition(cctx); |
1329 | ASSERT_TRUE((bool)dagList); |
1330 | // Check there are 2 partitions. |
1331 | EXPECT_EQ(mod_.getFunctions().size(), 2); |
1332 | EXPECT_EQ(dagList->size(), 1); |
1333 | ASSERT_TRUE(checkSaveNode(mod_)); |
1334 | |
1335 | for (auto &dag : dagList.get()) { |
1336 | // Check number of logical devices; |
1337 | llvm::SmallSet<DeviceIDTy, 4> usedID; |
1338 | for (auto &node : dag.nodes) { |
1339 | EXPECT_EQ(node->logicalDevices.size(), 1); |
1340 | usedID.insert(node->logicalDevices[0]); |
1341 | } |
1342 | // Check there are 2 logical devices. |
1343 | EXPECT_EQ(usedID.size(), 2); |
1344 | } |
1345 | mod_.clear(); |
1346 | } |
1347 | |
1348 | /// Test assigning more than one partitions in to one device in Heterogeneous |
1349 | /// partition. |
1350 | TEST_F(PartitionerTest, logicalIDTest1) { |
1351 | createSimpleModule(mod_); |
1352 | BackendWithoutSub backendWithoutSub1, backendWithoutSub2; |
1353 | BackendWithoutMul backendWithoutMul1, backendWithoutMul2; |
1354 | // Create two backends which support different ops, then do the partition by |
1355 | // assigning the ops to the corresponding abackends. |
1356 | std::vector<Backend *> backends; |
1357 | backends.emplace_back(&backendWithoutMul1); |
1358 | backends.emplace_back(&backendWithoutSub1); |
1359 | std::vector<DeviceInfo> devices = {{3072, "Interpreter" }, {3072, "CPU" }}; |
1360 | Partitioner partitioner(&mod_, devices, backends); |
1361 | CompilationContext cctx; |
1362 | cctx.saturateHost = true; |
1363 | auto dagList = partitioner.partition(cctx); |
1364 | ASSERT_TRUE((bool)dagList); |
1365 | EXPECT_EQ(mod_.getFunctions().size(), 3); |
1366 | EXPECT_EQ(dagList->size(), 1); |
1367 | ASSERT_TRUE(checkSaveNode(mod_)); |
1368 | |
1369 | for (auto &dag : dagList.get()) { |
1370 | // Check number of logical devices; |
1371 | llvm::SmallSet<DeviceIDTy, 4> usedID; |
1372 | for (auto &node : dag.nodes) { |
1373 | // Although the saturateHost is set true, no saturating the host in |
1374 | // heterogeneous partiton. |
1375 | EXPECT_EQ(node->logicalDevices.size(), 1); |
1376 | usedID.insert(node->logicalDevices[0]); |
1377 | } |
1378 | EXPECT_EQ(usedID.size(), 2); |
1379 | } |
1380 | mod_.clear(); |
1381 | } |
1382 | |
1383 | /// Check the function getGraphMemInfo and updateGraphMemInfo to handle more |
1384 | /// than one outputs of a single Node in PartitionerUtils.cpp |
1385 | TEST_F(PartitionerTest, graphMemInfoCalculation1) { |
1386 | // TODO: The values are too large when dim_t is 32b. Figure out how it's |
1387 | // computed and ensure it's computed correctly. |
1388 | if (DIM_T_BITWIDTH == 32) |
1389 | return; |
1390 | auto *inp1 = |
1391 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 1, 3}, "input" , false); |
1392 | auto *inp2 = |
1393 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 1, 3}, "input" , false); |
1394 | auto *indices = |
1395 | mod_.createPlaceholder(ElemKind::Int64ITy, {4, 1, 2}, "indices" , false); |
1396 | |
1397 | auto *R1 = F_->createTopK("TopK1" , inp1, 2); |
1398 | auto *R2 = F_->createTopK("TopK2" , inp2, 2); |
1399 | |
1400 | // Concat the values and indices separately, both on the 0th dimension, |
1401 | // matching the shapes of the values and indices variables above. |
1402 | auto *CV = |
1403 | F_->createConcat("Concat.Values" , {R1->getValues(), R2->getValues()}, 0); |
1404 | auto *CI = F_->createConcat("Concat.Indices" , |
1405 | {R1->getIndices(), R2->getIndices()}, 0); |
1406 | |
1407 | auto *saveValues = F_->createSave("Save.Values" , CV); |
1408 | auto *saveIndices = F_->createSave("Save.Indices" , CI, indices); |
1409 | |
1410 | std::set<Node *> nodes1; |
1411 | GraphMemInfo res; |
1412 | res = updateGraphMemInfoByAddingNode(nodes1, res, R1); |
1413 | EXPECT_EQ(res, GraphMemInfo(24, 48, 0)); |
1414 | nodes1.insert(R1); |
1415 | |
1416 | res = updateGraphMemInfoByAddingNode(nodes1, res, R2); |
1417 | EXPECT_EQ(res, GraphMemInfo(48, 96, 0)); |
1418 | nodes1.insert(R2); |
1419 | |
1420 | res = updateGraphMemInfoByAddingNode(nodes1, res, CV); |
1421 | EXPECT_EQ(res, GraphMemInfo(48, 96, 0)); |
1422 | nodes1.insert(CV); |
1423 | |
1424 | res = updateGraphMemInfoByAddingNode(nodes1, res, CI); |
1425 | EXPECT_EQ(res, GraphMemInfo(48, 96, 0)); |
1426 | nodes1.insert(CI); |
1427 | |
1428 | res = updateGraphMemInfoByAddingNode(nodes1, res, saveValues); |
1429 | EXPECT_EQ(res, GraphMemInfo(48, 96, 0)); |
1430 | nodes1.insert(saveValues); |
1431 | |
1432 | res = updateGraphMemInfoByAddingNode(nodes1, res, saveIndices); |
1433 | EXPECT_EQ(res, GraphMemInfo(48, 96, 0)); |
1434 | nodes1.insert(saveIndices); |
1435 | |
1436 | std::set<Node *> nodes2, nodes3; |
1437 | nodes2.insert(R1); |
1438 | nodes2.insert(R2); |
1439 | nodes3.insert(CV); |
1440 | nodes3.insert(CI); |
1441 | nodes3.insert(saveValues); |
1442 | nodes3.insert(saveIndices); |
1443 | GraphMemInfo res1 = getGraphMemInfo(nodes2, 1); |
1444 | GraphMemInfo res2 = getGraphMemInfo(nodes3, 1); |
1445 | GraphMemInfo ref1(48, 96, 0); |
1446 | GraphMemInfo ref2(96, 96, 0); |
1447 | EXPECT_EQ(res1, ref1); |
1448 | EXPECT_EQ(res2, ref2); |
1449 | } |
1450 | |
1451 | /// Check the function updateGraphMemInfoByAddingNode and getGraphMemInfo to |
1452 | /// handle shared Storage node in PartitionerUtils.cpp |
1453 | TEST_F(PartitionerTest, graphMemInfoCalculation2) { |
1454 | auto *input = |
1455 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 16}, "input" , false); |
1456 | |
1457 | // Left branch. |
1458 | auto *w2 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
1459 | auto *b2 = mod_.createConstant(ElemKind::FloatTy, {16}, "b2" ); |
1460 | auto *L = F_->createFullyConnected("left_fc1" , input, w2, b2); |
1461 | auto *L1 = F_->createSigmoid("left_sigmoid1" , L); |
1462 | auto *w3 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
1463 | auto *b3 = mod_.createConstant(ElemKind::FloatTy, {8}, "b3" ); |
1464 | auto *L2 = F_->createFullyConnected("left_fc2" , L1, w3, b3); |
1465 | auto *L3 = F_->createSigmoid("left_sigmoid2" , L2); |
1466 | |
1467 | // Right branch. |
1468 | auto *R = F_->createFullyConnected("right_fc1" , input, w2, b2); |
1469 | auto *R1 = F_->createSigmoid("right_sigmoid1" , R); |
1470 | auto *w5 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
1471 | auto *b5 = mod_.createConstant(ElemKind::FloatTy, {8}, "b5" ); |
1472 | auto *R2 = F_->createFullyConnected("right_fc2" , R1, w5, b5); |
1473 | auto *R3 = F_->createSigmoid("right_sigmoid2" , R2); |
1474 | |
1475 | // Join branches. |
1476 | auto *mul = F_->createMul("mul" , L3, R3); |
1477 | auto *save = F_->createSave("ret" , mul); |
1478 | |
1479 | std::set<Node *> nodes1, nodes2; |
1480 | GraphMemInfo res1, res2; |
1481 | res1 = updateGraphMemInfoByAddingNode(nodes1, res1, L); |
1482 | EXPECT_EQ(res1, GraphMemInfo(64, 64, 1088)); |
1483 | nodes1.insert(L); |
1484 | |
1485 | res1 = updateGraphMemInfoByAddingNode(nodes1, res1, R); |
1486 | EXPECT_EQ(res1, GraphMemInfo(64, 128, 1088)); |
1487 | nodes1.insert(R); |
1488 | |
1489 | res1 = updateGraphMemInfoByAddingNode(nodes1, res1, R1); |
1490 | EXPECT_EQ(res1, GraphMemInfo(64, 128, 1088)); |
1491 | nodes1.insert(R1); |
1492 | |
1493 | res1 = updateGraphMemInfoByAddingNode(nodes1, res1, R2); |
1494 | EXPECT_EQ(res1, GraphMemInfo(64, 96, 1632)); |
1495 | nodes1.insert(R2); |
1496 | |
1497 | res1 = getGraphMemInfo(nodes1, 1); |
1498 | EXPECT_EQ(res1, GraphMemInfo(64, 96, 1632)); |
1499 | |
1500 | res2 = updateGraphMemInfoByAddingNode(nodes2, res2, L1); |
1501 | EXPECT_EQ(res2, GraphMemInfo(64, 64, 0)); |
1502 | nodes2.insert(L1); |
1503 | |
1504 | res2 = updateGraphMemInfoByAddingNode(nodes2, res2, L2); |
1505 | EXPECT_EQ(res2, GraphMemInfo(64, 32, 544)); |
1506 | nodes2.insert(L2); |
1507 | |
1508 | res2 = updateGraphMemInfoByAddingNode(nodes2, res2, L3); |
1509 | EXPECT_EQ(res2, GraphMemInfo(64, 32, 544)); |
1510 | nodes2.insert(L3); |
1511 | |
1512 | res2 = updateGraphMemInfoByAddingNode(nodes2, res2, mul); |
1513 | EXPECT_EQ(res2, GraphMemInfo(96, 32, 544)); |
1514 | nodes2.insert(mul); |
1515 | |
1516 | res2 = updateGraphMemInfoByAddingNode(nodes2, res2, R3); |
1517 | EXPECT_EQ(res2, GraphMemInfo(96, 32, 544)); |
1518 | nodes2.insert(R3); |
1519 | |
1520 | res2 = updateGraphMemInfoByAddingNode(nodes2, res2, save); |
1521 | EXPECT_EQ(res2, GraphMemInfo(96, 32, 544)); |
1522 | nodes2.insert(save); |
1523 | |
1524 | res2 = getGraphMemInfo(nodes2, 1); |
1525 | EXPECT_EQ(res2, GraphMemInfo(96, 32, 544)); |
1526 | } |
1527 | |
1528 | /// Check the function getFunctionMemory in PartitionerUtils.cpp to compute |
1529 | /// memory consumption of a simple function with same inputs used for multiple |
1530 | /// nodes. |
1531 | TEST_F(PartitionerTest, funcMemInfoCalculation1) { |
1532 | auto *input1 = |
1533 | mod_.createPlaceholder(ElemKind::FloatTy, {16}, "input1" , false); |
1534 | auto *input2 = |
1535 | mod_.createPlaceholder(ElemKind::FloatTy, {16}, "input2" , false); |
1536 | auto *input3 = |
1537 | mod_.createPlaceholder(ElemKind::FloatTy, {16}, "input3" , false); |
1538 | auto *sub = F_->createSub("sub" , input1, input2); |
1539 | auto *mul = F_->createMul("mul" , input1, input2); |
1540 | auto *sum = F_->createAdd("add" , sub, mul); |
1541 | auto *sub2 = F_->createSub("sub1" , sum, input3); |
1542 | auto *save = F_->createSave("ret" , sub2); |
1543 | (void)save; |
1544 | |
1545 | GraphMemInfo info = getFunctionMemory(F_); |
1546 | // 3x input Tensors of 16 fp32 each and 1x output of 16 fp32 values. |
1547 | GraphMemInfo res(192, 64, 0); |
1548 | EXPECT_EQ(res, info); |
1549 | } |
1550 | |
1551 | /// Check the function getFunctionMemory in PartitionerUtils.cpp to compute |
1552 | /// memory consumption of a function with constants. |
1553 | TEST_F(PartitionerTest, funcMemInfoCalculation2) { |
1554 | auto *input = |
1555 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 16}, "input" , false); |
1556 | |
1557 | // Left branch. |
1558 | auto *w2 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
1559 | auto *b2 = mod_.createConstant(ElemKind::FloatTy, {16}, "b2" ); |
1560 | auto *L = F_->createFullyConnected("left_fc1" , input, w2, b2); |
1561 | auto *L1 = F_->createSigmoid("left_sigmoid1" , L); |
1562 | auto *w3 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
1563 | auto *b3 = mod_.createConstant(ElemKind::FloatTy, {8}, "b3" ); |
1564 | auto *L2 = F_->createFullyConnected("left_fc2" , L1, w3, b3); |
1565 | auto *L3 = F_->createSigmoid("left_sigmoid2" , L2); |
1566 | |
1567 | // Right branch. |
1568 | auto *R = F_->createFullyConnected("right_fc1" , input, w2, b2); |
1569 | auto *R1 = F_->createSigmoid("right_sigmoid1" , R); |
1570 | auto *w5 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
1571 | auto *b5 = mod_.createConstant(ElemKind::FloatTy, {8}, "b5" ); |
1572 | auto *R2 = F_->createFullyConnected("right_fc2" , R1, w5, b5); |
1573 | auto *R3 = F_->createSigmoid("right_sigmoid2" , R2); |
1574 | |
1575 | // Join branches. |
1576 | auto *mul = F_->createMul("mul" , L3, R3); |
1577 | auto *save = F_->createSave("ret" , mul); |
1578 | (void)save; |
1579 | |
1580 | GraphMemInfo info = getFunctionMemory(F_); |
1581 | // single input tensor (1*16) fp32 |
1582 | // single output tensor (1*8) fp32 |
1583 | // constants fp32 (16*16 + 16 + 16*8 + 8 + 16*8 + 8) |
1584 | GraphMemInfo res(64, 32, 2176); |
1585 | EXPECT_EQ(res, info); |
1586 | } |
1587 | |
1588 | /// Check the function getFunctionMemory in PartitionerUtils.cpp to compute |
1589 | /// memory consumption of a function with same inputs used for multiple |
1590 | /// nodes. |
1591 | TEST_F(PartitionerTest, funcMemInfoCalculation3) { |
1592 | auto *input1 = |
1593 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
1594 | auto *input2 = |
1595 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 16}, "input2" , false); |
1596 | auto *input3 = |
1597 | mod_.createPlaceholder(ElemKind::FloatTy, {16, 20}, "input3" , false); |
1598 | auto *input4 = |
1599 | mod_.createPlaceholder(ElemKind::FloatTy, {20, 1}, "input4" , false); |
1600 | auto *input5 = |
1601 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 50}, "input5" , false); |
1602 | auto *mul0 = F_->createMatMul("mul0" , input1, input2); |
1603 | auto *mul1 = F_->createMatMul("mul1" , mul0, input3); |
1604 | auto *mul2 = F_->createMatMul("mul2" , mul1, input4); |
1605 | auto *mul3 = F_->createMatMul("mul3" , mul2, input5); |
1606 | auto *save = F_->createSave("ret" , mul3); |
1607 | (void)save; |
1608 | |
1609 | GraphMemInfo info = getFunctionMemory(F_); |
1610 | // input consists of 5 tensors (2*10 + 10*16 + 16*20 + 20*1 + 1*50 = 570) fp32 |
1611 | // output is tensor of 2*50 = 100 fp32 |
1612 | GraphMemInfo res(2280, 400, 0); |
1613 | EXPECT_EQ(res, info); |
1614 | } |
1615 | |
1616 | /// This one test the memoryUsageValidation in Partitioner : the memory usage |
1617 | /// of one single node is larger than the given device memory. |
1618 | TEST_F(PartitionerTest, memoryUsageValidation1) { |
1619 | auto *input1 = |
1620 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
1621 | auto *input2 = |
1622 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 16}, "input2" , false); |
1623 | auto *mul0 = F_->createMatMul("mul0" , input1, input2); |
1624 | F_->createSave("ret" , mul0); |
1625 | |
1626 | std::vector<DeviceInfo> devices = {{500, "Interpreter" }, |
1627 | {500, "Interpreter" }}; |
1628 | Partitioner myPartitioner(&mod_, devices); |
1629 | CompilationContext cctx; |
1630 | auto dagList = myPartitioner.partition(cctx); |
1631 | EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError())); |
1632 | } |
1633 | |
1634 | /// This one test dagValidation in partitioner : p1->p2, p2->p1. |
1635 | TEST_F(PartitionerTest, dagValidationWithBackendHints) { |
1636 | auto *input1 = |
1637 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
1638 | auto *input2 = |
1639 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2" , false); |
1640 | auto *input3 = |
1641 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3" , false); |
1642 | auto *add1 = F_->createAdd("add1" , input1, input2); |
1643 | auto *add2 = F_->createAdd("add2" , add1, input3); |
1644 | auto *sub1 = F_->createSub("sub1" , add1, add2); |
1645 | F_->createSave("save" , sub1); |
1646 | |
1647 | std::vector<DeviceInfo> devices = {{3072, "Interpreter" }, |
1648 | {3072, "Interpreter" }}; |
1649 | |
1650 | // User-defined partition: p1->p2, p2->p1. |
1651 | PartitionConfig partitionConfig; |
1652 | partitionConfig.funcName = "main" ; |
1653 | partitionConfig.numOfPartitions = 2; |
1654 | BackendHints bh1, bh2; |
1655 | bh1.executionUnits = 2; |
1656 | bh2.executionUnits = 3; |
1657 | partitionConfig.backendHints = {bh1, bh2}; |
1658 | partitionConfig.backendNames = {"Interpreter" , "Interpreter" }; |
1659 | partitionConfig.partitionNames = {"p1" , "p2" }; |
1660 | partitionConfig.nodeToPartition = {{"add2" , 0}}; |
1661 | auto partitioner = Partitioner(&mod_, devices, false, partitionConfig); |
1662 | CompilationContext cctx; |
1663 | auto dagList = partitioner.partition(cctx); |
1664 | EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError())); |
1665 | } |
1666 | |
1667 | /// This one test dagValidation in partitioner : p1->p2, p2->p1. |
1668 | TEST_F(PartitionerTest, dagValidation1) { |
1669 | auto *input1 = |
1670 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
1671 | auto *input2 = |
1672 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2" , false); |
1673 | auto *input3 = |
1674 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3" , false); |
1675 | auto *add1 = F_->createAdd("add1" , input1, input2); |
1676 | auto *add2 = F_->createAdd("add2" , add1, input3); |
1677 | auto *sub1 = F_->createSub("sub1" , add1, add2); |
1678 | F_->createSave("save" , sub1); |
1679 | |
1680 | std::vector<DeviceInfo> devices = {{3072, "Interpreter" }, |
1681 | {3072, "Interpreter" }}; |
1682 | |
1683 | // User-defined partition: p1->p2, p2->p1. |
1684 | PartitionConfig partitionConfig; |
1685 | partitionConfig.funcName = "main" ; |
1686 | partitionConfig.numOfPartitions = 2; |
1687 | partitionConfig.backendNames = {"Interpreter" , "Interpreter" }; |
1688 | partitionConfig.partitionNames = {"p1" , "p2" }; |
1689 | partitionConfig.nodeToPartition = {{"add2" , 0}}; |
1690 | auto partitioner = Partitioner(&mod_, devices, false, partitionConfig); |
1691 | CompilationContext cctx; |
1692 | auto dagList = partitioner.partition(cctx); |
1693 | EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError())); |
1694 | } |
1695 | |
1696 | /// This one test dagValidation in partitioner: p0->p1, p1->p2, p2->p1. |
1697 | TEST_F(PartitionerTest, dagValidation2) { |
1698 | auto *input1 = |
1699 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
1700 | auto *input2 = |
1701 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2" , false); |
1702 | auto *input3 = |
1703 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3" , false); |
1704 | auto *input4 = |
1705 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input4" , false); |
1706 | auto *add0 = F_->createAdd("add0" , input1, input2); |
1707 | auto *add1 = F_->createAdd("add1" , add0, input3); |
1708 | auto *add2 = F_->createAdd("add2" , add1, input4); |
1709 | auto *sub1 = F_->createSub("sub1" , add1, add2); |
1710 | F_->createSave("save" , sub1); |
1711 | |
1712 | std::vector<DeviceInfo> devices = { |
1713 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "Interpreter" }}; |
1714 | |
1715 | // User-defined partition: p0->p1, p1->p2, p2->p1. |
1716 | PartitionConfig partitionConfig; |
1717 | partitionConfig.funcName = "main" ; |
1718 | partitionConfig.numOfPartitions = 3; |
1719 | partitionConfig.backendNames = {"Interpreter" , "Interpreter" , "Interpreter" }; |
1720 | partitionConfig.partitionNames = {"p0" , "p1" , "p2" }; |
1721 | partitionConfig.nodeToPartition = {{"add0" , 0}, {"add2" , 2}}; |
1722 | auto partitioner = Partitioner(&mod_, devices, false, partitionConfig); |
1723 | CompilationContext cctx; |
1724 | auto dagList = partitioner.partition(cctx); |
1725 | EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError())); |
1726 | } |
1727 | |
1728 | /// This one tests partition from a user-defined config. |
1729 | TEST_F(PartitionerTest, partitionFromConfig) { |
1730 | #ifndef GLOW_WITH_CPU |
1731 | return; |
1732 | #endif |
1733 | createSimpleModule(mod_); |
1734 | std::vector<DeviceInfo> devices = { |
1735 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "CPU" }}; |
1736 | |
1737 | // User-defined partition: 3 partitions (2 interpreter, 1 cpu), Mul nodes to |
1738 | // CPU, others to Interpreter. |
1739 | PartitionConfig partitionConfig; |
1740 | partitionConfig.funcName = "test" ; |
1741 | partitionConfig.numOfPartitions = 3; |
1742 | partitionConfig.backendNames = {"Interpreter" , "CPU" , "Interpreter" }; |
1743 | partitionConfig.partitionNames = {"p1" , "p2" , "p3" }; |
1744 | partitionConfig.nodeToPartition = {{"sub" , 0}, {"mul" , 1}}; |
1745 | Partitioner partitioner(&mod_, devices, false, partitionConfig); |
1746 | CompilationContext cctx; |
1747 | auto dagList = partitioner.partition(cctx); |
1748 | ASSERT_TRUE((bool)dagList); |
1749 | EXPECT_EQ(mod_.getFunctions().size(), 3); |
1750 | EXPECT_EQ(dagList->size(), 1); |
1751 | ASSERT_TRUE(checkSaveNode(mod_)); |
1752 | heterogeneousPartitionValidation(dagList.get(), mod_); |
1753 | } |
1754 | |
1755 | /// Test user-defined partition with user specified logical devices through |
1756 | /// compilationContext. |
1757 | TEST_F(PartitionerTest, partitionFromConfigWithLogicalDevices) { |
1758 | auto *input1 = |
1759 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
1760 | auto *input2 = |
1761 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2" , false); |
1762 | auto *input3 = |
1763 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3" , false); |
1764 | auto *add1 = F_->createAdd("add1" , input1, input2); |
1765 | auto *add2 = F_->createAdd("add2" , add1, input3); |
1766 | auto *sub1 = F_->createSub("sub1" , add1, add2); |
1767 | F_->createSave("save" , sub1); |
1768 | |
1769 | std::vector<DeviceInfo> devices = { |
1770 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "Interpreter" }}; |
1771 | |
1772 | // User-defined partition: p0->p1, p1->p2, p2->p1. |
1773 | PartitionConfig partitionConfig; |
1774 | partitionConfig.funcName = "main" ; |
1775 | partitionConfig.numOfPartitions = 3; |
1776 | partitionConfig.backendNames = {"Interpreter" , "Interpreter" , "Interpreter" }; |
1777 | partitionConfig.partitionNames = {"p0" , "p1" , "p2" }; |
1778 | partitionConfig.nodeToPartition = {{"add1" , 0}, {"add2" , 2}}; |
1779 | partitionConfig.logicalIDs = {{0}, {1}, {0, 1}}; |
1780 | auto partitioner = Partitioner(&mod_, devices); |
1781 | CompilationContext cctx; |
1782 | cctx.partitionConfig = &partitionConfig; |
1783 | auto result = partitioner.partition(cctx); |
1784 | DAGListTy nodeList; |
1785 | EXPECT_FALSE(ERR_TO_BOOL(result.takeError())); |
1786 | nodeList = std::move(result.get()); |
1787 | // Check that p2 has both 0 and 1 for logicalDevices. |
1788 | EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[0], 0); |
1789 | EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[1], 1); |
1790 | } |
1791 | |
1792 | /// Test user-defined partition with user specified logical devices through |
1793 | /// compilationContext using fp16. |
1794 | TEST_F(PartitionerTest, partitionFromConfigWithLogicalDevicesFp16) { |
1795 | auto *input1 = |
1796 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
1797 | auto *input2 = |
1798 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2" , false); |
1799 | auto *input3 = |
1800 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3" , false); |
1801 | auto *add1 = F_->createAdd("add1" , input1, input2); |
1802 | auto *add2 = F_->createAdd("add2" , add1, input3); |
1803 | auto *sub1 = F_->createSub("sub1" , add1, add2); |
1804 | F_->createSave("save" , sub1); |
1805 | |
1806 | std::vector<DeviceInfo> devices = { |
1807 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "Interpreter" }}; |
1808 | |
1809 | // User-defined partition: p0->p1, p1->p2, p2->p1. |
1810 | PartitionConfig partitionConfig; |
1811 | partitionConfig.funcName = "main" ; |
1812 | partitionConfig.numOfPartitions = 3; |
1813 | partitionConfig.backendNames = {"Interpreter" , "Interpreter" , "Interpreter" }; |
1814 | partitionConfig.partitionNames = {"p0" , "p1" , "p2" }; |
1815 | partitionConfig.nodeToPartition = {{"add1" , 0}, {"add2" , 2}}; |
1816 | partitionConfig.logicalIDs = {{0}, {1}, {0, 1}}; |
1817 | auto partitioner = Partitioner(&mod_, devices); |
1818 | CompilationContext cctx; |
1819 | cctx.partitionConfig = &partitionConfig; |
1820 | PrecisionConfiguration pc; |
1821 | pc.convertToFP16 = true; |
1822 | cctx.precisionConfig = pc; |
1823 | auto result = partitioner.partition(cctx); |
1824 | |
1825 | // Do optimization |
1826 | for (dim_t i = 0; i < partitionConfig.numOfPartitions; i++) { |
1827 | for (auto &func : mod_.getFunctions()) { |
1828 | std::unique_ptr<Backend> backend(createBackend("Interpreter" )); |
1829 | auto err = ::glow::optimizeFunction(func, *backend, cctx); |
1830 | EXPECT_FALSE(err); |
1831 | } |
1832 | } |
1833 | |
1834 | DAGListTy nodeList; |
1835 | EXPECT_FALSE(ERR_TO_BOOL(result.takeError())); |
1836 | nodeList = std::move(result.get()); |
1837 | // Check that p2 has both 0 and 1 for logicalDevices. |
1838 | EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[0], 0); |
1839 | EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[1], 1); |
1840 | // Check that the inputs and outputs of add1, add2 and sub1 are in fp16 |
1841 | for (auto const &F : mod_.getFunctions()) { |
1842 | for (auto const &N : F->getNodes()) { |
1843 | auto NI = NodeInfo(N); |
1844 | if (NI.getKind() != Kinded::Kind::SaveNodeKind && |
1845 | NI.getKind() != Kinded::Kind::ConvertToNodeKind) { |
1846 | EXPECT_TRUE( |
1847 | NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty})); |
1848 | } |
1849 | } |
1850 | } |
1851 | } |
1852 | |
1853 | /// This one tests calling PartitionFromConfig directly. |
1854 | TEST_F(PartitionerTest, partitionFromConfigDirectCall) { |
1855 | #ifndef GLOW_WITH_CPU |
1856 | return; |
1857 | #endif |
1858 | createSimpleModule(mod_); |
1859 | std::vector<DeviceInfo> devices = { |
1860 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "CPU" }}; |
1861 | |
1862 | // User-defined partition: 3 partitions (2 interpreter, 1 cpu), Mul nodes to |
1863 | // CPU, others to Interpreter. |
1864 | PartitionConfig partitionConfig; |
1865 | partitionConfig.funcName = "test" ; |
1866 | partitionConfig.numOfPartitions = 3; |
1867 | partitionConfig.backendNames = {"Interpreter" , "CPU" , "Interpreter" }; |
1868 | partitionConfig.partitionNames = {"p1" , "p2" , "p3" }; |
1869 | partitionConfig.nodeToPartition = {{"sub" , 0}, {"mul" , 1}}; |
1870 | Partitioner partitioner(&mod_, devices); |
1871 | CompilationContext cctx; |
1872 | auto dagList = partitioner.partitionFromConfig(partitionConfig, cctx); |
1873 | ASSERT_TRUE((bool)dagList); |
1874 | EXPECT_EQ(mod_.getFunctions().size(), 3); |
1875 | EXPECT_EQ(dagList->size(), 1); |
1876 | ASSERT_TRUE(checkSaveNode(mod_)); |
1877 | heterogeneousPartitionValidation(dagList.get(), mod_); |
1878 | } |
1879 | |
1880 | /// This one test load-balanced partition flow. |
1881 | TEST_F(PartitionerTest, loadBalancedPartition) { |
1882 | ExecutionEngine EER, EEP; |
1883 | EEP.setSkipModuleStrip(true); |
1884 | constexpr float range = 2.0; |
1885 | std::vector<ExecutionEngine *> engines{&EER, &EEP}; |
1886 | // Since compiling modifies the module and partitioning modifies the |
1887 | // function, setup two EEs with identical functions for validation. |
1888 | for (auto EE : engines) { |
1889 | auto mod = &EE->getModule(); |
1890 | F_ = mod->createFunction("main" ); |
1891 | auto *input = |
1892 | mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input" , false); |
1893 | auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1" ); |
1894 | auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1" ); |
1895 | bindings_.allocate(input); |
1896 | w1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1897 | b1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1898 | |
1899 | // Initial FC. |
1900 | Node *I = F_->createFullyConnected("initial_fc" , input, w1, b1); |
1901 | I = F_->createSigmoid("initial_sigmoid" , I); |
1902 | |
1903 | // Left branch. |
1904 | auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
1905 | auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2" ); |
1906 | w2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1907 | b2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1908 | Node *L = F_->createFullyConnected("left_fc1" , I, w2, b2); |
1909 | L = F_->createSigmoid("left_sigmoid1" , L); |
1910 | auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
1911 | auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3" ); |
1912 | w3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1913 | b3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1914 | L = F_->createFullyConnected("left_fc2" , L, w3, b3); |
1915 | L = F_->createSigmoid("left_sigmoid2" , L); |
1916 | |
1917 | // Right branch. |
1918 | auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4" ); |
1919 | auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4" ); |
1920 | w4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1921 | b4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1922 | Node *R = F_->createFullyConnected("right_fc1" , I, w4, b4); |
1923 | R = F_->createSigmoid("right_sigmoid1" , R); |
1924 | auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
1925 | auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5" ); |
1926 | w5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1927 | b5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
1928 | R = F_->createFullyConnected("right_fc2" , R, w5, b5); |
1929 | R = F_->createSigmoid("right_sigmoid2" , R); |
1930 | |
1931 | // Join branches. |
1932 | auto *mul = F_->createMul("mul" , L, R); |
1933 | F_->createSave("ret" , mul); |
1934 | } |
1935 | |
1936 | // Infer using the un-partitioned graph. |
1937 | Tensor in(ElemKind::FloatTy, {1, 32}); |
1938 | in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG()); |
1939 | |
1940 | EER.compile(CompilationMode::Infer); |
1941 | bindings_.clear(); |
1942 | bindings_.allocate(EER.getModule().getPlaceholders()); |
1943 | updateInputPlaceholders(bindings_, |
1944 | {bindings_.getPlaceholderByNameSlow("input" )}, {&in}); |
1945 | EER.run(bindings_); |
1946 | Tensor ref = |
1947 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
1948 | |
1949 | std::vector<DeviceInfo> devices = { |
1950 | {3072, "Interpreter" }, {3072, "Interpreter" }, {3072, "Interpreter" }}; |
1951 | Partitioner myPartitioner(&EEP.getModule(), devices, true); |
1952 | CompilationContext cctx; |
1953 | auto dagList = myPartitioner.loadBalancedPartition(cctx); |
1954 | ASSERT_TRUE((bool)dagList); |
1955 | EXPECT_EQ(EEP.getModule().getFunctions().size(), 3); |
1956 | EXPECT_EQ(dagList->size(), 1); |
1957 | EXPECT_TRUE(checkSaveNode(EEP.getModule())); |
1958 | |
1959 | // Run the paritioned graph and compare the results. |
1960 | bindings_.clear(); |
1961 | bindings_.allocate(EEP.getModule().getPlaceholders()); |
1962 | EEP.compile(cctx); |
1963 | executeDAG(dagList->begin()->root.get(), EEP.getModule(), bindings_, |
1964 | {bindings_.getPlaceholderByNameSlow("input" )}, {&in}, &EEP); |
1965 | Tensor test = |
1966 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
1967 | EXPECT_TRUE(ref.isEqual(test, 0.0f)); |
1968 | verifyDAGSerialization(*dagList, EEP.getModule(), bindings_, {"input" }, "ret" , |
1969 | devices, {&in}, ref); |
1970 | } |
1971 | |
1972 | /// This tests the pre-partitioned flow. |
1973 | TEST_F(PartitionerTest, PrePartitionedTest) { |
1974 | CompilationContext cctx; |
1975 | PrePartitionedConfig PPC; |
1976 | cctx.prepartitionedConfig = &PPC; |
1977 | Function *F0 = F_; |
1978 | Function *F1 = mod_.createFunction("main_1" ); |
1979 | Function *F2 = mod_.createFunction("main_2" ); |
1980 | PPC.funcs.push_back(F0); |
1981 | PPC.funcs.push_back(F1); |
1982 | PPC.funcs.push_back(F2); |
1983 | PPC.logicalIDs.resize(3); |
1984 | PPC.logicalIDs[0].push_back(0); |
1985 | PPC.logicalIDs[1].push_back(1); |
1986 | PPC.logicalIDs[2].push_back(1); |
1987 | PPC.logicalIDs[2].push_back(2); |
1988 | PPC.backendSpecificOpts.emplace_back( |
1989 | BackendSpecificOptions{{"opt0" , "val0" }, {"opt1" , "val1" }}); |
1990 | PPC.backendSpecificOpts.emplace_back( |
1991 | BackendSpecificOptions{{"opt2" , "val2" }}); |
1992 | PPC.backendSpecificOpts.emplace_back(BackendSpecificOptions{}); |
1993 | PPC.replicationCounts.push_back(3); |
1994 | PPC.replicationCounts.push_back(4); |
1995 | PPC.replicationCounts.push_back(1); |
1996 | PPC.backendHints.push_back({7, {"a" }}); |
1997 | PPC.backendHints.push_back({8, {"b" }}); |
1998 | PPC.backendHints.push_back({9, {"c" , "d" }}); |
1999 | |
2000 | auto *I0 = mod_.createPlaceholder(ElemKind::FloatTy, {5, 5}, "I0" , false); |
2001 | auto *I1 = mod_.createPlaceholder(ElemKind::FloatTy, {5, 5}, "I1" , false); |
2002 | auto *I2 = mod_.createPlaceholder(ElemKind::FloatTy, {5, 5}, "I1" , false); |
2003 | |
2004 | // Partition 0 is a MatMul and Save. |
2005 | MatMulNode *MM = F0->createMatMul("MM" , I0, I1); |
2006 | SaveNode *SMM = F0->createSave("SMM" , MM); |
2007 | |
2008 | // Partition 1 loads from the Partition 0 MatMul. |
2009 | AddNode *AN = F1->createAdd("AN" , SMM->getPlaceholder(), I2); |
2010 | SaveNode *SAN = F1->createSave("SAN" , AN); |
2011 | |
2012 | // Partition 2 loads from both Partition 0 and 1. |
2013 | MulNode *MN = F2->createMul("MN" , SMM->getPlaceholder(), I0); |
2014 | SubNode *SN = F2->createSub("SN" , SAN->getPlaceholder(), MN); |
2015 | SaveNode *finalSave = F2->createSave("finalSave" , SN); |
2016 | |
2017 | const runtime::DeviceInfo dev{/* 16GB: */ 0x400000000, "Interpreter" }; |
2018 | const std::vector<runtime::DeviceInfo> devices(3, dev); |
2019 | Partitioner partitioner(&mod_, devices); |
2020 | DAGListTy d; |
2021 | ASSIGN_VALUE_OR_FAIL_TEST(d, partitioner.setupPrepartitionedModule(cctx)); |
2022 | |
2023 | // Note: DAG should look like: F0 -> F1 |
2024 | // \ | |
2025 | // v v |
2026 | // F2 |
2027 | |
2028 | ASSERT_EQ(d.size(), 1); |
2029 | |
2030 | DAGNodePtr &root = d[0].root; |
2031 | EXPECT_EQ(root->module, &mod_); |
2032 | |
2033 | ASSERT_EQ(root->children.size(), 1); |
2034 | DAGNode *D0 = root->children[0]; |
2035 | ASSERT_EQ(D0->name, F0->getName()); |
2036 | ASSERT_EQ(D0->parents.size(), 1); |
2037 | EXPECT_EQ(D0->parents[0], root.get()); |
2038 | ASSERT_EQ(D0->logicalDevices.size(), 1); |
2039 | EXPECT_EQ(D0->logicalDevices[0], 0); |
2040 | EXPECT_EQ(D0->size, I0->getType()->getSizeInBytes() + |
2041 | I1->getType()->getSizeInBytes() + |
2042 | SMM->getPlaceholder()->getType()->getSizeInBytes()); |
2043 | EXPECT_EQ(D0->backendSpecificOpts.size(), 2); |
2044 | ASSERT_TRUE(D0->backendSpecificOpts.count("opt0" )); |
2045 | EXPECT_EQ(D0->backendSpecificOpts.at("opt0" ), "val0" ); |
2046 | ASSERT_TRUE(D0->backendSpecificOpts.count("opt1" )); |
2047 | EXPECT_EQ(D0->backendSpecificOpts.at("opt1" ), "val1" ); |
2048 | EXPECT_EQ(D0->replicationCount, 3); |
2049 | EXPECT_EQ(D0->backendHints.executionUnits, 7); |
2050 | ASSERT_EQ(D0->backendHints.SRAMPrioritization.size(), 1); |
2051 | EXPECT_EQ(D0->backendHints.SRAMPrioritization[0], "a" ); |
2052 | |
2053 | ASSERT_EQ(D0->children.size(), 2); |
2054 | DAGNode *D1 = (D0->children[0]->name == F1->getName()) ? D0->children[0] |
2055 | : D0->children[1]; |
2056 | ASSERT_EQ(D1->parents.size(), 1); |
2057 | EXPECT_EQ(D1->parents[0], D0); |
2058 | ASSERT_EQ(D1->name, F1->getName()); |
2059 | ASSERT_EQ(D1->logicalDevices.size(), 1); |
2060 | EXPECT_EQ(D1->logicalDevices[0], 1); |
2061 | EXPECT_EQ(D1->size, I2->getType()->getSizeInBytes() + |
2062 | SAN->getPlaceholder()->getType()->getSizeInBytes() + |
2063 | SMM->getPlaceholder()->getType()->getSizeInBytes()); |
2064 | EXPECT_EQ(D1->backendSpecificOpts.size(), 1); |
2065 | ASSERT_TRUE(D1->backendSpecificOpts.count("opt2" )); |
2066 | EXPECT_EQ(D1->backendSpecificOpts.at("opt2" ), "val2" ); |
2067 | EXPECT_EQ(D1->replicationCount, 4); |
2068 | EXPECT_EQ(D1->backendHints.executionUnits, 8); |
2069 | ASSERT_EQ(D1->backendHints.SRAMPrioritization.size(), 1); |
2070 | EXPECT_EQ(D1->backendHints.SRAMPrioritization[0], "b" ); |
2071 | |
2072 | DAGNode *D2 = (D1 == D0->children[0]) ? D0->children[1] : D0->children[0]; |
2073 | ASSERT_EQ(D2->name, F2->getName()); |
2074 | ASSERT_EQ(D1->children.size(), 1); |
2075 | EXPECT_EQ(D1->children[0], D2); |
2076 | ASSERT_EQ(D2->parents.size(), 2); |
2077 | EXPECT_TRUE(D2->parents[0] == D0 || D2->parents[1] == D0); |
2078 | EXPECT_TRUE(D2->parents[0] == D1 || D2->parents[1] == D1); |
2079 | EXPECT_NE(D2->parents[0], D2->parents[1]); |
2080 | ASSERT_EQ(D2->logicalDevices.size(), 2); |
2081 | EXPECT_TRUE(D2->logicalDevices[0] == 1 || D2->logicalDevices[0] == 2); |
2082 | EXPECT_TRUE(D2->logicalDevices[1] == 1 || D2->logicalDevices[1] == 2); |
2083 | EXPECT_NE(D2->logicalDevices[0], D2->logicalDevices[1]); |
2084 | EXPECT_EQ(D2->size, |
2085 | I0->getType()->getSizeInBytes() + |
2086 | SAN->getPlaceholder()->getType()->getSizeInBytes() + |
2087 | SMM->getPlaceholder()->getType()->getSizeInBytes() + |
2088 | finalSave->getPlaceholder()->getType()->getSizeInBytes()); |
2089 | EXPECT_EQ(D2->backendSpecificOpts.size(), 0); |
2090 | EXPECT_EQ(D2->replicationCount, 1); |
2091 | EXPECT_EQ(D2->backendHints.executionUnits, 9); |
2092 | ASSERT_EQ(D2->backendHints.SRAMPrioritization.size(), 2); |
2093 | EXPECT_EQ(D2->backendHints.SRAMPrioritization[0], "c" ); |
2094 | EXPECT_EQ(D2->backendHints.SRAMPrioritization[1], "d" ); |
2095 | } |
2096 | |
2097 | /// Test that constant folding (de)serialization works along with partitioning. |
2098 | TEST_F(PartitionerTest, RecordedConstantFolding) { |
2099 | ExecutionEngine EER, EEP; |
2100 | EEP.setSkipModuleStrip(true); |
2101 | constexpr float range = 2.0; |
2102 | std::vector<ExecutionEngine *> engines{&EER, &EEP}; |
2103 | // Since compiling modifies the module and partitioning modifies the function, |
2104 | // setup two EEs with identical functions for validation. |
2105 | for (auto EE : engines) { |
2106 | auto mod = &EE->getModule(); |
2107 | F_ = mod->createFunction("main" ); |
2108 | auto *input = |
2109 | mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input" , false); |
2110 | auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1" ); |
2111 | auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1" ); |
2112 | bindings_.allocate(input); |
2113 | w1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2114 | b1->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2115 | |
2116 | // Initial FC. |
2117 | Node *I = F_->createFullyConnected("initial_fc" , input, w1, b1); |
2118 | I = F_->createSigmoid("initial_sigmoid" , I); |
2119 | |
2120 | // Left branch. Note that w2 and b2 will be constant folded. |
2121 | auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2" ); |
2122 | auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2" ); |
2123 | w2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2124 | b2->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2125 | auto *w2Clip = F_->createClip("clip_w2" , w2, -1, 1); |
2126 | auto *b2Clip = F_->createClip("clip_b2" , b2, -1, 1); |
2127 | Node *L = F_->createFullyConnected("left_fc1" , I, w2Clip, b2Clip); |
2128 | L = F_->createSigmoid("left_sigmoid1" , L); |
2129 | auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3" ); |
2130 | auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3" ); |
2131 | w3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2132 | b3->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2133 | L = F_->createFullyConnected("left_fc2" , L, w3, b3); |
2134 | L = F_->createSigmoid("left_sigmoid2" , L); |
2135 | |
2136 | // Right branch. Note that w4 will be constant folded. |
2137 | auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4" ); |
2138 | auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4" ); |
2139 | w4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2140 | b4->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2141 | auto *w4Sig = F_->createSigmoid("w4_sig" , w4); |
2142 | Node *R = F_->createFullyConnected("right_fc1" , I, w4Sig, b4); |
2143 | R = F_->createSigmoid("right_sigmoid1" , R); |
2144 | auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5" ); |
2145 | auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5" ); |
2146 | w5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2147 | b5->getHandle<>().randomize(-range, range, mod->getPRNG()); |
2148 | R = F_->createFullyConnected("right_fc2" , R, w5, b5); |
2149 | R = F_->createSigmoid("right_sigmoid2" , R); |
2150 | |
2151 | // Join branches. |
2152 | auto *mul = F_->createMul("mul" , L, R); |
2153 | F_->createSave("ret" , mul); |
2154 | } |
2155 | |
2156 | // Infer using the un-partitioned graph. |
2157 | Tensor in(ElemKind::FloatTy, {1, 32}); |
2158 | in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG()); |
2159 | |
2160 | EER.compile(CompilationMode::Infer); |
2161 | bindings_.clear(); |
2162 | bindings_.allocate(EER.getModule().getPlaceholders()); |
2163 | updateInputPlaceholders(bindings_, |
2164 | {bindings_.getPlaceholderByNameSlow("input" )}, {&in}); |
2165 | EER.run(bindings_); |
2166 | Tensor ref = |
2167 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
2168 | |
2169 | // Now try with partitioning, and partitioning + constant fold recording. |
2170 | auto &modP = EEP.getModule(); |
2171 | |
2172 | CompilationContext cctx; |
2173 | ASSERT_EQ(modP.getFunctions().size(), 1); |
2174 | Function *origF = *modP.getFunctions().begin(); |
2175 | ConstantFoldingRecordMap record = constantFoldAndRecord(origF, cctx); |
2176 | runDCEPass(origF, cctx); |
2177 | // Expect 3 Constants were folded: w2, b2, and w4 from above. |
2178 | ASSERT_EQ(record.size(), 3); |
2179 | |
2180 | const DeviceInfo devI{3072, "Interpreter" }; |
2181 | std::vector<DeviceInfo> devices = {devI, devI, devI}; |
2182 | Partitioner myPartitioner(&modP, devices, /* optimized */ true); |
2183 | EXPECT_TRUE(checkSaveNode(modP)); |
2184 | |
2185 | DAGListTy dagList; |
2186 | ASSIGN_VALUE_OR_FAIL_TEST(dagList, myPartitioner.loadBalancedPartition(cctx)); |
2187 | |
2188 | ASSERT_EQ(dagList.size(), 1); |
2189 | const auto &dag = *dagList.begin(); |
2190 | EXPECT_EQ(dag.nodes.size(), 3); |
2191 | |
2192 | // Verify that we serialize and deserialize the DAG correctly including with |
2193 | // the constant folding record, and that results are bitwise equal. |
2194 | verifyDAGSerialization(dagList, modP, bindings_, {"input" }, "ret" , devices, |
2195 | {&in}, ref, &record); |
2196 | |
2197 | // Now run the original partitioned model and verify it also is bitwise equal. |
2198 | bindings_.clear(); |
2199 | bindings_.allocate(modP.getPlaceholders()); |
2200 | EEP.compile(cctx); |
2201 | |
2202 | executeDAG(dagList.begin()->root.get(), modP, bindings_, |
2203 | {bindings_.getPlaceholderByNameSlow("input" )}, {&in}, &EEP); |
2204 | Tensor test = |
2205 | bindings_.get(bindings_.getPlaceholderByNameSlow("ret" ))->clone(); |
2206 | EXPECT_TRUE(ref.isEqual(test, 0.0f)); |
2207 | } |
2208 | |
2209 | /// This test verifies that resourceCount is being checked correctly, we set the |
2210 | /// resourceCount to 1 and expect an error. |
2211 | TEST_F(PartitionerTest, resourceCountValidationTest) { |
2212 | auto *input1 = |
2213 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1" , false); |
2214 | auto *input2 = |
2215 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2" , false); |
2216 | auto *input3 = |
2217 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3" , false); |
2218 | auto *add1 = F_->createAdd("add1" , input1, input2); |
2219 | auto *add2 = F_->createAdd("add2" , add1, input3); |
2220 | auto *sub1 = F_->createSub("sub1" , add1, add2); |
2221 | F_->createSave("save" , sub1); |
2222 | |
2223 | std::vector<DeviceInfo> devices = {{3072, "Interpreter" , {}}, |
2224 | {3072, "Interpreter" , {}}}; |
2225 | |
2226 | devices[0].inputCountMax = 1; |
2227 | devices[1].inputCountMax = 1; |
2228 | // User-defined partition: p1->p2, p2->p1. |
2229 | PartitionConfig partitionConfig; |
2230 | partitionConfig.funcName = "main" ; |
2231 | partitionConfig.numOfPartitions = 2; |
2232 | BackendHints bh1, bh2; |
2233 | bh1.executionUnits = 2; |
2234 | bh2.executionUnits = 3; |
2235 | partitionConfig.backendHints = {bh1, bh2}; |
2236 | partitionConfig.backendNames = {"Interpreter" , "Interpreter" }; |
2237 | partitionConfig.partitionNames = {"p1" , "p2" }; |
2238 | partitionConfig.nodeToPartition = {{"add2" , 0}}; |
2239 | auto partitioner = Partitioner(&mod_, devices, false, partitionConfig); |
2240 | CompilationContext cctx; |
2241 | auto dagList = partitioner.partition(cctx); |
2242 | EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError())); |
2243 | } |
2244 | |
2245 | /// Tests that the given net is assigned and duplicated on the given logical |
2246 | /// devices. |
2247 | TEST_F(PartitionerTest, saturateKDevicesTest) { |
2248 | createSimpleModule(mod_); |
2249 | std::vector<DeviceInfo> devices = {{2048, "Interpreter" , {}}, |
2250 | {2048, "Interpreter" , {}}, |
2251 | {2048, "Interpreter" , {}}}; |
2252 | auto partitioner = Partitioner(&mod_, devices, false); |
2253 | // Partitioner should create DAG without partitioning, duplicate it and |
2254 | // assign to the given logical devices. |
2255 | DAGListTy dagList; |
2256 | CompilationContext cctx; |
2257 | cctx.saturateHost = true; |
2258 | cctx.saturateKDevices = 2; |
2259 | ASSIGN_VALUE_OR_FAIL_TEST(dagList, partitioner.partition(cctx)); |
2260 | EXPECT_EQ(dagList.size(), 1); |
2261 | |
2262 | int numOfInterpreterBackends = 0; |
2263 | for (auto &dag : dagList) { |
2264 | for (auto &node : dag.nodes) { |
2265 | // Verify the node is assigned to K devices. |
2266 | EXPECT_EQ(node->logicalDevices.size(), cctx.saturateKDevices); |
2267 | |
2268 | if (node->backendName == "Interpreter" ) { |
2269 | numOfInterpreterBackends++; |
2270 | } |
2271 | } |
2272 | } |
2273 | EXPECT_EQ(numOfInterpreterBackends, 1); |
2274 | } |
2275 | |