1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/framework/device_base.h"
19
20#include <algorithm>
21#include <vector>
22
23#include "absl/container/flat_hash_set.h"
24#include "absl/synchronization/notification.h"
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26#include "tensorflow/core/util/work_sharder.h"
27
28namespace tensorflow {
29
30DeviceBase::~DeviceBase() {
31 for (auto& temp : eigen_cpu_devices_) {
32 delete temp;
33 }
34 eigen_cpu_devices_.clear();
35}
36
37Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor,
38 StringPiece tensor_name,
39 Device* device,
40 Tensor* cpu_tensor) {
41 absl::Notification n;
42 Status status;
43 CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
44 [&](const Status& s) {
45 status = s;
46 n.Notify();
47 });
48 n.WaitForNotification();
49 return status;
50}
51
52Status DeviceContext::CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor,
53 Device* device,
54 Tensor* device_tensor) const {
55 absl::Notification n;
56 Status status;
57 CopyCPUTensorToDevice(cpu_tensor, device, device_tensor,
58 [&](const Status& s) {
59 status = s;
60 n.Notify();
61 });
62 n.WaitForNotification();
63 return status;
64}
65
66const DeviceAttributes& DeviceBase::attributes() const {
67 LOG(FATAL) << "DeviceBase does not implement attributes()"; // Crash OK
68 std::abort();
69}
70
71const string& DeviceBase::name() const {
72 LOG(FATAL) << "DeviceBase does not implement name()"; // Crash OK
73 std::abort();
74}
75
76const DeviceNameUtils::ParsedName& DeviceBase::parsed_name() const {
77 LOG(FATAL) << "DeviceBase does not implement parsed_name()"; // Crash OK
78 std::abort();
79}
80
81void DeviceBase::set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
82 // Eigen::ThreadPoolDevice is a very cheap struct (two pointers and
83 // an int). Therefore, we can afford a pre-allocated array of
84 // Eigen::ThreadPoolDevice. Here, we ensure that
85 // Eigen::ThreadPoolDevices in eigen_cpu_devices_ has increasingly
86 // larger numThreads.
87 for (int i = 1; i <= d->numThreads(); ++i) {
88 eigen_cpu_devices_.push_back(new Eigen::ThreadPoolDevice(
89 d->getPool(), i /* numThreads() */, d->allocator()));
90 }
91}
92
93const Eigen::ThreadPoolDevice* DeviceBase::eigen_cpu_device() {
94 // Based on GetPerThreadMaxParallelism(), we return a different
95 // pre-allocated Eigen::ThreadPoolDevice. All these ThreadPoolDevice
96 // use the same underlying threadpool. But they use different
97 // nominal numThreads() hoping that the user of the returned
98 // Eigen::ThreadPoolDevice may not aggressively occupy all the
99 // threads in the underlying threadpool.
100 const int parallelism = std::max<int>(
101 1,
102 std::min<int>(GetPerThreadMaxParallelism(), eigen_cpu_devices_.size()));
103 return eigen_cpu_devices_[parallelism - 1];
104}
105
106namespace {
107
108absl::flat_hash_set<std::string>* GetSymbolicDeviceList() {
109 static absl::flat_hash_set<std::string>* symbolic_device_list =
110 new absl::flat_hash_set<std::string>();
111 return symbolic_device_list;
112}
113
114} // namespace
115
116void AddSymbolicExecutionDevice(const absl::string_view device_name) {
117 GetSymbolicDeviceList()->insert(std::string(device_name));
118}
119
120bool IsSymbolicExecutionDevice(const absl::string_view device_name) {
121 absl::flat_hash_set<std::string>* symbolic_devices = GetSymbolicDeviceList();
122 if (symbolic_devices->contains(device_name)) {
123 return true;
124 } else {
125 return false;
126 }
127}
128
129} // namespace tensorflow
130