1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include "tensorflow/lite/c/builtin_op_data.h" |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
20 | #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" |
21 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
22 | // clang-format off: Clang-format thinks this header is paired. |
23 | #include "tensorflow/lite/kernels/internal/optimized/resize_bilinear.h" |
24 | // clang-format on |
25 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
28 | #include "tensorflow/lite/kernels/internal/types.h" |
29 | #include "tensorflow/lite/kernels/kernel_util.h" |
30 | |
31 | namespace tflite { |
32 | namespace ops { |
33 | namespace builtin { |
34 | namespace resize_bilinear { |
35 | |
36 | // This file has three implementation of RESIZE_BILINEAR. |
37 | enum KernelType { |
38 | kReference, |
39 | kOptimized, |
40 | }; |
41 | |
42 | constexpr int kInputTensor = 0; |
43 | constexpr int kSizeTensor = 1; |
44 | constexpr int kOutputTensor = 0; |
45 | |
46 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
47 | const TfLiteTensor* input, |
48 | const TfLiteTensor* size, |
49 | TfLiteTensor* output) { |
50 | const int32* size_data = GetTensorData<int32>(size); |
51 | // Sanity check, the up/down sampling size should always be positive. |
52 | TF_LITE_ENSURE(context, size_data[0] > 0); |
53 | TF_LITE_ENSURE(context, size_data[1] > 0); |
54 | TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); |
55 | output_size->data[0] = input->dims->data[0]; |
56 | output_size->data[1] = size_data[0]; |
57 | output_size->data[2] = size_data[1]; |
58 | output_size->data[3] = input->dims->data[3]; |
59 | return context->ResizeTensor(context, output, output_size); |
60 | } |
61 | |
62 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
63 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
64 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
65 | |
66 | const TfLiteTensor* input; |
67 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
68 | const TfLiteTensor* size; |
69 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size)); |
70 | TfLiteTensor* output; |
71 | TF_LITE_ENSURE_OK(context, |
72 | GetOutputSafe(context, node, kOutputTensor, &output)); |
73 | |
74 | // TODO(ahentz): Our current implementations rely on the inputs being 4D. |
75 | TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); |
76 | TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); |
77 | |
78 | TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); |
79 | // ResizeBilinear creates a float tensor even when the input is made of |
80 | // integers. |
81 | output->type = input->type; |
82 | |
83 | if (!IsConstantTensor(size)) { |
84 | SetTensorToDynamic(output); |
85 | return kTfLiteOk; |
86 | } |
87 | |
88 | // Ensure params are valid. |
89 | auto* params = |
90 | reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data); |
91 | if (params->half_pixel_centers && params->align_corners) { |
92 | TF_LITE_KERNEL_LOG( |
93 | context, "If half_pixel_centers is True, align_corners must be False." ); |
94 | return kTfLiteError; |
95 | } |
96 | |
97 | return ResizeOutputTensor(context, input, size, output); |
98 | } |
99 | |
100 | template <KernelType kernel_type> |
101 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
102 | auto* params = |
103 | reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data); |
104 | |
105 | const TfLiteTensor* input; |
106 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
107 | TfLiteTensor* output; |
108 | TF_LITE_ENSURE_OK(context, |
109 | GetOutputSafe(context, node, kOutputTensor, &output)); |
110 | const TfLiteTensor* size; |
111 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size)); |
112 | |
113 | if (IsDynamicTensor(output)) { |
114 | TF_LITE_ENSURE_OK(context, |
115 | ResizeOutputTensor(context, input, size, output)); |
116 | } |
117 | |
118 | if (output->type == kTfLiteFloat32) { |
119 | #define TF_LITE_RESIZE_BILINEAR(type, opname, datatype) \ |
120 | tflite::ResizeBilinearParams op_params; \ |
121 | op_params.align_corners = params->align_corners; \ |
122 | op_params.half_pixel_centers = params->half_pixel_centers; \ |
123 | type::opname(op_params, GetTensorShape(input), \ |
124 | GetTensorData<datatype>(input), GetTensorShape(size), \ |
125 | GetTensorData<int32>(size), GetTensorShape(output), \ |
126 | GetTensorData<datatype>(output)) |
127 | |
128 | if (kernel_type == kReference) { |
129 | TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinear, float); |
130 | } else if (kernel_type == kOptimized) { |
131 | TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, float); |
132 | } |
133 | } else if (output->type == kTfLiteUInt8) { |
134 | if (kernel_type == kReference) { |
135 | TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinear, uint8_t); |
136 | } else if (kernel_type == kOptimized) { |
137 | TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, uint8_t); |
138 | } |
139 | } else if (output->type == kTfLiteInt8) { |
140 | if (kernel_type == kReference) { |
141 | TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinearInteger, int8_t); |
142 | } else if (kernel_type == kOptimized) { |
143 | TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, int8_t); |
144 | } |
145 | } else if (output->type == kTfLiteInt16) { |
146 | TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinearInteger, int16_t); |
147 | #undef TF_LITE_RESIZE_BILINEAR |
148 | } else { |
149 | TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float." , |
150 | output->type); |
151 | return kTfLiteError; |
152 | } |
153 | |
154 | return kTfLiteOk; |
155 | } |
156 | |
157 | } // namespace resize_bilinear |
158 | |
159 | TfLiteRegistration* Register_RESIZE_BILINEAR_REF() { |
160 | static TfLiteRegistration r = { |
161 | nullptr, nullptr, resize_bilinear::Prepare, |
162 | resize_bilinear::Eval<resize_bilinear::kReference>}; |
163 | return &r; |
164 | } |
165 | |
166 | TfLiteRegistration* Register_RESIZE_BILINEAR() { |
167 | static TfLiteRegistration r = { |
168 | nullptr, nullptr, resize_bilinear::Prepare, |
169 | resize_bilinear::Eval<resize_bilinear::kOptimized>}; |
170 | return &r; |
171 | } |
172 | |
173 | } // namespace builtin |
174 | } // namespace ops |
175 | } // namespace tflite |
176 | |