1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Runtime/HostManager/HostManager.h"
18#include "glow/Backends/DeviceManager.h"
19#include "glow/Exporter/ONNXModelWriter.h"
20#include "glow/Flags/Flags.h"
21#include "glow/Graph/PlaceholderBindings.h"
22#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
23#include "glow/Partitioner/Partitioner.h"
24#include "glow/Runtime/DeferredWeightLoader.h"
25#include "glow/Runtime/DeviceHealthMonitor.h"
26#include "glow/Runtime/ErrorReporter.h"
27#include "glow/Runtime/Executor/ThreadPoolExecutor.h"
28#include "glow/Runtime/Provisioner/Provisioner.h"
29#include "glow/Runtime/RequestData.h"
30#include "glow/Runtime/RuntimeTypes.h"
31#include "glow/Support/Support.h"
32
33#include "llvm/Support/CommandLine.h"
34#include "llvm/Support/FileSystem.h"
35#include "llvm/Support/FormatVariadic.h"
36
37#include <glog/logging.h>
38
39#include "folly/String.h"
40#include "folly/executors/CPUThreadPoolExecutor.h"
41
42#include <algorithm>
43#include <future>
44#include <queue>
45#include <shared_mutex>
46
47constexpr uint64_t P2PInputLimit = 256;
48using namespace glow;
49using namespace runtime;
50
51namespace {
52llvm::cl::OptionCategory hostManagerCat("HostManager Options");
53
54llvm::cl::opt<std::string> loadBackendSpecificOptionsOpt(
55 "load-backend-specific-opts",
56 llvm::cl::desc("Load backend-specific options for compilation."),
57 llvm::cl::value_desc("options.yaml"), llvm::cl::Optional,
58 llvm::cl::cat(hostManagerCat));
59} // namespace
60
61namespace glow {
62
63#if FACEBOOK_INTERNAL
64Error optimizeDAG(DAGListTy &nodeList, const Provisioner &provisioner,
65 Module &mod, const std::vector<DeviceInfo> &devices,
66 CompilationContext &cctx,
67 ConstantFoldingRecordMap &constFoldRecord);
68extern const char *revisionHash;
69#endif /* FACEBOOK_INTERNAL */
70} // namespace glow
71
72/// The device configs file used for Runtime.
73llvm::cl::opt<std::string> loadDeviceConfigsFileOpt(
74 "load-device-configs",
75 llvm::cl::desc("Load device configs used in Runtime"),
76 llvm::cl::value_desc("configs.yaml"), llvm::cl::Optional,
77 llvm::cl::cat(hostManagerCat));
78
79/// The value that should be used for device initialization timeout, default:
80/// 5000 milliseconds.
81llvm::cl::opt<unsigned, /* ExternalStorage */ true> deviceInitTimeout(
82 "device_init_timeout_ms",
83 llvm::cl::desc("Set device init timout in milliseconds"),
84 llvm::cl::Optional,
85 llvm::cl::location(glow::runtime::flags::DeviceInitTimeoutMs),
86 llvm::cl::cat(hostManagerCat));
87
88HostManager::HostManager() : HostManager(HostConfig{}) {}
89
90HostManager::HostManager(const HostConfig &hostConfig)
91 : config_(hostConfig),
92 statsExporterRegistry_(StatsExporterRegistry::Stats()) {
93 statsExporterRegistry_->setCounter(kMaxQueueSize, hostConfig.maxQueueSize);
94}
95
96HostManager::HostManager(
97 std::vector<std::unique_ptr<DeviceConfig>> deviceConfigs)
98 : HostManager(std::move(deviceConfigs), HostConfig{}) {}
99
100HostManager::HostManager(
101 std::vector<std::unique_ptr<DeviceConfig>> deviceConfigs,
102 const HostConfig &hostConfig)
103 : config_(hostConfig),
104 statsExporterRegistry_(StatsExporterRegistry::Stats()) {
105 // TODO: move all initialization out of constructor.
106
107 REPORT_AND_EXIT_ON_ERR(init(std::move(deviceConfigs)));
108 statsExporterRegistry_->setCounter(kMaxQueueSize, hostConfig.maxQueueSize);
109}
110
111Expected<DAG *> HostManager::getNetworkDAG(llvm::StringRef network) {
112 auto it = networks_.find(network.str());
113 if (it == networks_.end()) {
114 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR, "Network not found.");
115 }
116 return &it->second.dag;
117}
118
119Error HostManager::startDeviceTrace() {
120 LOG(INFO) << "start device tracing" << std::endl;
121 for (auto &dev : devices_) {
122 Error err = dev.second->startDeviceTrace(hostTraceContext_.get());
123 RETURN_IF_ERR(err);
124 }
125 return Error::success();
126}
127
128Error HostManager::stopDeviceTrace() {
129
130 auto *traceContext = hostTraceContext_.get();
131 if (!traceContext) {
132 LOG(INFO) << "No HostManager TraceContext registered, skipping call to "
133 "stopDeviceTrace";
134 return Error::success();
135 } else {
136 LOG(INFO) << "stop device tracing";
137 }
138 for (auto &dev : devices_) {
139 Error err = dev.second->stopDeviceTrace(traceContext);
140 RETURN_IF_ERR(err);
141 }
142 return Error::success();
143}
144
145Error HostManager::init(std::vector<std::unique_ptr<DeviceConfig>> configs) {
146 static std::once_flag monitorFlag;
147 std::call_once(monitorFlag, []() {
148 auto monitors = DeviceHealthMonitorRegistry::Monitors();
149 if (monitors) {
150 monitors->start();
151 }
152 });
153
154 DeviceIDTy deviceCount = 0;
155 for (auto &config : configs) {
156 if (!config->hasName()) {
157 config->name = "device_" + std::to_string(deviceCount);
158 }
159
160 devices_[deviceCount] = std::unique_ptr<DeviceManager>(
161 DeviceManager::createDeviceManager(*config));
162
163 std::promise<Error> devPromise;
164 auto devFuture = devPromise.get_future();
165 auto *dev = devices_[deviceCount].get();
166 threadPool_.submit([&devPromise, dev] {
167 auto err = dev->init();
168 devPromise.set_value(std::move(err));
169 });
170 if (devFuture.wait_for(std::chrono::milliseconds(
171 flags::DeviceInitTimeoutMs)) != std::future_status::timeout) {
172 RETURN_IF_ERR(devFuture.get());
173 } else {
174 // Device initialization is taking longer than expected, return an error.
175 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR,
176 "Timeout encountered when initializing device: " +
177 std::string(config->name));
178 }
179 availableDevices_.push_back(deviceCount);
180 deviceCount++;
181 }
182#ifdef FACEBOOK_INTERNAL
183 LOG(INFO) << "Initialized " << deviceCount << " device(s)";
184#endif
185
186 provisioner_.reset(new Provisioner(devices_));
187 executor_.reset(
188 new ThreadPoolExecutor(devices_, config_.executorThreads, "HostManager"));
189 exportMemoryCounters();
190 if (flags::AvailableDevices.length()) {
191 std::vector<unsigned> devices;
192 folly::split<char, std::string, unsigned>(',', flags::AvailableDevices,
193 devices,
194 /* ignoreEmpty */ true);
195 std::vector<runtime::DeviceIDTy> convertedDevs(devices.begin(),
196 devices.end());
197 setAvailableDevices(convertedDevs);
198 }
199 // If no HostManager is registered yet, register this one.
200 if (!ManagerRegistry()->getHostManager()) {
201 ManagerRegistry()->registerHostManager(this);
202 }
203
204 return Error::success();
205}
206
207void HostManager::setAvailableDevices(const std::vector<DeviceIDTy> &devices) {
208 // Validate new device list.
209 availableDevices_.clear();
210 std::vector<DeviceIDTy> mapping;
211 std::vector<DeviceManager *> availableDevices;
212 // Grab a lock to prevent devices_ getting changed concurrently.
213 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
214 for (auto dev : devices) {
215 auto it = devices_.find(dev);
216 if (it != devices_.end()) {
217 availableDevices_.push_back(dev);
218 availableDevices.push_back(devices_[dev].get());
219 mapping.push_back(it->first);
220 }
221 }
222 // Update the provisioner.
223 provisioner_->updateAvailableDevices(availableDevices, mapping);
224}
225
226void HostManager::exportMemoryCounters() {
227 uint64_t maxMem = 0;
228 uint64_t availableMem = 0;
229 for (auto &dev : devices_) {
230 maxMem += dev.second->getMaximumMemory();
231 availableMem += dev.second->getAvailableMemory();
232 }
233 statsExporterRegistry_->setCounter(kDeviceMemoryUsed, maxMem - availableMem);
234 statsExporterRegistry_->setCounter(kDeviceMemoryAvailable, availableMem);
235 statsExporterRegistry_->setCounter(kDeviceMemoryMax, maxMem);
236}
237
238HostManager::~HostManager() {
239 LOG(INFO) << "Destroying host manager...";
240 ERR_TO_VOID(clearHost());
241 exportMemoryCounters();
242}
243
244void HostManager::cleanupAddNetwork(llvm::ArrayRef<std::string> names) {
245 for (auto &name : names) {
246 processingNetworks_.erase(name);
247 }
248 exportMemoryCounters();
249}
250
251Error HostManager::addNetwork(std::unique_ptr<Module> module,
252 CompilationContext &cctx) {
253#ifdef FACEBOOK_INTERNAL
254 LOG(INFO) << "Adding Glow network built with revision hash: " << revisionHash;
255#endif /* FACEBOOK_INTERNAL */
256 VLOG(1) << "addNetwork";
257 ScopeGuard debugDumpDAGGuard([&]() {
258 if (cctx.dumpFinalGraph) {
259 for (Function *F : module->getFunctions()) {
260 auto fname = strFormat("%sfinal_graph_dbg_err_%s.dot",
261 cctx.dumpGraphPath.c_str(), F->getName().data());
262 LOG(INFO) << "Dumping final graph due to error to " << fname;
263 F->dumpDAG(fname);
264 }
265 }
266 });
267
268 /// If specified in the cctx, this will prevent Constants from being modified
269 /// until the current scope ends or the preventer is dismissed. Does so by
270 /// swapping in temporary Placeholders instead of Constants.
271 ConstantModificationPreventer constModPreventer(*module, cctx);
272 if (cctx.optimizationOpts.delayAndRecordConstantModification) {
273 constModPreventer.activate();
274 }
275
276 std::vector<std::string> names;
277 {
278 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
279 auto functions = module->getFunctions();
280 for (auto &F : functions) {
281 std::string name = F->getName().str();
282 auto it = networks_.find(name);
283 if (it != networks_.end() ||
284 processingNetworks_.find(name) != processingNetworks_.end()) {
285 cleanupAddNetwork(names);
286 return MAKE_ERR(
287 ErrorValue::ErrorCode::RUNTIME_ERROR,
288 "Failed to add network: already have a function called " + name);
289 }
290 // Add the network to processingNetworks_ so we know it's being worked on.
291 processingNetworks_.insert(name);
292 names.push_back(name);
293 }
294 }
295
296 // Issue a warning when loading backend specific options from the command line
297 // and the compile context also contains backend specific options.
298 if (!loadBackendSpecificOptionsOpt.empty()) {
299 if (cctx.backendOpts.backendSpecificOpts.size() != 0) {
300 VLOG_EVERY_N(1, 1000) << "Warning: backendSpecificOpts is set via the "
301 "HostManager, ignoring previously set options.";
302 }
303 cctx.backendOpts.backendSpecificOpts =
304 deserializeStrStrMapFromYaml(loadBackendSpecificOptionsOpt);
305 } else {
306 auto ctxLoadBackendSpecificOpt =
307 cctx.backendOpts.backendSpecificOpts.find("loadBackendSpecificOptions");
308
309 if (ctxLoadBackendSpecificOpt !=
310 cctx.backendOpts.backendSpecificOpts.end()) {
311 cctx.backendOpts.backendSpecificOpts =
312 deserializeStrStrMapFromYaml(ctxLoadBackendSpecificOpt->second);
313 }
314 }
315
316 std::vector<DeviceInfo> deviceInfo;
317 {
318 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
319 for (auto &device : availableDevices_) {
320 DeviceInfo info = devices_[device]->getDeviceInfo();
321 info.availableMemory = devices_[device]->getAvailableMemory();
322 info.backendName = devices_[device]->getBackendName().str();
323 info.nonSupportedNodes =
324 devices_[device]->getParamByName("nonSupportedNodes").str();
325 info.supportedNodes =
326 devices_[device]->getParamByName("supportedNodes").str();
327 // If p2p is enabled update the inputCount limit.
328 if (cctx.enableP2P) {
329 info.inputCountMax = P2PInputLimit;
330 }
331 deviceInfo.push_back(info);
332 }
333 }
334
335 // Optimize Functions only if we don't have any backendSpecificNodeInfo,
336 // because if we do then the Functions were already optimized and Nodes had
337 // extra info mapped to them, so we don't want to mutate the Function. Also
338 // skip optimizations if we're loading an AOT optimized model.
339 const bool skipOptimizations =
340 cctx.loadingAOTModel || !cctx.backendOpts.backendSpecificNodeInfo.empty();
341
342 // Perform a round of target-independent graph optimizations. This helps the
343 // partitioner to do its job more efficiently.
344 if (!skipOptimizations) {
345 for (Function *F : module->getFunctions()) {
346 auto err = optimizeFunctionBeforeLowering(F, cctx);
347 if (err) {
348 {
349 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
350 cleanupAddNetwork(names);
351 }
352 RETURN_ERR(err);
353 }
354 }
355 }
356 VLOG(1) << "Before partitioner";
357 Partitioner partitioner(module.get(), deviceInfo, skipOptimizations);
358 auto backendName = devices_[0]->getBackendName();
359 const auto &backend = provisioner_->getBackend(backendName);
360 auto contextCount = backend.getContextCount(cctx);
361 partitioner.setContextCount(contextCount);
362 DAGListTy nodeList;
363 auto result = partitioner.partition(cctx);
364 VLOG(1) << "After partitioner";
365 if (result) {
366 nodeList = std::move(result.get());
367 } else {
368 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
369 cleanupAddNetwork(names);
370 RETURN_ERR(result.takeError());
371 }
372 VLOG(1) << "Before quantmode";
373 if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) {
374 // Since for profiling the provisioner will be reset, we only allow one
375 // network in one HM.
376 if (networks_.size() > 0) {
377 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR,
378 "For quantization profiling flow, there can't be other "
379 "registered networks before this one");
380 }
381 // For profiling, we use CPU backend. Overwrite Provisioner and Executor
382 // to force the network is compiled and run in profilingBackend. backend.
383 size_t devicesNum = devices_.size();
384 for (size_t i = 0; i < devicesNum; i++) {
385 auto name = devices_[i]->getDeviceConfig().name;
386 auto config = glow::make_unique<DeviceConfig>(profilingBackend, name);
387 devices_[i] = std::unique_ptr<DeviceManager>(
388 DeviceManager::createDeviceManager(*config));
389 RETURN_IF_ERR(devices_[i]->init());
390 }
391 provisioner_.reset(new Provisioner(devices_));
392 executor_.reset(new ThreadPoolExecutor(devices_, config_.executorThreads));
393 }
394
395 VLOG(1) << "Before replace dummy TQPs";
396 // Now that we've partitioned and optimized, do some verification based on the
397 // dummy mode we're using, if any.
398 if (cctx.precisionConfig.replaceDummyTQPs ||
399 cctx.precisionConfig.loadUniquedDummyQParams) {
400 RETURN_IF_ERR(module->verifyDummyQParams(
401 cctx.precisionConfig.loadUniquedDummyQParams));
402 }
403
404 // If we are loading an AOT model where we are replacing dummy TQPs, then we
405 // may need to update Relu output types on FCs, since they should be set to
406 // use zero as min but the correct qparams could not be calculated AOT.
407 if (cctx.loadingAOTModel && cctx.precisionConfig.replaceDummyTQPs) {
408 LOG(INFO) << "Updating quantized Relu types given real TQPs";
409 for (Function *F : module->getFunctions()) {
410 updateQuantReluTypes(F);
411 }
412 }
413
414 VLOG(1) << "Before constant folding";
415 // If we prevented constant modification then run constant folding with
416 // recording now. Record so that if we are going to serialize we can embed the
417 // constant folding subgraphs in the Glow ONNX model.
418 ConstantFoldingRecordMap record;
419 if (cctx.optimizationOpts.delayAndRecordConstantModification) {
420 constModPreventer.deactivateAndCleanup();
421
422 RETURN_ERR_IF_NOT(nodeList.size() == 1, "Expect only one DAG.");
423 const auto &dag = *nodeList.begin();
424 for (auto &dagNode : dag.nodes) {
425 Function *F = module->getFunction(dagNode->name);
426 RETURN_ERR_IF_NOT(
427 F, strFormat("Function %s not found", dagNode->name.data()));
428
429 ConstantFoldingRecordMap currRecord = constantFoldAndRecord(F, cctx);
430 record.insert(currRecord.begin(), currRecord.end());
431 runDCEPass(F, cctx);
432
433 // Verify the Function is valid after constant folding takes place.
434 Backend &B = provisioner_->getBackend(dagNode->backendName);
435 RETURN_ERR_IF_NOT(
436 B.verify(*F, cctx.verboseCompile),
437 "Unsupported node(s) found after delayed constant folding Function " +
438 F->getName().str() + " for backend " + B.getBackendName());
439 }
440 }
441 VLOG(1) << "Before loading AOT";
442 if (!cctx.loadingAOTModel) {
443 if (cctx.callDAGOptimizer) {
444#if FACEBOOK_INTERNAL
445 auto optDagErr = optimizeDAG(nodeList, *provisioner_, *module, deviceInfo,
446 cctx, record);
447 if (optDagErr) {
448 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
449 cleanupAddNetwork(names);
450 RETURN_ERR(optDagErr);
451 }
452#endif /* FACEBOOK_INTERNAL */
453 } else {
454 // If not using the DAG optimizer, iterate over the DAGs and call
455 // transformPostOptPipeline() on the Functions.
456 VLOG(1) << "No DAG optimizer";
457 for (const auto &dag : nodeList) {
458 for (auto &dagNode : dag.nodes) {
459 Function *F = module->getFunction(dagNode->name);
460 RETURN_ERR_IF_NOT(
461 F, strFormat("Function %s not found", dagNode->name.data()));
462
463 if (cctx.optimizationOpts.onlyLowerFuns.count(F)) {
464 continue;
465 }
466
467 Backend &B = provisioner_->getBackend(dagNode->backendName);
468 RETURN_IF_EXPECTED_IS_ERR(B.transformPostOptPipeline(F, cctx));
469
470 RETURN_ERR_IF_NOT(
471 B.verify(*F, cctx.verboseCompile),
472 "Unsupported node(s) found after transformPostOptPipeline() " +
473 F->getName().str() + " for backend " + B.getBackendName());
474 }
475 }
476 }
477 }
478
479 VLOG(1) << "Before serialize compile DAG";
480 // If requested, serialize the resulting DAG that was just optimized and
481 // partitioned.
482 if (cctx.serializeCompiledDAG) {
483 std::string loc;
484 char *envSpecifiedSerializationPath = getenv("GLOW_DAG_SERIALIZATION_LOC");
485 if (!envSpecifiedSerializationPath) {
486 loc = nodeList.begin()->root->name + ".onnxtxt";
487 } else {
488 loc = std::string(envSpecifiedSerializationPath);
489 }
490
491 LOG(INFO) << "Serializing final compiled DAG to " << loc;
492 {
493 llvm::StringMap<std::string> extraMetadataProps;
494 if (cctx.precisionConfig.originNameToTQPMap) {
495 RETURN_IF_ERR(ONNXModelWriter::insertLoaderNameUniqueOffsetMetadata(
496 extraMetadataProps, *cctx.precisionConfig.originNameToTQPMap));
497 }
498 if (cctx.precisionConfig.clipQuantRangeToFP16) {
499 extraMetadataProps[clipQuantRangeToFP16Key] = "1";
500 }
501 Error writeErr = Error::empty();
502 // Note: If cctx.skipProvisioning then we want to serialize all meta info
503 // as we are likely doing AOT optimization. Otherwise do not provide the
504 // meta info as the model does not need to be reloaded.
505 ONNXModelWriter onnxWR(
506 loc, nodeList, 7, 9, &writeErr,
507 /* textMode */ true,
508 /* zipMode */ cctx.useZipModeForSerializeCompiledDAG,
509 /* includeConstantData */ cctx.saveConstantInSerializeCompiledDAG,
510 extraMetadataProps, record, cctx.backendOpts.backendSpecificNodeInfo,
511 cctx.skipProvisioning ? &cctx.loadedPHNames : nullptr,
512 cctx.skipProvisioning ? &cctx.staticPlaceholderTypesForAOT : nullptr,
513 cctx.returnGlowSerializedModelStr
514 ? cctx.glowAOTSerializationModelStrPtr.get()
515 : nullptr);
516 RETURN_IF_ERR(writeErr);
517 }
518
519 // If we're using AOT DAG optimizer then skip provisioning.
520 if (cctx.skipProvisioning ||
521 (cctx.callDAGOptimizer && cctx.useDAGOptimizerAOTMode)) {
522 LOG(INFO) << "Host manager skipping provisioning";
523 {
524 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
525 cleanupAddNetwork(names);
526 }
527 debugDumpDAGGuard.dismiss();
528 cleanupConstantFolding(*module, record);
529 if (cctx.dumpFinalGraph) {
530 for (Function *F : module->getFunctions()) {
531 auto fname =
532 strFormat("%sfinal_graph_aot_%s.dot", cctx.dumpGraphPath.c_str(),
533 F->getName().data());
534 LOG(INFO) << "Dumping final graph to " << fname;
535 F->dumpDAG(fname);
536 }
537 }
538 return Error::success();
539 }
540 }
541
542 // Now that we've serialized the model if requested, cleanup the temporary
543 // Functions and PHs used for constant folding.
544 cleanupConstantFolding(*module, record);
545 VLOG(1) << "Before provisioning";
546 auto err = provisioner_->provision(nodeList, *module, cctx);
547 if (err) {
548 if (err.peekErrorValue()->isFatalError()) {
549 statsExporterRegistry_->setCounter(kDeviceFatalError, 1);
550 }
551 {
552 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
553 cleanupAddNetwork(names);
554 }
555 RETURN_ERR(err);
556 }
557 debugDumpDAGGuard.dismiss();
558 VLOG(1) << "Calculation of maxActiveRequests";
559 {
560 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
561 /// Calculate networkMaxActive requests. Then update
562 /// config_.maxActiveRequests This will be maxActiveRequestsPerInstance *
563 /// instanceCount * minReplications or config_.maxActiveRequests whichever
564 /// is smaller.
565
566 // Find the minimum on device replication.
567 unsigned minReplications{1};
568 for (auto &node : nodeList) {
569 for (auto &dag : node.nodes) {
570 minReplications = std::min(dag->replicationCount, minReplications);
571 }
572 }
573 unsigned product{0};
574 if (nodeList.size() && nodeList[0].nodes.size()) {
575 product = nodeList[0].nodes[0]->instanceCount *
576 cctx.maxActiveRequestsPerInstance * minReplications;
577 } else {
578 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR,
579 "NodeList is empty.");
580 }
581 unsigned maxActiveRequests = config_.maxActiveRequests;
582 config_.maxActiveRequests = std::min(product, maxActiveRequests);
583
584 // Create pool of cachedExecutionStates.
585 for (auto &node : nodeList) {
586 // Note: currently getNextNetworkExecutionState assumes that pool size is
587 // >= currentInFlight requests, so we set pool size to maxActiveRequests.
588 executor_->createPool(node.root.get(), config_.maxActiveRequests,
589 cctx.enableP2P, cctx.enableDRT);
590 }
591 }
592 // Clear constants contents from the module then put it in a
593 // shared_ptr to be shared between all of the networks created from each
594 // function in the module.
595 auto targetBackendName = std::string(devices_[0]->getBackendName());
596 const auto &targetBackend = provisioner_->getBackend(targetBackendName);
597 if (targetBackend.shouldStripModule() && !cctx.skipModuleStrip) {
598 module->strip();
599 }
600 VLOG(1) << "Cleanup";
601 auto sharedModule = std::shared_ptr<Module>(std::move(module));
602 {
603 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
604 for (auto &node : nodeList) {
605#if FACEBOOK_INTERNAL
606 LOG(INFO) << "Successfully compiled and provisioned " << node.root->name;
607#endif
608 auto &networkData = networks_[(node.root)->name];
609 networkData.dag = std::move(node);
610 networkData.module = sharedModule;
611 }
612 cleanupAddNetwork(names);
613 }
614 VLOG(1) << "After cleanup";
615 return Error::success();
616}
617
618#if FACEBOOK_INTERNAL
619Error HostManager::addNetworkFX(
620 std::unique_ptr<Module> module, CompilationContext &cctx,
621 DAGListTy &networks, const folly::dynamic &FXIR,
622 const llvm::StringMap<const void *> &constants) {
623
624 LOG(INFO) << "Adding Glow network built with revision hash: " << revisionHash;
625 VLOG(1) << "addNetwork";
626
627 std::vector<std::string> names;
628 {
629 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
630 auto functions = module->getFunctions();
631 for (auto &F : functions) {
632 const auto name = F->getName().str();
633 auto it = networks_.find(name);
634 if (it != networks_.end() ||
635 processingNetworks_.find(name) != processingNetworks_.end()) {
636 cleanupAddNetwork(names);
637 return MAKE_ERR(
638 ErrorValue::ErrorCode::RUNTIME_ERROR,
639 "Failed to add network: already have a function called " + name);
640 }
641 // Add the network to processingNetworks_ so we know it's being worked on.
642 processingNetworks_.insert(name);
643 names.push_back(name);
644 }
645 }
646
647 // Issue a warning when loading backend specific options from the command line
648 // and the compile context also contains backend specific options.
649 if (!loadBackendSpecificOptionsOpt.empty()) {
650 if (cctx.backendOpts.backendSpecificOpts.size() != 0) {
651 VLOG_EVERY_N(1, 1000) << "Warning: backendSpecificOpts is set via the "
652 "HostManager, ignoring previously set options.";
653 }
654 cctx.backendOpts.backendSpecificOpts =
655 deserializeStrStrMapFromYaml(loadBackendSpecificOptionsOpt);
656 } else {
657 auto ctxLoadBackendSpecificOpt =
658 cctx.backendOpts.backendSpecificOpts.find("loadBackendSpecificOptions");
659
660 if (ctxLoadBackendSpecificOpt !=
661 cctx.backendOpts.backendSpecificOpts.end()) {
662 cctx.backendOpts.backendSpecificOpts =
663 deserializeStrStrMapFromYaml(ctxLoadBackendSpecificOpt->second);
664 }
665 }
666
667 std::vector<DeviceInfo> deviceInfo;
668 {
669 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
670 for (auto &device : availableDevices_) {
671 DeviceInfo info = devices_[device]->getDeviceInfo();
672 info.availableMemory = devices_[device]->getAvailableMemory();
673 info.backendName = devices_[device]->getBackendName();
674 info.nonSupportedNodes =
675 devices_[device]->getParamByName("nonSupportedNodes");
676 info.supportedNodes = devices_[device]->getParamByName("supportedNodes");
677 // If p2p is enabled update the inputCount limit.
678 if (cctx.enableP2P) {
679 info.inputCountMax = P2PInputLimit;
680 }
681 deviceInfo.push_back(info);
682 }
683 }
684
685 VLOG(1) << "Before provisioning";
686 auto err =
687 provisioner_->provisionFX(networks, *module, FXIR, constants, cctx);
688 if (err) {
689 if (err.peekErrorValue()->isFatalError()) {
690 statsExporterRegistry_->setCounter(kDeviceFatalError, 1);
691 }
692 {
693 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
694 cleanupAddNetwork(names);
695 }
696 RETURN_ERR(err);
697 }
698
699 VLOG(1) << "Calculation of maxActiveRequests";
700 {
701 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
702 /// Calculate networkMaxActive requests. Then update
703 /// config_.maxActiveRequests This will be maxActiveRequestsPerInstance *
704 /// instanceCount * minReplications or config_.maxActiveRequests whichever
705 /// is smaller.
706
707 // Find the minimum on device replication.
708 unsigned minReplications{1};
709 for (auto &node : networks) {
710 for (auto &dag : node.nodes) {
711 minReplications = std::min(dag->replicationCount, minReplications);
712 }
713 }
714 unsigned product{0};
715 if (networks.size() && networks[0].nodes.size()) {
716 product = networks[0].nodes[0]->instanceCount *
717 cctx.maxActiveRequestsPerInstance * minReplications;
718 } else {
719 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_ERROR,
720 "NodeList is empty.");
721 }
722 unsigned maxActiveRequests = config_.maxActiveRequests;
723 config_.maxActiveRequests = std::min(product, maxActiveRequests);
724
725 // Create pool of cachedExecutionStates.
726 for (auto &node : networks) {
727 // Note: currently getNextNetworkExecutionState assumes that pool size is
728 // >= currentInFlight requests, so we set pool size to maxActiveRequests.
729 executor_->createPool(node.root.get(), config_.maxActiveRequests,
730 cctx.enableP2P, cctx.enableDRT);
731 }
732 }
733 // Clear constants contents from the module then put it in a
734 // shared_ptr to be shared between all of the networks created from each
735 // function in the module.
736 auto targetBackendName = std::string(devices_[0]->getBackendName());
737 const auto &targetBackend = provisioner_->getBackend(targetBackendName);
738 if (targetBackend.shouldStripModule() && !cctx.skipModuleStrip) {
739 module->strip();
740 }
741 VLOG(1) << "Cleanup";
742 auto sharedModule = std::shared_ptr<Module>(std::move(module));
743 {
744 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
745 for (auto &node : networks) {
746 LOG(INFO) << "Successfully compiled and provisioned " << node.root->name;
747 auto &networkData = networks_[(node.root)->name];
748 networkData.dag = std::move(node);
749 networkData.module = sharedModule;
750 }
751 cleanupAddNetwork(names);
752 }
753 VLOG(1) << "After cleanup";
754 return Error::success();
755}
756#endif
757
758std::unordered_map<std::string, std::vector<DeviceIDTy>>
759HostManager::getDevicePartitionMapping(llvm::StringRef network) {
760 std::unordered_map<std::string, std::vector<DeviceIDTy>> mapping;
761 auto it = networks_.find(network.str());
762 if (it != networks_.end()) {
763 auto &nodeList = it->second.dag.nodes;
764 for (auto &node : nodeList) {
765 std::vector<DeviceIDTy> devices;
766 for (auto &dev : node->deviceRuntimeInfos) {
767 devices.push_back(dev.first);
768 }
769 mapping[node->name] = devices;
770 }
771 }
772 return mapping;
773}
774
775Error HostManager::removeNetwork(llvm::StringRef networkName) {
776 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
777 auto networkIterator = networks_.find(networkName.str());
778 if (networkIterator == networks_.end()) {
779 return Error::success();
780 }
781
782 if (processingNetworks_.find(networkName.str()) !=
783 processingNetworks_.end()) {
784 // Return an error, the network is in an incomplete state likely because
785 // it is still being added by a different call.
786 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_BUSY,
787 llvm::formatv("Cannot remove the network {0}, as it is "
788 "currently being modified.",
789 networkName)
790 .str());
791 }
792
793 // Issue an error as there are outstanding runs for the network
794 if (networkIterator->second.refcount != 0) {
795 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_BUSY,
796 llvm::formatv("Cannot remove the network {0}, as there are "
797 "still outstanding runs",
798 networkName)
799 .str());
800 }
801
802 OneErrOnly err;
803 auto &nodes = networkIterator->second.dag.nodes;
804 // Free the pool of executionStates.
805 executor_->freePool(networkIterator->second.dag.root.get());
806 for (auto &node : nodes) {
807 for (auto device : node->deviceRuntimeInfos) {
808 Error evictErr = provisioner_->evictFunction(
809 node->name, devices_[device.first].get(), node->replicationCount);
810 err.set(std::move(evictErr));
811 }
812 // Also remove compiledFunction from Provisioner.
813 err.set(provisioner_->removeFunction(node->name));
814 }
815 networks_.erase(networkIterator);
816 exportMemoryCounters();
817 RETURN_ERR(err.get());
818}
819
820bool HostManager::networkAdded(llvm::StringRef networkName) {
821 std::shared_lock<std::shared_timed_mutex> networkLock(networkLock_);
822 return networks_.find(networkName.str()) != networks_.end();
823}
824
825Error HostManager::clearHost() {
826 // shutdown the executor, blocking on any current inflight and prevent new
827 // requests from being serviced.
828 executor_->shutdown();
829
830 DCHECK_EQ(activeRequestCount_, 0)
831 << "All requests should be finished when shutting down HostManager.";
832
833 // Remove all networks from the host and device(s).
834 while (networks_.size() != 0) {
835 RETURN_IF_ERR(removeNetwork(networks_.begin()->first));
836 }
837
838 // Now it's safe to stop the DeviceManagers.
839 std::unique_lock<std::shared_timed_mutex> networkLock(networkLock_);
840 OneErrOnly errContainer;
841 for (auto &it : devices_) {
842 errContainer.set(it.second->stop());
843 }
844 // Zero out counters.
845 statsExporterRegistry_->setCounter(kDeviceMemoryUsed, 0);
846 statsExporterRegistry_->setCounter(kDeviceMemoryAvailable, 0);
847 statsExporterRegistry_->setCounter(kDeviceMemoryMax, 0);
848
849 RETURN_ERR(errContainer.get());
850}
851
852Error HostManager::runNetworkBlocking(llvm::StringRef networkName,
853 PlaceholderBindings &bindings) {
854 std::unique_ptr<PlaceholderBindings> phBindings(&bindings);
855 std::unique_ptr<ExecutionContext> context =
856 glow::make_unique<ExecutionContext>(std::move(phBindings));
857 std::promise<void> runPromise;
858 auto fut = runPromise.get_future();
859 std::unique_ptr<Error> runErr;
860 runNetwork(
861 networkName, std::move(context),
862 [&runPromise, &runErr](runtime::RunIdentifierTy, Error err,
863 std::unique_ptr<ExecutionContext> contextPtr) {
864 // Don't delete ph bindings since they were created from a passed in
865 // reference.
866 std::unique_ptr<PlaceholderBindings> phBind =
867 contextPtr->movePlaceholderBindings();
868 phBind.release();
869
870 runErr = glow::make_unique<Error>(std::move(err));
871 runPromise.set_value();
872 });
873
874 fut.wait();
875 return std::move(*DCHECK_NOTNULL(runErr.get()));
876}
877
878Error HostManager::runNetworkBlocking(
879 llvm::StringRef networkName, std::unique_ptr<ExecutionContext> &context) {
880 std::promise<void> runPromise;
881 auto fut = runPromise.get_future();
882 Error runErr = Error::empty();
883 std::unique_ptr<ExecutionContext> tempContext;
884
885 runNetwork(networkName, std::move(context),
886 [&runPromise, &runErr,
887 &tempContext](runtime::RunIdentifierTy, Error err,
888 std::unique_ptr<ExecutionContext> resultCtxt) {
889 runErr = std::move(err);
890 tempContext = std::move(resultCtxt);
891 runPromise.set_value();
892 });
893
894 fut.wait();
895 context = std::move(tempContext);
896 return runErr;
897}
898
899void HostManager::dispatchNextRun() {
900 int requestId = -1;
901 llvm::Optional<InferRequest> pRequest;
902 std::shared_lock<std::shared_timed_mutex> networkLock(networkLock_);
903 {
904 // hmm this lock is hot but I still have it as a unique lock because
905 // we always need to pop inferQueue and inferQueue is not thread safe
906 std::unique_lock<std::shared_timed_mutex> queueLock(inferQueueLock_);
907 if (inferQueue_.size()) {
908 // Get the next request, unfortunately priority_queue only
909 // provides a const ref to the top element, since we need to move
910 // it we first cast it to remove the const.
911 pRequest = std::move(const_cast<InferRequest &>(inferQueue_.top()));
912 requestId = static_cast<int>(pRequest->requestID);
913 inferQueue_.pop();
914 } else {
915 // Decrement the activeRequest counter so new requests can
916 // launched.
917 --activeRequestCount_;
918 return;
919 }
920 }
921
922 assert(pRequest.hasValue());
923 InferRequest request = std::move(pRequest.getValue());
924 auto startTime = TraceEvent::now();
925 auto requestReceived = request.startTime;
926 executor_->run(
927 networks_[request.networkName].dag.root.get(), std::move(request.context),
928 request.requestID,
929 [this, callback = request.callback, name = request.networkName, startTime,
930 requestReceived](RunIdentifierTy runID, Error err,
931 std::unique_ptr<ExecutionContext> context) mutable {
932 {
933 std::shared_lock<std::shared_timed_mutex> netLock(networkLock_);
934 auto it = networks_.find(name);
935 if (it != networks_.end()) {
936 it->second.refcount--;
937 }
938 }
939
940 updateExecutionStats(startTime, context, name, err);
941 // Update request runtime.
942 auto requestData = ::glow::runtime::RequestData::get();
943 if (requestData) {
944 uint64_t end = TraceEvent::now();
945 requestData->startTime = requestReceived;
946 requestData->stopTime = end;
947 }
948
949 callback(runID, std::move(err), std::move(context));
950 dispatchNextRun();
951 });
952}
953
954RunIdentifierTy
955HostManager::runNetwork(llvm::StringRef networkName,
956 std::unique_ptr<ExecutionContext> context,
957 ResultCBTy callback, uint64_t priority) {
958 DCHECK(callback != nullptr);
959
960 TRACE_EVENT_SCOPE_NAMED(context->getTraceContext(), TraceLevel::RUNTIME,
961 "HostManager::runNetwork", traceBlock);
962 auto currentRun = totalRequestCount_++;
963 traceBlock.addArg("glowRequestId", llvm::formatv("{0}", currentRun).str());
964 uint64_t requestReceived = TraceEvent::now();
965 size_t queueSize = 0;
966
967 NetworkData *network = nullptr;
968 {
969 std::shared_lock<std::shared_timed_mutex> networkLock(networkLock_);
970 auto it = networks_.find(networkName.str());
971 if (it != networks_.end()) {
972 network = &it->second;
973 network->refcount++;
974 }
975
976 if (network == nullptr) {
977 TRACE_EVENT_SCOPE_END_NAMED(traceBlock);
978 callback(
979 currentRun,
980 MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_NOT_FOUND,
981 llvm::formatv("Function {0} not found", networkName).str()),
982 std::move(context));
983 return currentRun;
984 }
985 // Put the request in the queue.
986 {
987 std::shared_lock<std::shared_timed_mutex> lock(inferQueueLock_);
988 queueSize = inferQueue_.size();
989 if (queueSize >= config_.maxQueueSize) {
990 // The queue is full, return an error.
991 network->refcount--;
992 TRACE_EVENT_SCOPE_END_NAMED(traceBlock);
993 callback(
994 currentRun,
995 MAKE_ERR(
996 ErrorValue::ErrorCode::RUNTIME_REQUEST_REFUSED,
997 strFormat(
998 "The number of allowed queued requests has been exceeded. "
999 "queued requests: %lu allowed requests: %zu",
1000 queueSize, config_.maxQueueSize)),
1001 std::move(context));
1002 return currentRun;
1003 }
1004 }
1005 reportCurrentQueueSize(queueSize);
1006 // Setup the request
1007 InferRequest queuedRequest(networkName.str(), std::move(context), callback,
1008 priority, currentRun, requestReceived);
1009 {
1010 std::unique_lock<std::shared_timed_mutex> lock(inferQueueLock_);
1011 TRACE_EVENT_SCOPE_END_NAMED(traceBlock);
1012 inferQueue_.push(std::move(queuedRequest));
1013 }
1014 }
1015
1016 // If we haven't reached maxActiveRequests kick off next request.
1017 size_t activeRequestCount = activeRequestCount_++;
1018 if (activeRequestCount < config_.maxActiveRequests) {
1019 dispatchNextRun();
1020 return currentRun;
1021 }
1022 activeRequestCount_--;
1023 return currentRun;
1024}
1025
1026/// Helper to report current queue size
1027void HostManager::reportCurrentQueueSize(int32_t queueSize) {
1028 statsExporterRegistry_->setCounter(
1029 kCurrentQueueSize10k, static_cast<float>(queueSize) /
1030 static_cast<float>(config_.maxQueueSize) *
1031 100000);
1032}
1033
1034/// Helper to update execution stats
1035void HostManager::updateExecutionStats(
1036 uint64_t startTime, std::unique_ptr<ExecutionContext> &context,
1037 llvm::StringRef networkName, const Error &error) {
1038 auto duration = TraceEvent::now() - startTime;
1039 auto updateCountersFn = [&](llvm::StringRef s) {
1040 statsExporterRegistry_->addTimeSeriesValue(
1041 ("glow.execution_duration_e2e." + s).str(), duration);
1042 statsExporterRegistry_->incrementCounter(
1043 ("glow.requests_processed." + s).str());
1044 if (error.peekErrorValue()) {
1045 statsExporterRegistry_->incrementCounter(
1046 ("glow.requests_failed." + s).str());
1047 } else {
1048 statsExporterRegistry_->incrementCounter(
1049 ("glow.requests_succeeded." + s).str());
1050 }
1051 };
1052 updateCountersFn(networkName);
1053 updateCountersFn("global");
1054}
1055
1056/// Helper to get the parameters in DeviceConfig from \p str. The \p str has
1057/// multiple lines, and each line with this format : "str1" : "str2".
1058static llvm::StringMap<std::string> getBackendParams(std::string &str) {
1059 llvm::StringMap<std::string> ret{};
1060 std::string s;
1061 std::istringstream f(str.c_str());
1062 while (getline(f, s, '\n')) {
1063 // Abstract the mapping from each line's string:
1064 // ""str1" : "str2"" => ret["str1"] = "str2";
1065 size_t pos1, pos2, pos3, pos4;
1066 pos1 = s.find('"');
1067 assert(pos1 != std::string::npos && "invalid string format");
1068 pos2 = s.find('"', pos1 + 1);
1069 assert(pos2 != std::string::npos && "invalid string format");
1070 pos3 = s.find('"', pos2 + 1);
1071 assert(pos3 != std::string::npos && "invalid string format");
1072 pos4 = s.find('"', pos3 + 1);
1073 assert(pos4 != std::string::npos && "invalid string format");
1074 ret[s.substr(pos1 + 1, pos2 - pos1 - 1)] =
1075 s.substr(pos3 + 1, pos4 - pos3 - 1);
1076 }
1077 return ret;
1078}
1079
1080/// If the device config file \p loadDeviceDoncfigsFile available, load \p
1081/// configs from the file. Otherwise, create \p numDevices number of devices
1082/// based on \p backendName.
1083std::vector<std::unique_ptr<runtime::DeviceConfig>>
1084runtime::generateDeviceConfigs(unsigned int numDevices,
1085 llvm::StringRef backendName, size_t memSize) {
1086 std::vector<std::unique_ptr<runtime::DeviceConfig>> configs;
1087 if (!loadDeviceConfigsFromFile(configs, memSize)) {
1088 // If there is no device config file, use numDevices to generate the
1089 // configs.
1090 std::vector<unsigned> available_device_ids;
1091 if (glow::flags::ScanDevices) {
1092 const auto &factories =
1093 FactoryRegistry<std::string, Backend>::factories();
1094 auto it = factories.find(backendName.str());
1095 if (it != factories.end()) {
1096 available_device_ids = it->second->scanDeviceIDs();
1097 }
1098 CHECK_GE(available_device_ids.size(), 0) << "No devices found.";
1099 CHECK_GE(available_device_ids.size(), numDevices)
1100 << "Not enough devices found.";
1101 }
1102 for (unsigned int i = 0; i < numDevices; ++i) {
1103 auto config = glow::make_unique<runtime::DeviceConfig>(backendName);
1104 config->setDeviceMemory(memSize);
1105 if (glow::flags::ScanDevices) {
1106 config->deviceID = available_device_ids.back();
1107 available_device_ids.pop_back();
1108 } else {
1109 config->deviceID = i;
1110 }
1111 configs.push_back(std::move(config));
1112 }
1113 }
1114 return configs;
1115}
1116
1117bool runtime::loadDeviceConfigsFromFile(
1118 std::vector<std::unique_ptr<runtime::DeviceConfig>> &configs,
1119 size_t memSize) {
1120 if (loadDeviceConfigsFileOpt.empty()) {
1121 return false;
1122 }
1123
1124 std::vector<DeviceConfigHelper> lists;
1125 lists = deserializeDeviceConfigFromYaml(loadDeviceConfigsFileOpt);
1126 for (unsigned int i = 0; i < lists.size(); ++i) {
1127 std::string configBackendName = lists[i].backendName_;
1128 std::string name = lists[i].name_;
1129 auto parameters = getBackendParams(lists[i].parameters_.str);
1130 auto config = glow::make_unique<runtime::DeviceConfig>(configBackendName,
1131 name, parameters);
1132 config->deviceID = i;
1133 config->setDeviceMemory(memSize);
1134 configs.push_back(std::move(config));
1135 }
1136 return true;
1137}
1138
1139Backend &HostManager::getBackend(llvm::StringRef backendName) const {
1140 return provisioner_->getBackend(backendName);
1141}
1142
1143Expected<Backend *> HostManager::getBackend() const {
1144 return provisioner_->getBackend();
1145}
1146
1147std::unique_ptr<
1148 std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>>
1149HostManager::getAllSerializedFunctions() {
1150 return provisioner_->getAllSerializedFunctionsMap();
1151}
1152
1153HostManager *HostManagerRegistry::getHostManager() { return hostManager_; }
1154
1155void HostManagerRegistry::registerHostManager(HostManager *hostManager) {
1156 hostManager_ = hostManager;
1157}
1158
1159std::shared_ptr<HostManagerRegistry> glow::runtime::ManagerRegistry() {
1160 static auto hostManager = std::make_shared<HostManagerRegistry>();
1161 return hostManager;
1162}
1163