1 | #include <ATen/Config.h> |
2 | #if AT_PARALLEL_NATIVE |
3 | #include <ATen/Parallel.h> |
4 | #include <ATen/ParallelFuture.h> |
5 | #include <ATen/PTThreadPool.h> |
6 | |
7 | #ifndef C10_MOBILE |
8 | #include <c10/core/thread_pool.h> |
9 | #include <c10/util/irange.h> |
10 | #else |
11 | #include <caffe2/utils/threadpool/pthreadpool-cpp.h> |
12 | #endif // C10_MOBILE |
13 | |
14 | #include <atomic> |
15 | #include <utility> |
16 | |
17 | #ifdef _OPENMP |
18 | #include <omp.h> |
19 | #endif |
20 | |
21 | #if AT_MKL_ENABLED() |
22 | #include <mkl.h> |
23 | #endif |
24 | |
25 | namespace at { |
26 | namespace { |
27 | // used with _set_in_parallel_region to mark master thread |
28 | // as in parallel region while executing parallel primitives |
29 | thread_local bool in_parallel_region_ = false; |
30 | |
31 | // thread number (task_id) set by parallel primitive |
32 | thread_local int thread_num_ = 0; |
33 | |
34 | void _set_in_parallel_region(bool in_region) { |
35 | in_parallel_region_ = in_region; |
36 | } |
37 | |
38 | } // namespace (anonymous) |
39 | |
40 | namespace internal { |
41 | void set_thread_num(int thread_num) { |
42 | thread_num_ = thread_num; |
43 | } |
44 | } |
45 | |
46 | namespace { |
47 | void _unset_thread_num() { |
48 | thread_num_ = 0; |
49 | } |
50 | |
51 | #ifndef C10_MOBILE |
52 | |
53 | const int NOT_SET = -1; |
54 | const int CONSUMED = -2; |
55 | |
56 | // Number of threads set by the user |
57 | // NOT_SET -> positive value -> CONSUMED |
58 | // or |
59 | // NOT_SET -> CONSUMED |
60 | // Meaning: |
61 | // - NOT_SET - pool not initialized, user value is not set |
62 | // - positive value - pool not initialized, user value set |
63 | // - CONSUMED - pool is initialized |
64 | std::atomic<int> num_intraop_threads{NOT_SET}; |
65 | |
66 | int _num_pool_threads(int nthreads) { |
67 | if (nthreads == NOT_SET) { |
68 | nthreads = intraop_default_num_threads(); |
69 | } else { |
70 | TORCH_INTERNAL_ASSERT(nthreads > 0); |
71 | } |
72 | // minus one because of the master thread |
73 | return nthreads - 1; |
74 | } |
75 | |
76 | TaskThreadPoolBase& _get_intraop_pool() { |
77 | static std::shared_ptr<TaskThreadPoolBase> pool = |
78 | ThreadPoolRegistry()->Create( |
79 | "C10" , |
80 | /* device_id */ 0, |
81 | /* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)), |
82 | /* create_new */ true); // create a separate thread pool for intra-op |
83 | return *pool; |
84 | } |
85 | |
86 | #endif // C10_MOBILE |
87 | |
88 | // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool. |
89 | // `fn` will be called with params: (thread_pool_task_id, task_id). |
90 | void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) { |
91 | #ifndef C10_MOBILE |
92 | for (const auto i : c10::irange(1, range)) { |
93 | _get_intraop_pool().run([fn, i]() { fn((int)i, i); }); |
94 | } |
95 | // Run the first task on the current thread directly. |
96 | fn(0, 0); |
97 | #else |
98 | caffe2::PThreadPool* const pool = caffe2::pthreadpool(); |
99 | TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!" ); |
100 | |
101 | pool->run( |
102 | // PThreadPool::run() is blocking. A std::function [const] reference to |
103 | // this lambda cannot go out of scope before PThreadPool::run() returns. |
104 | [&fn](const size_t task_id) { |
105 | fn(0 /* unused */, task_id); |
106 | }, range); |
107 | #endif // C10_MOBILE |
108 | } |
109 | |
110 | // RAII guard helps to support in_parallel_region() and get_thread_num() API. |
111 | struct ParallelRegionGuard { |
112 | ParallelRegionGuard(int task_id) { |
113 | internal::set_thread_num(task_id); |
114 | _set_in_parallel_region(true); |
115 | } |
116 | |
117 | ~ParallelRegionGuard() { |
118 | _set_in_parallel_region(false); |
119 | _unset_thread_num(); |
120 | } |
121 | }; |
122 | |
123 | } // namespace |
124 | |
125 | namespace internal { |
126 | |
127 | inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size( |
128 | int64_t begin, int64_t end, int64_t grain_size) { |
129 | if ((end - begin) < grain_size) { |
130 | return std::make_tuple(1, std::max((int64_t)0, end - begin)); |
131 | } |
132 | // Choose number of tasks based on grain size and number of threads. |
133 | size_t chunk_size = divup((end - begin), get_num_threads()); |
134 | // Make sure each task is at least grain_size size. |
135 | chunk_size = std::max((size_t)grain_size, chunk_size); |
136 | size_t num_tasks = divup((end - begin), chunk_size); |
137 | return std::make_tuple(num_tasks, chunk_size); |
138 | } |
139 | |
140 | void invoke_parallel( |
141 | const int64_t begin, |
142 | const int64_t end, |
143 | const int64_t grain_size, |
144 | const std::function<void(int64_t, int64_t)>& f) { |
145 | at::internal::lazy_init_num_threads(); |
146 | |
147 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
148 | size_t num_tasks, chunk_size; |
149 | std::tie(num_tasks, chunk_size) = |
150 | internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); |
151 | |
152 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
153 | struct { |
154 | std::atomic_flag err_flag = ATOMIC_FLAG_INIT; |
155 | std::exception_ptr eptr; |
156 | std::mutex mutex; |
157 | volatile size_t remaining; |
158 | std::condition_variable cv; |
159 | } state; |
160 | |
161 | auto task = [f, &state, begin, end, chunk_size] |
162 | (int /* unused */, size_t task_id) { |
163 | int64_t local_start = begin + task_id * chunk_size; |
164 | if (local_start < end) { |
165 | int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start)); |
166 | try { |
167 | ParallelRegionGuard guard(task_id); |
168 | f(local_start, local_end); |
169 | } catch (...) { |
170 | if (!state.err_flag.test_and_set()) { |
171 | state.eptr = std::current_exception(); |
172 | } |
173 | } |
174 | } |
175 | { |
176 | std::unique_lock<std::mutex> lk(state.mutex); |
177 | if (--state.remaining == 0) { |
178 | state.cv.notify_one(); |
179 | } |
180 | } |
181 | }; |
182 | state.remaining = num_tasks; |
183 | _run_with_pool(std::move(task), num_tasks); |
184 | |
185 | // Wait for all tasks to finish. |
186 | { |
187 | std::unique_lock<std::mutex> lk(state.mutex); |
188 | if (state.remaining != 0) { |
189 | state.cv.wait(lk); |
190 | } |
191 | } |
192 | if (state.eptr) { |
193 | std::rethrow_exception(state.eptr); |
194 | } |
195 | } |
196 | |
197 | } // namespace internal |
198 | |
199 | void init_num_threads() { |
200 | #ifdef _OPENMP |
201 | omp_set_num_threads(1); |
202 | #endif |
203 | |
204 | #if AT_MKL_ENABLED() |
205 | mkl_set_num_threads(1); |
206 | #endif |
207 | |
208 | #ifdef C10_MOBILE |
209 | caffe2::pthreadpool(); |
210 | #endif |
211 | } |
212 | |
213 | void set_num_threads(int nthreads) { |
214 | #ifndef C10_MOBILE |
215 | TORCH_CHECK(nthreads > 0, "Expected positive number of threads" ); |
216 | int no_value = NOT_SET; |
217 | if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) { |
218 | // num_intraop_threads either stores a positive integer or CONSUMED, |
219 | // check that requested size is the same as the current one |
220 | int stored_nthreads = num_intraop_threads.load(); |
221 | if (stored_nthreads <= 0) { |
222 | // plus one because of master thread |
223 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
224 | stored_nthreads = _get_intraop_pool().size() + 1; |
225 | } |
226 | if (stored_nthreads != nthreads) { |
227 | TORCH_WARN( |
228 | "Cannot set number of intraop threads " |
229 | "after parallel work has started or after set_num_threads call " |
230 | "when using native parallel backend" ); |
231 | } |
232 | } |
233 | #else |
234 | caffe2::PThreadPool* const pool = caffe2::pthreadpool(); |
235 | TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!" ); |
236 | pool->set_thread_count(nthreads); |
237 | #endif // C10_MOBILE |
238 | } |
239 | |
240 | int get_num_threads() { |
241 | at::internal::lazy_init_num_threads(); |
242 | #ifndef C10_MOBILE |
243 | // not initializing pool unnecessarily, |
244 | // because pool cannot be resized after initialization |
245 | int nthreads = num_intraop_threads.load(); |
246 | if (nthreads > 0) { |
247 | return nthreads; |
248 | } else if (nthreads == NOT_SET) { |
249 | return intraop_default_num_threads(); |
250 | } else { |
251 | TORCH_INTERNAL_ASSERT(nthreads == CONSUMED); |
252 | // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) |
253 | return _get_intraop_pool().size() + 1; |
254 | } |
255 | #else |
256 | caffe2::PThreadPool* const pool = caffe2::pthreadpool(); |
257 | TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!" ) |
258 | return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count(); |
259 | #endif // C10_MOBILE |
260 | } |
261 | |
262 | int get_thread_num() { |
263 | return thread_num_; |
264 | } |
265 | |
266 | bool in_parallel_region() { |
267 | #ifndef C10_MOBILE |
268 | return in_parallel_region_ || ( |
269 | num_intraop_threads.load() == CONSUMED && |
270 | // Needed as intraop_launch() doesn't set in_parallel_region(). |
271 | _get_intraop_pool().inThreadPool() |
272 | ); |
273 | #else |
274 | return in_parallel_region_; |
275 | #endif // C10_MOBILE |
276 | } |
277 | |
278 | void intraop_launch(std::function<void()> func) { |
279 | #ifndef C10_MOBILE |
280 | if (!in_parallel_region() && get_num_threads() > 1) { |
281 | _get_intraop_pool().run(std::move(func)); |
282 | } else { |
283 | // execute inline if we're in parallel region |
284 | func(); |
285 | } |
286 | #else |
287 | // TODO: caffe2::PThreadPool only provides a data-parallel API. |
288 | // Task parallelism is not currently supported. |
289 | func(); |
290 | #endif // C10_MOBILE |
291 | } |
292 | |
293 | c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future( |
294 | std::function<void()> func) { |
295 | #ifndef C10_MOBILE |
296 | auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get()); |
297 | if (!in_parallel_region() && get_num_threads() > 1) { |
298 | _get_intraop_pool().run( |
299 | [func, future]() { |
300 | func(); |
301 | future->markCompleted(); |
302 | } |
303 | ); |
304 | } else { |
305 | func(); |
306 | future->markCompleted(); |
307 | } |
308 | return future; |
309 | #else |
310 | // TODO: caffe2::PThreadPool only provides a data-parallel API. |
311 | // Task parallelism is not currently supported. |
312 | auto future = c10::make_intrusive<c10::ivalue::Future>(c10::dynT<NoneType>()); |
313 | func(); |
314 | future->markCompleted(); |
315 | return future; |
316 | #endif // C10_MOBILE |
317 | } |
318 | |
319 | } // namespace at |
320 | #endif |
321 | |