1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Partitioner/PartitionerUtils.h"
17#include "glow/Backend/BackendUtils.h"
18#include "glow/Flags/Flags.h"
19#include "glow/Partitioner/PartitionerTypes.h"
20#include "glow/Support/Support.h"
21
22#include <unordered_set>
23
24using llvm::isa;
25
26namespace glow {
27
28namespace {
29/// Used to sort 2 Nodes by their name, i.e. n1->name < n2->name order.
30auto compFunc = [](const Node *n1, Node *n2) -> bool {
31 return n1->compareByName(*n2);
32};
33constexpr uint32_t MB = 1024 * 1024;
34} // namespace
35
36/// The nodes in function \p F which be grouped into levels based on how far
37/// (the longest distance) they are from the roots.
38BFSLevel getBFSLevel(Function *F) {
39 // The current set of nodes needs to be visited
40 std::unordered_set<Node *> cur;
41 // A map between a node and its level.
42 llvm::DenseMap<Node *, int> nodeLevel;
43
44 // Get the roots set (i.e. the nodes without users).
45 for (auto &node : F->getNodes()) {
46 if (node.getNumUsers() == 0) {
47 // A root has no users.
48 cur.insert(&node);
49 nodeLevel[&node] = 0;
50 }
51 }
52
53 // Create the node to level map by traversing the nodes with BFS order.
54 BFSLevel bfs;
55 int level = 0;
56 int current = 0;
57 bfs.push_back(std::vector<Node *>());
58 level++;
59 while (current < level) {
60 std::unordered_set<Node *> nodes;
61 for (std::unordered_set<Node *>::iterator it = cur.begin(); it != cur.end();
62 ++it) {
63 Node *N = *it;
64 for (size_t j = 0, e = N->getNumInputs(); j < e; ++j) {
65 Node *in = N->getNthInput(j).getNode();
66 if (isa<Storage>(in)) {
67 continue;
68 }
69 nodes.insert(in);
70 nodeLevel[in] = level;
71 }
72 }
73 if (nodes.size() > 0) {
74 bfs.push_back(std::vector<Node *>());
75 level++;
76 cur = std::move(nodes);
77 }
78 current++;
79 }
80
81 // Based on the node to level map, group these nodes by levels.
82 for (llvm::DenseMap<Node *, int>::iterator it = nodeLevel.begin();
83 it != nodeLevel.end(); ++it) {
84 Node *in = (*it).first;
85 int level = (*it).second;
86 bfs[level].push_back(in);
87 }
88
89 // Sort the nodes of each level by name to make sure the nodes sequence are
90 // the same for different run.
91 for (int i = 0; i < level; i++) {
92 std::sort(bfs[i].begin(), bfs[i].end(), compFunc);
93 }
94 return bfs;
95}
96
97/// Given \p nodes, return a list of nodes who are not in this set but use any
98/// node in this set.
99std::vector<Node *> getOutUsers(const NodesSet &nodes) {
100 NodesSet used;
101 for (NodesSet::iterator it = nodes.begin(); it != nodes.end(); ++it) {
102 Node *cur = *it;
103 for (auto &U : cur->getUsers()) {
104 if (nodes.count(U.getUser())) {
105 continue;
106 }
107 used.insert(U.getUser());
108 }
109 }
110
111 std::vector<Node *> ret(used.begin(), used.end());
112 std::sort(ret.begin(), ret.end(), compFunc);
113 return ret;
114}
115
116/// Given \p nodes, return a list of nodes who are not in this set but use only
117/// the nodes in this set or constant.
118std::vector<Node *> getOutUsersWithOnePredecessor(const NodesSet &nodes) {
119 NodesSet used;
120 for (NodesSet::iterator it = nodes.begin(); it != nodes.end(); ++it) {
121 Node *cur = *it;
122 for (auto &U : cur->getUsers()) {
123 Node *user = U.getUser();
124 if (nodes.count(user)) {
125 continue;
126 }
127 bool flag = true;
128 for (size_t i = 0, e = user->getNumInputs(); i < e; i++) {
129 Node *in = user->getNthInput(i).getNode();
130 if (llvm::isa<Storage>(in) || nodes.count(in)) {
131 continue;
132 }
133 flag = false;
134 break;
135 }
136 if (flag) {
137 used.insert(user);
138 }
139 }
140 }
141
142 std::vector<Node *> ret(used.begin(), used.end());
143 std::sort(ret.begin(), ret.end(), compFunc);
144 return ret;
145}
146
147/// \returns the memory usage of the output caused by \p node who has users not
148/// in the set \p nodes.
149uint64_t getOutMemPerNode(const NodesSet &nodes, const Node *node) {
150 uint64_t ret = 0;
151 for (size_t i = 0, e = node->getNumResults(); i < e; i++) {
152 NodeValue nodeVal = node->getNthResult(i);
153 for (auto &U : nodeVal.getUsers()) {
154 Node *user = U.getUser();
155 if (nodes.find(const_cast<Node *>(user)) == nodes.end()) {
156 ret += node->getType(i)->getSizeInBytes();
157 break;
158 }
159 }
160 }
161 return ret;
162}
163
164/// Given a node, \return the NodeSet of all nodes that create the results
165/// for any of the inputs of this node (i.e. input of inputs)
166NodesSet getInputs(const Node *node) {
167 NodesSet result;
168 for (size_t i = 0, e = node->getNumInputs(); i < e; i++) {
169 Node *input = node->getNthInput(i).getNode();
170 Storage *in = llvm::dyn_cast<Storage>(input);
171 if (!in) {
172 result.insert(input);
173 }
174 }
175 return result;
176}
177
178uint64_t getNodeMemUsage(const Node *node) {
179 if (node->getKind() == Kinded::Kind::SaveNodeKind) {
180 return 0;
181 }
182 uint64_t size = 0;
183 for (size_t i = 0, e = node->getNumInputs(); i < e; i++) {
184 Storage *in = llvm::dyn_cast<Storage>(node->getNthInput(i).getNode());
185 if (in) {
186 auto ty = in->getType();
187 size += ty->getSizeInBytes();
188 }
189 }
190 return size;
191}
192
193float getNodeComputeTime(const Node *node, const BackendInfo &backendInfo) {
194 // This code assumes all ops are BW limited from SRAM; except
195 // if the input does not fit in SRAM -- then it is DRAM BW limited
196 float peakDramBw = backendInfo.peakDramBw;
197 float peakSramBw = backendInfo.peakSramBw;
198 uint64_t sramCapacity = backendInfo.sramCapacity;
199 float peakCompute = backendInfo.peakCompute;
200
201 // compute memory side bytes for inputs from DRAM, SRAM.
202 // TODO: think about whether this is better off computed inside a Node.
203
204 int n = node->getNumInputs();
205 uint64_t sizeDram = 0;
206 uint64_t sizeSram = 0;
207 if (node->getKind() == Kinded::Kind::SaveNodeKind) {
208 return 0.0f;
209 }
210 // The memory bytes for embedding table lookups is data dependent,
211 // so it needs to be calculated as per the number of indices accessed.
212 if (node->getKind() == Kinded::Kind::SparseLengthsWeightedSumNodeKind) {
213 auto *SLWSN = llvm::dyn_cast<SparseLengthsWeightedSumNode>(node);
214 // compute how many entries of the embedding table we look up
215 auto numLookups = SLWSN->getIndices().dims().front();
216 // compute how many bytes we read per lookup
217 auto tableSize = SLWSN->getData().getType()->getSizeInBytes();
218 auto numRows = SLWSN->getData().dims().front();
219 auto sizePerLookup = tableSize / numRows;
220 // compute total bytes read
221 uint64_t sizeInput = numLookups * sizePerLookup;
222
223 // tables are usually large and fit in DRAM
224 sizeDram += sizeInput;
225 // we also read the indices, weights and lengths arrays
226 sizeSram += SLWSN->getIndices().getType()->getSizeInBytes();
227 sizeSram += SLWSN->getWeights().getType()->getSizeInBytes();
228 sizeSram += SLWSN->getLengths().getType()->getSizeInBytes();
229 } else if (node->getKind() == Kinded::Kind::SparseLengthsSumNodeKind) {
230 auto *SLSN = llvm::dyn_cast<SparseLengthsSumNode>(node);
231 // compute how many entries of the embedding table we look up
232 auto numLookups = SLSN->getIndices().dims().front();
233 // compute how many bytes we read per lookup
234 auto tableSize = SLSN->getData().getType()->getSizeInBytes();
235 auto numRows = SLSN->getData().dims().front();
236 auto sizePerLookup = tableSize / numRows;
237 // compute total bytes read
238 uint64_t sizeInput = numLookups * sizePerLookup;
239
240 // tables are usually large and fit in DRAM
241 sizeDram += sizeInput;
242 // we also read the indices and lengths arrays
243 sizeSram += SLSN->getIndices().getType()->getSizeInBytes();
244 sizeSram += SLSN->getLengths().getType()->getSizeInBytes();
245 } else if (node->getKind() ==
246 Kinded::Kind::
247 FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind) {
248 auto *FRQSLWSN =
249 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(node);
250 // compute how many entries of the embedding table we look up
251 auto numLookups = FRQSLWSN->getIndices().dims().front();
252 // compute how many bytes we read per lookup
253 auto tableSize = FRQSLWSN->getData().getType()->getSizeInBytes();
254 auto numRows = FRQSLWSN->getData().dims().front();
255 auto sizePerLookup = tableSize / numRows;
256 // compute total bytes read
257 uint64_t sizeInput = numLookups * sizePerLookup;
258
259 // tables are usually large and fit in DRAM
260 sizeDram += sizeInput;
261
262 // we also read the indices, weights and lengths arrays
263 sizeSram += FRQSLWSN->getIndices().getType()->getSizeInBytes();
264 sizeSram += FRQSLWSN->getWeights().getType()->getSizeInBytes();
265 sizeSram += FRQSLWSN->getLengths().getType()->getSizeInBytes();
266 } else if (node->getKind() ==
267 Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind) {
268 auto *FRQSLSN =
269 llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsSumNode>(node);
270 // compute how many entries of the embedding table we look up
271 auto numLookups = FRQSLSN->getIndices().dims().front();
272 // compute how many bytes we read per lookup
273 auto tableSize = FRQSLSN->getData().getType()->getSizeInBytes();
274 auto numRows = FRQSLSN->getData().dims().front();
275 auto sizePerLookup = tableSize / numRows;
276 // compute total bytes read
277 uint64_t sizeInput = numLookups * sizePerLookup;
278
279 // tables are usually large and fit in DRAM
280 sizeDram += sizeInput;
281
282 // we also read the indices and lengths arrays
283 sizeSram += FRQSLSN->getIndices().getType()->getSizeInBytes();
284 sizeSram += FRQSLSN->getLengths().getType()->getSizeInBytes();
285 } else {
286 // for all other ops, iterate through all inputs and get size in bytes
287 for (int i = 0; i < n; i++) {
288 auto ty = node->getNthInput(i).getType();
289 uint64_t sizeInput = ty->getSizeInBytes();
290 if (sizeInput > sramCapacity) {
291 sizeDram += sizeInput;
292 } else {
293 sizeSram += sizeInput;
294 }
295 }
296 }
297
298 // Repeat for outputs
299 for (size_t i = 0, e = node->getNumResults(); i < e; i++) {
300 auto myty = node->getType(i);
301 uint64_t sizeOutput = myty->getSizeInBytes();
302 if (sizeOutput > sramCapacity) {
303 sizeDram += sizeOutput;
304 } else {
305 sizeSram += sizeOutput;
306 }
307 }
308
309 // Calculate compute ops. Currently only computed for Matmul, Conv, FC
310 // TODO: think about whether this is better off computed inside a Node.
311 uint64_t totalOps = 0;
312 switch (node->getKind()) {
313 case Kinded::Kind::MatMulNodeKind: {
314 auto *MMN = llvm::dyn_cast<MatMulNode>(node);
315 auto lhsDims = MMN->getLHS().dims();
316 auto rhsDims = MMN->getRHS().dims();
317 totalOps = 2 * lhsDims[0] * lhsDims[1] * rhsDims[1];
318 break;
319 }
320 case Kinded::Kind::FullyConnectedNodeKind: {
321 auto *FCN = llvm::dyn_cast<FullyConnectedNode>(node);
322 auto inputDims = FCN->getInput().dims();
323 auto wtDims = FCN->getWeights().dims();
324 totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[0];
325 break;
326 }
327#ifdef GLOW_WITH_HABANA
328 case Kinded::Kind::HabanaFullyConnectedNodeKind: {
329 auto *FCN = llvm::dyn_cast<HabanaFullyConnectedNode>(node);
330 auto inputDims = FCN->getInput().dims();
331 auto wtDims = FCN->getWeights().dims();
332 totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[0];
333 break;
334 }
335#endif
336 case Kinded::Kind::ConvolutionNodeKind: {
337 auto *CN = llvm::dyn_cast<ConvolutionNode>(node);
338 auto resultDims = CN->getResult().dims();
339 // Get the product of batch, output height, output dims, output channels
340 totalOps = resultDims[0];
341 for (size_t i = 1, e = resultDims.size(); i < e; i++) {
342 totalOps *= resultDims[i];
343 }
344 // Multiply in kernel height, kernel width
345 auto kernelDims = CN->getKernels();
346 totalOps *= kernelDims[0] * kernelDims[1];
347 // Multiply in input channels/groups
348 auto inputChannels = CN->getInput().dims()[1];
349 auto nGroups = CN->getGroup();
350 totalOps *= (inputChannels * 1.0 / nGroups);
351 break;
352 }
353#ifdef GLOW_WITH_HABANA
354 case Kinded::Kind::HabanaConvolutionNodeKind: {
355 auto *CN = llvm::dyn_cast<HabanaConvolutionNode>(node);
356 auto resultDims = CN->getResult().dims();
357 // Get the product of batch, output height, output dims, output channels
358 totalOps = resultDims[0];
359 for (size_t i = 1, e = resultDims.size(); i < e; i++) {
360 totalOps *= resultDims[i];
361 }
362 // Multiply in kernel height, kernel width
363 auto kernelDims = CN->getKernels();
364 totalOps *= kernelDims[0] * kernelDims[1];
365 // Multiply in input channels/groups
366 auto inputChannels = CN->getInput().dims()[1];
367 auto nGroups = CN->getGroup();
368 totalOps *= (inputChannels * 1.0 / nGroups);
369 break;
370 }
371#endif
372 default:
373 break;
374 }
375
376 // Compute compute roofline as max of flops, DRAM, SRAM BW
377 // See https://bit.ly/2UdJ3mz
378 // Add epsilons to prevent seg faults on uninitialized peak values.
379 return std::max(totalOps * 1.0f / std::max(peakCompute, 1e-6f),
380 std::max(sizeDram * 1.0f / std::max(peakDramBw, 1e-6f),
381 sizeSram * 1.0f / std::max(peakSramBw, 1e-6f)));
382}
383
384/// Given nodes set \p currNodes and its memory usage info \p info, \returns the
385/// new memory usage if \p newNode is added into \p currNodes.
386GraphMemInfo updateGraphMemInfoByAddingNode(const NodesSet &currNodes,
387 const GraphMemInfo &info,
388 Node *newNode) {
389 GraphMemInfo ret = info;
390
391 // Collect the used NodeValues (Storage nodes and outputs from the nodes
392 // outside of currNodes).
393 std::set<NodeValue> usedNodeValue;
394 for (auto N : currNodes) {
395 for (size_t i = 0, e = N->getNumInputs(); i < e; i++) {
396 NodeValue nodeVal = N->getNthInput(i);
397 if (currNodes.count(nodeVal.getNode()) == 0) {
398 usedNodeValue.insert(nodeVal);
399 }
400 }
401 }
402 // Calculate new outMemSize.
403 NodesSet newNodes = currNodes;
404 newNodes.insert(newNode);
405 uint64_t newSize = 0;
406 for (auto *node : newNodes) {
407 if (auto *SN = llvm::dyn_cast<SaveNode>(node)) {
408 // SaveNode is a special case since it has no users but always writes out.
409 newSize += SN->getOutput().getType()->getSizeInBytes();
410 } else {
411 newSize += getOutMemPerNode(newNodes, node);
412 }
413 }
414 ret.outMemSize = newSize;
415
416 // The memory usage changes due to newNode's inputs:
417 for (size_t i = 0, e = newNode->getNumInputs(); i < e; i++) {
418 if (llvm::isa<SaveNode>(newNode) && i == SaveNode::OutputIdx) {
419 continue;
420 }
421 NodeValue nodeVal = newNode->getNthInput(i);
422 Node *N = nodeVal.getNode();
423
424 if (usedNodeValue.count(nodeVal)) {
425 // This input has been considered already, nothing to do.
426 continue;
427 }
428
429 Storage *in = llvm::dyn_cast<Storage>(N);
430 if (in) {
431 // Node uses placeholders or constants which are not used in this set
432 // before, need to add the memory.
433 uint64_t size = in->getType()->getSizeInBytes();
434 if (in->getKind() == Kinded::Kind::ConstantKind) {
435 ret.constMemSize += size;
436 } else {
437 Placeholder *ph = llvm::dyn_cast<Placeholder>(N);
438 // If PH is static treat like a constant.
439 if (ph->isStatic()) {
440 ret.constMemSize += size;
441 ret.deferredConstMemSize += size;
442 } else {
443 // PlaceHolder for Input.
444 ret.inMemSize += size;
445 ret.inputCount += 1;
446 }
447 }
448 usedNodeValue.insert(nodeVal);
449 continue;
450 }
451
452 if (!currNodes.count(N)) {
453 // In this case, this input is not a storage type node nor belongs
454 // to this subgraph. Therefore, when creating paritions, we need to add
455 // a PlaceHolder for the data from outside.
456 ret.inMemSize += nodeVal.getType()->getSizeInBytes();
457 ret.inputCount += 1;
458 usedNodeValue.insert(nodeVal);
459 }
460 }
461
462 for (size_t i = 0, e = newNode->getNumResults(); i < e; i++) {
463 auto nodeVal = newNode->getNthResult(i);
464 for (auto &U : nodeVal.getUsers()) {
465 if (currNodes.count(U.getUser()) == 0) {
466 // The nodeVal (i.e. the ith output of newNode) is not used in
467 // currNodes:
468 continue;
469 }
470 // Assume newNode -> node1, where node1 belongs to currNodes set. Before
471 // newNode is added, node1's input size (from newNode) should be added
472 // into inMemSize. But afater newNode is added, the input size should be
473 // removed.
474 ret.inMemSize -= nodeVal.getType()->getSizeInBytes();
475 ret.inputCount -= 1;
476 break;
477 }
478 }
479
480 return ret;
481}
482
483GraphMemInfo getGraphMemInfo(const NodesSet &nodes, unsigned contextCount) {
484 GraphMemInfo ret;
485 ret.contextCount = contextCount;
486 NodesSet nodeSet;
487 for (NodesSet::iterator it = nodes.begin(); it != nodes.end(); ++it) {
488 Node *cur = *it;
489 ret = updateGraphMemInfoByAddingNode(nodeSet, ret, cur);
490 nodeSet.insert(cur);
491 }
492 return ret;
493}
494
495GraphMemInfo getFunctionMemory(Function *func) {
496 GraphMemInfo graphMem;
497
498 for (auto cons : func->findConstants()) {
499 graphMem.constMemSize += cons->getType()->getSizeInBytes();
500 }
501
502 // Gather all other functions in the module for peer resource usage counting.
503 std::vector<const Function *> otherFuns;
504 std::copy_if(func->getParent()->getFunctions().begin(),
505 func->getParent()->getFunctions().end(),
506 std::back_inserter(otherFuns),
507 [func](Function *F) { return func != F; });
508
509 // Walk thru all Placeholders in the function to accumulate input and
510 // output mem size. These utility functions check the users of the PH to
511 // determine if the PH is an input or an output.
512 for (auto &place : func->findPlaceholders()) {
513 if (place->isStatic()) {
514 graphMem.constMemSize += place->getType()->getSizeInBytes();
515 graphMem.deferredConstMemSize += place->getType()->getSizeInBytes();
516 } else {
517 if (isInput(place, *func)) {
518 graphMem.inMemSize += place->getType()->getSizeInBytes();
519 graphMem.inputCount += 1;
520 // Check if this placeholder is the output of a peer function.
521 if (isOutput(place, otherFuns)) {
522 graphMem.inputFromPeerCount += 1;
523 }
524 }
525 if (isOutput(place, *func)) {
526 graphMem.outMemSize += place->getType()->getSizeInBytes();
527 }
528 }
529 }
530
531 return graphMem;
532}
533
534std::set<Kinded::Kind> generateNodeKindsSet(llvm::StringRef names) {
535 std::set<Kinded::Kind> nodeKindsSet;
536 llvm::StringRef::size_type pos = names.find(',');
537 while (pos != llvm::StringRef::npos) {
538 nodeKindsSet.insert(getKindFromNodeName(names.substr(0, pos)));
539 names = names.substr(pos + 1);
540 pos = names.find(',');
541 }
542 if (!names.empty()) {
543 nodeKindsSet.insert(getKindFromNodeName(names));
544 }
545 return nodeKindsSet;
546}
547
548void logPartitionInfo(const NodeToFunctionMap &partitions) {
549 int i = 0;
550 for (Function *subF : partitions.getPartitions()) {
551 LOG(INFO) << "\t Partition " << i++ << ":\n"
552 << "\t\t Name :\t" << subF->getName().str() << "\n"
553 << "\t\t BackendKind :\t"
554 << partitions.getPartitionBackendName(subF) << "\n"
555 << "\t\t context count :\t"
556 << partitions.getGraphMemInfo(subF).contextCount << "\n"
557 << "\t\t total Memory :\t"
558 << partitions.getGraphMemInfo(subF).getTotalMemSize() << "\n"
559 << "\t\t\t input size:\t"
560 << partitions.getGraphMemInfo(subF).inMemSize << "\n"
561 << "\t\t\t input count :\t"
562 << partitions.getGraphMemInfo(subF).inputCount << "\n"
563 << "\t\t\t input only from peers count :\t"
564 << partitions.getGraphMemInfo(subF).inputFromPeerCount << "\n"
565 << "\t\t\t output size:\t"
566 << partitions.getGraphMemInfo(subF).outMemSize << "\n"
567 << "\t\t\t constant size:\t"
568 << partitions.getGraphMemInfo(subF).constMemSize << "\n"
569 << "\t\t\t\t non-deferred constant size:\t"
570 << partitions.getGraphMemInfo(subF).constMemSize -
571 partitions.getGraphMemInfo(subF).deferredConstMemSize
572 << "\n"
573 << "\t\t\t\t deferred constant size:\t"
574 << partitions.getGraphMemInfo(subF).deferredConstMemSize << "\n";
575 // This may be called before logicalDevices are assigned so check before
576 // printing.
577 if (partitions.getLogicalDeviceIDList(subF).size()) {
578 LOG(INFO) << "\t\t LogicalDeviceIDs :\t"
579 << partitions.getLogicalDeviceIDList(subF)[0] << "\n";
580 }
581 }
582}
583
584void printSlsTableInfo(std::vector<SLSTableInfo>::iterator start,
585 std::vector<SLSTableInfo>::iterator end,
586 bool verbose_only) {
587 if (start >= end) {
588 return;
589 }
590 std::stringstream ss;
591 ss << "(numBytesInTable(MB), deviceID, cost, cost/numBytesInTable) "
592 << strFormat(" - %zu tables -", end - start) << "\n";
593 while (start < end) {
594 const auto tableSizeInMB = (float)start->numBytesInTable / MB;
595 const auto costPerByte = tableSizeInMB == 0
596 ? "nan"
597 : std::to_string(start->cost / tableSizeInMB);
598 ss << " " << tableSizeInMB << " " << start->deviceId
599 << " " << start->cost << " " << costPerByte << std::endl;
600 start++;
601 }
602 if (verbose_only) {
603 VLOG(1) << ss.str();
604 } else {
605 LOG(INFO) << ss.str();
606 }
607}
608
609void printSlsTableInfo(std::vector<SLSTableInfo> &slsTables,
610 bool verbose_only) {
611 printSlsTableInfo(slsTables.begin(), slsTables.end(), verbose_only);
612}
613
614void printSlsDeviceInfo(const std::vector<SLSDeviceInfo> &slsDevices,
615 const std::vector<NodesSet> &nodesets,
616 const unsigned contextCount, bool verbose_only) {
617 std::stringstream ss;
618 ss << "(deviceId, used_memory(MB), free_memory(MB), cost, "
619 "node_size, cost/used_memory)"
620 << strFormat(" - %zu devices -", slsDevices.size()) << "\n";
621 for (const auto &d : slsDevices) {
622 const auto deviceId = d.deviceId;
623 const auto meminfo = getGraphMemInfo(nodesets[deviceId], contextCount);
624 const auto usedMem = (float)meminfo.getTotalMemSize() / MB;
625 const auto availMem = (float)d.memAvailableInBytes / MB;
626 const auto freeMem = availMem - usedMem;
627 const auto costPerUsedMemory =
628 usedMem == 0 ? "nan" : std::to_string(d.currentCost / usedMem);
629 ss << " " << deviceId << " " << usedMem << " " << freeMem
630 << " " << d.currentCost << " " << nodesets[deviceId].size()
631 << " " << costPerUsedMemory << "\n";
632 }
633 if (verbose_only) {
634 VLOG(1) << ss.str();
635 } else {
636 LOG(INFO) << ss.str();
637 }
638}
639
640bool isSLSNode(const Node *node) {
641 return (
642 node->getKind() ==
643 glow::Kinded::Kind::
644 FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind ||
645 node->getKind() ==
646 glow::Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind ||
647 node->getKind() == glow::Kinded::Kind::
648 RowwiseQuantizedSparseLengthsWeightedSumNodeKind ||
649 node->getKind() == glow::Kinded::Kind::SparseLengthsSumNodeKind ||
650 node->getKind() == glow::Kinded::Kind::SparseLengthsWeightedSumNodeKind ||
651 node->getKind() == glow::Kinded::Kind::EmbeddingBagNodeKind ||
652 node->getKind() ==
653 glow::Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind);
654}
655
656bool checkNodeInputsAllKind(const Node *node, glow::Kinded::Kind kind) {
657 bool allSameKind = true;
658 for (auto i = 0; i < node->getNumInputs(); i++) {
659 auto nodeInput = node->getNthInput(i);
660 allSameKind &= nodeInput.getNode()->getKind() == kind;
661 }
662 return allSameKind;
663}
664
665Error assignSlsTableToFirstAvailableDevice(
666 SLSTableInfo &table, std::vector<SLSDeviceInfo> &slsDevices,
667 std::vector<NodesSet> &nodesets,
668 std::vector<std::unordered_set<NodeValue>> &frontierValues,
669 const unsigned contextCount,
670 std::unordered_map<Node *, size_t> &addedSLSNodes) {
671 DCHECK(slsDevices.size() == nodesets.size() &&
672 slsDevices.size() == frontierValues.size());
673 auto addedNodeDeviceId = addedSLSNodes.find(table.node);
674 if (addedNodeDeviceId != addedSLSNodes.end()) {
675 table.deviceId = addedNodeDeviceId->second;
676 return Error::success();
677 }
678
679 bool deviceFound = false;
680 for (auto &d : slsDevices) {
681 const auto deviceId = d.deviceId;
682 // Calculate the memory needed if we merge SLS and its neighboring nodes
683 // into existing partition
684 auto nodesSetd = nodesets[deviceId];
685 nodesSetd.insert(table.node);
686 nodesSetd.insert(table.neighbors.begin(), table.neighbors.end());
687 auto meminfo = getGraphMemInfo(nodesSetd, contextCount);
688 const auto totalSize = meminfo.getTotalMemSize();
689 if (d.memAvailableInBytes >= totalSize) {
690 d.currentCost += (size_t)table.cost;
691 table.deviceId = deviceId;
692 frontierValues[deviceId].insert(table.frontier.begin(),
693 table.frontier.end());
694 for (auto &nb : table.neighbors) {
695 if (isSLSNode(nb)) {
696 addedSLSNodes.insert({nb, deviceId});
697 }
698 }
699 nodesets[deviceId].swap(nodesSetd);
700 deviceFound = true;
701 break;
702 }
703 }
704 if (!deviceFound) {
705 return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
706 "SLS Balancing Partitioning Error: Not enough memory");
707 }
708 return Error::success();
709}
710
711Error assignSlsTablesToDevices(
712 std::vector<SLSTableInfo> &slsTables,
713 std::vector<SLSDeviceInfo> &slsDevices,
714 std::vector<std::unordered_set<NodeValue>> &frontierValues,
715 const unsigned contextCount) {
716 if (slsTables.empty()) {
717 LOG(INFO) << "SLS tables empty!";
718 return Error::success();
719 }
720 // Keep a copy of input parameters, so that ScopeGuard could restore
721 // inputs in case of error.
722 std::vector<SLSTableInfo> slsTablesCopy = slsTables;
723 std::vector<SLSDeviceInfo> slsDevicesCopy = slsDevices;
724 std::vector<std::unordered_set<NodeValue>> frontierValuesCopy =
725 frontierValues;
726 ScopeGuard restoreInputsOnError([&]() {
727 slsTables.swap(slsTablesCopy);
728 slsDevices.swap(slsDevicesCopy);
729 frontierValues.swap(frontierValuesCopy);
730 });
731
732 // Now sort SLS tables by size decreasing
733 VLOG(1) << "SLS tables sorted by size decreasing";
734 std::sort(slsTables.begin(), slsTables.end(),
735 [](const SLSTableInfo &l, const SLSTableInfo &r) {
736 return l.numBytesInTable > r.numBytesInTable;
737 });
738
739 // slsTables is in sorted order decreasingly by numBytesInTable.
740 // The tables between [slsTablesLeft, slsTableRight) are large tables that
741 // have numBytesInTable > BigTableThresholdBytes.
742 // slsTablesLeft and slsTablesRight will be both pointed to slsTables.begin()
743 // if we could not find any large tables.
744 auto slsTablesLeft = slsTables.begin();
745 auto slsTableRight = slsTables.end();
746 if (slsTablesLeft->numBytesInTable >
747 glow::runtime::flags::BigTableThresholdBytes) {
748 for (auto it = slsTables.begin(); it < slsTables.end(); it++) {
749 if (it->numBytesInTable <= glow::runtime::flags::BigTableThresholdBytes) {
750 slsTableRight = it;
751 break;
752 }
753 }
754 } else {
755 // No large table found.
756 slsTablesLeft = slsTables.begin();
757 slsTableRight = slsTables.begin();
758 }
759
760 // We first assign large tables to devices. After allocation, each device
761 // should has roughly the same size.
762 LOG(INFO) << strFormat("Now assign %zu large tables to %zu devices.",
763 (slsTableRight - slsTablesLeft), slsDevices.size());
764 // Print Large SLS tables
765 VLOG(1) << "Large tables by size decreasing: ";
766 printSlsTableInfo(slsTablesLeft, slsTableRight);
767 std::vector<NodesSet> nodesets(slsDevices.size());
768 std::unordered_map<Node *, size_t> addedSLSNodes;
769 while (slsTablesLeft < slsTableRight) {
770 // Sort devices by size increasingly.
771 std::sort(slsDevices.begin(), slsDevices.end(),
772 [&nodesets, contextCount](const SLSDeviceInfo &l,
773 const SLSDeviceInfo &r) {
774 auto lTotalSize =
775 getGraphMemInfo(nodesets[l.deviceId], contextCount)
776 .getTotalMemSize();
777 auto rTotalSize =
778 getGraphMemInfo(nodesets[r.deviceId], contextCount)
779 .getTotalMemSize();
780 return lTotalSize < rTotalSize;
781 });
782 VLOG(1) << "Devices sorted by used memory increasing: ";
783 printSlsDeviceInfo(slsDevices, nodesets, contextCount,
784 true /* verbose_only */);
785
786 // Pick the first that fits
787 auto &table = *slsTablesLeft;
788 RETURN_IF_ERR(assignSlsTableToFirstAvailableDevice(
789 table, slsDevices, nodesets, frontierValues, contextCount,
790 addedSLSNodes));
791 slsTablesLeft++;
792 }
793 VLOG(1) << "Done assigning large tables, devices info: ";
794 printSlsDeviceInfo(slsDevices, nodesets, contextCount,
795 true /* verbose_only */);
796
797 // Now let us assign small size tables.
798 // First sort tables by cost decreasingly. For each table, we would like to
799 // assign it to the device with lowest cost.
800 LOG(INFO) << strFormat("Now assign %zu small tables to %zu devices.",
801 (slsTables.end() - slsTablesLeft), slsDevices.size());
802 if (slsTablesLeft < slsTables.end()) {
803 std::sort(slsTablesLeft, slsTables.end(),
804 [](const SLSTableInfo &l, const SLSTableInfo &r) {
805 return l.cost > r.cost;
806 });
807 }
808 VLOG(1) << "Small tables by cost decreasingly: ";
809 printSlsTableInfo(slsTablesLeft, slsTables.end());
810
811 while (slsTablesLeft < slsTables.end()) {
812 // Sort devices by cost increasingly.
813 std::sort(slsDevices.begin(), slsDevices.end(),
814 [](const SLSDeviceInfo &l, const SLSDeviceInfo &r) {
815 return l.currentCost < r.currentCost;
816 });
817
818 VLOG(1) << "Devices sorted by cost increasing: ";
819 printSlsDeviceInfo(slsDevices, nodesets, contextCount,
820 true /* verbose_only */);
821
822 // Pick the first that fits
823 auto &table = *slsTablesLeft;
824 RETURN_IF_ERR(assignSlsTableToFirstAvailableDevice(
825 table, slsDevices, nodesets, frontierValues, contextCount,
826 addedSLSNodes));
827 slsTablesLeft++;
828 }
829 // Print final device info
830 LOG(INFO) << "Done assigning small tables, final devices info: ";
831 printSlsDeviceInfo(slsDevices, nodesets, contextCount,
832 false /* verbose_only */);
833 restoreInputsOnError.dismiss();
834 return Error::success();
835}
836
837Error assignSlsTablesToDevicesGreedy(
838 std::vector<SLSTableInfo> &slsTables,
839 std::vector<SLSDeviceInfo> &slsDevices,
840 std::vector<std::unordered_set<NodeValue>> &frontierValues,
841 const unsigned contextCount) {
842 if (slsTables.empty()) {
843 LOG(INFO) << "SLS tables empty!";
844 return Error::success();
845 }
846 // Keep a copy of input parameters, so that ScopeGuard could restore
847 // inputs in case of error.
848 std::vector<SLSTableInfo> slsTablesCopy = slsTables;
849 std::vector<SLSDeviceInfo> slsDevicesCopy = slsDevices;
850 std::vector<std::unordered_set<NodeValue>> frontierValuesCopy =
851 frontierValues;
852 ScopeGuard restoreInputsOnError([&]() {
853 slsTables.swap(slsTablesCopy);
854 slsDevices.swap(slsDevicesCopy);
855 frontierValues.swap(frontierValuesCopy);
856 });
857
858 // Now sort SLS tables by size decreasing
859 VLOG(1) << "SLS tables sorted by size decreasing" << std::endl;
860 std::sort(slsTables.begin(), slsTables.end(),
861 [](const SLSTableInfo &l, const SLSTableInfo &r) {
862 return l.numBytesInTable > r.numBytesInTable;
863 });
864
865 // Print SLS tables
866 printSlsTableInfo(slsTables);
867
868 // Now assign SLS Nodes to devices
869 std::vector<NodesSet> nodesets(slsDevices.size());
870 std::unordered_map<Node *, size_t> addedSLSNodes;
871 for (auto &table : slsTables) {
872
873 // Sort by cost increasing
874 std::sort(slsDevices.begin(), slsDevices.end(),
875 [](const SLSDeviceInfo &l, const SLSDeviceInfo &r) {
876 return l.currentCost < r.currentCost;
877 });
878
879 VLOG(1) << "Devices sorted by cost increasing" << std::endl;
880 printSlsDeviceInfo(slsDevices, nodesets, contextCount,
881 true /* verbose_only */);
882
883 // Pick the first that fits
884 RETURN_IF_ERR(assignSlsTableToFirstAvailableDevice(
885 table, slsDevices, nodesets, frontierValues, contextCount,
886 addedSLSNodes));
887 }
888 // Print final device info
889 LOG(INFO) << "Devices sorted by cost increasing: ";
890 printSlsDeviceInfo(slsDevices, nodesets, contextCount,
891 false /* verbose_only */);
892 restoreInputsOnError.dismiss();
893 return Error::success();
894}
895
896} // namespace glow
897