1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/sparse_reorder_op.h"
19
20#include <algorithm>
21#include <numeric>
22#include <unordered_map>
23#include <utility>
24
25#include "tensorflow/core/framework/op_kernel.h"
26#include "tensorflow/core/framework/register_types.h"
27#include "tensorflow/core/framework/tensor.h"
28#include "tensorflow/core/framework/tensor_util.h"
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/lib/gtl/inlined_vector.h"
31#include "tensorflow/core/util/sparse/sparse_tensor.h"
32
33namespace tensorflow {
34
35using CPUDevice = Eigen::ThreadPoolDevice;
36
37#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
38using GPUDevice = Eigen::GpuDevice;
39#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
40
41namespace functor {
42
43template <typename T>
44struct SparseReorderFunctor<CPUDevice, T> {
45 void operator()(OpKernelContext* context, const Tensor& input_ind,
46 const Tensor& input_val, const Tensor& input_shape_in) {
47 gtl::ArraySlice<int64_t> input_shape(input_shape_in.vec<int64_t>().data(),
48 input_shape_in.NumElements());
49
50 gtl::InlinedVector<int64_t, 8> std_order(input_shape.size());
51 std::iota(std_order.begin(), std_order.end(), 0);
52
53 // Check if the sparse tensor is already ordered correctly
54 sparse::SparseTensor input_sp;
55 OP_REQUIRES_OK(
56 context, sparse::SparseTensor::Create(input_ind, input_val, input_shape,
57 std_order, &input_sp));
58
59 if (input_sp.IndicesValid().ok()) {
60 context->set_output(0, input_sp.indices());
61 context->set_output(1, input_sp.values());
62 } else {
63 // Deep-copy the input Tensors, then reorder in-place
64 sparse::SparseTensor reordered_sp;
65 OP_REQUIRES_OK(context,
66 sparse::SparseTensor::Create(tensor::DeepCopy(input_ind),
67 tensor::DeepCopy(input_val),
68 input_shape, &reordered_sp));
69 reordered_sp.Reorder<T>(std_order);
70 context->set_output(0, reordered_sp.indices());
71 context->set_output(1, reordered_sp.values());
72 }
73 }
74};
75
76} // namespace functor
77
78template <typename Device, typename T>
79class SparseReorderOp : public OpKernel {
80 public:
81 explicit SparseReorderOp(OpKernelConstruction* context) : OpKernel(context) {}
82
83 void Compute(OpKernelContext* context) override {
84 const Tensor& input_ind = context->input(0);
85 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_ind.shape()),
86 errors::InvalidArgument(
87 "Input indices should be a matrix but received shape ",
88 input_ind.shape().DebugString()));
89
90 const Tensor& input_val = context->input(1);
91 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_val.shape()),
92 errors::InvalidArgument(
93 "Input values should be a vector but received shape ",
94 input_val.shape().DebugString()));
95
96 const Tensor& input_shape_in = context->input(2);
97 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()),
98 errors::InvalidArgument(
99 "Input shape should be a vector but received shape ",
100 input_shape_in.shape().DebugString()));
101
102 functor::SparseReorderFunctor<Device, T>()(context, input_ind, input_val,
103 input_shape_in);
104 }
105};
106
107#define REGISTER_KERNELS(type) \
108 REGISTER_KERNEL_BUILDER( \
109 Name("SparseReorder").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
110 SparseReorderOp<CPUDevice, type>)
111
112TF_CALL_ALL_TYPES(REGISTER_KERNELS);
113#undef REGISTER_KERNELS
114
115#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
116
117#define REGISTER_GPU_KERNELS(type) \
118 REGISTER_KERNEL_BUILDER( \
119 Name("SparseReorder").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
120 SparseReorderOp<GPUDevice, type>)
121
122TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
123TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS);
124REGISTER_GPU_KERNELS(bool);
125#undef REGISTER_GPU_KERNELS
126
127#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
128
129} // namespace tensorflow
130