1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_RUNTIME_PROVISIONER_H |
17 | #define GLOW_RUNTIME_PROVISIONER_H |
18 | |
19 | #include "glow/Backend/Backend.h" |
20 | #include "glow/Backend/BlockStreamBase.h" |
21 | #include "glow/Backends/DeviceManager.h" |
22 | #include "glow/Runtime/RuntimeTypes.h" |
23 | #include "glow/Support/Error.h" |
24 | |
25 | #include <map> |
26 | |
27 | #if FACEBOOK_INTERNAL |
28 | namespace folly { |
29 | struct dynamic; |
30 | |
31 | } // namespace folly |
32 | namespace glow { |
33 | namespace runtime { |
34 | using FXFunction = folly::dynamic; |
35 | } |
36 | } // namespace glow |
37 | #endif |
38 | |
39 | namespace glow { |
40 | namespace runtime { |
41 | |
42 | enum class NetworkType { |
43 | // FX Network |
44 | FX_NETWORK, |
45 | // Glow Module network |
46 | GLOW_NETWORK, |
47 | }; |
48 | |
49 | /// Base struct for passing in a network to Provisioner. It contains all common |
50 | /// elements: DagListTy, Module, and CCTX and a NetworkType denoting what |
51 | /// subclass of network it is. |
52 | struct Network { |
53 | /// Backend used for this config. It is used in |
54 | /// checking the type of config before casting to a derived class. |
55 | const NetworkType networkType; |
56 | |
57 | /// Dag structure for the network to be added. |
58 | DAGListTy &networks; |
59 | |
60 | /// Module containing PH's for the network and in some cases the network. |
61 | Module &module; |
62 | |
63 | /// Compilation Context for the network being added. |
64 | CompilationContext &cctx; |
65 | |
66 | Network(NetworkType netType, DAGListTy &networks, Module &module, |
67 | CompilationContext &cctx) |
68 | : networkType(netType), networks(networks), module(module), cctx(cctx) {} |
69 | virtual ~Network() = default; |
70 | }; |
71 | #if FACEBOOK_INTERNAL |
72 | struct FXNetwork : Network { |
73 | const FXFunction &FXIR; |
74 | const llvm::StringMap<const void *> &constants; |
75 | FXNetwork(DAGListTy &networks, Module &module, CompilationContext &cctx, |
76 | const FXFunction &FXIR, |
77 | const llvm::StringMap<const void *> &constants) |
78 | : Network(NetworkType::FX_NETWORK, networks, module, cctx), FXIR(FXIR), |
79 | constants(constants) {} |
80 | }; |
81 | #endif |
82 | |
83 | struct GlowNetwork : Network { |
84 | GlowNetwork(DAGListTy &networks, Module &module, CompilationContext &cctx) |
85 | : Network(NetworkType::GLOW_NETWORK, networks, module, cctx) {} |
86 | }; |
87 | |
88 | /// The Provisioner is responsible for assigning networks to an actual device. |
89 | /// It also compiles the networks before passing the compiled functions to the |
90 | /// device. |
91 | class Provisioner final { |
92 | public: |
93 | Provisioner(DeviceManagerMapTy &devices); |
94 | |
95 | /// Traverses the DAG \p networks and compiles all the node's Functions from |
96 | /// \p module using \p cctx. Then add compiled functions to assigned devices. |
97 | /// |
98 | /// Pseudocode: |
99 | /// |
100 | /// generate device assignments |
101 | /// create map `optsMap`, `compiledFunctions`, `remainingDeviceCount` |
102 | /// |
103 | /// for each assignment |
104 | /// create vector functionsToCompile |
105 | /// create map functionMap |
106 | /// for each node in logical device |
107 | /// if Function hasn't been compiled before |
108 | /// add Function to `functionsToCompile` |
109 | /// add Function's BackendOptions to `optsMap` |
110 | /// set `remainingDeviceCount` for Function |
111 | /// else |
112 | /// decrease `remainingDeviceCount` for Function by 1 |
113 | /// |
114 | /// call Backend::compiledFunctions with `functionsToCompile` and |
115 | /// `optsMap` |
116 | /// move compiled functions to `compiledFunctions` |
117 | /// |
118 | /// for each node in logical device |
119 | /// add corresponding compiled functions in `compiledFunctions` to |
120 | /// `functionMap` |
121 | /// add replications to `functionMap` using the same compiled function |
122 | /// with a different name |
123 | /// |
124 | /// call DeviceManager::addNetwork with `FunctionMap` |
125 | /// |
126 | /// for each node in logical device |
127 | /// if `remainingDeviceCount` for Function is 0 |
128 | /// free up compilation resources |
129 | /// move corresponding compiled function from `compiledFunctions` |
130 | /// to `Provisioner::functions_` |
131 | Error provision(DAGListTy &networks, Module &module, |
132 | CompilationContext &cctx); |
133 | |
134 | #if FACEBOOK_INTERNAL |
135 | /// Traverses the DAG \p networks and: |
136 | /// 1. Retrieves each node's Function from the provided \p FXIR. |
137 | /// 2. Compiles it using the provided CompilationContext \p cctx. |
138 | /// 3. Assigns a device and calls addNetwork on the chosen device(s). |
139 | /// \returns a Error indicating if the operation was a success. |
140 | Error provisionFX(DAGListTy &networks, Module &module, const FXFunction &FXIR, |
141 | const llvm::StringMap<const void *> &constants, |
142 | CompilationContext &cctx); |
143 | #endif |
144 | // Unified provisioning function, tries to re-use most shared logic between |
145 | // provision and provisionFX. |
146 | Error provisionNetwork(std::unique_ptr<Network> network); |
147 | /// Remove stored compiledFunction. |
148 | Error removeFunction(llvm::StringRef name); |
149 | |
150 | /// Evict function from device. |
151 | Error evictFunction(llvm::StringRef name, DeviceManager *device, |
152 | unsigned replicaCount); |
153 | |
154 | /// \returns a reference to the backend with name \p backendName. |
155 | Backend &getBackend(llvm::StringRef backendName) const; |
156 | |
157 | /// \returns a reference to the Backend if only one Backend is found, |
158 | /// otherwise returns an Error. |
159 | Expected<Backend *> getBackend() const; |
160 | |
161 | /// Update the list of available devices. |
162 | void updateAvailableDevices(const std::vector<DeviceManager *> &devices, |
163 | const std::vector<DeviceIDTy> &mappings) { |
164 | devices_ = devices; |
165 | deviceMappings_ = mappings; |
166 | } |
167 | |
168 | // Extract function streams from functions_ to serializedFunctionMap_, |
169 | // and return a ptr of serializedFunctionMap_. |
170 | // Each time this function called, serializedFunctionMap_ will be regenerated. |
171 | std::unique_ptr< |
172 | std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>> |
173 | getAllSerializedFunctionsMap(); |
174 | |
175 | // Clean up all stored serializedFunctionMap_. |
176 | void cleanUpSerializedFunctionMap(); |
177 | |
178 | private: |
179 | /// Map of backends for all devices, one backend per device type. |
180 | std::unordered_map<std::string, std::unique_ptr<Backend>> backends_; |
181 | |
182 | /// Map of compiledFunction unique pointers. This maintains |
183 | /// ownership of the functions. |
184 | std::unordered_map<std::string, std::unique_ptr<CompiledFunction>> functions_; |
185 | |
186 | /// Map of serialized function pointers, storing all serialized functions on |
187 | /// backends. |
188 | /// Only used in serialization. |
189 | std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>> |
190 | serializedFunctionMap_; |
191 | |
192 | /// Set of active functions - these are functions that are currently being |
193 | /// compiled/added to devices. |
194 | std::set<std::string> activeFunctions_; |
195 | |
196 | /// Mapping from function name to its number of replications |
197 | std::unordered_map<std::string, unsigned> functionReplicaCount_; |
198 | |
199 | /// Mutex for functions_ and activeFunctions_ since add/remove can be called |
200 | /// from multiple threads simultaneously. |
201 | std::mutex functionsLock_; |
202 | |
203 | /// List of available DeviceManagers added during initialization. |
204 | std::vector<DeviceManager *> devices_; |
205 | |
206 | /// Mapping from available devices to deviceID; |
207 | std::vector<DeviceIDTy> deviceMappings_; |
208 | |
209 | /// Helper function to cleanup a provision call. On \p failure free the |
210 | /// compiledFunctions that were created, \p names , and remove networks |
211 | /// already added to devices, \p currentNetworkResidency . |
212 | void cleanupProvision(llvm::ArrayRef<std::string> names, |
213 | std::map<DeviceIDTy, std::vector<std::string>> const |
214 | ¤tNetworkResidency, |
215 | bool failure = true); |
216 | |
217 | /// Helper function to parse the DAG and generate logicalDevices. |
218 | std::map<DeviceIDTy, std::vector<DAGNode *>> |
219 | generateLogicalDevices(const DAGListTy &networks); |
220 | |
221 | /// Helper method to check that new networks don't collide with another |
222 | /// network currently being added. Note: This cannot be called under a lock on |
223 | /// functionsLock_ as it acquires a lock internally. |
224 | Error checkActiveNetworks(const DAGListTy &networks, |
225 | std::vector<std::string> &localActiveNames); |
226 | |
227 | /// This function pairs logical devices with phsyical devices, it sorts both |
228 | /// sets of devices by available memory and attempts to find pairings for all |
229 | /// of then. The output is a map between logicalDevice and PhysicalDevice. |
230 | /// Requires a vector of DeviceID:memorySize pairs \p logicalDeviceSize, \p |
231 | /// deviceMemoryMap a mapping from backendName to a list of device:memorySize |
232 | /// pairs for all devices of the specified backend, and \p logicalDevices a |
233 | /// map of logicalIDs to all associated DAGNodes. |
234 | Expected<std::map<DeviceIDTy, DeviceIDTy>> generateDeviceAssignments( |
235 | const std::vector<std::pair<DeviceIDTy, uint64_t>> &logicalDeviceSize, |
236 | std::map<std::string, std::vector<std::pair<DeviceIDTy, uint64_t>>> |
237 | &deviceMemoryMap, |
238 | std::map<DeviceIDTy, std::vector<DAGNode *>> &logicalDevices); |
239 | }; |
240 | } // namespace runtime |
241 | } // namespace glow |
242 | |
243 | #endif // GLOW_RUNTIME_PROVISIONER_H |
244 | |