1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/sparse_concat_op.h"
19
20#include <algorithm>
21#include <numeric>
22#include <unordered_map>
23#include <utility>
24#include <vector>
25
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor_util.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/lib/gtl/inlined_vector.h"
32#include "tensorflow/core/util/overflow.h"
33#include "tensorflow/core/util/sparse/sparse_tensor.h"
34
35namespace tensorflow {
36
37typedef Eigen::ThreadPoolDevice CPUDevice;
38
39namespace functor {
40
41template <typename T>
42struct SparseConcatFunctor<CPUDevice, T> {
43 void operator()(OpKernelContext* context, const OpInputList& inds,
44 const OpInputList& vals, const OpInputList& shapes,
45 int concat_dim) {
46 const int N = inds.size();
47 const TensorShape input_shape(shapes[0].vec<int64_t>());
48 const int input_rank = input_shape.dims();
49
50 // The input and output sparse tensors are assumed to be ordered along
51 // increasing dimension number. But in order for concat to work properly,
52 // order[0] must be concat_dim. So we will reorder the inputs to the
53 // concat ordering, concatenate, then reorder back to the standard order.
54 // We make a deep copy of the input tensors to ensure that the in-place
55 // reorder doesn't create race conditions for other ops that may be
56 // concurrently reading the indices and values tensors.
57
58 gtl::InlinedVector<int64, 8> std_order(input_rank);
59 std::iota(std_order.begin(), std_order.end(), 0);
60
61 std::vector<int64_t> concat_order;
62 concat_order.reserve(input_rank);
63 concat_order.push_back(concat_dim);
64 for (int j = 0; j < input_rank; ++j) {
65 if (j != concat_dim) {
66 concat_order.push_back(j);
67 }
68 }
69
70 std::vector<sparse::SparseTensor> sp_inputs;
71 for (int i = 0; i < N; ++i) {
72 const TensorShape current_shape(shapes[i].vec<int64_t>());
73 sparse::SparseTensor tensor;
74 OP_REQUIRES_OK(context,
75 sparse::SparseTensor::Create(
76 tensor::DeepCopy(inds[i]), tensor::DeepCopy(vals[i]),
77 current_shape, std_order, &tensor));
78 sp_inputs.push_back(std::move(tensor));
79 sp_inputs[i].Reorder<T>(concat_order);
80 }
81
82 sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs);
83 concat.Reorder<T>(std_order);
84
85 context->set_output(0, concat.indices());
86 context->set_output(1, concat.values());
87 }
88};
89
90} // namespace functor
91
92template <typename Device, typename T>
93class SparseConcatOp : public OpKernel {
94 public:
95 explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) {
96 OP_REQUIRES_OK(context, context->GetAttr("concat_dim", &concat_dim_attr_));
97 }
98
99 void Compute(OpKernelContext* context) override {
100 OpInputList inds;
101 OP_REQUIRES_OK(context, context->input_list("indices", &inds));
102 const int N = inds.size();
103 for (int i = 0; i < N; i++) {
104 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()),
105 errors::InvalidArgument(
106 "Input indices should be a matrix but received shape ",
107 inds[i].shape().DebugString(), " at position ", i));
108 }
109
110 OpInputList vals;
111 OP_REQUIRES_OK(context, context->input_list("values", &vals));
112 OP_REQUIRES(context, vals.size() == N,
113 errors::InvalidArgument("Expected ", N, " input values, got ",
114 vals.size()));
115 for (int i = 0; i < N; i++) {
116 OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()),
117 errors::InvalidArgument(
118 "Input values should be a vector but received shape ",
119 vals[i].shape().DebugString(), " at position ", i));
120 }
121
122 OpInputList shapes;
123 OP_REQUIRES_OK(context, context->input_list("shapes", &shapes));
124 OP_REQUIRES(context, shapes.size() == N,
125 errors::InvalidArgument("Expected ", N, " input shapes, got ",
126 shapes.size()));
127 bool overflow_ocurred = false;
128 for (int i = 0; i < N; i++) {
129 int64_t new_num_elements = 1;
130 OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()),
131 errors::InvalidArgument(
132 "Input shapes should be a vector but received shape ",
133 shapes[i].shape().DebugString(), " at position ", i));
134 auto input_shape_vector = shapes[i].vec<int64_t>();
135 for (int j = 0; j < input_shape_vector.size(); j++) {
136 new_num_elements =
137 MultiplyWithoutOverflow(new_num_elements, input_shape_vector(j));
138 if (new_num_elements < 0) {
139 overflow_ocurred = true;
140 break;
141 }
142 }
143
144 if (overflow_ocurred) {
145 break;
146 }
147 }
148
149 OP_REQUIRES(
150 context, !overflow_ocurred,
151 errors::Internal("Encountered overflow from large input shape."));
152
153 const TensorShape input_shape(shapes[0].vec<int64_t>());
154 const int input_rank = input_shape.dims();
155 const int concat_dim = (concat_dim_attr_ < 0)
156 ? input_rank + concat_dim_attr_
157 : concat_dim_attr_;
158 OP_REQUIRES(context, concat_dim >= 0 && concat_dim < input_rank,
159 errors::InvalidArgument("Concat dimension must be in range [",
160 -input_rank, ", ", input_rank,
161 "), got ", concat_dim_attr_));
162 TensorShape output_shape = input_shape;
163 for (int i = 1; i < N; ++i) {
164 const TensorShape current_shape(shapes[i].vec<int64_t>());
165 OP_REQUIRES(
166 context, current_shape.dims() == input_rank,
167 errors::InvalidArgument(
168 "Ranks of all input tensors must match: expected ", input_rank,
169 " but got ", current_shape.dims(), " at position ", i));
170 for (int j = 0; j < input_rank; ++j) {
171 if (j != concat_dim) {
172 OP_REQUIRES(
173 context, input_shape.dim_size(j) == current_shape.dim_size(j),
174 errors::InvalidArgument(
175 "Input shapes must match: expected ", input_shape.dim_size(j),
176 " for dimension ", j, " but got ", current_shape.dim_size(j),
177 " at position ", i));
178 } else {
179 output_shape.set_dim(
180 j, output_shape.dim_size(j) + current_shape.dim_size(j));
181 }
182 }
183 }
184
185 Tensor* output_shape_out = nullptr;
186 OP_REQUIRES_OK(
187 context, context->allocate_output(2, TensorShape({output_shape.dims()}),
188 &output_shape_out));
189 auto output_shape_t = output_shape_out->vec<int64_t>();
190 for (int j = 0; j < output_shape.dims(); ++j) {
191 output_shape_t(j) = output_shape.dim_size(j);
192 }
193
194 int64_t output_nnz = 0;
195 for (int i = 0; i < N; ++i) {
196 output_nnz += inds[i].dim_size(0);
197 }
198 if (output_nnz == 0) {
199 Tensor* output_inds = nullptr;
200 OP_REQUIRES_OK(context,
201 context->allocate_output(0, TensorShape({0, input_rank}),
202 &output_inds));
203 Tensor* output_vals = nullptr;
204 OP_REQUIRES_OK(
205 context, context->allocate_output(1, TensorShape({0}), &output_vals));
206 return; // No work to do
207 }
208
209 functor::SparseConcatFunctor<Device, T>()(context, inds, vals, shapes,
210 concat_dim);
211 }
212
213 private:
214 int concat_dim_attr_;
215};
216
217#define REGISTER_KERNELS(type) \
218 REGISTER_KERNEL_BUILDER( \
219 Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
220 SparseConcatOp<CPUDevice, type>)
221
222TF_CALL_ALL_TYPES(REGISTER_KERNELS);
223#undef REGISTER_KERNELS
224
225#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
226
227typedef Eigen::GpuDevice GPUDevice;
228
229#define REGISTER_KERNELS(type) \
230 REGISTER_KERNEL_BUILDER(Name("SparseConcat") \
231 .Device(DEVICE_GPU) \
232 .HostMemory("shapes") \
233 .HostMemory("output_shape") \
234 .TypeConstraint<type>("T"), \
235 SparseConcatOp<GPUDevice, type>)
236TF_CALL_POD_TYPES(REGISTER_KERNELS);
237#undef REGISTER_KERNELS
238
239#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
240
241} // namespace tensorflow
242