1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_DNNL_THREAD_HPP |
18 | #define COMMON_DNNL_THREAD_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <functional> |
22 | #include <mutex> |
23 | |
24 | #include "utils.hpp" |
25 | #include "z_magic.hpp" |
26 | |
27 | // IMPORTANT NOTICE: |
28 | // This header is special in the library since it enables all threading |
29 | // functionality in the product including tests. |
30 | // tests/test_thread.{c,h}pp files rely on this header file by: |
31 | // * Substituting `threadpool_utils` namespace to re-use threadpool functions |
32 | // and enable a second threadpool different from the library; |
33 | // * Re-defining `DNNL_CPU_THREADING_RUNTIME` macro value when it is supposed |
34 | // to be `DNNL_RUNTIME_SEQ`, e.g., for CPU_NONE configuration. |
35 | // 1. It implies all parts of code relying on this macro should stay in the |
36 | // file. |
37 | // 2. It implies there are no function bodies in the translation units |
38 | // related to the library. Tests threading layer uses dnnl::impl::func |
39 | // signature, and if library has symbols defined, regardless of |
40 | // redefinition, it will take those that were compiled with original |
41 | // macro value. |
42 | // |
43 | // Potential drawback could be increased binary size but it doesn't happen much |
44 | // due to linker optimizations. The newer compiler and C++ standard, the less |
45 | // binary size will be achieved. |
46 | |
47 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
48 | #include "counting_barrier.hpp" |
49 | #endif |
50 | |
51 | #if defined(DNNL_ENABLE_ITT_TASKS) |
52 | #include "common/ittnotify.hpp" |
53 | #endif |
54 | |
55 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ |
56 | #define DNNL_THR_SYNC 1 |
57 | inline int dnnl_get_max_threads() { |
58 | return 1; |
59 | } |
60 | inline int dnnl_in_parallel() { |
61 | return 0; |
62 | } |
63 | inline void dnnl_thr_barrier() {} |
64 | |
65 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP |
66 | #include "omp.h" |
67 | #define DNNL_THR_SYNC 1 |
68 | inline int dnnl_get_max_threads() { |
69 | return omp_get_max_threads(); |
70 | } |
71 | inline int dnnl_in_parallel() { |
72 | return omp_in_parallel(); |
73 | } |
74 | inline void dnnl_thr_barrier() { |
75 | #pragma omp barrier |
76 | } |
77 | |
78 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB |
79 | #include "tbb/parallel_for.h" |
80 | #include "tbb/task_arena.h" |
81 | #define DNNL_THR_SYNC 0 |
82 | inline int dnnl_get_max_threads() { |
83 | return tbb::this_task_arena::max_concurrency(); |
84 | } |
85 | inline int dnnl_in_parallel() { |
86 | return 0; |
87 | } |
88 | inline void dnnl_thr_barrier() { |
89 | assert(!"no barrier in TBB" ); |
90 | } |
91 | |
92 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
93 | #include <thread> |
94 | #include "oneapi/dnnl/dnnl_threadpool_iface.hpp" |
95 | #define DNNL_THR_SYNC 0 |
96 | |
97 | #include "cpu/platform.hpp" |
98 | |
99 | namespace dnnl { |
100 | namespace impl { |
101 | namespace threadpool_utils { |
102 | |
103 | // Each thread maintains a thread-local pointer to a threadpool which is |
104 | // 'active' for the current thread. If this pointer is a nullptr, all the work |
105 | // is executed sequentially. |
106 | |
107 | // Sets `tp` to be the active threadpool for the calling thread. This will |
108 | // make all calls to `get_active_threadpool()` to return `tp` thus enabling |
109 | // `parallel()` and `parallel_nd()` to submit work to `tp`. |
110 | void activate_threadpool(dnnl::threadpool_interop::threadpool_iface *tp); |
111 | |
112 | // Resets the active threadpool for the calling thread to nullptr. After this |
113 | // call `parallel()` and `parallel_nd()` would execute work sequentially. |
114 | void deactivate_threadpool(); |
115 | |
116 | // Returns the active threadpool for the calling thread. |
117 | dnnl::threadpool_interop::threadpool_iface *get_active_threadpool(); |
118 | |
119 | // returns the maximum concurrency available in the given global context |
120 | int get_max_concurrency(); |
121 | |
122 | int &get_threadlocal_max_concurrency(); |
123 | |
124 | } // namespace threadpool_utils |
125 | } // namespace impl |
126 | } // namespace dnnl |
127 | |
128 | inline int dnnl_get_max_threads() { |
129 | using namespace dnnl::impl::threadpool_utils; |
130 | dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); |
131 | |
132 | // This is the maximum number of threads oneDNN will use by default |
133 | int max_concurrency = dnnl::impl::threadpool_utils::get_max_concurrency(); |
134 | |
135 | // Use the default max_concurrency only when no tp is passed by |
136 | // user (e.g. primitive creation). |
137 | return tp ? std::max(1, tp->get_num_threads()) : max_concurrency; |
138 | } |
139 | inline int dnnl_in_parallel() { |
140 | using namespace dnnl::impl::threadpool_utils; |
141 | dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); |
142 | return tp ? tp->get_in_parallel() : 0; |
143 | } |
144 | inline void dnnl_thr_barrier() { |
145 | assert(!"no barrier with THREADPOOL" ); |
146 | } |
147 | #endif |
148 | |
149 | /* The purpose of this function is to provide the number of threads the library |
150 | * is aware of when this function is invoked. Since oneDNN does not allow nested |
151 | * parallelism, inside a parallel region the number of available threads is 1. |
152 | * Otherwise, the number of current threads varies between threading runtimes: |
153 | * - for OpenMP and TBB, return the max number of threads since the number of |
154 | * threads is held in a global object throughout the entire execution. |
155 | * - for Threadpool, since the global object in oneDNN changes throughout |
156 | * execution, two situations can occur: |
157 | * a) if the library *is* aware of a threadpool when this function is invoked, |
158 | * return the number of available threads in the threadpool; |
159 | * b) if the library *is not* aware of a threadpool when this function is |
160 | * invoked, return 1 since the main thread will do the work. |
161 | */ |
162 | inline int dnnl_get_current_num_threads() { |
163 | if (dnnl_in_parallel()) return 1; |
164 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP |
165 | return omp_get_max_threads(); |
166 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB |
167 | return tbb::this_task_arena::max_concurrency(); |
168 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
169 | using namespace dnnl::impl::threadpool_utils; |
170 | dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); |
171 | return (tp) ? dnnl_get_max_threads() : 1; |
172 | #else |
173 | return 1; |
174 | #endif |
175 | } |
176 | |
177 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP |
178 | #define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__)) |
179 | #define OMP_GET_THREAD_NUM() omp_get_thread_num() |
180 | #define OMP_GET_NUM_THREADS() omp_get_num_threads() |
181 | #else |
182 | #define PRAGMA_OMP(...) |
183 | #define OMP_GET_THREAD_NUM() 0 |
184 | #define OMP_GET_NUM_THREADS() 1 |
185 | #endif |
186 | |
187 | // MSVC still supports omp 2.0 only |
188 | #if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) |
189 | #define collapse(x) |
190 | #define PRAGMA_OMP_SIMD(...) |
191 | #else |
192 | #define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__)) |
193 | #endif // defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) |
194 | |
195 | // process simdlen; it is supported for Clang >= 3.9; ICC >= 17.0; GCC >= 6.1 |
196 | // No support on Windows. |
197 | #if (defined(__clang_major__) \ |
198 | && (__clang_major__ < 3 \ |
199 | || (__clang_major__ == 3 && __clang_minor__ < 9))) \ |
200 | || (defined(__INTEL_COMPILER) && __INTEL_COMPILER < 1700) \ |
201 | || (!defined(__INTEL_COMPILER) && !defined(__clang__) \ |
202 | && (defined(_MSC_VER) || __GNUC__ < 6 \ |
203 | || (__GNUC__ == 6 && __GNUC_MINOR__ < 1))) |
204 | #define simdlen(x) |
205 | #endif // long simdlen if |
206 | |
207 | namespace dnnl { |
208 | namespace impl { |
209 | |
210 | inline bool dnnl_thr_syncable() { |
211 | return DNNL_THR_SYNC == 1; |
212 | } |
213 | |
214 | template <typename T, typename U> |
215 | inline void balance211(T n, U team, U tid, T &n_start, T &n_end) { |
216 | T n_min = 1; |
217 | T &n_my = n_end; |
218 | if (team <= 1 || n == 0) { |
219 | n_start = 0; |
220 | n_my = n; |
221 | } else if (n_min == 1) { |
222 | // team = T1 + T2 |
223 | // n = T1*n1 + T2*n2 (n1 - n2 = 1) |
224 | T n1 = utils::div_up(n, (T)team); |
225 | T n2 = n1 - 1; |
226 | T T1 = n - n2 * (T)team; |
227 | n_my = (T)tid < T1 ? n1 : n2; |
228 | n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2; |
229 | } |
230 | |
231 | n_end += n_start; |
232 | } |
233 | |
234 | template <typename T, typename U> |
235 | void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, T nx, T &nx_start, |
236 | T &nx_end, T nx_divider) { |
237 | const T grp_count = nstl::min(nx_divider, static_cast<T>(nthr)); |
238 | const int grp_size_big = nthr / static_cast<int>(grp_count) + 1; |
239 | const int grp_size_small = nthr / static_cast<int>(grp_count); |
240 | const int n_grp_big = nthr % static_cast<int>(grp_count); |
241 | const int threads_in_big_groups = n_grp_big * grp_size_big; |
242 | |
243 | const int ithr_bound_distance = ithr - threads_in_big_groups; |
244 | T grp, grp_ithr, grp_nthr; |
245 | if (ithr_bound_distance < 0) { // ithr in first groups |
246 | grp = ithr / grp_size_big; |
247 | grp_ithr = ithr % grp_size_big; |
248 | grp_nthr = grp_size_big; |
249 | } else { // ithr in last groups |
250 | grp = n_grp_big + ithr_bound_distance / grp_size_small; |
251 | grp_ithr = ithr_bound_distance % grp_size_small; |
252 | grp_nthr = grp_size_small; |
253 | } |
254 | |
255 | balance211(nx, grp_count, grp, nx_start, nx_end); |
256 | balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); |
257 | } |
258 | |
259 | /* Functions: |
260 | * - parallel(nthr, f) - executes f in parallel using at |
261 | * most nthr threads. If nthr equals |
262 | * 0 dnnl_get_current_num_threads() threads |
263 | * is used |
264 | * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for |
265 | * already created threads |
266 | * - for_nd_ext(ithr, nthr, dims..., f) - multidimensional for loop for |
267 | * already created threads that passes |
268 | * ithr and nthr |
269 | * - parallel_nd(dims..., f) - creates a parallel section and then |
270 | * calls for_nd |
271 | * - parallel_nd_ext(dims..., f) - creates a parallel section and then |
272 | * calls for_nd_ext |
273 | * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and |
274 | * then calls for_nd (mostly for |
275 | * convenience) |
276 | */ |
277 | |
278 | /* general parallelization */ |
279 | inline int adjust_num_threads(int nthr, dim_t work_amount) { |
280 | if (nthr == 0) nthr = dnnl_get_current_num_threads(); |
281 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP |
282 | return (work_amount == 1 || omp_in_parallel()) ? 1 : nthr; |
283 | #else |
284 | return (int)std::min((dim_t)nthr, work_amount); |
285 | #endif |
286 | } |
287 | |
288 | inline void parallel(int nthr, const std::function<void(int, int)> &f) { |
289 | nthr = adjust_num_threads(nthr, INT64_MAX); |
290 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ |
291 | for (int i = 0; i < nthr; ++i) { |
292 | f(i, nthr); |
293 | } |
294 | #else |
295 | #if defined(DNNL_ENABLE_ITT_TASKS) |
296 | auto task_primitive_kind = itt::primitive_task_get_current_kind(); |
297 | bool itt_enable = itt::get_itt(itt::__itt_task_level_high); |
298 | #endif |
299 | if (nthr == 1) { |
300 | f(0, 1); |
301 | return; |
302 | } |
303 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP |
304 | #pragma omp parallel num_threads(nthr) |
305 | { |
306 | int nthr_ = omp_get_num_threads(); |
307 | int ithr_ = omp_get_thread_num(); |
308 | assert(nthr_ == nthr); |
309 | #if defined(DNNL_ENABLE_ITT_TASKS) |
310 | if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind); |
311 | #endif |
312 | f(ithr_, nthr_); |
313 | #if defined(DNNL_ENABLE_ITT_TASKS) |
314 | if (ithr_ && itt_enable) itt::primitive_task_end(); |
315 | #endif |
316 | } |
317 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB |
318 | tbb::parallel_for( |
319 | 0, nthr, |
320 | [&](int ithr) { |
321 | #if defined(DNNL_ENABLE_ITT_TASKS) |
322 | bool mark_task = itt::primitive_task_get_current_kind() |
323 | == primitive_kind::undefined; |
324 | if (mark_task && itt_enable) |
325 | itt::primitive_task_start(task_primitive_kind); |
326 | #endif |
327 | f(ithr, nthr); |
328 | #if defined(DNNL_ENABLE_ITT_TASKS) |
329 | if (mark_task && itt_enable) itt::primitive_task_end(); |
330 | #endif |
331 | }, |
332 | tbb::static_partitioner()); |
333 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
334 | using namespace dnnl::impl::threadpool_utils; |
335 | dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool(); |
336 | if (!tp || dnnl_in_parallel()) { |
337 | threadpool_utils::deactivate_threadpool(); |
338 | for (int ithr = 0; ithr < nthr; ithr++) { |
339 | f(ithr, nthr); |
340 | } |
341 | threadpool_utils::activate_threadpool(tp); |
342 | } else { |
343 | bool async = tp->get_flags() |
344 | & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS; |
345 | counting_barrier_t b; |
346 | if (async) b.init(nthr); |
347 | tp->parallel_for(nthr, [&, tp](int ithr, int nthr) { |
348 | bool is_master = threadpool_utils::get_active_threadpool() == tp; |
349 | if (!is_master) { |
350 | threadpool_utils::activate_threadpool(tp); |
351 | #if defined(DNNL_ENABLE_ITT_TASKS) |
352 | if (itt_enable) itt::primitive_task_start(task_primitive_kind); |
353 | #endif |
354 | } |
355 | f(ithr, nthr); |
356 | if (!is_master) { |
357 | #if defined(DNNL_ENABLE_ITT_TASKS) |
358 | if (itt_enable) itt::primitive_task_end(); |
359 | #endif |
360 | threadpool_utils::deactivate_threadpool(); |
361 | } |
362 | if (async) b.notify(); |
363 | }); |
364 | if (async) b.wait(); |
365 | } |
366 | #endif |
367 | #endif |
368 | } |
369 | |
370 | /* for_nd section */ |
371 | inline void for_nd(const int ithr, const int nthr, dim_t D0, |
372 | const std::function<void(dim_t)> &f) { |
373 | dim_t start {0}, end {0}; |
374 | balance211(D0, nthr, ithr, start, end); |
375 | for (dim_t d0 = start; d0 < end; ++d0) |
376 | f(d0); |
377 | } |
378 | inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, |
379 | const std::function<void(dim_t, dim_t)> &f) { |
380 | const dim_t work_amount = D0 * D1; |
381 | if (work_amount == 0) return; |
382 | dim_t start {0}, end {0}; |
383 | balance211(work_amount, nthr, ithr, start, end); |
384 | |
385 | dim_t d0 {0}, d1 {0}; |
386 | utils::nd_iterator_init(start, d0, D0, d1, D1); |
387 | for (dim_t iwork = start; iwork < end; ++iwork) { |
388 | f(d0, d1); |
389 | utils::nd_iterator_step(d0, D0, d1, D1); |
390 | } |
391 | } |
392 | inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2, |
393 | const std::function<void(dim_t, dim_t, dim_t)> &f) { |
394 | const dim_t work_amount = D0 * D1 * D2; |
395 | if (work_amount == 0) return; |
396 | dim_t start {0}, end {0}; |
397 | balance211(work_amount, nthr, ithr, start, end); |
398 | |
399 | dim_t d0 {0}, d1 {0}, d2 {0}; |
400 | utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); |
401 | for (dim_t iwork = start; iwork < end; ++iwork) { |
402 | f(d0, d1, d2); |
403 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); |
404 | } |
405 | } |
406 | inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2, |
407 | dim_t D3, const std::function<void(dim_t, dim_t, dim_t, dim_t)> &f) { |
408 | const dim_t work_amount = D0 * D1 * D2 * D3; |
409 | if (work_amount == 0) return; |
410 | dim_t start {0}, end {0}; |
411 | balance211(work_amount, nthr, ithr, start, end); |
412 | |
413 | dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}; |
414 | utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); |
415 | for (dim_t iwork = start; iwork < end; ++iwork) { |
416 | f(d0, d1, d2, d3); |
417 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); |
418 | } |
419 | } |
420 | inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2, |
421 | dim_t D3, dim_t D4, |
422 | const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t)> &f) { |
423 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4; |
424 | if (work_amount == 0) return; |
425 | dim_t start {0}, end {0}; |
426 | balance211(work_amount, nthr, ithr, start, end); |
427 | |
428 | dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0}; |
429 | utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); |
430 | for (dim_t iwork = start; iwork < end; ++iwork) { |
431 | f(d0, d1, d2, d3, d4); |
432 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); |
433 | } |
434 | } |
435 | inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2, |
436 | dim_t D3, dim_t D4, dim_t D5, |
437 | const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)> |
438 | &f) { |
439 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5; |
440 | if (work_amount == 0) return; |
441 | dim_t start {0}, end {0}; |
442 | balance211(work_amount, nthr, ithr, start, end); |
443 | |
444 | dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0}, d5 {0}; |
445 | utils::nd_iterator_init( |
446 | start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); |
447 | for (dim_t iwork = start; iwork < end; ++iwork) { |
448 | f(d0, d1, d2, d3, d4, d5); |
449 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); |
450 | } |
451 | } |
452 | |
453 | /* for_nd_ext section */ |
454 | inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, |
455 | const std::function<void(int, int, dim_t)> &f) { |
456 | dim_t start {0}, end {0}; |
457 | balance211(D0, nthr, ithr, start, end); |
458 | for (dim_t d0 = start; d0 < end; ++d0) |
459 | f(ithr, nthr, d0); |
460 | } |
461 | inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1, |
462 | const std::function<void(int, int, dim_t, dim_t)> &f) { |
463 | const dim_t work_amount = D0 * D1; |
464 | if (work_amount == 0) return; |
465 | dim_t start {0}, end {0}; |
466 | balance211(work_amount, nthr, ithr, start, end); |
467 | |
468 | dim_t d0 {0}, d1 {0}; |
469 | utils::nd_iterator_init(start, d0, D0, d1, D1); |
470 | for (dim_t iwork = start; iwork < end; ++iwork) { |
471 | f(ithr, nthr, d0, d1); |
472 | utils::nd_iterator_step(d0, D0, d1, D1); |
473 | } |
474 | } |
475 | inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1, |
476 | dim_t D2, const std::function<void(int, int, dim_t, dim_t, dim_t)> &f) { |
477 | const dim_t work_amount = D0 * D1 * D2; |
478 | if (work_amount == 0) return; |
479 | dim_t start {0}, end {0}; |
480 | balance211(work_amount, nthr, ithr, start, end); |
481 | |
482 | dim_t d0 {0}, d1 {0}, d2 {0}; |
483 | utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2); |
484 | for (dim_t iwork = start; iwork < end; ++iwork) { |
485 | f(ithr, nthr, d0, d1, d2); |
486 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2); |
487 | } |
488 | } |
489 | inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1, |
490 | dim_t D2, dim_t D3, |
491 | const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t)> &f) { |
492 | const dim_t work_amount = D0 * D1 * D2 * D3; |
493 | if (work_amount == 0) return; |
494 | dim_t start {0}, end {0}; |
495 | balance211(work_amount, nthr, ithr, start, end); |
496 | |
497 | dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}; |
498 | utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3); |
499 | for (dim_t iwork = start; iwork < end; ++iwork) { |
500 | f(ithr, nthr, d0, d1, d2, d3); |
501 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3); |
502 | } |
503 | } |
504 | inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1, |
505 | dim_t D2, dim_t D3, dim_t D4, |
506 | const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t, dim_t)> |
507 | &f) { |
508 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4; |
509 | if (work_amount == 0) return; |
510 | dim_t start {0}, end {0}; |
511 | balance211(work_amount, nthr, ithr, start, end); |
512 | |
513 | dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0}; |
514 | utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); |
515 | for (dim_t iwork = start; iwork < end; ++iwork) { |
516 | f(ithr, nthr, d0, d1, d2, d3, d4); |
517 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4); |
518 | } |
519 | } |
520 | inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1, |
521 | dim_t D2, dim_t D3, dim_t D4, dim_t D5, |
522 | const std::function<void( |
523 | int, int, dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)> &f) { |
524 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5; |
525 | if (work_amount == 0) return; |
526 | dim_t start {0}, end {0}; |
527 | balance211(work_amount, nthr, ithr, start, end); |
528 | |
529 | dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0}, d5 {0}; |
530 | utils::nd_iterator_init( |
531 | start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); |
532 | for (dim_t iwork = start; iwork < end; ++iwork) { |
533 | f(ithr, nthr, d0, d1, d2, d3, d4, d5); |
534 | utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5); |
535 | } |
536 | } |
537 | |
538 | /* parallel_nd_ext section */ |
539 | inline void parallel_nd_ext( |
540 | int nthr, dim_t D0, const std::function<void(int, int, dim_t)> &f) { |
541 | const dim_t work_amount = D0; |
542 | nthr = adjust_num_threads(nthr, work_amount); |
543 | if (nthr) |
544 | parallel(nthr, |
545 | [&](int ithr, int nthr) { for_nd_ext(ithr, nthr, D0, f); }); |
546 | } |
547 | inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, |
548 | const std::function<void(int, int, dim_t, dim_t)> &f) { |
549 | const dim_t work_amount = D0 * D1; |
550 | nthr = adjust_num_threads(nthr, work_amount); |
551 | if (nthr) |
552 | parallel(nthr, |
553 | [&](int ithr, int nthr) { for_nd_ext(ithr, nthr, D0, D1, f); }); |
554 | } |
555 | inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2, |
556 | const std::function<void(int, int, dim_t, dim_t, dim_t)> &f) { |
557 | const dim_t work_amount = D0 * D1 * D2; |
558 | nthr = adjust_num_threads(nthr, work_amount); |
559 | if (nthr) |
560 | parallel(nthr, [&](int ithr, int nthr) { |
561 | for_nd_ext(ithr, nthr, D0, D1, D2, f); |
562 | }); |
563 | } |
564 | inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2, dim_t D3, |
565 | const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t)> &f) { |
566 | const dim_t work_amount = D0 * D1 * D2 * D3; |
567 | nthr = adjust_num_threads(nthr, work_amount); |
568 | if (nthr) |
569 | parallel(nthr, [&](int ithr, int nthr) { |
570 | for_nd_ext(ithr, nthr, D0, D1, D2, D3, f); |
571 | }); |
572 | } |
573 | inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2, dim_t D3, |
574 | dim_t D4, |
575 | const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t, dim_t)> |
576 | &f) { |
577 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4; |
578 | nthr = adjust_num_threads(nthr, work_amount); |
579 | if (nthr) |
580 | parallel(nthr, [&](int ithr, int nthr) { |
581 | for_nd_ext(ithr, nthr, D0, D1, D2, D3, D4, f); |
582 | }); |
583 | } |
584 | inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2, dim_t D3, |
585 | dim_t D4, dim_t D5, |
586 | const std::function<void( |
587 | int, int, dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)> &f) { |
588 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5; |
589 | nthr = adjust_num_threads(nthr, work_amount); |
590 | if (nthr) |
591 | parallel(nthr, [&](int ithr, int nthr) { |
592 | for_nd_ext(ithr, nthr, D0, D1, D2, D3, D4, D5, f); |
593 | }); |
594 | } |
595 | |
596 | /* parallel_nd section */ |
597 | inline void parallel_nd(dim_t D0, const std::function<void(dim_t)> &f) { |
598 | int nthr = adjust_num_threads(dnnl_get_current_num_threads(), D0); |
599 | if (nthr) |
600 | parallel(nthr, [&](int ithr, int nthr) { for_nd(ithr, nthr, D0, f); }); |
601 | } |
602 | inline void parallel_nd( |
603 | dim_t D0, dim_t D1, const std::function<void(dim_t, dim_t)> &f) { |
604 | const dim_t work_amount = D0 * D1; |
605 | int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); |
606 | if (nthr) |
607 | parallel(nthr, |
608 | [&](int ithr, int nthr) { for_nd(ithr, nthr, D0, D1, f); }); |
609 | } |
610 | inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, |
611 | const std::function<void(dim_t, dim_t, dim_t)> &f) { |
612 | const dim_t work_amount = D0 * D1 * D2; |
613 | int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); |
614 | if (nthr) |
615 | parallel(nthr, |
616 | [&](int ithr, int nthr) { for_nd(ithr, nthr, D0, D1, D2, f); }); |
617 | } |
618 | inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, |
619 | const std::function<void(dim_t, dim_t, dim_t, dim_t)> &f) { |
620 | const dim_t work_amount = D0 * D1 * D2 * D3; |
621 | int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); |
622 | if (nthr) |
623 | parallel(nthr, [&](int ithr, int nthr) { |
624 | for_nd(ithr, nthr, D0, D1, D2, D3, f); |
625 | }); |
626 | } |
627 | inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, dim_t D4, |
628 | const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t)> &f) { |
629 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4; |
630 | int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); |
631 | if (nthr) |
632 | parallel(nthr, [&](int ithr, int nthr) { |
633 | for_nd(ithr, nthr, D0, D1, D2, D3, D4, f); |
634 | }); |
635 | } |
636 | inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, dim_t D4, |
637 | dim_t D5, |
638 | const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)> |
639 | &f) { |
640 | const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5; |
641 | int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount); |
642 | if (nthr) |
643 | parallel(nthr, [&](int ithr, int nthr) { |
644 | for_nd(ithr, nthr, D0, D1, D2, D3, D4, D5, f); |
645 | }); |
646 | } |
647 | /* parallel_nd_in_omp section */ |
648 | |
649 | template <typename... Args> |
650 | void parallel_nd_in_omp(Args &&... args) { |
651 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ |
652 | for_nd(0, 1, utils::forward<Args>(args)...); |
653 | #elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP |
654 | for_nd(omp_get_thread_num(), omp_get_num_threads(), |
655 | utils::forward<Args>(args)...); |
656 | #elif (DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB \ |
657 | || DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL) |
658 | assert(!"parallel_nd_in_omp() is not supported by this DNNL_CPU_RUNTIME" ); |
659 | #endif |
660 | } |
661 | |
662 | } // namespace impl |
663 | } // namespace dnnl |
664 | |
665 | #endif |
666 | |
667 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
668 | |