1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/nn_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/pad_op.h" |
21 | |
22 | #include <memory> |
23 | #include <string> |
24 | #include <utility> |
25 | |
26 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
27 | #include "tensorflow/core/framework/op.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/register_types.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/framework/tensor_types.h" |
33 | #include "tensorflow/core/framework/types.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/platform/types.h" |
36 | |
37 | namespace tensorflow { |
38 | |
39 | typedef Eigen::ThreadPoolDevice CPUDevice; |
40 | typedef Eigen::GpuDevice GPUDevice; |
41 | |
42 | template <typename Device, typename T, typename Tpadding> |
43 | class PadOp : public OpKernel { |
44 | public: |
45 | explicit PadOp(OpKernelConstruction* context) : OpKernel(context) {} |
46 | |
47 | void Compute(OpKernelContext* context) override { |
48 | const Tensor& in0 = context->input(0); |
49 | const Tensor& in1 = context->input(1); |
50 | const int dims = in0.dims(); |
51 | static const int kMinDims = 0; |
52 | static const int kMaxDims = 8; |
53 | OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims, |
54 | errors::Unimplemented("inputs rank not in [" , kMinDims, "," , |
55 | kMaxDims, "]: " , dims)); |
56 | OP_REQUIRES( |
57 | context, |
58 | TensorShapeUtils::IsMatrix(in1.shape()) && in1.dim_size(1) == 2, |
59 | errors::InvalidArgument("paddings must be a matrix with 2 columns: " , |
60 | in1.shape().DebugString())); |
61 | OP_REQUIRES( |
62 | context, dims == in1.dim_size(0), |
63 | errors::InvalidArgument( |
64 | "The first dimension of paddings must be the rank of inputs" , |
65 | in1.shape().DebugString(), " " , in0.shape().DebugString())); |
66 | |
67 | T pad_value = T(); |
68 | if (context->num_inputs() == 3) { |
69 | const Tensor& constant_values = context->input(2); |
70 | OP_REQUIRES( |
71 | context, TensorShapeUtils::IsScalar(constant_values.shape()), |
72 | errors::InvalidArgument("constant_values must be a scalar. Found: " , |
73 | constant_values.shape().DebugString())); |
74 | pad_value = context->input(2).scalar<T>()(); |
75 | } |
76 | |
77 | // Compute the shape of the output tensor, and allocate it. |
78 | TensorShape output_shape; |
79 | typename TTypes<Tpadding>::ConstMatrix paddings = in1.matrix<Tpadding>(); |
80 | for (int d = 0; d < dims; ++d) { |
81 | const Tpadding before_d = |
82 | paddings(d, 0); // Pad before existing elements. |
83 | const Tpadding after_d = paddings(d, 1); // Pad after existing elements. |
84 | OP_REQUIRES(context, before_d >= 0 && after_d >= 0, |
85 | errors::InvalidArgument("Paddings must be non-negative: " , |
86 | before_d, " " , after_d)); |
87 | const int64_t size_d = in0.dim_size(d); |
88 | OP_REQUIRES_OK( |
89 | context, output_shape.AddDimWithStatus(before_d + size_d + after_d)); |
90 | } |
91 | |
92 | // If there is no padding to be done, forward the input to output. |
93 | if (output_shape.num_elements() == in0.NumElements()) { |
94 | // When num_elements == 0, shape may have changed. |
95 | Tensor out; |
96 | CHECK(out.CopyFrom(in0, output_shape)); |
97 | context->set_output(0, out); |
98 | return; |
99 | } |
100 | |
101 | TensorShape collapsed_input_shape; |
102 | TensorShape collapsed_output_shape; |
103 | Tensor collapsed_paddings; |
104 | if (dims > 1 && CollapseAdjacentNonPaddedDimensions( |
105 | in0.shape(), in1, output_shape, &collapsed_input_shape, |
106 | &collapsed_paddings, &collapsed_output_shape)) { |
107 | Tensor collapsed_input; |
108 | CHECK(collapsed_input.CopyFrom(in0, collapsed_input_shape)); |
109 | Tensor collapsed_output; |
110 | AllocatorAttributes alloc_attrs; |
111 | alloc_attrs.set_on_host(context->input_memory_type(0) == HOST_MEMORY); |
112 | OP_REQUIRES_OK(context, |
113 | context->allocate_temp(collapsed_input.dtype(), |
114 | collapsed_output_shape, |
115 | &collapsed_output, alloc_attrs)); |
116 | const Tensor& collapsed_paddings_ref = collapsed_paddings; |
117 | typename TTypes<Tpadding>::ConstMatrix collapsed_paddings_matrix = |
118 | collapsed_paddings_ref.matrix<Tpadding>(); |
119 | |
120 | OperateWithVariableRank(context, collapsed_input_shape.dims(), |
121 | collapsed_input, collapsed_paddings_matrix, |
122 | pad_value, &collapsed_output); |
123 | |
124 | Tensor output; |
125 | CHECK(output.CopyFrom(collapsed_output, output_shape)); |
126 | context->set_output(0, output); |
127 | } else { |
128 | Tensor* output = nullptr; |
129 | OP_REQUIRES_OK(context, |
130 | context->allocate_output(0, output_shape, &output)); |
131 | OperateWithVariableRank(context, dims, in0, paddings, pad_value, output); |
132 | } |
133 | } |
134 | |
135 | private: |
136 | // Collapses adjacent dimensions that are not padded to one dimension for |
137 | // speed. Returns true if any two dimensions are collapsed. For example, |
138 | // |
139 | // Pad(input_shape=[8, 28, 28, 3], |
140 | // paddings=[[0, 0], [0, 0], [0, 0], [0, 1]] |
141 | // is equivalent to |
142 | // Pad(input_shape=[6272, 3], |
143 | // paddings=[[0, 0], [0, 1]]) |
144 | // |
145 | // input_shape: the original input shape. |
146 | // paddings_as_tensor: the original paddings. |
147 | // output_shape: the original output shape. |
148 | // collapsed_input_shape: the input shape after collapsing. |
149 | // collapsed_paddings_as_tensor: the paddings after collapsing. |
150 | // collapsed_output_shape: the output shape after collapsing. |
151 | static bool CollapseAdjacentNonPaddedDimensions( |
152 | const TensorShape& input_shape, const Tensor& paddings_as_tensor, |
153 | const TensorShape& output_shape, TensorShape* collapsed_input_shape, |
154 | Tensor* collapsed_paddings_as_tensor, |
155 | TensorShape* collapsed_output_shape) { |
156 | bool collapsed = false; |
157 | typename TTypes<Tpadding>::ConstMatrix paddings = |
158 | paddings_as_tensor.matrix<Tpadding>(); |
159 | std::vector<std::pair<int, int>> collapsed_paddings; |
160 | int i = 0; |
161 | while (i < paddings.dimension(0)) { |
162 | if (paddings(i, 0) != 0 || paddings(i, 1) != 0) { |
163 | // If padded, copy the original dimension over. |
164 | collapsed_input_shape->InsertDim(collapsed_input_shape->dims(), |
165 | input_shape.dim_size(i)); |
166 | collapsed_output_shape->InsertDim(collapsed_output_shape->dims(), |
167 | output_shape.dim_size(i)); |
168 | collapsed_paddings.push_back({paddings(i, 0), paddings(i, 1)}); |
169 | ++i; |
170 | } else { |
171 | // If not padded, find the next dimension that is padded and collapse |
172 | // all dimensions in between to one dimension. |
173 | int64_t collapsed_input_dim_size = input_shape.dim_size(i); |
174 | int64_t collapsed_output_dim_size = output_shape.dim_size(i); |
175 | ++i; |
176 | while (i < paddings.dimension(0) && paddings(i, 0) == 0 && |
177 | paddings(i, 1) == 0) { |
178 | collapsed = true; |
179 | collapsed_input_dim_size *= input_shape.dim_size(i); |
180 | collapsed_output_dim_size *= output_shape.dim_size(i); |
181 | ++i; |
182 | } |
183 | collapsed_input_shape->InsertDim(collapsed_input_shape->dims(), |
184 | collapsed_input_dim_size); |
185 | collapsed_output_shape->InsertDim(collapsed_output_shape->dims(), |
186 | collapsed_output_dim_size); |
187 | collapsed_paddings.push_back({0, 0}); |
188 | } |
189 | } |
190 | |
191 | // Copy collapsed_paddings to collapsed_paddings_as_tensor. |
192 | *collapsed_paddings_as_tensor = Tensor( |
193 | paddings_as_tensor.dtype(), |
194 | TensorShape({static_cast<int64_t>(collapsed_paddings.size()), 2})); |
195 | auto collapsed_paddings_as_matrix = |
196 | collapsed_paddings_as_tensor->matrix<Tpadding>(); |
197 | for (size_t i = 0; i < collapsed_paddings.size(); ++i) { |
198 | collapsed_paddings_as_matrix(i, 0) = collapsed_paddings[i].first; |
199 | collapsed_paddings_as_matrix(i, 1) = collapsed_paddings[i].second; |
200 | } |
201 | return collapsed; |
202 | } |
203 | |
204 | void OperateWithVariableRank(OpKernelContext* context, int fixed_dims, |
205 | const Tensor& input, |
206 | typename TTypes<Tpadding>::ConstMatrix paddings, |
207 | T pad_value, Tensor* output) { |
208 | // Invoke the dims-specific implementation. |
209 | switch (fixed_dims) { |
210 | case 0: |
211 | Operate<0>(context, input.tensor<T, 0>(), paddings, pad_value, output); |
212 | break; |
213 | case 1: |
214 | // TODO(irving): Once Pad doesn't need a scalar special case, |
215 | // change flat to tensor. That is, once !allow_legacy_scalars(). |
216 | Operate<1>(context, input.flat<T>(), paddings, pad_value, output); |
217 | break; |
218 | case 2: |
219 | Operate<2>(context, input.tensor<T, 2>(), paddings, pad_value, output); |
220 | break; |
221 | case 3: |
222 | Operate<3>(context, input.tensor<T, 3>(), paddings, pad_value, output); |
223 | break; |
224 | case 4: |
225 | Operate<4>(context, input.tensor<T, 4>(), paddings, pad_value, output); |
226 | break; |
227 | case 5: |
228 | Operate<5>(context, input.tensor<T, 5>(), paddings, pad_value, output); |
229 | break; |
230 | case 6: |
231 | Operate<6>(context, input.tensor<T, 6>(), paddings, pad_value, output); |
232 | break; |
233 | default: |
234 | OP_REQUIRES(context, false, |
235 | errors::InvalidArgument("Only ranks up to 6 supported: " , |
236 | input.shape().DebugString())); |
237 | } |
238 | } |
239 | |
240 | template <int Dims> |
241 | void Operate(OpKernelContext* context, |
242 | typename TTypes<T, Dims>::ConstTensor input, |
243 | typename TTypes<Tpadding>::ConstMatrix paddings, T pad_value, |
244 | Tensor* output) { |
245 | CHECK_EQ(Dims, paddings.dimension(0)); |
246 | CHECK_EQ(2, paddings.dimension(1)); |
247 | Eigen::array<Eigen::IndexPair<Tpadding>, Dims> paddings_array; |
248 | for (int i = 0; i < Dims; ++i) { |
249 | paddings_array[i] = {paddings(i, 0), paddings(i, 1)}; |
250 | } |
251 | functor::Pad<Device, T, Tpadding, Dims> functor; |
252 | functor(context->eigen_device<Device>(), output->tensor<T, Dims>(), input, |
253 | paddings_array, pad_value); |
254 | } |
255 | }; |
256 | |
257 | #define REGISTER_KERNEL(type) \ |
258 | REGISTER_KERNEL_BUILDER(Name("Pad") \ |
259 | .Device(DEVICE_CPU) \ |
260 | .TypeConstraint<type>("T") \ |
261 | .TypeConstraint<int32>("Tpaddings") \ |
262 | .HostMemory("paddings"), \ |
263 | PadOp<CPUDevice, type, int32>); \ |
264 | REGISTER_KERNEL_BUILDER(Name("Pad") \ |
265 | .Device(DEVICE_CPU) \ |
266 | .TypeConstraint<type>("T") \ |
267 | .TypeConstraint<int64_t>("Tpaddings") \ |
268 | .HostMemory("paddings"), \ |
269 | PadOp<CPUDevice, type, int64>); \ |
270 | REGISTER_KERNEL_BUILDER(Name("PadV2") \ |
271 | .Device(DEVICE_CPU) \ |
272 | .TypeConstraint<type>("T") \ |
273 | .TypeConstraint<int32>("Tpaddings") \ |
274 | .HostMemory("paddings") \ |
275 | .HostMemory("constant_values"), \ |
276 | PadOp<CPUDevice, type, int32>); \ |
277 | REGISTER_KERNEL_BUILDER(Name("PadV2") \ |
278 | .Device(DEVICE_CPU) \ |
279 | .TypeConstraint<type>("T") \ |
280 | .TypeConstraint<int64_t>("Tpaddings") \ |
281 | .HostMemory("paddings") \ |
282 | .HostMemory("constant_values"), \ |
283 | PadOp<CPUDevice, type, int64>); |
284 | |
285 | TF_CALL_POD_TYPES(REGISTER_KERNEL); |
286 | TF_CALL_QUANTIZED_TYPES(REGISTER_KERNEL); |
287 | TF_CALL_tstring(REGISTER_KERNEL); |
288 | #undef REGISTER_KERNEL |
289 | |
290 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
291 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
292 | // Forward declarations of the functor specializations for GPU. |
293 | namespace functor { |
294 | #define DECLARE_GPU_SPEC(T, Dims) \ |
295 | template <> \ |
296 | void Pad<GPUDevice, T, int32, Dims>::operator()( \ |
297 | const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ |
298 | typename TTypes<T, Dims>::ConstTensor input, \ |
299 | Eigen::array<Eigen::IndexPair<int32>, Dims> paddings, T pad_value); \ |
300 | extern template struct Pad<GPUDevice, T, int32, Dims>; \ |
301 | template <> \ |
302 | void Pad<GPUDevice, T, int64_t, Dims>::operator()( \ |
303 | const GPUDevice& d, typename TTypes<T, Dims>::Tensor output, \ |
304 | typename TTypes<T, Dims>::ConstTensor input, \ |
305 | Eigen::array<Eigen::IndexPair<int64_t>, Dims> paddings, T pad_value); \ |
306 | extern template struct Pad<GPUDevice, T, int64_t, Dims>; |
307 | |
308 | #define DECLARE_GPU_SPECS(T) \ |
309 | DECLARE_GPU_SPEC(T, 0); \ |
310 | DECLARE_GPU_SPEC(T, 1); \ |
311 | DECLARE_GPU_SPEC(T, 2); \ |
312 | DECLARE_GPU_SPEC(T, 3); \ |
313 | DECLARE_GPU_SPEC(T, 4); \ |
314 | DECLARE_GPU_SPEC(T, 5); \ |
315 | DECLARE_GPU_SPEC(T, 6); |
316 | |
317 | TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPECS); |
318 | TF_CALL_int8(DECLARE_GPU_SPECS); |
319 | TF_CALL_uint8(DECLARE_GPU_SPECS); |
320 | } // namespace functor |
321 | |
322 | // Registration of the GPU implementations. |
323 | #define REGISTER_GPU_KERNEL(T) \ |
324 | REGISTER_KERNEL_BUILDER(Name("Pad") \ |
325 | .Device(DEVICE_GPU) \ |
326 | .TypeConstraint<T>("T") \ |
327 | .TypeConstraint<int32>("Tpaddings") \ |
328 | .HostMemory("paddings"), \ |
329 | PadOp<GPUDevice, T, int32>); \ |
330 | REGISTER_KERNEL_BUILDER(Name("Pad") \ |
331 | .Device(DEVICE_GPU) \ |
332 | .TypeConstraint<T>("T") \ |
333 | .TypeConstraint<int64_t>("Tpaddings") \ |
334 | .HostMemory("paddings"), \ |
335 | PadOp<GPUDevice, T, int64>); \ |
336 | REGISTER_KERNEL_BUILDER(Name("PadV2") \ |
337 | .Device(DEVICE_GPU) \ |
338 | .TypeConstraint<T>("T") \ |
339 | .TypeConstraint<int32>("Tpaddings") \ |
340 | .HostMemory("paddings") \ |
341 | .HostMemory("constant_values"), \ |
342 | PadOp<GPUDevice, T, int32>) \ |
343 | REGISTER_KERNEL_BUILDER(Name("PadV2") \ |
344 | .Device(DEVICE_GPU) \ |
345 | .TypeConstraint<T>("T") \ |
346 | .TypeConstraint<int64_t>("Tpaddings") \ |
347 | .HostMemory("paddings") \ |
348 | .HostMemory("constant_values"), \ |
349 | PadOp<GPUDevice, T, int64>) |
350 | |
351 | TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNEL); |
352 | TF_CALL_int8(REGISTER_GPU_KERNEL); |
353 | TF_CALL_uint8(REGISTER_GPU_KERNEL); |
354 | |
355 | // A special GPU kernel for int32. |
356 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
357 | // registration requires all int32 inputs and outputs to be in host memory. |
358 | REGISTER_KERNEL_BUILDER(Name("Pad" ) |
359 | .Device(DEVICE_GPU) |
360 | .TypeConstraint<int32>("T" ) |
361 | .TypeConstraint<int32>("Tpaddings" ) |
362 | .HostMemory("input" ) |
363 | .HostMemory("paddings" ) |
364 | .HostMemory("output" ), |
365 | PadOp<CPUDevice, int32, int32>); |
366 | REGISTER_KERNEL_BUILDER(Name("Pad" ) |
367 | .Device(DEVICE_GPU) |
368 | .TypeConstraint<int32>("T" ) |
369 | .TypeConstraint<int64_t>("Tpaddings" ) |
370 | .HostMemory("input" ) |
371 | .HostMemory("paddings" ) |
372 | .HostMemory("output" ), |
373 | PadOp<CPUDevice, int32, int64>); |
374 | REGISTER_KERNEL_BUILDER(Name("PadV2" ) |
375 | .Device(DEVICE_GPU) |
376 | .TypeConstraint<int32>("T" ) |
377 | .TypeConstraint<int32>("Tpaddings" ) |
378 | .HostMemory("input" ) |
379 | .HostMemory("paddings" ) |
380 | .HostMemory("constant_values" ) |
381 | .HostMemory("output" ), |
382 | PadOp<CPUDevice, int32, int32>); |
383 | REGISTER_KERNEL_BUILDER(Name("PadV2" ) |
384 | .Device(DEVICE_GPU) |
385 | .TypeConstraint<int32>("T" ) |
386 | .TypeConstraint<int64_t>("Tpaddings" ) |
387 | .HostMemory("input" ) |
388 | .HostMemory("paddings" ) |
389 | .HostMemory("constant_values" ) |
390 | .HostMemory("output" ), |
391 | PadOp<CPUDevice, int32, int64>); |
392 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
393 | |
394 | |
395 | } // end namespace tensorflow |
396 | |