1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/framework/device_base.h" |
19 | |
20 | #include <algorithm> |
21 | #include <vector> |
22 | |
23 | #include "absl/container/flat_hash_set.h" |
24 | #include "absl/synchronization/notification.h" |
25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
26 | #include "tensorflow/core/util/work_sharder.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | DeviceBase::~DeviceBase() { |
31 | for (auto& temp : eigen_cpu_devices_) { |
32 | delete temp; |
33 | } |
34 | eigen_cpu_devices_.clear(); |
35 | } |
36 | |
37 | Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor, |
38 | StringPiece tensor_name, |
39 | Device* device, |
40 | Tensor* cpu_tensor) { |
41 | absl::Notification n; |
42 | Status status; |
43 | CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, |
44 | [&](const Status& s) { |
45 | status = s; |
46 | n.Notify(); |
47 | }); |
48 | n.WaitForNotification(); |
49 | return status; |
50 | } |
51 | |
52 | Status DeviceContext::CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, |
53 | Device* device, |
54 | Tensor* device_tensor) const { |
55 | absl::Notification n; |
56 | Status status; |
57 | CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, |
58 | [&](const Status& s) { |
59 | status = s; |
60 | n.Notify(); |
61 | }); |
62 | n.WaitForNotification(); |
63 | return status; |
64 | } |
65 | |
66 | const DeviceAttributes& DeviceBase::attributes() const { |
67 | LOG(FATAL) << "DeviceBase does not implement attributes()" ; // Crash OK |
68 | std::abort(); |
69 | } |
70 | |
71 | const string& DeviceBase::name() const { |
72 | LOG(FATAL) << "DeviceBase does not implement name()" ; // Crash OK |
73 | std::abort(); |
74 | } |
75 | |
76 | const DeviceNameUtils::ParsedName& DeviceBase::parsed_name() const { |
77 | LOG(FATAL) << "DeviceBase does not implement parsed_name()" ; // Crash OK |
78 | std::abort(); |
79 | } |
80 | |
81 | void DeviceBase::set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) { |
82 | // Eigen::ThreadPoolDevice is a very cheap struct (two pointers and |
83 | // an int). Therefore, we can afford a pre-allocated array of |
84 | // Eigen::ThreadPoolDevice. Here, we ensure that |
85 | // Eigen::ThreadPoolDevices in eigen_cpu_devices_ has increasingly |
86 | // larger numThreads. |
87 | for (int i = 1; i <= d->numThreads(); ++i) { |
88 | eigen_cpu_devices_.push_back(new Eigen::ThreadPoolDevice( |
89 | d->getPool(), i /* numThreads() */, d->allocator())); |
90 | } |
91 | } |
92 | |
93 | const Eigen::ThreadPoolDevice* DeviceBase::eigen_cpu_device() { |
94 | // Based on GetPerThreadMaxParallelism(), we return a different |
95 | // pre-allocated Eigen::ThreadPoolDevice. All these ThreadPoolDevice |
96 | // use the same underlying threadpool. But they use different |
97 | // nominal numThreads() hoping that the user of the returned |
98 | // Eigen::ThreadPoolDevice may not aggressively occupy all the |
99 | // threads in the underlying threadpool. |
100 | const int parallelism = std::max<int>( |
101 | 1, |
102 | std::min<int>(GetPerThreadMaxParallelism(), eigen_cpu_devices_.size())); |
103 | return eigen_cpu_devices_[parallelism - 1]; |
104 | } |
105 | |
106 | namespace { |
107 | |
108 | absl::flat_hash_set<std::string>* GetSymbolicDeviceList() { |
109 | static absl::flat_hash_set<std::string>* symbolic_device_list = |
110 | new absl::flat_hash_set<std::string>(); |
111 | return symbolic_device_list; |
112 | } |
113 | |
114 | } // namespace |
115 | |
116 | void AddSymbolicExecutionDevice(const absl::string_view device_name) { |
117 | GetSymbolicDeviceList()->insert(std::string(device_name)); |
118 | } |
119 | |
120 | bool IsSymbolicExecutionDevice(const absl::string_view device_name) { |
121 | absl::flat_hash_set<std::string>* symbolic_devices = GetSymbolicDeviceList(); |
122 | if (symbolic_devices->contains(device_name)) { |
123 | return true; |
124 | } else { |
125 | return false; |
126 | } |
127 | } |
128 | |
129 | } // namespace tensorflow |
130 | |