1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #include <limits> |
19 | #include <vector> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/register_types.h" |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_types.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/kernels/concat_lib.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | #include "tensorflow/core/platform/errors.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | typedef Eigen::ThreadPoolDevice CPUDevice; |
36 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
37 | typedef Eigen::GpuDevice GPUDevice; |
38 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
39 | |
40 | enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; |
41 | |
42 | // -------------------------------------------------------------------------- |
43 | template <typename Device, typename T, AxisArgumentName AxisArgName> |
44 | class ConcatBaseOp : public OpKernel { |
45 | public: |
46 | typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> |
47 | ConstMatrixVector; |
48 | |
49 | explicit ConcatBaseOp(OpKernelConstruction* c) |
50 | : OpKernel(c), |
51 | axis_attribute_name_(AxisArgName == NAME_IS_AXIS ? "axis" |
52 | : AxisArgName == NAME_IS_CONCAT_DIM |
53 | ? "concat_dim" |
54 | : "<invalid>" ) { |
55 | int unused; |
56 | OP_REQUIRES_OK( |
57 | c, InputRange(axis_attribute_name_, &axis_input_index_, &unused)); |
58 | OP_REQUIRES_OK(c, InputRange("values" , &values_input_start_index_, |
59 | &values_input_end_index_)); |
60 | } |
61 | |
62 | void Compute(OpKernelContext* c) override { |
63 | const Tensor& concat_dim_tensor = c->input(axis_input_index_); |
64 | |
65 | // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars. |
66 | OP_REQUIRES(c, |
67 | (TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) || |
68 | (TensorShapeUtils::IsVector(concat_dim_tensor.shape()) && |
69 | concat_dim_tensor.shape().dim_size(0) == 1)), |
70 | errors::InvalidArgument( |
71 | axis_attribute_name_, |
72 | " tensor should be a scalar integer, but got shape " , |
73 | concat_dim_tensor.shape().DebugString())); |
74 | int64_t concat_dim; |
75 | // In case of ConcatV2, "axis" could be int32 or int64 |
76 | if (AxisArgName == NAME_IS_AXIS) { |
77 | OP_REQUIRES( |
78 | c, |
79 | (concat_dim_tensor.dtype() == DT_INT32 || |
80 | concat_dim_tensor.dtype() == DT_INT64), |
81 | errors::InvalidArgument(axis_attribute_name_, |
82 | " tensor should be int32 or int64, but got " , |
83 | DataTypeString(concat_dim_tensor.dtype()))); |
84 | } else { |
85 | OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32), |
86 | errors::InvalidArgument( |
87 | axis_attribute_name_, " tensor should be int32, but got " , |
88 | DataTypeString(concat_dim_tensor.dtype()))); |
89 | } |
90 | if (concat_dim_tensor.dtype() == DT_INT32) { |
91 | concat_dim = |
92 | internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()()); |
93 | } else { |
94 | concat_dim = |
95 | internal::SubtleMustCopy(concat_dim_tensor.scalar<int64_t>()()); |
96 | } |
97 | |
98 | const int N = values_input_end_index_ - values_input_start_index_; |
99 | const Tensor& first_input = c->input(values_input_start_index_); |
100 | const int input_dims = first_input.dims(); |
101 | const TensorShape& input_shape = first_input.shape(); |
102 | |
103 | int32_t axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; |
104 | // concat_dim==0 allows concatenating a list of scalars into a vector. |
105 | OP_REQUIRES(c, (0 <= axis && axis < input_dims) || concat_dim == 0, |
106 | errors::InvalidArgument( |
107 | "ConcatOp : Expected concatenating dimensions in the range " |
108 | "[" , |
109 | -input_dims, ", " , input_dims, "), but got " , concat_dim)); |
110 | // Note that we reduce the concat of n-dimensional tensors into a two |
111 | // dimensional concat. Assuming the dimensions of any input/output |
112 | // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along |
113 | // the dimension indicated with size y0, we flatten it to {x, y}, where y = |
114 | // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1). |
115 | ConstMatrixVector inputs_flat; |
116 | inputs_flat.reserve(N); |
117 | int64_t inputs_flat_dim0 = 1; |
118 | for (int d = 0; d < axis; ++d) { |
119 | inputs_flat_dim0 *= input_shape.dim_size(d); |
120 | } |
121 | int64_t output_concat_dim = 0; |
122 | for (int i = 0; i < N; ++i) { |
123 | const auto& in = c->input(values_input_start_index_ + i); |
124 | OP_REQUIRES( |
125 | c, in.dims() > 0, |
126 | errors::InvalidArgument("ConcatOp : Can't concatenate scalars " |
127 | "(use tf.stack instead)" )); |
128 | OP_REQUIRES( |
129 | c, in.dims() == input_dims, |
130 | errors::InvalidArgument( |
131 | "ConcatOp : Ranks of all input tensors should match: shape[0] = " , |
132 | input_shape.DebugString(), " vs. shape[" , i, |
133 | "] = " , in.shape().DebugString())); |
134 | for (int j = 0; j < input_dims; ++j) { |
135 | if (j == axis) { |
136 | continue; |
137 | } |
138 | OP_REQUIRES( |
139 | c, in.dim_size(j) == input_shape.dim_size(j), |
140 | errors::InvalidArgument("ConcatOp : Dimension " , j, |
141 | " in both shapes must be equal: " |
142 | "shape[0] = " , |
143 | input_shape.DebugString(), " vs. shape[" , i, |
144 | "] = " , in.shape().DebugString())); |
145 | } |
146 | if (in.NumElements() > 0) { |
147 | int64_t inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0; |
148 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
149 | in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1}))); |
150 | } |
151 | // TODO(rmlarsen): Remove check once !allow_legacy_scalars()? |
152 | output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1; |
153 | } |
154 | |
155 | TensorShape output_shape(input_shape); |
156 | // TODO(rmlarsen): Remove rank 0 case once !allow_legacy_scalars()? |
157 | if (output_shape.dims() == 0) { |
158 | output_shape.AddDim(output_concat_dim); |
159 | } else { |
160 | output_shape.set_dim(axis, output_concat_dim); |
161 | } |
162 | Tensor* output = nullptr; |
163 | OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); |
164 | if (output->NumElements() > 0) { |
165 | int64_t output_dim1 = output->NumElements() / inputs_flat_dim0; |
166 | auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1}); |
167 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
168 | if (std::is_same<Device, GPUDevice>::value) { |
169 | ConcatGPU<T>(c, inputs_flat, output, &output_flat); |
170 | return; |
171 | } |
172 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
173 | ConcatCPU<T>(c->device(), inputs_flat, &output_flat); |
174 | } |
175 | } |
176 | |
177 | private: |
178 | const char* const axis_attribute_name_; |
179 | int axis_input_index_; |
180 | int values_input_start_index_; |
181 | int values_input_end_index_; |
182 | }; |
183 | |
184 | template <typename Device, typename T> |
185 | using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>; |
186 | template <typename Device, typename T> |
187 | using ConcatV2Op = ConcatBaseOp<Device, T, NAME_IS_AXIS>; |
188 | |
189 | #define REGISTER_CONCAT(type) \ |
190 | REGISTER_KERNEL_BUILDER(Name("Concat") \ |
191 | .Device(DEVICE_CPU) \ |
192 | .TypeConstraint<type>("T") \ |
193 | .HostMemory("concat_dim"), \ |
194 | ConcatOp<CPUDevice, type>) \ |
195 | REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ |
196 | .Device(DEVICE_CPU) \ |
197 | .TypeConstraint<type>("T") \ |
198 | .HostMemory("axis"), \ |
199 | ConcatV2Op<CPUDevice, type>) |
200 | |
201 | TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT); |
202 | REGISTER_CONCAT(quint8); |
203 | REGISTER_CONCAT(qint8); |
204 | REGISTER_CONCAT(quint16); |
205 | REGISTER_CONCAT(qint16); |
206 | REGISTER_CONCAT(qint32); |
207 | |
208 | #undef REGISTER_CONCAT |
209 | |
210 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
211 | |
212 | #define REGISTER_GPU(type) \ |
213 | REGISTER_KERNEL_BUILDER(Name("Concat") \ |
214 | .Device(DEVICE_GPU) \ |
215 | .TypeConstraint<type>("T") \ |
216 | .HostMemory("concat_dim"), \ |
217 | ConcatOp<GPUDevice, type>) \ |
218 | REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ |
219 | .Device(DEVICE_GPU) \ |
220 | .TypeConstraint<type>("T") \ |
221 | .HostMemory("axis"), \ |
222 | ConcatV2Op<GPUDevice, type>) |
223 | |
224 | TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU); |
225 | TF_CALL_bfloat16(REGISTER_GPU); |
226 | TF_CALL_GPU_ALL_TYPES(REGISTER_GPU); |
227 | #undef REGISTER_GPU |
228 | |
229 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
230 | |
231 | // A special DEVICE_DEFAULT kernel for int32. |
232 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
233 | // registration requires all int32 inputs and outputs to be in host memory. |
234 | REGISTER_KERNEL_BUILDER(Name("Concat" ) |
235 | .Device(DEVICE_DEFAULT) |
236 | .TypeConstraint<int32>("T" ) |
237 | .HostMemory("concat_dim" ) |
238 | .HostMemory("values" ) |
239 | .HostMemory("output" ), |
240 | ConcatOp<CPUDevice, int32>); |
241 | REGISTER_KERNEL_BUILDER(Name("ConcatV2" ) |
242 | .Device(DEVICE_DEFAULT) |
243 | .TypeConstraint<int32>("T" ) |
244 | .HostMemory("values" ) |
245 | .HostMemory("axis" ) |
246 | .HostMemory("output" ), |
247 | ConcatV2Op<CPUDevice, int32>); |
248 | |
249 | class ConcatOffsetOp : public OpKernel { |
250 | public: |
251 | explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
252 | |
253 | void Compute(OpKernelContext* ctx) override { |
254 | const Tensor& concat_dim = ctx->input(0); |
255 | OP_REQUIRES( |
256 | ctx, TensorShapeUtils::IsScalar(concat_dim.shape()), |
257 | errors::InvalidArgument( |
258 | "Concat dim tensor should be a scalar integer, but got shape " , |
259 | concat_dim.shape().DebugString())); |
260 | for (int i = 1; i < ctx->num_inputs(); ++i) { |
261 | const Tensor& inp = ctx->input(i); |
262 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(inp.shape()), |
263 | errors::InvalidArgument("input " , i, |
264 | " should be a vector, but got shape " , |
265 | inp.shape().DebugString())); |
266 | } |
267 | // Suppose a Concat() op needs to Concatenate N tensors, each of |
268 | // which has the same number of dimensions. Their shapes match |
269 | // except the concat dimension. |
270 | // |
271 | // E.g., say, we want to concatenate 3 tensors in the 2nd |
272 | // dimension, and their shapes are: |
273 | // |
274 | // [2, 2, 5, 7] |
275 | // [2, 3, 5, 7] |
276 | // [2, 4, 5, 7] |
277 | // |
278 | // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape |
279 | // [2,9,5,7]. We will compute the cumulative sum along the 2nd |
280 | // dimension to figure out each input's offset in the concatenated |
281 | // output: |
282 | // [0, 0, 0, 0] |
283 | // [0, 2, 0, 0] |
284 | // [0, 5, 0, 0] |
285 | const int32_t N = ctx->num_inputs() - 1; |
286 | const Tensor& inp0 = ctx->input(1); |
287 | auto inp0_vec = inp0.vec<int32>(); |
288 | const int64_t cdim = internal::SubtleMustCopy(concat_dim.scalar<int32>()()); |
289 | const int64_t dims = inp0.NumElements(); |
290 | int32_t axis = cdim < 0 ? cdim + dims : cdim; |
291 | OP_REQUIRES(ctx, FastBoundsCheck(axis, dims), |
292 | errors::InvalidArgument("Concat dim is out of range: " , cdim, |
293 | " vs. " , dims)); |
294 | int32_t offset = 0; |
295 | for (int i = 0; i < N; ++i) { |
296 | const Tensor& inp = ctx->input(1 + i); |
297 | OP_REQUIRES( |
298 | ctx, dims == inp.NumElements(), |
299 | errors::InvalidArgument("input " , i, " should contain " , dims, |
300 | " elements, but got " , inp.NumElements())); |
301 | auto inp_vec = inp.vec<int32>(); |
302 | Tensor* out = nullptr; |
303 | OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); |
304 | auto out_vec = out->vec<int32>(); |
305 | for (int64_t j = 0; j < dims; ++j) { |
306 | if (j == axis) { |
307 | out_vec(j) = offset; |
308 | offset += inp_vec(j); |
309 | } else { |
310 | OP_REQUIRES(ctx, (inp0_vec(j) == inp_vec(j)), |
311 | errors::InvalidArgument( |
312 | "All dimensions except " , axis, " must match. Input " , |
313 | i, " has shape [" , inp.SummarizeValue(10), |
314 | "] and doesn't match input 0 with shape [" , |
315 | inp0.SummarizeValue(10), "]." )); |
316 | out_vec(j) = 0; |
317 | } |
318 | } |
319 | } |
320 | } |
321 | |
322 | bool IsExpensive() override { return false; } |
323 | }; |
324 | |
325 | REGISTER_KERNEL_BUILDER(Name("ConcatOffset" ).Device(DEVICE_CPU), |
326 | ConcatOffsetOp); |
327 | REGISTER_KERNEL_BUILDER(Name("ConcatOffset" ) |
328 | .Device(DEVICE_DEFAULT) |
329 | .HostMemory("concat_dim" ) |
330 | .HostMemory("shape" ) |
331 | .HostMemory("offset" ), |
332 | ConcatOffsetOp); |
333 | |
334 | } // namespace tensorflow |
335 | |