1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #include <limits> |
19 | #include <vector> |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_types.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/kernels/concat_lib.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | typedef Eigen::ThreadPoolDevice CPUDevice; |
34 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
35 | typedef Eigen::GpuDevice GPUDevice; |
36 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
37 | |
38 | // -------------------------------------------------------------------------- |
39 | template <typename Device, typename T> |
40 | class PackOp : public OpKernel { |
41 | public: |
42 | typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> |
43 | ConstMatrixVector; |
44 | |
45 | explicit PackOp(OpKernelConstruction* context) : OpKernel(context) { |
46 | OP_REQUIRES_OK(context, context->GetAttr("axis" , &axis_)); |
47 | } |
48 | |
49 | void Compute(OpKernelContext* c) override { |
50 | const int num = num_inputs(); |
51 | const Tensor& first_input = c->input(0); |
52 | |
53 | int expanded_num_dims = first_input.dims() + 1; |
54 | int axis = axis_; |
55 | if (axis < 0) axis += expanded_num_dims; |
56 | |
57 | OP_REQUIRES(c, 0 <= axis && axis < expanded_num_dims, |
58 | errors::InvalidArgument("axis = " , axis_, " not in [" , |
59 | -expanded_num_dims, ", " , |
60 | expanded_num_dims, ")" )); |
61 | |
62 | TensorShape output_shape(first_input.shape()); |
63 | output_shape.InsertDim(axis, num); |
64 | |
65 | // In the num = 1 case, just reshape the input |
66 | if (num == 1) { |
67 | Tensor output; |
68 | CHECK(output.CopyFrom(first_input, output_shape)); |
69 | c->set_output(0, output); |
70 | return; |
71 | } |
72 | |
73 | // Allocate output |
74 | Tensor* output; |
75 | OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); |
76 | |
77 | int64_t before_dim = 1; |
78 | for (int i = 0; i < axis; ++i) { |
79 | before_dim *= output_shape.dim_size(i); |
80 | } |
81 | |
82 | int64_t after_dim = 1; |
83 | for (int i = axis + 1; i < output_shape.dims(); ++i) { |
84 | after_dim *= output_shape.dim_size(i); |
85 | } |
86 | |
87 | const int64_t axis_dim = output_shape.dim_size(axis); |
88 | |
89 | const int64_t output_size = output->NumElements(); |
90 | auto output_flat = output->shaped<T, 2>({before_dim, after_dim * axis_dim}); |
91 | |
92 | // Except for shapes, pack is a special case of concat, so we reuse the |
93 | // same computational kernels. |
94 | ConstMatrixVector inputs_flat; |
95 | inputs_flat.reserve(num); |
96 | for (int i = 0; i < num; ++i) { |
97 | const Tensor& input = c->input(i); |
98 | OP_REQUIRES(c, first_input.shape().IsSameSize(input.shape()), |
99 | errors::InvalidArgument( |
100 | "Shapes of all inputs must match: values[0].shape = " , |
101 | first_input.shape().DebugString(), " != values[" , i, |
102 | "].shape = " , input.shape().DebugString())); |
103 | |
104 | inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( |
105 | input.shaped<T, 2>({before_dim, after_dim}))); |
106 | } |
107 | if (output_size > 0) { |
108 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
109 | if (std::is_same<Device, GPUDevice>::value) { |
110 | ConcatGPU<T>(c, inputs_flat, output, &output_flat); |
111 | return; |
112 | } |
113 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
114 | ConcatCPU<T>(c->device(), inputs_flat, &output_flat); |
115 | } |
116 | } |
117 | |
118 | private: |
119 | int axis_; |
120 | }; |
121 | |
122 | #define REGISTER_PACK(type) \ |
123 | REGISTER_KERNEL_BUILDER( \ |
124 | Name("Pack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
125 | PackOp<CPUDevice, type>) |
126 | |
127 | TF_CALL_ALL_TYPES(REGISTER_PACK); |
128 | TF_CALL_QUANTIZED_TYPES(REGISTER_PACK); |
129 | |
130 | #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) |
131 | // Primarily used for SavedModel support on mobile. |
132 | REGISTER_PACK(tstring); |
133 | #endif // defined(IS_MOBILE_PLATFORM) && |
134 | // !defined(SUPPORT_SELECTIVE_REGISTRATION) |
135 | |
136 | #undef REGISTER_PACK |
137 | |
138 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
139 | |
140 | #define REGISTER_GPU(type) \ |
141 | REGISTER_KERNEL_BUILDER( \ |
142 | Name("Pack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
143 | PackOp<GPUDevice, type>) |
144 | |
145 | TF_CALL_bfloat16(REGISTER_GPU); |
146 | TF_CALL_int64(REGISTER_GPU); |
147 | TF_CALL_int16(REGISTER_GPU); |
148 | TF_CALL_uint32(REGISTER_GPU); |
149 | TF_CALL_uint64(REGISTER_GPU); |
150 | TF_CALL_GPU_ALL_TYPES(REGISTER_GPU); |
151 | #undef REGISTER_GPU |
152 | |
153 | // A special GPU kernel for int32. |
154 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
155 | // registration requires all int32 inputs and outputs to be in host memory. |
156 | REGISTER_KERNEL_BUILDER(Name("Pack" ) |
157 | .Device(DEVICE_GPU) |
158 | .HostMemory("values" ) |
159 | .HostMemory("output" ) |
160 | .TypeConstraint<int32>("T" ), |
161 | PackOp<CPUDevice, int32>); |
162 | |
163 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
164 | |
165 | } // namespace tensorflow |
166 | |