1 | /* |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "HostManagerOnnxifi.h" |
18 | |
19 | #include "glow/Flags/Flags.h" |
20 | #include "glow/Runtime/DeferredWeightLoader.h" |
21 | #include "glow/Runtime/ErrorReporter.h" |
22 | #include "glow/Runtime/RequestData.h" |
23 | |
24 | #include "llvm/Support/CommandLine.h" |
25 | #include "llvm/Support/FileSystem.h" |
26 | |
27 | namespace glow { |
28 | namespace onnxifi { |
29 | |
30 | static llvm::cl::opt<int32_t, true> |
31 | GlowNumDevicesOpt("glow-num-devices" , |
32 | llvm::cl::desc("Number of devices for Glow backend" ), |
33 | llvm::cl::location(::glow::flags::NumDevices)); |
34 | |
35 | static llvm::cl::opt<bool, true> |
36 | GlowDumpDebugTracesOpt("glow-dump-debug-traces" , |
37 | llvm::cl::desc("Dump a trace of each run to /tmp" ), |
38 | llvm::cl::location(glow::flags::DumpDebugTraces)); |
39 | |
40 | static llvm::cl::opt<bool, true> GlowSaturateHostOpt( |
41 | "glow-saturate-host" , |
42 | llvm::cl::desc("Try to use all available devices on the host" ), |
43 | llvm::cl::location(glow::flags::SaturateHost)); |
44 | |
45 | static llvm::cl::opt<int32_t, true> ( |
46 | "glow_snn_partitioning_num_cards" , |
47 | llvm::cl::desc("Number of cards for SparseNNPartitioningScheme" ), |
48 | llvm::cl::location(glow::flags::SparseNNPartitioningSchemeNumCards)); |
49 | |
50 | static llvm::cl::opt<int64_t, true> |
51 | GlowSparseNNPartitioningSchemeSLSTableKBytesPerCardOpt( |
52 | "glow_snn_partitioning_kbytes_per_card" , |
53 | llvm::cl::desc("SLS KBytes per card for SparseNNPartitioningScheme" ), |
54 | llvm::cl::location( |
55 | glow::flags::SparseNNPartitioningSchemeSLSTableKBytesPerCard)); |
56 | |
57 | static llvm::cl::opt<int32_t, true> |
58 | ( |
59 | "glow_snn_partitioning_num_cores_sls" , |
60 | llvm::cl::desc( |
61 | "Number of cores for SLS for SparseNNPartitioningScheme" ), |
62 | llvm::cl::location(glow::flags::SparseNNPartitioningSchemeNumCoresSLS)); |
63 | |
64 | static llvm::cl::opt<int32_t, true> |
65 | ( |
66 | "glow_snn_partitioning_num_cores_other" , |
67 | llvm::cl::desc( |
68 | "Number of cores for other for SparseNNPartitioningScheme" ), |
69 | llvm::cl::location( |
70 | glow::flags::SparseNNPartitioningSchemeNumCoresOther)); |
71 | |
72 | static llvm::cl::opt<bool, true> GlowUseSparseNNPartitioningSchemeOpt( |
73 | "glow_use_sparsenn_partitioning_scheme" , |
74 | llvm::cl::desc("Whether to use SparseNNPartitioningScheme" ), |
75 | llvm::cl::location(glow::flags::UseSparseNNPartitioningScheme)); |
76 | |
77 | static llvm::cl::opt<bool, true> GlowSparseNNPartitioningAddSLSConcatsOpt( |
78 | "glow_sparsenn_partitioning_add_sls_concats" , |
79 | llvm::cl::desc("Add extra concats inside of SLS partitions for more " |
80 | "efficient inter-partitition transfers" ), |
81 | llvm::cl::location(glow::flags::SparseNNPartitioningAddSLSConcats)); |
82 | |
83 | static llvm::cl::opt<bool, true> GlowSparseNNPartitioningBalancePerfModelOpt( |
84 | "glow_sparsenn_partitioning_balance_perf_model" , |
85 | llvm::cl::desc("Balance SLS tables across cards using a perf model" ), |
86 | llvm::cl::location(glow::flags::SparseNNPartitioningBalancePerfModel)); |
87 | |
88 | static llvm::cl::opt<bool, true> GlowSparseNNPartitioningPairLNWithSLSOpt( |
89 | "glow_sparsenn_partitioning_pair_ln_with_sls" , |
90 | llvm::cl::desc("Place layer normalization nodes immediately following SLS " |
91 | "into SLS partition" ), |
92 | llvm::cl::location(glow::flags::SparseNNPartitioningPairLNWithSLS)); |
93 | static llvm::cl::opt<bool, true> GlowSparseNNPartitioningPairTileWithSLSOpt( |
94 | "glow_sparsenn_partitioning_pair_tile_with_sls" , |
95 | llvm::cl::desc("Place Tile nodes immediately following SLS " |
96 | "for user embeddings into SLS partition" ), |
97 | llvm::cl::location(glow::flags::SparseNNPartitioningPairTileWithSLS)); |
98 | |
99 | static llvm::cl::opt<std::string, true> GlowSparseNNPartitioningPairSLSWithOpt( |
100 | "glow_sparsenn_partitioning_pair_sls_with" , |
101 | llvm::cl::desc("Place specified nodes immediately following SLS " |
102 | "into SLS partition" ), |
103 | llvm::cl::location(glow::flags::SparseNNPartitioningPairSLSWith)); |
104 | |
105 | static llvm::cl::opt<int32_t, true> GlowSparseNNPartitioningConcatSplitSizeOpt( |
106 | "glow_sparsenn_partitioning_concat_split_size" , |
107 | llvm::cl::desc("Split concat going into tanh sink into smaller concats of " |
108 | "specified size to move into SLS partition" ), |
109 | llvm::cl::location(glow::flags::SparseNNPartitioningConcatSplitSize)); |
110 | |
111 | std::unique_ptr<runtime::HostManager> |
112 | HostManagerBackend::createHostManager(llvm::StringRef backendName) { |
113 | std::vector<std::unique_ptr<runtime::DeviceConfig>> configs; |
114 | // If GlowNumDevices is set specify that many devices, otherwise use all |
115 | // discovered devices. |
116 | if (glow::flags::NumDevices) { |
117 | for (int i = 0; i < glow::flags::NumDevices; i++) { |
118 | auto config = glow::make_unique<runtime::DeviceConfig>(backendName); |
119 | config->deviceID = i; |
120 | configs.push_back(std::move(config)); |
121 | } |
122 | } else { |
123 | configs = runtime::DeviceManager::generateDeviceConfigs( |
124 | backendName, glow::flags::ScanDevices); |
125 | } |
126 | |
127 | runtime::HostConfig hostConfig; |
128 | hostConfig.maxActiveRequests = glow::flags::MaxActiveRequests; |
129 | hostConfig.maxQueueSize = glow::flags::MaxQueueSize; |
130 | hostConfig.executorThreads = glow::flags::ExecutorThreads; |
131 | |
132 | return glow::make_unique<runtime::HostManager>(std::move(configs), |
133 | hostConfig); |
134 | } |
135 | |
136 | void HostManagerBackend::runNetwork(const Graph *graph, |
137 | std::unique_ptr<ExecutionContext> context, |
138 | runtime::ResultCBTy callback, |
139 | uint64_t priority) { |
140 | DCHECK(callback != nullptr); |
141 | |
142 | auto hostManagerGraph = static_cast<const HostManagerGraph *>(graph); |
143 | hostManager_->runNetwork(hostManagerGraph->getName(), std::move(context), |
144 | std::move(callback), priority); |
145 | } |
146 | |
147 | onnxStatus HostManagerBackend::addNetwork( |
148 | std::unique_ptr<Module> module, void *deferredBlobReader, |
149 | CompilationContext &cctx, |
150 | std::map<std::string, Type> &&staticPlaceholderTypes) { |
151 | PrecisionConfiguration &precConfig = cctx.precisionConfig; |
152 | cctx.maxActiveRequestsPerInstance = glow::flags::MaxActiveRequestsPerInstance; |
153 | |
154 | if (glow::flags::SkipProvisioning || deferredBlobReader) { |
155 | // Generate a map of type date for all static placeholders. Do this |
156 | // regardless of whether we have deferredBlobReader because we don't have |
157 | // one for AOT but we still want to use this info for serialization. |
158 | if (staticPlaceholderTypes.size() == 0) { |
159 | for (auto *PH : module->getPlaceholders()) { |
160 | if (PH->isStatic()) { |
161 | staticPlaceholderTypes[std::string(PH->getName())] = *PH->getType(); |
162 | } |
163 | } |
164 | } |
165 | |
166 | // Signal that we want to fold convertTo and Quantize into static |
167 | // Placeholders. Also want to do this for AOT optimization even if we don't |
168 | // have a deferred blob reader present. |
169 | cctx.optimizationOpts.foldStaticPlaceholderConversions = true; |
170 | } |
171 | |
172 | // Copy the types into the cctx so that we have access to them regardless of |
173 | // whether there is a deferredBlobReader. |
174 | cctx.staticPlaceholderTypesForAOT = staticPlaceholderTypes; |
175 | |
176 | if (deferredBlobReader) { |
177 | // Initialize loader and set field in cctx. |
178 | auto loader = runtime::DeferredLoader()->getLoader(); |
179 | if (!loader) { |
180 | LOG(INFO) << "Blob reader provided but no loader registered!" ; |
181 | return ONNXIFI_STATUS_INTERNAL_ERROR; |
182 | } |
183 | |
184 | loader->setTypeInfo(std::move(staticPlaceholderTypes)); |
185 | auto err = loader->setSrc(deferredBlobReader); |
186 | if (ERR_TO_BOOL(std::move(err))) { |
187 | return ONNXIFI_STATUS_INTERNAL_ERROR; |
188 | } |
189 | |
190 | cctx.deferredWeightLoader = loader; |
191 | } |
192 | |
193 | if (glow::flags::ConvertToFP16) { |
194 | precConfig.convertToFP16 = glow::flags::ConvertToFP16; |
195 | LOG(INFO) << "Conversion to fp16 enabled" ; |
196 | } |
197 | if (glow::flags::SkipBiasFp32tofp16Convert) { |
198 | precConfig.skipBiasFp32tofp16Convert = |
199 | glow::flags::SkipBiasFp32tofp16Convert; |
200 | LOG(INFO) << "Skip fp16 convert for bias" ; |
201 | } |
202 | if (glow::flags::ConvertPlaceholdersToFP16) { |
203 | precConfig.convertPlaceholdersToFP16 = |
204 | glow::flags::ConvertPlaceholdersToFP16; |
205 | LOG(INFO) << "Conversion of Placeholders to fp16 enabled" ; |
206 | } |
207 | if (glow::flags::ConvertConstantsToFP16) { |
208 | precConfig.convertConstantsToFP16 = glow::flags::ConvertConstantsToFP16; |
209 | LOG(INFO) << "Conversion of Constants to fp16 enabled" ; |
210 | } |
211 | if (glow::flags::ConvertFusedScaleOffsetToFP16) { |
212 | precConfig.convertFusedToFP16 = glow::flags::ConvertFusedScaleOffsetToFP16; |
213 | LOG(INFO) << "Conversion of fused scales/offsets to fp16 enabled" ; |
214 | } |
215 | if (glow::flags::ConvertFusedScaleOffsetToFP32) { |
216 | precConfig.convert4BitFusedToFP32 = |
217 | glow::flags::ConvertFusedScaleOffsetToFP32; |
218 | precConfig.convert8BitFusedToFP32 = |
219 | glow::flags::ConvertFusedScaleOffsetToFP32; |
220 | LOG(INFO) << "Conversion of fused scales/offsets to fp32 enabled" ; |
221 | } |
222 | if (glow::flags::ClipToFP16) { |
223 | precConfig.clipFP16 = glow::flags::ClipToFP16; |
224 | LOG(INFO) << "Clipping to fp16 enabled" ; |
225 | } |
226 | if (glow::flags::SkipInputsOnClipToFP16) { |
227 | precConfig.clipFP16SkipInputs = glow::flags::SkipInputsOnClipToFP16; |
228 | LOG(INFO) << "Skipping clipping for fp16 Node inputs fp16" ; |
229 | } |
230 | if (glow::flags::ForceSLSToFP16Accum) { |
231 | precConfig.forceFP16AccumSLS = glow::flags::ForceSLSToFP16Accum; |
232 | LOG(INFO) << "Forcing all SLS/SLWS ops to use FP16 accumulation enabled" ; |
233 | } |
234 | if (!glow::flags::EnableQuantParamChanges) { |
235 | cctx.optimizationOpts.enableQuantParamChanges = false; |
236 | LOG(INFO) << "Disabling quantization param changes during optimizations" ; |
237 | } |
238 | if (glow::flags::DumpCompilationLog) { |
239 | cctx.compilationLogPrefix = "glow-onnxifi" ; |
240 | } |
241 | if (glow::flags::SinkTanhBelowConcat) { |
242 | cctx.optimizationOpts.sinkTanhBelowConcat = |
243 | glow::flags::SinkTanhBelowConcat; |
244 | LOG(INFO) << "Sinking tanh below concat" ; |
245 | } |
246 | if (glow::flags::UseSparseNNPartitioningScheme) { |
247 | cctx.optimizationOpts.useSparseNNPartitioningScheme = true; |
248 | cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats = |
249 | glow::flags::SparseNNPartitioningAddSLSConcats; |
250 | cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel = |
251 | glow::flags::SparseNNPartitioningBalancePerfModel; |
252 | cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS = |
253 | glow::flags::SparseNNPartitioningPairLNWithSLS; |
254 | cctx.optimizationOpts.sparseNNPartitioningPairTileWithSLS = |
255 | glow::flags::SparseNNPartitioningPairTileWithSLS; |
256 | cctx.optimizationOpts.sparseNNPartitioningPairSLSWith = |
257 | glow::flags::SparseNNPartitioningPairSLSWith; |
258 | cctx.optimizationOpts.sparseNNPartitioningConcatSplitSize = |
259 | glow::flags::SparseNNPartitioningConcatSplitSize; |
260 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards = |
261 | glow::flags::SparseNNPartitioningSchemeNumCards; |
262 | cctx.optimizationOpts.sparseNNPartitioningSchemeSLSTableKBytesPerCard = |
263 | glow::flags::SparseNNPartitioningSchemeSLSTableKBytesPerCard; |
264 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresSLS = |
265 | glow::flags::SparseNNPartitioningSchemeNumCoresSLS; |
266 | cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresOther = |
267 | glow::flags::SparseNNPartitioningSchemeNumCoresOther; |
268 | } |
269 | if (glow::flags::DumpGraph) { |
270 | cctx.dumpFinalGraph = true; |
271 | cctx.dumpGraphPath = glow::flags::DumpGraphPath; |
272 | } |
273 | if (glow::flags::UseDAGOptimizer) { |
274 | LOG(INFO) << "Will call the DAG optimizer." ; |
275 | cctx.callDAGOptimizer = true; |
276 | cctx.optimizationOpts.DAGOptimizerPlacementTaggingAlgorithm = |
277 | glow::flags::DAGOptimizerPlacementTaggingAlgorithm; |
278 | cctx.optimizationOpts.DAGOptimizerParallelizationTaggingAlgorithm = |
279 | glow::flags::DAGOptimizerParallelizationTaggingAlgorithm; |
280 | cctx.optimizationOpts.DAGOptimizerNumParallelChunks = |
281 | glow::flags::DAGOptimizerNumParallelChunks; |
282 | } |
283 | if (glow::flags::SkipProvisioning) { |
284 | LOG(INFO) << "Will skip provisioning (likely due to AOT opt)." ; |
285 | cctx.skipProvisioning = true; |
286 | } |
287 | if (glow::onnxifi::flags::SaveDAG) { |
288 | LOG(INFO) << "Serializing DAG after optimization and partitioning." ; |
289 | cctx.serializeCompiledDAG = true; |
290 | } |
291 | if (glow::flags::DelayAndRecordConstantModification) { |
292 | LOG(INFO) << "Delaying constant modification until after optimizations, " |
293 | "including recording constant folding for DAG serialization." ; |
294 | cctx.optimizationOpts.delayAndRecordConstantModification = true; |
295 | } |
296 | cctx.saturateHost = glow::flags::SaturateHost; |
297 | |
298 | if (!glow::flags::processBackendSpecificOpts( |
299 | cctx.backendOpts.backendSpecificOpts, |
300 | glow::flags::BackendSpecificOpts)) { |
301 | return ONNXIFI_STATUS_INTERNAL_ERROR; |
302 | } |
303 | if (glow::runtime::flags::EnableP2P) { |
304 | LOG(INFO) << "Glow P2P Enabled" ; |
305 | cctx.enableP2P = true; |
306 | } |
307 | if (glow::runtime::flags::EnableDRT) { |
308 | LOG(INFO) << "Glow DRT Enabled" ; |
309 | cctx.enableDRT = true; |
310 | } |
311 | |
312 | auto err = hostManager_->addNetwork(std::move(module), cctx); |
313 | |
314 | if (err) { |
315 | std::string msg = err.peekErrorValue()->logToString(); |
316 | auto reporters = ErrorReporterRegistry::ErrorReporters(); |
317 | if (reporters) { |
318 | reporters->report(msg); |
319 | } |
320 | const std::string errMsg = |
321 | "Non-recoverable device error when adding network: " + msg; |
322 | if (cctx.skipProvisioning) { |
323 | LOG(ERROR) << errMsg; |
324 | throw std::invalid_argument(strFormat( |
325 | "Error during AOT optimization (non-provisioned addNetwork):\n%s\n" , |
326 | errMsg.c_str())); |
327 | } else if (err.peekErrorValue()->getErrorCode() == |
328 | ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR) { |
329 | // If a deferred weight error occurs, log the error but do not fatal so we |
330 | // can try again. |
331 | LOG(ERROR) << errMsg; |
332 | return ONNXIFI_STATUS_INTERNAL_ERROR; |
333 | } else { |
334 | LOG(FATAL) << errMsg; |
335 | } |
336 | } |
337 | |
338 | return ONNXIFI_STATUS_SUCCESS; |
339 | } |
340 | |
341 | onnxStatus HostManagerBackend::removeNetwork(const Graph *graph) { |
342 | auto hostManagerGraph = static_cast<const HostManagerGraph *>(graph); |
343 | auto error = hostManager_->removeNetwork(hostManagerGraph->getName()); |
344 | |
345 | if (ERR_TO_BOOL(std::move(error))) { |
346 | return ONNXIFI_STATUS_INTERNAL_ERROR; |
347 | } |
348 | |
349 | return ONNXIFI_STATUS_SUCCESS; |
350 | } |
351 | |
352 | onnxStatus HostManagerGraph::initGraph( |
353 | const void *onnxModel, size_t onnxModelSize, uint32_t weightCount, |
354 | const onnxTensorDescriptorV1 *weightDescriptors, uint32_t maxSeqLength, |
355 | void *deferedBlobReader, bool loadingGlowAOT) { |
356 | |
357 | netName_ = strFormat("onnxifi_function_%lu" , makeUniqueGraphId()); |
358 | |
359 | std::unique_ptr<Module> module = glow::make_unique<Module>(); |
360 | CompilationContext cctx; |
361 | runtime::PrePartitionedConfig PPC; |
362 | cctx.prepartitionedConfig = &PPC; |
363 | OriginNameToTQPMap originNameToTQPMap; |
364 | if (glow::flags::UseTrackedDummyQuantParams) { |
365 | cctx.precisionConfig.originNameToTQPMap = &originNameToTQPMap; |
366 | cctx.precisionConfig.loadUniquedDummyQParams = true; |
367 | } |
368 | cctx.precisionConfig.clipQuantRangeToFP16 = glow::flags::ClipQuantRangeToFP16; |
369 | cctx.precisionConfig.zeroScaleFP16Clip = glow::flags::ClipZeroScaleFP16; |
370 | std::map<std::string, Type> staticPlaceholderTypes; |
371 | |
372 | std::unique_ptr<ONNXIFIModelLoader> loader; |
373 | auto loaderOrErr = ONNXIFIModelLoader::parse( |
374 | onnxModel, onnxModelSize, weightCount, weightDescriptors, *module, |
375 | netName_, cctx, &staticPlaceholderTypes, |
376 | true /*loadInputsAsPlaceholdersForOnnx*/, backendPtr_->getUseOnnx(), |
377 | /* constFoldInLoader */ false); |
378 | if (loaderOrErr) { |
379 | loader = std::move(*loaderOrErr); |
380 | } else { |
381 | LOG(ERROR) << "Error when loading model: " |
382 | << ERR_TO_STRING(loaderOrErr.takeError()); |
383 | return ONNXIFI_STATUS_INVALID_MODEL; |
384 | } |
385 | |
386 | if (!bindPlaceholders(*loader, &cctx.loadedPHNames)) { |
387 | return ONNXIFI_STATUS_INVALID_MODEL; |
388 | } |
389 | setZeroLengthSequence(maxSeqLength); |
390 | // Make sure the pool is ready to go. |
391 | for (auto &obj : onnxInputToPlaceholder_) { |
392 | tensorPool_.reserve(obj.second->getType(), 10); |
393 | } |
394 | |
395 | if (glow::onnxifi::flags::SaveModel) { |
396 | for (Function *F : module->getFunctions()) { |
397 | saveOnnxifiModel(F); |
398 | } |
399 | } |
400 | |
401 | if (glow::flags::DumpInitialLoadedGraph) { |
402 | for (Function *F : module->getFunctions()) { |
403 | auto fname = strFormat("initial_graph__%s.dot" , F->getName().data()); |
404 | LOG(INFO) << "Dumping initially loaded graph to " << fname; |
405 | F->dumpDAG(fname); |
406 | } |
407 | } |
408 | |
409 | if (loadingGlowAOT) { |
410 | LOG(INFO) << "Loading a Glow AOT optimized model." ; |
411 | cctx.loadingAOTModel = true; |
412 | } |
413 | |
414 | return static_cast<HostManagerBackend *>(backendPtr_) |
415 | ->addNetwork(std::move(module), deferedBlobReader, cctx, |
416 | std::move(staticPlaceholderTypes)); |
417 | } |
418 | |
419 | namespace { |
420 | void dumpTraces(TraceContext *traceContext) { |
421 | CHECK(traceContext); |
422 | llvm::SmallString<64> path; |
423 | auto tempFileRes = |
424 | llvm::sys::fs::createTemporaryFile("glow-trace" , "json" , path); |
425 | if (tempFileRes.value() != 0) { |
426 | LOG(ERROR) << "Failed to create temp file for Glow trace events: " |
427 | << tempFileRes; |
428 | } else { |
429 | traceContext->dump(path); |
430 | } |
431 | } |
432 | |
433 | } // namespace |
434 | |
435 | onnxStatus HostManagerGraph::run(std::unique_ptr<ExecutionContext> ctx, |
436 | EventPtr outputEvent, |
437 | onnxTraceEventList *traceEvents) { |
438 | auto threadId = threads::getThreadId(); |
439 | auto startTime = TraceEvent::now(); |
440 | |
441 | auto *data = ::glow::runtime::RequestData::get(); |
442 | std::map<std::string, std::string> attributes; |
443 | if (data) { |
444 | attributes["app level request id" ] = |
445 | llvm::formatv("{0}" , data->appLevelRequestId); |
446 | } |
447 | |
448 | backendPtr_->runNetwork( |
449 | this, std::move(ctx), |
450 | [outputEvent, traceEvents, threadId, startTime, |
451 | attributes = std::move(attributes), |
452 | this](runtime::RunIdentifierTy runId, Error err, |
453 | std::unique_ptr<ExecutionContext> ctx) mutable { |
454 | TRACE_EVENT_SCOPE(ctx->getTraceContext(), TraceLevel::RUNTIME, |
455 | "Onnxifi::callback" ); |
456 | |
457 | if (err) { |
458 | if (err.peekErrorValue() && err.peekErrorValue()->isFatalError()) { |
459 | std::string msg = err.peekErrorValue()->logToString(); |
460 | auto reporters = ErrorReporterRegistry::ErrorReporters(); |
461 | if (reporters) { |
462 | reporters->report(msg); |
463 | } |
464 | LOG(FATAL) << "Non-recoverable device error when running network: " |
465 | << msg; |
466 | } |
467 | outputEvent->setMessage(ERR_TO_STRING(std::move(err))); |
468 | outputEvent->signal(ONNXIFI_STATUS_INTERNAL_ERROR); |
469 | return; |
470 | } |
471 | |
472 | // End the current trace event before we convert TraceEvents to the |
473 | // ONNX format. |
474 | TRACE_EVENT_SCOPE_END(); |
475 | |
476 | auto *traceContext = ctx->getTraceContext(); |
477 | if (traceContext) { |
478 | // We want to log the async start event with the original caller's |
479 | // threadId. This way, chrome UI will put the async event next to |
480 | // the caller thread. |
481 | traceContext->logTraceEvent("glow e2e" , TraceLevel::RUNTIME, |
482 | TraceEvent::BeginType, startTime, |
483 | attributes, threadId, runId); |
484 | traceContext->logTraceEvent("glow e2e" , TraceLevel::RUNTIME, |
485 | TraceEvent::EndType, TraceEvent::now(), |
486 | attributes, threadId, runId); |
487 | setTraceEvents(traceEvents, traceContext); |
488 | } |
489 | |
490 | // Signal to caller that the inference is completed. |
491 | outputEvent->signal(ONNXIFI_STATUS_SUCCESS); |
492 | |
493 | if (traceContext && glow::flags::DumpDebugTraces) { |
494 | // Dumping traces to a file can take a while. So avoid tracesMutex_ |
495 | // while we call dumpTraces. |
496 | std::unique_ptr<TraceContext> toDump; |
497 | { |
498 | std::unique_lock<std::mutex> lock(tracesMutex_); |
499 | if (!mergedTraceContext_) { |
500 | mergedTraceContext_ = |
501 | glow::make_unique<TraceContext>(TraceLevel::STANDARD); |
502 | } |
503 | mergedTraceContext_->merge(traceContext); |
504 | |
505 | if (++numTracesToDump_ >= glow::flags::NumDebugTracesPerDump) { |
506 | numTracesToDump_ = 0; |
507 | toDump.reset(mergedTraceContext_.release()); |
508 | } |
509 | } |
510 | |
511 | if (toDump) { |
512 | dumpTraces(toDump.get()); |
513 | } |
514 | } |
515 | }); |
516 | |
517 | return ONNXIFI_STATUS_SUCCESS; |
518 | } |
519 | |
520 | HostManagerGraph::~HostManagerGraph() { |
521 | // Remove network from the Backend |
522 | backendPtr_->removeNetwork(this); |
523 | |
524 | if (glow::flags::DumpDebugTraces) { |
525 | std::unique_lock<std::mutex> lock(tracesMutex_); |
526 | if (mergedTraceContext_ && numTracesToDump_ > 0) { |
527 | dumpTraces(mergedTraceContext_.get()); |
528 | } |
529 | } |
530 | } |
531 | |
532 | size_t HostManagerGraph::makeUniqueGraphId() { |
533 | static std::atomic<size_t> nextId{0}; |
534 | return nextId++; |
535 | } |
536 | |
537 | } // namespace onnxifi |
538 | } // namespace glow |
539 | |