1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include <algorithm>
17#include <array>
18#include <cstdlib>
19#include <fstream>
20#include <future>
21#include <random>
22#include <string>
23
24#include "Bench.h"
25
26#include "glow/ExecutionEngine/ExecutionEngine.h"
27#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
28
29using namespace glow;
30
31/*
32 * This class implements an Gather microbenchmark.
33 *
34 * Microbenchmarks are generally useful for understanding performance
35 * through targeted experiementation and are not representative of
36 * end-to-end workloads.
37 */
38
39struct GatherParam {
40 dim_t numReps;
41 dim_t numAsyncLaunches;
42 std::string backendStr;
43 std::string devId;
44 dim_t numIndices;
45 dim_t numTableEntries;
46 dim_t numElementsPerRow;
47 dim_t numGatherNodes;
48 bool isSorted;
49 ElemKind dtype;
50};
51
52std::string getGatherDescription(GatherParam param) {
53 std::string GatherStr = std::string("Gather");
54
55 return strFormat("%s_%zu_%zu_%zu", GatherStr.c_str(),
56 (size_t)param.numIndices, (size_t)param.numTableEntries,
57 (size_t)param.numElementsPerRow);
58}
59
60class GatherBench : public Benchmark {
61 std::unique_ptr<runtime::HostManager> hostManager_;
62 std::vector<std::unique_ptr<ExecutionContext>> contexts_;
63 std::vector<std::vector<Tensor>> indicesReal_;
64 dim_t asyncLaunchSize_;
65 std::string backendStr_;
66 std::vector<GatherParam> params_;
67 std::string devId_;
68
69public:
70 GatherBench(dim_t asyncLaunchSize_, std::string backendStr_,
71 std::vector<GatherParam> params_,
72 std::string devId_ = std::string(""))
73 : asyncLaunchSize_(asyncLaunchSize_), backendStr_(backendStr_),
74 params_(params_), devId_(devId_) {}
75
76 double countGatherInputGbytes(GatherParam param) const {
77
78 dim_t elementSize = 2;
79 if (param.dtype == ElemKind::FloatTy) {
80 elementSize = 4;
81 }
82
83 // Embedding data.
84 double input_gbytes = 0.0;
85 input_gbytes += (param.numGatherNodes * param.numIndices *
86 (param.numElementsPerRow * elementSize)) /
87 1e9;
88
89 // + Indices.
90 input_gbytes +=
91 (param.numGatherNodes * param.numIndices * sizeof(int32_t)) / 1e9;
92
93 return input_gbytes;
94 }
95
96 void addGatherNode(std::unique_ptr<Module> &mod, Function *fn,
97 GatherParam param) {
98
99 // Input date is Non-quantized and Constant.
100 Constant *dataConstant = nullptr;
101 Tensor dataConstantTensor(param.dtype,
102 {param.numTableEntries, param.numElementsPerRow});
103 if (param.dtype == ElemKind::FloatTy) {
104 dataConstantTensor.getHandle<float>().clear(1.0f);
105 } else {
106 dataConstantTensor.getHandle<float16_t>().clear(1.0f);
107 }
108 dataConstant = mod->createConstant("GatherData", dataConstantTensor);
109
110 auto *indices = mod->createPlaceholder(ElemKind::Int32ITy,
111 {param.numIndices}, "indices",
112 /* isTrainable */ false);
113
114 for (dim_t i = 0; i < asyncLaunchSize_; i++) {
115
116 // Create and sort indices.
117 Tensor indicesReal(ElemKind::Int32ITy, {param.numIndices});
118 indicesReal.getHandle<int32_t>().randomize(0, param.numTableEntries - 1,
119 mod->getPRNG());
120 // Sort each segment.
121 if (param.isSorted) {
122 int32_t *indicesRealPtr = (int32_t *)indicesReal.getUnsafePtr();
123 std::sort(indicesRealPtr, indicesRealPtr + param.numIndices);
124 }
125 indicesReal_[i].push_back(std::move(indicesReal));
126
127 Tensor indicesPartial(indicesReal_[i].back().getUnsafePtr(),
128 indices->getType(),
129 indicesReal_[i].back().getSizeInBytes());
130
131 contexts_[i]->getPlaceholderBindings()->insert(indices,
132 std::move(indicesPartial));
133
134 } // i
135
136 // Create Gather node, then slice it and then save node.
137 Node *R = nullptr;
138 R = fn->createGather(getGatherDescription(param), dataConstant, indices, 0);
139 SliceNode *SN;
140 SN = fn->createSlice("slice", R, {0, 0}, {1, param.numElementsPerRow});
141
142 SaveNode *S = nullptr;
143 S = fn->createSave("save", SN);
144
145 // For each context, add output bindings.
146 for (dim_t i = 0; i < asyncLaunchSize_; i++) {
147 contexts_[i]->getPlaceholderBindings()->allocate(S->getPlaceholder());
148 }
149 }
150
151 void setup() override {
152
153 // Create execution contexts here.
154 for (dim_t i = 0; i < asyncLaunchSize_; i++) {
155 std::unique_ptr<ExecutionContext> context(new ExecutionContext);
156 contexts_.push_back(std::move(context));
157 }
158
159 // Setup host manager.
160 std::vector<std::unique_ptr<runtime::DeviceConfig>> configs;
161 auto config = glow::make_unique<runtime::DeviceConfig>(backendStr_.c_str());
162 if (!devId_.empty()) {
163 config->parameters["DeviceID"] = devId_.c_str();
164 }
165 configs.push_back(std::move(config));
166 hostManager_ = glow::make_unique<runtime::HostManager>(std::move(configs));
167
168 // Create a function.
169 std::unique_ptr<Module> mod(new Module);
170 auto fn = mod->createFunction("singleNode");
171
172 // Keep tensors around so they aren't deleted.
173 indicesReal_.resize(asyncLaunchSize_);
174
175 // Add Gather nodes.
176 for (auto &param : params_) {
177 for (dim_t i = 0; i < param.numGatherNodes; i++) {
178 addGatherNode(mod, fn, param);
179 }
180 }
181
182 fn->dumpDAG("gatherbench.dot");
183 CompilationContext ctx;
184 EXIT_ON_ERR(hostManager_->addNetwork(std::move(mod), ctx));
185 }
186
187 void run() override {
188 std::vector<std::unique_ptr<ExecutionContext>> localContexts(
189 asyncLaunchSize_);
190 std::vector<std::promise<void>> promises(asyncLaunchSize_);
191 std::vector<std::future<void>> futures;
192
193 // Launch a number of independent requests.
194 int i = 0;
195 for (auto &promise : promises) {
196 futures.push_back(promise.get_future());
197 hostManager_->runNetwork(
198 "singleNode", std::move(contexts_[i]),
199 [&localContexts, &promise,
200 i](runtime::RunIdentifierTy, Error err,
201 std::unique_ptr<ExecutionContext> contextPtr) {
202 EXIT_ON_ERR(std::move(err));
203 localContexts[i] = std::move(contextPtr);
204 promise.set_value();
205 });
206 i++;
207 }
208 for (auto &fut : futures) {
209 fut.wait();
210 }
211 for (dim_t j = 0; j < asyncLaunchSize_; j++) {
212 contexts_[j] = std::move(localContexts[j]);
213 }
214 }
215
216 void teardown() override {}
217
218 double inputgbytes() const {
219 double total_in = 0.0;
220 for (auto &param : params_) {
221 total_in += countGatherInputGbytes(param);
222 }
223 return total_in;
224 }
225};
226
227// Index of arguments.
228#define DEVICE_ID 10
229
230GatherParam parseArgs(int argc, char *argv[]) {
231 GatherParam param;
232 param.numIndices = atoi(argv[1]);
233 param.numTableEntries = atoi(argv[2]);
234 param.numElementsPerRow = atoi(argv[3]);
235 param.numReps = atoi(argv[4]);
236 param.numAsyncLaunches = atoi(argv[5]);
237 param.numGatherNodes = atoi(argv[6]);
238 printf("numIndices %zu\n", (size_t)param.numIndices);
239 printf("numTableEntries %zu\n", (size_t)param.numTableEntries);
240 printf("numElementsPerRow %zu\n", (size_t)param.numElementsPerRow);
241 printf("numReps %zu\n", (size_t)param.numReps);
242 printf("numAsyncLaunches %zu\n", (size_t)param.numAsyncLaunches);
243 printf("numGatherNodes %zu\n", (size_t)param.numGatherNodes);
244 printf("sortedStr %s\n", argv[7]);
245 if (std::string(argv[7]) == "Sorted") {
246 param.isSorted = true;
247 } else if (std::string(argv[7]) == "Unsorted") {
248 param.isSorted = false;
249 } else {
250 llvm_unreachable("Invalid sortedStr");
251 }
252 printf("backendStr %s\n", argv[8]);
253 param.backendStr = std::string(argv[8]);
254 printf("dtypeStr %s\n", argv[9]);
255 if (std::string(argv[9]) == "Float16") {
256 param.dtype = ElemKind::Float16Ty;
257 } else if (std::string(argv[9]) == "Float32") {
258 param.dtype = ElemKind::FloatTy;
259 } else {
260 llvm_unreachable("Invalid dtype");
261 }
262 if (argc > DEVICE_ID) {
263 printf("devId %s\n", argv[DEVICE_ID]);
264 param.devId = std::string(argv[DEVICE_ID]);
265 } else {
266 param.devId = std::string("");
267 }
268 printf("\n\n");
269 return param;
270}
271
272int main(int argc, char *argv[]) {
273
274 printf("Gather Microbenchmark\n");
275 printf("Usage: GatherBench numIndices(Int) "
276 "numTableEntries(Int) "
277 "numElementsPerRow(int) numReps(Int) "
278 "numAsyncLaunches(Int) numGatherNodes(Int) "
279 "sortedStr(\"Sorted\"|\"Unsorted\") backendStr(String) "
280 "dtypeStr(\"Float16\"|\"Float32\") "
281 "dev_id(Int)\n");
282 printf("\n");
283
284 std::vector<GatherParam> params;
285 std::string runHeader;
286 std::string runPrefix;
287
288 // Using a config file.
289 if (argc == 2) {
290 auto fname = std::string(argv[1]);
291 std::ifstream fin(fname.c_str());
292 if (!fin) {
293 std::cout << "Could not open file: " << fname << std::endl;
294 exit(0);
295 }
296 std::string line;
297 while (getline(fin, line)) {
298 std::array<char, 1024> buf;
299 char *saveptr = nullptr;
300 std::vector<char *> argVec;
301 strcpy(buf.data(), line.c_str());
302 char *ptr = strtok_r(buf.data(), " ", &saveptr);
303 while (ptr != nullptr) {
304 argVec.push_back(ptr);
305 ptr = strtok_r(nullptr, " ", &saveptr);
306 }
307 GatherParam param = parseArgs(argVec.size(), argVec.data());
308 params.push_back(param);
309 runHeader = std::string("_,benchName,_,filename");
310 runPrefix = std::string(strFormat("GatherBench,SW,%s", fname.c_str()));
311 }
312 }
313 // Using command line.
314 else if (argc == 10 || argc == 11) {
315 GatherParam param = parseArgs(argc, argv);
316 params.push_back(param);
317
318 runHeader =
319 std::string("_,benchName,_numIndices,"
320 "numTableEntries,"
321 "numElementsPerRow,numReps,numAsyncLaunches,numGatherNodes,"
322 "sorted,backendStr,dtypeStr");
323 runPrefix = std::string(
324 strFormat("GatherBench,SW,%zu,%zu,%zu,%zu,%zu,%zu,%s,%s,%s",
325 (size_t)param.numIndices, (size_t)param.numTableEntries,
326 (size_t)param.numElementsPerRow, (size_t)param.numReps,
327 (size_t)param.numAsyncLaunches, (size_t)param.numGatherNodes,
328 argv[7], argv[8], argv[9]));
329 } else {
330 llvm_unreachable("Invalid command line");
331 }
332
333 GatherParam param = params.front();
334 GatherBench b(param.numAsyncLaunches, param.backendStr, params, param.devId);
335 auto times = bench(&b, param.numReps);
336
337 printf("%s,runtime, gbytesPerSec\n", runHeader.c_str());
338 for (auto t : times) {
339 printf("BenchResult,%s,%f,%f\n", runPrefix.c_str(),
340 t / param.numAsyncLaunches,
341 b.inputgbytes() * param.numAsyncLaunches / t);
342 }
343 double min = *(std::min_element(times.begin(), times.end()));
344 dim_t midElt = times.size() / 2;
345 std::nth_element(times.begin(), times.begin() + midElt, times.end());
346 double median = times[midElt];
347 double medianRuntime = median / ((double)param.numAsyncLaunches);
348 double minRuntime = min / ((double)param.numAsyncLaunches);
349 printf("%s,medianRuntime,minRuntime,"
350 "medianGbytesPerSec,maxGbytesPerSec\n",
351 runHeader.c_str());
352 printf("BenchSummary,%s,%f,%f,%f,%f\n", runPrefix.c_str(), medianRuntime,
353 minRuntime, b.inputgbytes() / medianRuntime,
354 b.inputgbytes() / minRuntime);
355}
356