1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc
17
18#define EIGEN_USE_THREADS
19
20#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
21#define EIGEN_USE_GPU
22#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23
24#include "tensorflow/core/kernels/diag_op.h"
25
26#include <algorithm>
27#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/register_types.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/framework/tensor_types.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/types.h"
34#include "tensorflow/core/util/work_sharder.h"
35
36namespace tensorflow {
37
38typedef Eigen::ThreadPoolDevice CPUDevice;
39typedef Eigen::GpuDevice GPUDevice;
40
41// Generate the diagonal tensor with the diagonal set to the input tensor.
42template <typename Device, typename T>
43class DiagOp : public OpKernel {
44 public:
45 explicit DiagOp(OpKernelConstruction* context) : OpKernel(context) {}
46
47 void Compute(OpKernelContext* context) override {
48 const Tensor& diagonal = context->input(0);
49 const int num_dims = diagonal.dims();
50 OP_REQUIRES(
51 context, 0 != num_dims,
52 errors::InvalidArgument("Input must be at least rank 1, got 0"));
53 TensorShape out_shape;
54 for (int i = 0; i < num_dims; ++i) {
55 out_shape.AddDim(diagonal.dim_size(i));
56 }
57 for (int i = 0; i < num_dims; ++i) {
58 out_shape.AddDim(diagonal.dim_size(i));
59 }
60 Tensor* output_tensor = nullptr;
61 OP_REQUIRES_OK(context,
62 context->allocate_output(0, out_shape, &output_tensor));
63 functor::DiagFunctor<Device, T> diagFunc;
64 Status s =
65 diagFunc(context, diagonal.NumElements(), diagonal.flat<T>().data(),
66 output_tensor->flat<T>().data());
67 OP_REQUIRES_OK(context, s);
68 }
69};
70
71// Extract the diagonal tensor with the diagonal set to the input tensor.
72template <typename Device, typename T>
73class DiagPartOp : public OpKernel {
74 public:
75 explicit DiagPartOp(OpKernelConstruction* context) : OpKernel(context) {}
76
77 void Compute(OpKernelContext* context) override {
78 const Tensor& tensor = context->input(0);
79 const int num_dims = tensor.dims();
80 const int out_dims = num_dims / 2;
81 OP_REQUIRES(context, 0 == num_dims % 2,
82 errors::InvalidArgument("The rank of the tensor should be \
83 even and positive, got shape ",
84 tensor.shape().DebugString()));
85 for (int i = 0; i < out_dims; i++) {
86 OP_REQUIRES(
87 context, tensor.dim_size(i) == tensor.dim_size(i + out_dims),
88 errors::InvalidArgument("Invalid shape ",
89 tensor.shape().DebugString(), ": dimensions ",
90 i, " and ", i + out_dims, " do not match."));
91 }
92
93 TensorShape out_shape;
94 for (int i = 0; i < out_dims; ++i) {
95 out_shape.AddDim(tensor.dim_size(i));
96 }
97
98 Tensor* output = nullptr;
99 OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
100 functor::DiagPartFunctor<Device, T> diagPartFunc;
101 Status s = diagPartFunc(context, out_shape.num_elements(),
102 tensor.flat<T>().data(), output->flat<T>().data());
103 OP_REQUIRES_OK(context, s);
104 }
105};
106
107// Implementation of the functor specialization for CPU.
108//
109// According to the diagonal definition,
110// `output[i1,..., ik, i1,..., ik] = input[i1,..., ik]`,
111//
112// Let the rank of input is [s1,..., sk], then any offset of input's
113// pointer can be represent by coordinate [i1,..., ik],
114// where `index = i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik`
115//
116// Let new_index is the offset of output's pointer with coordinate
117// [i1,..., ik, i1,..., ik], then we have
118// `new_index = i1*(s2*...sk*s1*...*sk) + i2*(s3*...*sk*s1*...*sk) +... + \
119// ik*(s1*...*sk) + i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik
120// = (i1*(s2*...*sk) + i2*(s3*...*sk) +... + ik) * (1 + s1*...*sk)
121// = index * (1 + s1*...*sk)
122//
123// Let `size = s1*...*sk`, we finally have `new_index = index * (1 + size)`,
124// which is the transfer function we use below.
125// This trick make our implementations clear and easy to be parallel.
126namespace functor {
127template <typename T>
128struct DiagFunctor<CPUDevice, T> {
129 EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context,
130 const int64_t size, const T* in,
131 T* out) {
132 // This subprocess is responsible for writing values in index range
133 // [start*size, limit*size)
134 auto subDiag = [in, out, size](int64_t start, int64_t limit) {
135 std::fill(out + size * start, out + size * limit, T());
136 for (int64_t index = start; index < limit; ++index) {
137 out[(1 + size) * index] = in[index];
138 }
139 };
140
141 // Here, 5 is a empirical factor of cost_per_unit.
142 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
143 Shard(worker_threads.num_threads, worker_threads.workers, size, 5 * size,
144 subDiag);
145 return OkStatus();
146 }
147};
148
149template <typename T>
150struct DiagPartFunctor<CPUDevice, T> {
151 EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context,
152 const int64_t size, const T* in,
153 T* out) {
154 // This subprocess is responsible for extracting values in index range
155 // [start, limit)
156 auto subDiagPart = [in, out, size](int64_t start, int64_t limit) {
157 for (int64_t index = start; index < limit; ++index) {
158 out[index] = in[(1 + size) * index];
159 }
160 };
161
162 // Here, 5 is a empirical factor of cost_per_unit.
163 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
164 Shard(worker_threads.num_threads, worker_threads.workers, size, 5,
165 subDiagPart);
166 return OkStatus();
167 }
168};
169} // namespace functor
170
171// Register the CPU kernels.
172#define REGISTER_DIAGOP(T) \
173 REGISTER_KERNEL_BUILDER( \
174 Name("Diag").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
175 DiagOp<CPUDevice, T>)
176
177TF_CALL_double(REGISTER_DIAGOP);
178TF_CALL_float(REGISTER_DIAGOP);
179TF_CALL_int32(REGISTER_DIAGOP);
180TF_CALL_int64(REGISTER_DIAGOP);
181TF_CALL_COMPLEX_TYPES(REGISTER_DIAGOP);
182TF_CALL_half(REGISTER_DIAGOP);
183#undef REGISTER_DIAGOP
184
185#define REGISTER_DIAGPARTOP(T) \
186 REGISTER_KERNEL_BUILDER( \
187 Name("DiagPart").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
188 DiagPartOp<CPUDevice, T>)
189
190TF_CALL_double(REGISTER_DIAGPARTOP);
191TF_CALL_float(REGISTER_DIAGPARTOP);
192TF_CALL_int32(REGISTER_DIAGPARTOP);
193TF_CALL_int64(REGISTER_DIAGPARTOP);
194TF_CALL_COMPLEX_TYPES(REGISTER_DIAGPARTOP);
195TF_CALL_half(REGISTER_DIAGPARTOP);
196#undef REGISTER_DIAGPARTOP
197
198// Register the GPU kernels.
199#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
200
201// Forward declarations of the functor specializations for GPU.
202namespace functor {
203extern template struct DiagFunctor<GPUDevice, double>;
204extern template struct DiagFunctor<GPUDevice, float>;
205extern template struct DiagFunctor<GPUDevice, int32>;
206extern template struct DiagFunctor<GPUDevice, int64_t>;
207extern template struct DiagFunctor<GPUDevice, complex64>;
208extern template struct DiagFunctor<GPUDevice, complex128>;
209} // namespace functor
210
211#define REGISTER_DIAGOP_GPU(T) \
212 REGISTER_KERNEL_BUILDER( \
213 Name("Diag").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
214 DiagOp<GPUDevice, T>)
215
216TF_CALL_double(REGISTER_DIAGOP_GPU);
217TF_CALL_float(REGISTER_DIAGOP_GPU);
218TF_CALL_int32(REGISTER_DIAGOP_GPU);
219TF_CALL_int64(REGISTER_DIAGOP_GPU);
220TF_CALL_COMPLEX_TYPES(REGISTER_DIAGOP_GPU);
221TF_CALL_half(REGISTER_DIAGOP_GPU);
222#undef REGISTER_DIAGOP_GPU
223
224// Forward declarations of the functor specializations for GPU.
225namespace functor {
226extern template struct DiagPartFunctor<GPUDevice, double>;
227extern template struct DiagPartFunctor<GPUDevice, float>;
228extern template struct DiagPartFunctor<GPUDevice, int32>;
229extern template struct DiagPartFunctor<GPUDevice, int64_t>;
230extern template struct DiagPartFunctor<GPUDevice, complex64>;
231extern template struct DiagPartFunctor<GPUDevice, complex128>;
232extern template struct DiagPartFunctor<GPUDevice, Eigen::half>;
233} // namespace functor
234
235#define REGISTER_DIAGPARTOP_GPU(T) \
236 REGISTER_KERNEL_BUILDER( \
237 Name("DiagPart").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
238 DiagPartOp<GPUDevice, T>)
239
240TF_CALL_double(REGISTER_DIAGPARTOP_GPU);
241TF_CALL_float(REGISTER_DIAGPARTOP_GPU);
242TF_CALL_int32(REGISTER_DIAGPARTOP_GPU);
243TF_CALL_int64(REGISTER_DIAGPARTOP_GPU);
244TF_CALL_COMPLEX_TYPES(REGISTER_DIAGPARTOP_GPU);
245TF_CALL_half(REGISTER_DIAGPARTOP_GPU);
246#undef REGISTER_DIAGPARTOP_GPU
247
248#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
249
250} // namespace tensorflow
251