1 | #include <ATen/Config.h> |
2 | #if AT_PARALLEL_NATIVE_TBB |
3 | #include <ATen/Parallel.h> |
4 | #include <ATen/ParallelFuture.h> |
5 | #include <ATen/PTThreadPool.h> |
6 | |
7 | #include <atomic> |
8 | #include <mutex> |
9 | |
10 | #include <tbb/tbb.h> |
11 | #define TBB_PREVIEW_GLOBAL_CONTROL 1 |
12 | #include <tbb/global_control.h> |
13 | |
14 | #ifdef _OPENMP |
15 | #include <omp.h> |
16 | #endif |
17 | |
18 | #if AT_MKL_ENABLED() |
19 | #include <mkl.h> |
20 | #endif |
21 | |
22 | namespace at { |
23 | |
24 | namespace { |
25 | static thread_local tbb::task_group tg_; |
26 | thread_local int this_thread_id{0}; |
27 | |
28 | std::mutex global_thread_mutex_; |
29 | std::shared_ptr<tbb::global_control> global_thread_limit_ = nullptr; |
30 | std::atomic<int> num_intraop_threads_{-1}; |
31 | |
32 | void _internal_set_num_threads(int nthreads) { |
33 | TORCH_INTERNAL_ASSERT(nthreads > 0); |
34 | { |
35 | std::unique_lock<std::mutex> lk(global_thread_mutex_); |
36 | // This is an antipattern and we shouldn't be constraining the number of |
37 | // threads in library code. |
38 | // TODO: Think of a smarter way to leverage tbb::thread_arena to limit the |
39 | // number of slots instead of the number of threads. |
40 | global_thread_limit_ = std::make_shared<tbb::global_control>( |
41 | tbb::global_control::max_allowed_parallelism, nthreads); |
42 | num_intraop_threads_.store(nthreads); |
43 | } |
44 | } |
45 | } |
46 | |
47 | void init_num_threads() { |
48 | #ifdef _OPENMP |
49 | omp_set_num_threads(1); |
50 | #endif |
51 | |
52 | #if AT_MKL_ENABLED() |
53 | mkl_set_num_threads(1); |
54 | #endif |
55 | |
56 | int nthreads = num_intraop_threads_.load(); |
57 | if (nthreads < 0) { |
58 | nthreads = intraop_default_num_threads(); |
59 | } |
60 | _internal_set_num_threads(nthreads); |
61 | } |
62 | |
63 | void set_num_threads(int nthreads) { |
64 | TORCH_CHECK(nthreads > 0); |
65 | |
66 | _internal_set_num_threads(nthreads); |
67 | } |
68 | |
69 | int get_num_threads() { |
70 | at::internal::lazy_init_num_threads(); |
71 | return tbb::global_control::active_value( |
72 | tbb::global_control::max_allowed_parallelism); |
73 | } |
74 | |
75 | int get_thread_num() { |
76 | return this_thread_id; |
77 | } |
78 | |
79 | namespace internal { |
80 | void set_thread_num(int id) { |
81 | this_thread_id = id; |
82 | } |
83 | } |
84 | |
85 | bool in_parallel_region() { |
86 | return tbb::this_task_arena::current_thread_index() >= 0; |
87 | } |
88 | |
89 | void intraop_launch(std::function<void()> func) { |
90 | if (get_num_threads() > 1) { |
91 | tg_.run(func); |
92 | } else { |
93 | func(); |
94 | } |
95 | } |
96 | |
97 | c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future( |
98 | std::function<void()> func) { |
99 | auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get()); |
100 | if (get_num_threads() > 1) { |
101 | tg_.run( |
102 | [func, future]() { |
103 | func(); |
104 | future->markCompleted(); |
105 | } |
106 | ); |
107 | } else { |
108 | func(); |
109 | future->markCompleted(); |
110 | } |
111 | return future; |
112 | } |
113 | |
114 | } // namespace at |
115 | #endif |
116 | |