1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/sparse_reorder_op.h" |
19 | |
20 | #include <algorithm> |
21 | #include <numeric> |
22 | #include <unordered_map> |
23 | #include <utility> |
24 | |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/framework/tensor_util.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
31 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | using CPUDevice = Eigen::ThreadPoolDevice; |
36 | |
37 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
38 | using GPUDevice = Eigen::GpuDevice; |
39 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
40 | |
41 | namespace functor { |
42 | |
43 | template <typename T> |
44 | struct SparseReorderFunctor<CPUDevice, T> { |
45 | void operator()(OpKernelContext* context, const Tensor& input_ind, |
46 | const Tensor& input_val, const Tensor& input_shape_in) { |
47 | gtl::ArraySlice<int64_t> input_shape(input_shape_in.vec<int64_t>().data(), |
48 | input_shape_in.NumElements()); |
49 | |
50 | gtl::InlinedVector<int64_t, 8> std_order(input_shape.size()); |
51 | std::iota(std_order.begin(), std_order.end(), 0); |
52 | |
53 | // Check if the sparse tensor is already ordered correctly |
54 | sparse::SparseTensor input_sp; |
55 | OP_REQUIRES_OK( |
56 | context, sparse::SparseTensor::Create(input_ind, input_val, input_shape, |
57 | std_order, &input_sp)); |
58 | |
59 | if (input_sp.IndicesValid().ok()) { |
60 | context->set_output(0, input_sp.indices()); |
61 | context->set_output(1, input_sp.values()); |
62 | } else { |
63 | // Deep-copy the input Tensors, then reorder in-place |
64 | sparse::SparseTensor reordered_sp; |
65 | OP_REQUIRES_OK(context, |
66 | sparse::SparseTensor::Create(tensor::DeepCopy(input_ind), |
67 | tensor::DeepCopy(input_val), |
68 | input_shape, &reordered_sp)); |
69 | reordered_sp.Reorder<T>(std_order); |
70 | context->set_output(0, reordered_sp.indices()); |
71 | context->set_output(1, reordered_sp.values()); |
72 | } |
73 | } |
74 | }; |
75 | |
76 | } // namespace functor |
77 | |
78 | template <typename Device, typename T> |
79 | class SparseReorderOp : public OpKernel { |
80 | public: |
81 | explicit SparseReorderOp(OpKernelConstruction* context) : OpKernel(context) {} |
82 | |
83 | void Compute(OpKernelContext* context) override { |
84 | const Tensor& input_ind = context->input(0); |
85 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_ind.shape()), |
86 | errors::InvalidArgument( |
87 | "Input indices should be a matrix but received shape " , |
88 | input_ind.shape().DebugString())); |
89 | |
90 | const Tensor& input_val = context->input(1); |
91 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_val.shape()), |
92 | errors::InvalidArgument( |
93 | "Input values should be a vector but received shape " , |
94 | input_val.shape().DebugString())); |
95 | |
96 | const Tensor& input_shape_in = context->input(2); |
97 | OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), |
98 | errors::InvalidArgument( |
99 | "Input shape should be a vector but received shape " , |
100 | input_shape_in.shape().DebugString())); |
101 | |
102 | functor::SparseReorderFunctor<Device, T>()(context, input_ind, input_val, |
103 | input_shape_in); |
104 | } |
105 | }; |
106 | |
107 | #define REGISTER_KERNELS(type) \ |
108 | REGISTER_KERNEL_BUILDER( \ |
109 | Name("SparseReorder").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
110 | SparseReorderOp<CPUDevice, type>) |
111 | |
112 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
113 | #undef REGISTER_KERNELS |
114 | |
115 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
116 | |
117 | #define REGISTER_GPU_KERNELS(type) \ |
118 | REGISTER_KERNEL_BUILDER( \ |
119 | Name("SparseReorder").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
120 | SparseReorderOp<GPUDevice, type>) |
121 | |
122 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); |
123 | TF_CALL_INTEGRAL_TYPES(REGISTER_GPU_KERNELS); |
124 | REGISTER_GPU_KERNELS(bool); |
125 | #undef REGISTER_GPU_KERNELS |
126 | |
127 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
128 | |
129 | } // namespace tensorflow |
130 | |