1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_RUNTIME_HOSTMANAGERR_HOSTMANAGER_H
17#define GLOW_RUNTIME_HOSTMANAGERR_HOSTMANAGER_H
18
19#include "glow/Backend/Backend.h"
20#include "glow/Backends/DeviceManager.h"
21#include "glow/Graph/Graph.h"
22#include "glow/Runtime/Executor/Executor.h"
23#include "glow/Runtime/Provisioner/Provisioner.h"
24#include "glow/Runtime/RuntimeTypes.h"
25#include "glow/Runtime/StatsExporter.h"
26
27#include <atomic>
28#include <map>
29#include <mutex>
30#include <queue>
31#include <shared_mutex>
32#include <unordered_map>
33#include <vector>
34
35#if FACEBOOK_INTERNAL
36namespace folly {
37struct dynamic;
38}
39#endif
40
41namespace glow {
42namespace runtime {
43/// The HostManager serves as an entry point into the Runtime environment. It
44/// provides an interface to add, run, and evict networks from the host. It
45/// handles DeviceManager initialization, houses the Executor, and calls into
46/// the Partitioner and Provisioner for network initialization.
47class HostManager final {
48 /// NetworkData contains data about each network in HostManager that is needed
49 /// by the runtime.
50 struct NetworkData {
51 DAG dag{};
52 // Module that was used to create this network. Everything except
53 // placeholders and types have been removed from it.
54 std::shared_ptr<Module> module{nullptr};
55
56 /// use an atomic refcount rather than just store a shared_ptr for thread
57 /// safety.
58 std::atomic<size_t> refcount{0};
59 };
60 /// Container for inference requests waiting in the queue.
61 struct InferRequest {
62 /// Name of the network the requested run is for.
63 std::string networkName;
64
65 /// The execution context for the request.
66 std::unique_ptr<ExecutionContext> context;
67
68 /// The user provided callback to run after execution finishes.
69 ResultCBTy callback;
70
71 /// The specified priority for the run.
72 uint64_t priority;
73
74 /// The runtime generated ID for the run request.
75 uint64_t requestID;
76
77 /// Timestamp for request creation.
78 uint64_t startTime;
79
80 // Define greater than operator to allow sorting in priority_heap for queue
81 // reqests. If priority is the same fall back to order of submission.
82 bool operator>(const InferRequest &inferReq) const {
83 if (priority == inferReq.priority) {
84 return requestID > inferReq.requestID;
85 }
86 return priority > inferReq.priority;
87 }
88 InferRequest(std::string networkName,
89 std::unique_ptr<ExecutionContext> context, ResultCBTy callback,
90 uint64_t priority, uint64_t requestID, uint64_t startTime = 0)
91 : networkName{networkName}, context{std::move(context)},
92 callback{callback}, priority{priority}, requestID{requestID},
93 startTime{startTime} {}
94 };
95
96 /// Count of current in-flight networks being run. Atomic to allow
97 /// concurrency in runNetwork.
98 std::atomic<size_t> activeRequestCount_{0};
99
100 /// Count of total requests, this is used as a run ID. Atomic to allow
101 /// concurrency in runNetwork.
102 std::atomic<size_t> totalRequestCount_{0};
103
104 /// Priority queue for queued requests. This is a min-heap so lowest value is
105 /// popped first.
106 std::priority_queue<InferRequest, std::vector<InferRequest>,
107 std::greater<InferRequest>>
108 inferQueue_;
109
110 /// Lock for the priority queue above. Please make sure whenever you want to
111 /// access inferQueue_, you take a lock. Usage is the same as
112 /// std::shared_mutex
113 std::shared_timed_mutex inferQueueLock_;
114
115 /// Configuration parameters for this Runtime Host.
116 HostConfig config_{};
117
118 std::unique_ptr<TraceContext> hostTraceContext_;
119
120 /// A map from a networkName to a network, which is represented by struct DAG.
121 std::unordered_map<std::string, NetworkData> networks_;
122
123 /// Mutex for networks_ since runNetwork, addNetwork, and
124 /// removeNetwork can all be called concurrently, a guard is needed.
125 std::shared_timed_mutex networkLock_;
126
127 /// A map of DeviceManagers by deviceID. An ordered map is used here to allow
128 /// a stable iteration order over devices.
129 DeviceManagerMapTy devices_;
130
131 /// A vector of devices available for new networks to be added to.
132 std::vector<DeviceIDTy> availableDevices_;
133
134 /// A single threaded threadpool used by init() when initializing devices.
135 ThreadPool threadPool_{1};
136
137 /// Executor class, this handles dispatching execution requests to the
138 /// appropriate device managers for an inference request.
139 std::unique_ptr<Executor> executor_;
140
141 /// The provisioner owns the compiledFunctions and handles loading functions
142 /// onto the devices.
143 std::unique_ptr<Provisioner> provisioner_;
144
145 /// String const for logging max queue size in glow
146 static constexpr const char *kMaxQueueSize = "glow.queue.max.size";
147
148 /// String const for logging total device memory usage.
149 static constexpr const char *kDeviceMemoryUsed =
150 "glow.devices.used_memory.total";
151
152 /// String const for logging total available device memory.
153 static constexpr const char *kDeviceMemoryAvailable =
154 "glow.devices.available_memory.total";
155
156 /// String const for logging total maximum device memory.
157 static constexpr const char *kDeviceMemoryMax =
158 "glow.devices.maximum_memory.total";
159
160 /// String const for logging device fatal errors.
161 static constexpr const char *kDeviceFatalError =
162 "glow.devices.fatal_compilation_error";
163
164 /// Helper function to handle cleanup if an error occurs during addNetwork.
165 /// This must be called while holding the a lock on networkLock_.
166 void cleanupAddNetwork(llvm::ArrayRef<std::string> names);
167
168 /// Set of networks in the process of being added.
169 std::set<std::string> processingNetworks_;
170
171 /// Method to dispatch a new run to the executor.
172 void dispatchNextRun();
173
174 /// Method to calculate and export aggregate memory usage counters.
175 void exportMemoryCounters();
176
177 /// Queue size stat update
178 void reportCurrentQueueSize(int32_t queueSize);
179
180 /// Execution stats update.
181 void updateExecutionStats(uint64_t startTime,
182 std::unique_ptr<ExecutionContext> &context,
183 llvm::StringRef name, const Error &error);
184
185 /// Keeps the stats exporter registry object alive till destructor.
186 std::shared_ptr<StatsExporterRegistry> statsExporterRegistry_;
187
188 /// Default constructor.
189 HostManager();
190
191public:
192 /// Constructor that takes configuration options.
193 HostManager(const HostConfig &hostConfig);
194
195 /// Constructor that takes a list of Devices to use.
196 HostManager(std::vector<std::unique_ptr<DeviceConfig>> deviceConfigs);
197
198 /// Constructor that takes both Devices and the configuration.
199 HostManager(std::vector<std::unique_ptr<DeviceConfig>> deviceConfigs,
200 const HostConfig &hostConfig);
201
202 /// Adds the network to the host and does the necessary setup work. This
203 /// includes partitioning, provisioning, compiling and initializing
204 /// backends. Additionally DAGs are created for each function and stored in
205 /// networks_. \returns an Error containing the results of the
206 /// operation. This function consumes the \p module so any pointers to data
207 /// contained within the module should be considered invalid. The function is
208 /// optimized based on \p cctx.
209 Error addNetwork(std::unique_ptr<Module> module, CompilationContext &cctx);
210
211/// Adds the already partitioned FX \p FXIR network to the host and does the
212/// necessary setup work. This includes provisioning, compiling and
213/// initializing backends. Requires a DAG \p networks to be provided.
214/// \returns an Error containing the results of the operation. This function
215/// consumes the \p module so any pointers to data contained within the module
216/// should be considered invalid. The function is optimized based on \p cctx.
217/// Constants are provided with a stringmap \p constants.
218#if FACEBOOK_INTERNAL
219 Error addNetworkFX(std::unique_ptr<Module> module, CompilationContext &cctx,
220 DAGListTy &networks, const folly::dynamic &FXIR,
221 const llvm::StringMap<const void *> &constants);
222#endif
223
224 /// Given \p networkName removes that network from the host. This also
225 /// removes the network from any backends setup to execute it.
226 /// \returns an Error indicating success or failure of the operation.
227 Error removeNetwork(llvm::StringRef networkName);
228
229 /// Update the list of available devices.
230 void setAvailableDevices(const std::vector<DeviceIDTy> &devices);
231
232 /// Returns a string map containing the name and block-stream for all
233 /// serialized functions.
234 std::unique_ptr<
235 std::unordered_map<std::string, std::unique_ptr<BlockStreamBase>>>
236 getAllSerializedFunctions();
237
238 /// For a given \p network returns all partitions of that network and the
239 /// devices each partition is assigned to.
240 std::unordered_map<std::string, std::vector<DeviceIDTy>>
241 getDevicePartitionMapping(llvm::StringRef network);
242
243 /// Returns true if \p networkName is already added to the host.
244 bool networkAdded(llvm::StringRef networkName);
245
246 /// Removes all networks from the host, and stops execution on all devices.
247 Error clearHost();
248
249 /// Runs the network specified by \p networkName using
250 /// the provided \p context, returns a runIdentifier which refers to the
251 /// specic inference request. Calls \p callback with the results when
252 /// inference is done.
253 /// Note: This method is intended to be thread-safe, it will be called
254 /// concurrently from multiple threads.
255 /// Returns -1 if networkName not found or too many active requests.
256 /// The parameter \p priority is used to indicate queueing priority, priority
257 /// is lowest number first and in case of a tie the request that was submitted
258 /// first will go first.
259 RunIdentifierTy runNetwork(llvm::StringRef networkName,
260 std::unique_ptr<ExecutionContext> context,
261 ResultCBTy callback, uint64_t priority = 0);
262
263 /// A wrapper around runNetwork that provides a blocking interface for an
264 /// inference request. Runs the network provided in \p networkName using \p
265 /// context. \returns an Error indicating success or failure. Upon return,
266 /// regardless of success or failure, \p context will be filled with the
267 /// return context from running the network.
268 Error runNetworkBlocking(llvm::StringRef networkName,
269 std::unique_ptr<ExecutionContext> &context);
270
271 /// A wrapper around runNetwork that provides a blocking interface for an
272 /// inference request. Runs the network provided in \p networkName using \p
273 /// bindings for placeholder bindings. \returns an Error indicating
274 /// success or failure.
275 Error runNetworkBlocking(llvm::StringRef networkName,
276 PlaceholderBindings &bindings);
277
278 /// Initialize the HostManager with the given \p configs creating one
279 /// DeviceManager for each config listed.
280 Error init(std::vector<std::unique_ptr<DeviceConfig>> configs);
281
282 /// Get the network DAG for \p network if it exists.
283 Expected<DAG *> getNetworkDAG(llvm::StringRef network);
284
285 /// \returns a non-owning pointer to the TraceContext.
286 TraceContext *getTraceContext() { return hostTraceContext_.get(); }
287
288 /// Sets the TraceContext and \returns the existing value.
289 std::unique_ptr<TraceContext>
290 setTraceContext(std::unique_ptr<TraceContext> traceContext) {
291 std::swap(hostTraceContext_, traceContext);
292 return traceContext;
293 }
294
295 /// Triggers start tracing of all active devices \returns Error if fails.
296 Error startDeviceTrace();
297
298 /// Triggers stop tracing of all active devices \returns Error if fails.
299 Error stopDeviceTrace();
300
301 /// \returns a reference to the backend with name \p backendName owned by the
302 /// Provisioner.
303 Backend &getBackend(llvm::StringRef backendName) const;
304
305 /// \returns a reference to the Backend if only one Backend is found,
306 /// otherwise returns an Error.
307 Expected<Backend *> getBackend() const;
308
309 /// \returns the number of devices the HostManager owns.
310 size_t numDevices() const { return devices_.size(); }
311
312 ~HostManager();
313
314 /// String const for logging current queue size in glow
315 static constexpr const char *kCurrentQueueSize10k =
316 "glow.queue.current.occupancy.10k";
317};
318
319/// If the device config file specified in loadDeviceConfigsFileOpt is
320/// available, load \p configs from the file. Otherwise, create \p numDevices
321/// number of devices based on \p backendName.
322std::vector<std::unique_ptr<runtime::DeviceConfig>>
323generateDeviceConfigs(unsigned int numDevices, llvm::StringRef backendName,
324 size_t memSize = 0);
325
326/// Attempts to load user-specified DeviceConfigs file
327/// \ref loadDeviceConfigsFileOpt. If the path exists then \p configs will be
328/// loaded with DeviceConfigs given that file and \p memSize, and the function
329/// \returns true. Otherwise \returns false with \p configs untouched.
330bool loadDeviceConfigsFromFile(
331 std::vector<std::unique_ptr<runtime::DeviceConfig>> &configs,
332 size_t memSize);
333
334/// Registry singleton for aquiring a HostManager.
335class HostManagerRegistry final {
336public:
337 void registerHostManager(HostManager *hostManager);
338 HostManager *getHostManager();
339
340private:
341 HostManager *hostManager_{nullptr};
342};
343
344/// Global singleton.
345std::shared_ptr<HostManagerRegistry> ManagerRegistry();
346
347} // namespace runtime
348} // namespace glow
349#endif // GLOW_RUNTIME_HOSTMANAGERR_HOSTMANAGER_H
350