1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include <algorithm> |
17 | #include <array> |
18 | #include <cstdlib> |
19 | #include <fstream> |
20 | #include <future> |
21 | #include <random> |
22 | #include <string> |
23 | |
24 | #include "Bench.h" |
25 | |
26 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
27 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
28 | |
29 | using namespace glow; |
30 | |
31 | /* |
32 | * This class implements an SLS microbenchmark. There are a number of |
33 | * parallel FusedRowwiseQuantizedSparseLengthsWeightedSum, |
34 | * FusedRowwiseQuantizedSparseLengthsSum, SparseLengthsWeightedSum, or |
35 | * SparseLengthsSum nodes which are created. |
36 | * |
37 | * Microbenchmarks are generally useful for understanding performance |
38 | * through targeted experiementation and are not representative of |
39 | * end-to-end workloads. |
40 | */ |
41 | |
42 | llvm::cl::OptionCategory SLSBenchCat("SLSBench Category" ); |
43 | llvm::cl::opt<bool> dumpOnnx("dump_onnx" , |
44 | llvm::cl::desc("dump onnx text format for model" ), |
45 | llvm::cl::Optional, llvm::cl::init(false), |
46 | llvm::cl::cat(SLSBenchCat)); |
47 | |
48 | enum SLSKind { |
49 | NONQUANTIZED_UNWEIGHTED, |
50 | NONQUANTIZED_WEIGHTED, |
51 | QUANTIZED_UNWEIGHTED, |
52 | QUANTIZED_WEIGHTED |
53 | }; |
54 | |
55 | struct SLSParam { |
56 | dim_t batchSize; |
57 | dim_t numReps; |
58 | dim_t numAsyncLaunches; |
59 | std::string backendStr; |
60 | std::string devId; |
61 | dim_t numIndicesPerBatchMin; |
62 | dim_t numIndicesPerBatchMax; |
63 | dim_t numIndicesPerBatchPad; |
64 | dim_t numTableEntries; |
65 | dim_t numElementsPerRow; |
66 | dim_t numSLSNodes; |
67 | SLSKind slsKind; |
68 | bool isSorted; |
69 | bool addClip; |
70 | bool useFP16Accumulation; |
71 | ElemKind fusedDtype; |
72 | ElemKind dtype; |
73 | bool convertFusedToFP32; |
74 | }; |
75 | |
76 | std::string getSLSDescription(SLSParam param) { |
77 | std::string SLSStr = |
78 | (param.slsKind == NONQUANTIZED_UNWEIGHTED) ? std::string("SLS" ) |
79 | : (param.slsKind == NONQUANTIZED_WEIGHTED) ? std::string("SLWS" ) |
80 | : (param.slsKind == QUANTIZED_UNWEIGHTED) ? std::string("RWQLSS" ) |
81 | : std::string("RWQLSWS" ); |
82 | |
83 | return strFormat( |
84 | "%s__%zu_%zu__%zu__%zu__%zu" , SLSStr.c_str(), |
85 | (size_t)param.numIndicesPerBatchMin, (size_t)param.numIndicesPerBatchMax, |
86 | (size_t)param.numIndicesPerBatchPad, (size_t)param.numTableEntries, |
87 | (size_t)param.numElementsPerRow); |
88 | } |
89 | |
90 | class SLSBench : public Benchmark { |
91 | std::unique_ptr<runtime::HostManager> hostManager_; |
92 | std::vector<std::unique_ptr<ExecutionContext>> contexts_; |
93 | std::vector<std::vector<Tensor>> indicesReal_; |
94 | std::vector<std::vector<Tensor>> weightsReal_; |
95 | dim_t batchSize_; |
96 | dim_t asyncLaunchSize_; |
97 | std::string backendStr_; |
98 | std::vector<SLSParam> params_; |
99 | bool convertFusedToFP32_; |
100 | std::string devId_; |
101 | |
102 | public: |
103 | SLSBench(dim_t batchSize_, dim_t asyncLaunchSize_, std::string backendStr_, |
104 | std::vector<SLSParam> params_, bool convertFusedToFP32, |
105 | std::string devId_ = std::string("" )) |
106 | : batchSize_(batchSize_), asyncLaunchSize_(asyncLaunchSize_), |
107 | backendStr_(backendStr_), params_(params_), |
108 | convertFusedToFP32_(convertFusedToFP32), devId_(devId_) {} |
109 | |
110 | double countSLSGbytes(SLSParam param) const { |
111 | |
112 | dim_t elementSize = 2; |
113 | if (param.dtype == ElemKind::FloatTy) { |
114 | elementSize = 4; |
115 | } |
116 | |
117 | dim_t scaleSize = 2; |
118 | if (param.convertFusedToFP32) { |
119 | scaleSize = 4; |
120 | } |
121 | |
122 | // This is approximate when numIndicesPerBatchMin != numIndicesPerBatchMax. |
123 | const double avgIndicesPerBatch = |
124 | (double)(param.numIndicesPerBatchMin + param.numIndicesPerBatchMax) / |
125 | 2.0; |
126 | |
127 | // Embedding data |
128 | double input_gbytes = 0.0; |
129 | if ((param.slsKind == NONQUANTIZED_WEIGHTED) || |
130 | (param.slsKind == NONQUANTIZED_UNWEIGHTED)) { |
131 | input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch * |
132 | (param.numElementsPerRow * elementSize)) / |
133 | 1e9; |
134 | } else { // Quantized |
135 | if (param.fusedDtype == ElemKind::UInt8FusedFP16QTy) { |
136 | input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch * |
137 | (param.numElementsPerRow + 2 * scaleSize)) / |
138 | 1e9; |
139 | } else { // Int4 |
140 | input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch * |
141 | ((param.numElementsPerRow + 1) / 2 + 2 * scaleSize)) / |
142 | 1e9; |
143 | } |
144 | } |
145 | |
146 | // + indices |
147 | input_gbytes += (param.numSLSNodes * batchSize_ * avgIndicesPerBatch * |
148 | sizeof(int32_t)) / |
149 | 1e9; |
150 | |
151 | // + weights |
152 | if ((param.slsKind == QUANTIZED_WEIGHTED) || |
153 | (param.slsKind == NONQUANTIZED_WEIGHTED)) { |
154 | input_gbytes += |
155 | (param.numSLSNodes * batchSize_ * avgIndicesPerBatch * elementSize) / |
156 | 1e9; |
157 | } |
158 | |
159 | // + lengths |
160 | input_gbytes += (param.numSLSNodes * batchSize_ * sizeof(int32_t)) / 1e9; |
161 | |
162 | double output_gbytes = (param.numSLSNodes * batchSize_ * |
163 | (param.numElementsPerRow * elementSize)) / |
164 | 1e9; |
165 | |
166 | return input_gbytes + output_gbytes; |
167 | } |
168 | |
169 | void addSLSNode(std::unique_ptr<Module> &mod, Function *fn, SLSParam param) { |
170 | |
171 | // Constant needed for Non-quantized case |
172 | Tensor dataConstantTensor; |
173 | if ((param.slsKind == NONQUANTIZED_WEIGHTED) || |
174 | (param.slsKind == NONQUANTIZED_UNWEIGHTED)) { |
175 | dataConstantTensor = |
176 | Tensor(param.dtype, {param.numTableEntries, param.numElementsPerRow}); |
177 | } else { |
178 | // If RWQ then we need to account for per-row scale/offset in the shape. |
179 | int64_t numBytePerRow = param.numElementsPerRow; |
180 | if (param.fusedDtype == ElemKind::UInt4FusedFP16QTy) { |
181 | // For 4bit tables the number of bytes should be halved (rounded up). |
182 | numBytePerRow = (numBytePerRow + 1) / 2; |
183 | } |
184 | const dim_t numTotalColumns = numBytePerRow + 2 * sizeof(float16_t); |
185 | dataConstantTensor = Tensor( |
186 | param.fusedDtype, {param.numTableEntries, numTotalColumns}, 1.0, 0); |
187 | } |
188 | Constant *dataConstant = mod->createConstant("SLSData" , dataConstantTensor); |
189 | |
190 | // Create placeholders for weights, indices and lengths |
191 | const dim_t maxNumIndicesWeights = param.numIndicesPerBatchPad * batchSize_; |
192 | auto *weights = mod->createPlaceholder(param.dtype, {maxNumIndicesWeights}, |
193 | "weights" , false); |
194 | |
195 | auto *indices = mod->createPlaceholder(ElemKind::Int64ITy, |
196 | {maxNumIndicesWeights}, "indices" , |
197 | /* isTrainable */ false); |
198 | |
199 | auto *lengths = |
200 | mod->createPlaceholder(ElemKind::Int32ITy, {batchSize_}, "lengths" , |
201 | /* isTrainable */ false); |
202 | |
203 | size_t totalLengthsSum = 0; |
204 | size_t totalNumLengths = 0; |
205 | for (dim_t i = 0; i < asyncLaunchSize_; i++) { |
206 | auto lengthsHandle = contexts_[i] |
207 | ->getPlaceholderBindings() |
208 | ->allocate(lengths) |
209 | ->getHandle<int32_t>(); |
210 | |
211 | // Generate lengths across a uniform distribution. |
212 | lengthsHandle.randomize(param.numIndicesPerBatchMin, |
213 | param.numIndicesPerBatchMax, mod->getPRNG()); |
214 | dim_t lengthsSum = 0; |
215 | for (size_t j = 0, e = lengthsHandle.size(); j < e; j++) { |
216 | auto &nextLength = lengthsHandle.raw(j); |
217 | if (lengthsSum == maxNumIndicesWeights) { |
218 | // If we have maxed out the maximum allowed indices then zero out the |
219 | // rest of the lengths. |
220 | nextLength = 0; |
221 | continue; |
222 | } else if (lengthsSum + nextLength > maxNumIndicesWeights) { |
223 | // If the next length will equal or overflow the maximum allowed |
224 | // indices then fill it up totally. |
225 | nextLength = maxNumIndicesWeights - lengthsSum; |
226 | } |
227 | lengthsSum += nextLength; |
228 | totalNumLengths += 1; |
229 | } |
230 | totalLengthsSum += lengthsSum; |
231 | |
232 | // Create and sort indices |
233 | Tensor indicesReal(ElemKind::Int64ITy, {lengthsSum}); |
234 | indicesReal.getHandle<int64_t>().randomize(0, param.numTableEntries, |
235 | mod->getPRNG()); |
236 | // Sort each segment |
237 | if (param.isSorted) { |
238 | int64_t *indicesRealPtr = (int64_t *)indicesReal.getUnsafePtr(); |
239 | for (size_t j = 0, e = lengthsHandle.size(); j < e; j++) { |
240 | const size_t curLength = lengthsHandle.raw(j); |
241 | std::sort(indicesRealPtr, indicesRealPtr + curLength); |
242 | indicesRealPtr += curLength; |
243 | } |
244 | } |
245 | indicesReal_[i].push_back(std::move(indicesReal)); |
246 | |
247 | // Create weights |
248 | if (param.dtype == ElemKind::FloatTy) { |
249 | Tensor weightsReal(ElemKind::FloatTy, {lengthsSum}); |
250 | weightsReal.getHandle<float>().clear(1.0f); |
251 | weightsReal_[i].push_back(std::move(weightsReal)); |
252 | } else if (param.dtype == ElemKind::Float16Ty) { |
253 | Tensor weightsReal(ElemKind::Float16Ty, {lengthsSum}); |
254 | weightsReal.getHandle<float16_t>().clear(1.0f); |
255 | weightsReal_[i].push_back(std::move(weightsReal)); |
256 | } |
257 | |
258 | Tensor indicesPartial(indicesReal_[i].back().getUnsafePtr(), |
259 | indices->getType(), |
260 | indicesReal_[i].back().getSizeInBytes()); |
261 | |
262 | contexts_[i]->getPlaceholderBindings()->insert(indices, |
263 | std::move(indicesPartial)); |
264 | |
265 | Tensor weightsPartial(weightsReal_[i].back().getUnsafePtr(), |
266 | weights->getType(), |
267 | weightsReal_[i].back().getSizeInBytes()); |
268 | contexts_[i]->getPlaceholderBindings()->insert(weights, |
269 | std::move(weightsPartial)); |
270 | } // i |
271 | |
272 | // Calculate the average length based on all of the lengths generated. |
273 | const double avgLength = (double)totalLengthsSum / (double)totalNumLengths; |
274 | |
275 | // Create SLS node, optional clip node, and save node |
276 | const LengthsMode LM = |
277 | avgLength == 1.f ? LengthsMode::AllOne : LengthsMode::Variable; |
278 | Node *R = nullptr; |
279 | if (param.slsKind == QUANTIZED_UNWEIGHTED) { |
280 | R = fn->createFusedRowwiseQuantizedSparseLengthsSum( |
281 | getSLSDescription(param), dataConstant, indices, lengths, |
282 | param.useFP16Accumulation, LM, avgLength); |
283 | } else if (param.slsKind == QUANTIZED_WEIGHTED) { |
284 | R = fn->createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
285 | getSLSDescription(param), dataConstant, weights, indices, lengths, |
286 | param.useFP16Accumulation, LM, avgLength); |
287 | } else if (param.slsKind == NONQUANTIZED_WEIGHTED) { |
288 | R = fn->createSparseLengthsWeightedSum(getSLSDescription(param), |
289 | dataConstant, weights, indices, |
290 | lengths, LM, avgLength); |
291 | } else { // NonquantizedUnweighted |
292 | R = fn->createSparseLengthsSum(getSLSDescription(param), dataConstant, |
293 | indices, lengths, LM, avgLength); |
294 | } |
295 | SaveNode *S = nullptr; |
296 | if (param.addClip) { |
297 | auto *clp = fn->createClip("clip" , R, -65504.0f, 65504.0f); |
298 | S = fn->createSave("save" , clp); |
299 | } else { |
300 | S = fn->createSave("save" , R); |
301 | } |
302 | |
303 | // for each context, add output bindings |
304 | for (dim_t i = 0; i < asyncLaunchSize_; i++) { |
305 | contexts_[i]->getPlaceholderBindings()->allocate(S->getPlaceholder()); |
306 | } |
307 | } |
308 | |
309 | void setup() override { |
310 | |
311 | // Create execution contexts here |
312 | for (dim_t i = 0; i < asyncLaunchSize_; i++) { |
313 | std::unique_ptr<ExecutionContext> context(new ExecutionContext); |
314 | contexts_.push_back(std::move(context)); |
315 | } |
316 | |
317 | // Setup host manager |
318 | std::vector<std::unique_ptr<runtime::DeviceConfig>> configs; |
319 | auto config = glow::make_unique<runtime::DeviceConfig>(backendStr_.c_str()); |
320 | if (devId_ != "" ) { |
321 | config->parameters["DeviceID" ] = devId_.c_str(); |
322 | } |
323 | configs.push_back(std::move(config)); |
324 | hostManager_ = glow::make_unique<runtime::HostManager>(std::move(configs)); |
325 | |
326 | // Create a function |
327 | std::unique_ptr<Module> mod(new Module); |
328 | auto fn = mod->createFunction("singleNode" ); |
329 | |
330 | // Keep tensors around so they aren't deleted |
331 | indicesReal_.resize(asyncLaunchSize_); |
332 | weightsReal_.resize(asyncLaunchSize_); |
333 | |
334 | // Add SLS nodes |
335 | for (auto ¶m : params_) { |
336 | for (dim_t i = 0; i < param.numSLSNodes; i++) { |
337 | addSLSNode(mod, fn, param); |
338 | } |
339 | } |
340 | |
341 | fn->dumpDAG("slsbench.dot" ); |
342 | CompilationContext ctx; |
343 | ctx.dumpFinalGraph = true; |
344 | ctx.serializeCompiledDAG = dumpOnnx; |
345 | |
346 | if (convertFusedToFP32_) { |
347 | ctx.precisionConfig.convert4BitFusedToFP32 = true; |
348 | ctx.precisionConfig.convert8BitFusedToFP32 = true; |
349 | } |
350 | |
351 | EXIT_ON_ERR(hostManager_->addNetwork(std::move(mod), ctx)); |
352 | } |
353 | |
354 | void run() override { |
355 | std::vector<std::unique_ptr<ExecutionContext>> localContexts( |
356 | asyncLaunchSize_); |
357 | std::vector<std::promise<void>> promises(asyncLaunchSize_); |
358 | std::vector<std::future<void>> futures; |
359 | |
360 | // Launch a number of independent requests |
361 | int i = 0; |
362 | for (auto &promise : promises) { |
363 | futures.push_back(promise.get_future()); |
364 | hostManager_->runNetwork( |
365 | "singleNode" , std::move(contexts_[i]), |
366 | [&localContexts, &promise, |
367 | i](runtime::RunIdentifierTy, Error err, |
368 | std::unique_ptr<ExecutionContext> contextPtr) { |
369 | EXIT_ON_ERR(std::move(err)); |
370 | localContexts[i] = std::move(contextPtr); |
371 | promise.set_value(); |
372 | }); |
373 | i++; |
374 | } |
375 | for (auto &fut : futures) { |
376 | fut.wait(); |
377 | } |
378 | for (dim_t j = 0; j < asyncLaunchSize_; j++) { |
379 | contexts_[j] = std::move(localContexts[j]); |
380 | } |
381 | } |
382 | |
383 | void teardown() override {} |
384 | |
385 | double gbytes() const { |
386 | double total = 0.0; |
387 | for (auto ¶m : params_) { |
388 | total += countSLSGbytes(param); |
389 | } |
390 | return total; |
391 | } |
392 | }; |
393 | |
394 | // Indices of arguments |
395 | #define ROWWISE_QUANT 14 |
396 | #define ACCUM_TYPE 15 |
397 | #define DEVICE_ID 16 |
398 | |
399 | SLSParam parseArgs(int argc, char *argv[]) { |
400 | SLSParam param; |
401 | param.batchSize = atoi(argv[1]); |
402 | llvm::StringRef numIndicesPerBatchStr(argv[2]); |
403 | auto split = numIndicesPerBatchStr.split(':'); |
404 | if (split.second == "" ) { |
405 | ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchMin, getIntFromStr(argv[2])); |
406 | param.numIndicesPerBatchMax = param.numIndicesPerBatchMin; |
407 | } else { |
408 | ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchMin, |
409 | getIntFromStr(split.first)); |
410 | ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchMax, |
411 | getIntFromStr(split.second)); |
412 | CHECK_LE(param.numIndicesPerBatchMin, param.numIndicesPerBatchMax); |
413 | } |
414 | ASSIGN_VALUE_OR_FATAL(param.numIndicesPerBatchPad, getIntFromStr(argv[3])); |
415 | CHECK_LE(param.numIndicesPerBatchMax, param.numIndicesPerBatchPad); |
416 | ASSIGN_VALUE_OR_FATAL(param.numTableEntries, getIntFromStr(argv[4])); |
417 | ASSIGN_VALUE_OR_FATAL(param.numElementsPerRow, getIntFromStr(argv[5])); |
418 | ASSIGN_VALUE_OR_FATAL(param.numReps, getIntFromStr(argv[6])); |
419 | ASSIGN_VALUE_OR_FATAL(param.numAsyncLaunches, getIntFromStr(argv[7])); |
420 | ASSIGN_VALUE_OR_FATAL(param.numSLSNodes, getIntFromStr(argv[8])); |
421 | printf("batchSize %zu\n" , (size_t)param.batchSize); |
422 | printf("numIndicesPerBatchMin %zu\n" , (size_t)param.numIndicesPerBatchMin); |
423 | printf("numIndicesPerBatchMax %zu\n" , (size_t)param.numIndicesPerBatchMax); |
424 | printf("numIndicesPerBatchPad %zu\n" , (size_t)param.numIndicesPerBatchPad); |
425 | printf("numTableEntries %zu\n" , (size_t)param.numTableEntries); |
426 | printf("numElementsPerRow %zu\n" , (size_t)param.numElementsPerRow); |
427 | printf("numReps %zu\n" , (size_t)param.numReps); |
428 | printf("numAsyncLaunches %zu\n" , (size_t)param.numAsyncLaunches); |
429 | printf("numSLSNodes %zu\n" , (size_t)param.numSLSNodes); |
430 | printf("slsKind %s\n" , argv[9]); |
431 | if (std::string(argv[9]) == "NonquantizedUnweighted" ) { |
432 | param.slsKind = NONQUANTIZED_UNWEIGHTED; |
433 | } else if (std::string(argv[9]) == "NonquantizedWeighted" ) { |
434 | param.slsKind = NONQUANTIZED_WEIGHTED; |
435 | } else if (std::string(argv[9]) == "QuantizedUnweighted" ) { |
436 | param.slsKind = QUANTIZED_UNWEIGHTED; |
437 | } else if (std::string(argv[9]) == "QuantizedWeighted" ) { |
438 | param.slsKind = QUANTIZED_WEIGHTED; |
439 | } else { |
440 | llvm_unreachable("Invalid SLS Kind" ); |
441 | } |
442 | printf("sortedStr %s\n" , argv[10]); |
443 | if (std::string(argv[10]) == "Sorted" ) { |
444 | param.isSorted = true; |
445 | } else if (std::string(argv[10]) == "Unsorted" ) { |
446 | param.isSorted = false; |
447 | } else { |
448 | llvm_unreachable("Invalid sortedStr" ); |
449 | } |
450 | printf("backendStr %s\n" , argv[11]); |
451 | param.backendStr = std::string(argv[11]); |
452 | printf("dtypeStr %s\n" , argv[12]); |
453 | if (std::string(argv[12]) == "Float16" ) { |
454 | param.dtype = ElemKind::Float16Ty; |
455 | } else if (std::string(argv[12]) == "Float32" ) { |
456 | param.dtype = ElemKind::FloatTy; |
457 | } else { |
458 | llvm_unreachable("Invalid dtype" ); |
459 | } |
460 | printf("addClipStr %s\n" , argv[13]); |
461 | if (std::string(argv[13]) == "True" ) { |
462 | param.addClip = true; |
463 | } else if (std::string(argv[13]) == "False" ) { |
464 | param.addClip = false; |
465 | } else { |
466 | llvm_unreachable("Invalid addClipStr" ); |
467 | } |
468 | param.convertFusedToFP32 = false; |
469 | if (argc > ROWWISE_QUANT) { |
470 | printf("fusedDtype %s\n" , argv[ROWWISE_QUANT]); |
471 | if (std::string(argv[ROWWISE_QUANT]) == "Int8" ) { |
472 | param.fusedDtype = ElemKind::UInt8FusedFP16QTy; |
473 | } else if (std::string(argv[ROWWISE_QUANT]) == "Int8_Fp32" ) { |
474 | param.fusedDtype = ElemKind::UInt8FusedFP16QTy; |
475 | param.convertFusedToFP32 = true; |
476 | } else if (std::string(argv[ROWWISE_QUANT]) == "Int4" ) { |
477 | param.fusedDtype = ElemKind::UInt4FusedFP16QTy; |
478 | } else if (std::string(argv[ROWWISE_QUANT]) == "Int4_Fp32" ) { |
479 | param.fusedDtype = ElemKind::UInt4FusedFP16QTy; |
480 | param.convertFusedToFP32 = true; |
481 | } else { |
482 | llvm_unreachable("Invalid Quantization datatype" ); |
483 | } |
484 | } else { |
485 | param.fusedDtype = ElemKind::UInt8FusedFP16QTy; |
486 | } |
487 | if (argc > ACCUM_TYPE) { |
488 | printf("useFP16Accumulation %s\n" , argv[ACCUM_TYPE]); |
489 | if (std::string(argv[ACCUM_TYPE]) == "True" ) { |
490 | param.useFP16Accumulation = true; |
491 | } else if (std::string(argv[ACCUM_TYPE]) == "False" ) { |
492 | param.useFP16Accumulation = false; |
493 | } else { |
494 | llvm_unreachable("Invalid useFP16Accumulation" ); |
495 | } |
496 | } else { |
497 | param.useFP16Accumulation = false; |
498 | } |
499 | if (argc > DEVICE_ID) { |
500 | printf("devId %s\n" , argv[DEVICE_ID]); |
501 | param.devId = std::string(argv[DEVICE_ID]); |
502 | } else { |
503 | param.devId = std::string("" ); |
504 | } |
505 | printf("\n\n" ); |
506 | return param; |
507 | } |
508 | |
509 | int main(int argc, char *argv[]) { |
510 | |
511 | printf("SLS Microbenchmark\n" ); |
512 | printf("Usage: SLSBench batchSize(Int) " |
513 | "[numIndicesPerBatch(Int) | " |
514 | "numIndicesPerBatchMin(Int):numIndicesPerBatchMax(Int)] " |
515 | "numIndicesPerBatchPad(Int) numTableEntries(Int) " |
516 | "numElementsPerRow(int) numReps(Int) " |
517 | "numAsyncLaunches(Int) numSLSNodes(Int) " |
518 | "slsKindStr(\"QuantizedWeighted\"|\"QuantizedUnweighted\"|" |
519 | "\"NonquantizedWeighted\"|" |
520 | "\"NonquantizedUnweighted\") " |
521 | "sortedStr(\"Sorted\"|\"Unsorted\") backendStr(String) " |
522 | "dtypeStr(\"Float16\"|\"Float32\") " |
523 | "addClipStr(\"True\"|\"False\")\nQuantized only options: " |
524 | "quantizationDtypeStr(\"Int8\"|\"Int4\") " |
525 | "useFP16AccumulationStr(\"True\"|\"False\") \n" |
526 | "Optional: dev_id(Int)\n" ); |
527 | printf("\n" ); |
528 | printf("Standard Glow command-line options may be passed via the GLOW_OPTS " |
529 | "environment variable\n" ); |
530 | benchParseGlowOpts(argc, argv); |
531 | |
532 | std::vector<SLSParam> params; |
533 | std::string ; |
534 | std::string runPrefix; |
535 | |
536 | // Using a config file |
537 | if (argc == 2) { |
538 | auto fname = std::string(argv[1]); |
539 | std::ifstream fin(fname.c_str()); |
540 | if (!fin) { |
541 | std::cout << "Could not open file: " << fname << std::endl; |
542 | exit(0); |
543 | } |
544 | std::string line; |
545 | while (getline(fin, line)) { |
546 | std::array<char, 1024> buf; |
547 | char *saveptr = nullptr; |
548 | std::vector<char *> argVec; |
549 | strcpy(buf.data(), line.c_str()); |
550 | char *ptr = strtok_r(buf.data(), " " , &saveptr); |
551 | while (ptr != nullptr) { |
552 | argVec.push_back(ptr); |
553 | ptr = strtok_r(nullptr, " " , &saveptr); |
554 | } |
555 | SLSParam param = parseArgs(argVec.size(), argVec.data()); |
556 | params.push_back(param); |
557 | runHeader = std::string("_,benchName,_,filename" ); |
558 | runPrefix = std::string(strFormat("SLSBench,SW,%s" , fname.c_str())); |
559 | } |
560 | } |
561 | // Using command line |
562 | else if (argc == 14 || argc == 15 || argc == 16 || argc == 17 || argc == 18) { |
563 | SLSParam param = parseArgs(argc, argv); |
564 | params.push_back(param); |
565 | |
566 | runHeader = std::string( |
567 | "_,benchName,_,batchSize,numIndicesPerBatchMin:numIndicesPerBatchMax," |
568 | "numIndicesPerBatchPad,numTableEntries,numElementsPerRow,numReps," |
569 | "numAsyncLaunches,numSLSNodes,slsKindStr,backendStr,dtypeStr," |
570 | "addClipStr,quantizationDtypeStr,useFP16AccumulationStr" ); |
571 | runPrefix = std::string(strFormat( |
572 | "SLSBench,SW,%zu,%zu:%zu,%zu,%zu,%zu,%zu,%zu,%zu,%s,%s,%s,%s,%s,%s,%s" , |
573 | (size_t)param.batchSize, (size_t)param.numIndicesPerBatchMin, |
574 | (size_t)param.numIndicesPerBatchMax, |
575 | (size_t)param.numIndicesPerBatchPad, (size_t)param.numTableEntries, |
576 | (size_t)param.numElementsPerRow, (size_t)param.numReps, |
577 | (size_t)param.numAsyncLaunches, (size_t)param.numSLSNodes, argv[9], |
578 | argv[10], argv[11], argv[12], argv[13], argv[14], argv[15])); |
579 | } else { |
580 | llvm_unreachable("Invalid command line" ); |
581 | } |
582 | |
583 | SLSParam param = params.front(); |
584 | SLSBench b(param.batchSize, param.numAsyncLaunches, param.backendStr, params, |
585 | param.convertFusedToFP32, param.devId); |
586 | auto times = bench(&b, param.numReps); |
587 | |
588 | printf("%s,runtime,gbytesPerSec\n" , runHeader.c_str()); |
589 | for (auto t : times) { |
590 | printf("BenchResult,%s,%f,%f\n" , runPrefix.c_str(), |
591 | t / param.numAsyncLaunches, b.gbytes() * param.numAsyncLaunches / t); |
592 | } |
593 | double min = *(std::min_element(times.begin(), times.end())); |
594 | dim_t midElt = times.size() / 2; |
595 | std::nth_element(times.begin(), times.begin() + midElt, times.end()); |
596 | double median = times[midElt]; |
597 | double medianRuntime = median / ((double)param.numAsyncLaunches); |
598 | double minRuntime = min / ((double)param.numAsyncLaunches); |
599 | printf("%s,medianRuntime,minRuntime,medianGbytesPerSec,maxGbytesPerSec\n" , |
600 | runHeader.c_str()); |
601 | printf("BenchSummary,%s,%f,%f,%f,%f\n" , runPrefix.c_str(), medianRuntime, |
602 | minRuntime, b.gbytes() / medianRuntime, b.gbytes() / minRuntime); |
603 | } |
604 | |