1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17#include "tensorflow/core/kernels/tensor_array.h"
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/framework/tensor_util.h"
22#include "tensorflow/core/kernels/aggregate_ops_cpu.h"
23
24namespace tensorflow {
25
26typedef Eigen::ThreadPoolDevice CPUDevice;
27typedef Eigen::GpuDevice GPUDevice;
28
29namespace tensor_array {
30
31#define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \
32 template <> \
33 Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \
34 const Tensor* current, const Tensor* add) { \
35 functor::Add2Functor<Device, T> add_functor; \
36 add_functor(ctx->template eigen_device<Device>(), sum->flat<T>(), \
37 current->flat<T>(), add->flat<T>()); \
38 return OkStatus(); \
39 }
40
41#define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
42TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
43#undef TENSOR_ARRAY_WRITE_OR_ADD_CPU
44
45#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
46
47#define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
48TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
49TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
50#undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
51
52#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
53
54#undef TENSOR_ARRAY_WRITE_OR_ADD
55
56#define TENSOR_ARRAY_SET_ZERO(Device, T) \
57 template <> \
58 Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value) { \
59 functor::SetZeroFunctor<Device, T> set_zero_functor; \
60 set_zero_functor(ctx->template eigen_device<Device>(), value->flat<T>()); \
61 return OkStatus(); \
62 }
63
64#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
65TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
66TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
67#undef TENSOR_ARRAY_SET_ZERO_CPU
68
69#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
70
71#define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
72TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
73TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
74#undef TENSOR_ARRAY_SET_ZERO_GPU
75
76#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
77
78#undef TENSOR_ARRAY_SET_ZERO
79
80} // namespace tensor_array
81
82std::atomic<int64_t> TensorArray::tensor_array_counter{0};
83
84Status TensorArray::CopyShapesFrom(TensorArray* rhs,
85 const TensorShape* shape_to_prepend) {
86 mutex_lock l(mu_);
87 mutex_lock l_rhs(rhs->mu_);
88 TF_RETURN_IF_ERROR(LockedReturnIfClosed());
89 TF_RETURN_IF_ERROR(rhs->LockedReturnIfClosed());
90 if (tensors_.size() != rhs->tensors_.size()) {
91 return errors::InvalidArgument(
92 "TensorArray sizes do not match during CopyShapesFrom: ",
93 handle_.vec<tstring>()(1), " has size ", tensors_.size(), " but rhs ",
94 rhs->handle_.vec<tstring>()(1), " has size ", rhs->tensors_.size());
95 }
96 for (std::size_t i = 0; i < tensors_.size(); ++i) {
97 // Skip "soft copy" of indices which have not been written.
98 if (!rhs->tensors_[i].written) continue;
99
100 // Copy the shape over.
101 if (shape_to_prepend) {
102 tensors_[i].shape = *shape_to_prepend;
103 tensors_[i].shape.AppendShape(rhs->tensors_[i].shape);
104 } else {
105 tensors_[i].shape = rhs->tensors_[i].shape;
106 }
107 // Mark as written. Reads will know that if written is true and
108 // read is false, and cleared is false, to return zeros of the
109 // appropriate shape. Future aggregating writes will only use the shape
110 // for validation.
111 tensors_[i].written = true;
112 }
113
114 return OkStatus();
115}
116
117} // namespace tensorflow
118