1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements a quantized version of the resize bilinear op.
17
18#define EIGEN_USE_THREADS
19
20#if defined(__ARM_NEON__) || defined(__ARM_NEON)
21#define USE_NEON
22#define QUANTIZED_RESIZE_BILINEAR_USE_NEON
23#include <arm_neon.h>
24#endif
25
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/kernels/quantization_utils.h"
29#include "tensorflow/core/platform/macros.h"
30#include "tensorflow/core/util/image_resizer_state.h"
31
32namespace tensorflow {
33
34static constexpr bool USE_REFERENCE = false;
35
36namespace {
37// Compute the interpolation indices only once.
38template <typename T_SCALE>
39struct InterpolationCache {
40 std::vector<int64_t> lower; // Lower source index used in the interpolation
41 std::vector<int64_t> upper; // Upper source index used in the interpolation
42 // 1-D linear interpolation scale (see:
43 // https://en.wikipedia.org/wiki/Bilinear_interpolation)
44 std::vector<float> lerp;
45 std::vector<T_SCALE> ilerp;
46};
47
48template <typename T_SCALE, typename Scaler>
49inline void ComputeInterpolationWeights(
50 const int64_t out_size, const int64_t in_size, const float scale,
51 const int resolution, InterpolationCache<T_SCALE>* interpolation) {
52 const Scaler scaler;
53 interpolation->lower.resize(out_size + 1);
54 interpolation->upper.resize(out_size + 1);
55 interpolation->lerp.resize(out_size + 1);
56 interpolation->ilerp.resize(out_size + 1);
57
58 interpolation->lower[out_size] = 0;
59 interpolation->upper[out_size] = 0;
60 for (int64_t i = out_size - 1; i >= 0; --i) {
61 const float in = scaler(i, scale);
62 const float in_f = std::floor(in);
63 interpolation->lower[i] =
64 std::max(static_cast<int64_t>(in_f), static_cast<int64_t>(0));
65 interpolation->upper[i] =
66 std::min(static_cast<int64_t>(std::ceil(in)), in_size - 1);
67 interpolation->lower[i] =
68 std::min(interpolation->lower[i], interpolation->upper[i]);
69 interpolation->lerp[i] = in - in_f;
70 interpolation->ilerp[i] =
71 static_cast<T_SCALE>((in - in_f) * (1 << resolution));
72 }
73}
74
75template <typename T_SCALE>
76inline InterpolationCache<T_SCALE> BuildLerpCache(
77 const int64_t out_size, const int64_t in_size, const float scale,
78 const int index_step, const int resolution, const bool half_pixel_centers) {
79 InterpolationCache<T_SCALE> cache;
80 // Compute the cached interpolation weights on the x and y dimensions.
81 if (half_pixel_centers) {
82 ComputeInterpolationWeights<T_SCALE, HalfPixelScaler>(
83 out_size, in_size, scale, resolution, &cache);
84 } else {
85 ComputeInterpolationWeights<T_SCALE, LegacyScaler>(out_size, in_size, scale,
86 resolution, &cache);
87 }
88 CHECK(index_step > 0);
89 if (index_step > 1) {
90 for (int i = 0; i < cache.lower.size(); ++i) {
91 cache.lower[i] *= index_step;
92 cache.upper[i] *= index_step;
93 }
94 }
95 return cache;
96}
97
98/**
99 * Computes the bilinear interpolation from the appropriate 4 float points
100 * and the linear interpolation weights.
101 */
102template <typename T>
103inline T ComputeLerpReference(const T in_top_left, const T in_top_right,
104 const T in_bottom_left, const T in_bottom_right,
105 const float x_lerp, const float y_lerp,
106 const float min, const float max) {
107 const float top_left = QuantizedToFloat<T>(in_top_left, min, max);
108 const float top_right = QuantizedToFloat<T>(in_top_right, min, max);
109 const float bottom_left = QuantizedToFloat<T>(in_bottom_left, min, max);
110 const float bottom_right = QuantizedToFloat<T>(in_bottom_right, min, max);
111 const float top = top_left + (top_right - top_left) * x_lerp;
112 const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
113 const float out = top + (bottom - top) * y_lerp;
114 return FloatToQuantized<T>(out, min, max);
115}
116
117template <typename T, typename T_SCALE, typename T_CALC>
118inline T_CALC MulOffset(T a, T b, T_SCALE c) {
119 return (static_cast<T_CALC>(a) - static_cast<T_CALC>(b)) *
120 static_cast<T_CALC>(c);
121}
122
123template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
124inline T ComputeLerp(const T top_left, const T top_right, const T bottom_left,
125 const T bottom_right, const T_SCALE x_lerp,
126 const T_SCALE y_lerp) {
127 constexpr T_CALC RESOLUTION_MULT = (1 << RESOLUTION);
128 const T_CALC top = static_cast<T_CALC>(top_left) * RESOLUTION_MULT +
129 MulOffset<T, T_SCALE, T_CALC>(top_right, top_left, x_lerp);
130 const T_CALC bottom =
131 static_cast<T_CALC>(bottom_left) * RESOLUTION_MULT +
132 MulOffset<T, T_SCALE, T_CALC>(bottom_right, bottom_left, x_lerp);
133 const T_CALC out = top + (bottom - top) / RESOLUTION_MULT * y_lerp;
134 return static_cast<T>(
135 static_cast<int32>((out + RESOLUTION_MULT / 2) / RESOLUTION_MULT));
136}
137
138#ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
139inline uint8x8_t ToUint8x8(const quint8* v0, const quint8* v1, const quint8* v2,
140 const quint8* v3, const quint8* v4, const quint8* v5,
141 const quint8* v6, const quint8* v7) {
142 static const uint8x8_t ZERO_8x8 = vmov_n_u8(0);
143 uint8x8_t ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v0), ZERO_8x8, 0);
144 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v1), ret, 1);
145 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v2), ret, 2);
146 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v3), ret, 3);
147 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v4), ret, 4);
148 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v5), ret, 5);
149 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v6), ret, 6);
150 ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v7), ret, 7);
151 return ret;
152}
153
154inline int16x8_t ToInt16x8(const int16* v0, const int16* v1, const int16* v2,
155 const int16* v3, const int16* v4, const int16* v5,
156 const int16* v6, const int16* v7) {
157 static const int16x8_t ZERO_16x8 = vmovq_n_s16(0);
158 int16x8_t ret = vld1q_lane_s16(v0, ZERO_16x8, 0);
159 ret = vld1q_lane_s16(v1, ret, 1);
160 ret = vld1q_lane_s16(v2, ret, 2);
161 ret = vld1q_lane_s16(v3, ret, 3);
162 ret = vld1q_lane_s16(v4, ret, 4);
163 ret = vld1q_lane_s16(v5, ret, 5);
164 ret = vld1q_lane_s16(v6, ret, 6);
165 ret = vld1q_lane_s16(v7, ret, 7);
166 return ret;
167}
168
169inline int32x2_t ToInt32x2(const qint32* v0, const qint32* v1) {
170 static const int32x2_t ZERO_32x2 = vmov_n_s32(0);
171 const int32x2_t ret0 =
172 vld1_lane_s32(reinterpret_cast<const int32*>(v0), ZERO_32x2, 0);
173 const int32x2_t ret1 =
174 vld1_lane_s32(reinterpret_cast<const int32*>(v1), ret0, 1);
175 return ret1;
176}
177
178template <int RESOLUTION, bool X_LERP_SAME>
179inline int32x2_t ComputeLerpx2(
180 const qint32* top_left0, const qint32* top_right0,
181 const qint32* bottom_left0, const qint32* bottom_right0,
182 const qint32* top_left1, const qint32* top_right1,
183 const qint32* bottom_left1, const qint32* bottom_right1,
184 const int32* x_lerp, const int32x2_t y_lerpsx) {
185 const int32x2_t x_lerpsx =
186 X_LERP_SAME ? vld1_dup_s32(reinterpret_cast<const int32*>(x_lerp))
187 : vld1_s32(reinterpret_cast<const int32*>(x_lerp));
188
189 const int32x2_t top_leftsx = ToInt32x2(top_left0, top_left1);
190 const int32x2_t top_rightsx = ToInt32x2(top_right0, top_right1);
191 const int32x2_t bottom_leftsx = ToInt32x2(bottom_left0, bottom_left1);
192 const int32x2_t bottom_rightsx = ToInt32x2(bottom_right0, bottom_right1);
193
194 const int32x2_t retval =
195 ComputeLerp32x2<RESOLUTION>(top_leftsx, top_rightsx, bottom_leftsx,
196 bottom_rightsx, x_lerpsx, y_lerpsx);
197 return retval;
198}
199
200template <int RESOLUTION>
201inline uint8x8_t ComputeLerpx8(
202 const quint8* tl0, const quint8* tr0, const quint8* bl0, const quint8* br0,
203 const int16* xlp0, const quint8* tl1, const quint8* tr1, const quint8* bl1,
204 const quint8* br1, const int16* xlp1, const quint8* tl2, const quint8* tr2,
205 const quint8* bl2, const quint8* br2, const int16* xlp2, const quint8* tl3,
206 const quint8* tr3, const quint8* bl3, const quint8* br3, const int16* xlp3,
207 const quint8* tl4, const quint8* tr4, const quint8* bl4, const quint8* br4,
208 const int16* xlp4, const quint8* tl5, const quint8* tr5, const quint8* bl5,
209 const quint8* br5, const int16* xlp5, const quint8* tl6, const quint8* tr6,
210 const quint8* bl6, const quint8* br6, const int16* xlp6, const quint8* tl7,
211 const quint8* tr7, const quint8* bl7, const quint8* br7, const int16* xlp7,
212 const int16x8_t ys_lerpsx) {
213 const uint8x8_t tl8x8 = ToUint8x8(tl0, tl1, tl2, tl3, tl4, tl5, tl6, tl7);
214 const uint8x8_t tr8x8 = ToUint8x8(tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7);
215 const uint8x8_t bl8x8 = ToUint8x8(bl0, bl1, bl2, bl3, bl4, bl5, bl6, bl7);
216 const uint8x8_t br8x8 = ToUint8x8(br0, br1, br2, br3, br4, br5, br6, br7);
217 const int16x8_t xs_lerpsx =
218 ToInt16x8(xlp0, xlp1, xlp2, xlp3, xlp4, xlp5, xlp6, xlp7);
219 return ComputeLerp8x8<RESOLUTION>(tl8x8, tr8x8, bl8x8, br8x8, xs_lerpsx,
220 ys_lerpsx);
221}
222
223// Expand address at compile time to improve performance
224template <int RESOLUTION, int ID0, int CH0, int ID1, int CH1, int ID2, int CH2,
225 int ID3, int CH3, int ID4, int CH4, int ID5, int CH5, int ID6,
226 int CH6, int ID7, int CH7>
227inline uint8x8_t ComputeLerpx8Tmpl(const quint8* const yl, const quint8* yu,
228 const int64* xl, const int64* xu,
229 const int16* xlp,
230 const int16x8_t ys_lerpsx) {
231 return ComputeLerpx8<RESOLUTION>(
232 yl + xl[ID0] + CH0, yl + xu[ID0] + CH0, yu + xl[ID0] + CH0,
233 yu + xu[ID0] + CH0, xlp + ID0, yl + xl[ID1] + CH1, yl + xu[ID1] + CH1,
234 yu + xl[ID1] + CH1, yu + xu[ID1] + CH1, xlp + ID1, yl + xl[ID2] + CH2,
235 yl + xu[ID2] + CH2, yu + xl[ID2] + CH2, yu + xu[ID2] + CH2, xlp + ID2,
236 yl + xl[ID3] + CH3, yl + xu[ID3] + CH3, yu + xl[ID3] + CH3,
237 yu + xu[ID3] + CH3, xlp + ID3, yl + xl[ID4] + CH4, yl + xu[ID4] + CH4,
238 yu + xl[ID4] + CH4, yu + xu[ID4] + CH4, xlp + ID4, yl + xl[ID5] + CH5,
239 yl + xu[ID5] + CH5, yu + xl[ID5] + CH5, yu + xu[ID5] + CH5, xlp + ID5,
240 yl + xl[ID6] + CH6, yl + xu[ID6] + CH6, yu + xl[ID6] + CH6,
241 yu + xu[ID6] + CH6, xlp + ID6, yl + xl[ID7] + CH7, yl + xu[ID7] + CH7,
242 yu + xl[ID7] + CH7, yu + xu[ID7] + CH7, xlp + ID7, ys_lerpsx);
243}
244
245#endif
246
247template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC>
248inline void OutputLerpForChannels(const InterpolationCache<T_SCALE>& xs,
249 const int64_t x, const T_SCALE ys_ilerp,
250 const int channels, const float min,
251 const float max, const T* ys_input_lower_ptr,
252 const T* ys_input_upper_ptr,
253 T* output_y_ptr) {
254 const int64_t xs_lower = xs.lower[x];
255 const int64_t xs_upper = xs.upper[x];
256 const T_SCALE xs_ilerp = xs.ilerp[x];
257 for (int c = 0; c < channels; ++c) {
258 const T top_left = ys_input_lower_ptr[xs_lower + c];
259 const T top_right = ys_input_lower_ptr[xs_upper + c];
260 const T bottom_left = ys_input_upper_ptr[xs_lower + c];
261 const T bottom_right = ys_input_upper_ptr[xs_upper + c];
262 const T val = ComputeLerp<RESOLUTION, T, T_SCALE, T_CALC>(
263 top_left, top_right, bottom_left, bottom_right, xs_ilerp, ys_ilerp);
264 output_y_ptr[x * channels + c] = val;
265 }
266}
267
268template <int RES>
269inline void OutputLerp8x8x1(const InterpolationCache<int16>& xs,
270 const int64_t x_start, const int16_t ys_ilerp,
271 const float min, const float max,
272 const quint8* const ys_input_lower_ptr,
273 const quint8* const ys_input_upper_ptr,
274 quint8* output_y_ptr) {
275#ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
276 const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
277
278 const uint8x8_t x0x7 =
279 ComputeLerpx8Tmpl<RES, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0>(
280 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
281 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
282
283 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start), x0x7);
284
285#else
286 for (int x = x_start; x < x_start + 8; ++x) {
287 OutputLerpForChannels<RES, quint8, int16, int16>(
288 xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
289 output_y_ptr);
290 }
291#endif
292}
293
294template <int RES>
295inline void OutputLerp8x8x3(const InterpolationCache<int16>& xs,
296 const int64_t x_start, const int16_t ys_ilerp,
297 const float min, const float max,
298 const quint8* const ys_input_lower_ptr,
299 const quint8* const ys_input_upper_ptr,
300 quint8* output_y_ptr) {
301#ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
302 const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp);
303
304 const uint8x8_t x0c0x2c1 =
305 ComputeLerpx8Tmpl<RES, 0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1>(
306 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
307 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
308
309 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3), x0c0x2c1);
310
311 const uint8x8_t x2c2x5c0 =
312 ComputeLerpx8Tmpl<RES, 2, 2, 3, 0, 3, 1, 3, 2, 4, 0, 4, 1, 4, 2, 5, 0>(
313 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
314 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
315
316 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 8), x2c2x5c0);
317
318 const uint8x8_t x5c1x7c2 =
319 ComputeLerpx8Tmpl<RES, 5, 1, 5, 2, 6, 0, 6, 1, 6, 2, 7, 0, 7, 1, 7, 2>(
320 ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start],
321 &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx);
322
323 vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 16),
324 x5c1x7c2);
325
326#else
327 for (int x = x_start; x < x_start + 8; ++x) {
328 OutputLerpForChannels<RES, quint8, int16, int16>(
329 xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
330 output_y_ptr);
331 }
332#endif
333}
334
335template <int RESOLUTION>
336inline void OutputLerp32x4x1(const InterpolationCache<int32>& xs,
337 const int64_t x_start, const int32_t ys_ilerp,
338 const float min, const float max,
339 const qint32* const ys_input_lower_ptr,
340 const qint32* const ys_input_upper_ptr,
341 qint32* output_y_ptr) {
342#ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
343 const int64 xs_lower0 = xs.lower[x_start];
344 const int64 xs_upper0 = xs.upper[x_start];
345 const int32* const xs_ilerp0 = &xs.ilerp[x_start];
346 const int64 xs_lower1 = xs.lower[x_start + 1];
347 const int64 xs_upper1 = xs.upper[x_start + 1];
348 const int64 xs_lower2 = xs.lower[x_start + 2];
349 const int64 xs_upper2 = xs.upper[x_start + 2];
350 const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
351 const int64 xs_lower3 = xs.lower[x_start + 3];
352 const int64 xs_upper3 = xs.upper[x_start + 3];
353
354 const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
355
356 const int32x2_t x0x1 = ComputeLerpx2<RESOLUTION, false>(
357 ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
358 ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
359 ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
360 ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
361 y_lerpsx);
362
363 const int32x2_t x1x2 = ComputeLerpx2<RESOLUTION, false>(
364 ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
365 ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
366 ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
367 ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
368 y_lerpsx);
369
370 const int32x4_t x0x1x2x3 = vcombine_s32(x0x1, x1x2);
371
372 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start), x0x1x2x3);
373
374#else
375 for (int x = x_start; x < x_start + 4; ++x) {
376 OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>(
377 xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
378 output_y_ptr);
379 }
380#endif
381}
382
383template <int RESOLUTION>
384inline void OutputLerp32x4x3(const InterpolationCache<int32>& xs,
385 const int64_t x_start, const int32_t ys_ilerp,
386 const float min, const float max,
387 const qint32* const ys_input_lower_ptr,
388 const qint32* const ys_input_upper_ptr,
389 qint32* output_y_ptr) {
390#ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON
391 const int64 xs_lower0 = xs.lower[x_start];
392 const int64 xs_upper0 = xs.upper[x_start];
393 const int32* const xs_ilerp0 = &xs.ilerp[x_start];
394 const int64 xs_lower1 = xs.lower[x_start + 1];
395 const int64 xs_upper1 = xs.upper[x_start + 1];
396 const int32* const xs_ilerp1 = &xs.ilerp[x_start + 1];
397 const int64 xs_lower2 = xs.lower[x_start + 2];
398 const int64 xs_upper2 = xs.upper[x_start + 2];
399 const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2];
400 const int64 xs_lower3 = xs.lower[x_start + 3];
401 const int64 xs_upper3 = xs.upper[x_start + 3];
402 const int32* const xs_ilerp3 = &xs.ilerp[x_start + 3];
403
404 const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp);
405
406 const int32x2_t x0c0x0c1 = ComputeLerpx2<RESOLUTION, true>(
407 ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0,
408 ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0,
409 ys_input_lower_ptr + xs_lower0 + 1, ys_input_lower_ptr + xs_upper0 + 1,
410 ys_input_upper_ptr + xs_lower0 + 1, ys_input_upper_ptr + xs_upper0 + 1,
411 xs_ilerp0, y_lerpsx);
412
413 const int32x2_t x0c2x1c0 = ComputeLerpx2<RESOLUTION, false>(
414 ys_input_lower_ptr + xs_lower0 + 2, ys_input_lower_ptr + xs_upper0 + 2,
415 ys_input_upper_ptr + xs_lower0 + 2, ys_input_upper_ptr + xs_upper0 + 2,
416 ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1,
417 ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0,
418 y_lerpsx);
419
420 const int32x2_t x1c1x1c2 = ComputeLerpx2<RESOLUTION, true>(
421 ys_input_lower_ptr + xs_lower1 + 1, ys_input_lower_ptr + xs_upper1 + 1,
422 ys_input_upper_ptr + xs_lower1 + 1, ys_input_upper_ptr + xs_upper1 + 1,
423 ys_input_lower_ptr + xs_lower1 + 2, ys_input_lower_ptr + xs_upper1 + 2,
424 ys_input_upper_ptr + xs_lower1 + 2, ys_input_upper_ptr + xs_upper1 + 2,
425 xs_ilerp1, y_lerpsx);
426
427 const int32x2_t x2c0x2c1 = ComputeLerpx2<RESOLUTION, true>(
428 ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2,
429 ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2,
430 ys_input_lower_ptr + xs_lower2 + 1, ys_input_lower_ptr + xs_upper2 + 1,
431 ys_input_upper_ptr + xs_lower2 + 1, ys_input_upper_ptr + xs_upper2 + 1,
432 xs_ilerp2, y_lerpsx);
433
434 const int32x2_t x2c2x3c0 = ComputeLerpx2<RESOLUTION, false>(
435 ys_input_lower_ptr + xs_lower2 + 2, ys_input_lower_ptr + xs_upper2 + 2,
436 ys_input_upper_ptr + xs_lower2 + 2, ys_input_upper_ptr + xs_upper2 + 2,
437 ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3,
438 ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2,
439 y_lerpsx);
440
441 const int32x2_t x3c1x3c2 = ComputeLerpx2<RESOLUTION, true>(
442 ys_input_lower_ptr + xs_lower3 + 1, ys_input_lower_ptr + xs_upper3 + 1,
443 ys_input_upper_ptr + xs_lower3 + 1, ys_input_upper_ptr + xs_upper3 + 1,
444 ys_input_lower_ptr + xs_lower3 + 2, ys_input_lower_ptr + xs_upper3 + 2,
445 ys_input_upper_ptr + xs_lower3 + 2, ys_input_upper_ptr + xs_upper3 + 2,
446 xs_ilerp3, y_lerpsx);
447
448 const int32x4_t x0c0x0c1x0c2x1c0 = vcombine_s32(x0c0x0c1, x0c2x1c0);
449 const int32x4_t x1c1x1c2x2c0x2c1 = vcombine_s32(x1c1x1c2, x2c0x2c1);
450 const int32x4_t x2c2x3c0x3c1x3c2 = vcombine_s32(x2c2x3c0, x3c1x3c2);
451
452 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3),
453 x0c0x0c1x0c2x1c0);
454 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 4),
455 x1c1x1c2x2c0x2c1);
456 vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 8),
457 x2c2x3c0x3c1x3c2);
458
459#else
460 for (int x = x_start; x < x_start + 4; ++x) {
461 OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>(
462 xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr,
463 output_y_ptr);
464 }
465#endif
466}
467
468template <typename T>
469void ResizeImageReference(typename TTypes<T, 4>::ConstTensor images,
470 const int batch_size, const int64_t in_height,
471 const int64_t in_width, const int64_t out_height,
472 const int64_t out_width, const int channels,
473 const float height_scale, const float width_scale,
474 const float in_min, const float in_max,
475 const bool half_pixel_centers,
476 typename TTypes<T, 4>::Tensor* output) {
477 CHECK_NOTNULL(output);
478
479 const InterpolationCache<float> xs = BuildLerpCache<float>(
480 out_width, in_width, width_scale, channels, 0, half_pixel_centers);
481 const InterpolationCache<float> ys = BuildLerpCache<float>(
482 out_height, in_height, height_scale, 1, 0, half_pixel_centers);
483
484 const int64_t in_row_size = in_width * channels;
485 const int64_t in_batch_num_values = in_height * in_row_size;
486 const int64_t out_row_size = out_width * channels;
487
488 const T* input_b_ptr = images.data();
489
490 T* output_y_ptr = output->data();
491 for (int b = 0; b < batch_size; ++b) {
492 for (int64_t y = 0; y < out_height; ++y) {
493 const T* ys_input_lower_ptr = input_b_ptr + ys.lower[y] * in_row_size;
494 const T* ys_input_upper_ptr = input_b_ptr + ys.upper[y] * in_row_size;
495 const float ys_lerp = ys.lerp[y];
496 for (int64_t x = 0; x < out_width; ++x) {
497 const int64_t xs_lower = xs.lower[x];
498 const int64_t xs_upper = xs.upper[x];
499 const float xs_lerp = xs.lerp[x];
500 for (int c = 0; c < channels; ++c) {
501 const T top_left = ys_input_lower_ptr[xs_lower + c];
502 const T top_right = ys_input_lower_ptr[xs_upper + c];
503 const T bottom_left = ys_input_upper_ptr[xs_lower + c];
504 const T bottom_right = ys_input_upper_ptr[xs_upper + c];
505 const T val = ComputeLerpReference<T>(
506 top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp,
507 in_min, in_max);
508 output_y_ptr[x * channels + c] = val;
509 }
510 }
511 output_y_ptr += out_row_size;
512 }
513 input_b_ptr += in_batch_num_values;
514 }
515}
516
517template <typename T>
518void ResizeImage(typename TTypes<T, 4>::ConstTensor images,
519 const int batch_size, const int64_t in_height,
520 const int64_t in_width, const int64_t out_height,
521 const int64_t out_width, const int channels,
522 const float height_scale, const float width_scale,
523 const float in_min, const float in_max,
524 const bool half_pixel_centers,
525 typename TTypes<T, 4>::Tensor* output) {
526 ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
527 out_width, channels, height_scale, width_scale,
528 in_min, in_max, half_pixel_centers, output);
529}
530
531template <>
532void ResizeImage<qint32>(typename TTypes<qint32, 4>::ConstTensor images,
533 const int batch_size, const int64_t in_height,
534 const int64_t in_width, const int64_t out_height,
535 const int64_t out_width, const int channels,
536 const float height_scale, const float width_scale,
537 const float in_min, const float in_max,
538 const bool half_pixel_centers,
539 typename TTypes<qint32, 4>::Tensor* output) {
540 // 30 is maximum resolution for signed int.
541 constexpr int RESOLUTION = 30;
542 constexpr int SIMD_STEP = 4;
543
544 CHECK_NOTNULL(output);
545
546 const InterpolationCache<int32> xs =
547 BuildLerpCache<int32>(out_width, in_width, width_scale, channels,
548 RESOLUTION, half_pixel_centers);
549 const InterpolationCache<int32> ys = BuildLerpCache<int32>(
550 out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
551
552 const int64_t in_row_size = in_width * channels;
553 const int64_t in_batch_num_values = in_height * in_row_size;
554 const int64_t out_row_size = out_width * channels;
555
556 const qint32* input_b_ptr = images.data();
557
558 qint32* output_y_ptr = output->data();
559
560 for (int b = 0; b < batch_size; ++b) {
561 for (int64_t y = 0; y < out_height; ++y) {
562 const qint32* ys_input_lower_ptr =
563 input_b_ptr + ys.lower[y] * in_row_size;
564 const qint32* ys_input_upper_ptr =
565 input_b_ptr + ys.upper[y] * in_row_size;
566 const int32_t ys_ilerp = ys.ilerp[y];
567 // Optimized for channels == 1 or channels == 3 as this
568 // is typical channels.
569 int64_t x = 0;
570 if (channels == 1) {
571 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
572 OutputLerp32x4x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
573 ys_input_lower_ptr, ys_input_upper_ptr,
574 output_y_ptr);
575 }
576 } else if (channels == 3) {
577 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
578 OutputLerp32x4x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
579 ys_input_lower_ptr, ys_input_upper_ptr,
580 output_y_ptr);
581 }
582 }
583 for (; x < out_width; ++x) {
584 OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>(
585 xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
586 ys_input_upper_ptr, output_y_ptr);
587 }
588 output_y_ptr += out_row_size;
589 }
590 input_b_ptr += in_batch_num_values;
591 }
592}
593
594template <>
595void ResizeImage<quint8>(typename TTypes<quint8, 4>::ConstTensor images,
596 const int batch_size, const int64_t in_height,
597 const int64_t in_width, const int64_t out_height,
598 const int64_t out_width, const int channels,
599 const float height_scale, const float width_scale,
600 const float in_min, const float in_max,
601 const bool half_pixel_centers,
602 typename TTypes<quint8, 4>::Tensor* output) {
603 // 7 is maximum resolution for unsigned byte.
604 constexpr int RESOLUTION = 7;
605 constexpr int SIMD_STEP = 8;
606
607 CHECK_NOTNULL(output);
608
609 const InterpolationCache<int16> xs =
610 BuildLerpCache<int16>(out_width, in_width, width_scale, channels,
611 RESOLUTION, half_pixel_centers);
612 const InterpolationCache<int16> ys = BuildLerpCache<int16>(
613 out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers);
614
615 const int64_t in_row_size = in_width * channels;
616 const int64_t in_batch_num_values = in_height * in_row_size;
617 const int64_t out_row_size = out_width * channels;
618
619 const quint8* input_b_ptr = images.data();
620
621 quint8* output_y_ptr = output->data();
622
623 for (int b = 0; b < batch_size; ++b) {
624 for (int64_t y = 0; y < out_height; ++y) {
625 const quint8* ys_input_lower_ptr =
626 input_b_ptr + ys.lower[y] * in_row_size;
627 const quint8* ys_input_upper_ptr =
628 input_b_ptr + ys.upper[y] * in_row_size;
629 const int32_t ys_ilerp = ys.ilerp[y];
630 // Optimized for channels == 1 or channels == 3 as this
631 // is typical channels.
632 // TODO(satok): Support more generic NEON optimized implementation
633 // for different channels.
634 int64_t x = 0;
635 if (channels == 1) {
636 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
637 OutputLerp8x8x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
638 ys_input_lower_ptr, ys_input_upper_ptr,
639 output_y_ptr);
640 }
641 } else if (channels == 3) {
642 for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) {
643 OutputLerp8x8x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max,
644 ys_input_lower_ptr, ys_input_upper_ptr,
645 output_y_ptr);
646 }
647 }
648 for (; x < out_width; ++x) {
649 OutputLerpForChannels<RESOLUTION, quint8, int16, int16>(
650 xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr,
651 ys_input_upper_ptr, output_y_ptr);
652 }
653 output_y_ptr += out_row_size;
654 }
655 input_b_ptr += in_batch_num_values;
656 }
657}
658
659template <typename T>
660void ResizeBilinear(const typename TTypes<T, 4>::ConstTensor& images,
661 const float height_scale, const float width_scale,
662 const float in_min, const float in_max,
663 const bool half_pixel_centers,
664 typename TTypes<T, 4>::Tensor* output) {
665 CHECK_NOTNULL(output);
666
667 const int batch_size = images.dimension(0);
668 const int64_t in_height = images.dimension(1);
669 const int64_t in_width = images.dimension(2);
670 const int channels = images.dimension(3);
671
672 const int64_t out_height = output->dimension(1);
673 const int64_t out_width = output->dimension(2);
674
675 // Handle no-op resizes efficiently.
676 if (out_height == in_height && out_width == in_width) {
677 *output = images.template cast<T>();
678 return;
679 }
680
681 if (USE_REFERENCE) {
682 ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height,
683 out_width, channels, height_scale, width_scale,
684 in_min, in_max, half_pixel_centers, output);
685 } else {
686 ResizeImage<T>(images, batch_size, in_height, in_width, out_height,
687 out_width, channels, height_scale, width_scale, in_min,
688 in_max, half_pixel_centers, output);
689 }
690}
691
692} // namespace
693
694template <class T>
695class QuantizedResizeBilinearOp : public OpKernel {
696 public:
697 explicit QuantizedResizeBilinearOp(OpKernelConstruction* context)
698 : OpKernel(context) {
699 OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
700 OP_REQUIRES_OK(
701 context, context->GetAttr("half_pixel_centers", &half_pixel_centers_));
702 }
703
704 void Compute(OpKernelContext* context) override {
705 const auto& in_min_tensor = context->input(2);
706 OP_REQUIRES(context, TensorShapeUtils::IsScalar(in_min_tensor.shape()),
707 errors::InvalidArgument("min must be a scalar"));
708 const float in_min = in_min_tensor.flat<float>()(0);
709 const auto& in_max_tensor = context->input(3);
710 OP_REQUIRES(context, TensorShapeUtils::IsScalar(in_max_tensor.shape()),
711 errors::InvalidArgument("max must be a scalar"));
712 const float in_max = in_max_tensor.flat<float>()(0);
713
714 ImageResizerState st(align_corners_, false);
715 st.ValidateAndCreateOutput(context);
716
717 if (!context->status().ok()) return;
718
719 // Return if the output is empty.
720 if (st.output->NumElements() == 0) return;
721
722 typename TTypes<T, 4>::ConstTensor image_data(
723 context->input(0).tensor<T, 4>());
724 typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
725
726 ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min,
727 in_max, half_pixel_centers_, &output_data);
728 Tensor* out_min = nullptr;
729 OP_REQUIRES_OK(context, context->allocate_output(1, {}, &out_min));
730 out_min->flat<float>()(0) = in_min;
731
732 Tensor* out_max = nullptr;
733 OP_REQUIRES_OK(context, context->allocate_output(2, {}, &out_max));
734 out_max->flat<float>()(0) = in_max;
735 }
736
737 private:
738 bool align_corners_;
739 bool half_pixel_centers_;
740
741 TF_DISALLOW_COPY_AND_ASSIGN(QuantizedResizeBilinearOp<T>);
742};
743
744#define REGISTER_CPU_KERNEL(type) \
745 REGISTER_KERNEL_BUILDER(Name("QuantizedResizeBilinear") \
746 .Device(DEVICE_CPU) \
747 .HostMemory("size") \
748 .TypeConstraint<type>("T"), \
749 QuantizedResizeBilinearOp<type>)
750
751REGISTER_CPU_KERNEL(::tensorflow::quint8);
752REGISTER_CPU_KERNEL(::tensorflow::qint32);
753REGISTER_CPU_KERNEL(float);
754
755} // namespace tensorflow
756