1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <tuple> |
17 | |
18 | #include "tests/test_thread.hpp" |
19 | |
20 | std::ostream &operator<<(std::ostream &os, const thr_ctx_t &ctx) { |
21 | if (ctx.max_concurrency == default_thr_ctx.max_concurrency) |
22 | os << "auto:" ; |
23 | else |
24 | os << ctx.max_concurrency << ":" ; |
25 | |
26 | if (ctx.core_type == default_thr_ctx.core_type) |
27 | os << "auto:" ; |
28 | else |
29 | os << ctx.core_type << ":" ; |
30 | |
31 | if (ctx.nthr_per_core == default_thr_ctx.nthr_per_core) |
32 | os << "auto" ; |
33 | else |
34 | os << ctx.nthr_per_core; |
35 | |
36 | return os; |
37 | } |
38 | |
39 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
40 | void *thr_ctx_t::get_interop_obj() const { |
41 | return dnnl::testing::get_threadpool(*this); |
42 | } |
43 | #else |
44 | void *thr_ctx_t::get_interop_obj() const { |
45 | return nullptr; |
46 | } |
47 | #endif |
48 | |
49 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_THREADPOOL |
50 | |
51 | #include <mutex> |
52 | #include <unordered_map> |
53 | |
54 | #ifdef _WIN32 |
55 | #include <windows.h> |
56 | #else |
57 | #include <stdlib.h> |
58 | #endif |
59 | |
60 | #include "oneapi/dnnl/dnnl_threadpool_iface.hpp" |
61 | #include "src/common/counting_barrier.hpp" |
62 | |
63 | #if !defined(DNNL_TEST_THREADPOOL_USE_TBB) |
64 | |
65 | #include "src/cpu/platform.hpp" |
66 | namespace dnnl { |
67 | namespace testing { |
68 | namespace { |
69 | inline int read_num_threads_from_env() { |
70 | const char *env_num_threads = nullptr; |
71 | const char *env_var_name = "OMP_NUM_THREADS" ; |
72 | #ifdef _WIN32 |
73 | // This is only required to avoid using _CRT_SECURE_NO_WARNINGS |
74 | const size_t buf_size = 12; |
75 | char buf[buf_size]; |
76 | size_t val_size = GetEnvironmentVariable(env_var_name, buf, buf_size); |
77 | if (val_size > 0 && val_size < buf_size) env_num_threads = buf; |
78 | #else // ifdef _WIN32 |
79 | env_num_threads = ::getenv(env_var_name); |
80 | #endif |
81 | |
82 | int num_threads = 0; |
83 | if (env_num_threads) { |
84 | char *endp; |
85 | int nt = strtol(env_num_threads, &endp, 10); |
86 | if (*endp == '\0') num_threads = nt; |
87 | } |
88 | if (num_threads <= 0) { |
89 | num_threads = (int)dnnl::impl::cpu::platform::get_max_threads_to_use(); |
90 | } |
91 | return num_threads; |
92 | } |
93 | } // namespace |
94 | } // namespace testing |
95 | } // namespace dnnl |
96 | #endif // !defined(DNNL_TEST_THREADPOOL_USE_TBB) |
97 | |
98 | #if defined(DNNL_TEST_THREADPOOL_USE_EIGEN) |
99 | |
100 | #include <memory> |
101 | #include "Eigen/Core" |
102 | #include "unsupported/Eigen/CXX11/ThreadPool" |
103 | |
104 | #if EIGEN_WORLD_VERSION + 10 * EIGEN_MAJOR_VERSION < 33 |
105 | #define STR_(x) #x |
106 | #define STR(x) STR_(x) |
107 | #pragma message("EIGEN_WORLD_VERSION " STR(EIGEN_WORLD_VERSION)) |
108 | #pragma message("EIGEN_MAJOR_VERSION " STR(EIGEN_MAJOR_VERSION)) |
109 | #error Unsupported Eigen version (need 3.3.x or higher) |
110 | #endif |
111 | |
112 | #if EIGEN_MINOR_VERSION >= 90 |
113 | using EigenThreadPool = Eigen::ThreadPool; |
114 | #else |
115 | using EigenThreadPool = Eigen::NonBlockingThreadPool; |
116 | #endif |
117 | |
118 | namespace dnnl { |
119 | namespace testing { |
120 | |
121 | class threadpool_t : public dnnl::threadpool_interop::threadpool_iface { |
122 | private: |
123 | std::unique_ptr<EigenThreadPool> tp_; |
124 | |
125 | public: |
126 | explicit threadpool_t(int num_threads = 0) { |
127 | if (num_threads <= 0) num_threads = read_num_threads_from_env(); |
128 | tp_.reset(new EigenThreadPool(num_threads)); |
129 | } |
130 | int get_num_threads() const override { return tp_->NumThreads(); } |
131 | bool get_in_parallel() const override { |
132 | return tp_->CurrentThreadId() != -1; |
133 | } |
134 | uint64_t get_flags() const override { return ASYNCHRONOUS; } |
135 | void parallel_for(int n, const std::function<void(int, int)> &fn) override { |
136 | int nthr = get_num_threads(); |
137 | int njobs = std::min(n, nthr); |
138 | |
139 | for (int i = 0; i < njobs; i++) { |
140 | tp_->Schedule([i, n, njobs, fn]() { |
141 | int start, end; |
142 | impl::balance211(n, njobs, i, start, end); |
143 | for (int j = start; j < end; j++) |
144 | fn(j, n); |
145 | }); |
146 | } |
147 | }; |
148 | }; |
149 | |
150 | } // namespace testing |
151 | } // namespace dnnl |
152 | |
153 | #elif defined(DNNL_TEST_THREADPOOL_USE_TBB) |
154 | #include "tbb/parallel_for.h" |
155 | #include "tbb/task_arena.h" |
156 | |
157 | namespace dnnl { |
158 | namespace testing { |
159 | |
160 | class threadpool_t : public dnnl::threadpool_interop::threadpool_iface { |
161 | public: |
162 | explicit threadpool_t(int num_threads) { (void)num_threads; } |
163 | int get_num_threads() const override { |
164 | return tbb::this_task_arena::max_concurrency(); |
165 | } |
166 | bool get_in_parallel() const override { return 0; } |
167 | uint64_t get_flags() const override { return 0; } |
168 | void parallel_for(int n, const std::function<void(int, int)> &fn) override { |
169 | tbb::parallel_for( |
170 | 0, n, [&](int i) { fn(i, n); }, tbb::static_partitioner()); |
171 | } |
172 | }; |
173 | |
174 | } // namespace testing |
175 | } // namespace dnnl |
176 | |
177 | #else |
178 | |
179 | #include <atomic> |
180 | #include <thread> |
181 | #include <vector> |
182 | #include <condition_variable> |
183 | |
184 | namespace dnnl { |
185 | namespace testing { |
186 | |
187 | // Naiive synchronous threadpool: |
188 | // - Only a single parallel_for is executed at the same time. |
189 | // - Recursive parallel_for results in sequential execution. |
190 | class threadpool_t : public dnnl::threadpool_interop::threadpool_iface { |
191 | public: |
192 | using task_func = std::function<void(int, int)>; |
193 | |
194 | explicit threadpool_t(int num_threads = 0) { |
195 | if (num_threads <= 0) num_threads = read_num_threads_from_env(); |
196 | num_threads_ = num_threads; |
197 | master_sense_ = 0; |
198 | |
199 | for (int i = 0; i < 2; i++) { |
200 | tasks_[i].go_flag.store(0); |
201 | tasks_[i].fn = nullptr; |
202 | tasks_[i].n = 0; |
203 | } |
204 | |
205 | barrier_init(); |
206 | workers_.reset(new std::vector<worker_data>(num_threads_)); |
207 | for (int i = 0; i < num_threads_; i++) { |
208 | auto wd = &workers_->at(i); |
209 | wd->thread_id = i; |
210 | wd->tp = this; |
211 | wd->thread.reset(new std::thread(worker_loop, &workers_->at(i))); |
212 | } |
213 | barrier_wait(); |
214 | } |
215 | |
216 | virtual ~threadpool_t() { |
217 | std::unique_lock<std::mutex> l(master_mutex_); |
218 | barrier_init(); |
219 | task_submit(nullptr, 0); |
220 | for (int i = 0; i < num_threads_; i++) |
221 | workers_->at(i).thread->join(); |
222 | barrier_wait(); |
223 | } |
224 | |
225 | virtual int get_num_threads() const { return num_threads_; } |
226 | |
227 | virtual bool get_in_parallel() const { return worker_self() != nullptr; } |
228 | |
229 | virtual uint64_t get_flags() const { return 0; } |
230 | |
231 | virtual void parallel_for(int n, const task_func &fn) { |
232 | if (worker_self() != nullptr) |
233 | task_execute(0, 1, &fn, n); |
234 | else { |
235 | std::unique_lock<std::mutex> l(master_mutex_); |
236 | barrier_init(); |
237 | task_submit(&fn, n); |
238 | barrier_wait(); |
239 | } |
240 | } |
241 | |
242 | private: |
243 | int num_threads_; |
244 | std::mutex master_mutex_; |
245 | std::mutex master_submit_mutex_; |
246 | |
247 | struct worker_data { |
248 | int thread_id; |
249 | threadpool_t *tp; |
250 | std::condition_variable cv; |
251 | std::unique_ptr<std::thread> thread; |
252 | }; |
253 | std::unique_ptr<std::vector<worker_data>> workers_; |
254 | static thread_local worker_data *worker_self_; |
255 | worker_data *worker_self() const { |
256 | return worker_self_ != nullptr && worker_self_->tp == this |
257 | ? worker_self_ |
258 | : nullptr; |
259 | } |
260 | |
261 | struct task_data { |
262 | std::atomic<int> go_flag; |
263 | const task_func *fn; |
264 | int n; |
265 | }; |
266 | int master_sense_; |
267 | task_data tasks_[2]; |
268 | |
269 | dnnl::impl::counting_barrier_t barrier_; |
270 | |
271 | void barrier_init() { barrier_.init(num_threads_); } |
272 | |
273 | void barrier_wait() { |
274 | barrier_.wait(); |
275 | tasks_[master_sense_].go_flag.store(0); |
276 | master_sense_ = !master_sense_; |
277 | } |
278 | |
279 | void barrier_notify(int worker_sense) { barrier_.notify(); } |
280 | |
281 | void task_submit(const task_func *fn, int n) { |
282 | std::lock_guard<std::mutex> l(master_submit_mutex_); |
283 | tasks_[master_sense_].fn = fn; |
284 | tasks_[master_sense_].n = n; |
285 | tasks_[master_sense_].go_flag.store(1); |
286 | for (int i = 0; i < num_threads_; i++) { |
287 | workers_->at(i).cv.notify_one(); |
288 | } |
289 | } |
290 | |
291 | void task_execute(int ithr, int nthr, const task_func *fn, int n) { |
292 | if (fn != nullptr && n > 0) { |
293 | int start, end; |
294 | impl::balance211(n, nthr, ithr, start, end); |
295 | for (int i = start; i < end; i++) |
296 | (*fn)(i, n); |
297 | } |
298 | } |
299 | |
300 | static void worker_loop(worker_data *wd) { |
301 | worker_self_ = wd; |
302 | int worker_sense = 0; |
303 | |
304 | wd->tp->barrier_notify(worker_sense); |
305 | |
306 | bool time_to_exit = false; |
307 | std::unique_lock<std::mutex> l(wd->tp->master_submit_mutex_); |
308 | |
309 | do { |
310 | worker_sense = !worker_sense; |
311 | auto *t = &wd->tp->tasks_[worker_sense]; |
312 | wd->tp->workers_->at(wd->thread_id).cv.wait(l, [t]() { |
313 | return t->go_flag.load() != 0; |
314 | }); |
315 | wd->tp->task_execute( |
316 | wd->thread_id, wd->tp->num_threads_, t->fn, t->n); |
317 | time_to_exit = t->fn == nullptr; |
318 | wd->tp->barrier_notify(worker_sense); |
319 | } while (!time_to_exit); |
320 | } |
321 | }; |
322 | |
323 | thread_local threadpool_t::worker_data *threadpool_t::worker_self_ = nullptr; |
324 | |
325 | } // namespace testing |
326 | } // namespace dnnl |
327 | #endif |
328 | |
329 | namespace dnnl { |
330 | |
331 | namespace testing { |
332 | // Threadpool singleton |
333 | dnnl::threadpool_interop::threadpool_iface *get_threadpool( |
334 | const thr_ctx_t &ctx) { |
335 | // global default threadpool is returned when thr context is |
336 | // default |
337 | static std::unordered_map<int, dnnl::testing::threadpool_t> tp_map; |
338 | auto ret_val = tp_map.find(ctx.max_concurrency); |
339 | if (ret_val != tp_map.end()) return &(ret_val->second); |
340 | auto res = tp_map.emplace(std::piecewise_construct, |
341 | std::forward_as_tuple(ctx.max_concurrency), |
342 | std::forward_as_tuple(ctx.max_concurrency)); |
343 | if (!res.second) { |
344 | fprintf(stderr, "get_threadpool failed to create a threadpool\n" ); |
345 | exit(1); |
346 | } |
347 | return &(res.first->second); |
348 | } |
349 | |
350 | } // namespace testing |
351 | |
352 | // Implement a dummy threadpools_utils protocol here so that it is picked up |
353 | // by parallel*() calls from the tests. |
354 | namespace impl { |
355 | namespace testing_threadpool_utils { |
356 | void activate_threadpool(dnnl::threadpool_interop::threadpool_iface *tp) {} |
357 | void deactivate_threadpool() {} |
358 | dnnl::threadpool_interop::threadpool_iface *get_active_threadpool() { |
359 | return testing::get_threadpool(); |
360 | } |
361 | |
362 | // here we return 0 so that parallel* calls use the |
363 | // default number of threads in the threadpool. |
364 | int get_max_concurrency() { |
365 | return 0; |
366 | } |
367 | |
368 | } // namespace testing_threadpool_utils |
369 | |
370 | } // namespace impl |
371 | } // namespace dnnl |
372 | |
373 | #endif |
374 | |