1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/process_util.h"
17
18#if defined(ENABLE_MKL) && defined(ENABLE_ONEDNN_OPENMP)
19#ifdef _OPENMP
20#include <omp.h>
21#endif // _OPENMP
22#endif // defined(ENABLE_MKL) && defined(ENABLE_ONEDNN_OPENMP)
23#include <string.h>
24
25#include "tensorflow/core/lib/core/threadpool.h"
26#include "tensorflow/core/platform/byte_order.h"
27#include "tensorflow/core/platform/cpu_info.h"
28#include "tensorflow/core/platform/logging.h"
29#include "tensorflow/core/platform/tracing.h"
30#include "tensorflow/core/platform/types.h"
31#include "tensorflow/core/util/util.h"
32
33namespace tensorflow {
34
35namespace {
36
37// Use environment setting if specified (init once)
38int32 GetEnvNumInterOpThreads() {
39 static int32_t env_num_threads = NumInterOpThreadsFromEnvironment();
40 return env_num_threads;
41}
42
43int32 DefaultNumInterOpThreads() {
44#ifndef __ANDROID__
45 int32_t env_num_threads = GetEnvNumInterOpThreads();
46 if (env_num_threads > 0) {
47 return env_num_threads;
48 }
49
50 // Default to the maximum parallelism for the current process.
51 return port::MaxParallelism();
52#else
53 // Historically, -D__ANDROID__ resulted in the inter-op threadpool not being
54 // used (regardless of what was chosen here); instead, all work was done on
55 // the thread(s) calling Session::Run. That's no longer the case, but we'd
56 // like to avoid suddenly higher concurrency and peak resource usage (for the
57 // same device shape, graph, and options) versus prior versions - as best we
58 // can:
59 //
60 // - Single Session::Run (none concurrent), and default options:
61 // Behavior is mostly the same as before.
62 //
63 // - Concurrent Session::Runs, and default options:
64 // Reduced concurrency versus before.
65 //
66 // - Thread-pool size set explicitly (>1):
67 // Increased concurrency versus before.
68 //
69 // (We assume the first case is the most common)
70 return 1;
71#endif
72}
73
74static thread::ThreadPool* InitComputePool(const SessionOptions& options) {
75 int32_t inter_op_parallelism_threads =
76 options.config.inter_op_parallelism_threads();
77 if (inter_op_parallelism_threads == 0) {
78 inter_op_parallelism_threads = DefaultNumInterOpThreads();
79 }
80 return new thread::ThreadPool(
81 Env::Default(), ThreadOptions(), "Compute", inter_op_parallelism_threads,
82 !options.config.experimental().disable_thread_spinning(),
83 /*allocator=*/nullptr);
84}
85
86} // namespace
87
88thread::ThreadPool* ComputePool(const SessionOptions& options) {
89 static thread::ThreadPool* compute_pool = InitComputePool(options);
90 return compute_pool;
91}
92
93int32 NumInterOpThreadsFromEnvironment() {
94 int32_t num;
95 const char* val = std::getenv("TF_NUM_INTEROP_THREADS");
96 return (val && strings::safe_strto32(val, &num)) ? num : 0;
97}
98
99int32 NumIntraOpThreadsFromEnvironment() {
100 int32_t num;
101 const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
102 return (val && strings::safe_strto32(val, &num)) ? num : 0;
103}
104#if defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_MKL)
105int32 OMPThreadsFromEnvironment() {
106 // 1) std::getenv is thread-safe (as long as no other function modifies the
107 // host env) from C++11 onward. 2) Most of TF code (except tests and
108 // experimental code) doesn't call setenv and unsetenv
109 int32 num;
110 const char* val = std::getenv("OMP_NUM_THREADS");
111 return (val && strings::safe_strto32(val, &num)) ? num : 0;
112}
113
114int32 DefaultNumIntraOpThreads() {
115 // Use environment setting if specified (init once)
116 static int env_num_threads = NumIntraOpThreadsFromEnvironment();
117 if (env_num_threads > 0) {
118 return env_num_threads;
119 }
120
121 // Default to the maximum parallelism for the current process.
122 return port::MaxParallelism();
123}
124#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_MKL)
125int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
126 const int32_t inter_op = options.config.inter_op_parallelism_threads();
127 if (inter_op > 0) return inter_op;
128 const int32_t env_inter_op = GetEnvNumInterOpThreads();
129 if (env_inter_op > 0) return env_inter_op;
130
131#if defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_MKL)
132 if (IsMKLEnabled()) {
133 // MKL library executes ops in parallel using OMP threads.
134 // Setting inter_op conservatively to avoid thread oversubscription that
135 // could lead to severe perf degradations and OMP resource exhaustion.
136 // Inter ops are set such that mkl_inter_op * mkl_intra_op <= NumCores.
137 const int32 intra_op = options.config.intra_op_parallelism_threads();
138 const int32 omp_max_threads = OMPThreadsFromEnvironment();
139 const int32 mkl_intra_op =
140 (omp_max_threads > 0)
141 ? omp_max_threads
142 : (intra_op > 0) ? intra_op : DefaultNumIntraOpThreads();
143 DCHECK_GE(mkl_intra_op, 1);
144 const int32 mkl_inter_op = std::max(
145 (DefaultNumInterOpThreads() + mkl_intra_op - 1) / mkl_intra_op, 2);
146 VLOG(0)
147 << "Creating new thread pool with default inter op setting: "
148 << mkl_inter_op
149 << ". Tune using inter_op_parallelism_threads for best performance.";
150 return mkl_inter_op;
151 }
152#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(ENABLE_MKL)
153 return DefaultNumInterOpThreads();
154}
155
156thread::ThreadPool* NewThreadPoolFromSessionOptions(
157 const SessionOptions& options) {
158 const int32_t num_threads = NumInterOpThreadsFromSessionOptions(options);
159 VLOG(1) << "Session inter op parallelism threads: " << num_threads;
160 return new thread::ThreadPool(
161 options.env, ThreadOptions(), "Compute", num_threads,
162 !options.config.experimental().disable_thread_spinning(),
163 /*allocator=*/nullptr);
164}
165
166void SchedClosure(std::function<void()> closure) {
167 if (!tracing::EventCollector::IsEnabled()) {
168 return Env::Default()->SchedClosure(std::move(closure));
169 }
170 uint64 id = tracing::GetUniqueArg();
171 tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
172
173 Env::Default()->SchedClosure([id, closure = std::move(closure)]() {
174 tracing::ScopedRegion region(tracing::EventCategory::kRunClosure, id);
175 closure();
176 });
177}
178
179void SchedNonBlockingClosureAfter(int64_t micros,
180 std::function<void()> closure) {
181 Env::Default()->SchedClosureAfter(micros, std::move(closure));
182}
183
184} // namespace tensorflow
185