1#include <ATen/Config.h>
2#include <ATen/core/jit_type.h>
3#if AT_PARALLEL_OPENMP
4#include <ATen/Parallel.h>
5#include <ATen/ParallelFuture.h>
6
7#include <atomic>
8
9#if AT_MKL_ENABLED()
10#include <mkl.h>
11#endif
12
13#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
14
15namespace at {
16#if AT_MKLDNN_ENABLED()
17namespace native { namespace mkldnn {
18void clear_computation_cache();
19}} // namespace native::mkldnn
20#endif
21
22namespace {
23// Number of threads set by the user
24std::atomic<int> num_threads{-1};
25thread_local int this_thread_id{0};
26
27} // namespace
28
29void init_num_threads() {
30 auto nthreads = num_threads.load();
31 if (nthreads > 0) {
32 set_num_threads(nthreads);
33 } else {
34#if defined(_OPENMP) && AT_MKL_ENABLED() && !AT_MKL_SEQUENTIAL()
35 // If we are using MKL an OpenMP make sure the number of threads match.
36 // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
37 // size of the OpenMP thread pool, resulting in worse performance (and memory
38 // leaks in GCC 5.4)
39 omp_set_num_threads(mkl_get_max_threads());
40#elif defined(_OPENMP)
41 omp_set_num_threads(intraop_default_num_threads());
42#endif
43 }
44}
45
46void set_num_threads(int nthreads) {
47 TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
48 num_threads.store(nthreads);
49#ifdef _OPENMP
50 omp_set_num_threads(nthreads);
51#endif
52#if AT_MKL_ENABLED()
53 mkl_set_num_threads_local(nthreads);
54
55 // because PyTorch uses OpenMP outside of MKL invocations
56 // as well, we want this flag to be false, so that
57 // threads aren't destroyed and recreated across every
58 // MKL / non-MKL boundary of OpenMP usage
59 // See https://github.com/pytorch/pytorch/issues/13757
60 mkl_set_dynamic(false);
61#endif
62#ifdef USE_PTHREADPOOL
63 // because PyTorch uses caffe2::pthreadpool() in QNNPACK
64 caffe2::PThreadPool* const pool = caffe2::pthreadpool();
65 TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
66 pool->set_thread_count(nthreads);
67#endif
68#if AT_MKLDNN_ENABLED()
69 at::native::mkldnn::clear_computation_cache();
70#endif
71}
72
73// Explicitly calling omp_get_max_threads() as the size of the parallel
74// region might be different in the new thread;
75// Use init_num_threads() during thread initialization to ensure
76// consistent size of parallel region in different threads
77int get_num_threads() {
78#ifdef _OPENMP
79 at::internal::lazy_init_num_threads();
80 return omp_get_max_threads();
81#else
82 return 1;
83#endif
84}
85
86int get_thread_num() {
87 return this_thread_id;
88}
89
90namespace internal {
91void set_thread_num(int id) {
92 this_thread_id = id;
93}
94}
95
96bool in_parallel_region() {
97#ifdef _OPENMP
98 return omp_in_parallel();
99#else
100 return false;
101#endif
102}
103
104void intraop_launch(std::function<void()> func) {
105 // execute inline in openmp case
106 func();
107}
108
109c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
110 std::function<void()> func) {
111 func();
112 auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
113 future->markCompleted();
114 return future;
115}
116
117} // namespace at
118#endif
119