1 | #pragma once |
---|---|
2 | |
3 | #include <atomic> |
4 | #include <cstddef> |
5 | #include <exception> |
6 | |
7 | #ifdef _OPENMP |
8 | #define INTRA_OP_PARALLEL |
9 | |
10 | #include <omp.h> |
11 | #endif |
12 | |
13 | namespace at { |
14 | |
15 | #ifdef _OPENMP |
16 | namespace internal { |
17 | template <typename F> |
18 | inline void invoke_parallel( |
19 | int64_t begin, |
20 | int64_t end, |
21 | int64_t grain_size, |
22 | const F& f) { |
23 | std::atomic_flag err_flag = ATOMIC_FLAG_INIT; |
24 | std::exception_ptr eptr; |
25 | |
26 | #pragma omp parallel |
27 | { |
28 | // choose number of tasks based on grain size and number of threads |
29 | // can't use num_threads clause due to bugs in GOMP's thread pool (See |
30 | // #32008) |
31 | int64_t num_threads = omp_get_num_threads(); |
32 | if (grain_size > 0) { |
33 | num_threads = std::min(num_threads, divup((end - begin), grain_size)); |
34 | } |
35 | |
36 | int64_t tid = omp_get_thread_num(); |
37 | int64_t chunk_size = divup((end - begin), num_threads); |
38 | int64_t begin_tid = begin + tid * chunk_size; |
39 | if (begin_tid < end) { |
40 | try { |
41 | internal::ThreadIdGuard tid_guard(tid); |
42 | f(begin_tid, std::min(end, chunk_size + begin_tid)); |
43 | } catch (...) { |
44 | if (!err_flag.test_and_set()) { |
45 | eptr = std::current_exception(); |
46 | } |
47 | } |
48 | } |
49 | } |
50 | if (eptr) { |
51 | std::rethrow_exception(eptr); |
52 | } |
53 | } |
54 | } // namespace internal |
55 | #endif // _OPENMP |
56 | |
57 | } // namespace at |
58 |