1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
17#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
18
19#include "tensorflow/core/lib/core/threadpool.h"
20#include "tensorflow/core/lib/histogram/histogram.h"
21#include "tensorflow/core/platform/context.h"
22#include "tensorflow/core/platform/mutex.h"
23#include "tensorflow/core/platform/thread_annotations.h"
24#include "tensorflow/core/protobuf/config.pb.h"
25
26namespace Eigen {
27struct ThreadPoolDevice;
28}
29
30namespace tensorflow {
31
32class RunHandler;
33
34// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
35// that can be used for tracking inter-op work for a given Session::Run().
36// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
37// 'active' when its unique_ptr is returned by Get() and is being used by a
38// client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
39//
40// Expected usage:
41//
42// * Create a single RunHandlerPool (say run_handler_pool_).
43//
44// * When a Session::Run() is invoked, obtain a handler by:
45// auto handler = run_handler_pool_->Get();
46//
47// * Use handler for scheduling all inter-op work by:
48// handler->ScheduleInterOpClosure(closure);
49//
50// This class is thread safe.
51class RunHandlerPool {
52 public:
53 explicit RunHandlerPool(int num_inter_op_threads);
54
55 RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads);
56 ~RunHandlerPool();
57
58 // Returns an inactive RunHandler from the pool.
59 //
60 // RunHandlers in RunHandlerPool are initially 'inactive'.
61 // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
62 // and is being used by a client. It becomes 'inactive' once more when the
63 // unique_ptr is destroyed.
64 //
65 // Will block unless there is an inactive handler.
66 std::unique_ptr<RunHandler> Get(
67 int64_t step_id = 0, int64_t timeout_in_ms = 0,
68 const RunOptions::Experimental::RunHandlerPoolOptions& options =
69 RunOptions::Experimental::RunHandlerPoolOptions());
70
71 // Get the priorities for active handlers. The return result is with the same
72 // order of the active handler list.
73 std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() const;
74
75 private:
76 class Impl;
77 friend class RunHandler;
78
79 std::unique_ptr<Impl> impl_;
80};
81
82// RunHandler can be used to schedule inter/intra-op closures to run on a global
83// pool shared across all Session::Run(s). The closures are enqueued to a
84// handler specific queue, from which the work is stolen in a priority order
85// (time of the Get() call).
86//
87// It can only be created via RunHandlerPool::Get().
88//
89// This class can be used instead of directly scheduling closures on a global
90// pool since it maintains a global view across all sessions and optimizes pool
91// scheduling to improve (median and tail) latency.
92//
93// This class is thread safe.
94class RunHandler {
95 public:
96 void ScheduleInterOpClosure(std::function<void()> fn);
97 thread::ThreadPoolInterface* AsIntraThreadPoolInterface();
98
99 ~RunHandler();
100
101 private:
102 class Impl;
103 friend class RunHandlerPool::Impl;
104
105 explicit RunHandler(Impl* impl);
106
107 Impl* impl_; // NOT OWNED.
108};
109
110namespace internal {
111
112// TODO(azaks): Refactor with thread:ThreadPool
113class RunHandlerEnvironment {
114 typedef Thread EnvThread;
115 struct TaskImpl {
116 std::function<void()> f;
117 Context context;
118 uint64 trace_id;
119 };
120 Env* const env_;
121 const ThreadOptions thread_options_;
122 const string name_;
123
124 public:
125 struct Task {
126 std::unique_ptr<TaskImpl> f;
127 };
128
129 RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options,
130 const string& name);
131
132 EnvThread* CreateThread(std::function<void()> f,
133 const std::string& thread_name);
134
135 Task CreateTask(std::function<void()> f);
136
137 void ExecuteTask(const Task& t);
138};
139
140typedef typename RunHandlerEnvironment::Task Task;
141typedef Eigen::RunQueue<Task, 1024> Queue;
142
143// To reduce cache misses, we use a doubly-linked list of Waiter structs and
144// queue them in LIFO order rather than the FIFO order used by a single
145// condition variable.
146struct Waiter {
147 Waiter() {
148 next = this;
149 prev = this;
150 }
151 condition_variable cv;
152 mutex mu;
153 Waiter* next;
154 Waiter* prev;
155};
156
157class ThreadWorkSource {
158 public:
159 ThreadWorkSource();
160
161 ~ThreadWorkSource();
162
163 Task EnqueueTask(Task t, bool is_blocking);
164
165 Task PopBlockingTask();
166
167 Task PopNonBlockingTask(int start_index, bool search_from_all_queue);
168
169 void WaitForWork(int max_sleep_micros);
170
171 int TaskQueueSize(bool is_blocking);
172
173 int64_t GetTracemeId();
174
175 void SetTracemeId(int64_t value);
176
177 void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex);
178
179 int64_t GetInflightTaskCount(bool is_blocking);
180
181 void IncrementInflightTaskCount(bool is_blocking);
182
183 void DecrementInflightTaskCount(bool is_blocking);
184
185 unsigned NonBlockingWorkShardingFactor();
186
187 std::string ToString();
188
189 private:
190 struct NonBlockingQueue {
191 mutex queue_op_mu;
192 char pad[128];
193 Queue queue;
194 };
195
196 int32 non_blocking_work_sharding_factor_;
197 Eigen::MaxSizeVector<NonBlockingQueue*> non_blocking_work_queues_;
198
199 std::atomic<int64_t> blocking_inflight_;
200 std::atomic<int64_t> non_blocking_inflight_;
201
202 Queue blocking_work_queue_;
203 mutex blocking_queue_op_mu_;
204 char pad_[128];
205 mutex waiters_mu_;
206 Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_);
207 std::atomic<int64_t> traceme_id_;
208
209 mutex run_handler_waiter_mu_;
210 uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_);
211 mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_);
212 Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_);
213};
214
215class RunHandlerThreadPool {
216 public:
217 struct PerThread {
218 constexpr PerThread() : pool(nullptr), thread_id(-1) {}
219 RunHandlerThreadPool* pool; // Parent pool, or null for normal threads.
220 int thread_id; // Worker thread index in pool.
221 };
222
223 RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads,
224 Env* env, const ThreadOptions& thread_options,
225 const string& name,
226 Eigen::MaxSizeVector<mutex>* waiters_mu,
227 Eigen::MaxSizeVector<Waiter>* queue_waiters);
228
229 ~RunHandlerThreadPool();
230
231 void Start();
232
233 void StartOneThreadForTesting();
234
235 void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking,
236 std::function<void()> fn);
237
238 // Set work queues from which the thread 'tid' can steal its work.
239 // The request with start_request_idx will be attempted first. Other requests
240 // will be attempted in FIFO order based on their arrival time.
241 void SetThreadWorkSources(
242 int tid, int start_request_idx, uint64 version,
243 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources);
244
245 PerThread* GetPerThread();
246
247 int CurrentThreadId() const;
248
249 int NumThreads() const;
250
251 int NumBlockingThreads() const;
252
253 int NumNonBlockingThreads() const;
254
255 void WorkerLoop(int thread_id, bool may_steal_blocking_work);
256
257 // Search tasks from Requets range searching_range_start to
258 // searching_range_end. If there is no tasks in the search range and
259 // may_steal_blocking_work is true, then search from all requests.
260 Task FindTask(
261 int searching_range_start, int searching_range_end, int thread_id,
262 int sub_thread_pool_id, int max_blocking_inflight,
263 bool may_steal_blocking_work,
264 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
265 bool* task_from_blocking_queue, ThreadWorkSource** tws);
266
267 void WaitForWork(bool is_blocking, int thread_id,
268 int32_t max_blocking_inflight);
269
270 void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id);
271
272 private:
273 struct ThreadData {
274 ThreadData();
275 mutex mu;
276 uint64 new_version;
277 condition_variable sources_not_empty;
278 std::unique_ptr<Thread> thread;
279 int current_index;
280 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
281 new_thread_work_sources TF_GUARDED_BY(mu);
282
283 uint64 current_version;
284 // Should only be accessed by one thread.
285 std::unique_ptr<Eigen::MaxSizeVector<ThreadWorkSource*>>
286 current_thread_work_sources;
287
288 int sub_thread_pool_id;
289 };
290
291 const int num_threads_;
292 const int num_blocking_threads_;
293 const int num_non_blocking_threads_;
294 Eigen::MaxSizeVector<ThreadData> thread_data_;
295 internal::RunHandlerEnvironment env_;
296 std::atomic<bool> cancelled_;
297 string name_;
298 Eigen::MaxSizeVector<mutex>* waiters_mu_;
299 Eigen::MaxSizeVector<Waiter>* queue_waiters_;
300
301 bool use_sub_thread_pool_;
302 std::vector<int> num_threads_in_sub_thread_pool_;
303
304 // Threads in each sub thread pool will search tasks from the given
305 // start_request_percentage to end_request_percentage in a round robin
306 // fashion.
307 std::vector<double> sub_thread_pool_start_request_percentage_;
308 std::vector<double> sub_thread_pool_end_request_percentage_;
309};
310
311} // namespace internal
312
313} // end namespace tensorflow.
314
315#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
316