1#pragma once
2
3#include <c10/util/Exception.h>
4#include <c10/util/SmallVector.h>
5
6namespace at {
7
8template <class F>
9inline void parallel_for(
10 const int64_t begin,
11 const int64_t end,
12 const int64_t grain_size,
13 const F& f) {
14 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
15 if (begin >= end) {
16 return;
17 }
18
19#ifdef INTRA_OP_PARALLEL
20 at::internal::lazy_init_num_threads();
21 const auto numiter = end - begin;
22 const bool use_parallel =
23 (numiter > grain_size && numiter > 1 && !at::in_parallel_region() &&
24 at::get_num_threads() > 1);
25 if (!use_parallel) {
26 internal::ThreadIdGuard tid_guard(0);
27 f(begin, end);
28 return;
29 }
30
31 internal::invoke_parallel(begin, end, grain_size, f);
32#else
33 internal::ThreadIdGuard tid_guard(0);
34 f(begin, end);
35#endif
36}
37
38template <class scalar_t, class F, class SF>
39inline scalar_t parallel_reduce(
40 const int64_t begin,
41 const int64_t end,
42 const int64_t grain_size,
43 const scalar_t ident,
44 const F& f,
45 const SF& sf) {
46 TORCH_CHECK(grain_size >= 0);
47 if (begin >= end) {
48 return ident;
49 }
50
51#ifdef INTRA_OP_PARALLEL
52 at::internal::lazy_init_num_threads();
53 const auto max_threads = at::get_num_threads();
54 const bool use_parallel =
55 ((end - begin) > grain_size && !at::in_parallel_region() &&
56 max_threads > 1);
57 if (!use_parallel) {
58 internal::ThreadIdGuard tid_guard(0);
59 return f(begin, end, ident);
60 }
61
62 c10::SmallVector<scalar_t, 64> results(max_threads, ident);
63 internal::invoke_parallel(
64 begin,
65 end,
66 grain_size,
67 [&](const int64_t my_begin, const int64_t my_end) {
68 const auto tid = at::get_thread_num();
69 results[tid] = f(my_begin, my_end, ident);
70 });
71
72 scalar_t result = ident;
73 for (auto partial_result : results) {
74 result = sf(result, partial_result);
75 }
76 return result;
77#else
78 internal::ThreadIdGuard tid_guard(0);
79 return f(begin, end, ident);
80#endif
81}
82
83} // namespace at
84