1 | #include <ATen/Config.h> |
2 | #if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE || AT_PARALLEL_NATIVE_TBB |
3 | #include <ATen/Parallel.h> |
4 | #include <ATen/PTThreadPool.h> |
5 | #include <ATen/ThreadLocalState.h> |
6 | |
7 | #include <atomic> |
8 | |
9 | namespace at { |
10 | |
11 | namespace { |
12 | const int NOT_SET = -1; |
13 | const int CONSUMED = -2; |
14 | |
15 | // Number of inter-op threads set by the user; |
16 | // NOT_SET -> positive value -> CONSUMED |
17 | // (CONSUMED - thread pool is initialized) |
18 | // or |
19 | // NOT_SET -> CONSUMED |
20 | std::atomic<int> num_interop_threads{NOT_SET}; |
21 | |
22 | // thread pool global instance is hidden, |
23 | // users should use at::launch and get/set_num_interop_threads interface |
24 | TaskThreadPoolBase& get_pool() { |
25 | static std::shared_ptr<TaskThreadPoolBase> pool = |
26 | ThreadPoolRegistry()->Create( |
27 | "C10" , |
28 | /* device_id */ 0, |
29 | /* pool_size */ num_interop_threads.exchange(CONSUMED), |
30 | /* create_new */ true); |
31 | return *pool; |
32 | } |
33 | |
34 | // Factory function for ThreadPoolRegistry |
35 | std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool( |
36 | int device_id, |
37 | int pool_size, |
38 | bool create_new) { |
39 | // For now, the only accepted device id is 0 |
40 | TORCH_CHECK(device_id == 0); |
41 | // Create new thread pool |
42 | TORCH_CHECK(create_new); |
43 | return std::make_shared<PTThreadPool>(pool_size); |
44 | } |
45 | |
46 | } // namespace |
47 | |
48 | C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool); |
49 | |
50 | void set_num_interop_threads(int nthreads) { |
51 | TORCH_CHECK(nthreads > 0, "Expected positive number of threads" ); |
52 | |
53 | int no_value = NOT_SET; |
54 | TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads), |
55 | "Error: cannot set number of interop threads after parallel work " |
56 | "has started or set_num_interop_threads called" ); |
57 | } |
58 | |
59 | int get_num_interop_threads() { |
60 | at::internal::lazy_init_num_threads(); |
61 | int nthreads = num_interop_threads.load(); |
62 | if (nthreads > 0) { |
63 | return nthreads; |
64 | } else if (nthreads == NOT_SET) { |
65 | // return default value |
66 | return TaskThreadPoolBase::defaultNumThreads(); |
67 | } else { |
68 | return get_pool().size(); |
69 | } |
70 | } |
71 | |
72 | namespace internal { |
73 | void launch_no_thread_state(std::function<void()> fn) { |
74 | #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL |
75 | intraop_launch(std::move(fn)); |
76 | #else |
77 | get_pool().run(std::move(fn)); |
78 | #endif |
79 | } |
80 | } // namespace internal |
81 | |
82 | void launch(std::function<void()> func) { |
83 | // NOLINTNEXTLINE(modernize-avoid-bind) |
84 | internal::launch_no_thread_state(std::bind([]( |
85 | std::function<void()> f, ThreadLocalState thread_locals) { |
86 | ThreadLocalStateGuard guard(std::move(thread_locals)); |
87 | f(); |
88 | }, |
89 | std::move(func), |
90 | ThreadLocalState() |
91 | )); |
92 | } |
93 | |
94 | } // namespace at |
95 | #endif |
96 | |