1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "BackendTestUtils.h"
17
18#include "glow/Base/TensorSerialization.h"
19#include "glow/Converter/TypeAToTypeBFunctionConverter.h"
20#include "glow/ExecutionEngine/ExecutionEngine.h"
21#include "glow/Exporter/ONNXModelWriter.h"
22#include "glow/Flags/Flags.h"
23#include "glow/Graph/Graph.h"
24#include "glow/Partitioner/Partitioner.h"
25#include "glow/Runtime/DeferredWeightLoader.h"
26#include "glow/Runtime/HostManager/HostManager.h"
27#include "lib/Onnxifi/Base.h"
28
29#include <algorithm>
30#include <cmath>
31#include <future>
32#include <random>
33
34#include "gtest/gtest.h"
35
36#include "llvm/Support/CommandLine.h"
37#include "llvm/Support/FileSystem.h"
38
39constexpr size_t MAX_MEMORY = 64e+9;
40
41using namespace glow;
42
43namespace {
44llvm::cl::OptionCategory recSysTestCat("RecSys Category");
45
46llvm::cl::opt<bool> enableStaticPlaceholderOpt(
47 "enable-static-placeholder", llvm::cl::desc("Enable Static Placeholder."),
48 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(recSysTestCat));
49
50llvm::cl::opt<unsigned> miniBatchOpt("mini-batch", llvm::cl::desc("Minibatch."),
51 llvm::cl::Optional, llvm::cl::init(8),
52 llvm::cl::cat(recSysTestCat));
53
54llvm::cl::opt<unsigned> concurrentReqestsOpt(
55 "concurrent-count", llvm::cl::desc("Number of concurrent requests."),
56 llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(recSysTestCat));
57
58llvm::cl::opt<unsigned>
59 repsOpt("reps", llvm::cl::desc("Number of benchmark repetitions."),
60 llvm::cl::Optional, llvm::cl::init(1),
61 llvm::cl::cat(recSysTestCat));
62
63llvm::cl::opt<unsigned> embeddingDimOpt("embedding-dim",
64 llvm::cl::desc("Embedding dim."),
65 llvm::cl::Optional, llvm::cl::init(64),
66 llvm::cl::cat(recSysTestCat));
67
68llvm::cl::opt<unsigned> denseDimOpt("dense-dim", llvm::cl::desc("Dense dim."),
69 llvm::cl::Optional, llvm::cl::init(800),
70 llvm::cl::cat(recSysTestCat));
71
72llvm::cl::opt<unsigned> numHiddenBottomMLPLayersOpt(
73 "num-hidden-bottom-mlp-layers",
74 llvm::cl::desc("Number of hidden bottom MLP layers."), llvm::cl::Optional,
75 llvm::cl::init(3), llvm::cl::cat(recSysTestCat));
76
77llvm::cl::list<unsigned> bottomMLPIntermediateDimsOpt(
78 "bottom-mlp-intermediate-dims",
79 llvm::cl::desc(
80 "Comma-separated list of intermediate dim for each of the bottom MLP "
81 "hidden layers and output layer. Will wrap around to the start of the "
82 "list and reuse dimensions if less than the number of layers. If "
83 "unprovided, default is 1024."),
84 llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
85 llvm::cl::cat(recSysTestCat));
86
87llvm::cl::opt<unsigned>
88 numHiddenTopMLPLayersOpt("num-hidden-top-mlp-layers",
89 llvm::cl::desc("Number of hidden top MLP layers."),
90 llvm::cl::Optional, llvm::cl::init(3),
91 llvm::cl::cat(recSysTestCat));
92
93llvm::cl::list<unsigned> topMLPIntermediateDimsOpt(
94 "top-mlp-intermediate-dims",
95 llvm::cl::desc(
96 "Comma-separated list of intermediate dim for each of the top MLP "
97 "hidden layers and output layer. Will wrap around to the start of the "
98 "list and reuse dimensions if less than the number of layers. If "
99 "unprovided, default is 1024."),
100 llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
101 llvm::cl::cat(recSysTestCat));
102
103llvm::cl::list<unsigned> lengthsMinMaxOpt(
104 "lengths-min-max",
105 llvm::cl::desc("Comma separated [min, max) value to be used when "
106 "generating random lengths inputs for SLS/SLWS. If left "
107 "unspecified, will use [90, 110)."),
108 llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
109 llvm::cl::cat(recSysTestCat));
110
111llvm::cl::opt<unsigned> randomSeedContentOpt(
112 "random-seed-content",
113 llvm::cl::desc(
114 "Seed for the random data generation for indices and weights tensor"),
115 llvm::cl::Optional, llvm::cl::init(2001), llvm::cl::cat(recSysTestCat));
116
117llvm::cl::opt<unsigned> randomSeedLengthsOpt(
118 "random-seed-lengths",
119 llvm::cl::desc("Seed for the random data generation for lengths tensor"),
120 llvm::cl::Optional, llvm::cl::init(2001), llvm::cl::cat(recSysTestCat));
121
122llvm::cl::list<unsigned> tableSizesOpt(
123 "embedding-table-sizes",
124 llvm::cl::desc("Comma-separated list of embedding table sizes."),
125 llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
126 llvm::cl::cat(recSysTestCat));
127
128llvm::cl::list<unsigned> tableCountsOpt(
129 "embedding-table-counts",
130 llvm::cl::desc("Comma-separated list of embedding table counts, "
131 "corresponding to a count for each size listed in "
132 "embedding-table-sizes."),
133 llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
134 llvm::cl::cat(recSysTestCat));
135
136llvm::cl::opt<unsigned> deviceMemCapacityOpt(
137 "device-mem-capacity",
138 llvm::cl::desc("Device memory capacity in kB. Default is dependent on the "
139 "test in order to potentially force partitioning."),
140 llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(recSysTestCat));
141
142llvm::cl::opt<unsigned> numDevicesOpt(
143 "num-devices", llvm::cl::desc("Number of devices to use for partitioning."),
144 llvm::cl::Optional, llvm::cl::init(2), llvm::cl::cat(recSysTestCat));
145
146llvm::cl::opt<unsigned> partitioningNumDevicesOpt(
147 "partitioning-num-devices",
148 llvm::cl::desc(
149 "Number of devices to override sparseNNPartitioningNumCards."),
150 llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(recSysTestCat));
151
152llvm::cl::opt<std::string> traceDir(
153 "trace-dir",
154 llvm::cl::desc("Directory used to store Glow trace events files. If not "
155 "used, tracing is not enabled."),
156 llvm::cl::Optional, llvm::cl::cat(recSysTestCat));
157
158llvm::cl::opt<bool> dumpBinaryResults(
159 "dump-binary-results",
160 llvm::cl::desc("Dump raw binary Tensor results after execution."),
161 llvm::cl::init(false), llvm::cl::cat(recSysTestCat));
162
163llvm::cl::opt<bool> dumpModelInputs(
164 "dump-model-inputs",
165 llvm::cl::desc(
166 "Dump model and inputs into format that repro binary can run."),
167 llvm::cl::init(false), llvm::cl::cat(recSysTestCat));
168
169llvm::cl::opt<bool> dumpFinalGraph(
170 "dump-final-graph",
171 llvm::cl::desc(
172 "Call dumpDag on each Function passed to the backend for compilation."),
173 llvm::cl::init(false), llvm::cl::cat(recSysTestCat));
174
175llvm::cl::opt<bool> saturateHost("saturate-host",
176 llvm::cl::desc("Enable host saturation."),
177 llvm::cl::init(false),
178 llvm::cl::cat(recSysTestCat));
179
180llvm::cl::opt<bool> fuseScaleOffsetFp32Opt(
181 "glow_global_fused_scale_offset_fp32",
182 llvm::cl::desc(
183 "Enable converting scale/offset in sls's input data from fp16 to fp32"),
184 llvm::cl::init(false), llvm::cl::cat(recSysTestCat));
185
186llvm::cl::opt<bool> skipCorrectnessCheck(
187 "skip_correctness_check",
188 llvm::cl::desc("Skip correctness check with Interpreter backend"),
189 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(recSysTestCat));
190} // namespace
191
192class TestDeferredWeightLoader : public DeferredWeightLoader {
193public:
194 Error loadNextWeight() override {
195 position_++;
196 return Error::success();
197 }
198 Error setSrc(void *loaderObject) override { return Error::success(); }
199
200 Tensor *addWeight(TypeRef ty) {
201 // auto weight = Tensor(ty);
202 weights_.push_back(Tensor(ty));
203 return &weights_.at(weights_.size() - 1);
204 }
205
206 void addName(std::string name) { names_.push_back(name); }
207 void setTypeInfo(std::map<std::string, Type> info) override {}
208
209 std::string getName() override {
210 if (position_ >= int(names_.size())) {
211 return "";
212 }
213 return names_[position_];
214 }
215
216 Tensor *getTensor() override {
217 if (position_ >= int(weights_.size())) {
218 return nullptr;
219 }
220 return &weights_[position_];
221 }
222
223private:
224 std::vector<Tensor> weights_{};
225 std::vector<std::string> names_{};
226 int position_{-1};
227};
228
229/// Fills the tensor \p H with some stable random data with the seed \p seed
230/// and the range [-scale .. scale].
231static void fillStableRandomData(Handle<float> H, size_t seed,
232 float scale = 1) {
233 for (size_t i = 0, e = H.size(); i < e; i++) {
234 H.raw(i) = scale * (float((int(i * 1921 + seed) % 100) - 50) / 50);
235 }
236}
237
238/// Fills the tensor \p H with some stable random integers with the seed \p
239/// seed and the range [0, scale).
240template <typename T>
241static void fillStableRandomIndex(Handle<T> H, size_t seed, size_t min = 0,
242 size_t max = 10) {
243 for (size_t i = 0, e = H.size(); i < e; i++) {
244 H.raw(i) = min + (int(i * 1921 + seed) % (max - min));
245 }
246}
247template void fillStableRandomIndex(Handle<int64_t> Handle, size_t seed,
248 size_t min, size_t max);
249template void fillStableRandomIndex(Handle<int32_t> Handle, size_t seed,
250 size_t min, size_t max);
251
252/// Sum of all elements in Tensor.
253static size_t sumOfElements(Handle<int32_t> H) {
254 size_t sum = 0;
255 for (size_t i = 0, e = H.size(); i < e; i++) {
256 sum += H.raw(i);
257 }
258 return sum;
259}
260
261/// Tests a simplified Recommendation System model.
262///
263/// The RecSys model has four components:
264/// * An initial Multilayer Perceptron acting in the inputs.
265/// * Some number of Sparse Features: SparseLengthSum nodes acting on
266/// embedding tables (see https://caffe2.ai/docs/sparse-operations.html).
267/// * An interaction layer bringing together the output for the top MLP and
268/// the sparse features.
269/// * A final MLP acting on the result of the interaction.
270///
271/// The final result is a float indicating the strength of the recommendation.
272///
273///
274/// +------+
275/// |Output|
276/// +--^---+
277/// |
278/// +---+---+
279/// | TOP |
280/// | |
281/// | MLP |
282/// +---^---+
283/// |
284/// |
285/// +-------+--------+
286/// | Interaction <---------+
287/// +-----> <---+ |
288/// | +--------^-----^-+ | |
289/// | | | | |
290/// +--+----+ +-+-+ +-+-+ +-+-+ +-+-+
291/// | Bottom| |SLS| |SLS| |SLS| |SLS|
292/// | | +---+ +---+ +---+ +---+
293/// | MLP | Sparse Features
294/// +---^---+
295/// |
296/// +---+---+
297/// | Input |
298/// +-------+
299///
300class RecommendationSystemTest : public BackendTest {
301public:
302 RecommendationSystemTest() : BackendTest(/* deviceMemory */ MAX_MEMORY) {}
303
304protected:
305 ExecutionContext context_;
306 PlaceholderBindings *bindings_;
307 PrecisionConfiguration precConfig_;
308 PrecisionConfiguration precConfigForInterpreter_;
309
310 // Test Config:
311 dim_t miniBatch;
312 dim_t embeddingDim;
313 dim_t denseDim;
314 std::vector<dim_t> tableSizes;
315 std::vector<dim_t> bottomMLPIntermediateDims;
316 std::vector<dim_t> topMLPIntermediateDims;
317 size_t lengthsMin;
318 size_t lengthsMax;
319
320 // Used to configure correct precision settings:
321 bool quantizeSLWSData{false};
322 bool quantizeFC{false};
323 bool convertToFP16{false};
324 bool useFP16SLWS{false};
325 bool useFP16AccumSLWS{false};
326
327 bool convertFusedToFP16{false};
328 bool convert4or8BitFusedToFP32{false};
329
330 // Used to enable static placeholder:
331 bool enableStaticPlaceholder{false};
332
333 // Whether to use SLWS with gather of weights, instead of SLS.
334 bool gatherWeights{false};
335
336 // Used to disable Interpreter deferred weight loading, because we run
337 // FBA and Interpreter tests sequentially.
338 bool isInterpreter{false};
339
340 // Partitioner config:
341 uint64_t deviceMemCapacity;
342 size_t numDevices;
343 bool useSparseNNPartitioning{false};
344 bool sparseNNPartitioningAddSLSConcats{false};
345 int32_t sparseNNPartitioningNumCards{1};
346 int64_t sparseNNPartitioningSLSKbytes{1000};
347 int32_t sparseNNPartitioningNumCoresSLS{1};
348 int32_t sparseNNPartitioningNumCoresOther{1};
349
350 // Result from executing the unpartitioned model on the backend being tested.
351 Tensor *resultTensor{nullptr};
352
353 /// Helper that \returns intermediate dims given a provided list of dims \p
354 /// providedIntermediateDims and the number of layers needed \p numLayers. If
355 /// the provided list is empty then all dims will be set to
356 /// \p defaultIntermediateDim. If the size of \p providedIntermediateDims is
357 /// less than \p numLayers then it will wrap around and reuse
358 /// \p providedIntermediateDims until \p numLayers are added to the returned
359 /// vector.
360 static std::vector<dim_t>
361 getIntermediateDims(llvm::ArrayRef<unsigned> providedIntermediateDims,
362 unsigned numLayers, dim_t defaultIntermediateDim = 1024) {
363 std::vector<dim_t> destIntermediateDims;
364 std::vector<dim_t> dims(providedIntermediateDims.begin(),
365 providedIntermediateDims.end());
366 if (dims.empty()) {
367 dims.push_back(defaultIntermediateDim);
368 }
369 const size_t numProvidedDimsTop = dims.size();
370 // Note: Add one extra intermediate dim, which is used by the output layer
371 // of the MLP. The input layer is set based on its own input.
372 for (dim_t i = 0, e = numLayers + 1; i < e; i++) {
373 destIntermediateDims.push_back(dims[i % numProvidedDimsTop]);
374 }
375 return destIntermediateDims;
376 }
377
378 void SetUp() override {
379 bindings_ = context_.getPlaceholderBindings();
380
381 /// Test configuration, tweak here:
382 miniBatch = miniBatchOpt;
383 embeddingDim = embeddingDimOpt;
384 denseDim = denseDimOpt;
385 lengthsMin = 90;
386 lengthsMax = 111;
387
388 if (!tableSizesOpt.empty()) {
389 if (!tableCountsOpt.empty()) {
390 CHECK_EQ(tableSizesOpt.size(), tableCountsOpt.size())
391 << "Embedding table sizes and counts must be same length.";
392 for (size_t i = 0, e = tableSizesOpt.size(); i < e; i++) {
393 for (size_t j = 0, f = tableCountsOpt[i]; j < f; j++) {
394 tableSizes.push_back(tableSizesOpt[i]);
395 }
396 }
397 } else {
398 tableSizes =
399 std::vector<dim_t>(tableSizesOpt.begin(), tableSizesOpt.end());
400 }
401 // Stable randomization of the order of the tables.
402 std::shuffle(tableSizes.begin(), tableSizes.end(), std::mt19937());
403 } else {
404 tableSizes = {8000, 6000, 7000, 9000, 12000,
405 8000, 6000, 7000, 9000, 12000};
406 }
407
408 // Set up the bottom and top MLP intermediate dimensions.
409 bottomMLPIntermediateDims = getIntermediateDims(
410 bottomMLPIntermediateDimsOpt, numHiddenBottomMLPLayersOpt);
411 topMLPIntermediateDims = getIntermediateDims(topMLPIntermediateDimsOpt,
412 numHiddenTopMLPLayersOpt);
413
414 if (!lengthsMinMaxOpt.empty()) {
415 assert(lengthsMinMaxOpt.size() == 2 &&
416 "If min and max are used, must be 2 values provided");
417 lengthsMin = lengthsMinMaxOpt[0];
418 lengthsMax = lengthsMinMaxOpt[1];
419 assert(lengthsMinMaxOpt[0] < lengthsMinMaxOpt[1] && "Min must be < max");
420 }
421
422 // Create TraceContext if trace file path is provided.
423 if (!traceDir.empty()) {
424 context_.setTraceContext(
425 glow::make_unique<TraceContext>(TraceEvent::TraceLevel::STANDARD));
426 }
427
428 // If device memory capacity is unset via command line, use 32MB by default.
429 deviceMemCapacity =
430 (int64_t)1024 *
431 ((deviceMemCapacityOpt != 0) ? deviceMemCapacityOpt : 1024 * 32);
432
433 numDevices = numDevicesOpt;
434 }
435
436 // dump inputs into onnx file which can run with repro binary.
437 void dumpInputs() {
438 std::stringstream ss;
439 ss << "input_0.onnx";
440 std::ofstream of(ss.str(), std::ios::binary);
441 auto *resultPHBindings = context_.getPlaceholderBindings();
442 ONNX_NAMESPACE::GraphProto inputG;
443 for (auto &pair : resultPHBindings->pairs()) {
444 auto *t = inputG.add_initializer();
445 auto *PH = pair.first;
446 const auto &resultTensor = pair.second;
447 ONNXModelWriter::writeTensor(resultTensor, t,
448 /*useGlowCustomOps*/ true);
449 t->set_name(PH->getName().str());
450 }
451 std::string buffer;
452 inputG.SerializeToString(&buffer);
453 of << buffer;
454 }
455
456 // dump outputs into onnx file which can run with repro binary.
457 void dumpOutputs() {
458 std::stringstream ss;
459 ss << "output_0.onnx";
460 std::ofstream of(ss.str(), std::ios::binary);
461 ONNX_NAMESPACE::GraphProto inputG;
462 auto *t = inputG.add_initializer();
463 ONNXModelWriter::writeTensor(*resultTensor, t,
464 /*useGlowCustomOps*/ true);
465 t->set_name("save");
466 std::string buffer;
467 inputG.SerializeToString(&buffer);
468 of << buffer;
469 }
470
471 void TearDown() override {
472 if (dumpBinaryResults) {
473 ASSERT_TRUE(resultTensor) << "Could not dump result tensor, was nullptr";
474 llvm::SmallString<64> path;
475 auto tempFileRes =
476 llvm::sys::fs::createTemporaryFile("result", "bin", path);
477 if (tempFileRes.value() != 0) {
478 FAIL() << "Failed to create temp file to write into.";
479 }
480 std::cout
481 << "Dumping binary results of "
482 << ::testing::UnitTest::GetInstance()->current_test_info()->name()
483 << " to " << path.data() << std::endl;
484 TensorSerializationOptions opts;
485 opts.withType = false;
486 dumpTensorToBinaryFile(*resultTensor, path, opts);
487 }
488
489 if (dumpModelInputs) {
490 dumpInputs();
491 }
492
493 resultTensor = nullptr;
494 bindings_->clear();
495
496 auto *traceContext = context_.getTraceContext();
497
498 if (traceContext) {
499 // If traceContext exists, that means trace data was collected and needs
500 // to be dumped to a file.
501
502 // Get the test case and test names. They will be used to name the file.
503 const ::testing::TestInfo *const testInfo =
504 ::testing::UnitTest::GetInstance()->current_test_info();
505 std::string testName(testInfo->name());
506 std::string testCaseName(testInfo->test_case_name());
507
508 // Replace all '/' in the test case and test names with '-' to preclude
509 // errors related to directories not existing.
510 for (auto &c : testName) {
511 if (c == '/') {
512 c = '-';
513 }
514 }
515
516 for (auto &c : testCaseName) {
517 if (c == '/') {
518 c = '-';
519 }
520 }
521
522 auto traceFileName =
523 strFormat("%s/%s-%s.json", traceDir.getValue().c_str(),
524 testName.c_str(), testCaseName.c_str());
525 traceContext->dump(traceFileName);
526 }
527 }
528
529 /// Creates a Multi-layer perceptron network consisting of start & end FCs
530 /// with \p intermediateLayers hidden layers.
531 /// * All weights and biases are random.
532 /// * All internal activations are RELU.
533 /// * Parent node \p N_ has output dimension \p inputDim.
534 /// * Hidden layers have dimension of \p intDim * intDim.
535 /// * Output layer has output dimension \p outputDim.
536 static NodeValue createMLP(Module &mod, Function *F_, Node *N_,
537 dim_t inputDim, llvm::ArrayRef<dim_t> intDims,
538 dim_t outputDim, dim_t intermediateLayers) {
539 assert(intermediateLayers > 0);
540
541 const dim_t firstIntDim = intDims[0];
542
543 // Type object for the internal layers.
544 // Note: dimension argument is a placeholder and will get filled out by each
545 // createRandomizedConstant invocation.
546 auto internalType = mod.uniqueType(ElemKind::FloatTy, {1});
547
548 /// Initial
549 auto *initial_bias = createRandomizedConstant(
550 mod, internalType, {firstIntDim}, "initial_bias");
551 auto *initial_weight = createRandomizedConstant(
552 mod, internalType, {inputDim, firstIntDim}, "initial_weight");
553
554 FullyConnectedNode *initial_layer = F_->createFullyConnected(
555 "dense", N_, initial_weight,
556 initial_bias); // Output is size {MB, intermediate dim}
557 NodeValue last = F_->createRELU("relu1", initial_layer);
558
559 /// Intermediate
560 for (unsigned i = 0; i < intermediateLayers; ++i) {
561 // The current intermediate dimension is based on the previous FC's
562 // result's trailing dimension. Thus we set the current FC's trailing
563 // weight dim equal to the next FC's intermediate dimension.
564 const dim_t intDim = intDims[i + 1];
565 auto *intermediate_bias = createRandomizedConstant(
566 mod, internalType, {intDim}, "intermediate_bias");
567 auto *intermediate_weight = createRandomizedConstant(
568 mod, internalType, {last.dims()[1], intDim}, "intermediate_weight");
569
570 FullyConnectedNode *intermediate_layer = F_->createFullyConnected(
571 "dense", last, intermediate_weight,
572 intermediate_bias); // Output is size {MB, intDims[i]}
573 last = F_->createRELU("relu2", intermediate_layer);
574 }
575
576 /// End
577 auto *end_bias =
578 createRandomizedConstant(mod, internalType, {outputDim}, "end_bias");
579 auto *end_weight = createRandomizedConstant(
580 mod, internalType, {last.dims()[1], outputDim}, "end_weight");
581
582 FullyConnectedNode *end_layer = F_->createFullyConnected(
583 "dense", last, end_weight, end_bias); // Output is size {MB, embDim}
584
585 auto *RN = F_->createRELU("relu3", end_layer);
586
587 return RN->getResult();
588 }
589
590 /// Creates a rowwise quantized Multi-layer perceptron network consisting of
591 /// start & end FCs with \p intermediateLayers hidden layers.
592 /// * All weights and biases are random. Weights are Int8Q (rowwise), biases
593 /// are Int32.
594 /// * All internal activations are RELU, however the final layer has no
595 /// activation attached.
596 /// * Parent node \p N_ has output dimension \p inputDim int float.
597 /// * Hidden layers have dimension of \p intDim * intDim int Int8Q
598 /// (rowwise).
599 /// * Output layer has output dimension \p outputDim in float.
600 ///
601 /// Quantized MLPs use RowwiseQuantizedFullyConnected Nodes, which expect:
602 /// * weights to be Float32 and convert to Int8 fused rowwise quantized
603 /// Tensors internally
604 /// * Biases are Int32 quantized.
605 static NodeValue createQuantizedMLP(Module &mod, Function *F_, NodeValue N_,
606 dim_t inputDim,
607 llvm::ArrayRef<dim_t> intDims,
608 dim_t outputDim,
609 dim_t intermediateLayers) {
610 // Must have intermediate layers.
611 assert(intermediateLayers > 0);
612
613 const dim_t minibatchSize = N_.dims()[0];
614 const dim_t firstIntDim = intDims[0];
615
616 // Type objects for the internal types.
617 // Note: dimension argument is a placeholder and will get filled out by each
618 // createRandomizedConstant invocation.
619 auto internalTypeF = mod.uniqueType(ElemKind::FloatTy, {1});
620 auto internalTypeQ = mod.uniqueType(ElemKind::Int8QTy, {1}, 1, 0);
621 auto internalBiasType = mod.uniqueType(ElemKind::Int32QTy, {1}, 1e-11, 0);
622
623 auto *start = F_->createQuantize(
624 "mlp_quant", N_, mod.uniqueTypeWithNewShape(internalTypeQ, N_.dims()));
625
626 /// Initial.
627 auto *initial_bias = createRandomizedConstant(
628 mod, internalBiasType, {firstIntDim}, "initial_bias");
629 auto *initial_weight = createRandomizedConstant(
630 mod, internalTypeF, {inputDim, firstIntDim}, "initial_weight");
631
632 // Output is size {MB, intermediatDim}
633 quantization::Schema rowwiseQuantSchema = useSymmetricRowwiseQuantFC
634 ? quantization::Symmetric
635 : quantization::Asymmetric;
636 Node *initial_layer = F_->createRowwiseQuantizedFullyConnected(
637 "dense", start, initial_weight, initial_bias,
638 mod.uniqueTypeWithNewShape(internalTypeQ, {minibatchSize, firstIntDim}),
639 rowwiseQuantSchema,
640 /* transposeWeight */ true);
641
642 NodeValue last = F_->createRELU("initial_relu", initial_layer);
643
644 /// Intermediate
645 for (unsigned i = 0; i < intermediateLayers; ++i) {
646 // The current intermediate dimension is based on the previous FC's
647 // result's trailing dimension. Thus we set the current FC's trailing
648 // weight dim equal to the next FC's intermediate dimension.
649 const dim_t intDim = intDims[i + 1];
650 auto *intermediate_bias = createRandomizedConstant(
651 mod, internalBiasType, {intDim}, "intermediate_bias");
652 auto *intermediate_weight = createRandomizedConstant(
653 mod, internalTypeF, {last.dims()[1], intDim}, "intermediate_weight");
654
655 Node *intermediate_layer = F_->createRowwiseQuantizedFullyConnected(
656 "dense", last, intermediate_weight, intermediate_bias,
657 mod.uniqueType(ElemKind::Int8QTy, {minibatchSize, intDim}, 1.0, 0),
658 rowwiseQuantSchema,
659 /* transposeWeight */ true); // Output is size {MB, intDims[i]}
660 last = F_->createRELU("intermediate_relu", intermediate_layer);
661 }
662
663 /// End
664 auto *end_bias = createRandomizedConstant(mod, internalBiasType,
665 {outputDim}, "end_bias");
666 auto *end_weight = createRandomizedConstant(
667 mod, internalTypeF, {last.dims()[1], outputDim}, "end_weight");
668
669 // Output is size {MB, embDim}
670 auto *end_layer = F_->createRowwiseQuantizedFullyConnected(
671 "dense", last, end_weight, end_bias,
672 mod.uniqueTypeWithNewShape(internalTypeQ, {minibatchSize, outputDim}),
673 rowwiseQuantSchema,
674 /* transposeWeight */ true);
675
676 auto *RN = F_->createRELU("relu", end_layer);
677 auto *DQN = F_->createDequantize("mlp_dequant", RN, ElemKind::FloatTy);
678
679 return DQN->getResult();
680 }
681
682 /// Creates a number of Sparse tables (FP32 or Int8Q), the Indices lookup and
683 /// the SpareLengthsSum Node tying it together.
684 void createSparseEmbeddings(Module &mod, PlaceholderBindings &bindings_,
685 Function *F_, TestDeferredWeightLoader &loader,
686 llvm::ArrayRef<Placeholder *> lengths,
687 llvm::ArrayRef<dim_t> embSizes, dim_t embDim,
688 std::vector<NodeValue> &embeddings) {
689 auto internalTypeF = mod.uniqueType(ElemKind::FloatTy, {1});
690
691 for (unsigned int i = 0; i < lengths.size(); i++) {
692 fillStableRandomIndex(
693 bindings_.allocate(lengths[i])->getHandle<int32_t>(),
694 randomSeedLengthsOpt, lengthsMin, lengthsMax);
695
696 dim_t sum =
697 sumOfElements(bindings_.get(lengths[i])->getHandle<int32_t>());
698 auto *indices = mod.createPlaceholder(
699 ElemKind::Int64ITy, {sum}, "indices" + std::to_string(i), false);
700 fillStableRandomIndex(bindings_.allocate(indices)->getHandle<int64_t>(),
701 randomSeedContentOpt, 0, embSizes[i]);
702
703 // output is size {MB, embDim}
704 if (quantizeSLWSData) {
705 Storage *data;
706 if (!isInterpreter && enableStaticPlaceholder) {
707 Placeholder *ph = createFusedRowwiseQuantizedPlaceholder(
708 mod, {embSizes[i], embDim}, "data" + std::to_string(i),
709 useFP16SLWS);
710
711 ph->setStatic(true);
712 auto *tensor = loader.addWeight(ph->getType());
713 auto fData = Tensor(ElemKind::FloatTy, {embSizes[i], embDim});
714 fData.getHandle<uint8_t>().randomize(UINT8_MIN, UINT8_MAX,
715 mod.getPRNG());
716 loader.addName("data" + std::to_string(i));
717
718 bindings_.allocate(ph);
719 updateInputPlaceholders(bindings_, {ph}, {tensor});
720
721 data = ph;
722
723 Tensor rwqData(ElemKind::UInt8FusedQTy,
724 {embSizes[i], embDim + 2 * (dim_t)sizeof(float)},
725 data->getType()->getScale(),
726 data->getType()->getOffset());
727
728 quantization::tensorFusedRowwiseQuantization<float>(fData, rwqData);
729 tensor->assign(&rwqData);
730 } else {
731 data = createRandomFusedRowwiseQuantizedConstant(
732 mod, {embSizes[i], embDim}, "data" + std::to_string(i),
733 useFP16SLWS);
734 }
735
736 embeddings[i] = F_->createFusedRowwiseQuantizedSparseLengthsSum(
737 "RQSLWS" + std::to_string(i), data, indices, lengths[i],
738 useFP16AccumSLWS);
739 // Convert back to Float if we used Float16 here. Optimizer will
740 // eliminate if necessary.
741 if (useFP16SLWS) {
742 embeddings[i] = F_->createConvertTo(
743 "convert_" + embeddings[i].getNode()->getName().str(),
744 embeddings[i], ElemKind::FloatTy);
745 }
746 } else {
747 Storage *data;
748 if (!isInterpreter && enableStaticPlaceholder) {
749 Placeholder *ph =
750 mod.createPlaceholder(ElemKind::FloatTy, {embSizes[i], embDim},
751 "data" + std::to_string(i), false);
752 ph->setStatic(true);
753 auto *tensor = loader.addWeight(ph->getType());
754 tensor->getHandle<float>().initXavier(tensor->getType().size() * 2,
755 mod.getPRNG());
756 loader.addName("data" + std::to_string(i));
757
758 bindings_.allocate(ph);
759 updateInputPlaceholders(bindings_, {ph}, {tensor});
760 data = ph;
761 } else {
762 data = createRandomizedConstant(mod, internalTypeF,
763 {embSizes[i], embDim},
764 "data" + std::to_string(i));
765 }
766
767 embeddings[i] = F_->createSparseLengthsSum("sls" + std::to_string(i),
768 data, indices, lengths[i]);
769 }
770 }
771 }
772
773 /// Creates a number of Sparse tables (FP32 or Int8Q), the Indices lookup and
774 /// the SpareLengthsSum Node tying it together.
775 /// TODO: we need to quantize the data tensors for deferred weight loading.
776 void createSparseWeightedGatherEmbeddings(
777 Module &mod, PlaceholderBindings &bindings_, Function *F_,
778 TestDeferredWeightLoader &loader, llvm::ArrayRef<Placeholder *> lengths,
779 llvm::ArrayRef<dim_t> tableSizes, dim_t embeddingDim,
780 std::vector<NodeValue> &embeddings, uint32_t weightsSize = 1000) {
781 for (size_t i = 0; i < lengths.size(); i++) {
782 fillStableRandomIndex(
783 bindings_.allocate(lengths[i])->getHandle<int32_t>(),
784 randomSeedLengthsOpt, lengthsMin, lengthsMax);
785
786 dim_t sum =
787 sumOfElements(bindings_.get(lengths[i])->getHandle<int32_t>());
788 auto *indices = mod.createPlaceholder(
789 ElemKind::Int64ITy, {sum}, "indices" + std::to_string(i), false);
790 fillStableRandomIndex(bindings_.allocate(indices)->getHandle<int64_t>(),
791 randomSeedContentOpt, 0, tableSizes[i]);
792
793 // Should be able to pass weights - fix later. Currently, just a
794 // randomized constant.
795 Constant *weightsConst = createRandomizedConstant(
796 mod, mod.uniqueType(ElemKind::FloatTy, {weightsSize}), {weightsSize},
797 "weights" + std::to_string(i));
798
799 auto *weightIndices =
800 mod.createPlaceholder(ElemKind::Int32ITy, {sum},
801 "weight_indices" + std::to_string(i), false);
802 fillStableRandomIndex(
803 bindings_.allocate(weightIndices)->getHandle<int32_t>(),
804 randomSeedContentOpt, 0, weightsSize - 1);
805
806 auto *weights = F_->createGather("weight_gather" + std::to_string(i),
807 weightsConst, weightIndices, 0);
808
809 // output is size {MB, embeddingDim_}
810 if (quantizeSLWSData) {
811 Storage *data;
812 if (!isInterpreter && enableStaticPlaceholder) {
813 Placeholder *ph = createFusedRowwiseQuantizedPlaceholder(
814 mod, {tableSizes[i], embeddingDim}, "data" + std::to_string(i),
815 useFP16SLWS);
816 ph->setStatic(true);
817 auto *tensor = loader.addWeight(ph->getType());
818 tensor->getHandle<uint8_t>().randomize(UINT8_MIN, UINT8_MAX,
819 mod.getPRNG());
820
821 loader.addName("data" + std::to_string(i));
822
823 bindings_.allocate(ph);
824 updateInputPlaceholders(bindings_, {ph}, {tensor});
825
826 data = ph;
827 } else {
828 data = createRandomFusedRowwiseQuantizedConstant(
829 mod, {tableSizes[i], embeddingDim}, "data" + std::to_string(i),
830 useFP16SLWS);
831 }
832
833 embeddings[i] = F_->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
834 "RQSLWS" + std::to_string(i), data, weights, indices, lengths[i],
835 useFP16AccumSLWS);
836 // Convert back to Float if we used Float16 here. Optimizer will
837 // eliminate if necessary.
838 if (useFP16SLWS) {
839 embeddings[i] = F_->createConvertTo(
840 "convert_" + embeddings[i].getNode()->getName().str(),
841 embeddings[i], ElemKind::FloatTy);
842 }
843 } else {
844 Storage *data;
845 if (!isInterpreter && enableStaticPlaceholder) {
846 Placeholder *ph = mod.createPlaceholder(
847 ElemKind::FloatTy, {tableSizes[i], embeddingDim},
848 "data" + std::to_string(i), false);
849 ph->setStatic(true);
850 auto *tensor = loader.addWeight(ph->getType());
851 tensor->getHandle<float>().initXavier(tensor->getType().size() * 2,
852 mod.getPRNG());
853 loader.addName("data" + std::to_string(i));
854
855 bindings_.allocate(ph);
856 updateInputPlaceholders(bindings_, {ph}, {tensor});
857 data = ph;
858 } else {
859 data = createRandomizedConstant(
860 mod,
861 mod.uniqueType(ElemKind::FloatTy, {tableSizes[i], embeddingDim}),
862 {tableSizes[i], embeddingDim}, "data" + std::to_string(i));
863 }
864
865 embeddings[i] = F_->createSparseLengthsWeightedSum(
866 "slws" + std::to_string(i), data, weights, indices, lengths[i]);
867 }
868 }
869 }
870
871 /// Builds a simple graph, \returns the Tensor output of the graph.
872 Tensor *createSimpleRecSysGraph(Module &mod, PlaceholderBindings &bindings,
873 Function *F, TestDeferredWeightLoader &loader,
874 llvm::ArrayRef<dim_t> embSizes,
875 dim_t embDim) {
876 EXPECT_EQ(tableSizes.size(), embSizes.size());
877
878 // Create the tables.
879 std::vector<Placeholder *> lengths(tableSizes.size());
880 for (unsigned int i = 0; i < lengths.size(); i++) {
881 lengths[i] = mod.createPlaceholder(ElemKind::Int32ITy, {miniBatch},
882 "SL" + std::to_string(i), false);
883 }
884
885 auto *denseData = mod.createPlaceholder(ElemKind::FloatTy,
886 {miniBatch, denseDim}, "denseData",
887 false); // denseDim can be anything
888
889 // First Dense embedding
890 fillStableRandomData(bindings.allocate(denseData)->getHandle(),
891 randomSeedContentOpt, 0.001);
892 NodeValue bottomMLP;
893 if (quantizeFC) {
894 bottomMLP = createQuantizedMLP(mod, F, denseData, denseData->dims()[1],
895 bottomMLPIntermediateDims, embDim,
896 numHiddenBottomMLPLayersOpt);
897 } else {
898 bottomMLP = createMLP(mod, F, denseData, denseData->dims()[1],
899 bottomMLPIntermediateDims, embDim,
900 numHiddenBottomMLPLayersOpt);
901 }
902
903 // Sparse Embeddings
904 std::vector<NodeValue> embeddings(lengths.size());
905 if (gatherWeights) {
906 createSparseWeightedGatherEmbeddings(mod, bindings, F, loader, lengths,
907 embSizes, embDim, embeddings);
908 } else {
909 createSparseEmbeddings(mod, bindings, F, loader, lengths, embSizes,
910 embDim, embeddings);
911 }
912
913 // Interacting sparse and dense
914 embeddings.push_back(bottomMLP);
915 std::cout << "Number of embeddings concatenated: " << embeddings.size()
916 << std::endl;
917 auto *CN = F->createConcat("concat", embeddings,
918 1); // Output is size {MB, embDim*n}
919 auto *reshaped =
920 F->createReshape("reshape", CN,
921 {bottomMLP.dims()[0], (dim_t)embeddings.size(),
922 embDim}); // {MB, n, embDim}
923 auto *transposed =
924 F->createTranspose("transpose", reshaped, {0, 2, 1}); // {MB, embDim, n}
925 auto *dot = F->createBatchMatMul("dot_products", reshaped,
926 transposed); // {MB, n, n}
927 auto *reshapeDot = F->createReshape(
928 "reshapeDot", dot,
929 {bottomMLP.dims()[0],
930 (dim_t)(embeddings.size() * embeddings.size())}); // {MB, n^2}
931 NodeValue interact = F->createConcat("interact", {reshapeDot, bottomMLP},
932 1); // {MB, n^2 + embDim}
933
934 // MLP at the top
935 Node *topMLP;
936 if (quantizeFC) {
937 topMLP = createQuantizedMLP(mod, F, interact, interact.dims()[1],
938 topMLPIntermediateDims,
939 /* outputDim */ 1, numHiddenTopMLPLayersOpt);
940 } else {
941 topMLP = createMLP(mod, F, interact, interact.dims()[1],
942 topMLPIntermediateDims,
943 /* outputDim */ 1, numHiddenTopMLPLayersOpt);
944 }
945
946 // Output
947 auto *save = F->createSave("save", topMLP);
948
949 return bindings.allocate(save->getPlaceholder());
950 }
951
952 /// Set up the precision configuration. This will be used for all
953 /// compilations which are compared to (Interpreter/Partitioned).
954 void setupPrecisionConfig() {
955 if (convertToFP16) {
956 precConfig_.convertToFP16 = convertToFP16;
957 precConfig_.convertFusedToFP16 = convertFusedToFP16;
958 precConfig_.convert4BitFusedToFP32 = convert4or8BitFusedToFP32;
959 precConfig_.convert8BitFusedToFP32 = convert4or8BitFusedToFP32;
960 // Note: always do not convert RWQ-SLWS here. The creator itself for
961 // precisionForNonDataSLWS already directly created the node with the
962 // correct precision.
963 precConfig_.precisionModeKindSet.insert(
964 Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind);
965 precConfig_.precisionModeKindSet.insert(
966 Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind);
967 }
968 if (fuseScaleOffsetFp32Opt) {
969 precConfig_.convert4BitFusedToFP32 = fuseScaleOffsetFp32Opt;
970 precConfig_.convert8BitFusedToFP32 = fuseScaleOffsetFp32Opt;
971 }
972 }
973
974 /// Set up the precision configuration for Interpreter.
975 void setupPrecisionConfigforInterpreter() {
976 if (convertToFP16) {
977 precConfigForInterpreter_.convertToFP16 = convertToFP16;
978 precConfigForInterpreter_.convertFusedToFP16 = convertToFP16;
979 // Note: always do not convert RWQ-SLWS here. The creator itself for
980 // precisionForNonDataSLWS already directly created the node with the
981 // correct precision.
982 precConfigForInterpreter_.precisionModeKindSet.insert(
983 Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind);
984 precConfigForInterpreter_.precisionModeKindSet.insert(
985 Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind);
986 }
987 if (fuseScaleOffsetFp32Opt) {
988 precConfigForInterpreter_.convert4BitFusedToFP32 = fuseScaleOffsetFp32Opt;
989 precConfigForInterpreter_.convert8BitFusedToFP32 = fuseScaleOffsetFp32Opt;
990 }
991 }
992
993 void printPerfSummary(std::vector<double> times) {
994 std::cout << "_,benchName,concurrent-count,runtime,QPS" << std::endl;
995 for (auto t : times) {
996 auto qps = miniBatchOpt / t * concurrentReqestsOpt;
997 std::cout << "BenchResult,RecommendationSystemTest,"
998 << (unsigned)concurrentReqestsOpt << ","
999 << t / concurrentReqestsOpt << "," << qps << std::endl;
1000 }
1001 double min = *(std::min_element(times.begin(), times.end()));
1002 dim_t midElt = times.size() / 2;
1003 std::nth_element(times.begin(), times.begin() + midElt, times.end());
1004 double median = times[midElt];
1005 double medianRuntime = median / ((double)concurrentReqestsOpt);
1006 double minRuntime = min / ((double)concurrentReqestsOpt);
1007 std::cout << "_,benchName,reps,concurrent-count,medianRuntime,minRuntime,"
1008 "medianQPS,maxQPS"
1009 << std::endl;
1010 std::cout << "BenchSummary,RecommendationSystemTest," << (unsigned)repsOpt
1011 << "," << (unsigned)concurrentReqestsOpt << "," << medianRuntime
1012 << "," << minRuntime << "," << miniBatchOpt / medianRuntime << ","
1013 << miniBatchOpt / minRuntime << std::endl;
1014 }
1015
1016 void testRecSys(bool checkConcat = false) {
1017 assert((!useFP16AccumSLWS || useFP16SLWS) &&
1018 "Can only use FP16 accumulation when using FP16 precision.");
1019 isInterpreter = false;
1020 setupPrecisionConfig();
1021
1022 // Generate the network.
1023 std::unique_ptr<Module> mod(new Module);
1024 TestDeferredWeightLoader loader;
1025
1026 F_ = mod->createFunction("main");
1027 resultTensor = createSimpleRecSysGraph(*mod.get(), *bindings_, F_, loader,
1028 tableSizes, embeddingDim);
1029
1030 Placeholder *concatPH = nullptr;
1031 if (checkConcat) {
1032 // Add an observer node after concat.
1033 auto *CN = F_->getNodeByName("concat");
1034 auto *saveConcat = F_->createSave("after_concat_data", CN);
1035 concatPH = saveConcat->getPlaceholder();
1036 }
1037 if (dumpModelInputs) {
1038 // dump model into a zip file which can run with repro binary.
1039 glow::onnxifi::saveOnnxifiModel(F_);
1040 }
1041 auto configs =
1042 runtime::generateDeviceConfigs(1, getBackendName(), MAX_MEMORY);
1043 std::unique_ptr<HostManager> hostManager(
1044 new HostManager(std::move(configs)));
1045
1046 DeferredLoader()->registerLoader(&loader);
1047
1048 CompilationContext cctx;
1049 if (enableStaticPlaceholder) {
1050 cctx.optimizationOpts.foldStaticPlaceholderConversions = true;
1051 }
1052 cctx.precisionConfig = precConfig_;
1053 cctx.deferredWeightLoader = &loader;
1054 cctx.dumpFinalGraph = dumpFinalGraph;
1055 EXIT_ON_ERR(hostManager->addNetwork(std::move(mod), cctx));
1056
1057 // Run graph
1058 std::vector<double> times(repsOpt);
1059 for (size_t i = 0; i < repsOpt; i++) {
1060 auto start = std::chrono::high_resolution_clock::now();
1061 dispatchInference("main", hostManager.get(), context_,
1062 concurrentReqestsOpt);
1063 auto end = std::chrono::high_resolution_clock::now();
1064 auto duration = std::chrono::duration<double>(end - start).count();
1065 times[i] = duration;
1066 }
1067
1068 printPerfSummary(times);
1069
1070 // NaNs are a sign of something gone wrong. Always verify there aren't any
1071 // in the result.
1072 auto resultTensorH = resultTensor->getHandle();
1073 for (size_t i = 0, e = resultTensorH.size(); i < e; i++) {
1074 EXPECT_FALSE(std::isnan(resultTensorH.raw(i)));
1075 }
1076
1077 if (checkConcat) {
1078 // Get result and verify.
1079 EXPECT_EQ(resultTensor->size(), miniBatch);
1080
1081 auto *concatT = bindings_->get(concatPH);
1082 auto concatH = concatT->getHandle();
1083 // Check that intermediate concat results didn't overflow.
1084 std::cout << "Intermediate concats" << std::endl;
1085 concatH.dump();
1086 for (int i = 0, e = concatH.size(); i < e; ++i) {
1087 EXPECT_LE(fabs(concatH.raw(i)), 100);
1088 }
1089
1090 std::cout << "Result of prediction" << std::endl;
1091 std::cout << resultTensorH.size() << std::endl;
1092 resultTensorH.dump();
1093 for (int i = 0, e = resultTensorH.size(); i < e; ++i) {
1094 EXPECT_GE(resultTensorH.raw(i), 0.0);
1095 }
1096 }
1097
1098 if (dumpModelInputs) {
1099 dumpOutputs();
1100 }
1101
1102 // Undeploy the network.
1103 CHECK(!ERR_TO_BOOL(hostManager->removeNetwork("main")))
1104 << "Could not remove the network";
1105 // Free memory.
1106 hostManager.reset();
1107 mod.reset();
1108
1109 // Compare against interpreter if we're not executing already on it.
1110 if (!skipCorrectnessCheck && getBackendName() != "Interpreter") {
1111 compareAgainstInterpreter();
1112 } else {
1113 std::cout << "Skip correctness check with Interpreter backend"
1114 << std::endl;
1115 }
1116 }
1117
1118 /// Run on the Interpreter and compare the result to previous result.
1119 void compareAgainstInterpreter() {
1120 isInterpreter = true;
1121 setupPrecisionConfigforInterpreter();
1122
1123 ExecutionContext contextI;
1124 // Create a new module for the interpreter run.
1125 std::unique_ptr<Module> modI(new Module);
1126 TestDeferredWeightLoader loaderI;
1127 auto *IF = modI->createFunction("main");
1128 PlaceholderBindings *bindingsI = contextI.getPlaceholderBindings();
1129 Tensor *resultIT = createSimpleRecSysGraph(*modI, *bindingsI, IF, loaderI,
1130 tableSizes, embeddingDim);
1131 bindingsI->allocate(modI->getPlaceholders());
1132
1133 // Set device memory to 64GB to prevent partitioning. We are using the
1134 // Interpreter's result just as a reference result to compare against.
1135 auto configs = generateDeviceConfigs(1, "Interpreter", MAX_MEMORY);
1136 std::unique_ptr<HostManager> hostManager(
1137 new HostManager(std::move(configs)));
1138
1139 DeferredLoader()->registerLoader(&loaderI);
1140
1141 // Use the same precision transformation for compilation.
1142 CompilationContext cctx;
1143 cctx.precisionConfig = precConfigForInterpreter_;
1144 cctx.deferredWeightLoader = &loaderI;
1145 EXIT_ON_ERR(hostManager->addNetwork(std::move(modI), cctx));
1146 dispatchInference("main", hostManager.get(), contextI,
1147 concurrentReqestsOpt);
1148
1149 assert(resultTensor && "Must run and set resultTensor before comparing "
1150 "against the intepreter.");
1151 EXPECT_TRUE(resultIT->isEqual(*resultTensor, 0.005));
1152 }
1153
1154 /// Create partitions to run and compare results.
1155 void testPartitionedRecSys(size_t numDevices, size_t memSize,
1156 ExecutionContext &context) {
1157 isInterpreter = false;
1158 // Result tensors are reused below, so create a local copy.
1159 Tensor referenceResultT = resultTensor->clone();
1160 // Generate configs and create a new HostManager for testing partitioning.
1161 auto configs = generateDeviceConfigs(numDevices, getBackendName(), memSize);
1162 std::unique_ptr<HostManager> hostManager(
1163 new HostManager(std::move(configs)));
1164
1165 // Create a new module and placeholderBindings to run on the partitioning
1166 // HostManager.
1167 PlaceholderBindings bindingsP;
1168 std::unique_ptr<Module> modP(new Module);
1169 TestDeferredWeightLoader loaderP;
1170 // Since HostManager consumed the uniquePtr we grab a raw pointer to the
1171 // module so we can verify partitioning.
1172 Module *rawModule = modP.get();
1173 auto *funcP = modP->createFunction("main");
1174 createSimpleRecSysGraph(*modP, bindingsP, funcP, loaderP, tableSizes,
1175 embeddingDim);
1176
1177 assert(memSize > 0 && "Must set partitionerPerDeviceMemCapacity > 0.");
1178 assert(numDevices > 0 && "Must set partitionerNumDevices > 0.");
1179 std::cout << numDevices << " devices of size " << memSize << "\n";
1180
1181 DeferredLoader()->registerLoader(&loaderP);
1182
1183 // Use the same precision transformation for compilation.
1184 CompilationContext cctx;
1185 if (enableStaticPlaceholder) {
1186 cctx.optimizationOpts.foldStaticPlaceholderConversions = true;
1187 }
1188 cctx.precisionConfig = precConfig_;
1189 cctx.deferredWeightLoader = &loaderP;
1190 cctx.optimizationOpts.useSparseNNPartitioningScheme =
1191 useSparseNNPartitioning;
1192 cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats =
1193 sparseNNPartitioningAddSLSConcats;
1194 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards =
1195 sparseNNPartitioningNumCards;
1196 cctx.optimizationOpts.sparseNNPartitioningSchemeSLSTableKBytesPerCard =
1197 sparseNNPartitioningSLSKbytes;
1198 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresSLS =
1199 sparseNNPartitioningNumCoresSLS;
1200 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresOther =
1201 sparseNNPartitioningNumCoresOther;
1202 cctx.dumpFinalGraph = dumpFinalGraph;
1203 cctx.saturateHost = saturateHost;
1204 EXIT_ON_ERR(hostManager->addNetwork(std::move(modP), cctx));
1205 std::cout << "Partitions = " << rawModule->getFunctions().size()
1206 << std::endl;
1207
1208 // Run the partitioned graph and compare the results.
1209 auto &bindings = *context.getPlaceholderBindings();
1210 bindings.clear();
1211 bindings.allocate(rawModule->getPlaceholders());
1212 bindingsP.allocate(rawModule->getPlaceholders());
1213 for (const auto &PH : bindingsP.pairs()) {
1214 bindingsP.copyToTarget(PH.first->getName(), bindings);
1215 }
1216
1217 dispatchInference("main", hostManager.get(), context, concurrentReqestsOpt);
1218
1219 Tensor *resultTensorP =
1220 bindings.get(bindings.getPlaceholderByNameSlow("save"));
1221 if (enableStaticPlaceholder) {
1222 EXPECT_TRUE(referenceResultT.isEqual(*resultTensorP, 0.005));
1223 } else {
1224 EXPECT_TRUE(referenceResultT.isEqual(*resultTensorP));
1225 }
1226 }
1227
1228 /// Test SparseLengthsSum independently.
1229 void testSLSQuant() {
1230 isInterpreter = false;
1231 std::unique_ptr<Module> mod(new Module);
1232 TestDeferredWeightLoader loader;
1233 F_ = mod->createFunction("main");
1234 std::vector<Placeholder *> sparseLengths(1);
1235 sparseLengths[0] =
1236 mod->createPlaceholder(ElemKind::Int32ITy, {miniBatch}, "SL0", false);
1237
1238 std::vector<NodeValue> embeddings(sparseLengths.size());
1239 createSparseEmbeddings(*mod.get(), *bindings_, F_, loader, sparseLengths,
1240 tableSizes, embeddingDim, embeddings);
1241
1242 auto *save = F_->createSave("save", embeddings[0]);
1243 Tensor *resultTensorLocal = bindings_->allocate(save->getPlaceholder());
1244
1245 DeferredLoader()->registerLoader(&loader);
1246
1247 // Use the same precision transformation for compilation.
1248 CompilationContext cctx;
1249 if (enableStaticPlaceholder) {
1250 cctx.optimizationOpts.foldStaticPlaceholderConversions = true;
1251 }
1252 cctx.precisionConfig = precConfig_;
1253 cctx.deferredWeightLoader = &loader;
1254 auto configs = generateDeviceConfigs(1, getBackendName(), MAX_MEMORY);
1255 std::unique_ptr<HostManager> hostManager(
1256 new HostManager(std::move(configs)));
1257 EXIT_ON_ERR(hostManager->addNetwork(std::move(mod), cctx));
1258
1259 // Run graph.
1260 dispatchInference("main", hostManager.get(), context_,
1261 concurrentReqestsOpt);
1262
1263 // TODO: for now we only check the output dimension, contents are ignored
1264 EXPECT_EQ(resultTensorLocal->size(), miniBatch * embeddingDim);
1265 resultTensorLocal->getHandle().dump();
1266 }
1267};
1268
1269/// Standard Tests
1270/// These tests have three options:
1271/// * quantizeSLWSData enables Int8 Fused Rowwise Quantization for the Sparse
1272/// Embeddings (Int8 quantized values with float scale and offset).
1273/// * quantizeFC enables Int8 Fused Rowwise Quantization for FC weights and
1274/// activations inside the MLPs.
1275/// * convertToFP16 walks the graph at the end of constructing the graph and
1276/// converts all FP32 nodes & tensors to FP16, meaning the graph will use
1277/// FP16 for internal weights, biases and activations (when not already Int8
1278/// quantized). Inputs and outputs are still FP32 but are immediately
1279/// dropped to FP16 precision at the beginning of the graph.
1280/// * useFP16SLWS represents whether to use Float16 for non-data
1281/// inputs/outputs for SLWS and SLS Nodes, and for data per-row scale and
1282/// offset.
1283/// * useFP16AccumSLWS represents whether to use Float16 accumulation for SLWS
1284/// and SLS Nodes. Note this should only be used if useFP16SLWS.
1285
1286/// Everything in FP32.
1287TEST_P(RecommendationSystemTest, RecSys_FP32) {
1288 CHECK_IF_ENABLED();
1289
1290 quantizeSLWSData = false;
1291 useFP16SLWS = false;
1292 useFP16AccumSLWS = false;
1293 quantizeFC = false;
1294 convertToFP16 = false;
1295
1296 testRecSys();
1297}
1298
1299// RecSys_FP32 with deferred weight loading.
1300TEST_P(RecommendationSystemTest, RecSys_FP32_Deferred) {
1301 CHECK_IF_ENABLED();
1302
1303 quantizeSLWSData = false;
1304 useFP16SLWS = false;
1305 useFP16AccumSLWS = false;
1306 quantizeFC = false;
1307 convertToFP16 = true;
1308 enableStaticPlaceholder = true;
1309 convertFusedToFP16 = false;
1310 convert4or8BitFusedToFP32 = true;
1311
1312 testRecSys();
1313}
1314
1315/// Rowwise quantize the SLWS and FC; everything else in FP32.
1316TEST_P(RecommendationSystemTest, RecSys_RWQuantized_SLWS_FC) {
1317 CHECK_IF_ENABLED();
1318
1319 quantizeSLWSData = true;
1320 useFP16SLWS = false;
1321 useFP16AccumSLWS = false;
1322 quantizeFC = true;
1323 convertToFP16 = false;
1324
1325 testRecSys();
1326}
1327
1328// RecSys_RWQuantized_SLWS_FC with deferred weight loading.
1329TEST_P(RecommendationSystemTest, RecSys_RWQuantized_SLWS_FC_Deferred) {
1330 CHECK_IF_ENABLED();
1331
1332 quantizeSLWSData = true;
1333 useFP16SLWS = false;
1334 useFP16AccumSLWS = false;
1335 quantizeFC = true;
1336
1337 enableStaticPlaceholder = true;
1338 convertToFP16 = true;
1339 convertFusedToFP16 = false;
1340 convert4or8BitFusedToFP32 = true;
1341
1342 testRecSys();
1343}
1344
1345/// Rowwise quantize the SLWS; everything else in FP32.
1346TEST_P(RecommendationSystemTest, RecSys_RWQuantized_SLWS) {
1347 CHECK_IF_ENABLED();
1348
1349 quantizeSLWSData = true;
1350 useFP16SLWS = false;
1351 useFP16AccumSLWS = false;
1352 quantizeFC = false;
1353 convertToFP16 = false;
1354
1355 testRecSys();
1356}
1357
1358// RecSys_RWQuantized_SLWS with deferred weight loading.
1359TEST_P(RecommendationSystemTest, RecSys_RWQuantized_SLWS_Deferred) {
1360 CHECK_IF_ENABLED();
1361
1362 quantizeSLWSData = true;
1363 useFP16SLWS = false;
1364 useFP16AccumSLWS = false;
1365 quantizeFC = false;
1366
1367 enableStaticPlaceholder = true;
1368 convertToFP16 = true;
1369 convertFusedToFP16 = false;
1370 convert4or8BitFusedToFP32 = true;
1371
1372 testRecSys();
1373}
1374
1375/// Rowwise quantize the SLWS and FC; everything else in FP16.
1376TEST_P(RecommendationSystemTest, RecSys_RWQuantized_SLWS_FC_FP16) {
1377 CHECK_IF_ENABLED();
1378
1379 quantizeSLWSData = true;
1380 useFP16SLWS = false;
1381 useFP16AccumSLWS = false;
1382 quantizeFC = true;
1383 convertToFP16 = true;
1384 convertFusedToFP16 = true;
1385
1386 testRecSys();
1387}
1388
1389/// Rowwise quantize the SLWS; everything else in FP16.
1390TEST_P(RecommendationSystemTest, RecSys_RWQuantized_SLWS_FP16) {
1391 CHECK_IF_ENABLED();
1392
1393 quantizeSLWSData = true;
1394 useFP16SLWS = false;
1395 useFP16AccumSLWS = false;
1396 quantizeFC = false;
1397 convertToFP16 = true;
1398 convertFusedToFP16 = true;
1399
1400 testRecSys();
1401}
1402
1403// RecSys_RWQuantized_SLWS_FP16 with deferred weight loading.
1404TEST_P(RecommendationSystemTest, RecSys_RWQuantized_SLWS_FP16_Deferred) {
1405 CHECK_IF_ENABLED();
1406
1407 quantizeSLWSData = true;
1408 useFP16SLWS = false;
1409 useFP16AccumSLWS = false;
1410 quantizeFC = false;
1411
1412 enableStaticPlaceholder = true;
1413 convertToFP16 = true;
1414 convertFusedToFP16 = false;
1415 convert4or8BitFusedToFP32 = true;
1416
1417 testRecSys();
1418}
1419
1420/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1421/// inputs/outputs in FP16. Everything else in FP32.
1422TEST_P(RecommendationSystemTest, RecSys_RWQuantizedFP16_SLWS) {
1423 CHECK_IF_ENABLED();
1424
1425 quantizeSLWSData = true;
1426 useFP16SLWS = true;
1427 useFP16AccumSLWS = false;
1428 quantizeFC = false;
1429 convertToFP16 = false;
1430
1431 testRecSys();
1432}
1433
1434/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1435/// inputs/outputs in FP16, and use FP16 accumulation. Everything else in FP32.
1436TEST_P(RecommendationSystemTest, RecSys_RWQuantizedFP16AccumFP16_SLWS) {
1437 CHECK_IF_ENABLED();
1438
1439 quantizeSLWSData = true;
1440 useFP16SLWS = true;
1441 useFP16AccumSLWS = true;
1442 quantizeFC = false;
1443 convertToFP16 = false;
1444
1445 testRecSys();
1446}
1447
1448/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1449/// inputs/outputs in FP16. Everything else in FP16.
1450TEST_P(RecommendationSystemTest, RecSys_RWQuantizedFP16_SLWS_FP16) {
1451 CHECK_IF_ENABLED();
1452
1453 quantizeSLWSData = true;
1454 useFP16SLWS = true;
1455 useFP16AccumSLWS = false;
1456 quantizeFC = false;
1457 convertToFP16 = true;
1458 convertFusedToFP16 = true;
1459
1460 testRecSys();
1461}
1462
1463/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1464/// inputs/outputs in FP16, and use FP16 accumulation. Everything else in FP16.
1465TEST_P(RecommendationSystemTest, RecSys_RWQuantizedFP16AccumFP16_SLWS_FP16) {
1466 CHECK_IF_ENABLED();
1467
1468 quantizeSLWSData = true;
1469 useFP16SLWS = true;
1470 useFP16AccumSLWS = true;
1471 quantizeFC = false;
1472 convertToFP16 = true;
1473 convertFusedToFP16 = true;
1474
1475 testRecSys();
1476}
1477
1478/// Partitioning Tests
1479/// These tests have the same options as the above, but also partition the
1480/// created graph into segments and walk the dag. The test then compares output
1481/// for the partitioned and unpartitioned runs.
1482
1483TEST_P(RecommendationSystemTest, RecSys_FP32_Partitioned) {
1484 CHECK_IF_ENABLED();
1485
1486 quantizeSLWSData = false;
1487 useFP16SLWS = false;
1488 useFP16AccumSLWS = false;
1489 quantizeFC = false;
1490 convertToFP16 = false;
1491
1492 testRecSys();
1493
1494 // If the memory capacity was not set on the command line, then double the
1495 // default value for this test.
1496 if (deviceMemCapacityOpt == 0) {
1497 deviceMemCapacity *= 2; // Double memory for this test
1498 }
1499
1500 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1501}
1502
1503// RecSys_FP32_Partitioned with deferred weight loading.
1504TEST_P(RecommendationSystemTest, RecSys_FP32_Partitioned_Deferred) {
1505 CHECK_IF_ENABLED();
1506
1507 quantizeSLWSData = false;
1508 useFP16SLWS = false;
1509 useFP16AccumSLWS = false;
1510 quantizeFC = false;
1511
1512 enableStaticPlaceholder = true;
1513 convertToFP16 = true;
1514 convertFusedToFP16 = false;
1515 convert4or8BitFusedToFP32 = true;
1516
1517 testRecSys();
1518
1519 // If the memory capacity was not set on the command line, then double the
1520 // default value for this test.
1521 if (deviceMemCapacityOpt == 0) {
1522 deviceMemCapacity *= 2; // Double memory for this test
1523 }
1524
1525 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1526}
1527
1528TEST_P(RecommendationSystemTest, RecSys_Partitioned_RWQuantized_SLWS) {
1529 CHECK_IF_ENABLED();
1530
1531 quantizeSLWSData = true;
1532 useFP16SLWS = false;
1533 useFP16AccumSLWS = false;
1534 quantizeFC = false;
1535 convertToFP16 = false;
1536
1537 testRecSys();
1538
1539 // If the memory capacity was not set on the command line, then double the
1540 // default value for this test.
1541 if (deviceMemCapacityOpt == 0) {
1542 deviceMemCapacity *= 2; // Double memory for this test
1543 }
1544
1545 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1546}
1547
1548// RecSys_Partitioned_RWQuantized_SLWS with deferred weight loading.
1549TEST_P(RecommendationSystemTest, RecSys_Partitioned_RWQuantized_SLWS_Deferred) {
1550 CHECK_IF_ENABLED();
1551
1552 quantizeSLWSData = true;
1553 useFP16SLWS = false;
1554 useFP16AccumSLWS = false;
1555 quantizeFC = false;
1556
1557 enableStaticPlaceholder = true;
1558 convertToFP16 = true;
1559 convertFusedToFP16 = false;
1560 convert4or8BitFusedToFP32 = true;
1561
1562 testRecSys();
1563
1564 // If the memory capacity was not set on the command line, then double the
1565 // default value for this test.
1566 if (deviceMemCapacityOpt == 0) {
1567 deviceMemCapacity *= 2; // Double memory for this test
1568 }
1569
1570 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1571}
1572
1573TEST_P(RecommendationSystemTest, RecSys_Partitioned_RWQuantized_SLWS_FC) {
1574 CHECK_IF_ENABLED();
1575
1576 quantizeSLWSData = true;
1577 useFP16SLWS = false;
1578 useFP16AccumSLWS = false;
1579 quantizeFC = true;
1580 convertToFP16 = false;
1581
1582 testRecSys();
1583
1584 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1585}
1586
1587// RecSys_Partitioned_RWQuantized_SLWS_FC with deferred weight loading.
1588TEST_P(RecommendationSystemTest,
1589 RecSys_Partitioned_RWQuantized_SLWS_FC_Deferred) {
1590 CHECK_IF_ENABLED();
1591
1592 quantizeSLWSData = true;
1593 useFP16SLWS = false;
1594 useFP16AccumSLWS = false;
1595 quantizeFC = true;
1596
1597 enableStaticPlaceholder = true;
1598 convertToFP16 = true;
1599 convertFusedToFP16 = false;
1600 convert4or8BitFusedToFP32 = true;
1601
1602 testRecSys();
1603
1604 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1605}
1606
1607TEST_P(RecommendationSystemTest, RecSys_Partitioned_RWQuantized_SLWS_FP16) {
1608 CHECK_IF_ENABLED();
1609
1610 quantizeSLWSData = true;
1611 useFP16SLWS = false;
1612 useFP16AccumSLWS = false;
1613 quantizeFC = false;
1614 convertToFP16 = true;
1615 convertFusedToFP16 = true;
1616
1617 testRecSys();
1618
1619 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1620}
1621
1622// RecSys_Partitioned_RWQuantized_SLWS_FP16 with deferred weight loading.
1623TEST_P(RecommendationSystemTest,
1624 RecSys_Partitioned_RWQuantized_SLWS_FP16_Deferred) {
1625 CHECK_IF_ENABLED();
1626
1627 quantizeSLWSData = true;
1628 useFP16SLWS = false;
1629 useFP16AccumSLWS = false;
1630 quantizeFC = false;
1631
1632 enableStaticPlaceholder = true;
1633 convertToFP16 = true;
1634 convertFusedToFP16 = false;
1635 convert4or8BitFusedToFP32 = true;
1636
1637 testRecSys();
1638
1639 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1640}
1641
1642TEST_P(RecommendationSystemTest, RecSys_Partitioned_RWQuantized_SLWS_FC_FP16) {
1643 CHECK_IF_ENABLED();
1644
1645 quantizeSLWSData = true;
1646 useFP16SLWS = false;
1647 useFP16AccumSLWS = false;
1648 quantizeFC = true;
1649 convertToFP16 = true;
1650 convertFusedToFP16 = true;
1651
1652 testRecSys();
1653
1654 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1655}
1656
1657/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1658/// inputs/outputs in FP16. Everything else in FP32. Also run partitioned and
1659/// compare results.
1660TEST_P(RecommendationSystemTest, RecSys_Partitioned_RWQuantizedFP16_SLWS) {
1661 CHECK_IF_ENABLED();
1662
1663 quantizeSLWSData = true;
1664 useFP16SLWS = true;
1665 useFP16AccumSLWS = false;
1666 quantizeFC = false;
1667 convertToFP16 = false;
1668
1669 testRecSys();
1670
1671 // If the memory capacity was not set on the command line, then double the
1672 // default value for this test.
1673 if (deviceMemCapacityOpt == 0) {
1674 deviceMemCapacity *= 2; // Double memory for this test
1675 }
1676
1677 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1678}
1679
1680// RecSys_Partitioned_RWQuantizedFP16_SLWS with deferred weight loading.
1681TEST_P(RecommendationSystemTest,
1682 RecSys_Partitioned_RWQuantizedFP16_SLWS_Deferred) {
1683 CHECK_IF_ENABLED();
1684
1685 quantizeSLWSData = true;
1686 useFP16SLWS = true;
1687 useFP16AccumSLWS = false;
1688 quantizeFC = false;
1689
1690 enableStaticPlaceholder = true;
1691 convertToFP16 = true;
1692 convertFusedToFP16 = false;
1693 convert4or8BitFusedToFP32 = true;
1694
1695 testRecSys();
1696
1697 // If the memory capacity was not set on the command line, then double the
1698 // default value for this test.
1699 if (deviceMemCapacityOpt == 0) {
1700 deviceMemCapacity *= 2; // Double memory for this test
1701 }
1702
1703 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1704}
1705
1706/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1707/// inputs/outputs in FP16, and use FP16 accumulation. Everything else in FP32.
1708/// Also run partitioned and compare results.
1709TEST_P(RecommendationSystemTest,
1710 RecSys_Partitioned_RWQuantizedFP16AccumFP16_SLWS) {
1711 CHECK_IF_ENABLED();
1712
1713 quantizeSLWSData = true;
1714 useFP16SLWS = true;
1715 useFP16AccumSLWS = true;
1716 quantizeFC = false;
1717 convertToFP16 = false;
1718
1719 testRecSys();
1720
1721 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1722}
1723
1724/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1725/// inputs/outputs in FP16. Everything else in FP16. Also run partitioned and
1726/// compare results.
1727TEST_P(RecommendationSystemTest, RecSys_Partitioned_RWQuantizedFP16_SLWS_FP16) {
1728 CHECK_IF_ENABLED();
1729
1730 quantizeSLWSData = true;
1731 useFP16SLWS = true;
1732 useFP16AccumSLWS = false;
1733 quantizeFC = false;
1734 convertToFP16 = true;
1735 convertFusedToFP16 = true;
1736
1737 testRecSys();
1738
1739 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1740}
1741
1742/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1743/// inputs/outputs in FP16, and use FP16 accumulation. Everything else in FP16.
1744/// Also run partitioned and compare results.
1745TEST_P(RecommendationSystemTest,
1746 RecSys_Partitioned_RWQuantizedFP16AccumFP16_SLWS_FP16) {
1747 CHECK_IF_ENABLED();
1748
1749 quantizeSLWSData = true;
1750 useFP16SLWS = true;
1751 useFP16AccumSLWS = true;
1752 quantizeFC = false;
1753 convertToFP16 = true;
1754 convertFusedToFP16 = true;
1755
1756 testRecSys();
1757
1758 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1759}
1760
1761/// Rowwise quantize the SLWS, with FP16 for scales/bias, and other
1762/// inputs/outputs in FP16, and use FP16 accumulation. Everything else in FP16.
1763/// Also run partitioned using SparseNN partitioning and compare results.
1764TEST_P(RecommendationSystemTest,
1765 RecSys_Partitioned_RWQuantizedFP16AccumFP16_SLWS_FP16_SNN_Partitioning) {
1766 CHECK_IF_ENABLED();
1767
1768 quantizeSLWSData = true;
1769 useFP16SLWS = true;
1770 useFP16AccumSLWS = true;
1771 quantizeFC = false;
1772 convertToFP16 = true;
1773 convertFusedToFP16 = true;
1774
1775 // Options for SparseNN Partitioning
1776 useSparseNNPartitioning = true;
1777 sparseNNPartitioningAddSLSConcats = true;
1778 sparseNNPartitioningNumCards = partitioningNumDevicesOpt;
1779 sparseNNPartitioningSLSKbytes = 1000000;
1780 sparseNNPartitioningNumCoresSLS = 6;
1781 sparseNNPartitioningNumCoresOther = 4;
1782
1783 testRecSys();
1784
1785 testPartitionedRecSys(numDevices, deviceMemCapacity, context_);
1786}
1787
1788/// Test SLS independently, with no other layers being run.
1789TEST_P(RecommendationSystemTest, RecSys_SLS_Only) {
1790 CHECK_IF_ENABLED();
1791
1792 quantizeSLWSData = true;
1793
1794 // Normally called in testRecSys(), but we're bypassing it here.
1795 setupPrecisionConfig();
1796
1797 testSLSQuant();
1798}
1799
1800// RecSys_SLS_Only with deferred weight loading.
1801TEST_P(RecommendationSystemTest, RecSys_SLS_Only_Deferred) {
1802 CHECK_IF_ENABLED();
1803
1804 quantizeSLWSData = true;
1805
1806 enableStaticPlaceholder = true;
1807 convertFusedToFP16 = false;
1808 convert4or8BitFusedToFP32 = true;
1809
1810 // Normally called in testRecSys(), but we're bypassing it here.
1811 setupPrecisionConfig();
1812
1813 testSLSQuant();
1814}
1815
1816/// Test gathering weights for SLWS.
1817TEST_P(RecommendationSystemTest, RecSys_FP32_Gather_Weights) {
1818 CHECK_IF_ENABLED();
1819
1820 quantizeSLWSData = false;
1821 useFP16SLWS = false;
1822 useFP16AccumSLWS = false;
1823 quantizeFC = false;
1824 convertToFP16 = false;
1825
1826 gatherWeights = true;
1827
1828 testRecSys();
1829}
1830
1831INSTANTIATE_BACKEND_TEST(RecommendationSystemTest);
1832