1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_DATA_UNBOUNDED_THREAD_POOL_H_ |
16 | #define TENSORFLOW_CORE_DATA_UNBOUNDED_THREAD_POOL_H_ |
17 | |
18 | #include <deque> |
19 | #include <memory> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/thread_factory.h" |
23 | #include "tensorflow/core/lib/core/notification.h" |
24 | #include "tensorflow/core/lib/core/threadpool_interface.h" |
25 | #include "tensorflow/core/platform/env.h" |
26 | #include "tensorflow/core/platform/unbounded_work_queue.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace data { |
30 | |
31 | // An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a |
32 | // potentially large number of "logical" threads onto a smaller number of |
33 | // "physical" threads. The multiplexing is achieved by using an |
34 | // `UnboundedWorkQueue`. |
35 | class UnboundedThreadPool : public thread::ThreadPoolInterface { |
36 | public: |
37 | UnboundedThreadPool(Env* env, const string& thread_name) |
38 | : unbounded_work_queue_(env, thread_name) {} |
39 | UnboundedThreadPool(Env* env, const string& thread_name, |
40 | const ThreadOptions& thread_options) |
41 | : unbounded_work_queue_(env, thread_name, thread_options) {} |
42 | ~UnboundedThreadPool() override = default; |
43 | |
44 | // Returns an implementation of `ThreadFactory` that can be used to create |
45 | // logical threads in this pool. |
46 | std::shared_ptr<ThreadFactory> get_thread_factory(); |
47 | |
48 | void Schedule(std::function<void()> fn) override; |
49 | int NumThreads() const override; |
50 | int CurrentThreadId() const override; |
51 | |
52 | private: |
53 | class LogicalThreadFactory; |
54 | class LogicalThreadWrapper; |
55 | |
56 | void ScheduleOnWorkQueue(std::function<void()> fn, |
57 | std::shared_ptr<Notification> done); |
58 | |
59 | UnboundedWorkQueue unbounded_work_queue_; |
60 | }; |
61 | |
62 | } // namespace data |
63 | } // namespace tensorflow |
64 | |
65 | #endif // TENSORFLOW_CORE_DATA_UNBOUNDED_THREAD_POOL_H_ |
66 | |