1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_RUNTIME_PROVISIONER_H
17#define GLOW_RUNTIME_PROVISIONER_H
18
19#include "glow/Backend/Backend.h"
20#include "glow/Backend/BlockStreamBase.h"
21#include "glow/Backends/DeviceManager.h"
22#include "glow/Runtime/RuntimeTypes.h"
23#include "glow/Support/Error.h"
24
25#include <map>
26
27#if FACEBOOK_INTERNAL
28namespace folly {
29struct dynamic;
30
31} // namespace folly
32namespace glow {
33namespace runtime {
34using FXFunction = folly::dynamic;
35}
36} // namespace glow
37#endif
38
39namespace glow {
40namespace runtime {
41
42enum class NetworkType {
43 // FX Network
44 FX_NETWORK,
45 // Glow Module network
46 GLOW_NETWORK,
47};
48
49/// Base struct for passing in a network to Provisioner. It contains all common
50/// elements: DagListTy, Module, and CCTX and a NetworkType denoting what
51/// subclass of network it is.
52struct Network {
53 /// Backend used for this config. It is used in
54 /// checking the type of config before casting to a derived class.
55 const NetworkType networkType;
56
57 /// Dag structure for the network to be added.
58 DAGListTy &networks;
59
60 /// Module containing PH's for the network and in some cases the network.
61 Module &module;
62
63 /// Compilation Context for the network being added.
64 CompilationContext &cctx;
65
66 Network(NetworkType netType, DAGListTy &networks, Module &module,
67 CompilationContext &cctx)
68 : networkType(netType), networks(networks), module(module), cctx(cctx) {}
69 virtual ~Network() = default;
70};
71#if FACEBOOK_INTERNAL
72struct FXNetwork : Network {
73 const FXFunction &FXIR;
74 const llvm::StringMap<const void *> &constants;
75 FXNetwork(DAGListTy &networks, Module &module, CompilationContext &cctx,
76 const FXFunction &FXIR,
77 const llvm::StringMap<const void *> &constants)
78 : Network(NetworkType::FX_NETWORK, networks, module, cctx), FXIR(FXIR),
79 constants(constants) {}
80};
81#endif
82
83struct GlowNetwork : Network {
84 GlowNetwork(DAGListTy &networks, Module &module, CompilationContext &cctx)
85 : Network(NetworkType::GLOW_NETWORK, networks, module, cctx) {}
86};
87
88/// The Provisioner is responsible for assigning networks to an actual device.
89/// It also compiles the networks before passing the compiled functions to the
90/// device.
91class Provisioner final {
92public:
93 Provisioner(DeviceManagerMapTy &devices);
94
95 /// Traverses the DAG \p networks and compiles all the node's Functions from
96 /// \p module using \p cctx. Then add compiled functions to assigned devices.
97 ///
98 /// Pseudocode:
99 ///
100 /// generate device assignments
101 /// create map `optsMap`, `compiledFunctions`, `remainingDeviceCount`
102 ///
103 /// for each assignment
104 /// create vector functionsToCompile
105 /// create map functionMap
106 /// for each node in logical device
107 /// if Function hasn't been compiled before
108 /// add Function to `functionsToCompile`
109 /// add Function's BackendOptions to `optsMap`
110 /// set `remainingDeviceCount` for Function
111 /// else
112 /// decrease `remainingDeviceCount` for Function by 1
113 ///
114 /// call Backend::compiledFunctions with `functionsToCompile` and
115 /// `optsMap`
116 /// move compiled functions to `compiledFunctions`
117 ///
118 /// for each node in logical device
119 /// add corresponding compiled functions in `compiledFunctions` to
120 /// `functionMap`
121 /// add replications to `functionMap` using the same compiled function
122 /// with a different name
123 ///
124 /// call DeviceManager::addNetwork with `FunctionMap`
125 ///
126 /// for each node in logical device
127 /// if `remainingDeviceCount` for Function is 0
128 /// free up compilation resources
129 /// move corresponding compiled function from `compiledFunctions`
130 /// to `Provisioner::functions_`
131 Error provision(DAGListTy &networks, Module &module,
132 CompilationContext &cctx);
133
134#if FACEBOOK_INTERNAL
135 /// Traverses the DAG \p networks and:
136 /// 1. Retrieves each node's Function from the provided \p FXIR.
137 /// 2. Compiles it using the provided CompilationContext \p cctx.
138 /// 3. Assigns a device and calls addNetwork on the chosen device(s).
139 /// \returns a Error indicating if the operation was a success.
140 Error provisionFX(DAGListTy &networks, Module &module, const FXFunction &FXIR,
141 const llvm::StringMap<const void *> &constants,
142 CompilationContext &cctx);
143#endif
144 // Unified provisioning function, tries to re-use most shared logic between
145 // provision and provisionFX.
146 Error provisionNetwork(std::unique_ptr<Network> network);
147 /// Remove stored compiledFunction.
148 Error removeFunction(llvm::StringRef name);
149
150 /// Evict function from device.
151 Error evictFunction(llvm::StringRef name, DeviceManager *device,
152 unsigned replicaCount);
153
154 /// \returns a reference to the backend with name \p backendName.
155 Backend &getBackend(llvm::StringRef backendName) const;
156
157 /// \returns a reference to the Backend if only one Backend is found,
158 /// otherwise returns an Error.
159 Expected<Backend *> getBackend() const;
160
161 /// Update the list of available devices.
162 void updateAvailableDevices(const std::vector<DeviceManager *> &devices,
163 const std::vector<DeviceIDTy> &mappings) {
164 devices_ = devices;
165 deviceMappings_ = mappings;
166 }
167
168 // Extract function streams from functions_ to serializedFunctionMap_,
169 // and return a ptr of serializedFunctionMap_.
170 // Each time this function called, serializedFunctionMap_ will be regenerated.
171 std::unique_ptr<
172 std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>>
173 getAllSerializedFunctionsMap();
174
175 // Clean up all stored serializedFunctionMap_.
176 void cleanUpSerializedFunctionMap();
177
178private:
179 /// Map of backends for all devices, one backend per device type.
180 std::unordered_map<std::string, std::unique_ptr<Backend>> backends_;
181
182 /// Map of compiledFunction unique pointers. This maintains
183 /// ownership of the functions.
184 std::unordered_map<std::string, std::unique_ptr<CompiledFunction>> functions_;
185
186 /// Map of serialized function pointers, storing all serialized functions on
187 /// backends.
188 /// Only used in serialization.
189 std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>
190 serializedFunctionMap_;
191
192 /// Set of active functions - these are functions that are currently being
193 /// compiled/added to devices.
194 std::set<std::string> activeFunctions_;
195
196 /// Mapping from function name to its number of replications
197 std::unordered_map<std::string, unsigned> functionReplicaCount_;
198
199 /// Mutex for functions_ and activeFunctions_ since add/remove can be called
200 /// from multiple threads simultaneously.
201 std::mutex functionsLock_;
202
203 /// List of available DeviceManagers added during initialization.
204 std::vector<DeviceManager *> devices_;
205
206 /// Mapping from available devices to deviceID;
207 std::vector<DeviceIDTy> deviceMappings_;
208
209 /// Helper function to cleanup a provision call. On \p failure free the
210 /// compiledFunctions that were created, \p names , and remove networks
211 /// already added to devices, \p currentNetworkResidency .
212 void cleanupProvision(llvm::ArrayRef<std::string> names,
213 std::map<DeviceIDTy, std::vector<std::string>> const
214 &currentNetworkResidency,
215 bool failure = true);
216
217 /// Helper function to parse the DAG and generate logicalDevices.
218 std::map<DeviceIDTy, std::vector<DAGNode *>>
219 generateLogicalDevices(const DAGListTy &networks);
220
221 /// Helper method to check that new networks don't collide with another
222 /// network currently being added. Note: This cannot be called under a lock on
223 /// functionsLock_ as it acquires a lock internally.
224 Error checkActiveNetworks(const DAGListTy &networks,
225 std::vector<std::string> &localActiveNames);
226
227 /// This function pairs logical devices with phsyical devices, it sorts both
228 /// sets of devices by available memory and attempts to find pairings for all
229 /// of then. The output is a map between logicalDevice and PhysicalDevice.
230 /// Requires a vector of DeviceID:memorySize pairs \p logicalDeviceSize, \p
231 /// deviceMemoryMap a mapping from backendName to a list of device:memorySize
232 /// pairs for all devices of the specified backend, and \p logicalDevices a
233 /// map of logicalIDs to all associated DAGNodes.
234 Expected<std::map<DeviceIDTy, DeviceIDTy>> generateDeviceAssignments(
235 const std::vector<std::pair<DeviceIDTy, uint64_t>> &logicalDeviceSize,
236 std::map<std::string, std::vector<std::pair<DeviceIDTy, uint64_t>>>
237 &deviceMemoryMap,
238 std::map<DeviceIDTy, std::vector<DAGNode *>> &logicalDevices);
239};
240} // namespace runtime
241} // namespace glow
242
243#endif // GLOW_RUNTIME_PROVISIONER_H
244