1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Optimizer/GraphOptimizer/NodeSplitting.h"
18#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
19
20#include <algorithm>
21#include <numeric>
22#include <unordered_map>
23#include <unordered_set>
24#include <vector>
25
26using namespace glow;
27using llvm::dyn_cast;
28
29///===---------------------------------------------------------------------===//
30/// SplitNodeOption
31///===---------------------------------------------------------------------===//
32size_t SplitNodeOptionOrthogonal::getSplitDimIdx(dim_t splitDim) const {
33 auto splitDimsIt = std::find(splitDims_.begin(), splitDims_.end(), splitDim);
34 CHECK(splitDimsIt != splitDims_.end())
35 << "Split dimension '" << splitDim
36 << "' invalid! Not registered in this SplitNodeOption!";
37 return std::distance(splitDims_.begin(), splitDimsIt);
38}
39
40std::vector<dim_t> SplitNodeByNumChunks::splitAlongDim(size_t dim,
41 dim_t dimSize) const {
42 size_t dimIdx = getSplitDimIdx(dim);
43 dim_t numChunks = numChunks_[dimIdx];
44 CHECK((1 <= numChunks) && (numChunks <= dimSize))
45 << "SplitNodeByNumChunks: Invalid number of chunks '" << numChunks
46 << "' for splitting a dimension with size '" << dimSize << "'!";
47
48 dim_t chunkDiv = dimSize / numChunks;
49 dim_t chunkRem = dimSize % numChunks;
50
51 // Small and big chunk sizes.
52 dim_t numChunksBig = chunkRem;
53 dim_t chunkSizeSmall = chunkDiv;
54 dim_t chunkSizeBig = chunkDiv + 1;
55
56 // Split dimension.
57 std::vector<dim_t> chunkSizes(numChunks);
58 for (size_t idx = 0, end = numChunks; idx < end; ++idx) {
59 dim_t chunkSize = chunkSizeSmall;
60 if (bigChunksFirst_ && (idx < numChunksBig)) {
61 chunkSize = chunkSizeBig;
62 }
63 if ((!bigChunksFirst_ && (idx >= numChunks - numChunksBig))) {
64 chunkSize = chunkSizeBig;
65 }
66 chunkSizes[idx] = chunkSize;
67 }
68 return chunkSizes;
69}
70
71std::vector<dim_t> SplitNodeByChunkSize::splitAlongDim(size_t dim,
72 dim_t dimSize) const {
73 size_t dimIdx = getSplitDimIdx(dim);
74 dim_t chunkSize = chunkSizes_[dimIdx];
75 CHECK((1 <= chunkSize) && (chunkSize <= dimSize))
76 << "SplitNodeByChunkSize: Invalid chunk size '" << chunkSize
77 << "' for splitting a dimension with size '" << dimSize << "'!";
78
79 dim_t chunkDiv = dimSize / chunkSize;
80 dim_t chunkRem = dimSize % chunkSize;
81
82 // Small and big chunk sizes.
83 dim_t numChunks = chunkRem > 0 ? chunkDiv + 1 : chunkDiv;
84 dim_t chunkSizeSmall = chunkRem > 0 ? chunkRem : chunkSize;
85 dim_t chunkSizeBig = chunkSize;
86
87 // Split dimension.
88 std::vector<dim_t> chunkSizes(numChunks);
89 for (size_t idx = 0, end = numChunks; idx < end; ++idx) {
90 dim_t chunkSizeFinal = chunkSizeBig;
91 if (bigChunksFirst_ && (idx == numChunks - 1)) {
92 chunkSizeFinal = chunkSizeSmall;
93 }
94 if (!bigChunksFirst_ && (idx == 0)) {
95 chunkSizeFinal = chunkSizeSmall;
96 }
97 chunkSizes[idx] = chunkSizeFinal;
98 }
99 return chunkSizes;
100}
101
102std::vector<dim_t> SplitNodeByChunkSizes::splitAlongDim(size_t dim,
103 dim_t dimSize) const {
104 size_t dimIdx = getSplitDimIdx(dim);
105 std::vector<dim_t> chunkSizes = chunkSizes_[dimIdx];
106 size_t numChunks = chunkSizes.size();
107 CHECK((1 <= numChunks) && (numChunks <= dimSize))
108 << "SplitNodeByChunkSizes: Invalid number of sizes '" << numChunks
109 << "' for splitting a dimension with size '" << dimSize << "'!";
110 for (const auto &chunkSize : chunkSizes) {
111 CHECK_GT(chunkSize, 0)
112 << "SplitNodeByChunkSizes: Chunk size 0 is not allowed!";
113 }
114 return chunkSizes;
115}
116
117std::vector<dim_t> SplitNodeByChunkWeights::splitAlongDim(size_t dim,
118 dim_t dimSize) const {
119 size_t dimIdx = getSplitDimIdx(dim);
120 const std::vector<float> &chunkWeights = chunkWeights_[dimIdx];
121 dim_t numChunks = chunkWeights.size();
122 CHECK((1 <= numChunks) && (numChunks <= dimSize))
123 << "SplitNodeByChunkWeights: Invalid number of weights '" << numChunks
124 << "' for splitting a dimension with size '" << dimSize << "'!";
125
126 // Verify that all the weights are positive and compute the weights sum.
127 float chunkWeightsSum = 0;
128 for (const auto &weight : chunkWeights) {
129 CHECK_GT(weight, 0.f) << "SplitNodeByChunkWeights: Chunk weight '" << weight
130 << "' invalid! Should be strictly positive!";
131 chunkWeightsSum += weight;
132 }
133
134 // Compute individual chunk sizes such that each chunk gets at least one unit
135 // (empty chunks are NOT allowed). The total number of units is distributed
136 // such that the error between the given chunk weights and the actual chunk
137 // weights is minimized.
138 std::vector<dim_t> chunkSizes(numChunks, 1);
139 dim_t unitsRem = dimSize - numChunks;
140 while (unitsRem > 0) {
141 // Find chunk with maximum weight error.
142 float weightErrMax = std::numeric_limits<float>::lowest();
143 size_t weightErrMaxIdx = 0;
144 for (size_t idx = 0; idx < numChunks; ++idx) {
145 float weightVal =
146 float(chunkSizes[idx]) / float(dimSize) * chunkWeightsSum;
147 // We use a signed error here to starve those chunks for which the actual
148 // weight surpassed the given weight.
149 float weightErr = chunkWeights[idx] - weightVal;
150 if (weightErr > weightErrMax) {
151 weightErrMax = weightErr;
152 weightErrMaxIdx = idx;
153 }
154 }
155 // Distribute unit.
156 chunkSizes[weightErrMaxIdx] += 1;
157 unitsRem--;
158 }
159 return chunkSizes;
160}
161
162/// Utility function to split an array of slice ranges \p ranges along the given
163/// dimension \p dim using the split option \p splitOption.
164static std::vector<SliceRange>
165splitSliceRanges(const std::vector<SliceRange> &ranges, size_t dim,
166 const SplitNodeOptionOrthogonal *splitOption) {
167 std::vector<SliceRange> outRanges;
168 for (const auto &range : ranges) {
169
170 // Split dimension.
171 dim_t dimSize = range.getDimSize(dim);
172 std::vector<dim_t> chunkSizes = splitOption->splitAlongDim(dim, dimSize);
173
174 // Check for empty chunks.
175 for (auto chunkSize : chunkSizes) {
176 CHECK_GT(chunkSize, 0) << "Chunk size 0 is not allowed!";
177 }
178
179 // Check dimension splitting consistency.
180 dim_t chunkSizesSum =
181 std::accumulate(chunkSizes.begin(), chunkSizes.end(), (dim_t)0);
182 CHECK_EQ(dimSize, chunkSizesSum)
183 << "Inconsistent splitting of dimension " << dim << " with size "
184 << dimSize << " into chunks with total size " << chunkSizesSum << "!";
185
186 // Split current slice range.
187 auto numChunks = chunkSizes.size();
188 std::vector<SliceRange> splitRanges(numChunks, range);
189 dim_t chunkStart = range[dim].first;
190 for (size_t idx = 0; idx < numChunks; ++idx) {
191 // Current chunk size.
192 dim_t chunkSize = chunkSizes[idx];
193 // Update chunk bounds.
194 splitRanges[idx][dim].first = chunkStart;
195 splitRanges[idx][dim].second = chunkStart + chunkSize;
196 chunkStart += chunkSize;
197 }
198 CHECK_EQ(splitRanges.back()[dim].second, range[dim].second)
199 << "Inconsistent splitting of SliceRange!";
200
201 // Append split slice ranges.
202 outRanges.insert(outRanges.end(), splitRanges.begin(), splitRanges.end());
203 }
204 return outRanges;
205}
206
207///===---------------------------------------------------------------------===//
208/// CheckedSliceRangeMap
209///===---------------------------------------------------------------------===//
210/// Definition of a checked slice range which provides an extra boolean flag to
211/// inform whether the slice range is valid and allowed to be used.
212using CheckedSliceRange = std::pair<bool, SliceRange>;
213
214/// Definition of a functional mapping between two slice ranges with extra
215/// information about whether the mapping is allowed to be used (valid).
216using CheckedSliceRangeMap =
217 std::function<CheckedSliceRange(const SliceRange &)>;
218
219/// Identity checked slice range map to use for simple identity mappings.
220CheckedSliceRange CheckedSliceRangeMapIdentity(const SliceRange &range) {
221 return {true, range};
222}
223
224/// Definition of a pair with an operand index and a checked slice range map.
225using OpIdxAndMap = std::pair<unsigned, CheckedSliceRangeMap>;
226
227/// Utility function to verify that a given slice range \p map represents an
228/// exact mapping from \p mapInputRanges to \p mapOutputRanges.
229static bool isMappingExact(const CheckedSliceRangeMap &map,
230 const std::vector<SliceRange> &mapInputRanges,
231 const std::vector<SliceRange> &mapOutputRanges) {
232 bool mapOk = true;
233 DCHECK_EQ(mapInputRanges.size(), mapOutputRanges.size())
234 << "Slice ranges length mismatch for CheckedSliceRangeMap verification!";
235 for (size_t idx = 0, e = mapInputRanges.size(); idx < e; ++idx) {
236 auto checkedSliceRange = map(mapInputRanges[idx]);
237 mapOk = mapOk && checkedSliceRange.first;
238 mapOk = mapOk && (checkedSliceRange.second == mapOutputRanges[idx]);
239 }
240 return mapOk;
241}
242
243/// Utility function to verify that a given slice range \p map when applied to
244/// \p mapInputRanges produces ranges which are included in \p mapOutputRanges.
245/// This is a weaker verification than \ref isMappingExact.
246static bool isMappingIncluded(const CheckedSliceRangeMap &map,
247 const std::vector<SliceRange> &mapInputRanges,
248 const std::vector<SliceRange> &mapOutputRanges) {
249 bool mapOk = true;
250 DCHECK_EQ(mapInputRanges.size(), mapOutputRanges.size())
251 << "Slice ranges length mismatch for CheckedSliceRangeMap verification!";
252 for (size_t idx = 0, e = mapInputRanges.size(); idx < e; ++idx) {
253 auto checkedSliceRange = map(mapInputRanges[idx]);
254 mapOk = mapOk && checkedSliceRange.first;
255 mapOk =
256 mapOk && checkedSliceRange.second.isIncludedBy(mapOutputRanges[idx]);
257 }
258 return mapOk;
259}
260
261///===---------------------------------------------------------------------===//
262/// SplitNodeModifier
263///===---------------------------------------------------------------------===//
264/// Definition of a function which modifies the split node \p splitNode after it
265/// was cloned from the original node \p origNode. The input slice ranges \p
266/// inputSliceRanges and the output slice ranges \p outputSliceRanges are also
267/// provided by the caller to provide extra context about how the split node was
268/// obtained from the original node. This function is provided to the node
269/// splitting procedure as a callback and provides the mechanism of modifying
270/// the split node attributes in special situations, for example when the "Pads"
271/// or "Group" node attributes must be changed when splitting Convolution nodes.
272using SplitNodeModifier =
273 std::function<void(const Node *origNode, Node *splitNode,
274 const std::vector<SliceRange> &inputSliceRanges,
275 const std::vector<SliceRange> &outputSliceRanges)>;
276
277/// Definition of a "nop" split node modifier which does no modifications.
278void SplitNodeModifierNop(const Node *origNode, Node *splitNode,
279 const std::vector<SliceRange> &inputSliceRanges,
280 const std::vector<SliceRange> &outputSliceRanges) {}
281
282///===---------------------------------------------------------------------===//
283/// verifySplitParams
284///===---------------------------------------------------------------------===//
285/// List of nodes for which there is a weak mapping between input and output
286/// and thus a weaker verification must be performed. Such an example is the
287/// Conv2D/MaxPool node when using strides larger than 1 resulting in cases
288/// where the output operand does not reference the input operand entirely.
289static std::vector<Kinded::Kind> weakOutToInMappingNodeKinds = {
290 Kinded::Kind::ConvolutionNodeKind,
291 Kinded::Kind::MaxPoolNodeKind,
292 Kinded::Kind::AvgPoolNodeKind,
293};
294
295/// Function to verify the split parameters.
296static Error
297verifySplitParams(const Node *node, dim_t splitOutputIdx,
298 const llvm::ArrayRef<size_t> &splitDims,
299 const llvm::ArrayRef<OpIdxAndMap> &inputIdxAndMaps,
300 const llvm::ArrayRef<OpIdxAndMap> &outputIdxAndMaps) {
301
302 // Verify original node.
303 if (!node->verify()) {
304 llvm::errs() << node->toString() << "\n";
305 return MAKE_ERR("Invalid node given to node splitting procedure!");
306 }
307
308 // Verify split dims.
309 RETURN_ERR_IF_NOT(splitDims.size() > 0,
310 "Empty split dimensions for splitting node!");
311 RETURN_ERR_IF_NOT(splitOutputIdx < node->getNumResults(),
312 "Invalid output index for splitting node!");
313 for (size_t dim = 0; dim < splitDims.size() - 1; ++dim) {
314 RETURN_ERR_IF_NOT(splitDims[dim] < splitDims[dim + 1],
315 "Invalid split dimensions for splitting node! Dimensions "
316 "should be given in ascending order e.g. {0,2,3}!");
317 }
318 for (const auto dim : splitDims) {
319 RETURN_ERR_IF_NOT(dim < node->getType(splitOutputIdx)->dims().size(),
320 "Invalid split dimension for splitting node! Dimension "
321 "exceeds the split output tensor shape!");
322 }
323
324 // Verify all the input indices and maps were given.
325 RETURN_ERR_IF_NOT(inputIdxAndMaps.size() == node->getNumInputs(),
326 "Invalid number of input maps for splitting node!");
327 std::vector<bool> inputIdxMask(node->getNumInputs(), false);
328 for (const auto &inputIdxMap : inputIdxAndMaps) {
329 RETURN_ERR_IF_NOT(inputIdxMap.first < node->getNumInputs(),
330 "Invalid input index for input range map!");
331 inputIdxMask[inputIdxMap.first] = true;
332 }
333 RETURN_ERR_IF_NOT(
334 std::find(inputIdxMask.begin(), inputIdxMask.end(), false) ==
335 inputIdxMask.end(),
336 "Not all input indices and maps were provided for splitting node!");
337
338 // Verify all the output indices and maps were given.
339 RETURN_ERR_IF_NOT(outputIdxAndMaps.size() == node->getNumResults() - 1,
340 "Invalid number of output maps for splitting node!");
341 std::vector<bool> outputIdxMask(node->getNumResults(), false);
342 outputIdxMask[splitOutputIdx] = true;
343 for (const auto &outputIdxMap : outputIdxAndMaps) {
344 RETURN_ERR_IF_NOT(outputIdxMap.first < node->getNumResults(),
345 "Invalid output index for output range map!");
346 outputIdxMask[outputIdxMap.first] = true;
347 }
348 RETURN_ERR_IF_NOT(
349 std::find(outputIdxMask.begin(), outputIdxMask.end(), false) ==
350 outputIdxMask.end(),
351 "Not all output indices and maps were provided for splitting node!");
352
353 // Get split output range.
354 SliceRange splitOutputRange = SliceRange(node->getType(splitOutputIdx));
355
356 // Verify the input slice range maps.
357 for (const auto &inputIdxMap : inputIdxAndMaps) {
358 SliceRange inputRange =
359 SliceRange(node->getNthInput(inputIdxMap.first).getType());
360 if (std::find(weakOutToInMappingNodeKinds.begin(),
361 weakOutToInMappingNodeKinds.end(),
362 node->getKind()) != weakOutToInMappingNodeKinds.end()) {
363 // Verify weak mapping.
364 RETURN_ERR_IF_NOT(isMappingIncluded(inputIdxMap.second,
365 {splitOutputRange}, {inputRange}),
366 "Invalid input range map for splitting node!");
367 } else {
368 // Verify exact mapping.
369 RETURN_ERR_IF_NOT(
370 isMappingExact(inputIdxMap.second, {splitOutputRange}, {inputRange}),
371 "Invalid input range map for splitting node!");
372 }
373 }
374
375 // Verify the output slice range maps.
376 for (const auto &outputIdxMap : outputIdxAndMaps) {
377 SliceRange outputRange = SliceRange(node->getType(outputIdxMap.first));
378 RETURN_ERR_IF_NOT(
379 isMappingExact(outputIdxMap.second, {splitOutputRange}, {outputRange}),
380 "Invalid output range map for splitting node!");
381 }
382
383 return Error::success();
384}
385
386///===---------------------------------------------------------------------===//
387/// verifySplitNodes
388///===---------------------------------------------------------------------===//
389/// Function to verify the split nodes.
390static Expected<bool>
391verifySplitNodes(const Node *node, dim_t splitOutputIdx,
392 const llvm::ArrayRef<SliceRange> &splitOutputSlices,
393 const llvm::ArrayRef<OpIdxAndMap> &inputIdxAndMaps,
394 const llvm::ArrayRef<OpIdxAndMap> &outputIdxAndMaps,
395 const SplitNodeConstraint *splitConstraint,
396 const SplitNodeModifier &splitNodeModifier) {
397
398 // Create temporary nodes to make verifications without adding them to
399 // the graph in order to avoid the pollution of the graph with nodes
400 // which could be invalid or could not meet all the constraints and
401 // hence be later removed from the graph.
402 bool splitNodesCheck = true;
403 std::vector<Node *> splitNodes;
404 std::list<std::unique_ptr<SliceNode>> inputSliceNodes;
405 std::list<Type> inputTypes;
406 std::list<Type> outputTypes;
407 for (const auto &splitOutputSlice : splitOutputSlices) {
408
409 // Create clone to inherit all the inputs/members of the original node.
410 Node *clone = node->clone();
411 splitNodes.push_back(clone);
412
413 // Detach clone from all the inputs of the original node.
414 for (unsigned idx = 0, e = clone->getNumInputs(); idx < e; ++idx) {
415 clone->setNthInput(idx, nullptr);
416 }
417
418 // Gather input slice ranges for the clone. The ranges are ordered
419 // according to the input operand indices.
420 std::vector<SliceRange> inputRanges(clone->getNumInputs());
421 for (const auto &inputIdxMap : inputIdxAndMaps) {
422 auto inputCheckedRange = inputIdxMap.second(splitOutputSlice);
423 splitNodesCheck = splitNodesCheck && inputCheckedRange.first;
424 splitNodesCheck = splitNodesCheck && !inputCheckedRange.second.isEmpty();
425 inputRanges[inputIdxMap.first] = inputCheckedRange.second;
426 }
427
428 // Gather output slice ranges for the clone. The ranges are ordered
429 // according to the output operand indices.
430 std::vector<SliceRange> outputRanges(clone->getNumResults());
431 outputRanges[splitOutputIdx] = splitOutputSlice;
432 for (const auto &outputIdxMap : outputIdxAndMaps) {
433 auto outputCheckedRange = outputIdxMap.second(splitOutputSlice);
434 splitNodesCheck = splitNodesCheck && outputCheckedRange.first;
435 splitNodesCheck = splitNodesCheck && !outputCheckedRange.second.isEmpty();
436 outputRanges[outputIdxMap.first] = outputCheckedRange.second;
437 }
438
439 // Early break.
440 if (!splitNodesCheck) {
441 break;
442 }
443
444 // Set clone input types. Since a node does not own its input types and
445 // the clone inherits the input types from the input nodes of the
446 // original node we create here dummy input SliceNodes and attach them
447 // to the clone in order to allow setting and checking the clone input
448 // types without modifying the input types of the original node.
449 for (const auto &inputIdxMap : inputIdxAndMaps) {
450 auto &inputRange = inputRanges[inputIdxMap.first];
451 auto inputType =
452 Type::newShape(*(node->getNthInput(inputIdxMap.first).getType()),
453 inputRange.getSizes());
454 inputTypes.push_back(inputType);
455 inputSliceNodes.push_back(std::make_unique<SliceNode>(
456 "inputSlice", &(inputTypes.back()),
457 /* Input */ nullptr, inputRange.getStarts()));
458 clone->setNthInput(inputIdxMap.first, inputSliceNodes.back().get());
459 }
460
461 // Set clone split output type. The original node output type is not
462 // modified because the clone owns its output types.
463 outputTypes.push_back(Type::newShape(*node->getType(splitOutputIdx),
464 splitOutputSlice.getSizes()));
465 clone->getNthResult(splitOutputIdx).setTypeUnsafe(&outputTypes.back());
466
467 // Set clone output types. The original node output types are not
468 // modified because the clone owns its output types.
469 for (const auto &outputIdxMap : outputIdxAndMaps) {
470 auto &outputRange = outputRanges[outputIdxMap.first];
471 auto outputType = Type::newShape(*node->getType(outputIdxMap.first),
472 outputRange.getSizes());
473 outputTypes.push_back(outputType);
474 clone->getNthResult(outputIdxMap.first)
475 .setTypeUnsafe(&outputTypes.back());
476 }
477
478 // Modify clone.
479 splitNodeModifier(node, clone, inputRanges, outputRanges);
480
481 // Verify clone. If the clone is invalid at this point this means there
482 // is a logic error in the splitting infrastructure (the input/output
483 // maps are not checked properly or the split node modifier is flawed)
484 // so we throw an error (not the same thing as returning false which is
485 // intended for signaling that the splitting infrastructure correctly
486 // identified an incorrect split configuration or the split configuration
487 // is not accepted by the user constraints).
488 if (!clone->verify()) {
489 // Dump some extra error context.
490 llvm::errs() << "Slice range description:\n";
491 for (unsigned idx = 0, e = clone->getNumInputs(); idx < e; ++idx) {
492 llvm::errs() << clone->getInputName(idx) << ": "
493 << inputRanges[idx].toString() << "\n";
494 }
495 for (unsigned idx = 0, e = clone->getNumResults(); idx < e; ++idx) {
496 llvm::errs() << clone->getOutputName(idx).str() << ": "
497 << outputRanges[idx].toString() << "\n";
498 }
499 llvm::errs() << "Node description:\n";
500 llvm::errs() << clone->toString() << "\n";
501 return MAKE_ERR("Invalid node obtained during node splitting!");
502 }
503
504 // Early break.
505 if (!splitNodesCheck) {
506 break;
507 }
508 }
509
510 // Check split nodes against user constraint (if any).
511 if (splitConstraint) {
512 splitNodesCheck = splitNodesCheck && (*splitConstraint)(node, splitNodes);
513 }
514
515 // Explicitly destroy the temporary nodes.
516 for (auto *splitNode : splitNodes) {
517 Node::destroyNode(splitNode);
518 }
519
520 return splitNodesCheck;
521}
522
523///===---------------------------------------------------------------------===//
524/// splitAndReplaceNode
525///===---------------------------------------------------------------------===//
526static Expected<std::vector<Node *>> splitAndReplaceNode(
527 Node *node, const SplitNodeOption *splitOption,
528 const SplitNodeConstraint *splitConstraint, dim_t splitOutputIdx,
529 const llvm::ArrayRef<OpIdxAndMap> &inputIdxAndMaps,
530 const llvm::ArrayRef<OpIdxAndMap> &outputIdxAndMaps = {},
531 const SplitNodeModifier &splitNodeModifier = SplitNodeModifierNop) {
532
533 // If the split output operand has no dimensions then return.
534 if (node->getType(splitOutputIdx)->dims().empty()) {
535 return std::vector<Node *>();
536 }
537
538 // The default split dims are all the dims of the split output operand.
539 RETURN_ERR_IF_NOT(splitOutputIdx < node->getNumResults(),
540 "Invalid output index for splitting node!");
541 std::vector<size_t> splitDims(node->getType(splitOutputIdx)->dims().size());
542 std::iota(splitDims.begin(), splitDims.end(), 0);
543
544 // Explicit split dims for this node.
545 if (splitOption) {
546 // We use explicit split dims only for orthogonal option.
547 auto *splitOptionOrthogonal =
548 dyn_cast<SplitNodeOptionOrthogonal>(splitOption);
549 if (splitOptionOrthogonal) {
550 splitDims = splitOptionOrthogonal->getSplitDims();
551 }
552 }
553
554 // Verify split parameters.
555 RETURN_IF_ERR(verifySplitParams(node, splitOutputIdx, splitDims,
556 inputIdxAndMaps, outputIdxAndMaps));
557
558 // ------------------------------- Split output ------------------------------
559 // Initialize the split output slices with the initial output range.
560 SliceRange splitOutputRange = SliceRange(node->getType(splitOutputIdx));
561 std::vector<SliceRange> splitOutputSlices = {splitOutputRange};
562
563 // If a specific split option is given then we do a targeted splitting.
564 // If no specific split option is given then we search a split configuration
565 // which meets the constraint.
566 if (splitOption) {
567
568 // Orthogonal: Split along all the given dimensions using the given option.
569 // Non-orthogonal: Use the raw slice ranges explicitly provided.
570 auto *splitOptionOrthogonal =
571 dyn_cast<SplitNodeOptionOrthogonal>(splitOption);
572 if (splitOptionOrthogonal) {
573 for (size_t splitDim : splitDims) {
574 splitOutputSlices = splitSliceRanges(splitOutputSlices, splitDim,
575 splitOptionOrthogonal);
576 }
577 } else {
578 // Set raw slice ranges.
579 auto *splitOptionRaw = dyn_cast<SplitNodeBySliceRanges>(splitOption);
580 RETURN_ERR_IF_NOT(splitOptionRaw,
581 "Non orthogonal split option not supported!");
582 splitOutputSlices = splitOptionRaw->getSliceRanges();
583 // We verify that the explicitly provided slice ranges are valid.
584 for (const auto &sliceRange : splitOutputSlices) {
585 RETURN_ERR_IF_NOT(
586 sliceRange.getNumDims() == splitOutputRange.getNumDims(),
587 "Non orthogonal slice range rank not equal to output rank!");
588 RETURN_ERR_IF_NOT(
589 sliceRange.isIncludedBy(splitOutputRange),
590 "Non orthogonal slice range exceed the output operand range!");
591 }
592 }
593
594 // Verify split nodes.
595 bool splitNodesCheck = true;
596 ASSIGN_VALUE_OR_RETURN_ERR(
597 splitNodesCheck,
598 verifySplitNodes(node, splitOutputIdx, splitOutputSlices,
599 inputIdxAndMaps, outputIdxAndMaps, splitConstraint,
600 splitNodeModifier));
601
602 // If split nodes are invalid then we do not perform any splitting.
603 if (!splitNodesCheck) {
604 return std::vector<Node *>();
605 }
606
607 } else {
608
609 // When no split option is given a split constraint is mandatory.
610 RETURN_ERR_IF_NOT(splitConstraint, "When a split option is not given then "
611 "a split constraint must be given!");
612
613 // Start searching of a split configuration which meets the constraint by
614 // splitting along the given dimensions iteratively in smaller chunks.
615 bool splitFound = false;
616 for (size_t splitDim : splitDims) {
617
618 dim_t splitDimSize = splitOutputRange.getDimSize(splitDim);
619 std::vector<SliceRange> splitOutputSlicesTemp;
620 for (dim_t dimNumChunks = 1; dimNumChunks <= splitDimSize;
621 dimNumChunks++) {
622
623 // Split along current dimension in the given number of chunks.
624 auto splitOptionSearch =
625 SplitNodeByNumChunks({splitDim}, {dimNumChunks});
626 splitOutputSlicesTemp =
627 splitSliceRanges(splitOutputSlices, splitDim, &splitOptionSearch);
628
629 // Verify split nodes.
630 bool splitNodesCheck = true;
631 ASSIGN_VALUE_OR_RETURN_ERR(
632 splitNodesCheck,
633 verifySplitNodes(node, splitOutputIdx, splitOutputSlicesTemp,
634 inputIdxAndMaps, outputIdxAndMaps, splitConstraint,
635 splitNodeModifier));
636
637 // If split is found we stop searching.
638 if (splitNodesCheck) {
639 splitFound = true;
640 break;
641 }
642 }
643
644 // Save the split output slices.
645 splitOutputSlices = splitOutputSlicesTemp;
646
647 // If split is found we stop searching.
648 if (splitFound) {
649 break;
650 }
651 }
652
653 // If split is not found then we do not perform any splitting.
654 if (!splitFound) {
655 return std::vector<Node *>();
656 }
657 }
658
659 // If with the current split parameters only one slice is obtained then no
660 // splitting is required since the slice is the same as the original node.
661 if (splitOutputSlices.size() == 1) {
662 return std::vector<Node *>();
663 }
664
665 // -------------------------------- Split node -------------------------------
666 // Get parent function.
667 Function *F = node->getParent();
668 RETURN_ERR_IF_NOT(F, "Cannot split a node without a parent Function!");
669
670 // Allocate output tensors used for merging the partial outputs. We only merge
671 // the partial outputs for the output operands which are effectively used.
672 std::vector<NodeValue> mergedOutputs(node->getNumResults());
673 for (size_t outIdx = 0, outIdxEnd = node->getNumResults(); outIdx < outIdxEnd;
674 outIdx++) {
675 if (node->getNthResult(outIdx).getNumUsers()) {
676 auto nodeName =
677 node->getName().str() + ".TouchOutput" + std::to_string(outIdx);
678 mergedOutputs[outIdx] = F->createTouch(nodeName, node->getType(outIdx));
679 } else {
680 mergedOutputs[outIdx] = nullptr;
681 }
682 }
683
684 // Create split nodes.
685 std::vector<Node *> splitNodes(splitOutputSlices.size(), nullptr);
686 for (size_t sliceIdx = 0, sliceIdxEnd = splitOutputSlices.size();
687 sliceIdx < sliceIdxEnd; sliceIdx++) {
688
689 // Current split output slice.
690 const auto &splitOutputSlice = splitOutputSlices[sliceIdx];
691
692 // Create clone to inherit all the inputs/members of the original node.
693 Node *clone = node->clone();
694 clone->setName(node->getName().str() + ".Split" + std::to_string(sliceIdx));
695
696 // Gather final input slice ranges for the clone.
697 std::vector<SliceRange> inputRanges(clone->getNumInputs());
698 for (const auto &inputIdxMap : inputIdxAndMaps) {
699 auto inputCheckedRange = inputIdxMap.second(splitOutputSlice);
700 inputRanges[inputIdxMap.first] = inputCheckedRange.second;
701 }
702
703 // Gather final output slice ranges for the clone.
704 std::vector<SliceRange> outputRanges(clone->getNumResults());
705 outputRanges[splitOutputIdx] = splitOutputSlice;
706 for (const auto &outputIdxMap : outputIdxAndMaps) {
707 auto outputCheckedRange = outputIdxMap.second(splitOutputSlice);
708 outputRanges[outputIdxMap.first] = outputCheckedRange.second;
709 }
710
711 // Create input Slice nodes (only if necessary).
712 for (const auto &inputIdxMap : inputIdxAndMaps) {
713 auto inputIdx = inputIdxMap.first;
714 auto inputNodeValue = node->getNthInput(inputIdx);
715 auto &inputSliceRange = inputRanges[inputIdxMap.first];
716 TypeRef inpTy = inputNodeValue.getType();
717 Type outTy = Type::newShape(*(inpTy), inputSliceRange.getSizes());
718 if (outTy.isEqual(inpTy)) {
719 clone->setNthInput(inputIdx, inputNodeValue);
720 } else {
721 auto nodeName =
722 clone->getName().str() + ".SliceInput" + std::to_string(inputIdx);
723 auto *inputSlice = F->createSlice(nodeName, inputNodeValue,
724 inputSliceRange.getStarts(), &outTy);
725 clone->setNthInput(inputIdx, inputSlice);
726 }
727 }
728
729 // Set clone split output type. The original node output type is not
730 // modified because the clone owns its output types.
731 TypeRef splitOutputType = F->getParent()->uniqueTypeWithNewShape(
732 node->getType(splitOutputIdx), splitOutputSlice.getSizes());
733 clone->getNthResult(splitOutputIdx).setTypeUnsafe(splitOutputType);
734
735 // Set clone output types. The original node output types are not
736 // modified because the clone owns its output types.
737 for (const auto &outputIdxMap : outputIdxAndMaps) {
738 auto &outputRange = outputRanges[outputIdxMap.first];
739 TypeRef outputType = F->getParent()->uniqueTypeWithNewShape(
740 node->getType(outputIdxMap.first), outputRange.getSizes());
741 clone->getNthResult(outputIdxMap.first).setTypeUnsafe(outputType);
742 }
743
744 // Modify clone.
745 splitNodeModifier(node, clone, inputRanges, outputRanges);
746
747 // Verify clone.
748 RETURN_ERR_IF_NOT(clone->verify(),
749 "Invalid node obtained during node splitting!");
750
751 // Add clone to function.
752 F->addNode(clone);
753
754 // Add clone to vector.
755 splitNodes[sliceIdx] = clone;
756
757 // Merge the partial outputs of this clone (only if used).
758 for (size_t outIdx = 0, outIdxEnd = node->getNumResults();
759 outIdx < outIdxEnd; outIdx++) {
760 if (mergedOutputs[outIdx]) {
761 auto nodeName =
762 clone->getName().str() + ".MergeOutput" + std::to_string(outIdx);
763 mergedOutputs[outIdx] = F->createInsertTensor(
764 nodeName, mergedOutputs[outIdx], clone->getNthResult(outIdx),
765 outputRanges[outIdx].getStarts());
766 }
767 }
768 }
769
770 // Replace all the node outputs with the merged outputs (only if used).
771 for (size_t outIdx = 0, outIdxEnd = node->getNumResults(); outIdx < outIdxEnd;
772 outIdx++) {
773 if (mergedOutputs[outIdx]) {
774 node->getNthResult(outIdx).replaceAllUsesOfWith(mergedOutputs[outIdx]);
775 }
776 }
777
778 // Erase original node.
779 F->eraseNode(node);
780
781 return splitNodes;
782}
783
784namespace {
785///===---------------------------------------------------------------------===//
786/// Conv
787///===---------------------------------------------------------------------===//
788/// Structure which contains a valid flag \p check, a dimension range \p range
789/// and a dimension padding \p pads.
790struct CheckedRangeAndPads {
791 bool check{false};
792 DimRange range;
793 DimPads pads;
794};
795
796/// Structure which contains a valid flag \p check a dimension range \p range.
797struct CheckedRange {
798 bool check{false};
799 DimRange range;
800};
801} // namespace
802
803static CheckedRangeAndPads
804getConvInputCheckedRangeAndPads(const DimRange &outputSliceRange,
805 const DimRange &inputRange, dim_t kernel,
806 dim_t stride, DimPads pads, dim_t dilation) {
807
808 CHECK_LT(outputSliceRange.first, outputSliceRange.second)
809 << "Invalid output slice range!";
810 CHECK_LT(inputRange.first, inputRange.second) << "Invalid input range!";
811 CHECK_EQ(inputRange.first, 0) << "Input range must start with 0!";
812 CHECK_GE(kernel, 1) << "Invalid kernel size!";
813 CHECK_GE(stride, 1) << "Invalid stride size!";
814 CHECK_GE(dilation, 1) << "Invalid dilation size!";
815
816 // Get padded input range.
817 dim_t inputStartPadded = inputRange.first + pads.first;
818 dim_t inputStopPadded = inputRange.second + pads.first;
819
820 // Get padded input slice range.
821 dim_t inputSliceStartPadded = outputSliceRange.first * stride;
822 dim_t inputSliceStopPadded =
823 (outputSliceRange.second - 1) * stride + dilation * (kernel - 1) + 1;
824
825 // Verify input slice range bounds.
826 dim_t inputSliceStopPaddedMax = pads.first + inputRange.second + pads.second;
827 CHECK_LE(inputSliceStopPadded, inputSliceStopPaddedMax)
828 << "Input slice range out of bounds!";
829
830 // Get intersection.
831 dim_t intersectStartPadded =
832 std::max(inputStartPadded, inputSliceStartPadded);
833 dim_t intersectStopPadded = std::min(inputStopPadded, inputSliceStopPadded);
834
835 // Get checked input range.
836 bool allowed = (intersectStartPadded < intersectStopPadded);
837 dim_t inputSliceStart = intersectStartPadded - pads.first;
838 dim_t inputSliceStop =
839 intersectStopPadded >= pads.first ? intersectStopPadded - pads.first : 0;
840
841 // Get start pad.
842 dim_t inputSliceStartPad = 0;
843 if (inputSliceStartPadded < inputStartPadded) {
844 inputSliceStartPad = inputStartPadded - inputSliceStartPadded;
845 }
846
847 // Get stop pad.
848 dim_t inputSliceStopPad = 0;
849 if (inputSliceStopPadded > inputStopPadded) {
850 inputSliceStopPad = inputSliceStopPadded - inputStopPadded;
851 }
852
853 DimRange inputSliceRange = {inputSliceStart, inputSliceStop};
854 DimPads inputSlicePads = {inputSliceStartPad, inputSliceStopPad};
855 return CheckedRangeAndPads{allowed, inputSliceRange, inputSlicePads};
856}
857
858static CheckedRange
859getConvInputChannelCheckedRange(const DimRange &outputSliceRange,
860 const DimRange &inputRange, dim_t inputChannels,
861 dim_t outputChannels, dim_t group) {
862
863 CHECK_EQ(inputChannels % group, 0)
864 << "Input channels must be divisible by group!";
865 CHECK_EQ(outputChannels % group, 0)
866 << "Output channels must be divisible by group!";
867
868 dim_t inputChannelsPerGroup = inputChannels / group;
869 dim_t outputChannelsPerGroup = outputChannels / group;
870 dim_t outputSliceChannels = SliceRange({outputSliceRange}).getDimSize(0);
871
872 // Output slice range start/stop group index (inclusive).
873 dim_t outputSliceRangeStartGroupIdx =
874 outputSliceRange.first / outputChannelsPerGroup;
875 dim_t outputSliceRangeStopGroupIdx =
876 (outputSliceRange.second - 1) / outputChannelsPerGroup;
877
878 bool allowed = false;
879 if (outputSliceChannels <= outputChannelsPerGroup) {
880 // If the output slice range spans fully or partially one group then both
881 // ends of the range must be part of the same group.
882 allowed = (outputSliceRangeStartGroupIdx == outputSliceRangeStopGroupIdx);
883 } else {
884 // If the output slice range spans multiple groups then both ends of the
885 // range must be aligned to outputChannelsPerGroup.
886 allowed = SliceRange({outputSliceRange})
887 .isDimRangeAligned(0, outputChannelsPerGroup);
888 }
889
890 // Compute input slice range as a multiple of groups.
891 DimRange inputSliceRange;
892 inputSliceRange.first = outputSliceRangeStartGroupIdx * inputChannelsPerGroup;
893 inputSliceRange.second =
894 (outputSliceRangeStopGroupIdx + 1) * inputChannelsPerGroup;
895 return CheckedRange{allowed, inputSliceRange};
896}
897
898///===---------------------------------------------------------------------===//
899/// Conv2D
900///===---------------------------------------------------------------------===//
901template <typename ConvNodeTy, typename Shape>
902static std::vector<OpIdxAndMap>
903getConv2DInputIdxAndMaps(const ConvNodeTy *node) {
904
905 ShapeHW kernels = ShapeHW(node->getKernels());
906 ShapeHW strides = ShapeHW(node->getStrides());
907 PaddingTLBR pads(node->getPads());
908 unsigned_t group = node->getGroup();
909 ShapeHW dilations = ShapeHW(node->getDilation());
910 DimPads padsTB = {pads.top, pads.bottom};
911 DimPads padsLR = {pads.left, pads.right};
912
913 SliceRange inputRange = SliceRange(node->getInput().getType());
914 SliceRange filterRange = SliceRange(node->getFilter().getType());
915 SliceRange outputRange = SliceRange(node->getResult().getType());
916
917 // Output slice to input slice range map.
918 CheckedSliceRangeMap inputSliceRangeMap =
919 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
920 // Get range and pads for dimension H.
921 auto checkedRangeAndPadsH = getConvInputCheckedRangeAndPads(
922 outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height,
923 strides.height, padsTB, dilations.height);
924
925 // Get range and pads for dimension W.
926 auto checkedRangeAndPadsW = getConvInputCheckedRangeAndPads(
927 outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width,
928 strides.width, padsLR, dilations.width);
929
930 // Get range for dimension C.
931 auto checkedRangeC = getConvInputChannelCheckedRange(
932 outputSliceRange[Shape::DimC], inputRange[Shape::DimC],
933 inputRange.getDimSize(Shape::DimC), outputRange.getDimSize(Shape::DimC),
934 group);
935
936 std::vector<DimRange> inputDimRanges(4);
937 inputDimRanges[Shape::DimN] = outputSliceRange[Shape::DimN];
938 inputDimRanges[Shape::DimH] = checkedRangeAndPadsH.range;
939 inputDimRanges[Shape::DimW] = checkedRangeAndPadsW.range;
940 inputDimRanges[Shape::DimC] = checkedRangeC.range;
941 bool allowed = checkedRangeAndPadsH.check && checkedRangeAndPadsW.check &&
942 checkedRangeC.check;
943 return {allowed, SliceRange(inputDimRanges)};
944 };
945
946 // Output slice to filter slice range map.
947 CheckedSliceRangeMap filterSliceRangeMap =
948 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
949 std::vector<DimRange> filterDimRanges(4);
950 filterDimRanges[Shape::DimN] = outputSliceRange[Shape::DimC];
951 filterDimRanges[Shape::DimH] = filterRange[Shape::DimH];
952 filterDimRanges[Shape::DimW] = filterRange[Shape::DimW];
953 filterDimRanges[Shape::DimC] = filterRange[Shape::DimC];
954 return {true, SliceRange(filterDimRanges)};
955 };
956
957 // Output slice to bias slice range map.
958 CheckedSliceRangeMap biasSliceRangeMap =
959 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
960 return {true, SliceRange({outputSliceRange[Shape::DimC]})};
961 };
962
963 // Return input indices and maps.
964 if (std::is_same<ConvNodeTy, ConvolutionNode>::value) {
965 return {{ConvolutionNode::InputIdx, inputSliceRangeMap},
966 {ConvolutionNode::FilterIdx, filterSliceRangeMap},
967 {ConvolutionNode::BiasIdx, biasSliceRangeMap}};
968 } else if (std::is_same<ConvNodeTy,
969 ChannelwiseQuantizedConvolutionNode>::value) {
970 return {
971 {ChannelwiseQuantizedConvolutionNode::InputIdx, inputSliceRangeMap},
972 {ChannelwiseQuantizedConvolutionNode::FilterIdx, filterSliceRangeMap},
973 {ChannelwiseQuantizedConvolutionNode::BiasIdx, biasSliceRangeMap},
974 {ChannelwiseQuantizedConvolutionNode::FilterScalesIdx,
975 biasSliceRangeMap},
976 {ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx,
977 biasSliceRangeMap},
978 {ChannelwiseQuantizedConvolutionNode::BiasScalesIdx, biasSliceRangeMap},
979 {ChannelwiseQuantizedConvolutionNode::BiasOffsetsIdx,
980 biasSliceRangeMap}};
981 }
982 llvm_unreachable("Invalid Convolution node type!");
983}
984
985template <typename ConvNodeTy, typename Shape>
986void Conv2DSplitNodeModifier(const Node *origNode, Node *splitNode,
987 const std::vector<SliceRange> &inputSliceRanges,
988 const std::vector<SliceRange> &outputSliceRanges) {
989 auto *convOrigNode = dyn_cast<ConvNodeTy>(origNode);
990 auto *convSplitNode = dyn_cast<ConvNodeTy>(splitNode);
991 if (!(convOrigNode && convSplitNode)) {
992 return;
993 }
994
995 ShapeHW kernels = ShapeHW(convOrigNode->getKernels());
996 ShapeHW strides = ShapeHW(convOrigNode->getStrides());
997 PaddingTLBR pads(convOrigNode->getPads());
998 ShapeHW dilations = ShapeHW(convOrigNode->getDilation());
999 DimPads padsTB = {pads.top, pads.bottom};
1000 DimPads padsLR = {pads.left, pads.right};
1001
1002 // Get paddings for split node.
1003 auto outputSliceRange = outputSliceRanges[ConvNodeTy::ResultIdx];
1004 auto inputRange = SliceRange(convOrigNode->getInput().getType());
1005 auto checkedRangeAndPadsH = getConvInputCheckedRangeAndPads(
1006 outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height,
1007 strides.height, padsTB, dilations.height);
1008 auto checkedRangeAndPadsW = getConvInputCheckedRangeAndPads(
1009 outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width,
1010 strides.width, padsLR, dilations.width);
1011 DimPads splitPadsTB = checkedRangeAndPadsH.pads;
1012 DimPads splitPadsLR = checkedRangeAndPadsW.pads;
1013
1014 // Modify paddings for split node.
1015 std::vector<unsigned_t> splitPads = {
1016 static_cast<unsigned_t>(splitPadsTB.first),
1017 static_cast<unsigned_t>(splitPadsLR.first),
1018 static_cast<unsigned_t>(splitPadsTB.second),
1019 static_cast<unsigned_t>(splitPadsLR.second)};
1020 convSplitNode->setPads(splitPads);
1021
1022 // Modify group for split node.
1023 dim_t outputChannels =
1024 SliceRange(convOrigNode->getType(ConvNodeTy::ResultIdx))
1025 .getDimSize(Shape::DimC);
1026 dim_t outputSliceChannels =
1027 SliceRange(convSplitNode->getType(ConvNodeTy::ResultIdx))
1028 .getDimSize(Shape::DimC);
1029 auto group = convOrigNode->getGroup();
1030
1031 CHECK_EQ(outputChannels % group, 0)
1032 << "Output channels must be divisible by group!";
1033 dim_t outputChannelsPerGroup = outputChannels / group;
1034
1035 if (outputSliceChannels <= outputChannelsPerGroup) {
1036 // If the output slice range spans fully or partially one group then we
1037 // set the group to 1.
1038 convSplitNode->setGroup(1);
1039 } else {
1040 // If the output slice range spans more than a group then it must span a
1041 // multiple of outputChannelsPerGroup.
1042 CHECK_EQ(outputSliceChannels % outputChannelsPerGroup, 0)
1043 << "Output slice channels must be divisible by the output channels per "
1044 "group!";
1045 dim_t splitGroup = outputSliceChannels / outputChannelsPerGroup;
1046 convSplitNode->setGroup(static_cast<unsigned_t>(splitGroup));
1047 }
1048}
1049
1050///===---------------------------------------------------------------------===//
1051/// Pool
1052///===---------------------------------------------------------------------===//
1053static CheckedRangeAndPads
1054getPoolInputCheckedRangeAndPads(const DimRange &outputSliceRange,
1055 const DimRange &inputRange, dim_t kernel,
1056 dim_t stride, DimPads pads) {
1057 return getConvInputCheckedRangeAndPads(outputSliceRange, inputRange, kernel,
1058 stride, pads, /* dilation */ 1);
1059}
1060
1061template <class PoolNode, typename Shape>
1062static std::vector<OpIdxAndMap> getPoolInputIdxAndMaps(const PoolNode *node) {
1063
1064 ShapeHW kernels = ShapeHW(node->getKernels());
1065 ShapeHW strides = ShapeHW(node->getStrides());
1066 PaddingTLBR pads(node->getPads());
1067 DimPads padsTB = {pads.top, pads.bottom};
1068 DimPads padsLR = {pads.left, pads.right};
1069
1070 // Output slice to input slice range map.
1071 SliceRange inputRange = SliceRange(node->getInput().getType());
1072 CheckedSliceRangeMap inputSliceRangeMap =
1073 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1074 // Get range and pads for dimension H.
1075 auto checkedRangeAndPadsH = getPoolInputCheckedRangeAndPads(
1076 outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height,
1077 strides.height, padsTB);
1078
1079 // Get range and pads for dimension W.
1080 auto checkedRangeAndPadsW = getPoolInputCheckedRangeAndPads(
1081 outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width,
1082 strides.width, padsLR);
1083
1084 std::vector<DimRange> inputDimRanges(4);
1085 inputDimRanges[Shape::DimN] = outputSliceRange[Shape::DimN];
1086 inputDimRanges[Shape::DimH] = checkedRangeAndPadsH.range;
1087 inputDimRanges[Shape::DimW] = checkedRangeAndPadsW.range;
1088 inputDimRanges[Shape::DimC] = outputSliceRange[Shape::DimC];
1089 bool allowed = checkedRangeAndPadsH.check && checkedRangeAndPadsW.check;
1090 return {allowed, SliceRange(inputDimRanges)};
1091 };
1092
1093 // Return input index and map.
1094 return {{PoolNode::InputIdx, inputSliceRangeMap}};
1095}
1096
1097template <class PoolNode, typename Shape>
1098void PoolSplitNodeModifier(const Node *origNode, Node *splitNode,
1099 const std::vector<SliceRange> &inputSliceRanges,
1100 const std::vector<SliceRange> &outputSliceRanges) {
1101 auto *poolOrigNode = dyn_cast<PoolNode>(origNode);
1102 auto *poolSplitNode = dyn_cast<PoolNode>(splitNode);
1103 if (!(poolOrigNode && poolSplitNode)) {
1104 return;
1105 }
1106
1107 ShapeHW kernels = ShapeHW(poolOrigNode->getKernels());
1108 ShapeHW strides = ShapeHW(poolOrigNode->getStrides());
1109 PaddingTLBR pads(poolOrigNode->getPads());
1110 DimPads padsTB = {pads.top, pads.bottom};
1111 DimPads padsLR = {pads.left, pads.right};
1112
1113 // Get paddings for split node.
1114 auto outputSliceRange = outputSliceRanges[PoolNode::ResultIdx];
1115 auto inputRange = SliceRange(poolOrigNode->getInput().getType());
1116 auto checkedRangeAndPadsH = getPoolInputCheckedRangeAndPads(
1117 outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height,
1118 strides.height, padsTB);
1119 auto checkedRangeAndPadsW = getPoolInputCheckedRangeAndPads(
1120 outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width,
1121 strides.width, padsLR);
1122 DimPads splitPadsTB = checkedRangeAndPadsH.pads;
1123 DimPads splitPadsLR = checkedRangeAndPadsW.pads;
1124
1125 // Modify paddings for split node.
1126 std::vector<unsigned_t> splitPads = {
1127 static_cast<unsigned_t>(splitPadsTB.first),
1128 static_cast<unsigned_t>(splitPadsLR.first),
1129 static_cast<unsigned_t>(splitPadsTB.second),
1130 static_cast<unsigned_t>(splitPadsLR.second)};
1131 poolSplitNode->setPads(splitPads);
1132}
1133
1134///===---------------------------------------------------------------------===//
1135/// FullyConnected
1136///===---------------------------------------------------------------------===//
1137static std::vector<OpIdxAndMap>
1138getFullyConnectedInputIdxAndMaps(const FullyConnectedNode *node) {
1139 // Output slice to input slice range map.
1140 CheckedSliceRangeMap inputSliceRangeMap =
1141 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1142 SliceRange inputRange = SliceRange(node->getInput().getType());
1143 inputRange[ShapeHW::DimH] = outputSliceRange[ShapeHW::DimH];
1144 return {true, inputRange};
1145 };
1146
1147 // Output slice to weights slice range map.
1148 CheckedSliceRangeMap weightsSliceRangeMap =
1149 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1150 SliceRange weightsRange = SliceRange(node->getWeights().getType());
1151 weightsRange[ShapeHW::DimW] = outputSliceRange[ShapeHW::DimW];
1152 return {true, weightsRange};
1153 };
1154
1155 // Output slice to bias slice range map.
1156 CheckedSliceRangeMap biasSliceRangeMap =
1157 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1158 return {true, SliceRange({outputSliceRange[ShapeHW::DimW]})};
1159 };
1160
1161 // Return input index and map.
1162 return {{FullyConnectedNode::InputIdx, inputSliceRangeMap},
1163 {FullyConnectedNode::WeightsIdx, weightsSliceRangeMap},
1164 {FullyConnectedNode::BiasIdx, biasSliceRangeMap}};
1165}
1166
1167///===---------------------------------------------------------------------===//
1168/// MatMul
1169///===---------------------------------------------------------------------===//
1170static std::vector<OpIdxAndMap>
1171getMatMulInputIdxAndMaps(const MatMulNode *node) {
1172 // Output slice to LHS slice range map.
1173 CheckedSliceRangeMap lhsSliceRangeMap =
1174 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1175 SliceRange lhsRange = SliceRange(node->getLHS().getType());
1176 lhsRange[ShapeHW::DimH] = outputSliceRange[ShapeHW::DimH];
1177 return {true, lhsRange};
1178 };
1179
1180 // Output slice to RHS slice range map.
1181 CheckedSliceRangeMap rhsSliceRangeMap =
1182 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1183 SliceRange rhsRange = SliceRange(node->getRHS().getType());
1184 rhsRange[ShapeHW::DimW] = outputSliceRange[ShapeHW::DimW];
1185 return {true, rhsRange};
1186 };
1187
1188 // Return input index and map.
1189 return {{MatMulNode::LHSIdx, lhsSliceRangeMap},
1190 {MatMulNode::RHSIdx, rhsSliceRangeMap}};
1191}
1192
1193///===---------------------------------------------------------------------===//
1194/// BatchMatMul
1195///===---------------------------------------------------------------------===//
1196static std::vector<OpIdxAndMap>
1197getBatchMatMulInputIdxAndMaps(const BatchMatMulNode *node) {
1198 // Output slice to LHS slice range map.
1199 CheckedSliceRangeMap lhsSliceRangeMap =
1200 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1201 SliceRange lhsRange = SliceRange(node->getLHS().getType());
1202 lhsRange[ShapeNHW::DimN] = outputSliceRange[ShapeNHW::DimN];
1203 lhsRange[ShapeNHW::DimH] = outputSliceRange[ShapeNHW::DimH];
1204 return {true, lhsRange};
1205 };
1206
1207 // Output slice to RHS slice range map.
1208 CheckedSliceRangeMap rhsSliceRangeMap =
1209 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1210 SliceRange rhsRange = SliceRange(node->getRHS().getType());
1211 rhsRange[ShapeNHW::DimN] = outputSliceRange[ShapeNHW::DimN];
1212 rhsRange[ShapeNHW::DimW] = outputSliceRange[ShapeNHW::DimW];
1213 return {true, rhsRange};
1214 };
1215
1216 // Return input index and map.
1217 return {{BatchMatMulNode::LHSIdx, lhsSliceRangeMap},
1218 {BatchMatMulNode::RHSIdx, rhsSliceRangeMap}};
1219}
1220
1221///===---------------------------------------------------------------------===//
1222/// BatchedAdd
1223///===---------------------------------------------------------------------===//
1224static std::vector<OpIdxAndMap>
1225getBatchedAddInputIdxAndMaps(const BatchedAddNode *node) {
1226
1227 // Output slice to Batch slice range map.
1228 CheckedSliceRangeMap batchSliceRangeMap =
1229 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1230 return {true, outputSliceRange};
1231 };
1232
1233 // Output slice to Slice slice range map.
1234 CheckedSliceRangeMap sliceSliceRangeMap =
1235 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1236 size_t numOutDims = outputSliceRange.getNumDims();
1237 return {true, outputSliceRange.extractRanges(1, numOutDims - 1)};
1238 };
1239
1240 // Return input index and map.
1241 return {{BatchedAddNode::BatchIdx, batchSliceRangeMap},
1242 {BatchedAddNode::SliceIdx, sliceSliceRangeMap}};
1243}
1244
1245///===---------------------------------------------------------------------===//
1246/// Transpose
1247///===---------------------------------------------------------------------===//
1248static std::vector<OpIdxAndMap>
1249getTransposeInputIdxAndMaps(const TransposeNode *node) {
1250
1251 // Transpose shuffle.
1252 std::vector<unsigned_t> nodeShuffle = node->getShuffle();
1253 std::vector<size_t> shuffle(nodeShuffle.size());
1254 for (size_t idx = 0, e = nodeShuffle.size(); idx < e; ++idx) {
1255 shuffle[idx] = static_cast<size_t>(nodeShuffle[idx]);
1256 }
1257
1258 // Output slice to Input slice range map.
1259 CheckedSliceRangeMap inputSliceRangeMap =
1260 [=](const SliceRange &outputSliceRange) -> CheckedSliceRange {
1261 return {true, outputSliceRange.shuffleRanges(shuffle, /*invert*/ true)};
1262 };
1263
1264 // Return input index and map.
1265 return {{TransposeNode::InputIdx, inputSliceRangeMap}};
1266}
1267
1268///===---------------------------------------------------------------------===//
1269/// splitNode
1270///===---------------------------------------------------------------------===//
1271Expected<std::vector<Node *>>
1272glow::splitNode(Node *node, const SplitNodeOption *splitOption,
1273 const SplitNodeConstraint *splitConstraint) {
1274
1275 // We can do the splitting if at least the option or the constraint is given.
1276 RETURN_ERR_IF_NOT(
1277 splitOption || splitConstraint,
1278 "At least the split option or the split constraint must be given!");
1279
1280 switch (node->getKind()) {
1281
1282 case Kinded::Kind::ConvolutionNodeKind: {
1283 return splitAndReplaceNode(
1284 node, splitOption, splitConstraint, ConvolutionNode::ResultIdx,
1285 getConv2DInputIdxAndMaps<ConvolutionNode, ShapeNHWC>(
1286 dyn_cast<ConvolutionNode>(node)),
1287 {}, Conv2DSplitNodeModifier<ConvolutionNode, ShapeNHWC>);
1288 }
1289
1290 case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: {
1291 return splitAndReplaceNode(
1292 node, splitOption, splitConstraint,
1293 ChannelwiseQuantizedConvolutionNode::ResultIdx,
1294 getConv2DInputIdxAndMaps<ChannelwiseQuantizedConvolutionNode,
1295 ShapeNHWC>(
1296 dyn_cast<ChannelwiseQuantizedConvolutionNode>(node)),
1297 {},
1298 Conv2DSplitNodeModifier<ChannelwiseQuantizedConvolutionNode,
1299 ShapeNHWC>);
1300 }
1301
1302 case Kinded::Kind::MaxPoolNodeKind: {
1303 // The current definition of the MaxPool node does not allow splitting
1304 // of the second output operand 'Argmax' which contains flattened
1305 // indices whose values will be altered if processed in smaller chunks.
1306 // We allow splitting only if the 'Argmax' node value has no users.
1307 if (node->getNthResult(MaxPoolNode::ArgmaxIdx).getNumUsers() != 0) {
1308 break;
1309 }
1310 return splitAndReplaceNode(
1311 node, splitOption, splitConstraint, MaxPoolNode::ResultIdx,
1312 getPoolInputIdxAndMaps<MaxPoolNode, ShapeNHWC>(
1313 dyn_cast<MaxPoolNode>(node)),
1314 {{MaxPoolNode::ArgmaxIdx, CheckedSliceRangeMapIdentity}},
1315 PoolSplitNodeModifier<MaxPoolNode, ShapeNHWC>);
1316 }
1317
1318 case Kinded::Kind::AvgPoolNodeKind: {
1319 return splitAndReplaceNode(
1320 node, splitOption, splitConstraint, AvgPoolNode::ResultIdx,
1321 getPoolInputIdxAndMaps<AvgPoolNode, ShapeNHWC>(
1322 dyn_cast<AvgPoolNode>(node)),
1323 {}, PoolSplitNodeModifier<AvgPoolNode, ShapeNHWC>);
1324 }
1325
1326 case Kinded::Kind::FullyConnectedNodeKind: {
1327 return splitAndReplaceNode(
1328 node, splitOption, splitConstraint, FullyConnectedNode::ResultIdx,
1329 getFullyConnectedInputIdxAndMaps(dyn_cast<FullyConnectedNode>(node)));
1330 }
1331
1332 case Kinded::Kind::MatMulNodeKind: {
1333 return splitAndReplaceNode(
1334 node, splitOption, splitConstraint, MatMulNode::ResultIdx,
1335 getMatMulInputIdxAndMaps(dyn_cast<MatMulNode>(node)));
1336 }
1337
1338 case Kinded::Kind::BatchMatMulNodeKind: {
1339 return splitAndReplaceNode(
1340 node, splitOption, splitConstraint, BatchMatMulNode::ResultIdx,
1341 getBatchMatMulInputIdxAndMaps(dyn_cast<BatchMatMulNode>(node)));
1342 }
1343
1344 case Kinded::Kind::BatchedAddNodeKind: {
1345 return splitAndReplaceNode(
1346 node, splitOption, splitConstraint, BatchedAddNode::ResultIdx,
1347 getBatchedAddInputIdxAndMaps(dyn_cast<BatchedAddNode>(node)));
1348 }
1349
1350 case Kinded::Kind::TransposeNodeKind: {
1351 return splitAndReplaceNode(
1352 node, splitOption, splitConstraint, TransposeNode::ResultIdx,
1353 getTransposeInputIdxAndMaps(dyn_cast<TransposeNode>(node)));
1354 }
1355
1356 case Kinded::Kind::AddNodeKind:
1357 case Kinded::Kind::MulNodeKind:
1358 case Kinded::Kind::SubNodeKind:
1359 case Kinded::Kind::DivNodeKind:
1360 case Kinded::Kind::FmodNodeKind:
1361 case Kinded::Kind::MaxNodeKind:
1362 case Kinded::Kind::MinNodeKind:
1363 case Kinded::Kind::CmpLTENodeKind:
1364 case Kinded::Kind::CmpLTNodeKind:
1365 case Kinded::Kind::CmpEQNodeKind:
1366 case Kinded::Kind::PowNodeKind: {
1367 DCHECK_EQ(node->getNumInputs(), 2) << "Binary operator invalid!";
1368 DCHECK_EQ(node->getNumResults(), 1) << "Binary operator invalid!";
1369 return splitAndReplaceNode(
1370 node, splitOption, splitConstraint, ArithmeticNode::ResultIdx,
1371 {{ArithmeticNode::LHSIdx, CheckedSliceRangeMapIdentity},
1372 {ArithmeticNode::RHSIdx, CheckedSliceRangeMapIdentity}});
1373 }
1374
1375 case Kinded::Kind::ReluNodeKind:
1376 case Kinded::Kind::LeakyReluNodeKind:
1377 case Kinded::Kind::ClipNodeKind:
1378 case Kinded::Kind::TanhNodeKind:
1379 case Kinded::Kind::SigmoidNodeKind:
1380 case Kinded::Kind::LogNodeKind:
1381 case Kinded::Kind::ExpNodeKind:
1382 case Kinded::Kind::QuantizeNodeKind:
1383 case Kinded::Kind::RescaleQuantizedNodeKind:
1384 case Kinded::Kind::DequantizeNodeKind:
1385 case Kinded::Kind::ConvertToNodeKind: {
1386 DCHECK_EQ(node->getNumInputs(), 1) << "Unary operator invalid!";
1387 DCHECK_EQ(node->getNumResults(), 1) << "Unary operator invalid!";
1388 return splitAndReplaceNode(node, splitOption, splitConstraint,
1389 /*splitOutputIdx*/ 0,
1390 {{0, CheckedSliceRangeMapIdentity}});
1391 }
1392
1393 default:
1394 VLOG(1) << "Splitting node type '" << node->getKindName()
1395 << "' is not supported!\n";
1396 break;
1397 }
1398
1399 return std::vector<Node *>();
1400}
1401
1402Expected<std::vector<Node *>>
1403glow::splitNode(Node *node, const SplitNodeOption &splitOption) {
1404 return splitNode(node, &splitOption, nullptr);
1405}
1406
1407Expected<std::vector<Node *>>
1408glow::splitNode(Node *node, const SplitNodeConstraint &splitConstraint) {
1409 return splitNode(node, nullptr, &splitConstraint);
1410}
1411
1412///===---------------------------------------------------------------------===//
1413/// splitNodes
1414///===---------------------------------------------------------------------===//
1415Expected<SplitNodeMap>
1416glow::splitNodes(Function *F, const SplitNodeOptionMap &splitOptionMap,
1417 const SplitNodeConstraintMap &splitConstraintMap) {
1418 // Create split map.
1419 SplitNodeMap splitMap;
1420
1421 // Since we will be transforming the original list of nodes, reverse iterate.
1422 auto &nodes = F->getNodes();
1423 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e;) {
1424 Node *node = &*(it++);
1425
1426 // Find explicit split option for current node (if any).
1427 const SplitNodeOption *splitOption = nullptr;
1428 auto splitOptionIt = splitOptionMap.find(node);
1429 if (splitOptionIt != splitOptionMap.end()) {
1430 splitOption = splitOptionIt->second;
1431 }
1432
1433 // Find explicit split constraint for current node (if any).
1434 const SplitNodeConstraint *splitConstraint = nullptr;
1435 auto splitConstraintIt = splitConstraintMap.find(node);
1436 if (splitConstraintIt != splitConstraintMap.end()) {
1437 splitConstraint = splitConstraintIt->second;
1438 }
1439
1440 // Split current node if at least the option or the constraint is given.
1441 if (splitOption || splitConstraint) {
1442 ASSIGN_VALUE_OR_RETURN_ERR(splitMap[node],
1443 splitNode(node, splitOption, splitConstraint));
1444 }
1445 }
1446
1447 // Verify function after splitting nodes.
1448 RETURN_ERR_IF_NOT(F->verify(), "Function is not valid after node splitting!");
1449 return splitMap;
1450}
1451
1452Expected<SplitNodeMap> glow::splitNodes(Function *F,
1453 const SplitNodeOption &splitOption) {
1454 // Since we will be transforming the original list of nodes, reverse iterate.
1455 SplitNodeMap splitMap;
1456 auto &nodes = F->getNodes();
1457 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e;) {
1458 Node *node = &*(it++);
1459 const SplitNodeConstraint *splitConstraint = nullptr;
1460 ASSIGN_VALUE_OR_RETURN_ERR(splitMap[node],
1461 splitNode(node, &splitOption, splitConstraint));
1462 }
1463 // Verify function after splitting nodes.
1464 RETURN_ERR_IF_NOT(F->verify(), "Function is not valid after node splitting!");
1465 return splitMap;
1466}
1467
1468Expected<SplitNodeMap>
1469glow::splitNodes(Function *F, const SplitNodeConstraint &splitConstraint) {
1470 // Since we will be transforming the original list of nodes, reverse iterate.
1471 SplitNodeMap splitMap;
1472 auto &nodes = F->getNodes();
1473 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e;) {
1474 Node *node = &*(it++);
1475 const SplitNodeOption *splitOption = nullptr;
1476 ASSIGN_VALUE_OR_RETURN_ERR(splitMap[node],
1477 splitNode(node, splitOption, &splitConstraint));
1478 }
1479 // Verify function after splitting nodes.
1480 RETURN_ERR_IF_NOT(F->verify(), "Function is not valid after node splitting!");
1481 return splitMap;
1482}
1483
1484///===---------------------------------------------------------------------===//
1485/// splitNodeRecursively
1486///===---------------------------------------------------------------------===//
1487static Error
1488splitNodeRecursivelyMain(SplitNodeMap &splitMap, Node *node,
1489 const SplitNodeOption *splitOption,
1490 const SplitNodeConstraint *splitConstraint,
1491 unsigned maxDepth, bool singleUseOnly) {
1492
1493 // Early return if depth is 0.
1494 if (maxDepth == 0) {
1495 return Error::success();
1496 }
1497
1498 // We can do the splitting if at least the option or the constraint is given.
1499 RETURN_ERR_IF_NOT(
1500 splitOption || splitConstraint,
1501 "At least the split option or the split constraint must be given!");
1502
1503 // Split starting node.
1504 std::vector<Node *> splitNodesCurr;
1505 ASSIGN_VALUE_OR_RETURN_ERR(splitNodesCurr,
1506 splitNode(node, splitOption, splitConstraint));
1507
1508 // If starting node was NOT split then return, otherwise add nodes to map.
1509 if (!splitNodesCurr.size()) {
1510 return Error::success();
1511 } else {
1512 splitMap[node] = splitNodesCurr;
1513 }
1514
1515 // Early return if depth is 1.
1516 if (maxDepth == 1) {
1517 return Error::success();
1518 }
1519
1520 // Iterate through all of the input operands and split them.
1521 unsigned inpNum = splitNodesCurr.front()->getNumInputs();
1522 for (unsigned inpIdx = 0; inpIdx < inpNum; ++inpIdx) {
1523
1524 // Find parent node value and ranges for the current input.
1525 std::vector<SliceRange> inputRanges;
1526 inputRanges.reserve(splitNodesCurr.size());
1527 NodeValue sliceInputNodeValue = nullptr;
1528 for (const Node *splitNode : splitNodesCurr) {
1529 // Get parent node value and range.
1530 NodeValue inputNV = splitNode->getNthInput(inpIdx);
1531 if (auto *sliceNode = dyn_cast<SliceNode>(inputNV)) {
1532 inputNV = sliceNode->getInput();
1533 inputRanges.push_back(SliceRange(sliceNode));
1534 } else {
1535 inputRanges.push_back(SliceRange(inputNV.getType()));
1536 }
1537 // Verify that parent node value is common.
1538 if (!sliceInputNodeValue) {
1539 sliceInputNodeValue = inputNV;
1540 } else {
1541 RETURN_ERR_IF_NOT(
1542 sliceInputNodeValue == inputNV,
1543 "Input slices do not have a common parent node value!");
1544 }
1545 }
1546
1547 // If the node value which is sliced has other consumers than the SliceNodes
1548 // inserted during splitting then we do not split the node which produces
1549 // that node value.
1550 if (singleUseOnly &&
1551 sliceInputNodeValue.getNumUsers() > splitNodesCurr.size()) {
1552 continue;
1553 }
1554
1555 // Check that we split the 1st output operand of the parent node.
1556 if (sliceInputNodeValue.getResNo() != SplitNodeOutputIdx) {
1557 continue;
1558 }
1559
1560 // Split the input node of the SliceNodes using same slice ranges.
1561 auto splitInputOption = SplitNodeBySliceRanges(inputRanges);
1562 Node *splitInputNode = sliceInputNodeValue.getNode();
1563 RETURN_IF_ERR(splitNodeRecursivelyMain(splitMap, splitInputNode,
1564 &splitInputOption, splitConstraint,
1565 maxDepth - 1, singleUseOnly));
1566
1567 // Remove Slice and Insert nodes between adjacent split nodes.
1568 if (splitMap.count(splitInputNode)) {
1569 auto &splitNodesNext = splitMap[splitInputNode];
1570 assert(splitNodesCurr.size() == splitNodesNext.size() &&
1571 "Mismatch for number of split Nodes!");
1572 // Create short circuit between inputs and outputs of adjacent nodes.
1573 for (size_t idx = 0, len = splitNodesCurr.size(); idx < len; ++idx) {
1574 NodeValue splitNodeNextOut =
1575 splitNodesNext[idx]->getNthResult(SplitNodeOutputIdx);
1576 NodeValue splitNodeCurrInp = splitNodesCurr[idx]->getNthInput(inpIdx);
1577 assert(
1578 splitNodeNextOut.getType()->isEqual(splitNodeCurrInp.getType()) &&
1579 "Mismatch between input/output type when doing short-circuit!");
1580 assert(splitNodeNextOut.getNumUsers() == 1 &&
1581 "Split node output value has more than one use!");
1582 assert(splitNodeCurrInp.getNumUsers() == 1 &&
1583 "Split node input value has more than one use!");
1584 splitNodeCurrInp.replaceAllUsesOfWith(splitNodeNextOut);
1585 }
1586 }
1587 }
1588
1589 return Error::success();
1590}
1591
1592Expected<SplitNodeMap>
1593glow::splitNodeRecursively(Node *node, const SplitNodeOption *splitOption,
1594 const SplitNodeConstraint *splitConstraint,
1595 unsigned maxDepth, bool singleUseOnly) {
1596 Function *F = node->getParent();
1597 RETURN_ERR_IF_NOT(F, "Cannot split a node without a parent Function!");
1598
1599 // Create split map.
1600 SplitNodeMap splitMap;
1601
1602 // Check if this node has single use only before splitting it.
1603 if (singleUseOnly &&
1604 node->getNthResult(SplitNodeOutputIdx).getNumUsers() > 1) {
1605 return splitMap;
1606 }
1607
1608 // Split node recursively.
1609 RETURN_IF_ERR(splitNodeRecursivelyMain(
1610 splitMap, node, splitOption, splitConstraint, maxDepth, singleUseOnly));
1611
1612 // Perform DCE to cleanup.
1613 auto cctx = glow::CompilationContext();
1614 glow::runDCEPass(F, cctx);
1615
1616 // Verify function after splitting nodes.
1617 RETURN_ERR_IF_NOT(F->verify(),
1618 "Function is not valid after recursive node splitting!");
1619 return splitMap;
1620}
1621
1622Expected<SplitNodeMap>
1623glow::splitNodeRecursively(Node *node, const SplitNodeOption &splitOption,
1624 unsigned maxDepth, bool singleUseOnly) {
1625 return splitNodeRecursively(node, &splitOption, nullptr, maxDepth,
1626 singleUseOnly);
1627}
1628
1629Expected<SplitNodeMap>
1630glow::splitNodeRecursively(Node *node,
1631 const SplitNodeConstraint &splitConstraint,
1632 unsigned maxDepth, bool singleUseOnly) {
1633 return splitNodeRecursively(node, nullptr, &splitConstraint, maxDepth,
1634 singleUseOnly);
1635}
1636