1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
17
18#include <algorithm>
19#include <cstddef>
20#include <memory>
21#include <queue>
22#include <string>
23#include <unordered_map>
24#include <vector>
25
26// clang-format off
27// Required for IS_MOBILE_PLATFORM
28#include "tensorflow/core/framework/shape_inference.h"
29#include "tensorflow/core/framework/tensor_shape.h"
30#include "tensorflow/core/platform/platform.h"
31// clang-format on
32
33#include "absl/types/variant.h"
34#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
35#include "tensorflow/core/common_runtime/device.h"
36#include "tensorflow/core/common_runtime/eager/eager_executor.h"
37#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
38#include "tensorflow/core/common_runtime/function.h"
39#if !defined(IS_MOBILE_PLATFORM)
40#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
41#endif // IS_MOBILE_PLATFORM
42#include "tensorflow/core/framework/tensor.h"
43
44#include "tensorflow/core/lib/core/stringpiece.h"
45
46#include "tensorflow/core/platform/mutex.h"
47#include "tensorflow/core/platform/thread_annotations.h"
48
49namespace tensorflow {
50
51class EagerContext;
52
53// Associates a Tensor and a Device, used in the eager runtime. Internal version
54// of the TFE_TensorHandle struct and the python EagerTensor class
55// (unrelated to python TensorHandle).
56class TensorHandle : public ImmediateExecutionTensorHandle {
57 // TensorHandle for dtype != DT_RESOURCE
58 TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
59 Device* resource_device, EagerContext* ctx);
60 // TensorHandle for dtype == DT_RESOURCE
61 TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
62 EagerContext* ctx);
63 TensorHandle(Device* d, Device* op_device, Device* resource_device,
64 tensorflow::DataType dtype, EagerContext* ctx);
65
66#if !defined(IS_MOBILE_PLATFORM)
67 TensorHandle(int64_t op_id, int32_t output_num, const string& remote_task,
68 tensorflow::DataType dtype, Device* device, EagerContext* ctx,
69 const bool unknown_device);
70 TensorHandle(int64_t op_id, int32_t output_num, tensorflow::DataType dtype,
71 Device* device, const bool is_ready, EagerContext* ctx);
72#endif // IS_MOBILE_PLATFORM
73
74 public:
75 // TensorHandle with no assigned device
76 static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t);
77 static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
78 Device* op_device, EagerContext* ctx);
79 static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
80 Device* op_device,
81 Device* resource_device,
82 EagerContext* ctx);
83 static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
84 Device* resource_device,
85 tensorflow::DataType dtype,
86 EagerContext* ctx);
87
88 // Create a handle which packs the given handles of the same dtype and shape.
89 // If handles are on different devices, assign the packed handle to a
90 // CompositeDevice.
91 //
92 // The new tensor handle shares ownership of the given handle: their reference
93 // count will be increased by one after a call to `CreatePackedHandle`.
94 // TODO(b/170414377): Use `TensorHandlePtr` instead.
95 static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
96 const tensorflow::DataType dtype,
97 const tensorflow::TensorShape& shape,
98 const string& device_name, EagerContext* ctx,
99 TensorHandle** packed_handle);
100 static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
101 EagerContext* ctx,
102 TensorHandle** packed_handle);
103
104#if !defined(IS_MOBILE_PLATFORM)
105 // An unshaped remote handle refers to a tensor on a remote worker. It's not
106 // ready until the shape is set. It controls the lifetime of the remote
107 // tensor.
108 static TensorHandle* CreateUnshapedRemoteHandle(
109 int64_t op_id, int32_t output_num, const string& remote_task,
110 tensorflow::DataType dtype, Device* d, EagerContext* ctx,
111 const bool unknown_device = false);
112 // A lazy remote handle refers to a tensor on a remote worker. The lifetime of
113 // the remote tensor is controlled by the remote worker, but not by the lazy
114 // remote handle. Lazy handles are normally created on a default function
115 // device.
116 static TensorHandle* CreateLazyRemoteHandle(int64_t op_id, int32_t output_num,
117 tensorflow::DataType dtype,
118 Device* d, const bool is_ready,
119 EagerContext* ctx);
120#endif // IS_MOBILE_PLATFORM
121
122 void Release() override;
123
124 tensorflow::DataType DataType() const override;
125 Status Shape(tensorflow::PartialTensorShape* shape) const override;
126 Status NumDims(int* num_dims) const override;
127 Status NumElements(int64_t* num_elements) const override;
128 Status Dim(int dim_index, int64_t* dim) const override;
129
130 const char* DeviceName(Status* status) const override;
131 const char* BackingDeviceName(Status* status) const override;
132 const char* DeviceType(Status* status) const override;
133 int DeviceId(Status* status) const override;
134 AbstractTensorInterface* Resolve(Status* status) override;
135
136 ImmediateExecutionTensorHandle* Copy() override;
137
138 // Subclasses may return True to instruct the string formatter
139 // to use SummarizeValue instead of the NumPy formatter.
140 bool PreferCustomSummarizer() const override {
141 return dtype == DT_VARIANT || dtype == DT_RESOURCE;
142 }
143
144 // Return the Tensor from the default device.
145 Status Tensor(const tensorflow::Tensor** t) const;
146 // Return the Tensor from the specified device which could be either the
147 // default device or a local mirror. The device pointer should be nullptr if
148 // requesting the HostCPU.
149 Status TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const;
150
151 // Return the TensorValue from the specified device which could be either the
152 // default device or a local mirror. The device pointer should be nullptr if
153 // requesting the HostCPU.
154 Status TensorValue(const Device* d, tensorflow::TensorValue* t);
155
156 Device* device() const { return device_; }
157 Device* op_device() const { return op_device_; }
158 Device* resource_device() const { return resource_device_; }
159 int64_t resource_remote_device_incarnation() const {
160 return resource_remote_device_incarnation_;
161 }
162
163 // If the devices are unknown at creation time, block until the actual devices
164 // are set (data is ready).
165 Status WaitUnknownDevice() const;
166
167 Device* DeviceOrHostCPU(const EagerContext& ctx) const;
168
169 Status Shape(tensorflow::TensorShape* shape);
170
171 Status Unprotect(const Device* d);
172
173 // Checks if a mirror tensor exists for the specified device. Mirrors are only
174 // maintained for local devices, like CPUs & GPUs. Note a mirror may be empty,
175 // as it is still to be set by an async operation.
176 bool HasLocalMirror(const Device* d) const;
177 // Add an empty mirror placeholder for the specified device. The expectation
178 // is this will be populated by a call to SetTensor.
179 Status AddEmptyLocalMirror(const Device* d);
180 // Add a local mirror. This will fail if an empty local mirror was previously
181 // added. For that case, SetTensor should be used instead.
182 Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d);
183
184#if !defined(IS_MOBILE_PLATFORM)
185 bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
186 bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
187
188 Status AddUnshapedRemoteMirror(const Device* d, int64_t op_id, int output_num,
189 const string& remote_task, EagerContext* ctx);
190 Status AddResourceShapeMirror(const Device* d, int64_t op_id, int output_num,
191 EagerContext* ctx);
192
193 // Return the op_id and output num if the handle refers to a remote tensor.
194 // If wait_until_ready is true, block until the remote tensor is ready on the
195 // given remote worker.
196 Status RemoteAddress(const Device* d, const bool wait_until_ready,
197 int64_t* op_id, int32* output_num) const;
198
199 // Called on an async remote tensor once it's shape has been determined. This
200 // transitions the tensor handle from a non-ready to a ready state by
201 // replacing the backing data abstraction to allow for the shape to be
202 // queried.
203 // creating a TensorHandle (e.g. a remote output of a remote function).
204 // This method or Poison must be called exactly once for remote tensors that
205 // were created without a known shape.
206 Status SetRemoteShape(const TensorShape& shape, const Device* d,
207 uint64 context_view_id);
208 // If op_device is not empty, reset the devices of a remote tensor which is
209 // created without known devices (e.g. function outputs).
210 Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d,
211 uint64 context_view_id, string op_device);
212
213 // Poisons either this handle or a remote mirror with error `status`.
214 // Poisoning means that the handle will become ready and methods trying
215 // to access the remote shape will return this error `status`.
216 // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a
217 // unshaped handle on a remote device.
218 void PoisonRemote(Status status, const Device* d, uint64 context_view_id);
219#endif
220
221 // Sets the `tensor` for this async non-ready handle making it ready.
222 // This method or Poison must be called exactly once for non-ready async
223 // handles to make them ready.
224 Status SetTensor(tensorflow::Tensor&& tensor, const Device* d);
225
226 // Poisons either this handle or a local mirror with error `status`.
227 // Poisoning means that the handle will become ready and methods trying
228 // to access the actual tensor or shape will return this error `status`.
229 // Exactly one of SetTensor or Poison methods must be called on a non-ready
230 // tensor for a specific device.
231 void Poison(Status status, const Device* d);
232
233 // TODO(b/154282629): Consider moving it to EagerContext.
234 // Copies to the tensor on the given device `d`, or to host iff `d` is null.
235 Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d,
236 tensorflow::Tensor* output) const;
237
238 Status InferenceShape(
239 shape_inference::InferenceContext* const inference_context,
240 shape_inference::ShapeHandle* shape_handle);
241 void SetInferenceShape(
242 shape_inference::InferenceContext* const inference_context,
243 const shape_inference::ShapeHandle& shape_handle);
244 Status CopyInferenceShape(TensorHandle* other);
245
246 // dtype for the handle. It must be the same as t.dtype() once the handle is
247 // ready.
248 const tensorflow::DataType dtype;
249
250 enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 };
251
252 HandleType Type() const;
253 string TypeString() const;
254
255 void SetResourceHandleDtypeAndShape(
256 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
257
258 // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
259 // return data types and shapes of the underlying resource.
260 Status GetResourceHandleDtypesAndShapes(
261 std::vector<DtypeAndPartialTensorShape>* result);
262
263 // Returns the number of packed handles. 0 if the handle type is not PACKED.
264 int NumPackedHandles() const;
265 // It's called on a packed TensorHandle. Extract a handle with the given
266 // index.
267 Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
268
269 // For LLVM style RTTI.
270 static bool classof(const AbstractTensorHandle* ptr) {
271 return ptr->getKind() == kEager;
272 }
273
274 private:
275 friend class PackedTensorHandleTest;
276
277 TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
278 const tensorflow::DataType dtype,
279 const tensorflow::TensorShape& shape, EagerContext* ctx);
280
281 ~TensorHandle() override;
282
283 // The TensorHandleData can either represent a local or remote tensor handle.
284 // Further, it can be in a non-ready state. It would become ready with a call
285 // to either SetTensor or SetRemoteShape which replaces the underlying data
286 // with a ready version of the tensor handle data.
287 bool IsReady() const;
288 Status WaitReady(const char* caller) const;
289
290 tensorflow::Device* device_;
291
292 // Device in which the op producing this tensor was executed. Equals to
293 // device_ for constant tensors.
294 // Can be nullptr if the op producing this tensor was a function executed
295 // with function library runtime.
296 tensorflow::Device* op_device_;
297
298 // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
299 // backing the resource. Else resource_device_ is nullptr.
300 tensorflow::Device* resource_device_;
301 // Incarnation ID of the resource device if it locates on a remote device, or
302 // 0 if it locates on a local device.
303 int64_t resource_remote_device_incarnation_;
304
305 // If true, the handle refers to a remote tensor which is created without
306 // known devices. The actual devices are set by SetRemoteShape. The devices
307 // should be accessed once the handle is ready.
308 const bool unknown_device_ = false;
309
310 mutable mutex mu_;
311
312 // Map of local mirrors. This can include both ready and non-ready mirrors.
313 std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
314 local_mirrors_ TF_GUARDED_BY(mu_);
315#if !defined(IS_MOBILE_PLATFORM)
316 // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
317 // variable is ready, since we could get the shape locally without remote copy
318 // then.
319 std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
320 TF_GUARDED_BY(mu_);
321 // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
322 // a fixed size map.
323 std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
324 TF_GUARDED_BY(mu_);
325#endif
326
327 // `ctx` is only guaranteed to be set if the handle is not "ready". This is
328 // typically true when the handle was produced during async execution.
329 // `ctx` object is not owned and should outlive this handle.
330 //
331 // TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
332 // a TensorHandle does not outlive the EagerContext from which it came?
333 EagerContext* const ctx_;
334
335 // Does not need synchronization because it can be accessed only after
336 // WaitReady() has returned. At that point, is_poisoned_ is immutable.
337 Status is_poisoned_;
338
339 // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
340 // refers to a remote resource handle, we store data types and shapes for
341 // the underlying resource.
342 std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
343
344 // A handle data which refers to multiple TensorHandles of the same dtype and
345 // shape.
346 class PackedTensorHandleData {
347 public:
348 // Initialize handle data from list of tensor handles.
349 // Ownership of the tensor handles is shared between the
350 // `PackedTensorHandleData` and the caller (the reference count for the
351 // given handles is incremented).
352 // TODO(b/170414377): Use `TensorHandlePtr` instead.
353 PackedTensorHandleData(std::vector<TensorHandle*>&& handles,
354 const TensorShape& shape);
355
356 ~PackedTensorHandleData();
357
358 Status Shape(TensorShape* shape) const;
359 Status NumDims(int* num_dims) const;
360 Status Dim(int dim_index, int64_t* dim) const;
361 Status NumElements(int64_t* num_elements) const;
362 Status Unprotect();
363 bool IsReady() const;
364 Status WaitReady(const char* caller) const;
365 void Poison(Status status);
366 string DebugString() const;
367
368 // Number of packed handles.
369 int NumPackedHandles() const;
370 // Extract a handle on the given index.
371 Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
372
373 private:
374 // TODO(b/170414377): Use `TensorHandlePtr` instead.
375 const std::vector<TensorHandle*> handles_;
376 const TensorShape shape_;
377
378 mutable mutex mu_;
379 Status is_poisoned_ TF_GUARDED_BY(mu_);
380 };
381
382 // Does not need synchronization because it can be accessed only after
383 // WaitReady() has returned. At that point, data_ is immutable.
384#if !defined(IS_MOBILE_PLATFORM)
385 absl::variant<LocalTensorHandleData, PackedTensorHandleData,
386 RemoteTensorHandleData>
387 data_;
388#else
389 absl::variant<LocalTensorHandleData, PackedTensorHandleData> data_;
390#endif
391
392 PartialTensorShape inference_shape_;
393};
394
395// Returns the device backing the resource. Else, returns nullptr.
396Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
397
398class TensorHandleInterface : public ImmediateExecutionTensorHandle {
399 public:
400};
401
402template <typename T>
403inline TensorHandle* TensorHandleFromInterface(T* handle) {
404 return down_cast<TensorHandle*>(handle);
405}
406
407} // namespace tensorflow
408
409#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
410