1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/common_runtime/local_device.h"
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/common_runtime/process_state.h"
22#include "tensorflow/core/common_runtime/process_util.h"
23#include "tensorflow/core/lib/core/threadpool.h"
24#include "tensorflow/core/platform/byte_order.h"
25#include "tensorflow/core/platform/cpu_feature_guard.h"
26#include "tensorflow/core/platform/cpu_info.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/platform/numa.h"
29#include "tensorflow/core/platform/types.h"
30#include "tensorflow/core/public/session_options.h"
31#include "tensorflow/core/util/env_var.h"
32
33namespace tensorflow {
34namespace {
35
36bool OverrideGlobalThreadPoolFromEnvironment() {
37 static const bool override_global_threadpool = [] {
38 bool flag;
39 auto status = ReadBoolFromEnvVar("TF_OVERRIDE_GLOBAL_THREADPOOL",
40 /*default_val=*/false, &flag);
41 if (!status.ok()) {
42 LOG(ERROR) << "OverrideGlobalThreadPool: " << status.error_message();
43 return false;
44 }
45 return flag;
46 }();
47 return override_global_threadpool;
48}
49
50} // namespace
51
52/* static */
53bool LocalDevice::use_global_threadpool_ = true;
54mutex LocalDevice::global_tp_mu_;
55gtl::InlinedVector<LocalDevice::EigenThreadPoolInfo*, 4>
56 LocalDevice::global_tp_info_;
57
58struct LocalDevice::EigenThreadPoolInfo {
59 // Wrapper so we can provide the CPUAllocator to Eigen for use
60 // when ops need extra tmp memory.
61 class EigenAllocator : public Eigen::Allocator {
62 public:
63 explicit EigenAllocator(tensorflow::Allocator* a) : allocator_(a) {}
64 void* allocate(size_t num_bytes) const override {
65 return allocator_->AllocateRaw(64, num_bytes);
66 }
67 void deallocate(void* buffer) const override {
68 allocator_->DeallocateRaw(buffer);
69 }
70 tensorflow::Allocator* allocator_;
71 };
72
73 explicit EigenThreadPoolInfo(const SessionOptions& options, int numa_node,
74 Allocator* allocator) {
75 // Use session setting if specified.
76 int32_t intra_op_parallelism_threads =
77 options.config.intra_op_parallelism_threads();
78 // If no session setting, use environment setting.
79 if (intra_op_parallelism_threads == 0) {
80 static int env_num_threads = NumIntraOpThreadsFromEnvironment();
81 intra_op_parallelism_threads = env_num_threads;
82 // If no session setting or environment, compute a reasonable default.
83 if (intra_op_parallelism_threads == 0) {
84 intra_op_parallelism_threads = port::MaxParallelism(numa_node);
85 }
86 }
87 ThreadOptions thread_opts;
88 thread_opts.numa_node = numa_node;
89 eigen_worker_threads_.num_threads = intra_op_parallelism_threads;
90 eigen_worker_threads_.workers = new thread::ThreadPool(
91 options.env, thread_opts, strings::StrCat("numa_", numa_node, "_Eigen"),
92 intra_op_parallelism_threads,
93 !options.config.experimental().disable_thread_spinning(),
94 /*allocator=*/nullptr);
95 Eigen::ThreadPoolInterface* threadpool =
96 eigen_worker_threads_.workers->AsEigenThreadPool();
97 if (allocator != nullptr) {
98 eigen_allocator_.reset(new EigenAllocator(allocator));
99 }
100 eigen_device_.reset(new Eigen::ThreadPoolDevice(
101 threadpool, eigen_worker_threads_.num_threads, eigen_allocator_.get()));
102 }
103
104 ~EigenThreadPoolInfo() {
105 eigen_device_.reset();
106 delete eigen_worker_threads_.workers;
107 }
108
109 DeviceBase::CpuWorkerThreads eigen_worker_threads_;
110 std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
111 std::unique_ptr<EigenAllocator> eigen_allocator_;
112};
113
114LocalDevice::LocalDevice(const SessionOptions& options,
115 const DeviceAttributes& attributes)
116 : Device(options.env, attributes), owned_tp_info_(nullptr) {
117 // Log info messages if TensorFlow is not compiled with instructions that
118 // could speed up performance and are available on the current CPU.
119 port::InfoAboutUnusedCPUFeatures();
120 LocalDevice::EigenThreadPoolInfo* tp_info;
121
122 if (OverrideGlobalThreadPoolFromEnvironment()) {
123 set_use_global_threadpool(false);
124 }
125
126 if (use_global_threadpool_) {
127 mutex_lock l(global_tp_mu_);
128 if (options.config.experimental().use_numa_affinity()) {
129 int numa_node = attributes.locality().numa_node();
130 int num_numa_nodes = port::NUMANumNodes();
131 DCHECK_LT(numa_node, num_numa_nodes);
132 Allocator* numa_allocator =
133 ProcessState::singleton()->GetCPUAllocator(numa_node);
134 while (numa_node >= global_tp_info_.size()) {
135 global_tp_info_.push_back(nullptr);
136 }
137 if (!global_tp_info_[numa_node]) {
138 global_tp_info_[numa_node] = new LocalDevice::EigenThreadPoolInfo(
139 options, numa_node, numa_allocator);
140 }
141 tp_info = global_tp_info_[numa_node];
142 } else {
143 if (global_tp_info_.empty()) {
144 global_tp_info_.push_back(new LocalDevice::EigenThreadPoolInfo(
145 options, port::kNUMANoAffinity, nullptr));
146 }
147 tp_info = global_tp_info_[0];
148 }
149 } else {
150 // Each LocalDevice owns a separate ThreadPoolDevice for numerical
151 // computations.
152 // TODO(tucker): NUMA for these too?
153 owned_tp_info_.reset(new LocalDevice::EigenThreadPoolInfo(
154 options, port::kNUMANoAffinity, nullptr));
155 tp_info = owned_tp_info_.get();
156 }
157 set_tensorflow_cpu_worker_threads(&tp_info->eigen_worker_threads_);
158 set_eigen_cpu_device(tp_info->eigen_device_.get());
159}
160
161LocalDevice::~LocalDevice() {}
162
163} // namespace tensorflow
164