1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/pad.h"
16
17#include <stdint.h>
18
19#include <limits>
20
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/compatibility.h"
23#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25#include "tensorflow/lite/kernels/internal/tensor.h"
26#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27#include "tensorflow/lite/kernels/internal/types.h"
28#include "tensorflow/lite/kernels/kernel_util.h"
29
30namespace tflite {
31namespace ops {
32namespace builtin {
33namespace pad {
34
35// This file has two implementations of Pad.
36enum KernelType {
37 kReference,
38 kGenericOptimized,
39};
40
41struct PadContext {
42 PadContext(TfLiteContext* context, TfLiteNode* node) {
43 input = GetInput(context, node, 0);
44 paddings = GetInput(context, node, 1);
45 if (NumInputs(node) == 3) {
46 constant_values = GetOptionalInputTensor(context, node, 2);
47 } else {
48 constant_values = nullptr;
49 }
50 output = GetOutput(context, node, 0);
51 dims = NumDimensions(input);
52
53 resizing_category = ResizingCategory::kGenericResize;
54 const int paddings_total = GetTensorShape(paddings).FlatSize();
55 const int32* paddings_data = GetTensorData<int32>(paddings);
56 // Paddings will be a n,2 array, and we need to detect 4D arrays with the
57 // pattern { {0,0}, {a, b}, {c, d}, {0,0} }.
58 if (IsConstantTensor(paddings) && paddings_total == 8 &&
59 (paddings_data[0] == 0 && paddings_data[1] == 0) &&
60 (paddings_data[6] == 0 && paddings_data[7] == 0)) {
61 resizing_category = ResizingCategory::kImageStyle;
62 }
63 }
64 const TfLiteTensor* constant_values;
65 const TfLiteTensor* input;
66 const TfLiteTensor* paddings;
67 TfLiteTensor* output;
68 int dims;
69 ResizingCategory resizing_category;
70};
71
72// Resizes output array based on the input size and padding size. This function
73// is callable from both Prepare() and Eval() as long as the caller ensures the
74// paddings data is present.
75TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
76 PadContext* op_context) {
77 // Ensures the paddings array is dims x 2.
78 TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0),
79 op_context->dims);
80 TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2);
81
82 // Ensures all the elements of the paddings is non-negative.
83 const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
84
85 for (int idx = 0; idx < op_context->dims; ++idx) {
86 int before_padding = *paddings_data++;
87 int after_padding = *paddings_data++;
88
89 TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0),
90 "Pad value has to be greater than equal to 0.");
91 }
92
93 // Determines the size of the output tensor.
94 TfLiteIntArray* input_size = op_context->input->dims;
95 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
96 paddings_data = GetTensorData<int32>(op_context->paddings);
97
98 for (int idx = 0; idx < op_context->dims; ++idx) {
99 int before_padding = *paddings_data++;
100 int after_padding = *paddings_data++;
101
102 output_size->data[idx] =
103 (input_size->data[idx] + before_padding + after_padding);
104 }
105
106 return context->ResizeTensor(context, op_context->output, output_size);
107}
108
109TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
110 TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3);
111 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
112
113 PadContext op_context(context, node);
114 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
115 op_context.output->type);
116 if (op_context.constant_values != nullptr) {
117 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
118 op_context.constant_values->type);
119 }
120
121 // Ensure we do not exceed maximum dimension count.
122 TF_LITE_ENSURE(
123 context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
124
125 // Exit early if paddings is a non-const tensor or the given input is an
126 // unranked input. Set output tensor to dynamic so output size can be
127 // determined in Eval.
128 if (NumDimensions(op_context.input) == 0 ||
129 !IsConstantTensor(op_context.paddings)) {
130 SetTensorToDynamic(op_context.output);
131 return kTfLiteOk;
132 }
133 return ResizeOutputTensor(context, &op_context);
134}
135
136template <typename integer_type>
137TfLiteStatus EvalInt(TfLiteContext* context, const PadContext& op_context,
138 const tflite::PadParams& op_params) {
139 integer_type pad_value;
140 if (op_context.constant_values == nullptr) {
141 // Quantized Pad requires that 0 is represented in the quantized
142 // range.
143 TF_LITE_ENSURE(context, op_context.output->params.zero_point >=
144 std::numeric_limits<integer_type>::min());
145 TF_LITE_ENSURE(context, op_context.output->params.zero_point <=
146 std::numeric_limits<integer_type>::max());
147 pad_value = static_cast<integer_type>(op_context.output->params.zero_point);
148 } else {
149 // Quantized Pad requires that 'constant_values' is represented in the
150 // same quantized range as the input and output tensors.
151 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point,
152 op_context.constant_values->params.zero_point);
153 TF_LITE_ENSURE_EQ(context, op_context.output->params.scale,
154 op_context.constant_values->params.scale);
155 pad_value = *GetTensorData<integer_type>(op_context.constant_values);
156 }
157 const integer_type pad_value_copy = pad_value;
158 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
159 optimized_ops::PadImageStyle(
160 op_params, GetTensorShape(op_context.input),
161 GetTensorData<integer_type>(op_context.input), &pad_value_copy,
162 GetTensorShape(op_context.output),
163 GetTensorData<integer_type>(op_context.output));
164 } else {
165 optimized_ops::Pad(op_params, GetTensorShape(op_context.input),
166 GetTensorData<integer_type>(op_context.input),
167 &pad_value_copy, GetTensorShape(op_context.output),
168 GetTensorData<integer_type>(op_context.output));
169 }
170
171 return kTfLiteOk;
172}
173
174template <KernelType kernel_type>
175TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
176 PadContext op_context(context, node);
177
178 if (op_context.constant_values != nullptr) {
179 // Ensure that constant_values is a scalar.
180 TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1);
181 }
182
183 // Resize the output tensor if the output tensor is dynamic.
184 if (IsDynamicTensor(op_context.output)) {
185 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
186 }
187
188 // Create before and after padding arrays that are accepted by the kernel.
189 const int32* paddings_data = GetTensorData<int32>(op_context.paddings);
190
191 TF_LITE_ENSURE(
192 context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount());
193
194 tflite::PadParams op_params;
195 op_params.left_padding_count = op_context.dims;
196 op_params.right_padding_count = op_context.dims;
197
198 for (int idx = op_context.dims - 1; idx >= 0; --idx) {
199 op_params.left_padding[idx] = paddings_data[idx * 2];
200 op_params.right_padding[idx] = paddings_data[idx * 2 + 1];
201 }
202
203#define TF_LITE_PAD(type, op_name, scalar, pad_value) \
204 const scalar pad_value_copy = pad_value; \
205 \
206 type::op_name(op_params, GetTensorShape(op_context.input), \
207 GetTensorData<scalar>(op_context.input), &pad_value_copy, \
208 GetTensorShape(op_context.output), \
209 GetTensorData<scalar>(op_context.output))
210 switch (op_context.input->type) {
211 case kTfLiteFloat32: {
212 float pad_value = op_context.constant_values == nullptr
213 ? 0.f
214 : *GetTensorData<float>(op_context.constant_values);
215 if (kernel_type == kReference) {
216 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
217 TF_LITE_PAD(reference_ops, PadImageStyle, float, pad_value);
218 } else {
219 TF_LITE_PAD(reference_ops, Pad, float, pad_value);
220 }
221 } else if (kernel_type == kGenericOptimized) {
222 if (op_context.resizing_category == ResizingCategory::kImageStyle) {
223 TF_LITE_PAD(optimized_ops, PadImageStyle, float, pad_value);
224 } else {
225 TF_LITE_PAD(optimized_ops, Pad, float, pad_value);
226 }
227 }
228 } break;
229 case kTfLiteUInt8: {
230 EvalInt<uint8_t>(context, op_context, op_params);
231 } break;
232 case kTfLiteInt8: {
233 EvalInt<int8_t>(context, op_context, op_params);
234 } break;
235 case kTfLiteInt16: {
236 EvalInt<int16_t>(context, op_context, op_params);
237 } break;
238 case kTfLiteInt32: {
239 int32_t pad_value =
240 op_context.constant_values == nullptr
241 ? 0
242 : *GetTensorData<int32_t>(op_context.constant_values);
243 if (kernel_type == kReference) {
244 TF_LITE_PAD(reference_ops, Pad, int32_t, pad_value);
245 } else if (kernel_type == kGenericOptimized) {
246 TF_LITE_PAD(optimized_ops, Pad, int32_t, pad_value);
247 }
248 } break;
249 case kTfLiteInt64: {
250 int64_t pad_value =
251 op_context.constant_values == nullptr
252 ? 0L
253 : *GetTensorData<int64_t>(op_context.constant_values);
254 if (kernel_type == kReference) {
255 TF_LITE_PAD(reference_ops, Pad, int64_t, pad_value);
256 } else if (kernel_type == kGenericOptimized) {
257 TF_LITE_PAD(optimized_ops, Pad, int64_t, pad_value);
258 }
259 } break;
260 default:
261 TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by Pad.",
262 TfLiteTypeGetName(op_context.input->type));
263 return kTfLiteError;
264 }
265#undef TF_LITE_PAD
266 return kTfLiteOk;
267}
268
269} // namespace pad
270
271TfLiteRegistration* Register_PAD_REF() {
272 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
273 pad::Eval<pad::kReference>};
274 return &r;
275}
276
277TfLiteRegistration* Register_PAD_GENERIC_OPT() {
278 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
279 pad::Eval<pad::kGenericOptimized>};
280 return &r;
281}
282
283TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); }
284
285// Also register Pad as PadV2.
286TfLiteRegistration* Register_PADV2_REF() {
287 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
288 pad::Eval<pad::kReference>};
289 return &r;
290}
291
292TfLiteRegistration* Register_PADV2_GENERIC_OPT() {
293 static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare,
294 pad::Eval<pad::kGenericOptimized>};
295 return &r;
296}
297
298TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); }
299
300} // namespace builtin
301} // namespace ops
302} // namespace tflite
303