1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "BackendTestUtils.h"
18#include "folly/executors/CPUThreadPoolExecutor.h"
19#include "glow/Backend/Backend.h"
20#include "glow/ExecutionEngine/ExecutionEngine.h"
21#include "glow/Exporter/ONNXModelWriter.h"
22#include "glow/Flags/Flags.h"
23#include "glow/Graph/Graph.h"
24#include "glow/Importer/ONNXModelLoader.h"
25#include "glow/Runtime/DeferredWeightLoader.h"
26#include "glow/Runtime/TraceExporter.h"
27#include "glow/Support/Support.h"
28#include "glow/Support/ZipUtils.h"
29
30#include "llvm/Support/CommandLine.h"
31#include "llvm/Support/FileSystem.h"
32#include "llvm/Support/Signals.h"
33
34#include "google/protobuf/io/coded_stream.h"
35#include "google/protobuf/io/zero_copy_stream_impl.h"
36#include <glog/logging.h>
37
38#include "folly/stats/Histogram.h"
39
40#include <fstream>
41#include <iostream>
42#include <string>
43
44using namespace glow;
45
46namespace {
47llvm::cl::OptionCategory reproTestCat("Repro Category");
48llvm::cl::opt<std::string> modelPathOpt("model", llvm::cl::desc("Input models"),
49 llvm::cl::value_desc("modelPath"),
50 llvm::cl::Required,
51 llvm::cl::cat(reproTestCat));
52llvm::cl::opt<std::string> deferredWeightsPathOpt(
53 "deferred_weights", llvm::cl::desc("Path to the deferred weights file"),
54 llvm::cl::Optional, llvm::cl::init(""), llvm::cl::cat(reproTestCat));
55llvm::cl::list<std::string> inputsOpt("inputs", llvm::cl::desc("Inputs"),
56 llvm::cl::value_desc("Inputs"),
57 llvm::cl::Optional, llvm::cl::ZeroOrMore,
58 llvm::cl::cat(reproTestCat));
59llvm::cl::list<std::string> outputsOpt("outputs", llvm::cl::desc("Ouptuts"),
60 llvm::cl::value_desc("Ouptuts"),
61 llvm::cl::Optional, llvm::cl::ZeroOrMore,
62 llvm::cl::cat(reproTestCat));
63llvm::cl::opt<std::string>
64 inputPatternOpt("input_pattern",
65 llvm::cl::desc("Input file pattern. in_{}.onnx"),
66 llvm::cl::init(""), llvm::cl::cat(reproTestCat));
67llvm::cl::opt<std::string>
68 outputPatternOpt("output_pattern",
69 llvm::cl::desc("Output file pattern. out_{}.onnx"),
70 llvm::cl::init(""), llvm::cl::cat(reproTestCat));
71llvm::cl::opt<unsigned> seqStartOpt(
72 "seq_start", llvm::cl::desc("Start index of input/output files"),
73 llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(reproTestCat));
74llvm::cl::opt<unsigned> seqLenOpt(
75 "seq_len", llvm::cl::desc("Lengths of the input/output file seqquence."),
76 llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(reproTestCat));
77
78llvm::cl::opt<std::string> ExecutionBackend(
79 "backend", llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, NNPI:"),
80 llvm::cl::init("NNPI"), llvm::cl::cat(reproTestCat));
81
82llvm::cl::opt<unsigned> concurrentCountOpt(
83 "concurrent_count", llvm::cl::desc("Number of concurrent requests."),
84 llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(reproTestCat));
85
86llvm::cl::opt<float> deviceMemoryOpt(
87 "glow_device_memory",
88 llvm::cl::desc("Size of memory for a certain Glow backend device"),
89 llvm::cl::Optional, llvm::cl::init(256 * 1024.0 * 1024.0 * 1024.0),
90 llvm::cl::cat(reproTestCat));
91
92llvm::cl::opt<float> thresholdOpt(
93 "threshold", llvm::cl::desc("theshold for tensor numeric comparison"),
94 llvm::cl::Optional, llvm::cl::init(1e-5), llvm::cl::cat(reproTestCat));
95
96llvm::cl::opt<bool> glowDumpGraphAfterLoadOpt(
97 "glow_dump_graph_after_load",
98 llvm::cl::desc(
99 "Dump the glow Graph into files immediately after loading from ONNX"),
100 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat));
101
102llvm::cl::opt<bool> sliceConcatFp32Opt(
103 "glow_slice_concat_fp32",
104 llvm::cl::desc("Don't convert slice and concat ops's precision when "
105 "--glow_global_fp16 is used."),
106 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat));
107
108llvm::cl::opt<bool> dumpOutputsOpt("dump_outputs",
109 llvm::cl::desc("Dump output tensors"),
110 llvm::cl::Optional, llvm::cl::init(true),
111 llvm::cl::cat(reproTestCat));
112
113llvm::cl::opt<bool> indicesInt64Opt(
114 "glow_global_indices_fp64",
115 llvm::cl::desc("Enable converting scale/offset in frwqslws's data from "
116 "int32 to int64"));
117
118llvm::cl::opt<bool, /* ExternalStorage */ true> enablePartialTensorOpt(
119 "glow_enable_partial_tensor", llvm::cl::desc("Enable partial tensor"),
120 llvm::cl::Optional, llvm::cl::location(glow::flags::EnablePartialTensors),
121 llvm::cl::init(true), llvm::cl::cat(reproTestCat));
122
123llvm::cl::opt<unsigned> itersOpt(
124 "iters",
125 llvm::cl::desc("Total number of requests to loop over provided input."),
126 llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(reproTestCat));
127
128llvm::cl::alias requestCountOpt("request_count",
129 llvm::cl::desc("Alias for -iters"),
130 llvm::cl::aliasopt(itersOpt));
131
132llvm::cl::opt<unsigned> durationMinOpt(
133 "duration_min", llvm::cl::desc("Running duration limit in minutes"),
134 llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(reproTestCat));
135
136llvm::cl::opt<bool> glowEnableDeviceTrace(
137 "glow_enable_device_traces",
138 llvm::cl::desc("Enable trace events from inference backend device."),
139 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat));
140
141llvm::cl::opt<bool> skipCorrectnessCheck(
142 "skip_correctness_check", llvm::cl::desc("Skip correctness check"),
143 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat));
144
145llvm::cl::opt<std::string>
146 glowDumpTraceFile("glow_dump_debug_traces_file",
147 llvm::cl::desc("Dump glow trace file"),
148 llvm::cl::Optional, llvm::cl::init(std::string("")),
149 llvm::cl::cat(reproTestCat));
150
151llvm::cl::opt<int32_t>
152 topKCompare("topk_compare",
153 llvm::cl::desc("Compare the topk results against reference"),
154 llvm::cl::Optional, llvm::cl::init(0),
155 llvm::cl::cat(reproTestCat));
156
157llvm::cl::opt<float> top1Threshold(
158 "top1_threshold",
159 llvm::cl::desc(
160 "Percentage of top1 matches to reference that must be achieved"),
161 llvm::cl::Optional, llvm::cl::init(0.0), llvm::cl::cat(reproTestCat));
162
163llvm::cl::opt<bool> logTopKResultsPerExample(
164 "log_topk_results_per_example",
165 llvm::cl::desc("Whether to log topk results vs reference for each example"),
166 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat));
167
168llvm::cl::opt<bool> cosineSimilarityStats(
169 "cosine_similarity_stats",
170 llvm::cl::desc("Whether to compute cosine similarity stats"),
171 llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat));
172
173llvm::cl::opt<float> cosineSimilarityThreshold(
174 "p50_cosine_similarity_threshold",
175 llvm::cl::desc(
176 "Percentage of top1 matches to reference that must be achieved"),
177 llvm::cl::Optional, llvm::cl::init(0.0), llvm::cl::cat(reproTestCat));
178
179llvm::cl::opt<bool> onnxLoaderZipMode(
180 "zip_mode", llvm::cl::desc("zipMode to use with OnnxModelLoader"),
181 llvm::cl::Optional, llvm::cl::init(true), llvm::cl::cat(reproTestCat));
182llvm::cl::opt<unsigned> replicationCountOpt(
183 "replication_count", llvm::cl::desc("Set the network replication count"),
184 llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(reproTestCat));
185
186/// Explicitly show gflags help/version info, depending on \p foundHelpFlag and
187/// \p foundVersionFlag. llvm shows its own help/version info when it parses.
188void gflagsShowHelpVersion(bool foundHelpFlag, bool foundVersionFlag) {
189 const char *binName = gflags::ProgramInvocationShortName();
190 if (foundHelpFlag) {
191 gflags::SetUsageMessage(
192 strFormat("gflags for %s\nUSAGE: %s [options]:", binName, binName));
193 gflags::ShowUsageWithFlagsRestrict(binName, /* restrict_ */ "");
194 llvm::outs() << "\nLLVM CommandLine options:\n";
195 }
196 if (foundVersionFlag) {
197 llvm::outs() << "gflags version:\n";
198 const char *versionStr = gflags::VersionString();
199 llvm::outs() << binName;
200 if (versionStr && *versionStr) {
201 llvm::outs() << " version " << versionStr;
202 }
203 llvm::outs() << "\n\n";
204 }
205}
206
207void parseCommandLine(int argc, char **argv) {
208 // Use different defaults for some flags:
209 FLAGS_glow_global_fp16 = true;
210 FLAGS_glow_clip_fp16 = true;
211 FLAGS_glow_global_fused_scale_offset_fp16 = true;
212 FLAGS_glow_global_fused_scale_offset_fp32 = false;
213 FLAGS_glow_snn_partitioning_kbytes_per_card = 5000000;
214 FLAGS_glow_snn_partitioning_num_cores_sls = 6;
215 FLAGS_glow_snn_partitioning_num_cores_other = 6;
216
217 llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
218
219 // Verify there's no unexpected overlap in flags by llvm/gflags.
220 const auto &llvmOpts = llvm::cl::getRegisteredOptions();
221 for (const auto &opt : llvmOpts) {
222 static const llvm::StringSet<> allowedInBoth = {"help", "version"};
223 if (allowedInBoth.count(opt.getKey())) {
224 continue;
225 }
226 gflags::CommandLineFlagInfo dummy;
227 CHECK(!gflags::GetCommandLineFlagInfo(opt.getKey().data(), &dummy))
228 << "Error: Repeated flag used by both llvm and gflags: "
229 << opt.getKey().data();
230 }
231
232 // Separate out llvm and gflags into their own argc/argv.
233 llvm::SmallVector<char *, 40> llvmArgv, gflagsArgv;
234 llvmArgv.push_back(argv[0]);
235 gflagsArgv.push_back(argv[0]);
236 bool foundHelpFlag = false;
237 bool foundVersionFlag = false;
238 for (int i = 1; i < argc; ++i) {
239 llvm::StringRef flagName(argv[i]);
240 // Positional args are always llvm cl args.
241 if (!flagName.startswith("-")) {
242 llvmArgv.push_back(argv[i]);
243 continue;
244 }
245
246 // Strip off leading '-'.
247 flagName = flagName.drop_while([](char c) -> bool { return c == '-'; });
248 // Look for everything leading up to '=', if any.
249 flagName = flagName.take_until([](char c) -> bool { return c == '='; });
250
251 // Now check if flagName is a gflag, otherwise assume it was from llvm. If
252 // help/version, always pass to llvm; we will also call gflags directly for
253 // them to print before llvm parses/prints, so that both gflags and llvm
254 // will print help/version.
255 gflags::CommandLineFlagInfo dummy;
256 if (!gflags::GetCommandLineFlagInfo(flagName.str().c_str(), &dummy) ||
257 flagName == "help" || flagName == "version") {
258 llvmArgv.push_back(argv[i]);
259 if (flagName == "help") {
260 foundHelpFlag = true;
261 } else if (flagName == "version") {
262 foundVersionFlag = true;
263 }
264 } else {
265 gflagsArgv.push_back(argv[i]);
266 }
267 }
268 int llvmArgc = static_cast<int>(llvmArgv.size());
269 int gflagsArgc = static_cast<int>(gflagsArgv.size());
270
271 // Now we can parse both llvm and gflags safely. All gflags should be
272 // legitimate. All other flags will be passed to llvm, which will complain
273 // about unknown ones.
274 char **gflagsArgvPtr = &gflagsArgv[0];
275 gflags::AllowCommandLineReparsing();
276 gflags::ParseCommandLineFlags(&gflagsArgc, &gflagsArgvPtr,
277 /* remove_flags */ false);
278 gflagsShowHelpVersion(foundHelpFlag, foundVersionFlag);
279 llvm::cl::ParseCommandLineOptions(
280 llvmArgc, &llvmArgv[0],
281 " The Glow compiler\n\n"
282 "Glow is a compiler for neural network accelerators.\n");
283
284 if (top1Threshold > 0.0 && topKCompare == 0) {
285 topKCompare = 1;
286 }
287
288 if (cosineSimilarityThreshold > 0.0) {
289 cosineSimilarityStats = true;
290 }
291}
292
293struct InferenceResult {
294 Error error = Error::empty();
295 std::unique_ptr<ExecutionContext> ctx;
296 int index = 0;
297 std::chrono::time_point<std::chrono::steady_clock> endTime;
298};
299
300class ZipFileBackedDeferredBlobLoader
301 : public ::glow::runtime::DeferredWeightLoader {
302public:
303 explicit ZipFileBackedDeferredBlobLoader(const std::string &path) {
304 zip_ = ::glow::make_unique<::glow::ZipReader>(path);
305 CHECK(zip_);
306 auto numWeightsStr = zip_->getRecord("weights");
307 weightsToLoad_ = atoi(numWeightsStr.c_str());
308 i_ = 0;
309 }
310
311 ::glow::Error loadNextWeight() override {
312 if (weightsToLoad_ == i_) {
313 llvm::outs() << "All deferred weights are loaded\n";
314 currentBlobName_ = "";
315 currentTensor_.reset();
316 zip_.reset(nullptr);
317 return ::glow::Error::success();
318 }
319
320 std::stringstream ss;
321 ss << "weight_" << i_;
322 largeBuffer_ = zip_->getRecord(ss.str());
323 ::ONNX_NAMESPACE::TensorProto t;
324 t.ParseFromString(largeBuffer_);
325
326 currentBlobName_ = glow::legalizeName(t.name());
327 auto tyIdx = typeInfo_.find(currentBlobName_);
328 if (tyIdx == typeInfo_.end()) {
329 return ::MAKE_ERR(
330 ::glow::ErrorValue::ErrorCode::RUNTIME_ERROR,
331 ::glow::strFormat(
332 "Error: Blob name: %s not found in list of static placeholders.",
333 currentBlobName_.c_str()));
334 }
335 auto ty = typeInfo_[currentBlobName_];
336
337 ss.str("");
338 ss << "data_" << i_++;
339 largeBuffer_.clear();
340 if (zip_->hasRecord(ss.str())) {
341 largeBuffer_ = zip_->getRecord(ss.str());
342 LOG(INFO) << "Read weight data " << ss.str() << " of size "
343 << largeBuffer_.size();
344 }
345 currentTensor_.reset(new ::glow::Tensor());
346 RETURN_IF_ERR(::glow::loadTensor(t, currentTensor_.get(),
347 /*useGlowCustomOps*/ false, largeBuffer_));
348 CHECK(currentTensor_->getType().isEqual(ty))
349 << "Mismatched tensor type: " << currentTensor_->getType().toString()
350 << " vs " << ty.toString();
351
352 return ::glow::Error::success();
353 }
354
355 ::glow::Error setSrc(void * /*unused*/) override {
356 return ::glow::Error::success();
357 }
358
359 std::string getName() override { return currentBlobName_; }
360
361 ::glow::Tensor *getTensor() override { return currentTensor_.get(); }
362
363 void setTypeInfo(std::map<std::string, ::glow::Type> info) override {
364 typeInfo_ = std::move(info);
365 }
366
367private:
368 std::unique_ptr<::glow::ZipReader> zip_;
369 std::string largeBuffer_;
370 std::string currentBlobName_;
371 std::unique_ptr<::glow::Tensor> currentTensor_;
372 size_t weightsToLoad_{0};
373 size_t i_{0};
374};
375
376/// Given a float Tensor \p t, \returns a vector of pairs of entries in t with
377/// the first element in the pair being the value and the second element being
378/// the original index of that value in t. The vector is partially sorted such
379/// that the first \p k elements are the k elements from t with the greatest
380/// values.
381static std::vector<std::pair<float, size_t>>
382partialSortFloatTensor(const Tensor &t, size_t k) {
383 std::vector<std::pair<float, size_t>> vec;
384 auto handle = t.getHandle<float>();
385 for (size_t i = 0; i < handle.size(); ++i) {
386 vec.push_back({handle.raw(i), i});
387 }
388 std::partial_sort(
389 vec.begin(), vec.begin() + k, vec.end(),
390 [](const auto &p1, const auto &p2) { return p1.first > p2.first; });
391 return vec;
392}
393
394static float dotProd(const Tensor &t1, const Tensor &t2) {
395 CHECK(t1.getElementType() == ElemKind::FloatTy);
396 CHECK(t2.getElementType() == ElemKind::FloatTy);
397 auto t1H = t1.getHandle<float>();
398 auto t2H = t2.getHandle<float>();
399 CHECK_EQ(t1H.size(), t2H.size());
400 float res = 0.0f;
401 for (dim_t i = 0; i < t1H.size(); i++) {
402 res += t1H.raw(i) * t2H.raw(i);
403 }
404 return res;
405}
406
407static float cosineSimilarity(const Tensor &t1, const Tensor &t2) {
408 auto fn = [](const Tensor &t1, const Tensor &t2) {
409 return dotProd(t1, t2) /
410 (std::sqrt(dotProd(t1, t1)) * std::sqrt(dotProd(t2, t2)));
411 };
412 if (t1.getType().isQuantizedType()) {
413 auto t1Float = quantization::dequantizeTensor(t1, ElemKind::FloatTy);
414 auto t2Float = quantization::dequantizeTensor(t2, ElemKind::FloatTy);
415 return fn(t1Float, t2Float);
416 } else {
417 return fn(t1, t2);
418 }
419}
420
421int run() {
422 int numFailed = 0;
423
424 int numTop1Matches = 0;
425 int numTopKMatches = 0;
426 int numTotalTopKCompares = 0;
427
428 folly::Histogram<float> cosineHist(/* bucketSize */ 0.1f, /* min */ 0.0f,
429 /* max */ 1.0f);
430
431 // Build the execution engine and deserialize the Function.
432 auto mod = glow::make_unique<Module>();
433 Error err = Error::empty();
434 bool usingGlowCustomOps = false;
435 CompilationContext cctx;
436 cctx.replicationCount = replicationCountOpt;
437 cctx.maxActiveRequestsPerInstance = glow::flags::MaxActiveRequestsPerInstance;
438 runtime::PrePartitionedConfig PPC;
439 cctx.prepartitionedConfig = &PPC;
440 {
441 ONNXModelLoader onnxLD(modelPathOpt, {}, {}, *mod, "test", &PPC, &err,
442 onnxLoaderZipMode,
443 &cctx.backendOpts.backendSpecificNodeInfo,
444 /* loadIntoExistingModule */ false,
445 /* disableConstFoldInLoader */ true);
446 usingGlowCustomOps = onnxLD.usingGlowCustomOps();
447 }
448 CHECK(!ERR_TO_BOOL(std::move(err)))
449 << "ONNXModelLoader failed to load model: " << modelPathOpt;
450 llvm::outs() << "End onnx model load\n";
451
452 if (glowDumpGraphAfterLoadOpt) {
453 for (Function *F : mod->getFunctions()) {
454 F->dumpDAG(glow::flags::DumpGraphPath + F->getName().str() + ".dot");
455 }
456 }
457
458 // Build host manager and compile the module.
459 PrecisionConfiguration &precConfig = cctx.precisionConfig;
460 if (glow::flags::ConvertToFP16) {
461 precConfig.convertToFP16 = true;
462 if (sliceConcatFp32Opt) {
463 precConfig.precisionModeKindSet.insert(Kinded::Kind::SliceNodeKind);
464 precConfig.precisionModeKindSet.insert(Kinded::Kind::ConcatNodeKind);
465 }
466 llvm::outs() << "Conversion to fp16 enabled\n";
467 }
468 if (glow::flags::ConvertPlaceholdersToFP16) {
469 precConfig.convertPlaceholdersToFP16 = true;
470 llvm::outs() << "Conversion of Placeholders to fp16 enabled\n";
471 }
472 if (glow::flags::ConvertConstantsToFP16) {
473 precConfig.convertConstantsToFP16 = true;
474 llvm::outs() << "Conversion of Constants to fp16 enabled\n";
475 }
476 if (glow::flags::ConvertFusedScaleOffsetToFP16) {
477 precConfig.convertFusedToFP16 = true;
478 llvm::outs() << "Conversion of fused scales/offsets to fp16 enabled\n";
479 }
480 if (glow::flags::ConvertFusedScaleOffsetToFP32) {
481 precConfig.convert4BitFusedToFP32 = true;
482 precConfig.convert8BitFusedToFP32 = true;
483 llvm::outs()
484 << "Conversion of fused scales/offsets to fp32 in frwqslws enabled\n";
485 }
486 if (indicesInt64Opt) {
487 precConfig.convertIndicesToInt64 = indicesInt64Opt;
488 llvm::outs() << "Conversion of indices to int64 enabled\n";
489 }
490 if (glow::flags::ClipToFP16) {
491 precConfig.clipFP16 = true;
492 llvm::outs() << "Clipping to fp16 enabled\n";
493 }
494 if (glow::flags::SkipInputsOnClipToFP16) {
495 precConfig.clipFP16SkipInputs = true;
496 llvm::outs() << "Skipping clipping for fp16 Node inputs fp16\n";
497 }
498 if (glow::flags::ForceSLSToFP16Accum) {
499 precConfig.forceFP16AccumSLS = true;
500 llvm::outs() << "Forcing fp16 accumulation for SLS ops enabled\n";
501 }
502 if (!glow::flags::EnableQuantParamChanges) {
503 cctx.optimizationOpts.enableQuantParamChanges = false;
504 LOG(INFO) << "Disabling quantization param changes during optimizations";
505 }
506 if (glow::flags::DumpGraph) {
507 cctx.dumpFinalGraph = true;
508 cctx.dumpGraphPath = glow::flags::DumpGraphPath;
509 }
510 if (glow::flags::SinkTanhBelowConcat) {
511 cctx.optimizationOpts.sinkTanhBelowConcat =
512 glow::flags::SinkTanhBelowConcat;
513 LOG(INFO) << "Sinking tanh below concat";
514 }
515
516 if (glow::flags::UseSparseNNPartitioningScheme) {
517 cctx.optimizationOpts.useSparseNNPartitioningScheme = true;
518 cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats =
519 glow::flags::SparseNNPartitioningAddSLSConcats;
520 cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel =
521 glow::flags::SparseNNPartitioningBalancePerfModel;
522 cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS =
523 glow::flags::SparseNNPartitioningPairLNWithSLS;
524 cctx.optimizationOpts.sparseNNPartitioningPairTileWithSLS =
525 glow::flags::SparseNNPartitioningPairTileWithSLS;
526 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards =
527 glow::flags::SparseNNPartitioningSchemeNumCards;
528 cctx.optimizationOpts.sparseNNPartitioningSchemeSLSTableKBytesPerCard =
529 glow::flags::SparseNNPartitioningSchemeSLSTableKBytesPerCard;
530 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresSLS =
531 glow::flags::SparseNNPartitioningSchemeNumCoresSLS;
532 cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresOther =
533 glow::flags::SparseNNPartitioningSchemeNumCoresOther;
534 }
535
536 if (!glow::flags::processBackendSpecificOpts(
537 cctx.backendOpts.backendSpecificOpts,
538 glow::flags::BackendSpecificOpts)) {
539 return -1;
540 }
541
542 if (glow::nnpi::flags::NumParallelChunks > 1) {
543 cctx.backendOpts.backendSpecificOpts["NNPINumParallelChunks"] =
544 std::to_string(glow::nnpi::flags::NumParallelChunks);
545 }
546 if (glow::nnpi::flags::ModelParallelSplitAlignment > 1) {
547 cctx.backendOpts.backendSpecificOpts["NNPIModelParallelSplitAlignment"] =
548 std::to_string(glow::nnpi::flags::ModelParallelSplitAlignment);
549 }
550 if (glow::flags::UseDAGOptimizer) {
551 cctx.callDAGOptimizer = true;
552 cctx.optimizationOpts.DAGOptimizerNumParallelChunks =
553 glow::flags::DAGOptimizerNumParallelChunks;
554 cctx.optimizationOpts.DAGOptimizerParallelizationTaggingAlgorithm =
555 glow::flags::DAGOptimizerParallelizationTaggingAlgorithm;
556 cctx.optimizationOpts.DAGOptimizerPlacementTaggingAlgorithm =
557 glow::flags::DAGOptimizerPlacementTaggingAlgorithm;
558 }
559 if (glow::flags::DelayAndRecordConstantModification) {
560 cctx.optimizationOpts.delayAndRecordConstantModification = true;
561 }
562 if (glow::runtime::flags::EnableP2P) {
563 LOG(INFO) << "Glow P2P Enabled";
564 cctx.enableP2P = true;
565 }
566 if (glow::runtime::flags::EnableDRT) {
567 LOG(INFO) << "Glow DRT Enabled";
568 cctx.enableDRT = true;
569 }
570 if (glow::onnxifi::flags::SaveDAG) {
571 LOG(INFO) << "Serializing DAG after optimization and partitioning.";
572 cctx.serializeCompiledDAG = true;
573 }
574 if (glow::onnxifi::flags::SaveDAGWithConstants) {
575 LOG(INFO) << "Serializing DAG with constants after optimization and "
576 "partitioning.";
577 cctx.saveConstantInSerializeCompiledDAG = true;
578 }
579 if (glow::onnxifi::flags::SaveDAGInZipMode) {
580 LOG(INFO) << "Serializing DAG with constants after optimization and "
581 "partitioning in Zip mode.";
582 cctx.useZipModeForSerializeCompiledDAG = true;
583 }
584
585 // Load deferred weights if applicable
586 const auto &placeholderList = mod->getPlaceholders();
587 glow::PlaceholderList nonStaticPlaceholderList;
588 std::copy_if(placeholderList.begin(), placeholderList.end(),
589 std::back_inserter(nonStaticPlaceholderList),
590 [](const glow::Placeholder *p) { return !p->isStatic(); });
591 if (!deferredWeightsPathOpt.empty()) {
592 ::glow::runtime::DeferredLoader()->registerLoader(
593 new ZipFileBackedDeferredBlobLoader(deferredWeightsPathOpt));
594 // Initialize loader and set field in cctx.
595 auto *loader = runtime::DeferredLoader()->getLoader();
596 CHECK(loader) << "No deferred weights loader registered!";
597
598 // Generate a map of type date for all static placeholders.
599 std::map<std::string, Type> staticPlaceholderTypes;
600 for (auto *PH : placeholderList) {
601 if (PH->isStatic()) {
602 staticPlaceholderTypes[std::string(PH->getName())] = *PH->getType();
603 }
604 }
605 loader->setTypeInfo(std::move(staticPlaceholderTypes));
606 CHECK(!loader->setSrc(nullptr));
607 cctx.deferredWeightLoader = loader;
608 // Signal that we want to fold convertTo and Quantize into static
609 // Placeholders.
610 cctx.optimizationOpts.foldStaticPlaceholderConversions = true;
611 }
612
613 auto configs = runtime::generateDeviceConfigs(
614 glow::flags::NumDevices, ExecutionBackend, deviceMemoryOpt);
615 runtime::HostConfig hostConfig;
616 hostConfig.maxActiveRequests = glow::flags::MaxActiveRequests;
617 hostConfig.maxQueueSize = glow::flags::MaxQueueSize;
618 hostConfig.executorThreads = glow::flags::ExecutorThreads;
619
620 auto hostManager =
621 glow::make_unique<runtime::HostManager>(std::move(configs), hostConfig);
622 if (glow::flags::EnablePartialTensors) {
623 CHECK(hostManager->getBackend(ExecutionBackend).supportsPartialTensors())
624 << "Backend " << ExecutionBackend
625 << " doesn't support partial tensor but enablePartialTensor is set to "
626 "true.";
627 }
628 cctx.saturateHost = glow::flags::SaturateHost;
629 EXIT_ON_ERR(hostManager->addNetwork(std::move(mod), cctx));
630
631 // Whether to collect results and check accuracy. If we're not checking
632 // accuracy then don't load reference outputs
633 bool runAccuracyChecks =
634 !skipCorrectnessCheck || topKCompare > 0 || cosineSimilarityStats;
635
636 // Parse all input and output files ahead of inference.
637 std::vector<::ONNX_NAMESPACE::GraphProto> parsedInputs;
638 std::vector<::ONNX_NAMESPACE::GraphProto> parsedOutputs;
639 size_t inputGroupSize = inputsOpt.size();
640 if (inputGroupSize) {
641 for (size_t i = 0; i < inputGroupSize; ++i) {
642 llvm::outs() << "Loading input file: " << inputsOpt[i] << "\n";
643 auto inputGroup = parseOnnxFile(inputsOpt[i]);
644 parsedInputs.push_back(std::move(inputGroup));
645 if (runAccuracyChecks) {
646 llvm::outs() << "Loading output file: " << outputsOpt[i] << "\n";
647 auto outputGroup = parseOnnxFile(outputsOpt[i]);
648 parsedOutputs.push_back(std::move(outputGroup));
649 }
650 }
651 } else if (!inputPatternOpt.empty() && seqLenOpt > 0) {
652 inputGroupSize = seqLenOpt;
653 size_t input_iter = inputPatternOpt.find("{}");
654 CHECK_NE(input_iter, std::string::npos)
655 << "Input pattern " << inputPatternOpt << " has to contain {}";
656 for (unsigned i = 0; i < seqLenOpt; ++i) {
657 std::string copy = inputPatternOpt;
658 copy.replace(input_iter, 2, std::to_string(seqStartOpt + i));
659 llvm::outs() << "Loading input file: " << copy << "\n";
660 auto inputGroup = parseOnnxFile(copy);
661 parsedInputs.push_back(std::move(inputGroup));
662 }
663
664 if (runAccuracyChecks) {
665 CHECK(!outputPatternOpt.empty())
666 << "Output pattern must be provided for accuracy checks";
667 size_t output_iter = outputPatternOpt.find("{}");
668 CHECK_NE(output_iter, std::string::npos)
669 << "Output pattern " << outputPatternOpt << " has to contain {}";
670 for (unsigned i = 0; i < seqLenOpt; ++i) {
671 std::string copy = outputPatternOpt;
672 copy.replace(output_iter, 2, std::to_string(seqStartOpt + i));
673 llvm::outs() << "Loading output file: " << copy << "\n";
674 auto outputGroup = parseOnnxFile(copy);
675 parsedOutputs.push_back(std::move(outputGroup));
676 }
677 }
678 }
679
680 llvm::outs() << "\ninput pattern: " + inputPatternOpt + "\n";
681 llvm::outs() << "\nseqlen: " + std::to_string(seqLenOpt) + "\n";
682 llvm::outs() << "\ninputgroupsize: " + std::to_string(inputGroupSize) + "\n";
683
684 if (parsedInputs.empty()) {
685 llvm::outs() << "No inputs are provided. Exiting...\n";
686 return -1;
687 }
688
689 llvm::outs() << "Starting inference\n";
690 llvm::outs().flush(); // Explicit flush to denote the progress
691
692 auto nowTime = std::chrono::steady_clock::now();
693 auto endTimeDuration = nowTime + std::chrono::minutes(durationMinOpt);
694 do {
695 TraceContext mergedTraceContext(TraceLevel::STANDARD);
696 folly::CPUThreadPoolExecutor threadPool(concurrentCountOpt);
697 std::mutex mutex;
698 std::condition_variable cv;
699 int numTotalInferences = inputGroupSize * itersOpt;
700 int numFinishedInferences = 0;
701
702 // Figure out which placeholder is input.
703 std::unordered_set<std::string> inputTensorNames;
704 for (const auto &proto : parsedInputs[0].initializer()) {
705 inputTensorNames.insert(glow::legalizeName(proto.name()));
706 }
707
708 glow::PlaceholderList inputPlaceholderList;
709 std::copy_if(placeholderList.begin(), placeholderList.end(),
710 std::back_inserter(inputPlaceholderList),
711 [&](const glow::Placeholder *p) {
712 return inputTensorNames.find(p->getName().str()) !=
713 inputTensorNames.end();
714 });
715
716 std::vector<Tensor> partialTensorPayloads;
717 std::vector<PlaceholderBindings> inputBindings;
718 for (const auto &inputGroup : parsedInputs) {
719 PlaceholderBindings bindings;
720 bindings.allocate(inputPlaceholderList);
721 fillPlaceholders(
722 inputGroup, &bindings,
723 glow::flags::EnablePartialTensors ? &partialTensorPayloads : nullptr,
724 usingGlowCustomOps);
725 inputBindings.emplace_back(std::move(bindings));
726 }
727
728 bool enableGlowTrace = glow::flags::DumpDebugTraces ||
729 TraceExporterRegistry::getInstance()->shouldTrace();
730
731 if (enableGlowTrace && glowEnableDeviceTrace) {
732 // Start device traces.
733 hostManager->setTraceContext(
734 glow::make_unique<TraceContext>(TraceLevel::STANDARD));
735 Error startErr = hostManager->startDeviceTrace();
736 if (ERR_TO_BOOL(std::move(startErr))) {
737 LOG(WARNING) << "Failed to start device traces";
738 }
739 }
740
741 auto startTime = std::chrono::steady_clock::now();
742 std::list<InferenceResult> results;
743 for (int ioIndex = 0, numInferencesIssued = 0;
744 numInferencesIssued < numTotalInferences; ++numInferencesIssued,
745 ioIndex = numInferencesIssued % inputGroupSize) {
746
747 results.emplace_back();
748 auto &result = results.back();
749
750 threadPool.add([&inputBindings, &nonStaticPlaceholderList, ioIndex,
751 &mergedTraceContext, &hostManager, &result, &cv, &mutex,
752 numTotalInferences, &numFinishedInferences,
753 runAccuracyChecks, enableGlowTrace]() {
754 // Setup the inputs.
755 auto ctx = glow::make_unique<ExecutionContext>();
756
757 TraceContext *traceContext = nullptr;
758 if (enableGlowTrace) {
759 ctx->setTraceContext(
760 glow::make_unique<TraceContext>(TraceLevel::STANDARD));
761 traceContext = ctx->getTraceContext();
762 traceContext->setThreadName("Request Thread");
763 }
764 TRACE_EVENT_SCOPE(traceContext, TraceLevel::RUNTIME,
765 "Dispatch to prep input and dispatch");
766
767 // Set up input
768 auto &bindings = *ctx->getPlaceholderBindings();
769 bindings.clear();
770
771 for (const auto &binding : inputBindings[ioIndex].pairs()) {
772 auto *PH = binding.first;
773 bindings.insert(PH, binding.second.getUnowned());
774 }
775 // Allocate for output
776 bindings.allocate(nonStaticPlaceholderList);
777
778 std::promise<void> promise;
779 auto future = promise.get_future();
780
781 TRACE_EVENT_SCOPE_END();
782
783 hostManager->runNetwork(
784 "test", std::move(ctx),
785 [&promise, index = ioIndex,
786 &result](runtime::RunIdentifierTy, Error err,
787 std::unique_ptr<ExecutionContext> contextPtr) mutable {
788 result.error = std::move(err);
789 result.ctx = std::move(contextPtr);
790 result.index = index;
791 result.endTime = std::chrono::steady_clock::now();
792 promise.set_value();
793 });
794
795 // wait for glow to finish.
796 future.wait();
797 traceContext = result.ctx->getTraceContext();
798 if (traceContext) {
799 // export to registered trace exporters
800 TraceExporterRegistry::getInstance()->exportTrace(traceContext);
801 // merge() has internal lock and is thread safe.
802 mergedTraceContext.merge(traceContext);
803 }
804
805 if (!runAccuracyChecks) {
806 // if skipping correctness check, throw away the context to keep
807 // memory usage low.
808 result.ctx.reset();
809 }
810
811 std::unique_lock<std::mutex> lock(mutex);
812 if (++numFinishedInferences >= numTotalInferences) {
813 lock.unlock();
814 cv.notify_all();
815 }
816 });
817 }
818
819 // wait for all inferneces to finish
820 std::unique_lock<std::mutex> lock(mutex);
821 cv.wait(lock,
822 [&]() { return numFinishedInferences >= numTotalInferences; });
823
824 auto endTime = startTime;
825 llvm::outs() << "All inferences done. Checking results\n";
826 for (auto &result : results) {
827 if (result.endTime > endTime) {
828 endTime = result.endTime;
829 }
830
831 if (result.error) {
832 llvm::outs() << "Inference failed!\n";
833 if (result.error.peekErrorValue()->isFatalError()) {
834 std::string msg = result.error.peekErrorValue()->logToString();
835 llvm::outs() << "Non-recoverable device error: " << msg << "\n";
836 }
837 ++numFailed;
838 } else {
839 ONNX_NAMESPACE::GraphProto outputG;
840 std::ofstream of;
841 if (dumpOutputsOpt) {
842 std::stringstream ss;
843 ss << "output_dump_" << result.index << ".onnx";
844 of.open(ss.str(), std::ios::binary);
845 CHECK(of) << "Cannot create output dump file: " << ss.str();
846 }
847
848 if (runAccuracyChecks) {
849 const auto &outputGroup = parsedOutputs[result.index];
850 CHECK(result.ctx);
851 const auto &bindings = *result.ctx->getPlaceholderBindings();
852 for (const auto &tp : outputGroup.initializer()) {
853 Tensor tensorRef;
854 auto error = loadTensor(tp, &tensorRef, usingGlowCustomOps);
855 CHECK(!ERR_TO_BOOL(std::move(error)))
856 << "Cannot load output ref tensor";
857 const auto *tensor =
858 bindings.get(bindings.getPlaceholderByNameSlow(tp.name()));
859 CHECK(tensor) << "Missing " << tp.name()
860 << " in output placeholder";
861
862 if (cosineSimilarityStats) {
863 cosineHist.addValue(cosineSimilarity(*tensor, tensorRef));
864 }
865
866 if (topKCompare > 0) {
867 numTotalTopKCompares++;
868 assert(tensor->size() == tensorRef.size());
869 auto sortedResults = partialSortFloatTensor(*tensor, topKCompare);
870 auto sortedRefs = partialSortFloatTensor(tensorRef, topKCompare);
871 assert(sortedResults.size() == size_t(topKCompare) &&
872 sortedResults.size() == size_t(topKCompare));
873
874 bool allKMatch = true;
875 std::stringstream ss;
876 for (auto i = 0; i < topKCompare; i++) {
877 if (sortedResults[i].second == sortedRefs[i].second) {
878 if (i == 0) {
879 numTop1Matches++;
880 }
881 } else {
882 allKMatch = false;
883 }
884 if (logTopKResultsPerExample) {
885 ss << i << ": Test result: " << sortedResults[i].second
886 << " (p=" << sortedResults[i].first
887 << ") Reference result: " << sortedRefs[i].second
888 << " (p=" << sortedRefs[i].first << ")\n";
889 }
890 }
891 if (logTopKResultsPerExample) {
892 llvm::outs() << ss.str() << "\n";
893 }
894 if (allKMatch) {
895 numTopKMatches++;
896 }
897 }
898
899 if (dumpOutputsOpt) {
900 auto *t = outputG.add_initializer();
901 ONNXModelWriter::writeTensor(*tensor, t, usingGlowCustomOps);
902 t->set_name(tp.name());
903 }
904
905 if (!skipCorrectnessCheck) {
906 bool equal = tensorRef.isEqual(*tensor, thresholdOpt, true);
907 if (!equal) {
908 llvm::outs() << "Verification failed at input/output pair "
909 << result.index << " for output tensor "
910 << tp.name() << "\n";
911 ++numFailed;
912 break;
913 }
914 }
915 }
916 }
917
918 if (dumpOutputsOpt) {
919 std::string buffer;
920 outputG.SerializeToString(&buffer);
921 of << buffer;
922 }
923 }
924 }
925
926 if (glow::flags::DumpDebugTraces) {
927 if (glowEnableDeviceTrace) {
928 // Stop device traces and collect events.
929 Error stopErr = hostManager->stopDeviceTrace();
930 if (ERR_TO_BOOL(std::move(stopErr))) {
931 LOG(WARNING) << "Failed to stop device traces.";
932 } else {
933 mergedTraceContext.merge(hostManager->getTraceContext());
934 }
935 }
936 llvm::SmallString<64> path;
937 if (glowDumpTraceFile.empty()) {
938 auto tempFileRes =
939 llvm::sys::fs::createTemporaryFile("glow-trace", "json", path);
940 if (tempFileRes.value() != 0) {
941 LOG(ERROR) << "Failed to create temp file for Glow trace events: "
942 << tempFileRes;
943 } else {
944 LOG(INFO) << "Trace path=" << path.c_str();
945 mergedTraceContext.dump(path);
946 }
947 } else {
948 LOG(INFO) << "Trace path=" << path.c_str();
949 mergedTraceContext.dump(glowDumpTraceFile);
950 }
951 }
952
953 if (!skipCorrectnessCheck) {
954 if (numFailed == 0) {
955 llvm::outs() << "All passed!\n";
956 } else {
957 llvm::outs() << numFailed << " inferences failed to match reference.\n";
958 }
959 }
960
961 if (topKCompare > 0) {
962 llvm::outs() << "Num top1 exact matches: " << numTop1Matches << "/"
963 << numTotalTopKCompares << "\n";
964 llvm::outs() << "Num topK exact matches (k=" << topKCompare
965 << "): " << numTopKMatches << "/" << numTotalTopKCompares
966 << "\n";
967
968 if (top1Threshold > 0.0) {
969 float top1MatchRate = float(numTop1Matches) / numTotalTopKCompares;
970 if (top1MatchRate < top1Threshold) {
971 llvm::outs() << "Expected top1 match rate of at least "
972 << top1Threshold << " but only achieved "
973 << top1MatchRate << "\n";
974 return numTotalTopKCompares - numTop1Matches;
975 }
976 }
977 }
978
979 if (cosineSimilarityStats) {
980 float p50Similarity = cosineHist.getPercentileEstimate(0.5);
981 llvm::outs() << "cosine similarity stats:\n"
982 << "p01: " << cosineHist.getPercentileEstimate(0.01) << "\n"
983 << "p02: " << cosineHist.getPercentileEstimate(0.02) << "\n"
984 << "p05: " << cosineHist.getPercentileEstimate(0.05) << "\n"
985 << "p10: " << cosineHist.getPercentileEstimate(0.1) << "\n"
986 << "p25: " << cosineHist.getPercentileEstimate(0.25) << "\n"
987 << "p50: " << p50Similarity << "\n"
988 << "p75: " << cosineHist.getPercentileEstimate(0.75) << "\n"
989 << "p90: " << cosineHist.getPercentileEstimate(0.90) << "\n"
990 << "p95: " << cosineHist.getPercentileEstimate(0.95) << "\n"
991 << "p98: " << cosineHist.getPercentileEstimate(0.98) << "\n"
992 << "p99: " << cosineHist.getPercentileEstimate(0.99) << "\n";
993 if (cosineSimilarityThreshold > 0.0) {
994 if (p50Similarity < cosineSimilarityThreshold) {
995 llvm::outs() << "Expected p50 cosine similarity of at least "
996 << cosineSimilarityThreshold << " but only achieved "
997 << p50Similarity << "\n";
998 return 1;
999 }
1000 }
1001 }
1002
1003 llvm::outs().flush();
1004
1005 std::chrono::duration<double, std::milli> duration = endTime - startTime;
1006 std::cout << "Total inference duration (ms): " << duration.count() << "\n";
1007 std::cout << "Avg inference duration (ms): "
1008 << duration.count() / numTotalInferences << "\n";
1009 std::cout << "Avg inference per second: "
1010 << numTotalInferences * 1000 / duration.count()
1011 << std::endl; // Use endl to flush the buffer
1012 nowTime = std::chrono::steady_clock::now();
1013 } while (std::chrono::duration_cast<std::chrono::seconds>(nowTime -
1014 endTimeDuration)
1015 .count() < 0);
1016
1017 return numFailed;
1018}
1019
1020} // namespace
1021
1022int main(int argc, char **argv) {
1023 google::InitGoogleLogging(argv[0]);
1024 google::InstallFailureSignalHandler();
1025 parseCommandLine(argc, argv);
1026 return run();
1027}
1028