1 | #include <ATen/Parallel.h> |
---|---|
2 | |
3 | #include <ATen/Config.h> |
4 | #include <ATen/PTThreadPool.h> |
5 | #include <ATen/Version.h> |
6 | |
7 | #include <sstream> |
8 | #include <thread> |
9 | |
10 | #if AT_MKL_ENABLED() |
11 | #include <mkl.h> |
12 | #endif |
13 | |
14 | #ifdef _OPENMP |
15 | #include <omp.h> |
16 | #endif |
17 | |
18 | namespace at { |
19 | |
20 | namespace { |
21 | |
22 | const char* get_env_var( |
23 | const char* var_name, const char* def_value = nullptr) { |
24 | const char* value = std::getenv(var_name); |
25 | return value ? value : def_value; |
26 | } |
27 | |
28 | size_t get_env_num_threads(const char* var_name, size_t def_value = 0) { |
29 | try { |
30 | if (auto* value = std::getenv(var_name)) { |
31 | int nthreads = c10::stoi(value); |
32 | TORCH_CHECK(nthreads > 0); |
33 | return nthreads; |
34 | } |
35 | } catch (const std::exception& e) { |
36 | std::ostringstream oss; |
37 | oss << "Invalid "<< var_name << " variable value, "<< e.what(); |
38 | TORCH_WARN(oss.str()); |
39 | } |
40 | return def_value; |
41 | } |
42 | |
43 | } // namespace |
44 | |
45 | std::string get_parallel_info() { |
46 | std::ostringstream ss; |
47 | |
48 | ss << "ATen/Parallel:\n\tat::get_num_threads() : " |
49 | << at::get_num_threads() << std::endl; |
50 | ss << "\tat::get_num_interop_threads() : " |
51 | << at::get_num_interop_threads() << std::endl; |
52 | |
53 | ss << at::get_openmp_version() << std::endl; |
54 | #ifdef _OPENMP |
55 | ss << "\tomp_get_max_threads() : "<< omp_get_max_threads() << std::endl; |
56 | #endif |
57 | |
58 | ss << at::get_mkl_version() << std::endl; |
59 | #if AT_MKL_ENABLED() |
60 | ss << "\tmkl_get_max_threads() : "<< mkl_get_max_threads() << std::endl; |
61 | #endif |
62 | |
63 | ss << at::get_mkldnn_version() << std::endl; |
64 | |
65 | ss << "std::thread::hardware_concurrency() : " |
66 | << std::thread::hardware_concurrency() << std::endl; |
67 | |
68 | ss << "Environment variables:"<< std::endl; |
69 | ss << "\tOMP_NUM_THREADS : " |
70 | << get_env_var("OMP_NUM_THREADS", "[not set]") << std::endl; |
71 | ss << "\tMKL_NUM_THREADS : " |
72 | << get_env_var("MKL_NUM_THREADS", "[not set]") << std::endl; |
73 | |
74 | ss << "ATen parallel backend: "; |
75 | #if AT_PARALLEL_OPENMP |
76 | ss << "OpenMP"; |
77 | #elif AT_PARALLEL_NATIVE |
78 | ss << "native thread pool"; |
79 | #elif AT_PARALLEL_NATIVE_TBB |
80 | ss << "native thread pool and TBB"; |
81 | #endif |
82 | #ifdef C10_MOBILE |
83 | ss << " [mobile]"; |
84 | #endif |
85 | ss << std::endl; |
86 | |
87 | #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL |
88 | ss << "Experimental: single thread pool"<< std::endl; |
89 | #endif |
90 | |
91 | return ss.str(); |
92 | } |
93 | |
94 | int intraop_default_num_threads() { |
95 | #ifdef C10_MOBILE |
96 | // Intraop thread pool size should be determined by mobile cpuinfo. |
97 | // We should hook up with the logic in caffe2/utils/threadpool if we ever need |
98 | // call this API for mobile. |
99 | TORCH_CHECK(false, "Undefined intraop_default_num_threads on mobile."); |
100 | #else |
101 | size_t nthreads = get_env_num_threads("OMP_NUM_THREADS", 0); |
102 | nthreads = get_env_num_threads("MKL_NUM_THREADS", nthreads); |
103 | if (nthreads == 0) { |
104 | nthreads = TaskThreadPoolBase::defaultNumThreads(); |
105 | } |
106 | return nthreads; |
107 | #endif |
108 | } |
109 | |
110 | } // namespace at |
111 |