1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/sparse_slice_op.h"
19
20#include <vector>
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/tensor_shape.h"
25#include "tensorflow/core/util/sparse/sparse_tensor.h"
26
27namespace tensorflow {
28
29typedef Eigen::ThreadPoolDevice CPUDevice;
30
31namespace functor {
32
33template <typename T>
34struct SparseSliceFunctor<CPUDevice, T> {
35 void operator()(OpKernelContext* context, const Tensor& input_indices,
36 const Tensor& input_values, const Tensor& input_shape,
37 const Tensor& input_start, const Tensor& input_size,
38 typename AsyncOpKernel::DoneCallback done) const {
39 (void)done; // Unused (only used in GPU implementation)
40 const int input_dims = input_shape.NumElements();
41
42 sparse::SparseTensor sparse_tensor;
43 TensorShape sparse_tensor_shape;
44 OP_REQUIRES_OK(context,
45 TensorShapeBase<TensorShape>::BuildTensorShapeBase(
46 input_shape.vec<int64_t>(), &sparse_tensor_shape));
47 OP_REQUIRES_OK(context, sparse::SparseTensor::Create(
48 input_indices, input_values,
49 sparse_tensor_shape, &sparse_tensor));
50
51 const gtl::ArraySlice<int64_t> start(input_start.flat<int64_t>().data(),
52 input_dims);
53 const gtl::ArraySlice<int64_t> size(input_size.flat<int64_t>().data(),
54 input_dims);
55
56 const StatusOr<sparse::SparseTensor> output_or =
57 sparse::SparseTensor::Slice<T>(sparse_tensor, start, size);
58 OP_REQUIRES_OK(context, output_or.status());
59 auto output = output_or.value();
60
61 context->set_output(0, output.indices());
62 context->set_output(1, output.values());
63
64 TensorShape output_shape;
65 OP_REQUIRES_OK(context, TensorShapeBase<TensorShape>::BuildTensorShapeBase(
66 output.shape(), &output_shape));
67
68 TensorShape allocated_shape;
69 OP_REQUIRES_OK(context, TensorShapeBase<TensorShape>::BuildTensorShapeBase(
70 {output_shape.dims()}, &allocated_shape));
71
72 Tensor* shape = nullptr;
73 OP_REQUIRES_OK(context,
74 context->allocate_output(2, allocated_shape, &shape));
75 for (int dim = 0; dim < output_shape.dims(); ++dim) {
76 shape->vec<int64_t>()(dim) = output_shape.dim_size(dim);
77 }
78 }
79};
80
81} // namespace functor
82
83namespace {
84
85template <typename Device, typename T>
86void SparseSliceOpImpl(OpKernelContext* context,
87 typename AsyncOpKernel::DoneCallback done = nullptr) {
88 // Note that setting this empty lambda as the default parameter value directly
89 // can cause strange compiler/linker errors, so we do it like this instead.
90 if (!done) {
91 done = [] {};
92 }
93
94 const Tensor& input_indices = context->input(0);
95 const Tensor& input_values = context->input(1);
96 const Tensor& input_shape = context->input(2);
97 const Tensor& input_start = context->input(3);
98 const Tensor& input_size = context->input(4);
99
100 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
101 errors::InvalidArgument(
102 "Input indices should be a matrix but received shape ",
103 input_indices.shape().DebugString()),
104 done);
105 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_values.shape()),
106 errors::InvalidArgument(
107 "Input values should be a vector but received shape ",
108 input_values.shape().DebugString()),
109 done);
110 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_shape.shape()),
111 errors::InvalidArgument(
112 "Input shape should be a vector but received shape ",
113 input_shape.shape().DebugString()),
114 done);
115 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_start.shape()),
116 errors::InvalidArgument(
117 "Input start should be a vector but received shape ",
118 input_start.shape().DebugString()),
119 done);
120 OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input_size.shape()),
121 errors::InvalidArgument(
122 "Input size should be a vector but received shape ",
123 input_size.shape().DebugString()),
124 done);
125
126 const int input_dims = input_shape.NumElements();
127 OP_REQUIRES_ASYNC(context, input_dims == input_start.NumElements(),
128 errors::InvalidArgument(
129 "Expected start to be a vector of length ", input_dims,
130 " but got length ", input_start.NumElements()),
131 done);
132
133 OP_REQUIRES_ASYNC(context, input_dims == input_size.NumElements(),
134 errors::InvalidArgument(
135 "Expected size to be a vector of length ", input_dims,
136 " but got length ", input_size.NumElements()),
137 done);
138
139 functor::SparseSliceFunctor<Device, T>()(context, input_indices, input_values,
140 input_shape, input_start, input_size,
141 done);
142}
143
144} // namespace
145
146template <typename Device, typename T>
147class SparseSliceOp : public OpKernel {
148 public:
149 explicit SparseSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
150
151 void Compute(OpKernelContext* context) override {
152 SparseSliceOpImpl<Device, T>(context);
153 }
154};
155
156#define REGISTER_KERNELS(type) \
157 REGISTER_KERNEL_BUILDER( \
158 Name("SparseSlice").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
159 SparseSliceOp<CPUDevice, type>)
160
161TF_CALL_ALL_TYPES(REGISTER_KERNELS);
162#undef REGISTER_KERNELS
163
164#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
165
166typedef Eigen::GpuDevice GPUDevice;
167
168template <typename T>
169class SparseSliceGPUOp : public AsyncOpKernel {
170 public:
171 explicit SparseSliceGPUOp(OpKernelConstruction* context)
172 : AsyncOpKernel(context) {}
173
174 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
175 SparseSliceOpImpl<GPUDevice, T>(context, done);
176 }
177};
178
179#define REGISTER_KERNELS(type) \
180 REGISTER_KERNEL_BUILDER(Name("SparseSlice") \
181 .Device(DEVICE_GPU) \
182 .HostMemory("shape") \
183 .HostMemory("start") \
184 .HostMemory("size") \
185 .HostMemory("output_shape") \
186 .TypeConstraint<type>("T"), \
187 SparseSliceGPUOp<type>)
188
189TF_CALL_POD_TYPES(REGISTER_KERNELS);
190#undef REGISTER_KERNELS
191
192#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
193
194} // namespace tensorflow
195