1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#include <limits>
19#include <vector>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/bounds_check.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_types.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/kernels/concat_lib.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/platform/errors.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34
35typedef Eigen::ThreadPoolDevice CPUDevice;
36#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37typedef Eigen::GpuDevice GPUDevice;
38#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
39
40enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
41
42// --------------------------------------------------------------------------
43template <typename Device, typename T, AxisArgumentName AxisArgName>
44class ConcatBaseOp : public OpKernel {
45 public:
46 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
47 ConstMatrixVector;
48
49 explicit ConcatBaseOp(OpKernelConstruction* c)
50 : OpKernel(c),
51 axis_attribute_name_(AxisArgName == NAME_IS_AXIS ? "axis"
52 : AxisArgName == NAME_IS_CONCAT_DIM
53 ? "concat_dim"
54 : "<invalid>") {
55 int unused;
56 OP_REQUIRES_OK(
57 c, InputRange(axis_attribute_name_, &axis_input_index_, &unused));
58 OP_REQUIRES_OK(c, InputRange("values", &values_input_start_index_,
59 &values_input_end_index_));
60 }
61
62 void Compute(OpKernelContext* c) override {
63 const Tensor& concat_dim_tensor = c->input(axis_input_index_);
64
65 // TODO(rmlarsen): Disallow legacy use of length-1 vectors as scalars.
66 OP_REQUIRES(c,
67 (TensorShapeUtils::IsScalar(concat_dim_tensor.shape()) ||
68 (TensorShapeUtils::IsVector(concat_dim_tensor.shape()) &&
69 concat_dim_tensor.shape().dim_size(0) == 1)),
70 errors::InvalidArgument(
71 axis_attribute_name_,
72 " tensor should be a scalar integer, but got shape ",
73 concat_dim_tensor.shape().DebugString()));
74 int64_t concat_dim;
75 // In case of ConcatV2, "axis" could be int32 or int64
76 if (AxisArgName == NAME_IS_AXIS) {
77 OP_REQUIRES(
78 c,
79 (concat_dim_tensor.dtype() == DT_INT32 ||
80 concat_dim_tensor.dtype() == DT_INT64),
81 errors::InvalidArgument(axis_attribute_name_,
82 " tensor should be int32 or int64, but got ",
83 DataTypeString(concat_dim_tensor.dtype())));
84 } else {
85 OP_REQUIRES(c, (concat_dim_tensor.dtype() == DT_INT32),
86 errors::InvalidArgument(
87 axis_attribute_name_, " tensor should be int32, but got ",
88 DataTypeString(concat_dim_tensor.dtype())));
89 }
90 if (concat_dim_tensor.dtype() == DT_INT32) {
91 concat_dim =
92 internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
93 } else {
94 concat_dim =
95 internal::SubtleMustCopy(concat_dim_tensor.scalar<int64_t>()());
96 }
97
98 const int N = values_input_end_index_ - values_input_start_index_;
99 const Tensor& first_input = c->input(values_input_start_index_);
100 const int input_dims = first_input.dims();
101 const TensorShape& input_shape = first_input.shape();
102
103 int32_t axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
104 // concat_dim==0 allows concatenating a list of scalars into a vector.
105 OP_REQUIRES(c, (0 <= axis && axis < input_dims) || concat_dim == 0,
106 errors::InvalidArgument(
107 "ConcatOp : Expected concatenating dimensions in the range "
108 "[",
109 -input_dims, ", ", input_dims, "), but got ", concat_dim));
110 // Note that we reduce the concat of n-dimensional tensors into a two
111 // dimensional concat. Assuming the dimensions of any input/output
112 // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
113 // the dimension indicated with size y0, we flatten it to {x, y}, where y =
114 // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
115 ConstMatrixVector inputs_flat;
116 inputs_flat.reserve(N);
117 int64_t inputs_flat_dim0 = 1;
118 for (int d = 0; d < axis; ++d) {
119 inputs_flat_dim0 *= input_shape.dim_size(d);
120 }
121 int64_t output_concat_dim = 0;
122 for (int i = 0; i < N; ++i) {
123 const auto& in = c->input(values_input_start_index_ + i);
124 OP_REQUIRES(
125 c, in.dims() > 0,
126 errors::InvalidArgument("ConcatOp : Can't concatenate scalars "
127 "(use tf.stack instead)"));
128 OP_REQUIRES(
129 c, in.dims() == input_dims,
130 errors::InvalidArgument(
131 "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
132 input_shape.DebugString(), " vs. shape[", i,
133 "] = ", in.shape().DebugString()));
134 for (int j = 0; j < input_dims; ++j) {
135 if (j == axis) {
136 continue;
137 }
138 OP_REQUIRES(
139 c, in.dim_size(j) == input_shape.dim_size(j),
140 errors::InvalidArgument("ConcatOp : Dimension ", j,
141 " in both shapes must be equal: "
142 "shape[0] = ",
143 input_shape.DebugString(), " vs. shape[", i,
144 "] = ", in.shape().DebugString()));
145 }
146 if (in.NumElements() > 0) {
147 int64_t inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
148 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
149 in.template shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
150 }
151 // TODO(rmlarsen): Remove check once !allow_legacy_scalars()?
152 output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
153 }
154
155 TensorShape output_shape(input_shape);
156 // TODO(rmlarsen): Remove rank 0 case once !allow_legacy_scalars()?
157 if (output_shape.dims() == 0) {
158 output_shape.AddDim(output_concat_dim);
159 } else {
160 output_shape.set_dim(axis, output_concat_dim);
161 }
162 Tensor* output = nullptr;
163 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
164 if (output->NumElements() > 0) {
165 int64_t output_dim1 = output->NumElements() / inputs_flat_dim0;
166 auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
167#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
168 if (std::is_same<Device, GPUDevice>::value) {
169 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
170 return;
171 }
172#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
173 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
174 }
175 }
176
177 private:
178 const char* const axis_attribute_name_;
179 int axis_input_index_;
180 int values_input_start_index_;
181 int values_input_end_index_;
182};
183
184template <typename Device, typename T>
185using ConcatOp = ConcatBaseOp<Device, T, NAME_IS_CONCAT_DIM>;
186template <typename Device, typename T>
187using ConcatV2Op = ConcatBaseOp<Device, T, NAME_IS_AXIS>;
188
189#define REGISTER_CONCAT(type) \
190 REGISTER_KERNEL_BUILDER(Name("Concat") \
191 .Device(DEVICE_CPU) \
192 .TypeConstraint<type>("T") \
193 .HostMemory("concat_dim"), \
194 ConcatOp<CPUDevice, type>) \
195 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \
196 .Device(DEVICE_CPU) \
197 .TypeConstraint<type>("T") \
198 .HostMemory("axis"), \
199 ConcatV2Op<CPUDevice, type>)
200
201TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT);
202REGISTER_CONCAT(quint8);
203REGISTER_CONCAT(qint8);
204REGISTER_CONCAT(quint16);
205REGISTER_CONCAT(qint16);
206REGISTER_CONCAT(qint32);
207
208#undef REGISTER_CONCAT
209
210#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
211
212#define REGISTER_GPU(type) \
213 REGISTER_KERNEL_BUILDER(Name("Concat") \
214 .Device(DEVICE_GPU) \
215 .TypeConstraint<type>("T") \
216 .HostMemory("concat_dim"), \
217 ConcatOp<GPUDevice, type>) \
218 REGISTER_KERNEL_BUILDER(Name("ConcatV2") \
219 .Device(DEVICE_GPU) \
220 .TypeConstraint<type>("T") \
221 .HostMemory("axis"), \
222 ConcatV2Op<GPUDevice, type>)
223
224TF_CALL_INTEGRAL_TYPES_NO_INT32(REGISTER_GPU);
225TF_CALL_bfloat16(REGISTER_GPU);
226TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
227#undef REGISTER_GPU
228
229#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
230
231// A special DEVICE_DEFAULT kernel for int32.
232// TODO(b/25387198): Also enable int32 in device memory. This kernel
233// registration requires all int32 inputs and outputs to be in host memory.
234REGISTER_KERNEL_BUILDER(Name("Concat")
235 .Device(DEVICE_DEFAULT)
236 .TypeConstraint<int32>("T")
237 .HostMemory("concat_dim")
238 .HostMemory("values")
239 .HostMemory("output"),
240 ConcatOp<CPUDevice, int32>);
241REGISTER_KERNEL_BUILDER(Name("ConcatV2")
242 .Device(DEVICE_DEFAULT)
243 .TypeConstraint<int32>("T")
244 .HostMemory("values")
245 .HostMemory("axis")
246 .HostMemory("output"),
247 ConcatV2Op<CPUDevice, int32>);
248
249class ConcatOffsetOp : public OpKernel {
250 public:
251 explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
252
253 void Compute(OpKernelContext* ctx) override {
254 const Tensor& concat_dim = ctx->input(0);
255 OP_REQUIRES(
256 ctx, TensorShapeUtils::IsScalar(concat_dim.shape()),
257 errors::InvalidArgument(
258 "Concat dim tensor should be a scalar integer, but got shape ",
259 concat_dim.shape().DebugString()));
260 for (int i = 1; i < ctx->num_inputs(); ++i) {
261 const Tensor& inp = ctx->input(i);
262 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(inp.shape()),
263 errors::InvalidArgument("input ", i,
264 " should be a vector, but got shape ",
265 inp.shape().DebugString()));
266 }
267 // Suppose a Concat() op needs to Concatenate N tensors, each of
268 // which has the same number of dimensions. Their shapes match
269 // except the concat dimension.
270 //
271 // E.g., say, we want to concatenate 3 tensors in the 2nd
272 // dimension, and their shapes are:
273 //
274 // [2, 2, 5, 7]
275 // [2, 3, 5, 7]
276 // [2, 4, 5, 7]
277 //
278 // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape
279 // [2,9,5,7]. We will compute the cumulative sum along the 2nd
280 // dimension to figure out each input's offset in the concatenated
281 // output:
282 // [0, 0, 0, 0]
283 // [0, 2, 0, 0]
284 // [0, 5, 0, 0]
285 const int32_t N = ctx->num_inputs() - 1;
286 const Tensor& inp0 = ctx->input(1);
287 auto inp0_vec = inp0.vec<int32>();
288 const int64_t cdim = internal::SubtleMustCopy(concat_dim.scalar<int32>()());
289 const int64_t dims = inp0.NumElements();
290 int32_t axis = cdim < 0 ? cdim + dims : cdim;
291 OP_REQUIRES(ctx, FastBoundsCheck(axis, dims),
292 errors::InvalidArgument("Concat dim is out of range: ", cdim,
293 " vs. ", dims));
294 int32_t offset = 0;
295 for (int i = 0; i < N; ++i) {
296 const Tensor& inp = ctx->input(1 + i);
297 OP_REQUIRES(
298 ctx, dims == inp.NumElements(),
299 errors::InvalidArgument("input ", i, " should contain ", dims,
300 " elements, but got ", inp.NumElements()));
301 auto inp_vec = inp.vec<int32>();
302 Tensor* out = nullptr;
303 OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
304 auto out_vec = out->vec<int32>();
305 for (int64_t j = 0; j < dims; ++j) {
306 if (j == axis) {
307 out_vec(j) = offset;
308 offset += inp_vec(j);
309 } else {
310 OP_REQUIRES(ctx, (inp0_vec(j) == inp_vec(j)),
311 errors::InvalidArgument(
312 "All dimensions except ", axis, " must match. Input ",
313 i, " has shape [", inp.SummarizeValue(10),
314 "] and doesn't match input 0 with shape [",
315 inp0.SummarizeValue(10), "]."));
316 out_vec(j) = 0;
317 }
318 }
319 }
320 }
321
322 bool IsExpensive() override { return false; }
323};
324
325REGISTER_KERNEL_BUILDER(Name("ConcatOffset").Device(DEVICE_CPU),
326 ConcatOffsetOp);
327REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
328 .Device(DEVICE_DEFAULT)
329 .HostMemory("concat_dim")
330 .HostMemory("shape")
331 .HostMemory("offset"),
332 ConcatOffsetOp);
333
334} // namespace tensorflow
335