1 | #pragma once |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/SmallVector.h> |
5 | |
6 | namespace at { |
7 | |
8 | template <class F> |
9 | inline void parallel_for( |
10 | const int64_t begin, |
11 | const int64_t end, |
12 | const int64_t grain_size, |
13 | const F& f) { |
14 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0); |
15 | if (begin >= end) { |
16 | return; |
17 | } |
18 | |
19 | #ifdef INTRA_OP_PARALLEL |
20 | at::internal::lazy_init_num_threads(); |
21 | const auto numiter = end - begin; |
22 | const bool use_parallel = |
23 | (numiter > grain_size && numiter > 1 && !at::in_parallel_region() && |
24 | at::get_num_threads() > 1); |
25 | if (!use_parallel) { |
26 | internal::ThreadIdGuard tid_guard(0); |
27 | f(begin, end); |
28 | return; |
29 | } |
30 | |
31 | internal::invoke_parallel(begin, end, grain_size, f); |
32 | #else |
33 | internal::ThreadIdGuard tid_guard(0); |
34 | f(begin, end); |
35 | #endif |
36 | } |
37 | |
38 | template <class scalar_t, class F, class SF> |
39 | inline scalar_t parallel_reduce( |
40 | const int64_t begin, |
41 | const int64_t end, |
42 | const int64_t grain_size, |
43 | const scalar_t ident, |
44 | const F& f, |
45 | const SF& sf) { |
46 | TORCH_CHECK(grain_size >= 0); |
47 | if (begin >= end) { |
48 | return ident; |
49 | } |
50 | |
51 | #ifdef INTRA_OP_PARALLEL |
52 | at::internal::lazy_init_num_threads(); |
53 | const auto max_threads = at::get_num_threads(); |
54 | const bool use_parallel = |
55 | ((end - begin) > grain_size && !at::in_parallel_region() && |
56 | max_threads > 1); |
57 | if (!use_parallel) { |
58 | internal::ThreadIdGuard tid_guard(0); |
59 | return f(begin, end, ident); |
60 | } |
61 | |
62 | c10::SmallVector<scalar_t, 64> results(max_threads, ident); |
63 | internal::invoke_parallel( |
64 | begin, |
65 | end, |
66 | grain_size, |
67 | [&](const int64_t my_begin, const int64_t my_end) { |
68 | const auto tid = at::get_thread_num(); |
69 | results[tid] = f(my_begin, my_end, ident); |
70 | }); |
71 | |
72 | scalar_t result = ident; |
73 | for (auto partial_result : results) { |
74 | result = sf(result, partial_result); |
75 | } |
76 | return result; |
77 | #else |
78 | internal::ThreadIdGuard tid_guard(0); |
79 | return f(begin, end, ident); |
80 | #endif |
81 | } |
82 | |
83 | } // namespace at |
84 | |