1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // A Device is a something that can perform computations as part of a |
17 | // model. Devices can be local (runs computation on this machine), or |
18 | // remote (contacts a device local to another machine using an RPC to |
19 | // do the work). Devices are registered in a DeviceSet, which is also |
20 | // responsible for the Device <-> id mapping. |
21 | // |
22 | // Device names |
23 | // * Every Device should have a unique name with the format: |
24 | // /job:___/replica:___/task:___/(gpu|cpu):___ |
25 | // An example name would be "/job:train/replica:0/task:3/device:GPU:2". |
26 | // * Task numbers are within the specified replica, so there are as |
27 | // many "task zeros" as replicas. |
28 | |
29 | #ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ |
30 | #define TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ |
31 | |
32 | #include <memory> |
33 | #include <string> |
34 | |
35 | #include "tensorflow/core/framework/allocator.h" |
36 | #include "tensorflow/core/framework/control_flow.h" |
37 | #include "tensorflow/core/framework/device_attributes.pb.h" |
38 | #include "tensorflow/core/framework/device_base.h" |
39 | #include "tensorflow/core/framework/graph.pb.h" |
40 | #include "tensorflow/core/framework/op_kernel.h" |
41 | #include "tensorflow/core/framework/op_segment.h" |
42 | #include "tensorflow/core/framework/resource_mgr.h" |
43 | #include "tensorflow/core/framework/types.h" |
44 | #include "tensorflow/core/graph/graph.h" |
45 | #include "tensorflow/core/graph/types.h" |
46 | #include "tensorflow/core/platform/errors.h" |
47 | #include "tensorflow/core/platform/macros.h" |
48 | #include "tensorflow/core/platform/status.h" |
49 | #include "tensorflow/core/platform/types.h" |
50 | #include "tensorflow/core/util/device_name_utils.h" |
51 | |
52 | namespace tensorflow { |
53 | |
54 | class Device : public DeviceBase { |
55 | public: |
56 | // Callback type that takes a Status and returns void. |
57 | typedef std::function<void(const Status&)> DoneCallback; |
58 | |
59 | Device(Env* env, const DeviceAttributes& device_attributes); |
60 | ~Device() override; |
61 | |
62 | // Full name of this device (see top comment). |
63 | const std::string& name() const override { return device_attributes_.name(); } |
64 | |
65 | // Parsed name of this device |
66 | const DeviceNameUtils::ParsedName& parsed_name() const { |
67 | return parsed_name_; |
68 | } |
69 | |
70 | // Describes what kind of device this is. This is intended to be |
71 | // human-readable and not computer-parsed, except that two devices |
72 | // with the same device_type() are expected to perform similarly |
73 | // (both from a computation and communication perspective). |
74 | const std::string& device_type() const { |
75 | return device_attributes_.device_type(); |
76 | } |
77 | |
78 | // Returns an aggregation of device attributes. |
79 | const DeviceAttributes& attributes() const override { |
80 | return device_attributes_; |
81 | } |
82 | |
83 | // Performs the actual compute function. |
84 | // |
85 | // Subclasses may override this function if they wish to perform |
86 | // some initialization before each compute. |
87 | virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) { |
88 | op_kernel->Compute(context); |
89 | } |
90 | |
91 | // Asynchronous kernel's compute. |
92 | virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, |
93 | AsyncOpKernel::DoneCallback done) { |
94 | op_kernel->ComputeAsync(context, std::move(done)); |
95 | } |
96 | |
97 | // Blocks until all operations queued on the device at the time of |
98 | // the call have completed. Returns any error pending on the device |
99 | // at completion. |
100 | virtual Status Sync() = 0; |
101 | |
102 | // Calls the given callback when all operations queued on the device at the |
103 | // time of the call have completed. The callback is passed any error pending |
104 | // on the device at completion. |
105 | // TODO(b/112409994): Consolidate these two APIs, removing the synchronous |
106 | // version. |
107 | virtual void Sync(const DoneCallback& done); |
108 | |
109 | // On session completion, the executor may call Device::Sync() depending on |
110 | // flag settings. Override this to return false for devices that don't allow |
111 | // such calls. Instead, these devices must use other mechanisms (such as |
112 | // num_deferred_ops) to ensure the device has finished processing necessary |
113 | // work at session completion. In addition, for these devices, RefreshStatus |
114 | // must be called at session completion to retrieve execution result status. |
115 | // |
116 | // Devices that override this function must also implement RefreshStatus. |
117 | virtual bool AllowsSyncOnCompletion() const { return true; } |
118 | |
119 | // This is used in conjunction with AllowsSyncOnCompletion to allow the |
120 | // executor to get execution result status at session completion. |
121 | // |
122 | // For supported devices, this call returns the underlying device stream's |
123 | // current status in a non-blocking way, without using blocking calls such as |
124 | // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device |
125 | // status is also updated with the retrieved stream status. |
126 | virtual Status RefreshStatus() { |
127 | return errors::Unimplemented( |
128 | "RefreshStatus is not supported on this device." ); |
129 | } |
130 | |
131 | // Optionally modify the device's GraphDef before execution. |
132 | // |
133 | // This method should be considered experimental and is supplied to enable |
134 | // prototyping of TensorFlow device implementations that need to modify |
135 | // the GraphDef before execution. |
136 | // |
137 | // 'graph' supplies the partition of the graph assigned to this |
138 | // device. |
139 | virtual Status MaybeRewriteGraph(std::unique_ptr<Graph>* /*graph*/) { |
140 | return OkStatus(); |
141 | } |
142 | |
143 | // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr |
144 | // if the device does not support contexts. Returns an error status if any |
145 | // error occurred while trying to create a context, otherwise OK. |
146 | // |
147 | // The caller takes ownership of one reference on the output DeviceContext*, |
148 | // and should call Unref(). |
149 | virtual Status TryGetDeviceContext(DeviceContext** out_context) { |
150 | *out_context = nullptr; |
151 | return OkStatus(); |
152 | } |
153 | |
154 | // Returns the op segment of this device. The caller can reuse op |
155 | // kernels registered for the same session running on this device. |
156 | OpSegment* op_segment() { return &op_seg_; } |
157 | |
158 | // Returns the resource manager associated w/ this device. |
159 | virtual ResourceMgr* resource_manager() { return rmgr_; } |
160 | |
161 | // Summarizes the status of this Device, for debugging. |
162 | std::string DebugString() const { return device_attributes_.DebugString(); } |
163 | |
164 | // Assembles the parameter components into a complete DeviceAttributes value. |
165 | static DeviceAttributes BuildDeviceAttributes( |
166 | const std::string& name, DeviceType device, Bytes memory_limit, |
167 | const DeviceLocality& locality, const std::string& physical_device_desc); |
168 | |
169 | static DeviceAttributes BuildDeviceAttributes( |
170 | const std::string& name, DeviceType device, Bytes memory_limit, |
171 | const DeviceLocality& locality) { |
172 | // Pass in an empty string as physical device name. |
173 | return BuildDeviceAttributes(name, device, memory_limit, locality, "" ); |
174 | } |
175 | |
176 | // Updates `attributes()`, indicating the XLA global ID associated with this |
177 | // device. This ID is unique across clients in a multi-client setup. For TPUs |
178 | // this does not happen until the TPU system has been initialized. |
179 | void set_xla_global_id(int64_t id) override { |
180 | device_attributes_.set_xla_global_id(id); |
181 | } |
182 | |
183 | // Clears the resource manager associated with this device. |
184 | void ClearResourceMgr() { rmgr_->Clear(); } |
185 | |
186 | virtual bool IsLocal() const { return true; } |
187 | |
188 | // Informs if this Device can be used as a caller in RemoteCall operation. |
189 | virtual bool IsRemoteCallAllowed() const; |
190 | |
191 | protected: |
192 | void DeleteResourceMgr() { |
193 | delete rmgr_; |
194 | rmgr_ = nullptr; |
195 | } |
196 | |
197 | private: |
198 | DeviceAttributes device_attributes_; |
199 | DeviceNameUtils::ParsedName parsed_name_; |
200 | |
201 | // op_seg_ maps session handle and op name to OpKernel objects. |
202 | OpSegment op_seg_; |
203 | |
204 | // Resources associated w/ this device. E.g., shared variables, etc. |
205 | ResourceMgr* rmgr_ = nullptr; |
206 | |
207 | TF_DISALLOW_COPY_AND_ASSIGN(Device); |
208 | }; |
209 | |
210 | } // namespace tensorflow |
211 | |
212 | #endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ |
213 | |