1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tsl/platform/threadpool.h"
17
18#define EIGEN_USE_THREADS
19
20#include "absl/types/optional.h"
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/platform/context.h"
23#include "tensorflow/tsl/platform/blocking_counter.h"
24#include "tensorflow/tsl/platform/denormal.h"
25#include "tensorflow/tsl/platform/logging.h"
26#include "tensorflow/tsl/platform/mutex.h"
27#include "tensorflow/tsl/platform/numa.h"
28#include "tensorflow/tsl/platform/setround.h"
29#include "tensorflow/tsl/platform/tracing.h"
30
31namespace tsl {
32// TODO(aminim): remove after tensorflow/core/platform/context.h migration.
33using tensorflow::Context;
34using tensorflow::ContextKind;
35using tensorflow::WithContext;
36
37namespace thread {
38
39struct EigenEnvironment {
40 typedef Thread EnvThread;
41 struct TaskImpl {
42 std::function<void()> f;
43 Context context;
44 uint64 trace_id;
45 };
46 struct Task {
47 std::unique_ptr<TaskImpl> f;
48 };
49
50 Env* const env_;
51 const ThreadOptions thread_options_;
52 const string name_;
53
54 EigenEnvironment(Env* env, const ThreadOptions& thread_options,
55 const string& name)
56 : env_(env), thread_options_(thread_options), name_(name) {}
57
58 EnvThread* CreateThread(std::function<void()> f) {
59 return env_->StartThread(thread_options_, name_, [=]() {
60 // Set the processor flag to flush denormals to zero.
61 port::ScopedFlushDenormal flush;
62 // Set the processor rounding mode to ROUND TO NEAREST.
63 tsl::port::ScopedSetRound round(FE_TONEAREST);
64 if (thread_options_.numa_node != port::kNUMANoAffinity) {
65 port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
66 }
67 f();
68 });
69 }
70
71 Task CreateTask(std::function<void()> f) {
72 uint64 id = 0;
73 if (tracing::EventCollector::IsEnabled()) {
74 id = tracing::GetUniqueArg();
75 tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
76 }
77 return Task{
78 std::unique_ptr<TaskImpl>(new TaskImpl{
79 std::move(f),
80 Context(ContextKind::kThread),
81 id,
82 }),
83 };
84 }
85
86 void ExecuteTask(const Task& t) {
87 WithContext wc(t.f->context);
88 tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
89 t.f->trace_id);
90 t.f->f();
91 }
92};
93
94ThreadPool::ThreadPool(Env* env, const string& name, int num_threads)
95 : ThreadPool(env, ThreadOptions(), name, num_threads, true, nullptr) {}
96
97ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
98 const string& name, int num_threads)
99 : ThreadPool(env, thread_options, name, num_threads, true, nullptr) {}
100
101ThreadPool::ThreadPool(Env* env, const ThreadOptions& thread_options,
102 const string& name, int num_threads,
103 bool low_latency_hint, Eigen::Allocator* allocator) {
104 CHECK_GE(num_threads, 1);
105 eigen_threadpool_.reset(new Eigen::ThreadPoolTempl<EigenEnvironment>(
106 num_threads, low_latency_hint,
107 EigenEnvironment(env, thread_options, "tf_" + name)));
108 underlying_threadpool_ = eigen_threadpool_.get();
109 threadpool_device_.reset(new Eigen::ThreadPoolDevice(underlying_threadpool_,
110 num_threads, allocator));
111}
112
113ThreadPool::ThreadPool(thread::ThreadPoolInterface* user_threadpool) {
114 underlying_threadpool_ = user_threadpool;
115 threadpool_device_.reset(new Eigen::ThreadPoolDevice(
116 underlying_threadpool_, underlying_threadpool_->NumThreads(), nullptr));
117}
118
119ThreadPool::~ThreadPool() {}
120
121void ThreadPool::Schedule(std::function<void()> fn) {
122 CHECK(fn != nullptr);
123 underlying_threadpool_->Schedule(std::move(fn));
124}
125
126int ThreadPool::NumShardsUsedByFixedBlockSizeScheduling(
127 const int64_t total, const int64_t block_size) {
128 if (block_size <= 0 || total <= 1 || total <= block_size ||
129 NumThreads() == 1) {
130 return 1;
131 }
132 return (total + block_size - 1) / block_size;
133}
134
135int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
136 const int64_t block_size, const int64_t total) {
137 return NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
138}
139
140void ThreadPool::ParallelFor(int64_t total,
141 const SchedulingParams& scheduling_params,
142 const std::function<void(int64_t, int64_t)>& fn) {
143 switch (scheduling_params.strategy()) {
144 case SchedulingStrategy::kAdaptive: {
145 if (scheduling_params.cost_per_unit().has_value()) {
146 ParallelFor(total, *scheduling_params.cost_per_unit(), fn);
147 }
148 break;
149 }
150 case SchedulingStrategy::kFixedBlockSize: {
151 if (scheduling_params.block_size().has_value()) {
152 ParallelForFixedBlockSizeScheduling(
153 total, *scheduling_params.block_size(), fn);
154 }
155 break;
156 }
157 }
158}
159
160void ThreadPool::TransformRangeConcurrently(
161 const int64_t block_size, const int64_t total,
162 const std::function<void(int64_t, int64_t)>& fn) {
163 ParallelFor(total,
164 SchedulingParams(SchedulingStrategy::kFixedBlockSize,
165 absl::nullopt /* cost_per_unit */, block_size),
166 fn);
167}
168
169// This functionality is similar to parallelFor, except that reasoning about
170// the number of shards used is significantly easier.
171void ThreadPool::ParallelForFixedBlockSizeScheduling(
172 const int64_t total, const int64_t block_size,
173 const std::function<void(int64_t, int64_t)>& fn) {
174 const int num_shards_used =
175 NumShardsUsedByFixedBlockSizeScheduling(total, block_size);
176 if (num_shards_used == 1) {
177 fn(0, total);
178 return;
179 }
180
181 // Adapted from Eigen's parallelFor implementation.
182 BlockingCounter counter(num_shards_used);
183 std::function<void(int64_t, int64_t)> handle_range =
184 [=, &handle_range, &counter, &fn](int64_t first, int64_t last) {
185 while (last - first > block_size) {
186 // Find something near the midpoint which is a multiple of block size.
187 const int64_t mid = first + ((last - first) / 2 + block_size - 1) /
188 block_size * block_size;
189 Schedule([=, &handle_range]() { handle_range(mid, last); });
190 last = mid;
191 }
192 // Single block or less, execute directly.
193 fn(first, last);
194 counter.DecrementCount(); // The shard is done.
195 };
196 if (num_shards_used <= NumThreads()) {
197 // Avoid a thread hop by running the root of the tree and one block on the
198 // main thread.
199 handle_range(0, total);
200 } else {
201 // Execute the root in the thread pool to avoid running work on more than
202 // numThreads() threads.
203 Schedule([=, &handle_range]() { handle_range(0, total); });
204 }
205 counter.Wait();
206}
207
208void ThreadPool::ParallelFor(int64_t total, int64_t cost_per_unit,
209 const std::function<void(int64_t, int64_t)>& fn) {
210 CHECK_GE(total, 0);
211 CHECK_EQ(total, (int64_t)(Eigen::Index)total);
212 threadpool_device_->parallelFor(
213 total, Eigen::TensorOpCost(0, 0, cost_per_unit),
214 [&fn](Eigen::Index first, Eigen::Index last) { fn(first, last); });
215}
216
217void ThreadPool::ParallelForWithWorkerId(
218 int64_t total, int64_t cost_per_unit,
219 const std::function<void(int64_t, int64_t, int)>& fn) {
220 CHECK_GE(total, 0);
221 CHECK_EQ(total, (int64_t)(Eigen::Index)total);
222
223 threadpool_device_->parallelFor(total,
224 Eigen::TensorOpCost(0, 0, cost_per_unit),
225 [this, &fn](int64_t start, int64_t limit) {
226 // ParallelFor may use the current thread to
227 // do some work synchronously. When calling
228 // CurrentThreadId() from outside of the
229 // thread pool, we get -1, so we can shift
230 // every id up by 1.
231 int id = CurrentThreadId() + 1;
232 fn(start, limit, id);
233 });
234}
235
236void ThreadPool::ParallelForWithWorkerId(
237 int64_t total, const SchedulingParams& scheduling_params,
238 const std::function<void(int64_t, int64_t, int)>& fn) {
239 ParallelFor(total, scheduling_params,
240 [this, &fn](int64_t start, int64_t limit) {
241 // We may use the current thread to do some work synchronously.
242 // When calling CurrentThreadId() from outside of the thread
243 // pool, we get -1, so we can shift every id up by 1.
244 int id = CurrentThreadId() + 1;
245 fn(start, limit, id);
246 });
247}
248
249int ThreadPool::NumThreads() const {
250 return underlying_threadpool_->NumThreads();
251}
252
253int ThreadPool::CurrentThreadId() const {
254 return underlying_threadpool_->CurrentThreadId();
255}
256
257void ThreadPool::ScheduleWithHint(std::function<void()> fn, int start,
258 int limit) {
259 underlying_threadpool_->ScheduleWithHint(std::move(fn), start, limit);
260}
261
262void ThreadPool::SetStealPartitions(
263 const std::vector<std::pair<unsigned, unsigned>>& partitions) {
264 // ThreadPool::SetStealPartitions is only called in the constructor of
265 // RunHandlerPool::Impl, which currently instantiates ThreadPool using a
266 // constructor that does not take user_threadpool. Thus we assume
267 // eigen_threadpool_ is not null here.
268 DCHECK(eigen_threadpool_ != nullptr);
269 eigen_threadpool_->SetStealPartitions(partitions);
270}
271
272Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const {
273 DCHECK(underlying_threadpool_ != nullptr);
274 return underlying_threadpool_;
275}
276} // namespace thread
277} // namespace tsl
278