1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include <array>
17#include <cstdlib>
18#include <fstream>
19#include <future>
20#include <random>
21
22#include "Bench.h"
23
24#include "glow/ExecutionEngine/ExecutionEngine.h"
25#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
26
27#include "tests/unittests/BackendTestUtils.h"
28
29using namespace glow;
30
31/*
32 * This class implements a GEMM/FC microbenchmark. There are a set of
33 * (m x k) * (k x n) = (m x n) matrix multiplications, chained together in
34 * multiple layers.
35 *
36 * Microbenchmarks are generally useful for understanding performance
37 * through targeted experiementation and are not representative of
38 * end-to-end workloads.
39 */
40
41llvm::cl::OptionCategory GemmBenchCat("GemmBench Category");
42llvm::cl::opt<bool> checkCorrectness(
43 "check-results",
44 llvm::cl::desc("Check the correctness of the results against the reference "
45 "backend (Interpreter)"),
46 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(GemmBenchCat));
47llvm::cl::opt<bool> dumpOnnx("dump_onnx",
48 llvm::cl::desc("dump onnx text format for model"),
49 llvm::cl::Optional, llvm::cl::init(false),
50 llvm::cl::cat(GemmBenchCat));
51
52struct GemmParam {
53 dim_t m_;
54 dim_t n_;
55 dim_t k_;
56 dim_t numLayers_;
57 dim_t numReps_;
58 dim_t numAsyncLaunches_;
59 dim_t numSplits_;
60 std::string backendStr_;
61 std::string devId_;
62 ElemKind dtype_;
63};
64
65class GemmBench : public Benchmark {
66 GemmParam param_;
67 ExecutionContext context_;
68 PlaceholderBindings &bindings_;
69 std::unique_ptr<runtime::HostManager> hostManager_;
70
71 // Refernce bindings and network:
72 ExecutionContext refContext_;
73 PlaceholderBindings &refBindings_;
74 std::unique_ptr<runtime::HostManager> refHostManager_;
75
76public:
77 explicit GemmBench(GemmParam param_)
78 : param_(param_), bindings_(*context_.getPlaceholderBindings()),
79 refBindings_(*refContext_.getPlaceholderBindings()) {}
80
81 void addGemmNode(std::unique_ptr<Module> &mod, Function *fn, GemmParam param,
82 bool isRef) {
83 PlaceholderBindings &bindings = isRef ? refBindings_ : bindings_;
84 auto *input = mod->createPlaceholder(param.dtype_, {param.m_, param.k_},
85 "input", false);
86 if (param.dtype_ == ElemKind::Float16Ty) {
87 bindings.allocate(input)->getHandle<float16>().randomize(-1.f, 1.f,
88 mod->getPRNG());
89 } else {
90 assert(param.dtype_ == ElemKind::FloatTy);
91 bindings.allocate(input)->getHandle<float>().randomize(-1.f, 1.f,
92 mod->getPRNG());
93 }
94 auto *output = mod->createPlaceholder(param.dtype_, {param.m_, param.n_},
95 "output", false);
96 bindings.allocate(output);
97 Node *cur = input;
98
99 Placeholder *ones;
100 if (param.k_ > param.n_) {
101 ones = mod->createPlaceholder(
102 param.dtype_, {param.m_ * (param.k_ - param.n_)}, "ones", false);
103 if (param.dtype_ == ElemKind::Float16Ty) {
104 bindings.allocate(ones)->getHandle<float16_t>().clear(1.0);
105 } else if (param.dtype_ == ElemKind::FloatTy) {
106 bindings.allocate(ones)->getHandle<float>().clear(1.0);
107 }
108 }
109
110 Placeholder *weights;
111 Placeholder *bias;
112
113 // Create multiple layers of FC nodes
114 for (size_t layer = 0; layer < param.numLayers_; layer++) {
115 weights =
116 mod->createPlaceholder(param.dtype_, {param.k_, param.n_},
117 "weights" + std::to_string(layer), false);
118 bias = mod->createPlaceholder(param.dtype_, {param.n_},
119 "bias" + std::to_string(layer), false);
120
121 if (param.dtype_ == ElemKind::Float16Ty) {
122 bindings.allocate(weights)->getHandle<float16_t>().randomize(
123 -1.f, 1.f, mod->getPRNG());
124 bindings.allocate(bias)->getHandle<float16_t>().clear(32);
125 } else if (param.dtype_ == ElemKind::FloatTy) {
126 bindings.allocate(weights)->getHandle<float>().randomize(
127 -1.f, 1.f, mod->getPRNG());
128 bindings.allocate(bias)->getHandle<float>().clear(32);
129 }
130
131 Node *fc;
132 fc = fn->createFullyConnected("fc_" + std::to_string(layer), cur, weights,
133 bias);
134 cur = fc;
135
136 // Handle non-square cases
137 if (param.k_ > param.n_ && layer < (param.numLayers_ - 1)) {
138 Node *reshape1 = fn->createReshape("reshape1_" + std::to_string(layer),
139 fc, {param.m_ * param.n_});
140 Node *concat = fn->createConcat("concat_" + std::to_string(layer),
141 {reshape1, ones}, 0);
142 Node *reshape2 = fn->createReshape("reshape2_" + std::to_string(layer),
143 concat, {param.m_, param.k_});
144 cur = reshape2;
145 } else if (param.k_ < param.n_ && layer < (param.numLayers_ - 1)) {
146 Node *slice = fn->createSlice("slice_" + std::to_string(layer), fc,
147 {0, 0}, {param.m_, param.k_});
148 cur = slice;
149 }
150 }
151 fn->createSave("save1", cur, output);
152 ::glow::convertPlaceholdersToConstants(fn, bindings, {input, output});
153 }
154
155 void setupInternal(bool isRef) {
156 // Setup host manager
157 std::string backendStr = isRef ? "Interpreter" : param_.backendStr_.c_str();
158 std::vector<std::unique_ptr<runtime::DeviceConfig>> configs;
159 auto config = glow::make_unique<runtime::DeviceConfig>(backendStr.c_str());
160 if (param_.devId_ != "") {
161 config->parameters["DeviceID"] = param_.devId_.c_str();
162 }
163 configs.push_back(std::move(config));
164 if (isRef) {
165 refHostManager_ =
166 glow::make_unique<runtime::HostManager>(std::move(configs));
167 } else {
168 hostManager_ =
169 glow::make_unique<runtime::HostManager>(std::move(configs));
170 }
171
172 std::unique_ptr<Module> mod(new Module);
173 auto fn = mod->createFunction("singleNode");
174
175 addGemmNode(mod, fn, param_, isRef);
176
177 // Split weights
178 if (param_.numSplits_ > 1) {
179 executeVerticalFCWeightsSplit(fn, param_.numSplits_, param_.n_);
180 }
181
182 CompilationContext ctx;
183 ctx.dumpFinalGraph = true;
184 ctx.serializeCompiledDAG = dumpOnnx;
185 if (isRef) {
186 EXIT_ON_ERR(refHostManager_->addNetwork(std::move(mod), ctx));
187 } else {
188 EXIT_ON_ERR(hostManager_->addNetwork(std::move(mod), ctx));
189 }
190 }
191
192 void checkOutput() {
193 // First run on the reference backend
194 dispatchInference("singleNode", refHostManager_.get(), refContext_,
195 param_.numAsyncLaunches_,
196 /*useNewExecutionContext*/ true);
197 Tensor *refTensor =
198 refBindings_.get(refBindings_.getPlaceholderByNameSlow("output"));
199 CHECK(refTensor) << "Reference Tensor not found";
200
201 Tensor *noRefTensor =
202 bindings_.get(bindings_.getPlaceholderByNameSlow("output"));
203 CHECK(noRefTensor) << "non-reference Tensor not found";
204
205 // Compare the tensors
206 if (!noRefTensor->isEqual(*refTensor)) {
207 noRefTensor->dump();
208 refTensor->dump();
209 LOG(FATAL) << "Tensors don't match\n";
210 } else {
211 LOG(INFO) << "Tensors match\n";
212 }
213 }
214
215 void setup() override {
216 if (checkCorrectness) {
217 setupInternal(/* isRef */ true);
218 }
219 setupInternal(/* isRef */ false);
220 }
221 void run() override {
222 dispatchInference("singleNode", hostManager_.get(), context_,
223 param_.numAsyncLaunches_,
224 /*useNewExecutionContext*/ true);
225 if (checkCorrectness) {
226 checkOutput();
227 }
228 }
229
230 void teardown() override {}
231
232 double gflops() const {
233 return 2.0 * param_.m_ * param_.n_ * param_.k_ * param_.numLayers_ / 1e9;
234 }
235};
236
237#define DEVICE_ID 10
238
239GemmParam parseArgs(int argc, char *argv[]) {
240 GemmParam param;
241
242 param.m_ = atoi(argv[1]);
243 param.n_ = atoi(argv[2]);
244 param.k_ = atoi(argv[3]);
245 param.numLayers_ = atoi(argv[4]);
246 param.numReps_ = atoi(argv[5]);
247 param.numAsyncLaunches_ = atoi(argv[6]);
248 param.numSplits_ = atoi(argv[7]);
249 param.backendStr_ = std::string(argv[8]);
250 if (std::string(argv[9]) == "Float16") {
251 param.dtype_ = ElemKind::Float16Ty;
252 } else if (std::string(argv[9]) == "Float32") {
253 param.dtype_ = ElemKind::FloatTy;
254 } else {
255 llvm_unreachable("Invalid dtype");
256 }
257
258 printf("m %zu\n", (size_t)param.m_);
259 printf("n %zu\n", (size_t)param.n_);
260 printf("k %zu\n", (size_t)param.k_);
261 printf("numLayers %zu\n", (size_t)param.numLayers_);
262 printf("numReps %zu\n", (size_t)param.numReps_);
263 printf("numAsyncLaunches %zu\n", (size_t)param.numAsyncLaunches_);
264 printf("numSplits %zu\n", (size_t)param.numSplits_);
265 printf("backendStr %s\n", param.backendStr_.c_str());
266 printf("dtypeStr %s\n", argv[9]);
267
268 if (argc > DEVICE_ID) {
269 printf("devId %s\n", argv[DEVICE_ID]);
270 param.devId_ = std::string(argv[DEVICE_ID]);
271 } else {
272 param.devId_ = std::string("");
273 }
274 printf("\n\n");
275 return param;
276}
277
278int main(int argc, char *argv[]) {
279 printf("GEMM Microbenchmark\n");
280 printf("Usage: GemmBench m(Int) n(Int) k(Int) numLayers(Int) numReps(Int) "
281 "numAsyncLaunches(Int) numSplits(Int) backendStr(String) "
282 "dtypeStr(\"Float16\"|\"Float32\") dev_id(Int)\n");
283 printf("Standard Glow command-line options may be passed via the GLOW_OPTS "
284 "environment variable\n");
285 benchParseGlowOpts(argc, argv);
286
287 std::vector<GemmParam> params;
288 std::string runHeader;
289 std::string runPrefix;
290
291 // Using a config file
292 if (argc == 2) {
293 auto fname = std::string(argv[1]);
294 std::ifstream fin(fname.c_str());
295 if (!fin) {
296 std::cout << "Could not open file: " << fname << std::endl;
297 exit(0);
298 }
299 std::string line;
300 while (getline(fin, line)) {
301 std::array<char, 1024> buf;
302 char *saveptr = nullptr;
303 std::vector<char *> argVec;
304 strcpy(buf.data(), line.c_str());
305 char *ptr = strtok_r(buf.data(), " ", &saveptr);
306 while (ptr != nullptr) {
307 argVec.push_back(ptr);
308 ptr = strtok_r(nullptr, " ", &saveptr);
309 }
310 GemmParam param = parseArgs(argVec.size(), argVec.data());
311 params.push_back(param);
312 runHeader = std::string("_,benchName,_,filename");
313 runPrefix = std::string(strFormat("GemmBench,SW,%s", fname.c_str()));
314 }
315 } else if (argc == 10 || argc == 11) {
316 GemmParam param = parseArgs(argc, argv);
317 params.push_back(param);
318 runHeader = std::string(
319 "_,benchName,_,m,n,k,numLayers,numReps,numAsyncLaunches,numSplits,"
320 "backendStr,dtypeStr\n");
321 runPrefix = std::string(strFormat(
322 "GemmBench,SW,%zu,%zu,%zu,%zu,%zu,%zu,%zu,%s,%s", (size_t)param.m_,
323 (size_t)param.n_, (size_t)param.k_, (size_t)param.numLayers_,
324 (size_t)param.numReps_, (size_t)param.numAsyncLaunches_,
325 (size_t)param.numSplits_, argv[8], argv[9]));
326 } else {
327 llvm_unreachable("Invalid command line");
328 }
329
330 for (auto param : params) {
331 GemmBench b(param);
332 auto times = bench(&b, param.numReps_);
333
334 printf("%s,runtime,gflopPerSec\n", runHeader.c_str());
335 for (auto t : times) {
336 printf("BenchResult,%s,%f,%f\n", runPrefix.c_str(),
337 t / param.numAsyncLaunches_,
338 b.gflops() * param.numAsyncLaunches_ / t);
339 }
340 double min = *(std::min_element(times.begin(), times.end()));
341 dim_t midElt = times.size() / 2;
342 std::nth_element(times.begin(), times.begin() + midElt, times.end());
343 double median = times[midElt];
344 double medianRuntime = median / ((double)param.numAsyncLaunches_);
345 double minRuntime = min / ((double)param.numAsyncLaunches_);
346 printf("%s,medianRuntime,minRuntime,medianGflopPerSec,maxGflopPerSec\n",
347 runHeader.c_str());
348 printf("BenchSummary,%s,%f,%f,%f,%f\n", runPrefix.c_str(), medianRuntime,
349 minRuntime, b.gflops() / medianRuntime, b.gflops() / minRuntime);
350 }
351}
352