1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/array_ops.cc.
17
18#include <limits>
19#include <vector>
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_types.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/kernels/concat_lib.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace tensorflow {
32
33typedef Eigen::ThreadPoolDevice CPUDevice;
34#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
35typedef Eigen::GpuDevice GPUDevice;
36#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37
38// --------------------------------------------------------------------------
39template <typename Device, typename T>
40class PackOp : public OpKernel {
41 public:
42 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
43 ConstMatrixVector;
44
45 explicit PackOp(OpKernelConstruction* context) : OpKernel(context) {
46 OP_REQUIRES_OK(context, context->GetAttr("axis", &axis_));
47 }
48
49 void Compute(OpKernelContext* c) override {
50 const int num = num_inputs();
51 const Tensor& first_input = c->input(0);
52
53 int expanded_num_dims = first_input.dims() + 1;
54 int axis = axis_;
55 if (axis < 0) axis += expanded_num_dims;
56
57 OP_REQUIRES(c, 0 <= axis && axis < expanded_num_dims,
58 errors::InvalidArgument("axis = ", axis_, " not in [",
59 -expanded_num_dims, ", ",
60 expanded_num_dims, ")"));
61
62 TensorShape output_shape(first_input.shape());
63 output_shape.InsertDim(axis, num);
64
65 // In the num = 1 case, just reshape the input
66 if (num == 1) {
67 Tensor output;
68 CHECK(output.CopyFrom(first_input, output_shape));
69 c->set_output(0, output);
70 return;
71 }
72
73 // Allocate output
74 Tensor* output;
75 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
76
77 int64_t before_dim = 1;
78 for (int i = 0; i < axis; ++i) {
79 before_dim *= output_shape.dim_size(i);
80 }
81
82 int64_t after_dim = 1;
83 for (int i = axis + 1; i < output_shape.dims(); ++i) {
84 after_dim *= output_shape.dim_size(i);
85 }
86
87 const int64_t axis_dim = output_shape.dim_size(axis);
88
89 const int64_t output_size = output->NumElements();
90 auto output_flat = output->shaped<T, 2>({before_dim, after_dim * axis_dim});
91
92 // Except for shapes, pack is a special case of concat, so we reuse the
93 // same computational kernels.
94 ConstMatrixVector inputs_flat;
95 inputs_flat.reserve(num);
96 for (int i = 0; i < num; ++i) {
97 const Tensor& input = c->input(i);
98 OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()),
99 errors::InvalidArgument(
100 "Shapes of all inputs must match: values[0].shape = ",
101 first_input.shape().DebugString(), " != values[", i,
102 "].shape = ", input.shape().DebugString()));
103
104 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
105 input.shaped<T, 2>({before_dim, after_dim})));
106 }
107 if (output_size > 0) {
108#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
109 if (std::is_same<Device, GPUDevice>::value) {
110 ConcatGPU<T>(c, inputs_flat, output, &output_flat);
111 return;
112 }
113#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
114 ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
115 }
116 }
117
118 private:
119 int axis_;
120};
121
122#define REGISTER_PACK(type) \
123 REGISTER_KERNEL_BUILDER( \
124 Name("Pack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
125 PackOp<CPUDevice, type>)
126
127TF_CALL_ALL_TYPES(REGISTER_PACK);
128TF_CALL_QUANTIZED_TYPES(REGISTER_PACK);
129
130#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
131// Primarily used for SavedModel support on mobile.
132REGISTER_PACK(tstring);
133#endif // defined(IS_MOBILE_PLATFORM) &&
134 // !defined(SUPPORT_SELECTIVE_REGISTRATION)
135
136#undef REGISTER_PACK
137
138#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
139
140#define REGISTER_GPU(type) \
141 REGISTER_KERNEL_BUILDER( \
142 Name("Pack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
143 PackOp<GPUDevice, type>)
144
145TF_CALL_bfloat16(REGISTER_GPU);
146TF_CALL_int64(REGISTER_GPU);
147TF_CALL_int16(REGISTER_GPU);
148TF_CALL_uint32(REGISTER_GPU);
149TF_CALL_uint64(REGISTER_GPU);
150TF_CALL_GPU_ALL_TYPES(REGISTER_GPU);
151#undef REGISTER_GPU
152
153// A special GPU kernel for int32.
154// TODO(b/25387198): Also enable int32 in device memory. This kernel
155// registration requires all int32 inputs and outputs to be in host memory.
156REGISTER_KERNEL_BUILDER(Name("Pack")
157 .Device(DEVICE_GPU)
158 .HostMemory("values")
159 .HostMemory("output")
160 .TypeConstraint<int32>("T"),
161 PackOp<CPUDevice, int32>);
162
163#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
164
165} // namespace tensorflow
166