1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Runtime/Executor/ThreadPoolExecutor.h" |
18 | #include "glow/Backends/DeviceManager.h" |
19 | #include "glow/ExecutionContext/ExecutionContext.h" |
20 | #include "glow/Runtime/ErrorReporter.h" |
21 | |
22 | #include <queue> |
23 | #include <unordered_set> |
24 | |
25 | #include "llvm/Support/FormatVariadic.h" |
26 | #include <glog/logging.h> |
27 | |
28 | namespace glow { |
29 | namespace runtime { |
30 | |
31 | void InflightBarrier::decrement(unsigned decr) { |
32 | std::unique_lock<std::mutex> lock(mtx_); |
33 | DCHECK_GE(count_, decr) << "Barrier decrement cannot be less than count!" ; |
34 | count_ -= decr; |
35 | |
36 | // If count_ has hit zero, wake up all threads that are waiting. |
37 | if (count_ == 0) { |
38 | cv_.notify_all(); |
39 | } |
40 | } // namespace runtime |
41 | |
42 | void InflightBarrier::increment(unsigned incr) { |
43 | std::unique_lock<std::mutex> lock(mtx_); |
44 | count_ += incr; |
45 | } |
46 | |
47 | unsigned InflightBarrier::count() { |
48 | std::unique_lock<std::mutex> lock(mtx_); |
49 | return count_; |
50 | } |
51 | |
52 | void InflightBarrier::wait() { |
53 | std::unique_lock<std::mutex> lock(mtx_); |
54 | // If count_ is not 0, wait until a signal is received that it is. |
55 | // The second argument below is a predicate that returns true when |
56 | // it is safe to wake up. It preserves correctness in the case of |
57 | // spurious wakeups. |
58 | cv_.wait(lock, [&] { return count_ == 0; }); |
59 | } |
60 | |
61 | ThreadPoolExecutor::ThreadPoolExecutor(const DeviceManagerMapTy &deviceManagers, |
62 | unsigned numWorkers, |
63 | const std::string &name) |
64 | : threadPool_(numWorkers, |
65 | std::make_shared<folly::NamedThreadFactory>(name)), |
66 | deviceManagers_(deviceManagers) {} |
67 | |
68 | void ThreadPoolExecutor::shutdown() { |
69 | // Prevent more requests from being processed. |
70 | shuttingDown_ = true; |
71 | |
72 | // Wait for all inflight DeviceManager::runFunction() calls to return and be |
73 | // processed before starting to destroy state that is used in |
74 | // handleDeviceManagerResult(). |
75 | inflightBarrier_.wait(); |
76 | |
77 | threadPool_.stop(); |
78 | threadPool_.join(); |
79 | } |
80 | |
81 | void ThreadPoolExecutor::run(const DAGNode *root, |
82 | std::unique_ptr<ExecutionContext> context, |
83 | RunIdentifierTy runId, ResultCBTy cb) { |
84 | DCHECK(cb != nullptr); |
85 | |
86 | TRACE_EVENT_SCOPE(context->getTraceContext(), TraceLevel::RUNTIME, |
87 | "ThreadPoolExecutor::run" ); |
88 | |
89 | if (context->getTraceContext()) { |
90 | auto tid = threads::getThreadId(); |
91 | if (!context->getTraceContext()->getThreadNames().count(tid)) { |
92 | context->getTraceContext()->setThreadName(tid, "ThreadPoolExecutor" ); |
93 | } |
94 | } |
95 | |
96 | // Don't process new requests if the executor is shutting down. |
97 | if (shuttingDown_) { |
98 | cb(runId, |
99 | MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_REQUEST_REFUSED, |
100 | "ThreadPoolExecutor is shutting down" ), |
101 | std::move(context)); |
102 | return; |
103 | } |
104 | |
105 | // If list of roots is empty, there is nothing to do. Give back the |
106 | // bindings so the caller can reuse it. |
107 | if (!root) { |
108 | cb(runId, Error::success(), std::move(context)); |
109 | return; |
110 | } |
111 | |
112 | auto numChildren = (root->children).size(); |
113 | // Mark the child nodes as "inflight" (i.e. currently executing). This must |
114 | // be done here instead of inside executeDAGNode() so that a node can be |
115 | // executed while placeholders are being propagated for the next node |
116 | // without the callback for that node deleting the execution state. |
117 | inflightBarrier_.increment(numChildren); |
118 | |
119 | auto *traceContext = context->getTraceContext(); |
120 | |
121 | // Get and bind state. |
122 | auto currentState = states_.rlock()->at(root)->getNextNetworkExecutionState(); |
123 | TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME, |
124 | "bind network execution state" ); |
125 | currentState->bind(std::move(context), std::move(cb), runId); |
126 | TRACE_EVENT_END(traceContext, TraceLevel::RUNTIME, |
127 | "bind network execution state" ); |
128 | |
129 | currentState->incrementInflightNodes(numChildren); |
130 | |
131 | // End the trace block before calling executeDAGNode() which can trigger the |
132 | // result cb. Once the result cb is called, it's no longer safe to access the |
133 | // trace context. |
134 | TRACE_EVENT_SCOPE_END(); |
135 | for (auto const &node : root->children) { |
136 | // Run with cached state |
137 | executeDAGNode(currentState, node); |
138 | } |
139 | } |
140 | |
141 | void ThreadPoolExecutor::executeDAGNode(NetworkExecutionState *executionState, |
142 | DAGNode *node) { |
143 | std::string traceScopeStr; |
144 | bool tracingEnabled = |
145 | !!executionState->getRawResultContextPtr()->getTraceContext(); |
146 | if (tracingEnabled) { |
147 | traceScopeStr = llvm::formatv("ThreadPoolExecutor::executeDAGNode {0:x}" , |
148 | executionState->getRawResultContextPtr()) |
149 | .str(); |
150 | } |
151 | |
152 | TRACE_EVENT_SCOPE(executionState->getRawResultContextPtr()->getTraceContext(), |
153 | TraceLevel::RUNTIME, traceScopeStr); |
154 | |
155 | if (executionState->getErrorContainer().containsErr()) { |
156 | // Mark the node as no longer executing. |
157 | executionState->decrementInflightNodes(); |
158 | inflightBarrier_.decrement(); |
159 | return; |
160 | } |
161 | |
162 | // Get the PlaceholderBindings containing all of the inputs for the node. |
163 | std::unique_ptr<ExecutionContext> nodeCtx = |
164 | executionState->getUniqueNodeContextPtr(node); |
165 | |
166 | // Trace child node creation (to be able to identify function execution |
167 | // origin). |
168 | std::string traceNodeChildCreateStr; |
169 | if (tracingEnabled) { |
170 | traceNodeChildCreateStr = llvm::formatv( |
171 | "ThreadPoolExecutor::executeDAGNode child node {0:x}" , nodeCtx.get()); |
172 | } |
173 | |
174 | TRACE_EVENT_BEGIN(executionState->getRawResultContextPtr()->getTraceContext(), |
175 | TraceLevel::RUNTIME, traceNodeChildCreateStr); |
176 | // Get the DeviceManager that can run the node. |
177 | auto currentDevice = node->getNextDevice(); |
178 | auto deviceManagerIt = deviceManagers_.find(currentDevice); |
179 | |
180 | if (deviceManagerIt == deviceManagers_.end()) { |
181 | // Mark the node as no longer executing. |
182 | executionState->getErrorContainer().set( |
183 | MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND, |
184 | "Cannot find the DeviceManager specified." )); |
185 | executionState->decrementInflightNodes(); |
186 | inflightBarrier_.decrement(); |
187 | return; |
188 | } |
189 | DeviceManager *deviceManager = deviceManagerIt->second.get(); |
190 | // If the context has a deviceManager bound use that instead. |
191 | if (nodeCtx->getBoundDeviceManager()) { |
192 | deviceManager = nodeCtx->getBoundDeviceManager(); |
193 | } |
194 | |
195 | // End the trace block before calling deviceManager->runFunction which can |
196 | // trigger the result cb in a different thread. Once the result cb is called, |
197 | // it's no longer safe to access the trace context. |
198 | TRACE_EVENT_END(executionState->getRawResultContextPtr()->getTraceContext(), |
199 | TraceLevel::RUNTIME, traceNodeChildCreateStr); |
200 | TRACE_EVENT_SCOPE_END(); |
201 | // Run the node using the DeviceManager. |
202 | deviceManager->runFunction( |
203 | node->getNextName(currentDevice), std::move(nodeCtx), |
204 | [this, executionState, currentDevice, |
205 | node](RunIdentifierTy id, Error err, |
206 | std::unique_ptr<ExecutionContext> resultCtx) { |
207 | TRACE_EVENT_LOG_ID(resultCtx->getTraceContext(), TraceLevel::REQUEST, |
208 | "handle result queuing" , TraceEvent::AsyncBeginType, |
209 | TraceEvent::now(), id); |
210 | |
211 | // Immediately move the handling of the result onto this run's executor |
212 | // to avoid doing work on the DeviceManager thread. |
213 | threadPool_.add([this, executionState, node, err = std::move(err), |
214 | currentDevice, id, |
215 | ctx = std::move(resultCtx)]() mutable { |
216 | TRACE_EVENT_LOG_ID(ctx->getTraceContext(), TraceLevel::REQUEST, |
217 | "handle result queuing" , TraceEvent::AsyncEndType, |
218 | TraceEvent::now(), id); |
219 | |
220 | node->markFinished(currentDevice); |
221 | this->handleDeviceManagerResult(executionState, std::move(err), |
222 | std::move(ctx), node); |
223 | }); |
224 | }); |
225 | } |
226 | |
227 | void ThreadPoolExecutor::handleDeviceManagerResult( |
228 | NetworkExecutionState *executionState, Error err, |
229 | std::unique_ptr<ExecutionContext> ctx, const DAGNode *node) { |
230 | TraceContext *traceContext = ctx->getTraceContext(); |
231 | if (traceContext) { |
232 | TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME, |
233 | "ThreadPoolExecutor::handleResult" ); |
234 | } |
235 | |
236 | auto runWasSuccess = !err; |
237 | |
238 | // Set the result code for the run. |
239 | executionState->getErrorContainer().set(std::move(err)); |
240 | |
241 | // If the DeviceManager executed the node, propagate its output Placeholders |
242 | // to its children or the result PlaceholderBindings as appropriate. |
243 | if (runWasSuccess) { |
244 | for (auto &child : node->children) { |
245 | // Execute any child that has no parent nodes left to execute. |
246 | bool childReadyToExecute = |
247 | executionState->incrementNodeParentsDone(child); |
248 | if (childReadyToExecute) { |
249 | // Mark the node as "inflight" (i.e. currently executing). |
250 | executionState->incrementInflightNodes(); |
251 | inflightBarrier_.increment(); |
252 | executeDAGNode(executionState, child); |
253 | } |
254 | } |
255 | } else if (err && err.peekErrorValue() && |
256 | err.peekErrorValue()->isFatalError()) { |
257 | std::string msg = err.peekErrorValue()->logToString(); |
258 | auto reporters = ErrorReporterRegistry::ErrorReporters(); |
259 | if (reporters) { |
260 | reporters->report(msg); |
261 | } |
262 | LOG(FATAL) << "Non-recoverable device error: " << msg; |
263 | } |
264 | |
265 | // Return intermediateContext to executionState. |
266 | executionState->returnUniqueNodeContextPtr(node, std::move(ctx)); |
267 | |
268 | // This needs to happen before decrementInflightNodes(). Otherwise a race |
269 | // condition can happen where two threads call into this function at the same |
270 | // time. Once decrementInflightNodes() is called, only the thread that get |
271 | // noNodesInflight == true can access executionState. |
272 | if (traceContext) { |
273 | TRACE_EVENT_END(traceContext, TraceLevel::RUNTIME, |
274 | "ThreadPoolExecutor::handleResult" ); |
275 | executionState->insertIntoTraceContext(traceContext); |
276 | } |
277 | |
278 | // Now, check if all nodes in the graph are done. If so, the callback can be |
279 | // called and all state associated with the run can be erased. |
280 | bool noNodesInflight = executionState->decrementInflightNodes(); |
281 | |
282 | if (noNodesInflight) { |
283 | // If there are no nodes inflight, that means all nodes are done. Transfer |
284 | // the outpus. Call the callback and erase the state information. |
285 | // Because we are redirecting inputs and outputs to use the provided tensor |
286 | // we do not have to transfer outputs here. Once we have pinned memory we |
287 | // will transfer. //executionState->transferOutputs(); |
288 | ResultCBTy cb = executionState->getCallback(); |
289 | DCHECK(cb != nullptr); |
290 | |
291 | // Get what we need from the executionState and return it to the pool. |
292 | auto runId = executionState->getRunId(); |
293 | auto err = executionState->getErrorContainer().get(); |
294 | auto resultCtx = executionState->getUniqueResultContextPtr(); |
295 | states_.rlock() |
296 | ->at(executionState->getRoot()) |
297 | ->returnNetworkExecutionState(executionState); |
298 | |
299 | cb(runId, std::move(err), std::move(resultCtx)); |
300 | } |
301 | |
302 | // Decrement the inflight barrier for the executor keeping track of all |
303 | // outstanding DeviceManager::runFunction() calls. This must be done here |
304 | // instead of right after executionState->decrementInflightNodes() so that |
305 | // ~ThreadPoolExecutor does not delete executor state before this function |
306 | // is done using it (e.g. when erasing the ExecutionState object for a |
307 | // run). |
308 | inflightBarrier_.decrement(); |
309 | } |
310 | |
311 | void ThreadPoolExecutor::createPool(const DAGNode *root, unsigned poolSize, |
312 | bool enableP2P, bool enableDRT) { |
313 | std::unordered_map<DAGNode *, DeviceIDTy> assignment; |
314 | |
315 | // For static assignment we need to track devices each node is assigned to. |
316 | std::unordered_map<DAGNode *, std::vector<DeviceIDTy>> assignments; |
317 | std::unordered_map<DAGNode *, unsigned> currentAssignment; |
318 | if (enableP2P || enableDRT) { |
319 | // Walk the nodes and get assignments. |
320 | std::queue<DAGNode *> remaining; |
321 | for (auto node : root->children) { |
322 | remaining.push(node); |
323 | } |
324 | while (remaining.size()) { |
325 | auto node = remaining.front(); |
326 | remaining.pop(); |
327 | // Add any new children to the queue. |
328 | for (auto child : node->children) { |
329 | auto it = assignments.find(child); |
330 | if (it == assignments.end()) { |
331 | remaining.push(child); |
332 | } |
333 | } |
334 | std::vector<DeviceIDTy> assignment; |
335 | for (auto dev : node->deviceRuntimeInfos) { |
336 | assignment.push_back(dev.first); |
337 | } |
338 | assignments[node] = assignment; |
339 | currentAssignment[node] = 0; |
340 | } |
341 | } |
342 | |
343 | std::unique_ptr<NetworkExecutionStatePool> pool = |
344 | glow::make_unique<NetworkExecutionStatePool>(); |
345 | for (unsigned i = 0; i < poolSize; i++) { |
346 | auto newState = |
347 | glow::make_unique<NetworkExecutionState>(root, enableDRT, enableP2P); |
348 | // If assignStatic, calculate the device assignments for this |
349 | // executionState. For now we are assigning a round robin pattern per node. |
350 | if (enableDRT || enableP2P) { |
351 | for (auto it : currentAssignment) { |
352 | auto &nodeAssignments = assignments.at(it.first); |
353 | auto newAssignmentIdx = (it.second + 1) % nodeAssignments.size(); |
354 | auto newAssignment = nodeAssignments[newAssignmentIdx]; |
355 | assignment[it.first] = newAssignment; |
356 | currentAssignment[it.first] = newAssignmentIdx; |
357 | } |
358 | } |
359 | newState->init(deviceManagers_, assignment); |
360 | pool->addNewState(std::move(newState)); |
361 | } |
362 | |
363 | states_.wlock()->emplace(root, std::move(pool)); |
364 | } |
365 | |
366 | void ThreadPoolExecutor::freePool(const DAGNode *root) { |
367 | |
368 | states_.wlock()->erase(root); |
369 | } |
370 | |
371 | } // namespace runtime |
372 | } // namespace glow |
373 | |