1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "BackendTestUtils.h"
17
18#include "glow/Base/Type.h"
19#include "glow/Graph/Graph.h"
20#include "glow/Graph/Node.h"
21#include "glow/Graph/Nodes.h"
22#include "glow/Graph/PlaceholderBindings.h"
23#include "glow/IR/IR.h"
24#include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h"
25#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
26#include "glow/Optimizer/Lower/Lower.h"
27
28#include "gtest/gtest.h"
29
30using namespace glow;
31
32class GraphFold : public GraphOptz {};
33
34/// A helper predicate to check if the provided node has the same address as a
35/// pre-defined address provided in constructor. This is useful if you need to
36/// check that a given node is still in the graph. In general, it is not safe to
37/// use the std::find(begin_it, end_it, value) and compare the nodes by value,
38/// because the node provided as the last parameter of std::find (i.e. the value
39/// reference) may have been removed by some optimizations and cannot be
40/// dereferenced anymore. But comparing the addresses of the nodes should be
41/// fine. Thus, one can use the following form instead:
42/// std::find_if(begin_it, end_it, IsSameNodeAddress(node_address))
43struct IsSameNodeAddress {
44 const Node *nodeAddress_;
45 IsSameNodeAddress(const Node *nodeAddress) : nodeAddress_(nodeAddress) {}
46 bool operator()(const Node &n) const { return &n == nodeAddress_; }
47};
48
49/// \returns true if the Function \p F contains the Node \p N.
50static bool functionContainsNode(const Function *F, const Node *N) {
51 return std::find_if(F->getNodes().begin(), F->getNodes().end(),
52 IsSameNodeAddress(N)) != F->getNodes().end();
53}
54
55/// Optimize the function \p F with \p cctx. \returns the optimized function. If
56/// \p pass is empty then the whole default optimization pipeline is run.
57/// Otherwise only \p pipeline is used.
58static Function *
59optimizeFunctionForTest(Function *F,
60 std::initializer_list<FunctionPassConfig> configs = {},
61 const CompilationContext &cctx = CompilationContext()) {
62 auto *G = F->clone(F->getName().str() + "_optimized");
63 if (configs.size() == 0) {
64 ::glow::optimize(G, cctx);
65 return G;
66 }
67 FunctionPassManager FPM("TestFPM", configs);
68 FPM.run(G, cctx);
69 return G;
70}
71
72/// \returns the first node in a function which has the specificied name.
73template <typename NodeT = Node>
74static const NodeT *findFunctionNodeByName(const Function *F,
75 const llvm::StringRef name) {
76 return llvm::dyn_cast<NodeT>(
77 std::find_if(F->getNodes().begin(), F->getNodes().end(),
78 [=](auto &N) { return N.getName() == name; }));
79}
80
81TEST_F(GraphOptz, OptimizeClipFunnel) {
82 auto *A =
83 mod_.createPlaceholder(ElemKind::FloatTy, {100, 16}, "input", false);
84 Node *K = A;
85 float min = 0.0;
86 float max = 1000.0;
87 for (int i = 0; i < 10; ++i) {
88 min += 1.0;
89 max -= 1.0;
90 K = F_->createClip("clip", K, min, max);
91 }
92 F_->createSave("ret", K);
93
94 EXPECT_EQ(F_->getNodes().size(), 11);
95
96 optimizedF_ = optimizeFunctionForTest(F_);
97 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
98
99 // Find clip node in the optimized graph.
100 Node *newClip = A;
101 for (auto &N : optimizedF_->getNodes()) {
102 if (N.getKind() == Kinded::Kind::ClipNodeKind) {
103 newClip = llvm::dyn_cast<ClipNode>(&N);
104 }
105 }
106 EXPECT_TRUE(llvm::isa<ClipNode>(newClip));
107 ClipNode *c = llvm::dyn_cast<ClipNode>(newClip);
108 EXPECT_EQ(min, c->getMin());
109 EXPECT_EQ(max, c->getMax());
110
111 bindings_.allocate(mod_.getPlaceholders());
112 bindings_.get(A)->getHandle().randomize(-1000, 1000, mod_.getPRNG());
113 bindings_.get(A)->getHandle().raw(0) = -1000;
114 checkNumericalEquivalence();
115}
116
117TEST_F(GraphOptz, DCE) {
118 Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
119 false);
120
121 for (int i = 0; i < 40; i++) {
122 K = F_->createRELU("relu", K);
123 // Add a graph structure that diverges and converges, to catch algorithms
124 // that perform a dump recursive scan.
125 K = F_->createAdd("arith", K, K);
126 }
127
128 // Check that we know how many nodes we've created.
129 EXPECT_EQ(F_->getNodes().size(), 80);
130
131 // Optimize all of the dead code.
132 ::glow::optimize(F_, CompilationMode::Infer);
133
134 // All of the nodes are gone.
135 EXPECT_EQ(F_->getNodes().size(), 0);
136 EXPECT_EQ(mod_.getConstants().size(), 0);
137}
138
139/// Check that predicated instructions are DCE'ed like
140/// regular instructions.
141TEST_F(GraphOptz, DCEwithPredicate) {
142 Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
143 false);
144 Node *predicatedBatch =
145 mod_.createPlaceholder(ElemKind::FloatTy, {4}, "predicate", true);
146 for (int i = 0; i < 40; i++) {
147 K = F_->createRELU("relu", K);
148 K->setPredicate(predicatedBatch);
149 // Add a graph structure that diverges and converges, to catch algorithms
150 // that perform a dump recursive scan.
151 K = F_->createAdd("arith", K, K);
152 K->setPredicate(predicatedBatch);
153 }
154
155 // Check that we know how many nodes we've created.
156 EXPECT_EQ(F_->getNodes().size(), 80);
157
158 // Optimize all of the dead code.
159 ::glow::optimize(F_, CompilationMode::Infer);
160
161 // All of the nodes are gone.
162 EXPECT_EQ(F_->getNodes().size(), 0);
163 EXPECT_EQ(mod_.getConstants().size(), 0);
164}
165
166TEST_F(GraphOptz, liveCodeNotEliminated) {
167 Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
168 false);
169 auto *Ex = mod_.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "Ex", false);
170
171 for (int i = 0; i < 40; i++) {
172 K = F_->createRELU("relu", K);
173 K = F_->createAdd("arith", K, K);
174 }
175 K = F_->createSoftMax("Regression", K, Ex);
176 F_->createSave("ret", K);
177
178 // Check that we know how many nodes we've created.
179 EXPECT_EQ(F_->getNodes().size(), 82);
180
181 // This should not optimize code because none is dead.
182 ::glow::optimize(F_, CompilationMode::Infer);
183
184 // Nothing got optimized.
185 EXPECT_EQ(F_->getNodes().size(), 82);
186 EXPECT_EQ(mod_.getPlaceholders().size(), 3);
187}
188
189/// Skip Reshape sinking below BatchNorm when inapplicable.
190TEST_F(GraphOptz, SkipReshapeSinkBatchNorm) {
191 auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false);
192 Node *RS = F_->createReshape("reshape", A, {32, 64, 1});
193 Node *BN =
194 F_->createBatchNormalization(bindings_, "batch", RS, 1, 0.0001, 0.9);
195 F_->createSave("ret", BN);
196
197 optimizedF_ = optimizeFunctionForTest(F_);
198 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true),
199 optimizedF_->toString(/* skipUsersForStorage */ false,
200 /* skipName */ true));
201}
202
203// Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape
204// below BatchNorm. Reshape transforms [N][H][W][C] to [N][W][H][C].
205TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC) {
206 auto *A =
207 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
208 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
209 Node *RS = F_->createReshape("reshape", CV, {1, 20, 10, 16});
210 Node *BN =
211 F_->createBatchNormalization(bindings_, "batch", RS, 3, 0.0001, 0.9);
212 F_->createSave("ret", BN);
213
214 EXPECT_EQ(F_->getNodes().size(), 4);
215 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
216 optimizedF_ = optimizeFunctionForTest(F_);
217 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
218
219 ASSERT_EQ(A->getNumUsers(), 2);
220 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
221 [CV](auto &it) { return it.getUser() == CV; })
222 ->getUser();
223 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
224 ASSERT_EQ(newCV->getNumUsers(), 1);
225 Node *reshape = newCV->getUsers().begin()->getUser();
226 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
227
228 bindings_.allocate(mod_.getPlaceholders());
229 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
230 checkNumericalEquivalence();
231}
232
233// Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape
234// below BatchNorm. Reshape flattens [N][H][W][C] to [N][HxW][C].
235TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC2) {
236 auto *A =
237 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
238 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
239 Node *RS = F_->createReshape("reshape", CV, {1, 200, 16});
240 Node *BN =
241 F_->createBatchNormalization(bindings_, "batch", RS, 2, 0.0001, 0.9);
242 F_->createSave("ret", BN);
243
244 EXPECT_EQ(F_->getNodes().size(), 4);
245 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
246 optimizedF_ = optimizeFunctionForTest(F_);
247 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
248
249 ASSERT_EQ(A->getNumUsers(), 2);
250 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
251 [CV](auto &it) { return it.getUser() == CV; })
252 ->getUser();
253 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
254 ASSERT_EQ(newCV->getNumUsers(), 1);
255 Node *reshape = newCV->getUsers().begin()->getUser();
256 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
257
258 bindings_.allocate(mod_.getPlaceholders());
259 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
260 checkNumericalEquivalence();
261}
262
263// BatchNorm is not folded into Conv. Reshape changes Channel Index dimensions
264// and it prevents optimization. Reshape transforms [N][H][W][C] to
265// [N][H][W/2][C*2].
266TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWCneg) {
267 auto *A =
268 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
269 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
270 Node *RS = F_->createReshape("reshape", CV, {1, 10, 10, 32});
271 Node *BN =
272 F_->createBatchNormalization(bindings_, "batch", RS, 3, 0.0001, 0.9);
273 F_->createSave("ret", BN);
274
275 EXPECT_EQ(F_->getNodes().size(), 4);
276 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
277 optimizedF_ = optimizeFunctionForTest(F_);
278 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
279
280 ASSERT_EQ(A->getNumUsers(), 2);
281 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
282 [CV](auto &it) { return it.getUser() == CV; })
283 ->getUser();
284 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
285 ASSERT_EQ(newCV->getNumUsers(), 1);
286 Node *reshape = newCV->getUsers().begin()->getUser();
287 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
288 Node *bn = reshape->getUsers().begin()->getUser();
289 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn));
290
291 bindings_.allocate(mod_.getPlaceholders());
292 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
293 checkNumericalEquivalence();
294}
295
296// Sink Reshape below BatchNorm: multi-user testcase.
297TEST_F(GraphOptz, sinkReshapeBelowBatchNormMultiUser) {
298 auto *in =
299 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 40, 8}, "input", false);
300 auto *RS = F_->createReshape("reshape", in, {1, 20, 20, 8});
301 auto *BN =
302 F_->createBatchNormalization(bindings_, "batch", RS, 3, 0.0001, 0.9);
303 auto *save = F_->createSave("ret", BN);
304 F_->createSave("extra_user", RS);
305
306 optimizedF_ = optimizeFunctionForTest(F_);
307 EXPECT_EQ(F_->getNodes().size(), 4);
308 EXPECT_EQ(optimizedF_->getNodes().size(), 5);
309
310 auto *saveOpt =
311 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
312 auto *reshapeOpt = llvm::dyn_cast<ReshapeNode>(saveOpt->getInput());
313 ASSERT_TRUE(reshapeOpt);
314 ASSERT_TRUE(llvm::isa<BatchNormalizationNode>(reshapeOpt->getInput()));
315
316 bindings_.allocate(mod_.getPlaceholders());
317 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
318 checkNumericalEquivalence();
319}
320
321// Sink Reshape below BatchNorm: quantized testcase.
322TEST_F(GraphOptz, sinkReshapeBelowBatchNormQuantized) {
323 auto *in = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 40, 3}, 1.5f, 0,
324 "input", false);
325 auto *params =
326 mod_.createPlaceholder(ElemKind::Int8QTy, {3}, 1.0f, 0, "params", false);
327 auto *RS = F_->createReshape("reshape", in, {1, 20, 20, 3});
328 auto *bnOutTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 20, 20, 3}, 2.7f, 0);
329 auto *BN = F_->createBatchNormalization("batch", bnOutTy, RS, params, params,
330 params, params, 3);
331 auto *save = F_->createSave("ret", BN);
332
333 optimizedF_ = optimizeFunctionForTest(F_);
334 EXPECT_EQ(F_->getNodes().size(), 3);
335 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
336
337 auto *saveOpt =
338 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
339 auto *reshapeOpt = llvm::dyn_cast<ReshapeNode>(saveOpt->getInput());
340 ASSERT_TRUE(reshapeOpt);
341 auto *bnOpt = llvm::dyn_cast<BatchNormalizationNode>(reshapeOpt->getInput());
342 ASSERT_TRUE(bnOpt);
343 EXPECT_TRUE(
344 BN->getInput().getType()->isEqual(*bnOpt->getInput().getType(), true));
345 EXPECT_TRUE(
346 BN->getResult().getType()->isEqual(*bnOpt->getResult().getType(), true));
347}
348
349// Conv->Reshape->BatchNorm. Sink Reshape below BatchNorm. Check that BatchNorm
350// does not fold in to Conv.
351TEST_F(GraphOptz, sinkReshapeBelowBatchNormAndDoNotFuseConvBatchNorm) {
352 // Skip this test for now since Glow doesn't fully support
353 // Convolution of NCHW layout
354 GTEST_SKIP();
355
356 auto *A =
357 mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 10, 20}, "A", false);
358 Node *CV = F_->createConv(bindings_, "conv", A, /* outChannels */ 16,
359 /* kernel */ 5, /* stride */ 1, /* pad */ 2,
360 /* group */ 1, /* dilation */ {1, 1},
361 /* layout */ ConvolutionLayout::NCHW);
362 Node *RS = F_->createReshape("reshape", CV, {1, 10, 16, 20});
363 Node *BN =
364 F_->createBatchNormalization(bindings_, "batch", RS, 1, 0.0001, 0.9);
365 F_->createSave("ret", BN);
366
367 EXPECT_EQ(F_->getNodes().size(), 4);
368 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
369 optimizedF_ = optimizeFunctionForTest(F_);
370 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
371
372 ASSERT_EQ(A->getNumUsers(), 2);
373 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
374 [CV](auto &it) { return it.getUser() == CV; })
375 ->getUser();
376
377 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
378 ASSERT_EQ(newCV->getNumUsers(), 1);
379 Node *bn = newCV->getUsers().begin()->getUser();
380 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn));
381 Node *reshape = bn->getUsers().begin()->getUser();
382 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
383
384 bindings_.allocate(mod_.getPlaceholders());
385 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
386 checkNumericalEquivalence();
387}
388
389TEST_F(GraphOptz, optimizeBatchNormAfterConv) {
390 auto *A =
391 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
392 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
393 Node *BN =
394 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
395 F_->createSave("ret", BN);
396
397 EXPECT_EQ(F_->getNodes().size(), 3);
398 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
399 optimizedF_ = optimizeFunctionForTest(F_);
400 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
401
402 ASSERT_EQ(A->getNumUsers(), 2);
403 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
404 [CV](auto &it) { return it.getUser() == CV; })
405 ->getUser();
406 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
407 ASSERT_EQ(newCV->getNumUsers(), 1);
408 Node *save = newCV->getUsers().begin()->getUser();
409 EXPECT_TRUE(llvm::isa<SaveNode>(save));
410
411 bindings_.allocate(mod_.getPlaceholders());
412 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
413 checkNumericalEquivalence();
414}
415
416void optimizeRedundantBatchNormTest(
417 glow::Module &mod_, glow::Function *F_, glow::Function *&optimizedF_,
418 glow::PlaceholderBindings &bindings_, llvm::ArrayRef<float> varV,
419 llvm::ArrayRef<float> meanV, llvm::ArrayRef<float> gammaV,
420 llvm::ArrayRef<float> betaV, const float eps) {
421 auto *A =
422 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
423
424 auto *var = mod_.createConstant(ElemKind::FloatTy, {3}, "var");
425 auto *mean = mod_.createConstant(ElemKind::FloatTy, {3}, "mean");
426 auto *beta = mod_.createConstant(ElemKind::FloatTy, {3}, "beta");
427 auto *gamma = mod_.createConstant(ElemKind::FloatTy, {3}, "gamma");
428
429 // (X - mean) * (1.0 / sqrt(var + eps)) * gamma + beta
430 var->getPayloadMutable().getHandle<float>() = varV;
431 mean->getPayloadMutable().getHandle<float>() = meanV;
432 beta->getPayloadMutable().getHandle<float>() = betaV;
433 gamma->getPayloadMutable().getHandle<float>() = gammaV;
434 Node *BN = F_->createBatchNormalization("batch", A->getType(), A, beta, gamma,
435 mean, var, 3, eps);
436 Node *LRN = F_->createLocalResponseNormalization("LRN", BN);
437 F_->createSave("ret", LRN);
438
439 EXPECT_EQ(F_->getNodes().size(), 3);
440 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
441 optimizedF_ = optimizeFunctionForTest(F_);
442 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
443
444 ASSERT_EQ(A->getNumUsers(), 2);
445 Node *LRN1 = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
446 [BN](auto &it) { return it.getUser() == BN; })
447 ->getUser();
448 ASSERT_TRUE(llvm::isa<LocalResponseNormalizationNode>(LRN1));
449 ASSERT_EQ(LRN1->getNumUsers(), 1);
450 Node *save = LRN1->getUsers().begin()->getUser();
451 EXPECT_TRUE(llvm::isa<SaveNode>(save));
452
453 bindings_.allocate(mod_.getPlaceholders());
454 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
455}
456
457TEST_F(GraphOptz, optimizeRedundantBatchNorm1) {
458 optimizeRedundantBatchNormTest(mod_, F_, optimizedF_, bindings_, {1., 1., 1.},
459 {0., 0., 0.}, {1., 1., 1.}, {0., 0., 0.}, 0.0);
460 checkNumericalEquivalence();
461}
462
463TEST_F(GraphOptz, optimizeRedundantBatchNorm2) {
464 optimizeRedundantBatchNormTest(mod_, F_, optimizedF_, bindings_, {1., 1., 1.},
465 {33., 33., 33.}, {1., 1., 1.}, {33., 33., 33.},
466 0.0);
467 checkNumericalEquivalence();
468}
469
470TEST_F(GraphOptz, optimizeRedundantBatchNorm3) {
471 const float eps = 0.000001;
472 optimizeRedundantBatchNormTest(
473 mod_, F_, optimizedF_, bindings_, {1.0f - eps, 1.0f - eps, 1.0f - eps},
474 {33., 33., 33.}, {1., 1., 1.}, {33., 33., 33.}, eps);
475 checkNumericalEquivalence();
476}
477TEST_F(GraphOptz, optimizeRedundantBatchNorm4) {
478 optimizeRedundantBatchNormTest(mod_, F_, optimizedF_, bindings_,
479 {225., 225., 225.}, {-3., -3., -3.},
480 {15., 15., 15.}, {-3., -3., -3.}, 0.0);
481 checkNumericalEquivalence();
482}
483
484/// Verify that the Conv-BatchNorm merging optimization is not impacted by
485/// multiple users on the filter/bias.
486TEST_F(GraphOptz, optimizeBatchNormAfterConvMultiple) {
487 Placeholder *A =
488 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
489 ConvolutionNode *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
490 BatchNormalizationNode *BN =
491 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
492 F_->createSave("ret", BN);
493
494 // Adding these saves means the filter and bias have multiple uses. This
495 // should not impact the Conv-BatchNorm merging optimization.
496 F_->createSave("saveFilter", CV->getFilter());
497 F_->createSave("saveBias", CV->getBias());
498
499 // Three Saves, one Conv, and one BatchNorm.
500 EXPECT_EQ(F_->getNodes().size(), 5);
501
502 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
503
504 // Conv's Filter and Bias, plus BN's Scale, Bias, Mean, and Var.
505 EXPECT_EQ(mod_.getConstants().size(), 6);
506
507 optimizedF_ = optimizeFunctionForTest(F_);
508
509 // BatchNorm should have been merged into the Conv.
510 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
511
512 // Filter and Bias should have been duplicated so that the Conv-BN
513 // optimization does not modify the filter/bias being saved, equaling 4
514 // Constants. Additionally, the BN's Scale, Bias, Mean, and Var should be
515 // eliminated due to the opti.
516 EXPECT_EQ(mod_.getConstants().size(), 8);
517
518 ASSERT_EQ(A->getNumUsers(), 2);
519 Node *newCV = A->getUsers().back().getUser();
520 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
521 ASSERT_EQ(newCV->getNumUsers(), 1);
522 Node *save = newCV->getUsers().begin()->getUser();
523 EXPECT_TRUE(llvm::isa<SaveNode>(save));
524
525 EXPECT_EQ(
526 countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
527
528 bindings_.allocate(mod_.getPlaceholders());
529 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
530 checkNumericalEquivalence();
531}
532
533TEST_F(GraphOptz, optimizeBatchNormAfterConvFP16) {
534 auto *A =
535 mod_.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "A", false);
536 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
537 Node *BN =
538 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
539 F_->createSave("ret", BN);
540
541 EXPECT_EQ(F_->getNodes().size(), 3);
542
543 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
544 optimizedF_ = optimizeFunctionForTest(F_);
545
546 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
547
548 ASSERT_EQ(A->getNumUsers(), 2);
549
550 bool optimizedPathExists{false};
551 for (const auto &path : A->getUsers()) {
552 auto cv = path.getUser();
553 EXPECT_TRUE(llvm::isa<ConvolutionNode>(cv));
554 ASSERT_EQ(cv->getNumUsers(), 1);
555 auto next = cv->getUsers().begin()->getUser();
556 optimizedPathExists |= llvm::isa<SaveNode>(next);
557 }
558
559 EXPECT_TRUE(optimizedPathExists);
560
561 bindings_.allocate(A)->getHandle<float16_t>().randomize(-1.0, 1.0,
562 mod_.getPRNG());
563
564 checkNumericalEquivalence();
565}
566
567/// Check that transpose constant folding is done before BatchNorm optimization,
568/// which allows to merge BatchNorm into Convolution with transposed weights.
569TEST_F(GraphOptz, optimizeBatchNormAfterConvWithTransposedWeights) {
570 auto *input =
571 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input", false);
572 auto *filter =
573 mod_.createPlaceholder(ElemKind::FloatTy, {16, 3, 5, 5}, "filter", false);
574 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {16}, "bias", false);
575
576 auto *TN = F_->createTranspose("transpose", filter, NCHW2NHWC);
577 auto *CV = F_->createConv("conv", input, TN, bias,
578 mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 16}),
579 5, 1, 2, 1);
580 auto *BN =
581 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
582 F_->createSave("ret", BN);
583
584 // Initialize to ensure that constant tensors are not optimized out.
585 bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
586 bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
587
588 EXPECT_EQ(F_->getNodes().size(), 4);
589 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1);
590
591 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
592 optimizedF_ = optimizeFunctionForTest(F_);
593
594 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
595 EXPECT_EQ(
596 countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
597
598 bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
599 checkNumericalEquivalence();
600}
601
602/// Check that reshape constant folding is done before BatchNorm optimization,
603/// where Reshape is a result of Transpose 2 Reshape optimization,
604/// which allows to merge BatchNorm into Convolution with transposed weights.
605TEST_F(GraphOptz, optimizeBatchNormAfterConvWithReshapeConst) {
606 auto *input =
607 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input", false);
608 auto *filter =
609 mod_.createPlaceholder(ElemKind::FloatTy, {5, 5, 3, 1}, "filter", false);
610 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
611
612 auto *TN = F_->createTranspose("transpose", filter, HWCN2NHWC);
613 auto *CV = F_->createConv("conv", input, TN, bias,
614 mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 1}),
615 5, 1, 2, 1);
616 auto *BN =
617 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
618 F_->createSave("ret", BN);
619
620 // Initialize to ensure that constant tensors are not optimized out.
621 bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
622 bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
623
624 EXPECT_EQ(F_->getNodes().size(), 4);
625 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1);
626
627 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
628 optimizedF_ = optimizeFunctionForTest(F_);
629
630 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
631 EXPECT_EQ(
632 countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
633
634 bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
635 checkNumericalEquivalence();
636}
637
638/// Check that the batch normalization optimization is
639/// not blocked by predicates and that it preserves them.
640TEST_F(GraphOptz, optimizeBatchNormAfterConvWithPred) {
641 Node *A =
642 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
643 Node *pred1 =
644 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate", false);
645 Node *pred2 =
646 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate", false);
647 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
648 CV->setPredicate(pred1);
649 Node *BN =
650 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
651 BN->setPredicate(pred2);
652 F_->createSave("ret", BN);
653
654 EXPECT_EQ(F_->getNodes().size(), 3);
655
656 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
657 ::glow::optimize(F_, CompilationMode::Infer);
658 EXPECT_EQ(F_->getNodes().size(), 2);
659
660 ASSERT_EQ(A->getNumUsers(), 1);
661 Node *newCV = A->getUsers().begin()->getUser();
662 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
663 ASSERT_TRUE(newCV->hasPredicate());
664 EXPECT_EQ(newCV->getPredicate().getNode(), pred2);
665 ASSERT_EQ(newCV->getNumUsers(), 1);
666 Node *save = newCV->getUsers().begin()->getUser();
667 EXPECT_TRUE(llvm::isa<SaveNode>(save));
668}
669
670/// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add)
671/// into a BatchNorm.
672TEST_F(GraphOptz, MergeBatchNormalizationWithArithmeticChainTest) {
673 // Inputs.
674 auto *input =
675 mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 4}, "input", false);
676 auto *var = mod_.createConstant(ElemKind::FloatTy, {4}, "var");
677 auto *mean = mod_.createConstant(ElemKind::FloatTy, {4}, "mean");
678 auto *beta = mod_.createConstant(ElemKind::FloatTy, {4}, "beta");
679 auto *gamma = mod_.createConstant(ElemKind::FloatTy, {4}, "gamma");
680
681 Node *subC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "subC");
682 Node *mulC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "mulC");
683 Node *addC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "addC");
684 Node *divC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "divC");
685
686 // Fill tensors to check boundary values after the transformation.
687 std::vector<float> betaV = {1., 2., 3., 7.};
688 std::vector<float> gammaV = {4., 5., 6., 7.};
689
690 var->getPayloadMutable().getHandle<float>() = {1., 1., 1., 1.};
691 mean->getPayloadMutable().getHandle<float>() = {0., 0., 0., 0.};
692 beta->getPayloadMutable().getHandle<float>() = betaV;
693 gamma->getPayloadMutable().getHandle<float>() = gammaV;
694
695 // For at least one node (sub) make values within channel different, to test
696 // folding better.
697 const std::vector<float> subV = {1, 2., 3., 4.};
698 const float mulV = 4., addV = 3., divV = 2.;
699 auto subH = llvm::cast<Constant>(subC)->getHandle<float>();
700 subH = {1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
701 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
702 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.};
703
704 llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV);
705 llvm::cast<Constant>(addC)->getHandle<float>().clear(addV);
706 llvm::cast<Constant>(divC)->getHandle<float>().clear(divV);
707
708 BatchNormalizationNode *bn = F_->createBatchNormalization(
709 "batch", input->getType(), input, beta, gamma, mean, var, 3);
710
711 auto *sub = F_->createSub("sub", bn, subC);
712 auto *mul = F_->createMul("mul", sub, mulC);
713 auto *add = F_->createAdd("add", addC, mul);
714 auto *div = F_->createDiv("div", add, divC);
715 auto *res = F_->createSave("save", div);
716
717 // Compile.
718 EXPECT_EQ(F_->getNodes().size(), 6);
719 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
720 optimizedF_ = optimizeFunctionForTest(F_);
721 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
722
723 Constant *cs, *cb;
724
725 auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName());
726
727 auto *newBn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput());
728 ASSERT_TRUE(newBn);
729
730 cs = llvm::dyn_cast<Constant>(newBn->getScale());
731 cb = llvm::dyn_cast<Constant>(newBn->getBias());
732 ASSERT_TRUE(cs);
733 ASSERT_TRUE(cb);
734 ASSERT_TRUE(cs->getType()->isFPType());
735 ASSERT_TRUE(cb->getType()->isFPType());
736
737 auto hs = cs->getHandle<float>();
738 auto hb = cb->getHandle<float>();
739
740 // Verify that scale and offset are computed correctly.
741 for (dim_t i = 0; i < 4; i++) {
742 const float expScale = gammaV[i] * mulV / divV;
743 const float expBias = ((betaV[i] - subV[i]) * mulV + addV) / divV;
744 EXPECT_EQ(expScale, hs.raw(i));
745 EXPECT_EQ(expBias, hb.raw(i));
746 }
747
748 bindings_.allocate(mod_.getPlaceholders());
749 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
750 checkNumericalEquivalence();
751}
752
753/// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add)
754/// into a BatchNorm.
755TEST_F(GraphOptz, FoldArithmeticChainAfterConvIntoBatchNorm) {
756 Node *subC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "subC");
757 Node *mulC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "mulC");
758 Node *addC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "addC");
759 Node *divC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "divC");
760
761 // Start with identity values.
762 std::vector<float> betaV = {0., 0., 0.};
763 std::vector<float> gammaV = {1., 1., 1.};
764
765 // For at least one node make values within channel different, to test
766 // the folding better (ideally all should have different values).
767 const std::vector<float> subV = {1, 2., 3.};
768 const float mulV = 4., addV = 3., divV = 2.;
769 llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV);
770 llvm::cast<Constant>(addC)->getHandle<float>().clear(addV);
771 llvm::cast<Constant>(divC)->getHandle<float>().clear(divV);
772 auto subH = llvm::cast<Constant>(subC)->getHandle<float>();
773 subH = {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
774 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
775 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
776
777 auto *input =
778 mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 3}, "input", false);
779 auto filter =
780 mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 3}, "filter", false);
781 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {3}, "bias", false);
782 bindings_.allocate(bias)->zero();
783
784 ConvolutionNode *CV = F_->createConv(
785 "Conv", input, filter, bias,
786 mod_.uniqueType(ElemKind::FloatTy, {2, 3, 3, 3}), 2, 1, 1, 1);
787
788 auto *sub = F_->createSub("sub", CV, subC);
789 auto *mul = F_->createMul("mul", sub, mulC);
790 auto *add = F_->createAdd("add", addC, mul);
791 auto *div = F_->createDiv("div", add, divC);
792 auto *res = F_->createSave("save", div);
793
794 // Compile.
795 EXPECT_EQ(F_->getNodes().size(), 6);
796 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
797 optimizedF_ = optimizeFunctionForTest(F_);
798 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
799
800 auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName());
801
802 Constant *cs, *cb;
803
804 auto *bn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput());
805 ASSERT_TRUE(bn);
806
807 cs = llvm::dyn_cast<Constant>(bn->getScale());
808 cb = llvm::dyn_cast<Constant>(bn->getBias());
809
810 ASSERT_TRUE(cs);
811 ASSERT_TRUE(cb);
812 ASSERT_TRUE(cs->getType()->isFPType());
813 ASSERT_TRUE(cb->getType()->isFPType());
814
815 auto hs = cs->getHandle<float>();
816 auto hb = cb->getHandle<float>();
817
818 // Verify that scale and offset are computed correctly.
819 for (dim_t i = 0; i < 3; i++) {
820 const float expectedScale = gammaV[i] * (mulV / divV);
821 const float expectedBias = ((betaV[i] - subV[i]) * mulV + addV) / divV;
822 EXPECT_EQ(expectedScale, hs.raw(i));
823 EXPECT_EQ(expectedBias, hb.raw(i));
824 }
825 bindings_.allocate(mod_.getPlaceholders());
826 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
827 bindings_.get(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
828 bindings_.get(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
829 checkNumericalEquivalence();
830}
831
832/// Check CSE will not merge two nodes that have all the same inputs but
833/// different predicates.
834TEST_F(GraphOptz, cseRespectsPredicates) {
835 Placeholder *in = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "in", false);
836 Placeholder *pred1 =
837 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
838 Placeholder *pred2 =
839 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
840
841 Node *RN1 = F_->createRELU("relu1", in);
842 RN1->setPredicate(pred1);
843 SaveNode *save1 = F_->createSave("save1", RN1);
844 save1->setPredicate(pred1);
845
846 Node *RN2 = F_->createRELU("relu2", in);
847 RN2->setPredicate(pred2);
848 SaveNode *save2 = F_->createSave("save2", RN2);
849 save2->setPredicate(pred2);
850
851 // Two RELUS and two Saves.
852 EXPECT_EQ(F_->getNodes().size(), 4);
853 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2);
854 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
855
856 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
857 optimizedF_ = optimizeFunctionForTest(F_);
858
859 // Two RELUS and two Saves should still be there.
860 EXPECT_EQ(F_->getNodes().size(), 4);
861 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2);
862 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
863
864 bindings_.allocate(mod_.getPlaceholders());
865 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
866 checkNumericalEquivalence();
867}
868
869TEST_F(GraphOptz, optimizeBatchNormAfterConvButConvReused) {
870 Placeholder *A =
871 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
872 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
873 Node *BN =
874 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
875 F_->createSave("ret", BN);
876 F_->createSave("convSave", CV);
877
878 EXPECT_EQ(F_->getNodes().size(), 4);
879 optimizedF_ = optimizeFunctionForTest(F_);
880 // Make sure the structure of the graph did not change, since the convolution
881 // node is used more than once.
882 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
883 auto convIt =
884 std::find_if(optimizedF_->getNodes().begin(),
885 optimizedF_->getNodes().end(), [](const Node &node) -> bool {
886 return llvm::isa<ConvolutionNode>(node);
887 });
888 ASSERT_NE(convIt, optimizedF_->getNodes().end());
889 auto batchNormIt =
890 std::find_if(optimizedF_->getNodes().begin(),
891 optimizedF_->getNodes().end(), [](const Node &node) -> bool {
892 return (llvm::isa<BatchNormalizationNode>(node));
893 });
894 ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convIt);
895 BatchNormalizationNode *batchNorm =
896 llvm::dyn_cast<BatchNormalizationNode>(batchNormIt);
897
898 EXPECT_EQ(*conv, *CV);
899 EXPECT_EQ(batchNorm->getInput().getNode(), conv);
900 EXPECT_EQ(conv->getInput().getNode(), A);
901
902 bindings_.allocate(mod_.getPlaceholders());
903 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
904 checkNumericalEquivalence();
905}
906
907TEST_F(GraphOptz, optimizeBatchNormAfterConvButVarReused) {
908 auto *A =
909 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
910
911 ConvolutionNode *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
912 Node *BN =
913 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
914 auto *retSaveNode = F_->createSave("ret", BN);
915 auto *filterSaveNode = F_->createSave("filter", CV->getFilter());
916
917 EXPECT_EQ(F_->getNodes().size(), 4);
918 optimizedF_ = optimizeFunctionForTest(F_);
919 ASSERT_EQ(A->getNumUsers(), 2);
920
921 auto *optimizedF_ret =
922 findFunctionNodeByName<SaveNode>(optimizedF_, retSaveNode->getName());
923 auto *optimizedF_filterSave =
924 findFunctionNodeByName<SaveNode>(optimizedF_, filterSaveNode->getName());
925
926 // Make sure the structure of the graph did not change.
927 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
928 EXPECT_TRUE(llvm::isa<Placeholder>(optimizedF_filterSave->getInput()));
929 auto *varFilter =
930 llvm::dyn_cast<Placeholder>(optimizedF_filterSave->getInput());
931 EXPECT_EQ(varFilter, CV->getFilter());
932 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(optimizedF_ret->getInput()));
933
934 BatchNormalizationNode *batchNorm =
935 llvm::dyn_cast<BatchNormalizationNode>(optimizedF_ret->getInput());
936 ASSERT_TRUE(batchNorm);
937 auto *newCVNode =
938 llvm::dyn_cast<ConvolutionNode>(batchNorm->getInput().getNode());
939 ASSERT_TRUE(newCVNode);
940 EXPECT_EQ(newCVNode->getInput().getNode(), CV->getInput().getNode());
941 EXPECT_EQ(newCVNode->getInput().getNode(), A);
942
943 bindings_.allocate(mod_.getPlaceholders());
944 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
945 checkNumericalEquivalence();
946}
947
948TEST_F(GraphOptz, transposeConstant) {
949 auto *A =
950 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
951 bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG());
952 Tensor transposedA;
953 bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2});
954 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
955 SaveNode *save = F_->createSave("ret", T);
956 EXPECT_EQ(F_->getNodes().size(), 2);
957
958 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
959 ::glow::optimize(F_, CompilationMode::Infer);
960 ASSERT_EQ(F_->getNodes().size(), 1);
961 EXPECT_EQ(&*F_->getNodes().begin(), save);
962 Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode());
963 ASSERT_NE(optimizedA, nullptr);
964 // Check that A has been properly transposed.
965 EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA));
966}
967
968/// Check that the Transpose is merged with Constant in a sequence
969/// Transpose(Quantize(Constant)).
970TEST_F(GraphOptz, transposeQuantizeConstant) {
971 auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.2, 0);
972 auto *input = F_->getParent()->createConstant(ElemKind::FloatTy,
973 {1, 10, 20, 3}, "input");
974 auto *Q = F_->createQuantize("quantize", input, qTy);
975 auto *T = F_->createTranspose("transpose", Q, NHWC2NCHW);
976 auto *S = F_->createSave("save", T);
977
978 // Skip ConstantFolding as it would have the same result as this opt.
979 CompilationContext cctx;
980 cctx.optimizationOpts.enableConstantFolding = false;
981
982 EXPECT_EQ(F_->getNodes().size(), 3);
983 ::glow::optimize(F_, cctx);
984 EXPECT_EQ(F_->getNodes().size(), 2);
985
986 // Constant and Quantize should have new shape.
987 auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
988 ASSERT_TRUE(newQ);
989 EXPECT_TRUE(newQ->getResult().dims().equals({1, 3, 10, 20}));
990 auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
991 ASSERT_TRUE(newC);
992 EXPECT_TRUE(newC->getType()->dims().equals({1, 3, 10, 20}));
993}
994
995/// Check that the removing of transposes still happens when
996/// predicates are involved.
997TEST_F(GraphOptz, transposeConstantWithPredicate) {
998 auto *A =
999 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
1000 auto *pred = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1001 bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG());
1002 Tensor transposedA;
1003 bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2});
1004 // Arguably, if the transpose doesn't happen because the predicate is false
1005 // the value of A should be unchanged. However, the semantic of our
1006 // predicate is that they can be ignored and the program would still
1007 // be correct, thus this optimization is still legal.
1008 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
1009 T->setPredicate(pred);
1010 SaveNode *save = F_->createSave("ret", T);
1011 save->setPredicate(pred);
1012 EXPECT_EQ(F_->getNodes().size(), 2);
1013
1014 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
1015 ::glow::optimize(F_, CompilationMode::Infer);
1016 ASSERT_EQ(F_->getNodes().size(), 1);
1017 EXPECT_EQ(&*F_->getNodes().begin(), save);
1018 // We should have kept the predicate on the save node.
1019 ASSERT_EQ(pred->getNumUsers(), 1);
1020 EXPECT_EQ(pred->getUsers().begin()->getUser(), save);
1021 Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode());
1022 ASSERT_NE(optimizedA, nullptr);
1023 // Check that A has been properly transposed.
1024 EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA));
1025}
1026
1027TEST_F(GraphOptz, BatchNormAfterConvNotOptimizeForTrain) {
1028 Placeholder *A =
1029 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
1030 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
1031 Node *BN =
1032 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
1033 F_->createSave("ret", BN);
1034
1035 EXPECT_EQ(F_->getNodes().size(), 3);
1036
1037 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
1038 ::glow::optimize(optimizedF_, CompilationMode::Train);
1039 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
1040
1041 ASSERT_EQ(A->getNumUsers(), 2);
1042 Node *curCV = A->getUsers().begin()->getUser();
1043 EXPECT_EQ(curCV, CV);
1044 ASSERT_EQ(curCV->getNumUsers(), 1);
1045 Node *curBN = curCV->getUsers().begin()->getUser();
1046 EXPECT_EQ(curBN, BN);
1047 ASSERT_EQ(curBN->getNumUsers(), 1);
1048 Node *save = curBN->getUsers().begin()->getUser();
1049 EXPECT_TRUE(llvm::isa<SaveNode>(save));
1050
1051 bindings_.allocate(mod_.getPlaceholders());
1052 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1053 checkNumericalEquivalence();
1054}
1055
1056TEST_F(GraphOptz, batchNormAfterConvNotOptimizeWhenMoreThanOneUseOfConv) {
1057 Node *A =
1058 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
1059
1060 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
1061 Node *BN =
1062 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
1063 SaveNode *convSave = F_->createSave("ret", CV);
1064 SaveNode *ret = F_->createSave("ret", BN);
1065
1066 EXPECT_EQ(F_->getNodes().size(), 4);
1067
1068 ::glow::optimize(F_, CompilationMode::Infer);
1069 // Make sure the structure of the graph did not change, since the convolution
1070 // node is used more than once.
1071 EXPECT_EQ(F_->getNodes().size(), 4);
1072 ASSERT_TRUE(llvm::isa<ConvolutionNode>(convSave->getInput()));
1073 ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convSave->getInput());
1074 EXPECT_EQ(conv, CV);
1075 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(ret->getInput()));
1076 BatchNormalizationNode *batchNorm =
1077 llvm::dyn_cast<BatchNormalizationNode>(ret->getInput());
1078 EXPECT_EQ(batchNorm, BN);
1079 EXPECT_EQ(batchNorm->getInput().getNode(), CV);
1080 EXPECT_EQ(conv->getInput().getNode(), A);
1081}
1082
1083enum class TestSinkTransposeNodesKind {
1084 BatchNormalization,
1085 Relu,
1086 LeakyRelu,
1087 Clip,
1088 Sigmoid,
1089 Tanh,
1090 Quantize,
1091};
1092
1093class GraphOptzSinkTransposeBelowParametrized
1094 : public GraphOptz,
1095 public ::testing::WithParamInterface<TestSinkTransposeNodesKind> {
1096public:
1097 NodeValue getNodeFromInput(TestSinkTransposeNodesKind testNode, Node *T) {
1098 switch (testNode) {
1099 case TestSinkTransposeNodesKind::BatchNormalization: {
1100 return F_->createBatchNormalization(bindings_, "batch", T, 1, 0.0001, 0.9)
1101 ->getResult();
1102 }
1103 case TestSinkTransposeNodesKind::Relu: {
1104 return F_->createRELU("relu", T)->getResult();
1105 }
1106 case TestSinkTransposeNodesKind::LeakyRelu: {
1107 return F_->createLeakyRELU("leaky_relu", T, 0.1)->getResult();
1108 }
1109 case TestSinkTransposeNodesKind::Clip: {
1110 return F_->createClip("clip", T, 0.0, 6.0)->getResult();
1111 }
1112 case TestSinkTransposeNodesKind::Sigmoid: {
1113 return F_->createSigmoid("sigmoid", T)->getResult();
1114 }
1115 case TestSinkTransposeNodesKind::Tanh: {
1116 return F_->createTanh("tanh", T)->getResult();
1117 }
1118 case TestSinkTransposeNodesKind::Quantize: {
1119 return F_
1120 ->createQuantize(
1121 "quantize", T,
1122 mod_.uniqueType(ElemKind::Int8QTy, T->dims(0), 0.03, 5))
1123 ->getResult();
1124 }
1125 }
1126 LOG(DFATAL) << "Cannot reach here.";
1127 return NodeValue(); // Prevents a compilation warning.
1128 }
1129};
1130
1131TEST_P(GraphOptzSinkTransposeBelowParametrized,
1132 TestSinkTransposeForDifferentCases) {
1133 const dim_t origDims[] = {1, 5, 10, 15};
1134 const dim_t transposedDims[] = {1, 15, 5, 10};
1135 auto *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1136 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
1137 auto IN = getNodeFromInput(GetParam(), T);
1138 SaveNode *O = F_->createSave("ret", IN);
1139
1140 EXPECT_EQ(F_->getNodes().size(), 3);
1141 EXPECT_EQ(IN.dims(), llvm::makeArrayRef(transposedDims));
1142
1143 optimizedF_ = optimizeFunctionForTest(F_);
1144 O = llvm::dyn_cast<SaveNode>(std::find_if(
1145 optimizedF_->getNodes().begin(), optimizedF_->getNodes().end(),
1146 [](const auto &N) { return N.getKind() == Kinded::Kind::SaveNodeKind; }));
1147
1148 // Expecting Transpose->Output rather than N->Output.
1149 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1150 ASSERT_NE(transpose, nullptr);
1151 Node *N = transpose->getInput();
1152 ASSERT_TRUE(N);
1153 // Test correct input.
1154 if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) {
1155 ASSERT_EQ(BatchNormalizationNode::InputIdx, 0);
1156 } else {
1157 ASSERT_EQ(N->getNumInputs(), 1);
1158 }
1159 // Check that the dimensions of the input and output have been
1160 // updated to compensate the absence of transpose.
1161 EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims));
1162 EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims));
1163 EXPECT_EQ(F_->getNodes().size(), 3);
1164
1165 bindings_.allocate(mod_.getPlaceholders());
1166 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1167 checkNumericalEquivalence();
1168}
1169
1170TEST_P(GraphOptzSinkTransposeBelowParametrized,
1171 TestSinkTransposeWithPredicateForDifferentCases) {
1172 if (GetParam() == TestSinkTransposeNodesKind::Quantize) {
1173 // Quantize does not work with generic test for predicates.
1174 return;
1175 }
1176 const dim_t origDims[] = {1, 5, 10, 15};
1177 const dim_t transposedDims[] = {1, 15, 5, 10};
1178 Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1179 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1180 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1181 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1182 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
1183 T->setPredicate(pred1);
1184 Node *IN = getNodeFromInput(GetParam(), T);
1185 IN->setPredicate(pred2);
1186 SaveNode *O = F_->createSave("ret", IN);
1187 O->setPredicate(pred3);
1188
1189 EXPECT_EQ(F_->getNodes().size(), 3);
1190 EXPECT_EQ(IN->getNthResult(0).dims(), llvm::makeArrayRef(transposedDims));
1191
1192 ::glow::optimize(F_, CompilationMode::Infer);
1193
1194 EXPECT_EQ(O->getPredicate().getNode(), pred3);
1195 // Expecting Transpose->Output rather than N->Output.
1196 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1197 ASSERT_NE(transpose, nullptr);
1198 EXPECT_EQ(transpose->getPredicate().getNode(), pred2);
1199 Node *N = transpose->getInput();
1200 ASSERT_TRUE(N);
1201 EXPECT_EQ(N->getPredicate().getNode(), pred2);
1202
1203 // Test correct input.
1204 if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) {
1205 ASSERT_EQ(BatchNormalizationNode::InputIdx, 0);
1206 } else {
1207 ASSERT_EQ(N->getNumInputs(), 1);
1208 }
1209
1210 // Check that the dimensions of the input and output have been
1211 // updated to compensate the absence of transpose.
1212 EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims));
1213 EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims));
1214 EXPECT_EQ(F_->getNodes().size(), 3);
1215}
1216
1217GLOW_INSTANTIATE_TEST_SUITE_P(
1218 TestSinkTranspose, GraphOptzSinkTransposeBelowParametrized,
1219 ::testing::Values(TestSinkTransposeNodesKind::BatchNormalization,
1220 TestSinkTransposeNodesKind::Relu,
1221 TestSinkTransposeNodesKind::LeakyRelu,
1222 TestSinkTransposeNodesKind::Clip,
1223 TestSinkTransposeNodesKind::Sigmoid,
1224 TestSinkTransposeNodesKind::Tanh,
1225 TestSinkTransposeNodesKind::Quantize));
1226
1227TEST_F(GraphOptz, SinkTransposeBelowDequantize) {
1228 auto *in =
1229 mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input", false);
1230 auto *quantize = F_->createQuantize(
1231 "quantize", in, mod_.uniqueType(ElemKind::Int8QTy, in->dims(), 0.01, 2));
1232 auto *tile = F_->createTile("tile", quantize, 3, 0);
1233 auto *transpose = F_->createTranspose("transpose", tile, NHWC2NCHW);
1234 auto *deq = F_->createDequantize("dequantize", transpose, ElemKind::FloatTy);
1235 SaveNode *O = F_->createSave("out", deq);
1236
1237 optimizedF_ = optimizeFunctionForTest(F_);
1238
1239 EXPECT_EQ(F_->getNodes().size(), 5);
1240 EXPECT_EQ(optimizedF_->getNodes().size(), 5);
1241
1242 auto *optOut = findFunctionNodeByName<SaveNode>(optimizedF_, O->getName());
1243 EXPECT_TRUE(llvm::isa<TransposeNode>(optOut->getInput().getNode()));
1244
1245 bindings_.allocate(mod_.getPlaceholders());
1246 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1247 checkNumericalEquivalence();
1248}
1249
1250TEST_F(GraphOptz, SinkTransposeBelowPRelu) {
1251 auto *input =
1252 mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input", false);
1253 auto *slope =
1254 mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "slope", false);
1255 auto *OT = mod_.uniqueType(ElemKind::FloatTy, {1, 5, 10, 15});
1256 auto *prelu = F_->createPRELU("prelu", input, slope, OT);
1257 auto *transpose = F_->createTranspose("transpose", prelu, NHWC2NCHW);
1258 SaveNode *O = F_->createSave("out", transpose);
1259
1260 optimizedF_ = optimizeFunctionForTest(F_);
1261
1262 EXPECT_EQ(F_->getNodes().size(), 3);
1263 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
1264
1265 auto *optOut = findFunctionNodeByName<SaveNode>(optimizedF_, O->getName());
1266 EXPECT_TRUE(llvm::isa<TransposeNode>(optOut->getInput().getNode()));
1267
1268 bindings_.allocate(mod_.getPlaceholders());
1269 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1270 bindings_.get(slope)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1271 checkNumericalEquivalence();
1272}
1273
1274TEST_F(GraphOptz, SinkTransposeBelowTile) {
1275 auto *in =
1276 mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input", false);
1277 auto *transpose = F_->createTranspose("transpose", in, NHWC2NCHW);
1278 auto *tile = F_->createTile("tile", transpose, 4, 1);
1279 auto *save = F_->createSave("save", tile);
1280
1281 optimizedF_ = optimizeFunctionForTest(
1282 F_, {FunctionPassID::SinkCode, getDCEPassConfig()});
1283
1284 EXPECT_EQ(F_->getNodes().size(), 3);
1285 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
1286
1287 auto *saveOpt =
1288 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
1289 auto *transposeOpt = llvm::dyn_cast<TransposeNode>(saveOpt->getInput());
1290 ASSERT_TRUE(transposeOpt);
1291 EXPECT_EQ(transposeOpt->getShuffle(), transpose->getShuffle());
1292 auto *tileOpt = llvm::dyn_cast<TileNode>(transposeOpt->getInput());
1293 ASSERT_TRUE(tileOpt);
1294 EXPECT_EQ(tileOpt->getAxis(), 3);
1295 EXPECT_EQ(tileOpt->getCount(), 4);
1296
1297 bindings_.allocate(mod_.getPlaceholders());
1298 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1299 checkNumericalEquivalence();
1300}
1301
1302TEST_F(GraphOptz, HoistTransposeAboveTile) {
1303 auto *in =
1304 mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input", false);
1305 auto *tile = F_->createTile("tile", in, 4, 3);
1306 auto *transpose = F_->createTranspose("transpose", tile, NHWC2NCHW);
1307 auto *save = F_->createSave("save", transpose);
1308
1309 optimizedF_ = optimizeFunctionForTest(F_);
1310
1311 EXPECT_EQ(F_->getNodes().size(), 3);
1312 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
1313
1314 auto *saveOpt =
1315 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
1316 auto *tileOpt = llvm::dyn_cast<TileNode>(saveOpt->getInput());
1317 ASSERT_TRUE(tileOpt);
1318 EXPECT_EQ(tileOpt->getAxis(), 1);
1319 EXPECT_EQ(tileOpt->getCount(), 4);
1320 auto *transposeOpt = llvm::dyn_cast<TransposeNode>(tileOpt->getInput());
1321 ASSERT_TRUE(transposeOpt);
1322 EXPECT_EQ(transposeOpt->getShuffle(), transpose->getShuffle());
1323
1324 bindings_.allocate(mod_.getPlaceholders());
1325 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1326 checkNumericalEquivalence();
1327}
1328
1329/// For example folding Rescale in to Convolution.
1330TEST_F(GraphOptz, sinkTransposeBelowRescale) {
1331 // Inputs.
1332 const dim_t origDims[] = {1, 5, 10, 15};
1333 const dim_t transposedDims[] = {1, 15, 5, 10};
1334 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, origDims, 0.1, 0,
1335 "input", false);
1336 auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {15, 1, 1, 15}, 0.1,
1337 0, "filter", false);
1338 auto *bias =
1339 mod_.createPlaceholder(ElemKind::Int32QTy, {15}, 0.01, 0, "bias", false);
1340
1341 // Graph.
1342 ConvolutionNode *conv =
1343 F_->createConv("conv", input, filter, bias, input->getType(), {1, 1},
1344 {1, 1}, {0, 0, 0, 0}, 1);
1345
1346 auto *T = F_->createTranspose("transpose", conv, NHWC2NCHW);
1347 auto *RT = mod_.uniqueType(ElemKind::Int8QTy, T->getResult().dims(), 0.2, 0);
1348 auto *R = F_->createRescaleQuantized("rescale", T, RT);
1349 SaveNode *O = F_->createSave("ret", R);
1350
1351 EXPECT_EQ(F_->getNodes().size(), 4);
1352 EXPECT_EQ(RT->dims(), llvm::makeArrayRef(transposedDims));
1353
1354 ::glow::optimize(F_, CompilationMode::Infer);
1355
1356 // Expecting Transpose->Output rather than Rescale->Output.
1357 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1358 ASSERT_NE(transpose, nullptr);
1359 ASSERT_TRUE(llvm::isa<ConvolutionNode>(transpose->getInput()));
1360 auto &convTRInput = transpose->getInput();
1361 // Check that the dimensions of the input and output have been
1362 // updated to compensate the absence of transpose.
1363 EXPECT_EQ(convTRInput.dims(), llvm::makeArrayRef(origDims));
1364 EXPECT_EQ(convTRInput.getNode()->getNthInput(0).dims(),
1365 llvm::makeArrayRef(origDims));
1366 EXPECT_EQ(F_->getNodes().size(), 3);
1367}
1368
1369TEST_F(GraphOptz, cancelTwoTransposes) {
1370 const dim_t origDims[] = {1, 5, 10, 15};
1371 Placeholder *A =
1372 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1373 Node *T1 = F_->createTranspose("transpose", A, NCHW2NHWC);
1374 Node *T2 = F_->createTranspose("transpose", T1, NHWC2NCHW);
1375 ReluNode *K = F_->createRELU("relu", T2);
1376 SaveNode *save = F_->createSave("ret", K);
1377
1378 EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims));
1379 EXPECT_EQ(F_->getNodes().size(), 4);
1380
1381 optimizedF_ = optimizeFunctionForTest(F_);
1382
1383 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
1384
1385 for (auto &N : optimizedF_->getNodes()) {
1386 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1387 save = llvm::dyn_cast<SaveNode>(&N);
1388 }
1389 }
1390
1391 ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput());
1392 ASSERT_TRUE(relu);
1393 EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims));
1394 EXPECT_EQ(relu->getInput().getNode(), A);
1395
1396 bindings_.allocate(mod_.getPlaceholders());
1397 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1398
1399 checkNumericalEquivalence();
1400}
1401
1402/// Make sure the predicates don't get in the way of the
1403/// transpose(transpose) => identity and that they are
1404/// preserved.
1405TEST_F(GraphOptz, cancelTwoTransposesWithPredicate) {
1406 const dim_t origDims[] = {1, 5, 10, 15};
1407 Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1408 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1409 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1410 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1411 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1412 Node *T1 = F_->createTranspose("transpose", A, NCHW2NHWC);
1413 T1->setPredicate(pred1);
1414 Node *T2 = F_->createTranspose("transpose", T1, NHWC2NCHW);
1415 T2->setPredicate(pred2);
1416 ReluNode *K = F_->createRELU("relu", T2);
1417 K->setPredicate(pred3);
1418 SaveNode *save = F_->createSave("ret", K);
1419 save->setPredicate(pred4);
1420
1421 EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims));
1422 EXPECT_EQ(F_->getNodes().size(), 4);
1423
1424 ::glow::optimize(F_, CompilationMode::Infer);
1425
1426 EXPECT_EQ(F_->getNodes().size(), 2);
1427 EXPECT_EQ(save->getPredicate().getNode(), pred4);
1428 ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput());
1429 ASSERT_TRUE(relu);
1430 EXPECT_EQ(relu->getPredicate().getNode(), pred3);
1431 EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims));
1432 EXPECT_EQ(relu->getInput().getNode(), A);
1433}
1434
1435TEST_F(GraphOptz, removeIdentityTranspose) {
1436 const dim_t origDims[] = {1, 5, 10, 15};
1437 Placeholder *A =
1438 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1439 TransposeNode *T = F_->createTranspose("transpose", A, {0, 1, 2, 3});
1440 ReluNode *K = F_->createRELU("relu", T);
1441 F_->createSave("ret", K);
1442
1443 EXPECT_EQ(F_->getNodes().size(), 3);
1444 EXPECT_EQ(K->getInput().getNode(), T);
1445
1446 ::glow::optimize(F_, CompilationMode::Infer);
1447
1448 EXPECT_EQ(F_->getNodes().size(), 2);
1449 EXPECT_EQ(K->getInput().getNode(), A);
1450 // Make sure we didn't mess up with the dimensions of the
1451 // variable while eliminating the transpose.
1452 EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims));
1453}
1454
1455/// Check that the predicates don't get in the way of
1456/// the identity transpose removal, while still being
1457/// preserved.
1458TEST_F(GraphOptz, removeIdentityTransposeWithPredicate) {
1459 const dim_t origDims[] = {1, 5, 10, 15};
1460 Placeholder *A =
1461 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1462 Placeholder *pred1 =
1463 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1464 Placeholder *pred2 =
1465 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1466 Placeholder *pred3 =
1467 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1468 TransposeNode *T = F_->createTranspose("transpose", A, {0, 1, 2, 3});
1469 T->setPredicate(pred1);
1470 ReluNode *K = F_->createRELU("relu", T);
1471 K->setPredicate(pred2);
1472 SaveNode *save = F_->createSave("ret", K);
1473 save->setPredicate(pred3);
1474
1475 EXPECT_EQ(F_->getNodes().size(), 3);
1476 EXPECT_EQ(K->getInput().getNode(), T);
1477
1478 ::glow::optimize(F_, CompilationMode::Infer);
1479 EXPECT_EQ(F_->getNodes().size(), 2);
1480 EXPECT_EQ(save->getPredicate().getNode(), pred3);
1481 EXPECT_EQ(save->getInput().getNode(), K);
1482 EXPECT_EQ(K->getInput().getNode(), A);
1483 EXPECT_EQ(K->getPredicate().getNode(), pred2);
1484 // Make sure we didn't mess up with the dimensions of the
1485 // variable while eliminating the transpose.
1486 EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims));
1487}
1488
1489/// Check that consecutive non-inverse transposes are merged
1490/// into an equivalent single transpose node.
1491TEST_F(GraphOptz, mergeNonInverseTransposes) {
1492 const dim_t origDims[] = {1, 5, 10, 15};
1493 const dim_t finalDims[] = {5, 1, 15, 10};
1494
1495 Placeholder *A =
1496 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1497 TransposeNode *T1 = F_->createTranspose("transpose", A, {0, 3, 2, 1});
1498 TransposeNode *T2 = F_->createTranspose("transpose", T1, {0, 2, 3, 1});
1499 TransposeNode *T3 = F_->createTranspose("transpose", T2, {1, 0, 3, 2});
1500 TransposeNode *T4 = F_->createTranspose("transpose", T3, {3, 1, 2, 0});
1501
1502 // Intermediate dims after each tranpose
1503 // Initial : {1, 5, 10, 15}
1504 // After T1: {1, 15, 10, 5}
1505 // After T2: {1, 10, 5, 15}
1506 // After T3: {10, 1, 15, 5}
1507 // After T4: {5, 1, 15, 10}
1508
1509 SaveNode *save = F_->createSave("ret", T4);
1510
1511 EXPECT_EQ(F_->getNodes().size(), 5);
1512
1513 optimizedF_ = optimizeFunctionForTest(F_);
1514 // Find save node in the optimized graph.
1515 for (auto &N : optimizedF_->getNodes()) {
1516 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1517 save = llvm::dyn_cast<SaveNode>(&N);
1518 }
1519 }
1520 // Get the last transpose node in the optimized graph.
1521 auto *TR = llvm::dyn_cast<TransposeNode>(save->getInput());
1522 ASSERT_NE(TR, nullptr);
1523
1524 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
1525 EXPECT_EQ(TR->getResult().dims(), llvm::makeArrayRef(finalDims));
1526 EXPECT_EQ(A->getNthResult(0).dims(), llvm::makeArrayRef(origDims));
1527 EXPECT_EQ(TR->getInput().getNode(), A);
1528
1529 bindings_.allocate(mod_.getPlaceholders());
1530 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1531 checkNumericalEquivalence();
1532}
1533
1534TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodes) {
1535 const dim_t origDims[] = {1, 5, 10, 15};
1536 Node *A1 =
1537 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1538 Node *A2 =
1539 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1540 Node *T1 = F_->createTranspose("transpose1", A1, NHWC2NCHW);
1541 Node *T2 = F_->createTranspose("transpose2", A2, NHWC2NCHW);
1542 Node *K = F_->createAdd("arith", T1, T2);
1543 Node *P = F_->createPow("pow", K, T2);
1544 SaveNode *O = F_->createSave("ret", P);
1545
1546 EXPECT_EQ(F_->getNodes().size(), 5);
1547
1548 ::glow::optimize(F_, CompilationMode::Infer);
1549
1550 // Expecting Transpose->Output rather than Add->Output.
1551 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1552 ASSERT_NE(transpose, nullptr);
1553 auto *pow = llvm::dyn_cast<PowNode>(transpose->getInput());
1554 ASSERT_TRUE(pow);
1555 auto *add = llvm::dyn_cast<AddNode>(pow->getLHS());
1556 ASSERT_TRUE(add);
1557 // Check that the dimensions of the input and output have been
1558 // updated to compensate the absence of transpose.
1559 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1560 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1561 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1562 EXPECT_EQ(add->getLHS().getNode(), A1);
1563 EXPECT_EQ(add->getRHS().getNode(), A2);
1564
1565 EXPECT_EQ(F_->getNodes().size(), 4);
1566}
1567
1568/// Check that Transpose node is sunk below arithmetic nodes when one of the
1569/// operands is a Constant.
1570TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithConstantOperand) {
1571 const dim_t origDims[] = {1, 5, 10, 15};
1572 const dim_t transposedDims[] = {1, 15, 5, 10};
1573
1574 // Create one subgraph in which the Constant is the LHS operand of the Add.
1575 Constant *C1 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C1");
1576 // Initialize the payload before optimization so that it can be copied to the
1577 // new Constant that will be created by the GraphOptimizer.
1578 C1->getHandle().randomize(-1, 1, mod_.getPRNG());
1579
1580 auto *P1 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P1", false);
1581 auto *T1 = F_->createTranspose("T1", P1, NHWC2NCHW);
1582 auto *A1 = F_->createAdd("A1", C1, T1);
1583 SaveNode *S1 = F_->createSave("S1", A1);
1584
1585 // Create one subgraph in which the Constnat is the RHS operand of the Add.
1586 Constant *C2 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C2");
1587 // Initialize the payload before optimization so that it can be copied to the
1588 // new Constant that will be created by the GraphOptimizer.
1589 C2->getHandle().randomize(-1, 1, mod_.getPRNG());
1590
1591 auto *P2 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P2", false);
1592 auto *T2 = F_->createTranspose("T2", P2, NHWC2NCHW);
1593 auto *A2 = F_->createAdd("A2", T2, C2);
1594 SaveNode *S2 = F_->createSave("S2", A2);
1595
1596 EXPECT_EQ(F_->getNodes().size(), 6);
1597
1598 optimizedF_ = optimizeFunctionForTest(F_);
1599
1600 // Find the SaveNodes of the optimized graph.
1601 for (auto &N : optimizedF_->getNodes()) {
1602 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1603 if (N.getName() == S1->getName()) {
1604 S1 = llvm::dyn_cast<SaveNode>(&N);
1605 }
1606
1607 if (N.getName() == S2->getName()) {
1608 S2 = llvm::dyn_cast<SaveNode>(&N);
1609 }
1610 }
1611 }
1612
1613 // Expecting Transpose->Output rather than Add->Output.
1614 auto *transpose = llvm::dyn_cast<TransposeNode>(S1->getInput());
1615 ASSERT_NE(transpose, nullptr);
1616 auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1617 ASSERT_TRUE(add);
1618 // Check that the dimensions of the input and output of the add have been
1619 // updated to compensate the absence of transpose.
1620 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1621 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1622 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1623 EXPECT_EQ(add->getRHS().getNode(), P1);
1624
1625 // Repeat checks for other subgraph.
1626 transpose = llvm::dyn_cast<TransposeNode>(S2->getInput());
1627 ASSERT_NE(transpose, nullptr);
1628 add = llvm::dyn_cast<AddNode>(transpose->getInput());
1629 ASSERT_TRUE(add);
1630 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1631 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1632 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1633 EXPECT_EQ(add->getLHS().getNode(), P2);
1634
1635 EXPECT_EQ(optimizedF_->getNodes().size(), 6);
1636
1637 // Check that the original and optimized functions are numerically equivalent.
1638 // This indirectly checks that the Constant has been transposed properly.
1639 bindings_.allocate(mod_.getPlaceholders());
1640 bindings_.get(P1)->getHandle().randomize(-1, 1, mod_.getPRNG());
1641 bindings_.get(P2)->getHandle().randomize(-1, 1, mod_.getPRNG());
1642
1643 checkNumericalEquivalence();
1644}
1645
1646/// Test sink Transpose below Add of which operands has the same element type
1647/// and shape, but different scale and offset.
1648TEST_F(GraphOptz, sinkQuantTransposeBelowArithmeticNodesWithConstantOperand) {
1649 const dim_t origDims[] = {1, 5, 10, 15};
1650 const dim_t transposedDims[] = {1, 15, 5, 10};
1651
1652 // Create graph where a Add take a Constant in LHS and Transpose in RHS.
1653 // LHS and RHS has different scale and offset.
1654 Constant *lhsC =
1655 mod_.createConstant(ElemKind::Int8QTy, transposedDims, 0.2, 0, "C1");
1656 lhsC->getHandle<int8_t>().randomize(-128, 127, mod_.getPRNG());
1657
1658 auto *inputP =
1659 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "Input", false);
1660 auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, origDims, 0.3, 2);
1661 auto *quant = F_->createQuantize("Quant", inputP, qTy);
1662 auto *rhsT = F_->createTranspose("RHS", quant, NHWC2NCHW);
1663 auto *addQ = F_->createAdd("Add", lhsC, rhsT);
1664 SaveNode *save = F_->createSave("Save", addQ);
1665
1666 EXPECT_EQ(F_->getNodes().size(), 4);
1667
1668 optimizedF_ = optimizeFunctionForTest(F_);
1669
1670 // Expecting Transpose->Output rather than Add->Output.
1671 const auto *saveOpt =
1672 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
1673 auto *transpose = llvm::dyn_cast<TransposeNode>(saveOpt->getInput());
1674 ASSERT_NE(transpose, nullptr);
1675 auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1676 ASSERT_TRUE(add);
1677 // Check that the dimensions of the input and output of the add have been
1678 // updated to compensate the absence of transpose.
1679 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1680 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1681 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1682 quant = llvm::dyn_cast<QuantizeNode>(add->getRHS().getNode());
1683 ASSERT_TRUE(quant);
1684 EXPECT_EQ(quant->getInput().getNode(), inputP);
1685 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
1686
1687 // Check that the original and optimized functions are numerically equivalent.
1688 // This indirectly checks that the Constant has been transposed properly.
1689 bindings_.allocate(mod_.getPlaceholders());
1690 bindings_.get(inputP)->getHandle().randomize(-128, 127, mod_.getPRNG());
1691
1692 checkNumericalEquivalence();
1693}
1694
1695/// Check that the predicates are properly preserved while doing
1696/// the add(transpose, transpose) => transpose(add).
1697TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithPredicate) {
1698 const dim_t origDims[] = {1, 5, 10, 15};
1699 Node *A1 =
1700 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1701 Node *A2 =
1702 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1703 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1704 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1705 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1706 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1707 Node *T1 = F_->createTranspose("transpose1", A1, NHWC2NCHW);
1708 T1->setPredicate(pred1);
1709 Node *T2 = F_->createTranspose("transpose2", A2, NHWC2NCHW);
1710 T2->setPredicate(pred2);
1711 Node *K = F_->createAdd("arith", T1, T2);
1712 K->setPredicate(pred3);
1713 SaveNode *O = F_->createSave("ret", K);
1714 O->setPredicate(pred4);
1715
1716 EXPECT_EQ(F_->getNodes().size(), 4);
1717
1718 ::glow::optimize(F_, CompilationMode::Infer);
1719
1720 EXPECT_EQ(O->getPredicate().getNode(), pred4);
1721 // Expecting Transpose->Output rather than Add->Output.
1722 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1723 ASSERT_NE(transpose, nullptr);
1724 EXPECT_EQ(transpose->getPredicate().getNode(), pred3);
1725 auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1726 ASSERT_TRUE(add);
1727 EXPECT_EQ(add->getPredicate().getNode(), pred3);
1728 // Check that the dimensions of the input and output have been
1729 // updated to compensate the absence of transpose.
1730 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1731 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1732 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1733 EXPECT_EQ(add->getLHS().getNode(), A1);
1734 EXPECT_EQ(add->getRHS().getNode(), A2);
1735
1736 EXPECT_EQ(F_->getNodes().size(), 3);
1737}
1738
1739TEST_F(GraphOptz, sinkReluBelowConcatNodes) {
1740 const dim_t origDims[] = {1, 5, 10, 15};
1741 const dim_t origDimsConcat[] = {1, 10, 10, 15};
1742 Node *A1 =
1743 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1744 Node *A2 =
1745 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1746 Node *R1 = F_->createRELU("relu1", A1);
1747 Node *R2 = F_->createRELU("relu2", A2);
1748 Node *CN = F_->createConcat("concat", {R1, R2}, 1);
1749 SaveNode *O = F_->createSave("ret", CN);
1750
1751 EXPECT_EQ(F_->getNodes().size(), 4);
1752
1753 ::glow::optimize(F_, CompilationMode::Infer);
1754
1755 // Expecting RELU->Output rather than Concat->Output.
1756 auto *relu = llvm::dyn_cast<ReluNode>(O->getInput());
1757 ASSERT_NE(relu, nullptr);
1758 auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput());
1759 ASSERT_TRUE(concat);
1760 // Check that the dimensions of the input and output have been
1761 // updated to compensate the absence of transpose.
1762 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1763 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1764 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1765 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1766 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1767
1768 EXPECT_EQ(F_->getNodes().size(), 3);
1769}
1770
1771/// Check that the predicates are properly preserved while doing
1772/// the sinking of relu nodes.
1773TEST_F(GraphOptz, sinkReluBelowConcatNodesWithPredicate) {
1774 const dim_t origDims[] = {1, 5, 10, 15};
1775 const dim_t origDimsConcat[] = {1, 10, 10, 15};
1776 Node *A1 =
1777 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1778 Node *A2 =
1779 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1780 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1781 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1782 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1783 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1784 Node *R1 = F_->createRELU("relu1", A1);
1785 R1->setPredicate(pred1);
1786 Node *R2 = F_->createRELU("relu2", A2);
1787 R2->setPredicate(pred2);
1788 Node *CN = F_->createConcat("concat", {R1, R2}, 1);
1789 CN->setPredicate(pred3);
1790 SaveNode *O = F_->createSave("ret", CN);
1791 O->setPredicate(pred4);
1792
1793 EXPECT_EQ(F_->getNodes().size(), 4);
1794
1795 ::glow::optimize(F_, CompilationMode::Infer);
1796
1797 // Expecting RELU->Output rather than Concat->Output.
1798 EXPECT_EQ(O->getPredicate().getNode(), pred4);
1799 auto *relu = llvm::dyn_cast<ReluNode>(O->getInput());
1800 ASSERT_NE(relu, nullptr);
1801 EXPECT_EQ(relu->getPredicate().getNode(), pred3);
1802 auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput());
1803 ASSERT_TRUE(concat);
1804 EXPECT_EQ(concat->getPredicate().getNode(), pred3);
1805 // Check that the dimensions of the input and output have been
1806 // updated to compensate the absence of transpose.
1807 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1808 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1809 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1810 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1811 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1812
1813 EXPECT_EQ(F_->getNodes().size(), 3);
1814}
1815
1816TEST_F(GraphOptz, sinkTransposeBelowConcatNodes) {
1817 const dim_t origDims[] = {1, 5, 10, 15};
1818 const dim_t origDimsConcat[] = {1, 5, 20, 15};
1819 Node *A1 =
1820 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1821 Node *A2 =
1822 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1823 Node *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1824 Node *T2 = F_->createTranspose("transpose", A2, NCHW2NHWC);
1825 Node *CN = F_->createConcat("concat", {T1, T2}, 1);
1826 SaveNode *O = F_->createSave("ret", CN);
1827
1828 EXPECT_EQ(F_->getNodes().size(), 4);
1829
1830 ::glow::optimize(F_, CompilationMode::Infer);
1831
1832 // Expecting Transpose->Output rather than Add->Output.
1833 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1834 ASSERT_NE(transpose, nullptr);
1835 auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput());
1836 ASSERT_TRUE(concat);
1837 // Check that the dimensions of the input and output have been
1838 // updated to compensate the absence of transpose.
1839 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1840 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1841 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1842 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1843 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1844
1845 EXPECT_EQ(F_->getNodes().size(), 3);
1846}
1847
1848/// Check that the predicates are properly preserved while doing
1849/// the concat(transpose, transpose) => transpose(add).
1850TEST_F(GraphOptz, sinkTransposeBelowConcatWithPredicate) {
1851 const dim_t origDims[] = {1, 5, 10, 15};
1852 const dim_t origDimsConcat[] = {1, 5, 20, 15};
1853 Node *A1 =
1854 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1855 Node *A2 =
1856 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1857 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1858 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1859 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1860 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1861 Node *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1862 T1->setPredicate(pred1);
1863 Node *T2 = F_->createTranspose("transpose", A2, NCHW2NHWC);
1864 T2->setPredicate(pred2);
1865 Node *CN = F_->createConcat("concat", {T1, T2}, 1);
1866 CN->setPredicate(pred3);
1867 SaveNode *O = F_->createSave("ret", CN);
1868 O->setPredicate(pred4);
1869
1870 EXPECT_EQ(F_->getNodes().size(), 4);
1871
1872 ::glow::optimize(F_, CompilationMode::Infer);
1873
1874 EXPECT_EQ(O->getPredicate().getNode(), pred4);
1875 // Expecting Transpose->Output rather than Add->Output.
1876 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1877 ASSERT_NE(transpose, nullptr);
1878 EXPECT_EQ(transpose->getPredicate().getNode(), pred3);
1879 auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput());
1880 ASSERT_TRUE(concat);
1881 EXPECT_EQ(concat->getPredicate().getNode(), pred3);
1882 // Check that the dimensions of the input and output have been
1883 // updated to compensate the absence of transpose.
1884 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1885 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1886 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1887 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1888 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1889
1890 EXPECT_EQ(F_->getNodes().size(), 3);
1891}
1892
1893TEST_F(GraphOptz, sinkTransposeBelowPad) {
1894 // The shape of the graph before the optimization.
1895 const dim_t inputDims[] = {1, 5, 10, 15};
1896 const dim_t outTransposeDims[] = {1, 10, 15, 5};
1897 const dim_t outPadDims[] = {5, 18, 25, 11};
1898 // Padding before the optimization.
1899 int pads[] = {0, 2, 3, 1, 4, 6, 7, 5};
1900
1901 // The shape of the graph after the optimization.
1902 const dim_t outPadDimsAfterOptim[] = {5, 11, 18, 25};
1903 const dim_t outTransposeDimsAfterOptims[] = {5, 18, 25, 11};
1904 // Padding after the optimization.
1905 int padsAfterOptim[] = {0, 1, 2, 3, 4, 5, 6, 7};
1906
1907 // Create the initial graph.
1908 Node *A =
1909 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
1910 auto outTy = mod_.uniqueType(ElemKind::FloatTy, outPadDims);
1911 TransposeNode *T = F_->createTranspose("transpose", A, NCHW2NHWC);
1912 Node *P = F_->createPad("pad", T, outTy, PaddingMode::CONSTANT, pads, 23.f);
1913 EXPECT_EQ(T->getResult().dims(), llvm::makeArrayRef(outTransposeDims));
1914 SaveNode *O = F_->createSave("ret", P);
1915
1916 EXPECT_EQ(F_->getNodes().size(), 3);
1917
1918 ::glow::optimize(F_, CompilationMode::Infer);
1919
1920 // Check the graph structure and additional properties after optimization.
1921 auto *trans = llvm::dyn_cast<TransposeNode>(O->getInput());
1922 ASSERT_NE(trans, nullptr);
1923 EXPECT_EQ(trans->getResult().dims(),
1924 llvm::makeArrayRef(outTransposeDimsAfterOptims));
1925 auto *pad = llvm::dyn_cast<PadNode>(trans->getInput().getNode());
1926 ASSERT_NE(pad, nullptr);
1927
1928 EXPECT_EQ(pad->getPads(), llvm::makeArrayRef(padsAfterOptim));
1929 EXPECT_EQ(pad->getResult().dims(), llvm::makeArrayRef(outPadDimsAfterOptim));
1930
1931 EXPECT_EQ(F_->getNodes().size(), 3);
1932}
1933
1934TEST_F(GraphOptz, sinkTransposeBelowRelu) {
1935 // Define a type with custom alignments.
1936 Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1});
1937 Type transposedTypeWithAlignments(ElemKind::FloatTy, {2, 4, 5, 3},
1938 {1, 1, 32, 1});
1939 auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments);
1940 auto modTransposedTyWithAlignments =
1941 mod_.uniqueType(transposedTypeWithAlignments);
1942 auto *A1 = mod_.createPlaceholder(modTyWithAlignments, "input1", false);
1943 auto *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1944 T1->setType(0, modTransposedTyWithAlignments);
1945 auto *RN = F_->createRELU("relu", T1);
1946 SaveNode *O = F_->createSave("ret", RN);
1947
1948 EXPECT_EQ(F_->getNodes().size(), 3);
1949
1950 ::glow::optimize(F_, CompilationMode::Infer);
1951
1952 // Expecting Transpose->Output rather than Relu->Output, because Transpose was
1953 // sinked.
1954 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1955 ASSERT_NE(transpose, nullptr);
1956 auto *relu = llvm::dyn_cast<ReluNode>(transpose->getInput());
1957 ASSERT_TRUE(relu);
1958 // Check that alignments are preserved by optimizations.
1959 ASSERT_TRUE(relu->getInput().getType()->isEqual(modTyWithAlignments));
1960 ASSERT_TRUE(transpose->getInput().getType()->isEqual(modTyWithAlignments));
1961 ASSERT_TRUE(
1962 transpose->getResult().getType()->isEqual(modTransposedTyWithAlignments));
1963
1964 EXPECT_EQ(F_->getNodes().size(), 3);
1965 ASSERT_TRUE(F_->verify());
1966}
1967
1968TEST_F(GraphOptz, mergeConcatNodes) {
1969 Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1",
1970 false);
1971 Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2",
1972 false);
1973 Node *A3 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input3",
1974 false);
1975 Node *A4 =
1976 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input4", false);
1977 Node *A5 =
1978 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input5", false);
1979
1980 Node *CN1 = F_->createConcat("concat1", {A1, A2}, 1);
1981 Node *CN2 = F_->createConcat("concat2", {A1, CN1}, 1);
1982 Node *CN3 = F_->createConcat("concat3", {A4, A5}, 2);
1983 Node *CN4 = F_->createConcat("concat4", {A3, CN2, CN3}, 1);
1984 Node *O = F_->createSave("ret", CN4);
1985
1986 EXPECT_EQ(F_->getNodes().size(), 5);
1987
1988 ::glow::optimize(F_, CompilationMode::Train);
1989
1990 // It is expected that the optimization transforms
1991 // concat4(1, A3, concat2(1, A1, concat1(1, A1, A2)), concat3(2, A4, A5))
1992 // into
1993 // concat4(1, A3, A1, A1, A2, concat3(2, A4, A5))
1994
1995 EXPECT_TRUE(llvm::isa<SaveNode>(O));
1996
1997 auto *CN =
1998 llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
1999 EXPECT_TRUE(CN);
2000
2001 // The merged ConcatNode should have 5 inputs.
2002 EXPECT_EQ(CN->getInputs().size(), 5);
2003
2004 // CN1 should be merged into a new CN2 and later into a new CN4 and removed by
2005 // the optimizations.
2006 EXPECT_FALSE(functionContainsNode(F_, CN1));
2007
2008 // CN2 should be merged into a new CN4 and removed by the optimizations.
2009 EXPECT_FALSE(functionContainsNode(F_, CN2));
2010
2011 // CN3 should not be merged into CN4 and should not be removed,
2012 // because CN4 and CN3 have a different dimension parameter.
2013 EXPECT_TRUE(functionContainsNode(F_, CN3));
2014
2015 // The CN4 concat node should be replaced by a merged concat node.
2016 EXPECT_FALSE(functionContainsNode(F_, CN4));
2017
2018 EXPECT_EQ(F_->getNodes().size(), 3);
2019}
2020
2021TEST_F(GraphOptz, CSE) {
2022 Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1",
2023 false);
2024 Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2",
2025 false);
2026
2027 Node *CN1 = F_->createConcat("concat1", {A1, A2}, 1);
2028 Node *CN2 = F_->createConcat("concat2", {A1, A2}, 1);
2029 Node *CN3 = F_->createConcat("concat3", {CN1, CN2}, 2);
2030 Node *O = F_->createSave("ret", CN3);
2031
2032 EXPECT_EQ(F_->getNodes().size(), 4);
2033
2034 ::glow::optimize(F_, CompilationMode::Train);
2035
2036 EXPECT_TRUE(llvm::isa<SaveNode>(O));
2037
2038 auto *CN =
2039 llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
2040 EXPECT_TRUE(CN);
2041
2042 // The merged ConcatNode should have 2 inputs.
2043 EXPECT_EQ(CN->getInputs().size(), 2);
2044
2045 // CN1 should not be removed.
2046 EXPECT_TRUE(functionContainsNode(F_, CN1));
2047
2048 // CSE should replace CN2 by CN1 and remove CN2.
2049 EXPECT_FALSE(functionContainsNode(F_, CN2));
2050
2051 EXPECT_EQ(F_->getNodes().size(), 3);
2052}
2053
2054TEST_F(GraphOptz, SliceOfSplatNode) {
2055 Type t(ElemKind::FloatTy, {1000, 1000, 1000});
2056 Node *Z = F_->createSplat("zero", &t, 0.);
2057 Node *S = F_->createSlice("slice", Z, {5, 15, 42}, {99, 88, 77});
2058 Node *O = F_->createSave("ret", S);
2059
2060 EXPECT_EQ(F_->getNodes().size(), 3);
2061
2062 ::glow::optimize(F_, CompilationMode::Train);
2063
2064 EXPECT_EQ(F_->getNodes().size(), 2);
2065
2066 EXPECT_TRUE(llvm::isa<SaveNode>(O));
2067
2068 auto *CN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
2069 ASSERT_TRUE(CN);
2070
2071 EXPECT_TRUE(CN->getResult().getType()->dims().equals({94, 73, 35}));
2072}
2073
2074/// Test Clip(Splat(args)) -> Splat(args').
2075TEST_F(GraphOptz, ClipOfSplatNode) {
2076 Type T(ElemKind::FloatTy, {10, 10});
2077 SplatNode *splat = F_->createSplat("zero", &T, 5);
2078 ClipNode *clipMin = F_->createClip("clip", splat, 10, 15);
2079 ClipNode *clipMax = F_->createClip("clip", splat, 0, 2);
2080 ClipNode *clipSame = F_->createClip("clip", splat, 0, 10);
2081 SaveNode *saveMin = F_->createSave("saveMin", clipMin);
2082 SaveNode *saveMax = F_->createSave("saveMax", clipMax);
2083 SaveNode *saveSame = F_->createSave("saveSame", clipSame);
2084
2085 // Start with one splat, three clips, three saves.
2086 EXPECT_EQ(F_->getNodes().size(), 7);
2087
2088 ::glow::optimize(F_, CompilationMode::Infer);
2089
2090 // We will end up with three Splats and three saves.
2091 EXPECT_EQ(F_->getNodes().size(), 6);
2092
2093 SplatNode *splatMin = llvm::dyn_cast<SplatNode>(saveMin->getInput());
2094 ASSERT_TRUE(splatMin);
2095 EXPECT_EQ(splatMin->getValue(), 10);
2096
2097 SplatNode *splatMax = llvm::dyn_cast<SplatNode>(saveMax->getInput());
2098 ASSERT_TRUE(splatMax);
2099 EXPECT_EQ(splatMax->getValue(), 2);
2100
2101 ASSERT_EQ(saveSame->getInput().getNode(), splat);
2102 EXPECT_EQ(splat->getValue(), 5);
2103}
2104
2105TEST_F(GraphOptz, ZeroArithmetic) {
2106 // Tests the identities: [0 + X = X] [0 * X = 0] [0 / X = 0] [ X - 0 = X]
2107
2108 auto *input =
2109 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input", true);
2110
2111 // This builds the expression: ((0 / I) + (0 + I) + (0 * I)) - 0
2112
2113 auto *zero = F_->createSplat("zero", input->getType(), 0.);
2114
2115 auto *div = F_->createDiv("div", zero, input); // -> zero
2116
2117 auto *add = F_->createAdd("add", zero, input); // -> input
2118
2119 auto *mul = F_->createMul("mul", zero, input); // -> zero
2120
2121 auto *add3 = F_->createAdd("add", div, add);
2122
2123 add3 = F_->createAdd("add", add3, mul);
2124
2125 auto *sub = F_->createSub("sub", add3, zero); // -> input
2126
2127 SaveNode *O = F_->createSave("ret", sub);
2128
2129 // The expression evaluates to "I".
2130
2131 EXPECT_EQ(F_->getNodes().size(), 8);
2132
2133 ::glow::optimize(F_, CompilationMode::Infer);
2134
2135 EXPECT_EQ(F_->getNodes().size(), 1);
2136
2137 EXPECT_EQ(O->getInput().getNode(), input);
2138
2139 optimizedF_ = optimizeFunctionForTest(F_);
2140
2141 bindings_.allocate(mod_.getPlaceholders());
2142 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
2143
2144 checkNumericalEquivalence();
2145}
2146
2147// Similar to ZeroArithmetic, but tests that nodes with multiple results are
2148// correctly handled (i.e. that the correct output is selected after optimising
2149// away an arithmetic identity).
2150TEST_F(GraphOptz, ZeroArithmeticMultiResNode) {
2151 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", true);
2152 auto *topK = F_->createTopK("topK", input, /*k=*/5);
2153 auto *zero = F_->createSplat("zero", topK->getValues().getType(), 0.);
2154 auto *add = F_->createAdd("add", topK->getValues(), zero);
2155 auto *sub = F_->createSub("sub", topK->getValues(), zero);
2156
2157 SaveNode *AS = F_->createSave("ret", add);
2158 SaveNode *SS = F_->createSave("ret", sub);
2159
2160 // There should be 6 nodes: 2 Saves, Add, Sub, Splat and TopK.
2161 EXPECT_EQ(F_->getNodes().size(), 6);
2162
2163 optimizedF_ = optimizeFunctionForTest(F_);
2164
2165 // Now there should only be 3 nodes: TopK and 2 Saves.
2166 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
2167
2168 auto *OAS = findFunctionNodeByName<SaveNode>(optimizedF_, AS->getName());
2169 auto *OSS = findFunctionNodeByName<SaveNode>(optimizedF_, SS->getName());
2170 auto *OTopK = findFunctionNodeByName<TopKNode>(optimizedF_, topK->getName());
2171
2172 // Since the operations reprsented by the arithmetic nodes are no-ops,
2173 // the input to both SaveNodes should be the Values result of TopKNode.
2174 EXPECT_EQ(OAS->getInput(), OTopK->getValues());
2175 EXPECT_EQ(OSS->getInput(), OTopK->getValues());
2176
2177 // Check numerical equivalence.
2178 bindings_.allocate(mod_.getPlaceholders());
2179 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
2180
2181 checkNumericalEquivalence();
2182}
2183
2184/// A test that verifies that arithmetic simplification works correctly when
2185/// the parents need to be simplified prior to the node itself.
2186TEST_F(GraphOptz, ZeroArithmeticParentsMustBeSimplifiedFirst) {
2187 auto *input1 =
2188 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
2189 auto *input2 =
2190 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
2191
2192 // This builds the expression: ((0 * I1) * (0 * I2)) = 0
2193 // It should be simplified to simply the splat zero node being saved.
2194
2195 SplatNode *zero = F_->createSplat("zero", input1->getType(), 0.);
2196
2197 MulNode *mul1 = F_->createMul("mul1", zero, input1); // -> 0
2198 MulNode *mul2 = F_->createMul("mul2", zero, input2); // -> 0
2199
2200 MulNode *mul3 = F_->createMul("mul3", mul1, mul2); // -> 0
2201
2202 SaveNode *O = F_->createSave("ret", mul3);
2203
2204 // Expect 1 splat, 3 muls, 1 save.
2205 EXPECT_EQ(F_->getNodes().size(), 5);
2206
2207 ::glow::optimize(F_, CompilationMode::Infer);
2208
2209 // Expect all muls to be optimized away, with 1 splat and 1 save left.
2210 EXPECT_EQ(F_->getNodes().size(), 2);
2211 EXPECT_TRUE(functionContainsNode(F_, O));
2212 EXPECT_TRUE(functionContainsNode(F_, zero));
2213 EXPECT_EQ(O->getInput().getNode(), zero);
2214}
2215
2216/// Tests opts for the identities: [1 * X = X] [X / 1 = X]
2217TEST_F(GraphOptz, ArithmeticIdentitiesOne) {
2218 auto *input =
2219 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input", true);
2220
2221 // This builds the expression: (I / 1) * 1:
2222 SplatNode *one = F_->createSplat("one", input->getType(), 1.);
2223 DivNode *div = F_->createDiv("div", input, one);
2224 MulNode *mul = F_->createMul("mul", div, one);
2225 SaveNode *save = F_->createSave("ret", mul);
2226
2227 // Splat, Div, Mul, Save.
2228 EXPECT_EQ(F_->getNodes().size(), 4);
2229 // Save optimized function for future comparision
2230 optimizedF_ = optimizeFunctionForTest(F_);
2231
2232 // The expression evaluates to "I", so Save is only node left.
2233 EXPECT_EQ(optimizedF_->getNodes().size(), 1);
2234 SaveNode *SN =
2235 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
2236 ASSERT_TRUE(functionContainsNode(optimizedF_, SN));
2237 ASSERT_NE(SN, nullptr);
2238
2239 // Save node should just save the input.
2240 EXPECT_TRUE(SN->getInput().getNode() == input);
2241
2242 bindings_.allocate(mod_.getPlaceholders());
2243 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
2244
2245 checkNumericalEquivalence();
2246}
2247
2248/// Reverse the intrusive list of nodes. This custom implementation is required,
2249/// because std::reverse cannot be used with LLVM's intrusive lists.
2250static void reverse(NodesList &L) {
2251 if (L.empty())
2252 return;
2253 // Last element of the list before reversal.
2254 auto &last = L.back();
2255 // Take element from the beginning and move it right after the old last
2256 // element. Do it until the old last element becomes the first element.
2257 while (true) {
2258 auto &first = L.front();
2259 // Finish when the old last element becomes the new front element.
2260 if (&first == &last) {
2261 break;
2262 }
2263 L.remove(first);
2264 L.insert(++last.getIterator(), &first);
2265 }
2266}
2267
2268TEST(GraphOptzTest, SliceOfSplatNodeChain) {
2269 for (int shouldReverse = 0; shouldReverse <= 1; shouldReverse++) {
2270 Module mod;
2271 Function *F = mod.createFunction("foo");
2272
2273 Type t(ElemKind::FloatTy, {1000, 1000, 1000});
2274 Node *Z = F->createSplat("zero", &t, 0.);
2275 Node *S1 = F->createSlice("slice1", Z, {5, 15, 42}, {99, 88, 77});
2276 Node *S2 = F->createSlice("slice2", S1, {1, 1, 1}, {2, 3, 4});
2277 F->createSave("ret", S2);
2278
2279 if (shouldReverse) {
2280 auto &nodes = F->getNodes();
2281 reverse(nodes);
2282 }
2283
2284 EXPECT_EQ(F->getNodes().size(), 4);
2285
2286 CompilationContext cctx;
2287 cctx.compMode = CompilationMode::Train;
2288 // Do not perform any compile-time constant folding.
2289 cctx.optimizationOpts.enableConstantFolding = false;
2290 ::glow::optimize(F, cctx);
2291
2292 // This test illustrates some inconsistency in the optimization.
2293 // Chain splats are not guaranteed to be optimized.
2294 EXPECT_EQ(F->getNodes().size(), shouldReverse ? 3 : 2);
2295 }
2296}
2297
2298TEST_F(GraphOptz, ReshapeNoop) {
2299 const dim_t shape[] = {10, 20, 30};
2300 Type t(ElemKind::FloatTy, shape);
2301 auto *Z = F_->createSplat("zero", &t, 0.);
2302 auto *R = F_->createReshape("reshape", Z, shape);
2303 auto *O = F_->createSave("ret", R);
2304
2305 EXPECT_EQ(F_->getNodes().size(), 3);
2306
2307 ::glow::optimize(F_, CompilationMode::Train);
2308
2309 EXPECT_EQ(F_->getNodes().size(), 2);
2310
2311 auto *SN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
2312 EXPECT_TRUE(SN);
2313
2314 EXPECT_TRUE(SN->getResult().getType()->dims().equals(shape));
2315}
2316
2317/// Test the Reshape(Splat(args)) -> Splat(args') transformation.
2318/// Including a positive and a negative test case. In the positive case,
2319/// the optimization will take place for the splat node (Z2) that has only one
2320/// use. In the negative case, the optimization will not happen as the splat
2321/// node (Z1) has more than one use.
2322TEST_F(GraphOptz, ReshapeAfterSplat) {
2323 const dim_t shape[] = {10, 20, 30};
2324 const dim_t reshape[] = {1, 6000};
2325 Type t1(ElemKind::FloatTy, shape);
2326 Type t2(ElemKind::FloatTy, reshape);
2327 Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2328 "input", true);
2329 auto *Z1 = F_->createSplat("zero1", &t1, 1.5);
2330 auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input, Z1);
2331 auto *R1 = F_->createReshape("reshape1", Z1, reshape);
2332 // Z1 is used by R1 and A1.
2333 // The reshape optimization will thus NOT be able to remove this reshape node
2334 // (R1).
2335 auto *R2 = F_->createReshape("reshape2", A1, reshape);
2336 auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, R2);
2337 auto *Z2 = F_->createSplat("zero2", &t1, 2.5);
2338 auto *R3 = F_->createReshape("reshape3", Z2, reshape);
2339 // Z2 is only used by R3.
2340 // The Z2,R3 nodes will be replaced by a new splat node with the shape of R3.
2341 auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R3);
2342 auto *O = F_->createSave("ret", A3);
2343
2344 // Before optimization, we have 9 nodes in the graph.
2345 EXPECT_EQ(F_->getNodes().size(), 9);
2346
2347 cctx_.compMode = CompilationMode::Infer;
2348 // Do not perform any compile-time constant folding.
2349 cctx_.optimizationOpts.enableConstantFolding = false;
2350 ::glow::optimize(F_, cctx_);
2351
2352 // After optimization, we expect to see only 8 nodes, as Z2,R2 would be
2353 // replace by a new splat node.
2354 EXPECT_EQ(F_->getNodes().size(), 8);
2355
2356 // The second input of A3 shoule be a splat node with a shape of R3.
2357 auto *newA3 = llvm::dyn_cast<AddNode>(O->getInput());
2358 ASSERT_TRUE(newA3);
2359 auto *SN = llvm::dyn_cast<SplatNode>(newA3->getRHS());
2360 EXPECT_TRUE(SN);
2361 EXPECT_TRUE(SN->getResult().getType()->dims().equals(reshape));
2362
2363 // R1 should still be in the graph.
2364 EXPECT_TRUE(functionContainsNode(F_, R1));
2365
2366 // R3 and Z2 should not be in the graph any more.
2367 EXPECT_FALSE(functionContainsNode(F_, R3));
2368 EXPECT_FALSE(functionContainsNode(F_, Z2));
2369}
2370
2371/// Test the Reshape(Reshape(x)) -> Reshape(x) transformation.
2372TEST_F(GraphOptz, ReshapeReshapeOpt) {
2373 const dim_t shape[] = {10, 20};
2374 const dim_t reshape1[] = {200, 1};
2375 const dim_t reshape2[] = {200};
2376 Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2377 "input", true);
2378 auto *R1 = F_->createReshape("reshape1", input, reshape1);
2379 auto *R2 = F_->createReshape("reshape2", R1, reshape2);
2380 auto *O = F_->createSave("ret", R2);
2381
2382 // Before optimization, we have 2 Reshapes and a Save.
2383 EXPECT_EQ(F_->getNodes().size(), 3);
2384
2385 ::glow::optimize(F_, CompilationMode::Infer);
2386
2387 // After optimization, we expect to see only 1 Reshape and a Save.
2388 EXPECT_EQ(F_->getNodes().size(), 2);
2389
2390 // Save should have the new Reshape as input.
2391 auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput());
2392 ASSERT_TRUE(RN);
2393 // The new Reshape should have the same shape as the original second Reshape.
2394 EXPECT_TRUE(RN->getResult().getType()->dims().equals(reshape2));
2395
2396 // R1 and R2 should not be in the graph any more; they were replaced by a
2397 // single new reshape.
2398 EXPECT_FALSE(functionContainsNode(F_, R1));
2399 EXPECT_FALSE(functionContainsNode(F_, R2));
2400}
2401
2402TEST_F(GraphOptz, DCEPublicVars) {
2403 mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
2404
2405 EXPECT_EQ(mod_.getPlaceholders().size(), 1);
2406
2407 // Optimize all of the dead code.
2408 ::glow::optimize(F_, CompilationMode::Infer);
2409
2410 // Public nodes should not be deleted.
2411 EXPECT_EQ(mod_.getPlaceholders().size(), 1);
2412}
2413
2414TEST_F(GraphOptz, foldQuantizeIntoConstant) {
2415 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input", true);
2416 *bindings_.allocate(input) = {10, 10, 10, 10};
2417 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2418
2419 auto *Q = F_->createQuantize("quantize", input, qType);
2420 auto *S = F_->createSave("save", Q);
2421
2422 EXPECT_EQ(2, F_->getNodes().size());
2423 ::glow::convertPlaceholdersToConstants(F_, bindings_, {S->getPlaceholder()});
2424
2425 // 'optimize' doesn't merge quantize nodes into Constant.
2426 ::glow::optimize(F_, CompilationMode::Infer);
2427 EXPECT_EQ(2, F_->getNodes().size());
2428
2429 // 'convertQuantizedConstants' merges quantize nodes into Constant
2430 CompilationContext cctx;
2431 ::glow::convertQuantizedConstants(F_, cctx);
2432 EXPECT_EQ(1, F_->getNodes().size());
2433
2434 auto quantizedInput = llvm::cast<Constant>(S->getInput());
2435 auto quantizedValues = quantizedInput->getHandle<int8_t>();
2436 for (unsigned i = 0; i < 4; ++i) {
2437 EXPECT_EQ(5, quantizedValues.raw(i));
2438 }
2439}
2440
2441TEST_F(GraphOptz, foldQuantizeIntoConstantMultipleUsages) {
2442 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input", true);
2443 *bindings_.allocate(input) = {10, 10, 10, 10};
2444 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2445
2446 auto *Q = F_->createQuantize("quantize", input, qType);
2447 F_->createSave("save", Q);
2448 auto clonedF = F_->clone("cloned");
2449
2450 EXPECT_EQ(2, clonedF->getNodes().size());
2451 ::glow::convertPlaceholdersToConstants(clonedF, bindings_, {});
2452 CompilationContext cctx;
2453 ::glow::convertQuantizedConstants(clonedF, cctx);
2454
2455 // F_ function should not be affected.
2456 EXPECT_EQ(2, F_->getNodes().size());
2457
2458 // Check original var.
2459 for (unsigned i = 0; i < 4; ++i) {
2460 EXPECT_EQ(10, bindings_.get(input)->getHandle().raw(i));
2461 }
2462
2463 // Quantization node was merged into input var.
2464 EXPECT_EQ(1, clonedF->getNodes().size());
2465 auto *save = llvm::dyn_cast<SaveNode>(&clonedF->getNodes().front());
2466 ASSERT_TRUE(save);
2467 auto quantizedInput = llvm::cast<Constant>(save->getInput());
2468 auto quantizedValues = quantizedInput->getHandle<int8_t>();
2469 for (unsigned i = 0; i < 4; ++i) {
2470 EXPECT_EQ(5, quantizedValues.raw(i));
2471 }
2472}
2473
2474/// Search for a unique Save node in input graph \p F and return it.
2475/// Fails in case there is no Save node or more than one detected.
2476static SaveNode *getUniqueSaveNode(Function *F) {
2477 SaveNode *foundSaveNode = nullptr;
2478 for (auto &node : F->getNodes()) {
2479 if (auto *s = llvm::dyn_cast<SaveNode>(&node)) {
2480 EXPECT_EQ(foundSaveNode, nullptr);
2481 foundSaveNode = s;
2482 }
2483 }
2484 EXPECT_NE(foundSaveNode, nullptr);
2485 return foundSaveNode;
2486}
2487
2488/// Mock backend that requests the pre-quantization of constants.
2489class MockBackendPrequantizeConst : public MockBackend {
2490 bool shouldPreQuantizeConstants() const override { return true; }
2491 bool isOpSupported(const NodeInfo &) const override { return true; }
2492 Expected<bool>
2493 transformPostLowering(Function *F, CompilationContext &,
2494 const glow::runtime::DeviceInfo *) const override {
2495 // Check the IR.
2496 EXPECT_EQ(F->getNodes().size(), 1);
2497 auto *save = getUniqueSaveNode(F);
2498 EXPECT_TRUE(llvm::isa<Constant>(save->getInput()));
2499
2500 return false;
2501 }
2502};
2503/// Mock backend that requests the non pre-quantization of constants.
2504class MockBackendNotPrequantizeConst : public MockBackend {
2505 bool shouldPreQuantizeConstants() const override { return false; }
2506 bool isOpSupported(const NodeInfo &) const override { return true; }
2507 Expected<bool>
2508 transformPostLowering(Function *F, CompilationContext &,
2509 const glow::runtime::DeviceInfo *) const override {
2510 // Check the IR.
2511 EXPECT_EQ(F->getNodes().size(), 2);
2512 auto *save = getUniqueSaveNode(F);
2513 auto *quant = llvm::dyn_cast<QuantizeNode>(save->getInput());
2514 EXPECT_TRUE(quant);
2515 EXPECT_TRUE(llvm::isa<Constant>(quant->getInput()));
2516
2517 return false;
2518 }
2519};
2520
2521/// Test the actual constant quantization for backends.
2522template <typename Backend>
2523void testFoldQuantizeIntoConstant(Module &mod_, Function *F_) {
2524 auto *input = mod_.createConstant(ElemKind::FloatTy, {4}, "input");
2525 input->getHandle<float>() = {10, 10, 10, 10};
2526 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2527 auto *Q = F_->createQuantize("quantize", input, qType);
2528 auto *save = F_->createSave("save", Q);
2529
2530 CompilationContext cctx;
2531 auto B = Backend();
2532 // Note: the check that Quantize is or not folded into Constant before
2533 // post-lowering is done in <backend>::transformPostLowering()
2534 EXIT_ON_ERR(::glow::optimizeFunction(F_, B, cctx));
2535
2536 // Check the IR (the constant must have been quantized).
2537 EXPECT_EQ(F_->getNodes().size(), 1);
2538 EXPECT_TRUE(llvm::isa<Constant>(save->getInput()));
2539}
2540
2541/// Check the backend actual constant quantization is done before post-lowering.
2542TEST_F(GraphOptz, foldQuantizeIntoConstantBeforePostLowering) {
2543 testFoldQuantizeIntoConstant<MockBackendPrequantizeConst>(mod_, F_);
2544}
2545
2546/// Check the backend actual constant quantization is done after post-lowering.
2547TEST_F(GraphOptz, foldQuantizeIntoConstantAfterPostLowering) {
2548 testFoldQuantizeIntoConstant<MockBackendNotPrequantizeConst>(mod_, F_);
2549}
2550
2551/// Check that the Quantize(Splat) -> Splat' optimization works.
2552TEST_F(GraphOptz, foldQuantizeIntoSplat) {
2553 TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2554 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2555
2556 const float splatVal = 6.0;
2557 SplatNode *SN = F_->createSplat("splat", fType, splatVal);
2558
2559 QuantizeNode *Q = F_->createQuantize("quantize", SN, qType);
2560 SaveNode *S = F_->createSave("save", Q);
2561
2562 // Splat, quantize, save.
2563 EXPECT_EQ(3, F_->getNodes().size());
2564
2565 ::glow::optimize(F_, CompilationMode::Infer);
2566
2567 // Quantization node was merged into input splat.
2568 EXPECT_EQ(2, F_->getNodes().size());
2569
2570 // New quantized splat should exist with same value.
2571 SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput());
2572 ASSERT_TRUE(newSN);
2573 EXPECT_EQ(splatVal, newSN->getValue());
2574 EXPECT_EQ(qType, newSN->getResult().getType());
2575}
2576
2577/// Check that the Dequantize(Splat) -> Splat' optimization works.
2578TEST_F(GraphOptz, foldDequantizeIntoSplat) {
2579 TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2580 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2581
2582 const float splatVal = 6.0;
2583 SplatNode *SN = F_->createSplat("splat", qType, splatVal);
2584
2585 DequantizeNode *Q = F_->createDequantize("dequantize", SN, ElemKind::FloatTy);
2586 SaveNode *S = F_->createSave("save", Q);
2587
2588 // Splat, dequantize, save.
2589 EXPECT_EQ(3, F_->getNodes().size());
2590
2591 ::glow::optimize(F_, CompilationMode::Infer);
2592
2593 // Dequantization node was merged into input splat.
2594 EXPECT_EQ(2, F_->getNodes().size());
2595
2596 // New quantized splat should exist with same value.
2597 SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput());
2598 ASSERT_TRUE(newSN);
2599 EXPECT_EQ(splatVal, newSN->getValue());
2600 EXPECT_EQ(fType, newSN->getResult().getType());
2601}
2602
2603/// Check that the Quantize(Splat) -> Splat' optimization works when the Splat
2604/// has multiple users.
2605TEST_F(GraphOptz, foldQuantizeIntoSplatMultipleUsers) {
2606 TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2607 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2608
2609 SplatNode *SN = F_->createSplat("splat", fType, 6.0);
2610
2611 QuantizeNode *Q = F_->createQuantize("quantize", SN, qType);
2612 SaveNode *SQ = F_->createSave("saveQ", Q);
2613 SaveNode *SF = F_->createSave("saveF", SN);
2614
2615 // Splat, quantize, 2 saves.
2616 EXPECT_EQ(4, F_->getNodes().size());
2617
2618 ::glow::optimize(F_, CompilationMode::Infer);
2619
2620 // Quantization node was merged into input splat creating a new quantized
2621 // splat, but the original float splat still exists.
2622 EXPECT_EQ(4, F_->getNodes().size());
2623
2624 // New quantized splat should exist with same value.
2625 SplatNode *newSN = llvm::dyn_cast<SplatNode>(SQ->getInput());
2626 ASSERT_TRUE(newSN);
2627 EXPECT_EQ(SN->getValue(), newSN->getValue());
2628 EXPECT_EQ(qType, newSN->getResult().getType());
2629
2630 // Original float splat should still exist.
2631 EXPECT_EQ(llvm::dyn_cast<SplatNode>(SF->getInput()), SN);
2632}
2633
2634/// Check that an unnecessary rescale gets removed.
2635TEST_F(GraphOptz, removeUnnecessaryRescale) {
2636 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5);
2637 Placeholder *input =
2638 mod_.createPlaceholder(qType, "input", /* isTrainable */ true);
2639 RescaleQuantizedNode *RQ =
2640 F_->createRescaleQuantized("rescale", input, qType);
2641 SaveNode *save = F_->createSave("ret", RQ);
2642
2643 // RescaleQuantized and Save.
2644 EXPECT_EQ(F_->getNodes().size(), 2);
2645
2646 ::glow::optimize(F_, CompilationMode::Infer);
2647
2648 // Only Save should be left, which saves the Placeholder directly with
2649 // unchanged quantization parameters.
2650 EXPECT_EQ(F_->getNodes().size(), 1);
2651 EXPECT_EQ(save->getInput().getNode(), input);
2652 EXPECT_EQ(save->getInput().getType(), qType);
2653}
2654
2655/// Check that rescale gets correctly merged into a following dequantize node
2656TEST_F(GraphOptz, mergeRescaleIntoDequantize) {
2657 // Check that we are combining quantization-dequantization pairs.
2658 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2659 "input", true);
2660 auto *qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5);
2661 auto *R = F_->createRescaleQuantized("rescale", input, qType);
2662 auto *D = F_->createDequantize("dequantize", R, ElemKind::FloatTy);
2663 F_->createSave("ret", D);
2664
2665 EXPECT_EQ(F_->getNodes().size(), 3);
2666 ::glow::optimize(F_, CompilationMode::Infer);
2667
2668 // Only 2 nodes should remain (Dequantize -> Save)
2669 EXPECT_EQ(F_->getNodes().size(), 2);
2670
2671 // Check the graph structure
2672 auto *SN = F_->getNodeByName("ret_save");
2673 EXPECT_NE(nullptr, SN);
2674 auto *S = llvm::dyn_cast<SaveNode>(SN);
2675 EXPECT_NE(nullptr, S);
2676 auto *newDN = S->getInput().getNode();
2677 EXPECT_NE(nullptr, newDN);
2678 EXPECT_NE(nullptr, llvm::dyn_cast<DequantizeNode>(newDN));
2679}
2680
2681TEST_F(GraphOptz, quantizeToRescale) {
2682 // Check that we are combining quantization-dequantization pairs.
2683 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2684 "input", true);
2685
2686 auto *D = F_->createDequantize("dequantize", input, ElemKind::FloatTy);
2687
2688 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03, 5);
2689 auto *Q = F_->createQuantize("quantize", D, qType);
2690
2691 F_->createSave("ret", Q);
2692
2693 EXPECT_EQ(F_->getNodes().size(), 3);
2694
2695 ::glow::optimize(F_, CompilationMode::Infer);
2696 EXPECT_EQ(F_->getNodes().size(), 2);
2697}
2698
2699TEST_F(GraphOptz, MaxOfQuantizedSplat) {
2700 const dim_t size = 5;
2701 const float scale = 1;
2702 // offset == -128 guarantees that fp range has values which are not less than
2703 // 0.
2704 const int32_t offset = -128;
2705
2706 auto splatTy = mod_.uniqueType(ElemKind::Int8QTy, {size}, scale, offset);
2707 auto *splat = F_->createSplat("splat", splatTy, 0.0);
2708
2709 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {size}, scale, offset,
2710 "input", true);
2711
2712 auto *max = F_->createMax("max", splat, input);
2713 F_->createSave("save", max);
2714 EXPECT_EQ(F_->getNodes().size(), 3);
2715
2716 ::glow::optimize(F_, CompilationMode::Infer);
2717 // Splat and Max should be gone.
2718 EXPECT_EQ(F_->getNodes().size(), 1);
2719}
2720
2721TEST_F(GraphOptz, FuseRescaleIntoArithmetic) {
2722 // This test ensures the fact that fusing of rescale is done.
2723 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0);
2724 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1);
2725
2726 Placeholder *LHS =
2727 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS", true);
2728 Placeholder *RHS =
2729 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS", true);
2730
2731 AddNode *add = F_->createAdd("qAdd", opOutTy, LHS, RHS);
2732 RescaleQuantizedNode *rescaleAdd =
2733 F_->createRescaleQuantized("rsAdd", add, rescaleOutTy);
2734 SaveNode *addSave = F_->createSave("saveAdd", rescaleAdd);
2735
2736 SubNode *sub = F_->createSub("qSub", opOutTy, LHS, RHS);
2737 RescaleQuantizedNode *rescaleSub =
2738 F_->createRescaleQuantized("rsSub", sub, rescaleOutTy);
2739 SaveNode *subSave = F_->createSave("saveSub", rescaleSub);
2740
2741 DivNode *div = F_->createDiv("qDiv", opOutTy, LHS, RHS);
2742 RescaleQuantizedNode *rescaleDiv =
2743 F_->createRescaleQuantized("rsDiv", div, rescaleOutTy);
2744 SaveNode *divSave = F_->createSave("saveDiv", rescaleDiv);
2745
2746 MulNode *mul = F_->createMul("qMul", opOutTy, LHS, RHS);
2747 RescaleQuantizedNode *rescaleMul =
2748 F_->createRescaleQuantized("rsMul", mul, rescaleOutTy);
2749 SaveNode *mulSave = F_->createSave("saveMul", rescaleMul);
2750
2751 MinNode *min = F_->createMin("qMin", opOutTy, LHS, RHS);
2752 RescaleQuantizedNode *rescaleMin =
2753 F_->createRescaleQuantized("rsMin", min, rescaleOutTy);
2754 SaveNode *minSave = F_->createSave("saveMin", rescaleMin);
2755
2756 MaxNode *max = F_->createMax("qMax", opOutTy, LHS, RHS);
2757 RescaleQuantizedNode *rescaleMax =
2758 F_->createRescaleQuantized("rsMax", max, rescaleOutTy);
2759 SaveNode *maxSave = F_->createSave("saveMax", rescaleMax);
2760
2761 // All rescales must be fused into arithmetic operations above.
2762 ::glow::optimize(F_, CompilationMode::Infer);
2763
2764 EXPECT_EQ(F_->getNodes().size(), 12);
2765
2766 EXPECT_EQ(addSave->getInput().getType(), rescaleOutTy);
2767 EXPECT_EQ(subSave->getInput().getType(), rescaleOutTy);
2768 EXPECT_EQ(mulSave->getInput().getType(), rescaleOutTy);
2769 EXPECT_EQ(divSave->getInput().getType(), rescaleOutTy);
2770 EXPECT_EQ(minSave->getInput().getType(), rescaleOutTy);
2771 EXPECT_EQ(maxSave->getInput().getType(), rescaleOutTy);
2772}
2773
2774/// Check that the Rescale(MatMul) -> MatMul' optimization works correctly.
2775TEST_F(GraphOptz, FuseRescaleUpIntoMatMul) {
2776 // This test ensures the fact that fusing of rescale is done.
2777 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0);
2778 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1);
2779
2780 Placeholder *LHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0,
2781 "LHS", /* isTrainable */ false);
2782 Placeholder *RHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0,
2783 "RHS", /* isTrainable */ false);
2784
2785 MatMulNode *MMN = F_->createMatMul("matmul", opOutTy, LHS, RHS);
2786 RescaleQuantizedNode *rescaleMMN =
2787 F_->createRescaleQuantized("rsMMN", MMN, rescaleOutTy);
2788 SaveNode *saveMMN = F_->createSave("saveMMN", rescaleMMN);
2789
2790 // MatMul, Rescale, Save.
2791 EXPECT_EQ(F_->getNodes().size(), 3);
2792
2793 // All rescales must be fused into arithmetic operations above.
2794 ::glow::optimize(F_, CompilationMode::Infer);
2795
2796 // Rescale merged up into the MatMul.
2797 EXPECT_EQ(F_->getNodes().size(), 2);
2798
2799 MatMulNode *newMMN = llvm::dyn_cast<MatMulNode>(saveMMN->getInput());
2800 ASSERT_TRUE(newMMN);
2801 EXPECT_EQ(newMMN->getResult().getType(), rescaleOutTy);
2802}
2803
2804/// Check that the Rescale(SparseLengthsWeightedSum) ->
2805/// SparseLengthsWeightedSum' optimization works correctly.
2806TEST_F(GraphOptz, FuseRescaleUpIntoSparseLengthsWeightedSum) {
2807 // This test ensures the fact that fusing of rescale is done.
2808 TypeRef rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 1);
2809
2810 Placeholder *data =
2811 mod_.createPlaceholder(ElemKind::Int8QTy, {3}, 0.5, 0, "data",
2812 /* isTrainable */ false);
2813 Placeholder *weights = mod_.createPlaceholder(
2814 ElemKind::Int8QTy, {8}, 0.5, 0, "weights", /* isTrainable */ false);
2815 Placeholder *indices =
2816 mod_.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
2817 /* isTrainable */ false);
2818 Placeholder *lengths =
2819 mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
2820 /* isTrainable */ false);
2821
2822 SparseLengthsWeightedSumNode *SLWS = F_->createSparseLengthsWeightedSum(
2823 "SLWS", data, weights, indices, lengths);
2824 RescaleQuantizedNode *rescaleSLWS =
2825 F_->createRescaleQuantized("rsSLWS", SLWS, rescaleOutTy);
2826 SaveNode *saveSLWS = F_->createSave("saveSLWS", rescaleSLWS);
2827
2828 // SparseLengthsWeightedSum, Rescale, Save.
2829 EXPECT_EQ(F_->getNodes().size(), 3);
2830
2831 // All rescales must be fused into arithmetic operations above.
2832 ::glow::optimize(F_, CompilationMode::Infer);
2833
2834 // Rescale merged up into the SparseLengthsWeightedSum.
2835 EXPECT_EQ(F_->getNodes().size(), 2);
2836
2837 SparseLengthsWeightedSumNode *newSLWS =
2838 llvm::dyn_cast<SparseLengthsWeightedSumNode>(saveSLWS->getInput());
2839 ASSERT_TRUE(newSLWS);
2840 EXPECT_EQ(newSLWS->getResult().getType(), rescaleOutTy);
2841}
2842
2843TEST_F(GraphOptz, fuseRescaleIntoConv) {
2844 // This test ensures the fact that fusing of rescale is done.
2845 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.5,
2846 10, "input", true);
2847 auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.5,
2848 10, "filter", true);
2849 auto *bias =
2850 mod_.createPlaceholder(ElemKind::Int8QTy, {16}, 0.5, 10, "bias", true);
2851
2852 auto *rInput = F_->createRescaleQuantized(
2853 "rescale", input,
2854 mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.1, -25));
2855 auto *rFilter = F_->createRescaleQuantized(
2856 "rescale", filter,
2857 mod_.uniqueType(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.2, 0));
2858 auto *rBias = F_->createRescaleQuantized(
2859 "rescale", bias, mod_.uniqueType(ElemKind::Int8QTy, {16}, 0.3, 25));
2860 auto *CV = F_->createConv(
2861 "conv", rInput, rFilter, rBias,
2862 mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.7, -3), 5, 1, 2, 1);
2863 auto *rCV = F_->createRescaleQuantized(
2864 "rescale", CV,
2865 mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.4, 37));
2866 F_->createSave("save", rCV);
2867
2868 // All rescales must be fused into convolution.
2869 EXPECT_EQ(F_->getNodes().size(), 6);
2870 ::glow::optimize(F_, CompilationMode::Infer);
2871 EXPECT_EQ(F_->getNodes().size(), 2);
2872}
2873
2874/// This test ensures that if there is a Pad node as input of a Convolution
2875/// node, Pad gets merges into Convolution.
2876/// Note that Pads is merged into convolution only when it is compatible with
2877/// the convolution padding:
2878/// - Resulting padding after merge is positive
2879/// - Padding only concerns spatial dimensions
2880/// - Padding has mode 'constant' with value 0.f
2881void fusePadIntoConvTest(glow::Module &mod_, glow::Function *F_,
2882 llvm::ArrayRef<dim_t> inputDims,
2883 llvm::ArrayRef<int> pads, unsigned_t convKernelSize,
2884 llvm::ArrayRef<unsigned_t> convPads,
2885 unsigned_t convStride, unsigned_t convNumKernels) {
2886 auto *input =
2887 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", true);
2888
2889 // Pad
2890 dim_t inputWithPadDims[4];
2891 for (int i = 0; i < 4; i++) {
2892 inputWithPadDims[i] = dim_t(ssize_t(inputDims[i]) + pads[i] + pads[4 + i]);
2893 }
2894 dim_t outputConvDims[4] = {
2895 inputWithPadDims[0],
2896 inputWithPadDims[1] + convPads[0] + convPads[2] - (convKernelSize - 1),
2897 inputWithPadDims[2] + convPads[1] + convPads[3] - (convKernelSize - 1),
2898 convNumKernels};
2899
2900 auto outTy = mod_.uniqueType(ElemKind::FloatTy, inputWithPadDims);
2901 Node *P =
2902 F_->createPad("pad", input, outTy, PaddingMode::CONSTANT, pads, 0.f);
2903
2904 // Convolution
2905 dim_t filterDims[] = {convNumKernels, convKernelSize, convKernelSize,
2906 inputDims[3]};
2907 auto *F =
2908 mod_.createPlaceholder(ElemKind::FloatTy, filterDims, "filter", true);
2909 auto *B =
2910 mod_.createPlaceholder(ElemKind::FloatTy, {convNumKernels}, "bias", true);
2911 auto *CV = F_->createConv(
2912 "conv", P, F, B, mod_.uniqueType(ElemKind::FloatTy, outputConvDims),
2913 {convKernelSize, convKernelSize}, {convStride, convStride}, convPads, 1);
2914
2915 SaveNode *O = F_->createSave("save", CV);
2916
2917 // The pad node must be merged into convolution.
2918 EXPECT_EQ(F_->getNodes().size(), 3);
2919 ::glow::optimize(F_, CompilationMode::Infer);
2920 EXPECT_EQ(F_->getNodes().size(), 2);
2921
2922 // Check the graph structure and additional properties after optimization.
2923 auto *conv = llvm::dyn_cast<ConvolutionNode>(O->getInput());
2924 ASSERT_NE(conv, nullptr);
2925 EXPECT_EQ(conv->getResult().dims(), llvm::ArrayRef<dim_t>(outputConvDims));
2926 unsigned_t expectedPads[4];
2927 for (int i = 0; i < 2; i++) {
2928 for (int j = 0; j < 2; j++) {
2929 expectedPads[2 * i + j] =
2930 unsigned_t(int(convPads[2 * i + j]) + pads[4 * i + (1 + j)]);
2931 }
2932 }
2933 EXPECT_EQ(conv->getPads(), llvm::makeArrayRef(expectedPads));
2934}
2935
2936TEST_F(GraphOptz, fusePadIntoConv) {
2937 fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2938 {0, 1, 2, 0, 0, 3, 4, 0} /* pads */,
2939 5 /* convKernelSize */, {0, 0, 0, 0} /* convPads */,
2940 1 /* convStride */, 16 /* convNumKernels */);
2941}
2942
2943TEST_F(GraphOptz, fusePadIntoConvNeg1) {
2944 fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2945 {0, -1, 2, 0, 0, 3, -2, 0} /* pads */,
2946 5 /* convKernelSize */, {3, 0, 2, 5} /* convPads */,
2947 1 /* convStride */, 16 /* convNumKernels */);
2948}
2949
2950TEST_F(GraphOptz, fusePadIntoConvNeg2) {
2951 fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2952 {0, 1, -2, 0, 0, -3, 4, 0} /* pads */,
2953 5 /* convKernelSize */, {0, 2, 5, 7} /* convPads */,
2954 1 /* convStride */, 16 /* convNumKernels */);
2955}
2956
2957/// This test checks that a lowered LeakyRelu is corrected folded:
2958/// Max(A, Mult(A, Splat)) -> PRelu(Splat)
2959TEST_F(GraphFold, foldLeakyReluFromSplat) {
2960 std::vector<dim_t> dims = {5, 2};
2961
2962 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", true);
2963
2964 const float leakyAlpha = 0.05f;
2965 auto OutTy = mod_.uniqueType(ElemKind::FloatTy, dims);
2966 SplatNode *splatNode = F_->createSplat("splat", OutTy, leakyAlpha);
2967 MulNode *mulNode = F_->createMul("mul", input, splatNode);
2968 MaxNode *maxNode = F_->createMax("max", input, mulNode);
2969 SaveNode *output = F_->createSave("save", maxNode);
2970
2971 EXPECT_EQ(4, F_->getNodes().size());
2972
2973 CompilationContext cctx;
2974 ::glow::fold(F_, cctx);
2975
2976 // Check the resulting graph after folding.
2977 EXPECT_EQ(3, F_->getNodes().size());
2978 auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput());
2979 ASSERT_TRUE(newPReluNode);
2980 auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope());
2981 ASSERT_TRUE(newSplatNode);
2982 EXPECT_EQ(leakyAlpha, newSplatNode->getValue());
2983 EXPECT_EQ(input, newPReluNode->getInput());
2984}
2985
2986/// This test checks that a lowered LeakyRelu is corrected folded:
2987/// Max(A, Mult(A, broadcasted Const)) -> PRelu(Splat)
2988TEST_F(GraphFold, foldLeakyReluFromConst) {
2989 std::vector<dim_t> dims = {5, 2};
2990 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", true);
2991
2992 const float leakyAlpha = 0.99f;
2993 auto *alphaConst = mod_.createConstant(ElemKind::FloatTy, {1}, "alphaConst");
2994 alphaConst->getHandle() = {leakyAlpha};
2995 ReshapeNode *reshapeNode = F_->createReshape("reshape", alphaConst, {1, 1});
2996 TileNode *tileNode1 = F_->createTile("tile1", reshapeNode, 2, 1);
2997 TileNode *tileNode2 = F_->createTile("tile2", tileNode1, 5, 0);
2998 MulNode *mulNode = F_->createMul("mul", input, tileNode2);
2999 MaxNode *maxNode = F_->createMax("max", input, mulNode);
3000 SaveNode *output = F_->createSave("save", maxNode);
3001
3002 EXPECT_EQ(6, F_->getNodes().size());
3003
3004 CompilationContext cctx;
3005 ::glow::fold(F_, cctx);
3006
3007 // Check the resulting graph after folding. Reshape must have been merged into
3008 // the constant and LeakyRelu must have been folded.
3009 EXPECT_EQ(3, F_->getNodes().size());
3010 auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput());
3011 ASSERT_TRUE(newPReluNode);
3012 auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope());
3013 ASSERT_TRUE(newSplatNode);
3014 EXPECT_EQ(leakyAlpha, newSplatNode->getValue());
3015 EXPECT_EQ(input, newPReluNode->getInput());
3016}
3017
3018/// Test optimization of Convolution nodes with small input tensors by reducing
3019/// filters and removing redundant padding.
3020TEST_F(GraphFold, optimizeSmallConv) {
3021 auto *input =
3022 mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 2, 16}, "input", true);
3023 auto filter =
3024 mod_.createConstant(ElemKind::FloatTy, {16, 5, 5, 16}, "filter");
3025 auto bias = mod_.createConstant(ElemKind::FloatTy, {16}, "bias");
3026
3027 filter->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3028 bias->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3029
3030 auto *outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 1, 1, 16});
3031 auto *CN = F_->createConv("conv", input, filter, bias, outTy, {5, 5}, {2, 2},
3032 {2, 1, 1, 2}, 1);
3033 auto *save = F_->createSave("save", CN);
3034
3035 EXPECT_EQ(2, F_->getNodes().size());
3036 optimizedF_ = optimizeFunctionForTest(F_);
3037 EXPECT_EQ(2, optimizedF_->getNodes().size());
3038
3039 const auto *optSave =
3040 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
3041
3042 auto *newCN = llvm::dyn_cast<ConvolutionNode>(optSave->getInput());
3043 ASSERT_TRUE(newCN);
3044 // Kernel should be reduced.
3045 EXPECT_TRUE(isUniformArray(newCN->getKernels(), 2u));
3046 // Padding should be removed.
3047 EXPECT_TRUE(isUniformArray(newCN->getPads(), 0u));
3048 // Stride should be canonicalized to 1.
3049 EXPECT_TRUE(isUniformArray(newCN->getStrides(), 1u));
3050
3051 bindings_.allocate(mod_.getPlaceholders());
3052 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3053 checkNumericalEquivalence();
3054}
3055
3056TEST_F(GraphOptz, GatherToSliceOpt) {
3057 auto *LHS = mod_.createPlaceholder(ElemKind::Int32ITy, {16, 3}, "LHS", false);
3058 auto *RHS = mod_.createConstant(ElemKind::Int32ITy, {}, "RHS");
3059 RHS->getPayloadMutable().getHandle<int32_t>() = {1};
3060
3061 auto *gather = F_->createGather("gather", LHS, RHS, 1);
3062 auto *save = F_->createSave("save", gather);
3063
3064 optimizedF_ = optimizeFunctionForTest(F_);
3065
3066 auto *saveOpt =
3067 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
3068 ASSERT_TRUE(saveOpt);
3069 auto *reshapeN = llvm::dyn_cast<ReshapeNode>(saveOpt->getInput());
3070 ASSERT_TRUE(reshapeN);
3071 EXPECT_EQ(reshapeN->getResult().dims().size(), 1);
3072 EXPECT_EQ(reshapeN->getResult().dims()[0], 16);
3073
3074 bindings_.allocate(LHS)->getHandle<int32_t>().randomize(-128, 127,
3075 mod_.getPRNG());
3076 checkNumericalEquivalence();
3077}
3078
3079/// Fold a Convolution dilated manually using Transpose, SpaceToDepth and
3080/// DepthToSpace nodes into a single Convolution node. Pattern:
3081/// NHWC2CHWN -> S2D -> CHWN2NHWC -> Conv -> NHWC2CHWN -> D2S -> CHWN2NHWC
3082TEST_F(GraphFold, foldDilatedConv) {
3083 auto *input =
3084 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 16}, "input", true);
3085
3086 auto *T1 = F_->createTranspose("t1", input, NHWC2CHWN, "NHWC");
3087 auto *S2D = F_->createSpaceToDepth("s2d", T1, 2);
3088 auto *T2 = F_->createTranspose("t2", S2D, CHWN2NHWC, "NHWC");
3089 auto *CN = F_->createConv(bindings_, "conv", T2, 16, 3, 1, 0, 16, {1, 1});
3090 auto *T3 = F_->createTranspose("t3", CN, NHWC2CHWN, "NHWC");
3091 auto *D2S = F_->createDepthToSpace("d2s", T3, 2);
3092 auto *T4 = F_->createTranspose("t4", D2S, CHWN2NHWC, "NHWC");
3093 auto *save = F_->createSave("save", T4);
3094
3095 // To spice things up, add additional users for some nodes. The pattern should
3096 // still be recognized.
3097 F_->createSave("save_t1", T1);
3098 F_->createSave("save_s2d", S2D);
3099 F_->createSave("save_t2", T2);
3100
3101 EXPECT_EQ(13, F_->getNodes().size());
3102 optimizedF_ = optimizeFunctionForTest(F_);
3103 EXPECT_EQ(8, optimizedF_->getNodes().size());
3104
3105 const auto *optSave =
3106 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
3107
3108 auto *newCN = llvm::dyn_cast<ConvolutionNode>(optSave->getInput());
3109 ASSERT_TRUE(newCN);
3110 EXPECT_TRUE(isUniformArray(newCN->getDilation(), 2u));
3111
3112 bindings_.allocate(mod_.getPlaceholders());
3113 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3114 checkNumericalEquivalence();
3115}
3116
3117/// Fold a Convolution dilated manually using Transpose, SpaceToDepth and
3118/// DepthToSpace nodes into a single Convolution node. Pattern:
3119/// NHWC2CHWN -> S2D -> CHWN2NHWC -> Conv -> NHWC2CHWN -> D2S -> CHWN2NHWC
3120/// Test for ChannelwiseQuantizedConvolution.
3121TEST_F(GraphFold, foldDilatedConv_ChannelwiseQuantized) {
3122 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 10, 16}, 1.f,
3123 0, "input", true);
3124
3125 auto *filterF =
3126 mod_.createConstant(ElemKind::FloatTy, {16, 3, 3, 16}, "filterF");
3127 filterF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
3128 mod_.getPRNG());
3129 auto *biasF = mod_.createConstant(ElemKind::FloatTy, {16}, "biasF");
3130 biasF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
3131 mod_.getPRNG());
3132
3133 auto *T1 = F_->createTranspose("t1", input, NHWC2CHWN, "NHWC");
3134 auto *S2D = F_->createSpaceToDepth("s2d", T1, 2);
3135 auto *T2 = F_->createTranspose("t2", S2D, CHWN2NHWC, "NHWC");
3136 auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {4, 3, 3, 16}, 1.f, 0);
3137 auto *CN = F_->createChannelwiseQuantizedConv(
3138 "conv", T2, filterF, biasF, nullptr, nullptr, nullptr, nullptr, outTy,
3139 {3, 3}, {1, 1}, {0, 0, 0, 0}, 1, {1, 1}, true, true,
3140 quantization::Schema::Asymmetric, ElemKind::Int8QTy, ElemKind::Int32QTy);
3141 auto *T3 = F_->createTranspose("t3", CN, NHWC2CHWN, "NHWC");
3142 auto *D2S = F_->createDepthToSpace("d2s", T3, 2);
3143 auto *T4 = F_->createTranspose("t4", D2S, CHWN2NHWC, "NHWC");
3144 auto *save = F_->createSave("save", T4);
3145
3146 EXPECT_EQ(10, F_->getNodes().size());
3147 optimizedF_ = optimizeFunctionForTest(F_);
3148 EXPECT_EQ(2, optimizedF_->getNodes().size());
3149
3150 const auto *optSave =
3151 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
3152
3153 auto *newCN =
3154 llvm::dyn_cast<ChannelwiseQuantizedConvolutionNode>(optSave->getInput());
3155 ASSERT_TRUE(newCN);
3156 EXPECT_TRUE(isUniformArray(newCN->getDilation(), 2u));
3157
3158 bindings_.allocate(mod_.getPlaceholders());
3159 bindings_.get(input)->getHandle<int8_t>().randomize(-128, 127,
3160 mod_.getPRNG());
3161 checkNumericalEquivalence();
3162}
3163
3164/// Testing folding of Reshape->Transpose->Reshape into ChannelShuffle.
3165TEST_F(GraphFold, foldChannelShuffle) {
3166 const dim_t inputDims[] = {3, 136, 28, 28};
3167
3168 Node *K =
3169 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
3170 K = F_->createReshape("CS_reshape1", K, {3, 4, 34, 28, 28});
3171 K = F_->createTranspose("CS_transpose", K, {0, 2, 1, 3, 4});
3172 K = F_->createReshape("CS_reshape2", K, {3, 136, 28, 28});
3173 auto *save = F_->createSave("ret", K);
3174
3175 EXPECT_EQ(F_->getNodes().size(), 4);
3176
3177 // Fold RN->TR->RN into ChannelShuffle
3178 CompilationContext cctx;
3179 ::glow::fold(F_, cctx);
3180
3181 ASSERT_EQ(F_->getNodes().size(), 2);
3182
3183 // Check for ChannelShuffle node.
3184 auto *CS = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode());
3185 ASSERT_NE(nullptr, CS);
3186
3187 // Ensure ChannelShuffle node has the same dimensions as the input.
3188 EXPECT_EQ(CS->getResult().dims(), llvm::makeArrayRef(inputDims));
3189
3190 // Ensure Group and Kernel are as expected.
3191 EXPECT_EQ(CS->getGroup(), 4);
3192 EXPECT_EQ(CS->getKernel(), 1);
3193}
3194
3195TEST_F(GraphFold, NoFoldChannelShuffle) {
3196 auto Float = ElemKind::FloatTy;
3197 auto *P = mod_.createPlaceholder(Float, {10, 8928}, "P", false);
3198 auto *R1 = F_->createReshape("R1", P, {10, 186, 48});
3199 auto *TR = F_->createTranspose("TR", R1, {0, 2, 1});
3200 auto *R2 = F_->createReshape("R2", TR, {480, 186});
3201 auto *save = F_->createSave("save", R2);
3202
3203 EXPECT_EQ(F_->getNodes().size(), 4);
3204
3205 CompilationContext cctx;
3206 ::glow::fold(F_, cctx);
3207
3208 EXPECT_EQ(F_->getNodes().size(), 4);
3209 EXPECT_FALSE(llvm::isa<ChannelShuffleNode>(save->getInput()));
3210}
3211
3212class MockBackendWithFusion : public MockBackend {
3213 bool supportsFusedActivation(Node *parent, Node *activation) const override {
3214 switch (parent->getKind()) {
3215 case Kinded::Kind::ConvolutionNodeKind:
3216 switch (activation->getKind()) {
3217 case Kinded::Kind::ReluNodeKind:
3218 case Kinded::Kind::ClipNodeKind:
3219 case Kinded::Kind::SigmoidNodeKind:
3220 case Kinded::Kind::TanhNodeKind:
3221 case Kinded::Kind::LeakyReluNodeKind:
3222 return true;
3223 default:
3224 return false;
3225 }
3226 default:
3227 return false;
3228 }
3229 }
3230};
3231
3232#define CONV_ACTIVATION_TEST(ACTIVATION_, CREATOR_, ...) \
3233 TEST_F(GraphFold, FoldConv##ACTIVATION_##Activation) { \
3234 auto *A = \
3235 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false); \
3236 ConvolutionNode *CV = \
3237 F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1); \
3238 auto *AN = F_->CREATOR_(__VA_ARGS__); \
3239 SaveNode *SN = F_->createSave("ret", AN); \
3240 \
3241 EXPECT_EQ(F_->getNodes().size(), 3); \
3242 \
3243 CompilationContext cctx; \
3244 auto B = MockBackendWithFusion(); \
3245 ::glow::fold(F_, cctx, &B); \
3246 \
3247 ConvolutionNode *fusedCV = \
3248 llvm::dyn_cast<ConvolutionNode>(SN->getInput()); \
3249 ASSERT_TRUE(fusedCV); \
3250 EXPECT_EQ(fusedCV->getFusedActivation(), FusedActivation::ACTIVATION_); \
3251 }
3252
3253CONV_ACTIVATION_TEST(RELU, createRELU, "Relu", CV);
3254CONV_ACTIVATION_TEST(CLIP, createClip, "Clip", CV, 0.0, 1.0);
3255CONV_ACTIVATION_TEST(SIGMOID, createSigmoid, "Sigmoid", CV);
3256CONV_ACTIVATION_TEST(TANH, createTanh, "Tanh", CV);
3257CONV_ACTIVATION_TEST(LEAKY_RELU, createLeakyRELU, "LeakyRelu", CV, 1.0);
3258
3259#undef CONV_ACTIVATION_TEST
3260
3261/// This test ensures that if there is a RescaleNode whose input has multiple
3262/// users that the input is not cloned, as this duplicates the node.
3263TEST_F(GraphOptz, MultipleUsersRescaleCombineNoOpt) {
3264 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0);
3265 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1);
3266
3267 Node *LHS =
3268 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS", true);
3269 Node *RHS =
3270 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS", true);
3271
3272 AddNode *AN = F_->createAdd("qAdd", opOutTy, LHS, RHS);
3273 RescaleQuantizedNode *RQN =
3274 F_->createRescaleQuantized("rsAdd", AN, rescaleOutTy);
3275 SaveNode *saveRQN = F_->createSave("saveRQN", RQN);
3276 SaveNode *saveAN = F_->createSave("saveAN", AN);
3277
3278 EXPECT_EQ(F_->getNodes().size(), 4);
3279
3280 ::glow::optimize(F_, CompilationMode::Infer);
3281
3282 // The graph should be unchanged.
3283 EXPECT_EQ(F_->getNodes().size(), 4);
3284 EXPECT_EQ(saveRQN->getInput().getNode(), RQN);
3285 EXPECT_EQ(RQN->getInput().getNode(), AN);
3286 EXPECT_EQ(saveAN->getInput().getNode(), AN);
3287 EXPECT_EQ(AN->getLHS().getNode(), LHS);
3288 EXPECT_EQ(AN->getRHS().getNode(), RHS);
3289}
3290
3291/// This test ensures that fusing of rescale into MatMul is done.
3292TEST_F(GraphOptz, FuseRescaleIntoMatMul) {
3293 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0);
3294 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1);
3295
3296 Placeholder *LHS =
3297 mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0, "LHS", true);
3298 Placeholder *RHS =
3299 mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0, "RHS", true);
3300
3301 RescaleQuantizedNode *LHSR =
3302 F_->createRescaleQuantized("rs1", LHS, rescaleOutTy);
3303 RescaleQuantizedNode *RHSR =
3304 F_->createRescaleQuantized("rs2", RHS, rescaleOutTy);
3305 MatMulNode *MN = F_->createMatMul("qMatMul", opOutTy, LHSR, RHSR);
3306 SaveNode *SN = F_->createSave("save", MN);
3307
3308 // All rescales must be fused into arithmetic operations above.
3309 ::glow::optimize(F_, CompilationMode::Infer);
3310
3311 // Only the MatMul and Save should be left.
3312 EXPECT_EQ(F_->getNodes().size(), 2);
3313 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0);
3314
3315 MatMulNode *newMN = llvm::dyn_cast<MatMulNode>(SN->getInput());
3316 ASSERT_TRUE(newMN);
3317 Placeholder *LPH = llvm::dyn_cast<Placeholder>(newMN->getLHS());
3318 EXPECT_EQ(LPH, LHS);
3319 Placeholder *RPH = llvm::dyn_cast<Placeholder>(newMN->getRHS());
3320 EXPECT_EQ(RPH, RHS);
3321}
3322
3323TEST_F(GraphOptz, sinkRescaledQuantizedNode) {
3324 // Check that we eliminate rescale nodes by sinking them into other
3325 // operators.
3326 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
3327 "input", true);
3328
3329 // slice -> rescale -> reshape -> rescale -> transpose -> maxpool -> save.
3330 auto *slice = F_->createSlice("slice", input, {0, 0}, {2, 4});
3331 auto *rescale = F_->createRescaleQuantized(
3332 "rescale", slice, mod_.uniqueType(ElemKind::Int8QTy, {2, 4}, 0.4, 10));
3333 auto *reshape = F_->createReshape("reshape", rescale, {1, 2, 2, 2});
3334 auto *rescale2 = F_->createRescaleQuantized(
3335 "rescale", reshape,
3336 mod_.uniqueType(ElemKind::Int8QTy, {1, 2, 2, 2}, 0.3, 9));
3337 auto *transpose = F_->createTranspose("transpose", rescale2, {0, 2, 3, 1});
3338 auto *maxpool =
3339 F_->createMaxPool("maxpool", transpose, {2, 2}, {1, 1}, {0, 0, 0, 0});
3340 auto *save = F_->createSave("ret", maxpool->getResult());
3341
3342 EXPECT_EQ(F_->getNodes().size(), 7);
3343 ::glow::optimize(F_, CompilationMode::Infer);
3344 EXPECT_EQ(F_->getNodes().size(), 6);
3345 // Check that rescale sank all the way down to the save node.
3346 EXPECT_TRUE(llvm::dyn_cast<RescaleQuantizedNode>(save->getInput()));
3347}
3348
3349TEST_F(GraphOptz, mergeRescaleWithArithmeticNode) {
3350 // Check that Arithmetic operations can be merged with the Rescale.
3351 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
3352 "input", true);
3353
3354 auto *rescale1 = F_->createRescaleQuantized(
3355 "rescale", input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11));
3356 auto *add = F_->createAdd("add", rescale1, rescale1);
3357 auto *rescale2 = F_->createRescaleQuantized(
3358 "rescale", add, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.3, 11));
3359 auto *sub = F_->createSub("sub", rescale2, rescale2);
3360 auto *rescale3 = F_->createRescaleQuantized(
3361 "rescale", sub, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.2, 11));
3362 auto *mul = F_->createMul("mul", rescale3, rescale3);
3363 auto *rescale4 = F_->createRescaleQuantized(
3364 "rescale", mul, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.1, 11));
3365 auto *div = F_->createDiv("div", rescale4, rescale4);
3366 F_->createSave("save", div);
3367
3368 EXPECT_EQ(F_->getNodes().size(), 9);
3369 ::glow::optimize(F_, CompilationMode::Infer);
3370 EXPECT_EQ(F_->getNodes().size(), 5);
3371}
3372
3373/// Check that Relu can be merged with Rescale.
3374TEST_F(GraphOptz, mergeRescaleWithRelu) {
3375 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
3376 "input", false);
3377
3378 auto *rescale1 = F_->createRescaleQuantized(
3379 "rescale", input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11));
3380 auto *relu = F_->createRELU("relu", rescale1);
3381 F_->createSave("save", relu);
3382
3383 // Rescale, RELU, Save nodes.
3384 EXPECT_EQ(F_->getNodes().size(), 3);
3385
3386 ::glow::optimize(F_, CompilationMode::Infer);
3387
3388 // RELU, Save nodes left; Rescale merged into RELU.
3389 EXPECT_EQ(F_->getNodes().size(), 2);
3390 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0);
3391 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
3392}
3393
3394// Check that we are able to merge some small matmuls into a larger one.
3395TEST_F(GraphOptz, mergeMatMulNodes) {
3396 Node *input =
3397 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3398 Node *weight =
3399 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight", true);
3400
3401 // Split the input to a bunch of small slices.
3402 std::vector<NodeValue> inputs;
3403 for (dim_t i = 0; i < 10; i++) {
3404 auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
3405 auto *R = F_->createReshape("reshape", K, {10, 10});
3406 auto *MM = F_->createMatMul("mm", R, weight);
3407 inputs.push_back(MM);
3408 }
3409
3410 auto *cc = F_->createConcat("merge", inputs, 0);
3411 F_->createSave("save", cc);
3412
3413 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 10);
3414 ::glow::optimize(F_, CompilationMode::Infer);
3415
3416 // Check that all of the matmuls are merged into a single matmul node.
3417 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 1);
3418}
3419
3420// Check that we are able to merge batched adds.
3421TEST_F(GraphOptz, mergeBANodes) {
3422 Node *input =
3423 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3424 Node *slice =
3425 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight", true);
3426
3427 // Split the input to a bunch of small slices.
3428 std::vector<NodeValue> inputs;
3429 for (dim_t i = 0; i < 10; i++) {
3430 auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
3431 auto *MM = F_->createBatchedAdd("BA", K, slice);
3432 inputs.push_back(MM);
3433 }
3434
3435 auto *cc = F_->createConcat("merge", inputs, 0);
3436 F_->createSave("save", cc);
3437
3438 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 10);
3439 ::glow::optimize(F_, CompilationMode::Infer);
3440
3441 // Check that all of the batched-adds are merged into a single node.
3442 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1);
3443}
3444
3445/// Check that EliminateNoop optimization pass removes nodes which don't do
3446/// anything useful.
3447TEST_F(GraphOptz, eliminateNoop) {
3448 std::vector<dim_t> shape = {1, 2, 2, 3};
3449 Placeholder *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004,
3450 0, "input", false);
3451 Placeholder *input2 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004,
3452 0, "input", false);
3453 auto *cond = mod_.createConstant(ElemKind::BoolTy, shape, "input1");
3454 cond->getHandle<bool>() = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
3455
3456 auto *select = F_->createSelect("select", cond, input1, input2);
3457 auto *slice = F_->createSlice("slice", select, {0, 0, 0, 0}, shape);
3458 auto *tile = F_->createTile("tile", slice, 1, 1);
3459 auto *pad = F_->createPad("pad", tile, tile->getResult().getType(), 0,
3460 {0, 0, 0, 0, 0, 0, 0, 0}, 0);
3461 auto *avgPool = F_->createAvgPool("avgpool", pad, 1, 1, 0);
3462 auto *maxPool = F_->createMaxPool("maxpool", avgPool, 1, 1, 0);
3463
3464 F_->createSave("save", maxPool->getResult());
3465
3466 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SelectNodeKind), 1);
3467 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 1);
3468 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1);
3469 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::PadNodeKind), 1);
3470 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind), 1);
3471 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind), 1);
3472
3473 optimizedF_ = optimizeFunctionForTest(F_);
3474
3475 // Check that all nodes except for Save are eliminated.
3476 EXPECT_EQ(optimizedF_->getNodes().size(), 1);
3477
3478 bindings_.allocate(mod_.getPlaceholders());
3479 bindings_.get(input1)->getHandle<int8_t>().randomize(-1.0, 1.0,
3480 mod_.getPRNG());
3481 bindings_.get(input2)->getHandle<int8_t>().randomize(-1.0, 1.0,
3482 mod_.getPRNG());
3483
3484 checkNumericalEquivalence();
3485}
3486
3487// Check that we are able to replace
3488// Add(I, tile(B)) with -> BatchedAdd(I, B).
3489TEST_F(GraphOptz, FoldTileAddIntoBatchedAdd) {
3490 auto *batch =
3491 mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 2}, "batch", false);
3492 auto *added = mod_.createConstant(ElemKind::FloatTy, {1, 1, 2}, "added");
3493 auto *addedTiled = F_->createTile("addedTiled", added, 3, 0);
3494 auto *add = F_->createAdd("add", batch, addedTiled);
3495 auto *save = F_->createSave("save", add);
3496 auto *output = save->getPlaceholder();
3497
3498 bindings_.allocate(batch)->getHandle() = {2, 2, 3, 3, 4, 4};
3499 added->getPayloadMutable().getHandle() = {1, 1};
3500 bindings_.allocate(output);
3501
3502 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1);
3503 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 1);
3504 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 0);
3505
3506 ASSERT_TRUE(F_->verify());
3507
3508 // Currently the FoldTileAddIntoBatchedAdd opt which we're testing here is not
3509 // part of the default optimization pipeline. Create a local version of the
3510 // pipeline with that pass included.
3511 auto p = createDefaultGraphOptimizationPassPipeline();
3512 p->pushFront({FunctionPassID::FoldTileAddIntoBatchedAdd});
3513 FunctionPassManager FPM("opt", std::move(p));
3514 FPM.run(F_, CompilationContext());
3515 ASSERT_TRUE(F_->verify());
3516
3517 // Check that the Tile node and the Add node is replaced by
3518 // a BatchedAdd node.
3519 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 0);
3520 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 0);
3521 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1);
3522
3523 // Verify the correctness of the input to BatchedAdd operator.
3524 // The correctness of BatchedAdd operator itself is verified
3525 // by operator's unit tests.
3526 Tensor expectedBatch(ElemKind::FloatTy, {3, 1, 2});
3527 expectedBatch.getHandle() = {2, 2, 3, 3, 4, 4};
3528 Tensor expectedSlice(ElemKind::FloatTy, {1, 2});
3529 expectedSlice.getHandle() = {1, 1};
3530 for (auto &node : F_->getNodes()) {
3531 auto *recvdBANode = llvm::dyn_cast<BatchedAddNode>(&node);
3532 if (!recvdBANode) {
3533 continue;
3534 }
3535 auto *recvdBatch = llvm::dyn_cast<Placeholder>(recvdBANode->getBatch());
3536 ASSERT_TRUE(recvdBatch);
3537 auto *recvdSlice = llvm::dyn_cast<Constant>(recvdBANode->getSlice());
3538 ASSERT_TRUE(recvdSlice);
3539 EXPECT_TRUE(recvdBatch->dims().equals({3, 1, 2}));
3540 EXPECT_TRUE(recvdSlice->dims().equals({1, 2}));
3541 EXPECT_TRUE(bindings_.get(recvdBatch)->isEqual(expectedBatch));
3542 EXPECT_TRUE(recvdSlice->getPayload().isEqual(expectedSlice));
3543 break;
3544 }
3545}
3546
3547/// Test Concat(Slice, ..., Slice) opt works correctly. If \p reverseOrder then
3548/// the optimization is inapplicable and should not occur.
3549static void testConcatElim(Module &mod, Function *F, Function *&optimizedF,
3550 PlaceholderBindings &bindings, bool reverseOrder) {
3551 Placeholder *input =
3552 mod.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3553 bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
3554
3555 // Split the input to a bunch of small slices.
3556 std::array<NodeValue, 10> inputs;
3557 for (dim_t i = 0; i < 10; i++) {
3558 dim_t idx = reverseOrder ? 9 - i : i;
3559 inputs[i] =
3560 F->createSlice("extract", input, {idx, 0, 0}, {idx + 1, 10, 10});
3561 }
3562
3563 auto *cc = F->createConcat("merge", inputs, 0);
3564 F->createSave("save", cc);
3565
3566 EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), 10);
3567
3568 optimizedF = optimizeFunctionForTest(F);
3569
3570 // Check that either the concat and slices are gone if the optimization was
3571 // applicable, or otherwise that they're still there.
3572 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind),
3573 reverseOrder ? 1 : 0);
3574 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind),
3575 reverseOrder ? 10 : 0);
3576}
3577
3578// Check that we are able to eliminate concat nodes.
3579TEST_F(GraphOptz, concatElim) {
3580 testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ false);
3581 checkNumericalEquivalence(0.0f);
3582}
3583
3584// Check that when the order of the Slices is reversed no optimization kicks in.
3585TEST_F(GraphOptz, concatElimReverseOrder) {
3586 testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ true);
3587 checkNumericalEquivalence(0.0f);
3588}
3589
3590/// Check that we are able to eliminate concat nodes with redundant arithmetic
3591/// ops in way.
3592TEST_F(GraphOptz, concatArithElim) {
3593 auto *input =
3594 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3595 bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3596
3597 Type t(ElemKind::FloatTy, {1, 10, 10});
3598 Node *one = F_->createSplat("one", &t, 1.0);
3599 Node *zero = F_->createSplat("zero", &t, 0.0);
3600
3601 // Split the input to a bunch of small slices.
3602 std::vector<NodeValue> inputs;
3603 for (dim_t i = 0; i < 10; i++) {
3604 auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
3605 // Insert the nodes in reverse order to make sure that we can catch
3606 // non-consecutive graph-order slices.
3607 Node *N = K;
3608 switch (i) {
3609 case 0:
3610 N = F_->createAdd("add0", K, zero);
3611 break;
3612 case 1:
3613 N = F_->createSub("sub0", K, zero);
3614 break;
3615 case 2:
3616 N = F_->createAdd("add_0", zero, K);
3617 break;
3618 case 3:
3619 N = F_->createMul("mul1", K, one);
3620 break;
3621 case 4:
3622 N = F_->createDiv("div1", K, one);
3623 break;
3624 case 5:
3625 N = F_->createMul("mul_1", one, K);
3626 break;
3627 default:
3628 break;
3629 }
3630 inputs.push_back(N);
3631 }
3632
3633 auto *cc = F_->createConcat("merge", inputs, 0);
3634 F_->createSave("save", cc);
3635 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 10);
3636 optimizedF_ = optimizeFunctionForTest(F_);
3637
3638 // Check that the concat node is gone.
3639 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 0);
3640 checkNumericalEquivalence(0.0f);
3641}
3642
3643/// Check that we are able to eliminate concat followed by slices on axis
3644/// \p dim under certain conditions.
3645static void testConcatSliceElim(Module &mod, Function *F, Function *&optimizedF,
3646 PlaceholderBindings &bindings, size_t dim) {
3647 constexpr size_t N = 5;
3648 std::array<NodeValue, N> inputs;
3649 std::vector<dim_t> inShape = {10, 20};
3650 inShape.insert(inShape.begin() + dim, 0);
3651 for (dim_t i = 0; i < N; i++) {
3652 inShape[dim] = 1 + i;
3653 auto *P = mod.createPlaceholder(ElemKind::FloatTy, inShape, "in", true);
3654 bindings.allocate(P)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
3655 inputs[i] = P;
3656 }
3657 auto *CN = F->createConcat("merge", inputs, dim);
3658
3659 // Split the concat to a bunch of slices of the same shape as the concat
3660 // inputs and on the same axis.
3661 std::vector<dim_t> startShape = {0, 0, 0};
3662 std::vector<dim_t> endShape = {10, 20};
3663 endShape.insert(endShape.begin() + dim, 0);
3664 for (dim_t i = 0; i < N; i++) {
3665 startShape[dim] = (i * (i + 1)) / 2;
3666 endShape[dim] = ((i + 1) * (i + 2)) / 2;
3667 auto *SN = F->createSlice("extract", CN, startShape, endShape);
3668 F->createSave("save", SN);
3669 }
3670
3671 // We created a concat followed by N slices of its results.
3672 EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), N);
3673 EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConcatNodeKind), 1);
3674
3675 optimizedF = optimizeFunctionForTest(F);
3676
3677 // Check that the concat and slices are gone.
3678 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind), 0);
3679 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind), 0);
3680}
3681
3682TEST_F(GraphOptz, concatSliceElimInnerDim) {
3683 testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 0);
3684 checkNumericalEquivalence(0.0f);
3685}
3686
3687TEST_F(GraphOptz, concatSliceElimMiddleDim) {
3688 testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 1);
3689 checkNumericalEquivalence(0.0f);
3690}
3691
3692TEST_F(GraphOptz, concatSliceElimOuterDim) {
3693 testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 2);
3694 checkNumericalEquivalence(0.0f);
3695}
3696
3697/// Check the interaction between Sices(Concat) and Concat(Slices) optimizations
3698/// to make sure they work nicely together. Builds Concat(Slices(Concat)) and
3699/// expected a single Concat after optimizations.
3700TEST_F(GraphOptz, concatSliceElimMultiConcat) {
3701 std::array<NodeValue, 4> inputs;
3702 for (size_t i = 0; i < 4; i++) {
3703 auto *P = mod_.createPlaceholder(ElemKind::FloatTy, {2, 4},
3704 "in_" + std::to_string(i), false);
3705 bindings_.allocate(P)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3706 inputs[i] = P;
3707 }
3708 auto *CN0 = F_->createConcat("merge0", inputs, /* axis */ 1);
3709
3710 auto *SN0 = F_->createSlice("slice0", CN0, {0, 0}, {2, 4});
3711 auto *SN1 = F_->createSlice("slice1", CN0, {0, 4}, {2, 8});
3712 auto *SN2 = F_->createSlice("slice2", CN0, {0, 8}, {2, 12});
3713 auto *SN3 = F_->createSlice("slice3", CN0, {0, 12}, {2, 16});
3714
3715 auto *CN1 = F_->createConcat("merge1", {SN1, SN0, SN3, SN2}, /* axis */ 1);
3716 F_->createSave("save", CN1);
3717
3718 // We created a concat followed by 4 slices of its results followed by another
3719 // concat.
3720 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 2);
3721 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 4);
3722
3723 optimizedF_ = optimizeFunctionForTest(F_);
3724
3725 // Check that one concat and slices are gone.
3726 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
3727 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 0);
3728
3729 checkNumericalEquivalence(0.0f);
3730}
3731
3732// Check the transformation Concat(Reshape(x) * N) -> Reshape(Concat(x * N)).
3733TEST_F(GraphOptz, concatReshapes) {
3734 const dim_t shape1[] = {2, 5, 2, 1, 20};
3735 const dim_t shape2[] = {10, 2, 2, 10};
3736 const dim_t shape3[] = {5, 80};
3737 llvm::SmallVector<NodeValue, 10> inputs1;
3738 llvm::SmallVector<NodeValue, 10> inputs2;
3739 for (size_t i = 0; i < 10; i++) {
3740 // 10 reshape nodes that transform from {2,5,2,1,20} to {10,2,2,10}.
3741 // And a ConcatNode concatenates the outputs of reshape at 2nd dim.
3742 // The optimization would kick in, as the size of trailing dimensions of
3743 // original ConcatNode (before opt) is 20, and the size of leading
3744 // dimensions of original ConcatNode (before opt) is 10.
3745 Node *var = F_->getParent()->createPlaceholder(
3746 ElemKind::FloatTy, shape1, "input" + std::to_string(i), true);
3747 auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2);
3748 inputs1.push_back(RN);
3749 }
3750 auto *concatNode1 = F_->createConcat("concat", inputs1, 1);
3751 for (size_t i = 0; i < 10; i++) {
3752 // 10 reshape nodes that transform from {5,80} to {10,1,2,10}.
3753 // And a ConcatNode concatenates the outputs of reshape at 2nd dim.
3754 // The optimization would NOT kick in, as we cannot find the dim that
3755 // makes the leading/trailing dims same as in the case of the original
3756 // concat node.
3757 Node *var = F_->getParent()->createPlaceholder(
3758 ElemKind::FloatTy, shape3, "input" + std::to_string(i), true);
3759 auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2);
3760 inputs2.push_back(RN);
3761 }
3762 auto *concatNode2 = F_->createConcat("concat", inputs2, 1);
3763 auto outputShape = concatNode1->getResult().dims();
3764 // Need to dereference the RN vectors, otherwise the user number of those
3765 // nodes would always be positive, making them unable to be removed by DCE.
3766 inputs1.clear();
3767 inputs2.clear();
3768
3769 auto *addNode = F_->createAdd("add", concatNode1, concatNode2);
3770 auto *O = F_->createSave("ret", addNode);
3771
3772 EXPECT_EQ(F_->getNodes().size(), 24);
3773
3774 ::glow::optimize(F_, CompilationMode::Infer);
3775
3776 // After optimization, we expect to see only 15 nodes. All 10 of the
3777 // reshapes that were the inputs to the first original concat node
3778 // (concatNode1) are removed, and a single new reshape is added after the
3779 // new concat.
3780 EXPECT_EQ(F_->getNodes().size(), 15);
3781
3782 // concatNode1 should not exist any more.
3783 EXPECT_FALSE(functionContainsNode(F_, concatNode1));
3784 // concatNode2 should still exist.
3785 EXPECT_TRUE(functionContainsNode(F_, concatNode2));
3786
3787 // The first input of addNode should be a Reshape node now, with the same
3788 // result shape of concatNode1.
3789 auto *newAddNode = llvm::dyn_cast<AddNode>(O->getInput());
3790 ASSERT_TRUE(newAddNode);
3791 auto *newRN = llvm::dyn_cast<ReshapeNode>(newAddNode->getLHS());
3792 ASSERT_TRUE(newRN);
3793 EXPECT_TRUE(newRN->getResult().getType()->dims().equals(outputShape));
3794
3795 // The input of newRN should be a ConcatNode now.
3796 auto *newCN = llvm::dyn_cast<ConcatNode>(newRN->getInput());
3797 ASSERT_TRUE(newCN);
3798}
3799
3800// Making sure we do not try to to optimize concat2(dim1, concat1(dim2, X, Y),
3801// Z)
3802// -> concat(dim1, X, Y, Z) when concat1 has multiple users.
3803TEST_F(GraphOptz, ConcatSimplificationNegative) {
3804 const dim_t dim1[] = {1, 4, 4, 4};
3805 const dim_t dim2[] = {1, 4, 4, 8};
3806 auto *in1 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in1", false);
3807 auto *in2 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in2", false);
3808 auto *in3 = mod_.createPlaceholder(ElemKind::FloatTy, dim2, "in3", false);
3809
3810 auto *cnc1 = F_->createConcat("cnc1", {in1, in2}, 3);
3811 auto *add1 = F_->createAdd("add1", in3, cnc1);
3812 auto *cnc2 = F_->createConcat("cnc2", {add1, cnc1}, 3);
3813 F_->createSave("ret", cnc2);
3814 EXPECT_EQ(F_->getNodes().size(), 4);
3815 ::glow::optimize(F_, CompilationMode::Infer);
3816 EXPECT_EQ(F_->getNodes().size(), 4);
3817 for (auto &n : F_->getNodes()) {
3818 if (auto *tcnc = llvm::dyn_cast<ConcatNode>(&n)) {
3819 EXPECT_EQ(tcnc->getNumInputs(), 2);
3820 }
3821 }
3822}
3823
3824/// Check that Variable CSE works correctly, combining small Variables that
3825/// have the same data.
3826TEST_F(GraphOptz, VarsCSE) {
3827 // Create three variables that are Private, are not trainable, and have no
3828 // writers. The first two variables have the same data, and so should be
3829 // combined via variable CSE. The third variable differs by the last value,
3830 // and so should not be combined.
3831 auto *input1 = mod_.createConstant(ElemKind::FloatTy, {10}, "input1");
3832 auto *input2 = mod_.createConstant(ElemKind::FloatTy, {10}, "input2");
3833 auto *input3 = mod_.createConstant(ElemKind::FloatTy, {10}, "input3");
3834 input1->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3835 input2->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3836 input3->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1};
3837
3838 // Input them each to different nodes, so node CSE does not change them.
3839 auto *TN = F_->createTanh("tanh", input1);
3840 auto *SN = F_->createSigmoid("sigmoid", input2);
3841 auto *RN = F_->createRELU("relu", input3);
3842 auto *CN = F_->createConcat("concat", {TN, SN, RN}, /* axis */ 0);
3843 F_->createSave("ret", CN);
3844
3845 // Initially there are three variables: inputs 1, 2, and 3 (the save uses a
3846 // placeholder).
3847 EXPECT_EQ(mod_.getConstants().size(), 3);
3848
3849 cctx_.compMode = CompilationMode::Infer;
3850 // Do not perform any compile-time constant folding.
3851 cctx_.optimizationOpts.enableConstantFolding = false;
3852 ::glow::optimize(F_, cctx_);
3853
3854 // Now only two variables are left; input1 and input2 have been combined,
3855 // but input3 has not.
3856 EXPECT_EQ(mod_.getConstants().size(), 2);
3857
3858 // Verify that only one of input1 and input2 exists, and that input3 still
3859 // exists.
3860 Constant *varOneOrTwo = nullptr;
3861 bool foundVarThree = false;
3862 for (auto *V : mod_.getConstants()) {
3863 if (V == input1 || V == input2) {
3864 EXPECT_TRUE(varOneOrTwo == nullptr);
3865 varOneOrTwo = V;
3866 } else if (V == input3) {
3867 foundVarThree = true;
3868 }
3869 }
3870 EXPECT_TRUE(varOneOrTwo != nullptr);
3871 EXPECT_TRUE(foundVarThree);
3872
3873 // Verify that the users of the inputs are updated correctly.
3874 EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo);
3875 EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo);
3876 EXPECT_TRUE(RN->getInput().getNode() == input3);
3877
3878 // Verify that whichever input1/input2 is left over has two users TN and SN.
3879 EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2);
3880 for (auto &U : varOneOrTwo->getUsers()) {
3881 auto *N = U.getUser();
3882 EXPECT_TRUE(N == TN || N == SN);
3883 }
3884
3885 // Verify that input3 only has a single user RN.
3886 ASSERT_TRUE(input3->getUsers().size() == 1);
3887 EXPECT_TRUE(input3->getUsers().begin()->getUser() == RN);
3888}
3889
3890TEST_F(GraphOptz, VarsCSENaN) {
3891 // Create two variables that are Private, are not trainable, have no writers
3892 // and include NaNs. The first two variables have the same data, and so should
3893 // be combined via variable CSE. In particular, the NaN constants should not
3894 // prevent the variables from being combine.
3895 auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5}, "input1");
3896 auto *input2 = mod_.createConstant(ElemKind::FloatTy, {5}, "input2");
3897 input1->getHandle() = {0, NAN, 2, NAN, 4};
3898 input2->getHandle() = {0, NAN, 2, NAN, 4};
3899
3900 // Input them each to different nodes, so node CSE does not change them.
3901 auto *TN = F_->createTanh("tanh", input1);
3902 auto *SN = F_->createSigmoid("sigmoid", input2);
3903 auto *CN = F_->createConcat("concat", {TN, SN}, /* axis */ 0);
3904 F_->createSave("ret", CN);
3905
3906 // Initially there are two variables: inputs 1 and 2 (the save uses a
3907 // placeholder).
3908 EXPECT_EQ(mod_.getConstants().size(), 2);
3909
3910 cctx_.compMode = CompilationMode::Infer;
3911 // Do not perform any compile-time constant folding.
3912 cctx_.optimizationOpts.enableConstantFolding = false;
3913 ::glow::optimize(F_, cctx_);
3914
3915 // Now only one variables is left; input1 and input2 have been combined.
3916 EXPECT_EQ(mod_.getConstants().size(), 1);
3917
3918 // Verify that only one of input1 and input2 exists.
3919 Constant *varOneOrTwo = nullptr;
3920 for (auto *V : mod_.getConstants()) {
3921 if (V == input1 || V == input2) {
3922 EXPECT_TRUE(varOneOrTwo == nullptr);
3923 varOneOrTwo = V;
3924 }
3925 }
3926 EXPECT_TRUE(varOneOrTwo != nullptr);
3927
3928 // Verify that the users of the inputs are updated correctly.
3929 EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo);
3930 EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo);
3931
3932 // Verify that whichever input1/input2 is left over has two users TN and SN.
3933 EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2);
3934 for (auto &U : varOneOrTwo->getUsers()) {
3935 auto *N = U.getUser();
3936 EXPECT_TRUE(N == TN || N == SN);
3937 }
3938}
3939
3940// Verify that constant input canonicalization works correctly when the
3941// arithmetic nodes have multiple users.
3942TEST_F(GraphOptz, simplifyArithmeticMultipleUsers) {
3943 Node *I1 =
3944 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input1", false);
3945
3946 Type t(ElemKind::FloatTy, {10, 10, 10});
3947 Node *SN = F_->createSplat("one", &t, 1.0);
3948
3949 // The splat is a constant input to add1 and add2, and is their LHS input.
3950 // We expect canonicalization to occur during optimization, moving the splat
3951 // to the RHS for both. Note that add1 has multiple users: add2 and save1.
3952 Node *AN1 = F_->createAdd("add1", SN, I1);
3953 Node *AN2 = F_->createAdd("add2", SN, AN1);
3954 SaveNode *SN1 = F_->createSave("save1", AN1);
3955 SaveNode *SN2 = F_->createSave("save2", AN2);
3956
3957 // Five nodes in total: one splat, two adds, and two saves.
3958 EXPECT_EQ(F_->getNodes().size(), 5);
3959 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1);
3960 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2);
3961 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
3962
3963 // input1 has a single user before optimization.
3964 EXPECT_EQ(I1->getUsers().size(), 1);
3965
3966 // Simplify nodes will canonicalize add1 and add2, and should replace all
3967 // their users, without otherwise adding new nodes to the graph/changing the
3968 // overall structure.
3969 ::glow::optimize(F_, CompilationMode::Infer);
3970
3971 // We should have the same five nodes: one splat, two adds, and two saves.
3972 EXPECT_EQ(F_->getNodes().size(), 5);
3973 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1);
3974 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2);
3975 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
3976
3977 // Verify that both add nodes were canonicalized, and that the graph's shape
3978 // is the same as prior to optimization other than canonicalization.
3979 AddNode *newAN1 = llvm::dyn_cast<AddNode>(SN1->getInput().getNode());
3980 ASSERT_TRUE(newAN1 != nullptr);
3981 EXPECT_TRUE(llvm::isa<Placeholder>(newAN1->getLHS()));
3982 EXPECT_TRUE(llvm::isa<SplatNode>(newAN1->getRHS()));
3983
3984 AddNode *newAN2 = llvm::dyn_cast<AddNode>(SN2->getInput().getNode());
3985 ASSERT_TRUE(newAN2 != nullptr);
3986 EXPECT_TRUE(llvm::isa<AddNode>(newAN2->getLHS()));
3987 EXPECT_TRUE(llvm::isa<SplatNode>(newAN2->getRHS()));
3988
3989 EXPECT_EQ(newAN1, newAN2->getLHS());
3990
3991 // input1 should still have a single user after optimization.
3992 EXPECT_EQ(I1->getUsers().size(), 1);
3993}
3994
3995/// Test that a concat with a single input is replaced by the input.
3996TEST_F(GraphOptz, eliminateSingleConcat) {
3997 Node *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", false);
3998
3999 ConcatNode *CN = F_->createConcat("concat1", {input}, 0);
4000 SaveNode *SN = F_->createSave("ret", CN);
4001
4002 // The ConcatNode and SaveNode.
4003 EXPECT_EQ(F_->getNodes().size(), 2);
4004
4005 ::glow::optimize(F_, CompilationMode::Infer);
4006
4007 // Just the SaveNode should be left.
4008 EXPECT_EQ(F_->getNodes().size(), 1);
4009 ASSERT_TRUE(functionContainsNode(F_, SN));
4010
4011 // Save node should just save the input.
4012 EXPECT_TRUE(SN->getInput().getNode() == input);
4013}
4014
4015/// Test that a reshape of a private variable with one use has the reshape
4016/// merged into the variable.
4017TEST_F(GraphOptz, ReshapeConstantOneUse) {
4018 const dim_t shape[] = {10, 20};
4019 const dim_t reshape1[] = {200, 1};
4020 const dim_t reshape2[] = {200};
4021 Constant *input =
4022 F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input");
4023 input->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4024
4025 auto *R1 = F_->createReshape("reshape1", input, reshape1);
4026 auto *R2 = F_->createReshape("reshape2", R1, reshape2);
4027 auto *O = F_->createSave("ret", R2);
4028
4029 // Before optimization, we have 2 Reshapes and a Save.
4030 EXPECT_EQ(F_->getNodes().size(), 3);
4031
4032 // Skip ConstantFolding as it would have the same result as this opt.
4033 cctx_.optimizationOpts.enableConstantFolding = false;
4034 ::glow::optimize(F_, cctx_);
4035
4036 // After optimization, we expect to see just a Save.
4037 EXPECT_EQ(F_->getNodes().size(), 1);
4038
4039 // Save should have the new Variable as input.
4040 auto *V = llvm::dyn_cast<Constant>(O->getInput());
4041 ASSERT_TRUE(V);
4042 // The new Variable should have the same shape as the original second
4043 // Reshape.
4044 EXPECT_TRUE(V->getType()->dims().equals(reshape2));
4045}
4046
4047/// Test that reshape node is merged into Constant in a sequence
4048/// Reshape(Quantize(Constant)).
4049TEST_F(GraphOptz, ReshapeQuantizeConstant) {
4050 const dim_t shape[] = {10, 20};
4051 const dim_t newShape[] = {200, 1};
4052
4053 auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, shape, 0.2, 0);
4054
4055 auto *input =
4056 F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input");
4057 auto *Q = F_->createQuantize("quantize", input, qTy);
4058 auto *R = F_->createReshape("reshape", Q, newShape);
4059 auto *S = F_->createSave("ret", R);
4060
4061 // Skip ConstantFolding as it would have the same result as this opt.
4062 CompilationContext cctx;
4063 cctx.optimizationOpts.enableConstantFolding = false;
4064
4065 EXPECT_EQ(F_->getNodes().size(), 3);
4066 ::glow::optimize(F_, cctx);
4067 EXPECT_EQ(F_->getNodes().size(), 2);
4068
4069 // Constant and Quantize should have new shape.
4070 auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
4071 ASSERT_TRUE(newQ);
4072 EXPECT_TRUE(newQ->getResult().dims().equals(newShape));
4073 auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
4074 ASSERT_TRUE(newC);
4075 EXPECT_TRUE(newC->getType()->dims().equals(newShape));
4076}
4077
4078/// Test that Transpose is optimized into Reshape when it moves no data.
4079TEST_F(GraphOptz, transposeIntoReshapeOptim) {
4080 auto *batch =
4081 mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 2, 4}, "batch", false);
4082 Node *T = F_->createTranspose("transpose", batch, {1, 2, 0, 3});
4083 SaveNode *O = F_->createSave("ret", T);
4084
4085 EXPECT_EQ(F_->getNodes().size(), 2);
4086
4087 ::glow::optimize(F_, CompilationMode::Infer);
4088 EXPECT_EQ(F_->getNodes().size(), 2);
4089
4090 // TransposeNode is Optimized into ReshapeNode.
4091 auto *reshape = llvm::dyn_cast<ReshapeNode>(O->getInput().getNode());
4092 ASSERT_NE(reshape, nullptr);
4093}
4094
4095/// Test that transpose is merged into matmul.
4096TEST_F(GraphOptz, mergeTransposeIntoMatMul) {
4097 auto *input =
4098 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input", false);
4099 auto *weights =
4100 F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights");
4101
4102 weights->getHandle() = {0, 1, 2, 3, 4, 5};
4103 float newWeightsRef[] = {0, 2, 4, 1, 3, 5};
4104
4105 auto *TN = F_->createTranspose("transpose", input, NHWC2NCHW);
4106 auto *RN = F_->createReshape("reshape", TN, {1, 6});
4107 auto *MMN = F_->createMatMul("matmul", RN, weights);
4108 auto *SN = F_->createSave("ret", MMN);
4109
4110 // Transpose + Reshape + MatMul + Save.
4111 EXPECT_EQ(F_->getNodes().size(), 4);
4112
4113 ::glow::optimize(F_, CompilationMode::Infer);
4114
4115 // Reshape + MatMul + Save.
4116 EXPECT_EQ(F_->getNodes().size(), 3);
4117
4118 // Check reordered weights.
4119 auto *newMMN = llvm::dyn_cast<MatMulNode>(SN->getInput());
4120 ASSERT_TRUE(newMMN != nullptr);
4121 auto *newW = llvm::dyn_cast<Constant>(newMMN->getRHS());
4122 ASSERT_TRUE(newW != nullptr);
4123 for (unsigned i = 0; i < 6; ++i) {
4124 EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i));
4125 }
4126}
4127
4128/// Test that transpose is merged into FullyConnected.
4129TEST_F(GraphOptz, mergeTransposeIntoFC) {
4130 auto *input =
4131 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input", false);
4132 auto *weights =
4133 F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights");
4134 auto *bias = F_->getParent()->createConstant(ElemKind::FloatTy, {1}, "bias");
4135
4136 weights->getHandle() = {0, 1, 2, 3, 4, 5};
4137 float newWeightsRef[] = {0, 2, 4, 1, 3, 5};
4138
4139 auto *TN = F_->createTranspose("transpose", input, NHWC2NCHW);
4140 auto *RN = F_->createReshape("reshape", TN, {1, 6});
4141 auto *FCN = F_->createFullyConnected("fc", RN, weights, bias);
4142 auto *SN = F_->createSave("ret", FCN);
4143
4144 // Transpose + Reshape + FC + Save.
4145 EXPECT_EQ(F_->getNodes().size(), 4);
4146
4147 ::glow::optimize(F_, CompilationMode::Infer);
4148
4149 // Reshape + FC + Save.
4150 EXPECT_EQ(F_->getNodes().size(), 3);
4151
4152 // Check reordered weights.
4153 auto *newFCN = llvm::dyn_cast<FullyConnectedNode>(SN->getInput());
4154 ASSERT_TRUE(newFCN != nullptr);
4155 auto *newW = llvm::dyn_cast<Constant>(newFCN->getWeights());
4156 ASSERT_TRUE(newW != nullptr);
4157 for (unsigned i = 0; i < 6; ++i) {
4158 EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i));
4159 }
4160}
4161
4162TEST_F(GraphOptz, ConvertPlaceholdersToConstants) {
4163 auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1", true);
4164 auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2", true);
4165 auto *input3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input3", true);
4166 auto *save1 = F_->createSave("save1", input1);
4167 auto *save2 = F_->createSave("save2", input2);
4168 auto *save3 = F_->createSave("save3", input3);
4169
4170 // No variables, six PHs (3 inputs, 3 saves).
4171 EXPECT_EQ(mod_.getConstants().size(), 0);
4172 EXPECT_EQ(mod_.getPlaceholders().size(), 6);
4173
4174 // Allocate two of the three inputs, but mark input2 of them as
4175 // non-constant.
4176 bindings_.allocate(input1);
4177 bindings_.allocate(input2);
4178 // Don't allocate input3; keep it as a placeholder instead.
4179 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input2});
4180
4181 // input1 becomes a variable.
4182 EXPECT_EQ(mod_.getConstants().size(), 1);
4183 EXPECT_EQ(mod_.getPlaceholders().size(), 6);
4184
4185 EXPECT_TRUE(llvm::isa<Constant>(save1->getInput()));
4186 EXPECT_TRUE(llvm::isa<Placeholder>(save2->getInput()));
4187 EXPECT_TRUE(llvm::isa<Placeholder>(save3->getInput()));
4188}
4189
4190TEST_F(GraphOptz, optimizeConversion_i32_i64_i32) {
4191 auto *i32 = mod_.uniqueType(ElemKind::Int32ITy, {1});
4192 auto *i64 = mod_.uniqueType(ElemKind::Int64ITy, {1});
4193
4194 auto *A = mod_.createPlaceholder(i32, "A", false);
4195 auto *B = F_->createConvertTo("B", A, i64);
4196 auto *C = F_->createConvertTo("C", B, i32);
4197 auto *S = F_->createSave("S", C);
4198
4199 ::glow::optimize(F_, CompilationMode::Infer);
4200
4201 // All casting is optimized away, only left with Save of Placeholder.
4202 EXPECT_EQ(F_->getNodes().size(), 1);
4203 EXPECT_TRUE(llvm::isa<Placeholder>(S->getInput()));
4204}
4205
4206TEST_F(GraphOptz, optimizeSameTypeConversions) {
4207 auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1", true);
4208 auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2", true);
4209 auto *conv1 = F_->createConvertTo("cast1", input1, ElemKind::FloatTy);
4210 auto *conv2 = F_->createConvertTo("cast2", input2, ElemKind::Float16Ty);
4211 auto *save1 = F_->createSave("save1", conv1);
4212 auto *save2 = F_->createSave("save1", conv2);
4213
4214 // convert_to1 + save1 + convert_to2 + save2 nodes.
4215 EXPECT_EQ(F_->getNodes().size(), 4);
4216 EXPECT_TRUE(llvm::isa<ConvertToNode>(save1->getInput()));
4217
4218 ::glow::optimize(F_, CompilationMode::Infer);
4219
4220 // save1 + convert_to2 + save2 nodes.
4221 EXPECT_EQ(F_->getNodes().size(), 3);
4222 // convert_to1 node should be eliminated, because it converts the node into
4223 // the same type.
4224 EXPECT_TRUE(llvm::isa<Placeholder>(save1->getInput()));
4225 // convert_to1 node should not be eliminated, because it converts the node
4226 // into a different type.
4227 EXPECT_TRUE(llvm::isa<ConvertToNode>(save2->getInput()));
4228 EXPECT_EQ(save2->getInput(), NodeValue(conv2));
4229}
4230
4231TEST_F(GraphOptz, optimizeConvertingBetweenFused) {
4232 // Call with dims {5, 2}, which will actually create a constant with {5, 10}
4233 // for scale/offset per row.
4234 Constant *C = createRandomFusedRowwiseQuantizedConstant(
4235 mod_, {5, 2}, "fused", /* useFusedFP16 */ false);
4236 // Converting to fused FP16 means we have 4 total less bytes for scale/offset,
4237 // so we move to {5, 10} from {5, 6}.
4238 auto newOT = mod_.uniqueType(ElemKind::UInt8FusedFP16QTy, {5, 6}, 1.0, 0);
4239 auto *CN = F_->createConvertTo("convert", C, newOT);
4240 auto *SN = F_->createSave("save", CN);
4241
4242 ::glow::optimize(F_, CompilationMode::Infer);
4243
4244 // Convert should be eliminated and just the save of the Constant left.
4245 EXPECT_EQ(F_->getNodes().size(), 1);
4246 Constant *convertedC = llvm::dyn_cast<Constant>(SN->getInput());
4247 ASSERT_TRUE(convertedC);
4248 EXPECT_EQ(convertedC->getElementType(), ElemKind::UInt8FusedFP16QTy);
4249}
4250
4251TEST_F(GraphOptz, dceBeforeOptimizeTranpose) {
4252 auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5, 10}, "input1");
4253 // Create an unused node.
4254 F_->createAdd("add", input1, input1);
4255 auto *transposedInput1 = F_->createTranspose("transpose", input1, {1, 0});
4256 auto *save1 = F_->createSave("save1", transposedInput1);
4257
4258 // add + transpose + save.
4259 EXPECT_EQ(F_->getNodes().size(), 3);
4260
4261 ::glow::optimize(F_, CompilationMode::Infer);
4262
4263 // A single node: save.
4264 EXPECT_EQ(F_->getNodes().size(), 1);
4265 // transpose should be eliminated and replaced by the transposed constant.
4266 EXPECT_TRUE(llvm::isa<Constant>(save1->getInput()));
4267}
4268
4269/// Test that Transpose is sunk below ChannelShuffle and cancels with an
4270/// inverse transpose below the ChannelShuffle. This test models a pattern
4271/// that has has been observed in shufflenet during graph optimization.
4272TEST_F(GraphOptz, sinkTransposeBelowChannelShuffleNodesAndEliminate) {
4273 const dim_t inputDims[] = {3, 28, 28, 136};
4274
4275 Node *K =
4276 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
4277 K = F_->createTranspose("unnecessary_transpose_1", K, {0, 3, 1, 2});
4278 K = F_->createChannelShuffle("channel_shuffle", K, 4, 1);
4279 K = F_->createTranspose("unnecessary_transpose_2", K, {0, 2, 3, 1});
4280 auto *save = F_->createSave("ret", K);
4281
4282 EXPECT_EQ(F_->getNodes().size(), 4);
4283
4284 // Optimize away the unnecessary transposes.
4285 optimize(F_, CompilationMode::Infer);
4286
4287 // Ensure the two unnecessary transposes are gone.
4288 ASSERT_EQ(F_->getNodes().size(), 2);
4289
4290 // Check that the channel shuffle node is still there.
4291 auto *CSN = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode());
4292 ASSERT_NE(nullptr, CSN);
4293
4294 // Ensure ChannelShuffle node has the same dimensions as the input.
4295 EXPECT_EQ(CSN->getResult().dims(), llvm::makeArrayRef(inputDims));
4296
4297 // Ensure Group and Kernel are as expected.
4298 EXPECT_EQ(CSN->getGroup(), 4);
4299 EXPECT_EQ(CSN->getKernel(), 3);
4300}
4301
4302/// Test BatchNorm sinking below Slice.
4303TEST_F(GraphOptz, sinkBatchNormBelowSlice) {
4304 auto *inputTy = mod_.uniqueType(ElemKind::FloatTy, {1, 10, 10, 3});
4305 auto *slicedTy1 = mod_.uniqueType(ElemKind::FloatTy, {1, 8, 8, 3});
4306 auto *slicedTy2 = mod_.uniqueType(ElemKind::FloatTy, {1, 6, 6, 1});
4307
4308 auto *input = mod_.createPlaceholder(inputTy, "input", false);
4309 auto *BN = F_->createBatchNormalization(bindings_, "batchnorm", input, 3,
4310 0.0001, 0.9);
4311 auto *SN1 = F_->createSlice("slice1", BN, {0, 1, 1, 0}, slicedTy1);
4312 auto *SN2 = F_->createSlice("slice2", SN1, {0, 1, 1, 1}, slicedTy2);
4313 auto *save = F_->createSave("save", SN2);
4314
4315 EXPECT_EQ(F_->getNodes().size(), 4);
4316 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
4317 optimizedF_ = optimizeFunctionForTest(F_);
4318 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
4319
4320 // BatchNorm should have sunk below the first Slice, but not the second one,
4321 // as it changes channel dimmension.
4322 auto *newSave =
4323 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
4324 ASSERT_TRUE(newSave);
4325 auto *newSN2 = llvm::dyn_cast<SliceNode>(newSave->getInput());
4326 ASSERT_TRUE(newSN2);
4327 auto *newBN = llvm::dyn_cast<BatchNormalizationNode>(newSN2->getInput());
4328 ASSERT_TRUE(newBN);
4329 ASSERT_EQ(newBN->getResult().dims(), slicedTy1->dims());
4330 ASSERT_TRUE(llvm::isa<SliceNode>(newBN->getInput()));
4331
4332 bindings_.allocate(mod_.getPlaceholders());
4333 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4334 checkNumericalEquivalence();
4335}
4336
4337/// Test that convertPlaceholdersToConstants works properly with quantized
4338/// types.
4339TEST_F(GraphOptz, QuantizedFC) {
4340 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
4341 "input", false);
4342 auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 1.0, 0,
4343 "weights", false);
4344 auto *bias =
4345 mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 1.0, 0, "bias", false);
4346 auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
4347 "output", false);
4348
4349 auto *fc = F_->createFullyConnected("fc", input, weights, bias);
4350 F_->createSave("save", fc, output);
4351
4352 bindings_.allocate(input);
4353 bindings_.allocate(weights);
4354 bindings_.allocate(bias);
4355 bindings_.allocate(output);
4356
4357 glow::convertPlaceholdersToConstants(F_, bindings_, {input, output});
4358 // Two constants: weight and bias
4359 EXPECT_EQ(mod_.getConstants().size(), 2);
4360 // All four placeholders still exist in the module. The old weight and bias
4361 // placeholders just aren't hooked up the the Graph F_.
4362 EXPECT_EQ(mod_.getPlaceholders().size(), 4);
4363}
4364
4365/// Test batchedReduceMean optimization using AvgPool.
4366TEST_F(GraphOptz, convertReduceMean2AvgPool) {
4367 const dim_t dims[] = {2, 2, 2, 2};
4368
4369 Node *A = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
4370 Node *R = F_->createBatchedReduceMean("reduce.mean", A, {2, 3});
4371
4372 SaveNode *O = F_->createSave("ret", R);
4373
4374 EXPECT_EQ(F_->getNodes().size(), 2);
4375
4376 ::glow::optimize(F_, CompilationMode::Infer);
4377
4378 // Optimization adds 2 transpose nodes and one reshape node.
4379 EXPECT_EQ(F_->getNodes().size(), 5);
4380
4381 // Expecting reshape output rather than ReduceMean.
4382 auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput());
4383 ASSERT_NE(RN, nullptr);
4384
4385 // Expecting Transpose node before Reshape node.
4386 auto *TN = llvm::dyn_cast<TransposeNode>(RN->getInput());
4387 ASSERT_NE(TN, nullptr);
4388
4389 // Expecting AvgPool node before Transpose node.
4390 auto *APN = llvm::dyn_cast<AvgPoolNode>(TN->getInput());
4391 ASSERT_NE(APN, nullptr);
4392}
4393
4394/// Test Broadcasted RHS BatchMatMul is converted correctly to a single MatMul.
4395TEST_F(GraphOptz, convertBroadcastedBatchMatMulToMatMul) {
4396 auto *lhs =
4397 mod_.createPlaceholder(ElemKind::FloatTy, {6, 10, 4}, "lhs", false);
4398 auto *rhs = mod_.createConstant(ElemKind::FloatTy, {4, 8}, "rhs");
4399 rhs->getPayloadMutable().getHandle().randomize(-10, 10, mod_.getPRNG());
4400 auto *BMMN = F_->createBatchMatMul("BMM", lhs, rhs);
4401 F_->createSave("save", BMMN);
4402
4403 // Start with a BatchMatMul, not a MatMul.
4404 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 1);
4405 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 0);
4406
4407 optimizedF_ = optimizeFunctionForTest(F_);
4408
4409 // Optimization should replace the BatchMatMul with a single MatMul.
4410 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 1);
4411 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::BatchMatMulNodeKind), 0);
4412
4413 bindings_.allocate(lhs)->getHandle().randomize(-10, 10, mod_.getPRNG());
4414
4415 checkNumericalEquivalence(0.f);
4416}
4417
4418/// Test Broadcasted RHS BatchMatMul is converted correctly to a single MatMul,
4419/// where RHS is broadcasted in multiple dimensions.
4420TEST_F(GraphOptz, convertMultiBroadcastedBatchMatMulToMatMul) {
4421 auto *lhs =
4422 mod_.createPlaceholder(ElemKind::FloatTy, {5, 10, 4}, "lhs", false);
4423 auto *rhs = mod_.createConstant(ElemKind::FloatTy, {1, 1, 6}, "rhs");
4424 rhs->getPayloadMutable().getHandle().randomize(-10, 10, mod_.getPRNG());
4425 auto *BN = F_->createBroadcast("broadcast", rhs, {5, 4, 6}, /* axis */ 0);
4426 auto *BMMN = F_->createBatchMatMul("BMM", lhs, BN);
4427 F_->createSave("save", BMMN);
4428
4429 // Start with a BatchMatMul, not a MatMul, as well as a broadcast.
4430 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 1);
4431 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 0);
4432 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BroadcastNodeKind), 1);
4433
4434 optimizedF_ = optimizeFunctionForTest(
4435 F_, {FunctionPassID::ConvertBroadcastedBatchMatMul, getDCEPassConfig()});
4436
4437 // Optimization should replace the BatchMatMul with a single MatMul, as well
4438 // as include a broadcast leftover.
4439 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 1);
4440 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::BatchMatMulNodeKind), 0);
4441 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::BroadcastNodeKind), 1);
4442
4443 bindings_.allocate(lhs)->getHandle().randomize(-10, 10, mod_.getPRNG());
4444
4445 checkNumericalEquivalence(0.f);
4446}
4447
4448TEST_F(GraphOptz, dceQuantization) {
4449 auto *lhs =
4450 mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "lhs", false);
4451 auto *weights =
4452 mod_.createConstant(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "weights");
4453
4454 auto *add = F_->createAdd("add", lhs, weights);
4455 auto *t1 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.2, 0);
4456 auto *rs1 = F_->createRescaleQuantized("rs1", add, t1);
4457 auto *t2 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.1, 1);
4458 auto *rs2 = F_->createRescaleQuantized("rs2", rs1, t2);
4459 F_->createSave("save", rs2);
4460
4461 ::glow::optimize(F_, CompilationMode::Infer);
4462
4463 EXPECT_EQ(F_->getNodes().size(), 2);
4464}
4465
4466TEST_F(GraphOptz, nopRelu) {
4467 auto *in = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, -128, "lhs",
4468 false);
4469
4470 auto *relu = F_->createRELU("relu", in);
4471 F_->createSave("save", relu);
4472
4473 optimizedF_ = optimizeFunctionForTest(F_);
4474
4475 EXPECT_EQ(optimizedF_->getNodes().size(), 1);
4476
4477 bindings_.allocate(mod_.getPlaceholders());
4478 bindings_.get(in)->getHandle<int8_t>().randomize(-4, 4, mod_.getPRNG());
4479
4480 checkNumericalEquivalence();
4481}
4482
4483template <typename ElemTy>
4484static void setConstValue(Constant *C, ElemTy value) {
4485 Handle<ElemTy> TH = C->getPayload().getHandle<ElemTy>();
4486 TH.clear(value);
4487}
4488
4489TEST_F(GraphOptz, constantFoldSingleNode) {
4490 auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
4491 auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
4492 auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
4493 /* isTrainable */ false);
4494 setConstValue(const1, 1.0f);
4495 setConstValue(const2, 2.0f);
4496 auto *splat2 = F_->createSplat(
4497 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
4498 auto *splat3 = F_->createSplat(
4499 "splat3", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
4500
4501 auto *add1 = F_->createAdd("add", const1, const2);
4502 auto *mul1 = F_->createMul("mul1", add1, splat2);
4503 auto *mul2 = F_->createMul("mul2", mul1, splat3);
4504 auto *SN1 = F_->createSave("save", mul2);
4505 auto *add3 = F_->createAdd("add", const1, ph1);
4506 auto *SN2 = F_->createSave("save", add3);
4507
4508 // Perform constant folding for a specific node.
4509 std::vector<Constant *> constResults =
4510 constantFold(SN1->getInput().getNode());
4511
4512 ASSERT_EQ(constResults.size(), 1);
4513 SN1->getInput().replaceAllUsesOfWith(constResults[0]);
4514 // Second save should be unaffected.
4515 EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput()));
4516 // First save should have been constant folded.
4517 EXPECT_TRUE(llvm::isa<Constant>(SN1->getInput()));
4518 Constant *C = llvm::dyn_cast<Constant>(SN1->getInput());
4519 auto CH = C->getHandle();
4520 // The expected result should be: (((1+2) * 2 * 3) = 18
4521 EXPECT_EQ(CH.at({0, 0}), 18.0f);
4522 EXPECT_EQ(CH.at({0, 1}), 18.0f);
4523 EXPECT_EQ(CH.at({1, 0}), 18.0f);
4524 EXPECT_EQ(CH.at({1, 1}), 18.0f);
4525}
4526
4527/// Verify that we can specify what splats should be materialized to constants
4528/// based on their users via optimizationOpts.materializeSplatsUsedBySet.
4529TEST_F(GraphOptz, constantFoldSpecificSplat) {
4530 Placeholder *PH = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1}, "input",
4531 /* isTrainable */ false);
4532 SplatNode *splat1 = F_->createSplat(
4533 "splat1", mod_.uniqueType(ElemKind::FloatTy, {1, 1}), 1.0f);
4534 AddNode *add = F_->createAdd("add", PH, splat1);
4535 SplatNode *splat2 = F_->createSplat(
4536 "splat2", mod_.uniqueType(ElemKind::FloatTy, {1, 1}), 2.0f);
4537 MulNode *mul = F_->createMul("mul", add, splat2);
4538 SaveNode *save = F_->createSave("save", mul);
4539
4540 // Signal to materialize the splat used by Add, but not by Mul.
4541 cctx_.optimizationOpts.materializeSplatsUsedBySet.insert(
4542 Kinded::Kind::AddNodeKind);
4543
4544 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4545
4546 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4547 runDCEPass(optimizedF_, cctx_);
4548
4549 ASSERT_EQ(record.size(), 1);
4550 SaveNode *SN = record.begin()->second;
4551 SplatNode *foldSplat1 = llvm::dyn_cast<SplatNode>(SN->getInput());
4552 ASSERT_TRUE(foldSplat1);
4553 EXPECT_EQ(foldSplat1->getValue(), 1.0f);
4554
4555 // Verify one splat left in the optimized Function, and a new Constant.
4556 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SplatNodeKind));
4557 const SaveNode *optSave =
4558 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
4559 MulNode *optMul = llvm::dyn_cast<MulNode>(optSave->getInput());
4560 ASSERT_TRUE(optMul);
4561 SplatNode *optSplat2 = llvm::dyn_cast<SplatNode>(optMul->getRHS());
4562 ASSERT_TRUE(optSplat2);
4563 EXPECT_EQ(optSplat2->getValue(), 2.0f);
4564 AddNode *optAdd = llvm::dyn_cast<AddNode>(optMul->getLHS());
4565 ASSERT_TRUE(optAdd);
4566 EXPECT_EQ(optAdd->getLHS().getNode(), PH);
4567 Constant *optSplatConst1 = llvm::dyn_cast<Constant>(optAdd->getRHS());
4568 ASSERT_TRUE(optSplatConst1);
4569 EXPECT_EQ(optSplatConst1->getPayload().getHandle().at({0, 0}), 1.0f);
4570}
4571
4572/// Test that we correctly record a single constant folding subgraph that has a
4573/// single output.
4574TEST_F(GraphOptz, constantFoldWithRecordSingleChain) {
4575 Placeholder *I =
4576 mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
4577 /* isTrainable */ false);
4578 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
4579 ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
4580 ConvertToNode *convertW =
4581 F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
4582 TransposeNode *transposeW =
4583 F_->createTranspose("transpose", convertW, {1, 0});
4584 MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
4585 SaveNode *save = F_->createSave("save", MM);
4586 Placeholder *O = save->getPlaceholder();
4587 bindings_.allocate(O);
4588
4589 ASSERT_TRUE(F_->verify());
4590
4591 Tensor *IT = bindings_.allocate(I);
4592 IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4593 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4594
4595 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4596
4597 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4598
4599 runDCEPass(optimizedF_, cctx_);
4600
4601 ASSERT_EQ(record.size(), 1);
4602 SaveNode *SN = record.begin()->second;
4603 Function *constFoldF = SN->getParent();
4604
4605 // Expect to find a chain of Nodes based on Nodes above. Note that the clip is
4606 // lowered for the Interpreter backend which performs constant folding.
4607 EXPECT_EQ(2, countNodeKind(constFoldF, Kinded::Kind::SplatNodeKind));
4608 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MaxNodeKind));
4609 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MinNodeKind));
4610 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind));
4611 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TransposeNodeKind));
4612
4613 // Skip optimizations -- we just want to run them as is (otherwise we'll
4614 // constant fold them inside the optimization pipeline).
4615 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF);
4616 cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4617 cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4618
4619 // Don't strip the module as we want to compare the Constant values below.
4620 EE_.setSkipModuleStrip(true);
4621
4622 EE_.compile(cctx_);
4623 alreadyCompiled_ = true;
4624
4625 bindings_.allocate(mod_.getPlaceholders());
4626
4627 // Run the constant folding chain to check that we have the same constant used
4628 // by the optimized Function.
4629 EE_.run(bindings_, constFoldF->getName());
4630 Tensor *rerunT = bindings_.get(SN->getPlaceholder());
4631 ASSERT_TRUE(rerunT);
4632 auto optimizedConstants = optimizedF_->findConstants();
4633 ASSERT_EQ(optimizedConstants.size(), 1);
4634 EXPECT_TRUE(
4635 (*optimizedConstants.begin())->getPayload().isEqual(*rerunT, 0.f));
4636
4637 // Remove the temporary constant folding Functions and their Placeholders.
4638 cleanupConstantFolding(mod_, record, &bindings_);
4639
4640 // Now compile/run/compare F_ and optimizedF_.
4641 checkNumericalEquivalence(0.f);
4642}
4643
4644/// Test that we correctly record two constant folding subgraphs, with each with
4645/// a single output.
4646TEST_F(GraphOptz, constantFoldWithRecordMultiChain) {
4647 Placeholder *I =
4648 mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
4649 /* isTrainable */ false);
4650 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
4651 ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
4652 ConvertToNode *convertW =
4653 F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
4654 TransposeNode *transposeW =
4655 F_->createTranspose("transpose", convertW, {1, 0});
4656 MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
4657 SaveNode *saveMM = F_->createSave("save_mm", MM);
4658 Placeholder *MMP = saveMM->getPlaceholder();
4659 bindings_.allocate(MMP);
4660
4661 SigmoidNode *sigmoidW = F_->createSigmoid("sig", convertW);
4662 SaveNode *saveSig = F_->createSave("save_sig", sigmoidW);
4663 Placeholder *sigP = saveSig->getPlaceholder();
4664 bindings_.allocate(sigP);
4665
4666 ASSERT_TRUE(F_->verify());
4667
4668 Tensor *IT = bindings_.allocate(I);
4669 IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4670 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4671
4672 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4673
4674 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4675
4676 runDCEPass(optimizedF_, cctx_);
4677
4678 ASSERT_EQ(record.size(), 2);
4679 SaveNode *sigSN = record.begin()->second;
4680 SaveNode *transSN = std::next(record.begin())->second;
4681 if (llvm::isa<SigmoidNode>(transSN->getInput())) {
4682 std::swap(sigSN, transSN);
4683 }
4684
4685 Function *constFoldSig = sigSN->getParent();
4686 Function *constFoldTrans = transSN->getParent();
4687
4688 // Expect to find a chain of Nodes based on Nodes above. Note that the clip is
4689 // lowered for the Interpreter backend which performs constant folding.
4690 EXPECT_EQ(2, countNodeKind(constFoldTrans, Kinded::Kind::SplatNodeKind));
4691 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MaxNodeKind));
4692 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MinNodeKind));
4693 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::ConvertToNodeKind));
4694 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::TransposeNodeKind));
4695
4696 EXPECT_EQ(2, countNodeKind(constFoldSig, Kinded::Kind::SplatNodeKind));
4697 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MaxNodeKind));
4698 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MinNodeKind));
4699 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::ConvertToNodeKind));
4700 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::SigmoidNodeKind));
4701
4702 // Skip optimizations -- we just want to run them as is (otherwise we'll
4703 // constant fold them inside the optimization pipeline).
4704 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldTrans);
4705 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldSig);
4706 cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4707 cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4708
4709 // Don't strip the module as we want to compare the Constant values below.
4710 EE_.setSkipModuleStrip(true);
4711
4712 EE_.compile(cctx_);
4713 alreadyCompiled_ = true;
4714
4715 bindings_.allocate(mod_.getPlaceholders());
4716
4717 // Run the constant folding chain to check that we have the same constant used
4718 // by the optimized Function.
4719 EE_.run(bindings_, constFoldTrans->getName());
4720 EE_.run(bindings_, constFoldSig->getName());
4721
4722 // Find the correct PHs for each of the constant folding we do.
4723 Tensor *rerunTransT = bindings_.get(transSN->getPlaceholder());
4724 Tensor *rerunSigT = bindings_.get(sigSN->getPlaceholder());
4725 ASSERT_TRUE(rerunTransT);
4726 ASSERT_TRUE(rerunSigT);
4727
4728 auto optimizedConstants = optimizedF_->findConstants();
4729 ASSERT_EQ(optimizedConstants.size(), 2);
4730 Constant *transC = *optimizedConstants.begin();
4731 Constant *sigC = *std::next(optimizedConstants.begin());
4732 // If we have the constants backwards then swap them. Note that we know
4733 // sigC must be directly saved, while transC is input to a MatMulNode.
4734 ASSERT_EQ(transC->getNumUsers(), 1);
4735 if (llvm::isa<SaveNode>(transC->getUsers().begin()->getUser())) {
4736 std::swap(transC, sigC);
4737 }
4738 EXPECT_TRUE(transC->getPayload().isEqual(*rerunTransT, 0.f));
4739 EXPECT_TRUE(sigC->getPayload().isEqual(*rerunSigT, 0.f));
4740
4741 // Remove the temporary constant folding Functions and their Placeholders.
4742 cleanupConstantFolding(mod_, record, &bindings_);
4743
4744 // Now compile/run/compare F_ and optimizedF_.
4745 checkNumericalEquivalence(0.f);
4746}
4747
4748/// Test that we correctly record a single constant folding subgraph that has
4749/// two outputs.
4750TEST_F(GraphOptz, constantFoldWithRecordSingleChainMultiOutput) {
4751 Constant *W = mod_.createConstant(ElemKind::FloatTy, {100}, "weight");
4752 SigmoidNode *sigmoidW = F_->createSigmoid("sig", W);
4753 ConvertToNode *convertW =
4754 F_->createConvertTo("conv", sigmoidW, ElemKind::Float16Ty);
4755 TopKNode *TK = F_->createTopK("topk", convertW, 5);
4756
4757 SaveNode *indicesSave = F_->createSave("save_indices", TK->getIndices());
4758 Placeholder *indicesP = indicesSave->getPlaceholder();
4759 bindings_.allocate(indicesP);
4760
4761 Placeholder *I = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "input",
4762 /* isTrainable */ false);
4763 AddNode *add = F_->createAdd("add", I, TK->getValues());
4764 SaveNode *addSave = F_->createSave("save_add", add);
4765 Placeholder *addP = addSave->getPlaceholder();
4766 bindings_.allocate(addP);
4767
4768 ASSERT_TRUE(F_->verify());
4769
4770 Tensor *IT = bindings_.allocate(I);
4771 IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4772 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4773
4774 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4775
4776 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4777
4778 runDCEPass(optimizedF_, cctx_);
4779
4780 ASSERT_EQ(record.size(), 2);
4781 SaveNode *indicesSN = record.begin()->second;
4782 SaveNode *addSN = std::next(record.begin())->second;
4783
4784 // Find the correct PHs for each of the constant folding we do.
4785 if (indicesSN->getInput().getResNo() != TopKNode::IndicesIdx) {
4786 std::swap(indicesSN, addSN);
4787 }
4788
4789 // Expect that the two constants that we folded are from the same Function,
4790 // and that the two saves use the two different outputs from a topk.
4791 EXPECT_EQ(indicesSN->getParent(), addSN->getParent());
4792 ASSERT_TRUE(llvm::isa<TopKNode>(addSN->getInput()));
4793 ASSERT_TRUE(llvm::isa<TopKNode>(indicesSN->getInput()));
4794 EXPECT_EQ(addSN->getInput().getNode(), indicesSN->getInput().getNode());
4795
4796 Function *constFoldF = addSN->getParent();
4797
4798 // Expect to find a chain of Nodes based on Nodes above.
4799 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TopKNodeKind));
4800 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::SigmoidNodeKind));
4801 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind));
4802
4803 // Skip optimizations -- we just want to run them as is (otherwise we'll
4804 // constant fold them inside the optimization pipeline).
4805 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF);
4806 cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4807 cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4808
4809 // Don't strip the module as we want to compare the Constant values below.
4810 EE_.setSkipModuleStrip(true);
4811
4812 EE_.compile(cctx_);
4813 alreadyCompiled_ = true;
4814
4815 bindings_.allocate(mod_.getPlaceholders());
4816
4817 // Run the constant folding chain to check that we have the same constant used
4818 // by the optimized Function.
4819 EE_.run(bindings_, constFoldF->getName());
4820
4821 Tensor *rerunAddT = bindings_.get(addSN->getPlaceholder());
4822 Tensor *rerunIndicesT = bindings_.get(indicesSN->getPlaceholder());
4823 ASSERT_TRUE(rerunAddT);
4824 ASSERT_TRUE(rerunIndicesT);
4825
4826 auto optimizedConstants = optimizedF_->findConstants();
4827 ASSERT_EQ(optimizedConstants.size(), 2);
4828 Constant *addC = *optimizedConstants.begin();
4829 Constant *indicesC = *std::next(optimizedConstants.begin());
4830
4831 // If we have the constants backwards then swap them. Note that we know
4832 // indicesC must be directly saved, while addC is input to an AddNode.
4833 ASSERT_EQ(addC->getNumUsers(), 1);
4834 if (llvm::isa<SaveNode>(addC->getUsers().begin()->getUser())) {
4835 std::swap(addC, indicesC);
4836 }
4837 EXPECT_TRUE(addC->getPayload().isEqual(*rerunAddT, 0.f));
4838 EXPECT_TRUE(indicesC->getPayload().isEqual(*rerunIndicesT, 0.f));
4839
4840 // Remove the temporary constant folding Functions and their Placeholders.
4841 cleanupConstantFolding(mod_, record, &bindings_);
4842
4843 // Now compile/run/compare F_ and optimizedF_.
4844 checkNumericalEquivalence(0.f);
4845}
4846
4847/// Test that the constant folding record Function includes all ops,
4848/// i.e. they're not optimized away during optimizations when the constant
4849/// folding function is optimized.
4850TEST_F(GraphOptz, constantFoldOnlyLower) {
4851 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
4852 ConvertToNode *convertW = F_->createConvertTo("conv", W, ElemKind::Float16Ty);
4853 SaveNode *save = F_->createSave("save", convertW);
4854 Placeholder *O = save->getPlaceholder();
4855 bindings_.allocate(O);
4856
4857 ASSERT_TRUE(F_->verify());
4858
4859 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4860
4861 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4862
4863 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4864
4865 ASSERT_EQ(record.size(), 1);
4866 SaveNode *SN = record.begin()->second;
4867 Function *constFoldF = SN->getParent();
4868
4869 // Expect to find a Save and the ConvertTo still, i.e. it shouldn't have been
4870 // folded into the Constant as part of the OptimizeConversions pass.
4871 EXPECT_EQ(2, constFoldF->getNodes().size());
4872 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind));
4873 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::SaveNodeKind));
4874}
4875
4876TEST_F(GraphOptz, constantFoldWholeFunction) {
4877 auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
4878 auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
4879 auto *const3 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const3");
4880 auto *const4 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const4");
4881 auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
4882 /* isTrainable */ false);
4883 setConstValue(const1, 1.0f);
4884 setConstValue(const2, 2.0f);
4885 setConstValue(const3, 3.0f);
4886 setConstValue(const4, 4.0f);
4887 auto *splat2 = F_->createSplat(
4888 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
4889 auto *splat3 = F_->createSplat(
4890 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
4891 auto *splat4 = F_->createSplat(
4892 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 4.0f);
4893
4894 auto *add1 = F_->createAdd("add", const1, const2);
4895 auto *mul1 = F_->createMul("mul1", add1, splat2);
4896 auto *mul2 = F_->createMul("mul2", mul1, splat3);
4897 auto *sub = F_->createSub("sub", mul2, const3);
4898 auto *add2 = F_->createAdd("add2", sub, const4);
4899 auto *mul3 = F_->createMul("mul3", add2, splat4);
4900 // Check compile-time constant folding for nodes with multiple results.
4901 auto *topK = F_->createTopK("topK", mul3, 2);
4902 auto *SN1_0 = F_->createSave("save", topK->getValues());
4903 auto *SN1_1 = F_->createSave("save", topK->getIndices());
4904 auto *add3 = F_->createAdd("add", const1, ph1);
4905 auto *SN2 = F_->createSave("save", add3);
4906
4907 // Perform constant folding for a whole function.
4908 ::glow::optimize(F_, CompilationMode::Infer);
4909
4910 EXPECT_EQ(F_->getNodes().size(), 4);
4911 // Second save should be unaffected, as its value is not a constant operation.
4912 EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput()));
4913 // First save should have been constant folded.
4914 EXPECT_TRUE(llvm::isa<Constant>(SN1_0->getInput()));
4915 EXPECT_TRUE(llvm::isa<Constant>(SN1_1->getInput()));
4916 Constant *C = llvm::dyn_cast<Constant>(SN1_0->getInput());
4917 auto CH = C->getHandle();
4918 // The expected result should be: (((1+2) * 2 * 3 - 3) + 4) * 4 = 76
4919 EXPECT_EQ(CH.at({0, 0}), 76.0f);
4920 EXPECT_EQ(CH.at({0, 1}), 76.0f);
4921 EXPECT_EQ(CH.at({1, 0}), 76.0f);
4922 EXPECT_EQ(CH.at({1, 1}), 76.0f);
4923}
4924
4925/// Test constant folding for operators which are lowered in Interpreter
4926/// backend.
4927TEST_F(GraphOptz, constantFoldWithLowering) {
4928 auto *input = mod_.createConstant(ElemKind::FloatTy, {1, 6}, "input");
4929 input->getHandle() = {5, 4, 3, 2, 1, 0};
4930 auto *TN = F_->createTile("tile", input, 5, 0);
4931 auto *SN = F_->createSave("ret", TN);
4932
4933 // Perform constant folding.
4934 EXPECT_EQ(F_->getNodes().size(), 2);
4935 ::glow::optimize(F_, CompilationMode::Infer);
4936
4937 // Tile with its input should be folded into a single Constant node.
4938 EXPECT_EQ(F_->getNodes().size(), 1);
4939 ASSERT_TRUE(llvm::isa<Constant>(SN->getInput()));
4940}
4941
4942/// Test Splitting FC into multiple FCs.
4943TEST_F(GraphOptz, SplitFCIntoMultipleOps) {
4944 auto *input =
4945 mod_.createPlaceholder(ElemKind::FloatTy, {2, 32}, "input", false);
4946 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
4947 mod_.getPRNG());
4948 auto *weights = mod_.createConstant(ElemKind::FloatTy, {32, 850}, "weights");
4949 weights->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4950 auto *bias = mod_.createConstant(ElemKind::FloatTy, {850}, "bias");
4951 bias->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
4952 auto *output =
4953 mod_.createPlaceholder(ElemKind::FloatTy, {2, 850}, "output", false);
4954 bindings_.allocate(output);
4955
4956 auto *fc = F_->createFullyConnected("fc", input, weights, bias);
4957 auto *save = F_->createSave("save", fc, output);
4958
4959 ::glow::optimize(F_, CompilationMode::Infer);
4960
4961 // This is F_ but without the parallel transformation below.
4962 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4963
4964 EXPECT_TRUE(::glow::executeVerticalFCWeightsSplit(F_,
4965 /*numOfChunks*/ 12,
4966 /*minKToSplit*/ 800));
4967 runDCEPass(F_, cctx_);
4968
4969 // 24 Slices: 12 from bias and 12 from weights.
4970 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
4971
4972 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
4973
4974 // 12 newly created FCs.
4975 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4976
4977 auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput());
4978 ASSERT_TRUE(concatNode);
4979 // 12 FCs are connected to the concat node.
4980 EXPECT_EQ(12, concatNode->getInputs().size());
4981
4982 // Check all splitted FCs.
4983 for (unsigned i = 0; i < 12; ++i) {
4984 auto *fc = llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(i));
4985 ASSERT_TRUE(fc);
4986 // 2 * 71 for first 11 FCs and last 2 * 69
4987 if (i == 11) {
4988 EXPECT_TRUE(fc->getResult().dims().equals({2, 69}));
4989 EXPECT_TRUE(fc->getBias().dims().equals({69}));
4990 EXPECT_TRUE(fc->getWeights().dims().equals({32, 69}));
4991 } else {
4992 EXPECT_TRUE(fc->getResult().dims().equals({2, 71}));
4993 EXPECT_TRUE(fc->getBias().dims().equals({71}));
4994 EXPECT_TRUE(fc->getWeights().dims().equals({32, 71}));
4995 }
4996 }
4997
4998 checkNumericalEquivalence();
4999}
5000
5001/// Test Splitting FC into multiple FCs.
5002TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel) {
5003 auto *input =
5004 mod_.createPlaceholder(ElemKind::FloatTy, {8, 3}, "input", false);
5005 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
5006 mod_.getPRNG());
5007 auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {3, 150}, "weights");
5008 weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
5009 auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias");
5010 bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
5011 auto *weights2 =
5012 mod_.createConstant(ElemKind::FloatTy, {150, 150}, "weights");
5013 weights2->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
5014 auto *bias2 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias");
5015 bias2->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
5016 auto *output =
5017 mod_.createPlaceholder(ElemKind::FloatTy, {8, 150}, "output", false);
5018 bindings_.allocate(output);
5019
5020 auto *fc1 = F_->createFullyConnected("fc1", input, weights1, bias1);
5021 auto *relu1 = F_->createRELU("relu1", fc1);
5022
5023 auto *fc2 = F_->createFullyConnected("fc2", relu1, weights2, bias2);
5024 auto *relu2 = F_->createRELU("relu2", fc2);
5025 F_->createSave("save", relu2, output);
5026
5027 ::glow::optimize(F_, CompilationMode::Infer);
5028
5029 // This is F_ but without the parallel transformation below.
5030 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5031
5032 // Perform parallel transformation on F_.
5033 llvm::DenseMap<Node *, size_t> numChunks;
5034 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5035 numChunks[fc1] = 2;
5036 numChunks[relu1] = 2;
5037 numChunks[fc2] = 2;
5038 numChunks[relu2] = 2;
5039 parOpts[fc1] = ParallelTransformKind::Model;
5040 parOpts[relu1] = ParallelTransformKind::Model;
5041 parOpts[fc2] = ParallelTransformKind::Model;
5042 parOpts[relu2] = ParallelTransformKind::Model;
5043 std::unordered_map<Node *, ConcatNode *> replacedMap;
5044 ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
5045 ::glow::parallelizeOps(F_, numChunks, parOpts));
5046 EXPECT_EQ(replacedMap.size(), parOpts.size());
5047
5048 runDCEPass(F_, cctx_);
5049
5050 EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
5051 EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
5052
5053 checkNumericalEquivalence();
5054}
5055
5056/// Test Splitting FC into multiple FCs, special case for 866 by 8 with an
5057/// alignment of 64, which is a corner case for alignment and should only
5058/// produce 7 splits
5059TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel_Split866by8) {
5060 auto *input =
5061 mod_.createPlaceholder(ElemKind::FloatTy, {8, 3}, "input", false);
5062 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
5063 mod_.getPRNG());
5064 auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {3, 866}, "weights");
5065 weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
5066 auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {866}, "bias");
5067 bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
5068 auto *output =
5069 mod_.createPlaceholder(ElemKind::FloatTy, {8, 866}, "output", false);
5070 bindings_.allocate(output);
5071
5072 auto *fc1 = F_->createFullyConnected("fc1", input, weights1, bias1);
5073 auto *relu1 = F_->createRELU("relu1", fc1);
5074
5075 F_->createSave("save", relu1, output);
5076
5077 ::glow::optimize(F_, CompilationMode::Infer);
5078
5079 // This is F_ but without the parallel transformation below.
5080 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5081
5082 // Perform parallel transformation on F_.
5083 llvm::DenseMap<Node *, size_t> numChunks;
5084 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5085 numChunks[fc1] = 8;
5086 numChunks[relu1] = 8;
5087 parOpts[fc1] = ParallelTransformKind::Model;
5088 parOpts[relu1] = ParallelTransformKind::Model;
5089 std::unordered_map<Node *, ConcatNode *> replacedMap;
5090 ASSIGN_VALUE_OR_FAIL_TEST(
5091 replacedMap,
5092 ::glow::parallelizeOps(F_, numChunks, parOpts, /*numOfChunks*/ 1,
5093 /*modelParallelSplitAlignment*/ 64));
5094 EXPECT_EQ(replacedMap.size(), parOpts.size());
5095
5096 runDCEPass(F_, cctx_);
5097
5098 EXPECT_EQ(7, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
5099 EXPECT_EQ(7, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
5100
5101 // Check all splitted FCs.
5102 auto *concatNode = replacedMap[fc1];
5103 for (unsigned i = 0; i < 7; ++i) {
5104 auto *fc = llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(i));
5105 ASSERT_TRUE(fc);
5106 // 8 x 128 for first 6 FCs and last 8 x 30
5107 if (i == 6) {
5108 EXPECT_TRUE(fc->getResult().dims().equals({8, 98}));
5109 EXPECT_TRUE(fc->getBias().dims().equals({98}));
5110 EXPECT_TRUE(fc->getWeights().dims().equals({3, 98}));
5111 } else {
5112 EXPECT_TRUE(fc->getResult().dims().equals({8, 128}));
5113 EXPECT_TRUE(fc->getBias().dims().equals({128}));
5114 EXPECT_TRUE(fc->getWeights().dims().equals({3, 128}));
5115 }
5116 }
5117
5118 checkNumericalEquivalence();
5119}
5120
5121/// Test Splitting FC into multiple FCs, special case for 140 by 3 with an
5122/// alignment of 64. Should split 64, 64, 12
5123TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel_Split140by3) {
5124 auto *input =
5125 mod_.createPlaceholder(ElemKind::FloatTy, {8, 3}, "input", false);
5126 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
5127 mod_.getPRNG());
5128 auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {3, 140}, "weights");
5129 weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
5130 auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {140}, "bias");
5131 bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
5132 auto *output =
5133 mod_.createPlaceholder(ElemKind::FloatTy, {8, 140}, "output", false);
5134 bindings_.allocate(output);
5135
5136 auto *fc1 = F_->createFullyConnected("fc1", input, weights1, bias1);
5137 auto *relu1 = F_->createRELU("relu1", fc1);
5138
5139 F_->createSave("save", relu1, output);
5140
5141 ::glow::optimize(F_, CompilationMode::Infer);
5142
5143 // This is F_ but without the parallel transformation below.
5144 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5145
5146 // Perform parallel transformation on F_.
5147 llvm::DenseMap<Node *, size_t> numChunks;
5148 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5149 numChunks[fc1] = 3;
5150 parOpts[fc1] = ParallelTransformKind::Model;
5151 std::unordered_map<Node *, ConcatNode *> replacedMap;
5152 ASSIGN_VALUE_OR_FAIL_TEST(
5153 replacedMap,
5154 ::glow::parallelizeOps(F_, numChunks, parOpts, /*numOfChunks*/ 1,
5155 /*modelParallelSplitAlignment*/ 64));
5156 EXPECT_EQ(replacedMap.size(), parOpts.size());
5157 runDCEPass(F_, cctx_);
5158 EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
5159
5160 // Check all splitted FCs.
5161 auto *concatNode = replacedMap[fc1];
5162 auto *fc_split0 =
5163 llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(0));
5164 auto *fc_split1 =
5165 llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(1));
5166 auto *fc_split2 =
5167 llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(2));
5168 ASSERT_TRUE(fc_split0);
5169 ASSERT_TRUE(fc_split1);
5170 ASSERT_TRUE(fc_split2);
5171 EXPECT_TRUE(fc_split0->getResult().dims().equals({8, 64}));
5172 EXPECT_TRUE(fc_split0->getBias().dims().equals({64}));
5173 EXPECT_TRUE(fc_split0->getWeights().dims().equals({3, 64}));
5174 EXPECT_TRUE(fc_split1->getResult().dims().equals({8, 64}));
5175 EXPECT_TRUE(fc_split1->getBias().dims().equals({64}));
5176 EXPECT_TRUE(fc_split1->getWeights().dims().equals({3, 64}));
5177 EXPECT_TRUE(fc_split2->getResult().dims().equals({8, 12}));
5178 EXPECT_TRUE(fc_split2->getBias().dims().equals({12}));
5179 EXPECT_TRUE(fc_split2->getWeights().dims().equals({3, 12}));
5180
5181 checkNumericalEquivalence();
5182}
5183
5184/// Test Splitting MatMul into multiple MatMuls
5185TEST_F(GraphOptz, SplitMatMulIntoMultipleOps_Data) {
5186 auto *input1 =
5187 mod_.createPlaceholder(ElemKind::FloatTy, {12, 32}, "input1", false);
5188 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5189 mod_.getPRNG());
5190 auto *input2 =
5191 mod_.createPlaceholder(ElemKind::FloatTy, {32, 32}, "input2", false);
5192 bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
5193 mod_.getPRNG());
5194 auto *output =
5195 mod_.createPlaceholder(ElemKind::FloatTy, {12, 32}, "output", false);
5196 bindings_.allocate(output);
5197
5198 auto *mm = F_->createMatMul("mm", input1, input2);
5199 auto *save = F_->createSave("save", mm, output);
5200
5201 ::glow::optimize(F_, CompilationMode::Infer);
5202
5203 // This is F_ but without the parallel transformation below.
5204 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5205
5206 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5207 parOpts[mm] = ParallelTransformKind::Data;
5208
5209 std::unordered_map<Node *, ConcatNode *> replacedMap;
5210 ASSIGN_VALUE_OR_FAIL_TEST(
5211 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5212 parOpts, 12));
5213 EXPECT_EQ(replacedMap.size(), parOpts.size());
5214 runDCEPass(F_, cctx_);
5215
5216 // 12 Slices from LHS
5217 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5218
5219 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5220
5221 // 12 newly created MatMuls.
5222 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
5223
5224 auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput());
5225 ASSERT_TRUE(concatNode);
5226 // 12 FCs are connected to the concat node.
5227 EXPECT_EQ(12, concatNode->getInputs().size());
5228
5229 for (unsigned i = 0; i < 12; ++i) {
5230 auto *mmInput = llvm::dyn_cast<MatMulNode>(concatNode->getNthInput(i));
5231 ASSERT_TRUE(mmInput);
5232 EXPECT_TRUE(mmInput->getResult().dims().equals({1, 32}));
5233 }
5234
5235 checkNumericalEquivalence();
5236}
5237
5238/// Test Splitting MatMul into multiple MatMuls
5239TEST_F(GraphOptz, SplitMatMulIntoMultipleOps_Model) {
5240 auto *input1 =
5241 mod_.createPlaceholder(ElemKind::FloatTy, {12, 48}, "input1", false);
5242 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5243 mod_.getPRNG());
5244 auto *input2 =
5245 mod_.createPlaceholder(ElemKind::FloatTy, {48, 48}, "input2", false);
5246 bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
5247 mod_.getPRNG());
5248 auto *output =
5249 mod_.createPlaceholder(ElemKind::FloatTy, {12, 48}, "output", false);
5250 bindings_.allocate(output);
5251
5252 auto *mm = F_->createMatMul("mm", input1, input2);
5253 auto *save = F_->createSave("save", mm, output);
5254
5255 ::glow::optimize(F_, CompilationMode::Infer);
5256
5257 // This is F_ but without the parallel transformation below.
5258 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5259
5260 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5261 parOpts[mm] = ParallelTransformKind::Model;
5262
5263 std::unordered_map<Node *, ConcatNode *> replacedMap;
5264 ASSIGN_VALUE_OR_FAIL_TEST(
5265 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5266 parOpts, 12));
5267 EXPECT_EQ(replacedMap.size(), parOpts.size());
5268 runDCEPass(F_, cctx_);
5269
5270 // 12 Slices from RHS
5271 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5272
5273 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5274
5275 // 12 newly created MatMuls.
5276 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
5277
5278 auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput());
5279 ASSERT_TRUE(concatNode);
5280 // 12 FCs are connected to the concat node.
5281 EXPECT_EQ(12, concatNode->getInputs().size());
5282
5283 for (unsigned i = 0; i < 12; ++i) {
5284 auto *mmInput = llvm::dyn_cast<MatMulNode>(concatNode->getNthInput(i));
5285 ASSERT_TRUE(mmInput);
5286 EXPECT_TRUE(mmInput->getResult().dims().equals({12, 4}));
5287 }
5288
5289 checkNumericalEquivalence();
5290}
5291
5292/// Test Splitting Add into multiple Adds.
5293TEST_F(GraphOptz, ParallelizeGraph_Add) {
5294 auto *input1 =
5295 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1", false);
5296 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5297 mod_.getPRNG());
5298 auto *input2 =
5299 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2", false);
5300 bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
5301 mod_.getPRNG());
5302 auto *output =
5303 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5304 bindings_.allocate(output);
5305
5306 auto *add1 = F_->createAdd("add1", input1, input2);
5307 auto *add2 = F_->createAdd("add2", add1, add1);
5308 F_->createSave("save", add2, output);
5309
5310 ::glow::optimize(F_, CompilationMode::Infer);
5311
5312 // This is F_ but without the parallel transformation below.
5313 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5314
5315 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5316 parOpts[add1] = ParallelTransformKind::Data;
5317
5318 std::unordered_map<Node *, ConcatNode *> replacedMap;
5319 ASSIGN_VALUE_OR_FAIL_TEST(
5320 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5321 parOpts, 12));
5322 EXPECT_EQ(replacedMap.size(), parOpts.size());
5323 runDCEPass(F_, cctx_);
5324
5325 // We now have 12 Adds from add1, as well as the original add2 which is
5326 // unchanged.
5327 EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::AddNodeKind));
5328
5329 // Each input of the 12 Adds are sliced.
5330 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5331
5332 // One concat to bring all of the parallelized sliced Adds together.
5333 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5334
5335 checkNumericalEquivalence();
5336}
5337
5338/// Test Splitting Add into multiple Adds along different axes.
5339static void testParallelizeGraphAddModel(PlaceholderBindings &bindings,
5340 Module &mod, Function *F,
5341 Function *&optF,
5342 CompilationContext &cctx,
5343 ParallelTransformKind parKind) {
5344 auto *input1 = mod.createPlaceholder(ElemKind::FloatTy, {16, 17, 18, 19, 20},
5345 "input1", false);
5346 bindings.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5347 mod.getPRNG());
5348 auto *input2 = mod.createPlaceholder(ElemKind::FloatTy, {16, 17, 18, 19, 20},
5349 "input2", false);
5350 bindings.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
5351 mod.getPRNG());
5352 auto *output = mod.createPlaceholder(ElemKind::FloatTy, {16, 17, 18, 19, 20},
5353 "output", false);
5354 bindings.allocate(output);
5355
5356 auto *add1 = F->createAdd("add1", input1, input2);
5357 auto *add2 = F->createAdd("add2", add1, add1);
5358 F->createSave("save", add2, output);
5359
5360 ::glow::optimize(F, CompilationMode::Infer);
5361
5362 // This is F_ but without the parallel transformation below.
5363 optF = F->clone(F->getName().str() + "_optimized");
5364
5365 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5366 parOpts[add1] = parKind;
5367
5368 std::unordered_map<Node *, ConcatNode *> replacedMap;
5369 ASSIGN_VALUE_OR_FAIL_TEST(
5370 replacedMap,
5371 ::glow::parallelizeOps(F, llvm::DenseMap<Node *, size_t>(), parOpts, 12));
5372 EXPECT_EQ(replacedMap.size(), parOpts.size());
5373 runDCEPass(F, cctx);
5374
5375 // We now have 12 Adds from add1, as well as the original add2 which is
5376 // unchanged.
5377 EXPECT_EQ(13, countNodeKind(F, Kinded::Kind::AddNodeKind));
5378
5379 // Each input of the 12 Adds are sliced.
5380 EXPECT_EQ(24, countNodeKind(F, Kinded::Kind::SliceNodeKind));
5381
5382 // One concat to bring all of the parallelized sliced Adds together.
5383 EXPECT_EQ(1, countNodeKind(F, Kinded::Kind::ConcatNodeKind));
5384}
5385
5386TEST_F(GraphOptz, ParallelizeGraph_Add_Model_Axis1) {
5387 testParallelizeGraphAddModel(bindings_, mod_, F_, optimizedF_, cctx_,
5388 ParallelTransformKind::Model_Axis1);
5389 checkNumericalEquivalence(0.f);
5390}
5391
5392TEST_F(GraphOptz, ParallelizeGraph_Add_Model_Axis3) {
5393 testParallelizeGraphAddModel(bindings_, mod_, F_, optimizedF_, cctx_,
5394 ParallelTransformKind::Model_Axis3);
5395 checkNumericalEquivalence(0.f);
5396}
5397
5398TEST_F(GraphOptz, ParallelizeGraph_Add_Model_Axis4) {
5399 testParallelizeGraphAddModel(bindings_, mod_, F_, optimizedF_, cctx_,
5400 ParallelTransformKind::Model_Axis4);
5401 checkNumericalEquivalence(0.f);
5402}
5403
5404/// Test Splitting Sub into multiple Subs.
5405TEST_F(GraphOptz, ParallelizeGraph_Sub) {
5406 auto *input1 =
5407 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1", false);
5408 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5409 mod_.getPRNG());
5410 auto *input2 =
5411 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2", false);
5412 bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
5413 mod_.getPRNG());
5414 auto *output =
5415 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5416 bindings_.allocate(output);
5417
5418 auto *sub1 = F_->createSub("sub1", input1, input2);
5419 auto *sub2 = F_->createSub("sub2", sub1, sub1);
5420 F_->createSave("save", sub2, output);
5421
5422 ::glow::optimize(F_, CompilationMode::Infer);
5423
5424 // This is F_ but without the parallel transformation below.
5425 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5426
5427 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5428 parOpts[sub1] = ParallelTransformKind::Data;
5429
5430 std::unordered_map<Node *, ConcatNode *> replacedMap;
5431 ASSIGN_VALUE_OR_FAIL_TEST(
5432 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5433 parOpts, 12));
5434 EXPECT_EQ(replacedMap.size(), parOpts.size());
5435 runDCEPass(F_, cctx_);
5436
5437 // We now have 12 Subs from sub1, as well as the original sub2 which is
5438 // unchanged.
5439 EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::SubNodeKind));
5440
5441 // Each input of the 12 Subs are sliced.
5442 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5443
5444 // One concat to bring all of the parallelized sliced Subs together.
5445 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5446
5447 checkNumericalEquivalence();
5448}
5449
5450/// Test Splitting Pow into multiple Pows.
5451TEST_F(GraphOptz, ParallelizeGraph_Pow) {
5452 auto *input1 =
5453 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1", false);
5454 bindings_.allocate(input1)->getHandle<float>().randomize(1.0, 2.0,
5455 mod_.getPRNG());
5456 auto *input2 =
5457 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2", false);
5458 bindings_.allocate(input2)->getHandle<float>().randomize(0.0, 5.0,
5459 mod_.getPRNG());
5460 auto *output =
5461 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5462 bindings_.allocate(output);
5463
5464 auto *Pow1 = F_->createPow("Pow1", input1, input2);
5465 F_->createSave("save", Pow1, output);
5466
5467 ::glow::optimize(F_, CompilationMode::Infer);
5468
5469 // This is F_ but without the parallel transformation below.
5470 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5471
5472 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5473 parOpts[Pow1] = ParallelTransformKind::Data;
5474
5475 std::unordered_map<Node *, ConcatNode *> replacedMap;
5476 ASSIGN_VALUE_OR_FAIL_TEST(
5477 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5478 parOpts, 12));
5479 EXPECT_EQ(replacedMap.size(), parOpts.size());
5480 runDCEPass(F_, cctx_);
5481
5482 // We now have 12 Pows from Pow1
5483 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::PowNodeKind));
5484
5485 // Each input of the 12 Pows are sliced.
5486 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5487
5488 // One concat to bring all of the parallelized sliced Pows together.
5489 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5490
5491 checkNumericalEquivalence();
5492}
5493
5494/// Test Splitting Max into multiple Maxs.
5495TEST_F(GraphOptz, ParallelizeGraph_Max) {
5496 auto *input1 =
5497 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1", false);
5498 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5499 mod_.getPRNG());
5500 auto *input2 =
5501 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2", false);
5502 bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
5503 mod_.getPRNG());
5504 auto *output =
5505 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5506 bindings_.allocate(output);
5507
5508 auto *Max1 = F_->createMax("Max1", input1, input2);
5509 F_->createSave("save", Max1, output);
5510
5511 ::glow::optimize(F_, CompilationMode::Infer);
5512
5513 // This is F_ but without the parallel transformation below.
5514 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5515
5516 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5517 parOpts[Max1] = ParallelTransformKind::Data;
5518
5519 std::unordered_map<Node *, ConcatNode *> replacedMap;
5520 ASSIGN_VALUE_OR_FAIL_TEST(
5521 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5522 parOpts, 12));
5523 EXPECT_EQ(replacedMap.size(), parOpts.size());
5524 runDCEPass(F_, cctx_);
5525
5526 // We now have 12 Maxs from Max1
5527 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MaxNodeKind));
5528
5529 // Each input of the 12 Maxs are sliced.
5530 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5531
5532 // One concat to bring all of the parallelized sliced Maxs together.
5533 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5534
5535 checkNumericalEquivalence();
5536}
5537
5538/// Test Splitting Min into multiple Mins.
5539TEST_F(GraphOptz, ParallelizeGraph_Min) {
5540 auto *input1 =
5541 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1", false);
5542 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5543 mod_.getPRNG());
5544 auto *input2 =
5545 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2", false);
5546 bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
5547 mod_.getPRNG());
5548 auto *output =
5549 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5550 bindings_.allocate(output);
5551
5552 auto *Min1 = F_->createMin("Min1", input1, input2);
5553 F_->createSave("save", Min1, output);
5554
5555 ::glow::optimize(F_, CompilationMode::Infer);
5556
5557 // This is F_ but without the parallel transformation below.
5558 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5559
5560 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5561 parOpts[Min1] = ParallelTransformKind::Data;
5562
5563 std::unordered_map<Node *, ConcatNode *> replacedMap;
5564 ASSIGN_VALUE_OR_FAIL_TEST(
5565 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5566 parOpts, 12));
5567 EXPECT_EQ(replacedMap.size(), parOpts.size());
5568 runDCEPass(F_, cctx_);
5569
5570 // We now have 12 Mins from Min1
5571 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MinNodeKind));
5572
5573 // Each input of the 12 Mins are sliced.
5574 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5575
5576 // One concat to bring all of the parallelized sliced Mins together.
5577 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5578
5579 checkNumericalEquivalence();
5580}
5581
5582/// Test Splitting BatchedReduceMean into multiple BatchedReduceMeans.
5583TEST_F(GraphOptz, ParallelizeGraph_BatchedReduceMean) {
5584 auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {32, 16, 2048},
5585 "input1", false);
5586 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5587 mod_.getPRNG());
5588 auto *output =
5589 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5590 bindings_.allocate(output);
5591
5592 auto *BatchedReduceMean1 =
5593 F_->createBatchedReduceMean("BatchedReduceMean1", input1, {1});
5594 F_->createSave("save", BatchedReduceMean1, output);
5595
5596 ::glow::optimize(F_, CompilationMode::Infer);
5597
5598 // This is F_ but without the parallel transformation below.
5599 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5600
5601 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5602 parOpts[BatchedReduceMean1] = ParallelTransformKind::Data;
5603
5604 std::unordered_map<Node *, ConcatNode *> replacedMap;
5605 ASSIGN_VALUE_OR_FAIL_TEST(
5606 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5607 parOpts, 12));
5608 EXPECT_EQ(replacedMap.size(), parOpts.size());
5609 runDCEPass(F_, cctx_);
5610
5611 // We now have 12 BatchedReduceMeans from BatchedReduceMean1
5612 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::BatchedReduceMeanNodeKind));
5613
5614 // Each input of the 12 BatchedReduceMeans are sliced.
5615 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5616
5617 // One concat to bring all of the parallelized sliced BatchedReduceMeans
5618 // together.
5619 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5620
5621 checkNumericalEquivalence();
5622}
5623
5624/// Test Splitting BatchedReduceMean into multiple BatchedReduceMeans.
5625/// Failure case with first dimension in reduction
5626TEST_F(GraphOptz, ParallelizeGraph_BatchedReduceMean_failure) {
5627 auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {32, 16, 2048},
5628 "input1", false);
5629 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5630 mod_.getPRNG());
5631 auto *output =
5632 mod_.createPlaceholder(ElemKind::FloatTy, {16, 2048}, "output", false);
5633 bindings_.allocate(output);
5634
5635 auto *BatchedReduceMean1 =
5636 F_->createBatchedReduceMean("BatchedReduceMean1", input1, {0});
5637 F_->createSave("save", BatchedReduceMean1, output);
5638
5639 ::glow::optimize(F_, CompilationMode::Infer);
5640
5641 // This is F_ but without the parallel transformation below.
5642 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5643
5644 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5645 parOpts[BatchedReduceMean1] = ParallelTransformKind::Data;
5646
5647 std::unordered_map<Node *, ConcatNode *> replacedMap;
5648 ASSIGN_VALUE_OR_FAIL_TEST(
5649 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5650 parOpts, 12));
5651 EXPECT_EQ(replacedMap.size(), 0); // Nothing changes
5652 runDCEPass(F_, cctx_);
5653
5654 // We now have only 1 BatchedReduceMean since parallelization is disabled
5655 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::BatchedReduceMeanNodeKind));
5656
5657 // No concats
5658 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5659
5660 checkNumericalEquivalence();
5661}
5662
5663/// Test Splitting Transpose into multiple Transposes.
5664TEST_F(GraphOptz, ParallelizeGraph_Transpose) {
5665 auto *input =
5666 mod_.createPlaceholder(ElemKind::FloatTy, {32, 151, 64}, "input", false);
5667 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
5668 mod_.getPRNG());
5669 auto *output =
5670 mod_.createPlaceholder(ElemKind::FloatTy, {32, 64, 151}, "output", false);
5671 bindings_.allocate(output);
5672
5673 auto *trans1 = F_->createTranspose("trans1", input, {0, 2, 1});
5674 F_->createSave("save", trans1, output);
5675
5676 ::glow::optimize(F_, CompilationMode::Infer);
5677
5678 // This is F_ but without the parallel transformation below.
5679 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5680
5681 llvm::DenseMap<Node *, size_t> numChunks;
5682 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5683 numChunks[trans1] = 2;
5684 parOpts[trans1] = ParallelTransformKind::Data;
5685 std::unordered_map<Node *, ConcatNode *> replacedMap;
5686 ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
5687 ::glow::parallelizeOps(F_, numChunks, parOpts));
5688 EXPECT_EQ(replacedMap.size(), parOpts.size());
5689
5690 runDCEPass(F_, cctx_);
5691
5692 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TransposeNodeKind));
5693
5694 checkNumericalEquivalence();
5695}
5696
5697/// Test Splitting Transpose into multiple Transposes.
5698TEST_F(GraphOptz, ParallelizeGraph_Transpose3D_210) {
5699 auto *input =
5700 mod_.createPlaceholder(ElemKind::FloatTy, {4, 15, 23}, "input", false);
5701 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
5702 mod_.getPRNG());
5703 auto *output =
5704 mod_.createPlaceholder(ElemKind::FloatTy, {23, 15, 4}, "output", false);
5705 bindings_.allocate(output);
5706
5707 auto *trans1 = F_->createTranspose("trans1", input, {2, 1, 0});
5708 F_->createSave("save", trans1, output);
5709
5710 ::glow::optimize(F_, CompilationMode::Infer);
5711
5712 // This is F_ but without the parallel transformation below.
5713 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5714
5715 llvm::DenseMap<Node *, size_t> numChunks;
5716 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5717 numChunks[trans1] = 8;
5718 parOpts[trans1] = ParallelTransformKind::Data;
5719 std::unordered_map<Node *, ConcatNode *> replacedMap;
5720 ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
5721 ::glow::parallelizeOps(F_, numChunks, parOpts));
5722 EXPECT_EQ(replacedMap.size(), parOpts.size());
5723
5724 runDCEPass(F_, cctx_);
5725
5726 EXPECT_EQ(8, countNodeKind(F_, Kinded::Kind::TransposeNodeKind));
5727
5728 checkNumericalEquivalence();
5729}
5730
5731/// Test Splitting Transpose into multiple Transposes.
5732TEST_F(GraphOptz, ParallelizeGraph_Transpose3D_120) {
5733 auto *input =
5734 mod_.createPlaceholder(ElemKind::FloatTy, {15, 8, 23}, "input", false);
5735 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
5736 mod_.getPRNG());
5737 auto *output =
5738 mod_.createPlaceholder(ElemKind::FloatTy, {8, 23, 15}, "output", false);
5739 bindings_.allocate(output);
5740
5741 auto *trans1 = F_->createTranspose("trans1", input, {1, 2, 0});
5742 F_->createSave("save", trans1, output);
5743
5744 ::glow::optimize(F_, CompilationMode::Infer);
5745
5746 // This is F_ but without the parallel transformation below.
5747 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5748
5749 llvm::DenseMap<Node *, size_t> numChunks;
5750 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5751 numChunks[trans1] = 8;
5752 parOpts[trans1] = ParallelTransformKind::Data;
5753 std::unordered_map<Node *, ConcatNode *> replacedMap;
5754 ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
5755 ::glow::parallelizeOps(F_, numChunks, parOpts));
5756 EXPECT_EQ(replacedMap.size(), parOpts.size());
5757
5758 runDCEPass(F_, cctx_);
5759
5760 EXPECT_EQ(8, countNodeKind(F_, Kinded::Kind::TransposeNodeKind));
5761
5762 checkNumericalEquivalence();
5763}
5764
5765/// Test Splitting Select into multiple Selects.
5766TEST_F(GraphOptz, ParallelizeGraphData_Select) {
5767 auto *sel1_lhs =
5768 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_lhs", false);
5769 bindings_.allocate(sel1_lhs)->getHandle<float>().randomize(-1.0, 1.0,
5770 mod_.getPRNG());
5771 auto *sel1_rhs =
5772 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_rhs", false);
5773 bindings_.allocate(sel1_rhs)->getHandle<float>().randomize(-1.0, 1.0,
5774 mod_.getPRNG());
5775 auto *sel1_cond =
5776 mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel1_cond", false);
5777 bindings_.allocate(sel1_cond)->getHandle<bool>().randomize(0, 1,
5778 mod_.getPRNG());
5779 auto *sel2_rhs =
5780 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel2_rhs", false);
5781 bindings_.allocate(sel2_rhs)->getHandle<float>().randomize(-1.0, 1.0,
5782 mod_.getPRNG());
5783 auto *sel2_cond =
5784 mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel2_cond", false);
5785 bindings_.allocate(sel2_cond)->getHandle<bool>().randomize(0, 1,
5786 mod_.getPRNG());
5787 auto *output =
5788 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5789 bindings_.allocate(output);
5790
5791 auto *sel1 = F_->createSelect("sel1", sel1_cond, sel1_lhs, sel1_rhs);
5792 auto *sel2 = F_->createSelect("sel2", sel2_cond, sel1, sel2_rhs);
5793 F_->createSave("save", sel2, output);
5794
5795 ::glow::optimize(F_, CompilationMode::Infer);
5796
5797 // This is F_ but without the parallel transformation below.
5798 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5799
5800 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5801 parOpts[sel1] = ParallelTransformKind::Data;
5802
5803 std::unordered_map<Node *, ConcatNode *> replacedMap;
5804 ASSIGN_VALUE_OR_FAIL_TEST(
5805 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5806 parOpts, 12));
5807 EXPECT_EQ(replacedMap.size(), parOpts.size());
5808 runDCEPass(F_, cctx_);
5809
5810 // We now have 12 Selects from sel1, as well as the original sel2 which is
5811 // unchanged.
5812 EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::SelectNodeKind));
5813
5814 // Each input (3 total inputs) of the 12 Selects are sliced.
5815 EXPECT_EQ(36, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5816
5817 // One concat to bring all of the parallelized sliced Select together.
5818 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5819
5820 checkNumericalEquivalence();
5821}
5822
5823/// Test Splitting Select into multiple Selects.
5824TEST_F(GraphOptz, ParallelizeGraphModel_Select) {
5825 auto *sel1_lhs =
5826 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_lhs", false);
5827 bindings_.allocate(sel1_lhs)->getHandle<float>().randomize(-1.0, 1.0,
5828 mod_.getPRNG());
5829 auto *sel1_rhs =
5830 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_rhs", false);
5831 bindings_.allocate(sel1_rhs)->getHandle<float>().randomize(-1.0, 1.0,
5832 mod_.getPRNG());
5833 auto *sel1_cond =
5834 mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel1_cond", false);
5835 bindings_.allocate(sel1_cond)->getHandle<bool>().randomize(0, 1,
5836 mod_.getPRNG());
5837 auto *sel2_rhs =
5838 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel2_rhs", false);
5839 bindings_.allocate(sel2_rhs)->getHandle<float>().randomize(-1.0, 1.0,
5840 mod_.getPRNG());
5841 auto *sel2_cond =
5842 mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel2_cond", false);
5843 bindings_.allocate(sel2_cond)->getHandle<bool>().randomize(0, 1,
5844 mod_.getPRNG());
5845 auto *output =
5846 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
5847 bindings_.allocate(output);
5848
5849 auto *sel1 = F_->createSelect("sel1", sel1_cond, sel1_lhs, sel1_rhs);
5850 auto *sel2 = F_->createSelect("sel2", sel2_cond, sel1, sel2_rhs);
5851 F_->createSave("save", sel2, output);
5852
5853 ::glow::optimize(F_, CompilationMode::Infer);
5854
5855 // This is F_ but without the parallel transformation below.
5856 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5857
5858 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5859 parOpts[sel1] = ParallelTransformKind::Model;
5860
5861 std::unordered_map<Node *, ConcatNode *> replacedMap;
5862 ASSIGN_VALUE_OR_FAIL_TEST(
5863 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
5864 parOpts, 12));
5865 EXPECT_EQ(replacedMap.size(), parOpts.size());
5866 runDCEPass(F_, cctx_);
5867
5868 // We now have 12 Selects from sel1, as well as the original sel2 which is
5869 // unchanged.
5870 EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::SelectNodeKind));
5871
5872 // Each input (3 total inputs) of the 12 Selects are sliced.
5873 EXPECT_EQ(36, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
5874
5875 // One concat to bring all of the parallelized sliced Select together.
5876 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5877
5878 checkNumericalEquivalence();
5879}
5880
5881/// Test Splitting Reshape into multiple Reshapes.
5882TEST_F(GraphOptz, ParallelizeData_Reshape) {
5883 auto *input1 =
5884 mod_.createPlaceholder(ElemKind::FloatTy, {3, 64}, "input1", false);
5885 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5886 mod_.getPRNG());
5887 auto *output =
5888 mod_.createPlaceholder(ElemKind::FloatTy, {3, 8, 8}, "output", false);
5889 bindings_.allocate(output);
5890
5891 auto *rs = F_->createReshape("reshape1", input1, {3, 8, 8});
5892 F_->createSave("save", rs, output);
5893
5894 ::glow::optimize(F_, CompilationMode::Infer);
5895
5896 // This is F_ but without the parallel transformation below.
5897 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5898
5899 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5900 parOpts[rs] = ParallelTransformKind::Data;
5901
5902 std::unordered_map<Node *, ConcatNode *> replacedMap;
5903 ASSIGN_VALUE_OR_FAIL_TEST(
5904 replacedMap,
5905 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
5906 EXPECT_EQ(replacedMap.size(), parOpts.size());
5907 runDCEPass(F_, cctx_);
5908
5909 // We now have 3 Reshapes
5910 EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
5911
5912 // One concat to bring all of the parallelized sliced Reshapes together.
5913 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5914
5915 checkNumericalEquivalence();
5916}
5917
5918/// Test Splitting Reshape into multiple Reshapes when the batch
5919/// dimension changes. This is not allowed when the input or output batch size
5920/// dim cannot be divided by the # of the parallel chunks.
5921TEST_F(GraphOptz, ParallelizeData_Reshape_badcase) {
5922 auto *input1 =
5923 mod_.createPlaceholder(ElemKind::FloatTy, {4, 48}, "input1", false);
5924 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5925 mod_.getPRNG());
5926 auto *output =
5927 mod_.createPlaceholder(ElemKind::FloatTy, {24, 8}, "output", false);
5928 bindings_.allocate(output);
5929
5930 auto *rs = F_->createReshape("reshape1", input1, {24, 8});
5931 F_->createSave("save", rs, output);
5932
5933 ::glow::optimize(F_, CompilationMode::Infer);
5934
5935 // This is F_ but without the parallel transformation below.
5936 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5937
5938 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5939 parOpts[rs] = ParallelTransformKind::Data;
5940
5941 std::unordered_map<Node *, ConcatNode *> replacedMap;
5942 ASSIGN_VALUE_OR_FAIL_TEST(
5943 replacedMap,
5944 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
5945 EXPECT_EQ(replacedMap.size(), 0); // Nothing gets replaced
5946 runDCEPass(F_, cctx_);
5947
5948 // We now have only 1 Reshape as nothing should have split
5949 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
5950
5951 checkNumericalEquivalence();
5952}
5953
5954/// Test Splitting AdaptiveAvgPool into multiple AdaptiveAvgPools.
5955TEST_F(GraphOptz, ParallelizeData_AdaptiveAvgPool) {
5956 auto *input1 =
5957 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 8}, "input1", false);
5958 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5959 mod_.getPRNG());
5960 auto *output =
5961 mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 1, 8}, "output", false);
5962 bindings_.allocate(output);
5963
5964 auto outTy = mod_.uniqueType(ElemKind::FloatTy, {3, 1, 1, 8});
5965
5966 auto *aap = F_->createAdaptiveAvgPool("AdaptiveAvgPool1", input1, outTy);
5967 F_->createSave("save", aap, output);
5968
5969 ::glow::optimize(F_, CompilationMode::Infer);
5970
5971 // This is F_ but without the parallel transformation below.
5972 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5973
5974 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
5975 parOpts[aap] = ParallelTransformKind::Data;
5976
5977 std::unordered_map<Node *, ConcatNode *> replacedMap;
5978 ASSIGN_VALUE_OR_FAIL_TEST(
5979 replacedMap,
5980 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
5981 EXPECT_EQ(replacedMap.size(), parOpts.size());
5982 runDCEPass(F_, cctx_);
5983
5984 // We now have 3 AdaptiveAvgPools
5985 EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::AdaptiveAvgPoolNodeKind));
5986
5987 // One concat to bring all of the parallelized sliced AdaptiveAvgPools
5988 // together.
5989 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
5990
5991 checkNumericalEquivalence();
5992}
5993
5994/// Test Splitting RoIAlign into multiple RoIAligns.
5995TEST_F(GraphOptz, ParallelizeData_RoIAlign) {
5996 auto *input1 =
5997 mod_.createPlaceholder(ElemKind::FloatTy, {4, 5, 5, 8}, "input1", false);
5998 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
5999 mod_.getPRNG());
6000 auto *boxes = mod_.createPlaceholder(ElemKind::FloatTy, {6, 4}, "roi", false);
6001 bindings_.allocate(boxes)->getHandle<float>() = {
6002 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3};
6003 auto *batchIndices =
6004 mod_.createPlaceholder(ElemKind::Int64ITy, {6}, "roi", false);
6005 bindings_.allocate(batchIndices)
6006 ->getHandle<int64_t>()
6007 .randomize(0, 3, mod_.getPRNG());
6008
6009 auto *output =
6010 mod_.createPlaceholder(ElemKind::FloatTy, {6, 1, 1, 8}, "output", false);
6011 bindings_.allocate(output);
6012
6013 auto *aap = F_->createROIAlign("ROIAlign", input1, boxes, batchIndices, 1, 1,
6014 0, 1, false);
6015 F_->createSave("save", aap, output);
6016
6017 ::glow::optimize(F_, CompilationMode::Infer);
6018
6019 // This is F_ but without the parallel transformation below.
6020 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
6021
6022 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6023 parOpts[aap] = ParallelTransformKind::Data;
6024
6025 std::unordered_map<Node *, ConcatNode *> replacedMap;
6026 ASSIGN_VALUE_OR_FAIL_TEST(
6027 replacedMap,
6028 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
6029 EXPECT_EQ(replacedMap.size(), parOpts.size());
6030 runDCEPass(F_, cctx_);
6031
6032 // We now have 3 RoIAligns
6033 EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::ROIAlignNodeKind));
6034
6035 // One concat to bring all of the parallelized sliced RoIAligns
6036 // together.
6037 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6038
6039 checkNumericalEquivalence();
6040}
6041
6042/// Test Splitting MaxPool into multiple MaxPools.
6043TEST_F(GraphOptz, ParallelizeData_MaxPool) {
6044 auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 8}, 1.0, 0,
6045 "input1", false);
6046 bindings_.allocate(input1)->getHandle<int8_t>().randomize(-1.0, 1.0,
6047 mod_.getPRNG());
6048
6049 auto *maxp = F_->createMaxPool("MaxPool1", input1, 5, 1, 0);
6050 F_->createSave("save", maxp->getResult());
6051
6052 ::glow::optimize(F_, CompilationMode::Infer);
6053
6054 // This is F_ but without the parallel transformation below.
6055 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
6056
6057 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6058 parOpts[maxp] = ParallelTransformKind::Data;
6059
6060 std::unordered_map<Node *, ConcatNode *> replacedMap;
6061 ASSIGN_VALUE_OR_FAIL_TEST(
6062 replacedMap,
6063 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
6064 EXPECT_EQ(replacedMap.size(), parOpts.size());
6065 runDCEPass(F_, cctx_);
6066
6067 // We now have 3 MaxPools
6068 EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind));
6069
6070 // One concat to bring all of the parallelized sliced MaxPools
6071 // together.
6072 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6073
6074 checkNumericalEquivalence();
6075}
6076
6077/// Test Splitting ChannelwiseQuantizedConvolution into multiple
6078/// ChannelwiseQuantizedConvolutions.
6079TEST_F(GraphOptz, ParallelizeData_ChannelwiseQuantizedConvolution) {
6080 auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 8}, 1.0, 0,
6081 "input1", false);
6082 bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4,
6083 mod_.getPRNG());
6084 auto *filter =
6085 mod_.createConstant(ElemKind::FloatTy, {12, 1, 1, 8}, "weights");
6086 filter->getPayloadMutable().getHandle().randomize(-10, 10, mod_.getPRNG());
6087 auto *bias = mod_.createConstant(ElemKind::FloatTy, {12}, "bias");
6088 bias->getPayloadMutable().getHandle().randomize(-1, 1, mod_.getPRNG());
6089 auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0,
6090 0, "output", false);
6091 bindings_.allocate(output);
6092 auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0, 0);
6093
6094 auto *cqc = F_->createChannelwiseQuantizedConv(
6095 "ChannelwiseQuantizedConvolution1", input1, filter, bias, nullptr,
6096 nullptr, nullptr, nullptr, outTy, {1, 1}, {1, 1}, {0, 0, 0, 0}, 1);
6097 F_->createSave("save", cqc, output);
6098
6099 ::glow::optimize(F_, CompilationMode::Infer);
6100
6101 // This is F_ but without the parallel transformation below.
6102 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
6103
6104 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6105 parOpts[cqc] = ParallelTransformKind::Data;
6106
6107 std::unordered_map<Node *, ConcatNode *> replacedMap;
6108 ASSIGN_VALUE_OR_FAIL_TEST(
6109 replacedMap,
6110 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
6111 EXPECT_EQ(replacedMap.size(), parOpts.size());
6112 runDCEPass(F_, cctx_);
6113
6114 // We now have 3 ChannelwiseQuantizedConvolutions
6115 EXPECT_EQ(3, countNodeKind(
6116 F_, Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind));
6117
6118 // One concat to bring all of the parallelized sliced
6119 // ChannelwiseQuantizedConvolutions together.
6120 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6121
6122 checkNumericalEquivalence();
6123}
6124
6125/// Test Splitting Convolution into multiple Convolutions.
6126TEST_F(GraphOptz, ParallelizeData_Convolution) {
6127 auto *input1 =
6128 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 4}, "input1", false);
6129 bindings_.allocate(input1)->getHandle<float>().randomize(-1, 1,
6130 mod_.getPRNG());
6131 auto *filter =
6132 mod_.createConstant(ElemKind::FloatTy, {6, 1, 1, 2}, "weights");
6133 filter->getPayloadMutable().getHandle().randomize(-1, 1, mod_.getPRNG());
6134 auto *bias = mod_.createConstant(ElemKind::FloatTy, {6}, "bias");
6135 bias->getPayloadMutable().getHandle().randomize(-.1, .1, mod_.getPRNG());
6136 auto *output =
6137 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 6}, "output", false);
6138 bindings_.allocate(output);
6139 auto outTy = mod_.uniqueType(ElemKind::FloatTy, {3, 5, 5, 6});
6140
6141 auto *conv =
6142 F_->createConv("Convolution1", input1, filter, bias, outTy, 1, 1, 0, 2);
6143 F_->createSave("save", conv, output);
6144
6145 ::glow::optimize(F_, CompilationMode::Infer);
6146
6147 // This is F_ but without the parallel transformation below.
6148 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
6149
6150 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6151 parOpts[conv] = ParallelTransformKind::Data;
6152
6153 std::unordered_map<Node *, ConcatNode *> replacedMap;
6154 ASSIGN_VALUE_OR_FAIL_TEST(
6155 replacedMap,
6156 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
6157 EXPECT_EQ(replacedMap.size(), parOpts.size());
6158 runDCEPass(F_, cctx_);
6159
6160 // We now have 3 Convolutions
6161 EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
6162
6163 // One concat to bring all of the parallelized sliced
6164 // Convolutions together.
6165 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6166
6167 checkNumericalEquivalence();
6168}
6169
6170/// Test Splitting RowwiseQuantizedFullyConnected into multiple
6171/// RowwiseQuantizedFullyConnected nodes.
6172TEST_F(GraphOptz, ParallelizeData_RowwiseQuantizedFullyConnected) {
6173 auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 8}, 1.0, 0,
6174 "input1", false);
6175 bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4,
6176 mod_.getPRNG());
6177 auto *weights =
6178 mod_.createConstant(ElemKind::Int8QTy, {12, 8}, 1.0, 0, "weights");
6179 weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
6180 mod_.getPRNG());
6181 auto *scales = mod_.createConstant(ElemKind::FloatTy, {12}, "scales");
6182 scales->getPayloadMutable().getHandle().randomize(0.01, 0.1, mod_.getPRNG());
6183 auto *offsets = mod_.createConstant(ElemKind::Int32ITy, {12}, "offsets");
6184 offsets->getPayloadMutable().getHandle<int32_t>().randomize(0, 10,
6185 mod_.getPRNG());
6186
6187 auto *bias = mod_.createConstant(ElemKind::Int8QTy, {12}, 1.0, 0, "bias");
6188 bias->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
6189 mod_.getPRNG());
6190 auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 12}, 1.0, 0,
6191 "output", false);
6192 bindings_.allocate(output);
6193 auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 12}, 1.0, 0);
6194
6195 auto *rqfc = F_->createRowwiseQuantizedFullyConnected(
6196 "RowwiseQuantizedFullyConnected1", input1, weights, scales, offsets, bias,
6197 outTy);
6198 F_->createSave("save", rqfc, output);
6199
6200 ::glow::optimize(F_, CompilationMode::Infer);
6201
6202 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6203 parOpts[rqfc] = ParallelTransformKind::Data;
6204
6205 std::unordered_map<Node *, ConcatNode *> replacedMap;
6206 ASSIGN_VALUE_OR_FAIL_TEST(
6207 replacedMap,
6208 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3));
6209 EXPECT_EQ(replacedMap.size(), parOpts.size());
6210 runDCEPass(F_, cctx_);
6211
6212 // We now have 3 RowwiseQuantizedFullyConnecteds
6213 EXPECT_EQ(3, countNodeKind(
6214 F_, Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind));
6215
6216 // One concat to bring all of the parallelized sliced
6217 // RowwiseQuantizedFullyConnecteds together.
6218 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6219}
6220
6221/// Test Splitting Convolution into multiple Convolutions.
6222TEST_F(GraphOptz, ParallelizeGraph_Convolution_Model_Axis3) {
6223 auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 8}, 1.0, 0,
6224 "input1", false);
6225 bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4,
6226 mod_.getPRNG());
6227 auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {12, 1, 1, 8}, 0.1,
6228 0, "weights", false);
6229 auto *bias =
6230 mod_.createPlaceholder(ElemKind::Int32QTy, {12}, 0.01, 0, "bias", false);
6231
6232 auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0,
6233 0, "output", false);
6234 bindings_.allocate(output);
6235 auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0, 0);
6236
6237 auto *c = F_->createConv("Convolution1", input1, filter, bias, outTy, {1, 1},
6238 {1, 1}, {0, 0, 0, 0}, 1);
6239 F_->createSave("save", c, output);
6240
6241 ::glow::optimize(F_, CompilationMode::Infer);
6242
6243 // This is F_ but without the parallel transformation below.
6244 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
6245
6246 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6247 parOpts[c] = ParallelTransformKind::Model_Axis3;
6248
6249 std::unordered_map<Node *, ConcatNode *> replacedMap;
6250 ASSIGN_VALUE_OR_FAIL_TEST(
6251 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
6252 parOpts, 12));
6253 EXPECT_EQ(replacedMap.size(), parOpts.size());
6254 runDCEPass(F_, cctx_);
6255
6256 // We now have 12 Convolutions
6257 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
6258
6259 // One concat to bring all of the parallelized sliced
6260 // ChannelwiseQuantizedConvolutions together.
6261 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6262
6263 checkNumericalEquivalence(0.f);
6264}
6265
6266/// Test Splitting Convolution3D into multiple Convolution3Ds.
6267TEST_F(GraphOptz, ParallelizeGraph_Convolution3D_Model_Axis4) {
6268 auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 5, 8}, 1.0,
6269 0, "input1", false);
6270 bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4,
6271 mod_.getPRNG());
6272 auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {12, 1, 1, 1, 8},
6273 0.1, 0, "weights", false);
6274 auto *bias =
6275 mod_.createPlaceholder(ElemKind::Int32QTy, {12}, 0.01, 0, "bias", false);
6276
6277 auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 5, 12},
6278 1.0, 0, "output", false);
6279 bindings_.allocate(output);
6280 auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 5, 5, 5, 12}, 1.0, 0);
6281
6282 auto *c3d = F_->createConv3D("Convolution3D1", input1, filter, bias, outTy,
6283 {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
6284 F_->createSave("save", c3d, output);
6285
6286 ::glow::optimize(F_, CompilationMode::Infer);
6287
6288 // This is F_ but without the parallel transformation below.
6289 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
6290
6291 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6292 parOpts[c3d] = ParallelTransformKind::Model_Axis4;
6293
6294 std::unordered_map<Node *, ConcatNode *> replacedMap;
6295 ASSIGN_VALUE_OR_FAIL_TEST(
6296 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
6297 parOpts, 12));
6298 EXPECT_EQ(replacedMap.size(), parOpts.size());
6299 runDCEPass(F_, cctx_);
6300
6301 // We now have 12 Convolution3Ds
6302 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::Convolution3DNodeKind));
6303
6304 // One concat to bring all of the parallelized sliced
6305 // ChannelwiseQuantizedConvolutions together.
6306 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6307
6308 checkNumericalEquivalence(0.f);
6309}
6310
6311/// Test Splitting AvgPool into multiple AvgPools.
6312TEST_F(GraphOptz, ParallelizeGraph_AvgPool_Model_Axis4) {
6313 auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 5, 8},
6314 "input1", false);
6315 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
6316 mod_.getPRNG());
6317 auto *output = mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 1, 1, 8},
6318 "output", false);
6319 bindings_.allocate(output);
6320
6321 auto *ap = F_->createAvgPool("AvgPool1", input1, {5, 5, 5}, {1, 1, 1},
6322 {0, 0, 0, 0, 0, 0}, ConvolutionLayout::NTHWC);
6323 F_->createSave("save", ap, output);
6324
6325 ::glow::optimize(F_, CompilationMode::Infer);
6326
6327 // This is F_ but without the parallel transformation below.
6328 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
6329
6330 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
6331 parOpts[ap] = ParallelTransformKind::Model_Axis4;
6332
6333 std::unordered_map<Node *, ConcatNode *> replacedMap;
6334 ASSIGN_VALUE_OR_FAIL_TEST(
6335 replacedMap,
6336 ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 8));
6337 EXPECT_EQ(replacedMap.size(), parOpts.size());
6338 runDCEPass(F_, cctx_);
6339
6340 // We now have 8 AvgPools
6341 EXPECT_EQ(8, countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind));
6342
6343 // One concat to bring all of the parallelized sliced AvgPools
6344 // together.
6345 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
6346
6347 checkNumericalEquivalence(0.f);
6348}
6349
6350/// Test that Add after ConvTranspose is folded into Bias add when the actual
6351/// Add is is a broadcast of the bias. Test \p RnL (right of left) side add.
6352static void foldConvTransposeAddIntoBiasAdd(PlaceholderBindings &bindings,
6353 Module &mod, Function *F,
6354 Function *&optF, bool RnL) {
6355 dim_t batch = 2;
6356 dim_t inC = 2;
6357 dim_t outC = 5;
6358 dim_t inH = 3;
6359 dim_t inW = 3;
6360 unsigned_t kernel = 3;
6361 std::vector<uint32_t> pads = {0, 0, 0, 0};
6362 std::vector<uint32_t> stride = {1, 1};
6363
6364 auto *input = mod.createPlaceholder(ElemKind::FloatTy, {2, inH, inW, inC},
6365 "input", false);
6366 auto *filter = mod.createPlaceholder(
6367 ElemKind::FloatTy, {outC, kernel, kernel, inC}, "filter", false);
6368
6369 auto *bias = mod.createConstant(ElemKind::FloatTy, {outC}, "bias");
6370 bias->getPayloadMutable().getHandle<float>() = {1, 3, 5, 7, 9};
6371
6372 std::pair<dim_t, dim_t> outHW = calculateConvTransposeOutputDims(
6373 inH, inW, {kernel, kernel}, stride, pads);
6374 auto outTy = mod.uniqueType(ElemKind::FloatTy,
6375 {batch, outHW.first, outHW.second, outC});
6376
6377 ConvTransposeNode *CTN =
6378 F->createConvTranspose("ConvTranspose", input, filter, bias, outTy,
6379 {kernel, kernel}, stride, {0, 0, 0, 0}, 1);
6380
6381 auto *CN = mod.createConstant(ElemKind::FloatTy,
6382 {batch, outHW.first, outHW.second, outC}, "c1");
6383 auto *AN = RnL ? F->createAdd("add", CN, CTN) : F->createAdd("add", CTN, CN);
6384
6385 CN->getPayloadMutable().getHandle<float>() = {
6386 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3,
6387 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1,
6388 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4,
6389 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2,
6390 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
6391 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3,
6392 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1,
6393 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4,
6394 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2,
6395 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
6396 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
6397
6398 SaveNode *save = F->createSave("save", AN);
6399 bindings.allocate(save->getPlaceholder());
6400
6401 EXPECT_EQ(F->getNodes().size(), 3);
6402 optF = optimizeFunctionForTest(F);
6403 EXPECT_EQ(optF->getNodes().size(), 2);
6404
6405 const SaveNode *optSave =
6406 findFunctionNodeByName<SaveNode>(optF, save->getName());
6407
6408 ConvTransposeNode *optCN =
6409 llvm::dyn_cast<ConvTransposeNode>(optSave->getInput());
6410 EXPECT_TRUE(optCN);
6411
6412 Constant *optBias = llvm::dyn_cast<Constant>(optCN->getBias());
6413 EXPECT_TRUE(optBias);
6414
6415 auto BH = optBias->getPayload().getHandle();
6416 EXPECT_EQ(BH.raw(0), 1 + 1);
6417 EXPECT_EQ(BH.raw(1), 2 + 3);
6418 EXPECT_EQ(BH.raw(2), 3 + 5);
6419 EXPECT_EQ(BH.raw(3), 4 + 7);
6420 EXPECT_EQ(BH.raw(4), 5 + 9);
6421
6422 bindings.allocate(mod.getPlaceholders());
6423 bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
6424 bindings.get(filter)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
6425}
6426
6427/// Test that Add after ConvTranspose is folded into Bias add when the actual
6428/// Add is is a broadcast of the bias.
6429TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddRHS) {
6430 foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, false);
6431 checkNumericalEquivalence();
6432}
6433TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddLHS) {
6434 foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, true);
6435 checkNumericalEquivalence();
6436}
6437
6438/// Test that MatMul + Add is folded into FullyConnected.
6439TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnected) {
6440
6441 auto *input =
6442 mod_.createPlaceholder(ElemKind::FloatTy, {1, 3}, "input", false);
6443 auto *weights =
6444 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights", false);
6445 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5}, "bias", false);
6446
6447 MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
6448 AddNode *add = F_->createAdd("add", matmul, bias);
6449 F_->createSave("save", add);
6450 EXPECT_EQ(3, F_->getNodes().size());
6451
6452 // The folding should replace the MatMul + Add into a FullyConnected and a
6453 // Reshape to 1D for the Bias.
6454 CompilationContext cctx;
6455 ::glow::fold(F_, cctx);
6456 EXPECT_EQ(3, F_->getNodes().size());
6457 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind));
6458 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
6459 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
6460 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
6461}
6462
6463/// Test that batched MatMul + Add is folded into batched FullyConnected.
6464/// This optimization takes place only if the Bias is constant and the
6465/// bias data repeats for all the batches.
6466TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnectedBatched) {
6467
6468 auto *input =
6469 mod_.createPlaceholder(ElemKind::FloatTy, {2, 3}, "input", false);
6470 auto *weights =
6471 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights", false);
6472 auto *bias = mod_.createConstant(ElemKind::FloatTy, {2, 5}, "bias");
6473 auto biasH = bias->getPayloadMutable().getHandle<float>();
6474 biasH = {1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
6475
6476 MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
6477 AddNode *add = F_->createAdd("add", matmul, bias);
6478 F_->createSave("save", add);
6479 EXPECT_EQ(3, F_->getNodes().size());
6480
6481 // The folding should replace the MatMul + Add into a FullyConnected and a
6482 // Reshape to 1D for the Bias.
6483 CompilationContext cctx;
6484 ::glow::fold(F_, cctx);
6485 EXPECT_EQ(4, F_->getNodes().size());
6486 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind));
6487 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
6488 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
6489 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
6490 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
6491}
6492
6493/// Test that MatMul is converted to FullyConnected for Int8QTy.
6494TEST_F(GraphOptz, ConvertMatMulToFullyConnected_Int8QTy) {
6495
6496 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 3}, 0.1f, -13,
6497 "input", false);
6498 auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.2f, 15,
6499 "weights", false);
6500 MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
6501 F_->createSave("save", matmul);
6502 EXPECT_EQ(2, F_->getNodes().size());
6503
6504 optimizedF_ = optimizeFunctionForTest(
6505 F_, {FunctionPassID::ConvertMatMulToFullyConnected, getDCEPassConfig()});
6506
6507 EXPECT_EQ(2, optimizedF_->getNodes().size());
6508 EXPECT_EQ(1,
6509 countNodeKind(optimizedF_, Kinded::Kind::FullyConnectedNodeKind));
6510}
6511
6512/// Test that MatMul is converted to FullyConnected for FloatTy.
6513TEST_F(GraphOptz, ConvertMatMulToFullyConnected_FloatTy) {
6514
6515 auto *input =
6516 mod_.createPlaceholder(ElemKind::FloatTy, {1, 3}, "input", false);
6517 auto *weights =
6518 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights", false);
6519 MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
6520 F_->createSave("save", matmul);
6521 EXPECT_EQ(2, F_->getNodes().size());
6522
6523 optimizedF_ = optimizeFunctionForTest(
6524 F_, {FunctionPassID::ConvertMatMulToFullyConnected, getDCEPassConfig()});
6525
6526 EXPECT_EQ(2, optimizedF_->getNodes().size());
6527 EXPECT_EQ(1,
6528 countNodeKind(optimizedF_, Kinded::Kind::FullyConnectedNodeKind));
6529}
6530
6531/// Test that FoldSlicesIntoConstants pass works as expected.
6532TEST_F(GraphOptz, FoldSlicesIntoConstantsTest) {
6533 Constant *C = mod_.createConstant(ElemKind::FloatTy, {3, 4}, "C");
6534 auto CH = C->getPayloadMutable().getHandle<float>();
6535 CH = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
6536
6537 SliceNode *S1 = F_->createSlice("s1", C, {0, 0}, {3, 2});
6538 SliceNode *S2 = F_->createSlice("s2", C, {0, 2}, {3, 4});
6539 SaveNode *SN1 = F_->createSave("save1", S1);
6540 SaveNode *SN2 = F_->createSave("save2", S2);
6541
6542 optimizedF_ = optimizeFunctionForTest(
6543 F_, {FunctionPassID::FoldSlicesIntoConstants, getDCEPassConfig()});
6544
6545 SaveNode *optSN1 =
6546 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN1->getName()));
6547 SaveNode *optSN2 =
6548 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN2->getName()));
6549 ASSERT_TRUE(optSN1);
6550 ASSERT_TRUE(optSN2);
6551
6552 Constant *C1 = llvm::dyn_cast<Constant>(optSN1->getInput());
6553 ASSERT_TRUE(C1);
6554 auto H1 = C1->getPayloadMutable().getHandle();
6555 Constant *C2 = llvm::dyn_cast<Constant>(optSN2->getInput());
6556 ASSERT_TRUE(C2);
6557 auto H2 = C2->getPayloadMutable().getHandle();
6558 for (dim_t i = 0, e = 3; i < e; i++) {
6559 for (dim_t j = 0, e = 2; j < e; j++) {
6560 EXPECT_EQ(H1.at({i, j}), CH.at({i, j}));
6561 EXPECT_EQ(H2.at({i, j}), CH.at({i, j + 2}));
6562 }
6563 }
6564}
6565
6566/// Test that RaiseClipsAboveShapeNodes pass works as expected.
6567TEST_F(GraphOptz, RaiseClipsAboveShapeNodesTest) {
6568 Placeholder *input =
6569 mod_.createPlaceholder(ElemKind::FloatTy, {256, 64}, "input", false);
6570
6571 ReshapeNode *RN1 = F_->createReshape("reshape1", input, {4, 128, 32});
6572 ReshapeNode *RN2 = F_->createReshape("reshape2", RN1, {64, 256});
6573 TransposeNode *TN = F_->createTranspose("transpose", RN2, {1, 0});
6574 SliceNode *SN = F_->createSlice("slice", TN, {64, 0}, {256, 64});
6575 TileNode *TiN = F_->createTile("tile", SN, 2, 0);
6576 ClipNode *CN = F_->createClip("clip", TiN, -0.1, 0.1);
6577 SaveNode *save1 = F_->createSave("save1", RN1);
6578 SaveNode *save2 = F_->createSave("save2", CN);
6579
6580 optimizedF_ =
6581 optimizeFunctionForTest(F_, {FunctionPassID::RaiseClipsAboveShapeNodes});
6582
6583 auto *optSave1 =
6584 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save1->getName()));
6585 ASSERT_TRUE(optSave1);
6586 auto *optSave2 =
6587 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save2->getName()));
6588 ASSERT_TRUE(optSave2);
6589
6590 // save1 should only have a single untouched Reshape RN1 input which has input
6591 // input into it, because RN1 has multiple users.
6592 auto *optRN1 = llvm::dyn_cast<ReshapeNode>(optSave1->getInput().getNode());
6593 ASSERT_TRUE(optRN1);
6594 EXPECT_EQ(input, optRN1->getInput().getNode());
6595
6596 // save2 should have CN it originally saved pushed up above SN, TiN, TN, and
6597 // RN2.
6598 TileNode *newTiN = llvm::dyn_cast<TileNode>(optSave2->getInput());
6599 ASSERT_TRUE(newTiN);
6600 EXPECT_EQ(newTiN->getCount(), TiN->getCount());
6601 SliceNode *newSN = llvm::dyn_cast<SliceNode>(newTiN->getInput());
6602 ASSERT_TRUE(newSN);
6603 EXPECT_EQ(newSN->getStart(), SN->getStart());
6604 TransposeNode *newTN = llvm::dyn_cast<TransposeNode>(newSN->getInput());
6605 ASSERT_TRUE(newTN);
6606 EXPECT_EQ(newTN->getShuffle(), TN->getShuffle());
6607 ReshapeNode *newRN2 = llvm::dyn_cast<ReshapeNode>(newTN->getInput());
6608 ASSERT_TRUE(newRN2);
6609 ClipNode *newCN = llvm::dyn_cast<ClipNode>(newRN2->getInput());
6610 ASSERT_TRUE(newCN);
6611 EXPECT_EQ(newCN->getMin(), CN->getMin());
6612 EXPECT_EQ(newCN->getMax(), CN->getMax());
6613
6614 bindings_.allocate(mod_.getPlaceholders());
6615 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
6616 checkNumericalEquivalence();
6617}
6618
6619static void testOptimizeDequantizeClip(PlaceholderBindings &bindings,
6620 Module &mod, Function *F,
6621 Function *&optF,
6622 bool enableQuantParamChanges) {
6623 Placeholder *input =
6624 mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
6625
6626 const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
6627
6628 QuantizeNode *QN =
6629 F->createQuantize("quantize", input,
6630 mod.uniqueType(ElemKind::Int8QTy, {20, 20},
6631 qParams.scale, qParams.offset));
6632 DequantizeNode *DN = F->createDequantize("dequantize", QN, ElemKind::FloatTy);
6633 ClipNode *CN =
6634 F->createClip("clip", DN, enableQuantParamChanges ? 0 : -100, 100);
6635 SaveNode *SN = F->createSave("save", CN);
6636
6637 CompilationContext cctx;
6638 cctx.optimizationOpts.enableQuantParamChanges = true;
6639 optF = optimizeFunctionForTest(
6640 F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx);
6641
6642 EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0);
6643
6644 SaveNode *optSN =
6645 llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName()));
6646 ASSERT_TRUE(optSN);
6647
6648 // Now check that the quantization params have been correctly updated for QN,
6649 // and that CN has been eliminated.
6650 DequantizeNode *optDN =
6651 llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode());
6652 ASSERT_TRUE(optDN);
6653 const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange();
6654 // Min is either from Clip or Quant range depending on enableQuantParamChanges
6655 EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3);
6656 EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max from Quant range
6657
6658 bindings.allocate(mod.getPlaceholders());
6659 bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
6660}
6661
6662/// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize)
6663/// when the quantization parameters are allowed to change.
6664TEST_F(GraphOptz, OptimizeDequantizeClipTest_QuantParamChanges) {
6665 testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_,
6666 /* enableQuantParamChanges */ true);
6667 checkNumericalEquivalence(0.0005);
6668}
6669
6670/// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize)
6671/// when the quantization parameters are not allowed to change.
6672TEST_F(GraphOptz, OptimizeDequantizeClipTest_NoQuantParamChanges) {
6673 testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_,
6674 /* enableQuantParamChanges */ false);
6675 checkNumericalEquivalence();
6676}
6677
6678static void testOptimizeClipQuantize(PlaceholderBindings &bindings, Module &mod,
6679 Function *F, Function *&optF,
6680 bool enableQuantParamChanges) {
6681 Placeholder *input =
6682 mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
6683
6684 const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
6685
6686 ClipNode *CN =
6687 F->createClip("clip", input, enableQuantParamChanges ? 0 : -100, 100);
6688 QuantizeNode *QN =
6689 F->createQuantize("quantize", CN,
6690 mod.uniqueType(ElemKind::Int8QTy, {20, 20},
6691 qParams.scale, qParams.offset));
6692 DequantizeNode *DN = F->createDequantize("dequantize", QN, ElemKind::FloatTy);
6693 SaveNode *SN = F->createSave("save", DN);
6694
6695 CompilationContext cctx;
6696 cctx.optimizationOpts.enableQuantParamChanges = enableQuantParamChanges;
6697 optF = optimizeFunctionForTest(
6698 F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx);
6699
6700 EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0);
6701
6702 SaveNode *optSN =
6703 llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName()));
6704 ASSERT_TRUE(optSN);
6705
6706 // Now check that the quantization params have been correctly updated for QN,
6707 // and that CN has been eliminated.
6708 DequantizeNode *optDN =
6709 llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode());
6710 ASSERT_TRUE(optDN);
6711 const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange();
6712 // Min is either from Clip or Quant range depending on enableQuantParamChanges
6713 EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3);
6714 EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max always from Quant range
6715
6716 bindings.allocate(mod.getPlaceholders());
6717 bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
6718}
6719
6720/// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize)
6721/// when the quantization parameters are allowed to change.
6722TEST_F(GraphOptz, OptimizeClipQuantizeTest_QuantParamChanges) {
6723 testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_,
6724 /* enableQuantParamChanges */ true);
6725 checkNumericalEquivalence(0.0005);
6726}
6727
6728/// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize)
6729/// when the quantization parameters are not allowed to change.
6730TEST_F(GraphOptz, OptimizeClipQuantizeTest_NoQuantParamChanges) {
6731 testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_,
6732 /* enableQuantParamChanges */ false);
6733 checkNumericalEquivalence();
6734}
6735
6736/// Test Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8.
6737TEST_F(GraphOptz, OptimizeOutIntermediateConversionsTest) {
6738 Placeholder *input =
6739 mod_.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
6740
6741 const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
6742
6743 ConvertToNode *CN = F_->createConvertTo("conv", input, ElemKind::Float16Ty);
6744 QuantizeNode *QN =
6745 F_->createQuantize("quantize", CN,
6746 mod_.uniqueType(ElemKind::Int8QTy, {20, 20},
6747 qParams.scale, qParams.offset));
6748 DequantizeNode *DN =
6749 F_->createDequantize("dequantize", QN, ElemKind::FloatTy);
6750 F_->createSave("save", DN);
6751
6752 optimizedF_ = optimizeFunctionForTest(
6753 F_,
6754 {FunctionPassID::OptimizeOutIntermediateConversions, getDCEPassConfig()});
6755
6756 // Now check that the ConvertToNode has been eliminated.
6757 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConvertToNodeKind), 0);
6758
6759 bindings_.allocate(mod_.getPlaceholders());
6760 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
6761 checkNumericalEquivalence();
6762}
6763
6764/// Test Clip(Relu(Clip)) -> Clip'.
6765TEST_F(GraphOptz, ClipReluClipElimTest) {
6766 Placeholder *input =
6767 mod_.createPlaceholder(ElemKind::FloatTy, {64, 64}, "input", false);
6768 ClipNode *CN1 = F_->createClip("CN1", input, -10, 30);
6769 ReluNode *RN = F_->createRELU("RN", CN1);
6770 ClipNode *CN2 = F_->createClip("CN2", RN, -5, 20);
6771 SaveNode *SN = F_->createSave("save", CN2);
6772
6773 // Start with 2 clips, a relu, and a save.
6774 EXPECT_EQ(F_->getNodes().size(), 4);
6775 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 2);
6776 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
6777
6778 optimizedF_ = optimizeFunctionForTest(F_);
6779
6780 // Remove one of the clips and the relu.
6781 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
6782 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1);
6783 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0);
6784
6785 SaveNode *optSN =
6786 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
6787 ASSERT_TRUE(optSN);
6788
6789 // We combined all of the ranges into the single Clip.
6790 ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput());
6791 ASSERT_TRUE(optCN);
6792 EXPECT_EQ(optCN->getMin(), 0);
6793 EXPECT_EQ(optCN->getMax(), 20);
6794
6795 bindings_.allocate(input)->getHandle().randomize(-50.0, 5.0, mod_.getPRNG());
6796 checkNumericalEquivalence();
6797}
6798
6799/// Test that we can find a non-quantized relu and fuse it up into a quant FC.
6800TEST_F(GraphOptz, OptimizeQuantFCFloatReluTest) {
6801 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
6802 "input", false);
6803 auto *weights =
6804 mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights");
6805 auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias");
6806
6807 auto *FC = F_->createFullyConnected("fc", input, weights, bias);
6808 auto *DN = F_->createDequantize("dq", FC, ElemKind::FloatTy);
6809 auto *RN = F_->createRELU("relu", DN);
6810 auto *SN = F_->createSave("save", RN);
6811
6812 optimizedF_ = optimizeFunctionForTest(
6813 F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()});
6814
6815 SaveNode *optSN =
6816 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
6817 ASSERT_TRUE(optSN);
6818
6819 DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(optSN->getInput());
6820 ASSERT_TRUE(optDN);
6821 ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput());
6822 ASSERT_TRUE(optRN);
6823 auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange();
6824 EXPECT_EQ(rangeRN.first, 0.0f);
6825 FullyConnectedNode *optFC =
6826 llvm::dyn_cast<FullyConnectedNode>(optRN->getInput());
6827 ASSERT_TRUE(optFC);
6828 auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange();
6829 EXPECT_EQ(rangeRN.second, rangeFC.second);
6830
6831 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
6832 mod_.getPRNG());
6833 weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
6834 mod_.getPRNG());
6835 bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127,
6836 mod_.getPRNG());
6837 checkNumericalEquivalence();
6838}
6839
6840/// Test that we can find a non-quantized relu and fuse it up into a quant FC
6841/// even when setting dummy qparams to true.
6842TEST_F(GraphOptz, OptimizeDummyQuantFCFloatReluTest) {
6843 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
6844 "input", false);
6845 auto *weights =
6846 mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights");
6847 auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias");
6848 auto *addW =
6849 mod_.createPlaceholder(ElemKind::FloatTy, {2, 32}, "addw", false);
6850 auto *FC = F_->createFullyConnected("fc", input, weights, bias);
6851 auto *DN = F_->createDequantize("dq", FC, ElemKind::FloatTy);
6852 auto *RN = F_->createRELU("relu", DN);
6853 auto *AN = F_->createAdd("add", RN, addW);
6854 auto *SN = F_->createSave("save", AN);
6855
6856 CompilationContext cctx;
6857 cctx.precisionConfig.loadUniquedDummyQParams = true;
6858 optimizedF_ = optimizeFunctionForTest(
6859 F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()}, cctx);
6860
6861 SaveNode *optSN =
6862 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
6863 ASSERT_TRUE(optSN);
6864
6865 AddNode *optAN = llvm::dyn_cast<AddNode>(optSN->getInput());
6866 ASSERT_TRUE(optAN);
6867 DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(optAN->getLHS());
6868 ASSERT_TRUE(optDN);
6869 ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput());
6870 ASSERT_TRUE(optRN);
6871 auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange();
6872 FullyConnectedNode *optFC =
6873 llvm::dyn_cast<FullyConnectedNode>(optRN->getInput());
6874 ASSERT_TRUE(optFC);
6875 auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange();
6876 EXPECT_EQ(rangeRN.first, rangeFC.first);
6877 EXPECT_EQ(rangeRN.second, rangeFC.second);
6878
6879 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
6880 mod_.getPRNG());
6881 bindings_.allocate(addW)->getHandle<float>().randomize(-128, 127,
6882 mod_.getPRNG());
6883 weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
6884 mod_.getPRNG());
6885 bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127,
6886 mod_.getPRNG());
6887 checkNumericalEquivalence();
6888}
6889
6890/// Test that we can find a non-quantized relu and fuse it up into a series of
6891/// concatenated quant FCs.
6892TEST_F(GraphOptz, OptimizeConcatQuantFCFloatReluTest) {
6893 std::array<NodeValue, 5> DQs;
6894 for (size_t i = 0; i < 5; i++) {
6895 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32},
6896 1.0 / (i + 1), 0, "input", false);
6897 auto *weights =
6898 mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights");
6899 auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias");
6900
6901 auto *FC = F_->createFullyConnected("fc", input, weights, bias);
6902 DQs[i] = F_->createDequantize("dq", FC, ElemKind::FloatTy)->getResult();
6903
6904 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
6905 mod_.getPRNG());
6906 weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
6907 mod_.getPRNG());
6908 bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127,
6909 mod_.getPRNG());
6910 }
6911
6912 auto *CN = F_->createConcat("concat", DQs, 0);
6913 auto *RN = F_->createRELU("relu", CN);
6914 auto *SN = F_->createSave("save", RN);
6915
6916 optimizedF_ = optimizeFunctionForTest(
6917 F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()});
6918
6919 SaveNode *optSN =
6920 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
6921 ASSERT_TRUE(optSN);
6922 ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
6923 ASSERT_TRUE(optCN);
6924 EXPECT_EQ(optCN->getInputs().size(), 5);
6925
6926 for (const NodeValue &NV : optCN->getInputs()) {
6927 DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(NV);
6928 ASSERT_TRUE(optDN);
6929 ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput());
6930 ASSERT_TRUE(optRN);
6931 auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange();
6932 EXPECT_EQ(rangeRN.first, 0.0f);
6933 FullyConnectedNode *optFC =
6934 llvm::dyn_cast<FullyConnectedNode>(optRN->getInput());
6935 ASSERT_TRUE(optFC);
6936 auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange();
6937 EXPECT_EQ(rangeRN.second, rangeFC.second);
6938 }
6939
6940 checkNumericalEquivalence();
6941}
6942
6943/// Test that we can find a concat with all dequantize inputs and a quantize at
6944/// its output, and then replace quant/dequants with rescales.
6945TEST_F(GraphOptz, OptimizeDequantConcatQuant) {
6946 std::array<NodeValue, 5> DQs;
6947 std::array<Placeholder *, 5> inputs;
6948 for (size_t i = 0; i < 5; i++) {
6949 inputs[i] = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32},
6950 0.3 / (i + 1), 5, "input", false);
6951 DQs[i] =
6952 F_->createDequantize("dq", inputs[i], ElemKind::FloatTy)->getResult();
6953
6954 bindings_.allocate(inputs[i])->getHandle<int8_t>().randomize(
6955 -128, 127, mod_.getPRNG());
6956 }
6957
6958 auto *CN = F_->createConcat("concat", DQs, 0);
6959 constexpr float scale = 0.3;
6960 constexpr int32_t offset = 5;
6961 auto *RN = F_->createQuantize("quantize", CN,
6962 mod_.uniqueType(ElemKind::Int8QTy,
6963 CN->getResult().dims(), scale,
6964 offset));
6965 auto *SN = F_->createSave("save", RN);
6966
6967 optimizedF_ = optimizeFunctionForTest(
6968 F_, {FunctionPassID::OptimizeConcatQuantization, getDCEPassConfig()});
6969
6970 SaveNode *optSN =
6971 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
6972 ASSERT_TRUE(optSN);
6973 ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
6974 ASSERT_TRUE(optCN);
6975 EXPECT_EQ(optCN->getInputs().size(), 5);
6976
6977 for (size_t i = 0, e = optCN->getInputs().size(); i < e; i++) {
6978 const NodeValue NV = optCN->getInputs()[i];
6979 if (i == 0) {
6980 EXPECT_EQ(inputs[i], NV.getNode());
6981 EXPECT_EQ(inputs[i]->getOutput().getType()->getScale(), scale);
6982 EXPECT_EQ(inputs[i]->getOutput().getType()->getOffset(), offset);
6983 } else {
6984 RescaleQuantizedNode *optRN = llvm::dyn_cast<RescaleQuantizedNode>(NV);
6985 ASSERT_TRUE(optRN);
6986 EXPECT_EQ(optRN->getResult().getType()->getScale(), scale);
6987 EXPECT_EQ(optRN->getResult().getType()->getOffset(), offset);
6988 EXPECT_EQ(inputs[i], optRN->getInput().getNode());
6989 }
6990 }
6991 checkNumericalEquivalence();
6992}
6993
6994/// Test that if we have a Concat with all Dequantize inputs with the same
6995/// scale/offset/kind that we can sink the Dequantizes below the Concat.
6996TEST_F(GraphOptz, SinkDequantizeBelowConcatTest) {
6997 const float scale = 0.06;
6998 const int32_t offset = -15;
6999 std::array<NodeValue, 5> inputs;
7000 for (dim_t i = 0; i < 5; i++) {
7001 Placeholder *input = mod_.createPlaceholder(ElemKind::Int8QTy, {i + 1, 100},
7002 scale, offset, "input", false);
7003 bindings_.allocate(input)->getHandle<int8_t>().randomize(-100, 100,
7004 mod_.getPRNG());
7005 DequantizeNode *dequantize =
7006 F_->createDequantize("dequantize", input, ElemKind::Float16Ty);
7007 inputs[i] = dequantize->getResult();
7008 }
7009 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
7010 SaveNode *SN = F_->createSave("ret", concat);
7011
7012 optimizedF_ = optimizeFunctionForTest(
7013 F_, {FunctionPassID::SinkConversions, getDCEPassConfig()});
7014
7015 // Concat, dequantize, save.
7016 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
7017 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::DequantizeNodeKind), 1);
7018 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
7019 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
7020
7021 SaveNode *optSN =
7022 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
7023 ASSERT_TRUE(optSN);
7024 DequantizeNode *optDequantize =
7025 llvm::dyn_cast<DequantizeNode>(optSN->getInput());
7026 ASSERT_TRUE(optDequantize);
7027 NodeValue input = optDequantize->getInput();
7028 EXPECT_EQ(scale, input.getType()->getScale());
7029 EXPECT_EQ(offset, input.getType()->getOffset());
7030 EXPECT_EQ(ElemKind::Int8QTy, input.getType()->getElementType());
7031
7032 // Find dequantize node in the optimized graph.
7033 checkNumericalEquivalence();
7034}
7035
7036/// Test that if we have a Concat with all Quantize inputs with the same
7037/// scale/offset/kind that we can sink the Dequantizes below the Concat.
7038TEST_F(GraphOptz, SinkQuantizeBelowConcatTest) {
7039 const float scale = 0.06;
7040 const int32_t offset = -15;
7041 std::array<NodeValue, 5> inputs;
7042 for (dim_t i = 0; i < 5; i++) {
7043 Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty,
7044 {i + 1, 100}, "input", false);
7045 bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100,
7046 mod_.getPRNG());
7047 const TypeRef QTy = mod_.uniqueType(
7048 ElemKind::Int8QTy, input->getOutput().dims(), scale, offset);
7049 QuantizeNode *quantize = F_->createQuantize("quantize", input, QTy);
7050 inputs[i] = quantize->getResult();
7051 }
7052 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
7053 SaveNode *SN = F_->createSave("ret", concat);
7054
7055 optimizedF_ = optimizeFunctionForTest(
7056 F_, {FunctionPassID::SinkConversions, getDCEPassConfig()});
7057
7058 // Concat, quantize, save.
7059 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
7060 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1);
7061 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
7062 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
7063
7064 SaveNode *optSN =
7065 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
7066 ASSERT_TRUE(optSN);
7067 QuantizeNode *optQuantize = llvm::dyn_cast<QuantizeNode>(optSN->getInput());
7068 ASSERT_TRUE(optQuantize);
7069 EXPECT_EQ(scale, optQuantize->getResult().getType()->getScale());
7070 EXPECT_EQ(offset, optQuantize->getResult().getType()->getOffset());
7071 EXPECT_EQ(ElemKind::Int8QTy,
7072 optQuantize->getResult().getType()->getElementType());
7073
7074 // Find quantize node in the optimized graph.
7075 checkNumericalEquivalence();
7076}
7077
7078/// Test that if we have a Concat with all Tanh inputs,
7079/// we can sink the Tanh's below the Concat.
7080TEST_F(GraphOptz, SinkTanhBelowConcatTest) {
7081 std::array<NodeValue, 5> inputs;
7082 for (dim_t i = 0; i < 5; i++) {
7083 Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty,
7084 {i + 1, 100}, "input", false);
7085 bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100,
7086 mod_.getPRNG());
7087 TanhNode *tanh = F_->createTanh("tanh", input);
7088 inputs[i] = tanh->getResult();
7089 }
7090 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
7091 SaveNode *SN = F_->createSave("ret", concat);
7092 EXPECT_EQ(F_->getNodes().size(), 7);
7093 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TanhNodeKind), 5);
7094 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 1);
7095 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 1);
7096
7097 CompilationContext cctx;
7098 cctx.optimizationOpts.sinkTanhBelowConcat = true;
7099
7100 optimizedF_ = optimizeFunctionForTest(
7101 F_, {FunctionPassID::SinkConversions, getDCEPassConfig()}, cctx);
7102
7103 // Concat, dequantize, save.
7104 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
7105 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::TanhNodeKind), 1);
7106 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
7107 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
7108
7109 SaveNode *optSN =
7110 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
7111 ASSERT_TRUE(optSN);
7112 TanhNode *optTanh = llvm::dyn_cast<TanhNode>(optSN->getInput());
7113 ASSERT_TRUE(optTanh);
7114 NodeValue input = optTanh->getInput();
7115 EXPECT_EQ(ElemKind::Float16Ty, input.getType()->getElementType());
7116
7117 checkNumericalEquivalence();
7118}
7119
7120/// Test that if we have a Concat with all ConvertTo inputs,
7121/// we can sink the ConvertTo's below the Concat.
7122TEST_F(GraphOptz, SinkConvertToBelowConcatTest) {
7123 std::array<NodeValue, 5> inputs;
7124 for (dim_t i = 0; i < 5; i++) {
7125 Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty,
7126 {i + 1, 100}, "input", false);
7127 bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100,
7128 mod_.getPRNG());
7129 ConvertToNode *convertTo =
7130 F_->createConvertTo("convertToFP32", input, ElemKind::FloatTy);
7131 inputs[i] = convertTo->getResult();
7132 }
7133 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
7134 SaveNode *SN = F_->createSave("ret", concat);
7135 EXPECT_EQ(F_->getNodes().size(), 7);
7136 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConvertToNodeKind), 5);
7137 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 1);
7138 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 1);
7139
7140 CompilationContext cctx;
7141
7142 optimizedF_ = optimizeFunctionForTest(
7143 F_, {FunctionPassID::SinkConversions, getDCEPassConfig()}, cctx);
7144
7145 // Concat, converTo, save.
7146 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
7147 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConvertToNodeKind), 1);
7148 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
7149 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
7150
7151 SaveNode *optSN =
7152 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
7153 ASSERT_TRUE(optSN);
7154 ConvertToNode *optConvertTo =
7155 llvm::dyn_cast<ConvertToNode>(optSN->getInput());
7156 ASSERT_TRUE(optConvertTo);
7157 NodeValue input = optConvertTo->getInput();
7158 EXPECT_EQ(ElemKind::Float16Ty, input.getType()->getElementType());
7159
7160 checkNumericalEquivalence();
7161}
7162
7163/// Test Clip(Relu) -> Clip'.
7164TEST_F(GraphOptz, ClipReluTest) {
7165 Placeholder *input =
7166 mod_.createPlaceholder(ElemKind::Float16Ty, {64, 64}, "input", false);
7167 ReluNode *RN = F_->createRELU("RN", input);
7168 ClipNode *CN = F_->createClip("CN", RN, -5, 20);
7169 SaveNode *SN = F_->createSave("save", CN);
7170
7171 // Start with a clip, a relu, and a save.
7172 EXPECT_EQ(F_->getNodes().size(), 3);
7173 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 1);
7174 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
7175
7176 optimizedF_ = optimizeFunctionForTest(F_);
7177
7178 // Removed the relu
7179 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
7180 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1);
7181 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0);
7182
7183 SaveNode *optSN =
7184 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
7185 ASSERT_TRUE(optSN);
7186
7187 // We have the same max for clip as before, but 0 for min due to the Relu.
7188 ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput());
7189 ASSERT_TRUE(optCN);
7190 EXPECT_EQ(optCN->getMin(), 0);
7191 EXPECT_EQ(optCN->getMax(), 20);
7192
7193 bindings_.allocate(input)->getHandle<float16_t>().randomize(-50.0, 5.0,
7194 mod_.getPRNG());
7195 checkNumericalEquivalence();
7196}
7197
7198/// Test that if we have a concat with some dequantize inputs that are
7199/// concatenated together, and then a quantize after the concat, that we can
7200/// move the quantize above the concat and eliminate the dequantizes.
7201TEST_F(GraphOptz, SinkConcatBelowQuantize) {
7202 const float scale = 0.06;
7203 const int32_t offset = -15;
7204 std::array<NodeValue, 3> inputs;
7205
7206 // Concat input 0: Dequant(PH)
7207 const TypeRef in0QTy =
7208 mod_.uniqueType(ElemKind::Int8QTy, {1, 3}, scale, offset);
7209 Placeholder *input0 = mod_.createPlaceholder(in0QTy, "input", false);
7210 inputs[0] =
7211 F_->createDequantize("deq", input0, ElemKind::Float16Ty)->getResult();
7212
7213 // Concat input 1: Dequant(Add(PH, PH))
7214 const TypeRef in1QTy =
7215 mod_.uniqueType(ElemKind::Int8QTy, {5, 3}, scale, offset + 1);
7216 Placeholder *input1 = mod_.createPlaceholder(in1QTy, "input", false);
7217 AddNode *add = F_->createAdd("add", input1, input1);
7218 inputs[1] =
7219 F_->createDequantize("deq", add, ElemKind::Float16Ty)->getResult();
7220
7221 // Concat input 2: PH
7222 Placeholder *input2 =
7223 mod_.createPlaceholder(ElemKind::Float16Ty, {10, 3}, "input_fp", false);
7224 inputs[2] = input2->getOutput();
7225
7226 // Concat all 3 together, all FP16.
7227 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
7228
7229 // Now quantize the result of the concat.
7230 const TypeRef QTy = mod_.uniqueType(
7231 ElemKind::Int8QTy, concat->getResult().dims(), scale, offset);
7232 QuantizeNode *QN = F_->createQuantize("quantize", concat, QTy);
7233 SaveNode *SN = F_->createSave("ret", QN);
7234
7235 optimizedF_ = optimizeFunctionForTest(
7236 F_,
7237 {FunctionPassID::SinkConcatBelowQuantize,
7238 {FunctionPassID::OptimizeQuantization, ConvergenceMode::UntilFixedPoint},
7239 getDCEPassConfig()});
7240
7241 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
7242 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
7243 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1);
7244 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1);
7245 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
7246
7247 SaveNode *optSN =
7248 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
7249 ASSERT_TRUE(optSN);
7250
7251 // Concat should be directly connected to save, with same quantization
7252 // parameters as the quantize which used to follow it.
7253 ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
7254 ASSERT_TRUE(optCN);
7255 ASSERT_EQ(ElemKind::Int8QTy, optCN->getResult().getType()->getElementType());
7256 EXPECT_EQ(scale, optCN->getResult().getType()->getScale());
7257 EXPECT_EQ(offset, optCN->getResult().getType()->getOffset());
7258
7259 ASSERT_EQ(optCN->getInputs().size(), 3);
7260
7261 // No rescale here for the PH since its scale/offset match the PH and so
7262 // are optimized away.
7263 EXPECT_EQ(optCN->getInputs()[0], input0->getOutput());
7264
7265 // No rescale here because it should be fused into optAN. Check the
7266 // scale/offset use that scale/offset.
7267 AddNode *optAN = llvm::dyn_cast<AddNode>(optCN->getInputs()[1]);
7268 ASSERT_TRUE(optAN);
7269 ASSERT_EQ(ElemKind::Int8QTy, optAN->getResult().getType()->getElementType());
7270 EXPECT_EQ(scale, optAN->getResult().getType()->getScale());
7271 EXPECT_EQ(offset, optAN->getResult().getType()->getOffset());
7272 EXPECT_EQ(optAN->getLHS(), input1->getOutput());
7273 EXPECT_EQ(optAN->getRHS(), input1->getOutput());
7274
7275 // Must quantize this input since the PH is float16.
7276 QuantizeNode *optQN = llvm::dyn_cast<QuantizeNode>(optCN->getInputs()[2]);
7277 ASSERT_TRUE(optQN);
7278 ASSERT_EQ(ElemKind::Int8QTy, optQN->getResult().getType()->getElementType());
7279 EXPECT_EQ(scale, optQN->getResult().getType()->getScale());
7280 EXPECT_EQ(offset, optQN->getResult().getType()->getOffset());
7281 EXPECT_EQ(optQN->getInput(), input2->getOutput());
7282
7283 bindings_.allocate(input0)->getHandle<int8_t>().randomize(-50, 50,
7284 mod_.getPRNG());
7285 bindings_.allocate(input1)->getHandle<int8_t>().randomize(-50, 50,
7286 mod_.getPRNG());
7287 bindings_.allocate(input2)->getHandle<float16_t>().randomize(-10, 10,
7288 mod_.getPRNG());
7289}
7290
7291TEST_F(GraphOptz, EliminateSliceConcatTest) {
7292 auto *src1 =
7293 mod_.createPlaceholder(ElemKind::FloatTy, {10, 70}, "src1", false);
7294 auto *src2 =
7295 mod_.createPlaceholder(ElemKind::FloatTy, {10, 80}, "src2", false);
7296 auto *A = F_->createSlice("A", src1, {0, 0}, {10, 10});
7297 auto *B = F_->createSlice("B", src1, {0, 10}, {10, 20});
7298 auto *C = F_->createSlice("C", src1, {0, 20}, {10, 30});
7299 // interleaved Slices with different sources shouldn't merge
7300 auto *E = F_->createSlice("E", src1, {0, 30}, {10, 40});
7301 auto *F = F_->createSlice("F", src2, {0, 30}, {10, 40});
7302 auto *G = F_->createSlice("G", src1, {0, 40}, {10, 50});
7303 auto *H = F_->createSlice("H", src2, {0, 40}, {10, 50});
7304
7305 auto *D = mod_.createPlaceholder(ElemKind::FloatTy, {10, 50}, "D", false);
7306 auto *R = F_->createRELU("Relu", C);
7307 auto *CN = F_->createConcat("Concat", {A, B, D, E, F, G, H}, 1);
7308 F_->createSave("save1", CN);
7309 F_->createSave("save2", R);
7310
7311 EXPECT_EQ(F_->getNodes().size(), 11);
7312
7313 optimizedF_ = optimizeFunctionForTest(
7314 F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()});
7315
7316 EXPECT_EQ(optimizedF_->getNodes().size(), 10);
7317
7318 int numSlicesToConcat = 0;
7319 for (const auto &node : optimizedF_->getNodes()) {
7320 auto *newCN = llvm::dyn_cast<ConcatNode>(&node);
7321 if (!newCN) {
7322 continue;
7323 }
7324 EXPECT_EQ(newCN->getInputs().size(), 6);
7325 for (const auto &concatInput : newCN->getInputs()) {
7326 auto *SN = llvm::dyn_cast<SliceNode>(concatInput.getNode());
7327 if (SN) {
7328 numSlicesToConcat++;
7329 }
7330 }
7331 }
7332 EXPECT_EQ(numSlicesToConcat, 5);
7333
7334 bindings_.allocate(src1)->getHandle<float>().randomize(-10.0, 10.0,
7335 mod_.getPRNG());
7336 bindings_.allocate(src2)->getHandle<float>().randomize(-10.0, 10.0,
7337 mod_.getPRNG());
7338 bindings_.allocate(D)->getHandle<float>().randomize(-10.0, 10.0,
7339 mod_.getPRNG());
7340 checkNumericalEquivalence();
7341}
7342
7343TEST_F(GraphOptz, EliminateSliceConcatWithReshapeTest) {
7344 auto *src =
7345 mod_.createPlaceholder(ElemKind::FloatTy, {4, 5, 4}, "src", false);
7346 auto *A = F_->createSlice("A", src, {0, 0, 0}, {1, 5, 4});
7347 auto *B = F_->createSlice("B", src, {1, 0, 0}, {2, 5, 4});
7348 auto *C = F_->createSlice("C", src, {2, 0, 0}, {3, 5, 4});
7349 auto *CN1 = F_->createConcat("Concat1", {A, B, C}, 1);
7350
7351 auto *E = F_->createSlice("E", src, {0, 0, 0}, {4, 5, 1});
7352 auto *F = F_->createSlice("F", src, {0, 0, 1}, {4, 5, 2});
7353 auto *G = F_->createSlice("G", src, {0, 0, 2}, {4, 5, 3});
7354 auto *H = F_->createSlice("H", src, {0, 0, 3}, {4, 5, 4});
7355 auto *CN2 = F_->createConcat("Concat2", {E, F, G, H}, 1);
7356
7357 F_->createSave("save1", CN1);
7358 F_->createSave("save2", CN2);
7359
7360 EXPECT_EQ(F_->getNodes().size(), 11);
7361 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 7);
7362 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 2);
7363
7364 optimizedF_ = optimizeFunctionForTest(
7365 F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()});
7366
7367 EXPECT_EQ(optimizedF_->getNodes().size(), 9);
7368 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 2);
7369 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 2);
7370 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReshapeNodeKind), 2);
7371 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::TransposeNodeKind), 1);
7372
7373 bindings_.allocate(src)->getHandle<float>().randomize(-10.0, 10.0,
7374 mod_.getPRNG());
7375 checkNumericalEquivalence(0.f);
7376}
7377
7378// Check the merging of Sub(const, BN(x, scale, bias)) into BN.
7379TEST_F(GraphOptz, FoldArithmeticChainIntoBatchNormQuant) {
7380 auto *subC = mod_.createConstant(ElemKind::FloatTy, {1, 1, 1, 1}, "subC");
7381 auto *var = mod_.createConstant(ElemKind::FloatTy, {1}, "var");
7382 auto *mean = mod_.createConstant(ElemKind::FloatTy, {1}, "mean");
7383 auto *beta = mod_.createConstant(ElemKind::FloatTy, {1}, "beta");
7384 auto *gamma = mod_.createConstant(ElemKind::FloatTy, {1}, "gamma");
7385 float v = 0.3f, m = 0.4f, b = 0.7f, g = -0.5f, c = 0.1;
7386 // (X - mean) * (1.0 / sqrt(var + eps)) * gamma + beta
7387 var->getPayloadMutable().getHandle<float>() = {v};
7388 mean->getPayloadMutable().getHandle<float>() = {m};
7389 beta->getPayloadMutable().getHandle<float>() = {b};
7390 gamma->getPayloadMutable().getHandle<float>() = {g};
7391 subC->getPayloadMutable().getHandle<float>() = {c};
7392 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 1, 1}, "input",
7393 false, "NHWC");
7394
7395 auto *BN = F_->createBatchNormalization("batch", input->getType(), input,
7396 beta, gamma, mean, var);
7397 auto *sub = F_->createSub("sub", subC, BN);
7398 auto *res = F_->createSave("save", sub);
7399 // Compile.
7400 EXPECT_EQ(F_->getNodes().size(), 3);
7401 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
7402
7403 optimizedF_ = optimizeFunctionForTest(F_, {}, cctx_);
7404 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
7405
7406 auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName());
7407 auto *opt_bn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput());
7408 ASSERT_TRUE(opt_bn);
7409 // Verify that scale and offset are computed correctly.
7410 Constant *bnScale = llvm::dyn_cast<Constant>(opt_bn->getScale().getNode());
7411 Constant *bnBias = llvm::dyn_cast<Constant>(opt_bn->getBias().getNode());
7412 auto bnBiasVals = bnBias->getHandle<float>().raw(0);
7413 auto bnScaleVals = bnScale->getHandle<float>().raw(0);
7414 EXPECT_EQ(bnBiasVals, c - b);
7415 EXPECT_EQ(bnScaleVals, -g);
7416}
7417
7418/// Test that EliminateSliceConcat makes no optimization when the axis of
7419/// concatenation and slicing are not adjacent.
7420TEST_F(GraphOptz, EliminateSliceConcatWithReshapeTestNoChange) {
7421 auto *src =
7422 mod_.createPlaceholder(ElemKind::FloatTy, {4, 5, 4}, "src", false);
7423 auto *A = F_->createSlice("A", src, {0, 0, 0}, {1, 5, 4});
7424 auto *B = F_->createSlice("B", src, {1, 0, 0}, {2, 5, 4});
7425 auto *C = F_->createSlice("C", src, {2, 0, 0}, {3, 5, 4});
7426 auto *CN = F_->createConcat("Concat", {A, B, C}, 2);
7427
7428 F_->createSave("save", CN);
7429
7430 EXPECT_EQ(F_->getNodes().size(), 5);
7431 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 3);
7432 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 1);
7433
7434 optimizedF_ = optimizeFunctionForTest(
7435 F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()});
7436
7437 EXPECT_EQ(optimizedF_->getNodes().size(), 5);
7438 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 3);
7439 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
7440 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false,
7441 /* skipName */ true),
7442 optimizedF_->toString(/* skipUsersForStorage */ false,
7443 /* skipName */ true));
7444}
7445
7446/// Verify that when we want to prevent constant folding it doesn't occur.
7447TEST_F(GraphOptz, constantFoldPreventedNoop) {
7448 auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
7449 auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
7450 auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
7451 /* isTrainable */ false);
7452 setConstValue(const1, 1.0f);
7453 setConstValue(const2, 2.0f);
7454 auto *splat2 = F_->createSplat(
7455 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
7456 auto *splat3 = F_->createSplat(
7457 "splat3", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
7458
7459 auto *add1 = F_->createAdd("add", const1, const2);
7460 auto *mul1 = F_->createMul("mul1", add1, splat2);
7461 auto *mul2 = F_->createMul("mul2", mul1, splat3);
7462 F_->createSave("save", mul2);
7463 auto *add3 = F_->createAdd("add", const1, ph1);
7464 F_->createSave("save", add3);
7465
7466 ConstantModificationPreventer constModPreventer(mod_, cctx_);
7467 constModPreventer.activate();
7468 EXPECT_FALSE(cctx_.optimizationOpts.enableConstantFolding);
7469
7470 // Check that both Constants are protected and no change is made to the
7471 // Function during optimization.
7472 EXPECT_EQ(constModPreventer.getMapping().size(), 2);
7473 optimizedF_ = optimizeFunctionForTest(F_);
7474 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false,
7475 /* skipName */ true),
7476 optimizedF_->toString(/* skipUsersForStorage */ false,
7477 /* skipName */ true));
7478
7479 // Now deactivate the constModPreventer and check we can const fold still.
7480 constModPreventer.deactivateAndCleanup();
7481 EXPECT_TRUE(cctx_.optimizationOpts.enableConstantFolding);
7482 mod_.eraseFunction(optimizedF_);
7483 optimizedF_ = optimizeFunctionForTest(F_);
7484
7485 // After constant folding, left with just two Saves, one Add.
7486 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
7487 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1);
7488 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 2);
7489
7490 bindings_.allocate(ph1)->getHandle<float>().randomize(-10.0, 10.0,
7491 mod_.getPRNG());
7492 checkNumericalEquivalence();
7493}
7494
7495/// Test that a Conv2D is correctly lowered to FC for single batch.
7496TEST_F(GraphOptz, lowerConv2DToFCSingleBatch) {
7497 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4},
7498 "input", /* isTrainable */ false);
7499 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
7500 mod_.getPRNG());
7501
7502 Constant *filter =
7503 mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter");
7504 filter->getPayloadMutable().getHandle<float>().randomize(-10, 10,
7505 mod_.getPRNG());
7506
7507 Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias");
7508 bias->getPayloadMutable().getHandle<float>().randomize(-10, 10,
7509 mod_.getPRNG());
7510
7511 auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 3, 8});
7512 auto *conv = F_->createConv("conv", input, filter, bias, outTy, {1, 1},
7513 {1, 1}, {0, 0, 0, 0}, 1);
7514 SaveNode *save = F_->createSave("save", conv);
7515 bindings_.allocate(save->getPlaceholder());
7516
7517 // Backup function in optimizedF_.
7518 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
7519
7520 // Lower Convolution.
7521 EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv));
7522 EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_));
7523 runDCEPass(F_, cctx_);
7524 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
7525 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
7526
7527 // Now compile/run/compare F_ and optimizedF_.
7528 checkNumericalEquivalence(1e-6);
7529}
7530
7531/// Test that a Conv2D is correctly lowered to FC for multi batch.
7532TEST_F(GraphOptz, lowerConv2DToFCMultiBatch) {
7533 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3, 4},
7534 "input", /* isTrainable */ false);
7535 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
7536 mod_.getPRNG());
7537
7538 Constant *filter =
7539 mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter");
7540 filter->getPayloadMutable().getHandle<float>().randomize(-10, 10,
7541 mod_.getPRNG());
7542
7543 Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias");
7544 bias->getPayloadMutable().getHandle<float>().randomize(-10, 10,
7545 mod_.getPRNG());
7546
7547 auto outTy = mod_.uniqueType(ElemKind::FloatTy, {2, 2, 3, 8});
7548 auto *conv = F_->createConv("conv", input, filter, bias, outTy, {1, 1},
7549 {1, 1}, {0, 0, 0, 0}, 1);
7550 SaveNode *save = F_->createSave("save", conv);
7551 bindings_.allocate(save->getPlaceholder());
7552
7553 // Backup function in optimizedF_.
7554 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
7555
7556 // Lower Convolution.
7557 EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv));
7558 EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_));
7559 runDCEPass(F_, cctx_);
7560 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
7561 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
7562
7563 // Now compile/run/compare F_ and optimizedF_.
7564 checkNumericalEquivalence(1e-6);
7565}
7566
7567/// Test that Mul and Add can be folded into LayerNorm.
7568TEST_F(GraphOptz, foldMulAddIntoLayerNorm) {
7569 auto *input =
7570 mod_.createPlaceholder(ElemKind::FloatTy, {2, 4, 10, 20}, "in", false);
7571
7572 Tensor scaleT(ElemKind::FloatTy, {10, 20});
7573 scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
7574 Constant *scaleC = mod_.createConstant("scale", std::move(scaleT));
7575 SplatNode *biasS = F_->createSplat("bias", scaleC->getType(), 1.5f);
7576
7577 auto *LN = F_->createLayerNormalization("LN", input->getType(), input, scaleC,
7578 biasS, 1e-5);
7579
7580 SplatNode *splat = F_->createSplat("splat", scaleC->getType(), 0.5f);
7581 MulNode *MN =
7582 F_->createNodeWithBroadcast<MulNode>("mul", /* axis */ -1, LN, splat);
7583
7584 Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20});
7585 addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG());
7586 Constant *addC = mod_.createConstant("addC", std::move(addT));
7587 AddNode *AN =
7588 F_->createNodeWithBroadcast<AddNode>("add", /* axis */ -1, MN, addC);
7589
7590 // This MulNode has a Placeholder as RHS and shouldn't be fused into LayerNorm
7591 Tensor mulT(ElemKind::FloatTy, {1, 1, 10, 20});
7592 mulT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
7593 Constant *mulC = mod_.createConstant("mulC", std::move(mulT));
7594 MN = F_->createNodeWithBroadcast<MulNode>("mul_not_fuse", /* axis */ -1, AN,
7595 mulC);
7596 F_->createSave("save", MN);
7597
7598 ConstantModificationPreventer constModPreventer(mod_, cctx_);
7599 constModPreventer.activate();
7600 optimizedF_ = optimizeFunctionForTest(F_, {}, cctx_);
7601 // Now do const folding with constants swapped back in.
7602 constModPreventer.deactivateAndCleanup();
7603 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
7604 runDCEPass(optimizedF_, cctx_);
7605
7606 // Because Muls and Add are folded in, they should not exist anymore, nor
7607 // should Broadcasts that expand them to match the output of LN.
7608 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind));
7609 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind));
7610 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::BroadcastNodeKind));
7611
7612 // Remove the temporary constant folding Functions and their Placeholders
7613 // so that they don't participate in 'checkNumericalEquivalence'.
7614 cleanupConstantFolding(mod_, record, &bindings_);
7615
7616 // Now compile/run/compare F_ and optimizedF_.
7617 bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
7618 checkNumericalEquivalence(1.2e-7);
7619}
7620
7621/// Test that Mul and Add can be folded into LayerNorm when the leading dims are
7622/// all one.
7623TEST_F(GraphOptz, foldMulAddIntoLayerNormNoBatch) {
7624 auto *input =
7625 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 10, 20}, "in", false);
7626
7627 Tensor scaleT(ElemKind::FloatTy, {10, 20});
7628 scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
7629 Constant *scaleC = mod_.createConstant("scale", std::move(scaleT));
7630 SplatNode *biasS = F_->createSplat("bias", scaleC->getType(), 1.5f);
7631
7632 auto *LN = F_->createLayerNormalization("LN", input->getType(), input, scaleC,
7633 biasS, 1e-5);
7634
7635 SplatNode *splat = F_->createSplat("splat", scaleC->getType(), 0.5f);
7636 MulNode *MN =
7637 F_->createNodeWithBroadcast<MulNode>("mul", /* axis */ -1, LN, splat);
7638
7639 Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20});
7640 addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG());
7641 Constant *addC = mod_.createConstant("addC", std::move(addT));
7642 AddNode *AN =
7643 F_->createNodeWithBroadcast<AddNode>("add", /* axis */ -1, MN, addC);
7644 F_->createSave("save", AN);
7645
7646 optimizedF_ = optimizeFunctionForTest(F_);
7647
7648 // Because Mul and Add are folded in, they should not exist anymore, nor
7649 // should tiles that expand them to match the output of LN.
7650 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind));
7651 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind));
7652 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind));
7653
7654 // Now compile/run/compare F_ and optimizedF_.
7655 bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
7656 checkNumericalEquivalence(1e-6);
7657}
7658
7659TEST_F(GraphOptz, transposeQuantizeConstantWithAlignment) {
7660 // Define a type with custom alignments.
7661 Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1});
7662 Type quantTypeWithAlignments(ElemKind::Int8QTy, {2, 3, 4, 5}, {1, 1, 32, 1},
7663 1.0, 0);
7664 Type transposedQuantTypeWithAlignments(ElemKind::Int8QTy, {2, 4, 5, 3},
7665 {1, 1, 32, 1}, 1.0, 0);
7666 auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments);
7667 auto modQuantTransposedTyWithAlignments =
7668 mod_.uniqueType(transposedQuantTypeWithAlignments);
7669 auto modQuantTyWithAlignments = mod_.uniqueType(quantTypeWithAlignments);
7670 auto *I = mod_.createConstant(modTyWithAlignments, "input1");
7671 auto *Q = F_->createQuantize("quantize", I, modQuantTyWithAlignments);
7672 auto *T = F_->createTranspose("transpose", Q, NCHW2NHWC);
7673 T->setType(TransposeNode::ResultIdx, modQuantTransposedTyWithAlignments);
7674 SaveNode *S = F_->createSave("ret", T);
7675
7676 // Skip ConstantFolding as it would have the same result as this opt.
7677 CompilationContext cctx;
7678 cctx.optimizationOpts.enableConstantFolding = false;
7679
7680 EXPECT_EQ(F_->getNodes().size(), 3);
7681 ::glow::optimize(F_, cctx);
7682 EXPECT_EQ(F_->getNodes().size(), 2);
7683
7684 // Constant and Quantize should have new shape.
7685 auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
7686 ASSERT_TRUE(newQ);
7687 EXPECT_TRUE(newQ->getResult().dims().equals({2, 4, 5, 3}));
7688 auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
7689 ASSERT_TRUE(newC);
7690 EXPECT_TRUE(newC->getType()->dims().equals({2, 4, 5, 3}));
7691
7692 // Check that alignments are preserved by optimizations.
7693 auto expectedNewTy = mod_.uniqueTypeWithNewShape(
7694 modTyWithAlignments, modQuantTransposedTyWithAlignments);
7695 EXPECT_TRUE(newQ->getInput().getType()->isEqual(expectedNewTy));
7696
7697 EXPECT_TRUE(F_->verify());
7698}
7699
7700TEST_F(GraphOptz, DequantSwishQuantOpt) {
7701 const dim_t origDims[] = {1, 5, 10, 15};
7702 Placeholder *A = mod_.createPlaceholder(ElemKind::Int8QTy, origDims, 0.039, 0,
7703 "input", false);
7704 DequantizeNode *DN = F_->createDequantize("deq", A, ElemKind::Float16Ty);
7705 SwishNode *swish = F_->createSwish("swish", DN);
7706 QuantizeNode *QN =
7707 F_->createQuantize("quant", swish, ElemKind::Int8QTy, 0.0204, -114);
7708 DequantizeNode *finalDN =
7709 F_->createDequantize("deq_final", QN, ElemKind::Float16Ty);
7710 F_->createSave("ret", finalDN);
7711
7712 optimizedF_ = optimizeFunctionForTest(
7713 F_, {FunctionPassID::QuantizeSwish, getDCEPassConfig()});
7714
7715 // Swish, Dequant, Save
7716 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
7717
7718 SaveNode *save = nullptr;
7719 for (auto &N : optimizedF_->getNodes()) {
7720 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
7721 save = llvm::dyn_cast<SaveNode>(&N);
7722 break;
7723 }
7724 }
7725 ASSERT_TRUE(save);
7726
7727 DequantizeNode *dequantizeOpt =
7728 llvm::dyn_cast<DequantizeNode>(save->getInput());
7729 ASSERT_TRUE(dequantizeOpt);
7730
7731 SwishNode *swishOpt = llvm::dyn_cast<SwishNode>(dequantizeOpt->getInput());
7732 ASSERT_TRUE(swishOpt);
7733 EXPECT_EQ(swishOpt->getInput(), A->getOutput());
7734 EXPECT_EQ(swishOpt->getResult().getType(), QN->getResult().getType());
7735
7736 bindings_.allocate(mod_.getPlaceholders());
7737 bindings_.get(A)->getHandle<int8_t>().randomize(-128, 127, mod_.getPRNG());
7738
7739 checkNumericalEquivalence(0.025f);
7740}
7741
7742/// Test the conversion of FullyConnected to 1x1 Convolution.
7743TEST_F(GraphOptz, ConvertFullyConnectedToConvolutionOpt) {
7744
7745 const std::vector<dim_t> inpDims = {3, 5};
7746 const std::vector<dim_t> weightsDims = {5, 7};
7747 const std::vector<dim_t> biasDims = {7};
7748
7749 // Create graph.
7750 Placeholder *input =
7751 mod_.createPlaceholder(ElemKind::FloatTy, inpDims, "input", false);
7752 Placeholder *weights =
7753 mod_.createPlaceholder(ElemKind::FloatTy, weightsDims, "weights", false);
7754 Placeholder *bias =
7755 mod_.createPlaceholder(ElemKind::FloatTy, biasDims, "bias", false);
7756 FullyConnectedNode *FCN =
7757 F_->createFullyConnected("fc", input, weights, bias);
7758 F_->createSave("save", FCN);
7759
7760 // Optimize graph.
7761 optimizedF_ = optimizeFunctionForTest(
7762 F_,
7763 {FunctionPassID::ConvertFullyConnectedToConvolution, getDCEPassConfig()});
7764
7765 // Check optimized graph.
7766 EXPECT_EQ(optimizedF_->getNodes().size(), 6);
7767 SaveNode *save = nullptr;
7768 for (auto &N : optimizedF_->getNodes()) {
7769 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
7770 save = llvm::dyn_cast<SaveNode>(&N);
7771 break;
7772 }
7773 }
7774 ASSERT_TRUE(save);
7775 ReshapeNode *reshapeOut = llvm::dyn_cast<ReshapeNode>(save->getInput());
7776 ASSERT_TRUE(reshapeOut);
7777 ConvolutionNode *conv =
7778 llvm::dyn_cast<ConvolutionNode>(reshapeOut->getInput());
7779 ASSERT_TRUE(conv);
7780 ReshapeNode *reshapeFilter = llvm::dyn_cast<ReshapeNode>(conv->getFilter());
7781 ASSERT_TRUE(reshapeFilter);
7782 TransposeNode *transpFilter =
7783 llvm::dyn_cast<TransposeNode>(reshapeFilter->getInput());
7784 ASSERT_TRUE(transpFilter);
7785 ReshapeNode *reshapeInput = llvm::dyn_cast<ReshapeNode>(conv->getInput());
7786 ASSERT_TRUE(reshapeInput);
7787
7788 // Check numerical equivalence.
7789 bindings_.allocate(mod_.getPlaceholders());
7790 bindings_.get(input)->getHandle<float>().randomize(-1, 1, mod_.getPRNG());
7791 bindings_.get(weights)->getHandle<float>().randomize(-1, 1, mod_.getPRNG());
7792 bindings_.get(bias)->getHandle<float>().randomize(-1, 1, mod_.getPRNG());
7793 checkNumericalEquivalence(1e-8);
7794}
7795
7796/// Test that when we have Concat({X, Quantize(Clip)}), that we don't optimize
7797/// to Concat({X, Quantize'}), since Quantize' will have different quantization
7798/// parameters and therefore won't have the same quantization parameters as X.
7799TEST_F(GraphOptz, DisallowChangeQuantParamWithConcatInput) {
7800 Placeholder *PH1 = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.3, 5,
7801 "input", false);
7802 bindings_.allocate(PH1)->getHandle<int8_t>().randomize(-128, 127,
7803 mod_.getPRNG());
7804 Placeholder *PH2 =
7805 mod_.createPlaceholder(ElemKind::Float16Ty, {1, 32}, "input", false);
7806 bindings_.allocate(PH2)->getHandle<float16_t>().randomize(-40.f, 40.f,
7807 mod_.getPRNG());
7808
7809 ClipNode *clip = F_->createClip("clip", PH2, 0.f, 1000.f);
7810 QuantizeNode *quant = F_->createQuantize(
7811 "quantize", clip, mod_.uniqueType(ElemKind::Int8QTy, {1, 32}, 0.3, 5));
7812
7813 ConcatNode *CN = F_->createConcat("concat", {PH1, quant}, 0);
7814 F_->createSave("save", CN);
7815
7816 optimizedF_ = optimizeFunctionForTest(F_);
7817
7818 // Expect the graph didn't change at all, since we disallowed it due to the
7819 // fact that we disallowed Quantize(Clip) to be merged into Quantize', ssince
7820 // the Quantize is consumed by a Concat which requires the quantization
7821 // parameters to stay the same across all inputs.
7822 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false,
7823 /* skipName */ true),
7824 optimizedF_->toString(/* skipUsersForStorage */ false,
7825 /* skipName */ true));
7826
7827 checkNumericalEquivalence();
7828}
7829
7830/// Test that a AdaptiveAvgPool with 1x1 OFM is correctly lowered to AvgPool.
7831TEST_F(GraphOptz, lower1x1AdaptiveAvgPoolToAvgPool) {
7832 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3, 4},
7833 "input", /* isTrainable */ false);
7834 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
7835 mod_.getPRNG());
7836
7837 auto outTy = mod_.uniqueType(ElemKind::FloatTy, {2, 1, 1, 4});
7838 auto *pool = F_->createAdaptiveAvgPool("avg", input, outTy);
7839 SaveNode *save = F_->createSave("save", pool);
7840 bindings_.allocate(save->getPlaceholder());
7841
7842 // Backup function in optimizedF_.
7843 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
7844
7845 // Lower
7846 EXPECT_TRUE(glow::lowerNode(F_, pool, cctx_));
7847 runDCEPass(F_, cctx_);
7848 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AdaptiveAvgPoolNodeKind));
7849 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind));
7850
7851 // Now compile/run/compare F_ and optimizedF_.
7852 checkNumericalEquivalence(1e-6);
7853}
7854
7855/// Skip Clip-Quantize optimization when loadUniquedDummyQParams.
7856TEST_F(GraphOptz, SkipDummyQParamOpts) {
7857 Placeholder *A = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "A", false);
7858 ClipNode *CN = F_->createClip("clip", A, -1000.f, 1000.f);
7859 QuantizeNode *QN = F_->createQuantize(
7860 "quantize", CN, mod_.uniqueType(ElemKind::Int8QTy, {5}, 0.3, 5));
7861 F_->createSave("ret", QN);
7862
7863 CompilationContext cctx;
7864 cctx.precisionConfig.loadUniquedDummyQParams = true;
7865
7866 optimizedF_ = optimizeFunctionForTest(F_, {}, cctx);
7867 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true),
7868 optimizedF_->toString(/* skipUsersForStorage */ false,
7869 /* skipName */ true));
7870}
7871
7872/// Test that Min -> Max is correctly folded into Clip
7873TEST_F(GraphOptz, foldMinMaxToClipTest) {
7874 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 5},
7875 "input", /* isTrainable */ false);
7876 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
7877 mod_.getPRNG());
7878
7879 auto *minFirstSplat = F_->createSplat("min_first_splat", input->getType(), 5);
7880 auto *maxFirstSplat =
7881 F_->createSplat("max_first_splat", input->getType(), -2);
7882 auto *minFirst = F_->createMin("min_first", input, minFirstSplat);
7883 auto *maxFirst = F_->createMax("max_first", maxFirstSplat, minFirst);
7884
7885 auto *minSecondSplat = F_->createSplat(
7886 "min_second_splat",
7887 F_->getParent()->uniqueTypeWithNewShape(input->getType(), {3, 1, 1}), 3);
7888 auto *maxSecondSplat =
7889 F_->createSplat("max_second_splat", input->getType(), 1);
7890 auto *maxSecond = F_->createMax("max_second", maxFirst, maxSecondSplat);
7891 auto *minSecond = F_->createNodeWithBroadcast<MinNode>(
7892 "min_second", /* axis */ -1, maxSecond, minSecondSplat);
7893 SaveNode *save = F_->createSave("save", minSecond);
7894 bindings_.allocate(save->getPlaceholder());
7895
7896 // Need to run OptimizeArithmeticNodes first to move constant operators in
7897 // communative nodes to RHS.
7898 optimizedF_ = optimizeFunctionForTest(
7899 F_, {FunctionPassID::OptimizeArithmeticNodes,
7900 FunctionPassID::FoldMinMaxToClip, getDCEPassConfig()});
7901
7902 EXPECT_EQ(4, optimizedF_->getNodes().size());
7903 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MinNodeKind));
7904 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MaxNodeKind));
7905
7906 // Get SaveNode in optimizedF_
7907 save = llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName("save_save"));
7908 // Check min and max of the second ClipNode
7909 ClipNode *CN = llvm::dyn_cast<ClipNode>(save->getInput().getNode());
7910 EXPECT_EQ(1, CN->getMin());
7911 EXPECT_EQ(3, CN->getMax());
7912
7913 // There's a BroadcastNode in between the first and the second ClipNode
7914 BroadcastNode *BN = llvm::dyn_cast<BroadcastNode>(CN->getInput().getNode());
7915 // Check min and max of the first ClipNode
7916 CN = llvm::dyn_cast<ClipNode>(BN->getInput().getNode());
7917 EXPECT_EQ(-2, CN->getMin());
7918 EXPECT_EQ(5, CN->getMax());
7919
7920 checkNumericalEquivalence();
7921}
7922
7923/// Test that Min -> Max Fold pass does not break with a reshape LHS input.
7924TEST_F(GraphOptz, foldMinMaxToClipReshapeNoBroadcastTest) {
7925 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 100},
7926 "input", /* isTrainable */ false);
7927 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
7928 mod_.getPRNG());
7929
7930 auto *reshape = F_->createReshape("reshape", input, {100, 1});
7931 const TypeRef T = reshape->getResult().getType();
7932
7933 auto *maxSplat = F_->createSplat("max_splat", T, -2);
7934 auto *minSplat = F_->createSplat("min_splat", T, 5);
7935 auto *max = F_->createMax("max", reshape, maxSplat);
7936 auto *min = F_->createMin("min", max, minSplat);
7937 SaveNode *save = F_->createSave("save", min);
7938 bindings_.allocate(save->getPlaceholder());
7939
7940 optimizedF_ = optimizeFunctionForTest(
7941 F_, {FunctionPassID::FoldMinMaxToClip, getDCEPassConfig()});
7942
7943 EXPECT_EQ(3, optimizedF_->getNodes().size());
7944 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MinNodeKind));
7945 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MaxNodeKind));
7946
7947 save = llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName("save_save"));
7948 ASSERT_TRUE(save);
7949 auto *CN = llvm::dyn_cast<ClipNode>(save->getInput().getNode());
7950 ASSERT_TRUE(CN);
7951 EXPECT_EQ(-2, CN->getMin());
7952 EXPECT_EQ(5, CN->getMax());
7953 auto *RN = llvm::dyn_cast<ReshapeNode>(CN->getInput());
7954 ASSERT_TRUE(RN);
7955 EXPECT_TRUE(RN->getResult().getType()->isEqual(T));
7956 EXPECT_EQ(RN->getInput(), input->getOutput());
7957
7958 checkNumericalEquivalence();
7959}
7960
7961/// Check that we replace a Node with 0.f scale in fp16 with a splat correctly.
7962TEST_F(GraphOptz, ReplaceZeroScaleFP16QuantOpt) {
7963 auto *LHS = mod_.createPlaceholder(ElemKind::FloatTy, {20, 30}, "LHS", false);
7964 auto *RHSQ = mod_.createPlaceholder(ElemKind::Int8QTy, {20, 30}, 0.1f, 10,
7965 "LHS", false);
7966
7967 // scale = 1e-9 underflows fp16 and so this opt applies.
7968 auto *LHSQTy = mod_.uniqueType(ElemKind::Int8QTy, {20, 30}, 1e-9, 10);
7969 auto *LHSQ = F_->createQuantize("LHSQ", LHS, LHSQTy);
7970
7971 auto *A = F_->createAdd("add", RHSQ->getOutput().getType(), LHSQ, RHSQ);
7972 auto *Q = F_->createDequantize("deq", A, ElemKind::FloatTy);
7973 F_->createSave("save", Q);
7974
7975 optimizedF_ = optimizeFunctionForTest(
7976 F_, {FunctionPassID::ReplaceZeroScaleFP16QuantNodes, getDCEPassConfig()});
7977
7978 SaveNode *save = nullptr;
7979 for (auto &N : optimizedF_->getNodes()) {
7980 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
7981 save = llvm::dyn_cast<SaveNode>(&N);
7982 break;
7983 }
7984 }
7985 ASSERT_TRUE(save);
7986
7987 DequantizeNode *optQ = llvm::dyn_cast<DequantizeNode>(save->getInput());
7988 ASSERT_TRUE(optQ);
7989 AddNode *optA = llvm::dyn_cast<AddNode>(optQ->getInput());
7990 ASSERT_TRUE(A);
7991
7992 SplatNode *splat = llvm::dyn_cast<SplatNode>(optA->getLHS());
7993 ASSERT_TRUE(splat);
7994 EXPECT_EQ(splat->getValue(), 0.f);
7995 const TypeRef optLHSQTy = splat->getResult().getType();
7996 EXPECT_EQ(optLHSQTy->getScale(), 1.f);
7997 EXPECT_EQ(optLHSQTy->getOffset(), 0);
7998 EXPECT_EQ(optLHSQTy->getElementType(), LHSQTy->getElementType());
7999 EXPECT_EQ(optLHSQTy->dims(), LHSQTy->dims());
8000
8001 bindings_.allocate(LHS)->getHandle<float>().randomize(-10.f, 10.f,
8002 mod_.getPRNG());
8003 bindings_.allocate(RHSQ)->getHandle<int8_t>().randomize(-128, 127,
8004 mod_.getPRNG());
8005
8006 checkNumericalEquivalence(0.f);
8007}
8008
8009/// Same as GraphOptz, but when running numerical equivalence use the CPU
8010/// backend instead of Interpreter.
8011class GraphOptzOnCPU : public GraphOptz {
8012public:
8013 GraphOptzOnCPU() : GraphOptz("CPU") {}
8014#ifndef GLOW_WITH_CPU
8015 virtual void checkNumericalEquivalence(float allowedError = 0.0001) override {
8016 LOG(INFO) << "Skipping numerical equivalence check as the CPU backend is "
8017 "not built.";
8018 }
8019#endif /* GLOW_WITH_CPU */
8020};
8021
8022/// Check that we replace a Node with 0.f scale in fp16 with a splat correctly.
8023TEST_F(GraphOptzOnCPU, ReplaceZeroScaleFP16QuantConstOpt) {
8024 auto *input =
8025 mod_.createPlaceholder(ElemKind::Int8QTy, {1, 1}, 1.0, 0, "input", false);
8026 // scale = 1e-9 underflows fp16 and so this opt applies.
8027 auto *weights =
8028 mod_.createConstant(ElemKind::Int8QTy, {1, 1}, 1e-9, 0, "weights");
8029 weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
8030 mod_.getPRNG());
8031 auto *MM = F_->createMatMul("matmul", input, weights);
8032 auto *DQ = F_->createDequantize("dq", MM, ElemKind::FloatTy);
8033 F_->createSave("save", DQ);
8034
8035 optimizedF_ = optimizeFunctionForTest(
8036 F_, {FunctionPassID::ReplaceZeroScaleFP16QuantNodes, getDCEPassConfig()});
8037
8038 SaveNode *save = nullptr;
8039 for (auto &N : optimizedF_->getNodes()) {
8040 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
8041 save = llvm::dyn_cast<SaveNode>(&N);
8042 break;
8043 }
8044 }
8045 ASSERT_TRUE(save);
8046
8047 auto *optDQ = llvm::dyn_cast<DequantizeNode>(save->getInput());
8048 ASSERT_TRUE(optDQ);
8049 auto *optMM = llvm::dyn_cast<MatMulNode>(optDQ->getInput());
8050 ASSERT_TRUE(optMM);
8051
8052 SplatNode *splat = llvm::dyn_cast<SplatNode>(optMM->getRHS());
8053 ASSERT_TRUE(splat);
8054 EXPECT_EQ(splat->getValue(), 0.f);
8055 const TypeRef splatQTy = splat->getResult().getType();
8056 EXPECT_EQ(splatQTy->getScale(), 1.f);
8057 EXPECT_EQ(splatQTy->getOffset(), 0);
8058 EXPECT_EQ(splatQTy->getElementType(), weights->getOutput().getElementType());
8059 EXPECT_EQ(splatQTy->dims(), weights->getOutput().dims());
8060
8061 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
8062 mod_.getPRNG());
8063 checkNumericalEquivalence(0.f);
8064}
8065
8066TEST_F(GraphOptz, TestEliminateClipsOutsideFP16Range) {
8067 Placeholder *A = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "A", false);
8068 ClipNode *CN1 = F_->createClipMinMaxFP16("clip1", A);
8069 ClipNode *CN2 = F_->createClip("clip2", A, kMinFP16, kMaxFP16 - 1.f);
8070 QuantizeNode *QN1 = F_->createQuantize(
8071 "q1", CN1, mod_.uniqueType(ElemKind::Int8QTy, {5}, 0.3, 5));
8072 QuantizeNode *QN2 = F_->createQuantize(
8073 "q2", CN2, mod_.uniqueType(ElemKind::Int8QTy, {5}, 0.3, 5));
8074 AddNode *AN = F_->createAdd("add", QN1, QN2);
8075 DequantizeNode *DN = F_->createDequantize("dq", AN, ElemKind::Float16Ty);
8076 ClipNode *CN3 = F_->createClipMinMaxFP16("clip3", DN);
8077 F_->createSave("ret", CN3);
8078
8079 CompilationContext cctx;
8080 cctx.precisionConfig.clipQuantRangeToFP16 = true;
8081
8082 optimizedF_ = optimizeFunctionForTest(
8083 F_, {FunctionPassID::EliminateClipsOutsideFP16Range, getDCEPassConfig()},
8084 cctx);
8085
8086 SaveNode *save = nullptr;
8087 for (auto &N : optimizedF_->getNodes()) {
8088 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
8089 save = llvm::dyn_cast<SaveNode>(&N);
8090 break;
8091 }
8092 }
8093 ASSERT_TRUE(save);
8094
8095 auto *optDQ = llvm::dyn_cast<DequantizeNode>(save->getInput());
8096 ASSERT_TRUE(optDQ);
8097 auto *optAN = llvm::dyn_cast<AddNode>(optDQ->getInput());
8098 ASSERT_TRUE(optAN);
8099
8100 auto *optQN1 = llvm::dyn_cast<QuantizeNode>(optAN->getLHS());
8101 ASSERT_TRUE(optQN1);
8102 EXPECT_EQ(optQN1->getInput(), A->getOutput());
8103
8104 auto *optQN2 = llvm::dyn_cast<QuantizeNode>(optAN->getRHS());
8105 ASSERT_TRUE(optQN2);
8106 auto *optCN2 = llvm::dyn_cast<ClipNode>(optQN2->getInput());
8107 ASSERT_TRUE(optCN2);
8108 EXPECT_EQ(optCN2->getMin(), CN2->getMin());
8109 EXPECT_EQ(optCN2->getMax(), CN2->getMax());
8110 EXPECT_EQ(optCN2->getInput(), A->getOutput());
8111
8112 bindings_.allocate(A)->getHandle<float16_t>().randomize(-128, 127,
8113 mod_.getPRNG());
8114 checkNumericalEquivalence(0.f);
8115}
8116
8117TEST_F(GraphOptz, TestUpdateQuantReluTypes) {
8118 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.11, -1,
8119 "input", false);
8120 auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 0.2, 3,
8121 "weights", false);
8122 auto *bias =
8123 mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 0.01, 2, "bias", false);
8124 auto *addW = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.3, -4,
8125 "addw", false);
8126
8127 auto *fc = F_->createFullyConnected("fc", input, weights, bias);
8128 auto *qRelu = F_->createRELU("relu", fc->getResult());
8129 auto *qAdd = F_->createAdd("add", qRelu, addW);
8130 F_->createSave("save", qAdd);
8131
8132 updateQuantReluTypes(F_);
8133
8134 const auto fcRange = fc->getResult().getType()->getQuantizedValueRange();
8135 const auto reluRange = qRelu->getResult().getType()->getQuantizedValueRange();
8136 EXPECT_NE(reluRange.first, fcRange.first);
8137 EXPECT_EQ(reluRange.first, 0);
8138 EXPECT_EQ(reluRange.second, fcRange.second);
8139}
8140
8141TEST_F(GraphOptz, TestUpdateQuantReluTypesChained) {
8142 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.11, -1,
8143 "input", false);
8144 auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 0.2, 3,
8145 "weights", false);
8146 auto *bias =
8147 mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 0.01, 2, "bias", false);
8148 auto *addW =
8149 mod_.createPlaceholder(ElemKind::Int8QTy, {128}, 0.3, -4, "addw", false);
8150
8151 auto *fc = F_->createFullyConnected("fc", input, weights, bias);
8152 auto *qRelu = F_->createRELU("relu", fc->getResult());
8153 auto *qConcat = F_->createConcat("concat", {qRelu, qRelu}, 0);
8154 auto *qReshape = F_->createReshape("reshape", qConcat, {128});
8155 auto *qAdd = F_->createAdd("add", qReshape, addW);
8156 F_->createSave("save", qAdd);
8157
8158 updateQuantReluTypes(F_);
8159
8160 const auto fcRange = fc->getResult().getType()->getQuantizedValueRange();
8161 const auto reluRange = qRelu->getResult().getType()->getQuantizedValueRange();
8162 EXPECT_NE(reluRange.first, fcRange.first);
8163 EXPECT_EQ(reluRange.first, 0);
8164 EXPECT_EQ(reluRange.second, fcRange.second);
8165
8166 // Check that the relu's type now also matches that of the chain of shape
8167 // users after it.
8168 const TypeRef qReluTy = qRelu->getResult().getType();
8169 EXPECT_EQ(qReluTy->getScale(), qConcat->getResult().getType()->getScale());
8170 EXPECT_EQ(qReluTy->getOffset(), qConcat->getResult().getType()->getOffset());
8171 EXPECT_EQ(qReluTy->getScale(), qReshape->getResult().getType()->getScale());
8172 EXPECT_EQ(qReluTy->getOffset(), qReshape->getResult().getType()->getOffset());
8173}
8174
8175TEST_F(GraphOptz, SinkReshapeBelowQuantize) {
8176 auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false);
8177 auto *RN = F_->createReshape("reshape", I, {32, 64, 1});
8178 auto *QN = F_->createQuantize("quantize", RN, ElemKind::Int8QTy, 0.2f, 1);
8179 auto *SN = F_->createSave("ret", QN);
8180
8181 optimizedF_ = optimizeFunctionForTest(
8182 F_, {FunctionPassID::SinkCode, getDCEPassConfig()});
8183
8184 auto *optSN =
8185 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
8186 ASSERT_TRUE(optSN);
8187 auto *optRN = llvm::dyn_cast<ReshapeNode>(optSN->getInput());
8188 ASSERT_TRUE(optRN);
8189 EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Int8QTy);
8190 EXPECT_EQ(optRN->getResult().getScale(), 0.2f);
8191 EXPECT_EQ(optRN->getResult().getOffset(), 1);
8192 EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims());
8193 auto *optQN = llvm::dyn_cast<QuantizeNode>(optRN->getInput());
8194 ASSERT_TRUE(optQN);
8195 EXPECT_EQ(optQN->getResult().getElementType(), ElemKind::Int8QTy);
8196 EXPECT_EQ(optQN->getResult().getScale(), 0.2f);
8197 EXPECT_EQ(optQN->getResult().getOffset(), 1);
8198 EXPECT_EQ(optQN->getInput().getNode(), I);
8199
8200 bindings_.allocate(I)->getHandle<float>().randomize(-30, 30, mod_.getPRNG());
8201 checkNumericalEquivalence(0.f);
8202}
8203
8204TEST_F(GraphOptz, SinkReshapeBelowConvertTo) {
8205 auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false);
8206 auto *RN = F_->createReshape("reshape", I, {32, 64, 1});
8207 auto *CN = F_->createConvertTo("convert", RN, ElemKind::Float16Ty);
8208 auto *SN = F_->createSave("ret", CN);
8209
8210 optimizedF_ = optimizeFunctionForTest(
8211 F_, {FunctionPassID::SinkCode, getDCEPassConfig()});
8212
8213 auto *optSN =
8214 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
8215 ASSERT_TRUE(optSN);
8216 auto *optRN = llvm::dyn_cast<ReshapeNode>(optSN->getInput());
8217 ASSERT_TRUE(optRN);
8218 EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Float16Ty);
8219 EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims());
8220 auto *optCN = llvm::dyn_cast<ConvertToNode>(optRN->getInput());
8221 ASSERT_TRUE(optCN);
8222 EXPECT_EQ(optCN->getResult().getElementType(), ElemKind::Float16Ty);
8223 EXPECT_EQ(optCN->getInput().getNode(), I);
8224
8225 bindings_.allocate(I)->getHandle<float>().randomize(-30, 30, mod_.getPRNG());
8226 checkNumericalEquivalence(0.f);
8227}
8228
8229TEST_F(GraphOptz, SinkReshapeBelowUnaryEltwiseOps) {
8230 const dim_t dimsIn[] = {10, 10};
8231 const dim_t dimsOut[] = {5, 5, 4};
8232
8233 auto *in = mod_.createPlaceholder(ElemKind::FloatTy, dimsIn, "in", false);
8234 auto *RN = F_->createReshape("reshape", in, dimsOut);
8235 auto *AN = F_->createAbs("abs", RN);
8236 auto *SN = F_->createSin("sin", AN);
8237 auto *CN = F_->createClip("clip", SN, -4.f, 5.f);
8238 auto *TN = F_->createTanh("tanh", CN);
8239 auto *save = F_->createSave("ret", TN);
8240
8241 optimizedF_ = optimizeFunctionForTest(F_);
8242
8243 auto *optSave =
8244 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
8245 ASSERT_TRUE(optSave);
8246 auto *optRN = llvm::dyn_cast<ReshapeNode>(optSave->getInput());
8247 ASSERT_TRUE(optRN);
8248 EXPECT_EQ(optRN->getResult().dims(), llvm::makeArrayRef(dimsOut));
8249 auto *optTN = llvm::dyn_cast<TanhNode>(optRN->getInput());
8250 ASSERT_TRUE(optTN);
8251 EXPECT_EQ(optTN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8252 auto *optCN = llvm::dyn_cast<ClipNode>(optTN->getInput());
8253 ASSERT_TRUE(optCN);
8254 EXPECT_FLOAT_EQ(optCN->getMin(), CN->getMin());
8255 EXPECT_FLOAT_EQ(optCN->getMax(), CN->getMax());
8256 EXPECT_EQ(optCN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8257 auto *optSN = llvm::dyn_cast<SinNode>(optCN->getInput());
8258 ASSERT_TRUE(optSN);
8259 EXPECT_EQ(optSN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8260 auto *optAN = llvm::dyn_cast<AbsNode>(optSN->getInput());
8261 ASSERT_TRUE(optAN);
8262 EXPECT_EQ(optAN->getResult().dims(), llvm::makeArrayRef(dimsIn));
8263
8264 bindings_.allocate(in)->getHandle<float>().randomize(-30.f, 30.f,
8265 mod_.getPRNG());
8266 checkNumericalEquivalence(0.f);
8267}
8268
8269TEST_F(GraphOptz, OptConvertToDequantize) {
8270 auto *I =
8271 mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A", false);
8272 auto *DN = F_->createDequantize("deq", I, ElemKind::Float16Ty);
8273 auto *CN = F_->createConvertTo("convert", DN, ElemKind::FloatTy);
8274 auto *SN = F_->createSave("ret", CN);
8275
8276 optimizedF_ = optimizeFunctionForTest(
8277 F_,
8278 {FunctionPassID::OptimizeOutIntermediateConversions, getDCEPassConfig()});
8279
8280 auto *optSN =
8281 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
8282 ASSERT_TRUE(optSN);
8283 auto *optDN = llvm::dyn_cast<DequantizeNode>(optSN->getInput());
8284 ASSERT_TRUE(optDN);
8285 EXPECT_EQ(optDN->getResult().getElementType(), ElemKind::FloatTy);
8286 EXPECT_EQ(optDN->getResult().dims(), DN->getResult().dims());
8287 EXPECT_EQ(optDN->getInput().getNode(), I);
8288
8289 bindings_.allocate(I)->getHandle<int8_t>().randomize(-128, 127,
8290 mod_.getPRNG());
8291 checkNumericalEquivalence(0.007f);
8292}
8293
8294/// Test that Exp+ReduceSum+Div is replaced with SoftMax.
8295TEST_F(GraphOptz, FoldExpSumDivIntoSoftmax) {
8296 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 10},
8297 "input", /* isTrainable */ false);
8298 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
8299 mod_.getPRNG());
8300 auto *exp = F_->createExp("exp", input);
8301 auto *reduceSum = F_->createBatchedReduceAdd("reduce_sum", exp, {1});
8302 auto *div = F_->createNodeWithBroadcast<DivNode>("div", 1, exp, reduceSum);
8303 F_->createSave("save", div);
8304
8305 EXPECT_EQ(5, F_->getNodes().size());
8306
8307 optimizedF_ = optimizeFunctionForTest(
8308 F_, {FunctionPassID::FoldExpSumDivIntoSoftmax, getDCEPassConfig()});
8309
8310 EXPECT_EQ(2, optimizedF_->getNodes().size());
8311
8312 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::ExpNodeKind));
8313 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::DivNodeKind));
8314 EXPECT_EQ(0,
8315 countNodeKind(optimizedF_, Kinded::Kind::BatchedReduceAddNodeKind));
8316 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SoftMaxNodeKind));
8317
8318 checkNumericalEquivalence(1e-7f);
8319}
8320
8321/// Test that identity Relu is removed.
8322TEST_F(GraphOptz, RemoveIdentityRelu) {
8323
8324 Placeholder *input = mod_.createPlaceholder(
8325 ElemKind::Int8QTy, {20}, 0.123f, -128, "input", /* isTrainable */ false);
8326 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
8327 mod_.getPRNG());
8328 auto *relu = F_->createRELU("exp", input);
8329 F_->createSave("save", relu);
8330
8331 EXPECT_EQ(2, F_->getNodes().size());
8332 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
8333 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SaveNodeKind));
8334
8335 optimizedF_ = optimizeFunctionForTest(
8336 F_, {FunctionPassID::RemoveIdentityRelu, getDCEPassConfig()});
8337
8338 EXPECT_EQ(1, optimizedF_->getNodes().size());
8339 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind));
8340 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind));
8341
8342 checkNumericalEquivalence(0);
8343}
8344
8345/// Test that identity Clip is removed.
8346TEST_F(GraphOptz, RemoveIdentityClip) {
8347
8348 Placeholder *input =
8349 mod_.createPlaceholder(ElemKind::Int8QTy, {20}, 0.023529412f, -128,
8350 "input", /* isTrainable */ false);
8351 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
8352 mod_.getPRNG());
8353 auto *clip = F_->createClip("exp", input, 0.0f, 6.0f);
8354 F_->createSave("save", clip);
8355
8356 EXPECT_EQ(2, F_->getNodes().size());
8357 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ClipNodeKind));
8358 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SaveNodeKind));
8359
8360 optimizedF_ = optimizeFunctionForTest(
8361 F_, {FunctionPassID::RemoveIdentityClip, getDCEPassConfig()});
8362
8363 EXPECT_EQ(1, optimizedF_->getNodes().size());
8364 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind));
8365 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind));
8366
8367 checkNumericalEquivalence(0);
8368}
8369
8370/// Test that an identity ResizeNearest is removed.
8371TEST_F(GraphOptz, OptimizeIdentityResizeNearest) {
8372 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 33, 33, 1},
8373 "input", /* isTrainable */ false);
8374 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
8375 mod_.getPRNG());
8376 auto *resize = F_->createResizeNearest("resize", input, {1, 1, 1, 1});
8377 F_->createSave("save", resize);
8378 EXPECT_EQ(2, F_->getNodes().size());
8379 optimizedF_ = optimizeFunctionForTest(
8380 F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()});
8381 EXPECT_EQ(1, optimizedF_->getNodes().size());
8382 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind));
8383 checkNumericalEquivalence(1e-7f);
8384}
8385
8386/// Test that a ResizeNearest with integer scales is transformed to Tile.
8387TEST_F(GraphOptz, OptimizeResizeNearest) {
8388 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 33, 1},
8389 "input", /* isTrainable */ false);
8390 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
8391 mod_.getPRNG());
8392 auto *resize = F_->createResizeNearest("resize", input, {1, 2, 7.787879, 1});
8393 F_->createSave("save", resize);
8394 EXPECT_EQ(2, F_->getNodes().size());
8395 optimizedF_ = optimizeFunctionForTest(
8396 F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()});
8397 EXPECT_EQ(3, optimizedF_->getNodes().size());
8398 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind));
8399 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::ResizeNearestNodeKind));
8400 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind));
8401 checkNumericalEquivalence(1e-7f);
8402}
8403
8404/// Test that an identity ResizeBilinear is removed.
8405TEST_F(GraphOptz, OptimizeIdentityResizeBilinear) {
8406 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 33, 33, 1},
8407 "input", /* isTrainable */ false);
8408 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
8409 mod_.getPRNG());
8410 auto *resize = F_->createResizeBilinear("resize", input, {1, 1, 1, 1});
8411 F_->createSave("save", resize);
8412 EXPECT_EQ(2, F_->getNodes().size());
8413 optimizedF_ = optimizeFunctionForTest(
8414 F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()});
8415 EXPECT_EQ(1, optimizedF_->getNodes().size());
8416 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind));
8417 checkNumericalEquivalence(1e-7f);
8418}
8419
8420/// Test that a ResizeBilinear with integer scales is transformed to Tile.
8421TEST_F(GraphOptz, OptimizeResizeBilinear) {
8422 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 33, 1},
8423 "input", /* isTrainable */ false);
8424 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
8425 mod_.getPRNG());
8426 auto *resize = F_->createResizeBilinear("resize", input, {1, 2, 7.787879, 1});
8427 F_->createSave("save", resize);
8428 EXPECT_EQ(2, F_->getNodes().size());
8429 optimizedF_ = optimizeFunctionForTest(
8430 F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()});
8431 EXPECT_EQ(3, optimizedF_->getNodes().size());
8432 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind));
8433 EXPECT_EQ(1,
8434 countNodeKind(optimizedF_, Kinded::Kind::ResizeBilinearNodeKind));
8435 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind));
8436 checkNumericalEquivalence(1e-7f);
8437}
8438
8439/// Test that a InsertTensor which has the Big operand a Splat is replaced
8440/// with a Touch node when the Small operand fills it entirely.
8441TEST_F(GraphOptz, OptimizeInsertTensorBigSplat) {
8442 Type bigTy(ElemKind::FloatTy, {10});
8443 SplatNode *big = F_->createSplat("splat", &bigTy, 0);
8444 Placeholder *small = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input",
8445 /* isTrainable */ false);
8446 bindings_.allocate(small)->getHandle<float>().randomize(-10, 10,
8447 mod_.getPRNG());
8448 auto *insert = F_->createInsertTensor("insert", big, small,
8449 /* start */ {0},
8450 /* count */ 10,
8451 /* axis */ 0);
8452 F_->createSave("save", insert);
8453 EXPECT_EQ(3, F_->getNodes().size());
8454 optimizedF_ = optimizeFunctionForTest(
8455 F_, {FunctionPassID::OptimizeInsert, getDCEPassConfig()});
8456 EXPECT_EQ(3, optimizedF_->getNodes().size());
8457 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::TouchNodeKind));
8458 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::InsertTensorNodeKind));
8459 EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind));
8460 checkNumericalEquivalence(1e-7f);
8461}
8462
8463TEST_F(GraphOptz, sinkQuantizeTransposeMultiUser) {
8464 auto *input =
8465 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input",
8466 /* isTrainable */ false);
8467 auto *T = F_->createTranspose("transpose", input, NHWC2NCHW);
8468 auto *Q1 = F_->createQuantize("q1", T, ElemKind::Int8QTy, 0.11, 1);
8469 auto *Q2 = F_->createQuantize("q2", T, ElemKind::Int8QTy, 0.12, 2);
8470 auto *S1 = F_->createSave("save1", Q1);
8471 auto *S2 = F_->createSave("save2", Q2);
8472
8473 optimizedF_ = optimizeFunctionForTest(F_);
8474
8475 auto *optS1 = findFunctionNodeByName<SaveNode>(optimizedF_, S1->getName());
8476 auto *optS2 = findFunctionNodeByName<SaveNode>(optimizedF_, S2->getName());
8477
8478 // Check that transpose has been sunk below quantize now for both.
8479 EXPECT_TRUE(llvm::isa<TransposeNode>(optS1->getInput()));
8480 EXPECT_TRUE(llvm::isa<TransposeNode>(optS2->getInput()));
8481
8482 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
8483 mod_.getPRNG());
8484 checkNumericalEquivalence(0.f);
8485}
8486
8487TEST_F(GraphOptz, skipSinkQuantizeTransposeMultiUser) {
8488 auto *input =
8489 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input",
8490 /* isTrainable */ false);
8491 auto *T = F_->createTranspose("transpose", input, NHWC2NCHW);
8492 auto *Q = F_->createQuantize("quant", T, ElemKind::Int8QTy, 0.11, 1);
8493 F_->createSave("save1", Q);
8494 F_->createSave("save2", T);
8495
8496 optimizedF_ = optimizeFunctionForTest(F_);
8497
8498 // Verify the graph hasn't changed.
8499 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true),
8500 optimizedF_->toString(/* skipUsersForStorage */ false,
8501 /* skipName */ true));
8502}
8503
8504TEST_F(GraphOptz, MergeMatMulsOnLHSWhenSkippingOne) {
8505 Placeholder *LHS1 =
8506 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "LHS1", false);
8507 Placeholder *LHS2 =
8508 mod_.createPlaceholder(ElemKind::FloatTy, {30, 10}, "LHS2", false);
8509 Placeholder *LHS3 =
8510 mod_.createPlaceholder(ElemKind::FloatTy, {20, 10}, "LHS3", false);
8511 Placeholder *RHS =
8512 mod_.createPlaceholder(ElemKind::FloatTy, {10, 15}, "RHS", false);
8513 bindings_.allocate(LHS1)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8514 bindings_.allocate(LHS2)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8515 bindings_.allocate(LHS3)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8516 bindings_.allocate(RHS)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8517
8518 // Chain a bunch of nodes together for LHS2 to prevent dependency analysis
8519 // from allowing merging for MM2 below.
8520 Node *sigLHS2 = LHS2;
8521 for (size_t i = 0, e = 7; i < e; i++) {
8522 sigLHS2 = F_->createSigmoid("s_lhs2", sigLHS2);
8523 }
8524
8525 Node *MM1 = F_->createMatMul("mm1", LHS1, RHS);
8526 Node *MM2 = F_->createMatMul("mm2", sigLHS2, RHS);
8527 Node *MM3 = F_->createMatMul("mm3", LHS3, RHS);
8528
8529 F_->createSave("save1", MM1);
8530 F_->createSave("save2", MM2);
8531 F_->createSave("save3", MM3);
8532 ASSERT_TRUE(F_->verify());
8533
8534 optimizedF_ = optimizeFunctionForTest(
8535 F_, {FunctionPassID::MergeMatMulOnLHS, getDCEPassConfig()});
8536 ASSERT_TRUE(optimizedF_->verify());
8537
8538 // Expect three matmuls -> two matmuls, because mm1 and mm3 were merged.
8539 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 3);
8540 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 2);
8541
8542 checkNumericalEquivalence(0.f);
8543}
8544
8545TEST_F(GraphOptz, MergeMatMulsOnRHSWhenSkippingOne) {
8546 Placeholder *LHS =
8547 mod_.createPlaceholder(ElemKind::FloatTy, {40, 10}, "LHS", false);
8548 Placeholder *RHS1 =
8549 mod_.createPlaceholder(ElemKind::FloatTy, {10, 15}, "RHS1", false);
8550 Placeholder *RHS2 =
8551 mod_.createPlaceholder(ElemKind::FloatTy, {10, 20}, "RHS2", false);
8552 Placeholder *RHS3 =
8553 mod_.createPlaceholder(ElemKind::FloatTy, {10, 30}, "RHS3", false);
8554 bindings_.allocate(LHS)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8555 bindings_.allocate(RHS1)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8556 bindings_.allocate(RHS2)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8557 bindings_.allocate(RHS3)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG());
8558
8559 // Chain a bunch of nodes together for RHS2 to prevent dependency analysis
8560 // from allowing merging for MM2 below.
8561 Node *sigRHS2 = RHS2;
8562 for (size_t i = 0, e = 7; i < e; i++) {
8563 sigRHS2 = F_->createSigmoid("s_rhs2", sigRHS2);
8564 }
8565
8566 Node *MM1 = F_->createMatMul("mm1", LHS, RHS1);
8567 Node *MM2 = F_->createMatMul("mm2", LHS, sigRHS2);
8568 Node *MM3 = F_->createMatMul("mm3", LHS, RHS3);
8569
8570 F_->createSave("save1", MM1);
8571 F_->createSave("save2", MM2);
8572 F_->createSave("save3", MM3);
8573 ASSERT_TRUE(F_->verify());
8574
8575 optimizedF_ = optimizeFunctionForTest(
8576 F_, {FunctionPassID::MergeMatMulOnRHS, getDCEPassConfig()});
8577 ASSERT_TRUE(optimizedF_->verify());
8578
8579 // Expect three matmuls -> two matmuls, because mm1 and mm3 were merged.
8580 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 3);
8581 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 2);
8582
8583 checkNumericalEquivalence(0.f);
8584}
8585