1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "BackendTestUtils.h"
18
19#include "glow/ExecutionEngine/ExecutionEngine.h"
20#include "glow/Graph/Graph.h"
21#include "glow/Graph/PlaceholderBindings.h"
22#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
23#include "glow/Support/Random.h"
24
25#include "gtest/gtest.h"
26
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/Support/CommandLine.h"
29#include "llvm/Support/Signals.h"
30
31#include <functional>
32
33using namespace glow;
34using llvm::cast;
35
36/// This matches the signature that is used for the parameterized tests here,
37/// i.e. those passing three parameters via a single ::testing::Combine() into
38/// GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST().
39using ThreeIntTupleConfig = std::tuple<std::string, std::tuple<int, int, int>>;
40using FourIntTupleConfig =
41 std::tuple<std::string, std::tuple<int, int, int, int>>;
42
43#define SET_BACKEND_KIND_AND_THREE_INT_PARAMS(CONFIG, BACKEND_NAME, PARAM1, \
44 PARAM2, PARAM3) \
45 std::tuple<int, int, int> threeIntTupleParams; \
46 std::tie(BACKEND_NAME, threeIntTupleParams) = CONFIG; \
47 std::tie(PARAM1, PARAM2, PARAM3) = threeIntTupleParams;
48
49#define SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(CONFIG, BACKEND_KIND, PARAM1, \
50 PARAM2, PARAM3, PARAM4) \
51 std::tuple<int, int, int, int> fourIntTupleParams; \
52 std::tie(BACKEND_KIND, fourIntTupleParams) = CONFIG; \
53 std::tie(PARAM1, PARAM2, PARAM3, PARAM4) = fourIntTupleParams;
54
55//===--------------------------------------------------------------------===//
56// Convolution Parameter Sweep Tests
57//===--------------------------------------------------------------------===//
58
59/// Create a simple network that has a single fp convolution.
60static FunctionTensorPair
61createAndInitConvNet(glow::PlaceholderBindings &bindings,
62 glow::ExecutionEngine &EE, dim_t size, dim_t convDepth,
63 dim_t kernel, dim_t stride, dim_t pad) {
64 PseudoRNG PRNG;
65 auto &mod = EE.getModule();
66 Function *F = mod.createFunction("main");
67 auto *var = mod.createPlaceholder(ElemKind::FloatTy,
68 {1, size, size, convDepth}, "var", false);
69 bindings.allocate(var)->getHandle().initXavier(1, PRNG);
70
71 auto *conv =
72 F->createConv(bindings, "conv", var, convDepth, kernel, stride, pad, 1);
73 bindings.get(cast<Placeholder>(conv->getFilter()))->getHandle().clear(0.1);
74 bindings.get(cast<Placeholder>(conv->getBias()))->getHandle().clear(0.1);
75 auto *result = F->createSave("ret", conv);
76 auto *resultTensor = bindings.allocate(result->getPlaceholder());
77 convertPlaceholdersToConstants(F, bindings, {var, result->getPlaceholder()});
78
79 return std::make_pair(F, resultTensor);
80}
81
82/// Helper to test sweeping across a variety of configurations of a convolution
83/// by comparing the results to the Interpreter given some \p allowedError.
84/// \p config contains the backend to compare the Interpreter against, plus the
85/// specific configuration to run for this test. \p interpElemKind and \p
86/// backendElemKind are the element kinds to use for the Interpreter and
87/// backend, respectively.
88static void testParamSweepConv(ThreeIntTupleConfig config,
89 ElemKind interpElemKind,
90 ElemKind backendElemKind, float allowedError) {
91 std::string backend;
92 size_t size, depth, kernel;
93 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, size, depth, kernel)
94
95 LOG(INFO) << "Testing Conv with size: " << size << "; depth: " << depth
96 << "; kernel: " << kernel << "\n";
97
98 auto boundF = std::bind(createAndInitConvNet, std::placeholders::_1,
99 std::placeholders::_2, size, depth, kernel,
100 /* stride */ 1, /* pad */ 0);
101 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
102 allowedError, parCloneCountOpt);
103}
104
105DECLARE_STATELESS_BACKEND_TEST(ConvSweepTest, ThreeIntTupleConfig);
106
107GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
108 SweepTest, ConvSweepTest,
109 ::testing::Combine(/* size */ ::testing::Values(5, 7, 15),
110 /* depth */ ::testing::Values(8, 64),
111 /* kernel */ ::testing::Values(1, 3)));
112
113/// Compare backend against the interpreter in Float.
114TEST_P(ConvSweepTest, ConvTest_Float) {
115 CHECK_IF_ENABLED();
116 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f);
117}
118
119/// Compare backend against the interpreter in Int8.
120TEST_P(ConvSweepTest, ConvTest_Int8) {
121 CHECK_IF_ENABLED();
122 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.045f);
123}
124
125/// Compare backend against the interpreter in FP16.
126TEST_P(ConvSweepTest, ConvTest_Float16) {
127 CHECK_IF_ENABLED();
128 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
129 0.005f);
130}
131
132/// Compare backend against the interpreter in FP16.
133TEST_P(ConvSweepTest, ConvTest_BFloat16) {
134 CHECK_IF_ENABLED();
135 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
136 0.005f);
137}
138
139//===--------------------------------------------------------------------===//
140// BatchMatMul Parameter Sweep Tests
141//===--------------------------------------------------------------------===//
142
143/// Create a simple network that has a single fp batch mat mul.
144static FunctionTensorPair
145createAndInitBatchMatMulNet(glow::PlaceholderBindings &bindings,
146 glow::ExecutionEngine &EE, dim_t N, dim_t A,
147 dim_t Z, dim_t B) {
148 PseudoRNG PRNG;
149 auto &mod = EE.getModule();
150 Function *F = mod.createFunction("main");
151 auto *LHS = mod.createPlaceholder(ElemKind::FloatTy, {N, A, Z}, "LHS", false);
152 auto *RHS = mod.createPlaceholder(ElemKind::FloatTy, {N, Z, B}, "RHS", false);
153 bindings.allocate(LHS)->getHandle().initXavier(10, PRNG);
154 bindings.allocate(RHS)->getHandle().initXavier(10, PRNG);
155
156 auto *R = F->createBatchMatMul("BMM", LHS, RHS);
157
158 auto *save = F->createSave("save", R);
159 auto *resultTensor = bindings.allocate(save->getPlaceholder());
160
161 return std::make_pair(F, resultTensor);
162}
163
164/// Helper to test sweeping across a variety of configurations of a BatchMatMul
165/// by comparing the results to the Interpreter given some \p allowedError.
166/// \p config contains the backend to compare the Interpreter against, plus the
167/// specific configuration to run for this test. \p interpElemKind and \p
168/// backendElemKind are the element kinds to use for the Interpreter and
169/// backend, respectively.
170static void testParamSweepBatchMatMul(ThreeIntTupleConfig config,
171 ElemKind interpElemKind,
172 ElemKind backendElemKind,
173 float allowedError) {
174 std::string backend;
175 size_t N, A, Z;
176 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, N, A, Z);
177 size_t B = A;
178
179 LOG(INFO) << "\n\tTesting BatchMatMul with N: " << N << "; A: " << A
180 << "; Z: " << Z << "; B: " << B << "\n";
181
182 // Multiplying LHS {N, A, Z} by RHS {N, Z, B} to get result {N, A, B}.
183 auto boundF = std::bind(createAndInitBatchMatMulNet, std::placeholders::_1,
184 std::placeholders::_2, N, A, Z, B);
185 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
186 allowedError, parCloneCountOpt);
187}
188
189DECLARE_STATELESS_BACKEND_TEST(BatchMatMulSweepTest, ThreeIntTupleConfig);
190
191GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
192 SweepTest, BatchMatMulSweepTest,
193 ::testing::Combine(/* N */ ::testing::Values(1, 4, 16, 24),
194 /* A */ ::testing::Range(10, 16),
195 /* Z */ ::testing::Values(32, 64, 128, 256)));
196
197/// Compare backend against the interpreter in Float.
198TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float) {
199 CHECK_IF_ENABLED();
200 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
201 0.0001f);
202}
203
204/// Compare backend against the interpreter in Int8.
205TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Int8) {
206 CHECK_IF_ENABLED();
207 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy,
208 0.06f);
209}
210
211/// Compare backend against the interpreter in FP16.
212TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float16) {
213 CHECK_IF_ENABLED();
214 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
215 0.005f);
216}
217
218/// Compare backend against the interpreter in FP16.
219TEST_P(BatchMatMulSweepTest, BatchMatMulTest_BFloat16) {
220 CHECK_IF_ENABLED();
221 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
222 0.005f);
223}
224
225//===--------------------------------------------------------------------===//
226// FullyConnected Parameter Sweep Tests
227//===--------------------------------------------------------------------===//
228
229/// Create a simple network that has a single fp FC.
230static FunctionTensorPair
231createAndInitFCNet(glow::PlaceholderBindings &bindings,
232 glow::ExecutionEngine &EE, dim_t A, dim_t Z, dim_t B) {
233 PseudoRNG PRNG;
234 auto &mod = EE.getModule();
235 Function *F = mod.createFunction("main");
236 auto *IP = mod.createPlaceholder(ElemKind::FloatTy, {A, Z}, "input", false);
237 auto *WC = mod.createConstant(ElemKind::FloatTy, {Z, B}, "weights");
238 auto *BC = mod.createConstant(ElemKind::FloatTy, {B}, "bias");
239 bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG());
240 BC->getPayloadMutable().getHandle().randomize(0, 0.000005, mod.getPRNG());
241 WC->getPayloadMutable().getHandle().randomize(-0.4, 0.4, mod.getPRNG());
242
243 auto *FC = F->createFullyConnected("FC", IP, WC, BC);
244 auto *save = F->createSave("save", FC);
245 auto *resultTensor = bindings.allocate(save->getPlaceholder());
246
247 return std::make_pair(F, resultTensor);
248}
249
250/// Helper to test sweeping across a variety of configurations of a FC by
251/// comparing the results to the Interpreter given some \p allowedError.
252/// \p config contains the backend to compare the Interpreter against, plus the
253/// specific configuration to run for this test. \p interpElemKind and \p
254/// backendElemKind are the element kinds to use for the Interpreter and
255/// backend, respectively.
256static void testParamSweepFC(ThreeIntTupleConfig config,
257 ElemKind interpElemKind, ElemKind backendElemKind,
258 float allowedError) {
259 std::string backend;
260 size_t A, Z, B;
261 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, A, Z, B);
262
263 LOG(INFO) << "\n\tTesting FC with A: " << A << "; Z: " << Z << "; B: " << B
264 << "\n";
265
266 auto boundF = std::bind(createAndInitFCNet, std::placeholders::_1,
267 std::placeholders::_2, A, Z, B);
268 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
269 allowedError, parCloneCountOpt);
270}
271
272DECLARE_STATELESS_BACKEND_TEST(FCSweepTest, ThreeIntTupleConfig);
273
274GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
275 SweepTest, FCSweepTest,
276 ::testing::Combine(
277 /* A */ ::testing::Values(1, 4, 16, 64),
278 /* Z */ ::testing::Values(16, 128, 256, 512, 1024, 2048, 4096),
279 /* B */ ::testing::Values(1, 48, 64, 256, 1024)));
280
281/// Compare backend against the interpreter in Float.
282TEST_P(FCSweepTest, FCTest_Float) {
283 CHECK_IF_ENABLED();
284 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f);
285}
286
287/// Compare backend against the interpreter in Int8.
288TEST_P(FCSweepTest, FCTest_Int8) {
289 CHECK_IF_ENABLED();
290 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.065f);
291}
292
293/// Compare backend against the interpreter in FP16.
294TEST_P(FCSweepTest, FCTest_Float16) {
295 CHECK_IF_ENABLED();
296 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty, 0.005f);
297}
298
299/// Compare backend against the interpreter in BFloat16.
300TEST_P(FCSweepTest, FCTest_BFloat16) {
301 CHECK_IF_ENABLED();
302 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.005f);
303}
304
305//===--------------------------------------------------------------------===//
306// Concat Parameter Sweep Tests
307//===--------------------------------------------------------------------===//
308
309/// Create a simple network that has a single fp Concat.
310static FunctionTensorPair
311createAndInitConcatNet(glow::PlaceholderBindings &bindings,
312 glow::ExecutionEngine &EE, size_t numInputs,
313 size_t numDims, size_t maxLength, size_t axis) {
314 PseudoRNG PRNG;
315 auto &mod = EE.getModule();
316 Function *F = mod.createFunction("main");
317
318 // Make leading dimensions smaller than trailing. Reduces size of tests and is
319 // also in line with typical tests.
320 std::vector<dim_t> dims(numDims, maxLength);
321 for (size_t i = 0; i < numDims; i++) {
322 dims[numDims - 1 - i] /= std::pow(2, i);
323 }
324
325 std::vector<NodeValue> inputs(numInputs);
326 for (size_t i = 0; i < numInputs; i++) {
327 auto *IP = mod.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
328 bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG());
329 assert(IP);
330 inputs[i] = IP->getOutput();
331 }
332
333 auto *concat = F->createConcat("concat", inputs, axis);
334 auto *save = F->createSave("save", concat);
335 auto *resultTensor = bindings.allocate(save->getPlaceholder());
336
337 return std::make_pair(F, resultTensor);
338}
339
340/// Helper to test sweeping across a variety of configurations of a Concat by
341/// comparing the results to the Interpreter given some \p allowedError.
342/// \p config contains the backend to compare the Interpreter against, plus the
343/// specific configuration to run for this test. \p interpElemKind and \p
344/// backendElemKind are the element kinds to use for the Interpreter and
345/// backend, respectively.
346static void testParamSweepConcat(FourIntTupleConfig config,
347 ElemKind interpElemKind,
348 ElemKind backendElemKind, float allowedError) {
349 std::string backend;
350 size_t numInputs, numDims, maxLength, axis;
351 SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(config, backend, numInputs, numDims,
352 maxLength, axis);
353 // Exit if axis outside of numDims.
354 if (axis >= numDims) {
355 return;
356 }
357
358 LOG(INFO) << "\n\tTesting Concat with numInputs: " << numInputs
359 << "; numDims: " << numDims << "; maxLength: " << maxLength
360 << "; axis: " << axis << "\n";
361
362 auto boundF =
363 std::bind(createAndInitConcatNet, std::placeholders::_1,
364 std::placeholders::_2, numInputs, numDims, maxLength, axis);
365 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
366 allowedError, parCloneCountOpt);
367}
368
369DECLARE_STATELESS_BACKEND_TEST(ConcatSweepTest, FourIntTupleConfig);
370
371GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
372 SweepTest, ConcatSweepTest,
373 ::testing::Combine(/* numInputs */ ::testing::Values(1, 2, 4, 8, 16, 32, 64,
374 128, 192, 256),
375 /* numDims */ ::testing::Range(1, 4),
376 /* maxLength */ ::testing::Values(16, 32, 64, 128),
377 /* axis */ ::testing::Range(0, 3)));
378
379/// Compare backend against the interpreter in Float.
380TEST_P(ConcatSweepTest, ConcatTest_Float) {
381 CHECK_IF_ENABLED();
382 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0f);
383}
384
385/// Compare backend against the interpreter in Int8. Note that we do not use the
386/// same ElemKind for the Interpreter; this is because the backend will
387/// quantize/dequantize the input/result anyway, so the comparison wouldn't be
388/// purely on data movement.
389TEST_P(ConcatSweepTest, ConcatTest_Int8) {
390 CHECK_IF_ENABLED();
391 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy,
392 0.002f);
393}
394
395/// Compare backend against the interpreter in Float16. Note that we do not use
396/// the same ElemKind for the Interpreter; this is because the backend will
397/// down/up convert the input/result anyway, so the comparison wouldn't be
398/// purely on data movement.
399TEST_P(ConcatSweepTest, ConcatTest_Float16) {
400 CHECK_IF_ENABLED();
401 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
402 0.0001f);
403}
404
405/// Compare backend against the interpreter in BFloat16. Note that we do not use
406/// the same ElemKind for the Interpreter; this is because the backend will
407/// down/up convert the input/result anyway, so the comparison wouldn't be
408/// purely on data movement.
409TEST_P(ConcatSweepTest, ConcatTest_BFloat16) {
410 CHECK_IF_ENABLED();
411 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
412 0.0001f);
413}
414
415//===--------------------------------------------------------------------===//
416// SLWS Parameter Sweep Tests
417//===--------------------------------------------------------------------===//
418
419/// Create a simple network that has a single fp SLWS.
420static FunctionTensorPair
421createAndInitSLWSNet(glow::PlaceholderBindings &bindings,
422 glow::ExecutionEngine &EE, dim_t embeddingRows,
423 dim_t embeddingDim, dim_t numLengths, bool rowwiseQuantize,
424 bool fused, bool FP16, bool accumFP16) {
425 PseudoRNG PRNG;
426 auto &mod = EE.getModule();
427 Function *F = mod.createFunction("main");
428
429 // Initialize lengths according to the number provided by the test. Note that
430 // we arbitrarily set them between [80,120].
431 auto *lengths =
432 mod.createPlaceholder(ElemKind::Int32ITy, {numLengths}, "lengths", false);
433 auto LH = bindings.allocate(lengths)->getHandle<int32_t>();
434 LH.randomize(80, 120, mod.getPRNG());
435
436 // Get the sum of the lengths to then use as the size for indices and weights.
437 dim_t sumOfLengths = 0;
438 for (const int32_t &e : LH) {
439 sumOfLengths += e;
440 }
441
442 // Initialize indices to size of sum of lengths. Randomly set them to point
443 // somewhere inside the embedding.
444 auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {sumOfLengths},
445 "indices", false);
446 bindings.allocate(indices)->getHandle<int64_t>().randomize(
447 0, embeddingRows - 1, mod.getPRNG());
448
449 // Xavier initialize the weights with the correct data type.
450 Constant *weights;
451 if (FP16) {
452 weights =
453 mod.createConstant(ElemKind::Float16Ty, {sumOfLengths}, "weights");
454 weights->getPayloadMutable().getHandle<float16_t>().initXavier(
455 weights->getType()->size() * 2, mod.getPRNG());
456 } else {
457 weights = mod.createConstant(ElemKind::FloatTy, {sumOfLengths}, "weights");
458 weights->getPayloadMutable().getHandle<float>().initXavier(
459 weights->getType()->size() * 2, mod.getPRNG());
460 }
461
462 // Create the embedding; non-RWQ versions will simply create a Constant with
463 // it, while RWQ versions will use its data to create a RWQ Constant
464 // internally in the Node constructor.
465 Tensor embeddingT(ElemKind::FloatTy, {embeddingRows, embeddingDim});
466 embeddingT.getHandle().initXavier(embeddingT.size() * 2, mod.getPRNG());
467
468 // Create the SLWS based on provided options.
469 Node *SLWS;
470 if (!rowwiseQuantize) {
471 auto *embeddingC = mod.createConstant("embedding", std::move(embeddingT));
472 SLWS = F->createSparseLengthsWeightedSum("SLWS", embeddingC, weights,
473 indices, lengths);
474 } else {
475 if (fused) {
476 const ElemKind precision =
477 FP16 ? ElemKind::UInt8FusedFP16QTy : ElemKind::UInt8FusedQTy;
478 SLWS = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
479 "FRQSLWS", embeddingT, weights, indices, lengths, precision,
480 accumFP16);
481 } else {
482 const ElemKind precision = FP16 ? ElemKind::Float16Ty : ElemKind::FloatTy;
483 SLWS = F->createRowwiseQuantizedSparseLengthsWeightedSum(
484 "RQSLWS", embeddingT, weights, indices, lengths,
485 quantization::Schema::Asymmetric, precision, accumFP16);
486 }
487 }
488 auto *save = F->createSave("save", SLWS);
489 auto *resultTensor = bindings.allocate(save->getPlaceholder());
490
491 return std::make_pair(F, resultTensor);
492}
493
494/// Helper to test sweeping across a variety of configurations of a SLWS by
495/// comparing the results to the Interpreter given some \p allowedError.
496/// \p config contains the backend to compare the Interpreter against, plus the
497/// specific configuration to run for this test. \p interpElemKind and \p
498/// backendElemKind are the element kinds to use for the Interpreter and
499/// backend, respectively. Pass in options for the test \p rowwiseQuantize,
500/// \p fused, \p FP16, and \p accumFP16.
501static void testParamSweepSLWS(ThreeIntTupleConfig config,
502 ElemKind interpElemKind,
503 ElemKind backendElemKind, float allowedError,
504 bool rowwiseQuantize, bool fused, bool FP16,
505 bool accumFP16) {
506 std::string backend;
507 size_t embeddingRows, embeddingDim, numLengths;
508 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, embeddingRows,
509 embeddingDim, numLengths);
510
511 LOG(INFO) << "\n\tTesting SLWS with embeddingRows: " << embeddingRows
512 << "; embeddingDim: " << embeddingDim
513 << "; numLengths: " << numLengths << "\n";
514
515 auto boundF = std::bind(createAndInitSLWSNet, std::placeholders::_1,
516 std::placeholders::_2, embeddingRows, embeddingDim,
517 numLengths, rowwiseQuantize, fused, FP16, accumFP16);
518 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
519 allowedError, parCloneCountOpt);
520}
521
522DECLARE_STATELESS_BACKEND_TEST(SLWSSweepTest, ThreeIntTupleConfig);
523
524GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
525 SweepTest, SLWSSweepTest,
526 ::testing::Combine(
527 /* embeddingRows */ ::testing::Values(100, 1000, 10000, 100000),
528 /* embeddingDim */ ::testing::Values(32, 64, 96, 128),
529 /* numLengths */ ::testing::Values(16, 32, 64, 128, 256)));
530
531/// Compare backend against the interpreter.
532TEST_P(SLWSSweepTest, SLWS_Float) {
533 CHECK_IF_ENABLED();
534 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
535 0.000001f,
536 /* rowwiseQuantize */ false,
537 /* fused */ false, /* FP16 */ false,
538 /* accumFP16 */ false);
539}
540
541/// Compare backend against the interpreter in Float.
542TEST_P(SLWSSweepTest, RWQSLWS_Float) {
543 CHECK_IF_ENABLED();
544 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
545 0.000001f,
546 /* rowwiseQuantize */ true,
547 /* fused */ false, /* FP16 */ false,
548 /* accumFP16 */ false);
549}
550
551/// Compare backend against the interpreter in Float.
552TEST_P(SLWSSweepTest, FRWQSLWS_Float) {
553 CHECK_IF_ENABLED();
554 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
555 0.000001f,
556 /* rowwiseQuantize */ true,
557 /* fused */ true, /* FP16 */ false,
558 /* accumFP16 */ false);
559}
560
561/// Compare backend against the interpreter in Float.
562TEST_P(SLWSSweepTest, RWQSLWS_Float16) {
563 // Note: not currently enabled for any open-source backends, as only the
564 // Interpreter supports this.
565 CHECK_IF_ENABLED();
566 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
567 0.000001f,
568 /* rowwiseQuantize */ true,
569 /* fused */ false, /* FP16 */ true,
570 /* accumFP16 */ false);
571}
572
573/// Compare backend against the interpreter in Float.
574TEST_P(SLWSSweepTest, FRWQSLWS_Float16) {
575 // Note: not currently enabled for any open-source backends, as only the
576 // Interpreter supports this.
577 CHECK_IF_ENABLED();
578 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
579 0.000001f,
580 /* rowwiseQuantize */ true,
581 /* fused */ true, /* FP16 */ true,
582 /* accumFP16 */ false);
583}
584
585/// Compare backend against the interpreter in Float.
586TEST_P(SLWSSweepTest, RWQSLWS_Float16_AccumFloat16) {
587 // Note: not currently enabled for any open-source backends, as only the
588 // Interpreter supports this.
589 CHECK_IF_ENABLED();
590 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
591 0.000001f,
592 /* rowwiseQuantize */ true,
593 /* fused */ false, /* FP16 */ true,
594 /* accumFP16 */ true);
595}
596
597/// Compare backend against the interpreter in Float.
598TEST_P(SLWSSweepTest, FRWQSLWS_Float16_AccumFloat16) {
599 // Note: not currently enabled for any open-source backends, as only the
600 // Interpreter supports this.
601 CHECK_IF_ENABLED();
602 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
603 0.000001f,
604 /* rowwiseQuantize */ true,
605 /* fused */ true, /* FP16 */ true,
606 /* accumFP16 */ true);
607}
608
609int main(int argc, char **argv) {
610 ::testing::InitGoogleTest(&argc, argv);
611 llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
612 llvm::cl::ParseCommandLineOptions(argc, argv);
613 return RUN_ALL_TESTS();
614}
615