1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
18#include <folly/String.h>
19
20#include "glow/Backend/Backend.h"
21#include "glow/Converter/Float16Converter.h"
22#include "glow/Converter/FusedRowwiseConverter.h"
23#include "glow/Converter/TypeAToTypeBFunctionConverter.h"
24#include "glow/Flags/Flags.h"
25#include "glow/Graph/Graph.h"
26#include "glow/Graph/Log.h"
27#include "glow/Graph/Node.h"
28#include "glow/Graph/Nodes.h"
29#include "glow/Graph/PlaceholderBindings.h"
30#include "glow/Graph/TensorLayout.h"
31#include "glow/Graph/Utils.h"
32#include "glow/Graph/VerifierHelper.h"
33#include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h"
34#include "glow/Optimizer/GraphOptimizer/FunctionPasses.h"
35#include "glow/Optimizer/Lower/Lower.h"
36#include "glow/PassManager/PassManager.h"
37#include "glow/Quantization/Base/Base.h"
38#include "glow/Quantization/Quantization.h"
39#include "glow/Runtime/RuntimeTypes.h"
40
41#include "llvm/Support/Casting.h"
42#include "llvm/Support/CommandLine.h"
43
44#include <algorithm>
45#include <numeric>
46#include <unordered_map>
47#include <unordered_set>
48#include <vector>
49
50// Utility macro to continue loop if given condition is not met.
51// This is intended to improve code readability and size.
52#define CONTINUE_IF_NOT(cond) \
53 if (!(cond)) { \
54 continue; \
55 }
56
57llvm::cl::OptionCategory graphOptCat("Graph Optimizations Options");
58llvm::cl::opt<unsigned> constDedupSizeOpt(
59 "const-dedup-size",
60 llvm::cl::desc(
61 "Max number of elements allowed for deduplicating Constants. "
62 "A value equal to 0 means no limit. Default is 0."),
63 llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(graphOptCat));
64
65using namespace glow;
66using llvm::cast;
67using llvm::dyn_cast;
68using llvm::isa;
69
70static bool shouldDeleteNode(Node *N) {
71 // In general, nodes who have side effects are retained.
72 if (N->hasSideEffects()) {
73 return false;
74 }
75
76 // Don't delete nodes that have users.
77 if (N->hasUsers()) {
78 return false;
79 }
80
81 return true;
82}
83
84ConstantModificationPreventer::ConstantModificationPreventer(
85 Module &mod, CompilationContext &cctx)
86 : ScopeGuard([&]() {
87 // Ensure we cleanup Placeholder-Constant swap if necessary.
88 auto &PHs = mod_.getPlaceholders();
89 for (auto &pair : tmpPHToConstMap_) {
90 Placeholder *tmpPH = pair.first;
91 Constant *C = pair.second;
92 tmpPH->getOutput().replaceAllUsesOfWith(C->getOutput());
93 mod_.erasePlaceholder(std::find(PHs.begin(), PHs.end(), tmpPH));
94 }
95 cctx_.optimizationOpts.enableConstantFolding =
96 origEnableConstantFolding_;
97 }),
98 mod_(mod), cctx_(cctx),
99 origEnableConstantFolding_(cctx.optimizationOpts.enableConstantFolding) {
100 // By default dismiss until explicitly activated.
101 dismissed_ = true;
102}
103
104void ConstantModificationPreventer::activate() {
105 dismissed_ = false;
106 // Prevent Constant modification by temporarily replacing them with PHs.
107 for (Constant *C : mod_.getConstants()) {
108 // Note: These temp Placeholders are more like static Placeholders, but we
109 // don't want to set them as static here because optimizations may kick in
110 // to modify the type of the static Placeholder (see
111 // cctx.optimizationOpts.foldStaticPlaceholderConversions).
112 Placeholder *tmpPH = mod_.createPlaceholder(
113 C->getType(), C->getName().str() + "_SWAP_CONST_FOLD",
114 /* isTrainable */ false, C->getLayout());
115 tmpPHToConstMap_[tmpPH] = C;
116 cctx_.optimizationOpts.tempPHsForConstants.insert(tmpPH);
117 C->getOutput().replaceAllUsesOfWith(tmpPH->getOutput());
118 }
119 // Disable constant folding temporarily; restored later by the scope guard.
120 cctx_.optimizationOpts.enableConstantFolding = false;
121}
122
123/// Helper that \returns true if all functions in \p mod are loaded.
124static bool areAllFunctionsLoaded(Module *mod) {
125 for (auto *MF : mod->getFunctions()) {
126 if (MF->getState() < FunctionState::FuncLoaded) {
127 return false;
128 }
129 }
130 return true;
131}
132
133/// Helper that \returns the shuffle that inverts \p shuffle. For example, if
134/// \p shuffle is {3, 0, 1, 2}, then this function returns {1, 2, 3, 0}.
135static llvm::SmallVector<unsigned_t, max_tensor_dimensions>
136invertShuffle(llvm::ArrayRef<unsigned_t> shuffle) {
137 llvm::SmallVector<unsigned_t, max_tensor_dimensions> invertedShuffle;
138 invertedShuffle.resize(shuffle.size());
139
140 for (size_t i = 0; i < shuffle.size(); ++i) {
141 invertedShuffle[shuffle[i]] = i;
142 }
143
144 return invertedShuffle;
145}
146
147/// Add a TranposeNode after \p C in \p F that has the same shuffle as \p TR.
148/// This funcion assumes that the type of \p C and the output type of \p TR are
149/// the same. \returns the newly added TransposeNode.
150static TransposeNode *insertMatchingTransposeAfterConstant(Function *F,
151 Constant *C,
152 TransposeNode *TR) {
153 const auto *CT = C->getOutput().getType();
154 const auto *TRT = TR->getResult().getType();
155 DCHECK(CT->isEqual(*TRT, /* allowDifferentShape */ false,
156 /* allowDifferentStride */ false,
157 /* allowDifferentScaleOffset */ true));
158
159 auto &T = C->getPayload();
160
161 // In order for a new Transpose node with the same shuffle as TR to be created
162 // at the output of the Constant, a new Constant with the same dimension as
163 // the input of TR should be created. Note that the original scale and offset
164 // should be kept for quantized types.
165 auto newConstTy =
166 F->getParent()->uniqueTypeWithNewShape(CT, TR->getInput().dims());
167 auto *NC = F->getParent()->createConstant(newConstTy,
168 C->getName().str() + ".transposed");
169
170 // The payload of the original Constant C has the same type as the
171 // output of TR. In order to preserve correctness, this payload must
172 // be transposed using the inverse of the shuffle of TR and stored
173 // into the payload of the new Constant.
174 //
175 // Another way to think of this is that we are inserting two
176 // Transposes that are inverses of each other back to back after the original
177 // Constant. The shuffle of the second Transpose must match that of TR.
178 // In order to preserve correctness, the shuffle
179 // of the first Transpose must be the inverse of that shuffle of the
180 // second Transpose. The statement below statically computes this
181 // first Transpose.
182 T.transpose(&NC->getPayloadMutable(), invertShuffle(TR->getShuffle()));
183
184 // Create Transpose on the LHS that has the same shuffle as TR.
185 return F->createTranspose("transpose", NC, TR->getShuffle());
186}
187
188bool EmptyPass::run(Function *F, const CompilationContext &cctx) {
189 return false;
190}
191
192bool DCE::run(Function *F, const CompilationContext &cctx) {
193 LOG_SCOPE(F->getLogContext(), getName());
194
195 auto &nodes = F->getNodes();
196
197 std::vector<NodesList::iterator> erasedNodes{};
198
199 bool changed = false;
200
201 // Remove unused nodes.
202 while (true) {
203 bool changedLocally = false;
204 for (auto it = nodes.begin(), e = nodes.end(); it != e;) {
205 if (!shouldDeleteNode(&*it)) {
206 ++it;
207 continue;
208 }
209
210 erasedNodes.push_back(it);
211 ++it;
212 changedLocally = true;
213 changed = true;
214 }
215
216 while (!erasedNodes.empty()) {
217 auto it = erasedNodes.back();
218 F->eraseNode(it);
219 erasedNodes.pop_back();
220 }
221
222 if (!changedLocally) {
223 break;
224 }
225 }
226
227 // Don't remove unused Constants since many may be temporarily unused during
228 // optimizations.
229 if (cctx.optimizationOpts.delayAndRecordConstantModification) {
230 return changed;
231 }
232
233 if (!areAllFunctionsLoaded(F->getParent())) {
234 return changed;
235 }
236
237 // Delete unused Constants.
238 deleteUnusedConstants(*F->getParent());
239
240 return changed;
241}
242
243void glow::deleteUnusedConstants(Module &mod) {
244 auto &consts = mod.getConstants();
245 std::vector<ConstList::iterator> erasedConsts{};
246 for (auto it = consts.begin(), e = consts.end(); it != e;) {
247 if (!shouldDeleteNode(*it)) {
248 ++it;
249 continue;
250 }
251 erasedConsts.push_back(it);
252 ++it;
253 }
254
255 while (!erasedConsts.empty()) {
256 auto it = erasedConsts.back();
257 mod.eraseConstant(it);
258 erasedConsts.pop_back();
259 }
260}
261
262/// \returns true if the \p shuffle corresponds to an identity operation, false
263/// otherwise.
264static bool isIdentityShuffle(llvm::ArrayRef<unsigned> shuffle) {
265 for (size_t i = 0, e = shuffle.size(); i < e; i++) {
266 if (shuffle[i] != i) {
267 return false;
268 }
269 }
270 return true;
271}
272
273/// \returns True if the node \p N always evaluates to \p val.
274bool isSplatOfVal(Node *N, float val) {
275 SplatNode *Z = dyn_cast<SplatNode>(N);
276 if (!Z) {
277 return false;
278 }
279 return (Z->getValue() == val);
280}
281
282/// \returns True if the node returns a constant value.
283bool isConstant(Node *N) { return isa<SplatNode>(N); }
284
285/// \returns the new simplified NodeValue or the original node's first result.
286static NodeValue simplifyNode(Node *node, Function *F) {
287// Simplify commutative nodes by moving the constant operator to the right-hand
288// side.
289// Example: C + X => X + C
290#define COMMUTE_CONST_TO_RHS(NodeKind) \
291 if (auto *NN = dyn_cast<NodeKind##Node>(node)) \
292 if (isConstant(NN->getLHS()) && !isConstant(NN->getRHS())) { \
293 return F->create##NodeKind(NN->getName(), NN->getResult().getType(), \
294 NN->getRHS(), NN->getLHS()); \
295 }
296
297 COMMUTE_CONST_TO_RHS(Add)
298 COMMUTE_CONST_TO_RHS(Mul)
299 COMMUTE_CONST_TO_RHS(Max)
300 COMMUTE_CONST_TO_RHS(Min)
301#undef COMMUTE_CONST_TO_RHS
302
303 if (auto *AN = dyn_cast<AddNode>(node)) {
304 // X + 0 => X
305 if (isSplatOfVal(AN->getRHS(), 0)) {
306 return AN->getLHS();
307 }
308 }
309
310 if (auto *MN = dyn_cast<MulNode>(node)) {
311 // X * 0 => 0
312 if (isSplatOfVal(MN->getRHS(), 0)) {
313 return MN->getRHS();
314 }
315 // X * 1 => X
316 if (isSplatOfVal(MN->getRHS(), 1)) {
317 return MN->getLHS();
318 }
319 }
320
321 if (auto *DN = dyn_cast<DivNode>(node)) {
322 // 0 / X => 0
323 if (isSplatOfVal(DN->getLHS(), 0)) {
324 return DN->getLHS();
325 }
326 // X / 1 => X
327 if (isSplatOfVal(DN->getRHS(), 1)) {
328 return DN->getLHS();
329 }
330 }
331
332 // X - 0 => X
333 if (auto *SN = dyn_cast<SubNode>(node)) {
334 if (isSplatOfVal(SN->getRHS(), 0)) {
335 return SN->getLHS();
336 }
337 }
338
339 return node;
340}
341
342/// Sink Transpose below ChannelShuffle node.
343static bool sinkTranposeBelowChannelShuffle(Function *F,
344 ChannelShuffleNode *CS) {
345 auto *TR = dyn_cast<TransposeNode>(CS->getInput());
346 if (!TR) {
347 return false;
348 }
349
350 // Create a new ChannelShuffle with kernel parameter transposed by the
351 // sinking TR's shuffle because that Transpose will now be moved below this
352 // ChannelShuffle operator.
353 auto *newCS =
354 F->createChannelShuffle(CS->getName(), TR->getInput(), CS->getGroup(),
355 TR->getShuffle()[CS->getKernel()]);
356
357 // Create a copy of sinkingTR and insert after newChannelShuffle.
358 auto *newTR = F->createTranspose(TR->getName(), newCS, TR->getShuffle(),
359 TR->getLayout());
360
361 CS->getResult().replaceAllUsesOfWith(newTR);
362
363 return true;
364}
365
366/// Given \p CN from \p F, determines if all inputs are ConvertTo that use the
367/// same scale/offset/kind, and if so creates and \returns a new concat with all
368/// inputs as the inputs from the ConvertTo inputs. Otherwise \returns nullptr.
369static ConcatNode *setupConvertToSinkBelowConcat(Function *F, ConcatNode *CN) {
370 // Check if all inputs are ConvertTo.
371 std::vector<ConvertToNode *> inputNodes;
372 inputNodes.reserve(CN->getInputs().size());
373 for (auto &concatInput : CN->getInputs()) {
374 auto *CT = dyn_cast<ConvertToNode>(concatInput);
375 if (!CT) {
376 return nullptr;
377 }
378 inputNodes.push_back(CT);
379 }
380
381 // Gather all inputs of the nodes in inputNodes here.
382 std::vector<NodeValue> newInputs;
383 newInputs.reserve(inputNodes.size());
384 newInputs.push_back(inputNodes[0]->getInput());
385
386 // Get the CN's first input's result and input types to check against all
387 // other CN inputs.
388 const TypeRef firstResultTy = inputNodes[0]->getResult().getType();
389 const TypeRef firstInputTy = inputNodes[0]->getInput().getType();
390
391 // Check that all inputs have the same output and input element type.
392 for (size_t i = 1, e = inputNodes.size(); i < e; i++) {
393 const TypeRef currResultTy = inputNodes[i]->getResult().getType();
394 const TypeRef currInputTy = inputNodes[i]->getInput().getType();
395 if (currResultTy->getElementType() != firstResultTy->getElementType()) {
396 return nullptr;
397 }
398 if (firstResultTy->isQuantizedType()) {
399 if (currResultTy->getScale() != firstResultTy->getScale() ||
400 currResultTy->getOffset() != firstResultTy->getOffset()) {
401 return nullptr;
402 }
403 }
404 if (currInputTy->getElementType() != firstInputTy->getElementType()) {
405 return nullptr;
406 }
407 if (firstInputTy->isQuantizedType()) {
408 if (currInputTy->getScale() != firstInputTy->getScale() ||
409 currInputTy->getOffset() != firstInputTy->getOffset()) {
410 return nullptr;
411 }
412 }
413 newInputs.push_back(inputNodes[i]->getInput());
414 }
415
416 // Create and return a new ConcatNode with newInputs.
417 return F->createConcat(CN->getName(), newInputs, CN->getDim());
418}
419
420/// Given \p CN from \p F, determines if all inputs are either Quantize or
421/// Dequantize (depending on \p QuantNodeClass) that use the same
422/// scale/offset/kind, and if so creates and \returns a new concat with all
423/// inputs as the inputs from the Quantize or Dequantize inputs. Otherwise
424/// \returns nullptr.
425template <class QuantNodeClass>
426static ConcatNode *setupQuantDequantSinkBelowConcat(Function *F,
427 ConcatNode *CN) {
428 constexpr bool isQuant = std::is_same<QuantizeNode, QuantNodeClass>::value;
429 constexpr bool isDeq = std::is_same<DequantizeNode, QuantNodeClass>::value;
430 static_assert(isQuant || isDeq, "setupQuantDequantSinkBelowConcat() only "
431 "supports Quantize/Dequantize nodes.");
432 // Check if all inputs are Quantize with the same input
433 // scale/offset/ElemKind.
434 std::vector<QuantNodeClass *> qNodes;
435 qNodes.reserve(CN->getInputs().size());
436 for (auto &concatInput : CN->getInputs()) {
437 QuantNodeClass *Q = dyn_cast<QuantNodeClass>(concatInput);
438 if (!Q) {
439 return nullptr;
440 }
441 qNodes.push_back(Q);
442 }
443
444 // Gather all inputs of the nodes in qNodes here.
445 std::vector<NodeValue> newInputs;
446 newInputs.reserve(qNodes.size());
447 newInputs.push_back(qNodes[0]->getInput());
448
449 // Check the CN's first input's type to check against all other inputs. Use
450 // the output of Quantize or input of Dequantize.
451 const TypeRef firstTy = isQuant ? qNodes[0]->getResult().getType()
452 : qNodes[0]->getInput().getType();
453
454 // Check that all inputs have the same scale/offset/type.
455 for (size_t i = 1, e = qNodes.size(); i < e; i++) {
456 const TypeRef currTy = isQuant ? qNodes[i]->getResult().getType()
457 : qNodes[i]->getInput().getType();
458 if (currTy->getScale() != firstTy->getScale() ||
459 currTy->getOffset() != firstTy->getOffset() ||
460 currTy->getElementType() != firstTy->getElementType()) {
461 return nullptr;
462 }
463 newInputs.push_back(qNodes[i]->getInput());
464 }
465
466 // Create and return a new ConcatNode with newInputs.
467 return F->createConcat(CN->getName(), newInputs, CN->getDim());
468}
469
470/// Given \p CN from \p F, determines if all inputs are Tanh
471/// and if so creates and \returns a new concat with all
472/// inputs as the inputs from the Tanh inputs. Otherwise
473/// \returns nullptr.
474static ConcatNode *setupTanhSinkBelowConcat(Function *F, ConcatNode *CN) {
475 // Check if all inputs are Tanh
476 std::vector<TanhNode *> tanhNodes;
477 tanhNodes.reserve(CN->getInputs().size());
478 for (auto &concatInput : CN->getInputs()) {
479 TanhNode *T = dyn_cast<TanhNode>(concatInput);
480 if (!T) {
481 return nullptr;
482 }
483 tanhNodes.push_back(T);
484 }
485
486 // Gather all inputs of the nodes in tanhNodes.
487 std::vector<NodeValue> newInputs;
488 newInputs.reserve(tanhNodes.size());
489 for (size_t i = 0, e = tanhNodes.size(); i < e; i++) {
490 newInputs.emplace_back(tanhNodes[i]->getInput());
491 }
492 // Create and return a new ConcatNode with newInputs.
493 return F->createConcat(CN->getName(), newInputs, CN->getDim());
494}
495
496bool SinkConversions::run(Function *F, const CompilationContext &cctx) {
497 LOG_SCOPE(F->getLogContext(), getName());
498 bool changed = false;
499 auto &nodes = F->getNodes();
500 // For each node:
501 for (auto &N : nodes) {
502 ConcatNode *CN = dyn_cast<ConcatNode>(&N);
503 if (!CN) {
504 continue;
505 }
506 const Node *firstNode = CN->getInputs().front().getNode();
507
508 // Sink Dequantize below Concat nodes.
509 if (firstNode->getKind() == Kinded::Kind::DequantizeNodeKind) {
510 ConcatNode *newCN =
511 setupQuantDequantSinkBelowConcat<DequantizeNode>(F, CN);
512 if (!newCN) {
513 continue;
514 }
515
516 DequantizeNode *newDequantize =
517 F->createDequantize(CN->getName().str() + "_dequantize", newCN,
518 CN->getResult().getType());
519
520 CN->getResult().replaceAllUsesOfWith(newDequantize->getResult());
521 changed = true;
522 continue;
523 }
524
525 // Sink Quantize below Concat nodes.
526 if (firstNode->getKind() == Kinded::Kind::QuantizeNodeKind) {
527 ConcatNode *newCN = setupQuantDequantSinkBelowConcat<QuantizeNode>(F, CN);
528 if (!newCN) {
529 continue;
530 }
531
532 const TypeRef QTy =
533 llvm::cast<QuantizeNode>(firstNode)->getResult().getType();
534 const TypeRef concatQTy = F->getParent()->uniqueType(
535 QTy->getElementType(), newCN->getResult().dims(), QTy->getScale(),
536 QTy->getOffset());
537 QuantizeNode *newQuantize = F->createQuantize(
538 CN->getName().str() + "_quantize", newCN, concatQTy);
539
540 CN->getResult().replaceAllUsesOfWith(newQuantize->getResult());
541 changed = true;
542 continue;
543 }
544
545 // Sink ConvertTo below Concat nodes.
546 if (firstNode->getKind() == Kinded::Kind::ConvertToNodeKind) {
547 ConcatNode *newCN = setupConvertToSinkBelowConcat(F, CN);
548 if (!newCN) {
549 continue;
550 }
551 auto *newConvertTo =
552 F->createConvertTo(CN->getName().str() + "_convert_to", newCN,
553 CN->getResult().getType());
554 CN->getResult().replaceAllUsesOfWith(newConvertTo->getResult());
555 changed = true;
556 continue;
557 }
558
559 // Sink Tanh below Concat nodes.
560 if (cctx.optimizationOpts.sinkTanhBelowConcat) {
561 if (firstNode->getKind() == Kinded::Kind::TanhNodeKind) {
562 ConcatNode *newCN = setupTanhSinkBelowConcat(F, CN);
563 if (!newCN) {
564 continue;
565 }
566
567 const TypeRef TTy =
568 llvm::cast<TanhNode>(firstNode)->getResult().getType();
569 const TypeRef concatTy = F->getParent()->uniqueType(
570 TTy->getElementType(), newCN->getResult().dims());
571 TanhNode *newTanh =
572 F->createTanh(CN->getName().str() + "_tanh", concatTy, newCN);
573
574 CN->getResult().replaceAllUsesOfWith(newTanh->getResult());
575 changed = true;
576 continue;
577 }
578 }
579 }
580
581 return changed;
582}
583
584/// Sink Quantize(Concat(...)) -> Concat(Quantize(...)). This allows for
585/// concatenating less data, and if there are some inputs that are already
586/// quantized and are being dequantized just for the concat then we can skip
587/// this conversion.
588bool SinkConcatBelowQuantize::run(Function *F, const CompilationContext &cctx) {
589 LOG_SCOPE(F->getLogContext(), getName());
590 bool changed = false;
591 auto &nodes = F->getNodes();
592 // For each node:
593 for (auto &N : nodes) {
594 QuantizeNode *QN = dyn_cast<QuantizeNode>(&N);
595 if (!QN) {
596 continue;
597 }
598
599 ConcatNode *CN = dyn_cast<ConcatNode>(QN->getInput());
600 if (!CN || CN->getNumUsers() > 1) {
601 continue;
602 }
603
604 // For all inputs to the current CN, add quantize nodes to them all using
605 // the same scale/offset as QN and put the quantize nodes in newQuantInputs.
606 std::vector<NodeValue> newQuantInputs;
607 for (const NodeValue &inCN : CN->getInputs()) {
608 TypeRef newOutTy = F->getParent()->uniqueTypeWithNewShape(
609 QN->getResult().getType(), inCN.dims());
610 QuantizeNode *quantInCN = F->createQuantize(
611 inCN.getNode()->getName().str() + "_quant", inCN, newOutTy);
612 newQuantInputs.push_back(quantInCN);
613 }
614
615 // Create a new CN with the quantized inputs and replace QN with it.
616 ConcatNode *newCN =
617 F->createConcat(CN->getName(), newQuantInputs, CN->getDim());
618 QN->getResult().replaceAllUsesOfWith(newCN->getResult());
619 changed = true;
620 }
621
622 return changed;
623}
624
625/// If \p N is a TransposeNode with all of the same node kind of users, then
626/// \returns that TransposeNode, else \returns nullptr. For example, if \p N is
627/// a TransposeNode with two QuantizeNode users, this will return the
628/// TransposeNode, but if it had one QuantizeNode and one MatMul node then it
629/// will return nullptr.
630static TransposeNode *getTransposeNodeWithAllSameUserKind(Node *N) {
631 auto *TN = dyn_cast<TransposeNode>(N);
632 if (!TN) {
633 return nullptr;
634 }
635 if (TN->getNumUsers() <= 1) {
636 return TN;
637 }
638 auto firstKind = N->getUsers().front().getUser()->getKind();
639 for (auto &U : N->getUsers()) {
640 if (U.getUser()->getKind() != firstKind) {
641 return nullptr;
642 }
643 }
644 return TN;
645}
646
647/// Code Sinking.
648bool SinkCode::run(Function *F, const CompilationContext &cctx) {
649 LOG_SCOPE(F->getLogContext(), getName());
650 bool changed = false;
651 auto &nodes = F->getNodes();
652 // For each node:
653 for (auto &N : nodes) {
654 auto *node = &N;
655
656 // Sink Reshape/Transpose below BatchNormalization.
657 if (auto *BN = dyn_cast<BatchNormalizationNode>(node)) {
658
659 // Sink Reshape below BatchNormalization.
660 if (auto *RS = dyn_cast<ReshapeNode>(BN->getInput())) {
661 auto inDims = RS->getInput().dims();
662 auto outDims = RS->getResult().dims();
663 unsigned_t newChannelIdx;
664
665 // Skip sinking if: 1) the input was less than 3 dimensions,
666 // because we need spatial dimensions in addition to batch
667 // and channel or 2) if it is 3D data because the reshapes
668 // are deliberately introduced to phrase 3D BatchNormalization
669 // as a 2D one.
670 if (RS->getInput().dims().size() < 3 ||
671 RS->getInput().dims().size() == 5) {
672 continue;
673 }
674
675 // Reshape should not change the BatchNorm ChannelIdx dimensions.
676 // Only NH[W]C and NCH[W] are allowed.
677 if (BN->getChannelIdx() == outDims.size() - 1) {
678 if (inDims[inDims.size() - 1] != outDims[outDims.size() - 1]) {
679 continue;
680 }
681 newChannelIdx = inDims.size() - 1;
682 } else if (BN->getChannelIdx() == 1) {
683 // Note: index '1' maps to C in NCH[W] layout.
684 if (inDims[1] != outDims[1]) {
685 continue;
686 }
687 newChannelIdx = 1;
688 } else {
689 continue;
690 }
691
692 // Reshape should not change the batch dimension.
693 if (inDims[0] != outDims[0]) {
694 continue;
695 }
696
697 auto bnOutTy = F->getParent()->uniqueTypeWithNewShape(
698 BN->getResult().getType(), RS->getInput().getType());
699 auto rsInputType = RS->getInput().getType();
700 glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape(
701 bnOutTy, rsInputType->dims());
702 auto *newBN = F->createBatchNormalization(
703 BN->getName(), outTy, RS->getInput(), BN->getBias(), BN->getScale(),
704 BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(),
705 BN->getMomentum());
706 auto *newRS = F->createReshape(RS->getName(), newBN,
707 RS->getResult().dims(), RS->getLayout());
708 BN->getResult().replaceAllUsesOfWith(newRS);
709 changed = true;
710 continue;
711 }
712
713 // Sink Transpose below batch normalization nodes:
714 if (auto *TR = dyn_cast<TransposeNode>(BN->getInput())) {
715
716 // Figure out where we transposed the channel index for batch
717 // normalization.
718 unsigned_t idx = BN->getChannelIdx();
719 unsigned_t newChannelIdx = TR->getShuffle()[idx];
720
721 auto bnOutTy = BN->getResult().getType();
722 auto trInputType = TR->getInput().getType();
723 glow::TypeRef outTy = F->getParent()->uniqueTypeWithNewShape(
724 bnOutTy, trInputType->dims());
725
726 auto *NewBN = F->createBatchNormalization(
727 BN->getName(), outTy, TR->getInput(), BN->getBias(), BN->getScale(),
728 BN->getMean(), BN->getVar(), newChannelIdx, BN->getEpsilon(),
729 BN->getMomentum());
730 NewBN->setPredicate(node->getPredicate());
731 auto *newTR = F->createTranspose(TR->getName(), NewBN, TR->getShuffle(),
732 TR->getLayout());
733 newTR->setPredicate(node->getPredicate());
734
735 BN->getResult().replaceAllUsesOfWith(newTR);
736 changed = true;
737 continue;
738 }
739 }
740
741 if (auto *RL = dyn_cast<ReluNode>(node)) {
742 // Sink Transpose below batch RELU nodes.
743 if (auto *TR = dyn_cast<TransposeNode>(RL->getInput())) {
744 // Keep the same quantization parameters for ReLU output, but
745 // change the shape to appropriate value.
746 auto reluOutTy = F->getParent()->uniqueTypeWithNewShape(
747 RL->getResult().getType(), TR->getInput().getType());
748 auto *NRL = F->createRELU(RL->getName(), TR->getInput(), reluOutTy);
749 NRL->setPredicate(node->getPredicate());
750 auto *newTR = F->createTranspose(TR->getName(), NRL, TR->getShuffle(),
751 TR->getLayout());
752 newTR->setPredicate(node->getPredicate());
753 RL->getResult().replaceAllUsesOfWith(newTR);
754 changed = true;
755 continue;
756 }
757
758 // Sink Clip below RELU nodes.
759 if (ClipNode *CN = dyn_cast<ClipNode>(RL->getInput())) {
760 assert(!RL->getResult().getType()->isQuantizedType() &&
761 "Relu(Clip) means Relu should not be quantized.");
762 ReluNode *newRL = F->createRELU(RL->getName(), CN->getInput());
763 ClipNode *newCN =
764 F->createClip(CN->getName(), newRL->getResult(),
765 std::max(CN->getMin(), 0.0f), CN->getMax());
766 RL->getResult().replaceAllUsesOfWith(newCN);
767 changed = true;
768 continue;
769 }
770 }
771
772 // Sink Transpose below Clip nodes.
773 if (auto *CL = dyn_cast<ClipNode>(node)) {
774 auto *TR = dyn_cast<TransposeNode>(CL->getInput());
775
776 if (!TR) {
777 continue;
778 }
779
780 // Keep the same quantization parameters for Clip output, but
781 // change the shape to appropriate value.
782 auto clipOutTy = F->getParent()->uniqueTypeWithNewShape(
783 CL->getResult().getType(), TR->getInput().getType());
784 auto *NCL = F->createClip(CL->getName(), TR->getInput(), clipOutTy,
785 CL->getMin(), CL->getMax());
786 NCL->setPredicate(node->getPredicate());
787 auto *newTR = F->createTranspose(TR->getName(), NCL, TR->getShuffle());
788 newTR->setPredicate(node->getPredicate());
789 CL->getResult().replaceAllUsesOfWith(newTR);
790 changed = true;
791 continue;
792 }
793
794 // Sink Transpose below LeakyRelu nodes.
795 if (auto *LR = dyn_cast<LeakyReluNode>(node)) {
796 auto *TR = dyn_cast<TransposeNode>(LR->getInput());
797 if (!TR) {
798 continue;
799 }
800 auto newLROutTy = F->getParent()->uniqueTypeWithNewShape(
801 LR->getResult().getType(), TR->getInput().getType());
802 auto *newLR = F->createLeakyRELU(LR->getName(), newLROutTy,
803 TR->getInput(), LR->getAlpha());
804 newLR->setPredicate(node->getPredicate());
805 auto *newTR = F->createTranspose(TR->getName(), newLR, TR->getShuffle());
806 newTR->setPredicate(node->getPredicate());
807 LR->getResult().replaceAllUsesOfWith(newTR);
808 changed = true;
809 continue;
810 }
811
812 // Sink Transpose below PRelu with Splat.
813 if (auto *PN = dyn_cast<PReluNode>(node)) {
814 auto *TR = dyn_cast<TransposeNode>(PN->getInput());
815 if (!TR) {
816 continue;
817 }
818 auto *SN = dyn_cast<SplatNode>(PN->getSlope());
819 if (!SN) {
820 continue;
821 }
822 auto newSNOutTy = F->getParent()->uniqueTypeWithNewShape(
823 SN->getResult().getType(), TR->getInput().getType());
824 auto newPNOutTy = F->getParent()->uniqueTypeWithNewShape(
825 PN->getResult().getType(), TR->getInput().getType());
826 auto *newSN = F->createSplat(SN->getName(), newSNOutTy, SN->getValue());
827 auto *newPN =
828 F->createPRELU(PN->getName(), TR->getInput(), newSN, newPNOutTy);
829 auto *newTR = F->createTranspose(TR->getName(), newPN, TR->getShuffle());
830 newPN->setPredicate(node->getPredicate());
831 newTR->setPredicate(node->getPredicate());
832 PN->getResult().replaceAllUsesOfWith(newTR);
833 changed = true;
834 continue;
835 }
836
837 // Sink Transpose below Sigmoid nodes.
838 if (auto *SI = dyn_cast<SigmoidNode>(node)) {
839 auto *TR = dyn_cast<TransposeNode>(SI->getInput());
840
841 if (!TR) {
842 continue;
843 }
844
845 auto *NSI = F->createSigmoid(SI->getName(), TR->getInput());
846 NSI->setPredicate(node->getPredicate());
847 auto *newTR = F->createTranspose(TR->getName(), NSI, TR->getShuffle(),
848 TR->getLayout());
849 newTR->setPredicate(node->getPredicate());
850 SI->getResult().replaceAllUsesOfWith(newTR);
851 changed = true;
852 continue;
853 }
854
855 // Sink Transpose below Tile nodes.
856 if (auto *TN = dyn_cast<TileNode>(node)) {
857 auto *TR = dyn_cast<TransposeNode>(TN->getInput());
858
859 if (!TR) {
860 continue;
861 }
862
863 auto *newTN = F->createTile(TN->getName(), TR->getInput(), TN->getCount(),
864 TR->getShuffle()[TN->getAxis()]);
865 newTN->setPredicate(node->getPredicate());
866 auto *newTR = F->createTranspose(TR->getName(), newTN, TR->getShuffle(),
867 TR->getLayout());
868 newTR->setPredicate(node->getPredicate());
869 TN->getResult().replaceAllUsesOfWith(newTR);
870 changed = true;
871 continue;
872 }
873
874 // Sink Transpose below Pad nodes.
875 if (auto *padNode = dyn_cast<PadNode>(node)) {
876 auto *transposeNode = dyn_cast<TransposeNode>(padNode->getInput());
877
878 if (!transposeNode) {
879 continue;
880 }
881
882 // The transpose shuffle specifies the source dimension.
883 // When sinking Transpose below Pad, shuffle describes the target
884 // dimension.
885 auto shuffle = transposeNode->getShuffle();
886
887 // Shuffle the Pad output type and the padding attribute.
888 auto outPadType = padNode->getResult().getType();
889 auto outPadShape = outPadType->dims();
890 auto pads = padNode->getPads();
891 size_t numDims = outPadShape.size();
892 std::vector<dim_t> newOutPadShape(numDims);
893 std::vector<int> newPads(2 * numDims);
894 for (size_t i = 0; i < outPadShape.size(); i++) {
895 newOutPadShape[shuffle[i]] = outPadShape[i];
896 newPads[shuffle[i]] = pads[i];
897 newPads[shuffle[i] + numDims] = pads[i + numDims];
898 }
899
900 // New pad
901 auto newOutPadType =
902 F->getParent()->uniqueTypeWithNewShape(outPadType, newOutPadShape);
903 auto *NewPadNode = F->createPad(
904 padNode->getName(), transposeNode->getInput(), newOutPadType,
905 padNode->getMode(), newPads, padNode->getValue());
906 NewPadNode->setPredicate(node->getPredicate());
907 auto *newTransposeNode =
908 F->createTranspose(transposeNode->getName(), NewPadNode, shuffle);
909 newTransposeNode->setPredicate(node->getPredicate());
910 padNode->getResult().replaceAllUsesOfWith(newTransposeNode);
911 changed = true;
912 continue;
913 }
914
915 // Sink Transpose below Tanh nodes.
916 if (auto *TN = dyn_cast<TanhNode>(node)) {
917 auto *TR = dyn_cast<TransposeNode>(TN->getInput());
918
919 if (!TR) {
920 continue;
921 }
922
923 auto *NTN = F->createTanh(TN->getName(), TR->getInput());
924 NTN->setPredicate(node->getPredicate());
925 auto *newTR = F->createTranspose(TR->getName(), NTN, TR->getShuffle(),
926 TR->getLayout());
927 newTR->setPredicate(node->getPredicate());
928 TN->getResult().replaceAllUsesOfWith(newTR);
929 changed = true;
930 continue;
931 }
932
933 // Remove 'identity' transpose operations.
934 if (auto *TR = dyn_cast<TransposeNode>(node)) {
935 auto mask = TR->getShuffle();
936
937 if (isIdentityShuffle(mask)) {
938 TR->getResult().replaceAllUsesOfWith(TR->getInput());
939 changed = true;
940 continue;
941 }
942 }
943
944 // Merge consecutive Transpose operations.
945 if (auto *TR1 = dyn_cast<TransposeNode>(node)) {
946 auto *TR2 = dyn_cast<TransposeNode>(TR1->getInput());
947
948 if (!TR2) {
949 continue;
950 }
951
952 auto mask1 = TR1->getShuffle();
953 auto mask2 = TR2->getShuffle();
954 assert(mask1.size() == mask2.size() && "Invalid mask size");
955
956 llvm::SmallVector<unsigned_t, max_tensor_dimensions> newMask;
957 newMask.resize(mask2.size());
958
959 for (size_t i = 0, end = mask2.size(); i < end; i++) {
960 newMask[i] = mask2[mask1[i]];
961 }
962
963 auto *newTR = F->createTranspose("tranpose", TR2->getInput(), newMask);
964 TR1->getResult().replaceAllUsesOfWith(newTR->getResult());
965 changed = true;
966 continue;
967 }
968
969 if (auto *CS = dyn_cast<ChannelShuffleNode>(node)) {
970 // Sink Transpose below ChannelShuffle.
971 if (sinkTranposeBelowChannelShuffle(F, CS)) {
972 changed = true;
973 continue;
974 }
975 }
976
977 // Sink Transpose below Arithmetic nodes.
978 if (node->isArithmetic()) {
979 TransposeNode *LTR =
980 dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::LHSIdx));
981 TransposeNode *RTR =
982 dyn_cast<TransposeNode>(node->getNthInput(ArithmeticNode::RHSIdx));
983
984 if (!LTR || !RTR) {
985 // If one of the sides is a splat, it can be seen as
986 // transpose (splat'). Similarly, if one of the sides is a Constant,
987 // it can be seen as tranpose (Constant').
988 if (isa<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx)) && RTR) {
989 // Build splat' for LHS.
990 auto *SN =
991 dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::LHSIdx));
992 auto *NS = F->createSplat("splat", RTR->getInput().getType(),
993 SN->getValue());
994 LTR = F->createTranspose("transpose", NS, RTR->getShuffle(),
995 RTR->getLayout());
996 changed = true;
997 } else if (isa<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx)) &&
998 LTR) {
999 // Build splat' for RHS.
1000 auto *SN =
1001 dyn_cast<SplatNode>(node->getNthInput(ArithmeticNode::RHSIdx));
1002 auto *NS = F->createSplat("splat", LTR->getInput().getType(),
1003 SN->getValue());
1004 RTR = F->createTranspose("transpose", NS, LTR->getShuffle(),
1005 LTR->getLayout());
1006 changed = true;
1007 } else if (isa<Constant>(node->getNthInput(ArithmeticNode::LHSIdx)) &&
1008 RTR) {
1009 // Build Constant' for for LHS.
1010 auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::LHSIdx));
1011 LTR = insertMatchingTransposeAfterConstant(F, C, RTR);
1012 changed = true;
1013 } else if (isa<Constant>(node->getNthInput(ArithmeticNode::RHSIdx)) &&
1014 LTR) {
1015 // Build Constant' for for RHS.
1016 auto *C = cast<Constant>(node->getNthInput(ArithmeticNode::RHSIdx));
1017 RTR = insertMatchingTransposeAfterConstant(F, C, LTR);
1018 changed = true;
1019 } else {
1020 continue;
1021 }
1022 }
1023 // The masks of the transposes on both sizes must match.
1024 if (LTR->getShuffle() != RTR->getShuffle()) {
1025 continue;
1026 }
1027
1028 Node *newAN = nullptr;
1029
1030#define ARITHMETIC_CASE(NODE_NAME_) \
1031 case glow::Kinded::Kind::NODE_NAME_##NodeKind: \
1032 newAN = \
1033 F->create##NODE_NAME_(node->getName(), \
1034 F->getParent()->uniqueTypeWithNewShape( \
1035 node->getType(ArithmeticNode::ResultIdx), \
1036 LTR->getInput().getType()), \
1037 LTR->getInput(), RTR->getInput()); \
1038 break;
1039
1040#define BOOLEAN_OP_CASE(NODE_NAME_) \
1041 case glow::Kinded::Kind::NODE_NAME_##NodeKind: \
1042 newAN = F->create##NODE_NAME_(node->getName(), LTR->getInput(), \
1043 RTR->getInput()); \
1044 break;
1045
1046 switch (node->getKind()) {
1047 ARITHMETIC_CASE(Add);
1048 ARITHMETIC_CASE(Mul);
1049 ARITHMETIC_CASE(Sub);
1050 ARITHMETIC_CASE(Div);
1051 ARITHMETIC_CASE(Fmod);
1052 ARITHMETIC_CASE(Max);
1053 ARITHMETIC_CASE(Min);
1054 ARITHMETIC_CASE(Pow);
1055 BOOLEAN_OP_CASE(CmpLTE);
1056 BOOLEAN_OP_CASE(CmpEQ);
1057 default:
1058 llvm_unreachable("Unhandled node");
1059 }
1060#undef BOOLEAN_OP_CASE
1061#undef ARITHMETIC_CASE
1062
1063 newAN->setPredicate(node->getPredicate());
1064 changed = true;
1065 auto *newTR = F->createTranspose(LTR->getName(), newAN, LTR->getShuffle(),
1066 LTR->getLayout());
1067 newTR->setPredicate(node->getPredicate());
1068 node->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(newTR);
1069 }
1070
1071 if (auto *Q = dyn_cast<QuantizeNode>(node)) {
1072 // Sink TransposeNode below QuantizedNode.
1073 if (auto *TR = getTransposeNodeWithAllSameUserKind(Q->getInput())) {
1074 auto newQType = F->getParent()->uniqueTypeWithNewShape(
1075 Q->getResult().getType(), TR->getInput().dims());
1076 auto *newQ = F->createQuantize(Q->getName(), TR->getInput(), newQType);
1077 auto *newTR = F->createTranspose(TR->getName(), newQ, TR->getShuffle());
1078 Q->getResult().replaceAllUsesOfWith(newTR);
1079 changed = true;
1080 continue;
1081 }
1082
1083 // Sink Reshape below Quantize.
1084 if (auto *RN = dyn_cast<ReshapeNode>(Q->getInput())) {
1085 auto newQType = F->getParent()->uniqueTypeWithNewShape(
1086 Q->getResult().getType(), RN->getInput().dims());
1087 auto *newQ = F->createQuantize(Q->getName(), RN->getInput(), newQType);
1088 auto *newRN = F->createReshape(RN->getName(), newQ,
1089 RN->getResult().dims(), RN->getLayout());
1090 Q->getResult().replaceAllUsesOfWith(newRN->getResult());
1091 changed = true;
1092 continue;
1093 }
1094 }
1095
1096 // Sink Reshape below ConvertTo.
1097 if (auto *CN = dyn_cast<ConvertToNode>(node)) {
1098 auto *RN = dyn_cast<ReshapeNode>(CN->getInput());
1099 if (!RN) {
1100 continue;
1101 }
1102 auto *newCN = F->createConvertTo(CN->getName(), RN->getInput(),
1103 CN->getResult().getElementType());
1104 auto *newRN = F->createReshape(RN->getName(), newCN,
1105 RN->getResult().dims(), RN->getLayout());
1106 CN->getResult().replaceAllUsesOfWith(newRN->getResult());
1107 changed = true;
1108 continue;
1109 }
1110
1111 // Sink TransposeNode below DequantizedNode.
1112 // If it doesn't work out it will be re-sinked later.
1113 if (auto *D = dyn_cast<DequantizeNode>(node)) {
1114 auto *TR = dyn_cast<TransposeNode>(D->getInput());
1115 if (!TR) {
1116 continue;
1117 }
1118
1119 auto newDType = F->getParent()->uniqueTypeWithNewShape(
1120 D->getResult().getType(), TR->getInput().dims());
1121 auto *newD = F->createDequantize(D->getName(), TR->getInput(), newDType);
1122 auto *newTR = F->createTranspose(TR->getName(), newD, TR->getShuffle());
1123 D->getResult().replaceAllUsesOfWith(newTR);
1124 changed = true;
1125 }
1126
1127 // Sink Transpose below RescaleQuantized.
1128 // Potentially exposes opportunity to be combined up with Convolution.
1129 // If it doesn't work out it will be re-sinked later.
1130 if (auto *RQ = dyn_cast<RescaleQuantizedNode>(node)) {
1131 auto *TR = dyn_cast<TransposeNode>(RQ->getInput());
1132 if (!TR) {
1133 continue;
1134 }
1135
1136 auto newRQType = F->getParent()->uniqueTypeWithNewShape(
1137 RQ->getResult().getType(), TR->getInput().getType());
1138 auto *newRQ =
1139 F->createRescaleQuantized(RQ->getName(), TR->getInput(), newRQType);
1140 auto *newTR = F->createTranspose(TR->getName(), newRQ, TR->getShuffle(),
1141 TR->getLayout());
1142 RQ->getResult().replaceAllUsesOfWith(newTR);
1143 changed = true;
1144 }
1145
1146 if (auto *CN = dyn_cast<ConcatNode>(node)) {
1147 const Node *firstNode = CN->getInputs().front().getNode();
1148 // Sink RELU below batch concat nodes.
1149 if (firstNode->getKind() == Kinded::Kind::ReluNodeKind) {
1150 llvm::SmallVector<NodeValue, 6> CNInputs;
1151 for (auto &input : CN->getInputs()) {
1152 auto *inputRL = dyn_cast<ReluNode>(input);
1153 if (!inputRL) {
1154 break;
1155 }
1156 CNInputs.push_back(inputRL->getInput());
1157 }
1158
1159 if (CNInputs.size() == CN->getNumInputs()) {
1160 auto *newCN = F->createConcat(CN->getName(), CNInputs, CN->getDim());
1161 newCN->setPredicate(node->getPredicate());
1162 auto name = CN->getNthInput(0).getNode()->getName();
1163 auto *newRL = F->createRELU(name, newCN, CN->getResult().getType());
1164 newRL->setPredicate(node->getPredicate());
1165 CN->getResult().replaceAllUsesOfWith(newRL);
1166 changed = true;
1167 }
1168 continue;
1169 }
1170
1171 // Sink Transpose below concat nodes.
1172 if (firstNode->getKind() == Kinded::Kind::TransposeNodeKind) {
1173 llvm::SmallVector<NodeValue, 6> transVector;
1174 auto inputIter = CN->getInputs().begin();
1175 auto *firstInput = dyn_cast<TransposeNode>(*inputIter);
1176 if (!firstInput) {
1177 continue;
1178 }
1179
1180 transVector.push_back(firstInput->getInput());
1181 auto shuffle = firstInput->getShuffle();
1182 // If the shuffle masks don't agree or not all inputs are Transpose then
1183 // bail out.
1184 for (++inputIter; inputIter != CN->getInputs().end(); ++inputIter) {
1185 auto *tTR = dyn_cast<TransposeNode>(*inputIter);
1186 if (!tTR || tTR->getShuffle() != shuffle) {
1187 break;
1188 }
1189 transVector.push_back(tTR->getInput());
1190 }
1191
1192 if (transVector.size() != CN->getNumInputs()) {
1193 continue;
1194 }
1195
1196 // Figure out where we transposed the channel index for batch
1197 // normalization.
1198 unsigned_t idx = CN->getDim();
1199 unsigned_t newChannelIdx = shuffle[idx];
1200
1201 auto *newCN =
1202 F->createConcat(CN->getName(), transVector, newChannelIdx);
1203 newCN->setPredicate(node->getPredicate());
1204 auto *newTR = F->createTranspose(firstInput->getName(), newCN,
1205 firstInput->getShuffle(),
1206 firstInput->getLayout());
1207 newTR->setPredicate(node->getPredicate());
1208 CN->getResult().replaceAllUsesOfWith(newTR);
1209 changed = true;
1210 continue;
1211 }
1212 }
1213 } // For all nodes in the graph.
1214
1215 // Transformations to sink nodes below Slice. Outlined into a separate loop to
1216 // prevent Transpose/Slice sinking to affect them.
1217 for (auto &N : nodes) {
1218 auto *node = &N;
1219 // Sink BatchNorm below Slice.
1220 if (auto *SN = dyn_cast<SliceNode>(node)) {
1221 auto *BN = dyn_cast<BatchNormalizationNode>(SN->getInput());
1222 if (!BN || !BN->hasOneUse()) {
1223 continue;
1224 }
1225
1226 // Don't support sinking below Slice which affects depth.
1227 if (SN->getInput().dims()[BN->getChannelIdx()] !=
1228 SN->getResult().dims()[BN->getChannelIdx()]) {
1229 continue;
1230 }
1231
1232 auto newSNType = F->getParent()->uniqueTypeWithNewShape(
1233 BN->getInput().getType(), SN->getResult().dims());
1234 auto *newSN = F->createSlice(SN->getName(), BN->getInput(),
1235 SN->getStart(), newSNType);
1236 auto *newBN = F->createBatchNormalization(
1237 BN->getName(), SN->getResult().getType(), newSN, BN->getBias(),
1238 BN->getScale(), BN->getMean(), BN->getVar(), BN->getChannelIdx(),
1239 BN->getEpsilon(), BN->getMomentum());
1240 SN->getResult().replaceAllUsesOfWith(newBN);
1241 changed = true;
1242 }
1243 }
1244
1245 return changed;
1246}
1247
1248/// Code Hoisting.
1249bool HoistCode::run(Function *F, const CompilationContext &cctx) {
1250 LOG_SCOPE(F->getLogContext(), getName());
1251 bool changed = false;
1252 auto &nodes = F->getNodes();
1253 // For each node:
1254 for (auto &N : nodes) {
1255 auto *node = &N;
1256
1257 // Hoist Transpose above Tile nodes.
1258 if (auto *TR = dyn_cast<TransposeNode>(node)) {
1259 auto *TN = dyn_cast<TileNode>(TR->getInput());
1260
1261 if (!TN) {
1262 continue;
1263 }
1264
1265 auto *newTR = F->createTranspose(TR->getName(), TN->getInput(),
1266 TR->getShuffle(), TR->getLayout());
1267 newTR->setPredicate(node->getPredicate());
1268 auto *newTN =
1269 F->createTile(TN->getName(), newTR, TN->getCount(),
1270 invertShuffle(TR->getShuffle())[TN->getAxis()]);
1271 newTN->setPredicate(node->getPredicate());
1272 TR->getResult().replaceAllUsesOfWith(newTN);
1273 changed = true;
1274 continue;
1275 }
1276 }
1277
1278 return changed;
1279}
1280
1281/// Reshape Sinking.
1282bool SinkReshapes::run(Function *F, const CompilationContext &cctx) {
1283 LOG_SCOPE(F->getLogContext(), getName());
1284 bool changed = false;
1285 auto &nodes = F->getNodes();
1286 // For each node:
1287 for (auto &N : nodes) {
1288 auto *node = &N;
1289
1290 // Sink Reshape below eltwise nodes.
1291 if (!node->isDataParallel() || node->hasSideEffects()) {
1292 continue;
1293 }
1294
1295 // Unary eltwise nodes.
1296 if (node->getNumInputs() != 1 || node->getNumResults() != 1) {
1297 continue;
1298 }
1299
1300 auto *RS = dyn_cast<ReshapeNode>(node->getNthInput(0));
1301 if (!RS) {
1302 continue;
1303 }
1304
1305 // Create new eltwise node.
1306 auto in = RS->getInput();
1307 auto out = node->getNthResult(0);
1308 auto newTy =
1309 F->getParent()->uniqueTypeWithNewShape(out.getType(), in.dims());
1310 auto *newN = F->addNode(node->clone());
1311 newN->setNthInput(0, in);
1312 newN->setTypeUnsafe(0, newTy);
1313 newN->setPredicate(node->getPredicate());
1314
1315 // Create new Reshape.
1316 auto *newRS = F->createReshape(RS->getName(), newN,
1317 RS->getResult().getType()->dims());
1318 newRS->setPredicate(node->getPredicate());
1319 out.replaceAllUsesOfWith(newRS->getResult());
1320
1321 changed = true;
1322 }
1323 return changed;
1324}
1325
1326/// Remove unnecessary padding and reduce filters for Convolution nodes with
1327/// small input tensors.
1328bool OptimizeSmallConv::run(Function *F, const CompilationContext &cctx) {
1329 LOG_SCOPE(F->getLogContext(), getName());
1330 bool changed = false;
1331 for (auto &N : F->getNodes()) {
1332 auto *CN = dyn_cast<ConvolutionNode>(&N);
1333 if (!CN) {
1334 continue;
1335 }
1336
1337 // Consider a Convolution "small", if its output is 1x1.
1338 // The transformation doesn't support dilation.
1339 auto dilation = CN->getDilation();
1340 ShapeNHWC odim(CN->getResult().dims());
1341 if (odim.h > 1 || odim.w > 1 || !isUniformArray(dilation, 1u)) {
1342 continue;
1343 }
1344
1345 // Dealing with stride=1 Convoltuion nodes is generally easier, so try
1346 // to canonicalize stride to 1 if possible.
1347 ShapeNHWC idim(CN->getInput().dims());
1348 std::vector<unsigned_t> strides(CN->getStrides());
1349 std::vector<unsigned_t> kernels(CN->getKernels());
1350 std::vector<unsigned_t> pads(CN->getPads());
1351 auto newOutHW = calculateConvPoolOutputDims(idim.h, idim.w, kernels, {1, 1},
1352 pads, dilation);
1353 if (newOutHW.first == 1 && newOutHW.second == 1) {
1354 strides = {1, 1};
1355 }
1356
1357 // Slice off redundant filter parts.
1358 auto filters = CN->getFilter();
1359 auto *C = dyn_cast<Constant>(filters);
1360 if (C && isUniformArray(llvm::makeArrayRef(strides), 1u) &&
1361 !isUniformArray(llvm::makeArrayRef(pads), 0u)) {
1362 ShapeNHWC fdim(filters.dims());
1363 PaddingTLBR p(llvm::makeArrayRef(pads));
1364 dim_t start[] = {0u, p.top, p.left, 0u};
1365 dim_t end[] = {fdim.n, fdim.h - p.bottom, fdim.w - p.right, fdim.c};
1366 auto *SN = F->createSlice(C->getName(), C, start, end);
1367 filters = SN->getResult();
1368 kernels = {unsigned_t(idim.h), unsigned_t(idim.w)};
1369 pads = {0, 0, 0, 0};
1370 }
1371
1372 // Check if this node needs any changes.
1373 if (filters == CN->getFilter() &&
1374 llvm::makeArrayRef(strides) == CN->getStrides()) {
1375 continue;
1376 }
1377
1378 auto *newCN =
1379 F->createConv(CN->getName(), CN->getInput(), filters, CN->getBias(),
1380 CN->getResult().getType(), kernels, strides, pads,
1381 CN->getGroup(), dilation);
1382
1383 CN->getResult().replaceAllUsesOfWith(newCN->getResult());
1384 changed = true;
1385 }
1386
1387 return changed;
1388}
1389
1390/// Fold a Convolution dilated manually using Transpose, SpaceToDepth and
1391/// DepthToSpace nodes into a single Convolution node. Pattern:
1392/// NHWC2CHWN -> S2D -> CHWN2NHWC -> Conv -> NHWC2CHWN -> D2S -> CHWN2NHWC
1393bool FoldDilatedConv::run(Function *F, const CompilationContext &cctx) {
1394 LOG_SCOPE(F->getLogContext(), getName());
1395 bool changed = false;
1396 for (auto &N : F->getNodes()) {
1397 // Do matching starting from the last node of the pattern and go backwards.
1398
1399 // 1. Transpose CHWN2NHWC.
1400 // Do not match this Transpose, since it has likely sank and got separated
1401 // from the rest of the nodes in the pattern. Instead, we'll generate a
1402 // reverse Transpose after new Convolution for them to be optimized out.
1403
1404 // 2. DepthToSpace (represented as Reshape + Transpose[6 dims] + Reshape).
1405 // Ignore the last Reshape for the same reasons as Transpose above and start
1406 // matching from 6 dim Transpose.
1407 auto *D2S = dyn_cast<TransposeNode>(&N);
1408 if (!D2S || !D2S->hasOneUse() ||
1409 D2S->getShuffle() != llvm::makeArrayRef(D2S_DCR)) {
1410 continue;
1411 }
1412 unsigned_t block = D2S->getInput().dims()[3];
1413 NodeValue output = D2S->getResult();
1414
1415 auto *RN = dyn_cast<ReshapeNode>(D2S->getInput());
1416 if (!RN || !RN->hasOneUse()) {
1417 continue;
1418 }
1419 llvm::ArrayRef<dim_t> idim = RN->getInput().dims();
1420 llvm::SmallVector<dim_t, 6> odim = {
1421 idim[0], idim[1], idim[2], block, block, idim[3] / (block * block)};
1422 if (RN->getResult().dims() != llvm::makeArrayRef(odim)) {
1423 continue;
1424 }
1425
1426 // 3. Transpose NHWC2CHWN.
1427 auto *T1 = dyn_cast<TransposeNode>(RN->getInput());
1428 if (!T1 || !T1->hasOneUse() ||
1429 T1->getShuffle() != llvm::makeArrayRef(NHWC2CHWN)) {
1430 continue;
1431 }
1432
1433 // 4. Convolution.
1434 llvm::StringRef name;
1435 llvm::ArrayRef<unsigned_t> kernels, strides, pads, dilation;
1436 NodeValue convInput, convResult;
1437 auto getConvParams = [&](auto *N) -> bool {
1438 if (!N || !N->hasOneUse()) {
1439 return false;
1440 }
1441 name = N->getName();
1442 kernels = N->getKernels();
1443 strides = N->getStrides();
1444 pads = N->getPads();
1445 dilation = N->getDilation();
1446 convInput = N->getInput();
1447 convResult = N->getResult();
1448 return true;
1449 };
1450 auto *CN = dyn_cast<ConvolutionNode>(T1->getInput());
1451 auto *CQCN = dyn_cast<ChannelwiseQuantizedConvolutionNode>(T1->getInput());
1452 if (!getConvParams(CN) && !getConvParams(CQCN)) {
1453 continue;
1454 }
1455 if (!isUniformArray(strides, 1u) || !isUniformArray(pads, 0u) ||
1456 !isUniformArray(dilation, 1u)) {
1457 continue;
1458 }
1459
1460 // 5. Transpose CHWN2NHWC.
1461 auto *T2 = dyn_cast<TransposeNode>(convInput);
1462 if (!T2 || T2->getShuffle() != llvm::makeArrayRef(CHWN2NHWC)) {
1463 continue;
1464 }
1465
1466 // 6. SpaceToDepth.
1467 auto *S2D = dyn_cast<SpaceToDepthNode>(T2->getInput());
1468 if (!S2D || S2D->getBlockSize() != block) {
1469 continue;
1470 }
1471 NodeValue input = S2D->getInput();
1472
1473 // 7. Transpose NHWC2CHWN.
1474 // Can potentially be changed/removed by merging with other Transpose nodes,
1475 // so don't match it. Instead, will generate a reverse Transpose later.
1476
1477 // Create CHWN2NHWC -> Conv -> NHWC2CHWN -> Reshape.
1478 // Reshape and Transposes are created to cancel the ones we did not match.
1479 auto *newT1 =
1480 F->createTranspose(name.str() + "_chwn2nhwc", input, CHWN2NHWC);
1481
1482 auto trOutDims = newT1->getResult().dims();
1483 auto outHW = calculateConvPoolOutputDims(
1484 trOutDims[1], trOutDims[2], kernels, strides, pads, {block, block});
1485 auto convOutTy = F->getParent()->uniqueTypeWithNewShape(
1486 convResult.getType(),
1487 {trOutDims[0], outHW.first, outHW.second, convResult.dims()[3]});
1488 Node *newCN;
1489 if (CN) {
1490 newCN = F->createConv(name, newT1->getResult(), CN->getFilter(),
1491 CN->getBias(), convOutTy, kernels, strides, pads,
1492 CN->getGroup(), {block, block});
1493 } else if (CQCN) {
1494 newCN = F->createChannelwiseQuantizedConv(
1495 name, newT1->getResult(), CQCN->getFilter(), CQCN->getBias(),
1496 CQCN->getFilterScales(), CQCN->getFilterOffsets(),
1497 CQCN->getBiasScales(), CQCN->getBiasOffsets(), convOutTy, kernels,
1498 strides, pads, CQCN->getGroup(), {block, block}, false, false);
1499 } else {
1500 llvm_unreachable("Convolution must be in the pattern");
1501 }
1502
1503 auto *newT2 =
1504 F->createTranspose(name.str() + "_nhwc2chwn", newCN, NHWC2CHWN);
1505
1506 idim = newT2->getResult().dims();
1507 odim = {idim[0], idim[1] / block, block, idim[2] / block, block, idim[3]};
1508 auto *newRN =
1509 F->createReshape(name.str() + "_reshape", newT2->getResult(), odim);
1510
1511 output.replaceAllUsesOfWith(newRN->getResult());
1512 changed = true;
1513 }
1514
1515 return changed;
1516}
1517
1518/// \returns True if node A may depend on the result of B. The relationship
1519/// between the nodes does not have to be direct. For example, A can depend on
1520/// X which depends on B. In that case the method needs to return True.
1521/// Check the use-def dependency up to a depth of \p depth.
1522static bool mayDepend(Node *A, Node *B, unsigned depth = 6) {
1523 // We define the identify as a dependency.
1524 if (A == B) {
1525 return true;
1526 }
1527
1528 // A does not depend on anything.
1529 if (A->getNumInputs() == 0) {
1530 return false;
1531 }
1532
1533 // B has no users. Nothing can depend on it.
1534 if (B->getNumResults() == 0) {
1535 return false;
1536 }
1537
1538 // We can't continue the search. Assume that the nodes depend on one another.
1539 if (depth == 0) {
1540 return true;
1541 }
1542
1543 // Check all inputs of A. None of them may depend on B.
1544 for (int i = 0, e = A->getNumInputs(); i < e; i++) {
1545 auto *input = A->getNthInput(i).getNode();
1546 // The inputs of A must not depend on B.
1547 if (mayDepend(input, B, depth - 1)) {
1548 return true;
1549 }
1550 }
1551
1552 // We checked all inputs of A and none of them depend on B.
1553 return false;
1554}
1555
1556/// \returns True if the node \p N depends on any of the values in \p list, or
1557/// if any of the values in list depend on \p N.
1558static bool mayDependOnAny(llvm::ArrayRef<NodeValue> list, Node *N) {
1559 for (auto &ll : list) {
1560 if (mayDepend(ll.getNode(), N) || mayDepend(N, ll.getNode())) {
1561 return true;
1562 }
1563 }
1564
1565 return false;
1566}
1567
1568/// Helper function to merge matmuls in \p F.
1569/// If \p mergeOnLHS is true, it merges LHS operands of matmuls that have the
1570/// same RHS. If \p mergeOnLHS is false, it merges RHS operands of matmuls that
1571/// have the same LHS. \returns true if any matmuls are merged in the Function
1572/// \p F.
1573static bool mergeMatMuls(Function *F, bool mergeOnLHS) {
1574 bool changed = false;
1575 auto &nodes = F->getNodes();
1576
1577 // A map to record the list of matrix multipliers that use each node
1578 // value either as a right-hand-side user or a left-hand-user.
1579 llvm::DenseMap<Node *, std::vector<MatMulNode *>> matrixUsers;
1580
1581 // Collect the list of nodes that are used by the matrix multiplier.
1582 for (auto &node : nodes) {
1583 if (auto *MM = dyn_cast<MatMulNode>(&node)) {
1584 // Do not try to merge quantized matrix multiplications because their
1585 // quantized parameters may not match. Until we implement the logic to
1586 // match the scale and offset just avoid the optimization.
1587 if (MM->getResult().getType()->isQuantizedType()) {
1588 continue;
1589 }
1590
1591 if (!mergeOnLHS) {
1592 matrixUsers[MM->getLHS().getNode()].push_back(MM);
1593 } else {
1594 matrixUsers[MM->getRHS().getNode()].push_back(MM);
1595 }
1596 }
1597 }
1598
1599 // Merge matrices.
1600 for (auto &it : matrixUsers) {
1601 auto &MMs = it.second;
1602
1603 // Collects the LHS or RHS values to merge.
1604 std::vector<NodeValue> lhsOrRhs;
1605
1606 // For each matmul that depends on the matrix.
1607 std::unordered_set<MatMulNode *> skippedMMs;
1608 std::string firstMMName;
1609 std::string firstMatrixName;
1610 for (auto *MM : MMs) {
1611 auto I = mergeOnLHS ? MM->getLHS() : MM->getRHS();
1612 // The operands to the matrix multiplier should not depend on one another
1613 // or else we won't be able to get rid of the original matrix
1614 // multiplication.
1615 if (mayDependOnAny(lhsOrRhs, I.getNode())) {
1616 skippedMMs.insert(MM);
1617 continue;
1618 }
1619 lhsOrRhs.push_back(I);
1620 if (firstMMName.empty()) {
1621 firstMMName = MM->getName().str();
1622 }
1623 if (firstMatrixName.empty()) {
1624 firstMatrixName = I.getNode()->getName().str();
1625 }
1626 }
1627
1628 // We need to have at least two matrices to merge.
1629 if (lhsOrRhs.size() < 2) {
1630 continue;
1631 }
1632
1633 // Merge the matmul:
1634 auto *CC = F->createConcat(firstMatrixName + "_merge", lhsOrRhs,
1635 mergeOnLHS ? 0 : 1);
1636 auto *MM =
1637 F->createMatMul(firstMMName + "_bigMatMul", mergeOnLHS ? CC : it.first,
1638 mergeOnLHS ? it.first : CC);
1639
1640 // Slice the output so that other nodes can consume each slice separately.
1641 dim_t O = MM->getResult().dims()[mergeOnLHS ? 1 : 0];
1642 dim_t start = 0;
1643 for (auto *origMM : MMs) {
1644 if (skippedMMs.count(origMM)) {
1645 continue;
1646 }
1647 dim_t H = origMM->getResult().dims()[mergeOnLHS ? 0 : 1];
1648 auto startIndices = mergeOnLHS ? std::array<dim_t, 2>({start, 0})
1649 : std::array<dim_t, 2>({0, start});
1650 auto endIndices = mergeOnLHS ? std::array<dim_t, 2>({start + H, O})
1651 : std::array<dim_t, 2>({O, start + H});
1652 auto *ex = F->createSlice(origMM->getName().str() + "_extract", MM,
1653 startIndices, endIndices);
1654 start += H;
1655 origMM->getResult().replaceAllUsesOfWith(ex);
1656 changed = true;
1657 }
1658 }
1659 return changed;
1660}
1661
1662// Merge several two or more multiple matrix multiplications that share the same
1663// RHS into a single large matmul. The large matmul is more likely to utilize
1664// the hardware. The result of the big matmul is the concatenated results.
1665//
1666// ____ _________ _________
1667// ---- | | | | M| A * C |
1668// M| A | T| B | * K| C | = |---------|
1669// ---- , | | | | T| B * C |
1670// K ---- --------- ---------
1671// K R R
1672bool MergeMatMulOnLHS::run(Function *F, const CompilationContext &cctx) {
1673 LOG_SCOPE(F->getLogContext(), getName());
1674 return mergeMatMuls(F, /* mergeOnLHS */ true);
1675}
1676
1677// Merge several two or more multiple matrix multiplications that share the same
1678// LHS into a single large matmul. The large matmul is more likely to utilize
1679// the hardware. The result of the big matmul is the concatenated results.
1680//
1681// ____ _________ _________
1682// ---- | | | | | A | A |
1683// M| A | * K| B | , K| C | = M| * | * |
1684// ---- | | | | | B | C |
1685// K ---- --------- ---------
1686// R S R S
1687bool MergeMatMulOnRHS::run(Function *F, const CompilationContext &cctx) {
1688 LOG_SCOPE(F->getLogContext(), getName());
1689 return mergeMatMuls(F, /* mergeOnLHS */ false);
1690}
1691
1692bool MergePadIntoConvolution::run(Function *F, const CompilationContext &cctx) {
1693 LOG_SCOPE(F->getLogContext(), getName());
1694 bool changed = false;
1695 for (auto &node : F->getNodes()) {
1696 auto *CN = dyn_cast<ConvolutionNode>(&node);
1697 if (!CN) {
1698 continue;
1699 }
1700
1701 auto *PN = dyn_cast<PadNode>(CN->getInput());
1702 if (!PN) {
1703 continue;
1704 }
1705
1706 // Convolution only supports padding with 0 constant
1707 if ((PN->getMode() != PaddingMode::CONSTANT) || (PN->getValue() != 0.f)) {
1708 continue;
1709 }
1710
1711 // The convolution needs to be the unique user
1712 if (!PN->hasOneUse()) {
1713 continue;
1714 }
1715
1716 // Compute the new padding.
1717 // Note: - convolution only supports positive padding
1718 // - the convolution takes NHWC input tensors.
1719 bool canMerge = true;
1720 auto padPads = PN->getPads();
1721 auto convPads = CN->getPads();
1722
1723 // For now, there is a different interpretation of the ONNX spec for
1724 // Pad and Convolution. The 'pads' array won't have the same size because
1725 // only spatial dimensions are specified for the convolution while all
1726 // dimensions are specified for Pad.
1727
1728 // The merge can apply only if no padding is requested for non spatial
1729 // dimensions.
1730 if ((padPads[0] != 0) || (padPads[3] != 0) || (padPads[4] != 0) ||
1731 (padPads[7] != 0)) {
1732 continue;
1733 }
1734
1735 // Compute new spatial padding.
1736 const int H_INDEX = 1;
1737 std::vector<unsigned_t> newConvPads(4);
1738 auto numDims = PN->getResult().dims().size();
1739 for (size_t i = 0; i < 2; i++) {
1740 // Two pad integers per dimension (begin and end padding).
1741 for (size_t j = 0; j < 2; j++) {
1742 int newConvPadSigned =
1743 padPads[(i + H_INDEX) + j * numDims] + int(convPads[i + j * 2]);
1744 if (newConvPadSigned < 0) {
1745 canMerge = false;
1746 break;
1747 }
1748 newConvPads[i + j * 2] = unsigned_t(newConvPadSigned);
1749 }
1750 }
1751 if (!canMerge) {
1752 continue;
1753 }
1754
1755 // New Convolution
1756 auto *newCN = F->createConv(CN->getName(), PN->getInput(), CN->getFilter(),
1757 CN->getBias(), CN->getResult().getType(),
1758 CN->getKernels(), CN->getStrides(), newConvPads,
1759 CN->getGroup(), CN->getDilation());
1760 newCN->setFusedActivation(CN->getFusedActivation());
1761 newCN->setFusedActivationArgs(CN->getFusedActivationArgs());
1762
1763 CN->getResult().replaceAllUsesOfWith(newCN);
1764 changed = true;
1765 }
1766
1767 return changed;
1768}
1769
1770/// Merge Transpose into MatMul or FC.
1771/// MatMul/FC(Reshape(Transpose(X)), Weights) ->
1772/// -> MatMul/FC(Reshape(X), reordered Weights)
1773/// Common sequence while using NCHW as input layout, because GLOW convolution
1774/// layout is NHWC:
1775/// Transpose([N, H, W, C]) -> [N, C, H, W]
1776/// Reshape([N, C, H, W]) -> [N, C * H * W]
1777/// MatMul/FC([N, C * H * W], [C * H * W, K]) -> [N, K]
1778bool MergeTransposeIntoMatMulOrFC::run(Function *F,
1779 const CompilationContext &cctx) {
1780 LOG_SCOPE(F->getLogContext(), getName());
1781 bool changed = false;
1782 for (auto &node : F->getNodes()) {
1783 auto *MMN = dyn_cast<MatMulNode>(&node);
1784 auto *FCN = dyn_cast<FullyConnectedNode>(&node);
1785 Constant *W;
1786 ReshapeNode *RN;
1787
1788 // Node is either MatMul or FC.
1789 if (MMN) {
1790 W = dyn_cast<Constant>(MMN->getRHS());
1791 RN = dyn_cast<ReshapeNode>(MMN->getLHS());
1792 } else if (FCN) {
1793 W = dyn_cast<Constant>(FCN->getWeights());
1794 RN = dyn_cast<ReshapeNode>(FCN->getInput());
1795 } else {
1796 continue;
1797 }
1798
1799 // Weights node (or MatMul RHS) is constant.
1800 if (!W) {
1801 continue;
1802 }
1803
1804 // Linearizing Reshape precedes MatMul/FC.
1805 // The first dimension must be kept unchanged, the others are linearized.
1806 if (!RN || !RN->hasOneUse() ||
1807 RN->getInput().dims()[0] != RN->getDims()[0]) {
1808 continue;
1809 }
1810
1811 // Transpose precedes Reshape.
1812 // The first dimension must be kept unchanged, the others can be shuffled
1813 // in any way.
1814 auto *TN = dyn_cast<TransposeNode>(RN->getInput());
1815 if (!TN || !TN->hasOneUse() || TN->getShuffle()[0] != 0) {
1816 continue;
1817 }
1818
1819 // MatMul/FC weights tensor is 2D. De-linearize the first dimension
1820 // according to Transpose output layout (original shape) and input layout
1821 // (reordered shape). Then we can do weights reordering by simply
1822 // transposing the tensor from original shape to reordered shape.
1823 //
1824 // Example for [N, H, W, C] -> [N, C, H, W] transpose (common case):
1825 // De-linearized original shape: [C * H * W, K] -> [C, H, W, K]
1826 // De-linearized reordered shape: [C * H * W, K] -> [H, W, C, K]
1827 // Reorder weights by transposing them: [C, H, W, K] -> [H, W, C, K]
1828 ShapeVector shape, newShape;
1829 llvm::SmallVector<unsigned_t, max_tensor_dimensions> shuffle;
1830 shuffle.resize(TN->getShuffle().size() - 1);
1831 for (size_t i = 1; i < TN->getShuffle().size(); i++) {
1832 shape.push_back(TN->getResult().getType()->dims()[i]);
1833 newShape.push_back(TN->getInput().getType()->dims()[i]);
1834 shuffle[TN->getShuffle()[i] - 1] = i - 1;
1835 }
1836 shape.push_back(W->dims()[1]);
1837 newShape.push_back(W->dims()[1]);
1838 shuffle.push_back(TN->getShuffle().size() - 1);
1839 auto reshapedWTy =
1840 F->getParent()->uniqueTypeWithNewShape(W->getType(), shape);
1841 auto reshapedNewWTy =
1842 F->getParent()->uniqueTypeWithNewShape(W->getType(), newShape);
1843
1844 // New reordered weights.
1845 auto *newW = F->getParent()->createConstant(W->getType(), W->getName(),
1846 W->getLayout());
1847 Tensor reshapedSrc(W->getPayload().getUnsafePtr(), reshapedWTy);
1848 Tensor reshapedDst(newW->getPayload().getUnsafePtr(), reshapedNewWTy);
1849 reshapedSrc.transpose(&reshapedDst, shuffle);
1850
1851 // New Reshape and MatMul/FC.
1852 auto *newRN =
1853 F->createReshape(RN->getName(), TN->getInput(), RN->getDims());
1854 if (MMN) {
1855 auto *newMMN = F->createMatMul(MMN->getName(), MMN->getResult().getType(),
1856 newRN, newW);
1857 MMN->getResult().replaceAllUsesOfWith(newMMN);
1858 } else if (FCN) {
1859 auto *newFCN =
1860 F->createFullyConnected(FCN->getName(), newRN, newW, FCN->getBias(),
1861 FCN->getResult().getType());
1862 FCN->getResult().replaceAllUsesOfWith(newFCN);
1863 } else {
1864 llvm_unreachable("Unexpected node kind");
1865 }
1866
1867 changed = true;
1868 }
1869
1870 return changed;
1871}
1872
1873/// \returns True if the two slices \p A and \p B access consecutive spacial
1874/// regions on the \p dim dimension. For example Slice(0..10) Slice(10..50)
1875/// are consecutive but Slice(0..10) Slice(20..30) are not.
1876static bool areSlicesConsecutive(SliceNode *A, SliceNode *B, unsigned_t dim) {
1877 // The slices must extract from the same input.
1878 if (A->getInput() != B->getInput()) {
1879 return false;
1880 }
1881
1882 auto aStart = A->getStart();
1883 auto bStart = B->getStart();
1884
1885 assert(aStart.size() > dim && "Invalid dimension");
1886
1887 for (size_t i = 0, e = aStart.size(); i < e; i++) {
1888 if (i == dim) {
1889 auto resSize = A->getResult().dims();
1890 // This is the stride (the delta between the two slices on the requested
1891 // dimension).
1892 auto delta = bStart[i] - aStart[i];
1893 // The distance between the two slices must be identical to the size of
1894 // the result.
1895 if (resSize[dim] != delta) {
1896 return false;
1897 }
1898
1899 continue;
1900 }
1901
1902 // The non-consecutive dimensions must be identical.
1903 if (aStart[i] != bStart[i]) {
1904 return false;
1905 }
1906 }
1907
1908 return true;
1909}
1910
1911/// \returns True if the two slices \p A and \p B access consecutive spacial
1912/// regions along some dimension. The dimension is stored in \p dim.
1913/// For example, Slice((0, 0)..(1, 10)) Slice((1, 0)..(2, 10)) are consecutive
1914/// along dim=0.
1915static bool findConsecutiveSliceDim(SliceNode *A, SliceNode *B, int *dim) {
1916 // The slices must extract from the same input.
1917 if (A->getInput() != B->getInput()) {
1918 return false;
1919 }
1920
1921 for (size_t i = 0, e = A->getStart().size(); i < e; i++) {
1922 if (areSlicesConsecutive(A, B, i)) {
1923 *dim = i;
1924 return true;
1925 }
1926 }
1927
1928 return false;
1929}
1930
1931bool ConvertBroadcastedBatchMatMul::run(Function *F,
1932 const CompilationContext &cctx) {
1933 LOG_SCOPE(F->getLogContext(), getName());
1934 bool changed = false;
1935 for (auto &node : F->getNodes()) {
1936 BatchMatMulNode *BMMN = dyn_cast<BatchMatMulNode>(&node);
1937 if (!BMMN) {
1938 continue;
1939 }
1940
1941 NodeValue LHS = BMMN->getLHS();
1942 NodeValue RHS = BMMN->getRHS();
1943
1944 // If RHS is a Tile/Broadcast along axis 0 and the input's dims()[0] == 1,
1945 // then the RHS is fully broadcasted and we can perform the optimization.
1946 TileNode *TN = dyn_cast<TileNode>(RHS);
1947 BroadcastNode *BN = dyn_cast<BroadcastNode>(RHS);
1948 if (!TN && !BN) {
1949 continue;
1950 }
1951 const unsigned_t axis = TN ? TN->getAxis() : BN->getAxis();
1952 const dim_t dim0 = TN ? TN->getInput().dims()[0] : BN->getInput().dims()[0];
1953 if (axis != 0 || dim0 != 1) {
1954 continue;
1955 }
1956
1957 // If this is a Broadcast, check if the first dimension is the only one
1958 // that's tiled. If so, then we can treat this as the same as a
1959 // Tile. Otherwise we must keep around a Broadcast for everything but the
1960 // first dimension.
1961 NodeValue singleTileNV;
1962 if (BN) {
1963 ShapeVector newBNDims(BN->getResult().dims().begin(),
1964 BN->getResult().dims().end());
1965 newBNDims[0] = 1;
1966 if (!BN->getInput().dims().equals(newBNDims)) {
1967 BroadcastNode *newBN = F->createBroadcast(BN->getName(), BN->getInput(),
1968 newBNDims, /* axis */ 0);
1969 singleTileNV = newBN->getResult();
1970 } else {
1971 // This Broadcast is equivalent to a Tile.
1972 singleTileNV = BN->getInput();
1973 }
1974 } else {
1975 singleTileNV = TN->getInput();
1976 }
1977
1978 // Can now convert the broadcasted BatchMatMul to a MatMul.
1979 // LHS = {numBatches, N, M}
1980 // RHS = {M, P}
1981 // Multiply each LHS matrix {N, M} by RHS {M, P} to get final matrix
1982 // {numBatches, N, P}
1983 const dim_t numBatches = LHS.dims()[0];
1984 const dim_t N = LHS.dims()[1];
1985 const dim_t M = LHS.dims()[2];
1986 const dim_t P = RHS.dims()[2];
1987 auto name = BMMN->getName();
1988
1989 // Reshape the LHS to be a two-dimensional matrix, where each batch is
1990 // essentially concatenated onto itself in the 0th dimension.
1991 ReshapeNode *reshapeLHS =
1992 F->createReshape(name.str() + ".reshapeLHS", LHS, {numBatches * N, M});
1993 // Squeeze out the first dimension of the original Tile's input.
1994 ReshapeNode *squeezedRHS =
1995 F->createSqueeze(name.str() + ".squeezedRHS", singleTileNV, {0});
1996
1997 // Perform a normal matmul, implementing the batch matmul.
1998 MatMulNode *MMN = F->createMatMul(name, reshapeLHS, squeezedRHS);
1999
2000 assert(MMN->getResult().dims()[0] == (numBatches * N) &&
2001 "Incorrect resulting dimension for batch matmul");
2002 assert(MMN->getResult().dims()[1] == P &&
2003 "Incorrect resulting dimension for batch matmul");
2004
2005 // Reshape the result back to the expected batch output shape, with the
2006 // first dimension the number of batches.
2007 ReshapeNode *finalReshape = F->createReshape(name.str() + ".reshapeResult",
2008 MMN, {numBatches, N, P});
2009 BMMN->getResult().replaceAllUsesOfWith(finalReshape);
2010 changed = true;
2011 }
2012 return changed;
2013}
2014
2015/// Find a sequence of slices in \p input that span the whole input.
2016/// \returns True if a group of slices that span the whole input was found.
2017/// The order of the slices is recorded in \p order.
2018static bool findSlicesThatSpanInput(llvm::ArrayRef<SliceNode *> input,
2019 unsigned_t dimension,
2020 std::vector<SliceNode *> &order) {
2021 // This is the 'last' slice to be found in the sequence of slices.
2022 SliceNode *lastSlice = nullptr;
2023
2024 // Find the 'first' slice in the sequence.
2025 for (SliceNode *SN : input) {
2026 auto start = SN->getStart();
2027
2028 // Invalid dimension.
2029 if (start.size() <= dimension) {
2030 return false;
2031 }
2032
2033 // Check if this slice extract the first element.
2034 if (start[dimension] == 0) {
2035 // We found the first element.
2036 lastSlice = SN;
2037 order.push_back(lastSlice);
2038 break;
2039 }
2040 }
2041
2042 // We could not find a 'first' slice.
2043 if (!lastSlice) {
2044 return false;
2045 }
2046
2047 // Now that we've found the first slice in the sequence, try to order the
2048 // rest of the slices after the first one.
2049 bool addedSlice = true;
2050 while (addedSlice) {
2051 addedSlice = false;
2052
2053 // For each slice:
2054 for (SliceNode *SN : input) {
2055 // Ignore slices of invalid types. Ignore shapes for now, that's checked
2056 // next while ignoring the axis dimension.
2057 if (!lastSlice->getResult().getType()->isEqual(
2058 *SN->getResult().getType(),
2059 /* allowDifferentShape */ true)) {
2060 continue;
2061 }
2062
2063 // Check if shapes match except for the axis dimension.
2064 bool skip = false;
2065 for (size_t i = 0, e = lastSlice->getResult().dims().size(); i < e; ++i) {
2066 if (i != dimension &&
2067 lastSlice->getResult().dims()[i] != SN->getResult().dims()[i]) {
2068 skip = true;
2069 break;
2070 }
2071 }
2072 if (skip) {
2073 continue;
2074 }
2075
2076 // Check if SN comes after the last slice in the sequence.
2077 if (areSlicesConsecutive(lastSlice, SN, dimension)) {
2078 // Add the consecutive slice and schedule another iteration.
2079 lastSlice = SN;
2080 order.push_back(lastSlice);
2081 addedSlice = true;
2082 continue;
2083 }
2084 }
2085 } // While adding new slices.
2086
2087 // Check that the last slice completes the tensor.
2088 auto startCoor = lastSlice->getStart();
2089 auto resDim = lastSlice->getResult().getType()->dims();
2090 auto inDim = lastSlice->getInput().getType()->dims();
2091
2092 // Check if for all dimensions, the size of the result tensor plus the start
2093 // coordinate matches the size of the tensor.
2094 for (int i = 0, e = startCoor.size(); i < e; i++) {
2095 if (startCoor[i] + resDim[i] != inDim[i]) {
2096 return false;
2097 }
2098 }
2099
2100 // Report success if we found at least two slices that extract from the
2101 // input.
2102 return order.size() > 1;
2103}
2104
2105/// Merge multiple batched add nodes into a large batched-add node.
2106bool MergeBatchedAdd::run(Function *F, const CompilationContext &cctx) {
2107 LOG_SCOPE(F->getLogContext(), getName());
2108 bool changed = false;
2109 auto &nodes = F->getNodes();
2110
2111 // We index the batched add nodes by the slice operand.
2112 llvm::DenseMap<Node *, std::vector<BatchedAddNode *>> rightBAUsers;
2113
2114 // Collect all of the batched add nodes and index them by the 'slice'
2115 // operand.
2116 for (auto &node : nodes) {
2117 if (auto *BA = dyn_cast<BatchedAddNode>(&node)) {
2118 rightBAUsers[BA->getSlice().getNode()].push_back(BA);
2119 }
2120 }
2121
2122 // For each 'slice' that batched add nodes access:
2123 for (auto &it : rightBAUsers) {
2124 auto &BAs = it.second;
2125
2126 // Collects the left-hand-side operands that the batched-adds add into. We
2127 // only collect 'slice' nodes.
2128 std::vector<SliceNode *> slices;
2129
2130 for (auto *BA : BAs) {
2131 if (auto *S = dyn_cast<SliceNode>(BA->getBatch().getNode())) {
2132 slices.push_back(S);
2133 }
2134 }
2135
2136 // Check if the slice nodes that we've collected cover a whole tensor.
2137 std::vector<SliceNode *> order;
2138 bool found = findSlicesThatSpanInput(slices, 0, order);
2139
2140 if (!found) {
2141 continue;
2142 }
2143
2144 // We found a sequence of batched-add-slice that cover the input tensor.
2145 // We can transform the graph and create one big batched-add.
2146 std::vector<Node *> newSlices;
2147 assert(order.size() > 1 && "order must contain at least 2 SliceNodes.");
2148 SliceNode *S = llvm::cast<SliceNode>(order[0]);
2149 auto *mergedBA = F->createBatchedAdd("mergedBA", S->getInput(), it.first);
2150
2151 // Create the new slices. These slices will replace the original scalar
2152 // batched-add nodes.
2153 for (auto *orig : order) {
2154 newSlices.push_back(F->createSlice(orig->getName(), mergedBA,
2155 orig->getStart(),
2156 orig->getResult().getType()));
2157 }
2158
2159 // Replace the original individual batched adds with corresponding slices
2160 // from the new merged batch add.
2161 for (auto *BA : BAs) {
2162 for (int i = 0, e = order.size(); i < e; i++) {
2163 if (BA->getBatch().getNode() == order[i]) {
2164 BA->getResult().replaceAllUsesOfWith(newSlices[i]);
2165 changed = true;
2166 break;
2167 }
2168 }
2169 }
2170
2171 } // for each batched-add group.
2172 return changed;
2173}
2174
2175/// Optimize ReduceMean configuration with AvgPool if possible: last two axes
2176/// in a 4D input must be reduced.
2177bool OptimizeReduceMean::run(Function *F, const CompilationContext &cctx) {
2178 LOG_SCOPE(F->getLogContext(), getName());
2179 bool changed = false;
2180 auto &nodes = F->getNodes();
2181
2182 // For each node:
2183 for (auto &node : nodes) {
2184 if (auto *RM = dyn_cast<BatchedReduceMeanNode>(&node)) {
2185
2186 // Input shape must be 4D.
2187 if (RM->getBatch().dims().size() != 4) {
2188 continue;
2189 }
2190
2191 // Last two axes must be reduced.
2192 auto axes = RM->getAxes();
2193 if (axes.size() != 2 || std::count(axes.begin(), axes.end(), 2) != 1 ||
2194 std::count(axes.begin(), axes.end(), 3) != 1) {
2195 continue;
2196 }
2197
2198 // RM is already shaped to have the required output shape.
2199 NodeValue in = RM->getBatch();
2200
2201 std::vector<unsigned_t> kernels = {static_cast<unsigned_t>(in.dims()[2]),
2202 static_cast<unsigned_t>(in.dims()[3])};
2203 std::vector<unsigned_t> strides = {1, 1};
2204 std::vector<unsigned_t> pads = {0, 0, 0, 0};
2205
2206 // TODO: Fix bad assumption? See issue 3499, for now workaround it.
2207 // In Glow, AvgPool expects NHWC.
2208 auto *TR1 = F->createTranspose(
2209 RM->getName().str() + ".transposeNCHW2NHWC", in, NCHW2NHWC, "NHWC");
2210 auto *AP = F->createAvgPool(RM->getName().str() + ".avgPool", TR1,
2211 kernels, strides, pads);
2212 if (AP->getResult().getType()->isQuantizedType()) {
2213 auto TypeAP = F->getParent()->uniqueTypeWithNewQuantParams(
2214 AP->getResult().getType(), RM->getResult().getType());
2215 AP->getResult().setType(TypeAP);
2216 }
2217 auto *TR2 = F->createTranspose(
2218 RM->getName().str() + ".transposeNHWC2NCHW", AP, NHWC2NCHW, "NCHW");
2219
2220 // AvgPool keeps original shape. Add reshape to match expected output.
2221 std::vector<dim_t> shape = TR2->getResult().dims();
2222
2223 ShapeVector shapeAxes(axes.begin(), axes.end());
2224
2225 // Axes must be sorted for correct erase.
2226 std::sort(shapeAxes.rbegin(), shapeAxes.rend());
2227 for (const auto &axis : shapeAxes) {
2228 shape.erase(shape.begin() + axis);
2229 }
2230
2231 auto *RN = F->createReshape(RM->getName().str() + ".reshape", TR2, shape);
2232
2233 RM->getResult().replaceAllUsesOfWith(RN);
2234 changed = true;
2235 continue;
2236 }
2237 } // For all nodes in the graph.
2238
2239 return changed;
2240}
2241
2242/// \returns a uniquely used Constant with the same contents as \p node. If \p
2243/// node is not a Constant then \returns a nullptr. If \node is a Constant which
2244/// has a single use, \p node is returned. If \node is a Constant which has
2245/// multiple uses, then \returns a new duplicate Constant that has the same
2246/// contents as \p node contained in \p M.
2247static Constant *getUniquelyUsedConstant(Module *M, Node &node) {
2248 Constant *constant = dyn_cast<Constant>(&node);
2249 if (!constant) {
2250 return nullptr;
2251 }
2252
2253 if (constant->hasOneUse() && areAllFunctionsLoaded(M)) {
2254 return constant;
2255 }
2256
2257 // If constant has more than one use, duplicate it and return the duplicate.
2258 auto *NC = M->createConstant(constant->getType(), constant->getName(),
2259 constant->getLayout());
2260 NC->getPayloadMutable().assign(&constant->getPayload());
2261 return NC;
2262}
2263
2264/// Normalize the weight of \p CV with what \p BN is doing, given containing
2265/// Module \p M. \returns whether or not the normalization was possible.
2266template <typename ElemTy>
2267bool normalizeWeights(Module *M, ConvolutionNode &CV,
2268 BatchNormalizationNode &BN) {
2269 static_assert(
2270 std::is_floating_point<ElemTy>::value ||
2271 std::is_same<float16_t,
2272 typename std::remove_cv<ElemTy>::type>::value ||
2273 std::is_same<bfloat16_t,
2274 typename std::remove_cv<ElemTy>::type>::value,
2275 "This implementation is for floating-point values only");
2276
2277 Constant *filterC = getUniquelyUsedConstant(M, *CV.getFilter().getNode());
2278 Constant *cbiasC = getUniquelyUsedConstant(M, *CV.getBias().getNode());
2279
2280 if (!filterC || !cbiasC) {
2281 return false;
2282 }
2283
2284 // Perform normalization when Convolution layout is NHWC and BatchNorm
2285 // ChannelIdx points to C.
2286 if (BN.getChannelIdx() != 3) {
2287 return false;
2288 }
2289
2290 // Set the new filter and bias on CV if necessary.
2291 if (filterC != CV.getFilter().getNode()) {
2292 CV.getParent()->getLogContext()->logNodeInputChange(
2293 CV, CV.getNthInput(ConvolutionNode::FilterIdx), filterC);
2294 CV.setNthInput(ConvolutionNode::FilterIdx, filterC);
2295 }
2296 if (cbiasC != CV.getBias().getNode()) {
2297 CV.getParent()->getLogContext()->logNodeInputChange(
2298 CV, CV.getNthInput(ConvolutionNode::BiasIdx), cbiasC);
2299 CV.setNthInput(ConvolutionNode::BiasIdx, cbiasC);
2300 }
2301
2302 // First, BN computation can be phrased as follows:
2303 //
2304 // (X - mean) * (1.0 / sqrt(var + eps)) * bn_scale + bias
2305 //
2306 // Thus, we can rewrite bn_scale as:
2307 // X * bn_scale * 1.0 / (sqrt(var + eps)) +
2308 // (bias - mean * (1.0 / sqrt(var + eps)) * bn_scale)
2309 //
2310 // Thus, can just have the affine transform:
2311 //
2312 // X * A + B
2313 //
2314 // where
2315 //
2316 // A = bn_scale * 1.0 / (sqrt(running_var + eps))
2317 // B = (bias - mean * (1.0 / sqrt(var + eps)) * bn_scale)
2318 //
2319 // Now, we have that the computation made is the following:
2320 //
2321 // ((X `conv` W) + b) * A + B
2322 //
2323 // Then, we can simply fuse this as follows:
2324 //
2325 // (X `conv` (W * A)) + b * A + B
2326 //
2327 // which is simply
2328 //
2329 // (X `conv` Q) + C
2330 //
2331 // where
2332 //
2333 // Q = W * A
2334 // C = b * A + B
2335
2336 Constant *scaleC = cast<Constant>(BN.getScale());
2337 Constant *biasC = cast<Constant>(BN.getBias());
2338 Constant *meanC = cast<Constant>(BN.getMean());
2339 Constant *var = cast<Constant>(BN.getVar());
2340
2341 auto filterH = filterC->getHandle<ElemTy>();
2342
2343 auto cbiasH = cbiasC->getHandle<ElemTy>();
2344
2345 auto scaleH = scaleC->getHandle<ElemTy>();
2346 auto biasH = biasC->getHandle<ElemTy>();
2347 auto meanH = meanC->getHandle<ElemTy>();
2348 auto varH = var->getHandle<ElemTy>();
2349
2350 // Update the filter/bias constants of the Conv node.
2351 auto epsilon = BN.getEpsilon();
2352 for (size_t i = 0, e = filterH.size(); i < e; i++) {
2353 // Dimension zero is the 'channel' dimension. If we ever change the
2354 // layout of the filter then we need to change this optimization.
2355 dim_t channelId = filterH.getDimForPtr(0, i);
2356 float value = varH.at({channelId});
2357 float stdvar = 1.0f / std::sqrt(value + epsilon);
2358 float gamma = scaleH.at({channelId});
2359 float A = gamma * stdvar;
2360 filterH.raw(i) = ElemTy(float(filterH.raw(i)) * A);
2361 }
2362
2363 for (size_t i = 0, e = cbiasH.size(); i < e; i++) {
2364 // Dimension zero is the 'channel' dimension. If we ever change the
2365 // layout of the filter then we need to change this optimization.
2366 dim_t channelId = cbiasH.getDimForPtr(0, i);
2367 float mu = meanH.at({channelId});
2368 float value = varH.at({channelId});
2369 float stdvar = 1.0f / std::sqrt(value + epsilon);
2370 float gamma = scaleH.at({channelId});
2371 float beta = biasH.at({channelId});
2372 float A = gamma * stdvar;
2373 float B = beta - mu * A;
2374 cbiasH.raw(i) = ElemTy(float(cbiasH.raw(i)) * A + B);
2375 }
2376 return true;
2377}
2378
2379/// Gets Constant or returns nullptr if input is not Constant.
2380/// Skips QuantizeNode if present.
2381static Constant *getConstant(const NodeValue &NV) {
2382 Node *N = NV.getNode();
2383 if (isa<QuantizeNode>(N)) {
2384 N = N->getNthInput(QuantizeNode::InputIdx);
2385 }
2386 return dyn_cast<Constant>(N);
2387}
2388
2389bool OptimizeBatchNorm::run(Function *F, const CompilationContext &cctx) {
2390 LOG_SCOPE(F->getLogContext(), getName());
2391 bool changed = false;
2392 auto &nodes = F->getNodes();
2393 auto *M = F->getParent();
2394
2395 // For each node:
2396 for (auto &node : nodes) {
2397 auto *BN = dyn_cast<BatchNormalizationNode>(&node);
2398 if (!BN) {
2399 continue;
2400 }
2401
2402 // Remove BN if mean,var,eps,scale,beta values make it redundant as per
2403 // expression (X - mean) * (1.0 / sqrt(var + eps)) * scale + bias.
2404 float scale, bias, mean, var;
2405 auto *scaleC = getConstant(BN->getScale());
2406 auto *biasC = getConstant(BN->getBias());
2407 auto *meanC = getConstant(BN->getMean());
2408 auto *varC = getConstant(BN->getVar());
2409 if (scaleC && biasC && meanC && varC &&
2410 isUniformConstant<float>(*scaleC, scale) &&
2411 isUniformConstant<float>(*biasC, bias) &&
2412 isUniformConstant<float>(*meanC, mean) &&
2413 isUniformConstant<float>(*varC, var)) {
2414 float eps = BN->getEpsilon();
2415 // Relaxed redundancy check based on reduced BN expression so that A
2416 // is 1.0 and B is 0.0 in Y = A*X + B where,
2417 // A = scale * (1.0 / (sqrt(var + eps))
2418 // B = (bias - mean * (1.0 / sqrt(var + eps)) * scale)
2419 if (bias == mean && (std::sqrt(var + eps) == scale)) {
2420 BN->getResult().replaceAllUsesOfWith(BN->getInput());
2421 changed = true;
2422 continue;
2423 }
2424 }
2425
2426 // Merge the Batch Normalization operation into the convolution that comes
2427 // before it by updating the weights of the filter and bias.
2428 auto *CV = dyn_cast<ConvolutionNode>(BN->getInput());
2429 if (!CV) {
2430 continue;
2431 }
2432
2433 // We can't modify conv operators that have multiple users.
2434 if (!CV->hasOneUse()) {
2435 continue;
2436 }
2437
2438 bool normalizationHappened = false;
2439 switch (CV->getElementType(ConvolutionNode::ResultIdx)) {
2440 case ElemKind::FloatTy:
2441 normalizationHappened = normalizeWeights<float>(M, *CV, *BN);
2442 break;
2443 case ElemKind::Float16Ty:
2444 normalizationHappened = normalizeWeights<float16_t>(M, *CV, *BN);
2445 break;
2446 case ElemKind::BFloat16Ty:
2447 normalizationHappened = normalizeWeights<bfloat16_t>(M, *CV, *BN);
2448 break;
2449 default:
2450 llvm_unreachable("Type not supported");
2451 }
2452
2453 if (!normalizationHappened) {
2454 continue;
2455 }
2456
2457 // Take the predicate of what was expected for the output.
2458 CV->setPredicate(BN->getPredicate());
2459 BN->getResult().replaceAllUsesOfWith(CV);
2460 changed = true;
2461 continue;
2462 } // For all nodes in the graph.
2463 return changed;
2464}
2465
2466/// If \p node has uses, and all of them have one user node, return this user
2467/// node. Otherwise, return nullptr.
2468static Node *getOnlyUser(Node &node) {
2469 if (!node.hasUsers()) {
2470 // No users.
2471 return nullptr;
2472 }
2473 Node *first = node.getUsers().front().getUser();
2474 for (auto &U : node.getUsers()) {
2475 if (U.getUser() != first) {
2476 // Multiple users.
2477 return nullptr;
2478 }
2479 }
2480 // One user.
2481 return first;
2482}
2483
2484/// Checks that \p tile sub-tensors along \p chId axis repeat.
2485/// It also extracts the repeated dimension values into \p result.
2486static bool isConstBroadcasted(std::vector<float> &result, const Constant &tile,
2487 int32_t chId) {
2488 // TODO: This limitation can be lifted, but that is for simplicity.
2489 if (tile.getType()->dims().size() != 4) {
2490 return false;
2491 }
2492 // TODO: We can also support quantized constants if there is a need in it.
2493 if (tile.getType()->getElementType() != ElemKind::FloatTy) {
2494 return false;
2495 }
2496 auto handle = tile.getPayload().getHandle<float>();
2497 glow::dim_t n, h, w, c;
2498 for (c = 0; c < tile.getType()->dims()[chId]; c++) {
2499 std::vector<glow::dim_t> dims = {c, 0, 0, 0};
2500 std::rotate(dims.begin(), dims.begin() + dims.size() - chId, dims.end());
2501 const float expected = handle.at(llvm::ArrayRef<glow::dim_t>(dims));
2502 for (n = 0; n < tile.getType()->dims()[(chId + 1) % 4]; n++) {
2503 for (h = 0; h < tile.getType()->dims()[(chId + 2) % 4]; h++) {
2504 for (w = 0; w < tile.getType()->dims()[(chId + 3) % 4]; w++) {
2505 std::vector<glow::dim_t> dimsE = {c, n, h, w};
2506 std::rotate(dimsE.begin(), dimsE.begin() + dimsE.size() - chId,
2507 dimsE.end());
2508 if (handle.at(llvm::ArrayRef<glow::dim_t>(dims)) != expected) {
2509 return false;
2510 }
2511 }
2512 }
2513 }
2514 result[c] = expected;
2515 }
2516 return true;
2517}
2518
2519/// Collects the longest chain of arithmetic operations with constants starting
2520/// from \p start. Updates \p scale and \p bias as it collects the operands.
2521/// \p returns the last node in the chain.
2522static NodeValue collectArithmeticChain(Function *F, NodeValue start,
2523 Constant &scale, Constant &bias,
2524 int32_t chIdx) {
2525
2526 Node *user = getOnlyUser(*start.getNode());
2527 NodeValue chainEnd = start;
2528
2529 auto isSupportedForMerge = [](const Node *n) {
2530 return isa<MulNode>(n) || isa<AddNode>(n) || isa<SubNode>(n) ||
2531 isa<DivNode>(n);
2532 };
2533
2534 while (user && isSupportedForMerge(user)) {
2535 // Paranoid
2536 assert(user->isArithmetic() && "Not an arithmetic node!");
2537
2538 auto lhs = user->getNthInput(ArithmeticNode::LHSIdx);
2539 auto rhs = user->getNthInput(ArithmeticNode::RHSIdx);
2540
2541 // Paranoid.
2542 assert(((lhs == chainEnd) || (rhs == chainEnd)) && "Not a user?");
2543
2544 auto out = user->getNthResult(ArithmeticNode::ResultIdx);
2545
2546 // Quantized arithmetic operations may change scale of result, we don't want
2547 // to deal with that. May be supported later if needed.
2548 if (lhs.getType() != out.getType() || rhs.getType() != out.getType()) {
2549 break;
2550 }
2551
2552 // Only take this one if its other argument is a constant.
2553 // TODO: We can also support Splat here if needed.
2554 auto *c = dyn_cast<Constant>(lhs == chainEnd ? rhs : lhs);
2555 if (!c) {
2556 break;
2557 }
2558
2559 const dim_t numChannels = c->dims()[chIdx];
2560
2561 std::vector<float> toMerge(c->dims()[chIdx]);
2562 if (!isConstBroadcasted(toMerge, *c, chIdx)) {
2563 break;
2564 }
2565
2566 auto biasH = bias.getPayloadMutable().getHandle();
2567 auto scaleH = scale.getPayloadMutable().getHandle();
2568
2569 for (dim_t i = 0; i < numChannels; i++) {
2570 if (isa<DivNode>(user)) {
2571 scaleH.raw(i) /= toMerge[i];
2572 biasH.raw(i) /= toMerge[i];
2573 } else if (isa<MulNode>(user)) {
2574 scaleH.raw(i) *= toMerge[i];
2575 biasH.raw(i) *= toMerge[i];
2576 } else if (isa<SubNode>(user)) {
2577 if (chainEnd == rhs) {
2578 scaleH.raw(i) *= -1;
2579 biasH.raw(i) = toMerge[i] - biasH.raw(i);
2580 } else {
2581 biasH.raw(i) -= toMerge[i];
2582 }
2583 } else if (isa<AddNode>(user)) {
2584 biasH.raw(i) += toMerge[i];
2585 } else {
2586 llvm_unreachable("Unsupported type!");
2587 }
2588 }
2589 chainEnd = user->getNthResult(ArithmeticNode::ResultIdx);
2590 user = getOnlyUser(*user);
2591 }
2592 return chainEnd;
2593}
2594
2595/// Find the longest chain of Mul/Sub/Add/Div under a Convolution node that
2596/// operate on Constant and fold them into a new BatchNormalization node.
2597bool FoldArithmeticChainUnderConvIntoBN::run(Function *F,
2598 const CompilationContext &cctx) {
2599 LOG_SCOPE(F->getLogContext(), getName());
2600 bool changed = false;
2601
2602 for (auto &node : F->getNodes()) {
2603 auto *CN = dyn_cast<ConvolutionNode>(&node);
2604 if (!CN) {
2605 continue;
2606 }
2607 auto bias = CN->getBias();
2608 // Conv is in NHWC format - channel is dim 3.
2609 int32_t chIdx = 3;
2610
2611 // TODO: Support quantized constants if needed.
2612 if (bias.getType()->getElementType() != ElemKind::FloatTy) {
2613 continue;
2614 }
2615
2616 // Provide collectArithmeticChain w/ bias/scale that have identity values
2617 // as we are creating new BN consisted of the arithmetic nodes that the
2618 // function will find.
2619 auto *newScale = F->getParent()->createConstant(
2620 bias.getType(), CN->getName().str() + "_BN.scale");
2621 auto *newBias = F->getParent()->createConstant(
2622 bias.getType(), CN->getName().str() + "_BN.bias");
2623
2624 newScale->getPayloadMutable().getHandle<float>().clear(1.f);
2625 newBias->getPayloadMutable().getHandle<float>().clear(0.f);
2626
2627 // Collect the chain and compute the new scale and bias.
2628 NodeValue chainEnd =
2629 collectArithmeticChain(F, CN->getResult(), *newScale, *newBias, chIdx);
2630 if (chainEnd == CN->getResult()) {
2631 F->getParent()->eraseConstant(newScale);
2632 F->getParent()->eraseConstant(newBias);
2633 continue;
2634 }
2635
2636 // Compute the shape of batch normalization constants (array of
2637 // {depth} elements).
2638 glow::dim_t size = newScale->getPayloadMutable().getHandle<float>().size();
2639 auto depthTy =
2640 F->getParent()->uniqueTypeWithNewShape(bias.getType(), {size});
2641
2642 Tensor varianceT(depthTy);
2643 varianceT.init(glow::Tensor::InitKind::Broadcast, 1.0f, F->getPRNG());
2644 auto variance = F->getParent()->createConstant(
2645 CN->getName().str() + "_BN.var", varianceT);
2646
2647 Tensor meanT(depthTy);
2648 meanT.zero();
2649 auto mean =
2650 F->getParent()->createConstant(CN->getName().str() + "_BN.mean", meanT);
2651
2652 // Create a BN with new parameters.
2653 auto *nBN =
2654 F->createBatchNormalization("BatchNorm", chainEnd.getType(), &node,
2655 newBias, newScale, mean, variance, 3, 0, 0);
2656 chainEnd.replaceAllUsesOfWith(nBN);
2657 changed = true;
2658 }
2659 return changed;
2660}
2661
2662/// For each BatchNormalization node in \p F, find the longest chain of
2663/// Mul/Sub/Add/Div operations with constants that use it and merge all those
2664/// operations into the BatchNormalization.
2665bool FoldBatchNormalizationWithArithmeticChain::run(
2666 Function *F, const CompilationContext &cctx) {
2667 LOG_SCOPE(F->getLogContext(), getName());
2668 bool changed = false;
2669 for (auto &node : F->getNodes()) {
2670 auto *BN = dyn_cast<BatchNormalizationNode>(&node);
2671 if (!BN) {
2672 continue;
2673 }
2674
2675 // Expecting constant as const folding took place already.
2676 auto *scaleC = dyn_cast<Constant>(BN->getScale());
2677 auto *biasC = dyn_cast<Constant>(BN->getBias());
2678 if (!scaleC || !biasC) {
2679 continue;
2680 }
2681
2682 // TODO: Support quantized constants if needed.
2683 if (scaleC->getType()->getElementType() != ElemKind::FloatTy ||
2684 biasC->getType()->getElementType() != ElemKind::FloatTy) {
2685 continue;
2686 }
2687
2688 auto *newScaleC =
2689 F->getParent()->createConstant(scaleC->getType(), scaleC->getName());
2690 Tensor scaleT = scaleC->getPayload().getUnowned();
2691 newScaleC->assign(&scaleT);
2692
2693 auto *newBiasC =
2694 F->getParent()->createConstant(biasC->getType(), biasC->getName());
2695 Tensor biasT = biasC->getPayload().getUnowned();
2696 newBiasC->assign(&biasT);
2697
2698 // Collect the chain and compute the new scale and bias.
2699 NodeValue chainEnd = collectArithmeticChain(F, BN->getResult(), *newScaleC,
2700 *newBiasC, BN->getChannelIdx());
2701 if (chainEnd == BN->getResult()) {
2702 F->getParent()->eraseConstant(newScaleC);
2703 F->getParent()->eraseConstant(newBiasC);
2704 continue;
2705 }
2706
2707 Node *newScaleN = newScaleC, *newBiasN = newBiasC;
2708 if (isa<QuantizeNode>(BN->getScale())) {
2709 newScaleN = F->createQuantize(newScaleN->getName(), newScaleN,
2710 BN->getScale().getType());
2711 }
2712 if (isa<QuantizeNode>(BN->getBias())) {
2713 newBiasN = F->createQuantize(newBiasN->getName(), newBiasN,
2714 BN->getBias().getType());
2715 }
2716
2717 // Create a BN with new parameters.
2718 auto *newBN = F->createBatchNormalization(
2719 BN->getName(), chainEnd.getType(), BN->getInput(), newBiasN, newScaleN,
2720 BN->getMean(), BN->getVar(), BN->getChannelIdx(), BN->getEpsilon(),
2721 BN->getMomentum());
2722
2723 chainEnd.replaceAllUsesOfWith(newBN);
2724 changed = true;
2725 }
2726
2727 return changed;
2728}
2729
2730/// Fold MatMul + Add into FullyConnected. This is useful for backends which
2731/// have an atomic implementation for the FullyConnected node. It is also needed
2732/// for ONNX which does not have a representation for the FullyConnected node.
2733bool FoldMatMulAddIntoFullyConnected::run(Function *F,
2734 const CompilationContext &cctx) {
2735 LOG_SCOPE(F->getLogContext(), getName());
2736 bool changed = false;
2737 for (auto &node : F->getNodes()) {
2738 auto *addNode = dyn_cast<AddNode>(&node);
2739 if (!addNode) {
2740 continue;
2741 }
2742
2743 // Check for MatMul node being either RHS or LHS.
2744 auto *matMulNode_LHS = dyn_cast<MatMulNode>(addNode->getLHS());
2745 auto *matMulNode_RHS = dyn_cast<MatMulNode>(addNode->getRHS());
2746 auto *matMulNode = matMulNode_LHS ? matMulNode_LHS : matMulNode_RHS;
2747 NodeValue bias = matMulNode_LHS ? addNode->getRHS() : addNode->getLHS();
2748 if (!matMulNode) {
2749 continue;
2750 }
2751
2752 // Folding is allowed only if MatMul has one use.
2753 if (!matMulNode->getResult().hasOneUse()) {
2754 continue;
2755 }
2756
2757 // The corresponding batch/length of the FullyConnected Bias operand.
2758 assert(bias.dims().size() == 2 && "Bias should be 2D!");
2759 auto biasBatch = bias.dims()[0];
2760 auto biasLength = bias.dims()[1];
2761 if (biasBatch == 1) {
2762 // If bias is not batched then reshape to 1D.
2763 bias = F->createReshape(bias.getNode()->getName().str() + ".reshape",
2764 bias, {bias.getType()->size()});
2765 } else {
2766 // If bias is batched then we must verify that the bias data
2767 // is same for all batches. For this the bias must be a Constant.
2768 auto *biasC = llvm::dyn_cast<Constant>(bias.getNode());
2769 if (!biasC) {
2770 continue;
2771 }
2772 if (!biasC->getPayload().isTiled(0)) {
2773 continue;
2774 }
2775 // Slice batched 2D bias and reshape to 1D.
2776 bias = F->createSlice(bias.getNode()->getName().str() + ".slice", bias,
2777 {0, 0}, {1, biasLength});
2778 bias = F->createReshape(bias.getNode()->getName().str() + ".reshape",
2779 bias, {biasLength});
2780 }
2781
2782 // Create a new FullyConnected node.
2783 auto *newFC = F->createFullyConnected(
2784 matMulNode->getName(), matMulNode->getLHS(), matMulNode->getRHS(), bias,
2785 addNode->getResult().getType());
2786 addNode->getResult().replaceAllUsesOfWith(newFC);
2787 changed = true;
2788 }
2789
2790 return changed;
2791}
2792
2793/// Convert MatMul into FullyConnected with null bias. This pass is used if
2794/// we have an optimized implementation for FullyConnected but NOT for MatMul.
2795/// Make sure you run this pass after the FoldMatMulAddIntoFullyConnected pass
2796/// otherwise a MatMul followed by Add will be converted into a FullyConnected
2797/// followed by Add and NOT a single FullyConnected instance.
2798bool ConvertMatMulToFullyConnected::run(Function *F,
2799 const CompilationContext &cctx) {
2800 bool changed = false;
2801 for (auto &node : F->getNodes()) {
2802 auto *matMulNode = dyn_cast<MatMulNode>(&node);
2803 if (!matMulNode) {
2804 continue;
2805 }
2806
2807 // Create null bias.
2808 Constant *bias = nullptr;
2809 std::vector<dim_t> biasDims = {matMulNode->getResult().dims().back()};
2810 std::string biasName = matMulNode->getName().str() + "bias";
2811 if (matMulNode->getResult().getType()->isQuantizedType()) {
2812 // Create null bias with offset 0 and a scale equal to the product
2813 // between LHS scale and RHS scale.
2814 float biasScale = matMulNode->getLHS().getType()->getScale() *
2815 matMulNode->getRHS().getType()->getScale();
2816 int32_t biasOffset = 0;
2817 ElemKind biasPrec = cctx.precisionConfig.quantConfig.precisionBias;
2818 bias = F->getParent()->createConstant(biasPrec, biasDims, biasScale,
2819 biasOffset, biasName);
2820 bias->getPayloadMutable().zero();
2821 } else {
2822 // Create null FLOAT bias.
2823 bias =
2824 F->getParent()->createConstant(ElemKind::FloatTy, biasDims, biasName);
2825 bias->getPayloadMutable().zero();
2826 }
2827
2828 // Create a new FullyConnected node with null bias.
2829 auto *newFC = F->createFullyConnected(
2830 matMulNode->getName(), matMulNode->getLHS(), matMulNode->getRHS(), bias,
2831 matMulNode->getResult().getType());
2832 matMulNode->getResult().replaceAllUsesOfWith(newFC);
2833 changed = true;
2834 }
2835
2836 return changed;
2837}
2838
2839// Fold Add after ConvTranspose into ConvTranspose's bias, if such Add was a
2840// broadcasted Add. Examine by looking into Tensor repetitions. Fold this:
2841//
2842// CONST1 Input
2843// \ |
2844// CONST2 ConvTranspose
2845// \ /
2846// Add
2847// |
2848// Output
2849//
2850// into this:
2851//
2852// CONST1 (CONST2 SQUEEZED)
2853// | /
2854// Input ADD
2855// \ /
2856// ConvTranspose
2857// |
2858// Output
2859//
2860// Optimizations are going to take care of folding CONST1/CONST2/ADD
2861// into one const bias.
2862bool ConvTransposeBiasAddFold::run(Function *F,
2863 const CompilationContext &cctx) {
2864 LOG_SCOPE(F->getLogContext(), getName());
2865 bool changed = false;
2866 for (auto &node : F->getNodes()) {
2867
2868 auto *AN = dyn_cast<AddNode>(&node);
2869 if (!AN) {
2870 continue;
2871 }
2872
2873 // Check for Transpose node being either RHS or LHS.
2874 auto *DN_L = dyn_cast<ConvTransposeNode>(AN->getLHS());
2875 auto *DN_R = dyn_cast<ConvTransposeNode>(AN->getRHS());
2876 auto *DN = DN_L ? DN_L : DN_R;
2877 if (!(!DN_R ^ !DN_L)) {
2878 continue;
2879 }
2880 auto *biasTile = dyn_cast<Constant>(DN_L ? AN->getRHS() : AN->getLHS());
2881 if (!biasTile || (biasTile->dims().size() != 4)) {
2882 continue;
2883 }
2884 auto *bias = dyn_cast<Constant>(DN->getBias());
2885 if (!bias) {
2886 continue;
2887 }
2888
2889 // Check if Add is a broadcasted Add.
2890 std::vector<float> origConst(biasTile->dims()[3]);
2891 if (!isConstBroadcasted(origConst, *biasTile, 3)) {
2892 continue;
2893 }
2894
2895 // Expect Bias Add so allocate a new bias to fill as do checking.
2896 auto *newBias = F->getParent()->createConstant(
2897 ElemKind::FloatTy, {biasTile->dims()[3]}, biasTile->getName());
2898 newBias->getHandle() = origConst;
2899
2900 auto *add = F->createAdd(bias->getName(), bias, newBias);
2901 DN->setNthInput(ConvTransposeNode::BiasIdx, add);
2902 AN->getResult().replaceAllUsesOfWith(DN);
2903
2904 changed = true;
2905 } // For all nodes in the graph.
2906
2907 return changed;
2908}
2909
2910/// \returns true if all dimensions of the \p input tensors are the same
2911/// except for the provided \p dimension, otherwise return false.
2912static bool checkConcatNodeUniformDims(llvm::ArrayRef<NodeValue> inputs,
2913 unsigned_t dimension) {
2914 for (size_t i = 1; i < inputs.size(); i++) {
2915 for (size_t j = 0; j < inputs[0].dims().size(); j++) {
2916 if (j == dimension) {
2917 continue;
2918 }
2919 if (inputs[0].dims()[j] != inputs[i].dims()[j]) {
2920 return false;
2921 }
2922 }
2923 }
2924 return true;
2925}
2926
2927/// Given a tensor's dims \p firstDims and the desired leading/trailing dims
2928/// sizes \p leadingDimsProdOriginalConcatNode, \p
2929/// trailingDimsProdOriginalConcatNode. \returns the dimension, at which the
2930/// trailing/leading dimensions match the desired sizes, otherwise returns -1.
2931/// Example: Given a tensor <1,2,3,4,5>, and a desired trailing dimensions
2932/// size of 20, and a desired leading dimensions size of 2, this function will
2933/// return dimension 1 as the trailing dimensions after it are <4,5>, which
2934/// matches the size 20, and the leading dimensions are <1,2>, which matches
2935/// the size 2.
2936static ssize_t findMatchingConcatDimForSameTrailingAndLeadingDims(
2937 llvm::ArrayRef<dim_t> firstDims, size_t leadingDimsProdOriginalConcatNode,
2938 size_t trailingDimsProdOriginalConcatNode) {
2939 size_t trailingDimsProdCurNode = 1;
2940 for (ssize_t i = firstDims.size() - 1; i >= 0; i--) {
2941 if (trailingDimsProdCurNode == trailingDimsProdOriginalConcatNode) {
2942 size_t leadingDimsProdCurNode = 1;
2943 for (ssize_t j = 0; j < i; j++) {
2944 leadingDimsProdCurNode *= firstDims[j];
2945 }
2946 if (leadingDimsProdCurNode == leadingDimsProdOriginalConcatNode) {
2947 return i;
2948 }
2949 }
2950 trailingDimsProdCurNode *= firstDims[i];
2951 }
2952 return -1;
2953}
2954
2955/// Given input tensors \p inputs and a original ConcatNode \p origConcatN,
2956/// try to find out if there is a dimension in the input tensors, with which
2957/// we can meet two requirements:
2958/// 1) Input tensors are concatenate-able along this dimension.
2959/// 2) The trailing/leading dimensions sizes after/before this dimension in
2960/// the input tensors, are of the same size as the trailing/leading
2961/// dimensions of the input of the original Concat node after/before the
2962/// concatenation dimension. It is required, because they ensure that the
2963/// payload of the new concat node should be the same as the payload of
2964/// the original concat node, and also won't affect the data order of the
2965/// entire tensor.
2966/// \returns this dimension if found, otherwise -1.
2967static int
2968findConcatDimForSameTrailingAndLeadingDims(llvm::ArrayRef<NodeValue> inputs,
2969 ConcatNode *originalConcatNode) {
2970 // For the purpose of the optimiztion
2971 // Concat(Reshape(X)*N)->Reshape(Concat(N*X)), we want to make sure the new
2972 // ConcatNode can concatenate on the trailing/leading dimensions which are
2973 // of the same size of those of the original Concate node.
2974
2975 auto firstDims = inputs.front().dims();
2976 auto origConcatNInputDims = originalConcatNode->getInputs().front().dims();
2977 // The sizes of the trailing/leading dimensions of the original ConcatNode,
2978 // which are being concatenated. This sizes are simply the products of
2979 // dimensions following/before the dimension used for concatenation.
2980 dim_t trailingDimsProdOriginalConcatNode = 1;
2981 dim_t leadingDimsProdOriginalConcatNode = 1;
2982 for (size_t i = 0; i < origConcatNInputDims.size(); ++i) {
2983 if (i < originalConcatNode->getDim()) {
2984 leadingDimsProdOriginalConcatNode *= origConcatNInputDims[i];
2985 } else if (i > originalConcatNode->getDim()) {
2986 trailingDimsProdOriginalConcatNode *= origConcatNInputDims[i];
2987 }
2988 }
2989
2990 // Try to find the dimension in the first input such that the
2991 // trailing/leading dimensions sizes are the same as the sizes of the
2992 // trailing/leading dimensions based on the concatenation dimension used by
2993 // the original ConcatNode.
2994 ssize_t dim = findMatchingConcatDimForSameTrailingAndLeadingDims(
2995 firstDims, leadingDimsProdOriginalConcatNode,
2996 trailingDimsProdOriginalConcatNode);
2997 if (dim == -1) {
2998 return -1;
2999 }
3000
3001 // Now we have found the dimension, we need to check if all inputs can be
3002 // concatenated along this dimension.
3003 if (!checkConcatNodeUniformDims(inputs, dim)) {
3004 return -1;
3005 }
3006 return dim;
3007}
3008
3009/// Given the inputs \p originalConcatInputs of one Concat Nodes, \returns
3010/// true if they are all ReshapeNode, and the input tensors of these input
3011/// nodes have same number of dimensions, otherwise returns false.
3012static bool
3013tryToGetNewConcatInputs(NodeValueArrayRef originalConcatInputs,
3014 llvm::SmallVectorImpl<NodeValue> &newConcatInputs) {
3015 // Go through the input nodes of CN, check if they are all ReshapeNode,
3016 // and if the input tensors of these input nodes have same number of
3017 // dimensions.
3018 for (auto &I : originalConcatInputs) {
3019 if (auto *R = dyn_cast<ReshapeNode>(I)) {
3020 if (newConcatInputs.empty() || newConcatInputs.front().dims().size() ==
3021 R->getInput().dims().size()) {
3022 newConcatInputs.push_back(R->getInput());
3023 continue;
3024 }
3025 }
3026 return false;
3027 }
3028 return true;
3029}
3030
3031/// Concat(Reshape(x) * N) -> Reshape(Concat(x * N)).
3032/// \returns a new simplified Concat node or nullptr.
3033static NodeValue tryToOptimizeConcatOfRehapes(Function *F, ConcatNode *CN) {
3034 llvm::SmallVector<NodeValue, 16> newConcatInputs;
3035 // The inputs of the collected input reshape nodes. They will be used as
3036 // inputs for the new Concat node if possible.
3037 if (!tryToGetNewConcatInputs(CN->getInputs(), newConcatInputs)) {
3038 return NodeValue(nullptr);
3039 }
3040
3041 // Try to concatenate along the same size trailing/leading dimensions as of
3042 // the original Concat node.
3043 auto dim = findConcatDimForSameTrailingAndLeadingDims(newConcatInputs, CN);
3044 if (dim == -1) {
3045 return NodeValue(nullptr);
3046 }
3047 auto *newCN = F->createConcat(CN->getName(), newConcatInputs, dim);
3048 return F->createReshape(
3049 CN->getInputs().front().getNode()->getName(), newCN,
3050 CN->getResult().dims(),
3051 CanonicalTensorLayout::getInstance().getNthResultLayoutRequirements(
3052 CN, ConcatNode::ResultIdx));
3053}
3054
3055/// Simplify concat node.
3056/// \returns a new simplified Concat node or nullptr.
3057static NodeValue simplifyConcatNode(Function *F, ConcatNode *CN,
3058 const CompilationContext &cctx) {
3059 /// concat(dim1, concat(dim2, X, Y), Z) -> concat(dim1, X, Y, Z),
3060 /// but only if dim1 == dim2
3061
3062 LOG_SCOPE(F->getLogContext(), "simplifyConcatNode")
3063
3064 if (!cctx.optimizationOpts.skipConcatMerging) {
3065 auto inputs = CN->getInputs();
3066 // Check if any of the inputs are ConcatNode.
3067 llvm::SmallVector<NodeValue, 16> newInputs;
3068 bool merged = false;
3069 for (auto &input : inputs) {
3070 newInputs.push_back(input);
3071 auto *CNI = dyn_cast<ConcatNode>(input);
3072 // Bail if it is not a ConcatNode or it is a concat node with a diffrent
3073 // dimension.
3074 if (!CNI || CNI->getDim() != CN->getDim()) {
3075 continue;
3076 }
3077
3078 // Preventing this from kicking in.
3079 // Otherwise we will end up sequence of concats with more and more inputs,
3080 // without really eliminating any.
3081 if (CNI->getResult().getNumUsers() > 1) {
3082 continue;
3083 }
3084
3085 merged = true;
3086 // Replace current input by its own inputs, i.e. merge them into the
3087 // parent concat node.
3088 newInputs.pop_back();
3089 newInputs.append(CNI->getInputs().begin(), CNI->getInputs().end());
3090 }
3091 if (merged) {
3092 // Return a new simplified Concat node.
3093 return F->createConcat(CN->getName(), newInputs, CN->getDim());
3094 }
3095 }
3096
3097 // If all of the inputs to the concat are extracted from the same input in
3098 // the right order then we can just use the extract-input instead of the
3099 // concat. Concat(Slice(X, 0..10), Slice(X, 10..20)) -> X.
3100 {
3101 std::vector<SliceNode *> order;
3102 std::vector<SliceNode *> slices;
3103 // Collect all of the inputs that are SliceNode.
3104 for (auto &I : CN->getInputs()) {
3105 if (auto *S = dyn_cast<SliceNode>(I.getNode())) {
3106 slices.push_back(S);
3107 }
3108 }
3109 // Check if the slices span the input value.
3110 bool found = findSlicesThatSpanInput(slices, CN->getDim(), order);
3111 if (found && order.size() == slices.size()) {
3112 // Check that the ordered Slices that span the input are in order.
3113 bool ordered = true;
3114 for (size_t i = 0, e = slices.size(); i < e; i++) {
3115 if (order[i] != slices[i]) {
3116 ordered = false;
3117 break;
3118 }
3119 }
3120
3121 auto orig = order[0]->getInput();
3122 // The original value that we extract from must be of the same shape as
3123 // the concat.
3124 if (ordered && CN->getResult().getType() == orig.getType()) {
3125 return orig;
3126 }
3127 }
3128 }
3129
3130 // Try the optimization Concat(Reshape(x) * N) -> Reshape(Concat(x * N)).
3131 if (auto transformedConcatNode = tryToOptimizeConcatOfRehapes(F, CN)) {
3132 return transformedConcatNode;
3133 }
3134
3135 // If the concat has a single input, replace the concat with that input.
3136 if (CN->getNumInputs() == 1) {
3137 return CN->getInputs()[0];
3138 }
3139
3140 return NodeValue(nullptr);
3141}
3142
3143/// If all of the outputs of \p CN are essentially piped from the inputs of the
3144/// concat (i.e. same shape, axis, order) then we can get rid of the slices and
3145/// concat. \returns true if this optimization is successful and changes the
3146/// Function.
3147static bool combineConcatSlices(ConcatNode *CN) {
3148 auto inputsToCN = CN->getInputs();
3149 std::vector<SliceNode *> slices;
3150 std::vector<SliceNode *> orderedSlices;
3151 for (auto &user : CN->getUsers()) {
3152 if (SliceNode *SN = dyn_cast<SliceNode>(user.getUser())) {
3153 slices.push_back(SN);
3154 }
3155 }
3156
3157 // Check if the slices span the input value.
3158 bool found = findSlicesThatSpanInput(slices, CN->getDim(), orderedSlices);
3159 if (!found || orderedSlices.size() != slices.size() ||
3160 orderedSlices.size() != inputsToCN.size()) {
3161 return false;
3162 }
3163
3164 // Now verify that all of the inputs to CN have the same shape as all of the
3165 // slices for the result of CN.
3166 for (size_t i = 0, e = orderedSlices.size(); i < e; ++i) {
3167 if (orderedSlices[i]->getResult().dims() != inputsToCN[i].dims()) {
3168 return false;
3169 }
3170 }
3171
3172 // We can now replace all of the inputs to the concat to the result of
3173 // each slice.
3174 for (size_t i = 0, e = inputsToCN.size(); i < e; ++i) {
3175 orderedSlices[i]->getResult().replaceAllUsesOfWith(inputsToCN[i]);
3176 }
3177 return true;
3178}
3179
3180/// Eliminate Concat-Slice patterns which are unnecessary. E.g.:
3181/// NodeA NodeB NodeA NodeB
3182/// \ / | |
3183/// ConcatC | |
3184/// / \ -----> | |
3185/// SliceD SliceE | |
3186/// | | | |
3187/// NodeF NodeG NodeF NodeG
3188bool EliminateConcatSlice::run(Function *F, const CompilationContext &cctx) {
3189 LOG_SCOPE(F->getLogContext(), getName());
3190 bool changed = false;
3191 auto &nodes = F->getNodes();
3192
3193 // For each node:
3194 for (auto &node : nodes) {
3195 auto *CN = dyn_cast<ConcatNode>(&node);
3196 if (!CN) {
3197 continue;
3198 }
3199 if (combineConcatSlices(CN)) {
3200 changed = true;
3201 continue;
3202 }
3203 }
3204 return changed;
3205}
3206
3207/// Eliminate Slice-Concat patterns which are unnecessary. E.g.:
3208/// -- NodeSrc --- -NodeSrc-
3209/// / | | / |
3210/// SliceA SliceB SliceC -----> SliceAB SliceC
3211/// \ / | | |
3212/// NodeE - ConcatABE NodeD NodeE - ConcatABE NodeD
3213bool EliminateSliceConcat::run(Function *F, const CompilationContext &cctx) {
3214 LOG_SCOPE(F->getLogContext(), getName());
3215 bool changed = false;
3216 auto &nodes = F->getNodes();
3217
3218 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) {
3219 Node &node = *it;
3220 auto *CN = dyn_cast<ConcatNode>(&node);
3221 if (!CN) {
3222 continue;
3223 }
3224 // avoid 1) merging through operators other than Slices
3225 // e.g. Slice(A)-Other-Slice(B), A and B are consecutive
3226 // 2) merging Slices from different sources
3227 // e.g. Slice(A1)-Slice(B1)-Slice(A2)-Slice(B2), A1 and A2, B1 and B2 are
3228 // consecutive respectively
3229 //
3230 // Store consecutive slices along *any* dimension. If the consecutive
3231 // slices' dimension is the same as the concat, the concat can be removed.
3232 // If the slices' dimension is different from the concat, the nodes
3233 // can be replaced with slice+reshape OR slice+transpose+reshape.
3234 // The extra transpose is needed when the consecutive dimension is
3235 // after the concat dimension.
3236 std::vector<
3237 std::pair<int /* dimension of slicing */, std::vector<SliceNode *>>>
3238 consecutiveSlices;
3239 std::vector<SliceNode *> currConsecutiveSlices;
3240 SliceNode *lastSN = nullptr;
3241 int lastDim = -1;
3242 for (auto &concatInput : CN->getInputs()) {
3243 auto *SN = dyn_cast<SliceNode>(concatInput.getNode());
3244 // slices with multiple users will not be considered
3245 if (!SN || SN->getResult().getNumUsers() > 1) {
3246 if (currConsecutiveSlices.size()) {
3247 consecutiveSlices.emplace_back(lastDim, currConsecutiveSlices);
3248 currConsecutiveSlices.clear();
3249 }
3250 lastDim = -1;
3251 lastSN = nullptr;
3252 continue;
3253 }
3254 // slices with different sources will not be considered
3255 int dim = -1;
3256 if (lastSN && (lastSN->getInput() != SN->getInput() ||
3257 !findConsecutiveSliceDim(lastSN, SN, &dim))) {
3258 consecutiveSlices.emplace_back(lastDim, currConsecutiveSlices);
3259 currConsecutiveSlices.clear();
3260 }
3261 lastDim = dim;
3262 lastSN = SN;
3263 currConsecutiveSlices.emplace_back(SN);
3264 }
3265 if (currConsecutiveSlices.size()) {
3266 consecutiveSlices.emplace_back(lastDim, currConsecutiveSlices);
3267 }
3268
3269 // Mapping from old Slices to new Nodes where a Node can either be
3270 // i) a merged Slice
3271 // ii) a merged Slice + Reshape
3272 // iii) a merged Slice + Transpose + Reshape
3273 std::unordered_map<SliceNode *, Node *> oldSlicesToNewNodes;
3274
3275 for (const auto &slicePairs : consecutiveSlices) {
3276 unsigned_t slicesDim = slicePairs.first;
3277 auto &slices = slicePairs.second;
3278
3279 if (slices.size() <= 1) {
3280 continue;
3281 }
3282 if (slicesDim != CN->getDim() &&
3283 ((slicesDim != CN->getDim() + 1 && slicesDim != CN->getDim() - 1) ||
3284 slices[0]->getResult().dims()[slicesDim] != 1)) {
3285 // Optimizations are possible only if:
3286 // 1) slices consecutive dimension is the same as concat dimension, or
3287 // 2) slices consecutive dimension is adjacent to the concat dimension,
3288 // and the size of each slice along the consecutive dimension is 1.
3289 // NOTE: Checking the slicesDim dimension of 0th slice is
3290 // sufficient, as opposed to checking every slice. If the slices
3291 // can be concatenated (and they're being concatenated along a
3292 // different dimension), then each slicesDim dim must be equal.
3293 continue;
3294 }
3295 if ((slicesDim == CN->getDim() + 1 && slices.size() <= 3) ||
3296 (slicesDim == CN->getDim() - 1 && slices.size() <= 2)) {
3297 // Optimization does not decrease the number of nodes.
3298 continue;
3299 }
3300
3301 SliceNode *firstSlice = slices.front();
3302 auto *srcNode = firstSlice->getInput().getNode();
3303 std::vector<dim_t> endDims;
3304
3305 for (size_t i = 0, e2 = firstSlice->getResult().dims().size(); i < e2;
3306 i++) {
3307 endDims.emplace_back(slices.back()->getStart()[i] +
3308 slices.back()->getResult().dims()[i]);
3309 }
3310 Node *newNode = nullptr;
3311 auto *newSlice = F->createSlice(firstSlice->getName(), srcNode,
3312 firstSlice->getStart(), endDims);
3313
3314 // Create a reshape node based on consecutive slice dimension and
3315 // concat dimension.
3316 if (slicesDim == CN->getDim() + 1 || slicesDim == CN->getDim() - 1) {
3317 auto outputDimVec = newSlice->getResult().dims().vec();
3318 outputDimVec[CN->getDim()] *= outputDimVec[slicesDim];
3319 outputDimVec[slicesDim] = 1;
3320 auto outputDims = llvm::makeArrayRef(outputDimVec);
3321
3322 Node *inputToReshape = nullptr;
3323 if (slicesDim == CN->getDim() + 1) {
3324 std::vector<unsigned_t> shuffle(outputDimVec.size());
3325 std::iota(shuffle.begin(), shuffle.end(), 0);
3326 std::swap(shuffle[slicesDim], shuffle[CN->getDim()]);
3327 inputToReshape = F->createTranspose(
3328 newSlice->getName().str() + "_Transpose", newSlice, shuffle);
3329 } else {
3330 inputToReshape = newSlice;
3331 }
3332 newNode = F->createReshape(newSlice->getName().str() + "_Reshape",
3333 inputToReshape, outputDims);
3334 } else {
3335 newNode = newSlice;
3336 }
3337
3338 for (auto *slice : slices) {
3339 oldSlicesToNewNodes[slice] = newNode;
3340 }
3341
3342 changed = true;
3343 }
3344 if (!oldSlicesToNewNodes.size()) {
3345 continue;
3346 }
3347 // Replace the input Slices to CN with the merged Nodes.
3348 std::vector<NodeValue> newConcatInputs;
3349 const Node *lastNewNode = nullptr;
3350 for (const auto &concatInput : CN->getInputs()) {
3351 auto *SN = dyn_cast<SliceNode>(concatInput.getNode());
3352 if (!SN || !oldSlicesToNewNodes.count(SN)) {
3353 newConcatInputs.emplace_back(concatInput);
3354 } else {
3355 auto *newNode = oldSlicesToNewNodes[SN];
3356 if (newNode != lastNewNode) {
3357 lastNewNode = newNode;
3358 newConcatInputs.emplace_back(newNode);
3359 }
3360 }
3361 }
3362 if (newConcatInputs.size() != CN->getInputs().size()) {
3363 auto *newConcat =
3364 F->createConcat(CN->getName(), newConcatInputs, CN->getDim());
3365 CN->getResult().replaceAllUsesOfWith(newConcat);
3366 }
3367 }
3368
3369 return changed;
3370}
3371
3372/// Optimize Concat nodes.
3373bool OptimizeConcatNodes::run(Function *F, const CompilationContext &cctx) {
3374 LOG_SCOPE(F->getLogContext(), getName());
3375 bool changed = false;
3376 auto &nodes = F->getNodes();
3377
3378 // For each node:
3379 for (auto &node : nodes) {
3380 auto *CN = dyn_cast<ConcatNode>(&node);
3381 if (!CN) {
3382 continue;
3383 }
3384 NodeValue newCN = simplifyConcatNode(F, CN, cctx);
3385 if (newCN.getNode()) {
3386 CN->getResult().replaceAllUsesOfWith(newCN);
3387 changed = true;
3388 continue;
3389 }
3390 }
3391 return changed;
3392}
3393
3394/// Fold Slices into Constants. This will create new Constants if necessary.
3395bool FoldSlicesIntoConstants::run(Function *F, const CompilationContext &cctx) {
3396 LOG_SCOPE(F->getLogContext(), getName());
3397 bool changed = false;
3398 auto &nodes = F->getNodes();
3399
3400 // For each node:
3401 for (auto &node : nodes) {
3402 auto *SN = dyn_cast<SliceNode>(&node);
3403 if (!SN) {
3404 continue;
3405 }
3406 auto *C = dyn_cast<Constant>(SN->getInput());
3407 if (!C) {
3408 continue;
3409 }
3410
3411 // Create new slice of the Constant.
3412 Tensor outT = Tensor(SN->getResult().getType());
3413
3414 ElemKind k = outT.getElementType();
3415#define TYPED_INSERT(TY, TYPEKIND) \
3416 if (k == TYPEKIND) { \
3417 auto OH = outT.getHandle<TY>(); \
3418 auto IH = C->getPayloadMutable().getHandle<TY>(); \
3419 IH.extractTensors(OH, SN->getStart()); \
3420 }
3421
3422 TYPED_INSERT(float, ElemKind::FloatTy);
3423 TYPED_INSERT(float16_t, ElemKind::Float16Ty);
3424 TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty);
3425 TYPED_INSERT(int8_t, ElemKind::Int8QTy);
3426 TYPED_INSERT(int16_t, ElemKind::Int16QTy);
3427 TYPED_INSERT(int32_t, ElemKind::Int32QTy);
3428 TYPED_INSERT(int32_t, ElemKind::Int32ITy);
3429 TYPED_INSERT(int64_t, ElemKind::Int64ITy);
3430 TYPED_INSERT(bool, ElemKind::BoolTy);
3431#undef TYPED_INSERT
3432
3433 // Create a new Constant NC to hold the sliced result.
3434 auto *NC = F->getParent()->createConstant(C->getName(), std::move(outT));
3435 // Connect all Slice users with the new Slice.
3436 SN->getResult().replaceAllUsesOfWith(NC);
3437 changed = true;
3438 }
3439
3440 return changed;
3441}
3442
3443/// Simplify and canonicalize arithmetic nodes by detecting simple arithmetic
3444/// identities.
3445bool OptimizeArithmeticNodes::run(Function *F, const CompilationContext &cctx) {
3446 LOG_SCOPE(F->getLogContext(), getName());
3447 bool changed = false;
3448 // A worklist that contains the nodes to process.
3449 std::vector<Node *> worklist;
3450
3451 // Add all of the arithmetic nodes to the worklist, with a node's
3452 // dependencies added after itself so they are processed before the node.
3453 GraphPreOrderVisitor visitor(*F);
3454 worklist.reserve(visitor.getPreOrder().size());
3455 for (auto *N : visitor.getPreOrder()) {
3456 if (N->isArithmetic()) {
3457 worklist.push_back(N);
3458 }
3459 }
3460 while (!worklist.empty()) {
3461 Node *N = worklist.back();
3462 assert(N->isArithmetic() && "Must be an Arithmetic node.");
3463 worklist.pop_back();
3464
3465 auto SNV = simplifyNode(N, F);
3466 if (SNV.getNode() != N) {
3467 N->getNthResult(ArithmeticNode::ResultIdx).replaceAllUsesOfWith(SNV);
3468 changed = true;
3469
3470 auto *SN = SNV.getNode();
3471
3472 // The simplified node could be further simplified. Note that the
3473 // simplified node might not be arithmetic; it could be a splat.
3474 if (SN->isArithmetic()) {
3475 worklist.push_back(SN);
3476 }
3477
3478 // The simplified node's operands could be further simplified as well.
3479 // Push them after the node so they are processed before the node.
3480 for (size_t i = 0, e = SN->getNumInputs(); i < e; i++) {
3481 Node *input = SN->getNthInput(i).getNode();
3482 if (input->isArithmetic()) {
3483 worklist.push_back(input);
3484 }
3485 }
3486 continue;
3487 }
3488 }
3489 return changed;
3490}
3491
3492/// Statically transpose Constants.
3493bool TransposeConstants::run(Function *F, const CompilationContext &cctx) {
3494 LOG_SCOPE(F->getLogContext(), getName());
3495 auto &nodes = F->getNodes();
3496 bool changed = false;
3497 for (auto &node : nodes) {
3498 auto *TN = dyn_cast<TransposeNode>(&node);
3499 if (!TN) {
3500 continue;
3501 }
3502 auto *Q = dyn_cast<QuantizeNode>(TN->getInput());
3503 auto *C = dyn_cast<Constant>(Q ? Q->getInput() : TN->getInput());
3504 if (!C || (Q && !Q->hasOneUse())) {
3505 continue;
3506 }
3507 // Create a new Constant NC to hold the transposed result.
3508 auto cTy = F->getParent()->uniqueTypeWithNewShape(
3509 C->getType(), TN->getResult().getType());
3510 auto *NC =
3511 F->getParent()->createConstant(cTy, C->getName(), TN->getLayout());
3512 // Transpose the value of C into NC.
3513 genericTranspose(&C->getPayload(), &NC->getPayloadMutable(),
3514 TN->getShuffle());
3515 NC->getPayloadMutable().setType(NC->getType());
3516 // Create a new transposed Quantize, if needed.
3517 NodeValue NN(NC);
3518 if (Q) {
3519 auto qTy = F->getParent()->uniqueTypeWithNewShape(
3520 Q->getResult().getType(), TN->getResult().dims());
3521 auto *NQ = F->createQuantize(Q->getName(), NC, qTy);
3522 NN = NodeValue(NQ);
3523 }
3524 // Rewrite uses of TN to reference NC or NQ.
3525 TN->getResult().replaceAllUsesOfWith(NN);
3526 changed = true;
3527 }
3528 return changed;
3529}
3530
3531namespace {
3532
3533/// A helper type for hasing Node pointers when they are used as keys in hash
3534/// maps.
3535struct NodeHasher {
3536 size_t operator()(Node *N) const { return N->getHash(); }
3537};
3538
3539/// A helper type implementing the Node equality predicate that can be used
3540/// when Node pointers are used as keys in hash maps.
3541struct NodeEq {
3542 bool operator()(const Node *lhs, const Node *rhs) const {
3543 return lhs->isEqual(*rhs);
3544 }
3545};
3546
3547/// Array with node kinds which are excepted from CSE optimization, for example
3548/// node kinds for which the output is not necessarily identical when the nodes
3549/// are identical. Such nodes include Touch nodes which allocate buffers without
3550/// initialization or nodes which generate random data.
3551static const std::vector<Kinded::Kind> CSENodeExceptions = {
3552 Kinded::Kind::TouchNodeKind,
3553};
3554
3555/// This visitor is used to walk the graph and perform a common subexpression
3556/// evaluation.
3557struct CSEVisitor : NodeWalker {
3558 // Mapping from the original node to its canonical representation under CSE.
3559 std::unordered_map<Node *, Node *, NodeHasher, NodeEq> cseNodes_;
3560 // Set of visited nodes.
3561 std::unordered_set<Node *> visitedNodes_;
3562
3563 /// This callback is called before visiting the children of \p N.
3564 void pre(Node *parent, Node *N) override {
3565 // Put the node into a visited set to make sure it is visited
3566 // only once.
3567 visitedNodes_.insert(N);
3568 }
3569
3570 /// This callback is called after visiting the children of \p N.
3571 /// It means that all of its dependencies are processed already.
3572 void post(Node *parent, Node *N) override {
3573 // Try to find a node equivalent to the current one.
3574 auto FoundI = cseNodes_.find(N);
3575 if (FoundI == cseNodes_.end()) {
3576 // No node CSE-equivalent to the current one has been seen yet.
3577 // Remember this node, so that the next occurrence can be
3578 // replaced by this one.
3579 cseNodes_.insert({N, N});
3580 assert(cseNodes_.find(N) != cseNodes_.end());
3581 return;
3582 }
3583 Node *foundN = FoundI->second;
3584
3585 // Same node cannot be visited.
3586 assert(N != foundN);
3587
3588 // Replace current node by a found node, which is
3589 // equivalent to it.
3590 assert(N->isEqual(*foundN));
3591
3592 // Skip CSE node exceptions.
3593 if (std::find(CSENodeExceptions.begin(), CSENodeExceptions.end(),
3594 N->getKind()) != CSENodeExceptions.end()) {
3595 return;
3596 }
3597
3598 // Replace node.
3599 for (unsigned i = 0; i < N->getNumResults(); i++) {
3600 NodeValue FV(foundN, i);
3601 N->getNthResult(i).replaceAllUsesOfWith(FV);
3602 }
3603 // TODO: Erase N during CSE? If we don't do it here,
3604 // DCE will remove it later anyways.
3605 }
3606
3607 /// Make sure that each node is processed only once.
3608 bool shouldVisit(Node *parent, Node *N) override {
3609 return visitedNodes_.count(N) == 0;
3610 }
3611};
3612
3613/// A helper type for hashing Constant pointers when they are used as keys in
3614/// hash maps for deduplication. The hash is based on the type of the Constant
3615/// (element type, dimensions), as well as a constant number of elements from
3616/// the backing Tensor to balance collisions with hash calclulation time.
3617struct ConstsHasherDedup {
3618 size_t operator()(Constant *V) const {
3619 auto hash = llvm::hash_value(V->getType());
3620 auto &T = V->getPayload();
3621 // Only use the first 8 elements in the hash. It's likely that if two
3622 // tensors have different content they will diverge quickly. Fall back to
3623 // full equality check in ConstsEqDedup.
3624 constexpr dim_t maxNumEls = 8;
3625 dim_t numEls = std::min((dim_t)T.getType().size(), maxNumEls);
3626 dim_t bufSize = T.getType().getElementSize() * numEls;
3627 auto *data = T.getUnsafePtr();
3628 for (size_t i = 0; i < bufSize; i++) {
3629 hash = llvm::hash_combine(hash, data[i]);
3630 }
3631 return hash;
3632 }
3633};
3634
3635/// A helper type implementing the Constants equality predicate that can be
3636/// used when Constant pointers are used as keys in hash maps for
3637/// deduplication.
3638struct ConstsEqDedup {
3639 bool operator()(const Constant *lhs, const Constant *rhs) const {
3640 // Only consider Constants for deduplication if they have the same type.
3641 if (lhs->getType() != rhs->getType()) {
3642 return false;
3643 }
3644 // Only dedup Constants if they're bit exact matches.
3645 return lhs->getPayload().isBitwiseEqual(rhs->getPayload());
3646 }
3647};
3648
3649} // namespace
3650
3651/// Deduplicates Constants in the Module \p M. Applicable Constants for
3652/// deduplication must have the same data. \returns whether any Constants were
3653/// deduplicated.
3654static bool deduplicateConstants(Module *M) {
3655 // Map from Constants to other Constants that are equivalent for purposes of
3656 // deduplication.
3657 std::unordered_map<Constant *, Constant *, ConstsHasherDedup, ConstsEqDedup>
3658 duplicateConstants;
3659
3660 bool changed = false;
3661 for (auto &C : M->getConstants()) {
3662 // Only perform deduplication of consts with given max number of elements.
3663 size_t maxNumEls = constDedupSizeOpt;
3664 size_t numEls = C->getType()->size();
3665 if ((maxNumEls != 0) && (numEls > maxNumEls)) {
3666 continue;
3667 }
3668
3669 // Try to find a Constant that has the same data as the current one.
3670 auto foundI = duplicateConstants.find(C);
3671 if (foundI == duplicateConstants.end()) {
3672 // No node equivalent to the current one has been seen yet. Remember
3673 // this Constant, so that the next occurrence can be replaced by this
3674 // one.
3675 duplicateConstants.emplace(C, C);
3676 assert(duplicateConstants.find(C) != duplicateConstants.end());
3677 continue;
3678 }
3679 Constant *foundC = foundI->second;
3680 assert(C != foundC && "Constants should not be visited multiple times.");
3681
3682 // Replace current Constant by a found Constant, which is equivalent to
3683 // it.
3684 C->getOutput().replaceAllUsesOfWith(foundC);
3685 changed = true;
3686 }
3687 return changed;
3688}
3689
3690/// Common Subexpression Elimination.
3691bool CSE::run(Function *F, const CompilationContext &cctx) {
3692 LOG_SCOPE(F->getLogContext(), getName());
3693 CSEVisitor visitor;
3694
3695 bool changed = false;
3696 if (cctx.optimizationOpts.enableConstantDeduplication) {
3697 changed |= deduplicateConstants(F->getParent());
3698 }
3699
3700 // Perform CSE on all nodes.
3701 for (auto &N : F->getNodes()) {
3702 N.visit(nullptr, &visitor);
3703 }
3704 // TODO: Change Visitors to return whether they modified the Function they
3705 // are contained in. For now conservatively set changed to true;
3706 changed = true;
3707 return changed;
3708}
3709
3710/// Fold Nodes into SplatNodes.
3711bool OptimizeSplat::run(Function *F, const CompilationContext &cctx) {
3712 LOG_SCOPE(F->getLogContext(), getName());
3713 bool changed = false;
3714 for (Node &node : F->getNodes()) {
3715 // Slice(Splat(args)) -> Splat(args')
3716 if (SliceNode *sliceNode = dyn_cast<SliceNode>(&node)) {
3717 SplatNode *splatNode = dyn_cast<SplatNode>(sliceNode->getInput());
3718 if (!splatNode) {
3719 continue;
3720 }
3721 SplatNode *newSplatNode =
3722 F->createSplat(sliceNode->getName(), sliceNode->getResult().getType(),
3723 splatNode->getValue());
3724 sliceNode->getResult().replaceAllUsesOfWith(newSplatNode);
3725 changed = true;
3726 continue;
3727 }
3728
3729 // Clip(Splat(args)) -> Splat(args')
3730 if (ClipNode *clipNode = dyn_cast<ClipNode>(&node)) {
3731 SplatNode *splatNode = dyn_cast<SplatNode>(clipNode->getInput());
3732 if (!splatNode) {
3733 continue;
3734 }
3735 const float newSplatVal =
3736 std::min(std::max(splatNode->getValue(), clipNode->getMin()),
3737 clipNode->getMax());
3738
3739 SplatNode *newSplatNode = nullptr;
3740 if (newSplatVal == splatNode->getValue()) {
3741 // No need to crate a new Splat.
3742 newSplatNode = splatNode;
3743 } else {
3744 newSplatNode = F->createSplat(
3745 splatNode->getName().str() + clipNode->getName().str(),
3746 splatNode->getResult().getType(), newSplatVal);
3747 }
3748
3749 clipNode->getResult().replaceAllUsesOfWith(newSplatNode->getResult());
3750 changed = true;
3751 continue;
3752 }
3753 }
3754 return changed;
3755}
3756
3757bool GatherToSlice::run(Function *F, const CompilationContext &cctx) {
3758 LOG_SCOPE(F->getLogContext(), getName());
3759 bool changed = false;
3760
3761 for (auto &node : F->getNodes()) {
3762 auto *GN = dyn_cast<GatherNode>(&node);
3763 if (!GN) {
3764 continue;
3765 }
3766
3767 auto data = GN->getData();
3768 auto *indices = dyn_cast<Constant>(GN->getIndices());
3769
3770 // Only handling scalar constant index value
3771 if (!indices || indices->getPayload().size() != 1) {
3772 continue;
3773 }
3774
3775 dim_t index = 0;
3776 size_t axis = GN->getBatchDims();
3777 auto elementKind = indices->getElementType();
3778 if (elementKind == ElemKind::Int64ITy) {
3779 index = (size_t)indices->getHandle<int64_t>().raw(0);
3780 } else if (elementKind == ElemKind::Int32ITy) {
3781 index = (size_t)indices->getHandle<int32_t>().raw(0);
3782 } else {
3783 llvm_unreachable("GatherToSlice: Unsupported indices type");
3784 }
3785
3786 std::vector<dim_t> start;
3787 std::vector<dim_t> end;
3788 for (size_t i = 0; i < data.dims().size(); ++i) {
3789 if (i == axis) {
3790 start.push_back(index);
3791 end.push_back(index + 1);
3792 } else {
3793 start.push_back(0);
3794 end.push_back(data.dims()[i]);
3795 }
3796 }
3797
3798 auto name = GN->getName();
3799 auto *SN = F->createSlice(name, data, start, end);
3800 auto *RN = F->createReshape(name, SN, GN->getResult().dims());
3801
3802 GN->getResult().replaceAllUsesOfWith(RN->getResult());
3803 changed = true;
3804 }
3805
3806 return changed;
3807}
3808
3809/// Optimize TransposeNode into ReshapeNode when it actually moves no data.
3810bool OptimizeTransposeIntoReshape::run(Function *F,
3811 const CompilationContext &cctx) {
3812 LOG_SCOPE(F->getLogContext(), getName());
3813 bool changed = false;
3814
3815 for (auto &node : F->getNodes()) {
3816 auto *TR = dyn_cast<TransposeNode>(&node);
3817 if (!TR) {
3818 continue;
3819 }
3820 auto inputNode = TR->getInput();
3821 auto inputDims = inputNode.dims();
3822 auto outputDims = TR->getResult().dims();
3823 // The transformation is not possible if alignments different from 1 are
3824 // used for any dimension.
3825 if (!inputNode.getType()->isEqual(F->getParent()->uniqueTypeWithNewShape(
3826 inputNode.getType(), inputDims))) {
3827 continue;
3828 }
3829 if (!TR->getResult().getType()->isEqual(
3830 F->getParent()->uniqueTypeWithNewShape(TR->getResult().getType(),
3831 outputDims))) {
3832 continue;
3833 }
3834 // Transpose moves no data if input/output dimensions match after they both
3835 // drop dimensions of size 1. E.g. transposing [1 5 1 15] into [5 15 1 1]
3836 // produces vectors (1, 3) for both dimensions so optimization is executed.
3837 auto shuffle = TR->getShuffle();
3838 ShapeVector inDims;
3839 ShapeVector outDims;
3840 for (size_t i = 0; i < inputDims.size(); i++) {
3841 if (inputDims[i] != 1) {
3842 inDims.push_back(i);
3843 }
3844 if (outputDims[i] != 1) {
3845 outDims.push_back(shuffle[i]);
3846 }
3847 }
3848 if (inDims != outDims) {
3849 continue;
3850 }
3851 auto *RS =
3852 F->createReshape(TR->getName(), inputNode, outputDims, TR->getLayout());
3853 TR->getResult().replaceAllUsesOfWith(RS);
3854 changed = true;
3855 }
3856
3857 return changed;
3858}
3859
3860/// Eliminate nodes which don't do anything useful.
3861bool EliminateNoop::run(Function *F, const CompilationContext &cctx) {
3862 LOG_SCOPE(F->getLogContext(), getName());
3863 bool changed = false;
3864
3865 auto isNoop = [](const Node &node, NodeValue &input,
3866 NodeValue &output) -> bool {
3867 // Auto-select input/output, if there is just one. For other cases it
3868 // will be handled on per-operator basis below.
3869 if (node.getNumInputs() == 1) {
3870 input = node.getNthInput(0);
3871 }
3872 if (node.getNumResults() == 1) {
3873 output = node.getNthResult(0);
3874 }
3875
3876 // For some nodes it's enough just to compare input and output types to
3877 // determine if they are noop.
3878 if (isa<PadNode>(&node) || isa<SliceNode>(&node) || isa<TileNode>(&node) ||
3879 isa<BroadcastNode>(&node)) {
3880 return input.getType() == output.getType();
3881 }
3882
3883 // Operator-specific analysis.
3884 if (auto *APN = dyn_cast<AvgPoolNode>(&node)) {
3885 input = APN->getInput();
3886 return isUniformArray(APN->getKernels(), 1u) &&
3887 isUniformArray(APN->getStrides(), 1u) &&
3888 isUniformArray(APN->getPads(), 0u);
3889 } else if (auto *MPN = dyn_cast<MaxPoolNode>(&node)) {
3890 input = MPN->getInput();
3891 output = MPN->getResult();
3892 return isUniformArray(MPN->getKernels(), 1u) &&
3893 isUniformArray(MPN->getStrides(), 1u) &&
3894 isUniformArray(MPN->getPads(), 0u);
3895 } else if (auto *SN = dyn_cast<SelectNode>(&node)) {
3896 auto *cond = getConstant(SN->getCond());
3897 bool val = false;
3898 if (cond && isUniformConstant<bool>(*cond, val)) {
3899 input = val ? SN->getLHS() : SN->getRHS();
3900 return true;
3901 }
3902 }
3903
3904 return false;
3905 };
3906
3907 for (auto &node : F->getNodes()) {
3908 NodeValue input, output;
3909 if (isNoop(node, input, output)) {
3910 assert(input != NodeValue() && output != NodeValue() &&
3911 "Sanity check that input and output are set");
3912 output.replaceAllUsesOfWith(input);
3913 changed = true;
3914 }
3915 }
3916
3917 return changed;
3918}
3919
3920/// Optimize reshape nodes.
3921bool OptimizeReshape::run(Function *F, const CompilationContext &cctx) {
3922 LOG_SCOPE(F->getLogContext(), getName());
3923 bool changed = false;
3924 for (auto &node : F->getNodes()) {
3925 auto *reshapeNode = dyn_cast<ReshapeNode>(&node);
3926 if (!reshapeNode) {
3927 continue;
3928 }
3929 auto inputNode = reshapeNode->getNthInput(ReshapeNode::InputIdx);
3930 // Eliminate ReshapeNode when the input is already the correct shape.
3931 if (inputNode.dims() == reshapeNode->getResult().dims()) {
3932 reshapeNode->getResult().replaceAllUsesOfWith(inputNode);
3933 continue;
3934 }
3935 // Reshape(Splat(args)) -> Splat(args').
3936 auto *splatNode = dyn_cast<SplatNode>(inputNode);
3937 if (splatNode && splatNode->hasOneUse()) {
3938 // Splat node with more than one use can not be transformed, otherwise
3939 // we would increase the number of splats, which may lead to increased
3940 // memory consumption during the execution of the NN model.
3941 auto *newSplatNode = F->createSplat(splatNode->getName(),
3942 reshapeNode->getResult().getType(),
3943 splatNode->getValue());
3944 reshapeNode->getResult().replaceAllUsesOfWith(newSplatNode);
3945 changed = true;
3946 continue;
3947 }
3948 // Reshape(Reshape(x)) -> Reshape(x).
3949 auto *reshapeNodeInput = dyn_cast<ReshapeNode>(inputNode);
3950 if (reshapeNodeInput && reshapeNodeInput->hasOneUse()) {
3951 auto *newReshape = F->createReshape(
3952 reshapeNode->getName(), reshapeNodeInput->getInput(),
3953 reshapeNode->getResult().dims(), reshapeNode->getLayout());
3954 reshapeNode->getResult().replaceAllUsesOfWith(newReshape);
3955 changed = true;
3956 continue;
3957 }
3958 // Reshape(Constant) -> Constant' or
3959 // Reshape(Quantize(Constant)) -> Quantize'(Constant').
3960 auto *Q = dyn_cast<QuantizeNode>(inputNode);
3961 auto *C = dyn_cast<Constant>(Q ? Q->getInput() : inputNode);
3962 if (C && (!Q || Q->hasOneUse())) {
3963 // Create a new Constant with the dims of the reshape.
3964 auto layout =
3965 CanonicalTensorLayout::getInstance().getNthResultLayoutRequirements(
3966 reshapeNode, ReshapeNode::ResultIndices::ResultIdx);
3967 auto cTy = F->getParent()->uniqueTypeWithNewShape(
3968 C->getType(), reshapeNode->getResult().getType());
3969 auto *newC = F->getParent()->createConstant(cTy, C->getName(), layout);
3970 // Create an unowned view of the original tensor with the correct shape,
3971 // and assign it to the new Constant.
3972 Tensor reshapedT = C->getPayload().getUnowned(reshapeNode->getDims());
3973 newC->assign(&reshapedT);
3974 // Create a new Quantize with the dims of the reshape, if needed.
3975 NodeValue newN(newC);
3976 if (Q) {
3977 auto qTy = F->getParent()->uniqueTypeWithNewShape(
3978 Q->getResult().getType(), reshapeNode->getResult().dims());
3979 auto *newQ = F->createQuantize(Q->getName(), newC, qTy);
3980 newN = NodeValue(newQ);
3981 }
3982 reshapeNode->getResult().replaceAllUsesOfWith(newN);
3983 changed = true;
3984 continue;
3985 }
3986 }
3987 return changed;
3988}
3989
3990/// Helper to optimize Resize nodes.
3991template <typename ResizeNodeType>
3992static bool optimizeResize(Function *F, const CompilationContext &cctx) {
3993 bool changed = false;
3994 for (auto &node : F->getNodes()) {
3995 auto *resizeNode = dyn_cast<ResizeNodeType>(&node);
3996 CONTINUE_IF_NOT(resizeNode);
3997 // Remove identity resize (same input and output type).
3998 auto inpType = resizeNode->getInput().getType();
3999 auto outType = resizeNode->getResult().getType();
4000 if (inpType->isEqual(outType)) {
4001 resizeNode->getResult().replaceAllUsesOfWith(resizeNode->getInput());
4002 changed = true;
4003 continue;
4004 }
4005 // Dimensions which are resized from unitary sizes should use Tile nodes.
4006 // We pull out Tile nodes from the Resize output such that the Resize node
4007 // operates on smaller sizes thus reducing the complexity. We create Tile
4008 // nodes in the decreasing order of the dimensions to increase locality.
4009 auto inpDims = resizeNode->getInput().dims();
4010 auto outDims = resizeNode->getResult().dims();
4011 std::vector<dim_t> newOutDims = outDims.vec();
4012 std::vector<unsigned_t> axes;
4013 std::vector<unsigned_t> tiles;
4014 for (ssize_t idx = outDims.size() - 1; idx >= 0; idx--) {
4015 if ((inpDims[idx] == 1) && (outDims[idx] > 1)) {
4016 newOutDims[idx] = 1;
4017 axes.push_back(idx);
4018 tiles.push_back(outDims[idx]);
4019 }
4020 }
4021 CONTINUE_IF_NOT(axes.size());
4022 auto newOutType =
4023 F->getParent()->uniqueTypeWithNewShape(outType, newOutDims);
4024 // Create Resize node.
4025 NodeValue newOut = resizeNode->getInput();
4026 if (!inpType->isEqual(newOutType)) {
4027 if (std::is_same<ResizeNodeType, ResizeNearestNode>::value) {
4028 newOut = F->createResizeNearest(node.getName(), newOut, newOutType);
4029 } else if (std::is_same<ResizeNodeType, ResizeBilinearNode>::value) {
4030 newOut = F->createResizeBilinear(node.getName(), newOut, newOutType);
4031 } else {
4032 llvm_unreachable("Resize node type not supported!");
4033 }
4034 }
4035 // Create Tile nodes.
4036 newOut =
4037 F->createTile(node.getName().str() + "." + "Tile", newOut, tiles, axes);
4038 resizeNode->getResult().replaceAllUsesOfWith(newOut);
4039 changed = true;
4040 continue;
4041 }
4042 return changed;
4043}
4044
4045/// Optimize Resize nodes.
4046bool OptimizeResize::run(Function *F, const CompilationContext &cctx) {
4047 LOG_SCOPE(F->getLogContext(), getName());
4048 bool changed = false;
4049 changed |= optimizeResize<ResizeNearestNode>(F, cctx);
4050 changed |= optimizeResize<ResizeBilinearNode>(F, cctx);
4051 return changed;
4052}
4053
4054/// Optimize Insert nodes.
4055bool OptimizeInsert::run(Function *F, const CompilationContext &cctx) {
4056 LOG_SCOPE(F->getLogContext(), getName());
4057 bool changed = false;
4058 for (auto &node : F->getNodes()) {
4059 auto *insertNode = dyn_cast<InsertTensorNode>(&node);
4060 CONTINUE_IF_NOT(insertNode);
4061 // When the "Big" tensor is a Splat which is entirely filled by the
4062 // "Small" tensor then we replace the Splat with a Touch node to remove
4063 // the initialization overhead of the Splat which is not needed.
4064 NodeValue big = insertNode->getBig();
4065 NodeValue small = insertNode->getSmall();
4066 auto bigDims = big.dims().vec();
4067 auto smallDims = small.dims().vec();
4068 smallDims[insertNode->getAxis()] *= insertNode->getCount();
4069 CONTINUE_IF_NOT(isUniformArray(insertNode->getStart(), dim_t(0)) &&
4070 dyn_cast<SplatNode>(big) && (bigDims == smallDims));
4071 NodeValue touch =
4072 F->createTouch(node.getName().str() + "." + "Touch", big.getType());
4073 node.setNthInput(InsertTensorNode::BigIdx, touch);
4074 changed = true;
4075 continue;
4076 }
4077 return changed;
4078}
4079
4080/// Optimize: Max(Splat(), otherInput) or Max(otherInput, Splat()) for
4081/// quantized operations.
4082/// Splat and Max can be eliminated if Splat value cannot impact the result.
4083/// For example, Max and Splat can be removed if splat value is smaller
4084/// than quantization range [min, max].
4085/// \returns if anything was changed in the given function.
4086static bool optimizeQuantizedMaxSplat(Function *F) {
4087 LOG_SCOPE(F->getLogContext(), "optimizeQuantizedMaxSplat")
4088
4089 bool changed = false;
4090 // The following optimizations need to be performed after all
4091 // quantize/dequantize/rescale optimizations are done.
4092 for (auto &node : F->getNodes()) {
4093 // Potentially nop quantized Max can be eliminated.
4094 // Likely MaxNode has same types for LHS/RHS and Result, make sure
4095 // it's the case.
4096 if (auto *MN = dyn_cast<MaxNode>(&node)) {
4097 if (!MN->getResult().getType()->isQuantizedType() ||
4098 MN->getResult().getType() != MN->getLHS().getType() ||
4099 MN->getResult().getType() != MN->getRHS().getType()) {
4100 continue;
4101 }
4102
4103 // Check for input Splat node.
4104 if (!isa<SplatNode>(MN->getLHS()) && !isa<SplatNode>(MN->getRHS())) {
4105 continue;
4106 }
4107
4108 Node *splatNode =
4109 isa<SplatNode>(MN->getLHS()) ? MN->getLHS() : MN->getRHS();
4110 Node *otherInput =
4111 isa<SplatNode>(MN->getLHS()) ? MN->getRHS() : MN->getLHS();
4112
4113 // If splat value is smaller than values that can be covered by
4114 // quantization [min,max] range then just remove MaxNode operation.
4115 float splatValue = (dyn_cast<SplatNode>(splatNode))->getValue();
4116 float min = MN->getResult().getType()->getQuantizedValueRange().first;
4117 if (splatValue <= min) {
4118 changed = true;
4119 MN->getResult().replaceAllUsesOfWith(otherInput);
4120 }
4121 continue;
4122 }
4123 // Potentially nop quantized ReLU can be eliminated.
4124 if (auto *RN = dyn_cast<ReluNode>(&node)) {
4125 if (!RN->getResult().getType()->isQuantizedType() ||
4126 RN->getResult().getType() != RN->getInput().getType()) {
4127 continue;
4128 }
4129
4130 Node *input = RN->getInput();
4131
4132 // If zero is smaller or equal than values that can be covered by
4133 // quantization [min,max] range then just remove ReluNode operation.
4134 float min = RN->getResult().getType()->getQuantizedValueRange().first;
4135 if (0.0f <= min) {
4136 changed = true;
4137 RN->getResult().replaceAllUsesOfWith(input);
4138 }
4139 continue;
4140 }
4141 }
4142 return changed;
4143}
4144
4145/// \returns A value representing the given \p constant converted
4146/// to the destination type \p dstTy. If the conversion is not
4147/// possible, this method returns NodeValue().
4148static NodeValue convertConstant(Module &mod, Constant &constant,
4149 TypeRef dstTy) {
4150 // Sort out the easy case first.
4151 if (constant.getType() == dstTy) {
4152 return constant.getOutput();
4153 }
4154 auto modifyConstantTyAndGet = [&]() -> Constant & {
4155 Constant *oneUseCst = getUniquelyUsedConstant(&mod, constant);
4156 assert(oneUseCst &&
4157 "We should always be able to get a constant from a constant!");
4158 // This type setting updates the type of the node.
4159 // The underlying tensor still needs to be converted after this call.
4160 oneUseCst->setType(Storage::OutputIdx, dstTy);
4161 return *oneUseCst;
4162 };
4163 const Tensor &tensor = constant.getPayload();
4164 switch (tensor.getElementType()) {
4165 case ElemKind::FloatTy:
4166 case ElemKind::Float16Ty:
4167 case ElemKind::BFloat16Ty:
4168 switch (dstTy->getElementType()) {
4169 case ElemKind::FloatTy:
4170 case ElemKind::Float16Ty:
4171 case ElemKind::BFloat16Ty: {
4172 // Plain conversion:
4173 // {FloatTy, Float16Ty, BFloat16Ty} -> {FloatTy, Float16Ty, BFloat16Ty}.
4174 Constant &constantToBeModified = modifyConstantTyAndGet();
4175 constantToBeModified.getPayloadMutable().convertToType(
4176 dstTy->getElementType());
4177 return constantToBeModified.getOutput();
4178 }
4179 case ElemKind::Int64QTy:
4180 case ElemKind::Int32QTy:
4181 case ElemKind::Int16QTy:
4182 case ElemKind::Int8QTy: {
4183 // Quantization: {FloatTy, Float16Ty, BFloat16Ty} -> Quantized type.
4184 Constant &constantToBeModified = modifyConstantTyAndGet();
4185 TensorQuantizationParams params{dstTy->getScale(), dstTy->getOffset()};
4186 Tensor &tensorToBeModified = constantToBeModified.getPayloadMutable();
4187 // Right now we only quantize fp32 value.
4188 // Add an assert on that, so that if it changes, we adapt the
4189 // following code. Adapting the code would required to
4190 // teach quantizeTensor how to deal with Float16Ty.
4191 assert(tensor.getType().isFPType() &&
4192 "Type quantization not implemented");
4193 tensorToBeModified = quantization::quantizeTensor(
4194 tensorToBeModified, params, dstTy->getElementType());
4195 return constantToBeModified.getOutput();
4196 }
4197 case ElemKind::Int32ITy: {
4198 // Plain conversion: {FloatTy} -> {Int32ITy}.
4199 Constant &constantToBeModified = modifyConstantTyAndGet();
4200 constantToBeModified.getPayloadMutable().convertToType(
4201 dstTy->getElementType());
4202 return constantToBeModified.getOutput();
4203 }
4204 default:
4205 // Quantization: {FloatTy, Float16Ty, BFloat16Ty} -> Int[16|32]QTy.
4206 // Plain conversion: {FloatTy, Float16Ty, BFloat16Ty} -> Int64ITy.
4207 return NodeValue();
4208 }
4209 case ElemKind::UInt8FusedQTy: {
4210 if (dstTy->getElementType() != ElemKind::UInt8FusedFP16QTy) {
4211 return NodeValue();
4212 }
4213 auto *NC =
4214 mod.createConstant(dstTy, constant.getName(), constant.getLayout());
4215 NC->getPayloadMutable() =
4216 tensor.getCopyConvertedToType(dstTy->getElementType());
4217 return NC->getOutput();
4218 }
4219 case ElemKind::Int64ITy:
4220 case ElemKind::Int32ITy:
4221 switch (dstTy->getElementType()) {
4222 case ElemKind::Int32ITy:
4223 case ElemKind::Int64ITy: {
4224 // Plain conversion: {Int64ITy, Int32ITy} -> {Int64ITy, Int32ITy}.
4225 Constant &constantToBeModified = modifyConstantTyAndGet();
4226 constantToBeModified.getPayloadMutable().convertToType(
4227 dstTy->getElementType());
4228 return constantToBeModified.getOutput();
4229 }
4230 case ElemKind::FloatTy: {
4231 Constant &constantToBeModified = modifyConstantTyAndGet();
4232 constantToBeModified.getPayloadMutable().convertToType(
4233 dstTy->getElementType());
4234 return constantToBeModified.getOutput();
4235 }
4236
4237 default:
4238 return NodeValue();
4239 }
4240 default:
4241 // For now we don't see other quantize, dequantize, or rescale nodes
4242 // directly attached to constants.
4243 // Thus don't add code that will never be executed.
4244 // Dequantization: Int[8|16|32]QTy -> {FloatTy, Float16Ty, BFloat16Ty,
4245 // Int64I}. Rescale: Int[8|16|32]QTy -> Int[8|16|32]QTy. Plain conversion:
4246 // Int64ITy -> {FloatTy, Float16Ty, BFloat16Ty}. Quantization: Int64ITy ->
4247 // Int[8|16|32]QTy.
4248 return NodeValue();
4249 }
4250}
4251
4252/// Compute number of significant bits that are used to represent data of type
4253/// \p kind. For FP, it is the number of bits in mantissa, for integers it's the
4254/// number of bits except sign bit.
4255/// \p returns number of significant bits of \p kind.
4256/// TODO: Currently, for all supported types wider mantissa also means wider
4257/// exponent. If we add a type for which this is not true, we should check both
4258/// mantissa and exponent.
4259static size_t numSignificantBits(ElemKind kind) {
4260 switch (kind) {
4261 case ElemKind::BoolTy:
4262 return std::numeric_limits<bool>::digits;
4263 case ElemKind::Int8QTy:
4264 return std::numeric_limits<int8_t>::digits;
4265 case ElemKind::UInt8QTy:
4266 case ElemKind::UInt8FusedQTy:
4267 return std::numeric_limits<uint8_t>::digits;
4268 case ElemKind::Float16Ty:
4269 // Custom type with layout 0 00000 0000000000.
4270 return 10;
4271 case ElemKind::BFloat16Ty:
4272 // bfloat16 has 8 significant bits.
4273 return 8;
4274 case ElemKind::Int16QTy:
4275 return std::numeric_limits<int16_t>::digits;
4276 case ElemKind::FloatTy:
4277 return std::numeric_limits<float>::digits;
4278 case ElemKind::Float64Ty:
4279 return std::numeric_limits<double>::digits;
4280 case ElemKind::Int32QTy:
4281 case ElemKind::Int32ITy:
4282 return std::numeric_limits<int32_t>::digits;
4283 case ElemKind::Int64ITy:
4284 return std::numeric_limits<int64_t>::digits;
4285 default:
4286 // Avoid compiler warning.
4287 break;
4288 }
4289 llvm_unreachable("Unknown type!");
4290}
4291
4292/// Returns true if casting value from \p srcTy to \p destTy may change it. As
4293/// implication of this, casting value from \p srcTy to \p destTy and back may
4294/// produce different value than before cast.
4295static bool isValueChangingCast(TypeRef srcTy, TypeRef destTy) {
4296 // FP-to-Int conversion may lead to loss of fraction, so it's not NOOP.
4297 if (srcTy->isFPType() && !destTy->isFPType()) {
4298 return true;
4299 }
4300 // Narrowing transform (e.g. int64 to int32) may lead to loss of
4301 // significant senior bits, so it's not NOOP.
4302 ElemKind srcElKind = srcTy->getElementType();
4303 ElemKind convElKind = destTy->getElementType();
4304 if (numSignificantBits(srcElKind) > numSignificantBits(convElKind)) {
4305 return true;
4306 }
4307 return false;
4308}
4309
4310/// Optimize away redundant ClipNodes.
4311/// We basically turn "Clip(Clip(Clip(A)))" to "Clip(A)".
4312bool OptimizeClips::run(Function *F, const CompilationContext &cctx) {
4313 LOG_SCOPE(F->getLogContext(), getName());
4314
4315 bool changed = false;
4316 for (Node &node : F->getNodes()) {
4317 ClipNode *clip = dyn_cast<ClipNode>(&node);
4318 if (!clip) {
4319 continue;
4320 }
4321 float min = clip->getMin();
4322 float max = clip->getMax();
4323 if (auto *clipPrev = dyn_cast<ClipNode>(clip->getInput().getNode())) {
4324 float minPrev = clipPrev->getMin();
4325 float maxPrev = clipPrev->getMax();
4326 auto *newClip =
4327 F->createClip(clipPrev->getName(), clipPrev->getInput().getNode(),
4328 std::max(minPrev, min), std::min(maxPrev, max));
4329 clip->getResult().replaceAllUsesOfWith(newClip);
4330 changed = true;
4331 continue;
4332 }
4333
4334 // We can fold Clip(Relu) -> Clip'
4335 if (ReluNode *relu = dyn_cast<ReluNode>(clip->getInput())) {
4336 const float newMin = std::max(0.0f, min);
4337 ClipNode *newClip = F->createClip(clip->getName().str() + "_relu",
4338 relu->getInput(), newMin, max);
4339 clip->getResult().replaceAllUsesOfWith(newClip->getResult());
4340 changed = true;
4341 continue;
4342 }
4343 }
4344
4345 return changed;
4346}
4347
4348/// \returns whether \p N used used by any Nodes with side effects.
4349static bool isUsedByNodeWithSideEffects(Node *N) {
4350 for (const auto &user : N->getUsers()) {
4351 if (user.getUser()->hasSideEffects()) {
4352 return true;
4353 }
4354 }
4355 return false;
4356}
4357
4358/// Helper that \returns whether \p NV cannot have its output quantization
4359/// parameters changed. For example, Concats require all inputs to have the same
4360/// quantization parameters, so we cannot change the quantization parameters of
4361/// a Node if it is input into a Concat.
4362static bool disallowQuantParamChange(const NodeValue &NV) {
4363 for (auto &user : NV.getUsers()) {
4364 if (isa<ConcatNode>(user.getUser())) {
4365 return true;
4366 }
4367 }
4368 return false;
4369}
4370
4371/// This is a specialized pass to use where we assume that quantized ranges are
4372/// all inside the FP16 range. This means that if we have any clips outside the
4373/// FP16 range we can safely remove them if adjacent to a quantized op.
4374bool EliminateClipsOutsideFP16Range::run(Function *F,
4375 const CompilationContext &cctx) {
4376 LOG_SCOPE(F->getLogContext(), getName());
4377
4378 if (!cctx.precisionConfig.clipQuantRangeToFP16) {
4379 return false;
4380 }
4381
4382 bool changed = false;
4383 for (Node &node : F->getNodes()) {
4384 // Clip(Dequantize(Node)) -> Dequantize(Node)
4385 if (ClipNode *clip = dyn_cast<ClipNode>(&node)) {
4386 DequantizeNode *DQN = dyn_cast<DequantizeNode>(clip->getInput());
4387 if (!DQN) {
4388 continue;
4389 }
4390
4391 // Can only eliminate the clip if its outside the FP16 range.
4392 if (clip->getMin() > kMinFP16 || clip->getMax() < kMaxFP16) {
4393 continue;
4394 }
4395
4396 // We can safely skip the Clip at this point.
4397 clip->getResult().replaceAllUsesOfWith(DQN->getResult());
4398 changed = true;
4399 continue;
4400 }
4401
4402 // Quantize(Clip(Node)) -> Quantize(Node)
4403 if (QuantizeNode *QN = dyn_cast<QuantizeNode>(&node)) {
4404 ClipNode *clip = dyn_cast<ClipNode>(QN->getInput());
4405 if (!clip) {
4406 continue;
4407 }
4408
4409 // Can only eliminate the clip if its outside the FP16 range.
4410 if (clip->getMin() > kMinFP16 || clip->getMax() < kMaxFP16) {
4411 continue;
4412 }
4413
4414 // We can safely skip the Clip at this point.
4415 QN->setNthInput(QuantizeNode::InputIdx, clip->getInput());
4416 changed = true;
4417 continue;
4418 }
4419 }
4420
4421 return changed;
4422}
4423
4424/// When quantized operators and Clips are used together, we can often merge the
4425/// Clip range and the Quantized range and remove the Clip.
4426bool OptimizeQuantizeClip::run(Function *F, const CompilationContext &cctx) {
4427 LOG_SCOPE(F->getLogContext(), getName());
4428
4429 // All of the optimizations here depend on the quantization parameters. If
4430 // we've loaded dummy qparams then none should be performed.
4431 if (cctx.precisionConfig.loadUniquedDummyQParams) {
4432 return false;
4433 }
4434
4435 bool changed = false;
4436
4437 // Change a quantized result type qResult to account for the range from clip.
4438 auto updateQuantizeNodeType = [](Function *F, const CompilationContext &cctx,
4439 NodeValue qResult, ClipNode *clip,
4440 bool skipIfQuantParamChange,
4441 bool allowQParamChange) {
4442 const auto qMinMax = qResult.getType()->getQuantizedValueRange();
4443 const float newMin = std::max(clip->getMin(), qMinMax.first);
4444 const float newMax = std::min(clip->getMax(), qMinMax.second);
4445
4446 // If the quantization parameters do not change then we can always elide the
4447 // Clip and do not need to change the type of qResult.
4448 if (newMin == qMinMax.first && newMax == qMinMax.second) {
4449 return true;
4450 }
4451
4452 if (disallowQuantParamChange(qResult)) {
4453 return false;
4454 }
4455
4456 // At this point the quantization parameters must be changing, so if we do
4457 // not allow for that then return false.
4458 if (!allowQParamChange || skipIfQuantParamChange) {
4459 return false;
4460 }
4461
4462 // Replace the old quantized type with the new type with different
4463 // min/max.
4464 const TypeRef oldTy = qResult.getType();
4465 const auto qParams = quantization::chooseQuantizationParams(
4466 {newMin, newMax}, cctx.precisionConfig.quantConfig.schema,
4467 oldTy->getElementType());
4468 const TypeRef newTy = F->getParent()->uniqueType(
4469 oldTy->getElementType(), oldTy->dims(), qParams.scale, qParams.offset);
4470 qResult.getNode()->setType(qResult.getResNo(), newTy);
4471 return true;
4472 };
4473
4474 for (Node &node : F->getNodes()) {
4475 // Clip(Dequantize(Node)) -> Dequantize(Node)
4476 if (ClipNode *clip = dyn_cast<ClipNode>(&node)) {
4477 DequantizeNode *DQN = dyn_cast<DequantizeNode>(clip->getInput());
4478 if (!DQN) {
4479 continue;
4480 }
4481
4482 // Cannot perform this optimization if there are multiple users of DQN or
4483 // DQN's input, as otherwise they'd have incorrect quantization params.
4484 NodeValue qResult = DQN->getInput();
4485 const bool skipIfQuantParamChange =
4486 DQN->getNumUsers() != 1 || qResult.getNode()->getNumUsers() != 1;
4487
4488 // Try to update the quantize's type, otherwise skip this one.
4489 if (!updateQuantizeNodeType(
4490 F, cctx, qResult, clip, skipIfQuantParamChange,
4491 cctx.optimizationOpts.enableQuantParamChanges)) {
4492 continue;
4493 }
4494
4495 // Now we skip the Clip since the node prior to DQN has included the
4496 // Clip's range in its quantization parameters.
4497 clip->getResult().replaceAllUsesOfWith(DQN->getResult());
4498 changed = true;
4499 continue;
4500 }
4501
4502 // Quantize(Clip(Node)) -> Quantize(Node)
4503 if (QuantizeNode *QN = dyn_cast<QuantizeNode>(&node)) {
4504 ClipNode *clip = dyn_cast<ClipNode>(QN->getInput());
4505 if (!clip) {
4506 continue;
4507 }
4508
4509 // Cannot set the type of quantized nodes if they're used by a Node with
4510 // side effects, as they may be expecting a specific type.
4511 const bool skipIfQuantParamChange = isUsedByNodeWithSideEffects(QN);
4512
4513 // Try to update the quantize's type, otherwise skip this one.
4514 if (!updateQuantizeNodeType(
4515 F, cctx, QN->getResult(), clip, skipIfQuantParamChange,
4516 cctx.optimizationOpts.enableQuantParamChanges)) {
4517 continue;
4518 }
4519
4520 // Now we can skip the Clip since the QN has accounted for the Clip's
4521 // range in its quantization parameters.
4522 QN->setNthInput(QuantizeNode::InputIdx, clip->getInput());
4523 changed = true;
4524 continue;
4525 }
4526 }
4527
4528 return changed;
4529}
4530
4531/// Optimize away ConvertToNode.
4532/// This basically turns "conversion(conversion A to B) to C"
4533/// into noop if all of the conditions below are met:
4534/// - the type of A and C are the same;
4535/// - A->B is not a FP-to-Int conversion;
4536/// - A->B is not a narrowing conversion.
4537bool OptimizeConversions::run(Function *F, const CompilationContext &cctx) {
4538 LOG_SCOPE(F->getLogContext(), getName());
4539
4540 bool changed = false;
4541 for (auto &node : F->getNodes()) {
4542 if (auto *CN = llvm::dyn_cast<ConvertToNode>(&node)) {
4543
4544 // Eliminate no-op conversion.
4545 if (CN->getInput().getType() == CN->getResult().getType()) {
4546 CN->getResult().replaceAllUsesOfWith(CN->getInput());
4547 changed = true;
4548 continue;
4549 }
4550
4551 // Perform conversion of constants.
4552 if (auto *BN = llvm::dyn_cast<Constant>(CN->getInput())) {
4553 auto newConst =
4554 convertConstant(*F->getParent(), *BN, CN->getResult().getType());
4555 if (newConst == NodeValue()) {
4556 continue;
4557 }
4558 CN->getResult().replaceAllUsesOfWith(newConst, F);
4559 changed = true;
4560 continue;
4561 }
4562
4563 // Simplify a chain of conversions A -> B -> C to A -> C, unless A -> B
4564 // is a narrowing cast.
4565 if (auto *BN = llvm::dyn_cast<ConvertToNode>(CN->getInput())) {
4566 auto AN = BN->getInput();
4567
4568 // Do not optimize away narrowing casts.
4569 if (!isValueChangingCast(AN.getType(), BN->getResult().getType())) {
4570 auto *newCast =
4571 F->createConvertTo(CN->getName(), AN, CN->getResult().getType());
4572 CN->getResult().replaceAllUsesOfWith(newCast);
4573 changed = true;
4574 continue;
4575 }
4576 }
4577 }
4578 }
4579 return changed;
4580}
4581
4582/// Optimize patterns of Int8 quantization/dequantization with ConvertTo. This
4583/// may have numerical differences but since Int8 has a small range it's likely
4584/// fine. This is opt in by a backend.
4585bool OptimizeOutIntermediateConversions::run(Function *F,
4586 const CompilationContext &cctx) {
4587 LOG_SCOPE(F->getLogContext(), getName());
4588
4589 bool changed = false;
4590 for (auto &node : F->getNodes()) {
4591 // Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8
4592 if (QuantizeNode *QN = llvm::dyn_cast<QuantizeNode>(&node)) {
4593 if (QN->getResult().getType()->getElementType() != ElemKind::Int8QTy) {
4594 continue;
4595 }
4596
4597 ConvertToNode *CN = llvm::dyn_cast<ConvertToNode>(QN->getInput());
4598 if (!CN) {
4599 continue;
4600 }
4601
4602 QN->setNthInput(QuantizeNode::InputIdx, CN->getInput());
4603 changed = true;
4604 continue;
4605 }
4606
4607 // ConvertTo(Dequantize(Node)) -> Dequantize(Node), where Dequantize is int8
4608 if (ConvertToNode *CN = llvm::dyn_cast<ConvertToNode>(&node)) {
4609 DequantizeNode *DN = llvm::dyn_cast<DequantizeNode>(CN->getInput());
4610 if (!DN ||
4611 DN->getInput().getType()->getElementType() != ElemKind::Int8QTy) {
4612 continue;
4613 }
4614
4615 // Create new Dequantize node, dequantizing directly to the kind of the
4616 // ConverTo that originally consumed it.
4617 DequantizeNode *newDN = F->createDequantize(
4618 DN->getName(), DN->getInput(), CN->getResult().getElementType());
4619 CN->getResult().replaceAllUsesOfWith(newDN->getResult());
4620 changed = true;
4621 continue;
4622 }
4623 }
4624
4625 return changed;
4626}
4627
4628// Look for float Relus that we can fuse up into quantized FCs. This is either
4629// with a Dequantize between them, or a Concat with multiple FCs being
4630// dequantized and concatenated together.
4631bool OptimizeQuantFCFloatRelu::run(Function *F,
4632 const CompilationContext &cctx) {
4633 LOG_SCOPE(F->getLogContext(), getName());
4634
4635 // This opt implies there to be changes to quantization, because we create an
4636 // int relu that was previously float.
4637 if (!cctx.optimizationOpts.enableQuantParamChanges) {
4638 return false;
4639 }
4640
4641 bool changed = false;
4642 for (auto &node : F->getNodes()) {
4643 auto *relu = llvm::dyn_cast<ReluNode>(&node);
4644 // Look for Float relus to start.
4645 if (!relu ||
4646 !isFloatElemKind(relu->getResult().getType()->getElementType())) {
4647 continue;
4648 }
4649
4650 // Now look for dequantize nodes. We may need to move this above a Concat.
4651 // Check if necessary.
4652 std::vector<FullyConnectedNode *> nodesToFuse;
4653 if (auto *CN = llvm::dyn_cast<ConcatNode>(relu->getInput())) {
4654 if (CN->getNumUsers() != 1) {
4655 continue;
4656 }
4657
4658 // Check if all the concat inputs are dequantized FCs.
4659 for (const NodeValue &NV : CN->getInputs()) {
4660 auto *DQ = llvm::dyn_cast<DequantizeNode>(NV);
4661 if (!DQ || DQ->getNumUsers() != 1) {
4662 break;
4663 }
4664 auto *FC = llvm::dyn_cast<FullyConnectedNode>(DQ->getInput());
4665 if (!FC || FC->getNumUsers() != 1) {
4666 break;
4667 }
4668 nodesToFuse.push_back(FC);
4669 }
4670 if (nodesToFuse.size() != CN->getInputs().size()) {
4671 continue;
4672 }
4673 } else if (auto *DQ = llvm::dyn_cast<DequantizeNode>(relu->getInput())) {
4674 if (DQ->getNumUsers() != 1) {
4675 continue;
4676 }
4677
4678 auto *FC = llvm::dyn_cast<FullyConnectedNode>(DQ->getInput());
4679 if (!FC || FC->getNumUsers() != 1) {
4680 break;
4681 }
4682 nodesToFuse.push_back(FC);
4683 } else {
4684 continue;
4685 }
4686
4687 // Did not find any quantized FCs to fuse, so continue.
4688 if (!nodesToFuse.size()) {
4689 continue;
4690 }
4691
4692 // If the quant FC is used by nodes with side effects then skip, since we
4693 // may be changing the user's input qparam type.
4694 bool skip = false;
4695 for (FullyConnectedNode *FC : nodesToFuse) {
4696 if (isUsedByNodeWithSideEffects(FC)) {
4697 skip = true;
4698 break;
4699 }
4700 }
4701 if (skip) {
4702 continue;
4703 }
4704
4705 // Now add quantized relus onto all of the FCs.
4706 for (FullyConnectedNode *FC : nodesToFuse) {
4707 const TypeRef FCTy = FC->getResult().getType();
4708 TypeRef qReluTy = nullptr;
4709 if (cctx.precisionConfig.loadUniquedDummyQParams) {
4710 // Reuse the FC type, since we don't know its actual qparams during AOT
4711 // optimization. When we the actual type later before deploying, we will
4712 // do a final processing pass to set min to 0 via updateReluTypes().
4713 qReluTy = FCTy;
4714 } else {
4715 // Use the same type as the FC for the Relu but with 0 as min.
4716 const auto qParams = quantization::chooseQuantizationParams(
4717 {0, FCTy->getQuantizedValueRange().second},
4718 cctx.precisionConfig.quantConfig.schema, FCTy->getElementType());
4719 qReluTy =
4720 F->getParent()->uniqueType(FCTy->getElementType(), FCTy->dims(),
4721 qParams.scale, qParams.offset);
4722 }
4723 ReluNode *qRelu = F->createRELU(relu->getName().str() + "_quant",
4724 FC->getResult(), qReluTy);
4725 FC->getResult().typeUnsafeReplaceAllUsesOfWith(qRelu->getResult(), F,
4726 qRelu);
4727 }
4728
4729 // Now we can get rid of the relu.
4730 relu->getResult().replaceAllUsesOfWith(relu->getInput());
4731 changed = true;
4732 continue;
4733 }
4734
4735 return changed;
4736}
4737
4738/// Look for Concats with all Dequantization as input and Quantization as
4739/// output, and change the Quantization/Dequantization into a rescale.
4740bool OptimizeConcatQuantization::run(Function *F,
4741 const CompilationContext &cctx) {
4742 LOG_SCOPE(F->getLogContext(), getName());
4743
4744 bool changed = false;
4745 for (auto &node : F->getNodes()) {
4746 auto *CN = dyn_cast<ConcatNode>(&node);
4747 if (!CN) {
4748 continue;
4749 }
4750
4751 // Look for a single Quantize user.
4752 if (CN->getUsers().size() != 1) {
4753 continue;
4754 }
4755 auto *QN = dyn_cast<QuantizeNode>((*CN->getUsers().begin()).getUser());
4756 if (!QN) {
4757 continue;
4758 }
4759
4760 // Gather/check all of the inputs are DequantizeNodes.
4761 std::vector<DequantizeNode *> DNs;
4762 DNs.reserve(CN->getInputs().size());
4763 for (const NodeValue &NV : CN->getInputs()) {
4764 auto *DN = dyn_cast<DequantizeNode>(NV);
4765 if (!DN || DN->getNumUsers() != 1) {
4766 break;
4767 }
4768 DNs.push_back(DN);
4769 }
4770
4771 // If not all CN inputs are Dequantizes then skip.
4772 if (DNs.size() != CN->getInputs().size()) {
4773 continue;
4774 }
4775
4776 // Now create Rescales instead of Dequantizes for all CN inputs.
4777 std::vector<NodeValue> newConcatInputs;
4778 newConcatInputs.reserve(DNs.size());
4779 TypeRef QNTy = QN->getResult().getType();
4780 for (DequantizeNode *DN : DNs) {
4781 if (DN->getInput().getType()->getScale() == QNTy->getScale() &&
4782 DN->getInput().getType()->getOffset() == QNTy->getOffset()) {
4783 // Don't need to rescale as it already has the right scale/offset.
4784 newConcatInputs.push_back(DN->getInput());
4785 } else {
4786 TypeRef newTy = F->getParent()->uniqueTypeWithNewShape(
4787 QNTy, DN->getResult().dims());
4788 auto *RS = F->createRescaleQuantized(DN->getName().str() + "_rescale",
4789 DN->getInput(), newTy);
4790 newConcatInputs.push_back(RS->getResult());
4791 }
4792 }
4793
4794 auto *newCN = F->createConcat(CN->getName(), newConcatInputs, CN->getDim());
4795
4796 // Now we can get rid of the Quantize after the CN.
4797 QN->getResult().replaceAllUsesOfWith(newCN->getResult());
4798 changed = true;
4799 continue;
4800 }
4801
4802 return changed;
4803}
4804
4805/// \returns a cloned version of node \p N, but with each of the cloned node's
4806/// output types set to the corresponding type in \p types. The new node is
4807/// added to Function \p F.
4808/// \pre types.size() == N->getNumResults()
4809static Node *cloneNodeWithNewTypes(Function *F, Node *N,
4810 llvm::ArrayRef<TypeRef> types) {
4811 assert(N->getNumResults() == types.size() &&
4812 "Number of types must equal number of results of the node.");
4813
4814 Node *newNode = F->addNode(N->clone());
4815 for (size_t i = 0; i < types.size(); i++) {
4816 newNode->setType(i, types[i]);
4817 }
4818
4819 return newNode;
4820}
4821
4822template <class T, class U>
4823using enable_if_same_t = std::enable_if<std::is_same<T, U>::value, U>;
4824#define FUNCTION_ENABLE_IF_TEMPLATE(NODE_NAME_) \
4825 template <class T, typename... Args> \
4826 typename enable_if_same_t<T, NODE_NAME_##Node>::type static
4827
4828FUNCTION_ENABLE_IF_TEMPLATE(AvgPool) * createNode(Function &F, Args... args) {
4829 return F.createAvgPool(args...);
4830}
4831FUNCTION_ENABLE_IF_TEMPLATE(MaxPool) * createNode(Function &F, Args... args) {
4832 return F.createMaxPool(args...);
4833}
4834FUNCTION_ENABLE_IF_TEMPLATE(Add)
4835*createNode(Function &F, Args... args) { return F.createAdd(args...); }
4836FUNCTION_ENABLE_IF_TEMPLATE(Sub)
4837*createNode(Function &F, Args... args) { return F.createSub(args...); }
4838FUNCTION_ENABLE_IF_TEMPLATE(Mul)
4839*createNode(Function &F, Args... args) { return F.createMul(args...); }
4840FUNCTION_ENABLE_IF_TEMPLATE(Div)
4841*createNode(Function &F, Args... args) { return F.createDiv(args...); }
4842FUNCTION_ENABLE_IF_TEMPLATE(Min)
4843*createNode(Function &F, Args... args) { return F.createMin(args...); }
4844FUNCTION_ENABLE_IF_TEMPLATE(Max)
4845*createNode(Function &F, Args... args) { return F.createMax(args...); }
4846FUNCTION_ENABLE_IF_TEMPLATE(MatMul)
4847*createNode(Function &F, Args... args) { return F.createMatMul(args...); }
4848
4849FUNCTION_ENABLE_IF_TEMPLATE(AvgPool) *
4850 createNewPool(Function &F, T *PN, RescaleQuantizedNode *rescale) {
4851 return createNode<T>(F, PN->getName(), rescale->getInput(), PN->getKernels(),
4852 PN->getStrides(), PN->getPads(), NCHW,
4853 PN->getCountIncludePads());
4854}
4855FUNCTION_ENABLE_IF_TEMPLATE(MaxPool) *
4856 createNewPool(Function &F, T *PN, RescaleQuantizedNode *rescale) {
4857 return createNode<T>(F, PN->getName(), rescale->getInput(), PN->getKernels(),
4858 PN->getStrides(), PN->getPads(),
4859 PN->getArgmax().getElementType());
4860}
4861
4862/// Sink Rescale down with Pooling node.
4863/// PoolingNode(Rescale(X)) -> Rescale(PoolingNode(X)).
4864/// Apply this transformation for AvgPool and MaxPool.
4865template <typename T>
4866static bool sinkDownRescaleToPoolingNode(Function &F, T *PN) {
4867 LOG_SCOPE(F.getLogContext(), "sinkDownRescaleToPoolingNode")
4868
4869 bool changed = false;
4870
4871 if (auto *rescale = dyn_cast<RescaleQuantizedNode>(PN->getInput())) {
4872 T *newPN = createNewPool(F, PN, rescale);
4873 auto rescaleOutTy = F.getParent()->uniqueTypeWithNewShape(
4874 rescale->getResult().getType(), PN->getResult().getType());
4875 auto *newRescale = F.createRescaleQuantized(
4876 rescale->getName(), newPN->getResult(), rescaleOutTy);
4877 PN->getResult().replaceAllUsesOfWith(newRescale);
4878 for (size_t i = 1; i < PN->getNumResults(); i++) {
4879 PN->getNthResult(i).replaceAllUsesOfWith(newPN->getNthResult(i));
4880 }
4881 changed = true;
4882 }
4883
4884 return changed;
4885}
4886
4887/// Combine Rescale down with Arithmetic node.
4888/// ArithmeticNode(Rescale(X), Rescale(Y)) -> ArithmeticNode(X, Y).
4889/// ArithmeticNode(Rescale(X), Y) -> ArithmeticNode(X, Y).
4890/// ArithmeticNode(X, Rescale(Y)) -> ArithmeticNode(X, Y).
4891/// Apply this optimization for Add, Sub, Mul, Div, Min, Max.
4892template <typename T>
4893static bool combineDownRescaleToArithmeticNode(Function &F, T *AN) {
4894 LOG_SCOPE(F.getLogContext(), "combineDownRescaleToArithmeticNode")
4895
4896 bool changed = false;
4897
4898 if (auto *rescale = dyn_cast<RescaleQuantizedNode>(AN->getLHS())) {
4899 T *newAN = createNode<T>(F, AN->getName(), AN->getResult().getType(),
4900 rescale->getInput(), AN->getRHS());
4901 AN->getResult().replaceAllUsesOfWith(newAN);
4902 AN = newAN;
4903 changed = true;
4904 }
4905 if (auto *rescale = dyn_cast<RescaleQuantizedNode>(AN->getRHS())) {
4906 T *newAN = createNode<T>(F, AN->getName(), AN->getResult().getType(),
4907 AN->getLHS(), rescale->getInput());
4908 AN->getResult().replaceAllUsesOfWith(newAN);
4909 changed = true;
4910 }
4911
4912 return changed;
4913}
4914
4915/// Sink Rescale nodes down when possible.
4916/// \returns if anything was changed in the given function.
4917static bool sinkRescaleQuantizedNode(Function *F,
4918 const CompilationContext &cctx) {
4919 LOG_SCOPE(F->getLogContext(), "sinkRescaleQuantizedNode");
4920 bool changed = false;
4921 for (auto &node : F->getNodes()) {
4922 // Sink Rescale below Reshape node.
4923 // Reshape(Rescale(X)) -> Rescale(Reshape(X)).
4924 if (auto *reshape = dyn_cast<ReshapeNode>(&node)) {
4925 auto *rescale = dyn_cast<RescaleQuantizedNode>(reshape->getInput());
4926 if (!rescale) {
4927 continue;
4928 }
4929
4930 auto *newReshape =
4931 F->createReshape(reshape->getName(), rescale->getInput(),
4932 reshape->getResult().dims(), reshape->getLayout());
4933 auto *newRescale = F->createRescaleQuantized(
4934 rescale->getName(), newReshape, reshape->getResult().getType());
4935 reshape->getResult().replaceAllUsesOfWith(newRescale);
4936
4937 changed = true;
4938 continue;
4939 }
4940
4941 // Sink Rescale below Slice node.
4942 // Slice(Rescale(X)) -> Rescale(Slice(X)).
4943 if (auto *slice = dyn_cast<SliceNode>(&node)) {
4944 auto *rescale = dyn_cast<RescaleQuantizedNode>(slice->getInput());
4945 if (!rescale) {
4946 continue;
4947 }
4948
4949 auto sliceOutTy = F->getParent()->uniqueTypeWithNewShape(
4950 rescale->getInput().getType(), slice->getResult().getType());
4951 auto *newSlice = F->createSlice(slice->getName(), rescale->getInput(),
4952 slice->getStart(), sliceOutTy);
4953 auto *newRescale = F->createRescaleQuantized(
4954 rescale->getName(), newSlice, slice->getResult().getType());
4955 slice->getResult().replaceAllUsesOfWith(newRescale);
4956
4957 changed = true;
4958 continue;
4959 }
4960
4961 // Sink Rescale below Transpose node.
4962 // Transpose(Rescale(X)) -> Rescale(Transpose(X)).
4963 if (auto *transpose = dyn_cast<TransposeNode>(&node)) {
4964 auto *rescale = dyn_cast<RescaleQuantizedNode>(transpose->getInput());
4965 if (!rescale) {
4966 continue;
4967 }
4968
4969 auto *newTranspose =
4970 F->createTranspose(transpose->getName(), rescale->getInput(),
4971 transpose->getShuffle(), transpose->getLayout());
4972 auto rescaleOutTy = F->getParent()->uniqueTypeWithNewShape(
4973 rescale->getResult().getType(), transpose->getResult().getType());
4974 auto *newRescale = F->createRescaleQuantized(rescale->getName(),
4975 newTranspose, rescaleOutTy);
4976 transpose->getResult().replaceAllUsesOfWith(newRescale);
4977
4978 changed = true;
4979 continue;
4980 }
4981
4982 if (auto *PN = dyn_cast<AvgPoolNode>(&node)) {
4983 // AvgPool input and output scale/bias may differ.
4984 if (!cctx.optimizationOpts.enableQuantParamChanges) {
4985 changed |= sinkDownRescaleToPoolingNode<AvgPoolNode>(*F, PN);
4986 }
4987 continue;
4988 }
4989
4990 if (auto *PN = dyn_cast<MaxPoolNode>(&node)) {
4991 changed |= sinkDownRescaleToPoolingNode<MaxPoolNode>(*F, PN);
4992 continue;
4993 }
4994
4995 // Combine Rescale down with FullyConnected node.
4996 // FullyConnected(Rescale(X)) -> FullyConnected(X).
4997 if (auto *FC = dyn_cast<FullyConnectedNode>(&node)) {
4998 auto *rescale = dyn_cast<RescaleQuantizedNode>(FC->getInput());
4999 if (!rescale) {
5000 continue;
5001 }
5002
5003 auto *newFC = F->createFullyConnected(FC->getName(), rescale->getInput(),
5004 FC->getWeights(), FC->getBias(),
5005 FC->getResult().getType());
5006 FC->getResult().replaceAllUsesOfWith(newFC);
5007
5008 changed = true;
5009 continue;
5010 }
5011
5012 // Combine Rescale down with Convolution node.
5013 // Convolution(Rescale(X), F, B) -> Convolution(X, F, B).
5014 // Convolution(X, Rescale(F), B) -> Convolution(X, F, B).
5015 // Convolution(X, F, Rescale(B)) -> Convolution(X, F, B).
5016 // ... and different combinations.
5017 if (auto *CN = dyn_cast<ConvolutionNode>(&node)) {
5018 auto *rescaleX = dyn_cast<RescaleQuantizedNode>(CN->getInput());
5019 auto *rescaleF = dyn_cast<RescaleQuantizedNode>(CN->getFilter());
5020 auto *rescaleB = dyn_cast<RescaleQuantizedNode>(CN->getBias());
5021 auto newX = rescaleX ? rescaleX->getInput() : CN->getInput();
5022 auto newF = rescaleF ? rescaleF->getInput() : CN->getFilter();
5023 auto newB = rescaleB ? rescaleB->getInput() : CN->getBias();
5024 if (rescaleX || rescaleF || rescaleB) {
5025 auto *newCN = F->createConv(CN->getName(), newX, newF, newB,
5026 CN->getResult().getType(), CN->getKernels(),
5027 CN->getStrides(), CN->getPads(),
5028 CN->getGroup(), CN->getDilation());
5029 newCN->setFusedActivation(CN->getFusedActivation());
5030 newCN->setFusedActivationArgs(CN->getFusedActivationArgs());
5031
5032 CN->getResult().replaceAllUsesOfWith(newCN);
5033 changed = true;
5034 }
5035 continue;
5036 }
5037
5038 if (auto *AN = dyn_cast<AddNode>(&node)) {
5039 changed |= combineDownRescaleToArithmeticNode<AddNode>(*F, AN);
5040 continue;
5041 }
5042 if (auto *AN = dyn_cast<SubNode>(&node)) {
5043 changed |= combineDownRescaleToArithmeticNode<SubNode>(*F, AN);
5044 continue;
5045 }
5046 if (auto *AN = dyn_cast<MulNode>(&node)) {
5047 changed |= combineDownRescaleToArithmeticNode<MulNode>(*F, AN);
5048 continue;
5049 }
5050 if (auto *AN = dyn_cast<DivNode>(&node)) {
5051 changed |= combineDownRescaleToArithmeticNode<DivNode>(*F, AN);
5052 continue;
5053 }
5054 if (auto *AN = dyn_cast<MinNode>(&node)) {
5055 changed |= combineDownRescaleToArithmeticNode<MinNode>(*F, AN);
5056 continue;
5057 }
5058 if (auto *AN = dyn_cast<MaxNode>(&node)) {
5059 changed |= combineDownRescaleToArithmeticNode<MaxNode>(*F, AN);
5060 continue;
5061 }
5062
5063 // Combine Rescale down with Relu node.
5064 // ReluNode(Rescale(in)) -> ReluNode(in).
5065 if (auto *RN = dyn_cast<ReluNode>(&node)) {
5066 if (auto *rescale = dyn_cast<RescaleQuantizedNode>(RN->getInput())) {
5067 auto *newRN = F->createRELU(RN->getName(), rescale->getInput(),
5068 RN->getResult().getType());
5069 RN->getResult().replaceAllUsesOfWith(newRN);
5070 changed = true;
5071 }
5072 continue;
5073 }
5074
5075 if (auto *MN = dyn_cast<MatMulNode>(&node)) {
5076 changed |= combineDownRescaleToArithmeticNode<MatMulNode>(*F, MN);
5077 continue;
5078 }
5079 }
5080
5081 return changed;
5082}
5083
5084/// Eliminate node sequences that are related to quantization.
5085/// \returns if anything was changed in the given function.
5086bool OptimizeQuantization::run(Function *F, const CompilationContext &cctx) {
5087 LOG_SCOPE(F->getLogContext(), getName());
5088 bool changed = false;
5089 // A worklist that contains the nodes to process.
5090 std::vector<Node *> worklist;
5091
5092 // Add all of the interesting nodes to the worklist.
5093 for (auto &node : F->getNodes()) {
5094 if (isa<QuantizeNode>(node) || isa<DequantizeNode>(node) ||
5095 isa<RescaleQuantizedNode>(node)) {
5096 worklist.push_back(&node);
5097 }
5098 }
5099
5100 while (!worklist.empty()) {
5101 // Take a node from the worklist.
5102 Node *node = worklist.back();
5103 worklist.pop_back();
5104
5105 if (auto *Q = dyn_cast<QuantizeNode>(node)) {
5106 if (auto *DQ = dyn_cast<DequantizeNode>(Q->getInput())) {
5107 // Quantize(Dequantize(X)) -> RescaleQuantized(X)
5108 // If the quantization-dequantization sequence does not change the
5109 // type then we can simply drop them without adding a requantization
5110 // node.
5111 changed = true;
5112 if (DQ->getInput().getType() == Q->getResult().getType()) {
5113 Q->getResult().replaceAllUsesOfWith(DQ->getInput());
5114 continue;
5115 }
5116
5117 auto *RS = F->createRescaleQuantized(Q->getName(), DQ->getInput(),
5118 Q->getResult().getType());
5119 Q->getResult().replaceAllUsesOfWith(RS);
5120
5121 // We may be able to optimize this rescale node. Remember to visit
5122 // this new node and try to optimize it later.
5123 worklist.push_back(RS);
5124 continue;
5125 }
5126
5127 if (auto *SN = dyn_cast<SplatNode>(Q->getInput())) {
5128 // Quantize(Splat) -> Splat'
5129 changed = true;
5130 SplatNode *newSN = F->createSplat(
5131 SN->getName(), Q->getResult().getType(), SN->getValue());
5132 Q->getResult().replaceAllUsesOfWith(newSN);
5133 continue;
5134 }
5135 }
5136
5137 if (auto *DQ = dyn_cast<DequantizeNode>(node)) {
5138 if (auto *Q = dyn_cast<QuantizeNode>(DQ->getInput())) {
5139 // Dequantize(Quantize(X)) -> X
5140 changed = true;
5141 DQ->getResult().replaceAllUsesOfWith(Q->getInput());
5142 continue;
5143 }
5144 // Fold the rescale into the following Dequantize.
5145 // Dequantize(rescale) -> Dequantize()
5146 if (auto *RS = dyn_cast<RescaleQuantizedNode>(DQ->getInput())) {
5147 changed = true;
5148 auto *newRS = F->createDequantize(DQ->getName(), RS->getInput(),
5149 DQ->getResult().getType());
5150 DQ->getResult().replaceAllUsesOfWith(newRS);
5151
5152 // We may be able to optimize this rescale node. Remember to visit
5153 // this new node and try to optimize it later.
5154 worklist.push_back(newRS);
5155 continue;
5156 }
5157 if (auto *SN = dyn_cast<SplatNode>(DQ->getInput())) {
5158 // Dequantize(Splat) -> Splat'
5159 changed = true;
5160 SplatNode *newSN = F->createSplat(
5161 SN->getName(), DQ->getResult().getType(), SN->getValue());
5162 DQ->getResult().replaceAllUsesOfWith(newSN);
5163 continue;
5164 }
5165 }
5166
5167 if (auto *RS = dyn_cast<RescaleQuantizedNode>(node)) {
5168 // All cases below tend to change the output scale/bias of a Node. This
5169 // may change the numerics of the op (even the range is narrower and so it
5170 // should be more accurate).
5171 if (!cctx.optimizationOpts.enableQuantParamChanges) {
5172 continue;
5173 }
5174
5175 if (RS->getInput().getType() == RS->getResult().getType()) {
5176 // If rescale does not change the type, then simply drop it.
5177 changed = true;
5178 RS->getResult().replaceAllUsesOfWith(RS->getInput());
5179 continue;
5180 }
5181
5182 // All optimizations in this scope below combine a Rescale up into or
5183 // above the Rescale's input X. If X has multiple users then this merging
5184 // will duplicate X, just with a different output scale/offset. If X is
5185 // not a Splat then this is likely not desired, as it means a
5186 // computational node (e.g. Add) is duplicated.
5187 if (!RS->getInput().hasOneUse() && !isa<SplatNode>(RS->getInput())) {
5188 continue;
5189 }
5190
5191 // Combine the rescale node up into its parent node.
5192 // Rescale(Node()) -> 'Node().
5193 bool addNewNodeToWorklist = false;
5194 switch (RS->getInput().getNode()->getKind()) {
5195 case Kinded::Kind::RescaleQuantizedNodeKind:
5196 case Kinded::Kind::QuantizeNodeKind:
5197 addNewNodeToWorklist = true;
5198 case Kinded::Kind::SplatNodeKind:
5199 case Kinded::Kind::AddNodeKind:
5200 case Kinded::Kind::SubNodeKind:
5201 case Kinded::Kind::MulNodeKind:
5202 case Kinded::Kind::DivNodeKind:
5203 case Kinded::Kind::FmodNodeKind:
5204 case Kinded::Kind::MinNodeKind:
5205 case Kinded::Kind::MatMulNodeKind:
5206 case Kinded::Kind::ConvolutionNodeKind:
5207 case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
5208 case Kinded::Kind::FullyConnectedNodeKind:
5209 case Kinded::Kind::SparseLengthsWeightedSumNodeKind: {
5210 changed = true;
5211 Node *newNode =
5212 cloneNodeWithNewTypes(F, RS->getInput(), RS->getResult().getType());
5213 RS->getResult().replaceAllUsesOfWith(newNode);
5214 if (addNewNodeToWorklist) {
5215 worklist.push_back(newNode);
5216 }
5217 continue;
5218 }
5219 default:;
5220 }
5221
5222 if (auto *MN = dyn_cast<MaxNode>(RS->getInput())) {
5223 // Rescale(MAX(X, Y)) -> MAX(Rescale(X), Rescale(Y)).
5224 // It's okay to rescale the operands because even if the output range
5225 // is smaller then truncation would have happened during the rescale.
5226 // On values that are outside of the range we just moved the
5227 // truncation to a different location.
5228 changed = true;
5229 auto name = RS->getName();
5230 auto *L = F->createRescaleQuantized(name, MN->getLHS(),
5231 RS->getResult().getType());
5232 auto *R = F->createRescaleQuantized(name, MN->getRHS(),
5233 RS->getResult().getType());
5234 auto *newMN = F->createMax(MN->getName(), L, R);
5235 worklist.push_back(L);
5236 worklist.push_back(R);
5237 RS->getResult().replaceAllUsesOfWith(newMN);
5238 continue;
5239 }
5240 } // Handle RescaleQuantizedNode
5241 } // For each item in the worklist.
5242
5243 // This pass is based on real qparams, so skip this opt if using dummies.
5244 if (!cctx.precisionConfig.loadUniquedDummyQParams) {
5245 changed |= optimizeQuantizedMaxSplat(F);
5246 }
5247
5248 // If nothing has changed then sink rescale quantization nodes.
5249 if (!changed) {
5250 changed = sinkRescaleQuantizedNode(F, cctx);
5251 }
5252 return changed;
5253}
5254
5255void glow::convertQuantizedConstants(Function *F, CompilationContext &cctx) {
5256 for (auto &node : F->getNodes()) {
5257 auto *Q = dyn_cast<QuantizeNode>(&node);
5258 if (!Q) {
5259 continue;
5260 }
5261 auto *C = dyn_cast<Constant>(Q->getInput());
5262 if (!C) {
5263 continue;
5264 }
5265
5266 // Quantize(Constant) -> Constant
5267 // Note, it does not really matter how many usages this Constant has.
5268 // Quantized graph will use optimized Constant and other functions will
5269 // refer to the floating point original Constant.
5270 NodeValue NC =
5271 convertConstant(*F->getParent(), *C, Q->getResult().getType());
5272 if (NC == NodeValue()) {
5273 continue;
5274 }
5275 Q->getResult().replaceAllUsesOfWith(NC);
5276 }
5277
5278 // Perform Dead Code Elimination.
5279 runDCEPass(F, cctx);
5280}
5281
5282void glow::convertPlaceholdersToConstants(Function *F,
5283 const PlaceholderBindings &bindings,
5284 llvm::ArrayRef<Placeholder *> phs) {
5285 LOG_SCOPE(F->getLogContext(), "convertPlaceholdersToConstants")
5286
5287 auto *M = F->getParent();
5288 for (auto &PH : F->findPlaceholders()) {
5289 if (std::find(phs.begin(), phs.end(), PH) != phs.end()) {
5290 continue;
5291 }
5292 auto *tensor = bindings.get(PH);
5293 if (!tensor) {
5294 continue;
5295 }
5296 auto *constant = M->createConstant(PH->getName(), *tensor, PH->getLayout());
5297 PH->getOutput().replaceAllUsesOfWith(constant, F);
5298 }
5299}
5300
5301/// \returns True if the \p node sub-tree corresponds to a scalar
5302/// (Constant or Splat) and return the float value in \p retFloat.
5303static bool getFloatScalar(Node *node, float *retFloat) {
5304 // Iterate across potential Tile Nodes (implied by broadcasting if any).
5305 auto *n = node;
5306 while (auto *TN = dyn_cast<TileNode>(n)) {
5307 n = TN->getInput();
5308 }
5309
5310 // After potential Tile nodes, it should be a singleton constant scalar node
5311 // with any shape corresponding to one single element.
5312 if (auto *constNode = dyn_cast<Constant>(n)) {
5313 if ((constNode->getType()->getElementType() != ElemKind::FloatTy) ||
5314 (constNode->getType()->size() != 1)) {
5315 return false;
5316 }
5317 auto valueH = constNode->getHandle<float>();
5318 std::vector<dim_t> coord(constNode->getType()->dims().size(), 0);
5319 *retFloat = valueH.at(coord);
5320 return true;
5321 }
5322 if (auto *splatNode = dyn_cast<SplatNode>(n)) {
5323 *retFloat = splatNode->getValue();
5324 return true;
5325 }
5326
5327 return false;
5328}
5329
5330/// Fold leakyRelu operations expressed as a sub-graph Max(A, Mul(A, scalar))
5331/// and replace it by PRelu(Splat).
5332bool FoldLeakyRelu::run(Function *F, const CompilationContext &cctx) {
5333 LOG_SCOPE(F->getLogContext(), getName());
5334 bool changed = false;
5335 auto &nodes = F->getNodes();
5336 for (auto &node : nodes) {
5337 auto *maxNode = dyn_cast<MaxNode>(&node);
5338 if (!maxNode) {
5339 continue;
5340 }
5341 NodeValue otherMaxOperand;
5342 MulNode *mulNode;
5343 if ((mulNode = dyn_cast<MulNode>(maxNode->getRHS()))) {
5344 otherMaxOperand = maxNode->getLHS();
5345 } else if ((mulNode = dyn_cast<MulNode>(maxNode->getLHS()))) {
5346 otherMaxOperand = maxNode->getRHS();
5347 } else {
5348 continue;
5349 }
5350 NodeValue otherMulOperand;
5351 float value;
5352 if (getFloatScalar(mulNode->getRHS(), &value)) {
5353 otherMulOperand = maxNode->getLHS();
5354 } else if (getFloatScalar(mulNode->getLHS(), &value)) {
5355 otherMulOperand = maxNode->getRHS();
5356 } else {
5357 continue;
5358 }
5359 if ((value <= 1.0f) && (otherMulOperand == otherMaxOperand)) {
5360 // The sub-tree is a Leaky-Relu, express it as a PRelu.
5361 auto *splat = F->createSplat(maxNode->getName(),
5362 mulNode->getResult().getType(), value);
5363 auto *PRelu = F->createPRELU(maxNode->getName(), otherMaxOperand, splat);
5364 maxNode->getResult().replaceAllUsesOfWith(PRelu);
5365 changed = true;
5366 continue;
5367 }
5368 }
5369 return changed;
5370}
5371
5372/// Parameters that are used to define ChannelShuffle operators.
5373struct ChannelShuffleParams {
5374 size_t group;
5375 size_t kernel;
5376};
5377
5378/// Compute the original parameters to the ChannelShuffle operator (represented
5379/// as ReshapeNode->TransposeNode->ReshapeNode) for which \p node is the leading
5380/// ReshapeNode. \returns The original ChannelShuffle parameters if possible and
5381/// empty Optional otherwise.
5382static llvm::Optional<ChannelShuffleParams>
5383getChannelShuffleParams(const ReshapeNode &node) {
5384 auto resM = llvm::Optional<ChannelShuffleParams>();
5385
5386 llvm::ArrayRef<dim_t> inputDims = node.getInput().dims();
5387 llvm::ArrayRef<dim_t> resultDims = node.getDims();
5388
5389 // Check that there is one more output dimension than input dimension.
5390 if (resultDims.size() != inputDims.size() + 1) {
5391 return resM;
5392 }
5393
5394 // Find the first output dimension that doesn't match its corresponding input
5395 // dimension.
5396 ChannelShuffleParams params;
5397 bool found = false;
5398 for (size_t i = 0, e = resultDims.size(); i < e - 1; ++i) {
5399 if (inputDims[i] != resultDims[i]) {
5400 params.kernel = i;
5401 params.group = resultDims[i];
5402 found = true;
5403 break;
5404 }
5405 }
5406
5407 // Double check the property that the mismatched output found dimension and
5408 // its successor together evenly multiply to the input dimension they
5409 // mismatched on.
5410 if (found && resultDims[params.kernel] * resultDims[params.kernel + 1] ==
5411 inputDims[params.kernel]) {
5412 resM = params;
5413 }
5414
5415 return resM;
5416}
5417
5418// Fold Reshape->Transpose->Reshape into ChannelShuffle when applicable.
5419bool FoldChannelShuffle::run(Function *F, const CompilationContext &cctx) {
5420 LOG_SCOPE(F->getLogContext(), getName());
5421
5422 bool changed = false;
5423 auto &nodes = F->getNodes();
5424 for (auto &node : nodes) {
5425 auto *RN2 = dyn_cast<ReshapeNode>(&node);
5426 if (!RN2) {
5427 continue;
5428 }
5429
5430 auto *TR = dyn_cast<TransposeNode>(RN2->getInput());
5431 if (!TR) {
5432 continue;
5433 }
5434
5435 auto *RN1 = dyn_cast<ReshapeNode>(TR->getInput());
5436 if (!RN1) {
5437 continue;
5438 }
5439
5440 // Check that the input and output shapes match:
5441 if (RN1->getInput().getType() != RN2->getResult().getType()) {
5442 continue;
5443 }
5444
5445 // Compute the original parameters to ChannelShuffle.
5446 auto paramsM = getChannelShuffleParams(*RN1);
5447 if (!paramsM.hasValue()) {
5448 continue;
5449 }
5450
5451 // Create a new ChannelShuffle with kernel parameter tranposed by the
5452 // TR's shuffle.
5453 auto *newCS = F->createChannelShuffle("channel_shuffle", RN1->getInput(),
5454 paramsM->group, paramsM->kernel);
5455 RN2->getResult().replaceAllUsesOfWith(newCS);
5456 changed = true;
5457 }
5458 return changed;
5459}
5460
5461// Fold Tile -> Add into BatchedAdd wherever applicable.
5462bool FoldTileAddIntoBatchedAdd::run(Function *F,
5463 const CompilationContext &cctx) {
5464 LOG_SCOPE(F->getLogContext(), getName());
5465
5466 bool changed = false;
5467 for (const auto &node : F->getNodes()) {
5468 const auto *addNode = dyn_cast<AddNode>(&node);
5469 if (!addNode) {
5470 continue;
5471 }
5472
5473 NodeValue batchNode, addedNode;
5474 const auto &LHS = addNode->getLHS();
5475 const auto &RHS = addNode->getRHS();
5476 const TileNode *tileNode = nullptr;
5477
5478 // Check if LHS is a tile.
5479 if ((tileNode = dyn_cast<TileNode>(LHS))) {
5480 batchNode = RHS;
5481 addedNode = tileNode->getInput();
5482 }
5483 // Check if RHS is a tile.
5484 else if ((tileNode = dyn_cast<TileNode>(RHS))) {
5485 batchNode = LHS;
5486 addedNode = tileNode->getInput();
5487 }
5488 // If neither LHS or RHS is a tile, nothing to do.
5489 else {
5490 continue;
5491 }
5492
5493 // If the tiling of the added node is not along the 0th axis,
5494 // 'Add' cannot be replaced with 'BatchedAdd'.
5495 if (tileNode->getAxis() != 0) {
5496 continue;
5497 }
5498
5499 auto oldDims = addedNode.dims();
5500 // If the 0th dimension of the added node is not 1,
5501 // then reducing dimension via reshaping is more complicated.
5502 // Hence, Add will not be replaced with BatchedAdd.
5503 if (oldDims.size() == 0 || oldDims[0] != 1) {
5504 continue;
5505 }
5506
5507 // Reshape the added node to create a slice for the batched add
5508 // such that its dim size is one less than that of the batch.
5509 const auto newDims = oldDims.take_back(oldDims.size() - 1);
5510 auto *slice = F->createReshape(tileNode->getName().str() + "_reshape",
5511 addedNode, newDims);
5512
5513 // Create a new batched add node to replace existing add node.
5514 auto *newBA = F->createBatchedAdd(addNode->getName().str() + "_batched_add",
5515 batchNode, slice);
5516 addNode->getResult().replaceAllUsesOfWith(newBA);
5517 changed = true;
5518 }
5519 return changed;
5520}
5521
5522/// Raise ClipNodes above shaping ops, e.g. Reshape, Transpose, Slice. Other
5523/// passes will sink Clips to try to eliminate redundant ones. This pass should
5524/// happen after sinking of Clips in order to try to get Clips to directly
5525/// consume compute Nodes outputs.
5526bool RaiseClipsAboveShapeNodes::run(Function *F,
5527 const CompilationContext &cctx) {
5528 LOG_SCOPE(F->getLogContext(), getName());
5529 bool changed = false;
5530
5531 // Keep track of what nodes will have oneLessUser due to DCE eventually. As an
5532 // example, we know that after we replace all users of OrigSlice with NewClip,
5533 // then OrigSlice and OrigClip are dead, and so Input1 will have one less user
5534 // after DCE.
5535 // Input Input
5536 // | / \
5537 // OrigSlice NewClip OrigSlice <--|
5538 // | --> | | |-- (These two are now dead.)
5539 // OrigClip NewSlice OrigClip <--|
5540 // | |
5541 // Save Save
5542 std::unordered_set<Node *> oneLessUser;
5543
5544 for (auto &N : F->getNodes()) {
5545 ClipNode *CN = dyn_cast<ClipNode>(&N);
5546 if (!CN) {
5547 continue;
5548 }
5549
5550 // If the Clip's input has multiple users then do not raise the Clip, as
5551 // otherwise this will impact other Nodes. We subtract off an extra user
5552 // here if we know one user will be eliminated due to DCE eventually (see
5553 // above pic).
5554 unsigned numUsers = CN->getInput().getNode()->getNumUsers();
5555 if (oneLessUser.count(CN->getInput().getNode())) {
5556 numUsers -= 1;
5557 }
5558 if (numUsers != 1) {
5559 continue;
5560 }
5561
5562 // Sink Reshape below Clip.
5563 if (ReshapeNode *RN = dyn_cast<ReshapeNode>(CN->getInput())) {
5564 ClipNode *newCN = F->createClip(CN->getName(), RN->getInput(),
5565 CN->getMin(), CN->getMax());
5566 ReshapeNode *newRN = F->createReshape(RN->getName(), newCN->getResult(),
5567 RN->getDims(), RN->getLayout());
5568 CN->getResult().replaceAllUsesOfWith(newRN->getResult());
5569 oneLessUser.insert(RN->getInput().getNode());
5570 changed = true;
5571 continue;
5572 }
5573
5574 // Sink Transpose below Clip.
5575 if (TransposeNode *TN = dyn_cast<TransposeNode>(CN->getInput())) {
5576 ClipNode *newCN = F->createClip(CN->getName(), TN->getInput(),
5577 CN->getMin(), CN->getMax());
5578 TransposeNode *newTN = F->createTranspose(
5579 TN->getName(), newCN->getResult(), TN->getShuffle(), TN->getLayout());
5580 CN->getResult().replaceAllUsesOfWith(newTN->getResult());
5581 oneLessUser.insert(TN->getInput().getNode());
5582 changed = true;
5583 continue;
5584 }
5585
5586 // Sink Slice below Clip.
5587 if (SliceNode *SN = dyn_cast<SliceNode>(CN->getInput())) {
5588 ClipNode *newCN = F->createClip(CN->getName(), SN->getInput(),
5589 CN->getMin(), CN->getMax());
5590 SliceNode *newSN =
5591 F->createSlice(SN->getName(), newCN->getResult(), SN->getStart(),
5592 SN->getResult().getType());
5593 CN->getResult().replaceAllUsesOfWith(newSN->getResult());
5594 oneLessUser.insert(SN->getInput().getNode());
5595 changed = true;
5596 continue;
5597 }
5598
5599 // Sink Tile below Clip.
5600 if (TileNode *TN = dyn_cast<TileNode>(CN->getInput())) {
5601 ClipNode *newCN = F->createClip(CN->getName(), TN->getInput(),
5602 CN->getMin(), CN->getMax());
5603 TileNode *newTN = F->createTile(TN->getName(), newCN->getResult(),
5604 TN->getCount(), TN->getAxis());
5605 CN->getResult().replaceAllUsesOfWith(newTN->getResult());
5606 oneLessUser.insert(TN->getInput().getNode());
5607 changed = true;
5608 continue;
5609 }
5610 } // For all nodes in the graph.
5611
5612 return changed;
5613}
5614
5615/// Fold ElemKind conversion nodes (ConvertTo, Quantize) into
5616/// single-user Placeholders. Note that this changes the semantics
5617/// of the IO of the Function and so must be done carefully, i.e. should always
5618/// be opt-in and done alongside conversion of corresponding Tensors in
5619/// PlaceholderBindings. If
5620/// cctx.optimizationOpts.foldStaticPlaceholderConversions is set this will
5621/// only change Placeholders marked as static.
5622bool FoldElemKindConversionIntoInputs::run(Function *F,
5623 const CompilationContext &cctx) {
5624 LOG_SCOPE(F->getLogContext(), getName());
5625
5626 bool changed = false;
5627 auto &nodes = F->getNodes();
5628
5629 for (auto it = nodes.begin(), e = nodes.end(); it != e; it++) {
5630 Node *N = &*it;
5631 // Handle conversion of inputs (conversion of Placeholders):
5632 ConvertToNode *CTN = llvm::dyn_cast<ConvertToNode>(N);
5633 QuantizeNode *QN = llvm::dyn_cast<QuantizeNode>(N);
5634 if (CTN || QN) {
5635 NodeValue in = CTN ? CTN->getInput() : QN->getInput();
5636 Placeholder *P = llvm::dyn_cast<Placeholder>(in);
5637 if (!P || P->getUsers().size() != 1) {
5638 continue;
5639 }
5640 // If foldElemKindConversionIntoIO is not set and this is not a static
5641 // placeholder then skip.
5642 if (!cctx.optimizationOpts.foldElemKindConversionIntoIO &&
5643 !P->isStatic()) {
5644 continue;
5645 }
5646
5647 // We have a conversion of a single-use placeholder to some other type, so
5648 // it is safe to do the requested conversion.
5649 NodeValue res = CTN ? CTN->getResult() : QN->getResult();
5650
5651 // Convert the type of the Placeholder to the conversion type. If target
5652 // type is fused call setTypeUnsafe because the shape can change in this
5653 // case.
5654 if (isFusedQuantizedElemKind(res.getElementType())) {
5655 P->setTypeUnsafe(Storage::OutputIdx, res.getType());
5656 } else {
5657 P->setType(Storage::OutputIdx, res.getType());
5658 }
5659
5660 // Replace all uses of the original ConvertTo to the Placeholder.
5661 res.replaceAllUsesOfWith(P);
5662
5663 changed = true;
5664 continue;
5665 }
5666 }
5667 return changed;
5668}
5669
5670/// Fold ElemKind conversion nodes (ConvertTo, Dequantize) into SaveNodes. Note
5671/// that this changes the semantics of the IO of the Function and so must be
5672/// done carefully, i.e. should always be opt-in and done alongside conversion
5673/// of corresponding Tensors in PlaceholderBindings.
5674bool FoldElemKindConversionIntoOutputs::run(Function *F,
5675 const CompilationContext &cctx) {
5676 LOG_SCOPE(F->getLogContext(), getName());
5677
5678 std::unordered_set<SaveNode *> deadSaves;
5679
5680 bool changed = false;
5681 // Since we will be adding in new SaveNodes, reverse iterate to be safe.
5682 auto &nodes = F->getNodes();
5683 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) {
5684 Node *N = &*it;
5685
5686 // Handle conversion of outputs (SaveNodes + Placeholders):
5687 if (SaveNode *SN = llvm::dyn_cast<SaveNode>(N)) {
5688 if (!SN) {
5689 continue;
5690 }
5691 if (SN->getPlaceholder()->getUsers().size() != 1) {
5692 continue;
5693 }
5694 ConvertToNode *CTN = llvm::dyn_cast<ConvertToNode>(SN->getInput());
5695 DequantizeNode *DQN = llvm::dyn_cast<DequantizeNode>(SN->getInput());
5696 if (!CTN && !DQN) {
5697 continue;
5698 }
5699 NodeValue in = CTN ? CTN->getInput() : DQN->getInput();
5700
5701 // Set the type of the Placeholder to be same the conversion's input.
5702 SN->getPlaceholder()->setType(Storage::OutputIdx, in.getType());
5703
5704 // Create a new SaveNode directly using the conversion's input.
5705 F->createSave(SN->getName(), in, SN->getPlaceholder());
5706
5707 // Queue up deleting the original SaveNode as it won't be deleted via DCE.
5708 deadSaves.insert(SN);
5709 changed = true;
5710 continue;
5711 }
5712 }
5713
5714 // Delete all the dead saves.
5715 for (SaveNode *SN : deadSaves) {
5716 F->eraseNode(SN);
5717 }
5718
5719 return changed;
5720}
5721
5722/// Broadcasts are implemented via 1) Reshape followed by a series of Tiles 2)
5723/// BroadcastNode. This helper unwinds Broadcast operation -- it \returns the
5724/// original Node before the broadcasting \p N if the broadcast takes place
5725/// between the 0th dimension to \p endDim. Otherwise, \p return \p N.
5726static NodeValue unwindBroadcast(NodeValue N, unsigned_t endDim) {
5727 if (auto *BN = dyn_cast<BroadcastNode>(N)) {
5728 const auto newShape = BN->getTargetDim();
5729 const auto axis = BN->getAxis();
5730 const auto &origDims = BN->getInput().dims();
5731
5732 if (origDims.size() + axis != newShape.size()) {
5733 return N;
5734 }
5735
5736 for (dim_t i = endDim; i < newShape.size(); i++) {
5737 if (!(i >= axis && origDims[i - axis] == newShape[i])) {
5738 return N;
5739 }
5740 }
5741
5742 return BN->getInput();
5743 }
5744
5745 // All non-BroadcastNode broadcasts must Tile at least once.
5746 if (!isa<TileNode>(N)) {
5747 return N;
5748 }
5749
5750 while (TileNode *TN = dyn_cast<TileNode>(N)) {
5751 // Check that the axis of the current Tile is inside of the expected
5752 // provided endDim.
5753 if (TN->getAxis() >= endDim) {
5754 return N;
5755 }
5756 // Applicable only if original dim is 1 in the Broadcast's Tile.
5757 if (TN->getInput().dims()[TN->getAxis()] != 1) {
5758 return N;
5759 }
5760 N = TN->getInput();
5761 }
5762 if (ReshapeNode *RN = dyn_cast<ReshapeNode>(N)) {
5763 return RN->getInput();
5764 }
5765
5766 return N;
5767}
5768
5769/// Looks for supported arithmetic ops following LayerNorm and folds them into
5770/// the scale/bias of the LayerNorm if the scale/bias are single-use
5771/// Splat/Constant, as we can then later on constant fold them in.
5772bool FoldLayerNormArithmetic::run(Function *F, const CompilationContext &cctx) {
5773 LOG_SCOPE(F->getLogContext(), getName());
5774
5775 bool changed = false;
5776 for (auto &N : F->getNodes()) {
5777 if (!isa<MulNode>(&N) && !isa<AddNode>(&N)) {
5778 continue;
5779 }
5780 // Currently only support floating point, as otherwise there will be
5781 // quantization parameter mismatches.
5782 if (!isFloatElemKind(N.getElementType(ArithmeticNode::ResultIdx))) {
5783 continue;
5784 }
5785
5786 // Check if the Mul/Add is consuming an LN.
5787 LayerNormalizationNode *LN =
5788 dyn_cast<LayerNormalizationNode>(N.getNthInput(ArithmeticNode::LHSIdx));
5789 if (!LN) {
5790 continue;
5791 }
5792
5793 // Check if the RHS is a Splat, or Constant, or temp PH (to be Constant
5794 // later). It may have been broadcasted to the correct shape.
5795 NodeValue RHS = unwindBroadcast(N.getNthInput(ArithmeticNode::RHSIdx),
5796 LN->getResult().dims().size() -
5797 LN->getScale().dims().size());
5798
5799 auto *P = dyn_cast<Placeholder>(RHS);
5800 if (!isa<SplatNode>(RHS) && !isa<Constant>(RHS) &&
5801 !(P && cctx.optimizationOpts.tempPHsForConstants.count(P))) {
5802 continue;
5803 }
5804
5805 // Make sure the RHS that we want to merge into the LN's Scale/Bias have the
5806 // same number of elements, since we're about to reshape them to match.
5807 if (RHS.getType()->size() != LN->getScale().getType()->size()) {
5808 continue;
5809 }
5810
5811 // RHS may have already been fused with a Reshape to get ready for
5812 // Tiling. Reshape RHS back here if necessary.
5813 if (RHS.dims() != LN->getScale().dims()) {
5814 RHS = F->createReshape(RHS.getNode()->getName().str() + "_squeezed", RHS,
5815 LN->getScale().dims());
5816 }
5817
5818 if (MulNode *MN = dyn_cast<MulNode>(&N)) {
5819 // Merge the Mul into a new Scale and Bias, multiplying by the original
5820 // Mul RHS that followed the LayerNorm.
5821 MulNode *newScale =
5822 F->createMul(LN->getScale().getNode()->getName().str() + "_fuse_" +
5823 MN->getName().data(),
5824 LN->getScale(), RHS);
5825 MulNode *newBias = F->createMul(LN->getBias().getNode()->getName().str() +
5826 "_fuse_" + MN->getName().data(),
5827 LN->getBias(), RHS);
5828 LN->getScale().replaceAllUsesOfWith(newScale->getResult(), F, newScale);
5829 LN->getBias().replaceAllUsesOfWith(newBias->getResult(), F, newBias);
5830 MN->getResult().replaceAllUsesOfWith(LN->getResult());
5831 changed = true;
5832 continue;
5833 }
5834
5835 if (AddNode *AN = dyn_cast<AddNode>(&N)) {
5836 // Merge the Add into a new Bias, adding the original Add RHS that
5837 // followed the LayerNorm.
5838 AddNode *newBias = F->createAdd(LN->getBias().getNode()->getName().str() +
5839 "_fuse_" + AN->getName().data(),
5840 LN->getBias(), RHS);
5841 LN->getBias().replaceAllUsesOfWith(newBias->getResult(), F, newBias);
5842 AN->getResult().replaceAllUsesOfWith(LN->getResult());
5843 changed = true;
5844 continue;
5845 }
5846 }
5847
5848 return changed;
5849}
5850
5851/// Looks for an activation directly following \p N from \p F that the backend
5852/// \p B supports for fusion.
5853template <class T> bool fuseActivation(T *N, Function *F, const Backend *B) {
5854 if (!N || N->hasFusedActivation() || !N->getResult().hasOneUse()) {
5855 return false;
5856 }
5857
5858 // We know there is one result user so we can just deref the first result.
5859 Node *activation = (*N->getResult().getUsers().begin()).getUser();
5860 if (!B || !B->supportsFusedActivation(N, activation)) {
5861 return false;
5862 }
5863
5864 NodeValue activationNV;
5865 switch (activation->getKind()) {
5866 case Kinded::Kind::ReluNodeKind:
5867 activationNV = cast<ReluNode>(activation)->getResult();
5868 N->setFusedActivation(FusedActivation::RELU);
5869 break;
5870 case Kinded::Kind::ClipNodeKind:
5871 activationNV = cast<ClipNode>(activation)->getResult();
5872 N->setFusedActivation(FusedActivation::CLIP);
5873 N->setFusedActivationArgs({cast<ClipNode>(activation)->getMin(),
5874 cast<ClipNode>(activation)->getMax()});
5875 break;
5876 case Kinded::Kind::SigmoidNodeKind:
5877 activationNV = cast<SigmoidNode>(activation)->getResult();
5878 N->setFusedActivation(FusedActivation::SIGMOID);
5879 break;
5880 case Kinded::Kind::TanhNodeKind:
5881 activationNV = cast<TanhNode>(activation)->getResult();
5882 N->setFusedActivation(FusedActivation::TANH);
5883 break;
5884 case Kinded::Kind::LeakyReluNodeKind:
5885 activationNV = cast<LeakyReluNode>(activation)->getResult();
5886 N->setFusedActivation(FusedActivation::LEAKY_RELU);
5887 N->setFusedActivationArgs({cast<LeakyReluNode>(activation)->getAlpha()});
5888 break;
5889 default:
5890 return false;
5891 }
5892
5893 // Modify the node output type to that of the activation.
5894 if (!(activationNV.getType()->isEqual(N->getResult().getType()))) {
5895 N->getResult().setType(activationNV.getType());
5896 }
5897
5898 activationNV.replaceAllUsesOfWith(N->getResult());
5899 return true;
5900}
5901
5902static bool foldActivations(Function *F, const CompilationContext &cctx,
5903 const Backend *B) {
5904 bool changed = false;
5905 for (auto &node : F->getNodes()) {
5906 if (fuseActivation(dyn_cast<ConvolutionNode>(&node), F, B)) {
5907 changed = true;
5908 continue;
5909 }
5910 if (fuseActivation(dyn_cast<ChannelwiseQuantizedConvolutionNode>(&node), F,
5911 B)) {
5912 changed = true;
5913 continue;
5914 }
5915 }
5916 return changed;
5917}
5918
5919void glow::fold(Function *F, const CompilationContext &cctx, const Backend *B) {
5920 LOG_SCOPE(F->getLogContext(), "glow::fold")
5921
5922 FunctionPassManager FPM("FoldFPM", createDefaultFoldPassPipeline());
5923 FPM.run(F, cctx);
5924
5925 foldActivations(F, cctx, B);
5926}
5927
5928void glow::optimize(Function *F, const CompilationContext &cctx,
5929 const Backend &B) {
5930 LOG_SCOPE(F->getLogContext(), "glow::optimize")
5931
5932 FunctionPassManager FPM("TargetDependentGraphOptzFPM",
5933 B.getOptimizationPipeline(), &B);
5934 FPM.run(F, cctx);
5935}
5936
5937void glow::optimize(Function *F, const CompilationContext &cctx) {
5938 LOG_SCOPE(F->getLogContext(), "glow::optimize")
5939
5940 // Indicates if the given function is completely loaded. A temporary
5941 // workaround until #3213 is complete.
5942 F->setState(FunctionState::FuncLoaded);
5943
5944 FunctionPassManager FPM("TargetIndependentGraphOptzFPM",
5945 createDefaultGraphOptimizationPassPipeline());
5946 FPM.run(F, cctx);
5947}
5948
5949void glow::optimize(Function *F, CompilationMode mode) {
5950 CompilationContext cctx;
5951 cctx.compMode = mode;
5952 optimize(F, cctx);
5953}
5954
5955/// Helper to pass over all Nodes in \p F and set FP16 accumulation to true for
5956/// those Nodes in the SLS family which support and need it. \p precConfig
5957/// contains the black/whitelist for skipping.
5958static void setFP16AccumSLS(Function *F,
5959 const PrecisionConfiguration &precConfig) {
5960 // Iterate from original end to beginning to avoid processing new Nodes added
5961 // during the pass.
5962 auto nodeIt = F->getNodes().end();
5963 auto stopIt = F->getNodes().begin();
5964 do {
5965 --nodeIt;
5966 Node &node = *nodeIt;
5967 // Only update allowed nodes based on black/whitelist.
5968 const bool inSet = precConfig.precisionModeKindSet.count(node.getKind());
5969 const bool allowConversion = precConfig.useSetAsWhitelist ? inSet : !inSet;
5970 if (!allowConversion) {
5971 continue;
5972 }
5973
5974#define CASE_SET_SLS_FP16_ACCUM(NODE_) \
5975 case Kinded::Kind::NODE_##NodeKind: { \
5976 NODE_##Node *SLS = llvm::cast<NODE_##Node>(&node); \
5977 if (SLS->getResult().getElementType() != ElemKind::Float16Ty) { \
5978 continue; \
5979 } \
5980 SLS->setUseFP16Accumulation(true); \
5981 continue; \
5982 }
5983
5984 switch (node.getKind()) {
5985 CASE_SET_SLS_FP16_ACCUM(RowwiseQuantizedSparseLengthsWeightedSum);
5986 CASE_SET_SLS_FP16_ACCUM(FusedRowwiseQuantizedSparseLengthsWeightedSum);
5987 CASE_SET_SLS_FP16_ACCUM(FusedRowwiseQuantizedSparseLengthsSum);
5988 CASE_SET_SLS_FP16_ACCUM(EmbeddingBagByteRowwiseOffsets);
5989 default:
5990 continue;
5991 }
5992 } while (nodeIt != stopIt);
5993}
5994
5995/// Look for Dequantize -> Swish -> Quantize, replace it with a quantized Swish.
5996bool QuantizeSwish::run(Function *F, const CompilationContext &cctx) {
5997 LOG_SCOPE(F->getLogContext(), getName());
5998
5999 if (!cctx.optimizationOpts.enableQuantParamChanges) {
6000 return false;
6001 }
6002
6003 bool changed = false;
6004 for (auto &N : F->getNodes()) {
6005 auto *SN = dyn_cast<SwishNode>(&N);
6006 if (!SN || SN->getNumUsers() != 1) {
6007 continue;
6008 }
6009
6010 QuantizeNode *QN =
6011 dyn_cast<QuantizeNode>((*SN->getUsers().begin()).getUser());
6012 if (!QN) {
6013 continue;
6014 }
6015
6016 DequantizeNode *DN = dyn_cast<DequantizeNode>(SN->getInput());
6017 if (!DN) {
6018 continue;
6019 }
6020
6021 SwishNode *newSN =
6022 F->createSwish(SN->getName().str() + "_int", DN->getInput(),
6023 QN->getResult().getType());
6024 QN->getResult().replaceAllUsesOfWith(newSN);
6025 changed = true;
6026 }
6027 return changed;
6028}
6029
6030/// Fold Exp + ReduceSum + Div into Softmax
6031/// IN
6032/// |
6033/// Exp IN
6034/// / \ |
6035/// | ReduceSum --> Softmax
6036/// \ / |
6037/// Div OUT
6038/// |
6039/// OUT
6040bool FoldExpSumDivIntoSoftmax::run(Function *F,
6041 const CompilationContext &cctx) {
6042 LOG_SCOPE(F->getLogContext(), getName());
6043
6044 bool changed = false;
6045 for (auto &N : F->getNodes()) {
6046 auto *EN = dyn_cast<ExpNode>(&N);
6047 if (!EN || EN->getNumUsers() != 2) {
6048 continue;
6049 }
6050
6051 DivNode *DN = nullptr;
6052 BatchedReduceAddNode *RSN = nullptr;
6053
6054 auto *user1 = EN->getUsers().front().getUser();
6055 auto *user2 = EN->getUsers().back().getUser();
6056
6057 if (isa<DivNode>(user1) && isa<BatchedReduceAddNode>(user2)) {
6058 DN = cast<DivNode>(user1);
6059 RSN = cast<BatchedReduceAddNode>(user2);
6060 } else if (isa<DivNode>(user2) && isa<BatchedReduceAddNode>(user1)) {
6061 DN = cast<DivNode>(user2);
6062 RSN = cast<BatchedReduceAddNode>(user1);
6063 } else {
6064 continue;
6065 }
6066
6067 if (RSN->getNumUsers() != 1) {
6068 continue;
6069 }
6070
6071 auto *broadcastNode = getOnlyUser(*RSN);
6072 if (broadcastNode == nullptr) {
6073 continue;
6074 }
6075 auto *tempDN = getOnlyUser(*broadcastNode);
6076 // Ensure that the inputs to the DivNode are Exp and ReduceSum.
6077 if (DN != tempDN) {
6078 continue;
6079 }
6080
6081 auto axes = EN->getInput().dims().vec();
6082 axes.back() = 1;
6083 auto *CN = F->getParent()->createConstant(glow::ElemKind::Int64ITy, axes,
6084 "selected");
6085
6086 auto *SM = F->createSoftMax("softmax", EN->getInput(), CN);
6087 DN->getResult().replaceAllUsesOfWith(SM);
6088 changed = true;
6089 }
6090 return changed;
6091}
6092
6093/// Local utility to remove identity Relu if fused into \p node.
6094/// \returns true or false whether the Relu was removed or not.
6095template <class NodeTy> static bool removeFusedIdentityRelu(Node *node) {
6096 auto *RN = dyn_cast<NodeTy>(node);
6097 if (!RN || RN->getFusedActivation() != FusedActivation::RELU) {
6098 return false;
6099 }
6100 // The output type must be quantized.
6101 auto outTy = RN->getResult().getType();
6102 if (!outTy->isQuantizedType()) {
6103 return false;
6104 }
6105 // The quantized 0.0f for Relu must match the min of the output type.
6106 auto outRange = quantization::getQuantizedRange(outTy->getElementType());
6107 if (outTy->getOffset() != outRange.first) {
6108 return false;
6109 }
6110 // Remove fused Relu.
6111 RN->setFusedActivation(FusedActivation::NONE);
6112 RN->setFusedActivationArgs({});
6113 return true;
6114}
6115
6116bool RemoveIdentityRelu::run(Function *F, const CompilationContext &cctx) {
6117 LOG_SCOPE(F->getLogContext(), getName());
6118 bool changed = false;
6119
6120 // Remove standalone Relu.
6121 for (auto &N : F->getNodes()) {
6122 auto *RN = dyn_cast<ReluNode>(&N);
6123 if (!RN) {
6124 continue;
6125 }
6126
6127 // The input and output types must be quantized.
6128 auto inpTy = RN->getInput().getType();
6129 auto outTy = RN->getResult().getType();
6130 if (!(inpTy->isQuantizedType() && outTy->isQuantizedType())) {
6131 continue;
6132 }
6133
6134 // The quantized 0.0f for Relu must match the min of the output type.
6135 auto outRange = quantization::getQuantizedRange(outTy->getElementType());
6136 if (outTy->getOffset() != outRange.first) {
6137 continue;
6138 }
6139
6140 // Remove Relu if input and output types are the same.
6141 // Otherwise change it with a RescaleQuantized.
6142 if (inpTy->isEqual(outTy)) {
6143 RN->getResult().replaceAllUsesOfWith(RN->getInput());
6144 } else {
6145 // TODO: Uncomment this once #5729 gets fixed.
6146 // auto *rescale =
6147 // F->createRescaleQuantized(RN->getName(), RN->getInput(), outTy);
6148 // RN->getResult().replaceAllUsesOfWith(rescale);
6149 }
6150 changed = true;
6151 }
6152
6153 // Remove fused Relu.
6154 for (auto &N : F->getNodes()) {
6155 changed |= removeFusedIdentityRelu<ConvolutionNode>(&N);
6156 changed |= removeFusedIdentityRelu<ChannelwiseQuantizedConvolutionNode>(&N);
6157 }
6158
6159 return changed;
6160}
6161
6162/// Local utility to remove identity Clip if fused into \p node.
6163/// \returns true or false whether the Clip was removed or not.
6164template <class NodeTy> static bool removeFusedIdentityClip(Node *node) {
6165 auto *CN = dyn_cast<NodeTy>(node);
6166 if (!CN || CN->getFusedActivation() != FusedActivation::CLIP) {
6167 return false;
6168 }
6169 // The output type must be quantized.
6170 auto outTy = CN->getResult().getType();
6171 if (!outTy->isQuantizedType()) {
6172 return false;
6173 }
6174 // The quantized min/max for Clip must match the min/max of the output type.
6175 TensorQuantizationParams outTQP{outTy->getScale(), outTy->getOffset()};
6176 auto fMin = CN->getFusedActivationArgs()[0];
6177 auto fMax = CN->getFusedActivationArgs()[1];
6178 auto qMin = quantization::quantize(fMin, outTQP, outTy->getElementType());
6179 auto qMax = quantization::quantize(fMax, outTQP, outTy->getElementType());
6180 auto outRange = quantization::getQuantizedRange(outTy->getElementType());
6181 if (!(qMin == outRange.first && qMax == outRange.second)) {
6182 return false;
6183 }
6184 // Remove fused Relu.
6185 CN->setFusedActivation(FusedActivation::NONE);
6186 CN->setFusedActivationArgs({});
6187 return true;
6188}
6189
6190bool RemoveIdentityClip::run(Function *F, const CompilationContext &cctx) {
6191 LOG_SCOPE(F->getLogContext(), getName());
6192 bool changed = false;
6193
6194 // Remove standalone Clip.
6195 for (auto &N : F->getNodes()) {
6196 auto *CN = dyn_cast<ClipNode>(&N);
6197 if (!CN) {
6198 continue;
6199 }
6200
6201 // The input and output types must be quantized.
6202 auto inpTy = CN->getInput().getType();
6203 auto outTy = CN->getResult().getType();
6204 if (!(inpTy->isQuantizedType() && outTy->isQuantizedType())) {
6205 continue;
6206 }
6207
6208 // The quantized min/max for Clip must match the min/max of the output type.
6209 TensorQuantizationParams outTQP{outTy->getScale(), outTy->getOffset()};
6210 auto fMin = CN->getMin();
6211 auto fMax = CN->getMax();
6212 auto qMin = quantization::quantize(fMin, outTQP, outTy->getElementType());
6213 auto qMax = quantization::quantize(fMax, outTQP, outTy->getElementType());
6214 auto outRange = quantization::getQuantizedRange(outTy->getElementType());
6215 if (!(qMin == outRange.first && qMax == outRange.second)) {
6216 continue;
6217 }
6218
6219 // Remove Clip if input and output types are the same.
6220 // Otherwise change it with a RescaleQuantized.
6221 if (inpTy->isEqual(outTy)) {
6222 CN->getResult().replaceAllUsesOfWith(CN->getInput());
6223 } else {
6224 // TODO: Uncomment this once #5729 gets fixed.
6225 // auto *rescale =
6226 // F->createRescaleQuantized(CN->getName(), CN->getInput(), outTy);
6227 // CN->getResult().replaceAllUsesOfWith(rescale);
6228 }
6229 changed = true;
6230 }
6231
6232 // Remove fused Clip.
6233 for (auto &N : F->getNodes()) {
6234 changed |= removeFusedIdentityClip<ConvolutionNode>(&N);
6235 changed |= removeFusedIdentityClip<ChannelwiseQuantizedConvolutionNode>(&N);
6236 }
6237
6238 return changed;
6239}
6240
6241/// Convert a FullyConnected node to a 1x1 Convolution.
6242bool ConvertFullyConnectedToConvolution::run(Function *F,
6243 const CompilationContext &cctx) {
6244 LOG_SCOPE(F->getLogContext(), getName());
6245
6246 bool changed = false;
6247 for (auto &N : F->getNodes()) {
6248 auto *FCN = dyn_cast<FullyConnectedNode>(&N);
6249 if (!FCN) {
6250 continue;
6251 }
6252
6253 NodeValue output = FCN->getResult();
6254 NodeValue input = FCN->getInput();
6255 NodeValue filter = FCN->getWeights();
6256 NodeValue bias = FCN->getBias();
6257
6258 // Reshape input from 2D to 4D.
6259 auto inpDims = ShapeHW(input.getType()->dims());
6260 std::vector<dim_t> inpDimsCN = {inpDims.height, 1, 1, inpDims.width};
6261 input = F->createReshape(FCN->getName(), input, inpDimsCN);
6262
6263 // Transpose filter and reshape from 2D to 4D.
6264 filter = F->createTranspose(FCN->getName(), filter, {1, 0});
6265 auto filterDims = ShapeHW(filter.getType()->dims());
6266 std::vector<dim_t> filterDimsCN = {filterDims.height, 1, 1,
6267 filterDims.width};
6268 filter = F->createReshape(FCN->getName(), filter, filterDimsCN);
6269
6270 // Create Conv2D node with same output type but 4D shape.
6271 auto outDims = ShapeHW(output.getType()->dims());
6272 std::vector<dim_t> outDimsCN = {outDims.height, 1, 1, outDims.width};
6273 auto outTyCN =
6274 F->getParent()->uniqueTypeWithNewShape(output.getType(), outDimsCN);
6275 NodeValue outputCN =
6276 F->createConv(FCN->getName(), input, filter, bias, outTyCN,
6277 /* kernels */ {1, 1},
6278 /* strides */ {1, 1},
6279 /* pads */ {0, 0, 0, 0},
6280 /* group */ 1);
6281
6282 // Reshape the 4D output back to its original 2D shape.
6283 outputCN =
6284 F->createReshape(FCN->getName(), outputCN, output.getType()->dims());
6285 FCN->getResult().replaceAllUsesOfWith(outputCN);
6286 changed = true;
6287 }
6288 return changed;
6289}
6290
6291/// Fold Min -> Max or Max -> Min to Clip
6292/// We need to place this function after `OptimizeArithmeticNodes` in which the
6293/// constant operator in commutative nodes is moved to the RHS, so that we can
6294/// assume SplatNode is on the RHS of MinNode (MaxNode).
6295/// Only if it's the following structure we do this optimization
6296///
6297/// someOtherInput SplatNode
6298/// \ /
6299/// SplatNode MinNode (MaxNode)
6300/// \ /
6301/// MaxNode (MinNode)
6302/// |
6303/// Out
6304///
6305/// ClipNode's min will be the SplatNode (maxSN) connected to the MaxNode.
6306/// ClipNode's max will be the SplatNode (minSN) connected to the MinNode.
6307bool FoldMinMaxToClip::run(Function *F, const CompilationContext &cctx) {
6308 LOG_SCOPE(F->getLogContext(), getName());
6309
6310 bool changed = false;
6311 for (auto &N : F->getNodes()) {
6312 MaxNode *maxNode = dyn_cast<MaxNode>(&N);
6313 MinNode *minNode = dyn_cast<MinNode>(&N);
6314 NodeValue otherInput;
6315 NodeValue resultNV;
6316 SplatNode *minSN = nullptr;
6317 SplatNode *maxSN = nullptr;
6318
6319 // We assume SplatNode is on the RHS.
6320 // If currect node is MinNode
6321 // - try casting the input to MaxNode and SplatNode.
6322 // - If LHS of MinNode is MaxNode
6323 // - try casting RHS of maxNode to SplatNode.
6324 if (minNode) {
6325 resultNV = minNode->getResult();
6326 maxNode = dyn_cast<MaxNode>(
6327 unwindBroadcast(minNode->getLHS(), minNode->getLHS().dims().size()));
6328 minSN = dyn_cast<SplatNode>(
6329 unwindBroadcast(minNode->getRHS(), minNode->getRHS().dims().size()));
6330
6331 if (maxNode) {
6332 maxSN = dyn_cast<SplatNode>(unwindBroadcast(
6333 maxNode->getRHS(), maxNode->getRHS().dims().size()));
6334 otherInput =
6335 unwindBroadcast(maxNode->getLHS(), maxNode->getLHS().dims().size());
6336 }
6337 } else if (maxNode) { // vice versa for MaxNode
6338 resultNV = maxNode->getResult();
6339 minNode = dyn_cast<MinNode>(
6340 unwindBroadcast(maxNode->getLHS(), maxNode->getLHS().dims().size()));
6341 maxSN = dyn_cast<SplatNode>(
6342 unwindBroadcast(maxNode->getRHS(), maxNode->getRHS().dims().size()));
6343
6344 if (minNode) {
6345 minSN = dyn_cast<SplatNode>(unwindBroadcast(
6346 minNode->getRHS(), minNode->getRHS().dims().size()));
6347 otherInput =
6348 unwindBroadcast(minNode->getLHS(), minNode->getLHS().dims().size());
6349 }
6350 }
6351
6352 // If any of these is nullptr, it means the structure is not the same as
6353 // above example.
6354 if (!(minNode && maxNode && minSN && maxSN)) {
6355 continue;
6356 }
6357
6358 // If minSN is smaller than maxSN, which is really weird because every entry
6359 // of the result will be a same number. In this case, we don't fold them.
6360 if (minSN->getValue() < maxSN->getValue()) {
6361 DLOG(INFO) << "Catch a combination of MinNode and MaxNode while MinSplat "
6362 "is smaller than MaxSplat, which would make all entries of "
6363 "the result tensor to be the same!";
6364 continue;
6365 }
6366
6367 // The second node is broadcasted and we need to broadcast the otherInput.
6368 if (otherInput.dims() != resultNV.dims()) {
6369 otherInput = F->createBroadcast(
6370 otherInput.getNode()->getName().str() + ".broadcast", otherInput,
6371 resultNV.dims(), resultNV.dims().size() - otherInput.dims().size());
6372 }
6373
6374 // MinSN is the SplatNode input of MinNode while MaxSN is the SplatNode
6375 // input of MaxNode. MinSN should be greater than MaxSN so we put MaxSN as
6376 // the min of ClipNode.
6377 auto minValue = maxSN->getValue();
6378 auto maxValue = minSN->getValue();
6379 ClipNode *CN = F->createClip(resultNV.getNode()->getName(), otherInput,
6380 resultNV.getType(),
6381 /* min */ minValue, /* max */ maxValue);
6382 resultNV.replaceAllUsesOfWith(CN->getResult());
6383 changed = true;
6384 }
6385
6386 return changed;
6387}
6388
6389/// Look for qparams with scale (when casted to fp16) == 0 and replace them with
6390/// zero Splats.
6391bool ReplaceZeroScaleFP16QuantNodes::run(Function *F,
6392 const CompilationContext &cctx) {
6393 LOG_SCOPE(F->getLogContext(), getName());
6394
6395 // Cannot run this opt if we're using dummy qparams.
6396 if (cctx.precisionConfig.loadUniquedDummyQParams) {
6397 return false;
6398 }
6399
6400 auto processNV = [](Function *F, NodeValue resNV) {
6401 const TypeRef resTy = resNV.getType();
6402 if (resTy->isFusedQuantizedType() || !resTy->isQuantizedType()) {
6403 return false;
6404 }
6405
6406 // Check if we have a scale that's below the minimum allowed FP16 val.
6407 if (resTy->getScale() >= kMinScaleFP16) {
6408 return false;
6409 }
6410
6411 // Skip if used by node with side effects, since we cannot change the
6412 // qparams in such cases.
6413 if (isUsedByNodeWithSideEffects(resNV.getNode())) {
6414 return false;
6415 }
6416
6417 // This NodeValue has scale = 0.f, which means the equivalent float result
6418 // will always be equal to 0.f. So create a splat with value 0.f, scale 1.f,
6419 // offset 0 instead.
6420 auto splatTy = F->getParent()->uniqueType(resTy->getElementType(),
6421 resTy->dims(), 1.f, 0);
6422 auto *SN = F->createSplat(resNV.getNode()->getName().str() + ".splatted",
6423 splatTy, 0.f);
6424
6425 // Note: must use type unsafe replace because we've changed the qparams.
6426 resNV.typeUnsafeReplaceAllUsesOfWith(SN->getResult(), F);
6427 return true;
6428 };
6429
6430 bool changed = false;
6431 // Since we will be adding in new SplatNodes, reverse iterate to be safe.
6432 auto &nodes = F->getNodes();
6433 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) {
6434 Node *N = &*it;
6435 // For now only support nodes with single outputs.
6436 if (N->getNumResults() != 1) {
6437 continue;
6438 }
6439 NodeValue resNV = N->getNthResult(0);
6440 changed |= processNV(F, resNV);
6441 }
6442 for (Placeholder *PH : F->findPlaceholders()) {
6443 changed |= processNV(F, PH->getOutput());
6444 }
6445 for (Constant *C : F->findConstants()) {
6446 changed |= processNV(F, C->getOutput());
6447 }
6448
6449 return changed;
6450}
6451
6452/// This funciton uses TypeAToTypeBFunctionConverter to do a whole graph
6453/// demotion of Index type from INT64 to INT32.
6454static void transformIndexTypeDemotion(const Backend &B, Function *F,
6455 CompilationContext &cctx) {
6456
6457 // Does a coarse check to make sure none of the indices potentially can
6458 // overflow 32 bit. For now we just give up on the whole optimization, since
6459 // this is probably a corner case.
6460 for (auto &n : F->getNodes()) {
6461 for (int i = 0, nOutputs = n.getNumResults(); i < nOutputs; ++i) {
6462 if (n.getNthResult(i).getType()->actualSize() >=
6463 size_t(std::numeric_limits<int32_t>::max())) {
6464 return;
6465 }
6466 }
6467 }
6468
6469 PrecisionConfiguration precConfig;
6470 if (B.canDoIndexTypeDemotion(ElemKind::Int64ITy, ElemKind::Int32ITy,
6471 precConfig) &&
6472 cctx.optimizationOpts.enableTypeDemotion) {
6473 precConfig.precisionModeKindSet.insert(Kinded::Kind::TraceEventNodeKind);
6474 TypeAToTypeBFunctionConverter converter(*F, ElemKind::Int64ITy,
6475 ElemKind::Int32ITy, precConfig);
6476 converter.convert();
6477 }
6478}
6479
6480void glow::transformForPrecisionMode(const Backend &B, Function *F,
6481 CompilationContext &cctx) {
6482 LOG_SCOPE(F->getLogContext(), "transformForPrecisionMode")
6483 const PrecisionConfiguration &precConfig = cctx.precisionConfig;
6484
6485 switch (precConfig.quantMode) {
6486 case QuantizationMode::Profile: {
6487 assert(cctx.bindings);
6488
6489 LOG_SCOPE(F->getLogContext(), "glow::profileQuantization")
6490
6491 glow::profileQuantization(*cctx.bindings, F, precConfig.profConfig);
6492 break;
6493 }
6494
6495 case QuantizationMode::Quantize: {
6496 LOG_SCOPE(F->getLogContext(), "quantization::quantizeFunction")
6497
6498 quantization::quantizeFunction(F, precConfig.quantConfig, B,
6499 *cctx.loweredInfoMap,
6500 precConfig.precisionModeKindSet);
6501 break;
6502 }
6503
6504 case QuantizationMode::None: {
6505 break;
6506 }
6507 }
6508
6509 if (precConfig.convertToFP16) {
6510 LOG_SCOPE(F->getLogContext(), "glow::convertFunctionToFloat16")
6511 convertFunctionToFloat16(F, precConfig);
6512 FunctionPassManager FPM("FP16GraphOptzFPM",
6513 createFP16GraphOptimizationPassPipeline());
6514 FPM.run(F, cctx);
6515 }
6516
6517 // By default, FP16 SLS accumulation is not enabled.
6518 // If requested, Force all ops in the SLS family to use FP16 accumulation.
6519 if (precConfig.forceFP16AccumSLS) {
6520 setFP16AccumSLS(F, precConfig);
6521 }
6522
6523 // Convert UInt4FusedFP16QTy/UInt8FusedFP16QTy to UInt8FusedQTy.
6524 if (precConfig.convert4BitFusedToFP32 || precConfig.convert8BitFusedToFP32) {
6525 LOG_SCOPE(F->getLogContext(), "glow::convertFunctionToFP32ScaleOffset");
6526 convertFunctionToFP32ScaleOffset(F, precConfig);
6527 }
6528
6529 // In FusedRowwiseQSLWS, convert its indices from Int32(if there is any) to
6530 // Int64.
6531 if (precConfig.convertIndicesToInt64) {
6532 LOG_SCOPE(F->getLogContext(), "glow::convertFunctionIndicesToInt64");
6533 convertFunctionIndicesToInt64(F, precConfig);
6534 }
6535}
6536
6537Error glow::optimizeFunctionBeforeLowering(Function *F,
6538 CompilationContext &cctx) {
6539 LOG_SCOPE(F->getLogContext(), "glow::optimizeFunctionBeforeLowering")
6540
6541 // If we only want to lower the Function, do nothing here.
6542 if (cctx.optimizationOpts.onlyLowerFuns.count(F)) {
6543 return Error::success();
6544 }
6545
6546 // Verify the function pre-optimization/lowering.
6547 assert(F->verify() && "Function must be valid");
6548
6549 // Verify that the CompilationContext is set up correctly.
6550 RETURN_IF_ERR(cctx.verify());
6551
6552 // Fold low-level operators into higher-level operators.
6553 // This is useful when compiling an input model where some high-level
6554 // operators have been lowered (this can be for instance a side effect of
6555 // model converters, like converters from Tensorflow to ONNX). In this
6556 // situation, such folding can then enable more optimizations and also improve
6557 // the performance backends that support natively such high-level operators.
6558 ::glow::fold(F, cctx);
6559
6560 // Optimize the graph. Only runs optimizations that are target-independent.
6561 ::glow::optimize(F, cctx);
6562 return Error::success();
6563}
6564
6565/// Error message to print when there is a graph hash checking error.
6566static const char *graphPreLowerHashCheckErrMsg =
6567 R"RAW(Graph hash mismatch!
6568%s
6569%s
6570Potential causes:
65711. The profile YAML file was produced with an older version of the Glow tools
6572 while the quantization of the model is performed with a newer version.
65732. The profile YAML file was produced for a different model than the model used
6574 for quantization.
65753. The profile YAML file was produced for the same model but for a different
6576 batch size. If the profile was generated using the 'image-classifier' Glow
6577 tool you can select the batch size of the model during profiling using the
6578 'minibatch' option. During quantization you can choose the batch size of the
6579 model by choosing for each input placeholder the tensor size using the
6580 'model-input' option.
6581)RAW";
6582
6583// NOTE: When updating this function, please also update the documentation in
6584// docs/GraphOptimizationPipeline.md
6585Error glow::optimizeFunction(Function *F, const Backend &B,
6586 CompilationContext &cctx,
6587 const glow::runtime::DeviceInfo *devInfo) {
6588 LOG_SCOPE(F->getLogContext(), "glow::optimizeFunction")
6589
6590 // If requested only lower the Function and early return.
6591 if (cctx.optimizationOpts.onlyLowerFuns.count(F)) {
6592 ::glow::lower(F, cctx, &B);
6593 // Cleanup from lowering via DCE.
6594 runDCEPass(F, cctx);
6595
6596 if (!B.verify(*F, cctx.verboseCompile)) {
6597 return MAKE_ERR(
6598 ErrorValue::ErrorCode::COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE,
6599 "Unsupported node(s) found after only-lowering path for Function " +
6600 F->getName().str() + " for backend " + B.getBackendName());
6601 }
6602 return Error::success();
6603 }
6604
6605 RETURN_IF_ERR(optimizeFunctionBeforeLowering(F, cctx));
6606
6607 // Graph hash check:
6608 // - During PROFILING store the hash of the graph at this point of the
6609 // optimization pipeline. The hash will be exported in the YAML profile.
6610 // - During QUANTIZATION the hash imported from the YAML profile is used to
6611 // verify the hash of the graph. This is helpful to catch mismatches between
6612 // the graph used during profiling/quantization.
6613 const PrecisionConfiguration &precConfig = cctx.precisionConfig;
6614 if (precConfig.quantMode == QuantizationMode::Profile) {
6615 cctx.info.graphPreLowerHash = F->getHash();
6616 } else if (precConfig.quantMode == QuantizationMode::Quantize) {
6617 const auto &quantConfig = cctx.precisionConfig.quantConfig;
6618 if (quantConfig.checkGraphPreLowerHash) {
6619 auto profileHash = quantConfig.graphPreLowerHash;
6620 auto currentHash = F->getHash();
6621 auto profileHashStr =
6622 strFormat("Profile graph hash: 0x%" PRIX64, (uint64_t)(profileHash));
6623 auto currentHashStr =
6624 strFormat("Current graph hash: 0x%" PRIX64, (uint64_t)(currentHash));
6625 RETURN_ERR_IF_NOT(profileHash == currentHash,
6626 strFormat(graphPreLowerHashCheckErrMsg,
6627 profileHashStr.c_str(),
6628 currentHashStr.c_str()));
6629 }
6630 }
6631
6632 // Lower the graph into a sequence of low-level linear algebra operations.
6633 if (precConfig.quantMode == QuantizationMode::Profile) {
6634 // When profiling, pass a nullptr for the backend, signaling that all nodes
6635 // should be lowered. loweredInfoMap logs what is lowered from what for
6636 // later use when creating quantization infos. Also pass the precision mode
6637 // kind set as nodes to not lower, specified higher up in the stack.
6638 ::glow::lower(F, cctx, /* backend */ nullptr,
6639 precConfig.precisionModeKindSet);
6640 } else {
6641 // Lower based on the backend's preferences.
6642 ::glow::lower(F, cctx, &B);
6643 }
6644
6645 // Transforms the graph by demoting i64 to i32.
6646 transformIndexTypeDemotion(B, F, cctx);
6647
6648 // Transform given precision mode; may quantize, convert to fp16, or
6649 // instrument with profiling nodes. This must be done after lowering.
6650 transformForPrecisionMode(B, F, cctx);
6651
6652 // Optimize the quantized graph because quantization nodes should be optimized
6653 // before folding Activation into Conv.
6654 ::glow::optimize(F, cctx);
6655
6656 // Fold activations before lowering to enable cases which would not fuse after
6657 // lowering. This concerns particularly convolution&relu since relu will be
6658 // lowered to max(0, x).
6659 foldActivations(F, cctx, &B);
6660
6661 // Lower once more, in case precision transform has introduced operators that
6662 // need to be lowered, e.g., Clip.
6663 ::glow::lower(F, cctx, &B);
6664
6665 // Optimize the graph again now that we have a lowered representation.
6666 ::glow::optimize(F, cctx);
6667
6668 // If requested fold ElemKind conversion Nodes into static Placeholders,
6669 // inputs, and outputs (Placeholders and SaveNodes).
6670 if (cctx.optimizationOpts.foldStaticPlaceholderConversions ||
6671 cctx.optimizationOpts.foldElemKindConversionIntoIO) {
6672 std::unique_ptr<FunctionPassPipeline> pipeline =
6673 glow::make_unique<FunctionPassPipeline>();
6674 pipeline->pushBack({FunctionPassID::FoldElemKindConversionIntoInputs});
6675
6676 if (cctx.optimizationOpts.foldElemKindConversionIntoIO) {
6677 pipeline->pushBack({FunctionPassID::FoldElemKindConversionIntoOutputs});
6678 }
6679 FunctionPassManager FPM("FoldElemKindConversionIntoIO",
6680 std::move(pipeline));
6681 if (FPM.run(F, cctx)) {
6682 ::glow::optimize(F, cctx);
6683 }
6684 }
6685
6686 if (B.shouldPreQuantizeConstants()) {
6687 // Do the actual float ->fix-point conversion of constant tensors before
6688 // Post-lowering.
6689 ::glow::convertQuantizedConstants(F, cctx);
6690 }
6691
6692 // Allow the backend to transform the graph after lowering.
6693 RETURN_IF_EXPECTED_IS_ERR(B.transformPostLowering(F, cctx, devInfo));
6694
6695 if (!B.shouldPreQuantizeConstants()) {
6696 // Do the actual float ->fix-point conversion of constant tensors after
6697 // Post-lowering.
6698 ::glow::convertQuantizedConstants(F, cctx);
6699 }
6700
6701 // Optimize the graph again after the backend transformation.
6702 // In particular, DCE is very likely to be useful.
6703 ::glow::optimize(F, cctx, B);
6704
6705 // We already started using backend specific verification when the function
6706 // state became lowered. Do one more verification pass to make sure everything
6707 // is in order and to bail if it is not.
6708 if (cctx.optimizationOpts.delayAndRecordConstantModification ||
6709 cctx.optimizationOpts.skipBackendSupportCheck) {
6710 // Only do verification without checking the backend's support if requested,
6711 // or if we are disallowing constant modification (since this flag may have
6712 // prevented a backend from supporting some Nodes, and may be supported
6713 // after constant folding finishes). Expect the caller to verify that Nodes
6714 // are supported by the backend later on in such cases.
6715 RETURN_ERR_IF_NOT(F->verify(&B),
6716 "Verification after optimization failed for Function " +
6717 F->getName().str() + " and Backend " +
6718 B.getBackendName());
6719 } else if (!B.verify(*F, cctx.verboseCompile)) {
6720 return MAKE_ERR(
6721 ErrorValue::ErrorCode::COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE,
6722 "Unsupported node(s) found after optimizing Function " +
6723 F->getName().str() + " for backend " + B.getBackendName());
6724 }
6725 return Error::success();
6726}
6727
6728bool glow::executeVerticalFCWeightsSplit(Function *F, unsigned numOfChunks,
6729 unsigned minKToSplit) {
6730 DCHECK(numOfChunks > 0) << "numOfChunks must be a positive number, given: "
6731 << numOfChunks;
6732 DCHECK(minKToSplit > 0) << "minKToSplit must be a positive number, given: "
6733 << minKToSplit;
6734
6735 bool changed = false;
6736 for (auto it = F->getNodes().begin(), e = F->getNodes().end(); it != e;
6737 ++it) {
6738 auto *FC = dyn_cast<FullyConnectedNode>(it);
6739 if (!FC) {
6740 continue;
6741 }
6742
6743 size_t K = FC->getWeights().dims()[1];
6744 if (K < minKToSplit) {
6745 continue;
6746 }
6747
6748 auto input = FC->getInput();
6749 auto weights = FC->getWeights();
6750 auto bias = FC->getBias();
6751
6752 dim_t elemPerChunk = (bias.dims()[0] + numOfChunks - 1) / numOfChunks;
6753 dim_t sliceStart = 0;
6754 std::vector<NodeValue> fcs(numOfChunks);
6755
6756 // Split weights across second dimension into numOfChunks pieces.
6757 // Input dimension is [M;K] and kept untouched.
6758 // Bias dimension is [N], split into chunks.
6759 // Weight dimension is [K;N], split into numOfChunks chunks,
6760 // [K;N/numOfChunks] each.
6761 // Last chunk might require special handling in case
6762 // N is not divisible by numOfChunks.
6763 auto *fcType = F->getParent()->uniqueTypeWithNewShape(
6764 FC->getResult().getType(), {FC->getResult().dims()[0], elemPerChunk});
6765
6766 for (unsigned i = 0; i < numOfChunks; ++i) {
6767 // Last chunk might need special handling if bias dimension
6768 // is not divisible by numOfChunks.
6769 if (i == numOfChunks - 1 && bias.dims()[0] % numOfChunks != 0) {
6770 elemPerChunk = bias.dims()[0] - (numOfChunks - 1) * elemPerChunk;
6771 fcType = F->getParent()->uniqueTypeWithNewShape(
6772 FC->getResult().getType(),
6773 {FC->getResult().dims()[0], elemPerChunk});
6774 }
6775
6776 auto *weightSlice = F->createSlice(
6777 "weight_slice." + weights.getNode()->getName().str(), weights,
6778 {0, sliceStart}, {weights.dims()[0], sliceStart + elemPerChunk});
6779 auto *biasSlice =
6780 F->createSlice("bias_slice." + bias.getNode()->getName().str(), bias,
6781 {sliceStart}, {sliceStart + elemPerChunk});
6782 fcs[i] = F->createFullyConnected("fc_slice." + FC->getName().str(), input,
6783 weightSlice->getResult(),
6784 biasSlice->getResult(), fcType);
6785 sliceStart += elemPerChunk;
6786 }
6787
6788 auto *concat =
6789 F->createConcat("concat." + FC->getName().str(), fcs, /*dimension*/ 1);
6790 FC->getResult().replaceAllUsesOfWith(concat);
6791 changed = true;
6792 }
6793
6794 return changed;
6795}
6796
6797static Expected<ConcatNode *> parallelizeAndReplaceReshapeNode(
6798 Function *F, Node *curNode, dim_t numOfChunksNode, dim_t inputBatchIdx,
6799 dim_t resultIdx, llvm::ArrayRef<int> splitDims, size_t resultDim,
6800 dim_t modelParallelSplitAlignment = 1) {
6801 const int inputIdx = splitDims[inputBatchIdx];
6802 RETURN_ERR_IF_NOT(inputIdx >= 0, "Input batch idx must be split");
6803 RETURN_ERR_IF_NOT(modelParallelSplitAlignment == 1,
6804 "modelParallelSplitAlignment must be 1");
6805 const dim_t batchSize = curNode->getNthInput(inputBatchIdx).dims()[inputIdx];
6806 // We can only apply this parallelization when the input/output batch sizes
6807 // can be divided by numOfChunksNode
6808 if ((curNode->getNthResult(0).dims()[0] % numOfChunksNode != 0) ||
6809 (batchSize % numOfChunksNode != 0)) {
6810 return nullptr;
6811 }
6812 const dim_t elemPerChunk = batchSize / numOfChunksNode;
6813
6814 RETURN_ERR_IF_NOT(
6815 batchSize >= numOfChunksNode,
6816 strFormat("Invalid parallelization; batchSize %lu must be "
6817 ">= numOfChunksNode %lu for node %s with kind %s",
6818 (unsigned long)batchSize, (unsigned long)numOfChunksNode,
6819 curNode->getName().str().c_str(), curNode->getKindName()));
6820
6821 std::vector<NodeValue> newNodes(numOfChunksNode);
6822 for (dim_t i = 0; i < numOfChunksNode; ++i) {
6823 // Calculate the out type of this chunk.
6824 const dim_t sliceStart = i * elemPerChunk;
6825 const dim_t sliceEnd =
6826 (i < numOfChunksNode - 1) ? sliceStart + elemPerChunk : batchSize;
6827 VLOG(1) << "\tChunk " << i << ": start: " << sliceStart
6828 << " end: " << sliceEnd << "\n";
6829 auto outDims = curNode->dims(resultIdx).vec();
6830 VLOG(1) << "original out dims: [" << folly::join(", ", outDims) << "]";
6831 RETURN_ERR_IF_NOT(resultDim < outDims.size(),
6832 "outDims access out of range");
6833 outDims[resultDim] /= numOfChunksNode;
6834 VLOG(1) << "modified out dims: [" << folly::join(", ", outDims) << "]";
6835
6836 // Clone the original Node, so that it keeps all of the inputs/members of
6837 // the original Node. Then modify the output type so that its new shape is
6838 // correct, and below change the inputs to the sliced inputs.
6839 Node *clone = curNode->clone();
6840 clone->getNthResult(resultIdx).setTypeUnsafe(
6841 F->getParent()->uniqueTypeWithNewShape(curNode->getType(resultIdx),
6842 outDims));
6843 F->addNode(clone);
6844
6845 // Loop over all of the inputs and slice those inputs that need to be
6846 // sliced, and set them on the clone.
6847 for (int j = 0, e = curNode->getNumInputs(); j < e; j++) {
6848 int dim = splitDims[j];
6849 if (dim == -1) {
6850 continue;
6851 }
6852
6853 NodeValue currInput = curNode->getNthInput(j);
6854 auto sliceDimsStart = std::vector<dim_t>(currInput.dims().size(), 0);
6855 RETURN_ERR_IF_NOT(dim < sliceDimsStart.size(),
6856 "sliceDimsStart access out of range");
6857 sliceDimsStart[dim] = sliceStart;
6858 auto sliceDimsEnd = currInput.dims().vec();
6859 RETURN_ERR_IF_NOT(dim < sliceDimsEnd.size(),
6860 "sliceDimsEnd access out of range");
6861 sliceDimsEnd[dim] = sliceEnd;
6862 VLOG(1) << "start: [" << folly::join(", ", sliceDimsStart) << "]";
6863 VLOG(1) << "end: [" << folly::join(", ", sliceDimsEnd) << "]";
6864 VLOG(1) << "Input name: " << currInput.getNode()->getName().str() << "\n";
6865
6866 auto *inputSlice =
6867 F->createSlice("dp_slice." + currInput.getNode()->getName().str() +
6868 "." + std::to_string(i),
6869 currInput, sliceDimsStart, sliceDimsEnd);
6870 clone->setNthInput(j, inputSlice);
6871
6872 newNodes[i] = NodeValue(clone, resultIdx);
6873 }
6874 }
6875
6876 std::vector<NodeValue> newNodesClip(numOfChunksNode);
6877 int cnt = 0;
6878 // Add extra Clip ops
6879 // TODO: need to remove the Clip ops once the FC inputs can be put on SRAM
6880 for (auto &node : newNodes) {
6881 ClipNode *newCN = F->createClip(node.getNode()->getName().str() + "_clip_",
6882 node.getNode(), kMinFP16, kMaxFP16);
6883 newNodesClip[cnt] = newCN;
6884 cnt++;
6885 }
6886
6887 // Now that we have split the node into many, concat all of the pieces back
6888 // together and replace the original by the concat.
6889 VLOG(1) << "Creating Concat";
6890 auto *concat = F->createConcat("concat." + curNode->getName().str(),
6891 newNodesClip, resultDim);
6892 curNode->getNthResult(resultIdx).replaceAllUsesOfWith(concat);
6893 return concat;
6894}
6895
6896/// Helper to parallelize a node \p curNode from \p F into \p numOfChunksNode
6897/// Nodes by slicing its inputs, creating clones of it and changing the inputs
6898/// of the clones to the slices, and then concatenating all of the clones
6899/// together and replacing \p curNode with the concat. \p inputBatchIdx is the
6900/// input idx from \p curNode that will be split (there may be more than one
6901/// input to split, but their splitDim should all have the same size).
6902/// \p splitDim represents what dimension to split for each of the inputs to
6903/// \p curNode. \p resultDim is the dimension on which we are splitting and then
6904/// concatenating the results. \p resultIdx represents the result index from
6905/// \p curNode that is being split and later concatenated. The size of the
6906/// splits will be increased to a multiple of \p modelParallelSplitAlignment, if
6907/// possible. If the result after aligning the splits is that the new aligned
6908/// splits are larger than the original requested num splits, then the number of
6909/// resulting splits may be less than requested. \returns an Expected of the
6910/// ConcatNode that is created and replaces \p curNode, or otherwise an Error if
6911/// parallelization had some issue.
6912static Expected<ConcatNode *>
6913parallelizeAndReplaceNode(Function *F, Node *curNode, dim_t numOfChunksNode,
6914 dim_t inputBatchIdx, dim_t resultIdx,
6915 llvm::ArrayRef<int> splitDims, size_t resultDim,
6916 dim_t modelParallelSplitAlignment = 1) {
6917 const int inputIdx = splitDims[inputBatchIdx];
6918 CHECK_GE(inputIdx, 0) << "Input batch idx must be split";
6919 const dim_t batchSize = curNode->getNthInput(inputBatchIdx).dims()[inputIdx];
6920 const dim_t elemPerChunk = batchSize / numOfChunksNode;
6921 const dim_t remain = batchSize % numOfChunksNode;
6922 // This alignment will create aligned splits. So for example, if we're
6923 // splitting 190 by 3, then without alignment it would be {64, 63, 63}.
6924 // With alignment of 64, it will be {64, 64, 62}
6925 const dim_t alignedElemPerChunk =
6926 ((elemPerChunk + (modelParallelSplitAlignment - 1)) /
6927 modelParallelSplitAlignment) *
6928 modelParallelSplitAlignment;
6929
6930 RETURN_ERR_IF_NOT(
6931 batchSize >= numOfChunksNode,
6932 "Invalid parallelization; batchSize " + std::to_string(batchSize) +
6933 " must be >= numOfChunksNode " + std::to_string(numOfChunksNode) +
6934 " for node " + curNode->getName().str() + " with kind " +
6935 curNode->getKindName());
6936
6937 // Potentially modify numOfChunksNode, if the aligned size times current
6938 // numOfChunksNode exceeds the total size
6939 if (modelParallelSplitAlignment > 1) {
6940 numOfChunksNode =
6941 (batchSize + alignedElemPerChunk - 1) / alignedElemPerChunk;
6942 }
6943
6944 std::vector<NodeValue> newNodes(numOfChunksNode);
6945 for (dim_t i = 0; i < numOfChunksNode; ++i) {
6946 // Calculate the out type of this chunk.
6947 dim_t sliceStart, sliceEnd;
6948
6949 if (modelParallelSplitAlignment > 1) {
6950 // If we are using aligned splits, then slice by multiples of the
6951 // alignment, leaving the rest to the last split. The last split is
6952 // necessarily smaller than the other splits.
6953 sliceStart = i * alignedElemPerChunk;
6954 sliceEnd = (i < numOfChunksNode - 1) ? sliceStart + alignedElemPerChunk
6955 : batchSize;
6956 } else {
6957 // Otherwise, distribute elements evenly across the splits and sprinkle
6958 // the remainder evenly as well
6959 sliceStart = i * elemPerChunk + std::min(i, remain);
6960 sliceEnd = sliceStart + elemPerChunk + ((i < remain) ? 1 : 0);
6961 }
6962 VLOG(1) << "\tChunk " << i << ": start: " << sliceStart
6963 << " end: " << sliceEnd << "\n";
6964 auto outDims = curNode->dims(resultIdx).vec();
6965 outDims[resultDim] = (sliceEnd - sliceStart);
6966 for (auto outDim : outDims) {
6967 VLOG(1) << "outDim: " << outDim << "\n";
6968 }
6969
6970 // Clone the original Node, so that it keeps all of the inputs/members of
6971 // the original Node. Then modify the output type so that its new shape is
6972 // correct, and below change the inputs the sliced inputs.
6973 Node *clone = curNode->clone();
6974 clone->getNthResult(resultIdx).setTypeUnsafe(
6975 F->getParent()->uniqueTypeWithNewShape(curNode->getType(resultIdx),
6976 outDims));
6977 F->addNode(clone);
6978
6979 // Loop over all of the inputs and slice those inputs that need to be
6980 // sliced, and set them on the clone.
6981 for (int j = 0, e = curNode->getNumInputs(); j < e; j++) {
6982 int dim = splitDims[j];
6983 if (dim == -1) {
6984 continue;
6985 }
6986
6987 NodeValue currInput = curNode->getNthInput(j);
6988 auto sliceDimsStart = std::vector<dim_t>(currInput.dims().size(), 0);
6989 sliceDimsStart[dim] = sliceStart;
6990 auto sliceDimsEnd = currInput.dims().vec();
6991 sliceDimsEnd[dim] = sliceEnd;
6992 VLOG(1) << "start: ";
6993 for (auto sliceDimStart : sliceDimsStart) {
6994 VLOG(1) << sliceDimStart << "\n";
6995 }
6996 VLOG(1) << "end: ";
6997 for (auto sliceDimEnd : sliceDimsEnd) {
6998 VLOG(1) << sliceDimEnd << "\n";
6999 }
7000 VLOG(1) << "Input name: " << currInput.getNode()->getName().str() << "\n";
7001
7002 auto *inputSlice =
7003 F->createSlice("dp_slice." + currInput.getNode()->getName().str() +
7004 "." + std::to_string(i),
7005 currInput, sliceDimsStart, sliceDimsEnd);
7006 clone->setNthInput(j, inputSlice);
7007
7008 newNodes[i] = NodeValue(clone, resultIdx);
7009 }
7010 }
7011
7012 // Now that we have split the node into many, concat all of the pieces back
7013 // together and replace the original by the concat.
7014 VLOG(1) << "Creating Concat";
7015 auto *concat = F->createConcat("concat." + curNode->getName().str(), newNodes,
7016 resultDim);
7017 curNode->getNthResult(resultIdx).replaceAllUsesOfWith(concat);
7018 return concat;
7019}
7020
7021/// Specialized helper for parallelizing ConcatNode \p CN from \p F into
7022/// \p numOfChunks.
7023static Expected<ConcatNode *>
7024parallelizeAndReplaceConcat(Function *F, ConcatNode *CN, dim_t numOfChunks) {
7025 auto in = CN->getInputs();
7026
7027 const dim_t inputSize = in.size();
7028 const dim_t elemPerChunk = inputSize / numOfChunks;
7029 const dim_t remain = inputSize % numOfChunks;
7030
7031 RETURN_ERR_IF_NOT(
7032 elemPerChunk > 0,
7033 "When parallelizing a Concat, inputSize must be larger than numOfChunks");
7034
7035 auto startIt = in.begin();
7036 std::vector<NodeValue> finalConcatInputs;
7037 for (dim_t i = 0; i < numOfChunks; ++i) {
7038 const dim_t sliceStart = i * elemPerChunk + std::min(i, remain);
7039 const dim_t sliceEnd = sliceStart + elemPerChunk + ((i < remain) ? 1 : 0);
7040
7041 // Slice out the original Concat chunk's inputs, create a new Concat with
7042 // the slice, and then add the new Concat to the final Concat's inputs.
7043 std::vector<NodeValue> newInputs(startIt + sliceStart, startIt + sliceEnd);
7044 ConcatNode *concatSlice = F->createConcat(CN->getName().str() + "_slice",
7045 newInputs, CN->getDim());
7046 finalConcatInputs.push_back(concatSlice);
7047 }
7048 ConcatNode *finalConcat = F->createConcat(CN->getName().str() + "_merge",
7049 finalConcatInputs, CN->getDim());
7050 CN->getResult().replaceAllUsesOfWith(finalConcat);
7051 return finalConcat;
7052}
7053
7054#define SPLIT_ELTWISE_UNARY_OP_HELPER(NodeKind, Axis) \
7055 if (curNode->getNthInput(NodeKind##Node::InputIdx).dims().size() <= Axis) { \
7056 break; \
7057 } \
7058 splitDims[NodeKind##Node::InputIdx] = Axis; \
7059 ASSIGN_VALUE_OR_RETURN_ERR( \
7060 CN, parallelizeAndReplaceNode( \
7061 F, curNode, curNumOfChunks, NodeKind##Node::InputIdx, \
7062 NodeKind##Node::ResultIdx, splitDims, /*resultDim*/ Axis, \
7063 modelParallelSplitAlignment));
7064
7065#define SPLIT_ELTWISE_BINARY_OP_HELPER(NodeKind, Axis) \
7066 if (curNode->getNthInput(NodeKind##Node::LHSIdx).dims().size() <= Axis) { \
7067 break; \
7068 } \
7069 splitDims[NodeKind##Node::LHSIdx] = Axis; \
7070 splitDims[NodeKind##Node::RHSIdx] = Axis; \
7071 ASSIGN_VALUE_OR_RETURN_ERR( \
7072 CN, parallelizeAndReplaceNode( \
7073 F, curNode, curNumOfChunks, NodeKind##Node::LHSIdx, \
7074 NodeKind##Node::ResultIdx, splitDims, /*resultDim*/ Axis, \
7075 modelParallelSplitAlignment));
7076
7077Expected<std::unordered_map<Node *, ConcatNode *>> glow::parallelizeOps(
7078 Function *F, const llvm::DenseMap<Node *, size_t> &numOfChunksMap,
7079 const llvm::DenseMap<Node *, ParallelTransformKind> &parOpts,
7080 size_t numOfChunks, size_t modelParallelSplitAlignment) {
7081 // Since we will be transforming the original list of nodes, reverse iterate.
7082 auto &nodes = F->getNodes();
7083 size_t numProcessedNodes = 0;
7084 std::unordered_map<Node *, ConcatNode *> replacedMap;
7085 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) {
7086 Node *curNode = &*it;
7087 size_t curNumOfChunks = numOfChunks;
7088 auto numOfChunksIt = numOfChunksMap.find(curNode);
7089 if (numOfChunksIt != numOfChunksMap.end()) {
7090 curNumOfChunks = numOfChunksIt->second;
7091 }
7092
7093 ParallelTransformKind parTransformMode = ParallelTransformKind::None;
7094 auto parOptsIt = parOpts.find(curNode);
7095 if (parOptsIt != parOpts.end()) {
7096 parTransformMode = parOptsIt->second;
7097 ++numProcessedNodes;
7098 }
7099
7100 VLOG(1) << "Attempting to Parallelizing Node: " << curNode->getName().str()
7101 << "\n";
7102
7103 ConcatNode *CN = nullptr;
7104
7105 // Use this vector to communicate what dims to split to
7106 // parallelizeAndReplaceNode(). -1 represents not splitting at all.
7107 llvm::SmallVector<int, 3> splitDims(curNode->getNumInputs(), -1);
7108
7109 // Set model parallelization axis
7110 // Default model parallelization is along axis = 1, hence the default value.
7111 dim_t modelParAxis = 1;
7112 switch (parTransformMode) {
7113#define MODEL_AXIS_CASE(_N) \
7114 case ParallelTransformKind::Model_Axis##_N: \
7115 modelParAxis = _N; \
7116 parTransformMode = ParallelTransformKind::Model; \
7117 break;
7118 MODEL_AXIS_CASE(1)
7119 MODEL_AXIS_CASE(2)
7120 MODEL_AXIS_CASE(3)
7121 MODEL_AXIS_CASE(4)
7122 MODEL_AXIS_CASE(5)
7123 default:
7124 break;
7125 }
7126
7127 switch (parTransformMode) {
7128 case ParallelTransformKind::Data: {
7129 switch (curNode->getKind()) {
7130 case Kinded::Kind::FullyConnectedNodeKind: {
7131 splitDims[FullyConnectedNode::InputIdx] = 0;
7132 ASSIGN_VALUE_OR_RETURN_ERR(
7133 CN, parallelizeAndReplaceNode(
7134 F, curNode, curNumOfChunks, FullyConnectedNode::InputIdx,
7135 FullyConnectedNode::ResultIdx, splitDims, 0));
7136 break;
7137 }
7138 case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind: {
7139 splitDims[RowwiseQuantizedFullyConnectedNode::InputIdx] = 0;
7140 ASSIGN_VALUE_OR_RETURN_ERR(
7141 CN,
7142 parallelizeAndReplaceNode(
7143 F, curNode, curNumOfChunks,
7144 RowwiseQuantizedFullyConnectedNode::InputIdx,
7145 RowwiseQuantizedFullyConnectedNode::ResultIdx, splitDims, 0));
7146 break;
7147 }
7148 case Kinded::Kind::MatMulNodeKind: {
7149 splitDims[MatMulNode::LHSIdx] = 0;
7150 ASSIGN_VALUE_OR_RETURN_ERR(
7151 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7152 MatMulNode::LHSIdx,
7153 MatMulNode::ResultIdx, splitDims, 0));
7154 break;
7155 }
7156 case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind: {
7157 splitDims[ChannelwiseQuantizedConvolutionNode::InputIdx] = 0;
7158 ASSIGN_VALUE_OR_RETURN_ERR(
7159 CN,
7160 parallelizeAndReplaceNode(
7161 F, curNode, curNumOfChunks,
7162 ChannelwiseQuantizedConvolutionNode::InputIdx,
7163 ChannelwiseQuantizedConvolutionNode::ResultIdx, splitDims, 0));
7164 break;
7165 }
7166 case Kinded::Kind::ConvolutionNodeKind: {
7167 splitDims[ConvolutionNode::InputIdx] = 0;
7168 ASSIGN_VALUE_OR_RETURN_ERR(
7169 CN, parallelizeAndReplaceNode(
7170 F, curNode, curNumOfChunks, ConvolutionNode::InputIdx,
7171 ConvolutionNode::ResultIdx, splitDims, 0));
7172 break;
7173 }
7174 case Kinded::Kind::AdaptiveAvgPoolNodeKind: {
7175 splitDims[AdaptiveAvgPoolNode::InputIdx] = 0;
7176 ASSIGN_VALUE_OR_RETURN_ERR(
7177 CN, parallelizeAndReplaceNode(
7178 F, curNode, curNumOfChunks, AdaptiveAvgPoolNode::InputIdx,
7179 AdaptiveAvgPoolNode::ResultIdx, splitDims, 0));
7180 break;
7181 }
7182 case Kinded::Kind::ROIAlignNodeKind: {
7183 splitDims[ROIAlignNode::BoxesIdx] = 0;
7184 splitDims[ROIAlignNode::BatchIndicesIdx] = 0;
7185 ASSIGN_VALUE_OR_RETURN_ERR(
7186 CN, parallelizeAndReplaceNode(
7187 F, curNode, curNumOfChunks, ROIAlignNode::BoxesIdx,
7188 ROIAlignNode::ResultIdx, splitDims, 0));
7189 break;
7190 }
7191 case Kinded::Kind::MaxPoolNodeKind: {
7192 splitDims[MaxPoolNode::InputIdx] = 0;
7193 ASSIGN_VALUE_OR_RETURN_ERR(
7194 CN, parallelizeAndReplaceNode(
7195 F, curNode, curNumOfChunks, MaxPoolNode::InputIdx,
7196 MaxPoolNode::ResultIdx, splitDims, 0));
7197 break;
7198 }
7199 case Kinded::Kind::ReshapeNodeKind: {
7200 splitDims[ReshapeNode::InputIdx] = 0;
7201 const dim_t batchSize =
7202 curNode->getNthInput(ReshapeNode::InputIdx).dims()[0];
7203 if (batchSize != curNode->getNthResult(0).dims()[0]) {
7204 if (glow::flags::SparseNNParallelizeReshapeOnBatchDim) {
7205 ASSIGN_VALUE_OR_RETURN_ERR(
7206 CN, parallelizeAndReplaceReshapeNode(
7207 F, curNode, curNumOfChunks, ReshapeNode::InputIdx,
7208 ReshapeNode::ResultIdx, splitDims, 0));
7209 } else {
7210 // Do nothing if reshape applies to the first batch dimension
7211 LOG(INFO) << "Reshape changes batch dimension; Disabling data "
7212 "parallel split";
7213 }
7214 break;
7215 }
7216 ASSIGN_VALUE_OR_RETURN_ERR(
7217 CN, parallelizeAndReplaceNode(
7218 F, curNode, curNumOfChunks, ReshapeNode::InputIdx,
7219 ReshapeNode::ResultIdx, splitDims, 0));
7220 break;
7221 }
7222 case Kinded::Kind::AddNodeKind: {
7223 splitDims[AddNode::LHSIdx] = 0;
7224 splitDims[AddNode::RHSIdx] = 0;
7225 ASSIGN_VALUE_OR_RETURN_ERR(
7226 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7227 AddNode::LHSIdx, AddNode::ResultIdx,
7228 splitDims, 0));
7229 break;
7230 }
7231 case Kinded::Kind::SubNodeKind: {
7232 splitDims[SubNode::LHSIdx] = 0;
7233 splitDims[SubNode::RHSIdx] = 0;
7234 ASSIGN_VALUE_OR_RETURN_ERR(
7235 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7236 SubNode::LHSIdx, SubNode::ResultIdx,
7237 splitDims, 0));
7238 break;
7239 }
7240 case Kinded::Kind::BatchMatMulNodeKind: {
7241 splitDims[BatchMatMulNode::LHSIdx] = 0;
7242 splitDims[BatchMatMulNode::RHSIdx] = 0;
7243 ASSIGN_VALUE_OR_RETURN_ERR(
7244 CN, parallelizeAndReplaceNode(
7245 F, curNode, curNumOfChunks, BatchMatMulNode::LHSIdx,
7246 BatchMatMulNode::ResultIdx, splitDims, 0));
7247 break;
7248 }
7249 case Kinded::Kind::MulNodeKind: {
7250 splitDims[AddNode::LHSIdx] = 0;
7251 splitDims[AddNode::RHSIdx] = 0;
7252 ASSIGN_VALUE_OR_RETURN_ERR(
7253 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7254 MulNode::LHSIdx, MulNode::ResultIdx,
7255 splitDims, 0));
7256 break;
7257 }
7258 case Kinded::Kind::PowNodeKind: {
7259 splitDims[PowNode::LHSIdx] = 0;
7260 splitDims[PowNode::RHSIdx] = 0;
7261 ASSIGN_VALUE_OR_RETURN_ERR(
7262 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7263 PowNode::LHSIdx, PowNode::ResultIdx,
7264 splitDims, 0));
7265 break;
7266 }
7267 case Kinded::Kind::SelectNodeKind: {
7268 splitDims[SelectNode::LHSIdx] = 0;
7269 splitDims[SelectNode::RHSIdx] = 0;
7270 splitDims[SelectNode::CondIdx] = 0;
7271 ASSIGN_VALUE_OR_RETURN_ERR(
7272 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7273 SelectNode::LHSIdx,
7274 SelectNode::ResultIdx, splitDims, 0));
7275 break;
7276 }
7277 case Kinded::Kind::ExpNodeKind: {
7278 splitDims[ExpNode::InputIdx] = 0;
7279 ASSIGN_VALUE_OR_RETURN_ERR(
7280 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7281 ExpNode::InputIdx, ExpNode::ResultIdx,
7282 splitDims, 0));
7283 break;
7284 }
7285 case Kinded::Kind::SigmoidNodeKind: {
7286 splitDims[SigmoidNode::InputIdx] = 0;
7287 ASSIGN_VALUE_OR_RETURN_ERR(
7288 CN, parallelizeAndReplaceNode(
7289 F, curNode, curNumOfChunks, SigmoidNode::InputIdx,
7290 SigmoidNode::ResultIdx, splitDims, 0));
7291 break;
7292 }
7293 case Kinded::Kind::SoftMaxNodeKind: {
7294 splitDims[SoftMaxNode::InputIdx] = 0;
7295 ASSIGN_VALUE_OR_RETURN_ERR(
7296 CN, parallelizeAndReplaceNode(
7297 F, curNode, curNumOfChunks, SoftMaxNode::InputIdx,
7298 SoftMaxNode::ResultIdx, splitDims, 0));
7299 break;
7300 }
7301 case Kinded::Kind::LogSoftMaxNodeKind: {
7302 splitDims[LogSoftMaxNode::InputIdx] = 0;
7303 ASSIGN_VALUE_OR_RETURN_ERR(
7304 CN, parallelizeAndReplaceNode(
7305 F, curNode, curNumOfChunks, LogSoftMaxNode::InputIdx,
7306 LogSoftMaxNode::ResultIdx, splitDims, 0));
7307 break;
7308 }
7309 case Kinded::Kind::TanhNodeKind: {
7310 splitDims[TanhNode::InputIdx] = 0;
7311 ASSIGN_VALUE_OR_RETURN_ERR(
7312 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7313 TanhNode::InputIdx,
7314 TanhNode::ResultIdx, splitDims, 0));
7315 break;
7316 }
7317 case Kinded::Kind::SwishNodeKind: {
7318 splitDims[SwishNode::InputIdx] = 0;
7319 ASSIGN_VALUE_OR_RETURN_ERR(
7320 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7321 SwishNode::InputIdx,
7322 SwishNode::ResultIdx, splitDims, 0));
7323 break;
7324 }
7325 case Kinded::Kind::MaxNodeKind: {
7326 splitDims[MaxNode::LHSIdx] = 0;
7327 splitDims[MaxNode::RHSIdx] = 0;
7328 ASSIGN_VALUE_OR_RETURN_ERR(
7329 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7330 MaxNode::LHSIdx, MaxNode::ResultIdx,
7331 splitDims, 0));
7332 break;
7333 }
7334 case Kinded::Kind::MinNodeKind: {
7335 splitDims[MinNode::LHSIdx] = 0;
7336 splitDims[MinNode::RHSIdx] = 0;
7337 ASSIGN_VALUE_OR_RETURN_ERR(
7338 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7339 MinNode::LHSIdx, MinNode::ResultIdx,
7340 splitDims, 0));
7341 break;
7342 }
7343 case Kinded::Kind::TransposeNodeKind: {
7344 auto shuffleVec = cast<TransposeNode>(curNode)->getShuffle();
7345 unsigned_t inputDim = shuffleVec[0];
7346 splitDims[TransposeNode::InputIdx] = inputDim;
7347 ASSIGN_VALUE_OR_RETURN_ERR(
7348 CN, parallelizeAndReplaceNode(
7349 F, curNode, curNumOfChunks, TransposeNode::InputIdx,
7350 TransposeNode::ResultIdx, splitDims, 0));
7351 break;
7352 }
7353 case Kinded::Kind::ReluNodeKind: {
7354 splitDims[ReluNode::InputIdx] = 0;
7355 ASSIGN_VALUE_OR_RETURN_ERR(
7356 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7357 ReluNode::InputIdx,
7358 ReluNode::ResultIdx, splitDims, 0));
7359 break;
7360 }
7361 case Kinded::Kind::GeluNodeKind: {
7362 splitDims[GeluNode::InputIdx] = 0;
7363 ASSIGN_VALUE_OR_RETURN_ERR(
7364 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7365 GeluNode::InputIdx,
7366 GeluNode::ResultIdx, splitDims, 0));
7367 break;
7368 }
7369 case Kinded::Kind::ClipNodeKind: {
7370 splitDims[ClipNode::InputIdx] = 0;
7371 ASSIGN_VALUE_OR_RETURN_ERR(
7372 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7373 ClipNode::InputIdx,
7374 ClipNode::ResultIdx, splitDims, 0));
7375 break;
7376 }
7377 case Kinded::Kind::TileNodeKind: {
7378 TileNode *TN = llvm::dyn_cast<TileNode>(curNode);
7379 RETURN_ERR_IF_NOT(
7380 TN->getAxis() != 0,
7381 "Tile node cannot be split on axis 0 which is being replicated");
7382 splitDims[TileNode::InputIdx] = 0;
7383 ASSIGN_VALUE_OR_RETURN_ERR(
7384 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7385 TileNode::InputIdx,
7386 TileNode::ResultIdx, splitDims, 0));
7387 break;
7388 }
7389 case Kinded::Kind::BatchedReduceAddNodeKind: {
7390 BatchedReduceAddNode *BR = llvm::cast<BatchedReduceAddNode>(curNode);
7391 splitDims[BatchedReduceAddNode::BatchIdx] =
7392 (BR->getAxis() == 0) ? 1 : 0;
7393 ASSIGN_VALUE_OR_RETURN_ERR(
7394 CN, parallelizeAndReplaceNode(
7395 F, curNode, curNumOfChunks, BatchedReduceAddNode::BatchIdx,
7396 BatchedReduceAddNode::ResultIdx, splitDims, 0));
7397 break;
7398 }
7399 case Kinded::Kind::BatchedReduceMeanNodeKind: {
7400 auto *BR = llvm::cast<BatchedReduceMeanNode>(curNode);
7401 const auto &BRaxes = BR->getAxes();
7402 if (std::find(BRaxes.begin(), BRaxes.end(), 0) != BRaxes.end()) {
7403 LOG(INFO) << "BatchedReduceMean along the first dimension not "
7404 "parallelized. Current node: "
7405 << BR->getDebugDesc();
7406 } else {
7407 splitDims[BatchedReduceMeanNode::BatchIdx] = 0;
7408 ASSIGN_VALUE_OR_RETURN_ERR(
7409 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7410 BatchedReduceMeanNode::BatchIdx,
7411 BatchedReduceMeanNode::ResultIdx,
7412 splitDims, 0));
7413 }
7414 break;
7415 }
7416 case Kinded::Kind::ConcatNodeKind: {
7417 ConcatNode *concat = llvm::cast<ConcatNode>(curNode);
7418 RETURN_ERR_IF_NOT(concat->getDim() == 0,
7419 "Expected to Data parallelize for concat on dim 0");
7420 ASSIGN_VALUE_OR_RETURN_ERR(
7421 CN, parallelizeAndReplaceConcat(F, concat, curNumOfChunks));
7422 break;
7423 }
7424 case Kinded::Kind::LayerNormalizationNodeKind: {
7425 splitDims[LayerNormalizationNode::InputIdx] = 0;
7426 ASSIGN_VALUE_OR_RETURN_ERR(
7427 CN, parallelizeAndReplaceNode(F, curNode, curNumOfChunks,
7428 LayerNormalizationNode::InputIdx,
7429 LayerNormalizationNode::ResultIdx,
7430 splitDims, 0));
7431 break;
7432 }
7433 case Kinded::Kind::QuantizeNodeKind: {
7434 splitDims[QuantizeNode::InputIdx] = 0;
7435 ASSIGN_VALUE_OR_RETURN_ERR(
7436 CN, parallelizeAndReplaceNode(
7437 F, curNode, curNumOfChunks, QuantizeNode::InputIdx,
7438 QuantizeNode::ResultIdx, splitDims, 0));
7439 break;
7440 }
7441 case Kinded::Kind::DequantizeNodeKind: {
7442 splitDims[DequantizeNode::InputIdx] = 0;
7443 ASSIGN_VALUE_OR_RETURN_ERR(
7444 CN, parallelizeAndReplaceNode(
7445 F, curNode, curNumOfChunks, DequantizeNode::InputIdx,
7446 DequantizeNode::ResultIdx, splitDims, 0));
7447 break;
7448 }
7449 case Kinded::Kind::RescaleQuantizedNodeKind: {
7450 splitDims[RescaleQuantizedNode::InputIdx] = 0;
7451 ASSIGN_VALUE_OR_RETURN_ERR(
7452 CN, parallelizeAndReplaceNode(
7453 F, curNode, curNumOfChunks, RescaleQuantizedNode::InputIdx,
7454 RescaleQuantizedNode::ResultIdx, splitDims, 0));
7455 break;
7456 }
7457 case Kinded::Kind::ConvertToNodeKind: {
7458 splitDims[ConvertToNode::InputIdx] = 0;
7459 ASSIGN_VALUE_OR_RETURN_ERR(
7460 CN, parallelizeAndReplaceNode(
7461 F, curNode, curNumOfChunks, ConvertToNode::InputIdx,
7462 ConvertToNode::ResultIdx, splitDims, 0));
7463 break;
7464 }
7465 default:
7466 VLOG(1) << "Attempted to parallelize op type " << curNode->getKindName()
7467 << "not yet supported"
7468 << "\n";
7469 break;
7470 }
7471 break;
7472 }
7473
7474 case ParallelTransformKind::Model: {
7475 switch (curNode->getKind()) {
7476 case Kinded::Kind::FullyConnectedNodeKind: {
7477 if (modelParAxis != 1) {
7478 break;
7479 }
7480 splitDims[FullyConnectedNode::WeightsIdx] = 1;
7481 splitDims[FullyConnectedNode::BiasIdx] = 0;
7482 ASSIGN_VALUE_OR_RETURN_ERR(
7483 CN, parallelizeAndReplaceNode(
7484 F, curNode, curNumOfChunks, FullyConnectedNode::WeightsIdx,
7485 FullyConnectedNode::ResultIdx, splitDims, /*resultDim*/ 1,
7486 modelParallelSplitAlignment));
7487 break;
7488 }
7489 case Kinded::Kind::MatMulNodeKind: {
7490 if (modelParAxis != 1) {
7491 break;
7492 }
7493 splitDims[MatMulNode::RHSIdx] = 1;
7494 ASSIGN_VALUE_OR_RETURN_ERR(
7495 CN, parallelizeAndReplaceNode(
7496 F, curNode, curNumOfChunks, MatMulNode::RHSIdx,
7497 MatMulNode::ResultIdx, splitDims, /*resultDim*/ 1,
7498 modelParallelSplitAlignment));
7499 break;
7500 }
7501 case Kinded::Kind::ReluNodeKind: {
7502 SPLIT_ELTWISE_UNARY_OP_HELPER(Relu, modelParAxis);
7503 break;
7504 }
7505 case Kinded::Kind::ClipNodeKind: {
7506 SPLIT_ELTWISE_UNARY_OP_HELPER(Clip, modelParAxis);
7507 break;
7508 }
7509 case Kinded::Kind::SelectNodeKind: {
7510 if (modelParAxis != 1) {
7511 break;
7512 }
7513 auto *SL = llvm::cast<SelectNode>(curNode);
7514 if (SL->getNthInput(SelectNode::LHSIdx).dims().size() < 2) {
7515 break;
7516 }
7517 splitDims[SelectNode::LHSIdx] = 1;
7518 splitDims[SelectNode::RHSIdx] = 1;
7519 splitDims[SelectNode::CondIdx] = 1;
7520 ASSIGN_VALUE_OR_RETURN_ERR(
7521 CN, parallelizeAndReplaceNode(
7522 F, curNode, curNumOfChunks, SelectNode::LHSIdx,
7523 SelectNode::ResultIdx, splitDims,
7524 /*resultDim*/ 1, modelParallelSplitAlignment));
7525 break;
7526 }
7527 case Kinded::Kind::AddNodeKind: {
7528 SPLIT_ELTWISE_BINARY_OP_HELPER(Add, modelParAxis);
7529 break;
7530 }
7531 case Kinded::Kind::TileNodeKind: {
7532 if (modelParAxis != 1) {
7533 break;
7534 }
7535 if (curNode->getNthInput(TileNode::InputIdx).dims().size() < 2) {
7536 break;
7537 }
7538 TileNode *TN = llvm::dyn_cast<TileNode>(curNode);
7539 RETURN_ERR_IF_NOT(
7540 TN->getAxis() != 1,
7541 "Tile node cannot be split on axis 1 which is being replicated");
7542 splitDims[TileNode::InputIdx] = 1;
7543 ASSIGN_VALUE_OR_RETURN_ERR(
7544 CN, parallelizeAndReplaceNode(
7545 F, curNode, curNumOfChunks, TileNode::InputIdx,
7546 TileNode::ResultIdx, splitDims,
7547 /*resultDim*/ 1, modelParallelSplitAlignment));
7548 break;
7549 }
7550 case Kinded::Kind::ConcatNodeKind: {
7551 if (modelParAxis != 1) {
7552 break;
7553 }
7554 ConcatNode *concat = llvm::cast<ConcatNode>(curNode);
7555 RETURN_ERR_IF_NOT(concat->getDim() == 1,
7556 "Expected to Model parallelize for concat on dim 1");
7557 ASSIGN_VALUE_OR_RETURN_ERR(
7558 CN, parallelizeAndReplaceConcat(F, concat, curNumOfChunks));
7559 break;
7560 }
7561 case Kinded::Kind::QuantizeNodeKind: {
7562 SPLIT_ELTWISE_UNARY_OP_HELPER(Quantize, modelParAxis);
7563 break;
7564 }
7565 case Kinded::Kind::DequantizeNodeKind: {
7566 SPLIT_ELTWISE_UNARY_OP_HELPER(Dequantize, modelParAxis);
7567 break;
7568 }
7569 case Kinded::Kind::BatchedReduceMeanNodeKind: {
7570 if (curNode->getNthInput(BatchedReduceMeanNode::BatchIdx)
7571 .dims()
7572 .size() <= modelParAxis) {
7573 break;
7574 }
7575 splitDims[BatchedReduceMeanNode::BatchIdx] = modelParAxis;
7576 ASSIGN_VALUE_OR_RETURN_ERR(
7577 CN, parallelizeAndReplaceNode(
7578 F, curNode, curNumOfChunks, BatchedReduceMeanNode::BatchIdx,
7579 BatchedReduceMeanNode::ResultIdx, splitDims,
7580 /*resultDim*/ modelParAxis, modelParallelSplitAlignment));
7581 break;
7582 }
7583 case Kinded::Kind::RescaleQuantizedNodeKind: {
7584 SPLIT_ELTWISE_UNARY_OP_HELPER(RescaleQuantized, modelParAxis);
7585 break;
7586 }
7587 case Kinded::Kind::BatchNormalizationNodeKind: {
7588 auto *BN = llvm::cast<BatchNormalizationNode>(curNode);
7589
7590 if (modelParAxis != BN->getChannelIdx()) {
7591 break;
7592 }
7593 if (BN->getInput().dims().size() <= modelParAxis) {
7594 break;
7595 }
7596 splitDims[BatchNormalizationNode::InputIdx] = modelParAxis;
7597 splitDims[BatchNormalizationNode::ScaleIdx] = 0;
7598 splitDims[BatchNormalizationNode::BiasIdx] = 0;
7599 splitDims[BatchNormalizationNode::MeanIdx] = 0;
7600 splitDims[BatchNormalizationNode::VarIdx] = 0;
7601
7602 ASSIGN_VALUE_OR_RETURN_ERR(
7603 CN,
7604 parallelizeAndReplaceNode(
7605 F, curNode, curNumOfChunks, BatchNormalizationNode::InputIdx,
7606 BatchNormalizationNode::ResultIdx, splitDims,
7607 /*resultDim*/ modelParAxis, modelParallelSplitAlignment));
7608 break;
7609 }
7610 case Kinded::Kind::ResizeNearestNodeKind: {
7611 SPLIT_ELTWISE_UNARY_OP_HELPER(ResizeNearest, modelParAxis);
7612 break;
7613 }
7614 case Kinded::Kind::ConvolutionNodeKind: {
7615 if (modelParAxis != 3) {
7616 break;
7617 }
7618 splitDims[ConvolutionNode::FilterIdx] = 0;
7619 splitDims[ConvolutionNode::BiasIdx] = 0;
7620 ASSIGN_VALUE_OR_RETURN_ERR(
7621 CN, parallelizeAndReplaceNode(
7622 F, curNode, curNumOfChunks, ConvolutionNode::FilterIdx,
7623 ConvolutionNode::ResultIdx, splitDims,
7624 /*resultDim*/ 3, modelParallelSplitAlignment));
7625 break;
7626 }
7627 case Kinded::Kind::Convolution3DNodeKind: {
7628 if (modelParAxis != 4) {
7629 break;
7630 }
7631 splitDims[Convolution3DNode::FilterIdx] = 0;
7632 splitDims[Convolution3DNode::BiasIdx] = 0;
7633 ASSIGN_VALUE_OR_RETURN_ERR(
7634 CN, parallelizeAndReplaceNode(
7635 F, curNode, curNumOfChunks, Convolution3DNode::FilterIdx,
7636 Convolution3DNode::ResultIdx, splitDims,
7637 /*resultDim*/ 4, modelParallelSplitAlignment));
7638 break;
7639 }
7640 case Kinded::Kind::AvgPoolNodeKind: {
7641 auto *APN = llvm::cast<AvgPoolNode>(curNode);
7642 if (APN->getLayout() != 2) {
7643 break;
7644 }
7645 if (modelParAxis != 4) {
7646 break;
7647 }
7648 SPLIT_ELTWISE_UNARY_OP_HELPER(AvgPool, 4);
7649 break;
7650 }
7651 default:
7652 VLOG(1) << "Attempted to parallelize op type " << curNode->getKindName()
7653 << "not yet supported"
7654 << "\n";
7655 break;
7656 }
7657 break;
7658 }
7659
7660 default:
7661 break;
7662 }
7663
7664 if (CN) {
7665 replacedMap[curNode] = CN;
7666 }
7667 }
7668
7669 // Because we transformed Node types unsafely, make sure all types of the
7670 // Function still are valid.
7671 RETURN_ERR_IF_NOT(F->verify(), "Verification issue post parallelization");
7672
7673 RETURN_ERR_IF_NOT(numProcessedNodes == parOpts.size(),
7674 "Not all Nodes specified in parOpts were processed.");
7675
7676 return replacedMap;
7677}
7678
7679void glow::updateQuantReluTypes(Function *F) {
7680 // A worklist that contains the nodes to process.
7681 std::vector<Node *> worklist;
7682 auto needsQuantTyUpdate = [](const Node *N) {
7683 return isa<ConcatNode>(N) || isa<SliceNode>(N) || isa<ReshapeNode>(N) ||
7684 isa<TileNode>(N) || isa<BroadcastNode>(N) || isa<TransposeNode>(N);
7685 };
7686
7687 for (Node &N : F->getNodes()) {
7688 // Look for quantized Relus that have negative min, and update their min to
7689 // be zero.
7690 auto *RN = llvm::dyn_cast<ReluNode>(&N);
7691 if (!RN || !RN->getResult().getType()->isQuantizedType() ||
7692 isUsedByNodeWithSideEffects(RN)) {
7693 continue;
7694 }
7695 const TypeRef RNTy = RN->getResult().getType();
7696
7697 const auto qRange = RNTy->getQuantizedValueRange();
7698 if (qRange.first >= 0) {
7699 continue;
7700 }
7701 const auto qParams = quantization::chooseQuantizationParams(
7702 {0, qRange.second}, quantization::Asymmetric, RNTy->getElementType());
7703 const TypeRef qReluTy = F->getParent()->uniqueType(
7704 RNTy->getElementType(), RNTy->dims(), qParams.scale, qParams.offset);
7705 RN->setType(ReluNode::ResultIdx, qReluTy);
7706
7707 // Now look for any users of the Relu which set their type based directly on
7708 // the type of the Relu, and update them as well. These tend to be shape
7709 // changes such as Concat, Slice, Reshape, Tile, Broadcast, Transpose, etc.
7710 for (auto &user : RN->getUsers()) {
7711 auto *U = user.getUser();
7712 if (needsQuantTyUpdate(U)) {
7713 worklist.push_back(U);
7714 }
7715 }
7716 }
7717
7718 // Now we need to update all nodes following the relus which directly took
7719 // their output types from the relu.
7720 while (!worklist.empty()) {
7721 Node *N = worklist.back();
7722 assert(needsQuantTyUpdate(N) && "Unsupported node for quant update.");
7723 worklist.pop_back();
7724
7725 // Look for other users that also need updates and add to worklist.
7726 for (auto &user : N->getUsers()) {
7727 auto *U = user.getUser();
7728 if (needsQuantTyUpdate(U)) {
7729 worklist.push_back(U);
7730 }
7731 }
7732
7733 // Note: We can unconditionally get the 0th result because all nodes we
7734 // currently support to update have a single output.
7735 assert(N->getNumResults() == 1 && "Unsupported multi-output Node");
7736 constexpr unsigned resultIdx = 0;
7737 const TypeRef T = N->getNthResult(resultIdx).getType();
7738
7739 // We must still be in the quantized domain, because we are only following
7740 // chains starting from quantized relus down through shape nodes.
7741 assert(T->isQuantizedType());
7742
7743 // This likely represents an issue because it means e.g. a Reshape will
7744 // change the scale/bias, but continue for now and assume a verifier will
7745 // catch the issue if it is one.
7746 if (isUsedByNodeWithSideEffects(N)) {
7747 continue;
7748 }
7749
7750 // Update the output type just like we did for the original relu.
7751 const auto qRange = T->getQuantizedValueRange();
7752 if (qRange.first >= 0) {
7753 continue;
7754 }
7755 const auto qParams = quantization::chooseQuantizationParams(
7756 {0, qRange.second}, quantization::Asymmetric, T->getElementType());
7757 const TypeRef qReluTy = F->getParent()->uniqueType(
7758 T->getElementType(), T->dims(), qParams.scale, qParams.offset);
7759 N->setType(resultIdx, qReluTy);
7760 }
7761}
7762