1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include <array> |
17 | #include <cstdlib> |
18 | #include <fstream> |
19 | #include <future> |
20 | #include <random> |
21 | |
22 | #include "Bench.h" |
23 | |
24 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
25 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
26 | |
27 | #include "tests/unittests/BackendTestUtils.h" |
28 | |
29 | using namespace glow; |
30 | |
31 | /* |
32 | * This class implements a GEMM/FC microbenchmark. There are a set of |
33 | * (m x k) * (k x n) = (m x n) matrix multiplications, chained together in |
34 | * multiple layers. |
35 | * |
36 | * Microbenchmarks are generally useful for understanding performance |
37 | * through targeted experiementation and are not representative of |
38 | * end-to-end workloads. |
39 | */ |
40 | |
41 | llvm::cl::OptionCategory GemmBenchCat("GemmBench Category" ); |
42 | llvm::cl::opt<bool> checkCorrectness( |
43 | "check-results" , |
44 | llvm::cl::desc("Check the correctness of the results against the reference " |
45 | "backend (Interpreter)" ), |
46 | llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(GemmBenchCat)); |
47 | llvm::cl::opt<bool> dumpOnnx("dump_onnx" , |
48 | llvm::cl::desc("dump onnx text format for model" ), |
49 | llvm::cl::Optional, llvm::cl::init(false), |
50 | llvm::cl::cat(GemmBenchCat)); |
51 | |
52 | struct GemmParam { |
53 | dim_t m_; |
54 | dim_t n_; |
55 | dim_t k_; |
56 | dim_t numLayers_; |
57 | dim_t numReps_; |
58 | dim_t numAsyncLaunches_; |
59 | dim_t numSplits_; |
60 | std::string backendStr_; |
61 | std::string devId_; |
62 | ElemKind dtype_; |
63 | }; |
64 | |
65 | class GemmBench : public Benchmark { |
66 | GemmParam param_; |
67 | ExecutionContext context_; |
68 | PlaceholderBindings &bindings_; |
69 | std::unique_ptr<runtime::HostManager> hostManager_; |
70 | |
71 | // Refernce bindings and network: |
72 | ExecutionContext refContext_; |
73 | PlaceholderBindings &refBindings_; |
74 | std::unique_ptr<runtime::HostManager> refHostManager_; |
75 | |
76 | public: |
77 | explicit GemmBench(GemmParam param_) |
78 | : param_(param_), bindings_(*context_.getPlaceholderBindings()), |
79 | refBindings_(*refContext_.getPlaceholderBindings()) {} |
80 | |
81 | void addGemmNode(std::unique_ptr<Module> &mod, Function *fn, GemmParam param, |
82 | bool isRef) { |
83 | PlaceholderBindings &bindings = isRef ? refBindings_ : bindings_; |
84 | auto *input = mod->createPlaceholder(param.dtype_, {param.m_, param.k_}, |
85 | "input" , false); |
86 | if (param.dtype_ == ElemKind::Float16Ty) { |
87 | bindings.allocate(input)->getHandle<float16>().randomize(-1.f, 1.f, |
88 | mod->getPRNG()); |
89 | } else { |
90 | assert(param.dtype_ == ElemKind::FloatTy); |
91 | bindings.allocate(input)->getHandle<float>().randomize(-1.f, 1.f, |
92 | mod->getPRNG()); |
93 | } |
94 | auto *output = mod->createPlaceholder(param.dtype_, {param.m_, param.n_}, |
95 | "output" , false); |
96 | bindings.allocate(output); |
97 | Node *cur = input; |
98 | |
99 | Placeholder *ones; |
100 | if (param.k_ > param.n_) { |
101 | ones = mod->createPlaceholder( |
102 | param.dtype_, {param.m_ * (param.k_ - param.n_)}, "ones" , false); |
103 | if (param.dtype_ == ElemKind::Float16Ty) { |
104 | bindings.allocate(ones)->getHandle<float16_t>().clear(1.0); |
105 | } else if (param.dtype_ == ElemKind::FloatTy) { |
106 | bindings.allocate(ones)->getHandle<float>().clear(1.0); |
107 | } |
108 | } |
109 | |
110 | Placeholder *weights; |
111 | Placeholder *bias; |
112 | |
113 | // Create multiple layers of FC nodes |
114 | for (size_t layer = 0; layer < param.numLayers_; layer++) { |
115 | weights = |
116 | mod->createPlaceholder(param.dtype_, {param.k_, param.n_}, |
117 | "weights" + std::to_string(layer), false); |
118 | bias = mod->createPlaceholder(param.dtype_, {param.n_}, |
119 | "bias" + std::to_string(layer), false); |
120 | |
121 | if (param.dtype_ == ElemKind::Float16Ty) { |
122 | bindings.allocate(weights)->getHandle<float16_t>().randomize( |
123 | -1.f, 1.f, mod->getPRNG()); |
124 | bindings.allocate(bias)->getHandle<float16_t>().clear(32); |
125 | } else if (param.dtype_ == ElemKind::FloatTy) { |
126 | bindings.allocate(weights)->getHandle<float>().randomize( |
127 | -1.f, 1.f, mod->getPRNG()); |
128 | bindings.allocate(bias)->getHandle<float>().clear(32); |
129 | } |
130 | |
131 | Node *fc; |
132 | fc = fn->createFullyConnected("fc_" + std::to_string(layer), cur, weights, |
133 | bias); |
134 | cur = fc; |
135 | |
136 | // Handle non-square cases |
137 | if (param.k_ > param.n_ && layer < (param.numLayers_ - 1)) { |
138 | Node *reshape1 = fn->createReshape("reshape1_" + std::to_string(layer), |
139 | fc, {param.m_ * param.n_}); |
140 | Node *concat = fn->createConcat("concat_" + std::to_string(layer), |
141 | {reshape1, ones}, 0); |
142 | Node *reshape2 = fn->createReshape("reshape2_" + std::to_string(layer), |
143 | concat, {param.m_, param.k_}); |
144 | cur = reshape2; |
145 | } else if (param.k_ < param.n_ && layer < (param.numLayers_ - 1)) { |
146 | Node *slice = fn->createSlice("slice_" + std::to_string(layer), fc, |
147 | {0, 0}, {param.m_, param.k_}); |
148 | cur = slice; |
149 | } |
150 | } |
151 | fn->createSave("save1" , cur, output); |
152 | ::glow::convertPlaceholdersToConstants(fn, bindings, {input, output}); |
153 | } |
154 | |
155 | void setupInternal(bool isRef) { |
156 | // Setup host manager |
157 | std::string backendStr = isRef ? "Interpreter" : param_.backendStr_.c_str(); |
158 | std::vector<std::unique_ptr<runtime::DeviceConfig>> configs; |
159 | auto config = glow::make_unique<runtime::DeviceConfig>(backendStr.c_str()); |
160 | if (param_.devId_ != "" ) { |
161 | config->parameters["DeviceID" ] = param_.devId_.c_str(); |
162 | } |
163 | configs.push_back(std::move(config)); |
164 | if (isRef) { |
165 | refHostManager_ = |
166 | glow::make_unique<runtime::HostManager>(std::move(configs)); |
167 | } else { |
168 | hostManager_ = |
169 | glow::make_unique<runtime::HostManager>(std::move(configs)); |
170 | } |
171 | |
172 | std::unique_ptr<Module> mod(new Module); |
173 | auto fn = mod->createFunction("singleNode" ); |
174 | |
175 | addGemmNode(mod, fn, param_, isRef); |
176 | |
177 | // Split weights |
178 | if (param_.numSplits_ > 1) { |
179 | executeVerticalFCWeightsSplit(fn, param_.numSplits_, param_.n_); |
180 | } |
181 | |
182 | CompilationContext ctx; |
183 | ctx.dumpFinalGraph = true; |
184 | ctx.serializeCompiledDAG = dumpOnnx; |
185 | if (isRef) { |
186 | EXIT_ON_ERR(refHostManager_->addNetwork(std::move(mod), ctx)); |
187 | } else { |
188 | EXIT_ON_ERR(hostManager_->addNetwork(std::move(mod), ctx)); |
189 | } |
190 | } |
191 | |
192 | void checkOutput() { |
193 | // First run on the reference backend |
194 | dispatchInference("singleNode" , refHostManager_.get(), refContext_, |
195 | param_.numAsyncLaunches_, |
196 | /*useNewExecutionContext*/ true); |
197 | Tensor *refTensor = |
198 | refBindings_.get(refBindings_.getPlaceholderByNameSlow("output" )); |
199 | CHECK(refTensor) << "Reference Tensor not found" ; |
200 | |
201 | Tensor *noRefTensor = |
202 | bindings_.get(bindings_.getPlaceholderByNameSlow("output" )); |
203 | CHECK(noRefTensor) << "non-reference Tensor not found" ; |
204 | |
205 | // Compare the tensors |
206 | if (!noRefTensor->isEqual(*refTensor)) { |
207 | noRefTensor->dump(); |
208 | refTensor->dump(); |
209 | LOG(FATAL) << "Tensors don't match\n" ; |
210 | } else { |
211 | LOG(INFO) << "Tensors match\n" ; |
212 | } |
213 | } |
214 | |
215 | void setup() override { |
216 | if (checkCorrectness) { |
217 | setupInternal(/* isRef */ true); |
218 | } |
219 | setupInternal(/* isRef */ false); |
220 | } |
221 | void run() override { |
222 | dispatchInference("singleNode" , hostManager_.get(), context_, |
223 | param_.numAsyncLaunches_, |
224 | /*useNewExecutionContext*/ true); |
225 | if (checkCorrectness) { |
226 | checkOutput(); |
227 | } |
228 | } |
229 | |
230 | void teardown() override {} |
231 | |
232 | double gflops() const { |
233 | return 2.0 * param_.m_ * param_.n_ * param_.k_ * param_.numLayers_ / 1e9; |
234 | } |
235 | }; |
236 | |
237 | #define DEVICE_ID 10 |
238 | |
239 | GemmParam parseArgs(int argc, char *argv[]) { |
240 | GemmParam param; |
241 | |
242 | param.m_ = atoi(argv[1]); |
243 | param.n_ = atoi(argv[2]); |
244 | param.k_ = atoi(argv[3]); |
245 | param.numLayers_ = atoi(argv[4]); |
246 | param.numReps_ = atoi(argv[5]); |
247 | param.numAsyncLaunches_ = atoi(argv[6]); |
248 | param.numSplits_ = atoi(argv[7]); |
249 | param.backendStr_ = std::string(argv[8]); |
250 | if (std::string(argv[9]) == "Float16" ) { |
251 | param.dtype_ = ElemKind::Float16Ty; |
252 | } else if (std::string(argv[9]) == "Float32" ) { |
253 | param.dtype_ = ElemKind::FloatTy; |
254 | } else { |
255 | llvm_unreachable("Invalid dtype" ); |
256 | } |
257 | |
258 | printf("m %zu\n" , (size_t)param.m_); |
259 | printf("n %zu\n" , (size_t)param.n_); |
260 | printf("k %zu\n" , (size_t)param.k_); |
261 | printf("numLayers %zu\n" , (size_t)param.numLayers_); |
262 | printf("numReps %zu\n" , (size_t)param.numReps_); |
263 | printf("numAsyncLaunches %zu\n" , (size_t)param.numAsyncLaunches_); |
264 | printf("numSplits %zu\n" , (size_t)param.numSplits_); |
265 | printf("backendStr %s\n" , param.backendStr_.c_str()); |
266 | printf("dtypeStr %s\n" , argv[9]); |
267 | |
268 | if (argc > DEVICE_ID) { |
269 | printf("devId %s\n" , argv[DEVICE_ID]); |
270 | param.devId_ = std::string(argv[DEVICE_ID]); |
271 | } else { |
272 | param.devId_ = std::string("" ); |
273 | } |
274 | printf("\n\n" ); |
275 | return param; |
276 | } |
277 | |
278 | int main(int argc, char *argv[]) { |
279 | printf("GEMM Microbenchmark\n" ); |
280 | printf("Usage: GemmBench m(Int) n(Int) k(Int) numLayers(Int) numReps(Int) " |
281 | "numAsyncLaunches(Int) numSplits(Int) backendStr(String) " |
282 | "dtypeStr(\"Float16\"|\"Float32\") dev_id(Int)\n" ); |
283 | printf("Standard Glow command-line options may be passed via the GLOW_OPTS " |
284 | "environment variable\n" ); |
285 | benchParseGlowOpts(argc, argv); |
286 | |
287 | std::vector<GemmParam> params; |
288 | std::string ; |
289 | std::string runPrefix; |
290 | |
291 | // Using a config file |
292 | if (argc == 2) { |
293 | auto fname = std::string(argv[1]); |
294 | std::ifstream fin(fname.c_str()); |
295 | if (!fin) { |
296 | std::cout << "Could not open file: " << fname << std::endl; |
297 | exit(0); |
298 | } |
299 | std::string line; |
300 | while (getline(fin, line)) { |
301 | std::array<char, 1024> buf; |
302 | char *saveptr = nullptr; |
303 | std::vector<char *> argVec; |
304 | strcpy(buf.data(), line.c_str()); |
305 | char *ptr = strtok_r(buf.data(), " " , &saveptr); |
306 | while (ptr != nullptr) { |
307 | argVec.push_back(ptr); |
308 | ptr = strtok_r(nullptr, " " , &saveptr); |
309 | } |
310 | GemmParam param = parseArgs(argVec.size(), argVec.data()); |
311 | params.push_back(param); |
312 | runHeader = std::string("_,benchName,_,filename" ); |
313 | runPrefix = std::string(strFormat("GemmBench,SW,%s" , fname.c_str())); |
314 | } |
315 | } else if (argc == 10 || argc == 11) { |
316 | GemmParam param = parseArgs(argc, argv); |
317 | params.push_back(param); |
318 | runHeader = std::string( |
319 | "_,benchName,_,m,n,k,numLayers,numReps,numAsyncLaunches,numSplits," |
320 | "backendStr,dtypeStr\n" ); |
321 | runPrefix = std::string(strFormat( |
322 | "GemmBench,SW,%zu,%zu,%zu,%zu,%zu,%zu,%zu,%s,%s" , (size_t)param.m_, |
323 | (size_t)param.n_, (size_t)param.k_, (size_t)param.numLayers_, |
324 | (size_t)param.numReps_, (size_t)param.numAsyncLaunches_, |
325 | (size_t)param.numSplits_, argv[8], argv[9])); |
326 | } else { |
327 | llvm_unreachable("Invalid command line" ); |
328 | } |
329 | |
330 | for (auto param : params) { |
331 | GemmBench b(param); |
332 | auto times = bench(&b, param.numReps_); |
333 | |
334 | printf("%s,runtime,gflopPerSec\n" , runHeader.c_str()); |
335 | for (auto t : times) { |
336 | printf("BenchResult,%s,%f,%f\n" , runPrefix.c_str(), |
337 | t / param.numAsyncLaunches_, |
338 | b.gflops() * param.numAsyncLaunches_ / t); |
339 | } |
340 | double min = *(std::min_element(times.begin(), times.end())); |
341 | dim_t midElt = times.size() / 2; |
342 | std::nth_element(times.begin(), times.begin() + midElt, times.end()); |
343 | double median = times[midElt]; |
344 | double medianRuntime = median / ((double)param.numAsyncLaunches_); |
345 | double minRuntime = min / ((double)param.numAsyncLaunches_); |
346 | printf("%s,medianRuntime,minRuntime,medianGflopPerSec,maxGflopPerSec\n" , |
347 | runHeader.c_str()); |
348 | printf("BenchSummary,%s,%f,%f,%f,%f\n" , runPrefix.c_str(), medianRuntime, |
349 | minRuntime, b.gflops() / medianRuntime, b.gflops() / minRuntime); |
350 | } |
351 | } |
352 | |