1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implements a quantized version of the resize bilinear op. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if defined(__ARM_NEON__) || defined(__ARM_NEON) |
21 | #define USE_NEON |
22 | #define QUANTIZED_RESIZE_BILINEAR_USE_NEON |
23 | #include <arm_neon.h> |
24 | #endif |
25 | |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/kernels/quantization_utils.h" |
29 | #include "tensorflow/core/platform/macros.h" |
30 | #include "tensorflow/core/util/image_resizer_state.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | static constexpr bool USE_REFERENCE = false; |
35 | |
36 | namespace { |
37 | // Compute the interpolation indices only once. |
38 | template <typename T_SCALE> |
39 | struct InterpolationCache { |
40 | std::vector<int64_t> lower; // Lower source index used in the interpolation |
41 | std::vector<int64_t> upper; // Upper source index used in the interpolation |
42 | // 1-D linear interpolation scale (see: |
43 | // https://en.wikipedia.org/wiki/Bilinear_interpolation) |
44 | std::vector<float> lerp; |
45 | std::vector<T_SCALE> ilerp; |
46 | }; |
47 | |
48 | template <typename T_SCALE, typename Scaler> |
49 | inline void ComputeInterpolationWeights( |
50 | const int64_t out_size, const int64_t in_size, const float scale, |
51 | const int resolution, InterpolationCache<T_SCALE>* interpolation) { |
52 | const Scaler scaler; |
53 | interpolation->lower.resize(out_size + 1); |
54 | interpolation->upper.resize(out_size + 1); |
55 | interpolation->lerp.resize(out_size + 1); |
56 | interpolation->ilerp.resize(out_size + 1); |
57 | |
58 | interpolation->lower[out_size] = 0; |
59 | interpolation->upper[out_size] = 0; |
60 | for (int64_t i = out_size - 1; i >= 0; --i) { |
61 | const float in = scaler(i, scale); |
62 | const float in_f = std::floor(in); |
63 | interpolation->lower[i] = |
64 | std::max(static_cast<int64_t>(in_f), static_cast<int64_t>(0)); |
65 | interpolation->upper[i] = |
66 | std::min(static_cast<int64_t>(std::ceil(in)), in_size - 1); |
67 | interpolation->lower[i] = |
68 | std::min(interpolation->lower[i], interpolation->upper[i]); |
69 | interpolation->lerp[i] = in - in_f; |
70 | interpolation->ilerp[i] = |
71 | static_cast<T_SCALE>((in - in_f) * (1 << resolution)); |
72 | } |
73 | } |
74 | |
75 | template <typename T_SCALE> |
76 | inline InterpolationCache<T_SCALE> BuildLerpCache( |
77 | const int64_t out_size, const int64_t in_size, const float scale, |
78 | const int index_step, const int resolution, const bool half_pixel_centers) { |
79 | InterpolationCache<T_SCALE> cache; |
80 | // Compute the cached interpolation weights on the x and y dimensions. |
81 | if (half_pixel_centers) { |
82 | ComputeInterpolationWeights<T_SCALE, HalfPixelScaler>( |
83 | out_size, in_size, scale, resolution, &cache); |
84 | } else { |
85 | ComputeInterpolationWeights<T_SCALE, LegacyScaler>(out_size, in_size, scale, |
86 | resolution, &cache); |
87 | } |
88 | CHECK(index_step > 0); |
89 | if (index_step > 1) { |
90 | for (int i = 0; i < cache.lower.size(); ++i) { |
91 | cache.lower[i] *= index_step; |
92 | cache.upper[i] *= index_step; |
93 | } |
94 | } |
95 | return cache; |
96 | } |
97 | |
98 | /** |
99 | * Computes the bilinear interpolation from the appropriate 4 float points |
100 | * and the linear interpolation weights. |
101 | */ |
102 | template <typename T> |
103 | inline T ComputeLerpReference(const T in_top_left, const T in_top_right, |
104 | const T in_bottom_left, const T in_bottom_right, |
105 | const float x_lerp, const float y_lerp, |
106 | const float min, const float max) { |
107 | const float top_left = QuantizedToFloat<T>(in_top_left, min, max); |
108 | const float top_right = QuantizedToFloat<T>(in_top_right, min, max); |
109 | const float bottom_left = QuantizedToFloat<T>(in_bottom_left, min, max); |
110 | const float bottom_right = QuantizedToFloat<T>(in_bottom_right, min, max); |
111 | const float top = top_left + (top_right - top_left) * x_lerp; |
112 | const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; |
113 | const float out = top + (bottom - top) * y_lerp; |
114 | return FloatToQuantized<T>(out, min, max); |
115 | } |
116 | |
117 | template <typename T, typename T_SCALE, typename T_CALC> |
118 | inline T_CALC MulOffset(T a, T b, T_SCALE c) { |
119 | return (static_cast<T_CALC>(a) - static_cast<T_CALC>(b)) * |
120 | static_cast<T_CALC>(c); |
121 | } |
122 | |
123 | template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC> |
124 | inline T ComputeLerp(const T top_left, const T top_right, const T bottom_left, |
125 | const T bottom_right, const T_SCALE x_lerp, |
126 | const T_SCALE y_lerp) { |
127 | constexpr T_CALC RESOLUTION_MULT = (1 << RESOLUTION); |
128 | const T_CALC top = static_cast<T_CALC>(top_left) * RESOLUTION_MULT + |
129 | MulOffset<T, T_SCALE, T_CALC>(top_right, top_left, x_lerp); |
130 | const T_CALC bottom = |
131 | static_cast<T_CALC>(bottom_left) * RESOLUTION_MULT + |
132 | MulOffset<T, T_SCALE, T_CALC>(bottom_right, bottom_left, x_lerp); |
133 | const T_CALC out = top + (bottom - top) / RESOLUTION_MULT * y_lerp; |
134 | return static_cast<T>( |
135 | static_cast<int32>((out + RESOLUTION_MULT / 2) / RESOLUTION_MULT)); |
136 | } |
137 | |
138 | #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON |
139 | inline uint8x8_t ToUint8x8(const quint8* v0, const quint8* v1, const quint8* v2, |
140 | const quint8* v3, const quint8* v4, const quint8* v5, |
141 | const quint8* v6, const quint8* v7) { |
142 | static const uint8x8_t ZERO_8x8 = vmov_n_u8(0); |
143 | uint8x8_t ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v0), ZERO_8x8, 0); |
144 | ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v1), ret, 1); |
145 | ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v2), ret, 2); |
146 | ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v3), ret, 3); |
147 | ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v4), ret, 4); |
148 | ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v5), ret, 5); |
149 | ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v6), ret, 6); |
150 | ret = vld1_lane_u8(reinterpret_cast<const uint8*>(v7), ret, 7); |
151 | return ret; |
152 | } |
153 | |
154 | inline int16x8_t ToInt16x8(const int16* v0, const int16* v1, const int16* v2, |
155 | const int16* v3, const int16* v4, const int16* v5, |
156 | const int16* v6, const int16* v7) { |
157 | static const int16x8_t ZERO_16x8 = vmovq_n_s16(0); |
158 | int16x8_t ret = vld1q_lane_s16(v0, ZERO_16x8, 0); |
159 | ret = vld1q_lane_s16(v1, ret, 1); |
160 | ret = vld1q_lane_s16(v2, ret, 2); |
161 | ret = vld1q_lane_s16(v3, ret, 3); |
162 | ret = vld1q_lane_s16(v4, ret, 4); |
163 | ret = vld1q_lane_s16(v5, ret, 5); |
164 | ret = vld1q_lane_s16(v6, ret, 6); |
165 | ret = vld1q_lane_s16(v7, ret, 7); |
166 | return ret; |
167 | } |
168 | |
169 | inline int32x2_t ToInt32x2(const qint32* v0, const qint32* v1) { |
170 | static const int32x2_t ZERO_32x2 = vmov_n_s32(0); |
171 | const int32x2_t ret0 = |
172 | vld1_lane_s32(reinterpret_cast<const int32*>(v0), ZERO_32x2, 0); |
173 | const int32x2_t ret1 = |
174 | vld1_lane_s32(reinterpret_cast<const int32*>(v1), ret0, 1); |
175 | return ret1; |
176 | } |
177 | |
178 | template <int RESOLUTION, bool X_LERP_SAME> |
179 | inline int32x2_t ComputeLerpx2( |
180 | const qint32* top_left0, const qint32* top_right0, |
181 | const qint32* bottom_left0, const qint32* bottom_right0, |
182 | const qint32* top_left1, const qint32* top_right1, |
183 | const qint32* bottom_left1, const qint32* bottom_right1, |
184 | const int32* x_lerp, const int32x2_t y_lerpsx) { |
185 | const int32x2_t x_lerpsx = |
186 | X_LERP_SAME ? vld1_dup_s32(reinterpret_cast<const int32*>(x_lerp)) |
187 | : vld1_s32(reinterpret_cast<const int32*>(x_lerp)); |
188 | |
189 | const int32x2_t top_leftsx = ToInt32x2(top_left0, top_left1); |
190 | const int32x2_t top_rightsx = ToInt32x2(top_right0, top_right1); |
191 | const int32x2_t bottom_leftsx = ToInt32x2(bottom_left0, bottom_left1); |
192 | const int32x2_t bottom_rightsx = ToInt32x2(bottom_right0, bottom_right1); |
193 | |
194 | const int32x2_t retval = |
195 | ComputeLerp32x2<RESOLUTION>(top_leftsx, top_rightsx, bottom_leftsx, |
196 | bottom_rightsx, x_lerpsx, y_lerpsx); |
197 | return retval; |
198 | } |
199 | |
200 | template <int RESOLUTION> |
201 | inline uint8x8_t ComputeLerpx8( |
202 | const quint8* tl0, const quint8* tr0, const quint8* bl0, const quint8* br0, |
203 | const int16* xlp0, const quint8* tl1, const quint8* tr1, const quint8* bl1, |
204 | const quint8* br1, const int16* xlp1, const quint8* tl2, const quint8* tr2, |
205 | const quint8* bl2, const quint8* br2, const int16* xlp2, const quint8* tl3, |
206 | const quint8* tr3, const quint8* bl3, const quint8* br3, const int16* xlp3, |
207 | const quint8* tl4, const quint8* tr4, const quint8* bl4, const quint8* br4, |
208 | const int16* xlp4, const quint8* tl5, const quint8* tr5, const quint8* bl5, |
209 | const quint8* br5, const int16* xlp5, const quint8* tl6, const quint8* tr6, |
210 | const quint8* bl6, const quint8* br6, const int16* xlp6, const quint8* tl7, |
211 | const quint8* tr7, const quint8* bl7, const quint8* br7, const int16* xlp7, |
212 | const int16x8_t ys_lerpsx) { |
213 | const uint8x8_t tl8x8 = ToUint8x8(tl0, tl1, tl2, tl3, tl4, tl5, tl6, tl7); |
214 | const uint8x8_t tr8x8 = ToUint8x8(tr0, tr1, tr2, tr3, tr4, tr5, tr6, tr7); |
215 | const uint8x8_t bl8x8 = ToUint8x8(bl0, bl1, bl2, bl3, bl4, bl5, bl6, bl7); |
216 | const uint8x8_t br8x8 = ToUint8x8(br0, br1, br2, br3, br4, br5, br6, br7); |
217 | const int16x8_t xs_lerpsx = |
218 | ToInt16x8(xlp0, xlp1, xlp2, xlp3, xlp4, xlp5, xlp6, xlp7); |
219 | return ComputeLerp8x8<RESOLUTION>(tl8x8, tr8x8, bl8x8, br8x8, xs_lerpsx, |
220 | ys_lerpsx); |
221 | } |
222 | |
223 | // Expand address at compile time to improve performance |
224 | template <int RESOLUTION, int ID0, int CH0, int ID1, int CH1, int ID2, int CH2, |
225 | int ID3, int CH3, int ID4, int CH4, int ID5, int CH5, int ID6, |
226 | int CH6, int ID7, int CH7> |
227 | inline uint8x8_t ComputeLerpx8Tmpl(const quint8* const yl, const quint8* yu, |
228 | const int64* xl, const int64* xu, |
229 | const int16* xlp, |
230 | const int16x8_t ys_lerpsx) { |
231 | return ComputeLerpx8<RESOLUTION>( |
232 | yl + xl[ID0] + CH0, yl + xu[ID0] + CH0, yu + xl[ID0] + CH0, |
233 | yu + xu[ID0] + CH0, xlp + ID0, yl + xl[ID1] + CH1, yl + xu[ID1] + CH1, |
234 | yu + xl[ID1] + CH1, yu + xu[ID1] + CH1, xlp + ID1, yl + xl[ID2] + CH2, |
235 | yl + xu[ID2] + CH2, yu + xl[ID2] + CH2, yu + xu[ID2] + CH2, xlp + ID2, |
236 | yl + xl[ID3] + CH3, yl + xu[ID3] + CH3, yu + xl[ID3] + CH3, |
237 | yu + xu[ID3] + CH3, xlp + ID3, yl + xl[ID4] + CH4, yl + xu[ID4] + CH4, |
238 | yu + xl[ID4] + CH4, yu + xu[ID4] + CH4, xlp + ID4, yl + xl[ID5] + CH5, |
239 | yl + xu[ID5] + CH5, yu + xl[ID5] + CH5, yu + xu[ID5] + CH5, xlp + ID5, |
240 | yl + xl[ID6] + CH6, yl + xu[ID6] + CH6, yu + xl[ID6] + CH6, |
241 | yu + xu[ID6] + CH6, xlp + ID6, yl + xl[ID7] + CH7, yl + xu[ID7] + CH7, |
242 | yu + xl[ID7] + CH7, yu + xu[ID7] + CH7, xlp + ID7, ys_lerpsx); |
243 | } |
244 | |
245 | #endif |
246 | |
247 | template <int RESOLUTION, typename T, typename T_SCALE, typename T_CALC> |
248 | inline void OutputLerpForChannels(const InterpolationCache<T_SCALE>& xs, |
249 | const int64_t x, const T_SCALE ys_ilerp, |
250 | const int channels, const float min, |
251 | const float max, const T* ys_input_lower_ptr, |
252 | const T* ys_input_upper_ptr, |
253 | T* output_y_ptr) { |
254 | const int64_t xs_lower = xs.lower[x]; |
255 | const int64_t xs_upper = xs.upper[x]; |
256 | const T_SCALE xs_ilerp = xs.ilerp[x]; |
257 | for (int c = 0; c < channels; ++c) { |
258 | const T top_left = ys_input_lower_ptr[xs_lower + c]; |
259 | const T top_right = ys_input_lower_ptr[xs_upper + c]; |
260 | const T bottom_left = ys_input_upper_ptr[xs_lower + c]; |
261 | const T bottom_right = ys_input_upper_ptr[xs_upper + c]; |
262 | const T val = ComputeLerp<RESOLUTION, T, T_SCALE, T_CALC>( |
263 | top_left, top_right, bottom_left, bottom_right, xs_ilerp, ys_ilerp); |
264 | output_y_ptr[x * channels + c] = val; |
265 | } |
266 | } |
267 | |
268 | template <int RES> |
269 | inline void OutputLerp8x8x1(const InterpolationCache<int16>& xs, |
270 | const int64_t x_start, const int16_t ys_ilerp, |
271 | const float min, const float max, |
272 | const quint8* const ys_input_lower_ptr, |
273 | const quint8* const ys_input_upper_ptr, |
274 | quint8* output_y_ptr) { |
275 | #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON |
276 | const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp); |
277 | |
278 | const uint8x8_t x0x7 = |
279 | ComputeLerpx8Tmpl<RES, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0>( |
280 | ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], |
281 | &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); |
282 | |
283 | vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start), x0x7); |
284 | |
285 | #else |
286 | for (int x = x_start; x < x_start + 8; ++x) { |
287 | OutputLerpForChannels<RES, quint8, int16, int16>( |
288 | xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr, |
289 | output_y_ptr); |
290 | } |
291 | #endif |
292 | } |
293 | |
294 | template <int RES> |
295 | inline void OutputLerp8x8x3(const InterpolationCache<int16>& xs, |
296 | const int64_t x_start, const int16_t ys_ilerp, |
297 | const float min, const float max, |
298 | const quint8* const ys_input_lower_ptr, |
299 | const quint8* const ys_input_upper_ptr, |
300 | quint8* output_y_ptr) { |
301 | #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON |
302 | const int16x8_t y_lerpsx = vmovq_n_s16(ys_ilerp); |
303 | |
304 | const uint8x8_t x0c0x2c1 = |
305 | ComputeLerpx8Tmpl<RES, 0, 0, 0, 1, 0, 2, 1, 0, 1, 1, 1, 2, 2, 0, 2, 1>( |
306 | ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], |
307 | &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); |
308 | |
309 | vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3), x0c0x2c1); |
310 | |
311 | const uint8x8_t x2c2x5c0 = |
312 | ComputeLerpx8Tmpl<RES, 2, 2, 3, 0, 3, 1, 3, 2, 4, 0, 4, 1, 4, 2, 5, 0>( |
313 | ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], |
314 | &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); |
315 | |
316 | vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 8), x2c2x5c0); |
317 | |
318 | const uint8x8_t x5c1x7c2 = |
319 | ComputeLerpx8Tmpl<RES, 5, 1, 5, 2, 6, 0, 6, 1, 6, 2, 7, 0, 7, 1, 7, 2>( |
320 | ys_input_lower_ptr, ys_input_upper_ptr, &xs.lower[x_start], |
321 | &xs.upper[x_start], &xs.ilerp[x_start], y_lerpsx); |
322 | |
323 | vst1_u8(reinterpret_cast<uint8_t*>(output_y_ptr + x_start * 3 + 16), |
324 | x5c1x7c2); |
325 | |
326 | #else |
327 | for (int x = x_start; x < x_start + 8; ++x) { |
328 | OutputLerpForChannels<RES, quint8, int16, int16>( |
329 | xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr, |
330 | output_y_ptr); |
331 | } |
332 | #endif |
333 | } |
334 | |
335 | template <int RESOLUTION> |
336 | inline void OutputLerp32x4x1(const InterpolationCache<int32>& xs, |
337 | const int64_t x_start, const int32_t ys_ilerp, |
338 | const float min, const float max, |
339 | const qint32* const ys_input_lower_ptr, |
340 | const qint32* const ys_input_upper_ptr, |
341 | qint32* output_y_ptr) { |
342 | #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON |
343 | const int64 xs_lower0 = xs.lower[x_start]; |
344 | const int64 xs_upper0 = xs.upper[x_start]; |
345 | const int32* const xs_ilerp0 = &xs.ilerp[x_start]; |
346 | const int64 xs_lower1 = xs.lower[x_start + 1]; |
347 | const int64 xs_upper1 = xs.upper[x_start + 1]; |
348 | const int64 xs_lower2 = xs.lower[x_start + 2]; |
349 | const int64 xs_upper2 = xs.upper[x_start + 2]; |
350 | const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2]; |
351 | const int64 xs_lower3 = xs.lower[x_start + 3]; |
352 | const int64 xs_upper3 = xs.upper[x_start + 3]; |
353 | |
354 | const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp); |
355 | |
356 | const int32x2_t x0x1 = ComputeLerpx2<RESOLUTION, false>( |
357 | ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0, |
358 | ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0, |
359 | ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1, |
360 | ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0, |
361 | y_lerpsx); |
362 | |
363 | const int32x2_t x1x2 = ComputeLerpx2<RESOLUTION, false>( |
364 | ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2, |
365 | ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2, |
366 | ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3, |
367 | ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2, |
368 | y_lerpsx); |
369 | |
370 | const int32x4_t x0x1x2x3 = vcombine_s32(x0x1, x1x2); |
371 | |
372 | vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start), x0x1x2x3); |
373 | |
374 | #else |
375 | for (int x = x_start; x < x_start + 4; ++x) { |
376 | OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>( |
377 | xs, x, ys_ilerp, 1, min, max, ys_input_lower_ptr, ys_input_upper_ptr, |
378 | output_y_ptr); |
379 | } |
380 | #endif |
381 | } |
382 | |
383 | template <int RESOLUTION> |
384 | inline void OutputLerp32x4x3(const InterpolationCache<int32>& xs, |
385 | const int64_t x_start, const int32_t ys_ilerp, |
386 | const float min, const float max, |
387 | const qint32* const ys_input_lower_ptr, |
388 | const qint32* const ys_input_upper_ptr, |
389 | qint32* output_y_ptr) { |
390 | #ifdef QUANTIZED_RESIZE_BILINEAR_USE_NEON |
391 | const int64 xs_lower0 = xs.lower[x_start]; |
392 | const int64 xs_upper0 = xs.upper[x_start]; |
393 | const int32* const xs_ilerp0 = &xs.ilerp[x_start]; |
394 | const int64 xs_lower1 = xs.lower[x_start + 1]; |
395 | const int64 xs_upper1 = xs.upper[x_start + 1]; |
396 | const int32* const xs_ilerp1 = &xs.ilerp[x_start + 1]; |
397 | const int64 xs_lower2 = xs.lower[x_start + 2]; |
398 | const int64 xs_upper2 = xs.upper[x_start + 2]; |
399 | const int32* const xs_ilerp2 = &xs.ilerp[x_start + 2]; |
400 | const int64 xs_lower3 = xs.lower[x_start + 3]; |
401 | const int64 xs_upper3 = xs.upper[x_start + 3]; |
402 | const int32* const xs_ilerp3 = &xs.ilerp[x_start + 3]; |
403 | |
404 | const int32x2_t y_lerpsx = vmov_n_s32(ys_ilerp); |
405 | |
406 | const int32x2_t x0c0x0c1 = ComputeLerpx2<RESOLUTION, true>( |
407 | ys_input_lower_ptr + xs_lower0, ys_input_lower_ptr + xs_upper0, |
408 | ys_input_upper_ptr + xs_lower0, ys_input_upper_ptr + xs_upper0, |
409 | ys_input_lower_ptr + xs_lower0 + 1, ys_input_lower_ptr + xs_upper0 + 1, |
410 | ys_input_upper_ptr + xs_lower0 + 1, ys_input_upper_ptr + xs_upper0 + 1, |
411 | xs_ilerp0, y_lerpsx); |
412 | |
413 | const int32x2_t x0c2x1c0 = ComputeLerpx2<RESOLUTION, false>( |
414 | ys_input_lower_ptr + xs_lower0 + 2, ys_input_lower_ptr + xs_upper0 + 2, |
415 | ys_input_upper_ptr + xs_lower0 + 2, ys_input_upper_ptr + xs_upper0 + 2, |
416 | ys_input_lower_ptr + xs_lower1, ys_input_lower_ptr + xs_upper1, |
417 | ys_input_upper_ptr + xs_lower1, ys_input_upper_ptr + xs_upper1, xs_ilerp0, |
418 | y_lerpsx); |
419 | |
420 | const int32x2_t x1c1x1c2 = ComputeLerpx2<RESOLUTION, true>( |
421 | ys_input_lower_ptr + xs_lower1 + 1, ys_input_lower_ptr + xs_upper1 + 1, |
422 | ys_input_upper_ptr + xs_lower1 + 1, ys_input_upper_ptr + xs_upper1 + 1, |
423 | ys_input_lower_ptr + xs_lower1 + 2, ys_input_lower_ptr + xs_upper1 + 2, |
424 | ys_input_upper_ptr + xs_lower1 + 2, ys_input_upper_ptr + xs_upper1 + 2, |
425 | xs_ilerp1, y_lerpsx); |
426 | |
427 | const int32x2_t x2c0x2c1 = ComputeLerpx2<RESOLUTION, true>( |
428 | ys_input_lower_ptr + xs_lower2, ys_input_lower_ptr + xs_upper2, |
429 | ys_input_upper_ptr + xs_lower2, ys_input_upper_ptr + xs_upper2, |
430 | ys_input_lower_ptr + xs_lower2 + 1, ys_input_lower_ptr + xs_upper2 + 1, |
431 | ys_input_upper_ptr + xs_lower2 + 1, ys_input_upper_ptr + xs_upper2 + 1, |
432 | xs_ilerp2, y_lerpsx); |
433 | |
434 | const int32x2_t x2c2x3c0 = ComputeLerpx2<RESOLUTION, false>( |
435 | ys_input_lower_ptr + xs_lower2 + 2, ys_input_lower_ptr + xs_upper2 + 2, |
436 | ys_input_upper_ptr + xs_lower2 + 2, ys_input_upper_ptr + xs_upper2 + 2, |
437 | ys_input_lower_ptr + xs_lower3, ys_input_lower_ptr + xs_upper3, |
438 | ys_input_upper_ptr + xs_lower3, ys_input_upper_ptr + xs_upper3, xs_ilerp2, |
439 | y_lerpsx); |
440 | |
441 | const int32x2_t x3c1x3c2 = ComputeLerpx2<RESOLUTION, true>( |
442 | ys_input_lower_ptr + xs_lower3 + 1, ys_input_lower_ptr + xs_upper3 + 1, |
443 | ys_input_upper_ptr + xs_lower3 + 1, ys_input_upper_ptr + xs_upper3 + 1, |
444 | ys_input_lower_ptr + xs_lower3 + 2, ys_input_lower_ptr + xs_upper3 + 2, |
445 | ys_input_upper_ptr + xs_lower3 + 2, ys_input_upper_ptr + xs_upper3 + 2, |
446 | xs_ilerp3, y_lerpsx); |
447 | |
448 | const int32x4_t x0c0x0c1x0c2x1c0 = vcombine_s32(x0c0x0c1, x0c2x1c0); |
449 | const int32x4_t x1c1x1c2x2c0x2c1 = vcombine_s32(x1c1x1c2, x2c0x2c1); |
450 | const int32x4_t x2c2x3c0x3c1x3c2 = vcombine_s32(x2c2x3c0, x3c1x3c2); |
451 | |
452 | vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3), |
453 | x0c0x0c1x0c2x1c0); |
454 | vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 4), |
455 | x1c1x1c2x2c0x2c1); |
456 | vst1q_s32(reinterpret_cast<int32*>(output_y_ptr + x_start * 3 + 8), |
457 | x2c2x3c0x3c1x3c2); |
458 | |
459 | #else |
460 | for (int x = x_start; x < x_start + 4; ++x) { |
461 | OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>( |
462 | xs, x, ys_ilerp, 3, min, max, ys_input_lower_ptr, ys_input_upper_ptr, |
463 | output_y_ptr); |
464 | } |
465 | #endif |
466 | } |
467 | |
468 | template <typename T> |
469 | void ResizeImageReference(typename TTypes<T, 4>::ConstTensor images, |
470 | const int batch_size, const int64_t in_height, |
471 | const int64_t in_width, const int64_t out_height, |
472 | const int64_t out_width, const int channels, |
473 | const float height_scale, const float width_scale, |
474 | const float in_min, const float in_max, |
475 | const bool half_pixel_centers, |
476 | typename TTypes<T, 4>::Tensor* output) { |
477 | CHECK_NOTNULL(output); |
478 | |
479 | const InterpolationCache<float> xs = BuildLerpCache<float>( |
480 | out_width, in_width, width_scale, channels, 0, half_pixel_centers); |
481 | const InterpolationCache<float> ys = BuildLerpCache<float>( |
482 | out_height, in_height, height_scale, 1, 0, half_pixel_centers); |
483 | |
484 | const int64_t in_row_size = in_width * channels; |
485 | const int64_t in_batch_num_values = in_height * in_row_size; |
486 | const int64_t out_row_size = out_width * channels; |
487 | |
488 | const T* input_b_ptr = images.data(); |
489 | |
490 | T* output_y_ptr = output->data(); |
491 | for (int b = 0; b < batch_size; ++b) { |
492 | for (int64_t y = 0; y < out_height; ++y) { |
493 | const T* ys_input_lower_ptr = input_b_ptr + ys.lower[y] * in_row_size; |
494 | const T* ys_input_upper_ptr = input_b_ptr + ys.upper[y] * in_row_size; |
495 | const float ys_lerp = ys.lerp[y]; |
496 | for (int64_t x = 0; x < out_width; ++x) { |
497 | const int64_t xs_lower = xs.lower[x]; |
498 | const int64_t xs_upper = xs.upper[x]; |
499 | const float xs_lerp = xs.lerp[x]; |
500 | for (int c = 0; c < channels; ++c) { |
501 | const T top_left = ys_input_lower_ptr[xs_lower + c]; |
502 | const T top_right = ys_input_lower_ptr[xs_upper + c]; |
503 | const T bottom_left = ys_input_upper_ptr[xs_lower + c]; |
504 | const T bottom_right = ys_input_upper_ptr[xs_upper + c]; |
505 | const T val = ComputeLerpReference<T>( |
506 | top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp, |
507 | in_min, in_max); |
508 | output_y_ptr[x * channels + c] = val; |
509 | } |
510 | } |
511 | output_y_ptr += out_row_size; |
512 | } |
513 | input_b_ptr += in_batch_num_values; |
514 | } |
515 | } |
516 | |
517 | template <typename T> |
518 | void ResizeImage(typename TTypes<T, 4>::ConstTensor images, |
519 | const int batch_size, const int64_t in_height, |
520 | const int64_t in_width, const int64_t out_height, |
521 | const int64_t out_width, const int channels, |
522 | const float height_scale, const float width_scale, |
523 | const float in_min, const float in_max, |
524 | const bool half_pixel_centers, |
525 | typename TTypes<T, 4>::Tensor* output) { |
526 | ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height, |
527 | out_width, channels, height_scale, width_scale, |
528 | in_min, in_max, half_pixel_centers, output); |
529 | } |
530 | |
531 | template <> |
532 | void ResizeImage<qint32>(typename TTypes<qint32, 4>::ConstTensor images, |
533 | const int batch_size, const int64_t in_height, |
534 | const int64_t in_width, const int64_t out_height, |
535 | const int64_t out_width, const int channels, |
536 | const float height_scale, const float width_scale, |
537 | const float in_min, const float in_max, |
538 | const bool half_pixel_centers, |
539 | typename TTypes<qint32, 4>::Tensor* output) { |
540 | // 30 is maximum resolution for signed int. |
541 | constexpr int RESOLUTION = 30; |
542 | constexpr int SIMD_STEP = 4; |
543 | |
544 | CHECK_NOTNULL(output); |
545 | |
546 | const InterpolationCache<int32> xs = |
547 | BuildLerpCache<int32>(out_width, in_width, width_scale, channels, |
548 | RESOLUTION, half_pixel_centers); |
549 | const InterpolationCache<int32> ys = BuildLerpCache<int32>( |
550 | out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers); |
551 | |
552 | const int64_t in_row_size = in_width * channels; |
553 | const int64_t in_batch_num_values = in_height * in_row_size; |
554 | const int64_t out_row_size = out_width * channels; |
555 | |
556 | const qint32* input_b_ptr = images.data(); |
557 | |
558 | qint32* output_y_ptr = output->data(); |
559 | |
560 | for (int b = 0; b < batch_size; ++b) { |
561 | for (int64_t y = 0; y < out_height; ++y) { |
562 | const qint32* ys_input_lower_ptr = |
563 | input_b_ptr + ys.lower[y] * in_row_size; |
564 | const qint32* ys_input_upper_ptr = |
565 | input_b_ptr + ys.upper[y] * in_row_size; |
566 | const int32_t ys_ilerp = ys.ilerp[y]; |
567 | // Optimized for channels == 1 or channels == 3 as this |
568 | // is typical channels. |
569 | int64_t x = 0; |
570 | if (channels == 1) { |
571 | for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { |
572 | OutputLerp32x4x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, |
573 | ys_input_lower_ptr, ys_input_upper_ptr, |
574 | output_y_ptr); |
575 | } |
576 | } else if (channels == 3) { |
577 | for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { |
578 | OutputLerp32x4x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, |
579 | ys_input_lower_ptr, ys_input_upper_ptr, |
580 | output_y_ptr); |
581 | } |
582 | } |
583 | for (; x < out_width; ++x) { |
584 | OutputLerpForChannels<RESOLUTION, qint32, int32, int64_t>( |
585 | xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr, |
586 | ys_input_upper_ptr, output_y_ptr); |
587 | } |
588 | output_y_ptr += out_row_size; |
589 | } |
590 | input_b_ptr += in_batch_num_values; |
591 | } |
592 | } |
593 | |
594 | template <> |
595 | void ResizeImage<quint8>(typename TTypes<quint8, 4>::ConstTensor images, |
596 | const int batch_size, const int64_t in_height, |
597 | const int64_t in_width, const int64_t out_height, |
598 | const int64_t out_width, const int channels, |
599 | const float height_scale, const float width_scale, |
600 | const float in_min, const float in_max, |
601 | const bool half_pixel_centers, |
602 | typename TTypes<quint8, 4>::Tensor* output) { |
603 | // 7 is maximum resolution for unsigned byte. |
604 | constexpr int RESOLUTION = 7; |
605 | constexpr int SIMD_STEP = 8; |
606 | |
607 | CHECK_NOTNULL(output); |
608 | |
609 | const InterpolationCache<int16> xs = |
610 | BuildLerpCache<int16>(out_width, in_width, width_scale, channels, |
611 | RESOLUTION, half_pixel_centers); |
612 | const InterpolationCache<int16> ys = BuildLerpCache<int16>( |
613 | out_height, in_height, height_scale, 1, RESOLUTION, half_pixel_centers); |
614 | |
615 | const int64_t in_row_size = in_width * channels; |
616 | const int64_t in_batch_num_values = in_height * in_row_size; |
617 | const int64_t out_row_size = out_width * channels; |
618 | |
619 | const quint8* input_b_ptr = images.data(); |
620 | |
621 | quint8* output_y_ptr = output->data(); |
622 | |
623 | for (int b = 0; b < batch_size; ++b) { |
624 | for (int64_t y = 0; y < out_height; ++y) { |
625 | const quint8* ys_input_lower_ptr = |
626 | input_b_ptr + ys.lower[y] * in_row_size; |
627 | const quint8* ys_input_upper_ptr = |
628 | input_b_ptr + ys.upper[y] * in_row_size; |
629 | const int32_t ys_ilerp = ys.ilerp[y]; |
630 | // Optimized for channels == 1 or channels == 3 as this |
631 | // is typical channels. |
632 | // TODO(satok): Support more generic NEON optimized implementation |
633 | // for different channels. |
634 | int64_t x = 0; |
635 | if (channels == 1) { |
636 | for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { |
637 | OutputLerp8x8x1<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, |
638 | ys_input_lower_ptr, ys_input_upper_ptr, |
639 | output_y_ptr); |
640 | } |
641 | } else if (channels == 3) { |
642 | for (; x < out_width - SIMD_STEP + 1; x += SIMD_STEP) { |
643 | OutputLerp8x8x3<RESOLUTION>(xs, x, ys_ilerp, in_min, in_max, |
644 | ys_input_lower_ptr, ys_input_upper_ptr, |
645 | output_y_ptr); |
646 | } |
647 | } |
648 | for (; x < out_width; ++x) { |
649 | OutputLerpForChannels<RESOLUTION, quint8, int16, int16>( |
650 | xs, x, ys_ilerp, channels, in_min, in_max, ys_input_lower_ptr, |
651 | ys_input_upper_ptr, output_y_ptr); |
652 | } |
653 | output_y_ptr += out_row_size; |
654 | } |
655 | input_b_ptr += in_batch_num_values; |
656 | } |
657 | } |
658 | |
659 | template <typename T> |
660 | void ResizeBilinear(const typename TTypes<T, 4>::ConstTensor& images, |
661 | const float height_scale, const float width_scale, |
662 | const float in_min, const float in_max, |
663 | const bool half_pixel_centers, |
664 | typename TTypes<T, 4>::Tensor* output) { |
665 | CHECK_NOTNULL(output); |
666 | |
667 | const int batch_size = images.dimension(0); |
668 | const int64_t in_height = images.dimension(1); |
669 | const int64_t in_width = images.dimension(2); |
670 | const int channels = images.dimension(3); |
671 | |
672 | const int64_t out_height = output->dimension(1); |
673 | const int64_t out_width = output->dimension(2); |
674 | |
675 | // Handle no-op resizes efficiently. |
676 | if (out_height == in_height && out_width == in_width) { |
677 | *output = images.template cast<T>(); |
678 | return; |
679 | } |
680 | |
681 | if (USE_REFERENCE) { |
682 | ResizeImageReference<T>(images, batch_size, in_height, in_width, out_height, |
683 | out_width, channels, height_scale, width_scale, |
684 | in_min, in_max, half_pixel_centers, output); |
685 | } else { |
686 | ResizeImage<T>(images, batch_size, in_height, in_width, out_height, |
687 | out_width, channels, height_scale, width_scale, in_min, |
688 | in_max, half_pixel_centers, output); |
689 | } |
690 | } |
691 | |
692 | } // namespace |
693 | |
694 | template <class T> |
695 | class QuantizedResizeBilinearOp : public OpKernel { |
696 | public: |
697 | explicit QuantizedResizeBilinearOp(OpKernelConstruction* context) |
698 | : OpKernel(context) { |
699 | OP_REQUIRES_OK(context, context->GetAttr("align_corners" , &align_corners_)); |
700 | OP_REQUIRES_OK( |
701 | context, context->GetAttr("half_pixel_centers" , &half_pixel_centers_)); |
702 | } |
703 | |
704 | void Compute(OpKernelContext* context) override { |
705 | const auto& in_min_tensor = context->input(2); |
706 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(in_min_tensor.shape()), |
707 | errors::InvalidArgument("min must be a scalar" )); |
708 | const float in_min = in_min_tensor.flat<float>()(0); |
709 | const auto& in_max_tensor = context->input(3); |
710 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(in_max_tensor.shape()), |
711 | errors::InvalidArgument("max must be a scalar" )); |
712 | const float in_max = in_max_tensor.flat<float>()(0); |
713 | |
714 | ImageResizerState st(align_corners_, false); |
715 | st.ValidateAndCreateOutput(context); |
716 | |
717 | if (!context->status().ok()) return; |
718 | |
719 | // Return if the output is empty. |
720 | if (st.output->NumElements() == 0) return; |
721 | |
722 | typename TTypes<T, 4>::ConstTensor image_data( |
723 | context->input(0).tensor<T, 4>()); |
724 | typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>()); |
725 | |
726 | ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min, |
727 | in_max, half_pixel_centers_, &output_data); |
728 | Tensor* out_min = nullptr; |
729 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &out_min)); |
730 | out_min->flat<float>()(0) = in_min; |
731 | |
732 | Tensor* out_max = nullptr; |
733 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &out_max)); |
734 | out_max->flat<float>()(0) = in_max; |
735 | } |
736 | |
737 | private: |
738 | bool align_corners_; |
739 | bool half_pixel_centers_; |
740 | |
741 | TF_DISALLOW_COPY_AND_ASSIGN(QuantizedResizeBilinearOp<T>); |
742 | }; |
743 | |
744 | #define REGISTER_CPU_KERNEL(type) \ |
745 | REGISTER_KERNEL_BUILDER(Name("QuantizedResizeBilinear") \ |
746 | .Device(DEVICE_CPU) \ |
747 | .HostMemory("size") \ |
748 | .TypeConstraint<type>("T"), \ |
749 | QuantizedResizeBilinearOp<type>) |
750 | |
751 | REGISTER_CPU_KERNEL(::tensorflow::quint8); |
752 | REGISTER_CPU_KERNEL(::tensorflow::qint32); |
753 | REGISTER_CPU_KERNEL(float); |
754 | |
755 | } // namespace tensorflow |
756 | |