1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/kernels/sparse_concat_op.h" |
19 | |
20 | #include <algorithm> |
21 | #include <numeric> |
22 | #include <unordered_map> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/tensor_util.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
32 | #include "tensorflow/core/util/overflow.h" |
33 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | typedef Eigen::ThreadPoolDevice CPUDevice; |
38 | |
39 | namespace functor { |
40 | |
41 | template <typename T> |
42 | struct SparseConcatFunctor<CPUDevice, T> { |
43 | void operator()(OpKernelContext* context, const OpInputList& inds, |
44 | const OpInputList& vals, const OpInputList& shapes, |
45 | int concat_dim) { |
46 | const int N = inds.size(); |
47 | const TensorShape input_shape(shapes[0].vec<int64_t>()); |
48 | const int input_rank = input_shape.dims(); |
49 | |
50 | // The input and output sparse tensors are assumed to be ordered along |
51 | // increasing dimension number. But in order for concat to work properly, |
52 | // order[0] must be concat_dim. So we will reorder the inputs to the |
53 | // concat ordering, concatenate, then reorder back to the standard order. |
54 | // We make a deep copy of the input tensors to ensure that the in-place |
55 | // reorder doesn't create race conditions for other ops that may be |
56 | // concurrently reading the indices and values tensors. |
57 | |
58 | gtl::InlinedVector<int64, 8> std_order(input_rank); |
59 | std::iota(std_order.begin(), std_order.end(), 0); |
60 | |
61 | std::vector<int64_t> concat_order; |
62 | concat_order.reserve(input_rank); |
63 | concat_order.push_back(concat_dim); |
64 | for (int j = 0; j < input_rank; ++j) { |
65 | if (j != concat_dim) { |
66 | concat_order.push_back(j); |
67 | } |
68 | } |
69 | |
70 | std::vector<sparse::SparseTensor> sp_inputs; |
71 | for (int i = 0; i < N; ++i) { |
72 | const TensorShape current_shape(shapes[i].vec<int64_t>()); |
73 | sparse::SparseTensor tensor; |
74 | OP_REQUIRES_OK(context, |
75 | sparse::SparseTensor::Create( |
76 | tensor::DeepCopy(inds[i]), tensor::DeepCopy(vals[i]), |
77 | current_shape, std_order, &tensor)); |
78 | sp_inputs.push_back(std::move(tensor)); |
79 | sp_inputs[i].Reorder<T>(concat_order); |
80 | } |
81 | |
82 | sparse::SparseTensor concat = sparse::SparseTensor::Concat<T>(sp_inputs); |
83 | concat.Reorder<T>(std_order); |
84 | |
85 | context->set_output(0, concat.indices()); |
86 | context->set_output(1, concat.values()); |
87 | } |
88 | }; |
89 | |
90 | } // namespace functor |
91 | |
92 | template <typename Device, typename T> |
93 | class SparseConcatOp : public OpKernel { |
94 | public: |
95 | explicit SparseConcatOp(OpKernelConstruction* context) : OpKernel(context) { |
96 | OP_REQUIRES_OK(context, context->GetAttr("concat_dim" , &concat_dim_attr_)); |
97 | } |
98 | |
99 | void Compute(OpKernelContext* context) override { |
100 | OpInputList inds; |
101 | OP_REQUIRES_OK(context, context->input_list("indices" , &inds)); |
102 | const int N = inds.size(); |
103 | for (int i = 0; i < N; i++) { |
104 | OP_REQUIRES(context, TensorShapeUtils::IsMatrix(inds[i].shape()), |
105 | errors::InvalidArgument( |
106 | "Input indices should be a matrix but received shape " , |
107 | inds[i].shape().DebugString(), " at position " , i)); |
108 | } |
109 | |
110 | OpInputList vals; |
111 | OP_REQUIRES_OK(context, context->input_list("values" , &vals)); |
112 | OP_REQUIRES(context, vals.size() == N, |
113 | errors::InvalidArgument("Expected " , N, " input values, got " , |
114 | vals.size())); |
115 | for (int i = 0; i < N; i++) { |
116 | OP_REQUIRES(context, TensorShapeUtils::IsVector(vals[i].shape()), |
117 | errors::InvalidArgument( |
118 | "Input values should be a vector but received shape " , |
119 | vals[i].shape().DebugString(), " at position " , i)); |
120 | } |
121 | |
122 | OpInputList shapes; |
123 | OP_REQUIRES_OK(context, context->input_list("shapes" , &shapes)); |
124 | OP_REQUIRES(context, shapes.size() == N, |
125 | errors::InvalidArgument("Expected " , N, " input shapes, got " , |
126 | shapes.size())); |
127 | bool overflow_ocurred = false; |
128 | for (int i = 0; i < N; i++) { |
129 | int64_t new_num_elements = 1; |
130 | OP_REQUIRES(context, TensorShapeUtils::IsVector(shapes[i].shape()), |
131 | errors::InvalidArgument( |
132 | "Input shapes should be a vector but received shape " , |
133 | shapes[i].shape().DebugString(), " at position " , i)); |
134 | auto input_shape_vector = shapes[i].vec<int64_t>(); |
135 | for (int j = 0; j < input_shape_vector.size(); j++) { |
136 | new_num_elements = |
137 | MultiplyWithoutOverflow(new_num_elements, input_shape_vector(j)); |
138 | if (new_num_elements < 0) { |
139 | overflow_ocurred = true; |
140 | break; |
141 | } |
142 | } |
143 | |
144 | if (overflow_ocurred) { |
145 | break; |
146 | } |
147 | } |
148 | |
149 | OP_REQUIRES( |
150 | context, !overflow_ocurred, |
151 | errors::Internal("Encountered overflow from large input shape." )); |
152 | |
153 | const TensorShape input_shape(shapes[0].vec<int64_t>()); |
154 | const int input_rank = input_shape.dims(); |
155 | const int concat_dim = (concat_dim_attr_ < 0) |
156 | ? input_rank + concat_dim_attr_ |
157 | : concat_dim_attr_; |
158 | OP_REQUIRES(context, concat_dim >= 0 && concat_dim < input_rank, |
159 | errors::InvalidArgument("Concat dimension must be in range [" , |
160 | -input_rank, ", " , input_rank, |
161 | "), got " , concat_dim_attr_)); |
162 | TensorShape output_shape = input_shape; |
163 | for (int i = 1; i < N; ++i) { |
164 | const TensorShape current_shape(shapes[i].vec<int64_t>()); |
165 | OP_REQUIRES( |
166 | context, current_shape.dims() == input_rank, |
167 | errors::InvalidArgument( |
168 | "Ranks of all input tensors must match: expected " , input_rank, |
169 | " but got " , current_shape.dims(), " at position " , i)); |
170 | for (int j = 0; j < input_rank; ++j) { |
171 | if (j != concat_dim) { |
172 | OP_REQUIRES( |
173 | context, input_shape.dim_size(j) == current_shape.dim_size(j), |
174 | errors::InvalidArgument( |
175 | "Input shapes must match: expected " , input_shape.dim_size(j), |
176 | " for dimension " , j, " but got " , current_shape.dim_size(j), |
177 | " at position " , i)); |
178 | } else { |
179 | output_shape.set_dim( |
180 | j, output_shape.dim_size(j) + current_shape.dim_size(j)); |
181 | } |
182 | } |
183 | } |
184 | |
185 | Tensor* output_shape_out = nullptr; |
186 | OP_REQUIRES_OK( |
187 | context, context->allocate_output(2, TensorShape({output_shape.dims()}), |
188 | &output_shape_out)); |
189 | auto output_shape_t = output_shape_out->vec<int64_t>(); |
190 | for (int j = 0; j < output_shape.dims(); ++j) { |
191 | output_shape_t(j) = output_shape.dim_size(j); |
192 | } |
193 | |
194 | int64_t output_nnz = 0; |
195 | for (int i = 0; i < N; ++i) { |
196 | output_nnz += inds[i].dim_size(0); |
197 | } |
198 | if (output_nnz == 0) { |
199 | Tensor* output_inds = nullptr; |
200 | OP_REQUIRES_OK(context, |
201 | context->allocate_output(0, TensorShape({0, input_rank}), |
202 | &output_inds)); |
203 | Tensor* output_vals = nullptr; |
204 | OP_REQUIRES_OK( |
205 | context, context->allocate_output(1, TensorShape({0}), &output_vals)); |
206 | return; // No work to do |
207 | } |
208 | |
209 | functor::SparseConcatFunctor<Device, T>()(context, inds, vals, shapes, |
210 | concat_dim); |
211 | } |
212 | |
213 | private: |
214 | int concat_dim_attr_; |
215 | }; |
216 | |
217 | #define REGISTER_KERNELS(type) \ |
218 | REGISTER_KERNEL_BUILDER( \ |
219 | Name("SparseConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
220 | SparseConcatOp<CPUDevice, type>) |
221 | |
222 | TF_CALL_ALL_TYPES(REGISTER_KERNELS); |
223 | #undef REGISTER_KERNELS |
224 | |
225 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
226 | |
227 | typedef Eigen::GpuDevice GPUDevice; |
228 | |
229 | #define REGISTER_KERNELS(type) \ |
230 | REGISTER_KERNEL_BUILDER(Name("SparseConcat") \ |
231 | .Device(DEVICE_GPU) \ |
232 | .HostMemory("shapes") \ |
233 | .HostMemory("output_shape") \ |
234 | .TypeConstraint<type>("T"), \ |
235 | SparseConcatOp<GPUDevice, type>) |
236 | TF_CALL_POD_TYPES(REGISTER_KERNELS); |
237 | #undef REGISTER_KERNELS |
238 | |
239 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
240 | |
241 | } // namespace tensorflow |
242 | |