1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "BackendTestUtils.h" |
18 | |
19 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
20 | #include "glow/Graph/Graph.h" |
21 | #include "glow/Graph/PlaceholderBindings.h" |
22 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
23 | #include "glow/Support/Random.h" |
24 | |
25 | #include "gtest/gtest.h" |
26 | |
27 | #include "llvm/ADT/STLExtras.h" |
28 | #include "llvm/Support/CommandLine.h" |
29 | #include "llvm/Support/Signals.h" |
30 | |
31 | #include <functional> |
32 | |
33 | using namespace glow; |
34 | using llvm::cast; |
35 | |
36 | /// This matches the signature that is used for the parameterized tests here, |
37 | /// i.e. those passing three parameters via a single ::testing::Combine() into |
38 | /// GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(). |
39 | using ThreeIntTupleConfig = std::tuple<std::string, std::tuple<int, int, int>>; |
40 | using FourIntTupleConfig = |
41 | std::tuple<std::string, std::tuple<int, int, int, int>>; |
42 | |
43 | #define SET_BACKEND_KIND_AND_THREE_INT_PARAMS(CONFIG, BACKEND_NAME, PARAM1, \ |
44 | PARAM2, PARAM3) \ |
45 | std::tuple<int, int, int> threeIntTupleParams; \ |
46 | std::tie(BACKEND_NAME, threeIntTupleParams) = CONFIG; \ |
47 | std::tie(PARAM1, PARAM2, PARAM3) = threeIntTupleParams; |
48 | |
49 | #define SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(CONFIG, BACKEND_KIND, PARAM1, \ |
50 | PARAM2, PARAM3, PARAM4) \ |
51 | std::tuple<int, int, int, int> fourIntTupleParams; \ |
52 | std::tie(BACKEND_KIND, fourIntTupleParams) = CONFIG; \ |
53 | std::tie(PARAM1, PARAM2, PARAM3, PARAM4) = fourIntTupleParams; |
54 | |
55 | //===--------------------------------------------------------------------===// |
56 | // Convolution Parameter Sweep Tests |
57 | //===--------------------------------------------------------------------===// |
58 | |
59 | /// Create a simple network that has a single fp convolution. |
60 | static FunctionTensorPair |
61 | createAndInitConvNet(glow::PlaceholderBindings &bindings, |
62 | glow::ExecutionEngine &EE, dim_t size, dim_t convDepth, |
63 | dim_t kernel, dim_t stride, dim_t pad) { |
64 | PseudoRNG PRNG; |
65 | auto &mod = EE.getModule(); |
66 | Function *F = mod.createFunction("main" ); |
67 | auto *var = mod.createPlaceholder(ElemKind::FloatTy, |
68 | {1, size, size, convDepth}, "var" , false); |
69 | bindings.allocate(var)->getHandle().initXavier(1, PRNG); |
70 | |
71 | auto *conv = |
72 | F->createConv(bindings, "conv" , var, convDepth, kernel, stride, pad, 1); |
73 | bindings.get(cast<Placeholder>(conv->getFilter()))->getHandle().clear(0.1); |
74 | bindings.get(cast<Placeholder>(conv->getBias()))->getHandle().clear(0.1); |
75 | auto *result = F->createSave("ret" , conv); |
76 | auto *resultTensor = bindings.allocate(result->getPlaceholder()); |
77 | convertPlaceholdersToConstants(F, bindings, {var, result->getPlaceholder()}); |
78 | |
79 | return std::make_pair(F, resultTensor); |
80 | } |
81 | |
82 | /// Helper to test sweeping across a variety of configurations of a convolution |
83 | /// by comparing the results to the Interpreter given some \p allowedError. |
84 | /// \p config contains the backend to compare the Interpreter against, plus the |
85 | /// specific configuration to run for this test. \p interpElemKind and \p |
86 | /// backendElemKind are the element kinds to use for the Interpreter and |
87 | /// backend, respectively. |
88 | static void testParamSweepConv(ThreeIntTupleConfig config, |
89 | ElemKind interpElemKind, |
90 | ElemKind backendElemKind, float allowedError) { |
91 | std::string backend; |
92 | size_t size, depth, kernel; |
93 | SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, size, depth, kernel) |
94 | |
95 | LOG(INFO) << "Testing Conv with size: " << size << "; depth: " << depth |
96 | << "; kernel: " << kernel << "\n" ; |
97 | |
98 | auto boundF = std::bind(createAndInitConvNet, std::placeholders::_1, |
99 | std::placeholders::_2, size, depth, kernel, |
100 | /* stride */ 1, /* pad */ 0); |
101 | compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind, |
102 | allowedError, parCloneCountOpt); |
103 | } |
104 | |
105 | DECLARE_STATELESS_BACKEND_TEST(ConvSweepTest, ThreeIntTupleConfig); |
106 | |
107 | GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST( |
108 | SweepTest, ConvSweepTest, |
109 | ::testing::Combine(/* size */ ::testing::Values(5, 7, 15), |
110 | /* depth */ ::testing::Values(8, 64), |
111 | /* kernel */ ::testing::Values(1, 3))); |
112 | |
113 | /// Compare backend against the interpreter in Float. |
114 | TEST_P(ConvSweepTest, ConvTest_Float) { |
115 | CHECK_IF_ENABLED(); |
116 | testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f); |
117 | } |
118 | |
119 | /// Compare backend against the interpreter in Int8. |
120 | TEST_P(ConvSweepTest, ConvTest_Int8) { |
121 | CHECK_IF_ENABLED(); |
122 | testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.045f); |
123 | } |
124 | |
125 | /// Compare backend against the interpreter in FP16. |
126 | TEST_P(ConvSweepTest, ConvTest_Float16) { |
127 | CHECK_IF_ENABLED(); |
128 | testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty, |
129 | 0.005f); |
130 | } |
131 | |
132 | /// Compare backend against the interpreter in FP16. |
133 | TEST_P(ConvSweepTest, ConvTest_BFloat16) { |
134 | CHECK_IF_ENABLED(); |
135 | testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty, |
136 | 0.005f); |
137 | } |
138 | |
139 | //===--------------------------------------------------------------------===// |
140 | // BatchMatMul Parameter Sweep Tests |
141 | //===--------------------------------------------------------------------===// |
142 | |
143 | /// Create a simple network that has a single fp batch mat mul. |
144 | static FunctionTensorPair |
145 | createAndInitBatchMatMulNet(glow::PlaceholderBindings &bindings, |
146 | glow::ExecutionEngine &EE, dim_t N, dim_t A, |
147 | dim_t Z, dim_t B) { |
148 | PseudoRNG PRNG; |
149 | auto &mod = EE.getModule(); |
150 | Function *F = mod.createFunction("main" ); |
151 | auto *LHS = mod.createPlaceholder(ElemKind::FloatTy, {N, A, Z}, "LHS" , false); |
152 | auto *RHS = mod.createPlaceholder(ElemKind::FloatTy, {N, Z, B}, "RHS" , false); |
153 | bindings.allocate(LHS)->getHandle().initXavier(10, PRNG); |
154 | bindings.allocate(RHS)->getHandle().initXavier(10, PRNG); |
155 | |
156 | auto *R = F->createBatchMatMul("BMM" , LHS, RHS); |
157 | |
158 | auto *save = F->createSave("save" , R); |
159 | auto *resultTensor = bindings.allocate(save->getPlaceholder()); |
160 | |
161 | return std::make_pair(F, resultTensor); |
162 | } |
163 | |
164 | /// Helper to test sweeping across a variety of configurations of a BatchMatMul |
165 | /// by comparing the results to the Interpreter given some \p allowedError. |
166 | /// \p config contains the backend to compare the Interpreter against, plus the |
167 | /// specific configuration to run for this test. \p interpElemKind and \p |
168 | /// backendElemKind are the element kinds to use for the Interpreter and |
169 | /// backend, respectively. |
170 | static void testParamSweepBatchMatMul(ThreeIntTupleConfig config, |
171 | ElemKind interpElemKind, |
172 | ElemKind backendElemKind, |
173 | float allowedError) { |
174 | std::string backend; |
175 | size_t N, A, Z; |
176 | SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, N, A, Z); |
177 | size_t B = A; |
178 | |
179 | LOG(INFO) << "\n\tTesting BatchMatMul with N: " << N << "; A: " << A |
180 | << "; Z: " << Z << "; B: " << B << "\n" ; |
181 | |
182 | // Multiplying LHS {N, A, Z} by RHS {N, Z, B} to get result {N, A, B}. |
183 | auto boundF = std::bind(createAndInitBatchMatMulNet, std::placeholders::_1, |
184 | std::placeholders::_2, N, A, Z, B); |
185 | compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind, |
186 | allowedError, parCloneCountOpt); |
187 | } |
188 | |
189 | DECLARE_STATELESS_BACKEND_TEST(BatchMatMulSweepTest, ThreeIntTupleConfig); |
190 | |
191 | GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST( |
192 | SweepTest, BatchMatMulSweepTest, |
193 | ::testing::Combine(/* N */ ::testing::Values(1, 4, 16, 24), |
194 | /* A */ ::testing::Range(10, 16), |
195 | /* Z */ ::testing::Values(32, 64, 128, 256))); |
196 | |
197 | /// Compare backend against the interpreter in Float. |
198 | TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float) { |
199 | CHECK_IF_ENABLED(); |
200 | testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
201 | 0.0001f); |
202 | } |
203 | |
204 | /// Compare backend against the interpreter in Int8. |
205 | TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Int8) { |
206 | CHECK_IF_ENABLED(); |
207 | testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, |
208 | 0.06f); |
209 | } |
210 | |
211 | /// Compare backend against the interpreter in FP16. |
212 | TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float16) { |
213 | CHECK_IF_ENABLED(); |
214 | testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty, |
215 | 0.005f); |
216 | } |
217 | |
218 | /// Compare backend against the interpreter in FP16. |
219 | TEST_P(BatchMatMulSweepTest, BatchMatMulTest_BFloat16) { |
220 | CHECK_IF_ENABLED(); |
221 | testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty, |
222 | 0.005f); |
223 | } |
224 | |
225 | //===--------------------------------------------------------------------===// |
226 | // FullyConnected Parameter Sweep Tests |
227 | //===--------------------------------------------------------------------===// |
228 | |
229 | /// Create a simple network that has a single fp FC. |
230 | static FunctionTensorPair |
231 | createAndInitFCNet(glow::PlaceholderBindings &bindings, |
232 | glow::ExecutionEngine &EE, dim_t A, dim_t Z, dim_t B) { |
233 | PseudoRNG PRNG; |
234 | auto &mod = EE.getModule(); |
235 | Function *F = mod.createFunction("main" ); |
236 | auto *IP = mod.createPlaceholder(ElemKind::FloatTy, {A, Z}, "input" , false); |
237 | auto *WC = mod.createConstant(ElemKind::FloatTy, {Z, B}, "weights" ); |
238 | auto *BC = mod.createConstant(ElemKind::FloatTy, {B}, "bias" ); |
239 | bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG()); |
240 | BC->getPayloadMutable().getHandle().randomize(0, 0.000005, mod.getPRNG()); |
241 | WC->getPayloadMutable().getHandle().randomize(-0.4, 0.4, mod.getPRNG()); |
242 | |
243 | auto *FC = F->createFullyConnected("FC" , IP, WC, BC); |
244 | auto *save = F->createSave("save" , FC); |
245 | auto *resultTensor = bindings.allocate(save->getPlaceholder()); |
246 | |
247 | return std::make_pair(F, resultTensor); |
248 | } |
249 | |
250 | /// Helper to test sweeping across a variety of configurations of a FC by |
251 | /// comparing the results to the Interpreter given some \p allowedError. |
252 | /// \p config contains the backend to compare the Interpreter against, plus the |
253 | /// specific configuration to run for this test. \p interpElemKind and \p |
254 | /// backendElemKind are the element kinds to use for the Interpreter and |
255 | /// backend, respectively. |
256 | static void testParamSweepFC(ThreeIntTupleConfig config, |
257 | ElemKind interpElemKind, ElemKind backendElemKind, |
258 | float allowedError) { |
259 | std::string backend; |
260 | size_t A, Z, B; |
261 | SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, A, Z, B); |
262 | |
263 | LOG(INFO) << "\n\tTesting FC with A: " << A << "; Z: " << Z << "; B: " << B |
264 | << "\n" ; |
265 | |
266 | auto boundF = std::bind(createAndInitFCNet, std::placeholders::_1, |
267 | std::placeholders::_2, A, Z, B); |
268 | compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind, |
269 | allowedError, parCloneCountOpt); |
270 | } |
271 | |
272 | DECLARE_STATELESS_BACKEND_TEST(FCSweepTest, ThreeIntTupleConfig); |
273 | |
274 | GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST( |
275 | SweepTest, FCSweepTest, |
276 | ::testing::Combine( |
277 | /* A */ ::testing::Values(1, 4, 16, 64), |
278 | /* Z */ ::testing::Values(16, 128, 256, 512, 1024, 2048, 4096), |
279 | /* B */ ::testing::Values(1, 48, 64, 256, 1024))); |
280 | |
281 | /// Compare backend against the interpreter in Float. |
282 | TEST_P(FCSweepTest, FCTest_Float) { |
283 | CHECK_IF_ENABLED(); |
284 | testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f); |
285 | } |
286 | |
287 | /// Compare backend against the interpreter in Int8. |
288 | TEST_P(FCSweepTest, FCTest_Int8) { |
289 | CHECK_IF_ENABLED(); |
290 | testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.065f); |
291 | } |
292 | |
293 | /// Compare backend against the interpreter in FP16. |
294 | TEST_P(FCSweepTest, FCTest_Float16) { |
295 | CHECK_IF_ENABLED(); |
296 | testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty, 0.005f); |
297 | } |
298 | |
299 | /// Compare backend against the interpreter in BFloat16. |
300 | TEST_P(FCSweepTest, FCTest_BFloat16) { |
301 | CHECK_IF_ENABLED(); |
302 | testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.005f); |
303 | } |
304 | |
305 | //===--------------------------------------------------------------------===// |
306 | // Concat Parameter Sweep Tests |
307 | //===--------------------------------------------------------------------===// |
308 | |
309 | /// Create a simple network that has a single fp Concat. |
310 | static FunctionTensorPair |
311 | createAndInitConcatNet(glow::PlaceholderBindings &bindings, |
312 | glow::ExecutionEngine &EE, size_t numInputs, |
313 | size_t numDims, size_t maxLength, size_t axis) { |
314 | PseudoRNG PRNG; |
315 | auto &mod = EE.getModule(); |
316 | Function *F = mod.createFunction("main" ); |
317 | |
318 | // Make leading dimensions smaller than trailing. Reduces size of tests and is |
319 | // also in line with typical tests. |
320 | std::vector<dim_t> dims(numDims, maxLength); |
321 | for (size_t i = 0; i < numDims; i++) { |
322 | dims[numDims - 1 - i] /= std::pow(2, i); |
323 | } |
324 | |
325 | std::vector<NodeValue> inputs(numInputs); |
326 | for (size_t i = 0; i < numInputs; i++) { |
327 | auto *IP = mod.createPlaceholder(ElemKind::FloatTy, dims, "input" , false); |
328 | bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG()); |
329 | assert(IP); |
330 | inputs[i] = IP->getOutput(); |
331 | } |
332 | |
333 | auto *concat = F->createConcat("concat" , inputs, axis); |
334 | auto *save = F->createSave("save" , concat); |
335 | auto *resultTensor = bindings.allocate(save->getPlaceholder()); |
336 | |
337 | return std::make_pair(F, resultTensor); |
338 | } |
339 | |
340 | /// Helper to test sweeping across a variety of configurations of a Concat by |
341 | /// comparing the results to the Interpreter given some \p allowedError. |
342 | /// \p config contains the backend to compare the Interpreter against, plus the |
343 | /// specific configuration to run for this test. \p interpElemKind and \p |
344 | /// backendElemKind are the element kinds to use for the Interpreter and |
345 | /// backend, respectively. |
346 | static void testParamSweepConcat(FourIntTupleConfig config, |
347 | ElemKind interpElemKind, |
348 | ElemKind backendElemKind, float allowedError) { |
349 | std::string backend; |
350 | size_t numInputs, numDims, maxLength, axis; |
351 | SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(config, backend, numInputs, numDims, |
352 | maxLength, axis); |
353 | // Exit if axis outside of numDims. |
354 | if (axis >= numDims) { |
355 | return; |
356 | } |
357 | |
358 | LOG(INFO) << "\n\tTesting Concat with numInputs: " << numInputs |
359 | << "; numDims: " << numDims << "; maxLength: " << maxLength |
360 | << "; axis: " << axis << "\n" ; |
361 | |
362 | auto boundF = |
363 | std::bind(createAndInitConcatNet, std::placeholders::_1, |
364 | std::placeholders::_2, numInputs, numDims, maxLength, axis); |
365 | compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind, |
366 | allowedError, parCloneCountOpt); |
367 | } |
368 | |
369 | DECLARE_STATELESS_BACKEND_TEST(ConcatSweepTest, FourIntTupleConfig); |
370 | |
371 | GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST( |
372 | SweepTest, ConcatSweepTest, |
373 | ::testing::Combine(/* numInputs */ ::testing::Values(1, 2, 4, 8, 16, 32, 64, |
374 | 128, 192, 256), |
375 | /* numDims */ ::testing::Range(1, 4), |
376 | /* maxLength */ ::testing::Values(16, 32, 64, 128), |
377 | /* axis */ ::testing::Range(0, 3))); |
378 | |
379 | /// Compare backend against the interpreter in Float. |
380 | TEST_P(ConcatSweepTest, ConcatTest_Float) { |
381 | CHECK_IF_ENABLED(); |
382 | testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0f); |
383 | } |
384 | |
385 | /// Compare backend against the interpreter in Int8. Note that we do not use the |
386 | /// same ElemKind for the Interpreter; this is because the backend will |
387 | /// quantize/dequantize the input/result anyway, so the comparison wouldn't be |
388 | /// purely on data movement. |
389 | TEST_P(ConcatSweepTest, ConcatTest_Int8) { |
390 | CHECK_IF_ENABLED(); |
391 | testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, |
392 | 0.002f); |
393 | } |
394 | |
395 | /// Compare backend against the interpreter in Float16. Note that we do not use |
396 | /// the same ElemKind for the Interpreter; this is because the backend will |
397 | /// down/up convert the input/result anyway, so the comparison wouldn't be |
398 | /// purely on data movement. |
399 | TEST_P(ConcatSweepTest, ConcatTest_Float16) { |
400 | CHECK_IF_ENABLED(); |
401 | testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty, |
402 | 0.0001f); |
403 | } |
404 | |
405 | /// Compare backend against the interpreter in BFloat16. Note that we do not use |
406 | /// the same ElemKind for the Interpreter; this is because the backend will |
407 | /// down/up convert the input/result anyway, so the comparison wouldn't be |
408 | /// purely on data movement. |
409 | TEST_P(ConcatSweepTest, ConcatTest_BFloat16) { |
410 | CHECK_IF_ENABLED(); |
411 | testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty, |
412 | 0.0001f); |
413 | } |
414 | |
415 | //===--------------------------------------------------------------------===// |
416 | // SLWS Parameter Sweep Tests |
417 | //===--------------------------------------------------------------------===// |
418 | |
419 | /// Create a simple network that has a single fp SLWS. |
420 | static FunctionTensorPair |
421 | createAndInitSLWSNet(glow::PlaceholderBindings &bindings, |
422 | glow::ExecutionEngine &EE, dim_t embeddingRows, |
423 | dim_t embeddingDim, dim_t numLengths, bool rowwiseQuantize, |
424 | bool fused, bool FP16, bool accumFP16) { |
425 | PseudoRNG PRNG; |
426 | auto &mod = EE.getModule(); |
427 | Function *F = mod.createFunction("main" ); |
428 | |
429 | // Initialize lengths according to the number provided by the test. Note that |
430 | // we arbitrarily set them between [80,120]. |
431 | auto *lengths = |
432 | mod.createPlaceholder(ElemKind::Int32ITy, {numLengths}, "lengths" , false); |
433 | auto LH = bindings.allocate(lengths)->getHandle<int32_t>(); |
434 | LH.randomize(80, 120, mod.getPRNG()); |
435 | |
436 | // Get the sum of the lengths to then use as the size for indices and weights. |
437 | dim_t sumOfLengths = 0; |
438 | for (const int32_t &e : LH) { |
439 | sumOfLengths += e; |
440 | } |
441 | |
442 | // Initialize indices to size of sum of lengths. Randomly set them to point |
443 | // somewhere inside the embedding. |
444 | auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {sumOfLengths}, |
445 | "indices" , false); |
446 | bindings.allocate(indices)->getHandle<int64_t>().randomize( |
447 | 0, embeddingRows - 1, mod.getPRNG()); |
448 | |
449 | // Xavier initialize the weights with the correct data type. |
450 | Constant *weights; |
451 | if (FP16) { |
452 | weights = |
453 | mod.createConstant(ElemKind::Float16Ty, {sumOfLengths}, "weights" ); |
454 | weights->getPayloadMutable().getHandle<float16_t>().initXavier( |
455 | weights->getType()->size() * 2, mod.getPRNG()); |
456 | } else { |
457 | weights = mod.createConstant(ElemKind::FloatTy, {sumOfLengths}, "weights" ); |
458 | weights->getPayloadMutable().getHandle<float>().initXavier( |
459 | weights->getType()->size() * 2, mod.getPRNG()); |
460 | } |
461 | |
462 | // Create the embedding; non-RWQ versions will simply create a Constant with |
463 | // it, while RWQ versions will use its data to create a RWQ Constant |
464 | // internally in the Node constructor. |
465 | Tensor embeddingT(ElemKind::FloatTy, {embeddingRows, embeddingDim}); |
466 | embeddingT.getHandle().initXavier(embeddingT.size() * 2, mod.getPRNG()); |
467 | |
468 | // Create the SLWS based on provided options. |
469 | Node *SLWS; |
470 | if (!rowwiseQuantize) { |
471 | auto *embeddingC = mod.createConstant("embedding" , std::move(embeddingT)); |
472 | SLWS = F->createSparseLengthsWeightedSum("SLWS" , embeddingC, weights, |
473 | indices, lengths); |
474 | } else { |
475 | if (fused) { |
476 | const ElemKind precision = |
477 | FP16 ? ElemKind::UInt8FusedFP16QTy : ElemKind::UInt8FusedQTy; |
478 | SLWS = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
479 | "FRQSLWS" , embeddingT, weights, indices, lengths, precision, |
480 | accumFP16); |
481 | } else { |
482 | const ElemKind precision = FP16 ? ElemKind::Float16Ty : ElemKind::FloatTy; |
483 | SLWS = F->createRowwiseQuantizedSparseLengthsWeightedSum( |
484 | "RQSLWS" , embeddingT, weights, indices, lengths, |
485 | quantization::Schema::Asymmetric, precision, accumFP16); |
486 | } |
487 | } |
488 | auto *save = F->createSave("save" , SLWS); |
489 | auto *resultTensor = bindings.allocate(save->getPlaceholder()); |
490 | |
491 | return std::make_pair(F, resultTensor); |
492 | } |
493 | |
494 | /// Helper to test sweeping across a variety of configurations of a SLWS by |
495 | /// comparing the results to the Interpreter given some \p allowedError. |
496 | /// \p config contains the backend to compare the Interpreter against, plus the |
497 | /// specific configuration to run for this test. \p interpElemKind and \p |
498 | /// backendElemKind are the element kinds to use for the Interpreter and |
499 | /// backend, respectively. Pass in options for the test \p rowwiseQuantize, |
500 | /// \p fused, \p FP16, and \p accumFP16. |
501 | static void testParamSweepSLWS(ThreeIntTupleConfig config, |
502 | ElemKind interpElemKind, |
503 | ElemKind backendElemKind, float allowedError, |
504 | bool rowwiseQuantize, bool fused, bool FP16, |
505 | bool accumFP16) { |
506 | std::string backend; |
507 | size_t embeddingRows, embeddingDim, numLengths; |
508 | SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, embeddingRows, |
509 | embeddingDim, numLengths); |
510 | |
511 | LOG(INFO) << "\n\tTesting SLWS with embeddingRows: " << embeddingRows |
512 | << "; embeddingDim: " << embeddingDim |
513 | << "; numLengths: " << numLengths << "\n" ; |
514 | |
515 | auto boundF = std::bind(createAndInitSLWSNet, std::placeholders::_1, |
516 | std::placeholders::_2, embeddingRows, embeddingDim, |
517 | numLengths, rowwiseQuantize, fused, FP16, accumFP16); |
518 | compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind, |
519 | allowedError, parCloneCountOpt); |
520 | } |
521 | |
522 | DECLARE_STATELESS_BACKEND_TEST(SLWSSweepTest, ThreeIntTupleConfig); |
523 | |
524 | GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST( |
525 | SweepTest, SLWSSweepTest, |
526 | ::testing::Combine( |
527 | /* embeddingRows */ ::testing::Values(100, 1000, 10000, 100000), |
528 | /* embeddingDim */ ::testing::Values(32, 64, 96, 128), |
529 | /* numLengths */ ::testing::Values(16, 32, 64, 128, 256))); |
530 | |
531 | /// Compare backend against the interpreter. |
532 | TEST_P(SLWSSweepTest, SLWS_Float) { |
533 | CHECK_IF_ENABLED(); |
534 | testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
535 | 0.000001f, |
536 | /* rowwiseQuantize */ false, |
537 | /* fused */ false, /* FP16 */ false, |
538 | /* accumFP16 */ false); |
539 | } |
540 | |
541 | /// Compare backend against the interpreter in Float. |
542 | TEST_P(SLWSSweepTest, RWQSLWS_Float) { |
543 | CHECK_IF_ENABLED(); |
544 | testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
545 | 0.000001f, |
546 | /* rowwiseQuantize */ true, |
547 | /* fused */ false, /* FP16 */ false, |
548 | /* accumFP16 */ false); |
549 | } |
550 | |
551 | /// Compare backend against the interpreter in Float. |
552 | TEST_P(SLWSSweepTest, FRWQSLWS_Float) { |
553 | CHECK_IF_ENABLED(); |
554 | testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
555 | 0.000001f, |
556 | /* rowwiseQuantize */ true, |
557 | /* fused */ true, /* FP16 */ false, |
558 | /* accumFP16 */ false); |
559 | } |
560 | |
561 | /// Compare backend against the interpreter in Float. |
562 | TEST_P(SLWSSweepTest, RWQSLWS_Float16) { |
563 | // Note: not currently enabled for any open-source backends, as only the |
564 | // Interpreter supports this. |
565 | CHECK_IF_ENABLED(); |
566 | testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
567 | 0.000001f, |
568 | /* rowwiseQuantize */ true, |
569 | /* fused */ false, /* FP16 */ true, |
570 | /* accumFP16 */ false); |
571 | } |
572 | |
573 | /// Compare backend against the interpreter in Float. |
574 | TEST_P(SLWSSweepTest, FRWQSLWS_Float16) { |
575 | // Note: not currently enabled for any open-source backends, as only the |
576 | // Interpreter supports this. |
577 | CHECK_IF_ENABLED(); |
578 | testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
579 | 0.000001f, |
580 | /* rowwiseQuantize */ true, |
581 | /* fused */ true, /* FP16 */ true, |
582 | /* accumFP16 */ false); |
583 | } |
584 | |
585 | /// Compare backend against the interpreter in Float. |
586 | TEST_P(SLWSSweepTest, RWQSLWS_Float16_AccumFloat16) { |
587 | // Note: not currently enabled for any open-source backends, as only the |
588 | // Interpreter supports this. |
589 | CHECK_IF_ENABLED(); |
590 | testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
591 | 0.000001f, |
592 | /* rowwiseQuantize */ true, |
593 | /* fused */ false, /* FP16 */ true, |
594 | /* accumFP16 */ true); |
595 | } |
596 | |
597 | /// Compare backend against the interpreter in Float. |
598 | TEST_P(SLWSSweepTest, FRWQSLWS_Float16_AccumFloat16) { |
599 | // Note: not currently enabled for any open-source backends, as only the |
600 | // Interpreter supports this. |
601 | CHECK_IF_ENABLED(); |
602 | testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, |
603 | 0.000001f, |
604 | /* rowwiseQuantize */ true, |
605 | /* fused */ true, /* FP16 */ true, |
606 | /* accumFP16 */ true); |
607 | } |
608 | |
609 | int main(int argc, char **argv) { |
610 | ::testing::InitGoogleTest(&argc, argv); |
611 | llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); |
612 | llvm::cl::ParseCommandLineOptions(argc, argv); |
613 | return RUN_ALL_TESTS(); |
614 | } |
615 | |