1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_DNNL_THREAD_HPP
18#define COMMON_DNNL_THREAD_HPP
19
20#include <algorithm>
21#include <functional>
22#include <mutex>
23
24#include "utils.hpp"
25#include "z_magic.hpp"
26
27// IMPORTANT NOTICE:
28// This header is special in the library since it enables all threading
29// functionality in the product including tests.
30// tests/test_thread.{c,h}pp files rely on this header file by:
31// * Substituting `threadpool_utils` namespace to re-use threadpool functions
32// and enable a second threadpool different from the library;
33// * Re-defining `DNNL_CPU_THREADING_RUNTIME` macro value when it is supposed
34// to be `DNNL_RUNTIME_SEQ`, e.g., for CPU_NONE configuration.
35// 1. It implies all parts of code relying on this macro should stay in the
36// file.
37// 2. It implies there are no function bodies in the translation units
38// related to the library. Tests threading layer uses dnnl::impl::func
39// signature, and if library has symbols defined, regardless of
40// redefinition, it will take those that were compiled with original
41// macro value.
42//
43// Potential drawback could be increased binary size but it doesn't happen much
44// due to linker optimizations. The newer compiler and C++ standard, the less
45// binary size will be achieved.
46
47#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
48#include "counting_barrier.hpp"
49#endif
50
51#if defined(DNNL_ENABLE_ITT_TASKS)
52#include "common/ittnotify.hpp"
53#endif
54
55#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ
56#define DNNL_THR_SYNC 1
57inline int dnnl_get_max_threads() {
58 return 1;
59}
60inline int dnnl_in_parallel() {
61 return 0;
62}
63inline void dnnl_thr_barrier() {}
64
65#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
66#include "omp.h"
67#define DNNL_THR_SYNC 1
68inline int dnnl_get_max_threads() {
69 return omp_get_max_threads();
70}
71inline int dnnl_in_parallel() {
72 return omp_in_parallel();
73}
74inline void dnnl_thr_barrier() {
75#pragma omp barrier
76}
77
78#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB
79#include "tbb/parallel_for.h"
80#include "tbb/task_arena.h"
81#define DNNL_THR_SYNC 0
82inline int dnnl_get_max_threads() {
83 return tbb::this_task_arena::max_concurrency();
84}
85inline int dnnl_in_parallel() {
86 return 0;
87}
88inline void dnnl_thr_barrier() {
89 assert(!"no barrier in TBB");
90}
91
92#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
93#include <thread>
94#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
95#define DNNL_THR_SYNC 0
96
97#include "cpu/platform.hpp"
98
99namespace dnnl {
100namespace impl {
101namespace threadpool_utils {
102
103// Each thread maintains a thread-local pointer to a threadpool which is
104// 'active' for the current thread. If this pointer is a nullptr, all the work
105// is executed sequentially.
106
107// Sets `tp` to be the active threadpool for the calling thread. This will
108// make all calls to `get_active_threadpool()` to return `tp` thus enabling
109// `parallel()` and `parallel_nd()` to submit work to `tp`.
110void activate_threadpool(dnnl::threadpool_interop::threadpool_iface *tp);
111
112// Resets the active threadpool for the calling thread to nullptr. After this
113// call `parallel()` and `parallel_nd()` would execute work sequentially.
114void deactivate_threadpool();
115
116// Returns the active threadpool for the calling thread.
117dnnl::threadpool_interop::threadpool_iface *get_active_threadpool();
118
119// returns the maximum concurrency available in the given global context
120int get_max_concurrency();
121
122int &get_threadlocal_max_concurrency();
123
124} // namespace threadpool_utils
125} // namespace impl
126} // namespace dnnl
127
128inline int dnnl_get_max_threads() {
129 using namespace dnnl::impl::threadpool_utils;
130 dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool();
131
132 // This is the maximum number of threads oneDNN will use by default
133 int max_concurrency = dnnl::impl::threadpool_utils::get_max_concurrency();
134
135 // Use the default max_concurrency only when no tp is passed by
136 // user (e.g. primitive creation).
137 return tp ? std::max(1, tp->get_num_threads()) : max_concurrency;
138}
139inline int dnnl_in_parallel() {
140 using namespace dnnl::impl::threadpool_utils;
141 dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool();
142 return tp ? tp->get_in_parallel() : 0;
143}
144inline void dnnl_thr_barrier() {
145 assert(!"no barrier with THREADPOOL");
146}
147#endif
148
149/* The purpose of this function is to provide the number of threads the library
150 * is aware of when this function is invoked. Since oneDNN does not allow nested
151 * parallelism, inside a parallel region the number of available threads is 1.
152 * Otherwise, the number of current threads varies between threading runtimes:
153 * - for OpenMP and TBB, return the max number of threads since the number of
154 * threads is held in a global object throughout the entire execution.
155 * - for Threadpool, since the global object in oneDNN changes throughout
156 * execution, two situations can occur:
157 * a) if the library *is* aware of a threadpool when this function is invoked,
158 * return the number of available threads in the threadpool;
159 * b) if the library *is not* aware of a threadpool when this function is
160 * invoked, return 1 since the main thread will do the work.
161 */
162inline int dnnl_get_current_num_threads() {
163 if (dnnl_in_parallel()) return 1;
164#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
165 return omp_get_max_threads();
166#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB
167 return tbb::this_task_arena::max_concurrency();
168#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
169 using namespace dnnl::impl::threadpool_utils;
170 dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool();
171 return (tp) ? dnnl_get_max_threads() : 1;
172#else
173 return 1;
174#endif
175}
176
177#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
178#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
179#define OMP_GET_THREAD_NUM() omp_get_thread_num()
180#define OMP_GET_NUM_THREADS() omp_get_num_threads()
181#else
182#define PRAGMA_OMP(...)
183#define OMP_GET_THREAD_NUM() 0
184#define OMP_GET_NUM_THREADS() 1
185#endif
186
187// MSVC still supports omp 2.0 only
188#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
189#define collapse(x)
190#define PRAGMA_OMP_SIMD(...)
191#else
192#define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
193#endif // defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
194
195// process simdlen; it is supported for Clang >= 3.9; ICC >= 17.0; GCC >= 6.1
196// No support on Windows.
197#if (defined(__clang_major__) \
198 && (__clang_major__ < 3 \
199 || (__clang_major__ == 3 && __clang_minor__ < 9))) \
200 || (defined(__INTEL_COMPILER) && __INTEL_COMPILER < 1700) \
201 || (!defined(__INTEL_COMPILER) && !defined(__clang__) \
202 && (defined(_MSC_VER) || __GNUC__ < 6 \
203 || (__GNUC__ == 6 && __GNUC_MINOR__ < 1)))
204#define simdlen(x)
205#endif // long simdlen if
206
207namespace dnnl {
208namespace impl {
209
210inline bool dnnl_thr_syncable() {
211 return DNNL_THR_SYNC == 1;
212}
213
214template <typename T, typename U>
215inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
216 T n_min = 1;
217 T &n_my = n_end;
218 if (team <= 1 || n == 0) {
219 n_start = 0;
220 n_my = n;
221 } else if (n_min == 1) {
222 // team = T1 + T2
223 // n = T1*n1 + T2*n2 (n1 - n2 = 1)
224 T n1 = utils::div_up(n, (T)team);
225 T n2 = n1 - 1;
226 T T1 = n - n2 * (T)team;
227 n_my = (T)tid < T1 ? n1 : n2;
228 n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
229 }
230
231 n_end += n_start;
232}
233
234template <typename T, typename U>
235void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, T nx, T &nx_start,
236 T &nx_end, T nx_divider) {
237 const T grp_count = nstl::min(nx_divider, static_cast<T>(nthr));
238 const int grp_size_big = nthr / static_cast<int>(grp_count) + 1;
239 const int grp_size_small = nthr / static_cast<int>(grp_count);
240 const int n_grp_big = nthr % static_cast<int>(grp_count);
241 const int threads_in_big_groups = n_grp_big * grp_size_big;
242
243 const int ithr_bound_distance = ithr - threads_in_big_groups;
244 T grp, grp_ithr, grp_nthr;
245 if (ithr_bound_distance < 0) { // ithr in first groups
246 grp = ithr / grp_size_big;
247 grp_ithr = ithr % grp_size_big;
248 grp_nthr = grp_size_big;
249 } else { // ithr in last groups
250 grp = n_grp_big + ithr_bound_distance / grp_size_small;
251 grp_ithr = ithr_bound_distance % grp_size_small;
252 grp_nthr = grp_size_small;
253 }
254
255 balance211(nx, grp_count, grp, nx_start, nx_end);
256 balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
257}
258
259/* Functions:
260 * - parallel(nthr, f) - executes f in parallel using at
261 * most nthr threads. If nthr equals
262 * 0 dnnl_get_current_num_threads() threads
263 * is used
264 * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for
265 * already created threads
266 * - for_nd_ext(ithr, nthr, dims..., f) - multidimensional for loop for
267 * already created threads that passes
268 * ithr and nthr
269 * - parallel_nd(dims..., f) - creates a parallel section and then
270 * calls for_nd
271 * - parallel_nd_ext(dims..., f) - creates a parallel section and then
272 * calls for_nd_ext
273 * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and
274 * then calls for_nd (mostly for
275 * convenience)
276 */
277
278/* general parallelization */
279inline int adjust_num_threads(int nthr, dim_t work_amount) {
280 if (nthr == 0) nthr = dnnl_get_current_num_threads();
281#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
282 return (work_amount == 1 || omp_in_parallel()) ? 1 : nthr;
283#else
284 return (int)std::min((dim_t)nthr, work_amount);
285#endif
286}
287
288inline void parallel(int nthr, const std::function<void(int, int)> &f) {
289 nthr = adjust_num_threads(nthr, INT64_MAX);
290#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ
291 for (int i = 0; i < nthr; ++i) {
292 f(i, nthr);
293 }
294#else
295#if defined(DNNL_ENABLE_ITT_TASKS)
296 auto task_primitive_kind = itt::primitive_task_get_current_kind();
297 bool itt_enable = itt::get_itt(itt::__itt_task_level_high);
298#endif
299 if (nthr == 1) {
300 f(0, 1);
301 return;
302 }
303#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
304#pragma omp parallel num_threads(nthr)
305 {
306 int nthr_ = omp_get_num_threads();
307 int ithr_ = omp_get_thread_num();
308 assert(nthr_ == nthr);
309#if defined(DNNL_ENABLE_ITT_TASKS)
310 if (ithr_ && itt_enable) itt::primitive_task_start(task_primitive_kind);
311#endif
312 f(ithr_, nthr_);
313#if defined(DNNL_ENABLE_ITT_TASKS)
314 if (ithr_ && itt_enable) itt::primitive_task_end();
315#endif
316 }
317#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB
318 tbb::parallel_for(
319 0, nthr,
320 [&](int ithr) {
321#if defined(DNNL_ENABLE_ITT_TASKS)
322 bool mark_task = itt::primitive_task_get_current_kind()
323 == primitive_kind::undefined;
324 if (mark_task && itt_enable)
325 itt::primitive_task_start(task_primitive_kind);
326#endif
327 f(ithr, nthr);
328#if defined(DNNL_ENABLE_ITT_TASKS)
329 if (mark_task && itt_enable) itt::primitive_task_end();
330#endif
331 },
332 tbb::static_partitioner());
333#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
334 using namespace dnnl::impl::threadpool_utils;
335 dnnl::threadpool_interop::threadpool_iface *tp = get_active_threadpool();
336 if (!tp || dnnl_in_parallel()) {
337 threadpool_utils::deactivate_threadpool();
338 for (int ithr = 0; ithr < nthr; ithr++) {
339 f(ithr, nthr);
340 }
341 threadpool_utils::activate_threadpool(tp);
342 } else {
343 bool async = tp->get_flags()
344 & dnnl::threadpool_interop::threadpool_iface::ASYNCHRONOUS;
345 counting_barrier_t b;
346 if (async) b.init(nthr);
347 tp->parallel_for(nthr, [&, tp](int ithr, int nthr) {
348 bool is_master = threadpool_utils::get_active_threadpool() == tp;
349 if (!is_master) {
350 threadpool_utils::activate_threadpool(tp);
351#if defined(DNNL_ENABLE_ITT_TASKS)
352 if (itt_enable) itt::primitive_task_start(task_primitive_kind);
353#endif
354 }
355 f(ithr, nthr);
356 if (!is_master) {
357#if defined(DNNL_ENABLE_ITT_TASKS)
358 if (itt_enable) itt::primitive_task_end();
359#endif
360 threadpool_utils::deactivate_threadpool();
361 }
362 if (async) b.notify();
363 });
364 if (async) b.wait();
365 }
366#endif
367#endif
368}
369
370/* for_nd section */
371inline void for_nd(const int ithr, const int nthr, dim_t D0,
372 const std::function<void(dim_t)> &f) {
373 dim_t start {0}, end {0};
374 balance211(D0, nthr, ithr, start, end);
375 for (dim_t d0 = start; d0 < end; ++d0)
376 f(d0);
377}
378inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1,
379 const std::function<void(dim_t, dim_t)> &f) {
380 const dim_t work_amount = D0 * D1;
381 if (work_amount == 0) return;
382 dim_t start {0}, end {0};
383 balance211(work_amount, nthr, ithr, start, end);
384
385 dim_t d0 {0}, d1 {0};
386 utils::nd_iterator_init(start, d0, D0, d1, D1);
387 for (dim_t iwork = start; iwork < end; ++iwork) {
388 f(d0, d1);
389 utils::nd_iterator_step(d0, D0, d1, D1);
390 }
391}
392inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2,
393 const std::function<void(dim_t, dim_t, dim_t)> &f) {
394 const dim_t work_amount = D0 * D1 * D2;
395 if (work_amount == 0) return;
396 dim_t start {0}, end {0};
397 balance211(work_amount, nthr, ithr, start, end);
398
399 dim_t d0 {0}, d1 {0}, d2 {0};
400 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
401 for (dim_t iwork = start; iwork < end; ++iwork) {
402 f(d0, d1, d2);
403 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
404 }
405}
406inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2,
407 dim_t D3, const std::function<void(dim_t, dim_t, dim_t, dim_t)> &f) {
408 const dim_t work_amount = D0 * D1 * D2 * D3;
409 if (work_amount == 0) return;
410 dim_t start {0}, end {0};
411 balance211(work_amount, nthr, ithr, start, end);
412
413 dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0};
414 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
415 for (dim_t iwork = start; iwork < end; ++iwork) {
416 f(d0, d1, d2, d3);
417 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
418 }
419}
420inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2,
421 dim_t D3, dim_t D4,
422 const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t)> &f) {
423 const dim_t work_amount = D0 * D1 * D2 * D3 * D4;
424 if (work_amount == 0) return;
425 dim_t start {0}, end {0};
426 balance211(work_amount, nthr, ithr, start, end);
427
428 dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0};
429 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
430 for (dim_t iwork = start; iwork < end; ++iwork) {
431 f(d0, d1, d2, d3, d4);
432 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
433 }
434}
435inline void for_nd(const int ithr, const int nthr, dim_t D0, dim_t D1, dim_t D2,
436 dim_t D3, dim_t D4, dim_t D5,
437 const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)>
438 &f) {
439 const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5;
440 if (work_amount == 0) return;
441 dim_t start {0}, end {0};
442 balance211(work_amount, nthr, ithr, start, end);
443
444 dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0}, d5 {0};
445 utils::nd_iterator_init(
446 start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
447 for (dim_t iwork = start; iwork < end; ++iwork) {
448 f(d0, d1, d2, d3, d4, d5);
449 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
450 }
451}
452
453/* for_nd_ext section */
454inline void for_nd_ext(const int ithr, const int nthr, dim_t D0,
455 const std::function<void(int, int, dim_t)> &f) {
456 dim_t start {0}, end {0};
457 balance211(D0, nthr, ithr, start, end);
458 for (dim_t d0 = start; d0 < end; ++d0)
459 f(ithr, nthr, d0);
460}
461inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1,
462 const std::function<void(int, int, dim_t, dim_t)> &f) {
463 const dim_t work_amount = D0 * D1;
464 if (work_amount == 0) return;
465 dim_t start {0}, end {0};
466 balance211(work_amount, nthr, ithr, start, end);
467
468 dim_t d0 {0}, d1 {0};
469 utils::nd_iterator_init(start, d0, D0, d1, D1);
470 for (dim_t iwork = start; iwork < end; ++iwork) {
471 f(ithr, nthr, d0, d1);
472 utils::nd_iterator_step(d0, D0, d1, D1);
473 }
474}
475inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1,
476 dim_t D2, const std::function<void(int, int, dim_t, dim_t, dim_t)> &f) {
477 const dim_t work_amount = D0 * D1 * D2;
478 if (work_amount == 0) return;
479 dim_t start {0}, end {0};
480 balance211(work_amount, nthr, ithr, start, end);
481
482 dim_t d0 {0}, d1 {0}, d2 {0};
483 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
484 for (dim_t iwork = start; iwork < end; ++iwork) {
485 f(ithr, nthr, d0, d1, d2);
486 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
487 }
488}
489inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1,
490 dim_t D2, dim_t D3,
491 const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t)> &f) {
492 const dim_t work_amount = D0 * D1 * D2 * D3;
493 if (work_amount == 0) return;
494 dim_t start {0}, end {0};
495 balance211(work_amount, nthr, ithr, start, end);
496
497 dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0};
498 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
499 for (dim_t iwork = start; iwork < end; ++iwork) {
500 f(ithr, nthr, d0, d1, d2, d3);
501 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
502 }
503}
504inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1,
505 dim_t D2, dim_t D3, dim_t D4,
506 const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t, dim_t)>
507 &f) {
508 const dim_t work_amount = D0 * D1 * D2 * D3 * D4;
509 if (work_amount == 0) return;
510 dim_t start {0}, end {0};
511 balance211(work_amount, nthr, ithr, start, end);
512
513 dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0};
514 utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
515 for (dim_t iwork = start; iwork < end; ++iwork) {
516 f(ithr, nthr, d0, d1, d2, d3, d4);
517 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
518 }
519}
520inline void for_nd_ext(const int ithr, const int nthr, dim_t D0, dim_t D1,
521 dim_t D2, dim_t D3, dim_t D4, dim_t D5,
522 const std::function<void(
523 int, int, dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)> &f) {
524 const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5;
525 if (work_amount == 0) return;
526 dim_t start {0}, end {0};
527 balance211(work_amount, nthr, ithr, start, end);
528
529 dim_t d0 {0}, d1 {0}, d2 {0}, d3 {0}, d4 {0}, d5 {0};
530 utils::nd_iterator_init(
531 start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
532 for (dim_t iwork = start; iwork < end; ++iwork) {
533 f(ithr, nthr, d0, d1, d2, d3, d4, d5);
534 utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
535 }
536}
537
538/* parallel_nd_ext section */
539inline void parallel_nd_ext(
540 int nthr, dim_t D0, const std::function<void(int, int, dim_t)> &f) {
541 const dim_t work_amount = D0;
542 nthr = adjust_num_threads(nthr, work_amount);
543 if (nthr)
544 parallel(nthr,
545 [&](int ithr, int nthr) { for_nd_ext(ithr, nthr, D0, f); });
546}
547inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1,
548 const std::function<void(int, int, dim_t, dim_t)> &f) {
549 const dim_t work_amount = D0 * D1;
550 nthr = adjust_num_threads(nthr, work_amount);
551 if (nthr)
552 parallel(nthr,
553 [&](int ithr, int nthr) { for_nd_ext(ithr, nthr, D0, D1, f); });
554}
555inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2,
556 const std::function<void(int, int, dim_t, dim_t, dim_t)> &f) {
557 const dim_t work_amount = D0 * D1 * D2;
558 nthr = adjust_num_threads(nthr, work_amount);
559 if (nthr)
560 parallel(nthr, [&](int ithr, int nthr) {
561 for_nd_ext(ithr, nthr, D0, D1, D2, f);
562 });
563}
564inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2, dim_t D3,
565 const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t)> &f) {
566 const dim_t work_amount = D0 * D1 * D2 * D3;
567 nthr = adjust_num_threads(nthr, work_amount);
568 if (nthr)
569 parallel(nthr, [&](int ithr, int nthr) {
570 for_nd_ext(ithr, nthr, D0, D1, D2, D3, f);
571 });
572}
573inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2, dim_t D3,
574 dim_t D4,
575 const std::function<void(int, int, dim_t, dim_t, dim_t, dim_t, dim_t)>
576 &f) {
577 const dim_t work_amount = D0 * D1 * D2 * D3 * D4;
578 nthr = adjust_num_threads(nthr, work_amount);
579 if (nthr)
580 parallel(nthr, [&](int ithr, int nthr) {
581 for_nd_ext(ithr, nthr, D0, D1, D2, D3, D4, f);
582 });
583}
584inline void parallel_nd_ext(int nthr, dim_t D0, dim_t D1, dim_t D2, dim_t D3,
585 dim_t D4, dim_t D5,
586 const std::function<void(
587 int, int, dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)> &f) {
588 const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5;
589 nthr = adjust_num_threads(nthr, work_amount);
590 if (nthr)
591 parallel(nthr, [&](int ithr, int nthr) {
592 for_nd_ext(ithr, nthr, D0, D1, D2, D3, D4, D5, f);
593 });
594}
595
596/* parallel_nd section */
597inline void parallel_nd(dim_t D0, const std::function<void(dim_t)> &f) {
598 int nthr = adjust_num_threads(dnnl_get_current_num_threads(), D0);
599 if (nthr)
600 parallel(nthr, [&](int ithr, int nthr) { for_nd(ithr, nthr, D0, f); });
601}
602inline void parallel_nd(
603 dim_t D0, dim_t D1, const std::function<void(dim_t, dim_t)> &f) {
604 const dim_t work_amount = D0 * D1;
605 int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
606 if (nthr)
607 parallel(nthr,
608 [&](int ithr, int nthr) { for_nd(ithr, nthr, D0, D1, f); });
609}
610inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2,
611 const std::function<void(dim_t, dim_t, dim_t)> &f) {
612 const dim_t work_amount = D0 * D1 * D2;
613 int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
614 if (nthr)
615 parallel(nthr,
616 [&](int ithr, int nthr) { for_nd(ithr, nthr, D0, D1, D2, f); });
617}
618inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3,
619 const std::function<void(dim_t, dim_t, dim_t, dim_t)> &f) {
620 const dim_t work_amount = D0 * D1 * D2 * D3;
621 int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
622 if (nthr)
623 parallel(nthr, [&](int ithr, int nthr) {
624 for_nd(ithr, nthr, D0, D1, D2, D3, f);
625 });
626}
627inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, dim_t D4,
628 const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t)> &f) {
629 const dim_t work_amount = D0 * D1 * D2 * D3 * D4;
630 int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
631 if (nthr)
632 parallel(nthr, [&](int ithr, int nthr) {
633 for_nd(ithr, nthr, D0, D1, D2, D3, D4, f);
634 });
635}
636inline void parallel_nd(dim_t D0, dim_t D1, dim_t D2, dim_t D3, dim_t D4,
637 dim_t D5,
638 const std::function<void(dim_t, dim_t, dim_t, dim_t, dim_t, dim_t)>
639 &f) {
640 const dim_t work_amount = D0 * D1 * D2 * D3 * D4 * D5;
641 int nthr = adjust_num_threads(dnnl_get_current_num_threads(), work_amount);
642 if (nthr)
643 parallel(nthr, [&](int ithr, int nthr) {
644 for_nd(ithr, nthr, D0, D1, D2, D3, D4, D5, f);
645 });
646}
647/* parallel_nd_in_omp section */
648
649template <typename... Args>
650void parallel_nd_in_omp(Args &&... args) {
651#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ
652 for_nd(0, 1, utils::forward<Args>(args)...);
653#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
654 for_nd(omp_get_thread_num(), omp_get_num_threads(),
655 utils::forward<Args>(args)...);
656#elif (DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB \
657 || DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL)
658 assert(!"parallel_nd_in_omp() is not supported by this DNNL_CPU_RUNTIME");
659#endif
660}
661
662} // namespace impl
663} // namespace dnnl
664
665#endif
666
667// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
668