1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "BackendTestUtils.h" |
17 | |
18 | #include "glow/Base/Type.h" |
19 | #include "glow/Graph/Graph.h" |
20 | #include "glow/Graph/Node.h" |
21 | #include "glow/Graph/Nodes.h" |
22 | #include "glow/Graph/PlaceholderBindings.h" |
23 | #include "glow/IR/IR.h" |
24 | #include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h" |
25 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
26 | #include "glow/Optimizer/Lower/Lower.h" |
27 | |
28 | #include "gtest/gtest.h" |
29 | |
30 | using namespace glow; |
31 | |
32 | class GraphFold : public GraphOptz {}; |
33 | |
34 | /// A helper predicate to check if the provided node has the same address as a |
35 | /// pre-defined address provided in constructor. This is useful if you need to |
36 | /// check that a given node is still in the graph. In general, it is not safe to |
37 | /// use the std::find(begin_it, end_it, value) and compare the nodes by value, |
38 | /// because the node provided as the last parameter of std::find (i.e. the value |
39 | /// reference) may have been removed by some optimizations and cannot be |
40 | /// dereferenced anymore. But comparing the addresses of the nodes should be |
41 | /// fine. Thus, one can use the following form instead: |
42 | /// std::find_if(begin_it, end_it, IsSameNodeAddress(node_address)) |
43 | struct IsSameNodeAddress { |
44 | const Node *nodeAddress_; |
45 | IsSameNodeAddress(const Node *nodeAddress) : nodeAddress_(nodeAddress) {} |
46 | bool operator()(const Node &n) const { return &n == nodeAddress_; } |
47 | }; |
48 | |
49 | /// \returns true if the Function \p F contains the Node \p N. |
50 | static bool functionContainsNode(const Function *F, const Node *N) { |
51 | return std::find_if(F->getNodes().begin(), F->getNodes().end(), |
52 | IsSameNodeAddress(N)) != F->getNodes().end(); |
53 | } |
54 | |
55 | /// Optimize the function \p F with \p cctx. \returns the optimized function. If |
56 | /// \p pass is empty then the whole default optimization pipeline is run. |
57 | /// Otherwise only \p pipeline is used. |
58 | static Function * |
59 | optimizeFunctionForTest(Function *F, |
60 | std::initializer_list<FunctionPassConfig> configs = {}, |
61 | const CompilationContext &cctx = CompilationContext()) { |
62 | auto *G = F->clone(F->getName().str() + "_optimized" ); |
63 | if (configs.size() == 0) { |
64 | ::glow::optimize(G, cctx); |
65 | return G; |
66 | } |
67 | FunctionPassManager FPM("TestFPM" , configs); |
68 | FPM.run(G, cctx); |
69 | return G; |
70 | } |
71 | |
72 | /// \returns the first node in a function which has the specificied name. |
73 | template <typename NodeT = Node> |
74 | static const NodeT *findFunctionNodeByName(const Function *F, |
75 | const llvm::StringRef name) { |
76 | return llvm::dyn_cast<NodeT>( |
77 | std::find_if(F->getNodes().begin(), F->getNodes().end(), |
78 | [=](auto &N) { return N.getName() == name; })); |
79 | } |
80 | |
81 | TEST_F(GraphOptz, OptimizeClipFunnel) { |
82 | auto *A = |
83 | mod_.createPlaceholder(ElemKind::FloatTy, {100, 16}, "input" , false); |
84 | Node *K = A; |
85 | float min = 0.0; |
86 | float max = 1000.0; |
87 | for (int i = 0; i < 10; ++i) { |
88 | min += 1.0; |
89 | max -= 1.0; |
90 | K = F_->createClip("clip" , K, min, max); |
91 | } |
92 | F_->createSave("ret" , K); |
93 | |
94 | EXPECT_EQ(F_->getNodes().size(), 11); |
95 | |
96 | optimizedF_ = optimizeFunctionForTest(F_); |
97 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
98 | |
99 | // Find clip node in the optimized graph. |
100 | Node *newClip = A; |
101 | for (auto &N : optimizedF_->getNodes()) { |
102 | if (N.getKind() == Kinded::Kind::ClipNodeKind) { |
103 | newClip = llvm::dyn_cast<ClipNode>(&N); |
104 | } |
105 | } |
106 | EXPECT_TRUE(llvm::isa<ClipNode>(newClip)); |
107 | ClipNode *c = llvm::dyn_cast<ClipNode>(newClip); |
108 | EXPECT_EQ(min, c->getMin()); |
109 | EXPECT_EQ(max, c->getMax()); |
110 | |
111 | bindings_.allocate(mod_.getPlaceholders()); |
112 | bindings_.get(A)->getHandle().randomize(-1000, 1000, mod_.getPRNG()); |
113 | bindings_.get(A)->getHandle().raw(0) = -1000; |
114 | checkNumericalEquivalence(); |
115 | } |
116 | |
117 | TEST_F(GraphOptz, DCE) { |
118 | Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , |
119 | false); |
120 | |
121 | for (int i = 0; i < 40; i++) { |
122 | K = F_->createRELU("relu" , K); |
123 | // Add a graph structure that diverges and converges, to catch algorithms |
124 | // that perform a dump recursive scan. |
125 | K = F_->createAdd("arith" , K, K); |
126 | } |
127 | |
128 | // Check that we know how many nodes we've created. |
129 | EXPECT_EQ(F_->getNodes().size(), 80); |
130 | |
131 | // Optimize all of the dead code. |
132 | ::glow::optimize(F_, CompilationMode::Infer); |
133 | |
134 | // All of the nodes are gone. |
135 | EXPECT_EQ(F_->getNodes().size(), 0); |
136 | EXPECT_EQ(mod_.getConstants().size(), 0); |
137 | } |
138 | |
139 | /// Check that predicated instructions are DCE'ed like |
140 | /// regular instructions. |
141 | TEST_F(GraphOptz, DCEwithPredicate) { |
142 | Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , |
143 | false); |
144 | Node *predicatedBatch = |
145 | mod_.createPlaceholder(ElemKind::FloatTy, {4}, "predicate" , true); |
146 | for (int i = 0; i < 40; i++) { |
147 | K = F_->createRELU("relu" , K); |
148 | K->setPredicate(predicatedBatch); |
149 | // Add a graph structure that diverges and converges, to catch algorithms |
150 | // that perform a dump recursive scan. |
151 | K = F_->createAdd("arith" , K, K); |
152 | K->setPredicate(predicatedBatch); |
153 | } |
154 | |
155 | // Check that we know how many nodes we've created. |
156 | EXPECT_EQ(F_->getNodes().size(), 80); |
157 | |
158 | // Optimize all of the dead code. |
159 | ::glow::optimize(F_, CompilationMode::Infer); |
160 | |
161 | // All of the nodes are gone. |
162 | EXPECT_EQ(F_->getNodes().size(), 0); |
163 | EXPECT_EQ(mod_.getConstants().size(), 0); |
164 | } |
165 | |
166 | TEST_F(GraphOptz, liveCodeNotEliminated) { |
167 | Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , |
168 | false); |
169 | auto *Ex = mod_.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "Ex" , false); |
170 | |
171 | for (int i = 0; i < 40; i++) { |
172 | K = F_->createRELU("relu" , K); |
173 | K = F_->createAdd("arith" , K, K); |
174 | } |
175 | K = F_->createSoftMax("Regression" , K, Ex); |
176 | F_->createSave("ret" , K); |
177 | |
178 | // Check that we know how many nodes we've created. |
179 | EXPECT_EQ(F_->getNodes().size(), 82); |
180 | |
181 | // This should not optimize code because none is dead. |
182 | ::glow::optimize(F_, CompilationMode::Infer); |
183 | |
184 | // Nothing got optimized. |
185 | EXPECT_EQ(F_->getNodes().size(), 82); |
186 | EXPECT_EQ(mod_.getPlaceholders().size(), 3); |
187 | } |
188 | |
189 | /// Skip Reshape sinking below BatchNorm when inapplicable. |
190 | TEST_F(GraphOptz, SkipReshapeSinkBatchNorm) { |
191 | auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A" , false); |
192 | Node *RS = F_->createReshape("reshape" , A, {32, 64, 1}); |
193 | Node *BN = |
194 | F_->createBatchNormalization(bindings_, "batch" , RS, 1, 0.0001, 0.9); |
195 | F_->createSave("ret" , BN); |
196 | |
197 | optimizedF_ = optimizeFunctionForTest(F_); |
198 | EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true), |
199 | optimizedF_->toString(/* skipUsersForStorage */ false, |
200 | /* skipName */ true)); |
201 | } |
202 | |
203 | // Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape |
204 | // below BatchNorm. Reshape transforms [N][H][W][C] to [N][W][H][C]. |
205 | TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC) { |
206 | auto *A = |
207 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
208 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
209 | Node *RS = F_->createReshape("reshape" , CV, {1, 20, 10, 16}); |
210 | Node *BN = |
211 | F_->createBatchNormalization(bindings_, "batch" , RS, 3, 0.0001, 0.9); |
212 | F_->createSave("ret" , BN); |
213 | |
214 | EXPECT_EQ(F_->getNodes().size(), 4); |
215 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
216 | optimizedF_ = optimizeFunctionForTest(F_); |
217 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
218 | |
219 | ASSERT_EQ(A->getNumUsers(), 2); |
220 | Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(), |
221 | [CV](auto &it) { return it.getUser() == CV; }) |
222 | ->getUser(); |
223 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV)); |
224 | ASSERT_EQ(newCV->getNumUsers(), 1); |
225 | Node *reshape = newCV->getUsers().begin()->getUser(); |
226 | EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape)); |
227 | |
228 | bindings_.allocate(mod_.getPlaceholders()); |
229 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
230 | checkNumericalEquivalence(); |
231 | } |
232 | |
233 | // Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape |
234 | // below BatchNorm. Reshape flattens [N][H][W][C] to [N][HxW][C]. |
235 | TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC2) { |
236 | auto *A = |
237 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
238 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
239 | Node *RS = F_->createReshape("reshape" , CV, {1, 200, 16}); |
240 | Node *BN = |
241 | F_->createBatchNormalization(bindings_, "batch" , RS, 2, 0.0001, 0.9); |
242 | F_->createSave("ret" , BN); |
243 | |
244 | EXPECT_EQ(F_->getNodes().size(), 4); |
245 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
246 | optimizedF_ = optimizeFunctionForTest(F_); |
247 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
248 | |
249 | ASSERT_EQ(A->getNumUsers(), 2); |
250 | Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(), |
251 | [CV](auto &it) { return it.getUser() == CV; }) |
252 | ->getUser(); |
253 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV)); |
254 | ASSERT_EQ(newCV->getNumUsers(), 1); |
255 | Node *reshape = newCV->getUsers().begin()->getUser(); |
256 | EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape)); |
257 | |
258 | bindings_.allocate(mod_.getPlaceholders()); |
259 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
260 | checkNumericalEquivalence(); |
261 | } |
262 | |
263 | // BatchNorm is not folded into Conv. Reshape changes Channel Index dimensions |
264 | // and it prevents optimization. Reshape transforms [N][H][W][C] to |
265 | // [N][H][W/2][C*2]. |
266 | TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWCneg) { |
267 | auto *A = |
268 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
269 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
270 | Node *RS = F_->createReshape("reshape" , CV, {1, 10, 10, 32}); |
271 | Node *BN = |
272 | F_->createBatchNormalization(bindings_, "batch" , RS, 3, 0.0001, 0.9); |
273 | F_->createSave("ret" , BN); |
274 | |
275 | EXPECT_EQ(F_->getNodes().size(), 4); |
276 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
277 | optimizedF_ = optimizeFunctionForTest(F_); |
278 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
279 | |
280 | ASSERT_EQ(A->getNumUsers(), 2); |
281 | Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(), |
282 | [CV](auto &it) { return it.getUser() == CV; }) |
283 | ->getUser(); |
284 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV)); |
285 | ASSERT_EQ(newCV->getNumUsers(), 1); |
286 | Node *reshape = newCV->getUsers().begin()->getUser(); |
287 | EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape)); |
288 | Node *bn = reshape->getUsers().begin()->getUser(); |
289 | EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn)); |
290 | |
291 | bindings_.allocate(mod_.getPlaceholders()); |
292 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
293 | checkNumericalEquivalence(); |
294 | } |
295 | |
296 | // Sink Reshape below BatchNorm: multi-user testcase. |
297 | TEST_F(GraphOptz, sinkReshapeBelowBatchNormMultiUser) { |
298 | auto *in = |
299 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 40, 8}, "input" , false); |
300 | auto *RS = F_->createReshape("reshape" , in, {1, 20, 20, 8}); |
301 | auto *BN = |
302 | F_->createBatchNormalization(bindings_, "batch" , RS, 3, 0.0001, 0.9); |
303 | auto *save = F_->createSave("ret" , BN); |
304 | F_->createSave("extra_user" , RS); |
305 | |
306 | optimizedF_ = optimizeFunctionForTest(F_); |
307 | EXPECT_EQ(F_->getNodes().size(), 4); |
308 | EXPECT_EQ(optimizedF_->getNodes().size(), 5); |
309 | |
310 | auto *saveOpt = |
311 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
312 | auto *reshapeOpt = llvm::dyn_cast<ReshapeNode>(saveOpt->getInput()); |
313 | ASSERT_TRUE(reshapeOpt); |
314 | ASSERT_TRUE(llvm::isa<BatchNormalizationNode>(reshapeOpt->getInput())); |
315 | |
316 | bindings_.allocate(mod_.getPlaceholders()); |
317 | bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
318 | checkNumericalEquivalence(); |
319 | } |
320 | |
321 | // Sink Reshape below BatchNorm: quantized testcase. |
322 | TEST_F(GraphOptz, sinkReshapeBelowBatchNormQuantized) { |
323 | auto *in = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 40, 3}, 1.5f, 0, |
324 | "input" , false); |
325 | auto *params = |
326 | mod_.createPlaceholder(ElemKind::Int8QTy, {3}, 1.0f, 0, "params" , false); |
327 | auto *RS = F_->createReshape("reshape" , in, {1, 20, 20, 3}); |
328 | auto *bnOutTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 20, 20, 3}, 2.7f, 0); |
329 | auto *BN = F_->createBatchNormalization("batch" , bnOutTy, RS, params, params, |
330 | params, params, 3); |
331 | auto *save = F_->createSave("ret" , BN); |
332 | |
333 | optimizedF_ = optimizeFunctionForTest(F_); |
334 | EXPECT_EQ(F_->getNodes().size(), 3); |
335 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
336 | |
337 | auto *saveOpt = |
338 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
339 | auto *reshapeOpt = llvm::dyn_cast<ReshapeNode>(saveOpt->getInput()); |
340 | ASSERT_TRUE(reshapeOpt); |
341 | auto *bnOpt = llvm::dyn_cast<BatchNormalizationNode>(reshapeOpt->getInput()); |
342 | ASSERT_TRUE(bnOpt); |
343 | EXPECT_TRUE( |
344 | BN->getInput().getType()->isEqual(*bnOpt->getInput().getType(), true)); |
345 | EXPECT_TRUE( |
346 | BN->getResult().getType()->isEqual(*bnOpt->getResult().getType(), true)); |
347 | } |
348 | |
349 | // Conv->Reshape->BatchNorm. Sink Reshape below BatchNorm. Check that BatchNorm |
350 | // does not fold in to Conv. |
351 | TEST_F(GraphOptz, sinkReshapeBelowBatchNormAndDoNotFuseConvBatchNorm) { |
352 | // Skip this test for now since Glow doesn't fully support |
353 | // Convolution of NCHW layout |
354 | GTEST_SKIP(); |
355 | |
356 | auto *A = |
357 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 10, 20}, "A" , false); |
358 | Node *CV = F_->createConv(bindings_, "conv" , A, /* outChannels */ 16, |
359 | /* kernel */ 5, /* stride */ 1, /* pad */ 2, |
360 | /* group */ 1, /* dilation */ {1, 1}, |
361 | /* layout */ ConvolutionLayout::NCHW); |
362 | Node *RS = F_->createReshape("reshape" , CV, {1, 10, 16, 20}); |
363 | Node *BN = |
364 | F_->createBatchNormalization(bindings_, "batch" , RS, 1, 0.0001, 0.9); |
365 | F_->createSave("ret" , BN); |
366 | |
367 | EXPECT_EQ(F_->getNodes().size(), 4); |
368 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
369 | optimizedF_ = optimizeFunctionForTest(F_); |
370 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
371 | |
372 | ASSERT_EQ(A->getNumUsers(), 2); |
373 | Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(), |
374 | [CV](auto &it) { return it.getUser() == CV; }) |
375 | ->getUser(); |
376 | |
377 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV)); |
378 | ASSERT_EQ(newCV->getNumUsers(), 1); |
379 | Node *bn = newCV->getUsers().begin()->getUser(); |
380 | EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn)); |
381 | Node *reshape = bn->getUsers().begin()->getUser(); |
382 | EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape)); |
383 | |
384 | bindings_.allocate(mod_.getPlaceholders()); |
385 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
386 | checkNumericalEquivalence(); |
387 | } |
388 | |
389 | TEST_F(GraphOptz, optimizeBatchNormAfterConv) { |
390 | auto *A = |
391 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
392 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
393 | Node *BN = |
394 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
395 | F_->createSave("ret" , BN); |
396 | |
397 | EXPECT_EQ(F_->getNodes().size(), 3); |
398 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
399 | optimizedF_ = optimizeFunctionForTest(F_); |
400 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
401 | |
402 | ASSERT_EQ(A->getNumUsers(), 2); |
403 | Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(), |
404 | [CV](auto &it) { return it.getUser() == CV; }) |
405 | ->getUser(); |
406 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV)); |
407 | ASSERT_EQ(newCV->getNumUsers(), 1); |
408 | Node *save = newCV->getUsers().begin()->getUser(); |
409 | EXPECT_TRUE(llvm::isa<SaveNode>(save)); |
410 | |
411 | bindings_.allocate(mod_.getPlaceholders()); |
412 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
413 | checkNumericalEquivalence(); |
414 | } |
415 | |
416 | void optimizeRedundantBatchNormTest( |
417 | glow::Module &mod_, glow::Function *F_, glow::Function *&optimizedF_, |
418 | glow::PlaceholderBindings &bindings_, llvm::ArrayRef<float> varV, |
419 | llvm::ArrayRef<float> meanV, llvm::ArrayRef<float> gammaV, |
420 | llvm::ArrayRef<float> betaV, const float eps) { |
421 | auto *A = |
422 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
423 | |
424 | auto *var = mod_.createConstant(ElemKind::FloatTy, {3}, "var" ); |
425 | auto *mean = mod_.createConstant(ElemKind::FloatTy, {3}, "mean" ); |
426 | auto *beta = mod_.createConstant(ElemKind::FloatTy, {3}, "beta" ); |
427 | auto *gamma = mod_.createConstant(ElemKind::FloatTy, {3}, "gamma" ); |
428 | |
429 | // (X - mean) * (1.0 / sqrt(var + eps)) * gamma + beta |
430 | var->getPayloadMutable().getHandle<float>() = varV; |
431 | mean->getPayloadMutable().getHandle<float>() = meanV; |
432 | beta->getPayloadMutable().getHandle<float>() = betaV; |
433 | gamma->getPayloadMutable().getHandle<float>() = gammaV; |
434 | Node *BN = F_->createBatchNormalization("batch" , A->getType(), A, beta, gamma, |
435 | mean, var, 3, eps); |
436 | Node *LRN = F_->createLocalResponseNormalization("LRN" , BN); |
437 | F_->createSave("ret" , LRN); |
438 | |
439 | EXPECT_EQ(F_->getNodes().size(), 3); |
440 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
441 | optimizedF_ = optimizeFunctionForTest(F_); |
442 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
443 | |
444 | ASSERT_EQ(A->getNumUsers(), 2); |
445 | Node *LRN1 = std::find_if_not(A->getUsers().begin(), A->getUsers().end(), |
446 | [BN](auto &it) { return it.getUser() == BN; }) |
447 | ->getUser(); |
448 | ASSERT_TRUE(llvm::isa<LocalResponseNormalizationNode>(LRN1)); |
449 | ASSERT_EQ(LRN1->getNumUsers(), 1); |
450 | Node *save = LRN1->getUsers().begin()->getUser(); |
451 | EXPECT_TRUE(llvm::isa<SaveNode>(save)); |
452 | |
453 | bindings_.allocate(mod_.getPlaceholders()); |
454 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
455 | } |
456 | |
457 | TEST_F(GraphOptz, optimizeRedundantBatchNorm1) { |
458 | optimizeRedundantBatchNormTest(mod_, F_, optimizedF_, bindings_, {1., 1., 1.}, |
459 | {0., 0., 0.}, {1., 1., 1.}, {0., 0., 0.}, 0.0); |
460 | checkNumericalEquivalence(); |
461 | } |
462 | |
463 | TEST_F(GraphOptz, optimizeRedundantBatchNorm2) { |
464 | optimizeRedundantBatchNormTest(mod_, F_, optimizedF_, bindings_, {1., 1., 1.}, |
465 | {33., 33., 33.}, {1., 1., 1.}, {33., 33., 33.}, |
466 | 0.0); |
467 | checkNumericalEquivalence(); |
468 | } |
469 | |
470 | TEST_F(GraphOptz, optimizeRedundantBatchNorm3) { |
471 | const float eps = 0.000001; |
472 | optimizeRedundantBatchNormTest( |
473 | mod_, F_, optimizedF_, bindings_, {1.0f - eps, 1.0f - eps, 1.0f - eps}, |
474 | {33., 33., 33.}, {1., 1., 1.}, {33., 33., 33.}, eps); |
475 | checkNumericalEquivalence(); |
476 | } |
477 | TEST_F(GraphOptz, optimizeRedundantBatchNorm4) { |
478 | optimizeRedundantBatchNormTest(mod_, F_, optimizedF_, bindings_, |
479 | {225., 225., 225.}, {-3., -3., -3.}, |
480 | {15., 15., 15.}, {-3., -3., -3.}, 0.0); |
481 | checkNumericalEquivalence(); |
482 | } |
483 | |
484 | /// Verify that the Conv-BatchNorm merging optimization is not impacted by |
485 | /// multiple users on the filter/bias. |
486 | TEST_F(GraphOptz, optimizeBatchNormAfterConvMultiple) { |
487 | Placeholder *A = |
488 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
489 | ConvolutionNode *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
490 | BatchNormalizationNode *BN = |
491 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
492 | F_->createSave("ret" , BN); |
493 | |
494 | // Adding these saves means the filter and bias have multiple uses. This |
495 | // should not impact the Conv-BatchNorm merging optimization. |
496 | F_->createSave("saveFilter" , CV->getFilter()); |
497 | F_->createSave("saveBias" , CV->getBias()); |
498 | |
499 | // Three Saves, one Conv, and one BatchNorm. |
500 | EXPECT_EQ(F_->getNodes().size(), 5); |
501 | |
502 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
503 | |
504 | // Conv's Filter and Bias, plus BN's Scale, Bias, Mean, and Var. |
505 | EXPECT_EQ(mod_.getConstants().size(), 6); |
506 | |
507 | optimizedF_ = optimizeFunctionForTest(F_); |
508 | |
509 | // BatchNorm should have been merged into the Conv. |
510 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
511 | |
512 | // Filter and Bias should have been duplicated so that the Conv-BN |
513 | // optimization does not modify the filter/bias being saved, equaling 4 |
514 | // Constants. Additionally, the BN's Scale, Bias, Mean, and Var should be |
515 | // eliminated due to the opti. |
516 | EXPECT_EQ(mod_.getConstants().size(), 8); |
517 | |
518 | ASSERT_EQ(A->getNumUsers(), 2); |
519 | Node *newCV = A->getUsers().back().getUser(); |
520 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV)); |
521 | ASSERT_EQ(newCV->getNumUsers(), 1); |
522 | Node *save = newCV->getUsers().begin()->getUser(); |
523 | EXPECT_TRUE(llvm::isa<SaveNode>(save)); |
524 | |
525 | EXPECT_EQ( |
526 | countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0); |
527 | |
528 | bindings_.allocate(mod_.getPlaceholders()); |
529 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
530 | checkNumericalEquivalence(); |
531 | } |
532 | |
533 | TEST_F(GraphOptz, optimizeBatchNormAfterConvFP16) { |
534 | auto *A = |
535 | mod_.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "A" , false); |
536 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
537 | Node *BN = |
538 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
539 | F_->createSave("ret" , BN); |
540 | |
541 | EXPECT_EQ(F_->getNodes().size(), 3); |
542 | |
543 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
544 | optimizedF_ = optimizeFunctionForTest(F_); |
545 | |
546 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
547 | |
548 | ASSERT_EQ(A->getNumUsers(), 2); |
549 | |
550 | bool optimizedPathExists{false}; |
551 | for (const auto &path : A->getUsers()) { |
552 | auto cv = path.getUser(); |
553 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(cv)); |
554 | ASSERT_EQ(cv->getNumUsers(), 1); |
555 | auto next = cv->getUsers().begin()->getUser(); |
556 | optimizedPathExists |= llvm::isa<SaveNode>(next); |
557 | } |
558 | |
559 | EXPECT_TRUE(optimizedPathExists); |
560 | |
561 | bindings_.allocate(A)->getHandle<float16_t>().randomize(-1.0, 1.0, |
562 | mod_.getPRNG()); |
563 | |
564 | checkNumericalEquivalence(); |
565 | } |
566 | |
567 | /// Check that transpose constant folding is done before BatchNorm optimization, |
568 | /// which allows to merge BatchNorm into Convolution with transposed weights. |
569 | TEST_F(GraphOptz, optimizeBatchNormAfterConvWithTransposedWeights) { |
570 | auto *input = |
571 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input" , false); |
572 | auto *filter = |
573 | mod_.createPlaceholder(ElemKind::FloatTy, {16, 3, 5, 5}, "filter" , false); |
574 | auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {16}, "bias" , false); |
575 | |
576 | auto *TN = F_->createTranspose("transpose" , filter, NCHW2NHWC); |
577 | auto *CV = F_->createConv("conv" , input, TN, bias, |
578 | mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 16}), |
579 | 5, 1, 2, 1); |
580 | auto *BN = |
581 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
582 | F_->createSave("ret" , BN); |
583 | |
584 | // Initialize to ensure that constant tensors are not optimized out. |
585 | bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
586 | bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
587 | |
588 | EXPECT_EQ(F_->getNodes().size(), 4); |
589 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1); |
590 | |
591 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {input}); |
592 | optimizedF_ = optimizeFunctionForTest(F_); |
593 | |
594 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
595 | EXPECT_EQ( |
596 | countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0); |
597 | |
598 | bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
599 | checkNumericalEquivalence(); |
600 | } |
601 | |
602 | /// Check that reshape constant folding is done before BatchNorm optimization, |
603 | /// where Reshape is a result of Transpose 2 Reshape optimization, |
604 | /// which allows to merge BatchNorm into Convolution with transposed weights. |
605 | TEST_F(GraphOptz, optimizeBatchNormAfterConvWithReshapeConst) { |
606 | auto *input = |
607 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input" , false); |
608 | auto *filter = |
609 | mod_.createPlaceholder(ElemKind::FloatTy, {5, 5, 3, 1}, "filter" , false); |
610 | auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias" , false); |
611 | |
612 | auto *TN = F_->createTranspose("transpose" , filter, HWCN2NHWC); |
613 | auto *CV = F_->createConv("conv" , input, TN, bias, |
614 | mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 1}), |
615 | 5, 1, 2, 1); |
616 | auto *BN = |
617 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
618 | F_->createSave("ret" , BN); |
619 | |
620 | // Initialize to ensure that constant tensors are not optimized out. |
621 | bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
622 | bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
623 | |
624 | EXPECT_EQ(F_->getNodes().size(), 4); |
625 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1); |
626 | |
627 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {input}); |
628 | optimizedF_ = optimizeFunctionForTest(F_); |
629 | |
630 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
631 | EXPECT_EQ( |
632 | countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0); |
633 | |
634 | bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
635 | checkNumericalEquivalence(); |
636 | } |
637 | |
638 | /// Check that the batch normalization optimization is |
639 | /// not blocked by predicates and that it preserves them. |
640 | TEST_F(GraphOptz, optimizeBatchNormAfterConvWithPred) { |
641 | Node *A = |
642 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
643 | Node *pred1 = |
644 | mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate" , false); |
645 | Node *pred2 = |
646 | mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate" , false); |
647 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
648 | CV->setPredicate(pred1); |
649 | Node *BN = |
650 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
651 | BN->setPredicate(pred2); |
652 | F_->createSave("ret" , BN); |
653 | |
654 | EXPECT_EQ(F_->getNodes().size(), 3); |
655 | |
656 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
657 | ::glow::optimize(F_, CompilationMode::Infer); |
658 | EXPECT_EQ(F_->getNodes().size(), 2); |
659 | |
660 | ASSERT_EQ(A->getNumUsers(), 1); |
661 | Node *newCV = A->getUsers().begin()->getUser(); |
662 | EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV)); |
663 | ASSERT_TRUE(newCV->hasPredicate()); |
664 | EXPECT_EQ(newCV->getPredicate().getNode(), pred2); |
665 | ASSERT_EQ(newCV->getNumUsers(), 1); |
666 | Node *save = newCV->getUsers().begin()->getUser(); |
667 | EXPECT_TRUE(llvm::isa<SaveNode>(save)); |
668 | } |
669 | |
670 | /// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add) |
671 | /// into a BatchNorm. |
672 | TEST_F(GraphOptz, MergeBatchNormalizationWithArithmeticChainTest) { |
673 | // Inputs. |
674 | auto *input = |
675 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 4}, "input" , false); |
676 | auto *var = mod_.createConstant(ElemKind::FloatTy, {4}, "var" ); |
677 | auto *mean = mod_.createConstant(ElemKind::FloatTy, {4}, "mean" ); |
678 | auto *beta = mod_.createConstant(ElemKind::FloatTy, {4}, "beta" ); |
679 | auto *gamma = mod_.createConstant(ElemKind::FloatTy, {4}, "gamma" ); |
680 | |
681 | Node *subC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "subC" ); |
682 | Node *mulC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "mulC" ); |
683 | Node *addC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "addC" ); |
684 | Node *divC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "divC" ); |
685 | |
686 | // Fill tensors to check boundary values after the transformation. |
687 | std::vector<float> betaV = {1., 2., 3., 7.}; |
688 | std::vector<float> gammaV = {4., 5., 6., 7.}; |
689 | |
690 | var->getPayloadMutable().getHandle<float>() = {1., 1., 1., 1.}; |
691 | mean->getPayloadMutable().getHandle<float>() = {0., 0., 0., 0.}; |
692 | beta->getPayloadMutable().getHandle<float>() = betaV; |
693 | gamma->getPayloadMutable().getHandle<float>() = gammaV; |
694 | |
695 | // For at least one node (sub) make values within channel different, to test |
696 | // folding better. |
697 | const std::vector<float> subV = {1, 2., 3., 4.}; |
698 | const float mulV = 4., addV = 3., divV = 2.; |
699 | auto subH = llvm::cast<Constant>(subC)->getHandle<float>(); |
700 | subH = {1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., |
701 | 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., |
702 | 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.}; |
703 | |
704 | llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV); |
705 | llvm::cast<Constant>(addC)->getHandle<float>().clear(addV); |
706 | llvm::cast<Constant>(divC)->getHandle<float>().clear(divV); |
707 | |
708 | BatchNormalizationNode *bn = F_->createBatchNormalization( |
709 | "batch" , input->getType(), input, beta, gamma, mean, var, 3); |
710 | |
711 | auto *sub = F_->createSub("sub" , bn, subC); |
712 | auto *mul = F_->createMul("mul" , sub, mulC); |
713 | auto *add = F_->createAdd("add" , addC, mul); |
714 | auto *div = F_->createDiv("div" , add, divC); |
715 | auto *res = F_->createSave("save" , div); |
716 | |
717 | // Compile. |
718 | EXPECT_EQ(F_->getNodes().size(), 6); |
719 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {input}); |
720 | optimizedF_ = optimizeFunctionForTest(F_); |
721 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
722 | |
723 | Constant *cs, *cb; |
724 | |
725 | auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName()); |
726 | |
727 | auto *newBn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput()); |
728 | ASSERT_TRUE(newBn); |
729 | |
730 | cs = llvm::dyn_cast<Constant>(newBn->getScale()); |
731 | cb = llvm::dyn_cast<Constant>(newBn->getBias()); |
732 | ASSERT_TRUE(cs); |
733 | ASSERT_TRUE(cb); |
734 | ASSERT_TRUE(cs->getType()->isFPType()); |
735 | ASSERT_TRUE(cb->getType()->isFPType()); |
736 | |
737 | auto hs = cs->getHandle<float>(); |
738 | auto hb = cb->getHandle<float>(); |
739 | |
740 | // Verify that scale and offset are computed correctly. |
741 | for (dim_t i = 0; i < 4; i++) { |
742 | const float expScale = gammaV[i] * mulV / divV; |
743 | const float expBias = ((betaV[i] - subV[i]) * mulV + addV) / divV; |
744 | EXPECT_EQ(expScale, hs.raw(i)); |
745 | EXPECT_EQ(expBias, hb.raw(i)); |
746 | } |
747 | |
748 | bindings_.allocate(mod_.getPlaceholders()); |
749 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
750 | checkNumericalEquivalence(); |
751 | } |
752 | |
753 | /// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add) |
754 | /// into a BatchNorm. |
755 | TEST_F(GraphOptz, FoldArithmeticChainAfterConvIntoBatchNorm) { |
756 | Node *subC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "subC" ); |
757 | Node *mulC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "mulC" ); |
758 | Node *addC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "addC" ); |
759 | Node *divC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "divC" ); |
760 | |
761 | // Start with identity values. |
762 | std::vector<float> betaV = {0., 0., 0.}; |
763 | std::vector<float> gammaV = {1., 1., 1.}; |
764 | |
765 | // For at least one node make values within channel different, to test |
766 | // the folding better (ideally all should have different values). |
767 | const std::vector<float> subV = {1, 2., 3.}; |
768 | const float mulV = 4., addV = 3., divV = 2.; |
769 | llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV); |
770 | llvm::cast<Constant>(addC)->getHandle<float>().clear(addV); |
771 | llvm::cast<Constant>(divC)->getHandle<float>().clear(divV); |
772 | auto subH = llvm::cast<Constant>(subC)->getHandle<float>(); |
773 | subH = {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, |
774 | 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, |
775 | 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3}; |
776 | |
777 | auto *input = |
778 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 3}, "input" , false); |
779 | auto filter = |
780 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 3}, "filter" , false); |
781 | auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {3}, "bias" , false); |
782 | bindings_.allocate(bias)->zero(); |
783 | |
784 | ConvolutionNode *CV = F_->createConv( |
785 | "Conv" , input, filter, bias, |
786 | mod_.uniqueType(ElemKind::FloatTy, {2, 3, 3, 3}), 2, 1, 1, 1); |
787 | |
788 | auto *sub = F_->createSub("sub" , CV, subC); |
789 | auto *mul = F_->createMul("mul" , sub, mulC); |
790 | auto *add = F_->createAdd("add" , addC, mul); |
791 | auto *div = F_->createDiv("div" , add, divC); |
792 | auto *res = F_->createSave("save" , div); |
793 | |
794 | // Compile. |
795 | EXPECT_EQ(F_->getNodes().size(), 6); |
796 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
797 | optimizedF_ = optimizeFunctionForTest(F_); |
798 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
799 | |
800 | auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName()); |
801 | |
802 | Constant *cs, *cb; |
803 | |
804 | auto *bn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput()); |
805 | ASSERT_TRUE(bn); |
806 | |
807 | cs = llvm::dyn_cast<Constant>(bn->getScale()); |
808 | cb = llvm::dyn_cast<Constant>(bn->getBias()); |
809 | |
810 | ASSERT_TRUE(cs); |
811 | ASSERT_TRUE(cb); |
812 | ASSERT_TRUE(cs->getType()->isFPType()); |
813 | ASSERT_TRUE(cb->getType()->isFPType()); |
814 | |
815 | auto hs = cs->getHandle<float>(); |
816 | auto hb = cb->getHandle<float>(); |
817 | |
818 | // Verify that scale and offset are computed correctly. |
819 | for (dim_t i = 0; i < 3; i++) { |
820 | const float expectedScale = gammaV[i] * (mulV / divV); |
821 | const float expectedBias = ((betaV[i] - subV[i]) * mulV + addV) / divV; |
822 | EXPECT_EQ(expectedScale, hs.raw(i)); |
823 | EXPECT_EQ(expectedBias, hb.raw(i)); |
824 | } |
825 | bindings_.allocate(mod_.getPlaceholders()); |
826 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
827 | bindings_.get(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
828 | bindings_.get(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
829 | checkNumericalEquivalence(); |
830 | } |
831 | |
832 | /// Check CSE will not merge two nodes that have all the same inputs but |
833 | /// different predicates. |
834 | TEST_F(GraphOptz, cseRespectsPredicates) { |
835 | Placeholder *in = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "in" , false); |
836 | Placeholder *pred1 = |
837 | mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
838 | Placeholder *pred2 = |
839 | mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
840 | |
841 | Node *RN1 = F_->createRELU("relu1" , in); |
842 | RN1->setPredicate(pred1); |
843 | SaveNode *save1 = F_->createSave("save1" , RN1); |
844 | save1->setPredicate(pred1); |
845 | |
846 | Node *RN2 = F_->createRELU("relu2" , in); |
847 | RN2->setPredicate(pred2); |
848 | SaveNode *save2 = F_->createSave("save2" , RN2); |
849 | save2->setPredicate(pred2); |
850 | |
851 | // Two RELUS and two Saves. |
852 | EXPECT_EQ(F_->getNodes().size(), 4); |
853 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2); |
854 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2); |
855 | |
856 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
857 | optimizedF_ = optimizeFunctionForTest(F_); |
858 | |
859 | // Two RELUS and two Saves should still be there. |
860 | EXPECT_EQ(F_->getNodes().size(), 4); |
861 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2); |
862 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2); |
863 | |
864 | bindings_.allocate(mod_.getPlaceholders()); |
865 | bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
866 | checkNumericalEquivalence(); |
867 | } |
868 | |
869 | TEST_F(GraphOptz, optimizeBatchNormAfterConvButConvReused) { |
870 | Placeholder *A = |
871 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
872 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
873 | Node *BN = |
874 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
875 | F_->createSave("ret" , BN); |
876 | F_->createSave("convSave" , CV); |
877 | |
878 | EXPECT_EQ(F_->getNodes().size(), 4); |
879 | optimizedF_ = optimizeFunctionForTest(F_); |
880 | // Make sure the structure of the graph did not change, since the convolution |
881 | // node is used more than once. |
882 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
883 | auto convIt = |
884 | std::find_if(optimizedF_->getNodes().begin(), |
885 | optimizedF_->getNodes().end(), [](const Node &node) -> bool { |
886 | return llvm::isa<ConvolutionNode>(node); |
887 | }); |
888 | ASSERT_NE(convIt, optimizedF_->getNodes().end()); |
889 | auto batchNormIt = |
890 | std::find_if(optimizedF_->getNodes().begin(), |
891 | optimizedF_->getNodes().end(), [](const Node &node) -> bool { |
892 | return (llvm::isa<BatchNormalizationNode>(node)); |
893 | }); |
894 | ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convIt); |
895 | BatchNormalizationNode *batchNorm = |
896 | llvm::dyn_cast<BatchNormalizationNode>(batchNormIt); |
897 | |
898 | EXPECT_EQ(*conv, *CV); |
899 | EXPECT_EQ(batchNorm->getInput().getNode(), conv); |
900 | EXPECT_EQ(conv->getInput().getNode(), A); |
901 | |
902 | bindings_.allocate(mod_.getPlaceholders()); |
903 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
904 | checkNumericalEquivalence(); |
905 | } |
906 | |
907 | TEST_F(GraphOptz, optimizeBatchNormAfterConvButVarReused) { |
908 | auto *A = |
909 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
910 | |
911 | ConvolutionNode *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
912 | Node *BN = |
913 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
914 | auto *retSaveNode = F_->createSave("ret" , BN); |
915 | auto *filterSaveNode = F_->createSave("filter" , CV->getFilter()); |
916 | |
917 | EXPECT_EQ(F_->getNodes().size(), 4); |
918 | optimizedF_ = optimizeFunctionForTest(F_); |
919 | ASSERT_EQ(A->getNumUsers(), 2); |
920 | |
921 | auto *optimizedF_ret = |
922 | findFunctionNodeByName<SaveNode>(optimizedF_, retSaveNode->getName()); |
923 | auto *optimizedF_filterSave = |
924 | findFunctionNodeByName<SaveNode>(optimizedF_, filterSaveNode->getName()); |
925 | |
926 | // Make sure the structure of the graph did not change. |
927 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
928 | EXPECT_TRUE(llvm::isa<Placeholder>(optimizedF_filterSave->getInput())); |
929 | auto *varFilter = |
930 | llvm::dyn_cast<Placeholder>(optimizedF_filterSave->getInput()); |
931 | EXPECT_EQ(varFilter, CV->getFilter()); |
932 | EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(optimizedF_ret->getInput())); |
933 | |
934 | BatchNormalizationNode *batchNorm = |
935 | llvm::dyn_cast<BatchNormalizationNode>(optimizedF_ret->getInput()); |
936 | ASSERT_TRUE(batchNorm); |
937 | auto *newCVNode = |
938 | llvm::dyn_cast<ConvolutionNode>(batchNorm->getInput().getNode()); |
939 | ASSERT_TRUE(newCVNode); |
940 | EXPECT_EQ(newCVNode->getInput().getNode(), CV->getInput().getNode()); |
941 | EXPECT_EQ(newCVNode->getInput().getNode(), A); |
942 | |
943 | bindings_.allocate(mod_.getPlaceholders()); |
944 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
945 | checkNumericalEquivalence(); |
946 | } |
947 | |
948 | TEST_F(GraphOptz, transposeConstant) { |
949 | auto *A = |
950 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
951 | bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG()); |
952 | Tensor transposedA; |
953 | bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2}); |
954 | Node *T = F_->createTranspose("transpose" , A, NHWC2NCHW); |
955 | SaveNode *save = F_->createSave("ret" , T); |
956 | EXPECT_EQ(F_->getNodes().size(), 2); |
957 | |
958 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
959 | ::glow::optimize(F_, CompilationMode::Infer); |
960 | ASSERT_EQ(F_->getNodes().size(), 1); |
961 | EXPECT_EQ(&*F_->getNodes().begin(), save); |
962 | Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode()); |
963 | ASSERT_NE(optimizedA, nullptr); |
964 | // Check that A has been properly transposed. |
965 | EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA)); |
966 | } |
967 | |
968 | /// Check that the Transpose is merged with Constant in a sequence |
969 | /// Transpose(Quantize(Constant)). |
970 | TEST_F(GraphOptz, transposeQuantizeConstant) { |
971 | auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.2, 0); |
972 | auto *input = F_->getParent()->createConstant(ElemKind::FloatTy, |
973 | {1, 10, 20, 3}, "input" ); |
974 | auto *Q = F_->createQuantize("quantize" , input, qTy); |
975 | auto *T = F_->createTranspose("transpose" , Q, NHWC2NCHW); |
976 | auto *S = F_->createSave("save" , T); |
977 | |
978 | // Skip ConstantFolding as it would have the same result as this opt. |
979 | CompilationContext cctx; |
980 | cctx.optimizationOpts.enableConstantFolding = false; |
981 | |
982 | EXPECT_EQ(F_->getNodes().size(), 3); |
983 | ::glow::optimize(F_, cctx); |
984 | EXPECT_EQ(F_->getNodes().size(), 2); |
985 | |
986 | // Constant and Quantize should have new shape. |
987 | auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput()); |
988 | ASSERT_TRUE(newQ); |
989 | EXPECT_TRUE(newQ->getResult().dims().equals({1, 3, 10, 20})); |
990 | auto *newC = llvm::dyn_cast<Constant>(newQ->getInput()); |
991 | ASSERT_TRUE(newC); |
992 | EXPECT_TRUE(newC->getType()->dims().equals({1, 3, 10, 20})); |
993 | } |
994 | |
995 | /// Check that the removing of transposes still happens when |
996 | /// predicates are involved. |
997 | TEST_F(GraphOptz, transposeConstantWithPredicate) { |
998 | auto *A = |
999 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
1000 | auto *pred = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1001 | bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG()); |
1002 | Tensor transposedA; |
1003 | bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2}); |
1004 | // Arguably, if the transpose doesn't happen because the predicate is false |
1005 | // the value of A should be unchanged. However, the semantic of our |
1006 | // predicate is that they can be ignored and the program would still |
1007 | // be correct, thus this optimization is still legal. |
1008 | Node *T = F_->createTranspose("transpose" , A, NHWC2NCHW); |
1009 | T->setPredicate(pred); |
1010 | SaveNode *save = F_->createSave("ret" , T); |
1011 | save->setPredicate(pred); |
1012 | EXPECT_EQ(F_->getNodes().size(), 2); |
1013 | |
1014 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
1015 | ::glow::optimize(F_, CompilationMode::Infer); |
1016 | ASSERT_EQ(F_->getNodes().size(), 1); |
1017 | EXPECT_EQ(&*F_->getNodes().begin(), save); |
1018 | // We should have kept the predicate on the save node. |
1019 | ASSERT_EQ(pred->getNumUsers(), 1); |
1020 | EXPECT_EQ(pred->getUsers().begin()->getUser(), save); |
1021 | Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode()); |
1022 | ASSERT_NE(optimizedA, nullptr); |
1023 | // Check that A has been properly transposed. |
1024 | EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA)); |
1025 | } |
1026 | |
1027 | TEST_F(GraphOptz, BatchNormAfterConvNotOptimizeForTrain) { |
1028 | Placeholder *A = |
1029 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
1030 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
1031 | Node *BN = |
1032 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
1033 | F_->createSave("ret" , BN); |
1034 | |
1035 | EXPECT_EQ(F_->getNodes().size(), 3); |
1036 | |
1037 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
1038 | ::glow::optimize(optimizedF_, CompilationMode::Train); |
1039 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
1040 | |
1041 | ASSERT_EQ(A->getNumUsers(), 2); |
1042 | Node *curCV = A->getUsers().begin()->getUser(); |
1043 | EXPECT_EQ(curCV, CV); |
1044 | ASSERT_EQ(curCV->getNumUsers(), 1); |
1045 | Node *curBN = curCV->getUsers().begin()->getUser(); |
1046 | EXPECT_EQ(curBN, BN); |
1047 | ASSERT_EQ(curBN->getNumUsers(), 1); |
1048 | Node *save = curBN->getUsers().begin()->getUser(); |
1049 | EXPECT_TRUE(llvm::isa<SaveNode>(save)); |
1050 | |
1051 | bindings_.allocate(mod_.getPlaceholders()); |
1052 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1053 | checkNumericalEquivalence(); |
1054 | } |
1055 | |
1056 | TEST_F(GraphOptz, batchNormAfterConvNotOptimizeWhenMoreThanOneUseOfConv) { |
1057 | Node *A = |
1058 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A" , false); |
1059 | |
1060 | Node *CV = F_->createConv(bindings_, "conv" , A, 16, 5, 1, 2, 1); |
1061 | Node *BN = |
1062 | F_->createBatchNormalization(bindings_, "batch" , CV, 3, 0.0001, 0.9); |
1063 | SaveNode *convSave = F_->createSave("ret" , CV); |
1064 | SaveNode *ret = F_->createSave("ret" , BN); |
1065 | |
1066 | EXPECT_EQ(F_->getNodes().size(), 4); |
1067 | |
1068 | ::glow::optimize(F_, CompilationMode::Infer); |
1069 | // Make sure the structure of the graph did not change, since the convolution |
1070 | // node is used more than once. |
1071 | EXPECT_EQ(F_->getNodes().size(), 4); |
1072 | ASSERT_TRUE(llvm::isa<ConvolutionNode>(convSave->getInput())); |
1073 | ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convSave->getInput()); |
1074 | EXPECT_EQ(conv, CV); |
1075 | EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(ret->getInput())); |
1076 | BatchNormalizationNode *batchNorm = |
1077 | llvm::dyn_cast<BatchNormalizationNode>(ret->getInput()); |
1078 | EXPECT_EQ(batchNorm, BN); |
1079 | EXPECT_EQ(batchNorm->getInput().getNode(), CV); |
1080 | EXPECT_EQ(conv->getInput().getNode(), A); |
1081 | } |
1082 | |
1083 | enum class TestSinkTransposeNodesKind { |
1084 | BatchNormalization, |
1085 | Relu, |
1086 | LeakyRelu, |
1087 | Clip, |
1088 | Sigmoid, |
1089 | Tanh, |
1090 | Quantize, |
1091 | }; |
1092 | |
1093 | class GraphOptzSinkTransposeBelowParametrized |
1094 | : public GraphOptz, |
1095 | public ::testing::WithParamInterface<TestSinkTransposeNodesKind> { |
1096 | public: |
1097 | NodeValue getNodeFromInput(TestSinkTransposeNodesKind testNode, Node *T) { |
1098 | switch (testNode) { |
1099 | case TestSinkTransposeNodesKind::BatchNormalization: { |
1100 | return F_->createBatchNormalization(bindings_, "batch" , T, 1, 0.0001, 0.9) |
1101 | ->getResult(); |
1102 | } |
1103 | case TestSinkTransposeNodesKind::Relu: { |
1104 | return F_->createRELU("relu" , T)->getResult(); |
1105 | } |
1106 | case TestSinkTransposeNodesKind::LeakyRelu: { |
1107 | return F_->createLeakyRELU("leaky_relu" , T, 0.1)->getResult(); |
1108 | } |
1109 | case TestSinkTransposeNodesKind::Clip: { |
1110 | return F_->createClip("clip" , T, 0.0, 6.0)->getResult(); |
1111 | } |
1112 | case TestSinkTransposeNodesKind::Sigmoid: { |
1113 | return F_->createSigmoid("sigmoid" , T)->getResult(); |
1114 | } |
1115 | case TestSinkTransposeNodesKind::Tanh: { |
1116 | return F_->createTanh("tanh" , T)->getResult(); |
1117 | } |
1118 | case TestSinkTransposeNodesKind::Quantize: { |
1119 | return F_ |
1120 | ->createQuantize( |
1121 | "quantize" , T, |
1122 | mod_.uniqueType(ElemKind::Int8QTy, T->dims(0), 0.03, 5)) |
1123 | ->getResult(); |
1124 | } |
1125 | } |
1126 | LOG(DFATAL) << "Cannot reach here." ; |
1127 | return NodeValue(); // Prevents a compilation warning. |
1128 | } |
1129 | }; |
1130 | |
1131 | TEST_P(GraphOptzSinkTransposeBelowParametrized, |
1132 | TestSinkTransposeForDifferentCases) { |
1133 | const dim_t origDims[] = {1, 5, 10, 15}; |
1134 | const dim_t transposedDims[] = {1, 15, 5, 10}; |
1135 | auto *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input" , false); |
1136 | Node *T = F_->createTranspose("transpose" , A, NHWC2NCHW); |
1137 | auto IN = getNodeFromInput(GetParam(), T); |
1138 | SaveNode *O = F_->createSave("ret" , IN); |
1139 | |
1140 | EXPECT_EQ(F_->getNodes().size(), 3); |
1141 | EXPECT_EQ(IN.dims(), llvm::makeArrayRef(transposedDims)); |
1142 | |
1143 | optimizedF_ = optimizeFunctionForTest(F_); |
1144 | O = llvm::dyn_cast<SaveNode>(std::find_if( |
1145 | optimizedF_->getNodes().begin(), optimizedF_->getNodes().end(), |
1146 | [](const auto &N) { return N.getKind() == Kinded::Kind::SaveNodeKind; })); |
1147 | |
1148 | // Expecting Transpose->Output rather than N->Output. |
1149 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1150 | ASSERT_NE(transpose, nullptr); |
1151 | Node *N = transpose->getInput(); |
1152 | ASSERT_TRUE(N); |
1153 | // Test correct input. |
1154 | if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) { |
1155 | ASSERT_EQ(BatchNormalizationNode::InputIdx, 0); |
1156 | } else { |
1157 | ASSERT_EQ(N->getNumInputs(), 1); |
1158 | } |
1159 | // Check that the dimensions of the input and output have been |
1160 | // updated to compensate the absence of transpose. |
1161 | EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims)); |
1162 | EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims)); |
1163 | EXPECT_EQ(F_->getNodes().size(), 3); |
1164 | |
1165 | bindings_.allocate(mod_.getPlaceholders()); |
1166 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1167 | checkNumericalEquivalence(); |
1168 | } |
1169 | |
1170 | TEST_P(GraphOptzSinkTransposeBelowParametrized, |
1171 | TestSinkTransposeWithPredicateForDifferentCases) { |
1172 | if (GetParam() == TestSinkTransposeNodesKind::Quantize) { |
1173 | // Quantize does not work with generic test for predicates. |
1174 | return; |
1175 | } |
1176 | const dim_t origDims[] = {1, 5, 10, 15}; |
1177 | const dim_t transposedDims[] = {1, 15, 5, 10}; |
1178 | Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input" , false); |
1179 | Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1180 | Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1181 | Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1182 | Node *T = F_->createTranspose("transpose" , A, NHWC2NCHW); |
1183 | T->setPredicate(pred1); |
1184 | Node *IN = getNodeFromInput(GetParam(), T); |
1185 | IN->setPredicate(pred2); |
1186 | SaveNode *O = F_->createSave("ret" , IN); |
1187 | O->setPredicate(pred3); |
1188 | |
1189 | EXPECT_EQ(F_->getNodes().size(), 3); |
1190 | EXPECT_EQ(IN->getNthResult(0).dims(), llvm::makeArrayRef(transposedDims)); |
1191 | |
1192 | ::glow::optimize(F_, CompilationMode::Infer); |
1193 | |
1194 | EXPECT_EQ(O->getPredicate().getNode(), pred3); |
1195 | // Expecting Transpose->Output rather than N->Output. |
1196 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1197 | ASSERT_NE(transpose, nullptr); |
1198 | EXPECT_EQ(transpose->getPredicate().getNode(), pred2); |
1199 | Node *N = transpose->getInput(); |
1200 | ASSERT_TRUE(N); |
1201 | EXPECT_EQ(N->getPredicate().getNode(), pred2); |
1202 | |
1203 | // Test correct input. |
1204 | if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) { |
1205 | ASSERT_EQ(BatchNormalizationNode::InputIdx, 0); |
1206 | } else { |
1207 | ASSERT_EQ(N->getNumInputs(), 1); |
1208 | } |
1209 | |
1210 | // Check that the dimensions of the input and output have been |
1211 | // updated to compensate the absence of transpose. |
1212 | EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims)); |
1213 | EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims)); |
1214 | EXPECT_EQ(F_->getNodes().size(), 3); |
1215 | } |
1216 | |
1217 | GLOW_INSTANTIATE_TEST_SUITE_P( |
1218 | TestSinkTranspose, GraphOptzSinkTransposeBelowParametrized, |
1219 | ::testing::Values(TestSinkTransposeNodesKind::BatchNormalization, |
1220 | TestSinkTransposeNodesKind::Relu, |
1221 | TestSinkTransposeNodesKind::LeakyRelu, |
1222 | TestSinkTransposeNodesKind::Clip, |
1223 | TestSinkTransposeNodesKind::Sigmoid, |
1224 | TestSinkTransposeNodesKind::Tanh, |
1225 | TestSinkTransposeNodesKind::Quantize)); |
1226 | |
1227 | TEST_F(GraphOptz, SinkTransposeBelowDequantize) { |
1228 | auto *in = |
1229 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input" , false); |
1230 | auto *quantize = F_->createQuantize( |
1231 | "quantize" , in, mod_.uniqueType(ElemKind::Int8QTy, in->dims(), 0.01, 2)); |
1232 | auto *tile = F_->createTile("tile" , quantize, 3, 0); |
1233 | auto *transpose = F_->createTranspose("transpose" , tile, NHWC2NCHW); |
1234 | auto *deq = F_->createDequantize("dequantize" , transpose, ElemKind::FloatTy); |
1235 | SaveNode *O = F_->createSave("out" , deq); |
1236 | |
1237 | optimizedF_ = optimizeFunctionForTest(F_); |
1238 | |
1239 | EXPECT_EQ(F_->getNodes().size(), 5); |
1240 | EXPECT_EQ(optimizedF_->getNodes().size(), 5); |
1241 | |
1242 | auto *optOut = findFunctionNodeByName<SaveNode>(optimizedF_, O->getName()); |
1243 | EXPECT_TRUE(llvm::isa<TransposeNode>(optOut->getInput().getNode())); |
1244 | |
1245 | bindings_.allocate(mod_.getPlaceholders()); |
1246 | bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1247 | checkNumericalEquivalence(); |
1248 | } |
1249 | |
1250 | TEST_F(GraphOptz, SinkTransposeBelowPRelu) { |
1251 | auto *input = |
1252 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input" , false); |
1253 | auto *slope = |
1254 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "slope" , false); |
1255 | auto *OT = mod_.uniqueType(ElemKind::FloatTy, {1, 5, 10, 15}); |
1256 | auto *prelu = F_->createPRELU("prelu" , input, slope, OT); |
1257 | auto *transpose = F_->createTranspose("transpose" , prelu, NHWC2NCHW); |
1258 | SaveNode *O = F_->createSave("out" , transpose); |
1259 | |
1260 | optimizedF_ = optimizeFunctionForTest(F_); |
1261 | |
1262 | EXPECT_EQ(F_->getNodes().size(), 3); |
1263 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
1264 | |
1265 | auto *optOut = findFunctionNodeByName<SaveNode>(optimizedF_, O->getName()); |
1266 | EXPECT_TRUE(llvm::isa<TransposeNode>(optOut->getInput().getNode())); |
1267 | |
1268 | bindings_.allocate(mod_.getPlaceholders()); |
1269 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1270 | bindings_.get(slope)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1271 | checkNumericalEquivalence(); |
1272 | } |
1273 | |
1274 | TEST_F(GraphOptz, SinkTransposeBelowTile) { |
1275 | auto *in = |
1276 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input" , false); |
1277 | auto *transpose = F_->createTranspose("transpose" , in, NHWC2NCHW); |
1278 | auto *tile = F_->createTile("tile" , transpose, 4, 1); |
1279 | auto *save = F_->createSave("save" , tile); |
1280 | |
1281 | optimizedF_ = optimizeFunctionForTest( |
1282 | F_, {FunctionPassID::SinkCode, getDCEPassConfig()}); |
1283 | |
1284 | EXPECT_EQ(F_->getNodes().size(), 3); |
1285 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
1286 | |
1287 | auto *saveOpt = |
1288 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
1289 | auto *transposeOpt = llvm::dyn_cast<TransposeNode>(saveOpt->getInput()); |
1290 | ASSERT_TRUE(transposeOpt); |
1291 | EXPECT_EQ(transposeOpt->getShuffle(), transpose->getShuffle()); |
1292 | auto *tileOpt = llvm::dyn_cast<TileNode>(transposeOpt->getInput()); |
1293 | ASSERT_TRUE(tileOpt); |
1294 | EXPECT_EQ(tileOpt->getAxis(), 3); |
1295 | EXPECT_EQ(tileOpt->getCount(), 4); |
1296 | |
1297 | bindings_.allocate(mod_.getPlaceholders()); |
1298 | bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1299 | checkNumericalEquivalence(); |
1300 | } |
1301 | |
1302 | TEST_F(GraphOptz, HoistTransposeAboveTile) { |
1303 | auto *in = |
1304 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input" , false); |
1305 | auto *tile = F_->createTile("tile" , in, 4, 3); |
1306 | auto *transpose = F_->createTranspose("transpose" , tile, NHWC2NCHW); |
1307 | auto *save = F_->createSave("save" , transpose); |
1308 | |
1309 | optimizedF_ = optimizeFunctionForTest(F_); |
1310 | |
1311 | EXPECT_EQ(F_->getNodes().size(), 3); |
1312 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
1313 | |
1314 | auto *saveOpt = |
1315 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
1316 | auto *tileOpt = llvm::dyn_cast<TileNode>(saveOpt->getInput()); |
1317 | ASSERT_TRUE(tileOpt); |
1318 | EXPECT_EQ(tileOpt->getAxis(), 1); |
1319 | EXPECT_EQ(tileOpt->getCount(), 4); |
1320 | auto *transposeOpt = llvm::dyn_cast<TransposeNode>(tileOpt->getInput()); |
1321 | ASSERT_TRUE(transposeOpt); |
1322 | EXPECT_EQ(transposeOpt->getShuffle(), transpose->getShuffle()); |
1323 | |
1324 | bindings_.allocate(mod_.getPlaceholders()); |
1325 | bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1326 | checkNumericalEquivalence(); |
1327 | } |
1328 | |
1329 | /// For example folding Rescale in to Convolution. |
1330 | TEST_F(GraphOptz, sinkTransposeBelowRescale) { |
1331 | // Inputs. |
1332 | const dim_t origDims[] = {1, 5, 10, 15}; |
1333 | const dim_t transposedDims[] = {1, 15, 5, 10}; |
1334 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, origDims, 0.1, 0, |
1335 | "input" , false); |
1336 | auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {15, 1, 1, 15}, 0.1, |
1337 | 0, "filter" , false); |
1338 | auto *bias = |
1339 | mod_.createPlaceholder(ElemKind::Int32QTy, {15}, 0.01, 0, "bias" , false); |
1340 | |
1341 | // Graph. |
1342 | ConvolutionNode *conv = |
1343 | F_->createConv("conv" , input, filter, bias, input->getType(), {1, 1}, |
1344 | {1, 1}, {0, 0, 0, 0}, 1); |
1345 | |
1346 | auto *T = F_->createTranspose("transpose" , conv, NHWC2NCHW); |
1347 | auto *RT = mod_.uniqueType(ElemKind::Int8QTy, T->getResult().dims(), 0.2, 0); |
1348 | auto *R = F_->createRescaleQuantized("rescale" , T, RT); |
1349 | SaveNode *O = F_->createSave("ret" , R); |
1350 | |
1351 | EXPECT_EQ(F_->getNodes().size(), 4); |
1352 | EXPECT_EQ(RT->dims(), llvm::makeArrayRef(transposedDims)); |
1353 | |
1354 | ::glow::optimize(F_, CompilationMode::Infer); |
1355 | |
1356 | // Expecting Transpose->Output rather than Rescale->Output. |
1357 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1358 | ASSERT_NE(transpose, nullptr); |
1359 | ASSERT_TRUE(llvm::isa<ConvolutionNode>(transpose->getInput())); |
1360 | auto &convTRInput = transpose->getInput(); |
1361 | // Check that the dimensions of the input and output have been |
1362 | // updated to compensate the absence of transpose. |
1363 | EXPECT_EQ(convTRInput.dims(), llvm::makeArrayRef(origDims)); |
1364 | EXPECT_EQ(convTRInput.getNode()->getNthInput(0).dims(), |
1365 | llvm::makeArrayRef(origDims)); |
1366 | EXPECT_EQ(F_->getNodes().size(), 3); |
1367 | } |
1368 | |
1369 | TEST_F(GraphOptz, cancelTwoTransposes) { |
1370 | const dim_t origDims[] = {1, 5, 10, 15}; |
1371 | Placeholder *A = |
1372 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input" , false); |
1373 | Node *T1 = F_->createTranspose("transpose" , A, NCHW2NHWC); |
1374 | Node *T2 = F_->createTranspose("transpose" , T1, NHWC2NCHW); |
1375 | ReluNode *K = F_->createRELU("relu" , T2); |
1376 | SaveNode *save = F_->createSave("ret" , K); |
1377 | |
1378 | EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims)); |
1379 | EXPECT_EQ(F_->getNodes().size(), 4); |
1380 | |
1381 | optimizedF_ = optimizeFunctionForTest(F_); |
1382 | |
1383 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
1384 | |
1385 | for (auto &N : optimizedF_->getNodes()) { |
1386 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
1387 | save = llvm::dyn_cast<SaveNode>(&N); |
1388 | } |
1389 | } |
1390 | |
1391 | ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput()); |
1392 | ASSERT_TRUE(relu); |
1393 | EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims)); |
1394 | EXPECT_EQ(relu->getInput().getNode(), A); |
1395 | |
1396 | bindings_.allocate(mod_.getPlaceholders()); |
1397 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1398 | |
1399 | checkNumericalEquivalence(); |
1400 | } |
1401 | |
1402 | /// Make sure the predicates don't get in the way of the |
1403 | /// transpose(transpose) => identity and that they are |
1404 | /// preserved. |
1405 | TEST_F(GraphOptz, cancelTwoTransposesWithPredicate) { |
1406 | const dim_t origDims[] = {1, 5, 10, 15}; |
1407 | Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input" , false); |
1408 | Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1409 | Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1410 | Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1411 | Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1412 | Node *T1 = F_->createTranspose("transpose" , A, NCHW2NHWC); |
1413 | T1->setPredicate(pred1); |
1414 | Node *T2 = F_->createTranspose("transpose" , T1, NHWC2NCHW); |
1415 | T2->setPredicate(pred2); |
1416 | ReluNode *K = F_->createRELU("relu" , T2); |
1417 | K->setPredicate(pred3); |
1418 | SaveNode *save = F_->createSave("ret" , K); |
1419 | save->setPredicate(pred4); |
1420 | |
1421 | EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims)); |
1422 | EXPECT_EQ(F_->getNodes().size(), 4); |
1423 | |
1424 | ::glow::optimize(F_, CompilationMode::Infer); |
1425 | |
1426 | EXPECT_EQ(F_->getNodes().size(), 2); |
1427 | EXPECT_EQ(save->getPredicate().getNode(), pred4); |
1428 | ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput()); |
1429 | ASSERT_TRUE(relu); |
1430 | EXPECT_EQ(relu->getPredicate().getNode(), pred3); |
1431 | EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims)); |
1432 | EXPECT_EQ(relu->getInput().getNode(), A); |
1433 | } |
1434 | |
1435 | TEST_F(GraphOptz, removeIdentityTranspose) { |
1436 | const dim_t origDims[] = {1, 5, 10, 15}; |
1437 | Placeholder *A = |
1438 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input" , false); |
1439 | TransposeNode *T = F_->createTranspose("transpose" , A, {0, 1, 2, 3}); |
1440 | ReluNode *K = F_->createRELU("relu" , T); |
1441 | F_->createSave("ret" , K); |
1442 | |
1443 | EXPECT_EQ(F_->getNodes().size(), 3); |
1444 | EXPECT_EQ(K->getInput().getNode(), T); |
1445 | |
1446 | ::glow::optimize(F_, CompilationMode::Infer); |
1447 | |
1448 | EXPECT_EQ(F_->getNodes().size(), 2); |
1449 | EXPECT_EQ(K->getInput().getNode(), A); |
1450 | // Make sure we didn't mess up with the dimensions of the |
1451 | // variable while eliminating the transpose. |
1452 | EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims)); |
1453 | } |
1454 | |
1455 | /// Check that the predicates don't get in the way of |
1456 | /// the identity transpose removal, while still being |
1457 | /// preserved. |
1458 | TEST_F(GraphOptz, removeIdentityTransposeWithPredicate) { |
1459 | const dim_t origDims[] = {1, 5, 10, 15}; |
1460 | Placeholder *A = |
1461 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input" , false); |
1462 | Placeholder *pred1 = |
1463 | mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1464 | Placeholder *pred2 = |
1465 | mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1466 | Placeholder *pred3 = |
1467 | mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1468 | TransposeNode *T = F_->createTranspose("transpose" , A, {0, 1, 2, 3}); |
1469 | T->setPredicate(pred1); |
1470 | ReluNode *K = F_->createRELU("relu" , T); |
1471 | K->setPredicate(pred2); |
1472 | SaveNode *save = F_->createSave("ret" , K); |
1473 | save->setPredicate(pred3); |
1474 | |
1475 | EXPECT_EQ(F_->getNodes().size(), 3); |
1476 | EXPECT_EQ(K->getInput().getNode(), T); |
1477 | |
1478 | ::glow::optimize(F_, CompilationMode::Infer); |
1479 | EXPECT_EQ(F_->getNodes().size(), 2); |
1480 | EXPECT_EQ(save->getPredicate().getNode(), pred3); |
1481 | EXPECT_EQ(save->getInput().getNode(), K); |
1482 | EXPECT_EQ(K->getInput().getNode(), A); |
1483 | EXPECT_EQ(K->getPredicate().getNode(), pred2); |
1484 | // Make sure we didn't mess up with the dimensions of the |
1485 | // variable while eliminating the transpose. |
1486 | EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims)); |
1487 | } |
1488 | |
1489 | /// Check that consecutive non-inverse transposes are merged |
1490 | /// into an equivalent single transpose node. |
1491 | TEST_F(GraphOptz, mergeNonInverseTransposes) { |
1492 | const dim_t origDims[] = {1, 5, 10, 15}; |
1493 | const dim_t finalDims[] = {5, 1, 15, 10}; |
1494 | |
1495 | Placeholder *A = |
1496 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input" , false); |
1497 | TransposeNode *T1 = F_->createTranspose("transpose" , A, {0, 3, 2, 1}); |
1498 | TransposeNode *T2 = F_->createTranspose("transpose" , T1, {0, 2, 3, 1}); |
1499 | TransposeNode *T3 = F_->createTranspose("transpose" , T2, {1, 0, 3, 2}); |
1500 | TransposeNode *T4 = F_->createTranspose("transpose" , T3, {3, 1, 2, 0}); |
1501 | |
1502 | // Intermediate dims after each tranpose |
1503 | // Initial : {1, 5, 10, 15} |
1504 | // After T1: {1, 15, 10, 5} |
1505 | // After T2: {1, 10, 5, 15} |
1506 | // After T3: {10, 1, 15, 5} |
1507 | // After T4: {5, 1, 15, 10} |
1508 | |
1509 | SaveNode *save = F_->createSave("ret" , T4); |
1510 | |
1511 | EXPECT_EQ(F_->getNodes().size(), 5); |
1512 | |
1513 | optimizedF_ = optimizeFunctionForTest(F_); |
1514 | // Find save node in the optimized graph. |
1515 | for (auto &N : optimizedF_->getNodes()) { |
1516 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
1517 | save = llvm::dyn_cast<SaveNode>(&N); |
1518 | } |
1519 | } |
1520 | // Get the last transpose node in the optimized graph. |
1521 | auto *TR = llvm::dyn_cast<TransposeNode>(save->getInput()); |
1522 | ASSERT_NE(TR, nullptr); |
1523 | |
1524 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
1525 | EXPECT_EQ(TR->getResult().dims(), llvm::makeArrayRef(finalDims)); |
1526 | EXPECT_EQ(A->getNthResult(0).dims(), llvm::makeArrayRef(origDims)); |
1527 | EXPECT_EQ(TR->getInput().getNode(), A); |
1528 | |
1529 | bindings_.allocate(mod_.getPlaceholders()); |
1530 | bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
1531 | checkNumericalEquivalence(); |
1532 | } |
1533 | |
1534 | TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodes) { |
1535 | const dim_t origDims[] = {1, 5, 10, 15}; |
1536 | Node *A1 = |
1537 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1" , false); |
1538 | Node *A2 = |
1539 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2" , false); |
1540 | Node *T1 = F_->createTranspose("transpose1" , A1, NHWC2NCHW); |
1541 | Node *T2 = F_->createTranspose("transpose2" , A2, NHWC2NCHW); |
1542 | Node *K = F_->createAdd("arith" , T1, T2); |
1543 | Node *P = F_->createPow("pow" , K, T2); |
1544 | SaveNode *O = F_->createSave("ret" , P); |
1545 | |
1546 | EXPECT_EQ(F_->getNodes().size(), 5); |
1547 | |
1548 | ::glow::optimize(F_, CompilationMode::Infer); |
1549 | |
1550 | // Expecting Transpose->Output rather than Add->Output. |
1551 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1552 | ASSERT_NE(transpose, nullptr); |
1553 | auto *pow = llvm::dyn_cast<PowNode>(transpose->getInput()); |
1554 | ASSERT_TRUE(pow); |
1555 | auto *add = llvm::dyn_cast<AddNode>(pow->getLHS()); |
1556 | ASSERT_TRUE(add); |
1557 | // Check that the dimensions of the input and output have been |
1558 | // updated to compensate the absence of transpose. |
1559 | EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims)); |
1560 | EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims)); |
1561 | EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims)); |
1562 | EXPECT_EQ(add->getLHS().getNode(), A1); |
1563 | EXPECT_EQ(add->getRHS().getNode(), A2); |
1564 | |
1565 | EXPECT_EQ(F_->getNodes().size(), 4); |
1566 | } |
1567 | |
1568 | /// Check that Transpose node is sunk below arithmetic nodes when one of the |
1569 | /// operands is a Constant. |
1570 | TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithConstantOperand) { |
1571 | const dim_t origDims[] = {1, 5, 10, 15}; |
1572 | const dim_t transposedDims[] = {1, 15, 5, 10}; |
1573 | |
1574 | // Create one subgraph in which the Constant is the LHS operand of the Add. |
1575 | Constant *C1 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C1" ); |
1576 | // Initialize the payload before optimization so that it can be copied to the |
1577 | // new Constant that will be created by the GraphOptimizer. |
1578 | C1->getHandle().randomize(-1, 1, mod_.getPRNG()); |
1579 | |
1580 | auto *P1 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P1" , false); |
1581 | auto *T1 = F_->createTranspose("T1" , P1, NHWC2NCHW); |
1582 | auto *A1 = F_->createAdd("A1" , C1, T1); |
1583 | SaveNode *S1 = F_->createSave("S1" , A1); |
1584 | |
1585 | // Create one subgraph in which the Constnat is the RHS operand of the Add. |
1586 | Constant *C2 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C2" ); |
1587 | // Initialize the payload before optimization so that it can be copied to the |
1588 | // new Constant that will be created by the GraphOptimizer. |
1589 | C2->getHandle().randomize(-1, 1, mod_.getPRNG()); |
1590 | |
1591 | auto *P2 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P2" , false); |
1592 | auto *T2 = F_->createTranspose("T2" , P2, NHWC2NCHW); |
1593 | auto *A2 = F_->createAdd("A2" , T2, C2); |
1594 | SaveNode *S2 = F_->createSave("S2" , A2); |
1595 | |
1596 | EXPECT_EQ(F_->getNodes().size(), 6); |
1597 | |
1598 | optimizedF_ = optimizeFunctionForTest(F_); |
1599 | |
1600 | // Find the SaveNodes of the optimized graph. |
1601 | for (auto &N : optimizedF_->getNodes()) { |
1602 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
1603 | if (N.getName() == S1->getName()) { |
1604 | S1 = llvm::dyn_cast<SaveNode>(&N); |
1605 | } |
1606 | |
1607 | if (N.getName() == S2->getName()) { |
1608 | S2 = llvm::dyn_cast<SaveNode>(&N); |
1609 | } |
1610 | } |
1611 | } |
1612 | |
1613 | // Expecting Transpose->Output rather than Add->Output. |
1614 | auto *transpose = llvm::dyn_cast<TransposeNode>(S1->getInput()); |
1615 | ASSERT_NE(transpose, nullptr); |
1616 | auto *add = llvm::dyn_cast<AddNode>(transpose->getInput()); |
1617 | ASSERT_TRUE(add); |
1618 | // Check that the dimensions of the input and output of the add have been |
1619 | // updated to compensate the absence of transpose. |
1620 | EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims)); |
1621 | EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims)); |
1622 | EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims)); |
1623 | EXPECT_EQ(add->getRHS().getNode(), P1); |
1624 | |
1625 | // Repeat checks for other subgraph. |
1626 | transpose = llvm::dyn_cast<TransposeNode>(S2->getInput()); |
1627 | ASSERT_NE(transpose, nullptr); |
1628 | add = llvm::dyn_cast<AddNode>(transpose->getInput()); |
1629 | ASSERT_TRUE(add); |
1630 | EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims)); |
1631 | EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims)); |
1632 | EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims)); |
1633 | EXPECT_EQ(add->getLHS().getNode(), P2); |
1634 | |
1635 | EXPECT_EQ(optimizedF_->getNodes().size(), 6); |
1636 | |
1637 | // Check that the original and optimized functions are numerically equivalent. |
1638 | // This indirectly checks that the Constant has been transposed properly. |
1639 | bindings_.allocate(mod_.getPlaceholders()); |
1640 | bindings_.get(P1)->getHandle().randomize(-1, 1, mod_.getPRNG()); |
1641 | bindings_.get(P2)->getHandle().randomize(-1, 1, mod_.getPRNG()); |
1642 | |
1643 | checkNumericalEquivalence(); |
1644 | } |
1645 | |
1646 | /// Test sink Transpose below Add of which operands has the same element type |
1647 | /// and shape, but different scale and offset. |
1648 | TEST_F(GraphOptz, sinkQuantTransposeBelowArithmeticNodesWithConstantOperand) { |
1649 | const dim_t origDims[] = {1, 5, 10, 15}; |
1650 | const dim_t transposedDims[] = {1, 15, 5, 10}; |
1651 | |
1652 | // Create graph where a Add take a Constant in LHS and Transpose in RHS. |
1653 | // LHS and RHS has different scale and offset. |
1654 | Constant *lhsC = |
1655 | mod_.createConstant(ElemKind::Int8QTy, transposedDims, 0.2, 0, "C1" ); |
1656 | lhsC->getHandle<int8_t>().randomize(-128, 127, mod_.getPRNG()); |
1657 | |
1658 | auto *inputP = |
1659 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "Input" , false); |
1660 | auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, origDims, 0.3, 2); |
1661 | auto *quant = F_->createQuantize("Quant" , inputP, qTy); |
1662 | auto *rhsT = F_->createTranspose("RHS" , quant, NHWC2NCHW); |
1663 | auto *addQ = F_->createAdd("Add" , lhsC, rhsT); |
1664 | SaveNode *save = F_->createSave("Save" , addQ); |
1665 | |
1666 | EXPECT_EQ(F_->getNodes().size(), 4); |
1667 | |
1668 | optimizedF_ = optimizeFunctionForTest(F_); |
1669 | |
1670 | // Expecting Transpose->Output rather than Add->Output. |
1671 | const auto *saveOpt = |
1672 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
1673 | auto *transpose = llvm::dyn_cast<TransposeNode>(saveOpt->getInput()); |
1674 | ASSERT_NE(transpose, nullptr); |
1675 | auto *add = llvm::dyn_cast<AddNode>(transpose->getInput()); |
1676 | ASSERT_TRUE(add); |
1677 | // Check that the dimensions of the input and output of the add have been |
1678 | // updated to compensate the absence of transpose. |
1679 | EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims)); |
1680 | EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims)); |
1681 | EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims)); |
1682 | quant = llvm::dyn_cast<QuantizeNode>(add->getRHS().getNode()); |
1683 | ASSERT_TRUE(quant); |
1684 | EXPECT_EQ(quant->getInput().getNode(), inputP); |
1685 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
1686 | |
1687 | // Check that the original and optimized functions are numerically equivalent. |
1688 | // This indirectly checks that the Constant has been transposed properly. |
1689 | bindings_.allocate(mod_.getPlaceholders()); |
1690 | bindings_.get(inputP)->getHandle().randomize(-128, 127, mod_.getPRNG()); |
1691 | |
1692 | checkNumericalEquivalence(); |
1693 | } |
1694 | |
1695 | /// Check that the predicates are properly preserved while doing |
1696 | /// the add(transpose, transpose) => transpose(add). |
1697 | TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithPredicate) { |
1698 | const dim_t origDims[] = {1, 5, 10, 15}; |
1699 | Node *A1 = |
1700 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1" , false); |
1701 | Node *A2 = |
1702 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2" , false); |
1703 | Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1704 | Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1705 | Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1706 | Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1707 | Node *T1 = F_->createTranspose("transpose1" , A1, NHWC2NCHW); |
1708 | T1->setPredicate(pred1); |
1709 | Node *T2 = F_->createTranspose("transpose2" , A2, NHWC2NCHW); |
1710 | T2->setPredicate(pred2); |
1711 | Node *K = F_->createAdd("arith" , T1, T2); |
1712 | K->setPredicate(pred3); |
1713 | SaveNode *O = F_->createSave("ret" , K); |
1714 | O->setPredicate(pred4); |
1715 | |
1716 | EXPECT_EQ(F_->getNodes().size(), 4); |
1717 | |
1718 | ::glow::optimize(F_, CompilationMode::Infer); |
1719 | |
1720 | EXPECT_EQ(O->getPredicate().getNode(), pred4); |
1721 | // Expecting Transpose->Output rather than Add->Output. |
1722 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1723 | ASSERT_NE(transpose, nullptr); |
1724 | EXPECT_EQ(transpose->getPredicate().getNode(), pred3); |
1725 | auto *add = llvm::dyn_cast<AddNode>(transpose->getInput()); |
1726 | ASSERT_TRUE(add); |
1727 | EXPECT_EQ(add->getPredicate().getNode(), pred3); |
1728 | // Check that the dimensions of the input and output have been |
1729 | // updated to compensate the absence of transpose. |
1730 | EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims)); |
1731 | EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims)); |
1732 | EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims)); |
1733 | EXPECT_EQ(add->getLHS().getNode(), A1); |
1734 | EXPECT_EQ(add->getRHS().getNode(), A2); |
1735 | |
1736 | EXPECT_EQ(F_->getNodes().size(), 3); |
1737 | } |
1738 | |
1739 | TEST_F(GraphOptz, sinkReluBelowConcatNodes) { |
1740 | const dim_t origDims[] = {1, 5, 10, 15}; |
1741 | const dim_t origDimsConcat[] = {1, 10, 10, 15}; |
1742 | Node *A1 = |
1743 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1" , false); |
1744 | Node *A2 = |
1745 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2" , false); |
1746 | Node *R1 = F_->createRELU("relu1" , A1); |
1747 | Node *R2 = F_->createRELU("relu2" , A2); |
1748 | Node *CN = F_->createConcat("concat" , {R1, R2}, 1); |
1749 | SaveNode *O = F_->createSave("ret" , CN); |
1750 | |
1751 | EXPECT_EQ(F_->getNodes().size(), 4); |
1752 | |
1753 | ::glow::optimize(F_, CompilationMode::Infer); |
1754 | |
1755 | // Expecting RELU->Output rather than Concat->Output. |
1756 | auto *relu = llvm::dyn_cast<ReluNode>(O->getInput()); |
1757 | ASSERT_NE(relu, nullptr); |
1758 | auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput()); |
1759 | ASSERT_TRUE(concat); |
1760 | // Check that the dimensions of the input and output have been |
1761 | // updated to compensate the absence of transpose. |
1762 | EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat)); |
1763 | EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims)); |
1764 | EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims)); |
1765 | EXPECT_EQ(concat->getInputs()[0].getNode(), A1); |
1766 | EXPECT_EQ(concat->getInputs()[1].getNode(), A2); |
1767 | |
1768 | EXPECT_EQ(F_->getNodes().size(), 3); |
1769 | } |
1770 | |
1771 | /// Check that the predicates are properly preserved while doing |
1772 | /// the sinking of relu nodes. |
1773 | TEST_F(GraphOptz, sinkReluBelowConcatNodesWithPredicate) { |
1774 | const dim_t origDims[] = {1, 5, 10, 15}; |
1775 | const dim_t origDimsConcat[] = {1, 10, 10, 15}; |
1776 | Node *A1 = |
1777 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1" , false); |
1778 | Node *A2 = |
1779 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2" , false); |
1780 | Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1781 | Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1782 | Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1783 | Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1784 | Node *R1 = F_->createRELU("relu1" , A1); |
1785 | R1->setPredicate(pred1); |
1786 | Node *R2 = F_->createRELU("relu2" , A2); |
1787 | R2->setPredicate(pred2); |
1788 | Node *CN = F_->createConcat("concat" , {R1, R2}, 1); |
1789 | CN->setPredicate(pred3); |
1790 | SaveNode *O = F_->createSave("ret" , CN); |
1791 | O->setPredicate(pred4); |
1792 | |
1793 | EXPECT_EQ(F_->getNodes().size(), 4); |
1794 | |
1795 | ::glow::optimize(F_, CompilationMode::Infer); |
1796 | |
1797 | // Expecting RELU->Output rather than Concat->Output. |
1798 | EXPECT_EQ(O->getPredicate().getNode(), pred4); |
1799 | auto *relu = llvm::dyn_cast<ReluNode>(O->getInput()); |
1800 | ASSERT_NE(relu, nullptr); |
1801 | EXPECT_EQ(relu->getPredicate().getNode(), pred3); |
1802 | auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput()); |
1803 | ASSERT_TRUE(concat); |
1804 | EXPECT_EQ(concat->getPredicate().getNode(), pred3); |
1805 | // Check that the dimensions of the input and output have been |
1806 | // updated to compensate the absence of transpose. |
1807 | EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat)); |
1808 | EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims)); |
1809 | EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims)); |
1810 | EXPECT_EQ(concat->getInputs()[0].getNode(), A1); |
1811 | EXPECT_EQ(concat->getInputs()[1].getNode(), A2); |
1812 | |
1813 | EXPECT_EQ(F_->getNodes().size(), 3); |
1814 | } |
1815 | |
1816 | TEST_F(GraphOptz, sinkTransposeBelowConcatNodes) { |
1817 | const dim_t origDims[] = {1, 5, 10, 15}; |
1818 | const dim_t origDimsConcat[] = {1, 5, 20, 15}; |
1819 | Node *A1 = |
1820 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1" , false); |
1821 | Node *A2 = |
1822 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2" , false); |
1823 | Node *T1 = F_->createTranspose("transpose" , A1, NCHW2NHWC); |
1824 | Node *T2 = F_->createTranspose("transpose" , A2, NCHW2NHWC); |
1825 | Node *CN = F_->createConcat("concat" , {T1, T2}, 1); |
1826 | SaveNode *O = F_->createSave("ret" , CN); |
1827 | |
1828 | EXPECT_EQ(F_->getNodes().size(), 4); |
1829 | |
1830 | ::glow::optimize(F_, CompilationMode::Infer); |
1831 | |
1832 | // Expecting Transpose->Output rather than Add->Output. |
1833 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1834 | ASSERT_NE(transpose, nullptr); |
1835 | auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput()); |
1836 | ASSERT_TRUE(concat); |
1837 | // Check that the dimensions of the input and output have been |
1838 | // updated to compensate the absence of transpose. |
1839 | EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat)); |
1840 | EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims)); |
1841 | EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims)); |
1842 | EXPECT_EQ(concat->getInputs()[0].getNode(), A1); |
1843 | EXPECT_EQ(concat->getInputs()[1].getNode(), A2); |
1844 | |
1845 | EXPECT_EQ(F_->getNodes().size(), 3); |
1846 | } |
1847 | |
1848 | /// Check that the predicates are properly preserved while doing |
1849 | /// the concat(transpose, transpose) => transpose(add). |
1850 | TEST_F(GraphOptz, sinkTransposeBelowConcatWithPredicate) { |
1851 | const dim_t origDims[] = {1, 5, 10, 15}; |
1852 | const dim_t origDimsConcat[] = {1, 5, 20, 15}; |
1853 | Node *A1 = |
1854 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1" , false); |
1855 | Node *A2 = |
1856 | mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2" , false); |
1857 | Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1858 | Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1859 | Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1860 | Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred" , false); |
1861 | Node *T1 = F_->createTranspose("transpose" , A1, NCHW2NHWC); |
1862 | T1->setPredicate(pred1); |
1863 | Node *T2 = F_->createTranspose("transpose" , A2, NCHW2NHWC); |
1864 | T2->setPredicate(pred2); |
1865 | Node *CN = F_->createConcat("concat" , {T1, T2}, 1); |
1866 | CN->setPredicate(pred3); |
1867 | SaveNode *O = F_->createSave("ret" , CN); |
1868 | O->setPredicate(pred4); |
1869 | |
1870 | EXPECT_EQ(F_->getNodes().size(), 4); |
1871 | |
1872 | ::glow::optimize(F_, CompilationMode::Infer); |
1873 | |
1874 | EXPECT_EQ(O->getPredicate().getNode(), pred4); |
1875 | // Expecting Transpose->Output rather than Add->Output. |
1876 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1877 | ASSERT_NE(transpose, nullptr); |
1878 | EXPECT_EQ(transpose->getPredicate().getNode(), pred3); |
1879 | auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput()); |
1880 | ASSERT_TRUE(concat); |
1881 | EXPECT_EQ(concat->getPredicate().getNode(), pred3); |
1882 | // Check that the dimensions of the input and output have been |
1883 | // updated to compensate the absence of transpose. |
1884 | EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat)); |
1885 | EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims)); |
1886 | EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims)); |
1887 | EXPECT_EQ(concat->getInputs()[0].getNode(), A1); |
1888 | EXPECT_EQ(concat->getInputs()[1].getNode(), A2); |
1889 | |
1890 | EXPECT_EQ(F_->getNodes().size(), 3); |
1891 | } |
1892 | |
1893 | TEST_F(GraphOptz, sinkTransposeBelowPad) { |
1894 | // The shape of the graph before the optimization. |
1895 | const dim_t inputDims[] = {1, 5, 10, 15}; |
1896 | const dim_t outTransposeDims[] = {1, 10, 15, 5}; |
1897 | const dim_t outPadDims[] = {5, 18, 25, 11}; |
1898 | // Padding before the optimization. |
1899 | int pads[] = {0, 2, 3, 1, 4, 6, 7, 5}; |
1900 | |
1901 | // The shape of the graph after the optimization. |
1902 | const dim_t outPadDimsAfterOptim[] = {5, 11, 18, 25}; |
1903 | const dim_t outTransposeDimsAfterOptims[] = {5, 18, 25, 11}; |
1904 | // Padding after the optimization. |
1905 | int padsAfterOptim[] = {0, 1, 2, 3, 4, 5, 6, 7}; |
1906 | |
1907 | // Create the initial graph. |
1908 | Node *A = |
1909 | mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input" , false); |
1910 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, outPadDims); |
1911 | TransposeNode *T = F_->createTranspose("transpose" , A, NCHW2NHWC); |
1912 | Node *P = F_->createPad("pad" , T, outTy, PaddingMode::CONSTANT, pads, 23.f); |
1913 | EXPECT_EQ(T->getResult().dims(), llvm::makeArrayRef(outTransposeDims)); |
1914 | SaveNode *O = F_->createSave("ret" , P); |
1915 | |
1916 | EXPECT_EQ(F_->getNodes().size(), 3); |
1917 | |
1918 | ::glow::optimize(F_, CompilationMode::Infer); |
1919 | |
1920 | // Check the graph structure and additional properties after optimization. |
1921 | auto *trans = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1922 | ASSERT_NE(trans, nullptr); |
1923 | EXPECT_EQ(trans->getResult().dims(), |
1924 | llvm::makeArrayRef(outTransposeDimsAfterOptims)); |
1925 | auto *pad = llvm::dyn_cast<PadNode>(trans->getInput().getNode()); |
1926 | ASSERT_NE(pad, nullptr); |
1927 | |
1928 | EXPECT_EQ(pad->getPads(), llvm::makeArrayRef(padsAfterOptim)); |
1929 | EXPECT_EQ(pad->getResult().dims(), llvm::makeArrayRef(outPadDimsAfterOptim)); |
1930 | |
1931 | EXPECT_EQ(F_->getNodes().size(), 3); |
1932 | } |
1933 | |
1934 | TEST_F(GraphOptz, sinkTransposeBelowRelu) { |
1935 | // Define a type with custom alignments. |
1936 | Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1}); |
1937 | Type transposedTypeWithAlignments(ElemKind::FloatTy, {2, 4, 5, 3}, |
1938 | {1, 1, 32, 1}); |
1939 | auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments); |
1940 | auto modTransposedTyWithAlignments = |
1941 | mod_.uniqueType(transposedTypeWithAlignments); |
1942 | auto *A1 = mod_.createPlaceholder(modTyWithAlignments, "input1" , false); |
1943 | auto *T1 = F_->createTranspose("transpose" , A1, NCHW2NHWC); |
1944 | T1->setType(0, modTransposedTyWithAlignments); |
1945 | auto *RN = F_->createRELU("relu" , T1); |
1946 | SaveNode *O = F_->createSave("ret" , RN); |
1947 | |
1948 | EXPECT_EQ(F_->getNodes().size(), 3); |
1949 | |
1950 | ::glow::optimize(F_, CompilationMode::Infer); |
1951 | |
1952 | // Expecting Transpose->Output rather than Relu->Output, because Transpose was |
1953 | // sinked. |
1954 | auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput()); |
1955 | ASSERT_NE(transpose, nullptr); |
1956 | auto *relu = llvm::dyn_cast<ReluNode>(transpose->getInput()); |
1957 | ASSERT_TRUE(relu); |
1958 | // Check that alignments are preserved by optimizations. |
1959 | ASSERT_TRUE(relu->getInput().getType()->isEqual(modTyWithAlignments)); |
1960 | ASSERT_TRUE(transpose->getInput().getType()->isEqual(modTyWithAlignments)); |
1961 | ASSERT_TRUE( |
1962 | transpose->getResult().getType()->isEqual(modTransposedTyWithAlignments)); |
1963 | |
1964 | EXPECT_EQ(F_->getNodes().size(), 3); |
1965 | ASSERT_TRUE(F_->verify()); |
1966 | } |
1967 | |
1968 | TEST_F(GraphOptz, mergeConcatNodes) { |
1969 | Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1" , |
1970 | false); |
1971 | Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2" , |
1972 | false); |
1973 | Node *A3 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input3" , |
1974 | false); |
1975 | Node *A4 = |
1976 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input4" , false); |
1977 | Node *A5 = |
1978 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input5" , false); |
1979 | |
1980 | Node *CN1 = F_->createConcat("concat1" , {A1, A2}, 1); |
1981 | Node *CN2 = F_->createConcat("concat2" , {A1, CN1}, 1); |
1982 | Node *CN3 = F_->createConcat("concat3" , {A4, A5}, 2); |
1983 | Node *CN4 = F_->createConcat("concat4" , {A3, CN2, CN3}, 1); |
1984 | Node *O = F_->createSave("ret" , CN4); |
1985 | |
1986 | EXPECT_EQ(F_->getNodes().size(), 5); |
1987 | |
1988 | ::glow::optimize(F_, CompilationMode::Train); |
1989 | |
1990 | // It is expected that the optimization transforms |
1991 | // concat4(1, A3, concat2(1, A1, concat1(1, A1, A2)), concat3(2, A4, A5)) |
1992 | // into |
1993 | // concat4(1, A3, A1, A1, A2, concat3(2, A4, A5)) |
1994 | |
1995 | EXPECT_TRUE(llvm::isa<SaveNode>(O)); |
1996 | |
1997 | auto *CN = |
1998 | llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput()); |
1999 | EXPECT_TRUE(CN); |
2000 | |
2001 | // The merged ConcatNode should have 5 inputs. |
2002 | EXPECT_EQ(CN->getInputs().size(), 5); |
2003 | |
2004 | // CN1 should be merged into a new CN2 and later into a new CN4 and removed by |
2005 | // the optimizations. |
2006 | EXPECT_FALSE(functionContainsNode(F_, CN1)); |
2007 | |
2008 | // CN2 should be merged into a new CN4 and removed by the optimizations. |
2009 | EXPECT_FALSE(functionContainsNode(F_, CN2)); |
2010 | |
2011 | // CN3 should not be merged into CN4 and should not be removed, |
2012 | // because CN4 and CN3 have a different dimension parameter. |
2013 | EXPECT_TRUE(functionContainsNode(F_, CN3)); |
2014 | |
2015 | // The CN4 concat node should be replaced by a merged concat node. |
2016 | EXPECT_FALSE(functionContainsNode(F_, CN4)); |
2017 | |
2018 | EXPECT_EQ(F_->getNodes().size(), 3); |
2019 | } |
2020 | |
2021 | TEST_F(GraphOptz, CSE) { |
2022 | Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1" , |
2023 | false); |
2024 | Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2" , |
2025 | false); |
2026 | |
2027 | Node *CN1 = F_->createConcat("concat1" , {A1, A2}, 1); |
2028 | Node *CN2 = F_->createConcat("concat2" , {A1, A2}, 1); |
2029 | Node *CN3 = F_->createConcat("concat3" , {CN1, CN2}, 2); |
2030 | Node *O = F_->createSave("ret" , CN3); |
2031 | |
2032 | EXPECT_EQ(F_->getNodes().size(), 4); |
2033 | |
2034 | ::glow::optimize(F_, CompilationMode::Train); |
2035 | |
2036 | EXPECT_TRUE(llvm::isa<SaveNode>(O)); |
2037 | |
2038 | auto *CN = |
2039 | llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput()); |
2040 | EXPECT_TRUE(CN); |
2041 | |
2042 | // The merged ConcatNode should have 2 inputs. |
2043 | EXPECT_EQ(CN->getInputs().size(), 2); |
2044 | |
2045 | // CN1 should not be removed. |
2046 | EXPECT_TRUE(functionContainsNode(F_, CN1)); |
2047 | |
2048 | // CSE should replace CN2 by CN1 and remove CN2. |
2049 | EXPECT_FALSE(functionContainsNode(F_, CN2)); |
2050 | |
2051 | EXPECT_EQ(F_->getNodes().size(), 3); |
2052 | } |
2053 | |
2054 | TEST_F(GraphOptz, SliceOfSplatNode) { |
2055 | Type t(ElemKind::FloatTy, {1000, 1000, 1000}); |
2056 | Node *Z = F_->createSplat("zero" , &t, 0.); |
2057 | Node *S = F_->createSlice("slice" , Z, {5, 15, 42}, {99, 88, 77}); |
2058 | Node *O = F_->createSave("ret" , S); |
2059 | |
2060 | EXPECT_EQ(F_->getNodes().size(), 3); |
2061 | |
2062 | ::glow::optimize(F_, CompilationMode::Train); |
2063 | |
2064 | EXPECT_EQ(F_->getNodes().size(), 2); |
2065 | |
2066 | EXPECT_TRUE(llvm::isa<SaveNode>(O)); |
2067 | |
2068 | auto *CN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput()); |
2069 | ASSERT_TRUE(CN); |
2070 | |
2071 | EXPECT_TRUE(CN->getResult().getType()->dims().equals({94, 73, 35})); |
2072 | } |
2073 | |
2074 | /// Test Clip(Splat(args)) -> Splat(args'). |
2075 | TEST_F(GraphOptz, ClipOfSplatNode) { |
2076 | Type T(ElemKind::FloatTy, {10, 10}); |
2077 | SplatNode *splat = F_->createSplat("zero" , &T, 5); |
2078 | ClipNode *clipMin = F_->createClip("clip" , splat, 10, 15); |
2079 | ClipNode *clipMax = F_->createClip("clip" , splat, 0, 2); |
2080 | ClipNode *clipSame = F_->createClip("clip" , splat, 0, 10); |
2081 | SaveNode *saveMin = F_->createSave("saveMin" , clipMin); |
2082 | SaveNode *saveMax = F_->createSave("saveMax" , clipMax); |
2083 | SaveNode *saveSame = F_->createSave("saveSame" , clipSame); |
2084 | |
2085 | // Start with one splat, three clips, three saves. |
2086 | EXPECT_EQ(F_->getNodes().size(), 7); |
2087 | |
2088 | ::glow::optimize(F_, CompilationMode::Infer); |
2089 | |
2090 | // We will end up with three Splats and three saves. |
2091 | EXPECT_EQ(F_->getNodes().size(), 6); |
2092 | |
2093 | SplatNode *splatMin = llvm::dyn_cast<SplatNode>(saveMin->getInput()); |
2094 | ASSERT_TRUE(splatMin); |
2095 | EXPECT_EQ(splatMin->getValue(), 10); |
2096 | |
2097 | SplatNode *splatMax = llvm::dyn_cast<SplatNode>(saveMax->getInput()); |
2098 | ASSERT_TRUE(splatMax); |
2099 | EXPECT_EQ(splatMax->getValue(), 2); |
2100 | |
2101 | ASSERT_EQ(saveSame->getInput().getNode(), splat); |
2102 | EXPECT_EQ(splat->getValue(), 5); |
2103 | } |
2104 | |
2105 | TEST_F(GraphOptz, ZeroArithmetic) { |
2106 | // Tests the identities: [0 + X = X] [0 * X = 0] [0 / X = 0] [ X - 0 = X] |
2107 | |
2108 | auto *input = |
2109 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input" , true); |
2110 | |
2111 | // This builds the expression: ((0 / I) + (0 + I) + (0 * I)) - 0 |
2112 | |
2113 | auto *zero = F_->createSplat("zero" , input->getType(), 0.); |
2114 | |
2115 | auto *div = F_->createDiv("div" , zero, input); // -> zero |
2116 | |
2117 | auto *add = F_->createAdd("add" , zero, input); // -> input |
2118 | |
2119 | auto *mul = F_->createMul("mul" , zero, input); // -> zero |
2120 | |
2121 | auto *add3 = F_->createAdd("add" , div, add); |
2122 | |
2123 | add3 = F_->createAdd("add" , add3, mul); |
2124 | |
2125 | auto *sub = F_->createSub("sub" , add3, zero); // -> input |
2126 | |
2127 | SaveNode *O = F_->createSave("ret" , sub); |
2128 | |
2129 | // The expression evaluates to "I". |
2130 | |
2131 | EXPECT_EQ(F_->getNodes().size(), 8); |
2132 | |
2133 | ::glow::optimize(F_, CompilationMode::Infer); |
2134 | |
2135 | EXPECT_EQ(F_->getNodes().size(), 1); |
2136 | |
2137 | EXPECT_EQ(O->getInput().getNode(), input); |
2138 | |
2139 | optimizedF_ = optimizeFunctionForTest(F_); |
2140 | |
2141 | bindings_.allocate(mod_.getPlaceholders()); |
2142 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
2143 | |
2144 | checkNumericalEquivalence(); |
2145 | } |
2146 | |
2147 | // Similar to ZeroArithmetic, but tests that nodes with multiple results are |
2148 | // correctly handled (i.e. that the correct output is selected after optimising |
2149 | // away an arithmetic identity). |
2150 | TEST_F(GraphOptz, ZeroArithmeticMultiResNode) { |
2151 | auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input" , true); |
2152 | auto *topK = F_->createTopK("topK" , input, /*k=*/5); |
2153 | auto *zero = F_->createSplat("zero" , topK->getValues().getType(), 0.); |
2154 | auto *add = F_->createAdd("add" , topK->getValues(), zero); |
2155 | auto *sub = F_->createSub("sub" , topK->getValues(), zero); |
2156 | |
2157 | SaveNode *AS = F_->createSave("ret" , add); |
2158 | SaveNode *SS = F_->createSave("ret" , sub); |
2159 | |
2160 | // There should be 6 nodes: 2 Saves, Add, Sub, Splat and TopK. |
2161 | EXPECT_EQ(F_->getNodes().size(), 6); |
2162 | |
2163 | optimizedF_ = optimizeFunctionForTest(F_); |
2164 | |
2165 | // Now there should only be 3 nodes: TopK and 2 Saves. |
2166 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
2167 | |
2168 | auto *OAS = findFunctionNodeByName<SaveNode>(optimizedF_, AS->getName()); |
2169 | auto *OSS = findFunctionNodeByName<SaveNode>(optimizedF_, SS->getName()); |
2170 | auto *OTopK = findFunctionNodeByName<TopKNode>(optimizedF_, topK->getName()); |
2171 | |
2172 | // Since the operations reprsented by the arithmetic nodes are no-ops, |
2173 | // the input to both SaveNodes should be the Values result of TopKNode. |
2174 | EXPECT_EQ(OAS->getInput(), OTopK->getValues()); |
2175 | EXPECT_EQ(OSS->getInput(), OTopK->getValues()); |
2176 | |
2177 | // Check numerical equivalence. |
2178 | bindings_.allocate(mod_.getPlaceholders()); |
2179 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
2180 | |
2181 | checkNumericalEquivalence(); |
2182 | } |
2183 | |
2184 | /// A test that verifies that arithmetic simplification works correctly when |
2185 | /// the parents need to be simplified prior to the node itself. |
2186 | TEST_F(GraphOptz, ZeroArithmeticParentsMustBeSimplifiedFirst) { |
2187 | auto *input1 = |
2188 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1" , true); |
2189 | auto *input2 = |
2190 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2" , true); |
2191 | |
2192 | // This builds the expression: ((0 * I1) * (0 * I2)) = 0 |
2193 | // It should be simplified to simply the splat zero node being saved. |
2194 | |
2195 | SplatNode *zero = F_->createSplat("zero" , input1->getType(), 0.); |
2196 | |
2197 | MulNode *mul1 = F_->createMul("mul1" , zero, input1); // -> 0 |
2198 | MulNode *mul2 = F_->createMul("mul2" , zero, input2); // -> 0 |
2199 | |
2200 | MulNode *mul3 = F_->createMul("mul3" , mul1, mul2); // -> 0 |
2201 | |
2202 | SaveNode *O = F_->createSave("ret" , mul3); |
2203 | |
2204 | // Expect 1 splat, 3 muls, 1 save. |
2205 | EXPECT_EQ(F_->getNodes().size(), 5); |
2206 | |
2207 | ::glow::optimize(F_, CompilationMode::Infer); |
2208 | |
2209 | // Expect all muls to be optimized away, with 1 splat and 1 save left. |
2210 | EXPECT_EQ(F_->getNodes().size(), 2); |
2211 | EXPECT_TRUE(functionContainsNode(F_, O)); |
2212 | EXPECT_TRUE(functionContainsNode(F_, zero)); |
2213 | EXPECT_EQ(O->getInput().getNode(), zero); |
2214 | } |
2215 | |
2216 | /// Tests opts for the identities: [1 * X = X] [X / 1 = X] |
2217 | TEST_F(GraphOptz, ArithmeticIdentitiesOne) { |
2218 | auto *input = |
2219 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input" , true); |
2220 | |
2221 | // This builds the expression: (I / 1) * 1: |
2222 | SplatNode *one = F_->createSplat("one" , input->getType(), 1.); |
2223 | DivNode *div = F_->createDiv("div" , input, one); |
2224 | MulNode *mul = F_->createMul("mul" , div, one); |
2225 | SaveNode *save = F_->createSave("ret" , mul); |
2226 | |
2227 | // Splat, Div, Mul, Save. |
2228 | EXPECT_EQ(F_->getNodes().size(), 4); |
2229 | // Save optimized function for future comparision |
2230 | optimizedF_ = optimizeFunctionForTest(F_); |
2231 | |
2232 | // The expression evaluates to "I", so Save is only node left. |
2233 | EXPECT_EQ(optimizedF_->getNodes().size(), 1); |
2234 | SaveNode *SN = |
2235 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName())); |
2236 | ASSERT_TRUE(functionContainsNode(optimizedF_, SN)); |
2237 | ASSERT_NE(SN, nullptr); |
2238 | |
2239 | // Save node should just save the input. |
2240 | EXPECT_TRUE(SN->getInput().getNode() == input); |
2241 | |
2242 | bindings_.allocate(mod_.getPlaceholders()); |
2243 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
2244 | |
2245 | checkNumericalEquivalence(); |
2246 | } |
2247 | |
2248 | /// Reverse the intrusive list of nodes. This custom implementation is required, |
2249 | /// because std::reverse cannot be used with LLVM's intrusive lists. |
2250 | static void reverse(NodesList &L) { |
2251 | if (L.empty()) |
2252 | return; |
2253 | // Last element of the list before reversal. |
2254 | auto &last = L.back(); |
2255 | // Take element from the beginning and move it right after the old last |
2256 | // element. Do it until the old last element becomes the first element. |
2257 | while (true) { |
2258 | auto &first = L.front(); |
2259 | // Finish when the old last element becomes the new front element. |
2260 | if (&first == &last) { |
2261 | break; |
2262 | } |
2263 | L.remove(first); |
2264 | L.insert(++last.getIterator(), &first); |
2265 | } |
2266 | } |
2267 | |
2268 | TEST(GraphOptzTest, SliceOfSplatNodeChain) { |
2269 | for (int shouldReverse = 0; shouldReverse <= 1; shouldReverse++) { |
2270 | Module mod; |
2271 | Function *F = mod.createFunction("foo" ); |
2272 | |
2273 | Type t(ElemKind::FloatTy, {1000, 1000, 1000}); |
2274 | Node *Z = F->createSplat("zero" , &t, 0.); |
2275 | Node *S1 = F->createSlice("slice1" , Z, {5, 15, 42}, {99, 88, 77}); |
2276 | Node *S2 = F->createSlice("slice2" , S1, {1, 1, 1}, {2, 3, 4}); |
2277 | F->createSave("ret" , S2); |
2278 | |
2279 | if (shouldReverse) { |
2280 | auto &nodes = F->getNodes(); |
2281 | reverse(nodes); |
2282 | } |
2283 | |
2284 | EXPECT_EQ(F->getNodes().size(), 4); |
2285 | |
2286 | CompilationContext cctx; |
2287 | cctx.compMode = CompilationMode::Train; |
2288 | // Do not perform any compile-time constant folding. |
2289 | cctx.optimizationOpts.enableConstantFolding = false; |
2290 | ::glow::optimize(F, cctx); |
2291 | |
2292 | // This test illustrates some inconsistency in the optimization. |
2293 | // Chain splats are not guaranteed to be optimized. |
2294 | EXPECT_EQ(F->getNodes().size(), shouldReverse ? 3 : 2); |
2295 | } |
2296 | } |
2297 | |
2298 | TEST_F(GraphOptz, ReshapeNoop) { |
2299 | const dim_t shape[] = {10, 20, 30}; |
2300 | Type t(ElemKind::FloatTy, shape); |
2301 | auto *Z = F_->createSplat("zero" , &t, 0.); |
2302 | auto *R = F_->createReshape("reshape" , Z, shape); |
2303 | auto *O = F_->createSave("ret" , R); |
2304 | |
2305 | EXPECT_EQ(F_->getNodes().size(), 3); |
2306 | |
2307 | ::glow::optimize(F_, CompilationMode::Train); |
2308 | |
2309 | EXPECT_EQ(F_->getNodes().size(), 2); |
2310 | |
2311 | auto *SN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput()); |
2312 | EXPECT_TRUE(SN); |
2313 | |
2314 | EXPECT_TRUE(SN->getResult().getType()->dims().equals(shape)); |
2315 | } |
2316 | |
2317 | /// Test the Reshape(Splat(args)) -> Splat(args') transformation. |
2318 | /// Including a positive and a negative test case. In the positive case, |
2319 | /// the optimization will take place for the splat node (Z2) that has only one |
2320 | /// use. In the negative case, the optimization will not happen as the splat |
2321 | /// node (Z1) has more than one use. |
2322 | TEST_F(GraphOptz, ReshapeAfterSplat) { |
2323 | const dim_t shape[] = {10, 20, 30}; |
2324 | const dim_t reshape[] = {1, 6000}; |
2325 | Type t1(ElemKind::FloatTy, shape); |
2326 | Type t2(ElemKind::FloatTy, reshape); |
2327 | Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape, |
2328 | "input" , true); |
2329 | auto *Z1 = F_->createSplat("zero1" , &t1, 1.5); |
2330 | auto *A1 = F_->createAdd("add1" , Z1->getResult().getType(), input, Z1); |
2331 | auto *R1 = F_->createReshape("reshape1" , Z1, reshape); |
2332 | // Z1 is used by R1 and A1. |
2333 | // The reshape optimization will thus NOT be able to remove this reshape node |
2334 | // (R1). |
2335 | auto *R2 = F_->createReshape("reshape2" , A1, reshape); |
2336 | auto *A2 = F_->createAdd("add" , R1->getResult().getType(), R1, R2); |
2337 | auto *Z2 = F_->createSplat("zero2" , &t1, 2.5); |
2338 | auto *R3 = F_->createReshape("reshape3" , Z2, reshape); |
2339 | // Z2 is only used by R3. |
2340 | // The Z2,R3 nodes will be replaced by a new splat node with the shape of R3. |
2341 | auto *A3 = F_->createAdd("add" , A2->getResult().getType(), A2, R3); |
2342 | auto *O = F_->createSave("ret" , A3); |
2343 | |
2344 | // Before optimization, we have 9 nodes in the graph. |
2345 | EXPECT_EQ(F_->getNodes().size(), 9); |
2346 | |
2347 | cctx_.compMode = CompilationMode::Infer; |
2348 | // Do not perform any compile-time constant folding. |
2349 | cctx_.optimizationOpts.enableConstantFolding = false; |
2350 | ::glow::optimize(F_, cctx_); |
2351 | |
2352 | // After optimization, we expect to see only 8 nodes, as Z2,R2 would be |
2353 | // replace by a new splat node. |
2354 | EXPECT_EQ(F_->getNodes().size(), 8); |
2355 | |
2356 | // The second input of A3 shoule be a splat node with a shape of R3. |
2357 | auto *newA3 = llvm::dyn_cast<AddNode>(O->getInput()); |
2358 | ASSERT_TRUE(newA3); |
2359 | auto *SN = llvm::dyn_cast<SplatNode>(newA3->getRHS()); |
2360 | EXPECT_TRUE(SN); |
2361 | EXPECT_TRUE(SN->getResult().getType()->dims().equals(reshape)); |
2362 | |
2363 | // R1 should still be in the graph. |
2364 | EXPECT_TRUE(functionContainsNode(F_, R1)); |
2365 | |
2366 | // R3 and Z2 should not be in the graph any more. |
2367 | EXPECT_FALSE(functionContainsNode(F_, R3)); |
2368 | EXPECT_FALSE(functionContainsNode(F_, Z2)); |
2369 | } |
2370 | |
2371 | /// Test the Reshape(Reshape(x)) -> Reshape(x) transformation. |
2372 | TEST_F(GraphOptz, ReshapeReshapeOpt) { |
2373 | const dim_t shape[] = {10, 20}; |
2374 | const dim_t reshape1[] = {200, 1}; |
2375 | const dim_t reshape2[] = {200}; |
2376 | Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape, |
2377 | "input" , true); |
2378 | auto *R1 = F_->createReshape("reshape1" , input, reshape1); |
2379 | auto *R2 = F_->createReshape("reshape2" , R1, reshape2); |
2380 | auto *O = F_->createSave("ret" , R2); |
2381 | |
2382 | // Before optimization, we have 2 Reshapes and a Save. |
2383 | EXPECT_EQ(F_->getNodes().size(), 3); |
2384 | |
2385 | ::glow::optimize(F_, CompilationMode::Infer); |
2386 | |
2387 | // After optimization, we expect to see only 1 Reshape and a Save. |
2388 | EXPECT_EQ(F_->getNodes().size(), 2); |
2389 | |
2390 | // Save should have the new Reshape as input. |
2391 | auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput()); |
2392 | ASSERT_TRUE(RN); |
2393 | // The new Reshape should have the same shape as the original second Reshape. |
2394 | EXPECT_TRUE(RN->getResult().getType()->dims().equals(reshape2)); |
2395 | |
2396 | // R1 and R2 should not be in the graph any more; they were replaced by a |
2397 | // single new reshape. |
2398 | EXPECT_FALSE(functionContainsNode(F_, R1)); |
2399 | EXPECT_FALSE(functionContainsNode(F_, R2)); |
2400 | } |
2401 | |
2402 | TEST_F(GraphOptz, DCEPublicVars) { |
2403 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input" , true); |
2404 | |
2405 | EXPECT_EQ(mod_.getPlaceholders().size(), 1); |
2406 | |
2407 | // Optimize all of the dead code. |
2408 | ::glow::optimize(F_, CompilationMode::Infer); |
2409 | |
2410 | // Public nodes should not be deleted. |
2411 | EXPECT_EQ(mod_.getPlaceholders().size(), 1); |
2412 | } |
2413 | |
2414 | TEST_F(GraphOptz, foldQuantizeIntoConstant) { |
2415 | auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input" , true); |
2416 | *bindings_.allocate(input) = {10, 10, 10, 10}; |
2417 | auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0); |
2418 | |
2419 | auto *Q = F_->createQuantize("quantize" , input, qType); |
2420 | auto *S = F_->createSave("save" , Q); |
2421 | |
2422 | EXPECT_EQ(2, F_->getNodes().size()); |
2423 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {S->getPlaceholder()}); |
2424 | |
2425 | // 'optimize' doesn't merge quantize nodes into Constant. |
2426 | ::glow::optimize(F_, CompilationMode::Infer); |
2427 | EXPECT_EQ(2, F_->getNodes().size()); |
2428 | |
2429 | // 'convertQuantizedConstants' merges quantize nodes into Constant |
2430 | CompilationContext cctx; |
2431 | ::glow::convertQuantizedConstants(F_, cctx); |
2432 | EXPECT_EQ(1, F_->getNodes().size()); |
2433 | |
2434 | auto quantizedInput = llvm::cast<Constant>(S->getInput()); |
2435 | auto quantizedValues = quantizedInput->getHandle<int8_t>(); |
2436 | for (unsigned i = 0; i < 4; ++i) { |
2437 | EXPECT_EQ(5, quantizedValues.raw(i)); |
2438 | } |
2439 | } |
2440 | |
2441 | TEST_F(GraphOptz, foldQuantizeIntoConstantMultipleUsages) { |
2442 | auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input" , true); |
2443 | *bindings_.allocate(input) = {10, 10, 10, 10}; |
2444 | auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0); |
2445 | |
2446 | auto *Q = F_->createQuantize("quantize" , input, qType); |
2447 | F_->createSave("save" , Q); |
2448 | auto clonedF = F_->clone("cloned" ); |
2449 | |
2450 | EXPECT_EQ(2, clonedF->getNodes().size()); |
2451 | ::glow::convertPlaceholdersToConstants(clonedF, bindings_, {}); |
2452 | CompilationContext cctx; |
2453 | ::glow::convertQuantizedConstants(clonedF, cctx); |
2454 | |
2455 | // F_ function should not be affected. |
2456 | EXPECT_EQ(2, F_->getNodes().size()); |
2457 | |
2458 | // Check original var. |
2459 | for (unsigned i = 0; i < 4; ++i) { |
2460 | EXPECT_EQ(10, bindings_.get(input)->getHandle().raw(i)); |
2461 | } |
2462 | |
2463 | // Quantization node was merged into input var. |
2464 | EXPECT_EQ(1, clonedF->getNodes().size()); |
2465 | auto *save = llvm::dyn_cast<SaveNode>(&clonedF->getNodes().front()); |
2466 | ASSERT_TRUE(save); |
2467 | auto quantizedInput = llvm::cast<Constant>(save->getInput()); |
2468 | auto quantizedValues = quantizedInput->getHandle<int8_t>(); |
2469 | for (unsigned i = 0; i < 4; ++i) { |
2470 | EXPECT_EQ(5, quantizedValues.raw(i)); |
2471 | } |
2472 | } |
2473 | |
2474 | /// Search for a unique Save node in input graph \p F and return it. |
2475 | /// Fails in case there is no Save node or more than one detected. |
2476 | static SaveNode *getUniqueSaveNode(Function *F) { |
2477 | SaveNode *foundSaveNode = nullptr; |
2478 | for (auto &node : F->getNodes()) { |
2479 | if (auto *s = llvm::dyn_cast<SaveNode>(&node)) { |
2480 | EXPECT_EQ(foundSaveNode, nullptr); |
2481 | foundSaveNode = s; |
2482 | } |
2483 | } |
2484 | EXPECT_NE(foundSaveNode, nullptr); |
2485 | return foundSaveNode; |
2486 | } |
2487 | |
2488 | /// Mock backend that requests the pre-quantization of constants. |
2489 | class MockBackendPrequantizeConst : public MockBackend { |
2490 | bool shouldPreQuantizeConstants() const override { return true; } |
2491 | bool isOpSupported(const NodeInfo &) const override { return true; } |
2492 | Expected<bool> |
2493 | transformPostLowering(Function *F, CompilationContext &, |
2494 | const glow::runtime::DeviceInfo *) const override { |
2495 | // Check the IR. |
2496 | EXPECT_EQ(F->getNodes().size(), 1); |
2497 | auto *save = getUniqueSaveNode(F); |
2498 | EXPECT_TRUE(llvm::isa<Constant>(save->getInput())); |
2499 | |
2500 | return false; |
2501 | } |
2502 | }; |
2503 | /// Mock backend that requests the non pre-quantization of constants. |
2504 | class MockBackendNotPrequantizeConst : public MockBackend { |
2505 | bool shouldPreQuantizeConstants() const override { return false; } |
2506 | bool isOpSupported(const NodeInfo &) const override { return true; } |
2507 | Expected<bool> |
2508 | transformPostLowering(Function *F, CompilationContext &, |
2509 | const glow::runtime::DeviceInfo *) const override { |
2510 | // Check the IR. |
2511 | EXPECT_EQ(F->getNodes().size(), 2); |
2512 | auto *save = getUniqueSaveNode(F); |
2513 | auto *quant = llvm::dyn_cast<QuantizeNode>(save->getInput()); |
2514 | EXPECT_TRUE(quant); |
2515 | EXPECT_TRUE(llvm::isa<Constant>(quant->getInput())); |
2516 | |
2517 | return false; |
2518 | } |
2519 | }; |
2520 | |
2521 | /// Test the actual constant quantization for backends. |
2522 | template <typename Backend> |
2523 | void testFoldQuantizeIntoConstant(Module &mod_, Function *F_) { |
2524 | auto *input = mod_.createConstant(ElemKind::FloatTy, {4}, "input" ); |
2525 | input->getHandle<float>() = {10, 10, 10, 10}; |
2526 | auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0); |
2527 | auto *Q = F_->createQuantize("quantize" , input, qType); |
2528 | auto *save = F_->createSave("save" , Q); |
2529 | |
2530 | CompilationContext cctx; |
2531 | auto B = Backend(); |
2532 | // Note: the check that Quantize is or not folded into Constant before |
2533 | // post-lowering is done in <backend>::transformPostLowering() |
2534 | EXIT_ON_ERR(::glow::optimizeFunction(F_, B, cctx)); |
2535 | |
2536 | // Check the IR (the constant must have been quantized). |
2537 | EXPECT_EQ(F_->getNodes().size(), 1); |
2538 | EXPECT_TRUE(llvm::isa<Constant>(save->getInput())); |
2539 | } |
2540 | |
2541 | /// Check the backend actual constant quantization is done before post-lowering. |
2542 | TEST_F(GraphOptz, foldQuantizeIntoConstantBeforePostLowering) { |
2543 | testFoldQuantizeIntoConstant<MockBackendPrequantizeConst>(mod_, F_); |
2544 | } |
2545 | |
2546 | /// Check the backend actual constant quantization is done after post-lowering. |
2547 | TEST_F(GraphOptz, foldQuantizeIntoConstantAfterPostLowering) { |
2548 | testFoldQuantizeIntoConstant<MockBackendNotPrequantizeConst>(mod_, F_); |
2549 | } |
2550 | |
2551 | /// Check that the Quantize(Splat) -> Splat' optimization works. |
2552 | TEST_F(GraphOptz, foldQuantizeIntoSplat) { |
2553 | TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4}); |
2554 | TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0); |
2555 | |
2556 | const float splatVal = 6.0; |
2557 | SplatNode *SN = F_->createSplat("splat" , fType, splatVal); |
2558 | |
2559 | QuantizeNode *Q = F_->createQuantize("quantize" , SN, qType); |
2560 | SaveNode *S = F_->createSave("save" , Q); |
2561 | |
2562 | // Splat, quantize, save. |
2563 | EXPECT_EQ(3, F_->getNodes().size()); |
2564 | |
2565 | ::glow::optimize(F_, CompilationMode::Infer); |
2566 | |
2567 | // Quantization node was merged into input splat. |
2568 | EXPECT_EQ(2, F_->getNodes().size()); |
2569 | |
2570 | // New quantized splat should exist with same value. |
2571 | SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput()); |
2572 | ASSERT_TRUE(newSN); |
2573 | EXPECT_EQ(splatVal, newSN->getValue()); |
2574 | EXPECT_EQ(qType, newSN->getResult().getType()); |
2575 | } |
2576 | |
2577 | /// Check that the Dequantize(Splat) -> Splat' optimization works. |
2578 | TEST_F(GraphOptz, foldDequantizeIntoSplat) { |
2579 | TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4}); |
2580 | TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0); |
2581 | |
2582 | const float splatVal = 6.0; |
2583 | SplatNode *SN = F_->createSplat("splat" , qType, splatVal); |
2584 | |
2585 | DequantizeNode *Q = F_->createDequantize("dequantize" , SN, ElemKind::FloatTy); |
2586 | SaveNode *S = F_->createSave("save" , Q); |
2587 | |
2588 | // Splat, dequantize, save. |
2589 | EXPECT_EQ(3, F_->getNodes().size()); |
2590 | |
2591 | ::glow::optimize(F_, CompilationMode::Infer); |
2592 | |
2593 | // Dequantization node was merged into input splat. |
2594 | EXPECT_EQ(2, F_->getNodes().size()); |
2595 | |
2596 | // New quantized splat should exist with same value. |
2597 | SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput()); |
2598 | ASSERT_TRUE(newSN); |
2599 | EXPECT_EQ(splatVal, newSN->getValue()); |
2600 | EXPECT_EQ(fType, newSN->getResult().getType()); |
2601 | } |
2602 | |
2603 | /// Check that the Quantize(Splat) -> Splat' optimization works when the Splat |
2604 | /// has multiple users. |
2605 | TEST_F(GraphOptz, foldQuantizeIntoSplatMultipleUsers) { |
2606 | TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4}); |
2607 | TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0); |
2608 | |
2609 | SplatNode *SN = F_->createSplat("splat" , fType, 6.0); |
2610 | |
2611 | QuantizeNode *Q = F_->createQuantize("quantize" , SN, qType); |
2612 | SaveNode *SQ = F_->createSave("saveQ" , Q); |
2613 | SaveNode *SF = F_->createSave("saveF" , SN); |
2614 | |
2615 | // Splat, quantize, 2 saves. |
2616 | EXPECT_EQ(4, F_->getNodes().size()); |
2617 | |
2618 | ::glow::optimize(F_, CompilationMode::Infer); |
2619 | |
2620 | // Quantization node was merged into input splat creating a new quantized |
2621 | // splat, but the original float splat still exists. |
2622 | EXPECT_EQ(4, F_->getNodes().size()); |
2623 | |
2624 | // New quantized splat should exist with same value. |
2625 | SplatNode *newSN = llvm::dyn_cast<SplatNode>(SQ->getInput()); |
2626 | ASSERT_TRUE(newSN); |
2627 | EXPECT_EQ(SN->getValue(), newSN->getValue()); |
2628 | EXPECT_EQ(qType, newSN->getResult().getType()); |
2629 | |
2630 | // Original float splat should still exist. |
2631 | EXPECT_EQ(llvm::dyn_cast<SplatNode>(SF->getInput()), SN); |
2632 | } |
2633 | |
2634 | /// Check that an unnecessary rescale gets removed. |
2635 | TEST_F(GraphOptz, removeUnnecessaryRescale) { |
2636 | TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5); |
2637 | Placeholder *input = |
2638 | mod_.createPlaceholder(qType, "input" , /* isTrainable */ true); |
2639 | RescaleQuantizedNode *RQ = |
2640 | F_->createRescaleQuantized("rescale" , input, qType); |
2641 | SaveNode *save = F_->createSave("ret" , RQ); |
2642 | |
2643 | // RescaleQuantized and Save. |
2644 | EXPECT_EQ(F_->getNodes().size(), 2); |
2645 | |
2646 | ::glow::optimize(F_, CompilationMode::Infer); |
2647 | |
2648 | // Only Save should be left, which saves the Placeholder directly with |
2649 | // unchanged quantization parameters. |
2650 | EXPECT_EQ(F_->getNodes().size(), 1); |
2651 | EXPECT_EQ(save->getInput().getNode(), input); |
2652 | EXPECT_EQ(save->getInput().getType(), qType); |
2653 | } |
2654 | |
2655 | /// Check that rescale gets correctly merged into a following dequantize node |
2656 | TEST_F(GraphOptz, mergeRescaleIntoDequantize) { |
2657 | // Check that we are combining quantization-dequantization pairs. |
2658 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11, |
2659 | "input" , true); |
2660 | auto *qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5); |
2661 | auto *R = F_->createRescaleQuantized("rescale" , input, qType); |
2662 | auto *D = F_->createDequantize("dequantize" , R, ElemKind::FloatTy); |
2663 | F_->createSave("ret" , D); |
2664 | |
2665 | EXPECT_EQ(F_->getNodes().size(), 3); |
2666 | ::glow::optimize(F_, CompilationMode::Infer); |
2667 | |
2668 | // Only 2 nodes should remain (Dequantize -> Save) |
2669 | EXPECT_EQ(F_->getNodes().size(), 2); |
2670 | |
2671 | // Check the graph structure |
2672 | auto *SN = F_->getNodeByName("ret_save" ); |
2673 | EXPECT_NE(nullptr, SN); |
2674 | auto *S = llvm::dyn_cast<SaveNode>(SN); |
2675 | EXPECT_NE(nullptr, S); |
2676 | auto *newDN = S->getInput().getNode(); |
2677 | EXPECT_NE(nullptr, newDN); |
2678 | EXPECT_NE(nullptr, llvm::dyn_cast<DequantizeNode>(newDN)); |
2679 | } |
2680 | |
2681 | TEST_F(GraphOptz, quantizeToRescale) { |
2682 | // Check that we are combining quantization-dequantization pairs. |
2683 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11, |
2684 | "input" , true); |
2685 | |
2686 | auto *D = F_->createDequantize("dequantize" , input, ElemKind::FloatTy); |
2687 | |
2688 | auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03, 5); |
2689 | auto *Q = F_->createQuantize("quantize" , D, qType); |
2690 | |
2691 | F_->createSave("ret" , Q); |
2692 | |
2693 | EXPECT_EQ(F_->getNodes().size(), 3); |
2694 | |
2695 | ::glow::optimize(F_, CompilationMode::Infer); |
2696 | EXPECT_EQ(F_->getNodes().size(), 2); |
2697 | } |
2698 | |
2699 | TEST_F(GraphOptz, MaxOfQuantizedSplat) { |
2700 | const dim_t size = 5; |
2701 | const float scale = 1; |
2702 | // offset == -128 guarantees that fp range has values which are not less than |
2703 | // 0. |
2704 | const int32_t offset = -128; |
2705 | |
2706 | auto splatTy = mod_.uniqueType(ElemKind::Int8QTy, {size}, scale, offset); |
2707 | auto *splat = F_->createSplat("splat" , splatTy, 0.0); |
2708 | |
2709 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {size}, scale, offset, |
2710 | "input" , true); |
2711 | |
2712 | auto *max = F_->createMax("max" , splat, input); |
2713 | F_->createSave("save" , max); |
2714 | EXPECT_EQ(F_->getNodes().size(), 3); |
2715 | |
2716 | ::glow::optimize(F_, CompilationMode::Infer); |
2717 | // Splat and Max should be gone. |
2718 | EXPECT_EQ(F_->getNodes().size(), 1); |
2719 | } |
2720 | |
2721 | TEST_F(GraphOptz, FuseRescaleIntoArithmetic) { |
2722 | // This test ensures the fact that fusing of rescale is done. |
2723 | auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0); |
2724 | auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1); |
2725 | |
2726 | Placeholder *LHS = |
2727 | mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS" , true); |
2728 | Placeholder *RHS = |
2729 | mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS" , true); |
2730 | |
2731 | AddNode *add = F_->createAdd("qAdd" , opOutTy, LHS, RHS); |
2732 | RescaleQuantizedNode *rescaleAdd = |
2733 | F_->createRescaleQuantized("rsAdd" , add, rescaleOutTy); |
2734 | SaveNode *addSave = F_->createSave("saveAdd" , rescaleAdd); |
2735 | |
2736 | SubNode *sub = F_->createSub("qSub" , opOutTy, LHS, RHS); |
2737 | RescaleQuantizedNode *rescaleSub = |
2738 | F_->createRescaleQuantized("rsSub" , sub, rescaleOutTy); |
2739 | SaveNode *subSave = F_->createSave("saveSub" , rescaleSub); |
2740 | |
2741 | DivNode *div = F_->createDiv("qDiv" , opOutTy, LHS, RHS); |
2742 | RescaleQuantizedNode *rescaleDiv = |
2743 | F_->createRescaleQuantized("rsDiv" , div, rescaleOutTy); |
2744 | SaveNode *divSave = F_->createSave("saveDiv" , rescaleDiv); |
2745 | |
2746 | MulNode *mul = F_->createMul("qMul" , opOutTy, LHS, RHS); |
2747 | RescaleQuantizedNode *rescaleMul = |
2748 | F_->createRescaleQuantized("rsMul" , mul, rescaleOutTy); |
2749 | SaveNode *mulSave = F_->createSave("saveMul" , rescaleMul); |
2750 | |
2751 | MinNode *min = F_->createMin("qMin" , opOutTy, LHS, RHS); |
2752 | RescaleQuantizedNode *rescaleMin = |
2753 | F_->createRescaleQuantized("rsMin" , min, rescaleOutTy); |
2754 | SaveNode *minSave = F_->createSave("saveMin" , rescaleMin); |
2755 | |
2756 | MaxNode *max = F_->createMax("qMax" , opOutTy, LHS, RHS); |
2757 | RescaleQuantizedNode *rescaleMax = |
2758 | F_->createRescaleQuantized("rsMax" , max, rescaleOutTy); |
2759 | SaveNode *maxSave = F_->createSave("saveMax" , rescaleMax); |
2760 | |
2761 | // All rescales must be fused into arithmetic operations above. |
2762 | ::glow::optimize(F_, CompilationMode::Infer); |
2763 | |
2764 | EXPECT_EQ(F_->getNodes().size(), 12); |
2765 | |
2766 | EXPECT_EQ(addSave->getInput().getType(), rescaleOutTy); |
2767 | EXPECT_EQ(subSave->getInput().getType(), rescaleOutTy); |
2768 | EXPECT_EQ(mulSave->getInput().getType(), rescaleOutTy); |
2769 | EXPECT_EQ(divSave->getInput().getType(), rescaleOutTy); |
2770 | EXPECT_EQ(minSave->getInput().getType(), rescaleOutTy); |
2771 | EXPECT_EQ(maxSave->getInput().getType(), rescaleOutTy); |
2772 | } |
2773 | |
2774 | /// Check that the Rescale(MatMul) -> MatMul' optimization works correctly. |
2775 | TEST_F(GraphOptz, FuseRescaleUpIntoMatMul) { |
2776 | // This test ensures the fact that fusing of rescale is done. |
2777 | auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0); |
2778 | auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1); |
2779 | |
2780 | Placeholder *LHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0, |
2781 | "LHS" , /* isTrainable */ false); |
2782 | Placeholder *RHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0, |
2783 | "RHS" , /* isTrainable */ false); |
2784 | |
2785 | MatMulNode *MMN = F_->createMatMul("matmul" , opOutTy, LHS, RHS); |
2786 | RescaleQuantizedNode *rescaleMMN = |
2787 | F_->createRescaleQuantized("rsMMN" , MMN, rescaleOutTy); |
2788 | SaveNode *saveMMN = F_->createSave("saveMMN" , rescaleMMN); |
2789 | |
2790 | // MatMul, Rescale, Save. |
2791 | EXPECT_EQ(F_->getNodes().size(), 3); |
2792 | |
2793 | // All rescales must be fused into arithmetic operations above. |
2794 | ::glow::optimize(F_, CompilationMode::Infer); |
2795 | |
2796 | // Rescale merged up into the MatMul. |
2797 | EXPECT_EQ(F_->getNodes().size(), 2); |
2798 | |
2799 | MatMulNode *newMMN = llvm::dyn_cast<MatMulNode>(saveMMN->getInput()); |
2800 | ASSERT_TRUE(newMMN); |
2801 | EXPECT_EQ(newMMN->getResult().getType(), rescaleOutTy); |
2802 | } |
2803 | |
2804 | /// Check that the Rescale(SparseLengthsWeightedSum) -> |
2805 | /// SparseLengthsWeightedSum' optimization works correctly. |
2806 | TEST_F(GraphOptz, FuseRescaleUpIntoSparseLengthsWeightedSum) { |
2807 | // This test ensures the fact that fusing of rescale is done. |
2808 | TypeRef rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 1); |
2809 | |
2810 | Placeholder *data = |
2811 | mod_.createPlaceholder(ElemKind::Int8QTy, {3}, 0.5, 0, "data" , |
2812 | /* isTrainable */ false); |
2813 | Placeholder *weights = mod_.createPlaceholder( |
2814 | ElemKind::Int8QTy, {8}, 0.5, 0, "weights" , /* isTrainable */ false); |
2815 | Placeholder *indices = |
2816 | mod_.createPlaceholder(ElemKind::Int64ITy, {8}, "indices" , |
2817 | /* isTrainable */ false); |
2818 | Placeholder *lengths = |
2819 | mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths" , |
2820 | /* isTrainable */ false); |
2821 | |
2822 | SparseLengthsWeightedSumNode *SLWS = F_->createSparseLengthsWeightedSum( |
2823 | "SLWS" , data, weights, indices, lengths); |
2824 | RescaleQuantizedNode *rescaleSLWS = |
2825 | F_->createRescaleQuantized("rsSLWS" , SLWS, rescaleOutTy); |
2826 | SaveNode *saveSLWS = F_->createSave("saveSLWS" , rescaleSLWS); |
2827 | |
2828 | // SparseLengthsWeightedSum, Rescale, Save. |
2829 | EXPECT_EQ(F_->getNodes().size(), 3); |
2830 | |
2831 | // All rescales must be fused into arithmetic operations above. |
2832 | ::glow::optimize(F_, CompilationMode::Infer); |
2833 | |
2834 | // Rescale merged up into the SparseLengthsWeightedSum. |
2835 | EXPECT_EQ(F_->getNodes().size(), 2); |
2836 | |
2837 | SparseLengthsWeightedSumNode *newSLWS = |
2838 | llvm::dyn_cast<SparseLengthsWeightedSumNode>(saveSLWS->getInput()); |
2839 | ASSERT_TRUE(newSLWS); |
2840 | EXPECT_EQ(newSLWS->getResult().getType(), rescaleOutTy); |
2841 | } |
2842 | |
2843 | TEST_F(GraphOptz, fuseRescaleIntoConv) { |
2844 | // This test ensures the fact that fusing of rescale is done. |
2845 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.5, |
2846 | 10, "input" , true); |
2847 | auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.5, |
2848 | 10, "filter" , true); |
2849 | auto *bias = |
2850 | mod_.createPlaceholder(ElemKind::Int8QTy, {16}, 0.5, 10, "bias" , true); |
2851 | |
2852 | auto *rInput = F_->createRescaleQuantized( |
2853 | "rescale" , input, |
2854 | mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.1, -25)); |
2855 | auto *rFilter = F_->createRescaleQuantized( |
2856 | "rescale" , filter, |
2857 | mod_.uniqueType(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.2, 0)); |
2858 | auto *rBias = F_->createRescaleQuantized( |
2859 | "rescale" , bias, mod_.uniqueType(ElemKind::Int8QTy, {16}, 0.3, 25)); |
2860 | auto *CV = F_->createConv( |
2861 | "conv" , rInput, rFilter, rBias, |
2862 | mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.7, -3), 5, 1, 2, 1); |
2863 | auto *rCV = F_->createRescaleQuantized( |
2864 | "rescale" , CV, |
2865 | mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.4, 37)); |
2866 | F_->createSave("save" , rCV); |
2867 | |
2868 | // All rescales must be fused into convolution. |
2869 | EXPECT_EQ(F_->getNodes().size(), 6); |
2870 | ::glow::optimize(F_, CompilationMode::Infer); |
2871 | EXPECT_EQ(F_->getNodes().size(), 2); |
2872 | } |
2873 | |
2874 | /// This test ensures that if there is a Pad node as input of a Convolution |
2875 | /// node, Pad gets merges into Convolution. |
2876 | /// Note that Pads is merged into convolution only when it is compatible with |
2877 | /// the convolution padding: |
2878 | /// - Resulting padding after merge is positive |
2879 | /// - Padding only concerns spatial dimensions |
2880 | /// - Padding has mode 'constant' with value 0.f |
2881 | void fusePadIntoConvTest(glow::Module &mod_, glow::Function *F_, |
2882 | llvm::ArrayRef<dim_t> inputDims, |
2883 | llvm::ArrayRef<int> pads, unsigned_t convKernelSize, |
2884 | llvm::ArrayRef<unsigned_t> convPads, |
2885 | unsigned_t convStride, unsigned_t convNumKernels) { |
2886 | auto *input = |
2887 | mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input" , true); |
2888 | |
2889 | // Pad |
2890 | dim_t inputWithPadDims[4]; |
2891 | for (int i = 0; i < 4; i++) { |
2892 | inputWithPadDims[i] = dim_t(ssize_t(inputDims[i]) + pads[i] + pads[4 + i]); |
2893 | } |
2894 | dim_t outputConvDims[4] = { |
2895 | inputWithPadDims[0], |
2896 | inputWithPadDims[1] + convPads[0] + convPads[2] - (convKernelSize - 1), |
2897 | inputWithPadDims[2] + convPads[1] + convPads[3] - (convKernelSize - 1), |
2898 | convNumKernels}; |
2899 | |
2900 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, inputWithPadDims); |
2901 | Node *P = |
2902 | F_->createPad("pad" , input, outTy, PaddingMode::CONSTANT, pads, 0.f); |
2903 | |
2904 | // Convolution |
2905 | dim_t filterDims[] = {convNumKernels, convKernelSize, convKernelSize, |
2906 | inputDims[3]}; |
2907 | auto *F = |
2908 | mod_.createPlaceholder(ElemKind::FloatTy, filterDims, "filter" , true); |
2909 | auto *B = |
2910 | mod_.createPlaceholder(ElemKind::FloatTy, {convNumKernels}, "bias" , true); |
2911 | auto *CV = F_->createConv( |
2912 | "conv" , P, F, B, mod_.uniqueType(ElemKind::FloatTy, outputConvDims), |
2913 | {convKernelSize, convKernelSize}, {convStride, convStride}, convPads, 1); |
2914 | |
2915 | SaveNode *O = F_->createSave("save" , CV); |
2916 | |
2917 | // The pad node must be merged into convolution. |
2918 | EXPECT_EQ(F_->getNodes().size(), 3); |
2919 | ::glow::optimize(F_, CompilationMode::Infer); |
2920 | EXPECT_EQ(F_->getNodes().size(), 2); |
2921 | |
2922 | // Check the graph structure and additional properties after optimization. |
2923 | auto *conv = llvm::dyn_cast<ConvolutionNode>(O->getInput()); |
2924 | ASSERT_NE(conv, nullptr); |
2925 | EXPECT_EQ(conv->getResult().dims(), llvm::ArrayRef<dim_t>(outputConvDims)); |
2926 | unsigned_t expectedPads[4]; |
2927 | for (int i = 0; i < 2; i++) { |
2928 | for (int j = 0; j < 2; j++) { |
2929 | expectedPads[2 * i + j] = |
2930 | unsigned_t(int(convPads[2 * i + j]) + pads[4 * i + (1 + j)]); |
2931 | } |
2932 | } |
2933 | EXPECT_EQ(conv->getPads(), llvm::makeArrayRef(expectedPads)); |
2934 | } |
2935 | |
2936 | TEST_F(GraphOptz, fusePadIntoConv) { |
2937 | fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */, |
2938 | {0, 1, 2, 0, 0, 3, 4, 0} /* pads */, |
2939 | 5 /* convKernelSize */, {0, 0, 0, 0} /* convPads */, |
2940 | 1 /* convStride */, 16 /* convNumKernels */); |
2941 | } |
2942 | |
2943 | TEST_F(GraphOptz, fusePadIntoConvNeg1) { |
2944 | fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */, |
2945 | {0, -1, 2, 0, 0, 3, -2, 0} /* pads */, |
2946 | 5 /* convKernelSize */, {3, 0, 2, 5} /* convPads */, |
2947 | 1 /* convStride */, 16 /* convNumKernels */); |
2948 | } |
2949 | |
2950 | TEST_F(GraphOptz, fusePadIntoConvNeg2) { |
2951 | fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */, |
2952 | {0, 1, -2, 0, 0, -3, 4, 0} /* pads */, |
2953 | 5 /* convKernelSize */, {0, 2, 5, 7} /* convPads */, |
2954 | 1 /* convStride */, 16 /* convNumKernels */); |
2955 | } |
2956 | |
2957 | /// This test checks that a lowered LeakyRelu is corrected folded: |
2958 | /// Max(A, Mult(A, Splat)) -> PRelu(Splat) |
2959 | TEST_F(GraphFold, foldLeakyReluFromSplat) { |
2960 | std::vector<dim_t> dims = {5, 2}; |
2961 | |
2962 | auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input" , true); |
2963 | |
2964 | const float leakyAlpha = 0.05f; |
2965 | auto OutTy = mod_.uniqueType(ElemKind::FloatTy, dims); |
2966 | SplatNode *splatNode = F_->createSplat("splat" , OutTy, leakyAlpha); |
2967 | MulNode *mulNode = F_->createMul("mul" , input, splatNode); |
2968 | MaxNode *maxNode = F_->createMax("max" , input, mulNode); |
2969 | SaveNode *output = F_->createSave("save" , maxNode); |
2970 | |
2971 | EXPECT_EQ(4, F_->getNodes().size()); |
2972 | |
2973 | CompilationContext cctx; |
2974 | ::glow::fold(F_, cctx); |
2975 | |
2976 | // Check the resulting graph after folding. |
2977 | EXPECT_EQ(3, F_->getNodes().size()); |
2978 | auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput()); |
2979 | ASSERT_TRUE(newPReluNode); |
2980 | auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope()); |
2981 | ASSERT_TRUE(newSplatNode); |
2982 | EXPECT_EQ(leakyAlpha, newSplatNode->getValue()); |
2983 | EXPECT_EQ(input, newPReluNode->getInput()); |
2984 | } |
2985 | |
2986 | /// This test checks that a lowered LeakyRelu is corrected folded: |
2987 | /// Max(A, Mult(A, broadcasted Const)) -> PRelu(Splat) |
2988 | TEST_F(GraphFold, foldLeakyReluFromConst) { |
2989 | std::vector<dim_t> dims = {5, 2}; |
2990 | auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input" , true); |
2991 | |
2992 | const float leakyAlpha = 0.99f; |
2993 | auto *alphaConst = mod_.createConstant(ElemKind::FloatTy, {1}, "alphaConst" ); |
2994 | alphaConst->getHandle() = {leakyAlpha}; |
2995 | ReshapeNode *reshapeNode = F_->createReshape("reshape" , alphaConst, {1, 1}); |
2996 | TileNode *tileNode1 = F_->createTile("tile1" , reshapeNode, 2, 1); |
2997 | TileNode *tileNode2 = F_->createTile("tile2" , tileNode1, 5, 0); |
2998 | MulNode *mulNode = F_->createMul("mul" , input, tileNode2); |
2999 | MaxNode *maxNode = F_->createMax("max" , input, mulNode); |
3000 | SaveNode *output = F_->createSave("save" , maxNode); |
3001 | |
3002 | EXPECT_EQ(6, F_->getNodes().size()); |
3003 | |
3004 | CompilationContext cctx; |
3005 | ::glow::fold(F_, cctx); |
3006 | |
3007 | // Check the resulting graph after folding. Reshape must have been merged into |
3008 | // the constant and LeakyRelu must have been folded. |
3009 | EXPECT_EQ(3, F_->getNodes().size()); |
3010 | auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput()); |
3011 | ASSERT_TRUE(newPReluNode); |
3012 | auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope()); |
3013 | ASSERT_TRUE(newSplatNode); |
3014 | EXPECT_EQ(leakyAlpha, newSplatNode->getValue()); |
3015 | EXPECT_EQ(input, newPReluNode->getInput()); |
3016 | } |
3017 | |
3018 | /// Test optimization of Convolution nodes with small input tensors by reducing |
3019 | /// filters and removing redundant padding. |
3020 | TEST_F(GraphFold, optimizeSmallConv) { |
3021 | auto *input = |
3022 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 2, 16}, "input" , true); |
3023 | auto filter = |
3024 | mod_.createConstant(ElemKind::FloatTy, {16, 5, 5, 16}, "filter" ); |
3025 | auto bias = mod_.createConstant(ElemKind::FloatTy, {16}, "bias" ); |
3026 | |
3027 | filter->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
3028 | bias->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
3029 | |
3030 | auto *outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 1, 1, 16}); |
3031 | auto *CN = F_->createConv("conv" , input, filter, bias, outTy, {5, 5}, {2, 2}, |
3032 | {2, 1, 1, 2}, 1); |
3033 | auto *save = F_->createSave("save" , CN); |
3034 | |
3035 | EXPECT_EQ(2, F_->getNodes().size()); |
3036 | optimizedF_ = optimizeFunctionForTest(F_); |
3037 | EXPECT_EQ(2, optimizedF_->getNodes().size()); |
3038 | |
3039 | const auto *optSave = |
3040 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
3041 | |
3042 | auto *newCN = llvm::dyn_cast<ConvolutionNode>(optSave->getInput()); |
3043 | ASSERT_TRUE(newCN); |
3044 | // Kernel should be reduced. |
3045 | EXPECT_TRUE(isUniformArray(newCN->getKernels(), 2u)); |
3046 | // Padding should be removed. |
3047 | EXPECT_TRUE(isUniformArray(newCN->getPads(), 0u)); |
3048 | // Stride should be canonicalized to 1. |
3049 | EXPECT_TRUE(isUniformArray(newCN->getStrides(), 1u)); |
3050 | |
3051 | bindings_.allocate(mod_.getPlaceholders()); |
3052 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
3053 | checkNumericalEquivalence(); |
3054 | } |
3055 | |
3056 | TEST_F(GraphOptz, GatherToSliceOpt) { |
3057 | auto *LHS = mod_.createPlaceholder(ElemKind::Int32ITy, {16, 3}, "LHS" , false); |
3058 | auto *RHS = mod_.createConstant(ElemKind::Int32ITy, {}, "RHS" ); |
3059 | RHS->getPayloadMutable().getHandle<int32_t>() = {1}; |
3060 | |
3061 | auto *gather = F_->createGather("gather" , LHS, RHS, 1); |
3062 | auto *save = F_->createSave("save" , gather); |
3063 | |
3064 | optimizedF_ = optimizeFunctionForTest(F_); |
3065 | |
3066 | auto *saveOpt = |
3067 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName())); |
3068 | ASSERT_TRUE(saveOpt); |
3069 | auto *reshapeN = llvm::dyn_cast<ReshapeNode>(saveOpt->getInput()); |
3070 | ASSERT_TRUE(reshapeN); |
3071 | EXPECT_EQ(reshapeN->getResult().dims().size(), 1); |
3072 | EXPECT_EQ(reshapeN->getResult().dims()[0], 16); |
3073 | |
3074 | bindings_.allocate(LHS)->getHandle<int32_t>().randomize(-128, 127, |
3075 | mod_.getPRNG()); |
3076 | checkNumericalEquivalence(); |
3077 | } |
3078 | |
3079 | /// Fold a Convolution dilated manually using Transpose, SpaceToDepth and |
3080 | /// DepthToSpace nodes into a single Convolution node. Pattern: |
3081 | /// NHWC2CHWN -> S2D -> CHWN2NHWC -> Conv -> NHWC2CHWN -> D2S -> CHWN2NHWC |
3082 | TEST_F(GraphFold, foldDilatedConv) { |
3083 | auto *input = |
3084 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 16}, "input" , true); |
3085 | |
3086 | auto *T1 = F_->createTranspose("t1" , input, NHWC2CHWN, "NHWC" ); |
3087 | auto *S2D = F_->createSpaceToDepth("s2d" , T1, 2); |
3088 | auto *T2 = F_->createTranspose("t2" , S2D, CHWN2NHWC, "NHWC" ); |
3089 | auto *CN = F_->createConv(bindings_, "conv" , T2, 16, 3, 1, 0, 16, {1, 1}); |
3090 | auto *T3 = F_->createTranspose("t3" , CN, NHWC2CHWN, "NHWC" ); |
3091 | auto *D2S = F_->createDepthToSpace("d2s" , T3, 2); |
3092 | auto *T4 = F_->createTranspose("t4" , D2S, CHWN2NHWC, "NHWC" ); |
3093 | auto *save = F_->createSave("save" , T4); |
3094 | |
3095 | // To spice things up, add additional users for some nodes. The pattern should |
3096 | // still be recognized. |
3097 | F_->createSave("save_t1" , T1); |
3098 | F_->createSave("save_s2d" , S2D); |
3099 | F_->createSave("save_t2" , T2); |
3100 | |
3101 | EXPECT_EQ(13, F_->getNodes().size()); |
3102 | optimizedF_ = optimizeFunctionForTest(F_); |
3103 | EXPECT_EQ(8, optimizedF_->getNodes().size()); |
3104 | |
3105 | const auto *optSave = |
3106 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
3107 | |
3108 | auto *newCN = llvm::dyn_cast<ConvolutionNode>(optSave->getInput()); |
3109 | ASSERT_TRUE(newCN); |
3110 | EXPECT_TRUE(isUniformArray(newCN->getDilation(), 2u)); |
3111 | |
3112 | bindings_.allocate(mod_.getPlaceholders()); |
3113 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
3114 | checkNumericalEquivalence(); |
3115 | } |
3116 | |
3117 | /// Fold a Convolution dilated manually using Transpose, SpaceToDepth and |
3118 | /// DepthToSpace nodes into a single Convolution node. Pattern: |
3119 | /// NHWC2CHWN -> S2D -> CHWN2NHWC -> Conv -> NHWC2CHWN -> D2S -> CHWN2NHWC |
3120 | /// Test for ChannelwiseQuantizedConvolution. |
3121 | TEST_F(GraphFold, foldDilatedConv_ChannelwiseQuantized) { |
3122 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 10, 16}, 1.f, |
3123 | 0, "input" , true); |
3124 | |
3125 | auto *filterF = |
3126 | mod_.createConstant(ElemKind::FloatTy, {16, 3, 3, 16}, "filterF" ); |
3127 | filterF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0, |
3128 | mod_.getPRNG()); |
3129 | auto *biasF = mod_.createConstant(ElemKind::FloatTy, {16}, "biasF" ); |
3130 | biasF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0, |
3131 | mod_.getPRNG()); |
3132 | |
3133 | auto *T1 = F_->createTranspose("t1" , input, NHWC2CHWN, "NHWC" ); |
3134 | auto *S2D = F_->createSpaceToDepth("s2d" , T1, 2); |
3135 | auto *T2 = F_->createTranspose("t2" , S2D, CHWN2NHWC, "NHWC" ); |
3136 | auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {4, 3, 3, 16}, 1.f, 0); |
3137 | auto *CN = F_->createChannelwiseQuantizedConv( |
3138 | "conv" , T2, filterF, biasF, nullptr, nullptr, nullptr, nullptr, outTy, |
3139 | {3, 3}, {1, 1}, {0, 0, 0, 0}, 1, {1, 1}, true, true, |
3140 | quantization::Schema::Asymmetric, ElemKind::Int8QTy, ElemKind::Int32QTy); |
3141 | auto *T3 = F_->createTranspose("t3" , CN, NHWC2CHWN, "NHWC" ); |
3142 | auto *D2S = F_->createDepthToSpace("d2s" , T3, 2); |
3143 | auto *T4 = F_->createTranspose("t4" , D2S, CHWN2NHWC, "NHWC" ); |
3144 | auto *save = F_->createSave("save" , T4); |
3145 | |
3146 | EXPECT_EQ(10, F_->getNodes().size()); |
3147 | optimizedF_ = optimizeFunctionForTest(F_); |
3148 | EXPECT_EQ(2, optimizedF_->getNodes().size()); |
3149 | |
3150 | const auto *optSave = |
3151 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
3152 | |
3153 | auto *newCN = |
3154 | llvm::dyn_cast<ChannelwiseQuantizedConvolutionNode>(optSave->getInput()); |
3155 | ASSERT_TRUE(newCN); |
3156 | EXPECT_TRUE(isUniformArray(newCN->getDilation(), 2u)); |
3157 | |
3158 | bindings_.allocate(mod_.getPlaceholders()); |
3159 | bindings_.get(input)->getHandle<int8_t>().randomize(-128, 127, |
3160 | mod_.getPRNG()); |
3161 | checkNumericalEquivalence(); |
3162 | } |
3163 | |
3164 | /// Testing folding of Reshape->Transpose->Reshape into ChannelShuffle. |
3165 | TEST_F(GraphFold, foldChannelShuffle) { |
3166 | const dim_t inputDims[] = {3, 136, 28, 28}; |
3167 | |
3168 | Node *K = |
3169 | mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input" , false); |
3170 | K = F_->createReshape("CS_reshape1" , K, {3, 4, 34, 28, 28}); |
3171 | K = F_->createTranspose("CS_transpose" , K, {0, 2, 1, 3, 4}); |
3172 | K = F_->createReshape("CS_reshape2" , K, {3, 136, 28, 28}); |
3173 | auto *save = F_->createSave("ret" , K); |
3174 | |
3175 | EXPECT_EQ(F_->getNodes().size(), 4); |
3176 | |
3177 | // Fold RN->TR->RN into ChannelShuffle |
3178 | CompilationContext cctx; |
3179 | ::glow::fold(F_, cctx); |
3180 | |
3181 | ASSERT_EQ(F_->getNodes().size(), 2); |
3182 | |
3183 | // Check for ChannelShuffle node. |
3184 | auto *CS = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode()); |
3185 | ASSERT_NE(nullptr, CS); |
3186 | |
3187 | // Ensure ChannelShuffle node has the same dimensions as the input. |
3188 | EXPECT_EQ(CS->getResult().dims(), llvm::makeArrayRef(inputDims)); |
3189 | |
3190 | // Ensure Group and Kernel are as expected. |
3191 | EXPECT_EQ(CS->getGroup(), 4); |
3192 | EXPECT_EQ(CS->getKernel(), 1); |
3193 | } |
3194 | |
3195 | TEST_F(GraphFold, NoFoldChannelShuffle) { |
3196 | auto Float = ElemKind::FloatTy; |
3197 | auto *P = mod_.createPlaceholder(Float, {10, 8928}, "P" , false); |
3198 | auto *R1 = F_->createReshape("R1" , P, {10, 186, 48}); |
3199 | auto *TR = F_->createTranspose("TR" , R1, {0, 2, 1}); |
3200 | auto *R2 = F_->createReshape("R2" , TR, {480, 186}); |
3201 | auto *save = F_->createSave("save" , R2); |
3202 | |
3203 | EXPECT_EQ(F_->getNodes().size(), 4); |
3204 | |
3205 | CompilationContext cctx; |
3206 | ::glow::fold(F_, cctx); |
3207 | |
3208 | EXPECT_EQ(F_->getNodes().size(), 4); |
3209 | EXPECT_FALSE(llvm::isa<ChannelShuffleNode>(save->getInput())); |
3210 | } |
3211 | |
3212 | class MockBackendWithFusion : public MockBackend { |
3213 | bool supportsFusedActivation(Node *parent, Node *activation) const override { |
3214 | switch (parent->getKind()) { |
3215 | case Kinded::Kind::ConvolutionNodeKind: |
3216 | switch (activation->getKind()) { |
3217 | case Kinded::Kind::ReluNodeKind: |
3218 | case Kinded::Kind::ClipNodeKind: |
3219 | case Kinded::Kind::SigmoidNodeKind: |
3220 | case Kinded::Kind::TanhNodeKind: |
3221 | case Kinded::Kind::LeakyReluNodeKind: |
3222 | return true; |
3223 | default: |
3224 | return false; |
3225 | } |
3226 | default: |
3227 | return false; |
3228 | } |
3229 | } |
3230 | }; |
3231 | |
3232 | #define CONV_ACTIVATION_TEST(ACTIVATION_, CREATOR_, ...) \ |
3233 | TEST_F(GraphFold, FoldConv##ACTIVATION_##Activation) { \ |
3234 | auto *A = \ |
3235 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false); \ |
3236 | ConvolutionNode *CV = \ |
3237 | F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1); \ |
3238 | auto *AN = F_->CREATOR_(__VA_ARGS__); \ |
3239 | SaveNode *SN = F_->createSave("ret", AN); \ |
3240 | \ |
3241 | EXPECT_EQ(F_->getNodes().size(), 3); \ |
3242 | \ |
3243 | CompilationContext cctx; \ |
3244 | auto B = MockBackendWithFusion(); \ |
3245 | ::glow::fold(F_, cctx, &B); \ |
3246 | \ |
3247 | ConvolutionNode *fusedCV = \ |
3248 | llvm::dyn_cast<ConvolutionNode>(SN->getInput()); \ |
3249 | ASSERT_TRUE(fusedCV); \ |
3250 | EXPECT_EQ(fusedCV->getFusedActivation(), FusedActivation::ACTIVATION_); \ |
3251 | } |
3252 | |
3253 | CONV_ACTIVATION_TEST(RELU, createRELU, "Relu" , CV); |
3254 | CONV_ACTIVATION_TEST(CLIP, createClip, "Clip" , CV, 0.0, 1.0); |
3255 | CONV_ACTIVATION_TEST(SIGMOID, createSigmoid, "Sigmoid" , CV); |
3256 | CONV_ACTIVATION_TEST(TANH, createTanh, "Tanh" , CV); |
3257 | CONV_ACTIVATION_TEST(LEAKY_RELU, createLeakyRELU, "LeakyRelu" , CV, 1.0); |
3258 | |
3259 | #undef CONV_ACTIVATION_TEST |
3260 | |
3261 | /// This test ensures that if there is a RescaleNode whose input has multiple |
3262 | /// users that the input is not cloned, as this duplicates the node. |
3263 | TEST_F(GraphOptz, MultipleUsersRescaleCombineNoOpt) { |
3264 | auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0); |
3265 | auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1); |
3266 | |
3267 | Node *LHS = |
3268 | mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS" , true); |
3269 | Node *RHS = |
3270 | mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS" , true); |
3271 | |
3272 | AddNode *AN = F_->createAdd("qAdd" , opOutTy, LHS, RHS); |
3273 | RescaleQuantizedNode *RQN = |
3274 | F_->createRescaleQuantized("rsAdd" , AN, rescaleOutTy); |
3275 | SaveNode *saveRQN = F_->createSave("saveRQN" , RQN); |
3276 | SaveNode *saveAN = F_->createSave("saveAN" , AN); |
3277 | |
3278 | EXPECT_EQ(F_->getNodes().size(), 4); |
3279 | |
3280 | ::glow::optimize(F_, CompilationMode::Infer); |
3281 | |
3282 | // The graph should be unchanged. |
3283 | EXPECT_EQ(F_->getNodes().size(), 4); |
3284 | EXPECT_EQ(saveRQN->getInput().getNode(), RQN); |
3285 | EXPECT_EQ(RQN->getInput().getNode(), AN); |
3286 | EXPECT_EQ(saveAN->getInput().getNode(), AN); |
3287 | EXPECT_EQ(AN->getLHS().getNode(), LHS); |
3288 | EXPECT_EQ(AN->getRHS().getNode(), RHS); |
3289 | } |
3290 | |
3291 | /// This test ensures that fusing of rescale into MatMul is done. |
3292 | TEST_F(GraphOptz, FuseRescaleIntoMatMul) { |
3293 | auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0); |
3294 | auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1); |
3295 | |
3296 | Placeholder *LHS = |
3297 | mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0, "LHS" , true); |
3298 | Placeholder *RHS = |
3299 | mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0, "RHS" , true); |
3300 | |
3301 | RescaleQuantizedNode *LHSR = |
3302 | F_->createRescaleQuantized("rs1" , LHS, rescaleOutTy); |
3303 | RescaleQuantizedNode *RHSR = |
3304 | F_->createRescaleQuantized("rs2" , RHS, rescaleOutTy); |
3305 | MatMulNode *MN = F_->createMatMul("qMatMul" , opOutTy, LHSR, RHSR); |
3306 | SaveNode *SN = F_->createSave("save" , MN); |
3307 | |
3308 | // All rescales must be fused into arithmetic operations above. |
3309 | ::glow::optimize(F_, CompilationMode::Infer); |
3310 | |
3311 | // Only the MatMul and Save should be left. |
3312 | EXPECT_EQ(F_->getNodes().size(), 2); |
3313 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0); |
3314 | |
3315 | MatMulNode *newMN = llvm::dyn_cast<MatMulNode>(SN->getInput()); |
3316 | ASSERT_TRUE(newMN); |
3317 | Placeholder *LPH = llvm::dyn_cast<Placeholder>(newMN->getLHS()); |
3318 | EXPECT_EQ(LPH, LHS); |
3319 | Placeholder *RPH = llvm::dyn_cast<Placeholder>(newMN->getRHS()); |
3320 | EXPECT_EQ(RPH, RHS); |
3321 | } |
3322 | |
3323 | TEST_F(GraphOptz, sinkRescaledQuantizedNode) { |
3324 | // Check that we eliminate rescale nodes by sinking them into other |
3325 | // operators. |
3326 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11, |
3327 | "input" , true); |
3328 | |
3329 | // slice -> rescale -> reshape -> rescale -> transpose -> maxpool -> save. |
3330 | auto *slice = F_->createSlice("slice" , input, {0, 0}, {2, 4}); |
3331 | auto *rescale = F_->createRescaleQuantized( |
3332 | "rescale" , slice, mod_.uniqueType(ElemKind::Int8QTy, {2, 4}, 0.4, 10)); |
3333 | auto *reshape = F_->createReshape("reshape" , rescale, {1, 2, 2, 2}); |
3334 | auto *rescale2 = F_->createRescaleQuantized( |
3335 | "rescale" , reshape, |
3336 | mod_.uniqueType(ElemKind::Int8QTy, {1, 2, 2, 2}, 0.3, 9)); |
3337 | auto *transpose = F_->createTranspose("transpose" , rescale2, {0, 2, 3, 1}); |
3338 | auto *maxpool = |
3339 | F_->createMaxPool("maxpool" , transpose, {2, 2}, {1, 1}, {0, 0, 0, 0}); |
3340 | auto *save = F_->createSave("ret" , maxpool->getResult()); |
3341 | |
3342 | EXPECT_EQ(F_->getNodes().size(), 7); |
3343 | ::glow::optimize(F_, CompilationMode::Infer); |
3344 | EXPECT_EQ(F_->getNodes().size(), 6); |
3345 | // Check that rescale sank all the way down to the save node. |
3346 | EXPECT_TRUE(llvm::dyn_cast<RescaleQuantizedNode>(save->getInput())); |
3347 | } |
3348 | |
3349 | TEST_F(GraphOptz, mergeRescaleWithArithmeticNode) { |
3350 | // Check that Arithmetic operations can be merged with the Rescale. |
3351 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11, |
3352 | "input" , true); |
3353 | |
3354 | auto *rescale1 = F_->createRescaleQuantized( |
3355 | "rescale" , input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11)); |
3356 | auto *add = F_->createAdd("add" , rescale1, rescale1); |
3357 | auto *rescale2 = F_->createRescaleQuantized( |
3358 | "rescale" , add, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.3, 11)); |
3359 | auto *sub = F_->createSub("sub" , rescale2, rescale2); |
3360 | auto *rescale3 = F_->createRescaleQuantized( |
3361 | "rescale" , sub, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.2, 11)); |
3362 | auto *mul = F_->createMul("mul" , rescale3, rescale3); |
3363 | auto *rescale4 = F_->createRescaleQuantized( |
3364 | "rescale" , mul, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.1, 11)); |
3365 | auto *div = F_->createDiv("div" , rescale4, rescale4); |
3366 | F_->createSave("save" , div); |
3367 | |
3368 | EXPECT_EQ(F_->getNodes().size(), 9); |
3369 | ::glow::optimize(F_, CompilationMode::Infer); |
3370 | EXPECT_EQ(F_->getNodes().size(), 5); |
3371 | } |
3372 | |
3373 | /// Check that Relu can be merged with Rescale. |
3374 | TEST_F(GraphOptz, mergeRescaleWithRelu) { |
3375 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11, |
3376 | "input" , false); |
3377 | |
3378 | auto *rescale1 = F_->createRescaleQuantized( |
3379 | "rescale" , input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11)); |
3380 | auto *relu = F_->createRELU("relu" , rescale1); |
3381 | F_->createSave("save" , relu); |
3382 | |
3383 | // Rescale, RELU, Save nodes. |
3384 | EXPECT_EQ(F_->getNodes().size(), 3); |
3385 | |
3386 | ::glow::optimize(F_, CompilationMode::Infer); |
3387 | |
3388 | // RELU, Save nodes left; Rescale merged into RELU. |
3389 | EXPECT_EQ(F_->getNodes().size(), 2); |
3390 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0); |
3391 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1); |
3392 | } |
3393 | |
3394 | // Check that we are able to merge some small matmuls into a larger one. |
3395 | TEST_F(GraphOptz, mergeMatMulNodes) { |
3396 | Node *input = |
3397 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input" , true); |
3398 | Node *weight = |
3399 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight" , true); |
3400 | |
3401 | // Split the input to a bunch of small slices. |
3402 | std::vector<NodeValue> inputs; |
3403 | for (dim_t i = 0; i < 10; i++) { |
3404 | auto *K = F_->createSlice("extract" , input, {i, 0, 0}, {i + 1, 10, 10}); |
3405 | auto *R = F_->createReshape("reshape" , K, {10, 10}); |
3406 | auto *MM = F_->createMatMul("mm" , R, weight); |
3407 | inputs.push_back(MM); |
3408 | } |
3409 | |
3410 | auto *cc = F_->createConcat("merge" , inputs, 0); |
3411 | F_->createSave("save" , cc); |
3412 | |
3413 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 10); |
3414 | ::glow::optimize(F_, CompilationMode::Infer); |
3415 | |
3416 | // Check that all of the matmuls are merged into a single matmul node. |
3417 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 1); |
3418 | } |
3419 | |
3420 | // Check that we are able to merge batched adds. |
3421 | TEST_F(GraphOptz, mergeBANodes) { |
3422 | Node *input = |
3423 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input" , true); |
3424 | Node *slice = |
3425 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight" , true); |
3426 | |
3427 | // Split the input to a bunch of small slices. |
3428 | std::vector<NodeValue> inputs; |
3429 | for (dim_t i = 0; i < 10; i++) { |
3430 | auto *K = F_->createSlice("extract" , input, {i, 0, 0}, {i + 1, 10, 10}); |
3431 | auto *MM = F_->createBatchedAdd("BA" , K, slice); |
3432 | inputs.push_back(MM); |
3433 | } |
3434 | |
3435 | auto *cc = F_->createConcat("merge" , inputs, 0); |
3436 | F_->createSave("save" , cc); |
3437 | |
3438 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 10); |
3439 | ::glow::optimize(F_, CompilationMode::Infer); |
3440 | |
3441 | // Check that all of the batched-adds are merged into a single node. |
3442 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1); |
3443 | } |
3444 | |
3445 | /// Check that EliminateNoop optimization pass removes nodes which don't do |
3446 | /// anything useful. |
3447 | TEST_F(GraphOptz, eliminateNoop) { |
3448 | std::vector<dim_t> shape = {1, 2, 2, 3}; |
3449 | Placeholder *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004, |
3450 | 0, "input" , false); |
3451 | Placeholder *input2 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004, |
3452 | 0, "input" , false); |
3453 | auto *cond = mod_.createConstant(ElemKind::BoolTy, shape, "input1" ); |
3454 | cond->getHandle<bool>() = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; |
3455 | |
3456 | auto *select = F_->createSelect("select" , cond, input1, input2); |
3457 | auto *slice = F_->createSlice("slice" , select, {0, 0, 0, 0}, shape); |
3458 | auto *tile = F_->createTile("tile" , slice, 1, 1); |
3459 | auto *pad = F_->createPad("pad" , tile, tile->getResult().getType(), 0, |
3460 | {0, 0, 0, 0, 0, 0, 0, 0}, 0); |
3461 | auto *avgPool = F_->createAvgPool("avgpool" , pad, 1, 1, 0); |
3462 | auto *maxPool = F_->createMaxPool("maxpool" , avgPool, 1, 1, 0); |
3463 | |
3464 | F_->createSave("save" , maxPool->getResult()); |
3465 | |
3466 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SelectNodeKind), 1); |
3467 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 1); |
3468 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1); |
3469 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::PadNodeKind), 1); |
3470 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind), 1); |
3471 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind), 1); |
3472 | |
3473 | optimizedF_ = optimizeFunctionForTest(F_); |
3474 | |
3475 | // Check that all nodes except for Save are eliminated. |
3476 | EXPECT_EQ(optimizedF_->getNodes().size(), 1); |
3477 | |
3478 | bindings_.allocate(mod_.getPlaceholders()); |
3479 | bindings_.get(input1)->getHandle<int8_t>().randomize(-1.0, 1.0, |
3480 | mod_.getPRNG()); |
3481 | bindings_.get(input2)->getHandle<int8_t>().randomize(-1.0, 1.0, |
3482 | mod_.getPRNG()); |
3483 | |
3484 | checkNumericalEquivalence(); |
3485 | } |
3486 | |
3487 | // Check that we are able to replace |
3488 | // Add(I, tile(B)) with -> BatchedAdd(I, B). |
3489 | TEST_F(GraphOptz, FoldTileAddIntoBatchedAdd) { |
3490 | auto *batch = |
3491 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 2}, "batch" , false); |
3492 | auto *added = mod_.createConstant(ElemKind::FloatTy, {1, 1, 2}, "added" ); |
3493 | auto *addedTiled = F_->createTile("addedTiled" , added, 3, 0); |
3494 | auto *add = F_->createAdd("add" , batch, addedTiled); |
3495 | auto *save = F_->createSave("save" , add); |
3496 | auto *output = save->getPlaceholder(); |
3497 | |
3498 | bindings_.allocate(batch)->getHandle() = {2, 2, 3, 3, 4, 4}; |
3499 | added->getPayloadMutable().getHandle() = {1, 1}; |
3500 | bindings_.allocate(output); |
3501 | |
3502 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1); |
3503 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 1); |
3504 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 0); |
3505 | |
3506 | ASSERT_TRUE(F_->verify()); |
3507 | |
3508 | // Currently the FoldTileAddIntoBatchedAdd opt which we're testing here is not |
3509 | // part of the default optimization pipeline. Create a local version of the |
3510 | // pipeline with that pass included. |
3511 | auto p = createDefaultGraphOptimizationPassPipeline(); |
3512 | p->pushFront({FunctionPassID::FoldTileAddIntoBatchedAdd}); |
3513 | FunctionPassManager FPM("opt" , std::move(p)); |
3514 | FPM.run(F_, CompilationContext()); |
3515 | ASSERT_TRUE(F_->verify()); |
3516 | |
3517 | // Check that the Tile node and the Add node is replaced by |
3518 | // a BatchedAdd node. |
3519 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 0); |
3520 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 0); |
3521 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1); |
3522 | |
3523 | // Verify the correctness of the input to BatchedAdd operator. |
3524 | // The correctness of BatchedAdd operator itself is verified |
3525 | // by operator's unit tests. |
3526 | Tensor expectedBatch(ElemKind::FloatTy, {3, 1, 2}); |
3527 | expectedBatch.getHandle() = {2, 2, 3, 3, 4, 4}; |
3528 | Tensor expectedSlice(ElemKind::FloatTy, {1, 2}); |
3529 | expectedSlice.getHandle() = {1, 1}; |
3530 | for (auto &node : F_->getNodes()) { |
3531 | auto *recvdBANode = llvm::dyn_cast<BatchedAddNode>(&node); |
3532 | if (!recvdBANode) { |
3533 | continue; |
3534 | } |
3535 | auto *recvdBatch = llvm::dyn_cast<Placeholder>(recvdBANode->getBatch()); |
3536 | ASSERT_TRUE(recvdBatch); |
3537 | auto *recvdSlice = llvm::dyn_cast<Constant>(recvdBANode->getSlice()); |
3538 | ASSERT_TRUE(recvdSlice); |
3539 | EXPECT_TRUE(recvdBatch->dims().equals({3, 1, 2})); |
3540 | EXPECT_TRUE(recvdSlice->dims().equals({1, 2})); |
3541 | EXPECT_TRUE(bindings_.get(recvdBatch)->isEqual(expectedBatch)); |
3542 | EXPECT_TRUE(recvdSlice->getPayload().isEqual(expectedSlice)); |
3543 | break; |
3544 | } |
3545 | } |
3546 | |
3547 | /// Test Concat(Slice, ..., Slice) opt works correctly. If \p reverseOrder then |
3548 | /// the optimization is inapplicable and should not occur. |
3549 | static void testConcatElim(Module &mod, Function *F, Function *&optimizedF, |
3550 | PlaceholderBindings &bindings, bool reverseOrder) { |
3551 | Placeholder *input = |
3552 | mod.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input" , true); |
3553 | bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG()); |
3554 | |
3555 | // Split the input to a bunch of small slices. |
3556 | std::array<NodeValue, 10> inputs; |
3557 | for (dim_t i = 0; i < 10; i++) { |
3558 | dim_t idx = reverseOrder ? 9 - i : i; |
3559 | inputs[i] = |
3560 | F->createSlice("extract" , input, {idx, 0, 0}, {idx + 1, 10, 10}); |
3561 | } |
3562 | |
3563 | auto *cc = F->createConcat("merge" , inputs, 0); |
3564 | F->createSave("save" , cc); |
3565 | |
3566 | EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), 10); |
3567 | |
3568 | optimizedF = optimizeFunctionForTest(F); |
3569 | |
3570 | // Check that either the concat and slices are gone if the optimization was |
3571 | // applicable, or otherwise that they're still there. |
3572 | EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind), |
3573 | reverseOrder ? 1 : 0); |
3574 | EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind), |
3575 | reverseOrder ? 10 : 0); |
3576 | } |
3577 | |
3578 | // Check that we are able to eliminate concat nodes. |
3579 | TEST_F(GraphOptz, concatElim) { |
3580 | testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ false); |
3581 | checkNumericalEquivalence(0.0f); |
3582 | } |
3583 | |
3584 | // Check that when the order of the Slices is reversed no optimization kicks in. |
3585 | TEST_F(GraphOptz, concatElimReverseOrder) { |
3586 | testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ true); |
3587 | checkNumericalEquivalence(0.0f); |
3588 | } |
3589 | |
3590 | /// Check that we are able to eliminate concat nodes with redundant arithmetic |
3591 | /// ops in way. |
3592 | TEST_F(GraphOptz, concatArithElim) { |
3593 | auto *input = |
3594 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input" , true); |
3595 | bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
3596 | |
3597 | Type t(ElemKind::FloatTy, {1, 10, 10}); |
3598 | Node *one = F_->createSplat("one" , &t, 1.0); |
3599 | Node *zero = F_->createSplat("zero" , &t, 0.0); |
3600 | |
3601 | // Split the input to a bunch of small slices. |
3602 | std::vector<NodeValue> inputs; |
3603 | for (dim_t i = 0; i < 10; i++) { |
3604 | auto *K = F_->createSlice("extract" , input, {i, 0, 0}, {i + 1, 10, 10}); |
3605 | // Insert the nodes in reverse order to make sure that we can catch |
3606 | // non-consecutive graph-order slices. |
3607 | Node *N = K; |
3608 | switch (i) { |
3609 | case 0: |
3610 | N = F_->createAdd("add0" , K, zero); |
3611 | break; |
3612 | case 1: |
3613 | N = F_->createSub("sub0" , K, zero); |
3614 | break; |
3615 | case 2: |
3616 | N = F_->createAdd("add_0" , zero, K); |
3617 | break; |
3618 | case 3: |
3619 | N = F_->createMul("mul1" , K, one); |
3620 | break; |
3621 | case 4: |
3622 | N = F_->createDiv("div1" , K, one); |
3623 | break; |
3624 | case 5: |
3625 | N = F_->createMul("mul_1" , one, K); |
3626 | break; |
3627 | default: |
3628 | break; |
3629 | } |
3630 | inputs.push_back(N); |
3631 | } |
3632 | |
3633 | auto *cc = F_->createConcat("merge" , inputs, 0); |
3634 | F_->createSave("save" , cc); |
3635 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 10); |
3636 | optimizedF_ = optimizeFunctionForTest(F_); |
3637 | |
3638 | // Check that the concat node is gone. |
3639 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 0); |
3640 | checkNumericalEquivalence(0.0f); |
3641 | } |
3642 | |
3643 | /// Check that we are able to eliminate concat followed by slices on axis |
3644 | /// \p dim under certain conditions. |
3645 | static void testConcatSliceElim(Module &mod, Function *F, Function *&optimizedF, |
3646 | PlaceholderBindings &bindings, size_t dim) { |
3647 | constexpr size_t N = 5; |
3648 | std::array<NodeValue, N> inputs; |
3649 | std::vector<dim_t> inShape = {10, 20}; |
3650 | inShape.insert(inShape.begin() + dim, 0); |
3651 | for (dim_t i = 0; i < N; i++) { |
3652 | inShape[dim] = 1 + i; |
3653 | auto *P = mod.createPlaceholder(ElemKind::FloatTy, inShape, "in" , true); |
3654 | bindings.allocate(P)->getHandle().randomize(-1.0, 1.0, mod.getPRNG()); |
3655 | inputs[i] = P; |
3656 | } |
3657 | auto *CN = F->createConcat("merge" , inputs, dim); |
3658 | |
3659 | // Split the concat to a bunch of slices of the same shape as the concat |
3660 | // inputs and on the same axis. |
3661 | std::vector<dim_t> startShape = {0, 0, 0}; |
3662 | std::vector<dim_t> endShape = {10, 20}; |
3663 | endShape.insert(endShape.begin() + dim, 0); |
3664 | for (dim_t i = 0; i < N; i++) { |
3665 | startShape[dim] = (i * (i + 1)) / 2; |
3666 | endShape[dim] = ((i + 1) * (i + 2)) / 2; |
3667 | auto *SN = F->createSlice("extract" , CN, startShape, endShape); |
3668 | F->createSave("save" , SN); |
3669 | } |
3670 | |
3671 | // We created a concat followed by N slices of its results. |
3672 | EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), N); |
3673 | EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConcatNodeKind), 1); |
3674 | |
3675 | optimizedF = optimizeFunctionForTest(F); |
3676 | |
3677 | // Check that the concat and slices are gone. |
3678 | EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind), 0); |
3679 | EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind), 0); |
3680 | } |
3681 | |
3682 | TEST_F(GraphOptz, concatSliceElimInnerDim) { |
3683 | testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 0); |
3684 | checkNumericalEquivalence(0.0f); |
3685 | } |
3686 | |
3687 | TEST_F(GraphOptz, concatSliceElimMiddleDim) { |
3688 | testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 1); |
3689 | checkNumericalEquivalence(0.0f); |
3690 | } |
3691 | |
3692 | TEST_F(GraphOptz, concatSliceElimOuterDim) { |
3693 | testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 2); |
3694 | checkNumericalEquivalence(0.0f); |
3695 | } |
3696 | |
3697 | /// Check the interaction between Sices(Concat) and Concat(Slices) optimizations |
3698 | /// to make sure they work nicely together. Builds Concat(Slices(Concat)) and |
3699 | /// expected a single Concat after optimizations. |
3700 | TEST_F(GraphOptz, concatSliceElimMultiConcat) { |
3701 | std::array<NodeValue, 4> inputs; |
3702 | for (size_t i = 0; i < 4; i++) { |
3703 | auto *P = mod_.createPlaceholder(ElemKind::FloatTy, {2, 4}, |
3704 | "in_" + std::to_string(i), false); |
3705 | bindings_.allocate(P)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
3706 | inputs[i] = P; |
3707 | } |
3708 | auto *CN0 = F_->createConcat("merge0" , inputs, /* axis */ 1); |
3709 | |
3710 | auto *SN0 = F_->createSlice("slice0" , CN0, {0, 0}, {2, 4}); |
3711 | auto *SN1 = F_->createSlice("slice1" , CN0, {0, 4}, {2, 8}); |
3712 | auto *SN2 = F_->createSlice("slice2" , CN0, {0, 8}, {2, 12}); |
3713 | auto *SN3 = F_->createSlice("slice3" , CN0, {0, 12}, {2, 16}); |
3714 | |
3715 | auto *CN1 = F_->createConcat("merge1" , {SN1, SN0, SN3, SN2}, /* axis */ 1); |
3716 | F_->createSave("save" , CN1); |
3717 | |
3718 | // We created a concat followed by 4 slices of its results followed by another |
3719 | // concat. |
3720 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 2); |
3721 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 4); |
3722 | |
3723 | optimizedF_ = optimizeFunctionForTest(F_); |
3724 | |
3725 | // Check that one concat and slices are gone. |
3726 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1); |
3727 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 0); |
3728 | |
3729 | checkNumericalEquivalence(0.0f); |
3730 | } |
3731 | |
3732 | // Check the transformation Concat(Reshape(x) * N) -> Reshape(Concat(x * N)). |
3733 | TEST_F(GraphOptz, concatReshapes) { |
3734 | const dim_t shape1[] = {2, 5, 2, 1, 20}; |
3735 | const dim_t shape2[] = {10, 2, 2, 10}; |
3736 | const dim_t shape3[] = {5, 80}; |
3737 | llvm::SmallVector<NodeValue, 10> inputs1; |
3738 | llvm::SmallVector<NodeValue, 10> inputs2; |
3739 | for (size_t i = 0; i < 10; i++) { |
3740 | // 10 reshape nodes that transform from {2,5,2,1,20} to {10,2,2,10}. |
3741 | // And a ConcatNode concatenates the outputs of reshape at 2nd dim. |
3742 | // The optimization would kick in, as the size of trailing dimensions of |
3743 | // original ConcatNode (before opt) is 20, and the size of leading |
3744 | // dimensions of original ConcatNode (before opt) is 10. |
3745 | Node *var = F_->getParent()->createPlaceholder( |
3746 | ElemKind::FloatTy, shape1, "input" + std::to_string(i), true); |
3747 | auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2); |
3748 | inputs1.push_back(RN); |
3749 | } |
3750 | auto *concatNode1 = F_->createConcat("concat" , inputs1, 1); |
3751 | for (size_t i = 0; i < 10; i++) { |
3752 | // 10 reshape nodes that transform from {5,80} to {10,1,2,10}. |
3753 | // And a ConcatNode concatenates the outputs of reshape at 2nd dim. |
3754 | // The optimization would NOT kick in, as we cannot find the dim that |
3755 | // makes the leading/trailing dims same as in the case of the original |
3756 | // concat node. |
3757 | Node *var = F_->getParent()->createPlaceholder( |
3758 | ElemKind::FloatTy, shape3, "input" + std::to_string(i), true); |
3759 | auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2); |
3760 | inputs2.push_back(RN); |
3761 | } |
3762 | auto *concatNode2 = F_->createConcat("concat" , inputs2, 1); |
3763 | auto outputShape = concatNode1->getResult().dims(); |
3764 | // Need to dereference the RN vectors, otherwise the user number of those |
3765 | // nodes would always be positive, making them unable to be removed by DCE. |
3766 | inputs1.clear(); |
3767 | inputs2.clear(); |
3768 | |
3769 | auto *addNode = F_->createAdd("add" , concatNode1, concatNode2); |
3770 | auto *O = F_->createSave("ret" , addNode); |
3771 | |
3772 | EXPECT_EQ(F_->getNodes().size(), 24); |
3773 | |
3774 | ::glow::optimize(F_, CompilationMode::Infer); |
3775 | |
3776 | // After optimization, we expect to see only 15 nodes. All 10 of the |
3777 | // reshapes that were the inputs to the first original concat node |
3778 | // (concatNode1) are removed, and a single new reshape is added after the |
3779 | // new concat. |
3780 | EXPECT_EQ(F_->getNodes().size(), 15); |
3781 | |
3782 | // concatNode1 should not exist any more. |
3783 | EXPECT_FALSE(functionContainsNode(F_, concatNode1)); |
3784 | // concatNode2 should still exist. |
3785 | EXPECT_TRUE(functionContainsNode(F_, concatNode2)); |
3786 | |
3787 | // The first input of addNode should be a Reshape node now, with the same |
3788 | // result shape of concatNode1. |
3789 | auto *newAddNode = llvm::dyn_cast<AddNode>(O->getInput()); |
3790 | ASSERT_TRUE(newAddNode); |
3791 | auto *newRN = llvm::dyn_cast<ReshapeNode>(newAddNode->getLHS()); |
3792 | ASSERT_TRUE(newRN); |
3793 | EXPECT_TRUE(newRN->getResult().getType()->dims().equals(outputShape)); |
3794 | |
3795 | // The input of newRN should be a ConcatNode now. |
3796 | auto *newCN = llvm::dyn_cast<ConcatNode>(newRN->getInput()); |
3797 | ASSERT_TRUE(newCN); |
3798 | } |
3799 | |
3800 | // Making sure we do not try to to optimize concat2(dim1, concat1(dim2, X, Y), |
3801 | // Z) |
3802 | // -> concat(dim1, X, Y, Z) when concat1 has multiple users. |
3803 | TEST_F(GraphOptz, ConcatSimplificationNegative) { |
3804 | const dim_t dim1[] = {1, 4, 4, 4}; |
3805 | const dim_t dim2[] = {1, 4, 4, 8}; |
3806 | auto *in1 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in1" , false); |
3807 | auto *in2 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in2" , false); |
3808 | auto *in3 = mod_.createPlaceholder(ElemKind::FloatTy, dim2, "in3" , false); |
3809 | |
3810 | auto *cnc1 = F_->createConcat("cnc1" , {in1, in2}, 3); |
3811 | auto *add1 = F_->createAdd("add1" , in3, cnc1); |
3812 | auto *cnc2 = F_->createConcat("cnc2" , {add1, cnc1}, 3); |
3813 | F_->createSave("ret" , cnc2); |
3814 | EXPECT_EQ(F_->getNodes().size(), 4); |
3815 | ::glow::optimize(F_, CompilationMode::Infer); |
3816 | EXPECT_EQ(F_->getNodes().size(), 4); |
3817 | for (auto &n : F_->getNodes()) { |
3818 | if (auto *tcnc = llvm::dyn_cast<ConcatNode>(&n)) { |
3819 | EXPECT_EQ(tcnc->getNumInputs(), 2); |
3820 | } |
3821 | } |
3822 | } |
3823 | |
3824 | /// Check that Variable CSE works correctly, combining small Variables that |
3825 | /// have the same data. |
3826 | TEST_F(GraphOptz, VarsCSE) { |
3827 | // Create three variables that are Private, are not trainable, and have no |
3828 | // writers. The first two variables have the same data, and so should be |
3829 | // combined via variable CSE. The third variable differs by the last value, |
3830 | // and so should not be combined. |
3831 | auto *input1 = mod_.createConstant(ElemKind::FloatTy, {10}, "input1" ); |
3832 | auto *input2 = mod_.createConstant(ElemKind::FloatTy, {10}, "input2" ); |
3833 | auto *input3 = mod_.createConstant(ElemKind::FloatTy, {10}, "input3" ); |
3834 | input1->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; |
3835 | input2->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; |
3836 | input3->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1}; |
3837 | |
3838 | // Input them each to different nodes, so node CSE does not change them. |
3839 | auto *TN = F_->createTanh("tanh" , input1); |
3840 | auto *SN = F_->createSigmoid("sigmoid" , input2); |
3841 | auto *RN = F_->createRELU("relu" , input3); |
3842 | auto *CN = F_->createConcat("concat" , {TN, SN, RN}, /* axis */ 0); |
3843 | F_->createSave("ret" , CN); |
3844 | |
3845 | // Initially there are three variables: inputs 1, 2, and 3 (the save uses a |
3846 | // placeholder). |
3847 | EXPECT_EQ(mod_.getConstants().size(), 3); |
3848 | |
3849 | cctx_.compMode = CompilationMode::Infer; |
3850 | // Do not perform any compile-time constant folding. |
3851 | cctx_.optimizationOpts.enableConstantFolding = false; |
3852 | ::glow::optimize(F_, cctx_); |
3853 | |
3854 | // Now only two variables are left; input1 and input2 have been combined, |
3855 | // but input3 has not. |
3856 | EXPECT_EQ(mod_.getConstants().size(), 2); |
3857 | |
3858 | // Verify that only one of input1 and input2 exists, and that input3 still |
3859 | // exists. |
3860 | Constant *varOneOrTwo = nullptr; |
3861 | bool foundVarThree = false; |
3862 | for (auto *V : mod_.getConstants()) { |
3863 | if (V == input1 || V == input2) { |
3864 | EXPECT_TRUE(varOneOrTwo == nullptr); |
3865 | varOneOrTwo = V; |
3866 | } else if (V == input3) { |
3867 | foundVarThree = true; |
3868 | } |
3869 | } |
3870 | EXPECT_TRUE(varOneOrTwo != nullptr); |
3871 | EXPECT_TRUE(foundVarThree); |
3872 | |
3873 | // Verify that the users of the inputs are updated correctly. |
3874 | EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo); |
3875 | EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo); |
3876 | EXPECT_TRUE(RN->getInput().getNode() == input3); |
3877 | |
3878 | // Verify that whichever input1/input2 is left over has two users TN and SN. |
3879 | EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2); |
3880 | for (auto &U : varOneOrTwo->getUsers()) { |
3881 | auto *N = U.getUser(); |
3882 | EXPECT_TRUE(N == TN || N == SN); |
3883 | } |
3884 | |
3885 | // Verify that input3 only has a single user RN. |
3886 | ASSERT_TRUE(input3->getUsers().size() == 1); |
3887 | EXPECT_TRUE(input3->getUsers().begin()->getUser() == RN); |
3888 | } |
3889 | |
3890 | TEST_F(GraphOptz, VarsCSENaN) { |
3891 | // Create two variables that are Private, are not trainable, have no writers |
3892 | // and include NaNs. The first two variables have the same data, and so should |
3893 | // be combined via variable CSE. In particular, the NaN constants should not |
3894 | // prevent the variables from being combine. |
3895 | auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5}, "input1" ); |
3896 | auto *input2 = mod_.createConstant(ElemKind::FloatTy, {5}, "input2" ); |
3897 | input1->getHandle() = {0, NAN, 2, NAN, 4}; |
3898 | input2->getHandle() = {0, NAN, 2, NAN, 4}; |
3899 | |
3900 | // Input them each to different nodes, so node CSE does not change them. |
3901 | auto *TN = F_->createTanh("tanh" , input1); |
3902 | auto *SN = F_->createSigmoid("sigmoid" , input2); |
3903 | auto *CN = F_->createConcat("concat" , {TN, SN}, /* axis */ 0); |
3904 | F_->createSave("ret" , CN); |
3905 | |
3906 | // Initially there are two variables: inputs 1 and 2 (the save uses a |
3907 | // placeholder). |
3908 | EXPECT_EQ(mod_.getConstants().size(), 2); |
3909 | |
3910 | cctx_.compMode = CompilationMode::Infer; |
3911 | // Do not perform any compile-time constant folding. |
3912 | cctx_.optimizationOpts.enableConstantFolding = false; |
3913 | ::glow::optimize(F_, cctx_); |
3914 | |
3915 | // Now only one variables is left; input1 and input2 have been combined. |
3916 | EXPECT_EQ(mod_.getConstants().size(), 1); |
3917 | |
3918 | // Verify that only one of input1 and input2 exists. |
3919 | Constant *varOneOrTwo = nullptr; |
3920 | for (auto *V : mod_.getConstants()) { |
3921 | if (V == input1 || V == input2) { |
3922 | EXPECT_TRUE(varOneOrTwo == nullptr); |
3923 | varOneOrTwo = V; |
3924 | } |
3925 | } |
3926 | EXPECT_TRUE(varOneOrTwo != nullptr); |
3927 | |
3928 | // Verify that the users of the inputs are updated correctly. |
3929 | EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo); |
3930 | EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo); |
3931 | |
3932 | // Verify that whichever input1/input2 is left over has two users TN and SN. |
3933 | EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2); |
3934 | for (auto &U : varOneOrTwo->getUsers()) { |
3935 | auto *N = U.getUser(); |
3936 | EXPECT_TRUE(N == TN || N == SN); |
3937 | } |
3938 | } |
3939 | |
3940 | // Verify that constant input canonicalization works correctly when the |
3941 | // arithmetic nodes have multiple users. |
3942 | TEST_F(GraphOptz, simplifyArithmeticMultipleUsers) { |
3943 | Node *I1 = |
3944 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input1" , false); |
3945 | |
3946 | Type t(ElemKind::FloatTy, {10, 10, 10}); |
3947 | Node *SN = F_->createSplat("one" , &t, 1.0); |
3948 | |
3949 | // The splat is a constant input to add1 and add2, and is their LHS input. |
3950 | // We expect canonicalization to occur during optimization, moving the splat |
3951 | // to the RHS for both. Note that add1 has multiple users: add2 and save1. |
3952 | Node *AN1 = F_->createAdd("add1" , SN, I1); |
3953 | Node *AN2 = F_->createAdd("add2" , SN, AN1); |
3954 | SaveNode *SN1 = F_->createSave("save1" , AN1); |
3955 | SaveNode *SN2 = F_->createSave("save2" , AN2); |
3956 | |
3957 | // Five nodes in total: one splat, two adds, and two saves. |
3958 | EXPECT_EQ(F_->getNodes().size(), 5); |
3959 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1); |
3960 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2); |
3961 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2); |
3962 | |
3963 | // input1 has a single user before optimization. |
3964 | EXPECT_EQ(I1->getUsers().size(), 1); |
3965 | |
3966 | // Simplify nodes will canonicalize add1 and add2, and should replace all |
3967 | // their users, without otherwise adding new nodes to the graph/changing the |
3968 | // overall structure. |
3969 | ::glow::optimize(F_, CompilationMode::Infer); |
3970 | |
3971 | // We should have the same five nodes: one splat, two adds, and two saves. |
3972 | EXPECT_EQ(F_->getNodes().size(), 5); |
3973 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1); |
3974 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2); |
3975 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2); |
3976 | |
3977 | // Verify that both add nodes were canonicalized, and that the graph's shape |
3978 | // is the same as prior to optimization other than canonicalization. |
3979 | AddNode *newAN1 = llvm::dyn_cast<AddNode>(SN1->getInput().getNode()); |
3980 | ASSERT_TRUE(newAN1 != nullptr); |
3981 | EXPECT_TRUE(llvm::isa<Placeholder>(newAN1->getLHS())); |
3982 | EXPECT_TRUE(llvm::isa<SplatNode>(newAN1->getRHS())); |
3983 | |
3984 | AddNode *newAN2 = llvm::dyn_cast<AddNode>(SN2->getInput().getNode()); |
3985 | ASSERT_TRUE(newAN2 != nullptr); |
3986 | EXPECT_TRUE(llvm::isa<AddNode>(newAN2->getLHS())); |
3987 | EXPECT_TRUE(llvm::isa<SplatNode>(newAN2->getRHS())); |
3988 | |
3989 | EXPECT_EQ(newAN1, newAN2->getLHS()); |
3990 | |
3991 | // input1 should still have a single user after optimization. |
3992 | EXPECT_EQ(I1->getUsers().size(), 1); |
3993 | } |
3994 | |
3995 | /// Test that a concat with a single input is replaced by the input. |
3996 | TEST_F(GraphOptz, eliminateSingleConcat) { |
3997 | Node *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input" , false); |
3998 | |
3999 | ConcatNode *CN = F_->createConcat("concat1" , {input}, 0); |
4000 | SaveNode *SN = F_->createSave("ret" , CN); |
4001 | |
4002 | // The ConcatNode and SaveNode. |
4003 | EXPECT_EQ(F_->getNodes().size(), 2); |
4004 | |
4005 | ::glow::optimize(F_, CompilationMode::Infer); |
4006 | |
4007 | // Just the SaveNode should be left. |
4008 | EXPECT_EQ(F_->getNodes().size(), 1); |
4009 | ASSERT_TRUE(functionContainsNode(F_, SN)); |
4010 | |
4011 | // Save node should just save the input. |
4012 | EXPECT_TRUE(SN->getInput().getNode() == input); |
4013 | } |
4014 | |
4015 | /// Test that a reshape of a private variable with one use has the reshape |
4016 | /// merged into the variable. |
4017 | TEST_F(GraphOptz, ReshapeConstantOneUse) { |
4018 | const dim_t shape[] = {10, 20}; |
4019 | const dim_t reshape1[] = {200, 1}; |
4020 | const dim_t reshape2[] = {200}; |
4021 | Constant *input = |
4022 | F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input" ); |
4023 | input->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
4024 | |
4025 | auto *R1 = F_->createReshape("reshape1" , input, reshape1); |
4026 | auto *R2 = F_->createReshape("reshape2" , R1, reshape2); |
4027 | auto *O = F_->createSave("ret" , R2); |
4028 | |
4029 | // Before optimization, we have 2 Reshapes and a Save. |
4030 | EXPECT_EQ(F_->getNodes().size(), 3); |
4031 | |
4032 | // Skip ConstantFolding as it would have the same result as this opt. |
4033 | cctx_.optimizationOpts.enableConstantFolding = false; |
4034 | ::glow::optimize(F_, cctx_); |
4035 | |
4036 | // After optimization, we expect to see just a Save. |
4037 | EXPECT_EQ(F_->getNodes().size(), 1); |
4038 | |
4039 | // Save should have the new Variable as input. |
4040 | auto *V = llvm::dyn_cast<Constant>(O->getInput()); |
4041 | ASSERT_TRUE(V); |
4042 | // The new Variable should have the same shape as the original second |
4043 | // Reshape. |
4044 | EXPECT_TRUE(V->getType()->dims().equals(reshape2)); |
4045 | } |
4046 | |
4047 | /// Test that reshape node is merged into Constant in a sequence |
4048 | /// Reshape(Quantize(Constant)). |
4049 | TEST_F(GraphOptz, ReshapeQuantizeConstant) { |
4050 | const dim_t shape[] = {10, 20}; |
4051 | const dim_t newShape[] = {200, 1}; |
4052 | |
4053 | auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, shape, 0.2, 0); |
4054 | |
4055 | auto *input = |
4056 | F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input" ); |
4057 | auto *Q = F_->createQuantize("quantize" , input, qTy); |
4058 | auto *R = F_->createReshape("reshape" , Q, newShape); |
4059 | auto *S = F_->createSave("ret" , R); |
4060 | |
4061 | // Skip ConstantFolding as it would have the same result as this opt. |
4062 | CompilationContext cctx; |
4063 | cctx.optimizationOpts.enableConstantFolding = false; |
4064 | |
4065 | EXPECT_EQ(F_->getNodes().size(), 3); |
4066 | ::glow::optimize(F_, cctx); |
4067 | EXPECT_EQ(F_->getNodes().size(), 2); |
4068 | |
4069 | // Constant and Quantize should have new shape. |
4070 | auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput()); |
4071 | ASSERT_TRUE(newQ); |
4072 | EXPECT_TRUE(newQ->getResult().dims().equals(newShape)); |
4073 | auto *newC = llvm::dyn_cast<Constant>(newQ->getInput()); |
4074 | ASSERT_TRUE(newC); |
4075 | EXPECT_TRUE(newC->getType()->dims().equals(newShape)); |
4076 | } |
4077 | |
4078 | /// Test that Transpose is optimized into Reshape when it moves no data. |
4079 | TEST_F(GraphOptz, transposeIntoReshapeOptim) { |
4080 | auto *batch = |
4081 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 2, 4}, "batch" , false); |
4082 | Node *T = F_->createTranspose("transpose" , batch, {1, 2, 0, 3}); |
4083 | SaveNode *O = F_->createSave("ret" , T); |
4084 | |
4085 | EXPECT_EQ(F_->getNodes().size(), 2); |
4086 | |
4087 | ::glow::optimize(F_, CompilationMode::Infer); |
4088 | EXPECT_EQ(F_->getNodes().size(), 2); |
4089 | |
4090 | // TransposeNode is Optimized into ReshapeNode. |
4091 | auto *reshape = llvm::dyn_cast<ReshapeNode>(O->getInput().getNode()); |
4092 | ASSERT_NE(reshape, nullptr); |
4093 | } |
4094 | |
4095 | /// Test that transpose is merged into matmul. |
4096 | TEST_F(GraphOptz, mergeTransposeIntoMatMul) { |
4097 | auto *input = |
4098 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input" , false); |
4099 | auto *weights = |
4100 | F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights" ); |
4101 | |
4102 | weights->getHandle() = {0, 1, 2, 3, 4, 5}; |
4103 | float newWeightsRef[] = {0, 2, 4, 1, 3, 5}; |
4104 | |
4105 | auto *TN = F_->createTranspose("transpose" , input, NHWC2NCHW); |
4106 | auto *RN = F_->createReshape("reshape" , TN, {1, 6}); |
4107 | auto *MMN = F_->createMatMul("matmul" , RN, weights); |
4108 | auto *SN = F_->createSave("ret" , MMN); |
4109 | |
4110 | // Transpose + Reshape + MatMul + Save. |
4111 | EXPECT_EQ(F_->getNodes().size(), 4); |
4112 | |
4113 | ::glow::optimize(F_, CompilationMode::Infer); |
4114 | |
4115 | // Reshape + MatMul + Save. |
4116 | EXPECT_EQ(F_->getNodes().size(), 3); |
4117 | |
4118 | // Check reordered weights. |
4119 | auto *newMMN = llvm::dyn_cast<MatMulNode>(SN->getInput()); |
4120 | ASSERT_TRUE(newMMN != nullptr); |
4121 | auto *newW = llvm::dyn_cast<Constant>(newMMN->getRHS()); |
4122 | ASSERT_TRUE(newW != nullptr); |
4123 | for (unsigned i = 0; i < 6; ++i) { |
4124 | EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i)); |
4125 | } |
4126 | } |
4127 | |
4128 | /// Test that transpose is merged into FullyConnected. |
4129 | TEST_F(GraphOptz, mergeTransposeIntoFC) { |
4130 | auto *input = |
4131 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input" , false); |
4132 | auto *weights = |
4133 | F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights" ); |
4134 | auto *bias = F_->getParent()->createConstant(ElemKind::FloatTy, {1}, "bias" ); |
4135 | |
4136 | weights->getHandle() = {0, 1, 2, 3, 4, 5}; |
4137 | float newWeightsRef[] = {0, 2, 4, 1, 3, 5}; |
4138 | |
4139 | auto *TN = F_->createTranspose("transpose" , input, NHWC2NCHW); |
4140 | auto *RN = F_->createReshape("reshape" , TN, {1, 6}); |
4141 | auto *FCN = F_->createFullyConnected("fc" , RN, weights, bias); |
4142 | auto *SN = F_->createSave("ret" , FCN); |
4143 | |
4144 | // Transpose + Reshape + FC + Save. |
4145 | EXPECT_EQ(F_->getNodes().size(), 4); |
4146 | |
4147 | ::glow::optimize(F_, CompilationMode::Infer); |
4148 | |
4149 | // Reshape + FC + Save. |
4150 | EXPECT_EQ(F_->getNodes().size(), 3); |
4151 | |
4152 | // Check reordered weights. |
4153 | auto *newFCN = llvm::dyn_cast<FullyConnectedNode>(SN->getInput()); |
4154 | ASSERT_TRUE(newFCN != nullptr); |
4155 | auto *newW = llvm::dyn_cast<Constant>(newFCN->getWeights()); |
4156 | ASSERT_TRUE(newW != nullptr); |
4157 | for (unsigned i = 0; i < 6; ++i) { |
4158 | EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i)); |
4159 | } |
4160 | } |
4161 | |
4162 | TEST_F(GraphOptz, ConvertPlaceholdersToConstants) { |
4163 | auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1" , true); |
4164 | auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2" , true); |
4165 | auto *input3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input3" , true); |
4166 | auto *save1 = F_->createSave("save1" , input1); |
4167 | auto *save2 = F_->createSave("save2" , input2); |
4168 | auto *save3 = F_->createSave("save3" , input3); |
4169 | |
4170 | // No variables, six PHs (3 inputs, 3 saves). |
4171 | EXPECT_EQ(mod_.getConstants().size(), 0); |
4172 | EXPECT_EQ(mod_.getPlaceholders().size(), 6); |
4173 | |
4174 | // Allocate two of the three inputs, but mark input2 of them as |
4175 | // non-constant. |
4176 | bindings_.allocate(input1); |
4177 | bindings_.allocate(input2); |
4178 | // Don't allocate input3; keep it as a placeholder instead. |
4179 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {input2}); |
4180 | |
4181 | // input1 becomes a variable. |
4182 | EXPECT_EQ(mod_.getConstants().size(), 1); |
4183 | EXPECT_EQ(mod_.getPlaceholders().size(), 6); |
4184 | |
4185 | EXPECT_TRUE(llvm::isa<Constant>(save1->getInput())); |
4186 | EXPECT_TRUE(llvm::isa<Placeholder>(save2->getInput())); |
4187 | EXPECT_TRUE(llvm::isa<Placeholder>(save3->getInput())); |
4188 | } |
4189 | |
4190 | TEST_F(GraphOptz, optimizeConversion_i32_i64_i32) { |
4191 | auto *i32 = mod_.uniqueType(ElemKind::Int32ITy, {1}); |
4192 | auto *i64 = mod_.uniqueType(ElemKind::Int64ITy, {1}); |
4193 | |
4194 | auto *A = mod_.createPlaceholder(i32, "A" , false); |
4195 | auto *B = F_->createConvertTo("B" , A, i64); |
4196 | auto *C = F_->createConvertTo("C" , B, i32); |
4197 | auto *S = F_->createSave("S" , C); |
4198 | |
4199 | ::glow::optimize(F_, CompilationMode::Infer); |
4200 | |
4201 | // All casting is optimized away, only left with Save of Placeholder. |
4202 | EXPECT_EQ(F_->getNodes().size(), 1); |
4203 | EXPECT_TRUE(llvm::isa<Placeholder>(S->getInput())); |
4204 | } |
4205 | |
4206 | TEST_F(GraphOptz, optimizeSameTypeConversions) { |
4207 | auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1" , true); |
4208 | auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2" , true); |
4209 | auto *conv1 = F_->createConvertTo("cast1" , input1, ElemKind::FloatTy); |
4210 | auto *conv2 = F_->createConvertTo("cast2" , input2, ElemKind::Float16Ty); |
4211 | auto *save1 = F_->createSave("save1" , conv1); |
4212 | auto *save2 = F_->createSave("save1" , conv2); |
4213 | |
4214 | // convert_to1 + save1 + convert_to2 + save2 nodes. |
4215 | EXPECT_EQ(F_->getNodes().size(), 4); |
4216 | EXPECT_TRUE(llvm::isa<ConvertToNode>(save1->getInput())); |
4217 | |
4218 | ::glow::optimize(F_, CompilationMode::Infer); |
4219 | |
4220 | // save1 + convert_to2 + save2 nodes. |
4221 | EXPECT_EQ(F_->getNodes().size(), 3); |
4222 | // convert_to1 node should be eliminated, because it converts the node into |
4223 | // the same type. |
4224 | EXPECT_TRUE(llvm::isa<Placeholder>(save1->getInput())); |
4225 | // convert_to1 node should not be eliminated, because it converts the node |
4226 | // into a different type. |
4227 | EXPECT_TRUE(llvm::isa<ConvertToNode>(save2->getInput())); |
4228 | EXPECT_EQ(save2->getInput(), NodeValue(conv2)); |
4229 | } |
4230 | |
4231 | TEST_F(GraphOptz, optimizeConvertingBetweenFused) { |
4232 | // Call with dims {5, 2}, which will actually create a constant with {5, 10} |
4233 | // for scale/offset per row. |
4234 | Constant *C = createRandomFusedRowwiseQuantizedConstant( |
4235 | mod_, {5, 2}, "fused" , /* useFusedFP16 */ false); |
4236 | // Converting to fused FP16 means we have 4 total less bytes for scale/offset, |
4237 | // so we move to {5, 10} from {5, 6}. |
4238 | auto newOT = mod_.uniqueType(ElemKind::UInt8FusedFP16QTy, {5, 6}, 1.0, 0); |
4239 | auto *CN = F_->createConvertTo("convert" , C, newOT); |
4240 | auto *SN = F_->createSave("save" , CN); |
4241 | |
4242 | ::glow::optimize(F_, CompilationMode::Infer); |
4243 | |
4244 | // Convert should be eliminated and just the save of the Constant left. |
4245 | EXPECT_EQ(F_->getNodes().size(), 1); |
4246 | Constant *convertedC = llvm::dyn_cast<Constant>(SN->getInput()); |
4247 | ASSERT_TRUE(convertedC); |
4248 | EXPECT_EQ(convertedC->getElementType(), ElemKind::UInt8FusedFP16QTy); |
4249 | } |
4250 | |
4251 | TEST_F(GraphOptz, dceBeforeOptimizeTranpose) { |
4252 | auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5, 10}, "input1" ); |
4253 | // Create an unused node. |
4254 | F_->createAdd("add" , input1, input1); |
4255 | auto *transposedInput1 = F_->createTranspose("transpose" , input1, {1, 0}); |
4256 | auto *save1 = F_->createSave("save1" , transposedInput1); |
4257 | |
4258 | // add + transpose + save. |
4259 | EXPECT_EQ(F_->getNodes().size(), 3); |
4260 | |
4261 | ::glow::optimize(F_, CompilationMode::Infer); |
4262 | |
4263 | // A single node: save. |
4264 | EXPECT_EQ(F_->getNodes().size(), 1); |
4265 | // transpose should be eliminated and replaced by the transposed constant. |
4266 | EXPECT_TRUE(llvm::isa<Constant>(save1->getInput())); |
4267 | } |
4268 | |
4269 | /// Test that Transpose is sunk below ChannelShuffle and cancels with an |
4270 | /// inverse transpose below the ChannelShuffle. This test models a pattern |
4271 | /// that has has been observed in shufflenet during graph optimization. |
4272 | TEST_F(GraphOptz, sinkTransposeBelowChannelShuffleNodesAndEliminate) { |
4273 | const dim_t inputDims[] = {3, 28, 28, 136}; |
4274 | |
4275 | Node *K = |
4276 | mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input" , false); |
4277 | K = F_->createTranspose("unnecessary_transpose_1" , K, {0, 3, 1, 2}); |
4278 | K = F_->createChannelShuffle("channel_shuffle" , K, 4, 1); |
4279 | K = F_->createTranspose("unnecessary_transpose_2" , K, {0, 2, 3, 1}); |
4280 | auto *save = F_->createSave("ret" , K); |
4281 | |
4282 | EXPECT_EQ(F_->getNodes().size(), 4); |
4283 | |
4284 | // Optimize away the unnecessary transposes. |
4285 | optimize(F_, CompilationMode::Infer); |
4286 | |
4287 | // Ensure the two unnecessary transposes are gone. |
4288 | ASSERT_EQ(F_->getNodes().size(), 2); |
4289 | |
4290 | // Check that the channel shuffle node is still there. |
4291 | auto *CSN = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode()); |
4292 | ASSERT_NE(nullptr, CSN); |
4293 | |
4294 | // Ensure ChannelShuffle node has the same dimensions as the input. |
4295 | EXPECT_EQ(CSN->getResult().dims(), llvm::makeArrayRef(inputDims)); |
4296 | |
4297 | // Ensure Group and Kernel are as expected. |
4298 | EXPECT_EQ(CSN->getGroup(), 4); |
4299 | EXPECT_EQ(CSN->getKernel(), 3); |
4300 | } |
4301 | |
4302 | /// Test BatchNorm sinking below Slice. |
4303 | TEST_F(GraphOptz, sinkBatchNormBelowSlice) { |
4304 | auto *inputTy = mod_.uniqueType(ElemKind::FloatTy, {1, 10, 10, 3}); |
4305 | auto *slicedTy1 = mod_.uniqueType(ElemKind::FloatTy, {1, 8, 8, 3}); |
4306 | auto *slicedTy2 = mod_.uniqueType(ElemKind::FloatTy, {1, 6, 6, 1}); |
4307 | |
4308 | auto *input = mod_.createPlaceholder(inputTy, "input" , false); |
4309 | auto *BN = F_->createBatchNormalization(bindings_, "batchnorm" , input, 3, |
4310 | 0.0001, 0.9); |
4311 | auto *SN1 = F_->createSlice("slice1" , BN, {0, 1, 1, 0}, slicedTy1); |
4312 | auto *SN2 = F_->createSlice("slice2" , SN1, {0, 1, 1, 1}, slicedTy2); |
4313 | auto *save = F_->createSave("save" , SN2); |
4314 | |
4315 | EXPECT_EQ(F_->getNodes().size(), 4); |
4316 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
4317 | optimizedF_ = optimizeFunctionForTest(F_); |
4318 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
4319 | |
4320 | // BatchNorm should have sunk below the first Slice, but not the second one, |
4321 | // as it changes channel dimmension. |
4322 | auto *newSave = |
4323 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
4324 | ASSERT_TRUE(newSave); |
4325 | auto *newSN2 = llvm::dyn_cast<SliceNode>(newSave->getInput()); |
4326 | ASSERT_TRUE(newSN2); |
4327 | auto *newBN = llvm::dyn_cast<BatchNormalizationNode>(newSN2->getInput()); |
4328 | ASSERT_TRUE(newBN); |
4329 | ASSERT_EQ(newBN->getResult().dims(), slicedTy1->dims()); |
4330 | ASSERT_TRUE(llvm::isa<SliceNode>(newBN->getInput())); |
4331 | |
4332 | bindings_.allocate(mod_.getPlaceholders()); |
4333 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
4334 | checkNumericalEquivalence(); |
4335 | } |
4336 | |
4337 | /// Test that convertPlaceholdersToConstants works properly with quantized |
4338 | /// types. |
4339 | TEST_F(GraphOptz, QuantizedFC) { |
4340 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0, |
4341 | "input" , false); |
4342 | auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 1.0, 0, |
4343 | "weights" , false); |
4344 | auto *bias = |
4345 | mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 1.0, 0, "bias" , false); |
4346 | auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0, |
4347 | "output" , false); |
4348 | |
4349 | auto *fc = F_->createFullyConnected("fc" , input, weights, bias); |
4350 | F_->createSave("save" , fc, output); |
4351 | |
4352 | bindings_.allocate(input); |
4353 | bindings_.allocate(weights); |
4354 | bindings_.allocate(bias); |
4355 | bindings_.allocate(output); |
4356 | |
4357 | glow::convertPlaceholdersToConstants(F_, bindings_, {input, output}); |
4358 | // Two constants: weight and bias |
4359 | EXPECT_EQ(mod_.getConstants().size(), 2); |
4360 | // All four placeholders still exist in the module. The old weight and bias |
4361 | // placeholders just aren't hooked up the the Graph F_. |
4362 | EXPECT_EQ(mod_.getPlaceholders().size(), 4); |
4363 | } |
4364 | |
4365 | /// Test batchedReduceMean optimization using AvgPool. |
4366 | TEST_F(GraphOptz, convertReduceMean2AvgPool) { |
4367 | const dim_t dims[] = {2, 2, 2, 2}; |
4368 | |
4369 | Node *A = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input" , false); |
4370 | Node *R = F_->createBatchedReduceMean("reduce.mean" , A, {2, 3}); |
4371 | |
4372 | SaveNode *O = F_->createSave("ret" , R); |
4373 | |
4374 | EXPECT_EQ(F_->getNodes().size(), 2); |
4375 | |
4376 | ::glow::optimize(F_, CompilationMode::Infer); |
4377 | |
4378 | // Optimization adds 2 transpose nodes and one reshape node. |
4379 | EXPECT_EQ(F_->getNodes().size(), 5); |
4380 | |
4381 | // Expecting reshape output rather than ReduceMean. |
4382 | auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput()); |
4383 | ASSERT_NE(RN, nullptr); |
4384 | |
4385 | // Expecting Transpose node before Reshape node. |
4386 | auto *TN = llvm::dyn_cast<TransposeNode>(RN->getInput()); |
4387 | ASSERT_NE(TN, nullptr); |
4388 | |
4389 | // Expecting AvgPool node before Transpose node. |
4390 | auto *APN = llvm::dyn_cast<AvgPoolNode>(TN->getInput()); |
4391 | ASSERT_NE(APN, nullptr); |
4392 | } |
4393 | |
4394 | /// Test Broadcasted RHS BatchMatMul is converted correctly to a single MatMul. |
4395 | TEST_F(GraphOptz, convertBroadcastedBatchMatMulToMatMul) { |
4396 | auto *lhs = |
4397 | mod_.createPlaceholder(ElemKind::FloatTy, {6, 10, 4}, "lhs" , false); |
4398 | auto *rhs = mod_.createConstant(ElemKind::FloatTy, {4, 8}, "rhs" ); |
4399 | rhs->getPayloadMutable().getHandle().randomize(-10, 10, mod_.getPRNG()); |
4400 | auto *BMMN = F_->createBatchMatMul("BMM" , lhs, rhs); |
4401 | F_->createSave("save" , BMMN); |
4402 | |
4403 | // Start with a BatchMatMul, not a MatMul. |
4404 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 1); |
4405 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 0); |
4406 | |
4407 | optimizedF_ = optimizeFunctionForTest(F_); |
4408 | |
4409 | // Optimization should replace the BatchMatMul with a single MatMul. |
4410 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 1); |
4411 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::BatchMatMulNodeKind), 0); |
4412 | |
4413 | bindings_.allocate(lhs)->getHandle().randomize(-10, 10, mod_.getPRNG()); |
4414 | |
4415 | checkNumericalEquivalence(0.f); |
4416 | } |
4417 | |
4418 | /// Test Broadcasted RHS BatchMatMul is converted correctly to a single MatMul, |
4419 | /// where RHS is broadcasted in multiple dimensions. |
4420 | TEST_F(GraphOptz, convertMultiBroadcastedBatchMatMulToMatMul) { |
4421 | auto *lhs = |
4422 | mod_.createPlaceholder(ElemKind::FloatTy, {5, 10, 4}, "lhs" , false); |
4423 | auto *rhs = mod_.createConstant(ElemKind::FloatTy, {1, 1, 6}, "rhs" ); |
4424 | rhs->getPayloadMutable().getHandle().randomize(-10, 10, mod_.getPRNG()); |
4425 | auto *BN = F_->createBroadcast("broadcast" , rhs, {5, 4, 6}, /* axis */ 0); |
4426 | auto *BMMN = F_->createBatchMatMul("BMM" , lhs, BN); |
4427 | F_->createSave("save" , BMMN); |
4428 | |
4429 | // Start with a BatchMatMul, not a MatMul, as well as a broadcast. |
4430 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 1); |
4431 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 0); |
4432 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BroadcastNodeKind), 1); |
4433 | |
4434 | optimizedF_ = optimizeFunctionForTest( |
4435 | F_, {FunctionPassID::ConvertBroadcastedBatchMatMul, getDCEPassConfig()}); |
4436 | |
4437 | // Optimization should replace the BatchMatMul with a single MatMul, as well |
4438 | // as include a broadcast leftover. |
4439 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 1); |
4440 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::BatchMatMulNodeKind), 0); |
4441 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::BroadcastNodeKind), 1); |
4442 | |
4443 | bindings_.allocate(lhs)->getHandle().randomize(-10, 10, mod_.getPRNG()); |
4444 | |
4445 | checkNumericalEquivalence(0.f); |
4446 | } |
4447 | |
4448 | TEST_F(GraphOptz, dceQuantization) { |
4449 | auto *lhs = |
4450 | mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "lhs" , false); |
4451 | auto *weights = |
4452 | mod_.createConstant(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "weights" ); |
4453 | |
4454 | auto *add = F_->createAdd("add" , lhs, weights); |
4455 | auto *t1 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.2, 0); |
4456 | auto *rs1 = F_->createRescaleQuantized("rs1" , add, t1); |
4457 | auto *t2 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.1, 1); |
4458 | auto *rs2 = F_->createRescaleQuantized("rs2" , rs1, t2); |
4459 | F_->createSave("save" , rs2); |
4460 | |
4461 | ::glow::optimize(F_, CompilationMode::Infer); |
4462 | |
4463 | EXPECT_EQ(F_->getNodes().size(), 2); |
4464 | } |
4465 | |
4466 | TEST_F(GraphOptz, nopRelu) { |
4467 | auto *in = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, -128, "lhs" , |
4468 | false); |
4469 | |
4470 | auto *relu = F_->createRELU("relu" , in); |
4471 | F_->createSave("save" , relu); |
4472 | |
4473 | optimizedF_ = optimizeFunctionForTest(F_); |
4474 | |
4475 | EXPECT_EQ(optimizedF_->getNodes().size(), 1); |
4476 | |
4477 | bindings_.allocate(mod_.getPlaceholders()); |
4478 | bindings_.get(in)->getHandle<int8_t>().randomize(-4, 4, mod_.getPRNG()); |
4479 | |
4480 | checkNumericalEquivalence(); |
4481 | } |
4482 | |
4483 | template <typename ElemTy> |
4484 | static void setConstValue(Constant *C, ElemTy value) { |
4485 | Handle<ElemTy> TH = C->getPayload().getHandle<ElemTy>(); |
4486 | TH.clear(value); |
4487 | } |
4488 | |
4489 | TEST_F(GraphOptz, constantFoldSingleNode) { |
4490 | auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1" ); |
4491 | auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2" ); |
4492 | auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1" , |
4493 | /* isTrainable */ false); |
4494 | setConstValue(const1, 1.0f); |
4495 | setConstValue(const2, 2.0f); |
4496 | auto *splat2 = F_->createSplat( |
4497 | "splat2" , mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f); |
4498 | auto *splat3 = F_->createSplat( |
4499 | "splat3" , mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f); |
4500 | |
4501 | auto *add1 = F_->createAdd("add" , const1, const2); |
4502 | auto *mul1 = F_->createMul("mul1" , add1, splat2); |
4503 | auto *mul2 = F_->createMul("mul2" , mul1, splat3); |
4504 | auto *SN1 = F_->createSave("save" , mul2); |
4505 | auto *add3 = F_->createAdd("add" , const1, ph1); |
4506 | auto *SN2 = F_->createSave("save" , add3); |
4507 | |
4508 | // Perform constant folding for a specific node. |
4509 | std::vector<Constant *> constResults = |
4510 | constantFold(SN1->getInput().getNode()); |
4511 | |
4512 | ASSERT_EQ(constResults.size(), 1); |
4513 | SN1->getInput().replaceAllUsesOfWith(constResults[0]); |
4514 | // Second save should be unaffected. |
4515 | EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput())); |
4516 | // First save should have been constant folded. |
4517 | EXPECT_TRUE(llvm::isa<Constant>(SN1->getInput())); |
4518 | Constant *C = llvm::dyn_cast<Constant>(SN1->getInput()); |
4519 | auto CH = C->getHandle(); |
4520 | // The expected result should be: (((1+2) * 2 * 3) = 18 |
4521 | EXPECT_EQ(CH.at({0, 0}), 18.0f); |
4522 | EXPECT_EQ(CH.at({0, 1}), 18.0f); |
4523 | EXPECT_EQ(CH.at({1, 0}), 18.0f); |
4524 | EXPECT_EQ(CH.at({1, 1}), 18.0f); |
4525 | } |
4526 | |
4527 | /// Verify that we can specify what splats should be materialized to constants |
4528 | /// based on their users via optimizationOpts.materializeSplatsUsedBySet. |
4529 | TEST_F(GraphOptz, constantFoldSpecificSplat) { |
4530 | Placeholder *PH = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1}, "input" , |
4531 | /* isTrainable */ false); |
4532 | SplatNode *splat1 = F_->createSplat( |
4533 | "splat1" , mod_.uniqueType(ElemKind::FloatTy, {1, 1}), 1.0f); |
4534 | AddNode *add = F_->createAdd("add" , PH, splat1); |
4535 | SplatNode *splat2 = F_->createSplat( |
4536 | "splat2" , mod_.uniqueType(ElemKind::FloatTy, {1, 1}), 2.0f); |
4537 | MulNode *mul = F_->createMul("mul" , add, splat2); |
4538 | SaveNode *save = F_->createSave("save" , mul); |
4539 | |
4540 | // Signal to materialize the splat used by Add, but not by Mul. |
4541 | cctx_.optimizationOpts.materializeSplatsUsedBySet.insert( |
4542 | Kinded::Kind::AddNodeKind); |
4543 | |
4544 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
4545 | |
4546 | ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_); |
4547 | runDCEPass(optimizedF_, cctx_); |
4548 | |
4549 | ASSERT_EQ(record.size(), 1); |
4550 | SaveNode *SN = record.begin()->second; |
4551 | SplatNode *foldSplat1 = llvm::dyn_cast<SplatNode>(SN->getInput()); |
4552 | ASSERT_TRUE(foldSplat1); |
4553 | EXPECT_EQ(foldSplat1->getValue(), 1.0f); |
4554 | |
4555 | // Verify one splat left in the optimized Function, and a new Constant. |
4556 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SplatNodeKind)); |
4557 | const SaveNode *optSave = |
4558 | findFunctionNodeByName<SaveNode>(optimizedF_, save->getName()); |
4559 | MulNode *optMul = llvm::dyn_cast<MulNode>(optSave->getInput()); |
4560 | ASSERT_TRUE(optMul); |
4561 | SplatNode *optSplat2 = llvm::dyn_cast<SplatNode>(optMul->getRHS()); |
4562 | ASSERT_TRUE(optSplat2); |
4563 | EXPECT_EQ(optSplat2->getValue(), 2.0f); |
4564 | AddNode *optAdd = llvm::dyn_cast<AddNode>(optMul->getLHS()); |
4565 | ASSERT_TRUE(optAdd); |
4566 | EXPECT_EQ(optAdd->getLHS().getNode(), PH); |
4567 | Constant *optSplatConst1 = llvm::dyn_cast<Constant>(optAdd->getRHS()); |
4568 | ASSERT_TRUE(optSplatConst1); |
4569 | EXPECT_EQ(optSplatConst1->getPayload().getHandle().at({0, 0}), 1.0f); |
4570 | } |
4571 | |
4572 | /// Test that we correctly record a single constant folding subgraph that has a |
4573 | /// single output. |
4574 | TEST_F(GraphOptz, constantFoldWithRecordSingleChain) { |
4575 | Placeholder *I = |
4576 | mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input" , |
4577 | /* isTrainable */ false); |
4578 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight" ); |
4579 | ClipNode *clipW = F_->createClip("clip" , W, -5.f, 5.f); |
4580 | ConvertToNode *convertW = |
4581 | F_->createConvertTo("conv" , clipW, ElemKind::Float16Ty); |
4582 | TransposeNode *transposeW = |
4583 | F_->createTranspose("transpose" , convertW, {1, 0}); |
4584 | MatMulNode *MM = F_->createMatMul("matmul" , I, transposeW); |
4585 | SaveNode *save = F_->createSave("save" , MM); |
4586 | Placeholder *O = save->getPlaceholder(); |
4587 | bindings_.allocate(O); |
4588 | |
4589 | ASSERT_TRUE(F_->verify()); |
4590 | |
4591 | Tensor *IT = bindings_.allocate(I); |
4592 | IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG()); |
4593 | W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG()); |
4594 | |
4595 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
4596 | |
4597 | ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_); |
4598 | |
4599 | runDCEPass(optimizedF_, cctx_); |
4600 | |
4601 | ASSERT_EQ(record.size(), 1); |
4602 | SaveNode *SN = record.begin()->second; |
4603 | Function *constFoldF = SN->getParent(); |
4604 | |
4605 | // Expect to find a chain of Nodes based on Nodes above. Note that the clip is |
4606 | // lowered for the Interpreter backend which performs constant folding. |
4607 | EXPECT_EQ(2, countNodeKind(constFoldF, Kinded::Kind::SplatNodeKind)); |
4608 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MaxNodeKind)); |
4609 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MinNodeKind)); |
4610 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind)); |
4611 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TransposeNodeKind)); |
4612 | |
4613 | // Skip optimizations -- we just want to run them as is (otherwise we'll |
4614 | // constant fold them inside the optimization pipeline). |
4615 | cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF); |
4616 | cctx_.optimizationOpts.onlyLowerFuns.insert(F_); |
4617 | cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_); |
4618 | |
4619 | // Don't strip the module as we want to compare the Constant values below. |
4620 | EE_.setSkipModuleStrip(true); |
4621 | |
4622 | EE_.compile(cctx_); |
4623 | alreadyCompiled_ = true; |
4624 | |
4625 | bindings_.allocate(mod_.getPlaceholders()); |
4626 | |
4627 | // Run the constant folding chain to check that we have the same constant used |
4628 | // by the optimized Function. |
4629 | EE_.run(bindings_, constFoldF->getName()); |
4630 | Tensor *rerunT = bindings_.get(SN->getPlaceholder()); |
4631 | ASSERT_TRUE(rerunT); |
4632 | auto optimizedConstants = optimizedF_->findConstants(); |
4633 | ASSERT_EQ(optimizedConstants.size(), 1); |
4634 | EXPECT_TRUE( |
4635 | (*optimizedConstants.begin())->getPayload().isEqual(*rerunT, 0.f)); |
4636 | |
4637 | // Remove the temporary constant folding Functions and their Placeholders. |
4638 | cleanupConstantFolding(mod_, record, &bindings_); |
4639 | |
4640 | // Now compile/run/compare F_ and optimizedF_. |
4641 | checkNumericalEquivalence(0.f); |
4642 | } |
4643 | |
4644 | /// Test that we correctly record two constant folding subgraphs, with each with |
4645 | /// a single output. |
4646 | TEST_F(GraphOptz, constantFoldWithRecordMultiChain) { |
4647 | Placeholder *I = |
4648 | mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input" , |
4649 | /* isTrainable */ false); |
4650 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight" ); |
4651 | ClipNode *clipW = F_->createClip("clip" , W, -5.f, 5.f); |
4652 | ConvertToNode *convertW = |
4653 | F_->createConvertTo("conv" , clipW, ElemKind::Float16Ty); |
4654 | TransposeNode *transposeW = |
4655 | F_->createTranspose("transpose" , convertW, {1, 0}); |
4656 | MatMulNode *MM = F_->createMatMul("matmul" , I, transposeW); |
4657 | SaveNode *saveMM = F_->createSave("save_mm" , MM); |
4658 | Placeholder *MMP = saveMM->getPlaceholder(); |
4659 | bindings_.allocate(MMP); |
4660 | |
4661 | SigmoidNode *sigmoidW = F_->createSigmoid("sig" , convertW); |
4662 | SaveNode *saveSig = F_->createSave("save_sig" , sigmoidW); |
4663 | Placeholder *sigP = saveSig->getPlaceholder(); |
4664 | bindings_.allocate(sigP); |
4665 | |
4666 | ASSERT_TRUE(F_->verify()); |
4667 | |
4668 | Tensor *IT = bindings_.allocate(I); |
4669 | IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG()); |
4670 | W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG()); |
4671 | |
4672 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
4673 | |
4674 | ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_); |
4675 | |
4676 | runDCEPass(optimizedF_, cctx_); |
4677 | |
4678 | ASSERT_EQ(record.size(), 2); |
4679 | SaveNode *sigSN = record.begin()->second; |
4680 | SaveNode *transSN = std::next(record.begin())->second; |
4681 | if (llvm::isa<SigmoidNode>(transSN->getInput())) { |
4682 | std::swap(sigSN, transSN); |
4683 | } |
4684 | |
4685 | Function *constFoldSig = sigSN->getParent(); |
4686 | Function *constFoldTrans = transSN->getParent(); |
4687 | |
4688 | // Expect to find a chain of Nodes based on Nodes above. Note that the clip is |
4689 | // lowered for the Interpreter backend which performs constant folding. |
4690 | EXPECT_EQ(2, countNodeKind(constFoldTrans, Kinded::Kind::SplatNodeKind)); |
4691 | EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MaxNodeKind)); |
4692 | EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MinNodeKind)); |
4693 | EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::ConvertToNodeKind)); |
4694 | EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::TransposeNodeKind)); |
4695 | |
4696 | EXPECT_EQ(2, countNodeKind(constFoldSig, Kinded::Kind::SplatNodeKind)); |
4697 | EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MaxNodeKind)); |
4698 | EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MinNodeKind)); |
4699 | EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::ConvertToNodeKind)); |
4700 | EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::SigmoidNodeKind)); |
4701 | |
4702 | // Skip optimizations -- we just want to run them as is (otherwise we'll |
4703 | // constant fold them inside the optimization pipeline). |
4704 | cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldTrans); |
4705 | cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldSig); |
4706 | cctx_.optimizationOpts.onlyLowerFuns.insert(F_); |
4707 | cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_); |
4708 | |
4709 | // Don't strip the module as we want to compare the Constant values below. |
4710 | EE_.setSkipModuleStrip(true); |
4711 | |
4712 | EE_.compile(cctx_); |
4713 | alreadyCompiled_ = true; |
4714 | |
4715 | bindings_.allocate(mod_.getPlaceholders()); |
4716 | |
4717 | // Run the constant folding chain to check that we have the same constant used |
4718 | // by the optimized Function. |
4719 | EE_.run(bindings_, constFoldTrans->getName()); |
4720 | EE_.run(bindings_, constFoldSig->getName()); |
4721 | |
4722 | // Find the correct PHs for each of the constant folding we do. |
4723 | Tensor *rerunTransT = bindings_.get(transSN->getPlaceholder()); |
4724 | Tensor *rerunSigT = bindings_.get(sigSN->getPlaceholder()); |
4725 | ASSERT_TRUE(rerunTransT); |
4726 | ASSERT_TRUE(rerunSigT); |
4727 | |
4728 | auto optimizedConstants = optimizedF_->findConstants(); |
4729 | ASSERT_EQ(optimizedConstants.size(), 2); |
4730 | Constant *transC = *optimizedConstants.begin(); |
4731 | Constant *sigC = *std::next(optimizedConstants.begin()); |
4732 | // If we have the constants backwards then swap them. Note that we know |
4733 | // sigC must be directly saved, while transC is input to a MatMulNode. |
4734 | ASSERT_EQ(transC->getNumUsers(), 1); |
4735 | if (llvm::isa<SaveNode>(transC->getUsers().begin()->getUser())) { |
4736 | std::swap(transC, sigC); |
4737 | } |
4738 | EXPECT_TRUE(transC->getPayload().isEqual(*rerunTransT, 0.f)); |
4739 | EXPECT_TRUE(sigC->getPayload().isEqual(*rerunSigT, 0.f)); |
4740 | |
4741 | // Remove the temporary constant folding Functions and their Placeholders. |
4742 | cleanupConstantFolding(mod_, record, &bindings_); |
4743 | |
4744 | // Now compile/run/compare F_ and optimizedF_. |
4745 | checkNumericalEquivalence(0.f); |
4746 | } |
4747 | |
4748 | /// Test that we correctly record a single constant folding subgraph that has |
4749 | /// two outputs. |
4750 | TEST_F(GraphOptz, constantFoldWithRecordSingleChainMultiOutput) { |
4751 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {100}, "weight" ); |
4752 | SigmoidNode *sigmoidW = F_->createSigmoid("sig" , W); |
4753 | ConvertToNode *convertW = |
4754 | F_->createConvertTo("conv" , sigmoidW, ElemKind::Float16Ty); |
4755 | TopKNode *TK = F_->createTopK("topk" , convertW, 5); |
4756 | |
4757 | SaveNode *indicesSave = F_->createSave("save_indices" , TK->getIndices()); |
4758 | Placeholder *indicesP = indicesSave->getPlaceholder(); |
4759 | bindings_.allocate(indicesP); |
4760 | |
4761 | Placeholder *I = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "input" , |
4762 | /* isTrainable */ false); |
4763 | AddNode *add = F_->createAdd("add" , I, TK->getValues()); |
4764 | SaveNode *addSave = F_->createSave("save_add" , add); |
4765 | Placeholder *addP = addSave->getPlaceholder(); |
4766 | bindings_.allocate(addP); |
4767 | |
4768 | ASSERT_TRUE(F_->verify()); |
4769 | |
4770 | Tensor *IT = bindings_.allocate(I); |
4771 | IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG()); |
4772 | W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG()); |
4773 | |
4774 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
4775 | |
4776 | ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_); |
4777 | |
4778 | runDCEPass(optimizedF_, cctx_); |
4779 | |
4780 | ASSERT_EQ(record.size(), 2); |
4781 | SaveNode *indicesSN = record.begin()->second; |
4782 | SaveNode *addSN = std::next(record.begin())->second; |
4783 | |
4784 | // Find the correct PHs for each of the constant folding we do. |
4785 | if (indicesSN->getInput().getResNo() != TopKNode::IndicesIdx) { |
4786 | std::swap(indicesSN, addSN); |
4787 | } |
4788 | |
4789 | // Expect that the two constants that we folded are from the same Function, |
4790 | // and that the two saves use the two different outputs from a topk. |
4791 | EXPECT_EQ(indicesSN->getParent(), addSN->getParent()); |
4792 | ASSERT_TRUE(llvm::isa<TopKNode>(addSN->getInput())); |
4793 | ASSERT_TRUE(llvm::isa<TopKNode>(indicesSN->getInput())); |
4794 | EXPECT_EQ(addSN->getInput().getNode(), indicesSN->getInput().getNode()); |
4795 | |
4796 | Function *constFoldF = addSN->getParent(); |
4797 | |
4798 | // Expect to find a chain of Nodes based on Nodes above. |
4799 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TopKNodeKind)); |
4800 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::SigmoidNodeKind)); |
4801 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind)); |
4802 | |
4803 | // Skip optimizations -- we just want to run them as is (otherwise we'll |
4804 | // constant fold them inside the optimization pipeline). |
4805 | cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF); |
4806 | cctx_.optimizationOpts.onlyLowerFuns.insert(F_); |
4807 | cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_); |
4808 | |
4809 | // Don't strip the module as we want to compare the Constant values below. |
4810 | EE_.setSkipModuleStrip(true); |
4811 | |
4812 | EE_.compile(cctx_); |
4813 | alreadyCompiled_ = true; |
4814 | |
4815 | bindings_.allocate(mod_.getPlaceholders()); |
4816 | |
4817 | // Run the constant folding chain to check that we have the same constant used |
4818 | // by the optimized Function. |
4819 | EE_.run(bindings_, constFoldF->getName()); |
4820 | |
4821 | Tensor *rerunAddT = bindings_.get(addSN->getPlaceholder()); |
4822 | Tensor *rerunIndicesT = bindings_.get(indicesSN->getPlaceholder()); |
4823 | ASSERT_TRUE(rerunAddT); |
4824 | ASSERT_TRUE(rerunIndicesT); |
4825 | |
4826 | auto optimizedConstants = optimizedF_->findConstants(); |
4827 | ASSERT_EQ(optimizedConstants.size(), 2); |
4828 | Constant *addC = *optimizedConstants.begin(); |
4829 | Constant *indicesC = *std::next(optimizedConstants.begin()); |
4830 | |
4831 | // If we have the constants backwards then swap them. Note that we know |
4832 | // indicesC must be directly saved, while addC is input to an AddNode. |
4833 | ASSERT_EQ(addC->getNumUsers(), 1); |
4834 | if (llvm::isa<SaveNode>(addC->getUsers().begin()->getUser())) { |
4835 | std::swap(addC, indicesC); |
4836 | } |
4837 | EXPECT_TRUE(addC->getPayload().isEqual(*rerunAddT, 0.f)); |
4838 | EXPECT_TRUE(indicesC->getPayload().isEqual(*rerunIndicesT, 0.f)); |
4839 | |
4840 | // Remove the temporary constant folding Functions and their Placeholders. |
4841 | cleanupConstantFolding(mod_, record, &bindings_); |
4842 | |
4843 | // Now compile/run/compare F_ and optimizedF_. |
4844 | checkNumericalEquivalence(0.f); |
4845 | } |
4846 | |
4847 | /// Test that the constant folding record Function includes all ops, |
4848 | /// i.e. they're not optimized away during optimizations when the constant |
4849 | /// folding function is optimized. |
4850 | TEST_F(GraphOptz, constantFoldOnlyLower) { |
4851 | Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight" ); |
4852 | ConvertToNode *convertW = F_->createConvertTo("conv" , W, ElemKind::Float16Ty); |
4853 | SaveNode *save = F_->createSave("save" , convertW); |
4854 | Placeholder *O = save->getPlaceholder(); |
4855 | bindings_.allocate(O); |
4856 | |
4857 | ASSERT_TRUE(F_->verify()); |
4858 | |
4859 | W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG()); |
4860 | |
4861 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
4862 | |
4863 | ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_); |
4864 | |
4865 | ASSERT_EQ(record.size(), 1); |
4866 | SaveNode *SN = record.begin()->second; |
4867 | Function *constFoldF = SN->getParent(); |
4868 | |
4869 | // Expect to find a Save and the ConvertTo still, i.e. it shouldn't have been |
4870 | // folded into the Constant as part of the OptimizeConversions pass. |
4871 | EXPECT_EQ(2, constFoldF->getNodes().size()); |
4872 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind)); |
4873 | EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::SaveNodeKind)); |
4874 | } |
4875 | |
4876 | TEST_F(GraphOptz, constantFoldWholeFunction) { |
4877 | auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1" ); |
4878 | auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2" ); |
4879 | auto *const3 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const3" ); |
4880 | auto *const4 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const4" ); |
4881 | auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1" , |
4882 | /* isTrainable */ false); |
4883 | setConstValue(const1, 1.0f); |
4884 | setConstValue(const2, 2.0f); |
4885 | setConstValue(const3, 3.0f); |
4886 | setConstValue(const4, 4.0f); |
4887 | auto *splat2 = F_->createSplat( |
4888 | "splat2" , mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f); |
4889 | auto *splat3 = F_->createSplat( |
4890 | "splat2" , mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f); |
4891 | auto *splat4 = F_->createSplat( |
4892 | "splat2" , mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 4.0f); |
4893 | |
4894 | auto *add1 = F_->createAdd("add" , const1, const2); |
4895 | auto *mul1 = F_->createMul("mul1" , add1, splat2); |
4896 | auto *mul2 = F_->createMul("mul2" , mul1, splat3); |
4897 | auto *sub = F_->createSub("sub" , mul2, const3); |
4898 | auto *add2 = F_->createAdd("add2" , sub, const4); |
4899 | auto *mul3 = F_->createMul("mul3" , add2, splat4); |
4900 | // Check compile-time constant folding for nodes with multiple results. |
4901 | auto *topK = F_->createTopK("topK" , mul3, 2); |
4902 | auto *SN1_0 = F_->createSave("save" , topK->getValues()); |
4903 | auto *SN1_1 = F_->createSave("save" , topK->getIndices()); |
4904 | auto *add3 = F_->createAdd("add" , const1, ph1); |
4905 | auto *SN2 = F_->createSave("save" , add3); |
4906 | |
4907 | // Perform constant folding for a whole function. |
4908 | ::glow::optimize(F_, CompilationMode::Infer); |
4909 | |
4910 | EXPECT_EQ(F_->getNodes().size(), 4); |
4911 | // Second save should be unaffected, as its value is not a constant operation. |
4912 | EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput())); |
4913 | // First save should have been constant folded. |
4914 | EXPECT_TRUE(llvm::isa<Constant>(SN1_0->getInput())); |
4915 | EXPECT_TRUE(llvm::isa<Constant>(SN1_1->getInput())); |
4916 | Constant *C = llvm::dyn_cast<Constant>(SN1_0->getInput()); |
4917 | auto CH = C->getHandle(); |
4918 | // The expected result should be: (((1+2) * 2 * 3 - 3) + 4) * 4 = 76 |
4919 | EXPECT_EQ(CH.at({0, 0}), 76.0f); |
4920 | EXPECT_EQ(CH.at({0, 1}), 76.0f); |
4921 | EXPECT_EQ(CH.at({1, 0}), 76.0f); |
4922 | EXPECT_EQ(CH.at({1, 1}), 76.0f); |
4923 | } |
4924 | |
4925 | /// Test constant folding for operators which are lowered in Interpreter |
4926 | /// backend. |
4927 | TEST_F(GraphOptz, constantFoldWithLowering) { |
4928 | auto *input = mod_.createConstant(ElemKind::FloatTy, {1, 6}, "input" ); |
4929 | input->getHandle() = {5, 4, 3, 2, 1, 0}; |
4930 | auto *TN = F_->createTile("tile" , input, 5, 0); |
4931 | auto *SN = F_->createSave("ret" , TN); |
4932 | |
4933 | // Perform constant folding. |
4934 | EXPECT_EQ(F_->getNodes().size(), 2); |
4935 | ::glow::optimize(F_, CompilationMode::Infer); |
4936 | |
4937 | // Tile with its input should be folded into a single Constant node. |
4938 | EXPECT_EQ(F_->getNodes().size(), 1); |
4939 | ASSERT_TRUE(llvm::isa<Constant>(SN->getInput())); |
4940 | } |
4941 | |
4942 | /// Test Splitting FC into multiple FCs. |
4943 | TEST_F(GraphOptz, SplitFCIntoMultipleOps) { |
4944 | auto *input = |
4945 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 32}, "input" , false); |
4946 | bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0, |
4947 | mod_.getPRNG()); |
4948 | auto *weights = mod_.createConstant(ElemKind::FloatTy, {32, 850}, "weights" ); |
4949 | weights->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
4950 | auto *bias = mod_.createConstant(ElemKind::FloatTy, {850}, "bias" ); |
4951 | bias->getHandle().randomize(0.0, 0.5, mod_.getPRNG()); |
4952 | auto *output = |
4953 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 850}, "output" , false); |
4954 | bindings_.allocate(output); |
4955 | |
4956 | auto *fc = F_->createFullyConnected("fc" , input, weights, bias); |
4957 | auto *save = F_->createSave("save" , fc, output); |
4958 | |
4959 | ::glow::optimize(F_, CompilationMode::Infer); |
4960 | |
4961 | // This is F_ but without the parallel transformation below. |
4962 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
4963 | |
4964 | EXPECT_TRUE(::glow::executeVerticalFCWeightsSplit(F_, |
4965 | /*numOfChunks*/ 12, |
4966 | /*minKToSplit*/ 800)); |
4967 | runDCEPass(F_, cctx_); |
4968 | |
4969 | // 24 Slices: 12 from bias and 12 from weights. |
4970 | EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
4971 | |
4972 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
4973 | |
4974 | // 12 newly created FCs. |
4975 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
4976 | |
4977 | auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput()); |
4978 | ASSERT_TRUE(concatNode); |
4979 | // 12 FCs are connected to the concat node. |
4980 | EXPECT_EQ(12, concatNode->getInputs().size()); |
4981 | |
4982 | // Check all splitted FCs. |
4983 | for (unsigned i = 0; i < 12; ++i) { |
4984 | auto *fc = llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(i)); |
4985 | ASSERT_TRUE(fc); |
4986 | // 2 * 71 for first 11 FCs and last 2 * 69 |
4987 | if (i == 11) { |
4988 | EXPECT_TRUE(fc->getResult().dims().equals({2, 69})); |
4989 | EXPECT_TRUE(fc->getBias().dims().equals({69})); |
4990 | EXPECT_TRUE(fc->getWeights().dims().equals({32, 69})); |
4991 | } else { |
4992 | EXPECT_TRUE(fc->getResult().dims().equals({2, 71})); |
4993 | EXPECT_TRUE(fc->getBias().dims().equals({71})); |
4994 | EXPECT_TRUE(fc->getWeights().dims().equals({32, 71})); |
4995 | } |
4996 | } |
4997 | |
4998 | checkNumericalEquivalence(); |
4999 | } |
5000 | |
5001 | /// Test Splitting FC into multiple FCs. |
5002 | TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel) { |
5003 | auto *input = |
5004 | mod_.createPlaceholder(ElemKind::FloatTy, {8, 3}, "input" , false); |
5005 | bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0, |
5006 | mod_.getPRNG()); |
5007 | auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {3, 150}, "weights" ); |
5008 | weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
5009 | auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias" ); |
5010 | bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG()); |
5011 | auto *weights2 = |
5012 | mod_.createConstant(ElemKind::FloatTy, {150, 150}, "weights" ); |
5013 | weights2->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
5014 | auto *bias2 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias" ); |
5015 | bias2->getHandle().randomize(0.0, 0.5, mod_.getPRNG()); |
5016 | auto *output = |
5017 | mod_.createPlaceholder(ElemKind::FloatTy, {8, 150}, "output" , false); |
5018 | bindings_.allocate(output); |
5019 | |
5020 | auto *fc1 = F_->createFullyConnected("fc1" , input, weights1, bias1); |
5021 | auto *relu1 = F_->createRELU("relu1" , fc1); |
5022 | |
5023 | auto *fc2 = F_->createFullyConnected("fc2" , relu1, weights2, bias2); |
5024 | auto *relu2 = F_->createRELU("relu2" , fc2); |
5025 | F_->createSave("save" , relu2, output); |
5026 | |
5027 | ::glow::optimize(F_, CompilationMode::Infer); |
5028 | |
5029 | // This is F_ but without the parallel transformation below. |
5030 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5031 | |
5032 | // Perform parallel transformation on F_. |
5033 | llvm::DenseMap<Node *, size_t> numChunks; |
5034 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5035 | numChunks[fc1] = 2; |
5036 | numChunks[relu1] = 2; |
5037 | numChunks[fc2] = 2; |
5038 | numChunks[relu2] = 2; |
5039 | parOpts[fc1] = ParallelTransformKind::Model; |
5040 | parOpts[relu1] = ParallelTransformKind::Model; |
5041 | parOpts[fc2] = ParallelTransformKind::Model; |
5042 | parOpts[relu2] = ParallelTransformKind::Model; |
5043 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5044 | ASSIGN_VALUE_OR_FAIL_TEST(replacedMap, |
5045 | ::glow::parallelizeOps(F_, numChunks, parOpts)); |
5046 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5047 | |
5048 | runDCEPass(F_, cctx_); |
5049 | |
5050 | EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
5051 | EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::ReluNodeKind)); |
5052 | |
5053 | checkNumericalEquivalence(); |
5054 | } |
5055 | |
5056 | /// Test Splitting FC into multiple FCs, special case for 866 by 8 with an |
5057 | /// alignment of 64, which is a corner case for alignment and should only |
5058 | /// produce 7 splits |
5059 | TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel_Split866by8) { |
5060 | auto *input = |
5061 | mod_.createPlaceholder(ElemKind::FloatTy, {8, 3}, "input" , false); |
5062 | bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0, |
5063 | mod_.getPRNG()); |
5064 | auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {3, 866}, "weights" ); |
5065 | weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
5066 | auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {866}, "bias" ); |
5067 | bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG()); |
5068 | auto *output = |
5069 | mod_.createPlaceholder(ElemKind::FloatTy, {8, 866}, "output" , false); |
5070 | bindings_.allocate(output); |
5071 | |
5072 | auto *fc1 = F_->createFullyConnected("fc1" , input, weights1, bias1); |
5073 | auto *relu1 = F_->createRELU("relu1" , fc1); |
5074 | |
5075 | F_->createSave("save" , relu1, output); |
5076 | |
5077 | ::glow::optimize(F_, CompilationMode::Infer); |
5078 | |
5079 | // This is F_ but without the parallel transformation below. |
5080 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5081 | |
5082 | // Perform parallel transformation on F_. |
5083 | llvm::DenseMap<Node *, size_t> numChunks; |
5084 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5085 | numChunks[fc1] = 8; |
5086 | numChunks[relu1] = 8; |
5087 | parOpts[fc1] = ParallelTransformKind::Model; |
5088 | parOpts[relu1] = ParallelTransformKind::Model; |
5089 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5090 | ASSIGN_VALUE_OR_FAIL_TEST( |
5091 | replacedMap, |
5092 | ::glow::parallelizeOps(F_, numChunks, parOpts, /*numOfChunks*/ 1, |
5093 | /*modelParallelSplitAlignment*/ 64)); |
5094 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5095 | |
5096 | runDCEPass(F_, cctx_); |
5097 | |
5098 | EXPECT_EQ(7, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
5099 | EXPECT_EQ(7, countNodeKind(F_, Kinded::Kind::ReluNodeKind)); |
5100 | |
5101 | // Check all splitted FCs. |
5102 | auto *concatNode = replacedMap[fc1]; |
5103 | for (unsigned i = 0; i < 7; ++i) { |
5104 | auto *fc = llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(i)); |
5105 | ASSERT_TRUE(fc); |
5106 | // 8 x 128 for first 6 FCs and last 8 x 30 |
5107 | if (i == 6) { |
5108 | EXPECT_TRUE(fc->getResult().dims().equals({8, 98})); |
5109 | EXPECT_TRUE(fc->getBias().dims().equals({98})); |
5110 | EXPECT_TRUE(fc->getWeights().dims().equals({3, 98})); |
5111 | } else { |
5112 | EXPECT_TRUE(fc->getResult().dims().equals({8, 128})); |
5113 | EXPECT_TRUE(fc->getBias().dims().equals({128})); |
5114 | EXPECT_TRUE(fc->getWeights().dims().equals({3, 128})); |
5115 | } |
5116 | } |
5117 | |
5118 | checkNumericalEquivalence(); |
5119 | } |
5120 | |
5121 | /// Test Splitting FC into multiple FCs, special case for 140 by 3 with an |
5122 | /// alignment of 64. Should split 64, 64, 12 |
5123 | TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel_Split140by3) { |
5124 | auto *input = |
5125 | mod_.createPlaceholder(ElemKind::FloatTy, {8, 3}, "input" , false); |
5126 | bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0, |
5127 | mod_.getPRNG()); |
5128 | auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {3, 140}, "weights" ); |
5129 | weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
5130 | auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {140}, "bias" ); |
5131 | bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG()); |
5132 | auto *output = |
5133 | mod_.createPlaceholder(ElemKind::FloatTy, {8, 140}, "output" , false); |
5134 | bindings_.allocate(output); |
5135 | |
5136 | auto *fc1 = F_->createFullyConnected("fc1" , input, weights1, bias1); |
5137 | auto *relu1 = F_->createRELU("relu1" , fc1); |
5138 | |
5139 | F_->createSave("save" , relu1, output); |
5140 | |
5141 | ::glow::optimize(F_, CompilationMode::Infer); |
5142 | |
5143 | // This is F_ but without the parallel transformation below. |
5144 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5145 | |
5146 | // Perform parallel transformation on F_. |
5147 | llvm::DenseMap<Node *, size_t> numChunks; |
5148 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5149 | numChunks[fc1] = 3; |
5150 | parOpts[fc1] = ParallelTransformKind::Model; |
5151 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5152 | ASSIGN_VALUE_OR_FAIL_TEST( |
5153 | replacedMap, |
5154 | ::glow::parallelizeOps(F_, numChunks, parOpts, /*numOfChunks*/ 1, |
5155 | /*modelParallelSplitAlignment*/ 64)); |
5156 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5157 | runDCEPass(F_, cctx_); |
5158 | EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
5159 | |
5160 | // Check all splitted FCs. |
5161 | auto *concatNode = replacedMap[fc1]; |
5162 | auto *fc_split0 = |
5163 | llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(0)); |
5164 | auto *fc_split1 = |
5165 | llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(1)); |
5166 | auto *fc_split2 = |
5167 | llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(2)); |
5168 | ASSERT_TRUE(fc_split0); |
5169 | ASSERT_TRUE(fc_split1); |
5170 | ASSERT_TRUE(fc_split2); |
5171 | EXPECT_TRUE(fc_split0->getResult().dims().equals({8, 64})); |
5172 | EXPECT_TRUE(fc_split0->getBias().dims().equals({64})); |
5173 | EXPECT_TRUE(fc_split0->getWeights().dims().equals({3, 64})); |
5174 | EXPECT_TRUE(fc_split1->getResult().dims().equals({8, 64})); |
5175 | EXPECT_TRUE(fc_split1->getBias().dims().equals({64})); |
5176 | EXPECT_TRUE(fc_split1->getWeights().dims().equals({3, 64})); |
5177 | EXPECT_TRUE(fc_split2->getResult().dims().equals({8, 12})); |
5178 | EXPECT_TRUE(fc_split2->getBias().dims().equals({12})); |
5179 | EXPECT_TRUE(fc_split2->getWeights().dims().equals({3, 12})); |
5180 | |
5181 | checkNumericalEquivalence(); |
5182 | } |
5183 | |
5184 | /// Test Splitting MatMul into multiple MatMuls |
5185 | TEST_F(GraphOptz, SplitMatMulIntoMultipleOps_Data) { |
5186 | auto *input1 = |
5187 | mod_.createPlaceholder(ElemKind::FloatTy, {12, 32}, "input1" , false); |
5188 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5189 | mod_.getPRNG()); |
5190 | auto *input2 = |
5191 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 32}, "input2" , false); |
5192 | bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0, |
5193 | mod_.getPRNG()); |
5194 | auto *output = |
5195 | mod_.createPlaceholder(ElemKind::FloatTy, {12, 32}, "output" , false); |
5196 | bindings_.allocate(output); |
5197 | |
5198 | auto *mm = F_->createMatMul("mm" , input1, input2); |
5199 | auto *save = F_->createSave("save" , mm, output); |
5200 | |
5201 | ::glow::optimize(F_, CompilationMode::Infer); |
5202 | |
5203 | // This is F_ but without the parallel transformation below. |
5204 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5205 | |
5206 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5207 | parOpts[mm] = ParallelTransformKind::Data; |
5208 | |
5209 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5210 | ASSIGN_VALUE_OR_FAIL_TEST( |
5211 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5212 | parOpts, 12)); |
5213 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5214 | runDCEPass(F_, cctx_); |
5215 | |
5216 | // 12 Slices from LHS |
5217 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5218 | |
5219 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5220 | |
5221 | // 12 newly created MatMuls. |
5222 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MatMulNodeKind)); |
5223 | |
5224 | auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput()); |
5225 | ASSERT_TRUE(concatNode); |
5226 | // 12 FCs are connected to the concat node. |
5227 | EXPECT_EQ(12, concatNode->getInputs().size()); |
5228 | |
5229 | for (unsigned i = 0; i < 12; ++i) { |
5230 | auto *mmInput = llvm::dyn_cast<MatMulNode>(concatNode->getNthInput(i)); |
5231 | ASSERT_TRUE(mmInput); |
5232 | EXPECT_TRUE(mmInput->getResult().dims().equals({1, 32})); |
5233 | } |
5234 | |
5235 | checkNumericalEquivalence(); |
5236 | } |
5237 | |
5238 | /// Test Splitting MatMul into multiple MatMuls |
5239 | TEST_F(GraphOptz, SplitMatMulIntoMultipleOps_Model) { |
5240 | auto *input1 = |
5241 | mod_.createPlaceholder(ElemKind::FloatTy, {12, 48}, "input1" , false); |
5242 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5243 | mod_.getPRNG()); |
5244 | auto *input2 = |
5245 | mod_.createPlaceholder(ElemKind::FloatTy, {48, 48}, "input2" , false); |
5246 | bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0, |
5247 | mod_.getPRNG()); |
5248 | auto *output = |
5249 | mod_.createPlaceholder(ElemKind::FloatTy, {12, 48}, "output" , false); |
5250 | bindings_.allocate(output); |
5251 | |
5252 | auto *mm = F_->createMatMul("mm" , input1, input2); |
5253 | auto *save = F_->createSave("save" , mm, output); |
5254 | |
5255 | ::glow::optimize(F_, CompilationMode::Infer); |
5256 | |
5257 | // This is F_ but without the parallel transformation below. |
5258 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5259 | |
5260 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5261 | parOpts[mm] = ParallelTransformKind::Model; |
5262 | |
5263 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5264 | ASSIGN_VALUE_OR_FAIL_TEST( |
5265 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5266 | parOpts, 12)); |
5267 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5268 | runDCEPass(F_, cctx_); |
5269 | |
5270 | // 12 Slices from RHS |
5271 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5272 | |
5273 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5274 | |
5275 | // 12 newly created MatMuls. |
5276 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MatMulNodeKind)); |
5277 | |
5278 | auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput()); |
5279 | ASSERT_TRUE(concatNode); |
5280 | // 12 FCs are connected to the concat node. |
5281 | EXPECT_EQ(12, concatNode->getInputs().size()); |
5282 | |
5283 | for (unsigned i = 0; i < 12; ++i) { |
5284 | auto *mmInput = llvm::dyn_cast<MatMulNode>(concatNode->getNthInput(i)); |
5285 | ASSERT_TRUE(mmInput); |
5286 | EXPECT_TRUE(mmInput->getResult().dims().equals({12, 4})); |
5287 | } |
5288 | |
5289 | checkNumericalEquivalence(); |
5290 | } |
5291 | |
5292 | /// Test Splitting Add into multiple Adds. |
5293 | TEST_F(GraphOptz, ParallelizeGraph_Add) { |
5294 | auto *input1 = |
5295 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1" , false); |
5296 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5297 | mod_.getPRNG()); |
5298 | auto *input2 = |
5299 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2" , false); |
5300 | bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0, |
5301 | mod_.getPRNG()); |
5302 | auto *output = |
5303 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5304 | bindings_.allocate(output); |
5305 | |
5306 | auto *add1 = F_->createAdd("add1" , input1, input2); |
5307 | auto *add2 = F_->createAdd("add2" , add1, add1); |
5308 | F_->createSave("save" , add2, output); |
5309 | |
5310 | ::glow::optimize(F_, CompilationMode::Infer); |
5311 | |
5312 | // This is F_ but without the parallel transformation below. |
5313 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5314 | |
5315 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5316 | parOpts[add1] = ParallelTransformKind::Data; |
5317 | |
5318 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5319 | ASSIGN_VALUE_OR_FAIL_TEST( |
5320 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5321 | parOpts, 12)); |
5322 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5323 | runDCEPass(F_, cctx_); |
5324 | |
5325 | // We now have 12 Adds from add1, as well as the original add2 which is |
5326 | // unchanged. |
5327 | EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::AddNodeKind)); |
5328 | |
5329 | // Each input of the 12 Adds are sliced. |
5330 | EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5331 | |
5332 | // One concat to bring all of the parallelized sliced Adds together. |
5333 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5334 | |
5335 | checkNumericalEquivalence(); |
5336 | } |
5337 | |
5338 | /// Test Splitting Add into multiple Adds along different axes. |
5339 | static void testParallelizeGraphAddModel(PlaceholderBindings &bindings, |
5340 | Module &mod, Function *F, |
5341 | Function *&optF, |
5342 | CompilationContext &cctx, |
5343 | ParallelTransformKind parKind) { |
5344 | auto *input1 = mod.createPlaceholder(ElemKind::FloatTy, {16, 17, 18, 19, 20}, |
5345 | "input1" , false); |
5346 | bindings.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5347 | mod.getPRNG()); |
5348 | auto *input2 = mod.createPlaceholder(ElemKind::FloatTy, {16, 17, 18, 19, 20}, |
5349 | "input2" , false); |
5350 | bindings.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0, |
5351 | mod.getPRNG()); |
5352 | auto *output = mod.createPlaceholder(ElemKind::FloatTy, {16, 17, 18, 19, 20}, |
5353 | "output" , false); |
5354 | bindings.allocate(output); |
5355 | |
5356 | auto *add1 = F->createAdd("add1" , input1, input2); |
5357 | auto *add2 = F->createAdd("add2" , add1, add1); |
5358 | F->createSave("save" , add2, output); |
5359 | |
5360 | ::glow::optimize(F, CompilationMode::Infer); |
5361 | |
5362 | // This is F_ but without the parallel transformation below. |
5363 | optF = F->clone(F->getName().str() + "_optimized" ); |
5364 | |
5365 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5366 | parOpts[add1] = parKind; |
5367 | |
5368 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5369 | ASSIGN_VALUE_OR_FAIL_TEST( |
5370 | replacedMap, |
5371 | ::glow::parallelizeOps(F, llvm::DenseMap<Node *, size_t>(), parOpts, 12)); |
5372 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5373 | runDCEPass(F, cctx); |
5374 | |
5375 | // We now have 12 Adds from add1, as well as the original add2 which is |
5376 | // unchanged. |
5377 | EXPECT_EQ(13, countNodeKind(F, Kinded::Kind::AddNodeKind)); |
5378 | |
5379 | // Each input of the 12 Adds are sliced. |
5380 | EXPECT_EQ(24, countNodeKind(F, Kinded::Kind::SliceNodeKind)); |
5381 | |
5382 | // One concat to bring all of the parallelized sliced Adds together. |
5383 | EXPECT_EQ(1, countNodeKind(F, Kinded::Kind::ConcatNodeKind)); |
5384 | } |
5385 | |
5386 | TEST_F(GraphOptz, ParallelizeGraph_Add_Model_Axis1) { |
5387 | testParallelizeGraphAddModel(bindings_, mod_, F_, optimizedF_, cctx_, |
5388 | ParallelTransformKind::Model_Axis1); |
5389 | checkNumericalEquivalence(0.f); |
5390 | } |
5391 | |
5392 | TEST_F(GraphOptz, ParallelizeGraph_Add_Model_Axis3) { |
5393 | testParallelizeGraphAddModel(bindings_, mod_, F_, optimizedF_, cctx_, |
5394 | ParallelTransformKind::Model_Axis3); |
5395 | checkNumericalEquivalence(0.f); |
5396 | } |
5397 | |
5398 | TEST_F(GraphOptz, ParallelizeGraph_Add_Model_Axis4) { |
5399 | testParallelizeGraphAddModel(bindings_, mod_, F_, optimizedF_, cctx_, |
5400 | ParallelTransformKind::Model_Axis4); |
5401 | checkNumericalEquivalence(0.f); |
5402 | } |
5403 | |
5404 | /// Test Splitting Sub into multiple Subs. |
5405 | TEST_F(GraphOptz, ParallelizeGraph_Sub) { |
5406 | auto *input1 = |
5407 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1" , false); |
5408 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5409 | mod_.getPRNG()); |
5410 | auto *input2 = |
5411 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2" , false); |
5412 | bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0, |
5413 | mod_.getPRNG()); |
5414 | auto *output = |
5415 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5416 | bindings_.allocate(output); |
5417 | |
5418 | auto *sub1 = F_->createSub("sub1" , input1, input2); |
5419 | auto *sub2 = F_->createSub("sub2" , sub1, sub1); |
5420 | F_->createSave("save" , sub2, output); |
5421 | |
5422 | ::glow::optimize(F_, CompilationMode::Infer); |
5423 | |
5424 | // This is F_ but without the parallel transformation below. |
5425 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5426 | |
5427 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5428 | parOpts[sub1] = ParallelTransformKind::Data; |
5429 | |
5430 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5431 | ASSIGN_VALUE_OR_FAIL_TEST( |
5432 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5433 | parOpts, 12)); |
5434 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5435 | runDCEPass(F_, cctx_); |
5436 | |
5437 | // We now have 12 Subs from sub1, as well as the original sub2 which is |
5438 | // unchanged. |
5439 | EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::SubNodeKind)); |
5440 | |
5441 | // Each input of the 12 Subs are sliced. |
5442 | EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5443 | |
5444 | // One concat to bring all of the parallelized sliced Subs together. |
5445 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5446 | |
5447 | checkNumericalEquivalence(); |
5448 | } |
5449 | |
5450 | /// Test Splitting Pow into multiple Pows. |
5451 | TEST_F(GraphOptz, ParallelizeGraph_Pow) { |
5452 | auto *input1 = |
5453 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1" , false); |
5454 | bindings_.allocate(input1)->getHandle<float>().randomize(1.0, 2.0, |
5455 | mod_.getPRNG()); |
5456 | auto *input2 = |
5457 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2" , false); |
5458 | bindings_.allocate(input2)->getHandle<float>().randomize(0.0, 5.0, |
5459 | mod_.getPRNG()); |
5460 | auto *output = |
5461 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5462 | bindings_.allocate(output); |
5463 | |
5464 | auto *Pow1 = F_->createPow("Pow1" , input1, input2); |
5465 | F_->createSave("save" , Pow1, output); |
5466 | |
5467 | ::glow::optimize(F_, CompilationMode::Infer); |
5468 | |
5469 | // This is F_ but without the parallel transformation below. |
5470 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5471 | |
5472 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5473 | parOpts[Pow1] = ParallelTransformKind::Data; |
5474 | |
5475 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5476 | ASSIGN_VALUE_OR_FAIL_TEST( |
5477 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5478 | parOpts, 12)); |
5479 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5480 | runDCEPass(F_, cctx_); |
5481 | |
5482 | // We now have 12 Pows from Pow1 |
5483 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::PowNodeKind)); |
5484 | |
5485 | // Each input of the 12 Pows are sliced. |
5486 | EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5487 | |
5488 | // One concat to bring all of the parallelized sliced Pows together. |
5489 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5490 | |
5491 | checkNumericalEquivalence(); |
5492 | } |
5493 | |
5494 | /// Test Splitting Max into multiple Maxs. |
5495 | TEST_F(GraphOptz, ParallelizeGraph_Max) { |
5496 | auto *input1 = |
5497 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1" , false); |
5498 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5499 | mod_.getPRNG()); |
5500 | auto *input2 = |
5501 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2" , false); |
5502 | bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0, |
5503 | mod_.getPRNG()); |
5504 | auto *output = |
5505 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5506 | bindings_.allocate(output); |
5507 | |
5508 | auto *Max1 = F_->createMax("Max1" , input1, input2); |
5509 | F_->createSave("save" , Max1, output); |
5510 | |
5511 | ::glow::optimize(F_, CompilationMode::Infer); |
5512 | |
5513 | // This is F_ but without the parallel transformation below. |
5514 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5515 | |
5516 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5517 | parOpts[Max1] = ParallelTransformKind::Data; |
5518 | |
5519 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5520 | ASSIGN_VALUE_OR_FAIL_TEST( |
5521 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5522 | parOpts, 12)); |
5523 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5524 | runDCEPass(F_, cctx_); |
5525 | |
5526 | // We now have 12 Maxs from Max1 |
5527 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MaxNodeKind)); |
5528 | |
5529 | // Each input of the 12 Maxs are sliced. |
5530 | EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5531 | |
5532 | // One concat to bring all of the parallelized sliced Maxs together. |
5533 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5534 | |
5535 | checkNumericalEquivalence(); |
5536 | } |
5537 | |
5538 | /// Test Splitting Min into multiple Mins. |
5539 | TEST_F(GraphOptz, ParallelizeGraph_Min) { |
5540 | auto *input1 = |
5541 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1" , false); |
5542 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5543 | mod_.getPRNG()); |
5544 | auto *input2 = |
5545 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2" , false); |
5546 | bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0, |
5547 | mod_.getPRNG()); |
5548 | auto *output = |
5549 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5550 | bindings_.allocate(output); |
5551 | |
5552 | auto *Min1 = F_->createMin("Min1" , input1, input2); |
5553 | F_->createSave("save" , Min1, output); |
5554 | |
5555 | ::glow::optimize(F_, CompilationMode::Infer); |
5556 | |
5557 | // This is F_ but without the parallel transformation below. |
5558 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5559 | |
5560 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5561 | parOpts[Min1] = ParallelTransformKind::Data; |
5562 | |
5563 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5564 | ASSIGN_VALUE_OR_FAIL_TEST( |
5565 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5566 | parOpts, 12)); |
5567 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5568 | runDCEPass(F_, cctx_); |
5569 | |
5570 | // We now have 12 Mins from Min1 |
5571 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::MinNodeKind)); |
5572 | |
5573 | // Each input of the 12 Mins are sliced. |
5574 | EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5575 | |
5576 | // One concat to bring all of the parallelized sliced Mins together. |
5577 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5578 | |
5579 | checkNumericalEquivalence(); |
5580 | } |
5581 | |
5582 | /// Test Splitting BatchedReduceMean into multiple BatchedReduceMeans. |
5583 | TEST_F(GraphOptz, ParallelizeGraph_BatchedReduceMean) { |
5584 | auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {32, 16, 2048}, |
5585 | "input1" , false); |
5586 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5587 | mod_.getPRNG()); |
5588 | auto *output = |
5589 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5590 | bindings_.allocate(output); |
5591 | |
5592 | auto *BatchedReduceMean1 = |
5593 | F_->createBatchedReduceMean("BatchedReduceMean1" , input1, {1}); |
5594 | F_->createSave("save" , BatchedReduceMean1, output); |
5595 | |
5596 | ::glow::optimize(F_, CompilationMode::Infer); |
5597 | |
5598 | // This is F_ but without the parallel transformation below. |
5599 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5600 | |
5601 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5602 | parOpts[BatchedReduceMean1] = ParallelTransformKind::Data; |
5603 | |
5604 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5605 | ASSIGN_VALUE_OR_FAIL_TEST( |
5606 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5607 | parOpts, 12)); |
5608 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5609 | runDCEPass(F_, cctx_); |
5610 | |
5611 | // We now have 12 BatchedReduceMeans from BatchedReduceMean1 |
5612 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::BatchedReduceMeanNodeKind)); |
5613 | |
5614 | // Each input of the 12 BatchedReduceMeans are sliced. |
5615 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5616 | |
5617 | // One concat to bring all of the parallelized sliced BatchedReduceMeans |
5618 | // together. |
5619 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5620 | |
5621 | checkNumericalEquivalence(); |
5622 | } |
5623 | |
5624 | /// Test Splitting BatchedReduceMean into multiple BatchedReduceMeans. |
5625 | /// Failure case with first dimension in reduction |
5626 | TEST_F(GraphOptz, ParallelizeGraph_BatchedReduceMean_failure) { |
5627 | auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {32, 16, 2048}, |
5628 | "input1" , false); |
5629 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5630 | mod_.getPRNG()); |
5631 | auto *output = |
5632 | mod_.createPlaceholder(ElemKind::FloatTy, {16, 2048}, "output" , false); |
5633 | bindings_.allocate(output); |
5634 | |
5635 | auto *BatchedReduceMean1 = |
5636 | F_->createBatchedReduceMean("BatchedReduceMean1" , input1, {0}); |
5637 | F_->createSave("save" , BatchedReduceMean1, output); |
5638 | |
5639 | ::glow::optimize(F_, CompilationMode::Infer); |
5640 | |
5641 | // This is F_ but without the parallel transformation below. |
5642 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5643 | |
5644 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5645 | parOpts[BatchedReduceMean1] = ParallelTransformKind::Data; |
5646 | |
5647 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5648 | ASSIGN_VALUE_OR_FAIL_TEST( |
5649 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5650 | parOpts, 12)); |
5651 | EXPECT_EQ(replacedMap.size(), 0); // Nothing changes |
5652 | runDCEPass(F_, cctx_); |
5653 | |
5654 | // We now have only 1 BatchedReduceMean since parallelization is disabled |
5655 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::BatchedReduceMeanNodeKind)); |
5656 | |
5657 | // No concats |
5658 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5659 | |
5660 | checkNumericalEquivalence(); |
5661 | } |
5662 | |
5663 | /// Test Splitting Transpose into multiple Transposes. |
5664 | TEST_F(GraphOptz, ParallelizeGraph_Transpose) { |
5665 | auto *input = |
5666 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 151, 64}, "input" , false); |
5667 | bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0, |
5668 | mod_.getPRNG()); |
5669 | auto *output = |
5670 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 64, 151}, "output" , false); |
5671 | bindings_.allocate(output); |
5672 | |
5673 | auto *trans1 = F_->createTranspose("trans1" , input, {0, 2, 1}); |
5674 | F_->createSave("save" , trans1, output); |
5675 | |
5676 | ::glow::optimize(F_, CompilationMode::Infer); |
5677 | |
5678 | // This is F_ but without the parallel transformation below. |
5679 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5680 | |
5681 | llvm::DenseMap<Node *, size_t> numChunks; |
5682 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5683 | numChunks[trans1] = 2; |
5684 | parOpts[trans1] = ParallelTransformKind::Data; |
5685 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5686 | ASSIGN_VALUE_OR_FAIL_TEST(replacedMap, |
5687 | ::glow::parallelizeOps(F_, numChunks, parOpts)); |
5688 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5689 | |
5690 | runDCEPass(F_, cctx_); |
5691 | |
5692 | EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TransposeNodeKind)); |
5693 | |
5694 | checkNumericalEquivalence(); |
5695 | } |
5696 | |
5697 | /// Test Splitting Transpose into multiple Transposes. |
5698 | TEST_F(GraphOptz, ParallelizeGraph_Transpose3D_210) { |
5699 | auto *input = |
5700 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 15, 23}, "input" , false); |
5701 | bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0, |
5702 | mod_.getPRNG()); |
5703 | auto *output = |
5704 | mod_.createPlaceholder(ElemKind::FloatTy, {23, 15, 4}, "output" , false); |
5705 | bindings_.allocate(output); |
5706 | |
5707 | auto *trans1 = F_->createTranspose("trans1" , input, {2, 1, 0}); |
5708 | F_->createSave("save" , trans1, output); |
5709 | |
5710 | ::glow::optimize(F_, CompilationMode::Infer); |
5711 | |
5712 | // This is F_ but without the parallel transformation below. |
5713 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5714 | |
5715 | llvm::DenseMap<Node *, size_t> numChunks; |
5716 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5717 | numChunks[trans1] = 8; |
5718 | parOpts[trans1] = ParallelTransformKind::Data; |
5719 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5720 | ASSIGN_VALUE_OR_FAIL_TEST(replacedMap, |
5721 | ::glow::parallelizeOps(F_, numChunks, parOpts)); |
5722 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5723 | |
5724 | runDCEPass(F_, cctx_); |
5725 | |
5726 | EXPECT_EQ(8, countNodeKind(F_, Kinded::Kind::TransposeNodeKind)); |
5727 | |
5728 | checkNumericalEquivalence(); |
5729 | } |
5730 | |
5731 | /// Test Splitting Transpose into multiple Transposes. |
5732 | TEST_F(GraphOptz, ParallelizeGraph_Transpose3D_120) { |
5733 | auto *input = |
5734 | mod_.createPlaceholder(ElemKind::FloatTy, {15, 8, 23}, "input" , false); |
5735 | bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0, |
5736 | mod_.getPRNG()); |
5737 | auto *output = |
5738 | mod_.createPlaceholder(ElemKind::FloatTy, {8, 23, 15}, "output" , false); |
5739 | bindings_.allocate(output); |
5740 | |
5741 | auto *trans1 = F_->createTranspose("trans1" , input, {1, 2, 0}); |
5742 | F_->createSave("save" , trans1, output); |
5743 | |
5744 | ::glow::optimize(F_, CompilationMode::Infer); |
5745 | |
5746 | // This is F_ but without the parallel transformation below. |
5747 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5748 | |
5749 | llvm::DenseMap<Node *, size_t> numChunks; |
5750 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5751 | numChunks[trans1] = 8; |
5752 | parOpts[trans1] = ParallelTransformKind::Data; |
5753 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5754 | ASSIGN_VALUE_OR_FAIL_TEST(replacedMap, |
5755 | ::glow::parallelizeOps(F_, numChunks, parOpts)); |
5756 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5757 | |
5758 | runDCEPass(F_, cctx_); |
5759 | |
5760 | EXPECT_EQ(8, countNodeKind(F_, Kinded::Kind::TransposeNodeKind)); |
5761 | |
5762 | checkNumericalEquivalence(); |
5763 | } |
5764 | |
5765 | /// Test Splitting Select into multiple Selects. |
5766 | TEST_F(GraphOptz, ParallelizeGraphData_Select) { |
5767 | auto *sel1_lhs = |
5768 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_lhs" , false); |
5769 | bindings_.allocate(sel1_lhs)->getHandle<float>().randomize(-1.0, 1.0, |
5770 | mod_.getPRNG()); |
5771 | auto *sel1_rhs = |
5772 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_rhs" , false); |
5773 | bindings_.allocate(sel1_rhs)->getHandle<float>().randomize(-1.0, 1.0, |
5774 | mod_.getPRNG()); |
5775 | auto *sel1_cond = |
5776 | mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel1_cond" , false); |
5777 | bindings_.allocate(sel1_cond)->getHandle<bool>().randomize(0, 1, |
5778 | mod_.getPRNG()); |
5779 | auto *sel2_rhs = |
5780 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel2_rhs" , false); |
5781 | bindings_.allocate(sel2_rhs)->getHandle<float>().randomize(-1.0, 1.0, |
5782 | mod_.getPRNG()); |
5783 | auto *sel2_cond = |
5784 | mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel2_cond" , false); |
5785 | bindings_.allocate(sel2_cond)->getHandle<bool>().randomize(0, 1, |
5786 | mod_.getPRNG()); |
5787 | auto *output = |
5788 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5789 | bindings_.allocate(output); |
5790 | |
5791 | auto *sel1 = F_->createSelect("sel1" , sel1_cond, sel1_lhs, sel1_rhs); |
5792 | auto *sel2 = F_->createSelect("sel2" , sel2_cond, sel1, sel2_rhs); |
5793 | F_->createSave("save" , sel2, output); |
5794 | |
5795 | ::glow::optimize(F_, CompilationMode::Infer); |
5796 | |
5797 | // This is F_ but without the parallel transformation below. |
5798 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5799 | |
5800 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5801 | parOpts[sel1] = ParallelTransformKind::Data; |
5802 | |
5803 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5804 | ASSIGN_VALUE_OR_FAIL_TEST( |
5805 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5806 | parOpts, 12)); |
5807 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5808 | runDCEPass(F_, cctx_); |
5809 | |
5810 | // We now have 12 Selects from sel1, as well as the original sel2 which is |
5811 | // unchanged. |
5812 | EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::SelectNodeKind)); |
5813 | |
5814 | // Each input (3 total inputs) of the 12 Selects are sliced. |
5815 | EXPECT_EQ(36, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5816 | |
5817 | // One concat to bring all of the parallelized sliced Select together. |
5818 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5819 | |
5820 | checkNumericalEquivalence(); |
5821 | } |
5822 | |
5823 | /// Test Splitting Select into multiple Selects. |
5824 | TEST_F(GraphOptz, ParallelizeGraphModel_Select) { |
5825 | auto *sel1_lhs = |
5826 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_lhs" , false); |
5827 | bindings_.allocate(sel1_lhs)->getHandle<float>().randomize(-1.0, 1.0, |
5828 | mod_.getPRNG()); |
5829 | auto *sel1_rhs = |
5830 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel1_rhs" , false); |
5831 | bindings_.allocate(sel1_rhs)->getHandle<float>().randomize(-1.0, 1.0, |
5832 | mod_.getPRNG()); |
5833 | auto *sel1_cond = |
5834 | mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel1_cond" , false); |
5835 | bindings_.allocate(sel1_cond)->getHandle<bool>().randomize(0, 1, |
5836 | mod_.getPRNG()); |
5837 | auto *sel2_rhs = |
5838 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "sel2_rhs" , false); |
5839 | bindings_.allocate(sel2_rhs)->getHandle<float>().randomize(-1.0, 1.0, |
5840 | mod_.getPRNG()); |
5841 | auto *sel2_cond = |
5842 | mod_.createPlaceholder(ElemKind::BoolTy, {32, 2048}, "sel2_cond" , false); |
5843 | bindings_.allocate(sel2_cond)->getHandle<bool>().randomize(0, 1, |
5844 | mod_.getPRNG()); |
5845 | auto *output = |
5846 | mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output" , false); |
5847 | bindings_.allocate(output); |
5848 | |
5849 | auto *sel1 = F_->createSelect("sel1" , sel1_cond, sel1_lhs, sel1_rhs); |
5850 | auto *sel2 = F_->createSelect("sel2" , sel2_cond, sel1, sel2_rhs); |
5851 | F_->createSave("save" , sel2, output); |
5852 | |
5853 | ::glow::optimize(F_, CompilationMode::Infer); |
5854 | |
5855 | // This is F_ but without the parallel transformation below. |
5856 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5857 | |
5858 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5859 | parOpts[sel1] = ParallelTransformKind::Model; |
5860 | |
5861 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5862 | ASSIGN_VALUE_OR_FAIL_TEST( |
5863 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
5864 | parOpts, 12)); |
5865 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5866 | runDCEPass(F_, cctx_); |
5867 | |
5868 | // We now have 12 Selects from sel1, as well as the original sel2 which is |
5869 | // unchanged. |
5870 | EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::SelectNodeKind)); |
5871 | |
5872 | // Each input (3 total inputs) of the 12 Selects are sliced. |
5873 | EXPECT_EQ(36, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
5874 | |
5875 | // One concat to bring all of the parallelized sliced Select together. |
5876 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5877 | |
5878 | checkNumericalEquivalence(); |
5879 | } |
5880 | |
5881 | /// Test Splitting Reshape into multiple Reshapes. |
5882 | TEST_F(GraphOptz, ParallelizeData_Reshape) { |
5883 | auto *input1 = |
5884 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 64}, "input1" , false); |
5885 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5886 | mod_.getPRNG()); |
5887 | auto *output = |
5888 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 8, 8}, "output" , false); |
5889 | bindings_.allocate(output); |
5890 | |
5891 | auto *rs = F_->createReshape("reshape1" , input1, {3, 8, 8}); |
5892 | F_->createSave("save" , rs, output); |
5893 | |
5894 | ::glow::optimize(F_, CompilationMode::Infer); |
5895 | |
5896 | // This is F_ but without the parallel transformation below. |
5897 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5898 | |
5899 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5900 | parOpts[rs] = ParallelTransformKind::Data; |
5901 | |
5902 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5903 | ASSIGN_VALUE_OR_FAIL_TEST( |
5904 | replacedMap, |
5905 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
5906 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5907 | runDCEPass(F_, cctx_); |
5908 | |
5909 | // We now have 3 Reshapes |
5910 | EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind)); |
5911 | |
5912 | // One concat to bring all of the parallelized sliced Reshapes together. |
5913 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5914 | |
5915 | checkNumericalEquivalence(); |
5916 | } |
5917 | |
5918 | /// Test Splitting Reshape into multiple Reshapes when the batch |
5919 | /// dimension changes. This is not allowed when the input or output batch size |
5920 | /// dim cannot be divided by the # of the parallel chunks. |
5921 | TEST_F(GraphOptz, ParallelizeData_Reshape_badcase) { |
5922 | auto *input1 = |
5923 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 48}, "input1" , false); |
5924 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5925 | mod_.getPRNG()); |
5926 | auto *output = |
5927 | mod_.createPlaceholder(ElemKind::FloatTy, {24, 8}, "output" , false); |
5928 | bindings_.allocate(output); |
5929 | |
5930 | auto *rs = F_->createReshape("reshape1" , input1, {24, 8}); |
5931 | F_->createSave("save" , rs, output); |
5932 | |
5933 | ::glow::optimize(F_, CompilationMode::Infer); |
5934 | |
5935 | // This is F_ but without the parallel transformation below. |
5936 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5937 | |
5938 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5939 | parOpts[rs] = ParallelTransformKind::Data; |
5940 | |
5941 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5942 | ASSIGN_VALUE_OR_FAIL_TEST( |
5943 | replacedMap, |
5944 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
5945 | EXPECT_EQ(replacedMap.size(), 0); // Nothing gets replaced |
5946 | runDCEPass(F_, cctx_); |
5947 | |
5948 | // We now have only 1 Reshape as nothing should have split |
5949 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind)); |
5950 | |
5951 | checkNumericalEquivalence(); |
5952 | } |
5953 | |
5954 | /// Test Splitting AdaptiveAvgPool into multiple AdaptiveAvgPools. |
5955 | TEST_F(GraphOptz, ParallelizeData_AdaptiveAvgPool) { |
5956 | auto *input1 = |
5957 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 8}, "input1" , false); |
5958 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5959 | mod_.getPRNG()); |
5960 | auto *output = |
5961 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 1, 8}, "output" , false); |
5962 | bindings_.allocate(output); |
5963 | |
5964 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, {3, 1, 1, 8}); |
5965 | |
5966 | auto *aap = F_->createAdaptiveAvgPool("AdaptiveAvgPool1" , input1, outTy); |
5967 | F_->createSave("save" , aap, output); |
5968 | |
5969 | ::glow::optimize(F_, CompilationMode::Infer); |
5970 | |
5971 | // This is F_ but without the parallel transformation below. |
5972 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
5973 | |
5974 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
5975 | parOpts[aap] = ParallelTransformKind::Data; |
5976 | |
5977 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
5978 | ASSIGN_VALUE_OR_FAIL_TEST( |
5979 | replacedMap, |
5980 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
5981 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
5982 | runDCEPass(F_, cctx_); |
5983 | |
5984 | // We now have 3 AdaptiveAvgPools |
5985 | EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::AdaptiveAvgPoolNodeKind)); |
5986 | |
5987 | // One concat to bring all of the parallelized sliced AdaptiveAvgPools |
5988 | // together. |
5989 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
5990 | |
5991 | checkNumericalEquivalence(); |
5992 | } |
5993 | |
5994 | /// Test Splitting RoIAlign into multiple RoIAligns. |
5995 | TEST_F(GraphOptz, ParallelizeData_RoIAlign) { |
5996 | auto *input1 = |
5997 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 5, 5, 8}, "input1" , false); |
5998 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
5999 | mod_.getPRNG()); |
6000 | auto *boxes = mod_.createPlaceholder(ElemKind::FloatTy, {6, 4}, "roi" , false); |
6001 | bindings_.allocate(boxes)->getHandle<float>() = { |
6002 | 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3, 0, 0, 3, 3}; |
6003 | auto *batchIndices = |
6004 | mod_.createPlaceholder(ElemKind::Int64ITy, {6}, "roi" , false); |
6005 | bindings_.allocate(batchIndices) |
6006 | ->getHandle<int64_t>() |
6007 | .randomize(0, 3, mod_.getPRNG()); |
6008 | |
6009 | auto *output = |
6010 | mod_.createPlaceholder(ElemKind::FloatTy, {6, 1, 1, 8}, "output" , false); |
6011 | bindings_.allocate(output); |
6012 | |
6013 | auto *aap = F_->createROIAlign("ROIAlign" , input1, boxes, batchIndices, 1, 1, |
6014 | 0, 1, false); |
6015 | F_->createSave("save" , aap, output); |
6016 | |
6017 | ::glow::optimize(F_, CompilationMode::Infer); |
6018 | |
6019 | // This is F_ but without the parallel transformation below. |
6020 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
6021 | |
6022 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6023 | parOpts[aap] = ParallelTransformKind::Data; |
6024 | |
6025 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6026 | ASSIGN_VALUE_OR_FAIL_TEST( |
6027 | replacedMap, |
6028 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
6029 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6030 | runDCEPass(F_, cctx_); |
6031 | |
6032 | // We now have 3 RoIAligns |
6033 | EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::ROIAlignNodeKind)); |
6034 | |
6035 | // One concat to bring all of the parallelized sliced RoIAligns |
6036 | // together. |
6037 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6038 | |
6039 | checkNumericalEquivalence(); |
6040 | } |
6041 | |
6042 | /// Test Splitting MaxPool into multiple MaxPools. |
6043 | TEST_F(GraphOptz, ParallelizeData_MaxPool) { |
6044 | auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 8}, 1.0, 0, |
6045 | "input1" , false); |
6046 | bindings_.allocate(input1)->getHandle<int8_t>().randomize(-1.0, 1.0, |
6047 | mod_.getPRNG()); |
6048 | |
6049 | auto *maxp = F_->createMaxPool("MaxPool1" , input1, 5, 1, 0); |
6050 | F_->createSave("save" , maxp->getResult()); |
6051 | |
6052 | ::glow::optimize(F_, CompilationMode::Infer); |
6053 | |
6054 | // This is F_ but without the parallel transformation below. |
6055 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
6056 | |
6057 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6058 | parOpts[maxp] = ParallelTransformKind::Data; |
6059 | |
6060 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6061 | ASSIGN_VALUE_OR_FAIL_TEST( |
6062 | replacedMap, |
6063 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
6064 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6065 | runDCEPass(F_, cctx_); |
6066 | |
6067 | // We now have 3 MaxPools |
6068 | EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind)); |
6069 | |
6070 | // One concat to bring all of the parallelized sliced MaxPools |
6071 | // together. |
6072 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6073 | |
6074 | checkNumericalEquivalence(); |
6075 | } |
6076 | |
6077 | /// Test Splitting ChannelwiseQuantizedConvolution into multiple |
6078 | /// ChannelwiseQuantizedConvolutions. |
6079 | TEST_F(GraphOptz, ParallelizeData_ChannelwiseQuantizedConvolution) { |
6080 | auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 8}, 1.0, 0, |
6081 | "input1" , false); |
6082 | bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4, |
6083 | mod_.getPRNG()); |
6084 | auto *filter = |
6085 | mod_.createConstant(ElemKind::FloatTy, {12, 1, 1, 8}, "weights" ); |
6086 | filter->getPayloadMutable().getHandle().randomize(-10, 10, mod_.getPRNG()); |
6087 | auto *bias = mod_.createConstant(ElemKind::FloatTy, {12}, "bias" ); |
6088 | bias->getPayloadMutable().getHandle().randomize(-1, 1, mod_.getPRNG()); |
6089 | auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0, |
6090 | 0, "output" , false); |
6091 | bindings_.allocate(output); |
6092 | auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0, 0); |
6093 | |
6094 | auto *cqc = F_->createChannelwiseQuantizedConv( |
6095 | "ChannelwiseQuantizedConvolution1" , input1, filter, bias, nullptr, |
6096 | nullptr, nullptr, nullptr, outTy, {1, 1}, {1, 1}, {0, 0, 0, 0}, 1); |
6097 | F_->createSave("save" , cqc, output); |
6098 | |
6099 | ::glow::optimize(F_, CompilationMode::Infer); |
6100 | |
6101 | // This is F_ but without the parallel transformation below. |
6102 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
6103 | |
6104 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6105 | parOpts[cqc] = ParallelTransformKind::Data; |
6106 | |
6107 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6108 | ASSIGN_VALUE_OR_FAIL_TEST( |
6109 | replacedMap, |
6110 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
6111 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6112 | runDCEPass(F_, cctx_); |
6113 | |
6114 | // We now have 3 ChannelwiseQuantizedConvolutions |
6115 | EXPECT_EQ(3, countNodeKind( |
6116 | F_, Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind)); |
6117 | |
6118 | // One concat to bring all of the parallelized sliced |
6119 | // ChannelwiseQuantizedConvolutions together. |
6120 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6121 | |
6122 | checkNumericalEquivalence(); |
6123 | } |
6124 | |
6125 | /// Test Splitting Convolution into multiple Convolutions. |
6126 | TEST_F(GraphOptz, ParallelizeData_Convolution) { |
6127 | auto *input1 = |
6128 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 4}, "input1" , false); |
6129 | bindings_.allocate(input1)->getHandle<float>().randomize(-1, 1, |
6130 | mod_.getPRNG()); |
6131 | auto *filter = |
6132 | mod_.createConstant(ElemKind::FloatTy, {6, 1, 1, 2}, "weights" ); |
6133 | filter->getPayloadMutable().getHandle().randomize(-1, 1, mod_.getPRNG()); |
6134 | auto *bias = mod_.createConstant(ElemKind::FloatTy, {6}, "bias" ); |
6135 | bias->getPayloadMutable().getHandle().randomize(-.1, .1, mod_.getPRNG()); |
6136 | auto *output = |
6137 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 6}, "output" , false); |
6138 | bindings_.allocate(output); |
6139 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, {3, 5, 5, 6}); |
6140 | |
6141 | auto *conv = |
6142 | F_->createConv("Convolution1" , input1, filter, bias, outTy, 1, 1, 0, 2); |
6143 | F_->createSave("save" , conv, output); |
6144 | |
6145 | ::glow::optimize(F_, CompilationMode::Infer); |
6146 | |
6147 | // This is F_ but without the parallel transformation below. |
6148 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
6149 | |
6150 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6151 | parOpts[conv] = ParallelTransformKind::Data; |
6152 | |
6153 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6154 | ASSIGN_VALUE_OR_FAIL_TEST( |
6155 | replacedMap, |
6156 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
6157 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6158 | runDCEPass(F_, cctx_); |
6159 | |
6160 | // We now have 3 Convolutions |
6161 | EXPECT_EQ(3, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind)); |
6162 | |
6163 | // One concat to bring all of the parallelized sliced |
6164 | // Convolutions together. |
6165 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6166 | |
6167 | checkNumericalEquivalence(); |
6168 | } |
6169 | |
6170 | /// Test Splitting RowwiseQuantizedFullyConnected into multiple |
6171 | /// RowwiseQuantizedFullyConnected nodes. |
6172 | TEST_F(GraphOptz, ParallelizeData_RowwiseQuantizedFullyConnected) { |
6173 | auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 8}, 1.0, 0, |
6174 | "input1" , false); |
6175 | bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4, |
6176 | mod_.getPRNG()); |
6177 | auto *weights = |
6178 | mod_.createConstant(ElemKind::Int8QTy, {12, 8}, 1.0, 0, "weights" ); |
6179 | weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127, |
6180 | mod_.getPRNG()); |
6181 | auto *scales = mod_.createConstant(ElemKind::FloatTy, {12}, "scales" ); |
6182 | scales->getPayloadMutable().getHandle().randomize(0.01, 0.1, mod_.getPRNG()); |
6183 | auto *offsets = mod_.createConstant(ElemKind::Int32ITy, {12}, "offsets" ); |
6184 | offsets->getPayloadMutable().getHandle<int32_t>().randomize(0, 10, |
6185 | mod_.getPRNG()); |
6186 | |
6187 | auto *bias = mod_.createConstant(ElemKind::Int8QTy, {12}, 1.0, 0, "bias" ); |
6188 | bias->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127, |
6189 | mod_.getPRNG()); |
6190 | auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 12}, 1.0, 0, |
6191 | "output" , false); |
6192 | bindings_.allocate(output); |
6193 | auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 12}, 1.0, 0); |
6194 | |
6195 | auto *rqfc = F_->createRowwiseQuantizedFullyConnected( |
6196 | "RowwiseQuantizedFullyConnected1" , input1, weights, scales, offsets, bias, |
6197 | outTy); |
6198 | F_->createSave("save" , rqfc, output); |
6199 | |
6200 | ::glow::optimize(F_, CompilationMode::Infer); |
6201 | |
6202 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6203 | parOpts[rqfc] = ParallelTransformKind::Data; |
6204 | |
6205 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6206 | ASSIGN_VALUE_OR_FAIL_TEST( |
6207 | replacedMap, |
6208 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 3)); |
6209 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6210 | runDCEPass(F_, cctx_); |
6211 | |
6212 | // We now have 3 RowwiseQuantizedFullyConnecteds |
6213 | EXPECT_EQ(3, countNodeKind( |
6214 | F_, Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind)); |
6215 | |
6216 | // One concat to bring all of the parallelized sliced |
6217 | // RowwiseQuantizedFullyConnecteds together. |
6218 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6219 | } |
6220 | |
6221 | /// Test Splitting Convolution into multiple Convolutions. |
6222 | TEST_F(GraphOptz, ParallelizeGraph_Convolution_Model_Axis3) { |
6223 | auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 8}, 1.0, 0, |
6224 | "input1" , false); |
6225 | bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4, |
6226 | mod_.getPRNG()); |
6227 | auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {12, 1, 1, 8}, 0.1, |
6228 | 0, "weights" , false); |
6229 | auto *bias = |
6230 | mod_.createPlaceholder(ElemKind::Int32QTy, {12}, 0.01, 0, "bias" , false); |
6231 | |
6232 | auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0, |
6233 | 0, "output" , false); |
6234 | bindings_.allocate(output); |
6235 | auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 5, 5, 12}, 1.0, 0); |
6236 | |
6237 | auto *c = F_->createConv("Convolution1" , input1, filter, bias, outTy, {1, 1}, |
6238 | {1, 1}, {0, 0, 0, 0}, 1); |
6239 | F_->createSave("save" , c, output); |
6240 | |
6241 | ::glow::optimize(F_, CompilationMode::Infer); |
6242 | |
6243 | // This is F_ but without the parallel transformation below. |
6244 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
6245 | |
6246 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6247 | parOpts[c] = ParallelTransformKind::Model_Axis3; |
6248 | |
6249 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6250 | ASSIGN_VALUE_OR_FAIL_TEST( |
6251 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
6252 | parOpts, 12)); |
6253 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6254 | runDCEPass(F_, cctx_); |
6255 | |
6256 | // We now have 12 Convolutions |
6257 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind)); |
6258 | |
6259 | // One concat to bring all of the parallelized sliced |
6260 | // ChannelwiseQuantizedConvolutions together. |
6261 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6262 | |
6263 | checkNumericalEquivalence(0.f); |
6264 | } |
6265 | |
6266 | /// Test Splitting Convolution3D into multiple Convolution3Ds. |
6267 | TEST_F(GraphOptz, ParallelizeGraph_Convolution3D_Model_Axis4) { |
6268 | auto *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 5, 8}, 1.0, |
6269 | 0, "input1" , false); |
6270 | bindings_.allocate(input1)->getHandle<int8_t>().randomize(-4, 4, |
6271 | mod_.getPRNG()); |
6272 | auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {12, 1, 1, 1, 8}, |
6273 | 0.1, 0, "weights" , false); |
6274 | auto *bias = |
6275 | mod_.createPlaceholder(ElemKind::Int32QTy, {12}, 0.01, 0, "bias" , false); |
6276 | |
6277 | auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5, 5, 5, 12}, |
6278 | 1.0, 0, "output" , false); |
6279 | bindings_.allocate(output); |
6280 | auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 5, 5, 5, 12}, 1.0, 0); |
6281 | |
6282 | auto *c3d = F_->createConv3D("Convolution3D1" , input1, filter, bias, outTy, |
6283 | {1, 1, 1}, {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1); |
6284 | F_->createSave("save" , c3d, output); |
6285 | |
6286 | ::glow::optimize(F_, CompilationMode::Infer); |
6287 | |
6288 | // This is F_ but without the parallel transformation below. |
6289 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
6290 | |
6291 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6292 | parOpts[c3d] = ParallelTransformKind::Model_Axis4; |
6293 | |
6294 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6295 | ASSIGN_VALUE_OR_FAIL_TEST( |
6296 | replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), |
6297 | parOpts, 12)); |
6298 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6299 | runDCEPass(F_, cctx_); |
6300 | |
6301 | // We now have 12 Convolution3Ds |
6302 | EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::Convolution3DNodeKind)); |
6303 | |
6304 | // One concat to bring all of the parallelized sliced |
6305 | // ChannelwiseQuantizedConvolutions together. |
6306 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6307 | |
6308 | checkNumericalEquivalence(0.f); |
6309 | } |
6310 | |
6311 | /// Test Splitting AvgPool into multiple AvgPools. |
6312 | TEST_F(GraphOptz, ParallelizeGraph_AvgPool_Model_Axis4) { |
6313 | auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {3, 5, 5, 5, 8}, |
6314 | "input1" , false); |
6315 | bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0, |
6316 | mod_.getPRNG()); |
6317 | auto *output = mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 1, 1, 8}, |
6318 | "output" , false); |
6319 | bindings_.allocate(output); |
6320 | |
6321 | auto *ap = F_->createAvgPool("AvgPool1" , input1, {5, 5, 5}, {1, 1, 1}, |
6322 | {0, 0, 0, 0, 0, 0}, ConvolutionLayout::NTHWC); |
6323 | F_->createSave("save" , ap, output); |
6324 | |
6325 | ::glow::optimize(F_, CompilationMode::Infer); |
6326 | |
6327 | // This is F_ but without the parallel transformation below. |
6328 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
6329 | |
6330 | llvm::DenseMap<Node *, ParallelTransformKind> parOpts; |
6331 | parOpts[ap] = ParallelTransformKind::Model_Axis4; |
6332 | |
6333 | std::unordered_map<Node *, ConcatNode *> replacedMap; |
6334 | ASSIGN_VALUE_OR_FAIL_TEST( |
6335 | replacedMap, |
6336 | ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(), parOpts, 8)); |
6337 | EXPECT_EQ(replacedMap.size(), parOpts.size()); |
6338 | runDCEPass(F_, cctx_); |
6339 | |
6340 | // We now have 8 AvgPools |
6341 | EXPECT_EQ(8, countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind)); |
6342 | |
6343 | // One concat to bring all of the parallelized sliced AvgPools |
6344 | // together. |
6345 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind)); |
6346 | |
6347 | checkNumericalEquivalence(0.f); |
6348 | } |
6349 | |
6350 | /// Test that Add after ConvTranspose is folded into Bias add when the actual |
6351 | /// Add is is a broadcast of the bias. Test \p RnL (right of left) side add. |
6352 | static void foldConvTransposeAddIntoBiasAdd(PlaceholderBindings &bindings, |
6353 | Module &mod, Function *F, |
6354 | Function *&optF, bool RnL) { |
6355 | dim_t batch = 2; |
6356 | dim_t inC = 2; |
6357 | dim_t outC = 5; |
6358 | dim_t inH = 3; |
6359 | dim_t inW = 3; |
6360 | unsigned_t kernel = 3; |
6361 | std::vector<uint32_t> pads = {0, 0, 0, 0}; |
6362 | std::vector<uint32_t> stride = {1, 1}; |
6363 | |
6364 | auto *input = mod.createPlaceholder(ElemKind::FloatTy, {2, inH, inW, inC}, |
6365 | "input" , false); |
6366 | auto *filter = mod.createPlaceholder( |
6367 | ElemKind::FloatTy, {outC, kernel, kernel, inC}, "filter" , false); |
6368 | |
6369 | auto *bias = mod.createConstant(ElemKind::FloatTy, {outC}, "bias" ); |
6370 | bias->getPayloadMutable().getHandle<float>() = {1, 3, 5, 7, 9}; |
6371 | |
6372 | std::pair<dim_t, dim_t> outHW = calculateConvTransposeOutputDims( |
6373 | inH, inW, {kernel, kernel}, stride, pads); |
6374 | auto outTy = mod.uniqueType(ElemKind::FloatTy, |
6375 | {batch, outHW.first, outHW.second, outC}); |
6376 | |
6377 | ConvTransposeNode *CTN = |
6378 | F->createConvTranspose("ConvTranspose" , input, filter, bias, outTy, |
6379 | {kernel, kernel}, stride, {0, 0, 0, 0}, 1); |
6380 | |
6381 | auto *CN = mod.createConstant(ElemKind::FloatTy, |
6382 | {batch, outHW.first, outHW.second, outC}, "c1" ); |
6383 | auto *AN = RnL ? F->createAdd("add" , CN, CTN) : F->createAdd("add" , CTN, CN); |
6384 | |
6385 | CN->getPayloadMutable().getHandle<float>() = { |
6386 | 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, |
6387 | 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, |
6388 | 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, |
6389 | 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, |
6390 | 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, |
6391 | 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, |
6392 | 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, |
6393 | 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, |
6394 | 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, |
6395 | 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, |
6396 | 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; |
6397 | |
6398 | SaveNode *save = F->createSave("save" , AN); |
6399 | bindings.allocate(save->getPlaceholder()); |
6400 | |
6401 | EXPECT_EQ(F->getNodes().size(), 3); |
6402 | optF = optimizeFunctionForTest(F); |
6403 | EXPECT_EQ(optF->getNodes().size(), 2); |
6404 | |
6405 | const SaveNode *optSave = |
6406 | findFunctionNodeByName<SaveNode>(optF, save->getName()); |
6407 | |
6408 | ConvTransposeNode *optCN = |
6409 | llvm::dyn_cast<ConvTransposeNode>(optSave->getInput()); |
6410 | EXPECT_TRUE(optCN); |
6411 | |
6412 | Constant *optBias = llvm::dyn_cast<Constant>(optCN->getBias()); |
6413 | EXPECT_TRUE(optBias); |
6414 | |
6415 | auto BH = optBias->getPayload().getHandle(); |
6416 | EXPECT_EQ(BH.raw(0), 1 + 1); |
6417 | EXPECT_EQ(BH.raw(1), 2 + 3); |
6418 | EXPECT_EQ(BH.raw(2), 3 + 5); |
6419 | EXPECT_EQ(BH.raw(3), 4 + 7); |
6420 | EXPECT_EQ(BH.raw(4), 5 + 9); |
6421 | |
6422 | bindings.allocate(mod.getPlaceholders()); |
6423 | bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG()); |
6424 | bindings.get(filter)->getHandle().randomize(-1.0, 1.0, mod.getPRNG()); |
6425 | } |
6426 | |
6427 | /// Test that Add after ConvTranspose is folded into Bias add when the actual |
6428 | /// Add is is a broadcast of the bias. |
6429 | TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddRHS) { |
6430 | foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, false); |
6431 | checkNumericalEquivalence(); |
6432 | } |
6433 | TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddLHS) { |
6434 | foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, true); |
6435 | checkNumericalEquivalence(); |
6436 | } |
6437 | |
6438 | /// Test that MatMul + Add is folded into FullyConnected. |
6439 | TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnected) { |
6440 | |
6441 | auto *input = |
6442 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 3}, "input" , false); |
6443 | auto *weights = |
6444 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights" , false); |
6445 | auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5}, "bias" , false); |
6446 | |
6447 | MatMulNode *matmul = F_->createMatMul("matmul" , input, weights); |
6448 | AddNode *add = F_->createAdd("add" , matmul, bias); |
6449 | F_->createSave("save" , add); |
6450 | EXPECT_EQ(3, F_->getNodes().size()); |
6451 | |
6452 | // The folding should replace the MatMul + Add into a FullyConnected and a |
6453 | // Reshape to 1D for the Bias. |
6454 | CompilationContext cctx; |
6455 | ::glow::fold(F_, cctx); |
6456 | EXPECT_EQ(3, F_->getNodes().size()); |
6457 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind)); |
6458 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind)); |
6459 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
6460 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind)); |
6461 | } |
6462 | |
6463 | /// Test that batched MatMul + Add is folded into batched FullyConnected. |
6464 | /// This optimization takes place only if the Bias is constant and the |
6465 | /// bias data repeats for all the batches. |
6466 | TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnectedBatched) { |
6467 | |
6468 | auto *input = |
6469 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 3}, "input" , false); |
6470 | auto *weights = |
6471 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights" , false); |
6472 | auto *bias = mod_.createConstant(ElemKind::FloatTy, {2, 5}, "bias" ); |
6473 | auto biasH = bias->getPayloadMutable().getHandle<float>(); |
6474 | biasH = {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; |
6475 | |
6476 | MatMulNode *matmul = F_->createMatMul("matmul" , input, weights); |
6477 | AddNode *add = F_->createAdd("add" , matmul, bias); |
6478 | F_->createSave("save" , add); |
6479 | EXPECT_EQ(3, F_->getNodes().size()); |
6480 | |
6481 | // The folding should replace the MatMul + Add into a FullyConnected and a |
6482 | // Reshape to 1D for the Bias. |
6483 | CompilationContext cctx; |
6484 | ::glow::fold(F_, cctx); |
6485 | EXPECT_EQ(4, F_->getNodes().size()); |
6486 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind)); |
6487 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind)); |
6488 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
6489 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SliceNodeKind)); |
6490 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind)); |
6491 | } |
6492 | |
6493 | /// Test that MatMul is converted to FullyConnected for Int8QTy. |
6494 | TEST_F(GraphOptz, ConvertMatMulToFullyConnected_Int8QTy) { |
6495 | |
6496 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 3}, 0.1f, -13, |
6497 | "input" , false); |
6498 | auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.2f, 15, |
6499 | "weights" , false); |
6500 | MatMulNode *matmul = F_->createMatMul("matmul" , input, weights); |
6501 | F_->createSave("save" , matmul); |
6502 | EXPECT_EQ(2, F_->getNodes().size()); |
6503 | |
6504 | optimizedF_ = optimizeFunctionForTest( |
6505 | F_, {FunctionPassID::ConvertMatMulToFullyConnected, getDCEPassConfig()}); |
6506 | |
6507 | EXPECT_EQ(2, optimizedF_->getNodes().size()); |
6508 | EXPECT_EQ(1, |
6509 | countNodeKind(optimizedF_, Kinded::Kind::FullyConnectedNodeKind)); |
6510 | } |
6511 | |
6512 | /// Test that MatMul is converted to FullyConnected for FloatTy. |
6513 | TEST_F(GraphOptz, ConvertMatMulToFullyConnected_FloatTy) { |
6514 | |
6515 | auto *input = |
6516 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 3}, "input" , false); |
6517 | auto *weights = |
6518 | mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights" , false); |
6519 | MatMulNode *matmul = F_->createMatMul("matmul" , input, weights); |
6520 | F_->createSave("save" , matmul); |
6521 | EXPECT_EQ(2, F_->getNodes().size()); |
6522 | |
6523 | optimizedF_ = optimizeFunctionForTest( |
6524 | F_, {FunctionPassID::ConvertMatMulToFullyConnected, getDCEPassConfig()}); |
6525 | |
6526 | EXPECT_EQ(2, optimizedF_->getNodes().size()); |
6527 | EXPECT_EQ(1, |
6528 | countNodeKind(optimizedF_, Kinded::Kind::FullyConnectedNodeKind)); |
6529 | } |
6530 | |
6531 | /// Test that FoldSlicesIntoConstants pass works as expected. |
6532 | TEST_F(GraphOptz, FoldSlicesIntoConstantsTest) { |
6533 | Constant *C = mod_.createConstant(ElemKind::FloatTy, {3, 4}, "C" ); |
6534 | auto CH = C->getPayloadMutable().getHandle<float>(); |
6535 | CH = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; |
6536 | |
6537 | SliceNode *S1 = F_->createSlice("s1" , C, {0, 0}, {3, 2}); |
6538 | SliceNode *S2 = F_->createSlice("s2" , C, {0, 2}, {3, 4}); |
6539 | SaveNode *SN1 = F_->createSave("save1" , S1); |
6540 | SaveNode *SN2 = F_->createSave("save2" , S2); |
6541 | |
6542 | optimizedF_ = optimizeFunctionForTest( |
6543 | F_, {FunctionPassID::FoldSlicesIntoConstants, getDCEPassConfig()}); |
6544 | |
6545 | SaveNode *optSN1 = |
6546 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN1->getName())); |
6547 | SaveNode *optSN2 = |
6548 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN2->getName())); |
6549 | ASSERT_TRUE(optSN1); |
6550 | ASSERT_TRUE(optSN2); |
6551 | |
6552 | Constant *C1 = llvm::dyn_cast<Constant>(optSN1->getInput()); |
6553 | ASSERT_TRUE(C1); |
6554 | auto H1 = C1->getPayloadMutable().getHandle(); |
6555 | Constant *C2 = llvm::dyn_cast<Constant>(optSN2->getInput()); |
6556 | ASSERT_TRUE(C2); |
6557 | auto H2 = C2->getPayloadMutable().getHandle(); |
6558 | for (dim_t i = 0, e = 3; i < e; i++) { |
6559 | for (dim_t j = 0, e = 2; j < e; j++) { |
6560 | EXPECT_EQ(H1.at({i, j}), CH.at({i, j})); |
6561 | EXPECT_EQ(H2.at({i, j}), CH.at({i, j + 2})); |
6562 | } |
6563 | } |
6564 | } |
6565 | |
6566 | /// Test that RaiseClipsAboveShapeNodes pass works as expected. |
6567 | TEST_F(GraphOptz, RaiseClipsAboveShapeNodesTest) { |
6568 | Placeholder *input = |
6569 | mod_.createPlaceholder(ElemKind::FloatTy, {256, 64}, "input" , false); |
6570 | |
6571 | ReshapeNode *RN1 = F_->createReshape("reshape1" , input, {4, 128, 32}); |
6572 | ReshapeNode *RN2 = F_->createReshape("reshape2" , RN1, {64, 256}); |
6573 | TransposeNode *TN = F_->createTranspose("transpose" , RN2, {1, 0}); |
6574 | SliceNode *SN = F_->createSlice("slice" , TN, {64, 0}, {256, 64}); |
6575 | TileNode *TiN = F_->createTile("tile" , SN, 2, 0); |
6576 | ClipNode *CN = F_->createClip("clip" , TiN, -0.1, 0.1); |
6577 | SaveNode *save1 = F_->createSave("save1" , RN1); |
6578 | SaveNode *save2 = F_->createSave("save2" , CN); |
6579 | |
6580 | optimizedF_ = |
6581 | optimizeFunctionForTest(F_, {FunctionPassID::RaiseClipsAboveShapeNodes}); |
6582 | |
6583 | auto *optSave1 = |
6584 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save1->getName())); |
6585 | ASSERT_TRUE(optSave1); |
6586 | auto *optSave2 = |
6587 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save2->getName())); |
6588 | ASSERT_TRUE(optSave2); |
6589 | |
6590 | // save1 should only have a single untouched Reshape RN1 input which has input |
6591 | // input into it, because RN1 has multiple users. |
6592 | auto *optRN1 = llvm::dyn_cast<ReshapeNode>(optSave1->getInput().getNode()); |
6593 | ASSERT_TRUE(optRN1); |
6594 | EXPECT_EQ(input, optRN1->getInput().getNode()); |
6595 | |
6596 | // save2 should have CN it originally saved pushed up above SN, TiN, TN, and |
6597 | // RN2. |
6598 | TileNode *newTiN = llvm::dyn_cast<TileNode>(optSave2->getInput()); |
6599 | ASSERT_TRUE(newTiN); |
6600 | EXPECT_EQ(newTiN->getCount(), TiN->getCount()); |
6601 | SliceNode *newSN = llvm::dyn_cast<SliceNode>(newTiN->getInput()); |
6602 | ASSERT_TRUE(newSN); |
6603 | EXPECT_EQ(newSN->getStart(), SN->getStart()); |
6604 | TransposeNode *newTN = llvm::dyn_cast<TransposeNode>(newSN->getInput()); |
6605 | ASSERT_TRUE(newTN); |
6606 | EXPECT_EQ(newTN->getShuffle(), TN->getShuffle()); |
6607 | ReshapeNode *newRN2 = llvm::dyn_cast<ReshapeNode>(newTN->getInput()); |
6608 | ASSERT_TRUE(newRN2); |
6609 | ClipNode *newCN = llvm::dyn_cast<ClipNode>(newRN2->getInput()); |
6610 | ASSERT_TRUE(newCN); |
6611 | EXPECT_EQ(newCN->getMin(), CN->getMin()); |
6612 | EXPECT_EQ(newCN->getMax(), CN->getMax()); |
6613 | |
6614 | bindings_.allocate(mod_.getPlaceholders()); |
6615 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
6616 | checkNumericalEquivalence(); |
6617 | } |
6618 | |
6619 | static void testOptimizeDequantizeClip(PlaceholderBindings &bindings, |
6620 | Module &mod, Function *F, |
6621 | Function *&optF, |
6622 | bool enableQuantParamChanges) { |
6623 | Placeholder *input = |
6624 | mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input" , false); |
6625 | |
6626 | const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1}); |
6627 | |
6628 | QuantizeNode *QN = |
6629 | F->createQuantize("quantize" , input, |
6630 | mod.uniqueType(ElemKind::Int8QTy, {20, 20}, |
6631 | qParams.scale, qParams.offset)); |
6632 | DequantizeNode *DN = F->createDequantize("dequantize" , QN, ElemKind::FloatTy); |
6633 | ClipNode *CN = |
6634 | F->createClip("clip" , DN, enableQuantParamChanges ? 0 : -100, 100); |
6635 | SaveNode *SN = F->createSave("save" , CN); |
6636 | |
6637 | CompilationContext cctx; |
6638 | cctx.optimizationOpts.enableQuantParamChanges = true; |
6639 | optF = optimizeFunctionForTest( |
6640 | F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx); |
6641 | |
6642 | EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0); |
6643 | |
6644 | SaveNode *optSN = |
6645 | llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName())); |
6646 | ASSERT_TRUE(optSN); |
6647 | |
6648 | // Now check that the quantization params have been correctly updated for QN, |
6649 | // and that CN has been eliminated. |
6650 | DequantizeNode *optDN = |
6651 | llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode()); |
6652 | ASSERT_TRUE(optDN); |
6653 | const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange(); |
6654 | // Min is either from Clip or Quant range depending on enableQuantParamChanges |
6655 | EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3); |
6656 | EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max from Quant range |
6657 | |
6658 | bindings.allocate(mod.getPlaceholders()); |
6659 | bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG()); |
6660 | } |
6661 | |
6662 | /// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize) |
6663 | /// when the quantization parameters are allowed to change. |
6664 | TEST_F(GraphOptz, OptimizeDequantizeClipTest_QuantParamChanges) { |
6665 | testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_, |
6666 | /* enableQuantParamChanges */ true); |
6667 | checkNumericalEquivalence(0.0005); |
6668 | } |
6669 | |
6670 | /// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize) |
6671 | /// when the quantization parameters are not allowed to change. |
6672 | TEST_F(GraphOptz, OptimizeDequantizeClipTest_NoQuantParamChanges) { |
6673 | testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_, |
6674 | /* enableQuantParamChanges */ false); |
6675 | checkNumericalEquivalence(); |
6676 | } |
6677 | |
6678 | static void testOptimizeClipQuantize(PlaceholderBindings &bindings, Module &mod, |
6679 | Function *F, Function *&optF, |
6680 | bool enableQuantParamChanges) { |
6681 | Placeholder *input = |
6682 | mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input" , false); |
6683 | |
6684 | const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1}); |
6685 | |
6686 | ClipNode *CN = |
6687 | F->createClip("clip" , input, enableQuantParamChanges ? 0 : -100, 100); |
6688 | QuantizeNode *QN = |
6689 | F->createQuantize("quantize" , CN, |
6690 | mod.uniqueType(ElemKind::Int8QTy, {20, 20}, |
6691 | qParams.scale, qParams.offset)); |
6692 | DequantizeNode *DN = F->createDequantize("dequantize" , QN, ElemKind::FloatTy); |
6693 | SaveNode *SN = F->createSave("save" , DN); |
6694 | |
6695 | CompilationContext cctx; |
6696 | cctx.optimizationOpts.enableQuantParamChanges = enableQuantParamChanges; |
6697 | optF = optimizeFunctionForTest( |
6698 | F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx); |
6699 | |
6700 | EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0); |
6701 | |
6702 | SaveNode *optSN = |
6703 | llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName())); |
6704 | ASSERT_TRUE(optSN); |
6705 | |
6706 | // Now check that the quantization params have been correctly updated for QN, |
6707 | // and that CN has been eliminated. |
6708 | DequantizeNode *optDN = |
6709 | llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode()); |
6710 | ASSERT_TRUE(optDN); |
6711 | const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange(); |
6712 | // Min is either from Clip or Quant range depending on enableQuantParamChanges |
6713 | EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3); |
6714 | EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max always from Quant range |
6715 | |
6716 | bindings.allocate(mod.getPlaceholders()); |
6717 | bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG()); |
6718 | } |
6719 | |
6720 | /// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize) |
6721 | /// when the quantization parameters are allowed to change. |
6722 | TEST_F(GraphOptz, OptimizeClipQuantizeTest_QuantParamChanges) { |
6723 | testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_, |
6724 | /* enableQuantParamChanges */ true); |
6725 | checkNumericalEquivalence(0.0005); |
6726 | } |
6727 | |
6728 | /// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize) |
6729 | /// when the quantization parameters are not allowed to change. |
6730 | TEST_F(GraphOptz, OptimizeClipQuantizeTest_NoQuantParamChanges) { |
6731 | testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_, |
6732 | /* enableQuantParamChanges */ false); |
6733 | checkNumericalEquivalence(); |
6734 | } |
6735 | |
6736 | /// Test Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8. |
6737 | TEST_F(GraphOptz, OptimizeOutIntermediateConversionsTest) { |
6738 | Placeholder *input = |
6739 | mod_.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input" , false); |
6740 | |
6741 | const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1}); |
6742 | |
6743 | ConvertToNode *CN = F_->createConvertTo("conv" , input, ElemKind::Float16Ty); |
6744 | QuantizeNode *QN = |
6745 | F_->createQuantize("quantize" , CN, |
6746 | mod_.uniqueType(ElemKind::Int8QTy, {20, 20}, |
6747 | qParams.scale, qParams.offset)); |
6748 | DequantizeNode *DN = |
6749 | F_->createDequantize("dequantize" , QN, ElemKind::FloatTy); |
6750 | F_->createSave("save" , DN); |
6751 | |
6752 | optimizedF_ = optimizeFunctionForTest( |
6753 | F_, |
6754 | {FunctionPassID::OptimizeOutIntermediateConversions, getDCEPassConfig()}); |
6755 | |
6756 | // Now check that the ConvertToNode has been eliminated. |
6757 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConvertToNodeKind), 0); |
6758 | |
6759 | bindings_.allocate(mod_.getPlaceholders()); |
6760 | bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG()); |
6761 | checkNumericalEquivalence(); |
6762 | } |
6763 | |
6764 | /// Test Clip(Relu(Clip)) -> Clip'. |
6765 | TEST_F(GraphOptz, ClipReluClipElimTest) { |
6766 | Placeholder *input = |
6767 | mod_.createPlaceholder(ElemKind::FloatTy, {64, 64}, "input" , false); |
6768 | ClipNode *CN1 = F_->createClip("CN1" , input, -10, 30); |
6769 | ReluNode *RN = F_->createRELU("RN" , CN1); |
6770 | ClipNode *CN2 = F_->createClip("CN2" , RN, -5, 20); |
6771 | SaveNode *SN = F_->createSave("save" , CN2); |
6772 | |
6773 | // Start with 2 clips, a relu, and a save. |
6774 | EXPECT_EQ(F_->getNodes().size(), 4); |
6775 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 2); |
6776 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1); |
6777 | |
6778 | optimizedF_ = optimizeFunctionForTest(F_); |
6779 | |
6780 | // Remove one of the clips and the relu. |
6781 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
6782 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1); |
6783 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0); |
6784 | |
6785 | SaveNode *optSN = |
6786 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
6787 | ASSERT_TRUE(optSN); |
6788 | |
6789 | // We combined all of the ranges into the single Clip. |
6790 | ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput()); |
6791 | ASSERT_TRUE(optCN); |
6792 | EXPECT_EQ(optCN->getMin(), 0); |
6793 | EXPECT_EQ(optCN->getMax(), 20); |
6794 | |
6795 | bindings_.allocate(input)->getHandle().randomize(-50.0, 5.0, mod_.getPRNG()); |
6796 | checkNumericalEquivalence(); |
6797 | } |
6798 | |
6799 | /// Test that we can find a non-quantized relu and fuse it up into a quant FC. |
6800 | TEST_F(GraphOptz, OptimizeQuantFCFloatReluTest) { |
6801 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0, |
6802 | "input" , false); |
6803 | auto *weights = |
6804 | mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights" ); |
6805 | auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias" ); |
6806 | |
6807 | auto *FC = F_->createFullyConnected("fc" , input, weights, bias); |
6808 | auto *DN = F_->createDequantize("dq" , FC, ElemKind::FloatTy); |
6809 | auto *RN = F_->createRELU("relu" , DN); |
6810 | auto *SN = F_->createSave("save" , RN); |
6811 | |
6812 | optimizedF_ = optimizeFunctionForTest( |
6813 | F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()}); |
6814 | |
6815 | SaveNode *optSN = |
6816 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
6817 | ASSERT_TRUE(optSN); |
6818 | |
6819 | DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(optSN->getInput()); |
6820 | ASSERT_TRUE(optDN); |
6821 | ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput()); |
6822 | ASSERT_TRUE(optRN); |
6823 | auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange(); |
6824 | EXPECT_EQ(rangeRN.first, 0.0f); |
6825 | FullyConnectedNode *optFC = |
6826 | llvm::dyn_cast<FullyConnectedNode>(optRN->getInput()); |
6827 | ASSERT_TRUE(optFC); |
6828 | auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange(); |
6829 | EXPECT_EQ(rangeRN.second, rangeFC.second); |
6830 | |
6831 | bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127, |
6832 | mod_.getPRNG()); |
6833 | weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127, |
6834 | mod_.getPRNG()); |
6835 | bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127, |
6836 | mod_.getPRNG()); |
6837 | checkNumericalEquivalence(); |
6838 | } |
6839 | |
6840 | /// Test that we can find a non-quantized relu and fuse it up into a quant FC |
6841 | /// even when setting dummy qparams to true. |
6842 | TEST_F(GraphOptz, OptimizeDummyQuantFCFloatReluTest) { |
6843 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0, |
6844 | "input" , false); |
6845 | auto *weights = |
6846 | mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights" ); |
6847 | auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias" ); |
6848 | auto *addW = |
6849 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 32}, "addw" , false); |
6850 | auto *FC = F_->createFullyConnected("fc" , input, weights, bias); |
6851 | auto *DN = F_->createDequantize("dq" , FC, ElemKind::FloatTy); |
6852 | auto *RN = F_->createRELU("relu" , DN); |
6853 | auto *AN = F_->createAdd("add" , RN, addW); |
6854 | auto *SN = F_->createSave("save" , AN); |
6855 | |
6856 | CompilationContext cctx; |
6857 | cctx.precisionConfig.loadUniquedDummyQParams = true; |
6858 | optimizedF_ = optimizeFunctionForTest( |
6859 | F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()}, cctx); |
6860 | |
6861 | SaveNode *optSN = |
6862 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
6863 | ASSERT_TRUE(optSN); |
6864 | |
6865 | AddNode *optAN = llvm::dyn_cast<AddNode>(optSN->getInput()); |
6866 | ASSERT_TRUE(optAN); |
6867 | DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(optAN->getLHS()); |
6868 | ASSERT_TRUE(optDN); |
6869 | ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput()); |
6870 | ASSERT_TRUE(optRN); |
6871 | auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange(); |
6872 | FullyConnectedNode *optFC = |
6873 | llvm::dyn_cast<FullyConnectedNode>(optRN->getInput()); |
6874 | ASSERT_TRUE(optFC); |
6875 | auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange(); |
6876 | EXPECT_EQ(rangeRN.first, rangeFC.first); |
6877 | EXPECT_EQ(rangeRN.second, rangeFC.second); |
6878 | |
6879 | bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127, |
6880 | mod_.getPRNG()); |
6881 | bindings_.allocate(addW)->getHandle<float>().randomize(-128, 127, |
6882 | mod_.getPRNG()); |
6883 | weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127, |
6884 | mod_.getPRNG()); |
6885 | bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127, |
6886 | mod_.getPRNG()); |
6887 | checkNumericalEquivalence(); |
6888 | } |
6889 | |
6890 | /// Test that we can find a non-quantized relu and fuse it up into a series of |
6891 | /// concatenated quant FCs. |
6892 | TEST_F(GraphOptz, OptimizeConcatQuantFCFloatReluTest) { |
6893 | std::array<NodeValue, 5> DQs; |
6894 | for (size_t i = 0; i < 5; i++) { |
6895 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, |
6896 | 1.0 / (i + 1), 0, "input" , false); |
6897 | auto *weights = |
6898 | mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights" ); |
6899 | auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias" ); |
6900 | |
6901 | auto *FC = F_->createFullyConnected("fc" , input, weights, bias); |
6902 | DQs[i] = F_->createDequantize("dq" , FC, ElemKind::FloatTy)->getResult(); |
6903 | |
6904 | bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127, |
6905 | mod_.getPRNG()); |
6906 | weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127, |
6907 | mod_.getPRNG()); |
6908 | bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127, |
6909 | mod_.getPRNG()); |
6910 | } |
6911 | |
6912 | auto *CN = F_->createConcat("concat" , DQs, 0); |
6913 | auto *RN = F_->createRELU("relu" , CN); |
6914 | auto *SN = F_->createSave("save" , RN); |
6915 | |
6916 | optimizedF_ = optimizeFunctionForTest( |
6917 | F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()}); |
6918 | |
6919 | SaveNode *optSN = |
6920 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
6921 | ASSERT_TRUE(optSN); |
6922 | ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput()); |
6923 | ASSERT_TRUE(optCN); |
6924 | EXPECT_EQ(optCN->getInputs().size(), 5); |
6925 | |
6926 | for (const NodeValue &NV : optCN->getInputs()) { |
6927 | DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(NV); |
6928 | ASSERT_TRUE(optDN); |
6929 | ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput()); |
6930 | ASSERT_TRUE(optRN); |
6931 | auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange(); |
6932 | EXPECT_EQ(rangeRN.first, 0.0f); |
6933 | FullyConnectedNode *optFC = |
6934 | llvm::dyn_cast<FullyConnectedNode>(optRN->getInput()); |
6935 | ASSERT_TRUE(optFC); |
6936 | auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange(); |
6937 | EXPECT_EQ(rangeRN.second, rangeFC.second); |
6938 | } |
6939 | |
6940 | checkNumericalEquivalence(); |
6941 | } |
6942 | |
6943 | /// Test that we can find a concat with all dequantize inputs and a quantize at |
6944 | /// its output, and then replace quant/dequants with rescales. |
6945 | TEST_F(GraphOptz, OptimizeDequantConcatQuant) { |
6946 | std::array<NodeValue, 5> DQs; |
6947 | std::array<Placeholder *, 5> inputs; |
6948 | for (size_t i = 0; i < 5; i++) { |
6949 | inputs[i] = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, |
6950 | 0.3 / (i + 1), 5, "input" , false); |
6951 | DQs[i] = |
6952 | F_->createDequantize("dq" , inputs[i], ElemKind::FloatTy)->getResult(); |
6953 | |
6954 | bindings_.allocate(inputs[i])->getHandle<int8_t>().randomize( |
6955 | -128, 127, mod_.getPRNG()); |
6956 | } |
6957 | |
6958 | auto *CN = F_->createConcat("concat" , DQs, 0); |
6959 | constexpr float scale = 0.3; |
6960 | constexpr int32_t offset = 5; |
6961 | auto *RN = F_->createQuantize("quantize" , CN, |
6962 | mod_.uniqueType(ElemKind::Int8QTy, |
6963 | CN->getResult().dims(), scale, |
6964 | offset)); |
6965 | auto *SN = F_->createSave("save" , RN); |
6966 | |
6967 | optimizedF_ = optimizeFunctionForTest( |
6968 | F_, {FunctionPassID::OptimizeConcatQuantization, getDCEPassConfig()}); |
6969 | |
6970 | SaveNode *optSN = |
6971 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
6972 | ASSERT_TRUE(optSN); |
6973 | ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput()); |
6974 | ASSERT_TRUE(optCN); |
6975 | EXPECT_EQ(optCN->getInputs().size(), 5); |
6976 | |
6977 | for (size_t i = 0, e = optCN->getInputs().size(); i < e; i++) { |
6978 | const NodeValue NV = optCN->getInputs()[i]; |
6979 | if (i == 0) { |
6980 | EXPECT_EQ(inputs[i], NV.getNode()); |
6981 | EXPECT_EQ(inputs[i]->getOutput().getType()->getScale(), scale); |
6982 | EXPECT_EQ(inputs[i]->getOutput().getType()->getOffset(), offset); |
6983 | } else { |
6984 | RescaleQuantizedNode *optRN = llvm::dyn_cast<RescaleQuantizedNode>(NV); |
6985 | ASSERT_TRUE(optRN); |
6986 | EXPECT_EQ(optRN->getResult().getType()->getScale(), scale); |
6987 | EXPECT_EQ(optRN->getResult().getType()->getOffset(), offset); |
6988 | EXPECT_EQ(inputs[i], optRN->getInput().getNode()); |
6989 | } |
6990 | } |
6991 | checkNumericalEquivalence(); |
6992 | } |
6993 | |
6994 | /// Test that if we have a Concat with all Dequantize inputs with the same |
6995 | /// scale/offset/kind that we can sink the Dequantizes below the Concat. |
6996 | TEST_F(GraphOptz, SinkDequantizeBelowConcatTest) { |
6997 | const float scale = 0.06; |
6998 | const int32_t offset = -15; |
6999 | std::array<NodeValue, 5> inputs; |
7000 | for (dim_t i = 0; i < 5; i++) { |
7001 | Placeholder *input = mod_.createPlaceholder(ElemKind::Int8QTy, {i + 1, 100}, |
7002 | scale, offset, "input" , false); |
7003 | bindings_.allocate(input)->getHandle<int8_t>().randomize(-100, 100, |
7004 | mod_.getPRNG()); |
7005 | DequantizeNode *dequantize = |
7006 | F_->createDequantize("dequantize" , input, ElemKind::Float16Ty); |
7007 | inputs[i] = dequantize->getResult(); |
7008 | } |
7009 | ConcatNode *concat = F_->createConcat("concat" , inputs, 0); |
7010 | SaveNode *SN = F_->createSave("ret" , concat); |
7011 | |
7012 | optimizedF_ = optimizeFunctionForTest( |
7013 | F_, {FunctionPassID::SinkConversions, getDCEPassConfig()}); |
7014 | |
7015 | // Concat, dequantize, save. |
7016 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
7017 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::DequantizeNodeKind), 1); |
7018 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1); |
7019 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1); |
7020 | |
7021 | SaveNode *optSN = |
7022 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
7023 | ASSERT_TRUE(optSN); |
7024 | DequantizeNode *optDequantize = |
7025 | llvm::dyn_cast<DequantizeNode>(optSN->getInput()); |
7026 | ASSERT_TRUE(optDequantize); |
7027 | NodeValue input = optDequantize->getInput(); |
7028 | EXPECT_EQ(scale, input.getType()->getScale()); |
7029 | EXPECT_EQ(offset, input.getType()->getOffset()); |
7030 | EXPECT_EQ(ElemKind::Int8QTy, input.getType()->getElementType()); |
7031 | |
7032 | // Find dequantize node in the optimized graph. |
7033 | checkNumericalEquivalence(); |
7034 | } |
7035 | |
7036 | /// Test that if we have a Concat with all Quantize inputs with the same |
7037 | /// scale/offset/kind that we can sink the Dequantizes below the Concat. |
7038 | TEST_F(GraphOptz, SinkQuantizeBelowConcatTest) { |
7039 | const float scale = 0.06; |
7040 | const int32_t offset = -15; |
7041 | std::array<NodeValue, 5> inputs; |
7042 | for (dim_t i = 0; i < 5; i++) { |
7043 | Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty, |
7044 | {i + 1, 100}, "input" , false); |
7045 | bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100, |
7046 | mod_.getPRNG()); |
7047 | const TypeRef QTy = mod_.uniqueType( |
7048 | ElemKind::Int8QTy, input->getOutput().dims(), scale, offset); |
7049 | QuantizeNode *quantize = F_->createQuantize("quantize" , input, QTy); |
7050 | inputs[i] = quantize->getResult(); |
7051 | } |
7052 | ConcatNode *concat = F_->createConcat("concat" , inputs, 0); |
7053 | SaveNode *SN = F_->createSave("ret" , concat); |
7054 | |
7055 | optimizedF_ = optimizeFunctionForTest( |
7056 | F_, {FunctionPassID::SinkConversions, getDCEPassConfig()}); |
7057 | |
7058 | // Concat, quantize, save. |
7059 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
7060 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1); |
7061 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1); |
7062 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1); |
7063 | |
7064 | SaveNode *optSN = |
7065 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
7066 | ASSERT_TRUE(optSN); |
7067 | QuantizeNode *optQuantize = llvm::dyn_cast<QuantizeNode>(optSN->getInput()); |
7068 | ASSERT_TRUE(optQuantize); |
7069 | EXPECT_EQ(scale, optQuantize->getResult().getType()->getScale()); |
7070 | EXPECT_EQ(offset, optQuantize->getResult().getType()->getOffset()); |
7071 | EXPECT_EQ(ElemKind::Int8QTy, |
7072 | optQuantize->getResult().getType()->getElementType()); |
7073 | |
7074 | // Find quantize node in the optimized graph. |
7075 | checkNumericalEquivalence(); |
7076 | } |
7077 | |
7078 | /// Test that if we have a Concat with all Tanh inputs, |
7079 | /// we can sink the Tanh's below the Concat. |
7080 | TEST_F(GraphOptz, SinkTanhBelowConcatTest) { |
7081 | std::array<NodeValue, 5> inputs; |
7082 | for (dim_t i = 0; i < 5; i++) { |
7083 | Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty, |
7084 | {i + 1, 100}, "input" , false); |
7085 | bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100, |
7086 | mod_.getPRNG()); |
7087 | TanhNode *tanh = F_->createTanh("tanh" , input); |
7088 | inputs[i] = tanh->getResult(); |
7089 | } |
7090 | ConcatNode *concat = F_->createConcat("concat" , inputs, 0); |
7091 | SaveNode *SN = F_->createSave("ret" , concat); |
7092 | EXPECT_EQ(F_->getNodes().size(), 7); |
7093 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TanhNodeKind), 5); |
7094 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 1); |
7095 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 1); |
7096 | |
7097 | CompilationContext cctx; |
7098 | cctx.optimizationOpts.sinkTanhBelowConcat = true; |
7099 | |
7100 | optimizedF_ = optimizeFunctionForTest( |
7101 | F_, {FunctionPassID::SinkConversions, getDCEPassConfig()}, cctx); |
7102 | |
7103 | // Concat, dequantize, save. |
7104 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
7105 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::TanhNodeKind), 1); |
7106 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1); |
7107 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1); |
7108 | |
7109 | SaveNode *optSN = |
7110 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
7111 | ASSERT_TRUE(optSN); |
7112 | TanhNode *optTanh = llvm::dyn_cast<TanhNode>(optSN->getInput()); |
7113 | ASSERT_TRUE(optTanh); |
7114 | NodeValue input = optTanh->getInput(); |
7115 | EXPECT_EQ(ElemKind::Float16Ty, input.getType()->getElementType()); |
7116 | |
7117 | checkNumericalEquivalence(); |
7118 | } |
7119 | |
7120 | /// Test that if we have a Concat with all ConvertTo inputs, |
7121 | /// we can sink the ConvertTo's below the Concat. |
7122 | TEST_F(GraphOptz, SinkConvertToBelowConcatTest) { |
7123 | std::array<NodeValue, 5> inputs; |
7124 | for (dim_t i = 0; i < 5; i++) { |
7125 | Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty, |
7126 | {i + 1, 100}, "input" , false); |
7127 | bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100, |
7128 | mod_.getPRNG()); |
7129 | ConvertToNode *convertTo = |
7130 | F_->createConvertTo("convertToFP32" , input, ElemKind::FloatTy); |
7131 | inputs[i] = convertTo->getResult(); |
7132 | } |
7133 | ConcatNode *concat = F_->createConcat("concat" , inputs, 0); |
7134 | SaveNode *SN = F_->createSave("ret" , concat); |
7135 | EXPECT_EQ(F_->getNodes().size(), 7); |
7136 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConvertToNodeKind), 5); |
7137 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 1); |
7138 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 1); |
7139 | |
7140 | CompilationContext cctx; |
7141 | |
7142 | optimizedF_ = optimizeFunctionForTest( |
7143 | F_, {FunctionPassID::SinkConversions, getDCEPassConfig()}, cctx); |
7144 | |
7145 | // Concat, converTo, save. |
7146 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
7147 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConvertToNodeKind), 1); |
7148 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1); |
7149 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1); |
7150 | |
7151 | SaveNode *optSN = |
7152 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
7153 | ASSERT_TRUE(optSN); |
7154 | ConvertToNode *optConvertTo = |
7155 | llvm::dyn_cast<ConvertToNode>(optSN->getInput()); |
7156 | ASSERT_TRUE(optConvertTo); |
7157 | NodeValue input = optConvertTo->getInput(); |
7158 | EXPECT_EQ(ElemKind::Float16Ty, input.getType()->getElementType()); |
7159 | |
7160 | checkNumericalEquivalence(); |
7161 | } |
7162 | |
7163 | /// Test Clip(Relu) -> Clip'. |
7164 | TEST_F(GraphOptz, ClipReluTest) { |
7165 | Placeholder *input = |
7166 | mod_.createPlaceholder(ElemKind::Float16Ty, {64, 64}, "input" , false); |
7167 | ReluNode *RN = F_->createRELU("RN" , input); |
7168 | ClipNode *CN = F_->createClip("CN" , RN, -5, 20); |
7169 | SaveNode *SN = F_->createSave("save" , CN); |
7170 | |
7171 | // Start with a clip, a relu, and a save. |
7172 | EXPECT_EQ(F_->getNodes().size(), 3); |
7173 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 1); |
7174 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1); |
7175 | |
7176 | optimizedF_ = optimizeFunctionForTest(F_); |
7177 | |
7178 | // Removed the relu |
7179 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
7180 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1); |
7181 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0); |
7182 | |
7183 | SaveNode *optSN = |
7184 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
7185 | ASSERT_TRUE(optSN); |
7186 | |
7187 | // We have the same max for clip as before, but 0 for min due to the Relu. |
7188 | ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput()); |
7189 | ASSERT_TRUE(optCN); |
7190 | EXPECT_EQ(optCN->getMin(), 0); |
7191 | EXPECT_EQ(optCN->getMax(), 20); |
7192 | |
7193 | bindings_.allocate(input)->getHandle<float16_t>().randomize(-50.0, 5.0, |
7194 | mod_.getPRNG()); |
7195 | checkNumericalEquivalence(); |
7196 | } |
7197 | |
7198 | /// Test that if we have a concat with some dequantize inputs that are |
7199 | /// concatenated together, and then a quantize after the concat, that we can |
7200 | /// move the quantize above the concat and eliminate the dequantizes. |
7201 | TEST_F(GraphOptz, SinkConcatBelowQuantize) { |
7202 | const float scale = 0.06; |
7203 | const int32_t offset = -15; |
7204 | std::array<NodeValue, 3> inputs; |
7205 | |
7206 | // Concat input 0: Dequant(PH) |
7207 | const TypeRef in0QTy = |
7208 | mod_.uniqueType(ElemKind::Int8QTy, {1, 3}, scale, offset); |
7209 | Placeholder *input0 = mod_.createPlaceholder(in0QTy, "input" , false); |
7210 | inputs[0] = |
7211 | F_->createDequantize("deq" , input0, ElemKind::Float16Ty)->getResult(); |
7212 | |
7213 | // Concat input 1: Dequant(Add(PH, PH)) |
7214 | const TypeRef in1QTy = |
7215 | mod_.uniqueType(ElemKind::Int8QTy, {5, 3}, scale, offset + 1); |
7216 | Placeholder *input1 = mod_.createPlaceholder(in1QTy, "input" , false); |
7217 | AddNode *add = F_->createAdd("add" , input1, input1); |
7218 | inputs[1] = |
7219 | F_->createDequantize("deq" , add, ElemKind::Float16Ty)->getResult(); |
7220 | |
7221 | // Concat input 2: PH |
7222 | Placeholder *input2 = |
7223 | mod_.createPlaceholder(ElemKind::Float16Ty, {10, 3}, "input_fp" , false); |
7224 | inputs[2] = input2->getOutput(); |
7225 | |
7226 | // Concat all 3 together, all FP16. |
7227 | ConcatNode *concat = F_->createConcat("concat" , inputs, 0); |
7228 | |
7229 | // Now quantize the result of the concat. |
7230 | const TypeRef QTy = mod_.uniqueType( |
7231 | ElemKind::Int8QTy, concat->getResult().dims(), scale, offset); |
7232 | QuantizeNode *QN = F_->createQuantize("quantize" , concat, QTy); |
7233 | SaveNode *SN = F_->createSave("ret" , QN); |
7234 | |
7235 | optimizedF_ = optimizeFunctionForTest( |
7236 | F_, |
7237 | {FunctionPassID::SinkConcatBelowQuantize, |
7238 | {FunctionPassID::OptimizeQuantization, ConvergenceMode::UntilFixedPoint}, |
7239 | getDCEPassConfig()}); |
7240 | |
7241 | EXPECT_EQ(optimizedF_->getNodes().size(), 4); |
7242 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1); |
7243 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1); |
7244 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1); |
7245 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1); |
7246 | |
7247 | SaveNode *optSN = |
7248 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
7249 | ASSERT_TRUE(optSN); |
7250 | |
7251 | // Concat should be directly connected to save, with same quantization |
7252 | // parameters as the quantize which used to follow it. |
7253 | ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput()); |
7254 | ASSERT_TRUE(optCN); |
7255 | ASSERT_EQ(ElemKind::Int8QTy, optCN->getResult().getType()->getElementType()); |
7256 | EXPECT_EQ(scale, optCN->getResult().getType()->getScale()); |
7257 | EXPECT_EQ(offset, optCN->getResult().getType()->getOffset()); |
7258 | |
7259 | ASSERT_EQ(optCN->getInputs().size(), 3); |
7260 | |
7261 | // No rescale here for the PH since its scale/offset match the PH and so |
7262 | // are optimized away. |
7263 | EXPECT_EQ(optCN->getInputs()[0], input0->getOutput()); |
7264 | |
7265 | // No rescale here because it should be fused into optAN. Check the |
7266 | // scale/offset use that scale/offset. |
7267 | AddNode *optAN = llvm::dyn_cast<AddNode>(optCN->getInputs()[1]); |
7268 | ASSERT_TRUE(optAN); |
7269 | ASSERT_EQ(ElemKind::Int8QTy, optAN->getResult().getType()->getElementType()); |
7270 | EXPECT_EQ(scale, optAN->getResult().getType()->getScale()); |
7271 | EXPECT_EQ(offset, optAN->getResult().getType()->getOffset()); |
7272 | EXPECT_EQ(optAN->getLHS(), input1->getOutput()); |
7273 | EXPECT_EQ(optAN->getRHS(), input1->getOutput()); |
7274 | |
7275 | // Must quantize this input since the PH is float16. |
7276 | QuantizeNode *optQN = llvm::dyn_cast<QuantizeNode>(optCN->getInputs()[2]); |
7277 | ASSERT_TRUE(optQN); |
7278 | ASSERT_EQ(ElemKind::Int8QTy, optQN->getResult().getType()->getElementType()); |
7279 | EXPECT_EQ(scale, optQN->getResult().getType()->getScale()); |
7280 | EXPECT_EQ(offset, optQN->getResult().getType()->getOffset()); |
7281 | EXPECT_EQ(optQN->getInput(), input2->getOutput()); |
7282 | |
7283 | bindings_.allocate(input0)->getHandle<int8_t>().randomize(-50, 50, |
7284 | mod_.getPRNG()); |
7285 | bindings_.allocate(input1)->getHandle<int8_t>().randomize(-50, 50, |
7286 | mod_.getPRNG()); |
7287 | bindings_.allocate(input2)->getHandle<float16_t>().randomize(-10, 10, |
7288 | mod_.getPRNG()); |
7289 | } |
7290 | |
7291 | TEST_F(GraphOptz, EliminateSliceConcatTest) { |
7292 | auto *src1 = |
7293 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 70}, "src1" , false); |
7294 | auto *src2 = |
7295 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 80}, "src2" , false); |
7296 | auto *A = F_->createSlice("A" , src1, {0, 0}, {10, 10}); |
7297 | auto *B = F_->createSlice("B" , src1, {0, 10}, {10, 20}); |
7298 | auto *C = F_->createSlice("C" , src1, {0, 20}, {10, 30}); |
7299 | // interleaved Slices with different sources shouldn't merge |
7300 | auto *E = F_->createSlice("E" , src1, {0, 30}, {10, 40}); |
7301 | auto *F = F_->createSlice("F" , src2, {0, 30}, {10, 40}); |
7302 | auto *G = F_->createSlice("G" , src1, {0, 40}, {10, 50}); |
7303 | auto *H = F_->createSlice("H" , src2, {0, 40}, {10, 50}); |
7304 | |
7305 | auto *D = mod_.createPlaceholder(ElemKind::FloatTy, {10, 50}, "D" , false); |
7306 | auto *R = F_->createRELU("Relu" , C); |
7307 | auto *CN = F_->createConcat("Concat" , {A, B, D, E, F, G, H}, 1); |
7308 | F_->createSave("save1" , CN); |
7309 | F_->createSave("save2" , R); |
7310 | |
7311 | EXPECT_EQ(F_->getNodes().size(), 11); |
7312 | |
7313 | optimizedF_ = optimizeFunctionForTest( |
7314 | F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()}); |
7315 | |
7316 | EXPECT_EQ(optimizedF_->getNodes().size(), 10); |
7317 | |
7318 | int numSlicesToConcat = 0; |
7319 | for (const auto &node : optimizedF_->getNodes()) { |
7320 | auto *newCN = llvm::dyn_cast<ConcatNode>(&node); |
7321 | if (!newCN) { |
7322 | continue; |
7323 | } |
7324 | EXPECT_EQ(newCN->getInputs().size(), 6); |
7325 | for (const auto &concatInput : newCN->getInputs()) { |
7326 | auto *SN = llvm::dyn_cast<SliceNode>(concatInput.getNode()); |
7327 | if (SN) { |
7328 | numSlicesToConcat++; |
7329 | } |
7330 | } |
7331 | } |
7332 | EXPECT_EQ(numSlicesToConcat, 5); |
7333 | |
7334 | bindings_.allocate(src1)->getHandle<float>().randomize(-10.0, 10.0, |
7335 | mod_.getPRNG()); |
7336 | bindings_.allocate(src2)->getHandle<float>().randomize(-10.0, 10.0, |
7337 | mod_.getPRNG()); |
7338 | bindings_.allocate(D)->getHandle<float>().randomize(-10.0, 10.0, |
7339 | mod_.getPRNG()); |
7340 | checkNumericalEquivalence(); |
7341 | } |
7342 | |
7343 | TEST_F(GraphOptz, EliminateSliceConcatWithReshapeTest) { |
7344 | auto *src = |
7345 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 5, 4}, "src" , false); |
7346 | auto *A = F_->createSlice("A" , src, {0, 0, 0}, {1, 5, 4}); |
7347 | auto *B = F_->createSlice("B" , src, {1, 0, 0}, {2, 5, 4}); |
7348 | auto *C = F_->createSlice("C" , src, {2, 0, 0}, {3, 5, 4}); |
7349 | auto *CN1 = F_->createConcat("Concat1" , {A, B, C}, 1); |
7350 | |
7351 | auto *E = F_->createSlice("E" , src, {0, 0, 0}, {4, 5, 1}); |
7352 | auto *F = F_->createSlice("F" , src, {0, 0, 1}, {4, 5, 2}); |
7353 | auto *G = F_->createSlice("G" , src, {0, 0, 2}, {4, 5, 3}); |
7354 | auto *H = F_->createSlice("H" , src, {0, 0, 3}, {4, 5, 4}); |
7355 | auto *CN2 = F_->createConcat("Concat2" , {E, F, G, H}, 1); |
7356 | |
7357 | F_->createSave("save1" , CN1); |
7358 | F_->createSave("save2" , CN2); |
7359 | |
7360 | EXPECT_EQ(F_->getNodes().size(), 11); |
7361 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 7); |
7362 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 2); |
7363 | |
7364 | optimizedF_ = optimizeFunctionForTest( |
7365 | F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()}); |
7366 | |
7367 | EXPECT_EQ(optimizedF_->getNodes().size(), 9); |
7368 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 2); |
7369 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 2); |
7370 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReshapeNodeKind), 2); |
7371 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::TransposeNodeKind), 1); |
7372 | |
7373 | bindings_.allocate(src)->getHandle<float>().randomize(-10.0, 10.0, |
7374 | mod_.getPRNG()); |
7375 | checkNumericalEquivalence(0.f); |
7376 | } |
7377 | |
7378 | // Check the merging of Sub(const, BN(x, scale, bias)) into BN. |
7379 | TEST_F(GraphOptz, FoldArithmeticChainIntoBatchNormQuant) { |
7380 | auto *subC = mod_.createConstant(ElemKind::FloatTy, {1, 1, 1, 1}, "subC" ); |
7381 | auto *var = mod_.createConstant(ElemKind::FloatTy, {1}, "var" ); |
7382 | auto *mean = mod_.createConstant(ElemKind::FloatTy, {1}, "mean" ); |
7383 | auto *beta = mod_.createConstant(ElemKind::FloatTy, {1}, "beta" ); |
7384 | auto *gamma = mod_.createConstant(ElemKind::FloatTy, {1}, "gamma" ); |
7385 | float v = 0.3f, m = 0.4f, b = 0.7f, g = -0.5f, c = 0.1; |
7386 | // (X - mean) * (1.0 / sqrt(var + eps)) * gamma + beta |
7387 | var->getPayloadMutable().getHandle<float>() = {v}; |
7388 | mean->getPayloadMutable().getHandle<float>() = {m}; |
7389 | beta->getPayloadMutable().getHandle<float>() = {b}; |
7390 | gamma->getPayloadMutable().getHandle<float>() = {g}; |
7391 | subC->getPayloadMutable().getHandle<float>() = {c}; |
7392 | auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 1, 1}, "input" , |
7393 | false, "NHWC" ); |
7394 | |
7395 | auto *BN = F_->createBatchNormalization("batch" , input->getType(), input, |
7396 | beta, gamma, mean, var); |
7397 | auto *sub = F_->createSub("sub" , subC, BN); |
7398 | auto *res = F_->createSave("save" , sub); |
7399 | // Compile. |
7400 | EXPECT_EQ(F_->getNodes().size(), 3); |
7401 | ::glow::convertPlaceholdersToConstants(F_, bindings_, {}); |
7402 | |
7403 | optimizedF_ = optimizeFunctionForTest(F_, {}, cctx_); |
7404 | EXPECT_EQ(optimizedF_->getNodes().size(), 2); |
7405 | |
7406 | auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName()); |
7407 | auto *opt_bn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput()); |
7408 | ASSERT_TRUE(opt_bn); |
7409 | // Verify that scale and offset are computed correctly. |
7410 | Constant *bnScale = llvm::dyn_cast<Constant>(opt_bn->getScale().getNode()); |
7411 | Constant *bnBias = llvm::dyn_cast<Constant>(opt_bn->getBias().getNode()); |
7412 | auto bnBiasVals = bnBias->getHandle<float>().raw(0); |
7413 | auto bnScaleVals = bnScale->getHandle<float>().raw(0); |
7414 | EXPECT_EQ(bnBiasVals, c - b); |
7415 | EXPECT_EQ(bnScaleVals, -g); |
7416 | } |
7417 | |
7418 | /// Test that EliminateSliceConcat makes no optimization when the axis of |
7419 | /// concatenation and slicing are not adjacent. |
7420 | TEST_F(GraphOptz, EliminateSliceConcatWithReshapeTestNoChange) { |
7421 | auto *src = |
7422 | mod_.createPlaceholder(ElemKind::FloatTy, {4, 5, 4}, "src" , false); |
7423 | auto *A = F_->createSlice("A" , src, {0, 0, 0}, {1, 5, 4}); |
7424 | auto *B = F_->createSlice("B" , src, {1, 0, 0}, {2, 5, 4}); |
7425 | auto *C = F_->createSlice("C" , src, {2, 0, 0}, {3, 5, 4}); |
7426 | auto *CN = F_->createConcat("Concat" , {A, B, C}, 2); |
7427 | |
7428 | F_->createSave("save" , CN); |
7429 | |
7430 | EXPECT_EQ(F_->getNodes().size(), 5); |
7431 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 3); |
7432 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 1); |
7433 | |
7434 | optimizedF_ = optimizeFunctionForTest( |
7435 | F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()}); |
7436 | |
7437 | EXPECT_EQ(optimizedF_->getNodes().size(), 5); |
7438 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 3); |
7439 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1); |
7440 | EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, |
7441 | /* skipName */ true), |
7442 | optimizedF_->toString(/* skipUsersForStorage */ false, |
7443 | /* skipName */ true)); |
7444 | } |
7445 | |
7446 | /// Verify that when we want to prevent constant folding it doesn't occur. |
7447 | TEST_F(GraphOptz, constantFoldPreventedNoop) { |
7448 | auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1" ); |
7449 | auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2" ); |
7450 | auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1" , |
7451 | /* isTrainable */ false); |
7452 | setConstValue(const1, 1.0f); |
7453 | setConstValue(const2, 2.0f); |
7454 | auto *splat2 = F_->createSplat( |
7455 | "splat2" , mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f); |
7456 | auto *splat3 = F_->createSplat( |
7457 | "splat3" , mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f); |
7458 | |
7459 | auto *add1 = F_->createAdd("add" , const1, const2); |
7460 | auto *mul1 = F_->createMul("mul1" , add1, splat2); |
7461 | auto *mul2 = F_->createMul("mul2" , mul1, splat3); |
7462 | F_->createSave("save" , mul2); |
7463 | auto *add3 = F_->createAdd("add" , const1, ph1); |
7464 | F_->createSave("save" , add3); |
7465 | |
7466 | ConstantModificationPreventer constModPreventer(mod_, cctx_); |
7467 | constModPreventer.activate(); |
7468 | EXPECT_FALSE(cctx_.optimizationOpts.enableConstantFolding); |
7469 | |
7470 | // Check that both Constants are protected and no change is made to the |
7471 | // Function during optimization. |
7472 | EXPECT_EQ(constModPreventer.getMapping().size(), 2); |
7473 | optimizedF_ = optimizeFunctionForTest(F_); |
7474 | EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, |
7475 | /* skipName */ true), |
7476 | optimizedF_->toString(/* skipUsersForStorage */ false, |
7477 | /* skipName */ true)); |
7478 | |
7479 | // Now deactivate the constModPreventer and check we can const fold still. |
7480 | constModPreventer.deactivateAndCleanup(); |
7481 | EXPECT_TRUE(cctx_.optimizationOpts.enableConstantFolding); |
7482 | mod_.eraseFunction(optimizedF_); |
7483 | optimizedF_ = optimizeFunctionForTest(F_); |
7484 | |
7485 | // After constant folding, left with just two Saves, one Add. |
7486 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
7487 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1); |
7488 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 2); |
7489 | |
7490 | bindings_.allocate(ph1)->getHandle<float>().randomize(-10.0, 10.0, |
7491 | mod_.getPRNG()); |
7492 | checkNumericalEquivalence(); |
7493 | } |
7494 | |
7495 | /// Test that a Conv2D is correctly lowered to FC for single batch. |
7496 | TEST_F(GraphOptz, lowerConv2DToFCSingleBatch) { |
7497 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, |
7498 | "input" , /* isTrainable */ false); |
7499 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
7500 | mod_.getPRNG()); |
7501 | |
7502 | Constant *filter = |
7503 | mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter" ); |
7504 | filter->getPayloadMutable().getHandle<float>().randomize(-10, 10, |
7505 | mod_.getPRNG()); |
7506 | |
7507 | Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias" ); |
7508 | bias->getPayloadMutable().getHandle<float>().randomize(-10, 10, |
7509 | mod_.getPRNG()); |
7510 | |
7511 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 3, 8}); |
7512 | auto *conv = F_->createConv("conv" , input, filter, bias, outTy, {1, 1}, |
7513 | {1, 1}, {0, 0, 0, 0}, 1); |
7514 | SaveNode *save = F_->createSave("save" , conv); |
7515 | bindings_.allocate(save->getPlaceholder()); |
7516 | |
7517 | // Backup function in optimizedF_. |
7518 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
7519 | |
7520 | // Lower Convolution. |
7521 | EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv)); |
7522 | EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_)); |
7523 | runDCEPass(F_, cctx_); |
7524 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind)); |
7525 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
7526 | |
7527 | // Now compile/run/compare F_ and optimizedF_. |
7528 | checkNumericalEquivalence(1e-6); |
7529 | } |
7530 | |
7531 | /// Test that a Conv2D is correctly lowered to FC for multi batch. |
7532 | TEST_F(GraphOptz, lowerConv2DToFCMultiBatch) { |
7533 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3, 4}, |
7534 | "input" , /* isTrainable */ false); |
7535 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
7536 | mod_.getPRNG()); |
7537 | |
7538 | Constant *filter = |
7539 | mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter" ); |
7540 | filter->getPayloadMutable().getHandle<float>().randomize(-10, 10, |
7541 | mod_.getPRNG()); |
7542 | |
7543 | Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias" ); |
7544 | bias->getPayloadMutable().getHandle<float>().randomize(-10, 10, |
7545 | mod_.getPRNG()); |
7546 | |
7547 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, {2, 2, 3, 8}); |
7548 | auto *conv = F_->createConv("conv" , input, filter, bias, outTy, {1, 1}, |
7549 | {1, 1}, {0, 0, 0, 0}, 1); |
7550 | SaveNode *save = F_->createSave("save" , conv); |
7551 | bindings_.allocate(save->getPlaceholder()); |
7552 | |
7553 | // Backup function in optimizedF_. |
7554 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
7555 | |
7556 | // Lower Convolution. |
7557 | EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv)); |
7558 | EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_)); |
7559 | runDCEPass(F_, cctx_); |
7560 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind)); |
7561 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind)); |
7562 | |
7563 | // Now compile/run/compare F_ and optimizedF_. |
7564 | checkNumericalEquivalence(1e-6); |
7565 | } |
7566 | |
7567 | /// Test that Mul and Add can be folded into LayerNorm. |
7568 | TEST_F(GraphOptz, foldMulAddIntoLayerNorm) { |
7569 | auto *input = |
7570 | mod_.createPlaceholder(ElemKind::FloatTy, {2, 4, 10, 20}, "in" , false); |
7571 | |
7572 | Tensor scaleT(ElemKind::FloatTy, {10, 20}); |
7573 | scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG()); |
7574 | Constant *scaleC = mod_.createConstant("scale" , std::move(scaleT)); |
7575 | SplatNode *biasS = F_->createSplat("bias" , scaleC->getType(), 1.5f); |
7576 | |
7577 | auto *LN = F_->createLayerNormalization("LN" , input->getType(), input, scaleC, |
7578 | biasS, 1e-5); |
7579 | |
7580 | SplatNode *splat = F_->createSplat("splat" , scaleC->getType(), 0.5f); |
7581 | MulNode *MN = |
7582 | F_->createNodeWithBroadcast<MulNode>("mul" , /* axis */ -1, LN, splat); |
7583 | |
7584 | Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20}); |
7585 | addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG()); |
7586 | Constant *addC = mod_.createConstant("addC" , std::move(addT)); |
7587 | AddNode *AN = |
7588 | F_->createNodeWithBroadcast<AddNode>("add" , /* axis */ -1, MN, addC); |
7589 | |
7590 | // This MulNode has a Placeholder as RHS and shouldn't be fused into LayerNorm |
7591 | Tensor mulT(ElemKind::FloatTy, {1, 1, 10, 20}); |
7592 | mulT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG()); |
7593 | Constant *mulC = mod_.createConstant("mulC" , std::move(mulT)); |
7594 | MN = F_->createNodeWithBroadcast<MulNode>("mul_not_fuse" , /* axis */ -1, AN, |
7595 | mulC); |
7596 | F_->createSave("save" , MN); |
7597 | |
7598 | ConstantModificationPreventer constModPreventer(mod_, cctx_); |
7599 | constModPreventer.activate(); |
7600 | optimizedF_ = optimizeFunctionForTest(F_, {}, cctx_); |
7601 | // Now do const folding with constants swapped back in. |
7602 | constModPreventer.deactivateAndCleanup(); |
7603 | ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_); |
7604 | runDCEPass(optimizedF_, cctx_); |
7605 | |
7606 | // Because Muls and Add are folded in, they should not exist anymore, nor |
7607 | // should Broadcasts that expand them to match the output of LN. |
7608 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind)); |
7609 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind)); |
7610 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::BroadcastNodeKind)); |
7611 | |
7612 | // Remove the temporary constant folding Functions and their Placeholders |
7613 | // so that they don't participate in 'checkNumericalEquivalence'. |
7614 | cleanupConstantFolding(mod_, record, &bindings_); |
7615 | |
7616 | // Now compile/run/compare F_ and optimizedF_. |
7617 | bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG()); |
7618 | checkNumericalEquivalence(1.2e-7); |
7619 | } |
7620 | |
7621 | /// Test that Mul and Add can be folded into LayerNorm when the leading dims are |
7622 | /// all one. |
7623 | TEST_F(GraphOptz, foldMulAddIntoLayerNormNoBatch) { |
7624 | auto *input = |
7625 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 10, 20}, "in" , false); |
7626 | |
7627 | Tensor scaleT(ElemKind::FloatTy, {10, 20}); |
7628 | scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG()); |
7629 | Constant *scaleC = mod_.createConstant("scale" , std::move(scaleT)); |
7630 | SplatNode *biasS = F_->createSplat("bias" , scaleC->getType(), 1.5f); |
7631 | |
7632 | auto *LN = F_->createLayerNormalization("LN" , input->getType(), input, scaleC, |
7633 | biasS, 1e-5); |
7634 | |
7635 | SplatNode *splat = F_->createSplat("splat" , scaleC->getType(), 0.5f); |
7636 | MulNode *MN = |
7637 | F_->createNodeWithBroadcast<MulNode>("mul" , /* axis */ -1, LN, splat); |
7638 | |
7639 | Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20}); |
7640 | addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG()); |
7641 | Constant *addC = mod_.createConstant("addC" , std::move(addT)); |
7642 | AddNode *AN = |
7643 | F_->createNodeWithBroadcast<AddNode>("add" , /* axis */ -1, MN, addC); |
7644 | F_->createSave("save" , AN); |
7645 | |
7646 | optimizedF_ = optimizeFunctionForTest(F_); |
7647 | |
7648 | // Because Mul and Add are folded in, they should not exist anymore, nor |
7649 | // should tiles that expand them to match the output of LN. |
7650 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind)); |
7651 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind)); |
7652 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind)); |
7653 | |
7654 | // Now compile/run/compare F_ and optimizedF_. |
7655 | bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG()); |
7656 | checkNumericalEquivalence(1e-6); |
7657 | } |
7658 | |
7659 | TEST_F(GraphOptz, transposeQuantizeConstantWithAlignment) { |
7660 | // Define a type with custom alignments. |
7661 | Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1}); |
7662 | Type quantTypeWithAlignments(ElemKind::Int8QTy, {2, 3, 4, 5}, {1, 1, 32, 1}, |
7663 | 1.0, 0); |
7664 | Type transposedQuantTypeWithAlignments(ElemKind::Int8QTy, {2, 4, 5, 3}, |
7665 | {1, 1, 32, 1}, 1.0, 0); |
7666 | auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments); |
7667 | auto modQuantTransposedTyWithAlignments = |
7668 | mod_.uniqueType(transposedQuantTypeWithAlignments); |
7669 | auto modQuantTyWithAlignments = mod_.uniqueType(quantTypeWithAlignments); |
7670 | auto *I = mod_.createConstant(modTyWithAlignments, "input1" ); |
7671 | auto *Q = F_->createQuantize("quantize" , I, modQuantTyWithAlignments); |
7672 | auto *T = F_->createTranspose("transpose" , Q, NCHW2NHWC); |
7673 | T->setType(TransposeNode::ResultIdx, modQuantTransposedTyWithAlignments); |
7674 | SaveNode *S = F_->createSave("ret" , T); |
7675 | |
7676 | // Skip ConstantFolding as it would have the same result as this opt. |
7677 | CompilationContext cctx; |
7678 | cctx.optimizationOpts.enableConstantFolding = false; |
7679 | |
7680 | EXPECT_EQ(F_->getNodes().size(), 3); |
7681 | ::glow::optimize(F_, cctx); |
7682 | EXPECT_EQ(F_->getNodes().size(), 2); |
7683 | |
7684 | // Constant and Quantize should have new shape. |
7685 | auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput()); |
7686 | ASSERT_TRUE(newQ); |
7687 | EXPECT_TRUE(newQ->getResult().dims().equals({2, 4, 5, 3})); |
7688 | auto *newC = llvm::dyn_cast<Constant>(newQ->getInput()); |
7689 | ASSERT_TRUE(newC); |
7690 | EXPECT_TRUE(newC->getType()->dims().equals({2, 4, 5, 3})); |
7691 | |
7692 | // Check that alignments are preserved by optimizations. |
7693 | auto expectedNewTy = mod_.uniqueTypeWithNewShape( |
7694 | modTyWithAlignments, modQuantTransposedTyWithAlignments); |
7695 | EXPECT_TRUE(newQ->getInput().getType()->isEqual(expectedNewTy)); |
7696 | |
7697 | EXPECT_TRUE(F_->verify()); |
7698 | } |
7699 | |
7700 | TEST_F(GraphOptz, DequantSwishQuantOpt) { |
7701 | const dim_t origDims[] = {1, 5, 10, 15}; |
7702 | Placeholder *A = mod_.createPlaceholder(ElemKind::Int8QTy, origDims, 0.039, 0, |
7703 | "input" , false); |
7704 | DequantizeNode *DN = F_->createDequantize("deq" , A, ElemKind::Float16Ty); |
7705 | SwishNode *swish = F_->createSwish("swish" , DN); |
7706 | QuantizeNode *QN = |
7707 | F_->createQuantize("quant" , swish, ElemKind::Int8QTy, 0.0204, -114); |
7708 | DequantizeNode *finalDN = |
7709 | F_->createDequantize("deq_final" , QN, ElemKind::Float16Ty); |
7710 | F_->createSave("ret" , finalDN); |
7711 | |
7712 | optimizedF_ = optimizeFunctionForTest( |
7713 | F_, {FunctionPassID::QuantizeSwish, getDCEPassConfig()}); |
7714 | |
7715 | // Swish, Dequant, Save |
7716 | EXPECT_EQ(optimizedF_->getNodes().size(), 3); |
7717 | |
7718 | SaveNode *save = nullptr; |
7719 | for (auto &N : optimizedF_->getNodes()) { |
7720 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
7721 | save = llvm::dyn_cast<SaveNode>(&N); |
7722 | break; |
7723 | } |
7724 | } |
7725 | ASSERT_TRUE(save); |
7726 | |
7727 | DequantizeNode *dequantizeOpt = |
7728 | llvm::dyn_cast<DequantizeNode>(save->getInput()); |
7729 | ASSERT_TRUE(dequantizeOpt); |
7730 | |
7731 | SwishNode *swishOpt = llvm::dyn_cast<SwishNode>(dequantizeOpt->getInput()); |
7732 | ASSERT_TRUE(swishOpt); |
7733 | EXPECT_EQ(swishOpt->getInput(), A->getOutput()); |
7734 | EXPECT_EQ(swishOpt->getResult().getType(), QN->getResult().getType()); |
7735 | |
7736 | bindings_.allocate(mod_.getPlaceholders()); |
7737 | bindings_.get(A)->getHandle<int8_t>().randomize(-128, 127, mod_.getPRNG()); |
7738 | |
7739 | checkNumericalEquivalence(0.025f); |
7740 | } |
7741 | |
7742 | /// Test the conversion of FullyConnected to 1x1 Convolution. |
7743 | TEST_F(GraphOptz, ConvertFullyConnectedToConvolutionOpt) { |
7744 | |
7745 | const std::vector<dim_t> inpDims = {3, 5}; |
7746 | const std::vector<dim_t> weightsDims = {5, 7}; |
7747 | const std::vector<dim_t> biasDims = {7}; |
7748 | |
7749 | // Create graph. |
7750 | Placeholder *input = |
7751 | mod_.createPlaceholder(ElemKind::FloatTy, inpDims, "input" , false); |
7752 | Placeholder *weights = |
7753 | mod_.createPlaceholder(ElemKind::FloatTy, weightsDims, "weights" , false); |
7754 | Placeholder *bias = |
7755 | mod_.createPlaceholder(ElemKind::FloatTy, biasDims, "bias" , false); |
7756 | FullyConnectedNode *FCN = |
7757 | F_->createFullyConnected("fc" , input, weights, bias); |
7758 | F_->createSave("save" , FCN); |
7759 | |
7760 | // Optimize graph. |
7761 | optimizedF_ = optimizeFunctionForTest( |
7762 | F_, |
7763 | {FunctionPassID::ConvertFullyConnectedToConvolution, getDCEPassConfig()}); |
7764 | |
7765 | // Check optimized graph. |
7766 | EXPECT_EQ(optimizedF_->getNodes().size(), 6); |
7767 | SaveNode *save = nullptr; |
7768 | for (auto &N : optimizedF_->getNodes()) { |
7769 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
7770 | save = llvm::dyn_cast<SaveNode>(&N); |
7771 | break; |
7772 | } |
7773 | } |
7774 | ASSERT_TRUE(save); |
7775 | ReshapeNode *reshapeOut = llvm::dyn_cast<ReshapeNode>(save->getInput()); |
7776 | ASSERT_TRUE(reshapeOut); |
7777 | ConvolutionNode *conv = |
7778 | llvm::dyn_cast<ConvolutionNode>(reshapeOut->getInput()); |
7779 | ASSERT_TRUE(conv); |
7780 | ReshapeNode *reshapeFilter = llvm::dyn_cast<ReshapeNode>(conv->getFilter()); |
7781 | ASSERT_TRUE(reshapeFilter); |
7782 | TransposeNode *transpFilter = |
7783 | llvm::dyn_cast<TransposeNode>(reshapeFilter->getInput()); |
7784 | ASSERT_TRUE(transpFilter); |
7785 | ReshapeNode *reshapeInput = llvm::dyn_cast<ReshapeNode>(conv->getInput()); |
7786 | ASSERT_TRUE(reshapeInput); |
7787 | |
7788 | // Check numerical equivalence. |
7789 | bindings_.allocate(mod_.getPlaceholders()); |
7790 | bindings_.get(input)->getHandle<float>().randomize(-1, 1, mod_.getPRNG()); |
7791 | bindings_.get(weights)->getHandle<float>().randomize(-1, 1, mod_.getPRNG()); |
7792 | bindings_.get(bias)->getHandle<float>().randomize(-1, 1, mod_.getPRNG()); |
7793 | checkNumericalEquivalence(1e-8); |
7794 | } |
7795 | |
7796 | /// Test that when we have Concat({X, Quantize(Clip)}), that we don't optimize |
7797 | /// to Concat({X, Quantize'}), since Quantize' will have different quantization |
7798 | /// parameters and therefore won't have the same quantization parameters as X. |
7799 | TEST_F(GraphOptz, DisallowChangeQuantParamWithConcatInput) { |
7800 | Placeholder *PH1 = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.3, 5, |
7801 | "input" , false); |
7802 | bindings_.allocate(PH1)->getHandle<int8_t>().randomize(-128, 127, |
7803 | mod_.getPRNG()); |
7804 | Placeholder *PH2 = |
7805 | mod_.createPlaceholder(ElemKind::Float16Ty, {1, 32}, "input" , false); |
7806 | bindings_.allocate(PH2)->getHandle<float16_t>().randomize(-40.f, 40.f, |
7807 | mod_.getPRNG()); |
7808 | |
7809 | ClipNode *clip = F_->createClip("clip" , PH2, 0.f, 1000.f); |
7810 | QuantizeNode *quant = F_->createQuantize( |
7811 | "quantize" , clip, mod_.uniqueType(ElemKind::Int8QTy, {1, 32}, 0.3, 5)); |
7812 | |
7813 | ConcatNode *CN = F_->createConcat("concat" , {PH1, quant}, 0); |
7814 | F_->createSave("save" , CN); |
7815 | |
7816 | optimizedF_ = optimizeFunctionForTest(F_); |
7817 | |
7818 | // Expect the graph didn't change at all, since we disallowed it due to the |
7819 | // fact that we disallowed Quantize(Clip) to be merged into Quantize', ssince |
7820 | // the Quantize is consumed by a Concat which requires the quantization |
7821 | // parameters to stay the same across all inputs. |
7822 | EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, |
7823 | /* skipName */ true), |
7824 | optimizedF_->toString(/* skipUsersForStorage */ false, |
7825 | /* skipName */ true)); |
7826 | |
7827 | checkNumericalEquivalence(); |
7828 | } |
7829 | |
7830 | /// Test that a AdaptiveAvgPool with 1x1 OFM is correctly lowered to AvgPool. |
7831 | TEST_F(GraphOptz, lower1x1AdaptiveAvgPoolToAvgPool) { |
7832 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3, 4}, |
7833 | "input" , /* isTrainable */ false); |
7834 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
7835 | mod_.getPRNG()); |
7836 | |
7837 | auto outTy = mod_.uniqueType(ElemKind::FloatTy, {2, 1, 1, 4}); |
7838 | auto *pool = F_->createAdaptiveAvgPool("avg" , input, outTy); |
7839 | SaveNode *save = F_->createSave("save" , pool); |
7840 | bindings_.allocate(save->getPlaceholder()); |
7841 | |
7842 | // Backup function in optimizedF_. |
7843 | optimizedF_ = F_->clone(F_->getName().str() + "_optimized" ); |
7844 | |
7845 | // Lower |
7846 | EXPECT_TRUE(glow::lowerNode(F_, pool, cctx_)); |
7847 | runDCEPass(F_, cctx_); |
7848 | EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AdaptiveAvgPoolNodeKind)); |
7849 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind)); |
7850 | |
7851 | // Now compile/run/compare F_ and optimizedF_. |
7852 | checkNumericalEquivalence(1e-6); |
7853 | } |
7854 | |
7855 | /// Skip Clip-Quantize optimization when loadUniquedDummyQParams. |
7856 | TEST_F(GraphOptz, SkipDummyQParamOpts) { |
7857 | Placeholder *A = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "A" , false); |
7858 | ClipNode *CN = F_->createClip("clip" , A, -1000.f, 1000.f); |
7859 | QuantizeNode *QN = F_->createQuantize( |
7860 | "quantize" , CN, mod_.uniqueType(ElemKind::Int8QTy, {5}, 0.3, 5)); |
7861 | F_->createSave("ret" , QN); |
7862 | |
7863 | CompilationContext cctx; |
7864 | cctx.precisionConfig.loadUniquedDummyQParams = true; |
7865 | |
7866 | optimizedF_ = optimizeFunctionForTest(F_, {}, cctx); |
7867 | EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true), |
7868 | optimizedF_->toString(/* skipUsersForStorage */ false, |
7869 | /* skipName */ true)); |
7870 | } |
7871 | |
7872 | /// Test that Min -> Max is correctly folded into Clip |
7873 | TEST_F(GraphOptz, foldMinMaxToClipTest) { |
7874 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 5}, |
7875 | "input" , /* isTrainable */ false); |
7876 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
7877 | mod_.getPRNG()); |
7878 | |
7879 | auto *minFirstSplat = F_->createSplat("min_first_splat" , input->getType(), 5); |
7880 | auto *maxFirstSplat = |
7881 | F_->createSplat("max_first_splat" , input->getType(), -2); |
7882 | auto *minFirst = F_->createMin("min_first" , input, minFirstSplat); |
7883 | auto *maxFirst = F_->createMax("max_first" , maxFirstSplat, minFirst); |
7884 | |
7885 | auto *minSecondSplat = F_->createSplat( |
7886 | "min_second_splat" , |
7887 | F_->getParent()->uniqueTypeWithNewShape(input->getType(), {3, 1, 1}), 3); |
7888 | auto *maxSecondSplat = |
7889 | F_->createSplat("max_second_splat" , input->getType(), 1); |
7890 | auto *maxSecond = F_->createMax("max_second" , maxFirst, maxSecondSplat); |
7891 | auto *minSecond = F_->createNodeWithBroadcast<MinNode>( |
7892 | "min_second" , /* axis */ -1, maxSecond, minSecondSplat); |
7893 | SaveNode *save = F_->createSave("save" , minSecond); |
7894 | bindings_.allocate(save->getPlaceholder()); |
7895 | |
7896 | // Need to run OptimizeArithmeticNodes first to move constant operators in |
7897 | // communative nodes to RHS. |
7898 | optimizedF_ = optimizeFunctionForTest( |
7899 | F_, {FunctionPassID::OptimizeArithmeticNodes, |
7900 | FunctionPassID::FoldMinMaxToClip, getDCEPassConfig()}); |
7901 | |
7902 | EXPECT_EQ(4, optimizedF_->getNodes().size()); |
7903 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MinNodeKind)); |
7904 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MaxNodeKind)); |
7905 | |
7906 | // Get SaveNode in optimizedF_ |
7907 | save = llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName("save_save" )); |
7908 | // Check min and max of the second ClipNode |
7909 | ClipNode *CN = llvm::dyn_cast<ClipNode>(save->getInput().getNode()); |
7910 | EXPECT_EQ(1, CN->getMin()); |
7911 | EXPECT_EQ(3, CN->getMax()); |
7912 | |
7913 | // There's a BroadcastNode in between the first and the second ClipNode |
7914 | BroadcastNode *BN = llvm::dyn_cast<BroadcastNode>(CN->getInput().getNode()); |
7915 | // Check min and max of the first ClipNode |
7916 | CN = llvm::dyn_cast<ClipNode>(BN->getInput().getNode()); |
7917 | EXPECT_EQ(-2, CN->getMin()); |
7918 | EXPECT_EQ(5, CN->getMax()); |
7919 | |
7920 | checkNumericalEquivalence(); |
7921 | } |
7922 | |
7923 | /// Test that Min -> Max Fold pass does not break with a reshape LHS input. |
7924 | TEST_F(GraphOptz, foldMinMaxToClipReshapeNoBroadcastTest) { |
7925 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 100}, |
7926 | "input" , /* isTrainable */ false); |
7927 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
7928 | mod_.getPRNG()); |
7929 | |
7930 | auto *reshape = F_->createReshape("reshape" , input, {100, 1}); |
7931 | const TypeRef T = reshape->getResult().getType(); |
7932 | |
7933 | auto *maxSplat = F_->createSplat("max_splat" , T, -2); |
7934 | auto *minSplat = F_->createSplat("min_splat" , T, 5); |
7935 | auto *max = F_->createMax("max" , reshape, maxSplat); |
7936 | auto *min = F_->createMin("min" , max, minSplat); |
7937 | SaveNode *save = F_->createSave("save" , min); |
7938 | bindings_.allocate(save->getPlaceholder()); |
7939 | |
7940 | optimizedF_ = optimizeFunctionForTest( |
7941 | F_, {FunctionPassID::FoldMinMaxToClip, getDCEPassConfig()}); |
7942 | |
7943 | EXPECT_EQ(3, optimizedF_->getNodes().size()); |
7944 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MinNodeKind)); |
7945 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MaxNodeKind)); |
7946 | |
7947 | save = llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName("save_save" )); |
7948 | ASSERT_TRUE(save); |
7949 | auto *CN = llvm::dyn_cast<ClipNode>(save->getInput().getNode()); |
7950 | ASSERT_TRUE(CN); |
7951 | EXPECT_EQ(-2, CN->getMin()); |
7952 | EXPECT_EQ(5, CN->getMax()); |
7953 | auto *RN = llvm::dyn_cast<ReshapeNode>(CN->getInput()); |
7954 | ASSERT_TRUE(RN); |
7955 | EXPECT_TRUE(RN->getResult().getType()->isEqual(T)); |
7956 | EXPECT_EQ(RN->getInput(), input->getOutput()); |
7957 | |
7958 | checkNumericalEquivalence(); |
7959 | } |
7960 | |
7961 | /// Check that we replace a Node with 0.f scale in fp16 with a splat correctly. |
7962 | TEST_F(GraphOptz, ReplaceZeroScaleFP16QuantOpt) { |
7963 | auto *LHS = mod_.createPlaceholder(ElemKind::FloatTy, {20, 30}, "LHS" , false); |
7964 | auto *RHSQ = mod_.createPlaceholder(ElemKind::Int8QTy, {20, 30}, 0.1f, 10, |
7965 | "LHS" , false); |
7966 | |
7967 | // scale = 1e-9 underflows fp16 and so this opt applies. |
7968 | auto *LHSQTy = mod_.uniqueType(ElemKind::Int8QTy, {20, 30}, 1e-9, 10); |
7969 | auto *LHSQ = F_->createQuantize("LHSQ" , LHS, LHSQTy); |
7970 | |
7971 | auto *A = F_->createAdd("add" , RHSQ->getOutput().getType(), LHSQ, RHSQ); |
7972 | auto *Q = F_->createDequantize("deq" , A, ElemKind::FloatTy); |
7973 | F_->createSave("save" , Q); |
7974 | |
7975 | optimizedF_ = optimizeFunctionForTest( |
7976 | F_, {FunctionPassID::ReplaceZeroScaleFP16QuantNodes, getDCEPassConfig()}); |
7977 | |
7978 | SaveNode *save = nullptr; |
7979 | for (auto &N : optimizedF_->getNodes()) { |
7980 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
7981 | save = llvm::dyn_cast<SaveNode>(&N); |
7982 | break; |
7983 | } |
7984 | } |
7985 | ASSERT_TRUE(save); |
7986 | |
7987 | DequantizeNode *optQ = llvm::dyn_cast<DequantizeNode>(save->getInput()); |
7988 | ASSERT_TRUE(optQ); |
7989 | AddNode *optA = llvm::dyn_cast<AddNode>(optQ->getInput()); |
7990 | ASSERT_TRUE(A); |
7991 | |
7992 | SplatNode *splat = llvm::dyn_cast<SplatNode>(optA->getLHS()); |
7993 | ASSERT_TRUE(splat); |
7994 | EXPECT_EQ(splat->getValue(), 0.f); |
7995 | const TypeRef optLHSQTy = splat->getResult().getType(); |
7996 | EXPECT_EQ(optLHSQTy->getScale(), 1.f); |
7997 | EXPECT_EQ(optLHSQTy->getOffset(), 0); |
7998 | EXPECT_EQ(optLHSQTy->getElementType(), LHSQTy->getElementType()); |
7999 | EXPECT_EQ(optLHSQTy->dims(), LHSQTy->dims()); |
8000 | |
8001 | bindings_.allocate(LHS)->getHandle<float>().randomize(-10.f, 10.f, |
8002 | mod_.getPRNG()); |
8003 | bindings_.allocate(RHSQ)->getHandle<int8_t>().randomize(-128, 127, |
8004 | mod_.getPRNG()); |
8005 | |
8006 | checkNumericalEquivalence(0.f); |
8007 | } |
8008 | |
8009 | /// Same as GraphOptz, but when running numerical equivalence use the CPU |
8010 | /// backend instead of Interpreter. |
8011 | class GraphOptzOnCPU : public GraphOptz { |
8012 | public: |
8013 | GraphOptzOnCPU() : GraphOptz("CPU" ) {} |
8014 | #ifndef GLOW_WITH_CPU |
8015 | virtual void checkNumericalEquivalence(float allowedError = 0.0001) override { |
8016 | LOG(INFO) << "Skipping numerical equivalence check as the CPU backend is " |
8017 | "not built." ; |
8018 | } |
8019 | #endif /* GLOW_WITH_CPU */ |
8020 | }; |
8021 | |
8022 | /// Check that we replace a Node with 0.f scale in fp16 with a splat correctly. |
8023 | TEST_F(GraphOptzOnCPU, ReplaceZeroScaleFP16QuantConstOpt) { |
8024 | auto *input = |
8025 | mod_.createPlaceholder(ElemKind::Int8QTy, {1, 1}, 1.0, 0, "input" , false); |
8026 | // scale = 1e-9 underflows fp16 and so this opt applies. |
8027 | auto *weights = |
8028 | mod_.createConstant(ElemKind::Int8QTy, {1, 1}, 1e-9, 0, "weights" ); |
8029 | weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127, |
8030 | mod_.getPRNG()); |
8031 | auto *MM = F_->createMatMul("matmul" , input, weights); |
8032 | auto *DQ = F_->createDequantize("dq" , MM, ElemKind::FloatTy); |
8033 | F_->createSave("save" , DQ); |
8034 | |
8035 | optimizedF_ = optimizeFunctionForTest( |
8036 | F_, {FunctionPassID::ReplaceZeroScaleFP16QuantNodes, getDCEPassConfig()}); |
8037 | |
8038 | SaveNode *save = nullptr; |
8039 | for (auto &N : optimizedF_->getNodes()) { |
8040 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
8041 | save = llvm::dyn_cast<SaveNode>(&N); |
8042 | break; |
8043 | } |
8044 | } |
8045 | ASSERT_TRUE(save); |
8046 | |
8047 | auto *optDQ = llvm::dyn_cast<DequantizeNode>(save->getInput()); |
8048 | ASSERT_TRUE(optDQ); |
8049 | auto *optMM = llvm::dyn_cast<MatMulNode>(optDQ->getInput()); |
8050 | ASSERT_TRUE(optMM); |
8051 | |
8052 | SplatNode *splat = llvm::dyn_cast<SplatNode>(optMM->getRHS()); |
8053 | ASSERT_TRUE(splat); |
8054 | EXPECT_EQ(splat->getValue(), 0.f); |
8055 | const TypeRef splatQTy = splat->getResult().getType(); |
8056 | EXPECT_EQ(splatQTy->getScale(), 1.f); |
8057 | EXPECT_EQ(splatQTy->getOffset(), 0); |
8058 | EXPECT_EQ(splatQTy->getElementType(), weights->getOutput().getElementType()); |
8059 | EXPECT_EQ(splatQTy->dims(), weights->getOutput().dims()); |
8060 | |
8061 | bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127, |
8062 | mod_.getPRNG()); |
8063 | checkNumericalEquivalence(0.f); |
8064 | } |
8065 | |
8066 | TEST_F(GraphOptz, TestEliminateClipsOutsideFP16Range) { |
8067 | Placeholder *A = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "A" , false); |
8068 | ClipNode *CN1 = F_->createClipMinMaxFP16("clip1" , A); |
8069 | ClipNode *CN2 = F_->createClip("clip2" , A, kMinFP16, kMaxFP16 - 1.f); |
8070 | QuantizeNode *QN1 = F_->createQuantize( |
8071 | "q1" , CN1, mod_.uniqueType(ElemKind::Int8QTy, {5}, 0.3, 5)); |
8072 | QuantizeNode *QN2 = F_->createQuantize( |
8073 | "q2" , CN2, mod_.uniqueType(ElemKind::Int8QTy, {5}, 0.3, 5)); |
8074 | AddNode *AN = F_->createAdd("add" , QN1, QN2); |
8075 | DequantizeNode *DN = F_->createDequantize("dq" , AN, ElemKind::Float16Ty); |
8076 | ClipNode *CN3 = F_->createClipMinMaxFP16("clip3" , DN); |
8077 | F_->createSave("ret" , CN3); |
8078 | |
8079 | CompilationContext cctx; |
8080 | cctx.precisionConfig.clipQuantRangeToFP16 = true; |
8081 | |
8082 | optimizedF_ = optimizeFunctionForTest( |
8083 | F_, {FunctionPassID::EliminateClipsOutsideFP16Range, getDCEPassConfig()}, |
8084 | cctx); |
8085 | |
8086 | SaveNode *save = nullptr; |
8087 | for (auto &N : optimizedF_->getNodes()) { |
8088 | if (N.getKind() == Kinded::Kind::SaveNodeKind) { |
8089 | save = llvm::dyn_cast<SaveNode>(&N); |
8090 | break; |
8091 | } |
8092 | } |
8093 | ASSERT_TRUE(save); |
8094 | |
8095 | auto *optDQ = llvm::dyn_cast<DequantizeNode>(save->getInput()); |
8096 | ASSERT_TRUE(optDQ); |
8097 | auto *optAN = llvm::dyn_cast<AddNode>(optDQ->getInput()); |
8098 | ASSERT_TRUE(optAN); |
8099 | |
8100 | auto *optQN1 = llvm::dyn_cast<QuantizeNode>(optAN->getLHS()); |
8101 | ASSERT_TRUE(optQN1); |
8102 | EXPECT_EQ(optQN1->getInput(), A->getOutput()); |
8103 | |
8104 | auto *optQN2 = llvm::dyn_cast<QuantizeNode>(optAN->getRHS()); |
8105 | ASSERT_TRUE(optQN2); |
8106 | auto *optCN2 = llvm::dyn_cast<ClipNode>(optQN2->getInput()); |
8107 | ASSERT_TRUE(optCN2); |
8108 | EXPECT_EQ(optCN2->getMin(), CN2->getMin()); |
8109 | EXPECT_EQ(optCN2->getMax(), CN2->getMax()); |
8110 | EXPECT_EQ(optCN2->getInput(), A->getOutput()); |
8111 | |
8112 | bindings_.allocate(A)->getHandle<float16_t>().randomize(-128, 127, |
8113 | mod_.getPRNG()); |
8114 | checkNumericalEquivalence(0.f); |
8115 | } |
8116 | |
8117 | TEST_F(GraphOptz, TestUpdateQuantReluTypes) { |
8118 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.11, -1, |
8119 | "input" , false); |
8120 | auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 0.2, 3, |
8121 | "weights" , false); |
8122 | auto *bias = |
8123 | mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 0.01, 2, "bias" , false); |
8124 | auto *addW = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.3, -4, |
8125 | "addw" , false); |
8126 | |
8127 | auto *fc = F_->createFullyConnected("fc" , input, weights, bias); |
8128 | auto *qRelu = F_->createRELU("relu" , fc->getResult()); |
8129 | auto *qAdd = F_->createAdd("add" , qRelu, addW); |
8130 | F_->createSave("save" , qAdd); |
8131 | |
8132 | updateQuantReluTypes(F_); |
8133 | |
8134 | const auto fcRange = fc->getResult().getType()->getQuantizedValueRange(); |
8135 | const auto reluRange = qRelu->getResult().getType()->getQuantizedValueRange(); |
8136 | EXPECT_NE(reluRange.first, fcRange.first); |
8137 | EXPECT_EQ(reluRange.first, 0); |
8138 | EXPECT_EQ(reluRange.second, fcRange.second); |
8139 | } |
8140 | |
8141 | TEST_F(GraphOptz, TestUpdateQuantReluTypesChained) { |
8142 | auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 0.11, -1, |
8143 | "input" , false); |
8144 | auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 0.2, 3, |
8145 | "weights" , false); |
8146 | auto *bias = |
8147 | mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 0.01, 2, "bias" , false); |
8148 | auto *addW = |
8149 | mod_.createPlaceholder(ElemKind::Int8QTy, {128}, 0.3, -4, "addw" , false); |
8150 | |
8151 | auto *fc = F_->createFullyConnected("fc" , input, weights, bias); |
8152 | auto *qRelu = F_->createRELU("relu" , fc->getResult()); |
8153 | auto *qConcat = F_->createConcat("concat" , {qRelu, qRelu}, 0); |
8154 | auto *qReshape = F_->createReshape("reshape" , qConcat, {128}); |
8155 | auto *qAdd = F_->createAdd("add" , qReshape, addW); |
8156 | F_->createSave("save" , qAdd); |
8157 | |
8158 | updateQuantReluTypes(F_); |
8159 | |
8160 | const auto fcRange = fc->getResult().getType()->getQuantizedValueRange(); |
8161 | const auto reluRange = qRelu->getResult().getType()->getQuantizedValueRange(); |
8162 | EXPECT_NE(reluRange.first, fcRange.first); |
8163 | EXPECT_EQ(reluRange.first, 0); |
8164 | EXPECT_EQ(reluRange.second, fcRange.second); |
8165 | |
8166 | // Check that the relu's type now also matches that of the chain of shape |
8167 | // users after it. |
8168 | const TypeRef qReluTy = qRelu->getResult().getType(); |
8169 | EXPECT_EQ(qReluTy->getScale(), qConcat->getResult().getType()->getScale()); |
8170 | EXPECT_EQ(qReluTy->getOffset(), qConcat->getResult().getType()->getOffset()); |
8171 | EXPECT_EQ(qReluTy->getScale(), qReshape->getResult().getType()->getScale()); |
8172 | EXPECT_EQ(qReluTy->getOffset(), qReshape->getResult().getType()->getOffset()); |
8173 | } |
8174 | |
8175 | TEST_F(GraphOptz, SinkReshapeBelowQuantize) { |
8176 | auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A" , false); |
8177 | auto *RN = F_->createReshape("reshape" , I, {32, 64, 1}); |
8178 | auto *QN = F_->createQuantize("quantize" , RN, ElemKind::Int8QTy, 0.2f, 1); |
8179 | auto *SN = F_->createSave("ret" , QN); |
8180 | |
8181 | optimizedF_ = optimizeFunctionForTest( |
8182 | F_, {FunctionPassID::SinkCode, getDCEPassConfig()}); |
8183 | |
8184 | auto *optSN = |
8185 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
8186 | ASSERT_TRUE(optSN); |
8187 | auto *optRN = llvm::dyn_cast<ReshapeNode>(optSN->getInput()); |
8188 | ASSERT_TRUE(optRN); |
8189 | EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Int8QTy); |
8190 | EXPECT_EQ(optRN->getResult().getScale(), 0.2f); |
8191 | EXPECT_EQ(optRN->getResult().getOffset(), 1); |
8192 | EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims()); |
8193 | auto *optQN = llvm::dyn_cast<QuantizeNode>(optRN->getInput()); |
8194 | ASSERT_TRUE(optQN); |
8195 | EXPECT_EQ(optQN->getResult().getElementType(), ElemKind::Int8QTy); |
8196 | EXPECT_EQ(optQN->getResult().getScale(), 0.2f); |
8197 | EXPECT_EQ(optQN->getResult().getOffset(), 1); |
8198 | EXPECT_EQ(optQN->getInput().getNode(), I); |
8199 | |
8200 | bindings_.allocate(I)->getHandle<float>().randomize(-30, 30, mod_.getPRNG()); |
8201 | checkNumericalEquivalence(0.f); |
8202 | } |
8203 | |
8204 | TEST_F(GraphOptz, SinkReshapeBelowConvertTo) { |
8205 | auto *I = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A" , false); |
8206 | auto *RN = F_->createReshape("reshape" , I, {32, 64, 1}); |
8207 | auto *CN = F_->createConvertTo("convert" , RN, ElemKind::Float16Ty); |
8208 | auto *SN = F_->createSave("ret" , CN); |
8209 | |
8210 | optimizedF_ = optimizeFunctionForTest( |
8211 | F_, {FunctionPassID::SinkCode, getDCEPassConfig()}); |
8212 | |
8213 | auto *optSN = |
8214 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
8215 | ASSERT_TRUE(optSN); |
8216 | auto *optRN = llvm::dyn_cast<ReshapeNode>(optSN->getInput()); |
8217 | ASSERT_TRUE(optRN); |
8218 | EXPECT_EQ(optRN->getResult().getElementType(), ElemKind::Float16Ty); |
8219 | EXPECT_EQ(optRN->getResult().dims(), RN->getResult().dims()); |
8220 | auto *optCN = llvm::dyn_cast<ConvertToNode>(optRN->getInput()); |
8221 | ASSERT_TRUE(optCN); |
8222 | EXPECT_EQ(optCN->getResult().getElementType(), ElemKind::Float16Ty); |
8223 | EXPECT_EQ(optCN->getInput().getNode(), I); |
8224 | |
8225 | bindings_.allocate(I)->getHandle<float>().randomize(-30, 30, mod_.getPRNG()); |
8226 | checkNumericalEquivalence(0.f); |
8227 | } |
8228 | |
8229 | TEST_F(GraphOptz, SinkReshapeBelowUnaryEltwiseOps) { |
8230 | const dim_t dimsIn[] = {10, 10}; |
8231 | const dim_t dimsOut[] = {5, 5, 4}; |
8232 | |
8233 | auto *in = mod_.createPlaceholder(ElemKind::FloatTy, dimsIn, "in" , false); |
8234 | auto *RN = F_->createReshape("reshape" , in, dimsOut); |
8235 | auto *AN = F_->createAbs("abs" , RN); |
8236 | auto *SN = F_->createSin("sin" , AN); |
8237 | auto *CN = F_->createClip("clip" , SN, -4.f, 5.f); |
8238 | auto *TN = F_->createTanh("tanh" , CN); |
8239 | auto *save = F_->createSave("ret" , TN); |
8240 | |
8241 | optimizedF_ = optimizeFunctionForTest(F_); |
8242 | |
8243 | auto *optSave = |
8244 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName())); |
8245 | ASSERT_TRUE(optSave); |
8246 | auto *optRN = llvm::dyn_cast<ReshapeNode>(optSave->getInput()); |
8247 | ASSERT_TRUE(optRN); |
8248 | EXPECT_EQ(optRN->getResult().dims(), llvm::makeArrayRef(dimsOut)); |
8249 | auto *optTN = llvm::dyn_cast<TanhNode>(optRN->getInput()); |
8250 | ASSERT_TRUE(optTN); |
8251 | EXPECT_EQ(optTN->getResult().dims(), llvm::makeArrayRef(dimsIn)); |
8252 | auto *optCN = llvm::dyn_cast<ClipNode>(optTN->getInput()); |
8253 | ASSERT_TRUE(optCN); |
8254 | EXPECT_FLOAT_EQ(optCN->getMin(), CN->getMin()); |
8255 | EXPECT_FLOAT_EQ(optCN->getMax(), CN->getMax()); |
8256 | EXPECT_EQ(optCN->getResult().dims(), llvm::makeArrayRef(dimsIn)); |
8257 | auto *optSN = llvm::dyn_cast<SinNode>(optCN->getInput()); |
8258 | ASSERT_TRUE(optSN); |
8259 | EXPECT_EQ(optSN->getResult().dims(), llvm::makeArrayRef(dimsIn)); |
8260 | auto *optAN = llvm::dyn_cast<AbsNode>(optSN->getInput()); |
8261 | ASSERT_TRUE(optAN); |
8262 | EXPECT_EQ(optAN->getResult().dims(), llvm::makeArrayRef(dimsIn)); |
8263 | |
8264 | bindings_.allocate(in)->getHandle<float>().randomize(-30.f, 30.f, |
8265 | mod_.getPRNG()); |
8266 | checkNumericalEquivalence(0.f); |
8267 | } |
8268 | |
8269 | TEST_F(GraphOptz, OptConvertToDequantize) { |
8270 | auto *I = |
8271 | mod_.createPlaceholder(ElemKind::Int8QTy, {32, 64}, 0.2f, 1, "A" , false); |
8272 | auto *DN = F_->createDequantize("deq" , I, ElemKind::Float16Ty); |
8273 | auto *CN = F_->createConvertTo("convert" , DN, ElemKind::FloatTy); |
8274 | auto *SN = F_->createSave("ret" , CN); |
8275 | |
8276 | optimizedF_ = optimizeFunctionForTest( |
8277 | F_, |
8278 | {FunctionPassID::OptimizeOutIntermediateConversions, getDCEPassConfig()}); |
8279 | |
8280 | auto *optSN = |
8281 | llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName())); |
8282 | ASSERT_TRUE(optSN); |
8283 | auto *optDN = llvm::dyn_cast<DequantizeNode>(optSN->getInput()); |
8284 | ASSERT_TRUE(optDN); |
8285 | EXPECT_EQ(optDN->getResult().getElementType(), ElemKind::FloatTy); |
8286 | EXPECT_EQ(optDN->getResult().dims(), DN->getResult().dims()); |
8287 | EXPECT_EQ(optDN->getInput().getNode(), I); |
8288 | |
8289 | bindings_.allocate(I)->getHandle<int8_t>().randomize(-128, 127, |
8290 | mod_.getPRNG()); |
8291 | checkNumericalEquivalence(0.007f); |
8292 | } |
8293 | |
8294 | /// Test that Exp+ReduceSum+Div is replaced with SoftMax. |
8295 | TEST_F(GraphOptz, FoldExpSumDivIntoSoftmax) { |
8296 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 10}, |
8297 | "input" , /* isTrainable */ false); |
8298 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
8299 | mod_.getPRNG()); |
8300 | auto *exp = F_->createExp("exp" , input); |
8301 | auto *reduceSum = F_->createBatchedReduceAdd("reduce_sum" , exp, {1}); |
8302 | auto *div = F_->createNodeWithBroadcast<DivNode>("div" , 1, exp, reduceSum); |
8303 | F_->createSave("save" , div); |
8304 | |
8305 | EXPECT_EQ(5, F_->getNodes().size()); |
8306 | |
8307 | optimizedF_ = optimizeFunctionForTest( |
8308 | F_, {FunctionPassID::FoldExpSumDivIntoSoftmax, getDCEPassConfig()}); |
8309 | |
8310 | EXPECT_EQ(2, optimizedF_->getNodes().size()); |
8311 | |
8312 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::ExpNodeKind)); |
8313 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::DivNodeKind)); |
8314 | EXPECT_EQ(0, |
8315 | countNodeKind(optimizedF_, Kinded::Kind::BatchedReduceAddNodeKind)); |
8316 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SoftMaxNodeKind)); |
8317 | |
8318 | checkNumericalEquivalence(1e-7f); |
8319 | } |
8320 | |
8321 | /// Test that identity Relu is removed. |
8322 | TEST_F(GraphOptz, RemoveIdentityRelu) { |
8323 | |
8324 | Placeholder *input = mod_.createPlaceholder( |
8325 | ElemKind::Int8QTy, {20}, 0.123f, -128, "input" , /* isTrainable */ false); |
8326 | bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127, |
8327 | mod_.getPRNG()); |
8328 | auto *relu = F_->createRELU("exp" , input); |
8329 | F_->createSave("save" , relu); |
8330 | |
8331 | EXPECT_EQ(2, F_->getNodes().size()); |
8332 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReluNodeKind)); |
8333 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SaveNodeKind)); |
8334 | |
8335 | optimizedF_ = optimizeFunctionForTest( |
8336 | F_, {FunctionPassID::RemoveIdentityRelu, getDCEPassConfig()}); |
8337 | |
8338 | EXPECT_EQ(1, optimizedF_->getNodes().size()); |
8339 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind)); |
8340 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind)); |
8341 | |
8342 | checkNumericalEquivalence(0); |
8343 | } |
8344 | |
8345 | /// Test that identity Clip is removed. |
8346 | TEST_F(GraphOptz, RemoveIdentityClip) { |
8347 | |
8348 | Placeholder *input = |
8349 | mod_.createPlaceholder(ElemKind::Int8QTy, {20}, 0.023529412f, -128, |
8350 | "input" , /* isTrainable */ false); |
8351 | bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127, |
8352 | mod_.getPRNG()); |
8353 | auto *clip = F_->createClip("exp" , input, 0.0f, 6.0f); |
8354 | F_->createSave("save" , clip); |
8355 | |
8356 | EXPECT_EQ(2, F_->getNodes().size()); |
8357 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ClipNodeKind)); |
8358 | EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SaveNodeKind)); |
8359 | |
8360 | optimizedF_ = optimizeFunctionForTest( |
8361 | F_, {FunctionPassID::RemoveIdentityClip, getDCEPassConfig()}); |
8362 | |
8363 | EXPECT_EQ(1, optimizedF_->getNodes().size()); |
8364 | EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind)); |
8365 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind)); |
8366 | |
8367 | checkNumericalEquivalence(0); |
8368 | } |
8369 | |
8370 | /// Test that an identity ResizeNearest is removed. |
8371 | TEST_F(GraphOptz, OptimizeIdentityResizeNearest) { |
8372 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 33, 33, 1}, |
8373 | "input" , /* isTrainable */ false); |
8374 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
8375 | mod_.getPRNG()); |
8376 | auto *resize = F_->createResizeNearest("resize" , input, {1, 1, 1, 1}); |
8377 | F_->createSave("save" , resize); |
8378 | EXPECT_EQ(2, F_->getNodes().size()); |
8379 | optimizedF_ = optimizeFunctionForTest( |
8380 | F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()}); |
8381 | EXPECT_EQ(1, optimizedF_->getNodes().size()); |
8382 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind)); |
8383 | checkNumericalEquivalence(1e-7f); |
8384 | } |
8385 | |
8386 | /// Test that a ResizeNearest with integer scales is transformed to Tile. |
8387 | TEST_F(GraphOptz, OptimizeResizeNearest) { |
8388 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 33, 1}, |
8389 | "input" , /* isTrainable */ false); |
8390 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
8391 | mod_.getPRNG()); |
8392 | auto *resize = F_->createResizeNearest("resize" , input, {1, 2, 7.787879, 1}); |
8393 | F_->createSave("save" , resize); |
8394 | EXPECT_EQ(2, F_->getNodes().size()); |
8395 | optimizedF_ = optimizeFunctionForTest( |
8396 | F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()}); |
8397 | EXPECT_EQ(3, optimizedF_->getNodes().size()); |
8398 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind)); |
8399 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::ResizeNearestNodeKind)); |
8400 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind)); |
8401 | checkNumericalEquivalence(1e-7f); |
8402 | } |
8403 | |
8404 | /// Test that an identity ResizeBilinear is removed. |
8405 | TEST_F(GraphOptz, OptimizeIdentityResizeBilinear) { |
8406 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 33, 33, 1}, |
8407 | "input" , /* isTrainable */ false); |
8408 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
8409 | mod_.getPRNG()); |
8410 | auto *resize = F_->createResizeBilinear("resize" , input, {1, 1, 1, 1}); |
8411 | F_->createSave("save" , resize); |
8412 | EXPECT_EQ(2, F_->getNodes().size()); |
8413 | optimizedF_ = optimizeFunctionForTest( |
8414 | F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()}); |
8415 | EXPECT_EQ(1, optimizedF_->getNodes().size()); |
8416 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind)); |
8417 | checkNumericalEquivalence(1e-7f); |
8418 | } |
8419 | |
8420 | /// Test that a ResizeBilinear with integer scales is transformed to Tile. |
8421 | TEST_F(GraphOptz, OptimizeResizeBilinear) { |
8422 | Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 33, 1}, |
8423 | "input" , /* isTrainable */ false); |
8424 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
8425 | mod_.getPRNG()); |
8426 | auto *resize = F_->createResizeBilinear("resize" , input, {1, 2, 7.787879, 1}); |
8427 | F_->createSave("save" , resize); |
8428 | EXPECT_EQ(2, F_->getNodes().size()); |
8429 | optimizedF_ = optimizeFunctionForTest( |
8430 | F_, {FunctionPassID::OptimizeResize, getDCEPassConfig()}); |
8431 | EXPECT_EQ(3, optimizedF_->getNodes().size()); |
8432 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind)); |
8433 | EXPECT_EQ(1, |
8434 | countNodeKind(optimizedF_, Kinded::Kind::ResizeBilinearNodeKind)); |
8435 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind)); |
8436 | checkNumericalEquivalence(1e-7f); |
8437 | } |
8438 | |
8439 | /// Test that a InsertTensor which has the Big operand a Splat is replaced |
8440 | /// with a Touch node when the Small operand fills it entirely. |
8441 | TEST_F(GraphOptz, OptimizeInsertTensorBigSplat) { |
8442 | Type bigTy(ElemKind::FloatTy, {10}); |
8443 | SplatNode *big = F_->createSplat("splat" , &bigTy, 0); |
8444 | Placeholder *small = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input" , |
8445 | /* isTrainable */ false); |
8446 | bindings_.allocate(small)->getHandle<float>().randomize(-10, 10, |
8447 | mod_.getPRNG()); |
8448 | auto *insert = F_->createInsertTensor("insert" , big, small, |
8449 | /* start */ {0}, |
8450 | /* count */ 10, |
8451 | /* axis */ 0); |
8452 | F_->createSave("save" , insert); |
8453 | EXPECT_EQ(3, F_->getNodes().size()); |
8454 | optimizedF_ = optimizeFunctionForTest( |
8455 | F_, {FunctionPassID::OptimizeInsert, getDCEPassConfig()}); |
8456 | EXPECT_EQ(3, optimizedF_->getNodes().size()); |
8457 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::TouchNodeKind)); |
8458 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::InsertTensorNodeKind)); |
8459 | EXPECT_EQ(1, countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind)); |
8460 | checkNumericalEquivalence(1e-7f); |
8461 | } |
8462 | |
8463 | TEST_F(GraphOptz, sinkQuantizeTransposeMultiUser) { |
8464 | auto *input = |
8465 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input" , |
8466 | /* isTrainable */ false); |
8467 | auto *T = F_->createTranspose("transpose" , input, NHWC2NCHW); |
8468 | auto *Q1 = F_->createQuantize("q1" , T, ElemKind::Int8QTy, 0.11, 1); |
8469 | auto *Q2 = F_->createQuantize("q2" , T, ElemKind::Int8QTy, 0.12, 2); |
8470 | auto *S1 = F_->createSave("save1" , Q1); |
8471 | auto *S2 = F_->createSave("save2" , Q2); |
8472 | |
8473 | optimizedF_ = optimizeFunctionForTest(F_); |
8474 | |
8475 | auto *optS1 = findFunctionNodeByName<SaveNode>(optimizedF_, S1->getName()); |
8476 | auto *optS2 = findFunctionNodeByName<SaveNode>(optimizedF_, S2->getName()); |
8477 | |
8478 | // Check that transpose has been sunk below quantize now for both. |
8479 | EXPECT_TRUE(llvm::isa<TransposeNode>(optS1->getInput())); |
8480 | EXPECT_TRUE(llvm::isa<TransposeNode>(optS2->getInput())); |
8481 | |
8482 | bindings_.allocate(input)->getHandle<float>().randomize(-10, 10, |
8483 | mod_.getPRNG()); |
8484 | checkNumericalEquivalence(0.f); |
8485 | } |
8486 | |
8487 | TEST_F(GraphOptz, skipSinkQuantizeTransposeMultiUser) { |
8488 | auto *input = |
8489 | mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input" , |
8490 | /* isTrainable */ false); |
8491 | auto *T = F_->createTranspose("transpose" , input, NHWC2NCHW); |
8492 | auto *Q = F_->createQuantize("quant" , T, ElemKind::Int8QTy, 0.11, 1); |
8493 | F_->createSave("save1" , Q); |
8494 | F_->createSave("save2" , T); |
8495 | |
8496 | optimizedF_ = optimizeFunctionForTest(F_); |
8497 | |
8498 | // Verify the graph hasn't changed. |
8499 | EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true), |
8500 | optimizedF_->toString(/* skipUsersForStorage */ false, |
8501 | /* skipName */ true)); |
8502 | } |
8503 | |
8504 | TEST_F(GraphOptz, MergeMatMulsOnLHSWhenSkippingOne) { |
8505 | Placeholder *LHS1 = |
8506 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "LHS1" , false); |
8507 | Placeholder *LHS2 = |
8508 | mod_.createPlaceholder(ElemKind::FloatTy, {30, 10}, "LHS2" , false); |
8509 | Placeholder *LHS3 = |
8510 | mod_.createPlaceholder(ElemKind::FloatTy, {20, 10}, "LHS3" , false); |
8511 | Placeholder *RHS = |
8512 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 15}, "RHS" , false); |
8513 | bindings_.allocate(LHS1)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8514 | bindings_.allocate(LHS2)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8515 | bindings_.allocate(LHS3)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8516 | bindings_.allocate(RHS)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8517 | |
8518 | // Chain a bunch of nodes together for LHS2 to prevent dependency analysis |
8519 | // from allowing merging for MM2 below. |
8520 | Node *sigLHS2 = LHS2; |
8521 | for (size_t i = 0, e = 7; i < e; i++) { |
8522 | sigLHS2 = F_->createSigmoid("s_lhs2" , sigLHS2); |
8523 | } |
8524 | |
8525 | Node *MM1 = F_->createMatMul("mm1" , LHS1, RHS); |
8526 | Node *MM2 = F_->createMatMul("mm2" , sigLHS2, RHS); |
8527 | Node *MM3 = F_->createMatMul("mm3" , LHS3, RHS); |
8528 | |
8529 | F_->createSave("save1" , MM1); |
8530 | F_->createSave("save2" , MM2); |
8531 | F_->createSave("save3" , MM3); |
8532 | ASSERT_TRUE(F_->verify()); |
8533 | |
8534 | optimizedF_ = optimizeFunctionForTest( |
8535 | F_, {FunctionPassID::MergeMatMulOnLHS, getDCEPassConfig()}); |
8536 | ASSERT_TRUE(optimizedF_->verify()); |
8537 | |
8538 | // Expect three matmuls -> two matmuls, because mm1 and mm3 were merged. |
8539 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 3); |
8540 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 2); |
8541 | |
8542 | checkNumericalEquivalence(0.f); |
8543 | } |
8544 | |
8545 | TEST_F(GraphOptz, MergeMatMulsOnRHSWhenSkippingOne) { |
8546 | Placeholder *LHS = |
8547 | mod_.createPlaceholder(ElemKind::FloatTy, {40, 10}, "LHS" , false); |
8548 | Placeholder *RHS1 = |
8549 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 15}, "RHS1" , false); |
8550 | Placeholder *RHS2 = |
8551 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 20}, "RHS2" , false); |
8552 | Placeholder *RHS3 = |
8553 | mod_.createPlaceholder(ElemKind::FloatTy, {10, 30}, "RHS3" , false); |
8554 | bindings_.allocate(LHS)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8555 | bindings_.allocate(RHS1)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8556 | bindings_.allocate(RHS2)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8557 | bindings_.allocate(RHS3)->getHandle().randomize(-1.f, 1.f, mod_.getPRNG()); |
8558 | |
8559 | // Chain a bunch of nodes together for RHS2 to prevent dependency analysis |
8560 | // from allowing merging for MM2 below. |
8561 | Node *sigRHS2 = RHS2; |
8562 | for (size_t i = 0, e = 7; i < e; i++) { |
8563 | sigRHS2 = F_->createSigmoid("s_rhs2" , sigRHS2); |
8564 | } |
8565 | |
8566 | Node *MM1 = F_->createMatMul("mm1" , LHS, RHS1); |
8567 | Node *MM2 = F_->createMatMul("mm2" , LHS, sigRHS2); |
8568 | Node *MM3 = F_->createMatMul("mm3" , LHS, RHS3); |
8569 | |
8570 | F_->createSave("save1" , MM1); |
8571 | F_->createSave("save2" , MM2); |
8572 | F_->createSave("save3" , MM3); |
8573 | ASSERT_TRUE(F_->verify()); |
8574 | |
8575 | optimizedF_ = optimizeFunctionForTest( |
8576 | F_, {FunctionPassID::MergeMatMulOnRHS, getDCEPassConfig()}); |
8577 | ASSERT_TRUE(optimizedF_->verify()); |
8578 | |
8579 | // Expect three matmuls -> two matmuls, because mm1 and mm3 were merged. |
8580 | EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 3); |
8581 | EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::MatMulNodeKind), 2); |
8582 | |
8583 | checkNumericalEquivalence(0.f); |
8584 | } |
8585 | |