1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/sparse_split_op.h"
19
20#include <vector>
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/util/sparse/sparse_tensor.h"
25
26namespace tensorflow {
27
28typedef Eigen::ThreadPoolDevice CPUDevice;
29
30namespace functor {
31
32template <typename T>
33struct SparseSplitFunctor<CPUDevice, T> {
34 void operator()(OpKernelContext* context, const Tensor& input_indices,
35 const Tensor& input_values, const TensorShape& dense_shape,
36 const int64_t axis, const int num_split,
37 typename AsyncOpKernel::DoneCallback done) {
38 (void)done; // Unused (only used in GPU implementation)
39 sparse::SparseTensor sparse_tensor;
40 OP_REQUIRES_OK(context,
41 sparse::SparseTensor::Create(input_indices, input_values,
42 dense_shape, &sparse_tensor));
43
44 std::vector<sparse::SparseTensor> outputs;
45 OP_REQUIRES_OK(context, sparse::SparseTensor::Split<T>(
46 sparse_tensor, axis, num_split, &outputs));
47
48 for (int slice_index = 0; slice_index < num_split; ++slice_index) {
49 context->set_output(slice_index, outputs[slice_index].indices());
50 context->set_output(slice_index + num_split,
51 outputs[slice_index].values());
52 Tensor* shape = nullptr;
53 OP_REQUIRES_OK(context, context->allocate_output(
54 slice_index + 2 * num_split,
55 {outputs[slice_index].dims()}, &shape));
56 auto output_shape = outputs[slice_index].shape();
57 for (int dim = 0; dim < outputs[slice_index].dims(); ++dim) {
58 shape->vec<int64_t>()(dim) = output_shape[dim];
59 }
60 }
61 }
62};
63
64} // namespace functor
65
66namespace {
67
68template <typename Device, typename T>
69void SparseSplitOpImpl(OpKernelContext* context, int num_split,
70 AsyncOpKernel::DoneCallback done = nullptr) {
71 // Note that setting this empty lambda as the default parameter value directly
72 // can cause strange compiler/linker errors, so we do it like this instead.
73 if (!done) {
74 done = [] {};
75 }
76
77 const Tensor& input_axis = context->input(0);
78 const Tensor& input_indices = context->input(1);
79 const Tensor& input_values = context->input(2);
80 const Tensor& input_shape = context->input(3);
81
82 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsScalar(input_axis.shape()),
83 errors::InvalidArgument(
84 "Input axis should be a scalar but received shape ",
85 input_axis.shape().DebugString()),
86 done);
87 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
88 errors::InvalidArgument(
89 "Input indices should be a matrix but received shape ",
90 input_indices.shape().DebugString()),
91 done);
92 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_values.shape()),
93 errors::InvalidArgument(
94 "Input values should be a vector but received shape ",
95 input_indices.shape().DebugString()),
96 done);
97 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_shape.shape()),
98 errors::InvalidArgument(
99 "Input shape should be a vector but received shape ",
100 input_shape.shape().DebugString()),
101 done);
102
103 const int64_t axis_input = input_axis.scalar<int64_t>()();
104 const int64_t input_rank = input_shape.vec<int64_t>().size();
105 const int64_t axis = (axis_input < 0) ? input_rank + axis_input : axis_input;
106
107 OP_REQUIRES_ASYNC(
108 context, axis >= 0 && axis < input_rank,
109 errors::InvalidArgument("Input axis should be in range [", -input_rank,
110 ", ", input_rank, "), got ", axis_input),
111 done);
112
113 OP_REQUIRES_ASYNC(
114 context, num_split >= 1 && num_split <= input_shape.vec<int64_t>()(axis),
115 errors::InvalidArgument("Input num_split should be between 1 "
116 "and the splitting dimension size (",
117 input_shape.vec<int64_t>()(axis), "), got ",
118 num_split),
119 done);
120
121 // Prevent overflow by constructing the dense shape separately
122 TensorShape dense_shape;
123 const auto input_shape_flat = input_shape.flat<int64_t>();
124 for (int i = 0; i < input_shape.NumElements(); i++) {
125 OP_REQUIRES_OK_ASYNC(
126 context, dense_shape.AddDimWithStatus(input_shape_flat(i)), done);
127 }
128
129 functor::SparseSplitFunctor<Device, T>()(context, input_indices, input_values,
130 dense_shape, axis, num_split, done);
131}
132
133} // namespace
134
135template <typename T>
136class SparseSplitOp : public OpKernel {
137 public:
138 explicit SparseSplitOp(OpKernelConstruction* context) : OpKernel(context) {
139 OP_REQUIRES_OK(context, context->GetAttr("num_split", &num_split_));
140 }
141
142 void Compute(OpKernelContext* context) override {
143 SparseSplitOpImpl<CPUDevice, T>(context, num_split_);
144 }
145
146 private:
147 int num_split_;
148};
149
150#define REGISTER_KERNELS(type) \
151 REGISTER_KERNEL_BUILDER( \
152 Name("SparseSplit").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
153 SparseSplitOp<type>)
154
155TF_CALL_ALL_TYPES(REGISTER_KERNELS);
156#undef REGISTER_KERNELS
157
158#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
159
160typedef Eigen::GpuDevice GPUDevice;
161
162// The GPU implementation is async because it requires waiting for a
163// host->device memcpy before the output is allocated (similar to
164// SegmentSumGPUOp).
165template <typename T>
166class SparseSplitGPUOp : public AsyncOpKernel {
167 public:
168 explicit SparseSplitGPUOp(OpKernelConstruction* context)
169 : AsyncOpKernel(context) {
170 OP_REQUIRES_OK(context, context->GetAttr("num_split", &num_split_));
171 }
172
173 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
174 SparseSplitOpImpl<GPUDevice, T>(context, num_split_, done);
175 }
176
177 private:
178 int num_split_;
179};
180
181#define REGISTER_KERNELS(type) \
182 REGISTER_KERNEL_BUILDER(Name("SparseSplit") \
183 .Device(DEVICE_GPU) \
184 .HostMemory("split_dim") \
185 .HostMemory("shape") \
186 .HostMemory("output_shape") \
187 .TypeConstraint<type>("T"), \
188 SparseSplitGPUOp<type>)
189TF_CALL_POD_TYPES(REGISTER_KERNELS);
190#undef REGISTER_KERNELS
191
192#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
193
194} // namespace tensorflow
195