1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ |
18 | |
19 | #include "tensorflow/core/lib/core/threadpool.h" |
20 | #include "tensorflow/core/lib/histogram/histogram.h" |
21 | #include "tensorflow/core/platform/context.h" |
22 | #include "tensorflow/core/platform/mutex.h" |
23 | #include "tensorflow/core/platform/thread_annotations.h" |
24 | #include "tensorflow/core/protobuf/config.pb.h" |
25 | |
26 | namespace Eigen { |
27 | struct ThreadPoolDevice; |
28 | } |
29 | |
30 | namespace tensorflow { |
31 | |
32 | class RunHandler; |
33 | |
34 | // RunHandlerPool is a fixed size pool of pre-allocated RunHandlers |
35 | // that can be used for tracking inter-op work for a given Session::Run(). |
36 | // RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes |
37 | // 'active' when its unique_ptr is returned by Get() and is being used by a |
38 | // client. It becomes 'inactive' once more when its unique_ptr gets destroyed. |
39 | // |
40 | // Expected usage: |
41 | // |
42 | // * Create a single RunHandlerPool (say run_handler_pool_). |
43 | // |
44 | // * When a Session::Run() is invoked, obtain a handler by: |
45 | // auto handler = run_handler_pool_->Get(); |
46 | // |
47 | // * Use handler for scheduling all inter-op work by: |
48 | // handler->ScheduleInterOpClosure(closure); |
49 | // |
50 | // This class is thread safe. |
51 | class RunHandlerPool { |
52 | public: |
53 | explicit RunHandlerPool(int num_inter_op_threads); |
54 | |
55 | RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads); |
56 | ~RunHandlerPool(); |
57 | |
58 | // Returns an inactive RunHandler from the pool. |
59 | // |
60 | // RunHandlers in RunHandlerPool are initially 'inactive'. |
61 | // A RunHandler becomes 'active' when its unique_ptr its returned by Get() |
62 | // and is being used by a client. It becomes 'inactive' once more when the |
63 | // unique_ptr is destroyed. |
64 | // |
65 | // Will block unless there is an inactive handler. |
66 | std::unique_ptr<RunHandler> Get( |
67 | int64_t step_id = 0, int64_t timeout_in_ms = 0, |
68 | const RunOptions::Experimental::RunHandlerPoolOptions& options = |
69 | RunOptions::Experimental::RunHandlerPoolOptions()); |
70 | |
71 | // Get the priorities for active handlers. The return result is with the same |
72 | // order of the active handler list. |
73 | std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() const; |
74 | |
75 | private: |
76 | class Impl; |
77 | friend class RunHandler; |
78 | |
79 | std::unique_ptr<Impl> impl_; |
80 | }; |
81 | |
82 | // RunHandler can be used to schedule inter/intra-op closures to run on a global |
83 | // pool shared across all Session::Run(s). The closures are enqueued to a |
84 | // handler specific queue, from which the work is stolen in a priority order |
85 | // (time of the Get() call). |
86 | // |
87 | // It can only be created via RunHandlerPool::Get(). |
88 | // |
89 | // This class can be used instead of directly scheduling closures on a global |
90 | // pool since it maintains a global view across all sessions and optimizes pool |
91 | // scheduling to improve (median and tail) latency. |
92 | // |
93 | // This class is thread safe. |
94 | class RunHandler { |
95 | public: |
96 | void ScheduleInterOpClosure(std::function<void()> fn); |
97 | thread::ThreadPoolInterface* AsIntraThreadPoolInterface(); |
98 | |
99 | ~RunHandler(); |
100 | |
101 | private: |
102 | class Impl; |
103 | friend class RunHandlerPool::Impl; |
104 | |
105 | explicit RunHandler(Impl* impl); |
106 | |
107 | Impl* impl_; // NOT OWNED. |
108 | }; |
109 | |
110 | namespace internal { |
111 | |
112 | // TODO(azaks): Refactor with thread:ThreadPool |
113 | class RunHandlerEnvironment { |
114 | typedef Thread EnvThread; |
115 | struct TaskImpl { |
116 | std::function<void()> f; |
117 | Context context; |
118 | uint64 trace_id; |
119 | }; |
120 | Env* const env_; |
121 | const ThreadOptions thread_options_; |
122 | const string name_; |
123 | |
124 | public: |
125 | struct Task { |
126 | std::unique_ptr<TaskImpl> f; |
127 | }; |
128 | |
129 | RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options, |
130 | const string& name); |
131 | |
132 | EnvThread* CreateThread(std::function<void()> f, |
133 | const std::string& thread_name); |
134 | |
135 | Task CreateTask(std::function<void()> f); |
136 | |
137 | void ExecuteTask(const Task& t); |
138 | }; |
139 | |
140 | typedef typename RunHandlerEnvironment::Task Task; |
141 | typedef Eigen::RunQueue<Task, 1024> Queue; |
142 | |
143 | // To reduce cache misses, we use a doubly-linked list of Waiter structs and |
144 | // queue them in LIFO order rather than the FIFO order used by a single |
145 | // condition variable. |
146 | struct Waiter { |
147 | Waiter() { |
148 | next = this; |
149 | prev = this; |
150 | } |
151 | condition_variable cv; |
152 | mutex mu; |
153 | Waiter* next; |
154 | Waiter* prev; |
155 | }; |
156 | |
157 | class ThreadWorkSource { |
158 | public: |
159 | ThreadWorkSource(); |
160 | |
161 | ~ThreadWorkSource(); |
162 | |
163 | Task EnqueueTask(Task t, bool is_blocking); |
164 | |
165 | Task PopBlockingTask(); |
166 | |
167 | Task PopNonBlockingTask(int start_index, bool search_from_all_queue); |
168 | |
169 | void WaitForWork(int max_sleep_micros); |
170 | |
171 | int TaskQueueSize(bool is_blocking); |
172 | |
173 | int64_t GetTracemeId(); |
174 | |
175 | void SetTracemeId(int64_t value); |
176 | |
177 | void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex); |
178 | |
179 | int64_t GetInflightTaskCount(bool is_blocking); |
180 | |
181 | void IncrementInflightTaskCount(bool is_blocking); |
182 | |
183 | void DecrementInflightTaskCount(bool is_blocking); |
184 | |
185 | unsigned NonBlockingWorkShardingFactor(); |
186 | |
187 | std::string ToString(); |
188 | |
189 | private: |
190 | struct NonBlockingQueue { |
191 | mutex queue_op_mu; |
192 | char pad[128]; |
193 | Queue queue; |
194 | }; |
195 | |
196 | int32 non_blocking_work_sharding_factor_; |
197 | Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_; |
198 | |
199 | std::atomic<int64_t> blocking_inflight_; |
200 | std::atomic<int64_t> non_blocking_inflight_; |
201 | |
202 | Queue blocking_work_queue_; |
203 | mutex blocking_queue_op_mu_; |
204 | char pad_[128]; |
205 | mutex waiters_mu_; |
206 | Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); |
207 | std::atomic<int64_t> traceme_id_; |
208 | |
209 | mutex run_handler_waiter_mu_; |
210 | uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_); |
211 | mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_); |
212 | Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); |
213 | }; |
214 | |
215 | class RunHandlerThreadPool { |
216 | public: |
217 | struct PerThread { |
218 | constexpr PerThread() : pool(nullptr), thread_id(-1) {} |
219 | RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. |
220 | int thread_id; // Worker thread index in pool. |
221 | }; |
222 | |
223 | RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, |
224 | Env* env, const ThreadOptions& thread_options, |
225 | const string& name, |
226 | Eigen::MaxSizeVector<mutex>* waiters_mu, |
227 | Eigen::MaxSizeVector<Waiter>* queue_waiters); |
228 | |
229 | ~RunHandlerThreadPool(); |
230 | |
231 | void Start(); |
232 | |
233 | void StartOneThreadForTesting(); |
234 | |
235 | void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, |
236 | std::function<void()> fn); |
237 | |
238 | // Set work queues from which the thread 'tid' can steal its work. |
239 | // The request with start_request_idx will be attempted first. Other requests |
240 | // will be attempted in FIFO order based on their arrival time. |
241 | void SetThreadWorkSources( |
242 | int tid, int start_request_idx, uint64 version, |
243 | const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources); |
244 | |
245 | PerThread* GetPerThread(); |
246 | |
247 | int CurrentThreadId() const; |
248 | |
249 | int NumThreads() const; |
250 | |
251 | int NumBlockingThreads() const; |
252 | |
253 | int NumNonBlockingThreads() const; |
254 | |
255 | void WorkerLoop(int thread_id, bool may_steal_blocking_work); |
256 | |
257 | // Search tasks from Requets range searching_range_start to |
258 | // searching_range_end. If there is no tasks in the search range and |
259 | // may_steal_blocking_work is true, then search from all requests. |
260 | Task FindTask( |
261 | int searching_range_start, int searching_range_end, int thread_id, |
262 | int sub_thread_pool_id, int max_blocking_inflight, |
263 | bool may_steal_blocking_work, |
264 | const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources, |
265 | bool* task_from_blocking_queue, ThreadWorkSource** tws); |
266 | |
267 | void WaitForWork(bool is_blocking, int thread_id, |
268 | int32_t max_blocking_inflight); |
269 | |
270 | void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id); |
271 | |
272 | private: |
273 | struct ThreadData { |
274 | ThreadData(); |
275 | mutex mu; |
276 | uint64 new_version; |
277 | condition_variable sources_not_empty; |
278 | std::unique_ptr<Thread> thread; |
279 | int current_index; |
280 | std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> |
281 | new_thread_work_sources TF_GUARDED_BY(mu); |
282 | |
283 | uint64 current_version; |
284 | // Should only be accessed by one thread. |
285 | std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>> |
286 | current_thread_work_sources; |
287 | |
288 | int sub_thread_pool_id; |
289 | }; |
290 | |
291 | const int num_threads_; |
292 | const int num_blocking_threads_; |
293 | const int num_non_blocking_threads_; |
294 | Eigen::MaxSizeVector<ThreadData> thread_data_; |
295 | internal::RunHandlerEnvironment env_; |
296 | std::atomic<bool> cancelled_; |
297 | string name_; |
298 | Eigen::MaxSizeVector<mutex>* waiters_mu_; |
299 | Eigen::MaxSizeVector<Waiter>* queue_waiters_; |
300 | |
301 | bool use_sub_thread_pool_; |
302 | std::vector<int> num_threads_in_sub_thread_pool_; |
303 | |
304 | // Threads in each sub thread pool will search tasks from the given |
305 | // start_request_percentage to end_request_percentage in a round robin |
306 | // fashion. |
307 | std::vector<double> sub_thread_pool_start_request_percentage_; |
308 | std::vector<double> sub_thread_pool_end_request_percentage_; |
309 | }; |
310 | |
311 | } // namespace internal |
312 | |
313 | } // end namespace tensorflow. |
314 | |
315 | #endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ |
316 | |