1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef TEST_THREAD_HPP
18#define TEST_THREAD_HPP
19
20#include <iostream>
21
22#include "oneapi/dnnl/dnnl_config.h"
23
24#ifdef COMMON_DNNL_THREAD_HPP
25#error "src/common/dnnl_thread.hpp" was already included
26#endif
27
28#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE
29
30#if DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_SEQ
31#error "DNNL_CPU_THREADING_RUNTIME is expected to be SEQ for GPU only configurations."
32#endif
33
34#undef DNNL_CPU_THREADING_RUNTIME
35
36// Enable CPU threading layer for testing:
37// - DPCPP: TBB
38// - OCL: OpenMP
39#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
40#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_TBB
41#elif DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
42#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP
43#endif
44
45#endif
46
47// Here we define some types in global namespace to handle customized
48// threading context for creation and execution
49struct thr_ctx_t {
50 int max_concurrency;
51 int core_type;
52 int nthr_per_core;
53
54 bool operator==(const thr_ctx_t &rhs) const {
55 return max_concurrency == rhs.max_concurrency
56 && core_type == rhs.core_type
57 && nthr_per_core == rhs.nthr_per_core;
58 }
59 bool operator!=(const thr_ctx_t &rhs) const { return !(*this == rhs); }
60 void *get_interop_obj() const;
61};
62
63// tbb constraints on core type appear in 2021.2
64// tbb constraints on max_concurrency appear in 2020
65// we check only for 2021.2 to enable thread context knobs
66#ifdef TBB_INTERFACE_VERSION
67#define DNNL_TBB_CONSTRAINTS_ENABLED (TBB_INTERFACE_VERSION >= 12020)
68#else
69#define DNNL_TBB_CONSTRAINTS_ENABLED 0
70#endif
71
72#define DNNL_TBB_THREADING_WITH_CONSTRAINTS \
73 (DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB) \
74 && DNNL_TBB_CONSTRAINTS_ENABLED
75#define DNNL_TBB_THREADING_WITHOUT_CONSTRAINTS \
76 (DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_TBB) \
77 && !DNNL_TBB_CONSTRAINTS_ENABLED
78
79#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ \
80 || DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL \
81 || DNNL_TBB_THREADING_WITHOUT_CONSTRAINTS
82const thr_ctx_t default_thr_ctx = {0, -1, 0};
83#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
84#include "omp.h"
85const thr_ctx_t default_thr_ctx = {omp_get_max_threads(), -1, 0};
86#elif DNNL_TBB_THREADING_WITH_CONSTRAINTS
87#include "oneapi/tbb/task_arena.h"
88const thr_ctx_t default_thr_ctx = {tbb::task_arena::automatic,
89 tbb::task_arena::automatic, tbb::task_arena::automatic};
90#endif
91
92std::ostream &operator<<(std::ostream &os, const thr_ctx_t &ctx);
93
94#define THR_CTX_ASSERT(check, msg_fmt, ...) \
95 do { \
96 if (!(check)) { \
97 fprintf(stderr, msg_fmt, __VA_ARGS__); \
98 exit(1); \
99 } \
100 } while (0)
101
102// This hack renames the namespaces used by threading functions for
103// threadpool-related functions so that the calls to dnnl::impl::parallel*()
104// from the test use a special testing threadpool.
105//
106// At the same time, the calls to dnnl::impl::parallel*() from within the
107// library continue using the library version of these functions.
108#define threadpool_utils testing_threadpool_utils
109#include "src/common/dnnl_thread.hpp"
110#undef threadpool_utils
111
112#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE
113// Restore the original DNNL_CPU_THREADING_RUNTIME value.
114#undef DNNL_CPU_THREADING_RUNTIME
115#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_SEQ
116#endif
117
118#ifndef COMMON_DNNL_THREAD_HPP
119#error "src/common/dnnl_thread.hpp" has an unexpected header guard
120#endif
121
122#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
123#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
124namespace dnnl {
125
126// Original threadpool utils are used by the scoped_tp_activation_t and thus
127// need to be re-declared because of the hack above.
128namespace impl {
129namespace threadpool_utils {
130void activate_threadpool(dnnl::threadpool_interop::threadpool_iface *tp);
131void deactivate_threadpool();
132dnnl::threadpool_interop::threadpool_iface *get_active_threadpool();
133int get_max_concurrency();
134} // namespace threadpool_utils
135} // namespace impl
136
137namespace testing {
138
139dnnl::threadpool_interop::threadpool_iface *get_threadpool(
140 const thr_ctx_t &ctx = default_thr_ctx);
141
142// Sets the testing threadpool as active for the lifetime of the object.
143// Required for the tests that throw to work.
144struct scoped_tp_activation_t {
145 scoped_tp_activation_t(dnnl::threadpool_interop::threadpool_iface *tp_
146 = get_threadpool()) {
147 impl::threadpool_utils::activate_threadpool(tp_);
148 }
149 ~scoped_tp_activation_t() {
150 impl::threadpool_utils::deactivate_threadpool();
151 }
152};
153
154struct scoped_tp_deactivation_t {
155 scoped_tp_deactivation_t() {
156 impl::threadpool_utils::deactivate_threadpool();
157 }
158 ~scoped_tp_deactivation_t() {
159 // we always use the same threadpool that is returned by `get_threadpool()`
160 impl::threadpool_utils::activate_threadpool(get_threadpool());
161 }
162};
163
164} // namespace testing
165} // namespace dnnl
166#endif
167
168// These are free functions to allow running a function in a given threading context.
169// A threading context is defined by:
170// - number of threads
171// - type of cores (TBB only)
172// - threads per core (TBB only)
173
174// Note: we have to differentiate creation and execution in thread
175// context because of threadpool as it uses different mecanisms in
176// both (in execution, tp is passed in stream)
177
178#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ \
179 || DNNL_TBB_THREADING_WITHOUT_CONSTRAINTS
180
181#define RUN_IN_THR_CTX(name) \
182 template <typename F, typename... Args_t> \
183 auto name(const thr_ctx_t &ctx, F &&f, Args_t &... args) \
184 ->decltype(f(args...)) { \
185\
186 THR_CTX_ASSERT(ctx.core_type == default_thr_ctx.core_type \
187 && ctx.max_concurrency \
188 == default_thr_ctx.max_concurrency \
189 && ctx.nthr_per_core == default_thr_ctx.nthr_per_core, \
190 "Threading knobs not supported for this runtime: %s\n", \
191 DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_SEQ \
192 ? "sequential runtime has no threading" \
193 : "TBB version is too old (>=2021.2 required)"); \
194\
195 return f(args...); \
196 }
197
198RUN_IN_THR_CTX(create_in_thr_ctx)
199RUN_IN_THR_CTX(execute_in_thr_ctx)
200#undef RUN_IN_THR_CTX
201
202#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
203#define RUN_IN_THR_CTX(name) \
204 template <typename F, typename... Args_t> \
205 auto name(const thr_ctx_t &ctx, F &&f, Args_t &... args) \
206 ->decltype(f(args...)) { \
207\
208 THR_CTX_ASSERT(ctx.core_type == default_thr_ctx.core_type, \
209 "core type %d is not supported for OMP runtime\n", \
210 ctx.core_type); \
211\
212 auto max_nthr = omp_get_max_threads(); \
213 omp_set_num_threads(ctx.max_concurrency); \
214 auto st = f(args...); \
215 omp_set_num_threads(max_nthr); \
216 return st; \
217 }
218
219RUN_IN_THR_CTX(create_in_thr_ctx)
220RUN_IN_THR_CTX(execute_in_thr_ctx)
221#undef RUN_IN_THR_CTX
222
223#elif DNNL_TBB_THREADING_WITH_CONSTRAINTS
224#include "oneapi/tbb/info.h"
225#define RUN_IN_THR_CTX(name) \
226 template <typename F, typename... Args_t> \
227 auto name(const thr_ctx_t &ctx, F &&f, Args_t &... args) \
228 ->decltype(f(args...)) { \
229 static auto core_types = tbb::info:: \
230 core_types(); /* sorted by the relative strength */ \
231\
232 if ((ctx.core_type != default_thr_ctx.core_type) \
233 && (ctx.core_type >= core_types.size())) \
234 printf("WARNING: TBB smallest core has index %lu. Using this " \
235 "instead of %d.\n", \
236 core_types.size() - 1, ctx.core_type); \
237 size_t core_type_id = ctx.core_type < core_types.size() \
238 ? ctx.core_type \
239 : core_types.size() - 1; \
240 static auto core_type = ctx.core_type == tbb::task_arena::automatic \
241 ? tbb::task_arena::automatic \
242 : core_types[core_type_id]; \
243 static auto arena = tbb::task_arena { \
244 tbb::task_arena::constraints {} \
245 .set_core_type(core_type) \
246 .set_max_threads_per_core(ctx.nthr_per_core) \
247 .set_max_concurrency(ctx.max_concurrency)}; \
248 return arena.execute([&] { return f(args...); }); \
249 }
250
251RUN_IN_THR_CTX(create_in_thr_ctx)
252RUN_IN_THR_CTX(execute_in_thr_ctx)
253#undef RUN_IN_THR_CTX
254
255#elif DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
256template <typename F, typename... Args_t>
257auto create_in_thr_ctx(const thr_ctx_t &ctx, F &&f, Args_t &... args)
258 -> decltype(f(args...)) {
259 THR_CTX_ASSERT(ctx.core_type == default_thr_ctx.core_type,
260 "core type %d is not supported for TP runtime\n", ctx.core_type);
261
262 auto tp = dnnl::testing::get_threadpool(ctx);
263 auto stp = dnnl::testing::scoped_tp_activation_t(tp);
264 return f(args...);
265}
266
267// The function f shall take an interop obj as last argument
268template <typename F, typename... Args_t>
269auto execute_in_thr_ctx(const thr_ctx_t &ctx, F &&f, Args_t &... args)
270 -> decltype(f(args...)) {
271 THR_CTX_ASSERT(ctx.core_type == default_thr_ctx.core_type,
272 "core type %d is not supported for TP runtime\n", ctx.core_type);
273 return f(args...);
274}
275
276#else
277#error __FILE__"(" __LINE__ ")" "unsupported threading runtime!"
278#endif
279
280#undef ALIAS_TO_RUN_IN_THR_CTX
281#undef THR_CTX_ASSERT
282#undef DNNL_TBB_THREADING_WITHOUT_CONSTRAINTS
283#undef DNNL_TBB_THREADING_WITH_CONSTRAINTS
284#undef DNNL_TBB_CONSTRAINTS_ENABLED
285
286#endif
287