1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "BackendTestUtils.h" |
18 | #include "folly/executors/CPUThreadPoolExecutor.h" |
19 | #include "glow/Backend/Backend.h" |
20 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
21 | #include "glow/Exporter/ONNXModelWriter.h" |
22 | #include "glow/Flags/Flags.h" |
23 | #include "glow/Graph/Graph.h" |
24 | #include "glow/Importer/ONNXModelLoader.h" |
25 | #include "glow/Runtime/DeferredWeightLoader.h" |
26 | #include "glow/Runtime/TraceExporter.h" |
27 | #include "glow/Support/Support.h" |
28 | #include "glow/Support/ZipUtils.h" |
29 | |
30 | #include "llvm/Support/CommandLine.h" |
31 | #include "llvm/Support/FileSystem.h" |
32 | #include "llvm/Support/Signals.h" |
33 | |
34 | #include "google/protobuf/io/coded_stream.h" |
35 | #include "google/protobuf/io/zero_copy_stream_impl.h" |
36 | #include <glog/logging.h> |
37 | |
38 | #include "folly/stats/Histogram.h" |
39 | |
40 | #include <fstream> |
41 | #include <iostream> |
42 | #include <string> |
43 | |
44 | using namespace glow; |
45 | |
46 | namespace { |
47 | llvm::cl::OptionCategory reproTestCat("Repro Category" ); |
48 | llvm::cl::opt<std::string> modelPathOpt("model" , llvm::cl::desc("Input models" ), |
49 | llvm::cl::value_desc("modelPath" ), |
50 | llvm::cl::Required, |
51 | llvm::cl::cat(reproTestCat)); |
52 | llvm::cl::opt<std::string> deferredWeightsPathOpt( |
53 | "deferred_weights" , llvm::cl::desc("Path to the deferred weights file" ), |
54 | llvm::cl::Optional, llvm::cl::init("" ), llvm::cl::cat(reproTestCat)); |
55 | llvm::cl::list<std::string> inputsOpt("inputs" , llvm::cl::desc("Inputs" ), |
56 | llvm::cl::value_desc("Inputs" ), |
57 | llvm::cl::Optional, llvm::cl::ZeroOrMore, |
58 | llvm::cl::cat(reproTestCat)); |
59 | llvm::cl::list<std::string> outputsOpt("outputs" , llvm::cl::desc("Ouptuts" ), |
60 | llvm::cl::value_desc("Ouptuts" ), |
61 | llvm::cl::Optional, llvm::cl::ZeroOrMore, |
62 | llvm::cl::cat(reproTestCat)); |
63 | llvm::cl::opt<std::string> |
64 | inputPatternOpt("input_pattern" , |
65 | llvm::cl::desc("Input file pattern. in_{}.onnx" ), |
66 | llvm::cl::init("" ), llvm::cl::cat(reproTestCat)); |
67 | llvm::cl::opt<std::string> |
68 | outputPatternOpt("output_pattern" , |
69 | llvm::cl::desc("Output file pattern. out_{}.onnx" ), |
70 | llvm::cl::init("" ), llvm::cl::cat(reproTestCat)); |
71 | llvm::cl::opt<unsigned> seqStartOpt( |
72 | "seq_start" , llvm::cl::desc("Start index of input/output files" ), |
73 | llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(reproTestCat)); |
74 | llvm::cl::opt<unsigned> seqLenOpt( |
75 | "seq_len" , llvm::cl::desc("Lengths of the input/output file seqquence." ), |
76 | llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(reproTestCat)); |
77 | |
78 | llvm::cl::opt<std::string> ExecutionBackend( |
79 | "backend" , llvm::cl::desc("Backend to use, e.g., Interpreter, CPU, NNPI:" ), |
80 | llvm::cl::init("NNPI" ), llvm::cl::cat(reproTestCat)); |
81 | |
82 | llvm::cl::opt<unsigned> concurrentCountOpt( |
83 | "concurrent_count" , llvm::cl::desc("Number of concurrent requests." ), |
84 | llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(reproTestCat)); |
85 | |
86 | llvm::cl::opt<float> deviceMemoryOpt( |
87 | "glow_device_memory" , |
88 | llvm::cl::desc("Size of memory for a certain Glow backend device" ), |
89 | llvm::cl::Optional, llvm::cl::init(256 * 1024.0 * 1024.0 * 1024.0), |
90 | llvm::cl::cat(reproTestCat)); |
91 | |
92 | llvm::cl::opt<float> thresholdOpt( |
93 | "threshold" , llvm::cl::desc("theshold for tensor numeric comparison" ), |
94 | llvm::cl::Optional, llvm::cl::init(1e-5), llvm::cl::cat(reproTestCat)); |
95 | |
96 | llvm::cl::opt<bool> glowDumpGraphAfterLoadOpt( |
97 | "glow_dump_graph_after_load" , |
98 | llvm::cl::desc( |
99 | "Dump the glow Graph into files immediately after loading from ONNX" ), |
100 | llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat)); |
101 | |
102 | llvm::cl::opt<bool> sliceConcatFp32Opt( |
103 | "glow_slice_concat_fp32" , |
104 | llvm::cl::desc("Don't convert slice and concat ops's precision when " |
105 | "--glow_global_fp16 is used." ), |
106 | llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat)); |
107 | |
108 | llvm::cl::opt<bool> dumpOutputsOpt("dump_outputs" , |
109 | llvm::cl::desc("Dump output tensors" ), |
110 | llvm::cl::Optional, llvm::cl::init(true), |
111 | llvm::cl::cat(reproTestCat)); |
112 | |
113 | llvm::cl::opt<bool> indicesInt64Opt( |
114 | "glow_global_indices_fp64" , |
115 | llvm::cl::desc("Enable converting scale/offset in frwqslws's data from " |
116 | "int32 to int64" )); |
117 | |
118 | llvm::cl::opt<bool, /* ExternalStorage */ true> enablePartialTensorOpt( |
119 | "glow_enable_partial_tensor" , llvm::cl::desc("Enable partial tensor" ), |
120 | llvm::cl::Optional, llvm::cl::location(glow::flags::EnablePartialTensors), |
121 | llvm::cl::init(true), llvm::cl::cat(reproTestCat)); |
122 | |
123 | llvm::cl::opt<unsigned> itersOpt( |
124 | "iters" , |
125 | llvm::cl::desc("Total number of requests to loop over provided input." ), |
126 | llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(reproTestCat)); |
127 | |
128 | llvm::cl::alias requestCountOpt("request_count" , |
129 | llvm::cl::desc("Alias for -iters" ), |
130 | llvm::cl::aliasopt(itersOpt)); |
131 | |
132 | llvm::cl::opt<unsigned> durationMinOpt( |
133 | "duration_min" , llvm::cl::desc("Running duration limit in minutes" ), |
134 | llvm::cl::Optional, llvm::cl::init(0), llvm::cl::cat(reproTestCat)); |
135 | |
136 | llvm::cl::opt<bool> glowEnableDeviceTrace( |
137 | "glow_enable_device_traces" , |
138 | llvm::cl::desc("Enable trace events from inference backend device." ), |
139 | llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat)); |
140 | |
141 | llvm::cl::opt<bool> skipCorrectnessCheck( |
142 | "skip_correctness_check" , llvm::cl::desc("Skip correctness check" ), |
143 | llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat)); |
144 | |
145 | llvm::cl::opt<std::string> |
146 | glowDumpTraceFile("glow_dump_debug_traces_file" , |
147 | llvm::cl::desc("Dump glow trace file" ), |
148 | llvm::cl::Optional, llvm::cl::init(std::string("" )), |
149 | llvm::cl::cat(reproTestCat)); |
150 | |
151 | llvm::cl::opt<int32_t> |
152 | topKCompare("topk_compare" , |
153 | llvm::cl::desc("Compare the topk results against reference" ), |
154 | llvm::cl::Optional, llvm::cl::init(0), |
155 | llvm::cl::cat(reproTestCat)); |
156 | |
157 | llvm::cl::opt<float> top1Threshold( |
158 | "top1_threshold" , |
159 | llvm::cl::desc( |
160 | "Percentage of top1 matches to reference that must be achieved" ), |
161 | llvm::cl::Optional, llvm::cl::init(0.0), llvm::cl::cat(reproTestCat)); |
162 | |
163 | llvm::cl::opt<bool> logTopKResultsPerExample( |
164 | "log_topk_results_per_example" , |
165 | llvm::cl::desc("Whether to log topk results vs reference for each example" ), |
166 | llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat)); |
167 | |
168 | llvm::cl::opt<bool> cosineSimilarityStats( |
169 | "cosine_similarity_stats" , |
170 | llvm::cl::desc("Whether to compute cosine similarity stats" ), |
171 | llvm::cl::Optional, llvm::cl::init(false), llvm::cl::cat(reproTestCat)); |
172 | |
173 | llvm::cl::opt<float> cosineSimilarityThreshold( |
174 | "p50_cosine_similarity_threshold" , |
175 | llvm::cl::desc( |
176 | "Percentage of top1 matches to reference that must be achieved" ), |
177 | llvm::cl::Optional, llvm::cl::init(0.0), llvm::cl::cat(reproTestCat)); |
178 | |
179 | llvm::cl::opt<bool> onnxLoaderZipMode( |
180 | "zip_mode" , llvm::cl::desc("zipMode to use with OnnxModelLoader" ), |
181 | llvm::cl::Optional, llvm::cl::init(true), llvm::cl::cat(reproTestCat)); |
182 | llvm::cl::opt<unsigned> replicationCountOpt( |
183 | "replication_count" , llvm::cl::desc("Set the network replication count" ), |
184 | llvm::cl::Optional, llvm::cl::init(1), llvm::cl::cat(reproTestCat)); |
185 | |
186 | /// Explicitly show gflags help/version info, depending on \p foundHelpFlag and |
187 | /// \p foundVersionFlag. llvm shows its own help/version info when it parses. |
188 | void gflagsShowHelpVersion(bool foundHelpFlag, bool foundVersionFlag) { |
189 | const char *binName = gflags::ProgramInvocationShortName(); |
190 | if (foundHelpFlag) { |
191 | gflags::SetUsageMessage( |
192 | strFormat("gflags for %s\nUSAGE: %s [options]:" , binName, binName)); |
193 | gflags::ShowUsageWithFlagsRestrict(binName, /* restrict_ */ "" ); |
194 | llvm::outs() << "\nLLVM CommandLine options:\n" ; |
195 | } |
196 | if (foundVersionFlag) { |
197 | llvm::outs() << "gflags version:\n" ; |
198 | const char *versionStr = gflags::VersionString(); |
199 | llvm::outs() << binName; |
200 | if (versionStr && *versionStr) { |
201 | llvm::outs() << " version " << versionStr; |
202 | } |
203 | llvm::outs() << "\n\n" ; |
204 | } |
205 | } |
206 | |
207 | void parseCommandLine(int argc, char **argv) { |
208 | // Use different defaults for some flags: |
209 | FLAGS_glow_global_fp16 = true; |
210 | FLAGS_glow_clip_fp16 = true; |
211 | FLAGS_glow_global_fused_scale_offset_fp16 = true; |
212 | FLAGS_glow_global_fused_scale_offset_fp32 = false; |
213 | FLAGS_glow_snn_partitioning_kbytes_per_card = 5000000; |
214 | FLAGS_glow_snn_partitioning_num_cores_sls = 6; |
215 | FLAGS_glow_snn_partitioning_num_cores_other = 6; |
216 | |
217 | llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); |
218 | |
219 | // Verify there's no unexpected overlap in flags by llvm/gflags. |
220 | const auto &llvmOpts = llvm::cl::getRegisteredOptions(); |
221 | for (const auto &opt : llvmOpts) { |
222 | static const llvm::StringSet<> allowedInBoth = {"help" , "version" }; |
223 | if (allowedInBoth.count(opt.getKey())) { |
224 | continue; |
225 | } |
226 | gflags::CommandLineFlagInfo dummy; |
227 | CHECK(!gflags::GetCommandLineFlagInfo(opt.getKey().data(), &dummy)) |
228 | << "Error: Repeated flag used by both llvm and gflags: " |
229 | << opt.getKey().data(); |
230 | } |
231 | |
232 | // Separate out llvm and gflags into their own argc/argv. |
233 | llvm::SmallVector<char *, 40> llvmArgv, gflagsArgv; |
234 | llvmArgv.push_back(argv[0]); |
235 | gflagsArgv.push_back(argv[0]); |
236 | bool foundHelpFlag = false; |
237 | bool foundVersionFlag = false; |
238 | for (int i = 1; i < argc; ++i) { |
239 | llvm::StringRef flagName(argv[i]); |
240 | // Positional args are always llvm cl args. |
241 | if (!flagName.startswith("-" )) { |
242 | llvmArgv.push_back(argv[i]); |
243 | continue; |
244 | } |
245 | |
246 | // Strip off leading '-'. |
247 | flagName = flagName.drop_while([](char c) -> bool { return c == '-'; }); |
248 | // Look for everything leading up to '=', if any. |
249 | flagName = flagName.take_until([](char c) -> bool { return c == '='; }); |
250 | |
251 | // Now check if flagName is a gflag, otherwise assume it was from llvm. If |
252 | // help/version, always pass to llvm; we will also call gflags directly for |
253 | // them to print before llvm parses/prints, so that both gflags and llvm |
254 | // will print help/version. |
255 | gflags::CommandLineFlagInfo dummy; |
256 | if (!gflags::GetCommandLineFlagInfo(flagName.str().c_str(), &dummy) || |
257 | flagName == "help" || flagName == "version" ) { |
258 | llvmArgv.push_back(argv[i]); |
259 | if (flagName == "help" ) { |
260 | foundHelpFlag = true; |
261 | } else if (flagName == "version" ) { |
262 | foundVersionFlag = true; |
263 | } |
264 | } else { |
265 | gflagsArgv.push_back(argv[i]); |
266 | } |
267 | } |
268 | int llvmArgc = static_cast<int>(llvmArgv.size()); |
269 | int gflagsArgc = static_cast<int>(gflagsArgv.size()); |
270 | |
271 | // Now we can parse both llvm and gflags safely. All gflags should be |
272 | // legitimate. All other flags will be passed to llvm, which will complain |
273 | // about unknown ones. |
274 | char **gflagsArgvPtr = &gflagsArgv[0]; |
275 | gflags::AllowCommandLineReparsing(); |
276 | gflags::ParseCommandLineFlags(&gflagsArgc, &gflagsArgvPtr, |
277 | /* remove_flags */ false); |
278 | gflagsShowHelpVersion(foundHelpFlag, foundVersionFlag); |
279 | llvm::cl::ParseCommandLineOptions( |
280 | llvmArgc, &llvmArgv[0], |
281 | " The Glow compiler\n\n" |
282 | "Glow is a compiler for neural network accelerators.\n" ); |
283 | |
284 | if (top1Threshold > 0.0 && topKCompare == 0) { |
285 | topKCompare = 1; |
286 | } |
287 | |
288 | if (cosineSimilarityThreshold > 0.0) { |
289 | cosineSimilarityStats = true; |
290 | } |
291 | } |
292 | |
293 | struct InferenceResult { |
294 | Error error = Error::empty(); |
295 | std::unique_ptr<ExecutionContext> ctx; |
296 | int index = 0; |
297 | std::chrono::time_point<std::chrono::steady_clock> endTime; |
298 | }; |
299 | |
300 | class ZipFileBackedDeferredBlobLoader |
301 | : public ::glow::runtime::DeferredWeightLoader { |
302 | public: |
303 | explicit ZipFileBackedDeferredBlobLoader(const std::string &path) { |
304 | zip_ = ::glow::make_unique<::glow::ZipReader>(path); |
305 | CHECK(zip_); |
306 | auto numWeightsStr = zip_->getRecord("weights" ); |
307 | weightsToLoad_ = atoi(numWeightsStr.c_str()); |
308 | i_ = 0; |
309 | } |
310 | |
311 | ::glow::Error loadNextWeight() override { |
312 | if (weightsToLoad_ == i_) { |
313 | llvm::outs() << "All deferred weights are loaded\n" ; |
314 | currentBlobName_ = "" ; |
315 | currentTensor_.reset(); |
316 | zip_.reset(nullptr); |
317 | return ::glow::Error::success(); |
318 | } |
319 | |
320 | std::stringstream ss; |
321 | ss << "weight_" << i_; |
322 | largeBuffer_ = zip_->getRecord(ss.str()); |
323 | ::ONNX_NAMESPACE::TensorProto t; |
324 | t.ParseFromString(largeBuffer_); |
325 | |
326 | currentBlobName_ = glow::legalizeName(t.name()); |
327 | auto tyIdx = typeInfo_.find(currentBlobName_); |
328 | if (tyIdx == typeInfo_.end()) { |
329 | return ::MAKE_ERR( |
330 | ::glow::ErrorValue::ErrorCode::RUNTIME_ERROR, |
331 | ::glow::strFormat( |
332 | "Error: Blob name: %s not found in list of static placeholders." , |
333 | currentBlobName_.c_str())); |
334 | } |
335 | auto ty = typeInfo_[currentBlobName_]; |
336 | |
337 | ss.str("" ); |
338 | ss << "data_" << i_++; |
339 | largeBuffer_.clear(); |
340 | if (zip_->hasRecord(ss.str())) { |
341 | largeBuffer_ = zip_->getRecord(ss.str()); |
342 | LOG(INFO) << "Read weight data " << ss.str() << " of size " |
343 | << largeBuffer_.size(); |
344 | } |
345 | currentTensor_.reset(new ::glow::Tensor()); |
346 | RETURN_IF_ERR(::glow::loadTensor(t, currentTensor_.get(), |
347 | /*useGlowCustomOps*/ false, largeBuffer_)); |
348 | CHECK(currentTensor_->getType().isEqual(ty)) |
349 | << "Mismatched tensor type: " << currentTensor_->getType().toString() |
350 | << " vs " << ty.toString(); |
351 | |
352 | return ::glow::Error::success(); |
353 | } |
354 | |
355 | ::glow::Error setSrc(void * /*unused*/) override { |
356 | return ::glow::Error::success(); |
357 | } |
358 | |
359 | std::string getName() override { return currentBlobName_; } |
360 | |
361 | ::glow::Tensor *getTensor() override { return currentTensor_.get(); } |
362 | |
363 | void setTypeInfo(std::map<std::string, ::glow::Type> info) override { |
364 | typeInfo_ = std::move(info); |
365 | } |
366 | |
367 | private: |
368 | std::unique_ptr<::glow::ZipReader> zip_; |
369 | std::string largeBuffer_; |
370 | std::string currentBlobName_; |
371 | std::unique_ptr<::glow::Tensor> currentTensor_; |
372 | size_t weightsToLoad_{0}; |
373 | size_t i_{0}; |
374 | }; |
375 | |
376 | /// Given a float Tensor \p t, \returns a vector of pairs of entries in t with |
377 | /// the first element in the pair being the value and the second element being |
378 | /// the original index of that value in t. The vector is partially sorted such |
379 | /// that the first \p k elements are the k elements from t with the greatest |
380 | /// values. |
381 | static std::vector<std::pair<float, size_t>> |
382 | partialSortFloatTensor(const Tensor &t, size_t k) { |
383 | std::vector<std::pair<float, size_t>> vec; |
384 | auto handle = t.getHandle<float>(); |
385 | for (size_t i = 0; i < handle.size(); ++i) { |
386 | vec.push_back({handle.raw(i), i}); |
387 | } |
388 | std::partial_sort( |
389 | vec.begin(), vec.begin() + k, vec.end(), |
390 | [](const auto &p1, const auto &p2) { return p1.first > p2.first; }); |
391 | return vec; |
392 | } |
393 | |
394 | static float dotProd(const Tensor &t1, const Tensor &t2) { |
395 | CHECK(t1.getElementType() == ElemKind::FloatTy); |
396 | CHECK(t2.getElementType() == ElemKind::FloatTy); |
397 | auto t1H = t1.getHandle<float>(); |
398 | auto t2H = t2.getHandle<float>(); |
399 | CHECK_EQ(t1H.size(), t2H.size()); |
400 | float res = 0.0f; |
401 | for (dim_t i = 0; i < t1H.size(); i++) { |
402 | res += t1H.raw(i) * t2H.raw(i); |
403 | } |
404 | return res; |
405 | } |
406 | |
407 | static float cosineSimilarity(const Tensor &t1, const Tensor &t2) { |
408 | auto fn = [](const Tensor &t1, const Tensor &t2) { |
409 | return dotProd(t1, t2) / |
410 | (std::sqrt(dotProd(t1, t1)) * std::sqrt(dotProd(t2, t2))); |
411 | }; |
412 | if (t1.getType().isQuantizedType()) { |
413 | auto t1Float = quantization::dequantizeTensor(t1, ElemKind::FloatTy); |
414 | auto t2Float = quantization::dequantizeTensor(t2, ElemKind::FloatTy); |
415 | return fn(t1Float, t2Float); |
416 | } else { |
417 | return fn(t1, t2); |
418 | } |
419 | } |
420 | |
421 | int run() { |
422 | int numFailed = 0; |
423 | |
424 | int numTop1Matches = 0; |
425 | int numTopKMatches = 0; |
426 | int numTotalTopKCompares = 0; |
427 | |
428 | folly::Histogram<float> cosineHist(/* bucketSize */ 0.1f, /* min */ 0.0f, |
429 | /* max */ 1.0f); |
430 | |
431 | // Build the execution engine and deserialize the Function. |
432 | auto mod = glow::make_unique<Module>(); |
433 | Error err = Error::empty(); |
434 | bool usingGlowCustomOps = false; |
435 | CompilationContext cctx; |
436 | cctx.replicationCount = replicationCountOpt; |
437 | cctx.maxActiveRequestsPerInstance = glow::flags::MaxActiveRequestsPerInstance; |
438 | runtime::PrePartitionedConfig PPC; |
439 | cctx.prepartitionedConfig = &PPC; |
440 | { |
441 | ONNXModelLoader onnxLD(modelPathOpt, {}, {}, *mod, "test" , &PPC, &err, |
442 | onnxLoaderZipMode, |
443 | &cctx.backendOpts.backendSpecificNodeInfo, |
444 | /* loadIntoExistingModule */ false, |
445 | /* disableConstFoldInLoader */ true); |
446 | usingGlowCustomOps = onnxLD.usingGlowCustomOps(); |
447 | } |
448 | CHECK(!ERR_TO_BOOL(std::move(err))) |
449 | << "ONNXModelLoader failed to load model: " << modelPathOpt; |
450 | llvm::outs() << "End onnx model load\n" ; |
451 | |
452 | if (glowDumpGraphAfterLoadOpt) { |
453 | for (Function *F : mod->getFunctions()) { |
454 | F->dumpDAG(glow::flags::DumpGraphPath + F->getName().str() + ".dot" ); |
455 | } |
456 | } |
457 | |
458 | // Build host manager and compile the module. |
459 | PrecisionConfiguration &precConfig = cctx.precisionConfig; |
460 | if (glow::flags::ConvertToFP16) { |
461 | precConfig.convertToFP16 = true; |
462 | if (sliceConcatFp32Opt) { |
463 | precConfig.precisionModeKindSet.insert(Kinded::Kind::SliceNodeKind); |
464 | precConfig.precisionModeKindSet.insert(Kinded::Kind::ConcatNodeKind); |
465 | } |
466 | llvm::outs() << "Conversion to fp16 enabled\n" ; |
467 | } |
468 | if (glow::flags::ConvertPlaceholdersToFP16) { |
469 | precConfig.convertPlaceholdersToFP16 = true; |
470 | llvm::outs() << "Conversion of Placeholders to fp16 enabled\n" ; |
471 | } |
472 | if (glow::flags::ConvertConstantsToFP16) { |
473 | precConfig.convertConstantsToFP16 = true; |
474 | llvm::outs() << "Conversion of Constants to fp16 enabled\n" ; |
475 | } |
476 | if (glow::flags::ConvertFusedScaleOffsetToFP16) { |
477 | precConfig.convertFusedToFP16 = true; |
478 | llvm::outs() << "Conversion of fused scales/offsets to fp16 enabled\n" ; |
479 | } |
480 | if (glow::flags::ConvertFusedScaleOffsetToFP32) { |
481 | precConfig.convert4BitFusedToFP32 = true; |
482 | precConfig.convert8BitFusedToFP32 = true; |
483 | llvm::outs() |
484 | << "Conversion of fused scales/offsets to fp32 in frwqslws enabled\n" ; |
485 | } |
486 | if (indicesInt64Opt) { |
487 | precConfig.convertIndicesToInt64 = indicesInt64Opt; |
488 | llvm::outs() << "Conversion of indices to int64 enabled\n" ; |
489 | } |
490 | if (glow::flags::ClipToFP16) { |
491 | precConfig.clipFP16 = true; |
492 | llvm::outs() << "Clipping to fp16 enabled\n" ; |
493 | } |
494 | if (glow::flags::SkipInputsOnClipToFP16) { |
495 | precConfig.clipFP16SkipInputs = true; |
496 | llvm::outs() << "Skipping clipping for fp16 Node inputs fp16\n" ; |
497 | } |
498 | if (glow::flags::ForceSLSToFP16Accum) { |
499 | precConfig.forceFP16AccumSLS = true; |
500 | llvm::outs() << "Forcing fp16 accumulation for SLS ops enabled\n" ; |
501 | } |
502 | if (!glow::flags::EnableQuantParamChanges) { |
503 | cctx.optimizationOpts.enableQuantParamChanges = false; |
504 | LOG(INFO) << "Disabling quantization param changes during optimizations" ; |
505 | } |
506 | if (glow::flags::DumpGraph) { |
507 | cctx.dumpFinalGraph = true; |
508 | cctx.dumpGraphPath = glow::flags::DumpGraphPath; |
509 | } |
510 | if (glow::flags::SinkTanhBelowConcat) { |
511 | cctx.optimizationOpts.sinkTanhBelowConcat = |
512 | glow::flags::SinkTanhBelowConcat; |
513 | LOG(INFO) << "Sinking tanh below concat" ; |
514 | } |
515 | |
516 | if (glow::flags::UseSparseNNPartitioningScheme) { |
517 | cctx.optimizationOpts.useSparseNNPartitioningScheme = true; |
518 | cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats = |
519 | glow::flags::SparseNNPartitioningAddSLSConcats; |
520 | cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel = |
521 | glow::flags::SparseNNPartitioningBalancePerfModel; |
522 | cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS = |
523 | glow::flags::SparseNNPartitioningPairLNWithSLS; |
524 | cctx.optimizationOpts.sparseNNPartitioningPairTileWithSLS = |
525 | glow::flags::SparseNNPartitioningPairTileWithSLS; |
526 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards = |
527 | glow::flags::SparseNNPartitioningSchemeNumCards; |
528 | cctx.optimizationOpts.sparseNNPartitioningSchemeSLSTableKBytesPerCard = |
529 | glow::flags::SparseNNPartitioningSchemeSLSTableKBytesPerCard; |
530 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresSLS = |
531 | glow::flags::SparseNNPartitioningSchemeNumCoresSLS; |
532 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresOther = |
533 | glow::flags::SparseNNPartitioningSchemeNumCoresOther; |
534 | } |
535 | |
536 | if (!glow::flags::processBackendSpecificOpts( |
537 | cctx.backendOpts.backendSpecificOpts, |
538 | glow::flags::BackendSpecificOpts)) { |
539 | return -1; |
540 | } |
541 | |
542 | if (glow::nnpi::flags::NumParallelChunks > 1) { |
543 | cctx.backendOpts.backendSpecificOpts["NNPINumParallelChunks" ] = |
544 | std::to_string(glow::nnpi::flags::NumParallelChunks); |
545 | } |
546 | if (glow::nnpi::flags::ModelParallelSplitAlignment > 1) { |
547 | cctx.backendOpts.backendSpecificOpts["NNPIModelParallelSplitAlignment" ] = |
548 | std::to_string(glow::nnpi::flags::ModelParallelSplitAlignment); |
549 | } |
550 | if (glow::flags::UseDAGOptimizer) { |
551 | cctx.callDAGOptimizer = true; |
552 | cctx.optimizationOpts.DAGOptimizerNumParallelChunks = |
553 | glow::flags::DAGOptimizerNumParallelChunks; |
554 | cctx.optimizationOpts.DAGOptimizerParallelizationTaggingAlgorithm = |
555 | glow::flags::DAGOptimizerParallelizationTaggingAlgorithm; |
556 | cctx.optimizationOpts.DAGOptimizerPlacementTaggingAlgorithm = |
557 | glow::flags::DAGOptimizerPlacementTaggingAlgorithm; |
558 | } |
559 | if (glow::flags::DelayAndRecordConstantModification) { |
560 | cctx.optimizationOpts.delayAndRecordConstantModification = true; |
561 | } |
562 | if (glow::runtime::flags::EnableP2P) { |
563 | LOG(INFO) << "Glow P2P Enabled" ; |
564 | cctx.enableP2P = true; |
565 | } |
566 | if (glow::runtime::flags::EnableDRT) { |
567 | LOG(INFO) << "Glow DRT Enabled" ; |
568 | cctx.enableDRT = true; |
569 | } |
570 | if (glow::onnxifi::flags::SaveDAG) { |
571 | LOG(INFO) << "Serializing DAG after optimization and partitioning." ; |
572 | cctx.serializeCompiledDAG = true; |
573 | } |
574 | if (glow::onnxifi::flags::SaveDAGWithConstants) { |
575 | LOG(INFO) << "Serializing DAG with constants after optimization and " |
576 | "partitioning." ; |
577 | cctx.saveConstantInSerializeCompiledDAG = true; |
578 | } |
579 | if (glow::onnxifi::flags::SaveDAGInZipMode) { |
580 | LOG(INFO) << "Serializing DAG with constants after optimization and " |
581 | "partitioning in Zip mode." ; |
582 | cctx.useZipModeForSerializeCompiledDAG = true; |
583 | } |
584 | |
585 | // Load deferred weights if applicable |
586 | const auto &placeholderList = mod->getPlaceholders(); |
587 | glow::PlaceholderList nonStaticPlaceholderList; |
588 | std::copy_if(placeholderList.begin(), placeholderList.end(), |
589 | std::back_inserter(nonStaticPlaceholderList), |
590 | [](const glow::Placeholder *p) { return !p->isStatic(); }); |
591 | if (!deferredWeightsPathOpt.empty()) { |
592 | ::glow::runtime::DeferredLoader()->registerLoader( |
593 | new ZipFileBackedDeferredBlobLoader(deferredWeightsPathOpt)); |
594 | // Initialize loader and set field in cctx. |
595 | auto *loader = runtime::DeferredLoader()->getLoader(); |
596 | CHECK(loader) << "No deferred weights loader registered!" ; |
597 | |
598 | // Generate a map of type date for all static placeholders. |
599 | std::map<std::string, Type> staticPlaceholderTypes; |
600 | for (auto *PH : placeholderList) { |
601 | if (PH->isStatic()) { |
602 | staticPlaceholderTypes[std::string(PH->getName())] = *PH->getType(); |
603 | } |
604 | } |
605 | loader->setTypeInfo(std::move(staticPlaceholderTypes)); |
606 | CHECK(!loader->setSrc(nullptr)); |
607 | cctx.deferredWeightLoader = loader; |
608 | // Signal that we want to fold convertTo and Quantize into static |
609 | // Placeholders. |
610 | cctx.optimizationOpts.foldStaticPlaceholderConversions = true; |
611 | } |
612 | |
613 | auto configs = runtime::generateDeviceConfigs( |
614 | glow::flags::NumDevices, ExecutionBackend, deviceMemoryOpt); |
615 | runtime::HostConfig hostConfig; |
616 | hostConfig.maxActiveRequests = glow::flags::MaxActiveRequests; |
617 | hostConfig.maxQueueSize = glow::flags::MaxQueueSize; |
618 | hostConfig.executorThreads = glow::flags::ExecutorThreads; |
619 | |
620 | auto hostManager = |
621 | glow::make_unique<runtime::HostManager>(std::move(configs), hostConfig); |
622 | if (glow::flags::EnablePartialTensors) { |
623 | CHECK(hostManager->getBackend(ExecutionBackend).supportsPartialTensors()) |
624 | << "Backend " << ExecutionBackend |
625 | << " doesn't support partial tensor but enablePartialTensor is set to " |
626 | "true." ; |
627 | } |
628 | cctx.saturateHost = glow::flags::SaturateHost; |
629 | EXIT_ON_ERR(hostManager->addNetwork(std::move(mod), cctx)); |
630 | |
631 | // Whether to collect results and check accuracy. If we're not checking |
632 | // accuracy then don't load reference outputs |
633 | bool runAccuracyChecks = |
634 | !skipCorrectnessCheck || topKCompare > 0 || cosineSimilarityStats; |
635 | |
636 | // Parse all input and output files ahead of inference. |
637 | std::vector<::ONNX_NAMESPACE::GraphProto> parsedInputs; |
638 | std::vector<::ONNX_NAMESPACE::GraphProto> parsedOutputs; |
639 | size_t inputGroupSize = inputsOpt.size(); |
640 | if (inputGroupSize) { |
641 | for (size_t i = 0; i < inputGroupSize; ++i) { |
642 | llvm::outs() << "Loading input file: " << inputsOpt[i] << "\n" ; |
643 | auto inputGroup = parseOnnxFile(inputsOpt[i]); |
644 | parsedInputs.push_back(std::move(inputGroup)); |
645 | if (runAccuracyChecks) { |
646 | llvm::outs() << "Loading output file: " << outputsOpt[i] << "\n" ; |
647 | auto outputGroup = parseOnnxFile(outputsOpt[i]); |
648 | parsedOutputs.push_back(std::move(outputGroup)); |
649 | } |
650 | } |
651 | } else if (!inputPatternOpt.empty() && seqLenOpt > 0) { |
652 | inputGroupSize = seqLenOpt; |
653 | size_t input_iter = inputPatternOpt.find("{}" ); |
654 | CHECK_NE(input_iter, std::string::npos) |
655 | << "Input pattern " << inputPatternOpt << " has to contain {}" ; |
656 | for (unsigned i = 0; i < seqLenOpt; ++i) { |
657 | std::string copy = inputPatternOpt; |
658 | copy.replace(input_iter, 2, std::to_string(seqStartOpt + i)); |
659 | llvm::outs() << "Loading input file: " << copy << "\n" ; |
660 | auto inputGroup = parseOnnxFile(copy); |
661 | parsedInputs.push_back(std::move(inputGroup)); |
662 | } |
663 | |
664 | if (runAccuracyChecks) { |
665 | CHECK(!outputPatternOpt.empty()) |
666 | << "Output pattern must be provided for accuracy checks" ; |
667 | size_t output_iter = outputPatternOpt.find("{}" ); |
668 | CHECK_NE(output_iter, std::string::npos) |
669 | << "Output pattern " << outputPatternOpt << " has to contain {}" ; |
670 | for (unsigned i = 0; i < seqLenOpt; ++i) { |
671 | std::string copy = outputPatternOpt; |
672 | copy.replace(output_iter, 2, std::to_string(seqStartOpt + i)); |
673 | llvm::outs() << "Loading output file: " << copy << "\n" ; |
674 | auto outputGroup = parseOnnxFile(copy); |
675 | parsedOutputs.push_back(std::move(outputGroup)); |
676 | } |
677 | } |
678 | } |
679 | |
680 | llvm::outs() << "\ninput pattern: " + inputPatternOpt + "\n" ; |
681 | llvm::outs() << "\nseqlen: " + std::to_string(seqLenOpt) + "\n" ; |
682 | llvm::outs() << "\ninputgroupsize: " + std::to_string(inputGroupSize) + "\n" ; |
683 | |
684 | if (parsedInputs.empty()) { |
685 | llvm::outs() << "No inputs are provided. Exiting...\n" ; |
686 | return -1; |
687 | } |
688 | |
689 | llvm::outs() << "Starting inference\n" ; |
690 | llvm::outs().flush(); // Explicit flush to denote the progress |
691 | |
692 | auto nowTime = std::chrono::steady_clock::now(); |
693 | auto endTimeDuration = nowTime + std::chrono::minutes(durationMinOpt); |
694 | do { |
695 | TraceContext mergedTraceContext(TraceLevel::STANDARD); |
696 | folly::CPUThreadPoolExecutor threadPool(concurrentCountOpt); |
697 | std::mutex mutex; |
698 | std::condition_variable cv; |
699 | int numTotalInferences = inputGroupSize * itersOpt; |
700 | int numFinishedInferences = 0; |
701 | |
702 | // Figure out which placeholder is input. |
703 | std::unordered_set<std::string> inputTensorNames; |
704 | for (const auto &proto : parsedInputs[0].initializer()) { |
705 | inputTensorNames.insert(glow::legalizeName(proto.name())); |
706 | } |
707 | |
708 | glow::PlaceholderList inputPlaceholderList; |
709 | std::copy_if(placeholderList.begin(), placeholderList.end(), |
710 | std::back_inserter(inputPlaceholderList), |
711 | [&](const glow::Placeholder *p) { |
712 | return inputTensorNames.find(p->getName().str()) != |
713 | inputTensorNames.end(); |
714 | }); |
715 | |
716 | std::vector<Tensor> partialTensorPayloads; |
717 | std::vector<PlaceholderBindings> inputBindings; |
718 | for (const auto &inputGroup : parsedInputs) { |
719 | PlaceholderBindings bindings; |
720 | bindings.allocate(inputPlaceholderList); |
721 | fillPlaceholders( |
722 | inputGroup, &bindings, |
723 | glow::flags::EnablePartialTensors ? &partialTensorPayloads : nullptr, |
724 | usingGlowCustomOps); |
725 | inputBindings.emplace_back(std::move(bindings)); |
726 | } |
727 | |
728 | bool enableGlowTrace = glow::flags::DumpDebugTraces || |
729 | TraceExporterRegistry::getInstance()->shouldTrace(); |
730 | |
731 | if (enableGlowTrace && glowEnableDeviceTrace) { |
732 | // Start device traces. |
733 | hostManager->setTraceContext( |
734 | glow::make_unique<TraceContext>(TraceLevel::STANDARD)); |
735 | Error startErr = hostManager->startDeviceTrace(); |
736 | if (ERR_TO_BOOL(std::move(startErr))) { |
737 | LOG(WARNING) << "Failed to start device traces" ; |
738 | } |
739 | } |
740 | |
741 | auto startTime = std::chrono::steady_clock::now(); |
742 | std::list<InferenceResult> results; |
743 | for (int ioIndex = 0, numInferencesIssued = 0; |
744 | numInferencesIssued < numTotalInferences; ++numInferencesIssued, |
745 | ioIndex = numInferencesIssued % inputGroupSize) { |
746 | |
747 | results.emplace_back(); |
748 | auto &result = results.back(); |
749 | |
750 | threadPool.add([&inputBindings, &nonStaticPlaceholderList, ioIndex, |
751 | &mergedTraceContext, &hostManager, &result, &cv, &mutex, |
752 | numTotalInferences, &numFinishedInferences, |
753 | runAccuracyChecks, enableGlowTrace]() { |
754 | // Setup the inputs. |
755 | auto ctx = glow::make_unique<ExecutionContext>(); |
756 | |
757 | TraceContext *traceContext = nullptr; |
758 | if (enableGlowTrace) { |
759 | ctx->setTraceContext( |
760 | glow::make_unique<TraceContext>(TraceLevel::STANDARD)); |
761 | traceContext = ctx->getTraceContext(); |
762 | traceContext->setThreadName("Request Thread" ); |
763 | } |
764 | TRACE_EVENT_SCOPE(traceContext, TraceLevel::RUNTIME, |
765 | "Dispatch to prep input and dispatch" ); |
766 | |
767 | // Set up input |
768 | auto &bindings = *ctx->getPlaceholderBindings(); |
769 | bindings.clear(); |
770 | |
771 | for (const auto &binding : inputBindings[ioIndex].pairs()) { |
772 | auto *PH = binding.first; |
773 | bindings.insert(PH, binding.second.getUnowned()); |
774 | } |
775 | // Allocate for output |
776 | bindings.allocate(nonStaticPlaceholderList); |
777 | |
778 | std::promise<void> promise; |
779 | auto future = promise.get_future(); |
780 | |
781 | TRACE_EVENT_SCOPE_END(); |
782 | |
783 | hostManager->runNetwork( |
784 | "test" , std::move(ctx), |
785 | [&promise, index = ioIndex, |
786 | &result](runtime::RunIdentifierTy, Error err, |
787 | std::unique_ptr<ExecutionContext> contextPtr) mutable { |
788 | result.error = std::move(err); |
789 | result.ctx = std::move(contextPtr); |
790 | result.index = index; |
791 | result.endTime = std::chrono::steady_clock::now(); |
792 | promise.set_value(); |
793 | }); |
794 | |
795 | // wait for glow to finish. |
796 | future.wait(); |
797 | traceContext = result.ctx->getTraceContext(); |
798 | if (traceContext) { |
799 | // export to registered trace exporters |
800 | TraceExporterRegistry::getInstance()->exportTrace(traceContext); |
801 | // merge() has internal lock and is thread safe. |
802 | mergedTraceContext.merge(traceContext); |
803 | } |
804 | |
805 | if (!runAccuracyChecks) { |
806 | // if skipping correctness check, throw away the context to keep |
807 | // memory usage low. |
808 | result.ctx.reset(); |
809 | } |
810 | |
811 | std::unique_lock<std::mutex> lock(mutex); |
812 | if (++numFinishedInferences >= numTotalInferences) { |
813 | lock.unlock(); |
814 | cv.notify_all(); |
815 | } |
816 | }); |
817 | } |
818 | |
819 | // wait for all inferneces to finish |
820 | std::unique_lock<std::mutex> lock(mutex); |
821 | cv.wait(lock, |
822 | [&]() { return numFinishedInferences >= numTotalInferences; }); |
823 | |
824 | auto endTime = startTime; |
825 | llvm::outs() << "All inferences done. Checking results\n" ; |
826 | for (auto &result : results) { |
827 | if (result.endTime > endTime) { |
828 | endTime = result.endTime; |
829 | } |
830 | |
831 | if (result.error) { |
832 | llvm::outs() << "Inference failed!\n" ; |
833 | if (result.error.peekErrorValue()->isFatalError()) { |
834 | std::string msg = result.error.peekErrorValue()->logToString(); |
835 | llvm::outs() << "Non-recoverable device error: " << msg << "\n" ; |
836 | } |
837 | ++numFailed; |
838 | } else { |
839 | ONNX_NAMESPACE::GraphProto outputG; |
840 | std::ofstream of; |
841 | if (dumpOutputsOpt) { |
842 | std::stringstream ss; |
843 | ss << "output_dump_" << result.index << ".onnx" ; |
844 | of.open(ss.str(), std::ios::binary); |
845 | CHECK(of) << "Cannot create output dump file: " << ss.str(); |
846 | } |
847 | |
848 | if (runAccuracyChecks) { |
849 | const auto &outputGroup = parsedOutputs[result.index]; |
850 | CHECK(result.ctx); |
851 | const auto &bindings = *result.ctx->getPlaceholderBindings(); |
852 | for (const auto &tp : outputGroup.initializer()) { |
853 | Tensor tensorRef; |
854 | auto error = loadTensor(tp, &tensorRef, usingGlowCustomOps); |
855 | CHECK(!ERR_TO_BOOL(std::move(error))) |
856 | << "Cannot load output ref tensor" ; |
857 | const auto *tensor = |
858 | bindings.get(bindings.getPlaceholderByNameSlow(tp.name())); |
859 | CHECK(tensor) << "Missing " << tp.name() |
860 | << " in output placeholder" ; |
861 | |
862 | if (cosineSimilarityStats) { |
863 | cosineHist.addValue(cosineSimilarity(*tensor, tensorRef)); |
864 | } |
865 | |
866 | if (topKCompare > 0) { |
867 | numTotalTopKCompares++; |
868 | assert(tensor->size() == tensorRef.size()); |
869 | auto sortedResults = partialSortFloatTensor(*tensor, topKCompare); |
870 | auto sortedRefs = partialSortFloatTensor(tensorRef, topKCompare); |
871 | assert(sortedResults.size() == size_t(topKCompare) && |
872 | sortedResults.size() == size_t(topKCompare)); |
873 | |
874 | bool allKMatch = true; |
875 | std::stringstream ss; |
876 | for (auto i = 0; i < topKCompare; i++) { |
877 | if (sortedResults[i].second == sortedRefs[i].second) { |
878 | if (i == 0) { |
879 | numTop1Matches++; |
880 | } |
881 | } else { |
882 | allKMatch = false; |
883 | } |
884 | if (logTopKResultsPerExample) { |
885 | ss << i << ": Test result: " << sortedResults[i].second |
886 | << " (p=" << sortedResults[i].first |
887 | << ") Reference result: " << sortedRefs[i].second |
888 | << " (p=" << sortedRefs[i].first << ")\n" ; |
889 | } |
890 | } |
891 | if (logTopKResultsPerExample) { |
892 | llvm::outs() << ss.str() << "\n" ; |
893 | } |
894 | if (allKMatch) { |
895 | numTopKMatches++; |
896 | } |
897 | } |
898 | |
899 | if (dumpOutputsOpt) { |
900 | auto *t = outputG.add_initializer(); |
901 | ONNXModelWriter::writeTensor(*tensor, t, usingGlowCustomOps); |
902 | t->set_name(tp.name()); |
903 | } |
904 | |
905 | if (!skipCorrectnessCheck) { |
906 | bool equal = tensorRef.isEqual(*tensor, thresholdOpt, true); |
907 | if (!equal) { |
908 | llvm::outs() << "Verification failed at input/output pair " |
909 | << result.index << " for output tensor " |
910 | << tp.name() << "\n" ; |
911 | ++numFailed; |
912 | break; |
913 | } |
914 | } |
915 | } |
916 | } |
917 | |
918 | if (dumpOutputsOpt) { |
919 | std::string buffer; |
920 | outputG.SerializeToString(&buffer); |
921 | of << buffer; |
922 | } |
923 | } |
924 | } |
925 | |
926 | if (glow::flags::DumpDebugTraces) { |
927 | if (glowEnableDeviceTrace) { |
928 | // Stop device traces and collect events. |
929 | Error stopErr = hostManager->stopDeviceTrace(); |
930 | if (ERR_TO_BOOL(std::move(stopErr))) { |
931 | LOG(WARNING) << "Failed to stop device traces." ; |
932 | } else { |
933 | mergedTraceContext.merge(hostManager->getTraceContext()); |
934 | } |
935 | } |
936 | llvm::SmallString<64> path; |
937 | if (glowDumpTraceFile.empty()) { |
938 | auto tempFileRes = |
939 | llvm::sys::fs::createTemporaryFile("glow-trace" , "json" , path); |
940 | if (tempFileRes.value() != 0) { |
941 | LOG(ERROR) << "Failed to create temp file for Glow trace events: " |
942 | << tempFileRes; |
943 | } else { |
944 | LOG(INFO) << "Trace path=" << path.c_str(); |
945 | mergedTraceContext.dump(path); |
946 | } |
947 | } else { |
948 | LOG(INFO) << "Trace path=" << path.c_str(); |
949 | mergedTraceContext.dump(glowDumpTraceFile); |
950 | } |
951 | } |
952 | |
953 | if (!skipCorrectnessCheck) { |
954 | if (numFailed == 0) { |
955 | llvm::outs() << "All passed!\n" ; |
956 | } else { |
957 | llvm::outs() << numFailed << " inferences failed to match reference.\n" ; |
958 | } |
959 | } |
960 | |
961 | if (topKCompare > 0) { |
962 | llvm::outs() << "Num top1 exact matches: " << numTop1Matches << "/" |
963 | << numTotalTopKCompares << "\n" ; |
964 | llvm::outs() << "Num topK exact matches (k=" << topKCompare |
965 | << "): " << numTopKMatches << "/" << numTotalTopKCompares |
966 | << "\n" ; |
967 | |
968 | if (top1Threshold > 0.0) { |
969 | float top1MatchRate = float(numTop1Matches) / numTotalTopKCompares; |
970 | if (top1MatchRate < top1Threshold) { |
971 | llvm::outs() << "Expected top1 match rate of at least " |
972 | << top1Threshold << " but only achieved " |
973 | << top1MatchRate << "\n" ; |
974 | return numTotalTopKCompares - numTop1Matches; |
975 | } |
976 | } |
977 | } |
978 | |
979 | if (cosineSimilarityStats) { |
980 | float p50Similarity = cosineHist.getPercentileEstimate(0.5); |
981 | llvm::outs() << "cosine similarity stats:\n" |
982 | << "p01: " << cosineHist.getPercentileEstimate(0.01) << "\n" |
983 | << "p02: " << cosineHist.getPercentileEstimate(0.02) << "\n" |
984 | << "p05: " << cosineHist.getPercentileEstimate(0.05) << "\n" |
985 | << "p10: " << cosineHist.getPercentileEstimate(0.1) << "\n" |
986 | << "p25: " << cosineHist.getPercentileEstimate(0.25) << "\n" |
987 | << "p50: " << p50Similarity << "\n" |
988 | << "p75: " << cosineHist.getPercentileEstimate(0.75) << "\n" |
989 | << "p90: " << cosineHist.getPercentileEstimate(0.90) << "\n" |
990 | << "p95: " << cosineHist.getPercentileEstimate(0.95) << "\n" |
991 | << "p98: " << cosineHist.getPercentileEstimate(0.98) << "\n" |
992 | << "p99: " << cosineHist.getPercentileEstimate(0.99) << "\n" ; |
993 | if (cosineSimilarityThreshold > 0.0) { |
994 | if (p50Similarity < cosineSimilarityThreshold) { |
995 | llvm::outs() << "Expected p50 cosine similarity of at least " |
996 | << cosineSimilarityThreshold << " but only achieved " |
997 | << p50Similarity << "\n" ; |
998 | return 1; |
999 | } |
1000 | } |
1001 | } |
1002 | |
1003 | llvm::outs().flush(); |
1004 | |
1005 | std::chrono::duration<double, std::milli> duration = endTime - startTime; |
1006 | std::cout << "Total inference duration (ms): " << duration.count() << "\n" ; |
1007 | std::cout << "Avg inference duration (ms): " |
1008 | << duration.count() / numTotalInferences << "\n" ; |
1009 | std::cout << "Avg inference per second: " |
1010 | << numTotalInferences * 1000 / duration.count() |
1011 | << std::endl; // Use endl to flush the buffer |
1012 | nowTime = std::chrono::steady_clock::now(); |
1013 | } while (std::chrono::duration_cast<std::chrono::seconds>(nowTime - |
1014 | endTimeDuration) |
1015 | .count() < 0); |
1016 | |
1017 | return numFailed; |
1018 | } |
1019 | |
1020 | } // namespace |
1021 | |
1022 | int main(int argc, char **argv) { |
1023 | google::InitGoogleLogging(argv[0]); |
1024 | google::InstallFailureSignalHandler(); |
1025 | parseCommandLine(argc, argv); |
1026 | return run(); |
1027 | } |
1028 | |