1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ |
17 | |
18 | #include <algorithm> |
19 | #include <cstddef> |
20 | #include <memory> |
21 | #include <queue> |
22 | #include <string> |
23 | #include <unordered_map> |
24 | #include <vector> |
25 | |
26 | // clang-format off |
27 | // Required for IS_MOBILE_PLATFORM |
28 | #include "tensorflow/core/framework/shape_inference.h" |
29 | #include "tensorflow/core/framework/tensor_shape.h" |
30 | #include "tensorflow/core/platform/platform.h" |
31 | // clang-format on |
32 | |
33 | #include "absl/types/variant.h" |
34 | #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" |
35 | #include "tensorflow/core/common_runtime/device.h" |
36 | #include "tensorflow/core/common_runtime/eager/eager_executor.h" |
37 | #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" |
38 | #include "tensorflow/core/common_runtime/function.h" |
39 | #if !defined(IS_MOBILE_PLATFORM) |
40 | #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" |
41 | #endif // IS_MOBILE_PLATFORM |
42 | #include "tensorflow/core/framework/tensor.h" |
43 | |
44 | #include "tensorflow/core/lib/core/stringpiece.h" |
45 | |
46 | #include "tensorflow/core/platform/mutex.h" |
47 | #include "tensorflow/core/platform/thread_annotations.h" |
48 | |
49 | namespace tensorflow { |
50 | |
51 | class EagerContext; |
52 | |
53 | // Associates a Tensor and a Device, used in the eager runtime. Internal version |
54 | // of the TFE_TensorHandle struct and the python EagerTensor class |
55 | // (unrelated to python TensorHandle). |
56 | class TensorHandle : public ImmediateExecutionTensorHandle { |
57 | // TensorHandle for dtype != DT_RESOURCE |
58 | TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, |
59 | Device* resource_device, EagerContext* ctx); |
60 | // TensorHandle for dtype == DT_RESOURCE |
61 | TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, |
62 | EagerContext* ctx); |
63 | TensorHandle(Device* d, Device* op_device, Device* resource_device, |
64 | tensorflow::DataType dtype, EagerContext* ctx); |
65 | |
66 | #if !defined(IS_MOBILE_PLATFORM) |
67 | TensorHandle(int64_t op_id, int32_t output_num, const string& remote_task, |
68 | tensorflow::DataType dtype, Device* device, EagerContext* ctx, |
69 | const bool unknown_device); |
70 | TensorHandle(int64_t op_id, int32_t output_num, tensorflow::DataType dtype, |
71 | Device* device, const bool is_ready, EagerContext* ctx); |
72 | #endif // IS_MOBILE_PLATFORM |
73 | |
74 | public: |
75 | // TensorHandle with no assigned device |
76 | static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t); |
77 | static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d, |
78 | Device* op_device, EagerContext* ctx); |
79 | static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d, |
80 | Device* op_device, |
81 | Device* resource_device, |
82 | EagerContext* ctx); |
83 | static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device, |
84 | Device* resource_device, |
85 | tensorflow::DataType dtype, |
86 | EagerContext* ctx); |
87 | |
88 | // Create a handle which packs the given handles of the same dtype and shape. |
89 | // If handles are on different devices, assign the packed handle to a |
90 | // CompositeDevice. |
91 | // |
92 | // The new tensor handle shares ownership of the given handle: their reference |
93 | // count will be increased by one after a call to `CreatePackedHandle`. |
94 | // TODO(b/170414377): Use `TensorHandlePtr` instead. |
95 | static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles, |
96 | const tensorflow::DataType dtype, |
97 | const tensorflow::TensorShape& shape, |
98 | const string& device_name, EagerContext* ctx, |
99 | TensorHandle** packed_handle); |
100 | static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles, |
101 | EagerContext* ctx, |
102 | TensorHandle** packed_handle); |
103 | |
104 | #if !defined(IS_MOBILE_PLATFORM) |
105 | // An unshaped remote handle refers to a tensor on a remote worker. It's not |
106 | // ready until the shape is set. It controls the lifetime of the remote |
107 | // tensor. |
108 | static TensorHandle* CreateUnshapedRemoteHandle( |
109 | int64_t op_id, int32_t output_num, const string& remote_task, |
110 | tensorflow::DataType dtype, Device* d, EagerContext* ctx, |
111 | const bool unknown_device = false); |
112 | // A lazy remote handle refers to a tensor on a remote worker. The lifetime of |
113 | // the remote tensor is controlled by the remote worker, but not by the lazy |
114 | // remote handle. Lazy handles are normally created on a default function |
115 | // device. |
116 | static TensorHandle* CreateLazyRemoteHandle(int64_t op_id, int32_t output_num, |
117 | tensorflow::DataType dtype, |
118 | Device* d, const bool is_ready, |
119 | EagerContext* ctx); |
120 | #endif // IS_MOBILE_PLATFORM |
121 | |
122 | void Release() override; |
123 | |
124 | tensorflow::DataType DataType() const override; |
125 | Status Shape(tensorflow::PartialTensorShape* shape) const override; |
126 | Status NumDims(int* num_dims) const override; |
127 | Status NumElements(int64_t* num_elements) const override; |
128 | Status Dim(int dim_index, int64_t* dim) const override; |
129 | |
130 | const char* DeviceName(Status* status) const override; |
131 | const char* BackingDeviceName(Status* status) const override; |
132 | const char* DeviceType(Status* status) const override; |
133 | int DeviceId(Status* status) const override; |
134 | AbstractTensorInterface* Resolve(Status* status) override; |
135 | |
136 | ImmediateExecutionTensorHandle* Copy() override; |
137 | |
138 | // Subclasses may return True to instruct the string formatter |
139 | // to use SummarizeValue instead of the NumPy formatter. |
140 | bool PreferCustomSummarizer() const override { |
141 | return dtype == DT_VARIANT || dtype == DT_RESOURCE; |
142 | } |
143 | |
144 | // Return the Tensor from the default device. |
145 | Status Tensor(const tensorflow::Tensor** t) const; |
146 | // Return the Tensor from the specified device which could be either the |
147 | // default device or a local mirror. The device pointer should be nullptr if |
148 | // requesting the HostCPU. |
149 | Status TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const; |
150 | |
151 | // Return the TensorValue from the specified device which could be either the |
152 | // default device or a local mirror. The device pointer should be nullptr if |
153 | // requesting the HostCPU. |
154 | Status TensorValue(const Device* d, tensorflow::TensorValue* t); |
155 | |
156 | Device* device() const { return device_; } |
157 | Device* op_device() const { return op_device_; } |
158 | Device* resource_device() const { return resource_device_; } |
159 | int64_t resource_remote_device_incarnation() const { |
160 | return resource_remote_device_incarnation_; |
161 | } |
162 | |
163 | // If the devices are unknown at creation time, block until the actual devices |
164 | // are set (data is ready). |
165 | Status WaitUnknownDevice() const; |
166 | |
167 | Device* DeviceOrHostCPU(const EagerContext& ctx) const; |
168 | |
169 | Status Shape(tensorflow::TensorShape* shape); |
170 | |
171 | Status Unprotect(const Device* d); |
172 | |
173 | // Checks if a mirror tensor exists for the specified device. Mirrors are only |
174 | // maintained for local devices, like CPUs & GPUs. Note a mirror may be empty, |
175 | // as it is still to be set by an async operation. |
176 | bool HasLocalMirror(const Device* d) const; |
177 | // Add an empty mirror placeholder for the specified device. The expectation |
178 | // is this will be populated by a call to SetTensor. |
179 | Status AddEmptyLocalMirror(const Device* d); |
180 | // Add a local mirror. This will fail if an empty local mirror was previously |
181 | // added. For that case, SetTensor should be used instead. |
182 | Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d); |
183 | |
184 | #if !defined(IS_MOBILE_PLATFORM) |
185 | bool HasRemoteMirror(const Device* d, uint64 context_view_id) const; |
186 | bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const; |
187 | |
188 | Status AddUnshapedRemoteMirror(const Device* d, int64_t op_id, int output_num, |
189 | const string& remote_task, EagerContext* ctx); |
190 | Status AddResourceShapeMirror(const Device* d, int64_t op_id, int output_num, |
191 | EagerContext* ctx); |
192 | |
193 | // Return the op_id and output num if the handle refers to a remote tensor. |
194 | // If wait_until_ready is true, block until the remote tensor is ready on the |
195 | // given remote worker. |
196 | Status RemoteAddress(const Device* d, const bool wait_until_ready, |
197 | int64_t* op_id, int32* output_num) const; |
198 | |
199 | // Called on an async remote tensor once it's shape has been determined. This |
200 | // transitions the tensor handle from a non-ready to a ready state by |
201 | // replacing the backing data abstraction to allow for the shape to be |
202 | // queried. |
203 | // creating a TensorHandle (e.g. a remote output of a remote function). |
204 | // This method or Poison must be called exactly once for remote tensors that |
205 | // were created without a known shape. |
206 | Status SetRemoteShape(const TensorShape& shape, const Device* d, |
207 | uint64 context_view_id); |
208 | // If op_device is not empty, reset the devices of a remote tensor which is |
209 | // created without known devices (e.g. function outputs). |
210 | Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d, |
211 | uint64 context_view_id, string op_device); |
212 | |
213 | // Poisons either this handle or a remote mirror with error `status`. |
214 | // Poisoning means that the handle will become ready and methods trying |
215 | // to access the remote shape will return this error `status`. |
216 | // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a |
217 | // unshaped handle on a remote device. |
218 | void PoisonRemote(Status status, const Device* d, uint64 context_view_id); |
219 | #endif |
220 | |
221 | // Sets the `tensor` for this async non-ready handle making it ready. |
222 | // This method or Poison must be called exactly once for non-ready async |
223 | // handles to make them ready. |
224 | Status SetTensor(tensorflow::Tensor&& tensor, const Device* d); |
225 | |
226 | // Poisons either this handle or a local mirror with error `status`. |
227 | // Poisoning means that the handle will become ready and methods trying |
228 | // to access the actual tensor or shape will return this error `status`. |
229 | // Exactly one of SetTensor or Poison methods must be called on a non-ready |
230 | // tensor for a specific device. |
231 | void Poison(Status status, const Device* d); |
232 | |
233 | // TODO(b/154282629): Consider moving it to EagerContext. |
234 | // Copies to the tensor on the given device `d`, or to host iff `d` is null. |
235 | Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, |
236 | tensorflow::Tensor* output) const; |
237 | |
238 | Status InferenceShape( |
239 | shape_inference::InferenceContext* const inference_context, |
240 | shape_inference::ShapeHandle* shape_handle); |
241 | void SetInferenceShape( |
242 | shape_inference::InferenceContext* const inference_context, |
243 | const shape_inference::ShapeHandle& shape_handle); |
244 | Status CopyInferenceShape(TensorHandle* other); |
245 | |
246 | // dtype for the handle. It must be the same as t.dtype() once the handle is |
247 | // ready. |
248 | const tensorflow::DataType dtype; |
249 | |
250 | enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 }; |
251 | |
252 | HandleType Type() const; |
253 | string TypeString() const; |
254 | |
255 | void SetResourceHandleDtypeAndShape( |
256 | std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes); |
257 | |
258 | // If this TensorHandle is 1) a local tensor, and 2) a resource handle, |
259 | // return data types and shapes of the underlying resource. |
260 | Status GetResourceHandleDtypesAndShapes( |
261 | std::vector<DtypeAndPartialTensorShape>* result); |
262 | |
263 | // Returns the number of packed handles. 0 if the handle type is not PACKED. |
264 | int NumPackedHandles() const; |
265 | // It's called on a packed TensorHandle. Extract a handle with the given |
266 | // index. |
267 | Status ExtractPackedHandle(const int index, TensorHandle** handle) const; |
268 | |
269 | // For LLVM style RTTI. |
270 | static bool classof(const AbstractTensorHandle* ptr) { |
271 | return ptr->getKind() == kEager; |
272 | } |
273 | |
274 | private: |
275 | friend class PackedTensorHandleTest; |
276 | |
277 | TensorHandle(std::vector<TensorHandle*>&& handles, Device* device, |
278 | const tensorflow::DataType dtype, |
279 | const tensorflow::TensorShape& shape, EagerContext* ctx); |
280 | |
281 | ~TensorHandle() override; |
282 | |
283 | // The TensorHandleData can either represent a local or remote tensor handle. |
284 | // Further, it can be in a non-ready state. It would become ready with a call |
285 | // to either SetTensor or SetRemoteShape which replaces the underlying data |
286 | // with a ready version of the tensor handle data. |
287 | bool IsReady() const; |
288 | Status WaitReady(const char* caller) const; |
289 | |
290 | tensorflow::Device* device_; |
291 | |
292 | // Device in which the op producing this tensor was executed. Equals to |
293 | // device_ for constant tensors. |
294 | // Can be nullptr if the op producing this tensor was a function executed |
295 | // with function library runtime. |
296 | tensorflow::Device* op_device_; |
297 | |
298 | // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device |
299 | // backing the resource. Else resource_device_ is nullptr. |
300 | tensorflow::Device* resource_device_; |
301 | // Incarnation ID of the resource device if it locates on a remote device, or |
302 | // 0 if it locates on a local device. |
303 | int64_t resource_remote_device_incarnation_; |
304 | |
305 | // If true, the handle refers to a remote tensor which is created without |
306 | // known devices. The actual devices are set by SetRemoteShape. The devices |
307 | // should be accessed once the handle is ready. |
308 | const bool unknown_device_ = false; |
309 | |
310 | mutable mutex mu_; |
311 | |
312 | // Map of local mirrors. This can include both ready and non-ready mirrors. |
313 | std::unordered_map<const tensorflow::Device*, LocalTensorHandleData> |
314 | local_mirrors_ TF_GUARDED_BY(mu_); |
315 | #if !defined(IS_MOBILE_PLATFORM) |
316 | // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica |
317 | // variable is ready, since we could get the shape locally without remote copy |
318 | // then. |
319 | std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_ |
320 | TF_GUARDED_BY(mu_); |
321 | // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be |
322 | // a fixed size map. |
323 | std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_ |
324 | TF_GUARDED_BY(mu_); |
325 | #endif |
326 | |
327 | // `ctx` is only guaranteed to be set if the handle is not "ready". This is |
328 | // typically true when the handle was produced during async execution. |
329 | // `ctx` object is not owned and should outlive this handle. |
330 | // |
331 | // TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of |
332 | // a TensorHandle does not outlive the EagerContext from which it came? |
333 | EagerContext* const ctx_; |
334 | |
335 | // Does not need synchronization because it can be accessed only after |
336 | // WaitReady() has returned. At that point, is_poisoned_ is immutable. |
337 | Status is_poisoned_; |
338 | |
339 | // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or |
340 | // refers to a remote resource handle, we store data types and shapes for |
341 | // the underlying resource. |
342 | std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_; |
343 | |
344 | // A handle data which refers to multiple TensorHandles of the same dtype and |
345 | // shape. |
346 | class PackedTensorHandleData { |
347 | public: |
348 | // Initialize handle data from list of tensor handles. |
349 | // Ownership of the tensor handles is shared between the |
350 | // `PackedTensorHandleData` and the caller (the reference count for the |
351 | // given handles is incremented). |
352 | // TODO(b/170414377): Use `TensorHandlePtr` instead. |
353 | PackedTensorHandleData(std::vector<TensorHandle*>&& handles, |
354 | const TensorShape& shape); |
355 | |
356 | ~PackedTensorHandleData(); |
357 | |
358 | Status Shape(TensorShape* shape) const; |
359 | Status NumDims(int* num_dims) const; |
360 | Status Dim(int dim_index, int64_t* dim) const; |
361 | Status NumElements(int64_t* num_elements) const; |
362 | Status Unprotect(); |
363 | bool IsReady() const; |
364 | Status WaitReady(const char* caller) const; |
365 | void Poison(Status status); |
366 | string DebugString() const; |
367 | |
368 | // Number of packed handles. |
369 | int NumPackedHandles() const; |
370 | // Extract a handle on the given index. |
371 | Status ExtractPackedHandle(const int index, TensorHandle** handle) const; |
372 | |
373 | private: |
374 | // TODO(b/170414377): Use `TensorHandlePtr` instead. |
375 | const std::vector<TensorHandle*> handles_; |
376 | const TensorShape shape_; |
377 | |
378 | mutable mutex mu_; |
379 | Status is_poisoned_ TF_GUARDED_BY(mu_); |
380 | }; |
381 | |
382 | // Does not need synchronization because it can be accessed only after |
383 | // WaitReady() has returned. At that point, data_ is immutable. |
384 | #if !defined(IS_MOBILE_PLATFORM) |
385 | absl::variant<LocalTensorHandleData, PackedTensorHandleData, |
386 | RemoteTensorHandleData> |
387 | data_; |
388 | #else |
389 | absl::variant<LocalTensorHandleData, PackedTensorHandleData> data_; |
390 | #endif |
391 | |
392 | PartialTensorShape inference_shape_; |
393 | }; |
394 | |
395 | // Returns the device backing the resource. Else, returns nullptr. |
396 | Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); |
397 | |
398 | class TensorHandleInterface : public ImmediateExecutionTensorHandle { |
399 | public: |
400 | }; |
401 | |
402 | template <typename T> |
403 | inline TensorHandle* TensorHandleFromInterface(T* handle) { |
404 | return down_cast<TensorHandle*>(handle); |
405 | } |
406 | |
407 | } // namespace tensorflow |
408 | |
409 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ |
410 | |