1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/framework/run_handler.h"
19
20#include <algorithm>
21#include <cmath>
22#include <list>
23#include <memory>
24
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26#include "tensorflow/core/framework/run_handler_util.h"
27#include "tensorflow/core/lib/core/threadpool_interface.h"
28#include "tensorflow/core/lib/strings/strcat.h"
29#include "tensorflow/core/platform/context.h"
30#include "tensorflow/core/platform/denormal.h"
31#include "tensorflow/core/platform/mutex.h"
32#include "tensorflow/core/platform/numa.h"
33#include "tensorflow/core/platform/setround.h"
34#include "tensorflow/core/platform/tracing.h"
35#include "tensorflow/core/profiler/lib/traceme.h"
36#include "tensorflow/core/util/ptr_util.h"
37
38namespace tensorflow {
39namespace {
40// LINT.IfChange
41static constexpr int32_t kMaxConcurrentHandlers = 128;
42// LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
43
44typedef typename internal::RunHandlerEnvironment::Task Task;
45typedef Eigen::RunQueue<Task, 1024> Queue;
46
47} // namespace
48
49namespace internal {
50RunHandlerEnvironment::RunHandlerEnvironment(
51 Env* env, const ThreadOptions& thread_options, const string& name)
52 : env_(env), thread_options_(thread_options), name_(name) {}
53
54RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
55 std::function<void()> f, const std::string& thread_name) {
56 return env_->StartThread(thread_options_, thread_name, [=]() {
57 // Set the processor flag to flush denormals to zero.
58 port::ScopedFlushDenormal flush;
59 // Set the processor rounding mode to ROUND TO NEAREST.
60 port::ScopedSetRound round(FE_TONEAREST);
61 if (thread_options_.numa_node != port::kNUMANoAffinity) {
62 port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
63 }
64 f();
65 });
66}
67
68RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
69 std::function<void()> f) {
70 uint64 id = 0;
71 if (tracing::EventCollector::IsEnabled()) {
72 id = tracing::GetUniqueArg();
73 tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
74 }
75 return Task{
76 std::unique_ptr<TaskImpl>(new TaskImpl{
77 std::move(f),
78 Context(ContextKind::kThread),
79 id,
80 }),
81 };
82}
83
84void RunHandlerEnvironment::ExecuteTask(const Task& t) {
85 WithContext wc(t.f->context);
86 tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
87 t.f->trace_id);
88 t.f->f();
89}
90
91void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
92 int max_sleep_micros) {
93 {
94 mutex_lock l(*mutex);
95 CHECK_EQ(waiter->next, waiter); // Crash OK.
96 CHECK_EQ(waiter->prev, waiter); // Crash OK.
97
98 // Add waiter to the LIFO queue
99 waiter->prev = queue_head;
100 waiter->next = queue_head->next;
101 waiter->next->prev = waiter;
102 waiter->prev->next = waiter;
103 }
104 {
105 mutex_lock l(waiter->mu);
106 // Wait on the condition variable
107 waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
108 }
109
110 mutex_lock l(*mutex);
111 // Remove waiter from the LIFO queue. Note even when a waiter wakes up due
112 // to a notification we cannot conclude the waiter is not in the queue.
113 // This is due to the fact that a thread preempted right before notifying
114 // may resume after a waiter got re-added.
115 if (waiter->next != waiter) {
116 CHECK(waiter->prev != waiter); // Crash OK.
117 waiter->next->prev = waiter->prev;
118 waiter->prev->next = waiter->next;
119 waiter->next = waiter;
120 waiter->prev = waiter;
121 } else {
122 CHECK_EQ(waiter->prev, waiter); // Crash OK.
123 }
124}
125
126ThreadWorkSource::ThreadWorkSource()
127 : non_blocking_work_sharding_factor_(
128 static_cast<int32>(ParamFromEnvWithDefault(
129 "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
130 non_blocking_work_queues_(non_blocking_work_sharding_factor_),
131 blocking_inflight_(0),
132 non_blocking_inflight_(0),
133 traceme_id_(0),
134 version_(0),
135 sub_thread_pool_waiter_(nullptr) {
136 queue_waiters_.next = &queue_waiters_;
137 queue_waiters_.prev = &queue_waiters_;
138 for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
139 non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
140 }
141}
142
143ThreadWorkSource::~ThreadWorkSource() {
144 for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
145 delete non_blocking_work_queues_[i];
146 }
147}
148
149Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
150 mutex* mu = nullptr;
151 Queue* task_queue = nullptr;
152 thread_local int64_t closure_counter = 0;
153
154 if (!is_blocking) {
155 int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
156 task_queue = &(non_blocking_work_queues_[queue_index]->queue);
157 mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
158 } else {
159 task_queue = &blocking_work_queue_;
160 mu = &blocking_queue_op_mu_;
161 }
162
163 {
164 mutex_lock l(*mu);
165 // For a given queue, only one thread can call PushFront.
166 t = task_queue->PushFront(std::move(t));
167 }
168
169 Waiter* w = nullptr;
170 static const bool use_sub_thread_pool =
171 ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
172
173 Waiter* waiter_queue;
174 mutex* waiter_queue_mu;
175 if (use_sub_thread_pool) {
176 // When we use multiple sub thread pools, free threads wait on sub
177 // thread pool waiting queues. Wake up threads from sub thread waiting
178 // queues.
179 // The waiting queues are defined at RunHandlerPool.
180 // Get the waiter_queue and corresponding mutex. Note, the thread work
181 // source may change afterwards if a new request comes or an old request
182 // finishes.
183 tf_shared_lock lock(run_handler_waiter_mu_);
184 waiter_queue = sub_thread_pool_waiter_;
185 waiter_queue_mu = sub_thread_pool_waiter_mu_;
186 } else {
187 waiter_queue = &queue_waiters_;
188 waiter_queue_mu = &waiters_mu_;
189 }
190 {
191 mutex_lock l(*waiter_queue_mu);
192 if (waiter_queue->next != waiter_queue) {
193 // Remove waiter from the LIFO queue
194 w = waiter_queue->next;
195
196 CHECK(w->prev != w); // Crash OK.
197 CHECK(w->next != w); // Crash OK.
198
199 w->next->prev = w->prev;
200 w->prev->next = w->next;
201
202 // Use `w->next == &w` to indicate that the waiter has been removed
203 // from the queue.
204 w->next = w;
205 w->prev = w;
206 }
207 }
208 if (w != nullptr) {
209 // We call notify_one() without any locks, so we can miss notifications.
210 // The wake up logic is best effort and a thread will wake in short
211 // period of time in case a notification is missed.
212 w->cv.notify_one();
213 }
214 VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
215 << traceme_id_.load(std::memory_order_relaxed);
216 return t;
217}
218
219Task ThreadWorkSource::PopBlockingTask() {
220 return blocking_work_queue_.PopBack();
221}
222
223Task ThreadWorkSource::PopNonBlockingTask(int start_index,
224 bool search_from_all_queue) {
225 Task t;
226 unsigned sharding_factor = NonBlockingWorkShardingFactor();
227 for (unsigned j = 0; j < sharding_factor; ++j) {
228 t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
229 ->queue.PopBack();
230 if (t.f) {
231 return t;
232 }
233 if (!search_from_all_queue) {
234 break;
235 }
236 }
237 return t;
238}
239
240void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
241 thread_local Waiter waiter;
242 WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
243}
244
245int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
246 if (is_blocking) {
247 return blocking_work_queue_.Size();
248 } else {
249 unsigned total_size = 0;
250 for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
251 total_size += non_blocking_work_queues_[i]->queue.Size();
252 }
253 return total_size;
254 }
255}
256
257int64_t ThreadWorkSource::GetTracemeId() {
258 return traceme_id_.load(std::memory_order_relaxed);
259}
260
261void ThreadWorkSource::SetTracemeId(int64_t value) { traceme_id_ = value; }
262
263void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
264 {
265 tf_shared_lock lock(run_handler_waiter_mu_);
266 // Most of the request won't change sub pool for recomputation.
267 // Optimization for avoiding holding exclusive lock to reduce contention.
268 if (sub_thread_pool_waiter_ == waiter) {
269 return;
270 }
271 // If the current version is a newer version, no need to update.
272 if (version_ > version) {
273 return;
274 }
275 }
276
277 mutex_lock l(run_handler_waiter_mu_);
278 sub_thread_pool_waiter_ = waiter;
279 sub_thread_pool_waiter_mu_ = mutex;
280 version_ = version;
281}
282
283int64_t ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
284 std::atomic<int64_t>* counter =
285 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
286 return counter->load(std::memory_order_relaxed);
287}
288
289void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
290 std::atomic<int64_t>* counter =
291 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
292 counter->fetch_add(1, std::memory_order_relaxed);
293}
294
295void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
296 std::atomic<int64_t>* counter =
297 is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
298 counter->fetch_sub(1, std::memory_order_relaxed);
299}
300
301unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
302 return non_blocking_work_sharding_factor_;
303}
304
305std::string ThreadWorkSource::ToString() {
306 return strings::StrCat("traceme_id = ", GetTracemeId(),
307 ", inter queue size = ", TaskQueueSize(true),
308 ", inter inflight = ", GetInflightTaskCount(true),
309 ", intra queue size = ", TaskQueueSize(false),
310 ", intra inflight = ", GetInflightTaskCount(false));
311}
312
313RunHandlerThreadPool::RunHandlerThreadPool(
314 int num_blocking_threads, int num_non_blocking_threads, Env* env,
315 const ThreadOptions& thread_options, const string& name,
316 Eigen::MaxSizeVector<mutex>* waiters_mu,
317 Eigen::MaxSizeVector<Waiter>* queue_waiters)
318 : num_threads_(num_blocking_threads + num_non_blocking_threads),
319 num_blocking_threads_(num_blocking_threads),
320 num_non_blocking_threads_(num_non_blocking_threads),
321 thread_data_(num_threads_),
322 env_(env, thread_options, name),
323 name_(name),
324 waiters_mu_(waiters_mu),
325 queue_waiters_(queue_waiters),
326 use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
327 "TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
328 num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
329 "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
330 std::vector<int>({num_blocking_threads / 2,
331 num_blocking_threads - num_blocking_threads / 2}))),
332 sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
333 "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
334 std::vector<double>({0, 0.4}))),
335 sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
336 "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
337 std::vector<double>({0.4, 1}))) {
338 thread_data_.resize(num_threads_);
339 VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
340 << num_blocking_threads_ << " blocking threads and "
341 << num_non_blocking_threads_ << " non-blocking threads.";
342}
343
344RunHandlerThreadPool::~RunHandlerThreadPool() {
345 VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
346
347 cancelled_ = true;
348 for (size_t i = 0; i < thread_data_.size(); ++i) {
349 {
350 mutex_lock l(thread_data_[i].mu);
351 thread_data_[i].sources_not_empty.notify_all();
352 }
353 thread_data_[i].thread.reset();
354 }
355}
356
357void RunHandlerThreadPool::Start() {
358 cancelled_ = false;
359 int num_blocking_threads = num_blocking_threads_;
360 for (int i = 0; i < num_threads_; i++) {
361 int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
362 for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
363 if (i < num_threads_in_sub_thread_pool_[j]) {
364 sub_thread_pool_id = j;
365 break;
366 }
367 }
368 thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
369 const bool is_blocking_thread = (i < num_blocking_threads) ? true : false;
370 // The blocking threads will handle both inter and intra op workload;
371 // non-blocking thread will handle intra op workload only; and the
372 // sub thread pool is only provided for blocking threads.
373 // Name the threads accordingly.
374 thread_data_[i].thread.reset(env_.CreateThread(
375 [this, is_blocking_thread, i, sub_thread_pool_id]() {
376 WorkerLoop(i, is_blocking_thread);
377 },
378 is_blocking_thread
379 ? strings::StrCat(name_, "_blocking_thread_", sub_thread_pool_id)
380 : strings::StrCat(name_, "_non_blocking_thread")));
381 }
382}
383
384void RunHandlerThreadPool::StartOneThreadForTesting() {
385 cancelled_ = false;
386 thread_data_[0].sub_thread_pool_id = 0;
387 thread_data_[0].thread.reset(
388 env_.CreateThread([this]() { WorkerLoop(0, true); }, name_));
389}
390
391void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
392 bool is_blocking,
393 std::function<void()> fn) {
394 Task t = env_.CreateTask(std::move(fn));
395 t = tws->EnqueueTask(std::move(t), is_blocking);
396 if (t.f) {
397 VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
398 << tws->GetTracemeId();
399 env_.ExecuteTask(t);
400 }
401}
402
403// TODO(donglin) Change the task steal order to be round-robin such that if
404// an attempt to steal task from request i failed, then attempt to steal task
405// from the next request in terms of the arrival time. This approach may
406// provide better performance due to less lock retention. The drawback is that
407// the profiler will be a bit harder to read.
408void RunHandlerThreadPool::SetThreadWorkSources(
409 int tid, int start_request_idx, uint64 version,
410 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
411 mutex_lock l(thread_data_[tid].mu);
412 if (version > thread_data_[tid].new_version) {
413 thread_data_[tid].new_version = version;
414 } else {
415 // A newer version is already updated. No need to update.
416 return;
417 }
418 thread_data_[tid].new_thread_work_sources->resize(0);
419 if (use_sub_thread_pool_) {
420 for (int i = 0; i < thread_work_sources.size(); ++i) {
421 thread_data_[tid].new_thread_work_sources->emplace_back(
422 thread_work_sources[i]);
423 }
424 } else {
425 thread_data_[tid].new_thread_work_sources->emplace_back(
426 thread_work_sources[start_request_idx]);
427 // The number of shards for the queue. Threads in each shard will
428 // prioritize different thread_work_sources. Increase the number of shards
429 // could decrease the contention in the queue. For example, when
430 // num_shards == 1: thread_work_sources are ordered as start_request_idx,
431 // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
432 // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
433 // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
434 // 4... for the other half of the threads.
435 static const int num_shards =
436 ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
437 int token = tid % num_shards;
438 for (int i = 0; i < num_shards; ++i) {
439 for (int j = token; j < thread_work_sources.size(); j += num_shards) {
440 if (j != start_request_idx) {
441 thread_data_[tid].new_thread_work_sources->emplace_back(
442 thread_work_sources[j]);
443 }
444 }
445 token = (token + 1) % num_shards;
446 }
447 thread_data_[tid].sources_not_empty.notify_all();
448 }
449}
450
451RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
452 thread_local RunHandlerThreadPool::PerThread per_thread_;
453 RunHandlerThreadPool::PerThread* pt = &per_thread_;
454 return pt;
455}
456
457int RunHandlerThreadPool::CurrentThreadId() const {
458 const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
459 if (pt->pool == this) {
460 return pt->thread_id;
461 } else {
462 return -1;
463 }
464}
465
466int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
467
468int RunHandlerThreadPool::NumBlockingThreads() const {
469 return num_blocking_threads_;
470}
471
472int RunHandlerThreadPool::NumNonBlockingThreads() const {
473 return num_non_blocking_threads_;
474}
475
476RunHandlerThreadPool::ThreadData::ThreadData()
477 : new_version(0),
478 current_index(0),
479 new_thread_work_sources(
480 new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
481 ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
482 kMaxConcurrentHandlers)))),
483 current_version(0),
484 current_thread_work_sources(
485 new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
486 ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
487 kMaxConcurrentHandlers)))) {}
488
489Task RunHandlerThreadPool::FindTask(
490 int searching_range_start, int searching_range_end, int thread_id,
491 int sub_thread_pool_id, int max_blocking_inflight,
492 bool may_steal_blocking_work,
493 const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
494 bool* task_from_blocking_queue, ThreadWorkSource** tws) {
495 Task t;
496 int current_index = thread_data_[thread_id].current_index;
497 *task_from_blocking_queue = false;
498
499 for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
500 if (current_index >= searching_range_end ||
501 current_index < searching_range_start) {
502 current_index = searching_range_start;
503 }
504 *tws = thread_work_sources[current_index];
505 ++current_index;
506
507 // For blocking thread, search for blocking tasks first.
508 if (may_steal_blocking_work &&
509 (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
510 t = (*tws)->PopBlockingTask();
511 if (t.f) {
512 *task_from_blocking_queue = true;
513 break;
514 }
515 }
516
517 // Search for non-blocking tasks.
518 t = (*tws)->PopNonBlockingTask(thread_id, true);
519 if (t.f) {
520 break;
521 }
522 }
523 thread_data_[thread_id].current_index = current_index;
524 return t;
525}
526
527// Main worker thread loop.
528void RunHandlerThreadPool::WorkerLoop(int thread_id,
529 bool may_steal_blocking_work) {
530 PerThread* pt = GetPerThread();
531 pt->pool = this;
532 pt->thread_id = thread_id;
533 static constexpr int32_t kMaxBlockingInflight = 10;
534
535 while (!cancelled_) {
536 Task t;
537 ThreadWorkSource* tws = nullptr;
538 bool task_from_blocking_queue = true;
539 int sub_thread_pool_id;
540 // Get the current thread work sources.
541 {
542 mutex_lock l(thread_data_[thread_id].mu);
543 if (thread_data_[thread_id].current_version <
544 thread_data_[thread_id].new_version) {
545 thread_data_[thread_id].current_version =
546 thread_data_[thread_id].new_version;
547 thread_data_[thread_id].current_thread_work_sources.swap(
548 thread_data_[thread_id].new_thread_work_sources);
549 }
550 }
551 Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
552 thread_data_[thread_id].current_thread_work_sources.get();
553 if (use_sub_thread_pool_) {
554 sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
555 int active_requests = thread_work_sources->size();
556 if (may_steal_blocking_work) {
557 // Each thread will first look for tasks from requests that belongs to
558 // its sub thread pool.
559 int search_range_start =
560 active_requests *
561 sub_thread_pool_start_request_percentage_[sub_thread_pool_id];
562 int search_range_end =
563 active_requests *
564 sub_thread_pool_end_request_percentage_[sub_thread_pool_id];
565 search_range_end =
566 std::min(active_requests,
567 std::max(search_range_end, search_range_start + 1));
568
569 t = FindTask(search_range_start, search_range_end, thread_id,
570 sub_thread_pool_id, kMaxBlockingInflight,
571 /*may_steal_blocking_work=*/true, *thread_work_sources,
572 &task_from_blocking_queue, &tws);
573 if (!t.f) {
574 // Search from all requests if the thread cannot find tasks from
575 // requests that belong to its own sub thread pool.
576 t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
577 kMaxBlockingInflight,
578 /*may_steal_blocking_work=*/true, *thread_work_sources,
579 &task_from_blocking_queue, &tws);
580 }
581 } else {
582 // For non-blocking threads, it will always search from all pending
583 // requests.
584 t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
585 kMaxBlockingInflight,
586 /*may_steal_blocking_work=*/false, *thread_work_sources,
587 &task_from_blocking_queue, &tws);
588 }
589 } else {
590 // TODO(chaox): Refactor the following code to share the logic with
591 // FindTask.
592 for (int i = 0; i < thread_work_sources->size(); ++i) {
593 tws = (*thread_work_sources)[i];
594 // We want a smallish numbers of inter threads since
595 // otherwise there will be contention in PropagateOutputs.
596 // This is best effort policy.
597 if (may_steal_blocking_work &&
598 tws->GetInflightTaskCount(true) < kMaxBlockingInflight) {
599 t = tws->PopBlockingTask();
600 if (t.f) {
601 break;
602 }
603 }
604 if (i == 0) {
605 // Always look for any work from the "primary" work source.
606 // This way when we wake up a thread for a new closure we are
607 // guaranteed it can be worked on.
608 t = tws->PopNonBlockingTask(thread_id, true);
609 if (t.f) {
610 task_from_blocking_queue = false;
611 break;
612 }
613 if (t.f) {
614 break;
615 }
616 } else {
617 t = tws->PopNonBlockingTask(thread_id, false);
618 if (t.f) {
619 task_from_blocking_queue = false;
620 break;
621 }
622 }
623 }
624 }
625 if (t.f) {
626 profiler::TraceMe activity(
627 [=] {
628 return strings::StrCat(task_from_blocking_queue ? "inter" : "intra",
629 " #id = ", tws->GetTracemeId(), " ",
630 thread_id, "#");
631 },
632 profiler::TraceMeLevel::kInfo);
633 VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra")
634 << " work from " << tws->GetTracemeId();
635 tws->IncrementInflightTaskCount(task_from_blocking_queue);
636 env_.ExecuteTask(t);
637 tws->DecrementInflightTaskCount(task_from_blocking_queue);
638 } else {
639 profiler::TraceMe activity(
640 [=] {
641 return strings::StrCat("Sleeping#thread_id=", thread_id, "#");
642 },
643 profiler::TraceMeLevel::kInfo);
644 if (VLOG_IS_ON(4)) {
645 for (int i = 0; i < thread_work_sources->size(); ++i) {
646 VLOG(4) << "source id " << i << " "
647 << (*thread_work_sources)[i]->ToString();
648 }
649 }
650 if (use_sub_thread_pool_) {
651 WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
652 } else {
653 WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
654 }
655 }
656 }
657}
658
659void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
660 int sub_thread_pool_id) {
661 const int kMaxSleepMicros = 250;
662
663 // The non-blocking thread will just sleep.
664 if (!is_blocking) {
665 Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
666 return;
667 }
668
669 thread_local Waiter waiter;
670 WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
671 &(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
672}
673
674void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
675 int32_t max_blocking_inflight) {
676 const int kMaxSleepMicros = 250;
677
678 // The non-blocking thread will just sleep.
679 if (!is_blocking) {
680 Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
681 return;
682 }
683
684 ThreadWorkSource* tws = nullptr;
685 {
686 mutex_lock l(thread_data_[thread_id].mu);
687 if (thread_data_[thread_id].new_version >
688 thread_data_[thread_id].current_version) {
689 thread_data_[thread_id].current_thread_work_sources.swap(
690 thread_data_[thread_id].new_thread_work_sources);
691 thread_data_[thread_id].current_version =
692 thread_data_[thread_id].new_version;
693 }
694 Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
695 thread_data_[thread_id].current_thread_work_sources.get();
696 while (!cancelled_ && thread_work_sources->empty()) {
697 // Wait until there is new request
698 thread_data_[thread_id].sources_not_empty.wait(l);
699 if (thread_data_[thread_id].new_version >
700 thread_data_[thread_id].current_version) {
701 thread_data_[thread_id].current_thread_work_sources.swap(
702 thread_data_[thread_id].new_thread_work_sources);
703 thread_data_[thread_id].current_version =
704 thread_data_[thread_id].new_version;
705 thread_work_sources =
706 thread_data_[thread_id].current_thread_work_sources.get();
707 }
708 }
709 if (cancelled_) {
710 return;
711 }
712 tws = (*thread_work_sources)[0];
713 }
714
715 if (tws->GetInflightTaskCount(true) >= max_blocking_inflight) {
716 // Sleep to reduce contention in PropagateOutputs
717 Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
718 }
719 tws->WaitForWork(kMaxSleepMicros);
720}
721
722} // namespace internal
723
724// Contains the concrete implementation of the RunHandler.
725// Externally visible RunHandler class simply forwards the work to this one.
726class RunHandler::Impl {
727 public:
728 explicit Impl(RunHandlerPool::Impl* pool_impl);
729
730 ~Impl() {}
731
732 thread::ThreadPoolInterface* thread_pool_interface() {
733 return thread_pool_interface_.get();
734 }
735
736 // Stores now time (in microseconds) since unix epoch when the handler is
737 // requested via RunHandlerPool::Get().
738 uint64 start_time_us() const { return start_time_us_; }
739 int64_t step_id() const { return step_id_; }
740 void ScheduleInterOpClosure(std::function<void()> fn);
741 void ScheduleIntraOpClosure(std::function<void()> fn);
742
743 void Reset(int64_t step_id,
744 const RunOptions::Experimental::RunHandlerPoolOptions& options);
745
746 RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
747
748 internal::ThreadWorkSource* tws() { return &tws_; }
749
750 int64_t priority() { return options_.priority(); }
751
752 private:
753 class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface {
754 public:
755 explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl)
756 : run_handler_impl_(run_handler_impl) {}
757 ~ThreadPoolInterfaceWrapper() override {}
758 void Schedule(std::function<void()> fn) override;
759 int NumThreads() const override;
760 int CurrentThreadId() const override;
761
762 private:
763 RunHandler::Impl* run_handler_impl_ = nullptr;
764 };
765
766 RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
767 uint64 start_time_us_;
768 int64_t step_id_;
769 std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
770 internal::ThreadWorkSource tws_;
771 RunOptions::Experimental::RunHandlerPoolOptions options_;
772};
773
774// Contains shared state across all run handlers present in the pool. Also
775// responsible for pool management decisions.
776// This class is thread safe.
777class RunHandlerPool::Impl {
778 public:
779 explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
780 : max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
781 "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
782 waiters_mu_(
783 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
784 queue_waiters_(
785 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
786 run_handler_thread_pool_(new internal::RunHandlerThreadPool(
787 num_inter_op_threads, num_intra_op_threads, Env::Default(),
788 ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
789 &queue_waiters_)),
790 iterations_(0),
791 version_(0),
792 sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
793 "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
794 std::vector<double>({1}))) {
795 VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
796 free_handlers_.reserve(max_handlers_);
797 handlers_.reserve(max_handlers_);
798 for (int i = 0; i < max_handlers_; ++i) {
799 handlers_.emplace_back(new RunHandler::Impl(this));
800 free_handlers_.push_back(handlers_.back().get());
801 }
802 queue_waiters_.resize(
803 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
804 waiters_mu_.resize(
805 ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
806 for (auto& queue_waiter : queue_waiters_) {
807 queue_waiter.next = &queue_waiter;
808 queue_waiter.prev = &queue_waiter;
809 }
810 run_handler_thread_pool_->Start();
811 }
812
813 ~Impl() {
814 // Sanity check that all handlers have been returned back to the pool before
815 // destruction.
816 DCHECK_EQ(handlers_.size(), max_handlers_);
817 DCHECK_EQ(free_handlers_.size(), handlers_.size());
818 DCHECK_EQ(sorted_active_handlers_.size(), 0);
819 // Stop the threads in run_handler_thread_pool_ before freeing other
820 // pointers. Otherwise a thread may try to access a pointer after the
821 // pointer has been freed.
822 run_handler_thread_pool_.reset();
823 }
824
825 internal::RunHandlerThreadPool* run_handler_thread_pool() {
826 return run_handler_thread_pool_.get();
827 }
828
829 bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
830 return !free_handlers_.empty();
831 }
832
833 std::unique_ptr<RunHandler> Get(
834 int64_t step_id, int64_t timeout_in_ms,
835 const RunOptions::Experimental::RunHandlerPoolOptions& options)
836 TF_LOCKS_EXCLUDED(mu_) {
837 thread_local std::unique_ptr<
838 Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
839 thread_work_sources =
840 std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
841 new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
842 static_cast<int32>(ParamFromEnvWithDefault(
843 "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
844 kMaxConcurrentHandlers))));
845 uint64 version;
846 int num_active_requests;
847 RunHandler::Impl* handler_impl;
848 {
849 mutex_lock l(mu_);
850 if (!has_free_handler()) {
851 profiler::TraceMe activity(
852 [&] {
853 return strings::StrCat("WaitingForHandler#step_id=", step_id,
854 "#");
855 },
856 profiler::TraceMeLevel::kInfo);
857 TRACESTRING(
858 strings::StrCat("RunHandlerPool::Impl::Get waiting for a handler "
859 "with timeout in millisecond",
860 timeout_in_ms));
861 if (timeout_in_ms == 0) {
862 mu_.Await(Condition(this, &Impl::has_free_handler));
863 } else if (!mu_.AwaitWithDeadline(
864 Condition(this, &Impl::has_free_handler),
865 EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) {
866 return nullptr;
867 }
868 }
869 // Remove the last entry from free_handlers_ and add to the end of
870 // sorted_active_handlers_.
871 handler_impl = free_handlers_.back();
872 handler_impl->Reset(step_id, options);
873 free_handlers_.pop_back();
874
875 num_active_requests = sorted_active_handlers_.size() + 1;
876 thread_work_sources->resize(num_active_requests);
877 int priority = options.priority();
878 auto it = sorted_active_handlers_.cbegin();
879 bool new_handler_inserted = false;
880 for (int i = 0; i < num_active_requests; ++i) {
881 if (!new_handler_inserted && (it == sorted_active_handlers_.cend() ||
882 priority > (*it)->priority())) {
883 sorted_active_handlers_.insert(it, handler_impl);
884 new_handler_inserted = true;
885 // Point to the newly added handler.
886 --it;
887 }
888 (*thread_work_sources)[i] = (*it)->tws();
889 ++it;
890 }
891 version = ++version_;
892 }
893 RecomputePoolStats(num_active_requests, version, *thread_work_sources);
894 return WrapUnique<RunHandler>(new RunHandler(handler_impl));
895 }
896
897 void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) {
898 mutex_lock l(mu_);
899 DCHECK_GT(sorted_active_handlers_.size(), 0);
900
901 CHECK_EQ(handler->tws()->TaskQueueSize(true), 0); // Crash OK.
902 CHECK_EQ(handler->tws()->TaskQueueSize(false), 0); // Crash OK.
903
904 uint64 now = tensorflow::EnvTime::NowMicros();
905 double elapsed = (now - handler->start_time_us()) / 1000.0;
906 time_hist_.Add(elapsed);
907
908 // Erase from and update sorted_active_handlers_. Add it to the end of
909 // free_handlers_.
910 auto iter = std::find(sorted_active_handlers_.begin(),
911 sorted_active_handlers_.end(), handler);
912 DCHECK(iter != sorted_active_handlers_.end())
913 << "Unexpected handler: " << handler
914 << " is being requested for release";
915
916 // Remove this handler from this list and add it to the list of free
917 // handlers.
918 sorted_active_handlers_.erase(iter);
919 free_handlers_.push_back(handler);
920 DCHECK_LE(free_handlers_.size(), max_handlers_);
921 LogInfo();
922
923 // We do not recompute pool stats during release. The side effect is that
924 // there may be empty thread work sources in the queue. However, any new
925 // requests will trigger recomputation.
926 }
927
928 std::vector<int64_t> GetActiveHandlerPrioritiesForTesting()
929 TF_LOCKS_EXCLUDED(mu_) {
930 mutex_lock l(mu_);
931 std::vector<int64_t> ret;
932 for (const auto& handler_impl : sorted_active_handlers_) {
933 ret.push_back(handler_impl->priority());
934 }
935 return ret;
936 }
937
938 private:
939 void RecomputePoolStats(
940 int num_active_requests, uint64 version,
941 const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
942 thread_work_sources);
943
944 void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
945
946 // Maximum number of handlers pre-created during pool construction time. The
947 // number has been chosen expecting each handler might at least want 1
948 // inter-op thread for execution (during compute intensive workloads like
949 // inference).
950 const int max_handlers_;
951
952 Eigen::MaxSizeVector<mutex> waiters_mu_;
953 Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
954
955 std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
956 // Thread compatible part used only by lock under RunHandlerPool.
957 // Handlers are sorted by start time.
958 // TODO(azaks): sort by the remaining latency budget.
959 // TODO(chaox): Consider other data structure for maintaining the sorted
960 // active handlers if the searching overhead(currently O(n)) becomes the
961 // bottleneck.
962 std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_);
963 std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_);
964 std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_);
965
966 // Histogram of elapsed runtime of every handler (in ms).
967 histogram::Histogram time_hist_ TF_GUARDED_BY(mu_);
968
969 int64_t iterations_ TF_GUARDED_BY(mu_);
970 mutex mu_;
971 int64_t version_ TF_GUARDED_BY(mu_);
972 const std::vector<double> sub_thread_pool_end_request_percentage_;
973};
974
975void RunHandlerPool::Impl::RecomputePoolStats(
976 int num_active_requests, uint64 version,
977 const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
978 thread_work_sources) {
979 if (num_active_requests == 0) return;
980
981 int sub_thread_pool_id = 0;
982 for (int i = 0; i < num_active_requests; ++i) {
983 while (
984 sub_thread_pool_id <
985 sub_thread_pool_end_request_percentage_.size() - 1 &&
986 i >= num_active_requests *
987 sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) {
988 sub_thread_pool_id++;
989 }
990 thread_work_sources[i]->SetWaiter(version,
991 &queue_waiters_[sub_thread_pool_id],
992 &waiters_mu_[sub_thread_pool_id]);
993 }
994
995 int num_threads = run_handler_thread_pool()->NumThreads();
996 int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads();
997 int num_non_blocking_threads = num_threads - num_blocking_threads;
998
999 std::vector<int> request_idx_list = ChooseRequestsWithExponentialDistribution(
1000 num_active_requests, num_blocking_threads);
1001 for (int i = 0; i < num_blocking_threads; ++i) {
1002 VLOG(2) << "Set work for tid=" << i
1003 << " with start_request_idx=" << request_idx_list[i];
1004 run_handler_thread_pool()->SetThreadWorkSources(
1005 i, request_idx_list[i], version, thread_work_sources);
1006 }
1007
1008 request_idx_list = ChooseRequestsWithExponentialDistribution(
1009 num_active_requests, num_non_blocking_threads);
1010 for (int i = 0; i < num_non_blocking_threads; ++i) {
1011 VLOG(2) << "Set work for tid=" << (i + num_blocking_threads)
1012 << " with start_request_idx=" << request_idx_list[i];
1013 run_handler_thread_pool()->SetThreadWorkSources(
1014 i + num_blocking_threads, request_idx_list[i], version,
1015 thread_work_sources);
1016 }
1017}
1018
1019void RunHandlerPool::Impl::LogInfo() {
1020 if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
1021 int num_active_requests = sorted_active_handlers_.size();
1022 VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
1023 VLOG(1) << "Active session runs: " << num_active_requests;
1024 uint64 now = tensorflow::Env::Default()->NowMicros();
1025 string times_str = "";
1026 string ids_str = "";
1027 auto it = sorted_active_handlers_.cbegin();
1028 for (int i = 0; i < num_active_requests; ++i) {
1029 if (i > 0) {
1030 times_str += " ";
1031 ids_str += " ";
1032 }
1033
1034 times_str +=
1035 strings::StrCat((now - (*it)->start_time_us()) / 1000.0, " ms.");
1036 ids_str += strings::StrCat((*it)->tws()->GetTracemeId());
1037 ++it;
1038 }
1039 VLOG(1) << "Elapsed times are: " << times_str;
1040 VLOG(1) << "Step ids are: " << ids_str;
1041 }
1042}
1043
1044// It is important to return a value such as:
1045// CurrentThreadId() in [0, NumThreads)
1046int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const {
1047 return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads();
1048}
1049
1050int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const {
1051 return run_handler_impl_->pool_impl_->run_handler_thread_pool()
1052 ->CurrentThreadId();
1053}
1054
1055void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule(
1056 std::function<void()> fn) {
1057 return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn));
1058}
1059
1060RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
1061 : pool_impl_(pool_impl) {
1062 thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this));
1063 Reset(0, RunOptions::Experimental::RunHandlerPoolOptions());
1064}
1065
1066void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
1067 VLOG(3) << "Scheduling inter work for " << tws()->GetTracemeId();
1068 pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
1069 std::move(fn));
1070}
1071
1072void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) {
1073 VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
1074 pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
1075 std::move(fn));
1076}
1077
1078void RunHandler::Impl::Reset(
1079 int64_t step_id,
1080 const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1081 start_time_us_ = tensorflow::Env::Default()->NowMicros();
1082 step_id_ = step_id;
1083 options_ = options;
1084 tws_.SetTracemeId(step_id);
1085}
1086
1087RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
1088 : impl_(new Impl(num_inter_op_threads, 0)) {}
1089
1090RunHandlerPool::RunHandlerPool(int num_inter_op_threads,
1091 int num_intra_op_threads)
1092 : impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {}
1093
1094RunHandlerPool::~RunHandlerPool() {}
1095
1096std::unique_ptr<RunHandler> RunHandlerPool::Get(
1097 int64_t step_id, int64_t timeout_in_ms,
1098 const RunOptions::Experimental::RunHandlerPoolOptions& options) {
1099 return impl_->Get(step_id, timeout_in_ms, options);
1100}
1101
1102std::vector<int64_t> RunHandlerPool::GetActiveHandlerPrioritiesForTesting()
1103 const {
1104 return impl_->GetActiveHandlerPrioritiesForTesting();
1105}
1106
1107RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
1108
1109void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
1110 impl_->ScheduleInterOpClosure(std::move(fn));
1111}
1112
1113thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() {
1114 return impl_->thread_pool_interface();
1115}
1116
1117RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
1118
1119} // namespace tensorflow
1120