1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/pad.h" |
16 | |
17 | #include <stdint.h> |
18 | |
19 | #include <limits> |
20 | |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
23 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
24 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
25 | #include "tensorflow/lite/kernels/internal/tensor.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
27 | #include "tensorflow/lite/kernels/internal/types.h" |
28 | #include "tensorflow/lite/kernels/kernel_util.h" |
29 | |
30 | namespace tflite { |
31 | namespace ops { |
32 | namespace builtin { |
33 | namespace pad { |
34 | |
35 | // This file has two implementations of Pad. |
36 | enum KernelType { |
37 | kReference, |
38 | kGenericOptimized, |
39 | }; |
40 | |
41 | struct PadContext { |
42 | PadContext(TfLiteContext* context, TfLiteNode* node) { |
43 | input = GetInput(context, node, 0); |
44 | paddings = GetInput(context, node, 1); |
45 | if (NumInputs(node) == 3) { |
46 | constant_values = GetOptionalInputTensor(context, node, 2); |
47 | } else { |
48 | constant_values = nullptr; |
49 | } |
50 | output = GetOutput(context, node, 0); |
51 | dims = NumDimensions(input); |
52 | |
53 | resizing_category = ResizingCategory::kGenericResize; |
54 | const int paddings_total = GetTensorShape(paddings).FlatSize(); |
55 | const int32* paddings_data = GetTensorData<int32>(paddings); |
56 | // Paddings will be a n,2 array, and we need to detect 4D arrays with the |
57 | // pattern { {0,0}, {a, b}, {c, d}, {0,0} }. |
58 | if (IsConstantTensor(paddings) && paddings_total == 8 && |
59 | (paddings_data[0] == 0 && paddings_data[1] == 0) && |
60 | (paddings_data[6] == 0 && paddings_data[7] == 0)) { |
61 | resizing_category = ResizingCategory::kImageStyle; |
62 | } |
63 | } |
64 | const TfLiteTensor* constant_values; |
65 | const TfLiteTensor* input; |
66 | const TfLiteTensor* paddings; |
67 | TfLiteTensor* output; |
68 | int dims; |
69 | ResizingCategory resizing_category; |
70 | }; |
71 | |
72 | // Resizes output array based on the input size and padding size. This function |
73 | // is callable from both Prepare() and Eval() as long as the caller ensures the |
74 | // paddings data is present. |
75 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
76 | PadContext* op_context) { |
77 | // Ensures the paddings array is dims x 2. |
78 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0), |
79 | op_context->dims); |
80 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2); |
81 | |
82 | // Ensures all the elements of the paddings is non-negative. |
83 | const int32* paddings_data = GetTensorData<int32>(op_context->paddings); |
84 | |
85 | for (int idx = 0; idx < op_context->dims; ++idx) { |
86 | int before_padding = *paddings_data++; |
87 | int after_padding = *paddings_data++; |
88 | |
89 | TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0), |
90 | "Pad value has to be greater than equal to 0." ); |
91 | } |
92 | |
93 | // Determines the size of the output tensor. |
94 | TfLiteIntArray* input_size = op_context->input->dims; |
95 | TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); |
96 | paddings_data = GetTensorData<int32>(op_context->paddings); |
97 | |
98 | for (int idx = 0; idx < op_context->dims; ++idx) { |
99 | int before_padding = *paddings_data++; |
100 | int after_padding = *paddings_data++; |
101 | |
102 | output_size->data[idx] = |
103 | (input_size->data[idx] + before_padding + after_padding); |
104 | } |
105 | |
106 | return context->ResizeTensor(context, op_context->output, output_size); |
107 | } |
108 | |
109 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
110 | TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); |
111 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
112 | |
113 | PadContext op_context(context, node); |
114 | TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type, |
115 | op_context.output->type); |
116 | if (op_context.constant_values != nullptr) { |
117 | TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type, |
118 | op_context.constant_values->type); |
119 | } |
120 | |
121 | // Ensure we do not exceed maximum dimension count. |
122 | TF_LITE_ENSURE( |
123 | context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount()); |
124 | |
125 | // Exit early if paddings is a non-const tensor or the given input is an |
126 | // unranked input. Set output tensor to dynamic so output size can be |
127 | // determined in Eval. |
128 | if (NumDimensions(op_context.input) == 0 || |
129 | !IsConstantTensor(op_context.paddings)) { |
130 | SetTensorToDynamic(op_context.output); |
131 | return kTfLiteOk; |
132 | } |
133 | return ResizeOutputTensor(context, &op_context); |
134 | } |
135 | |
136 | template <typename integer_type> |
137 | TfLiteStatus EvalInt(TfLiteContext* context, const PadContext& op_context, |
138 | const tflite::PadParams& op_params) { |
139 | integer_type pad_value; |
140 | if (op_context.constant_values == nullptr) { |
141 | // Quantized Pad requires that 0 is represented in the quantized |
142 | // range. |
143 | TF_LITE_ENSURE(context, op_context.output->params.zero_point >= |
144 | std::numeric_limits<integer_type>::min()); |
145 | TF_LITE_ENSURE(context, op_context.output->params.zero_point <= |
146 | std::numeric_limits<integer_type>::max()); |
147 | pad_value = static_cast<integer_type>(op_context.output->params.zero_point); |
148 | } else { |
149 | // Quantized Pad requires that 'constant_values' is represented in the |
150 | // same quantized range as the input and output tensors. |
151 | TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, |
152 | op_context.constant_values->params.zero_point); |
153 | TF_LITE_ENSURE_EQ(context, op_context.output->params.scale, |
154 | op_context.constant_values->params.scale); |
155 | pad_value = *GetTensorData<integer_type>(op_context.constant_values); |
156 | } |
157 | const integer_type pad_value_copy = pad_value; |
158 | if (op_context.resizing_category == ResizingCategory::kImageStyle) { |
159 | optimized_ops::PadImageStyle( |
160 | op_params, GetTensorShape(op_context.input), |
161 | GetTensorData<integer_type>(op_context.input), &pad_value_copy, |
162 | GetTensorShape(op_context.output), |
163 | GetTensorData<integer_type>(op_context.output)); |
164 | } else { |
165 | optimized_ops::Pad(op_params, GetTensorShape(op_context.input), |
166 | GetTensorData<integer_type>(op_context.input), |
167 | &pad_value_copy, GetTensorShape(op_context.output), |
168 | GetTensorData<integer_type>(op_context.output)); |
169 | } |
170 | |
171 | return kTfLiteOk; |
172 | } |
173 | |
174 | template <KernelType kernel_type> |
175 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
176 | PadContext op_context(context, node); |
177 | |
178 | if (op_context.constant_values != nullptr) { |
179 | // Ensure that constant_values is a scalar. |
180 | TF_LITE_ENSURE_EQ(context, NumElements(op_context.constant_values), 1); |
181 | } |
182 | |
183 | // Resize the output tensor if the output tensor is dynamic. |
184 | if (IsDynamicTensor(op_context.output)) { |
185 | TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); |
186 | } |
187 | |
188 | // Create before and after padding arrays that are accepted by the kernel. |
189 | const int32* paddings_data = GetTensorData<int32>(op_context.paddings); |
190 | |
191 | TF_LITE_ENSURE( |
192 | context, op_context.dims <= reference_ops::PadKernelMaxDimensionCount()); |
193 | |
194 | tflite::PadParams op_params; |
195 | op_params.left_padding_count = op_context.dims; |
196 | op_params.right_padding_count = op_context.dims; |
197 | |
198 | for (int idx = op_context.dims - 1; idx >= 0; --idx) { |
199 | op_params.left_padding[idx] = paddings_data[idx * 2]; |
200 | op_params.right_padding[idx] = paddings_data[idx * 2 + 1]; |
201 | } |
202 | |
203 | #define TF_LITE_PAD(type, op_name, scalar, pad_value) \ |
204 | const scalar pad_value_copy = pad_value; \ |
205 | \ |
206 | type::op_name(op_params, GetTensorShape(op_context.input), \ |
207 | GetTensorData<scalar>(op_context.input), &pad_value_copy, \ |
208 | GetTensorShape(op_context.output), \ |
209 | GetTensorData<scalar>(op_context.output)) |
210 | switch (op_context.input->type) { |
211 | case kTfLiteFloat32: { |
212 | float pad_value = op_context.constant_values == nullptr |
213 | ? 0.f |
214 | : *GetTensorData<float>(op_context.constant_values); |
215 | if (kernel_type == kReference) { |
216 | if (op_context.resizing_category == ResizingCategory::kImageStyle) { |
217 | TF_LITE_PAD(reference_ops, PadImageStyle, float, pad_value); |
218 | } else { |
219 | TF_LITE_PAD(reference_ops, Pad, float, pad_value); |
220 | } |
221 | } else if (kernel_type == kGenericOptimized) { |
222 | if (op_context.resizing_category == ResizingCategory::kImageStyle) { |
223 | TF_LITE_PAD(optimized_ops, PadImageStyle, float, pad_value); |
224 | } else { |
225 | TF_LITE_PAD(optimized_ops, Pad, float, pad_value); |
226 | } |
227 | } |
228 | } break; |
229 | case kTfLiteUInt8: { |
230 | EvalInt<uint8_t>(context, op_context, op_params); |
231 | } break; |
232 | case kTfLiteInt8: { |
233 | EvalInt<int8_t>(context, op_context, op_params); |
234 | } break; |
235 | case kTfLiteInt16: { |
236 | EvalInt<int16_t>(context, op_context, op_params); |
237 | } break; |
238 | case kTfLiteInt32: { |
239 | int32_t pad_value = |
240 | op_context.constant_values == nullptr |
241 | ? 0 |
242 | : *GetTensorData<int32_t>(op_context.constant_values); |
243 | if (kernel_type == kReference) { |
244 | TF_LITE_PAD(reference_ops, Pad, int32_t, pad_value); |
245 | } else if (kernel_type == kGenericOptimized) { |
246 | TF_LITE_PAD(optimized_ops, Pad, int32_t, pad_value); |
247 | } |
248 | } break; |
249 | case kTfLiteInt64: { |
250 | int64_t pad_value = |
251 | op_context.constant_values == nullptr |
252 | ? 0L |
253 | : *GetTensorData<int64_t>(op_context.constant_values); |
254 | if (kernel_type == kReference) { |
255 | TF_LITE_PAD(reference_ops, Pad, int64_t, pad_value); |
256 | } else if (kernel_type == kGenericOptimized) { |
257 | TF_LITE_PAD(optimized_ops, Pad, int64_t, pad_value); |
258 | } |
259 | } break; |
260 | default: |
261 | TF_LITE_KERNEL_LOG(context, "Type %s is currently not supported by Pad." , |
262 | TfLiteTypeGetName(op_context.input->type)); |
263 | return kTfLiteError; |
264 | } |
265 | #undef TF_LITE_PAD |
266 | return kTfLiteOk; |
267 | } |
268 | |
269 | } // namespace pad |
270 | |
271 | TfLiteRegistration* Register_PAD_REF() { |
272 | static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, |
273 | pad::Eval<pad::kReference>}; |
274 | return &r; |
275 | } |
276 | |
277 | TfLiteRegistration* Register_PAD_GENERIC_OPT() { |
278 | static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, |
279 | pad::Eval<pad::kGenericOptimized>}; |
280 | return &r; |
281 | } |
282 | |
283 | TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); } |
284 | |
285 | // Also register Pad as PadV2. |
286 | TfLiteRegistration* Register_PADV2_REF() { |
287 | static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, |
288 | pad::Eval<pad::kReference>}; |
289 | return &r; |
290 | } |
291 | |
292 | TfLiteRegistration* Register_PADV2_GENERIC_OPT() { |
293 | static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, |
294 | pad::Eval<pad::kGenericOptimized>}; |
295 | return &r; |
296 | } |
297 | |
298 | TfLiteRegistration* Register_PADV2() { return Register_PADV2_GENERIC_OPT(); } |
299 | |
300 | } // namespace builtin |
301 | } // namespace ops |
302 | } // namespace tflite |
303 | |