1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Optimizer/GraphOptimizer/NodeSplitting.h" |
18 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
19 | |
20 | #include <algorithm> |
21 | #include <numeric> |
22 | #include <unordered_map> |
23 | #include <unordered_set> |
24 | #include <vector> |
25 | |
26 | using namespace glow; |
27 | using llvm::dyn_cast; |
28 | |
29 | ///===---------------------------------------------------------------------===// |
30 | /// SplitNodeOption |
31 | ///===---------------------------------------------------------------------===// |
32 | size_t SplitNodeOptionOrthogonal::getSplitDimIdx(dim_t splitDim) const { |
33 | auto splitDimsIt = std::find(splitDims_.begin(), splitDims_.end(), splitDim); |
34 | CHECK(splitDimsIt != splitDims_.end()) |
35 | << "Split dimension '" << splitDim |
36 | << "' invalid! Not registered in this SplitNodeOption!" ; |
37 | return std::distance(splitDims_.begin(), splitDimsIt); |
38 | } |
39 | |
40 | std::vector<dim_t> SplitNodeByNumChunks::splitAlongDim(size_t dim, |
41 | dim_t dimSize) const { |
42 | size_t dimIdx = getSplitDimIdx(dim); |
43 | dim_t numChunks = numChunks_[dimIdx]; |
44 | CHECK((1 <= numChunks) && (numChunks <= dimSize)) |
45 | << "SplitNodeByNumChunks: Invalid number of chunks '" << numChunks |
46 | << "' for splitting a dimension with size '" << dimSize << "'!" ; |
47 | |
48 | dim_t chunkDiv = dimSize / numChunks; |
49 | dim_t chunkRem = dimSize % numChunks; |
50 | |
51 | // Small and big chunk sizes. |
52 | dim_t numChunksBig = chunkRem; |
53 | dim_t chunkSizeSmall = chunkDiv; |
54 | dim_t chunkSizeBig = chunkDiv + 1; |
55 | |
56 | // Split dimension. |
57 | std::vector<dim_t> chunkSizes(numChunks); |
58 | for (size_t idx = 0, end = numChunks; idx < end; ++idx) { |
59 | dim_t chunkSize = chunkSizeSmall; |
60 | if (bigChunksFirst_ && (idx < numChunksBig)) { |
61 | chunkSize = chunkSizeBig; |
62 | } |
63 | if ((!bigChunksFirst_ && (idx >= numChunks - numChunksBig))) { |
64 | chunkSize = chunkSizeBig; |
65 | } |
66 | chunkSizes[idx] = chunkSize; |
67 | } |
68 | return chunkSizes; |
69 | } |
70 | |
71 | std::vector<dim_t> SplitNodeByChunkSize::splitAlongDim(size_t dim, |
72 | dim_t dimSize) const { |
73 | size_t dimIdx = getSplitDimIdx(dim); |
74 | dim_t chunkSize = chunkSizes_[dimIdx]; |
75 | CHECK((1 <= chunkSize) && (chunkSize <= dimSize)) |
76 | << "SplitNodeByChunkSize: Invalid chunk size '" << chunkSize |
77 | << "' for splitting a dimension with size '" << dimSize << "'!" ; |
78 | |
79 | dim_t chunkDiv = dimSize / chunkSize; |
80 | dim_t chunkRem = dimSize % chunkSize; |
81 | |
82 | // Small and big chunk sizes. |
83 | dim_t numChunks = chunkRem > 0 ? chunkDiv + 1 : chunkDiv; |
84 | dim_t chunkSizeSmall = chunkRem > 0 ? chunkRem : chunkSize; |
85 | dim_t chunkSizeBig = chunkSize; |
86 | |
87 | // Split dimension. |
88 | std::vector<dim_t> chunkSizes(numChunks); |
89 | for (size_t idx = 0, end = numChunks; idx < end; ++idx) { |
90 | dim_t chunkSizeFinal = chunkSizeBig; |
91 | if (bigChunksFirst_ && (idx == numChunks - 1)) { |
92 | chunkSizeFinal = chunkSizeSmall; |
93 | } |
94 | if (!bigChunksFirst_ && (idx == 0)) { |
95 | chunkSizeFinal = chunkSizeSmall; |
96 | } |
97 | chunkSizes[idx] = chunkSizeFinal; |
98 | } |
99 | return chunkSizes; |
100 | } |
101 | |
102 | std::vector<dim_t> SplitNodeByChunkSizes::splitAlongDim(size_t dim, |
103 | dim_t dimSize) const { |
104 | size_t dimIdx = getSplitDimIdx(dim); |
105 | std::vector<dim_t> chunkSizes = chunkSizes_[dimIdx]; |
106 | size_t numChunks = chunkSizes.size(); |
107 | CHECK((1 <= numChunks) && (numChunks <= dimSize)) |
108 | << "SplitNodeByChunkSizes: Invalid number of sizes '" << numChunks |
109 | << "' for splitting a dimension with size '" << dimSize << "'!" ; |
110 | for (const auto &chunkSize : chunkSizes) { |
111 | CHECK_GT(chunkSize, 0) |
112 | << "SplitNodeByChunkSizes: Chunk size 0 is not allowed!" ; |
113 | } |
114 | return chunkSizes; |
115 | } |
116 | |
117 | std::vector<dim_t> SplitNodeByChunkWeights::splitAlongDim(size_t dim, |
118 | dim_t dimSize) const { |
119 | size_t dimIdx = getSplitDimIdx(dim); |
120 | const std::vector<float> &chunkWeights = chunkWeights_[dimIdx]; |
121 | dim_t numChunks = chunkWeights.size(); |
122 | CHECK((1 <= numChunks) && (numChunks <= dimSize)) |
123 | << "SplitNodeByChunkWeights: Invalid number of weights '" << numChunks |
124 | << "' for splitting a dimension with size '" << dimSize << "'!" ; |
125 | |
126 | // Verify that all the weights are positive and compute the weights sum. |
127 | float chunkWeightsSum = 0; |
128 | for (const auto &weight : chunkWeights) { |
129 | CHECK_GT(weight, 0.f) << "SplitNodeByChunkWeights: Chunk weight '" << weight |
130 | << "' invalid! Should be strictly positive!" ; |
131 | chunkWeightsSum += weight; |
132 | } |
133 | |
134 | // Compute individual chunk sizes such that each chunk gets at least one unit |
135 | // (empty chunks are NOT allowed). The total number of units is distributed |
136 | // such that the error between the given chunk weights and the actual chunk |
137 | // weights is minimized. |
138 | std::vector<dim_t> chunkSizes(numChunks, 1); |
139 | dim_t unitsRem = dimSize - numChunks; |
140 | while (unitsRem > 0) { |
141 | // Find chunk with maximum weight error. |
142 | float weightErrMax = std::numeric_limits<float>::lowest(); |
143 | size_t weightErrMaxIdx = 0; |
144 | for (size_t idx = 0; idx < numChunks; ++idx) { |
145 | float weightVal = |
146 | float(chunkSizes[idx]) / float(dimSize) * chunkWeightsSum; |
147 | // We use a signed error here to starve those chunks for which the actual |
148 | // weight surpassed the given weight. |
149 | float weightErr = chunkWeights[idx] - weightVal; |
150 | if (weightErr > weightErrMax) { |
151 | weightErrMax = weightErr; |
152 | weightErrMaxIdx = idx; |
153 | } |
154 | } |
155 | // Distribute unit. |
156 | chunkSizes[weightErrMaxIdx] += 1; |
157 | unitsRem--; |
158 | } |
159 | return chunkSizes; |
160 | } |
161 | |
162 | /// Utility function to split an array of slice ranges \p ranges along the given |
163 | /// dimension \p dim using the split option \p splitOption. |
164 | static std::vector<SliceRange> |
165 | splitSliceRanges(const std::vector<SliceRange> &ranges, size_t dim, |
166 | const SplitNodeOptionOrthogonal *splitOption) { |
167 | std::vector<SliceRange> outRanges; |
168 | for (const auto &range : ranges) { |
169 | |
170 | // Split dimension. |
171 | dim_t dimSize = range.getDimSize(dim); |
172 | std::vector<dim_t> chunkSizes = splitOption->splitAlongDim(dim, dimSize); |
173 | |
174 | // Check for empty chunks. |
175 | for (auto chunkSize : chunkSizes) { |
176 | CHECK_GT(chunkSize, 0) << "Chunk size 0 is not allowed!" ; |
177 | } |
178 | |
179 | // Check dimension splitting consistency. |
180 | dim_t chunkSizesSum = |
181 | std::accumulate(chunkSizes.begin(), chunkSizes.end(), (dim_t)0); |
182 | CHECK_EQ(dimSize, chunkSizesSum) |
183 | << "Inconsistent splitting of dimension " << dim << " with size " |
184 | << dimSize << " into chunks with total size " << chunkSizesSum << "!" ; |
185 | |
186 | // Split current slice range. |
187 | auto numChunks = chunkSizes.size(); |
188 | std::vector<SliceRange> splitRanges(numChunks, range); |
189 | dim_t chunkStart = range[dim].first; |
190 | for (size_t idx = 0; idx < numChunks; ++idx) { |
191 | // Current chunk size. |
192 | dim_t chunkSize = chunkSizes[idx]; |
193 | // Update chunk bounds. |
194 | splitRanges[idx][dim].first = chunkStart; |
195 | splitRanges[idx][dim].second = chunkStart + chunkSize; |
196 | chunkStart += chunkSize; |
197 | } |
198 | CHECK_EQ(splitRanges.back()[dim].second, range[dim].second) |
199 | << "Inconsistent splitting of SliceRange!" ; |
200 | |
201 | // Append split slice ranges. |
202 | outRanges.insert(outRanges.end(), splitRanges.begin(), splitRanges.end()); |
203 | } |
204 | return outRanges; |
205 | } |
206 | |
207 | ///===---------------------------------------------------------------------===// |
208 | /// CheckedSliceRangeMap |
209 | ///===---------------------------------------------------------------------===// |
210 | /// Definition of a checked slice range which provides an extra boolean flag to |
211 | /// inform whether the slice range is valid and allowed to be used. |
212 | using CheckedSliceRange = std::pair<bool, SliceRange>; |
213 | |
214 | /// Definition of a functional mapping between two slice ranges with extra |
215 | /// information about whether the mapping is allowed to be used (valid). |
216 | using CheckedSliceRangeMap = |
217 | std::function<CheckedSliceRange(const SliceRange &)>; |
218 | |
219 | /// Identity checked slice range map to use for simple identity mappings. |
220 | CheckedSliceRange CheckedSliceRangeMapIdentity(const SliceRange &range) { |
221 | return {true, range}; |
222 | } |
223 | |
224 | /// Definition of a pair with an operand index and a checked slice range map. |
225 | using OpIdxAndMap = std::pair<unsigned, CheckedSliceRangeMap>; |
226 | |
227 | /// Utility function to verify that a given slice range \p map represents an |
228 | /// exact mapping from \p mapInputRanges to \p mapOutputRanges. |
229 | static bool isMappingExact(const CheckedSliceRangeMap &map, |
230 | const std::vector<SliceRange> &mapInputRanges, |
231 | const std::vector<SliceRange> &mapOutputRanges) { |
232 | bool mapOk = true; |
233 | DCHECK_EQ(mapInputRanges.size(), mapOutputRanges.size()) |
234 | << "Slice ranges length mismatch for CheckedSliceRangeMap verification!" ; |
235 | for (size_t idx = 0, e = mapInputRanges.size(); idx < e; ++idx) { |
236 | auto checkedSliceRange = map(mapInputRanges[idx]); |
237 | mapOk = mapOk && checkedSliceRange.first; |
238 | mapOk = mapOk && (checkedSliceRange.second == mapOutputRanges[idx]); |
239 | } |
240 | return mapOk; |
241 | } |
242 | |
243 | /// Utility function to verify that a given slice range \p map when applied to |
244 | /// \p mapInputRanges produces ranges which are included in \p mapOutputRanges. |
245 | /// This is a weaker verification than \ref isMappingExact. |
246 | static bool isMappingIncluded(const CheckedSliceRangeMap &map, |
247 | const std::vector<SliceRange> &mapInputRanges, |
248 | const std::vector<SliceRange> &mapOutputRanges) { |
249 | bool mapOk = true; |
250 | DCHECK_EQ(mapInputRanges.size(), mapOutputRanges.size()) |
251 | << "Slice ranges length mismatch for CheckedSliceRangeMap verification!" ; |
252 | for (size_t idx = 0, e = mapInputRanges.size(); idx < e; ++idx) { |
253 | auto checkedSliceRange = map(mapInputRanges[idx]); |
254 | mapOk = mapOk && checkedSliceRange.first; |
255 | mapOk = |
256 | mapOk && checkedSliceRange.second.isIncludedBy(mapOutputRanges[idx]); |
257 | } |
258 | return mapOk; |
259 | } |
260 | |
261 | ///===---------------------------------------------------------------------===// |
262 | /// SplitNodeModifier |
263 | ///===---------------------------------------------------------------------===// |
264 | /// Definition of a function which modifies the split node \p splitNode after it |
265 | /// was cloned from the original node \p origNode. The input slice ranges \p |
266 | /// inputSliceRanges and the output slice ranges \p outputSliceRanges are also |
267 | /// provided by the caller to provide extra context about how the split node was |
268 | /// obtained from the original node. This function is provided to the node |
269 | /// splitting procedure as a callback and provides the mechanism of modifying |
270 | /// the split node attributes in special situations, for example when the "Pads" |
271 | /// or "Group" node attributes must be changed when splitting Convolution nodes. |
272 | using SplitNodeModifier = |
273 | std::function<void(const Node *origNode, Node *splitNode, |
274 | const std::vector<SliceRange> &inputSliceRanges, |
275 | const std::vector<SliceRange> &outputSliceRanges)>; |
276 | |
277 | /// Definition of a "nop" split node modifier which does no modifications. |
278 | void SplitNodeModifierNop(const Node *origNode, Node *splitNode, |
279 | const std::vector<SliceRange> &inputSliceRanges, |
280 | const std::vector<SliceRange> &outputSliceRanges) {} |
281 | |
282 | ///===---------------------------------------------------------------------===// |
283 | /// verifySplitParams |
284 | ///===---------------------------------------------------------------------===// |
285 | /// List of nodes for which there is a weak mapping between input and output |
286 | /// and thus a weaker verification must be performed. Such an example is the |
287 | /// Conv2D/MaxPool node when using strides larger than 1 resulting in cases |
288 | /// where the output operand does not reference the input operand entirely. |
289 | static std::vector<Kinded::Kind> weakOutToInMappingNodeKinds = { |
290 | Kinded::Kind::ConvolutionNodeKind, |
291 | Kinded::Kind::MaxPoolNodeKind, |
292 | Kinded::Kind::AvgPoolNodeKind, |
293 | }; |
294 | |
295 | /// Function to verify the split parameters. |
296 | static Error |
297 | verifySplitParams(const Node *node, dim_t splitOutputIdx, |
298 | const llvm::ArrayRef<size_t> &splitDims, |
299 | const llvm::ArrayRef<OpIdxAndMap> &inputIdxAndMaps, |
300 | const llvm::ArrayRef<OpIdxAndMap> &outputIdxAndMaps) { |
301 | |
302 | // Verify original node. |
303 | if (!node->verify()) { |
304 | llvm::errs() << node->toString() << "\n" ; |
305 | return MAKE_ERR("Invalid node given to node splitting procedure!" ); |
306 | } |
307 | |
308 | // Verify split dims. |
309 | RETURN_ERR_IF_NOT(splitDims.size() > 0, |
310 | "Empty split dimensions for splitting node!" ); |
311 | RETURN_ERR_IF_NOT(splitOutputIdx < node->getNumResults(), |
312 | "Invalid output index for splitting node!" ); |
313 | for (size_t dim = 0; dim < splitDims.size() - 1; ++dim) { |
314 | RETURN_ERR_IF_NOT(splitDims[dim] < splitDims[dim + 1], |
315 | "Invalid split dimensions for splitting node! Dimensions " |
316 | "should be given in ascending order e.g. {0,2,3}!" ); |
317 | } |
318 | for (const auto dim : splitDims) { |
319 | RETURN_ERR_IF_NOT(dim < node->getType(splitOutputIdx)->dims().size(), |
320 | "Invalid split dimension for splitting node! Dimension " |
321 | "exceeds the split output tensor shape!" ); |
322 | } |
323 | |
324 | // Verify all the input indices and maps were given. |
325 | RETURN_ERR_IF_NOT(inputIdxAndMaps.size() == node->getNumInputs(), |
326 | "Invalid number of input maps for splitting node!" ); |
327 | std::vector<bool> inputIdxMask(node->getNumInputs(), false); |
328 | for (const auto &inputIdxMap : inputIdxAndMaps) { |
329 | RETURN_ERR_IF_NOT(inputIdxMap.first < node->getNumInputs(), |
330 | "Invalid input index for input range map!" ); |
331 | inputIdxMask[inputIdxMap.first] = true; |
332 | } |
333 | RETURN_ERR_IF_NOT( |
334 | std::find(inputIdxMask.begin(), inputIdxMask.end(), false) == |
335 | inputIdxMask.end(), |
336 | "Not all input indices and maps were provided for splitting node!" ); |
337 | |
338 | // Verify all the output indices and maps were given. |
339 | RETURN_ERR_IF_NOT(outputIdxAndMaps.size() == node->getNumResults() - 1, |
340 | "Invalid number of output maps for splitting node!" ); |
341 | std::vector<bool> outputIdxMask(node->getNumResults(), false); |
342 | outputIdxMask[splitOutputIdx] = true; |
343 | for (const auto &outputIdxMap : outputIdxAndMaps) { |
344 | RETURN_ERR_IF_NOT(outputIdxMap.first < node->getNumResults(), |
345 | "Invalid output index for output range map!" ); |
346 | outputIdxMask[outputIdxMap.first] = true; |
347 | } |
348 | RETURN_ERR_IF_NOT( |
349 | std::find(outputIdxMask.begin(), outputIdxMask.end(), false) == |
350 | outputIdxMask.end(), |
351 | "Not all output indices and maps were provided for splitting node!" ); |
352 | |
353 | // Get split output range. |
354 | SliceRange splitOutputRange = SliceRange(node->getType(splitOutputIdx)); |
355 | |
356 | // Verify the input slice range maps. |
357 | for (const auto &inputIdxMap : inputIdxAndMaps) { |
358 | SliceRange inputRange = |
359 | SliceRange(node->getNthInput(inputIdxMap.first).getType()); |
360 | if (std::find(weakOutToInMappingNodeKinds.begin(), |
361 | weakOutToInMappingNodeKinds.end(), |
362 | node->getKind()) != weakOutToInMappingNodeKinds.end()) { |
363 | // Verify weak mapping. |
364 | RETURN_ERR_IF_NOT(isMappingIncluded(inputIdxMap.second, |
365 | {splitOutputRange}, {inputRange}), |
366 | "Invalid input range map for splitting node!" ); |
367 | } else { |
368 | // Verify exact mapping. |
369 | RETURN_ERR_IF_NOT( |
370 | isMappingExact(inputIdxMap.second, {splitOutputRange}, {inputRange}), |
371 | "Invalid input range map for splitting node!" ); |
372 | } |
373 | } |
374 | |
375 | // Verify the output slice range maps. |
376 | for (const auto &outputIdxMap : outputIdxAndMaps) { |
377 | SliceRange outputRange = SliceRange(node->getType(outputIdxMap.first)); |
378 | RETURN_ERR_IF_NOT( |
379 | isMappingExact(outputIdxMap.second, {splitOutputRange}, {outputRange}), |
380 | "Invalid output range map for splitting node!" ); |
381 | } |
382 | |
383 | return Error::success(); |
384 | } |
385 | |
386 | ///===---------------------------------------------------------------------===// |
387 | /// verifySplitNodes |
388 | ///===---------------------------------------------------------------------===// |
389 | /// Function to verify the split nodes. |
390 | static Expected<bool> |
391 | verifySplitNodes(const Node *node, dim_t splitOutputIdx, |
392 | const llvm::ArrayRef<SliceRange> &splitOutputSlices, |
393 | const llvm::ArrayRef<OpIdxAndMap> &inputIdxAndMaps, |
394 | const llvm::ArrayRef<OpIdxAndMap> &outputIdxAndMaps, |
395 | const SplitNodeConstraint *splitConstraint, |
396 | const SplitNodeModifier &splitNodeModifier) { |
397 | |
398 | // Create temporary nodes to make verifications without adding them to |
399 | // the graph in order to avoid the pollution of the graph with nodes |
400 | // which could be invalid or could not meet all the constraints and |
401 | // hence be later removed from the graph. |
402 | bool splitNodesCheck = true; |
403 | std::vector<Node *> splitNodes; |
404 | std::list<std::unique_ptr<SliceNode>> inputSliceNodes; |
405 | std::list<Type> inputTypes; |
406 | std::list<Type> outputTypes; |
407 | for (const auto &splitOutputSlice : splitOutputSlices) { |
408 | |
409 | // Create clone to inherit all the inputs/members of the original node. |
410 | Node *clone = node->clone(); |
411 | splitNodes.push_back(clone); |
412 | |
413 | // Detach clone from all the inputs of the original node. |
414 | for (unsigned idx = 0, e = clone->getNumInputs(); idx < e; ++idx) { |
415 | clone->setNthInput(idx, nullptr); |
416 | } |
417 | |
418 | // Gather input slice ranges for the clone. The ranges are ordered |
419 | // according to the input operand indices. |
420 | std::vector<SliceRange> inputRanges(clone->getNumInputs()); |
421 | for (const auto &inputIdxMap : inputIdxAndMaps) { |
422 | auto inputCheckedRange = inputIdxMap.second(splitOutputSlice); |
423 | splitNodesCheck = splitNodesCheck && inputCheckedRange.first; |
424 | splitNodesCheck = splitNodesCheck && !inputCheckedRange.second.isEmpty(); |
425 | inputRanges[inputIdxMap.first] = inputCheckedRange.second; |
426 | } |
427 | |
428 | // Gather output slice ranges for the clone. The ranges are ordered |
429 | // according to the output operand indices. |
430 | std::vector<SliceRange> outputRanges(clone->getNumResults()); |
431 | outputRanges[splitOutputIdx] = splitOutputSlice; |
432 | for (const auto &outputIdxMap : outputIdxAndMaps) { |
433 | auto outputCheckedRange = outputIdxMap.second(splitOutputSlice); |
434 | splitNodesCheck = splitNodesCheck && outputCheckedRange.first; |
435 | splitNodesCheck = splitNodesCheck && !outputCheckedRange.second.isEmpty(); |
436 | outputRanges[outputIdxMap.first] = outputCheckedRange.second; |
437 | } |
438 | |
439 | // Early break. |
440 | if (!splitNodesCheck) { |
441 | break; |
442 | } |
443 | |
444 | // Set clone input types. Since a node does not own its input types and |
445 | // the clone inherits the input types from the input nodes of the |
446 | // original node we create here dummy input SliceNodes and attach them |
447 | // to the clone in order to allow setting and checking the clone input |
448 | // types without modifying the input types of the original node. |
449 | for (const auto &inputIdxMap : inputIdxAndMaps) { |
450 | auto &inputRange = inputRanges[inputIdxMap.first]; |
451 | auto inputType = |
452 | Type::newShape(*(node->getNthInput(inputIdxMap.first).getType()), |
453 | inputRange.getSizes()); |
454 | inputTypes.push_back(inputType); |
455 | inputSliceNodes.push_back(std::make_unique<SliceNode>( |
456 | "inputSlice" , &(inputTypes.back()), |
457 | /* Input */ nullptr, inputRange.getStarts())); |
458 | clone->setNthInput(inputIdxMap.first, inputSliceNodes.back().get()); |
459 | } |
460 | |
461 | // Set clone split output type. The original node output type is not |
462 | // modified because the clone owns its output types. |
463 | outputTypes.push_back(Type::newShape(*node->getType(splitOutputIdx), |
464 | splitOutputSlice.getSizes())); |
465 | clone->getNthResult(splitOutputIdx).setTypeUnsafe(&outputTypes.back()); |
466 | |
467 | // Set clone output types. The original node output types are not |
468 | // modified because the clone owns its output types. |
469 | for (const auto &outputIdxMap : outputIdxAndMaps) { |
470 | auto &outputRange = outputRanges[outputIdxMap.first]; |
471 | auto outputType = Type::newShape(*node->getType(outputIdxMap.first), |
472 | outputRange.getSizes()); |
473 | outputTypes.push_back(outputType); |
474 | clone->getNthResult(outputIdxMap.first) |
475 | .setTypeUnsafe(&outputTypes.back()); |
476 | } |
477 | |
478 | // Modify clone. |
479 | splitNodeModifier(node, clone, inputRanges, outputRanges); |
480 | |
481 | // Verify clone. If the clone is invalid at this point this means there |
482 | // is a logic error in the splitting infrastructure (the input/output |
483 | // maps are not checked properly or the split node modifier is flawed) |
484 | // so we throw an error (not the same thing as returning false which is |
485 | // intended for signaling that the splitting infrastructure correctly |
486 | // identified an incorrect split configuration or the split configuration |
487 | // is not accepted by the user constraints). |
488 | if (!clone->verify()) { |
489 | // Dump some extra error context. |
490 | llvm::errs() << "Slice range description:\n" ; |
491 | for (unsigned idx = 0, e = clone->getNumInputs(); idx < e; ++idx) { |
492 | llvm::errs() << clone->getInputName(idx) << ": " |
493 | << inputRanges[idx].toString() << "\n" ; |
494 | } |
495 | for (unsigned idx = 0, e = clone->getNumResults(); idx < e; ++idx) { |
496 | llvm::errs() << clone->getOutputName(idx).str() << ": " |
497 | << outputRanges[idx].toString() << "\n" ; |
498 | } |
499 | llvm::errs() << "Node description:\n" ; |
500 | llvm::errs() << clone->toString() << "\n" ; |
501 | return MAKE_ERR("Invalid node obtained during node splitting!" ); |
502 | } |
503 | |
504 | // Early break. |
505 | if (!splitNodesCheck) { |
506 | break; |
507 | } |
508 | } |
509 | |
510 | // Check split nodes against user constraint (if any). |
511 | if (splitConstraint) { |
512 | splitNodesCheck = splitNodesCheck && (*splitConstraint)(node, splitNodes); |
513 | } |
514 | |
515 | // Explicitly destroy the temporary nodes. |
516 | for (auto *splitNode : splitNodes) { |
517 | Node::destroyNode(splitNode); |
518 | } |
519 | |
520 | return splitNodesCheck; |
521 | } |
522 | |
523 | ///===---------------------------------------------------------------------===// |
524 | /// splitAndReplaceNode |
525 | ///===---------------------------------------------------------------------===// |
526 | static Expected<std::vector<Node *>> splitAndReplaceNode( |
527 | Node *node, const SplitNodeOption *splitOption, |
528 | const SplitNodeConstraint *splitConstraint, dim_t splitOutputIdx, |
529 | const llvm::ArrayRef<OpIdxAndMap> &inputIdxAndMaps, |
530 | const llvm::ArrayRef<OpIdxAndMap> &outputIdxAndMaps = {}, |
531 | const SplitNodeModifier &splitNodeModifier = SplitNodeModifierNop) { |
532 | |
533 | // If the split output operand has no dimensions then return. |
534 | if (node->getType(splitOutputIdx)->dims().empty()) { |
535 | return std::vector<Node *>(); |
536 | } |
537 | |
538 | // The default split dims are all the dims of the split output operand. |
539 | RETURN_ERR_IF_NOT(splitOutputIdx < node->getNumResults(), |
540 | "Invalid output index for splitting node!" ); |
541 | std::vector<size_t> splitDims(node->getType(splitOutputIdx)->dims().size()); |
542 | std::iota(splitDims.begin(), splitDims.end(), 0); |
543 | |
544 | // Explicit split dims for this node. |
545 | if (splitOption) { |
546 | // We use explicit split dims only for orthogonal option. |
547 | auto *splitOptionOrthogonal = |
548 | dyn_cast<SplitNodeOptionOrthogonal>(splitOption); |
549 | if (splitOptionOrthogonal) { |
550 | splitDims = splitOptionOrthogonal->getSplitDims(); |
551 | } |
552 | } |
553 | |
554 | // Verify split parameters. |
555 | RETURN_IF_ERR(verifySplitParams(node, splitOutputIdx, splitDims, |
556 | inputIdxAndMaps, outputIdxAndMaps)); |
557 | |
558 | // ------------------------------- Split output ------------------------------ |
559 | // Initialize the split output slices with the initial output range. |
560 | SliceRange splitOutputRange = SliceRange(node->getType(splitOutputIdx)); |
561 | std::vector<SliceRange> splitOutputSlices = {splitOutputRange}; |
562 | |
563 | // If a specific split option is given then we do a targeted splitting. |
564 | // If no specific split option is given then we search a split configuration |
565 | // which meets the constraint. |
566 | if (splitOption) { |
567 | |
568 | // Orthogonal: Split along all the given dimensions using the given option. |
569 | // Non-orthogonal: Use the raw slice ranges explicitly provided. |
570 | auto *splitOptionOrthogonal = |
571 | dyn_cast<SplitNodeOptionOrthogonal>(splitOption); |
572 | if (splitOptionOrthogonal) { |
573 | for (size_t splitDim : splitDims) { |
574 | splitOutputSlices = splitSliceRanges(splitOutputSlices, splitDim, |
575 | splitOptionOrthogonal); |
576 | } |
577 | } else { |
578 | // Set raw slice ranges. |
579 | auto *splitOptionRaw = dyn_cast<SplitNodeBySliceRanges>(splitOption); |
580 | RETURN_ERR_IF_NOT(splitOptionRaw, |
581 | "Non orthogonal split option not supported!" ); |
582 | splitOutputSlices = splitOptionRaw->getSliceRanges(); |
583 | // We verify that the explicitly provided slice ranges are valid. |
584 | for (const auto &sliceRange : splitOutputSlices) { |
585 | RETURN_ERR_IF_NOT( |
586 | sliceRange.getNumDims() == splitOutputRange.getNumDims(), |
587 | "Non orthogonal slice range rank not equal to output rank!" ); |
588 | RETURN_ERR_IF_NOT( |
589 | sliceRange.isIncludedBy(splitOutputRange), |
590 | "Non orthogonal slice range exceed the output operand range!" ); |
591 | } |
592 | } |
593 | |
594 | // Verify split nodes. |
595 | bool splitNodesCheck = true; |
596 | ASSIGN_VALUE_OR_RETURN_ERR( |
597 | splitNodesCheck, |
598 | verifySplitNodes(node, splitOutputIdx, splitOutputSlices, |
599 | inputIdxAndMaps, outputIdxAndMaps, splitConstraint, |
600 | splitNodeModifier)); |
601 | |
602 | // If split nodes are invalid then we do not perform any splitting. |
603 | if (!splitNodesCheck) { |
604 | return std::vector<Node *>(); |
605 | } |
606 | |
607 | } else { |
608 | |
609 | // When no split option is given a split constraint is mandatory. |
610 | RETURN_ERR_IF_NOT(splitConstraint, "When a split option is not given then " |
611 | "a split constraint must be given!" ); |
612 | |
613 | // Start searching of a split configuration which meets the constraint by |
614 | // splitting along the given dimensions iteratively in smaller chunks. |
615 | bool splitFound = false; |
616 | for (size_t splitDim : splitDims) { |
617 | |
618 | dim_t splitDimSize = splitOutputRange.getDimSize(splitDim); |
619 | std::vector<SliceRange> splitOutputSlicesTemp; |
620 | for (dim_t dimNumChunks = 1; dimNumChunks <= splitDimSize; |
621 | dimNumChunks++) { |
622 | |
623 | // Split along current dimension in the given number of chunks. |
624 | auto splitOptionSearch = |
625 | SplitNodeByNumChunks({splitDim}, {dimNumChunks}); |
626 | splitOutputSlicesTemp = |
627 | splitSliceRanges(splitOutputSlices, splitDim, &splitOptionSearch); |
628 | |
629 | // Verify split nodes. |
630 | bool splitNodesCheck = true; |
631 | ASSIGN_VALUE_OR_RETURN_ERR( |
632 | splitNodesCheck, |
633 | verifySplitNodes(node, splitOutputIdx, splitOutputSlicesTemp, |
634 | inputIdxAndMaps, outputIdxAndMaps, splitConstraint, |
635 | splitNodeModifier)); |
636 | |
637 | // If split is found we stop searching. |
638 | if (splitNodesCheck) { |
639 | splitFound = true; |
640 | break; |
641 | } |
642 | } |
643 | |
644 | // Save the split output slices. |
645 | splitOutputSlices = splitOutputSlicesTemp; |
646 | |
647 | // If split is found we stop searching. |
648 | if (splitFound) { |
649 | break; |
650 | } |
651 | } |
652 | |
653 | // If split is not found then we do not perform any splitting. |
654 | if (!splitFound) { |
655 | return std::vector<Node *>(); |
656 | } |
657 | } |
658 | |
659 | // If with the current split parameters only one slice is obtained then no |
660 | // splitting is required since the slice is the same as the original node. |
661 | if (splitOutputSlices.size() == 1) { |
662 | return std::vector<Node *>(); |
663 | } |
664 | |
665 | // -------------------------------- Split node ------------------------------- |
666 | // Get parent function. |
667 | Function *F = node->getParent(); |
668 | RETURN_ERR_IF_NOT(F, "Cannot split a node without a parent Function!" ); |
669 | |
670 | // Allocate output tensors used for merging the partial outputs. We only merge |
671 | // the partial outputs for the output operands which are effectively used. |
672 | std::vector<NodeValue> mergedOutputs(node->getNumResults()); |
673 | for (size_t outIdx = 0, outIdxEnd = node->getNumResults(); outIdx < outIdxEnd; |
674 | outIdx++) { |
675 | if (node->getNthResult(outIdx).getNumUsers()) { |
676 | auto nodeName = |
677 | node->getName().str() + ".TouchOutput" + std::to_string(outIdx); |
678 | mergedOutputs[outIdx] = F->createTouch(nodeName, node->getType(outIdx)); |
679 | } else { |
680 | mergedOutputs[outIdx] = nullptr; |
681 | } |
682 | } |
683 | |
684 | // Create split nodes. |
685 | std::vector<Node *> splitNodes(splitOutputSlices.size(), nullptr); |
686 | for (size_t sliceIdx = 0, sliceIdxEnd = splitOutputSlices.size(); |
687 | sliceIdx < sliceIdxEnd; sliceIdx++) { |
688 | |
689 | // Current split output slice. |
690 | const auto &splitOutputSlice = splitOutputSlices[sliceIdx]; |
691 | |
692 | // Create clone to inherit all the inputs/members of the original node. |
693 | Node *clone = node->clone(); |
694 | clone->setName(node->getName().str() + ".Split" + std::to_string(sliceIdx)); |
695 | |
696 | // Gather final input slice ranges for the clone. |
697 | std::vector<SliceRange> inputRanges(clone->getNumInputs()); |
698 | for (const auto &inputIdxMap : inputIdxAndMaps) { |
699 | auto inputCheckedRange = inputIdxMap.second(splitOutputSlice); |
700 | inputRanges[inputIdxMap.first] = inputCheckedRange.second; |
701 | } |
702 | |
703 | // Gather final output slice ranges for the clone. |
704 | std::vector<SliceRange> outputRanges(clone->getNumResults()); |
705 | outputRanges[splitOutputIdx] = splitOutputSlice; |
706 | for (const auto &outputIdxMap : outputIdxAndMaps) { |
707 | auto outputCheckedRange = outputIdxMap.second(splitOutputSlice); |
708 | outputRanges[outputIdxMap.first] = outputCheckedRange.second; |
709 | } |
710 | |
711 | // Create input Slice nodes (only if necessary). |
712 | for (const auto &inputIdxMap : inputIdxAndMaps) { |
713 | auto inputIdx = inputIdxMap.first; |
714 | auto inputNodeValue = node->getNthInput(inputIdx); |
715 | auto &inputSliceRange = inputRanges[inputIdxMap.first]; |
716 | TypeRef inpTy = inputNodeValue.getType(); |
717 | Type outTy = Type::newShape(*(inpTy), inputSliceRange.getSizes()); |
718 | if (outTy.isEqual(inpTy)) { |
719 | clone->setNthInput(inputIdx, inputNodeValue); |
720 | } else { |
721 | auto nodeName = |
722 | clone->getName().str() + ".SliceInput" + std::to_string(inputIdx); |
723 | auto *inputSlice = F->createSlice(nodeName, inputNodeValue, |
724 | inputSliceRange.getStarts(), &outTy); |
725 | clone->setNthInput(inputIdx, inputSlice); |
726 | } |
727 | } |
728 | |
729 | // Set clone split output type. The original node output type is not |
730 | // modified because the clone owns its output types. |
731 | TypeRef splitOutputType = F->getParent()->uniqueTypeWithNewShape( |
732 | node->getType(splitOutputIdx), splitOutputSlice.getSizes()); |
733 | clone->getNthResult(splitOutputIdx).setTypeUnsafe(splitOutputType); |
734 | |
735 | // Set clone output types. The original node output types are not |
736 | // modified because the clone owns its output types. |
737 | for (const auto &outputIdxMap : outputIdxAndMaps) { |
738 | auto &outputRange = outputRanges[outputIdxMap.first]; |
739 | TypeRef outputType = F->getParent()->uniqueTypeWithNewShape( |
740 | node->getType(outputIdxMap.first), outputRange.getSizes()); |
741 | clone->getNthResult(outputIdxMap.first).setTypeUnsafe(outputType); |
742 | } |
743 | |
744 | // Modify clone. |
745 | splitNodeModifier(node, clone, inputRanges, outputRanges); |
746 | |
747 | // Verify clone. |
748 | RETURN_ERR_IF_NOT(clone->verify(), |
749 | "Invalid node obtained during node splitting!" ); |
750 | |
751 | // Add clone to function. |
752 | F->addNode(clone); |
753 | |
754 | // Add clone to vector. |
755 | splitNodes[sliceIdx] = clone; |
756 | |
757 | // Merge the partial outputs of this clone (only if used). |
758 | for (size_t outIdx = 0, outIdxEnd = node->getNumResults(); |
759 | outIdx < outIdxEnd; outIdx++) { |
760 | if (mergedOutputs[outIdx]) { |
761 | auto nodeName = |
762 | clone->getName().str() + ".MergeOutput" + std::to_string(outIdx); |
763 | mergedOutputs[outIdx] = F->createInsertTensor( |
764 | nodeName, mergedOutputs[outIdx], clone->getNthResult(outIdx), |
765 | outputRanges[outIdx].getStarts()); |
766 | } |
767 | } |
768 | } |
769 | |
770 | // Replace all the node outputs with the merged outputs (only if used). |
771 | for (size_t outIdx = 0, outIdxEnd = node->getNumResults(); outIdx < outIdxEnd; |
772 | outIdx++) { |
773 | if (mergedOutputs[outIdx]) { |
774 | node->getNthResult(outIdx).replaceAllUsesOfWith(mergedOutputs[outIdx]); |
775 | } |
776 | } |
777 | |
778 | // Erase original node. |
779 | F->eraseNode(node); |
780 | |
781 | return splitNodes; |
782 | } |
783 | |
784 | namespace { |
785 | ///===---------------------------------------------------------------------===// |
786 | /// Conv |
787 | ///===---------------------------------------------------------------------===// |
788 | /// Structure which contains a valid flag \p check, a dimension range \p range |
789 | /// and a dimension padding \p pads. |
790 | struct CheckedRangeAndPads { |
791 | bool check{false}; |
792 | DimRange range; |
793 | DimPads pads; |
794 | }; |
795 | |
796 | /// Structure which contains a valid flag \p check a dimension range \p range. |
797 | struct CheckedRange { |
798 | bool check{false}; |
799 | DimRange range; |
800 | }; |
801 | } // namespace |
802 | |
803 | static CheckedRangeAndPads |
804 | getConvInputCheckedRangeAndPads(const DimRange &outputSliceRange, |
805 | const DimRange &inputRange, dim_t kernel, |
806 | dim_t stride, DimPads pads, dim_t dilation) { |
807 | |
808 | CHECK_LT(outputSliceRange.first, outputSliceRange.second) |
809 | << "Invalid output slice range!" ; |
810 | CHECK_LT(inputRange.first, inputRange.second) << "Invalid input range!" ; |
811 | CHECK_EQ(inputRange.first, 0) << "Input range must start with 0!" ; |
812 | CHECK_GE(kernel, 1) << "Invalid kernel size!" ; |
813 | CHECK_GE(stride, 1) << "Invalid stride size!" ; |
814 | CHECK_GE(dilation, 1) << "Invalid dilation size!" ; |
815 | |
816 | // Get padded input range. |
817 | dim_t inputStartPadded = inputRange.first + pads.first; |
818 | dim_t inputStopPadded = inputRange.second + pads.first; |
819 | |
820 | // Get padded input slice range. |
821 | dim_t inputSliceStartPadded = outputSliceRange.first * stride; |
822 | dim_t inputSliceStopPadded = |
823 | (outputSliceRange.second - 1) * stride + dilation * (kernel - 1) + 1; |
824 | |
825 | // Verify input slice range bounds. |
826 | dim_t inputSliceStopPaddedMax = pads.first + inputRange.second + pads.second; |
827 | CHECK_LE(inputSliceStopPadded, inputSliceStopPaddedMax) |
828 | << "Input slice range out of bounds!" ; |
829 | |
830 | // Get intersection. |
831 | dim_t intersectStartPadded = |
832 | std::max(inputStartPadded, inputSliceStartPadded); |
833 | dim_t intersectStopPadded = std::min(inputStopPadded, inputSliceStopPadded); |
834 | |
835 | // Get checked input range. |
836 | bool allowed = (intersectStartPadded < intersectStopPadded); |
837 | dim_t inputSliceStart = intersectStartPadded - pads.first; |
838 | dim_t inputSliceStop = |
839 | intersectStopPadded >= pads.first ? intersectStopPadded - pads.first : 0; |
840 | |
841 | // Get start pad. |
842 | dim_t inputSliceStartPad = 0; |
843 | if (inputSliceStartPadded < inputStartPadded) { |
844 | inputSliceStartPad = inputStartPadded - inputSliceStartPadded; |
845 | } |
846 | |
847 | // Get stop pad. |
848 | dim_t inputSliceStopPad = 0; |
849 | if (inputSliceStopPadded > inputStopPadded) { |
850 | inputSliceStopPad = inputSliceStopPadded - inputStopPadded; |
851 | } |
852 | |
853 | DimRange inputSliceRange = {inputSliceStart, inputSliceStop}; |
854 | DimPads inputSlicePads = {inputSliceStartPad, inputSliceStopPad}; |
855 | return CheckedRangeAndPads{allowed, inputSliceRange, inputSlicePads}; |
856 | } |
857 | |
858 | static CheckedRange |
859 | getConvInputChannelCheckedRange(const DimRange &outputSliceRange, |
860 | const DimRange &inputRange, dim_t inputChannels, |
861 | dim_t outputChannels, dim_t group) { |
862 | |
863 | CHECK_EQ(inputChannels % group, 0) |
864 | << "Input channels must be divisible by group!" ; |
865 | CHECK_EQ(outputChannels % group, 0) |
866 | << "Output channels must be divisible by group!" ; |
867 | |
868 | dim_t inputChannelsPerGroup = inputChannels / group; |
869 | dim_t outputChannelsPerGroup = outputChannels / group; |
870 | dim_t outputSliceChannels = SliceRange({outputSliceRange}).getDimSize(0); |
871 | |
872 | // Output slice range start/stop group index (inclusive). |
873 | dim_t outputSliceRangeStartGroupIdx = |
874 | outputSliceRange.first / outputChannelsPerGroup; |
875 | dim_t outputSliceRangeStopGroupIdx = |
876 | (outputSliceRange.second - 1) / outputChannelsPerGroup; |
877 | |
878 | bool allowed = false; |
879 | if (outputSliceChannels <= outputChannelsPerGroup) { |
880 | // If the output slice range spans fully or partially one group then both |
881 | // ends of the range must be part of the same group. |
882 | allowed = (outputSliceRangeStartGroupIdx == outputSliceRangeStopGroupIdx); |
883 | } else { |
884 | // If the output slice range spans multiple groups then both ends of the |
885 | // range must be aligned to outputChannelsPerGroup. |
886 | allowed = SliceRange({outputSliceRange}) |
887 | .isDimRangeAligned(0, outputChannelsPerGroup); |
888 | } |
889 | |
890 | // Compute input slice range as a multiple of groups. |
891 | DimRange inputSliceRange; |
892 | inputSliceRange.first = outputSliceRangeStartGroupIdx * inputChannelsPerGroup; |
893 | inputSliceRange.second = |
894 | (outputSliceRangeStopGroupIdx + 1) * inputChannelsPerGroup; |
895 | return CheckedRange{allowed, inputSliceRange}; |
896 | } |
897 | |
898 | ///===---------------------------------------------------------------------===// |
899 | /// Conv2D |
900 | ///===---------------------------------------------------------------------===// |
901 | template <typename ConvNodeTy, typename Shape> |
902 | static std::vector<OpIdxAndMap> |
903 | getConv2DInputIdxAndMaps(const ConvNodeTy *node) { |
904 | |
905 | ShapeHW kernels = ShapeHW(node->getKernels()); |
906 | ShapeHW strides = ShapeHW(node->getStrides()); |
907 | PaddingTLBR pads(node->getPads()); |
908 | unsigned_t group = node->getGroup(); |
909 | ShapeHW dilations = ShapeHW(node->getDilation()); |
910 | DimPads padsTB = {pads.top, pads.bottom}; |
911 | DimPads padsLR = {pads.left, pads.right}; |
912 | |
913 | SliceRange inputRange = SliceRange(node->getInput().getType()); |
914 | SliceRange filterRange = SliceRange(node->getFilter().getType()); |
915 | SliceRange outputRange = SliceRange(node->getResult().getType()); |
916 | |
917 | // Output slice to input slice range map. |
918 | CheckedSliceRangeMap inputSliceRangeMap = |
919 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
920 | // Get range and pads for dimension H. |
921 | auto checkedRangeAndPadsH = getConvInputCheckedRangeAndPads( |
922 | outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height, |
923 | strides.height, padsTB, dilations.height); |
924 | |
925 | // Get range and pads for dimension W. |
926 | auto checkedRangeAndPadsW = getConvInputCheckedRangeAndPads( |
927 | outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width, |
928 | strides.width, padsLR, dilations.width); |
929 | |
930 | // Get range for dimension C. |
931 | auto checkedRangeC = getConvInputChannelCheckedRange( |
932 | outputSliceRange[Shape::DimC], inputRange[Shape::DimC], |
933 | inputRange.getDimSize(Shape::DimC), outputRange.getDimSize(Shape::DimC), |
934 | group); |
935 | |
936 | std::vector<DimRange> inputDimRanges(4); |
937 | inputDimRanges[Shape::DimN] = outputSliceRange[Shape::DimN]; |
938 | inputDimRanges[Shape::DimH] = checkedRangeAndPadsH.range; |
939 | inputDimRanges[Shape::DimW] = checkedRangeAndPadsW.range; |
940 | inputDimRanges[Shape::DimC] = checkedRangeC.range; |
941 | bool allowed = checkedRangeAndPadsH.check && checkedRangeAndPadsW.check && |
942 | checkedRangeC.check; |
943 | return {allowed, SliceRange(inputDimRanges)}; |
944 | }; |
945 | |
946 | // Output slice to filter slice range map. |
947 | CheckedSliceRangeMap filterSliceRangeMap = |
948 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
949 | std::vector<DimRange> filterDimRanges(4); |
950 | filterDimRanges[Shape::DimN] = outputSliceRange[Shape::DimC]; |
951 | filterDimRanges[Shape::DimH] = filterRange[Shape::DimH]; |
952 | filterDimRanges[Shape::DimW] = filterRange[Shape::DimW]; |
953 | filterDimRanges[Shape::DimC] = filterRange[Shape::DimC]; |
954 | return {true, SliceRange(filterDimRanges)}; |
955 | }; |
956 | |
957 | // Output slice to bias slice range map. |
958 | CheckedSliceRangeMap biasSliceRangeMap = |
959 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
960 | return {true, SliceRange({outputSliceRange[Shape::DimC]})}; |
961 | }; |
962 | |
963 | // Return input indices and maps. |
964 | if (std::is_same<ConvNodeTy, ConvolutionNode>::value) { |
965 | return {{ConvolutionNode::InputIdx, inputSliceRangeMap}, |
966 | {ConvolutionNode::FilterIdx, filterSliceRangeMap}, |
967 | {ConvolutionNode::BiasIdx, biasSliceRangeMap}}; |
968 | } else if (std::is_same<ConvNodeTy, |
969 | ChannelwiseQuantizedConvolutionNode>::value) { |
970 | return { |
971 | {ChannelwiseQuantizedConvolutionNode::InputIdx, inputSliceRangeMap}, |
972 | {ChannelwiseQuantizedConvolutionNode::FilterIdx, filterSliceRangeMap}, |
973 | {ChannelwiseQuantizedConvolutionNode::BiasIdx, biasSliceRangeMap}, |
974 | {ChannelwiseQuantizedConvolutionNode::FilterScalesIdx, |
975 | biasSliceRangeMap}, |
976 | {ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx, |
977 | biasSliceRangeMap}, |
978 | {ChannelwiseQuantizedConvolutionNode::BiasScalesIdx, biasSliceRangeMap}, |
979 | {ChannelwiseQuantizedConvolutionNode::BiasOffsetsIdx, |
980 | biasSliceRangeMap}}; |
981 | } |
982 | llvm_unreachable("Invalid Convolution node type!" ); |
983 | } |
984 | |
985 | template <typename ConvNodeTy, typename Shape> |
986 | void Conv2DSplitNodeModifier(const Node *origNode, Node *splitNode, |
987 | const std::vector<SliceRange> &inputSliceRanges, |
988 | const std::vector<SliceRange> &outputSliceRanges) { |
989 | auto *convOrigNode = dyn_cast<ConvNodeTy>(origNode); |
990 | auto *convSplitNode = dyn_cast<ConvNodeTy>(splitNode); |
991 | if (!(convOrigNode && convSplitNode)) { |
992 | return; |
993 | } |
994 | |
995 | ShapeHW kernels = ShapeHW(convOrigNode->getKernels()); |
996 | ShapeHW strides = ShapeHW(convOrigNode->getStrides()); |
997 | PaddingTLBR pads(convOrigNode->getPads()); |
998 | ShapeHW dilations = ShapeHW(convOrigNode->getDilation()); |
999 | DimPads padsTB = {pads.top, pads.bottom}; |
1000 | DimPads padsLR = {pads.left, pads.right}; |
1001 | |
1002 | // Get paddings for split node. |
1003 | auto outputSliceRange = outputSliceRanges[ConvNodeTy::ResultIdx]; |
1004 | auto inputRange = SliceRange(convOrigNode->getInput().getType()); |
1005 | auto checkedRangeAndPadsH = getConvInputCheckedRangeAndPads( |
1006 | outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height, |
1007 | strides.height, padsTB, dilations.height); |
1008 | auto checkedRangeAndPadsW = getConvInputCheckedRangeAndPads( |
1009 | outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width, |
1010 | strides.width, padsLR, dilations.width); |
1011 | DimPads splitPadsTB = checkedRangeAndPadsH.pads; |
1012 | DimPads splitPadsLR = checkedRangeAndPadsW.pads; |
1013 | |
1014 | // Modify paddings for split node. |
1015 | std::vector<unsigned_t> splitPads = { |
1016 | static_cast<unsigned_t>(splitPadsTB.first), |
1017 | static_cast<unsigned_t>(splitPadsLR.first), |
1018 | static_cast<unsigned_t>(splitPadsTB.second), |
1019 | static_cast<unsigned_t>(splitPadsLR.second)}; |
1020 | convSplitNode->setPads(splitPads); |
1021 | |
1022 | // Modify group for split node. |
1023 | dim_t outputChannels = |
1024 | SliceRange(convOrigNode->getType(ConvNodeTy::ResultIdx)) |
1025 | .getDimSize(Shape::DimC); |
1026 | dim_t outputSliceChannels = |
1027 | SliceRange(convSplitNode->getType(ConvNodeTy::ResultIdx)) |
1028 | .getDimSize(Shape::DimC); |
1029 | auto group = convOrigNode->getGroup(); |
1030 | |
1031 | CHECK_EQ(outputChannels % group, 0) |
1032 | << "Output channels must be divisible by group!" ; |
1033 | dim_t outputChannelsPerGroup = outputChannels / group; |
1034 | |
1035 | if (outputSliceChannels <= outputChannelsPerGroup) { |
1036 | // If the output slice range spans fully or partially one group then we |
1037 | // set the group to 1. |
1038 | convSplitNode->setGroup(1); |
1039 | } else { |
1040 | // If the output slice range spans more than a group then it must span a |
1041 | // multiple of outputChannelsPerGroup. |
1042 | CHECK_EQ(outputSliceChannels % outputChannelsPerGroup, 0) |
1043 | << "Output slice channels must be divisible by the output channels per " |
1044 | "group!" ; |
1045 | dim_t splitGroup = outputSliceChannels / outputChannelsPerGroup; |
1046 | convSplitNode->setGroup(static_cast<unsigned_t>(splitGroup)); |
1047 | } |
1048 | } |
1049 | |
1050 | ///===---------------------------------------------------------------------===// |
1051 | /// Pool |
1052 | ///===---------------------------------------------------------------------===// |
1053 | static CheckedRangeAndPads |
1054 | getPoolInputCheckedRangeAndPads(const DimRange &outputSliceRange, |
1055 | const DimRange &inputRange, dim_t kernel, |
1056 | dim_t stride, DimPads pads) { |
1057 | return getConvInputCheckedRangeAndPads(outputSliceRange, inputRange, kernel, |
1058 | stride, pads, /* dilation */ 1); |
1059 | } |
1060 | |
1061 | template <class PoolNode, typename Shape> |
1062 | static std::vector<OpIdxAndMap> getPoolInputIdxAndMaps(const PoolNode *node) { |
1063 | |
1064 | ShapeHW kernels = ShapeHW(node->getKernels()); |
1065 | ShapeHW strides = ShapeHW(node->getStrides()); |
1066 | PaddingTLBR pads(node->getPads()); |
1067 | DimPads padsTB = {pads.top, pads.bottom}; |
1068 | DimPads padsLR = {pads.left, pads.right}; |
1069 | |
1070 | // Output slice to input slice range map. |
1071 | SliceRange inputRange = SliceRange(node->getInput().getType()); |
1072 | CheckedSliceRangeMap inputSliceRangeMap = |
1073 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1074 | // Get range and pads for dimension H. |
1075 | auto checkedRangeAndPadsH = getPoolInputCheckedRangeAndPads( |
1076 | outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height, |
1077 | strides.height, padsTB); |
1078 | |
1079 | // Get range and pads for dimension W. |
1080 | auto checkedRangeAndPadsW = getPoolInputCheckedRangeAndPads( |
1081 | outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width, |
1082 | strides.width, padsLR); |
1083 | |
1084 | std::vector<DimRange> inputDimRanges(4); |
1085 | inputDimRanges[Shape::DimN] = outputSliceRange[Shape::DimN]; |
1086 | inputDimRanges[Shape::DimH] = checkedRangeAndPadsH.range; |
1087 | inputDimRanges[Shape::DimW] = checkedRangeAndPadsW.range; |
1088 | inputDimRanges[Shape::DimC] = outputSliceRange[Shape::DimC]; |
1089 | bool allowed = checkedRangeAndPadsH.check && checkedRangeAndPadsW.check; |
1090 | return {allowed, SliceRange(inputDimRanges)}; |
1091 | }; |
1092 | |
1093 | // Return input index and map. |
1094 | return {{PoolNode::InputIdx, inputSliceRangeMap}}; |
1095 | } |
1096 | |
1097 | template <class PoolNode, typename Shape> |
1098 | void PoolSplitNodeModifier(const Node *origNode, Node *splitNode, |
1099 | const std::vector<SliceRange> &inputSliceRanges, |
1100 | const std::vector<SliceRange> &outputSliceRanges) { |
1101 | auto *poolOrigNode = dyn_cast<PoolNode>(origNode); |
1102 | auto *poolSplitNode = dyn_cast<PoolNode>(splitNode); |
1103 | if (!(poolOrigNode && poolSplitNode)) { |
1104 | return; |
1105 | } |
1106 | |
1107 | ShapeHW kernels = ShapeHW(poolOrigNode->getKernels()); |
1108 | ShapeHW strides = ShapeHW(poolOrigNode->getStrides()); |
1109 | PaddingTLBR pads(poolOrigNode->getPads()); |
1110 | DimPads padsTB = {pads.top, pads.bottom}; |
1111 | DimPads padsLR = {pads.left, pads.right}; |
1112 | |
1113 | // Get paddings for split node. |
1114 | auto outputSliceRange = outputSliceRanges[PoolNode::ResultIdx]; |
1115 | auto inputRange = SliceRange(poolOrigNode->getInput().getType()); |
1116 | auto checkedRangeAndPadsH = getPoolInputCheckedRangeAndPads( |
1117 | outputSliceRange[Shape::DimH], inputRange[Shape::DimH], kernels.height, |
1118 | strides.height, padsTB); |
1119 | auto checkedRangeAndPadsW = getPoolInputCheckedRangeAndPads( |
1120 | outputSliceRange[Shape::DimW], inputRange[Shape::DimW], kernels.width, |
1121 | strides.width, padsLR); |
1122 | DimPads splitPadsTB = checkedRangeAndPadsH.pads; |
1123 | DimPads splitPadsLR = checkedRangeAndPadsW.pads; |
1124 | |
1125 | // Modify paddings for split node. |
1126 | std::vector<unsigned_t> splitPads = { |
1127 | static_cast<unsigned_t>(splitPadsTB.first), |
1128 | static_cast<unsigned_t>(splitPadsLR.first), |
1129 | static_cast<unsigned_t>(splitPadsTB.second), |
1130 | static_cast<unsigned_t>(splitPadsLR.second)}; |
1131 | poolSplitNode->setPads(splitPads); |
1132 | } |
1133 | |
1134 | ///===---------------------------------------------------------------------===// |
1135 | /// FullyConnected |
1136 | ///===---------------------------------------------------------------------===// |
1137 | static std::vector<OpIdxAndMap> |
1138 | getFullyConnectedInputIdxAndMaps(const FullyConnectedNode *node) { |
1139 | // Output slice to input slice range map. |
1140 | CheckedSliceRangeMap inputSliceRangeMap = |
1141 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1142 | SliceRange inputRange = SliceRange(node->getInput().getType()); |
1143 | inputRange[ShapeHW::DimH] = outputSliceRange[ShapeHW::DimH]; |
1144 | return {true, inputRange}; |
1145 | }; |
1146 | |
1147 | // Output slice to weights slice range map. |
1148 | CheckedSliceRangeMap weightsSliceRangeMap = |
1149 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1150 | SliceRange weightsRange = SliceRange(node->getWeights().getType()); |
1151 | weightsRange[ShapeHW::DimW] = outputSliceRange[ShapeHW::DimW]; |
1152 | return {true, weightsRange}; |
1153 | }; |
1154 | |
1155 | // Output slice to bias slice range map. |
1156 | CheckedSliceRangeMap biasSliceRangeMap = |
1157 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1158 | return {true, SliceRange({outputSliceRange[ShapeHW::DimW]})}; |
1159 | }; |
1160 | |
1161 | // Return input index and map. |
1162 | return {{FullyConnectedNode::InputIdx, inputSliceRangeMap}, |
1163 | {FullyConnectedNode::WeightsIdx, weightsSliceRangeMap}, |
1164 | {FullyConnectedNode::BiasIdx, biasSliceRangeMap}}; |
1165 | } |
1166 | |
1167 | ///===---------------------------------------------------------------------===// |
1168 | /// MatMul |
1169 | ///===---------------------------------------------------------------------===// |
1170 | static std::vector<OpIdxAndMap> |
1171 | getMatMulInputIdxAndMaps(const MatMulNode *node) { |
1172 | // Output slice to LHS slice range map. |
1173 | CheckedSliceRangeMap lhsSliceRangeMap = |
1174 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1175 | SliceRange lhsRange = SliceRange(node->getLHS().getType()); |
1176 | lhsRange[ShapeHW::DimH] = outputSliceRange[ShapeHW::DimH]; |
1177 | return {true, lhsRange}; |
1178 | }; |
1179 | |
1180 | // Output slice to RHS slice range map. |
1181 | CheckedSliceRangeMap rhsSliceRangeMap = |
1182 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1183 | SliceRange rhsRange = SliceRange(node->getRHS().getType()); |
1184 | rhsRange[ShapeHW::DimW] = outputSliceRange[ShapeHW::DimW]; |
1185 | return {true, rhsRange}; |
1186 | }; |
1187 | |
1188 | // Return input index and map. |
1189 | return {{MatMulNode::LHSIdx, lhsSliceRangeMap}, |
1190 | {MatMulNode::RHSIdx, rhsSliceRangeMap}}; |
1191 | } |
1192 | |
1193 | ///===---------------------------------------------------------------------===// |
1194 | /// BatchMatMul |
1195 | ///===---------------------------------------------------------------------===// |
1196 | static std::vector<OpIdxAndMap> |
1197 | getBatchMatMulInputIdxAndMaps(const BatchMatMulNode *node) { |
1198 | // Output slice to LHS slice range map. |
1199 | CheckedSliceRangeMap lhsSliceRangeMap = |
1200 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1201 | SliceRange lhsRange = SliceRange(node->getLHS().getType()); |
1202 | lhsRange[ShapeNHW::DimN] = outputSliceRange[ShapeNHW::DimN]; |
1203 | lhsRange[ShapeNHW::DimH] = outputSliceRange[ShapeNHW::DimH]; |
1204 | return {true, lhsRange}; |
1205 | }; |
1206 | |
1207 | // Output slice to RHS slice range map. |
1208 | CheckedSliceRangeMap rhsSliceRangeMap = |
1209 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1210 | SliceRange rhsRange = SliceRange(node->getRHS().getType()); |
1211 | rhsRange[ShapeNHW::DimN] = outputSliceRange[ShapeNHW::DimN]; |
1212 | rhsRange[ShapeNHW::DimW] = outputSliceRange[ShapeNHW::DimW]; |
1213 | return {true, rhsRange}; |
1214 | }; |
1215 | |
1216 | // Return input index and map. |
1217 | return {{BatchMatMulNode::LHSIdx, lhsSliceRangeMap}, |
1218 | {BatchMatMulNode::RHSIdx, rhsSliceRangeMap}}; |
1219 | } |
1220 | |
1221 | ///===---------------------------------------------------------------------===// |
1222 | /// BatchedAdd |
1223 | ///===---------------------------------------------------------------------===// |
1224 | static std::vector<OpIdxAndMap> |
1225 | getBatchedAddInputIdxAndMaps(const BatchedAddNode *node) { |
1226 | |
1227 | // Output slice to Batch slice range map. |
1228 | CheckedSliceRangeMap batchSliceRangeMap = |
1229 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1230 | return {true, outputSliceRange}; |
1231 | }; |
1232 | |
1233 | // Output slice to Slice slice range map. |
1234 | CheckedSliceRangeMap sliceSliceRangeMap = |
1235 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1236 | size_t numOutDims = outputSliceRange.getNumDims(); |
1237 | return {true, outputSliceRange.extractRanges(1, numOutDims - 1)}; |
1238 | }; |
1239 | |
1240 | // Return input index and map. |
1241 | return {{BatchedAddNode::BatchIdx, batchSliceRangeMap}, |
1242 | {BatchedAddNode::SliceIdx, sliceSliceRangeMap}}; |
1243 | } |
1244 | |
1245 | ///===---------------------------------------------------------------------===// |
1246 | /// Transpose |
1247 | ///===---------------------------------------------------------------------===// |
1248 | static std::vector<OpIdxAndMap> |
1249 | getTransposeInputIdxAndMaps(const TransposeNode *node) { |
1250 | |
1251 | // Transpose shuffle. |
1252 | std::vector<unsigned_t> nodeShuffle = node->getShuffle(); |
1253 | std::vector<size_t> shuffle(nodeShuffle.size()); |
1254 | for (size_t idx = 0, e = nodeShuffle.size(); idx < e; ++idx) { |
1255 | shuffle[idx] = static_cast<size_t>(nodeShuffle[idx]); |
1256 | } |
1257 | |
1258 | // Output slice to Input slice range map. |
1259 | CheckedSliceRangeMap inputSliceRangeMap = |
1260 | [=](const SliceRange &outputSliceRange) -> CheckedSliceRange { |
1261 | return {true, outputSliceRange.shuffleRanges(shuffle, /*invert*/ true)}; |
1262 | }; |
1263 | |
1264 | // Return input index and map. |
1265 | return {{TransposeNode::InputIdx, inputSliceRangeMap}}; |
1266 | } |
1267 | |
1268 | ///===---------------------------------------------------------------------===// |
1269 | /// splitNode |
1270 | ///===---------------------------------------------------------------------===// |
1271 | Expected<std::vector<Node *>> |
1272 | glow::splitNode(Node *node, const SplitNodeOption *splitOption, |
1273 | const SplitNodeConstraint *splitConstraint) { |
1274 | |
1275 | // We can do the splitting if at least the option or the constraint is given. |
1276 | RETURN_ERR_IF_NOT( |
1277 | splitOption || splitConstraint, |
1278 | "At least the split option or the split constraint must be given!" ); |
1279 | |
1280 | switch (node->getKind()) { |
1281 | |
1282 | case Kinded::Kind::ConvolutionNodeKind: { |
1283 | return splitAndReplaceNode( |
1284 | node, splitOption, splitConstraint, ConvolutionNode::ResultIdx, |
1285 | getConv2DInputIdxAndMaps<ConvolutionNode, ShapeNHWC>( |
1286 | dyn_cast<ConvolutionNode>(node)), |
1287 | {}, Conv2DSplitNodeModifier<ConvolutionNode, ShapeNHWC>); |
1288 | } |
1289 | |
1290 | case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: { |
1291 | return splitAndReplaceNode( |
1292 | node, splitOption, splitConstraint, |
1293 | ChannelwiseQuantizedConvolutionNode::ResultIdx, |
1294 | getConv2DInputIdxAndMaps<ChannelwiseQuantizedConvolutionNode, |
1295 | ShapeNHWC>( |
1296 | dyn_cast<ChannelwiseQuantizedConvolutionNode>(node)), |
1297 | {}, |
1298 | Conv2DSplitNodeModifier<ChannelwiseQuantizedConvolutionNode, |
1299 | ShapeNHWC>); |
1300 | } |
1301 | |
1302 | case Kinded::Kind::MaxPoolNodeKind: { |
1303 | // The current definition of the MaxPool node does not allow splitting |
1304 | // of the second output operand 'Argmax' which contains flattened |
1305 | // indices whose values will be altered if processed in smaller chunks. |
1306 | // We allow splitting only if the 'Argmax' node value has no users. |
1307 | if (node->getNthResult(MaxPoolNode::ArgmaxIdx).getNumUsers() != 0) { |
1308 | break; |
1309 | } |
1310 | return splitAndReplaceNode( |
1311 | node, splitOption, splitConstraint, MaxPoolNode::ResultIdx, |
1312 | getPoolInputIdxAndMaps<MaxPoolNode, ShapeNHWC>( |
1313 | dyn_cast<MaxPoolNode>(node)), |
1314 | {{MaxPoolNode::ArgmaxIdx, CheckedSliceRangeMapIdentity}}, |
1315 | PoolSplitNodeModifier<MaxPoolNode, ShapeNHWC>); |
1316 | } |
1317 | |
1318 | case Kinded::Kind::AvgPoolNodeKind: { |
1319 | return splitAndReplaceNode( |
1320 | node, splitOption, splitConstraint, AvgPoolNode::ResultIdx, |
1321 | getPoolInputIdxAndMaps<AvgPoolNode, ShapeNHWC>( |
1322 | dyn_cast<AvgPoolNode>(node)), |
1323 | {}, PoolSplitNodeModifier<AvgPoolNode, ShapeNHWC>); |
1324 | } |
1325 | |
1326 | case Kinded::Kind::FullyConnectedNodeKind: { |
1327 | return splitAndReplaceNode( |
1328 | node, splitOption, splitConstraint, FullyConnectedNode::ResultIdx, |
1329 | getFullyConnectedInputIdxAndMaps(dyn_cast<FullyConnectedNode>(node))); |
1330 | } |
1331 | |
1332 | case Kinded::Kind::MatMulNodeKind: { |
1333 | return splitAndReplaceNode( |
1334 | node, splitOption, splitConstraint, MatMulNode::ResultIdx, |
1335 | getMatMulInputIdxAndMaps(dyn_cast<MatMulNode>(node))); |
1336 | } |
1337 | |
1338 | case Kinded::Kind::BatchMatMulNodeKind: { |
1339 | return splitAndReplaceNode( |
1340 | node, splitOption, splitConstraint, BatchMatMulNode::ResultIdx, |
1341 | getBatchMatMulInputIdxAndMaps(dyn_cast<BatchMatMulNode>(node))); |
1342 | } |
1343 | |
1344 | case Kinded::Kind::BatchedAddNodeKind: { |
1345 | return splitAndReplaceNode( |
1346 | node, splitOption, splitConstraint, BatchedAddNode::ResultIdx, |
1347 | getBatchedAddInputIdxAndMaps(dyn_cast<BatchedAddNode>(node))); |
1348 | } |
1349 | |
1350 | case Kinded::Kind::TransposeNodeKind: { |
1351 | return splitAndReplaceNode( |
1352 | node, splitOption, splitConstraint, TransposeNode::ResultIdx, |
1353 | getTransposeInputIdxAndMaps(dyn_cast<TransposeNode>(node))); |
1354 | } |
1355 | |
1356 | case Kinded::Kind::AddNodeKind: |
1357 | case Kinded::Kind::MulNodeKind: |
1358 | case Kinded::Kind::SubNodeKind: |
1359 | case Kinded::Kind::DivNodeKind: |
1360 | case Kinded::Kind::FmodNodeKind: |
1361 | case Kinded::Kind::MaxNodeKind: |
1362 | case Kinded::Kind::MinNodeKind: |
1363 | case Kinded::Kind::CmpLTENodeKind: |
1364 | case Kinded::Kind::CmpLTNodeKind: |
1365 | case Kinded::Kind::CmpEQNodeKind: |
1366 | case Kinded::Kind::PowNodeKind: { |
1367 | DCHECK_EQ(node->getNumInputs(), 2) << "Binary operator invalid!" ; |
1368 | DCHECK_EQ(node->getNumResults(), 1) << "Binary operator invalid!" ; |
1369 | return splitAndReplaceNode( |
1370 | node, splitOption, splitConstraint, ArithmeticNode::ResultIdx, |
1371 | {{ArithmeticNode::LHSIdx, CheckedSliceRangeMapIdentity}, |
1372 | {ArithmeticNode::RHSIdx, CheckedSliceRangeMapIdentity}}); |
1373 | } |
1374 | |
1375 | case Kinded::Kind::ReluNodeKind: |
1376 | case Kinded::Kind::LeakyReluNodeKind: |
1377 | case Kinded::Kind::ClipNodeKind: |
1378 | case Kinded::Kind::TanhNodeKind: |
1379 | case Kinded::Kind::SigmoidNodeKind: |
1380 | case Kinded::Kind::LogNodeKind: |
1381 | case Kinded::Kind::ExpNodeKind: |
1382 | case Kinded::Kind::QuantizeNodeKind: |
1383 | case Kinded::Kind::RescaleQuantizedNodeKind: |
1384 | case Kinded::Kind::DequantizeNodeKind: |
1385 | case Kinded::Kind::ConvertToNodeKind: { |
1386 | DCHECK_EQ(node->getNumInputs(), 1) << "Unary operator invalid!" ; |
1387 | DCHECK_EQ(node->getNumResults(), 1) << "Unary operator invalid!" ; |
1388 | return splitAndReplaceNode(node, splitOption, splitConstraint, |
1389 | /*splitOutputIdx*/ 0, |
1390 | {{0, CheckedSliceRangeMapIdentity}}); |
1391 | } |
1392 | |
1393 | default: |
1394 | VLOG(1) << "Splitting node type '" << node->getKindName() |
1395 | << "' is not supported!\n" ; |
1396 | break; |
1397 | } |
1398 | |
1399 | return std::vector<Node *>(); |
1400 | } |
1401 | |
1402 | Expected<std::vector<Node *>> |
1403 | glow::splitNode(Node *node, const SplitNodeOption &splitOption) { |
1404 | return splitNode(node, &splitOption, nullptr); |
1405 | } |
1406 | |
1407 | Expected<std::vector<Node *>> |
1408 | glow::splitNode(Node *node, const SplitNodeConstraint &splitConstraint) { |
1409 | return splitNode(node, nullptr, &splitConstraint); |
1410 | } |
1411 | |
1412 | ///===---------------------------------------------------------------------===// |
1413 | /// splitNodes |
1414 | ///===---------------------------------------------------------------------===// |
1415 | Expected<SplitNodeMap> |
1416 | glow::splitNodes(Function *F, const SplitNodeOptionMap &splitOptionMap, |
1417 | const SplitNodeConstraintMap &splitConstraintMap) { |
1418 | // Create split map. |
1419 | SplitNodeMap splitMap; |
1420 | |
1421 | // Since we will be transforming the original list of nodes, reverse iterate. |
1422 | auto &nodes = F->getNodes(); |
1423 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e;) { |
1424 | Node *node = &*(it++); |
1425 | |
1426 | // Find explicit split option for current node (if any). |
1427 | const SplitNodeOption *splitOption = nullptr; |
1428 | auto splitOptionIt = splitOptionMap.find(node); |
1429 | if (splitOptionIt != splitOptionMap.end()) { |
1430 | splitOption = splitOptionIt->second; |
1431 | } |
1432 | |
1433 | // Find explicit split constraint for current node (if any). |
1434 | const SplitNodeConstraint *splitConstraint = nullptr; |
1435 | auto splitConstraintIt = splitConstraintMap.find(node); |
1436 | if (splitConstraintIt != splitConstraintMap.end()) { |
1437 | splitConstraint = splitConstraintIt->second; |
1438 | } |
1439 | |
1440 | // Split current node if at least the option or the constraint is given. |
1441 | if (splitOption || splitConstraint) { |
1442 | ASSIGN_VALUE_OR_RETURN_ERR(splitMap[node], |
1443 | splitNode(node, splitOption, splitConstraint)); |
1444 | } |
1445 | } |
1446 | |
1447 | // Verify function after splitting nodes. |
1448 | RETURN_ERR_IF_NOT(F->verify(), "Function is not valid after node splitting!" ); |
1449 | return splitMap; |
1450 | } |
1451 | |
1452 | Expected<SplitNodeMap> glow::splitNodes(Function *F, |
1453 | const SplitNodeOption &splitOption) { |
1454 | // Since we will be transforming the original list of nodes, reverse iterate. |
1455 | SplitNodeMap splitMap; |
1456 | auto &nodes = F->getNodes(); |
1457 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e;) { |
1458 | Node *node = &*(it++); |
1459 | const SplitNodeConstraint *splitConstraint = nullptr; |
1460 | ASSIGN_VALUE_OR_RETURN_ERR(splitMap[node], |
1461 | splitNode(node, &splitOption, splitConstraint)); |
1462 | } |
1463 | // Verify function after splitting nodes. |
1464 | RETURN_ERR_IF_NOT(F->verify(), "Function is not valid after node splitting!" ); |
1465 | return splitMap; |
1466 | } |
1467 | |
1468 | Expected<SplitNodeMap> |
1469 | glow::splitNodes(Function *F, const SplitNodeConstraint &splitConstraint) { |
1470 | // Since we will be transforming the original list of nodes, reverse iterate. |
1471 | SplitNodeMap splitMap; |
1472 | auto &nodes = F->getNodes(); |
1473 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e;) { |
1474 | Node *node = &*(it++); |
1475 | const SplitNodeOption *splitOption = nullptr; |
1476 | ASSIGN_VALUE_OR_RETURN_ERR(splitMap[node], |
1477 | splitNode(node, splitOption, &splitConstraint)); |
1478 | } |
1479 | // Verify function after splitting nodes. |
1480 | RETURN_ERR_IF_NOT(F->verify(), "Function is not valid after node splitting!" ); |
1481 | return splitMap; |
1482 | } |
1483 | |
1484 | ///===---------------------------------------------------------------------===// |
1485 | /// splitNodeRecursively |
1486 | ///===---------------------------------------------------------------------===// |
1487 | static Error |
1488 | splitNodeRecursivelyMain(SplitNodeMap &splitMap, Node *node, |
1489 | const SplitNodeOption *splitOption, |
1490 | const SplitNodeConstraint *splitConstraint, |
1491 | unsigned maxDepth, bool singleUseOnly) { |
1492 | |
1493 | // Early return if depth is 0. |
1494 | if (maxDepth == 0) { |
1495 | return Error::success(); |
1496 | } |
1497 | |
1498 | // We can do the splitting if at least the option or the constraint is given. |
1499 | RETURN_ERR_IF_NOT( |
1500 | splitOption || splitConstraint, |
1501 | "At least the split option or the split constraint must be given!" ); |
1502 | |
1503 | // Split starting node. |
1504 | std::vector<Node *> splitNodesCurr; |
1505 | ASSIGN_VALUE_OR_RETURN_ERR(splitNodesCurr, |
1506 | splitNode(node, splitOption, splitConstraint)); |
1507 | |
1508 | // If starting node was NOT split then return, otherwise add nodes to map. |
1509 | if (!splitNodesCurr.size()) { |
1510 | return Error::success(); |
1511 | } else { |
1512 | splitMap[node] = splitNodesCurr; |
1513 | } |
1514 | |
1515 | // Early return if depth is 1. |
1516 | if (maxDepth == 1) { |
1517 | return Error::success(); |
1518 | } |
1519 | |
1520 | // Iterate through all of the input operands and split them. |
1521 | unsigned inpNum = splitNodesCurr.front()->getNumInputs(); |
1522 | for (unsigned inpIdx = 0; inpIdx < inpNum; ++inpIdx) { |
1523 | |
1524 | // Find parent node value and ranges for the current input. |
1525 | std::vector<SliceRange> inputRanges; |
1526 | inputRanges.reserve(splitNodesCurr.size()); |
1527 | NodeValue sliceInputNodeValue = nullptr; |
1528 | for (const Node *splitNode : splitNodesCurr) { |
1529 | // Get parent node value and range. |
1530 | NodeValue inputNV = splitNode->getNthInput(inpIdx); |
1531 | if (auto *sliceNode = dyn_cast<SliceNode>(inputNV)) { |
1532 | inputNV = sliceNode->getInput(); |
1533 | inputRanges.push_back(SliceRange(sliceNode)); |
1534 | } else { |
1535 | inputRanges.push_back(SliceRange(inputNV.getType())); |
1536 | } |
1537 | // Verify that parent node value is common. |
1538 | if (!sliceInputNodeValue) { |
1539 | sliceInputNodeValue = inputNV; |
1540 | } else { |
1541 | RETURN_ERR_IF_NOT( |
1542 | sliceInputNodeValue == inputNV, |
1543 | "Input slices do not have a common parent node value!" ); |
1544 | } |
1545 | } |
1546 | |
1547 | // If the node value which is sliced has other consumers than the SliceNodes |
1548 | // inserted during splitting then we do not split the node which produces |
1549 | // that node value. |
1550 | if (singleUseOnly && |
1551 | sliceInputNodeValue.getNumUsers() > splitNodesCurr.size()) { |
1552 | continue; |
1553 | } |
1554 | |
1555 | // Check that we split the 1st output operand of the parent node. |
1556 | if (sliceInputNodeValue.getResNo() != SplitNodeOutputIdx) { |
1557 | continue; |
1558 | } |
1559 | |
1560 | // Split the input node of the SliceNodes using same slice ranges. |
1561 | auto splitInputOption = SplitNodeBySliceRanges(inputRanges); |
1562 | Node *splitInputNode = sliceInputNodeValue.getNode(); |
1563 | RETURN_IF_ERR(splitNodeRecursivelyMain(splitMap, splitInputNode, |
1564 | &splitInputOption, splitConstraint, |
1565 | maxDepth - 1, singleUseOnly)); |
1566 | |
1567 | // Remove Slice and Insert nodes between adjacent split nodes. |
1568 | if (splitMap.count(splitInputNode)) { |
1569 | auto &splitNodesNext = splitMap[splitInputNode]; |
1570 | assert(splitNodesCurr.size() == splitNodesNext.size() && |
1571 | "Mismatch for number of split Nodes!" ); |
1572 | // Create short circuit between inputs and outputs of adjacent nodes. |
1573 | for (size_t idx = 0, len = splitNodesCurr.size(); idx < len; ++idx) { |
1574 | NodeValue splitNodeNextOut = |
1575 | splitNodesNext[idx]->getNthResult(SplitNodeOutputIdx); |
1576 | NodeValue splitNodeCurrInp = splitNodesCurr[idx]->getNthInput(inpIdx); |
1577 | assert( |
1578 | splitNodeNextOut.getType()->isEqual(splitNodeCurrInp.getType()) && |
1579 | "Mismatch between input/output type when doing short-circuit!" ); |
1580 | assert(splitNodeNextOut.getNumUsers() == 1 && |
1581 | "Split node output value has more than one use!" ); |
1582 | assert(splitNodeCurrInp.getNumUsers() == 1 && |
1583 | "Split node input value has more than one use!" ); |
1584 | splitNodeCurrInp.replaceAllUsesOfWith(splitNodeNextOut); |
1585 | } |
1586 | } |
1587 | } |
1588 | |
1589 | return Error::success(); |
1590 | } |
1591 | |
1592 | Expected<SplitNodeMap> |
1593 | glow::splitNodeRecursively(Node *node, const SplitNodeOption *splitOption, |
1594 | const SplitNodeConstraint *splitConstraint, |
1595 | unsigned maxDepth, bool singleUseOnly) { |
1596 | Function *F = node->getParent(); |
1597 | RETURN_ERR_IF_NOT(F, "Cannot split a node without a parent Function!" ); |
1598 | |
1599 | // Create split map. |
1600 | SplitNodeMap splitMap; |
1601 | |
1602 | // Check if this node has single use only before splitting it. |
1603 | if (singleUseOnly && |
1604 | node->getNthResult(SplitNodeOutputIdx).getNumUsers() > 1) { |
1605 | return splitMap; |
1606 | } |
1607 | |
1608 | // Split node recursively. |
1609 | RETURN_IF_ERR(splitNodeRecursivelyMain( |
1610 | splitMap, node, splitOption, splitConstraint, maxDepth, singleUseOnly)); |
1611 | |
1612 | // Perform DCE to cleanup. |
1613 | auto cctx = glow::CompilationContext(); |
1614 | glow::runDCEPass(F, cctx); |
1615 | |
1616 | // Verify function after splitting nodes. |
1617 | RETURN_ERR_IF_NOT(F->verify(), |
1618 | "Function is not valid after recursive node splitting!" ); |
1619 | return splitMap; |
1620 | } |
1621 | |
1622 | Expected<SplitNodeMap> |
1623 | glow::splitNodeRecursively(Node *node, const SplitNodeOption &splitOption, |
1624 | unsigned maxDepth, bool singleUseOnly) { |
1625 | return splitNodeRecursively(node, &splitOption, nullptr, maxDepth, |
1626 | singleUseOnly); |
1627 | } |
1628 | |
1629 | Expected<SplitNodeMap> |
1630 | glow::splitNodeRecursively(Node *node, |
1631 | const SplitNodeConstraint &splitConstraint, |
1632 | unsigned maxDepth, bool singleUseOnly) { |
1633 | return splitNodeRecursively(node, nullptr, &splitConstraint, maxDepth, |
1634 | singleUseOnly); |
1635 | } |
1636 | |