1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | #include "tensorflow/core/kernels/tensor_array.h" |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/register_types.h" |
21 | #include "tensorflow/core/framework/tensor_util.h" |
22 | #include "tensorflow/core/kernels/aggregate_ops_cpu.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | typedef Eigen::ThreadPoolDevice CPUDevice; |
27 | typedef Eigen::GpuDevice GPUDevice; |
28 | |
29 | namespace tensor_array { |
30 | |
31 | #define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \ |
32 | template <> \ |
33 | Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \ |
34 | const Tensor* current, const Tensor* add) { \ |
35 | functor::Add2Functor<Device, T> add_functor; \ |
36 | add_functor(ctx->template eigen_device<Device>(), sum->flat<T>(), \ |
37 | current->flat<T>(), add->flat<T>()); \ |
38 | return OkStatus(); \ |
39 | } |
40 | |
41 | #define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T) |
42 | TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU) |
43 | #undef TENSOR_ARRAY_WRITE_OR_ADD_CPU |
44 | |
45 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
46 | |
47 | #define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T) |
48 | TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU); |
49 | TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU); |
50 | #undef TENSOR_ARRAY_WRITE_OR_ADD_GPU |
51 | |
52 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
53 | |
54 | #undef TENSOR_ARRAY_WRITE_OR_ADD |
55 | |
56 | #define TENSOR_ARRAY_SET_ZERO(Device, T) \ |
57 | template <> \ |
58 | Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value) { \ |
59 | functor::SetZeroFunctor<Device, T> set_zero_functor; \ |
60 | set_zero_functor(ctx->template eigen_device<Device>(), value->flat<T>()); \ |
61 | return OkStatus(); \ |
62 | } |
63 | |
64 | #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T) |
65 | TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU); |
66 | TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU); |
67 | #undef TENSOR_ARRAY_SET_ZERO_CPU |
68 | |
69 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
70 | |
71 | #define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T) |
72 | TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); |
73 | TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); |
74 | #undef TENSOR_ARRAY_SET_ZERO_GPU |
75 | |
76 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
77 | |
78 | #undef TENSOR_ARRAY_SET_ZERO |
79 | |
80 | } // namespace tensor_array |
81 | |
82 | std::atomic<int64_t> TensorArray::tensor_array_counter{0}; |
83 | |
84 | Status TensorArray::CopyShapesFrom(TensorArray* rhs, |
85 | const TensorShape* shape_to_prepend) { |
86 | mutex_lock l(mu_); |
87 | mutex_lock l_rhs(rhs->mu_); |
88 | TF_RETURN_IF_ERROR(LockedReturnIfClosed()); |
89 | TF_RETURN_IF_ERROR(rhs->LockedReturnIfClosed()); |
90 | if (tensors_.size() != rhs->tensors_.size()) { |
91 | return errors::InvalidArgument( |
92 | "TensorArray sizes do not match during CopyShapesFrom: " , |
93 | handle_.vec<tstring>()(1), " has size " , tensors_.size(), " but rhs " , |
94 | rhs->handle_.vec<tstring>()(1), " has size " , rhs->tensors_.size()); |
95 | } |
96 | for (std::size_t i = 0; i < tensors_.size(); ++i) { |
97 | // Skip "soft copy" of indices which have not been written. |
98 | if (!rhs->tensors_[i].written) continue; |
99 | |
100 | // Copy the shape over. |
101 | if (shape_to_prepend) { |
102 | tensors_[i].shape = *shape_to_prepend; |
103 | tensors_[i].shape.AppendShape(rhs->tensors_[i].shape); |
104 | } else { |
105 | tensors_[i].shape = rhs->tensors_[i].shape; |
106 | } |
107 | // Mark as written. Reads will know that if written is true and |
108 | // read is false, and cleared is false, to return zeros of the |
109 | // appropriate shape. Future aggregating writes will only use the shape |
110 | // for validation. |
111 | tensors_[i].written = true; |
112 | } |
113 | |
114 | return OkStatus(); |
115 | } |
116 | |
117 | } // namespace tensorflow |
118 | |