1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include <algorithm> |
17 | #include <array> |
18 | #include <cstdlib> |
19 | #include <fstream> |
20 | #include <future> |
21 | #include <random> |
22 | #include <string> |
23 | |
24 | #include "Bench.h" |
25 | |
26 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
27 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
28 | |
29 | using namespace glow; |
30 | |
31 | /* |
32 | * This class implements an Gather microbenchmark. |
33 | * |
34 | * Microbenchmarks are generally useful for understanding performance |
35 | * through targeted experiementation and are not representative of |
36 | * end-to-end workloads. |
37 | */ |
38 | |
39 | struct GatherParam { |
40 | dim_t numReps; |
41 | dim_t numAsyncLaunches; |
42 | std::string backendStr; |
43 | std::string devId; |
44 | dim_t numIndices; |
45 | dim_t numTableEntries; |
46 | dim_t numElementsPerRow; |
47 | dim_t numGatherNodes; |
48 | bool isSorted; |
49 | ElemKind dtype; |
50 | }; |
51 | |
52 | std::string getGatherDescription(GatherParam param) { |
53 | std::string GatherStr = std::string("Gather" ); |
54 | |
55 | return strFormat("%s_%zu_%zu_%zu" , GatherStr.c_str(), |
56 | (size_t)param.numIndices, (size_t)param.numTableEntries, |
57 | (size_t)param.numElementsPerRow); |
58 | } |
59 | |
60 | class GatherBench : public Benchmark { |
61 | std::unique_ptr<runtime::HostManager> hostManager_; |
62 | std::vector<std::unique_ptr<ExecutionContext>> contexts_; |
63 | std::vector<std::vector<Tensor>> indicesReal_; |
64 | dim_t asyncLaunchSize_; |
65 | std::string backendStr_; |
66 | std::vector<GatherParam> params_; |
67 | std::string devId_; |
68 | |
69 | public: |
70 | GatherBench(dim_t asyncLaunchSize_, std::string backendStr_, |
71 | std::vector<GatherParam> params_, |
72 | std::string devId_ = std::string("" )) |
73 | : asyncLaunchSize_(asyncLaunchSize_), backendStr_(backendStr_), |
74 | params_(params_), devId_(devId_) {} |
75 | |
76 | double countGatherInputGbytes(GatherParam param) const { |
77 | |
78 | dim_t elementSize = 2; |
79 | if (param.dtype == ElemKind::FloatTy) { |
80 | elementSize = 4; |
81 | } |
82 | |
83 | // Embedding data. |
84 | double input_gbytes = 0.0; |
85 | input_gbytes += (param.numGatherNodes * param.numIndices * |
86 | (param.numElementsPerRow * elementSize)) / |
87 | 1e9; |
88 | |
89 | // + Indices. |
90 | input_gbytes += |
91 | (param.numGatherNodes * param.numIndices * sizeof(int32_t)) / 1e9; |
92 | |
93 | return input_gbytes; |
94 | } |
95 | |
96 | void addGatherNode(std::unique_ptr<Module> &mod, Function *fn, |
97 | GatherParam param) { |
98 | |
99 | // Input date is Non-quantized and Constant. |
100 | Constant *dataConstant = nullptr; |
101 | Tensor dataConstantTensor(param.dtype, |
102 | {param.numTableEntries, param.numElementsPerRow}); |
103 | if (param.dtype == ElemKind::FloatTy) { |
104 | dataConstantTensor.getHandle<float>().clear(1.0f); |
105 | } else { |
106 | dataConstantTensor.getHandle<float16_t>().clear(1.0f); |
107 | } |
108 | dataConstant = mod->createConstant("GatherData" , dataConstantTensor); |
109 | |
110 | auto *indices = mod->createPlaceholder(ElemKind::Int32ITy, |
111 | {param.numIndices}, "indices" , |
112 | /* isTrainable */ false); |
113 | |
114 | for (dim_t i = 0; i < asyncLaunchSize_; i++) { |
115 | |
116 | // Create and sort indices. |
117 | Tensor indicesReal(ElemKind::Int32ITy, {param.numIndices}); |
118 | indicesReal.getHandle<int32_t>().randomize(0, param.numTableEntries - 1, |
119 | mod->getPRNG()); |
120 | // Sort each segment. |
121 | if (param.isSorted) { |
122 | int32_t *indicesRealPtr = (int32_t *)indicesReal.getUnsafePtr(); |
123 | std::sort(indicesRealPtr, indicesRealPtr + param.numIndices); |
124 | } |
125 | indicesReal_[i].push_back(std::move(indicesReal)); |
126 | |
127 | Tensor indicesPartial(indicesReal_[i].back().getUnsafePtr(), |
128 | indices->getType(), |
129 | indicesReal_[i].back().getSizeInBytes()); |
130 | |
131 | contexts_[i]->getPlaceholderBindings()->insert(indices, |
132 | std::move(indicesPartial)); |
133 | |
134 | } // i |
135 | |
136 | // Create Gather node, then slice it and then save node. |
137 | Node *R = nullptr; |
138 | R = fn->createGather(getGatherDescription(param), dataConstant, indices, 0); |
139 | SliceNode *SN; |
140 | SN = fn->createSlice("slice" , R, {0, 0}, {1, param.numElementsPerRow}); |
141 | |
142 | SaveNode *S = nullptr; |
143 | S = fn->createSave("save" , SN); |
144 | |
145 | // For each context, add output bindings. |
146 | for (dim_t i = 0; i < asyncLaunchSize_; i++) { |
147 | contexts_[i]->getPlaceholderBindings()->allocate(S->getPlaceholder()); |
148 | } |
149 | } |
150 | |
151 | void setup() override { |
152 | |
153 | // Create execution contexts here. |
154 | for (dim_t i = 0; i < asyncLaunchSize_; i++) { |
155 | std::unique_ptr<ExecutionContext> context(new ExecutionContext); |
156 | contexts_.push_back(std::move(context)); |
157 | } |
158 | |
159 | // Setup host manager. |
160 | std::vector<std::unique_ptr<runtime::DeviceConfig>> configs; |
161 | auto config = glow::make_unique<runtime::DeviceConfig>(backendStr_.c_str()); |
162 | if (!devId_.empty()) { |
163 | config->parameters["DeviceID" ] = devId_.c_str(); |
164 | } |
165 | configs.push_back(std::move(config)); |
166 | hostManager_ = glow::make_unique<runtime::HostManager>(std::move(configs)); |
167 | |
168 | // Create a function. |
169 | std::unique_ptr<Module> mod(new Module); |
170 | auto fn = mod->createFunction("singleNode" ); |
171 | |
172 | // Keep tensors around so they aren't deleted. |
173 | indicesReal_.resize(asyncLaunchSize_); |
174 | |
175 | // Add Gather nodes. |
176 | for (auto ¶m : params_) { |
177 | for (dim_t i = 0; i < param.numGatherNodes; i++) { |
178 | addGatherNode(mod, fn, param); |
179 | } |
180 | } |
181 | |
182 | fn->dumpDAG("gatherbench.dot" ); |
183 | CompilationContext ctx; |
184 | EXIT_ON_ERR(hostManager_->addNetwork(std::move(mod), ctx)); |
185 | } |
186 | |
187 | void run() override { |
188 | std::vector<std::unique_ptr<ExecutionContext>> localContexts( |
189 | asyncLaunchSize_); |
190 | std::vector<std::promise<void>> promises(asyncLaunchSize_); |
191 | std::vector<std::future<void>> futures; |
192 | |
193 | // Launch a number of independent requests. |
194 | int i = 0; |
195 | for (auto &promise : promises) { |
196 | futures.push_back(promise.get_future()); |
197 | hostManager_->runNetwork( |
198 | "singleNode" , std::move(contexts_[i]), |
199 | [&localContexts, &promise, |
200 | i](runtime::RunIdentifierTy, Error err, |
201 | std::unique_ptr<ExecutionContext> contextPtr) { |
202 | EXIT_ON_ERR(std::move(err)); |
203 | localContexts[i] = std::move(contextPtr); |
204 | promise.set_value(); |
205 | }); |
206 | i++; |
207 | } |
208 | for (auto &fut : futures) { |
209 | fut.wait(); |
210 | } |
211 | for (dim_t j = 0; j < asyncLaunchSize_; j++) { |
212 | contexts_[j] = std::move(localContexts[j]); |
213 | } |
214 | } |
215 | |
216 | void teardown() override {} |
217 | |
218 | double inputgbytes() const { |
219 | double total_in = 0.0; |
220 | for (auto ¶m : params_) { |
221 | total_in += countGatherInputGbytes(param); |
222 | } |
223 | return total_in; |
224 | } |
225 | }; |
226 | |
227 | // Index of arguments. |
228 | #define DEVICE_ID 10 |
229 | |
230 | GatherParam parseArgs(int argc, char *argv[]) { |
231 | GatherParam param; |
232 | param.numIndices = atoi(argv[1]); |
233 | param.numTableEntries = atoi(argv[2]); |
234 | param.numElementsPerRow = atoi(argv[3]); |
235 | param.numReps = atoi(argv[4]); |
236 | param.numAsyncLaunches = atoi(argv[5]); |
237 | param.numGatherNodes = atoi(argv[6]); |
238 | printf("numIndices %zu\n" , (size_t)param.numIndices); |
239 | printf("numTableEntries %zu\n" , (size_t)param.numTableEntries); |
240 | printf("numElementsPerRow %zu\n" , (size_t)param.numElementsPerRow); |
241 | printf("numReps %zu\n" , (size_t)param.numReps); |
242 | printf("numAsyncLaunches %zu\n" , (size_t)param.numAsyncLaunches); |
243 | printf("numGatherNodes %zu\n" , (size_t)param.numGatherNodes); |
244 | printf("sortedStr %s\n" , argv[7]); |
245 | if (std::string(argv[7]) == "Sorted" ) { |
246 | param.isSorted = true; |
247 | } else if (std::string(argv[7]) == "Unsorted" ) { |
248 | param.isSorted = false; |
249 | } else { |
250 | llvm_unreachable("Invalid sortedStr" ); |
251 | } |
252 | printf("backendStr %s\n" , argv[8]); |
253 | param.backendStr = std::string(argv[8]); |
254 | printf("dtypeStr %s\n" , argv[9]); |
255 | if (std::string(argv[9]) == "Float16" ) { |
256 | param.dtype = ElemKind::Float16Ty; |
257 | } else if (std::string(argv[9]) == "Float32" ) { |
258 | param.dtype = ElemKind::FloatTy; |
259 | } else { |
260 | llvm_unreachable("Invalid dtype" ); |
261 | } |
262 | if (argc > DEVICE_ID) { |
263 | printf("devId %s\n" , argv[DEVICE_ID]); |
264 | param.devId = std::string(argv[DEVICE_ID]); |
265 | } else { |
266 | param.devId = std::string("" ); |
267 | } |
268 | printf("\n\n" ); |
269 | return param; |
270 | } |
271 | |
272 | int main(int argc, char *argv[]) { |
273 | |
274 | printf("Gather Microbenchmark\n" ); |
275 | printf("Usage: GatherBench numIndices(Int) " |
276 | "numTableEntries(Int) " |
277 | "numElementsPerRow(int) numReps(Int) " |
278 | "numAsyncLaunches(Int) numGatherNodes(Int) " |
279 | "sortedStr(\"Sorted\"|\"Unsorted\") backendStr(String) " |
280 | "dtypeStr(\"Float16\"|\"Float32\") " |
281 | "dev_id(Int)\n" ); |
282 | printf("\n" ); |
283 | |
284 | std::vector<GatherParam> params; |
285 | std::string ; |
286 | std::string runPrefix; |
287 | |
288 | // Using a config file. |
289 | if (argc == 2) { |
290 | auto fname = std::string(argv[1]); |
291 | std::ifstream fin(fname.c_str()); |
292 | if (!fin) { |
293 | std::cout << "Could not open file: " << fname << std::endl; |
294 | exit(0); |
295 | } |
296 | std::string line; |
297 | while (getline(fin, line)) { |
298 | std::array<char, 1024> buf; |
299 | char *saveptr = nullptr; |
300 | std::vector<char *> argVec; |
301 | strcpy(buf.data(), line.c_str()); |
302 | char *ptr = strtok_r(buf.data(), " " , &saveptr); |
303 | while (ptr != nullptr) { |
304 | argVec.push_back(ptr); |
305 | ptr = strtok_r(nullptr, " " , &saveptr); |
306 | } |
307 | GatherParam param = parseArgs(argVec.size(), argVec.data()); |
308 | params.push_back(param); |
309 | runHeader = std::string("_,benchName,_,filename" ); |
310 | runPrefix = std::string(strFormat("GatherBench,SW,%s" , fname.c_str())); |
311 | } |
312 | } |
313 | // Using command line. |
314 | else if (argc == 10 || argc == 11) { |
315 | GatherParam param = parseArgs(argc, argv); |
316 | params.push_back(param); |
317 | |
318 | runHeader = |
319 | std::string("_,benchName,_numIndices," |
320 | "numTableEntries," |
321 | "numElementsPerRow,numReps,numAsyncLaunches,numGatherNodes," |
322 | "sorted,backendStr,dtypeStr" ); |
323 | runPrefix = std::string( |
324 | strFormat("GatherBench,SW,%zu,%zu,%zu,%zu,%zu,%zu,%s,%s,%s" , |
325 | (size_t)param.numIndices, (size_t)param.numTableEntries, |
326 | (size_t)param.numElementsPerRow, (size_t)param.numReps, |
327 | (size_t)param.numAsyncLaunches, (size_t)param.numGatherNodes, |
328 | argv[7], argv[8], argv[9])); |
329 | } else { |
330 | llvm_unreachable("Invalid command line" ); |
331 | } |
332 | |
333 | GatherParam param = params.front(); |
334 | GatherBench b(param.numAsyncLaunches, param.backendStr, params, param.devId); |
335 | auto times = bench(&b, param.numReps); |
336 | |
337 | printf("%s,runtime, gbytesPerSec\n" , runHeader.c_str()); |
338 | for (auto t : times) { |
339 | printf("BenchResult,%s,%f,%f\n" , runPrefix.c_str(), |
340 | t / param.numAsyncLaunches, |
341 | b.inputgbytes() * param.numAsyncLaunches / t); |
342 | } |
343 | double min = *(std::min_element(times.begin(), times.end())); |
344 | dim_t midElt = times.size() / 2; |
345 | std::nth_element(times.begin(), times.begin() + midElt, times.end()); |
346 | double median = times[midElt]; |
347 | double medianRuntime = median / ((double)param.numAsyncLaunches); |
348 | double minRuntime = min / ((double)param.numAsyncLaunches); |
349 | printf("%s,medianRuntime,minRuntime," |
350 | "medianGbytesPerSec,maxGbytesPerSec\n" , |
351 | runHeader.c_str()); |
352 | printf("BenchSummary,%s,%f,%f,%f,%f\n" , runPrefix.c_str(), medianRuntime, |
353 | minRuntime, b.inputgbytes() / medianRuntime, |
354 | b.inputgbytes() / minRuntime); |
355 | } |
356 | |