1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Partitioner/PartitionerUtils.h" |
17 | #include "glow/Backend/BackendUtils.h" |
18 | #include "glow/Flags/Flags.h" |
19 | #include "glow/Partitioner/PartitionerTypes.h" |
20 | #include "glow/Support/Support.h" |
21 | |
22 | #include <unordered_set> |
23 | |
24 | using llvm::isa; |
25 | |
26 | namespace glow { |
27 | |
28 | namespace { |
29 | /// Used to sort 2 Nodes by their name, i.e. n1->name < n2->name order. |
30 | auto compFunc = [](const Node *n1, Node *n2) -> bool { |
31 | return n1->compareByName(*n2); |
32 | }; |
33 | constexpr uint32_t MB = 1024 * 1024; |
34 | } // namespace |
35 | |
36 | /// The nodes in function \p F which be grouped into levels based on how far |
37 | /// (the longest distance) they are from the roots. |
38 | BFSLevel getBFSLevel(Function *F) { |
39 | // The current set of nodes needs to be visited |
40 | std::unordered_set<Node *> cur; |
41 | // A map between a node and its level. |
42 | llvm::DenseMap<Node *, int> nodeLevel; |
43 | |
44 | // Get the roots set (i.e. the nodes without users). |
45 | for (auto &node : F->getNodes()) { |
46 | if (node.getNumUsers() == 0) { |
47 | // A root has no users. |
48 | cur.insert(&node); |
49 | nodeLevel[&node] = 0; |
50 | } |
51 | } |
52 | |
53 | // Create the node to level map by traversing the nodes with BFS order. |
54 | BFSLevel bfs; |
55 | int level = 0; |
56 | int current = 0; |
57 | bfs.push_back(std::vector<Node *>()); |
58 | level++; |
59 | while (current < level) { |
60 | std::unordered_set<Node *> nodes; |
61 | for (std::unordered_set<Node *>::iterator it = cur.begin(); it != cur.end(); |
62 | ++it) { |
63 | Node *N = *it; |
64 | for (size_t j = 0, e = N->getNumInputs(); j < e; ++j) { |
65 | Node *in = N->getNthInput(j).getNode(); |
66 | if (isa<Storage>(in)) { |
67 | continue; |
68 | } |
69 | nodes.insert(in); |
70 | nodeLevel[in] = level; |
71 | } |
72 | } |
73 | if (nodes.size() > 0) { |
74 | bfs.push_back(std::vector<Node *>()); |
75 | level++; |
76 | cur = std::move(nodes); |
77 | } |
78 | current++; |
79 | } |
80 | |
81 | // Based on the node to level map, group these nodes by levels. |
82 | for (llvm::DenseMap<Node *, int>::iterator it = nodeLevel.begin(); |
83 | it != nodeLevel.end(); ++it) { |
84 | Node *in = (*it).first; |
85 | int level = (*it).second; |
86 | bfs[level].push_back(in); |
87 | } |
88 | |
89 | // Sort the nodes of each level by name to make sure the nodes sequence are |
90 | // the same for different run. |
91 | for (int i = 0; i < level; i++) { |
92 | std::sort(bfs[i].begin(), bfs[i].end(), compFunc); |
93 | } |
94 | return bfs; |
95 | } |
96 | |
97 | /// Given \p nodes, return a list of nodes who are not in this set but use any |
98 | /// node in this set. |
99 | std::vector<Node *> getOutUsers(const NodesSet &nodes) { |
100 | NodesSet used; |
101 | for (NodesSet::iterator it = nodes.begin(); it != nodes.end(); ++it) { |
102 | Node *cur = *it; |
103 | for (auto &U : cur->getUsers()) { |
104 | if (nodes.count(U.getUser())) { |
105 | continue; |
106 | } |
107 | used.insert(U.getUser()); |
108 | } |
109 | } |
110 | |
111 | std::vector<Node *> ret(used.begin(), used.end()); |
112 | std::sort(ret.begin(), ret.end(), compFunc); |
113 | return ret; |
114 | } |
115 | |
116 | /// Given \p nodes, return a list of nodes who are not in this set but use only |
117 | /// the nodes in this set or constant. |
118 | std::vector<Node *> getOutUsersWithOnePredecessor(const NodesSet &nodes) { |
119 | NodesSet used; |
120 | for (NodesSet::iterator it = nodes.begin(); it != nodes.end(); ++it) { |
121 | Node *cur = *it; |
122 | for (auto &U : cur->getUsers()) { |
123 | Node *user = U.getUser(); |
124 | if (nodes.count(user)) { |
125 | continue; |
126 | } |
127 | bool flag = true; |
128 | for (size_t i = 0, e = user->getNumInputs(); i < e; i++) { |
129 | Node *in = user->getNthInput(i).getNode(); |
130 | if (llvm::isa<Storage>(in) || nodes.count(in)) { |
131 | continue; |
132 | } |
133 | flag = false; |
134 | break; |
135 | } |
136 | if (flag) { |
137 | used.insert(user); |
138 | } |
139 | } |
140 | } |
141 | |
142 | std::vector<Node *> ret(used.begin(), used.end()); |
143 | std::sort(ret.begin(), ret.end(), compFunc); |
144 | return ret; |
145 | } |
146 | |
147 | /// \returns the memory usage of the output caused by \p node who has users not |
148 | /// in the set \p nodes. |
149 | uint64_t getOutMemPerNode(const NodesSet &nodes, const Node *node) { |
150 | uint64_t ret = 0; |
151 | for (size_t i = 0, e = node->getNumResults(); i < e; i++) { |
152 | NodeValue nodeVal = node->getNthResult(i); |
153 | for (auto &U : nodeVal.getUsers()) { |
154 | Node *user = U.getUser(); |
155 | if (nodes.find(const_cast<Node *>(user)) == nodes.end()) { |
156 | ret += node->getType(i)->getSizeInBytes(); |
157 | break; |
158 | } |
159 | } |
160 | } |
161 | return ret; |
162 | } |
163 | |
164 | /// Given a node, \return the NodeSet of all nodes that create the results |
165 | /// for any of the inputs of this node (i.e. input of inputs) |
166 | NodesSet getInputs(const Node *node) { |
167 | NodesSet result; |
168 | for (size_t i = 0, e = node->getNumInputs(); i < e; i++) { |
169 | Node *input = node->getNthInput(i).getNode(); |
170 | Storage *in = llvm::dyn_cast<Storage>(input); |
171 | if (!in) { |
172 | result.insert(input); |
173 | } |
174 | } |
175 | return result; |
176 | } |
177 | |
178 | uint64_t getNodeMemUsage(const Node *node) { |
179 | if (node->getKind() == Kinded::Kind::SaveNodeKind) { |
180 | return 0; |
181 | } |
182 | uint64_t size = 0; |
183 | for (size_t i = 0, e = node->getNumInputs(); i < e; i++) { |
184 | Storage *in = llvm::dyn_cast<Storage>(node->getNthInput(i).getNode()); |
185 | if (in) { |
186 | auto ty = in->getType(); |
187 | size += ty->getSizeInBytes(); |
188 | } |
189 | } |
190 | return size; |
191 | } |
192 | |
193 | float getNodeComputeTime(const Node *node, const BackendInfo &backendInfo) { |
194 | // This code assumes all ops are BW limited from SRAM; except |
195 | // if the input does not fit in SRAM -- then it is DRAM BW limited |
196 | float peakDramBw = backendInfo.peakDramBw; |
197 | float peakSramBw = backendInfo.peakSramBw; |
198 | uint64_t sramCapacity = backendInfo.sramCapacity; |
199 | float peakCompute = backendInfo.peakCompute; |
200 | |
201 | // compute memory side bytes for inputs from DRAM, SRAM. |
202 | // TODO: think about whether this is better off computed inside a Node. |
203 | |
204 | int n = node->getNumInputs(); |
205 | uint64_t sizeDram = 0; |
206 | uint64_t sizeSram = 0; |
207 | if (node->getKind() == Kinded::Kind::SaveNodeKind) { |
208 | return 0.0f; |
209 | } |
210 | // The memory bytes for embedding table lookups is data dependent, |
211 | // so it needs to be calculated as per the number of indices accessed. |
212 | if (node->getKind() == Kinded::Kind::SparseLengthsWeightedSumNodeKind) { |
213 | auto *SLWSN = llvm::dyn_cast<SparseLengthsWeightedSumNode>(node); |
214 | // compute how many entries of the embedding table we look up |
215 | auto numLookups = SLWSN->getIndices().dims().front(); |
216 | // compute how many bytes we read per lookup |
217 | auto tableSize = SLWSN->getData().getType()->getSizeInBytes(); |
218 | auto numRows = SLWSN->getData().dims().front(); |
219 | auto sizePerLookup = tableSize / numRows; |
220 | // compute total bytes read |
221 | uint64_t sizeInput = numLookups * sizePerLookup; |
222 | |
223 | // tables are usually large and fit in DRAM |
224 | sizeDram += sizeInput; |
225 | // we also read the indices, weights and lengths arrays |
226 | sizeSram += SLWSN->getIndices().getType()->getSizeInBytes(); |
227 | sizeSram += SLWSN->getWeights().getType()->getSizeInBytes(); |
228 | sizeSram += SLWSN->getLengths().getType()->getSizeInBytes(); |
229 | } else if (node->getKind() == Kinded::Kind::SparseLengthsSumNodeKind) { |
230 | auto *SLSN = llvm::dyn_cast<SparseLengthsSumNode>(node); |
231 | // compute how many entries of the embedding table we look up |
232 | auto numLookups = SLSN->getIndices().dims().front(); |
233 | // compute how many bytes we read per lookup |
234 | auto tableSize = SLSN->getData().getType()->getSizeInBytes(); |
235 | auto numRows = SLSN->getData().dims().front(); |
236 | auto sizePerLookup = tableSize / numRows; |
237 | // compute total bytes read |
238 | uint64_t sizeInput = numLookups * sizePerLookup; |
239 | |
240 | // tables are usually large and fit in DRAM |
241 | sizeDram += sizeInput; |
242 | // we also read the indices and lengths arrays |
243 | sizeSram += SLSN->getIndices().getType()->getSizeInBytes(); |
244 | sizeSram += SLSN->getLengths().getType()->getSizeInBytes(); |
245 | } else if (node->getKind() == |
246 | Kinded::Kind:: |
247 | FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind) { |
248 | auto *FRQSLWSN = |
249 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(node); |
250 | // compute how many entries of the embedding table we look up |
251 | auto numLookups = FRQSLWSN->getIndices().dims().front(); |
252 | // compute how many bytes we read per lookup |
253 | auto tableSize = FRQSLWSN->getData().getType()->getSizeInBytes(); |
254 | auto numRows = FRQSLWSN->getData().dims().front(); |
255 | auto sizePerLookup = tableSize / numRows; |
256 | // compute total bytes read |
257 | uint64_t sizeInput = numLookups * sizePerLookup; |
258 | |
259 | // tables are usually large and fit in DRAM |
260 | sizeDram += sizeInput; |
261 | |
262 | // we also read the indices, weights and lengths arrays |
263 | sizeSram += FRQSLWSN->getIndices().getType()->getSizeInBytes(); |
264 | sizeSram += FRQSLWSN->getWeights().getType()->getSizeInBytes(); |
265 | sizeSram += FRQSLWSN->getLengths().getType()->getSizeInBytes(); |
266 | } else if (node->getKind() == |
267 | Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind) { |
268 | auto *FRQSLSN = |
269 | llvm::dyn_cast<FusedRowwiseQuantizedSparseLengthsSumNode>(node); |
270 | // compute how many entries of the embedding table we look up |
271 | auto numLookups = FRQSLSN->getIndices().dims().front(); |
272 | // compute how many bytes we read per lookup |
273 | auto tableSize = FRQSLSN->getData().getType()->getSizeInBytes(); |
274 | auto numRows = FRQSLSN->getData().dims().front(); |
275 | auto sizePerLookup = tableSize / numRows; |
276 | // compute total bytes read |
277 | uint64_t sizeInput = numLookups * sizePerLookup; |
278 | |
279 | // tables are usually large and fit in DRAM |
280 | sizeDram += sizeInput; |
281 | |
282 | // we also read the indices and lengths arrays |
283 | sizeSram += FRQSLSN->getIndices().getType()->getSizeInBytes(); |
284 | sizeSram += FRQSLSN->getLengths().getType()->getSizeInBytes(); |
285 | } else { |
286 | // for all other ops, iterate through all inputs and get size in bytes |
287 | for (int i = 0; i < n; i++) { |
288 | auto ty = node->getNthInput(i).getType(); |
289 | uint64_t sizeInput = ty->getSizeInBytes(); |
290 | if (sizeInput > sramCapacity) { |
291 | sizeDram += sizeInput; |
292 | } else { |
293 | sizeSram += sizeInput; |
294 | } |
295 | } |
296 | } |
297 | |
298 | // Repeat for outputs |
299 | for (size_t i = 0, e = node->getNumResults(); i < e; i++) { |
300 | auto myty = node->getType(i); |
301 | uint64_t sizeOutput = myty->getSizeInBytes(); |
302 | if (sizeOutput > sramCapacity) { |
303 | sizeDram += sizeOutput; |
304 | } else { |
305 | sizeSram += sizeOutput; |
306 | } |
307 | } |
308 | |
309 | // Calculate compute ops. Currently only computed for Matmul, Conv, FC |
310 | // TODO: think about whether this is better off computed inside a Node. |
311 | uint64_t totalOps = 0; |
312 | switch (node->getKind()) { |
313 | case Kinded::Kind::MatMulNodeKind: { |
314 | auto *MMN = llvm::dyn_cast<MatMulNode>(node); |
315 | auto lhsDims = MMN->getLHS().dims(); |
316 | auto rhsDims = MMN->getRHS().dims(); |
317 | totalOps = 2 * lhsDims[0] * lhsDims[1] * rhsDims[1]; |
318 | break; |
319 | } |
320 | case Kinded::Kind::FullyConnectedNodeKind: { |
321 | auto *FCN = llvm::dyn_cast<FullyConnectedNode>(node); |
322 | auto inputDims = FCN->getInput().dims(); |
323 | auto wtDims = FCN->getWeights().dims(); |
324 | totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[0]; |
325 | break; |
326 | } |
327 | #ifdef GLOW_WITH_HABANA |
328 | case Kinded::Kind::HabanaFullyConnectedNodeKind: { |
329 | auto *FCN = llvm::dyn_cast<HabanaFullyConnectedNode>(node); |
330 | auto inputDims = FCN->getInput().dims(); |
331 | auto wtDims = FCN->getWeights().dims(); |
332 | totalOps = 2 * inputDims[0] * inputDims[1] * wtDims[0]; |
333 | break; |
334 | } |
335 | #endif |
336 | case Kinded::Kind::ConvolutionNodeKind: { |
337 | auto *CN = llvm::dyn_cast<ConvolutionNode>(node); |
338 | auto resultDims = CN->getResult().dims(); |
339 | // Get the product of batch, output height, output dims, output channels |
340 | totalOps = resultDims[0]; |
341 | for (size_t i = 1, e = resultDims.size(); i < e; i++) { |
342 | totalOps *= resultDims[i]; |
343 | } |
344 | // Multiply in kernel height, kernel width |
345 | auto kernelDims = CN->getKernels(); |
346 | totalOps *= kernelDims[0] * kernelDims[1]; |
347 | // Multiply in input channels/groups |
348 | auto inputChannels = CN->getInput().dims()[1]; |
349 | auto nGroups = CN->getGroup(); |
350 | totalOps *= (inputChannels * 1.0 / nGroups); |
351 | break; |
352 | } |
353 | #ifdef GLOW_WITH_HABANA |
354 | case Kinded::Kind::HabanaConvolutionNodeKind: { |
355 | auto *CN = llvm::dyn_cast<HabanaConvolutionNode>(node); |
356 | auto resultDims = CN->getResult().dims(); |
357 | // Get the product of batch, output height, output dims, output channels |
358 | totalOps = resultDims[0]; |
359 | for (size_t i = 1, e = resultDims.size(); i < e; i++) { |
360 | totalOps *= resultDims[i]; |
361 | } |
362 | // Multiply in kernel height, kernel width |
363 | auto kernelDims = CN->getKernels(); |
364 | totalOps *= kernelDims[0] * kernelDims[1]; |
365 | // Multiply in input channels/groups |
366 | auto inputChannels = CN->getInput().dims()[1]; |
367 | auto nGroups = CN->getGroup(); |
368 | totalOps *= (inputChannels * 1.0 / nGroups); |
369 | break; |
370 | } |
371 | #endif |
372 | default: |
373 | break; |
374 | } |
375 | |
376 | // Compute compute roofline as max of flops, DRAM, SRAM BW |
377 | // See https://bit.ly/2UdJ3mz |
378 | // Add epsilons to prevent seg faults on uninitialized peak values. |
379 | return std::max(totalOps * 1.0f / std::max(peakCompute, 1e-6f), |
380 | std::max(sizeDram * 1.0f / std::max(peakDramBw, 1e-6f), |
381 | sizeSram * 1.0f / std::max(peakSramBw, 1e-6f))); |
382 | } |
383 | |
384 | /// Given nodes set \p currNodes and its memory usage info \p info, \returns the |
385 | /// new memory usage if \p newNode is added into \p currNodes. |
386 | GraphMemInfo updateGraphMemInfoByAddingNode(const NodesSet &currNodes, |
387 | const GraphMemInfo &info, |
388 | Node *newNode) { |
389 | GraphMemInfo ret = info; |
390 | |
391 | // Collect the used NodeValues (Storage nodes and outputs from the nodes |
392 | // outside of currNodes). |
393 | std::set<NodeValue> usedNodeValue; |
394 | for (auto N : currNodes) { |
395 | for (size_t i = 0, e = N->getNumInputs(); i < e; i++) { |
396 | NodeValue nodeVal = N->getNthInput(i); |
397 | if (currNodes.count(nodeVal.getNode()) == 0) { |
398 | usedNodeValue.insert(nodeVal); |
399 | } |
400 | } |
401 | } |
402 | // Calculate new outMemSize. |
403 | NodesSet newNodes = currNodes; |
404 | newNodes.insert(newNode); |
405 | uint64_t newSize = 0; |
406 | for (auto *node : newNodes) { |
407 | if (auto *SN = llvm::dyn_cast<SaveNode>(node)) { |
408 | // SaveNode is a special case since it has no users but always writes out. |
409 | newSize += SN->getOutput().getType()->getSizeInBytes(); |
410 | } else { |
411 | newSize += getOutMemPerNode(newNodes, node); |
412 | } |
413 | } |
414 | ret.outMemSize = newSize; |
415 | |
416 | // The memory usage changes due to newNode's inputs: |
417 | for (size_t i = 0, e = newNode->getNumInputs(); i < e; i++) { |
418 | if (llvm::isa<SaveNode>(newNode) && i == SaveNode::OutputIdx) { |
419 | continue; |
420 | } |
421 | NodeValue nodeVal = newNode->getNthInput(i); |
422 | Node *N = nodeVal.getNode(); |
423 | |
424 | if (usedNodeValue.count(nodeVal)) { |
425 | // This input has been considered already, nothing to do. |
426 | continue; |
427 | } |
428 | |
429 | Storage *in = llvm::dyn_cast<Storage>(N); |
430 | if (in) { |
431 | // Node uses placeholders or constants which are not used in this set |
432 | // before, need to add the memory. |
433 | uint64_t size = in->getType()->getSizeInBytes(); |
434 | if (in->getKind() == Kinded::Kind::ConstantKind) { |
435 | ret.constMemSize += size; |
436 | } else { |
437 | Placeholder *ph = llvm::dyn_cast<Placeholder>(N); |
438 | // If PH is static treat like a constant. |
439 | if (ph->isStatic()) { |
440 | ret.constMemSize += size; |
441 | ret.deferredConstMemSize += size; |
442 | } else { |
443 | // PlaceHolder for Input. |
444 | ret.inMemSize += size; |
445 | ret.inputCount += 1; |
446 | } |
447 | } |
448 | usedNodeValue.insert(nodeVal); |
449 | continue; |
450 | } |
451 | |
452 | if (!currNodes.count(N)) { |
453 | // In this case, this input is not a storage type node nor belongs |
454 | // to this subgraph. Therefore, when creating paritions, we need to add |
455 | // a PlaceHolder for the data from outside. |
456 | ret.inMemSize += nodeVal.getType()->getSizeInBytes(); |
457 | ret.inputCount += 1; |
458 | usedNodeValue.insert(nodeVal); |
459 | } |
460 | } |
461 | |
462 | for (size_t i = 0, e = newNode->getNumResults(); i < e; i++) { |
463 | auto nodeVal = newNode->getNthResult(i); |
464 | for (auto &U : nodeVal.getUsers()) { |
465 | if (currNodes.count(U.getUser()) == 0) { |
466 | // The nodeVal (i.e. the ith output of newNode) is not used in |
467 | // currNodes: |
468 | continue; |
469 | } |
470 | // Assume newNode -> node1, where node1 belongs to currNodes set. Before |
471 | // newNode is added, node1's input size (from newNode) should be added |
472 | // into inMemSize. But afater newNode is added, the input size should be |
473 | // removed. |
474 | ret.inMemSize -= nodeVal.getType()->getSizeInBytes(); |
475 | ret.inputCount -= 1; |
476 | break; |
477 | } |
478 | } |
479 | |
480 | return ret; |
481 | } |
482 | |
483 | GraphMemInfo getGraphMemInfo(const NodesSet &nodes, unsigned contextCount) { |
484 | GraphMemInfo ret; |
485 | ret.contextCount = contextCount; |
486 | NodesSet nodeSet; |
487 | for (NodesSet::iterator it = nodes.begin(); it != nodes.end(); ++it) { |
488 | Node *cur = *it; |
489 | ret = updateGraphMemInfoByAddingNode(nodeSet, ret, cur); |
490 | nodeSet.insert(cur); |
491 | } |
492 | return ret; |
493 | } |
494 | |
495 | GraphMemInfo getFunctionMemory(Function *func) { |
496 | GraphMemInfo graphMem; |
497 | |
498 | for (auto cons : func->findConstants()) { |
499 | graphMem.constMemSize += cons->getType()->getSizeInBytes(); |
500 | } |
501 | |
502 | // Gather all other functions in the module for peer resource usage counting. |
503 | std::vector<const Function *> otherFuns; |
504 | std::copy_if(func->getParent()->getFunctions().begin(), |
505 | func->getParent()->getFunctions().end(), |
506 | std::back_inserter(otherFuns), |
507 | [func](Function *F) { return func != F; }); |
508 | |
509 | // Walk thru all Placeholders in the function to accumulate input and |
510 | // output mem size. These utility functions check the users of the PH to |
511 | // determine if the PH is an input or an output. |
512 | for (auto &place : func->findPlaceholders()) { |
513 | if (place->isStatic()) { |
514 | graphMem.constMemSize += place->getType()->getSizeInBytes(); |
515 | graphMem.deferredConstMemSize += place->getType()->getSizeInBytes(); |
516 | } else { |
517 | if (isInput(place, *func)) { |
518 | graphMem.inMemSize += place->getType()->getSizeInBytes(); |
519 | graphMem.inputCount += 1; |
520 | // Check if this placeholder is the output of a peer function. |
521 | if (isOutput(place, otherFuns)) { |
522 | graphMem.inputFromPeerCount += 1; |
523 | } |
524 | } |
525 | if (isOutput(place, *func)) { |
526 | graphMem.outMemSize += place->getType()->getSizeInBytes(); |
527 | } |
528 | } |
529 | } |
530 | |
531 | return graphMem; |
532 | } |
533 | |
534 | std::set<Kinded::Kind> generateNodeKindsSet(llvm::StringRef names) { |
535 | std::set<Kinded::Kind> nodeKindsSet; |
536 | llvm::StringRef::size_type pos = names.find(','); |
537 | while (pos != llvm::StringRef::npos) { |
538 | nodeKindsSet.insert(getKindFromNodeName(names.substr(0, pos))); |
539 | names = names.substr(pos + 1); |
540 | pos = names.find(','); |
541 | } |
542 | if (!names.empty()) { |
543 | nodeKindsSet.insert(getKindFromNodeName(names)); |
544 | } |
545 | return nodeKindsSet; |
546 | } |
547 | |
548 | void logPartitionInfo(const NodeToFunctionMap &partitions) { |
549 | int i = 0; |
550 | for (Function *subF : partitions.getPartitions()) { |
551 | LOG(INFO) << "\t Partition " << i++ << ":\n" |
552 | << "\t\t Name :\t" << subF->getName().str() << "\n" |
553 | << "\t\t BackendKind :\t" |
554 | << partitions.getPartitionBackendName(subF) << "\n" |
555 | << "\t\t context count :\t" |
556 | << partitions.getGraphMemInfo(subF).contextCount << "\n" |
557 | << "\t\t total Memory :\t" |
558 | << partitions.getGraphMemInfo(subF).getTotalMemSize() << "\n" |
559 | << "\t\t\t input size:\t" |
560 | << partitions.getGraphMemInfo(subF).inMemSize << "\n" |
561 | << "\t\t\t input count :\t" |
562 | << partitions.getGraphMemInfo(subF).inputCount << "\n" |
563 | << "\t\t\t input only from peers count :\t" |
564 | << partitions.getGraphMemInfo(subF).inputFromPeerCount << "\n" |
565 | << "\t\t\t output size:\t" |
566 | << partitions.getGraphMemInfo(subF).outMemSize << "\n" |
567 | << "\t\t\t constant size:\t" |
568 | << partitions.getGraphMemInfo(subF).constMemSize << "\n" |
569 | << "\t\t\t\t non-deferred constant size:\t" |
570 | << partitions.getGraphMemInfo(subF).constMemSize - |
571 | partitions.getGraphMemInfo(subF).deferredConstMemSize |
572 | << "\n" |
573 | << "\t\t\t\t deferred constant size:\t" |
574 | << partitions.getGraphMemInfo(subF).deferredConstMemSize << "\n" ; |
575 | // This may be called before logicalDevices are assigned so check before |
576 | // printing. |
577 | if (partitions.getLogicalDeviceIDList(subF).size()) { |
578 | LOG(INFO) << "\t\t LogicalDeviceIDs :\t" |
579 | << partitions.getLogicalDeviceIDList(subF)[0] << "\n" ; |
580 | } |
581 | } |
582 | } |
583 | |
584 | void printSlsTableInfo(std::vector<SLSTableInfo>::iterator start, |
585 | std::vector<SLSTableInfo>::iterator end, |
586 | bool verbose_only) { |
587 | if (start >= end) { |
588 | return; |
589 | } |
590 | std::stringstream ss; |
591 | ss << "(numBytesInTable(MB), deviceID, cost, cost/numBytesInTable) " |
592 | << strFormat(" - %zu tables -" , end - start) << "\n" ; |
593 | while (start < end) { |
594 | const auto tableSizeInMB = (float)start->numBytesInTable / MB; |
595 | const auto costPerByte = tableSizeInMB == 0 |
596 | ? "nan" |
597 | : std::to_string(start->cost / tableSizeInMB); |
598 | ss << " " << tableSizeInMB << " " << start->deviceId |
599 | << " " << start->cost << " " << costPerByte << std::endl; |
600 | start++; |
601 | } |
602 | if (verbose_only) { |
603 | VLOG(1) << ss.str(); |
604 | } else { |
605 | LOG(INFO) << ss.str(); |
606 | } |
607 | } |
608 | |
609 | void printSlsTableInfo(std::vector<SLSTableInfo> &slsTables, |
610 | bool verbose_only) { |
611 | printSlsTableInfo(slsTables.begin(), slsTables.end(), verbose_only); |
612 | } |
613 | |
614 | void printSlsDeviceInfo(const std::vector<SLSDeviceInfo> &slsDevices, |
615 | const std::vector<NodesSet> &nodesets, |
616 | const unsigned contextCount, bool verbose_only) { |
617 | std::stringstream ss; |
618 | ss << "(deviceId, used_memory(MB), free_memory(MB), cost, " |
619 | "node_size, cost/used_memory)" |
620 | << strFormat(" - %zu devices -" , slsDevices.size()) << "\n" ; |
621 | for (const auto &d : slsDevices) { |
622 | const auto deviceId = d.deviceId; |
623 | const auto meminfo = getGraphMemInfo(nodesets[deviceId], contextCount); |
624 | const auto usedMem = (float)meminfo.getTotalMemSize() / MB; |
625 | const auto availMem = (float)d.memAvailableInBytes / MB; |
626 | const auto freeMem = availMem - usedMem; |
627 | const auto costPerUsedMemory = |
628 | usedMem == 0 ? "nan" : std::to_string(d.currentCost / usedMem); |
629 | ss << " " << deviceId << " " << usedMem << " " << freeMem |
630 | << " " << d.currentCost << " " << nodesets[deviceId].size() |
631 | << " " << costPerUsedMemory << "\n" ; |
632 | } |
633 | if (verbose_only) { |
634 | VLOG(1) << ss.str(); |
635 | } else { |
636 | LOG(INFO) << ss.str(); |
637 | } |
638 | } |
639 | |
640 | bool isSLSNode(const Node *node) { |
641 | return ( |
642 | node->getKind() == |
643 | glow::Kinded::Kind:: |
644 | FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind || |
645 | node->getKind() == |
646 | glow::Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind || |
647 | node->getKind() == glow::Kinded::Kind:: |
648 | RowwiseQuantizedSparseLengthsWeightedSumNodeKind || |
649 | node->getKind() == glow::Kinded::Kind::SparseLengthsSumNodeKind || |
650 | node->getKind() == glow::Kinded::Kind::SparseLengthsWeightedSumNodeKind || |
651 | node->getKind() == glow::Kinded::Kind::EmbeddingBagNodeKind || |
652 | node->getKind() == |
653 | glow::Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind); |
654 | } |
655 | |
656 | bool checkNodeInputsAllKind(const Node *node, glow::Kinded::Kind kind) { |
657 | bool allSameKind = true; |
658 | for (auto i = 0; i < node->getNumInputs(); i++) { |
659 | auto nodeInput = node->getNthInput(i); |
660 | allSameKind &= nodeInput.getNode()->getKind() == kind; |
661 | } |
662 | return allSameKind; |
663 | } |
664 | |
665 | Error assignSlsTableToFirstAvailableDevice( |
666 | SLSTableInfo &table, std::vector<SLSDeviceInfo> &slsDevices, |
667 | std::vector<NodesSet> &nodesets, |
668 | std::vector<std::unordered_set<NodeValue>> &frontierValues, |
669 | const unsigned contextCount, |
670 | std::unordered_map<Node *, size_t> &addedSLSNodes) { |
671 | DCHECK(slsDevices.size() == nodesets.size() && |
672 | slsDevices.size() == frontierValues.size()); |
673 | auto addedNodeDeviceId = addedSLSNodes.find(table.node); |
674 | if (addedNodeDeviceId != addedSLSNodes.end()) { |
675 | table.deviceId = addedNodeDeviceId->second; |
676 | return Error::success(); |
677 | } |
678 | |
679 | bool deviceFound = false; |
680 | for (auto &d : slsDevices) { |
681 | const auto deviceId = d.deviceId; |
682 | // Calculate the memory needed if we merge SLS and its neighboring nodes |
683 | // into existing partition |
684 | auto nodesSetd = nodesets[deviceId]; |
685 | nodesSetd.insert(table.node); |
686 | nodesSetd.insert(table.neighbors.begin(), table.neighbors.end()); |
687 | auto meminfo = getGraphMemInfo(nodesSetd, contextCount); |
688 | const auto totalSize = meminfo.getTotalMemSize(); |
689 | if (d.memAvailableInBytes >= totalSize) { |
690 | d.currentCost += (size_t)table.cost; |
691 | table.deviceId = deviceId; |
692 | frontierValues[deviceId].insert(table.frontier.begin(), |
693 | table.frontier.end()); |
694 | for (auto &nb : table.neighbors) { |
695 | if (isSLSNode(nb)) { |
696 | addedSLSNodes.insert({nb, deviceId}); |
697 | } |
698 | } |
699 | nodesets[deviceId].swap(nodesSetd); |
700 | deviceFound = true; |
701 | break; |
702 | } |
703 | } |
704 | if (!deviceFound) { |
705 | return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR, |
706 | "SLS Balancing Partitioning Error: Not enough memory" ); |
707 | } |
708 | return Error::success(); |
709 | } |
710 | |
711 | Error assignSlsTablesToDevices( |
712 | std::vector<SLSTableInfo> &slsTables, |
713 | std::vector<SLSDeviceInfo> &slsDevices, |
714 | std::vector<std::unordered_set<NodeValue>> &frontierValues, |
715 | const unsigned contextCount) { |
716 | if (slsTables.empty()) { |
717 | LOG(INFO) << "SLS tables empty!" ; |
718 | return Error::success(); |
719 | } |
720 | // Keep a copy of input parameters, so that ScopeGuard could restore |
721 | // inputs in case of error. |
722 | std::vector<SLSTableInfo> slsTablesCopy = slsTables; |
723 | std::vector<SLSDeviceInfo> slsDevicesCopy = slsDevices; |
724 | std::vector<std::unordered_set<NodeValue>> frontierValuesCopy = |
725 | frontierValues; |
726 | ScopeGuard restoreInputsOnError([&]() { |
727 | slsTables.swap(slsTablesCopy); |
728 | slsDevices.swap(slsDevicesCopy); |
729 | frontierValues.swap(frontierValuesCopy); |
730 | }); |
731 | |
732 | // Now sort SLS tables by size decreasing |
733 | VLOG(1) << "SLS tables sorted by size decreasing" ; |
734 | std::sort(slsTables.begin(), slsTables.end(), |
735 | [](const SLSTableInfo &l, const SLSTableInfo &r) { |
736 | return l.numBytesInTable > r.numBytesInTable; |
737 | }); |
738 | |
739 | // slsTables is in sorted order decreasingly by numBytesInTable. |
740 | // The tables between [slsTablesLeft, slsTableRight) are large tables that |
741 | // have numBytesInTable > BigTableThresholdBytes. |
742 | // slsTablesLeft and slsTablesRight will be both pointed to slsTables.begin() |
743 | // if we could not find any large tables. |
744 | auto slsTablesLeft = slsTables.begin(); |
745 | auto slsTableRight = slsTables.end(); |
746 | if (slsTablesLeft->numBytesInTable > |
747 | glow::runtime::flags::BigTableThresholdBytes) { |
748 | for (auto it = slsTables.begin(); it < slsTables.end(); it++) { |
749 | if (it->numBytesInTable <= glow::runtime::flags::BigTableThresholdBytes) { |
750 | slsTableRight = it; |
751 | break; |
752 | } |
753 | } |
754 | } else { |
755 | // No large table found. |
756 | slsTablesLeft = slsTables.begin(); |
757 | slsTableRight = slsTables.begin(); |
758 | } |
759 | |
760 | // We first assign large tables to devices. After allocation, each device |
761 | // should has roughly the same size. |
762 | LOG(INFO) << strFormat("Now assign %zu large tables to %zu devices." , |
763 | (slsTableRight - slsTablesLeft), slsDevices.size()); |
764 | // Print Large SLS tables |
765 | VLOG(1) << "Large tables by size decreasing: " ; |
766 | printSlsTableInfo(slsTablesLeft, slsTableRight); |
767 | std::vector<NodesSet> nodesets(slsDevices.size()); |
768 | std::unordered_map<Node *, size_t> addedSLSNodes; |
769 | while (slsTablesLeft < slsTableRight) { |
770 | // Sort devices by size increasingly. |
771 | std::sort(slsDevices.begin(), slsDevices.end(), |
772 | [&nodesets, contextCount](const SLSDeviceInfo &l, |
773 | const SLSDeviceInfo &r) { |
774 | auto lTotalSize = |
775 | getGraphMemInfo(nodesets[l.deviceId], contextCount) |
776 | .getTotalMemSize(); |
777 | auto rTotalSize = |
778 | getGraphMemInfo(nodesets[r.deviceId], contextCount) |
779 | .getTotalMemSize(); |
780 | return lTotalSize < rTotalSize; |
781 | }); |
782 | VLOG(1) << "Devices sorted by used memory increasing: " ; |
783 | printSlsDeviceInfo(slsDevices, nodesets, contextCount, |
784 | true /* verbose_only */); |
785 | |
786 | // Pick the first that fits |
787 | auto &table = *slsTablesLeft; |
788 | RETURN_IF_ERR(assignSlsTableToFirstAvailableDevice( |
789 | table, slsDevices, nodesets, frontierValues, contextCount, |
790 | addedSLSNodes)); |
791 | slsTablesLeft++; |
792 | } |
793 | VLOG(1) << "Done assigning large tables, devices info: " ; |
794 | printSlsDeviceInfo(slsDevices, nodesets, contextCount, |
795 | true /* verbose_only */); |
796 | |
797 | // Now let us assign small size tables. |
798 | // First sort tables by cost decreasingly. For each table, we would like to |
799 | // assign it to the device with lowest cost. |
800 | LOG(INFO) << strFormat("Now assign %zu small tables to %zu devices." , |
801 | (slsTables.end() - slsTablesLeft), slsDevices.size()); |
802 | if (slsTablesLeft < slsTables.end()) { |
803 | std::sort(slsTablesLeft, slsTables.end(), |
804 | [](const SLSTableInfo &l, const SLSTableInfo &r) { |
805 | return l.cost > r.cost; |
806 | }); |
807 | } |
808 | VLOG(1) << "Small tables by cost decreasingly: " ; |
809 | printSlsTableInfo(slsTablesLeft, slsTables.end()); |
810 | |
811 | while (slsTablesLeft < slsTables.end()) { |
812 | // Sort devices by cost increasingly. |
813 | std::sort(slsDevices.begin(), slsDevices.end(), |
814 | [](const SLSDeviceInfo &l, const SLSDeviceInfo &r) { |
815 | return l.currentCost < r.currentCost; |
816 | }); |
817 | |
818 | VLOG(1) << "Devices sorted by cost increasing: " ; |
819 | printSlsDeviceInfo(slsDevices, nodesets, contextCount, |
820 | true /* verbose_only */); |
821 | |
822 | // Pick the first that fits |
823 | auto &table = *slsTablesLeft; |
824 | RETURN_IF_ERR(assignSlsTableToFirstAvailableDevice( |
825 | table, slsDevices, nodesets, frontierValues, contextCount, |
826 | addedSLSNodes)); |
827 | slsTablesLeft++; |
828 | } |
829 | // Print final device info |
830 | LOG(INFO) << "Done assigning small tables, final devices info: " ; |
831 | printSlsDeviceInfo(slsDevices, nodesets, contextCount, |
832 | false /* verbose_only */); |
833 | restoreInputsOnError.dismiss(); |
834 | return Error::success(); |
835 | } |
836 | |
837 | Error assignSlsTablesToDevicesGreedy( |
838 | std::vector<SLSTableInfo> &slsTables, |
839 | std::vector<SLSDeviceInfo> &slsDevices, |
840 | std::vector<std::unordered_set<NodeValue>> &frontierValues, |
841 | const unsigned contextCount) { |
842 | if (slsTables.empty()) { |
843 | LOG(INFO) << "SLS tables empty!" ; |
844 | return Error::success(); |
845 | } |
846 | // Keep a copy of input parameters, so that ScopeGuard could restore |
847 | // inputs in case of error. |
848 | std::vector<SLSTableInfo> slsTablesCopy = slsTables; |
849 | std::vector<SLSDeviceInfo> slsDevicesCopy = slsDevices; |
850 | std::vector<std::unordered_set<NodeValue>> frontierValuesCopy = |
851 | frontierValues; |
852 | ScopeGuard restoreInputsOnError([&]() { |
853 | slsTables.swap(slsTablesCopy); |
854 | slsDevices.swap(slsDevicesCopy); |
855 | frontierValues.swap(frontierValuesCopy); |
856 | }); |
857 | |
858 | // Now sort SLS tables by size decreasing |
859 | VLOG(1) << "SLS tables sorted by size decreasing" << std::endl; |
860 | std::sort(slsTables.begin(), slsTables.end(), |
861 | [](const SLSTableInfo &l, const SLSTableInfo &r) { |
862 | return l.numBytesInTable > r.numBytesInTable; |
863 | }); |
864 | |
865 | // Print SLS tables |
866 | printSlsTableInfo(slsTables); |
867 | |
868 | // Now assign SLS Nodes to devices |
869 | std::vector<NodesSet> nodesets(slsDevices.size()); |
870 | std::unordered_map<Node *, size_t> addedSLSNodes; |
871 | for (auto &table : slsTables) { |
872 | |
873 | // Sort by cost increasing |
874 | std::sort(slsDevices.begin(), slsDevices.end(), |
875 | [](const SLSDeviceInfo &l, const SLSDeviceInfo &r) { |
876 | return l.currentCost < r.currentCost; |
877 | }); |
878 | |
879 | VLOG(1) << "Devices sorted by cost increasing" << std::endl; |
880 | printSlsDeviceInfo(slsDevices, nodesets, contextCount, |
881 | true /* verbose_only */); |
882 | |
883 | // Pick the first that fits |
884 | RETURN_IF_ERR(assignSlsTableToFirstAvailableDevice( |
885 | table, slsDevices, nodesets, frontierValues, contextCount, |
886 | addedSLSNodes)); |
887 | } |
888 | // Print final device info |
889 | LOG(INFO) << "Devices sorted by cost increasing: " ; |
890 | printSlsDeviceInfo(slsDevices, nodesets, contextCount, |
891 | false /* verbose_only */); |
892 | restoreInputsOnError.dismiss(); |
893 | return Error::success(); |
894 | } |
895 | |
896 | } // namespace glow |
897 | |