1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_RUNTIME_THREAD_POOL_EXECUTOR_H
17#define GLOW_RUNTIME_THREAD_POOL_EXECUTOR_H
18
19#include <condition_variable>
20#include <mutex>
21#include <unordered_map>
22
23#include "NetworkExecutionState.h"
24#include "folly/Synchronized.h"
25#include "folly/executors/CPUThreadPoolExecutor.h"
26#include "glow/Runtime/Executor/Executor.h"
27
28namespace glow {
29namespace runtime {
30
31class ExecutionState;
32
33/// This class implements a simple barrier with which to wait for all threads
34/// to exit a certain section of code before proceeding.
35class InflightBarrier final {
36public:
37 /// Decrement the count of threads in the barrier by \p decr.
38 void decrement(unsigned decr = 1);
39
40 /// Increment the count of threads in the barrier by \p incr.
41 void increment(unsigned incr = 1);
42
43 /// \returns the current count of the barrier.
44 unsigned count();
45
46 /// Wait for the barrier count to hit zero before continuing. This is
47 /// potentially a blocking call.
48 void wait();
49
50private:
51 /// Count of threads inside the barrier.
52 unsigned count_{0};
53 /// Mutex for accessing count_;
54 std::mutex mtx_;
55 /// Condition variable for implementing wait().
56 std::condition_variable cv_;
57};
58
59/// This implementation of the Executor interface uses a thread pool to
60/// handle and process multiple concurrent execution runs.
61class ThreadPoolExecutor final : public Executor {
62public:
63 /// Constructor.
64 explicit ThreadPoolExecutor(const DeviceManagerMapTy &deviceManagers,
65 unsigned numWorkers = kNumWorkers,
66 const std::string &name = "");
67
68 /// Setup context pool for new network.
69 void createPool(const DAGNode *root, unsigned poolSize, bool enableP2P,
70 bool enableDRT) override;
71
72 /// Free the context pool for specified network.
73 void freePool(const DAGNode *root) override;
74
75 /// See Executor::run. A particular invocation is specified completely by
76 /// the triple (roots, bindings, runId).
77 void run(const DAGNode *root, std::unique_ptr<ExecutionContext> context,
78 RunIdentifierTy runId, ResultCBTy cb) override;
79
80 ~ThreadPoolExecutor() override { shutdown(); }
81
82 void shutdown() override;
83
84private:
85 /// Execute the DAG node specified by \p node within the run corresponding to
86 /// \p state.
87 void executeDAGNode(NetworkExecutionState *executionState, DAGNode *node);
88
89 /// Handle the result returned asynchronously by the DeviceManager.
90 /// \p executionState is tracks the state of the run that the node that
91 /// finished executing belongs to, \p err is the Error returned by the
92 /// DeviceManager, \p ctx is the ExecutionContext that contains the outputs
93 /// produced by \p node during the run.
94 ///
95 /// The main purpose of this function is to help move computation off of the
96 /// DeviceManager thread pool on onto the one owned by this class.
97 void handleDeviceManagerResult(NetworkExecutionState *executionState,
98 Error err,
99 std::unique_ptr<ExecutionContext> ctx,
100 const DAGNode *node);
101
102 /// The default number of workers in the thread pool.
103 constexpr static unsigned kNumWorkers = 3;
104 /// The thread pool used to drive execution.
105 folly::CPUThreadPoolExecutor threadPool_;
106
107 /// Map of networkExecutionState pools for each network.
108 folly::Synchronized<std::unordered_map<
109 const DAGNode *, std::unique_ptr<NetworkExecutionStatePool>>>
110 states_;
111
112 /// Barrier for making sure all asynchronous requests made to the
113 /// DeviceManager return before allowing destruction of the executor.
114 InflightBarrier inflightBarrier_;
115 /// Whether the executor is currently shutting down or not.
116 std::atomic<bool> shuttingDown_{false};
117
118 /// Map of available DeviceManagers.
119 const DeviceManagerMapTy &deviceManagers_;
120};
121
122} // namespace runtime
123} // namespace glow
124#endif
125