1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Partitioner/Partitioner.h"
18
19#include "folly/String.h"
20#include "glow/Flags/Flags.h"
21#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
22#include "glow/Partitioner/PartitionerOptimizer.h"
23#include "glow/Partitioner/PartitionerUtils.h"
24#include "glow/Partitioner/PartitionerValidation.h"
25#include "glow/Support/Support.h"
26
27#include "llvm/ADT/SmallSet.h"
28#include "llvm/Support/CommandLine.h"
29#include "llvm/Support/raw_ostream.h"
30#include <unordered_map>
31
32#include <fstream>
33
34namespace glow {
35static llvm::cl::opt<bool, /* ExternalStorage */ true>
36 GlowEnableLoadBalancedPartitioningOpt(
37 "partitioner_enable_load_balance",
38 llvm::cl::desc(
39 "Enable a partitioner pass to optimize for "
40 "load balance in addition to memory capacity constraints"),
41 llvm::cl::location(glow::flags::EnableLoadBalancedPartitioning));
42} // namespace glow
43
44/// -log-partition - Command line option to dump Partitioner logs.
45static llvm::cl::OptionCategory PartitionerCat("Glow Partitioner Options");
46static llvm::cl::opt<bool, /* ExternalStorage */ true>
47 logPartition("log-partition",
48 llvm::cl::desc("Enable logging partition info"),
49 llvm::cl::location(glow::flags::LogPartition),
50 llvm::cl::cat(PartitionerCat));
51
52/// -dump-partition - Command line option to dump the graph of each partitions
53/// by calling F->dumpDAG().
54static llvm::cl::opt<bool, /* ExternalStorage */ true>
55 dumpPartition("dump-partition",
56 llvm::cl::desc("Enable dumping the graph of each partitions"),
57 llvm::cl::location(glow::flags::DumpPartition),
58 llvm::cl::cat(PartitionerCat));
59
60using namespace glow;
61using llvm::isa;
62
63// Sorted the std::pair<DAGNode *, uint64_t> based on the second from min to
64// max.
65bool sortMinMemory(const std::pair<Function *, uint64_t> &a,
66 const std::pair<Function *, uint64_t> &b) {
67 return a.second < b.second;
68}
69
70void Partitioner::init() {
71 memSize_ = module_->getConstantsSize();
72 logicalDeviceID_ = 0;
73 multiBackendNames_ = false;
74 for (size_t i = 1, e = deviceInfo_.size(); i < e; i++) {
75 if (deviceInfo_[i].backendName != deviceInfo_[0].backendName) {
76 multiBackendNames_ = true;
77 break;
78 }
79 }
80}
81
82Error Partitioner::finalize(const DAGListTy &partitions,
83 const NodeToFunctionMap &mapping) {
84
85 // NOTE: Cannot validate the functions after partitioning here. The validation
86 // needs the backend specific verifier. Tensor layouts, for example, might
87 // have gone from canonical form to backend specific form.
88
89 if (logPartition) {
90 LOG(INFO) << "The number of partitions is : "
91 << mapping.getPartitions().size();
92 logPartitionInfo(mapping);
93 }
94
95 // Dump the graph of each function after partitioning.
96 if (dumpPartition) {
97 LOG(INFO) << "Dumping partitioning DAG to DAG.dot file.";
98 dumpDAG("DAG.dot", partitions);
99 for (const auto &node : partitions[0].nodes) {
100 Function *subF = module_->getFunction(node->name);
101 if (!subF) {
102 // If we fail dump partition info for debugging.
103 logPartitionInfo(mapping);
104 return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
105 "Invalid function name " + node->name);
106 }
107 subF->dumpDAG("partitionLogicalID" +
108 std::to_string(node->logicalDevices[0]) + "__" +
109 subF->getName().str() + "__" + node->backendName + ".dot");
110 }
111 }
112 return Error::success();
113}
114
115Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
116 const std::vector<Backend *> &backends, bool optimized)
117 : module_(parent), deviceInfo_(devices), backends_(backends),
118 optimized_(optimized) {
119 init();
120}
121
122Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
123 bool optimized, PartitionConfig partitionConfig)
124 : module_(parent), deviceInfo_(devices), optimized_(optimized),
125 partitionConfig_(partitionConfig) {
126 init();
127}
128
129Function *Partitioner::selectRepFunc(Module *parent, uint64_t &memSize) {
130 auto funcList = parent->getFunctions();
131 Function *ret = nullptr;
132 uint64_t maxMemSize = 0;
133 for (Function *F : funcList) {
134 uint64_t curSize = memSize;
135
136 // The set to keep the placeholders (only for Inputs) whose size is
137 // already calculated.
138 std::set<llvm::StringRef> pSet;
139
140 for (auto &node : F->getNodes()) {
141 int n = node.getNumInputs();
142 if (node.getKind() == Kinded::Kind::SaveNodeKind) {
143 // Special node, the placeholder should be ignored?
144 continue;
145 }
146 for (int i = 0; i < n; i++) {
147 Placeholder *in =
148 llvm::dyn_cast<Placeholder>(node.getNthInput(i).getNode());
149 if (in && pSet.find(in->getName()) == pSet.end()) {
150 auto ty = in->getType();
151 curSize += ty->getSizeInBytes();
152 pSet.insert(in->getName());
153 }
154 }
155 }
156 // Find the function with largest required memory as the representative
157 // function.
158 if (!ret || curSize > maxMemSize) {
159 ret = F;
160 maxMemSize = curSize;
161 }
162 }
163 memSize = maxMemSize;
164 return ret;
165}
166
167void Partitioner::partitionsAdjust(NodeToFunctionMap &partitions,
168 uint64_t availableMemory) {
169 // For each partition, create a node set.
170 FunctionToNodesMap nodesSet;
171 for (auto it = partitions.begin(); it != partitions.end(); ++it) {
172 nodesSet[(*it).second].insert((*it).first);
173 }
174
175 // Optimize the communication cost.
176 optimizeCommunicationCost(partitions, nodesSet, module_, availableMemory);
177
178 // Combine the current partitions if necessary.
179 partitionsCombine(partitions, nodesSet, module_, availableMemory);
180}
181
182/// Assign nodes to partitions and return the mapping.
183NodeToFunctionMap Partitioner::selectPartitions(Function *F,
184 uint64_t availableMemory,
185 llvm::StringRef backendName) {
186 NodeToFunctionMap mapping;
187 BFSLevel bfs = getBFSLevel(F);
188 size_t level = bfs.size();
189
190 // Step 1 : get the initial cut based on BFS levels and availableMemory.
191 int color = 0;
192 Function *newF;
193 newF = F->getParent()->createFunction(std::string(F->getName()) + "_part" +
194 std::to_string(++color));
195 mapping.createPartition(newF, backendName);
196 NodesSet currentPartition;
197 GraphMemInfo graphMem;
198 graphMem.contextCount = contextCount_;
199
200 for (int i = level - 1; i >= 0; i--) {
201 for (size_t j = 0, e = bfs[i].size(); j < e; j++) {
202 Node *N = bfs[i][j];
203 graphMem = updateGraphMemInfoByAddingNode(currentPartition, graphMem, N);
204 // If after adding node N, the memory usage of this partition exceeds the
205 // device memory limitations, N can't be added into the current partition
206 // and a new partition is created.
207 if (graphMem.getTotalMemSize() > availableMemory) {
208 newF = F->getParent()->createFunction(
209 std::string(F->getName()) + "_part" + std::to_string(++color));
210 mapping.createPartition(newF, backendName);
211 currentPartition.clear();
212 graphMem =
213 updateGraphMemInfoByAddingNode(currentPartition, GraphMemInfo{}, N);
214 }
215 currentPartition.insert(N);
216 mapping.add(N, newF);
217 graphMem.contextCount = contextCount_;
218 mapping.setGraphMemInfo(newF, graphMem);
219 }
220 }
221
222 // Step 2 : adjust the partition based on performance.
223 partitionsAdjust(mapping, availableMemory);
224
225 return mapping;
226}
227
228void Partitioner::saturateHost(unsigned logicalDeviceCount,
229 const DAGListTy &partitions,
230 size_t availableLogicalDevices) {
231 DCHECK(availableLogicalDevices <= deviceInfo_.size())
232 << "Requested number of logical devices must be less than or euqal "
233 "the number of found devices.";
234 // If not specified, use number of available physical devices.
235 if (availableLogicalDevices == 0 ||
236 availableLogicalDevices > deviceInfo_.size()) {
237 availableLogicalDevices = deviceInfo_.size();
238 }
239 unsigned duplications = availableLogicalDevices / logicalDeviceCount;
240 if (duplications < 2) {
241 return;
242 }
243 // Add additional logical devices to each node.
244 for (auto &network : partitions) {
245 for (auto &node : network.nodes) {
246 // Set instanceCount.
247 node->instanceCount = duplications;
248 // Build list of new logical devices to add to node.
249 std::vector<unsigned> newDevices;
250 for (auto logical : node->logicalDevices) {
251 // To ensure we do not have a logicalID collision we use the following
252 // scheme. We have an iterator starting at 1 for each duplication pass.
253 // The new ID we add is calculated as follows:
254 // (iterator * logicalDeviceCount) + initialLogicalID
255 for (unsigned i = 1; i < duplications; i++) {
256 newDevices.push_back(logical + (i * logicalDeviceCount));
257 }
258 }
259 // Append the new logical devices to the node's logical device vector.
260 node->logicalDevices.insert(node->logicalDevices.end(),
261 newDevices.begin(), newDevices.end());
262 }
263 }
264}
265
266Expected<DAGListTy> Partitioner::backendBasedPartition(
267 FunctionToBackendNameMap &funcToBackend, Function *F,
268 std::vector<Backend *> &backends, CompilationContext &cctx) {
269 NodeToFunctionMap mapping;
270 llvm::DenseMap<Node *, std::string> nodeToBackendName;
271
272 // For each node find a backend that supports it.
273 for (auto &N : F->getNodes()) {
274 for (auto &backend : backends) {
275 // Find the first backend that supports this node. The order of backends
276 // is important. The check flow is :
277
278 // Step 1: If a node is in pre-defined non-supported nodes set, it can not
279 // be assigned to this backend. Continue.
280 const auto &nonSupportedNodesKinds =
281 backendMap_[backend->getBackendName()].nonSupportedNodesKinds;
282 if (nonSupportedNodesKinds.count(N.getKind())) {
283 // This op is on the pre-defined non-supported op list:
284 continue;
285 }
286 // Step 2: If the pre-defined supported nodes set is empty, it means all
287 // nodes could be assigned to this backend. If the pre-defined supported
288 // nodes set is not empty, we check that if the node from Step 1 is in
289 // this set or not. If not, continue.
290 const auto &supportedNodesKinds =
291 backendMap_[backend->getBackendName()].supportedNodesKinds;
292 if (!supportedNodesKinds.empty() &&
293 !supportedNodesKinds.count(N.getKind())) {
294 // This op is not on the pre-definded supported op list:
295 continue;
296 }
297 // Step 3: Check if the node is actually supported in this backend, if so,
298 // assign it to this backend and break. Otherwise continue.
299 // TODO: the logic here need to be improved.
300 if (backend->shouldLower(&N) || backend->isOpSupported(N)) {
301 // Put this node into a partition for this backend.
302 nodeToBackendName[&N] = backend->getBackendName();
303 break;
304 }
305 }
306 if (nodeToBackendName.find(&N) == nodeToBackendName.end()) {
307 logPartitionInfo(mapping);
308 return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
309 "Node is not supported by any of the provided backends");
310 }
311 }
312
313 BFSLevel bfs = getBFSLevel(F);
314 size_t level = bfs.size();
315 int color = 0;
316 Function *newF;
317 newF = F->getParent()->createFunction(std::string(F->getName()) + "_part" +
318 std::to_string(++color));
319 auto backendName = nodeToBackendName[bfs[level - 1][0]];
320 if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) {
321 // When profiling, all the partition backend is assigned to
322 // profilingBackend.
323 mapping.createPartition(newF, profilingBackend);
324 funcToBackend[newF] = profilingBackend;
325 } else {
326 mapping.createPartition(newF, backendName);
327 funcToBackend[newF] = backendName;
328 }
329 for (int i = level - 1; i >= 0; i--) {
330 for (size_t j = 0, e = bfs[i].size(); j < e; j++) {
331 Node *N = bfs[i][j];
332 auto bk = nodeToBackendName[N];
333 if (bk != backendName) {
334 backendName = bk;
335 newF = F->getParent()->createFunction(
336 std::string(F->getName()) + "_part" + std::to_string(++color));
337 if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) {
338 // When profiling, all the partition backend is assigned to be
339 // profilingBackend.
340 mapping.createPartition(newF, profilingBackend);
341 funcToBackend[newF] = profilingBackend;
342 } else {
343 mapping.createPartition(newF, backendName);
344 funcToBackend[newF] = backendName;
345 }
346 }
347 mapping.add(N, newF);
348 }
349 }
350
351 std::vector<Function *> funcs;
352 funcs.push_back(F);
353 // When profiling, the partition flow will be stopped after
354 // backendBasedPartition. Therefore, the DAG needs to be generated. Otherwise,
355 // no need to generate DAG.
356 bool genDAG = cctx.precisionConfig.quantMode == QuantizationMode::Profile
357 ? true
358 : false;
359 if (genDAG) {
360 DeviceIDTy logicalDeviceID = 0;
361 for (auto &func : mapping.getPartitions()) {
362 mapping.appendLogicalDeviceID(func, logicalDeviceID++);
363 }
364 }
365 return doPartitioning(F->getName(), funcs, module_, mapping, genDAG,
366 cctx.backendOpts.backendSpecificNodeInfo);
367}
368
369void Partitioner::genBackendMap(
370 std::map<std::string, BackendInfo> &backendMap,
371 std::vector<std::unique_ptr<Backend>> &backendsHolder,
372 std::vector<Backend *> &backends) {
373 // If the backends are created already, we use them directly.
374 bool hasBackends = backends_.size() != 0;
375 if (hasBackends) {
376 DCHECK(backends_.size() == deviceInfo_.size())
377 << "number of backends and devices is not match.";
378 }
379
380 int n = 0;
381 for (size_t i = 0, e = deviceInfo_.size(); i < e; i++) {
382 std::string backendName = deviceInfo_[i].backendName;
383 if (hasBackends) {
384 DCHECK(backends_[i]->getBackendName() == backendName)
385 << "Backend Type mismatch.";
386 }
387 if (backendMap.find(backendName) == backendMap.end()) {
388 BackendInfo backendInfo;
389 backendInfo.num = 1;
390 // We assume that for the same type of devices, the available memory size
391 // is the same.
392 // TODO : will improve the algorithm for different memory size.
393 backendInfo.memSize = deviceInfo_[i].availableMemory;
394 backendInfo.inputCountMax = deviceInfo_[i].inputCountMax;
395 backendInfo.peakDramBw = deviceInfo_[i].peakDramBw;
396 backendInfo.peakSramBw = deviceInfo_[i].peakSramBw;
397 backendInfo.sramCapacity = deviceInfo_[i].sramCapacity;
398 backendInfo.peakCompute = deviceInfo_[i].peakCompute;
399 backendInfo.nonSupportedNodesKinds =
400 generateNodeKindsSet(deviceInfo_[i].nonSupportedNodes);
401 backendInfo.supportedNodesKinds =
402 generateNodeKindsSet(deviceInfo_[i].supportedNodes);
403 if (hasBackends) {
404 backendInfo.backend = backends_[i];
405 } else {
406 backendsHolder.emplace_back(createBackend(backendName));
407 backendInfo.backend = backendsHolder[n++].get();
408 }
409 backendMap[backendName] = backendInfo;
410 backends.push_back(backendMap[backendName].backend);
411 } else {
412 backendMap[backendName].num += 1;
413 // Since we are currently assuming one value it should be the max.
414 backendMap[backendName].memSize = std::max(
415 backendMap[backendName].memSize, deviceInfo_[i].availableMemory);
416 }
417 }
418}
419
420const DeviceInfo &
421Partitioner::getDeviceInfoForBackend(llvm::StringRef backendName) {
422 for (DeviceInfo &devInfo : deviceInfo_) {
423 if (devInfo.backendName == backendName)
424 return devInfo;
425 }
426 llvm_unreachable("Each backend should have at least one device");
427}
428
429Expected<DAGListTy> Partitioner::createDAGWithoutPartition(
430 llvm::StringRef backendName, std::map<std::string, BackendInfo> &backendMap,
431 CompilationContext &cctx) {
432 DAGListTy partitions;
433 const DeviceIDTy logDevice = 0;
434 for (auto F : module_->getFunctions()) {
435 if (!optimized_) {
436 auto backend = backendMap[backendName.str()].backend;
437 RETURN_IF_ERR(::glow::optimizeFunction(
438 F, *backend, cctx, &getDeviceInfoForBackend(backendName)));
439 }
440 std::unique_ptr<DAGNode> DAG0 = glow::make_unique<DAGNode>();
441 DAG0->logicalDevices = {logDevice};
442 DAG0->name = F->getName().str();
443 DAG0->module = module_;
444 std::unique_ptr<DAGNode> DAG1 = glow::make_unique<DAGNode>();
445 DAG1->logicalDevices = {logDevice};
446 DAG1->name = F->getName().str();
447 DAG1->backendName = backendName.str();
448 DAG1->parents.push_back(DAG0.get());
449 DAG0->children.push_back(DAG1.get());
450 DAG1->replicationCount = cctx.replicationCount;
451 DAGNodePtrVec nodes;
452 nodes.push_back(std::move(DAG1));
453 partitions.push_back({std::move(DAG0), std::move(nodes)});
454 }
455 if (cctx.saturateHost) {
456 // Saturate the Host.
457 saturateHost(1, partitions, cctx.saturateKDevices);
458 }
459
460 NodeToFunctionMap mapping;
461 for (auto func : module_->getFunctions()) {
462 mapping.createPartition(func, backendName);
463 mapping.setGraphMemInfo(func, getFunctionMemory(func));
464
465 // Use the same hard-coded logical device ID as used for the DAG itself.
466 mapping.appendLogicalDeviceID(func, logDevice);
467 }
468
469 RETURN_IF_ERR(finalize(partitions, mapping));
470
471 return std::move(partitions);
472}
473
474Expected<DAGListTy> Partitioner::loadBalancedPartition(CompilationContext &cctx,
475 size_t numDevices) {
476
477 if (multiBackendNames_) {
478 VLOG(1) << "For multi backend types, load-balanced partition can't be "
479 "applied. Call heterogeneous partition instead.";
480 return heterogeneousPartition(cctx);
481 }
482 F_ = selectRepFunc(module_, memSize_);
483 std::string origName(F_->getName().data());
484 DAGListTy partitions;
485 std::vector<Backend *> backends;
486 genBackendMap(backendMap_, backendHolder_, backends);
487 auto backendName = backends[0]->getBackendName();
488
489 if (memSize_ < backendMap_[backendName].memSize) {
490 // No partition is needed. Create DAGNode and return. This root is always a
491 // dummy function.
492 if (logPartition) {
493 LOG(INFO) << "The model is too small for applying partition.\n"
494 << "Model size : " << memSize_ << "\n"
495 << "Backend Name : " << backendName << "\n"
496 << "Device memory: " << backendMap_[backendName].memSize
497 << "\n";
498 }
499 return createDAGWithoutPartition(backendName, backendMap_, cctx);
500 }
501
502 // Step 1: Get the minial number of partitions from auto-partition.
503 uint64_t availableMemory = backendMap_[backendName].memSize;
504 if (!optimized_) {
505 RETURN_IF_ERR(::glow::optimizeFunction(F_, *(backends[0]), cctx));
506 }
507 NodeToFunctionMap mapping =
508 selectPartitions(F_, availableMemory, backendName);
509 logicalDeviceID_ = assignLogicalDeviceID(mapping, backendMap_);
510
511 if (logicalDeviceID_ > numDevices) {
512 numDevices = logicalDeviceID_;
513 }
514 // Step 2:
515 // Currently, the load balanced partitioner disregards the input mapping
516 // and only uses the numPartitions input from previous partitioning passes
517 // But we take this in to leave open the option of using the previous mapping
518 // at a later point.
519 // The main idea here is to use the roofline estimates to load balance
520 // partitions. At this point, we stick to one partition per device, so
521 // we ensure that we only have edges from nodes in smaller partition ids to
522 // nodes in larger partition ids to ensure an acyclic DAGNode graph.
523 //
524 // The overall algorithm is as follows:
525 // Iterate through all operators in breadth-first fashion.
526 // For each operator do:
527 // (a) Find the maximum partition id of each input node.
528 // (b) Assign the operator to this partition if memory
529 // constraints are satisfied and the total sum of operator runtimes
530 // assigned to the partition exceeds 1/numPartitions fraction of
531 // overall roofline runtime
532 // (c) In case memory constraint isnt satisfied, then try to put operator
533 // in successively higher partitions until the conditions get satisfied.
534 // If we cannot find such a partition where this operator can be assigned,
535 // throw an error.
536
537 // Initialize runtimes and memory availability per device
538 std::vector<float> deviceTime(numDevices, 0);
539 std::vector<size_t> memoryAvailable(numDevices, availableMemory);
540 std::vector<NodesSet> nodesInPartitions(numDevices);
541 std::vector<GraphMemInfo> graphMem(numDevices, GraphMemInfo{});
542 std::vector<Function *> partitionFuncs(numDevices);
543
544 // Compute total roofline time
545 NodeToFunctionMap partitionMap;
546 float totalRooflineTime = 0;
547 for (auto &n : F_->getNodes()) {
548 totalRooflineTime +=
549 getNodeComputeTime(&n, backendMap_[deviceInfo_[0].backendName]);
550 }
551
552 float timePerPartition = totalRooflineTime / numDevices;
553
554 // Get the BFS levels
555 Function *newF;
556 BFSLevel bfs = getBFSLevel(F_);
557 size_t level = bfs.size();
558
559 // Create the functions and push them into the mapping
560 for (DeviceIDTy curPartition = 0; curPartition < numDevices; curPartition++) {
561 std::string funcName =
562 std::string(F_->getName()) + "_part" + std::to_string(curPartition + 1);
563 if (F_->getParent()->hasFunction(funcName)) {
564 newF = F_->getParent()->getFunction(funcName);
565 F_->getParent()->eraseFunction(newF);
566 }
567 newF = F_->getParent()->createFunction(funcName);
568 partitionMap.createPartition(newF, backendName);
569 partitionMap.appendLogicalDeviceID(newF, curPartition);
570 partitionFuncs[curPartition] = newF;
571 }
572
573 // Go through operators level by level
574 for (int i = level - 1; i >= 0; i--) {
575 for (size_t j = 0, e = bfs[i].size(); j < e; j++) {
576 Node *N = bfs[i][j];
577
578 // Find the maximum partition id of the inputs to the node
579 DeviceIDTy maxLogicalDeviceId = 0;
580 for (auto &I : getInputs(N)) {
581 Function *inpF = partitionMap[I];
582 auto logicalDeviceIds = partitionMap.getLogicalDeviceIDList(inpF);
583 DCHECK(logicalDeviceIds.size() == 1);
584 auto logicalDeviceId = logicalDeviceIds[0];
585 if (logicalDeviceId > maxLogicalDeviceId) {
586 maxLogicalDeviceId = logicalDeviceId;
587 }
588 }
589
590 auto curOpTime =
591 getNodeComputeTime(N, backendMap_[deviceInfo_[0].backendName]);
592 auto curOpMemory = getNodeMemUsage(N);
593
594 // Find a partition to put this node into
595 DeviceIDTy curPartition = maxLogicalDeviceId;
596 const float allowedLoadImbalanceFraction = 0.5f;
597 for (; curPartition < numDevices; curPartition++) {
598 // Put the op in current partition if
599 // (a) memory constaints and load balance constraints are not violated,
600 // or (b) this is the last partition and memory capacity isnt exceeded
601 // The allowedLoadImbalanceFraction in the load balance case is to avoid
602 // edge cases where load balance is only violated by a small amount and
603 // moving to the next partition would result in significant imbalance in
604 // runtime. Hence if the violation is by less than
605 // allowedLoadImbalanceFraction of the operator cost, then we prefer to
606 // keep it in the current partition.
607 bool loadBalanceValid = deviceTime[curPartition] +
608 curOpTime * allowedLoadImbalanceFraction <
609 timePerPartition;
610 bool memValid = memoryAvailable[curPartition] >= curOpMemory;
611
612 if (memValid && (loadBalanceValid || curPartition == numDevices - 1)) {
613 // valid, put the node in the current partition
614 Function *curF = partitionFuncs[curPartition];
615 partitionMap.add(N, curF);
616 deviceTime[curPartition] += curOpTime;
617 memoryAvailable[curPartition] -= curOpMemory;
618 graphMem[curPartition] = updateGraphMemInfoByAddingNode(
619 nodesInPartitions[curPartition], graphMem[curPartition], N);
620 nodesInPartitions[curPartition].insert(N);
621 partitionMap.setGraphMemInfo(curF, graphMem[curPartition]);
622 break;
623 }
624 }
625
626 // Throw error if we were not able to put this node into any partition
627 if (curPartition >= numDevices) {
628 logPartitionInfo(partitionMap);
629 return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
630 "Load balance partition error");
631 }
632 }
633 }
634 for (size_t i = 0; i < numDevices; i++) {
635 VLOG(1) << "Partition #" << i << " has estimated runtime " << deviceTime[i];
636 }
637 // Check if the memory usage meets the device memory limitation.
638 RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_));
639
640 // assignLogicalDeviceID adds all partitions to their logical device, clear
641 // the existing first to prevent duplication.
642 partitionMap.clearLogicalDeviceID();
643 logicalDeviceID_ = assignLogicalDeviceID(partitionMap, backendMap_);
644 RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_));
645 RETURN_IF_ERR(resourceCountValidation(partitionMap, backendMap_));
646
647 partitions =
648 doPartitioning(origName, {F_}, module_, partitionMap, /* saveDAG */ true,
649 cctx.backendOpts.backendSpecificNodeInfo);
650 module_->eraseFunction(F_);
651
652 if (cctx.saturateHost &&
653 partitionMap.getPartitions().size() < deviceInfo_.size()) {
654 saturateHost(logicalDeviceID_, partitions, cctx.saturateKDevices);
655 }
656
657 RETURN_IF_ERR(finalize(partitions, partitionMap));
658
659 return std::move(partitions);
660}
661
662Expected<DAGListTy>
663Partitioner::quantizationProfilingPartition(CompilationContext &cctx) {
664 // For quantization profiling flow, currently we assume there is only 1
665 // function in a module.
666 if (module_->getFunctions().size() != 1) {
667 return MAKE_ERR(
668 ErrorValue::ErrorCode::PARTITIONER_ERROR,
669 strFormat(
670 "Invalid : %lu functions in a module. In quantization profiling "
671 "partition flow, the module can only contain 1 function",
672 module_->getFunctions().size()));
673 }
674
675 // Quantization profiling flow is run under CPU backend, so we don't really
676 // need the concrete partition. The backendBasedPartition is necessary since
677 // we need the mapping between quantized tensor and original tensor.
678 DAGListTy partitions;
679 std::vector<Backend *> backends;
680 genBackendMap(backendMap_, backendHolder_, backends);
681 F_ = selectRepFunc(module_, memSize_);
682
683 FunctionToBackendNameMap funcToBackend;
684 ASSIGN_VALUE_OR_RETURN_ERR(
685 partitions, backendBasedPartition(funcToBackend, F_, backends, cctx));
686 module_->eraseFunction(F_);
687 std::unique_ptr<Backend> backend(createBackend(profilingBackend));
688 for (Function *subF : module_->getFunctions()) {
689 DCHECK(subF->verify()) << "Conversion led to invalid function";
690 if (!optimized_) {
691 RETURN_IF_ERR(::glow::optimizeFunction(subF, *backend, cctx));
692 }
693 }
694 if (logPartition) {
695 LOG(INFO)
696 << "Profiling a model to be partitioned cross different backends. Each "
697 "sub-network will be optimized and run on cpu backend.\n";
698 }
699 return std::move(partitions);
700}
701
702Expected<DAGListTy>
703Partitioner::heterogeneousPartition(CompilationContext &cctx) {
704 DAGListTy partitions;
705 // Prepare the mapping between BackendName and BackendInfo.
706 std::vector<Backend *> backends;
707 genBackendMap(backendMap_, backendHolder_, backends);
708
709 // Step 0: Find the representative function for running partitioning
710 // algorithm.
711 F_ = selectRepFunc(module_, memSize_);
712
713 // Step 1 : do the partition based on backends type.
714 FunctionToBackendNameMap funcToBackend;
715 std::string origName(F_->getName().data());
716 if (backends.size() == 1) {
717 // Only one type of backends, no need to backendName based partition.
718 auto backendName = backends[0]->getBackendName();
719 funcToBackend[F_] = backendName;
720
721 if (memSize_ < backendMap_[backendName].memSize) {
722 // No partition is needed. Create DAGNode and return. This root is alway a
723 // dummy function.
724 if (logPartition) {
725 LOG(INFO) << "The model is too small for applying partition.\n"
726 << "Model size : " << memSize_ << "\n"
727 << "Backend Name : " << backendName << "\n"
728 << "Device memory: " << backendMap_[backendName].memSize
729 << "\n";
730 }
731 return createDAGWithoutPartition(backendName, backendMap_, cctx);
732 }
733 // NOTE: the following error detection will be removed once multi-functions
734 // in a module is supported.
735 if (module_->getFunctions().size() != 1) {
736 return MAKE_ERR(
737 ErrorValue::ErrorCode::PARTITIONER_ERROR,
738 strFormat("Invalid : %lu functions in a module. Now in heterogeneous "
739 "partition flow, the module can only contain 1 function",
740 module_->getFunctions().size()));
741 }
742 } else {
743 // NOTE: the following error detection will be removed once multi-functions
744 // in a module is supported.
745 if (module_->getFunctions().size() != 1) {
746 return MAKE_ERR(
747 ErrorValue::ErrorCode::PARTITIONER_ERROR,
748 strFormat(
749 "Invalid : %lu functions in a module. Now in heterogeneous partition\
750 flow, the module can only contain 1 function",
751 module_->getFunctions().size()));
752 }
753 ASSIGN_VALUE_OR_RETURN_ERR(
754 partitions, backendBasedPartition(funcToBackend, F_, backends, cctx));
755 module_->eraseFunction(F_);
756 }
757
758 // Step 2 : optimize each functions based on its backend type and apply the
759 // partition algorithm.
760 NodeToFunctionMap mapping;
761 std::vector<Function *> funcs;
762 for (auto i = funcToBackend.begin(); i != funcToBackend.end(); ++i) {
763 auto *func = i->first;
764 auto *backend = backendMap_[i->second].backend;
765 auto availMem = backendMap_[i->second].memSize;
766 funcs.push_back(func);
767 DCHECK(func->verify()) << "Conversion led to invalid function";
768 // Step 2.1 : optimize a function if it has not been optimized yet.
769 if (!optimized_) {
770 RETURN_IF_ERR(::glow::optimizeFunction(
771 func, *backend, cctx,
772 &getDeviceInfoForBackend(backend->getBackendName())));
773 }
774
775 // Step 2.2 : apply graph partitioning algrithm to find out the partition.
776 NodeToFunctionMap partitionMap =
777 selectPartitions(func, availMem, i->second);
778 mapping.insert(partitionMap);
779 }
780
781 // Check if the memory usage meets the device memory limitation.
782 RETURN_IF_ERR(memoryUsageValidation(mapping, backendMap_));
783
784 // Step 3 : assign each partition with a logical device id. The partitions
785 // with the same logical device id will be assigned into the same physical
786 // device.
787 logicalDeviceID_ = assignLogicalDeviceID(mapping, backendMap_);
788
789 // Check if the number of logical devices is less than the given physical
790 // devices.
791 RETURN_IF_ERR(logicalDevicesValidation(mapping, backendMap_));
792
793 // Step 4 : do the real partitioning for the function list.
794 partitions =
795 doPartitioning(origName, funcs, module_, mapping, /* saveDAG */ true,
796 cctx.backendOpts.backendSpecificNodeInfo);
797
798 // Step 5 : Post-partition optimization - Adjust the logicalDevice for each
799 // DAGNode.
800 if (cctx.saturateHost && backends.size() == 1 &&
801 mapping.getPartitions().size() < deviceInfo_.size()) {
802 // Attempt to saturate the host when there is only one type of backend.
803 // Passing in the count of logical devices. Since logicalId starts at 0 we
804 // add one.
805 saturateHost(logicalDeviceID_, partitions, cctx.saturateKDevices);
806 }
807
808 // Step 6 : clean up and verify the generated new functions.
809 for (auto i = funcToBackend.begin(); i != funcToBackend.end(); ++i) {
810 module_->eraseFunction(i->first);
811 }
812
813 RETURN_IF_ERR(finalize(partitions, mapping));
814
815 return std::move(partitions);
816}
817
818Expected<DAGListTy>
819Partitioner::partitionFromConfig(const PartitionConfig &partitionConfig,
820 CompilationContext &cctx) {
821 DAGListTy partitions;
822 // Prepare the mapping between BackendName and BackendInfo.
823 std::vector<Backend *> backends;
824 genBackendMap(backendMap_, backendHolder_, backends);
825 Function *F = module_->getFunction(partitionConfig.funcName);
826 if (!F) {
827 return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
828 strFormat("Can't find function %s in current module.",
829 F->getName().str().data()));
830 }
831
832 DCHECK(
833 partitionConfig.numOfPartitions == partitionConfig.backendNames.size() &&
834 partitionConfig.numOfPartitions == partitionConfig.partitionNames.size())
835 << "Invalid user-defined partition config.";
836
837 if (partitionConfig.backendHints.size()) {
838 DCHECK(partitionConfig.numOfPartitions ==
839 partitionConfig.backendHints.size())
840 << "Invalid user-defined partition config (backendHints).";
841 }
842
843 NodeToFunctionMap partitionMap;
844 std::vector<Function *> funcList;
845 std::unordered_set<size_t> unused;
846 std::vector<NodesSet> nodesSets(partitionConfig.numOfPartitions);
847 // Create partitions based on the given number and names.
848 for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
849 Function *newF = module_->createFunction(partitionConfig.partitionNames[i]);
850 funcList.push_back(newF);
851 partitionMap.createPartition(newF, partitionConfig.backendNames[i]);
852 unused.insert(i);
853 }
854
855 // Map the nodes the the partitions.
856 std::vector<Node *> unMapped;
857 for (auto &node : F->getNodes()) {
858 auto iter = partitionConfig.nodeToPartition.find(node.getName());
859 if (iter == partitionConfig.nodeToPartition.end()) {
860 // If a node in F is not in the node to partition mapping, put it into
861 // unMaped list.
862 unMapped.push_back(&node);
863 } else {
864 size_t partitionID = iter->second;
865 DCHECK(partitionID < partitionConfig.numOfPartitions)
866 << "Invalid partition id :" << partitionID;
867 partitionMap.add(&node, funcList[partitionID]);
868 unused.erase(partitionID);
869 nodesSets[partitionID].insert(&node);
870 }
871 }
872
873 // If there is unused partition and unmapped nodes, map those nodes to the
874 // unused partition.
875 if (unMapped.size()) {
876 DCHECK_EQ(unused.size(), 1) << "There must be exactly 1 unused partition.";
877 auto partitionID = *(unused.begin());
878 for (auto &node : unMapped) {
879 partitionMap.add(node, funcList[partitionID]);
880 nodesSets[partitionID].insert(node);
881 }
882 }
883
884 // Set backend hints if they exist
885 if (partitionConfig.backendHints.size()) {
886 for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
887 auto func = funcList[i];
888 partitionMap.setBackendHints(func, partitionConfig.backendHints[i]);
889 }
890 }
891
892 // Validate memory usage.
893 for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
894 GraphMemInfo cost = getGraphMemInfo(nodesSets[i], contextCount_);
895 partitionMap.setGraphMemInfo(funcList[i], cost);
896 }
897 RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_));
898
899 // If logical device assignments are provided use them otherwise assign them.
900 if (partitionConfig.logicalIDs.size()) {
901 DCHECK(partitionConfig.numOfPartitions ==
902 partitionConfig.logicalIDs.size());
903 for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
904 auto func = funcList[i];
905 for (auto logicalDevice : partitionConfig.logicalIDs[i]) {
906 partitionMap.appendLogicalDeviceID(func, logicalDevice);
907 }
908 }
909
910 } else {
911 // Logical device ID validation.
912 logicalDeviceID_ = assignLogicalDeviceID(partitionMap, backendMap_);
913 }
914 // Add replication count to config if provided.
915 for (auto &replicationAssignment : partitionConfig.replicationCount) {
916 auto func = funcList.at(replicationAssignment.first);
917 partitionMap.addReplicationCount(func, replicationAssignment.second);
918 }
919
920 RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_));
921 RETURN_IF_ERR(resourceCountValidation(partitionMap, backendMap_));
922
923 // Do partition.
924 partitions = doPartitioning(F->getName(), {F}, module_, partitionMap,
925 /* saveDAG */ true,
926 cctx.backendOpts.backendSpecificNodeInfo);
927 module_->eraseFunction(F);
928
929 // DAG validation.
930 RETURN_IF_ERR(dagValidation(partitions[0]));
931
932 // Verify the function.
933 for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
934 auto func = funcList[i];
935 DCHECK(func->verify()) << "Conversion led to invalid function";
936 }
937
938 RETURN_IF_ERR(finalize(partitions, partitionMap));
939
940 return std::move(partitions);
941}
942
943Expected<DAGListTy>
944Partitioner::setupPrepartitionedModule(CompilationContext &cctx) {
945 const PrePartitionedConfig &config = *cctx.prepartitionedConfig;
946
947 RETURN_ERR_IF_NOT(
948 !multiBackendNames_,
949 "Do not support multiple backend kinds in prepartitioned flow.");
950
951 // Prepare the mapping between BackendName and BackendInfo.
952 std::vector<Backend *> backends;
953 genBackendMap(backendMap_, backendHolder_, backends);
954
955 const std::vector<Function *> &funcs = config.funcs;
956
957 Backend *B = backends[0];
958 auto backendName = B->getBackendName();
959
960 // Optimize all Functions if necessary.
961 if (!optimized_) {
962 for (Function *F : funcs) {
963 RETURN_IF_ERR(::glow::optimizeFunction(
964 F, *B, cctx, &getDeviceInfoForBackend(backendName)));
965 }
966 }
967
968 NodeToFunctionMap partitionMap;
969 // Create partitions based on the given number and names.
970 for (size_t i = 0, e = funcs.size(); i < e; i++) {
971 partitionMap.createPartition(funcs[i], deviceInfo_[0].backendName);
972 }
973
974 // Map the nodes the the partitions.
975 for (Function *F : funcs) {
976 for (auto &node : F->getNodes()) {
977 partitionMap.add(&node, F);
978 }
979 }
980
981 // Validate memory usage.
982 for (Function *F : funcs) {
983 partitionMap.setGraphMemInfo(F, getFunctionMemory(F));
984 }
985 RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_));
986
987 // If logical device assignments are provided use them otherwise assign them.
988 DCHECK(funcs.size() == config.logicalIDs.size());
989 for (size_t i = 0; i < funcs.size(); i++) {
990 Function *F = funcs[i];
991 for (auto logicalDevice : config.logicalIDs[i]) {
992 partitionMap.appendLogicalDeviceID(F, logicalDevice);
993 }
994 }
995 RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_));
996 RETURN_IF_ERR(resourceCountValidation(partitionMap, backendMap_));
997
998 // Copy in or validate all members of the PPC.
999 RETURN_ERR_IF_NOT(
1000 funcs.size() == config.backendSpecificOpts.size(),
1001 "Number of Functions must equal number of backendSpecificOpts");
1002 RETURN_ERR_IF_NOT(funcs.size() == config.backendHints.size(),
1003 "Number of Functions must equal number of backendHints");
1004 RETURN_ERR_IF_NOT(funcs.size() == config.replicationCounts.size(),
1005 "Number of Functions must equal");
1006 RETURN_ERR_IF_NOT(
1007 funcs.size() == config.backendNames.size() || config.backendNames.empty(),
1008 "If there are backendNames specified, there must be one per Function");
1009 for (size_t i = 0, e = funcs.size(); i < e; i++) {
1010 Function *F = funcs[i];
1011 partitionMap.setBackendSpecificOpts(F, config.backendSpecificOpts[i]);
1012 partitionMap.setBackendHints(F, config.backendHints[i]);
1013 partitionMap.addReplicationCount(F, config.replicationCounts[i]);
1014 if (!config.backendNames.empty()) {
1015 RETURN_ERR_IF_NOT(backendName == config.backendNames[i],
1016 "Mismatch on backendName for partition");
1017 }
1018 }
1019
1020 // Do partition.
1021 DAGListTy partitions = doPartitioning(
1022 config.funcName, funcs, module_, partitionMap,
1023 /* saveDAG */ true, cctx.backendOpts.backendSpecificNodeInfo,
1024 /* skipCloning */ true);
1025
1026 // DAG validation.
1027 RETURN_IF_ERR(dagValidation(partitions[0]));
1028
1029 // Verify the function.
1030 for (Function *F : funcs) {
1031 DCHECK(F->verify()) << "Conversion led to invalid function";
1032 }
1033
1034 RETURN_IF_ERR(finalize(partitions, partitionMap));
1035
1036 if (cctx.saturateHost) {
1037 // Use the config's logical IDs to determine how many cards it's using.
1038 llvm::SmallSet<DeviceIDTy, 6> allLogicalIDs;
1039 for (const auto &IDs : config.logicalIDs) {
1040 for (const auto &id : IDs) {
1041 allLogicalIDs.insert(id);
1042 }
1043 }
1044 saturateHost(allLogicalIDs.size(), partitions, cctx.saturateKDevices);
1045 }
1046
1047 return std::move(partitions);
1048}
1049
1050// Do a search starting at an SLS node to split any concats/tanh that will
1051// be included the SLS partition
1052static void splitConcatTanhFromNode(Function *F, Node *node,
1053 int concatSplitSize,
1054 const KindSet &pairSLSWithNodeKinds,
1055 bool concatTanhSinkApplied) {
1056 auto users = node->getUsers();
1057 for (auto &j : users) {
1058 Node *user = j.getUser();
1059 auto shouldPairWithSls = pairSLSWithNodeKinds.count(user->getKind());
1060 if (!shouldPairWithSls) {
1061 continue;
1062 }
1063
1064 if (auto *CN = llvm::dyn_cast<ConcatNode>(user)) {
1065 if (concatTanhSinkApplied) {
1066 auto concatUsers = CN->getUsers();
1067 // Skip splitting concats which don't go into a tanh sink or are small
1068 if (concatUsers.empty() ||
1069 concatUsers.begin()->getUser()->getKind() !=
1070 glow::Kinded::Kind::TanhNodeKind ||
1071 CN->getNumInputs() <= concatSplitSize) {
1072 continue;
1073 }
1074 auto tanhNode =
1075 llvm::dyn_cast<TanhNode>(concatUsers.begin()->getUser());
1076 auto dim = CN->getDim();
1077 // Split the concat into smaller concats and create a tanh sink for each
1078 // split
1079 std::vector<NodeValue> concats;
1080 for (size_t i = 0, n = CN->getNumInputs(); i < n;
1081 i += concatSplitSize) {
1082 auto begin = CN->getInputs().begin() + i;
1083 auto length = i + concatSplitSize < n ? concatSplitSize : n - i;
1084
1085 std::vector<NodeValue> concatInputs(begin, begin + length);
1086 auto *concat = F->createConcat(CN->getName().str() + "_part_" +
1087 std::to_string(i / n),
1088 concatInputs, dim);
1089 auto *tanh = F->createTanh(CN->getName().str() + "_tanh_part_" +
1090 std::to_string(i / n),
1091 concat->getResult());
1092 concats.emplace_back(tanh->getResult());
1093 }
1094 // Combine split up concats
1095 auto *newConcat =
1096 F->createConcat(CN->getName().str() + "_combined", concats, dim);
1097 tanhNode->getResult().replaceAllUsesOfWith(newConcat->getResult());
1098 F->eraseNode(CN);
1099 F->eraseNode(tanhNode);
1100 } else {
1101 // Skip splitting concats which don't have all tanh inputs or are small
1102 if (!checkNodeInputsAllKind(user, glow::Kinded::Kind::TanhNodeKind) ||
1103 CN->getNumInputs() <= concatSplitSize) {
1104 continue;
1105 }
1106 auto dim = CN->getDim();
1107 // Split the concat into smaller concats
1108 std::vector<NodeValue> concats;
1109 for (size_t i = 0, n = CN->getNumInputs(); i < n;
1110 i += concatSplitSize) {
1111 auto begin = CN->getInputs().begin() + i;
1112 auto length = i + concatSplitSize < n ? concatSplitSize : n - i;
1113
1114 std::vector<NodeValue> concatInputs(begin, begin + length);
1115 auto *concat = F->createConcat(CN->getName().str() + "_part_" +
1116 std::to_string(i / n),
1117 concatInputs, dim);
1118 concats.emplace_back(concat->getResult());
1119 }
1120 // Combine split-up concats
1121 auto *newConcat =
1122 F->createConcat(CN->getName().str() + "_combined", concats, dim);
1123 CN->getResult().replaceAllUsesOfWith(newConcat->getResult());
1124 F->eraseNode(CN);
1125 }
1126 } else {
1127 splitConcatTanhFromNode(F, user, concatSplitSize, pairSLSWithNodeKinds,
1128 concatTanhSinkApplied);
1129 }
1130 }
1131}
1132
1133static void splitConcatTanh(Function *F, int concatSplitSize,
1134 std::vector<std::string> pairSLSWith,
1135 bool concatTanhSinkApplied) {
1136 const std::unordered_map<std::string, glow::Kinded::Kind> nameToNodeKind = {
1137 {"Concat", glow::Kinded::Kind::ConcatNodeKind},
1138 {"LayerNorm", glow::Kinded::Kind::LayerNormalizationNodeKind},
1139 {"Tile", glow::Kinded::Kind::TileNodeKind},
1140 {"Tanh", glow::Kinded::Kind::TanhNodeKind}};
1141 for (auto &node : F->getNodes()) {
1142 switch (node.getKind()) {
1143
1144#define SPLIT_CONCAT_TANH_CASE(NODE_NAME_) \
1145 case Kinded::Kind::NODE_NAME_##Kind: { \
1146 auto SLS = llvm::cast<NODE_NAME_>(&node); \
1147 KindSet pairSLSWithNodeKinds; \
1148 for (auto &s : pairSLSWith) { \
1149 if (nameToNodeKind.find(s) == nameToNodeKind.end() || \
1150 pairSLSWithNodeKinds.count(nameToNodeKind.at(s))) { \
1151 continue; \
1152 } \
1153 if (s == "Tile") { \
1154 if (SLS->getResult().dims()[0] == 1) { \
1155 pairSLSWithNodeKinds.insert(nameToNodeKind.at(s)); \
1156 } \
1157 } else { \
1158 pairSLSWithNodeKinds.insert(nameToNodeKind.at(s)); \
1159 } \
1160 } \
1161 splitConcatTanhFromNode(F, SLS, concatSplitSize, pairSLSWithNodeKinds, \
1162 concatTanhSinkApplied); \
1163 } \
1164 continue;
1165
1166 SPLIT_CONCAT_TANH_CASE(FusedRowwiseQuantizedSparseLengthsWeightedSumNode);
1167 SPLIT_CONCAT_TANH_CASE(FusedRowwiseQuantizedSparseLengthsSumNode);
1168 SPLIT_CONCAT_TANH_CASE(RowwiseQuantizedSparseLengthsWeightedSumNode);
1169 SPLIT_CONCAT_TANH_CASE(SparseLengthsSumNode);
1170 SPLIT_CONCAT_TANH_CASE(SparseLengthsWeightedSumNode);
1171 SPLIT_CONCAT_TANH_CASE(EmbeddingBagNode);
1172 SPLIT_CONCAT_TANH_CASE(EmbeddingBagByteRowwiseOffsetsNode);
1173#undef SPLIT_CONCAT_TANH_CASE
1174
1175 default:
1176 continue;
1177 }
1178 }
1179}
1180
1181// Do a search starting at an SLS output to capture any Clip,
1182// LayerNormalization, Tile, Tanh nodes which are there
1183static void
1184expandFrontier(Node *node, const NodeValue &value,
1185 std::unordered_set<NodeValue> &frontier,
1186 std::unordered_set<Node *> &traversedNodes,
1187 const std::map<glow::Kinded::Kind, size_t> &pairSlsWithNodeKinds,
1188 bool concatTanhSinkApplied) {
1189 traversedNodes.insert(node);
1190 bool covered = true;
1191 auto users = node->getUsers();
1192 for (auto j = users.begin(), f = users.end(); j != f; ++j) {
1193 Node *user = (*j).getUser();
1194 if (ClipNode *CN = llvm::dyn_cast<ClipNode>(user)) {
1195 expandFrontier(user, CN->getResult(), frontier, traversedNodes,
1196 pairSlsWithNodeKinds, concatTanhSinkApplied);
1197 } else {
1198 auto it = pairSlsWithNodeKinds.find(user->getKind());
1199 if (it != pairSlsWithNodeKinds.end()) {
1200
1201 if (it->first == glow::Kinded::Kind::ConcatNodeKind) {
1202 auto concatUsers = user->getUsers();
1203 // If tanh sink was applied, only include concats which go into tanh
1204 // sink
1205 if (concatTanhSinkApplied && !concatUsers.empty() &&
1206 concatUsers.begin()->getUser()->getKind() ==
1207 glow::Kinded::Kind::TanhNodeKind) {
1208 expandFrontier(user, user->getNthResult(it->second), frontier,
1209 traversedNodes, pairSlsWithNodeKinds,
1210 concatTanhSinkApplied);
1211 }
1212 // If tanh sink was not applied, only include concats whose inputs are
1213 // all tanh
1214 else if (!concatTanhSinkApplied &&
1215 checkNodeInputsAllKind(user,
1216 glow::Kinded::Kind::TanhNodeKind)) {
1217 expandFrontier(user, user->getNthResult(it->second), frontier,
1218 traversedNodes, pairSlsWithNodeKinds,
1219 concatTanhSinkApplied);
1220 } else {
1221 covered = false;
1222 }
1223 } else {
1224 expandFrontier(user, user->getNthResult(it->second), frontier,
1225 traversedNodes, pairSlsWithNodeKinds,
1226 concatTanhSinkApplied);
1227 }
1228 } else {
1229 covered = false;
1230 }
1231 }
1232 }
1233 if (!covered) {
1234 frontier.insert(value);
1235 }
1236}
1237
1238/// Helper function for SparseNN Partitioning scheme. Checks for each
1239/// kind of SLS table and appends their metadata to the vector.
1240template <typename SLSType>
1241static Error appendSLSTable(SLSType *SLS, std::vector<SLSTableInfo> &slsTables,
1242 bool doPerfModelBalance, Backend *backend,
1243 const std::vector<std::string> &pairSLSWith,
1244 bool concatTanhSinkApplied) {
1245 uint64_t cost = 1;
1246 uint64_t numBytesInTable =
1247 (uint64_t)SLS->getData().getType()->getSizeInBytes();
1248
1249 // If average length is available, then compute cost using perf model
1250 if (doPerfModelBalance) {
1251 double cost_d;
1252 ASSIGN_VALUE_OR_RETURN_ERR(cost_d, backend->estimateNodeCost(SLS));
1253 cost = (uint64_t)cost_d;
1254 }
1255 auto slsResult = SLS->getResult();
1256 const std::unordered_map<std::string, std::pair<glow::Kinded::Kind, size_t>>
1257 nameToNodeKind{
1258 {"Concat",
1259 {glow::Kinded::Kind::ConcatNodeKind, ConcatNode::ResultIdx}},
1260 {"LayerNorm",
1261 {glow::Kinded::Kind::LayerNormalizationNodeKind,
1262 LayerNormalizationNode::ResultIdx}},
1263 {"Tile", {glow::Kinded::Kind::TileNodeKind, TileNode::ResultIdx}},
1264 {"Tanh", {glow::Kinded::Kind::TanhNodeKind, TanhNode::ResultIdx}},
1265 };
1266 std::map<glow::Kinded::Kind, size_t> pairSlsWithNodeKinds;
1267 for (auto &s : pairSLSWith) {
1268 if (nameToNodeKind.find(s) == nameToNodeKind.end() ||
1269 pairSlsWithNodeKinds.find(nameToNodeKind.at(s).first) !=
1270 pairSlsWithNodeKinds.end()) {
1271 continue;
1272 }
1273 // Only expand SLS w/ tile for user embeddings
1274 if (s == "Tile") {
1275 // The first dimension = 1 corresponds to user embeddings, so we expand w/
1276 // Tile
1277 if (slsResult.dims()[0] == 1) {
1278 pairSlsWithNodeKinds.insert(nameToNodeKind.at(s));
1279 }
1280 } else {
1281 pairSlsWithNodeKinds.insert(nameToNodeKind.at(s));
1282 }
1283 }
1284 std::unordered_set<NodeValue> frontier;
1285 std::unordered_set<Node *> neighbors;
1286 expandFrontier(SLS, slsResult, frontier, neighbors, pairSlsWithNodeKinds,
1287 concatTanhSinkApplied);
1288
1289 // neighbors contains only successors; add all predecessors too.
1290 std::unordered_set<Node *> addedSLSNeighbors;
1291 std::queue<Node *> preds;
1292 for (auto *N : neighbors) {
1293 preds.push(N);
1294 }
1295 preds.push(SLS);
1296 auto hasConcat = pairSlsWithNodeKinds.find(Kinded::Kind::ConcatNodeKind) !=
1297 pairSlsWithNodeKinds.end();
1298 while (!preds.empty()) {
1299 auto *cur = preds.front();
1300 if (cur != SLS) {
1301 neighbors.insert(cur);
1302 // Sum up the total sizes of SLS nodes under the same concat since they'll
1303 // all be in the same partition
1304 if (hasConcat && isSLSNode(cur) &&
1305 addedSLSNeighbors.find(cur) == addedSLSNeighbors.end()) {
1306 addedSLSNeighbors.insert(cur);
1307 switch (cur->getKind()) {
1308#define ADD_SLS_NB_NODE_SIZE_CASE(NODE_NAME_) \
1309 case Kinded::Kind::NODE_NAME_##Kind: { \
1310 auto SLS = llvm::cast<NODE_NAME_>(cur); \
1311 numBytesInTable += (uint64_t)SLS->getData().getType()->getSizeInBytes(); \
1312 } \
1313 continue;
1314
1315 ADD_SLS_NB_NODE_SIZE_CASE(
1316 FusedRowwiseQuantizedSparseLengthsWeightedSumNode);
1317 ADD_SLS_NB_NODE_SIZE_CASE(FusedRowwiseQuantizedSparseLengthsSumNode);
1318 ADD_SLS_NB_NODE_SIZE_CASE(
1319 RowwiseQuantizedSparseLengthsWeightedSumNode);
1320 ADD_SLS_NB_NODE_SIZE_CASE(SparseLengthsSumNode);
1321 ADD_SLS_NB_NODE_SIZE_CASE(SparseLengthsWeightedSumNode);
1322 ADD_SLS_NB_NODE_SIZE_CASE(EmbeddingBagNode);
1323 ADD_SLS_NB_NODE_SIZE_CASE(EmbeddingBagByteRowwiseOffsetsNode);
1324#undef ADD_SLS_NB_NODE_SIZE_CASE
1325 default:
1326 continue;
1327 }
1328 }
1329 }
1330 preds.pop();
1331 for (auto *N : getInputs(cur)) {
1332 preds.push(N);
1333 }
1334 }
1335
1336 slsTables.push_back(
1337 {SLS, neighbors, frontier, numBytesInTable, 0, slsResult, cost});
1338 return Error::success();
1339}
1340
1341// Check if the input for \p targetNode is a SplatNode with more than one
1342// user, and if so clone the splat node into \p F and set it to be the new
1343// input of \p targetNode.
1344static void cloneSplatInputIfNecessary(Node *targetNode, Function *F) {
1345 for (int inp = 0, e = targetNode->getNumInputs(); inp < e; inp++) {
1346 auto input = targetNode->getNthInput(inp);
1347 SplatNode *splatInput = llvm::dyn_cast<SplatNode>(input.getNode());
1348 if (!splatInput || splatInput->getNumUsers() <= 1) {
1349 continue;
1350 }
1351 SplatNode *splatInputClone =
1352 F->addNode(llvm::cast<SplatNode>(splatInput->clone()));
1353 targetNode->setNthInput(inp, splatInputClone->getResult());
1354 }
1355}
1356
1357// Insert Split->Concat at barrier between SLS and Non-SLS partitions
1358static Error
1359sparseNNInsertSplitConcat(Function *F,
1360 std::vector<std::unordered_set<NodeValue>> &frontiers,
1361 PartitionConfig &partitionConfig) {
1362
1363 // Walk through SLS tables and check that all the results are able to concat
1364 std::vector<std::vector<NodeValue>> concatInputs(frontiers.size());
1365 // Insert concat and slice nodes and assign them to partitions
1366 for (size_t p = 0; p < frontiers.size(); p++) {
1367 auto &frontier = frontiers[p];
1368
1369 if (frontier.size() == 0) {
1370 continue;
1371 }
1372 auto &templateResult = *frontier.begin();
1373 auto templateDims = templateResult.dims();
1374 auto templateConcatDim = templateDims.size() - 1;
1375
1376 for (auto &tableResult : frontier) {
1377 auto tableDims = tableResult.dims();
1378 RETURN_ERR_IF_NOT(tableDims.size() == templateDims.size(),
1379 strFormat("SLS concat addition encountered tensors "
1380 "with differing dimensions (%zu vs %zu)",
1381 (size_t)tableDims.size(),
1382 (size_t)templateDims.size()));
1383 for (dim_t otherDim = 0; otherDim < templateConcatDim; otherDim++) {
1384 RETURN_ERR_IF_NOT(tableDims[otherDim] == templateDims[otherDim],
1385 strFormat("SLS concat addition encountered tensors "
1386 "with differing dimension (%zu vs %zu)",
1387 (size_t)tableDims[otherDim],
1388 (size_t)templateDims[otherDim]));
1389 }
1390 RETURN_ERR_IF_NOT(tableResult.getType()->getElementType() ==
1391 templateResult.getType()->getElementType(),
1392 "SLS concat addition encountered tensors with "
1393 "differing ElementType");
1394 concatInputs[p].push_back(tableResult);
1395 }
1396
1397 if (concatInputs[p].size() > 1) {
1398
1399 // Insert concat
1400 auto *deviceConcat = F->createConcat("concat_dev_" + std::to_string(p),
1401 concatInputs[p], templateConcatDim);
1402 partitionConfig.nodeToPartition[deviceConcat->getName()] = p;
1403
1404 // Insert slices
1405 std::vector<dim_t> splits(concatInputs[p].size());
1406 for (dim_t i = 0; i < concatInputs[p].size(); i++) {
1407 auto inputDim = concatInputs[p][i].dims();
1408 splits[i] = inputDim[templateConcatDim];
1409 }
1410 std::vector<SliceNode *> splitOutputs;
1411 F->createSplit("split_dev" + std::to_string(p), deviceConcat,
1412 splits.size(), templateConcatDim, splits, splitOutputs);
1413 for (dim_t i = 0; i < concatInputs[p].size(); i++) {
1414 assert(i < splitOutputs.size());
1415 concatInputs[p][i].replaceAllUsesOfWith(splitOutputs[i]);
1416 deviceConcat->setNthInput(i, concatInputs[p][i]);
1417 partitionConfig.nodeToPartition[splitOutputs[i]->getName()] =
1418 partitionConfig.numOfPartitions - 1;
1419 }
1420 }
1421 }
1422 return Error::success();
1423};
1424
1425Expected<DAGListTy> Partitioner::partitionSparseNN(CompilationContext &cctx) {
1426 VLOG(1) << "Doing SparseNN partitioning" << std::endl;
1427 PartitionConfig partitionConfig;
1428 partitionConfig.numOfPartitions = 0;
1429
1430 // Find the first partition with an SLS node
1431 Function *F = nullptr;
1432 for (Function *currF : module_->getFunctions()) {
1433 for (auto &node : currF->getNodes()) {
1434 if (node.getKind() ==
1435 glow::Kinded::Kind::
1436 FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind ||
1437 node.getKind() == glow::Kinded::Kind::
1438 FusedRowwiseQuantizedSparseLengthsSumNodeKind ||
1439 node.getKind() ==
1440 glow::Kinded::Kind::
1441 RowwiseQuantizedSparseLengthsWeightedSumNodeKind ||
1442 node.getKind() == glow::Kinded::Kind::SparseLengthsSumNodeKind ||
1443 node.getKind() ==
1444 glow::Kinded::Kind::SparseLengthsWeightedSumNodeKind ||
1445 node.getKind() == glow::Kinded::Kind::EmbeddingBagNodeKind ||
1446 node.getKind() ==
1447 glow::Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind) {
1448 F = currF;
1449 break;
1450 }
1451 }
1452 if (F) {
1453 break;
1454 }
1455 }
1456
1457 // If no matching functions then return empty config
1458 if (!F) {
1459 return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
1460 "Did not find a partition with an SLS node");
1461 }
1462
1463 if (deviceInfo_.size() <
1464 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards) {
1465 return MAKE_ERR(
1466 ErrorValue::ErrorCode::PARTITIONER_ERROR,
1467 strFormat("Not enough devices to partition. Num Devices is %zu and Num "
1468 "SparseNN Cards Needed is %u",
1469 deviceInfo_.size(),
1470 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards));
1471 }
1472
1473 // Otherwise partition this function
1474 partitionConfig.funcName = F->getName().str();
1475
1476 // First optimize the function
1477 std::vector<Backend *> backends;
1478 genBackendMap(backendMap_, backendHolder_, backends);
1479 // First optimize it
1480 if (!optimized_) {
1481 RETURN_IF_ERR(::glow::optimizeFunction(F, *(backends[0]), cctx));
1482 }
1483
1484 // Now we may want to duplicate Splat input nodes in case they have been
1485 // CSE'd (CSE stands for common subexpression elimination) into a single
1486 // SplatNode. This is because if two SLWS that share Splat input nodes are
1487 // separated to two partitions, then partitioning will force a dependence
1488 // from whichever partition the input node are placed to the other
1489 // partition. After partitioning when we optimize each partition
1490 // individually, they may be merged again inside the partition. Besides, the
1491 // potential partition dependency introduced might lead to a circular
1492 // dependency in the final graph.
1493 //
1494 // We fix this issue by iterating over the Function and finding Splat input
1495 // nodes with multiple users and just creating new Splats (by cloning) for
1496 // each user.
1497 for (auto &node : F->getNodes()) {
1498 cloneSplatInputIfNecessary(&node, F);
1499 }
1500
1501 // Create list of SLS Tables
1502 std::vector<SLSTableInfo> slsTables;
1503 partitionConfig.funcName = std::string(F->getName());
1504 VLOG(1) << "Function: " << std::string(F->getName()) << std::endl;
1505
1506 std::vector<std::string> pairSLSWith;
1507 folly::split<char, std::string, std::string>(
1508 ',', cctx.optimizationOpts.sparseNNPartitioningPairSLSWith, pairSLSWith,
1509 /*ignoreEmpty*/ true);
1510 if (cctx.optimizationOpts.sparseNNPartitioningPairTileWithSLS) {
1511 pairSLSWith.emplace_back("Tile");
1512 }
1513 if (cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS) {
1514 pairSLSWith.emplace_back("LayerNorm");
1515 }
1516 bool concatTanhSinkApplied = cctx.optimizationOpts.sinkTanhBelowConcat;
1517 if (std::find(pairSLSWith.begin(), pairSLSWith.end(), "Concat") !=
1518 pairSLSWith.end()) {
1519 auto splitConcatSize =
1520 cctx.optimizationOpts.sparseNNPartitioningConcatSplitSize;
1521 splitConcatTanh(F, splitConcatSize, pairSLSWith, concatTanhSinkApplied);
1522 }
1523 const bool doPerfModelBalance =
1524 cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel;
1525 size_t totalSLSTableSizes = 0;
1526 for (auto &node : F->getNodes()) {
1527 switch (node.getKind()) {
1528
1529#define APPEND_TABLE_CASE(NODE_NAME_) \
1530 case Kinded::Kind::NODE_NAME_##Kind: \
1531 RETURN_IF_ERR(appendSLSTable<NODE_NAME_>( \
1532 llvm::cast<NODE_NAME_>(&node), slsTables, doPerfModelBalance, \
1533 backends[0], pairSLSWith, concatTanhSinkApplied)); \
1534 totalSLSTableSizes += slsTables.back().numBytesInTable; \
1535 continue;
1536
1537 APPEND_TABLE_CASE(FusedRowwiseQuantizedSparseLengthsWeightedSumNode);
1538 APPEND_TABLE_CASE(FusedRowwiseQuantizedSparseLengthsSumNode);
1539 APPEND_TABLE_CASE(RowwiseQuantizedSparseLengthsWeightedSumNode);
1540 APPEND_TABLE_CASE(SparseLengthsSumNode);
1541 APPEND_TABLE_CASE(SparseLengthsWeightedSumNode);
1542 APPEND_TABLE_CASE(EmbeddingBagNode);
1543 APPEND_TABLE_CASE(EmbeddingBagByteRowwiseOffsetsNode);
1544#undef APPEND_TABLE_CASE
1545
1546 default:
1547 continue;
1548 }
1549 }
1550 LOG(INFO) << "Total size of all " << slsTables.size()
1551 << " SLS embedding tables: " << totalSLSTableSizes;
1552
1553 // Now determine all nodes that fit in the NonSLS partition, so we know its
1554 // total size and can better judge how much space is left for SLS
1555 // partitions.
1556 std::unordered_set<const Node *> slsPartitionNodes;
1557 for (auto &slsTable : slsTables) {
1558 slsPartitionNodes.insert(slsTable.node);
1559 for (const Node *N : slsTable.neighbors) {
1560 slsPartitionNodes.insert(N);
1561 }
1562 }
1563
1564 NodesSet nonSLSPartitionNodes;
1565 for (auto &node : F->getNodes()) {
1566 if (!slsPartitionNodes.count(&node)) {
1567 nonSLSPartitionNodes.insert(&node);
1568 }
1569 }
1570
1571 // Calculate how much space the NonSLS partition takes up, and compare that
1572 // to how much memory the device has to determine the allows SLS partition
1573 // size.
1574 const uint64_t nonSLSPartitionSize =
1575 getGraphMemInfo(nonSLSPartitionNodes, contextCount_).getTotalMemSize();
1576 const uint64_t totalDeviceMemory = deviceInfo_[0].availableMemory;
1577 RETURN_ERR_IF_NOT(nonSLSPartitionSize < totalDeviceMemory,
1578 strFormat("nonSLSPartitionSize %lu must be less than %s "
1579 "totalDeviceMemory %lu",
1580 nonSLSPartitionSize,
1581 deviceInfo_[0].backendName.c_str(),
1582 totalDeviceMemory));
1583 const uint64_t allowedSLSMemBytes = totalDeviceMemory - nonSLSPartitionSize;
1584
1585 // Create table of devices
1586 std::vector<SLSDeviceInfo> slsDevices;
1587 std::vector<std::unordered_set<NodeValue>> frontierValues;
1588 unsigned int snnNumCards =
1589 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards;
1590
1591 LOG(INFO) << "totalDeviceMemory=" << totalDeviceMemory
1592 << ", nonSLSPartitionSize=" << nonSLSPartitionSize
1593 << ", allowedSLSMemBytes=" << allowedSLSMemBytes
1594 << ", snnNumCards=" << snnNumCards;
1595
1596 bool partitionSucceeded = false;
1597 std::vector<unsigned int> factors;
1598 factors.push_back(snnNumCards);
1599 for (unsigned int i = snnNumCards + 1, e = deviceInfo_.size(); i <= e; ++i) {
1600 if (deviceInfo_.size() % i == 0) {
1601 factors.push_back(i);
1602 }
1603 }
1604 auto it = std::lower_bound(factors.begin(), factors.end(), snnNumCards);
1605 for (unsigned i = std::distance(factors.begin(), it); i < factors.size();
1606 i++) {
1607 snnNumCards = factors[i];
1608 LOG(INFO) << "Trying " << snnNumCards << " sparse partitions.";
1609 // Reset some of the contexts.
1610 slsDevices.clear();
1611 for (unsigned int device = 0; device < snnNumCards; device++) {
1612 slsDevices.push_back({device, allowedSLSMemBytes, 0});
1613 }
1614 frontierValues.clear();
1615 frontierValues.resize(slsDevices.size());
1616
1617 // Now assign SLS Nodes to devices
1618 if (ERR_TO_BOOL(assignSlsTablesToDevices(slsTables, slsDevices,
1619 frontierValues, contextCount_))) {
1620 LOG(INFO) << "Failed to partition SLS tables, fall back to greedy "
1621 "algorithm.";
1622 if (!ERR_TO_BOOL(assignSlsTablesToDevicesGreedy(
1623 slsTables, slsDevices, frontierValues, contextCount_))) {
1624 partitionSucceeded = true;
1625 };
1626 } else {
1627 partitionSucceeded = true;
1628 }
1629
1630 if (partitionSucceeded) {
1631 LOG(INFO) << "Successfully got a SparseNN partition solution with "
1632 << snnNumCards << " sparse partitions.";
1633 break;
1634 } else {
1635 LOG(WARNING) << "Cannot find a valid SparseNN partition solution with "
1636 << snnNumCards << " sparse partitions.";
1637 }
1638 }
1639
1640 if (!partitionSucceeded) {
1641 return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
1642 "SLS Balancing Partitioning Error: Not enough memory");
1643 }
1644
1645 VLOG(1) << "Final table assignments: ";
1646 printSlsTableInfo(slsTables);
1647
1648 // Fill up the last partition with NonSLS nodes.
1649 for (auto *node : nonSLSPartitionNodes) {
1650 partitionConfig.nodeToPartition[node->getName()] = snnNumCards;
1651 }
1652
1653 // Create manual partition
1654 partitionConfig.numOfPartitions = slsDevices.size() + 1;
1655 std::vector<unsigned int> allLogicalIDs;
1656
1657 // Add SLS Partitions
1658 for (size_t p = 0; p < slsDevices.size(); p++) {
1659 partitionConfig.partitionNames.push_back(std::string("SLSPartition_") +
1660 std::to_string(p));
1661 partitionConfig.backendNames.push_back(deviceInfo_[p].backendName);
1662 partitionConfig.logicalIDs.push_back({(unsigned int)p});
1663 BackendHints backendHints;
1664 backendHints.executionUnits =
1665 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresSLS;
1666 partitionConfig.backendHints.push_back(backendHints);
1667 allLogicalIDs.push_back(p);
1668 }
1669
1670 // Add last partition
1671 partitionConfig.partitionNames.push_back(std::string("NonSLSPartition_"));
1672 partitionConfig.backendNames.push_back(deviceInfo_[0].backendName);
1673 partitionConfig.logicalIDs.push_back(allLogicalIDs);
1674 BackendHints backendHints;
1675 backendHints.executionUnits =
1676 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresOther;
1677 partitionConfig.backendHints.push_back(backendHints);
1678
1679 // Map SLS nodes to their partitions
1680 for (auto &table : slsTables) {
1681 partitionConfig.nodeToPartition[table.node->getName()] = table.deviceId;
1682 for (Node *N : table.neighbors) {
1683 partitionConfig.nodeToPartition[N->getName()] = table.deviceId;
1684 }
1685 }
1686
1687 // Insert Split->Concat at barrier between SLS and Non-SLS partitions
1688 if (cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats) {
1689 RETURN_IF_ERR(
1690 sparseNNInsertSplitConcat(F, frontierValues, partitionConfig));
1691 }
1692
1693 VLOG(1) << " Finished SparseNN partitioning" << std::endl;
1694 VLOG(1) << " PartitionConfig ::: funcName = " << partitionConfig.funcName
1695 << "\n";
1696 VLOG(1) << " PartitionConfig ::: numOfPartitions = "
1697 << partitionConfig.numOfPartitions << "\n";
1698 VLOG(1) << " PartitionConfig ::: partitionNames = ";
1699 for (unsigned i = 0; i < partitionConfig.numOfPartitions; i++) {
1700 VLOG(1) << partitionConfig.partitionNames[i] << " ";
1701 }
1702 VLOG(1) << "\n";
1703 VLOG(1) << " PartitionConfig ::: logicalIDs = ";
1704 for (unsigned i = 0; i < partitionConfig.numOfPartitions; i++) {
1705 for (auto &id : partitionConfig.logicalIDs[i]) {
1706 VLOG(1) << id << " ";
1707 }
1708 VLOG(1) << "\n";
1709 }
1710
1711 DAGListTy partitions;
1712 ASSIGN_VALUE_OR_RETURN_ERR(partitions,
1713 partitionFromConfig(partitionConfig, cctx));
1714 if (cctx.saturateHost) {
1715 saturateHost(snnNumCards, partitions, cctx.saturateKDevices);
1716 }
1717 return std::move(partitions);
1718}
1719
1720Expected<DAGListTy> Partitioner::partition(CompilationContext &cctx) {
1721 if (cctx.prepartitionedConfig &&
1722 cctx.prepartitionedConfig->funcs.size() != 0) {
1723 VLOG(1) << "Using prepartitioned config";
1724 return setupPrepartitionedModule(cctx);
1725 }
1726
1727 if (cctx.partitionConfig) {
1728 VLOG(1) << "Using partition config";
1729 partitionConfig_ = *cctx.partitionConfig;
1730 }
1731
1732 if (partitionConfig_.enabled()) {
1733 // Call user-defined partition flow.
1734 return partitionFromConfig(partitionConfig_, cctx);
1735 }
1736
1737 if (!multiBackendNames_ &&
1738 cctx.optimizationOpts.useSparseNNPartitioningScheme) {
1739 VLOG(1) << "Using SNN Partition Scheme";
1740 return partitionSparseNN(cctx);
1741 }
1742
1743 if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) {
1744 // Call quantization profiling partition flow.
1745 VLOG(1) << "Using QuantProfile Partition";
1746 return quantizationProfilingPartition(cctx);
1747 }
1748
1749 if (!multiBackendNames_ && glow::flags::EnableLoadBalancedPartitioning) {
1750 // Call load-balance partition flow.
1751 VLOG(1) << "Using Load balance Partition";
1752 return loadBalancedPartition(cctx);
1753 }
1754
1755 VLOG(1) << "Using Heterogenous Partition";
1756 // Call heterogeneous partition flow.
1757 return heterogeneousPartition(cctx);
1758}
1759