1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Partitioner/Partitioner.h" |
18 | |
19 | #include "folly/String.h" |
20 | #include "glow/Flags/Flags.h" |
21 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
22 | #include "glow/Partitioner/PartitionerOptimizer.h" |
23 | #include "glow/Partitioner/PartitionerUtils.h" |
24 | #include "glow/Partitioner/PartitionerValidation.h" |
25 | #include "glow/Support/Support.h" |
26 | |
27 | #include "llvm/ADT/SmallSet.h" |
28 | #include "llvm/Support/CommandLine.h" |
29 | #include "llvm/Support/raw_ostream.h" |
30 | #include <unordered_map> |
31 | |
32 | #include <fstream> |
33 | |
34 | namespace glow { |
35 | static llvm::cl::opt<bool, /* ExternalStorage */ true> |
36 | GlowEnableLoadBalancedPartitioningOpt( |
37 | "partitioner_enable_load_balance" , |
38 | llvm::cl::desc( |
39 | "Enable a partitioner pass to optimize for " |
40 | "load balance in addition to memory capacity constraints" ), |
41 | llvm::cl::location(glow::flags::EnableLoadBalancedPartitioning)); |
42 | } // namespace glow |
43 | |
44 | /// -log-partition - Command line option to dump Partitioner logs. |
45 | static llvm::cl::OptionCategory PartitionerCat("Glow Partitioner Options" ); |
46 | static llvm::cl::opt<bool, /* ExternalStorage */ true> |
47 | logPartition("log-partition" , |
48 | llvm::cl::desc("Enable logging partition info" ), |
49 | llvm::cl::location(glow::flags::LogPartition), |
50 | llvm::cl::cat(PartitionerCat)); |
51 | |
52 | /// -dump-partition - Command line option to dump the graph of each partitions |
53 | /// by calling F->dumpDAG(). |
54 | static llvm::cl::opt<bool, /* ExternalStorage */ true> |
55 | dumpPartition("dump-partition" , |
56 | llvm::cl::desc("Enable dumping the graph of each partitions" ), |
57 | llvm::cl::location(glow::flags::DumpPartition), |
58 | llvm::cl::cat(PartitionerCat)); |
59 | |
60 | using namespace glow; |
61 | using llvm::isa; |
62 | |
63 | // Sorted the std::pair<DAGNode *, uint64_t> based on the second from min to |
64 | // max. |
65 | bool sortMinMemory(const std::pair<Function *, uint64_t> &a, |
66 | const std::pair<Function *, uint64_t> &b) { |
67 | return a.second < b.second; |
68 | } |
69 | |
70 | void Partitioner::init() { |
71 | memSize_ = module_->getConstantsSize(); |
72 | logicalDeviceID_ = 0; |
73 | multiBackendNames_ = false; |
74 | for (size_t i = 1, e = deviceInfo_.size(); i < e; i++) { |
75 | if (deviceInfo_[i].backendName != deviceInfo_[0].backendName) { |
76 | multiBackendNames_ = true; |
77 | break; |
78 | } |
79 | } |
80 | } |
81 | |
82 | Error Partitioner::finalize(const DAGListTy &partitions, |
83 | const NodeToFunctionMap &mapping) { |
84 | |
85 | // NOTE: Cannot validate the functions after partitioning here. The validation |
86 | // needs the backend specific verifier. Tensor layouts, for example, might |
87 | // have gone from canonical form to backend specific form. |
88 | |
89 | if (logPartition) { |
90 | LOG(INFO) << "The number of partitions is : " |
91 | << mapping.getPartitions().size(); |
92 | logPartitionInfo(mapping); |
93 | } |
94 | |
95 | // Dump the graph of each function after partitioning. |
96 | if (dumpPartition) { |
97 | LOG(INFO) << "Dumping partitioning DAG to DAG.dot file." ; |
98 | dumpDAG("DAG.dot" , partitions); |
99 | for (const auto &node : partitions[0].nodes) { |
100 | Function *subF = module_->getFunction(node->name); |
101 | if (!subF) { |
102 | // If we fail dump partition info for debugging. |
103 | logPartitionInfo(mapping); |
104 | return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR, |
105 | "Invalid function name " + node->name); |
106 | } |
107 | subF->dumpDAG("partitionLogicalID" + |
108 | std::to_string(node->logicalDevices[0]) + "__" + |
109 | subF->getName().str() + "__" + node->backendName + ".dot" ); |
110 | } |
111 | } |
112 | return Error::success(); |
113 | } |
114 | |
115 | Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices, |
116 | const std::vector<Backend *> &backends, bool optimized) |
117 | : module_(parent), deviceInfo_(devices), backends_(backends), |
118 | optimized_(optimized) { |
119 | init(); |
120 | } |
121 | |
122 | Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices, |
123 | bool optimized, PartitionConfig partitionConfig) |
124 | : module_(parent), deviceInfo_(devices), optimized_(optimized), |
125 | partitionConfig_(partitionConfig) { |
126 | init(); |
127 | } |
128 | |
129 | Function *Partitioner::selectRepFunc(Module *parent, uint64_t &memSize) { |
130 | auto funcList = parent->getFunctions(); |
131 | Function *ret = nullptr; |
132 | uint64_t maxMemSize = 0; |
133 | for (Function *F : funcList) { |
134 | uint64_t curSize = memSize; |
135 | |
136 | // The set to keep the placeholders (only for Inputs) whose size is |
137 | // already calculated. |
138 | std::set<llvm::StringRef> pSet; |
139 | |
140 | for (auto &node : F->getNodes()) { |
141 | int n = node.getNumInputs(); |
142 | if (node.getKind() == Kinded::Kind::SaveNodeKind) { |
143 | // Special node, the placeholder should be ignored? |
144 | continue; |
145 | } |
146 | for (int i = 0; i < n; i++) { |
147 | Placeholder *in = |
148 | llvm::dyn_cast<Placeholder>(node.getNthInput(i).getNode()); |
149 | if (in && pSet.find(in->getName()) == pSet.end()) { |
150 | auto ty = in->getType(); |
151 | curSize += ty->getSizeInBytes(); |
152 | pSet.insert(in->getName()); |
153 | } |
154 | } |
155 | } |
156 | // Find the function with largest required memory as the representative |
157 | // function. |
158 | if (!ret || curSize > maxMemSize) { |
159 | ret = F; |
160 | maxMemSize = curSize; |
161 | } |
162 | } |
163 | memSize = maxMemSize; |
164 | return ret; |
165 | } |
166 | |
167 | void Partitioner::partitionsAdjust(NodeToFunctionMap &partitions, |
168 | uint64_t availableMemory) { |
169 | // For each partition, create a node set. |
170 | FunctionToNodesMap nodesSet; |
171 | for (auto it = partitions.begin(); it != partitions.end(); ++it) { |
172 | nodesSet[(*it).second].insert((*it).first); |
173 | } |
174 | |
175 | // Optimize the communication cost. |
176 | optimizeCommunicationCost(partitions, nodesSet, module_, availableMemory); |
177 | |
178 | // Combine the current partitions if necessary. |
179 | partitionsCombine(partitions, nodesSet, module_, availableMemory); |
180 | } |
181 | |
182 | /// Assign nodes to partitions and return the mapping. |
183 | NodeToFunctionMap Partitioner::selectPartitions(Function *F, |
184 | uint64_t availableMemory, |
185 | llvm::StringRef backendName) { |
186 | NodeToFunctionMap mapping; |
187 | BFSLevel bfs = getBFSLevel(F); |
188 | size_t level = bfs.size(); |
189 | |
190 | // Step 1 : get the initial cut based on BFS levels and availableMemory. |
191 | int color = 0; |
192 | Function *newF; |
193 | newF = F->getParent()->createFunction(std::string(F->getName()) + "_part" + |
194 | std::to_string(++color)); |
195 | mapping.createPartition(newF, backendName); |
196 | NodesSet currentPartition; |
197 | GraphMemInfo graphMem; |
198 | graphMem.contextCount = contextCount_; |
199 | |
200 | for (int i = level - 1; i >= 0; i--) { |
201 | for (size_t j = 0, e = bfs[i].size(); j < e; j++) { |
202 | Node *N = bfs[i][j]; |
203 | graphMem = updateGraphMemInfoByAddingNode(currentPartition, graphMem, N); |
204 | // If after adding node N, the memory usage of this partition exceeds the |
205 | // device memory limitations, N can't be added into the current partition |
206 | // and a new partition is created. |
207 | if (graphMem.getTotalMemSize() > availableMemory) { |
208 | newF = F->getParent()->createFunction( |
209 | std::string(F->getName()) + "_part" + std::to_string(++color)); |
210 | mapping.createPartition(newF, backendName); |
211 | currentPartition.clear(); |
212 | graphMem = |
213 | updateGraphMemInfoByAddingNode(currentPartition, GraphMemInfo{}, N); |
214 | } |
215 | currentPartition.insert(N); |
216 | mapping.add(N, newF); |
217 | graphMem.contextCount = contextCount_; |
218 | mapping.setGraphMemInfo(newF, graphMem); |
219 | } |
220 | } |
221 | |
222 | // Step 2 : adjust the partition based on performance. |
223 | partitionsAdjust(mapping, availableMemory); |
224 | |
225 | return mapping; |
226 | } |
227 | |
228 | void Partitioner::saturateHost(unsigned logicalDeviceCount, |
229 | const DAGListTy &partitions, |
230 | size_t availableLogicalDevices) { |
231 | DCHECK(availableLogicalDevices <= deviceInfo_.size()) |
232 | << "Requested number of logical devices must be less than or euqal " |
233 | "the number of found devices." ; |
234 | // If not specified, use number of available physical devices. |
235 | if (availableLogicalDevices == 0 || |
236 | availableLogicalDevices > deviceInfo_.size()) { |
237 | availableLogicalDevices = deviceInfo_.size(); |
238 | } |
239 | unsigned duplications = availableLogicalDevices / logicalDeviceCount; |
240 | if (duplications < 2) { |
241 | return; |
242 | } |
243 | // Add additional logical devices to each node. |
244 | for (auto &network : partitions) { |
245 | for (auto &node : network.nodes) { |
246 | // Set instanceCount. |
247 | node->instanceCount = duplications; |
248 | // Build list of new logical devices to add to node. |
249 | std::vector<unsigned> newDevices; |
250 | for (auto logical : node->logicalDevices) { |
251 | // To ensure we do not have a logicalID collision we use the following |
252 | // scheme. We have an iterator starting at 1 for each duplication pass. |
253 | // The new ID we add is calculated as follows: |
254 | // (iterator * logicalDeviceCount) + initialLogicalID |
255 | for (unsigned i = 1; i < duplications; i++) { |
256 | newDevices.push_back(logical + (i * logicalDeviceCount)); |
257 | } |
258 | } |
259 | // Append the new logical devices to the node's logical device vector. |
260 | node->logicalDevices.insert(node->logicalDevices.end(), |
261 | newDevices.begin(), newDevices.end()); |
262 | } |
263 | } |
264 | } |
265 | |
266 | Expected<DAGListTy> Partitioner::backendBasedPartition( |
267 | FunctionToBackendNameMap &funcToBackend, Function *F, |
268 | std::vector<Backend *> &backends, CompilationContext &cctx) { |
269 | NodeToFunctionMap mapping; |
270 | llvm::DenseMap<Node *, std::string> nodeToBackendName; |
271 | |
272 | // For each node find a backend that supports it. |
273 | for (auto &N : F->getNodes()) { |
274 | for (auto &backend : backends) { |
275 | // Find the first backend that supports this node. The order of backends |
276 | // is important. The check flow is : |
277 | |
278 | // Step 1: If a node is in pre-defined non-supported nodes set, it can not |
279 | // be assigned to this backend. Continue. |
280 | const auto &nonSupportedNodesKinds = |
281 | backendMap_[backend->getBackendName()].nonSupportedNodesKinds; |
282 | if (nonSupportedNodesKinds.count(N.getKind())) { |
283 | // This op is on the pre-defined non-supported op list: |
284 | continue; |
285 | } |
286 | // Step 2: If the pre-defined supported nodes set is empty, it means all |
287 | // nodes could be assigned to this backend. If the pre-defined supported |
288 | // nodes set is not empty, we check that if the node from Step 1 is in |
289 | // this set or not. If not, continue. |
290 | const auto &supportedNodesKinds = |
291 | backendMap_[backend->getBackendName()].supportedNodesKinds; |
292 | if (!supportedNodesKinds.empty() && |
293 | !supportedNodesKinds.count(N.getKind())) { |
294 | // This op is not on the pre-definded supported op list: |
295 | continue; |
296 | } |
297 | // Step 3: Check if the node is actually supported in this backend, if so, |
298 | // assign it to this backend and break. Otherwise continue. |
299 | // TODO: the logic here need to be improved. |
300 | if (backend->shouldLower(&N) || backend->isOpSupported(N)) { |
301 | // Put this node into a partition for this backend. |
302 | nodeToBackendName[&N] = backend->getBackendName(); |
303 | break; |
304 | } |
305 | } |
306 | if (nodeToBackendName.find(&N) == nodeToBackendName.end()) { |
307 | logPartitionInfo(mapping); |
308 | return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR, |
309 | "Node is not supported by any of the provided backends" ); |
310 | } |
311 | } |
312 | |
313 | BFSLevel bfs = getBFSLevel(F); |
314 | size_t level = bfs.size(); |
315 | int color = 0; |
316 | Function *newF; |
317 | newF = F->getParent()->createFunction(std::string(F->getName()) + "_part" + |
318 | std::to_string(++color)); |
319 | auto backendName = nodeToBackendName[bfs[level - 1][0]]; |
320 | if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) { |
321 | // When profiling, all the partition backend is assigned to |
322 | // profilingBackend. |
323 | mapping.createPartition(newF, profilingBackend); |
324 | funcToBackend[newF] = profilingBackend; |
325 | } else { |
326 | mapping.createPartition(newF, backendName); |
327 | funcToBackend[newF] = backendName; |
328 | } |
329 | for (int i = level - 1; i >= 0; i--) { |
330 | for (size_t j = 0, e = bfs[i].size(); j < e; j++) { |
331 | Node *N = bfs[i][j]; |
332 | auto bk = nodeToBackendName[N]; |
333 | if (bk != backendName) { |
334 | backendName = bk; |
335 | newF = F->getParent()->createFunction( |
336 | std::string(F->getName()) + "_part" + std::to_string(++color)); |
337 | if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) { |
338 | // When profiling, all the partition backend is assigned to be |
339 | // profilingBackend. |
340 | mapping.createPartition(newF, profilingBackend); |
341 | funcToBackend[newF] = profilingBackend; |
342 | } else { |
343 | mapping.createPartition(newF, backendName); |
344 | funcToBackend[newF] = backendName; |
345 | } |
346 | } |
347 | mapping.add(N, newF); |
348 | } |
349 | } |
350 | |
351 | std::vector<Function *> funcs; |
352 | funcs.push_back(F); |
353 | // When profiling, the partition flow will be stopped after |
354 | // backendBasedPartition. Therefore, the DAG needs to be generated. Otherwise, |
355 | // no need to generate DAG. |
356 | bool genDAG = cctx.precisionConfig.quantMode == QuantizationMode::Profile |
357 | ? true |
358 | : false; |
359 | if (genDAG) { |
360 | DeviceIDTy logicalDeviceID = 0; |
361 | for (auto &func : mapping.getPartitions()) { |
362 | mapping.appendLogicalDeviceID(func, logicalDeviceID++); |
363 | } |
364 | } |
365 | return doPartitioning(F->getName(), funcs, module_, mapping, genDAG, |
366 | cctx.backendOpts.backendSpecificNodeInfo); |
367 | } |
368 | |
369 | void Partitioner::genBackendMap( |
370 | std::map<std::string, BackendInfo> &backendMap, |
371 | std::vector<std::unique_ptr<Backend>> &backendsHolder, |
372 | std::vector<Backend *> &backends) { |
373 | // If the backends are created already, we use them directly. |
374 | bool hasBackends = backends_.size() != 0; |
375 | if (hasBackends) { |
376 | DCHECK(backends_.size() == deviceInfo_.size()) |
377 | << "number of backends and devices is not match." ; |
378 | } |
379 | |
380 | int n = 0; |
381 | for (size_t i = 0, e = deviceInfo_.size(); i < e; i++) { |
382 | std::string backendName = deviceInfo_[i].backendName; |
383 | if (hasBackends) { |
384 | DCHECK(backends_[i]->getBackendName() == backendName) |
385 | << "Backend Type mismatch." ; |
386 | } |
387 | if (backendMap.find(backendName) == backendMap.end()) { |
388 | BackendInfo backendInfo; |
389 | backendInfo.num = 1; |
390 | // We assume that for the same type of devices, the available memory size |
391 | // is the same. |
392 | // TODO : will improve the algorithm for different memory size. |
393 | backendInfo.memSize = deviceInfo_[i].availableMemory; |
394 | backendInfo.inputCountMax = deviceInfo_[i].inputCountMax; |
395 | backendInfo.peakDramBw = deviceInfo_[i].peakDramBw; |
396 | backendInfo.peakSramBw = deviceInfo_[i].peakSramBw; |
397 | backendInfo.sramCapacity = deviceInfo_[i].sramCapacity; |
398 | backendInfo.peakCompute = deviceInfo_[i].peakCompute; |
399 | backendInfo.nonSupportedNodesKinds = |
400 | generateNodeKindsSet(deviceInfo_[i].nonSupportedNodes); |
401 | backendInfo.supportedNodesKinds = |
402 | generateNodeKindsSet(deviceInfo_[i].supportedNodes); |
403 | if (hasBackends) { |
404 | backendInfo.backend = backends_[i]; |
405 | } else { |
406 | backendsHolder.emplace_back(createBackend(backendName)); |
407 | backendInfo.backend = backendsHolder[n++].get(); |
408 | } |
409 | backendMap[backendName] = backendInfo; |
410 | backends.push_back(backendMap[backendName].backend); |
411 | } else { |
412 | backendMap[backendName].num += 1; |
413 | // Since we are currently assuming one value it should be the max. |
414 | backendMap[backendName].memSize = std::max( |
415 | backendMap[backendName].memSize, deviceInfo_[i].availableMemory); |
416 | } |
417 | } |
418 | } |
419 | |
420 | const DeviceInfo & |
421 | Partitioner::getDeviceInfoForBackend(llvm::StringRef backendName) { |
422 | for (DeviceInfo &devInfo : deviceInfo_) { |
423 | if (devInfo.backendName == backendName) |
424 | return devInfo; |
425 | } |
426 | llvm_unreachable("Each backend should have at least one device" ); |
427 | } |
428 | |
429 | Expected<DAGListTy> Partitioner::createDAGWithoutPartition( |
430 | llvm::StringRef backendName, std::map<std::string, BackendInfo> &backendMap, |
431 | CompilationContext &cctx) { |
432 | DAGListTy partitions; |
433 | const DeviceIDTy logDevice = 0; |
434 | for (auto F : module_->getFunctions()) { |
435 | if (!optimized_) { |
436 | auto backend = backendMap[backendName.str()].backend; |
437 | RETURN_IF_ERR(::glow::optimizeFunction( |
438 | F, *backend, cctx, &getDeviceInfoForBackend(backendName))); |
439 | } |
440 | std::unique_ptr<DAGNode> DAG0 = glow::make_unique<DAGNode>(); |
441 | DAG0->logicalDevices = {logDevice}; |
442 | DAG0->name = F->getName().str(); |
443 | DAG0->module = module_; |
444 | std::unique_ptr<DAGNode> DAG1 = glow::make_unique<DAGNode>(); |
445 | DAG1->logicalDevices = {logDevice}; |
446 | DAG1->name = F->getName().str(); |
447 | DAG1->backendName = backendName.str(); |
448 | DAG1->parents.push_back(DAG0.get()); |
449 | DAG0->children.push_back(DAG1.get()); |
450 | DAG1->replicationCount = cctx.replicationCount; |
451 | DAGNodePtrVec nodes; |
452 | nodes.push_back(std::move(DAG1)); |
453 | partitions.push_back({std::move(DAG0), std::move(nodes)}); |
454 | } |
455 | if (cctx.saturateHost) { |
456 | // Saturate the Host. |
457 | saturateHost(1, partitions, cctx.saturateKDevices); |
458 | } |
459 | |
460 | NodeToFunctionMap mapping; |
461 | for (auto func : module_->getFunctions()) { |
462 | mapping.createPartition(func, backendName); |
463 | mapping.setGraphMemInfo(func, getFunctionMemory(func)); |
464 | |
465 | // Use the same hard-coded logical device ID as used for the DAG itself. |
466 | mapping.appendLogicalDeviceID(func, logDevice); |
467 | } |
468 | |
469 | RETURN_IF_ERR(finalize(partitions, mapping)); |
470 | |
471 | return std::move(partitions); |
472 | } |
473 | |
474 | Expected<DAGListTy> Partitioner::loadBalancedPartition(CompilationContext &cctx, |
475 | size_t numDevices) { |
476 | |
477 | if (multiBackendNames_) { |
478 | VLOG(1) << "For multi backend types, load-balanced partition can't be " |
479 | "applied. Call heterogeneous partition instead." ; |
480 | return heterogeneousPartition(cctx); |
481 | } |
482 | F_ = selectRepFunc(module_, memSize_); |
483 | std::string origName(F_->getName().data()); |
484 | DAGListTy partitions; |
485 | std::vector<Backend *> backends; |
486 | genBackendMap(backendMap_, backendHolder_, backends); |
487 | auto backendName = backends[0]->getBackendName(); |
488 | |
489 | if (memSize_ < backendMap_[backendName].memSize) { |
490 | // No partition is needed. Create DAGNode and return. This root is always a |
491 | // dummy function. |
492 | if (logPartition) { |
493 | LOG(INFO) << "The model is too small for applying partition.\n" |
494 | << "Model size : " << memSize_ << "\n" |
495 | << "Backend Name : " << backendName << "\n" |
496 | << "Device memory: " << backendMap_[backendName].memSize |
497 | << "\n" ; |
498 | } |
499 | return createDAGWithoutPartition(backendName, backendMap_, cctx); |
500 | } |
501 | |
502 | // Step 1: Get the minial number of partitions from auto-partition. |
503 | uint64_t availableMemory = backendMap_[backendName].memSize; |
504 | if (!optimized_) { |
505 | RETURN_IF_ERR(::glow::optimizeFunction(F_, *(backends[0]), cctx)); |
506 | } |
507 | NodeToFunctionMap mapping = |
508 | selectPartitions(F_, availableMemory, backendName); |
509 | logicalDeviceID_ = assignLogicalDeviceID(mapping, backendMap_); |
510 | |
511 | if (logicalDeviceID_ > numDevices) { |
512 | numDevices = logicalDeviceID_; |
513 | } |
514 | // Step 2: |
515 | // Currently, the load balanced partitioner disregards the input mapping |
516 | // and only uses the numPartitions input from previous partitioning passes |
517 | // But we take this in to leave open the option of using the previous mapping |
518 | // at a later point. |
519 | // The main idea here is to use the roofline estimates to load balance |
520 | // partitions. At this point, we stick to one partition per device, so |
521 | // we ensure that we only have edges from nodes in smaller partition ids to |
522 | // nodes in larger partition ids to ensure an acyclic DAGNode graph. |
523 | // |
524 | // The overall algorithm is as follows: |
525 | // Iterate through all operators in breadth-first fashion. |
526 | // For each operator do: |
527 | // (a) Find the maximum partition id of each input node. |
528 | // (b) Assign the operator to this partition if memory |
529 | // constraints are satisfied and the total sum of operator runtimes |
530 | // assigned to the partition exceeds 1/numPartitions fraction of |
531 | // overall roofline runtime |
532 | // (c) In case memory constraint isnt satisfied, then try to put operator |
533 | // in successively higher partitions until the conditions get satisfied. |
534 | // If we cannot find such a partition where this operator can be assigned, |
535 | // throw an error. |
536 | |
537 | // Initialize runtimes and memory availability per device |
538 | std::vector<float> deviceTime(numDevices, 0); |
539 | std::vector<size_t> memoryAvailable(numDevices, availableMemory); |
540 | std::vector<NodesSet> nodesInPartitions(numDevices); |
541 | std::vector<GraphMemInfo> graphMem(numDevices, GraphMemInfo{}); |
542 | std::vector<Function *> partitionFuncs(numDevices); |
543 | |
544 | // Compute total roofline time |
545 | NodeToFunctionMap partitionMap; |
546 | float totalRooflineTime = 0; |
547 | for (auto &n : F_->getNodes()) { |
548 | totalRooflineTime += |
549 | getNodeComputeTime(&n, backendMap_[deviceInfo_[0].backendName]); |
550 | } |
551 | |
552 | float timePerPartition = totalRooflineTime / numDevices; |
553 | |
554 | // Get the BFS levels |
555 | Function *newF; |
556 | BFSLevel bfs = getBFSLevel(F_); |
557 | size_t level = bfs.size(); |
558 | |
559 | // Create the functions and push them into the mapping |
560 | for (DeviceIDTy curPartition = 0; curPartition < numDevices; curPartition++) { |
561 | std::string funcName = |
562 | std::string(F_->getName()) + "_part" + std::to_string(curPartition + 1); |
563 | if (F_->getParent()->hasFunction(funcName)) { |
564 | newF = F_->getParent()->getFunction(funcName); |
565 | F_->getParent()->eraseFunction(newF); |
566 | } |
567 | newF = F_->getParent()->createFunction(funcName); |
568 | partitionMap.createPartition(newF, backendName); |
569 | partitionMap.appendLogicalDeviceID(newF, curPartition); |
570 | partitionFuncs[curPartition] = newF; |
571 | } |
572 | |
573 | // Go through operators level by level |
574 | for (int i = level - 1; i >= 0; i--) { |
575 | for (size_t j = 0, e = bfs[i].size(); j < e; j++) { |
576 | Node *N = bfs[i][j]; |
577 | |
578 | // Find the maximum partition id of the inputs to the node |
579 | DeviceIDTy maxLogicalDeviceId = 0; |
580 | for (auto &I : getInputs(N)) { |
581 | Function *inpF = partitionMap[I]; |
582 | auto logicalDeviceIds = partitionMap.getLogicalDeviceIDList(inpF); |
583 | DCHECK(logicalDeviceIds.size() == 1); |
584 | auto logicalDeviceId = logicalDeviceIds[0]; |
585 | if (logicalDeviceId > maxLogicalDeviceId) { |
586 | maxLogicalDeviceId = logicalDeviceId; |
587 | } |
588 | } |
589 | |
590 | auto curOpTime = |
591 | getNodeComputeTime(N, backendMap_[deviceInfo_[0].backendName]); |
592 | auto curOpMemory = getNodeMemUsage(N); |
593 | |
594 | // Find a partition to put this node into |
595 | DeviceIDTy curPartition = maxLogicalDeviceId; |
596 | const float allowedLoadImbalanceFraction = 0.5f; |
597 | for (; curPartition < numDevices; curPartition++) { |
598 | // Put the op in current partition if |
599 | // (a) memory constaints and load balance constraints are not violated, |
600 | // or (b) this is the last partition and memory capacity isnt exceeded |
601 | // The allowedLoadImbalanceFraction in the load balance case is to avoid |
602 | // edge cases where load balance is only violated by a small amount and |
603 | // moving to the next partition would result in significant imbalance in |
604 | // runtime. Hence if the violation is by less than |
605 | // allowedLoadImbalanceFraction of the operator cost, then we prefer to |
606 | // keep it in the current partition. |
607 | bool loadBalanceValid = deviceTime[curPartition] + |
608 | curOpTime * allowedLoadImbalanceFraction < |
609 | timePerPartition; |
610 | bool memValid = memoryAvailable[curPartition] >= curOpMemory; |
611 | |
612 | if (memValid && (loadBalanceValid || curPartition == numDevices - 1)) { |
613 | // valid, put the node in the current partition |
614 | Function *curF = partitionFuncs[curPartition]; |
615 | partitionMap.add(N, curF); |
616 | deviceTime[curPartition] += curOpTime; |
617 | memoryAvailable[curPartition] -= curOpMemory; |
618 | graphMem[curPartition] = updateGraphMemInfoByAddingNode( |
619 | nodesInPartitions[curPartition], graphMem[curPartition], N); |
620 | nodesInPartitions[curPartition].insert(N); |
621 | partitionMap.setGraphMemInfo(curF, graphMem[curPartition]); |
622 | break; |
623 | } |
624 | } |
625 | |
626 | // Throw error if we were not able to put this node into any partition |
627 | if (curPartition >= numDevices) { |
628 | logPartitionInfo(partitionMap); |
629 | return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR, |
630 | "Load balance partition error" ); |
631 | } |
632 | } |
633 | } |
634 | for (size_t i = 0; i < numDevices; i++) { |
635 | VLOG(1) << "Partition #" << i << " has estimated runtime " << deviceTime[i]; |
636 | } |
637 | // Check if the memory usage meets the device memory limitation. |
638 | RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_)); |
639 | |
640 | // assignLogicalDeviceID adds all partitions to their logical device, clear |
641 | // the existing first to prevent duplication. |
642 | partitionMap.clearLogicalDeviceID(); |
643 | logicalDeviceID_ = assignLogicalDeviceID(partitionMap, backendMap_); |
644 | RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_)); |
645 | RETURN_IF_ERR(resourceCountValidation(partitionMap, backendMap_)); |
646 | |
647 | partitions = |
648 | doPartitioning(origName, {F_}, module_, partitionMap, /* saveDAG */ true, |
649 | cctx.backendOpts.backendSpecificNodeInfo); |
650 | module_->eraseFunction(F_); |
651 | |
652 | if (cctx.saturateHost && |
653 | partitionMap.getPartitions().size() < deviceInfo_.size()) { |
654 | saturateHost(logicalDeviceID_, partitions, cctx.saturateKDevices); |
655 | } |
656 | |
657 | RETURN_IF_ERR(finalize(partitions, partitionMap)); |
658 | |
659 | return std::move(partitions); |
660 | } |
661 | |
662 | Expected<DAGListTy> |
663 | Partitioner::quantizationProfilingPartition(CompilationContext &cctx) { |
664 | // For quantization profiling flow, currently we assume there is only 1 |
665 | // function in a module. |
666 | if (module_->getFunctions().size() != 1) { |
667 | return MAKE_ERR( |
668 | ErrorValue::ErrorCode::PARTITIONER_ERROR, |
669 | strFormat( |
670 | "Invalid : %lu functions in a module. In quantization profiling " |
671 | "partition flow, the module can only contain 1 function" , |
672 | module_->getFunctions().size())); |
673 | } |
674 | |
675 | // Quantization profiling flow is run under CPU backend, so we don't really |
676 | // need the concrete partition. The backendBasedPartition is necessary since |
677 | // we need the mapping between quantized tensor and original tensor. |
678 | DAGListTy partitions; |
679 | std::vector<Backend *> backends; |
680 | genBackendMap(backendMap_, backendHolder_, backends); |
681 | F_ = selectRepFunc(module_, memSize_); |
682 | |
683 | FunctionToBackendNameMap funcToBackend; |
684 | ASSIGN_VALUE_OR_RETURN_ERR( |
685 | partitions, backendBasedPartition(funcToBackend, F_, backends, cctx)); |
686 | module_->eraseFunction(F_); |
687 | std::unique_ptr<Backend> backend(createBackend(profilingBackend)); |
688 | for (Function *subF : module_->getFunctions()) { |
689 | DCHECK(subF->verify()) << "Conversion led to invalid function" ; |
690 | if (!optimized_) { |
691 | RETURN_IF_ERR(::glow::optimizeFunction(subF, *backend, cctx)); |
692 | } |
693 | } |
694 | if (logPartition) { |
695 | LOG(INFO) |
696 | << "Profiling a model to be partitioned cross different backends. Each " |
697 | "sub-network will be optimized and run on cpu backend.\n" ; |
698 | } |
699 | return std::move(partitions); |
700 | } |
701 | |
702 | Expected<DAGListTy> |
703 | Partitioner::heterogeneousPartition(CompilationContext &cctx) { |
704 | DAGListTy partitions; |
705 | // Prepare the mapping between BackendName and BackendInfo. |
706 | std::vector<Backend *> backends; |
707 | genBackendMap(backendMap_, backendHolder_, backends); |
708 | |
709 | // Step 0: Find the representative function for running partitioning |
710 | // algorithm. |
711 | F_ = selectRepFunc(module_, memSize_); |
712 | |
713 | // Step 1 : do the partition based on backends type. |
714 | FunctionToBackendNameMap funcToBackend; |
715 | std::string origName(F_->getName().data()); |
716 | if (backends.size() == 1) { |
717 | // Only one type of backends, no need to backendName based partition. |
718 | auto backendName = backends[0]->getBackendName(); |
719 | funcToBackend[F_] = backendName; |
720 | |
721 | if (memSize_ < backendMap_[backendName].memSize) { |
722 | // No partition is needed. Create DAGNode and return. This root is alway a |
723 | // dummy function. |
724 | if (logPartition) { |
725 | LOG(INFO) << "The model is too small for applying partition.\n" |
726 | << "Model size : " << memSize_ << "\n" |
727 | << "Backend Name : " << backendName << "\n" |
728 | << "Device memory: " << backendMap_[backendName].memSize |
729 | << "\n" ; |
730 | } |
731 | return createDAGWithoutPartition(backendName, backendMap_, cctx); |
732 | } |
733 | // NOTE: the following error detection will be removed once multi-functions |
734 | // in a module is supported. |
735 | if (module_->getFunctions().size() != 1) { |
736 | return MAKE_ERR( |
737 | ErrorValue::ErrorCode::PARTITIONER_ERROR, |
738 | strFormat("Invalid : %lu functions in a module. Now in heterogeneous " |
739 | "partition flow, the module can only contain 1 function" , |
740 | module_->getFunctions().size())); |
741 | } |
742 | } else { |
743 | // NOTE: the following error detection will be removed once multi-functions |
744 | // in a module is supported. |
745 | if (module_->getFunctions().size() != 1) { |
746 | return MAKE_ERR( |
747 | ErrorValue::ErrorCode::PARTITIONER_ERROR, |
748 | strFormat( |
749 | "Invalid : %lu functions in a module. Now in heterogeneous partition\ |
750 | flow, the module can only contain 1 function" , |
751 | module_->getFunctions().size())); |
752 | } |
753 | ASSIGN_VALUE_OR_RETURN_ERR( |
754 | partitions, backendBasedPartition(funcToBackend, F_, backends, cctx)); |
755 | module_->eraseFunction(F_); |
756 | } |
757 | |
758 | // Step 2 : optimize each functions based on its backend type and apply the |
759 | // partition algorithm. |
760 | NodeToFunctionMap mapping; |
761 | std::vector<Function *> funcs; |
762 | for (auto i = funcToBackend.begin(); i != funcToBackend.end(); ++i) { |
763 | auto *func = i->first; |
764 | auto *backend = backendMap_[i->second].backend; |
765 | auto availMem = backendMap_[i->second].memSize; |
766 | funcs.push_back(func); |
767 | DCHECK(func->verify()) << "Conversion led to invalid function" ; |
768 | // Step 2.1 : optimize a function if it has not been optimized yet. |
769 | if (!optimized_) { |
770 | RETURN_IF_ERR(::glow::optimizeFunction( |
771 | func, *backend, cctx, |
772 | &getDeviceInfoForBackend(backend->getBackendName()))); |
773 | } |
774 | |
775 | // Step 2.2 : apply graph partitioning algrithm to find out the partition. |
776 | NodeToFunctionMap partitionMap = |
777 | selectPartitions(func, availMem, i->second); |
778 | mapping.insert(partitionMap); |
779 | } |
780 | |
781 | // Check if the memory usage meets the device memory limitation. |
782 | RETURN_IF_ERR(memoryUsageValidation(mapping, backendMap_)); |
783 | |
784 | // Step 3 : assign each partition with a logical device id. The partitions |
785 | // with the same logical device id will be assigned into the same physical |
786 | // device. |
787 | logicalDeviceID_ = assignLogicalDeviceID(mapping, backendMap_); |
788 | |
789 | // Check if the number of logical devices is less than the given physical |
790 | // devices. |
791 | RETURN_IF_ERR(logicalDevicesValidation(mapping, backendMap_)); |
792 | |
793 | // Step 4 : do the real partitioning for the function list. |
794 | partitions = |
795 | doPartitioning(origName, funcs, module_, mapping, /* saveDAG */ true, |
796 | cctx.backendOpts.backendSpecificNodeInfo); |
797 | |
798 | // Step 5 : Post-partition optimization - Adjust the logicalDevice for each |
799 | // DAGNode. |
800 | if (cctx.saturateHost && backends.size() == 1 && |
801 | mapping.getPartitions().size() < deviceInfo_.size()) { |
802 | // Attempt to saturate the host when there is only one type of backend. |
803 | // Passing in the count of logical devices. Since logicalId starts at 0 we |
804 | // add one. |
805 | saturateHost(logicalDeviceID_, partitions, cctx.saturateKDevices); |
806 | } |
807 | |
808 | // Step 6 : clean up and verify the generated new functions. |
809 | for (auto i = funcToBackend.begin(); i != funcToBackend.end(); ++i) { |
810 | module_->eraseFunction(i->first); |
811 | } |
812 | |
813 | RETURN_IF_ERR(finalize(partitions, mapping)); |
814 | |
815 | return std::move(partitions); |
816 | } |
817 | |
818 | Expected<DAGListTy> |
819 | Partitioner::partitionFromConfig(const PartitionConfig &partitionConfig, |
820 | CompilationContext &cctx) { |
821 | DAGListTy partitions; |
822 | // Prepare the mapping between BackendName and BackendInfo. |
823 | std::vector<Backend *> backends; |
824 | genBackendMap(backendMap_, backendHolder_, backends); |
825 | Function *F = module_->getFunction(partitionConfig.funcName); |
826 | if (!F) { |
827 | return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR, |
828 | strFormat("Can't find function %s in current module." , |
829 | F->getName().str().data())); |
830 | } |
831 | |
832 | DCHECK( |
833 | partitionConfig.numOfPartitions == partitionConfig.backendNames.size() && |
834 | partitionConfig.numOfPartitions == partitionConfig.partitionNames.size()) |
835 | << "Invalid user-defined partition config." ; |
836 | |
837 | if (partitionConfig.backendHints.size()) { |
838 | DCHECK(partitionConfig.numOfPartitions == |
839 | partitionConfig.backendHints.size()) |
840 | << "Invalid user-defined partition config (backendHints)." ; |
841 | } |
842 | |
843 | NodeToFunctionMap partitionMap; |
844 | std::vector<Function *> funcList; |
845 | std::unordered_set<size_t> unused; |
846 | std::vector<NodesSet> nodesSets(partitionConfig.numOfPartitions); |
847 | // Create partitions based on the given number and names. |
848 | for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) { |
849 | Function *newF = module_->createFunction(partitionConfig.partitionNames[i]); |
850 | funcList.push_back(newF); |
851 | partitionMap.createPartition(newF, partitionConfig.backendNames[i]); |
852 | unused.insert(i); |
853 | } |
854 | |
855 | // Map the nodes the the partitions. |
856 | std::vector<Node *> unMapped; |
857 | for (auto &node : F->getNodes()) { |
858 | auto iter = partitionConfig.nodeToPartition.find(node.getName()); |
859 | if (iter == partitionConfig.nodeToPartition.end()) { |
860 | // If a node in F is not in the node to partition mapping, put it into |
861 | // unMaped list. |
862 | unMapped.push_back(&node); |
863 | } else { |
864 | size_t partitionID = iter->second; |
865 | DCHECK(partitionID < partitionConfig.numOfPartitions) |
866 | << "Invalid partition id :" << partitionID; |
867 | partitionMap.add(&node, funcList[partitionID]); |
868 | unused.erase(partitionID); |
869 | nodesSets[partitionID].insert(&node); |
870 | } |
871 | } |
872 | |
873 | // If there is unused partition and unmapped nodes, map those nodes to the |
874 | // unused partition. |
875 | if (unMapped.size()) { |
876 | DCHECK_EQ(unused.size(), 1) << "There must be exactly 1 unused partition." ; |
877 | auto partitionID = *(unused.begin()); |
878 | for (auto &node : unMapped) { |
879 | partitionMap.add(node, funcList[partitionID]); |
880 | nodesSets[partitionID].insert(node); |
881 | } |
882 | } |
883 | |
884 | // Set backend hints if they exist |
885 | if (partitionConfig.backendHints.size()) { |
886 | for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) { |
887 | auto func = funcList[i]; |
888 | partitionMap.setBackendHints(func, partitionConfig.backendHints[i]); |
889 | } |
890 | } |
891 | |
892 | // Validate memory usage. |
893 | for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) { |
894 | GraphMemInfo cost = getGraphMemInfo(nodesSets[i], contextCount_); |
895 | partitionMap.setGraphMemInfo(funcList[i], cost); |
896 | } |
897 | RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_)); |
898 | |
899 | // If logical device assignments are provided use them otherwise assign them. |
900 | if (partitionConfig.logicalIDs.size()) { |
901 | DCHECK(partitionConfig.numOfPartitions == |
902 | partitionConfig.logicalIDs.size()); |
903 | for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) { |
904 | auto func = funcList[i]; |
905 | for (auto logicalDevice : partitionConfig.logicalIDs[i]) { |
906 | partitionMap.appendLogicalDeviceID(func, logicalDevice); |
907 | } |
908 | } |
909 | |
910 | } else { |
911 | // Logical device ID validation. |
912 | logicalDeviceID_ = assignLogicalDeviceID(partitionMap, backendMap_); |
913 | } |
914 | // Add replication count to config if provided. |
915 | for (auto &replicationAssignment : partitionConfig.replicationCount) { |
916 | auto func = funcList.at(replicationAssignment.first); |
917 | partitionMap.addReplicationCount(func, replicationAssignment.second); |
918 | } |
919 | |
920 | RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_)); |
921 | RETURN_IF_ERR(resourceCountValidation(partitionMap, backendMap_)); |
922 | |
923 | // Do partition. |
924 | partitions = doPartitioning(F->getName(), {F}, module_, partitionMap, |
925 | /* saveDAG */ true, |
926 | cctx.backendOpts.backendSpecificNodeInfo); |
927 | module_->eraseFunction(F); |
928 | |
929 | // DAG validation. |
930 | RETURN_IF_ERR(dagValidation(partitions[0])); |
931 | |
932 | // Verify the function. |
933 | for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) { |
934 | auto func = funcList[i]; |
935 | DCHECK(func->verify()) << "Conversion led to invalid function" ; |
936 | } |
937 | |
938 | RETURN_IF_ERR(finalize(partitions, partitionMap)); |
939 | |
940 | return std::move(partitions); |
941 | } |
942 | |
943 | Expected<DAGListTy> |
944 | Partitioner::setupPrepartitionedModule(CompilationContext &cctx) { |
945 | const PrePartitionedConfig &config = *cctx.prepartitionedConfig; |
946 | |
947 | RETURN_ERR_IF_NOT( |
948 | !multiBackendNames_, |
949 | "Do not support multiple backend kinds in prepartitioned flow." ); |
950 | |
951 | // Prepare the mapping between BackendName and BackendInfo. |
952 | std::vector<Backend *> backends; |
953 | genBackendMap(backendMap_, backendHolder_, backends); |
954 | |
955 | const std::vector<Function *> &funcs = config.funcs; |
956 | |
957 | Backend *B = backends[0]; |
958 | auto backendName = B->getBackendName(); |
959 | |
960 | // Optimize all Functions if necessary. |
961 | if (!optimized_) { |
962 | for (Function *F : funcs) { |
963 | RETURN_IF_ERR(::glow::optimizeFunction( |
964 | F, *B, cctx, &getDeviceInfoForBackend(backendName))); |
965 | } |
966 | } |
967 | |
968 | NodeToFunctionMap partitionMap; |
969 | // Create partitions based on the given number and names. |
970 | for (size_t i = 0, e = funcs.size(); i < e; i++) { |
971 | partitionMap.createPartition(funcs[i], deviceInfo_[0].backendName); |
972 | } |
973 | |
974 | // Map the nodes the the partitions. |
975 | for (Function *F : funcs) { |
976 | for (auto &node : F->getNodes()) { |
977 | partitionMap.add(&node, F); |
978 | } |
979 | } |
980 | |
981 | // Validate memory usage. |
982 | for (Function *F : funcs) { |
983 | partitionMap.setGraphMemInfo(F, getFunctionMemory(F)); |
984 | } |
985 | RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_)); |
986 | |
987 | // If logical device assignments are provided use them otherwise assign them. |
988 | DCHECK(funcs.size() == config.logicalIDs.size()); |
989 | for (size_t i = 0; i < funcs.size(); i++) { |
990 | Function *F = funcs[i]; |
991 | for (auto logicalDevice : config.logicalIDs[i]) { |
992 | partitionMap.appendLogicalDeviceID(F, logicalDevice); |
993 | } |
994 | } |
995 | RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_)); |
996 | RETURN_IF_ERR(resourceCountValidation(partitionMap, backendMap_)); |
997 | |
998 | // Copy in or validate all members of the PPC. |
999 | RETURN_ERR_IF_NOT( |
1000 | funcs.size() == config.backendSpecificOpts.size(), |
1001 | "Number of Functions must equal number of backendSpecificOpts" ); |
1002 | RETURN_ERR_IF_NOT(funcs.size() == config.backendHints.size(), |
1003 | "Number of Functions must equal number of backendHints" ); |
1004 | RETURN_ERR_IF_NOT(funcs.size() == config.replicationCounts.size(), |
1005 | "Number of Functions must equal" ); |
1006 | RETURN_ERR_IF_NOT( |
1007 | funcs.size() == config.backendNames.size() || config.backendNames.empty(), |
1008 | "If there are backendNames specified, there must be one per Function" ); |
1009 | for (size_t i = 0, e = funcs.size(); i < e; i++) { |
1010 | Function *F = funcs[i]; |
1011 | partitionMap.setBackendSpecificOpts(F, config.backendSpecificOpts[i]); |
1012 | partitionMap.setBackendHints(F, config.backendHints[i]); |
1013 | partitionMap.addReplicationCount(F, config.replicationCounts[i]); |
1014 | if (!config.backendNames.empty()) { |
1015 | RETURN_ERR_IF_NOT(backendName == config.backendNames[i], |
1016 | "Mismatch on backendName for partition" ); |
1017 | } |
1018 | } |
1019 | |
1020 | // Do partition. |
1021 | DAGListTy partitions = doPartitioning( |
1022 | config.funcName, funcs, module_, partitionMap, |
1023 | /* saveDAG */ true, cctx.backendOpts.backendSpecificNodeInfo, |
1024 | /* skipCloning */ true); |
1025 | |
1026 | // DAG validation. |
1027 | RETURN_IF_ERR(dagValidation(partitions[0])); |
1028 | |
1029 | // Verify the function. |
1030 | for (Function *F : funcs) { |
1031 | DCHECK(F->verify()) << "Conversion led to invalid function" ; |
1032 | } |
1033 | |
1034 | RETURN_IF_ERR(finalize(partitions, partitionMap)); |
1035 | |
1036 | if (cctx.saturateHost) { |
1037 | // Use the config's logical IDs to determine how many cards it's using. |
1038 | llvm::SmallSet<DeviceIDTy, 6> allLogicalIDs; |
1039 | for (const auto &IDs : config.logicalIDs) { |
1040 | for (const auto &id : IDs) { |
1041 | allLogicalIDs.insert(id); |
1042 | } |
1043 | } |
1044 | saturateHost(allLogicalIDs.size(), partitions, cctx.saturateKDevices); |
1045 | } |
1046 | |
1047 | return std::move(partitions); |
1048 | } |
1049 | |
1050 | // Do a search starting at an SLS node to split any concats/tanh that will |
1051 | // be included the SLS partition |
1052 | static void splitConcatTanhFromNode(Function *F, Node *node, |
1053 | int concatSplitSize, |
1054 | const KindSet &pairSLSWithNodeKinds, |
1055 | bool concatTanhSinkApplied) { |
1056 | auto users = node->getUsers(); |
1057 | for (auto &j : users) { |
1058 | Node *user = j.getUser(); |
1059 | auto shouldPairWithSls = pairSLSWithNodeKinds.count(user->getKind()); |
1060 | if (!shouldPairWithSls) { |
1061 | continue; |
1062 | } |
1063 | |
1064 | if (auto *CN = llvm::dyn_cast<ConcatNode>(user)) { |
1065 | if (concatTanhSinkApplied) { |
1066 | auto concatUsers = CN->getUsers(); |
1067 | // Skip splitting concats which don't go into a tanh sink or are small |
1068 | if (concatUsers.empty() || |
1069 | concatUsers.begin()->getUser()->getKind() != |
1070 | glow::Kinded::Kind::TanhNodeKind || |
1071 | CN->getNumInputs() <= concatSplitSize) { |
1072 | continue; |
1073 | } |
1074 | auto tanhNode = |
1075 | llvm::dyn_cast<TanhNode>(concatUsers.begin()->getUser()); |
1076 | auto dim = CN->getDim(); |
1077 | // Split the concat into smaller concats and create a tanh sink for each |
1078 | // split |
1079 | std::vector<NodeValue> concats; |
1080 | for (size_t i = 0, n = CN->getNumInputs(); i < n; |
1081 | i += concatSplitSize) { |
1082 | auto begin = CN->getInputs().begin() + i; |
1083 | auto length = i + concatSplitSize < n ? concatSplitSize : n - i; |
1084 | |
1085 | std::vector<NodeValue> concatInputs(begin, begin + length); |
1086 | auto *concat = F->createConcat(CN->getName().str() + "_part_" + |
1087 | std::to_string(i / n), |
1088 | concatInputs, dim); |
1089 | auto *tanh = F->createTanh(CN->getName().str() + "_tanh_part_" + |
1090 | std::to_string(i / n), |
1091 | concat->getResult()); |
1092 | concats.emplace_back(tanh->getResult()); |
1093 | } |
1094 | // Combine split up concats |
1095 | auto *newConcat = |
1096 | F->createConcat(CN->getName().str() + "_combined" , concats, dim); |
1097 | tanhNode->getResult().replaceAllUsesOfWith(newConcat->getResult()); |
1098 | F->eraseNode(CN); |
1099 | F->eraseNode(tanhNode); |
1100 | } else { |
1101 | // Skip splitting concats which don't have all tanh inputs or are small |
1102 | if (!checkNodeInputsAllKind(user, glow::Kinded::Kind::TanhNodeKind) || |
1103 | CN->getNumInputs() <= concatSplitSize) { |
1104 | continue; |
1105 | } |
1106 | auto dim = CN->getDim(); |
1107 | // Split the concat into smaller concats |
1108 | std::vector<NodeValue> concats; |
1109 | for (size_t i = 0, n = CN->getNumInputs(); i < n; |
1110 | i += concatSplitSize) { |
1111 | auto begin = CN->getInputs().begin() + i; |
1112 | auto length = i + concatSplitSize < n ? concatSplitSize : n - i; |
1113 | |
1114 | std::vector<NodeValue> concatInputs(begin, begin + length); |
1115 | auto *concat = F->createConcat(CN->getName().str() + "_part_" + |
1116 | std::to_string(i / n), |
1117 | concatInputs, dim); |
1118 | concats.emplace_back(concat->getResult()); |
1119 | } |
1120 | // Combine split-up concats |
1121 | auto *newConcat = |
1122 | F->createConcat(CN->getName().str() + "_combined" , concats, dim); |
1123 | CN->getResult().replaceAllUsesOfWith(newConcat->getResult()); |
1124 | F->eraseNode(CN); |
1125 | } |
1126 | } else { |
1127 | splitConcatTanhFromNode(F, user, concatSplitSize, pairSLSWithNodeKinds, |
1128 | concatTanhSinkApplied); |
1129 | } |
1130 | } |
1131 | } |
1132 | |
1133 | static void splitConcatTanh(Function *F, int concatSplitSize, |
1134 | std::vector<std::string> pairSLSWith, |
1135 | bool concatTanhSinkApplied) { |
1136 | const std::unordered_map<std::string, glow::Kinded::Kind> nameToNodeKind = { |
1137 | {"Concat" , glow::Kinded::Kind::ConcatNodeKind}, |
1138 | {"LayerNorm" , glow::Kinded::Kind::LayerNormalizationNodeKind}, |
1139 | {"Tile" , glow::Kinded::Kind::TileNodeKind}, |
1140 | {"Tanh" , glow::Kinded::Kind::TanhNodeKind}}; |
1141 | for (auto &node : F->getNodes()) { |
1142 | switch (node.getKind()) { |
1143 | |
1144 | #define SPLIT_CONCAT_TANH_CASE(NODE_NAME_) \ |
1145 | case Kinded::Kind::NODE_NAME_##Kind: { \ |
1146 | auto SLS = llvm::cast<NODE_NAME_>(&node); \ |
1147 | KindSet pairSLSWithNodeKinds; \ |
1148 | for (auto &s : pairSLSWith) { \ |
1149 | if (nameToNodeKind.find(s) == nameToNodeKind.end() || \ |
1150 | pairSLSWithNodeKinds.count(nameToNodeKind.at(s))) { \ |
1151 | continue; \ |
1152 | } \ |
1153 | if (s == "Tile") { \ |
1154 | if (SLS->getResult().dims()[0] == 1) { \ |
1155 | pairSLSWithNodeKinds.insert(nameToNodeKind.at(s)); \ |
1156 | } \ |
1157 | } else { \ |
1158 | pairSLSWithNodeKinds.insert(nameToNodeKind.at(s)); \ |
1159 | } \ |
1160 | } \ |
1161 | splitConcatTanhFromNode(F, SLS, concatSplitSize, pairSLSWithNodeKinds, \ |
1162 | concatTanhSinkApplied); \ |
1163 | } \ |
1164 | continue; |
1165 | |
1166 | SPLIT_CONCAT_TANH_CASE(FusedRowwiseQuantizedSparseLengthsWeightedSumNode); |
1167 | SPLIT_CONCAT_TANH_CASE(FusedRowwiseQuantizedSparseLengthsSumNode); |
1168 | SPLIT_CONCAT_TANH_CASE(RowwiseQuantizedSparseLengthsWeightedSumNode); |
1169 | SPLIT_CONCAT_TANH_CASE(SparseLengthsSumNode); |
1170 | SPLIT_CONCAT_TANH_CASE(SparseLengthsWeightedSumNode); |
1171 | SPLIT_CONCAT_TANH_CASE(EmbeddingBagNode); |
1172 | SPLIT_CONCAT_TANH_CASE(EmbeddingBagByteRowwiseOffsetsNode); |
1173 | #undef SPLIT_CONCAT_TANH_CASE |
1174 | |
1175 | default: |
1176 | continue; |
1177 | } |
1178 | } |
1179 | } |
1180 | |
1181 | // Do a search starting at an SLS output to capture any Clip, |
1182 | // LayerNormalization, Tile, Tanh nodes which are there |
1183 | static void |
1184 | expandFrontier(Node *node, const NodeValue &value, |
1185 | std::unordered_set<NodeValue> &frontier, |
1186 | std::unordered_set<Node *> &traversedNodes, |
1187 | const std::map<glow::Kinded::Kind, size_t> &pairSlsWithNodeKinds, |
1188 | bool concatTanhSinkApplied) { |
1189 | traversedNodes.insert(node); |
1190 | bool covered = true; |
1191 | auto users = node->getUsers(); |
1192 | for (auto j = users.begin(), f = users.end(); j != f; ++j) { |
1193 | Node *user = (*j).getUser(); |
1194 | if (ClipNode *CN = llvm::dyn_cast<ClipNode>(user)) { |
1195 | expandFrontier(user, CN->getResult(), frontier, traversedNodes, |
1196 | pairSlsWithNodeKinds, concatTanhSinkApplied); |
1197 | } else { |
1198 | auto it = pairSlsWithNodeKinds.find(user->getKind()); |
1199 | if (it != pairSlsWithNodeKinds.end()) { |
1200 | |
1201 | if (it->first == glow::Kinded::Kind::ConcatNodeKind) { |
1202 | auto concatUsers = user->getUsers(); |
1203 | // If tanh sink was applied, only include concats which go into tanh |
1204 | // sink |
1205 | if (concatTanhSinkApplied && !concatUsers.empty() && |
1206 | concatUsers.begin()->getUser()->getKind() == |
1207 | glow::Kinded::Kind::TanhNodeKind) { |
1208 | expandFrontier(user, user->getNthResult(it->second), frontier, |
1209 | traversedNodes, pairSlsWithNodeKinds, |
1210 | concatTanhSinkApplied); |
1211 | } |
1212 | // If tanh sink was not applied, only include concats whose inputs are |
1213 | // all tanh |
1214 | else if (!concatTanhSinkApplied && |
1215 | checkNodeInputsAllKind(user, |
1216 | glow::Kinded::Kind::TanhNodeKind)) { |
1217 | expandFrontier(user, user->getNthResult(it->second), frontier, |
1218 | traversedNodes, pairSlsWithNodeKinds, |
1219 | concatTanhSinkApplied); |
1220 | } else { |
1221 | covered = false; |
1222 | } |
1223 | } else { |
1224 | expandFrontier(user, user->getNthResult(it->second), frontier, |
1225 | traversedNodes, pairSlsWithNodeKinds, |
1226 | concatTanhSinkApplied); |
1227 | } |
1228 | } else { |
1229 | covered = false; |
1230 | } |
1231 | } |
1232 | } |
1233 | if (!covered) { |
1234 | frontier.insert(value); |
1235 | } |
1236 | } |
1237 | |
1238 | /// Helper function for SparseNN Partitioning scheme. Checks for each |
1239 | /// kind of SLS table and appends their metadata to the vector. |
1240 | template <typename SLSType> |
1241 | static Error appendSLSTable(SLSType *SLS, std::vector<SLSTableInfo> &slsTables, |
1242 | bool doPerfModelBalance, Backend *backend, |
1243 | const std::vector<std::string> &pairSLSWith, |
1244 | bool concatTanhSinkApplied) { |
1245 | uint64_t cost = 1; |
1246 | uint64_t numBytesInTable = |
1247 | (uint64_t)SLS->getData().getType()->getSizeInBytes(); |
1248 | |
1249 | // If average length is available, then compute cost using perf model |
1250 | if (doPerfModelBalance) { |
1251 | double cost_d; |
1252 | ASSIGN_VALUE_OR_RETURN_ERR(cost_d, backend->estimateNodeCost(SLS)); |
1253 | cost = (uint64_t)cost_d; |
1254 | } |
1255 | auto slsResult = SLS->getResult(); |
1256 | const std::unordered_map<std::string, std::pair<glow::Kinded::Kind, size_t>> |
1257 | nameToNodeKind{ |
1258 | {"Concat" , |
1259 | {glow::Kinded::Kind::ConcatNodeKind, ConcatNode::ResultIdx}}, |
1260 | {"LayerNorm" , |
1261 | {glow::Kinded::Kind::LayerNormalizationNodeKind, |
1262 | LayerNormalizationNode::ResultIdx}}, |
1263 | {"Tile" , {glow::Kinded::Kind::TileNodeKind, TileNode::ResultIdx}}, |
1264 | {"Tanh" , {glow::Kinded::Kind::TanhNodeKind, TanhNode::ResultIdx}}, |
1265 | }; |
1266 | std::map<glow::Kinded::Kind, size_t> pairSlsWithNodeKinds; |
1267 | for (auto &s : pairSLSWith) { |
1268 | if (nameToNodeKind.find(s) == nameToNodeKind.end() || |
1269 | pairSlsWithNodeKinds.find(nameToNodeKind.at(s).first) != |
1270 | pairSlsWithNodeKinds.end()) { |
1271 | continue; |
1272 | } |
1273 | // Only expand SLS w/ tile for user embeddings |
1274 | if (s == "Tile" ) { |
1275 | // The first dimension = 1 corresponds to user embeddings, so we expand w/ |
1276 | // Tile |
1277 | if (slsResult.dims()[0] == 1) { |
1278 | pairSlsWithNodeKinds.insert(nameToNodeKind.at(s)); |
1279 | } |
1280 | } else { |
1281 | pairSlsWithNodeKinds.insert(nameToNodeKind.at(s)); |
1282 | } |
1283 | } |
1284 | std::unordered_set<NodeValue> frontier; |
1285 | std::unordered_set<Node *> neighbors; |
1286 | expandFrontier(SLS, slsResult, frontier, neighbors, pairSlsWithNodeKinds, |
1287 | concatTanhSinkApplied); |
1288 | |
1289 | // neighbors contains only successors; add all predecessors too. |
1290 | std::unordered_set<Node *> addedSLSNeighbors; |
1291 | std::queue<Node *> preds; |
1292 | for (auto *N : neighbors) { |
1293 | preds.push(N); |
1294 | } |
1295 | preds.push(SLS); |
1296 | auto hasConcat = pairSlsWithNodeKinds.find(Kinded::Kind::ConcatNodeKind) != |
1297 | pairSlsWithNodeKinds.end(); |
1298 | while (!preds.empty()) { |
1299 | auto *cur = preds.front(); |
1300 | if (cur != SLS) { |
1301 | neighbors.insert(cur); |
1302 | // Sum up the total sizes of SLS nodes under the same concat since they'll |
1303 | // all be in the same partition |
1304 | if (hasConcat && isSLSNode(cur) && |
1305 | addedSLSNeighbors.find(cur) == addedSLSNeighbors.end()) { |
1306 | addedSLSNeighbors.insert(cur); |
1307 | switch (cur->getKind()) { |
1308 | #define ADD_SLS_NB_NODE_SIZE_CASE(NODE_NAME_) \ |
1309 | case Kinded::Kind::NODE_NAME_##Kind: { \ |
1310 | auto SLS = llvm::cast<NODE_NAME_>(cur); \ |
1311 | numBytesInTable += (uint64_t)SLS->getData().getType()->getSizeInBytes(); \ |
1312 | } \ |
1313 | continue; |
1314 | |
1315 | ADD_SLS_NB_NODE_SIZE_CASE( |
1316 | FusedRowwiseQuantizedSparseLengthsWeightedSumNode); |
1317 | ADD_SLS_NB_NODE_SIZE_CASE(FusedRowwiseQuantizedSparseLengthsSumNode); |
1318 | ADD_SLS_NB_NODE_SIZE_CASE( |
1319 | RowwiseQuantizedSparseLengthsWeightedSumNode); |
1320 | ADD_SLS_NB_NODE_SIZE_CASE(SparseLengthsSumNode); |
1321 | ADD_SLS_NB_NODE_SIZE_CASE(SparseLengthsWeightedSumNode); |
1322 | ADD_SLS_NB_NODE_SIZE_CASE(EmbeddingBagNode); |
1323 | ADD_SLS_NB_NODE_SIZE_CASE(EmbeddingBagByteRowwiseOffsetsNode); |
1324 | #undef ADD_SLS_NB_NODE_SIZE_CASE |
1325 | default: |
1326 | continue; |
1327 | } |
1328 | } |
1329 | } |
1330 | preds.pop(); |
1331 | for (auto *N : getInputs(cur)) { |
1332 | preds.push(N); |
1333 | } |
1334 | } |
1335 | |
1336 | slsTables.push_back( |
1337 | {SLS, neighbors, frontier, numBytesInTable, 0, slsResult, cost}); |
1338 | return Error::success(); |
1339 | } |
1340 | |
1341 | // Check if the input for \p targetNode is a SplatNode with more than one |
1342 | // user, and if so clone the splat node into \p F and set it to be the new |
1343 | // input of \p targetNode. |
1344 | static void cloneSplatInputIfNecessary(Node *targetNode, Function *F) { |
1345 | for (int inp = 0, e = targetNode->getNumInputs(); inp < e; inp++) { |
1346 | auto input = targetNode->getNthInput(inp); |
1347 | SplatNode *splatInput = llvm::dyn_cast<SplatNode>(input.getNode()); |
1348 | if (!splatInput || splatInput->getNumUsers() <= 1) { |
1349 | continue; |
1350 | } |
1351 | SplatNode *splatInputClone = |
1352 | F->addNode(llvm::cast<SplatNode>(splatInput->clone())); |
1353 | targetNode->setNthInput(inp, splatInputClone->getResult()); |
1354 | } |
1355 | } |
1356 | |
1357 | // Insert Split->Concat at barrier between SLS and Non-SLS partitions |
1358 | static Error |
1359 | sparseNNInsertSplitConcat(Function *F, |
1360 | std::vector<std::unordered_set<NodeValue>> &frontiers, |
1361 | PartitionConfig &partitionConfig) { |
1362 | |
1363 | // Walk through SLS tables and check that all the results are able to concat |
1364 | std::vector<std::vector<NodeValue>> concatInputs(frontiers.size()); |
1365 | // Insert concat and slice nodes and assign them to partitions |
1366 | for (size_t p = 0; p < frontiers.size(); p++) { |
1367 | auto &frontier = frontiers[p]; |
1368 | |
1369 | if (frontier.size() == 0) { |
1370 | continue; |
1371 | } |
1372 | auto &templateResult = *frontier.begin(); |
1373 | auto templateDims = templateResult.dims(); |
1374 | auto templateConcatDim = templateDims.size() - 1; |
1375 | |
1376 | for (auto &tableResult : frontier) { |
1377 | auto tableDims = tableResult.dims(); |
1378 | RETURN_ERR_IF_NOT(tableDims.size() == templateDims.size(), |
1379 | strFormat("SLS concat addition encountered tensors " |
1380 | "with differing dimensions (%zu vs %zu)" , |
1381 | (size_t)tableDims.size(), |
1382 | (size_t)templateDims.size())); |
1383 | for (dim_t otherDim = 0; otherDim < templateConcatDim; otherDim++) { |
1384 | RETURN_ERR_IF_NOT(tableDims[otherDim] == templateDims[otherDim], |
1385 | strFormat("SLS concat addition encountered tensors " |
1386 | "with differing dimension (%zu vs %zu)" , |
1387 | (size_t)tableDims[otherDim], |
1388 | (size_t)templateDims[otherDim])); |
1389 | } |
1390 | RETURN_ERR_IF_NOT(tableResult.getType()->getElementType() == |
1391 | templateResult.getType()->getElementType(), |
1392 | "SLS concat addition encountered tensors with " |
1393 | "differing ElementType" ); |
1394 | concatInputs[p].push_back(tableResult); |
1395 | } |
1396 | |
1397 | if (concatInputs[p].size() > 1) { |
1398 | |
1399 | // Insert concat |
1400 | auto *deviceConcat = F->createConcat("concat_dev_" + std::to_string(p), |
1401 | concatInputs[p], templateConcatDim); |
1402 | partitionConfig.nodeToPartition[deviceConcat->getName()] = p; |
1403 | |
1404 | // Insert slices |
1405 | std::vector<dim_t> splits(concatInputs[p].size()); |
1406 | for (dim_t i = 0; i < concatInputs[p].size(); i++) { |
1407 | auto inputDim = concatInputs[p][i].dims(); |
1408 | splits[i] = inputDim[templateConcatDim]; |
1409 | } |
1410 | std::vector<SliceNode *> splitOutputs; |
1411 | F->createSplit("split_dev" + std::to_string(p), deviceConcat, |
1412 | splits.size(), templateConcatDim, splits, splitOutputs); |
1413 | for (dim_t i = 0; i < concatInputs[p].size(); i++) { |
1414 | assert(i < splitOutputs.size()); |
1415 | concatInputs[p][i].replaceAllUsesOfWith(splitOutputs[i]); |
1416 | deviceConcat->setNthInput(i, concatInputs[p][i]); |
1417 | partitionConfig.nodeToPartition[splitOutputs[i]->getName()] = |
1418 | partitionConfig.numOfPartitions - 1; |
1419 | } |
1420 | } |
1421 | } |
1422 | return Error::success(); |
1423 | }; |
1424 | |
1425 | Expected<DAGListTy> Partitioner::partitionSparseNN(CompilationContext &cctx) { |
1426 | VLOG(1) << "Doing SparseNN partitioning" << std::endl; |
1427 | PartitionConfig partitionConfig; |
1428 | partitionConfig.numOfPartitions = 0; |
1429 | |
1430 | // Find the first partition with an SLS node |
1431 | Function *F = nullptr; |
1432 | for (Function *currF : module_->getFunctions()) { |
1433 | for (auto &node : currF->getNodes()) { |
1434 | if (node.getKind() == |
1435 | glow::Kinded::Kind:: |
1436 | FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind || |
1437 | node.getKind() == glow::Kinded::Kind:: |
1438 | FusedRowwiseQuantizedSparseLengthsSumNodeKind || |
1439 | node.getKind() == |
1440 | glow::Kinded::Kind:: |
1441 | RowwiseQuantizedSparseLengthsWeightedSumNodeKind || |
1442 | node.getKind() == glow::Kinded::Kind::SparseLengthsSumNodeKind || |
1443 | node.getKind() == |
1444 | glow::Kinded::Kind::SparseLengthsWeightedSumNodeKind || |
1445 | node.getKind() == glow::Kinded::Kind::EmbeddingBagNodeKind || |
1446 | node.getKind() == |
1447 | glow::Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind) { |
1448 | F = currF; |
1449 | break; |
1450 | } |
1451 | } |
1452 | if (F) { |
1453 | break; |
1454 | } |
1455 | } |
1456 | |
1457 | // If no matching functions then return empty config |
1458 | if (!F) { |
1459 | return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR, |
1460 | "Did not find a partition with an SLS node" ); |
1461 | } |
1462 | |
1463 | if (deviceInfo_.size() < |
1464 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards) { |
1465 | return MAKE_ERR( |
1466 | ErrorValue::ErrorCode::PARTITIONER_ERROR, |
1467 | strFormat("Not enough devices to partition. Num Devices is %zu and Num " |
1468 | "SparseNN Cards Needed is %u" , |
1469 | deviceInfo_.size(), |
1470 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards)); |
1471 | } |
1472 | |
1473 | // Otherwise partition this function |
1474 | partitionConfig.funcName = F->getName().str(); |
1475 | |
1476 | // First optimize the function |
1477 | std::vector<Backend *> backends; |
1478 | genBackendMap(backendMap_, backendHolder_, backends); |
1479 | // First optimize it |
1480 | if (!optimized_) { |
1481 | RETURN_IF_ERR(::glow::optimizeFunction(F, *(backends[0]), cctx)); |
1482 | } |
1483 | |
1484 | // Now we may want to duplicate Splat input nodes in case they have been |
1485 | // CSE'd (CSE stands for common subexpression elimination) into a single |
1486 | // SplatNode. This is because if two SLWS that share Splat input nodes are |
1487 | // separated to two partitions, then partitioning will force a dependence |
1488 | // from whichever partition the input node are placed to the other |
1489 | // partition. After partitioning when we optimize each partition |
1490 | // individually, they may be merged again inside the partition. Besides, the |
1491 | // potential partition dependency introduced might lead to a circular |
1492 | // dependency in the final graph. |
1493 | // |
1494 | // We fix this issue by iterating over the Function and finding Splat input |
1495 | // nodes with multiple users and just creating new Splats (by cloning) for |
1496 | // each user. |
1497 | for (auto &node : F->getNodes()) { |
1498 | cloneSplatInputIfNecessary(&node, F); |
1499 | } |
1500 | |
1501 | // Create list of SLS Tables |
1502 | std::vector<SLSTableInfo> slsTables; |
1503 | partitionConfig.funcName = std::string(F->getName()); |
1504 | VLOG(1) << "Function: " << std::string(F->getName()) << std::endl; |
1505 | |
1506 | std::vector<std::string> pairSLSWith; |
1507 | folly::split<char, std::string, std::string>( |
1508 | ',', cctx.optimizationOpts.sparseNNPartitioningPairSLSWith, pairSLSWith, |
1509 | /*ignoreEmpty*/ true); |
1510 | if (cctx.optimizationOpts.sparseNNPartitioningPairTileWithSLS) { |
1511 | pairSLSWith.emplace_back("Tile" ); |
1512 | } |
1513 | if (cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS) { |
1514 | pairSLSWith.emplace_back("LayerNorm" ); |
1515 | } |
1516 | bool concatTanhSinkApplied = cctx.optimizationOpts.sinkTanhBelowConcat; |
1517 | if (std::find(pairSLSWith.begin(), pairSLSWith.end(), "Concat" ) != |
1518 | pairSLSWith.end()) { |
1519 | auto splitConcatSize = |
1520 | cctx.optimizationOpts.sparseNNPartitioningConcatSplitSize; |
1521 | splitConcatTanh(F, splitConcatSize, pairSLSWith, concatTanhSinkApplied); |
1522 | } |
1523 | const bool doPerfModelBalance = |
1524 | cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel; |
1525 | size_t totalSLSTableSizes = 0; |
1526 | for (auto &node : F->getNodes()) { |
1527 | switch (node.getKind()) { |
1528 | |
1529 | #define APPEND_TABLE_CASE(NODE_NAME_) \ |
1530 | case Kinded::Kind::NODE_NAME_##Kind: \ |
1531 | RETURN_IF_ERR(appendSLSTable<NODE_NAME_>( \ |
1532 | llvm::cast<NODE_NAME_>(&node), slsTables, doPerfModelBalance, \ |
1533 | backends[0], pairSLSWith, concatTanhSinkApplied)); \ |
1534 | totalSLSTableSizes += slsTables.back().numBytesInTable; \ |
1535 | continue; |
1536 | |
1537 | APPEND_TABLE_CASE(FusedRowwiseQuantizedSparseLengthsWeightedSumNode); |
1538 | APPEND_TABLE_CASE(FusedRowwiseQuantizedSparseLengthsSumNode); |
1539 | APPEND_TABLE_CASE(RowwiseQuantizedSparseLengthsWeightedSumNode); |
1540 | APPEND_TABLE_CASE(SparseLengthsSumNode); |
1541 | APPEND_TABLE_CASE(SparseLengthsWeightedSumNode); |
1542 | APPEND_TABLE_CASE(EmbeddingBagNode); |
1543 | APPEND_TABLE_CASE(EmbeddingBagByteRowwiseOffsetsNode); |
1544 | #undef APPEND_TABLE_CASE |
1545 | |
1546 | default: |
1547 | continue; |
1548 | } |
1549 | } |
1550 | LOG(INFO) << "Total size of all " << slsTables.size() |
1551 | << " SLS embedding tables: " << totalSLSTableSizes; |
1552 | |
1553 | // Now determine all nodes that fit in the NonSLS partition, so we know its |
1554 | // total size and can better judge how much space is left for SLS |
1555 | // partitions. |
1556 | std::unordered_set<const Node *> slsPartitionNodes; |
1557 | for (auto &slsTable : slsTables) { |
1558 | slsPartitionNodes.insert(slsTable.node); |
1559 | for (const Node *N : slsTable.neighbors) { |
1560 | slsPartitionNodes.insert(N); |
1561 | } |
1562 | } |
1563 | |
1564 | NodesSet nonSLSPartitionNodes; |
1565 | for (auto &node : F->getNodes()) { |
1566 | if (!slsPartitionNodes.count(&node)) { |
1567 | nonSLSPartitionNodes.insert(&node); |
1568 | } |
1569 | } |
1570 | |
1571 | // Calculate how much space the NonSLS partition takes up, and compare that |
1572 | // to how much memory the device has to determine the allows SLS partition |
1573 | // size. |
1574 | const uint64_t nonSLSPartitionSize = |
1575 | getGraphMemInfo(nonSLSPartitionNodes, contextCount_).getTotalMemSize(); |
1576 | const uint64_t totalDeviceMemory = deviceInfo_[0].availableMemory; |
1577 | RETURN_ERR_IF_NOT(nonSLSPartitionSize < totalDeviceMemory, |
1578 | strFormat("nonSLSPartitionSize %lu must be less than %s " |
1579 | "totalDeviceMemory %lu" , |
1580 | nonSLSPartitionSize, |
1581 | deviceInfo_[0].backendName.c_str(), |
1582 | totalDeviceMemory)); |
1583 | const uint64_t allowedSLSMemBytes = totalDeviceMemory - nonSLSPartitionSize; |
1584 | |
1585 | // Create table of devices |
1586 | std::vector<SLSDeviceInfo> slsDevices; |
1587 | std::vector<std::unordered_set<NodeValue>> frontierValues; |
1588 | unsigned int snnNumCards = |
1589 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards; |
1590 | |
1591 | LOG(INFO) << "totalDeviceMemory=" << totalDeviceMemory |
1592 | << ", nonSLSPartitionSize=" << nonSLSPartitionSize |
1593 | << ", allowedSLSMemBytes=" << allowedSLSMemBytes |
1594 | << ", snnNumCards=" << snnNumCards; |
1595 | |
1596 | bool partitionSucceeded = false; |
1597 | std::vector<unsigned int> factors; |
1598 | factors.push_back(snnNumCards); |
1599 | for (unsigned int i = snnNumCards + 1, e = deviceInfo_.size(); i <= e; ++i) { |
1600 | if (deviceInfo_.size() % i == 0) { |
1601 | factors.push_back(i); |
1602 | } |
1603 | } |
1604 | auto it = std::lower_bound(factors.begin(), factors.end(), snnNumCards); |
1605 | for (unsigned i = std::distance(factors.begin(), it); i < factors.size(); |
1606 | i++) { |
1607 | snnNumCards = factors[i]; |
1608 | LOG(INFO) << "Trying " << snnNumCards << " sparse partitions." ; |
1609 | // Reset some of the contexts. |
1610 | slsDevices.clear(); |
1611 | for (unsigned int device = 0; device < snnNumCards; device++) { |
1612 | slsDevices.push_back({device, allowedSLSMemBytes, 0}); |
1613 | } |
1614 | frontierValues.clear(); |
1615 | frontierValues.resize(slsDevices.size()); |
1616 | |
1617 | // Now assign SLS Nodes to devices |
1618 | if (ERR_TO_BOOL(assignSlsTablesToDevices(slsTables, slsDevices, |
1619 | frontierValues, contextCount_))) { |
1620 | LOG(INFO) << "Failed to partition SLS tables, fall back to greedy " |
1621 | "algorithm." ; |
1622 | if (!ERR_TO_BOOL(assignSlsTablesToDevicesGreedy( |
1623 | slsTables, slsDevices, frontierValues, contextCount_))) { |
1624 | partitionSucceeded = true; |
1625 | }; |
1626 | } else { |
1627 | partitionSucceeded = true; |
1628 | } |
1629 | |
1630 | if (partitionSucceeded) { |
1631 | LOG(INFO) << "Successfully got a SparseNN partition solution with " |
1632 | << snnNumCards << " sparse partitions." ; |
1633 | break; |
1634 | } else { |
1635 | LOG(WARNING) << "Cannot find a valid SparseNN partition solution with " |
1636 | << snnNumCards << " sparse partitions." ; |
1637 | } |
1638 | } |
1639 | |
1640 | if (!partitionSucceeded) { |
1641 | return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR, |
1642 | "SLS Balancing Partitioning Error: Not enough memory" ); |
1643 | } |
1644 | |
1645 | VLOG(1) << "Final table assignments: " ; |
1646 | printSlsTableInfo(slsTables); |
1647 | |
1648 | // Fill up the last partition with NonSLS nodes. |
1649 | for (auto *node : nonSLSPartitionNodes) { |
1650 | partitionConfig.nodeToPartition[node->getName()] = snnNumCards; |
1651 | } |
1652 | |
1653 | // Create manual partition |
1654 | partitionConfig.numOfPartitions = slsDevices.size() + 1; |
1655 | std::vector<unsigned int> allLogicalIDs; |
1656 | |
1657 | // Add SLS Partitions |
1658 | for (size_t p = 0; p < slsDevices.size(); p++) { |
1659 | partitionConfig.partitionNames.push_back(std::string("SLSPartition_" ) + |
1660 | std::to_string(p)); |
1661 | partitionConfig.backendNames.push_back(deviceInfo_[p].backendName); |
1662 | partitionConfig.logicalIDs.push_back({(unsigned int)p}); |
1663 | BackendHints backendHints; |
1664 | backendHints.executionUnits = |
1665 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresSLS; |
1666 | partitionConfig.backendHints.push_back(backendHints); |
1667 | allLogicalIDs.push_back(p); |
1668 | } |
1669 | |
1670 | // Add last partition |
1671 | partitionConfig.partitionNames.push_back(std::string("NonSLSPartition_" )); |
1672 | partitionConfig.backendNames.push_back(deviceInfo_[0].backendName); |
1673 | partitionConfig.logicalIDs.push_back(allLogicalIDs); |
1674 | BackendHints backendHints; |
1675 | backendHints.executionUnits = |
1676 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresOther; |
1677 | partitionConfig.backendHints.push_back(backendHints); |
1678 | |
1679 | // Map SLS nodes to their partitions |
1680 | for (auto &table : slsTables) { |
1681 | partitionConfig.nodeToPartition[table.node->getName()] = table.deviceId; |
1682 | for (Node *N : table.neighbors) { |
1683 | partitionConfig.nodeToPartition[N->getName()] = table.deviceId; |
1684 | } |
1685 | } |
1686 | |
1687 | // Insert Split->Concat at barrier between SLS and Non-SLS partitions |
1688 | if (cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats) { |
1689 | RETURN_IF_ERR( |
1690 | sparseNNInsertSplitConcat(F, frontierValues, partitionConfig)); |
1691 | } |
1692 | |
1693 | VLOG(1) << " Finished SparseNN partitioning" << std::endl; |
1694 | VLOG(1) << " PartitionConfig ::: funcName = " << partitionConfig.funcName |
1695 | << "\n" ; |
1696 | VLOG(1) << " PartitionConfig ::: numOfPartitions = " |
1697 | << partitionConfig.numOfPartitions << "\n" ; |
1698 | VLOG(1) << " PartitionConfig ::: partitionNames = " ; |
1699 | for (unsigned i = 0; i < partitionConfig.numOfPartitions; i++) { |
1700 | VLOG(1) << partitionConfig.partitionNames[i] << " " ; |
1701 | } |
1702 | VLOG(1) << "\n" ; |
1703 | VLOG(1) << " PartitionConfig ::: logicalIDs = " ; |
1704 | for (unsigned i = 0; i < partitionConfig.numOfPartitions; i++) { |
1705 | for (auto &id : partitionConfig.logicalIDs[i]) { |
1706 | VLOG(1) << id << " " ; |
1707 | } |
1708 | VLOG(1) << "\n" ; |
1709 | } |
1710 | |
1711 | DAGListTy partitions; |
1712 | ASSIGN_VALUE_OR_RETURN_ERR(partitions, |
1713 | partitionFromConfig(partitionConfig, cctx)); |
1714 | if (cctx.saturateHost) { |
1715 | saturateHost(snnNumCards, partitions, cctx.saturateKDevices); |
1716 | } |
1717 | return std::move(partitions); |
1718 | } |
1719 | |
1720 | Expected<DAGListTy> Partitioner::partition(CompilationContext &cctx) { |
1721 | if (cctx.prepartitionedConfig && |
1722 | cctx.prepartitionedConfig->funcs.size() != 0) { |
1723 | VLOG(1) << "Using prepartitioned config" ; |
1724 | return setupPrepartitionedModule(cctx); |
1725 | } |
1726 | |
1727 | if (cctx.partitionConfig) { |
1728 | VLOG(1) << "Using partition config" ; |
1729 | partitionConfig_ = *cctx.partitionConfig; |
1730 | } |
1731 | |
1732 | if (partitionConfig_.enabled()) { |
1733 | // Call user-defined partition flow. |
1734 | return partitionFromConfig(partitionConfig_, cctx); |
1735 | } |
1736 | |
1737 | if (!multiBackendNames_ && |
1738 | cctx.optimizationOpts.useSparseNNPartitioningScheme) { |
1739 | VLOG(1) << "Using SNN Partition Scheme" ; |
1740 | return partitionSparseNN(cctx); |
1741 | } |
1742 | |
1743 | if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) { |
1744 | // Call quantization profiling partition flow. |
1745 | VLOG(1) << "Using QuantProfile Partition" ; |
1746 | return quantizationProfilingPartition(cctx); |
1747 | } |
1748 | |
1749 | if (!multiBackendNames_ && glow::flags::EnableLoadBalancedPartitioning) { |
1750 | // Call load-balance partition flow. |
1751 | VLOG(1) << "Using Load balance Partition" ; |
1752 | return loadBalancedPartition(cctx); |
1753 | } |
1754 | |
1755 | VLOG(1) << "Using Heterogenous Partition" ; |
1756 | // Call heterogeneous partition flow. |
1757 | return heterogeneousPartition(cctx); |
1758 | } |
1759 | |