1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Runtime/Provisioner/Provisioner.h"
18#include "folly/String.h"
19#include "glow/Backend/BackendUtils.h"
20#include "glow/Backend/CompiledFunction.h"
21#include "glow/Flags/Flags.h"
22#include "glow/Graph/Graph.h"
23#include "glow/Runtime/DeferredWeightLoader.h"
24#include "glow/Support/Debug.h"
25
26#include "llvm/Support/FileSystem.h"
27#include "llvm/Support/FormatVariadic.h"
28
29#include <folly/dynamic.h>
30#include <future>
31#include <map>
32#include <mutex>
33#include <queue>
34#include <set>
35#include <vector>
36
37using namespace glow;
38using namespace runtime;
39
40namespace {
41std::string getReplicatedName(std::string name, unsigned count) {
42 return name + "_replicated" + std::to_string(count);
43}
44} // namespace
45
46namespace {
47// STL sorting algorithm cannot inline predicate if it got provided as a regular
48// function.
49// Template instantiation expands std::sort with predicate type as
50// (bool)(const std::pair<DeviceIDTy, uint64_t> &,
51// const std::pair<DeviceIDTy, uint64_t> &).
52// It means any regular function with the above signature will match
53// the template instantiation, and compiler cannot inline the code of
54// one of the possible functions.
55// Declaring lambda, which has a unique type regardless its signature,
56// forces compiler to instantiate the template with a provided unique type and
57// correspondently compiler can inline the lambda code.
58auto sortMostMemory = [](const std::pair<DeviceIDTy, uint64_t> &a,
59 const std::pair<DeviceIDTy, uint64_t> &b) -> bool {
60 return a.second > b.second;
61};
62} // namespace
63
64Provisioner::Provisioner(DeviceManagerMapTy &devices) {
65 unsigned deviceMapping{0};
66 for (auto &device : devices) {
67 devices_.push_back(device.second.get());
68 deviceMappings_.push_back(deviceMapping++);
69 auto backendName = device.second->getBackendName().str();
70 if (backends_.find(backendName) == backends_.end()) {
71 std::unique_ptr<Backend> newBackend(createBackend(backendName));
72 backends_.emplace(std::string(backendName), std::move(newBackend));
73 }
74 }
75}
76
77Error Provisioner::checkActiveNetworks(
78 const DAGListTy &networks, std::vector<std::string> &localActiveNames) {
79
80 std::lock_guard<std::mutex> networkLock(functionsLock_);
81 for (auto &network : networks) {
82#if FACEBOOK_INTERNAL
83 LOG(INFO) << "Checking for active networks when adding: "
84 << network.root->name;
85#endif
86 for (auto &node : network.nodes) {
87 // Check to see if another thread is actively working on the same
88 // networks.
89 if (activeFunctions_.find(node->name) != activeFunctions_.end()) {
90 for (auto &name : localActiveNames) {
91 activeFunctions_.erase(name);
92 }
93 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_BUSY,
94 llvm::formatv("Cannot add the network {0}, as it is "
95 "currently being provisioned.",
96 node->name)
97 .str());
98 }
99#if FACEBOOK_INTERNAL
100 LOG(INFO) << "Adding partition name: " << node->name
101 << " to activeFunctions_";
102#endif
103 localActiveNames.push_back(node->name);
104 activeFunctions_.insert(node->name);
105 }
106 }
107 return Error::success();
108}
109
110std::map<DeviceIDTy, std::vector<DAGNode *>>
111Provisioner::generateLogicalDevices(const DAGListTy &networks) {
112 // For each network visit all the partitions (nodes) and add the node to each
113 // logical device it is assigned to.
114 std::map<DeviceIDTy, std::vector<DAGNode *>> logicalDevices;
115 for (auto &network : networks) {
116 for (auto &node : network.nodes) {
117 for (auto logical : node->logicalDevices) {
118 auto it = logicalDevices.find(logical);
119 if (it != logicalDevices.end()) {
120 it->second.push_back(node.get());
121 } else {
122 logicalDevices.emplace(logical, std::vector<DAGNode *>{node.get()});
123 }
124 }
125 }
126 }
127 return logicalDevices;
128}
129
130/// Helper method to calculate the size of each logical device, returns a
131/// vector of deviceID size pairs sorted in descending order by size.
132static std::vector<std::pair<DeviceIDTy, uint64_t>> calculateLogicalDeviceSize(
133 const std::map<DeviceIDTy, std::vector<DAGNode *>> &devices) {
134 std::vector<std::pair<DeviceIDTy, uint64_t>> logicalDeviceSize;
135 for (auto &device : devices) {
136 uint64_t sum{0};
137 for (const auto *node : device.second) {
138 sum += node->size;
139 }
140 logicalDeviceSize.push_back(std::make_pair(device.first, sum));
141 }
142 // Sort by total size in descending order.
143 std::sort(logicalDeviceSize.begin(), logicalDeviceSize.end(), sortMostMemory);
144 return logicalDeviceSize;
145}
146
147Expected<std::map<DeviceIDTy, DeviceIDTy>>
148Provisioner::generateDeviceAssignments(
149 const std::vector<std::pair<DeviceIDTy, uint64_t>> &logicalDeviceSize,
150 std::map<std::string, std::vector<std::pair<DeviceIDTy, uint64_t>>>
151 &deviceMemoryMap,
152 std::map<DeviceIDTy, std::vector<DAGNode *>> &logicalDevices) {
153 // Generate assignments, logical DeviceID to physical DeviceID.
154 std::map<DeviceIDTy, DeviceIDTy> deviceAssignment;
155 // Setup iterators for each backend type, intialize them to 0.
156 std::map<std::string, unsigned> positions;
157 for (auto &device : deviceMemoryMap) {
158 positions[device.first] = 0;
159 }
160 // Walk through the logical devices and assign them a physical device.
161 // This approach will try to evenly spread networks across devices, we first
162 // sort all devices by available space and then assign in descending order.
163 // Once we reach the end we resort and start over. This goes until we are
164 // unable to load a network at which point we sort one more time if the first
165 // device has enough space we continue, otherwise we return an error.
166 // This approach is to prevent many small networks from clumping on a single
167 // device.
168 for (auto logicalDevice : logicalDeviceSize) {
169 // First check that there the requested backend kind is available.
170 auto backendName = logicalDevices[logicalDevice.first][0]->backendName;
171 if (deviceMemoryMap.find(backendName) == deviceMemoryMap.end()) {
172 // Backend is unavailable return an error.
173 return MAKE_ERR(
174 ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND,
175 llvm::formatv("Cannot add the network {0}, as the requested "
176 "backend: {1} is unavailable.",
177 logicalDevices[logicalDevice.first][0]->name,
178 backendName)
179 .str());
180 }
181
182 auto currentPosition = positions[backendName];
183 if (deviceMemoryMap[backendName][currentPosition].second >=
184 logicalDevice.second) {
185 // There is enough space, assign the logical device to this physical
186 // device, increment the iterator and update the available memory.
187 deviceAssignment.emplace(
188 logicalDevice.first,
189 deviceMemoryMap[backendName][currentPosition].first);
190 deviceMemoryMap[backendName][currentPosition].second -=
191 logicalDevice.second;
192
193 // Check if we are at the end of the vector of devices.
194 if (currentPosition == deviceMemoryMap[backendName].size() - 1) {
195 // We are at the end of the vector of devices, re-sort and reset
196 // position to 0.
197 std::sort(deviceMemoryMap[backendName].begin(),
198 deviceMemoryMap[backendName].end(), sortMostMemory);
199 positions[backendName] = 0;
200 } else {
201 // Increment current position by one.
202 positions[backendName] = currentPosition + 1;
203 }
204 } else {
205 // Before we assume failure we should re-sort the list to see if the
206 // current largest amount of available space is enough to fit.
207 std::sort(deviceMemoryMap[backendName].begin(),
208 deviceMemoryMap[backendName].end(), sortMostMemory);
209 if (deviceMemoryMap[backendName][0].second >= logicalDevice.second) {
210 // There's a device that still has room, assign the network here.
211 deviceAssignment.emplace(logicalDevice.first,
212 deviceMemoryMap[backendName][0].first);
213 deviceMemoryMap[backendName][0].second -= logicalDevice.second;
214
215 // Since after sorting we were abel to add to device 0 set the current
216 // position 1 we modulo with the number of devices in case there is only
217 // 1 device.
218 currentPosition = 1 % deviceMemoryMap[backendName].size();
219 positions[backendName] = currentPosition;
220 } else {
221 // Return an error there is insufficient space for the logical device on
222 // any available device.
223 return MAKE_ERR(
224 ErrorValue::ErrorCode::RUNTIME_OUT_OF_DEVICE_MEMORY,
225 strFormat(
226 "Logical Device is too large to fit in available device "
227 "memory. Largest device memory: %lu, logic device size: %lu",
228 deviceMemoryMap[backendName][0].second, logicalDevice.second));
229 }
230 }
231 }
232
233 // Update nodes in logicalDevices with their assignments.
234 for (auto &assignment : deviceAssignment) {
235 for (auto &node : logicalDevices[assignment.first]) {
236 node->deviceRuntimeInfos[deviceMappings_[assignment.second]] =
237 DeviceRuntimeInfo();
238 }
239 }
240 return deviceAssignment;
241}
242
243Error Provisioner::provisionNetwork(std::unique_ptr<Network> network) {
244 VLOG(1) << "Started provisioner";
245 DAGListTy &networks = network->networks;
246 Module &module = network->module;
247 CompilationContext &cctx = network->cctx;
248 // Check that the requested networks don't collide with the names of any other
249 // networks being added.
250 std::vector<std::string> localActiveNames;
251 RETURN_IF_ERR(checkActiveNetworks(networks, localActiveNames));
252
253 // Mapping from function name to its compiled function. NB: compiledFunctions
254 // will hold compiled function which might be used in clean up process by
255 // cleanupGuard, hence this needs to be declared before cleanupGuard. We
256 // probably should clean up the compiledFunctions logic to make this more
257 // intuitive.
258 llvm::StringMap<std::unique_ptr<CompiledFunction>> compiledFunctions;
259
260 // If any error happens during the provison process, we will clean up the
261 // compiled networks.
262 std::map<DeviceIDTy, std::vector<std::string>> addedNetworks;
263 ScopeGuard cleanupGuard([&localActiveNames, &addedNetworks, this]() {
264 cleanupProvision(localActiveNames, addedNetworks);
265 });
266
267 // Walk the networks and group by logicalDeviceId.
268 auto logicalDevices = generateLogicalDevices(networks);
269
270 if (cctx.backendOpts.collectConstants) {
271 VLOG(1) << "Warning: collectConstants is set in a Runtime compile, "
272 "ignoring it.";
273 }
274 if (cctx.backendOpts.backendHints.SRAMPrioritization.size() != 0 ||
275 cctx.backendOpts.backendHints.executionUnits) {
276 VLOG(1) << "Warning: backendHints is set in a Runtime compile, "
277 "ignoring it.";
278 }
279
280 // Set collectConstants to false, this is because the DeviceManager will
281 // handle moving constants to the device, this way we can eliminate one
282 // copy operation.
283 cctx.backendOpts.collectConstants = false;
284
285 // Calculate the size of each logical device.
286 auto logicalDeviceSize = calculateLogicalDeviceSize(logicalDevices);
287
288 // Get available memory for all devices.
289 std::vector<std::pair<DeviceIDTy, uint64_t>> deviceMemory;
290 for (unsigned i = 0; i < devices_.size(); i++) {
291 uint64_t availableMemory = devices_[i]->getAvailableMemory();
292 deviceMemory.push_back(std::make_pair(i, availableMemory));
293 }
294
295 // Get available device memory, create a map of vectors for each backend kind
296 std::map<std::string, std::vector<std::pair<DeviceIDTy, uint64_t>>>
297 deviceMemoryMap;
298 for (unsigned i = 0; i < devices_.size(); i++) {
299 uint64_t availableMemory = devices_[i]->getAvailableMemory();
300
301 deviceMemoryMap[devices_[i]->getBackendName().str()].push_back(
302 std::make_pair(i, availableMemory));
303 }
304
305 // Sort all vectors in descending order of available memory.
306 for (auto &sizes : deviceMemoryMap) {
307 std::sort(sizes.second.begin(), sizes.second.end(), sortMostMemory);
308 }
309
310 // Generate assignments between physical and logical devices.
311 auto deviceAssignments = generateDeviceAssignments(
312 logicalDeviceSize, deviceMemoryMap, logicalDevices);
313
314 VLOG(1) << "Before device assignment";
315 // Check for errors.
316 if (!deviceAssignments) {
317 RETURN_ERR(deviceAssignments.takeError());
318 }
319 auto assignments = std::move(*deviceAssignments);
320
321 VLOG(1) << "Before compile";
322
323 // Stores function name and the remaining logical device count for that
324 // function.
325 llvm::StringMap<size_t> remainingDeviceCount;
326 // Mapping from function name to its backend options.
327 llvm::StringMap<BackendOptions> optsMap;
328
329 // Compile and load.
330 // This is done one logical device at a time. All functions in a logical
331 // device are compiled and then added to their assigned device. If a function
332 // is in multiple logical devices it is stored so that it only needs to be
333 // compiled once.
334 if (network->networkType == NetworkType::GLOW_NETWORK) {
335 for (auto &assignment : assignments) {
336 auto logicalDevice = assignment.first;
337 auto physicalDevice = assignment.second;
338 auto deviceBackendName = logicalDevices[logicalDevice][0]->backendName;
339
340 if (backends_.find(deviceBackendName) == backends_.end()) {
341 // Return error requested device type not found.
342 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND,
343 "Unable to find device of type: " + deviceBackendName);
344 }
345
346 // Stores all the functions in a logical device.
347 std::vector<glow::Function *> functionsToCompile;
348 // Stores the compiled functions that will be added to physical device.
349 FunctionMapTy functionMap;
350
351 // Collect all the functions in a logical device.
352 for (auto &node : logicalDevices[logicalDevice]) {
353 // If the function name exist we don't need to compile it again.
354 if (optsMap.count(node->name)) {
355 remainingDeviceCount[node->name] -= 1;
356 continue;
357 }
358
359 auto options = cctx.backendOpts;
360 options.backendHints = node->backendHints;
361 // Insert all options loaded in the Partitioner alongside options
362 // previously inserted, with Partitioner options taking precedence in
363 // case of a collision of keys.
364 for (auto &it : node->backendSpecificOpts) {
365 options.backendSpecificOpts[it.first] = it.second;
366 }
367 std::lock_guard<std::mutex> functionsLock(functionsLock_);
368 Function *function = module.getFunction(node->name);
369
370 functionsToCompile.push_back(function);
371 optsMap.insert({function->getName(), options});
372 functionReplicaCount_.emplace(node->name, node->replicationCount);
373 remainingDeviceCount.insert(
374 {node->name, node->logicalDevices.size() - 1});
375 }
376
377 // Compile all the functions in the logical device together.
378 // We add a lock here because some backends are not threadsafe (CPU
379 // backend).
380 std::unique_lock<std::mutex> compileLock(functionsLock_);
381 auto compiledOrErr = backends_[deviceBackendName]->compileFunctions(
382 functionsToCompile, optsMap);
383 VLOG(1) << "After compile";
384 compileLock.unlock();
385
386 // Dump graph and logs
387 for (auto *function : functionsToCompile) {
388 // Note: This needs to come after compile above because compile may
389 // modify the Function as well.
390 if (cctx.dumpFinalGraph) {
391 auto fname = strFormat(
392 "%sfinal_graph_%s_%s.dot", cctx.dumpGraphPath.c_str(),
393 deviceBackendName.c_str(), function->getName().str().c_str());
394 LOG(INFO) << "Dumping final graph to " << fname;
395 function->dumpDAG(fname);
396 // print stats of node
397 std::map<std::string, int> opCounter;
398 for (const auto &node : function->getNodes()) {
399 opCounter[node.getKindName()]++;
400 }
401 std::ostringstream ss;
402 ss << "Dump of Node stats for Function:\n";
403 ss << folly::stringPrintf("%30s %13s \n", "NodeKind", "Count");
404 for (const auto &p : opCounter) {
405 ss << folly::stringPrintf("%30s %13d \n", p.first.c_str(),
406 p.second);
407 }
408 LOG(INFO) << ss.str();
409 }
410
411 if (glow::flags::DumpCompilationLog) {
412 llvm::SmallString<64> path;
413 std::string prefix =
414 llvm::formatv("{0}-{1}", cctx.compilationLogPrefix,
415 function->getName())
416 .str();
417 auto tempFileRes =
418 llvm::sys::fs::createTemporaryFile(prefix, "log", path);
419 if (tempFileRes.value() != 0) {
420 LOG(ERROR)
421 << "Failed to create temp file for Glow compilation log: "
422 << tempFileRes;
423 }
424
425 function->getLogContext()->dumpLog(path);
426 }
427 }
428
429 // If err return it, else store compiled functions into compiledFunctions.
430 if (!compiledOrErr) {
431 RETURN_ERR(compiledOrErr.takeError());
432 }
433 auto compiled = std::move(*compiledOrErr);
434 for (auto &compiledFunction : compiled) {
435
436 // Deserialize compiled function from cctx.nameToFunctions
437 if (cctx.backendOpts.useDeserialize) {
438 std::string name = compiledFunction.first().str();
439 if (cctx.nameToFunctions.find(name) == cctx.nameToFunctions.end()) {
440 return MAKE_ERR(
441 ErrorValue::ErrorCode::UNKNOWN,
442 "Cannot find compiled function when deserializing " + name);
443 }
444 RETURN_IF_ERR(compiledFunction.second->deserialize(
445 *(cctx.nameToFunctions.find(name)->second)));
446 }
447 compiledFunctions.try_emplace(compiledFunction.first(),
448 std::move(compiledFunction.second));
449 }
450 // Construnct functionMap for physical device.
451 for (auto &node : logicalDevices[logicalDevice]) {
452 RETURN_ERR_IF_NOT(compiledFunctions.count(node->name),
453 "Can't find corresponding compiled function " +
454 node->name);
455
456 auto *compiledFunction = compiledFunctions[node->name].get();
457 functionMap.emplace(node->name, compiledFunction);
458
459 for (unsigned i = 1; i < node->replicationCount; i++) {
460 auto replicatedName = getReplicatedName(node->name, i);
461 functionMap.emplace(replicatedName, compiledFunction);
462 }
463
464 // Dump backend-specific IR
465 if (glow::flags::DumpBackendSpecificIRJSON) {
466 compiledFunction->dumpJSON(strFormat("%sbackend_specific_ir_%s.json",
467 cctx.dumpGraphPath.c_str(),
468 node->name.c_str()));
469 }
470
471 node->runtimeBundle = glow::make_unique<RuntimeBundle>(
472 compiledFunction->getRuntimeBundle());
473 }
474
475 // Now that the functions are compiled add them to their assigned device
476 // then cleanup.
477 std::promise<void> addPromise;
478 auto ready = addPromise.get_future();
479 std::unique_ptr<Error> addErr;
480 devices_[physicalDevice]->addNetwork(
481 &module, functionMap,
482 [&addErr, &addPromise](const Module *, Error err) {
483 addErr = glow::make_unique<Error>(std::move(err));
484 addPromise.set_value();
485 });
486 ready.wait();
487 DCHECK_NOTNULL(addErr.get());
488 if (*addErr.get()) {
489 return std::move(*addErr.get());
490 }
491
492 // Add networks successfully loaded on device to addedNetworks, this way
493 // if we fail later we can evict them.
494 for (const auto &func : functionMap) {
495 addedNetworks[physicalDevice].push_back(func.first);
496 }
497 VLOG(1) << "Added networks";
498
499 // Free up memory no longer needed by the compiledFunction.
500 for (auto &node : logicalDevices[logicalDevice]) {
501 // If the compiled function still needs to be added to other device,
502 // don't free the resources.
503 if (remainingDeviceCount[node->name] > 0) {
504 continue;
505 }
506
507 // Free compilation resources. This need to be done after add network
508 // and before move on to next logical device. If
509 // DisableFreeCompilationResource is true, we will not free it here.
510 // This is used in scenarios like model serialization.
511 auto &funtionPtr = compiledFunctions[node->name];
512 if (!glow::flags::DisableFreeCompilationResource) {
513 funtionPtr->freeCompilationResources();
514 }
515
516 // Move compiled functions from compiledFunctions to functions_.
517 {
518 std::lock_guard<std::mutex> functionsLock(functionsLock_);
519 functions_.emplace(node->name, std::move(funtionPtr));
520 }
521
522 compiledFunctions.erase(node->name);
523 }
524 }
525 } else if (network->networkType == NetworkType::FX_NETWORK) {
526#if FACEBOOK_INTERNAL
527 // Container for duplicated functions and map tracking remaining installs
528 // for a duplicated function.
529 std::map<std::string, std::unique_ptr<CompiledFunction>>
530 duplicatedFunctions;
531 std::map<DAGNode *, unsigned> remainingDuplications;
532 for (auto &assignment : assignments) {
533 auto logicalDevice = assignment.first;
534 auto physicalDevice = assignment.second;
535 auto deviceBackendName = logicalDevices[logicalDevice][0]->backendName;
536 FunctionMapTy functionMap;
537 // Container for the compiledFunctions for this logicalDevice.
538 std::map<std::string, std::unique_ptr<CompiledFunction>>
539 compiledFunctions;
540
541 for (auto &node : logicalDevices[logicalDevice]) {
542 // Check if this is a duplicated function that has already been
543 // compiled.
544 if (duplicatedFunctions.find(node->name) != duplicatedFunctions.end()) {
545 functionMap.emplace(node->name,
546 duplicatedFunctions[node->name].get());
547 remainingDuplications[node] -= 1;
548 } else {
549 // Compile and add to function map.
550 auto options = cctx.backendOpts;
551 options.backendHints = node->backendHints;
552 // Insert all options loaded in the Partitioner alongside options
553 // previously inserted, with Partitioner options taking precedence in
554 // case of a collision of keys.
555 for (auto &it : node->backendSpecificOpts) {
556 options.backendSpecificOpts[it.first] = it.second;
557 }
558 if (backends_.find(deviceBackendName) == backends_.end()) {
559 // Return error requested device type not found.
560 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND,
561 "Unable to find device of type: " +
562 deviceBackendName);
563 }
564 auto fxNetwork = static_cast<FXNetwork *>(network.get());
565 auto compiledOrErr = backends_[deviceBackendName]->compileFX(
566 fxNetwork->FXIR, node->name, fxNetwork->constants, options,
567 &module);
568
569 // Check to see if an error was encountered while compiling.
570 if (!compiledOrErr) {
571 // If an error occured return the error.
572 RETURN_ERR(compiledOrErr.takeError());
573 }
574 auto compiled = std::move(*compiledOrErr);
575
576 node->runtimeBundle =
577 glow::make_unique<RuntimeBundle>(compiled->getRuntimeBundle());
578
579 functionMap.emplace(node->name, compiled.get());
580 // If this function is in more than one logical device store it for
581 // reuse.
582 if (node->logicalDevices.size() > 1) {
583 duplicatedFunctions.emplace(node->name, std::move(compiled));
584 remainingDuplications[node] = node->logicalDevices.size() - 1;
585 } else {
586 compiledFunctions.emplace(node->name, std::move(compiled));
587 }
588 }
589 }
590 VLOG(1) << "After compile";
591
592 // Now that the functions are compiled add them to their assigned device
593 // then cleanup.
594 std::promise<void> addPromise;
595 auto ready = addPromise.get_future();
596 std::unique_ptr<Error> addErr;
597 devices_[physicalDevice]->addNetwork(
598 &module, functionMap,
599 [&addErr, &addPromise](const Module *, Error err) {
600 addErr = glow::make_unique<Error>(std::move(err));
601 addPromise.set_value();
602 });
603 ready.wait();
604 DCHECK_NOTNULL(addErr.get());
605 if (*addErr.get()) {
606 return std::move(*addErr.get());
607 }
608 // Add networks successfully loaded on device to addedNetworks, this way
609 // if we fail later we can evict them.
610 for (auto &node : logicalDevices[logicalDevice]) {
611 addedNetworks[physicalDevice].push_back(node->name);
612 }
613 VLOG(1) << "Added networks";
614
615 // Free up memory no longer needed by the compiledFunction.
616 for (auto &func : compiledFunctions) {
617 func.second->freeCompilationResources();
618 }
619 {
620 // Move compiled functions from compiledFunctions to functions_.
621 std::lock_guard<std::mutex> functionsLock(functionsLock_);
622 for (auto &func : compiledFunctions) {
623 functions_.emplace(func.first, std::move(func.second));
624 }
625 // Check if any of the duplicated functions can also be moved.
626 for (auto iter = remainingDuplications.begin();
627 iter != remainingDuplications.end();) {
628 const auto &func = *iter;
629 if (func.second == 0) {
630 duplicatedFunctions[func.first->name]->freeCompilationResources();
631 functions_.emplace(
632 func.first->name,
633 std::move(duplicatedFunctions[func.first->name]));
634 duplicatedFunctions.erase(func.first->name);
635 iter = remainingDuplications.erase(iter);
636 } else {
637 ++iter;
638 }
639 }
640 }
641 }
642#endif
643 }
644 RETURN_ERR_IF_NOT(compiledFunctions.empty(),
645 "compiledFunctions should be empty because all compiled "
646 "functions should be moved to Provisioner::function_");
647
648 // Map from Placeholder* to DeviceManager, this is used for deferred weight
649 // loading.
650 std::unordered_map<Placeholder *, std::vector<unsigned>>
651 placeholderToDeviceManager;
652 if (cctx.deferredWeightLoader) {
653 // Populate placeholdeToDeviceManager map.
654 for (auto &assignment : assignments) {
655 for (const auto &node : logicalDevices[assignment.first]) {
656 auto symbolTable = node->runtimeBundle->getSymbolTable();
657 for (auto info : symbolTable) {
658 if (info.second.symbolCategory ==
659 glow::runtime::SymbolCategory::Placeholder) {
660 auto PH = module.getPlaceholderByNameSlow(info.first);
661 if (PH->isStatic()) {
662 placeholderToDeviceManager[PH].push_back(assignment.second);
663 }
664 }
665 }
666 }
667 }
668 } else {
669 // Make sure there are no static placeholders.
670 for (auto PH : module.getPlaceholders()) {
671 if (PH->isStatic()) {
672 return MAKE_ERR(
673 ErrorValue::ErrorCode::RUNTIME_ERROR,
674 llvm::formatv("Error Placholder: {0} is marked as static but no "
675 "deferredWeightLoader is provided.",
676 PH->getName())
677 .str());
678 ;
679 }
680 }
681 }
682 // If a deferredWeightLoader is provided, create a deferredWeightLoader and
683 // load deferred weights.
684 if (cctx.deferredWeightLoader) {
685 const size_t totalNumDeferredWeights = placeholderToDeviceManager.size();
686 LOG(INFO) << "Loading " << totalNumDeferredWeights << " deferred weights";
687
688 auto startTime = std::chrono::steady_clock::now();
689 auto loader = cctx.deferredWeightLoader;
690 // Load the first weight.
691 auto err = loader->loadNextWeight();
692 if (err) {
693 auto val = takeErrorValue(std::move(err));
694 std::string msg = val->logToString();
695 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
696 msg);
697 }
698 std::string weightName = loader->getName();
699 // Load weights while there are weights to be loaded.
700 unsigned int weightCount = 0;
701 while (weightName != "") {
702 LOG(INFO) << "Loading deferred weight (" << ++weightCount << " / "
703 << totalNumDeferredWeights << "): " << weightName;
704 const auto PH = module.getPlaceholderByNameSlow(weightName);
705 if (!PH) {
706 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
707 llvm::formatv("Error loading deferred weight. Name: "
708 "{0} not found in module.",
709 weightName)
710 .str());
711 }
712 // Convert the weight if needed.
713 auto newTy = PH->getType();
714 auto weight = loader->getTensor();
715 auto oldKind = weight->getElementType();
716 // Ensure we are working with a static PH.
717 assert(PH->isStatic());
718 if (!weight->getType().isEqual(newTy)) {
719 ElemKind newK = newTy->getElementType();
720
721 if (!isQuantizedElemKind(oldKind) && isQuantizedElemKind(newK)) {
722 Tensor QT = quantization::quantizeTensor(
723 *weight, {newTy->getScale(), newTy->getOffset()}, newK);
724 weight->assign(&QT);
725 } else {
726 weight->convertToType(newK);
727 }
728 }
729 // Transfer weight to all devices needed.
730 std::list<Error> errors;
731 std::list<std::future<void>> futures;
732 for (const auto &device : placeholderToDeviceManager[PH]) {
733 std::promise<void> transferPromise;
734 errors.emplace_back(Error::empty());
735 futures.emplace_back(transferPromise.get_future());
736 devices_[device]->transferStaticPlaceholderToDevice(
737 PH, weight,
738 [&transferPromise, &error = errors.back()](Error err) mutable {
739 error = std::move(err);
740 transferPromise.set_value();
741 });
742 }
743
744 for (auto &done : futures) {
745 done.get();
746 }
747
748 for (auto &error : errors) {
749 RETURN_IF_ERR(error);
750 }
751
752 err = loader->loadNextWeight();
753 if (err) {
754 auto val = takeErrorValue(std::move(err));
755 std::string msg = val->logToString();
756 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
757 msg);
758 }
759 weightName = loader->getName();
760 // Remove PH from map, this way we can know that we've added all static
761 // PH's
762 placeholderToDeviceManager.erase(PH);
763 }
764 if (placeholderToDeviceManager.size()) {
765 return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR,
766 "Error not all static placeholders were initialized.");
767 }
768
769 std::chrono::duration<double> duration =
770 std::chrono::steady_clock::now() - startTime;
771 LOG(INFO) << "Done loading deferred weights in " << duration.count()
772 << " seconds";
773 }
774 // Init alternate name states.
775 for (auto &network : networks) {
776 for (auto &node : network.nodes) {
777 node->initAlternateState();
778 }
779 }
780
781 cleanupGuard.dismiss();
782 cleanupProvision(localActiveNames, {}, false);
783 return Error::success();
784};
785
786Error Provisioner::provision(DAGListTy &networks, Module &module,
787 CompilationContext &cctx) {
788 return provisionNetwork(
789 glow::make_unique<GlowNetwork>(networks, module, cctx));
790};
791
792#if FACEBOOK_INTERNAL
793Error Provisioner::provisionFX(DAGListTy &networks, Module &module,
794 const FXFunction &FXIR,
795 const llvm::StringMap<const void *> &constants,
796 CompilationContext &cctx) {
797 return provisionNetwork(
798 glow::make_unique<FXNetwork>(networks, module, cctx, FXIR, constants));
799};
800#endif
801
802Backend &Provisioner::getBackend(llvm::StringRef backendName) const {
803 assert(backends_.count(backendName.str()) &&
804 "No backend created by specified name.");
805 return *backends_.at(backendName.str());
806}
807
808Expected<Backend *> Provisioner::getBackend() const {
809 RETURN_ERR_IF_NOT(
810 backends_.size() == 1,
811 strFormat("Expected exactly 1 backend to be found but instead found %zu",
812 backends_.size()));
813 return backends_.begin()->second.get();
814}
815
816Error Provisioner::removeFunction(llvm::StringRef name) {
817 std::lock_guard<std::mutex> functionsLock(functionsLock_);
818 auto it = activeFunctions_.find(name.str());
819 if (it != activeFunctions_.end()) {
820 return MAKE_ERR(
821 ErrorValue::ErrorCode::RUNTIME_NET_BUSY,
822 llvm::formatv("Could not remove network: {0} as it is currently "
823 "being provisioned.",
824 name)
825 .str());
826 }
827 functions_.erase(name.str());
828 return Error::success();
829}
830
831Error Provisioner::evictFunction(llvm::StringRef name, DeviceManager *device,
832 unsigned replicaCount) {
833 std::promise<void> evictPromise;
834 OneErrOnly evictErr;
835 auto done = evictPromise.get_future();
836 device->evictNetwork(name.str(),
837 [&evictPromise, &evictErr](std::string, Error err) {
838 evictErr.set(std::move(err));
839 evictPromise.set_value();
840 });
841 done.get();
842
843 // If we are evict a main function, evict its replications as well.
844 if (replicaCount) {
845 for (unsigned i = 1; i < replicaCount; i++) {
846 auto replicaName = getReplicatedName(name.str(), i);
847 std::promise<void> evictReplicaPromise;
848 auto done = evictReplicaPromise.get_future();
849 device->evictNetwork(replicaName, [&evictReplicaPromise,
850 &evictErr](std::string, Error err) {
851 evictErr.set(std::move(err));
852 evictReplicaPromise.set_value();
853 });
854
855 done.get();
856 }
857 }
858
859 return evictErr.get();
860}
861
862void Provisioner::cleanupProvision(
863 llvm::ArrayRef<std::string> names,
864 std::map<DeviceIDTy, std::vector<std::string>> const
865 &currentNetworkResidency,
866 bool failure) {
867 std::lock_guard<std::mutex> functionLock(functionsLock_);
868 if (failure) {
869 // Remove any partitions added to devices.
870 for (auto &device : currentNetworkResidency) {
871 for (auto &network : device.second) {
872#if FACEBOOK_INTERNAL
873 LOG(INFO) << "Removing network " << network << " from device "
874 << device.first;
875#endif
876 auto replicaCountIdx = functionReplicaCount_.find(network);
877 unsigned replicaCount = 0;
878 if (replicaCountIdx != functionReplicaCount_.end()) {
879 replicaCount = replicaCountIdx->second;
880 }
881 Error evictErr =
882 evictFunction(network, devices_[device.first], replicaCount);
883 if (evictErr) {
884 LOG(ERROR) << "Unable to evict network: " << network << "\n";
885 }
886 }
887 }
888 }
889 // After we've removed the functions from the deviceManagers now free the
890 // compiledFunctions. We free after eviction to ensure the any reference the
891 // DeviceManager has to the compiledFunctions stays valid until after
892 // eviction.
893 for (auto &name : names) {
894 activeFunctions_.erase(name);
895 if (failure) {
896 // Remove any functions added before the failure.
897 functions_.erase(name);
898 }
899 }
900}
901
902void Provisioner::cleanUpSerializedFunctionMap() {
903 serializedFunctionMap_.clear();
904}
905
906// Get the hash as a string from a function's name
907std::string getNameHash(std::string name) {
908 return name.substr(name.find_last_of("_") + 1);
909}
910
911std::unique_ptr<
912 std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>>
913Provisioner::getAllSerializedFunctionsMap() {
914 // Assume all functions in functions_ are using the same backend
915 cleanUpSerializedFunctionMap();
916 for (auto &kv : functions_) {
917 std::string name = kv.first;
918 auto data = kv.second->serialize();
919 if (data != nullptr) {
920 serializedFunctionMap_.emplace(
921 std::make_pair(getNameHash(name), std::move(data)));
922 }
923 }
924 return std::make_unique<
925 std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>>(
926 std::move(serializedFunctionMap_));
927}
928