1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_RUNTIME_THREAD_POOL_EXECUTOR_H |
17 | #define GLOW_RUNTIME_THREAD_POOL_EXECUTOR_H |
18 | |
19 | #include <condition_variable> |
20 | #include <mutex> |
21 | #include <unordered_map> |
22 | |
23 | #include "NetworkExecutionState.h" |
24 | #include "folly/Synchronized.h" |
25 | #include "folly/executors/CPUThreadPoolExecutor.h" |
26 | #include "glow/Runtime/Executor/Executor.h" |
27 | |
28 | namespace glow { |
29 | namespace runtime { |
30 | |
31 | class ExecutionState; |
32 | |
33 | /// This class implements a simple barrier with which to wait for all threads |
34 | /// to exit a certain section of code before proceeding. |
35 | class InflightBarrier final { |
36 | public: |
37 | /// Decrement the count of threads in the barrier by \p decr. |
38 | void decrement(unsigned decr = 1); |
39 | |
40 | /// Increment the count of threads in the barrier by \p incr. |
41 | void increment(unsigned incr = 1); |
42 | |
43 | /// \returns the current count of the barrier. |
44 | unsigned count(); |
45 | |
46 | /// Wait for the barrier count to hit zero before continuing. This is |
47 | /// potentially a blocking call. |
48 | void wait(); |
49 | |
50 | private: |
51 | /// Count of threads inside the barrier. |
52 | unsigned count_{0}; |
53 | /// Mutex for accessing count_; |
54 | std::mutex mtx_; |
55 | /// Condition variable for implementing wait(). |
56 | std::condition_variable cv_; |
57 | }; |
58 | |
59 | /// This implementation of the Executor interface uses a thread pool to |
60 | /// handle and process multiple concurrent execution runs. |
61 | class ThreadPoolExecutor final : public Executor { |
62 | public: |
63 | /// Constructor. |
64 | explicit ThreadPoolExecutor(const DeviceManagerMapTy &deviceManagers, |
65 | unsigned numWorkers = kNumWorkers, |
66 | const std::string &name = "" ); |
67 | |
68 | /// Setup context pool for new network. |
69 | void createPool(const DAGNode *root, unsigned poolSize, bool enableP2P, |
70 | bool enableDRT) override; |
71 | |
72 | /// Free the context pool for specified network. |
73 | void freePool(const DAGNode *root) override; |
74 | |
75 | /// See Executor::run. A particular invocation is specified completely by |
76 | /// the triple (roots, bindings, runId). |
77 | void run(const DAGNode *root, std::unique_ptr<ExecutionContext> context, |
78 | RunIdentifierTy runId, ResultCBTy cb) override; |
79 | |
80 | ~ThreadPoolExecutor() override { shutdown(); } |
81 | |
82 | void shutdown() override; |
83 | |
84 | private: |
85 | /// Execute the DAG node specified by \p node within the run corresponding to |
86 | /// \p state. |
87 | void executeDAGNode(NetworkExecutionState *executionState, DAGNode *node); |
88 | |
89 | /// Handle the result returned asynchronously by the DeviceManager. |
90 | /// \p executionState is tracks the state of the run that the node that |
91 | /// finished executing belongs to, \p err is the Error returned by the |
92 | /// DeviceManager, \p ctx is the ExecutionContext that contains the outputs |
93 | /// produced by \p node during the run. |
94 | /// |
95 | /// The main purpose of this function is to help move computation off of the |
96 | /// DeviceManager thread pool on onto the one owned by this class. |
97 | void handleDeviceManagerResult(NetworkExecutionState *executionState, |
98 | Error err, |
99 | std::unique_ptr<ExecutionContext> ctx, |
100 | const DAGNode *node); |
101 | |
102 | /// The default number of workers in the thread pool. |
103 | constexpr static unsigned kNumWorkers = 3; |
104 | /// The thread pool used to drive execution. |
105 | folly::CPUThreadPoolExecutor threadPool_; |
106 | |
107 | /// Map of networkExecutionState pools for each network. |
108 | folly::Synchronized<std::unordered_map< |
109 | const DAGNode *, std::unique_ptr<NetworkExecutionStatePool>>> |
110 | states_; |
111 | |
112 | /// Barrier for making sure all asynchronous requests made to the |
113 | /// DeviceManager return before allowing destruction of the executor. |
114 | InflightBarrier inflightBarrier_; |
115 | /// Whether the executor is currently shutting down or not. |
116 | std::atomic<bool> shuttingDown_{false}; |
117 | |
118 | /// Map of available DeviceManagers. |
119 | const DeviceManagerMapTy &deviceManagers_; |
120 | }; |
121 | |
122 | } // namespace runtime |
123 | } // namespace glow |
124 | #endif |
125 | |