1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tsl/platform/threadpool.h" |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "absl/types/optional.h" |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/platform/context.h" |
23 | #include "tensorflow/tsl/platform/blocking_counter.h" |
24 | #include "tensorflow/tsl/platform/denormal.h" |
25 | #include "tensorflow/tsl/platform/logging.h" |
26 | #include "tensorflow/tsl/platform/mutex.h" |
27 | #include "tensorflow/tsl/platform/numa.h" |
28 | #include "tensorflow/tsl/platform/setround.h" |
29 | #include "tensorflow/tsl/platform/tracing.h" |
30 | |
31 | namespace tsl { |
32 | // TODO(aminim): remove after tensorflow/core/platform/context.h migration. |
33 | using tensorflow::Context; |
34 | using tensorflow::ContextKind; |
35 | using tensorflow::WithContext; |
36 | |
37 | namespace thread { |
38 | |
39 | struct EigenEnvironment { |
40 | typedef Thread EnvThread; |
41 | struct TaskImpl { |
42 | std::function<void()> f; |
43 | Context context; |
44 | uint64 trace_id; |
45 | }; |
46 | struct Task { |
47 | std::unique_ptr<TaskImpl> f; |
48 | }; |
49 | |
50 | Env* const env_; |
51 | const ThreadOptions thread_options_; |
52 | const string name_; |
53 | |
54 | EigenEnvironment(Env* env, const ThreadOptions& thread_options, |
55 | const string& name) |
56 | : env_(env), thread_options_(thread_options), name_(name) {} |
57 | |
58 | EnvThread* CreateThread(std::function<void()> f) { |
59 | return env_->StartThread(thread_options_, name_, [=]() { |
60 | // Set the processor flag to flush denormals to zero. |
61 | port::ScopedFlushDenormal flush; |
62 | // Set the processor rounding mode to ROUND TO NEAREST. |
63 | tsl::port::ScopedSetRound round(FE_TONEAREST); |
64 | if (thread_options_.numa_node != port::kNUMANoAffinity) { |
65 | port::NUMASetThreadNodeAffinity(thread_options_.numa_node); |
66 | } |
67 | f(); |
68 | }); |
69 | } |
70 | |
71 | Task CreateTask(std::function<void()> f) { |
72 | uint64 id = 0; |
73 | if (tracing::EventCollector::IsEnabled()) { |
74 | id = tracing::GetUniqueArg(); |
75 | tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id); |
76 | } |
77 | return Task{ |
78 | std::unique_ptr<TaskImpl>(new TaskImpl{ |
79 | std::move(f), |
80 | Context(ContextKind::kThread), |
81 | id, |
82 | }), |
83 | }; |
84 | } |
85 | |
86 | void ExecuteTask(const Task& t) { |
87 | WithContext wc(t.f->context); |
88 | tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, |
89 | t.f->trace_id); |
90 | t.f->f(); |
91 | } |
92 | }; |
93 | |
94 | ThreadPool::ThreadPool(Env* env, const string& name, int num_threads) |
95 | : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {} |
96 | |
97 | ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, |
98 | const string& name, int num_threads) |
99 | : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {} |
100 | |
101 | ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options, |
102 | const string& name, int num_threads, |
103 | bool low_latency_hint, Eigen::Allocator* allocator) { |
104 | CHECK_GE(num_threads, 1); |
105 | eigen_threadpool_.reset(new Eigen::ThreadPoolTempl<EigenEnvironment>( |
106 | num_threads, low_latency_hint, |
107 | EigenEnvironment(env, thread_options, "tf_" + name))); |
108 | underlying_threadpool_ = eigen_threadpool_.get(); |
109 | threadpool_device_.reset(new Eigen::ThreadPoolDevice(underlying_threadpool_, |
110 | num_threads, allocator)); |
111 | } |
112 | |
113 | ThreadPool::ThreadPool(thread::ThreadPoolInterface* user_threadpool) { |
114 | underlying_threadpool_ = user_threadpool; |
115 | threadpool_device_.reset(new Eigen::ThreadPoolDevice( |
116 | underlying_threadpool_, underlying_threadpool_->NumThreads(), nullptr)); |
117 | } |
118 | |
119 | ThreadPool::~ThreadPool() {} |
120 | |
121 | void ThreadPool::Schedule(std::function<void()> fn) { |
122 | CHECK(fn != nullptr); |
123 | underlying_threadpool_->Schedule(std::move(fn)); |
124 | } |
125 | |
126 | int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling( |
127 | const int64_t total, const int64_t block_size) { |
128 | if (block_size <= 0 || total <= 1 || total <= block_size || |
129 | NumThreads() == 1) { |
130 | return 1; |
131 | } |
132 | return (total + block_size - 1) / block_size; |
133 | } |
134 | |
135 | int ThreadPool::NumShardsUsedByTransformRangeConcurrently( |
136 | const int64_t block_size, const int64_t total) { |
137 | return NumShardsUsedByFixedBlockSizeScheduling(total, block_size); |
138 | } |
139 | |
140 | void ThreadPool::ParallelFor(int64_t total, |
141 | const SchedulingParams& scheduling_params, |
142 | const std::function<void(int64_t, int64_t)>& fn) { |
143 | switch (scheduling_params.strategy()) { |
144 | case SchedulingStrategy::kAdaptive: { |
145 | if (scheduling_params.cost_per_unit().has_value()) { |
146 | ParallelFor(total, *scheduling_params.cost_per_unit(), fn); |
147 | } |
148 | break; |
149 | } |
150 | case SchedulingStrategy::kFixedBlockSize: { |
151 | if (scheduling_params.block_size().has_value()) { |
152 | ParallelForFixedBlockSizeScheduling( |
153 | total, *scheduling_params.block_size(), fn); |
154 | } |
155 | break; |
156 | } |
157 | } |
158 | } |
159 | |
160 | void ThreadPool::TransformRangeConcurrently( |
161 | const int64_t block_size, const int64_t total, |
162 | const std::function<void(int64_t, int64_t)>& fn) { |
163 | ParallelFor(total, |
164 | SchedulingParams(SchedulingStrategy::kFixedBlockSize, |
165 | absl::nullopt /* cost_per_unit */, block_size), |
166 | fn); |
167 | } |
168 | |
169 | // This functionality is similar to parallelFor, except that reasoning about |
170 | // the number of shards used is significantly easier. |
171 | void ThreadPool::ParallelForFixedBlockSizeScheduling( |
172 | const int64_t total, const int64_t block_size, |
173 | const std::function<void(int64_t, int64_t)>& fn) { |
174 | const int num_shards_used = |
175 | NumShardsUsedByFixedBlockSizeScheduling(total, block_size); |
176 | if (num_shards_used == 1) { |
177 | fn(0, total); |
178 | return; |
179 | } |
180 | |
181 | // Adapted from Eigen's parallelFor implementation. |
182 | BlockingCounter counter(num_shards_used); |
183 | std::function<void(int64_t, int64_t)> handle_range = |
184 | [=, &handle_range, &counter, &fn](int64_t first, int64_t last) { |
185 | while (last - first > block_size) { |
186 | // Find something near the midpoint which is a multiple of block size. |
187 | const int64_t mid = first + ((last - first) / 2 + block_size - 1) / |
188 | block_size * block_size; |
189 | Schedule([=, &handle_range]() { handle_range(mid, last); }); |
190 | last = mid; |
191 | } |
192 | // Single block or less, execute directly. |
193 | fn(first, last); |
194 | counter.DecrementCount(); // The shard is done. |
195 | }; |
196 | if (num_shards_used <= NumThreads()) { |
197 | // Avoid a thread hop by running the root of the tree and one block on the |
198 | // main thread. |
199 | handle_range(0, total); |
200 | } else { |
201 | // Execute the root in the thread pool to avoid running work on more than |
202 | // numThreads() threads. |
203 | Schedule([=, &handle_range]() { handle_range(0, total); }); |
204 | } |
205 | counter.Wait(); |
206 | } |
207 | |
208 | void ThreadPool::ParallelFor(int64_t total, int64_t cost_per_unit, |
209 | const std::function<void(int64_t, int64_t)>& fn) { |
210 | CHECK_GE(total, 0); |
211 | CHECK_EQ(total, (int64_t)(Eigen::Index)total); |
212 | threadpool_device_->parallelFor( |
213 | total, Eigen::TensorOpCost(0, 0, cost_per_unit), |
214 | [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); }); |
215 | } |
216 | |
217 | void ThreadPool::ParallelForWithWorkerId( |
218 | int64_t total, int64_t cost_per_unit, |
219 | const std::function<void(int64_t, int64_t, int)>& fn) { |
220 | CHECK_GE(total, 0); |
221 | CHECK_EQ(total, (int64_t)(Eigen::Index)total); |
222 | |
223 | threadpool_device_->parallelFor(total, |
224 | Eigen::TensorOpCost(0, 0, cost_per_unit), |
225 | [this, &fn](int64_t start, int64_t limit) { |
226 | // ParallelFor may use the current thread to |
227 | // do some work synchronously. When calling |
228 | // CurrentThreadId() from outside of the |
229 | // thread pool, we get -1, so we can shift |
230 | // every id up by 1. |
231 | int id = CurrentThreadId() + 1; |
232 | fn(start, limit, id); |
233 | }); |
234 | } |
235 | |
236 | void ThreadPool::ParallelForWithWorkerId( |
237 | int64_t total, const SchedulingParams& scheduling_params, |
238 | const std::function<void(int64_t, int64_t, int)>& fn) { |
239 | ParallelFor(total, scheduling_params, |
240 | [this, &fn](int64_t start, int64_t limit) { |
241 | // We may use the current thread to do some work synchronously. |
242 | // When calling CurrentThreadId() from outside of the thread |
243 | // pool, we get -1, so we can shift every id up by 1. |
244 | int id = CurrentThreadId() + 1; |
245 | fn(start, limit, id); |
246 | }); |
247 | } |
248 | |
249 | int ThreadPool::NumThreads() const { |
250 | return underlying_threadpool_->NumThreads(); |
251 | } |
252 | |
253 | int ThreadPool::CurrentThreadId() const { |
254 | return underlying_threadpool_->CurrentThreadId(); |
255 | } |
256 | |
257 | void ThreadPool::ScheduleWithHint(std::function<void()> fn, int start, |
258 | int limit) { |
259 | underlying_threadpool_->ScheduleWithHint(std::move(fn), start, limit); |
260 | } |
261 | |
262 | void ThreadPool::SetStealPartitions( |
263 | const std::vector<std::pair<unsigned, unsigned>>& partitions) { |
264 | // ThreadPool::SetStealPartitions is only called in the constructor of |
265 | // RunHandlerPool::Impl, which currently instantiates ThreadPool using a |
266 | // constructor that does not take user_threadpool. Thus we assume |
267 | // eigen_threadpool_ is not null here. |
268 | DCHECK(eigen_threadpool_ != nullptr); |
269 | eigen_threadpool_->SetStealPartitions(partitions); |
270 | } |
271 | |
272 | Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const { |
273 | DCHECK(underlying_threadpool_ != nullptr); |
274 | return underlying_threadpool_; |
275 | } |
276 | } // namespace thread |
277 | } // namespace tsl |
278 | |