1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Runtime/Provisioner/Provisioner.h" |
18 | #include "folly/String.h" |
19 | #include "glow/Backend/BackendUtils.h" |
20 | #include "glow/Backend/CompiledFunction.h" |
21 | #include "glow/Flags/Flags.h" |
22 | #include "glow/Graph/Graph.h" |
23 | #include "glow/Runtime/DeferredWeightLoader.h" |
24 | #include "glow/Support/Debug.h" |
25 | |
26 | #include "llvm/Support/FileSystem.h" |
27 | #include "llvm/Support/FormatVariadic.h" |
28 | |
29 | #include <folly/dynamic.h> |
30 | #include <future> |
31 | #include <map> |
32 | #include <mutex> |
33 | #include <queue> |
34 | #include <set> |
35 | #include <vector> |
36 | |
37 | using namespace glow; |
38 | using namespace runtime; |
39 | |
40 | namespace { |
41 | std::string getReplicatedName(std::string name, unsigned count) { |
42 | return name + "_replicated" + std::to_string(count); |
43 | } |
44 | } // namespace |
45 | |
46 | namespace { |
47 | // STL sorting algorithm cannot inline predicate if it got provided as a regular |
48 | // function. |
49 | // Template instantiation expands std::sort with predicate type as |
50 | // (bool)(const std::pair<DeviceIDTy, uint64_t> &, |
51 | // const std::pair<DeviceIDTy, uint64_t> &). |
52 | // It means any regular function with the above signature will match |
53 | // the template instantiation, and compiler cannot inline the code of |
54 | // one of the possible functions. |
55 | // Declaring lambda, which has a unique type regardless its signature, |
56 | // forces compiler to instantiate the template with a provided unique type and |
57 | // correspondently compiler can inline the lambda code. |
58 | auto sortMostMemory = [](const std::pair<DeviceIDTy, uint64_t> &a, |
59 | const std::pair<DeviceIDTy, uint64_t> &b) -> bool { |
60 | return a.second > b.second; |
61 | }; |
62 | } // namespace |
63 | |
64 | Provisioner::Provisioner(DeviceManagerMapTy &devices) { |
65 | unsigned deviceMapping{0}; |
66 | for (auto &device : devices) { |
67 | devices_.push_back(device.second.get()); |
68 | deviceMappings_.push_back(deviceMapping++); |
69 | auto backendName = device.second->getBackendName().str(); |
70 | if (backends_.find(backendName) == backends_.end()) { |
71 | std::unique_ptr<Backend> newBackend(createBackend(backendName)); |
72 | backends_.emplace(std::string(backendName), std::move(newBackend)); |
73 | } |
74 | } |
75 | } |
76 | |
77 | Error Provisioner::checkActiveNetworks( |
78 | const DAGListTy &networks, std::vector<std::string> &localActiveNames) { |
79 | |
80 | std::lock_guard<std::mutex> networkLock(functionsLock_); |
81 | for (auto &network : networks) { |
82 | #if FACEBOOK_INTERNAL |
83 | LOG(INFO) << "Checking for active networks when adding: " |
84 | << network.root->name; |
85 | #endif |
86 | for (auto &node : network.nodes) { |
87 | // Check to see if another thread is actively working on the same |
88 | // networks. |
89 | if (activeFunctions_.find(node->name) != activeFunctions_.end()) { |
90 | for (auto &name : localActiveNames) { |
91 | activeFunctions_.erase(name); |
92 | } |
93 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_NET_BUSY, |
94 | llvm::formatv("Cannot add the network {0}, as it is " |
95 | "currently being provisioned." , |
96 | node->name) |
97 | .str()); |
98 | } |
99 | #if FACEBOOK_INTERNAL |
100 | LOG(INFO) << "Adding partition name: " << node->name |
101 | << " to activeFunctions_" ; |
102 | #endif |
103 | localActiveNames.push_back(node->name); |
104 | activeFunctions_.insert(node->name); |
105 | } |
106 | } |
107 | return Error::success(); |
108 | } |
109 | |
110 | std::map<DeviceIDTy, std::vector<DAGNode *>> |
111 | Provisioner::generateLogicalDevices(const DAGListTy &networks) { |
112 | // For each network visit all the partitions (nodes) and add the node to each |
113 | // logical device it is assigned to. |
114 | std::map<DeviceIDTy, std::vector<DAGNode *>> logicalDevices; |
115 | for (auto &network : networks) { |
116 | for (auto &node : network.nodes) { |
117 | for (auto logical : node->logicalDevices) { |
118 | auto it = logicalDevices.find(logical); |
119 | if (it != logicalDevices.end()) { |
120 | it->second.push_back(node.get()); |
121 | } else { |
122 | logicalDevices.emplace(logical, std::vector<DAGNode *>{node.get()}); |
123 | } |
124 | } |
125 | } |
126 | } |
127 | return logicalDevices; |
128 | } |
129 | |
130 | /// Helper method to calculate the size of each logical device, returns a |
131 | /// vector of deviceID size pairs sorted in descending order by size. |
132 | static std::vector<std::pair<DeviceIDTy, uint64_t>> calculateLogicalDeviceSize( |
133 | const std::map<DeviceIDTy, std::vector<DAGNode *>> &devices) { |
134 | std::vector<std::pair<DeviceIDTy, uint64_t>> logicalDeviceSize; |
135 | for (auto &device : devices) { |
136 | uint64_t sum{0}; |
137 | for (const auto *node : device.second) { |
138 | sum += node->size; |
139 | } |
140 | logicalDeviceSize.push_back(std::make_pair(device.first, sum)); |
141 | } |
142 | // Sort by total size in descending order. |
143 | std::sort(logicalDeviceSize.begin(), logicalDeviceSize.end(), sortMostMemory); |
144 | return logicalDeviceSize; |
145 | } |
146 | |
147 | Expected<std::map<DeviceIDTy, DeviceIDTy>> |
148 | Provisioner::generateDeviceAssignments( |
149 | const std::vector<std::pair<DeviceIDTy, uint64_t>> &logicalDeviceSize, |
150 | std::map<std::string, std::vector<std::pair<DeviceIDTy, uint64_t>>> |
151 | &deviceMemoryMap, |
152 | std::map<DeviceIDTy, std::vector<DAGNode *>> &logicalDevices) { |
153 | // Generate assignments, logical DeviceID to physical DeviceID. |
154 | std::map<DeviceIDTy, DeviceIDTy> deviceAssignment; |
155 | // Setup iterators for each backend type, intialize them to 0. |
156 | std::map<std::string, unsigned> positions; |
157 | for (auto &device : deviceMemoryMap) { |
158 | positions[device.first] = 0; |
159 | } |
160 | // Walk through the logical devices and assign them a physical device. |
161 | // This approach will try to evenly spread networks across devices, we first |
162 | // sort all devices by available space and then assign in descending order. |
163 | // Once we reach the end we resort and start over. This goes until we are |
164 | // unable to load a network at which point we sort one more time if the first |
165 | // device has enough space we continue, otherwise we return an error. |
166 | // This approach is to prevent many small networks from clumping on a single |
167 | // device. |
168 | for (auto logicalDevice : logicalDeviceSize) { |
169 | // First check that there the requested backend kind is available. |
170 | auto backendName = logicalDevices[logicalDevice.first][0]->backendName; |
171 | if (deviceMemoryMap.find(backendName) == deviceMemoryMap.end()) { |
172 | // Backend is unavailable return an error. |
173 | return MAKE_ERR( |
174 | ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND, |
175 | llvm::formatv("Cannot add the network {0}, as the requested " |
176 | "backend: {1} is unavailable." , |
177 | logicalDevices[logicalDevice.first][0]->name, |
178 | backendName) |
179 | .str()); |
180 | } |
181 | |
182 | auto currentPosition = positions[backendName]; |
183 | if (deviceMemoryMap[backendName][currentPosition].second >= |
184 | logicalDevice.second) { |
185 | // There is enough space, assign the logical device to this physical |
186 | // device, increment the iterator and update the available memory. |
187 | deviceAssignment.emplace( |
188 | logicalDevice.first, |
189 | deviceMemoryMap[backendName][currentPosition].first); |
190 | deviceMemoryMap[backendName][currentPosition].second -= |
191 | logicalDevice.second; |
192 | |
193 | // Check if we are at the end of the vector of devices. |
194 | if (currentPosition == deviceMemoryMap[backendName].size() - 1) { |
195 | // We are at the end of the vector of devices, re-sort and reset |
196 | // position to 0. |
197 | std::sort(deviceMemoryMap[backendName].begin(), |
198 | deviceMemoryMap[backendName].end(), sortMostMemory); |
199 | positions[backendName] = 0; |
200 | } else { |
201 | // Increment current position by one. |
202 | positions[backendName] = currentPosition + 1; |
203 | } |
204 | } else { |
205 | // Before we assume failure we should re-sort the list to see if the |
206 | // current largest amount of available space is enough to fit. |
207 | std::sort(deviceMemoryMap[backendName].begin(), |
208 | deviceMemoryMap[backendName].end(), sortMostMemory); |
209 | if (deviceMemoryMap[backendName][0].second >= logicalDevice.second) { |
210 | // There's a device that still has room, assign the network here. |
211 | deviceAssignment.emplace(logicalDevice.first, |
212 | deviceMemoryMap[backendName][0].first); |
213 | deviceMemoryMap[backendName][0].second -= logicalDevice.second; |
214 | |
215 | // Since after sorting we were abel to add to device 0 set the current |
216 | // position 1 we modulo with the number of devices in case there is only |
217 | // 1 device. |
218 | currentPosition = 1 % deviceMemoryMap[backendName].size(); |
219 | positions[backendName] = currentPosition; |
220 | } else { |
221 | // Return an error there is insufficient space for the logical device on |
222 | // any available device. |
223 | return MAKE_ERR( |
224 | ErrorValue::ErrorCode::RUNTIME_OUT_OF_DEVICE_MEMORY, |
225 | strFormat( |
226 | "Logical Device is too large to fit in available device " |
227 | "memory. Largest device memory: %lu, logic device size: %lu" , |
228 | deviceMemoryMap[backendName][0].second, logicalDevice.second)); |
229 | } |
230 | } |
231 | } |
232 | |
233 | // Update nodes in logicalDevices with their assignments. |
234 | for (auto &assignment : deviceAssignment) { |
235 | for (auto &node : logicalDevices[assignment.first]) { |
236 | node->deviceRuntimeInfos[deviceMappings_[assignment.second]] = |
237 | DeviceRuntimeInfo(); |
238 | } |
239 | } |
240 | return deviceAssignment; |
241 | } |
242 | |
243 | Error Provisioner::provisionNetwork(std::unique_ptr<Network> network) { |
244 | VLOG(1) << "Started provisioner" ; |
245 | DAGListTy &networks = network->networks; |
246 | Module &module = network->module; |
247 | CompilationContext &cctx = network->cctx; |
248 | // Check that the requested networks don't collide with the names of any other |
249 | // networks being added. |
250 | std::vector<std::string> localActiveNames; |
251 | RETURN_IF_ERR(checkActiveNetworks(networks, localActiveNames)); |
252 | |
253 | // Mapping from function name to its compiled function. NB: compiledFunctions |
254 | // will hold compiled function which might be used in clean up process by |
255 | // cleanupGuard, hence this needs to be declared before cleanupGuard. We |
256 | // probably should clean up the compiledFunctions logic to make this more |
257 | // intuitive. |
258 | llvm::StringMap<std::unique_ptr<CompiledFunction>> compiledFunctions; |
259 | |
260 | // If any error happens during the provison process, we will clean up the |
261 | // compiled networks. |
262 | std::map<DeviceIDTy, std::vector<std::string>> addedNetworks; |
263 | ScopeGuard cleanupGuard([&localActiveNames, &addedNetworks, this]() { |
264 | cleanupProvision(localActiveNames, addedNetworks); |
265 | }); |
266 | |
267 | // Walk the networks and group by logicalDeviceId. |
268 | auto logicalDevices = generateLogicalDevices(networks); |
269 | |
270 | if (cctx.backendOpts.collectConstants) { |
271 | VLOG(1) << "Warning: collectConstants is set in a Runtime compile, " |
272 | "ignoring it." ; |
273 | } |
274 | if (cctx.backendOpts.backendHints.SRAMPrioritization.size() != 0 || |
275 | cctx.backendOpts.backendHints.executionUnits) { |
276 | VLOG(1) << "Warning: backendHints is set in a Runtime compile, " |
277 | "ignoring it." ; |
278 | } |
279 | |
280 | // Set collectConstants to false, this is because the DeviceManager will |
281 | // handle moving constants to the device, this way we can eliminate one |
282 | // copy operation. |
283 | cctx.backendOpts.collectConstants = false; |
284 | |
285 | // Calculate the size of each logical device. |
286 | auto logicalDeviceSize = calculateLogicalDeviceSize(logicalDevices); |
287 | |
288 | // Get available memory for all devices. |
289 | std::vector<std::pair<DeviceIDTy, uint64_t>> deviceMemory; |
290 | for (unsigned i = 0; i < devices_.size(); i++) { |
291 | uint64_t availableMemory = devices_[i]->getAvailableMemory(); |
292 | deviceMemory.push_back(std::make_pair(i, availableMemory)); |
293 | } |
294 | |
295 | // Get available device memory, create a map of vectors for each backend kind |
296 | std::map<std::string, std::vector<std::pair<DeviceIDTy, uint64_t>>> |
297 | deviceMemoryMap; |
298 | for (unsigned i = 0; i < devices_.size(); i++) { |
299 | uint64_t availableMemory = devices_[i]->getAvailableMemory(); |
300 | |
301 | deviceMemoryMap[devices_[i]->getBackendName().str()].push_back( |
302 | std::make_pair(i, availableMemory)); |
303 | } |
304 | |
305 | // Sort all vectors in descending order of available memory. |
306 | for (auto &sizes : deviceMemoryMap) { |
307 | std::sort(sizes.second.begin(), sizes.second.end(), sortMostMemory); |
308 | } |
309 | |
310 | // Generate assignments between physical and logical devices. |
311 | auto deviceAssignments = generateDeviceAssignments( |
312 | logicalDeviceSize, deviceMemoryMap, logicalDevices); |
313 | |
314 | VLOG(1) << "Before device assignment" ; |
315 | // Check for errors. |
316 | if (!deviceAssignments) { |
317 | RETURN_ERR(deviceAssignments.takeError()); |
318 | } |
319 | auto assignments = std::move(*deviceAssignments); |
320 | |
321 | VLOG(1) << "Before compile" ; |
322 | |
323 | // Stores function name and the remaining logical device count for that |
324 | // function. |
325 | llvm::StringMap<size_t> remainingDeviceCount; |
326 | // Mapping from function name to its backend options. |
327 | llvm::StringMap<BackendOptions> optsMap; |
328 | |
329 | // Compile and load. |
330 | // This is done one logical device at a time. All functions in a logical |
331 | // device are compiled and then added to their assigned device. If a function |
332 | // is in multiple logical devices it is stored so that it only needs to be |
333 | // compiled once. |
334 | if (network->networkType == NetworkType::GLOW_NETWORK) { |
335 | for (auto &assignment : assignments) { |
336 | auto logicalDevice = assignment.first; |
337 | auto physicalDevice = assignment.second; |
338 | auto deviceBackendName = logicalDevices[logicalDevice][0]->backendName; |
339 | |
340 | if (backends_.find(deviceBackendName) == backends_.end()) { |
341 | // Return error requested device type not found. |
342 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND, |
343 | "Unable to find device of type: " + deviceBackendName); |
344 | } |
345 | |
346 | // Stores all the functions in a logical device. |
347 | std::vector<glow::Function *> functionsToCompile; |
348 | // Stores the compiled functions that will be added to physical device. |
349 | FunctionMapTy functionMap; |
350 | |
351 | // Collect all the functions in a logical device. |
352 | for (auto &node : logicalDevices[logicalDevice]) { |
353 | // If the function name exist we don't need to compile it again. |
354 | if (optsMap.count(node->name)) { |
355 | remainingDeviceCount[node->name] -= 1; |
356 | continue; |
357 | } |
358 | |
359 | auto options = cctx.backendOpts; |
360 | options.backendHints = node->backendHints; |
361 | // Insert all options loaded in the Partitioner alongside options |
362 | // previously inserted, with Partitioner options taking precedence in |
363 | // case of a collision of keys. |
364 | for (auto &it : node->backendSpecificOpts) { |
365 | options.backendSpecificOpts[it.first] = it.second; |
366 | } |
367 | std::lock_guard<std::mutex> functionsLock(functionsLock_); |
368 | Function *function = module.getFunction(node->name); |
369 | |
370 | functionsToCompile.push_back(function); |
371 | optsMap.insert({function->getName(), options}); |
372 | functionReplicaCount_.emplace(node->name, node->replicationCount); |
373 | remainingDeviceCount.insert( |
374 | {node->name, node->logicalDevices.size() - 1}); |
375 | } |
376 | |
377 | // Compile all the functions in the logical device together. |
378 | // We add a lock here because some backends are not threadsafe (CPU |
379 | // backend). |
380 | std::unique_lock<std::mutex> compileLock(functionsLock_); |
381 | auto compiledOrErr = backends_[deviceBackendName]->compileFunctions( |
382 | functionsToCompile, optsMap); |
383 | VLOG(1) << "After compile" ; |
384 | compileLock.unlock(); |
385 | |
386 | // Dump graph and logs |
387 | for (auto *function : functionsToCompile) { |
388 | // Note: This needs to come after compile above because compile may |
389 | // modify the Function as well. |
390 | if (cctx.dumpFinalGraph) { |
391 | auto fname = strFormat( |
392 | "%sfinal_graph_%s_%s.dot" , cctx.dumpGraphPath.c_str(), |
393 | deviceBackendName.c_str(), function->getName().str().c_str()); |
394 | LOG(INFO) << "Dumping final graph to " << fname; |
395 | function->dumpDAG(fname); |
396 | // print stats of node |
397 | std::map<std::string, int> opCounter; |
398 | for (const auto &node : function->getNodes()) { |
399 | opCounter[node.getKindName()]++; |
400 | } |
401 | std::ostringstream ss; |
402 | ss << "Dump of Node stats for Function:\n" ; |
403 | ss << folly::stringPrintf("%30s %13s \n" , "NodeKind" , "Count" ); |
404 | for (const auto &p : opCounter) { |
405 | ss << folly::stringPrintf("%30s %13d \n" , p.first.c_str(), |
406 | p.second); |
407 | } |
408 | LOG(INFO) << ss.str(); |
409 | } |
410 | |
411 | if (glow::flags::DumpCompilationLog) { |
412 | llvm::SmallString<64> path; |
413 | std::string prefix = |
414 | llvm::formatv("{0}-{1}" , cctx.compilationLogPrefix, |
415 | function->getName()) |
416 | .str(); |
417 | auto tempFileRes = |
418 | llvm::sys::fs::createTemporaryFile(prefix, "log" , path); |
419 | if (tempFileRes.value() != 0) { |
420 | LOG(ERROR) |
421 | << "Failed to create temp file for Glow compilation log: " |
422 | << tempFileRes; |
423 | } |
424 | |
425 | function->getLogContext()->dumpLog(path); |
426 | } |
427 | } |
428 | |
429 | // If err return it, else store compiled functions into compiledFunctions. |
430 | if (!compiledOrErr) { |
431 | RETURN_ERR(compiledOrErr.takeError()); |
432 | } |
433 | auto compiled = std::move(*compiledOrErr); |
434 | for (auto &compiledFunction : compiled) { |
435 | |
436 | // Deserialize compiled function from cctx.nameToFunctions |
437 | if (cctx.backendOpts.useDeserialize) { |
438 | std::string name = compiledFunction.first().str(); |
439 | if (cctx.nameToFunctions.find(name) == cctx.nameToFunctions.end()) { |
440 | return MAKE_ERR( |
441 | ErrorValue::ErrorCode::UNKNOWN, |
442 | "Cannot find compiled function when deserializing " + name); |
443 | } |
444 | RETURN_IF_ERR(compiledFunction.second->deserialize( |
445 | *(cctx.nameToFunctions.find(name)->second))); |
446 | } |
447 | compiledFunctions.try_emplace(compiledFunction.first(), |
448 | std::move(compiledFunction.second)); |
449 | } |
450 | // Construnct functionMap for physical device. |
451 | for (auto &node : logicalDevices[logicalDevice]) { |
452 | RETURN_ERR_IF_NOT(compiledFunctions.count(node->name), |
453 | "Can't find corresponding compiled function " + |
454 | node->name); |
455 | |
456 | auto *compiledFunction = compiledFunctions[node->name].get(); |
457 | functionMap.emplace(node->name, compiledFunction); |
458 | |
459 | for (unsigned i = 1; i < node->replicationCount; i++) { |
460 | auto replicatedName = getReplicatedName(node->name, i); |
461 | functionMap.emplace(replicatedName, compiledFunction); |
462 | } |
463 | |
464 | // Dump backend-specific IR |
465 | if (glow::flags::DumpBackendSpecificIRJSON) { |
466 | compiledFunction->dumpJSON(strFormat("%sbackend_specific_ir_%s.json" , |
467 | cctx.dumpGraphPath.c_str(), |
468 | node->name.c_str())); |
469 | } |
470 | |
471 | node->runtimeBundle = glow::make_unique<RuntimeBundle>( |
472 | compiledFunction->getRuntimeBundle()); |
473 | } |
474 | |
475 | // Now that the functions are compiled add them to their assigned device |
476 | // then cleanup. |
477 | std::promise<void> addPromise; |
478 | auto ready = addPromise.get_future(); |
479 | std::unique_ptr<Error> addErr; |
480 | devices_[physicalDevice]->addNetwork( |
481 | &module, functionMap, |
482 | [&addErr, &addPromise](const Module *, Error err) { |
483 | addErr = glow::make_unique<Error>(std::move(err)); |
484 | addPromise.set_value(); |
485 | }); |
486 | ready.wait(); |
487 | DCHECK_NOTNULL(addErr.get()); |
488 | if (*addErr.get()) { |
489 | return std::move(*addErr.get()); |
490 | } |
491 | |
492 | // Add networks successfully loaded on device to addedNetworks, this way |
493 | // if we fail later we can evict them. |
494 | for (const auto &func : functionMap) { |
495 | addedNetworks[physicalDevice].push_back(func.first); |
496 | } |
497 | VLOG(1) << "Added networks" ; |
498 | |
499 | // Free up memory no longer needed by the compiledFunction. |
500 | for (auto &node : logicalDevices[logicalDevice]) { |
501 | // If the compiled function still needs to be added to other device, |
502 | // don't free the resources. |
503 | if (remainingDeviceCount[node->name] > 0) { |
504 | continue; |
505 | } |
506 | |
507 | // Free compilation resources. This need to be done after add network |
508 | // and before move on to next logical device. If |
509 | // DisableFreeCompilationResource is true, we will not free it here. |
510 | // This is used in scenarios like model serialization. |
511 | auto &funtionPtr = compiledFunctions[node->name]; |
512 | if (!glow::flags::DisableFreeCompilationResource) { |
513 | funtionPtr->freeCompilationResources(); |
514 | } |
515 | |
516 | // Move compiled functions from compiledFunctions to functions_. |
517 | { |
518 | std::lock_guard<std::mutex> functionsLock(functionsLock_); |
519 | functions_.emplace(node->name, std::move(funtionPtr)); |
520 | } |
521 | |
522 | compiledFunctions.erase(node->name); |
523 | } |
524 | } |
525 | } else if (network->networkType == NetworkType::FX_NETWORK) { |
526 | #if FACEBOOK_INTERNAL |
527 | // Container for duplicated functions and map tracking remaining installs |
528 | // for a duplicated function. |
529 | std::map<std::string, std::unique_ptr<CompiledFunction>> |
530 | duplicatedFunctions; |
531 | std::map<DAGNode *, unsigned> remainingDuplications; |
532 | for (auto &assignment : assignments) { |
533 | auto logicalDevice = assignment.first; |
534 | auto physicalDevice = assignment.second; |
535 | auto deviceBackendName = logicalDevices[logicalDevice][0]->backendName; |
536 | FunctionMapTy functionMap; |
537 | // Container for the compiledFunctions for this logicalDevice. |
538 | std::map<std::string, std::unique_ptr<CompiledFunction>> |
539 | compiledFunctions; |
540 | |
541 | for (auto &node : logicalDevices[logicalDevice]) { |
542 | // Check if this is a duplicated function that has already been |
543 | // compiled. |
544 | if (duplicatedFunctions.find(node->name) != duplicatedFunctions.end()) { |
545 | functionMap.emplace(node->name, |
546 | duplicatedFunctions[node->name].get()); |
547 | remainingDuplications[node] -= 1; |
548 | } else { |
549 | // Compile and add to function map. |
550 | auto options = cctx.backendOpts; |
551 | options.backendHints = node->backendHints; |
552 | // Insert all options loaded in the Partitioner alongside options |
553 | // previously inserted, with Partitioner options taking precedence in |
554 | // case of a collision of keys. |
555 | for (auto &it : node->backendSpecificOpts) { |
556 | options.backendSpecificOpts[it.first] = it.second; |
557 | } |
558 | if (backends_.find(deviceBackendName) == backends_.end()) { |
559 | // Return error requested device type not found. |
560 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND, |
561 | "Unable to find device of type: " + |
562 | deviceBackendName); |
563 | } |
564 | auto fxNetwork = static_cast<FXNetwork *>(network.get()); |
565 | auto compiledOrErr = backends_[deviceBackendName]->compileFX( |
566 | fxNetwork->FXIR, node->name, fxNetwork->constants, options, |
567 | &module); |
568 | |
569 | // Check to see if an error was encountered while compiling. |
570 | if (!compiledOrErr) { |
571 | // If an error occured return the error. |
572 | RETURN_ERR(compiledOrErr.takeError()); |
573 | } |
574 | auto compiled = std::move(*compiledOrErr); |
575 | |
576 | node->runtimeBundle = |
577 | glow::make_unique<RuntimeBundle>(compiled->getRuntimeBundle()); |
578 | |
579 | functionMap.emplace(node->name, compiled.get()); |
580 | // If this function is in more than one logical device store it for |
581 | // reuse. |
582 | if (node->logicalDevices.size() > 1) { |
583 | duplicatedFunctions.emplace(node->name, std::move(compiled)); |
584 | remainingDuplications[node] = node->logicalDevices.size() - 1; |
585 | } else { |
586 | compiledFunctions.emplace(node->name, std::move(compiled)); |
587 | } |
588 | } |
589 | } |
590 | VLOG(1) << "After compile" ; |
591 | |
592 | // Now that the functions are compiled add them to their assigned device |
593 | // then cleanup. |
594 | std::promise<void> addPromise; |
595 | auto ready = addPromise.get_future(); |
596 | std::unique_ptr<Error> addErr; |
597 | devices_[physicalDevice]->addNetwork( |
598 | &module, functionMap, |
599 | [&addErr, &addPromise](const Module *, Error err) { |
600 | addErr = glow::make_unique<Error>(std::move(err)); |
601 | addPromise.set_value(); |
602 | }); |
603 | ready.wait(); |
604 | DCHECK_NOTNULL(addErr.get()); |
605 | if (*addErr.get()) { |
606 | return std::move(*addErr.get()); |
607 | } |
608 | // Add networks successfully loaded on device to addedNetworks, this way |
609 | // if we fail later we can evict them. |
610 | for (auto &node : logicalDevices[logicalDevice]) { |
611 | addedNetworks[physicalDevice].push_back(node->name); |
612 | } |
613 | VLOG(1) << "Added networks" ; |
614 | |
615 | // Free up memory no longer needed by the compiledFunction. |
616 | for (auto &func : compiledFunctions) { |
617 | func.second->freeCompilationResources(); |
618 | } |
619 | { |
620 | // Move compiled functions from compiledFunctions to functions_. |
621 | std::lock_guard<std::mutex> functionsLock(functionsLock_); |
622 | for (auto &func : compiledFunctions) { |
623 | functions_.emplace(func.first, std::move(func.second)); |
624 | } |
625 | // Check if any of the duplicated functions can also be moved. |
626 | for (auto iter = remainingDuplications.begin(); |
627 | iter != remainingDuplications.end();) { |
628 | const auto &func = *iter; |
629 | if (func.second == 0) { |
630 | duplicatedFunctions[func.first->name]->freeCompilationResources(); |
631 | functions_.emplace( |
632 | func.first->name, |
633 | std::move(duplicatedFunctions[func.first->name])); |
634 | duplicatedFunctions.erase(func.first->name); |
635 | iter = remainingDuplications.erase(iter); |
636 | } else { |
637 | ++iter; |
638 | } |
639 | } |
640 | } |
641 | } |
642 | #endif |
643 | } |
644 | RETURN_ERR_IF_NOT(compiledFunctions.empty(), |
645 | "compiledFunctions should be empty because all compiled " |
646 | "functions should be moved to Provisioner::function_" ); |
647 | |
648 | // Map from Placeholder* to DeviceManager, this is used for deferred weight |
649 | // loading. |
650 | std::unordered_map<Placeholder *, std::vector<unsigned>> |
651 | placeholderToDeviceManager; |
652 | if (cctx.deferredWeightLoader) { |
653 | // Populate placeholdeToDeviceManager map. |
654 | for (auto &assignment : assignments) { |
655 | for (const auto &node : logicalDevices[assignment.first]) { |
656 | auto symbolTable = node->runtimeBundle->getSymbolTable(); |
657 | for (auto info : symbolTable) { |
658 | if (info.second.symbolCategory == |
659 | glow::runtime::SymbolCategory::Placeholder) { |
660 | auto PH = module.getPlaceholderByNameSlow(info.first); |
661 | if (PH->isStatic()) { |
662 | placeholderToDeviceManager[PH].push_back(assignment.second); |
663 | } |
664 | } |
665 | } |
666 | } |
667 | } |
668 | } else { |
669 | // Make sure there are no static placeholders. |
670 | for (auto PH : module.getPlaceholders()) { |
671 | if (PH->isStatic()) { |
672 | return MAKE_ERR( |
673 | ErrorValue::ErrorCode::RUNTIME_ERROR, |
674 | llvm::formatv("Error Placholder: {0} is marked as static but no " |
675 | "deferredWeightLoader is provided." , |
676 | PH->getName()) |
677 | .str()); |
678 | ; |
679 | } |
680 | } |
681 | } |
682 | // If a deferredWeightLoader is provided, create a deferredWeightLoader and |
683 | // load deferred weights. |
684 | if (cctx.deferredWeightLoader) { |
685 | const size_t totalNumDeferredWeights = placeholderToDeviceManager.size(); |
686 | LOG(INFO) << "Loading " << totalNumDeferredWeights << " deferred weights" ; |
687 | |
688 | auto startTime = std::chrono::steady_clock::now(); |
689 | auto loader = cctx.deferredWeightLoader; |
690 | // Load the first weight. |
691 | auto err = loader->loadNextWeight(); |
692 | if (err) { |
693 | auto val = takeErrorValue(std::move(err)); |
694 | std::string msg = val->logToString(); |
695 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR, |
696 | msg); |
697 | } |
698 | std::string weightName = loader->getName(); |
699 | // Load weights while there are weights to be loaded. |
700 | unsigned int weightCount = 0; |
701 | while (weightName != "" ) { |
702 | LOG(INFO) << "Loading deferred weight (" << ++weightCount << " / " |
703 | << totalNumDeferredWeights << "): " << weightName; |
704 | const auto PH = module.getPlaceholderByNameSlow(weightName); |
705 | if (!PH) { |
706 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR, |
707 | llvm::formatv("Error loading deferred weight. Name: " |
708 | "{0} not found in module." , |
709 | weightName) |
710 | .str()); |
711 | } |
712 | // Convert the weight if needed. |
713 | auto newTy = PH->getType(); |
714 | auto weight = loader->getTensor(); |
715 | auto oldKind = weight->getElementType(); |
716 | // Ensure we are working with a static PH. |
717 | assert(PH->isStatic()); |
718 | if (!weight->getType().isEqual(newTy)) { |
719 | ElemKind newK = newTy->getElementType(); |
720 | |
721 | if (!isQuantizedElemKind(oldKind) && isQuantizedElemKind(newK)) { |
722 | Tensor QT = quantization::quantizeTensor( |
723 | *weight, {newTy->getScale(), newTy->getOffset()}, newK); |
724 | weight->assign(&QT); |
725 | } else { |
726 | weight->convertToType(newK); |
727 | } |
728 | } |
729 | // Transfer weight to all devices needed. |
730 | std::list<Error> errors; |
731 | std::list<std::future<void>> futures; |
732 | for (const auto &device : placeholderToDeviceManager[PH]) { |
733 | std::promise<void> transferPromise; |
734 | errors.emplace_back(Error::empty()); |
735 | futures.emplace_back(transferPromise.get_future()); |
736 | devices_[device]->transferStaticPlaceholderToDevice( |
737 | PH, weight, |
738 | [&transferPromise, &error = errors.back()](Error err) mutable { |
739 | error = std::move(err); |
740 | transferPromise.set_value(); |
741 | }); |
742 | } |
743 | |
744 | for (auto &done : futures) { |
745 | done.get(); |
746 | } |
747 | |
748 | for (auto &error : errors) { |
749 | RETURN_IF_ERR(error); |
750 | } |
751 | |
752 | err = loader->loadNextWeight(); |
753 | if (err) { |
754 | auto val = takeErrorValue(std::move(err)); |
755 | std::string msg = val->logToString(); |
756 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR, |
757 | msg); |
758 | } |
759 | weightName = loader->getName(); |
760 | // Remove PH from map, this way we can know that we've added all static |
761 | // PH's |
762 | placeholderToDeviceManager.erase(PH); |
763 | } |
764 | if (placeholderToDeviceManager.size()) { |
765 | return MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEFERRED_WEIGHT_ERROR, |
766 | "Error not all static placeholders were initialized." ); |
767 | } |
768 | |
769 | std::chrono::duration<double> duration = |
770 | std::chrono::steady_clock::now() - startTime; |
771 | LOG(INFO) << "Done loading deferred weights in " << duration.count() |
772 | << " seconds" ; |
773 | } |
774 | // Init alternate name states. |
775 | for (auto &network : networks) { |
776 | for (auto &node : network.nodes) { |
777 | node->initAlternateState(); |
778 | } |
779 | } |
780 | |
781 | cleanupGuard.dismiss(); |
782 | cleanupProvision(localActiveNames, {}, false); |
783 | return Error::success(); |
784 | }; |
785 | |
786 | Error Provisioner::provision(DAGListTy &networks, Module &module, |
787 | CompilationContext &cctx) { |
788 | return provisionNetwork( |
789 | glow::make_unique<GlowNetwork>(networks, module, cctx)); |
790 | }; |
791 | |
792 | #if FACEBOOK_INTERNAL |
793 | Error Provisioner::provisionFX(DAGListTy &networks, Module &module, |
794 | const FXFunction &FXIR, |
795 | const llvm::StringMap<const void *> &constants, |
796 | CompilationContext &cctx) { |
797 | return provisionNetwork( |
798 | glow::make_unique<FXNetwork>(networks, module, cctx, FXIR, constants)); |
799 | }; |
800 | #endif |
801 | |
802 | Backend &Provisioner::getBackend(llvm::StringRef backendName) const { |
803 | assert(backends_.count(backendName.str()) && |
804 | "No backend created by specified name." ); |
805 | return *backends_.at(backendName.str()); |
806 | } |
807 | |
808 | Expected<Backend *> Provisioner::getBackend() const { |
809 | RETURN_ERR_IF_NOT( |
810 | backends_.size() == 1, |
811 | strFormat("Expected exactly 1 backend to be found but instead found %zu" , |
812 | backends_.size())); |
813 | return backends_.begin()->second.get(); |
814 | } |
815 | |
816 | Error Provisioner::removeFunction(llvm::StringRef name) { |
817 | std::lock_guard<std::mutex> functionsLock(functionsLock_); |
818 | auto it = activeFunctions_.find(name.str()); |
819 | if (it != activeFunctions_.end()) { |
820 | return MAKE_ERR( |
821 | ErrorValue::ErrorCode::RUNTIME_NET_BUSY, |
822 | llvm::formatv("Could not remove network: {0} as it is currently " |
823 | "being provisioned." , |
824 | name) |
825 | .str()); |
826 | } |
827 | functions_.erase(name.str()); |
828 | return Error::success(); |
829 | } |
830 | |
831 | Error Provisioner::evictFunction(llvm::StringRef name, DeviceManager *device, |
832 | unsigned replicaCount) { |
833 | std::promise<void> evictPromise; |
834 | OneErrOnly evictErr; |
835 | auto done = evictPromise.get_future(); |
836 | device->evictNetwork(name.str(), |
837 | [&evictPromise, &evictErr](std::string, Error err) { |
838 | evictErr.set(std::move(err)); |
839 | evictPromise.set_value(); |
840 | }); |
841 | done.get(); |
842 | |
843 | // If we are evict a main function, evict its replications as well. |
844 | if (replicaCount) { |
845 | for (unsigned i = 1; i < replicaCount; i++) { |
846 | auto replicaName = getReplicatedName(name.str(), i); |
847 | std::promise<void> evictReplicaPromise; |
848 | auto done = evictReplicaPromise.get_future(); |
849 | device->evictNetwork(replicaName, [&evictReplicaPromise, |
850 | &evictErr](std::string, Error err) { |
851 | evictErr.set(std::move(err)); |
852 | evictReplicaPromise.set_value(); |
853 | }); |
854 | |
855 | done.get(); |
856 | } |
857 | } |
858 | |
859 | return evictErr.get(); |
860 | } |
861 | |
862 | void Provisioner::cleanupProvision( |
863 | llvm::ArrayRef<std::string> names, |
864 | std::map<DeviceIDTy, std::vector<std::string>> const |
865 | ¤tNetworkResidency, |
866 | bool failure) { |
867 | std::lock_guard<std::mutex> functionLock(functionsLock_); |
868 | if (failure) { |
869 | // Remove any partitions added to devices. |
870 | for (auto &device : currentNetworkResidency) { |
871 | for (auto &network : device.second) { |
872 | #if FACEBOOK_INTERNAL |
873 | LOG(INFO) << "Removing network " << network << " from device " |
874 | << device.first; |
875 | #endif |
876 | auto replicaCountIdx = functionReplicaCount_.find(network); |
877 | unsigned replicaCount = 0; |
878 | if (replicaCountIdx != functionReplicaCount_.end()) { |
879 | replicaCount = replicaCountIdx->second; |
880 | } |
881 | Error evictErr = |
882 | evictFunction(network, devices_[device.first], replicaCount); |
883 | if (evictErr) { |
884 | LOG(ERROR) << "Unable to evict network: " << network << "\n" ; |
885 | } |
886 | } |
887 | } |
888 | } |
889 | // After we've removed the functions from the deviceManagers now free the |
890 | // compiledFunctions. We free after eviction to ensure the any reference the |
891 | // DeviceManager has to the compiledFunctions stays valid until after |
892 | // eviction. |
893 | for (auto &name : names) { |
894 | activeFunctions_.erase(name); |
895 | if (failure) { |
896 | // Remove any functions added before the failure. |
897 | functions_.erase(name); |
898 | } |
899 | } |
900 | } |
901 | |
902 | void Provisioner::cleanUpSerializedFunctionMap() { |
903 | serializedFunctionMap_.clear(); |
904 | } |
905 | |
906 | // Get the hash as a string from a function's name |
907 | std::string getNameHash(std::string name) { |
908 | return name.substr(name.find_last_of("_" ) + 1); |
909 | } |
910 | |
911 | std::unique_ptr< |
912 | std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>> |
913 | Provisioner::getAllSerializedFunctionsMap() { |
914 | // Assume all functions in functions_ are using the same backend |
915 | cleanUpSerializedFunctionMap(); |
916 | for (auto &kv : functions_) { |
917 | std::string name = kv.first; |
918 | auto data = kv.second->serialize(); |
919 | if (data != nullptr) { |
920 | serializedFunctionMap_.emplace( |
921 | std::make_pair(getNameHash(name), std::move(data))); |
922 | } |
923 | } |
924 | return std::make_unique< |
925 | std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>>( |
926 | std::move(serializedFunctionMap_)); |
927 | } |
928 | |