1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | #define EIGEN_USE_GPU |
22 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
23 | |
24 | #include "tensorflow/core/kernels/diag_op.h" |
25 | |
26 | #include <algorithm> |
27 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/register_types.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_types.h" |
32 | #include "tensorflow/core/platform/logging.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | #include "tensorflow/core/util/work_sharder.h" |
35 | |
36 | namespace tensorflow { |
37 | |
38 | typedef Eigen::ThreadPoolDevice CPUDevice; |
39 | typedef Eigen::GpuDevice GPUDevice; |
40 | |
41 | // Generate the diagonal tensor with the diagonal set to the input tensor. |
42 | template <typename Device, typename T> |
43 | class DiagOp : public OpKernel { |
44 | public: |
45 | explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {} |
46 | |
47 | void Compute(OpKernelContext* context) override { |
48 | const Tensor& diagonal = context->input(0); |
49 | const int num_dims = diagonal.dims(); |
50 | OP_REQUIRES( |
51 | context, 0 != num_dims, |
52 | errors::InvalidArgument("Input must be at least rank 1, got 0" )); |
53 | TensorShape out_shape; |
54 | for (int i = 0; i < num_dims; ++i) { |
55 | out_shape.AddDim(diagonal.dim_size(i)); |
56 | } |
57 | for (int i = 0; i < num_dims; ++i) { |
58 | out_shape.AddDim(diagonal.dim_size(i)); |
59 | } |
60 | Tensor* output_tensor = nullptr; |
61 | OP_REQUIRES_OK(context, |
62 | context->allocate_output(0, out_shape, &output_tensor)); |
63 | functor::DiagFunctor<Device, T> diagFunc; |
64 | Status s = |
65 | diagFunc(context, diagonal.NumElements(), diagonal.flat<T>().data(), |
66 | output_tensor->flat<T>().data()); |
67 | OP_REQUIRES_OK(context, s); |
68 | } |
69 | }; |
70 | |
71 | // Extract the diagonal tensor with the diagonal set to the input tensor. |
72 | template <typename Device, typename T> |
73 | class DiagPartOp : public OpKernel { |
74 | public: |
75 | explicit DiagPartOp(OpKernelConstruction* context) : OpKernel(context) {} |
76 | |
77 | void Compute(OpKernelContext* context) override { |
78 | const Tensor& tensor = context->input(0); |
79 | const int num_dims = tensor.dims(); |
80 | const int out_dims = num_dims / 2; |
81 | OP_REQUIRES(context, 0 == num_dims % 2, |
82 | errors::InvalidArgument("The rank of the tensor should be \ |
83 | even and positive, got shape " , |
84 | tensor.shape().DebugString())); |
85 | for (int i = 0; i < out_dims; i++) { |
86 | OP_REQUIRES( |
87 | context, tensor.dim_size(i) == tensor.dim_size(i + out_dims), |
88 | errors::InvalidArgument("Invalid shape " , |
89 | tensor.shape().DebugString(), ": dimensions " , |
90 | i, " and " , i + out_dims, " do not match." )); |
91 | } |
92 | |
93 | TensorShape out_shape; |
94 | for (int i = 0; i < out_dims; ++i) { |
95 | out_shape.AddDim(tensor.dim_size(i)); |
96 | } |
97 | |
98 | Tensor* output = nullptr; |
99 | OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); |
100 | functor::DiagPartFunctor<Device, T> diagPartFunc; |
101 | Status s = diagPartFunc(context, out_shape.num_elements(), |
102 | tensor.flat<T>().data(), output->flat<T>().data()); |
103 | OP_REQUIRES_OK(context, s); |
104 | } |
105 | }; |
106 | |
107 | // Implementation of the functor specialization for CPU. |
108 | // |
109 | // According to the diagonal definition, |
110 | // `output[i1,..., ik, i1,..., ik] = input[i1,..., ik]`, |
111 | // |
112 | // Let the rank of input is [s1,..., sk], then any offset of input's |
113 | // pointer can be represent by coordinate [i1,..., ik], |
114 | // where `index = i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik` |
115 | // |
116 | // Let new_index is the offset of output's pointer with coordinate |
117 | // [i1,..., ik, i1,..., ik], then we have |
118 | // `new_index = i1*(s2*...sk*s1*...*sk) + i2*(s3*...*sk*s1*...*sk) +... + \ |
119 | // ik*(s1*...*sk) + i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik |
120 | // = (i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik) * (1 + s1*...*sk) |
121 | // = index * (1 + s1*...*sk) |
122 | // |
123 | // Let `size = s1*...*sk`, we finally have `new_index = index * (1 + size)`, |
124 | // which is the transfer function we use below. |
125 | // This trick make our implementations clear and easy to be parallel. |
126 | namespace functor { |
127 | template <typename T> |
128 | struct DiagFunctor<CPUDevice, T> { |
129 | EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, |
130 | const int64_t size, const T* in, |
131 | T* out) { |
132 | // This subprocess is responsible for writing values in index range |
133 | // [start*size, limit*size) |
134 | auto subDiag = [in, out, size](int64_t start, int64_t limit) { |
135 | std::fill(out + size * start, out + size * limit, T()); |
136 | for (int64_t index = start; index < limit; ++index) { |
137 | out[(1 + size) * index] = in[index]; |
138 | } |
139 | }; |
140 | |
141 | // Here, 5 is a empirical factor of cost_per_unit. |
142 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
143 | Shard(worker_threads.num_threads, worker_threads.workers, size, 5 * size, |
144 | subDiag); |
145 | return OkStatus(); |
146 | } |
147 | }; |
148 | |
149 | template <typename T> |
150 | struct DiagPartFunctor<CPUDevice, T> { |
151 | EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, |
152 | const int64_t size, const T* in, |
153 | T* out) { |
154 | // This subprocess is responsible for extracting values in index range |
155 | // [start, limit) |
156 | auto subDiagPart = [in, out, size](int64_t start, int64_t limit) { |
157 | for (int64_t index = start; index < limit; ++index) { |
158 | out[index] = in[(1 + size) * index]; |
159 | } |
160 | }; |
161 | |
162 | // Here, 5 is a empirical factor of cost_per_unit. |
163 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
164 | Shard(worker_threads.num_threads, worker_threads.workers, size, 5, |
165 | subDiagPart); |
166 | return OkStatus(); |
167 | } |
168 | }; |
169 | } // namespace functor |
170 | |
171 | // Register the CPU kernels. |
172 | #define REGISTER_DIAGOP(T) \ |
173 | REGISTER_KERNEL_BUILDER( \ |
174 | Name("Diag").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
175 | DiagOp<CPUDevice, T>) |
176 | |
177 | TF_CALL_double(REGISTER_DIAGOP); |
178 | TF_CALL_float(REGISTER_DIAGOP); |
179 | TF_CALL_int32(REGISTER_DIAGOP); |
180 | TF_CALL_int64(REGISTER_DIAGOP); |
181 | TF_CALL_COMPLEX_TYPES(REGISTER_DIAGOP); |
182 | TF_CALL_half(REGISTER_DIAGOP); |
183 | #undef REGISTER_DIAGOP |
184 | |
185 | #define REGISTER_DIAGPARTOP(T) \ |
186 | REGISTER_KERNEL_BUILDER( \ |
187 | Name("DiagPart").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
188 | DiagPartOp<CPUDevice, T>) |
189 | |
190 | TF_CALL_double(REGISTER_DIAGPARTOP); |
191 | TF_CALL_float(REGISTER_DIAGPARTOP); |
192 | TF_CALL_int32(REGISTER_DIAGPARTOP); |
193 | TF_CALL_int64(REGISTER_DIAGPARTOP); |
194 | TF_CALL_COMPLEX_TYPES(REGISTER_DIAGPARTOP); |
195 | TF_CALL_half(REGISTER_DIAGPARTOP); |
196 | #undef REGISTER_DIAGPARTOP |
197 | |
198 | // Register the GPU kernels. |
199 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
200 | |
201 | // Forward declarations of the functor specializations for GPU. |
202 | namespace functor { |
203 | extern template struct DiagFunctor<GPUDevice, double>; |
204 | extern template struct DiagFunctor<GPUDevice, float>; |
205 | extern template struct DiagFunctor<GPUDevice, int32>; |
206 | extern template struct DiagFunctor<GPUDevice, int64_t>; |
207 | extern template struct DiagFunctor<GPUDevice, complex64>; |
208 | extern template struct DiagFunctor<GPUDevice, complex128>; |
209 | } // namespace functor |
210 | |
211 | #define REGISTER_DIAGOP_GPU(T) \ |
212 | REGISTER_KERNEL_BUILDER( \ |
213 | Name("Diag").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
214 | DiagOp<GPUDevice, T>) |
215 | |
216 | TF_CALL_double(REGISTER_DIAGOP_GPU); |
217 | TF_CALL_float(REGISTER_DIAGOP_GPU); |
218 | TF_CALL_int32(REGISTER_DIAGOP_GPU); |
219 | TF_CALL_int64(REGISTER_DIAGOP_GPU); |
220 | TF_CALL_COMPLEX_TYPES(REGISTER_DIAGOP_GPU); |
221 | TF_CALL_half(REGISTER_DIAGOP_GPU); |
222 | #undef REGISTER_DIAGOP_GPU |
223 | |
224 | // Forward declarations of the functor specializations for GPU. |
225 | namespace functor { |
226 | extern template struct DiagPartFunctor<GPUDevice, double>; |
227 | extern template struct DiagPartFunctor<GPUDevice, float>; |
228 | extern template struct DiagPartFunctor<GPUDevice, int32>; |
229 | extern template struct DiagPartFunctor<GPUDevice, int64_t>; |
230 | extern template struct DiagPartFunctor<GPUDevice, complex64>; |
231 | extern template struct DiagPartFunctor<GPUDevice, complex128>; |
232 | extern template struct DiagPartFunctor<GPUDevice, Eigen::half>; |
233 | } // namespace functor |
234 | |
235 | #define REGISTER_DIAGPARTOP_GPU(T) \ |
236 | REGISTER_KERNEL_BUILDER( \ |
237 | Name("DiagPart").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ |
238 | DiagPartOp<GPUDevice, T>) |
239 | |
240 | TF_CALL_double(REGISTER_DIAGPARTOP_GPU); |
241 | TF_CALL_float(REGISTER_DIAGPARTOP_GPU); |
242 | TF_CALL_int32(REGISTER_DIAGPARTOP_GPU); |
243 | TF_CALL_int64(REGISTER_DIAGPARTOP_GPU); |
244 | TF_CALL_COMPLEX_TYPES(REGISTER_DIAGPARTOP_GPU); |
245 | TF_CALL_half(REGISTER_DIAGPARTOP_GPU); |
246 | #undef REGISTER_DIAGPARTOP_GPU |
247 | |
248 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
249 | |
250 | } // namespace tensorflow |
251 | |