1#include <ATen/Config.h>
2#if AT_PARALLEL_NATIVE
3#include <ATen/Parallel.h>
4#include <ATen/ParallelFuture.h>
5#include <ATen/PTThreadPool.h>
6
7#ifndef C10_MOBILE
8#include <c10/core/thread_pool.h>
9#include <c10/util/irange.h>
10#else
11#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
12#endif // C10_MOBILE
13
14#include <atomic>
15#include <utility>
16
17#ifdef _OPENMP
18#include <omp.h>
19#endif
20
21#if AT_MKL_ENABLED()
22#include <mkl.h>
23#endif
24
25namespace at {
26namespace {
27// used with _set_in_parallel_region to mark master thread
28// as in parallel region while executing parallel primitives
29thread_local bool in_parallel_region_ = false;
30
31// thread number (task_id) set by parallel primitive
32thread_local int thread_num_ = 0;
33
34void _set_in_parallel_region(bool in_region) {
35 in_parallel_region_ = in_region;
36}
37
38} // namespace (anonymous)
39
40namespace internal {
41void set_thread_num(int thread_num) {
42 thread_num_ = thread_num;
43}
44}
45
46namespace {
47void _unset_thread_num() {
48 thread_num_ = 0;
49}
50
51#ifndef C10_MOBILE
52
53const int NOT_SET = -1;
54const int CONSUMED = -2;
55
56// Number of threads set by the user
57// NOT_SET -> positive value -> CONSUMED
58// or
59// NOT_SET -> CONSUMED
60// Meaning:
61// - NOT_SET - pool not initialized, user value is not set
62// - positive value - pool not initialized, user value set
63// - CONSUMED - pool is initialized
64std::atomic<int> num_intraop_threads{NOT_SET};
65
66int _num_pool_threads(int nthreads) {
67 if (nthreads == NOT_SET) {
68 nthreads = intraop_default_num_threads();
69 } else {
70 TORCH_INTERNAL_ASSERT(nthreads > 0);
71 }
72 // minus one because of the master thread
73 return nthreads - 1;
74}
75
76TaskThreadPoolBase& _get_intraop_pool() {
77 static std::shared_ptr<TaskThreadPoolBase> pool =
78 ThreadPoolRegistry()->Create(
79 "C10",
80 /* device_id */ 0,
81 /* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)),
82 /* create_new */ true); // create a separate thread pool for intra-op
83 return *pool;
84}
85
86#endif // C10_MOBILE
87
88// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
89// `fn` will be called with params: (thread_pool_task_id, task_id).
90void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
91#ifndef C10_MOBILE
92 for (const auto i : c10::irange(1, range)) {
93 _get_intraop_pool().run([fn, i]() { fn((int)i, i); });
94 }
95 // Run the first task on the current thread directly.
96 fn(0, 0);
97#else
98 caffe2::PThreadPool* const pool = caffe2::pthreadpool();
99 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
100
101 pool->run(
102 // PThreadPool::run() is blocking. A std::function [const] reference to
103 // this lambda cannot go out of scope before PThreadPool::run() returns.
104 [&fn](const size_t task_id) {
105 fn(0 /* unused */, task_id);
106 }, range);
107#endif // C10_MOBILE
108}
109
110// RAII guard helps to support in_parallel_region() and get_thread_num() API.
111struct ParallelRegionGuard {
112 ParallelRegionGuard(int task_id) {
113 internal::set_thread_num(task_id);
114 _set_in_parallel_region(true);
115 }
116
117 ~ParallelRegionGuard() {
118 _set_in_parallel_region(false);
119 _unset_thread_num();
120 }
121};
122
123} // namespace
124
125namespace internal {
126
127inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
128 int64_t begin, int64_t end, int64_t grain_size) {
129 if ((end - begin) < grain_size) {
130 return std::make_tuple(1, std::max((int64_t)0, end - begin));
131 }
132 // Choose number of tasks based on grain size and number of threads.
133 size_t chunk_size = divup((end - begin), get_num_threads());
134 // Make sure each task is at least grain_size size.
135 chunk_size = std::max((size_t)grain_size, chunk_size);
136 size_t num_tasks = divup((end - begin), chunk_size);
137 return std::make_tuple(num_tasks, chunk_size);
138}
139
140void invoke_parallel(
141 const int64_t begin,
142 const int64_t end,
143 const int64_t grain_size,
144 const std::function<void(int64_t, int64_t)>& f) {
145 at::internal::lazy_init_num_threads();
146
147 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
148 size_t num_tasks, chunk_size;
149 std::tie(num_tasks, chunk_size) =
150 internal::calc_num_tasks_and_chunk_size(begin, end, grain_size);
151
152 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
153 struct {
154 std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
155 std::exception_ptr eptr;
156 std::mutex mutex;
157 volatile size_t remaining;
158 std::condition_variable cv;
159 } state;
160
161 auto task = [f, &state, begin, end, chunk_size]
162 (int /* unused */, size_t task_id) {
163 int64_t local_start = begin + task_id * chunk_size;
164 if (local_start < end) {
165 int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
166 try {
167 ParallelRegionGuard guard(task_id);
168 f(local_start, local_end);
169 } catch (...) {
170 if (!state.err_flag.test_and_set()) {
171 state.eptr = std::current_exception();
172 }
173 }
174 }
175 {
176 std::unique_lock<std::mutex> lk(state.mutex);
177 if (--state.remaining == 0) {
178 state.cv.notify_one();
179 }
180 }
181 };
182 state.remaining = num_tasks;
183 _run_with_pool(std::move(task), num_tasks);
184
185 // Wait for all tasks to finish.
186 {
187 std::unique_lock<std::mutex> lk(state.mutex);
188 if (state.remaining != 0) {
189 state.cv.wait(lk);
190 }
191 }
192 if (state.eptr) {
193 std::rethrow_exception(state.eptr);
194 }
195}
196
197} // namespace internal
198
199void init_num_threads() {
200#ifdef _OPENMP
201 omp_set_num_threads(1);
202#endif
203
204#if AT_MKL_ENABLED()
205 mkl_set_num_threads(1);
206#endif
207
208#ifdef C10_MOBILE
209 caffe2::pthreadpool();
210#endif
211}
212
213void set_num_threads(int nthreads) {
214#ifndef C10_MOBILE
215 TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
216 int no_value = NOT_SET;
217 if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) {
218 // num_intraop_threads either stores a positive integer or CONSUMED,
219 // check that requested size is the same as the current one
220 int stored_nthreads = num_intraop_threads.load();
221 if (stored_nthreads <= 0) {
222 // plus one because of master thread
223 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
224 stored_nthreads = _get_intraop_pool().size() + 1;
225 }
226 if (stored_nthreads != nthreads) {
227 TORCH_WARN(
228 "Cannot set number of intraop threads "
229 "after parallel work has started or after set_num_threads call "
230 "when using native parallel backend");
231 }
232 }
233#else
234 caffe2::PThreadPool* const pool = caffe2::pthreadpool();
235 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
236 pool->set_thread_count(nthreads);
237#endif // C10_MOBILE
238}
239
240int get_num_threads() {
241 at::internal::lazy_init_num_threads();
242#ifndef C10_MOBILE
243 // not initializing pool unnecessarily,
244 // because pool cannot be resized after initialization
245 int nthreads = num_intraop_threads.load();
246 if (nthreads > 0) {
247 return nthreads;
248 } else if (nthreads == NOT_SET) {
249 return intraop_default_num_threads();
250 } else {
251 TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
252 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
253 return _get_intraop_pool().size() + 1;
254 }
255#else
256 caffe2::PThreadPool* const pool = caffe2::pthreadpool();
257 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!")
258 return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count();
259#endif // C10_MOBILE
260}
261
262int get_thread_num() {
263 return thread_num_;
264}
265
266bool in_parallel_region() {
267#ifndef C10_MOBILE
268 return in_parallel_region_ || (
269 num_intraop_threads.load() == CONSUMED &&
270 // Needed as intraop_launch() doesn't set in_parallel_region().
271 _get_intraop_pool().inThreadPool()
272 );
273#else
274 return in_parallel_region_;
275#endif // C10_MOBILE
276}
277
278void intraop_launch(std::function<void()> func) {
279#ifndef C10_MOBILE
280 if (!in_parallel_region() && get_num_threads() > 1) {
281 _get_intraop_pool().run(std::move(func));
282 } else {
283 // execute inline if we're in parallel region
284 func();
285 }
286#else
287 // TODO: caffe2::PThreadPool only provides a data-parallel API.
288 // Task parallelism is not currently supported.
289 func();
290#endif // C10_MOBILE
291}
292
293c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
294 std::function<void()> func) {
295#ifndef C10_MOBILE
296 auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
297 if (!in_parallel_region() && get_num_threads() > 1) {
298 _get_intraop_pool().run(
299 [func, future]() {
300 func();
301 future->markCompleted();
302 }
303 );
304 } else {
305 func();
306 future->markCompleted();
307 }
308 return future;
309#else
310 // TODO: caffe2::PThreadPool only provides a data-parallel API.
311 // Task parallelism is not currently supported.
312 auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>());
313 func();
314 future->markCompleted();
315 return future;
316#endif // C10_MOBILE
317}
318
319} // namespace at
320#endif
321