1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
18 | #include <folly/String.h> |
19 | |
20 | #include "glow/Backend/Backend.h" |
21 | #include "glow/Converter/Float16Converter.h" |
22 | #include "glow/Converter/FusedRowwiseConverter.h" |
23 | #include "glow/Converter/TypeAToTypeBFunctionConverter.h" |
24 | #include "glow/Flags/Flags.h" |
25 | #include "glow/Graph/Graph.h" |
26 | #include "glow/Graph/Log.h" |
27 | #include "glow/Graph/Node.h" |
28 | #include "glow/Graph/Nodes.h" |
29 | #include "glow/Graph/PlaceholderBindings.h" |
30 | #include "glow/Graph/TensorLayout.h" |
31 | #include "glow/Graph/Utils.h" |
32 | #include "glow/Graph/VerifierHelper.h" |
33 | #include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h" |
34 | #include "glow/Optimizer/GraphOptimizer/FunctionPasses.h" |
35 | #include "glow/Optimizer/Lower/Lower.h" |
36 | #include "glow/PassManager/PassManager.h" |
37 | #include "glow/Quantization/Base/Base.h" |
38 | #include "glow/Quantization/Quantization.h" |
39 | #include "glow/Runtime/RuntimeTypes.h" |
40 | |
41 | #include "llvm/Support/Casting.h" |
42 | #include "llvm/Support/CommandLine.h" |
43 | |
44 | #include <algorithm> |
45 | #include <numeric> |
46 | #include <unordered_map> |
47 | #include <unordered_set> |
48 | #include <vector> |
49 | |
50 | // Utility macro to continue loop if given condition is not met. |
51 | // This is intended to improve code readability and size. |
52 | #define CONTINUE_IF_NOT(cond) \ |
53 | if (!(cond)) { \ |
54 | continue; \ |
55 | } |
56 | |
57 | llvm::cl::OptionCategory graphOptCat("Graph Optimizations Options" ); |
58 | llvm::cl::opt<unsigned> constDedupSizeOpt( |
59 | "const-dedup-size" , |
60 | llvm::cl::desc( |
61 | "Max number of elements allowed for deduplicating Constants. " |
62 | "A value equal to 0 means no limit. Default is 0." ), |
63 | llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(graphOptCat)); |
64 | |
65 | using namespace glow; |
66 | using llvm::cast; |
67 | using llvm::dyn_cast; |
68 | using llvm::isa; |
69 | |
70 | static bool shouldDeleteNode(Node *N) { |
71 | // In general, nodes who have side effects are retained. |
72 | if (N->hasSideEffects()) { |
73 | return false; |
74 | } |
75 | |
76 | // Don't delete nodes that have users. |
77 | if (N->hasUsers()) { |
78 | return false; |
79 | } |
80 | |
81 | return true; |
82 | } |
83 | |
84 | ConstantModificationPreventer::ConstantModificationPreventer( |
85 | Module &mod, CompilationContext &cctx) |
86 | : ScopeGuard([&]() { |
87 | // Ensure we cleanup Placeholder-Constant swap if necessary. |
88 | auto &PHs = mod_.getPlaceholders(); |
89 | for (auto &pair : tmpPHToConstMap_) { |
90 | Placeholder *tmpPH = pair.first; |
91 | Constant *C = pair.second; |
92 | tmpPH->getOutput().replaceAllUsesOfWith(C->getOutput()); |
93 | mod_.erasePlaceholder(std::find(PHs.begin(), PHs.end(), tmpPH)); |
94 | } |
95 | cctx_.optimizationOpts.enableConstantFolding = |
96 | origEnableConstantFolding_; |
97 | }), |
98 | mod_(mod), cctx_(cctx), |
99 | origEnableConstantFolding_(cctx.optimizationOpts.enableConstantFolding) { |
100 | // By default dismiss until explicitly activated. |
101 | dismissed_ = true; |
102 | } |
103 | |
104 | void ConstantModificationPreventer::activate() { |
105 | dismissed_ = false; |
106 | // Prevent Constant modification by temporarily replacing them with PHs. |
107 | for (Constant *C : mod_.getConstants()) { |
108 | // Note: These temp Placeholders are more like static Placeholders, but we |
109 | // don't want to set them as static here because optimizations may kick in |
110 | // to modify the type of the static Placeholder (see |
111 | // cctx.optimizationOpts.foldStaticPlaceholderConversions). |
112 | Placeholder *tmpPH = mod_.createPlaceholder( |
113 | C->getType(), C->getName().str() + "_SWAP_CONST_FOLD" , |
114 | /* isTrainable */ false, C->getLayout()); |
115 | tmpPHToConstMap_[tmpPH] = C; |
116 | cctx_.optimizationOpts.tempPHsForConstants.insert(tmpPH); |
117 | C->getOutput().replaceAllUsesOfWith(tmpPH->getOutput()); |
118 | } |
119 | // Disable constant folding temporarily; restored later by the scope guard. |
120 | cctx_.optimizationOpts.enableConstantFolding = false; |
121 | } |
122 | |
123 | /// Helper that \returns true if all functions in \p mod are loaded. |
124 | static bool areAllFunctionsLoaded(Module *mod) { |
125 | for (auto *MF : mod->getFunctions()) { |
126 | if (MF->getState() < FunctionState::FuncLoaded) { |
127 | return false; |
128 | } |
129 | } |
130 | return true; |
131 | } |
132 | |
133 | /// Helper that \returns the shuffle that inverts \p shuffle. For example, if |
134 | /// \p shuffle is {3, 0, 1, 2}, then this function returns {1, 2, 3, 0}. |
135 | static llvm::SmallVector<unsigned_t, max_tensor_dimensions> |
136 | invertShuffle(llvm::ArrayRef<unsigned_t> shuffle) { |
137 | llvm::SmallVector<unsigned_t, max_tensor_dimensions> invertedShuffle; |
138 | invertedShuffle.resize(shuffle.size()); |
139 | |
140 | for (size_t i = 0; i < shuffle.size(); ++i) { |
141 | invertedShuffle[shuffle[i]] = i; |
142 | } |
143 | |
144 | return invertedShuffle; |
145 | } |
146 | |
147 | /// Add a TranposeNode after \p C in \p F that has the same shuffle as \p TR. |
148 | /// This funcion assumes that the type of \p C and the output type of \p TR are |
149 | /// the same. \returns the newly added TransposeNode. |
150 | static TransposeNode *insertMatchingTransposeAfterConstant(Function *F, |
151 | Constant *C, |
152 | TransposeNode *TR) { |
153 | const auto *CT = C->getOutput().getType(); |
154 | const auto *TRT = TR->getResult().getType(); |
155 | DCHECK(CT->isEqual(*TRT, /* allowDifferentShape */ false, |
156 | /* allowDifferentStride */ false, |
157 | /* allowDifferentScaleOffset */ true)); |
158 | |
159 | auto &T = C->getPayload(); |
160 | |
161 | // In order for a new Transpose node with the same shuffle as TR to be created |
162 | // at the output of the Constant, a new Constant with the same dimension as |
163 | // the input of TR should be created. Note that the original scale and offset |
164 | // should be kept for quantized types. |
165 | auto newConstTy = |
166 | F->getParent()->uniqueTypeWithNewShape(CT, TR->getInput().dims()); |
167 | auto *NC = F->getParent()->createConstant(newConstTy, |
168 | C->getName().str() + ".transposed" ); |
169 | |
170 | // The payload of the original Constant C has the same type as the |
171 | // output of TR. In order to preserve correctness, this payload must |
172 | // be transposed using the inverse of the shuffle of TR and stored |
173 | // into the payload of the new Constant. |
174 | // |
175 | // Another way to think of this is that we are inserting two |
176 | // Transposes that are inverses of each other back to back after the original |
177 | // Constant. The shuffle of the second Transpose must match that of TR. |
178 | // In order to preserve correctness, the shuffle |
179 | // of the first Transpose must be the inverse of that shuffle of the |
180 | // second Transpose. The statement below statically computes this |
181 | // first Transpose. |
182 | T.transpose(&NC->getPayloadMutable(), invertShuffle(TR->getShuffle())); |
183 | |
184 | // Create Transpose on the LHS that has the same shuffle as TR. |
185 | return F->createTranspose("transpose" , NC, TR->getShuffle()); |
186 | } |
187 | |
188 | bool EmptyPass::run(Function *F, const CompilationContext &cctx) { |
189 | return false; |
190 | } |
191 | |
192 | bool DCE::run(Function *F, const CompilationContext &cctx) { |
193 | LOG_SCOPE(F->getLogContext(), getName()); |
194 | |
195 | auto &nodes = F->getNodes(); |
196 | |
197 | std::vector<NodesList::iterator> erasedNodes{}; |
198 | |
199 | bool changed = false; |
200 | |
201 | // Remove unused nodes. |
202 | while (true) { |
203 | bool changedLocally = false; |
204 | for (auto it = nodes.begin(), e = nodes.end(); it != e;) { |
205 | if (!shouldDeleteNode(&*it)) { |
206 | ++it; |
207 | continue; |
208 | } |
209 | |
210 | erasedNodes.push_back(it); |
211 | ++it; |
212 | changedLocally = true; |
213 | changed = true; |
214 | } |
215 | |
216 | while (!erasedNodes.empty()) { |
217 | auto it = erasedNodes.back(); |
218 | F->eraseNode(it); |
219 | erasedNodes.pop_back(); |
220 | } |
221 | |
222 | if (!changedLocally) { |
223 | break; |
224 | } |
225 | } |
226 | |
227 | // Don't remove unused Constants since many may be temporarily unused during |
228 | // optimizations. |
229 | if (cctx.optimizationOpts.delayAndRecordConstantModification) { |
230 | return changed; |
231 | } |
232 | |
233 | if (!areAllFunctionsLoaded(F->getParent())) { |
234 | return changed; |
235 | } |
236 | |
237 | // Delete unused Constants. |
238 | deleteUnusedConstants(*F->getParent()); |
239 | |
240 | return changed; |
241 | } |
242 | |
243 | void glow::deleteUnusedConstants(Module &mod) { |
244 | auto &consts = mod.getConstants(); |
245 | std::vector<ConstList::iterator> erasedConsts{}; |
246 | for (auto it = consts.begin(), e = consts.end(); it != e;) { |
247 | if (!shouldDeleteNode(*it)) { |
248 | ++it; |
249 | continue; |
250 | } |
251 | erasedConsts.push_back(it); |
252 | ++it; |
253 | } |
254 | |
255 | while (!erasedConsts.empty()) { |
256 | auto it = erasedConsts.back(); |
257 | mod.eraseConstant(it); |
258 | erasedConsts.pop_back(); |
259 | } |
260 | } |
261 | |
262 | /// \returns true if the \p shuffle corresponds to an identity operation, false |
263 | /// otherwise. |
264 | static bool isIdentityShuffle(llvm::ArrayRef<unsigned> shuffle) { |
265 | for (size_t i = 0, e = shuffle.size(); i < e; i++) { |
266 | if (shuffle[i] != i) { |
267 | return false; |
268 | } |
269 | } |
270 | return true; |
271 | } |
272 | |
273 | /// \returns True if the node \p N always evaluates to \p val. |
274 | bool isSplatOfVal(Node *N, float val) { |
275 | SplatNode *Z = dyn_cast<SplatNode>(N); |
276 | if (!Z) { |
277 | return false; |
278 | } |
279 | return (Z->getValue() == val); |
280 | } |
281 | |
282 | /// \returns True if the node returns a constant value. |
283 | bool isConstant(Node *N) { return isa<SplatNode>(N); } |
284 | |
285 | /// \returns the new simplified NodeValue or the original node's first result. |
286 | static NodeValue simplifyNode(Node *node, Function *F) { |
287 | // Simplify commutative nodes by moving the constant operator to the right-hand |
288 | // side. |
289 | // Example: C + X => X + C |
290 | #define COMMUTE_CONST_TO_RHS(NodeKind) \ |
291 | if (auto *NN = dyn_cast<NodeKind##Node>(node)) \ |
292 | if (isConstant(NN->getLHS()) && !isConstant(NN->getRHS())) { \ |
293 | return F->create##NodeKind(NN->getName(), NN->getResult().getType(), \ |
294 | NN->getRHS(), NN->getLHS()); \ |
295 | } |
296 | |
297 | COMMUTE_CONST_TO_RHS(Add) |
298 | COMMUTE_CONST_TO_RHS(Mul) |
299 | COMMUTE_CONST_TO_RHS(Max) |
300 | COMMUTE_CONST_TO_RHS(Min) |
301 | #undef COMMUTE_CONST_TO_RHS |
302 | |
303 | if (auto *AN = dyn_cast<AddNode>(node)) { |
304 | // X + 0 => X |
305 | if (isSplatOfVal(AN->getRHS(), 0)) { |
306 | return AN->getLHS(); |
307 | } |
308 | } |
309 | |
310 | if (auto *MN = dyn_cast<MulNode>(node)) { |
311 | // X * 0 => 0 |
312 | if (isSplatOfVal(MN->getRHS(), 0)) { |
313 | return MN->getRHS(); |
314 | } |
315 | // X * 1 => X |
316 | if (isSplatOfVal(MN->getRHS(), 1)) { |
317 | return MN->getLHS(); |
318 | } |
319 | } |
320 | |
321 | if (auto *DN = dyn_cast<DivNode>(node)) { |
322 | // 0 / X => 0 |
323 | if (isSplatOfVal(DN->getLHS(), 0)) { |
324 | return DN->getLHS(); |
325 | } |
326 | // X / 1 => X |
327 | if (isSplatOfVal(DN->getRHS(), 1)) { |
328 | return DN->getLHS(); |
329 | } |
330 | } |
331 | |
332 | // X - 0 => X |
333 | if (auto *SN = dyn_cast<SubNode>(node)) { |
334 | if (isSplatOfVal(SN->getRHS(), 0)) { |
335 | return SN->getLHS(); |
336 | } |
337 | } |
338 | |
339 | return node; |
340 | } |
341 | |
342 | /// Sink Transpose below ChannelShuffle node. |
343 | static bool sinkTranposeBelowChannelShuffle(Function *F, |
344 | ChannelShuffleNode *CS) { |
345 | auto *TR = dyn_cast<TransposeNode>(CS->getInput()); |
346 | if (!TR) { |
347 | return false; |
348 | } |
349 | |
350 | // Create a new ChannelShuffle with kernel parameter transposed by the |
351 | // sinking TR's shuffle because that Transpose will now be moved below this |
352 | // ChannelShuffle operator. |
353 | auto *newCS = |
354 | F->createChannelShuffle(CS->getName(), TR->getInput(), CS->getGroup(), |
355 | TR->getShuffle()[CS->getKernel()]); |
356 | |
357 | // Create a copy of sinkingTR and insert after newChannelShuffle. |
358 | auto *newTR = F->createTranspose(TR->getName(), newCS, TR->getShuffle(), |
359 | TR->getLayout()); |
360 | |
361 | CS->getResult().replaceAllUsesOfWith(newTR); |
362 | |
363 | return true; |
364 | } |
365 | |
366 | /// Given \p CN from \p F, determines if all inputs are ConvertTo that use the |
367 | /// same scale/offset/kind, and if so creates and \returns a new concat with all |
368 | /// inputs as the inputs from the ConvertTo inputs. Otherwise \returns nullptr. |
369 | static ConcatNode *setupConvertToSinkBelowConcat(Function *F, ConcatNode *CN) { |
370 | // Check if all inputs are ConvertTo. |
371 | std::vector<ConvertToNode *> inputNodes; |
372 | inputNodes.reserve(CN->getInputs().size()); |
373 | for (auto &concatInput : CN->getInputs()) { |
374 | auto *CT = dyn_cast<ConvertToNode>(concatInput); |
375 | if (!CT) { |
376 | return nullptr; |
377 | } |
378 | inputNodes.push_back(CT); |
379 | } |
380 | |
381 | // Gather all inputs of the nodes in inputNodes here. |
382 | std::vector<NodeValue> newInputs; |
383 | newInputs.reserve(inputNodes.size()); |
384 | newInputs.push_back(inputNodes[0]->getInput()); |
385 | |
386 | // Get the CN's first input's result and input types to check against all |
387 | // other CN inputs. |
388 | const TypeRef firstResultTy = inputNodes[0]->getResult().getType(); |
389 | const TypeRef firstInputTy = inputNodes[0]->getInput().getType(); |
390 | |
391 | // Check that all inputs have the same output and input element type. |
392 | for (size_t i = 1, e = inputNodes.size(); i < e; i++) { |
393 | const TypeRef currResultTy = inputNodes[i]->getResult().getType(); |
394 | const TypeRef currInputTy = inputNodes[i]->getInput().getType(); |
395 | if (currResultTy->getElementType() != firstResultTy->getElementType()) { |
396 | return nullptr; |
397 | } |
398 | if (firstResultTy->isQuantizedType()) { |
399 | if (currResultTy->getScale() != firstResultTy->getScale() || |
400 | currResultTy->getOffset() != firstResultTy->getOffset()) { |
401 | return nullptr; |
402 | } |
403 | } |
404 | if (currInputTy->getElementType() != firstInputTy->getElementType()) { |
405 | return nullptr; |
406 | } |
407 | if (firstInputTy->isQuantizedType()) { |
408 | if (currInputTy->getScale() != firstInputTy->getScale() || |
409 | currInputTy->getOffset() != firstInputTy->getOffset()) { |
410 | return nullptr; |
411 | } |
412 | } |
413 | newInputs.push_back(inputNodes[i]->getInput()); |
414 | } |
415 | |
416 | // Create and return a new ConcatNode with newInputs. |
417 | return F->createConcat(CN->getName(), newInputs, CN->getDim()); |
418 | } |
419 | |
420 | /// Given \p CN from \p F, determines if all inputs are either Quantize or |
421 | /// Dequantize (depending on \p QuantNodeClass) that use the same |
422 | /// scale/offset/kind, and if so creates and \returns a new concat with all |
423 | /// inputs as the inputs from the Quantize or Dequantize inputs. Otherwise |
424 | /// \returns nullptr. |
425 | template <class QuantNodeClass> |
426 | static ConcatNode *setupQuantDequantSinkBelowConcat(Function *F, |
427 | ConcatNode *CN) { |
428 | constexpr bool isQuant = std::is_same<QuantizeNode, QuantNodeClass>::value; |
429 | constexpr bool isDeq = std::is_same<DequantizeNode, QuantNodeClass>::value; |
430 | static_assert(isQuant || isDeq, "setupQuantDequantSinkBelowConcat() only " |
431 | "supports Quantize/Dequantize nodes." ); |
432 | // Check if all inputs are Quantize with the same input |
433 | // scale/offset/ElemKind. |
434 | std::vector<QuantNodeClass *> qNodes; |
435 | qNodes.reserve(CN->getInputs().size()); |
436 | for (auto &concatInput : CN->getInputs()) { |
437 | QuantNodeClass *Q = dyn_cast<QuantNodeClass>(concatInput); |
438 | if (!Q) { |
439 | return nullptr; |
440 | } |
441 | qNodes.push_back(Q); |
442 | } |
443 | |
444 | // Gather all inputs of the nodes in qNodes here. |
445 | std::vector<NodeValue> newInputs; |
446 | newInputs.reserve(qNodes.size()); |
447 | newInputs.push_back(qNodes[0]->getInput()); |
448 | |
449 | // Check the CN's first input's type to check against all other inputs. Use |
450 | // the output of Quantize or input of Dequantize. |
451 | const TypeRef firstTy = isQuant ? qNodes[0]->getResult().getType() |
452 | : qNodes[0]->getInput().getType(); |
453 | |
454 | // Check that all inputs have the same scale/offset/type. |
455 | for (size_t i = 1, e = qNodes.size(); i < e; i++) { |
456 | const TypeRef currTy = isQuant ? qNodes[i]->getResult().getType() |
457 | : qNodes[i]->getInput().getType(); |
458 | if (currTy->getScale() != firstTy->getScale() || |
459 | currTy->getOffset() != firstTy->getOffset() || |
460 | currTy->getElementType() != firstTy->getElementType()) { |
461 | return nullptr; |
462 | } |
463 | newInputs.push_back(qNodes[i]->getInput()); |
464 | } |
465 | |
466 | // Create and return a new ConcatNode with newInputs. |
467 | return F->createConcat(CN->getName(), newInputs, CN->getDim()); |
468 | } |
469 | |
470 | /// Given \p CN from \p F, determines if all inputs are Tanh |
471 | /// and if so creates and \returns a new concat with all |
472 | /// inputs as the inputs from the Tanh inputs. Otherwise |
473 | /// \returns nullptr. |
474 | static ConcatNode *setupTanhSinkBelowConcat(Function *F, ConcatNode *CN) { |
475 | // Check if all inputs are Tanh |
476 | std::vector<TanhNode *> tanhNodes; |
477 | tanhNodes.reserve(CN->getInputs().size()); |
478 | for (auto &concatInput : CN->getInputs()) { |
479 | TanhNode *T = dyn_cast<TanhNode>(concatInput); |
480 | if (!T) { |
481 | return nullptr; |
482 | } |
483 | tanhNodes.push_back(T); |
484 | } |
485 | |
486 | // Gather all inputs of the nodes in tanhNodes. |
487 | std::vector<NodeValue> newInputs; |
488 | newInputs.reserve(tanhNodes.size()); |
489 | for (size_t i = 0, e = tanhNodes.size(); i < e; i++) { |
490 | newInputs.emplace_back(tanhNodes[i]->getInput()); |
491 | } |
492 | // Create and return a new ConcatNode with newInputs. |
493 | return F->createConcat(CN->getName(), newInputs, CN->getDim()); |
494 | } |
495 | |
496 | bool SinkConversions::run(Function *F, const CompilationContext &cctx) { |
497 | LOG_SCOPE(F->getLogContext(), getName()); |
498 | bool changed = false; |
499 | auto &nodes = F->getNodes(); |
500 | // For each node: |
501 | for (auto &N : nodes) { |
502 | ConcatNode *CN = dyn_cast<ConcatNode>(&N); |
503 | if (!CN) { |
504 | continue; |
505 | } |
506 | const Node *firstNode = CN->getInputs().front().getNode(); |
507 | |
508 | // Sink Dequantize below Concat nodes. |
509 | if (firstNode->getKind() == Kinded::Kind::DequantizeNodeKind) { |
510 | ConcatNode *newCN = |
511 | setupQuantDequantSinkBelowConcat<DequantizeNode>(F, CN); |
512 | if (!newCN) { |
513 | continue; |
514 | } |
515 | |
516 | DequantizeNode *newDequantize = |
517 | F->createDequantize(CN->getName().str() + "_dequantize" , newCN, |
518 | CN->getResult().getType()); |
519 | |
520 | CN->getResult().replaceAllUsesOfWith(newDequantize->getResult()); |
521 | changed = true; |
522 | continue; |
523 | } |
524 | |
525 | // Sink Quantize below Concat nodes. |
526 | if (firstNode->getKind() == Kinded::Kind::QuantizeNodeKind) { |
527 | ConcatNode *newCN = setupQuantDequantSinkBelowConcat<QuantizeNode>(F, CN); |
528 | if (!newCN) { |
529 | continue; |
530 | } |
531 | |
532 | const TypeRef QTy = |
533 | llvm::cast<QuantizeNode>(firstNode)->getResult().getType(); |
534 | const TypeRef concatQTy = F->getParent()->uniqueType( |
535 | QTy->getElementType(), newCN->getResult().dims(), QTy->getScale(), |
536 | QTy->getOffset()); |
537 | QuantizeNode *newQuantize = F->createQuantize( |
538 | CN->getName().str() + "_quantize" , newCN, concatQTy); |
539 | |
540 | CN->getResult().replaceAllUsesOfWith(newQuantize->getResult()); |
541 | changed = true; |
542 | continue; |
543 | } |
544 | |
545 | // Sink ConvertTo below Concat nodes. |
546 | if (firstNode->getKind() == Kinded::Kind::ConvertToNodeKind) { |
547 | ConcatNode *newCN = setupConvertToSinkBelowConcat(F, CN); |
548 | if (!newCN) { |
549 | continue; |
550 | } |
551 | auto *newConvertTo = |
552 | F->createConvertTo(CN->getName().str() + "_convert_to" , newCN, |
553 | CN->getResult().getType()); |
554 | CN->getResult().replaceAllUsesOfWith(newConvertTo->getResult()); |
555 | changed = true; |
556 | continue; |
557 | } |
558 | |
559 | // Sink Tanh below Concat nodes. |
560 | if (cctx.optimizationOpts.sinkTanhBelowConcat) { |
561 | if (firstNode->getKind() == Kinded::Kind::TanhNodeKind) { |
562 | ConcatNode *newCN = setupTanhSinkBelowConcat(F, CN); |
563 | if (!newCN) { |
564 | continue; |
565 | } |
566 | |
567 | const TypeRef TTy = |
568 | llvm::cast<TanhNode>(firstNode)->getResult().getType(); |
569 | const TypeRef concatTy = F->getParent()->uniqueType( |
570 | TTy->getElementType(), newCN->getResult().dims()); |
571 | TanhNode *newTanh = |
572 | F->createTanh(CN->getName().str() + "_tanh" , concatTy, newCN); |
573 | |
574 | CN->getResult().replaceAllUsesOfWith(newTanh->getResult()); |
575 | changed = true; |
576 | continue; |
577 | } |
578 | } |
579 | } |
580 | |
581 | return changed; |
582 | } |
583 | |
584 | /// Sink Quantize(Concat(...)) -> Concat(Quantize(...)). This allows for |
585 | /// concatenating less data, and if there are some inputs that are already |
586 | /// quantized and are being dequantized just for the concat then we can skip |
587 | /// this conversion. |
588 | bool SinkConcatBelowQuantize::run(Function *F, const CompilationContext &cctx) { |
589 | LOG_SCOPE(F->getLogContext(), getName()); |
590 | bool changed = false; |
591 | auto &nodes = F->getNodes(); |
592 | // For each node: |
593 | for (auto &N : nodes) { |
594 | QuantizeNode *QN = dyn_cast<QuantizeNode>(&N); |
595 | if (!QN) { |
596 | continue; |
597 | } |
598 | |
599 | ConcatNode *CN = dyn_cast<ConcatNode>(QN->getInput()); |
600 | if (!CN || CN->getNumUsers() > 1) { |
601 | continue; |
602 | } |
603 | |
604 | // For all inputs to the current CN, add quantize nodes to them all using |
605 | // the same scale/offset as QN and put the quantize nodes in newQuantInputs. |
606 | std::vector<NodeValue> newQuantInputs; |
607 | for (const NodeValue &inCN : CN->getInputs()) { |
608 | TypeRef newOutTy = F->getParent()->uniqueTypeWithNewShape( |
609 | QN->getResult().getType(), inCN.dims()); |
610 | QuantizeNode *quantInCN = F->createQuantize( |
611 | inCN.getNode()->getName().str() + "_quant" , inCN, newOutTy); |
612 | newQuantInputs.push_back(quantInCN); |
613 | } |
614 | |
615 | // Create a new CN with the quantized inputs and replace QN with it. |
616 | ConcatNode *newCN = |
617 | F->createConcat(CN->getName(), newQuantInputs, CN->getDim()); |
618 | QN->getResult().replaceAllUsesOfWith(newCN->getResult()); |
619 | changed = true; |
620 | } |
621 | |
622 | return changed; |
623 | } |
624 | |
625 | /// If \p N is a TransposeNode with all of the same node kind of users, then |
626 | /// \returns that TransposeNode, else \returns nullptr. For example, if \p N is |
627 | /// a TransposeNode with two QuantizeNode users, this will return the |
628 | /// TransposeNode, but if it had one QuantizeNode and one MatMul node then it |
629 | /// will return nullptr. |
630 | static TransposeNode *getTransposeNodeWithAllSameUserKind(Node *N) { |
631 | auto *TN = dyn_cast<TransposeNode>(N); |
632 | if (!TN) { |
633 | return nullptr; |
634 | } |
635 | if (TN->getNumUsers() <= 1) { |
636 | return TN; |
637 | } |
638 | auto firstKind = N->getUsers().front().getUser()->getKind(); |
639 | for (auto &U : N->getUsers()) { |
640 | if (U.getUser()->getKind() != firstKind) { |
641 | return nullptr; |
642 | } |
643 | } |
644 | return TN; |
645 | } |
646 | |
647 | /// Code Sinking. |
648 | bool SinkCode::run(Function *F, const CompilationContext &cctx) { |
649 | LOG_SCOPE(F->getLogContext(), getName()); |
650 | bool changed = false; |
651 | auto &nodes = F->getNodes(); |
652 | // For each node: |
653 | for (auto &N : nodes) { |
654 | auto *node = &N; |
655 | |
656 | // Sink Reshape/Transpose below BatchNormalization. |
657 | if (auto *BN = dyn_cast<BatchNormalizationNode>(node)) { |
658 | |
659 | // Sink Reshape below BatchNormalization. |
660 | if (auto *RS = dyn_cast<ReshapeNode>(BN->getInput())) { |
661 | auto inDims = RS->getInput().dims(); |
662 | auto outDims = RS->getResult().dims(); |
663 | unsigned_t newChannelIdx; |
664 | |
665 | // Skip sinking if: 1) the input was less than 3 dimensions, |
666 | // because we need spatial dimensions in addition to batch |
667 | // and channel or 2) if it is 3D data because the reshapes |
668 | // are deliberately introduced to phrase 3D BatchNormalization |
669 | // as a 2D one. |
670 | if (RS->getInput().dims().size() < 3 || |
671 | RS->getInput().dims().size() == 5) { |
672 | continue; |
673 | } |
674 | |
675 | // Reshape should not change the BatchNorm ChannelIdx dimensions. |
676 | // Only NH[W]C and NCH[W] are allowed. |
677 | if (BN->getChannelIdx() == outDims.size() - 1) { |
678 | if (inDims[inDims.size() - 1] != outDims[outDims.size() - 1]) { |
679 | continue; |
680 | } |
681 | newChannelIdx = inDims.size() - 1; |
682 | } else if (BN->getChannelIdx() == 1) { |
683 | // Note: index '1' maps to C in NCH[W] layout. |
684 | if (inDims[1] != outDims[1]) { |
685 | continue; |
686 | } |
687 | newChannelIdx = 1; |
688 | } else { |
689 | continue; |
690 | } |
691 | |
692 | // Reshape should not change the batch dimension. |
693 | if (inDims[0] != outDims[0]) { |
694 | continue; |
695 | } |
696 | |
697 | auto bnOutTy = F->getParent()->uniqueTypeWithNewShape( |
698 | BN->getResult().getType(), RS->getInput().getType()); |
699 | auto rsInputType = RS->getInput().getType(); |
700 | glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape( |
701 | bnOutTy, rsInputType->dims()); |
702 | auto *newBN = F->createBatchNormalization( |
703 | BN->getName(), outTy, RS->getInput(), BN->getBias(), BN->getScale(), |
704 | BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(), |
705 | BN->getMomentum()); |
706 | auto *newRS = F->createReshape(RS->getName(), newBN, |
707 | RS->getResult().dims(), RS->getLayout()); |
708 | BN->getResult().replaceAllUsesOfWith(newRS); |
709 | changed = true; |
710 | continue; |
711 | } |
712 | |
713 | // Sink Transpose below batch normalization nodes: |
714 | if (auto *TR = dyn_cast<TransposeNode>(BN->getInput())) { |
715 | |
716 | // Figure out where we transposed the channel index for batch |
717 | // normalization. |
718 | unsigned_t idx = BN->getChannelIdx(); |
719 | unsigned_t newChannelIdx = TR->getShuffle()[idx]; |
720 | |
721 | auto bnOutTy = BN->getResult().getType(); |
722 | auto trInputType = TR->getInput().getType(); |
723 | glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape( |
724 | bnOutTy, trInputType->dims()); |
725 | |
726 | auto *NewBN = F->createBatchNormalization( |
727 | BN->getName(), outTy, TR->getInput(), BN->getBias(), BN->getScale(), |
728 | BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(), |
729 | BN->getMomentum()); |
730 | NewBN->setPredicate(node->getPredicate()); |
731 | auto *newTR = F->createTranspose(TR->getName(), NewBN, TR->getShuffle(), |
732 | TR->getLayout()); |
733 | newTR->setPredicate(node->getPredicate()); |
734 | |
735 | BN->getResult().replaceAllUsesOfWith(newTR); |
736 | changed = true; |
737 | continue; |
738 | } |
739 | } |
740 | |
741 | if (auto *RL = dyn_cast<ReluNode>(node)) { |
742 | // Sink Transpose below batch RELU nodes. |
743 | if (auto *TR = dyn_cast<TransposeNode>(RL->getInput())) { |
744 | // Keep the same quantization parameters for ReLU output, but |
745 | // change the shape to appropriate value. |
746 | auto reluOutTy = F->getParent()->uniqueTypeWithNewShape( |
747 | RL->getResult().getType(), TR->getInput().getType()); |
748 | auto *NRL = F->createRELU(RL->getName(), TR->getInput(), reluOutTy); |
749 | NRL->setPredicate(node->getPredicate()); |
750 | auto *newTR = F->createTranspose(TR->getName(), NRL, TR->getShuffle(), |
751 | TR->getLayout()); |
752 | newTR->setPredicate(node->getPredicate()); |
753 | RL->getResult().replaceAllUsesOfWith(newTR); |
754 | changed = true; |
755 | continue; |
756 | } |
757 | |
758 | // Sink Clip below RELU nodes. |
759 | if (ClipNode *CN = dyn_cast<ClipNode>(RL->getInput())) { |
760 | assert(!RL->getResult().getType()->isQuantizedType() && |
761 | "Relu(Clip) means Relu should not be quantized." ); |
762 | ReluNode *newRL = F->createRELU(RL->getName(), CN->getInput()); |
763 | ClipNode *newCN = |
764 | F->createClip(CN->getName(), newRL->getResult(), |
765 | std::max(CN->getMin(), 0.0f), CN->getMax()); |
766 | RL->getResult().replaceAllUsesOfWith(newCN); |
767 | changed = true; |
768 | continue; |
769 | } |
770 | } |
771 | |
772 | // Sink Transpose below Clip nodes. |
773 | if (auto *CL = dyn_cast<ClipNode>(node)) { |
774 | auto *TR = dyn_cast<TransposeNode>(CL->getInput()); |
775 | |
776 | if (!TR) { |
777 | continue; |
778 | } |
779 | |
780 | // Keep the same quantization parameters for Clip output, but |
781 | // change the shape to appropriate value. |
782 | auto clipOutTy = F->getParent()->uniqueTypeWithNewShape( |
783 | CL->getResult().getType(), TR->getInput().getType()); |
784 | auto *NCL = F->createClip(CL->getName(), TR->getInput(), clipOutTy, |
785 | CL->getMin(), CL->getMax()); |
786 | NCL->setPredicate(node->getPredicate()); |
787 | auto *newTR = F->createTranspose(TR->getName(), NCL, TR->getShuffle()); |
788 | newTR->setPredicate(node->getPredicate()); |
789 | CL->getResult().replaceAllUsesOfWith(newTR); |
790 | changed = true; |
791 | continue; |
792 | } |
793 | |
794 | // Sink Transpose below LeakyRelu nodes. |
795 | if (auto *LR = dyn_cast<LeakyReluNode>(node)) { |
796 | auto *TR = dyn_cast<TransposeNode>(LR->getInput()); |
797 | if (!TR) { |
798 | continue; |
799 | } |
800 | auto newLROutTy = F->getParent()->uniqueTypeWithNewShape( |
801 | LR->getResult().getType(), TR->getInput().getType()); |
802 | auto *newLR = F->createLeakyRELU(LR->getName(), newLROutTy, |
803 | TR->getInput(), LR->getAlpha()); |
804 | newLR->setPredicate(node->getPredicate()); |
805 | auto *newTR = F->createTranspose(TR->getName(), newLR, TR->getShuffle()); |
806 | newTR->setPredicate(node->getPredicate()); |
807 | LR->getResult().replaceAllUsesOfWith(newTR); |
808 | changed = true; |
809 | continue; |
810 | } |
811 | |
812 | // Sink Transpose below PRelu with Splat. |
813 | if (auto *PN = dyn_cast<PReluNode>(node)) { |
814 | auto *TR = dyn_cast<TransposeNode>(PN->getInput()); |
815 | if (!TR) { |
816 | continue; |
817 | } |
818 | auto *SN = dyn_cast<SplatNode>(PN->getSlope()); |
819 | if (!SN) { |
820 | continue; |
821 | } |
822 | auto newSNOutTy = F->getParent()->uniqueTypeWithNewShape( |
823 | SN->getResult().getType(), TR->getInput().getType()); |
824 | auto newPNOutTy = F->getParent()->uniqueTypeWithNewShape( |
825 | PN->getResult().getType(), TR->getInput().getType()); |
826 | auto *newSN = F->createSplat(SN->getName(), newSNOutTy, SN->getValue()); |
827 | auto *newPN = |
828 | F->createPRELU(PN->getName(), TR->getInput(), newSN, newPNOutTy); |
829 | auto *newTR = F->createTranspose(TR->getName(), newPN, TR->getShuffle()); |
830 | newPN->setPredicate(node->getPredicate()); |
831 | newTR->setPredicate(node->getPredicate()); |
832 | PN->getResult().replaceAllUsesOfWith(newTR); |
833 | changed = true; |
834 | continue; |
835 | } |
836 | |
837 | // Sink Transpose below Sigmoid nodes. |
838 | if (auto *SI = dyn_cast<SigmoidNode>(node)) { |
839 | auto *TR = dyn_cast<TransposeNode>(SI->getInput()); |
840 | |
841 | if (!TR) { |
842 | continue; |
843 | } |
844 | |
845 | auto *NSI = F->createSigmoid(SI->getName(), TR->getInput()); |
846 | NSI->setPredicate(node->getPredicate()); |
847 | auto *newTR = F->createTranspose(TR->getName(), NSI, TR->getShuffle(), |
848 | TR->getLayout()); |
849 | newTR->setPredicate(node->getPredicate()); |
850 | SI->getResult().replaceAllUsesOfWith(newTR); |
851 | changed = true; |
852 | continue; |
853 | } |
854 | |
855 | // Sink Transpose below Tile nodes. |
856 | if (auto *TN = dyn_cast<TileNode>(node)) { |
857 | auto *TR = dyn_cast<TransposeNode>(TN->getInput()); |
858 | |
859 | if (!TR) { |
860 | continue; |
861 | } |
862 | |
863 | auto *newTN = F->createTile(TN->getName(), TR->getInput(), TN->getCount(), |
864 | TR->getShuffle()[TN->getAxis()]); |
865 | newTN->setPredicate(node->getPredicate()); |
866 | auto *newTR = F->createTranspose(TR->getName(), newTN, TR->getShuffle(), |
867 | TR->getLayout()); |
868 | newTR->setPredicate(node->getPredicate()); |
869 | TN->getResult().replaceAllUsesOfWith(newTR); |
870 | changed = true; |
871 | continue; |
872 | } |
873 | |
874 | // Sink Transpose below Pad nodes. |
875 | if (auto *padNode = dyn_cast<PadNode>(node)) { |
876 | auto *transposeNode = dyn_cast<TransposeNode>(padNode->getInput()); |
877 | |
878 | if (!transposeNode) { |
879 | continue; |
880 | } |
881 | |
882 | // The transpose shuffle specifies the source dimension. |
883 | // When sinking Transpose below Pad, shuffle describes the target |
884 | // dimension. |
885 | auto shuffle = transposeNode->getShuffle(); |
886 | |
887 | // Shuffle the Pad output type and the padding attribute. |
888 | auto outPadType = padNode->getResult().getType(); |
889 | auto outPadShape = outPadType->dims(); |
890 | auto pads = padNode->getPads(); |
891 | size_t numDims = outPadShape.size(); |
892 | std::vector<dim_t> newOutPadShape(numDims); |
893 | std::vector<int> newPads(2 * numDims); |
894 | for (size_t i = 0; i < outPadShape.size(); i++) { |
895 | newOutPadShape[shuffle[i]] = outPadShape[i]; |
896 | newPads[shuffle[i]] = pads[i]; |
897 | newPads[shuffle[i] + numDims] = pads[i + numDims]; |
898 | } |
899 | |
900 | // New pad |
901 | auto newOutPadType = |
902 | F->getParent()->uniqueTypeWithNewShape(outPadType, newOutPadShape); |
903 | auto *NewPadNode = F->createPad( |
904 | padNode->getName(), transposeNode->getInput(), newOutPadType, |
905 | padNode->getMode(), newPads, padNode->getValue()); |
906 | NewPadNode->setPredicate(node->getPredicate()); |
907 | auto *newTransposeNode = |
908 | F->createTranspose(transposeNode->getName(), NewPadNode, shuffle); |
909 | newTransposeNode->setPredicate(node->getPredicate()); |
910 | padNode->getResult().replaceAllUsesOfWith(newTransposeNode); |
911 | changed = true; |
912 | continue; |
913 | } |
914 | |
915 | // Sink Transpose below Tanh nodes. |
916 | if (auto *TN = dyn_cast<TanhNode>(node)) { |
917 | auto *TR = dyn_cast<TransposeNode>(TN->getInput()); |
918 | |
919 | if (!TR) { |
920 | continue; |
921 | } |
922 | |
923 | auto *NTN = F->createTanh(TN->getName(), TR->getInput()); |
924 | NTN->setPredicate(node->getPredicate()); |
925 | auto *newTR = F->createTranspose(TR->getName(), NTN, TR->getShuffle(), |
926 | TR->getLayout()); |
927 | newTR->setPredicate(node->getPredicate()); |
928 | TN->getResult().replaceAllUsesOfWith(newTR); |
929 | changed = true; |
930 | continue; |
931 | } |
932 | |
933 | // Remove 'identity' transpose operations. |
934 | if (auto *TR = dyn_cast<TransposeNode>(node)) { |
935 | auto mask = TR->getShuffle(); |
936 | |
937 | if (isIdentityShuffle(mask)) { |
938 | TR->getResult().replaceAllUsesOfWith(TR->getInput()); |
939 | changed = true; |
940 | continue; |
941 | } |
942 | } |
943 | |
944 | // Merge consecutive Transpose operations. |
945 | if (auto *TR1 = dyn_cast<TransposeNode>(node)) { |
946 | auto *TR2 = dyn_cast<TransposeNode>(TR1->getInput()); |
947 | |
948 | if (!TR2) { |
949 | continue; |
950 | } |
951 | |
952 | auto mask1 = TR1->getShuffle(); |
953 | auto mask2 = TR2->getShuffle(); |
954 | assert(mask1.size() == mask2.size() && "Invalid mask size" ); |
955 | |
956 | llvm::SmallVector<unsigned_t, max_tensor_dimensions> newMask; |
957 | newMask.resize(mask2.size()); |
958 | |
959 | for (size_t i = 0, end = mask2.size(); i < end; i++) { |
960 | newMask[i] = mask2[mask1[i]]; |
961 | } |
962 | |
963 | auto *newTR = F->createTranspose("tranpose" , TR2->getInput(), newMask); |
964 | TR1->getResult().replaceAllUsesOfWith(newTR->getResult()); |
965 | changed = true; |
966 | continue; |
967 | } |
968 | |
969 | if (auto *CS = dyn_cast<ChannelShuffleNode>(node)) { |
970 | // Sink Transpose below ChannelShuffle. |
971 | if (sinkTranposeBelowChannelShuffle(F, CS)) { |
972 | changed = true; |
973 | continue; |
974 | } |
975 | } |
976 | |
977 | // Sink Transpose below Arithmetic nodes. |
978 | if (node->isArithmetic()) { |
979 | TransposeNode *LTR = |
980 | dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::LHSIdx)); |
981 | TransposeNode *RTR = |
982 | dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::RHSIdx)); |
983 | |
984 | if (!LTR || !RTR) { |
985 | // If one of the sides is a splat, it can be seen as |
986 | // transpose (splat'). Similarly, if one of the sides is a Constant, |
987 | // it can be seen as tranpose (Constant'). |
988 | if (isa<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx)) && RTR) { |
989 | // Build splat' for LHS. |
990 | auto *SN = |
991 | dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx)); |
992 | auto *NS = F->createSplat("splat" , RTR->getInput().getType(), |
993 | SN->getValue()); |
994 | LTR = F->createTranspose("transpose" , NS, RTR->getShuffle(), |
995 | RTR->getLayout()); |
996 | changed = true; |
997 | } else if (isa<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx)) && |
998 | LTR) { |
999 | // Build splat' for RHS. |
1000 | auto *SN = |
1001 | dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx)); |
1002 | auto *NS = F->createSplat("splat" , LTR->getInput().getType(), |
1003 | SN->getValue()); |
1004 | RTR = F->createTranspose("transpose" , NS, LTR->getShuffle(), |
1005 | LTR->getLayout()); |
1006 | changed = true; |
1007 | } else if (isa<Constant>(node->getNthInput(ArithmeticNode::LHSIdx)) && |
1008 | RTR) { |
1009 | // Build Constant' for for LHS. |
1010 | auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::LHSIdx)); |
1011 | LTR = insertMatchingTransposeAfterConstant(F, C, RTR); |
1012 | changed = true; |
1013 | } else if (isa<Constant>(node->getNthInput(ArithmeticNode::RHSIdx)) && |
1014 | LTR) { |
1015 | // Build Constant' for for RHS. |
1016 | auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::RHSIdx)); |
1017 | RTR = insertMatchingTransposeAfterConstant(F, C, LTR); |
1018 | changed = true; |
1019 | } else { |
1020 | continue; |
1021 | } |
1022 | } |
1023 | // The masks of the transposes on both sizes must match. |
1024 | if (LTR->getShuffle() != RTR->getShuffle()) { |
1025 | continue; |
1026 | } |
1027 | |
1028 | Node *newAN = nullptr; |
1029 | |
1030 | #define ARITHMETIC_CASE(NODE_NAME_) \ |
1031 | case glow::Kinded::Kind::NODE_NAME_##NodeKind: \ |
1032 | newAN = \ |
1033 | F->create##NODE_NAME_(node->getName(), \ |
1034 | F->getParent()->uniqueTypeWithNewShape( \ |
1035 | node->getType(ArithmeticNode::ResultIdx), \ |
1036 | LTR->getInput().getType()), \ |
1037 | LTR->getInput(), RTR->getInput()); \ |
1038 | break; |
1039 | |
1040 | #define BOOLEAN_OP_CASE(NODE_NAME_) \ |
1041 | case glow::Kinded::Kind::NODE_NAME_##NodeKind: \ |
1042 | newAN = F->create##NODE_NAME_(node->getName(), LTR->getInput(), \ |
1043 | RTR->getInput()); \ |
1044 | break; |
1045 | |
1046 | switch (node->getKind()) { |
1047 | ARITHMETIC_CASE(Add); |
1048 | ARITHMETIC_CASE(Mul); |
1049 | ARITHMETIC_CASE(Sub); |
1050 | ARITHMETIC_CASE(Div); |
1051 | ARITHMETIC_CASE(Fmod); |
1052 | ARITHMETIC_CASE(Max); |
1053 | ARITHMETIC_CASE(Min); |
1054 | ARITHMETIC_CASE(Pow); |
1055 | BOOLEAN_OP_CASE(CmpLTE); |
1056 | BOOLEAN_OP_CASE(CmpEQ); |
1057 | default: |
1058 | llvm_unreachable("Unhandled node" ); |
1059 | } |
1060 | #undef BOOLEAN_OP_CASE |
1061 | #undef ARITHMETIC_CASE |
1062 | |
1063 | newAN->setPredicate(node->getPredicate()); |
1064 | changed = true; |
1065 | auto *newTR = F->createTranspose(LTR->getName(), newAN, LTR->getShuffle(), |
1066 | LTR->getLayout()); |
1067 | newTR->setPredicate(node->getPredicate()); |
1068 | node->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(newTR); |
1069 | } |
1070 | |
1071 | if (auto *Q = dyn_cast<QuantizeNode>(node)) { |
1072 | // Sink TransposeNode below QuantizedNode. |
1073 | if (auto *TR = getTransposeNodeWithAllSameUserKind(Q->getInput())) { |
1074 | auto newQType = F->getParent()->uniqueTypeWithNewShape( |
1075 | Q->getResult().getType(), TR->getInput().dims()); |
1076 | auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType); |
1077 | auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle()); |
1078 | Q->getResult().replaceAllUsesOfWith(newTR); |
1079 | changed = true; |
1080 | continue; |
1081 | } |
1082 | |
1083 | // Sink Reshape below Quantize. |
1084 | if (auto *RN = dyn_cast<ReshapeNode>(Q->getInput())) { |
1085 | auto newQType = F->getParent()->uniqueTypeWithNewShape( |
1086 | Q->getResult().getType(), RN->getInput().dims()); |
1087 | auto *newQ = F->createQuantize(Q->getName(), RN->getInput(), newQType); |
1088 | auto *newRN = F->createReshape(RN->getName(), newQ, |
1089 | RN->getResult().dims(), RN->getLayout()); |
1090 | Q->getResult().replaceAllUsesOfWith(newRN->getResult()); |
1091 | changed = true; |
1092 | continue; |
1093 | } |
1094 | } |
1095 | |
1096 | // Sink Reshape below ConvertTo. |
1097 | if (auto *CN = dyn_cast<ConvertToNode>(node)) { |
1098 | auto *RN = dyn_cast<ReshapeNode>(CN->getInput()); |
1099 | if (!RN) { |
1100 | continue; |
1101 | } |
1102 | auto *newCN = F->createConvertTo(CN->getName(), RN->getInput(), |
1103 | CN->getResult().getElementType()); |
1104 | auto *newRN = F->createReshape(RN->getName(), newCN, |
1105 | RN->getResult().dims(), RN->getLayout()); |
1106 | CN->getResult().replaceAllUsesOfWith(newRN->getResult()); |
1107 | changed = true; |
1108 | continue; |
1109 | } |
1110 | |
1111 | // Sink TransposeNode below DequantizedNode. |
1112 | // If it doesn't work out it will be re-sinked later. |
1113 | if (auto *D = dyn_cast<DequantizeNode>(node)) { |
1114 | auto *TR = dyn_cast<TransposeNode>(D->getInput()); |
1115 | if (!TR) { |
1116 | continue; |
1117 | } |
1118 | |
1119 | auto newDType = F->getParent()->uniqueTypeWithNewShape( |
1120 | D->getResult().getType(), TR->getInput().dims()); |
1121 | auto *newD = F->createDequantize(D->getName(), TR->getInput(), newDType); |
1122 | auto *newTR = F->createTranspose(TR->getName(), newD, TR->getShuffle()); |
1123 | D->getResult().replaceAllUsesOfWith(newTR); |
1124 | changed = true; |
1125 | } |
1126 | |
1127 | // Sink Transpose below RescaleQuantized. |
1128 | // Potentially exposes opportunity to be combined up with Convolution. |
1129 | // If it doesn't work out it will be re-sinked later. |
1130 | if (auto *RQ = dyn_cast<RescaleQuantizedNode>(node)) { |
1131 | auto *TR = dyn_cast<TransposeNode>(RQ->getInput()); |
1132 | if (!TR) { |
1133 | continue; |
1134 | } |
1135 | |
1136 | auto newRQType = F->getParent()->uniqueTypeWithNewShape( |
1137 | RQ->getResult().getType(), TR->getInput().getType()); |
1138 | auto *newRQ = |
1139 | F->createRescaleQuantized(RQ->getName(), TR->getInput(), newRQType); |
1140 | auto *newTR = F->createTranspose(TR->getName(), newRQ, TR->getShuffle(), |
1141 | TR->getLayout()); |
1142 | RQ->getResult().replaceAllUsesOfWith(newTR); |
1143 | changed = true; |
1144 | } |
1145 | |
1146 | if (auto *CN = dyn_cast<ConcatNode>(node)) { |
1147 | const Node *firstNode = CN->getInputs().front().getNode(); |
1148 | // Sink RELU below batch concat nodes. |
1149 | if (firstNode->getKind() == Kinded::Kind::ReluNodeKind) { |
1150 | llvm::SmallVector<NodeValue, 6> CNInputs; |
1151 | for (auto &input : CN->getInputs()) { |
1152 | auto *inputRL = dyn_cast<ReluNode>(input); |
1153 | if (!inputRL) { |
1154 | break; |
1155 | } |
1156 | CNInputs.push_back(inputRL->getInput()); |
1157 | } |
1158 | |
1159 | if (CNInputs.size() == CN->getNumInputs()) { |
1160 | auto *newCN = F->createConcat(CN->getName(), CNInputs, CN->getDim()); |
1161 | newCN->setPredicate(node->getPredicate()); |
1162 | auto name = CN->getNthInput(0).getNode()->getName(); |
1163 | auto *newRL = F->createRELU(name, newCN, CN->getResult().getType()); |
1164 | newRL->setPredicate(node->getPredicate()); |
1165 | CN->getResult().replaceAllUsesOfWith(newRL); |
1166 | changed = true; |
1167 | } |
1168 | continue; |
1169 | } |
1170 | |
1171 | // Sink Transpose below concat nodes. |
1172 | if (firstNode->getKind() == Kinded::Kind::TransposeNodeKind) { |
1173 | llvm::SmallVector<NodeValue, 6> transVector; |
1174 | auto inputIter = CN->getInputs().begin(); |
1175 | auto *firstInput = dyn_cast<TransposeNode>(*inputIter); |
1176 | if (!firstInput) { |
1177 | continue; |
1178 | } |
1179 | |
1180 | transVector.push_back(firstInput->getInput()); |
1181 | auto shuffle = firstInput->getShuffle(); |
1182 | // If the shuffle masks don't agree or not all inputs are Transpose then |
1183 | // bail out. |
1184 | for (++inputIter; inputIter != CN->getInputs().end(); ++inputIter) { |
1185 | auto *tTR = dyn_cast<TransposeNode>(*inputIter); |
1186 | if (!tTR || tTR->getShuffle() != shuffle) { |
1187 | break; |
1188 | } |
1189 | transVector.push_back(tTR->getInput()); |
1190 | } |
1191 | |
1192 | if (transVector.size() != CN->getNumInputs()) { |
1193 | continue; |
1194 | } |
1195 | |
1196 | // Figure out where we transposed the channel index for batch |
1197 | // normalization. |
1198 | unsigned_t idx = CN->getDim(); |
1199 | unsigned_t newChannelIdx = shuffle[idx]; |
1200 | |
1201 | auto *newCN = |
1202 | F->createConcat(CN->getName(), transVector, newChannelIdx); |
1203 | newCN->setPredicate(node->getPredicate()); |
1204 | auto *newTR = F->createTranspose(firstInput->getName(), newCN, |
1205 | firstInput->getShuffle(), |
1206 | firstInput->getLayout()); |
1207 | newTR->setPredicate(node->getPredicate()); |
1208 | CN->getResult().replaceAllUsesOfWith(newTR); |
1209 | changed = true; |
1210 | continue; |
1211 | } |
1212 | } |
1213 | } // For all nodes in the graph. |
1214 | |
1215 | // Transformations to sink nodes below Slice. Outlined into a separate loop to |
1216 | // prevent Transpose/Slice sinking to affect them. |
1217 | for (auto &N : nodes) { |
1218 | auto *node = &N; |
1219 | // Sink BatchNorm below Slice. |
1220 | if (auto *SN = dyn_cast<SliceNode>(node)) { |
1221 | auto *BN = dyn_cast<BatchNormalizationNode>(SN->getInput()); |
1222 | if (!BN || !BN->hasOneUse()) { |
1223 | continue; |
1224 | } |
1225 | |
1226 | // Don't support sinking below Slice which affects depth. |
1227 | if (SN->getInput().dims()[BN->getChannelIdx()] != |
1228 | SN->getResult().dims()[BN->getChannelIdx()]) { |
1229 | continue; |
1230 | } |
1231 | |
1232 | auto newSNType = F->getParent()->uniqueTypeWithNewShape( |
1233 | BN->getInput().getType(), SN->getResult().dims()); |
1234 | auto *newSN = F->createSlice(SN->getName(), BN->getInput(), |
1235 | SN->getStart(), newSNType); |
1236 | auto *newBN = F->createBatchNormalization( |
1237 | BN->getName(), SN->getResult().getType(), newSN, BN->getBias(), |
1238 | BN->getScale(), BN->getMean(), BN->getVar(), BN->getChannelIdx(), |
1239 | BN->getEpsilon(), BN->getMomentum()); |
1240 | SN->getResult().replaceAllUsesOfWith(newBN); |
1241 | changed = true; |
1242 | } |
1243 | } |
1244 | |
1245 | return changed; |
1246 | } |
1247 | |
1248 | /// Code Hoisting. |
1249 | bool HoistCode::run(Function *F, const CompilationContext &cctx) { |
1250 | LOG_SCOPE(F->getLogContext(), getName()); |
1251 | bool changed = false; |
1252 | auto &nodes = F->getNodes(); |
1253 | // For each node: |
1254 | for (auto &N : nodes) { |
1255 | auto *node = &N; |
1256 | |
1257 | // Hoist Transpose above Tile nodes. |
1258 | if (auto *TR = dyn_cast<TransposeNode>(node)) { |
1259 | auto *TN = dyn_cast<TileNode>(TR->getInput()); |
1260 | |
1261 | if (!TN) { |
1262 | continue; |
1263 | } |
1264 | |
1265 | auto *newTR = F->createTranspose(TR->getName(), TN->getInput(), |
1266 | TR->getShuffle(), TR->getLayout()); |
1267 | newTR->setPredicate(node->getPredicate()); |
1268 | auto *newTN = |
1269 | F->createTile(TN->getName(), newTR, TN->getCount(), |
1270 | invertShuffle(TR->getShuffle())[TN->getAxis()]); |
1271 | newTN->setPredicate(node->getPredicate()); |
1272 | TR->getResult().replaceAllUsesOfWith(newTN); |
1273 | changed = true; |
1274 | continue; |
1275 | } |
1276 | } |
1277 | |
1278 | return changed; |
1279 | } |
1280 | |
1281 | /// Reshape Sinking. |
1282 | bool SinkReshapes::run(Function *F, const CompilationContext &cctx) { |
1283 | LOG_SCOPE(F->getLogContext(), getName()); |
1284 | bool changed = false; |
1285 | auto &nodes = F->getNodes(); |
1286 | // For each node: |
1287 | for (auto &N : nodes) { |
1288 | auto *node = &N; |
1289 | |
1290 | // Sink Reshape below eltwise nodes. |
1291 | if (!node->isDataParallel() || node->hasSideEffects()) { |
1292 | continue; |
1293 | } |
1294 | |
1295 | // Unary eltwise nodes. |
1296 | if (node->getNumInputs() != 1 || node->getNumResults() != 1) { |
1297 | continue; |
1298 | } |
1299 | |
1300 | auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0)); |
1301 | if (!RS) { |
1302 | continue; |
1303 | } |
1304 | |
1305 | // Create new eltwise node. |
1306 | auto in = RS->getInput(); |
1307 | auto out = node->getNthResult(0); |
1308 | auto newTy = |
1309 | F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims()); |
1310 | auto *newN = F->addNode(node->clone()); |
1311 | newN->setNthInput(0, in); |
1312 | newN->setTypeUnsafe(0, newTy); |
1313 | newN->setPredicate(node->getPredicate()); |
1314 | |
1315 | // Create new Reshape. |
1316 | auto *newRS = F->createReshape(RS->getName(), newN, |
1317 | RS->getResult().getType()->dims()); |
1318 | newRS->setPredicate(node->getPredicate()); |
1319 | out.replaceAllUsesOfWith(newRS->getResult()); |
1320 | |
1321 | changed = true; |
1322 | } |
1323 | return changed; |
1324 | } |
1325 | |
1326 | /// Remove unnecessary padding and reduce filters for Convolution nodes with |
1327 | /// small input tensors. |
1328 | bool OptimizeSmallConv::run(Function *F, const CompilationContext &cctx) { |
1329 | LOG_SCOPE(F->getLogContext(), getName()); |
1330 | bool changed = false; |
1331 | for (auto &N : F->getNodes()) { |
1332 | auto *CN = dyn_cast<ConvolutionNode>(&N); |
1333 | if (!CN) { |
1334 | continue; |
1335 | } |
1336 | |
1337 | // Consider a Convolution "small", if its output is 1x1. |
1338 | // The transformation doesn't support dilation. |
1339 | auto dilation = CN->getDilation(); |
1340 | ShapeNHWC odim(CN->getResult().dims()); |
1341 | if (odim.h > 1 || odim.w > 1 || !isUniformArray(dilation, 1u)) { |
1342 | continue; |
1343 | } |
1344 | |
1345 | // Dealing with stride=1 Convoltuion nodes is generally easier, so try |
1346 | // to canonicalize stride to 1 if possible. |
1347 | ShapeNHWC idim(CN->getInput().dims()); |
1348 | std::vector<unsigned_t> strides(CN->getStrides()); |
1349 | std::vector<unsigned_t> kernels(CN->getKernels()); |
1350 | std::vector<unsigned_t> pads(CN->getPads()); |
1351 | auto newOutHW = calculateConvPoolOutputDims(idim.h, idim.w, kernels, {1, 1}, |
1352 | pads, dilation); |
1353 | if (newOutHW.first == 1 && newOutHW.second == 1) { |
1354 | strides = {1, 1}; |
1355 | } |
1356 | |
1357 | // Slice off redundant filter parts. |
1358 | auto filters = CN->getFilter(); |
1359 | auto *C = dyn_cast<Constant>(filters); |
1360 | if (C && isUniformArray(llvm::makeArrayRef(strides), 1u) && |
1361 | !isUniformArray(llvm::makeArrayRef(pads), 0u)) { |
1362 | ShapeNHWC fdim(filters.dims()); |
1363 | PaddingTLBR p(llvm::makeArrayRef(pads)); |
1364 | dim_t start[] = {0u, p.top, p.left, 0u}; |
1365 | dim_t end[] = {fdim.n, fdim.h - p.bottom, fdim.w - p.right, fdim.c}; |
1366 | auto *SN = F->createSlice(C->getName(), C, start, end); |
1367 | filters = SN->getResult(); |
1368 | kernels = {unsigned_t(idim.h), unsigned_t(idim.w)}; |
1369 | pads = {0, 0, 0, 0}; |
1370 | } |
1371 | |
1372 | // Check if this node needs any changes. |
1373 | if (filters == CN->getFilter() && |
1374 | llvm::makeArrayRef(strides) == CN->getStrides()) { |
1375 | continue; |
1376 | } |
1377 | |
1378 | auto *newCN = |
1379 | F->createConv(CN->getName(), CN->getInput(), filters, CN->getBias(), |
1380 | CN->getResult().getType(), kernels, strides, pads, |
1381 | CN->getGroup(), dilation); |
1382 | |
1383 | CN->getResult().replaceAllUsesOfWith(newCN->getResult()); |
1384 | changed = true; |
1385 | } |
1386 | |
1387 | return changed; |
1388 | } |
1389 | |
1390 | /// Fold a Convolution dilated manually using Transpose, SpaceToDepth and |
1391 | /// DepthToSpace nodes into a single Convolution node. Pattern: |
1392 | /// NHWC2CHWN -> S2D -> CHWN2NHWC -> Conv -> NHWC2CHWN -> D2S -> CHWN2NHWC |
1393 | bool FoldDilatedConv::run(Function *F, const CompilationContext &cctx) { |
1394 | LOG_SCOPE(F->getLogContext(), getName()); |
1395 | bool changed = false; |
1396 | for (auto &N : F->getNodes()) { |
1397 | // Do matching starting from the last node of the pattern and go backwards. |
1398 | |
1399 | // 1. Transpose CHWN2NHWC. |
1400 | // Do not match this Transpose, since it has likely sank and got separated |
1401 | // from the rest of the nodes in the pattern. Instead, we'll generate a |
1402 | // reverse Transpose after new Convolution for them to be optimized out. |
1403 | |
1404 | // 2. DepthToSpace (represented as Reshape + Transpose[6 dims] + Reshape). |
1405 | // Ignore the last Reshape for the same reasons as Transpose above and start |
1406 | // matching from 6 dim Transpose. |
1407 | auto *D2S = dyn_cast<TransposeNode>(&N); |
1408 | if (!D2S || !D2S->hasOneUse() || |
1409 | D2S->getShuffle() != llvm::makeArrayRef(D2S_DCR)) { |
1410 | continue; |
1411 | } |
1412 | unsigned_t block = D2S->getInput().dims()[3]; |
1413 | NodeValue output = D2S->getResult(); |
1414 | |
1415 | auto *RN = dyn_cast<ReshapeNode>(D2S->getInput()); |
1416 | if (!RN || !RN->hasOneUse()) { |
1417 | continue; |
1418 | } |
1419 | llvm::ArrayRef<dim_t> idim = RN->getInput().dims(); |
1420 | llvm::SmallVector<dim_t, 6> odim = { |
1421 | idim[0], idim[1], idim[2], block, block, idim[3] / (block * block)}; |
1422 | if (RN->getResult().dims() != llvm::makeArrayRef(odim)) { |
1423 | continue; |
1424 | } |
1425 | |
1426 | // 3. Transpose NHWC2CHWN. |
1427 | auto *T1 = dyn_cast<TransposeNode>(RN->getInput()); |
1428 | if (!T1 || !T1->hasOneUse() || |
1429 | T1->getShuffle() != llvm::makeArrayRef(NHWC2CHWN)) { |
1430 | continue; |
1431 | } |
1432 | |
1433 | // 4. Convolution. |
1434 | llvm::StringRef name; |
1435 | llvm::ArrayRef<unsigned_t> kernels, strides, pads, dilation; |
1436 | NodeValue convInput, convResult; |
1437 | auto getConvParams = [&](auto *N) -> bool { |
1438 | if (!N || !N->hasOneUse()) { |
1439 | return false; |
1440 | } |
1441 | name = N->getName(); |
1442 | kernels = N->getKernels(); |
1443 | strides = N->getStrides(); |
1444 | pads = N->getPads(); |
1445 | dilation = N->getDilation(); |
1446 | convInput = N->getInput(); |
1447 | convResult = N->getResult(); |
1448 | return true; |
1449 | }; |
1450 | auto *CN = dyn_cast<ConvolutionNode>(T1->getInput()); |
1451 | auto *CQCN = dyn_cast<ChannelwiseQuantizedConvolutionNode>(T1->getInput()); |
1452 | if (!getConvParams(CN) && !getConvParams(CQCN)) { |
1453 | continue; |
1454 | } |
1455 | if (!isUniformArray(strides, 1u) || !isUniformArray(pads, 0u) || |
1456 | !isUniformArray(dilation, 1u)) { |
1457 | continue; |
1458 | } |
1459 | |
1460 | // 5. Transpose CHWN2NHWC. |
1461 | auto *T2 = dyn_cast<TransposeNode>(convInput); |
1462 | if (!T2 || T2->getShuffle() != llvm::makeArrayRef(CHWN2NHWC)) { |
1463 | continue; |
1464 | } |
1465 | |
1466 | // 6. SpaceToDepth. |
1467 | auto *S2D = dyn_cast<SpaceToDepthNode>(T2->getInput()); |
1468 | if (!S2D || S2D->getBlockSize() != block) { |
1469 | continue; |
1470 | } |
1471 | NodeValue input = S2D->getInput(); |
1472 | |
1473 | // 7. Transpose NHWC2CHWN. |
1474 | // Can potentially be changed/removed by merging with other Transpose nodes, |
1475 | // so don't match it. Instead, will generate a reverse Transpose later. |
1476 | |
1477 | // Create CHWN2NHWC -> Conv -> NHWC2CHWN -> Reshape. |
1478 | // Reshape and Transposes are created to cancel the ones we did not match. |
1479 | auto *newT1 = |
1480 | F->createTranspose(name.str() + "_chwn2nhwc" , input, CHWN2NHWC); |
1481 | |
1482 | auto trOutDims = newT1->getResult().dims(); |
1483 | auto outHW = calculateConvPoolOutputDims( |
1484 | trOutDims[1], trOutDims[2], kernels, strides, pads, {block, block}); |
1485 | auto convOutTy = F->getParent()->uniqueTypeWithNewShape( |
1486 | convResult.getType(), |
1487 | {trOutDims[0], outHW.first, outHW.second, convResult.dims()[3]}); |
1488 | Node *newCN; |
1489 | if (CN) { |
1490 | newCN = F->createConv(name, newT1->getResult(), CN->getFilter(), |
1491 | CN->getBias(), convOutTy, kernels, strides, pads, |
1492 | CN->getGroup(), {block, block}); |
1493 | } else if (CQCN) { |
1494 | newCN = F->createChannelwiseQuantizedConv( |
1495 | name, newT1->getResult(), CQCN->getFilter(), CQCN->getBias(), |
1496 | CQCN->getFilterScales(), CQCN->getFilterOffsets(), |
1497 | CQCN->getBiasScales(), CQCN->getBiasOffsets(), convOutTy, kernels, |
1498 | strides, pads, CQCN->getGroup(), {block, block}, false, false); |
1499 | } else { |
1500 | llvm_unreachable("Convolution must be in the pattern" ); |
1501 | } |
1502 | |
1503 | auto *newT2 = |
1504 | F->createTranspose(name.str() + "_nhwc2chwn" , newCN, NHWC2CHWN); |
1505 | |
1506 | idim = newT2->getResult().dims(); |
1507 | odim = {idim[0], idim[1] / block, block, idim[2] / block, block, idim[3]}; |
1508 | auto *newRN = |
1509 | F->createReshape(name.str() + "_reshape" , newT2->getResult(), odim); |
1510 | |
1511 | output.replaceAllUsesOfWith(newRN->getResult()); |
1512 | changed = true; |
1513 | } |
1514 | |
1515 | return changed; |
1516 | } |
1517 | |
1518 | /// \returns True if node A may depend on the result of B. The relationship |
1519 | /// between the nodes does not have to be direct. For example, A can depend on |
1520 | /// X which depends on B. In that case the method needs to return True. |
1521 | /// Check the use-def dependency up to a depth of \p depth. |
1522 | static bool mayDepend(Node *A, Node *B, unsigned depth = 6) { |
1523 | // We define the identify as a dependency. |
1524 | if (A == B) { |
1525 | return true; |
1526 | } |
1527 | |
1528 | // A does not depend on anything. |
1529 | if (A->getNumInputs() == 0) { |
1530 | return false; |
1531 | } |
1532 | |
1533 | // B has no users. Nothing can depend on it. |
1534 | if (B->getNumResults() == 0) { |
1535 | return false; |
1536 | } |
1537 | |
1538 | // We can't continue the search. Assume that the nodes depend on one another. |
1539 | if (depth == 0) { |
1540 | return true; |
1541 | } |
1542 | |
1543 | // Check all inputs of A. None of them may depend on B. |
1544 | for (int i = 0, e = A->getNumInputs(); i < e; i++) { |
1545 | auto *input = A->getNthInput(i).getNode(); |
1546 | // The inputs of A must not depend on B. |
1547 | if (mayDepend(input, B, depth - 1)) { |
1548 | return true; |
1549 | } |
1550 | } |
1551 | |
1552 | // We checked all inputs of A and none of them depend on B. |
1553 | return false; |
1554 | } |
1555 | |
1556 | /// \returns True if the node \p N depends on any of the values in \p list, or |
1557 | /// if any of the values in list depend on \p N. |
1558 | static bool mayDependOnAny(llvm::ArrayRef<NodeValue> list, Node *N) { |
1559 | for (auto &ll : list) { |
1560 | if (mayDepend(ll.getNode(), N) || mayDepend(N, ll.getNode())) { |
1561 | return true; |
1562 | } |
1563 | } |
1564 | |
1565 | return false; |
1566 | } |
1567 | |
1568 | /// Helper function to merge matmuls in \p F. |
1569 | /// If \p mergeOnLHS is true, it merges LHS operands of matmuls that have the |
1570 | /// same RHS. If \p mergeOnLHS is false, it merges RHS operands of matmuls that |
1571 | /// have the same LHS. \returns true if any matmuls are merged in the Function |
1572 | /// \p F. |
1573 | static bool mergeMatMuls(Function *F, bool mergeOnLHS) { |
1574 | bool changed = false; |
1575 | auto &nodes = F->getNodes(); |
1576 | |
1577 | // A map to record the list of matrix multipliers that use each node |
1578 | // value either as a right-hand-side user or a left-hand-user. |
1579 | llvm::DenseMap<Node *, std::vector<MatMulNode *>> matrixUsers; |
1580 | |
1581 | // Collect the list of nodes that are used by the matrix multiplier. |
1582 | for (auto &node : nodes) { |
1583 | if (auto *MM = dyn_cast<MatMulNode>(&node)) { |
1584 | // Do not try to merge quantized matrix multiplications because their |
1585 | // quantized parameters may not match. Until we implement the logic to |
1586 | // match the scale and offset just avoid the optimization. |
1587 | if (MM->getResult().getType()->isQuantizedType()) { |
1588 | continue; |
1589 | } |
1590 | |
1591 | if (!mergeOnLHS) { |
1592 | matrixUsers[MM->getLHS().getNode()].push_back(MM); |
1593 | } else { |
1594 | matrixUsers[MM->getRHS().getNode()].push_back(MM); |
1595 | } |
1596 | } |
1597 | } |
1598 | |
1599 | // Merge matrices. |
1600 | for (auto &it : matrixUsers) { |
1601 | auto &MMs = it.second; |
1602 | |
1603 | // Collects the LHS or RHS values to merge. |
1604 | std::vector<NodeValue> lhsOrRhs; |
1605 | |
1606 | // For each matmul that depends on the matrix. |
1607 | std::unordered_set<MatMulNode *> skippedMMs; |
1608 | std::string firstMMName; |
1609 | std::string firstMatrixName; |
1610 | for (auto *MM : MMs) { |
1611 | auto I = mergeOnLHS ? MM->getLHS() : MM->getRHS(); |
1612 | // The operands to the matrix multiplier should not depend on one another |
1613 | // or else we won't be able to get rid of the original matrix |
1614 | // multiplication. |
1615 | if (mayDependOnAny(lhsOrRhs, I.getNode())) { |
1616 | skippedMMs.insert(MM); |
1617 | continue; |
1618 | } |
1619 | lhsOrRhs.push_back(I); |
1620 | if (firstMMName.empty()) { |
1621 | firstMMName = MM->getName().str(); |
1622 | } |
1623 | if (firstMatrixName.empty()) { |
1624 | firstMatrixName = I.getNode()->getName().str(); |
1625 | } |
1626 | } |
1627 | |
1628 | // We need to have at least two matrices to merge. |
1629 | if (lhsOrRhs.size() < 2) { |
1630 | continue; |
1631 | } |
1632 | |
1633 | // Merge the matmul: |
1634 | auto *CC = F->createConcat(firstMatrixName + "_merge" , lhsOrRhs, |
1635 | mergeOnLHS ? 0 : 1); |
1636 | auto *MM = |
1637 | F->createMatMul(firstMMName + "_bigMatMul" , mergeOnLHS ? CC : it.first, |
1638 | mergeOnLHS ? it.first : CC); |
1639 | |
1640 | // Slice the output so that other nodes can consume each slice separately. |
1641 | dim_t O = MM->getResult().dims()[mergeOnLHS ? 1 : 0]; |
1642 | dim_t start = 0; |
1643 | for (auto *origMM : MMs) { |
1644 | if (skippedMMs.count(origMM)) { |
1645 | continue; |
1646 | } |
1647 | dim_t H = origMM->getResult().dims()[mergeOnLHS ? 0 : 1]; |
1648 | auto startIndices = mergeOnLHS ? std::array<dim_t, 2>({start, 0}) |
1649 | : std::array<dim_t, 2>({0, start}); |
1650 | auto endIndices = mergeOnLHS ? std::array<dim_t, 2>({start + H, O}) |
1651 | : std::array<dim_t, 2>({O, start + H}); |
1652 | auto *ex = F->createSlice(origMM->getName().str() + "_extract" , MM, |
1653 | startIndices, endIndices); |
1654 | start += H; |
1655 | origMM->getResult().replaceAllUsesOfWith(ex); |
1656 | changed = true; |
1657 | } |
1658 | } |
1659 | return changed; |
1660 | } |
1661 | |
1662 | // Merge several two or more multiple matrix multiplications that share the same |
1663 | // RHS into a single large matmul. The large matmul is more likely to utilize |
1664 | // the hardware. The result of the big matmul is the concatenated results. |
1665 | // |
1666 | // ____ _________ _________ |
1667 | // ---- | | | | M| A * C | |
1668 | // M| A | T| B | * K| C | = |---------| |
1669 | // ---- , | | | | T| B * C | |
1670 | // K ---- --------- --------- |
1671 | // K R R |
1672 | bool MergeMatMulOnLHS::run(Function *F, const CompilationContext &cctx) { |
1673 | LOG_SCOPE(F->getLogContext(), getName()); |
1674 | return mergeMatMuls(F, /* mergeOnLHS */ true); |
1675 | } |
1676 | |
1677 | // Merge several two or more multiple matrix multiplications that share the same |
1678 | // LHS into a single large matmul. The large matmul is more likely to utilize |
1679 | // the hardware. The result of the big matmul is the concatenated results. |
1680 | // |
1681 | // ____ _________ _________ |
1682 | // ---- | | | | | A | A | |
1683 | // M| A | * K| B | , K| C | = M| * | * | |
1684 | // ---- | | | | | B | C | |
1685 | // K ---- --------- --------- |
1686 | // R S R S |
1687 | bool MergeMatMulOnRHS::run(Function *F, const CompilationContext &cctx) { |
1688 | LOG_SCOPE(F->getLogContext(), getName()); |
1689 | return mergeMatMuls(F, /* mergeOnLHS */ false); |
1690 | } |
1691 | |
1692 | bool MergePadIntoConvolution::run(Function *F, const CompilationContext &cctx) { |
1693 | LOG_SCOPE(F->getLogContext(), getName()); |
1694 | bool changed = false; |
1695 | for (auto &node : F->getNodes()) { |
1696 | auto *CN = dyn_cast<ConvolutionNode>(&node); |
1697 | if (!CN) { |
1698 | continue; |
1699 | } |
1700 | |
1701 | auto *PN = dyn_cast<PadNode>(CN->getInput()); |
1702 | if (!PN) { |
1703 | continue; |
1704 | } |
1705 | |
1706 | // Convolution only supports padding with 0 constant |
1707 | if ((PN->getMode() != PaddingMode::CONSTANT) || (PN->getValue() != 0.f)) { |
1708 | continue; |
1709 | } |
1710 | |
1711 | // The convolution needs to be the unique user |
1712 | if (!PN->hasOneUse()) { |
1713 | continue; |
1714 | } |
1715 | |
1716 | // Compute the new padding. |
1717 | // Note: - convolution only supports positive padding |
1718 | // - the convolution takes NHWC input tensors. |
1719 | bool canMerge = true; |
1720 | auto padPads = PN->getPads(); |
1721 | auto convPads = CN->getPads(); |
1722 | |
1723 | // For now, there is a different interpretation of the ONNX spec for |
1724 | // Pad and Convolution. The 'pads' array won't have the same size because |
1725 | // only spatial dimensions are specified for the convolution while all |
1726 | // dimensions are specified for Pad. |
1727 | |
1728 | // The merge can apply only if no padding is requested for non spatial |
1729 | // dimensions. |
1730 | if ((padPads[0] != 0) || (padPads[3] != 0) || (padPads[4] != 0) || |
1731 | (padPads[7] != 0)) { |
1732 | continue; |
1733 | } |
1734 | |
1735 | // Compute new spatial padding. |
1736 | const int H_INDEX = 1; |
1737 | std::vector<unsigned_t> newConvPads(4); |
1738 | auto numDims = PN->getResult().dims().size(); |
1739 | for (size_t i = 0; i < 2; i++) { |
1740 | // Two pad integers per dimension (begin and end padding). |
1741 | for (size_t j = 0; j < 2; j++) { |
1742 | int newConvPadSigned = |
1743 | padPads[(i + H_INDEX) + j * numDims] + int(convPads[i + j * 2]); |
1744 | if (newConvPadSigned < 0) { |
1745 | canMerge = false; |
1746 | break; |
1747 | } |
1748 | newConvPads[i + j * 2] = unsigned_t(newConvPadSigned); |
1749 | } |
1750 | } |
1751 | if (!canMerge) { |
1752 | continue; |
1753 | } |
1754 | |
1755 | // New Convolution |
1756 | auto *newCN = F->createConv(CN->getName(), PN->getInput(), CN->getFilter(), |
1757 | CN->getBias(), CN->getResult().getType(), |
1758 | CN->getKernels(), CN->getStrides(), newConvPads, |
1759 | CN->getGroup(), CN->getDilation()); |
1760 | newCN->setFusedActivation(CN->getFusedActivation()); |
1761 | newCN->setFusedActivationArgs(CN->getFusedActivationArgs()); |
1762 | |
1763 | CN->getResult().replaceAllUsesOfWith(newCN); |
1764 | changed = true; |
1765 | } |
1766 | |
1767 | return changed; |
1768 | } |
1769 | |
1770 | /// Merge Transpose into MatMul or FC. |
1771 | /// MatMul/FC(Reshape(Transpose(X)), Weights) -> |
1772 | /// -> MatMul/FC(Reshape(X), reordered Weights) |
1773 | /// Common sequence while using NCHW as input layout, because GLOW convolution |
1774 | /// layout is NHWC: |
1775 | /// Transpose([N, H, W, C]) -> [N, C, H, W] |
1776 | /// Reshape([N, C, H, W]) -> [N, C * H * W] |
1777 | /// MatMul/FC([N, C * H * W], [C * H * W, K]) -> [N, K] |
1778 | bool MergeTransposeIntoMatMulOrFC::run(Function *F, |
1779 | const CompilationContext &cctx) { |
1780 | LOG_SCOPE(F->getLogContext(), getName()); |
1781 | bool changed = false; |
1782 | for (auto &node : F->getNodes()) { |
1783 | auto *MMN = dyn_cast<MatMulNode>(&node); |
1784 | auto *FCN = dyn_cast<FullyConnectedNode>(&node); |
1785 | Constant *W; |
1786 | ReshapeNode *RN; |
1787 | |
1788 | // Node is either MatMul or FC. |
1789 | if (MMN) { |
1790 | W = dyn_cast<Constant>(MMN->getRHS()); |
1791 | RN = dyn_cast<ReshapeNode>(MMN->getLHS()); |
1792 | } else if (FCN) { |
1793 | W = dyn_cast<Constant>(FCN->getWeights()); |
1794 | RN = dyn_cast<ReshapeNode>(FCN->getInput()); |
1795 | } else { |
1796 | continue; |
1797 | } |
1798 | |
1799 | // Weights node (or MatMul RHS) is constant. |
1800 | if (!W) { |
1801 | continue; |
1802 | } |
1803 | |
1804 | // Linearizing Reshape precedes MatMul/FC. |
1805 | // The first dimension must be kept unchanged, the others are linearized. |
1806 | if (!RN || !RN->hasOneUse() || |
1807 | RN->getInput().dims()[0] != RN->getDims()[0]) { |
1808 | continue; |
1809 | } |
1810 | |
1811 | // Transpose precedes Reshape. |
1812 | // The first dimension must be kept unchanged, the others can be shuffled |
1813 | // in any way. |
1814 | auto *TN = dyn_cast<TransposeNode>(RN->getInput()); |
1815 | if (!TN || !TN->hasOneUse() || TN->getShuffle()[0] != 0) { |
1816 | continue; |
1817 | } |
1818 | |
1819 | // MatMul/FC weights tensor is 2D. De-linearize the first dimension |
1820 | // according to Transpose output layout (original shape) and input layout |
1821 | // (reordered shape). Then we can do weights reordering by simply |
1822 | // transposing the tensor from original shape to reordered shape. |
1823 | // |
1824 | // Example for [N, H, W, C] -> [N, C, H, W] transpose (common case): |
1825 | // De-linearized original shape: [C * H * W, K] -> [C, H, W, K] |
1826 | // De-linearized reordered shape: [C * H * W, K] -> [H, W, C, K] |
1827 | // Reorder weights by transposing them: [C, H, W, K] -> [H, W, C, K] |
1828 | ShapeVector shape, newShape; |
1829 | llvm::SmallVector<unsigned_t, max_tensor_dimensions> shuffle; |
1830 | shuffle.resize(TN->getShuffle().size() - 1); |
1831 | for (size_t i = 1; i < TN->getShuffle().size(); i++) { |
1832 | shape.push_back(TN->getResult().getType()->dims()[i]); |
1833 | newShape.push_back(TN->getInput().getType()->dims()[i]); |
1834 | shuffle[TN->getShuffle()[i] - 1] = i - 1; |
1835 | } |
1836 | shape.push_back(W->dims()[1]); |
1837 | newShape.push_back(W->dims()[1]); |
1838 | shuffle.push_back(TN->getShuffle().size() - 1); |
1839 | auto reshapedWTy = |
1840 | F->getParent()->uniqueTypeWithNewShape(W->getType(), shape); |
1841 | auto reshapedNewWTy = |
1842 | F->getParent()->uniqueTypeWithNewShape(W->getType(), newShape); |
1843 | |
1844 | // New reordered weights. |
1845 | auto *newW = F->getParent()->createConstant(W->getType(), W->getName(), |
1846 | W->getLayout()); |
1847 | Tensor reshapedSrc(W->getPayload().getUnsafePtr(), reshapedWTy); |
1848 | Tensor reshapedDst(newW->getPayload().getUnsafePtr(), reshapedNewWTy); |
1849 | reshapedSrc.transpose(&reshapedDst, shuffle); |
1850 | |
1851 | // New Reshape and MatMul/FC. |
1852 | auto *newRN = |
1853 | F->createReshape(RN->getName(), TN->getInput(), RN->getDims()); |
1854 | if (MMN) { |
1855 | auto *newMMN = F->createMatMul(MMN->getName(), MMN->getResult().getType(), |
1856 | newRN, newW); |
1857 | MMN->getResult().replaceAllUsesOfWith(newMMN); |
1858 | } else if (FCN) { |
1859 | auto *newFCN = |
1860 | F->createFullyConnected(FCN->getName(), newRN, newW, FCN->getBias(), |
1861 | FCN->getResult().getType()); |
1862 | FCN->getResult().replaceAllUsesOfWith(newFCN); |
1863 | } else { |
1864 | llvm_unreachable("Unexpected node kind" ); |
1865 | } |
1866 | |
1867 | changed = true; |
1868 | } |
1869 | |
1870 | return changed; |
1871 | } |
1872 | |
1873 | /// \returns True if the two slices \p A and \p B access consecutive spacial |
1874 | /// regions on the \p dim dimension. For example Slice(0..10) Slice(10..50) |
1875 | /// are consecutive but Slice(0..10) Slice(20..30) are not. |
1876 | static bool areSlicesConsecutive(SliceNode *A, SliceNode *B, unsigned_t dim) { |
1877 | // The slices must extract from the same input. |
1878 | if (A->getInput() != B->getInput()) { |
1879 | return false; |
1880 | } |
1881 | |
1882 | auto aStart = A->getStart(); |
1883 | auto bStart = B->getStart(); |
1884 | |
1885 | assert(aStart.size() > dim && "Invalid dimension" ); |
1886 | |
1887 | for (size_t i = 0, e = aStart.size(); i < e; i++) { |
1888 | if (i == dim) { |
1889 | auto resSize = A->getResult().dims(); |
1890 | // This is the stride (the delta between the two slices on the requested |
1891 | // dimension). |
1892 | auto delta = bStart[i] - aStart[i]; |
1893 | // The distance between the two slices must be identical to the size of |
1894 | // the result. |
1895 | if (resSize[dim] != delta) { |
1896 | return false; |
1897 | } |
1898 | |
1899 | continue; |
1900 | } |
1901 | |
1902 | // The non-consecutive dimensions must be identical. |
1903 | if (aStart[i] != bStart[i]) { |
1904 | return false; |
1905 | } |
1906 | } |
1907 | |
1908 | return true; |
1909 | } |
1910 | |
1911 | /// \returns True if the two slices \p A and \p B access consecutive spacial |
1912 | /// regions along some dimension. The dimension is stored in \p dim. |
1913 | /// For example, Slice((0, 0)..(1, 10)) Slice((1, 0)..(2, 10)) are consecutive |
1914 | /// along dim=0. |
1915 | static bool findConsecutiveSliceDim(SliceNode *A, SliceNode *B, int *dim) { |
1916 | // The slices must extract from the same input. |
1917 | if (A->getInput() != B->getInput()) { |
1918 | return false; |
1919 | } |
1920 | |
1921 | for (size_t i = 0, e = A->getStart().size(); i < e; i++) { |
1922 | if (areSlicesConsecutive(A, B, i)) { |
1923 | *dim = i; |
1924 | return true; |
1925 | } |
1926 | } |
1927 | |
1928 | return false; |
1929 | } |
1930 | |
1931 | bool ConvertBroadcastedBatchMatMul::run(Function *F, |
1932 | const CompilationContext &cctx) { |
1933 | LOG_SCOPE(F->getLogContext(), getName()); |
1934 | bool changed = false; |
1935 | for (auto &node : F->getNodes()) { |
1936 | BatchMatMulNode *BMMN = dyn_cast<BatchMatMulNode>(&node); |
1937 | if (!BMMN) { |
1938 | continue; |
1939 | } |
1940 | |
1941 | NodeValue LHS = BMMN->getLHS(); |
1942 | NodeValue RHS = BMMN->getRHS(); |
1943 | |
1944 | // If RHS is a Tile/Broadcast along axis 0 and the input's dims()[0] == 1, |
1945 | // then the RHS is fully broadcasted and we can perform the optimization. |
1946 | TileNode *TN = dyn_cast<TileNode>(RHS); |
1947 | BroadcastNode *BN = dyn_cast<BroadcastNode>(RHS); |
1948 | if (!TN && !BN) { |
1949 | continue; |
1950 | } |
1951 | const unsigned_t axis = TN ? TN->getAxis() : BN->getAxis(); |
1952 | const dim_t dim0 = TN ? TN->getInput().dims()[0] : BN->getInput().dims()[0]; |
1953 | if (axis != 0 || dim0 != 1) { |
1954 | continue; |
1955 | } |
1956 | |
1957 | // If this is a Broadcast, check if the first dimension is the only one |
1958 | // that's tiled. If so, then we can treat this as the same as a |
1959 | // Tile. Otherwise we must keep around a Broadcast for everything but the |
1960 | // first dimension. |
1961 | NodeValue singleTileNV; |
1962 | if (BN) { |
1963 | ShapeVector newBNDims(BN->getResult().dims().begin(), |
1964 | BN->getResult().dims().end()); |
1965 | newBNDims[0] = 1; |
1966 | if (!BN->getInput().dims().equals(newBNDims)) { |
1967 | BroadcastNode *newBN = F->createBroadcast(BN->getName(), BN->getInput(), |
1968 | newBNDims, /* axis */ 0); |
1969 | singleTileNV = newBN->getResult(); |
1970 | } else { |
1971 | // This Broadcast is equivalent to a Tile. |
1972 | singleTileNV = BN->getInput(); |
1973 | } |
1974 | } else { |
1975 | singleTileNV = TN->getInput(); |
1976 | } |
1977 | |
1978 | // Can now convert the broadcasted BatchMatMul to a MatMul. |
1979 | // LHS = {numBatches, N, M} |
1980 | // RHS = {M, P} |
1981 | // Multiply each LHS matrix {N, M} by RHS {M, P} to get final matrix |
1982 | // {numBatches, N, P} |
1983 | const dim_t numBatches = LHS.dims()[0]; |
1984 | const dim_t N = LHS.dims()[1]; |
1985 | const dim_t M = LHS.dims()[2]; |
1986 | const dim_t P = RHS.dims()[2]; |
1987 | auto name = BMMN->getName(); |
1988 | |
1989 | // Reshape the LHS to be a two-dimensional matrix, where each batch is |
1990 | // essentially concatenated onto itself in the 0th dimension. |
1991 | ReshapeNode *reshapeLHS = |
1992 | F->createReshape(name.str() + ".reshapeLHS" , LHS, {numBatches * N, M}); |
1993 | // Squeeze out the first dimension of the original Tile's input. |
1994 | ReshapeNode *squeezedRHS = |
1995 | F->createSqueeze(name.str() + ".squeezedRHS" , singleTileNV, {0}); |
1996 | |
1997 | // Perform a normal matmul, implementing the batch matmul. |
1998 | MatMulNode *MMN = F->createMatMul(name, reshapeLHS, squeezedRHS); |
1999 | |
2000 | assert(MMN->getResult().dims()[0] == (numBatches * N) && |
2001 | "Incorrect resulting dimension for batch matmul" ); |
2002 | assert(MMN->getResult().dims()[1] == P && |
2003 | "Incorrect resulting dimension for batch matmul" ); |
2004 | |
2005 | // Reshape the result back to the expected batch output shape, with the |
2006 | // first dimension the number of batches. |
2007 | ReshapeNode *finalReshape = F->createReshape(name.str() + ".reshapeResult" , |
2008 | MMN, {numBatches, N, P}); |
2009 | BMMN->getResult().replaceAllUsesOfWith(finalReshape); |
2010 | changed = true; |
2011 | } |
2012 | return changed; |
2013 | } |
2014 | |
2015 | /// Find a sequence of slices in \p input that span the whole input. |
2016 | /// \returns True if a group of slices that span the whole input was found. |
2017 | /// The order of the slices is recorded in \p order. |
2018 | static bool findSlicesThatSpanInput(llvm::ArrayRef<SliceNode *> input, |
2019 | unsigned_t dimension, |
2020 | std::vector<SliceNode *> &order) { |
2021 | // This is the 'last' slice to be found in the sequence of slices. |
2022 | SliceNode *lastSlice = nullptr; |
2023 | |
2024 | // Find the 'first' slice in the sequence. |
2025 | for (SliceNode *SN : input) { |
2026 | auto start = SN->getStart(); |
2027 | |
2028 | // Invalid dimension. |
2029 | if (start.size() <= dimension) { |
2030 | return false; |
2031 | } |
2032 | |
2033 | // Check if this slice extract the first element. |
2034 | if (start[dimension] == 0) { |
2035 | // We found the first element. |
2036 | lastSlice = SN; |
2037 | order.push_back(lastSlice); |
2038 | break; |
2039 | } |
2040 | } |
2041 | |
2042 | // We could not find a 'first' slice. |
2043 | if (!lastSlice) { |
2044 | return false; |
2045 | } |
2046 | |
2047 | // Now that we've found the first slice in the sequence, try to order the |
2048 | // rest of the slices after the first one. |
2049 | bool addedSlice = true; |
2050 | while (addedSlice) { |
2051 | addedSlice = false; |
2052 | |
2053 | // For each slice: |
2054 | for (SliceNode *SN : input) { |
2055 | // Ignore slices of invalid types. Ignore shapes for now, that's checked |
2056 | // next while ignoring the axis dimension. |
2057 | if (!lastSlice->getResult().getType()->isEqual( |
2058 | *SN->getResult().getType(), |
2059 | /* allowDifferentShape */ true)) { |
2060 | continue; |
2061 | } |
2062 | |
2063 | // Check if shapes match except for the axis dimension. |
2064 | bool skip = false; |
2065 | for (size_t i = 0, e = lastSlice->getResult().dims().size(); i < e; ++i) { |
2066 | if (i != dimension && |
2067 | lastSlice->getResult().dims()[i] != SN->getResult().dims()[i]) { |
2068 | skip = true; |
2069 | break; |
2070 | } |
2071 | } |
2072 | if (skip) { |
2073 | continue; |
2074 | } |
2075 | |
2076 | // Check if SN comes after the last slice in the sequence. |
2077 | if (areSlicesConsecutive(lastSlice, SN, dimension)) { |
2078 | // Add the consecutive slice and schedule another iteration. |
2079 | lastSlice = SN; |
2080 | order.push_back(lastSlice); |
2081 | addedSlice = true; |
2082 | continue; |
2083 | } |
2084 | } |
2085 | } // While adding new slices. |
2086 | |
2087 | // Check that the last slice completes the tensor. |
2088 | auto startCoor = lastSlice->getStart(); |
2089 | auto resDim = lastSlice->getResult().getType()->dims(); |
2090 | auto inDim = lastSlice->getInput().getType()->dims(); |
2091 | |
2092 | // Check if for all dimensions, the size of the result tensor plus the start |
2093 | // coordinate matches the size of the tensor. |
2094 | for (int i = 0, e = startCoor.size(); i < e; i++) { |
2095 | if (startCoor[i] + resDim[i] != inDim[i]) { |
2096 | return false; |
2097 | } |
2098 | } |
2099 | |
2100 | // Report success if we found at least two slices that extract from the |
2101 | // input. |
2102 | return order.size() > 1; |
2103 | } |
2104 | |
2105 | /// Merge multiple batched add nodes into a large batched-add node. |
2106 | bool MergeBatchedAdd::run(Function *F, const CompilationContext &cctx) { |
2107 | LOG_SCOPE(F->getLogContext(), getName()); |
2108 | bool changed = false; |
2109 | auto &nodes = F->getNodes(); |
2110 | |
2111 | // We index the batched add nodes by the slice operand. |
2112 | llvm::DenseMap<Node *, std::vector<BatchedAddNode *>> rightBAUsers; |
2113 | |
2114 | // Collect all of the batched add nodes and index them by the 'slice' |
2115 | // operand. |
2116 | for (auto &node : nodes) { |
2117 | if (auto *BA = dyn_cast<BatchedAddNode>(&node)) { |
2118 | rightBAUsers[BA->getSlice().getNode()].push_back(BA); |
2119 | } |
2120 | } |
2121 | |
2122 | // For each 'slice' that batched add nodes access: |
2123 | for (auto &it : rightBAUsers) { |
2124 | auto &BAs = it.second; |
2125 | |
2126 | // Collects the left-hand-side operands that the batched-adds add into. We |
2127 | // only collect 'slice' nodes. |
2128 | std::vector<SliceNode *> slices; |
2129 | |
2130 | for (auto *BA : BAs) { |
2131 | if (auto *S = dyn_cast<SliceNode>(BA->getBatch().getNode())) { |
2132 | slices.push_back(S); |
2133 | } |
2134 | } |
2135 | |
2136 | // Check if the slice nodes that we've collected cover a whole tensor. |
2137 | std::vector<SliceNode *> order; |
2138 | bool found = findSlicesThatSpanInput(slices, 0, order); |
2139 | |
2140 | if (!found) { |
2141 | continue; |
2142 | } |
2143 | |
2144 | // We found a sequence of batched-add-slice that cover the input tensor. |
2145 | // We can transform the graph and create one big batched-add. |
2146 | std::vector<Node *> newSlices; |
2147 | assert(order.size() > 1 && "order must contain at least 2 SliceNodes." ); |
2148 | SliceNode *S = llvm::cast<SliceNode>(order[0]); |
2149 | auto *mergedBA = F->createBatchedAdd("mergedBA" , S->getInput(), it.first); |
2150 | |
2151 | // Create the new slices. These slices will replace the original scalar |
2152 | // batched-add nodes. |
2153 | for (auto *orig : order) { |
2154 | newSlices.push_back(F->createSlice(orig->getName(), mergedBA, |
2155 | orig->getStart(), |
2156 | orig->getResult().getType())); |
2157 | } |
2158 | |
2159 | // Replace the original individual batched adds with corresponding slices |
2160 | // from the new merged batch add. |
2161 | for (auto *BA : BAs) { |
2162 | for (int i = 0, e = order.size(); i < e; i++) { |
2163 | if (BA->getBatch().getNode() == order[i]) { |
2164 | BA->getResult().replaceAllUsesOfWith(newSlices[i]); |
2165 | changed = true; |
2166 | break; |
2167 | } |
2168 | } |
2169 | } |
2170 | |
2171 | } // for each batched-add group. |
2172 | return changed; |
2173 | } |
2174 | |
2175 | /// Optimize ReduceMean configuration with AvgPool if possible: last two axes |
2176 | /// in a 4D input must be reduced. |
2177 | bool OptimizeReduceMean::run(Function *F, const CompilationContext &cctx) { |
2178 | LOG_SCOPE(F->getLogContext(), getName()); |
2179 | bool changed = false; |
2180 | auto &nodes = F->getNodes(); |
2181 | |
2182 | // For each node: |
2183 | for (auto &node : nodes) { |
2184 | if (auto *RM = dyn_cast<BatchedReduceMeanNode>(&node)) { |
2185 | |
2186 | // Input shape must be 4D. |
2187 | if (RM->getBatch().dims().size() != 4) { |
2188 | continue; |
2189 | } |
2190 | |
2191 | // Last two axes must be reduced. |
2192 | auto axes = RM->getAxes(); |
2193 | if (axes.size() != 2 || std::count(axes.begin(), axes.end(), 2) != 1 || |
2194 | std::count(axes.begin(), axes.end(), 3) != 1) { |
2195 | continue; |
2196 | } |
2197 | |
2198 | // RM is already shaped to have the required output shape. |
2199 | NodeValue in = RM->getBatch(); |
2200 | |
2201 | std::vector<unsigned_t> kernels = {static_cast<unsigned_t>(in.dims()[2]), |
2202 | static_cast<unsigned_t>(in.dims()[3])}; |
2203 | std::vector<unsigned_t> strides = {1, 1}; |
2204 | std::vector<unsigned_t> pads = {0, 0, 0, 0}; |
2205 | |
2206 | // TODO: Fix bad assumption? See issue 3499, for now workaround it. |
2207 | // In Glow, AvgPool expects NHWC. |
2208 | auto *TR1 = F->createTranspose( |
2209 | RM->getName().str() + ".transposeNCHW2NHWC" , in, NCHW2NHWC, "NHWC" ); |
2210 | auto *AP = F->createAvgPool(RM->getName().str() + ".avgPool" , TR1, |
2211 | kernels, strides, pads); |
2212 | if (AP->getResult().getType()->isQuantizedType()) { |
2213 | auto TypeAP = F->getParent()->uniqueTypeWithNewQuantParams( |
2214 | AP->getResult().getType(), RM->getResult().getType()); |
2215 | AP->getResult().setType(TypeAP); |
2216 | } |
2217 | auto *TR2 = F->createTranspose( |
2218 | RM->getName().str() + ".transposeNHWC2NCHW" , AP, NHWC2NCHW, "NCHW" ); |
2219 | |
2220 | // AvgPool keeps original shape. Add reshape to match expected output. |
2221 | std::vector<dim_t> shape = TR2->getResult().dims(); |
2222 | |
2223 | ShapeVector shapeAxes(axes.begin(), axes.end()); |
2224 | |
2225 | // Axes must be sorted for correct erase. |
2226 | std::sort(shapeAxes.rbegin(), shapeAxes.rend()); |
2227 | for (const auto &axis : shapeAxes) { |
2228 | shape.erase(shape.begin() + axis); |
2229 | } |
2230 | |
2231 | auto *RN = F->createReshape(RM->getName().str() + ".reshape" , TR2, shape); |
2232 | |
2233 | RM->getResult().replaceAllUsesOfWith(RN); |
2234 | changed = true; |
2235 | continue; |
2236 | } |
2237 | } // For all nodes in the graph. |
2238 | |
2239 | return changed; |
2240 | } |
2241 | |
2242 | /// \returns a uniquely used Constant with the same contents as \p node. If \p |
2243 | /// node is not a Constant then \returns a nullptr. If \node is a Constant which |
2244 | /// has a single use, \p node is returned. If \node is a Constant which has |
2245 | /// multiple uses, then \returns a new duplicate Constant that has the same |
2246 | /// contents as \p node contained in \p M. |
2247 | static Constant *getUniquelyUsedConstant(Module *M, Node &node) { |
2248 | Constant *constant = dyn_cast<Constant>(&node); |
2249 | if (!constant) { |
2250 | return nullptr; |
2251 | } |
2252 | |
2253 | if (constant->hasOneUse() && areAllFunctionsLoaded(M)) { |
2254 | return constant; |
2255 | } |
2256 | |
2257 | // If constant has more than one use, duplicate it and return the duplicate. |
2258 | auto *NC = M->createConstant(constant->getType(), constant->getName(), |
2259 | constant->getLayout()); |
2260 | NC->getPayloadMutable().assign(&constant->getPayload()); |
2261 | return NC; |
2262 | } |
2263 | |
2264 | /// Normalize the weight of \p CV with what \p BN is doing, given containing |
2265 | /// Module \p M. \returns whether or not the normalization was possible. |
2266 | template <typename ElemTy> |
2267 | bool normalizeWeights(Module *M, ConvolutionNode &CV, |
2268 | BatchNormalizationNode &BN) { |
2269 | static_assert( |
2270 | std::is_floating_point<ElemTy>::value || |
2271 | std::is_same<float16_t, |
2272 | typename std::remove_cv<ElemTy>::type>::value || |
2273 | std::is_same<bfloat16_t, |
2274 | typename std::remove_cv<ElemTy>::type>::value, |
2275 | "This implementation is for floating-point values only" ); |
2276 | |
2277 | Constant *filterC = getUniquelyUsedConstant(M, *CV.getFilter().getNode()); |
2278 | Constant *cbiasC = getUniquelyUsedConstant(M, *CV.getBias().getNode()); |
2279 | |
2280 | if (!filterC || !cbiasC) { |
2281 | return false; |
2282 | } |
2283 | |
2284 | // Perform normalization when Convolution layout is NHWC and BatchNorm |
2285 | // ChannelIdx points to C. |
2286 | if (BN.getChannelIdx() != 3) { |
2287 | return false; |
2288 | } |
2289 | |
2290 | // Set the new filter and bias on CV if necessary. |
2291 | if (filterC != CV.getFilter().getNode()) { |
2292 | CV.getParent()->getLogContext()->logNodeInputChange( |
2293 | CV, CV.getNthInput(ConvolutionNode::FilterIdx), filterC); |
2294 | CV.setNthInput(ConvolutionNode::FilterIdx, filterC); |
2295 | } |
2296 | if (cbiasC != CV.getBias().getNode()) { |
2297 | CV.getParent()->getLogContext()->logNodeInputChange( |
2298 | CV, CV.getNthInput(ConvolutionNode::BiasIdx), cbiasC); |
2299 | CV.setNthInput(ConvolutionNode::BiasIdx, cbiasC); |
2300 | } |
2301 | |
2302 | // First, BN computation can be phrased as follows: |
2303 | // |
2304 | // (X - mean) * (1.0 / sqrt(var + eps)) * bn_scale + bias |
2305 | // |
2306 | // Thus, we can rewrite bn_scale as: |
2307 | // X * bn_scale * 1.0 / (sqrt(var + eps)) + |
2308 | // (bias - mean * (1.0 / sqrt(var + eps)) * bn_scale) |
2309 | // |
2310 | // Thus, can just have the affine transform: |
2311 | // |
2312 | // X * A + B |
2313 | // |
2314 | // where |
2315 | // |
2316 | // A = bn_scale * 1.0 / (sqrt(running_var + eps)) |
2317 | // B = (bias - mean * (1.0 / sqrt(var + eps)) * bn_scale) |
2318 | // |
2319 | // Now, we have that the computation made is the following: |
2320 | // |
2321 | // ((X `conv` W) + b) * A + B |
2322 | // |
2323 | // Then, we can simply fuse this as follows: |
2324 | // |
2325 | // (X `conv` (W * A)) + b * A + B |
2326 | // |
2327 | // which is simply |
2328 | // |
2329 | // (X `conv` Q) + C |
2330 | // |
2331 | // where |
2332 | // |
2333 | // Q = W * A |
2334 | // C = b * A + B |
2335 | |
2336 | Constant *scaleC = cast<Constant>(BN.getScale()); |
2337 | Constant *biasC = cast<Constant>(BN.getBias()); |
2338 | Constant *meanC = cast<Constant>(BN.getMean()); |
2339 | Constant *var = cast<Constant>(BN.getVar()); |
2340 | |
2341 | auto filterH = filterC->getHandle<ElemTy>(); |
2342 | |
2343 | auto cbiasH = cbiasC->getHandle<ElemTy>(); |
2344 | |
2345 | auto scaleH = scaleC->getHandle<ElemTy>(); |
2346 | auto biasH = biasC->getHandle<ElemTy>(); |
2347 | auto meanH = meanC->getHandle<ElemTy>(); |
2348 | auto varH = var->getHandle<ElemTy>(); |
2349 | |
2350 | // Update the filter/bias constants of the Conv node. |
2351 | auto epsilon = BN.getEpsilon(); |
2352 | for (size_t i = 0, e = filterH.size(); i < e; i++) { |
2353 | // Dimension zero is the 'channel' dimension. If we ever change the |
2354 | // layout of the filter then we need to change this optimization. |
2355 | dim_t channelId = filterH.getDimForPtr(0, i); |
2356 | float value = varH.at({channelId}); |
2357 | float stdvar = 1.0f / std::sqrt(value + epsilon); |
2358 | float gamma = scaleH.at({channelId}); |
2359 | float A = gamma * stdvar; |
2360 | filterH.raw(i) = ElemTy(float(filterH.raw(i)) * A); |
2361 | } |
2362 | |
2363 | for (size_t i = 0, e = cbiasH.size(); i < e; i++) { |
2364 | // Dimension zero is the 'channel' dimension. If we ever change the |
2365 | // layout of the filter then we need to change this optimization. |
2366 | dim_t channelId = cbiasH.getDimForPtr(0, i); |
2367 | float mu = meanH.at({channelId}); |
2368 | float value = varH.at({channelId}); |
2369 | float stdvar = 1.0f / std::sqrt(value + epsilon); |
2370 | float gamma = scaleH.at({channelId}); |
2371 | float beta = biasH.at({channelId}); |
2372 | float A = gamma * stdvar; |
2373 | float B = beta - mu * A; |
2374 | cbiasH.raw(i) = ElemTy(float(cbiasH.raw(i)) * A + B); |
2375 | } |
2376 | return true; |
2377 | } |
2378 | |
2379 | /// Gets Constant or returns nullptr if input is not Constant. |
2380 | /// Skips QuantizeNode if present. |
2381 | static Constant *getConstant(const NodeValue &NV) { |
2382 | Node *N = NV.getNode(); |
2383 | if (isa<QuantizeNode>(N)) { |
2384 | N = N->getNthInput(QuantizeNode::InputIdx); |
2385 | } |
2386 | return dyn_cast<Constant>(N); |
2387 | } |
2388 | |
2389 | bool OptimizeBatchNorm::run(Function *F, const CompilationContext &cctx) { |
2390 | LOG_SCOPE(F->getLogContext(), getName()); |
2391 | bool changed = false; |
2392 | auto &nodes = F->getNodes(); |
2393 | auto *M = F->getParent(); |
2394 | |
2395 | // For each node: |
2396 | for (auto &node : nodes) { |
2397 | auto *BN = dyn_cast<BatchNormalizationNode>(&node); |
2398 | if (!BN) { |
2399 | continue; |
2400 | } |
2401 | |
2402 | // Remove BN if mean,var,eps,scale,beta values make it redundant as per |
2403 | // expression (X - mean) * (1.0 / sqrt(var + eps)) * scale + bias. |
2404 | float scale, bias, mean, var; |
2405 | auto *scaleC = getConstant(BN->getScale()); |
2406 | auto *biasC = getConstant(BN->getBias()); |
2407 | auto *meanC = getConstant(BN->getMean()); |
2408 | auto *varC = getConstant(BN->getVar()); |
2409 | if (scaleC && biasC && meanC && varC && |
2410 | isUniformConstant<float>(*scaleC, scale) && |
2411 | isUniformConstant<float>(*biasC, bias) && |
2412 | isUniformConstant<float>(*meanC, mean) && |
2413 | isUniformConstant<float>(*varC, var)) { |
2414 | float eps = BN->getEpsilon(); |
2415 | // Relaxed redundancy check based on reduced BN expression so that A |
2416 | // is 1.0 and B is 0.0 in Y = A*X + B where, |
2417 | // A = scale * (1.0 / (sqrt(var + eps)) |
2418 | // B = (bias - mean * (1.0 / sqrt(var + eps)) * scale) |
2419 | if (bias == mean && (std::sqrt(var + eps) == scale)) { |
2420 | BN->getResult().replaceAllUsesOfWith(BN->getInput()); |
2421 | changed = true; |
2422 | continue; |
2423 | } |
2424 | } |
2425 | |
2426 | // Merge the Batch Normalization operation into the convolution that comes |
2427 | // before it by updating the weights of the filter and bias. |
2428 | auto *CV = dyn_cast<ConvolutionNode>(BN->getInput()); |
2429 | if (!CV) { |
2430 | continue; |
2431 | } |
2432 | |
2433 | // We can't modify conv operators that have multiple users. |
2434 | if (!CV->hasOneUse()) { |
2435 | continue; |
2436 | } |
2437 | |
2438 | bool normalizationHappened = false; |
2439 | switch (CV->getElementType(ConvolutionNode::ResultIdx)) { |
2440 | case ElemKind::FloatTy: |
2441 | normalizationHappened = normalizeWeights<float>(M, *CV, *BN); |
2442 | break; |
2443 | case ElemKind::Float16Ty: |
2444 | normalizationHappened = normalizeWeights<float16_t>(M, *CV, *BN); |
2445 | break; |
2446 | case ElemKind::BFloat16Ty: |
2447 | normalizationHappened = normalizeWeights<bfloat16_t>(M, *CV, *BN); |
2448 | break; |
2449 | default: |
2450 | llvm_unreachable("Type not supported" ); |
2451 | } |
2452 | |
2453 | if (!normalizationHappened) { |
2454 | continue; |
2455 | } |
2456 | |
2457 | // Take the predicate of what was expected for the output. |
2458 | CV->setPredicate(BN->getPredicate()); |
2459 | BN->getResult().replaceAllUsesOfWith(CV); |
2460 | changed = true; |
2461 | continue; |
2462 | } // For all nodes in the graph. |
2463 | return changed; |
2464 | } |
2465 | |
2466 | /// If \p node has uses, and all of them have one user node, return this user |
2467 | /// node. Otherwise, return nullptr. |
2468 | static Node *getOnlyUser(Node &node) { |
2469 | if (!node.hasUsers()) { |
2470 | // No users. |
2471 | return nullptr; |
2472 | } |
2473 | Node *first = node.getUsers().front().getUser(); |
2474 | for (auto &U : node.getUsers()) { |
2475 | if (U.getUser() != first) { |
2476 | // Multiple users. |
2477 | return nullptr; |
2478 | } |
2479 | } |
2480 | // One user. |
2481 | return first; |
2482 | } |
2483 | |
2484 | /// Checks that \p tile sub-tensors along \p chId axis repeat. |
2485 | /// It also extracts the repeated dimension values into \p result. |
2486 | static bool isConstBroadcasted(std::vector<float> &result, const Constant &tile, |
2487 | int32_t chId) { |
2488 | // TODO: This limitation can be lifted, but that is for simplicity. |
2489 | if (tile.getType()->dims().size() != 4) { |
2490 | return false; |
2491 | } |
2492 | // TODO: We can also support quantized constants if there is a need in it. |
2493 | if (tile.getType()->getElementType() != ElemKind::FloatTy) { |
2494 | return false; |
2495 | } |
2496 | auto handle = tile.getPayload().getHandle<float>(); |
2497 | glow::dim_t n, h, w, c; |
2498 | for (c = 0; c < tile.getType()->dims()[chId]; c++) { |
2499 | std::vector<glow::dim_t> dims = {c, 0, 0, 0}; |
2500 | std::rotate(dims.begin(), dims.begin() + dims.size() - chId, dims.end()); |
2501 | const float expected = handle.at(llvm::ArrayRef<glow::dim_t>(dims)); |
2502 | for (n = 0; n < tile.getType()->dims()[(chId + 1) % 4]; n++) { |
2503 | for (h = 0; h < tile.getType()->dims()[(chId + 2) % 4]; h++) { |
2504 | for (w = 0; w < tile.getType()->dims()[(chId + 3) % 4]; w++) { |
2505 | std::vector<glow::dim_t> dimsE = {c, n, h, w}; |
2506 | std::rotate(dimsE.begin(), dimsE.begin() + dimsE.size() - chId, |
2507 | dimsE.end()); |
2508 | if (handle.at(llvm::ArrayRef<glow::dim_t>(dims)) != expected) { |
2509 | return false; |
2510 | } |
2511 | } |
2512 | } |
2513 | } |
2514 | result[c] = expected; |
2515 | } |
2516 | return true; |
2517 | } |
2518 | |
2519 | /// Collects the longest chain of arithmetic operations with constants starting |
2520 | /// from \p start. Updates \p scale and \p bias as it collects the operands. |
2521 | /// \p returns the last node in the chain. |
2522 | static NodeValue collectArithmeticChain(Function *F, NodeValue start, |
2523 | Constant &scale, Constant &bias, |
2524 | int32_t chIdx) { |
2525 | |
2526 | Node *user = getOnlyUser(*start.getNode()); |
2527 | NodeValue chainEnd = start; |
2528 | |
2529 | auto isSupportedForMerge = [](const Node *n) { |
2530 | return isa<MulNode>(n) || isa<AddNode>(n) || isa<SubNode>(n) || |
2531 | isa<DivNode>(n); |
2532 | }; |
2533 | |
2534 | while (user && isSupportedForMerge(user)) { |
2535 | // Paranoid |
2536 | assert(user->isArithmetic() && "Not an arithmetic node!" ); |
2537 | |
2538 | auto lhs = user->getNthInput(ArithmeticNode::LHSIdx); |
2539 | auto rhs = user->getNthInput(ArithmeticNode::RHSIdx); |
2540 | |
2541 | // Paranoid. |
2542 | assert(((lhs == chainEnd) || (rhs == chainEnd)) && "Not a user?" ); |
2543 | |
2544 | auto out = user->getNthResult(ArithmeticNode::ResultIdx); |
2545 | |
2546 | // Quantized arithmetic operations may change scale of result, we don't want |
2547 | // to deal with that. May be supported later if needed. |
2548 | if (lhs.getType() != out.getType() || rhs.getType() != out.getType()) { |
2549 | break; |
2550 | } |
2551 | |
2552 | // Only take this one if its other argument is a constant. |
2553 | // TODO: We can also support Splat here if needed. |
2554 | auto *c = dyn_cast<Constant>(lhs == chainEnd ? rhs : lhs); |
2555 | if (!c) { |
2556 | break; |
2557 | } |
2558 | |
2559 | const dim_t numChannels = c->dims()[chIdx]; |
2560 | |
2561 | std::vector<float> toMerge(c->dims()[chIdx]); |
2562 | if (!isConstBroadcasted(toMerge, *c, chIdx)) { |
2563 | break; |
2564 | } |
2565 | |
2566 | auto biasH = bias.getPayloadMutable().getHandle(); |
2567 | auto scaleH = scale.getPayloadMutable().getHandle(); |
2568 | |
2569 | for (dim_t i = 0; i < numChannels; i++) { |
2570 | if (isa<DivNode>(user)) { |
2571 | scaleH.raw(i) /= toMerge[i]; |
2572 | biasH.raw(i) /= toMerge[i]; |
2573 | } else if (isa<MulNode>(user)) { |
2574 | scaleH.raw(i) *= toMerge[i]; |
2575 | biasH.raw(i) *= toMerge[i]; |
2576 | } else if (isa<SubNode>(user)) { |
2577 | if (chainEnd == rhs) { |
2578 | scaleH.raw(i) *= -1; |
2579 | biasH.raw(i) = toMerge[i] - biasH.raw(i); |
2580 | } else { |
2581 | biasH.raw(i) -= toMerge[i]; |
2582 | } |
2583 | } else if (isa<AddNode>(user)) { |
2584 | biasH.raw(i) += toMerge[i]; |
2585 | } else { |
2586 | llvm_unreachable("Unsupported type!" ); |
2587 | } |
2588 | } |
2589 | chainEnd = user->getNthResult(ArithmeticNode::ResultIdx); |
2590 | user = getOnlyUser(*user); |
2591 | } |
2592 | return chainEnd; |
2593 | } |
2594 | |
2595 | /// Find the longest chain of Mul/Sub/Add/Div under a Convolution node that |
2596 | /// operate on Constant and fold them into a new BatchNormalization node. |
2597 | bool FoldArithmeticChainUnderConvIntoBN::run(Function *F, |
2598 | const CompilationContext &cctx) { |
2599 | LOG_SCOPE(F->getLogContext(), getName()); |
2600 | bool changed = false; |
2601 | |
2602 | for (auto &node : F->getNodes()) { |
2603 | auto *CN = dyn_cast<ConvolutionNode>(&node); |
2604 | if (!CN) { |
2605 | continue; |
2606 | } |
2607 | auto bias = CN->getBias(); |
2608 | // Conv is in NHWC format - channel is dim 3. |
2609 | int32_t chIdx = 3; |
2610 | |
2611 | // TODO: Support quantized constants if needed. |
2612 | if (bias.getType()->getElementType() != ElemKind::FloatTy) { |
2613 | continue; |
2614 | } |
2615 | |
2616 | // Provide collectArithmeticChain w/ bias/scale that have identity values |
2617 | // as we are creating new BN consisted of the arithmetic nodes that the |
2618 | // function will find. |
2619 | auto *newScale = F->getParent()->createConstant( |
2620 | bias.getType(), CN->getName().str() + "_BN.scale" ); |
2621 | auto *newBias = F->getParent()->createConstant( |
2622 | bias.getType(), CN->getName().str() + "_BN.bias" ); |
2623 | |
2624 | newScale->getPayloadMutable().getHandle<float>().clear(1.f); |
2625 | newBias->getPayloadMutable().getHandle<float>().clear(0.f); |
2626 | |
2627 | // Collect the chain and compute the new scale and bias. |
2628 | NodeValue chainEnd = |
2629 | collectArithmeticChain(F, CN->getResult(), *newScale, *newBias, chIdx); |
2630 | if (chainEnd == CN->getResult()) { |
2631 | F->getParent()->eraseConstant(newScale); |
2632 | F->getParent()->eraseConstant(newBias); |
2633 | continue; |
2634 | } |
2635 | |
2636 | // Compute the shape of batch normalization constants (array of |
2637 | // {depth} elements). |
2638 | glow::dim_t size = newScale->getPayloadMutable().getHandle<float>().size(); |
2639 | auto depthTy = |
2640 | F->getParent()->uniqueTypeWithNewShape(bias.getType(), {size}); |
2641 | |
2642 | Tensor varianceT(depthTy); |
2643 | varianceT.init(glow::Tensor::InitKind::Broadcast, 1.0f, F->getPRNG()); |
2644 | auto variance = F->getParent()->createConstant( |
2645 | CN->getName().str() + "_BN.var" , varianceT); |
2646 | |
2647 | Tensor meanT(depthTy); |
2648 | meanT.zero(); |
2649 | auto mean = |
2650 | F->getParent()->createConstant(CN->getName().str() + "_BN.mean" , meanT); |
2651 | |
2652 | // Create a BN with new parameters. |
2653 | auto *nBN = |
2654 | F->createBatchNormalization("BatchNorm" , chainEnd.getType(), &node, |
2655 | newBias, newScale, mean, variance, 3, 0, 0); |
2656 | chainEnd.replaceAllUsesOfWith(nBN); |
2657 | changed = true; |
2658 | } |
2659 | return changed; |
2660 | } |
2661 | |
2662 | /// For each BatchNormalization node in \p F, find the longest chain of |
2663 | /// Mul/Sub/Add/Div operations with constants that use it and merge all those |
2664 | /// operations into the BatchNormalization. |
2665 | bool FoldBatchNormalizationWithArithmeticChain::run( |
2666 | Function *F, const CompilationContext &cctx) { |
2667 | LOG_SCOPE(F->getLogContext(), getName()); |
2668 | bool changed = false; |
2669 | for (auto &node : F->getNodes()) { |
2670 | auto *BN = dyn_cast<BatchNormalizationNode>(&node); |
2671 | if (!BN) { |
2672 | continue; |
2673 | } |
2674 | |
2675 | // Expecting constant as const folding took place already. |
2676 | auto *scaleC = dyn_cast<Constant>(BN->getScale()); |
2677 | auto *biasC = dyn_cast<Constant>(BN->getBias()); |
2678 | if (!scaleC || !biasC) { |
2679 | continue; |
2680 | } |
2681 | |
2682 | // TODO: Support quantized constants if needed. |
2683 | if (scaleC->getType()->getElementType() != ElemKind::FloatTy || |
2684 | biasC->getType()->getElementType() != ElemKind::FloatTy) { |
2685 | continue; |
2686 | } |
2687 | |
2688 | auto *newScaleC = |
2689 | F->getParent()->createConstant(scaleC->getType(), scaleC->getName()); |
2690 | Tensor scaleT = scaleC->getPayload().getUnowned(); |
2691 | newScaleC->assign(&scaleT); |
2692 | |
2693 | auto *newBiasC = |
2694 | F->getParent()->createConstant(biasC->getType(), biasC->getName()); |
2695 | Tensor biasT = biasC->getPayload().getUnowned(); |
2696 | newBiasC->assign(&biasT); |
2697 | |
2698 | // Collect the chain and compute the new scale and bias. |
2699 | NodeValue chainEnd = collectArithmeticChain(F, BN->getResult(), *newScaleC, |
2700 | *newBiasC, BN->getChannelIdx()); |
2701 | if (chainEnd == BN->getResult()) { |
2702 | F->getParent()->eraseConstant(newScaleC); |
2703 | F->getParent()->eraseConstant(newBiasC); |
2704 | continue; |
2705 | } |
2706 | |
2707 | Node *newScaleN = newScaleC, *newBiasN = newBiasC; |
2708 | if (isa<QuantizeNode>(BN->getScale())) { |
2709 | newScaleN = F->createQuantize(newScaleN->getName(), newScaleN, |
2710 | BN->getScale().getType()); |
2711 | } |
2712 | if (isa<QuantizeNode>(BN->getBias())) { |
2713 | newBiasN = F->createQuantize(newBiasN->getName(), newBiasN, |
2714 | BN->getBias().getType()); |
2715 | } |
2716 | |
2717 | // Create a BN with new parameters. |
2718 | auto *newBN = F->createBatchNormalization( |
2719 | BN->getName(), chainEnd.getType(), BN->getInput(), newBiasN, newScaleN, |
2720 | BN->getMean(), BN->getVar(), BN->getChannelIdx(), BN->getEpsilon(), |
2721 | BN->getMomentum()); |
2722 | |
2723 | chainEnd.replaceAllUsesOfWith(newBN); |
2724 | changed = true; |
2725 | } |
2726 | |
2727 | return changed; |
2728 | } |
2729 | |
2730 | /// Fold MatMul + Add into FullyConnected. This is useful for backends which |
2731 | /// have an atomic implementation for the FullyConnected node. It is also needed |
2732 | /// for ONNX which does not have a representation for the FullyConnected node. |
2733 | bool FoldMatMulAddIntoFullyConnected::run(Function *F, |
2734 | const CompilationContext &cctx) { |
2735 | LOG_SCOPE(F->getLogContext(), getName()); |
2736 | bool changed = false; |
2737 | for (auto &node : F->getNodes()) { |
2738 | auto *addNode = dyn_cast<AddNode>(&node); |
2739 | if (!addNode) { |
2740 | continue; |
2741 | } |
2742 | |
2743 | // Check for MatMul node being either RHS or LHS. |
2744 | auto *matMulNode_LHS = dyn_cast<MatMulNode>(addNode->getLHS()); |
2745 | auto *matMulNode_RHS = dyn_cast<MatMulNode>(addNode->getRHS()); |
2746 | auto *matMulNode = matMulNode_LHS ? matMulNode_LHS : matMulNode_RHS; |
2747 | NodeValue bias = matMulNode_LHS ? addNode->getRHS() : addNode->getLHS(); |
2748 | if (!matMulNode) { |
2749 | continue; |
2750 | } |
2751 | |
2752 | // Folding is allowed only if MatMul has one use. |
2753 | if (!matMulNode->getResult().hasOneUse()) { |
2754 | continue; |
2755 | } |
2756 | |
2757 | // The corresponding batch/length of the FullyConnected Bias operand. |
2758 | assert(bias.dims().size() == 2 && "Bias should be 2D!" ); |
2759 | auto biasBatch = bias.dims()[0]; |
2760 | auto biasLength = bias.dims()[1]; |
2761 | if (biasBatch == 1) { |
2762 | // If bias is not batched then reshape to 1D. |
2763 | bias = F->createReshape(bias.getNode()->getName().str() + ".reshape" , |
2764 | bias, {bias.getType()->size()}); |
2765 | } else { |
2766 | // If bias is batched then we must verify that the bias data |
2767 | // is same for all batches. For this the bias must be a Constant. |
2768 | auto *biasC = llvm::dyn_cast<Constant>(bias.getNode()); |
2769 | if (!biasC) { |
2770 | continue; |
2771 | } |
2772 | if (!biasC->getPayload().isTiled(0)) { |
2773 | continue; |
2774 | } |
2775 | // Slice batched 2D bias and reshape to 1D. |
2776 | bias = F->createSlice(bias.getNode()->getName().str() + ".slice" , bias, |
2777 | {0, 0}, {1, biasLength}); |
2778 | bias = F->createReshape(bias.getNode()->getName().str() + ".reshape" , |
2779 | bias, {biasLength}); |
2780 | } |
2781 | |
2782 | // Create a new FullyConnected node. |
2783 | auto *newFC = F->createFullyConnected( |
2784 | matMulNode->getName(), matMulNode->getLHS(), matMulNode->getRHS(), bias, |
2785 | addNode->getResult().getType()); |
2786 | addNode->getResult().replaceAllUsesOfWith(newFC); |
2787 | changed = true; |
2788 | } |
2789 | |
2790 | return changed; |
2791 | } |
2792 | |
2793 | /// Convert MatMul into FullyConnected with null bias. This pass is used if |
2794 | /// we have an optimized implementation for FullyConnected but NOT for MatMul. |
2795 | /// Make sure you run this pass after the FoldMatMulAddIntoFullyConnected pass |
2796 | /// otherwise a MatMul followed by Add will be converted into a FullyConnected |
2797 | /// followed by Add and NOT a single FullyConnected instance. |
2798 | bool ConvertMatMulToFullyConnected::run(Function *F, |
2799 | const CompilationContext &cctx) { |
2800 | bool changed = false; |
2801 | for (auto &node : F->getNodes()) { |
2802 | auto *matMulNode = dyn_cast<MatMulNode>(&node); |
2803 | if (!matMulNode) { |
2804 | continue; |
2805 | } |
2806 | |
2807 | // Create null bias. |
2808 | Constant *bias = nullptr; |
2809 | std::vector<dim_t> biasDims = {matMulNode->getResult().dims().back()}; |
2810 | std::string biasName = matMulNode->getName().str() + "bias" ; |
2811 | if (matMulNode->getResult().getType()->isQuantizedType()) { |
2812 | // Create null bias with offset 0 and a scale equal to the product |
2813 | // between LHS scale and RHS scale. |
2814 | float biasScale = matMulNode->getLHS().getType()->getScale() * |
2815 | matMulNode->getRHS().getType()->getScale(); |
2816 | int32_t biasOffset = 0; |
2817 | ElemKind biasPrec = cctx.precisionConfig.quantConfig.precisionBias; |
2818 | bias = F->getParent()->createConstant(biasPrec, biasDims, biasScale, |
2819 | biasOffset, biasName); |
2820 | bias->getPayloadMutable().zero(); |
2821 | } else { |
2822 | // Create null FLOAT bias. |
2823 | bias = |
2824 | F->getParent()->createConstant(ElemKind::FloatTy, biasDims, biasName); |
2825 | bias->getPayloadMutable().zero(); |
2826 | } |
2827 | |
2828 | // Create a new FullyConnected node with null bias. |
2829 | auto *newFC = F->createFullyConnected( |
2830 | matMulNode->getName(), matMulNode->getLHS(), matMulNode->getRHS(), bias, |
2831 | matMulNode->getResult().getType()); |
2832 | matMulNode->getResult().replaceAllUsesOfWith(newFC); |
2833 | changed = true; |
2834 | } |
2835 | |
2836 | return changed; |
2837 | } |
2838 | |
2839 | // Fold Add after ConvTranspose into ConvTranspose's bias, if such Add was a |
2840 | // broadcasted Add. Examine by looking into Tensor repetitions. Fold this: |
2841 | // |
2842 | // CONST1 Input |
2843 | // \ | |
2844 | // CONST2 ConvTranspose |
2845 | // \ / |
2846 | // Add |
2847 | // | |
2848 | // Output |
2849 | // |
2850 | // into this: |
2851 | // |
2852 | // CONST1 (CONST2 SQUEEZED) |
2853 | // | / |
2854 | // Input ADD |
2855 | // \ / |
2856 | // ConvTranspose |
2857 | // | |
2858 | // Output |
2859 | // |
2860 | // Optimizations are going to take care of folding CONST1/CONST2/ADD |
2861 | // into one const bias. |
2862 | bool ConvTransposeBiasAddFold::run(Function *F, |
2863 | const CompilationContext &cctx) { |
2864 | LOG_SCOPE(F->getLogContext(), getName()); |
2865 | bool changed = false; |
2866 | for (auto &node : F->getNodes()) { |
2867 | |
2868 | auto *AN = dyn_cast<AddNode>(&node); |
2869 | if (!AN) { |
2870 | continue; |
2871 | } |
2872 | |
2873 | // Check for Transpose node being either RHS or LHS. |
2874 | auto *DN_L = dyn_cast<ConvTransposeNode>(AN->getLHS()); |
2875 | auto *DN_R = dyn_cast<ConvTransposeNode>(AN->getRHS()); |
2876 | auto *DN = DN_L ? DN_L : DN_R; |
2877 | if (!(!DN_R ^ !DN_L)) { |
2878 | continue; |
2879 | } |
2880 | auto *biasTile = dyn_cast<Constant>(DN_L ? AN->getRHS() : AN->getLHS()); |
2881 | if (!biasTile || (biasTile->dims().size() != 4)) { |
2882 | continue; |
2883 | } |
2884 | auto *bias = dyn_cast<Constant>(DN->getBias()); |
2885 | if (!bias) { |
2886 | continue; |
2887 | } |
2888 | |
2889 | // Check if Add is a broadcasted Add. |
2890 | std::vector<float> origConst(biasTile->dims()[3]); |
2891 | if (!isConstBroadcasted(origConst, *biasTile, 3)) { |
2892 | continue; |
2893 | } |
2894 | |
2895 | // Expect Bias Add so allocate a new bias to fill as do checking. |
2896 | auto *newBias = F->getParent()->createConstant( |
2897 | ElemKind::FloatTy, {biasTile->dims()[3]}, biasTile->getName()); |
2898 | newBias->getHandle() = origConst; |
2899 | |
2900 | auto *add = F->createAdd(bias->getName(), bias, newBias); |
2901 | DN->setNthInput(ConvTransposeNode::BiasIdx, add); |
2902 | AN->getResult().replaceAllUsesOfWith(DN); |
2903 | |
2904 | changed = true; |
2905 | } // For all nodes in the graph. |
2906 | |
2907 | return changed; |
2908 | } |
2909 | |
2910 | /// \returns true if all dimensions of the \p input tensors are the same |
2911 | /// except for the provided \p dimension, otherwise return false. |
2912 | static bool checkConcatNodeUniformDims(llvm::ArrayRef<NodeValue> inputs, |
2913 | unsigned_t dimension) { |
2914 | for (size_t i = 1; i < inputs.size(); i++) { |
2915 | for (size_t j = 0; j < inputs[0].dims().size(); j++) { |
2916 | if (j == dimension) { |
2917 | continue; |
2918 | } |
2919 | if (inputs[0].dims()[j] != inputs[i].dims()[j]) { |
2920 | return false; |
2921 | } |
2922 | } |
2923 | } |
2924 | return true; |
2925 | } |
2926 | |
2927 | /// Given a tensor's dims \p firstDims and the desired leading/trailing dims |
2928 | /// sizes \p leadingDimsProdOriginalConcatNode, \p |
2929 | /// trailingDimsProdOriginalConcatNode. \returns the dimension, at which the |
2930 | /// trailing/leading dimensions match the desired sizes, otherwise returns -1. |
2931 | /// Example: Given a tensor <1,2,3,4,5>, and a desired trailing dimensions |
2932 | /// size of 20, and a desired leading dimensions size of 2, this function will |
2933 | /// return dimension 1 as the trailing dimensions after it are <4,5>, which |
2934 | /// matches the size 20, and the leading dimensions are <1,2>, which matches |
2935 | /// the size 2. |
2936 | static ssize_t findMatchingConcatDimForSameTrailingAndLeadingDims( |
2937 | llvm::ArrayRef<dim_t> firstDims, size_t leadingDimsProdOriginalConcatNode, |
2938 | size_t trailingDimsProdOriginalConcatNode) { |
2939 | size_t trailingDimsProdCurNode = 1; |
2940 | for (ssize_t i = firstDims.size() - 1; i >= 0; i--) { |
2941 | if (trailingDimsProdCurNode == trailingDimsProdOriginalConcatNode) { |
2942 | size_t leadingDimsProdCurNode = 1; |
2943 | for (ssize_t j = 0; j < i; j++) { |
2944 | leadingDimsProdCurNode *= firstDims[j]; |
2945 | } |
2946 | if (leadingDimsProdCurNode == leadingDimsProdOriginalConcatNode) { |
2947 | return i; |
2948 | } |
2949 | } |
2950 | trailingDimsProdCurNode *= firstDims[i]; |
2951 | } |
2952 | return -1; |
2953 | } |
2954 | |
2955 | /// Given input tensors \p inputs and a original ConcatNode \p origConcatN, |
2956 | /// try to find out if there is a dimension in the input tensors, with which |
2957 | /// we can meet two requirements: |
2958 | /// 1) Input tensors are concatenate-able along this dimension. |
2959 | /// 2) The trailing/leading dimensions sizes after/before this dimension in |
2960 | /// the input tensors, are of the same size as the trailing/leading |
2961 | /// dimensions of the input of the original Concat node after/before the |
2962 | /// concatenation dimension. It is required, because they ensure that the |
2963 | /// payload of the new concat node should be the same as the payload of |
2964 | /// the original concat node, and also won't affect the data order of the |
2965 | /// entire tensor. |
2966 | /// \returns this dimension if found, otherwise -1. |
2967 | static int |
2968 | findConcatDimForSameTrailingAndLeadingDims(llvm::ArrayRef<NodeValue> inputs, |
2969 | ConcatNode *originalConcatNode) { |
2970 | // For the purpose of the optimiztion |
2971 | // Concat(Reshape(X)*N)->Reshape(Concat(N*X)), we want to make sure the new |
2972 | // ConcatNode can concatenate on the trailing/leading dimensions which are |
2973 | // of the same size of those of the original Concate node. |
2974 | |
2975 | auto firstDims = inputs.front().dims(); |
2976 | auto origConcatNInputDims = originalConcatNode->getInputs().front().dims(); |
2977 | // The sizes of the trailing/leading dimensions of the original ConcatNode, |
2978 | // which are being concatenated. This sizes are simply the products of |
2979 | // dimensions following/before the dimension used for concatenation. |
2980 | dim_t trailingDimsProdOriginalConcatNode = 1; |
2981 | dim_t leadingDimsProdOriginalConcatNode = 1; |
2982 | for (size_t i = 0; i < origConcatNInputDims.size(); ++i) { |
2983 | if (i < originalConcatNode->getDim()) { |
2984 | leadingDimsProdOriginalConcatNode *= origConcatNInputDims[i]; |
2985 | } else if (i > originalConcatNode->getDim()) { |
2986 | trailingDimsProdOriginalConcatNode *= origConcatNInputDims[i]; |
2987 | } |
2988 | } |
2989 | |
2990 | // Try to find the dimension in the first input such that the |
2991 | // trailing/leading dimensions sizes are the same as the sizes of the |
2992 | // trailing/leading dimensions based on the concatenation dimension used by |
2993 | // the original ConcatNode. |
2994 | ssize_t dim = findMatchingConcatDimForSameTrailingAndLeadingDims( |
2995 | firstDims, leadingDimsProdOriginalConcatNode, |
2996 | trailingDimsProdOriginalConcatNode); |
2997 | if (dim == -1) { |
2998 | return -1; |
2999 | } |
3000 | |
3001 | // Now we have found the dimension, we need to check if all inputs can be |
3002 | // concatenated along this dimension. |
3003 | if (!checkConcatNodeUniformDims(inputs, dim)) { |
3004 | return -1; |
3005 | } |
3006 | return dim; |
3007 | } |
3008 | |
3009 | /// Given the inputs \p originalConcatInputs of one Concat Nodes, \returns |
3010 | /// true if they are all ReshapeNode, and the input tensors of these input |
3011 | /// nodes have same number of dimensions, otherwise returns false. |
3012 | static bool |
3013 | tryToGetNewConcatInputs(NodeValueArrayRef originalConcatInputs, |
3014 | llvm::SmallVectorImpl<NodeValue> &newConcatInputs) { |
3015 | // Go through the input nodes of CN, check if they are all ReshapeNode, |
3016 | // and if the input tensors of these input nodes have same number of |
3017 | // dimensions. |
3018 | for (auto &I : originalConcatInputs) { |
3019 | if (auto *R = dyn_cast<ReshapeNode>(I)) { |
3020 | if (newConcatInputs.empty() || newConcatInputs.front().dims().size() == |
3021 | R->getInput().dims().size()) { |
3022 | newConcatInputs.push_back(R->getInput()); |
3023 | continue; |
3024 | } |
3025 | } |
3026 | return false; |
3027 | } |
3028 | return true; |
3029 | } |
3030 | |
3031 | /// Concat(Reshape(x) * N) -> Reshape(Concat(x * N)). |
3032 | /// \returns a new simplified Concat node or nullptr. |
3033 | static NodeValue tryToOptimizeConcatOfRehapes(Function *F, ConcatNode *CN) { |
3034 | llvm::SmallVector<NodeValue, 16> newConcatInputs; |
3035 | // The inputs of the collected input reshape nodes. They will be used as |
3036 | // inputs for the new Concat node if possible. |
3037 | if (!tryToGetNewConcatInputs(CN->getInputs(), newConcatInputs)) { |
3038 | return NodeValue(nullptr); |
3039 | } |
3040 | |
3041 | // Try to concatenate along the same size trailing/leading dimensions as of |
3042 | // the original Concat node. |
3043 | auto dim = findConcatDimForSameTrailingAndLeadingDims(newConcatInputs, CN); |
3044 | if (dim == -1) { |
3045 | return NodeValue(nullptr); |
3046 | } |
3047 | auto *newCN = F->createConcat(CN->getName(), newConcatInputs, dim); |
3048 | return F->createReshape( |
3049 | CN->getInputs().front().getNode()->getName(), newCN, |
3050 | CN->getResult().dims(), |
3051 | CanonicalTensorLayout::getInstance().getNthResultLayoutRequirements( |
3052 | CN, ConcatNode::ResultIdx)); |
3053 | } |
3054 | |
3055 | /// Simplify concat node. |
3056 | /// \returns a new simplified Concat node or nullptr. |
3057 | static NodeValue simplifyConcatNode(Function *F, ConcatNode *CN, |
3058 | const CompilationContext &cctx) { |
3059 | /// concat(dim1, concat(dim2, X, Y), Z) -> concat(dim1, X, Y, Z), |
3060 | /// but only if dim1 == dim2 |
3061 | |
3062 | LOG_SCOPE(F->getLogContext(), "simplifyConcatNode" ) |
3063 | |
3064 | if (!cctx.optimizationOpts.skipConcatMerging) { |
3065 | auto inputs = CN->getInputs(); |
3066 | // Check if any of the inputs are ConcatNode. |
3067 | llvm::SmallVector<NodeValue, 16> newInputs; |
3068 | bool merged = false; |
3069 | for (auto &input : inputs) { |
3070 | newInputs.push_back(input); |
3071 | auto *CNI = dyn_cast<ConcatNode>(input); |
3072 | // Bail if it is not a ConcatNode or it is a concat node with a diffrent |
3073 | // dimension. |
3074 | if (!CNI || CNI->getDim() != CN->getDim()) { |
3075 | continue; |
3076 | } |
3077 | |
3078 | // Preventing this from kicking in. |
3079 | // Otherwise we will end up sequence of concats with more and more inputs, |
3080 | // without really eliminating any. |
3081 | if (CNI->getResult().getNumUsers() > 1) { |
3082 | continue; |
3083 | } |
3084 | |
3085 | merged = true; |
3086 | // Replace current input by its own inputs, i.e. merge them into the |
3087 | // parent concat node. |
3088 | newInputs.pop_back(); |
3089 | newInputs.append(CNI->getInputs().begin(), CNI->getInputs().end()); |
3090 | } |
3091 | if (merged) { |
3092 | // Return a new simplified Concat node. |
3093 | return F->createConcat(CN->getName(), newInputs, CN->getDim()); |
3094 | } |
3095 | } |
3096 | |
3097 | // If all of the inputs to the concat are extracted from the same input in |
3098 | // the right order then we can just use the extract-input instead of the |
3099 | // concat. Concat(Slice(X, 0..10), Slice(X, 10..20)) -> X. |
3100 | { |
3101 | std::vector<SliceNode *> order; |
3102 | std::vector<SliceNode *> slices; |
3103 | // Collect all of the inputs that are SliceNode. |
3104 | for (auto &I : CN->getInputs()) { |
3105 | if (auto *S = dyn_cast<SliceNode>(I.getNode())) { |
3106 | slices.push_back(S); |
3107 | } |
3108 | } |
3109 | // Check if the slices span the input value. |
3110 | bool found = findSlicesThatSpanInput(slices, CN->getDim(), order); |
3111 | if (found && order.size() == slices.size()) { |
3112 | // Check that the ordered Slices that span the input are in order. |
3113 | bool ordered = true; |
3114 | for (size_t i = 0, e = slices.size(); i < e; i++) { |
3115 | if (order[i] != slices[i]) { |
3116 | ordered = false; |
3117 | break; |
3118 | } |
3119 | } |
3120 | |
3121 | auto orig = order[0]->getInput(); |
3122 | // The original value that we extract from must be of the same shape as |
3123 | // the concat. |
3124 | if (ordered && CN->getResult().getType() == orig.getType()) { |
3125 | return orig; |
3126 | } |
3127 | } |
3128 | } |
3129 | |
3130 | // Try the optimization Concat(Reshape(x) * N) -> Reshape(Concat(x * N)). |
3131 | if (auto transformedConcatNode = tryToOptimizeConcatOfRehapes(F, CN)) { |
3132 | return transformedConcatNode; |
3133 | } |
3134 | |
3135 | // If the concat has a single input, replace the concat with that input. |
3136 | if (CN->getNumInputs() == 1) { |
3137 | return CN->getInputs()[0]; |
3138 | } |
3139 | |
3140 | return NodeValue(nullptr); |
3141 | } |
3142 | |
3143 | /// If all of the outputs of \p CN are essentially piped from the inputs of the |
3144 | /// concat (i.e. same shape, axis, order) then we can get rid of the slices and |
3145 | /// concat. \returns true if this optimization is successful and changes the |
3146 | /// Function. |
3147 | static bool combineConcatSlices(ConcatNode *CN) { |
3148 | auto inputsToCN = CN->getInputs(); |
3149 | std::vector<SliceNode *> slices; |
3150 | std::vector<SliceNode *> orderedSlices; |
3151 | for (auto &user : CN->getUsers()) { |
3152 | if (SliceNode *SN = dyn_cast<SliceNode>(user.getUser())) { |
3153 | slices.push_back(SN); |
3154 | } |
3155 | } |
3156 | |
3157 | // Check if the slices span the input value. |
3158 | bool found = findSlicesThatSpanInput(slices, CN->getDim(), orderedSlices); |
3159 | if (!found || orderedSlices.size() != slices.size() || |
3160 | orderedSlices.size() != inputsToCN.size()) { |
3161 | return false; |
3162 | } |
3163 | |
3164 | // Now verify that all of the inputs to CN have the same shape as all of the |
3165 | // slices for the result of CN. |
3166 | for (size_t i = 0, e = orderedSlices.size(); i < e; ++i) { |
3167 | if (orderedSlices[i]->getResult().dims() != inputsToCN[i].dims()) { |
3168 | return false; |
3169 | } |
3170 | } |
3171 | |
3172 | // We can now replace all of the inputs to the concat to the result of |
3173 | // each slice. |
3174 | for (size_t i = 0, e = inputsToCN.size(); i < e; ++i) { |
3175 | orderedSlices[i]->getResult().replaceAllUsesOfWith(inputsToCN[i]); |
3176 | } |
3177 | return true; |
3178 | } |
3179 | |
3180 | /// Eliminate Concat-Slice patterns which are unnecessary. E.g.: |
3181 | /// NodeA NodeB NodeA NodeB |
3182 | /// \ / | | |
3183 | /// ConcatC | | |
3184 | /// / \ -----> | | |
3185 | /// SliceD SliceE | | |
3186 | /// | | | | |
3187 | /// NodeF NodeG NodeF NodeG |
3188 | bool EliminateConcatSlice::run(Function *F, const CompilationContext &cctx) { |
3189 | LOG_SCOPE(F->getLogContext(), getName()); |
3190 | bool changed = false; |
3191 | auto &nodes = F->getNodes(); |
3192 | |
3193 | // For each node: |
3194 | for (auto &node : nodes) { |
3195 | auto *CN = dyn_cast<ConcatNode>(&node); |
3196 | if (!CN) { |
3197 | continue; |
3198 | } |
3199 | if (combineConcatSlices(CN)) { |
3200 | changed = true; |
3201 | continue; |
3202 | } |
3203 | } |
3204 | return changed; |
3205 | } |
3206 | |
3207 | /// Eliminate Slice-Concat patterns which are unnecessary. E.g.: |
3208 | /// -- NodeSrc --- -NodeSrc- |
3209 | /// / | | / | |
3210 | /// SliceA SliceB SliceC -----> SliceAB SliceC |
3211 | /// \ / | | | |
3212 | /// NodeE - ConcatABE NodeD NodeE - ConcatABE NodeD |
3213 | bool EliminateSliceConcat::run(Function *F, const CompilationContext &cctx) { |
3214 | LOG_SCOPE(F->getLogContext(), getName()); |
3215 | bool changed = false; |
3216 | auto &nodes = F->getNodes(); |
3217 | |
3218 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) { |
3219 | Node &node = *it; |
3220 | auto *CN = dyn_cast<ConcatNode>(&node); |
3221 | if (!CN) { |
3222 | continue; |
3223 | } |
3224 | // avoid 1) merging through operators other than Slices |
3225 | // e.g. Slice(A)-Other-Slice(B), A and B are consecutive |
3226 | // 2) merging Slices from different sources |
3227 | // e.g. Slice(A1)-Slice(B1)-Slice(A2)-Slice(B2), A1 and A2, B1 and B2 are |
3228 | // consecutive respectively |
3229 | // |
3230 | // Store consecutive slices along *any* dimension. If the consecutive |
3231 | // slices' dimension is the same as the concat, the concat can be removed. |
3232 | // If the slices' dimension is different from the concat, the nodes |
3233 | // can be replaced with slice+reshape OR slice+transpose+reshape. |
3234 | // The extra transpose is needed when the consecutive dimension is |
3235 | // after the concat dimension. |
3236 | std::vector< |
3237 | std::pair<int /* dimension of slicing */, std::vector<SliceNode *>>> |
3238 | consecutiveSlices; |
3239 | std::vector<SliceNode *> currConsecutiveSlices; |
3240 | SliceNode *lastSN = nullptr; |
3241 | int lastDim = -1; |
3242 | for (auto &concatInput : CN->getInputs()) { |
3243 | auto *SN = dyn_cast<SliceNode>(concatInput.getNode()); |
3244 | // slices with multiple users will not be considered |
3245 | if (!SN || SN->getResult().getNumUsers() > 1) { |
3246 | if (currConsecutiveSlices.size()) { |
3247 | consecutiveSlices.emplace_back(lastDim, currConsecutiveSlices); |
3248 | currConsecutiveSlices.clear(); |
3249 | } |
3250 | lastDim = -1; |
3251 | lastSN = nullptr; |
3252 | continue; |
3253 | } |
3254 | // slices with different sources will not be considered |
3255 | int dim = -1; |
3256 | if (lastSN && (lastSN->getInput() != SN->getInput() || |
3257 | !findConsecutiveSliceDim(lastSN, SN, &dim))) { |
3258 | consecutiveSlices.emplace_back(lastDim, currConsecutiveSlices); |
3259 | currConsecutiveSlices.clear(); |
3260 | } |
3261 | lastDim = dim; |
3262 | lastSN = SN; |
3263 | currConsecutiveSlices.emplace_back(SN); |
3264 | } |
3265 | if (currConsecutiveSlices.size()) { |
3266 | consecutiveSlices.emplace_back(lastDim, currConsecutiveSlices); |
3267 | } |
3268 | |
3269 | // Mapping from old Slices to new Nodes where a Node can either be |
3270 | // i) a merged Slice |
3271 | // ii) a merged Slice + Reshape |
3272 | // iii) a merged Slice + Transpose + Reshape |
3273 | std::unordered_map<SliceNode *, Node *> oldSlicesToNewNodes; |
3274 | |
3275 | for (const auto &slicePairs : consecutiveSlices) { |
3276 | unsigned_t slicesDim = slicePairs.first; |
3277 | auto &slices = slicePairs.second; |
3278 | |
3279 | if (slices.size() <= 1) { |
3280 | continue; |
3281 | } |
3282 | if (slicesDim != CN->getDim() && |
3283 | ((slicesDim != CN->getDim() + 1 && slicesDim != CN->getDim() - 1) || |
3284 | slices[0]->getResult().dims()[slicesDim] != 1)) { |
3285 | // Optimizations are possible only if: |
3286 | // 1) slices consecutive dimension is the same as concat dimension, or |
3287 | // 2) slices consecutive dimension is adjacent to the concat dimension, |
3288 | // and the size of each slice along the consecutive dimension is 1. |
3289 | // NOTE: Checking the slicesDim dimension of 0th slice is |
3290 | // sufficient, as opposed to checking every slice. If the slices |
3291 | // can be concatenated (and they're being concatenated along a |
3292 | // different dimension), then each slicesDim dim must be equal. |
3293 | continue; |
3294 | } |
3295 | if ((slicesDim == CN->getDim() + 1 && slices.size() <= 3) || |
3296 | (slicesDim == CN->getDim() - 1 && slices.size() <= 2)) { |
3297 | // Optimization does not decrease the number of nodes. |
3298 | continue; |
3299 | } |
3300 | |
3301 | SliceNode *firstSlice = slices.front(); |
3302 | auto *srcNode = firstSlice->getInput().getNode(); |
3303 | std::vector<dim_t> endDims; |
3304 | |
3305 | for (size_t i = 0, e2 = firstSlice->getResult().dims().size(); i < e2; |
3306 | i++) { |
3307 | endDims.emplace_back(slices.back()->getStart()[i] + |
3308 | slices.back()->getResult().dims()[i]); |
3309 | } |
3310 | Node *newNode = nullptr; |
3311 | auto *newSlice = F->createSlice(firstSlice->getName(), srcNode, |
3312 | firstSlice->getStart(), endDims); |
3313 | |
3314 | // Create a reshape node based on consecutive slice dimension and |
3315 | // concat dimension. |
3316 | if (slicesDim == CN->getDim() + 1 || slicesDim == CN->getDim() - 1) { |
3317 | auto outputDimVec = newSlice->getResult().dims().vec(); |
3318 | outputDimVec[CN->getDim()] *= outputDimVec[slicesDim]; |
3319 | outputDimVec[slicesDim] = 1; |
3320 | auto outputDims = llvm::makeArrayRef(outputDimVec); |
3321 | |
3322 | Node *inputToReshape = nullptr; |
3323 | if (slicesDim == CN->getDim() + 1) { |
3324 | std::vector<unsigned_t> shuffle(outputDimVec.size()); |
3325 | std::iota(shuffle.begin(), shuffle.end(), 0); |
3326 | std::swap(shuffle[slicesDim], shuffle[CN->getDim()]); |
3327 | inputToReshape = F->createTranspose( |
3328 | newSlice->getName().str() + "_Transpose" , newSlice, shuffle); |
3329 | } else { |
3330 | inputToReshape = newSlice; |
3331 | } |
3332 | newNode = F->createReshape(newSlice->getName().str() + "_Reshape" , |
3333 | inputToReshape, outputDims); |
3334 | } else { |
3335 | newNode = newSlice; |
3336 | } |
3337 | |
3338 | for (auto *slice : slices) { |
3339 | oldSlicesToNewNodes[slice] = newNode; |
3340 | } |
3341 | |
3342 | changed = true; |
3343 | } |
3344 | if (!oldSlicesToNewNodes.size()) { |
3345 | continue; |
3346 | } |
3347 | // Replace the input Slices to CN with the merged Nodes. |
3348 | std::vector<NodeValue> newConcatInputs; |
3349 | const Node *lastNewNode = nullptr; |
3350 | for (const auto &concatInput : CN->getInputs()) { |
3351 | auto *SN = dyn_cast<SliceNode>(concatInput.getNode()); |
3352 | if (!SN || !oldSlicesToNewNodes.count(SN)) { |
3353 | newConcatInputs.emplace_back(concatInput); |
3354 | } else { |
3355 | auto *newNode = oldSlicesToNewNodes[SN]; |
3356 | if (newNode != lastNewNode) { |
3357 | lastNewNode = newNode; |
3358 | newConcatInputs.emplace_back(newNode); |
3359 | } |
3360 | } |
3361 | } |
3362 | if (newConcatInputs.size() != CN->getInputs().size()) { |
3363 | auto *newConcat = |
3364 | F->createConcat(CN->getName(), newConcatInputs, CN->getDim()); |
3365 | CN->getResult().replaceAllUsesOfWith(newConcat); |
3366 | } |
3367 | } |
3368 | |
3369 | return changed; |
3370 | } |
3371 | |
3372 | /// Optimize Concat nodes. |
3373 | bool OptimizeConcatNodes::run(Function *F, const CompilationContext &cctx) { |
3374 | LOG_SCOPE(F->getLogContext(), getName()); |
3375 | bool changed = false; |
3376 | auto &nodes = F->getNodes(); |
3377 | |
3378 | // For each node: |
3379 | for (auto &node : nodes) { |
3380 | auto *CN = dyn_cast<ConcatNode>(&node); |
3381 | if (!CN) { |
3382 | continue; |
3383 | } |
3384 | NodeValue newCN = simplifyConcatNode(F, CN, cctx); |
3385 | if (newCN.getNode()) { |
3386 | CN->getResult().replaceAllUsesOfWith(newCN); |
3387 | changed = true; |
3388 | continue; |
3389 | } |
3390 | } |
3391 | return changed; |
3392 | } |
3393 | |
3394 | /// Fold Slices into Constants. This will create new Constants if necessary. |
3395 | bool FoldSlicesIntoConstants::run(Function *F, const CompilationContext &cctx) { |
3396 | LOG_SCOPE(F->getLogContext(), getName()); |
3397 | bool changed = false; |
3398 | auto &nodes = F->getNodes(); |
3399 | |
3400 | // For each node: |
3401 | for (auto &node : nodes) { |
3402 | auto *SN = dyn_cast<SliceNode>(&node); |
3403 | if (!SN) { |
3404 | continue; |
3405 | } |
3406 | auto *C = dyn_cast<Constant>(SN->getInput()); |
3407 | if (!C) { |
3408 | continue; |
3409 | } |
3410 | |
3411 | // Create new slice of the Constant. |
3412 | Tensor outT = Tensor(SN->getResult().getType()); |
3413 | |
3414 | ElemKind k = outT.getElementType(); |
3415 | #define TYPED_INSERT(TY, TYPEKIND) \ |
3416 | if (k == TYPEKIND) { \ |
3417 | auto OH = outT.getHandle<TY>(); \ |
3418 | auto IH = C->getPayloadMutable().getHandle<TY>(); \ |
3419 | IH.extractTensors(OH, SN->getStart()); \ |
3420 | } |
3421 | |
3422 | TYPED_INSERT(float, ElemKind::FloatTy); |
3423 | TYPED_INSERT(float16_t, ElemKind::Float16Ty); |
3424 | TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty); |
3425 | TYPED_INSERT(int8_t, ElemKind::Int8QTy); |
3426 | TYPED_INSERT(int16_t, ElemKind::Int16QTy); |
3427 | TYPED_INSERT(int32_t, ElemKind::Int32QTy); |
3428 | TYPED_INSERT(int32_t, ElemKind::Int32ITy); |
3429 | TYPED_INSERT(int64_t, ElemKind::Int64ITy); |
3430 | TYPED_INSERT(bool, ElemKind::BoolTy); |
3431 | #undef TYPED_INSERT |
3432 | |
3433 | // Create a new Constant NC to hold the sliced result. |
3434 | auto *NC = F->getParent()->createConstant(C->getName(), std::move(outT)); |
3435 | // Connect all Slice users with the new Slice. |
3436 | SN->getResult().replaceAllUsesOfWith(NC); |
3437 | changed = true; |
3438 | } |
3439 | |
3440 | return changed; |
3441 | } |
3442 | |
3443 | /// Simplify and canonicalize arithmetic nodes by detecting simple arithmetic |
3444 | /// identities. |
3445 | bool OptimizeArithmeticNodes::run(Function *F, const CompilationContext &cctx) { |
3446 | LOG_SCOPE(F->getLogContext(), getName()); |
3447 | bool changed = false; |
3448 | // A worklist that contains the nodes to process. |
3449 | std::vector<Node *> worklist; |
3450 | |
3451 | // Add all of the arithmetic nodes to the worklist, with a node's |
3452 | // dependencies added after itself so they are processed before the node. |
3453 | GraphPreOrderVisitor visitor(*F); |
3454 | worklist.reserve(visitor.getPreOrder().size()); |
3455 | for (auto *N : visitor.getPreOrder()) { |
3456 | if (N->isArithmetic()) { |
3457 | worklist.push_back(N); |
3458 | } |
3459 | } |
3460 | while (!worklist.empty()) { |
3461 | Node *N = worklist.back(); |
3462 | assert(N->isArithmetic() && "Must be an Arithmetic node." ); |
3463 | worklist.pop_back(); |
3464 | |
3465 | auto SNV = simplifyNode(N, F); |
3466 | if (SNV.getNode() != N) { |
3467 | N->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(SNV); |
3468 | changed = true; |
3469 | |
3470 | auto *SN = SNV.getNode(); |
3471 | |
3472 | // The simplified node could be further simplified. Note that the |
3473 | // simplified node might not be arithmetic; it could be a splat. |
3474 | if (SN->isArithmetic()) { |
3475 | worklist.push_back(SN); |
3476 | } |
3477 | |
3478 | // The simplified node's operands could be further simplified as well. |
3479 | // Push them after the node so they are processed before the node. |
3480 | for (size_t i = 0, e = SN->getNumInputs(); i < e; i++) { |
3481 | Node *input = SN->getNthInput(i).getNode(); |
3482 | if (input->isArithmetic()) { |
3483 | worklist.push_back(input); |
3484 | } |
3485 | } |
3486 | continue; |
3487 | } |
3488 | } |
3489 | return changed; |
3490 | } |
3491 | |
3492 | /// Statically transpose Constants. |
3493 | bool TransposeConstants::run(Function *F, const CompilationContext &cctx) { |
3494 | LOG_SCOPE(F->getLogContext(), getName()); |
3495 | auto &nodes = F->getNodes(); |
3496 | bool changed = false; |
3497 | for (auto &node : nodes) { |
3498 | auto *TN = dyn_cast<TransposeNode>(&node); |
3499 | if (!TN) { |
3500 | continue; |
3501 | } |
3502 | auto *Q = dyn_cast<QuantizeNode>(TN->getInput()); |
3503 | auto *C = dyn_cast<Constant>(Q ? Q->getInput() : TN->getInput()); |
3504 | if (!C || (Q && !Q->hasOneUse())) { |
3505 | continue; |
3506 | } |
3507 | // Create a new Constant NC to hold the transposed result. |
3508 | auto cTy = F->getParent()->uniqueTypeWithNewShape( |
3509 | C->getType(), TN->getResult().getType()); |
3510 | auto *NC = |
3511 | F->getParent()->createConstant(cTy, C->getName(), TN->getLayout()); |
3512 | // Transpose the value of C into NC. |
3513 | genericTranspose(&C->getPayload(), &NC->getPayloadMutable(), |
3514 | TN->getShuffle()); |
3515 | NC->getPayloadMutable().setType(NC->getType()); |
3516 | // Create a new transposed Quantize, if needed. |
3517 | NodeValue NN(NC); |
3518 | if (Q) { |
3519 | auto qTy = F->getParent()->uniqueTypeWithNewShape( |
3520 | Q->getResult().getType(), TN->getResult().dims()); |
3521 | auto *NQ = F->createQuantize(Q->getName(), NC, qTy); |
3522 | NN = NodeValue(NQ); |
3523 | } |
3524 | // Rewrite uses of TN to reference NC or NQ. |
3525 | TN->getResult().replaceAllUsesOfWith(NN); |
3526 | changed = true; |
3527 | } |
3528 | return changed; |
3529 | } |
3530 | |
3531 | namespace { |
3532 | |
3533 | /// A helper type for hasing Node pointers when they are used as keys in hash |
3534 | /// maps. |
3535 | struct NodeHasher { |
3536 | size_t operator()(Node *N) const { return N->getHash(); } |
3537 | }; |
3538 | |
3539 | /// A helper type implementing the Node equality predicate that can be used |
3540 | /// when Node pointers are used as keys in hash maps. |
3541 | struct NodeEq { |
3542 | bool operator()(const Node *lhs, const Node *rhs) const { |
3543 | return lhs->isEqual(*rhs); |
3544 | } |
3545 | }; |
3546 | |
3547 | /// Array with node kinds which are excepted from CSE optimization, for example |
3548 | /// node kinds for which the output is not necessarily identical when the nodes |
3549 | /// are identical. Such nodes include Touch nodes which allocate buffers without |
3550 | /// initialization or nodes which generate random data. |
3551 | static const std::vector<Kinded::Kind> CSENodeExceptions = { |
3552 | Kinded::Kind::TouchNodeKind, |
3553 | }; |
3554 | |
3555 | /// This visitor is used to walk the graph and perform a common subexpression |
3556 | /// evaluation. |
3557 | struct CSEVisitor : NodeWalker { |
3558 | // Mapping from the original node to its canonical representation under CSE. |
3559 | std::unordered_map<Node *, Node *, NodeHasher, NodeEq> cseNodes_; |
3560 | // Set of visited nodes. |
3561 | std::unordered_set<Node *> visitedNodes_; |
3562 | |
3563 | /// This callback is called before visiting the children of \p N. |
3564 | void pre(Node *parent, Node *N) override { |
3565 | // Put the node into a visited set to make sure it is visited |
3566 | // only once. |
3567 | visitedNodes_.insert(N); |
3568 | } |
3569 | |
3570 | /// This callback is called after visiting the children of \p N. |
3571 | /// It means that all of its dependencies are processed already. |
3572 | void post(Node *parent, Node *N) override { |
3573 | // Try to find a node equivalent to the current one. |
3574 | auto FoundI = cseNodes_.find(N); |
3575 | if (FoundI == cseNodes_.end()) { |
3576 | // No node CSE-equivalent to the current one has been seen yet. |
3577 | // Remember this node, so that the next occurrence can be |
3578 | // replaced by this one. |
3579 | cseNodes_.insert({N, N}); |
3580 | assert(cseNodes_.find(N) != cseNodes_.end()); |
3581 | return; |
3582 | } |
3583 | Node *foundN = FoundI->second; |
3584 | |
3585 | // Same node cannot be visited. |
3586 | assert(N != foundN); |
3587 | |
3588 | // Replace current node by a found node, which is |
3589 | // equivalent to it. |
3590 | assert(N->isEqual(*foundN)); |
3591 | |
3592 | // Skip CSE node exceptions. |
3593 | if (std::find(CSENodeExceptions.begin(), CSENodeExceptions.end(), |
3594 | N->getKind()) != CSENodeExceptions.end()) { |
3595 | return; |
3596 | } |
3597 | |
3598 | // Replace node. |
3599 | for (unsigned i = 0; i < N->getNumResults(); i++) { |
3600 | NodeValue FV(foundN, i); |
3601 | N->getNthResult(i).replaceAllUsesOfWith(FV); |
3602 | } |
3603 | // TODO: Erase N during CSE? If we don't do it here, |
3604 | // DCE will remove it later anyways. |
3605 | } |
3606 | |
3607 | /// Make sure that each node is processed only once. |
3608 | bool shouldVisit(Node *parent, Node *N) override { |
3609 | return visitedNodes_.count(N) == 0; |
3610 | } |
3611 | }; |
3612 | |
3613 | /// A helper type for hashing Constant pointers when they are used as keys in |
3614 | /// hash maps for deduplication. The hash is based on the type of the Constant |
3615 | /// (element type, dimensions), as well as a constant number of elements from |
3616 | /// the backing Tensor to balance collisions with hash calclulation time. |
3617 | struct ConstsHasherDedup { |
3618 | size_t operator()(Constant *V) const { |
3619 | auto hash = llvm::hash_value(V->getType()); |
3620 | auto &T = V->getPayload(); |
3621 | // Only use the first 8 elements in the hash. It's likely that if two |
3622 | // tensors have different content they will diverge quickly. Fall back to |
3623 | // full equality check in ConstsEqDedup. |
3624 | constexpr dim_t maxNumEls = 8; |
3625 | dim_t numEls = std::min((dim_t)T.getType().size(), maxNumEls); |
3626 | dim_t bufSize = T.getType().getElementSize() * numEls; |
3627 | auto *data = T.getUnsafePtr(); |
3628 | for (size_t i = 0; i < bufSize; i++) { |
3629 | hash = llvm::hash_combine(hash, data[i]); |
3630 | } |
3631 | return hash; |
3632 | } |
3633 | }; |
3634 | |
3635 | /// A helper type implementing the Constants equality predicate that can be |
3636 | /// used when Constant pointers are used as keys in hash maps for |
3637 | /// deduplication. |
3638 | struct ConstsEqDedup { |
3639 | bool operator()(const Constant *lhs, const Constant *rhs) const { |
3640 | // Only consider Constants for deduplication if they have the same type. |
3641 | if (lhs->getType() != rhs->getType()) { |
3642 | return false; |
3643 | } |
3644 | // Only dedup Constants if they're bit exact matches. |
3645 | return lhs->getPayload().isBitwiseEqual(rhs->getPayload()); |
3646 | } |
3647 | }; |
3648 | |
3649 | } // namespace |
3650 | |
3651 | /// Deduplicates Constants in the Module \p M. Applicable Constants for |
3652 | /// deduplication must have the same data. \returns whether any Constants were |
3653 | /// deduplicated. |
3654 | static bool deduplicateConstants(Module *M) { |
3655 | // Map from Constants to other Constants that are equivalent for purposes of |
3656 | // deduplication. |
3657 | std::unordered_map<Constant *, Constant *, ConstsHasherDedup, ConstsEqDedup> |
3658 | duplicateConstants; |
3659 | |
3660 | bool changed = false; |
3661 | for (auto &C : M->getConstants()) { |
3662 | // Only perform deduplication of consts with given max number of elements. |
3663 | size_t maxNumEls = constDedupSizeOpt; |
3664 | size_t numEls = C->getType()->size(); |
3665 | if ((maxNumEls != 0) && (numEls > maxNumEls)) { |
3666 | continue; |
3667 | } |
3668 | |
3669 | // Try to find a Constant that has the same data as the current one. |
3670 | auto foundI = duplicateConstants.find(C); |
3671 | if (foundI == duplicateConstants.end()) { |
3672 | // No node equivalent to the current one has been seen yet. Remember |
3673 | // this Constant, so that the next occurrence can be replaced by this |
3674 | // one. |
3675 | duplicateConstants.emplace(C, C); |
3676 | assert(duplicateConstants.find(C) != duplicateConstants.end()); |
3677 | continue; |
3678 | } |
3679 | Constant *foundC = foundI->second; |
3680 | assert(C != foundC && "Constants should not be visited multiple times." ); |
3681 | |
3682 | // Replace current Constant by a found Constant, which is equivalent to |
3683 | // it. |
3684 | C->getOutput().replaceAllUsesOfWith(foundC); |
3685 | changed = true; |
3686 | } |
3687 | return changed; |
3688 | } |
3689 | |
3690 | /// Common Subexpression Elimination. |
3691 | bool CSE::run(Function *F, const CompilationContext &cctx) { |
3692 | LOG_SCOPE(F->getLogContext(), getName()); |
3693 | CSEVisitor visitor; |
3694 | |
3695 | bool changed = false; |
3696 | if (cctx.optimizationOpts.enableConstantDeduplication) { |
3697 | changed |= deduplicateConstants(F->getParent()); |
3698 | } |
3699 | |
3700 | // Perform CSE on all nodes. |
3701 | for (auto &N : F->getNodes()) { |
3702 | N.visit(nullptr, &visitor); |
3703 | } |
3704 | // TODO: Change Visitors to return whether they modified the Function they |
3705 | // are contained in. For now conservatively set changed to true; |
3706 | changed = true; |
3707 | return changed; |
3708 | } |
3709 | |
3710 | /// Fold Nodes into SplatNodes. |
3711 | bool OptimizeSplat::run(Function *F, const CompilationContext &cctx) { |
3712 | LOG_SCOPE(F->getLogContext(), getName()); |
3713 | bool changed = false; |
3714 | for (Node &node : F->getNodes()) { |
3715 | // Slice(Splat(args)) -> Splat(args') |
3716 | if (SliceNode *sliceNode = dyn_cast<SliceNode>(&node)) { |
3717 | SplatNode *splatNode = dyn_cast<SplatNode>(sliceNode->getInput()); |
3718 | if (!splatNode) { |
3719 | continue; |
3720 | } |
3721 | SplatNode *newSplatNode = |
3722 | F->createSplat(sliceNode->getName(), sliceNode->getResult().getType(), |
3723 | splatNode->getValue()); |
3724 | sliceNode->getResult().replaceAllUsesOfWith(newSplatNode); |
3725 | changed = true; |
3726 | continue; |
3727 | } |
3728 | |
3729 | // Clip(Splat(args)) -> Splat(args') |
3730 | if (ClipNode *clipNode = dyn_cast<ClipNode>(&node)) { |
3731 | SplatNode *splatNode = dyn_cast<SplatNode>(clipNode->getInput()); |
3732 | if (!splatNode) { |
3733 | continue; |
3734 | } |
3735 | const float newSplatVal = |
3736 | std::min(std::max(splatNode->getValue(), clipNode->getMin()), |
3737 | clipNode->getMax()); |
3738 | |
3739 | SplatNode *newSplatNode = nullptr; |
3740 | if (newSplatVal == splatNode->getValue()) { |
3741 | // No need to crate a new Splat. |
3742 | newSplatNode = splatNode; |
3743 | } else { |
3744 | newSplatNode = F->createSplat( |
3745 | splatNode->getName().str() + clipNode->getName().str(), |
3746 | splatNode->getResult().getType(), newSplatVal); |
3747 | } |
3748 | |
3749 | clipNode->getResult().replaceAllUsesOfWith(newSplatNode->getResult()); |
3750 | changed = true; |
3751 | continue; |
3752 | } |
3753 | } |
3754 | return changed; |
3755 | } |
3756 | |
3757 | bool GatherToSlice::run(Function *F, const CompilationContext &cctx) { |
3758 | LOG_SCOPE(F->getLogContext(), getName()); |
3759 | bool changed = false; |
3760 | |
3761 | for (auto &node : F->getNodes()) { |
3762 | auto *GN = dyn_cast<GatherNode>(&node); |
3763 | if (!GN) { |
3764 | continue; |
3765 | } |
3766 | |
3767 | auto data = GN->getData(); |
3768 | auto *indices = dyn_cast<Constant>(GN->getIndices()); |
3769 | |
3770 | // Only handling scalar constant index value |
3771 | if (!indices || indices->getPayload().size() != 1) { |
3772 | continue; |
3773 | } |
3774 | |
3775 | dim_t index = 0; |
3776 | size_t axis = GN->getBatchDims(); |
3777 | auto elementKind = indices->getElementType(); |
3778 | if (elementKind == ElemKind::Int64ITy) { |
3779 | index = (size_t)indices->getHandle<int64_t>().raw(0); |
3780 | } else if (elementKind == ElemKind::Int32ITy) { |
3781 | index = (size_t)indices->getHandle<int32_t>().raw(0); |
3782 | } else { |
3783 | llvm_unreachable("GatherToSlice: Unsupported indices type" ); |
3784 | } |
3785 | |
3786 | std::vector<dim_t> start; |
3787 | std::vector<dim_t> end; |
3788 | for (size_t i = 0; i < data.dims().size(); ++i) { |
3789 | if (i == axis) { |
3790 | start.push_back(index); |
3791 | end.push_back(index + 1); |
3792 | } else { |
3793 | start.push_back(0); |
3794 | end.push_back(data.dims()[i]); |
3795 | } |
3796 | } |
3797 | |
3798 | auto name = GN->getName(); |
3799 | auto *SN = F->createSlice(name, data, start, end); |
3800 | auto *RN = F->createReshape(name, SN, GN->getResult().dims()); |
3801 | |
3802 | GN->getResult().replaceAllUsesOfWith(RN->getResult()); |
3803 | changed = true; |
3804 | } |
3805 | |
3806 | return changed; |
3807 | } |
3808 | |
3809 | /// Optimize TransposeNode into ReshapeNode when it actually moves no data. |
3810 | bool OptimizeTransposeIntoReshape::run(Function *F, |
3811 | const CompilationContext &cctx) { |
3812 | LOG_SCOPE(F->getLogContext(), getName()); |
3813 | bool changed = false; |
3814 | |
3815 | for (auto &node : F->getNodes()) { |
3816 | auto *TR = dyn_cast<TransposeNode>(&node); |
3817 | if (!TR) { |
3818 | continue; |
3819 | } |
3820 | auto inputNode = TR->getInput(); |
3821 | auto inputDims = inputNode.dims(); |
3822 | auto outputDims = TR->getResult().dims(); |
3823 | // The transformation is not possible if alignments different from 1 are |
3824 | // used for any dimension. |
3825 | if (!inputNode.getType()->isEqual(F->getParent()->uniqueTypeWithNewShape( |
3826 | inputNode.getType(), inputDims))) { |
3827 | continue; |
3828 | } |
3829 | if (!TR->getResult().getType()->isEqual( |
3830 | F->getParent()->uniqueTypeWithNewShape(TR->getResult().getType(), |
3831 | outputDims))) { |
3832 | continue; |
3833 | } |
3834 | // Transpose moves no data if input/output dimensions match after they both |
3835 | // drop dimensions of size 1. E.g. transposing [1 5 1 15] into [5 15 1 1] |
3836 | // produces vectors (1, 3) for both dimensions so optimization is executed. |
3837 | auto shuffle = TR->getShuffle(); |
3838 | ShapeVector inDims; |
3839 | ShapeVector outDims; |
3840 | for (size_t i = 0; i < inputDims.size(); i++) { |
3841 | if (inputDims[i] != 1) { |
3842 | inDims.push_back(i); |
3843 | } |
3844 | if (outputDims[i] != 1) { |
3845 | outDims.push_back(shuffle[i]); |
3846 | } |
3847 | } |
3848 | if (inDims != outDims) { |
3849 | continue; |
3850 | } |
3851 | auto *RS = |
3852 | F->createReshape(TR->getName(), inputNode, outputDims, TR->getLayout()); |
3853 | TR->getResult().replaceAllUsesOfWith(RS); |
3854 | changed = true; |
3855 | } |
3856 | |
3857 | return changed; |
3858 | } |
3859 | |
3860 | /// Eliminate nodes which don't do anything useful. |
3861 | bool EliminateNoop::run(Function *F, const CompilationContext &cctx) { |
3862 | LOG_SCOPE(F->getLogContext(), getName()); |
3863 | bool changed = false; |
3864 | |
3865 | auto isNoop = [](const Node &node, NodeValue &input, |
3866 | NodeValue &output) -> bool { |
3867 | // Auto-select input/output, if there is just one. For other cases it |
3868 | // will be handled on per-operator basis below. |
3869 | if (node.getNumInputs() == 1) { |
3870 | input = node.getNthInput(0); |
3871 | } |
3872 | if (node.getNumResults() == 1) { |
3873 | output = node.getNthResult(0); |
3874 | } |
3875 | |
3876 | // For some nodes it's enough just to compare input and output types to |
3877 | // determine if they are noop. |
3878 | if (isa<PadNode>(&node) || isa<SliceNode>(&node) || isa<TileNode>(&node) || |
3879 | isa<BroadcastNode>(&node)) { |
3880 | return input.getType() == output.getType(); |
3881 | } |
3882 | |
3883 | // Operator-specific analysis. |
3884 | if (auto *APN = dyn_cast<AvgPoolNode>(&node)) { |
3885 | input = APN->getInput(); |
3886 | return isUniformArray(APN->getKernels(), 1u) && |
3887 | isUniformArray(APN->getStrides(), 1u) && |
3888 | isUniformArray(APN->getPads(), 0u); |
3889 | } else if (auto *MPN = dyn_cast<MaxPoolNode>(&node)) { |
3890 | input = MPN->getInput(); |
3891 | output = MPN->getResult(); |
3892 | return isUniformArray(MPN->getKernels(), 1u) && |
3893 | isUniformArray(MPN->getStrides(), 1u) && |
3894 | isUniformArray(MPN->getPads(), 0u); |
3895 | } else if (auto *SN = dyn_cast<SelectNode>(&node)) { |
3896 | auto *cond = getConstant(SN->getCond()); |
3897 | bool val = false; |
3898 | if (cond && isUniformConstant<bool>(*cond, val)) { |
3899 | input = val ? SN->getLHS() : SN->getRHS(); |
3900 | return true; |
3901 | } |
3902 | } |
3903 | |
3904 | return false; |
3905 | }; |
3906 | |
3907 | for (auto &node : F->getNodes()) { |
3908 | NodeValue input, output; |
3909 | if (isNoop(node, input, output)) { |
3910 | assert(input != NodeValue() && output != NodeValue() && |
3911 | "Sanity check that input and output are set" ); |
3912 | output.replaceAllUsesOfWith(input); |
3913 | changed = true; |
3914 | } |
3915 | } |
3916 | |
3917 | return changed; |
3918 | } |
3919 | |
3920 | /// Optimize reshape nodes. |
3921 | bool OptimizeReshape::run(Function *F, const CompilationContext &cctx) { |
3922 | LOG_SCOPE(F->getLogContext(), getName()); |
3923 | bool changed = false; |
3924 | for (auto &node : F->getNodes()) { |
3925 | auto *reshapeNode = dyn_cast<ReshapeNode>(&node); |
3926 | if (!reshapeNode) { |
3927 | continue; |
3928 | } |
3929 | auto inputNode = reshapeNode->getNthInput(ReshapeNode::InputIdx); |
3930 | // Eliminate ReshapeNode when the input is already the correct shape. |
3931 | if (inputNode.dims() == reshapeNode->getResult().dims()) { |
3932 | reshapeNode->getResult().replaceAllUsesOfWith(inputNode); |
3933 | continue; |
3934 | } |
3935 | // Reshape(Splat(args)) -> Splat(args'). |
3936 | auto *splatNode = dyn_cast<SplatNode>(inputNode); |
3937 | if (splatNode && splatNode->hasOneUse()) { |
3938 | // Splat node with more than one use can not be transformed, otherwise |
3939 | // we would increase the number of splats, which may lead to increased |
3940 | // memory consumption during the execution of the NN model. |
3941 | auto *newSplatNode = F->createSplat(splatNode->getName(), |
3942 | reshapeNode->getResult().getType(), |
3943 | splatNode->getValue()); |
3944 | reshapeNode->getResult().replaceAllUsesOfWith(newSplatNode); |
3945 | changed = true; |
3946 | continue; |
3947 | } |
3948 | // Reshape(Reshape(x)) -> Reshape(x). |
3949 | auto *reshapeNodeInput = dyn_cast<ReshapeNode>(inputNode); |
3950 | if (reshapeNodeInput && reshapeNodeInput->hasOneUse()) { |
3951 | auto *newReshape = F->createReshape( |
3952 | reshapeNode->getName(), reshapeNodeInput->getInput(), |
3953 | reshapeNode->getResult().dims(), reshapeNode->getLayout()); |
3954 | reshapeNode->getResult().replaceAllUsesOfWith(newReshape); |
3955 | changed = true; |
3956 | continue; |
3957 | } |
3958 | // Reshape(Constant) -> Constant' or |
3959 | // Reshape(Quantize(Constant)) -> Quantize'(Constant'). |
3960 | auto *Q = dyn_cast<QuantizeNode>(inputNode); |
3961 | auto *C = dyn_cast<Constant>(Q ? Q->getInput() : inputNode); |
3962 | if (C && (!Q || Q->hasOneUse())) { |
3963 | // Create a new Constant with the dims of the reshape. |
3964 | auto layout = |
3965 | CanonicalTensorLayout::getInstance().getNthResultLayoutRequirements( |
3966 | reshapeNode, ReshapeNode::ResultIndices::ResultIdx); |
3967 | auto cTy = F->getParent()->uniqueTypeWithNewShape( |
3968 | C->getType(), reshapeNode->getResult().getType()); |
3969 | auto *newC = F->getParent()->createConstant(cTy, C->getName(), layout); |
3970 | // Create an unowned view of the original tensor with the correct shape, |
3971 | // and assign it to the new Constant. |
3972 | Tensor reshapedT = C->getPayload().getUnowned(reshapeNode->getDims()); |
3973 | newC->assign(&reshapedT); |
3974 | // Create a new Quantize with the dims of the reshape, if needed. |
3975 | NodeValue newN(newC); |
3976 | if (Q) { |
3977 | auto qTy = F->getParent()->uniqueTypeWithNewShape( |
3978 | Q->getResult().getType(), reshapeNode->getResult().dims()); |
3979 | auto *newQ = F->createQuantize(Q->getName(), newC, qTy); |
3980 | newN = NodeValue(newQ); |
3981 | } |
3982 | reshapeNode->getResult().replaceAllUsesOfWith(newN); |
3983 | changed = true; |
3984 | continue; |
3985 | } |
3986 | } |
3987 | return changed; |
3988 | } |
3989 | |
3990 | /// Helper to optimize Resize nodes. |
3991 | template <typename ResizeNodeType> |
3992 | static bool optimizeResize(Function *F, const CompilationContext &cctx) { |
3993 | bool changed = false; |
3994 | for (auto &node : F->getNodes()) { |
3995 | auto *resizeNode = dyn_cast<ResizeNodeType>(&node); |
3996 | CONTINUE_IF_NOT(resizeNode); |
3997 | // Remove identity resize (same input and output type). |
3998 | auto inpType = resizeNode->getInput().getType(); |
3999 | auto outType = resizeNode->getResult().getType(); |
4000 | if (inpType->isEqual(outType)) { |
4001 | resizeNode->getResult().replaceAllUsesOfWith(resizeNode->getInput()); |
4002 | changed = true; |
4003 | continue; |
4004 | } |
4005 | // Dimensions which are resized from unitary sizes should use Tile nodes. |
4006 | // We pull out Tile nodes from the Resize output such that the Resize node |
4007 | // operates on smaller sizes thus reducing the complexity. We create Tile |
4008 | // nodes in the decreasing order of the dimensions to increase locality. |
4009 | auto inpDims = resizeNode->getInput().dims(); |
4010 | auto outDims = resizeNode->getResult().dims(); |
4011 | std::vector<dim_t> newOutDims = outDims.vec(); |
4012 | std::vector<unsigned_t> axes; |
4013 | std::vector<unsigned_t> tiles; |
4014 | for (ssize_t idx = outDims.size() - 1; idx >= 0; idx--) { |
4015 | if ((inpDims[idx] == 1) && (outDims[idx] > 1)) { |
4016 | newOutDims[idx] = 1; |
4017 | axes.push_back(idx); |
4018 | tiles.push_back(outDims[idx]); |
4019 | } |
4020 | } |
4021 | CONTINUE_IF_NOT(axes.size()); |
4022 | auto newOutType = |
4023 | F->getParent()->uniqueTypeWithNewShape(outType, newOutDims); |
4024 | // Create Resize node. |
4025 | NodeValue newOut = resizeNode->getInput(); |
4026 | if (!inpType->isEqual(newOutType)) { |
4027 | if (std::is_same<ResizeNodeType, ResizeNearestNode>::value) { |
4028 | newOut = F->createResizeNearest(node.getName(), newOut, newOutType); |
4029 | } else if (std::is_same<ResizeNodeType, ResizeBilinearNode>::value) { |
4030 | newOut = F->createResizeBilinear(node.getName(), newOut, newOutType); |
4031 | } else { |
4032 | llvm_unreachable("Resize node type not supported!" ); |
4033 | } |
4034 | } |
4035 | // Create Tile nodes. |
4036 | newOut = |
4037 | F->createTile(node.getName().str() + "." + "Tile" , newOut, tiles, axes); |
4038 | resizeNode->getResult().replaceAllUsesOfWith(newOut); |
4039 | changed = true; |
4040 | continue; |
4041 | } |
4042 | return changed; |
4043 | } |
4044 | |
4045 | /// Optimize Resize nodes. |
4046 | bool OptimizeResize::run(Function *F, const CompilationContext &cctx) { |
4047 | LOG_SCOPE(F->getLogContext(), getName()); |
4048 | bool changed = false; |
4049 | changed |= optimizeResize<ResizeNearestNode>(F, cctx); |
4050 | changed |= optimizeResize<ResizeBilinearNode>(F, cctx); |
4051 | return changed; |
4052 | } |
4053 | |
4054 | /// Optimize Insert nodes. |
4055 | bool OptimizeInsert::run(Function *F, const CompilationContext &cctx) { |
4056 | LOG_SCOPE(F->getLogContext(), getName()); |
4057 | bool changed = false; |
4058 | for (auto &node : F->getNodes()) { |
4059 | auto *insertNode = dyn_cast<InsertTensorNode>(&node); |
4060 | CONTINUE_IF_NOT(insertNode); |
4061 | // When the "Big" tensor is a Splat which is entirely filled by the |
4062 | // "Small" tensor then we replace the Splat with a Touch node to remove |
4063 | // the initialization overhead of the Splat which is not needed. |
4064 | NodeValue big = insertNode->getBig(); |
4065 | NodeValue small = insertNode->getSmall(); |
4066 | auto bigDims = big.dims().vec(); |
4067 | auto smallDims = small.dims().vec(); |
4068 | smallDims[insertNode->getAxis()] *= insertNode->getCount(); |
4069 | CONTINUE_IF_NOT(isUniformArray(insertNode->getStart(), dim_t(0)) && |
4070 | dyn_cast<SplatNode>(big) && (bigDims == smallDims)); |
4071 | NodeValue touch = |
4072 | F->createTouch(node.getName().str() + "." + "Touch" , big.getType()); |
4073 | node.setNthInput(InsertTensorNode::BigIdx, touch); |
4074 | changed = true; |
4075 | continue; |
4076 | } |
4077 | return changed; |
4078 | } |
4079 | |
4080 | /// Optimize: Max(Splat(), otherInput) or Max(otherInput, Splat()) for |
4081 | /// quantized operations. |
4082 | /// Splat and Max can be eliminated if Splat value cannot impact the result. |
4083 | /// For example, Max and Splat can be removed if splat value is smaller |
4084 | /// than quantization range [min, max]. |
4085 | /// \returns if anything was changed in the given function. |
4086 | static bool optimizeQuantizedMaxSplat(Function *F) { |
4087 | LOG_SCOPE(F->getLogContext(), "optimizeQuantizedMaxSplat" ) |
4088 | |
4089 | bool changed = false; |
4090 | // The following optimizations need to be performed after all |
4091 | // quantize/dequantize/rescale optimizations are done. |
4092 | for (auto &node : F->getNodes()) { |
4093 | // Potentially nop quantized Max can be eliminated. |
4094 | // Likely MaxNode has same types for LHS/RHS and Result, make sure |
4095 | // it's the case. |
4096 | if (auto *MN = dyn_cast<MaxNode>(&node)) { |
4097 | if (!MN->getResult().getType()->isQuantizedType() || |
4098 | MN->getResult().getType() != MN->getLHS().getType() || |
4099 | MN->getResult().getType() != MN->getRHS().getType()) { |
4100 | continue; |
4101 | } |
4102 | |
4103 | // Check for input Splat node. |
4104 | if (!isa<SplatNode>(MN->getLHS()) && !isa<SplatNode>(MN->getRHS())) { |
4105 | continue; |
4106 | } |
4107 | |
4108 | Node *splatNode = |
4109 | isa<SplatNode>(MN->getLHS()) ? MN->getLHS() : MN->getRHS(); |
4110 | Node *otherInput = |
4111 | isa<SplatNode>(MN->getLHS()) ? MN->getRHS() : MN->getLHS(); |
4112 | |
4113 | // If splat value is smaller than values that can be covered by |
4114 | // quantization [min,max] range then just remove MaxNode operation. |
4115 | float splatValue = (dyn_cast<SplatNode>(splatNode))->getValue(); |
4116 | float min = MN->getResult().getType()->getQuantizedValueRange().first; |
4117 | if (splatValue <= min) { |
4118 | changed = true; |
4119 | MN->getResult().replaceAllUsesOfWith(otherInput); |
4120 | } |
4121 | continue; |
4122 | } |
4123 | // Potentially nop quantized ReLU can be eliminated. |
4124 | if (auto *RN = dyn_cast<ReluNode>(&node)) { |
4125 | if (!RN->getResult().getType()->isQuantizedType() || |
4126 | RN->getResult().getType() != RN->getInput().getType()) { |
4127 | continue; |
4128 | } |
4129 | |
4130 | Node *input = RN->getInput(); |
4131 | |
4132 | // If zero is smaller or equal than values that can be covered by |
4133 | // quantization [min,max] range then just remove ReluNode operation. |
4134 | float min = RN->getResult().getType()->getQuantizedValueRange().first; |
4135 | if (0.0f <= min) { |
4136 | changed = true; |
4137 | RN->getResult().replaceAllUsesOfWith(input); |
4138 | } |
4139 | continue; |
4140 | } |
4141 | } |
4142 | return changed; |
4143 | } |
4144 | |
4145 | /// \returns A value representing the given \p constant converted |
4146 | /// to the destination type \p dstTy. If the conversion is not |
4147 | /// possible, this method returns NodeValue(). |
4148 | static NodeValue convertConstant(Module &mod, Constant &constant, |
4149 | TypeRef dstTy) { |
4150 | // Sort out the easy case first. |
4151 | if (constant.getType() == dstTy) { |
4152 | return constant.getOutput(); |
4153 | } |
4154 | auto modifyConstantTyAndGet = [&]() -> Constant & { |
4155 | Constant *oneUseCst = getUniquelyUsedConstant(&mod, constant); |
4156 | assert(oneUseCst && |
4157 | "We should always be able to get a constant from a constant!" ); |
4158 | // This type setting updates the type of the node. |
4159 | // The underlying tensor still needs to be converted after this call. |
4160 | oneUseCst->setType(Storage::OutputIdx, dstTy); |
4161 | return *oneUseCst; |
4162 | }; |
4163 | const Tensor &tensor = constant.getPayload(); |
4164 | switch (tensor.getElementType()) { |
4165 | case ElemKind::FloatTy: |
4166 | case ElemKind::Float16Ty: |
4167 | case ElemKind::BFloat16Ty: |
4168 | switch (dstTy->getElementType()) { |
4169 | case ElemKind::FloatTy: |
4170 | case ElemKind::Float16Ty: |
4171 | case ElemKind::BFloat16Ty: { |
4172 | // Plain conversion: |
4173 | // {FloatTy, Float16Ty, BFloat16Ty} -> {FloatTy, Float16Ty, BFloat16Ty}. |
4174 | Constant &constantToBeModified = modifyConstantTyAndGet(); |
4175 | constantToBeModified.getPayloadMutable().convertToType( |
4176 | dstTy->getElementType()); |
4177 | return constantToBeModified.getOutput(); |
4178 | } |
4179 | case ElemKind::Int64QTy: |
4180 | case ElemKind::Int32QTy: |
4181 | case ElemKind::Int16QTy: |
4182 | case ElemKind::Int8QTy: { |
4183 | // Quantization: {FloatTy, Float16Ty, BFloat16Ty} -> Quantized type. |
4184 | Constant &constantToBeModified = modifyConstantTyAndGet(); |
4185 | TensorQuantizationParams params{dstTy->getScale(), dstTy->getOffset()}; |
4186 | Tensor &tensorToBeModified = constantToBeModified.getPayloadMutable(); |
4187 | // Right now we only quantize fp32 value. |
4188 | // Add an assert on that, so that if it changes, we adapt the |
4189 | // following code. Adapting the code would required to |
4190 | // teach quantizeTensor how to deal with Float16Ty. |
4191 | assert(tensor.getType().isFPType() && |
4192 | "Type quantization not implemented" ); |
4193 | tensorToBeModified = quantization::quantizeTensor( |
4194 | tensorToBeModified, params, dstTy->getElementType()); |
4195 | return constantToBeModified.getOutput(); |
4196 | } |
4197 | case ElemKind::Int32ITy: { |
4198 | // Plain conversion: {FloatTy} -> {Int32ITy}. |
4199 | Constant &constantToBeModified = modifyConstantTyAndGet(); |
4200 | constantToBeModified.getPayloadMutable().convertToType( |
4201 | dstTy->getElementType()); |
4202 | return constantToBeModified.getOutput(); |
4203 | } |
4204 | default: |
4205 | // Quantization: {FloatTy, Float16Ty, BFloat16Ty} -> Int[16|32]QTy. |
4206 | // Plain conversion: {FloatTy, Float16Ty, BFloat16Ty} -> Int64ITy. |
4207 | return NodeValue(); |
4208 | } |
4209 | case ElemKind::UInt8FusedQTy: { |
4210 | if (dstTy->getElementType() != ElemKind::UInt8FusedFP16QTy) { |
4211 | return NodeValue(); |
4212 | } |
4213 | auto *NC = |
4214 | mod.createConstant(dstTy, constant.getName(), constant.getLayout()); |
4215 | NC->getPayloadMutable() = |
4216 | tensor.getCopyConvertedToType(dstTy->getElementType()); |
4217 | return NC->getOutput(); |
4218 | } |
4219 | case ElemKind::Int64ITy: |
4220 | case ElemKind::Int32ITy: |
4221 | switch (dstTy->getElementType()) { |
4222 | case ElemKind::Int32ITy: |
4223 | case ElemKind::Int64ITy: { |
4224 | // Plain conversion: {Int64ITy, Int32ITy} -> {Int64ITy, Int32ITy}. |
4225 | Constant &constantToBeModified = modifyConstantTyAndGet(); |
4226 | constantToBeModified.getPayloadMutable().convertToType( |
4227 | dstTy->getElementType()); |
4228 | return constantToBeModified.getOutput(); |
4229 | } |
4230 | case ElemKind::FloatTy: { |
4231 | Constant &constantToBeModified = modifyConstantTyAndGet(); |
4232 | constantToBeModified.getPayloadMutable().convertToType( |
4233 | dstTy->getElementType()); |
4234 | return constantToBeModified.getOutput(); |
4235 | } |
4236 | |
4237 | default: |
4238 | return NodeValue(); |
4239 | } |
4240 | default: |
4241 | // For now we don't see other quantize, dequantize, or rescale nodes |
4242 | // directly attached to constants. |
4243 | // Thus don't add code that will never be executed. |
4244 | // Dequantization: Int[8|16|32]QTy -> {FloatTy, Float16Ty, BFloat16Ty, |
4245 | // Int64I}. Rescale: Int[8|16|32]QTy -> Int[8|16|32]QTy. Plain conversion: |
4246 | // Int64ITy -> {FloatTy, Float16Ty, BFloat16Ty}. Quantization: Int64ITy -> |
4247 | // Int[8|16|32]QTy. |
4248 | return NodeValue(); |
4249 | } |
4250 | } |
4251 | |
4252 | /// Compute number of significant bits that are used to represent data of type |
4253 | /// \p kind. For FP, it is the number of bits in mantissa, for integers it's the |
4254 | /// number of bits except sign bit. |
4255 | /// \p returns number of significant bits of \p kind. |
4256 | /// TODO: Currently, for all supported types wider mantissa also means wider |
4257 | /// exponent. If we add a type for which this is not true, we should check both |
4258 | /// mantissa and exponent. |
4259 | static size_t numSignificantBits(ElemKind kind) { |
4260 | switch (kind) { |
4261 | case ElemKind::BoolTy: |
4262 | return std::numeric_limits<bool>::digits; |
4263 | case ElemKind::Int8QTy: |
4264 | return std::numeric_limits<int8_t>::digits; |
4265 | case ElemKind::UInt8QTy: |
4266 | case ElemKind::UInt8FusedQTy: |
4267 | return std::numeric_limits<uint8_t>::digits; |
4268 | case ElemKind::Float16Ty: |
4269 | // Custom type with layout 0 00000 0000000000. |
4270 | return 10; |
4271 | case ElemKind::BFloat16Ty: |
4272 | // bfloat16 has 8 significant bits. |
4273 | return 8; |
4274 | case ElemKind::Int16QTy: |
4275 | return std::numeric_limits<int16_t>::digits; |
4276 | case ElemKind::FloatTy: |
4277 | return std::numeric_limits<float>::digits; |
4278 | case ElemKind::Float64Ty: |
4279 | return std::numeric_limits<double>::digits; |
4280 | case ElemKind::Int32QTy: |
4281 | case ElemKind::Int32ITy: |
4282 | return std::numeric_limits<int32_t>::digits; |
4283 | case ElemKind::Int64ITy: |
4284 | return std::numeric_limits<int64_t>::digits; |
4285 | default: |
4286 | // Avoid compiler warning. |
4287 | break; |
4288 | } |
4289 | llvm_unreachable("Unknown type!" ); |
4290 | } |
4291 | |
4292 | /// Returns true if casting value from \p srcTy to \p destTy may change it. As |
4293 | /// implication of this, casting value from \p srcTy to \p destTy and back may |
4294 | /// produce different value than before cast. |
4295 | static bool isValueChangingCast(TypeRef srcTy, TypeRef destTy) { |
4296 | // FP-to-Int conversion may lead to loss of fraction, so it's not NOOP. |
4297 | if (srcTy->isFPType() && !destTy->isFPType()) { |
4298 | return true; |
4299 | } |
4300 | // Narrowing transform (e.g. int64 to int32) may lead to loss of |
4301 | // significant senior bits, so it's not NOOP. |
4302 | ElemKind srcElKind = srcTy->getElementType(); |
4303 | ElemKind convElKind = destTy->getElementType(); |
4304 | if (numSignificantBits(srcElKind) > numSignificantBits(convElKind)) { |
4305 | return true; |
4306 | } |
4307 | return false; |
4308 | } |
4309 | |
4310 | /// Optimize away redundant ClipNodes. |
4311 | /// We basically turn "Clip(Clip(Clip(A)))" to "Clip(A)". |
4312 | bool OptimizeClips::run(Function *F, const CompilationContext &cctx) { |
4313 | LOG_SCOPE(F->getLogContext(), getName()); |
4314 | |
4315 | bool changed = false; |
4316 | for (Node &node : F->getNodes()) { |
4317 | ClipNode *clip = dyn_cast<ClipNode>(&node); |
4318 | if (!clip) { |
4319 | continue; |
4320 | } |
4321 | float min = clip->getMin(); |
4322 | float max = clip->getMax(); |
4323 | if (auto *clipPrev = dyn_cast<ClipNode>(clip->getInput().getNode())) { |
4324 | float minPrev = clipPrev->getMin(); |
4325 | float maxPrev = clipPrev->getMax(); |
4326 | auto *newClip = |
4327 | F->createClip(clipPrev->getName(), clipPrev->getInput().getNode(), |
4328 | std::max(minPrev, min), std::min(maxPrev, max)); |
4329 | clip->getResult().replaceAllUsesOfWith(newClip); |
4330 | changed = true; |
4331 | continue; |
4332 | } |
4333 | |
4334 | // We can fold Clip(Relu) -> Clip' |
4335 | if (ReluNode *relu = dyn_cast<ReluNode>(clip->getInput())) { |
4336 | const float newMin = std::max(0.0f, min); |
4337 | ClipNode *newClip = F->createClip(clip->getName().str() + "_relu" , |
4338 | relu->getInput(), newMin, max); |
4339 | clip->getResult().replaceAllUsesOfWith(newClip->getResult()); |
4340 | changed = true; |
4341 | continue; |
4342 | } |
4343 | } |
4344 | |
4345 | return changed; |
4346 | } |
4347 | |
4348 | /// \returns whether \p N used used by any Nodes with side effects. |
4349 | static bool isUsedByNodeWithSideEffects(Node *N) { |
4350 | for (const auto &user : N->getUsers()) { |
4351 | if (user.getUser()->hasSideEffects()) { |
4352 | return true; |
4353 | } |
4354 | } |
4355 | return false; |
4356 | } |
4357 | |
4358 | /// Helper that \returns whether \p NV cannot have its output quantization |
4359 | /// parameters changed. For example, Concats require all inputs to have the same |
4360 | /// quantization parameters, so we cannot change the quantization parameters of |
4361 | /// a Node if it is input into a Concat. |
4362 | static bool disallowQuantParamChange(const NodeValue &NV) { |
4363 | for (auto &user : NV.getUsers()) { |
4364 | if (isa<ConcatNode>(user.getUser())) { |
4365 | return true; |
4366 | } |
4367 | } |
4368 | return false; |
4369 | } |
4370 | |
4371 | /// This is a specialized pass to use where we assume that quantized ranges are |
4372 | /// all inside the FP16 range. This means that if we have any clips outside the |
4373 | /// FP16 range we can safely remove them if adjacent to a quantized op. |
4374 | bool EliminateClipsOutsideFP16Range::run(Function *F, |
4375 | const CompilationContext &cctx) { |
4376 | LOG_SCOPE(F->getLogContext(), getName()); |
4377 | |
4378 | if (!cctx.precisionConfig.clipQuantRangeToFP16) { |
4379 | return false; |
4380 | } |
4381 | |
4382 | bool changed = false; |
4383 | for (Node &node : F->getNodes()) { |
4384 | // Clip(Dequantize(Node)) -> Dequantize(Node) |
4385 | if (ClipNode *clip = dyn_cast<ClipNode>(&node)) { |
4386 | DequantizeNode *DQN = dyn_cast<DequantizeNode>(clip->getInput()); |
4387 | if (!DQN) { |
4388 | continue; |
4389 | } |
4390 | |
4391 | // Can only eliminate the clip if its outside the FP16 range. |
4392 | if (clip->getMin() > kMinFP16 || clip->getMax() < kMaxFP16) { |
4393 | continue; |
4394 | } |
4395 | |
4396 | // We can safely skip the Clip at this point. |
4397 | clip->getResult().replaceAllUsesOfWith(DQN->getResult()); |
4398 | changed = true; |
4399 | continue; |
4400 | } |
4401 | |
4402 | // Quantize(Clip(Node)) -> Quantize(Node) |
4403 | if (QuantizeNode *QN = dyn_cast<QuantizeNode>(&node)) { |
4404 | ClipNode *clip = dyn_cast<ClipNode>(QN->getInput()); |
4405 | if (!clip) { |
4406 | continue; |
4407 | } |
4408 | |
4409 | // Can only eliminate the clip if its outside the FP16 range. |
4410 | if (clip->getMin() > kMinFP16 || clip->getMax() < kMaxFP16) { |
4411 | continue; |
4412 | } |
4413 | |
4414 | // We can safely skip the Clip at this point. |
4415 | QN->setNthInput(QuantizeNode::InputIdx, clip->getInput()); |
4416 | changed = true; |
4417 | continue; |
4418 | } |
4419 | } |
4420 | |
4421 | return changed; |
4422 | } |
4423 | |
4424 | /// When quantized operators and Clips are used together, we can often merge the |
4425 | /// Clip range and the Quantized range and remove the Clip. |
4426 | bool OptimizeQuantizeClip::run(Function *F, const CompilationContext &cctx) { |
4427 | LOG_SCOPE(F->getLogContext(), getName()); |
4428 | |
4429 | // All of the optimizations here depend on the quantization parameters. If |
4430 | // we've loaded dummy qparams then none should be performed. |
4431 | if (cctx.precisionConfig.loadUniquedDummyQParams) { |
4432 | return false; |
4433 | } |
4434 | |
4435 | bool changed = false; |
4436 | |
4437 | // Change a quantized result type qResult to account for the range from clip. |
4438 | auto updateQuantizeNodeType = [](Function *F, const CompilationContext &cctx, |
4439 | NodeValue qResult, ClipNode *clip, |
4440 | bool skipIfQuantParamChange, |
4441 | bool allowQParamChange) { |
4442 | const auto qMinMax = qResult.getType()->getQuantizedValueRange(); |
4443 | const float newMin = std::max(clip->getMin(), qMinMax.first); |
4444 | const float newMax = std::min(clip->getMax(), qMinMax.second); |
4445 | |
4446 | // If the quantization parameters do not change then we can always elide the |
4447 | // Clip and do not need to change the type of qResult. |
4448 | if (newMin == qMinMax.first && newMax == qMinMax.second) { |
4449 | return true; |
4450 | } |
4451 | |
4452 | if (disallowQuantParamChange(qResult)) { |
4453 | return false; |
4454 | } |
4455 | |
4456 | // At this point the quantization parameters must be changing, so if we do |
4457 | // not allow for that then return false. |
4458 | if (!allowQParamChange || skipIfQuantParamChange) { |
4459 | return false; |
4460 | } |
4461 | |
4462 | // Replace the old quantized type with the new type with different |
4463 | // min/max. |
4464 | const TypeRef oldTy = qResult.getType(); |
4465 | const auto qParams = quantization::chooseQuantizationParams( |
4466 | {newMin, newMax}, cctx.precisionConfig.quantConfig.schema, |
4467 | oldTy->getElementType()); |
4468 | const TypeRef newTy = F->getParent()->uniqueType( |
4469 | oldTy->getElementType(), oldTy->dims(), qParams.scale, qParams.offset); |
4470 | qResult.getNode()->setType(qResult.getResNo(), newTy); |
4471 | return true; |
4472 | }; |
4473 | |
4474 | for (Node &node : F->getNodes()) { |
4475 | // Clip(Dequantize(Node)) -> Dequantize(Node) |
4476 | if (ClipNode *clip = dyn_cast<ClipNode>(&node)) { |
4477 | DequantizeNode *DQN = dyn_cast<DequantizeNode>(clip->getInput()); |
4478 | if (!DQN) { |
4479 | continue; |
4480 | } |
4481 | |
4482 | // Cannot perform this optimization if there are multiple users of DQN or |
4483 | // DQN's input, as otherwise they'd have incorrect quantization params. |
4484 | NodeValue qResult = DQN->getInput(); |
4485 | const bool skipIfQuantParamChange = |
4486 | DQN->getNumUsers() != 1 || qResult.getNode()->getNumUsers() != 1; |
4487 | |
4488 | // Try to update the quantize's type, otherwise skip this one. |
4489 | if (!updateQuantizeNodeType( |
4490 | F, cctx, qResult, clip, skipIfQuantParamChange, |
4491 | cctx.optimizationOpts.enableQuantParamChanges)) { |
4492 | continue; |
4493 | } |
4494 | |
4495 | // Now we skip the Clip since the node prior to DQN has included the |
4496 | // Clip's range in its quantization parameters. |
4497 | clip->getResult().replaceAllUsesOfWith(DQN->getResult()); |
4498 | changed = true; |
4499 | continue; |
4500 | } |
4501 | |
4502 | // Quantize(Clip(Node)) -> Quantize(Node) |
4503 | if (QuantizeNode *QN = dyn_cast<QuantizeNode>(&node)) { |
4504 | ClipNode *clip = dyn_cast<ClipNode>(QN->getInput()); |
4505 | if (!clip) { |
4506 | continue; |
4507 | } |
4508 | |
4509 | // Cannot set the type of quantized nodes if they're used by a Node with |
4510 | // side effects, as they may be expecting a specific type. |
4511 | const bool skipIfQuantParamChange = isUsedByNodeWithSideEffects(QN); |
4512 | |
4513 | // Try to update the quantize's type, otherwise skip this one. |
4514 | if (!updateQuantizeNodeType( |
4515 | F, cctx, QN->getResult(), clip, skipIfQuantParamChange, |
4516 | cctx.optimizationOpts.enableQuantParamChanges)) { |
4517 | continue; |
4518 | } |
4519 | |
4520 | // Now we can skip the Clip since the QN has accounted for the Clip's |
4521 | // range in its quantization parameters. |
4522 | QN->setNthInput(QuantizeNode::InputIdx, clip->getInput()); |
4523 | changed = true; |
4524 | continue; |
4525 | } |
4526 | } |
4527 | |
4528 | return changed; |
4529 | } |
4530 | |
4531 | /// Optimize away ConvertToNode. |
4532 | /// This basically turns "conversion(conversion A to B) to C" |
4533 | /// into noop if all of the conditions below are met: |
4534 | /// - the type of A and C are the same; |
4535 | /// - A->B is not a FP-to-Int conversion; |
4536 | /// - A->B is not a narrowing conversion. |
4537 | bool OptimizeConversions::run(Function *F, const CompilationContext &cctx) { |
4538 | LOG_SCOPE(F->getLogContext(), getName()); |
4539 | |
4540 | bool changed = false; |
4541 | for (auto &node : F->getNodes()) { |
4542 | if (auto *CN = llvm::dyn_cast<ConvertToNode>(&node)) { |
4543 | |
4544 | // Eliminate no-op conversion. |
4545 | if (CN->getInput().getType() == CN->getResult().getType()) { |
4546 | CN->getResult().replaceAllUsesOfWith(CN->getInput()); |
4547 | changed = true; |
4548 | continue; |
4549 | } |
4550 | |
4551 | // Perform conversion of constants. |
4552 | if (auto *BN = llvm::dyn_cast<Constant>(CN->getInput())) { |
4553 | auto newConst = |
4554 | convertConstant(*F->getParent(), *BN, CN->getResult().getType()); |
4555 | if (newConst == NodeValue()) { |
4556 | continue; |
4557 | } |
4558 | CN->getResult().replaceAllUsesOfWith(newConst, F); |
4559 | changed = true; |
4560 | continue; |
4561 | } |
4562 | |
4563 | // Simplify a chain of conversions A -> B -> C to A -> C, unless A -> B |
4564 | // is a narrowing cast. |
4565 | if (auto *BN = llvm::dyn_cast<ConvertToNode>(CN->getInput())) { |
4566 | auto AN = BN->getInput(); |
4567 | |
4568 | // Do not optimize away narrowing casts. |
4569 | if (!isValueChangingCast(AN.getType(), BN->getResult().getType())) { |
4570 | auto *newCast = |
4571 | F->createConvertTo(CN->getName(), AN, CN->getResult().getType()); |
4572 | CN->getResult().replaceAllUsesOfWith(newCast); |
4573 | changed = true; |
4574 | continue; |
4575 | } |
4576 | } |
4577 | } |
4578 | } |
4579 | return changed; |
4580 | } |
4581 | |
4582 | /// Optimize patterns of Int8 quantization/dequantization with ConvertTo. This |
4583 | /// may have numerical differences but since Int8 has a small range it's likely |
4584 | /// fine. This is opt in by a backend. |
4585 | bool OptimizeOutIntermediateConversions::run(Function *F, |
4586 | const CompilationContext &cctx) { |
4587 | LOG_SCOPE(F->getLogContext(), getName()); |
4588 | |
4589 | bool changed = false; |
4590 | for (auto &node : F->getNodes()) { |
4591 | // Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8 |
4592 | if (QuantizeNode *QN = llvm::dyn_cast<QuantizeNode>(&node)) { |
4593 | if (QN->getResult().getType()->getElementType() != ElemKind::Int8QTy) { |
4594 | continue; |
4595 | } |
4596 | |
4597 | ConvertToNode *CN = llvm::dyn_cast<ConvertToNode>(QN->getInput()); |
4598 | if (!CN) { |
4599 | continue; |
4600 | } |
4601 | |
4602 | QN->setNthInput(QuantizeNode::InputIdx, CN->getInput()); |
4603 | changed = true; |
4604 | continue; |
4605 | } |
4606 | |
4607 | // ConvertTo(Dequantize(Node)) -> Dequantize(Node), where Dequantize is int8 |
4608 | if (ConvertToNode *CN = llvm::dyn_cast<ConvertToNode>(&node)) { |
4609 | DequantizeNode *DN = llvm::dyn_cast<DequantizeNode>(CN->getInput()); |
4610 | if (!DN || |
4611 | DN->getInput().getType()->getElementType() != ElemKind::Int8QTy) { |
4612 | continue; |
4613 | } |
4614 | |
4615 | // Create new Dequantize node, dequantizing directly to the kind of the |
4616 | // ConverTo that originally consumed it. |
4617 | DequantizeNode *newDN = F->createDequantize( |
4618 | DN->getName(), DN->getInput(), CN->getResult().getElementType()); |
4619 | CN->getResult().replaceAllUsesOfWith(newDN->getResult()); |
4620 | changed = true; |
4621 | continue; |
4622 | } |
4623 | } |
4624 | |
4625 | return changed; |
4626 | } |
4627 | |
4628 | // Look for float Relus that we can fuse up into quantized FCs. This is either |
4629 | // with a Dequantize between them, or a Concat with multiple FCs being |
4630 | // dequantized and concatenated together. |
4631 | bool OptimizeQuantFCFloatRelu::run(Function *F, |
4632 | const CompilationContext &cctx) { |
4633 | LOG_SCOPE(F->getLogContext(), getName()); |
4634 | |
4635 | // This opt implies there to be changes to quantization, because we create an |
4636 | // int relu that was previously float. |
4637 | if (!cctx.optimizationOpts.enableQuantParamChanges) { |
4638 | return false; |
4639 | } |
4640 | |
4641 | bool changed = false; |
4642 | for (auto &node : F->getNodes()) { |
4643 | auto *relu = llvm::dyn_cast<ReluNode>(&node); |
4644 | // Look for Float relus to start. |
4645 | if (!relu || |
4646 | !isFloatElemKind(relu->getResult().getType()->getElementType())) { |
4647 | continue; |
4648 | } |
4649 | |
4650 | // Now look for dequantize nodes. We may need to move this above a Concat. |
4651 | // Check if necessary. |
4652 | std::vector<FullyConnectedNode *> nodesToFuse; |
4653 | if (auto *CN = llvm::dyn_cast<ConcatNode>(relu->getInput())) { |
4654 | if (CN->getNumUsers() != 1) { |
4655 | continue; |
4656 | } |
4657 | |
4658 | // Check if all the concat inputs are dequantized FCs. |
4659 | for (const NodeValue &NV : CN->getInputs()) { |
4660 | auto *DQ = llvm::dyn_cast<DequantizeNode>(NV); |
4661 | if (!DQ || DQ->getNumUsers() != 1) { |
4662 | break; |
4663 | } |
4664 | auto *FC = llvm::dyn_cast<FullyConnectedNode>(DQ->getInput()); |
4665 | if (!FC || FC->getNumUsers() != 1) { |
4666 | break; |
4667 | } |
4668 | nodesToFuse.push_back(FC); |
4669 | } |
4670 | if (nodesToFuse.size() != CN->getInputs().size()) { |
4671 | continue; |
4672 | } |
4673 | } else if (auto *DQ = llvm::dyn_cast<DequantizeNode>(relu->getInput())) { |
4674 | if (DQ->getNumUsers() != 1) { |
4675 | continue; |
4676 | } |
4677 | |
4678 | auto *FC = llvm::dyn_cast<FullyConnectedNode>(DQ->getInput()); |
4679 | if (!FC || FC->getNumUsers() != 1) { |
4680 | break; |
4681 | } |
4682 | nodesToFuse.push_back(FC); |
4683 | } else { |
4684 | continue; |
4685 | } |
4686 | |
4687 | // Did not find any quantized FCs to fuse, so continue. |
4688 | if (!nodesToFuse.size()) { |
4689 | continue; |
4690 | } |
4691 | |
4692 | // If the quant FC is used by nodes with side effects then skip, since we |
4693 | // may be changing the user's input qparam type. |
4694 | bool skip = false; |
4695 | for (FullyConnectedNode *FC : nodesToFuse) { |
4696 | if (isUsedByNodeWithSideEffects(FC)) { |
4697 | skip = true; |
4698 | break; |
4699 | } |
4700 | } |
4701 | if (skip) { |
4702 | continue; |
4703 | } |
4704 | |
4705 | // Now add quantized relus onto all of the FCs. |
4706 | for (FullyConnectedNode *FC : nodesToFuse) { |
4707 | const TypeRef FCTy = FC->getResult().getType(); |
4708 | TypeRef qReluTy = nullptr; |
4709 | if (cctx.precisionConfig.loadUniquedDummyQParams) { |
4710 | // Reuse the FC type, since we don't know its actual qparams during AOT |
4711 | // optimization. When we the actual type later before deploying, we will |
4712 | // do a final processing pass to set min to 0 via updateReluTypes(). |
4713 | qReluTy = FCTy; |
4714 | } else { |
4715 | // Use the same type as the FC for the Relu but with 0 as min. |
4716 | const auto qParams = quantization::chooseQuantizationParams( |
4717 | {0, FCTy->getQuantizedValueRange().second}, |
4718 | cctx.precisionConfig.quantConfig.schema, FCTy->getElementType()); |
4719 | qReluTy = |
4720 | F->getParent()->uniqueType(FCTy->getElementType(), FCTy->dims(), |
4721 | qParams.scale, qParams.offset); |
4722 | } |
4723 | ReluNode *qRelu = F->createRELU(relu->getName().str() + "_quant" , |
4724 | FC->getResult(), qReluTy); |
4725 | FC->getResult().typeUnsafeReplaceAllUsesOfWith(qRelu->getResult(), F, |
4726 | qRelu); |
4727 | } |
4728 | |
4729 | // Now we can get rid of the relu. |
4730 | relu->getResult().replaceAllUsesOfWith(relu->getInput()); |
4731 | changed = true; |
4732 | continue; |
4733 | } |
4734 | |
4735 | return changed; |
4736 | } |
4737 | |
4738 | /// Look for Concats with all Dequantization as input and Quantization as |
4739 | /// output, and change the Quantization/Dequantization into a rescale. |
4740 | bool OptimizeConcatQuantization::run(Function *F, |
4741 | const CompilationContext &cctx) { |
4742 | LOG_SCOPE(F->getLogContext(), getName()); |
4743 | |
4744 | bool changed = false; |
4745 | for (auto &node : F->getNodes()) { |
4746 | auto *CN = dyn_cast<ConcatNode>(&node); |
4747 | if (!CN) { |
4748 | continue; |
4749 | } |
4750 | |
4751 | // Look for a single Quantize user. |
4752 | if (CN->getUsers().size() != 1) { |
4753 | continue; |
4754 | } |
4755 | auto *QN = dyn_cast<QuantizeNode>((*CN->getUsers().begin()).getUser()); |
4756 | if (!QN) { |
4757 | continue; |
4758 | } |
4759 | |
4760 | // Gather/check all of the inputs are DequantizeNodes. |
4761 | std::vector<DequantizeNode *> DNs; |
4762 | DNs.reserve(CN->getInputs().size()); |
4763 | for (const NodeValue &NV : CN->getInputs()) { |
4764 | auto *DN = dyn_cast<DequantizeNode>(NV); |
4765 | if (!DN || DN->getNumUsers() != 1) { |
4766 | break; |
4767 | } |
4768 | DNs.push_back(DN); |
4769 | } |
4770 | |
4771 | // If not all CN inputs are Dequantizes then skip. |
4772 | if (DNs.size() != CN->getInputs().size()) { |
4773 | continue; |
4774 | } |
4775 | |
4776 | // Now create Rescales instead of Dequantizes for all CN inputs. |
4777 | std::vector<NodeValue> newConcatInputs; |
4778 | newConcatInputs.reserve(DNs.size()); |
4779 | TypeRef QNTy = QN->getResult().getType(); |
4780 | for (DequantizeNode *DN : DNs) { |
4781 | if (DN->getInput().getType()->getScale() == QNTy->getScale() && |
4782 | DN->getInput().getType()->getOffset() == QNTy->getOffset()) { |
4783 | // Don't need to rescale as it already has the right scale/offset. |
4784 | newConcatInputs.push_back(DN->getInput()); |
4785 | } else { |
4786 | TypeRef newTy = F->getParent()->uniqueTypeWithNewShape( |
4787 | QNTy, DN->getResult().dims()); |
4788 | auto *RS = F->createRescaleQuantized(DN->getName().str() + "_rescale" , |
4789 | DN->getInput(), newTy); |
4790 | newConcatInputs.push_back(RS->getResult()); |
4791 | } |
4792 | } |
4793 | |
4794 | auto *newCN = F->createConcat(CN->getName(), newConcatInputs, CN->getDim()); |
4795 | |
4796 | // Now we can get rid of the Quantize after the CN. |
4797 | QN->getResult().replaceAllUsesOfWith(newCN->getResult()); |
4798 | changed = true; |
4799 | continue; |
4800 | } |
4801 | |
4802 | return changed; |
4803 | } |
4804 | |
4805 | /// \returns a cloned version of node \p N, but with each of the cloned node's |
4806 | /// output types set to the corresponding type in \p types. The new node is |
4807 | /// added to Function \p F. |
4808 | /// \pre types.size() == N->getNumResults() |
4809 | static Node *cloneNodeWithNewTypes(Function *F, Node *N, |
4810 | llvm::ArrayRef<TypeRef> types) { |
4811 | assert(N->getNumResults() == types.size() && |
4812 | "Number of types must equal number of results of the node." ); |
4813 | |
4814 | Node *newNode = F->addNode(N->clone()); |
4815 | for (size_t i = 0; i < types.size(); i++) { |
4816 | newNode->setType(i, types[i]); |
4817 | } |
4818 | |
4819 | return newNode; |
4820 | } |
4821 | |
4822 | template <class T, class U> |
4823 | using enable_if_same_t = std::enable_if<std::is_same<T, U>::value, U>; |
4824 | #define FUNCTION_ENABLE_IF_TEMPLATE(NODE_NAME_) \ |
4825 | template <class T, typename... Args> \ |
4826 | typename enable_if_same_t<T, NODE_NAME_##Node>::type static |
4827 | |
4828 | FUNCTION_ENABLE_IF_TEMPLATE(AvgPool) * createNode(Function &F, Args... args) { |
4829 | return F.createAvgPool(args...); |
4830 | } |
4831 | FUNCTION_ENABLE_IF_TEMPLATE(MaxPool) * createNode(Function &F, Args... args) { |
4832 | return F.createMaxPool(args...); |
4833 | } |
4834 | FUNCTION_ENABLE_IF_TEMPLATE(Add) |
4835 | *createNode(Function &F, Args... args) { return F.createAdd(args...); } |
4836 | FUNCTION_ENABLE_IF_TEMPLATE(Sub) |
4837 | *createNode(Function &F, Args... args) { return F.createSub(args...); } |
4838 | FUNCTION_ENABLE_IF_TEMPLATE(Mul) |
4839 | *createNode(Function &F, Args... args) { return F.createMul(args...); } |
4840 | FUNCTION_ENABLE_IF_TEMPLATE(Div) |
4841 | *createNode(Function &F, Args... args) { return F.createDiv(args...); } |
4842 | FUNCTION_ENABLE_IF_TEMPLATE(Min) |
4843 | *createNode(Function &F, Args... args) { return F.createMin(args...); } |
4844 | FUNCTION_ENABLE_IF_TEMPLATE(Max) |
4845 | *createNode(Function &F, Args... args) { return F.createMax(args...); } |
4846 | FUNCTION_ENABLE_IF_TEMPLATE(MatMul) |
4847 | *createNode(Function &F, Args... args) { return F.createMatMul(args...); } |
4848 | |
4849 | FUNCTION_ENABLE_IF_TEMPLATE(AvgPool) * |
4850 | createNewPool(Function &F, T *PN, RescaleQuantizedNode *rescale) { |
4851 | return createNode<T>(F, PN->getName(), rescale->getInput(), PN->getKernels(), |
4852 | PN->getStrides(), PN->getPads(), NCHW, |
4853 | PN->getCountIncludePads()); |
4854 | } |
4855 | FUNCTION_ENABLE_IF_TEMPLATE(MaxPool) * |
4856 | createNewPool(Function &F, T *PN, RescaleQuantizedNode *rescale) { |
4857 | return createNode<T>(F, PN->getName(), rescale->getInput(), PN->getKernels(), |
4858 | PN->getStrides(), PN->getPads(), |
4859 | PN->getArgmax().getElementType()); |
4860 | } |
4861 | |
4862 | /// Sink Rescale down with Pooling node. |
4863 | /// PoolingNode(Rescale(X)) -> Rescale(PoolingNode(X)). |
4864 | /// Apply this transformation for AvgPool and MaxPool. |
4865 | template <typename T> |
4866 | static bool sinkDownRescaleToPoolingNode(Function &F, T *PN) { |
4867 | LOG_SCOPE(F.getLogContext(), "sinkDownRescaleToPoolingNode" ) |
4868 | |
4869 | bool changed = false; |
4870 | |
4871 | if (auto *rescale = dyn_cast<RescaleQuantizedNode>(PN->getInput())) { |
4872 | T *newPN = createNewPool(F, PN, rescale); |
4873 | auto rescaleOutTy = F.getParent()->uniqueTypeWithNewShape( |
4874 | rescale->getResult().getType(), PN->getResult().getType()); |
4875 | auto *newRescale = F.createRescaleQuantized( |
4876 | rescale->getName(), newPN->getResult(), rescaleOutTy); |
4877 | PN->getResult().replaceAllUsesOfWith(newRescale); |
4878 | for (size_t i = 1; i < PN->getNumResults(); i++) { |
4879 | PN->getNthResult(i).replaceAllUsesOfWith(newPN->getNthResult(i)); |
4880 | } |
4881 | changed = true; |
4882 | } |
4883 | |
4884 | return changed; |
4885 | } |
4886 | |
4887 | /// Combine Rescale down with Arithmetic node. |
4888 | /// ArithmeticNode(Rescale(X), Rescale(Y)) -> ArithmeticNode(X, Y). |
4889 | /// ArithmeticNode(Rescale(X), Y) -> ArithmeticNode(X, Y). |
4890 | /// ArithmeticNode(X, Rescale(Y)) -> ArithmeticNode(X, Y). |
4891 | /// Apply this optimization for Add, Sub, Mul, Div, Min, Max. |
4892 | template <typename T> |
4893 | static bool combineDownRescaleToArithmeticNode(Function &F, T *AN) { |
4894 | LOG_SCOPE(F.getLogContext(), "combineDownRescaleToArithmeticNode" ) |
4895 | |
4896 | bool changed = false; |
4897 | |
4898 | if (auto *rescale = dyn_cast<RescaleQuantizedNode>(AN->getLHS())) { |
4899 | T *newAN = createNode<T>(F, AN->getName(), AN->getResult().getType(), |
4900 | rescale->getInput(), AN->getRHS()); |
4901 | AN->getResult().replaceAllUsesOfWith(newAN); |
4902 | AN = newAN; |
4903 | changed = true; |
4904 | } |
4905 | if (auto *rescale = dyn_cast<RescaleQuantizedNode>(AN->getRHS())) { |
4906 | T *newAN = createNode<T>(F, AN->getName(), AN->getResult().getType(), |
4907 | AN->getLHS(), rescale->getInput()); |
4908 | AN->getResult().replaceAllUsesOfWith(newAN); |
4909 | changed = true; |
4910 | } |
4911 | |
4912 | return changed; |
4913 | } |
4914 | |
4915 | /// Sink Rescale nodes down when possible. |
4916 | /// \returns if anything was changed in the given function. |
4917 | static bool sinkRescaleQuantizedNode(Function *F, |
4918 | const CompilationContext &cctx) { |
4919 | LOG_SCOPE(F->getLogContext(), "sinkRescaleQuantizedNode" ); |
4920 | bool changed = false; |
4921 | for (auto &node : F->getNodes()) { |
4922 | // Sink Rescale below Reshape node. |
4923 | // Reshape(Rescale(X)) -> Rescale(Reshape(X)). |
4924 | if (auto *reshape = dyn_cast<ReshapeNode>(&node)) { |
4925 | auto *rescale = dyn_cast<RescaleQuantizedNode>(reshape->getInput()); |
4926 | if (!rescale) { |
4927 | continue; |
4928 | } |
4929 | |
4930 | auto *newReshape = |
4931 | F->createReshape(reshape->getName(), rescale->getInput(), |
4932 | reshape->getResult().dims(), reshape->getLayout()); |
4933 | auto *newRescale = F->createRescaleQuantized( |
4934 | rescale->getName(), newReshape, reshape->getResult().getType()); |
4935 | reshape->getResult().replaceAllUsesOfWith(newRescale); |
4936 | |
4937 | changed = true; |
4938 | continue; |
4939 | } |
4940 | |
4941 | // Sink Rescale below Slice node. |
4942 | // Slice(Rescale(X)) -> Rescale(Slice(X)). |
4943 | if (auto *slice = dyn_cast<SliceNode>(&node)) { |
4944 | auto *rescale = dyn_cast<RescaleQuantizedNode>(slice->getInput()); |
4945 | if (!rescale) { |
4946 | continue; |
4947 | } |
4948 | |
4949 | auto sliceOutTy = F->getParent()->uniqueTypeWithNewShape( |
4950 | rescale->getInput().getType(), slice->getResult().getType()); |
4951 | auto *newSlice = F->createSlice(slice->getName(), rescale->getInput(), |
4952 | slice->getStart(), sliceOutTy); |
4953 | auto *newRescale = F->createRescaleQuantized( |
4954 | rescale->getName(), newSlice, slice->getResult().getType()); |
4955 | slice->getResult().replaceAllUsesOfWith(newRescale); |
4956 | |
4957 | changed = true; |
4958 | continue; |
4959 | } |
4960 | |
4961 | // Sink Rescale below Transpose node. |
4962 | // Transpose(Rescale(X)) -> Rescale(Transpose(X)). |
4963 | if (auto *transpose = dyn_cast<TransposeNode>(&node)) { |
4964 | auto *rescale = dyn_cast<RescaleQuantizedNode>(transpose->getInput()); |
4965 | if (!rescale) { |
4966 | continue; |
4967 | } |
4968 | |
4969 | auto *newTranspose = |
4970 | F->createTranspose(transpose->getName(), rescale->getInput(), |
4971 | transpose->getShuffle(), transpose->getLayout()); |
4972 | auto rescaleOutTy = F->getParent()->uniqueTypeWithNewShape( |
4973 | rescale->getResult().getType(), transpose->getResult().getType()); |
4974 | auto *newRescale = F->createRescaleQuantized(rescale->getName(), |
4975 | newTranspose, rescaleOutTy); |
4976 | transpose->getResult().replaceAllUsesOfWith(newRescale); |
4977 | |
4978 | changed = true; |
4979 | continue; |
4980 | } |
4981 | |
4982 | if (auto *PN = dyn_cast<AvgPoolNode>(&node)) { |
4983 | // AvgPool input and output scale/bias may differ. |
4984 | if (!cctx.optimizationOpts.enableQuantParamChanges) { |
4985 | changed |= sinkDownRescaleToPoolingNode<AvgPoolNode>(*F, PN); |
4986 | } |
4987 | continue; |
4988 | } |
4989 | |
4990 | if (auto *PN = dyn_cast<MaxPoolNode>(&node)) { |
4991 | changed |= sinkDownRescaleToPoolingNode<MaxPoolNode>(*F, PN); |
4992 | continue; |
4993 | } |
4994 | |
4995 | // Combine Rescale down with FullyConnected node. |
4996 | // FullyConnected(Rescale(X)) -> FullyConnected(X). |
4997 | if (auto *FC = dyn_cast<FullyConnectedNode>(&node)) { |
4998 | auto *rescale = dyn_cast<RescaleQuantizedNode>(FC->getInput()); |
4999 | if (!rescale) { |
5000 | continue; |
5001 | } |
5002 | |
5003 | auto *newFC = F->createFullyConnected(FC->getName(), rescale->getInput(), |
5004 | FC->getWeights(), FC->getBias(), |
5005 | FC->getResult().getType()); |
5006 | FC->getResult().replaceAllUsesOfWith(newFC); |
5007 | |
5008 | changed = true; |
5009 | continue; |
5010 | } |
5011 | |
5012 | // Combine Rescale down with Convolution node. |
5013 | // Convolution(Rescale(X), F, B) -> Convolution(X, F, B). |
5014 | // Convolution(X, Rescale(F), B) -> Convolution(X, F, B). |
5015 | // Convolution(X, F, Rescale(B)) -> Convolution(X, F, B). |
5016 | // ... and different combinations. |
5017 | if (auto *CN = dyn_cast<ConvolutionNode>(&node)) { |
5018 | auto *rescaleX = dyn_cast<RescaleQuantizedNode>(CN->getInput()); |
5019 | auto *rescaleF = dyn_cast<RescaleQuantizedNode>(CN->getFilter()); |
5020 | auto *rescaleB = dyn_cast<RescaleQuantizedNode>(CN->getBias()); |
5021 | auto newX = rescaleX ? rescaleX->getInput() : CN->getInput(); |
5022 | auto newF = rescaleF ? rescaleF->getInput() : CN->getFilter(); |
5023 | auto newB = rescaleB ? rescaleB->getInput() : CN->getBias(); |
5024 | if (rescaleX || rescaleF || rescaleB) { |
5025 | auto *newCN = F->createConv(CN->getName(), newX, newF, newB, |
5026 | CN->getResult().getType(), CN->getKernels(), |
5027 | CN->getStrides(), CN->getPads(), |
5028 | CN->getGroup(), CN->getDilation()); |
5029 | newCN->setFusedActivation(CN->getFusedActivation()); |
5030 | newCN->setFusedActivationArgs(CN->getFusedActivationArgs()); |
5031 | |
5032 | CN->getResult().replaceAllUsesOfWith(newCN); |
5033 | changed = true; |
5034 | } |
5035 | continue; |
5036 | } |
5037 | |
5038 | if (auto *AN = dyn_cast<AddNode>(&node)) { |
5039 | changed |= combineDownRescaleToArithmeticNode<AddNode>(*F, AN); |
5040 | continue; |
5041 | } |
5042 | if (auto *AN = dyn_cast<SubNode>(&node)) { |
5043 | changed |= combineDownRescaleToArithmeticNode<SubNode>(*F, AN); |
5044 | continue; |
5045 | } |
5046 | if (auto *AN = dyn_cast<MulNode>(&node)) { |
5047 | changed |= combineDownRescaleToArithmeticNode<MulNode>(*F, AN); |
5048 | continue; |
5049 | } |
5050 | if (auto *AN = dyn_cast<DivNode>(&node)) { |
5051 | changed |= combineDownRescaleToArithmeticNode<DivNode>(*F, AN); |
5052 | continue; |
5053 | } |
5054 | if (auto *AN = dyn_cast<MinNode>(&node)) { |
5055 | changed |= combineDownRescaleToArithmeticNode<MinNode>(*F, AN); |
5056 | continue; |
5057 | } |
5058 | if (auto *AN = dyn_cast<MaxNode>(&node)) { |
5059 | changed |= combineDownRescaleToArithmeticNode<MaxNode>(*F, AN); |
5060 | continue; |
5061 | } |
5062 | |
5063 | // Combine Rescale down with Relu node. |
5064 | // ReluNode(Rescale(in)) -> ReluNode(in). |
5065 | if (auto *RN = dyn_cast<ReluNode>(&node)) { |
5066 | if (auto *rescale = dyn_cast<RescaleQuantizedNode>(RN->getInput())) { |
5067 | auto *newRN = F->createRELU(RN->getName(), rescale->getInput(), |
5068 | RN->getResult().getType()); |
5069 | RN->getResult().replaceAllUsesOfWith(newRN); |
5070 | changed = true; |
5071 | } |
5072 | continue; |
5073 | } |
5074 | |
5075 | if (auto *MN = dyn_cast<MatMulNode>(&node)) { |
5076 | changed |= combineDownRescaleToArithmeticNode<MatMulNode>(*F, MN); |
5077 | continue; |
5078 | } |
5079 | } |
5080 | |
5081 | return changed; |
5082 | } |
5083 | |
5084 | /// Eliminate node sequences that are related to quantization. |
5085 | /// \returns if anything was changed in the given function. |
5086 | bool OptimizeQuantization::run(Function *F, const CompilationContext &cctx) { |
5087 | LOG_SCOPE(F->getLogContext(), getName()); |
5088 | bool changed = false; |
5089 | // A worklist that contains the nodes to process. |
5090 | std::vector<Node *> worklist; |
5091 | |
5092 | // Add all of the interesting nodes to the worklist. |
5093 | for (auto &node : F->getNodes()) { |
5094 | if (isa<QuantizeNode>(node) || isa<DequantizeNode>(node) || |
5095 | isa<RescaleQuantizedNode>(node)) { |
5096 | worklist.push_back(&node); |
5097 | } |
5098 | } |
5099 | |
5100 | while (!worklist.empty()) { |
5101 | // Take a node from the worklist. |
5102 | Node *node = worklist.back(); |
5103 | worklist.pop_back(); |
5104 | |
5105 | if (auto *Q = dyn_cast<QuantizeNode>(node)) { |
5106 | if (auto *DQ = dyn_cast<DequantizeNode>(Q->getInput())) { |
5107 | // Quantize(Dequantize(X)) -> RescaleQuantized(X) |
5108 | // If the quantization-dequantization sequence does not change the |
5109 | // type then we can simply drop them without adding a requantization |
5110 | // node. |
5111 | changed = true; |
5112 | if (DQ->getInput().getType() == Q->getResult().getType()) { |
5113 | Q->getResult().replaceAllUsesOfWith(DQ->getInput()); |
5114 | continue; |
5115 | } |
5116 | |
5117 | auto *RS = F->createRescaleQuantized(Q->getName(), DQ->getInput(), |
5118 | Q->getResult().getType()); |
5119 | Q->getResult().replaceAllUsesOfWith(RS); |
5120 | |
5121 | // We may be able to optimize this rescale node. Remember to visit |
5122 | // this new node and try to optimize it later. |
5123 | worklist.push_back(RS); |
5124 | continue; |
5125 | } |
5126 | |
5127 | if (auto *SN = dyn_cast<SplatNode>(Q->getInput())) { |
5128 | // Quantize(Splat) -> Splat' |
5129 | changed = true; |
5130 | SplatNode *newSN = F->createSplat( |
5131 | SN->getName(), Q->getResult().getType(), SN->getValue()); |
5132 | Q->getResult().replaceAllUsesOfWith(newSN); |
5133 | continue; |
5134 | } |
5135 | } |
5136 | |
5137 | if (auto *DQ = dyn_cast<DequantizeNode>(node)) { |
5138 | if (auto *Q = dyn_cast<QuantizeNode>(DQ->getInput())) { |
5139 | // Dequantize(Quantize(X)) -> X |
5140 | changed = true; |
5141 | DQ->getResult().replaceAllUsesOfWith(Q->getInput()); |
5142 | continue; |
5143 | } |
5144 | // Fold the rescale into the following Dequantize. |
5145 | // Dequantize(rescale) -> Dequantize() |
5146 | if (auto *RS = dyn_cast<RescaleQuantizedNode>(DQ->getInput())) { |
5147 | changed = true; |
5148 | auto *newRS = F->createDequantize(DQ->getName(), RS->getInput(), |
5149 | DQ->getResult().getType()); |
5150 | DQ->getResult().replaceAllUsesOfWith(newRS); |
5151 | |
5152 | // We may be able to optimize this rescale node. Remember to visit |
5153 | // this new node and try to optimize it later. |
5154 | worklist.push_back(newRS); |
5155 | continue; |
5156 | } |
5157 | if (auto *SN = dyn_cast<SplatNode>(DQ->getInput())) { |
5158 | // Dequantize(Splat) -> Splat' |
5159 | changed = true; |
5160 | SplatNode *newSN = F->createSplat( |
5161 | SN->getName(), DQ->getResult().getType(), SN->getValue()); |
5162 | DQ->getResult().replaceAllUsesOfWith(newSN); |
5163 | continue; |
5164 | } |
5165 | } |
5166 | |
5167 | if (auto *RS = dyn_cast<RescaleQuantizedNode>(node)) { |
5168 | // All cases below tend to change the output scale/bias of a Node. This |
5169 | // may change the numerics of the op (even the range is narrower and so it |
5170 | // should be more accurate). |
5171 | if (!cctx.optimizationOpts.enableQuantParamChanges) { |
5172 | continue; |
5173 | } |
5174 | |
5175 | if (RS->getInput().getType() == RS->getResult().getType()) { |
5176 | // If rescale does not change the type, then simply drop it. |
5177 | changed = true; |
5178 | RS->getResult().replaceAllUsesOfWith(RS->getInput()); |
5179 | continue; |
5180 | } |
5181 | |
5182 | // All optimizations in this scope below combine a Rescale up into or |
5183 | // above the Rescale's input X. If X has multiple users then this merging |
5184 | // will duplicate X, just with a different output scale/offset. If X is |
5185 | // not a Splat then this is likely not desired, as it means a |
5186 | // computational node (e.g. Add) is duplicated. |
5187 | if (!RS->getInput().hasOneUse() && !isa<SplatNode>(RS->getInput())) { |
5188 | continue; |
5189 | } |
5190 | |
5191 | // Combine the rescale node up into its parent node. |
5192 | // Rescale(Node()) -> 'Node(). |
5193 | bool addNewNodeToWorklist = false; |
5194 | switch (RS->getInput().getNode()->getKind()) { |
5195 | case Kinded::Kind::RescaleQuantizedNodeKind: |
5196 | case Kinded::Kind::QuantizeNodeKind: |
5197 | addNewNodeToWorklist = true; |
5198 | case Kinded::Kind::SplatNodeKind: |
5199 | case Kinded::Kind::AddNodeKind: |
5200 | case Kinded::Kind::SubNodeKind: |
5201 | case Kinded::Kind::MulNodeKind: |
5202 | case Kinded::Kind::DivNodeKind: |
5203 | case Kinded::Kind::FmodNodeKind: |
5204 | case Kinded::Kind::MinNodeKind: |
5205 | case Kinded::Kind::MatMulNodeKind: |
5206 | case Kinded::Kind::ConvolutionNodeKind: |
5207 | case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: |
5208 | case Kinded::Kind::FullyConnectedNodeKind: |
5209 | case Kinded::Kind::SparseLengthsWeightedSumNodeKind: { |
5210 | changed = true; |
5211 | Node *newNode = |
5212 | cloneNodeWithNewTypes(F, RS->getInput(), RS->getResult().getType()); |
5213 | RS->getResult().replaceAllUsesOfWith(newNode); |
5214 | if (addNewNodeToWorklist) { |
5215 | worklist.push_back(newNode); |
5216 | } |
5217 | continue; |
5218 | } |
5219 | default:; |
5220 | } |
5221 | |
5222 | if (auto *MN = dyn_cast<MaxNode>(RS->getInput())) { |
5223 | // Rescale(MAX(X, Y)) -> MAX(Rescale(X), Rescale(Y)). |
5224 | // It's okay to rescale the operands because even if the output range |
5225 | // is smaller then truncation would have happened during the rescale. |
5226 | // On values that are outside of the range we just moved the |
5227 | // truncation to a different location. |
5228 | changed = true; |
5229 | auto name = RS->getName(); |
5230 | auto *L = F->createRescaleQuantized(name, MN->getLHS(), |
5231 | RS->getResult().getType()); |
5232 | auto *R = F->createRescaleQuantized(name, MN->getRHS(), |
5233 | RS->getResult().getType()); |
5234 | auto *newMN = F->createMax(MN->getName(), L, R); |
5235 | worklist.push_back(L); |
5236 | worklist.push_back(R); |
5237 | RS->getResult().replaceAllUsesOfWith(newMN); |
5238 | continue; |
5239 | } |
5240 | } // Handle RescaleQuantizedNode |
5241 | } // For each item in the worklist. |
5242 | |
5243 | // This pass is based on real qparams, so skip this opt if using dummies. |
5244 | if (!cctx.precisionConfig.loadUniquedDummyQParams) { |
5245 | changed |= optimizeQuantizedMaxSplat(F); |
5246 | } |
5247 | |
5248 | // If nothing has changed then sink rescale quantization nodes. |
5249 | if (!changed) { |
5250 | changed = sinkRescaleQuantizedNode(F, cctx); |
5251 | } |
5252 | return changed; |
5253 | } |
5254 | |
5255 | void glow::convertQuantizedConstants(Function *F, CompilationContext &cctx) { |
5256 | for (auto &node : F->getNodes()) { |
5257 | auto *Q = dyn_cast<QuantizeNode>(&node); |
5258 | if (!Q) { |
5259 | continue; |
5260 | } |
5261 | auto *C = dyn_cast<Constant>(Q->getInput()); |
5262 | if (!C) { |
5263 | continue; |
5264 | } |
5265 | |
5266 | // Quantize(Constant) -> Constant |
5267 | // Note, it does not really matter how many usages this Constant has. |
5268 | // Quantized graph will use optimized Constant and other functions will |
5269 | // refer to the floating point original Constant. |
5270 | NodeValue NC = |
5271 | convertConstant(*F->getParent(), *C, Q->getResult().getType()); |
5272 | if (NC == NodeValue()) { |
5273 | continue; |
5274 | } |
5275 | Q->getResult().replaceAllUsesOfWith(NC); |
5276 | } |
5277 | |
5278 | // Perform Dead Code Elimination. |
5279 | runDCEPass(F, cctx); |
5280 | } |
5281 | |
5282 | void glow::convertPlaceholdersToConstants(Function *F, |
5283 | const PlaceholderBindings &bindings, |
5284 | llvm::ArrayRef<Placeholder *> phs) { |
5285 | LOG_SCOPE(F->getLogContext(), "convertPlaceholdersToConstants" ) |
5286 | |
5287 | auto *M = F->getParent(); |
5288 | for (auto &PH : F->findPlaceholders()) { |
5289 | if (std::find(phs.begin(), phs.end(), PH) != phs.end()) { |
5290 | continue; |
5291 | } |
5292 | auto *tensor = bindings.get(PH); |
5293 | if (!tensor) { |
5294 | continue; |
5295 | } |
5296 | auto *constant = M->createConstant(PH->getName(), *tensor, PH->getLayout()); |
5297 | PH->getOutput().replaceAllUsesOfWith(constant, F); |
5298 | } |
5299 | } |
5300 | |
5301 | /// \returns True if the \p node sub-tree corresponds to a scalar |
5302 | /// (Constant or Splat) and return the float value in \p retFloat. |
5303 | static bool getFloatScalar(Node *node, float *retFloat) { |
5304 | // Iterate across potential Tile Nodes (implied by broadcasting if any). |
5305 | auto *n = node; |
5306 | while (auto *TN = dyn_cast<TileNode>(n)) { |
5307 | n = TN->getInput(); |
5308 | } |
5309 | |
5310 | // After potential Tile nodes, it should be a singleton constant scalar node |
5311 | // with any shape corresponding to one single element. |
5312 | if (auto *constNode = dyn_cast<Constant>(n)) { |
5313 | if ((constNode->getType()->getElementType() != ElemKind::FloatTy) || |
5314 | (constNode->getType()->size() != 1)) { |
5315 | return false; |
5316 | } |
5317 | auto valueH = constNode->getHandle<float>(); |
5318 | std::vector<dim_t> coord(constNode->getType()->dims().size(), 0); |
5319 | *retFloat = valueH.at(coord); |
5320 | return true; |
5321 | } |
5322 | if (auto *splatNode = dyn_cast<SplatNode>(n)) { |
5323 | *retFloat = splatNode->getValue(); |
5324 | return true; |
5325 | } |
5326 | |
5327 | return false; |
5328 | } |
5329 | |
5330 | /// Fold leakyRelu operations expressed as a sub-graph Max(A, Mul(A, scalar)) |
5331 | /// and replace it by PRelu(Splat). |
5332 | bool FoldLeakyRelu::run(Function *F, const CompilationContext &cctx) { |
5333 | LOG_SCOPE(F->getLogContext(), getName()); |
5334 | bool changed = false; |
5335 | auto &nodes = F->getNodes(); |
5336 | for (auto &node : nodes) { |
5337 | auto *maxNode = dyn_cast<MaxNode>(&node); |
5338 | if (!maxNode) { |
5339 | continue; |
5340 | } |
5341 | NodeValue otherMaxOperand; |
5342 | MulNode *mulNode; |
5343 | if ((mulNode = dyn_cast<MulNode>(maxNode->getRHS()))) { |
5344 | otherMaxOperand = maxNode->getLHS(); |
5345 | } else if ((mulNode = dyn_cast<MulNode>(maxNode->getLHS()))) { |
5346 | otherMaxOperand = maxNode->getRHS(); |
5347 | } else { |
5348 | continue; |
5349 | } |
5350 | NodeValue otherMulOperand; |
5351 | float value; |
5352 | if (getFloatScalar(mulNode->getRHS(), &value)) { |
5353 | otherMulOperand = maxNode->getLHS(); |
5354 | } else if (getFloatScalar(mulNode->getLHS(), &value)) { |
5355 | otherMulOperand = maxNode->getRHS(); |
5356 | } else { |
5357 | continue; |
5358 | } |
5359 | if ((value <= 1.0f) && (otherMulOperand == otherMaxOperand)) { |
5360 | // The sub-tree is a Leaky-Relu, express it as a PRelu. |
5361 | auto *splat = F->createSplat(maxNode->getName(), |
5362 | mulNode->getResult().getType(), value); |
5363 | auto *PRelu = F->createPRELU(maxNode->getName(), otherMaxOperand, splat); |
5364 | maxNode->getResult().replaceAllUsesOfWith(PRelu); |
5365 | changed = true; |
5366 | continue; |
5367 | } |
5368 | } |
5369 | return changed; |
5370 | } |
5371 | |
5372 | /// Parameters that are used to define ChannelShuffle operators. |
5373 | struct ChannelShuffleParams { |
5374 | size_t group; |
5375 | size_t kernel; |
5376 | }; |
5377 | |
5378 | /// Compute the original parameters to the ChannelShuffle operator (represented |
5379 | /// as ReshapeNode->TransposeNode->ReshapeNode) for which \p node is the leading |
5380 | /// ReshapeNode. \returns The original ChannelShuffle parameters if possible and |
5381 | /// empty Optional otherwise. |
5382 | static llvm::Optional<ChannelShuffleParams> |
5383 | getChannelShuffleParams(const ReshapeNode &node) { |
5384 | auto resM = llvm::Optional<ChannelShuffleParams>(); |
5385 | |
5386 | llvm::ArrayRef<dim_t> inputDims = node.getInput().dims(); |
5387 | llvm::ArrayRef<dim_t> resultDims = node.getDims(); |
5388 | |
5389 | // Check that there is one more output dimension than input dimension. |
5390 | if (resultDims.size() != inputDims.size() + 1) { |
5391 | return resM; |
5392 | } |
5393 | |
5394 | // Find the first output dimension that doesn't match its corresponding input |
5395 | // dimension. |
5396 | ChannelShuffleParams params; |
5397 | bool found = false; |
5398 | for (size_t i = 0, e = resultDims.size(); i < e - 1; ++i) { |
5399 | if (inputDims[i] != resultDims[i]) { |
5400 | params.kernel = i; |
5401 | params.group = resultDims[i]; |
5402 | found = true; |
5403 | break; |
5404 | } |
5405 | } |
5406 | |
5407 | // Double check the property that the mismatched output found dimension and |
5408 | // its successor together evenly multiply to the input dimension they |
5409 | // mismatched on. |
5410 | if (found && resultDims[params.kernel] * resultDims[params.kernel + 1] == |
5411 | inputDims[params.kernel]) { |
5412 | resM = params; |
5413 | } |
5414 | |
5415 | return resM; |
5416 | } |
5417 | |
5418 | // Fold Reshape->Transpose->Reshape into ChannelShuffle when applicable. |
5419 | bool FoldChannelShuffle::run(Function *F, const CompilationContext &cctx) { |
5420 | LOG_SCOPE(F->getLogContext(), getName()); |
5421 | |
5422 | bool changed = false; |
5423 | auto &nodes = F->getNodes(); |
5424 | for (auto &node : nodes) { |
5425 | auto *RN2 = dyn_cast<ReshapeNode>(&node); |
5426 | if (!RN2) { |
5427 | continue; |
5428 | } |
5429 | |
5430 | auto *TR = dyn_cast<TransposeNode>(RN2->getInput()); |
5431 | if (!TR) { |
5432 | continue; |
5433 | } |
5434 | |
5435 | auto *RN1 = dyn_cast<ReshapeNode>(TR->getInput()); |
5436 | if (!RN1) { |
5437 | continue; |
5438 | } |
5439 | |
5440 | // Check that the input and output shapes match: |
5441 | if (RN1->getInput().getType() != RN2->getResult().getType()) { |
5442 | continue; |
5443 | } |
5444 | |
5445 | // Compute the original parameters to ChannelShuffle. |
5446 | auto paramsM = getChannelShuffleParams(*RN1); |
5447 | if (!paramsM.hasValue()) { |
5448 | continue; |
5449 | } |
5450 | |
5451 | // Create a new ChannelShuffle with kernel parameter tranposed by the |
5452 | // TR's shuffle. |
5453 | auto *newCS = F->createChannelShuffle("channel_shuffle" , RN1->getInput(), |
5454 | paramsM->group, paramsM->kernel); |
5455 | RN2->getResult().replaceAllUsesOfWith(newCS); |
5456 | changed = true; |
5457 | } |
5458 | return changed; |
5459 | } |
5460 | |
5461 | // Fold Tile -> Add into BatchedAdd wherever applicable. |
5462 | bool FoldTileAddIntoBatchedAdd::run(Function *F, |
5463 | const CompilationContext &cctx) { |
5464 | LOG_SCOPE(F->getLogContext(), getName()); |
5465 | |
5466 | bool changed = false; |
5467 | for (const auto &node : F->getNodes()) { |
5468 | const auto *addNode = dyn_cast<AddNode>(&node); |
5469 | if (!addNode) { |
5470 | continue; |
5471 | } |
5472 | |
5473 | NodeValue batchNode, addedNode; |
5474 | const auto &LHS = addNode->getLHS(); |
5475 | const auto &RHS = addNode->getRHS(); |
5476 | const TileNode *tileNode = nullptr; |
5477 | |
5478 | // Check if LHS is a tile. |
5479 | if ((tileNode = dyn_cast<TileNode>(LHS))) { |
5480 | batchNode = RHS; |
5481 | addedNode = tileNode->getInput(); |
5482 | } |
5483 | // Check if RHS is a tile. |
5484 | else if ((tileNode = dyn_cast<TileNode>(RHS))) { |
5485 | batchNode = LHS; |
5486 | addedNode = tileNode->getInput(); |
5487 | } |
5488 | // If neither LHS or RHS is a tile, nothing to do. |
5489 | else { |
5490 | continue; |
5491 | } |
5492 | |
5493 | // If the tiling of the added node is not along the 0th axis, |
5494 | // 'Add' cannot be replaced with 'BatchedAdd'. |
5495 | if (tileNode->getAxis() != 0) { |
5496 | continue; |
5497 | } |
5498 | |
5499 | auto oldDims = addedNode.dims(); |
5500 | // If the 0th dimension of the added node is not 1, |
5501 | // then reducing dimension via reshaping is more complicated. |
5502 | // Hence, Add will not be replaced with BatchedAdd. |
5503 | if (oldDims.size() == 0 || oldDims[0] != 1) { |
5504 | continue; |
5505 | } |
5506 | |
5507 | // Reshape the added node to create a slice for the batched add |
5508 | // such that its dim size is one less than that of the batch. |
5509 | const auto newDims = oldDims.take_back(oldDims.size() - 1); |
5510 | auto *slice = F->createReshape(tileNode->getName().str() + "_reshape" , |
5511 | addedNode, newDims); |
5512 | |
5513 | // Create a new batched add node to replace existing add node. |
5514 | auto *newBA = F->createBatchedAdd(addNode->getName().str() + "_batched_add" , |
5515 | batchNode, slice); |
5516 | addNode->getResult().replaceAllUsesOfWith(newBA); |
5517 | changed = true; |
5518 | } |
5519 | return changed; |
5520 | } |
5521 | |
5522 | /// Raise ClipNodes above shaping ops, e.g. Reshape, Transpose, Slice. Other |
5523 | /// passes will sink Clips to try to eliminate redundant ones. This pass should |
5524 | /// happen after sinking of Clips in order to try to get Clips to directly |
5525 | /// consume compute Nodes outputs. |
5526 | bool RaiseClipsAboveShapeNodes::run(Function *F, |
5527 | const CompilationContext &cctx) { |
5528 | LOG_SCOPE(F->getLogContext(), getName()); |
5529 | bool changed = false; |
5530 | |
5531 | // Keep track of what nodes will have oneLessUser due to DCE eventually. As an |
5532 | // example, we know that after we replace all users of OrigSlice with NewClip, |
5533 | // then OrigSlice and OrigClip are dead, and so Input1 will have one less user |
5534 | // after DCE. |
5535 | // Input Input |
5536 | // | / \ |
5537 | // OrigSlice NewClip OrigSlice <--| |
5538 | // | --> | | |-- (These two are now dead.) |
5539 | // OrigClip NewSlice OrigClip <--| |
5540 | // | | |
5541 | // Save Save |
5542 | std::unordered_set<Node *> oneLessUser; |
5543 | |
5544 | for (auto &N : F->getNodes()) { |
5545 | ClipNode *CN = dyn_cast<ClipNode>(&N); |
5546 | if (!CN) { |
5547 | continue; |
5548 | } |
5549 | |
5550 | // If the Clip's input has multiple users then do not raise the Clip, as |
5551 | // otherwise this will impact other Nodes. We subtract off an extra user |
5552 | // here if we know one user will be eliminated due to DCE eventually (see |
5553 | // above pic). |
5554 | unsigned numUsers = CN->getInput().getNode()->getNumUsers(); |
5555 | if (oneLessUser.count(CN->getInput().getNode())) { |
5556 | numUsers -= 1; |
5557 | } |
5558 | if (numUsers != 1) { |
5559 | continue; |
5560 | } |
5561 | |
5562 | // Sink Reshape below Clip. |
5563 | if (ReshapeNode *RN = dyn_cast<ReshapeNode>(CN->getInput())) { |
5564 | ClipNode *newCN = F->createClip(CN->getName(), RN->getInput(), |
5565 | CN->getMin(), CN->getMax()); |
5566 | ReshapeNode *newRN = F->createReshape(RN->getName(), newCN->getResult(), |
5567 | RN->getDims(), RN->getLayout()); |
5568 | CN->getResult().replaceAllUsesOfWith(newRN->getResult()); |
5569 | oneLessUser.insert(RN->getInput().getNode()); |
5570 | changed = true; |
5571 | continue; |
5572 | } |
5573 | |
5574 | // Sink Transpose below Clip. |
5575 | if (TransposeNode *TN = dyn_cast<TransposeNode>(CN->getInput())) { |
5576 | ClipNode *newCN = F->createClip(CN->getName(), TN->getInput(), |
5577 | CN->getMin(), CN->getMax()); |
5578 | TransposeNode *newTN = F->createTranspose( |
5579 | TN->getName(), newCN->getResult(), TN->getShuffle(), TN->getLayout()); |
5580 | CN->getResult().replaceAllUsesOfWith(newTN->getResult()); |
5581 | oneLessUser.insert(TN->getInput().getNode()); |
5582 | changed = true; |
5583 | continue; |
5584 | } |
5585 | |
5586 | // Sink Slice below Clip. |
5587 | if (SliceNode *SN = dyn_cast<SliceNode>(CN->getInput())) { |
5588 | ClipNode *newCN = F->createClip(CN->getName(), SN->getInput(), |
5589 | CN->getMin(), CN->getMax()); |
5590 | SliceNode *newSN = |
5591 | F->createSlice(SN->getName(), newCN->getResult(), SN->getStart(), |
5592 | SN->getResult().getType()); |
5593 | CN->getResult().replaceAllUsesOfWith(newSN->getResult()); |
5594 | oneLessUser.insert(SN->getInput().getNode()); |
5595 | changed = true; |
5596 | continue; |
5597 | } |
5598 | |
5599 | // Sink Tile below Clip. |
5600 | if (TileNode *TN = dyn_cast<TileNode>(CN->getInput())) { |
5601 | ClipNode *newCN = F->createClip(CN->getName(), TN->getInput(), |
5602 | CN->getMin(), CN->getMax()); |
5603 | TileNode *newTN = F->createTile(TN->getName(), newCN->getResult(), |
5604 | TN->getCount(), TN->getAxis()); |
5605 | CN->getResult().replaceAllUsesOfWith(newTN->getResult()); |
5606 | oneLessUser.insert(TN->getInput().getNode()); |
5607 | changed = true; |
5608 | continue; |
5609 | } |
5610 | } // For all nodes in the graph. |
5611 | |
5612 | return changed; |
5613 | } |
5614 | |
5615 | /// Fold ElemKind conversion nodes (ConvertTo, Quantize) into |
5616 | /// single-user Placeholders. Note that this changes the semantics |
5617 | /// of the IO of the Function and so must be done carefully, i.e. should always |
5618 | /// be opt-in and done alongside conversion of corresponding Tensors in |
5619 | /// PlaceholderBindings. If |
5620 | /// cctx.optimizationOpts.foldStaticPlaceholderConversions is set this will |
5621 | /// only change Placeholders marked as static. |
5622 | bool FoldElemKindConversionIntoInputs::run(Function *F, |
5623 | const CompilationContext &cctx) { |
5624 | LOG_SCOPE(F->getLogContext(), getName()); |
5625 | |
5626 | bool changed = false; |
5627 | auto &nodes = F->getNodes(); |
5628 | |
5629 | for (auto it = nodes.begin(), e = nodes.end(); it != e; it++) { |
5630 | Node *N = &*it; |
5631 | // Handle conversion of inputs (conversion of Placeholders): |
5632 | ConvertToNode *CTN = llvm::dyn_cast<ConvertToNode>(N); |
5633 | QuantizeNode *QN = llvm::dyn_cast<QuantizeNode>(N); |
5634 | if (CTN || QN) { |
5635 | NodeValue in = CTN ? CTN->getInput() : QN->getInput(); |
5636 | Placeholder *P = llvm::dyn_cast<Placeholder>(in); |
5637 | if (!P || P->getUsers().size() != 1) { |
5638 | continue; |
5639 | } |
5640 | // If foldElemKindConversionIntoIO is not set and this is not a static |
5641 | // placeholder then skip. |
5642 | if (!cctx.optimizationOpts.foldElemKindConversionIntoIO && |
5643 | !P->isStatic()) { |
5644 | continue; |
5645 | } |
5646 | |
5647 | // We have a conversion of a single-use placeholder to some other type, so |
5648 | // it is safe to do the requested conversion. |
5649 | NodeValue res = CTN ? CTN->getResult() : QN->getResult(); |
5650 | |
5651 | // Convert the type of the Placeholder to the conversion type. If target |
5652 | // type is fused call setTypeUnsafe because the shape can change in this |
5653 | // case. |
5654 | if (isFusedQuantizedElemKind(res.getElementType())) { |
5655 | P->setTypeUnsafe(Storage::OutputIdx, res.getType()); |
5656 | } else { |
5657 | P->setType(Storage::OutputIdx, res.getType()); |
5658 | } |
5659 | |
5660 | // Replace all uses of the original ConvertTo to the Placeholder. |
5661 | res.replaceAllUsesOfWith(P); |
5662 | |
5663 | changed = true; |
5664 | continue; |
5665 | } |
5666 | } |
5667 | return changed; |
5668 | } |
5669 | |
5670 | /// Fold ElemKind conversion nodes (ConvertTo, Dequantize) into SaveNodes. Note |
5671 | /// that this changes the semantics of the IO of the Function and so must be |
5672 | /// done carefully, i.e. should always be opt-in and done alongside conversion |
5673 | /// of corresponding Tensors in PlaceholderBindings. |
5674 | bool FoldElemKindConversionIntoOutputs::run(Function *F, |
5675 | const CompilationContext &cctx) { |
5676 | LOG_SCOPE(F->getLogContext(), getName()); |
5677 | |
5678 | std::unordered_set<SaveNode *> deadSaves; |
5679 | |
5680 | bool changed = false; |
5681 | // Since we will be adding in new SaveNodes, reverse iterate to be safe. |
5682 | auto &nodes = F->getNodes(); |
5683 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) { |
5684 | Node *N = &*it; |
5685 | |
5686 | // Handle conversion of outputs (SaveNodes + Placeholders): |
5687 | if (SaveNode *SN = llvm::dyn_cast<SaveNode>(N)) { |
5688 | if (!SN) { |
5689 | continue; |
5690 | } |
5691 | if (SN->getPlaceholder()->getUsers().size() != 1) { |
5692 | continue; |
5693 | } |
5694 | ConvertToNode *CTN = llvm::dyn_cast<ConvertToNode>(SN->getInput()); |
5695 | DequantizeNode *DQN = llvm::dyn_cast<DequantizeNode>(SN->getInput()); |
5696 | if (!CTN && !DQN) { |
5697 | continue; |
5698 | } |
5699 | NodeValue in = CTN ? CTN->getInput() : DQN->getInput(); |
5700 | |
5701 | // Set the type of the Placeholder to be same the conversion's input. |
5702 | SN->getPlaceholder()->setType(Storage::OutputIdx, in.getType()); |
5703 | |
5704 | // Create a new SaveNode directly using the conversion's input. |
5705 | F->createSave(SN->getName(), in, SN->getPlaceholder()); |
5706 | |
5707 | // Queue up deleting the original SaveNode as it won't be deleted via DCE. |
5708 | deadSaves.insert(SN); |
5709 | changed = true; |
5710 | continue; |
5711 | } |
5712 | } |
5713 | |
5714 | // Delete all the dead saves. |
5715 | for (SaveNode *SN : deadSaves) { |
5716 | F->eraseNode(SN); |
5717 | } |
5718 | |
5719 | return changed; |
5720 | } |
5721 | |
5722 | /// Broadcasts are implemented via 1) Reshape followed by a series of Tiles 2) |
5723 | /// BroadcastNode. This helper unwinds Broadcast operation -- it \returns the |
5724 | /// original Node before the broadcasting \p N if the broadcast takes place |
5725 | /// between the 0th dimension to \p endDim. Otherwise, \p return \p N. |
5726 | static NodeValue unwindBroadcast(NodeValue N, unsigned_t endDim) { |
5727 | if (auto *BN = dyn_cast<BroadcastNode>(N)) { |
5728 | const auto newShape = BN->getTargetDim(); |
5729 | const auto axis = BN->getAxis(); |
5730 | const auto &origDims = BN->getInput().dims(); |
5731 | |
5732 | if (origDims.size() + axis != newShape.size()) { |
5733 | return N; |
5734 | } |
5735 | |
5736 | for (dim_t i = endDim; i < newShape.size(); i++) { |
5737 | if (!(i >= axis && origDims[i - axis] == newShape[i])) { |
5738 | return N; |
5739 | } |
5740 | } |
5741 | |
5742 | return BN->getInput(); |
5743 | } |
5744 | |
5745 | // All non-BroadcastNode broadcasts must Tile at least once. |
5746 | if (!isa<TileNode>(N)) { |
5747 | return N; |
5748 | } |
5749 | |
5750 | while (TileNode *TN = dyn_cast<TileNode>(N)) { |
5751 | // Check that the axis of the current Tile is inside of the expected |
5752 | // provided endDim. |
5753 | if (TN->getAxis() >= endDim) { |
5754 | return N; |
5755 | } |
5756 | // Applicable only if original dim is 1 in the Broadcast's Tile. |
5757 | if (TN->getInput().dims()[TN->getAxis()] != 1) { |
5758 | return N; |
5759 | } |
5760 | N = TN->getInput(); |
5761 | } |
5762 | if (ReshapeNode *RN = dyn_cast<ReshapeNode>(N)) { |
5763 | return RN->getInput(); |
5764 | } |
5765 | |
5766 | return N; |
5767 | } |
5768 | |
5769 | /// Looks for supported arithmetic ops following LayerNorm and folds them into |
5770 | /// the scale/bias of the LayerNorm if the scale/bias are single-use |
5771 | /// Splat/Constant, as we can then later on constant fold them in. |
5772 | bool FoldLayerNormArithmetic::run(Function *F, const CompilationContext &cctx) { |
5773 | LOG_SCOPE(F->getLogContext(), getName()); |
5774 | |
5775 | bool changed = false; |
5776 | for (auto &N : F->getNodes()) { |
5777 | if (!isa<MulNode>(&N) && !isa<AddNode>(&N)) { |
5778 | continue; |
5779 | } |
5780 | // Currently only support floating point, as otherwise there will be |
5781 | // quantization parameter mismatches. |
5782 | if (!isFloatElemKind(N.getElementType(ArithmeticNode::ResultIdx))) { |
5783 | continue; |
5784 | } |
5785 | |
5786 | // Check if the Mul/Add is consuming an LN. |
5787 | LayerNormalizationNode *LN = |
5788 | dyn_cast<LayerNormalizationNode>(N.getNthInput(ArithmeticNode::LHSIdx)); |
5789 | if (!LN) { |
5790 | continue; |
5791 | } |
5792 | |
5793 | // Check if the RHS is a Splat, or Constant, or temp PH (to be Constant |
5794 | // later). It may have been broadcasted to the correct shape. |
5795 | NodeValue RHS = unwindBroadcast(N.getNthInput(ArithmeticNode::RHSIdx), |
5796 | LN->getResult().dims().size() - |
5797 | LN->getScale().dims().size()); |
5798 | |
5799 | auto *P = dyn_cast<Placeholder>(RHS); |
5800 | if (!isa<SplatNode>(RHS) && !isa<Constant>(RHS) && |
5801 | !(P && cctx.optimizationOpts.tempPHsForConstants.count(P))) { |
5802 | continue; |
5803 | } |
5804 | |
5805 | // Make sure the RHS that we want to merge into the LN's Scale/Bias have the |
5806 | // same number of elements, since we're about to reshape them to match. |
5807 | if (RHS.getType()->size() != LN->getScale().getType()->size()) { |
5808 | continue; |
5809 | } |
5810 | |
5811 | // RHS may have already been fused with a Reshape to get ready for |
5812 | // Tiling. Reshape RHS back here if necessary. |
5813 | if (RHS.dims() != LN->getScale().dims()) { |
5814 | RHS = F->createReshape(RHS.getNode()->getName().str() + "_squeezed" , RHS, |
5815 | LN->getScale().dims()); |
5816 | } |
5817 | |
5818 | if (MulNode *MN = dyn_cast<MulNode>(&N)) { |
5819 | // Merge the Mul into a new Scale and Bias, multiplying by the original |
5820 | // Mul RHS that followed the LayerNorm. |
5821 | MulNode *newScale = |
5822 | F->createMul(LN->getScale().getNode()->getName().str() + "_fuse_" + |
5823 | MN->getName().data(), |
5824 | LN->getScale(), RHS); |
5825 | MulNode *newBias = F->createMul(LN->getBias().getNode()->getName().str() + |
5826 | "_fuse_" + MN->getName().data(), |
5827 | LN->getBias(), RHS); |
5828 | LN->getScale().replaceAllUsesOfWith(newScale->getResult(), F, newScale); |
5829 | LN->getBias().replaceAllUsesOfWith(newBias->getResult(), F, newBias); |
5830 | MN->getResult().replaceAllUsesOfWith(LN->getResult()); |
5831 | changed = true; |
5832 | continue; |
5833 | } |
5834 | |
5835 | if (AddNode *AN = dyn_cast<AddNode>(&N)) { |
5836 | // Merge the Add into a new Bias, adding the original Add RHS that |
5837 | // followed the LayerNorm. |
5838 | AddNode *newBias = F->createAdd(LN->getBias().getNode()->getName().str() + |
5839 | "_fuse_" + AN->getName().data(), |
5840 | LN->getBias(), RHS); |
5841 | LN->getBias().replaceAllUsesOfWith(newBias->getResult(), F, newBias); |
5842 | AN->getResult().replaceAllUsesOfWith(LN->getResult()); |
5843 | changed = true; |
5844 | continue; |
5845 | } |
5846 | } |
5847 | |
5848 | return changed; |
5849 | } |
5850 | |
5851 | /// Looks for an activation directly following \p N from \p F that the backend |
5852 | /// \p B supports for fusion. |
5853 | template <class T> bool fuseActivation(T *N, Function *F, const Backend *B) { |
5854 | if (!N || N->hasFusedActivation() || !N->getResult().hasOneUse()) { |
5855 | return false; |
5856 | } |
5857 | |
5858 | // We know there is one result user so we can just deref the first result. |
5859 | Node *activation = (*N->getResult().getUsers().begin()).getUser(); |
5860 | if (!B || !B->supportsFusedActivation(N, activation)) { |
5861 | return false; |
5862 | } |
5863 | |
5864 | NodeValue activationNV; |
5865 | switch (activation->getKind()) { |
5866 | case Kinded::Kind::ReluNodeKind: |
5867 | activationNV = cast<ReluNode>(activation)->getResult(); |
5868 | N->setFusedActivation(FusedActivation::RELU); |
5869 | break; |
5870 | case Kinded::Kind::ClipNodeKind: |
5871 | activationNV = cast<ClipNode>(activation)->getResult(); |
5872 | N->setFusedActivation(FusedActivation::CLIP); |
5873 | N->setFusedActivationArgs({cast<ClipNode>(activation)->getMin(), |
5874 | cast<ClipNode>(activation)->getMax()}); |
5875 | break; |
5876 | case Kinded::Kind::SigmoidNodeKind: |
5877 | activationNV = cast<SigmoidNode>(activation)->getResult(); |
5878 | N->setFusedActivation(FusedActivation::SIGMOID); |
5879 | break; |
5880 | case Kinded::Kind::TanhNodeKind: |
5881 | activationNV = cast<TanhNode>(activation)->getResult(); |
5882 | N->setFusedActivation(FusedActivation::TANH); |
5883 | break; |
5884 | case Kinded::Kind::LeakyReluNodeKind: |
5885 | activationNV = cast<LeakyReluNode>(activation)->getResult(); |
5886 | N->setFusedActivation(FusedActivation::LEAKY_RELU); |
5887 | N->setFusedActivationArgs({cast<LeakyReluNode>(activation)->getAlpha()}); |
5888 | break; |
5889 | default: |
5890 | return false; |
5891 | } |
5892 | |
5893 | // Modify the node output type to that of the activation. |
5894 | if (!(activationNV.getType()->isEqual(N->getResult().getType()))) { |
5895 | N->getResult().setType(activationNV.getType()); |
5896 | } |
5897 | |
5898 | activationNV.replaceAllUsesOfWith(N->getResult()); |
5899 | return true; |
5900 | } |
5901 | |
5902 | static bool foldActivations(Function *F, const CompilationContext &cctx, |
5903 | const Backend *B) { |
5904 | bool changed = false; |
5905 | for (auto &node : F->getNodes()) { |
5906 | if (fuseActivation(dyn_cast<ConvolutionNode>(&node), F, B)) { |
5907 | changed = true; |
5908 | continue; |
5909 | } |
5910 | if (fuseActivation(dyn_cast<ChannelwiseQuantizedConvolutionNode>(&node), F, |
5911 | B)) { |
5912 | changed = true; |
5913 | continue; |
5914 | } |
5915 | } |
5916 | return changed; |
5917 | } |
5918 | |
5919 | void glow::fold(Function *F, const CompilationContext &cctx, const Backend *B) { |
5920 | LOG_SCOPE(F->getLogContext(), "glow::fold" ) |
5921 | |
5922 | FunctionPassManager FPM("FoldFPM" , createDefaultFoldPassPipeline()); |
5923 | FPM.run(F, cctx); |
5924 | |
5925 | foldActivations(F, cctx, B); |
5926 | } |
5927 | |
5928 | void glow::optimize(Function *F, const CompilationContext &cctx, |
5929 | const Backend &B) { |
5930 | LOG_SCOPE(F->getLogContext(), "glow::optimize" ) |
5931 | |
5932 | FunctionPassManager FPM("TargetDependentGraphOptzFPM" , |
5933 | B.getOptimizationPipeline(), &B); |
5934 | FPM.run(F, cctx); |
5935 | } |
5936 | |
5937 | void glow::optimize(Function *F, const CompilationContext &cctx) { |
5938 | LOG_SCOPE(F->getLogContext(), "glow::optimize" ) |
5939 | |
5940 | // Indicates if the given function is completely loaded. A temporary |
5941 | // workaround until #3213 is complete. |
5942 | F->setState(FunctionState::FuncLoaded); |
5943 | |
5944 | FunctionPassManager FPM("TargetIndependentGraphOptzFPM" , |
5945 | createDefaultGraphOptimizationPassPipeline()); |
5946 | FPM.run(F, cctx); |
5947 | } |
5948 | |
5949 | void glow::optimize(Function *F, CompilationMode mode) { |
5950 | CompilationContext cctx; |
5951 | cctx.compMode = mode; |
5952 | optimize(F, cctx); |
5953 | } |
5954 | |
5955 | /// Helper to pass over all Nodes in \p F and set FP16 accumulation to true for |
5956 | /// those Nodes in the SLS family which support and need it. \p precConfig |
5957 | /// contains the black/whitelist for skipping. |
5958 | static void setFP16AccumSLS(Function *F, |
5959 | const PrecisionConfiguration &precConfig) { |
5960 | // Iterate from original end to beginning to avoid processing new Nodes added |
5961 | // during the pass. |
5962 | auto nodeIt = F->getNodes().end(); |
5963 | auto stopIt = F->getNodes().begin(); |
5964 | do { |
5965 | --nodeIt; |
5966 | Node &node = *nodeIt; |
5967 | // Only update allowed nodes based on black/whitelist. |
5968 | const bool inSet = precConfig.precisionModeKindSet.count(node.getKind()); |
5969 | const bool allowConversion = precConfig.useSetAsWhitelist ? inSet : !inSet; |
5970 | if (!allowConversion) { |
5971 | continue; |
5972 | } |
5973 | |
5974 | #define CASE_SET_SLS_FP16_ACCUM(NODE_) \ |
5975 | case Kinded::Kind::NODE_##NodeKind: { \ |
5976 | NODE_##Node *SLS = llvm::cast<NODE_##Node>(&node); \ |
5977 | if (SLS->getResult().getElementType() != ElemKind::Float16Ty) { \ |
5978 | continue; \ |
5979 | } \ |
5980 | SLS->setUseFP16Accumulation(true); \ |
5981 | continue; \ |
5982 | } |
5983 | |
5984 | switch (node.getKind()) { |
5985 | CASE_SET_SLS_FP16_ACCUM(RowwiseQuantizedSparseLengthsWeightedSum); |
5986 | CASE_SET_SLS_FP16_ACCUM(FusedRowwiseQuantizedSparseLengthsWeightedSum); |
5987 | CASE_SET_SLS_FP16_ACCUM(FusedRowwiseQuantizedSparseLengthsSum); |
5988 | CASE_SET_SLS_FP16_ACCUM(EmbeddingBagByteRowwiseOffsets); |
5989 | default: |
5990 | continue; |
5991 | } |
5992 | } while (nodeIt != stopIt); |
5993 | } |
5994 | |
5995 | /// Look for Dequantize -> Swish -> Quantize, replace it with a quantized Swish. |
5996 | bool QuantizeSwish::run(Function *F, const CompilationContext &cctx) { |
5997 | LOG_SCOPE(F->getLogContext(), getName()); |
5998 | |
5999 | if (!cctx.optimizationOpts.enableQuantParamChanges) { |
6000 | return false; |
6001 | } |
6002 | |
6003 | bool changed = false; |
6004 | for (auto &N : F->getNodes()) { |
6005 | auto *SN = dyn_cast<SwishNode>(&N); |
6006 | if (!SN || SN->getNumUsers() != 1) { |
6007 | continue; |
6008 | } |
6009 | |
6010 | QuantizeNode *QN = |
6011 | dyn_cast<QuantizeNode>((*SN->getUsers().begin()).getUser()); |
6012 | if (!QN) { |
6013 | continue; |
6014 | } |
6015 | |
6016 | DequantizeNode *DN = dyn_cast<DequantizeNode>(SN->getInput()); |
6017 | if (!DN) { |
6018 | continue; |
6019 | } |
6020 | |
6021 | SwishNode *newSN = |
6022 | F->createSwish(SN->getName().str() + "_int" , DN->getInput(), |
6023 | QN->getResult().getType()); |
6024 | QN->getResult().replaceAllUsesOfWith(newSN); |
6025 | changed = true; |
6026 | } |
6027 | return changed; |
6028 | } |
6029 | |
6030 | /// Fold Exp + ReduceSum + Div into Softmax |
6031 | /// IN |
6032 | /// | |
6033 | /// Exp IN |
6034 | /// / \ | |
6035 | /// | ReduceSum --> Softmax |
6036 | /// \ / | |
6037 | /// Div OUT |
6038 | /// | |
6039 | /// OUT |
6040 | bool FoldExpSumDivIntoSoftmax::run(Function *F, |
6041 | const CompilationContext &cctx) { |
6042 | LOG_SCOPE(F->getLogContext(), getName()); |
6043 | |
6044 | bool changed = false; |
6045 | for (auto &N : F->getNodes()) { |
6046 | auto *EN = dyn_cast<ExpNode>(&N); |
6047 | if (!EN || EN->getNumUsers() != 2) { |
6048 | continue; |
6049 | } |
6050 | |
6051 | DivNode *DN = nullptr; |
6052 | BatchedReduceAddNode *RSN = nullptr; |
6053 | |
6054 | auto *user1 = EN->getUsers().front().getUser(); |
6055 | auto *user2 = EN->getUsers().back().getUser(); |
6056 | |
6057 | if (isa<DivNode>(user1) && isa<BatchedReduceAddNode>(user2)) { |
6058 | DN = cast<DivNode>(user1); |
6059 | RSN = cast<BatchedReduceAddNode>(user2); |
6060 | } else if (isa<DivNode>(user2) && isa<BatchedReduceAddNode>(user1)) { |
6061 | DN = cast<DivNode>(user2); |
6062 | RSN = cast<BatchedReduceAddNode>(user1); |
6063 | } else { |
6064 | continue; |
6065 | } |
6066 | |
6067 | if (RSN->getNumUsers() != 1) { |
6068 | continue; |
6069 | } |
6070 | |
6071 | auto *broadcastNode = getOnlyUser(*RSN); |
6072 | if (broadcastNode == nullptr) { |
6073 | continue; |
6074 | } |
6075 | auto *tempDN = getOnlyUser(*broadcastNode); |
6076 | // Ensure that the inputs to the DivNode are Exp and ReduceSum. |
6077 | if (DN != tempDN) { |
6078 | continue; |
6079 | } |
6080 | |
6081 | auto axes = EN->getInput().dims().vec(); |
6082 | axes.back() = 1; |
6083 | auto *CN = F->getParent()->createConstant(glow::ElemKind::Int64ITy, axes, |
6084 | "selected" ); |
6085 | |
6086 | auto *SM = F->createSoftMax("softmax" , EN->getInput(), CN); |
6087 | DN->getResult().replaceAllUsesOfWith(SM); |
6088 | changed = true; |
6089 | } |
6090 | return changed; |
6091 | } |
6092 | |
6093 | /// Local utility to remove identity Relu if fused into \p node. |
6094 | /// \returns true or false whether the Relu was removed or not. |
6095 | template <class NodeTy> static bool removeFusedIdentityRelu(Node *node) { |
6096 | auto *RN = dyn_cast<NodeTy>(node); |
6097 | if (!RN || RN->getFusedActivation() != FusedActivation::RELU) { |
6098 | return false; |
6099 | } |
6100 | // The output type must be quantized. |
6101 | auto outTy = RN->getResult().getType(); |
6102 | if (!outTy->isQuantizedType()) { |
6103 | return false; |
6104 | } |
6105 | // The quantized 0.0f for Relu must match the min of the output type. |
6106 | auto outRange = quantization::getQuantizedRange(outTy->getElementType()); |
6107 | if (outTy->getOffset() != outRange.first) { |
6108 | return false; |
6109 | } |
6110 | // Remove fused Relu. |
6111 | RN->setFusedActivation(FusedActivation::NONE); |
6112 | RN->setFusedActivationArgs({}); |
6113 | return true; |
6114 | } |
6115 | |
6116 | bool RemoveIdentityRelu::run(Function *F, const CompilationContext &cctx) { |
6117 | LOG_SCOPE(F->getLogContext(), getName()); |
6118 | bool changed = false; |
6119 | |
6120 | // Remove standalone Relu. |
6121 | for (auto &N : F->getNodes()) { |
6122 | auto *RN = dyn_cast<ReluNode>(&N); |
6123 | if (!RN) { |
6124 | continue; |
6125 | } |
6126 | |
6127 | // The input and output types must be quantized. |
6128 | auto inpTy = RN->getInput().getType(); |
6129 | auto outTy = RN->getResult().getType(); |
6130 | if (!(inpTy->isQuantizedType() && outTy->isQuantizedType())) { |
6131 | continue; |
6132 | } |
6133 | |
6134 | // The quantized 0.0f for Relu must match the min of the output type. |
6135 | auto outRange = quantization::getQuantizedRange(outTy->getElementType()); |
6136 | if (outTy->getOffset() != outRange.first) { |
6137 | continue; |
6138 | } |
6139 | |
6140 | // Remove Relu if input and output types are the same. |
6141 | // Otherwise change it with a RescaleQuantized. |
6142 | if (inpTy->isEqual(outTy)) { |
6143 | RN->getResult().replaceAllUsesOfWith(RN->getInput()); |
6144 | } else { |
6145 | // TODO: Uncomment this once #5729 gets fixed. |
6146 | // auto *rescale = |
6147 | // F->createRescaleQuantized(RN->getName(), RN->getInput(), outTy); |
6148 | // RN->getResult().replaceAllUsesOfWith(rescale); |
6149 | } |
6150 | changed = true; |
6151 | } |
6152 | |
6153 | // Remove fused Relu. |
6154 | for (auto &N : F->getNodes()) { |
6155 | changed |= removeFusedIdentityRelu<ConvolutionNode>(&N); |
6156 | changed |= removeFusedIdentityRelu<ChannelwiseQuantizedConvolutionNode>(&N); |
6157 | } |
6158 | |
6159 | return changed; |
6160 | } |
6161 | |
6162 | /// Local utility to remove identity Clip if fused into \p node. |
6163 | /// \returns true or false whether the Clip was removed or not. |
6164 | template <class NodeTy> static bool removeFusedIdentityClip(Node *node) { |
6165 | auto *CN = dyn_cast<NodeTy>(node); |
6166 | if (!CN || CN->getFusedActivation() != FusedActivation::CLIP) { |
6167 | return false; |
6168 | } |
6169 | // The output type must be quantized. |
6170 | auto outTy = CN->getResult().getType(); |
6171 | if (!outTy->isQuantizedType()) { |
6172 | return false; |
6173 | } |
6174 | // The quantized min/max for Clip must match the min/max of the output type. |
6175 | TensorQuantizationParams outTQP{outTy->getScale(), outTy->getOffset()}; |
6176 | auto fMin = CN->getFusedActivationArgs()[0]; |
6177 | auto fMax = CN->getFusedActivationArgs()[1]; |
6178 | auto qMin = quantization::quantize(fMin, outTQP, outTy->getElementType()); |
6179 | auto qMax = quantization::quantize(fMax, outTQP, outTy->getElementType()); |
6180 | auto outRange = quantization::getQuantizedRange(outTy->getElementType()); |
6181 | if (!(qMin == outRange.first && qMax == outRange.second)) { |
6182 | return false; |
6183 | } |
6184 | // Remove fused Relu. |
6185 | CN->setFusedActivation(FusedActivation::NONE); |
6186 | CN->setFusedActivationArgs({}); |
6187 | return true; |
6188 | } |
6189 | |
6190 | bool RemoveIdentityClip::run(Function *F, const CompilationContext &cctx) { |
6191 | LOG_SCOPE(F->getLogContext(), getName()); |
6192 | bool changed = false; |
6193 | |
6194 | // Remove standalone Clip. |
6195 | for (auto &N : F->getNodes()) { |
6196 | auto *CN = dyn_cast<ClipNode>(&N); |
6197 | if (!CN) { |
6198 | continue; |
6199 | } |
6200 | |
6201 | // The input and output types must be quantized. |
6202 | auto inpTy = CN->getInput().getType(); |
6203 | auto outTy = CN->getResult().getType(); |
6204 | if (!(inpTy->isQuantizedType() && outTy->isQuantizedType())) { |
6205 | continue; |
6206 | } |
6207 | |
6208 | // The quantized min/max for Clip must match the min/max of the output type. |
6209 | TensorQuantizationParams outTQP{outTy->getScale(), outTy->getOffset()}; |
6210 | auto fMin = CN->getMin(); |
6211 | auto fMax = CN->getMax(); |
6212 | auto qMin = quantization::quantize(fMin, outTQP, outTy->getElementType()); |
6213 | auto qMax = quantization::quantize(fMax, outTQP, outTy->getElementType()); |
6214 | auto outRange = quantization::getQuantizedRange(outTy->getElementType()); |
6215 | if (!(qMin == outRange.first && qMax == outRange.second)) { |
6216 | continue; |
6217 | } |
6218 | |
6219 | // Remove Clip if input and output types are the same. |
6220 | // Otherwise change it with a RescaleQuantized. |
6221 | if (inpTy->isEqual(outTy)) { |
6222 | CN->getResult().replaceAllUsesOfWith(CN->getInput()); |
6223 | } else { |
6224 | // TODO: Uncomment this once #5729 gets fixed. |
6225 | // auto *rescale = |
6226 | // F->createRescaleQuantized(CN->getName(), CN->getInput(), outTy); |
6227 | // CN->getResult().replaceAllUsesOfWith(rescale); |
6228 | } |
6229 | changed = true; |
6230 | } |
6231 | |
6232 | // Remove fused Clip. |
6233 | for (auto &N : F->getNodes()) { |
6234 | changed |= removeFusedIdentityClip<ConvolutionNode>(&N); |
6235 | changed |= removeFusedIdentityClip<ChannelwiseQuantizedConvolutionNode>(&N); |
6236 | } |
6237 | |
6238 | return changed; |
6239 | } |
6240 | |
6241 | /// Convert a FullyConnected node to a 1x1 Convolution. |
6242 | bool ConvertFullyConnectedToConvolution::run(Function *F, |
6243 | const CompilationContext &cctx) { |
6244 | LOG_SCOPE(F->getLogContext(), getName()); |
6245 | |
6246 | bool changed = false; |
6247 | for (auto &N : F->getNodes()) { |
6248 | auto *FCN = dyn_cast<FullyConnectedNode>(&N); |
6249 | if (!FCN) { |
6250 | continue; |
6251 | } |
6252 | |
6253 | NodeValue output = FCN->getResult(); |
6254 | NodeValue input = FCN->getInput(); |
6255 | NodeValue filter = FCN->getWeights(); |
6256 | NodeValue bias = FCN->getBias(); |
6257 | |
6258 | // Reshape input from 2D to 4D. |
6259 | auto inpDims = ShapeHW(input.getType()->dims()); |
6260 | std::vector<dim_t> inpDimsCN = {inpDims.height, 1, 1, inpDims.width}; |
6261 | input = F->createReshape(FCN->getName(), input, inpDimsCN); |
6262 | |
6263 | // Transpose filter and reshape from 2D to 4D. |
6264 | filter = F->createTranspose(FCN->getName(), filter, {1, 0}); |
6265 | auto filterDims = ShapeHW(filter.getType()->dims()); |
6266 | std::vector<dim_t> filterDimsCN = {filterDims.height, 1, 1, |
6267 | filterDims.width}; |
6268 | filter = F->createReshape(FCN->getName(), filter, filterDimsCN); |
6269 | |
6270 | // Create Conv2D node with same output type but 4D shape. |
6271 | auto outDims = ShapeHW(output.getType()->dims()); |
6272 | std::vector<dim_t> outDimsCN = {outDims.height, 1, 1, outDims.width}; |
6273 | auto outTyCN = |
6274 | F->getParent()->uniqueTypeWithNewShape(output.getType(), outDimsCN); |
6275 | NodeValue outputCN = |
6276 | F->createConv(FCN->getName(), input, filter, bias, outTyCN, |
6277 | /* kernels */ {1, 1}, |
6278 | /* strides */ {1, 1}, |
6279 | /* pads */ {0, 0, 0, 0}, |
6280 | /* group */ 1); |
6281 | |
6282 | // Reshape the 4D output back to its original 2D shape. |
6283 | outputCN = |
6284 | F->createReshape(FCN->getName(), outputCN, output.getType()->dims()); |
6285 | FCN->getResult().replaceAllUsesOfWith(outputCN); |
6286 | changed = true; |
6287 | } |
6288 | return changed; |
6289 | } |
6290 | |
6291 | /// Fold Min -> Max or Max -> Min to Clip |
6292 | /// We need to place this function after `OptimizeArithmeticNodes` in which the |
6293 | /// constant operator in commutative nodes is moved to the RHS, so that we can |
6294 | /// assume SplatNode is on the RHS of MinNode (MaxNode). |
6295 | /// Only if it's the following structure we do this optimization |
6296 | /// |
6297 | /// someOtherInput SplatNode |
6298 | /// \ / |
6299 | /// SplatNode MinNode (MaxNode) |
6300 | /// \ / |
6301 | /// MaxNode (MinNode) |
6302 | /// | |
6303 | /// Out |
6304 | /// |
6305 | /// ClipNode's min will be the SplatNode (maxSN) connected to the MaxNode. |
6306 | /// ClipNode's max will be the SplatNode (minSN) connected to the MinNode. |
6307 | bool FoldMinMaxToClip::run(Function *F, const CompilationContext &cctx) { |
6308 | LOG_SCOPE(F->getLogContext(), getName()); |
6309 | |
6310 | bool changed = false; |
6311 | for (auto &N : F->getNodes()) { |
6312 | MaxNode *maxNode = dyn_cast<MaxNode>(&N); |
6313 | MinNode *minNode = dyn_cast<MinNode>(&N); |
6314 | NodeValue otherInput; |
6315 | NodeValue resultNV; |
6316 | SplatNode *minSN = nullptr; |
6317 | SplatNode *maxSN = nullptr; |
6318 | |
6319 | // We assume SplatNode is on the RHS. |
6320 | // If currect node is MinNode |
6321 | // - try casting the input to MaxNode and SplatNode. |
6322 | // - If LHS of MinNode is MaxNode |
6323 | // - try casting RHS of maxNode to SplatNode. |
6324 | if (minNode) { |
6325 | resultNV = minNode->getResult(); |
6326 | maxNode = dyn_cast<MaxNode>( |
6327 | unwindBroadcast(minNode->getLHS(), minNode->getLHS().dims().size())); |
6328 | minSN = dyn_cast<SplatNode>( |
6329 | unwindBroadcast(minNode->getRHS(), minNode->getRHS().dims().size())); |
6330 | |
6331 | if (maxNode) { |
6332 | maxSN = dyn_cast<SplatNode>(unwindBroadcast( |
6333 | maxNode->getRHS(), maxNode->getRHS().dims().size())); |
6334 | otherInput = |
6335 | unwindBroadcast(maxNode->getLHS(), maxNode->getLHS().dims().size()); |
6336 | } |
6337 | } else if (maxNode) { // vice versa for MaxNode |
6338 | resultNV = maxNode->getResult(); |
6339 | minNode = dyn_cast<MinNode>( |
6340 | unwindBroadcast(maxNode->getLHS(), maxNode->getLHS().dims().size())); |
6341 | maxSN = dyn_cast<SplatNode>( |
6342 | unwindBroadcast(maxNode->getRHS(), maxNode->getRHS().dims().size())); |
6343 | |
6344 | if (minNode) { |
6345 | minSN = dyn_cast<SplatNode>(unwindBroadcast( |
6346 | minNode->getRHS(), minNode->getRHS().dims().size())); |
6347 | otherInput = |
6348 | unwindBroadcast(minNode->getLHS(), minNode->getLHS().dims().size()); |
6349 | } |
6350 | } |
6351 | |
6352 | // If any of these is nullptr, it means the structure is not the same as |
6353 | // above example. |
6354 | if (!(minNode && maxNode && minSN && maxSN)) { |
6355 | continue; |
6356 | } |
6357 | |
6358 | // If minSN is smaller than maxSN, which is really weird because every entry |
6359 | // of the result will be a same number. In this case, we don't fold them. |
6360 | if (minSN->getValue() < maxSN->getValue()) { |
6361 | DLOG(INFO) << "Catch a combination of MinNode and MaxNode while MinSplat " |
6362 | "is smaller than MaxSplat, which would make all entries of " |
6363 | "the result tensor to be the same!" ; |
6364 | continue; |
6365 | } |
6366 | |
6367 | // The second node is broadcasted and we need to broadcast the otherInput. |
6368 | if (otherInput.dims() != resultNV.dims()) { |
6369 | otherInput = F->createBroadcast( |
6370 | otherInput.getNode()->getName().str() + ".broadcast" , otherInput, |
6371 | resultNV.dims(), resultNV.dims().size() - otherInput.dims().size()); |
6372 | } |
6373 | |
6374 | // MinSN is the SplatNode input of MinNode while MaxSN is the SplatNode |
6375 | // input of MaxNode. MinSN should be greater than MaxSN so we put MaxSN as |
6376 | // the min of ClipNode. |
6377 | auto minValue = maxSN->getValue(); |
6378 | auto maxValue = minSN->getValue(); |
6379 | ClipNode *CN = F->createClip(resultNV.getNode()->getName(), otherInput, |
6380 | resultNV.getType(), |
6381 | /* min */ minValue, /* max */ maxValue); |
6382 | resultNV.replaceAllUsesOfWith(CN->getResult()); |
6383 | changed = true; |
6384 | } |
6385 | |
6386 | return changed; |
6387 | } |
6388 | |
6389 | /// Look for qparams with scale (when casted to fp16) == 0 and replace them with |
6390 | /// zero Splats. |
6391 | bool ReplaceZeroScaleFP16QuantNodes::run(Function *F, |
6392 | const CompilationContext &cctx) { |
6393 | LOG_SCOPE(F->getLogContext(), getName()); |
6394 | |
6395 | // Cannot run this opt if we're using dummy qparams. |
6396 | if (cctx.precisionConfig.loadUniquedDummyQParams) { |
6397 | return false; |
6398 | } |
6399 | |
6400 | auto processNV = [](Function *F, NodeValue resNV) { |
6401 | const TypeRef resTy = resNV.getType(); |
6402 | if (resTy->isFusedQuantizedType() || !resTy->isQuantizedType()) { |
6403 | return false; |
6404 | } |
6405 | |
6406 | // Check if we have a scale that's below the minimum allowed FP16 val. |
6407 | if (resTy->getScale() >= kMinScaleFP16) { |
6408 | return false; |
6409 | } |
6410 | |
6411 | // Skip if used by node with side effects, since we cannot change the |
6412 | // qparams in such cases. |
6413 | if (isUsedByNodeWithSideEffects(resNV.getNode())) { |
6414 | return false; |
6415 | } |
6416 | |
6417 | // This NodeValue has scale = 0.f, which means the equivalent float result |
6418 | // will always be equal to 0.f. So create a splat with value 0.f, scale 1.f, |
6419 | // offset 0 instead. |
6420 | auto splatTy = F->getParent()->uniqueType(resTy->getElementType(), |
6421 | resTy->dims(), 1.f, 0); |
6422 | auto *SN = F->createSplat(resNV.getNode()->getName().str() + ".splatted" , |
6423 | splatTy, 0.f); |
6424 | |
6425 | // Note: must use type unsafe replace because we've changed the qparams. |
6426 | resNV.typeUnsafeReplaceAllUsesOfWith(SN->getResult(), F); |
6427 | return true; |
6428 | }; |
6429 | |
6430 | bool changed = false; |
6431 | // Since we will be adding in new SplatNodes, reverse iterate to be safe. |
6432 | auto &nodes = F->getNodes(); |
6433 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) { |
6434 | Node *N = &*it; |
6435 | // For now only support nodes with single outputs. |
6436 | if (N->getNumResults() != 1) { |
6437 | continue; |
6438 | } |
6439 | NodeValue resNV = N->getNthResult(0); |
6440 | changed |= processNV(F, resNV); |
6441 | } |
6442 | for (Placeholder *PH : F->findPlaceholders()) { |
6443 | changed |= processNV(F, PH->getOutput()); |
6444 | } |
6445 | for (Constant *C : F->findConstants()) { |
6446 | changed |= processNV(F, C->getOutput()); |
6447 | } |
6448 | |
6449 | return changed; |
6450 | } |
6451 | |
6452 | /// This funciton uses TypeAToTypeBFunctionConverter to do a whole graph |
6453 | /// demotion of Index type from INT64 to INT32. |
6454 | static void transformIndexTypeDemotion(const Backend &B, Function *F, |
6455 | CompilationContext &cctx) { |
6456 | |
6457 | // Does a coarse check to make sure none of the indices potentially can |
6458 | // overflow 32 bit. For now we just give up on the whole optimization, since |
6459 | // this is probably a corner case. |
6460 | for (auto &n : F->getNodes()) { |
6461 | for (int i = 0, nOutputs = n.getNumResults(); i < nOutputs; ++i) { |
6462 | if (n.getNthResult(i).getType()->actualSize() >= |
6463 | size_t(std::numeric_limits<int32_t>::max())) { |
6464 | return; |
6465 | } |
6466 | } |
6467 | } |
6468 | |
6469 | PrecisionConfiguration precConfig; |
6470 | if (B.canDoIndexTypeDemotion(ElemKind::Int64ITy, ElemKind::Int32ITy, |
6471 | precConfig) && |
6472 | cctx.optimizationOpts.enableTypeDemotion) { |
6473 | precConfig.precisionModeKindSet.insert(Kinded::Kind::TraceEventNodeKind); |
6474 | TypeAToTypeBFunctionConverter converter(*F, ElemKind::Int64ITy, |
6475 | ElemKind::Int32ITy, precConfig); |
6476 | converter.convert(); |
6477 | } |
6478 | } |
6479 | |
6480 | void glow::transformForPrecisionMode(const Backend &B, Function *F, |
6481 | CompilationContext &cctx) { |
6482 | LOG_SCOPE(F->getLogContext(), "transformForPrecisionMode" ) |
6483 | const PrecisionConfiguration &precConfig = cctx.precisionConfig; |
6484 | |
6485 | switch (precConfig.quantMode) { |
6486 | case QuantizationMode::Profile: { |
6487 | assert(cctx.bindings); |
6488 | |
6489 | LOG_SCOPE(F->getLogContext(), "glow::profileQuantization" ) |
6490 | |
6491 | glow::profileQuantization(*cctx.bindings, F, precConfig.profConfig); |
6492 | break; |
6493 | } |
6494 | |
6495 | case QuantizationMode::Quantize: { |
6496 | LOG_SCOPE(F->getLogContext(), "quantization::quantizeFunction" ) |
6497 | |
6498 | quantization::quantizeFunction(F, precConfig.quantConfig, B, |
6499 | *cctx.loweredInfoMap, |
6500 | precConfig.precisionModeKindSet); |
6501 | break; |
6502 | } |
6503 | |
6504 | case QuantizationMode::None: { |
6505 | break; |
6506 | } |
6507 | } |
6508 | |
6509 | if (precConfig.convertToFP16) { |
6510 | LOG_SCOPE(F->getLogContext(), "glow::convertFunctionToFloat16" ) |
6511 | convertFunctionToFloat16(F, precConfig); |
6512 | FunctionPassManager FPM("FP16GraphOptzFPM" , |
6513 | createFP16GraphOptimizationPassPipeline()); |
6514 | FPM.run(F, cctx); |
6515 | } |
6516 | |
6517 | // By default, FP16 SLS accumulation is not enabled. |
6518 | // If requested, Force all ops in the SLS family to use FP16 accumulation. |
6519 | if (precConfig.forceFP16AccumSLS) { |
6520 | setFP16AccumSLS(F, precConfig); |
6521 | } |
6522 | |
6523 | // Convert UInt4FusedFP16QTy/UInt8FusedFP16QTy to UInt8FusedQTy. |
6524 | if (precConfig.convert4BitFusedToFP32 || precConfig.convert8BitFusedToFP32) { |
6525 | LOG_SCOPE(F->getLogContext(), "glow::convertFunctionToFP32ScaleOffset" ); |
6526 | convertFunctionToFP32ScaleOffset(F, precConfig); |
6527 | } |
6528 | |
6529 | // In FusedRowwiseQSLWS, convert its indices from Int32(if there is any) to |
6530 | // Int64. |
6531 | if (precConfig.convertIndicesToInt64) { |
6532 | LOG_SCOPE(F->getLogContext(), "glow::convertFunctionIndicesToInt64" ); |
6533 | convertFunctionIndicesToInt64(F, precConfig); |
6534 | } |
6535 | } |
6536 | |
6537 | Error glow::optimizeFunctionBeforeLowering(Function *F, |
6538 | CompilationContext &cctx) { |
6539 | LOG_SCOPE(F->getLogContext(), "glow::optimizeFunctionBeforeLowering" ) |
6540 | |
6541 | // If we only want to lower the Function, do nothing here. |
6542 | if (cctx.optimizationOpts.onlyLowerFuns.count(F)) { |
6543 | return Error::success(); |
6544 | } |
6545 | |
6546 | // Verify the function pre-optimization/lowering. |
6547 | assert(F->verify() && "Function must be valid" ); |
6548 | |
6549 | // Verify that the CompilationContext is set up correctly. |
6550 | RETURN_IF_ERR(cctx.verify()); |
6551 | |
6552 | // Fold low-level operators into higher-level operators. |
6553 | // This is useful when compiling an input model where some high-level |
6554 | // operators have been lowered (this can be for instance a side effect of |
6555 | // model converters, like converters from Tensorflow to ONNX). In this |
6556 | // situation, such folding can then enable more optimizations and also improve |
6557 | // the performance backends that support natively such high-level operators. |
6558 | ::glow::fold(F, cctx); |
6559 | |
6560 | // Optimize the graph. Only runs optimizations that are target-independent. |
6561 | ::glow::optimize(F, cctx); |
6562 | return Error::success(); |
6563 | } |
6564 | |
6565 | /// Error message to print when there is a graph hash checking error. |
6566 | static const char *graphPreLowerHashCheckErrMsg = |
6567 | R"RAW(Graph hash mismatch! |
6568 | %s |
6569 | %s |
6570 | Potential causes: |
6571 | 1. The profile YAML file was produced with an older version of the Glow tools |
6572 | while the quantization of the model is performed with a newer version. |
6573 | 2. The profile YAML file was produced for a different model than the model used |
6574 | for quantization. |
6575 | 3. The profile YAML file was produced for the same model but for a different |
6576 | batch size. If the profile was generated using the 'image-classifier' Glow |
6577 | tool you can select the batch size of the model during profiling using the |
6578 | 'minibatch' option. During quantization you can choose the batch size of the |
6579 | model by choosing for each input placeholder the tensor size using the |
6580 | 'model-input' option. |
6581 | )RAW" ; |
6582 | |
6583 | // NOTE: When updating this function, please also update the documentation in |
6584 | // docs/GraphOptimizationPipeline.md |
6585 | Error glow::optimizeFunction(Function *F, const Backend &B, |
6586 | CompilationContext &cctx, |
6587 | const glow::runtime::DeviceInfo *devInfo) { |
6588 | LOG_SCOPE(F->getLogContext(), "glow::optimizeFunction" ) |
6589 | |
6590 | // If requested only lower the Function and early return. |
6591 | if (cctx.optimizationOpts.onlyLowerFuns.count(F)) { |
6592 | ::glow::lower(F, cctx, &B); |
6593 | // Cleanup from lowering via DCE. |
6594 | runDCEPass(F, cctx); |
6595 | |
6596 | if (!B.verify(*F, cctx.verboseCompile)) { |
6597 | return MAKE_ERR( |
6598 | ErrorValue::ErrorCode::COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE, |
6599 | "Unsupported node(s) found after only-lowering path for Function " + |
6600 | F->getName().str() + " for backend " + B.getBackendName()); |
6601 | } |
6602 | return Error::success(); |
6603 | } |
6604 | |
6605 | RETURN_IF_ERR(optimizeFunctionBeforeLowering(F, cctx)); |
6606 | |
6607 | // Graph hash check: |
6608 | // - During PROFILING store the hash of the graph at this point of the |
6609 | // optimization pipeline. The hash will be exported in the YAML profile. |
6610 | // - During QUANTIZATION the hash imported from the YAML profile is used to |
6611 | // verify the hash of the graph. This is helpful to catch mismatches between |
6612 | // the graph used during profiling/quantization. |
6613 | const PrecisionConfiguration &precConfig = cctx.precisionConfig; |
6614 | if (precConfig.quantMode == QuantizationMode::Profile) { |
6615 | cctx.info.graphPreLowerHash = F->getHash(); |
6616 | } else if (precConfig.quantMode == QuantizationMode::Quantize) { |
6617 | const auto &quantConfig = cctx.precisionConfig.quantConfig; |
6618 | if (quantConfig.checkGraphPreLowerHash) { |
6619 | auto profileHash = quantConfig.graphPreLowerHash; |
6620 | auto currentHash = F->getHash(); |
6621 | auto profileHashStr = |
6622 | strFormat("Profile graph hash: 0x%" PRIX64, (uint64_t)(profileHash)); |
6623 | auto currentHashStr = |
6624 | strFormat("Current graph hash: 0x%" PRIX64, (uint64_t)(currentHash)); |
6625 | RETURN_ERR_IF_NOT(profileHash == currentHash, |
6626 | strFormat(graphPreLowerHashCheckErrMsg, |
6627 | profileHashStr.c_str(), |
6628 | currentHashStr.c_str())); |
6629 | } |
6630 | } |
6631 | |
6632 | // Lower the graph into a sequence of low-level linear algebra operations. |
6633 | if (precConfig.quantMode == QuantizationMode::Profile) { |
6634 | // When profiling, pass a nullptr for the backend, signaling that all nodes |
6635 | // should be lowered. loweredInfoMap logs what is lowered from what for |
6636 | // later use when creating quantization infos. Also pass the precision mode |
6637 | // kind set as nodes to not lower, specified higher up in the stack. |
6638 | ::glow::lower(F, cctx, /* backend */ nullptr, |
6639 | precConfig.precisionModeKindSet); |
6640 | } else { |
6641 | // Lower based on the backend's preferences. |
6642 | ::glow::lower(F, cctx, &B); |
6643 | } |
6644 | |
6645 | // Transforms the graph by demoting i64 to i32. |
6646 | transformIndexTypeDemotion(B, F, cctx); |
6647 | |
6648 | // Transform given precision mode; may quantize, convert to fp16, or |
6649 | // instrument with profiling nodes. This must be done after lowering. |
6650 | transformForPrecisionMode(B, F, cctx); |
6651 | |
6652 | // Optimize the quantized graph because quantization nodes should be optimized |
6653 | // before folding Activation into Conv. |
6654 | ::glow::optimize(F, cctx); |
6655 | |
6656 | // Fold activations before lowering to enable cases which would not fuse after |
6657 | // lowering. This concerns particularly convolution&relu since relu will be |
6658 | // lowered to max(0, x). |
6659 | foldActivations(F, cctx, &B); |
6660 | |
6661 | // Lower once more, in case precision transform has introduced operators that |
6662 | // need to be lowered, e.g., Clip. |
6663 | ::glow::lower(F, cctx, &B); |
6664 | |
6665 | // Optimize the graph again now that we have a lowered representation. |
6666 | ::glow::optimize(F, cctx); |
6667 | |
6668 | // If requested fold ElemKind conversion Nodes into static Placeholders, |
6669 | // inputs, and outputs (Placeholders and SaveNodes). |
6670 | if (cctx.optimizationOpts.foldStaticPlaceholderConversions || |
6671 | cctx.optimizationOpts.foldElemKindConversionIntoIO) { |
6672 | std::unique_ptr<FunctionPassPipeline> pipeline = |
6673 | glow::make_unique<FunctionPassPipeline>(); |
6674 | pipeline->pushBack({FunctionPassID::FoldElemKindConversionIntoInputs}); |
6675 | |
6676 | if (cctx.optimizationOpts.foldElemKindConversionIntoIO) { |
6677 | pipeline->pushBack({FunctionPassID::FoldElemKindConversionIntoOutputs}); |
6678 | } |
6679 | FunctionPassManager FPM("FoldElemKindConversionIntoIO" , |
6680 | std::move(pipeline)); |
6681 | if (FPM.run(F, cctx)) { |
6682 | ::glow::optimize(F, cctx); |
6683 | } |
6684 | } |
6685 | |
6686 | if (B.shouldPreQuantizeConstants()) { |
6687 | // Do the actual float ->fix-point conversion of constant tensors before |
6688 | // Post-lowering. |
6689 | ::glow::convertQuantizedConstants(F, cctx); |
6690 | } |
6691 | |
6692 | // Allow the backend to transform the graph after lowering. |
6693 | RETURN_IF_EXPECTED_IS_ERR(B.transformPostLowering(F, cctx, devInfo)); |
6694 | |
6695 | if (!B.shouldPreQuantizeConstants()) { |
6696 | // Do the actual float ->fix-point conversion of constant tensors after |
6697 | // Post-lowering. |
6698 | ::glow::convertQuantizedConstants(F, cctx); |
6699 | } |
6700 | |
6701 | // Optimize the graph again after the backend transformation. |
6702 | // In particular, DCE is very likely to be useful. |
6703 | ::glow::optimize(F, cctx, B); |
6704 | |
6705 | // We already started using backend specific verification when the function |
6706 | // state became lowered. Do one more verification pass to make sure everything |
6707 | // is in order and to bail if it is not. |
6708 | if (cctx.optimizationOpts.delayAndRecordConstantModification || |
6709 | cctx.optimizationOpts.skipBackendSupportCheck) { |
6710 | // Only do verification without checking the backend's support if requested, |
6711 | // or if we are disallowing constant modification (since this flag may have |
6712 | // prevented a backend from supporting some Nodes, and may be supported |
6713 | // after constant folding finishes). Expect the caller to verify that Nodes |
6714 | // are supported by the backend later on in such cases. |
6715 | RETURN_ERR_IF_NOT(F->verify(&B), |
6716 | "Verification after optimization failed for Function " + |
6717 | F->getName().str() + " and Backend " + |
6718 | B.getBackendName()); |
6719 | } else if (!B.verify(*F, cctx.verboseCompile)) { |
6720 | return MAKE_ERR( |
6721 | ErrorValue::ErrorCode::COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE, |
6722 | "Unsupported node(s) found after optimizing Function " + |
6723 | F->getName().str() + " for backend " + B.getBackendName()); |
6724 | } |
6725 | return Error::success(); |
6726 | } |
6727 | |
6728 | bool glow::executeVerticalFCWeightsSplit(Function *F, unsigned numOfChunks, |
6729 | unsigned minKToSplit) { |
6730 | DCHECK(numOfChunks > 0) << "numOfChunks must be a positive number, given: " |
6731 | << numOfChunks; |
6732 | DCHECK(minKToSplit > 0) << "minKToSplit must be a positive number, given: " |
6733 | << minKToSplit; |
6734 | |
6735 | bool changed = false; |
6736 | for (auto it = F->getNodes().begin(), e = F->getNodes().end(); it != e; |
6737 | ++it) { |
6738 | auto *FC = dyn_cast<FullyConnectedNode>(it); |
6739 | if (!FC) { |
6740 | continue; |
6741 | } |
6742 | |
6743 | size_t K = FC->getWeights().dims()[1]; |
6744 | if (K < minKToSplit) { |
6745 | continue; |
6746 | } |
6747 | |
6748 | auto input = FC->getInput(); |
6749 | auto weights = FC->getWeights(); |
6750 | auto bias = FC->getBias(); |
6751 | |
6752 | dim_t elemPerChunk = (bias.dims()[0] + numOfChunks - 1) / numOfChunks; |
6753 | dim_t sliceStart = 0; |
6754 | std::vector<NodeValue> fcs(numOfChunks); |
6755 | |
6756 | // Split weights across second dimension into numOfChunks pieces. |
6757 | // Input dimension is [M;K] and kept untouched. |
6758 | // Bias dimension is [N], split into chunks. |
6759 | // Weight dimension is [K;N], split into numOfChunks chunks, |
6760 | // [K;N/numOfChunks] each. |
6761 | // Last chunk might require special handling in case |
6762 | // N is not divisible by numOfChunks. |
6763 | auto *fcType = F->getParent()->uniqueTypeWithNewShape( |
6764 | FC->getResult().getType(), {FC->getResult().dims()[0], elemPerChunk}); |
6765 | |
6766 | for (unsigned i = 0; i < numOfChunks; ++i) { |
6767 | // Last chunk might need special handling if bias dimension |
6768 | // is not divisible by numOfChunks. |
6769 | if (i == numOfChunks - 1 && bias.dims()[0] % numOfChunks != 0) { |
6770 | elemPerChunk = bias.dims()[0] - (numOfChunks - 1) * elemPerChunk; |
6771 | fcType = F->getParent()->uniqueTypeWithNewShape( |
6772 | FC->getResult().getType(), |
6773 | {FC->getResult().dims()[0], elemPerChunk}); |
6774 | } |
6775 | |
6776 | auto *weightSlice = F->createSlice( |
6777 | "weight_slice." + weights.getNode()->getName().str(), weights, |
6778 | {0, sliceStart}, {weights.dims()[0], sliceStart + elemPerChunk}); |
6779 | auto *biasSlice = |
6780 | F->createSlice("bias_slice." + bias.getNode()->getName().str(), bias, |
6781 | {sliceStart}, {sliceStart + elemPerChunk}); |
6782 | fcs[i] = F->createFullyConnected("fc_slice." + FC->getName().str(), input, |
6783 | weightSlice->getResult(), |
6784 | biasSlice->getResult(), fcType); |
6785 | sliceStart += elemPerChunk; |
6786 | } |
6787 | |
6788 | auto *concat = |
6789 | F->createConcat("concat." + FC->getName().str(), fcs, /*dimension*/ 1); |
6790 | FC->getResult().replaceAllUsesOfWith(concat); |
6791 | changed = true; |
6792 | } |
6793 | |
6794 | return changed; |
6795 | } |
6796 | |
6797 | static Expected<ConcatNode *> parallelizeAndReplaceReshapeNode( |
6798 | Function *F, Node *curNode, dim_t numOfChunksNode, dim_t inputBatchIdx, |
6799 | dim_t resultIdx, llvm::ArrayRef<int> splitDims, size_t resultDim, |
6800 | dim_t modelParallelSplitAlignment = 1) { |
6801 | const int inputIdx = splitDims[inputBatchIdx]; |
6802 | RETURN_ERR_IF_NOT(inputIdx >= 0, "Input batch idx must be split" ); |
6803 | RETURN_ERR_IF_NOT(modelParallelSplitAlignment == 1, |
6804 | "modelParallelSplitAlignment must be 1" ); |
6805 | const dim_t batchSize = curNode->getNthInput(inputBatchIdx).dims()[inputIdx]; |
6806 | // We can only apply this parallelization when the input/output batch sizes |
6807 | // can be divided by numOfChunksNode |
6808 | if ((curNode->getNthResult(0).dims()[0] % numOfChunksNode != 0) || |
6809 | (batchSize % numOfChunksNode != 0)) { |
6810 | return nullptr; |
6811 | } |
6812 | const dim_t elemPerChunk = batchSize / numOfChunksNode; |
6813 | |
6814 | RETURN_ERR_IF_NOT( |
6815 | batchSize >= numOfChunksNode, |
6816 | strFormat("Invalid parallelization; batchSize %lu must be " |
6817 | ">= numOfChunksNode %lu for node %s with kind %s" , |
6818 | (unsigned long)batchSize, (unsigned long)numOfChunksNode, |
6819 | curNode->getName().str().c_str(), curNode->getKindName())); |
6820 | |
6821 | std::vector<NodeValue> newNodes(numOfChunksNode); |
6822 | for (dim_t i = 0; i < numOfChunksNode; ++i) { |
6823 | // Calculate the out type of this chunk. |
6824 | const dim_t sliceStart = i * elemPerChunk; |
6825 | const dim_t sliceEnd = |
6826 | (i < numOfChunksNode - 1) ? sliceStart + elemPerChunk : batchSize; |
6827 | VLOG(1) << "\tChunk " << i << ": start: " << sliceStart |
6828 | << " end: " << sliceEnd << "\n" ; |
6829 | auto outDims = curNode->dims(resultIdx).vec(); |
6830 | VLOG(1) << "original out dims: [" << folly::join(", " , outDims) << "]" ; |
6831 | RETURN_ERR_IF_NOT(resultDim < outDims.size(), |
6832 | "outDims access out of range" ); |
6833 | outDims[resultDim] /= numOfChunksNode; |
6834 | VLOG(1) << "modified out dims: [" << folly::join(", " , outDims) << "]" ; |
6835 | |
6836 | // Clone the original Node, so that it keeps all of the inputs/members of |
6837 | // the original Node. Then modify the output type so that its new shape is |
6838 | // correct, and below change the inputs to the sliced inputs. |
6839 | Node *clone = curNode->clone(); |
6840 | clone->getNthResult(resultIdx).setTypeUnsafe( |
6841 | F->getParent()->uniqueTypeWithNewShape(curNode->getType(resultIdx), |
6842 | outDims)); |
6843 | F->addNode(clone); |
6844 | |
6845 | // Loop over all of the inputs and slice those inputs that need to be |
6846 | // sliced, and set them on the clone. |
6847 | for (int j = 0, e = curNode->getNumInputs(); j < e; j++) { |
6848 | int dim = splitDims[j]; |
6849 | if (dim == -1) { |
6850 | continue; |
6851 | } |
6852 | |
6853 | NodeValue currInput = curNode->getNthInput(j); |
6854 | auto sliceDimsStart = std::vector<dim_t>(currInput.dims().size(), 0); |
6855 | RETURN_ERR_IF_NOT(dim < sliceDimsStart.size(), |
6856 | "sliceDimsStart access out of range" ); |
6857 | sliceDimsStart[dim] = sliceStart; |
6858 | auto sliceDimsEnd = currInput.dims().vec(); |
6859 | RETURN_ERR_IF_NOT(dim < sliceDimsEnd.size(), |
6860 | "sliceDimsEnd access out of range" ); |
6861 | sliceDimsEnd[dim] = sliceEnd; |
6862 | VLOG(1) << "start: [" << folly::join(", " , sliceDimsStart) << "]" ; |
6863 | VLOG(1) << "end: [" << folly::join(", " , sliceDimsEnd) << "]" ; |
6864 | VLOG(1) << "Input name: " << currInput.getNode()->getName().str() << "\n" ; |
6865 | |
6866 | auto *inputSlice = |
6867 | F->createSlice("dp_slice." + currInput.getNode()->getName().str() + |
6868 | "." + std::to_string(i), |
6869 | currInput, sliceDimsStart, sliceDimsEnd); |
6870 | clone->setNthInput(j, inputSlice); |
6871 | |
6872 | newNodes[i] = NodeValue(clone, resultIdx); |
6873 | } |
6874 | } |
6875 | |
6876 | std::vector<NodeValue> newNodesClip(numOfChunksNode); |
6877 | int cnt = 0; |
6878 | // Add extra Clip ops |
6879 | // TODO: need to remove the Clip ops once the FC inputs can be put on SRAM |
6880 | for (auto &node : newNodes) { |
6881 | ClipNode *newCN = F->createClip(node.getNode()->getName().str() + "_clip_" , |
6882 | node.getNode(), kMinFP16, kMaxFP16); |
6883 | newNodesClip[cnt] = newCN; |
6884 | cnt++; |
6885 | } |
6886 | |
6887 | // Now that we have split the node into many, concat all of the pieces back |
6888 | // together and replace the original by the concat. |
6889 | VLOG(1) << "Creating Concat" ; |
6890 | auto *concat = F->createConcat("concat." + curNode->getName().str(), |
6891 | newNodesClip, resultDim); |
6892 | curNode->getNthResult(resultIdx).replaceAllUsesOfWith(concat); |
6893 | return concat; |
6894 | } |
6895 | |
6896 | /// Helper to parallelize a node \p curNode from \p F into \p numOfChunksNode |
6897 | /// Nodes by slicing its inputs, creating clones of it and changing the inputs |
6898 | /// of the clones to the slices, and then concatenating all of the clones |
6899 | /// together and replacing \p curNode with the concat. \p inputBatchIdx is the |
6900 | /// input idx from \p curNode that will be split (there may be more than one |
6901 | /// input to split, but their splitDim should all have the same size). |
6902 | /// \p splitDim represents what dimension to split for each of the inputs to |
6903 | /// \p curNode. \p resultDim is the dimension on which we are splitting and then |
6904 | /// concatenating the results. \p resultIdx represents the result index from |
6905 | /// \p curNode that is being split and later concatenated. The size of the |
6906 | /// splits will be increased to a multiple of \p modelParallelSplitAlignment, if |
6907 | /// possible. If the result after aligning the splits is that the new aligned |
6908 | /// splits are larger than the original requested num splits, then the number of |
6909 | /// resulting splits may be less than requested. \returns an Expected of the |
6910 | /// ConcatNode that is created and replaces \p curNode, or otherwise an Error if |
6911 | /// parallelization had some issue. |
6912 | static Expected<ConcatNode *> |
6913 | parallelizeAndReplaceNode(Function *F, Node *curNode, dim_t numOfChunksNode, |
6914 | dim_t inputBatchIdx, dim_t resultIdx, |
6915 | llvm::ArrayRef<int> splitDims, size_t resultDim, |
6916 | dim_t modelParallelSplitAlignment = 1) { |
6917 | const int inputIdx = splitDims[inputBatchIdx]; |
6918 | CHECK_GE(inputIdx, 0) << "Input batch idx must be split" ; |
6919 | const dim_t batchSize = curNode->getNthInput(inputBatchIdx).dims()[inputIdx]; |
6920 | const dim_t elemPerChunk = batchSize / numOfChunksNode; |
6921 | const dim_t remain = batchSize % numOfChunksNode; |
6922 | // This alignment will create aligned splits. So for example, if we're |
6923 | // splitting 190 by 3, then without alignment it would be {64, 63, 63}. |
6924 | // With alignment of 64, it will be {64, 64, 62} |
6925 | const dim_t alignedElemPerChunk = |
6926 | ((elemPerChunk + (modelParallelSplitAlignment - 1)) / |
6927 | modelParallelSplitAlignment) * |
6928 | modelParallelSplitAlignment; |
6929 | |
6930 | RETURN_ERR_IF_NOT( |
6931 | batchSize >= numOfChunksNode, |
6932 | "Invalid parallelization; batchSize " + std::to_string(batchSize) + |
6933 | " must be >= numOfChunksNode " + std::to_string(numOfChunksNode) + |
6934 | " for node " + curNode->getName().str() + " with kind " + |
6935 | curNode->getKindName()); |
6936 | |
6937 | // Potentially modify numOfChunksNode, if the aligned size times current |
6938 | // numOfChunksNode exceeds the total size |
6939 | if (modelParallelSplitAlignment > 1) { |
6940 | numOfChunksNode = |
6941 | (batchSize + alignedElemPerChunk - 1) / alignedElemPerChunk; |
6942 | } |
6943 | |
6944 | std::vector<NodeValue> newNodes(numOfChunksNode); |
6945 | for (dim_t i = 0; i < numOfChunksNode; ++i) { |
6946 | // Calculate the out type of this chunk. |
6947 | dim_t sliceStart, sliceEnd; |
6948 | |
6949 | if (modelParallelSplitAlignment > 1) { |
6950 | // If we are using aligned splits, then slice by multiples of the |
6951 | // alignment, leaving the rest to the last split. The last split is |
6952 | // necessarily smaller than the other splits. |
6953 | sliceStart = i * alignedElemPerChunk; |
6954 | sliceEnd = (i < numOfChunksNode - 1) ? sliceStart + alignedElemPerChunk |
6955 | : batchSize; |
6956 | } else { |
6957 | // Otherwise, distribute elements evenly across the splits and sprinkle |
6958 | // the remainder evenly as well |
6959 | sliceStart = i * elemPerChunk + std::min(i, remain); |
6960 | sliceEnd = sliceStart + elemPerChunk + ((i < remain) ? 1 : 0); |
6961 | } |
6962 | VLOG(1) << "\tChunk " << i << ": start: " << sliceStart |
6963 | << " end: " << sliceEnd << "\n" ; |
6964 | auto outDims = curNode->dims(resultIdx).vec(); |
6965 | outDims[resultDim] = (sliceEnd - sliceStart); |
6966 | for (auto outDim : outDims) { |
6967 | VLOG(1) << "outDim: " << outDim << "\n" ; |
6968 | } |
6969 | |
6970 | // Clone the original Node, so that it keeps all of the inputs/members of |
6971 | // the original Node. Then modify the output type so that its new shape is |
6972 | // correct, and below change the inputs the sliced inputs. |
6973 | Node *clone = curNode->clone(); |
6974 | clone->getNthResult(resultIdx).setTypeUnsafe( |
6975 | F->getParent()->uniqueTypeWithNewShape(curNode->getType(resultIdx), |
6976 | outDims)); |
6977 | F->addNode(clone); |
6978 | |
6979 | // Loop over all of the inputs and slice those inputs that need to be |
6980 | // sliced, and set them on the clone. |
6981 | for (int j = 0, e = curNode->getNumInputs(); j < e; j++) { |
6982 | int dim = splitDims[j]; |
6983 | if (dim == -1) { |
6984 | continue; |
6985 | } |
6986 | |
6987 | NodeValue currInput = curNode->getNthInput(j); |
6988 | auto sliceDimsStart = std::vector<dim_t>(currInput.dims().size(), 0); |
6989 | sliceDimsStart[dim] = sliceStart; |
6990 | auto sliceDimsEnd = currInput.dims().vec(); |
6991 | sliceDimsEnd[dim] = sliceEnd; |
6992 | VLOG(1) << "start: " ; |
6993 | for (auto sliceDimStart : sliceDimsStart) { |
6994 | VLOG(1) << sliceDimStart << "\n" ; |
6995 | } |
6996 | VLOG(1) << "end: " ; |
6997 | for (auto sliceDimEnd : sliceDimsEnd) { |
6998 | VLOG(1) << sliceDimEnd << "\n" ; |
6999 | } |
7000 | VLOG(1) << "Input name: " << currInput.getNode()->getName().str() << "\n" ; |
7001 | |
7002 | auto *inputSlice = |
7003 | F->createSlice("dp_slice." + currInput.getNode()->getName().str() + |
7004 | "." + std::to_string(i), |
7005 | currInput, sliceDimsStart, sliceDimsEnd); |
7006 | clone->setNthInput(j, inputSlice); |
7007 | |
7008 | newNodes[i] = NodeValue(clone, resultIdx); |
7009 | } |
7010 | } |
7011 | |
7012 | // Now that we have split the node into many, concat all of the pieces back |
7013 | // together and replace the original by the concat. |
7014 | VLOG(1) << "Creating Concat" ; |
7015 | auto *concat = F->createConcat("concat." + curNode->getName().str(), newNodes, |
7016 | resultDim); |
7017 | curNode->getNthResult(resultIdx).replaceAllUsesOfWith(concat); |
7018 | return concat; |
7019 | } |
7020 | |
7021 | /// Specialized helper for parallelizing ConcatNode \p CN from \p F into |
7022 | /// \p numOfChunks. |
7023 | static Expected<ConcatNode *> |
7024 | parallelizeAndReplaceConcat(Function *F, ConcatNode *CN, dim_t numOfChunks) { |
7025 | auto in = CN->getInputs(); |
7026 | |
7027 | const dim_t inputSize = in.size(); |
7028 | const dim_t elemPerChunk = inputSize / numOfChunks; |
7029 | const dim_t remain = inputSize % numOfChunks; |
7030 | |
7031 | RETURN_ERR_IF_NOT( |
7032 | elemPerChunk > 0, |
7033 | "When parallelizing a Concat, inputSize must be larger than numOfChunks" ); |
7034 | |
7035 | auto startIt = in.begin(); |
7036 | std::vector<NodeValue> finalConcatInputs; |
7037 | for (dim_t i = 0; i < numOfChunks; ++i) { |
7038 | const dim_t sliceStart = i * elemPerChunk + std::min(i, remain); |
7039 | const dim_t sliceEnd = sliceStart + elemPerChunk + ((i < remain) ? 1 : 0); |
7040 | |
7041 | // Slice out the original Concat chunk's inputs, create a new Concat with |
7042 | // the slice, and then add the new Concat to the final Concat's inputs. |
7043 | std::vector<NodeValue> newInputs(startIt + sliceStart, startIt + sliceEnd); |
7044 | ConcatNode *concatSlice = F->createConcat(CN->getName().str() + "_slice" , |
7045 | newInputs, CN->getDim()); |
7046 | finalConcatInputs.push_back(concatSlice); |
7047 | } |
7048 | ConcatNode *finalConcat = F->createConcat(CN->getName().str() + "_merge" , |
7049 | finalConcatInputs, CN->getDim()); |
7050 | CN->getResult().replaceAllUsesOfWith(finalConcat); |
7051 | return finalConcat; |
7052 | } |
7053 | |
7054 | #define SPLIT_ELTWISE_UNARY_OP_HELPER(NodeKind, Axis) \ |
7055 | if (curNode->getNthInput(NodeKind##Node::InputIdx).dims().size() <= Axis) { \ |
7056 | break; \ |
7057 | } \ |
7058 | splitDims[NodeKind##Node::InputIdx] = Axis; \ |
7059 | ASSIGN_VALUE_OR_RETURN_ERR( \ |
7060 | CN, parallelizeAndReplaceNode( \ |
7061 | F, curNode, curNumOfChunks, NodeKind##Node::InputIdx, \ |
7062 | NodeKind##Node::ResultIdx, splitDims, /*resultDim*/ Axis, \ |
7063 | modelParallelSplitAlignment)); |
7064 | |
7065 | #define SPLIT_ELTWISE_BINARY_OP_HELPER(NodeKind, Axis) \ |
7066 | if (curNode->getNthInput(NodeKind##Node::LHSIdx).dims().size() <= Axis) { \ |
7067 | break; \ |
7068 | } \ |
7069 | splitDims[NodeKind##Node::LHSIdx] = Axis; \ |
7070 | splitDims[NodeKind##Node::RHSIdx] = Axis; \ |
7071 | ASSIGN_VALUE_OR_RETURN_ERR( \ |
7072 | CN, parallelizeAndReplaceNode( \ |
7073 | F, curNode, curNumOfChunks, NodeKind##Node::LHSIdx, \ |
7074 | NodeKind##Node::ResultIdx, splitDims, /*resultDim*/ Axis, \ |
7075 | modelParallelSplitAlignment)); |
7076 | |
7077 | Expected<std::unordered_map<Node *, ConcatNode *>> glow::parallelizeOps( |
7078 | Function *F, const llvm::DenseMap<Node *, size_t> &numOfChunksMap, |
7079 | const llvm::DenseMap<Node *, ParallelTransformKind> &parOpts, |
7080 | size_t numOfChunks, size_t modelParallelSplitAlignment) { |
7081 | // Since we will be transforming the original list of nodes, reverse iterate. |
7082 | auto &nodes = F->getNodes(); |
7083 | size_t numProcessedNodes = 0; |
7084 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
7085 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) { |
7086 | Node *curNode = &*it; |
7087 | size_t curNumOfChunks = numOfChunks; |
7088 | auto numOfChunksIt = numOfChunksMap.find(curNode); |
7089 | if (numOfChunksIt != numOfChunksMap.end()) { |
7090 | curNumOfChunks = numOfChunksIt->second; |
7091 | } |
7092 | |
7093 | ParallelTransformKind parTransformMode = ParallelTransformKind::None; |
7094 | auto parOptsIt = parOpts.find(curNode); |
7095 | if (parOptsIt != parOpts.end()) { |
7096 | parTransformMode = parOptsIt->second; |
7097 | ++numProcessedNodes; |
7098 | } |
7099 | |
7100 | VLOG(1) << "Attempting to Parallelizing Node: " << curNode->getName().str() |
7101 | << "\n" ; |
7102 | |
7103 | ConcatNode *CN = nullptr; |
7104 | |
7105 | // Use this vector to communicate what dims to split to |
7106 | // parallelizeAndReplaceNode(). -1 represents not splitting at all. |
7107 | llvm::SmallVector<int, 3> splitDims(curNode->getNumInputs(), -1); |
7108 | |
7109 | // Set model parallelization axis |
7110 | // Default model parallelization is along axis = 1, hence the default value. |
7111 | dim_t modelParAxis = 1; |
7112 | switch (parTransformMode) { |
7113 | #define MODEL_AXIS_CASE(_N) \ |
7114 | case ParallelTransformKind::Model_Axis##_N: \ |
7115 | modelParAxis = _N; \ |
7116 | parTransformMode = ParallelTransformKind::Model; \ |
7117 | break; |
7118 | MODEL_AXIS_CASE(1) |
7119 | MODEL_AXIS_CASE(2) |
7120 | MODEL_AXIS_CASE(3) |
7121 | MODEL_AXIS_CASE(4) |
7122 | MODEL_AXIS_CASE(5) |
7123 | default: |
7124 | break; |
7125 | } |
7126 | |
7127 | switch (parTransformMode) { |
7128 | case ParallelTransformKind::Data: { |
7129 | switch (curNode->getKind()) { |
7130 | case Kinded::Kind::FullyConnectedNodeKind: { |
7131 | splitDims[FullyConnectedNode::InputIdx] = 0; |
7132 | ASSIGN_VALUE_OR_RETURN_ERR( |
7133 | CN, parallelizeAndReplaceNode( |
7134 | F, curNode, curNumOfChunks, FullyConnectedNode::InputIdx, |
7135 | FullyConnectedNode::ResultIdx, splitDims, 0)); |
7136 | break; |
7137 | } |
7138 | case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind: { |
7139 | splitDims[RowwiseQuantizedFullyConnectedNode::InputIdx] = 0; |
7140 | ASSIGN_VALUE_OR_RETURN_ERR( |
7141 | CN, |
7142 | parallelizeAndReplaceNode( |
7143 | F, curNode, curNumOfChunks, |
7144 | RowwiseQuantizedFullyConnectedNode::InputIdx, |
7145 | RowwiseQuantizedFullyConnectedNode::ResultIdx, splitDims, 0)); |
7146 | break; |
7147 | } |
7148 | case Kinded::Kind::MatMulNodeKind: { |
7149 | splitDims[MatMulNode::LHSIdx] = 0; |
7150 | ASSIGN_VALUE_OR_RETURN_ERR( |
7151 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7152 | MatMulNode::LHSIdx, |
7153 | MatMulNode::ResultIdx, splitDims, 0)); |
7154 | break; |
7155 | } |
7156 | case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: { |
7157 | splitDims[ChannelwiseQuantizedConvolutionNode::InputIdx] = 0; |
7158 | ASSIGN_VALUE_OR_RETURN_ERR( |
7159 | CN, |
7160 | parallelizeAndReplaceNode( |
7161 | F, curNode, curNumOfChunks, |
7162 | ChannelwiseQuantizedConvolutionNode::InputIdx, |
7163 | ChannelwiseQuantizedConvolutionNode::ResultIdx, splitDims, 0)); |
7164 | break; |
7165 | } |
7166 | case Kinded::Kind::ConvolutionNodeKind: { |
7167 | splitDims[ConvolutionNode::InputIdx] = 0; |
7168 | ASSIGN_VALUE_OR_RETURN_ERR( |
7169 | CN, parallelizeAndReplaceNode( |
7170 | F, curNode, curNumOfChunks, ConvolutionNode::InputIdx, |
7171 | ConvolutionNode::ResultIdx, splitDims, 0)); |
7172 | break; |
7173 | } |
7174 | case Kinded::Kind::AdaptiveAvgPoolNodeKind: { |
7175 | splitDims[AdaptiveAvgPoolNode::InputIdx] = 0; |
7176 | ASSIGN_VALUE_OR_RETURN_ERR( |
7177 | CN, parallelizeAndReplaceNode( |
7178 | F, curNode, curNumOfChunks, AdaptiveAvgPoolNode::InputIdx, |
7179 | AdaptiveAvgPoolNode::ResultIdx, splitDims, 0)); |
7180 | break; |
7181 | } |
7182 | case Kinded::Kind::ROIAlignNodeKind: { |
7183 | splitDims[ROIAlignNode::BoxesIdx] = 0; |
7184 | splitDims[ROIAlignNode::BatchIndicesIdx] = 0; |
7185 | ASSIGN_VALUE_OR_RETURN_ERR( |
7186 | CN, parallelizeAndReplaceNode( |
7187 | F, curNode, curNumOfChunks, ROIAlignNode::BoxesIdx, |
7188 | ROIAlignNode::ResultIdx, splitDims, 0)); |
7189 | break; |
7190 | } |
7191 | case Kinded::Kind::MaxPoolNodeKind: { |
7192 | splitDims[MaxPoolNode::InputIdx] = 0; |
7193 | ASSIGN_VALUE_OR_RETURN_ERR( |
7194 | CN, parallelizeAndReplaceNode( |
7195 | F, curNode, curNumOfChunks, MaxPoolNode::InputIdx, |
7196 | MaxPoolNode::ResultIdx, splitDims, 0)); |
7197 | break; |
7198 | } |
7199 | case Kinded::Kind::ReshapeNodeKind: { |
7200 | splitDims[ReshapeNode::InputIdx] = 0; |
7201 | const dim_t batchSize = |
7202 | curNode->getNthInput(ReshapeNode::InputIdx).dims()[0]; |
7203 | if (batchSize != curNode->getNthResult(0).dims()[0]) { |
7204 | if (glow::flags::SparseNNParallelizeReshapeOnBatchDim) { |
7205 | ASSIGN_VALUE_OR_RETURN_ERR( |
7206 | CN, parallelizeAndReplaceReshapeNode( |
7207 | F, curNode, curNumOfChunks, ReshapeNode::InputIdx, |
7208 | ReshapeNode::ResultIdx, splitDims, 0)); |
7209 | } else { |
7210 | // Do nothing if reshape applies to the first batch dimension |
7211 | LOG(INFO) << "Reshape changes batch dimension; Disabling data " |
7212 | "parallel split" ; |
7213 | } |
7214 | break; |
7215 | } |
7216 | ASSIGN_VALUE_OR_RETURN_ERR( |
7217 | CN, parallelizeAndReplaceNode( |
7218 | F, curNode, curNumOfChunks, ReshapeNode::InputIdx, |
7219 | ReshapeNode::ResultIdx, splitDims, 0)); |
7220 | break; |
7221 | } |
7222 | case Kinded::Kind::AddNodeKind: { |
7223 | splitDims[AddNode::LHSIdx] = 0; |
7224 | splitDims[AddNode::RHSIdx] = 0; |
7225 | ASSIGN_VALUE_OR_RETURN_ERR( |
7226 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7227 | AddNode::LHSIdx, AddNode::ResultIdx, |
7228 | splitDims, 0)); |
7229 | break; |
7230 | } |
7231 | case Kinded::Kind::SubNodeKind: { |
7232 | splitDims[SubNode::LHSIdx] = 0; |
7233 | splitDims[SubNode::RHSIdx] = 0; |
7234 | ASSIGN_VALUE_OR_RETURN_ERR( |
7235 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7236 | SubNode::LHSIdx, SubNode::ResultIdx, |
7237 | splitDims, 0)); |
7238 | break; |
7239 | } |
7240 | case Kinded::Kind::BatchMatMulNodeKind: { |
7241 | splitDims[BatchMatMulNode::LHSIdx] = 0; |
7242 | splitDims[BatchMatMulNode::RHSIdx] = 0; |
7243 | ASSIGN_VALUE_OR_RETURN_ERR( |
7244 | CN, parallelizeAndReplaceNode( |
7245 | F, curNode, curNumOfChunks, BatchMatMulNode::LHSIdx, |
7246 | BatchMatMulNode::ResultIdx, splitDims, 0)); |
7247 | break; |
7248 | } |
7249 | case Kinded::Kind::MulNodeKind: { |
7250 | splitDims[AddNode::LHSIdx] = 0; |
7251 | splitDims[AddNode::RHSIdx] = 0; |
7252 | ASSIGN_VALUE_OR_RETURN_ERR( |
7253 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7254 | MulNode::LHSIdx, MulNode::ResultIdx, |
7255 | splitDims, 0)); |
7256 | break; |
7257 | } |
7258 | case Kinded::Kind::PowNodeKind: { |
7259 | splitDims[PowNode::LHSIdx] = 0; |
7260 | splitDims[PowNode::RHSIdx] = 0; |
7261 | ASSIGN_VALUE_OR_RETURN_ERR( |
7262 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7263 | PowNode::LHSIdx, PowNode::ResultIdx, |
7264 | splitDims, 0)); |
7265 | break; |
7266 | } |
7267 | case Kinded::Kind::SelectNodeKind: { |
7268 | splitDims[SelectNode::LHSIdx] = 0; |
7269 | splitDims[SelectNode::RHSIdx] = 0; |
7270 | splitDims[SelectNode::CondIdx] = 0; |
7271 | ASSIGN_VALUE_OR_RETURN_ERR( |
7272 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7273 | SelectNode::LHSIdx, |
7274 | SelectNode::ResultIdx, splitDims, 0)); |
7275 | break; |
7276 | } |
7277 | case Kinded::Kind::ExpNodeKind: { |
7278 | splitDims[ExpNode::InputIdx] = 0; |
7279 | ASSIGN_VALUE_OR_RETURN_ERR( |
7280 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7281 | ExpNode::InputIdx, ExpNode::ResultIdx, |
7282 | splitDims, 0)); |
7283 | break; |
7284 | } |
7285 | case Kinded::Kind::SigmoidNodeKind: { |
7286 | splitDims[SigmoidNode::InputIdx] = 0; |
7287 | ASSIGN_VALUE_OR_RETURN_ERR( |
7288 | CN, parallelizeAndReplaceNode( |
7289 | F, curNode, curNumOfChunks, SigmoidNode::InputIdx, |
7290 | SigmoidNode::ResultIdx, splitDims, 0)); |
7291 | break; |
7292 | } |
7293 | case Kinded::Kind::SoftMaxNodeKind: { |
7294 | splitDims[SoftMaxNode::InputIdx] = 0; |
7295 | ASSIGN_VALUE_OR_RETURN_ERR( |
7296 | CN, parallelizeAndReplaceNode( |
7297 | F, curNode, curNumOfChunks, SoftMaxNode::InputIdx, |
7298 | SoftMaxNode::ResultIdx, splitDims, 0)); |
7299 | break; |
7300 | } |
7301 | case Kinded::Kind::LogSoftMaxNodeKind: { |
7302 | splitDims[LogSoftMaxNode::InputIdx] = 0; |
7303 | ASSIGN_VALUE_OR_RETURN_ERR( |
7304 | CN, parallelizeAndReplaceNode( |
7305 | F, curNode, curNumOfChunks, LogSoftMaxNode::InputIdx, |
7306 | LogSoftMaxNode::ResultIdx, splitDims, 0)); |
7307 | break; |
7308 | } |
7309 | case Kinded::Kind::TanhNodeKind: { |
7310 | splitDims[TanhNode::InputIdx] = 0; |
7311 | ASSIGN_VALUE_OR_RETURN_ERR( |
7312 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7313 | TanhNode::InputIdx, |
7314 | TanhNode::ResultIdx, splitDims, 0)); |
7315 | break; |
7316 | } |
7317 | case Kinded::Kind::SwishNodeKind: { |
7318 | splitDims[SwishNode::InputIdx] = 0; |
7319 | ASSIGN_VALUE_OR_RETURN_ERR( |
7320 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7321 | SwishNode::InputIdx, |
7322 | SwishNode::ResultIdx, splitDims, 0)); |
7323 | break; |
7324 | } |
7325 | case Kinded::Kind::MaxNodeKind: { |
7326 | splitDims[MaxNode::LHSIdx] = 0; |
7327 | splitDims[MaxNode::RHSIdx] = 0; |
7328 | ASSIGN_VALUE_OR_RETURN_ERR( |
7329 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7330 | MaxNode::LHSIdx, MaxNode::ResultIdx, |
7331 | splitDims, 0)); |
7332 | break; |
7333 | } |
7334 | case Kinded::Kind::MinNodeKind: { |
7335 | splitDims[MinNode::LHSIdx] = 0; |
7336 | splitDims[MinNode::RHSIdx] = 0; |
7337 | ASSIGN_VALUE_OR_RETURN_ERR( |
7338 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7339 | MinNode::LHSIdx, MinNode::ResultIdx, |
7340 | splitDims, 0)); |
7341 | break; |
7342 | } |
7343 | case Kinded::Kind::TransposeNodeKind: { |
7344 | auto shuffleVec = cast<TransposeNode>(curNode)->getShuffle(); |
7345 | unsigned_t inputDim = shuffleVec[0]; |
7346 | splitDims[TransposeNode::InputIdx] = inputDim; |
7347 | ASSIGN_VALUE_OR_RETURN_ERR( |
7348 | CN, parallelizeAndReplaceNode( |
7349 | F, curNode, curNumOfChunks, TransposeNode::InputIdx, |
7350 | TransposeNode::ResultIdx, splitDims, 0)); |
7351 | break; |
7352 | } |
7353 | case Kinded::Kind::ReluNodeKind: { |
7354 | splitDims[ReluNode::InputIdx] = 0; |
7355 | ASSIGN_VALUE_OR_RETURN_ERR( |
7356 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7357 | ReluNode::InputIdx, |
7358 | ReluNode::ResultIdx, splitDims, 0)); |
7359 | break; |
7360 | } |
7361 | case Kinded::Kind::GeluNodeKind: { |
7362 | splitDims[GeluNode::InputIdx] = 0; |
7363 | ASSIGN_VALUE_OR_RETURN_ERR( |
7364 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7365 | GeluNode::InputIdx, |
7366 | GeluNode::ResultIdx, splitDims, 0)); |
7367 | break; |
7368 | } |
7369 | case Kinded::Kind::ClipNodeKind: { |
7370 | splitDims[ClipNode::InputIdx] = 0; |
7371 | ASSIGN_VALUE_OR_RETURN_ERR( |
7372 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7373 | ClipNode::InputIdx, |
7374 | ClipNode::ResultIdx, splitDims, 0)); |
7375 | break; |
7376 | } |
7377 | case Kinded::Kind::TileNodeKind: { |
7378 | TileNode *TN = llvm::dyn_cast<TileNode>(curNode); |
7379 | RETURN_ERR_IF_NOT( |
7380 | TN->getAxis() != 0, |
7381 | "Tile node cannot be split on axis 0 which is being replicated" ); |
7382 | splitDims[TileNode::InputIdx] = 0; |
7383 | ASSIGN_VALUE_OR_RETURN_ERR( |
7384 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7385 | TileNode::InputIdx, |
7386 | TileNode::ResultIdx, splitDims, 0)); |
7387 | break; |
7388 | } |
7389 | case Kinded::Kind::BatchedReduceAddNodeKind: { |
7390 | BatchedReduceAddNode *BR = llvm::cast<BatchedReduceAddNode>(curNode); |
7391 | splitDims[BatchedReduceAddNode::BatchIdx] = |
7392 | (BR->getAxis() == 0) ? 1 : 0; |
7393 | ASSIGN_VALUE_OR_RETURN_ERR( |
7394 | CN, parallelizeAndReplaceNode( |
7395 | F, curNode, curNumOfChunks, BatchedReduceAddNode::BatchIdx, |
7396 | BatchedReduceAddNode::ResultIdx, splitDims, 0)); |
7397 | break; |
7398 | } |
7399 | case Kinded::Kind::BatchedReduceMeanNodeKind: { |
7400 | auto *BR = llvm::cast<BatchedReduceMeanNode>(curNode); |
7401 | const auto &BRaxes = BR->getAxes(); |
7402 | if (std::find(BRaxes.begin(), BRaxes.end(), 0) != BRaxes.end()) { |
7403 | LOG(INFO) << "BatchedReduceMean along the first dimension not " |
7404 | "parallelized. Current node: " |
7405 | << BR->getDebugDesc(); |
7406 | } else { |
7407 | splitDims[BatchedReduceMeanNode::BatchIdx] = 0; |
7408 | ASSIGN_VALUE_OR_RETURN_ERR( |
7409 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7410 | BatchedReduceMeanNode::BatchIdx, |
7411 | BatchedReduceMeanNode::ResultIdx, |
7412 | splitDims, 0)); |
7413 | } |
7414 | break; |
7415 | } |
7416 | case Kinded::Kind::ConcatNodeKind: { |
7417 | ConcatNode *concat = llvm::cast<ConcatNode>(curNode); |
7418 | RETURN_ERR_IF_NOT(concat->getDim() == 0, |
7419 | "Expected to Data parallelize for concat on dim 0" ); |
7420 | ASSIGN_VALUE_OR_RETURN_ERR( |
7421 | CN, parallelizeAndReplaceConcat(F, concat, curNumOfChunks)); |
7422 | break; |
7423 | } |
7424 | case Kinded::Kind::LayerNormalizationNodeKind: { |
7425 | splitDims[LayerNormalizationNode::InputIdx] = 0; |
7426 | ASSIGN_VALUE_OR_RETURN_ERR( |
7427 | CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks, |
7428 | LayerNormalizationNode::InputIdx, |
7429 | LayerNormalizationNode::ResultIdx, |
7430 | splitDims, 0)); |
7431 | break; |
7432 | } |
7433 | case Kinded::Kind::QuantizeNodeKind: { |
7434 | splitDims[QuantizeNode::InputIdx] = 0; |
7435 | ASSIGN_VALUE_OR_RETURN_ERR( |
7436 | CN, parallelizeAndReplaceNode( |
7437 | F, curNode, curNumOfChunks, QuantizeNode::InputIdx, |
7438 | QuantizeNode::ResultIdx, splitDims, 0)); |
7439 | break; |
7440 | } |
7441 | case Kinded::Kind::DequantizeNodeKind: { |
7442 | splitDims[DequantizeNode::InputIdx] = 0; |
7443 | ASSIGN_VALUE_OR_RETURN_ERR( |
7444 | CN, parallelizeAndReplaceNode( |
7445 | F, curNode, curNumOfChunks, DequantizeNode::InputIdx, |
7446 | DequantizeNode::ResultIdx, splitDims, 0)); |
7447 | break; |
7448 | } |
7449 | case Kinded::Kind::RescaleQuantizedNodeKind: { |
7450 | splitDims[RescaleQuantizedNode::InputIdx] = 0; |
7451 | ASSIGN_VALUE_OR_RETURN_ERR( |
7452 | CN, parallelizeAndReplaceNode( |
7453 | F, curNode, curNumOfChunks, RescaleQuantizedNode::InputIdx, |
7454 | RescaleQuantizedNode::ResultIdx, splitDims, 0)); |
7455 | break; |
7456 | } |
7457 | case Kinded::Kind::ConvertToNodeKind: { |
7458 | splitDims[ConvertToNode::InputIdx] = 0; |
7459 | ASSIGN_VALUE_OR_RETURN_ERR( |
7460 | CN, parallelizeAndReplaceNode( |
7461 | F, curNode, curNumOfChunks, ConvertToNode::InputIdx, |
7462 | ConvertToNode::ResultIdx, splitDims, 0)); |
7463 | break; |
7464 | } |
7465 | default: |
7466 | VLOG(1) << "Attempted to parallelize op type " << curNode->getKindName() |
7467 | << "not yet supported" |
7468 | << "\n" ; |
7469 | break; |
7470 | } |
7471 | break; |
7472 | } |
7473 | |
7474 | case ParallelTransformKind::Model: { |
7475 | switch (curNode->getKind()) { |
7476 | case Kinded::Kind::FullyConnectedNodeKind: { |
7477 | if (modelParAxis != 1) { |
7478 | break; |
7479 | } |
7480 | splitDims[FullyConnectedNode::WeightsIdx] = 1; |
7481 | splitDims[FullyConnectedNode::BiasIdx] = 0; |
7482 | ASSIGN_VALUE_OR_RETURN_ERR( |
7483 | CN, parallelizeAndReplaceNode( |
7484 | F, curNode, curNumOfChunks, FullyConnectedNode::WeightsIdx, |
7485 | FullyConnectedNode::ResultIdx, splitDims, /*resultDim*/ 1, |
7486 | modelParallelSplitAlignment)); |
7487 | break; |
7488 | } |
7489 | case Kinded::Kind::MatMulNodeKind: { |
7490 | if (modelParAxis != 1) { |
7491 | break; |
7492 | } |
7493 | splitDims[MatMulNode::RHSIdx] = 1; |
7494 | ASSIGN_VALUE_OR_RETURN_ERR( |
7495 | CN, parallelizeAndReplaceNode( |
7496 | F, curNode, curNumOfChunks, MatMulNode::RHSIdx, |
7497 | MatMulNode::ResultIdx, splitDims, /*resultDim*/ 1, |
7498 | modelParallelSplitAlignment)); |
7499 | break; |
7500 | } |
7501 | case Kinded::Kind::ReluNodeKind: { |
7502 | SPLIT_ELTWISE_UNARY_OP_HELPER(Relu, modelParAxis); |
7503 | break; |
7504 | } |
7505 | case Kinded::Kind::ClipNodeKind: { |
7506 | SPLIT_ELTWISE_UNARY_OP_HELPER(Clip, modelParAxis); |
7507 | break; |
7508 | } |
7509 | case Kinded::Kind::SelectNodeKind: { |
7510 | if (modelParAxis != 1) { |
7511 | break; |
7512 | } |
7513 | auto *SL = llvm::cast<SelectNode>(curNode); |
7514 | if (SL->getNthInput(SelectNode::LHSIdx).dims().size() < 2) { |
7515 | break; |
7516 | } |
7517 | splitDims[SelectNode::LHSIdx] = 1; |
7518 | splitDims[SelectNode::RHSIdx] = 1; |
7519 | splitDims[SelectNode::CondIdx] = 1; |
7520 | ASSIGN_VALUE_OR_RETURN_ERR( |
7521 | CN, parallelizeAndReplaceNode( |
7522 | F, curNode, curNumOfChunks, SelectNode::LHSIdx, |
7523 | SelectNode::ResultIdx, splitDims, |
7524 | /*resultDim*/ 1, modelParallelSplitAlignment)); |
7525 | break; |
7526 | } |
7527 | case Kinded::Kind::AddNodeKind: { |
7528 | SPLIT_ELTWISE_BINARY_OP_HELPER(Add, modelParAxis); |
7529 | break; |
7530 | } |
7531 | case Kinded::Kind::TileNodeKind: { |
7532 | if (modelParAxis != 1) { |
7533 | break; |
7534 | } |
7535 | if (curNode->getNthInput(TileNode::InputIdx).dims().size() < 2) { |
7536 | break; |
7537 | } |
7538 | TileNode *TN = llvm::dyn_cast<TileNode>(curNode); |
7539 | RETURN_ERR_IF_NOT( |
7540 | TN->getAxis() != 1, |
7541 | "Tile node cannot be split on axis 1 which is being replicated" ); |
7542 | splitDims[TileNode::InputIdx] = 1; |
7543 | ASSIGN_VALUE_OR_RETURN_ERR( |
7544 | CN, parallelizeAndReplaceNode( |
7545 | F, curNode, curNumOfChunks, TileNode::InputIdx, |
7546 | TileNode::ResultIdx, splitDims, |
7547 | /*resultDim*/ 1, modelParallelSplitAlignment)); |
7548 | break; |
7549 | } |
7550 | case Kinded::Kind::ConcatNodeKind: { |
7551 | if (modelParAxis != 1) { |
7552 | break; |
7553 | } |
7554 | ConcatNode *concat = llvm::cast<ConcatNode>(curNode); |
7555 | RETURN_ERR_IF_NOT(concat->getDim() == 1, |
7556 | "Expected to Model parallelize for concat on dim 1" ); |
7557 | ASSIGN_VALUE_OR_RETURN_ERR( |
7558 | CN, parallelizeAndReplaceConcat(F, concat, curNumOfChunks)); |
7559 | break; |
7560 | } |
7561 | case Kinded::Kind::QuantizeNodeKind: { |
7562 | SPLIT_ELTWISE_UNARY_OP_HELPER(Quantize, modelParAxis); |
7563 | break; |
7564 | } |
7565 | case Kinded::Kind::DequantizeNodeKind: { |
7566 | SPLIT_ELTWISE_UNARY_OP_HELPER(Dequantize, modelParAxis); |
7567 | break; |
7568 | } |
7569 | case Kinded::Kind::BatchedReduceMeanNodeKind: { |
7570 | if (curNode->getNthInput(BatchedReduceMeanNode::BatchIdx) |
7571 | .dims() |
7572 | .size() <= modelParAxis) { |
7573 | break; |
7574 | } |
7575 | splitDims[BatchedReduceMeanNode::BatchIdx] = modelParAxis; |
7576 | ASSIGN_VALUE_OR_RETURN_ERR( |
7577 | CN, parallelizeAndReplaceNode( |
7578 | F, curNode, curNumOfChunks, BatchedReduceMeanNode::BatchIdx, |
7579 | BatchedReduceMeanNode::ResultIdx, splitDims, |
7580 | /*resultDim*/ modelParAxis, modelParallelSplitAlignment)); |
7581 | break; |
7582 | } |
7583 | case Kinded::Kind::RescaleQuantizedNodeKind: { |
7584 | SPLIT_ELTWISE_UNARY_OP_HELPER(RescaleQuantized, modelParAxis); |
7585 | break; |
7586 | } |
7587 | case Kinded::Kind::BatchNormalizationNodeKind: { |
7588 | auto *BN = llvm::cast<BatchNormalizationNode>(curNode); |
7589 | |
7590 | if (modelParAxis != BN->getChannelIdx()) { |
7591 | break; |
7592 | } |
7593 | if (BN->getInput().dims().size() <= modelParAxis) { |
7594 | break; |
7595 | } |
7596 | splitDims[BatchNormalizationNode::InputIdx] = modelParAxis; |
7597 | splitDims[BatchNormalizationNode::ScaleIdx] = 0; |
7598 | splitDims[BatchNormalizationNode::BiasIdx] = 0; |
7599 | splitDims[BatchNormalizationNode::MeanIdx] = 0; |
7600 | splitDims[BatchNormalizationNode::VarIdx] = 0; |
7601 | |
7602 | ASSIGN_VALUE_OR_RETURN_ERR( |
7603 | CN, |
7604 | parallelizeAndReplaceNode( |
7605 | F, curNode, curNumOfChunks, BatchNormalizationNode::InputIdx, |
7606 | BatchNormalizationNode::ResultIdx, splitDims, |
7607 | /*resultDim*/ modelParAxis, modelParallelSplitAlignment)); |
7608 | break; |
7609 | } |
7610 | case Kinded::Kind::ResizeNearestNodeKind: { |
7611 | SPLIT_ELTWISE_UNARY_OP_HELPER(ResizeNearest, modelParAxis); |
7612 | break; |
7613 | } |
7614 | case Kinded::Kind::ConvolutionNodeKind: { |
7615 | if (modelParAxis != 3) { |
7616 | break; |
7617 | } |
7618 | splitDims[ConvolutionNode::FilterIdx] = 0; |
7619 | splitDims[ConvolutionNode::BiasIdx] = 0; |
7620 | ASSIGN_VALUE_OR_RETURN_ERR( |
7621 | CN, parallelizeAndReplaceNode( |
7622 | F, curNode, curNumOfChunks, ConvolutionNode::FilterIdx, |
7623 | ConvolutionNode::ResultIdx, splitDims, |
7624 | /*resultDim*/ 3, modelParallelSplitAlignment)); |
7625 | break; |
7626 | } |
7627 | case Kinded::Kind::Convolution3DNodeKind: { |
7628 | if (modelParAxis != 4) { |
7629 | break; |
7630 | } |
7631 | splitDims[Convolution3DNode::FilterIdx] = 0; |
7632 | splitDims[Convolution3DNode::BiasIdx] = 0; |
7633 | ASSIGN_VALUE_OR_RETURN_ERR( |
7634 | CN, parallelizeAndReplaceNode( |
7635 | F, curNode, curNumOfChunks, Convolution3DNode::FilterIdx, |
7636 | Convolution3DNode::ResultIdx, splitDims, |
7637 | /*resultDim*/ 4, modelParallelSplitAlignment)); |
7638 | break; |
7639 | } |
7640 | case Kinded::Kind::AvgPoolNodeKind: { |
7641 | auto *APN = llvm::cast<AvgPoolNode>(curNode); |
7642 | if (APN->getLayout() != 2) { |
7643 | break; |
7644 | } |
7645 | if (modelParAxis != 4) { |
7646 | break; |
7647 | } |
7648 | SPLIT_ELTWISE_UNARY_OP_HELPER(AvgPool, 4); |
7649 | break; |
7650 | } |
7651 | default: |
7652 | VLOG(1) << "Attempted to parallelize op type " << curNode->getKindName() |
7653 | << "not yet supported" |
7654 | << "\n" ; |
7655 | break; |
7656 | } |
7657 | break; |
7658 | } |
7659 | |
7660 | default: |
7661 | break; |
7662 | } |
7663 | |
7664 | if (CN) { |
7665 | replacedMap[curNode] = CN; |
7666 | } |
7667 | } |
7668 | |
7669 | // Because we transformed Node types unsafely, make sure all types of the |
7670 | // Function still are valid. |
7671 | RETURN_ERR_IF_NOT(F->verify(), "Verification issue post parallelization" ); |
7672 | |
7673 | RETURN_ERR_IF_NOT(numProcessedNodes == parOpts.size(), |
7674 | "Not all Nodes specified in parOpts were processed." ); |
7675 | |
7676 | return replacedMap; |
7677 | } |
7678 | |
7679 | void glow::updateQuantReluTypes(Function *F) { |
7680 | // A worklist that contains the nodes to process. |
7681 | std::vector<Node *> worklist; |
7682 | auto needsQuantTyUpdate = [](const Node *N) { |
7683 | return isa<ConcatNode>(N) || isa<SliceNode>(N) || isa<ReshapeNode>(N) || |
7684 | isa<TileNode>(N) || isa<BroadcastNode>(N) || isa<TransposeNode>(N); |
7685 | }; |
7686 | |
7687 | for (Node &N : F->getNodes()) { |
7688 | // Look for quantized Relus that have negative min, and update their min to |
7689 | // be zero. |
7690 | auto *RN = llvm::dyn_cast<ReluNode>(&N); |
7691 | if (!RN || !RN->getResult().getType()->isQuantizedType() || |
7692 | isUsedByNodeWithSideEffects(RN)) { |
7693 | continue; |
7694 | } |
7695 | const TypeRef RNTy = RN->getResult().getType(); |
7696 | |
7697 | const auto qRange = RNTy->getQuantizedValueRange(); |
7698 | if (qRange.first >= 0) { |
7699 | continue; |
7700 | } |
7701 | const auto qParams = quantization::chooseQuantizationParams( |
7702 | {0, qRange.second}, quantization::Asymmetric, RNTy->getElementType()); |
7703 | const TypeRef qReluTy = F->getParent()->uniqueType( |
7704 | RNTy->getElementType(), RNTy->dims(), qParams.scale, qParams.offset); |
7705 | RN->setType(ReluNode::ResultIdx, qReluTy); |
7706 | |
7707 | // Now look for any users of the Relu which set their type based directly on |
7708 | // the type of the Relu, and update them as well. These tend to be shape |
7709 | // changes such as Concat, Slice, Reshape, Tile, Broadcast, Transpose, etc. |
7710 | for (auto &user : RN->getUsers()) { |
7711 | auto *U = user.getUser(); |
7712 | if (needsQuantTyUpdate(U)) { |
7713 | worklist.push_back(U); |
7714 | } |
7715 | } |
7716 | } |
7717 | |
7718 | // Now we need to update all nodes following the relus which directly took |
7719 | // their output types from the relu. |
7720 | while (!worklist.empty()) { |
7721 | Node *N = worklist.back(); |
7722 | assert(needsQuantTyUpdate(N) && "Unsupported node for quant update." ); |
7723 | worklist.pop_back(); |
7724 | |
7725 | // Look for other users that also need updates and add to worklist. |
7726 | for (auto &user : N->getUsers()) { |
7727 | auto *U = user.getUser(); |
7728 | if (needsQuantTyUpdate(U)) { |
7729 | worklist.push_back(U); |
7730 | } |
7731 | } |
7732 | |
7733 | // Note: We can unconditionally get the 0th result because all nodes we |
7734 | // currently support to update have a single output. |
7735 | assert(N->getNumResults() == 1 && "Unsupported multi-output Node" ); |
7736 | constexpr unsigned resultIdx = 0; |
7737 | const TypeRef T = N->getNthResult(resultIdx).getType(); |
7738 | |
7739 | // We must still be in the quantized domain, because we are only following |
7740 | // chains starting from quantized relus down through shape nodes. |
7741 | assert(T->isQuantizedType()); |
7742 | |
7743 | // This likely represents an issue because it means e.g. a Reshape will |
7744 | // change the scale/bias, but continue for now and assume a verifier will |
7745 | // catch the issue if it is one. |
7746 | if (isUsedByNodeWithSideEffects(N)) { |
7747 | continue; |
7748 | } |
7749 | |
7750 | // Update the output type just like we did for the original relu. |
7751 | const auto qRange = T->getQuantizedValueRange(); |
7752 | if (qRange.first >= 0) { |
7753 | continue; |
7754 | } |
7755 | const auto qParams = quantization::chooseQuantizationParams( |
7756 | {0, qRange.second}, quantization::Asymmetric, T->getElementType()); |
7757 | const TypeRef qReluTy = F->getParent()->uniqueType( |
7758 | T->getElementType(), T->dims(), qParams.scale, qParams.offset); |
7759 | N->setType(resultIdx, qReluTy); |
7760 | } |
7761 | } |
7762 | |