1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/sparse_split_op.h" |
19 | |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | typedef Eigen::ThreadPoolDevice CPUDevice; |
29 | |
30 | namespace functor { |
31 | |
32 | template <typename T> |
33 | struct SparseSplitFunctor<CPUDevice, T> { |
34 | void operator()(OpKernelContext* context, const Tensor& input_indices, |
35 | const Tensor& input_values, const TensorShape& dense_shape, |
36 | const int64_t axis, const int num_split, |
37 | typename AsyncOpKernel::DoneCallback done) { |
38 | (void)done; // Unused (only used in GPU implementation) |
39 | sparse::SparseTensor sparse_tensor; |
40 | OP_REQUIRES_OK(context, |
41 | sparse::SparseTensor::Create(input_indices, input_values, |
42 | dense_shape, &sparse_tensor)); |
43 | |
44 | std::vector<sparse::SparseTensor> outputs; |
45 | OP_REQUIRES_OK(context, sparse::SparseTensor::Split<T>( |
46 | sparse_tensor, axis, num_split, &outputs)); |
47 | |
48 | for (int slice_index = 0; slice_index < num_split; ++slice_index) { |
49 | context->set_output(slice_index, outputs[slice_index].indices()); |
50 | context->set_output(slice_index + num_split, |
51 | outputs[slice_index].values()); |
52 | Tensor* shape = nullptr; |
53 | OP_REQUIRES_OK(context, context->allocate_output( |
54 | slice_index + 2 * num_split, |
55 | {outputs[slice_index].dims()}, &shape)); |
56 | auto output_shape = outputs[slice_index].shape(); |
57 | for (int dim = 0; dim < outputs[slice_index].dims(); ++dim) { |
58 | shape->vec<int64_t>()(dim) = output_shape[dim]; |
59 | } |
60 | } |
61 | } |
62 | }; |
63 | |
64 | } // namespace functor |
65 | |
66 | namespace { |
67 | |
68 | template <typename Device, typename T> |
69 | void SparseSplitOpImpl(OpKernelContext* context, int num_split, |
70 | AsyncOpKernel::DoneCallback done = nullptr) { |
71 | // Note that setting this empty lambda as the default parameter value directly |
72 | // can cause strange compiler/linker errors, so we do it like this instead. |
73 | if (!done) { |
74 | done = [] {}; |
75 | } |
76 | |
77 | const Tensor& input_axis = context->input(0); |
78 | const Tensor& input_indices = context->input(1); |
79 | const Tensor& input_values = context->input(2); |
80 | const Tensor& input_shape = context->input(3); |
81 | |
82 | OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsScalar(input_axis.shape()), |
83 | errors::InvalidArgument( |
84 | "Input axis should be a scalar but received shape " , |
85 | input_axis.shape().DebugString()), |
86 | done); |
87 | OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(input_indices.shape()), |
88 | errors::InvalidArgument( |
89 | "Input indices should be a matrix but received shape " , |
90 | input_indices.shape().DebugString()), |
91 | done); |
92 | OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_values.shape()), |
93 | errors::InvalidArgument( |
94 | "Input values should be a vector but received shape " , |
95 | input_indices.shape().DebugString()), |
96 | done); |
97 | OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_shape.shape()), |
98 | errors::InvalidArgument( |
99 | "Input shape should be a vector but received shape " , |
100 | input_shape.shape().DebugString()), |
101 | done); |
102 | |
103 | const int64_t axis_input = input_axis.scalar<int64_t>()(); |
104 | const int64_t input_rank = input_shape.vec<int64_t>().size(); |
105 | const int64_t axis = (axis_input < 0) ? input_rank + axis_input : axis_input; |
106 | |
107 | OP_REQUIRES_ASYNC( |
108 | context, axis >= 0 && axis < input_rank, |
109 | errors::InvalidArgument("Input axis should be in range [" , -input_rank, |
110 | ", " , input_rank, "), got " , axis_input), |
111 | done); |
112 | |
113 | OP_REQUIRES_ASYNC( |
114 | context, num_split >= 1 && num_split <= input_shape.vec<int64_t>()(axis), |
115 | errors::InvalidArgument("Input num_split should be between 1 " |
116 | "and the splitting dimension size (" , |
117 | input_shape.vec<int64_t>()(axis), "), got " , |
118 | num_split), |
119 | done); |
120 | |
121 | // Prevent overflow by constructing the dense shape separately |
122 | TensorShape dense_shape; |
123 | const auto input_shape_flat = input_shape.flat<int64_t>(); |
124 | for (int i = 0; i < input_shape.NumElements(); i++) { |
125 | OP_REQUIRES_OK_ASYNC( |
126 | context, dense_shape.AddDimWithStatus(input_shape_flat(i)), done); |
127 | } |
128 | |
129 | functor::SparseSplitFunctor<Device, T>()(context, input_indices, input_values, |
130 | dense_shape, axis, num_split, done); |
131 | } |
132 | |
133 | } // namespace |
134 | |
135 | template <typename T> |
136 | class SparseSplitOp : public OpKernel { |
137 | public: |
138 | explicit SparseSplitOp(OpKernelConstruction* context) : OpKernel(context) { |
139 | OP_REQUIRES_OK(context, context->GetAttr("num_split" , &num_split_)); |
140 | } |
141 | |
142 | void Compute(OpKernelContext* context) override { |
143 | SparseSplitOpImpl<CPUDevice, T>(context, num_split_); |
144 | } |
145 | |
146 | private: |
147 | int num_split_; |
148 | }; |
149 | |
150 | #define REGISTER_KERNELS(type) \ |
151 | REGISTER_KERNEL_BUILDER( \ |
152 | Name("SparseSplit").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
153 | SparseSplitOp<type>) |
154 | |
155 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
156 | #undef REGISTER_KERNELS |
157 | |
158 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
159 | |
160 | typedef Eigen::GpuDevice GPUDevice; |
161 | |
162 | // The GPU implementation is async because it requires waiting for a |
163 | // host->device memcpy before the output is allocated (similar to |
164 | // SegmentSumGPUOp). |
165 | template <typename T> |
166 | class SparseSplitGPUOp : public AsyncOpKernel { |
167 | public: |
168 | explicit SparseSplitGPUOp(OpKernelConstruction* context) |
169 | : AsyncOpKernel(context) { |
170 | OP_REQUIRES_OK(context, context->GetAttr("num_split" , &num_split_)); |
171 | } |
172 | |
173 | void ComputeAsync(OpKernelContext* context, DoneCallback done) override { |
174 | SparseSplitOpImpl<GPUDevice, T>(context, num_split_, done); |
175 | } |
176 | |
177 | private: |
178 | int num_split_; |
179 | }; |
180 | |
181 | #define REGISTER_KERNELS(type) \ |
182 | REGISTER_KERNEL_BUILDER(Name("SparseSplit") \ |
183 | .Device(DEVICE_GPU) \ |
184 | .HostMemory("split_dim") \ |
185 | .HostMemory("shape") \ |
186 | .HostMemory("output_shape") \ |
187 | .TypeConstraint<type>("T"), \ |
188 | SparseSplitGPUOp<type>) |
189 | TF_CALL_POD_TYPES(REGISTER_KERNELS); |
190 | #undef REGISTER_KERNELS |
191 | |
192 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
193 | |
194 | } // namespace tensorflow |
195 | |