1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// A Device is a something that can perform computations as part of a
17// model. Devices can be local (runs computation on this machine), or
18// remote (contacts a device local to another machine using an RPC to
19// do the work). Devices are registered in a DeviceSet, which is also
20// responsible for the Device <-> id mapping.
21//
22// Device names
23// * Every Device should have a unique name with the format:
24// /job:___/replica:___/task:___/(gpu|cpu):___
25// An example name would be "/job:train/replica:0/task:3/device:GPU:2".
26// * Task numbers are within the specified replica, so there are as
27// many "task zeros" as replicas.
28
29#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
30#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
31
32#include <memory>
33#include <string>
34
35#include "tensorflow/core/framework/allocator.h"
36#include "tensorflow/core/framework/control_flow.h"
37#include "tensorflow/core/framework/device_attributes.pb.h"
38#include "tensorflow/core/framework/device_base.h"
39#include "tensorflow/core/framework/graph.pb.h"
40#include "tensorflow/core/framework/op_kernel.h"
41#include "tensorflow/core/framework/op_segment.h"
42#include "tensorflow/core/framework/resource_mgr.h"
43#include "tensorflow/core/framework/types.h"
44#include "tensorflow/core/graph/graph.h"
45#include "tensorflow/core/graph/types.h"
46#include "tensorflow/core/platform/errors.h"
47#include "tensorflow/core/platform/macros.h"
48#include "tensorflow/core/platform/status.h"
49#include "tensorflow/core/platform/types.h"
50#include "tensorflow/core/util/device_name_utils.h"
51
52namespace tensorflow {
53
54class Device : public DeviceBase {
55 public:
56 // Callback type that takes a Status and returns void.
57 typedef std::function<void(const Status&)> DoneCallback;
58
59 Device(Env* env, const DeviceAttributes& device_attributes);
60 ~Device() override;
61
62 // Full name of this device (see top comment).
63 const std::string& name() const override { return device_attributes_.name(); }
64
65 // Parsed name of this device
66 const DeviceNameUtils::ParsedName& parsed_name() const {
67 return parsed_name_;
68 }
69
70 // Describes what kind of device this is. This is intended to be
71 // human-readable and not computer-parsed, except that two devices
72 // with the same device_type() are expected to perform similarly
73 // (both from a computation and communication perspective).
74 const std::string& device_type() const {
75 return device_attributes_.device_type();
76 }
77
78 // Returns an aggregation of device attributes.
79 const DeviceAttributes& attributes() const override {
80 return device_attributes_;
81 }
82
83 // Performs the actual compute function.
84 //
85 // Subclasses may override this function if they wish to perform
86 // some initialization before each compute.
87 virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) {
88 op_kernel->Compute(context);
89 }
90
91 // Asynchronous kernel's compute.
92 virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
93 AsyncOpKernel::DoneCallback done) {
94 op_kernel->ComputeAsync(context, std::move(done));
95 }
96
97 // Blocks until all operations queued on the device at the time of
98 // the call have completed. Returns any error pending on the device
99 // at completion.
100 virtual Status Sync() = 0;
101
102 // Calls the given callback when all operations queued on the device at the
103 // time of the call have completed. The callback is passed any error pending
104 // on the device at completion.
105 // TODO(b/112409994): Consolidate these two APIs, removing the synchronous
106 // version.
107 virtual void Sync(const DoneCallback& done);
108
109 // On session completion, the executor may call Device::Sync() depending on
110 // flag settings. Override this to return false for devices that don't allow
111 // such calls. Instead, these devices must use other mechanisms (such as
112 // num_deferred_ops) to ensure the device has finished processing necessary
113 // work at session completion. In addition, for these devices, RefreshStatus
114 // must be called at session completion to retrieve execution result status.
115 //
116 // Devices that override this function must also implement RefreshStatus.
117 virtual bool AllowsSyncOnCompletion() const { return true; }
118
119 // This is used in conjunction with AllowsSyncOnCompletion to allow the
120 // executor to get execution result status at session completion.
121 //
122 // For supported devices, this call returns the underlying device stream's
123 // current status in a non-blocking way, without using blocking calls such as
124 // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
125 // status is also updated with the retrieved stream status.
126 virtual Status RefreshStatus() {
127 return errors::Unimplemented(
128 "RefreshStatus is not supported on this device.");
129 }
130
131 // Optionally modify the device's GraphDef before execution.
132 //
133 // This method should be considered experimental and is supplied to enable
134 // prototyping of TensorFlow device implementations that need to modify
135 // the GraphDef before execution.
136 //
137 // 'graph' supplies the partition of the graph assigned to this
138 // device.
139 virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) {
140 return OkStatus();
141 }
142
143 // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr
144 // if the device does not support contexts. Returns an error status if any
145 // error occurred while trying to create a context, otherwise OK.
146 //
147 // The caller takes ownership of one reference on the output DeviceContext*,
148 // and should call Unref().
149 virtual Status TryGetDeviceContext(DeviceContext** out_context) {
150 *out_context = nullptr;
151 return OkStatus();
152 }
153
154 // Returns the op segment of this device. The caller can reuse op
155 // kernels registered for the same session running on this device.
156 OpSegment* op_segment() { return &op_seg_; }
157
158 // Returns the resource manager associated w/ this device.
159 virtual ResourceMgr* resource_manager() { return rmgr_; }
160
161 // Summarizes the status of this Device, for debugging.
162 std::string DebugString() const { return device_attributes_.DebugString(); }
163
164 // Assembles the parameter components into a complete DeviceAttributes value.
165 static DeviceAttributes BuildDeviceAttributes(
166 const std::string& name, DeviceType device, Bytes memory_limit,
167 const DeviceLocality& locality, const std::string& physical_device_desc);
168
169 static DeviceAttributes BuildDeviceAttributes(
170 const std::string& name, DeviceType device, Bytes memory_limit,
171 const DeviceLocality& locality) {
172 // Pass in an empty string as physical device name.
173 return BuildDeviceAttributes(name, device, memory_limit, locality, "");
174 }
175
176 // Updates `attributes()`, indicating the XLA global ID associated with this
177 // device. This ID is unique across clients in a multi-client setup. For TPUs
178 // this does not happen until the TPU system has been initialized.
179 void set_xla_global_id(int64_t id) override {
180 device_attributes_.set_xla_global_id(id);
181 }
182
183 // Clears the resource manager associated with this device.
184 void ClearResourceMgr() { rmgr_->Clear(); }
185
186 virtual bool IsLocal() const { return true; }
187
188 // Informs if this Device can be used as a caller in RemoteCall operation.
189 virtual bool IsRemoteCallAllowed() const;
190
191 protected:
192 void DeleteResourceMgr() {
193 delete rmgr_;
194 rmgr_ = nullptr;
195 }
196
197 private:
198 DeviceAttributes device_attributes_;
199 DeviceNameUtils::ParsedName parsed_name_;
200
201 // op_seg_ maps session handle and op name to OpKernel objects.
202 OpSegment op_seg_;
203
204 // Resources associated w/ this device. E.g., shared variables, etc.
205 ResourceMgr* rmgr_ = nullptr;
206
207 TF_DISALLOW_COPY_AND_ASSIGN(Device);
208};
209
210} // namespace tensorflow
211
212#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_
213