1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include <algorithm>
17#include <array>
18#include <cstdlib>
19#include <fstream>
20#include <future>
21#include <random>
22#include <string>
23
24#include "Bench.h"
25
26#include "glow/ExecutionEngine/ExecutionEngine.h"
27#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
28
29using namespace glow;
30
31/*
32 * This class implements an SLS microbenchmark. There are a number of
33 * parallel FusedRowwiseQuantizedSparseLengthsWeightedSum,
34 * FusedRowwiseQuantizedSparseLengthsSum, SparseLengthsWeightedSum, or
35 * SparseLengthsSum nodes which are created.
36 *
37 * Microbenchmarks are generally useful for understanding performance
38 * through targeted experiementation and are not representative of
39 * end-to-end workloads.
40 */
41
42llvm::cl::OptionCategory SLSBenchCat("SLSBench Category");
43llvm::cl::opt<bool> dumpOnnx("dump_onnx",
44 llvm::cl::desc("dump onnx text format for model"),
45 llvm::cl::Optional, llvm::cl::init(false),
46 llvm::cl::cat(SLSBenchCat));
47
48enum SLSKind {
49 NONQUANTIZED_UNWEIGHTED,
50 NONQUANTIZED_WEIGHTED,
51 QUANTIZED_UNWEIGHTED,
52 QUANTIZED_WEIGHTED
53};
54
55struct SLSParam {
56 dim_t batchSize;
57 dim_t numReps;
58 dim_t numAsyncLaunches;
59 std::string backendStr;
60 std::string devId;
61 dim_t numIndicesPerBatchMin;
62 dim_t numIndicesPerBatchMax;
63 dim_t numIndicesPerBatchPad;
64 dim_t numTableEntries;
65 dim_t numElementsPerRow;
66 dim_t numSLSNodes;
67 SLSKind slsKind;
68 bool isSorted;
69 bool addClip;
70 bool useFP16Accumulation;
71 ElemKind fusedDtype;
72 ElemKind dtype;
73 bool convertFusedToFP32;
74};
75
76std::string getSLSDescription(SLSParam param) {
77 std::string SLSStr =
78 (param.slsKind == NONQUANTIZED_UNWEIGHTED) ? std::string("SLS")
79 : (param.slsKind == NONQUANTIZED_WEIGHTED) ? std::string("SLWS")
80 : (param.slsKind == QUANTIZED_UNWEIGHTED) ? std::string("RWQLSS")
81 : std::string("RWQLSWS");
82
83 return strFormat(
84 "%s__%zu_%zu__%zu__%zu__%zu", SLSStr.c_str(),
85 (size_t)param.numIndicesPerBatchMin, (size_t)param.numIndicesPerBatchMax,
86 (size_t)param.numIndicesPerBatchPad, (size_t)param.numTableEntries,
87 (size_t)param.numElementsPerRow);
88}
89
90class SLSBench : public Benchmark {
91 std::unique_ptr<runtime::HostManager> hostManager_;
92 std::vector<std::unique_ptr<ExecutionContext>> contexts_;
93 std::vector<std::vector<Tensor>> indicesReal_;
94 std::vector<std::vector<Tensor>> weightsReal_;
95 dim_t batchSize_;
96 dim_t asyncLaunchSize_;
97 std::string backendStr_;
98 std::vector<SLSParam> params_;
99 bool convertFusedToFP32_;
100 std::string devId_;
101
102public:
103 SLSBench(dim_t batchSize_, dim_t asyncLaunchSize_, std::string backendStr_,
104 std::vector<SLSParam> params_, bool convertFusedToFP32,
105 std::string devId_ = std::string(""))
106 : batchSize_(batchSize_), asyncLaunchSize_(asyncLaunchSize_),
107 backendStr_(backendStr_), params_(params_),
108 convertFusedToFP32_(convertFusedToFP32), devId_(devId_) {}
109
110 double countSLSGbytes(SLSParam param) const {
111
112 dim_t elementSize = 2;
113 if (param.dtype == ElemKind::FloatTy) {
114 elementSize = 4;
115 }
116
117 dim_t scaleSize = 2;
118 if (param.convertFusedToFP32) {
119 scaleSize = 4;
120 }
121
122 // This is approximate when numIndicesPerBatchMin != numIndicesPerBatchMax.
123 const double avgIndicesPerBatch =
124 (double)(param.numIndicesPerBatchMin + param.numIndicesPerBatchMax) /
125 2.0;
126
127 // Embedding data
128 double input_gbytes = 0.0;
129 if ((param.slsKind == NONQUANTIZED_WEIGHTED) ||
130 (param.slsKind == NONQUANTIZED_UNWEIGHTED)) {
131 input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch *
132 (param.numElementsPerRow * elementSize)) /
133 1e9;
134 } else { // Quantized
135 if (param.fusedDtype == ElemKind::UInt8FusedFP16QTy) {
136 input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch *
137 (param.numElementsPerRow + 2 * scaleSize)) /
138 1e9;
139 } else { // Int4
140 input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch *
141 ((param.numElementsPerRow + 1) / 2 + 2 * scaleSize)) /
142 1e9;
143 }
144 }
145
146 // + indices
147 input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch *
148 sizeof(int32_t)) /
149 1e9;
150
151 // + weights
152 if ((param.slsKind == QUANTIZED_WEIGHTED) ||
153 (param.slsKind == NONQUANTIZED_WEIGHTED)) {
154 input_gbytes +=
155 (param.numSLSNodes * batchSize_ * avgIndicesPerBatch * elementSize) /
156 1e9;
157 }
158
159 // + lengths
160 input_gbytes += (param.numSLSNodes * batchSize_ * sizeof(int32_t)) / 1e9;
161
162 double output_gbytes = (param.numSLSNodes * batchSize_ *
163 (param.numElementsPerRow * elementSize)) /
164 1e9;
165
166 return input_gbytes + output_gbytes;
167 }
168
169 void addSLSNode(std::unique_ptr<Module> &mod, Function *fn, SLSParam param) {
170
171 // Constant needed for Non-quantized case
172 Tensor dataConstantTensor;
173 if ((param.slsKind == NONQUANTIZED_WEIGHTED) ||
174 (param.slsKind == NONQUANTIZED_UNWEIGHTED)) {
175 dataConstantTensor =
176 Tensor(param.dtype, {param.numTableEntries, param.numElementsPerRow});
177 } else {
178 // If RWQ then we need to account for per-row scale/offset in the shape.
179 int64_t numBytePerRow = param.numElementsPerRow;
180 if (param.fusedDtype == ElemKind::UInt4FusedFP16QTy) {
181 // For 4bit tables the number of bytes should be halved (rounded up).
182 numBytePerRow = (numBytePerRow + 1) / 2;
183 }
184 const dim_t numTotalColumns = numBytePerRow + 2 * sizeof(float16_t);
185 dataConstantTensor = Tensor(
186 param.fusedDtype, {param.numTableEntries, numTotalColumns}, 1.0, 0);
187 }
188 Constant *dataConstant = mod->createConstant("SLSData", dataConstantTensor);
189
190 // Create placeholders for weights, indices and lengths
191 const dim_t maxNumIndicesWeights = param.numIndicesPerBatchPad * batchSize_;
192 auto *weights = mod->createPlaceholder(param.dtype, {maxNumIndicesWeights},
193 "weights", false);
194
195 auto *indices = mod->createPlaceholder(ElemKind::Int64ITy,
196 {maxNumIndicesWeights}, "indices",
197 /* isTrainable */ false);
198
199 auto *lengths =
200 mod->createPlaceholder(ElemKind::Int32ITy, {batchSize_}, "lengths",
201 /* isTrainable */ false);
202
203 size_t totalLengthsSum = 0;
204 size_t totalNumLengths = 0;
205 for (dim_t i = 0; i < asyncLaunchSize_; i++) {
206 auto lengthsHandle = contexts_[i]
207 ->getPlaceholderBindings()
208 ->allocate(lengths)
209 ->getHandle<int32_t>();
210
211 // Generate lengths across a uniform distribution.
212 lengthsHandle.randomize(param.numIndicesPerBatchMin,
213 param.numIndicesPerBatchMax, mod->getPRNG());
214 dim_t lengthsSum = 0;
215 for (size_t j = 0, e = lengthsHandle.size(); j < e; j++) {
216 auto &nextLength = lengthsHandle.raw(j);
217 if (lengthsSum == maxNumIndicesWeights) {
218 // If we have maxed out the maximum allowed indices then zero out the
219 // rest of the lengths.
220 nextLength = 0;
221 continue;
222 } else if (lengthsSum + nextLength > maxNumIndicesWeights) {
223 // If the next length will equal or overflow the maximum allowed
224 // indices then fill it up totally.
225 nextLength = maxNumIndicesWeights - lengthsSum;
226 }
227 lengthsSum += nextLength;
228 totalNumLengths += 1;
229 }
230 totalLengthsSum += lengthsSum;
231
232 // Create and sort indices
233 Tensor indicesReal(ElemKind::Int64ITy, {lengthsSum});
234 indicesReal.getHandle<int64_t>().randomize(0, param.numTableEntries,
235 mod->getPRNG());
236 // Sort each segment
237 if (param.isSorted) {
238 int64_t *indicesRealPtr = (int64_t *)indicesReal.getUnsafePtr();
239 for (size_t j = 0, e = lengthsHandle.size(); j < e; j++) {
240 const size_t curLength = lengthsHandle.raw(j);
241 std::sort(indicesRealPtr, indicesRealPtr + curLength);
242 indicesRealPtr += curLength;
243 }
244 }
245 indicesReal_[i].push_back(std::move(indicesReal));
246
247 // Create weights
248 if (param.dtype == ElemKind::FloatTy) {
249 Tensor weightsReal(ElemKind::FloatTy, {lengthsSum});
250 weightsReal.getHandle<float>().clear(1.0f);
251 weightsReal_[i].push_back(std::move(weightsReal));
252 } else if (param.dtype == ElemKind::Float16Ty) {
253 Tensor weightsReal(ElemKind::Float16Ty, {lengthsSum});
254 weightsReal.getHandle<float16_t>().clear(1.0f);
255 weightsReal_[i].push_back(std::move(weightsReal));
256 }
257
258 Tensor indicesPartial(indicesReal_[i].back().getUnsafePtr(),
259 indices->getType(),
260 indicesReal_[i].back().getSizeInBytes());
261
262 contexts_[i]->getPlaceholderBindings()->insert(indices,
263 std::move(indicesPartial));
264
265 Tensor weightsPartial(weightsReal_[i].back().getUnsafePtr(),
266 weights->getType(),
267 weightsReal_[i].back().getSizeInBytes());
268 contexts_[i]->getPlaceholderBindings()->insert(weights,
269 std::move(weightsPartial));
270 } // i
271
272 // Calculate the average length based on all of the lengths generated.
273 const double avgLength = (double)totalLengthsSum / (double)totalNumLengths;
274
275 // Create SLS node, optional clip node, and save node
276 const LengthsMode LM =
277 avgLength == 1.f ? LengthsMode::AllOne : LengthsMode::Variable;
278 Node *R = nullptr;
279 if (param.slsKind == QUANTIZED_UNWEIGHTED) {
280 R = fn->createFusedRowwiseQuantizedSparseLengthsSum(
281 getSLSDescription(param), dataConstant, indices, lengths,
282 param.useFP16Accumulation, LM, avgLength);
283 } else if (param.slsKind == QUANTIZED_WEIGHTED) {
284 R = fn->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
285 getSLSDescription(param), dataConstant, weights, indices, lengths,
286 param.useFP16Accumulation, LM, avgLength);
287 } else if (param.slsKind == NONQUANTIZED_WEIGHTED) {
288 R = fn->createSparseLengthsWeightedSum(getSLSDescription(param),
289 dataConstant, weights, indices,
290 lengths, LM, avgLength);
291 } else { // NonquantizedUnweighted
292 R = fn->createSparseLengthsSum(getSLSDescription(param), dataConstant,
293 indices, lengths, LM, avgLength);
294 }
295 SaveNode *S = nullptr;
296 if (param.addClip) {
297 auto *clp = fn->createClip("clip", R, -65504.0f, 65504.0f);
298 S = fn->createSave("save", clp);
299 } else {
300 S = fn->createSave("save", R);
301 }
302
303 // for each context, add output bindings
304 for (dim_t i = 0; i < asyncLaunchSize_; i++) {
305 contexts_[i]->getPlaceholderBindings()->allocate(S->getPlaceholder());
306 }
307 }
308
309 void setup() override {
310
311 // Create execution contexts here
312 for (dim_t i = 0; i < asyncLaunchSize_; i++) {
313 std::unique_ptr<ExecutionContext> context(new ExecutionContext);
314 contexts_.push_back(std::move(context));
315 }
316
317 // Setup host manager
318 std::vector<std::unique_ptr<runtime::DeviceConfig>> configs;
319 auto config = glow::make_unique<runtime::DeviceConfig>(backendStr_.c_str());
320 if (devId_ != "") {
321 config->parameters["DeviceID"] = devId_.c_str();
322 }
323 configs.push_back(std::move(config));
324 hostManager_ = glow::make_unique<runtime::HostManager>(std::move(configs));
325
326 // Create a function
327 std::unique_ptr<Module> mod(new Module);
328 auto fn = mod->createFunction("singleNode");
329
330 // Keep tensors around so they aren't deleted
331 indicesReal_.resize(asyncLaunchSize_);
332 weightsReal_.resize(asyncLaunchSize_);
333
334 // Add SLS nodes
335 for (auto &param : params_) {
336 for (dim_t i = 0; i < param.numSLSNodes; i++) {
337 addSLSNode(mod, fn, param);
338 }
339 }
340
341 fn->dumpDAG("slsbench.dot");
342 CompilationContext ctx;
343 ctx.dumpFinalGraph = true;
344 ctx.serializeCompiledDAG = dumpOnnx;
345
346 if (convertFusedToFP32_) {
347 ctx.precisionConfig.convert4BitFusedToFP32 = true;
348 ctx.precisionConfig.convert8BitFusedToFP32 = true;
349 }
350
351 EXIT_ON_ERR(hostManager_->addNetwork(std::move(mod), ctx));
352 }
353
354 void run() override {
355 std::vector<std::unique_ptr<ExecutionContext>> localContexts(
356 asyncLaunchSize_);
357 std::vector<std::promise<void>> promises(asyncLaunchSize_);
358 std::vector<std::future<void>> futures;
359
360 // Launch a number of independent requests
361 int i = 0;
362 for (auto &promise : promises) {
363 futures.push_back(promise.get_future());
364 hostManager_->runNetwork(
365 "singleNode", std::move(contexts_[i]),
366 [&localContexts, &promise,
367 i](runtime::RunIdentifierTy, Error err,
368 std::unique_ptr<ExecutionContext> contextPtr) {
369 EXIT_ON_ERR(std::move(err));
370 localContexts[i] = std::move(contextPtr);
371 promise.set_value();
372 });
373 i++;
374 }
375 for (auto &fut : futures) {
376 fut.wait();
377 }
378 for (dim_t j = 0; j < asyncLaunchSize_; j++) {
379 contexts_[j] = std::move(localContexts[j]);
380 }
381 }
382
383 void teardown() override {}
384
385 double gbytes() const {
386 double total = 0.0;
387 for (auto &param : params_) {
388 total += countSLSGbytes(param);
389 }
390 return total;
391 }
392};
393
394// Indices of arguments
395#define ROWWISE_QUANT 14
396#define ACCUM_TYPE 15
397#define DEVICE_ID 16
398
399SLSParam parseArgs(int argc, char *argv[]) {
400 SLSParam param;
401 param.batchSize = atoi(argv[1]);
402 llvm::StringRef numIndicesPerBatchStr(argv[2]);
403 auto split = numIndicesPerBatchStr.split(':');
404 if (split.second == "") {
405 ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchMin, getIntFromStr(argv[2]));
406 param.numIndicesPerBatchMax = param.numIndicesPerBatchMin;
407 } else {
408 ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchMin,
409 getIntFromStr(split.first));
410 ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchMax,
411 getIntFromStr(split.second));
412 CHECK_LE(param.numIndicesPerBatchMin, param.numIndicesPerBatchMax);
413 }
414 ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchPad, getIntFromStr(argv[3]));
415 CHECK_LE(param.numIndicesPerBatchMax, param.numIndicesPerBatchPad);
416 ASSIGN_VALUE_OR_FATAL(param.numTableEntries, getIntFromStr(argv[4]));
417 ASSIGN_VALUE_OR_FATAL(param.numElementsPerRow, getIntFromStr(argv[5]));
418 ASSIGN_VALUE_OR_FATAL(param.numReps, getIntFromStr(argv[6]));
419 ASSIGN_VALUE_OR_FATAL(param.numAsyncLaunches, getIntFromStr(argv[7]));
420 ASSIGN_VALUE_OR_FATAL(param.numSLSNodes, getIntFromStr(argv[8]));
421 printf("batchSize %zu\n", (size_t)param.batchSize);
422 printf("numIndicesPerBatchMin %zu\n", (size_t)param.numIndicesPerBatchMin);
423 printf("numIndicesPerBatchMax %zu\n", (size_t)param.numIndicesPerBatchMax);
424 printf("numIndicesPerBatchPad %zu\n", (size_t)param.numIndicesPerBatchPad);
425 printf("numTableEntries %zu\n", (size_t)param.numTableEntries);
426 printf("numElementsPerRow %zu\n", (size_t)param.numElementsPerRow);
427 printf("numReps %zu\n", (size_t)param.numReps);
428 printf("numAsyncLaunches %zu\n", (size_t)param.numAsyncLaunches);
429 printf("numSLSNodes %zu\n", (size_t)param.numSLSNodes);
430 printf("slsKind %s\n", argv[9]);
431 if (std::string(argv[9]) == "NonquantizedUnweighted") {
432 param.slsKind = NONQUANTIZED_UNWEIGHTED;
433 } else if (std::string(argv[9]) == "NonquantizedWeighted") {
434 param.slsKind = NONQUANTIZED_WEIGHTED;
435 } else if (std::string(argv[9]) == "QuantizedUnweighted") {
436 param.slsKind = QUANTIZED_UNWEIGHTED;
437 } else if (std::string(argv[9]) == "QuantizedWeighted") {
438 param.slsKind = QUANTIZED_WEIGHTED;
439 } else {
440 llvm_unreachable("Invalid SLS Kind");
441 }
442 printf("sortedStr %s\n", argv[10]);
443 if (std::string(argv[10]) == "Sorted") {
444 param.isSorted = true;
445 } else if (std::string(argv[10]) == "Unsorted") {
446 param.isSorted = false;
447 } else {
448 llvm_unreachable("Invalid sortedStr");
449 }
450 printf("backendStr %s\n", argv[11]);
451 param.backendStr = std::string(argv[11]);
452 printf("dtypeStr %s\n", argv[12]);
453 if (std::string(argv[12]) == "Float16") {
454 param.dtype = ElemKind::Float16Ty;
455 } else if (std::string(argv[12]) == "Float32") {
456 param.dtype = ElemKind::FloatTy;
457 } else {
458 llvm_unreachable("Invalid dtype");
459 }
460 printf("addClipStr %s\n", argv[13]);
461 if (std::string(argv[13]) == "True") {
462 param.addClip = true;
463 } else if (std::string(argv[13]) == "False") {
464 param.addClip = false;
465 } else {
466 llvm_unreachable("Invalid addClipStr");
467 }
468 param.convertFusedToFP32 = false;
469 if (argc > ROWWISE_QUANT) {
470 printf("fusedDtype %s\n", argv[ROWWISE_QUANT]);
471 if (std::string(argv[ROWWISE_QUANT]) == "Int8") {
472 param.fusedDtype = ElemKind::UInt8FusedFP16QTy;
473 } else if (std::string(argv[ROWWISE_QUANT]) == "Int8_Fp32") {
474 param.fusedDtype = ElemKind::UInt8FusedFP16QTy;
475 param.convertFusedToFP32 = true;
476 } else if (std::string(argv[ROWWISE_QUANT]) == "Int4") {
477 param.fusedDtype = ElemKind::UInt4FusedFP16QTy;
478 } else if (std::string(argv[ROWWISE_QUANT]) == "Int4_Fp32") {
479 param.fusedDtype = ElemKind::UInt4FusedFP16QTy;
480 param.convertFusedToFP32 = true;
481 } else {
482 llvm_unreachable("Invalid Quantization datatype");
483 }
484 } else {
485 param.fusedDtype = ElemKind::UInt8FusedFP16QTy;
486 }
487 if (argc > ACCUM_TYPE) {
488 printf("useFP16Accumulation %s\n", argv[ACCUM_TYPE]);
489 if (std::string(argv[ACCUM_TYPE]) == "True") {
490 param.useFP16Accumulation = true;
491 } else if (std::string(argv[ACCUM_TYPE]) == "False") {
492 param.useFP16Accumulation = false;
493 } else {
494 llvm_unreachable("Invalid useFP16Accumulation");
495 }
496 } else {
497 param.useFP16Accumulation = false;
498 }
499 if (argc > DEVICE_ID) {
500 printf("devId %s\n", argv[DEVICE_ID]);
501 param.devId = std::string(argv[DEVICE_ID]);
502 } else {
503 param.devId = std::string("");
504 }
505 printf("\n\n");
506 return param;
507}
508
509int main(int argc, char *argv[]) {
510
511 printf("SLS Microbenchmark\n");
512 printf("Usage: SLSBench batchSize(Int) "
513 "[numIndicesPerBatch(Int) | "
514 "numIndicesPerBatchMin(Int):numIndicesPerBatchMax(Int)] "
515 "numIndicesPerBatchPad(Int) numTableEntries(Int) "
516 "numElementsPerRow(int) numReps(Int) "
517 "numAsyncLaunches(Int) numSLSNodes(Int) "
518 "slsKindStr(\"QuantizedWeighted\"|\"QuantizedUnweighted\"|"
519 "\"NonquantizedWeighted\"|"
520 "\"NonquantizedUnweighted\") "
521 "sortedStr(\"Sorted\"|\"Unsorted\") backendStr(String) "
522 "dtypeStr(\"Float16\"|\"Float32\") "
523 "addClipStr(\"True\"|\"False\")\nQuantized only options: "
524 "quantizationDtypeStr(\"Int8\"|\"Int4\") "
525 "useFP16AccumulationStr(\"True\"|\"False\") \n"
526 "Optional: dev_id(Int)\n");
527 printf("\n");
528 printf("Standard Glow command-line options may be passed via the GLOW_OPTS "
529 "environment variable\n");
530 benchParseGlowOpts(argc, argv);
531
532 std::vector<SLSParam> params;
533 std::string runHeader;
534 std::string runPrefix;
535
536 // Using a config file
537 if (argc == 2) {
538 auto fname = std::string(argv[1]);
539 std::ifstream fin(fname.c_str());
540 if (!fin) {
541 std::cout << "Could not open file: " << fname << std::endl;
542 exit(0);
543 }
544 std::string line;
545 while (getline(fin, line)) {
546 std::array<char, 1024> buf;
547 char *saveptr = nullptr;
548 std::vector<char *> argVec;
549 strcpy(buf.data(), line.c_str());
550 char *ptr = strtok_r(buf.data(), " ", &saveptr);
551 while (ptr != nullptr) {
552 argVec.push_back(ptr);
553 ptr = strtok_r(nullptr, " ", &saveptr);
554 }
555 SLSParam param = parseArgs(argVec.size(), argVec.data());
556 params.push_back(param);
557 runHeader = std::string("_,benchName,_,filename");
558 runPrefix = std::string(strFormat("SLSBench,SW,%s", fname.c_str()));
559 }
560 }
561 // Using command line
562 else if (argc == 14 || argc == 15 || argc == 16 || argc == 17 || argc == 18) {
563 SLSParam param = parseArgs(argc, argv);
564 params.push_back(param);
565
566 runHeader = std::string(
567 "_,benchName,_,batchSize,numIndicesPerBatchMin:numIndicesPerBatchMax,"
568 "numIndicesPerBatchPad,numTableEntries,numElementsPerRow,numReps,"
569 "numAsyncLaunches,numSLSNodes,slsKindStr,backendStr,dtypeStr,"
570 "addClipStr,quantizationDtypeStr,useFP16AccumulationStr");
571 runPrefix = std::string(strFormat(
572 "SLSBench,SW,%zu,%zu:%zu,%zu,%zu,%zu,%zu,%zu,%zu,%s,%s,%s,%s,%s,%s,%s",
573 (size_t)param.batchSize, (size_t)param.numIndicesPerBatchMin,
574 (size_t)param.numIndicesPerBatchMax,
575 (size_t)param.numIndicesPerBatchPad, (size_t)param.numTableEntries,
576 (size_t)param.numElementsPerRow, (size_t)param.numReps,
577 (size_t)param.numAsyncLaunches, (size_t)param.numSLSNodes, argv[9],
578 argv[10], argv[11], argv[12], argv[13], argv[14], argv[15]));
579 } else {
580 llvm_unreachable("Invalid command line");
581 }
582
583 SLSParam param = params.front();
584 SLSBench b(param.batchSize, param.numAsyncLaunches, param.backendStr, params,
585 param.convertFusedToFP32, param.devId);
586 auto times = bench(&b, param.numReps);
587
588 printf("%s,runtime,gbytesPerSec\n", runHeader.c_str());
589 for (auto t : times) {
590 printf("BenchResult,%s,%f,%f\n", runPrefix.c_str(),
591 t / param.numAsyncLaunches, b.gbytes() * param.numAsyncLaunches / t);
592 }
593 double min = *(std::min_element(times.begin(), times.end()));
594 dim_t midElt = times.size() / 2;
595 std::nth_element(times.begin(), times.begin() + midElt, times.end());
596 double median = times[midElt];
597 double medianRuntime = median / ((double)param.numAsyncLaunches);
598 double minRuntime = min / ((double)param.numAsyncLaunches);
599 printf("%s,medianRuntime,minRuntime,medianGbytesPerSec,maxGbytesPerSec\n",
600 runHeader.c_str());
601 printf("BenchSummary,%s,%f,%f,%f,%f\n", runPrefix.c_str(), medianRuntime,
602 minRuntime, b.gbytes() / medianRuntime, b.gbytes() / minRuntime);
603}
604