1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <tuple>
17
18#include "tests/test_thread.hpp"
19
20std::ostream &operator<<(std::ostream &os, const thr_ctx_t &ctx) {
21 if (ctx.max_concurrency == default_thr_ctx.max_concurrency)
22 os << "auto:";
23 else
24 os << ctx.max_concurrency << ":";
25
26 if (ctx.core_type == default_thr_ctx.core_type)
27 os << "auto:";
28 else
29 os << ctx.core_type << ":";
30
31 if (ctx.nthr_per_core == default_thr_ctx.nthr_per_core)
32 os << "auto";
33 else
34 os << ctx.nthr_per_core;
35
36 return os;
37}
38
39#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
40void *thr_ctx_t::get_interop_obj() const {
41 return dnnl::testing::get_threadpool(*this);
42}
43#else
44void *thr_ctx_t::get_interop_obj() const {
45 return nullptr;
46}
47#endif
48
49#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL
50
51#include <mutex>
52#include <unordered_map>
53
54#ifdef _WIN32
55#include <windows.h>
56#else
57#include <stdlib.h>
58#endif
59
60#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
61#include "src/common/counting_barrier.hpp"
62
63#if !defined(DNNL_TEST_THREADPOOL_USE_TBB)
64
65#include "src/cpu/platform.hpp"
66namespace dnnl {
67namespace testing {
68namespace {
69inline int read_num_threads_from_env() {
70 const char *env_num_threads = nullptr;
71 const char *env_var_name = "OMP_NUM_THREADS";
72#ifdef _WIN32
73 // This is only required to avoid using _CRT_SECURE_NO_WARNINGS
74 const size_t buf_size = 12;
75 char buf[buf_size];
76 size_t val_size = GetEnvironmentVariable(env_var_name, buf, buf_size);
77 if (val_size > 0 && val_size < buf_size) env_num_threads = buf;
78#else // ifdef _WIN32
79 env_num_threads = ::getenv(env_var_name);
80#endif
81
82 int num_threads = 0;
83 if (env_num_threads) {
84 char *endp;
85 int nt = strtol(env_num_threads, &endp, 10);
86 if (*endp == '\0') num_threads = nt;
87 }
88 if (num_threads <= 0) {
89 num_threads = (int)dnnl::impl::cpu::platform::get_max_threads_to_use();
90 }
91 return num_threads;
92}
93} // namespace
94} // namespace testing
95} // namespace dnnl
96#endif // !defined(DNNL_TEST_THREADPOOL_USE_TBB)
97
98#if defined(DNNL_TEST_THREADPOOL_USE_EIGEN)
99
100#include <memory>
101#include "Eigen/Core"
102#include "unsupported/Eigen/CXX11/ThreadPool"
103
104#if EIGEN_WORLD_VERSION + 10 * EIGEN_MAJOR_VERSION < 33
105#define STR_(x) #x
106#define STR(x) STR_(x)
107#pragma message("EIGEN_WORLD_VERSION " STR(EIGEN_WORLD_VERSION))
108#pragma message("EIGEN_MAJOR_VERSION " STR(EIGEN_MAJOR_VERSION))
109#error Unsupported Eigen version (need 3.3.x or higher)
110#endif
111
112#if EIGEN_MINOR_VERSION >= 90
113using EigenThreadPool = Eigen::ThreadPool;
114#else
115using EigenThreadPool = Eigen::NonBlockingThreadPool;
116#endif
117
118namespace dnnl {
119namespace testing {
120
121class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
122private:
123 std::unique_ptr<EigenThreadPool> tp_;
124
125public:
126 explicit threadpool_t(int num_threads = 0) {
127 if (num_threads <= 0) num_threads = read_num_threads_from_env();
128 tp_.reset(new EigenThreadPool(num_threads));
129 }
130 int get_num_threads() const override { return tp_->NumThreads(); }
131 bool get_in_parallel() const override {
132 return tp_->CurrentThreadId() != -1;
133 }
134 uint64_t get_flags() const override { return ASYNCHRONOUS; }
135 void parallel_for(int n, const std::function<void(int, int)> &fn) override {
136 int nthr = get_num_threads();
137 int njobs = std::min(n, nthr);
138
139 for (int i = 0; i < njobs; i++) {
140 tp_->Schedule([i, n, njobs, fn]() {
141 int start, end;
142 impl::balance211(n, njobs, i, start, end);
143 for (int j = start; j < end; j++)
144 fn(j, n);
145 });
146 }
147 };
148};
149
150} // namespace testing
151} // namespace dnnl
152
153#elif defined(DNNL_TEST_THREADPOOL_USE_TBB)
154#include "tbb/parallel_for.h"
155#include "tbb/task_arena.h"
156
157namespace dnnl {
158namespace testing {
159
160class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
161public:
162 explicit threadpool_t(int num_threads) { (void)num_threads; }
163 int get_num_threads() const override {
164 return tbb::this_task_arena::max_concurrency();
165 }
166 bool get_in_parallel() const override { return 0; }
167 uint64_t get_flags() const override { return 0; }
168 void parallel_for(int n, const std::function<void(int, int)> &fn) override {
169 tbb::parallel_for(
170 0, n, [&](int i) { fn(i, n); }, tbb::static_partitioner());
171 }
172};
173
174} // namespace testing
175} // namespace dnnl
176
177#else
178
179#include <atomic>
180#include <thread>
181#include <vector>
182#include <condition_variable>
183
184namespace dnnl {
185namespace testing {
186
187// Naiive synchronous threadpool:
188// - Only a single parallel_for is executed at the same time.
189// - Recursive parallel_for results in sequential execution.
190class threadpool_t : public dnnl::threadpool_interop::threadpool_iface {
191public:
192 using task_func = std::function<void(int, int)>;
193
194 explicit threadpool_t(int num_threads = 0) {
195 if (num_threads <= 0) num_threads = read_num_threads_from_env();
196 num_threads_ = num_threads;
197 master_sense_ = 0;
198
199 for (int i = 0; i < 2; i++) {
200 tasks_[i].go_flag.store(0);
201 tasks_[i].fn = nullptr;
202 tasks_[i].n = 0;
203 }
204
205 barrier_init();
206 workers_.reset(new std::vector<worker_data>(num_threads_));
207 for (int i = 0; i < num_threads_; i++) {
208 auto wd = &workers_->at(i);
209 wd->thread_id = i;
210 wd->tp = this;
211 wd->thread.reset(new std::thread(worker_loop, &workers_->at(i)));
212 }
213 barrier_wait();
214 }
215
216 virtual ~threadpool_t() {
217 std::unique_lock<std::mutex> l(master_mutex_);
218 barrier_init();
219 task_submit(nullptr, 0);
220 for (int i = 0; i < num_threads_; i++)
221 workers_->at(i).thread->join();
222 barrier_wait();
223 }
224
225 virtual int get_num_threads() const { return num_threads_; }
226
227 virtual bool get_in_parallel() const { return worker_self() != nullptr; }
228
229 virtual uint64_t get_flags() const { return 0; }
230
231 virtual void parallel_for(int n, const task_func &fn) {
232 if (worker_self() != nullptr)
233 task_execute(0, 1, &fn, n);
234 else {
235 std::unique_lock<std::mutex> l(master_mutex_);
236 barrier_init();
237 task_submit(&fn, n);
238 barrier_wait();
239 }
240 }
241
242private:
243 int num_threads_;
244 std::mutex master_mutex_;
245 std::mutex master_submit_mutex_;
246
247 struct worker_data {
248 int thread_id;
249 threadpool_t *tp;
250 std::condition_variable cv;
251 std::unique_ptr<std::thread> thread;
252 };
253 std::unique_ptr<std::vector<worker_data>> workers_;
254 static thread_local worker_data *worker_self_;
255 worker_data *worker_self() const {
256 return worker_self_ != nullptr && worker_self_->tp == this
257 ? worker_self_
258 : nullptr;
259 }
260
261 struct task_data {
262 std::atomic<int> go_flag;
263 const task_func *fn;
264 int n;
265 };
266 int master_sense_;
267 task_data tasks_[2];
268
269 dnnl::impl::counting_barrier_t barrier_;
270
271 void barrier_init() { barrier_.init(num_threads_); }
272
273 void barrier_wait() {
274 barrier_.wait();
275 tasks_[master_sense_].go_flag.store(0);
276 master_sense_ = !master_sense_;
277 }
278
279 void barrier_notify(int worker_sense) { barrier_.notify(); }
280
281 void task_submit(const task_func *fn, int n) {
282 std::lock_guard<std::mutex> l(master_submit_mutex_);
283 tasks_[master_sense_].fn = fn;
284 tasks_[master_sense_].n = n;
285 tasks_[master_sense_].go_flag.store(1);
286 for (int i = 0; i < num_threads_; i++) {
287 workers_->at(i).cv.notify_one();
288 }
289 }
290
291 void task_execute(int ithr, int nthr, const task_func *fn, int n) {
292 if (fn != nullptr && n > 0) {
293 int start, end;
294 impl::balance211(n, nthr, ithr, start, end);
295 for (int i = start; i < end; i++)
296 (*fn)(i, n);
297 }
298 }
299
300 static void worker_loop(worker_data *wd) {
301 worker_self_ = wd;
302 int worker_sense = 0;
303
304 wd->tp->barrier_notify(worker_sense);
305
306 bool time_to_exit = false;
307 std::unique_lock<std::mutex> l(wd->tp->master_submit_mutex_);
308
309 do {
310 worker_sense = !worker_sense;
311 auto *t = &wd->tp->tasks_[worker_sense];
312 wd->tp->workers_->at(wd->thread_id).cv.wait(l, [t]() {
313 return t->go_flag.load() != 0;
314 });
315 wd->tp->task_execute(
316 wd->thread_id, wd->tp->num_threads_, t->fn, t->n);
317 time_to_exit = t->fn == nullptr;
318 wd->tp->barrier_notify(worker_sense);
319 } while (!time_to_exit);
320 }
321};
322
323thread_local threadpool_t::worker_data *threadpool_t::worker_self_ = nullptr;
324
325} // namespace testing
326} // namespace dnnl
327#endif
328
329namespace dnnl {
330
331namespace testing {
332// Threadpool singleton
333dnnl::threadpool_interop::threadpool_iface *get_threadpool(
334 const thr_ctx_t &ctx) {
335 // global default threadpool is returned when thr context is
336 // default
337 static std::unordered_map<int, dnnl::testing::threadpool_t> tp_map;
338 auto ret_val = tp_map.find(ctx.max_concurrency);
339 if (ret_val != tp_map.end()) return &(ret_val->second);
340 auto res = tp_map.emplace(std::piecewise_construct,
341 std::forward_as_tuple(ctx.max_concurrency),
342 std::forward_as_tuple(ctx.max_concurrency));
343 if (!res.second) {
344 fprintf(stderr, "get_threadpool failed to create a threadpool\n");
345 exit(1);
346 }
347 return &(res.first->second);
348}
349
350} // namespace testing
351
352// Implement a dummy threadpools_utils protocol here so that it is picked up
353// by parallel*() calls from the tests.
354namespace impl {
355namespace testing_threadpool_utils {
356void activate_threadpool(dnnl::threadpool_interop::threadpool_iface *tp) {}
357void deactivate_threadpool() {}
358dnnl::threadpool_interop::threadpool_iface *get_active_threadpool() {
359 return testing::get_threadpool();
360}
361
362// here we return 0 so that parallel* calls use the
363// default number of threads in the threadpool.
364int get_max_concurrency() {
365 return 0;
366}
367
368} // namespace testing_threadpool_utils
369
370} // namespace impl
371} // namespace dnnl
372
373#endif
374