1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "BackendTestUtils.h"
17
18#include "glow/Graph/Graph.h"
19#include "glow/Graph/Node.h"
20#include "glow/Graph/Nodes.h"
21#include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h"
22#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
23#include "glow/Optimizer/GraphOptimizer/NodeSplitting.h"
24
25#include "gtest/gtest.h"
26
27using namespace glow;
28
29class NodeSplitting : public GraphOptz {};
30
31bool operator==(const std::vector<dim_t> &lhs, const std::vector<dim_t> &rhs) {
32 return std::equal(lhs.begin(), lhs.end(), rhs.begin());
33}
34
35/// Test for checking the LLVM style RTTI used for split options.
36TEST(TestSplitNodeOption, CheckLLVMStyleRTTI) {
37 // Check orthogonal options.
38 auto opt1 = SplitNodeByNumChunks({0}, {1});
39 auto opt2 = SplitNodeByChunkSize({0}, {1});
40 auto opt3 = SplitNodeByChunkSizes({0}, {{1}});
41 auto opt4 = SplitNodeByChunkWeights({0}, {{1}});
42 EXPECT_EQ(opt1.getKind(),
43 SplitNodeOption::SplitNodeKind::SplitNodeByNumChunks);
44 EXPECT_EQ(opt2.getKind(),
45 SplitNodeOption::SplitNodeKind::SplitNodeByChunkSize);
46 EXPECT_EQ(opt3.getKind(),
47 SplitNodeOption::SplitNodeKind::SplitNodeByChunkSizes);
48 EXPECT_EQ(opt4.getKind(),
49 SplitNodeOption::SplitNodeKind::SplitNodeByChunkWeights);
50 std::vector<SplitNodeOption *> orthogonalOpts = {&opt1, &opt2, &opt3, &opt4};
51 for (auto opt : orthogonalOpts) {
52 EXPECT_NE(nullptr, dyn_cast<SplitNodeOptionOrthogonal>(opt));
53 EXPECT_EQ(nullptr, dyn_cast<SplitNodeBySliceRanges>(opt));
54 }
55 // Check non-orthogonal options.
56 std::vector<SliceRange> sliceRanges = {SliceRange({{0, 1}})};
57 auto opt5 = SplitNodeBySliceRanges(sliceRanges);
58 EXPECT_EQ(opt5.getKind(),
59 SplitNodeOption::SplitNodeKind::SplitNodeBySliceRanges);
60 SplitNodeOption *nonOrthogonalOpt = &opt5;
61 EXPECT_EQ(nullptr, dyn_cast<SplitNodeOptionOrthogonal>(nonOrthogonalOpt));
62 EXPECT_NE(nullptr, dyn_cast<SplitNodeBySliceRanges>(nonOrthogonalOpt));
63}
64
65/// Test for SplitNodeByNumChunks option.
66TEST(TestSplitNodeOption, SplitNodeByNumChunksOptionTest) {
67 auto opt1 = SplitNodeByNumChunks({0, 1, 2, 3}, {1, 2, 3, 4},
68 /* bigChunksFirst */ false);
69 EXPECT_EQ(opt1.splitAlongDim(0, 10), std::vector<dim_t>({10}));
70 EXPECT_EQ(opt1.splitAlongDim(1, 10), std::vector<dim_t>({5, 5}));
71 EXPECT_EQ(opt1.splitAlongDim(2, 10), std::vector<dim_t>({3, 3, 4}));
72 EXPECT_EQ(opt1.splitAlongDim(3, 10), std::vector<dim_t>({2, 2, 3, 3}));
73 EXPECT_EQ(opt1.splitAlongDim(3, 12), std::vector<dim_t>({3, 3, 3, 3}));
74
75 auto opt2 = SplitNodeByNumChunks({0, 1, 2, 3}, {1, 2, 3, 4},
76 /* bigChunksFirst */ true);
77 EXPECT_EQ(opt2.splitAlongDim(0, 10), std::vector<dim_t>({10}));
78 EXPECT_EQ(opt2.splitAlongDim(1, 10), std::vector<dim_t>({5, 5}));
79 EXPECT_EQ(opt2.splitAlongDim(2, 10), std::vector<dim_t>({4, 3, 3}));
80 EXPECT_EQ(opt2.splitAlongDim(3, 10), std::vector<dim_t>({3, 3, 2, 2}));
81 EXPECT_EQ(opt2.splitAlongDim(3, 12), std::vector<dim_t>({3, 3, 3, 3}));
82}
83
84/// Test for SplitNodeByChunkSize option.
85TEST(TestSplitNodeOption, SplitNodeByChunkSizeOptionTest) {
86 auto opt1 = SplitNodeByChunkSize({0, 1, 2, 3}, {3, 4, 5, 6},
87 /* bigChunksFirst */ false);
88 EXPECT_EQ(opt1.splitAlongDim(0, 10), std::vector<dim_t>({1, 3, 3, 3}));
89 EXPECT_EQ(opt1.splitAlongDim(1, 10), std::vector<dim_t>({2, 4, 4}));
90 EXPECT_EQ(opt1.splitAlongDim(2, 10), std::vector<dim_t>({5, 5}));
91 EXPECT_EQ(opt1.splitAlongDim(3, 10), std::vector<dim_t>({4, 6}));
92 EXPECT_EQ(opt1.splitAlongDim(3, 18), std::vector<dim_t>({6, 6, 6}));
93
94 auto opt2 = SplitNodeByChunkSize({0, 1, 2, 3}, {3, 4, 5, 6},
95 /* bigChunksFirst */ true);
96 EXPECT_EQ(opt2.splitAlongDim(0, 10), std::vector<dim_t>({3, 3, 3, 1}));
97 EXPECT_EQ(opt2.splitAlongDim(1, 10), std::vector<dim_t>({4, 4, 2}));
98 EXPECT_EQ(opt2.splitAlongDim(2, 10), std::vector<dim_t>({5, 5}));
99 EXPECT_EQ(opt2.splitAlongDim(3, 10), std::vector<dim_t>({6, 4}));
100 EXPECT_EQ(opt2.splitAlongDim(3, 18), std::vector<dim_t>({6, 6, 6}));
101}
102
103/// Test for SplitNodeByChunkSizes option.
104TEST(TestSplitNodeOption, SplitNodeByChunkSizesOptionTest) {
105 auto opt = SplitNodeByChunkSizes({0, 1, 2, 3},
106 {{1, 3, 3, 3}, {2, 4, 4}, {5, 5}, {4, 6}});
107 EXPECT_EQ(opt.splitAlongDim(0, 10), std::vector<dim_t>({1, 3, 3, 3}));
108 EXPECT_EQ(opt.splitAlongDim(1, 10), std::vector<dim_t>({2, 4, 4}));
109 EXPECT_EQ(opt.splitAlongDim(2, 10), std::vector<dim_t>({5, 5}));
110 EXPECT_EQ(opt.splitAlongDim(3, 10), std::vector<dim_t>({4, 6}));
111}
112
113/// Test for SplitNodeByChunkWeights option.
114TEST(TestSplitNodeOption, SplitNodeByChunkWeightsOptionTest) {
115 auto opt1 = SplitNodeByChunkWeights(
116 {0, 1, 2, 3}, {{1, 3, 3, 3}, {2, 4, 4}, {5, 5}, {4, 6}});
117 EXPECT_EQ(opt1.splitAlongDim(0, 20), std::vector<dim_t>({2, 6, 6, 6}));
118 EXPECT_EQ(opt1.splitAlongDim(1, 20), std::vector<dim_t>({4, 8, 8}));
119 EXPECT_EQ(opt1.splitAlongDim(2, 20), std::vector<dim_t>({10, 10}));
120 EXPECT_EQ(opt1.splitAlongDim(3, 20), std::vector<dim_t>({8, 12}));
121
122 auto opt2 = SplitNodeByChunkWeights({0}, {{0.15, 0.15, 0.2, 0.5}});
123 EXPECT_EQ(opt2.splitAlongDim(0, 100), std::vector<dim_t>({15, 15, 20, 50}));
124
125 auto opt3 = SplitNodeByChunkWeights({0}, {{0.00000001, 33, 66}});
126 EXPECT_EQ(opt3.splitAlongDim(0, 100), std::vector<dim_t>({1, 33, 66}));
127}
128
129///===---------------------------------------------------------------------===//
130/// Conv2D
131///===---------------------------------------------------------------------===//
132/// Utility function to create a simple network with a single Conv2D node using
133/// the function \p F and the bindings \p bindings.
134static Node *createConv2D(
135 Function *F, PlaceholderBindings &bindings, llvm::ArrayRef<dim_t> inputDims,
136 llvm::ArrayRef<dim_t> filterDims, llvm::ArrayRef<dim_t> biasDims,
137 llvm::ArrayRef<dim_t> outputDims, llvm::ArrayRef<unsigned_t> kernels,
138 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
139 dim_t group, llvm::ArrayRef<unsigned_t> dilation) {
140 // Create input placeholder.
141 auto &mod = *(F->getParent());
142 auto *input =
143 mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
144 bindings.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
145 mod.getPRNG());
146 // Create filter constant.
147 auto *filter = mod.createConstant(ElemKind::FloatTy, filterDims, "filter");
148 filter->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
149 mod.getPRNG());
150 // Create bias constant.
151 auto *bias = mod.createConstant(ElemKind::FloatTy, biasDims, "bias");
152 bias->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
153 mod.getPRNG());
154 // Create Conv2D.
155 auto *outTy = mod.uniqueType(ElemKind::FloatTy, outputDims);
156 ConvolutionNode *conv =
157 F->createConv("conv", input, filter, bias, outTy, kernels, strides, pads,
158 group, dilation);
159 SaveNode *save = F->createSave("save", conv);
160 bindings.allocate(save->getPlaceholder());
161 return conv;
162}
163
164/// Utility function to test splitting a basic Conv2D node along the dimensions
165/// \p splitDims in the given number chunks \p numChunks. The split is done
166/// implicitly relative to the Conv2D output operand.
167static void splitConv2DBasic(Function *F, Function *&optF,
168 PlaceholderBindings &bindings,
169 CompilationContext &cctx,
170 llvm::ArrayRef<size_t> splitDims,
171 llvm::ArrayRef<dim_t> numChunks) {
172 Node *node = createConv2D(F, bindings,
173 /* inputDims */ {5, 7, 8, 2},
174 /* filterDims */ {8, 2, 2, 1},
175 /* biasDims */ {8},
176 /* outputDims */ {5, 6, 7, 8},
177 /* kernels */ {2, 2},
178 /* strides */ {1, 1},
179 /* pads */ {0, 0, 0, 0},
180 /* group */ 2,
181 /* dilation */ {1, 1});
182
183 // Save current function state as reference.
184 optF = F->clone(F->getName().str() + "_optimized");
185
186 // Split node.
187 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
188 std::vector<Node *> splitNodes;
189 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
190
191 // Compute total number of chunks.
192 dim_t totNumChunks = 1;
193 for (auto numChunk : numChunks) {
194 totNumChunks *= numChunk;
195 }
196
197 // Check node count.
198 EXPECT_EQ(splitNodes.size(), totNumChunks);
199 EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConvolutionNodeKind), totNumChunks);
200 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
201 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
202}
203
204/// Test splitting a Conv2D along dimension N, H, W or C.
205/// Not all the combinations are allowed when splitting along C.
206#define TEST_CONV2D_BASIC_SPLIT(splitDim, numChunks) \
207 TEST_F(NodeSplitting, Conv2D_Basic_Dim##splitDim##_Chunks##numChunks) { \
208 splitConv2DBasic(F_, optimizedF_, bindings_, cctx_, \
209 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
210 checkNumericalEquivalence(0); \
211 }
212TEST_CONV2D_BASIC_SPLIT(N, 2)
213TEST_CONV2D_BASIC_SPLIT(N, 3)
214TEST_CONV2D_BASIC_SPLIT(N, 4)
215TEST_CONV2D_BASIC_SPLIT(N, 5)
216TEST_CONV2D_BASIC_SPLIT(H, 2)
217TEST_CONV2D_BASIC_SPLIT(H, 3)
218TEST_CONV2D_BASIC_SPLIT(H, 4)
219TEST_CONV2D_BASIC_SPLIT(H, 5)
220TEST_CONV2D_BASIC_SPLIT(H, 6)
221TEST_CONV2D_BASIC_SPLIT(W, 2)
222TEST_CONV2D_BASIC_SPLIT(W, 3)
223TEST_CONV2D_BASIC_SPLIT(W, 4)
224TEST_CONV2D_BASIC_SPLIT(W, 5)
225TEST_CONV2D_BASIC_SPLIT(W, 6)
226TEST_CONV2D_BASIC_SPLIT(W, 7)
227TEST_CONV2D_BASIC_SPLIT(C, 2)
228TEST_CONV2D_BASIC_SPLIT(C, 4)
229TEST_CONV2D_BASIC_SPLIT(C, 8)
230#undef TEST_CONV2D_BASIC_SPLIT
231
232/// Test splitting a Conv2D along dimensions N, H.
233TEST_F(NodeSplitting, Conv2D_Basic_DimNH_Chunks4) {
234 splitConv2DBasic(F_, optimizedF_, bindings_, cctx_,
235 {ShapeNHWC::DimN, ShapeNHWC::DimH}, {2, 2});
236 checkNumericalEquivalence(0);
237}
238
239/// Test splitting a Conv2D along dimensions N, H, W.
240TEST_F(NodeSplitting, Conv2D_Basic_DimNHW_Chunks8) {
241 splitConv2DBasic(F_, optimizedF_, bindings_, cctx_,
242 {ShapeNHWC::DimN, ShapeNHWC::DimH, ShapeNHWC::DimW},
243 {2, 2, 2});
244 checkNumericalEquivalence(0);
245}
246
247/// Test splitting a Conv2D along dimensions N, H, W, C.
248TEST_F(NodeSplitting, Conv2D_Basic_DimNHWC_Chunks16) {
249 splitConv2DBasic(
250 F_, optimizedF_, bindings_, cctx_,
251 {ShapeNHWC::DimN, ShapeNHWC::DimH, ShapeNHWC::DimW, ShapeNHWC::DimC},
252 {2, 2, 2, 2});
253 checkNumericalEquivalence(0);
254}
255
256/// Utility function to test splitting a Conv2D node with non-zero padding
257/// along the dimensions \p splitDims in the given number chunks \p numChunks.
258/// The split is done implicitly relative to the Conv2D output operand.
259static void splitConv2DNonZeroPad(Function *F, Function *&optF,
260 PlaceholderBindings &bindings,
261 CompilationContext &cctx,
262 llvm::ArrayRef<size_t> splitDims,
263 llvm::ArrayRef<dim_t> numChunks) {
264 Node *node = createConv2D(F, bindings,
265 /* inputDims */ {1, 8, 9, 1},
266 /* filterDims */ {1, 2, 3, 1},
267 /* biasDims */ {1},
268 /* outputDims */ {1, 11, 10, 1},
269 /* kernels */ {2, 3},
270 /* strides */ {1, 1},
271 /* pads */ {2, 1, 3, 4},
272 /* group */ 1,
273 /* dilation */ {2, 2});
274
275 // Save current function state as reference.
276 optF = F->clone(F->getName().str() + "_optimized");
277
278 // Split node.
279 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
280 std::vector<Node *> splitNodes;
281 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
282
283 // Compute total number of chunks.
284 dim_t totNumChunks = 1;
285 for (auto numChunk : numChunks) {
286 totNumChunks *= numChunk;
287 }
288
289 // Check node count.
290 EXPECT_EQ(splitNodes.size(), totNumChunks);
291 EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConvolutionNodeKind), totNumChunks);
292 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
293 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
294}
295
296/// Test splitting a Conv2D with padding along dimension N, H, W or C.
297#define TEST_CONV2D_NONZEROPAD_SPLIT(splitDim, numChunks) \
298 TEST_F(NodeSplitting, Conv2D_NonZeroPad_Dim##splitDim##_Chunks##numChunks) { \
299 splitConv2DNonZeroPad(F_, optimizedF_, bindings_, cctx_, \
300 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
301 checkNumericalEquivalence(0); \
302 }
303TEST_CONV2D_NONZEROPAD_SPLIT(H, 2)
304TEST_CONV2D_NONZEROPAD_SPLIT(H, 3)
305TEST_CONV2D_NONZEROPAD_SPLIT(W, 2)
306TEST_CONV2D_NONZEROPAD_SPLIT(W, 3)
307#undef TEST_CONV2D_NONZEROPAD_SPLIT
308
309/// Test splitting a Conv2D with padding along dimensions H, W.
310TEST_F(NodeSplitting, Conv2D_NonZeroPad_DimHW_Chunks9) {
311 splitConv2DNonZeroPad(F_, optimizedF_, bindings_, cctx_,
312 {ShapeNHWC::DimH, ShapeNHWC::DimW}, {3, 3});
313 checkNumericalEquivalence(0);
314}
315
316/// Utility function to test splitting a group Conv2D node along dimension C in
317/// \p numChunks having the given number of \p inputChannels, \p outputChannels
318/// and the given \p group. The split is done implicitly relative to the Conv2D
319/// output operand.
320static void splitConv2DGrouped(Function *F, Function *&optF,
321 PlaceholderBindings &bindings,
322 CompilationContext &cctx, dim_t inputChannels,
323 dim_t outputChannels, dim_t group,
324 dim_t numChunks) {
325 dim_t filterChannels = inputChannels / group;
326 dim_t filterNum = outputChannels;
327 Node *node = createConv2D(F, bindings,
328 /* inputDims */ {1, 2, 2, inputChannels},
329 /* filterDims */ {filterNum, 2, 2, filterChannels},
330 /* biasDims */ {outputChannels},
331 /* outputDims */ {1, 1, 1, outputChannels},
332 /* kernels */ {2, 2},
333 /* strides */ {1, 1},
334 /* pads */ {0, 0, 0, 0},
335 /* group */ group,
336 /* dilation */ {1, 1});
337
338 // Save current function state as reference.
339 optF = F->clone(F->getName().str() + "_optimized");
340
341 // Split node.
342 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimC}, {numChunks});
343 std::vector<Node *> splitNodes;
344 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
345
346 // Check node count.
347 EXPECT_EQ(splitNodes.size(), numChunks);
348 EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConvolutionNodeKind), numChunks);
349 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), numChunks);
350 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
351}
352
353/// Test splitting a grouped Conv2D along dimension C.
354#define TEST_CONV2D_GROUP_SPLIT(IC, OC, G, chunks) \
355 TEST_F(NodeSplitting, \
356 Conv2D_Group_DimC_InpC##IC##_OutC##OC##_Group##G##_Chunks##chunks) { \
357 splitConv2DGrouped(F_, optimizedF_, bindings_, cctx_, IC, OC, G, chunks); \
358 checkNumericalEquivalence(0); \
359 }
360TEST_CONV2D_GROUP_SPLIT(8, 8, 2, 2)
361TEST_CONV2D_GROUP_SPLIT(8, 8, 2, 4)
362TEST_CONV2D_GROUP_SPLIT(8, 8, 2, 8)
363TEST_CONV2D_GROUP_SPLIT(8, 8, 4, 2)
364TEST_CONV2D_GROUP_SPLIT(8, 8, 4, 4)
365TEST_CONV2D_GROUP_SPLIT(8, 8, 4, 8)
366TEST_CONV2D_GROUP_SPLIT(8, 8, 8, 2)
367TEST_CONV2D_GROUP_SPLIT(8, 8, 8, 4)
368TEST_CONV2D_GROUP_SPLIT(8, 8, 8, 8)
369TEST_CONV2D_GROUP_SPLIT(8, 16, 2, 2)
370TEST_CONV2D_GROUP_SPLIT(8, 16, 2, 4)
371TEST_CONV2D_GROUP_SPLIT(8, 16, 2, 8)
372TEST_CONV2D_GROUP_SPLIT(8, 16, 4, 2)
373TEST_CONV2D_GROUP_SPLIT(8, 16, 4, 4)
374TEST_CONV2D_GROUP_SPLIT(8, 16, 4, 8)
375TEST_CONV2D_GROUP_SPLIT(8, 16, 8, 2)
376TEST_CONV2D_GROUP_SPLIT(8, 16, 8, 4)
377TEST_CONV2D_GROUP_SPLIT(8, 16, 8, 8)
378#undef TEST_CONV2D_GROUP_SPLIT
379
380/// Test splitting an "ill-defined" Conv2D for which not all the input
381/// (including padding) is referenced by the output tensor. This happens
382/// when using a stride larger than 1. This verifies that the node
383/// splitting infrastructure uses a weaker verification of the mapping
384/// between input and output for Conv2D.
385TEST_F(NodeSplitting, Conv2D_IllDefined_DimHW) {
386 std::vector<size_t> splitDims = {ShapeNHWC::DimH, ShapeNHWC::DimW};
387 std::vector<dim_t> numChunks = {3, 3};
388 Node *node = createConv2D(F_, bindings_,
389 /* inputDims */ {1, 16, 18, 1},
390 /* filterDims */ {1, 2, 2, 1},
391 /* biasDims */ {1},
392 /* outputDims */ {1, 8, 9, 1},
393 /* kernels */ {2, 2},
394 /* strides */ {2, 2},
395 /* pads */ {1, 1, 0, 0},
396 /* group */ 1,
397 /* dilation */ {1, 1});
398
399 // Save current function state as reference.
400 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
401
402 // Split node.
403 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
404 std::vector<Node *> splitNodes;
405 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
406
407 // Check node count.
408 dim_t totNumChunks = numChunks[0] * numChunks[1];
409 EXPECT_EQ(splitNodes.size(), totNumChunks);
410 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind), totNumChunks);
411 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind),
412 totNumChunks);
413 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
414 checkNumericalEquivalence(0);
415}
416
417/// Test splitting a Conv2D based on memory constraint.
418TEST_F(NodeSplitting, Conv2D_MaxMem) {
419 Node *node = createConv2D(F_, bindings_,
420 /* inputDims */ {5, 7, 8, 2},
421 /* filterDims */ {8, 2, 2, 1},
422 /* biasDims */ {8},
423 /* outputDims */ {5, 6, 7, 8},
424 /* kernels */ {2, 2},
425 /* strides */ {1, 1},
426 /* pads */ {0, 0, 0, 0},
427 /* group */ 2,
428 /* dilation */ {1, 1});
429
430 // Save current function state as reference.
431 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
432
433 // Split node by memory size.
434 auto origMemSize = node->getTotMemSize();
435 auto splitMaxMemSize = origMemSize / 2;
436 std::vector<Node *> splitNodes;
437 ASSIGN_VALUE_OR_FAIL_TEST(
438 splitNodes,
439 ::glow::splitNode(node, SplitNodeMaxMemConstraint(splitMaxMemSize)));
440
441 // Check node count.
442 auto totNumChunks = countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind);
443 EXPECT_TRUE(totNumChunks > 1);
444 EXPECT_EQ(splitNodes.size(), totNumChunks);
445 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind),
446 totNumChunks);
447 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
448
449 // Check split nodes memory sizes.
450 for (auto *splitNode : splitNodes) {
451 EXPECT_TRUE(splitNode->getTotMemSize() <= splitMaxMemSize);
452 }
453 checkNumericalEquivalence(0);
454}
455
456/// Test splitting a Conv2D based on an impossible constraint forcing the split
457/// procedure to go through all the split configurations while verifying them.
458/// In the end no split should be performed.
459TEST_F(NodeSplitting, Conv2D_NoSplit) {
460 Node *node = createConv2D(F_, bindings_,
461 /* inputDims */ {5, 7, 8, 2},
462 /* filterDims */ {8, 2, 2, 1},
463 /* biasDims */ {8},
464 /* outputDims */ {5, 6, 7, 8},
465 /* kernels */ {2, 2},
466 /* strides */ {1, 1},
467 /* pads */ {0, 0, 0, 0},
468 /* group */ 2,
469 /* dilation */ {1, 1});
470
471 // Save current function state as reference.
472 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
473
474 // Split node by memory size 0.
475 std::vector<Node *> splitNodes;
476 ASSIGN_VALUE_OR_FAIL_TEST(
477 splitNodes, ::glow::splitNode(node, SplitNodeMaxMemConstraint(0)));
478
479 // Check node count.
480 EXPECT_EQ(splitNodes.size(), 0);
481 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 0);
482 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind), 1);
483 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 0);
484 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 0);
485 checkNumericalEquivalence(0);
486}
487
488///===---------------------------------------------------------------------===//
489/// ChannelwiseQuantizedConv2D
490///===---------------------------------------------------------------------===//
491/// Utility function to create a simple network with a CWQConv2D node using
492/// the function \p F and the bindings \p bindings.
493static Node *createCWQConv2D(
494 Function *F, PlaceholderBindings &bindings, llvm::ArrayRef<dim_t> inputDims,
495 llvm::ArrayRef<dim_t> filterDims, llvm::ArrayRef<dim_t> biasDims,
496 llvm::ArrayRef<dim_t> outputDims, llvm::ArrayRef<unsigned_t> kernels,
497 llvm::ArrayRef<unsigned_t> strides, llvm::ArrayRef<unsigned_t> pads,
498 dim_t group, llvm::ArrayRef<unsigned_t> dilation) {
499 // Create quantized input placeholder.
500 auto &mod = *(F->getParent());
501 auto *inputQ = mod.createPlaceholder(ElemKind::Int8QTy, inputDims, 1.0, 0,
502 "inputQ", false);
503 bindings.allocate(inputQ)->getHandle<int8_t>().randomize(-128, 127,
504 mod.getPRNG());
505 // Create float filter constant.
506 auto *filterF = mod.createConstant(ElemKind::FloatTy, filterDims, "filterF");
507 filterF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
508 mod.getPRNG());
509 // Create float bias constant.
510 auto *biasF = mod.createConstant(ElemKind::FloatTy, biasDims, "biasF");
511 biasF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
512 mod.getPRNG());
513 // Create ChannelwiseQuantizedConv2D.
514 auto *outTy = mod.uniqueType(ElemKind::Int8QTy, outputDims, 1.0, 0);
515 auto *conv = F->createChannelwiseQuantizedConv(
516 "cwqconv", inputQ, filterF, biasF,
517 /* filterScales */ nullptr, /* filterOffsets */ nullptr,
518 /* biasScales */ nullptr, /* biasOffsets */ nullptr, outTy, kernels,
519 strides, pads, group, dilation,
520 /* quantizeFilter */ true, /* quantizeBias */ true);
521 SaveNode *save = F->createSave("save", conv);
522 bindings.allocate(save->getPlaceholder());
523 return conv;
524}
525
526/// Utility function to test splitting a basic CWQConv2D node along the
527/// dimensions \p splitDims in the given number chunks \p numChunks. The split
528/// is done implicitly relative to the Conv2D output operand.
529static void splitCWQConv2DBasic(Function *F, Function *&optF,
530 PlaceholderBindings &bindings,
531 CompilationContext &cctx,
532 llvm::ArrayRef<size_t> splitDims,
533 llvm::ArrayRef<dim_t> numChunks) {
534 Node *node = createCWQConv2D(F, bindings,
535 /* inputDims */ {5, 7, 8, 2},
536 /* filterDims */ {8, 2, 2, 1},
537 /* biasDims */ {8},
538 /* outputDims */ {5, 6, 7, 8},
539 /* kernels */ {2, 2},
540 /* strides */ {1, 1},
541 /* pads */ {0, 0, 0, 0},
542 /* group */ 2,
543 /* dilation */ {1, 1});
544
545 // Save current function state as reference.
546 optF = F->clone(F->getName().str() + "_optimized");
547
548 // Split node.
549 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
550 std::vector<Node *> splitNodes;
551 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
552
553 // Compute total number of chunks.
554 dim_t totNumChunks = 1;
555 for (auto numChunk : numChunks) {
556 totNumChunks *= numChunk;
557 }
558
559 // Check node count.
560 EXPECT_EQ(splitNodes.size(), totNumChunks);
561 EXPECT_EQ(
562 countNodeKind(F, Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind),
563 totNumChunks);
564 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
565 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
566}
567
568/// Test splitting a CWQConv2D along dimension N, H, W or C.
569/// Not all the combinations are allowed when splitting along C.
570#define TEST_CWQCONV2D_BASIC_SPLIT(splitDim, numChunks) \
571 TEST_F(NodeSplitting, CWQConv2D_Basic_Dim##splitDim##_Chunks##numChunks) { \
572 splitCWQConv2DBasic(F_, optimizedF_, bindings_, cctx_, \
573 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
574 checkNumericalEquivalence(0); \
575 }
576TEST_CWQCONV2D_BASIC_SPLIT(N, 2)
577TEST_CWQCONV2D_BASIC_SPLIT(N, 3)
578TEST_CWQCONV2D_BASIC_SPLIT(N, 4)
579TEST_CWQCONV2D_BASIC_SPLIT(N, 5)
580TEST_CWQCONV2D_BASIC_SPLIT(H, 2)
581TEST_CWQCONV2D_BASIC_SPLIT(H, 3)
582TEST_CWQCONV2D_BASIC_SPLIT(H, 4)
583TEST_CWQCONV2D_BASIC_SPLIT(H, 5)
584TEST_CWQCONV2D_BASIC_SPLIT(H, 6)
585TEST_CWQCONV2D_BASIC_SPLIT(W, 2)
586TEST_CWQCONV2D_BASIC_SPLIT(W, 3)
587TEST_CWQCONV2D_BASIC_SPLIT(W, 4)
588TEST_CWQCONV2D_BASIC_SPLIT(W, 5)
589TEST_CWQCONV2D_BASIC_SPLIT(W, 6)
590TEST_CWQCONV2D_BASIC_SPLIT(W, 7)
591TEST_CWQCONV2D_BASIC_SPLIT(C, 2)
592TEST_CWQCONV2D_BASIC_SPLIT(C, 4)
593TEST_CWQCONV2D_BASIC_SPLIT(C, 8)
594#undef TEST_CWQCONV2D_BASIC_SPLIT
595
596/// Test splitting a CWQConv2D along dimensions N, H.
597TEST_F(NodeSplitting, CWQConv2D_Basic_DimNH_Chunks4) {
598 splitCWQConv2DBasic(F_, optimizedF_, bindings_, cctx_,
599 {ShapeNHWC::DimN, ShapeNHWC::DimH}, {2, 2});
600 checkNumericalEquivalence(0);
601}
602
603/// Test splitting a CWQConv2D along dimensions N, H, W.
604TEST_F(NodeSplitting, CWQConv2D_Basic_DimNHW_Chunks8) {
605 splitCWQConv2DBasic(F_, optimizedF_, bindings_, cctx_,
606 {ShapeNHWC::DimN, ShapeNHWC::DimH, ShapeNHWC::DimW},
607 {2, 2, 2});
608 checkNumericalEquivalence(0);
609}
610
611/// Test splitting a CWQConv2D along dimensions N, H, W, C.
612TEST_F(NodeSplitting, CWQConv2D_Basic_DimNHWC_Chunks16) {
613 splitCWQConv2DBasic(
614 F_, optimizedF_, bindings_, cctx_,
615 {ShapeNHWC::DimN, ShapeNHWC::DimH, ShapeNHWC::DimW, ShapeNHWC::DimC},
616 {2, 2, 2, 2});
617 checkNumericalEquivalence(0);
618}
619
620/// Utility function to test splitting a CWQConv2D node with non-zero padding
621/// along the dimensions \p splitDims in the given number chunks \p numChunks.
622/// The split is done implicitly relative to the Conv2D output operand.
623static void splitCWQConv2DNonZeroPad(Function *F, Function *&optF,
624 PlaceholderBindings &bindings,
625 CompilationContext &cctx,
626 llvm::ArrayRef<size_t> splitDims,
627 llvm::ArrayRef<dim_t> numChunks) {
628 Node *node = createCWQConv2D(F, bindings,
629 /* inputDims */ {1, 8, 9, 1},
630 /* filterDims */ {1, 2, 3, 1},
631 /* biasDims */ {1},
632 /* outputDims */ {1, 11, 10, 1},
633 /* kernels */ {2, 3},
634 /* strides */ {1, 1},
635 /* pads */ {2, 1, 3, 4},
636 /* group */ 1,
637 /* dilation */ {2, 2});
638
639 // Save current function state as reference.
640 optF = F->clone(F->getName().str() + "_optimized");
641
642 // Split node.
643 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
644 std::vector<Node *> splitNodes;
645 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
646
647 // Compute total number of chunks.
648 dim_t totNumChunks = 1;
649 for (auto numChunk : numChunks) {
650 totNumChunks *= numChunk;
651 }
652
653 // Check node count.
654 EXPECT_EQ(splitNodes.size(), totNumChunks);
655 EXPECT_EQ(
656 countNodeKind(F, Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind),
657 totNumChunks);
658 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
659 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
660}
661
662/// Test splitting a CWQConv2D with padding along dimension N, H, W or C.
663#define TEST_CWQCONV2D_NONZEROPAD_SPLIT(splitDim, numChunks) \
664 TEST_F(NodeSplitting, \
665 CWQConv2D_NonZeroPad_Dim##splitDim##_Chunks##numChunks) { \
666 splitCWQConv2DNonZeroPad(F_, optimizedF_, bindings_, cctx_, \
667 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
668 checkNumericalEquivalence(0); \
669 }
670TEST_CWQCONV2D_NONZEROPAD_SPLIT(H, 2)
671TEST_CWQCONV2D_NONZEROPAD_SPLIT(H, 3)
672TEST_CWQCONV2D_NONZEROPAD_SPLIT(W, 2)
673TEST_CWQCONV2D_NONZEROPAD_SPLIT(W, 3)
674#undef TEST_CWQCONV2D_NONZEROPAD_SPLIT
675
676/// Test splitting a CWQConv2D with padding along dimensions H, W.
677TEST_F(NodeSplitting, CWQConv2D_NonZeroPad_DimHW_Chunks9) {
678 splitCWQConv2DNonZeroPad(F_, optimizedF_, bindings_, cctx_,
679 {ShapeNHWC::DimH, ShapeNHWC::DimW}, {3, 3});
680 checkNumericalEquivalence(0);
681}
682
683/// Utility function to test splitting a group CWQConv2D node along dimension C
684/// in \p numChunks having the given number of \p inputChannels,
685/// \p outputChannels and the given \p group. The split is done implicitly
686/// relative to the Conv2D output operand.
687static void splitCWQConv2DGrouped(Function *F, Function *&optF,
688 PlaceholderBindings &bindings,
689 CompilationContext &cctx, dim_t inputChannels,
690 dim_t outputChannels, dim_t group,
691 dim_t numChunks) {
692 dim_t filterChannels = inputChannels / group;
693 dim_t filterNum = outputChannels;
694 Node *node =
695 createCWQConv2D(F, bindings,
696 /* inputDims */ {1, 2, 2, inputChannels},
697 /* filterDims */ {filterNum, 2, 2, filterChannels},
698 /* biasDims */ {outputChannels},
699 /* outputDims */ {1, 1, 1, outputChannels},
700 /* kernels */ {2, 2},
701 /* strides */ {1, 1},
702 /* pads */ {0, 0, 0, 0},
703 /* group */ group,
704 /* dilation */ {1, 1});
705
706 // Save current function state as reference.
707 optF = F->clone(F->getName().str() + "_optimized");
708
709 // Split node.
710 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimC}, {numChunks});
711 std::vector<Node *> splitNodes;
712 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
713
714 // Check node count.
715 EXPECT_EQ(splitNodes.size(), numChunks);
716 EXPECT_EQ(
717 countNodeKind(F, Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind),
718 numChunks);
719 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), numChunks);
720 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
721}
722
723/// Test splitting a grouped Conv2D along dimension C.
724#define TEST_CWQCONV2D_GROUP_SPLIT(IC, OC, G, chunks) \
725 TEST_F( \
726 NodeSplitting, \
727 CWQConv2D_Group_DimC_InpC##IC##_OutC##OC##_Group##G##_Chunks##chunks) { \
728 splitCWQConv2DGrouped(F_, optimizedF_, bindings_, cctx_, IC, OC, G, \
729 chunks); \
730 checkNumericalEquivalence(0); \
731 }
732TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 2, 2)
733TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 2, 4)
734TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 2, 8)
735TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 4, 2)
736TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 4, 4)
737TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 4, 8)
738TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 8, 2)
739TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 8, 4)
740TEST_CWQCONV2D_GROUP_SPLIT(8, 8, 8, 8)
741TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 2, 2)
742TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 2, 4)
743TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 2, 8)
744TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 4, 2)
745TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 4, 4)
746TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 4, 8)
747TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 8, 2)
748TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 8, 4)
749TEST_CWQCONV2D_GROUP_SPLIT(8, 16, 8, 8)
750#undef TEST_CWQCONV2D_GROUP_SPLIT
751
752///===---------------------------------------------------------------------===//
753/// MaxPool
754///===---------------------------------------------------------------------===//
755/// Utility function to create a simple network with a single MaxPool node using
756/// the function \p F and the bindings \p bindings.
757static Node *createMaxPool(Function *F, PlaceholderBindings &bindings,
758 llvm::ArrayRef<dim_t> inputDims,
759 llvm::ArrayRef<dim_t> outputDims,
760 llvm::ArrayRef<unsigned_t> kernels,
761 llvm::ArrayRef<unsigned_t> strides,
762 llvm::ArrayRef<unsigned_t> pads) {
763 // Create input placeholder.
764 auto &mod = *(F->getParent());
765 auto *input =
766 mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
767 bindings.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
768 mod.getPRNG());
769 // Create MaxPool.
770 MaxPoolNode *maxpool =
771 F->createMaxPool("maxpool", input, kernels, strides, pads);
772 SaveNode *save = F->createSave("save", maxpool->getResult());
773 bindings.allocate(save->getPlaceholder());
774 EXPECT_EQ(maxpool->getResult().getType()->dims(), outputDims);
775 return maxpool;
776}
777
778/// Utility function to test splitting a basic MaxPool node along the dimensions
779/// \p splitDims in the given number chunks \p numChunks. The split is done
780/// implicitly relative to the MaxPool output operand.
781static void splitMaxPoolBasic(Function *F, Function *&optF,
782 PlaceholderBindings &bindings,
783 CompilationContext &cctx,
784 llvm::ArrayRef<size_t> splitDims,
785 llvm::ArrayRef<dim_t> numChunks) {
786 Node *node = createMaxPool(F, bindings,
787 /* inputDims */ {3, 7, 8, 4},
788 /* outputDims */ {3, 6, 7, 4},
789 /* kernels */ {2, 2},
790 /* strides */ {1, 1},
791 /* pads */ {0, 0, 0, 0});
792
793 // Save current function state as reference.
794 optF = F->clone(F->getName().str() + "_optimized");
795
796 // Split node.
797 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
798 std::vector<Node *> splitNodes;
799 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
800
801 // Compute total number of chunks.
802 dim_t totNumChunks = 1;
803 for (auto numChunk : numChunks) {
804 totNumChunks *= numChunk;
805 }
806
807 // Check node count.
808 EXPECT_EQ(splitNodes.size(), totNumChunks);
809 EXPECT_EQ(countNodeKind(F, Kinded::Kind::MaxPoolNodeKind), totNumChunks);
810 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
811 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
812}
813
814/// Test splitting a MaxPool along dimension N, H, W or C.
815#define TEST_MAXPOOL_BASIC_SPLIT(splitDim, numChunks) \
816 TEST_F(NodeSplitting, MaxPool_Basic_Dim##splitDim##_Chunks##numChunks) { \
817 splitMaxPoolBasic(F_, optimizedF_, bindings_, cctx_, \
818 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
819 checkNumericalEquivalence(0); \
820 }
821TEST_MAXPOOL_BASIC_SPLIT(N, 2)
822TEST_MAXPOOL_BASIC_SPLIT(N, 3)
823TEST_MAXPOOL_BASIC_SPLIT(H, 2)
824TEST_MAXPOOL_BASIC_SPLIT(H, 3)
825TEST_MAXPOOL_BASIC_SPLIT(H, 4)
826TEST_MAXPOOL_BASIC_SPLIT(H, 5)
827TEST_MAXPOOL_BASIC_SPLIT(H, 6)
828TEST_MAXPOOL_BASIC_SPLIT(W, 2)
829TEST_MAXPOOL_BASIC_SPLIT(W, 3)
830TEST_MAXPOOL_BASIC_SPLIT(W, 4)
831TEST_MAXPOOL_BASIC_SPLIT(W, 5)
832TEST_MAXPOOL_BASIC_SPLIT(W, 6)
833TEST_MAXPOOL_BASIC_SPLIT(W, 7)
834TEST_MAXPOOL_BASIC_SPLIT(C, 2)
835TEST_MAXPOOL_BASIC_SPLIT(C, 3)
836TEST_MAXPOOL_BASIC_SPLIT(C, 4)
837#undef TEST_MAXPOOL_BASIC_SPLIT
838
839/// Utility function to test splitting a MaxPool node with non-zero padding
840/// along the dimensions \p splitDims in the given number chunks \p numChunks.
841/// The split is done implicitly relative to the MaxPool output operand.
842static void splitMaxPoolNonZeroPad(Function *F, Function *&optF,
843 PlaceholderBindings &bindings,
844 CompilationContext &cctx,
845 llvm::ArrayRef<size_t> splitDims,
846 llvm::ArrayRef<dim_t> numChunks) {
847 Node *node = createMaxPool(F, bindings,
848 /* inputDims */ {1, 4, 4, 1},
849 /* outputDims */ {1, 4, 8, 1},
850 /* kernels */ {2, 2},
851 /* strides */ {1, 1},
852 /* pads */ {0, 2, 1, 3});
853
854 // Save current function state as reference.
855 optF = F->clone(F->getName().str() + "_optimized");
856
857 // Split node.
858 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
859 std::vector<Node *> splitNodes;
860 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
861
862 // Compute total number of chunks.
863 dim_t totNumChunks = 1;
864 for (auto numChunk : numChunks) {
865 totNumChunks *= numChunk;
866 }
867
868 // Check node count.
869 EXPECT_EQ(splitNodes.size(), totNumChunks);
870 EXPECT_EQ(countNodeKind(F, Kinded::Kind::MaxPoolNodeKind), totNumChunks);
871 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
872 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
873}
874
875/// Test splitting a MaxPool with padding along dimension N, H, W or C.
876#define TEST_MAXPOOL_NONZEROPAD_SPLIT(splitDim, numChunks) \
877 TEST_F(NodeSplitting, \
878 MaxPool_NonZeroPad_Dim##splitDim##_Chunks##numChunks) { \
879 splitMaxPoolNonZeroPad(F_, optimizedF_, bindings_, cctx_, \
880 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
881 checkNumericalEquivalence(0); \
882 }
883TEST_MAXPOOL_NONZEROPAD_SPLIT(H, 2)
884TEST_MAXPOOL_NONZEROPAD_SPLIT(W, 2)
885#undef TEST_MAXPOOL_NONZEROPAD_SPLIT
886
887/// Test splitting a MaxPool with padding along dimensions H, W.
888TEST_F(NodeSplitting, MaxPool_NonZeroPad_DimHW_Chunks4) {
889 splitMaxPoolNonZeroPad(F_, optimizedF_, bindings_, cctx_,
890 {ShapeNHWC::DimH, ShapeNHWC::DimW}, {2, 2});
891 checkNumericalEquivalence(0);
892}
893
894/// Test splitting an "ill-defined" MaxPool for which not all the input
895/// (including padding) is referenced by the output tensor. This happens
896/// when using a stride larger than 1. This verifies that the node
897/// splitting infrastructure uses a weaker verification of the mapping
898/// between input and output for MaxPool.
899TEST_F(NodeSplitting, MaxPool_IllDefined_DimHW) {
900 std::vector<size_t> splitDims = {ShapeNHWC::DimH, ShapeNHWC::DimW};
901 std::vector<dim_t> numChunks = {3, 3};
902 Node *node = createMaxPool(F_, bindings_,
903 /* inputDims */ {1, 16, 18, 1},
904 /* outputDims */ {1, 8, 9, 1},
905 /* kernels */ {2, 2},
906 /* strides */ {2, 2},
907 /* pads */ {1, 1, 0, 0});
908
909 // Save current function state as reference.
910 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
911
912 // Split node.
913 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
914 std::vector<Node *> splitNodes;
915 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
916
917 // Check node count.
918 dim_t totNumChunks = numChunks[0] * numChunks[1];
919 EXPECT_EQ(splitNodes.size(), totNumChunks);
920 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind), totNumChunks);
921 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind),
922 totNumChunks);
923 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
924 checkNumericalEquivalence(0);
925}
926
927/// Test splitting a MaxPool based on memory constraint.
928TEST_F(NodeSplitting, MaxPool_MaxMem) {
929 Node *node = createMaxPool(F_, bindings_,
930 /* inputDims */ {3, 7, 8, 4},
931 /* outputDims */ {3, 6, 7, 4},
932 /* kernels */ {2, 2},
933 /* strides */ {1, 1},
934 /* pads */ {0, 0, 0, 0});
935
936 // Save current function state as reference.
937 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
938
939 // Split node by memory size.
940 auto origMemSize = node->getTotMemSize();
941 auto splitMaxMemSize = origMemSize / 2;
942 std::vector<Node *> splitNodes;
943 ASSIGN_VALUE_OR_FAIL_TEST(
944 splitNodes,
945 ::glow::splitNode(node, SplitNodeMaxMemConstraint(splitMaxMemSize)));
946
947 // Check node count.
948 auto totNumChunks = countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind);
949 EXPECT_EQ(splitNodes.size(), totNumChunks);
950 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind),
951 totNumChunks);
952 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
953
954 // Check split nodes memory sizes.
955 for (auto *splitNode : splitNodes) {
956 EXPECT_TRUE(splitNode->getTotMemSize() <= splitMaxMemSize);
957 }
958 checkNumericalEquivalence(0);
959}
960
961/// Test that a MaxPool node is not split when the second output operand
962/// Argmax has users.
963TEST_F(NodeSplitting, MaxPool_Argmax_NoSplit) {
964 std::vector<dim_t> inputDims = {1, 16, 18, 1};
965 std::vector<dim_t> outputDims = {1, 8, 9, 1};
966 std::vector<unsigned_t> kernels = {2, 2};
967 std::vector<unsigned_t> strides = {2, 2};
968 std::vector<unsigned_t> pads = {1, 1, 0, 0};
969 auto *input =
970 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
971 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
972 mod_.getPRNG());
973 MaxPoolNode *maxpool =
974 F_->createMaxPool("maxpool", input, kernels, strides, pads);
975 SaveNode *saveResult = F_->createSave("saveResult", maxpool->getResult());
976 bindings_.allocate(saveResult->getPlaceholder());
977 SaveNode *saveArgmax = F_->createSave("saveArgmax", maxpool->getArgmax());
978 bindings_.allocate(saveArgmax->getPlaceholder());
979 std::vector<dim_t> actualOutputDims = maxpool->getResult().getType()->dims();
980 EXPECT_EQ(actualOutputDims, outputDims);
981
982 // Save current function state as reference.
983 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
984
985 // Split node.
986 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimH}, {3});
987 std::vector<Node *> splitNodes;
988 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes,
989 ::glow::splitNode(maxpool, splitOption));
990
991 // Check node count.
992 EXPECT_EQ(splitNodes.size(), 0);
993 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 0);
994 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind), 1);
995 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 0);
996 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 0);
997 checkNumericalEquivalence(0);
998}
999
1000/// Test splitting a MaxPool based on an impossible constraint forcing the
1001/// split procedure to go through all the split configurations while verifying
1002/// them. In the end no split should be performed.
1003TEST_F(NodeSplitting, MaxPool_NoSplit) {
1004 Node *node = createMaxPool(F_, bindings_,
1005 /* inputDims */ {3, 7, 8, 4},
1006 /* outputDims */ {3, 6, 7, 4},
1007 /* kernels */ {2, 2},
1008 /* strides */ {1, 1},
1009 /* pads */ {0, 0, 0, 0});
1010
1011 // Save current function state as reference.
1012 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
1013
1014 // Split node by memory size 0.
1015 std::vector<Node *> splitNodes;
1016 ASSIGN_VALUE_OR_FAIL_TEST(
1017 splitNodes, ::glow::splitNode(node, SplitNodeMaxMemConstraint(0)));
1018
1019 // Check node count.
1020 EXPECT_EQ(splitNodes.size(), 0);
1021 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 0);
1022 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind), 1);
1023 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 0);
1024 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 0);
1025 checkNumericalEquivalence(0);
1026}
1027
1028///===---------------------------------------------------------------------===//
1029/// AvgPool
1030///===---------------------------------------------------------------------===//
1031/// Utility function to create a simple network with a single AvgPool node using
1032/// the function \p F and the bindings \p bindings.
1033static Node *createAvgPool(Function *F, PlaceholderBindings &bindings,
1034 llvm::ArrayRef<dim_t> inputDims,
1035 llvm::ArrayRef<dim_t> outputDims,
1036 llvm::ArrayRef<unsigned_t> kernels,
1037 llvm::ArrayRef<unsigned_t> strides,
1038 llvm::ArrayRef<unsigned_t> pads) {
1039 // Create input placeholder.
1040 auto &mod = *(F->getParent());
1041 auto *input =
1042 mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
1043 bindings.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
1044 mod.getPRNG());
1045 // Create AvgPool.
1046 AvgPoolNode *avgpool =
1047 F->createAvgPool("avgpool", input, kernels, strides, pads);
1048 SaveNode *save = F->createSave("save", avgpool->getResult());
1049 bindings.allocate(save->getPlaceholder());
1050 EXPECT_EQ(avgpool->getResult().getType()->dims(), outputDims);
1051 return avgpool;
1052}
1053
1054/// Utility function to test splitting a basic AvgPool node along the dimensions
1055/// \p splitDims in the given number chunks \p numChunks. The split is done
1056/// implicitly relative to the AvgPool output operand.
1057static void splitAvgPoolBasic(Function *F, Function *&optF,
1058 PlaceholderBindings &bindings,
1059 CompilationContext &cctx,
1060 llvm::ArrayRef<size_t> splitDims,
1061 llvm::ArrayRef<dim_t> numChunks) {
1062 Node *node = createAvgPool(F, bindings,
1063 /* inputDims */ {3, 7, 8, 4},
1064 /* outputDims */ {3, 6, 7, 4},
1065 /* kernels */ {2, 2},
1066 /* strides */ {1, 1},
1067 /* pads */ {0, 0, 0, 0});
1068
1069 // Save current function state as reference.
1070 optF = F->clone(F->getName().str() + "_optimized");
1071
1072 // Split node.
1073 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1074 std::vector<Node *> splitNodes;
1075 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1076
1077 // Compute total number of chunks.
1078 dim_t totNumChunks = 1;
1079 for (auto numChunk : numChunks) {
1080 totNumChunks *= numChunk;
1081 }
1082
1083 // Check node count.
1084 EXPECT_EQ(splitNodes.size(), totNumChunks);
1085 EXPECT_EQ(countNodeKind(F, Kinded::Kind::AvgPoolNodeKind), totNumChunks);
1086 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1087 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1088}
1089
1090/// Test splitting a AvgPool along dimension N, H, W or C.
1091#define TEST_AVGPOOL_BASIC_SPLIT(splitDim, numChunks) \
1092 TEST_F(NodeSplitting, AvgPool_Basic_Dim##splitDim##_Chunks##numChunks) { \
1093 splitAvgPoolBasic(F_, optimizedF_, bindings_, cctx_, \
1094 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
1095 checkNumericalEquivalence(0); \
1096 }
1097TEST_AVGPOOL_BASIC_SPLIT(N, 2)
1098TEST_AVGPOOL_BASIC_SPLIT(N, 3)
1099TEST_AVGPOOL_BASIC_SPLIT(H, 2)
1100TEST_AVGPOOL_BASIC_SPLIT(H, 3)
1101TEST_AVGPOOL_BASIC_SPLIT(H, 4)
1102TEST_AVGPOOL_BASIC_SPLIT(H, 5)
1103TEST_AVGPOOL_BASIC_SPLIT(H, 6)
1104TEST_AVGPOOL_BASIC_SPLIT(W, 2)
1105TEST_AVGPOOL_BASIC_SPLIT(W, 3)
1106TEST_AVGPOOL_BASIC_SPLIT(W, 4)
1107TEST_AVGPOOL_BASIC_SPLIT(W, 5)
1108TEST_AVGPOOL_BASIC_SPLIT(W, 6)
1109TEST_AVGPOOL_BASIC_SPLIT(W, 7)
1110TEST_AVGPOOL_BASIC_SPLIT(C, 2)
1111TEST_AVGPOOL_BASIC_SPLIT(C, 3)
1112TEST_AVGPOOL_BASIC_SPLIT(C, 4)
1113#undef TEST_AVGPOOL_BASIC_SPLIT
1114
1115/// Utility function to test splitting a AvgPool node with non-zero padding
1116/// along the dimensions \p splitDims in the given number chunks \p numChunks.
1117/// The split is done implicitly relative to the AvgPool output operand.
1118static void splitAvgPoolNonZeroPad(Function *F, Function *&optF,
1119 PlaceholderBindings &bindings,
1120 CompilationContext &cctx,
1121 llvm::ArrayRef<size_t> splitDims,
1122 llvm::ArrayRef<dim_t> numChunks) {
1123 Node *node = createAvgPool(F, bindings,
1124 /* inputDims */ {1, 4, 4, 1},
1125 /* outputDims */ {1, 4, 8, 1},
1126 /* kernels */ {2, 2},
1127 /* strides */ {1, 1},
1128 /* pads */ {0, 2, 1, 3});
1129
1130 // Save current function state as reference.
1131 optF = F->clone(F->getName().str() + "_optimized");
1132
1133 // Split node.
1134 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1135 std::vector<Node *> splitNodes;
1136 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1137
1138 // Compute total number of chunks.
1139 dim_t totNumChunks = 1;
1140 for (auto numChunk : numChunks) {
1141 totNumChunks *= numChunk;
1142 }
1143
1144 // Check node count.
1145 EXPECT_EQ(splitNodes.size(), totNumChunks);
1146 EXPECT_EQ(countNodeKind(F, Kinded::Kind::AvgPoolNodeKind), totNumChunks);
1147 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1148 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1149}
1150
1151/// Test splitting a AvgPool with padding along dimension N, H, W or C.
1152#define TEST_AVGPOOL_NONZEROPAD_SPLIT(splitDim, numChunks) \
1153 TEST_F(NodeSplitting, \
1154 AvgPool_NonZeroPad_Dim##splitDim##_Chunks##numChunks) { \
1155 splitAvgPoolNonZeroPad(F_, optimizedF_, bindings_, cctx_, \
1156 {ShapeNHWC::Dim##splitDim}, {numChunks}); \
1157 checkNumericalEquivalence(0); \
1158 }
1159TEST_AVGPOOL_NONZEROPAD_SPLIT(H, 2)
1160TEST_AVGPOOL_NONZEROPAD_SPLIT(W, 2)
1161#undef TEST_AVGPOOL_NONZEROPAD_SPLIT
1162
1163/// Test splitting a AvgPool with padding along dimensions H, W.
1164TEST_F(NodeSplitting, AvgPool_NonZeroPad_DimHW_Chunks4) {
1165 splitAvgPoolNonZeroPad(F_, optimizedF_, bindings_, cctx_,
1166 {ShapeNHWC::DimH, ShapeNHWC::DimW}, {2, 2});
1167 checkNumericalEquivalence(0);
1168}
1169
1170/// Test splitting an "ill-defined" AvgPool for which not all the input
1171/// (including padding) is referenced by the output tensor. This happens
1172/// when using a stride larger than 1. This verifies that the node
1173/// splitting infrastructure uses a weaker verification of the mapping
1174/// between input and output for AvgPool.
1175TEST_F(NodeSplitting, AvgPool_IllDefined_DimHW) {
1176 std::vector<size_t> splitDims = {ShapeNHWC::DimH, ShapeNHWC::DimW};
1177 std::vector<dim_t> numChunks = {3, 3};
1178 Node *node = createAvgPool(F_, bindings_,
1179 /* inputDims */ {1, 16, 18, 1},
1180 /* outputDims */ {1, 8, 9, 1},
1181 /* kernels */ {2, 2},
1182 /* strides */ {2, 2},
1183 /* pads */ {1, 1, 0, 0});
1184
1185 // Save current function state as reference.
1186 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
1187
1188 // Split node.
1189 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1190 std::vector<Node *> splitNodes;
1191 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1192
1193 // Check node count.
1194 dim_t totNumChunks = numChunks[0] * numChunks[1];
1195 EXPECT_EQ(splitNodes.size(), totNumChunks);
1196 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind), totNumChunks);
1197 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind),
1198 totNumChunks);
1199 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
1200 checkNumericalEquivalence(0);
1201}
1202
1203/// Test splitting a AvgPool based on memory constraint.
1204TEST_F(NodeSplitting, AvgPool_MaxMem) {
1205 Node *node = createAvgPool(F_, bindings_,
1206 /* inputDims */ {3, 7, 8, 4},
1207 /* outputDims */ {3, 6, 7, 4},
1208 /* kernels */ {2, 2},
1209 /* strides */ {1, 1},
1210 /* pads */ {0, 0, 0, 0});
1211
1212 // Save current function state as reference.
1213 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
1214
1215 // Split node by memory size.
1216 auto origMemSize = node->getTotMemSize();
1217 auto splitMaxMemSize = origMemSize / 2;
1218 std::vector<Node *> splitNodes;
1219 ASSIGN_VALUE_OR_FAIL_TEST(
1220 splitNodes,
1221 ::glow::splitNode(node, SplitNodeMaxMemConstraint(splitMaxMemSize)));
1222
1223 // Check node count.
1224 auto totNumChunks = countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind);
1225 EXPECT_EQ(splitNodes.size(), totNumChunks);
1226 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind),
1227 totNumChunks);
1228 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
1229
1230 // Check split nodes memory sizes.
1231 for (auto *splitNode : splitNodes) {
1232 EXPECT_TRUE(splitNode->getTotMemSize() <= splitMaxMemSize);
1233 }
1234 checkNumericalEquivalence(0);
1235}
1236
1237/// Test splitting a AvgPool based on an impossible constraint forcing the
1238/// split procedure to go through all the split configurations while verifying
1239/// them. In the end no split should be performed.
1240TEST_F(NodeSplitting, AvgPool_NoSplit) {
1241 Node *node = createAvgPool(F_, bindings_,
1242 /* inputDims */ {3, 7, 8, 4},
1243 /* outputDims */ {3, 6, 7, 4},
1244 /* kernels */ {2, 2},
1245 /* strides */ {1, 1},
1246 /* pads */ {0, 0, 0, 0});
1247
1248 // Save current function state as reference.
1249 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
1250
1251 // Split node by memory size 0.
1252 std::vector<Node *> splitNodes;
1253 ASSIGN_VALUE_OR_FAIL_TEST(
1254 splitNodes, ::glow::splitNode(node, SplitNodeMaxMemConstraint(0)));
1255
1256 // Check node count.
1257 EXPECT_EQ(splitNodes.size(), 0);
1258 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 0);
1259 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind), 1);
1260 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 0);
1261 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 0);
1262 checkNumericalEquivalence(0);
1263}
1264
1265///===---------------------------------------------------------------------===//
1266/// FullyConnected
1267///===---------------------------------------------------------------------===//
1268/// Utility function to test splitting a FullyConnected node.
1269static void splitFullyConnected(Function *F, Function *&optF,
1270 PlaceholderBindings &bindings,
1271 CompilationContext &cctx,
1272 llvm::ArrayRef<size_t> splitDims,
1273 llvm::ArrayRef<dim_t> numChunks) {
1274 std::vector<dim_t> inputDims = {10, 13};
1275 std::vector<dim_t> weightsDims = {13, 20};
1276 std::vector<dim_t> biasDims = {20};
1277 auto &mod = *(F->getParent());
1278 auto *input =
1279 mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
1280 bindings.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
1281 mod.getPRNG());
1282 auto *weights =
1283 mod.createPlaceholder(ElemKind::FloatTy, weightsDims, "weights", false);
1284 bindings.allocate(weights)->getHandle<float>().randomize(-10.0, 10.0,
1285 mod.getPRNG());
1286 auto *bias =
1287 mod.createPlaceholder(ElemKind::FloatTy, biasDims, "bias", false);
1288 bindings.allocate(bias)->getHandle<float>().randomize(-10.0, 10.0,
1289 mod.getPRNG());
1290 Node *node = F->createFullyConnected("fc", input, weights, bias);
1291 SaveNode *output = F->createSave("output", node);
1292 bindings.allocate(output->getPlaceholder());
1293
1294 // Save current function state as reference.
1295 optF = F->clone(F->getName().str() + "_optimized");
1296
1297 // Split node.
1298 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1299 std::vector<Node *> splitNodes;
1300 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1301
1302 // Compute total number of chunks.
1303 dim_t totNumChunks = 1;
1304 for (auto numChunk : numChunks) {
1305 totNumChunks *= numChunk;
1306 }
1307
1308 // Check node count.
1309 EXPECT_EQ(splitNodes.size(), totNumChunks);
1310 EXPECT_EQ(countNodeKind(F, Kinded::Kind::FullyConnectedNodeKind),
1311 totNumChunks);
1312 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1313 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1314}
1315
1316/// Test splitting FullyConnected along dimension H.
1317TEST_F(NodeSplitting, FullyConnected_DimH_Chunks2) {
1318 splitFullyConnected(F_, optimizedF_, bindings_, cctx_, {ShapeHW::DimH}, {2});
1319 checkNumericalEquivalence(0);
1320}
1321
1322/// Test splitting FullyConnected along dimension W.
1323TEST_F(NodeSplitting, FullyConnected_DimW_Chunks2) {
1324 splitFullyConnected(F_, optimizedF_, bindings_, cctx_, {ShapeHW::DimW}, {2});
1325 checkNumericalEquivalence(0);
1326}
1327
1328/// Test splitting FullyConnected along dimension H and W.
1329TEST_F(NodeSplitting, FullyConnected_DimHW_Chunks4) {
1330 splitFullyConnected(F_, optimizedF_, bindings_, cctx_,
1331 {ShapeHW::DimH, ShapeHW::DimW}, {2, 2});
1332 checkNumericalEquivalence(0);
1333}
1334
1335///===---------------------------------------------------------------------===//
1336/// MatMul
1337///===---------------------------------------------------------------------===//
1338/// Utility function to test splitting a MatMul node.
1339static void splitMatMul(Function *F, Function *&optF,
1340 PlaceholderBindings &bindings, CompilationContext &cctx,
1341 llvm::ArrayRef<size_t> splitDims,
1342 llvm::ArrayRef<dim_t> numChunks) {
1343 std::vector<dim_t> dimsLHS = {10, 13};
1344 std::vector<dim_t> dimsRHS = {13, 20};
1345 auto &mod = *(F->getParent());
1346 auto *LHS = mod.createPlaceholder(ElemKind::FloatTy, dimsLHS, "LHS", false);
1347 bindings.allocate(LHS)->getHandle<float>().randomize(-10.0, 10.0,
1348 mod.getPRNG());
1349 auto *RHS = mod.createPlaceholder(ElemKind::FloatTy, dimsRHS, "RHS", false);
1350 bindings.allocate(RHS)->getHandle<float>().randomize(-10.0, 10.0,
1351 mod.getPRNG());
1352 Node *node = F->createMatMul("matmul", LHS, RHS);
1353 SaveNode *output = F->createSave("output", node);
1354 bindings.allocate(output->getPlaceholder());
1355
1356 // Save current function state as reference.
1357 optF = F->clone(F->getName().str() + "_optimized");
1358
1359 // Split node.
1360 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1361 std::vector<Node *> splitNodes;
1362 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1363
1364 // Compute total number of chunks.
1365 dim_t totNumChunks = 1;
1366 for (auto numChunk : numChunks) {
1367 totNumChunks *= numChunk;
1368 }
1369
1370 // Check node count.
1371 EXPECT_EQ(splitNodes.size(), totNumChunks);
1372 EXPECT_EQ(countNodeKind(F, Kinded::Kind::MatMulNodeKind), totNumChunks);
1373 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1374 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1375}
1376
1377/// Test splitting MatMul along dimension H.
1378TEST_F(NodeSplitting, MatMul_DimH_Chunks2) {
1379 splitMatMul(F_, optimizedF_, bindings_, cctx_, {ShapeHW::DimH}, {2});
1380 checkNumericalEquivalence(0);
1381}
1382
1383/// Test splitting MatMul along dimension W.
1384TEST_F(NodeSplitting, MatMul_DimW_Chunks2) {
1385 splitMatMul(F_, optimizedF_, bindings_, cctx_, {ShapeHW::DimW}, {2});
1386 checkNumericalEquivalence(0);
1387}
1388
1389/// Test splitting MatMul along dimension H and W.
1390TEST_F(NodeSplitting, MatMul_DimHW_Chunks4) {
1391 splitMatMul(F_, optimizedF_, bindings_, cctx_, {ShapeHW::DimH, ShapeHW::DimW},
1392 {2, 2});
1393 checkNumericalEquivalence(0);
1394}
1395
1396///===---------------------------------------------------------------------===//
1397/// BatchMatMul
1398///===---------------------------------------------------------------------===//
1399/// Utility function to test splitting a BatchMatMul node.
1400static void splitBatchMatMul(Function *F, Function *&optF,
1401 PlaceholderBindings &bindings,
1402 CompilationContext &cctx,
1403 llvm::ArrayRef<size_t> splitDims,
1404 llvm::ArrayRef<dim_t> numChunks) {
1405 std::vector<dim_t> dimsLHS = {2, 10, 13};
1406 std::vector<dim_t> dimsRHS = {2, 13, 20};
1407 auto &mod = *(F->getParent());
1408 auto *LHS = mod.createPlaceholder(ElemKind::FloatTy, dimsLHS, "LHS", false);
1409 bindings.allocate(LHS)->getHandle<float>().randomize(-10.0, 10.0,
1410 mod.getPRNG());
1411 auto *RHS = mod.createPlaceholder(ElemKind::FloatTy, dimsRHS, "RHS", false);
1412 bindings.allocate(RHS)->getHandle<float>().randomize(-10.0, 10.0,
1413 mod.getPRNG());
1414 Node *node = F->createBatchMatMul("batchmatmul", LHS, RHS);
1415 SaveNode *output = F->createSave("output", node);
1416 bindings.allocate(output->getPlaceholder());
1417
1418 // Save current function state as reference.
1419 optF = F->clone(F->getName().str() + "_optimized");
1420
1421 // Split node.
1422 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1423 std::vector<Node *> splitNodes;
1424 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1425
1426 // Compute total number of chunks.
1427 dim_t totNumChunks = 1;
1428 for (auto numChunk : numChunks) {
1429 totNumChunks *= numChunk;
1430 }
1431
1432 // Check node count.
1433 EXPECT_EQ(splitNodes.size(), totNumChunks);
1434 EXPECT_EQ(countNodeKind(F, Kinded::Kind::BatchMatMulNodeKind), totNumChunks);
1435 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1436 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1437}
1438
1439/// Test splitting BatchMatMul along dimension N.
1440TEST_F(NodeSplitting, BatchMatMul_DimN_Chunks2) {
1441 splitBatchMatMul(F_, optimizedF_, bindings_, cctx_, {ShapeNHW::DimN}, {2});
1442 checkNumericalEquivalence(0);
1443}
1444
1445/// Test splitting BatchMatMul along dimension H.
1446TEST_F(NodeSplitting, BatchMatMul_DimH_Chunks2) {
1447 splitBatchMatMul(F_, optimizedF_, bindings_, cctx_, {ShapeNHW::DimH}, {2});
1448 checkNumericalEquivalence(0);
1449}
1450
1451/// Test splitting BatchMatMul along dimension W.
1452TEST_F(NodeSplitting, BatchMatMul_DimW_Chunks2) {
1453 splitBatchMatMul(F_, optimizedF_, bindings_, cctx_, {ShapeNHW::DimW}, {2});
1454 checkNumericalEquivalence(0);
1455}
1456
1457/// Test splitting BatchMatMul along dimension N and H.
1458TEST_F(NodeSplitting, BatchMatMul_DimNH_Chunks4) {
1459 splitBatchMatMul(F_, optimizedF_, bindings_, cctx_,
1460 {ShapeNHW::DimN, ShapeNHW::DimH}, {2, 2});
1461 checkNumericalEquivalence(0);
1462}
1463
1464/// Test splitting BatchMatMul along dimension N, H and W.
1465TEST_F(NodeSplitting, BatchMatMul_DimNHW_Chunks8) {
1466 splitBatchMatMul(F_, optimizedF_, bindings_, cctx_,
1467 {ShapeNHW::DimN, ShapeNHW::DimH, ShapeNHW::DimW}, {2, 2, 2});
1468 checkNumericalEquivalence(0);
1469}
1470
1471///===---------------------------------------------------------------------===//
1472/// BatchedAdd
1473///===---------------------------------------------------------------------===//
1474/// Utility function to test splitting a BatchedAdd node.
1475static void splitBatchedAdd(Function *F, Function *&optF,
1476 PlaceholderBindings &bindings,
1477 CompilationContext &cctx,
1478 llvm::ArrayRef<size_t> splitDims,
1479 llvm::ArrayRef<dim_t> numChunks) {
1480 std::vector<dim_t> batchDims = {2, 10, 13};
1481 std::vector<dim_t> sliceDims = {10, 13};
1482 auto &mod = *(F->getParent());
1483 auto *batch =
1484 mod.createPlaceholder(ElemKind::FloatTy, batchDims, "batch", false);
1485 bindings.allocate(batch)->getHandle<float>().randomize(-10.0, 10.0,
1486 mod.getPRNG());
1487 auto *slice =
1488 mod.createPlaceholder(ElemKind::FloatTy, sliceDims, "slice", false);
1489 bindings.allocate(slice)->getHandle<float>().randomize(-10.0, 10.0,
1490 mod.getPRNG());
1491 Node *node = F->createBatchedAdd("batchedadd", batch, slice);
1492 SaveNode *output = F->createSave("output", node);
1493 bindings.allocate(output->getPlaceholder());
1494
1495 // Save current function state as reference.
1496 optF = F->clone(F->getName().str() + "_optimized");
1497
1498 // Split node.
1499 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1500 std::vector<Node *> splitNodes;
1501 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1502
1503 // Compute total number of chunks.
1504 dim_t totNumChunks = 1;
1505 for (auto numChunk : numChunks) {
1506 totNumChunks *= numChunk;
1507 }
1508
1509 // Check node count.
1510 EXPECT_EQ(splitNodes.size(), totNumChunks);
1511 EXPECT_EQ(countNodeKind(F, Kinded::Kind::BatchedAddNodeKind), totNumChunks);
1512 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1513 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1514}
1515
1516/// Test splitting BatchedAdd along dimension 0.
1517TEST_F(NodeSplitting, BatchedAdd_Dim0_Chunks2) {
1518 splitBatchedAdd(F_, optimizedF_, bindings_, cctx_, {0}, {2});
1519 checkNumericalEquivalence(0);
1520}
1521
1522/// Test splitting BatchedAdd along dimension 1.
1523TEST_F(NodeSplitting, BatchedAdd_Dim1_Chunks2) {
1524 splitBatchedAdd(F_, optimizedF_, bindings_, cctx_, {1}, {2});
1525 checkNumericalEquivalence(0);
1526}
1527
1528/// Test splitting BatchedAdd along dimension 2.
1529TEST_F(NodeSplitting, BatchedAdd_Dim2_Chunks2) {
1530 splitBatchedAdd(F_, optimizedF_, bindings_, cctx_, {2}, {2});
1531 checkNumericalEquivalence(0);
1532}
1533
1534/// Test splitting BatchedAdd along dimension 0 and 1.
1535TEST_F(NodeSplitting, BatchedAdd_Dim01_Chunks4) {
1536 splitBatchedAdd(F_, optimizedF_, bindings_, cctx_, {0, 1}, {2, 2});
1537 checkNumericalEquivalence(0);
1538}
1539
1540/// Test splitting BatchedAdd along dimension 0, 1 and 2.
1541TEST_F(NodeSplitting, BatchedAdd_Dim012_Chunks8) {
1542 splitBatchedAdd(F_, optimizedF_, bindings_, cctx_, {0, 1, 2}, {2, 2, 2});
1543 checkNumericalEquivalence(0);
1544}
1545
1546///===---------------------------------------------------------------------===//
1547/// Transpose
1548///===---------------------------------------------------------------------===//
1549/// Utility function to test splitting a Transpose node.
1550static void splitTranspose(Function *F, Function *&optF,
1551 PlaceholderBindings &bindings,
1552 CompilationContext &cctx,
1553 llvm::ArrayRef<size_t> splitDims,
1554 llvm::ArrayRef<dim_t> numChunks) {
1555 std::vector<dim_t> inputDims = {3, 5, 7};
1556 std::vector<unsigned_t> shuffle = {2, 0, 1};
1557 auto &mod = *(F->getParent());
1558 auto *input =
1559 mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
1560 bindings.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
1561 mod.getPRNG());
1562 Node *node = F->createTranspose("transpose", input, shuffle);
1563 SaveNode *output = F->createSave("output", node);
1564 bindings.allocate(output->getPlaceholder());
1565
1566 // Save current function state as reference.
1567 optF = F->clone(F->getName().str() + "_optimized");
1568
1569 // Split node.
1570 auto splitOption = SplitNodeByNumChunks(splitDims, numChunks);
1571 std::vector<Node *> splitNodes;
1572 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1573
1574 // Compute total number of chunks.
1575 dim_t totNumChunks = 1;
1576 for (auto numChunk : numChunks) {
1577 totNumChunks *= numChunk;
1578 }
1579
1580 // Check node count.
1581 EXPECT_EQ(splitNodes.size(), totNumChunks);
1582 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TransposeNodeKind), totNumChunks);
1583 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1584 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1585}
1586
1587/// Test splitting Transpose along dimension 0.
1588TEST_F(NodeSplitting, Transpose_Dim0_Chunks2) {
1589 splitTranspose(F_, optimizedF_, bindings_, cctx_, {0}, {2});
1590 checkNumericalEquivalence(0);
1591}
1592
1593/// Test splitting Transpose along dimension 1.
1594TEST_F(NodeSplitting, Transpose_Dim1_Chunks2) {
1595 splitTranspose(F_, optimizedF_, bindings_, cctx_, {1}, {2});
1596 checkNumericalEquivalence(0);
1597}
1598
1599/// Test splitting Transpose along dimension 2.
1600TEST_F(NodeSplitting, Transpose_Dim2_Chunks2) {
1601 splitTranspose(F_, optimizedF_, bindings_, cctx_, {2}, {2});
1602 checkNumericalEquivalence(0);
1603}
1604
1605/// Test splitting Transpose along dimension 0 and 1.
1606TEST_F(NodeSplitting, Transpose_Dim01_Chunks4) {
1607 splitTranspose(F_, optimizedF_, bindings_, cctx_, {0, 1}, {2, 2});
1608 checkNumericalEquivalence(0);
1609}
1610
1611/// Test splitting Transpose along dimension 0, 1 and 2.
1612TEST_F(NodeSplitting, Transpose_Dim012_Chunks8) {
1613 splitTranspose(F_, optimizedF_, bindings_, cctx_, {0, 1, 2}, {2, 2, 2});
1614 checkNumericalEquivalence(0);
1615}
1616
1617///===---------------------------------------------------------------------===//
1618/// Binary Operators
1619///===---------------------------------------------------------------------===//
1620/// Test splitting binary operators.
1621TEST_F(NodeSplitting, BinaryOps) {
1622 // Create network with parallel binary operators.
1623 std::vector<dim_t> dims = {10, 10};
1624 auto *inputLHS =
1625 mod_.createPlaceholder(ElemKind::FloatTy, dims, "inputLHS", false);
1626 auto *inputRHS =
1627 mod_.createPlaceholder(ElemKind::FloatTy, dims, "inputRHS", false);
1628 bindings_.allocate(inputLHS)->getHandle<float>().randomize(1.0, 2.0,
1629 mod_.getPRNG());
1630 bindings_.allocate(inputRHS)->getHandle<float>().randomize(1.0, 2.0,
1631 mod_.getPRNG());
1632 Node *add = F_->createAdd("add", inputLHS, inputRHS);
1633 Node *mul = F_->createMul("mul", inputLHS, inputRHS);
1634 Node *sub = F_->createSub("sub", inputLHS, inputRHS);
1635 Node *div = F_->createDiv("div", inputLHS, inputRHS);
1636 Node *max = F_->createMax("max", inputLHS, inputRHS);
1637 Node *min = F_->createMin("min", inputLHS, inputRHS);
1638 Node *cmpLTE = F_->createCmpLTE("cmpLTE", inputLHS, inputRHS);
1639 Node *cmpLT = F_->createCmpLT("cmpLT", inputLHS, inputRHS);
1640 Node *cmpEQ = F_->createCmpEQ("cmpEQ", inputLHS, inputRHS);
1641 Node *pow = F_->createPow("pow", inputLHS, inputRHS);
1642 SaveNode *addSave = F_->createSave("addSave", add);
1643 SaveNode *mulSave = F_->createSave("mulSave", mul);
1644 SaveNode *subSave = F_->createSave("subSave", sub);
1645 SaveNode *divSave = F_->createSave("divSave", div);
1646 SaveNode *maxSave = F_->createSave("maxSave", max);
1647 SaveNode *minSave = F_->createSave("minSave", min);
1648 SaveNode *cmpLTESave = F_->createSave("cmpLTESave", cmpLTE);
1649 SaveNode *cmpLTSave = F_->createSave("cmpLTSave", cmpLT);
1650 SaveNode *cmpEQSave = F_->createSave("cmpEQSave", cmpEQ);
1651 SaveNode *powSave = F_->createSave("powSave", pow);
1652 bindings_.allocate(addSave->getPlaceholder());
1653 bindings_.allocate(mulSave->getPlaceholder());
1654 bindings_.allocate(subSave->getPlaceholder());
1655 bindings_.allocate(divSave->getPlaceholder());
1656 bindings_.allocate(maxSave->getPlaceholder());
1657 bindings_.allocate(minSave->getPlaceholder());
1658 bindings_.allocate(cmpLTESave->getPlaceholder());
1659 bindings_.allocate(cmpLTSave->getPlaceholder());
1660 bindings_.allocate(cmpEQSave->getPlaceholder());
1661 bindings_.allocate(powSave->getPlaceholder());
1662
1663 // Save current function state as reference.
1664 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
1665
1666 // Split nodes.
1667 auto splitOption = SplitNodeByNumChunks({0}, {2});
1668 SplitNodeMap splitMap;
1669 ASSIGN_VALUE_OR_FAIL_TEST(splitMap, ::glow::splitNodes(F_, splitOption));
1670
1671 // Check node count.
1672 EXPECT_EQ(2, splitMap[add].size());
1673 EXPECT_EQ(2, splitMap[mul].size());
1674 EXPECT_EQ(2, splitMap[sub].size());
1675 EXPECT_EQ(2, splitMap[div].size());
1676 EXPECT_EQ(2, splitMap[max].size());
1677 EXPECT_EQ(2, splitMap[min].size());
1678 EXPECT_EQ(2, splitMap[cmpLTE].size());
1679 EXPECT_EQ(2, splitMap[cmpLT].size());
1680 EXPECT_EQ(2, splitMap[cmpEQ].size());
1681 EXPECT_EQ(2, splitMap[pow].size());
1682 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::AddNodeKind));
1683 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::MulNodeKind));
1684 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::SubNodeKind));
1685 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::DivNodeKind));
1686 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::MaxNodeKind));
1687 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::MinNodeKind));
1688 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::CmpLTENodeKind));
1689 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::CmpLTNodeKind));
1690 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::CmpEQNodeKind));
1691 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::PowNodeKind));
1692 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 10 * 2 * 2);
1693 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 10 * 2);
1694 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 10);
1695 checkNumericalEquivalence(0);
1696}
1697
1698///===---------------------------------------------------------------------===//
1699/// Unary Operators
1700///===---------------------------------------------------------------------===//
1701/// Test splitting unary operators.
1702TEST_F(NodeSplitting, UnaryOps) {
1703 std::vector<dim_t> dims = {10, 10};
1704 auto quantizeTy = mod_.uniqueType(ElemKind::Int8QTy, dims, 1.0, 0);
1705 auto requantizeTy = mod_.uniqueType(ElemKind::Int8QTy, dims, 0.5, 0);
1706 auto convertTy = mod_.uniqueType(ElemKind::Float16Ty, dims);
1707
1708 // Create network with chained unary operators.
1709 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
1710 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
1711 mod_.getPRNG());
1712 Node *relu = F_->createRELU("relu", input);
1713 Node *leakyRelu = F_->createLeakyRELU("leakyrelu", relu, 0.1);
1714 Node *clip = F_->createClip("clip", leakyRelu, 1.0, 10.0);
1715 Node *tanh = F_->createTanh("tanh", clip);
1716 Node *sigmoid = F_->createSigmoid("sigmoid", tanh);
1717 Node *log = F_->createLog("log", sigmoid);
1718 Node *exp = F_->createExp("exp", log);
1719 Node *quantize = F_->createQuantize("quantize", exp, quantizeTy);
1720 Node *requantize =
1721 F_->createRescaleQuantized("requantize", quantize, requantizeTy);
1722 Node *dequantize =
1723 F_->createDequantize("dequantize", requantize, ElemKind::FloatTy);
1724 Node *convert = F_->createConvertTo("convert", dequantize, convertTy);
1725 SaveNode *output = F_->createSave("output", convert);
1726 bindings_.allocate(output->getPlaceholder());
1727
1728 // Save current function state as reference.
1729 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
1730
1731 // Split nodes.
1732 auto splitOption = SplitNodeByNumChunks({0}, {2});
1733 SplitNodeMap splitMap;
1734 ASSIGN_VALUE_OR_FAIL_TEST(splitMap, ::glow::splitNodes(F_, splitOption));
1735
1736 // Check node count.
1737 EXPECT_EQ(2, splitMap[relu].size());
1738 EXPECT_EQ(2, splitMap[leakyRelu].size());
1739 EXPECT_EQ(2, splitMap[clip].size());
1740 EXPECT_EQ(2, splitMap[tanh].size());
1741 EXPECT_EQ(2, splitMap[sigmoid].size());
1742 EXPECT_EQ(2, splitMap[log].size());
1743 EXPECT_EQ(2, splitMap[exp].size());
1744 EXPECT_EQ(2, splitMap[quantize].size());
1745 EXPECT_EQ(2, splitMap[requantize].size());
1746 EXPECT_EQ(2, splitMap[dequantize].size());
1747 EXPECT_EQ(2, splitMap[convert].size());
1748 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
1749 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::LeakyReluNodeKind));
1750 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ClipNodeKind));
1751 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TanhNodeKind));
1752 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::SigmoidNodeKind));
1753 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::LogNodeKind));
1754 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ExpNodeKind));
1755 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::QuantizeNodeKind));
1756 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind));
1757 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::DequantizeNodeKind));
1758 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ConvertToNodeKind));
1759 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 11 * 2);
1760 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 11 * 2);
1761 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 11);
1762 checkNumericalEquivalence(0);
1763}
1764
1765///===---------------------------------------------------------------------===//
1766/// Non-Orthogonal Split
1767///===---------------------------------------------------------------------===//
1768/// Utility function to test splitting a Conv2D non-orthogonally based on given
1769/// raw \p sliceRanges.
1770static void splitConv2DNonOrthogonal(Function *F, Function *&optF,
1771 PlaceholderBindings &bindings,
1772 CompilationContext &cctx,
1773 llvm::ArrayRef<SliceRange> sliceRanges) {
1774 Node *node = createConv2D(F, bindings,
1775 /* inputDims */ {5, 4, 4, 7},
1776 /* filterDims */ {8, 3, 3, 7},
1777 /* biasDims */ {8},
1778 /* outputDims */ {5, 4, 4, 8},
1779 /* kernels */ {3, 3},
1780 /* strides */ {1, 1},
1781 /* pads */ {1, 1, 1, 1},
1782 /* group */ 1,
1783 /* dilation */ {1, 1});
1784
1785 // Save current function state as reference.
1786 optF = F->clone(F->getName().str() + "_optimized");
1787
1788 // Split node.
1789 auto splitOption = SplitNodeBySliceRanges(sliceRanges);
1790 std::vector<Node *> splitNodes;
1791 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes, ::glow::splitNode(node, splitOption));
1792
1793 // Check node count.
1794 dim_t totNumChunks = sliceRanges.size();
1795 EXPECT_EQ(splitNodes.size(), totNumChunks);
1796 EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConvolutionNodeKind), totNumChunks);
1797 EXPECT_EQ(countNodeKind(F, Kinded::Kind::InsertTensorNodeKind), totNumChunks);
1798 EXPECT_EQ(countNodeKind(F, Kinded::Kind::TouchNodeKind), 1);
1799}
1800
1801/// Test splitting a Conv2D non-orthogonally along N dimension.
1802TEST_F(NodeSplitting, Conv2D_NonOrthogonal_DimN) {
1803 std::vector<SliceRange> sliceRanges = {
1804 SliceRange({{0, 4}, {0, 4}, {0, 4}, {0, 8}}),
1805 SliceRange({{2, 5}, {0, 4}, {0, 4}, {0, 8}})};
1806 splitConv2DNonOrthogonal(F_, optimizedF_, bindings_, cctx_, sliceRanges);
1807 checkNumericalEquivalence(0);
1808}
1809
1810/// Test splitting a Conv2D non-orthogonally along H dimension.
1811TEST_F(NodeSplitting, Conv2D_NonOrthogonal_DimH) {
1812 std::vector<SliceRange> sliceRanges = {
1813 SliceRange({{0, 5}, {0, 3}, {0, 4}, {0, 8}}),
1814 SliceRange({{0, 5}, {2, 4}, {0, 4}, {0, 8}})};
1815 splitConv2DNonOrthogonal(F_, optimizedF_, bindings_, cctx_, sliceRanges);
1816 checkNumericalEquivalence(0);
1817}
1818
1819/// Test splitting a Conv2D non-orthogonally along W dimension.
1820TEST_F(NodeSplitting, Conv2D_NonOrthogonal_DimW) {
1821 std::vector<SliceRange> sliceRanges = {
1822 SliceRange({{0, 5}, {0, 4}, {0, 3}, {0, 8}}),
1823 SliceRange({{0, 5}, {0, 4}, {2, 4}, {0, 8}})};
1824 splitConv2DNonOrthogonal(F_, optimizedF_, bindings_, cctx_, sliceRanges);
1825 checkNumericalEquivalence(0);
1826}
1827
1828/// Test splitting a Conv2D non-orthogonally along C dimension.
1829TEST_F(NodeSplitting, Conv2D_NonOrthogonal_DimC) {
1830 std::vector<SliceRange> sliceRanges = {
1831 SliceRange({{0, 5}, {0, 4}, {0, 4}, {0, 6}}),
1832 SliceRange({{0, 5}, {0, 4}, {0, 4}, {2, 8}})};
1833 splitConv2DNonOrthogonal(F_, optimizedF_, bindings_, cctx_, sliceRanges);
1834 checkNumericalEquivalence(0);
1835}
1836
1837/// Test splitting a Conv2D non-orthogonally along H and W dimensions.
1838TEST_F(NodeSplitting, Conv2D_NonOrthogonal_DimHW) {
1839 std::vector<SliceRange> sliceRanges = {
1840 SliceRange({{0, 5}, {0, 3}, {0, 3}, {0, 8}}),
1841 SliceRange({{0, 5}, {0, 3}, {1, 4}, {0, 8}}),
1842 SliceRange({{0, 5}, {1, 4}, {0, 3}, {0, 8}}),
1843 SliceRange({{0, 5}, {1, 4}, {1, 4}, {0, 8}})};
1844 splitConv2DNonOrthogonal(F_, optimizedF_, bindings_, cctx_, sliceRanges);
1845 checkNumericalEquivalence(0);
1846}
1847
1848/// Test splitting a Conv2D non-orthogonally along H, W and C dimensions.
1849TEST_F(NodeSplitting, Conv2D_NonOrthogonal_DimHWC) {
1850 std::vector<SliceRange> sliceRanges = {
1851 SliceRange({{0, 5}, {0, 3}, {0, 3}, {0, 6}}),
1852 SliceRange({{0, 5}, {0, 3}, {1, 4}, {0, 6}}),
1853 SliceRange({{0, 5}, {1, 4}, {0, 3}, {0, 6}}),
1854 SliceRange({{0, 5}, {1, 4}, {1, 4}, {0, 6}}),
1855 SliceRange({{0, 5}, {0, 3}, {0, 3}, {2, 8}}),
1856 SliceRange({{0, 5}, {0, 3}, {1, 4}, {2, 8}}),
1857 SliceRange({{0, 5}, {1, 4}, {0, 3}, {2, 8}}),
1858 SliceRange({{0, 5}, {1, 4}, {1, 4}, {2, 8}})};
1859 splitConv2DNonOrthogonal(F_, optimizedF_, bindings_, cctx_, sliceRanges);
1860 checkNumericalEquivalence(0);
1861}
1862
1863/// Test splitting a Conv2D non-orthogonally along N, H, W and C dimensions.
1864TEST_F(NodeSplitting, Conv2D_NonOrthogonal_DimNHWC) {
1865 std::vector<SliceRange> sliceRanges = {
1866 SliceRange({{0, 4}, {0, 3}, {0, 3}, {0, 6}}),
1867 SliceRange({{0, 4}, {0, 3}, {1, 4}, {0, 6}}),
1868 SliceRange({{0, 4}, {1, 4}, {0, 3}, {0, 6}}),
1869 SliceRange({{0, 4}, {1, 4}, {1, 4}, {0, 6}}),
1870 SliceRange({{0, 4}, {0, 3}, {0, 3}, {2, 8}}),
1871 SliceRange({{0, 4}, {0, 3}, {1, 4}, {2, 8}}),
1872 SliceRange({{0, 4}, {1, 4}, {0, 3}, {2, 8}}),
1873 SliceRange({{0, 4}, {1, 4}, {1, 4}, {2, 8}}),
1874 SliceRange({{2, 5}, {0, 3}, {0, 3}, {0, 6}}),
1875 SliceRange({{2, 5}, {0, 3}, {1, 4}, {0, 6}}),
1876 SliceRange({{2, 5}, {1, 4}, {0, 3}, {0, 6}}),
1877 SliceRange({{2, 5}, {1, 4}, {1, 4}, {0, 6}}),
1878 SliceRange({{2, 5}, {0, 3}, {0, 3}, {2, 8}}),
1879 SliceRange({{2, 5}, {0, 3}, {1, 4}, {2, 8}}),
1880 SliceRange({{2, 5}, {1, 4}, {0, 3}, {2, 8}}),
1881 SliceRange({{2, 5}, {1, 4}, {1, 4}, {2, 8}})};
1882 splitConv2DNonOrthogonal(F_, optimizedF_, bindings_, cctx_, sliceRanges);
1883 checkNumericalEquivalence(0);
1884}
1885
1886///===---------------------------------------------------------------------===//
1887/// Recursive Split
1888///===---------------------------------------------------------------------===//
1889/// Utility function to split Conv2D with Relu recursively.
1890static void splitConv2DReluRecursive(Function *F, Function *&optF,
1891 PlaceholderBindings &bindings,
1892 CompilationContext &cctx,
1893 SplitNodeOption &splitOption) {
1894 // Conv2D params.
1895 std::vector<dim_t> inputDims = {5, 8, 8, 4};
1896 std::vector<dim_t> filterDims = {16, 3, 3, 4};
1897 std::vector<dim_t> biasDims = {16};
1898 std::vector<dim_t> outputDims = {5, 4, 4, 16};
1899 std::vector<unsigned_t> kernels = {3, 3};
1900 std::vector<unsigned_t> strides = {2, 2};
1901 std::vector<unsigned_t> pads = {1, 1, 0, 0};
1902 dim_t group = 1;
1903 std::vector<unsigned_t> dilation = {1, 1};
1904
1905 // Create input placeholder.
1906 auto &mod = *(F->getParent());
1907 auto *input =
1908 mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
1909 bindings.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
1910 mod.getPRNG());
1911 // Create filter constant.
1912 auto *filter = mod.createConstant(ElemKind::FloatTy, filterDims, "filter");
1913 filter->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
1914 mod.getPRNG());
1915 // Create bias constant.
1916 auto *bias = mod.createConstant(ElemKind::FloatTy, biasDims, "bias");
1917 bias->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
1918 mod.getPRNG());
1919 // Create Conv2D.
1920 auto *outTy = mod.uniqueType(ElemKind::FloatTy, outputDims);
1921 auto *conv = F->createConv("conv", input, filter, bias, outTy, kernels,
1922 strides, pads, group, dilation);
1923 // Create Relu.
1924 auto *relu = F->createRELU("relu", conv);
1925 // Create Save.
1926 SaveNode *save = F->createSave("save", relu);
1927 bindings.allocate(save->getPlaceholder());
1928
1929 // Save current function state as reference.
1930 optF = F->clone(F->getName().str() + "_optimized");
1931
1932 // Split node recursively.
1933 SplitNodeMap splitMap;
1934 ASSIGN_VALUE_OR_FAIL_TEST(
1935 splitMap,
1936 ::glow::splitNodeRecursively(relu, splitOption, /* maxDepth */ 10));
1937
1938 EXPECT_EQ(splitMap.size(), 2);
1939 EXPECT_TRUE(splitMap.count(conv));
1940 EXPECT_TRUE(splitMap.count(relu));
1941}
1942
1943/// Test splitting Conv2D with Relu along N dimension.
1944TEST_F(NodeSplitting, Conv2D_Relu_Recursive_DimN) {
1945 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimN}, {2});
1946 splitConv2DReluRecursive(F_, optimizedF_, bindings_, cctx_, splitOption);
1947 checkNumericalEquivalence(0);
1948}
1949
1950/// Test splitting Conv2D with Relu along H dimension.
1951TEST_F(NodeSplitting, Conv2D_Relu_Recursive_DimH) {
1952 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimH}, {2});
1953 splitConv2DReluRecursive(F_, optimizedF_, bindings_, cctx_, splitOption);
1954 checkNumericalEquivalence(0);
1955}
1956
1957/// Test splitting Conv2D with Relu along W dimension.
1958TEST_F(NodeSplitting, Conv2D_Relu_Recursive_DimW) {
1959 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimW}, {2});
1960 splitConv2DReluRecursive(F_, optimizedF_, bindings_, cctx_, splitOption);
1961 checkNumericalEquivalence(0);
1962}
1963
1964/// Test splitting Conv2D with Relu along C dimension.
1965TEST_F(NodeSplitting, Conv2D_Relu_Recursive_DimC) {
1966 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimC}, {2});
1967 splitConv2DReluRecursive(F_, optimizedF_, bindings_, cctx_, splitOption);
1968 checkNumericalEquivalence(0);
1969}
1970
1971/// Test splitting Conv2D with Relu along H and W dimensions.
1972TEST_F(NodeSplitting, Conv2D_Relu_Recursive_DimHW) {
1973 auto splitOption =
1974 SplitNodeByNumChunks({ShapeNHWC::DimH, ShapeNHWC::DimW}, {2, 2});
1975 splitConv2DReluRecursive(F_, optimizedF_, bindings_, cctx_, splitOption);
1976 checkNumericalEquivalence(0);
1977}
1978
1979/// Test splitting Conv2D with Relu along H, W and C dimensions.
1980TEST_F(NodeSplitting, Conv2D_Relu_Recursive_DimHWC) {
1981 auto splitOption = SplitNodeByNumChunks(
1982 {ShapeNHWC::DimH, ShapeNHWC::DimW, ShapeNHWC::DimC}, {2, 2, 2});
1983 splitConv2DReluRecursive(F_, optimizedF_, bindings_, cctx_, splitOption);
1984 checkNumericalEquivalence(0);
1985}
1986
1987/// Test splitting Conv2D with Relu along N, H, W and C dimensions.
1988TEST_F(NodeSplitting, Conv2D_Relu_Recursive_DimNHWC) {
1989 auto splitOption = SplitNodeByNumChunks(
1990 {ShapeNHWC::DimN, ShapeNHWC::DimH, ShapeNHWC::DimW, ShapeNHWC::DimC},
1991 {2, 2, 2, 2});
1992 splitConv2DReluRecursive(F_, optimizedF_, bindings_, cctx_, splitOption);
1993 checkNumericalEquivalence(0);
1994}
1995
1996/// Utility function to split Conv2D with Relu and MaxPool recursively.
1997static void splitConv2DReluMaxPoolRecursive(Function *F, Function *&optF,
1998 PlaceholderBindings &bindings,
1999 CompilationContext &cctx,
2000 SplitNodeOption &splitOption) {
2001 // Conv2D params.
2002 std::vector<dim_t> inputDims = {5, 16, 16, 4};
2003 std::vector<dim_t> filterDims = {16, 3, 3, 4};
2004 std::vector<dim_t> biasDims = {16};
2005 std::vector<dim_t> outputDims = {5, 8, 8, 16};
2006 std::vector<unsigned_t> kernels = {3, 3};
2007 std::vector<unsigned_t> strides = {2, 2};
2008 std::vector<unsigned_t> pads = {1, 1, 0, 0};
2009 dim_t group = 1;
2010 std::vector<unsigned_t> dilation = {1, 1};
2011
2012 // MaxPool params.
2013 std::vector<unsigned_t> poolKernels = {3, 3};
2014 std::vector<unsigned_t> poolStrides = {2, 2};
2015 std::vector<unsigned_t> poolPads = {1, 1, 1, 1};
2016
2017 // Create input placeholder.
2018 auto &mod = *(F->getParent());
2019 auto *input =
2020 mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
2021 bindings.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
2022 mod.getPRNG());
2023 // Create filter constant.
2024 auto *filter = mod.createConstant(ElemKind::FloatTy, filterDims, "filter");
2025 filter->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
2026 mod.getPRNG());
2027 // Create bias constant.
2028 auto *bias = mod.createConstant(ElemKind::FloatTy, biasDims, "bias");
2029 bias->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
2030 mod.getPRNG());
2031 // Create Conv2D.
2032 auto *outTy = mod.uniqueType(ElemKind::FloatTy, outputDims);
2033 auto *conv = F->createConv("conv", input, filter, bias, outTy, kernels,
2034 strides, pads, group, dilation);
2035 // Create Relu.
2036 auto *relu = F->createRELU("relu", conv);
2037 // Create MaxPool.
2038 auto *pool =
2039 F->createMaxPool("pool", relu, poolKernels, poolStrides, poolPads);
2040 // Create Save.
2041 SaveNode *save = F->createSave("save", pool->getResult());
2042 bindings.allocate(save->getPlaceholder());
2043
2044 // Save current function state as reference.
2045 optF = F->clone(F->getName().str() + "_optimized");
2046
2047 // Split node recursively.
2048 SplitNodeMap splitMap;
2049 ASSIGN_VALUE_OR_FAIL_TEST(
2050 splitMap,
2051 ::glow::splitNodeRecursively(pool, splitOption, /* maxDepth */ 10));
2052
2053 EXPECT_EQ(splitMap.size(), 3);
2054 EXPECT_TRUE(splitMap.count(conv));
2055 EXPECT_TRUE(splitMap.count(relu));
2056 EXPECT_TRUE(splitMap.count(pool));
2057}
2058
2059/// Test splitting Conv2D with Relu and MaxPool along N dimension.
2060TEST_F(NodeSplitting, Conv2D_Relu_MaxPool_Recursive_DimN) {
2061 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimN}, {2});
2062 splitConv2DReluMaxPoolRecursive(F_, optimizedF_, bindings_, cctx_,
2063 splitOption);
2064 checkNumericalEquivalence(0);
2065}
2066
2067/// Test splittingConv2D with Relu and MaxPool along H dimension.
2068TEST_F(NodeSplitting, Conv2D_Relu_MaxPool_Recursive_DimH) {
2069 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimH}, {2});
2070 splitConv2DReluMaxPoolRecursive(F_, optimizedF_, bindings_, cctx_,
2071 splitOption);
2072 checkNumericalEquivalence(0);
2073}
2074
2075/// Test splitting Conv2D with Relu and MaxPool along W dimension.
2076TEST_F(NodeSplitting, Conv2D_Relu_MaxPool_Recursive_DimW) {
2077 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimW}, {2});
2078 splitConv2DReluMaxPoolRecursive(F_, optimizedF_, bindings_, cctx_,
2079 splitOption);
2080 checkNumericalEquivalence(0);
2081}
2082
2083/// Test splitting Conv2D with Relu and MaxPool along C dimension.
2084TEST_F(NodeSplitting, Conv2D_Relu_MaxPool_Recursive_DimC) {
2085 auto splitOption = SplitNodeByNumChunks({ShapeNHWC::DimC}, {2});
2086 splitConv2DReluMaxPoolRecursive(F_, optimizedF_, bindings_, cctx_,
2087 splitOption);
2088 checkNumericalEquivalence(0);
2089}
2090
2091/// Test splitting Conv2D with Relu and MaxPool along H and W dimensions.
2092TEST_F(NodeSplitting, Conv2D_Relu_MaxPool_Recursive_DimHW) {
2093 auto splitOption =
2094 SplitNodeByNumChunks({ShapeNHWC::DimH, ShapeNHWC::DimW}, {2, 2});
2095 splitConv2DReluMaxPoolRecursive(F_, optimizedF_, bindings_, cctx_,
2096 splitOption);
2097 checkNumericalEquivalence(0);
2098}
2099
2100/// Test splitting Conv2D with Relu and MaxPool along H, W and C dimensions.
2101TEST_F(NodeSplitting, Conv2D_Relu_MaxPool_Recursive_DimHWC) {
2102 auto splitOption = SplitNodeByNumChunks(
2103 {ShapeNHWC::DimH, ShapeNHWC::DimW, ShapeNHWC::DimC}, {2, 2, 2});
2104 splitConv2DReluMaxPoolRecursive(F_, optimizedF_, bindings_, cctx_,
2105 splitOption);
2106 checkNumericalEquivalence(0);
2107}
2108
2109/// Test splitting Conv2D with Relu and MaxPool along N, H, W and C dimensions.
2110TEST_F(NodeSplitting, Conv2D_Relu_MaxPool_Recursive_DimNHWC) {
2111 auto splitOption = SplitNodeByNumChunks(
2112 {ShapeNHWC::DimN, ShapeNHWC::DimH, ShapeNHWC::DimW, ShapeNHWC::DimC},
2113 {2, 2, 2, 2});
2114 splitConv2DReluMaxPoolRecursive(F_, optimizedF_, bindings_, cctx_,
2115 splitOption);
2116 checkNumericalEquivalence(0);
2117}
2118
2119/// Verify that the recursive splitting max depth parameter is honored.
2120TEST_F(NodeSplitting, Recursive_MaxDepth) {
2121 std::vector<dim_t> dims = {10, 10};
2122
2123 // Create network with chained unary operators.
2124 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
2125 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
2126 mod_.getPRNG());
2127 Node *relu = F_->createRELU("relu", input);
2128 Node *clip = F_->createClip("clip", relu, 1.0, 10.0);
2129 Node *tanh = F_->createTanh("tanh", clip);
2130 Node *sigmoid = F_->createSigmoid("sigmoid", tanh);
2131 SaveNode *output = F_->createSave("output", sigmoid);
2132 bindings_.allocate(output->getPlaceholder());
2133
2134 // Save current function state as reference.
2135 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
2136
2137 // Split nodes.
2138 unsigned maxDepth = 2;
2139 auto splitOption = SplitNodeByNumChunks({0}, {2});
2140 SplitNodeMap splitMap;
2141 ASSIGN_VALUE_OR_FAIL_TEST(
2142 splitMap, ::glow::splitNodeRecursively(sigmoid, splitOption, maxDepth));
2143
2144 // Check node count.
2145 EXPECT_EQ(splitMap.size(), 2);
2146 EXPECT_EQ(0, splitMap.count(relu));
2147 EXPECT_EQ(0, splitMap.count(clip));
2148 EXPECT_EQ(1, splitMap.count(tanh));
2149 EXPECT_EQ(1, splitMap.count(sigmoid));
2150 EXPECT_EQ(2, splitMap[tanh].size());
2151 EXPECT_EQ(2, splitMap[sigmoid].size());
2152 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
2153 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ClipNodeKind));
2154 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TanhNodeKind));
2155 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::SigmoidNodeKind));
2156 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 2);
2157 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 2);
2158 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
2159 checkNumericalEquivalence(0);
2160}
2161
2162/// Verify recursive splitting for nodes with single output uses.
2163TEST_F(NodeSplitting, Recursive_SingleOutputUse) {
2164 std::vector<dim_t> dims = {10, 10};
2165
2166 // Create network with chained unary operators.
2167 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
2168 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
2169 mod_.getPRNG());
2170 Node *relu = F_->createRELU("relu", input);
2171 Node *clip = F_->createClip("clip", relu, 1.0, 10.0);
2172 Node *tanh = F_->createTanh("tanh", clip);
2173 Node *sigmoid = F_->createSigmoid("sigmoid", tanh);
2174 SaveNode *output1 = F_->createSave("output1", clip);
2175 SaveNode *output2 = F_->createSave("output2", sigmoid);
2176 bindings_.allocate(output1->getPlaceholder());
2177 bindings_.allocate(output2->getPlaceholder());
2178
2179 // Save current function state as reference.
2180 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
2181
2182 // Split nodes.
2183 bool singleUseOnly = true;
2184 unsigned maxDepth = 10;
2185 auto splitOption = SplitNodeByNumChunks({0}, {2});
2186 SplitNodeMap splitMap;
2187 ASSIGN_VALUE_OR_FAIL_TEST(
2188 splitMap, ::glow::splitNodeRecursively(sigmoid, splitOption, maxDepth,
2189 singleUseOnly));
2190
2191 // Check node count.
2192 EXPECT_EQ(splitMap.size(), 2);
2193 EXPECT_EQ(0, splitMap.count(relu));
2194 EXPECT_EQ(0, splitMap.count(clip));
2195 EXPECT_EQ(1, splitMap.count(tanh));
2196 EXPECT_EQ(1, splitMap.count(sigmoid));
2197 EXPECT_EQ(2, splitMap[tanh].size());
2198 EXPECT_EQ(2, splitMap[sigmoid].size());
2199 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
2200 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ClipNodeKind));
2201 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TanhNodeKind));
2202 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::SigmoidNodeKind));
2203 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 2);
2204 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 2);
2205 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
2206 checkNumericalEquivalence(0);
2207}
2208
2209/// Verify recursive splitting for nodes with multiple output uses.
2210TEST_F(NodeSplitting, Recursive_MultipleOutputUses) {
2211 std::vector<dim_t> dims = {10, 10};
2212
2213 // Create network with chained unary operators.
2214 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
2215 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
2216 mod_.getPRNG());
2217 Node *relu = F_->createRELU("relu", input);
2218 Node *clip = F_->createClip("clip", relu, 1.0, 10.0);
2219 Node *tanh = F_->createTanh("tanh", clip);
2220 Node *sigmoid = F_->createSigmoid("sigmoid", tanh);
2221 SaveNode *output1 = F_->createSave("output1", clip);
2222 SaveNode *output2 = F_->createSave("output2", sigmoid);
2223 bindings_.allocate(output1->getPlaceholder());
2224 bindings_.allocate(output2->getPlaceholder());
2225
2226 // Save current function state as reference.
2227 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
2228
2229 // Split nodes.
2230 bool singleUseOnly = false;
2231 unsigned maxDepth = 10;
2232 auto splitOption = SplitNodeByNumChunks({0}, {2});
2233 SplitNodeMap splitMap;
2234 ASSIGN_VALUE_OR_FAIL_TEST(
2235 splitMap, ::glow::splitNodeRecursively(sigmoid, splitOption, maxDepth,
2236 singleUseOnly));
2237
2238 // Check node count.
2239 EXPECT_EQ(splitMap.size(), 4);
2240 EXPECT_EQ(1, splitMap.count(relu));
2241 EXPECT_EQ(1, splitMap.count(clip));
2242 EXPECT_EQ(1, splitMap.count(tanh));
2243 EXPECT_EQ(1, splitMap.count(sigmoid));
2244 EXPECT_EQ(2, splitMap[relu].size());
2245 EXPECT_EQ(2, splitMap[clip].size());
2246 EXPECT_EQ(2, splitMap[tanh].size());
2247 EXPECT_EQ(2, splitMap[sigmoid].size());
2248 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
2249 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ClipNodeKind));
2250 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TanhNodeKind));
2251 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::SigmoidNodeKind));
2252 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 2);
2253 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 4);
2254 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 2);
2255 checkNumericalEquivalence(0);
2256}
2257
2258/// Verify that the recursive splitting stops based on constraint.
2259TEST_F(NodeSplitting, Recursive_StopConstraint) {
2260 std::vector<dim_t> dims = {10, 10};
2261
2262 // Create network with chained unary operators.
2263 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
2264 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
2265 mod_.getPRNG());
2266 Node *relu = F_->createRELU("relu", input);
2267 Node *clip = F_->createClip("clip", relu, 1.0, 10.0);
2268 Node *tanh = F_->createTanh("tanh", clip);
2269 Node *sigmoid = F_->createSigmoid("sigmoid", tanh);
2270 SaveNode *output1 = F_->createSave("output1", clip);
2271 SaveNode *output2 = F_->createSave("output2", sigmoid);
2272 bindings_.allocate(output1->getPlaceholder());
2273 bindings_.allocate(output2->getPlaceholder());
2274
2275 // Save current function state as reference.
2276 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
2277
2278 // Split nodes.
2279 unsigned maxDepth = 10;
2280 auto splitOption = SplitNodeByNumChunks({0}, {2});
2281 SplitNodeConstraint splitConstraint =
2282 [=](const Node *origNode, const std::vector<Node *> &splitNodes) -> bool {
2283 return (origNode->getKind() != Kinded::Kind::ClipNodeKind);
2284 };
2285 SplitNodeMap splitMap;
2286 ASSIGN_VALUE_OR_FAIL_TEST(
2287 splitMap, ::glow::splitNodeRecursively(sigmoid, &splitOption,
2288 &splitConstraint, maxDepth));
2289
2290 // Check node count.
2291 EXPECT_EQ(splitMap.size(), 2);
2292 EXPECT_EQ(0, splitMap.count(relu));
2293 EXPECT_EQ(0, splitMap.count(clip));
2294 EXPECT_EQ(1, splitMap.count(tanh));
2295 EXPECT_EQ(1, splitMap.count(sigmoid));
2296 EXPECT_EQ(2, splitMap[tanh].size());
2297 EXPECT_EQ(2, splitMap[sigmoid].size());
2298 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
2299 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ClipNodeKind));
2300 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TanhNodeKind));
2301 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::SigmoidNodeKind));
2302 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 2);
2303 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 2);
2304 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
2305 checkNumericalEquivalence(0);
2306}
2307
2308/// Verify that the recursive splitting stops when reaching unsupported node.
2309TEST_F(NodeSplitting, Recursive_StopUnsupportedNode) {
2310 std::vector<dim_t> inputDims = {1, 5, 5, 4};
2311 std::vector<dim_t> outputDims = {1, 10, 10, 4};
2312 auto *input =
2313 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
2314 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
2315 mod_.getPRNG());
2316 auto *outTy = mod_.uniqueType(ElemKind::FloatTy, outputDims);
2317 auto *resize = F_->createResizeBilinear("resize", input, outTy);
2318 Node *relu = F_->createRELU("relu", resize);
2319 SaveNode *output = F_->createSave("output", relu);
2320 bindings_.allocate(output->getPlaceholder());
2321
2322 // Save current function state as reference.
2323 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
2324
2325 // Split nodes.
2326 unsigned maxDepth = 10;
2327 auto splitOption = SplitNodeByNumChunks({1}, {2});
2328 SplitNodeMap splitMap;
2329 ASSIGN_VALUE_OR_FAIL_TEST(
2330 splitMap, ::glow::splitNodeRecursively(relu, splitOption, maxDepth));
2331
2332 // Check node count.
2333 EXPECT_EQ(splitMap.size(), 1);
2334 EXPECT_EQ(1, splitMap.count(relu));
2335 EXPECT_EQ(0, splitMap.count(resize));
2336 EXPECT_EQ(2, splitMap[relu].size());
2337 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
2338 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ResizeBilinearNodeKind));
2339 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 2);
2340 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 2);
2341 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
2342 checkNumericalEquivalence(0);
2343}
2344
2345/// Verify that the recursive splitting only attempts to split the first output
2346/// operand. In this example we verify that the recursive splitting does not
2347/// attempt to split the ArgMax operand of MaxPool.
2348TEST_F(NodeSplitting, Recursive_OnlyFirstOutputOperand) {
2349 std::vector<dim_t> inputDims = {2, 8, 8, 4};
2350 std::vector<unsigned_t> kernels = {3, 3};
2351 std::vector<unsigned_t> strides = {1, 1};
2352 std::vector<unsigned_t> pads = {1, 1, 1, 1};
2353
2354 auto *input =
2355 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
2356 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
2357 mod_.getPRNG());
2358 auto *pool = F_->createMaxPool("pool", input, kernels, strides, pads);
2359 Node *relu = F_->createRELU("relu", pool->getArgmax());
2360 SaveNode *result = F_->createSave("result", pool->getResult());
2361 SaveNode *argmax = F_->createSave("argmax", relu);
2362 bindings_.allocate(result->getPlaceholder());
2363 bindings_.allocate(argmax->getPlaceholder());
2364
2365 // Save current function state as reference.
2366 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
2367
2368 // Split nodes.
2369 unsigned maxDepth = 10;
2370 auto splitOption = SplitNodeByNumChunks({0}, {2});
2371 SplitNodeMap splitMap;
2372 ASSIGN_VALUE_OR_FAIL_TEST(
2373 splitMap, ::glow::splitNodeRecursively(relu, splitOption, maxDepth));
2374
2375 // Check node count.
2376 EXPECT_EQ(splitMap.size(), 1);
2377 EXPECT_EQ(1, splitMap.count(relu));
2378 EXPECT_EQ(0, splitMap.count(pool));
2379 EXPECT_EQ(2, splitMap[relu].size());
2380 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
2381 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind));
2382 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 2);
2383 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 2);
2384 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 1);
2385 // Note: We do not run the inference since we do not support Relu for INT64.
2386}
2387
2388/// Verify that identical Touch nodes to not get reused. The Touch node should
2389/// be excepted from CSE because it is not safe to assume that identical Touch
2390/// nodes produce same output because the output is not initialized. The
2391/// presence of the Touch node in CSE tends to increase the lifetime of some
2392/// buffers resulting in extra copy instructions (runtime overhead) and
2393/// increased buffer lifetimes (memory overhead).
2394TEST_F(NodeSplitting, DoNotShareTouchNodes) {
2395 std::vector<dim_t> inputDims = {2, 2};
2396 auto *input =
2397 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
2398 bindings_.allocate(input)->getHandle<float>().randomize(-10.0, 10.0,
2399 mod_.getPRNG());
2400 Node *relu1 = F_->createRELU("relu1", input);
2401 Node *relu2 = F_->createRELU("relu2", relu1);
2402 SaveNode *result = F_->createSave("result", relu2);
2403 bindings_.allocate(result->getPlaceholder());
2404
2405 // Save current function state as reference.
2406 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
2407
2408 // Split nodes.
2409 auto splitOption = SplitNodeByNumChunks({0}, {2});
2410 std::vector<Node *> splitNodes1;
2411 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes1, ::glow::splitNode(relu1, splitOption));
2412 std::vector<Node *> splitNodes2;
2413 ASSIGN_VALUE_OR_FAIL_TEST(splitNodes2, ::glow::splitNode(relu2, splitOption));
2414
2415 // Optimize function (this includes CSE).
2416 ::glow::optimize(F_, cctx_);
2417
2418 // Check node count.
2419 EXPECT_EQ(splitNodes1.size(), 2);
2420 EXPECT_EQ(splitNodes2.size(), 2);
2421 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 4);
2422 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 4);
2423 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::InsertTensorNodeKind), 4);
2424
2425 // We should have 2 Touch nodes and not 1 because we should not reuse them.
2426 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TouchNodeKind), 2);
2427}
2428