1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ |
18 | |
19 | #include <cmath> |
20 | #define EIGEN_USE_THREADS |
21 | |
22 | // This is a set of functions that standardizes how quantized values are |
23 | // interpreted as float numbers. |
24 | // All of the current implementations are for reference and have not been |
25 | // optimized. They should be implementable using fixed point representations |
26 | // to avoid a dependency on floating-point hardware. |
27 | |
28 | #if defined(__ARM_NEON__) || defined(__ARM_NEON) |
29 | #define QUANTIZATION_UTILS_USE_NEON |
30 | #include <arm_neon.h> |
31 | #endif |
32 | |
33 | #include <array> |
34 | |
35 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
36 | #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK |
37 | #include "public/gemmlowp.h" |
38 | #include "tensorflow/core/framework/tensor.h" |
39 | #include "tensorflow/core/lib/core/threadpool.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | // We have to be able to detect and handle overflows in int32, so this function |
44 | // uses doubles and int64's to make sure we have enough room. |
45 | template <class T> |
46 | inline int64_t FloatToQuantizedUnclamped(float input, float range_min, |
47 | float range_max) { |
48 | const int64_t lowest_quantized = |
49 | static_cast<double>(Eigen::NumTraits<T>::lowest()); |
50 | if (range_min == range_max) { |
51 | return lowest_quantized; |
52 | } |
53 | const int number_of_bits = sizeof(T) * 8; |
54 | const int64_t number_of_steps = static_cast<int64_t>(1) << number_of_bits; |
55 | const double range_adjust = (number_of_steps / (number_of_steps - 1.0)); |
56 | const double range = ((range_max - range_min) * range_adjust); |
57 | const double range_scale = (number_of_steps / range); |
58 | int64_t quantized = |
59 | (round(input * range_scale) - round(range_min * range_scale)); |
60 | quantized += lowest_quantized; |
61 | return quantized; |
62 | } |
63 | |
64 | template <> |
65 | inline int64_t FloatToQuantizedUnclamped<float>(float input, float range_min, |
66 | float range_max) { |
67 | return -1; |
68 | } |
69 | |
70 | // This converts the float into the final quantized type, clamping/saturating |
71 | // any over or underflows. |
72 | template <class T> |
73 | T FloatToQuantized(float input, float range_min, float range_max) { |
74 | if (std::is_same<T, float>::value) { |
75 | // Specialization for float. This is used in reference implementation |
76 | // for float which is useful to compare performance between float |
77 | // and quantized type. |
78 | return input; |
79 | } |
80 | int64_t quantized = FloatToQuantizedUnclamped<T>(input, range_min, range_max); |
81 | const int64_t lowest_quantized = |
82 | static_cast<int64_t>(Eigen::NumTraits<T>::lowest()); |
83 | const int64_t highest_quantized = |
84 | static_cast<int64_t>(Eigen::NumTraits<T>::highest()); |
85 | quantized = std::max(quantized, lowest_quantized); |
86 | quantized = std::min(quantized, highest_quantized); |
87 | return static_cast<T>(static_cast<int32>(quantized)); |
88 | } |
89 | |
90 | template <class T> |
91 | float QuantizedToFloat(T input, float range_min, float range_max) { |
92 | if (std::is_same<T, float>::value) { |
93 | // Specialization for float. This is used in reference implementation |
94 | // for float which is useful to compare performance between float |
95 | // and quantized type. |
96 | return input; |
97 | } |
98 | if (range_min == range_max) { |
99 | return range_min; |
100 | } |
101 | const int number_of_bits = sizeof(T) * 8; |
102 | const int64_t number_of_steps = static_cast<int64_t>(1) << number_of_bits; |
103 | const double range_adjust = (number_of_steps / (number_of_steps - 1.0)); |
104 | const double range = ((range_max - range_min) * range_adjust); |
105 | const double range_scale = (range / number_of_steps); |
106 | const int64_t lowest_quantized = |
107 | static_cast<int64_t>(Eigen::NumTraits<T>::lowest()); |
108 | const double offset_input = static_cast<double>(input) - lowest_quantized; |
109 | // For compatibility with DEQUANTIZE_WITH_EIGEN, we should convert |
110 | // range_scale to a float, otherwise range_min_rounded might be slightly |
111 | // different. |
112 | const double range_min_rounded = |
113 | std::round(range_min / static_cast<float>(range_scale)) * |
114 | static_cast<float>(range_scale); |
115 | const double result = range_min_rounded + (offset_input * range_scale); |
116 | return static_cast<float>(result); |
117 | } |
118 | |
119 | template <class T> |
120 | float FloatForOneQuantizedLevel(float range_min, float range_max) { |
121 | const int64_t highest = static_cast<int64_t>(Eigen::NumTraits<T>::highest()); |
122 | const int64_t lowest = static_cast<int64_t>(Eigen::NumTraits<T>::lowest()); |
123 | const float float_for_one_quantized_level = |
124 | (range_max - range_min) / (highest - lowest); |
125 | return float_for_one_quantized_level; |
126 | } |
127 | |
128 | template <class T1, class T2, class T3> |
129 | void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b, |
130 | float max_b, float* min_c, |
131 | float* max_c) { |
132 | const float a_float_for_one_quant_level = |
133 | FloatForOneQuantizedLevel<T1>(min_a, max_a); |
134 | const float b_float_for_one_quant_level = |
135 | FloatForOneQuantizedLevel<T2>(min_b, max_b); |
136 | |
137 | const int64_t c_highest = |
138 | static_cast<int64_t>(Eigen::NumTraits<T3>::highest()); |
139 | const int64_t c_lowest = static_cast<int64_t>(Eigen::NumTraits<T3>::lowest()); |
140 | const float c_float_for_one_quant_level = |
141 | a_float_for_one_quant_level * b_float_for_one_quant_level; |
142 | |
143 | *min_c = c_float_for_one_quant_level * c_lowest; |
144 | *max_c = c_float_for_one_quant_level * c_highest; |
145 | } |
146 | |
147 | // input_array is an eigen Tensor. q2f is a QuantizedToFloatStruct. |
148 | // This evaluates to an eigen tensor expression, to be used like: |
149 | // auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f); |
150 | #define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \ |
151 | ((q2f.range_min_rounded - q2f.lowest_quantized() * q2f.range_scale) + \ |
152 | input_array.template cast<float>() * q2f.range_scale) |
153 | |
154 | // input_array is an eigen Tensor. f2q is a FloatToQuantizedStruct. |
155 | // OutputType is the type of output (e.g. quint8). |
156 | // This evaluates to an eigen tensor expression, to be used like: |
157 | // auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T); |
158 | #define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \ |
159 | ((input_array * f2q.range_scale).round() - \ |
160 | (f2q.range_min_scaled - f2q.lowest_quantized())) \ |
161 | .cwiseMax(f2q.lower_bound_float()) \ |
162 | .cwiseMin(f2q.upper_bound_float()) \ |
163 | .template cast<int32>() \ |
164 | .template cast<OutputType>() |
165 | |
166 | // For use with DEQUANTIZE_WITH_EIGEN. |
167 | template <typename T> |
168 | struct QuantizedToFloatStruct { |
169 | static constexpr int number_of_bits = sizeof(T) * 8; |
170 | static constexpr int64_t number_of_steps = static_cast<int64_t>(1) |
171 | << number_of_bits; |
172 | |
173 | static float lowest_quantized() { |
174 | return static_cast<float>(Eigen::NumTraits<T>::lowest()); |
175 | } |
176 | |
177 | QuantizedToFloatStruct(float range_min, float range_max) |
178 | : range_min(range_min), |
179 | range_scale((range_max - range_min) / (number_of_steps - 1.0)), |
180 | range_min_rounded(range_max == range_min |
181 | ? range_min |
182 | : std::round(range_min / range_scale) * |
183 | range_scale) {} |
184 | |
185 | const float range_min; |
186 | const float range_scale; |
187 | const float range_min_rounded; |
188 | }; |
189 | |
190 | // For use with QUANTIZE_WITH_EIGEN. |
191 | template <typename T> |
192 | struct FloatToQuantizedStruct { |
193 | static constexpr int number_of_bits = sizeof(T) * 8; |
194 | static constexpr int64_t number_of_steps = static_cast<int64_t>(1) |
195 | << number_of_bits; |
196 | static constexpr double range_adjust = |
197 | (number_of_steps / (number_of_steps - 1.0)); |
198 | |
199 | // Casting QInt32's lowest or highest to a float gives a float that can't be |
200 | // cast back to int32 or QInt32. Instead, use bounds that can be converted |
201 | // back to int32 without going outside the range of an int32. |
202 | static float lower_bound_float() { |
203 | return Eigen::numext::maxi( |
204 | static_cast<float>(Eigen::NumTraits<T>::lowest()), -2.147483648e+09f); |
205 | } |
206 | static float upper_bound_float() { |
207 | return Eigen::numext::mini( |
208 | static_cast<float>(Eigen::NumTraits<T>::highest()), +2.147483520e+09f); |
209 | } |
210 | |
211 | static float lowest_quantized() { |
212 | return static_cast<float>(Eigen::NumTraits<T>::lowest()); |
213 | } |
214 | |
215 | FloatToQuantizedStruct(float range_min, float range_max) |
216 | : range_min(range_min), |
217 | range_scale(range_max == range_min |
218 | ? 0.0 |
219 | : (number_of_steps - 1.0) / (range_max - range_min)), |
220 | range_min_scaled(std::round(range_min * range_scale)) {} |
221 | |
222 | const float range_min; |
223 | const float range_scale; |
224 | const float range_min_scaled; |
225 | }; |
226 | |
227 | template <class T1, class T2> |
228 | inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input, |
229 | float min_new, float max_new) { |
230 | const float input_float = QuantizedToFloat<T1>(input, min_input, max_input); |
231 | return FloatToQuantized<T2>(input_float, min_new, max_new); |
232 | } |
233 | |
234 | template <class T1, class T2> |
235 | inline void RequantizeManyInNewRange(const T1* input, int64_t count, |
236 | float min_input, float max_input, |
237 | float min_output, float max_output, |
238 | T2* output) { |
239 | for (size_t index = 0; index < count; ++index) { |
240 | const float input_float = |
241 | QuantizedToFloat<T1>(input[index], min_input, max_input); |
242 | output[index] = FloatToQuantized<T2>(input_float, min_output, max_output); |
243 | } |
244 | } |
245 | |
246 | // Because converting 32-bit accumulated results down to eight bit is a common |
247 | // case, we have a specialized code path to handle it as efficiently as |
248 | // possible using only fixed-point math for the inner loop. |
249 | inline void RequantizeManyInNewRangeReference(const qint32* input, |
250 | int64_t count, float min_input, |
251 | float max_input, float min_output, |
252 | float max_output, |
253 | quint8* output) { |
254 | // Initially we calculate all the constants we need once, before we go into |
255 | // the inner loop. If this is updated, also update the Eigen version. |
256 | const int fp_shift = 16; |
257 | const float input_range = max_input - min_input; |
258 | const float output_range = max_output - min_output; |
259 | const float recip_output_range = |
260 | output_range == 0.0 ? 0.0 : (255.0 / output_range); |
261 | const float input_rezero = (min_input + max_input) / 2.0; |
262 | const int64_t range_scale_fp = |
263 | output_range == 0.0 ? 0.0 |
264 | : static_cast<int64_t>(255.0 * (1 << fp_shift) * |
265 | input_range / output_range); |
266 | const int64_t input_offset_fp = |
267 | static_cast<int64_t>(input_rezero * recip_output_range * (1 << fp_shift)); |
268 | const int64_t output_offset_fp = |
269 | output_range == 0.0 |
270 | ? 0 |
271 | : std::lround((1 << fp_shift) * (min_output * 255.0) / output_range); |
272 | const int64_t rounding_delta = 1 << (fp_shift - 1); |
273 | |
274 | // Inside this loop we just do minimal adds, multiplies, and shifts, in a way |
275 | // that could be easily adapted for a SIMD implementation. It should also be |
276 | // possible to perform all the calculations in 32-bit rather than 64, but |
277 | // that's not been implemented yet. |
278 | for (int64_t index = 0; index < count; ++index) { |
279 | const int64_t input_value = static_cast<int64_t>(input[index]); |
280 | const int64_t fp_value = |
281 | ((input_value * range_scale_fp) >> 32) + input_offset_fp; |
282 | const int64_t offset_intermediate = fp_value - output_offset_fp; |
283 | const int64_t round_intermediate = offset_intermediate + rounding_delta; |
284 | int64_t quantized_int64 = round_intermediate >> fp_shift; |
285 | quantized_int64 = std::max(quantized_int64, int64_t{0}); |
286 | quantized_int64 = std::min(quantized_int64, int64_t{255}); |
287 | output[index] = static_cast<quint8>(static_cast<int32>(quantized_int64)); |
288 | } |
289 | } |
290 | |
291 | // Another common case is converting eight bit inputs up to thirty two bits, so |
292 | // we have specialized fixed-point code to accelerate that. There is also a NEON |
293 | // version for ARM devices below. |
294 | inline void RequantizeManyInNewRange8To32BitReference( |
295 | const quint8* input, int64_t count, float min_input, float max_input, |
296 | float min_output, float max_output, qint32* output) { |
297 | const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input); |
298 | const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input); |
299 | const int64_t code_0_int64 = |
300 | FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output); |
301 | const int64_t code_1_int64 = |
302 | FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output); |
303 | const int32_t mult_int32 = code_1_int64 - code_0_int64; |
304 | const int64_t lowest_quantized = |
305 | static_cast<int64_t>(Eigen::NumTraits<qint32>::lowest()); |
306 | const int64_t highest_quantized = |
307 | static_cast<int64_t>(Eigen::NumTraits<qint32>::highest()); |
308 | for (int64_t i = 0; i < count; ++i) { |
309 | const int64_t input_value = static_cast<int64_t>(input[i]); |
310 | int64_t output_value = code_0_int64 + (input_value * mult_int32); |
311 | output_value = std::max(output_value, lowest_quantized); |
312 | output_value = std::min(output_value, highest_quantized); |
313 | output[i] = static_cast<int32>(output_value); |
314 | } |
315 | } |
316 | |
317 | #ifdef QUANTIZATION_UTILS_USE_NEON |
318 | // Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD |
319 | // intrinsics for ARM platforms. |
320 | inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count, |
321 | float min_input, float max_input, |
322 | float min_output, float max_output, |
323 | quint8* output) { |
324 | // Initially we calculate all the constants we need once, before we go into |
325 | // the inner loop. If this is updated, also update the Eigen version. |
326 | const int fp_shift = 16; |
327 | |
328 | // Calculate range variables in advance. |
329 | // Input range. |
330 | const float input_range = max_input - min_input; |
331 | // Output range. |
332 | const float output_range = max_output - min_output; |
333 | // Ratio of output range. |
334 | const float recip_output_range = |
335 | output_range == 0.0 ? 0.0 : (255.0 / output_range); |
336 | // Average of input range as zero position of input. |
337 | const float input_rezero = (min_input + max_input) / 2.0; |
338 | // In-out range scale. |
339 | const int32 range_scale_fp = |
340 | output_range == 0.0 ? 0.0 |
341 | : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) * |
342 | input_range / output_range); |
343 | // Input zero position offset to output. |
344 | const int32 input_offset_fp = |
345 | static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift)); |
346 | // Output min offset. |
347 | const int32 output_offset_fp = |
348 | output_range == 0.0 |
349 | ? 0 |
350 | : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) / |
351 | output_range); |
352 | const int32 rounding_delta = 1 << (fp_shift - 1); |
353 | |
354 | // broadcast range to each lane |
355 | const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp); |
356 | const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp); |
357 | const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp); |
358 | const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta); |
359 | |
360 | int64 index = 0; |
361 | // Use SIMD to requantize. |
362 | for (; index < (count - 7); index += 8) { |
363 | const int32* input_ptr = &(input->value) + index; |
364 | const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr); |
365 | const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4); |
366 | const int32x4_t fp_value_low_32x4 = vaddq_s32( |
367 | input_offset_fp_32x4, |
368 | vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4)); |
369 | const int32x4_t fp_value_high_32x4 = vaddq_s32( |
370 | input_offset_fp_32x4, |
371 | vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4)); |
372 | const int32x4_t offset_intermediate_low_32x4 = |
373 | vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4); |
374 | const int32x4_t offset_intermediate_high_32x4 = |
375 | vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4); |
376 | const int32x4_t round_intermediate_low_32x4 = |
377 | vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4); |
378 | const int32x4_t round_intermediate_high_32x4 = |
379 | vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4); |
380 | const int16x4_t quantized_low_16x4 = |
381 | vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift)); |
382 | const int16x4_t quantized_high_16x4 = |
383 | vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift)); |
384 | const uint8x8_t quantized_8x8 = |
385 | vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4)); |
386 | uint8* output_ptr = &(output->value) + index; |
387 | vst1_u8(output_ptr, quantized_8x8); |
388 | } |
389 | |
390 | // Requantize remaining elements in array without SIMD. |
391 | for (; index < count; ++index) { |
392 | const int32 input_value = static_cast<int32>(input[index]); |
393 | const int32 fp_value = |
394 | static_cast<int32>( |
395 | (static_cast<int32>(input_value >> 16) * (range_scale_fp))) + |
396 | input_offset_fp; |
397 | const int32 offset_intermediate = fp_value - output_offset_fp; |
398 | const int32 round_intermediate = offset_intermediate + rounding_delta; |
399 | int32 quantized_int32 = round_intermediate >> fp_shift; |
400 | quantized_int32 = std::max(quantized_int32, 0); |
401 | quantized_int32 = std::min(quantized_int32, 255); |
402 | output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32)); |
403 | } |
404 | } |
405 | |
406 | template <> |
407 | inline void RequantizeManyInNewRange<qint32, quint8>( |
408 | const qint32* input, int64 count, float min_input, float max_input, |
409 | float min_output, float max_output, quint8* output) { |
410 | const float input_range = max_input - min_input; |
411 | const float output_range = max_output - min_output; |
412 | if ((input_range / output_range) > 16384.0f) { |
413 | // Our NEON implementation uses 32-bit math and can't handle very |
414 | // large ranges, so fall back to the reference implementation. We don't |
415 | // expect these to be common in models, so this shouldn't be a performance |
416 | // problem in practice. |
417 | RequantizeManyInNewRangeReference(input, count, min_input, max_input, |
418 | min_output, max_output, output); |
419 | } else { |
420 | RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output, |
421 | max_output, output); |
422 | } |
423 | } |
424 | |
425 | // NEON accelerated 16bit rounded division by 2^n. |
426 | template <int POW> |
427 | inline int16x8_t Divide16x8PowRound(const int16x8_t val) { |
428 | const int16x8_t val_sign = vshrq_n_s16(val, 15); |
429 | const int16x8_t val_xor = veorq_s16(val, val_sign); |
430 | const int16x8_t val_pos = vsubq_s16(val_xor, val_sign); |
431 | const int16x8_t shifted_val_pos = vrshrq_n_s16(val_pos, POW); |
432 | const int16x8_t shifted_val_pos_xor = veorq_s16(shifted_val_pos, val_sign); |
433 | const int16x8_t shifted_val = vsubq_s16(shifted_val_pos_xor, val_sign); |
434 | return shifted_val; |
435 | } |
436 | |
437 | // NEON accelerated 64bit rounded division by 2^n. |
438 | template <int POW> |
439 | inline int64x2_t Divide64x2PowRound(const int64x2_t val) { |
440 | const int64x2_t val_sign = vshrq_n_s64(val, 63); |
441 | const int64x2_t val_xor = veorq_s64(val, val_sign); |
442 | const int64x2_t val_pos = vsubq_s64(val_xor, val_sign); |
443 | const int64x2_t shifted_val_pos = vrshrq_n_s64(val_pos, POW); |
444 | const int64x2_t shifted_val_pos_xor = veorq_s64(shifted_val_pos, val_sign); |
445 | const int64x2_t shifted_val = vsubq_s64(shifted_val_pos_xor, val_sign); |
446 | return shifted_val; |
447 | } |
448 | |
449 | // NEON accelerated 16bit division by 2^n. |
450 | // CAVEAT: The input must be greater than min-int16 to avoid underflow. |
451 | template <int POW> |
452 | inline int16x8_t Divide16x8Pow(const int16x8_t val) { |
453 | static constexpr int16 FIRST_BIT_VAL = 0x0000000000000001; |
454 | static const int16x8_t FIRST_BIT = vmovq_n_s16(FIRST_BIT_VAL); |
455 | const int16x8_t val_sign = vshrq_n_s16(val, 15); |
456 | const int16x8_t neg_offset = vandq_s16(val_sign, FIRST_BIT); |
457 | const int16x8_t val_with_offset = vsubq_s16(val, neg_offset); |
458 | const int16x8_t shifted_wo_offset = |
459 | vsraq_n_s16(neg_offset, val_with_offset, POW); |
460 | return shifted_wo_offset; |
461 | } |
462 | |
463 | // NEON accelerated 64bit division by 2^n. |
464 | // CAVEAT: The input must be greater than min-int64 to avoid underflow. |
465 | template <int POW> |
466 | inline int64x2_t Divide64x2Pow(const int64x2_t val) { |
467 | static constexpr int64 FIRST_BIT_VAL = 0x0000000000000001; |
468 | static const int64x2_t FIRST_BIT = vmovq_n_s64(FIRST_BIT_VAL); |
469 | const int64x2_t val_sign = vshrq_n_s64(val, 63); |
470 | const int64x2_t neg_offset = vandq_s64(val_sign, FIRST_BIT); |
471 | const int64x2_t val_with_offset = vsubq_s64(val, neg_offset); |
472 | const int64x2_t shifted_wo_offset = |
473 | vsraq_n_s64(neg_offset, val_with_offset, POW); |
474 | return shifted_wo_offset; |
475 | } |
476 | |
477 | // 32bit x 2 NEON accelerated lerp computation. |
478 | template <int RESOLUTION> |
479 | inline int32x2_t ComputeLerp32x2(const int32x2_t top_left, |
480 | const int32x2_t top_right, |
481 | const int32x2_t bottom_left, |
482 | const int32x2_t bottom_right, |
483 | const int32x2_t x_lerp, |
484 | const int32x2_t y_lerp) { |
485 | static_assert(RESOLUTION < 31, "RESOLUTION must be less than 31" ); |
486 | constexpr int32 RESOLUTION_MULT32 = (1 << RESOLUTION); |
487 | static const int32x2_t RESOLUTION_MULT32x2 = vmov_n_s32(RESOLUTION_MULT32); |
488 | |
489 | const int64x2_t top_left_x_res = vmull_s32(top_left, RESOLUTION_MULT32x2); |
490 | const int64x2_t bottom_left_x_res = |
491 | vmull_s32(bottom_left, RESOLUTION_MULT32x2); |
492 | |
493 | const int32x2_t top_right_sub_top_left = vsub_s32(top_right, top_left); |
494 | const int64x2_t top_x_res = |
495 | vmlal_s32(top_left_x_res, top_right_sub_top_left, x_lerp); |
496 | const int32x2_t bottom_right_sub_bottom_left = |
497 | vsub_s32(bottom_right, bottom_left); |
498 | const int64x2_t bottom_x_res = |
499 | vmlal_s32(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp); |
500 | |
501 | const int64x2_t bottom_sub_top_x_res = vsubq_s64(bottom_x_res, top_x_res); |
502 | const int64x2_t bottom_sub_top = |
503 | Divide64x2Pow<RESOLUTION>(bottom_sub_top_x_res); |
504 | const int32x2_t bottom_sub_top_32 = vqmovn_s64(bottom_sub_top); |
505 | const int64x2_t top_add_bottom_sub_top_mul_ylerp_x_res = |
506 | vmlal_s32(top_x_res, bottom_sub_top_32, y_lerp); |
507 | const int64x2_t retval = |
508 | Divide64x2PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res); |
509 | const int32x2_t retval32 = vqmovn_s64(retval); |
510 | return retval32; |
511 | } |
512 | |
513 | // 8bit x 8 NEON accelerated lerp computation. |
514 | template <int RESOLUTION> |
515 | inline uint8x8_t ComputeLerp8x8(const uint8x8_t top_left8x8, |
516 | const uint8x8_t top_right8x8, |
517 | const uint8x8_t bottom_left8x8, |
518 | const uint8x8_t bottom_right8x8, |
519 | const int16x8_t x_lerp, |
520 | const int16x8_t y_lerp) { |
521 | static_assert(RESOLUTION < 8, "RESOLUTION must be less than 8" ); |
522 | constexpr uint8 RESOLUTION_MULT_VAL = (1 << RESOLUTION); |
523 | static const uint8x8_t RESOLUTION_MULT = vdup_n_u8(RESOLUTION_MULT_VAL); |
524 | |
525 | const int16x8_t top_left_x_res = |
526 | vreinterpretq_s16_u16(vmull_u8(top_left8x8, RESOLUTION_MULT)); |
527 | const int16x8_t bottom_left_x_res = |
528 | vreinterpretq_s16_u16(vmull_u8(bottom_left8x8, RESOLUTION_MULT)); |
529 | |
530 | const int16x8_t top_right_sub_top_left = |
531 | vreinterpretq_s16_u16(vsubl_u8(top_right8x8, top_left8x8)); |
532 | const int16x8_t top_x_res = |
533 | vmlaq_s16(top_left_x_res, top_right_sub_top_left, x_lerp); |
534 | |
535 | const int16x8_t bottom_right_sub_bottom_left = |
536 | vreinterpretq_s16_u16(vsubl_u8(bottom_right8x8, bottom_left8x8)); |
537 | const int16x8_t bottom_x_res = |
538 | vmlaq_s16(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp); |
539 | |
540 | const int16x8_t bottom_sub_top_x_res = vsubq_s16(bottom_x_res, top_x_res); |
541 | const int16x8_t bottom_sub_top = |
542 | Divide16x8Pow<RESOLUTION>(bottom_sub_top_x_res); |
543 | const int16x8_t top_add_bottom_sub_top_mul_ylerp_x_res = |
544 | vmlaq_s16(top_x_res, bottom_sub_top, y_lerp); |
545 | const int16x8_t retval16 = |
546 | Divide16x8PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res); |
547 | const uint8x8_t retval = vmovn_u16(vreinterpretq_u16_s16(retval16)); |
548 | return retval; |
549 | } |
550 | |
551 | // Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon |
552 | // Return std::array instead of pointer to leverage return value optimization |
553 | inline std::array<int32x4_t, 2> Requantize8x8To32Neon( |
554 | const uint8* input_ptr, const int64x2_t input_0_64x2, |
555 | const int32x2_t input_mult_32x2) { |
556 | const uint8x8_t input_value_8x8 = vld1_u8(input_ptr); |
557 | const int16x8_t input_value_16x8 = |
558 | vreinterpretq_s16_u16(vmovl_u8(input_value_8x8)); |
559 | const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8); |
560 | const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8); |
561 | const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4); |
562 | const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4); |
563 | const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4); |
564 | const int32x2_t input_value_low_high_32x2 = |
565 | vget_high_s32(input_value_low_32x4); |
566 | const int32x2_t input_value_high_low_32x2 = |
567 | vget_low_s32(input_value_high_32x4); |
568 | const int32x2_t input_value_high_high_32x2 = |
569 | vget_high_s32(input_value_high_32x4); |
570 | const int64x2_t mult_result_low_low_64x2 = |
571 | vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2); |
572 | const int64x2_t mult_result_low_high_64x2 = |
573 | vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2); |
574 | const int64x2_t mult_result_high_low_64x2 = |
575 | vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2); |
576 | const int64x2_t mult_result_high_high_64x2 = |
577 | vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2); |
578 | const int32x2_t output_value_low_low_32x2 = |
579 | vqmovn_s64(mult_result_low_low_64x2); |
580 | const int32x2_t output_value_low_high_32x2 = |
581 | vqmovn_s64(mult_result_low_high_64x2); |
582 | const int32x2_t output_value_high_low_32x2 = |
583 | vqmovn_s64(mult_result_high_low_64x2); |
584 | const int32x2_t output_value_high_high_32x2 = |
585 | vqmovn_s64(mult_result_high_high_64x2); |
586 | const int32x4_t output_value_low_32x4 = |
587 | vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2); |
588 | const int32x4_t output_value_high_32x4 = |
589 | vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2); |
590 | return std::array<int32x4_t, 2>{ |
591 | {output_value_low_32x4, output_value_high_32x4}}; |
592 | } |
593 | |
594 | // Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD |
595 | // intrinsics for ARM platforms. |
596 | template <> |
597 | inline void RequantizeManyInNewRange<quint8, qint32>( |
598 | const quint8* input, int64 count, float min_input, float max_input, |
599 | float min_output, float max_output, qint32* output) { |
600 | // Pre-calculate zero position and multiplier. |
601 | // Calculate 0 and 1 value in float. |
602 | const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input); |
603 | const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input); |
604 | |
605 | // Cast 0 and 1 value in int64. |
606 | const int64 code_0_int64 = |
607 | FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output); |
608 | const int64 code_1_int64 = |
609 | FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output); |
610 | |
611 | // Calculate multiplier. |
612 | const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64); |
613 | |
614 | // Broadcast 0 position and multiplier to lanes |
615 | const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64); |
616 | const int32x2_t mult_32x2 = vmov_n_s32(mult_int32); |
617 | |
618 | int64 i = 0; |
619 | |
620 | // Use SIMD to requantize array. |
621 | for (; i < (count - 7); i += 8) { |
622 | const uint8* input_ptr = &(input->value) + i; |
623 | int32* output_ptr = &(output->value) + i; |
624 | const std::array<int32x4_t, 2> output_value = |
625 | Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2); |
626 | vst1q_s32(output_ptr + 0, output_value[0]); |
627 | vst1q_s32(output_ptr + 4, output_value[1]); |
628 | } |
629 | |
630 | // Requantize remaining elements in array without SIMD. |
631 | const int64 lowest_quantized = |
632 | static_cast<int64_t>(Eigen::NumTraits<qint32>::lowest()); |
633 | const int64 highest_quantized = |
634 | static_cast<int64_t>(Eigen::NumTraits<qint32>::highest()); |
635 | |
636 | for (; i < count; ++i) { |
637 | const int64 input_value = static_cast<int64_t>(input[i]); |
638 | int64 output_value = code_0_int64 + (input_value * mult_int32); |
639 | output_value = std::max(output_value, lowest_quantized); |
640 | output_value = std::min(output_value, highest_quantized); |
641 | output[i] = static_cast<int32>(output_value); |
642 | } |
643 | } |
644 | |
645 | #else |
646 | |
647 | // If SIMD implementations aren't available, then use these default reference |
648 | // versions. |
649 | template <> |
650 | inline void RequantizeManyInNewRange<qint32, quint8>( |
651 | const qint32* input, int64_t count, float min_input, float max_input, |
652 | float min_output, float max_output, quint8* output) { |
653 | RequantizeManyInNewRangeReference(input, count, min_input, max_input, |
654 | min_output, max_output, output); |
655 | } |
656 | |
657 | template <> |
658 | inline void RequantizeManyInNewRange<quint8, qint32>( |
659 | const quint8* input, int64_t count, float min_input, float max_input, |
660 | float min_output, float max_output, qint32* output) { |
661 | RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input, |
662 | min_output, max_output, output); |
663 | } |
664 | |
665 | #endif |
666 | |
667 | template <int shift> |
668 | struct int64_right_shift_op { |
669 | EIGEN_DEVICE_FUNC |
670 | EIGEN_STRONG_INLINE const int64_t operator()(const int64_t a) const { |
671 | return a >> shift; |
672 | } |
673 | }; |
674 | |
675 | // See RequantizeManyInNewRange() for a non-eigen reference implementation. |
676 | template <class T1, class T2> |
677 | inline void RequantizeManyInNewRangeUsingEigen( |
678 | const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input, |
679 | float max_input, float min_output, float max_output, Tensor* output) { |
680 | auto input_array = input.flat<T1>(); |
681 | QuantizedToFloatStruct<T1> q2f(min_input, max_input); |
682 | auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f); |
683 | FloatToQuantizedStruct<T2> f2q(min_output, max_output); |
684 | auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2); |
685 | |
686 | output->flat<T2>().device(device) = input_requantized; |
687 | } |
688 | |
689 | // See RequantizeManyInNewRange() for a non-eigen reference implementation. |
690 | // |
691 | // Because converting 32-bit accumulated results down to eight bit is a common |
692 | // case, we have a specialized code path to handle it as efficiently as |
693 | // possible using only fixed-point math for the inner loop. |
694 | template <> |
695 | inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>( |
696 | const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input, |
697 | float max_input, float min_output, float max_output, Tensor* output) { |
698 | // Initially we calculate all the constants we need once, before we go into |
699 | // the inner loop. If this is updated, also update the non-Eigen version. |
700 | const int fp_shift = 16; |
701 | const float input_range = max_input - min_input; |
702 | const float output_range = max_output - min_output; |
703 | const float recip_output_range = |
704 | output_range == 0.0 ? 0.0 : (255.0 / output_range); |
705 | const float input_rezero = (min_input + max_input) / 2.0; |
706 | const int64_t range_scale_fp = |
707 | output_range == 0.0 ? 0.0 |
708 | : static_cast<int64_t>(255.0 * (1 << fp_shift) * |
709 | input_range / output_range); |
710 | const int64_t input_offset_fp = |
711 | static_cast<int64_t>(input_rezero * recip_output_range * (1 << fp_shift)); |
712 | const int64_t output_offset_fp = |
713 | output_range == 0.0 |
714 | ? 0 |
715 | : std::lround((1 << fp_shift) * (min_output * 255.0) / output_range); |
716 | const int64_t rounding_delta = 1 << (fp_shift - 1); |
717 | |
718 | // Inside this eigen expression we just do minimal adds, multiplies, and |
719 | // shifts. It should be possible to perform all the calculations in 32-bit |
720 | // rather than 64, but that's not been implemented yet. |
721 | auto input_array = input.flat<qint32>(); |
722 | auto fp_value = ((input_array.template cast<int64_t>() * range_scale_fp) |
723 | .unaryExpr(int64_right_shift_op<32>())) + |
724 | (input_offset_fp - output_offset_fp + rounding_delta); |
725 | auto intermediate = fp_value.unaryExpr(int64_right_shift_op<fp_shift>()); |
726 | auto input_requantized = intermediate.cwiseMax(int64_t{0}) |
727 | .cwiseMin(int64_t{255}) |
728 | .template cast<int32>() |
729 | .template cast<quint8>(); |
730 | output->flat<quint8>().device(device) = input_requantized; |
731 | } |
732 | |
733 | // REQUIRES: 'result->NumElements() == input.NumElements()' |
734 | template <class T> |
735 | void FloatTensorToQuantizedInPlaceUsingEigen( |
736 | const Eigen::ThreadPoolDevice& device, const Tensor& input, float min, |
737 | float max, Tensor* result) { |
738 | DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype()); |
739 | auto flat_input = input.flat<float>(); |
740 | auto flat_result = result->flat<T>(); |
741 | DCHECK_EQ(flat_input.size(), flat_result.size()); |
742 | |
743 | FloatToQuantizedStruct<T> f2q(min, max); |
744 | flat_result.device(device) = QUANTIZE_WITH_EIGEN(flat_input, f2q, T); |
745 | } |
746 | |
747 | template <class T> |
748 | void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max, |
749 | Tensor* result) { |
750 | DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype()); |
751 | auto flat_input = input.flat<float>(); |
752 | auto flat_result = result->flat<T>(); |
753 | const int data_size = flat_input.size(); |
754 | DCHECK(data_size == flat_result.size()); |
755 | for (int i = 0; i < data_size; ++i) { |
756 | flat_result(i) = FloatToQuantized<T>(flat_input(i), min, max); |
757 | } |
758 | } |
759 | |
760 | template <class T> |
761 | Tensor FloatTensorToQuantized(const Tensor& input, float min, float max) { |
762 | Tensor result(DataTypeToEnum<T>::v(), input.shape()); |
763 | FloatTensorToQuantizedInPlace<T>(input, min, max, &result); |
764 | return result; |
765 | } |
766 | |
767 | // REQUIRES: 'result->NumElements() == input.NumElements()' |
768 | template <class T> |
769 | void QuantizedTensorToFloatInPlaceUsingEigen( |
770 | const Eigen::ThreadPoolDevice& device, const Tensor& input, float min, |
771 | float max, Tensor* result) { |
772 | DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype()); |
773 | auto flat_input = input.flat<T>(); |
774 | auto flat_result = result->flat<float>(); |
775 | const int data_size = flat_input.size(); |
776 | DCHECK(data_size == flat_result.size()); |
777 | |
778 | QuantizedToFloatStruct<T> q2f(min, max); |
779 | flat_result.device(device) = DEQUANTIZE_WITH_EIGEN(flat_input, q2f); |
780 | } |
781 | |
782 | // REQUIRES: 'result->NumElements() == input.NumElements()' |
783 | template <class T> |
784 | void QuantizedTensorToFloatInPlace(const Tensor& input, float min, float max, |
785 | Tensor* result) { |
786 | DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype()); |
787 | auto flat_input = input.flat<T>(); |
788 | auto flat_result = result->flat<float>(); |
789 | const int data_size = flat_input.size(); |
790 | DCHECK(data_size == flat_result.size()); |
791 | for (int i = 0; i < data_size; ++i) { |
792 | flat_result(i) = QuantizedToFloat<T>(flat_input(i), min, max); |
793 | } |
794 | } |
795 | |
796 | template <class T> |
797 | Tensor QuantizedTensorToFloat(const Tensor& input, float min, float max) { |
798 | Tensor result(DT_FLOAT, input.shape()); |
799 | QuantizedTensorToFloatInPlace<T>(input, min, max, &result); |
800 | return result; |
801 | } |
802 | |
803 | void GetOutputMinAndMaxForQuantizedAdd(float input_min, float input_max, |
804 | float smaller_input_min, |
805 | float smaller_input_max, |
806 | float* output_min, float* output_max); |
807 | |
808 | // Add <input> and <smaller_input>. If <smaller_input> has fewer elements than |
809 | // <input>, then it is broadcast onto <input>. |
810 | template <typename T1, typename T2, typename T3> |
811 | void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device, |
812 | const Tensor& input, float input_min, |
813 | float input_max, const Tensor& smaller_input, |
814 | float smaller_input_min, float smaller_input_max, |
815 | Tensor* output, float* output_min, |
816 | float* output_max) { |
817 | const auto& input_flat = input.flat<T1>(); |
818 | const auto& smaller_input_flat = smaller_input.flat<T2>(); |
819 | auto output_flat = output->flat<T3>(); |
820 | |
821 | GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min, |
822 | smaller_input_max, output_min, output_max); |
823 | // To do addition properly, we need to compensate for a possibly unbalanced |
824 | // zero point in the total representation. The quantized value that |
825 | // represents the real number zero needs to be subtracted before addition to |
826 | // make sure that the identity of zero + zero = zero holds. |
827 | const T3 zero_in_total_space = |
828 | FloatToQuantized<T3>(0.0f, *output_min, *output_max); |
829 | |
830 | const int64_t input_element_count = input.NumElements(); |
831 | const int64_t smaller_input_element_count = smaller_input.NumElements(); |
832 | |
833 | QuantizedToFloatStruct<T1> input_q2f(input_min, input_max); |
834 | QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min, |
835 | smaller_input_max); |
836 | FloatToQuantizedStruct<T3> f2q(*output_min, *output_max); |
837 | |
838 | auto smaller_input_float = |
839 | DEQUANTIZE_WITH_EIGEN(smaller_input_flat, smaller_input_q2f); |
840 | auto smaller_input_in_total_space = |
841 | QUANTIZE_WITH_EIGEN(smaller_input_float, f2q, T3); |
842 | |
843 | auto input_float = DEQUANTIZE_WITH_EIGEN(input_flat, input_q2f); |
844 | auto input_in_total_space = QUANTIZE_WITH_EIGEN(input_float, f2q, T3); |
845 | |
846 | Eigen::array<Eigen::DenseIndex, 1> bcast; |
847 | bcast[0] = input_element_count / smaller_input_element_count; |
848 | output_flat.device(device) = |
849 | input_in_total_space + |
850 | (smaller_input_in_total_space.broadcast(bcast) + zero_in_total_space); |
851 | } |
852 | |
853 | // This is a reference implementation of the bias addition for quantized |
854 | // buffers, designed to provide a clear specification for the result we |
855 | // want. We'll want to specialize this for particular hardware, and |
856 | // probably even fuse it with matrix multiplications in a lot of cases. It's |
857 | // important to show the clamping behavior we want in particular. |
858 | template <typename T1, typename T2, typename T3> |
859 | void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input, |
860 | float input_min, float input_max, const Tensor& smaller_input, |
861 | float smaller_input_min, float smaller_input_max, |
862 | Tensor* output, float* output_min, float* output_max) { |
863 | const auto& input_flat = input.flat<T1>(); |
864 | const auto& smaller_input_flat = smaller_input.flat<T2>(); |
865 | auto output_flat = output->flat<T3>(); |
866 | |
867 | GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min, |
868 | smaller_input_max, output_min, output_max); |
869 | // To do addition properly, we need to compensate for a possibly unbalanced |
870 | // zero point in the total representation. The quantized value that |
871 | // represents the real number zero needs to be subtracted before addition to |
872 | // make sure that the identity of zero + zero = zero holds. |
873 | const T3 zero_in_total_space = |
874 | FloatToQuantized<T3>(0.0f, *output_min, *output_max); |
875 | |
876 | const int64_t input_element_count = input.NumElements(); |
877 | const int64_t smaller_input_element_count = smaller_input.NumElements(); |
878 | |
879 | float total_min = *output_min; |
880 | float total_max = *output_max; |
881 | const size_t how_many_iterations = |
882 | (input_element_count / smaller_input_element_count); |
883 | for (size_t iteration = 0; iteration < how_many_iterations; ++iteration) { |
884 | const size_t offset = iteration * smaller_input_element_count; |
885 | for (int c = 0; c < smaller_input_element_count; ++c) { |
886 | const int index = (offset + c); |
887 | // The two numbers we're going to add can each be in very different |
888 | // ranges (e.g. the quantized value '127' may represent very different |
889 | // real numbers in both) so we need to convert them to a common range |
890 | // before we sum them. |
891 | const T1 input_value = input_flat(index); |
892 | const T3 input_in_total_space = RequantizeInNewRange<T1, T3>( |
893 | input_value, input_min, input_max, total_min, total_max); |
894 | const T2 smaller_input_value = smaller_input_flat(c); |
895 | const T3 smaller_input_in_total_space = |
896 | RequantizeInNewRange<T2, T3>(smaller_input_value, smaller_input_min, |
897 | smaller_input_max, total_min, total_max); |
898 | const T3 total_pre = input_in_total_space + smaller_input_in_total_space; |
899 | // As noted above, we need to compensate for the offset of the actual |
900 | // zero point in the space we're operating in. |
901 | const T3 total = total_pre + zero_in_total_space; |
902 | output_flat(index) = total; |
903 | } |
904 | } |
905 | } |
906 | |
907 | // See gemmlowp/internal/multi_thread_gemm.h for the semantics of Execute. |
908 | class TensorflowGemmlowpWorkersPool { |
909 | public: |
910 | TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers) |
911 | : workers_(workers) {} |
912 | |
913 | ~TensorflowGemmlowpWorkersPool() { |
914 | // This workaround ensures that all worker tasks have exited methods in the |
915 | // BlockingCounter. Without this, there is a race where the context is torn |
916 | // down while the counter is in use. |
917 | counter_to_decrement_when_ready_.Reset(0); |
918 | } |
919 | |
920 | void Execute(const std::vector<gemmlowp::Task*>& tasks) { |
921 | assert(!tasks.empty()); |
922 | assert(workers_ != nullptr); |
923 | counter_to_decrement_when_ready_.Reset(tasks.size()); |
924 | for (gemmlowp::Task* task : tasks) { |
925 | workers_->Schedule([this, task]() { |
926 | // TODO(cwhipkey): get a local_allocator from a thread local storage. |
927 | gemmlowp::Allocator local_allocator; |
928 | CHECK(task != nullptr); |
929 | task->local_allocator = &local_allocator; |
930 | task->Run(); |
931 | counter_to_decrement_when_ready_.DecrementCount(); |
932 | }); |
933 | } |
934 | counter_to_decrement_when_ready_.Wait(); |
935 | for (gemmlowp::Task* task : tasks) { |
936 | delete task; |
937 | } |
938 | } |
939 | |
940 | private: |
941 | thread::ThreadPool* const workers_; |
942 | |
943 | // The BlockingCounter used to wait for the workers. |
944 | gemmlowp::BlockingCounter counter_to_decrement_when_ready_; |
945 | |
946 | TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool); |
947 | }; |
948 | |
949 | class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase { |
950 | public: |
951 | TensorflowGemmContext(int num_threads, thread::ThreadPool* workers) |
952 | : workers_pool_(workers) { |
953 | set_max_num_threads(num_threads); |
954 | } |
955 | |
956 | TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; } |
957 | |
958 | private: |
959 | TensorflowGemmlowpWorkersPool workers_pool_; |
960 | |
961 | TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext); |
962 | }; |
963 | |
964 | } // namespace tensorflow |
965 | |
966 | #endif // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ |
967 | |