1 | #define FBGEMM_EXPORTS |
2 | #include <algorithm> |
3 | #include <iterator> |
4 | #include <numeric> |
5 | #include <type_traits> |
6 | |
7 | #include "fbgemm/QuantUtils.h" |
8 | |
9 | #include <cpuinfo.h> |
10 | |
11 | #include "fbgemm/Fbgemm.h" |
12 | |
13 | #include "fbgemm/Types.h" |
14 | |
15 | namespace fbgemm { |
16 | |
17 | using namespace std; |
18 | |
19 | // Use fp16_min as the small scale cutoff because we don't want to use scales in |
20 | // fp16 subnormal range. This is to be consistent with Glow and FakeLowP |
21 | // implementation for NNPI. |
22 | constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; |
23 | |
24 | float TensorQuantizationParams::Min() const { |
25 | return Dequantize(0, *this); |
26 | } |
27 | |
28 | float TensorQuantizationParams::Max() const { |
29 | return Dequantize((1 << precision) - 1, *this); |
30 | } |
31 | |
32 | TensorQuantizationParams ChooseQuantizationParams( |
33 | float min, |
34 | float max, |
35 | int32_t qmin, |
36 | int32_t qmax, |
37 | bool preserve_sparsity, |
38 | bool force_scale_power_of_two) { |
39 | if (min < 0 && max > 0 && preserve_sparsity) { |
40 | int symmetric_qmin = -((qmax - qmin) / 2 + 1); |
41 | int symmetric_qmax = (qmax - qmin) / 2; |
42 | double max_scale = |
43 | std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax)); |
44 | min = max_scale * symmetric_qmin; |
45 | max = max_scale * symmetric_qmax; |
46 | } |
47 | |
48 | // We extend the [min, max] interval to ensure that it contains 0. |
49 | // Otherwise, we would not meet the requirement that 0 be an exactly |
50 | // representable value. |
51 | min = std::min(min, 0.f); |
52 | max = std::max(max, 0.f); |
53 | |
54 | // Use double precision for intermediate computation but use single precision |
55 | // in final number to reflect the actual number used during quantization. |
56 | float scale = (static_cast<double>(max) - min) / (qmax - qmin); |
57 | // If scale is 0 or too small so its reciprocal is infinity, we arbitrary |
58 | // adjust the scale to 0.1 . We want to avoid scale's reciprocal being |
59 | // infinity because some of fbgemm code pre-computes scale's reciprocal to do |
60 | // multiplication instead of division in the time critical part of code. |
61 | if (scale == 0.0f || isinf(1.0f / scale)) { |
62 | scale = 0.1; |
63 | } |
64 | assert(scale > 0); |
65 | |
66 | if (force_scale_power_of_two) { |
67 | if (scale < 1) { |
68 | scale = 1.0 / (1 << static_cast<int>(floor(log2(1.0 / scale)))); |
69 | } else { |
70 | scale = 1 << static_cast<int>(ceil(log2(scale))); |
71 | } |
72 | } |
73 | |
74 | // Cut off small scale |
75 | if (scale < SMALL_SCALE_THRESHOLD) { |
76 | float org_scale = scale; |
77 | scale = SMALL_SCALE_THRESHOLD; |
78 | // Adjust the min and max based on the new scale |
79 | if (min == 0.0f) { |
80 | max = SMALL_SCALE_THRESHOLD * (qmax - qmin); |
81 | } else if (max == 0.0f) { |
82 | min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); |
83 | } else { |
84 | float amplifier = SMALL_SCALE_THRESHOLD / org_scale; |
85 | min *= amplifier; |
86 | max *= amplifier; |
87 | } |
88 | } |
89 | |
90 | // Zero-point computation. |
91 | // First the initial floating-point computation. The zero-point can be |
92 | // determined from solving an affine equation for any known pair |
93 | // (real value, corresponding quantized value). |
94 | // We know two such pairs: (rmin, qmin) and (rmax, qmax). |
95 | // The arithmetic error on the zero point computed from either pair |
96 | // will be roughly machine_epsilon * (sum of absolute values of terms) |
97 | // so we want to use the variant that adds the smaller terms. |
98 | double zero_point_from_min = qmin - min / static_cast<double>(scale); |
99 | double zero_point_from_max = qmax - max / static_cast<double>(scale); |
100 | double zero_point_from_min_error = |
101 | std::abs(qmin) + std::abs(min / static_cast<double>(scale)); |
102 | double zero_point_from_max_error = |
103 | std::abs(qmax) + std::abs(max / static_cast<double>(scale)); |
104 | double initial_zero_point = |
105 | zero_point_from_min_error < zero_point_from_max_error |
106 | ? zero_point_from_min |
107 | : zero_point_from_max; |
108 | |
109 | // Note: preserve_sparsity here means symmetric quantization. |
110 | // for symmetric quantization, we force zero_point |
111 | // to be a middle value between qmin and qmax. |
112 | // If either min or max is 0, then we just use 0 as zero_point. |
113 | if (min < 0 && max > 0 && preserve_sparsity) { |
114 | initial_zero_point = static_cast<double>(qmin + qmax) / 2; |
115 | } |
116 | |
117 | // Now we need to nudge the zero point to be an integer |
118 | // (our zero points are integer, and this is motivated by the requirement |
119 | // to be able to represent the real value "0" exactly as a quantized value, |
120 | // which is required in multiple places, for example in Im2col with zero |
121 | // padding). |
122 | int32_t nudged_zero_point = 0; |
123 | if (initial_zero_point < qmin) { |
124 | nudged_zero_point = qmin; |
125 | } else if (initial_zero_point > qmax) { |
126 | nudged_zero_point = qmax; |
127 | } else { |
128 | nudged_zero_point = nearbyint(initial_zero_point); |
129 | } |
130 | |
131 | TensorQuantizationParams result; |
132 | result.scale = scale; |
133 | result.zero_point = nudged_zero_point; |
134 | return result; |
135 | } |
136 | |
137 | void ChooseRequantizationMultiplier( |
138 | float real_multiplier, |
139 | int32_t* quantized_multiplier, |
140 | int* right_shift, |
141 | int requantization_multiplier_precision) { |
142 | assert(real_multiplier != 0.f); |
143 | |
144 | // Assuming requantization_multiplier_precision_ = 31, |
145 | // the default right shift is 31 when the real multiplier is already |
146 | // in interval [1/2, 1). |
147 | // Multiplying a 32-bit signed integer with all 31 bits except the sign bit |
148 | // is used followed by 31-bit right shift implements multiplying with a real |
149 | // number in [1/2, 1). |
150 | // We want to utilize all 31 bits except the sign bit in the 32-bit signed |
151 | // integer to get the best accuracy. |
152 | int s = 31; |
153 | |
154 | // We want to bring the real multiplier into the interval [1/2, 1). |
155 | // We can do so by multiplying it by two, and recording how many times |
156 | // we multiplied by two so that we can compensate that by a right |
157 | // shift by the same amount. |
158 | if (real_multiplier > 0.f) { |
159 | while (real_multiplier < 0.5f) { |
160 | real_multiplier *= 2.f; |
161 | s++; |
162 | } |
163 | while (real_multiplier > 1.f) { |
164 | real_multiplier /= 2.f; |
165 | s--; |
166 | } |
167 | } |
168 | // Now that the real multiplier is in [1/2, 1), we convert it |
169 | // into a fixed-point number. |
170 | int64_t q = nearbyint( |
171 | real_multiplier * (1ll << (requantization_multiplier_precision - 1))); |
172 | assert(q <= (1ll << (requantization_multiplier_precision - 1))); |
173 | // Handle the special case when the real multiplier was so close to 1 |
174 | // that its fixed-point approximation was undistinguishable from 1. |
175 | // We handle this by dividing it by two, and remembering to decrement |
176 | // the right shift amount. |
177 | if (q == (1ll << (requantization_multiplier_precision - 1))) { |
178 | q /= 2; |
179 | s--; |
180 | } |
181 | assert(s >= 0); |
182 | assert(q >= 0); |
183 | assert(q <= numeric_limits<int32_t>::max()); |
184 | *quantized_multiplier = static_cast<int32_t>(q); |
185 | *right_shift = s; |
186 | assert(s < 64); |
187 | } |
188 | |
189 | //////////////////////////////////////////////////////////////////////////////// |
190 | // Utility functions |
191 | |
192 | #define FBGEMM_SPECIALIZED_QUANTIZE(T, LEGACY) \ |
193 | template <> \ |
194 | FBGEMM_API void Quantize<T, LEGACY>( \ |
195 | const float* src, \ |
196 | T* dst, \ |
197 | const int64_t len, \ |
198 | const TensorQuantizationParams& qparams, \ |
199 | int thread_id, \ |
200 | int num_threads) { \ |
201 | int64_t i_begin, i_end; \ |
202 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ |
203 | for (int64_t i = i_begin; i < i_end; ++i) { \ |
204 | dst[i] = Quantize<T, LEGACY>(src[i], qparams); \ |
205 | } \ |
206 | } |
207 | FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, true) |
208 | FBGEMM_SPECIALIZED_QUANTIZE(int16_t, true) |
209 | FBGEMM_SPECIALIZED_QUANTIZE(int32_t, true) |
210 | FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, false) |
211 | FBGEMM_SPECIALIZED_QUANTIZE(int16_t, false) |
212 | FBGEMM_SPECIALIZED_QUANTIZE(int32_t, false) |
213 | #undef FBGEMM_SPECIALIZED_QUANTIZE |
214 | |
215 | #define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T, LEGACY) \ |
216 | template <> \ |
217 | FBGEMM_API void Quantize<T, LEGACY>( \ |
218 | const float* src, \ |
219 | T* dst, \ |
220 | int64_t len, \ |
221 | const TensorQuantizationParams& qparams, \ |
222 | int thread_id, \ |
223 | int num_threads) { \ |
224 | bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ |
225 | bool fma_support = cpuinfo_has_x86_fma3(); \ |
226 | int64_t i_begin, i_end; \ |
227 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ |
228 | if (avx2_support && fma_support && qparams.precision == 8) { \ |
229 | /* fast path */ \ |
230 | QuantizeAvx2<T, LEGACY>( \ |
231 | &src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \ |
232 | } else { \ |
233 | for (int64_t i = i_begin; i < i_end; ++i) { \ |
234 | dst[i] = Quantize<T, LEGACY>(src[i], qparams); \ |
235 | } \ |
236 | } \ |
237 | } |
238 | |
239 | FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, true) |
240 | FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, true) |
241 | FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, false) |
242 | FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, false) |
243 | #undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2 |
244 | |
245 | #define FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(T) \ |
246 | template <> \ |
247 | FBGEMM_API void FusedQuantizeDequantize<T>( \ |
248 | const float* src, \ |
249 | float* dst, \ |
250 | int64_t len, \ |
251 | const TensorQuantizationParams& qparams, \ |
252 | int thread_id, \ |
253 | int num_threads, \ |
254 | float noise_ratio) { \ |
255 | bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ |
256 | bool fma_support = cpuinfo_has_x86_fma3(); \ |
257 | int64_t i_begin, i_end; \ |
258 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ |
259 | if (avx2_support && fma_support && qparams.precision == 8) { \ |
260 | /* fast path */ \ |
261 | FusedQuantizeDequantizeAvx2<T>( \ |
262 | &src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \ |
263 | } else if (noise_ratio <= 0.0f) { \ |
264 | for (int64_t i = i_begin; i < i_end; ++i) { \ |
265 | dst[i] = FusedQuantizeDequantize<T>(src[i], qparams); \ |
266 | } \ |
267 | } else { \ |
268 | throw std::runtime_error("Failed to initialize cpuinfo!"); \ |
269 | } \ |
270 | } |
271 | |
272 | FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(int8_t) |
273 | FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(uint8_t) |
274 | #undef FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2 |
275 | |
276 | #define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \ |
277 | template <> \ |
278 | FBGEMM_API void QuantizeGroupwise<T, layout_t::KCX>( \ |
279 | const float* src, \ |
280 | int N, \ |
281 | int C, \ |
282 | int X, \ |
283 | int G, \ |
284 | const float* scales, \ |
285 | const std::int32_t* zero_points, \ |
286 | T* dst) { \ |
287 | assert(C % G == 0); \ |
288 | int C_per_G = C / G; \ |
289 | for (int i = 0; i < N; ++i) { \ |
290 | for (int g = 0; g < G; ++g) { \ |
291 | float scale = scales[g]; \ |
292 | int32_t zero_point = zero_points[g]; \ |
293 | for (int c = 0; c < C / G; ++c) { \ |
294 | for (int x = 0; x < X; ++x) { \ |
295 | dst[(i * C + g * C_per_G + c) * X + x] = Quantize<T>( \ |
296 | src[(i * C + g * C_per_G + c) * X + x], \ |
297 | zero_point, \ |
298 | scale, \ |
299 | 8 * sizeof(T)); \ |
300 | } \ |
301 | } \ |
302 | } \ |
303 | } \ |
304 | } |
305 | FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int8_t) |
306 | FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int32_t) |
307 | #undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX |
308 | |
309 | template <> |
310 | FBGEMM_API void QuantizeGroupwise<uint8_t, layout_t::KCX>( |
311 | const float* src, |
312 | int K, |
313 | int C, |
314 | int X, |
315 | int G, |
316 | const float* scales, |
317 | const std::int32_t* zero_points, |
318 | uint8_t* dst) { |
319 | assert(C % G == 0); |
320 | int C_per_G = C / G; |
321 | fbgemm::TensorQuantizationParams qparams; |
322 | qparams.precision = 8 * sizeof(uint8_t); |
323 | bool takeFastPath = |
324 | cpuinfo_initialize() && fbgemmHasAvx2Support() && cpuinfo_has_x86_fma3(); |
325 | |
326 | for (int i = 0; i < K; ++i) { |
327 | for (int g = 0; g < G; ++g) { |
328 | qparams.scale = scales[g]; |
329 | qparams.zero_point = zero_points[g]; |
330 | if (takeFastPath) { |
331 | QuantizeAvx2( |
332 | src + (i * C + g * C_per_G) * X, |
333 | dst + (i * C + g * C_per_G) * X, |
334 | C_per_G * X, |
335 | qparams); |
336 | } else { |
337 | for (int c = 0; c < C / G; ++c) { |
338 | for (int x = 0; x < X; ++x) { |
339 | dst[(i * C + g * C_per_G + c) * X + x] = Quantize<uint8_t>( |
340 | src[(i * C + g * C_per_G + c) * X + x], |
341 | qparams.zero_point, |
342 | qparams.scale, |
343 | qparams.precision); |
344 | } |
345 | } |
346 | } |
347 | } |
348 | } |
349 | } |
350 | |
351 | #define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(T) \ |
352 | template <> \ |
353 | FBGEMM_API void QuantizeGroupwise<T, layout_t::KXC>( \ |
354 | const float* src, \ |
355 | int K, \ |
356 | int C, \ |
357 | int X, \ |
358 | int G, \ |
359 | const float* scales, \ |
360 | const std::int32_t* zero_points, \ |
361 | T* dst) { \ |
362 | assert(C % G == 0); \ |
363 | int C_per_G = C / G; \ |
364 | for (int i = 0; i < K; ++i) { \ |
365 | for (int x = 0; x < X; ++x) { \ |
366 | for (int g = 0; g < G; ++g) { \ |
367 | float scale = scales[g]; \ |
368 | int32_t zero_point = zero_points[g]; \ |
369 | for (int c = 0; c < C / G; ++c) { \ |
370 | dst[(i * X + x) * C + g * C_per_G + c] = Quantize<T>( \ |
371 | src[(i * X + x) * C + g * C_per_G + c], \ |
372 | zero_point, \ |
373 | scale, \ |
374 | 8 * sizeof(T)); \ |
375 | } \ |
376 | } \ |
377 | } \ |
378 | } \ |
379 | } |
380 | FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int8_t) |
381 | FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(uint8_t) |
382 | FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int32_t) |
383 | #undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC |
384 | |
385 | //////////////////////////////////////////////////////////////////////////////// |
386 | // Requantization (pure fixed-point) |
387 | |
388 | int64_t SaturatingRoundingMulWithShift(int32_t a, int32_t b, int right_shift) { |
389 | int64_t a_64(a); |
390 | int64_t b_64(b); |
391 | int64_t ab_64 = a_64 * b_64; |
392 | |
393 | int64_t nudge = 1ll << (right_shift - 1); |
394 | return (ab_64 + nudge) >> right_shift; |
395 | } |
396 | |
397 | #define FBGEMM_SPECIALIZED_REQUANTIZE(T) \ |
398 | template <> \ |
399 | FBGEMM_API void Requantize<T>( \ |
400 | const int32_t* src, \ |
401 | T* dst, \ |
402 | const int64_t len, \ |
403 | const RequantizationParams& params, \ |
404 | int thread_id, \ |
405 | int num_threads) { \ |
406 | int64_t i_begin, i_end; \ |
407 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ |
408 | for (int64_t i = i_begin; i < i_end; ++i) { \ |
409 | dst[i] = Requantize<T>(src[i], params); \ |
410 | } \ |
411 | } |
412 | FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t) |
413 | FBGEMM_SPECIALIZED_REQUANTIZE(int32_t) |
414 | #undef FBGEMM_SPECIALIZED_REQUANTIZE |
415 | |
416 | template <> |
417 | FBGEMM_API void Requantize<uint8_t>( |
418 | const int32_t* src, |
419 | uint8_t* dst, |
420 | const int64_t len, |
421 | const RequantizationParams& params, |
422 | int thread_id, |
423 | int num_threads) { |
424 | int64_t i_begin, i_end; |
425 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); |
426 | if (params.target_qparams.precision == 8 && cpuinfo_initialize() && |
427 | fbgemmHasAvx2Support()) { |
428 | RequantizeAvx2(&src[i_begin], &dst[i_begin], i_end - i_begin, params); |
429 | } else { |
430 | for (int64_t i = i_begin; i < i_end; ++i) { |
431 | dst[i] = Requantize<uint8_t>(src[i], params); |
432 | } |
433 | } |
434 | } |
435 | |
436 | template <typename T> |
437 | FBGEMM_API void RequantizeFixedPoint( |
438 | const std::int32_t* src, |
439 | T* dst, |
440 | int64_t len, |
441 | const RequantizationParams& params, |
442 | int thread_id, |
443 | int num_threads) { |
444 | int64_t i_begin, i_end; |
445 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); |
446 | if (std::is_same<T, uint8_t>::value && params.target_qparams.precision == 8 && |
447 | cpuinfo_initialize() && fbgemmHasAvx2Support()) { |
448 | RequantizeFixedPointAvx2( |
449 | &src[i_begin], &dst[i_begin], i_end - i_begin, params); |
450 | } else { |
451 | for (int64_t i = i_begin; i < i_end; ++i) { |
452 | dst[i] = RequantizeFixedPoint<T>(src[i], params); |
453 | } |
454 | } |
455 | } |
456 | |
457 | #define FBGEMM_SPECIALIZED_REQUANTIZE(T) \ |
458 | template <> \ |
459 | FBGEMM_API void RequantizeFixedPoint<T>( \ |
460 | const int32_t* src, \ |
461 | T* dst, \ |
462 | const int64_t len, \ |
463 | const RequantizationParams& params, \ |
464 | int thread_id, \ |
465 | int num_threads) { \ |
466 | int64_t i_begin, i_end; \ |
467 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ |
468 | for (int64_t i = i_begin; i < i_end; ++i) { \ |
469 | dst[i] = RequantizeFixedPoint<T>(src[i], params); \ |
470 | } \ |
471 | } |
472 | FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t) |
473 | FBGEMM_SPECIALIZED_REQUANTIZE(int32_t) |
474 | #undef FBGEMM_SPECIALIZED_REQUANTIZE |
475 | |
476 | template <> |
477 | FBGEMM_API void RequantizeFixedPoint<uint8_t>( |
478 | const int32_t* src, |
479 | uint8_t* dst, |
480 | const int64_t len, |
481 | const RequantizationParams& params, |
482 | int thread_id, |
483 | int num_threads) { |
484 | int64_t i_begin, i_end; |
485 | fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); |
486 | |
487 | if (params.target_qparams.precision == 8 && cpuinfo_initialize() && |
488 | fbgemmHasAvx2Support()) { |
489 | RequantizeFixedPointAvx2( |
490 | &src[i_begin], &dst[i_begin], i_end - i_begin, params); |
491 | } else { |
492 | for (int64_t i = i_begin; i < i_end; ++i) { |
493 | dst[i] = RequantizeFixedPoint<uint8_t>(src[i], params); |
494 | } |
495 | } |
496 | } |
497 | |
498 | template <typename InputType> |
499 | void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( |
500 | int bit_rate, |
501 | const InputType* input, |
502 | size_t input_rows, |
503 | int input_columns, |
504 | std::uint8_t* output) { |
505 | static_assert( |
506 | std::is_same<InputType, float>() || std::is_same<InputType, float16>(), |
507 | "Only float and float16 types are allowed." ); |
508 | int num_elem_per_byte = 8 / bit_rate; |
509 | int output_columns = |
510 | (input_columns + num_elem_per_byte - 1) / num_elem_per_byte + |
511 | 2 * sizeof(float16); |
512 | std::vector<float> input_row_float(input_columns); |
513 | for (size_t row = 0; row < input_rows; ++row) { |
514 | const InputType* input_row = input + row * input_columns; |
515 | std::uint8_t* output_row = output + row * output_columns; |
516 | float16* output_row_scale_bias = reinterpret_cast<float16*>( |
517 | output_row + |
518 | (input_columns + num_elem_per_byte - 1) / num_elem_per_byte); |
519 | |
520 | // NOTE: this can be optimized, however we don't care much about performance |
521 | // for reference implementation. |
522 | for (int col = 0; col < input_columns; ++col) { |
523 | if (std::is_same<InputType, float>()) { |
524 | input_row_float[col] = input_row[col]; |
525 | } else { |
526 | input_row_float[col] = cpu_half2float(input_row[col]); |
527 | } |
528 | } |
529 | |
530 | float minimum_element = |
531 | *std::min_element(input_row_float.begin(), input_row_float.end()); |
532 | float maximum_element = |
533 | *std::max_element(input_row_float.begin(), input_row_float.end()); |
534 | // Truncate since bias will be represented by fp16. Keep higher precision |
535 | // max untouched. |
536 | float16 minimum_element_fp16 = cpu_float2half_rn(minimum_element); |
537 | minimum_element = cpu_half2float(minimum_element_fp16); |
538 | const float range = maximum_element - minimum_element; |
539 | |
540 | float scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1); |
541 | float16 scale_fp16 = cpu_float2half_rn(scale); |
542 | scale = cpu_half2float(scale_fp16); |
543 | if (scale == 0) { |
544 | // Corner case handling when maximum_element == minimum_element |
545 | // Any scale would work because X - minimum_element will be 0 for all X |
546 | scale = 1.0f; |
547 | } |
548 | float inverse_scale = 1.0f / scale; |
549 | if (std::isinf(inverse_scale)) { |
550 | scale = 1.0f; |
551 | inverse_scale = 1.0f; |
552 | } |
553 | |
554 | output_row_scale_bias[0] = cpu_float2half_rn(scale); |
555 | output_row_scale_bias[1] = minimum_element_fp16; |
556 | for (int col = 0; col < input_columns; ++col) { |
557 | float X = input_row_float[col]; |
558 | std::uint8_t quantized = std::max( |
559 | 0, |
560 | std::min<int>( |
561 | std::lrintf((X - minimum_element) * inverse_scale), |
562 | (1 << bit_rate) - 1)); |
563 | if (col % num_elem_per_byte == 0) { |
564 | output_row[col / num_elem_per_byte] = quantized; |
565 | } else { |
566 | output_row[col / num_elem_per_byte] |= |
567 | (quantized << ((col % num_elem_per_byte) * bit_rate)); |
568 | } |
569 | } |
570 | } // for each row |
571 | } |
572 | |
573 | template <typename InputType> |
574 | void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( |
575 | int bit_rate, |
576 | const InputType* input, |
577 | size_t input_rows, |
578 | int input_columns, |
579 | std::uint8_t* output) { |
580 | // Currenlty we can only dequantize if the number of input columns |
581 | // is a multiple of number of elements_per_byte |
582 | |
583 | int num_elem_per_byte = 8 / bit_rate; |
584 | if (input_columns % num_elem_per_byte != 0) { |
585 | throw std::runtime_error("Unsupported number of columns" ); |
586 | } |
587 | |
588 | if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { |
589 | switch (bit_rate) { |
590 | case 2: |
591 | FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 2>( |
592 | input, input_rows, input_columns, output); |
593 | break; |
594 | case 4: |
595 | FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 4>( |
596 | input, input_rows, input_columns, output); |
597 | break; |
598 | case 8: |
599 | FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 8>( |
600 | input, input_rows, input_columns, output); |
601 | break; |
602 | default: |
603 | FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>( |
604 | bit_rate, input, input_rows, input_columns, output); |
605 | } |
606 | } else { |
607 | FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>( |
608 | bit_rate, input, input_rows, input_columns, output); |
609 | } |
610 | } |
611 | |
612 | template <typename InputType> |
613 | void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( |
614 | const InputType* input, |
615 | size_t input_rows, |
616 | int input_columns, |
617 | std::uint8_t* output) { |
618 | constexpr float kEpsilon = 1e-8f; |
619 | |
620 | int output_columns = input_columns + 2 * sizeof(float); |
621 | std::vector<float> input_row_float(input_columns); |
622 | for (size_t row = 0; row < input_rows; ++row) { |
623 | const InputType* input_row = input + row * input_columns; |
624 | std::uint8_t* output_row = output + row * output_columns; |
625 | float* output_row_scale_bias = |
626 | reinterpret_cast<float*>(output_row + input_columns); |
627 | |
628 | for (int col = 0; col < input_columns; ++col) { |
629 | if (std::is_same<InputType, float>()) { |
630 | input_row_float[col] = input_row[col]; |
631 | } else { |
632 | input_row_float[col] = cpu_half2float(input_row[col]); |
633 | } |
634 | } |
635 | |
636 | float minimum_element = |
637 | *std::min_element(input_row_float.begin(), input_row_float.end()); |
638 | float maximum_element = |
639 | *std::max_element(input_row_float.begin(), input_row_float.end()); |
640 | float range = maximum_element - minimum_element; |
641 | |
642 | output_row_scale_bias[0] = range / 255.0f; |
643 | output_row_scale_bias[1] = minimum_element; |
644 | const auto inverse_scale = 255.0f / (range + kEpsilon); |
645 | for (int col = 0; col < input_columns; ++col) { |
646 | output_row[col] = |
647 | std::lrintf((input_row_float[col] - minimum_element) * inverse_scale); |
648 | } |
649 | } // for each row |
650 | } |
651 | |
652 | template <typename InputType> |
653 | void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( |
654 | const InputType* input, |
655 | size_t input_rows, |
656 | int input_columns, |
657 | std::uint8_t* output) { |
658 | if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { |
659 | FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<InputType>( |
660 | input, input_rows, input_columns, output); |
661 | } else { |
662 | FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<InputType>( |
663 | input, input_rows, input_columns, output); |
664 | } |
665 | } |
666 | |
667 | template <typename OutputType> |
668 | void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( |
669 | int bit_rate, |
670 | const uint8_t* input, |
671 | size_t input_rows, |
672 | int input_columns, |
673 | OutputType* output) { |
674 | static_assert( |
675 | std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(), |
676 | "Only float and float16 types are allowed." ); |
677 | int num_elem_per_byte = 8 / bit_rate; |
678 | int output_columns = |
679 | (input_columns - 2 * sizeof(float16)) * num_elem_per_byte; |
680 | |
681 | for (size_t row = 0; row < input_rows; ++row) { |
682 | const std::uint8_t* input_row = input + row * input_columns; |
683 | const float16* input_row_scale_bias = reinterpret_cast<const float16*>( |
684 | input_row + |
685 | (output_columns + num_elem_per_byte - 1) / num_elem_per_byte); |
686 | float scale = cpu_half2float(input_row_scale_bias[0]); |
687 | float bias = cpu_half2float(input_row_scale_bias[1]); |
688 | OutputType* output_row = output + row * output_columns; |
689 | |
690 | for (int col = 0; col < output_columns; ++col) { |
691 | std::uint8_t quantized = input_row[col / num_elem_per_byte]; |
692 | quantized >>= (col % num_elem_per_byte) * bit_rate; |
693 | quantized &= (1 << bit_rate) - 1; |
694 | float output_value = scale * quantized + bias; |
695 | if (std::is_same<OutputType, float>()) { |
696 | output_row[col] = output_value; |
697 | } else { |
698 | output_row[col] = cpu_float2half_rn(output_value); |
699 | } |
700 | } |
701 | } |
702 | } |
703 | |
704 | template <typename OutputType> |
705 | void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( |
706 | int bit_rate, |
707 | const uint8_t* input, |
708 | size_t input_rows, |
709 | int input_columns, |
710 | OutputType* output) { |
711 | if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { |
712 | switch (bit_rate) { |
713 | case 2: |
714 | FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 2>( |
715 | input, input_rows, input_columns, output); |
716 | break; |
717 | case 4: |
718 | FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 4>( |
719 | input, input_rows, input_columns, output); |
720 | break; |
721 | case 8: |
722 | FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 8>( |
723 | input, input_rows, input_columns, output); |
724 | break; |
725 | default: |
726 | FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>( |
727 | bit_rate, input, input_rows, input_columns, output); |
728 | } |
729 | } else { |
730 | FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>( |
731 | bit_rate, input, input_rows, input_columns, output); |
732 | } |
733 | } |
734 | |
735 | template <typename OutputType> |
736 | void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( |
737 | const std::uint8_t* input, |
738 | size_t input_rows, |
739 | int input_columns, |
740 | OutputType* output) { |
741 | int output_columns = input_columns - 2 * sizeof(float); |
742 | |
743 | for (size_t row = 0; row < input_rows; ++row) { |
744 | const std::uint8_t* input_row = input + row * input_columns; |
745 | const float* input_row_scale_bias = |
746 | reinterpret_cast<const float*>(input_row + output_columns); |
747 | OutputType* output_row = output + row * output_columns; |
748 | |
749 | for (int col = 0; col < output_columns; ++col) { |
750 | float output_value = |
751 | input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; |
752 | if (std::is_same<OutputType, float>()) { |
753 | output_row[col] = output_value; |
754 | } else { |
755 | output_row[col] = cpu_float2half_rn(output_value); |
756 | } |
757 | } |
758 | } |
759 | } |
760 | |
761 | template <typename OutputType> |
762 | void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( |
763 | const std::uint8_t* input, |
764 | size_t input_rows, |
765 | int input_columns, |
766 | OutputType* output) { |
767 | if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { |
768 | Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>( |
769 | input, input_rows, input_columns, output); |
770 | } else { |
771 | Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>( |
772 | input, input_rows, input_columns, output); |
773 | } |
774 | } |
775 | |
776 | #define INSTANTIATE_QuantizationFunctions(type) \ |
777 | template FBGEMM_API void \ |
778 | FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<type>( \ |
779 | int bit_rate, \ |
780 | const type* input, \ |
781 | size_t input_rows, \ |
782 | int input_columns, \ |
783 | std::uint8_t* output); \ |
784 | template FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<type>( \ |
785 | int bit_rate, \ |
786 | const type* input, \ |
787 | size_t input_rows, \ |
788 | int input_columns, \ |
789 | std::uint8_t* output); \ |
790 | template FBGEMM_API void \ |
791 | FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<type>( \ |
792 | int bit_rate, \ |
793 | const uint8_t* input, \ |
794 | size_t input_rows, \ |
795 | int input_columns, \ |
796 | type* output); \ |
797 | template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type>( \ |
798 | int bit_rate, \ |
799 | const uint8_t* input, \ |
800 | size_t input_rows, \ |
801 | int input_columns, \ |
802 | type* output); \ |
803 | template FBGEMM_API void \ |
804 | FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<type>( \ |
805 | const type* input, \ |
806 | size_t input_rows, \ |
807 | int input_columns, \ |
808 | std::uint8_t* output); \ |
809 | template FBGEMM_API void \ |
810 | FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<type>( \ |
811 | const type* input, \ |
812 | size_t input_rows, \ |
813 | int input_columns, \ |
814 | std::uint8_t* output); \ |
815 | template FBGEMM_API void \ |
816 | Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type>( \ |
817 | const uint8_t* input, \ |
818 | size_t input_rows, \ |
819 | int input_columns, \ |
820 | type* output); \ |
821 | template FBGEMM_API void \ |
822 | Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type>( \ |
823 | const uint8_t* input, \ |
824 | size_t input_rows, \ |
825 | int input_columns, \ |
826 | type* output); |
827 | |
828 | // clang-format off |
829 | INSTANTIATE_QuantizationFunctions(float) |
830 | INSTANTIATE_QuantizationFunctions(float16) |
831 | // clang-format on |
832 | |
833 | #undef INSTANTIATE_QuantizationFunctions |
834 | |
835 | } // namespace fbgemm |
836 | |