1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Partitioner/Partitioner.h"
17#include "glow/ExecutionEngine/ExecutionEngine.h"
18#include "glow/Exporter/ONNXModelWriter.h"
19#include "glow/Graph/Graph.h"
20#include "glow/Importer/ONNXModelLoader.h"
21#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
22#include "glow/Partitioner/PartitionerUtils.h"
23
24#include "llvm/Support/FileSystem.h"
25
26#include "gtest/gtest.h"
27
28using namespace glow;
29
30class PartitionerTest : public ::testing::Test {
31public:
32 PartitionerTest() : F_(mod_.createFunction("main")) {}
33
34protected:
35 Module mod_;
36 Function *F_;
37 PlaceholderBindings bindings_;
38};
39
40/// Execute a graph of functions based on the given DAG.
41static void executeDAG(DAGNode *G, Module &mod, PlaceholderBindings &bindings,
42 llvm::ArrayRef<Placeholder *> vars,
43 llvm::ArrayRef<Tensor *> inputs, ExecutionEngine *EE) {
44 std::unordered_map<std::string, Function *> name2func;
45
46 for (auto *F : mod.getFunctions()) {
47 name2func[F->getName().str()] = F;
48 }
49
50 std::vector<DAGNode *> exeList;
51 int endPt = 0;
52 int curPt = 0;
53 // The first node is always the dummy node.
54 exeList.push_back(G);
55 endPt++;
56 while (curPt < endPt) {
57 DAGNode *dag = exeList.at(curPt);
58 // The root in a G is always a dummy function.
59 if (curPt > 0) {
60 updateInputPlaceholders(bindings, vars, inputs);
61 EE->run(bindings, dag->name);
62 }
63 for (int i = 0, e = dag->children.size(); i < e; i++) {
64 exeList.push_back(dag->children.at(i));
65 endPt++;
66 }
67 curPt++;
68 }
69}
70
71/// \returns true if all the functions have the valid save node format: i.e. no
72/// such pattern Save->Save.
73static bool checkSaveNode(Module &mod) {
74 for (auto F : mod.getFunctions()) {
75 for (const Node &N : F->getNodes()) {
76 if (N.getKind() != Kinded::Kind::SaveNodeKind) {
77 continue;
78 }
79 auto *ph = llvm::dyn_cast<Placeholder>(N.getNthInput(0).getNode());
80 if (!ph) {
81 continue;
82 }
83 // If this SaveNode use the output of another SaveNode, it is an illegal
84 // pattern.
85 for (auto &user : ph->getUsers()) {
86 if (user.getUser() == &N || !llvm::dyn_cast<SaveNode>(user.getUser())) {
87 continue;
88 }
89 return false;
90 }
91 }
92 }
93 return true;
94}
95
96/// Serializes \p dagList and re-loads it. Compares the structure of the DAGs
97/// before/after and verify results are still the same given \p devices
98static void verifyDAGSerialization(
99 DAGListTy &dagList, Module &origMod, PlaceholderBindings &bindings,
100 llvm::ArrayRef<llvm::StringRef> inputNames, llvm::StringRef resultName,
101 const std::vector<DeviceInfo> &devices, llvm::ArrayRef<Tensor *> inputs,
102 const Tensor &ref, ConstantFoldingRecordMap *constFoldRecord = nullptr) {
103 llvm::SmallString<64> path;
104 auto tempFileRes =
105 llvm::sys::fs::createTemporaryFile("exporter", "output.onnx", path);
106 (void)tempFileRes;
107 assert(tempFileRes.value() == 0);
108
109 std::string outputFilename(path.c_str());
110 std::cout << "Writing to file: " << outputFilename << std::endl;
111 {
112 // Note: do not include Constant data when we write out; we will reuse the
113 // Module so we don't need to save it.
114 Error err = Error::empty();
115 llvm::StringMap<std::string> extraMetadataProps;
116 ONNXModelWriter onnxWR(outputFilename, dagList, 7, 9, &err,
117 /* textMode */ false, /* zipMode */ false,
118 /* includeConstantData */ false, extraMetadataProps,
119 constFoldRecord ? *constFoldRecord
120 : ConstantFoldingRecordMap());
121
122 if (ERR_TO_BOOL(std::move(err))) {
123 llvm::errs() << "ONNXModelWriter failed to write model: "
124 << outputFilename << "\n";
125 llvm::sys::fs::remove(outputFilename);
126 FAIL() << "Error exporting DAG.";
127 }
128 }
129
130 // Create a new EE using the same module. Note that we assume devices are
131 // homogenous here.
132 ExecutionEngine loadedEE(devices[0].backendName, devices[0].availableMemory,
133 /* ignoreUserDeviceConfig */ false,
134 /* numDevices */ devices.size());
135 // Clone the original module into the one in the EE; we're going to
136 // deserialize the DAG into it as if we're reusing the same Module.
137 origMod.clone(&loadedEE.getModule());
138 Module &loadedMod = loadedEE.getModule();
139 CompilationContext loadedCctx;
140 runtime::PrePartitionedConfig PPC;
141 loadedCctx.prepartitionedConfig = &PPC;
142 {
143 // Clear out Functions from Nodes. We will reuse the empty Functions.
144 loadedMod.clearFunctions();
145 // If we have a constant folding record then delete those Constants too
146 // since we're going to recreate them. Also delete the const fold Functions.
147 if (constFoldRecord) {
148 std::unordered_set<Function *> funsToDelete;
149 for (auto &pair : *constFoldRecord) {
150 Function *origF = pair.second->getParent();
151 funsToDelete.insert(origF);
152 Constant *C = loadedMod.getConstantByName(pair.first->getName());
153 ASSERT_TRUE(C);
154 loadedMod.eraseConstant(C);
155 }
156 for (Function *origF : funsToDelete) {
157 Function *loadedConstFoldF = loadedMod.getFunction(origF->getName());
158 ASSERT_TRUE(loadedConstFoldF);
159 loadedMod.eraseFunction(loadedConstFoldF);
160 origMod.eraseFunction(origF);
161 }
162 }
163 Error err = Error::empty();
164 ONNXModelLoader onnxLD(
165 outputFilename, {}, {}, loadedMod, "main", &PPC, &err,
166 /* zipMode */ false, &loadedCctx.backendOpts.backendSpecificNodeInfo,
167 /* loadIntoExistingModule */ true);
168 if (ERR_TO_BOOL(std::move(err))) {
169 llvm::errs() << "ONNXModelLoader failed to load model: " << outputFilename
170 << "\n";
171 llvm::sys::fs::remove(outputFilename);
172 FAIL() << "Error importing DAG.";
173 }
174 }
175 llvm::sys::fs::remove(outputFilename);
176
177 // Now verify the DAG is the same, including all static properties of the DAG.
178 Partitioner loadedPartitioner(&loadedMod, devices, /* optimized */ true);
179 DAGListTy loadedDagList;
180 ASSIGN_VALUE_OR_FAIL_TEST(loadedDagList,
181 loadedPartitioner.partition(loadedCctx));
182
183 // Verify that two DAGs are the same.
184 ASSERT_EQ(dagList.size(), loadedDagList.size());
185 ASSERT_EQ(dagList.size(), 1);
186 DAG &origDAG = dagList.front();
187 DAG &loadedDAG = loadedDagList.front();
188 EXPECT_EQ(origDAG.root->name, loadedDAG.root->name);
189
190 // Map from orig DAGNodes to loaded DAGNodes.
191 std::unordered_map<const DAGNode *, const DAGNode *> origToLoaded;
192 for (DAGNodePtr &origN : origDAG.nodes) {
193 for (DAGNodePtr &loadedN : loadedDAG.nodes) {
194 if (origN->name != loadedN->name) {
195 continue;
196 }
197 origToLoaded[origN.get()] = loadedN.get();
198 break;
199 }
200 }
201 ASSERT_EQ(origDAG.nodes.size(), origToLoaded.size());
202 origToLoaded[origDAG.root.get()] = loadedDAG.root.get();
203
204 for (const auto &nPair : origToLoaded) {
205 const DAGNode *origN = nPair.first;
206 const DAGNode *loadedN = nPair.second;
207#define CHECK_DAG_EQ(MEM_NAME) EXPECT_EQ(origN->MEM_NAME, loadedN->MEM_NAME);
208 CHECK_DAG_EQ(name);
209 CHECK_DAG_EQ(size);
210 CHECK_DAG_EQ(backendName);
211 CHECK_DAG_EQ(backendHints.executionUnits);
212 CHECK_DAG_EQ(logicalDevices.size());
213 for (size_t i = 0, e = origN->logicalDevices.size(); i < e; i++) {
214 EXPECT_EQ(origN->logicalDevices[i], loadedN->logicalDevices[i]);
215 }
216 CHECK_DAG_EQ(replicationCount);
217 EXPECT_TRUE(std::equal(origN->backendSpecificOpts.begin(),
218 origN->backendSpecificOpts.end(),
219 loadedN->backendSpecificOpts.begin()));
220#undef CHECK_DAG_EQ
221
222 for (const DAGNode *origChild : origN->children) {
223 auto it = std::find_if(loadedN->children.begin(), loadedN->children.end(),
224 [=](auto *loadedChild) {
225 return loadedChild->name == origChild->name;
226 });
227 EXPECT_NE(it, std::end(loadedN->children));
228 }
229 for (const DAGNode *origParent : origN->parents) {
230 auto it = std::find_if(loadedN->parents.begin(), loadedN->parents.end(),
231 [=](auto *loadedParent) {
232 return loadedParent->name == origParent->name;
233 });
234 EXPECT_NE(it, std::end(loadedN->parents));
235 }
236
237 // Skip checking root as there's no Function for them.
238 if (origN == origDAG.root.get()) {
239 continue;
240 }
241 Function *origF = origMod.getFunction(origN->name);
242 Function *loadedF = loadedMod.getFunction(loadedN->name);
243 ASSERT_TRUE(origF);
244 ASSERT_TRUE(loadedF);
245 EXPECT_EQ(origF->toString(), loadedF->toString());
246 }
247 EXPECT_EQ(origMod.toString(), loadedMod.toString());
248
249 // Now reset bindings and run, checking results are bitwise equal from before
250 // and after serialization. Note that we still use the same PPC -- it will
251 // re-partition/setup the same DAG inside compilation.
252 loadedEE.compile(loadedCctx);
253 bindings.clear();
254 bindings.allocate(loadedMod.getPlaceholders());
255 std::vector<Placeholder *> inPHs;
256 for (const llvm::StringRef &inName : inputNames) {
257 inPHs.push_back(bindings.getPlaceholderByNameSlow(inName));
258 }
259 updateInputPlaceholders(bindings, inPHs, inputs);
260 loadedEE.run(bindings);
261 Tensor test =
262 bindings.get(bindings.getPlaceholderByNameSlow(resultName))->clone();
263 EXPECT_TRUE(ref.isEqual(test, 0.0f));
264}
265
266/// This one tests the model with this feature: after BFS, the memory
267/// consumption of all the nodes in each level won't exceed the device memory
268/// constraints.
269TEST_F(PartitionerTest, Basic1) {
270 ExecutionEngine EER, EEP;
271 EEP.setSkipModuleStrip(true);
272 constexpr float range = 2.0;
273 std::vector<ExecutionEngine *> engines{&EER, &EEP};
274 // Since compiling modifies the module and partitioning modifies the function,
275 // setup two EEs with identical functions for validation.
276 for (auto EE : engines) {
277 auto mod = &EE->getModule();
278 F_ = mod->createFunction("main");
279 auto *input =
280 mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input", false);
281 auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1");
282 auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1");
283 bindings_.allocate(input);
284 w1->getHandle<>().randomize(-range, range, mod->getPRNG());
285 b1->getHandle<>().randomize(-range, range, mod->getPRNG());
286
287 // Initial FC.
288 Node *I = F_->createFullyConnected("initial_fc", input, w1, b1);
289 I = F_->createSigmoid("initial_sigmoid", I);
290
291 // Left branch.
292 auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2");
293 auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2");
294 w2->getHandle<>().randomize(-range, range, mod->getPRNG());
295 b2->getHandle<>().randomize(-range, range, mod->getPRNG());
296 Node *L = F_->createFullyConnected("left_fc1", I, w2, b2);
297 L = F_->createSigmoid("left_sigmoid1", L);
298 auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3");
299 auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3");
300 w3->getHandle<>().randomize(-range, range, mod->getPRNG());
301 b3->getHandle<>().randomize(-range, range, mod->getPRNG());
302 L = F_->createFullyConnected("left_fc2", L, w3, b3);
303 L = F_->createSigmoid("left_sigmoid2", L);
304
305 // Right branch.
306 auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4");
307 auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4");
308 w4->getHandle<>().randomize(-range, range, mod->getPRNG());
309 b4->getHandle<>().randomize(-range, range, mod->getPRNG());
310 Node *R = F_->createFullyConnected("right_fc1", I, w4, b4);
311 R = F_->createSigmoid("right_sigmoid1", R);
312 auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5");
313 auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5");
314 w5->getHandle<>().randomize(-range, range, mod->getPRNG());
315 b5->getHandle<>().randomize(-range, range, mod->getPRNG());
316 R = F_->createFullyConnected("right_fc2", R, w5, b5);
317 R = F_->createSigmoid("right_sigmoid2", R);
318
319 // Join branches.
320 auto *mul = F_->createMul("mul", L, R);
321 F_->createSave("ret", mul);
322 }
323
324 // Infer using the un-partitioned graph.
325 Tensor in(ElemKind::FloatTy, {1, 32});
326 in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG());
327
328 EER.compile(CompilationMode::Infer);
329 bindings_.clear();
330 bindings_.allocate(EER.getModule().getPlaceholders());
331 updateInputPlaceholders(bindings_,
332 {bindings_.getPlaceholderByNameSlow("input")}, {&in});
333 EER.run(bindings_);
334 Tensor ref =
335 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
336
337 std::vector<DeviceInfo> devices = {
338 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "Interpreter"}};
339 Partitioner myPartitioner(&EEP.getModule(), devices, true);
340 CompilationContext cctx;
341 auto dagList = myPartitioner.partition(cctx);
342 ASSERT_TRUE((bool)dagList);
343 EXPECT_EQ(EEP.getModule().getFunctions().size(), 3);
344 EXPECT_EQ(dagList->size(), 1);
345 EXPECT_TRUE(checkSaveNode(EEP.getModule()));
346
347 // Run the paritioned graph and compare the results.
348 bindings_.clear();
349 bindings_.allocate(EEP.getModule().getPlaceholders());
350 EEP.compile(cctx);
351 executeDAG(dagList->begin()->root.get(), EEP.getModule(), bindings_,
352 {bindings_.getPlaceholderByNameSlow("input")}, {&in}, &EEP);
353 Tensor test =
354 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
355 EXPECT_TRUE(ref.isEqual(test, 0.0f));
356 verifyDAGSerialization(dagList.get(), EEP.getModule(), bindings_, {"input"},
357 "ret", devices, {&in}, ref);
358}
359
360/// This one tests the model with this feature: after BFS, there is one level,
361/// the memory consumption of all the nodes in which exceeds the device memory
362/// constraints.
363TEST_F(PartitionerTest, Basic2) {
364
365 ExecutionEngine EER, EEP;
366 EEP.setSkipModuleStrip(true);
367 constexpr float range = 2.0;
368 std::vector<ExecutionEngine *> engines{&EER, &EEP};
369 for (auto EE : engines) {
370 auto mod = &EE->getModule();
371 F_ = mod->createFunction("main");
372 auto *input =
373 mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input", false);
374 auto *input1 =
375 mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input1", false);
376 bindings_.allocate(input);
377 bindings_.allocate(input1);
378 // Left branch.
379 auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2");
380 auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2");
381 w2->getHandle<>().randomize(-range, range, mod->getPRNG());
382 b2->getHandle<>().randomize(-range, range, mod->getPRNG());
383 Node *L = F_->createFullyConnected("left_fc1", input, w2, b2);
384 L = F_->createSigmoid("left_sigmoid1", L);
385 auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3");
386 auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3");
387 w3->getHandle<>().randomize(-range, range, mod->getPRNG());
388 b3->getHandle<>().randomize(-range, range, mod->getPRNG());
389 L = F_->createFullyConnected("left_fc2", L, w3, b3);
390 L = F_->createSigmoid("left_sigmoid2", L);
391
392 // Right branch.
393 auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4");
394 auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4");
395 w4->getHandle<>().randomize(-range, range, mod->getPRNG());
396 b4->getHandle<>().randomize(-range, range, mod->getPRNG());
397 Node *R = F_->createFullyConnected("right_fc1", input1, w4, b4);
398 R = F_->createSigmoid("right_sigmoid1", R);
399 auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5");
400 auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5");
401 w5->getHandle<>().randomize(-range, range, mod->getPRNG());
402 b5->getHandle<>().randomize(-range, range, mod->getPRNG());
403 R = F_->createFullyConnected("right_fc2", R, w5, b5);
404 R = F_->createSigmoid("right_sigmoid2", R);
405
406 // Join branches.
407 auto *mul = F_->createMul("mul", L, R);
408 F_->createSave("ret", mul);
409 }
410
411 // Infer using the un-partitioned graph.
412 Tensor in(ElemKind::FloatTy, {1, 16});
413 in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG());
414 EER.compile(CompilationMode::Infer);
415 bindings_.clear();
416 bindings_.allocate(EER.getModule().getPlaceholders());
417 updateInputPlaceholders(bindings_,
418 {bindings_.getPlaceholderByNameSlow("input"),
419 bindings_.getPlaceholderByNameSlow("input1")},
420 {&in, &in});
421 EER.run(bindings_);
422 Tensor ref =
423 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
424
425 std::vector<DeviceInfo> devices = {{2048, "Interpreter"},
426 {2048, "Interpreter"},
427 {2048, "Interpreter"},
428 {2048, "Interpreter"}};
429 Partitioner myPartitioner(&EEP.getModule(), devices);
430 CompilationContext cctx;
431 cctx.saturateHost = true;
432 runtime::DAGListTy dagList;
433 ASSIGN_VALUE_OR_FAIL_TEST(dagList, myPartitioner.partition(cctx));
434 EXPECT_EQ(EEP.getModule().getFunctions().size(), 2);
435 EXPECT_EQ(dagList.size(), 1);
436 ASSERT_TRUE(checkSaveNode(EEP.getModule()));
437
438 for (auto &dag : dagList) {
439 for (auto &node : dag.nodes) {
440 // Since saturateHost is set true, in this case, there should be 2 copys
441 // of the partitions.
442 EXPECT_EQ(node->logicalDevices.size(), 2);
443 }
444 }
445
446 // Run the paritioned graph and compare the results.
447 bindings_.clear();
448 bindings_.allocate(EEP.getModule().getPlaceholders());
449 EEP.compile(cctx);
450 updateInputPlaceholders(bindings_,
451 {bindings_.getPlaceholderByNameSlow("input"),
452 bindings_.getPlaceholderByNameSlow("input1")},
453 {&in, &in});
454 executeDAG(dagList.begin()->root.get(), EEP.getModule(), bindings_,
455 {bindings_.getPlaceholderByNameSlow("input")}, {&in}, &EEP);
456 Tensor test =
457 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
458 ASSERT_TRUE(ref.isEqual(test, 0.0f));
459 verifyDAGSerialization(dagList, EEP.getModule(), bindings_,
460 {"input", "input1"}, "ret", devices, {&in, &in}, ref);
461}
462
463/// This one tests the error msg: if the number of partitions is larger than
464/// given number of devices, report an error.
465TEST_F(PartitionerTest, Error1) {
466 ExecutionEngine EER, EEP;
467 constexpr float range = 2.0;
468 std::vector<ExecutionEngine *> engines{&EER, &EEP};
469 for (auto EE : engines) {
470 auto mod = &EE->getModule();
471 F_ = mod->createFunction("main");
472 auto *input =
473 mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input", false);
474 auto *input1 =
475 mod->createPlaceholder(ElemKind::FloatTy, {1, 16}, "input1", false);
476 bindings_.allocate(input);
477 bindings_.allocate(input1);
478 // Left branch.
479 auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2");
480 auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2");
481 w2->getHandle<>().randomize(-range, range, mod->getPRNG());
482 b2->getHandle<>().randomize(-range, range, mod->getPRNG());
483 Node *L = F_->createFullyConnected("left_fc1", input, w2, b2);
484 L = F_->createSigmoid("left_sigmoid1", L);
485 auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3");
486 auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3");
487 w3->getHandle<>().randomize(-range, range, mod->getPRNG());
488 b3->getHandle<>().randomize(-range, range, mod->getPRNG());
489 L = F_->createFullyConnected("left_fc2", L, w3, b3);
490 L = F_->createSigmoid("left_sigmoid2", L);
491
492 // Right branch.
493 auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4");
494 auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4");
495 w4->getHandle<>().randomize(-range, range, mod->getPRNG());
496 b4->getHandle<>().randomize(-range, range, mod->getPRNG());
497 Node *R = F_->createFullyConnected("right_fc1", input1, w4, b4);
498 R = F_->createSigmoid("right_sigmoid1", R);
499 auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5");
500 auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5");
501 w5->getHandle<>().randomize(-range, range, mod->getPRNG());
502 b5->getHandle<>().randomize(-range, range, mod->getPRNG());
503 R = F_->createFullyConnected("right_fc2", R, w5, b5);
504 R = F_->createSigmoid("right_sigmoid2", R);
505
506 // Join branches.
507 auto *mul = F_->createMul("mul", L, R);
508 F_->createSave("ret", mul);
509 }
510
511 // Infer using the un-partitioned graph.
512 Tensor in(ElemKind::FloatTy, {1, 16});
513 in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG());
514
515 EER.compile(CompilationMode::Infer);
516 bindings_.clear();
517 bindings_.allocate(EER.getModule().getPlaceholders());
518 updateInputPlaceholders(bindings_,
519 {bindings_.getPlaceholderByNameSlow("input"),
520 bindings_.getPlaceholderByNameSlow("input1")},
521 {&in, &in});
522 EER.run(bindings_);
523
524 std::vector<DeviceInfo> devices = {{2048, "Interpreter"}};
525 Partitioner myPartitioner(&EEP.getModule(), devices);
526 CompilationContext cctx;
527 auto dagList = myPartitioner.partition(cctx);
528 EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError()));
529}
530
531/// This one tests the roofline computed with compute, memory and
532/// communication costs
533TEST_F(PartitionerTest, Basic1Roofline) {
534 ExecutionEngine EEP;
535 constexpr float range = 2.0;
536
537 auto mod = &EEP.getModule();
538 F_ = mod->createFunction("main");
539 auto *input =
540 mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input", false);
541 auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1");
542 auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1");
543 bindings_.allocate(input);
544 w1->getHandle<>().randomize(-range, range, mod->getPRNG());
545 b1->getHandle<>().randomize(-range, range, mod->getPRNG());
546
547 // Initial FC.
548 Node *I = F_->createFullyConnected("initial_fc", input, w1, b1);
549 I = F_->createSigmoid("initial_sigmoid", I);
550
551 // Left branch.
552 auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2");
553 auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2");
554 w2->getHandle<>().randomize(-range, range, mod->getPRNG());
555 b2->getHandle<>().randomize(-range, range, mod->getPRNG());
556 Node *L = F_->createFullyConnected("left_fc1", I, w2, b2);
557 L = F_->createSigmoid("left_sigmoid1", L);
558 auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3");
559 auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3");
560 w3->getHandle<>().randomize(-range, range, mod->getPRNG());
561 b3->getHandle<>().randomize(-range, range, mod->getPRNG());
562 L = F_->createFullyConnected("left_fc2", L, w3, b3);
563 L = F_->createSigmoid("left_sigmoid2", L);
564
565 // Right branch.
566 auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4");
567 auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4");
568 w4->getHandle<>().randomize(-range, range, mod->getPRNG());
569 b4->getHandle<>().randomize(-range, range, mod->getPRNG());
570 Node *R = F_->createFullyConnected("right_fc1", I, w4, b4);
571 R = F_->createSigmoid("right_sigmoid1", R);
572 auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5");
573 auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5");
574 w5->getHandle<>().randomize(-range, range, mod->getPRNG());
575 b5->getHandle<>().randomize(-range, range, mod->getPRNG());
576 R = F_->createFullyConnected("right_fc2", R, w5, b5);
577 R = F_->createSigmoid("right_sigmoid2", R);
578
579 // Join branches.
580 auto *mul = F_->createMul("mul", L, R);
581 F_->createSave("ret", mul);
582
583 // Since the partitioner will look at all nodesin the function post
584 // optimization and lowering, we need to do so here for the same list of
585 // nodes.
586 std::unique_ptr<Backend> backend(createBackend(EEP.getBackendName()));
587 CompilationContext cctx;
588 EXIT_ON_ERR(optimizeFunctionBeforeLowering(
589 EEP.getModule().getFunction("main"), cctx));
590 EXIT_ON_ERR(::glow::optimizeFunction(EEP.getModule().getFunction("main"),
591 *backend, cctx));
592 std::unordered_map<Node *, std::string> nodeNamesMap;
593 for (auto &node : EEP.getModule().getFunction("main")->getNodes()) {
594 nodeNamesMap[&node] = node.getName().str();
595 }
596
597 // check compute costs
598 std::unordered_map<std::string, float> expectedComputeTime{
599 {"initial_sigmoid", 128}, {"left_sigmoid2", 64},
600 {"right_sigmoid1", 128}, {"mul", 96},
601 {"ret_save", 0}, {"initial_fc", 21760},
602 {"left_fc1", 10240}, {"left_fc2", 5120},
603 {"left_sigmoid1", 128}, {"right_fc1", 10240},
604 {"right_fc2", 5120}, {"right_sigmoid2", 64},
605 };
606
607 BackendInfo backendInfo;
608 backendInfo.sramCapacity = 100;
609 backendInfo.peakCompute = 10;
610 backendInfo.peakDramBw = 0.1;
611 backendInfo.peakSramBw = 1;
612 backendInfo.peakPCIeBw = 0.05;
613 for (auto const &p : nodeNamesMap) {
614 auto *N = p.first;
615 EXPECT_EQ(getNodeComputeTime(N, backendInfo),
616 expectedComputeTime[p.second]);
617 }
618}
619
620TEST_F(PartitionerTest, SelectRepFunc) {
621 auto *inA = mod_.createConstant(ElemKind::FloatTy, {2}, "A");
622 auto *inB = mod_.createConstant(ElemKind::FloatTy, {2}, "B");
623 inA->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
624 inB->getHandle<>().randomize(-2.0, 2.0, mod_.getPRNG());
625
626 auto *plus = F_->createAdd("AplusB", inA, inB);
627 F_->createSave("save", plus);
628
629 Partitioner myPartitioner(&mod_, {{1000000, "Interpreter"},
630 {1000000, "Interpreter"},
631 {1000000, "Interpreter"}});
632
633 CompilationContext cctx;
634 auto dagList = myPartitioner.partition(cctx);
635 ASSERT_TRUE((bool)dagList);
636}
637
638/// Create a mock backend and rewrite the isOpSupported function
639/// to un-support the op \p unsupportedOpKind.
640template <glow::Kinded::Kind unsupportedOpKind>
641class MockBackend : public Backend {
642public:
643 std::string backendName;
644
645 class MockFunction : public CompiledFunction {
646 public:
647 MockFunction(llvm::StringRef backendName, runtime::RuntimeBundle &&bundle)
648 : CompiledFunction(std::move(bundle)), backendName(backendName) {}
649
650 Error execute(ExecutionContext *) override { return Error::success(); }
651
652 std::string getCompileBackendName() const override { return backendName; }
653
654 std::string backendName;
655 };
656
657 std::string getBackendName() const override { return backendName; }
658
659 Expected<std::unique_ptr<CompiledFunction>>
660 compile(Function *F, const BackendOptions &opts) const override {
661 return glow::make_unique<MockFunction>(backendName,
662 runtime::RuntimeBundle::create(*F));
663 }
664
665 bool isOpSupported(const NodeInfo &NI) const override {
666 if (NI.getKind() == unsupportedOpKind) {
667 return false;
668 }
669 return true;
670 }
671
672 bool shouldLower(const Node *N) const override { return false; }
673
674 bool generateInst(Node *N, IRGenVisitor &irgen) const override {
675 return false;
676 }
677
678 Expected<double> estimateNodeCost(const Node * /*node */) const override {
679 return 2.0;
680 }
681
682 runtime::DeviceManager *
683 createDeviceManager(const runtime::DeviceConfig &deviceConfig) override {
684 return nullptr;
685 }
686};
687
688class BackendWithoutSub : public MockBackend<Kinded::Kind::SubNodeKind> {
689public:
690 BackendWithoutSub() { backendName = "CPU"; }
691};
692class BackendWithoutMul : public MockBackend<Kinded::Kind::MulNodeKind> {
693public:
694 BackendWithoutMul() { backendName = "Interpreter"; }
695};
696
697static void createSimpleModule(Module &mod) {
698 mod.clear();
699 auto *F = mod.createFunction("test");
700 auto *input1 =
701 mod.createPlaceholder(ElemKind::FloatTy, {16}, "input1", false);
702 auto *input2 =
703 mod.createPlaceholder(ElemKind::FloatTy, {16}, "input2", false);
704 auto *input3 =
705 mod.createPlaceholder(ElemKind::FloatTy, {16}, "input3", false);
706 auto *sub = F->createSub("sub", input1, input2);
707 auto *mul = F->createMul("mul", input1, input2);
708 auto *sum = F->createAdd("add", sub, mul);
709 auto *sub2 = F->createSub("sub1", sum, input3);
710 auto *save = F->createSave("ret", sub2);
711 (void)save;
712}
713
714static void createSimpleSparseNNModule(Module &mod, bool shareSplatInputs,
715 bool addClipAndLayerNorm, bool addTile,
716 bool addTanh, dim_t numFCLayers) {
717 mod.clear();
718 auto *F = mod.createFunction("test");
719
720 // Create SLS inputs
721 std::vector<NodeValue> slsOutputs;
722 const dim_t tableWidth = 16;
723 const dim_t numIndices = 80;
724 const dim_t batchSize = 32;
725 const dim_t tableEntries = 10;
726 const size_t tableNum = 5;
727 // Based on how SLS table width is calculated below, the fcWidth
728 // is the sum of all the SLS tables' width.
729 const dim_t fcWidth = tableNum * (tableNum + 1) / 2 * tableWidth;
730
731 NodeValue weights;
732 NodeValue lengths;
733 NodeValue scale;
734 NodeValue bias;
735 if (shareSplatInputs) {
736 // Shared by FusedRowwiseQuantizedSparseLengthsWeightedSumNode
737 auto ty =
738 F->getParent()->uniqueType(ElemKind::FloatTy, {numIndices * batchSize});
739 weights = F->createSplat("ones", ty, 1.0)->getResult();
740 lengths =
741 F->createSplat(
742 "lengths",
743 F->getParent()->uniqueType(ElemKind::Int32ITy, {batchSize}), 1)
744 ->getResult();
745 // Shared by LayerNormalizationNode
746 scale = F->createSplat("LN_scale",
747 F->getParent()->uniqueType(ElemKind::FloatTy,
748 {fcWidth / tableNum}),
749 1.0)
750 ->getResult();
751 bias = F->createSplat("LN_bias",
752 F->getParent()->uniqueType(ElemKind::FloatTy,
753 {fcWidth / tableNum}),
754 1.0)
755 ->getResult();
756 }
757
758 // Create SLS portion
759 for (int table = 0; table < tableNum; table++) {
760 dim_t thisTableWidth =
761 shareSplatInputs ? fcWidth / tableNum : (table + 1) * tableWidth;
762 Tensor data(ElemKind::FloatTy, {tableEntries, thisTableWidth});
763 auto *indices = mod.createPlaceholder(
764 ElemKind::Int64ITy, {numIndices * batchSize}, "indices", false);
765 if (!shareSplatInputs) {
766 weights = mod.createPlaceholder(ElemKind::FloatTy,
767 {numIndices * batchSize}, "w", false)
768 ->getOutput();
769
770 lengths = mod.createPlaceholder(ElemKind::Int32ITy, {batchSize},
771 "lengths", false)
772 ->getOutput();
773 if (addTile && table == 0) {
774 lengths =
775 mod.createPlaceholder(ElemKind::Int32ITy, {1}, "lengths", false)
776 ->getOutput();
777 }
778 }
779 float avgLength = (table % 2) ? 12.0f : 10.0f;
780 auto *slsOutput = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
781 "SLS", data, weights, indices, lengths, ElemKind::UInt8FusedQTy,
782 /*useFP16Accumulation*/ false,
783 /* lengthsMode */ LengthsMode::Variable, /* avgLength */ avgLength);
784
785 if (addClipAndLayerNorm) {
786 /* Clip */
787 auto *clipped = F->createClip("SLS_clipped", slsOutput, 0.0f, 70.0f);
788
789 /* Layer Norm*/
790 if (!shareSplatInputs) {
791 Tensor scaleT(ElemKind::FloatTy, {thisTableWidth});
792 scaleT.getHandle().randomize(0.0f, 1.0f, mod.getPRNG());
793 scale = mod.createConstant("LN_scale", std::move(scaleT));
794
795 Tensor biasT(ElemKind::FloatTy, {thisTableWidth});
796 biasT.getHandle().randomize(0.0f, 1.0f, mod.getPRNG());
797 bias = mod.createConstant("LN_bias", std::move(biasT));
798 }
799 auto *layerNormed = F->createLayerNormalization(
800 "LN", clipped->getResult().getType(), clipped, scale, bias, 1e-5);
801
802 /* Clip */
803 auto *layerNormedClipped =
804 F->createClip("LN_clipped", layerNormed, 0.0f, 70.0f);
805 slsOutputs.emplace_back(layerNormedClipped);
806 } else if (addTile && table == 0) {
807 /* Tile */
808 auto *tiled = F->createTile("SLS_tiled", slsOutput, batchSize, 0);
809 slsOutputs.emplace_back(tiled);
810 } else if (addTanh) {
811 /* Tanh */
812 auto *tanh = F->createTanh("SLS_tanh", slsOutput);
813 slsOutputs.emplace_back(tanh);
814 } else {
815 slsOutputs.emplace_back(slsOutput);
816 }
817 }
818
819 // Create Concat
820 auto *concat = F->createConcat("concat", slsOutputs, 1);
821 Node *cur = (Node *)concat;
822
823 // Create FC portion
824 for (dim_t layer = 0; layer < numFCLayers; layer++) {
825 Tensor FCWeights(ElemKind::FloatTy, {fcWidth, fcWidth});
826 FCWeights.getHandle().randomize(-0.5, 0.5, mod.getPRNG());
827 Constant *weights = mod.createConstant("FCWeights", FCWeights);
828 Tensor FCBias(ElemKind::FloatTy, {fcWidth});
829 FCBias.getHandle().randomize(-0.5, 0.5, mod.getPRNG());
830 Constant *bias = mod.createConstant("FCBias", FCBias);
831
832 auto *FC = F->createFullyConnected("FC", cur, weights, bias);
833 cur = (Node *)FC;
834 }
835
836 auto *save = F->createSave("ret", cur);
837 (void)save;
838}
839
840/// \returns true if there is \p nodeKind kind of nodes in \p func.
841static bool findNodeInFunction(const Function *func,
842 const Kinded::Kind nodeKind) {
843 for (const Node &N : func->getNodes()) {
844 if (N.getKind() == nodeKind) {
845 return true;
846 }
847 }
848 return false;
849}
850
851/// To check if the generated DAG is correct for the SparseNN Partiton
852/// unnittests. The network used for check is generated from function static
853/// void createSimpleSparseNNModule(Module &mod).
854static void
855sparseNNPartitionValidation(const DAGListTy &dagList, uint64_t deviceMemory,
856 Module &mod, bool shareSplatInputs,
857 bool addClipAndLayerNorm, bool pairLNWithSLS,
858 bool addTile, bool pairTileWithSLS, bool addTanh,
859 std::string sparseNNPartitioningPairSLSWith) {
860 int numOfCPUBackends = 0;
861 int numOfSLSNodes = 0;
862 int numOfFCNodes = 0;
863 std::unordered_set<uint64_t> slsPartitionSizes;
864 uint64_t nonSlsPartitionSize = 0;
865 for (auto &dag : dagList) {
866 auto tileAdded = false;
867 for (auto &node : dag.nodes) {
868 ASSERT_TRUE(node->backendName == "CPU");
869 numOfCPUBackends++;
870 auto *func = mod.getFunction(node->name);
871 GraphMemInfo memInfo = getFunctionMemory(func);
872
873 if (shareSplatInputs) {
874 for (const Node &N : func->getNodes()) {
875 if (const auto *SLWS = llvm::dyn_cast<
876 FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(&N)) {
877 EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getWeights()));
878 EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getLengths()));
879 // weight/length node is splat node, partitioner will clone it,
880 // thus each user has its own copy.
881 EXPECT_EQ(SLWS->getWeights().getNumUsers(), 1);
882 EXPECT_EQ(SLWS->getLengths().getNumUsers(), 1);
883 } else if (const auto *SLWS =
884 llvm::dyn_cast<LayerNormalizationNode>(&N)) {
885 EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getScale()));
886 EXPECT_TRUE(llvm::isa<SplatNode>(SLWS->getBias()));
887 // scale/bias node is splat node, partitioner will clone it, thus
888 // each user has its own copy.
889 EXPECT_EQ(SLWS->getScale().getNumUsers(), 1);
890 EXPECT_EQ(SLWS->getBias().getNumUsers(), 1);
891 }
892 }
893 }
894
895 if (findNodeInFunction(
896 func,
897 Kinded::Kind::
898 FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind)) {
899 numOfSLSNodes++;
900 slsPartitionSizes.insert(memInfo.getTotalMemSize());
901 EXPECT_EQ(node->logicalDevices.size(), 1);
902 if (addClipAndLayerNorm && pairLNWithSLS) {
903 EXPECT_TRUE(findNodeInFunction(
904 func, Kinded::Kind::LayerNormalizationNodeKind));
905 EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::ClipNodeKind));
906 }
907 if (addTile && pairTileWithSLS) {
908 tileAdded |= findNodeInFunction(func, Kinded::Kind::TileNodeKind);
909 }
910 if (addTanh &&
911 sparseNNPartitioningPairSLSWith.find("Tanh") != std::string::npos) {
912 EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::TanhNodeKind));
913 }
914 } else if (findNodeInFunction(func,
915 Kinded::Kind::FullyConnectedNodeKind)) {
916 nonSlsPartitionSize = memInfo.getTotalMemSize();
917 numOfFCNodes++;
918 EXPECT_EQ(node->logicalDevices.size(), 3);
919 if (addClipAndLayerNorm && !pairLNWithSLS) {
920 EXPECT_TRUE(findNodeInFunction(
921 func, Kinded::Kind::LayerNormalizationNodeKind));
922 EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::ClipNodeKind));
923 }
924 if (addTanh &&
925 sparseNNPartitioningPairSLSWith.find("Tanh") == std::string::npos) {
926 EXPECT_TRUE(findNodeInFunction(func, Kinded::Kind::TanhNodeKind));
927 }
928 } else {
929 FAIL() << "Unexpected partition";
930 }
931 }
932 if (addTile && pairTileWithSLS) {
933 EXPECT_TRUE(tileAdded);
934 }
935 }
936
937 // 4 partitions (3 SLS + 1 FC)
938 EXPECT_EQ(numOfCPUBackends, 4);
939 EXPECT_EQ(numOfSLSNodes, 3);
940 EXPECT_EQ(numOfFCNodes, 1);
941 for (uint64_t slsPartitionSize : slsPartitionSizes) {
942 EXPECT_LE(slsPartitionSize + nonSlsPartitionSize, deviceMemory);
943 }
944}
945
946static void testSimpleSparseNNPartitioning(
947 Module &mod, bool shareSplatInputs, bool concatSLSOutputs,
948 bool balancePerfModel, bool addClipAndLayerNorm, bool pairLNWithSLS,
949 bool addTile, bool pairTileWithSLS, bool addTanh,
950 std::string sparseNNPartitioningPairSLSWith, bool forceFailure = false) {
951 createSimpleSparseNNModule(mod, shareSplatInputs, addClipAndLayerNorm,
952 addTile, addTanh, forceFailure ? 5 : 4);
953 BackendWithoutSub backend1, backend2, backend3;
954 std::vector<Backend *> backends;
955 backends.emplace_back(&backend1);
956 backends.emplace_back(&backend2);
957 backends.emplace_back(&backend3);
958 const uint64_t deviceMemory = 1250000;
959 std::vector<DeviceInfo> devices = {
960 {deviceMemory, "CPU"}, {deviceMemory, "CPU"}, {deviceMemory, "CPU"}};
961 Partitioner partitioner(&mod, devices, backends);
962 CompilationContext cctx;
963 cctx.optimizationOpts.useSparseNNPartitioningScheme = true;
964 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards = 3;
965 cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats = concatSLSOutputs;
966 cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel = balancePerfModel;
967 cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS = pairLNWithSLS;
968 cctx.optimizationOpts.sparseNNPartitioningPairTileWithSLS = pairTileWithSLS;
969 cctx.optimizationOpts.sparseNNPartitioningPairSLSWith =
970 sparseNNPartitioningPairSLSWith;
971 Expected<DAGListTy> dagList = partitioner.partition(cctx);
972 bool failed = ERR_TO_BOOL(dagList.takeError());
973 if (forceFailure) {
974 EXPECT_TRUE(failed);
975 return;
976 }
977 if (concatSLSOutputs && addTile && !pairTileWithSLS) {
978 EXPECT_TRUE(failed);
979 return;
980 }
981 EXPECT_EQ(mod.getFunctions().size(), 4);
982 EXPECT_EQ(dagList->size(), 1);
983 ASSERT_TRUE(checkSaveNode(mod));
984 sparseNNPartitionValidation(*dagList, deviceMemory, mod, shareSplatInputs,
985 addClipAndLayerNorm, pairLNWithSLS, addTile,
986 pairTileWithSLS, addTanh,
987 sparseNNPartitioningPairSLSWith);
988 mod.clear();
989}
990
991/// Test using user-defined backends for SparseNN partition.
992TEST_F(PartitionerTest, SimpleSparseNNPartitioning) {
993 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
994 /*concatSLSOutputs*/ false,
995 /*balancePerfModel*/ false,
996 /*addClipAndLayerNorm*/ false,
997 /*pairLNWithSLS*/ false,
998 /*addTile*/ false,
999 /*pairTileWithSLS*/ false,
1000 /*addTanh*/ false,
1001 /*pairSLSWith*/ "");
1002}
1003
1004/// Test that this flag is a NOP when LN doesn't exist
1005TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairLNNOP) {
1006 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1007 /*concatSLSOutputs*/ false,
1008 /*balancePerfModel*/ false,
1009 /*addClipAndLayerNorm*/ false,
1010 /*pairLNWithSLS*/ true,
1011 /*addTile*/ false,
1012 /*addTanh*/ false,
1013 /*pairTileWithSLS*/
1014 false,
1015 /*pairSLSWith*/ "");
1016}
1017
1018/// Test using user-defined backends for SparseNN partition.
1019TEST_F(PartitionerTest, SimpleSparseNNPartitioningClipAndLayerNormInNonSLS) {
1020 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1021 /*concatSLSOutputs*/ false,
1022 /*balancePerfModel*/ false,
1023 /*addClipAndLayerNorm*/ true,
1024 /*pairLNWithSLS*/ false,
1025 /*addTile*/ false,
1026 /*pairTileWithSLS*/ false,
1027 /*addTanh*/ false,
1028 /*pairSLSWith*/ "");
1029}
1030
1031/// Test using user-defined backends for SparseNN partition.
1032TEST_F(PartitionerTest, SimpleSparseNNPartitioningClipAndLayerNormInSLS) {
1033 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1034 /*concatSLSOutputs*/ false,
1035 /*balancePerfModel*/ false,
1036 /*addClipAndLayerNorm*/ true,
1037 /*pairLNWithSLS*/ true,
1038 /*addTile*/ false,
1039 /*pairTileWithSLS*/ false,
1040 /*addTanh*/ false,
1041 /*pairSLSWith*/ "");
1042}
1043
1044TEST_F(PartitionerTest, SimpleSparseNNPartitioning_ConcatSLSOutputs) {
1045 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1046 /*concatSLSOutputs*/ true,
1047 /*balancePerfModel*/ false,
1048 /*addClipAndLayerNorm*/ true,
1049 /*pairLNWithSLS*/ false,
1050 /*addTile*/ false,
1051 /*pairTileWithSLS*/ false,
1052 /*addTanh*/ false,
1053 /*pairSLSWith*/ "");
1054}
1055
1056/// Test using user-defined backends for SparseNN partition when inputs are
1057/// shared Splats by all SLSs.
1058TEST_F(PartitionerTest, SimpleSparseNNPartitioning_SharedSplatInputs) {
1059 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ true,
1060 /*concatSLSOutputs*/ false,
1061 /*balancePerfModel*/ false,
1062 /*addClipAndLayerNorm*/ true,
1063 /*pairLNWithSLS*/ false,
1064 /*addTile*/ false,
1065 /*pairTileWithSLS*/ false,
1066 /*addTanh*/ false,
1067 /*pairSLSWith*/ "");
1068}
1069
1070/// Test using user-defined backends for SparseNN partition when inputs are
1071/// shared Splats by all SLSs, and LN is included in frontier.
1072TEST_F(PartitionerTest,
1073 SimpleSparseNNPartitioning_SharedSplatInputsAndLayerNormInSLS) {
1074 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ true,
1075 /*concatSLSOutputs*/ false,
1076 /*balancePerfModel*/ false,
1077 /*addClipAndLayerNorm*/ true,
1078 /*pairLNWithSLS*/ true,
1079 /*addTile*/ false,
1080 /*pairTileWithSLS*/ false,
1081 /*addTanh*/ false,
1082 /*pairSLSWith*/ "");
1083}
1084
1085/// Test using user-defined backends for SparseNN partition.
1086TEST_F(PartitionerTest, SimpleSparseNNPartitioningBalancePerfModel) {
1087 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1088 /*concatSLSOutputs*/ false,
1089 /*balancePerfModel*/ true,
1090 /*addClipAndLayerNorm*/ true,
1091 /*pairLNWithSLS*/ false,
1092 /*addTile*/ false,
1093 /*pairTileWithSLS*/ false,
1094 /*addTanh*/ false,
1095 /*pairSLSWith*/ "");
1096}
1097
1098/// Test pairTileWithSLS is NOP when Tile doesn't exist
1099TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTileNOP) {
1100 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1101 /*concatSLSOutputs*/ true,
1102 /*balancePerfModel*/ false,
1103 /*addClipAndLayerNorm*/ false,
1104 /*pairLNWithSLS*/ false,
1105 /*addTile*/ false,
1106 /*pairTileWithSLS*/ true,
1107 /*addTanh*/ false,
1108 /*pairSLSWith*/ "");
1109}
1110
1111/// Test concatting SLS nodes where one node has first dimension = 1 without
1112/// concatSLSOutputs works
1113TEST_F(PartitionerTest, SimpleSparseNNPartitioningTileAndConcatSLSOutputs) {
1114 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1115 /*concatSLSOutputs*/ false,
1116 /*balancePerfModel*/ false,
1117 /*addClipAndLayerNorm*/ false,
1118 /*pairLNWithSLS*/ false,
1119 /*addTile*/ true,
1120 /*pairTileWithSLS*/ false,
1121 /*addTanh*/ false,
1122 /*pairSLSWith*/ "");
1123}
1124
1125/// Test concatSLSOutputs on SLS nodes w/ first dimension = 1 while other nodes
1126/// have the same first dimension without pairTileWithSLS flag fails
1127TEST_F(PartitionerTest,
1128 SimpleSparseNNPartitioningTileAndConcatSlsOutputsFails) {
1129 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1130 /*concatSLSOutputs*/ true,
1131 /*balancePerfModel*/ false,
1132 /*addClipAndLayerNorm*/ false,
1133 /*pairLNWithSLS*/ false,
1134 /*addTile*/ true,
1135 /*pairTileWithSLS*/ false,
1136 /*addTanh*/ false,
1137 /*pairSLSWith*/ "");
1138}
1139
1140/// Test concatSLSOutputs on SLS nodes w/ first dimension = 1 while other nodes
1141/// have the same first dimension with pairTileWithSLS flag works
1142TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTileAndConcatSlsOutputs) {
1143 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1144 /*concatSLSOutputs*/ true,
1145 /*balancePerfModel*/ false,
1146 /*addClipAndLayerNorm*/ false,
1147 /*pairLNWithSLS*/ false,
1148 /*addTile*/ true,
1149 /*pairTileWithSLS*/ true,
1150 /*addTanh*/ false,
1151 /*pairSLSWith*/ "");
1152}
1153
1154/// Test pairSLSWith is NOP when Tanh doesn't exist
1155TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTanhNOP) {
1156 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1157 /*concatSLSOutputs*/ true,
1158 /*balancePerfModel*/ false,
1159 /*addClipAndLayerNorm*/ false,
1160 /*pairLNWithSLS*/ false,
1161 /*addTile*/ false,
1162 /*pairTileWithSLS*/ false,
1163 /*addTanh*/ false,
1164 /*pairSLSWith*/ "Tanh");
1165}
1166
1167/// Test using user-defined backends for SparseNN partition.
1168TEST_F(PartitionerTest, SimpleSparseNNPartitioningPairTanhAndSLS) {
1169 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1170 /*concatSLSOutputs*/ false,
1171 /*balancePerfModel*/ false,
1172 /*addClipAndLayerNorm*/ false,
1173 /*pairLNWithSLS*/ false,
1174 /*addTile*/ false,
1175 /*pairTileWithSLS*/ false,
1176 /*addTanh*/ true,
1177 /*pairSLSWith*/ "Tanh");
1178}
1179
1180/// This test checks that we fail partitioning when we have a SLSPartition and
1181/// NonSLSPartition that, when summed together, cannot fit inside a device.
1182TEST_F(PartitionerTest, SimpleSparseNNPartitioningExpectFailure) {
1183 testSimpleSparseNNPartitioning(mod_, /*shareSplatInputs*/ false,
1184 /*concatSLSOutputs*/ false,
1185 /*balancePerfModel*/ false,
1186 /*addClipAndLayerNorm*/ true,
1187 /*pairLNWithSLS*/ true,
1188 /*addTile*/ false,
1189 /*pairTileWithSLS*/ false,
1190 /*addTanh*/ false,
1191 /*pairSLSWith*/ "LayerNorm",
1192 /*forceFailure*/ true);
1193}
1194
1195/// To check if the generated DAG is correct for the Heterogeneous Partiton
1196/// unnittests. The network used for check is generated from function static
1197/// void createSimpleModule(Module &mod).
1198static void heterogeneousPartitionValidation(const DAGListTy &dagList,
1199 Module &mod) {
1200 int numOfInterpreterBackends = 0;
1201 int numOfCPUBackends = 0;
1202 int numOfMulNodes = 0;
1203 int numOfSubNodes = 0;
1204 for (auto &dag : dagList) {
1205 for (auto &node : dag.nodes) {
1206 // Although the saturateHost is set true, no saturating the host in
1207 // heterogeneous partiton.
1208 EXPECT_EQ(node->logicalDevices.size(), 1);
1209 if (node->backendName == "CPU") {
1210 numOfCPUBackends++;
1211 auto func = mod.getFunction(node->name);
1212 // Sub Node should not be assigned to CPU backend.
1213 EXPECT_FALSE(findNodeInFunction(func, Kinded::Kind::SubNodeKind));
1214 numOfMulNodes +=
1215 (findNodeInFunction(func, Kinded::Kind::MulNodeKind) == true);
1216 }
1217 if (node->backendName == "Interpreter") {
1218 numOfInterpreterBackends++;
1219 auto func = mod.getFunction(node->name);
1220 // Mul Node should not be assigned to Interpreter backend.
1221 EXPECT_FALSE(findNodeInFunction(func, Kinded::Kind::MulNodeKind));
1222 numOfSubNodes +=
1223 (findNodeInFunction(func, Kinded::Kind::SubNodeKind) == true);
1224 }
1225 }
1226 }
1227 EXPECT_EQ(numOfInterpreterBackends, 2);
1228 EXPECT_EQ(numOfCPUBackends, 1);
1229 EXPECT_EQ(numOfSubNodes, 2);
1230 EXPECT_EQ(numOfMulNodes, 1);
1231}
1232
1233/// Test using user-defined backends for heterogeneous partition.
1234TEST_F(PartitionerTest, SimpleHeterogeneousPartitioning) {
1235 createSimpleModule(mod_);
1236 BackendWithoutSub backendWithoutSub1;
1237 BackendWithoutMul backendWithoutMul1, backendWithoutMul2;
1238 // Create two backends which support different ops, then do the partition by
1239 // assigning the ops to the corresponding abackends.
1240 std::vector<Backend *> backends;
1241 backends.emplace_back(&backendWithoutMul1);
1242 backends.emplace_back(&backendWithoutMul2);
1243 backends.emplace_back(&backendWithoutSub1);
1244 std::vector<DeviceInfo> devices = {
1245 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "CPU"}};
1246 Partitioner partitioner(&mod_, devices, backends);
1247 CompilationContext cctx;
1248 cctx.saturateHost = true;
1249 auto dagList = partitioner.partition(cctx);
1250 ASSERT_TRUE((bool)dagList);
1251 EXPECT_EQ(mod_.getFunctions().size(), 3);
1252 EXPECT_EQ(dagList->size(), 1);
1253 ASSERT_TRUE(checkSaveNode(mod_));
1254 heterogeneousPartitionValidation(dagList.get(), mod_);
1255
1256 mod_.clear();
1257}
1258
1259/// Test pre-defined non-supported ops used for choosing backend in
1260/// Heterogeneous Partition. In this test, "Mul" is not supported in
1261/// Interpreter backend, and "Sub" is not supported in CPU backend.
1262TEST_F(PartitionerTest, heterogeneousPartitioningWithNonSupportedNodes) {
1263#ifndef GLOW_WITH_CPU
1264 return;
1265#endif
1266 createSimpleModule(mod_);
1267 std::vector<DeviceInfo> devices = {{3072, "Interpreter", "Mul"},
1268 {3072, "Interpreter", "Mul"},
1269 {3072, "CPU", "Sub"}};
1270 Partitioner partitioner(&mod_, devices);
1271 CompilationContext cctx;
1272 auto dagList = partitioner.partition(cctx);
1273 ASSERT_TRUE((bool)dagList);
1274 EXPECT_EQ(mod_.getFunctions().size(), 3);
1275 EXPECT_EQ(dagList->size(), 1);
1276 ASSERT_TRUE(checkSaveNode(mod_));
1277 heterogeneousPartitionValidation(dagList.get(), mod_);
1278
1279 mod_.clear();
1280}
1281
1282/// Test pre-defined supported ops used for choosing backend in Heterogeneous
1283/// Partition. In this test, "Mul" is not supported in Interpreter backend,
1284/// and "Sub" is not supported in CPU backend. "Sub,Add,Save" can be supported
1285/// in Interpreter backend and "Mul,Add,Save" can be supported in CPU backend.
1286TEST_F(PartitionerTest, heterogeneousPartitioningWithSupportedNodes) {
1287#ifndef GLOW_WITH_CPU
1288 return;
1289#endif
1290 createSimpleModule(mod_);
1291 std::vector<DeviceInfo> devices = {
1292 // {memory size, backend, non-supported nodes, supported nodes}
1293 {3072, "Interpreter", "", "Sub,Add,Save"},
1294 {3072, "Interpreter", "", "Sub,Add,Save"},
1295 {3072, "CPU", "", "Mul,Add,Save"}};
1296 Partitioner partitioner(&mod_, devices);
1297 CompilationContext cctx;
1298 auto dagList = partitioner.partition(cctx);
1299 ASSERT_TRUE((bool)dagList);
1300 EXPECT_EQ(mod_.getFunctions().size(), 3);
1301 EXPECT_EQ(dagList->size(), 1);
1302 ASSERT_TRUE(checkSaveNode(mod_));
1303 heterogeneousPartitionValidation(dagList.get(), mod_);
1304
1305 mod_.clear();
1306}
1307
1308/// Test assigning more than one partitions in to one device for single
1309/// backendName.
1310TEST_F(PartitionerTest, logicalIDTest0) {
1311 auto *input1 = mod_.createConstant(ElemKind::FloatTy, {1, 100}, "input1");
1312 input1->getHandle<>().randomize(-10, 10, mod_.getPRNG());
1313 auto *input2 =
1314 mod_.createPlaceholder(ElemKind::FloatTy, {100, 1}, "input2", false);
1315 auto *input3 = mod_.createConstant(ElemKind::FloatTy, {1, 100}, "input5");
1316 input3->getHandle<>().randomize(-10, 10, mod_.getPRNG());
1317 auto *mul0 = F_->createMatMul("mul0", input1, input2);
1318 auto *mul1 = F_->createMatMul("mul1", mul0, input3);
1319 auto *save = F_->createSave("ret", mul1);
1320 (void)save;
1321 std::vector<DeviceInfo> devices = {{1000, "Interpreter"},
1322 {1000, "Interpreter"}};
1323 // Create two backends which support different ops, then do the partition by
1324 // assigning the ops to the corresponding abackends.
1325 Partitioner partitioner(&mod_, devices);
1326 CompilationContext cctx;
1327 cctx.saturateHost = true;
1328 auto dagList = partitioner.partition(cctx);
1329 ASSERT_TRUE((bool)dagList);
1330 // Check there are 2 partitions.
1331 EXPECT_EQ(mod_.getFunctions().size(), 2);
1332 EXPECT_EQ(dagList->size(), 1);
1333 ASSERT_TRUE(checkSaveNode(mod_));
1334
1335 for (auto &dag : dagList.get()) {
1336 // Check number of logical devices;
1337 llvm::SmallSet<DeviceIDTy, 4> usedID;
1338 for (auto &node : dag.nodes) {
1339 EXPECT_EQ(node->logicalDevices.size(), 1);
1340 usedID.insert(node->logicalDevices[0]);
1341 }
1342 // Check there are 2 logical devices.
1343 EXPECT_EQ(usedID.size(), 2);
1344 }
1345 mod_.clear();
1346}
1347
1348/// Test assigning more than one partitions in to one device in Heterogeneous
1349/// partition.
1350TEST_F(PartitionerTest, logicalIDTest1) {
1351 createSimpleModule(mod_);
1352 BackendWithoutSub backendWithoutSub1, backendWithoutSub2;
1353 BackendWithoutMul backendWithoutMul1, backendWithoutMul2;
1354 // Create two backends which support different ops, then do the partition by
1355 // assigning the ops to the corresponding abackends.
1356 std::vector<Backend *> backends;
1357 backends.emplace_back(&backendWithoutMul1);
1358 backends.emplace_back(&backendWithoutSub1);
1359 std::vector<DeviceInfo> devices = {{3072, "Interpreter"}, {3072, "CPU"}};
1360 Partitioner partitioner(&mod_, devices, backends);
1361 CompilationContext cctx;
1362 cctx.saturateHost = true;
1363 auto dagList = partitioner.partition(cctx);
1364 ASSERT_TRUE((bool)dagList);
1365 EXPECT_EQ(mod_.getFunctions().size(), 3);
1366 EXPECT_EQ(dagList->size(), 1);
1367 ASSERT_TRUE(checkSaveNode(mod_));
1368
1369 for (auto &dag : dagList.get()) {
1370 // Check number of logical devices;
1371 llvm::SmallSet<DeviceIDTy, 4> usedID;
1372 for (auto &node : dag.nodes) {
1373 // Although the saturateHost is set true, no saturating the host in
1374 // heterogeneous partiton.
1375 EXPECT_EQ(node->logicalDevices.size(), 1);
1376 usedID.insert(node->logicalDevices[0]);
1377 }
1378 EXPECT_EQ(usedID.size(), 2);
1379 }
1380 mod_.clear();
1381}
1382
1383/// Check the function getGraphMemInfo and updateGraphMemInfo to handle more
1384/// than one outputs of a single Node in PartitionerUtils.cpp
1385TEST_F(PartitionerTest, graphMemInfoCalculation1) {
1386 // TODO: The values are too large when dim_t is 32b. Figure out how it's
1387 // computed and ensure it's computed correctly.
1388 if (DIM_T_BITWIDTH == 32)
1389 return;
1390 auto *inp1 =
1391 mod_.createPlaceholder(ElemKind::FloatTy, {2, 1, 3}, "input", false);
1392 auto *inp2 =
1393 mod_.createPlaceholder(ElemKind::FloatTy, {2, 1, 3}, "input", false);
1394 auto *indices =
1395 mod_.createPlaceholder(ElemKind::Int64ITy, {4, 1, 2}, "indices", false);
1396
1397 auto *R1 = F_->createTopK("TopK1", inp1, 2);
1398 auto *R2 = F_->createTopK("TopK2", inp2, 2);
1399
1400 // Concat the values and indices separately, both on the 0th dimension,
1401 // matching the shapes of the values and indices variables above.
1402 auto *CV =
1403 F_->createConcat("Concat.Values", {R1->getValues(), R2->getValues()}, 0);
1404 auto *CI = F_->createConcat("Concat.Indices",
1405 {R1->getIndices(), R2->getIndices()}, 0);
1406
1407 auto *saveValues = F_->createSave("Save.Values", CV);
1408 auto *saveIndices = F_->createSave("Save.Indices", CI, indices);
1409
1410 std::set<Node *> nodes1;
1411 GraphMemInfo res;
1412 res = updateGraphMemInfoByAddingNode(nodes1, res, R1);
1413 EXPECT_EQ(res, GraphMemInfo(24, 48, 0));
1414 nodes1.insert(R1);
1415
1416 res = updateGraphMemInfoByAddingNode(nodes1, res, R2);
1417 EXPECT_EQ(res, GraphMemInfo(48, 96, 0));
1418 nodes1.insert(R2);
1419
1420 res = updateGraphMemInfoByAddingNode(nodes1, res, CV);
1421 EXPECT_EQ(res, GraphMemInfo(48, 96, 0));
1422 nodes1.insert(CV);
1423
1424 res = updateGraphMemInfoByAddingNode(nodes1, res, CI);
1425 EXPECT_EQ(res, GraphMemInfo(48, 96, 0));
1426 nodes1.insert(CI);
1427
1428 res = updateGraphMemInfoByAddingNode(nodes1, res, saveValues);
1429 EXPECT_EQ(res, GraphMemInfo(48, 96, 0));
1430 nodes1.insert(saveValues);
1431
1432 res = updateGraphMemInfoByAddingNode(nodes1, res, saveIndices);
1433 EXPECT_EQ(res, GraphMemInfo(48, 96, 0));
1434 nodes1.insert(saveIndices);
1435
1436 std::set<Node *> nodes2, nodes3;
1437 nodes2.insert(R1);
1438 nodes2.insert(R2);
1439 nodes3.insert(CV);
1440 nodes3.insert(CI);
1441 nodes3.insert(saveValues);
1442 nodes3.insert(saveIndices);
1443 GraphMemInfo res1 = getGraphMemInfo(nodes2, 1);
1444 GraphMemInfo res2 = getGraphMemInfo(nodes3, 1);
1445 GraphMemInfo ref1(48, 96, 0);
1446 GraphMemInfo ref2(96, 96, 0);
1447 EXPECT_EQ(res1, ref1);
1448 EXPECT_EQ(res2, ref2);
1449}
1450
1451/// Check the function updateGraphMemInfoByAddingNode and getGraphMemInfo to
1452/// handle shared Storage node in PartitionerUtils.cpp
1453TEST_F(PartitionerTest, graphMemInfoCalculation2) {
1454 auto *input =
1455 mod_.createPlaceholder(ElemKind::FloatTy, {1, 16}, "input", false);
1456
1457 // Left branch.
1458 auto *w2 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w2");
1459 auto *b2 = mod_.createConstant(ElemKind::FloatTy, {16}, "b2");
1460 auto *L = F_->createFullyConnected("left_fc1", input, w2, b2);
1461 auto *L1 = F_->createSigmoid("left_sigmoid1", L);
1462 auto *w3 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w3");
1463 auto *b3 = mod_.createConstant(ElemKind::FloatTy, {8}, "b3");
1464 auto *L2 = F_->createFullyConnected("left_fc2", L1, w3, b3);
1465 auto *L3 = F_->createSigmoid("left_sigmoid2", L2);
1466
1467 // Right branch.
1468 auto *R = F_->createFullyConnected("right_fc1", input, w2, b2);
1469 auto *R1 = F_->createSigmoid("right_sigmoid1", R);
1470 auto *w5 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w5");
1471 auto *b5 = mod_.createConstant(ElemKind::FloatTy, {8}, "b5");
1472 auto *R2 = F_->createFullyConnected("right_fc2", R1, w5, b5);
1473 auto *R3 = F_->createSigmoid("right_sigmoid2", R2);
1474
1475 // Join branches.
1476 auto *mul = F_->createMul("mul", L3, R3);
1477 auto *save = F_->createSave("ret", mul);
1478
1479 std::set<Node *> nodes1, nodes2;
1480 GraphMemInfo res1, res2;
1481 res1 = updateGraphMemInfoByAddingNode(nodes1, res1, L);
1482 EXPECT_EQ(res1, GraphMemInfo(64, 64, 1088));
1483 nodes1.insert(L);
1484
1485 res1 = updateGraphMemInfoByAddingNode(nodes1, res1, R);
1486 EXPECT_EQ(res1, GraphMemInfo(64, 128, 1088));
1487 nodes1.insert(R);
1488
1489 res1 = updateGraphMemInfoByAddingNode(nodes1, res1, R1);
1490 EXPECT_EQ(res1, GraphMemInfo(64, 128, 1088));
1491 nodes1.insert(R1);
1492
1493 res1 = updateGraphMemInfoByAddingNode(nodes1, res1, R2);
1494 EXPECT_EQ(res1, GraphMemInfo(64, 96, 1632));
1495 nodes1.insert(R2);
1496
1497 res1 = getGraphMemInfo(nodes1, 1);
1498 EXPECT_EQ(res1, GraphMemInfo(64, 96, 1632));
1499
1500 res2 = updateGraphMemInfoByAddingNode(nodes2, res2, L1);
1501 EXPECT_EQ(res2, GraphMemInfo(64, 64, 0));
1502 nodes2.insert(L1);
1503
1504 res2 = updateGraphMemInfoByAddingNode(nodes2, res2, L2);
1505 EXPECT_EQ(res2, GraphMemInfo(64, 32, 544));
1506 nodes2.insert(L2);
1507
1508 res2 = updateGraphMemInfoByAddingNode(nodes2, res2, L3);
1509 EXPECT_EQ(res2, GraphMemInfo(64, 32, 544));
1510 nodes2.insert(L3);
1511
1512 res2 = updateGraphMemInfoByAddingNode(nodes2, res2, mul);
1513 EXPECT_EQ(res2, GraphMemInfo(96, 32, 544));
1514 nodes2.insert(mul);
1515
1516 res2 = updateGraphMemInfoByAddingNode(nodes2, res2, R3);
1517 EXPECT_EQ(res2, GraphMemInfo(96, 32, 544));
1518 nodes2.insert(R3);
1519
1520 res2 = updateGraphMemInfoByAddingNode(nodes2, res2, save);
1521 EXPECT_EQ(res2, GraphMemInfo(96, 32, 544));
1522 nodes2.insert(save);
1523
1524 res2 = getGraphMemInfo(nodes2, 1);
1525 EXPECT_EQ(res2, GraphMemInfo(96, 32, 544));
1526}
1527
1528/// Check the function getFunctionMemory in PartitionerUtils.cpp to compute
1529/// memory consumption of a simple function with same inputs used for multiple
1530/// nodes.
1531TEST_F(PartitionerTest, funcMemInfoCalculation1) {
1532 auto *input1 =
1533 mod_.createPlaceholder(ElemKind::FloatTy, {16}, "input1", false);
1534 auto *input2 =
1535 mod_.createPlaceholder(ElemKind::FloatTy, {16}, "input2", false);
1536 auto *input3 =
1537 mod_.createPlaceholder(ElemKind::FloatTy, {16}, "input3", false);
1538 auto *sub = F_->createSub("sub", input1, input2);
1539 auto *mul = F_->createMul("mul", input1, input2);
1540 auto *sum = F_->createAdd("add", sub, mul);
1541 auto *sub2 = F_->createSub("sub1", sum, input3);
1542 auto *save = F_->createSave("ret", sub2);
1543 (void)save;
1544
1545 GraphMemInfo info = getFunctionMemory(F_);
1546 // 3x input Tensors of 16 fp32 each and 1x output of 16 fp32 values.
1547 GraphMemInfo res(192, 64, 0);
1548 EXPECT_EQ(res, info);
1549}
1550
1551/// Check the function getFunctionMemory in PartitionerUtils.cpp to compute
1552/// memory consumption of a function with constants.
1553TEST_F(PartitionerTest, funcMemInfoCalculation2) {
1554 auto *input =
1555 mod_.createPlaceholder(ElemKind::FloatTy, {1, 16}, "input", false);
1556
1557 // Left branch.
1558 auto *w2 = mod_.createConstant(ElemKind::FloatTy, {16, 16}, "w2");
1559 auto *b2 = mod_.createConstant(ElemKind::FloatTy, {16}, "b2");
1560 auto *L = F_->createFullyConnected("left_fc1", input, w2, b2);
1561 auto *L1 = F_->createSigmoid("left_sigmoid1", L);
1562 auto *w3 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w3");
1563 auto *b3 = mod_.createConstant(ElemKind::FloatTy, {8}, "b3");
1564 auto *L2 = F_->createFullyConnected("left_fc2", L1, w3, b3);
1565 auto *L3 = F_->createSigmoid("left_sigmoid2", L2);
1566
1567 // Right branch.
1568 auto *R = F_->createFullyConnected("right_fc1", input, w2, b2);
1569 auto *R1 = F_->createSigmoid("right_sigmoid1", R);
1570 auto *w5 = mod_.createConstant(ElemKind::FloatTy, {16, 8}, "w5");
1571 auto *b5 = mod_.createConstant(ElemKind::FloatTy, {8}, "b5");
1572 auto *R2 = F_->createFullyConnected("right_fc2", R1, w5, b5);
1573 auto *R3 = F_->createSigmoid("right_sigmoid2", R2);
1574
1575 // Join branches.
1576 auto *mul = F_->createMul("mul", L3, R3);
1577 auto *save = F_->createSave("ret", mul);
1578 (void)save;
1579
1580 GraphMemInfo info = getFunctionMemory(F_);
1581 // single input tensor (1*16) fp32
1582 // single output tensor (1*8) fp32
1583 // constants fp32 (16*16 + 16 + 16*8 + 8 + 16*8 + 8)
1584 GraphMemInfo res(64, 32, 2176);
1585 EXPECT_EQ(res, info);
1586}
1587
1588/// Check the function getFunctionMemory in PartitionerUtils.cpp to compute
1589/// memory consumption of a function with same inputs used for multiple
1590/// nodes.
1591TEST_F(PartitionerTest, funcMemInfoCalculation3) {
1592 auto *input1 =
1593 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
1594 auto *input2 =
1595 mod_.createPlaceholder(ElemKind::FloatTy, {10, 16}, "input2", false);
1596 auto *input3 =
1597 mod_.createPlaceholder(ElemKind::FloatTy, {16, 20}, "input3", false);
1598 auto *input4 =
1599 mod_.createPlaceholder(ElemKind::FloatTy, {20, 1}, "input4", false);
1600 auto *input5 =
1601 mod_.createPlaceholder(ElemKind::FloatTy, {1, 50}, "input5", false);
1602 auto *mul0 = F_->createMatMul("mul0", input1, input2);
1603 auto *mul1 = F_->createMatMul("mul1", mul0, input3);
1604 auto *mul2 = F_->createMatMul("mul2", mul1, input4);
1605 auto *mul3 = F_->createMatMul("mul3", mul2, input5);
1606 auto *save = F_->createSave("ret", mul3);
1607 (void)save;
1608
1609 GraphMemInfo info = getFunctionMemory(F_);
1610 // input consists of 5 tensors (2*10 + 10*16 + 16*20 + 20*1 + 1*50 = 570) fp32
1611 // output is tensor of 2*50 = 100 fp32
1612 GraphMemInfo res(2280, 400, 0);
1613 EXPECT_EQ(res, info);
1614}
1615
1616/// This one test the memoryUsageValidation in Partitioner : the memory usage
1617/// of one single node is larger than the given device memory.
1618TEST_F(PartitionerTest, memoryUsageValidation1) {
1619 auto *input1 =
1620 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
1621 auto *input2 =
1622 mod_.createPlaceholder(ElemKind::FloatTy, {10, 16}, "input2", false);
1623 auto *mul0 = F_->createMatMul("mul0", input1, input2);
1624 F_->createSave("ret", mul0);
1625
1626 std::vector<DeviceInfo> devices = {{500, "Interpreter"},
1627 {500, "Interpreter"}};
1628 Partitioner myPartitioner(&mod_, devices);
1629 CompilationContext cctx;
1630 auto dagList = myPartitioner.partition(cctx);
1631 EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError()));
1632}
1633
1634/// This one test dagValidation in partitioner : p1->p2, p2->p1.
1635TEST_F(PartitionerTest, dagValidationWithBackendHints) {
1636 auto *input1 =
1637 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
1638 auto *input2 =
1639 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2", false);
1640 auto *input3 =
1641 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3", false);
1642 auto *add1 = F_->createAdd("add1", input1, input2);
1643 auto *add2 = F_->createAdd("add2", add1, input3);
1644 auto *sub1 = F_->createSub("sub1", add1, add2);
1645 F_->createSave("save", sub1);
1646
1647 std::vector<DeviceInfo> devices = {{3072, "Interpreter"},
1648 {3072, "Interpreter"}};
1649
1650 // User-defined partition: p1->p2, p2->p1.
1651 PartitionConfig partitionConfig;
1652 partitionConfig.funcName = "main";
1653 partitionConfig.numOfPartitions = 2;
1654 BackendHints bh1, bh2;
1655 bh1.executionUnits = 2;
1656 bh2.executionUnits = 3;
1657 partitionConfig.backendHints = {bh1, bh2};
1658 partitionConfig.backendNames = {"Interpreter", "Interpreter"};
1659 partitionConfig.partitionNames = {"p1", "p2"};
1660 partitionConfig.nodeToPartition = {{"add2", 0}};
1661 auto partitioner = Partitioner(&mod_, devices, false, partitionConfig);
1662 CompilationContext cctx;
1663 auto dagList = partitioner.partition(cctx);
1664 EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError()));
1665}
1666
1667/// This one test dagValidation in partitioner : p1->p2, p2->p1.
1668TEST_F(PartitionerTest, dagValidation1) {
1669 auto *input1 =
1670 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
1671 auto *input2 =
1672 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2", false);
1673 auto *input3 =
1674 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3", false);
1675 auto *add1 = F_->createAdd("add1", input1, input2);
1676 auto *add2 = F_->createAdd("add2", add1, input3);
1677 auto *sub1 = F_->createSub("sub1", add1, add2);
1678 F_->createSave("save", sub1);
1679
1680 std::vector<DeviceInfo> devices = {{3072, "Interpreter"},
1681 {3072, "Interpreter"}};
1682
1683 // User-defined partition: p1->p2, p2->p1.
1684 PartitionConfig partitionConfig;
1685 partitionConfig.funcName = "main";
1686 partitionConfig.numOfPartitions = 2;
1687 partitionConfig.backendNames = {"Interpreter", "Interpreter"};
1688 partitionConfig.partitionNames = {"p1", "p2"};
1689 partitionConfig.nodeToPartition = {{"add2", 0}};
1690 auto partitioner = Partitioner(&mod_, devices, false, partitionConfig);
1691 CompilationContext cctx;
1692 auto dagList = partitioner.partition(cctx);
1693 EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError()));
1694}
1695
1696/// This one test dagValidation in partitioner: p0->p1, p1->p2, p2->p1.
1697TEST_F(PartitionerTest, dagValidation2) {
1698 auto *input1 =
1699 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
1700 auto *input2 =
1701 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2", false);
1702 auto *input3 =
1703 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3", false);
1704 auto *input4 =
1705 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input4", false);
1706 auto *add0 = F_->createAdd("add0", input1, input2);
1707 auto *add1 = F_->createAdd("add1", add0, input3);
1708 auto *add2 = F_->createAdd("add2", add1, input4);
1709 auto *sub1 = F_->createSub("sub1", add1, add2);
1710 F_->createSave("save", sub1);
1711
1712 std::vector<DeviceInfo> devices = {
1713 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "Interpreter"}};
1714
1715 // User-defined partition: p0->p1, p1->p2, p2->p1.
1716 PartitionConfig partitionConfig;
1717 partitionConfig.funcName = "main";
1718 partitionConfig.numOfPartitions = 3;
1719 partitionConfig.backendNames = {"Interpreter", "Interpreter", "Interpreter"};
1720 partitionConfig.partitionNames = {"p0", "p1", "p2"};
1721 partitionConfig.nodeToPartition = {{"add0", 0}, {"add2", 2}};
1722 auto partitioner = Partitioner(&mod_, devices, false, partitionConfig);
1723 CompilationContext cctx;
1724 auto dagList = partitioner.partition(cctx);
1725 EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError()));
1726}
1727
1728/// This one tests partition from a user-defined config.
1729TEST_F(PartitionerTest, partitionFromConfig) {
1730#ifndef GLOW_WITH_CPU
1731 return;
1732#endif
1733 createSimpleModule(mod_);
1734 std::vector<DeviceInfo> devices = {
1735 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "CPU"}};
1736
1737 // User-defined partition: 3 partitions (2 interpreter, 1 cpu), Mul nodes to
1738 // CPU, others to Interpreter.
1739 PartitionConfig partitionConfig;
1740 partitionConfig.funcName = "test";
1741 partitionConfig.numOfPartitions = 3;
1742 partitionConfig.backendNames = {"Interpreter", "CPU", "Interpreter"};
1743 partitionConfig.partitionNames = {"p1", "p2", "p3"};
1744 partitionConfig.nodeToPartition = {{"sub", 0}, {"mul", 1}};
1745 Partitioner partitioner(&mod_, devices, false, partitionConfig);
1746 CompilationContext cctx;
1747 auto dagList = partitioner.partition(cctx);
1748 ASSERT_TRUE((bool)dagList);
1749 EXPECT_EQ(mod_.getFunctions().size(), 3);
1750 EXPECT_EQ(dagList->size(), 1);
1751 ASSERT_TRUE(checkSaveNode(mod_));
1752 heterogeneousPartitionValidation(dagList.get(), mod_);
1753}
1754
1755/// Test user-defined partition with user specified logical devices through
1756/// compilationContext.
1757TEST_F(PartitionerTest, partitionFromConfigWithLogicalDevices) {
1758 auto *input1 =
1759 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
1760 auto *input2 =
1761 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2", false);
1762 auto *input3 =
1763 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3", false);
1764 auto *add1 = F_->createAdd("add1", input1, input2);
1765 auto *add2 = F_->createAdd("add2", add1, input3);
1766 auto *sub1 = F_->createSub("sub1", add1, add2);
1767 F_->createSave("save", sub1);
1768
1769 std::vector<DeviceInfo> devices = {
1770 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "Interpreter"}};
1771
1772 // User-defined partition: p0->p1, p1->p2, p2->p1.
1773 PartitionConfig partitionConfig;
1774 partitionConfig.funcName = "main";
1775 partitionConfig.numOfPartitions = 3;
1776 partitionConfig.backendNames = {"Interpreter", "Interpreter", "Interpreter"};
1777 partitionConfig.partitionNames = {"p0", "p1", "p2"};
1778 partitionConfig.nodeToPartition = {{"add1", 0}, {"add2", 2}};
1779 partitionConfig.logicalIDs = {{0}, {1}, {0, 1}};
1780 auto partitioner = Partitioner(&mod_, devices);
1781 CompilationContext cctx;
1782 cctx.partitionConfig = &partitionConfig;
1783 auto result = partitioner.partition(cctx);
1784 DAGListTy nodeList;
1785 EXPECT_FALSE(ERR_TO_BOOL(result.takeError()));
1786 nodeList = std::move(result.get());
1787 // Check that p2 has both 0 and 1 for logicalDevices.
1788 EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[0], 0);
1789 EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[1], 1);
1790}
1791
1792/// Test user-defined partition with user specified logical devices through
1793/// compilationContext using fp16.
1794TEST_F(PartitionerTest, partitionFromConfigWithLogicalDevicesFp16) {
1795 auto *input1 =
1796 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
1797 auto *input2 =
1798 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2", false);
1799 auto *input3 =
1800 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3", false);
1801 auto *add1 = F_->createAdd("add1", input1, input2);
1802 auto *add2 = F_->createAdd("add2", add1, input3);
1803 auto *sub1 = F_->createSub("sub1", add1, add2);
1804 F_->createSave("save", sub1);
1805
1806 std::vector<DeviceInfo> devices = {
1807 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "Interpreter"}};
1808
1809 // User-defined partition: p0->p1, p1->p2, p2->p1.
1810 PartitionConfig partitionConfig;
1811 partitionConfig.funcName = "main";
1812 partitionConfig.numOfPartitions = 3;
1813 partitionConfig.backendNames = {"Interpreter", "Interpreter", "Interpreter"};
1814 partitionConfig.partitionNames = {"p0", "p1", "p2"};
1815 partitionConfig.nodeToPartition = {{"add1", 0}, {"add2", 2}};
1816 partitionConfig.logicalIDs = {{0}, {1}, {0, 1}};
1817 auto partitioner = Partitioner(&mod_, devices);
1818 CompilationContext cctx;
1819 cctx.partitionConfig = &partitionConfig;
1820 PrecisionConfiguration pc;
1821 pc.convertToFP16 = true;
1822 cctx.precisionConfig = pc;
1823 auto result = partitioner.partition(cctx);
1824
1825 // Do optimization
1826 for (dim_t i = 0; i < partitionConfig.numOfPartitions; i++) {
1827 for (auto &func : mod_.getFunctions()) {
1828 std::unique_ptr<Backend> backend(createBackend("Interpreter"));
1829 auto err = ::glow::optimizeFunction(func, *backend, cctx);
1830 EXPECT_FALSE(err);
1831 }
1832 }
1833
1834 DAGListTy nodeList;
1835 EXPECT_FALSE(ERR_TO_BOOL(result.takeError()));
1836 nodeList = std::move(result.get());
1837 // Check that p2 has both 0 and 1 for logicalDevices.
1838 EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[0], 0);
1839 EXPECT_EQ(nodeList[0].nodes[2]->logicalDevices[1], 1);
1840 // Check that the inputs and outputs of add1, add2 and sub1 are in fp16
1841 for (auto const &F : mod_.getFunctions()) {
1842 for (auto const &N : F->getNodes()) {
1843 auto NI = NodeInfo(N);
1844 if (NI.getKind() != Kinded::Kind::SaveNodeKind &&
1845 NI.getKind() != Kinded::Kind::ConvertToNodeKind) {
1846 EXPECT_TRUE(
1847 NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty}));
1848 }
1849 }
1850 }
1851}
1852
1853/// This one tests calling PartitionFromConfig directly.
1854TEST_F(PartitionerTest, partitionFromConfigDirectCall) {
1855#ifndef GLOW_WITH_CPU
1856 return;
1857#endif
1858 createSimpleModule(mod_);
1859 std::vector<DeviceInfo> devices = {
1860 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "CPU"}};
1861
1862 // User-defined partition: 3 partitions (2 interpreter, 1 cpu), Mul nodes to
1863 // CPU, others to Interpreter.
1864 PartitionConfig partitionConfig;
1865 partitionConfig.funcName = "test";
1866 partitionConfig.numOfPartitions = 3;
1867 partitionConfig.backendNames = {"Interpreter", "CPU", "Interpreter"};
1868 partitionConfig.partitionNames = {"p1", "p2", "p3"};
1869 partitionConfig.nodeToPartition = {{"sub", 0}, {"mul", 1}};
1870 Partitioner partitioner(&mod_, devices);
1871 CompilationContext cctx;
1872 auto dagList = partitioner.partitionFromConfig(partitionConfig, cctx);
1873 ASSERT_TRUE((bool)dagList);
1874 EXPECT_EQ(mod_.getFunctions().size(), 3);
1875 EXPECT_EQ(dagList->size(), 1);
1876 ASSERT_TRUE(checkSaveNode(mod_));
1877 heterogeneousPartitionValidation(dagList.get(), mod_);
1878}
1879
1880/// This one test load-balanced partition flow.
1881TEST_F(PartitionerTest, loadBalancedPartition) {
1882 ExecutionEngine EER, EEP;
1883 EEP.setSkipModuleStrip(true);
1884 constexpr float range = 2.0;
1885 std::vector<ExecutionEngine *> engines{&EER, &EEP};
1886 // Since compiling modifies the module and partitioning modifies the
1887 // function, setup two EEs with identical functions for validation.
1888 for (auto EE : engines) {
1889 auto mod = &EE->getModule();
1890 F_ = mod->createFunction("main");
1891 auto *input =
1892 mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input", false);
1893 auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1");
1894 auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1");
1895 bindings_.allocate(input);
1896 w1->getHandle<>().randomize(-range, range, mod->getPRNG());
1897 b1->getHandle<>().randomize(-range, range, mod->getPRNG());
1898
1899 // Initial FC.
1900 Node *I = F_->createFullyConnected("initial_fc", input, w1, b1);
1901 I = F_->createSigmoid("initial_sigmoid", I);
1902
1903 // Left branch.
1904 auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2");
1905 auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2");
1906 w2->getHandle<>().randomize(-range, range, mod->getPRNG());
1907 b2->getHandle<>().randomize(-range, range, mod->getPRNG());
1908 Node *L = F_->createFullyConnected("left_fc1", I, w2, b2);
1909 L = F_->createSigmoid("left_sigmoid1", L);
1910 auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3");
1911 auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3");
1912 w3->getHandle<>().randomize(-range, range, mod->getPRNG());
1913 b3->getHandle<>().randomize(-range, range, mod->getPRNG());
1914 L = F_->createFullyConnected("left_fc2", L, w3, b3);
1915 L = F_->createSigmoid("left_sigmoid2", L);
1916
1917 // Right branch.
1918 auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4");
1919 auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4");
1920 w4->getHandle<>().randomize(-range, range, mod->getPRNG());
1921 b4->getHandle<>().randomize(-range, range, mod->getPRNG());
1922 Node *R = F_->createFullyConnected("right_fc1", I, w4, b4);
1923 R = F_->createSigmoid("right_sigmoid1", R);
1924 auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5");
1925 auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5");
1926 w5->getHandle<>().randomize(-range, range, mod->getPRNG());
1927 b5->getHandle<>().randomize(-range, range, mod->getPRNG());
1928 R = F_->createFullyConnected("right_fc2", R, w5, b5);
1929 R = F_->createSigmoid("right_sigmoid2", R);
1930
1931 // Join branches.
1932 auto *mul = F_->createMul("mul", L, R);
1933 F_->createSave("ret", mul);
1934 }
1935
1936 // Infer using the un-partitioned graph.
1937 Tensor in(ElemKind::FloatTy, {1, 32});
1938 in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG());
1939
1940 EER.compile(CompilationMode::Infer);
1941 bindings_.clear();
1942 bindings_.allocate(EER.getModule().getPlaceholders());
1943 updateInputPlaceholders(bindings_,
1944 {bindings_.getPlaceholderByNameSlow("input")}, {&in});
1945 EER.run(bindings_);
1946 Tensor ref =
1947 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
1948
1949 std::vector<DeviceInfo> devices = {
1950 {3072, "Interpreter"}, {3072, "Interpreter"}, {3072, "Interpreter"}};
1951 Partitioner myPartitioner(&EEP.getModule(), devices, true);
1952 CompilationContext cctx;
1953 auto dagList = myPartitioner.loadBalancedPartition(cctx);
1954 ASSERT_TRUE((bool)dagList);
1955 EXPECT_EQ(EEP.getModule().getFunctions().size(), 3);
1956 EXPECT_EQ(dagList->size(), 1);
1957 EXPECT_TRUE(checkSaveNode(EEP.getModule()));
1958
1959 // Run the paritioned graph and compare the results.
1960 bindings_.clear();
1961 bindings_.allocate(EEP.getModule().getPlaceholders());
1962 EEP.compile(cctx);
1963 executeDAG(dagList->begin()->root.get(), EEP.getModule(), bindings_,
1964 {bindings_.getPlaceholderByNameSlow("input")}, {&in}, &EEP);
1965 Tensor test =
1966 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
1967 EXPECT_TRUE(ref.isEqual(test, 0.0f));
1968 verifyDAGSerialization(*dagList, EEP.getModule(), bindings_, {"input"}, "ret",
1969 devices, {&in}, ref);
1970}
1971
1972/// This tests the pre-partitioned flow.
1973TEST_F(PartitionerTest, PrePartitionedTest) {
1974 CompilationContext cctx;
1975 PrePartitionedConfig PPC;
1976 cctx.prepartitionedConfig = &PPC;
1977 Function *F0 = F_;
1978 Function *F1 = mod_.createFunction("main_1");
1979 Function *F2 = mod_.createFunction("main_2");
1980 PPC.funcs.push_back(F0);
1981 PPC.funcs.push_back(F1);
1982 PPC.funcs.push_back(F2);
1983 PPC.logicalIDs.resize(3);
1984 PPC.logicalIDs[0].push_back(0);
1985 PPC.logicalIDs[1].push_back(1);
1986 PPC.logicalIDs[2].push_back(1);
1987 PPC.logicalIDs[2].push_back(2);
1988 PPC.backendSpecificOpts.emplace_back(
1989 BackendSpecificOptions{{"opt0", "val0"}, {"opt1", "val1"}});
1990 PPC.backendSpecificOpts.emplace_back(
1991 BackendSpecificOptions{{"opt2", "val2"}});
1992 PPC.backendSpecificOpts.emplace_back(BackendSpecificOptions{});
1993 PPC.replicationCounts.push_back(3);
1994 PPC.replicationCounts.push_back(4);
1995 PPC.replicationCounts.push_back(1);
1996 PPC.backendHints.push_back({7, {"a"}});
1997 PPC.backendHints.push_back({8, {"b"}});
1998 PPC.backendHints.push_back({9, {"c", "d"}});
1999
2000 auto *I0 = mod_.createPlaceholder(ElemKind::FloatTy, {5, 5}, "I0", false);
2001 auto *I1 = mod_.createPlaceholder(ElemKind::FloatTy, {5, 5}, "I1", false);
2002 auto *I2 = mod_.createPlaceholder(ElemKind::FloatTy, {5, 5}, "I1", false);
2003
2004 // Partition 0 is a MatMul and Save.
2005 MatMulNode *MM = F0->createMatMul("MM", I0, I1);
2006 SaveNode *SMM = F0->createSave("SMM", MM);
2007
2008 // Partition 1 loads from the Partition 0 MatMul.
2009 AddNode *AN = F1->createAdd("AN", SMM->getPlaceholder(), I2);
2010 SaveNode *SAN = F1->createSave("SAN", AN);
2011
2012 // Partition 2 loads from both Partition 0 and 1.
2013 MulNode *MN = F2->createMul("MN", SMM->getPlaceholder(), I0);
2014 SubNode *SN = F2->createSub("SN", SAN->getPlaceholder(), MN);
2015 SaveNode *finalSave = F2->createSave("finalSave", SN);
2016
2017 const runtime::DeviceInfo dev{/* 16GB: */ 0x400000000, "Interpreter"};
2018 const std::vector<runtime::DeviceInfo> devices(3, dev);
2019 Partitioner partitioner(&mod_, devices);
2020 DAGListTy d;
2021 ASSIGN_VALUE_OR_FAIL_TEST(d, partitioner.setupPrepartitionedModule(cctx));
2022
2023 // Note: DAG should look like: F0 -> F1
2024 // \ |
2025 // v v
2026 // F2
2027
2028 ASSERT_EQ(d.size(), 1);
2029
2030 DAGNodePtr &root = d[0].root;
2031 EXPECT_EQ(root->module, &mod_);
2032
2033 ASSERT_EQ(root->children.size(), 1);
2034 DAGNode *D0 = root->children[0];
2035 ASSERT_EQ(D0->name, F0->getName());
2036 ASSERT_EQ(D0->parents.size(), 1);
2037 EXPECT_EQ(D0->parents[0], root.get());
2038 ASSERT_EQ(D0->logicalDevices.size(), 1);
2039 EXPECT_EQ(D0->logicalDevices[0], 0);
2040 EXPECT_EQ(D0->size, I0->getType()->getSizeInBytes() +
2041 I1->getType()->getSizeInBytes() +
2042 SMM->getPlaceholder()->getType()->getSizeInBytes());
2043 EXPECT_EQ(D0->backendSpecificOpts.size(), 2);
2044 ASSERT_TRUE(D0->backendSpecificOpts.count("opt0"));
2045 EXPECT_EQ(D0->backendSpecificOpts.at("opt0"), "val0");
2046 ASSERT_TRUE(D0->backendSpecificOpts.count("opt1"));
2047 EXPECT_EQ(D0->backendSpecificOpts.at("opt1"), "val1");
2048 EXPECT_EQ(D0->replicationCount, 3);
2049 EXPECT_EQ(D0->backendHints.executionUnits, 7);
2050 ASSERT_EQ(D0->backendHints.SRAMPrioritization.size(), 1);
2051 EXPECT_EQ(D0->backendHints.SRAMPrioritization[0], "a");
2052
2053 ASSERT_EQ(D0->children.size(), 2);
2054 DAGNode *D1 = (D0->children[0]->name == F1->getName()) ? D0->children[0]
2055 : D0->children[1];
2056 ASSERT_EQ(D1->parents.size(), 1);
2057 EXPECT_EQ(D1->parents[0], D0);
2058 ASSERT_EQ(D1->name, F1->getName());
2059 ASSERT_EQ(D1->logicalDevices.size(), 1);
2060 EXPECT_EQ(D1->logicalDevices[0], 1);
2061 EXPECT_EQ(D1->size, I2->getType()->getSizeInBytes() +
2062 SAN->getPlaceholder()->getType()->getSizeInBytes() +
2063 SMM->getPlaceholder()->getType()->getSizeInBytes());
2064 EXPECT_EQ(D1->backendSpecificOpts.size(), 1);
2065 ASSERT_TRUE(D1->backendSpecificOpts.count("opt2"));
2066 EXPECT_EQ(D1->backendSpecificOpts.at("opt2"), "val2");
2067 EXPECT_EQ(D1->replicationCount, 4);
2068 EXPECT_EQ(D1->backendHints.executionUnits, 8);
2069 ASSERT_EQ(D1->backendHints.SRAMPrioritization.size(), 1);
2070 EXPECT_EQ(D1->backendHints.SRAMPrioritization[0], "b");
2071
2072 DAGNode *D2 = (D1 == D0->children[0]) ? D0->children[1] : D0->children[0];
2073 ASSERT_EQ(D2->name, F2->getName());
2074 ASSERT_EQ(D1->children.size(), 1);
2075 EXPECT_EQ(D1->children[0], D2);
2076 ASSERT_EQ(D2->parents.size(), 2);
2077 EXPECT_TRUE(D2->parents[0] == D0 || D2->parents[1] == D0);
2078 EXPECT_TRUE(D2->parents[0] == D1 || D2->parents[1] == D1);
2079 EXPECT_NE(D2->parents[0], D2->parents[1]);
2080 ASSERT_EQ(D2->logicalDevices.size(), 2);
2081 EXPECT_TRUE(D2->logicalDevices[0] == 1 || D2->logicalDevices[0] == 2);
2082 EXPECT_TRUE(D2->logicalDevices[1] == 1 || D2->logicalDevices[1] == 2);
2083 EXPECT_NE(D2->logicalDevices[0], D2->logicalDevices[1]);
2084 EXPECT_EQ(D2->size,
2085 I0->getType()->getSizeInBytes() +
2086 SAN->getPlaceholder()->getType()->getSizeInBytes() +
2087 SMM->getPlaceholder()->getType()->getSizeInBytes() +
2088 finalSave->getPlaceholder()->getType()->getSizeInBytes());
2089 EXPECT_EQ(D2->backendSpecificOpts.size(), 0);
2090 EXPECT_EQ(D2->replicationCount, 1);
2091 EXPECT_EQ(D2->backendHints.executionUnits, 9);
2092 ASSERT_EQ(D2->backendHints.SRAMPrioritization.size(), 2);
2093 EXPECT_EQ(D2->backendHints.SRAMPrioritization[0], "c");
2094 EXPECT_EQ(D2->backendHints.SRAMPrioritization[1], "d");
2095}
2096
2097/// Test that constant folding (de)serialization works along with partitioning.
2098TEST_F(PartitionerTest, RecordedConstantFolding) {
2099 ExecutionEngine EER, EEP;
2100 EEP.setSkipModuleStrip(true);
2101 constexpr float range = 2.0;
2102 std::vector<ExecutionEngine *> engines{&EER, &EEP};
2103 // Since compiling modifies the module and partitioning modifies the function,
2104 // setup two EEs with identical functions for validation.
2105 for (auto EE : engines) {
2106 auto mod = &EE->getModule();
2107 F_ = mod->createFunction("main");
2108 auto *input =
2109 mod->createPlaceholder(ElemKind::FloatTy, {1, 32}, "input", false);
2110 auto *w1 = mod->createConstant(ElemKind::FloatTy, {32, 16}, "w1");
2111 auto *b1 = mod->createConstant(ElemKind::FloatTy, {16}, "b1");
2112 bindings_.allocate(input);
2113 w1->getHandle<>().randomize(-range, range, mod->getPRNG());
2114 b1->getHandle<>().randomize(-range, range, mod->getPRNG());
2115
2116 // Initial FC.
2117 Node *I = F_->createFullyConnected("initial_fc", input, w1, b1);
2118 I = F_->createSigmoid("initial_sigmoid", I);
2119
2120 // Left branch. Note that w2 and b2 will be constant folded.
2121 auto *w2 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w2");
2122 auto *b2 = mod->createConstant(ElemKind::FloatTy, {16}, "b2");
2123 w2->getHandle<>().randomize(-range, range, mod->getPRNG());
2124 b2->getHandle<>().randomize(-range, range, mod->getPRNG());
2125 auto *w2Clip = F_->createClip("clip_w2", w2, -1, 1);
2126 auto *b2Clip = F_->createClip("clip_b2", b2, -1, 1);
2127 Node *L = F_->createFullyConnected("left_fc1", I, w2Clip, b2Clip);
2128 L = F_->createSigmoid("left_sigmoid1", L);
2129 auto *w3 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w3");
2130 auto *b3 = mod->createConstant(ElemKind::FloatTy, {8}, "b3");
2131 w3->getHandle<>().randomize(-range, range, mod->getPRNG());
2132 b3->getHandle<>().randomize(-range, range, mod->getPRNG());
2133 L = F_->createFullyConnected("left_fc2", L, w3, b3);
2134 L = F_->createSigmoid("left_sigmoid2", L);
2135
2136 // Right branch. Note that w4 will be constant folded.
2137 auto *w4 = mod->createConstant(ElemKind::FloatTy, {16, 16}, "w4");
2138 auto *b4 = mod->createConstant(ElemKind::FloatTy, {16}, "b4");
2139 w4->getHandle<>().randomize(-range, range, mod->getPRNG());
2140 b4->getHandle<>().randomize(-range, range, mod->getPRNG());
2141 auto *w4Sig = F_->createSigmoid("w4_sig", w4);
2142 Node *R = F_->createFullyConnected("right_fc1", I, w4Sig, b4);
2143 R = F_->createSigmoid("right_sigmoid1", R);
2144 auto *w5 = mod->createConstant(ElemKind::FloatTy, {16, 8}, "w5");
2145 auto *b5 = mod->createConstant(ElemKind::FloatTy, {8}, "b5");
2146 w5->getHandle<>().randomize(-range, range, mod->getPRNG());
2147 b5->getHandle<>().randomize(-range, range, mod->getPRNG());
2148 R = F_->createFullyConnected("right_fc2", R, w5, b5);
2149 R = F_->createSigmoid("right_sigmoid2", R);
2150
2151 // Join branches.
2152 auto *mul = F_->createMul("mul", L, R);
2153 F_->createSave("ret", mul);
2154 }
2155
2156 // Infer using the un-partitioned graph.
2157 Tensor in(ElemKind::FloatTy, {1, 32});
2158 in.getHandle<>().randomize(-range, range, EER.getModule().getPRNG());
2159
2160 EER.compile(CompilationMode::Infer);
2161 bindings_.clear();
2162 bindings_.allocate(EER.getModule().getPlaceholders());
2163 updateInputPlaceholders(bindings_,
2164 {bindings_.getPlaceholderByNameSlow("input")}, {&in});
2165 EER.run(bindings_);
2166 Tensor ref =
2167 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
2168
2169 // Now try with partitioning, and partitioning + constant fold recording.
2170 auto &modP = EEP.getModule();
2171
2172 CompilationContext cctx;
2173 ASSERT_EQ(modP.getFunctions().size(), 1);
2174 Function *origF = *modP.getFunctions().begin();
2175 ConstantFoldingRecordMap record = constantFoldAndRecord(origF, cctx);
2176 runDCEPass(origF, cctx);
2177 // Expect 3 Constants were folded: w2, b2, and w4 from above.
2178 ASSERT_EQ(record.size(), 3);
2179
2180 const DeviceInfo devI{3072, "Interpreter"};
2181 std::vector<DeviceInfo> devices = {devI, devI, devI};
2182 Partitioner myPartitioner(&modP, devices, /* optimized */ true);
2183 EXPECT_TRUE(checkSaveNode(modP));
2184
2185 DAGListTy dagList;
2186 ASSIGN_VALUE_OR_FAIL_TEST(dagList, myPartitioner.loadBalancedPartition(cctx));
2187
2188 ASSERT_EQ(dagList.size(), 1);
2189 const auto &dag = *dagList.begin();
2190 EXPECT_EQ(dag.nodes.size(), 3);
2191
2192 // Verify that we serialize and deserialize the DAG correctly including with
2193 // the constant folding record, and that results are bitwise equal.
2194 verifyDAGSerialization(dagList, modP, bindings_, {"input"}, "ret", devices,
2195 {&in}, ref, &record);
2196
2197 // Now run the original partitioned model and verify it also is bitwise equal.
2198 bindings_.clear();
2199 bindings_.allocate(modP.getPlaceholders());
2200 EEP.compile(cctx);
2201
2202 executeDAG(dagList.begin()->root.get(), modP, bindings_,
2203 {bindings_.getPlaceholderByNameSlow("input")}, {&in}, &EEP);
2204 Tensor test =
2205 bindings_.get(bindings_.getPlaceholderByNameSlow("ret"))->clone();
2206 EXPECT_TRUE(ref.isEqual(test, 0.0f));
2207}
2208
2209/// This test verifies that resourceCount is being checked correctly, we set the
2210/// resourceCount to 1 and expect an error.
2211TEST_F(PartitionerTest, resourceCountValidationTest) {
2212 auto *input1 =
2213 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input1", false);
2214 auto *input2 =
2215 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input2", false);
2216 auto *input3 =
2217 mod_.createPlaceholder(ElemKind::FloatTy, {2, 10}, "input3", false);
2218 auto *add1 = F_->createAdd("add1", input1, input2);
2219 auto *add2 = F_->createAdd("add2", add1, input3);
2220 auto *sub1 = F_->createSub("sub1", add1, add2);
2221 F_->createSave("save", sub1);
2222
2223 std::vector<DeviceInfo> devices = {{3072, "Interpreter", {}},
2224 {3072, "Interpreter", {}}};
2225
2226 devices[0].inputCountMax = 1;
2227 devices[1].inputCountMax = 1;
2228 // User-defined partition: p1->p2, p2->p1.
2229 PartitionConfig partitionConfig;
2230 partitionConfig.funcName = "main";
2231 partitionConfig.numOfPartitions = 2;
2232 BackendHints bh1, bh2;
2233 bh1.executionUnits = 2;
2234 bh2.executionUnits = 3;
2235 partitionConfig.backendHints = {bh1, bh2};
2236 partitionConfig.backendNames = {"Interpreter", "Interpreter"};
2237 partitionConfig.partitionNames = {"p1", "p2"};
2238 partitionConfig.nodeToPartition = {{"add2", 0}};
2239 auto partitioner = Partitioner(&mod_, devices, false, partitionConfig);
2240 CompilationContext cctx;
2241 auto dagList = partitioner.partition(cctx);
2242 EXPECT_TRUE(ERR_TO_BOOL(dagList.takeError()));
2243}
2244
2245/// Tests that the given net is assigned and duplicated on the given logical
2246/// devices.
2247TEST_F(PartitionerTest, saturateKDevicesTest) {
2248 createSimpleModule(mod_);
2249 std::vector<DeviceInfo> devices = {{2048, "Interpreter", {}},
2250 {2048, "Interpreter", {}},
2251 {2048, "Interpreter", {}}};
2252 auto partitioner = Partitioner(&mod_, devices, false);
2253 // Partitioner should create DAG without partitioning, duplicate it and
2254 // assign to the given logical devices.
2255 DAGListTy dagList;
2256 CompilationContext cctx;
2257 cctx.saturateHost = true;
2258 cctx.saturateKDevices = 2;
2259 ASSIGN_VALUE_OR_FAIL_TEST(dagList, partitioner.partition(cctx));
2260 EXPECT_EQ(dagList.size(), 1);
2261
2262 int numOfInterpreterBackends = 0;
2263 for (auto &dag : dagList) {
2264 for (auto &node : dag.nodes) {
2265 // Verify the node is assigned to K devices.
2266 EXPECT_EQ(node->logicalDevices.size(), cctx.saturateKDevices);
2267
2268 if (node->backendName == "Interpreter") {
2269 numOfInterpreterBackends++;
2270 }
2271 }
2272 }
2273 EXPECT_EQ(numOfInterpreterBackends, 1);
2274}
2275