1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ |
18 | |
19 | #include "tensorflow/core/common_runtime/device.h" |
20 | #include "tensorflow/core/lib/core/threadpool_interface.h" |
21 | #include "tensorflow/core/util/device_name_utils.h" |
22 | |
23 | namespace tensorflow { |
24 | |
25 | // Wraps a device with a new name, delegating work to the wrapped device. |
26 | // |
27 | // This class is used to wrap local devices when using clusterspec propagation |
28 | // where the name of a particular device may change in the context of a given |
29 | // session. |
30 | class RenamedDevice : public Device { |
31 | public: |
32 | static std::unique_ptr<Device> NewRenamedDevice( |
33 | const string& new_base, Device* underlying, bool owns_underlying, |
34 | bool isolate_session_state, |
35 | thread::ThreadPoolInterface* underlying_threadpool = nullptr); |
36 | |
37 | ~RenamedDevice() override; |
38 | |
39 | const DeviceBase* UnderlyingDevice() const override { |
40 | return underlying_device_->UnderlyingDevice(); |
41 | } |
42 | DeviceBase* UnderlyingDevice() override { |
43 | return underlying_device_->UnderlyingDevice(); |
44 | } |
45 | |
46 | const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { |
47 | if (underlying_threadpool_) { |
48 | return Device::tensorflow_cpu_worker_threads(); |
49 | } |
50 | return underlying_device_->tensorflow_cpu_worker_threads(); |
51 | } |
52 | |
53 | const DeviceBase::AcceleratorDeviceInfo* tensorflow_accelerator_device_info() |
54 | const override { |
55 | return underlying_device_->tensorflow_accelerator_device_info(); |
56 | } |
57 | |
58 | Allocator* GetAllocator(AllocatorAttributes attr) override { |
59 | return underlying_device_->GetAllocator(attr); |
60 | } |
61 | |
62 | Allocator* GetScopedAllocator(AllocatorAttributes attr, |
63 | int64_t step_id) override { |
64 | return underlying_device_->GetScopedAllocator(attr, step_id); |
65 | } |
66 | |
67 | ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { |
68 | return underlying_device_->GetScopedAllocatorMgr(); |
69 | } |
70 | |
71 | const Eigen::ThreadPoolDevice* eigen_cpu_device() override { |
72 | // Use the underlying threadpool only if the underlying device supports |
73 | // eigen_cpu_device. |
74 | if (underlying_threadpool_ && underlying_device_->has_eigen_cpu_device()) { |
75 | return Device::eigen_cpu_device(); |
76 | } |
77 | return underlying_device_->eigen_cpu_device(); |
78 | } |
79 | |
80 | thread::ThreadPool* tensorflow_device_thread_pool() override { |
81 | // Use the underlying threadpool instead of tensorflow_device_thread_pool |
82 | // of the underlying device only if tensorflow_device_thread_pool is defined |
83 | // for the underlying device. |
84 | if (underlying_threadpool_ && |
85 | underlying_device_->tensorflow_device_thread_pool() != nullptr) { |
86 | return Device::tensorflow_device_thread_pool(); |
87 | } |
88 | return underlying_device_->tensorflow_device_thread_pool(); |
89 | } |
90 | |
91 | bool has_eigen_cpu_device() const override { |
92 | return underlying_device_->has_eigen_cpu_device(); |
93 | } |
94 | |
95 | |
96 | PerOpGpuDevice* MakeGpuDevice() override { |
97 | return underlying_device_->MakeGpuDevice(); |
98 | } |
99 | |
100 | Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, |
101 | DeviceContext* dc, |
102 | Allocator* allocator) override { |
103 | return underlying_device_->ReinitializeGpuDevice(context, device, dc, |
104 | allocator); |
105 | } |
106 | |
107 | Status MakeTensorFromProto(const TensorProto& tensor_proto, |
108 | const AllocatorAttributes alloc_attrs, |
109 | Tensor* tensor) override { |
110 | return underlying_device_->MakeTensorFromProto(tensor_proto, alloc_attrs, |
111 | tensor); |
112 | } |
113 | |
114 | void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, |
115 | const DeviceContext* device_context, |
116 | StatusCallback done) override { |
117 | underlying_device_->CopyTensorInSameDevice(input_tensor, output_tensor, |
118 | device_context, std::move(done)); |
119 | } |
120 | |
121 | // Below are virtual methods defined on Device |
122 | |
123 | void Compute(OpKernel* op_kernel, OpKernelContext* context) override { |
124 | underlying_device_->Compute(op_kernel, context); |
125 | } |
126 | |
127 | void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, |
128 | AsyncOpKernel::DoneCallback done) override { |
129 | underlying_device_->ComputeAsync(op_kernel, context, std::move(done)); |
130 | } |
131 | |
132 | Status Sync() override { return underlying_device_->Sync(); } |
133 | |
134 | Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override { |
135 | return underlying_device_->MaybeRewriteGraph(graph); |
136 | } |
137 | |
138 | Status TryGetDeviceContext(DeviceContext** out_context) override { |
139 | return underlying_device_->TryGetDeviceContext(out_context); |
140 | } |
141 | |
142 | // Returns the resource manager associated w/ this device. |
143 | ResourceMgr* resource_manager() override { |
144 | if (isolate_session_state_) { |
145 | return Device::resource_manager(); |
146 | } else { |
147 | return underlying_device_->resource_manager(); |
148 | } |
149 | } |
150 | |
151 | bool IsLocal() const override { return underlying_device_->IsLocal(); } |
152 | |
153 | bool IsRemoteCallAllowed() const override { |
154 | return underlying_device_->IsRemoteCallAllowed(); |
155 | } |
156 | |
157 | private: |
158 | RenamedDevice(Device* underlying, const DeviceAttributes& attributes, |
159 | bool owns_underlying, bool isolate_session_state, |
160 | thread::ThreadPoolInterface* underlying_threadpool); |
161 | Device* const underlying_device_; |
162 | const bool owns_underlying_device_; |
163 | const bool isolate_session_state_; |
164 | |
165 | std::unique_ptr<thread::ThreadPool> underlying_threadpool_; |
166 | // eigen_worker_threads_ is stored here so that we can pass the pointer |
167 | // of eigen_worker_threads_.workers to the parent class. |
168 | DeviceBase::CpuWorkerThreads eigen_worker_threads_; |
169 | }; |
170 | |
171 | } // namespace tensorflow |
172 | |
173 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ |
174 | |