1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_ |
18 | |
19 | #include <limits.h> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/partial_tensor_shape.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/resource_mgr.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/kernels/aggregate_ops.h" |
30 | #include "tensorflow/core/kernels/fill_functor.h" |
31 | #include "tensorflow/core/lib/core/errors.h" |
32 | #include "tensorflow/core/platform/logging.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | typedef Eigen::ThreadPoolDevice CPUDevice; |
38 | typedef Eigen::GpuDevice GPUDevice; |
39 | |
40 | namespace tensor_array { |
41 | |
42 | // Full implementations are in tensor_array.cc |
43 | template <typename Device, typename T> |
44 | Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current, |
45 | const Tensor* add) { |
46 | return errors::InvalidArgument( |
47 | "tensor_array::AddToTensor type not supported: " , |
48 | DataTypeString(DataTypeToEnum<T>::value)); |
49 | } |
50 | |
51 | #define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \ |
52 | template <> \ |
53 | Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \ |
54 | const Tensor* current, const Tensor* add); |
55 | |
56 | #define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T) |
57 | TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU) |
58 | #undef TENSOR_ARRAY_WRITE_OR_ADD_CPU |
59 | |
60 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
61 | |
62 | #define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T) |
63 | TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU); |
64 | TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU); |
65 | #undef TENSOR_ARRAY_WRITE_OR_ADD_GPU |
66 | |
67 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
68 | |
69 | #undef TENSOR_ARRAY_WRITE_OR_ADD |
70 | |
71 | template <typename Device, typename T> |
72 | Status TensorSetZero(OpKernelContext* ctx, Tensor* value) { |
73 | return errors::InvalidArgument( |
74 | "tensor_array::TensorSetZero type not supported: " , |
75 | DataTypeString(DataTypeToEnum<T>::value)); |
76 | } |
77 | |
78 | #define TENSOR_ARRAY_SET_ZERO(Device, T) \ |
79 | template <> \ |
80 | Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value); |
81 | |
82 | #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T) |
83 | TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU); |
84 | TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU); |
85 | #undef TENSOR_ARRAY_SET_ZERO_CPU |
86 | |
87 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
88 | |
89 | #define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T) |
90 | TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); |
91 | TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); |
92 | #undef TENSOR_ARRAY_SET_ZERO_GPU |
93 | |
94 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
95 | |
96 | #undef TENSOR_ARRAY_SET_ZERO |
97 | |
98 | } // namespace tensor_array |
99 | |
100 | // The TensorArray object keeps an array of Tensors. It allows reading from the |
101 | // array and writing to the array. |
102 | // |
103 | // Important properties: |
104 | // * Usually, writing to a particular index in the TensorArray is allowed at |
105 | // most once per index. In a special case, writes with the flag |
106 | // multiple_writes_aggregate allow multiple writes to the same |
107 | // index. In this case, the writes are summed. |
108 | // * Multiple reads are supported. |
109 | // * Deep copies of Tensors are rarely made. The only time they are made is |
110 | // when WriteOrAggregate is called at least twice on the same index with the |
111 | // flag multiple_writes_aggregate = True. |
112 | // * Reading and Writing to the array is protected by a mutex. |
113 | // All operations on a TensorArray are thread-safe. |
114 | // * A TensorArray may be preemptively closed, which releases all |
115 | // memory associated with it. |
116 | // |
117 | // These properties together allow the TensorArray to work as a |
118 | // functional object and makes gradient computation easy. For |
119 | // example: |
120 | // * Write-Once semantics mean the gradient of a TensorArray Read never has to |
121 | // worry which of multiple writes to that index the gradient value |
122 | // is meant for. |
123 | // * Read-Many semantics (when using clear_after_read=false) allow the |
124 | // TensorArray to be read, packed, or concatenated multiple times; |
125 | // and the gradient operations use the multiple_writes_aggregate |
126 | // flag to aggregate the backprop writes. Multiple backprop writes to |
127 | // the same index are partial gradients corresponding to the |
128 | // multiple reads of that index in the forward phase. |
129 | // |
130 | class TensorArray : public ResourceBase { |
131 | public: |
132 | static std::atomic<int64_t> tensor_array_counter; |
133 | |
134 | // Construct a TensorArray for holding Tensors of type 'dtype' with |
135 | // 'N' elements. While the underlying storage is a std::vector and |
136 | // can hold more than MAX_INT entries, in practice we do not expect |
137 | // users to construct this many Tensors for storage in a TensorArray. |
138 | TensorArray(const string& key, const DataType& dtype, const Tensor& handle, |
139 | int32_t N, const PartialTensorShape& element_shape, |
140 | bool identical_element_shapes, bool dynamic_size, |
141 | bool multiple_writes_aggregate, bool is_grad, int32_t marked_size, |
142 | bool clear_after_read) |
143 | : key_(key), |
144 | dtype_(dtype), |
145 | handle_(handle), |
146 | closed_(false), |
147 | dynamic_size_(dynamic_size), |
148 | multiple_writes_aggregate_(multiple_writes_aggregate), |
149 | gradients_disallowed_(false), |
150 | clear_after_read_(clear_after_read), |
151 | is_grad_(is_grad), |
152 | marked_size_(marked_size), |
153 | element_shape_(element_shape), |
154 | identical_element_shapes_(identical_element_shapes), |
155 | tensors_(N) {} |
156 | |
157 | // Write Tensor 'value' to index 'index'. |
158 | // |
159 | // Preconditions: |
160 | // * The TensorArray is not closed |
161 | // * If the array has dynamic size: |
162 | // The index is >= 0 |
163 | // Otherwise: |
164 | // The index is in [0, N) where N == Size() |
165 | // * The dtype of the Tensor in 'value' matches the TensorArray's dtype. |
166 | // * If multiple_writes_aggregate is false: |
167 | // The Tensor at 'index' has not yet been written to. |
168 | // * If multiple_writes_aggregate is true: |
169 | // The Tensor at 'index' has the same shape as value. |
170 | // |
171 | // Side effects: |
172 | // * On the first write to 'index': |
173 | // - The underlying Tensor in 'value' has a new reference to it. |
174 | // - The index 'index' is marked as written. |
175 | // * If multiple_writes_aggregate is false, subsequent writes to 'index' |
176 | // raise an InvalidArgument error. |
177 | // * If multiple_writes_aggregate is true, subsequent writes to 'index': |
178 | // - The underlying Tensors in 'value' and from the first write |
179 | // are released and a local Tensor is created. |
180 | // - Index 'index' is also marked as local_copy. |
181 | // - The gradients_disallowed flag is set true (GradientsAllowed() |
182 | // will now return false). |
183 | // |
184 | // Note, value is passed as a pointer because we its underlying |
185 | // Tensor's shape is accessed. Otherwise it is not modified. |
186 | template <typename Device, typename T> |
187 | Status WriteOrAggregate(OpKernelContext* ctx, const int32_t index, |
188 | const Tensor* value) { |
189 | mutex_lock l(mu_); |
190 | return LockedWriteOrAggregate<Device, T>(ctx, index, value); |
191 | } |
192 | |
193 | template <typename Device, typename T> |
194 | Status WriteOrAggregateMany(OpKernelContext* ctx, |
195 | const std::vector<int32>& indices, |
196 | std::vector<Tensor>* values) { |
197 | mutex_lock l(mu_); |
198 | int32_t i = 0; |
199 | for (const int32_t ix : indices) { |
200 | Status s = LockedWriteOrAggregate<Device, T>(ctx, ix, &(*values)[i]); |
201 | ++i; |
202 | TF_RETURN_IF_ERROR(s); |
203 | } |
204 | return OkStatus(); |
205 | } |
206 | |
207 | // Read from index 'index' into Tensor 'value'. |
208 | // |
209 | // Preconditions: |
210 | // * The TensorArray is not closed |
211 | // * The index is in [0, N) |
212 | // * The Tensor at 'index' has been written to. |
213 | // * The Tensor at 'index' has not been read from with flag |
214 | // clear_after_read = true. |
215 | // |
216 | // Side effects: |
217 | // * If clear_after_read is true, the reference to the underlying |
218 | // Tensor is deleted. |
219 | // * The reference to the underlying Tensor at 'index' is copied to |
220 | // the returned '*value'. |
221 | // * The index is marked as read (it cannot be rewritten to). |
222 | template <typename Device, typename T> |
223 | Status Read(OpKernelContext* ctx, const int32_t index, Tensor* value) { |
224 | mutex_lock l(mu_); |
225 | return LockedRead<Device, T>(ctx, index, value); |
226 | } |
227 | |
228 | template <typename Device, typename T> |
229 | Status ReadMany(OpKernelContext* ctx, const std::vector<int32>& indices, |
230 | std::vector<Tensor>* values) { |
231 | mutex_lock l(mu_); |
232 | values->clear(); |
233 | values->resize(indices.size()); |
234 | int32_t i = 0; |
235 | for (const int32_t ix : indices) { |
236 | Status s = LockedRead<Device, T>(ctx, ix, &(*values)[i]); |
237 | ++i; |
238 | if (!s.ok()) return s; |
239 | } |
240 | return OkStatus(); |
241 | } |
242 | |
243 | DataType ElemType() const { return dtype_; } |
244 | |
245 | PartialTensorShape ElemShape() { |
246 | mutex_lock l(mu_); |
247 | return element_shape_; |
248 | } |
249 | |
250 | Status SetElemShape(const PartialTensorShape& candidate) { |
251 | mutex_lock l(mu_); |
252 | PartialTensorShape new_element_shape_; |
253 | Status s = element_shape_.MergeWith(candidate, &new_element_shape_); |
254 | if (!s.ok()) { |
255 | return s; |
256 | } |
257 | element_shape_ = new_element_shape_; |
258 | return OkStatus(); |
259 | } |
260 | |
261 | string DebugString() const override { |
262 | mutex_lock l(mu_); |
263 | CHECK(!closed_); |
264 | return strings::StrCat("TensorArray[" , tensors_.size(), "]" ); |
265 | } |
266 | |
267 | bool IsClosed() { |
268 | mutex_lock l(mu_); |
269 | return closed_; |
270 | } |
271 | |
272 | // Return the size of the TensorArray. |
273 | Status Size(int32* size) { |
274 | mutex_lock l(mu_); |
275 | TF_RETURN_IF_ERROR(LockedReturnIfClosed()); |
276 | *size = tensors_.size(); |
277 | return OkStatus(); |
278 | } |
279 | |
280 | // Record the size of the TensorArray after an unpack or split. |
281 | Status SetMarkedSize(int32_t size) { |
282 | mutex_lock l(mu_); |
283 | TF_RETURN_IF_ERROR(LockedReturnIfClosed()); |
284 | if (!is_grad_) { |
285 | marked_size_ = size; |
286 | } |
287 | return OkStatus(); |
288 | } |
289 | |
290 | // Return the marked size of the TensorArray. |
291 | Status MarkedSize(int32* size) { |
292 | mutex_lock l(mu_); |
293 | TF_RETURN_IF_ERROR(LockedReturnIfClosed()); |
294 | *size = marked_size_; |
295 | return OkStatus(); |
296 | } |
297 | |
298 | // Return the size that should be used by pack or concat op. |
299 | Status PackOrConcatSize(int32* size) { |
300 | mutex_lock l(mu_); |
301 | TF_RETURN_IF_ERROR(LockedReturnIfClosed()); |
302 | *size = is_grad_ ? marked_size_ : tensors_.size(); |
303 | return OkStatus(); |
304 | } |
305 | |
306 | // Once a TensorArray is being used for gradient calculations, it |
307 | // should be marked as no longer resizeable. |
308 | void DisableDynamicSize() { |
309 | mutex_lock l(mu_); |
310 | dynamic_size_ = false; |
311 | } |
312 | |
313 | bool HasDynamicSize() { |
314 | mutex_lock l(mu_); |
315 | return dynamic_size_; |
316 | } |
317 | |
318 | bool GradientsAllowed() { |
319 | mutex_lock l(mu_); |
320 | return !gradients_disallowed_; |
321 | } |
322 | |
323 | bool HasIdenticalElementShapes() const { return identical_element_shapes_; } |
324 | |
325 | // Copy the TensorShapes from another TensorArray into this one. |
326 | // If `shapes_to_prepend` is set, expands the rank of the copied shape by |
327 | // prepending the passed in shape prefix to the shape values in `rhs`. |
328 | // The sizes of the two TensorArrays must match and this one |
329 | // may not have any entries filled in. This performs a "soft copy", |
330 | // essentially filling the current TensorArray with virtual |
331 | // zero-tensors, which will be replaced by future aggregate writes, |
332 | // or instantiated by future reads. Requires a non-const pointer |
333 | // to the rhs to access its mutex. |
334 | Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend); |
335 | |
336 | // Clear the TensorArray, including any Tensor references, and mark as closed. |
337 | void ClearAndMarkClosed() { |
338 | mutex_lock l(mu_); |
339 | tensors_.clear(); |
340 | closed_ = true; |
341 | } |
342 | |
343 | mutex* mu() { return &mu_; } |
344 | Tensor* handle() { return &handle_; } |
345 | |
346 | ResourceHandle resource_handle(OpKernelContext* ctx) { |
347 | return ctx->step_container()->MakeResourceHandle<TensorArray>( |
348 | key_, *ctx->device()); |
349 | } |
350 | |
351 | private: |
352 | Status LockedWrite(OpKernelContext* ctx, const int32_t index, Tensor* value) |
353 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
354 | |
355 | template <typename Device, typename T> |
356 | Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32_t index, |
357 | const Tensor* value) |
358 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
359 | |
360 | template <typename Device, typename T> |
361 | Status LockedRead(OpKernelContext* ctx, const int32_t index, Tensor* value) |
362 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
363 | |
364 | Status LockedReturnIfClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
365 | if (closed_) { |
366 | return errors::InvalidArgument("TensorArray " , handle_.vec<tstring>()(1), |
367 | " has already been closed." ); |
368 | } |
369 | return OkStatus(); |
370 | } |
371 | |
372 | const string key_; |
373 | |
374 | const DataType dtype_; |
375 | Tensor handle_; |
376 | |
377 | mutable mutex mu_; |
378 | |
379 | // Marks that the tensor_array_ has been cleared. |
380 | bool closed_ TF_GUARDED_BY(mu_); |
381 | |
382 | // Writes are allowed to grow the array. |
383 | bool dynamic_size_; |
384 | |
385 | // Multiple writes to the same index will result in summation of the |
386 | // values (used by backprop) |
387 | const bool multiple_writes_aggregate_; |
388 | |
389 | // If multiple Writes were attempted (e.g. via attribute |
390 | // multiple_writes_aggregate), then gradients are disallowed. |
391 | bool gradients_disallowed_ TF_GUARDED_BY(mu_); |
392 | |
393 | // After a read at an index, clear away its Tensor to release memory. |
394 | const bool clear_after_read_; |
395 | |
396 | // True iff this is a gradient tensor array. |
397 | const bool is_grad_; |
398 | |
399 | // The size of the TensorArray after a (legacy) unpack or split is performed. |
400 | // -1 if there has been no unpack or split performed on the TensorArray. |
401 | int32 marked_size_; |
402 | |
403 | // The shape of each element in the TensorArray, may be partially known or not |
404 | // known at all. |
405 | PartialTensorShape element_shape_ TF_GUARDED_BY(mu_); |
406 | |
407 | // Whether all elements in the TensorArray have identical shapes. |
408 | // This allows certain behaviors, like dynamically checking for |
409 | // consistent shapes on write, and being able to fill in properly |
410 | // shaped zero tensors on stack -- even if the initial element_shape |
411 | // was not fully defined. |
412 | const bool identical_element_shapes_; |
413 | |
414 | // TensorAndState is used to keep track of the Tensors stored in the |
415 | // TensorArray, along with their shapes, and a boolean that determines whether |
416 | // they have already been read or not. |
417 | struct TensorAndState { |
418 | TensorAndState() |
419 | : written(false), read(false), cleared(false), local_copy(false) {} |
420 | Tensor tensor; |
421 | TensorShape shape; |
422 | bool written; // True if a Tensor has been written to the index. |
423 | bool read; // True if a Tensor has been written to and read from the index. |
424 | bool cleared; // True if a tensor has been read with |
425 | // clear_after_read = true; |
426 | |
427 | // Used by writes when multiple_writes_aggregate is true. In this |
428 | // case, the first time a value is written, it is a shallow copy. |
429 | // The second time a value is written, it is aggregated. However, |
430 | // in this case a new Tensor must be constructed to hold the |
431 | // aggregated value. This flag marks that such a Tensor is being |
432 | // used. All future writes will aggregate to the existing local Tensor. |
433 | bool local_copy; |
434 | }; |
435 | // The list of underlying Tensors and states. |
436 | std::vector<TensorAndState> tensors_ TF_GUARDED_BY(mu_); |
437 | }; |
438 | |
439 | template <typename Device, typename T> |
440 | Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, |
441 | const int32_t index, |
442 | const Tensor* value) { |
443 | TF_RETURN_IF_ERROR(LockedReturnIfClosed()); |
444 | size_t index_size = static_cast<size_t>(index); |
445 | if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) { |
446 | return errors::InvalidArgument( |
447 | "TensorArray " , handle_.vec<tstring>()(1), ": Tried to write to index " , |
448 | index, " but array is not resizeable and size is: " , tensors_.size()); |
449 | } |
450 | if (dynamic_size_) { |
451 | // We must grow the internal TensorArray |
452 | if (index_size >= tensors_.capacity()) { |
453 | tensors_.reserve(2 * (index_size + 1)); |
454 | } |
455 | if (index_size >= tensors_.size()) { |
456 | tensors_.resize(index_size + 1); |
457 | } |
458 | } |
459 | TensorAndState& t = tensors_[index]; |
460 | |
461 | if (value->dtype() != dtype_) { |
462 | return errors::InvalidArgument( |
463 | "TensorArray " , handle_.vec<tstring>()(1), |
464 | ": Could not write to TensorArray index " , index, |
465 | " because the value dtype is " , DataTypeString(value->dtype()), |
466 | " but TensorArray dtype is " , DataTypeString(dtype_), "." ); |
467 | } |
468 | if (!element_shape_.IsCompatibleWith(value->shape())) { |
469 | return errors::InvalidArgument( |
470 | "TensorArray " , handle_.vec<tstring>()(1), |
471 | ": Could not write to TensorArray index " , index, |
472 | " because the value shape is " , value->shape().DebugString(), |
473 | " which is incompatible with the TensorArray's inferred element " |
474 | "shape: " , |
475 | element_shape_.DebugString(), " (consider setting infer_shape=False)." ); |
476 | } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) { |
477 | element_shape_ = PartialTensorShape(value->shape().dim_sizes()); |
478 | } |
479 | |
480 | if (t.read) { |
481 | return errors::InvalidArgument("TensorArray " , handle_.vec<tstring>()(1), |
482 | ": Could not write to TensorArray index " , |
483 | index, " because it has already been read." ); |
484 | } |
485 | |
486 | if (!multiple_writes_aggregate_ && t.written) { |
487 | return errors::InvalidArgument("TensorArray " , handle_.vec<tstring>()(1), |
488 | ": Could not write to TensorArray index " , |
489 | index, |
490 | " because it has already been written to." ); |
491 | } |
492 | |
493 | if (t.written) { |
494 | DCHECK(multiple_writes_aggregate_); |
495 | |
496 | // Check that value shape matches t.shape |
497 | if (value->shape() != t.shape) { |
498 | return errors::InvalidArgument( |
499 | "TensorArray " , handle_.vec<tstring>()(1), |
500 | ": Could not aggregate to TensorArray index " , index, |
501 | " because the existing shape is " , t.shape.DebugString(), |
502 | " but the new input shape is " , value->shape().DebugString(), "." ); |
503 | } |
504 | |
505 | if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) { |
506 | // If existing_t == nullptr but written == true, then what was stored |
507 | // was just a shape, which just means zeros. So all we must do in this |
508 | // case is copy the reference over and return early. |
509 | t.tensor = *value; |
510 | return OkStatus(); |
511 | } |
512 | |
513 | Tensor* existing_t = &t.tensor; |
514 | |
515 | if (t.local_copy) { |
516 | Status s = tensor_array::AddToTensor<Device, T>(ctx, existing_t, |
517 | existing_t, value); |
518 | TF_RETURN_IF_ERROR(s); |
519 | } else { |
520 | Tensor local_tensor; |
521 | TF_RETURN_IF_ERROR( |
522 | ctx->allocate_temp(dtype_, existing_t->shape(), &local_tensor)); |
523 | Status s = tensor_array::AddToTensor<Device, T>(ctx, &local_tensor, |
524 | existing_t, value); |
525 | TF_RETURN_IF_ERROR(s); |
526 | t.tensor = local_tensor; |
527 | t.local_copy = true; |
528 | } |
529 | |
530 | // We've aggregated the values, so disallow backprop on this |
531 | // TensorArray. |
532 | gradients_disallowed_ = true; |
533 | } else { |
534 | t.tensor = *value; |
535 | t.shape = value->shape(); |
536 | t.written = true; |
537 | } |
538 | return OkStatus(); |
539 | } |
540 | |
541 | template <typename Device, typename T> |
542 | Status TensorArray::LockedRead(OpKernelContext* ctx, const int32_t index, |
543 | Tensor* value) { |
544 | TF_RETURN_IF_ERROR(LockedReturnIfClosed()); |
545 | if ((index < 0) || |
546 | (!is_grad_ && (static_cast<size_t>(index) >= tensors_.size()))) { |
547 | return errors::InvalidArgument("Tried to read from index " , index, |
548 | " but array size is: " , tensors_.size()); |
549 | } |
550 | size_t index_t = static_cast<size_t>(index); |
551 | if ((is_grad_ && (index_t >= tensors_.size() || !tensors_[index].written)) || |
552 | (!is_grad_ && (index_t < tensors_.size() && !tensors_[index].written))) { |
553 | // Special case returning zeros if this is a gradient read that happens |
554 | // after a stop_gradients call with dynamic forward TensorArrays. |
555 | // There is sometimes a race condition where the gradient is not |
556 | // written due to stop_gradients, but is later read. |
557 | TensorShape element_shape; |
558 | if (is_grad_ && index_t < tensors_.size() && |
559 | tensors_[index].shape.dims() > 0) { |
560 | // A gradient TensorArray has more specific gradient information |
561 | // available for each entry. A forward TensorArray must rely on |
562 | // the global element_shape_ to fill in zeros on read. |
563 | element_shape = tensors_[index].shape; |
564 | } else if (!element_shape_.IsFullyDefined()) { |
565 | return errors::InvalidArgument( |
566 | "TensorArray " , handle_.vec<tstring>()(1), |
567 | ": Could not read from TensorArray index " , index, |
568 | ". Furthermore, the element shape is not fully defined: " , |
569 | element_shape_.DebugString(), |
570 | ". It is possible you are working with a resizeable TensorArray and " |
571 | "stop_gradients is not allowing the gradients to be written. If you " |
572 | "set the full " |
573 | "element_shape property on the forward TensorArray, the proper " |
574 | "all-zeros tensor " |
575 | "will be returned instead of incurring this error." ); |
576 | } else { |
577 | element_shape_.AsTensorShape(&element_shape); // Always succeeds. |
578 | } |
579 | if (index_t >= tensors_.size()) { |
580 | // Fill in tensors_ up to index to have known shape. |
581 | size_t old_tensors_size = tensors_.size(); |
582 | tensors_.resize(index + 1); |
583 | for (size_t i = old_tensors_size; i < index + 1; ++i) { |
584 | tensors_[i].shape = element_shape; |
585 | tensors_[i].written = true; |
586 | } |
587 | } else { |
588 | tensors_[index].shape = element_shape; |
589 | tensors_[index].written = true; |
590 | } |
591 | } |
592 | |
593 | TensorAndState& t = tensors_[index]; |
594 | |
595 | if (t.cleared) { |
596 | return errors::InvalidArgument("TensorArray " , handle_.vec<tstring>()(1), |
597 | ": Could not read index " , index, |
598 | " twice because it was cleared after a " |
599 | "previous read (perhaps try setting " |
600 | "clear_after_read = false?)." ); |
601 | } |
602 | |
603 | if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) { |
604 | // We stored just a shape, but no value. This means create and |
605 | // return zeros of the appropriate shape. |
606 | TF_RETURN_IF_ERROR(ctx->allocate_temp(dtype_, t.shape, &t.tensor)); |
607 | if (t.shape.num_elements() > 0) { |
608 | Status s = tensor_array::TensorSetZero<Device, T>(ctx, &t.tensor); |
609 | if (!s.ok()) return s; |
610 | } |
611 | } |
612 | |
613 | // Data is available inside the tensor, copy the reference over. |
614 | *value = t.tensor; |
615 | |
616 | if (clear_after_read_) { |
617 | t.tensor = Tensor(); |
618 | t.cleared = true; |
619 | } |
620 | t.read = true; |
621 | return OkStatus(); |
622 | } |
623 | |
624 | } // namespace tensorflow |
625 | |
626 | #endif // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_ |
627 | |