1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Runtime/HostManager/HostManager.h" |
18 | #include "glow/Backends/DeviceManager.h" |
19 | #include "glow/Exporter/ONNXModelWriter.h" |
20 | #include "glow/Flags/Flags.h" |
21 | #include "glow/Graph/PlaceholderBindings.h" |
22 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
23 | #include "glow/Partitioner/Partitioner.h" |
24 | #include "glow/Runtime/DeferredWeightLoader.h" |
25 | #include "glow/Runtime/DeviceHealthMonitor.h" |
26 | #include "glow/Runtime/ErrorReporter.h" |
27 | #include "glow/Runtime/Executor/ThreadPoolExecutor.h" |
28 | #include "glow/Runtime/Provisioner/Provisioner.h" |
29 | #include "glow/Runtime/RequestData.h" |
30 | #include "glow/Runtime/RuntimeTypes.h" |
31 | #include "glow/Support/Support.h" |
32 | |
33 | #include "llvm/Support/CommandLine.h" |
34 | #include "llvm/Support/FileSystem.h" |
35 | #include "llvm/Support/FormatVariadic.h" |
36 | |
37 | #include <glog/logging.h> |
38 | |
39 | #include "folly/String.h" |
40 | #include "folly/executors/CPUThreadPoolExecutor.h" |
41 | |
42 | #include <algorithm> |
43 | #include <future> |
44 | #include <queue> |
45 | #include <shared_mutex> |
46 | |
47 | constexpr uint64_t P2PInputLimit = 256; |
48 | using namespace glow; |
49 | using namespace runtime; |
50 | |
51 | namespace { |
52 | llvm::cl::OptionCategory hostManagerCat("HostManager Options" ); |
53 | |
54 | llvm::cl::opt<std::string> loadBackendSpecificOptionsOpt( |
55 | "load-backend-specific-opts" , |
56 | llvm::cl::desc("Load backend-specific options for compilation." ), |
57 | llvm::cl::value_desc("options.yaml" ), llvm::cl::Optional, |
58 | llvm::cl::cat(hostManagerCat)); |
59 | } // namespace |
60 | |
61 | namespace glow { |
62 | |
63 | #if FACEBOOK_INTERNAL |
64 | Error optimizeDAG(DAGListTy &nodeList, const Provisioner &provisioner, |
65 | Module &mod, const std::vector<DeviceInfo> &devices, |
66 | CompilationContext &cctx, |
67 | ConstantFoldingRecordMap &constFoldRecord); |
68 | extern const char *revisionHash; |
69 | #endif /* FACEBOOK_INTERNAL */ |
70 | } // namespace glow |
71 | |
72 | /// The device configs file used for Runtime. |
73 | llvm::cl::opt<std::string> loadDeviceConfigsFileOpt( |
74 | "load-device-configs" , |
75 | llvm::cl::desc("Load device configs used in Runtime" ), |
76 | llvm::cl::value_desc("configs.yaml" ), llvm::cl::Optional, |
77 | llvm::cl::cat(hostManagerCat)); |
78 | |
79 | /// The value that should be used for device initialization timeout, default: |
80 | /// 5000 milliseconds. |
81 | llvm::cl::opt<unsigned, /* ExternalStorage */ true> deviceInitTimeout( |
82 | "device_init_timeout_ms" , |
83 | llvm::cl::desc("Set device init timout in milliseconds" ), |
84 | llvm::cl::Optional, |
85 | llvm::cl::location(glow::runtime::flags::DeviceInitTimeoutMs), |
86 | llvm::cl::cat(hostManagerCat)); |
87 | |
88 | HostManager::HostManager() : HostManager(HostConfig{}) {} |
89 | |
90 | HostManager::HostManager(const HostConfig &hostConfig) |
91 | : config_(hostConfig), |
92 | statsExporterRegistry_(StatsExporterRegistry::Stats()) { |
93 | statsExporterRegistry_->setCounter(kMaxQueueSize, hostConfig.maxQueueSize); |
94 | } |
95 | |
96 | HostManager::HostManager( |
97 | std::vector<std::unique_ptr<DeviceConfig>> deviceConfigs) |
98 | : HostManager(std::move(deviceConfigs), HostConfig{}) {} |
99 | |
100 | HostManager::HostManager( |
101 | std::vector<std::unique_ptr<DeviceConfig>> deviceConfigs, |
102 | const HostConfig &hostConfig) |
103 | : config_(hostConfig), |
104 | statsExporterRegistry_(StatsExporterRegistry::Stats()) { |
105 | // TODO: move all initialization out of constructor. |
106 | |
107 | REPORT_AND_EXIT_ON_ERR(init(std::move(deviceConfigs))); |
108 | statsExporterRegistry_->setCounter(kMaxQueueSize, hostConfig.maxQueueSize); |
109 | } |
110 | |
111 | Expected<DAG *> HostManager::getNetworkDAG(llvm::StringRef network) { |
112 | auto it = networks_.find(network.str()); |
113 | if (it == networks_.end()) { |
114 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR, "Network not found." ); |
115 | } |
116 | return &it->second.dag; |
117 | } |
118 | |
119 | Error HostManager::startDeviceTrace() { |
120 | LOG(INFO) << "start device tracing" << std::endl; |
121 | for (auto &dev : devices_) { |
122 | Error err = dev.second->startDeviceTrace(hostTraceContext_.get()); |
123 | RETURN_IF_ERR(err); |
124 | } |
125 | return Error::success(); |
126 | } |
127 | |
128 | Error HostManager::stopDeviceTrace() { |
129 | |
130 | auto *traceContext = hostTraceContext_.get(); |
131 | if (!traceContext) { |
132 | LOG(INFO) << "No HostManager TraceContext registered, skipping call to " |
133 | "stopDeviceTrace" ; |
134 | return Error::success(); |
135 | } else { |
136 | LOG(INFO) << "stop device tracing" ; |
137 | } |
138 | for (auto &dev : devices_) { |
139 | Error err = dev.second->stopDeviceTrace(traceContext); |
140 | RETURN_IF_ERR(err); |
141 | } |
142 | return Error::success(); |
143 | } |
144 | |
145 | Error HostManager::init(std::vector<std::unique_ptr<DeviceConfig>> configs) { |
146 | static std::once_flag monitorFlag; |
147 | std::call_once(monitorFlag, []() { |
148 | auto monitors = DeviceHealthMonitorRegistry::Monitors(); |
149 | if (monitors) { |
150 | monitors->start(); |
151 | } |
152 | }); |
153 | |
154 | DeviceIDTy deviceCount = 0; |
155 | for (auto &config : configs) { |
156 | if (!config->hasName()) { |
157 | config->name = "device_" + std::to_string(deviceCount); |
158 | } |
159 | |
160 | devices_[deviceCount] = std::unique_ptr<DeviceManager>( |
161 | DeviceManager::createDeviceManager(*config)); |
162 | |
163 | std::promise<Error> devPromise; |
164 | auto devFuture = devPromise.get_future(); |
165 | auto *dev = devices_[deviceCount].get(); |
166 | threadPool_.submit([&devPromise, dev] { |
167 | auto err = dev->init(); |
168 | devPromise.set_value(std::move(err)); |
169 | }); |
170 | if (devFuture.wait_for(std::chrono::milliseconds( |
171 | flags::DeviceInitTimeoutMs)) != std::future_status::timeout) { |
172 | RETURN_IF_ERR(devFuture.get()); |
173 | } else { |
174 | // Device initialization is taking longer than expected, return an error. |
175 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR, |
176 | "Timeout encountered when initializing device: " + |
177 | std::string(config->name)); |
178 | } |
179 | availableDevices_.push_back(deviceCount); |
180 | deviceCount++; |
181 | } |
182 | #ifdef FACEBOOK_INTERNAL |
183 | LOG(INFO) << "Initialized " << deviceCount << " device(s)" ; |
184 | #endif |
185 | |
186 | provisioner_.reset(new Provisioner(devices_)); |
187 | executor_.reset( |
188 | new ThreadPoolExecutor(devices_, config_.executorThreads, "HostManager" )); |
189 | exportMemoryCounters(); |
190 | if (flags::AvailableDevices.length()) { |
191 | std::vector<unsigned> devices; |
192 | folly::split<char, std::string, unsigned>(',', flags::AvailableDevices, |
193 | devices, |
194 | /* ignoreEmpty */ true); |
195 | std::vector<runtime::DeviceIDTy> convertedDevs(devices.begin(), |
196 | devices.end()); |
197 | setAvailableDevices(convertedDevs); |
198 | } |
199 | // If no HostManager is registered yet, register this one. |
200 | if (!ManagerRegistry()->getHostManager()) { |
201 | ManagerRegistry()->registerHostManager(this); |
202 | } |
203 | |
204 | return Error::success(); |
205 | } |
206 | |
207 | void HostManager::setAvailableDevices(const std::vector<DeviceIDTy> &devices) { |
208 | // Validate new device list. |
209 | availableDevices_.clear(); |
210 | std::vector<DeviceIDTy> mapping; |
211 | std::vector<DeviceManager *> availableDevices; |
212 | // Grab a lock to prevent devices_ getting changed concurrently. |
213 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
214 | for (auto dev : devices) { |
215 | auto it = devices_.find(dev); |
216 | if (it != devices_.end()) { |
217 | availableDevices_.push_back(dev); |
218 | availableDevices.push_back(devices_[dev].get()); |
219 | mapping.push_back(it->first); |
220 | } |
221 | } |
222 | // Update the provisioner. |
223 | provisioner_->updateAvailableDevices(availableDevices, mapping); |
224 | } |
225 | |
226 | void HostManager::exportMemoryCounters() { |
227 | uint64_t maxMem = 0; |
228 | uint64_t availableMem = 0; |
229 | for (auto &dev : devices_) { |
230 | maxMem += dev.second->getMaximumMemory(); |
231 | availableMem += dev.second->getAvailableMemory(); |
232 | } |
233 | statsExporterRegistry_->setCounter(kDeviceMemoryUsed, maxMem - availableMem); |
234 | statsExporterRegistry_->setCounter(kDeviceMemoryAvailable, availableMem); |
235 | statsExporterRegistry_->setCounter(kDeviceMemoryMax, maxMem); |
236 | } |
237 | |
238 | HostManager::~HostManager() { |
239 | LOG(INFO) << "Destroying host manager..." ; |
240 | ERR_TO_VOID(clearHost()); |
241 | exportMemoryCounters(); |
242 | } |
243 | |
244 | void HostManager::cleanupAddNetwork(llvm::ArrayRef<std::string> names) { |
245 | for (auto &name : names) { |
246 | processingNetworks_.erase(name); |
247 | } |
248 | exportMemoryCounters(); |
249 | } |
250 | |
251 | Error HostManager::addNetwork(std::unique_ptr<Module> module, |
252 | CompilationContext &cctx) { |
253 | #ifdef FACEBOOK_INTERNAL |
254 | LOG(INFO) << "Adding Glow network built with revision hash: " << revisionHash; |
255 | #endif /* FACEBOOK_INTERNAL */ |
256 | VLOG(1) << "addNetwork" ; |
257 | ScopeGuard debugDumpDAGGuard([&]() { |
258 | if (cctx.dumpFinalGraph) { |
259 | for (Function *F : module->getFunctions()) { |
260 | auto fname = strFormat("%sfinal_graph_dbg_err_%s.dot" , |
261 | cctx.dumpGraphPath.c_str(), F->getName().data()); |
262 | LOG(INFO) << "Dumping final graph due to error to " << fname; |
263 | F->dumpDAG(fname); |
264 | } |
265 | } |
266 | }); |
267 | |
268 | /// If specified in the cctx, this will prevent Constants from being modified |
269 | /// until the current scope ends or the preventer is dismissed. Does so by |
270 | /// swapping in temporary Placeholders instead of Constants. |
271 | ConstantModificationPreventer constModPreventer(*module, cctx); |
272 | if (cctx.optimizationOpts.delayAndRecordConstantModification) { |
273 | constModPreventer.activate(); |
274 | } |
275 | |
276 | std::vector<std::string> names; |
277 | { |
278 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
279 | auto functions = module->getFunctions(); |
280 | for (auto &F : functions) { |
281 | std::string name = F->getName().str(); |
282 | auto it = networks_.find(name); |
283 | if (it != networks_.end() || |
284 | processingNetworks_.find(name) != processingNetworks_.end()) { |
285 | cleanupAddNetwork(names); |
286 | return MAKE_ERR( |
287 | ErrorValue::ErrorCode::RUNTIME_ERROR, |
288 | "Failed to add network: already have a function called " + name); |
289 | } |
290 | // Add the network to processingNetworks_ so we know it's being worked on. |
291 | processingNetworks_.insert(name); |
292 | names.push_back(name); |
293 | } |
294 | } |
295 | |
296 | // Issue a warning when loading backend specific options from the command line |
297 | // and the compile context also contains backend specific options. |
298 | if (!loadBackendSpecificOptionsOpt.empty()) { |
299 | if (cctx.backendOpts.backendSpecificOpts.size() != 0) { |
300 | VLOG_EVERY_N(1, 1000) << "Warning: backendSpecificOpts is set via the " |
301 | "HostManager, ignoring previously set options." ; |
302 | } |
303 | cctx.backendOpts.backendSpecificOpts = |
304 | deserializeStrStrMapFromYaml(loadBackendSpecificOptionsOpt); |
305 | } else { |
306 | auto ctxLoadBackendSpecificOpt = |
307 | cctx.backendOpts.backendSpecificOpts.find("loadBackendSpecificOptions" ); |
308 | |
309 | if (ctxLoadBackendSpecificOpt != |
310 | cctx.backendOpts.backendSpecificOpts.end()) { |
311 | cctx.backendOpts.backendSpecificOpts = |
312 | deserializeStrStrMapFromYaml(ctxLoadBackendSpecificOpt->second); |
313 | } |
314 | } |
315 | |
316 | std::vector<DeviceInfo> deviceInfo; |
317 | { |
318 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
319 | for (auto &device : availableDevices_) { |
320 | DeviceInfo info = devices_[device]->getDeviceInfo(); |
321 | info.availableMemory = devices_[device]->getAvailableMemory(); |
322 | info.backendName = devices_[device]->getBackendName().str(); |
323 | info.nonSupportedNodes = |
324 | devices_[device]->getParamByName("nonSupportedNodes" ).str(); |
325 | info.supportedNodes = |
326 | devices_[device]->getParamByName("supportedNodes" ).str(); |
327 | // If p2p is enabled update the inputCount limit. |
328 | if (cctx.enableP2P) { |
329 | info.inputCountMax = P2PInputLimit; |
330 | } |
331 | deviceInfo.push_back(info); |
332 | } |
333 | } |
334 | |
335 | // Optimize Functions only if we don't have any backendSpecificNodeInfo, |
336 | // because if we do then the Functions were already optimized and Nodes had |
337 | // extra info mapped to them, so we don't want to mutate the Function. Also |
338 | // skip optimizations if we're loading an AOT optimized model. |
339 | const bool skipOptimizations = |
340 | cctx.loadingAOTModel || !cctx.backendOpts.backendSpecificNodeInfo.empty(); |
341 | |
342 | // Perform a round of target-independent graph optimizations. This helps the |
343 | // partitioner to do its job more efficiently. |
344 | if (!skipOptimizations) { |
345 | for (Function *F : module->getFunctions()) { |
346 | auto err = optimizeFunctionBeforeLowering(F, cctx); |
347 | if (err) { |
348 | { |
349 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
350 | cleanupAddNetwork(names); |
351 | } |
352 | RETURN_ERR(err); |
353 | } |
354 | } |
355 | } |
356 | VLOG(1) << "Before partitioner" ; |
357 | Partitioner partitioner(module.get(), deviceInfo, skipOptimizations); |
358 | auto backendName = devices_[0]->getBackendName(); |
359 | const auto &backend = provisioner_->getBackend(backendName); |
360 | auto contextCount = backend.getContextCount(cctx); |
361 | partitioner.setContextCount(contextCount); |
362 | DAGListTy nodeList; |
363 | auto result = partitioner.partition(cctx); |
364 | VLOG(1) << "After partitioner" ; |
365 | if (result) { |
366 | nodeList = std::move(result.get()); |
367 | } else { |
368 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
369 | cleanupAddNetwork(names); |
370 | RETURN_ERR(result.takeError()); |
371 | } |
372 | VLOG(1) << "Before quantmode" ; |
373 | if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) { |
374 | // Since for profiling the provisioner will be reset, we only allow one |
375 | // network in one HM. |
376 | if (networks_.size() > 0) { |
377 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR, |
378 | "For quantization profiling flow, there can't be other " |
379 | "registered networks before this one" ); |
380 | } |
381 | // For profiling, we use CPU backend. Overwrite Provisioner and Executor |
382 | // to force the network is compiled and run in profilingBackend. backend. |
383 | size_t devicesNum = devices_.size(); |
384 | for (size_t i = 0; i < devicesNum; i++) { |
385 | auto name = devices_[i]->getDeviceConfig().name; |
386 | auto config = glow::make_unique<DeviceConfig>(profilingBackend, name); |
387 | devices_[i] = std::unique_ptr<DeviceManager>( |
388 | DeviceManager::createDeviceManager(*config)); |
389 | RETURN_IF_ERR(devices_[i]->init()); |
390 | } |
391 | provisioner_.reset(new Provisioner(devices_)); |
392 | executor_.reset(new ThreadPoolExecutor(devices_, config_.executorThreads)); |
393 | } |
394 | |
395 | VLOG(1) << "Before replace dummy TQPs" ; |
396 | // Now that we've partitioned and optimized, do some verification based on the |
397 | // dummy mode we're using, if any. |
398 | if (cctx.precisionConfig.replaceDummyTQPs || |
399 | cctx.precisionConfig.loadUniquedDummyQParams) { |
400 | RETURN_IF_ERR(module->verifyDummyQParams( |
401 | cctx.precisionConfig.loadUniquedDummyQParams)); |
402 | } |
403 | |
404 | // If we are loading an AOT model where we are replacing dummy TQPs, then we |
405 | // may need to update Relu output types on FCs, since they should be set to |
406 | // use zero as min but the correct qparams could not be calculated AOT. |
407 | if (cctx.loadingAOTModel && cctx.precisionConfig.replaceDummyTQPs) { |
408 | LOG(INFO) << "Updating quantized Relu types given real TQPs" ; |
409 | for (Function *F : module->getFunctions()) { |
410 | updateQuantReluTypes(F); |
411 | } |
412 | } |
413 | |
414 | VLOG(1) << "Before constant folding" ; |
415 | // If we prevented constant modification then run constant folding with |
416 | // recording now. Record so that if we are going to serialize we can embed the |
417 | // constant folding subgraphs in the Glow ONNX model. |
418 | ConstantFoldingRecordMap record; |
419 | if (cctx.optimizationOpts.delayAndRecordConstantModification) { |
420 | constModPreventer.deactivateAndCleanup(); |
421 | |
422 | RETURN_ERR_IF_NOT(nodeList.size() == 1, "Expect only one DAG." ); |
423 | const auto &dag = *nodeList.begin(); |
424 | for (auto &dagNode : dag.nodes) { |
425 | Function *F = module->getFunction(dagNode->name); |
426 | RETURN_ERR_IF_NOT( |
427 | F, strFormat("Function %s not found" , dagNode->name.data())); |
428 | |
429 | ConstantFoldingRecordMap currRecord = constantFoldAndRecord(F, cctx); |
430 | record.insert(currRecord.begin(), currRecord.end()); |
431 | runDCEPass(F, cctx); |
432 | |
433 | // Verify the Function is valid after constant folding takes place. |
434 | Backend &B = provisioner_->getBackend(dagNode->backendName); |
435 | RETURN_ERR_IF_NOT( |
436 | B.verify(*F, cctx.verboseCompile), |
437 | "Unsupported node(s) found after delayed constant folding Function " + |
438 | F->getName().str() + " for backend " + B.getBackendName()); |
439 | } |
440 | } |
441 | VLOG(1) << "Before loading AOT" ; |
442 | if (!cctx.loadingAOTModel) { |
443 | if (cctx.callDAGOptimizer) { |
444 | #if FACEBOOK_INTERNAL |
445 | auto optDagErr = optimizeDAG(nodeList, *provisioner_, *module, deviceInfo, |
446 | cctx, record); |
447 | if (optDagErr) { |
448 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
449 | cleanupAddNetwork(names); |
450 | RETURN_ERR(optDagErr); |
451 | } |
452 | #endif /* FACEBOOK_INTERNAL */ |
453 | } else { |
454 | // If not using the DAG optimizer, iterate over the DAGs and call |
455 | // transformPostOptPipeline() on the Functions. |
456 | VLOG(1) << "No DAG optimizer" ; |
457 | for (const auto &dag : nodeList) { |
458 | for (auto &dagNode : dag.nodes) { |
459 | Function *F = module->getFunction(dagNode->name); |
460 | RETURN_ERR_IF_NOT( |
461 | F, strFormat("Function %s not found" , dagNode->name.data())); |
462 | |
463 | if (cctx.optimizationOpts.onlyLowerFuns.count(F)) { |
464 | continue; |
465 | } |
466 | |
467 | Backend &B = provisioner_->getBackend(dagNode->backendName); |
468 | RETURN_IF_EXPECTED_IS_ERR(B.transformPostOptPipeline(F, cctx)); |
469 | |
470 | RETURN_ERR_IF_NOT( |
471 | B.verify(*F, cctx.verboseCompile), |
472 | "Unsupported node(s) found after transformPostOptPipeline() " + |
473 | F->getName().str() + " for backend " + B.getBackendName()); |
474 | } |
475 | } |
476 | } |
477 | } |
478 | |
479 | VLOG(1) << "Before serialize compile DAG" ; |
480 | // If requested, serialize the resulting DAG that was just optimized and |
481 | // partitioned. |
482 | if (cctx.serializeCompiledDAG) { |
483 | std::string loc; |
484 | char *envSpecifiedSerializationPath = getenv("GLOW_DAG_SERIALIZATION_LOC" ); |
485 | if (!envSpecifiedSerializationPath) { |
486 | loc = nodeList.begin()->root->name + ".onnxtxt" ; |
487 | } else { |
488 | loc = std::string(envSpecifiedSerializationPath); |
489 | } |
490 | |
491 | LOG(INFO) << "Serializing final compiled DAG to " << loc; |
492 | { |
493 | llvm::StringMap<std::string> ; |
494 | if (cctx.precisionConfig.originNameToTQPMap) { |
495 | RETURN_IF_ERR(ONNXModelWriter::insertLoaderNameUniqueOffsetMetadata( |
496 | extraMetadataProps, *cctx.precisionConfig.originNameToTQPMap)); |
497 | } |
498 | if (cctx.precisionConfig.clipQuantRangeToFP16) { |
499 | extraMetadataProps[clipQuantRangeToFP16Key] = "1" ; |
500 | } |
501 | Error writeErr = Error::empty(); |
502 | // Note: If cctx.skipProvisioning then we want to serialize all meta info |
503 | // as we are likely doing AOT optimization. Otherwise do not provide the |
504 | // meta info as the model does not need to be reloaded. |
505 | ONNXModelWriter onnxWR( |
506 | loc, nodeList, 7, 9, &writeErr, |
507 | /* textMode */ true, |
508 | /* zipMode */ cctx.useZipModeForSerializeCompiledDAG, |
509 | /* includeConstantData */ cctx.saveConstantInSerializeCompiledDAG, |
510 | extraMetadataProps, record, cctx.backendOpts.backendSpecificNodeInfo, |
511 | cctx.skipProvisioning ? &cctx.loadedPHNames : nullptr, |
512 | cctx.skipProvisioning ? &cctx.staticPlaceholderTypesForAOT : nullptr, |
513 | cctx.returnGlowSerializedModelStr |
514 | ? cctx.glowAOTSerializationModelStrPtr.get() |
515 | : nullptr); |
516 | RETURN_IF_ERR(writeErr); |
517 | } |
518 | |
519 | // If we're using AOT DAG optimizer then skip provisioning. |
520 | if (cctx.skipProvisioning || |
521 | (cctx.callDAGOptimizer && cctx.useDAGOptimizerAOTMode)) { |
522 | LOG(INFO) << "Host manager skipping provisioning" ; |
523 | { |
524 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
525 | cleanupAddNetwork(names); |
526 | } |
527 | debugDumpDAGGuard.dismiss(); |
528 | cleanupConstantFolding(*module, record); |
529 | if (cctx.dumpFinalGraph) { |
530 | for (Function *F : module->getFunctions()) { |
531 | auto fname = |
532 | strFormat("%sfinal_graph_aot_%s.dot" , cctx.dumpGraphPath.c_str(), |
533 | F->getName().data()); |
534 | LOG(INFO) << "Dumping final graph to " << fname; |
535 | F->dumpDAG(fname); |
536 | } |
537 | } |
538 | return Error::success(); |
539 | } |
540 | } |
541 | |
542 | // Now that we've serialized the model if requested, cleanup the temporary |
543 | // Functions and PHs used for constant folding. |
544 | cleanupConstantFolding(*module, record); |
545 | VLOG(1) << "Before provisioning" ; |
546 | auto err = provisioner_->provision(nodeList, *module, cctx); |
547 | if (err) { |
548 | if (err.peekErrorValue()->isFatalError()) { |
549 | statsExporterRegistry_->setCounter(kDeviceFatalError, 1); |
550 | } |
551 | { |
552 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
553 | cleanupAddNetwork(names); |
554 | } |
555 | RETURN_ERR(err); |
556 | } |
557 | debugDumpDAGGuard.dismiss(); |
558 | VLOG(1) << "Calculation of maxActiveRequests" ; |
559 | { |
560 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
561 | /// Calculate networkMaxActive requests. Then update |
562 | /// config_.maxActiveRequests This will be maxActiveRequestsPerInstance * |
563 | /// instanceCount * minReplications or config_.maxActiveRequests whichever |
564 | /// is smaller. |
565 | |
566 | // Find the minimum on device replication. |
567 | unsigned minReplications{1}; |
568 | for (auto &node : nodeList) { |
569 | for (auto &dag : node.nodes) { |
570 | minReplications = std::min(dag->replicationCount, minReplications); |
571 | } |
572 | } |
573 | unsigned product{0}; |
574 | if (nodeList.size() && nodeList[0].nodes.size()) { |
575 | product = nodeList[0].nodes[0]->instanceCount * |
576 | cctx.maxActiveRequestsPerInstance * minReplications; |
577 | } else { |
578 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR, |
579 | "NodeList is empty." ); |
580 | } |
581 | unsigned maxActiveRequests = config_.maxActiveRequests; |
582 | config_.maxActiveRequests = std::min(product, maxActiveRequests); |
583 | |
584 | // Create pool of cachedExecutionStates. |
585 | for (auto &node : nodeList) { |
586 | // Note: currently getNextNetworkExecutionState assumes that pool size is |
587 | // >= currentInFlight requests, so we set pool size to maxActiveRequests. |
588 | executor_->createPool(node.root.get(), config_.maxActiveRequests, |
589 | cctx.enableP2P, cctx.enableDRT); |
590 | } |
591 | } |
592 | // Clear constants contents from the module then put it in a |
593 | // shared_ptr to be shared between all of the networks created from each |
594 | // function in the module. |
595 | auto targetBackendName = std::string(devices_[0]->getBackendName()); |
596 | const auto &targetBackend = provisioner_->getBackend(targetBackendName); |
597 | if (targetBackend.shouldStripModule() && !cctx.skipModuleStrip) { |
598 | module->strip(); |
599 | } |
600 | VLOG(1) << "Cleanup" ; |
601 | auto sharedModule = std::shared_ptr<Module>(std::move(module)); |
602 | { |
603 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
604 | for (auto &node : nodeList) { |
605 | #if FACEBOOK_INTERNAL |
606 | LOG(INFO) << "Successfully compiled and provisioned " << node.root->name; |
607 | #endif |
608 | auto &networkData = networks_[(node.root)->name]; |
609 | networkData.dag = std::move(node); |
610 | networkData.module = sharedModule; |
611 | } |
612 | cleanupAddNetwork(names); |
613 | } |
614 | VLOG(1) << "After cleanup" ; |
615 | return Error::success(); |
616 | } |
617 | |
618 | #if FACEBOOK_INTERNAL |
619 | Error HostManager::addNetworkFX( |
620 | std::unique_ptr<Module> module, CompilationContext &cctx, |
621 | DAGListTy &networks, const folly::dynamic &FXIR, |
622 | const llvm::StringMap<const void *> &constants) { |
623 | |
624 | LOG(INFO) << "Adding Glow network built with revision hash: " << revisionHash; |
625 | VLOG(1) << "addNetwork" ; |
626 | |
627 | std::vector<std::string> names; |
628 | { |
629 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
630 | auto functions = module->getFunctions(); |
631 | for (auto &F : functions) { |
632 | const auto name = F->getName().str(); |
633 | auto it = networks_.find(name); |
634 | if (it != networks_.end() || |
635 | processingNetworks_.find(name) != processingNetworks_.end()) { |
636 | cleanupAddNetwork(names); |
637 | return MAKE_ERR( |
638 | ErrorValue::ErrorCode::RUNTIME_ERROR, |
639 | "Failed to add network: already have a function called " + name); |
640 | } |
641 | // Add the network to processingNetworks_ so we know it's being worked on. |
642 | processingNetworks_.insert(name); |
643 | names.push_back(name); |
644 | } |
645 | } |
646 | |
647 | // Issue a warning when loading backend specific options from the command line |
648 | // and the compile context also contains backend specific options. |
649 | if (!loadBackendSpecificOptionsOpt.empty()) { |
650 | if (cctx.backendOpts.backendSpecificOpts.size() != 0) { |
651 | VLOG_EVERY_N(1, 1000) << "Warning: backendSpecificOpts is set via the " |
652 | "HostManager, ignoring previously set options." ; |
653 | } |
654 | cctx.backendOpts.backendSpecificOpts = |
655 | deserializeStrStrMapFromYaml(loadBackendSpecificOptionsOpt); |
656 | } else { |
657 | auto ctxLoadBackendSpecificOpt = |
658 | cctx.backendOpts.backendSpecificOpts.find("loadBackendSpecificOptions" ); |
659 | |
660 | if (ctxLoadBackendSpecificOpt != |
661 | cctx.backendOpts.backendSpecificOpts.end()) { |
662 | cctx.backendOpts.backendSpecificOpts = |
663 | deserializeStrStrMapFromYaml(ctxLoadBackendSpecificOpt->second); |
664 | } |
665 | } |
666 | |
667 | std::vector<DeviceInfo> deviceInfo; |
668 | { |
669 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
670 | for (auto &device : availableDevices_) { |
671 | DeviceInfo info = devices_[device]->getDeviceInfo(); |
672 | info.availableMemory = devices_[device]->getAvailableMemory(); |
673 | info.backendName = devices_[device]->getBackendName(); |
674 | info.nonSupportedNodes = |
675 | devices_[device]->getParamByName("nonSupportedNodes" ); |
676 | info.supportedNodes = devices_[device]->getParamByName("supportedNodes" ); |
677 | // If p2p is enabled update the inputCount limit. |
678 | if (cctx.enableP2P) { |
679 | info.inputCountMax = P2PInputLimit; |
680 | } |
681 | deviceInfo.push_back(info); |
682 | } |
683 | } |
684 | |
685 | VLOG(1) << "Before provisioning" ; |
686 | auto err = |
687 | provisioner_->provisionFX(networks, *module, FXIR, constants, cctx); |
688 | if (err) { |
689 | if (err.peekErrorValue()->isFatalError()) { |
690 | statsExporterRegistry_->setCounter(kDeviceFatalError, 1); |
691 | } |
692 | { |
693 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
694 | cleanupAddNetwork(names); |
695 | } |
696 | RETURN_ERR(err); |
697 | } |
698 | |
699 | VLOG(1) << "Calculation of maxActiveRequests" ; |
700 | { |
701 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
702 | /// Calculate networkMaxActive requests. Then update |
703 | /// config_.maxActiveRequests This will be maxActiveRequestsPerInstance * |
704 | /// instanceCount * minReplications or config_.maxActiveRequests whichever |
705 | /// is smaller. |
706 | |
707 | // Find the minimum on device replication. |
708 | unsigned minReplications{1}; |
709 | for (auto &node : networks) { |
710 | for (auto &dag : node.nodes) { |
711 | minReplications = std::min(dag->replicationCount, minReplications); |
712 | } |
713 | } |
714 | unsigned product{0}; |
715 | if (networks.size() && networks[0].nodes.size()) { |
716 | product = networks[0].nodes[0]->instanceCount * |
717 | cctx.maxActiveRequestsPerInstance * minReplications; |
718 | } else { |
719 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR, |
720 | "NodeList is empty." ); |
721 | } |
722 | unsigned maxActiveRequests = config_.maxActiveRequests; |
723 | config_.maxActiveRequests = std::min(product, maxActiveRequests); |
724 | |
725 | // Create pool of cachedExecutionStates. |
726 | for (auto &node : networks) { |
727 | // Note: currently getNextNetworkExecutionState assumes that pool size is |
728 | // >= currentInFlight requests, so we set pool size to maxActiveRequests. |
729 | executor_->createPool(node.root.get(), config_.maxActiveRequests, |
730 | cctx.enableP2P, cctx.enableDRT); |
731 | } |
732 | } |
733 | // Clear constants contents from the module then put it in a |
734 | // shared_ptr to be shared between all of the networks created from each |
735 | // function in the module. |
736 | auto targetBackendName = std::string(devices_[0]->getBackendName()); |
737 | const auto &targetBackend = provisioner_->getBackend(targetBackendName); |
738 | if (targetBackend.shouldStripModule() && !cctx.skipModuleStrip) { |
739 | module->strip(); |
740 | } |
741 | VLOG(1) << "Cleanup" ; |
742 | auto sharedModule = std::shared_ptr<Module>(std::move(module)); |
743 | { |
744 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
745 | for (auto &node : networks) { |
746 | LOG(INFO) << "Successfully compiled and provisioned " << node.root->name; |
747 | auto &networkData = networks_[(node.root)->name]; |
748 | networkData.dag = std::move(node); |
749 | networkData.module = sharedModule; |
750 | } |
751 | cleanupAddNetwork(names); |
752 | } |
753 | VLOG(1) << "After cleanup" ; |
754 | return Error::success(); |
755 | } |
756 | #endif |
757 | |
758 | std::unordered_map<std::string, std::vector<DeviceIDTy>> |
759 | HostManager::getDevicePartitionMapping(llvm::StringRef network) { |
760 | std::unordered_map<std::string, std::vector<DeviceIDTy>> mapping; |
761 | auto it = networks_.find(network.str()); |
762 | if (it != networks_.end()) { |
763 | auto &nodeList = it->second.dag.nodes; |
764 | for (auto &node : nodeList) { |
765 | std::vector<DeviceIDTy> devices; |
766 | for (auto &dev : node->deviceRuntimeInfos) { |
767 | devices.push_back(dev.first); |
768 | } |
769 | mapping[node->name] = devices; |
770 | } |
771 | } |
772 | return mapping; |
773 | } |
774 | |
775 | Error HostManager::removeNetwork(llvm::StringRef networkName) { |
776 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
777 | auto networkIterator = networks_.find(networkName.str()); |
778 | if (networkIterator == networks_.end()) { |
779 | return Error::success(); |
780 | } |
781 | |
782 | if (processingNetworks_.find(networkName.str()) != |
783 | processingNetworks_.end()) { |
784 | // Return an error, the network is in an incomplete state likely because |
785 | // it is still being added by a different call. |
786 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_BUSY, |
787 | llvm::formatv("Cannot remove the network {0}, as it is " |
788 | "currently being modified." , |
789 | networkName) |
790 | .str()); |
791 | } |
792 | |
793 | // Issue an error as there are outstanding runs for the network |
794 | if (networkIterator->second.refcount != 0) { |
795 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_BUSY, |
796 | llvm::formatv("Cannot remove the network {0}, as there are " |
797 | "still outstanding runs" , |
798 | networkName) |
799 | .str()); |
800 | } |
801 | |
802 | OneErrOnly err; |
803 | auto &nodes = networkIterator->second.dag.nodes; |
804 | // Free the pool of executionStates. |
805 | executor_->freePool(networkIterator->second.dag.root.get()); |
806 | for (auto &node : nodes) { |
807 | for (auto device : node->deviceRuntimeInfos) { |
808 | Error evictErr = provisioner_->evictFunction( |
809 | node->name, devices_[device.first].get(), node->replicationCount); |
810 | err.set(std::move(evictErr)); |
811 | } |
812 | // Also remove compiledFunction from Provisioner. |
813 | err.set(provisioner_->removeFunction(node->name)); |
814 | } |
815 | networks_.erase(networkIterator); |
816 | exportMemoryCounters(); |
817 | RETURN_ERR(err.get()); |
818 | } |
819 | |
820 | bool HostManager::networkAdded(llvm::StringRef networkName) { |
821 | std::shared_lock<std::shared_timed_mutex> networkLock(networkLock_); |
822 | return networks_.find(networkName.str()) != networks_.end(); |
823 | } |
824 | |
825 | Error HostManager::clearHost() { |
826 | // shutdown the executor, blocking on any current inflight and prevent new |
827 | // requests from being serviced. |
828 | executor_->shutdown(); |
829 | |
830 | DCHECK_EQ(activeRequestCount_, 0) |
831 | << "All requests should be finished when shutting down HostManager." ; |
832 | |
833 | // Remove all networks from the host and device(s). |
834 | while (networks_.size() != 0) { |
835 | RETURN_IF_ERR(removeNetwork(networks_.begin()->first)); |
836 | } |
837 | |
838 | // Now it's safe to stop the DeviceManagers. |
839 | std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_); |
840 | OneErrOnly errContainer; |
841 | for (auto &it : devices_) { |
842 | errContainer.set(it.second->stop()); |
843 | } |
844 | // Zero out counters. |
845 | statsExporterRegistry_->setCounter(kDeviceMemoryUsed, 0); |
846 | statsExporterRegistry_->setCounter(kDeviceMemoryAvailable, 0); |
847 | statsExporterRegistry_->setCounter(kDeviceMemoryMax, 0); |
848 | |
849 | RETURN_ERR(errContainer.get()); |
850 | } |
851 | |
852 | Error HostManager::runNetworkBlocking(llvm::StringRef networkName, |
853 | PlaceholderBindings &bindings) { |
854 | std::unique_ptr<PlaceholderBindings> phBindings(&bindings); |
855 | std::unique_ptr<ExecutionContext> context = |
856 | glow::make_unique<ExecutionContext>(std::move(phBindings)); |
857 | std::promise<void> runPromise; |
858 | auto fut = runPromise.get_future(); |
859 | std::unique_ptr<Error> runErr; |
860 | runNetwork( |
861 | networkName, std::move(context), |
862 | [&runPromise, &runErr](runtime::RunIdentifierTy, Error err, |
863 | std::unique_ptr<ExecutionContext> contextPtr) { |
864 | // Don't delete ph bindings since they were created from a passed in |
865 | // reference. |
866 | std::unique_ptr<PlaceholderBindings> phBind = |
867 | contextPtr->movePlaceholderBindings(); |
868 | phBind.release(); |
869 | |
870 | runErr = glow::make_unique<Error>(std::move(err)); |
871 | runPromise.set_value(); |
872 | }); |
873 | |
874 | fut.wait(); |
875 | return std::move(*DCHECK_NOTNULL(runErr.get())); |
876 | } |
877 | |
878 | Error HostManager::runNetworkBlocking( |
879 | llvm::StringRef networkName, std::unique_ptr<ExecutionContext> &context) { |
880 | std::promise<void> runPromise; |
881 | auto fut = runPromise.get_future(); |
882 | Error runErr = Error::empty(); |
883 | std::unique_ptr<ExecutionContext> tempContext; |
884 | |
885 | runNetwork(networkName, std::move(context), |
886 | [&runPromise, &runErr, |
887 | &tempContext](runtime::RunIdentifierTy, Error err, |
888 | std::unique_ptr<ExecutionContext> resultCtxt) { |
889 | runErr = std::move(err); |
890 | tempContext = std::move(resultCtxt); |
891 | runPromise.set_value(); |
892 | }); |
893 | |
894 | fut.wait(); |
895 | context = std::move(tempContext); |
896 | return runErr; |
897 | } |
898 | |
899 | void HostManager::dispatchNextRun() { |
900 | int requestId = -1; |
901 | llvm::Optional<InferRequest> pRequest; |
902 | std::shared_lock<std::shared_timed_mutex> networkLock(networkLock_); |
903 | { |
904 | // hmm this lock is hot but I still have it as a unique lock because |
905 | // we always need to pop inferQueue and inferQueue is not thread safe |
906 | std::unique_lock<std::shared_timed_mutex> queueLock(inferQueueLock_); |
907 | if (inferQueue_.size()) { |
908 | // Get the next request, unfortunately priority_queue only |
909 | // provides a const ref to the top element, since we need to move |
910 | // it we first cast it to remove the const. |
911 | pRequest = std::move(const_cast<InferRequest &>(inferQueue_.top())); |
912 | requestId = static_cast<int>(pRequest->requestID); |
913 | inferQueue_.pop(); |
914 | } else { |
915 | // Decrement the activeRequest counter so new requests can |
916 | // launched. |
917 | --activeRequestCount_; |
918 | return; |
919 | } |
920 | } |
921 | |
922 | assert(pRequest.hasValue()); |
923 | InferRequest request = std::move(pRequest.getValue()); |
924 | auto startTime = TraceEvent::now(); |
925 | auto requestReceived = request.startTime; |
926 | executor_->run( |
927 | networks_[request.networkName].dag.root.get(), std::move(request.context), |
928 | request.requestID, |
929 | [this, callback = request.callback, name = request.networkName, startTime, |
930 | requestReceived](RunIdentifierTy runID, Error err, |
931 | std::unique_ptr<ExecutionContext> context) mutable { |
932 | { |
933 | std::shared_lock<std::shared_timed_mutex> netLock(networkLock_); |
934 | auto it = networks_.find(name); |
935 | if (it != networks_.end()) { |
936 | it->second.refcount--; |
937 | } |
938 | } |
939 | |
940 | updateExecutionStats(startTime, context, name, err); |
941 | // Update request runtime. |
942 | auto requestData = ::glow::runtime::RequestData::get(); |
943 | if (requestData) { |
944 | uint64_t end = TraceEvent::now(); |
945 | requestData->startTime = requestReceived; |
946 | requestData->stopTime = end; |
947 | } |
948 | |
949 | callback(runID, std::move(err), std::move(context)); |
950 | dispatchNextRun(); |
951 | }); |
952 | } |
953 | |
954 | RunIdentifierTy |
955 | HostManager::runNetwork(llvm::StringRef networkName, |
956 | std::unique_ptr<ExecutionContext> context, |
957 | ResultCBTy callback, uint64_t priority) { |
958 | DCHECK(callback != nullptr); |
959 | |
960 | TRACE_EVENT_SCOPE_NAMED(context->getTraceContext(), TraceLevel::RUNTIME, |
961 | "HostManager::runNetwork" , traceBlock); |
962 | auto currentRun = totalRequestCount_++; |
963 | traceBlock.addArg("glowRequestId" , llvm::formatv("{0}" , currentRun).str()); |
964 | uint64_t requestReceived = TraceEvent::now(); |
965 | size_t queueSize = 0; |
966 | |
967 | NetworkData *network = nullptr; |
968 | { |
969 | std::shared_lock<std::shared_timed_mutex> networkLock(networkLock_); |
970 | auto it = networks_.find(networkName.str()); |
971 | if (it != networks_.end()) { |
972 | network = &it->second; |
973 | network->refcount++; |
974 | } |
975 | |
976 | if (network == nullptr) { |
977 | TRACE_EVENT_SCOPE_END_NAMED(traceBlock); |
978 | callback( |
979 | currentRun, |
980 | MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_NOT_FOUND, |
981 | llvm::formatv("Function {0} not found" , networkName).str()), |
982 | std::move(context)); |
983 | return currentRun; |
984 | } |
985 | // Put the request in the queue. |
986 | { |
987 | std::shared_lock<std::shared_timed_mutex> lock(inferQueueLock_); |
988 | queueSize = inferQueue_.size(); |
989 | if (queueSize >= config_.maxQueueSize) { |
990 | // The queue is full, return an error. |
991 | network->refcount--; |
992 | TRACE_EVENT_SCOPE_END_NAMED(traceBlock); |
993 | callback( |
994 | currentRun, |
995 | MAKE_ERR( |
996 | ErrorValue::ErrorCode::RUNTIME_REQUEST_REFUSED, |
997 | strFormat( |
998 | "The number of allowed queued requests has been exceeded. " |
999 | "queued requests: %lu allowed requests: %zu" , |
1000 | queueSize, config_.maxQueueSize)), |
1001 | std::move(context)); |
1002 | return currentRun; |
1003 | } |
1004 | } |
1005 | reportCurrentQueueSize(queueSize); |
1006 | // Setup the request |
1007 | InferRequest queuedRequest(networkName.str(), std::move(context), callback, |
1008 | priority, currentRun, requestReceived); |
1009 | { |
1010 | std::unique_lock<std::shared_timed_mutex> lock(inferQueueLock_); |
1011 | TRACE_EVENT_SCOPE_END_NAMED(traceBlock); |
1012 | inferQueue_.push(std::move(queuedRequest)); |
1013 | } |
1014 | } |
1015 | |
1016 | // If we haven't reached maxActiveRequests kick off next request. |
1017 | size_t activeRequestCount = activeRequestCount_++; |
1018 | if (activeRequestCount < config_.maxActiveRequests) { |
1019 | dispatchNextRun(); |
1020 | return currentRun; |
1021 | } |
1022 | activeRequestCount_--; |
1023 | return currentRun; |
1024 | } |
1025 | |
1026 | /// Helper to report current queue size |
1027 | void HostManager::reportCurrentQueueSize(int32_t queueSize) { |
1028 | statsExporterRegistry_->setCounter( |
1029 | kCurrentQueueSize10k, static_cast<float>(queueSize) / |
1030 | static_cast<float>(config_.maxQueueSize) * |
1031 | 100000); |
1032 | } |
1033 | |
1034 | /// Helper to update execution stats |
1035 | void HostManager::updateExecutionStats( |
1036 | uint64_t startTime, std::unique_ptr<ExecutionContext> &context, |
1037 | llvm::StringRef networkName, const Error &error) { |
1038 | auto duration = TraceEvent::now() - startTime; |
1039 | auto updateCountersFn = [&](llvm::StringRef s) { |
1040 | statsExporterRegistry_->addTimeSeriesValue( |
1041 | ("glow.execution_duration_e2e." + s).str(), duration); |
1042 | statsExporterRegistry_->incrementCounter( |
1043 | ("glow.requests_processed." + s).str()); |
1044 | if (error.peekErrorValue()) { |
1045 | statsExporterRegistry_->incrementCounter( |
1046 | ("glow.requests_failed." + s).str()); |
1047 | } else { |
1048 | statsExporterRegistry_->incrementCounter( |
1049 | ("glow.requests_succeeded." + s).str()); |
1050 | } |
1051 | }; |
1052 | updateCountersFn(networkName); |
1053 | updateCountersFn("global" ); |
1054 | } |
1055 | |
1056 | /// Helper to get the parameters in DeviceConfig from \p str. The \p str has |
1057 | /// multiple lines, and each line with this format : "str1" : "str2". |
1058 | static llvm::StringMap<std::string> getBackendParams(std::string &str) { |
1059 | llvm::StringMap<std::string> ret{}; |
1060 | std::string s; |
1061 | std::istringstream f(str.c_str()); |
1062 | while (getline(f, s, '\n')) { |
1063 | // Abstract the mapping from each line's string: |
1064 | // ""str1" : "str2"" => ret["str1"] = "str2"; |
1065 | size_t pos1, pos2, pos3, pos4; |
1066 | pos1 = s.find('"'); |
1067 | assert(pos1 != std::string::npos && "invalid string format" ); |
1068 | pos2 = s.find('"', pos1 + 1); |
1069 | assert(pos2 != std::string::npos && "invalid string format" ); |
1070 | pos3 = s.find('"', pos2 + 1); |
1071 | assert(pos3 != std::string::npos && "invalid string format" ); |
1072 | pos4 = s.find('"', pos3 + 1); |
1073 | assert(pos4 != std::string::npos && "invalid string format" ); |
1074 | ret[s.substr(pos1 + 1, pos2 - pos1 - 1)] = |
1075 | s.substr(pos3 + 1, pos4 - pos3 - 1); |
1076 | } |
1077 | return ret; |
1078 | } |
1079 | |
1080 | /// If the device config file \p loadDeviceDoncfigsFile available, load \p |
1081 | /// configs from the file. Otherwise, create \p numDevices number of devices |
1082 | /// based on \p backendName. |
1083 | std::vector<std::unique_ptr<runtime::DeviceConfig>> |
1084 | runtime::generateDeviceConfigs(unsigned int numDevices, |
1085 | llvm::StringRef backendName, size_t memSize) { |
1086 | std::vector<std::unique_ptr<runtime::DeviceConfig>> configs; |
1087 | if (!loadDeviceConfigsFromFile(configs, memSize)) { |
1088 | // If there is no device config file, use numDevices to generate the |
1089 | // configs. |
1090 | std::vector<unsigned> available_device_ids; |
1091 | if (glow::flags::ScanDevices) { |
1092 | const auto &factories = |
1093 | FactoryRegistry<std::string, Backend>::factories(); |
1094 | auto it = factories.find(backendName.str()); |
1095 | if (it != factories.end()) { |
1096 | available_device_ids = it->second->scanDeviceIDs(); |
1097 | } |
1098 | CHECK_GE(available_device_ids.size(), 0) << "No devices found." ; |
1099 | CHECK_GE(available_device_ids.size(), numDevices) |
1100 | << "Not enough devices found." ; |
1101 | } |
1102 | for (unsigned int i = 0; i < numDevices; ++i) { |
1103 | auto config = glow::make_unique<runtime::DeviceConfig>(backendName); |
1104 | config->setDeviceMemory(memSize); |
1105 | if (glow::flags::ScanDevices) { |
1106 | config->deviceID = available_device_ids.back(); |
1107 | available_device_ids.pop_back(); |
1108 | } else { |
1109 | config->deviceID = i; |
1110 | } |
1111 | configs.push_back(std::move(config)); |
1112 | } |
1113 | } |
1114 | return configs; |
1115 | } |
1116 | |
1117 | bool runtime::loadDeviceConfigsFromFile( |
1118 | std::vector<std::unique_ptr<runtime::DeviceConfig>> &configs, |
1119 | size_t memSize) { |
1120 | if (loadDeviceConfigsFileOpt.empty()) { |
1121 | return false; |
1122 | } |
1123 | |
1124 | std::vector<DeviceConfigHelper> lists; |
1125 | lists = deserializeDeviceConfigFromYaml(loadDeviceConfigsFileOpt); |
1126 | for (unsigned int i = 0; i < lists.size(); ++i) { |
1127 | std::string configBackendName = lists[i].backendName_; |
1128 | std::string name = lists[i].name_; |
1129 | auto parameters = getBackendParams(lists[i].parameters_.str); |
1130 | auto config = glow::make_unique<runtime::DeviceConfig>(configBackendName, |
1131 | name, parameters); |
1132 | config->deviceID = i; |
1133 | config->setDeviceMemory(memSize); |
1134 | configs.push_back(std::move(config)); |
1135 | } |
1136 | return true; |
1137 | } |
1138 | |
1139 | Backend &HostManager::getBackend(llvm::StringRef backendName) const { |
1140 | return provisioner_->getBackend(backendName); |
1141 | } |
1142 | |
1143 | Expected<Backend *> HostManager::getBackend() const { |
1144 | return provisioner_->getBackend(); |
1145 | } |
1146 | |
1147 | std::unique_ptr< |
1148 | std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>> |
1149 | HostManager::getAllSerializedFunctions() { |
1150 | return provisioner_->getAllSerializedFunctionsMap(); |
1151 | } |
1152 | |
1153 | HostManager *HostManagerRegistry::getHostManager() { return hostManager_; } |
1154 | |
1155 | void HostManagerRegistry::registerHostManager(HostManager *hostManager) { |
1156 | hostManager_ = hostManager; |
1157 | } |
1158 | |
1159 | std::shared_ptr<HostManagerRegistry> glow::runtime::ManagerRegistry() { |
1160 | static auto hostManager = std::make_shared<HostManagerRegistry>(); |
1161 | return hostManager; |
1162 | } |
1163 | |