1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Runtime/Executor/ThreadPoolExecutor.h"
18#include "glow/Backends/DeviceManager.h"
19#include "glow/ExecutionContext/ExecutionContext.h"
20#include "glow/Runtime/ErrorReporter.h"
21
22#include <queue>
23#include <unordered_set>
24
25#include "llvm/Support/FormatVariadic.h"
26#include <glog/logging.h>
27
28namespace glow {
29namespace runtime {
30
31void InflightBarrier::decrement(unsigned decr) {
32 std::unique_lock<std::mutex> lock(mtx_);
33 DCHECK_GE(count_, decr) << "Barrier decrement cannot be less than count!";
34 count_ -= decr;
35
36 // If count_ has hit zero, wake up all threads that are waiting.
37 if (count_ == 0) {
38 cv_.notify_all();
39 }
40} // namespace runtime
41
42void InflightBarrier::increment(unsigned incr) {
43 std::unique_lock<std::mutex> lock(mtx_);
44 count_ += incr;
45}
46
47unsigned InflightBarrier::count() {
48 std::unique_lock<std::mutex> lock(mtx_);
49 return count_;
50}
51
52void InflightBarrier::wait() {
53 std::unique_lock<std::mutex> lock(mtx_);
54 // If count_ is not 0, wait until a signal is received that it is.
55 // The second argument below is a predicate that returns true when
56 // it is safe to wake up. It preserves correctness in the case of
57 // spurious wakeups.
58 cv_.wait(lock, [&] { return count_ == 0; });
59}
60
61ThreadPoolExecutor::ThreadPoolExecutor(const DeviceManagerMapTy &deviceManagers,
62 unsigned numWorkers,
63 const std::string &name)
64 : threadPool_(numWorkers,
65 std::make_shared<folly::NamedThreadFactory>(name)),
66 deviceManagers_(deviceManagers) {}
67
68void ThreadPoolExecutor::shutdown() {
69 // Prevent more requests from being processed.
70 shuttingDown_ = true;
71
72 // Wait for all inflight DeviceManager::runFunction() calls to return and be
73 // processed before starting to destroy state that is used in
74 // handleDeviceManagerResult().
75 inflightBarrier_.wait();
76
77 threadPool_.stop();
78 threadPool_.join();
79}
80
81void ThreadPoolExecutor::run(const DAGNode *root,
82 std::unique_ptr<ExecutionContext> context,
83 RunIdentifierTy runId, ResultCBTy cb) {
84 DCHECK(cb != nullptr);
85
86 TRACE_EVENT_SCOPE(context->getTraceContext(), TraceLevel::RUNTIME,
87 "ThreadPoolExecutor::run");
88
89 if (context->getTraceContext()) {
90 auto tid = threads::getThreadId();
91 if (!context->getTraceContext()->getThreadNames().count(tid)) {
92 context->getTraceContext()->setThreadName(tid, "ThreadPoolExecutor");
93 }
94 }
95
96 // Don't process new requests if the executor is shutting down.
97 if (shuttingDown_) {
98 cb(runId,
99 MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_REQUEST_REFUSED,
100 "ThreadPoolExecutor is shutting down"),
101 std::move(context));
102 return;
103 }
104
105 // If list of roots is empty, there is nothing to do. Give back the
106 // bindings so the caller can reuse it.
107 if (!root) {
108 cb(runId, Error::success(), std::move(context));
109 return;
110 }
111
112 auto numChildren = (root->children).size();
113 // Mark the child nodes as "inflight" (i.e. currently executing). This must
114 // be done here instead of inside executeDAGNode() so that a node can be
115 // executed while placeholders are being propagated for the next node
116 // without the callback for that node deleting the execution state.
117 inflightBarrier_.increment(numChildren);
118
119 auto *traceContext = context->getTraceContext();
120
121 // Get and bind state.
122 auto currentState = states_.rlock()->at(root)->getNextNetworkExecutionState();
123 TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME,
124 "bind network execution state");
125 currentState->bind(std::move(context), std::move(cb), runId);
126 TRACE_EVENT_END(traceContext, TraceLevel::RUNTIME,
127 "bind network execution state");
128
129 currentState->incrementInflightNodes(numChildren);
130
131 // End the trace block before calling executeDAGNode() which can trigger the
132 // result cb. Once the result cb is called, it's no longer safe to access the
133 // trace context.
134 TRACE_EVENT_SCOPE_END();
135 for (auto const &node : root->children) {
136 // Run with cached state
137 executeDAGNode(currentState, node);
138 }
139}
140
141void ThreadPoolExecutor::executeDAGNode(NetworkExecutionState *executionState,
142 DAGNode *node) {
143 std::string traceScopeStr;
144 bool tracingEnabled =
145 !!executionState->getRawResultContextPtr()->getTraceContext();
146 if (tracingEnabled) {
147 traceScopeStr = llvm::formatv("ThreadPoolExecutor::executeDAGNode {0:x}",
148 executionState->getRawResultContextPtr())
149 .str();
150 }
151
152 TRACE_EVENT_SCOPE(executionState->getRawResultContextPtr()->getTraceContext(),
153 TraceLevel::RUNTIME, traceScopeStr);
154
155 if (executionState->getErrorContainer().containsErr()) {
156 // Mark the node as no longer executing.
157 executionState->decrementInflightNodes();
158 inflightBarrier_.decrement();
159 return;
160 }
161
162 // Get the PlaceholderBindings containing all of the inputs for the node.
163 std::unique_ptr<ExecutionContext> nodeCtx =
164 executionState->getUniqueNodeContextPtr(node);
165
166 // Trace child node creation (to be able to identify function execution
167 // origin).
168 std::string traceNodeChildCreateStr;
169 if (tracingEnabled) {
170 traceNodeChildCreateStr = llvm::formatv(
171 "ThreadPoolExecutor::executeDAGNode child node {0:x}", nodeCtx.get());
172 }
173
174 TRACE_EVENT_BEGIN(executionState->getRawResultContextPtr()->getTraceContext(),
175 TraceLevel::RUNTIME, traceNodeChildCreateStr);
176 // Get the DeviceManager that can run the node.
177 auto currentDevice = node->getNextDevice();
178 auto deviceManagerIt = deviceManagers_.find(currentDevice);
179
180 if (deviceManagerIt == deviceManagers_.end()) {
181 // Mark the node as no longer executing.
182 executionState->getErrorContainer().set(
183 MAKE_ERR(ErrorValue::ErrorCode::RUNTIME_DEVICE_NOT_FOUND,
184 "Cannot find the DeviceManager specified."));
185 executionState->decrementInflightNodes();
186 inflightBarrier_.decrement();
187 return;
188 }
189 DeviceManager *deviceManager = deviceManagerIt->second.get();
190 // If the context has a deviceManager bound use that instead.
191 if (nodeCtx->getBoundDeviceManager()) {
192 deviceManager = nodeCtx->getBoundDeviceManager();
193 }
194
195 // End the trace block before calling deviceManager->runFunction which can
196 // trigger the result cb in a different thread. Once the result cb is called,
197 // it's no longer safe to access the trace context.
198 TRACE_EVENT_END(executionState->getRawResultContextPtr()->getTraceContext(),
199 TraceLevel::RUNTIME, traceNodeChildCreateStr);
200 TRACE_EVENT_SCOPE_END();
201 // Run the node using the DeviceManager.
202 deviceManager->runFunction(
203 node->getNextName(currentDevice), std::move(nodeCtx),
204 [this, executionState, currentDevice,
205 node](RunIdentifierTy id, Error err,
206 std::unique_ptr<ExecutionContext> resultCtx) {
207 TRACE_EVENT_LOG_ID(resultCtx->getTraceContext(), TraceLevel::REQUEST,
208 "handle result queuing", TraceEvent::AsyncBeginType,
209 TraceEvent::now(), id);
210
211 // Immediately move the handling of the result onto this run's executor
212 // to avoid doing work on the DeviceManager thread.
213 threadPool_.add([this, executionState, node, err = std::move(err),
214 currentDevice, id,
215 ctx = std::move(resultCtx)]() mutable {
216 TRACE_EVENT_LOG_ID(ctx->getTraceContext(), TraceLevel::REQUEST,
217 "handle result queuing", TraceEvent::AsyncEndType,
218 TraceEvent::now(), id);
219
220 node->markFinished(currentDevice);
221 this->handleDeviceManagerResult(executionState, std::move(err),
222 std::move(ctx), node);
223 });
224 });
225}
226
227void ThreadPoolExecutor::handleDeviceManagerResult(
228 NetworkExecutionState *executionState, Error err,
229 std::unique_ptr<ExecutionContext> ctx, const DAGNode *node) {
230 TraceContext *traceContext = ctx->getTraceContext();
231 if (traceContext) {
232 TRACE_EVENT_BEGIN(traceContext, TraceLevel::RUNTIME,
233 "ThreadPoolExecutor::handleResult");
234 }
235
236 auto runWasSuccess = !err;
237
238 // Set the result code for the run.
239 executionState->getErrorContainer().set(std::move(err));
240
241 // If the DeviceManager executed the node, propagate its output Placeholders
242 // to its children or the result PlaceholderBindings as appropriate.
243 if (runWasSuccess) {
244 for (auto &child : node->children) {
245 // Execute any child that has no parent nodes left to execute.
246 bool childReadyToExecute =
247 executionState->incrementNodeParentsDone(child);
248 if (childReadyToExecute) {
249 // Mark the node as "inflight" (i.e. currently executing).
250 executionState->incrementInflightNodes();
251 inflightBarrier_.increment();
252 executeDAGNode(executionState, child);
253 }
254 }
255 } else if (err && err.peekErrorValue() &&
256 err.peekErrorValue()->isFatalError()) {
257 std::string msg = err.peekErrorValue()->logToString();
258 auto reporters = ErrorReporterRegistry::ErrorReporters();
259 if (reporters) {
260 reporters->report(msg);
261 }
262 LOG(FATAL) << "Non-recoverable device error: " << msg;
263 }
264
265 // Return intermediateContext to executionState.
266 executionState->returnUniqueNodeContextPtr(node, std::move(ctx));
267
268 // This needs to happen before decrementInflightNodes(). Otherwise a race
269 // condition can happen where two threads call into this function at the same
270 // time. Once decrementInflightNodes() is called, only the thread that get
271 // noNodesInflight == true can access executionState.
272 if (traceContext) {
273 TRACE_EVENT_END(traceContext, TraceLevel::RUNTIME,
274 "ThreadPoolExecutor::handleResult");
275 executionState->insertIntoTraceContext(traceContext);
276 }
277
278 // Now, check if all nodes in the graph are done. If so, the callback can be
279 // called and all state associated with the run can be erased.
280 bool noNodesInflight = executionState->decrementInflightNodes();
281
282 if (noNodesInflight) {
283 // If there are no nodes inflight, that means all nodes are done. Transfer
284 // the outpus. Call the callback and erase the state information.
285 // Because we are redirecting inputs and outputs to use the provided tensor
286 // we do not have to transfer outputs here. Once we have pinned memory we
287 // will transfer. //executionState->transferOutputs();
288 ResultCBTy cb = executionState->getCallback();
289 DCHECK(cb != nullptr);
290
291 // Get what we need from the executionState and return it to the pool.
292 auto runId = executionState->getRunId();
293 auto err = executionState->getErrorContainer().get();
294 auto resultCtx = executionState->getUniqueResultContextPtr();
295 states_.rlock()
296 ->at(executionState->getRoot())
297 ->returnNetworkExecutionState(executionState);
298
299 cb(runId, std::move(err), std::move(resultCtx));
300 }
301
302 // Decrement the inflight barrier for the executor keeping track of all
303 // outstanding DeviceManager::runFunction() calls. This must be done here
304 // instead of right after executionState->decrementInflightNodes() so that
305 // ~ThreadPoolExecutor does not delete executor state before this function
306 // is done using it (e.g. when erasing the ExecutionState object for a
307 // run).
308 inflightBarrier_.decrement();
309}
310
311void ThreadPoolExecutor::createPool(const DAGNode *root, unsigned poolSize,
312 bool enableP2P, bool enableDRT) {
313 std::unordered_map<DAGNode *, DeviceIDTy> assignment;
314
315 // For static assignment we need to track devices each node is assigned to.
316 std::unordered_map<DAGNode *, std::vector<DeviceIDTy>> assignments;
317 std::unordered_map<DAGNode *, unsigned> currentAssignment;
318 if (enableP2P || enableDRT) {
319 // Walk the nodes and get assignments.
320 std::queue<DAGNode *> remaining;
321 for (auto node : root->children) {
322 remaining.push(node);
323 }
324 while (remaining.size()) {
325 auto node = remaining.front();
326 remaining.pop();
327 // Add any new children to the queue.
328 for (auto child : node->children) {
329 auto it = assignments.find(child);
330 if (it == assignments.end()) {
331 remaining.push(child);
332 }
333 }
334 std::vector<DeviceIDTy> assignment;
335 for (auto dev : node->deviceRuntimeInfos) {
336 assignment.push_back(dev.first);
337 }
338 assignments[node] = assignment;
339 currentAssignment[node] = 0;
340 }
341 }
342
343 std::unique_ptr<NetworkExecutionStatePool> pool =
344 glow::make_unique<NetworkExecutionStatePool>();
345 for (unsigned i = 0; i < poolSize; i++) {
346 auto newState =
347 glow::make_unique<NetworkExecutionState>(root, enableDRT, enableP2P);
348 // If assignStatic, calculate the device assignments for this
349 // executionState. For now we are assigning a round robin pattern per node.
350 if (enableDRT || enableP2P) {
351 for (auto it : currentAssignment) {
352 auto &nodeAssignments = assignments.at(it.first);
353 auto newAssignmentIdx = (it.second + 1) % nodeAssignments.size();
354 auto newAssignment = nodeAssignments[newAssignmentIdx];
355 assignment[it.first] = newAssignment;
356 currentAssignment[it.first] = newAssignmentIdx;
357 }
358 }
359 newState->init(deviceManagers_, assignment);
360 pool->addNewState(std::move(newState));
361 }
362
363 states_.wlock()->emplace(root, std::move(pool));
364}
365
366void ThreadPoolExecutor::freePool(const DAGNode *root) {
367
368 states_.wlock()->erase(root);
369}
370
371} // namespace runtime
372} // namespace glow
373