1#include <ATen/Config.h>
2#if AT_PARALLEL_OPENMP || AT_PARALLEL_NATIVE || AT_PARALLEL_NATIVE_TBB
3#include <ATen/Parallel.h>
4#include <ATen/PTThreadPool.h>
5#include <ATen/ThreadLocalState.h>
6
7#include <atomic>
8
9namespace at {
10
11namespace {
12const int NOT_SET = -1;
13const int CONSUMED = -2;
14
15// Number of inter-op threads set by the user;
16// NOT_SET -> positive value -> CONSUMED
17// (CONSUMED - thread pool is initialized)
18// or
19// NOT_SET -> CONSUMED
20std::atomic<int> num_interop_threads{NOT_SET};
21
22// thread pool global instance is hidden,
23// users should use at::launch and get/set_num_interop_threads interface
24TaskThreadPoolBase& get_pool() {
25 static std::shared_ptr<TaskThreadPoolBase> pool =
26 ThreadPoolRegistry()->Create(
27 "C10",
28 /* device_id */ 0,
29 /* pool_size */ num_interop_threads.exchange(CONSUMED),
30 /* create_new */ true);
31 return *pool;
32}
33
34// Factory function for ThreadPoolRegistry
35std::shared_ptr<TaskThreadPoolBase> create_c10_threadpool(
36 int device_id,
37 int pool_size,
38 bool create_new) {
39 // For now, the only accepted device id is 0
40 TORCH_CHECK(device_id == 0);
41 // Create new thread pool
42 TORCH_CHECK(create_new);
43 return std::make_shared<PTThreadPool>(pool_size);
44}
45
46} // namespace
47
48C10_REGISTER_CREATOR(ThreadPoolRegistry, C10, create_c10_threadpool);
49
50void set_num_interop_threads(int nthreads) {
51 TORCH_CHECK(nthreads > 0, "Expected positive number of threads");
52
53 int no_value = NOT_SET;
54 TORCH_CHECK(num_interop_threads.compare_exchange_strong(no_value, nthreads),
55 "Error: cannot set number of interop threads after parallel work "
56 "has started or set_num_interop_threads called");
57}
58
59int get_num_interop_threads() {
60 at::internal::lazy_init_num_threads();
61 int nthreads = num_interop_threads.load();
62 if (nthreads > 0) {
63 return nthreads;
64 } else if (nthreads == NOT_SET) {
65 // return default value
66 return TaskThreadPoolBase::defaultNumThreads();
67 } else {
68 return get_pool().size();
69 }
70}
71
72namespace internal {
73void launch_no_thread_state(std::function<void()> fn) {
74#if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
75 intraop_launch(std::move(fn));
76#else
77 get_pool().run(std::move(fn));
78#endif
79}
80} // namespace internal
81
82void launch(std::function<void()> func) {
83 // NOLINTNEXTLINE(modernize-avoid-bind)
84 internal::launch_no_thread_state(std::bind([](
85 std::function<void()> f, ThreadLocalState thread_locals) {
86 ThreadLocalStateGuard guard(std::move(thread_locals));
87 f();
88 },
89 std::move(func),
90 ThreadLocalState()
91 ));
92}
93
94} // namespace at
95#endif
96