1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/framework/run_handler.h" |
19 | |
20 | #include <algorithm> |
21 | #include <cmath> |
22 | #include <list> |
23 | #include <memory> |
24 | |
25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
26 | #include "tensorflow/core/framework/run_handler_util.h" |
27 | #include "tensorflow/core/lib/core/threadpool_interface.h" |
28 | #include "tensorflow/core/lib/strings/strcat.h" |
29 | #include "tensorflow/core/platform/context.h" |
30 | #include "tensorflow/core/platform/denormal.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/numa.h" |
33 | #include "tensorflow/core/platform/setround.h" |
34 | #include "tensorflow/core/platform/tracing.h" |
35 | #include "tensorflow/core/profiler/lib/traceme.h" |
36 | #include "tensorflow/core/util/ptr_util.h" |
37 | |
38 | namespace tensorflow { |
39 | namespace { |
40 | // LINT.IfChange |
41 | static constexpr int32_t kMaxConcurrentHandlers = 128; |
42 | // LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc) |
43 | |
44 | typedef typename internal::RunHandlerEnvironment::Task Task; |
45 | typedef Eigen::RunQueue<Task, 1024> Queue; |
46 | |
47 | } // namespace |
48 | |
49 | namespace internal { |
50 | RunHandlerEnvironment::RunHandlerEnvironment( |
51 | Env* env, const ThreadOptions& thread_options, const string& name) |
52 | : env_(env), thread_options_(thread_options), name_(name) {} |
53 | |
54 | RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread( |
55 | std::function<void()> f, const std::string& thread_name) { |
56 | return env_->StartThread(thread_options_, thread_name, [=]() { |
57 | // Set the processor flag to flush denormals to zero. |
58 | port::ScopedFlushDenormal flush; |
59 | // Set the processor rounding mode to ROUND TO NEAREST. |
60 | port::ScopedSetRound round(FE_TONEAREST); |
61 | if (thread_options_.numa_node != port::kNUMANoAffinity) { |
62 | port::NUMASetThreadNodeAffinity(thread_options_.numa_node); |
63 | } |
64 | f(); |
65 | }); |
66 | } |
67 | |
68 | RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask( |
69 | std::function<void()> f) { |
70 | uint64 id = 0; |
71 | if (tracing::EventCollector::IsEnabled()) { |
72 | id = tracing::GetUniqueArg(); |
73 | tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id); |
74 | } |
75 | return Task{ |
76 | std::unique_ptr<TaskImpl>(new TaskImpl{ |
77 | std::move(f), |
78 | Context(ContextKind::kThread), |
79 | id, |
80 | }), |
81 | }; |
82 | } |
83 | |
84 | void RunHandlerEnvironment::ExecuteTask(const Task& t) { |
85 | WithContext wc(t.f->context); |
86 | tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, |
87 | t.f->trace_id); |
88 | t.f->f(); |
89 | } |
90 | |
91 | void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex, |
92 | int max_sleep_micros) { |
93 | { |
94 | mutex_lock l(*mutex); |
95 | CHECK_EQ(waiter->next, waiter); // Crash OK. |
96 | CHECK_EQ(waiter->prev, waiter); // Crash OK. |
97 | |
98 | // Add waiter to the LIFO queue |
99 | waiter->prev = queue_head; |
100 | waiter->next = queue_head->next; |
101 | waiter->next->prev = waiter; |
102 | waiter->prev->next = waiter; |
103 | } |
104 | { |
105 | mutex_lock l(waiter->mu); |
106 | // Wait on the condition variable |
107 | waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros)); |
108 | } |
109 | |
110 | mutex_lock l(*mutex); |
111 | // Remove waiter from the LIFO queue. Note even when a waiter wakes up due |
112 | // to a notification we cannot conclude the waiter is not in the queue. |
113 | // This is due to the fact that a thread preempted right before notifying |
114 | // may resume after a waiter got re-added. |
115 | if (waiter->next != waiter) { |
116 | CHECK(waiter->prev != waiter); // Crash OK. |
117 | waiter->next->prev = waiter->prev; |
118 | waiter->prev->next = waiter->next; |
119 | waiter->next = waiter; |
120 | waiter->prev = waiter; |
121 | } else { |
122 | CHECK_EQ(waiter->prev, waiter); // Crash OK. |
123 | } |
124 | } |
125 | |
126 | ThreadWorkSource::ThreadWorkSource() |
127 | : non_blocking_work_sharding_factor_( |
128 | static_cast<int32>(ParamFromEnvWithDefault( |
129 | "TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES" , 1))), |
130 | non_blocking_work_queues_(non_blocking_work_sharding_factor_), |
131 | blocking_inflight_(0), |
132 | non_blocking_inflight_(0), |
133 | traceme_id_(0), |
134 | version_(0), |
135 | sub_thread_pool_waiter_(nullptr) { |
136 | queue_waiters_.next = &queue_waiters_; |
137 | queue_waiters_.prev = &queue_waiters_; |
138 | for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) { |
139 | non_blocking_work_queues_.emplace_back(new NonBlockingQueue()); |
140 | } |
141 | } |
142 | |
143 | ThreadWorkSource::~ThreadWorkSource() { |
144 | for (int i = 0; i < non_blocking_work_queues_.size(); ++i) { |
145 | delete non_blocking_work_queues_[i]; |
146 | } |
147 | } |
148 | |
149 | Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) { |
150 | mutex* mu = nullptr; |
151 | Queue* task_queue = nullptr; |
152 | thread_local int64_t closure_counter = 0; |
153 | |
154 | if (!is_blocking) { |
155 | int queue_index = ++closure_counter % non_blocking_work_sharding_factor_; |
156 | task_queue = &(non_blocking_work_queues_[queue_index]->queue); |
157 | mu = &non_blocking_work_queues_[queue_index]->queue_op_mu; |
158 | } else { |
159 | task_queue = &blocking_work_queue_; |
160 | mu = &blocking_queue_op_mu_; |
161 | } |
162 | |
163 | { |
164 | mutex_lock l(*mu); |
165 | // For a given queue, only one thread can call PushFront. |
166 | t = task_queue->PushFront(std::move(t)); |
167 | } |
168 | |
169 | Waiter* w = nullptr; |
170 | static const bool use_sub_thread_pool = |
171 | ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL" , false); |
172 | |
173 | Waiter* waiter_queue; |
174 | mutex* waiter_queue_mu; |
175 | if (use_sub_thread_pool) { |
176 | // When we use multiple sub thread pools, free threads wait on sub |
177 | // thread pool waiting queues. Wake up threads from sub thread waiting |
178 | // queues. |
179 | // The waiting queues are defined at RunHandlerPool. |
180 | // Get the waiter_queue and corresponding mutex. Note, the thread work |
181 | // source may change afterwards if a new request comes or an old request |
182 | // finishes. |
183 | tf_shared_lock lock(run_handler_waiter_mu_); |
184 | waiter_queue = sub_thread_pool_waiter_; |
185 | waiter_queue_mu = sub_thread_pool_waiter_mu_; |
186 | } else { |
187 | waiter_queue = &queue_waiters_; |
188 | waiter_queue_mu = &waiters_mu_; |
189 | } |
190 | { |
191 | mutex_lock l(*waiter_queue_mu); |
192 | if (waiter_queue->next != waiter_queue) { |
193 | // Remove waiter from the LIFO queue |
194 | w = waiter_queue->next; |
195 | |
196 | CHECK(w->prev != w); // Crash OK. |
197 | CHECK(w->next != w); // Crash OK. |
198 | |
199 | w->next->prev = w->prev; |
200 | w->prev->next = w->next; |
201 | |
202 | // Use `w->next == &w` to indicate that the waiter has been removed |
203 | // from the queue. |
204 | w->next = w; |
205 | w->prev = w; |
206 | } |
207 | } |
208 | if (w != nullptr) { |
209 | // We call notify_one() without any locks, so we can miss notifications. |
210 | // The wake up logic is best effort and a thread will wake in short |
211 | // period of time in case a notification is missed. |
212 | w->cv.notify_one(); |
213 | } |
214 | VLOG(3) << "Added " << (is_blocking ? "inter" : "intra" ) << " work from " |
215 | << traceme_id_.load(std::memory_order_relaxed); |
216 | return t; |
217 | } |
218 | |
219 | Task ThreadWorkSource::PopBlockingTask() { |
220 | return blocking_work_queue_.PopBack(); |
221 | } |
222 | |
223 | Task ThreadWorkSource::PopNonBlockingTask(int start_index, |
224 | bool search_from_all_queue) { |
225 | Task t; |
226 | unsigned sharding_factor = NonBlockingWorkShardingFactor(); |
227 | for (unsigned j = 0; j < sharding_factor; ++j) { |
228 | t = non_blocking_work_queues_[(start_index + j) % sharding_factor] |
229 | ->queue.PopBack(); |
230 | if (t.f) { |
231 | return t; |
232 | } |
233 | if (!search_from_all_queue) { |
234 | break; |
235 | } |
236 | } |
237 | return t; |
238 | } |
239 | |
240 | void ThreadWorkSource::WaitForWork(int max_sleep_micros) { |
241 | thread_local Waiter waiter; |
242 | WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros); |
243 | } |
244 | |
245 | int ThreadWorkSource::TaskQueueSize(bool is_blocking) { |
246 | if (is_blocking) { |
247 | return blocking_work_queue_.Size(); |
248 | } else { |
249 | unsigned total_size = 0; |
250 | for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) { |
251 | total_size += non_blocking_work_queues_[i]->queue.Size(); |
252 | } |
253 | return total_size; |
254 | } |
255 | } |
256 | |
257 | int64_t ThreadWorkSource::GetTracemeId() { |
258 | return traceme_id_.load(std::memory_order_relaxed); |
259 | } |
260 | |
261 | void ThreadWorkSource::SetTracemeId(int64_t value) { traceme_id_ = value; } |
262 | |
263 | void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) { |
264 | { |
265 | tf_shared_lock lock(run_handler_waiter_mu_); |
266 | // Most of the request won't change sub pool for recomputation. |
267 | // Optimization for avoiding holding exclusive lock to reduce contention. |
268 | if (sub_thread_pool_waiter_ == waiter) { |
269 | return; |
270 | } |
271 | // If the current version is a newer version, no need to update. |
272 | if (version_ > version) { |
273 | return; |
274 | } |
275 | } |
276 | |
277 | mutex_lock l(run_handler_waiter_mu_); |
278 | sub_thread_pool_waiter_ = waiter; |
279 | sub_thread_pool_waiter_mu_ = mutex; |
280 | version_ = version; |
281 | } |
282 | |
283 | int64_t ThreadWorkSource::GetInflightTaskCount(bool is_blocking) { |
284 | std::atomic<int64_t>* counter = |
285 | is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; |
286 | return counter->load(std::memory_order_relaxed); |
287 | } |
288 | |
289 | void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) { |
290 | std::atomic<int64_t>* counter = |
291 | is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; |
292 | counter->fetch_add(1, std::memory_order_relaxed); |
293 | } |
294 | |
295 | void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) { |
296 | std::atomic<int64_t>* counter = |
297 | is_blocking ? &blocking_inflight_ : &non_blocking_inflight_; |
298 | counter->fetch_sub(1, std::memory_order_relaxed); |
299 | } |
300 | |
301 | unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() { |
302 | return non_blocking_work_sharding_factor_; |
303 | } |
304 | |
305 | std::string ThreadWorkSource::ToString() { |
306 | return strings::StrCat("traceme_id = " , GetTracemeId(), |
307 | ", inter queue size = " , TaskQueueSize(true), |
308 | ", inter inflight = " , GetInflightTaskCount(true), |
309 | ", intra queue size = " , TaskQueueSize(false), |
310 | ", intra inflight = " , GetInflightTaskCount(false)); |
311 | } |
312 | |
313 | RunHandlerThreadPool::RunHandlerThreadPool( |
314 | int num_blocking_threads, int num_non_blocking_threads, Env* env, |
315 | const ThreadOptions& thread_options, const string& name, |
316 | Eigen::MaxSizeVector<mutex>* waiters_mu, |
317 | Eigen::MaxSizeVector<Waiter>* queue_waiters) |
318 | : num_threads_(num_blocking_threads + num_non_blocking_threads), |
319 | num_blocking_threads_(num_blocking_threads), |
320 | num_non_blocking_threads_(num_non_blocking_threads), |
321 | thread_data_(num_threads_), |
322 | env_(env, thread_options, name), |
323 | name_(name), |
324 | waiters_mu_(waiters_mu), |
325 | queue_waiters_(queue_waiters), |
326 | use_sub_thread_pool_(ParamFromEnvBoolWithDefault( |
327 | "TF_RUN_HANDLER_USE_SUB_THREAD_POOL" , false)), |
328 | num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault( |
329 | "TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL" , |
330 | std::vector<int>({num_blocking_threads / 2, |
331 | num_blocking_threads - num_blocking_threads / 2}))), |
332 | sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault( |
333 | "TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE" , |
334 | std::vector<double>({0, 0.4}))), |
335 | sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault( |
336 | "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE" , |
337 | std::vector<double>({0.4, 1}))) { |
338 | thread_data_.resize(num_threads_); |
339 | VLOG(1) << "Creating RunHandlerThreadPool " << name << " with " |
340 | << num_blocking_threads_ << " blocking threads and " |
341 | << num_non_blocking_threads_ << " non-blocking threads." ; |
342 | } |
343 | |
344 | RunHandlerThreadPool::~RunHandlerThreadPool() { |
345 | VLOG(1) << "Exiting RunHandlerThreadPool " << name_; |
346 | |
347 | cancelled_ = true; |
348 | for (size_t i = 0; i < thread_data_.size(); ++i) { |
349 | { |
350 | mutex_lock l(thread_data_[i].mu); |
351 | thread_data_[i].sources_not_empty.notify_all(); |
352 | } |
353 | thread_data_[i].thread.reset(); |
354 | } |
355 | } |
356 | |
357 | void RunHandlerThreadPool::Start() { |
358 | cancelled_ = false; |
359 | int num_blocking_threads = num_blocking_threads_; |
360 | for (int i = 0; i < num_threads_; i++) { |
361 | int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1; |
362 | for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) { |
363 | if (i < num_threads_in_sub_thread_pool_[j]) { |
364 | sub_thread_pool_id = j; |
365 | break; |
366 | } |
367 | } |
368 | thread_data_[i].sub_thread_pool_id = sub_thread_pool_id; |
369 | const bool is_blocking_thread = (i < num_blocking_threads) ? true : false; |
370 | // The blocking threads will handle both inter and intra op workload; |
371 | // non-blocking thread will handle intra op workload only; and the |
372 | // sub thread pool is only provided for blocking threads. |
373 | // Name the threads accordingly. |
374 | thread_data_[i].thread.reset(env_.CreateThread( |
375 | [this, is_blocking_thread, i, sub_thread_pool_id]() { |
376 | WorkerLoop(i, is_blocking_thread); |
377 | }, |
378 | is_blocking_thread |
379 | ? strings::StrCat(name_, "_blocking_thread_" , sub_thread_pool_id) |
380 | : strings::StrCat(name_, "_non_blocking_thread" ))); |
381 | } |
382 | } |
383 | |
384 | void RunHandlerThreadPool::StartOneThreadForTesting() { |
385 | cancelled_ = false; |
386 | thread_data_[0].sub_thread_pool_id = 0; |
387 | thread_data_[0].thread.reset( |
388 | env_.CreateThread([this]() { WorkerLoop(0, true); }, name_)); |
389 | } |
390 | |
391 | void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws, |
392 | bool is_blocking, |
393 | std::function<void()> fn) { |
394 | Task t = env_.CreateTask(std::move(fn)); |
395 | t = tws->EnqueueTask(std::move(t), is_blocking); |
396 | if (t.f) { |
397 | VLOG(3) << "Running " << (is_blocking ? "inter" : "intra" ) << " work for " |
398 | << tws->GetTracemeId(); |
399 | env_.ExecuteTask(t); |
400 | } |
401 | } |
402 | |
403 | // TODO(donglin) Change the task steal order to be round-robin such that if |
404 | // an attempt to steal task from request i failed, then attempt to steal task |
405 | // from the next request in terms of the arrival time. This approach may |
406 | // provide better performance due to less lock retention. The drawback is that |
407 | // the profiler will be a bit harder to read. |
408 | void RunHandlerThreadPool::SetThreadWorkSources( |
409 | int tid, int start_request_idx, uint64 version, |
410 | const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) { |
411 | mutex_lock l(thread_data_[tid].mu); |
412 | if (version > thread_data_[tid].new_version) { |
413 | thread_data_[tid].new_version = version; |
414 | } else { |
415 | // A newer version is already updated. No need to update. |
416 | return; |
417 | } |
418 | thread_data_[tid].new_thread_work_sources->resize(0); |
419 | if (use_sub_thread_pool_) { |
420 | for (int i = 0; i < thread_work_sources.size(); ++i) { |
421 | thread_data_[tid].new_thread_work_sources->emplace_back( |
422 | thread_work_sources[i]); |
423 | } |
424 | } else { |
425 | thread_data_[tid].new_thread_work_sources->emplace_back( |
426 | thread_work_sources[start_request_idx]); |
427 | // The number of shards for the queue. Threads in each shard will |
428 | // prioritize different thread_work_sources. Increase the number of shards |
429 | // could decrease the contention in the queue. For example, when |
430 | // num_shards == 1: thread_work_sources are ordered as start_request_idx, |
431 | // 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2: |
432 | // thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3, |
433 | // 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2, |
434 | // 4... for the other half of the threads. |
435 | static const int num_shards = |
436 | ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS" , 1); |
437 | int token = tid % num_shards; |
438 | for (int i = 0; i < num_shards; ++i) { |
439 | for (int j = token; j < thread_work_sources.size(); j += num_shards) { |
440 | if (j != start_request_idx) { |
441 | thread_data_[tid].new_thread_work_sources->emplace_back( |
442 | thread_work_sources[j]); |
443 | } |
444 | } |
445 | token = (token + 1) % num_shards; |
446 | } |
447 | thread_data_[tid].sources_not_empty.notify_all(); |
448 | } |
449 | } |
450 | |
451 | RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() { |
452 | thread_local RunHandlerThreadPool::PerThread per_thread_; |
453 | RunHandlerThreadPool::PerThread* pt = &per_thread_; |
454 | return pt; |
455 | } |
456 | |
457 | int RunHandlerThreadPool::CurrentThreadId() const { |
458 | const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread(); |
459 | if (pt->pool == this) { |
460 | return pt->thread_id; |
461 | } else { |
462 | return -1; |
463 | } |
464 | } |
465 | |
466 | int RunHandlerThreadPool::NumThreads() const { return num_threads_; } |
467 | |
468 | int RunHandlerThreadPool::NumBlockingThreads() const { |
469 | return num_blocking_threads_; |
470 | } |
471 | |
472 | int RunHandlerThreadPool::NumNonBlockingThreads() const { |
473 | return num_non_blocking_threads_; |
474 | } |
475 | |
476 | RunHandlerThreadPool::ThreadData::ThreadData() |
477 | : new_version(0), |
478 | current_index(0), |
479 | new_thread_work_sources( |
480 | new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>( |
481 | ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS" , |
482 | kMaxConcurrentHandlers)))), |
483 | current_version(0), |
484 | current_thread_work_sources( |
485 | new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>( |
486 | ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS" , |
487 | kMaxConcurrentHandlers)))) {} |
488 | |
489 | Task RunHandlerThreadPool::FindTask( |
490 | int searching_range_start, int searching_range_end, int thread_id, |
491 | int sub_thread_pool_id, int max_blocking_inflight, |
492 | bool may_steal_blocking_work, |
493 | const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources, |
494 | bool* task_from_blocking_queue, ThreadWorkSource** tws) { |
495 | Task t; |
496 | int current_index = thread_data_[thread_id].current_index; |
497 | *task_from_blocking_queue = false; |
498 | |
499 | for (int i = 0; i < searching_range_end - searching_range_start; ++i) { |
500 | if (current_index >= searching_range_end || |
501 | current_index < searching_range_start) { |
502 | current_index = searching_range_start; |
503 | } |
504 | *tws = thread_work_sources[current_index]; |
505 | ++current_index; |
506 | |
507 | // For blocking thread, search for blocking tasks first. |
508 | if (may_steal_blocking_work && |
509 | (*tws)->GetInflightTaskCount(true) < max_blocking_inflight) { |
510 | t = (*tws)->PopBlockingTask(); |
511 | if (t.f) { |
512 | *task_from_blocking_queue = true; |
513 | break; |
514 | } |
515 | } |
516 | |
517 | // Search for non-blocking tasks. |
518 | t = (*tws)->PopNonBlockingTask(thread_id, true); |
519 | if (t.f) { |
520 | break; |
521 | } |
522 | } |
523 | thread_data_[thread_id].current_index = current_index; |
524 | return t; |
525 | } |
526 | |
527 | // Main worker thread loop. |
528 | void RunHandlerThreadPool::WorkerLoop(int thread_id, |
529 | bool may_steal_blocking_work) { |
530 | PerThread* pt = GetPerThread(); |
531 | pt->pool = this; |
532 | pt->thread_id = thread_id; |
533 | static constexpr int32_t kMaxBlockingInflight = 10; |
534 | |
535 | while (!cancelled_) { |
536 | Task t; |
537 | ThreadWorkSource* tws = nullptr; |
538 | bool task_from_blocking_queue = true; |
539 | int sub_thread_pool_id; |
540 | // Get the current thread work sources. |
541 | { |
542 | mutex_lock l(thread_data_[thread_id].mu); |
543 | if (thread_data_[thread_id].current_version < |
544 | thread_data_[thread_id].new_version) { |
545 | thread_data_[thread_id].current_version = |
546 | thread_data_[thread_id].new_version; |
547 | thread_data_[thread_id].current_thread_work_sources.swap( |
548 | thread_data_[thread_id].new_thread_work_sources); |
549 | } |
550 | } |
551 | Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources = |
552 | thread_data_[thread_id].current_thread_work_sources.get(); |
553 | if (use_sub_thread_pool_) { |
554 | sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id; |
555 | int active_requests = thread_work_sources->size(); |
556 | if (may_steal_blocking_work) { |
557 | // Each thread will first look for tasks from requests that belongs to |
558 | // its sub thread pool. |
559 | int search_range_start = |
560 | active_requests * |
561 | sub_thread_pool_start_request_percentage_[sub_thread_pool_id]; |
562 | int search_range_end = |
563 | active_requests * |
564 | sub_thread_pool_end_request_percentage_[sub_thread_pool_id]; |
565 | search_range_end = |
566 | std::min(active_requests, |
567 | std::max(search_range_end, search_range_start + 1)); |
568 | |
569 | t = FindTask(search_range_start, search_range_end, thread_id, |
570 | sub_thread_pool_id, kMaxBlockingInflight, |
571 | /*may_steal_blocking_work=*/true, *thread_work_sources, |
572 | &task_from_blocking_queue, &tws); |
573 | if (!t.f) { |
574 | // Search from all requests if the thread cannot find tasks from |
575 | // requests that belong to its own sub thread pool. |
576 | t = FindTask(0, active_requests, thread_id, sub_thread_pool_id, |
577 | kMaxBlockingInflight, |
578 | /*may_steal_blocking_work=*/true, *thread_work_sources, |
579 | &task_from_blocking_queue, &tws); |
580 | } |
581 | } else { |
582 | // For non-blocking threads, it will always search from all pending |
583 | // requests. |
584 | t = FindTask(0, active_requests, thread_id, sub_thread_pool_id, |
585 | kMaxBlockingInflight, |
586 | /*may_steal_blocking_work=*/false, *thread_work_sources, |
587 | &task_from_blocking_queue, &tws); |
588 | } |
589 | } else { |
590 | // TODO(chaox): Refactor the following code to share the logic with |
591 | // FindTask. |
592 | for (int i = 0; i < thread_work_sources->size(); ++i) { |
593 | tws = (*thread_work_sources)[i]; |
594 | // We want a smallish numbers of inter threads since |
595 | // otherwise there will be contention in PropagateOutputs. |
596 | // This is best effort policy. |
597 | if (may_steal_blocking_work && |
598 | tws->GetInflightTaskCount(true) < kMaxBlockingInflight) { |
599 | t = tws->PopBlockingTask(); |
600 | if (t.f) { |
601 | break; |
602 | } |
603 | } |
604 | if (i == 0) { |
605 | // Always look for any work from the "primary" work source. |
606 | // This way when we wake up a thread for a new closure we are |
607 | // guaranteed it can be worked on. |
608 | t = tws->PopNonBlockingTask(thread_id, true); |
609 | if (t.f) { |
610 | task_from_blocking_queue = false; |
611 | break; |
612 | } |
613 | if (t.f) { |
614 | break; |
615 | } |
616 | } else { |
617 | t = tws->PopNonBlockingTask(thread_id, false); |
618 | if (t.f) { |
619 | task_from_blocking_queue = false; |
620 | break; |
621 | } |
622 | } |
623 | } |
624 | } |
625 | if (t.f) { |
626 | profiler::TraceMe activity( |
627 | [=] { |
628 | return strings::StrCat(task_from_blocking_queue ? "inter" : "intra" , |
629 | " #id = " , tws->GetTracemeId(), " " , |
630 | thread_id, "#" ); |
631 | }, |
632 | profiler::TraceMeLevel::kInfo); |
633 | VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra" ) |
634 | << " work from " << tws->GetTracemeId(); |
635 | tws->IncrementInflightTaskCount(task_from_blocking_queue); |
636 | env_.ExecuteTask(t); |
637 | tws->DecrementInflightTaskCount(task_from_blocking_queue); |
638 | } else { |
639 | profiler::TraceMe activity( |
640 | [=] { |
641 | return strings::StrCat("Sleeping#thread_id=" , thread_id, "#" ); |
642 | }, |
643 | profiler::TraceMeLevel::kInfo); |
644 | if (VLOG_IS_ON(4)) { |
645 | for (int i = 0; i < thread_work_sources->size(); ++i) { |
646 | VLOG(4) << "source id " << i << " " |
647 | << (*thread_work_sources)[i]->ToString(); |
648 | } |
649 | } |
650 | if (use_sub_thread_pool_) { |
651 | WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id); |
652 | } else { |
653 | WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight); |
654 | } |
655 | } |
656 | } |
657 | } |
658 | |
659 | void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking, |
660 | int sub_thread_pool_id) { |
661 | const int kMaxSleepMicros = 250; |
662 | |
663 | // The non-blocking thread will just sleep. |
664 | if (!is_blocking) { |
665 | Env::Default()->SleepForMicroseconds(kMaxSleepMicros); |
666 | return; |
667 | } |
668 | |
669 | thread_local Waiter waiter; |
670 | WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id], |
671 | &(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros); |
672 | } |
673 | |
674 | void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id, |
675 | int32_t max_blocking_inflight) { |
676 | const int kMaxSleepMicros = 250; |
677 | |
678 | // The non-blocking thread will just sleep. |
679 | if (!is_blocking) { |
680 | Env::Default()->SleepForMicroseconds(kMaxSleepMicros); |
681 | return; |
682 | } |
683 | |
684 | ThreadWorkSource* tws = nullptr; |
685 | { |
686 | mutex_lock l(thread_data_[thread_id].mu); |
687 | if (thread_data_[thread_id].new_version > |
688 | thread_data_[thread_id].current_version) { |
689 | thread_data_[thread_id].current_thread_work_sources.swap( |
690 | thread_data_[thread_id].new_thread_work_sources); |
691 | thread_data_[thread_id].current_version = |
692 | thread_data_[thread_id].new_version; |
693 | } |
694 | Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources = |
695 | thread_data_[thread_id].current_thread_work_sources.get(); |
696 | while (!cancelled_ && thread_work_sources->empty()) { |
697 | // Wait until there is new request |
698 | thread_data_[thread_id].sources_not_empty.wait(l); |
699 | if (thread_data_[thread_id].new_version > |
700 | thread_data_[thread_id].current_version) { |
701 | thread_data_[thread_id].current_thread_work_sources.swap( |
702 | thread_data_[thread_id].new_thread_work_sources); |
703 | thread_data_[thread_id].current_version = |
704 | thread_data_[thread_id].new_version; |
705 | thread_work_sources = |
706 | thread_data_[thread_id].current_thread_work_sources.get(); |
707 | } |
708 | } |
709 | if (cancelled_) { |
710 | return; |
711 | } |
712 | tws = (*thread_work_sources)[0]; |
713 | } |
714 | |
715 | if (tws->GetInflightTaskCount(true) >= max_blocking_inflight) { |
716 | // Sleep to reduce contention in PropagateOutputs |
717 | Env::Default()->SleepForMicroseconds(kMaxSleepMicros); |
718 | } |
719 | tws->WaitForWork(kMaxSleepMicros); |
720 | } |
721 | |
722 | } // namespace internal |
723 | |
724 | // Contains the concrete implementation of the RunHandler. |
725 | // Externally visible RunHandler class simply forwards the work to this one. |
726 | class RunHandler::Impl { |
727 | public: |
728 | explicit Impl(RunHandlerPool::Impl* pool_impl); |
729 | |
730 | ~Impl() {} |
731 | |
732 | thread::ThreadPoolInterface* thread_pool_interface() { |
733 | return thread_pool_interface_.get(); |
734 | } |
735 | |
736 | // Stores now time (in microseconds) since unix epoch when the handler is |
737 | // requested via RunHandlerPool::Get(). |
738 | uint64 start_time_us() const { return start_time_us_; } |
739 | int64_t step_id() const { return step_id_; } |
740 | void ScheduleInterOpClosure(std::function<void()> fn); |
741 | void ScheduleIntraOpClosure(std::function<void()> fn); |
742 | |
743 | void Reset(int64_t step_id, |
744 | const RunOptions::Experimental::RunHandlerPoolOptions& options); |
745 | |
746 | RunHandlerPool::Impl* pool_impl() { return pool_impl_; } |
747 | |
748 | internal::ThreadWorkSource* tws() { return &tws_; } |
749 | |
750 | int64_t priority() { return options_.priority(); } |
751 | |
752 | private: |
753 | class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface { |
754 | public: |
755 | explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl) |
756 | : run_handler_impl_(run_handler_impl) {} |
757 | ~ThreadPoolInterfaceWrapper() override {} |
758 | void Schedule(std::function<void()> fn) override; |
759 | int NumThreads() const override; |
760 | int CurrentThreadId() const override; |
761 | |
762 | private: |
763 | RunHandler::Impl* run_handler_impl_ = nullptr; |
764 | }; |
765 | |
766 | RunHandlerPool::Impl* pool_impl_; // NOT OWNED. |
767 | uint64 start_time_us_; |
768 | int64_t step_id_; |
769 | std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_; |
770 | internal::ThreadWorkSource tws_; |
771 | RunOptions::Experimental::RunHandlerPoolOptions options_; |
772 | }; |
773 | |
774 | // Contains shared state across all run handlers present in the pool. Also |
775 | // responsible for pool management decisions. |
776 | // This class is thread safe. |
777 | class RunHandlerPool::Impl { |
778 | public: |
779 | explicit Impl(int num_inter_op_threads, int num_intra_op_threads) |
780 | : max_handlers_(static_cast<int32>(ParamFromEnvWithDefault( |
781 | "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS" , kMaxConcurrentHandlers))), |
782 | waiters_mu_( |
783 | ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL" , 2)), |
784 | queue_waiters_( |
785 | ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL" , 2)), |
786 | run_handler_thread_pool_(new internal::RunHandlerThreadPool( |
787 | num_inter_op_threads, num_intra_op_threads, Env::Default(), |
788 | ThreadOptions(), "tf_run_handler_pool" , &waiters_mu_, |
789 | &queue_waiters_)), |
790 | iterations_(0), |
791 | version_(0), |
792 | sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault( |
793 | "TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE" , |
794 | std::vector<double>({1}))) { |
795 | VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_; |
796 | free_handlers_.reserve(max_handlers_); |
797 | handlers_.reserve(max_handlers_); |
798 | for (int i = 0; i < max_handlers_; ++i) { |
799 | handlers_.emplace_back(new RunHandler::Impl(this)); |
800 | free_handlers_.push_back(handlers_.back().get()); |
801 | } |
802 | queue_waiters_.resize( |
803 | ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL" , 2)); |
804 | waiters_mu_.resize( |
805 | ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL" , 2)); |
806 | for (auto& queue_waiter : queue_waiters_) { |
807 | queue_waiter.next = &queue_waiter; |
808 | queue_waiter.prev = &queue_waiter; |
809 | } |
810 | run_handler_thread_pool_->Start(); |
811 | } |
812 | |
813 | ~Impl() { |
814 | // Sanity check that all handlers have been returned back to the pool before |
815 | // destruction. |
816 | DCHECK_EQ(handlers_.size(), max_handlers_); |
817 | DCHECK_EQ(free_handlers_.size(), handlers_.size()); |
818 | DCHECK_EQ(sorted_active_handlers_.size(), 0); |
819 | // Stop the threads in run_handler_thread_pool_ before freeing other |
820 | // pointers. Otherwise a thread may try to access a pointer after the |
821 | // pointer has been freed. |
822 | run_handler_thread_pool_.reset(); |
823 | } |
824 | |
825 | internal::RunHandlerThreadPool* run_handler_thread_pool() { |
826 | return run_handler_thread_pool_.get(); |
827 | } |
828 | |
829 | bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
830 | return !free_handlers_.empty(); |
831 | } |
832 | |
833 | std::unique_ptr<RunHandler> Get( |
834 | int64_t step_id, int64_t timeout_in_ms, |
835 | const RunOptions::Experimental::RunHandlerPoolOptions& options) |
836 | TF_LOCKS_EXCLUDED(mu_) { |
837 | thread_local std::unique_ptr< |
838 | Eigen::MaxSizeVector<internal::ThreadWorkSource*>> |
839 | thread_work_sources = |
840 | std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>( |
841 | new Eigen::MaxSizeVector<internal::ThreadWorkSource*>( |
842 | static_cast<int32>(ParamFromEnvWithDefault( |
843 | "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS" , |
844 | kMaxConcurrentHandlers)))); |
845 | uint64 version; |
846 | int num_active_requests; |
847 | RunHandler::Impl* handler_impl; |
848 | { |
849 | mutex_lock l(mu_); |
850 | if (!has_free_handler()) { |
851 | profiler::TraceMe activity( |
852 | [&] { |
853 | return strings::StrCat("WaitingForHandler#step_id=" , step_id, |
854 | "#" ); |
855 | }, |
856 | profiler::TraceMeLevel::kInfo); |
857 | TRACESTRING( |
858 | strings::StrCat("RunHandlerPool::Impl::Get waiting for a handler " |
859 | "with timeout in millisecond" , |
860 | timeout_in_ms)); |
861 | if (timeout_in_ms == 0) { |
862 | mu_.Await(Condition(this, &Impl::has_free_handler)); |
863 | } else if (!mu_.AwaitWithDeadline( |
864 | Condition(this, &Impl::has_free_handler), |
865 | EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) { |
866 | return nullptr; |
867 | } |
868 | } |
869 | // Remove the last entry from free_handlers_ and add to the end of |
870 | // sorted_active_handlers_. |
871 | handler_impl = free_handlers_.back(); |
872 | handler_impl->Reset(step_id, options); |
873 | free_handlers_.pop_back(); |
874 | |
875 | num_active_requests = sorted_active_handlers_.size() + 1; |
876 | thread_work_sources->resize(num_active_requests); |
877 | int priority = options.priority(); |
878 | auto it = sorted_active_handlers_.cbegin(); |
879 | bool new_handler_inserted = false; |
880 | for (int i = 0; i < num_active_requests; ++i) { |
881 | if (!new_handler_inserted && (it == sorted_active_handlers_.cend() || |
882 | priority > (*it)->priority())) { |
883 | sorted_active_handlers_.insert(it, handler_impl); |
884 | new_handler_inserted = true; |
885 | // Point to the newly added handler. |
886 | --it; |
887 | } |
888 | (*thread_work_sources)[i] = (*it)->tws(); |
889 | ++it; |
890 | } |
891 | version = ++version_; |
892 | } |
893 | RecomputePoolStats(num_active_requests, version, *thread_work_sources); |
894 | return WrapUnique<RunHandler>(new RunHandler(handler_impl)); |
895 | } |
896 | |
897 | void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) { |
898 | mutex_lock l(mu_); |
899 | DCHECK_GT(sorted_active_handlers_.size(), 0); |
900 | |
901 | CHECK_EQ(handler->tws()->TaskQueueSize(true), 0); // Crash OK. |
902 | CHECK_EQ(handler->tws()->TaskQueueSize(false), 0); // Crash OK. |
903 | |
904 | uint64 now = tensorflow::EnvTime::NowMicros(); |
905 | double elapsed = (now - handler->start_time_us()) / 1000.0; |
906 | time_hist_.Add(elapsed); |
907 | |
908 | // Erase from and update sorted_active_handlers_. Add it to the end of |
909 | // free_handlers_. |
910 | auto iter = std::find(sorted_active_handlers_.begin(), |
911 | sorted_active_handlers_.end(), handler); |
912 | DCHECK(iter != sorted_active_handlers_.end()) |
913 | << "Unexpected handler: " << handler |
914 | << " is being requested for release" ; |
915 | |
916 | // Remove this handler from this list and add it to the list of free |
917 | // handlers. |
918 | sorted_active_handlers_.erase(iter); |
919 | free_handlers_.push_back(handler); |
920 | DCHECK_LE(free_handlers_.size(), max_handlers_); |
921 | LogInfo(); |
922 | |
923 | // We do not recompute pool stats during release. The side effect is that |
924 | // there may be empty thread work sources in the queue. However, any new |
925 | // requests will trigger recomputation. |
926 | } |
927 | |
928 | std::vector<int64_t> GetActiveHandlerPrioritiesForTesting() |
929 | TF_LOCKS_EXCLUDED(mu_) { |
930 | mutex_lock l(mu_); |
931 | std::vector<int64_t> ret; |
932 | for (const auto& handler_impl : sorted_active_handlers_) { |
933 | ret.push_back(handler_impl->priority()); |
934 | } |
935 | return ret; |
936 | } |
937 | |
938 | private: |
939 | void RecomputePoolStats( |
940 | int num_active_requests, uint64 version, |
941 | const Eigen::MaxSizeVector<internal::ThreadWorkSource*>& |
942 | thread_work_sources); |
943 | |
944 | void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
945 | |
946 | // Maximum number of handlers pre-created during pool construction time. The |
947 | // number has been chosen expecting each handler might at least want 1 |
948 | // inter-op thread for execution (during compute intensive workloads like |
949 | // inference). |
950 | const int max_handlers_; |
951 | |
952 | Eigen::MaxSizeVector<mutex> waiters_mu_; |
953 | Eigen::MaxSizeVector<internal::Waiter> queue_waiters_; |
954 | |
955 | std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_; |
956 | // Thread compatible part used only by lock under RunHandlerPool. |
957 | // Handlers are sorted by start time. |
958 | // TODO(azaks): sort by the remaining latency budget. |
959 | // TODO(chaox): Consider other data structure for maintaining the sorted |
960 | // active handlers if the searching overhead(currently O(n)) becomes the |
961 | // bottleneck. |
962 | std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_); |
963 | std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_); |
964 | std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_); |
965 | |
966 | // Histogram of elapsed runtime of every handler (in ms). |
967 | histogram::Histogram time_hist_ TF_GUARDED_BY(mu_); |
968 | |
969 | int64_t iterations_ TF_GUARDED_BY(mu_); |
970 | mutex mu_; |
971 | int64_t version_ TF_GUARDED_BY(mu_); |
972 | const std::vector<double> sub_thread_pool_end_request_percentage_; |
973 | }; |
974 | |
975 | void RunHandlerPool::Impl::RecomputePoolStats( |
976 | int num_active_requests, uint64 version, |
977 | const Eigen::MaxSizeVector<internal::ThreadWorkSource*>& |
978 | thread_work_sources) { |
979 | if (num_active_requests == 0) return; |
980 | |
981 | int sub_thread_pool_id = 0; |
982 | for (int i = 0; i < num_active_requests; ++i) { |
983 | while ( |
984 | sub_thread_pool_id < |
985 | sub_thread_pool_end_request_percentage_.size() - 1 && |
986 | i >= num_active_requests * |
987 | sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) { |
988 | sub_thread_pool_id++; |
989 | } |
990 | thread_work_sources[i]->SetWaiter(version, |
991 | &queue_waiters_[sub_thread_pool_id], |
992 | &waiters_mu_[sub_thread_pool_id]); |
993 | } |
994 | |
995 | int num_threads = run_handler_thread_pool()->NumThreads(); |
996 | int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads(); |
997 | int num_non_blocking_threads = num_threads - num_blocking_threads; |
998 | |
999 | std::vector<int> request_idx_list = ChooseRequestsWithExponentialDistribution( |
1000 | num_active_requests, num_blocking_threads); |
1001 | for (int i = 0; i < num_blocking_threads; ++i) { |
1002 | VLOG(2) << "Set work for tid=" << i |
1003 | << " with start_request_idx=" << request_idx_list[i]; |
1004 | run_handler_thread_pool()->SetThreadWorkSources( |
1005 | i, request_idx_list[i], version, thread_work_sources); |
1006 | } |
1007 | |
1008 | request_idx_list = ChooseRequestsWithExponentialDistribution( |
1009 | num_active_requests, num_non_blocking_threads); |
1010 | for (int i = 0; i < num_non_blocking_threads; ++i) { |
1011 | VLOG(2) << "Set work for tid=" << (i + num_blocking_threads) |
1012 | << " with start_request_idx=" << request_idx_list[i]; |
1013 | run_handler_thread_pool()->SetThreadWorkSources( |
1014 | i + num_blocking_threads, request_idx_list[i], version, |
1015 | thread_work_sources); |
1016 | } |
1017 | } |
1018 | |
1019 | void RunHandlerPool::Impl::LogInfo() { |
1020 | if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) { |
1021 | int num_active_requests = sorted_active_handlers_.size(); |
1022 | VLOG(1) << "Printing time histogram: " << time_hist_.ToString(); |
1023 | VLOG(1) << "Active session runs: " << num_active_requests; |
1024 | uint64 now = tensorflow::Env::Default()->NowMicros(); |
1025 | string times_str = "" ; |
1026 | string ids_str = "" ; |
1027 | auto it = sorted_active_handlers_.cbegin(); |
1028 | for (int i = 0; i < num_active_requests; ++i) { |
1029 | if (i > 0) { |
1030 | times_str += " " ; |
1031 | ids_str += " " ; |
1032 | } |
1033 | |
1034 | times_str += |
1035 | strings::StrCat((now - (*it)->start_time_us()) / 1000.0, " ms." ); |
1036 | ids_str += strings::StrCat((*it)->tws()->GetTracemeId()); |
1037 | ++it; |
1038 | } |
1039 | VLOG(1) << "Elapsed times are: " << times_str; |
1040 | VLOG(1) << "Step ids are: " << ids_str; |
1041 | } |
1042 | } |
1043 | |
1044 | // It is important to return a value such as: |
1045 | // CurrentThreadId() in [0, NumThreads) |
1046 | int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const { |
1047 | return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads(); |
1048 | } |
1049 | |
1050 | int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const { |
1051 | return run_handler_impl_->pool_impl_->run_handler_thread_pool() |
1052 | ->CurrentThreadId(); |
1053 | } |
1054 | |
1055 | void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule( |
1056 | std::function<void()> fn) { |
1057 | return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn)); |
1058 | } |
1059 | |
1060 | RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl) |
1061 | : pool_impl_(pool_impl) { |
1062 | thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this)); |
1063 | Reset(0, RunOptions::Experimental::RunHandlerPoolOptions()); |
1064 | } |
1065 | |
1066 | void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) { |
1067 | VLOG(3) << "Scheduling inter work for " << tws()->GetTracemeId(); |
1068 | pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true, |
1069 | std::move(fn)); |
1070 | } |
1071 | |
1072 | void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) { |
1073 | VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId(); |
1074 | pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false, |
1075 | std::move(fn)); |
1076 | } |
1077 | |
1078 | void RunHandler::Impl::Reset( |
1079 | int64_t step_id, |
1080 | const RunOptions::Experimental::RunHandlerPoolOptions& options) { |
1081 | start_time_us_ = tensorflow::Env::Default()->NowMicros(); |
1082 | step_id_ = step_id; |
1083 | options_ = options; |
1084 | tws_.SetTracemeId(step_id); |
1085 | } |
1086 | |
1087 | RunHandlerPool::RunHandlerPool(int num_inter_op_threads) |
1088 | : impl_(new Impl(num_inter_op_threads, 0)) {} |
1089 | |
1090 | RunHandlerPool::RunHandlerPool(int num_inter_op_threads, |
1091 | int num_intra_op_threads) |
1092 | : impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {} |
1093 | |
1094 | RunHandlerPool::~RunHandlerPool() {} |
1095 | |
1096 | std::unique_ptr<RunHandler> RunHandlerPool::Get( |
1097 | int64_t step_id, int64_t timeout_in_ms, |
1098 | const RunOptions::Experimental::RunHandlerPoolOptions& options) { |
1099 | return impl_->Get(step_id, timeout_in_ms, options); |
1100 | } |
1101 | |
1102 | std::vector<int64_t> RunHandlerPool::GetActiveHandlerPrioritiesForTesting() |
1103 | const { |
1104 | return impl_->GetActiveHandlerPrioritiesForTesting(); |
1105 | } |
1106 | |
1107 | RunHandler::RunHandler(Impl* impl) : impl_(impl) {} |
1108 | |
1109 | void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) { |
1110 | impl_->ScheduleInterOpClosure(std::move(fn)); |
1111 | } |
1112 | |
1113 | thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() { |
1114 | return impl_->thread_pool_interface(); |
1115 | } |
1116 | |
1117 | RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); } |
1118 | |
1119 | } // namespace tensorflow |
1120 | |