1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/nn_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/pad_op.h"
21
22#include <memory>
23#include <string>
24#include <utility>
25
26#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27#include "tensorflow/core/framework/op.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/register_types.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/framework/tensor_shape.h"
32#include "tensorflow/core/framework/tensor_types.h"
33#include "tensorflow/core/framework/types.h"
34#include "tensorflow/core/platform/logging.h"
35#include "tensorflow/core/platform/types.h"
36
37namespace tensorflow {
38
39typedef Eigen::ThreadPoolDevice CPUDevice;
40typedef Eigen::GpuDevice GPUDevice;
41
42template <typename Device, typename T, typename Tpadding>
43class PadOp : public OpKernel {
44 public:
45 explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {}
46
47 void Compute(OpKernelContext* context) override {
48 const Tensor& in0 = context->input(0);
49 const Tensor& in1 = context->input(1);
50 const int dims = in0.dims();
51 static const int kMinDims = 0;
52 static const int kMaxDims = 8;
53 OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
54 errors::Unimplemented("inputs rank not in [", kMinDims, ",",
55 kMaxDims, "]: ", dims));
56 OP_REQUIRES(
57 context,
58 TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2,
59 errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
60 in1.shape().DebugString()));
61 OP_REQUIRES(
62 context, dims == in1.dim_size(0),
63 errors::InvalidArgument(
64 "The first dimension of paddings must be the rank of inputs",
65 in1.shape().DebugString(), " ", in0.shape().DebugString()));
66
67 T pad_value = T();
68 if (context->num_inputs() == 3) {
69 const Tensor& constant_values = context->input(2);
70 OP_REQUIRES(
71 context, TensorShapeUtils::IsScalar(constant_values.shape()),
72 errors::InvalidArgument("constant_values must be a scalar. Found: ",
73 constant_values.shape().DebugString()));
74 pad_value = context->input(2).scalar<T>()();
75 }
76
77 // Compute the shape of the output tensor, and allocate it.
78 TensorShape output_shape;
79 typename TTypes<Tpadding>::ConstMatrix paddings = in1.matrix<Tpadding>();
80 for (int d = 0; d < dims; ++d) {
81 const Tpadding before_d =
82 paddings(d, 0); // Pad before existing elements.
83 const Tpadding after_d = paddings(d, 1); // Pad after existing elements.
84 OP_REQUIRES(context, before_d >= 0 && after_d >= 0,
85 errors::InvalidArgument("Paddings must be non-negative: ",
86 before_d, " ", after_d));
87 const int64_t size_d = in0.dim_size(d);
88 OP_REQUIRES_OK(
89 context, output_shape.AddDimWithStatus(before_d + size_d + after_d));
90 }
91
92 // If there is no padding to be done, forward the input to output.
93 if (output_shape.num_elements() == in0.NumElements()) {
94 // When num_elements == 0, shape may have changed.
95 Tensor out;
96 CHECK(out.CopyFrom(in0, output_shape));
97 context->set_output(0, out);
98 return;
99 }
100
101 TensorShape collapsed_input_shape;
102 TensorShape collapsed_output_shape;
103 Tensor collapsed_paddings;
104 if (dims > 1 && CollapseAdjacentNonPaddedDimensions(
105 in0.shape(), in1, output_shape, &collapsed_input_shape,
106 &collapsed_paddings, &collapsed_output_shape)) {
107 Tensor collapsed_input;
108 CHECK(collapsed_input.CopyFrom(in0, collapsed_input_shape));
109 Tensor collapsed_output;
110 AllocatorAttributes alloc_attrs;
111 alloc_attrs.set_on_host(context->input_memory_type(0) == HOST_MEMORY);
112 OP_REQUIRES_OK(context,
113 context->allocate_temp(collapsed_input.dtype(),
114 collapsed_output_shape,
115 &collapsed_output, alloc_attrs));
116 const Tensor& collapsed_paddings_ref = collapsed_paddings;
117 typename TTypes<Tpadding>::ConstMatrix collapsed_paddings_matrix =
118 collapsed_paddings_ref.matrix<Tpadding>();
119
120 OperateWithVariableRank(context, collapsed_input_shape.dims(),
121 collapsed_input, collapsed_paddings_matrix,
122 pad_value, &collapsed_output);
123
124 Tensor output;
125 CHECK(output.CopyFrom(collapsed_output, output_shape));
126 context->set_output(0, output);
127 } else {
128 Tensor* output = nullptr;
129 OP_REQUIRES_OK(context,
130 context->allocate_output(0, output_shape, &output));
131 OperateWithVariableRank(context, dims, in0, paddings, pad_value, output);
132 }
133 }
134
135 private:
136 // Collapses adjacent dimensions that are not padded to one dimension for
137 // speed. Returns true if any two dimensions are collapsed. For example,
138 //
139 // Pad(input_shape=[8, 28, 28, 3],
140 // paddings=[[0, 0], [0, 0], [0, 0], [0, 1]]
141 // is equivalent to
142 // Pad(input_shape=[6272, 3],
143 // paddings=[[0, 0], [0, 1]])
144 //
145 // input_shape: the original input shape.
146 // paddings_as_tensor: the original paddings.
147 // output_shape: the original output shape.
148 // collapsed_input_shape: the input shape after collapsing.
149 // collapsed_paddings_as_tensor: the paddings after collapsing.
150 // collapsed_output_shape: the output shape after collapsing.
151 static bool CollapseAdjacentNonPaddedDimensions(
152 const TensorShape& input_shape, const Tensor& paddings_as_tensor,
153 const TensorShape& output_shape, TensorShape* collapsed_input_shape,
154 Tensor* collapsed_paddings_as_tensor,
155 TensorShape* collapsed_output_shape) {
156 bool collapsed = false;
157 typename TTypes<Tpadding>::ConstMatrix paddings =
158 paddings_as_tensor.matrix<Tpadding>();
159 std::vector<std::pair<int, int>> collapsed_paddings;
160 int i = 0;
161 while (i < paddings.dimension(0)) {
162 if (paddings(i, 0) != 0 || paddings(i, 1) != 0) {
163 // If padded, copy the original dimension over.
164 collapsed_input_shape->InsertDim(collapsed_input_shape->dims(),
165 input_shape.dim_size(i));
166 collapsed_output_shape->InsertDim(collapsed_output_shape->dims(),
167 output_shape.dim_size(i));
168 collapsed_paddings.push_back({paddings(i, 0), paddings(i, 1)});
169 ++i;
170 } else {
171 // If not padded, find the next dimension that is padded and collapse
172 // all dimensions in between to one dimension.
173 int64_t collapsed_input_dim_size = input_shape.dim_size(i);
174 int64_t collapsed_output_dim_size = output_shape.dim_size(i);
175 ++i;
176 while (i < paddings.dimension(0) && paddings(i, 0) == 0 &&
177 paddings(i, 1) == 0) {
178 collapsed = true;
179 collapsed_input_dim_size *= input_shape.dim_size(i);
180 collapsed_output_dim_size *= output_shape.dim_size(i);
181 ++i;
182 }
183 collapsed_input_shape->InsertDim(collapsed_input_shape->dims(),
184 collapsed_input_dim_size);
185 collapsed_output_shape->InsertDim(collapsed_output_shape->dims(),
186 collapsed_output_dim_size);
187 collapsed_paddings.push_back({0, 0});
188 }
189 }
190
191 // Copy collapsed_paddings to collapsed_paddings_as_tensor.
192 *collapsed_paddings_as_tensor = Tensor(
193 paddings_as_tensor.dtype(),
194 TensorShape({static_cast<int64_t>(collapsed_paddings.size()), 2}));
195 auto collapsed_paddings_as_matrix =
196 collapsed_paddings_as_tensor->matrix<Tpadding>();
197 for (size_t i = 0; i < collapsed_paddings.size(); ++i) {
198 collapsed_paddings_as_matrix(i, 0) = collapsed_paddings[i].first;
199 collapsed_paddings_as_matrix(i, 1) = collapsed_paddings[i].second;
200 }
201 return collapsed;
202 }
203
204 void OperateWithVariableRank(OpKernelContext* context, int fixed_dims,
205 const Tensor& input,
206 typename TTypes<Tpadding>::ConstMatrix paddings,
207 T pad_value, Tensor* output) {
208 // Invoke the dims-specific implementation.
209 switch (fixed_dims) {
210 case 0:
211 Operate<0>(context, input.tensor<T, 0>(), paddings, pad_value, output);
212 break;
213 case 1:
214 // TODO(irving): Once Pad doesn't need a scalar special case,
215 // change flat to tensor. That is, once !allow_legacy_scalars().
216 Operate<1>(context, input.flat<T>(), paddings, pad_value, output);
217 break;
218 case 2:
219 Operate<2>(context, input.tensor<T, 2>(), paddings, pad_value, output);
220 break;
221 case 3:
222 Operate<3>(context, input.tensor<T, 3>(), paddings, pad_value, output);
223 break;
224 case 4:
225 Operate<4>(context, input.tensor<T, 4>(), paddings, pad_value, output);
226 break;
227 case 5:
228 Operate<5>(context, input.tensor<T, 5>(), paddings, pad_value, output);
229 break;
230 case 6:
231 Operate<6>(context, input.tensor<T, 6>(), paddings, pad_value, output);
232 break;
233 default:
234 OP_REQUIRES(context, false,
235 errors::InvalidArgument("Only ranks up to 6 supported: ",
236 input.shape().DebugString()));
237 }
238 }
239
240 template <int Dims>
241 void Operate(OpKernelContext* context,
242 typename TTypes<T, Dims>::ConstTensor input,
243 typename TTypes<Tpadding>::ConstMatrix paddings, T pad_value,
244 Tensor* output) {
245 CHECK_EQ(Dims, paddings.dimension(0));
246 CHECK_EQ(2, paddings.dimension(1));
247 Eigen::array<Eigen::IndexPair<Tpadding>, Dims> paddings_array;
248 for (int i = 0; i < Dims; ++i) {
249 paddings_array[i] = {paddings(i, 0), paddings(i, 1)};
250 }
251 functor::Pad<Device, T, Tpadding, Dims> functor;
252 functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input,
253 paddings_array, pad_value);
254 }
255};
256
257#define REGISTER_KERNEL(type) \
258 REGISTER_KERNEL_BUILDER(Name("Pad") \
259 .Device(DEVICE_CPU) \
260 .TypeConstraint<type>("T") \
261 .TypeConstraint<int32>("Tpaddings") \
262 .HostMemory("paddings"), \
263 PadOp<CPUDevice, type, int32>); \
264 REGISTER_KERNEL_BUILDER(Name("Pad") \
265 .Device(DEVICE_CPU) \
266 .TypeConstraint<type>("T") \
267 .TypeConstraint<int64_t>("Tpaddings") \
268 .HostMemory("paddings"), \
269 PadOp<CPUDevice, type, int64>); \
270 REGISTER_KERNEL_BUILDER(Name("PadV2") \
271 .Device(DEVICE_CPU) \
272 .TypeConstraint<type>("T") \
273 .TypeConstraint<int32>("Tpaddings") \
274 .HostMemory("paddings") \
275 .HostMemory("constant_values"), \
276 PadOp<CPUDevice, type, int32>); \
277 REGISTER_KERNEL_BUILDER(Name("PadV2") \
278 .Device(DEVICE_CPU) \
279 .TypeConstraint<type>("T") \
280 .TypeConstraint<int64_t>("Tpaddings") \
281 .HostMemory("paddings") \
282 .HostMemory("constant_values"), \
283 PadOp<CPUDevice, type, int64>);
284
285TF_CALL_POD_TYPES(REGISTER_KERNEL);
286TF_CALL_QUANTIZED_TYPES(REGISTER_KERNEL);
287TF_CALL_tstring(REGISTER_KERNEL);
288#undef REGISTER_KERNEL
289
290#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
291 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
292// Forward declarations of the functor specializations for GPU.
293namespace functor {
294#define DECLARE_GPU_SPEC(T, Dims) \
295 template <> \
296 void Pad<GPUDevice, T, int32, Dims>::operator()( \
297 const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \
298 typename TTypes<T, Dims>::ConstTensor input, \
299 Eigen::array<Eigen::IndexPair<int32>, Dims> paddings, T pad_value); \
300 extern template struct Pad<GPUDevice, T, int32, Dims>; \
301 template <> \
302 void Pad<GPUDevice, T, int64_t, Dims>::operator()( \
303 const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \
304 typename TTypes<T, Dims>::ConstTensor input, \
305 Eigen::array<Eigen::IndexPair<int64_t>, Dims> paddings, T pad_value); \
306 extern template struct Pad<GPUDevice, T, int64_t, Dims>;
307
308#define DECLARE_GPU_SPECS(T) \
309 DECLARE_GPU_SPEC(T, 0); \
310 DECLARE_GPU_SPEC(T, 1); \
311 DECLARE_GPU_SPEC(T, 2); \
312 DECLARE_GPU_SPEC(T, 3); \
313 DECLARE_GPU_SPEC(T, 4); \
314 DECLARE_GPU_SPEC(T, 5); \
315 DECLARE_GPU_SPEC(T, 6);
316
317TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPECS);
318TF_CALL_int8(DECLARE_GPU_SPECS);
319TF_CALL_uint8(DECLARE_GPU_SPECS);
320} // namespace functor
321
322// Registration of the GPU implementations.
323#define REGISTER_GPU_KERNEL(T) \
324 REGISTER_KERNEL_BUILDER(Name("Pad") \
325 .Device(DEVICE_GPU) \
326 .TypeConstraint<T>("T") \
327 .TypeConstraint<int32>("Tpaddings") \
328 .HostMemory("paddings"), \
329 PadOp<GPUDevice, T, int32>); \
330 REGISTER_KERNEL_BUILDER(Name("Pad") \
331 .Device(DEVICE_GPU) \
332 .TypeConstraint<T>("T") \
333 .TypeConstraint<int64_t>("Tpaddings") \
334 .HostMemory("paddings"), \
335 PadOp<GPUDevice, T, int64>); \
336 REGISTER_KERNEL_BUILDER(Name("PadV2") \
337 .Device(DEVICE_GPU) \
338 .TypeConstraint<T>("T") \
339 .TypeConstraint<int32>("Tpaddings") \
340 .HostMemory("paddings") \
341 .HostMemory("constant_values"), \
342 PadOp<GPUDevice, T, int32>) \
343 REGISTER_KERNEL_BUILDER(Name("PadV2") \
344 .Device(DEVICE_GPU) \
345 .TypeConstraint<T>("T") \
346 .TypeConstraint<int64_t>("Tpaddings") \
347 .HostMemory("paddings") \
348 .HostMemory("constant_values"), \
349 PadOp<GPUDevice, T, int64>)
350
351TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNEL);
352TF_CALL_int8(REGISTER_GPU_KERNEL);
353TF_CALL_uint8(REGISTER_GPU_KERNEL);
354
355// A special GPU kernel for int32.
356// TODO(b/25387198): Also enable int32 in device memory. This kernel
357// registration requires all int32 inputs and outputs to be in host memory.
358REGISTER_KERNEL_BUILDER(Name("Pad")
359 .Device(DEVICE_GPU)
360 .TypeConstraint<int32>("T")
361 .TypeConstraint<int32>("Tpaddings")
362 .HostMemory("input")
363 .HostMemory("paddings")
364 .HostMemory("output"),
365 PadOp<CPUDevice, int32, int32>);
366REGISTER_KERNEL_BUILDER(Name("Pad")
367 .Device(DEVICE_GPU)
368 .TypeConstraint<int32>("T")
369 .TypeConstraint<int64_t>("Tpaddings")
370 .HostMemory("input")
371 .HostMemory("paddings")
372 .HostMemory("output"),
373 PadOp<CPUDevice, int32, int64>);
374REGISTER_KERNEL_BUILDER(Name("PadV2")
375 .Device(DEVICE_GPU)
376 .TypeConstraint<int32>("T")
377 .TypeConstraint<int32>("Tpaddings")
378 .HostMemory("input")
379 .HostMemory("paddings")
380 .HostMemory("constant_values")
381 .HostMemory("output"),
382 PadOp<CPUDevice, int32, int32>);
383REGISTER_KERNEL_BUILDER(Name("PadV2")
384 .Device(DEVICE_GPU)
385 .TypeConstraint<int32>("T")
386 .TypeConstraint<int64_t>("Tpaddings")
387 .HostMemory("input")
388 .HostMemory("paddings")
389 .HostMemory("constant_values")
390 .HostMemory("output"),
391 PadOp<CPUDevice, int32, int64>);
392#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
393
394
395} // end namespace tensorflow
396