1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
16
17#include <stdint.h>
18
19#include "tensorflow/lite/c/builtin_op_data.h"
20#include "tensorflow/lite/c/common.h"
21#include "tensorflow/lite/kernels/internal/compatibility.h"
22#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
23#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25#include "tensorflow/lite/kernels/internal/tensor.h"
26#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27#include "tensorflow/lite/kernels/internal/types.h"
28#include "tensorflow/lite/kernels/kernel_util.h"
29
30namespace tflite {
31namespace ops {
32namespace builtin {
33namespace resize_nearest_neighbor {
34
35// This file has three implementations of RESIZE_NEAREST_NEIGHBOR.
36enum KernelType {
37 kReference,
38 kGenericOptimized,
39 kNeonOptimized,
40};
41
42constexpr int kInputTensor = 0;
43constexpr int kSizeTensor = 1;
44constexpr int kOutputTensor = 0;
45
46TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
47 const TfLiteTensor* input,
48 const TfLiteTensor* size,
49 TfLiteTensor* output) {
50 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
51 output_size->data[0] = input->dims->data[0];
52 const int32* size_data = GetTensorData<int32>(size);
53 output_size->data[1] = size_data[0];
54 output_size->data[2] = size_data[1];
55 output_size->data[3] = input->dims->data[3];
56 return context->ResizeTensor(context, output, output_size);
57}
58
59TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
60 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
61 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
62
63 const TfLiteTensor* input;
64 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
65 const TfLiteTensor* size;
66 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
67 TfLiteTensor* output;
68 TF_LITE_ENSURE_OK(context,
69 GetOutputSafe(context, node, kOutputTensor, &output));
70
71 // Our current implementations relies on the input being 4D,
72 // and the size being 1D tensor with exactly 2 elements.
73 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
74 TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
75 TF_LITE_ENSURE_TYPES_EQ(context, size->type, kTfLiteInt32);
76 TF_LITE_ENSURE_EQ(context, size->dims->data[0], 2);
77
78 output->type = input->type;
79
80 if (!IsConstantTensor(size)) {
81 SetTensorToDynamic(output);
82 return kTfLiteOk;
83 }
84 return ResizeOutputTensor(context, input, size, output);
85}
86
87template <KernelType kernel_type>
88TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
89 auto* params =
90 reinterpret_cast<TfLiteResizeNearestNeighborParams*>(node->builtin_data);
91
92 const TfLiteTensor* input;
93 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
94 TfLiteTensor* output;
95 TF_LITE_ENSURE_OK(context,
96 GetOutputSafe(context, node, kOutputTensor, &output));
97 const TfLiteTensor* size;
98 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
99
100 if (IsDynamicTensor(output)) {
101 TF_LITE_ENSURE_OK(context,
102 ResizeOutputTensor(context, input, size, output));
103 }
104
105 tflite::ResizeNearestNeighborParams op_params;
106 op_params.align_corners = params->align_corners;
107 op_params.half_pixel_centers = params->half_pixel_centers;
108
109 if (output->type == kTfLiteFloat32) {
110 reference_ops::ResizeNearestNeighbor(
111 op_params, GetTensorShape(input), GetTensorData<int32>(input),
112 GetTensorShape(size), GetTensorData<int32>(size),
113 GetTensorShape(output), GetTensorData<int32>(output));
114 } else if (output->type == kTfLiteUInt8) {
115 if (kernel_type == kReference) {
116 reference_ops::ResizeNearestNeighbor(
117 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
118 GetTensorShape(size), GetTensorData<int32>(size),
119 GetTensorShape(output), GetTensorData<uint8_t>(output));
120 }
121 if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
122 optimized_ops::ResizeNearestNeighbor(
123 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
124 GetTensorShape(size), GetTensorData<int32>(size),
125 GetTensorShape(output), GetTensorData<uint8_t>(output));
126 }
127 } else if (output->type == kTfLiteInt8) {
128 reference_ops::ResizeNearestNeighbor(
129 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
130 GetTensorShape(size), GetTensorData<int32>(size),
131 GetTensorShape(output), GetTensorData<int8_t>(output));
132 } else if (output->type == kTfLiteInt16) {
133 reference_ops::ResizeNearestNeighbor(
134 op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
135 GetTensorShape(size), GetTensorData<int32>(size),
136 GetTensorShape(output), GetTensorData<int16_t>(output));
137 } else {
138 TF_LITE_KERNEL_LOG(
139 context, "Output type is %s, requires float, uint8, int8 or int16.",
140 TfLiteTypeGetName(output->type));
141 return kTfLiteError;
142 }
143
144 return kTfLiteOk;
145}
146
147} // namespace resize_nearest_neighbor
148
149TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR_REF() {
150 static TfLiteRegistration r = {
151 nullptr, nullptr, resize_nearest_neighbor::Prepare,
152 resize_nearest_neighbor::Eval<resize_nearest_neighbor::kReference>};
153 return &r;
154}
155
156TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR_GENERIC_OPT() {
157 static TfLiteRegistration r = {
158 nullptr, nullptr, resize_nearest_neighbor::Prepare,
159 resize_nearest_neighbor::Eval<
160 resize_nearest_neighbor::kGenericOptimized>};
161 return &r;
162}
163
164TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR_NEON_OPT() {
165 static TfLiteRegistration r = {
166 nullptr, nullptr, resize_nearest_neighbor::Prepare,
167 resize_nearest_neighbor::Eval<resize_nearest_neighbor::kNeonOptimized>};
168 return &r;
169}
170
171TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR() {
172#ifdef USE_NEON
173 return Register_RESIZE_NEAREST_NEIGHBOR_NEON_OPT();
174#else
175 return Register_RESIZE_NEAREST_NEIGHBOR_GENERIC_OPT();
176#endif
177}
178
179} // namespace builtin
180} // namespace ops
181} // namespace tflite
182