1#define FBGEMM_EXPORTS
2#include <algorithm>
3#include <iterator>
4#include <numeric>
5#include <type_traits>
6
7#include "fbgemm/QuantUtils.h"
8
9#include <cpuinfo.h>
10
11#include "fbgemm/Fbgemm.h"
12
13#include "fbgemm/Types.h"
14
15namespace fbgemm {
16
17using namespace std;
18
19// Use fp16_min as the small scale cutoff because we don't want to use scales in
20// fp16 subnormal range. This is to be consistent with Glow and FakeLowP
21// implementation for NNPI.
22constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
23
24float TensorQuantizationParams::Min() const {
25 return Dequantize(0, *this);
26}
27
28float TensorQuantizationParams::Max() const {
29 return Dequantize((1 << precision) - 1, *this);
30}
31
32TensorQuantizationParams ChooseQuantizationParams(
33 float min,
34 float max,
35 int32_t qmin,
36 int32_t qmax,
37 bool preserve_sparsity,
38 bool force_scale_power_of_two) {
39 if (min < 0 && max > 0 && preserve_sparsity) {
40 int symmetric_qmin = -((qmax - qmin) / 2 + 1);
41 int symmetric_qmax = (qmax - qmin) / 2;
42 double max_scale =
43 std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
44 min = max_scale * symmetric_qmin;
45 max = max_scale * symmetric_qmax;
46 }
47
48 // We extend the [min, max] interval to ensure that it contains 0.
49 // Otherwise, we would not meet the requirement that 0 be an exactly
50 // representable value.
51 min = std::min(min, 0.f);
52 max = std::max(max, 0.f);
53
54 // Use double precision for intermediate computation but use single precision
55 // in final number to reflect the actual number used during quantization.
56 float scale = (static_cast<double>(max) - min) / (qmax - qmin);
57 // If scale is 0 or too small so its reciprocal is infinity, we arbitrary
58 // adjust the scale to 0.1 . We want to avoid scale's reciprocal being
59 // infinity because some of fbgemm code pre-computes scale's reciprocal to do
60 // multiplication instead of division in the time critical part of code.
61 if (scale == 0.0f || isinf(1.0f / scale)) {
62 scale = 0.1;
63 }
64 assert(scale > 0);
65
66 if (force_scale_power_of_two) {
67 if (scale < 1) {
68 scale = 1.0 / (1 << static_cast<int>(floor(log2(1.0 / scale))));
69 } else {
70 scale = 1 << static_cast<int>(ceil(log2(scale)));
71 }
72 }
73
74 // Cut off small scale
75 if (scale < SMALL_SCALE_THRESHOLD) {
76 float org_scale = scale;
77 scale = SMALL_SCALE_THRESHOLD;
78 // Adjust the min and max based on the new scale
79 if (min == 0.0f) {
80 max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
81 } else if (max == 0.0f) {
82 min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
83 } else {
84 float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
85 min *= amplifier;
86 max *= amplifier;
87 }
88 }
89
90 // Zero-point computation.
91 // First the initial floating-point computation. The zero-point can be
92 // determined from solving an affine equation for any known pair
93 // (real value, corresponding quantized value).
94 // We know two such pairs: (rmin, qmin) and (rmax, qmax).
95 // The arithmetic error on the zero point computed from either pair
96 // will be roughly machine_epsilon * (sum of absolute values of terms)
97 // so we want to use the variant that adds the smaller terms.
98 double zero_point_from_min = qmin - min / static_cast<double>(scale);
99 double zero_point_from_max = qmax - max / static_cast<double>(scale);
100 double zero_point_from_min_error =
101 std::abs(qmin) + std::abs(min / static_cast<double>(scale));
102 double zero_point_from_max_error =
103 std::abs(qmax) + std::abs(max / static_cast<double>(scale));
104 double initial_zero_point =
105 zero_point_from_min_error < zero_point_from_max_error
106 ? zero_point_from_min
107 : zero_point_from_max;
108
109 // Note: preserve_sparsity here means symmetric quantization.
110 // for symmetric quantization, we force zero_point
111 // to be a middle value between qmin and qmax.
112 // If either min or max is 0, then we just use 0 as zero_point.
113 if (min < 0 && max > 0 && preserve_sparsity) {
114 initial_zero_point = static_cast<double>(qmin + qmax) / 2;
115 }
116
117 // Now we need to nudge the zero point to be an integer
118 // (our zero points are integer, and this is motivated by the requirement
119 // to be able to represent the real value "0" exactly as a quantized value,
120 // which is required in multiple places, for example in Im2col with zero
121 // padding).
122 int32_t nudged_zero_point = 0;
123 if (initial_zero_point < qmin) {
124 nudged_zero_point = qmin;
125 } else if (initial_zero_point > qmax) {
126 nudged_zero_point = qmax;
127 } else {
128 nudged_zero_point = nearbyint(initial_zero_point);
129 }
130
131 TensorQuantizationParams result;
132 result.scale = scale;
133 result.zero_point = nudged_zero_point;
134 return result;
135}
136
137void ChooseRequantizationMultiplier(
138 float real_multiplier,
139 int32_t* quantized_multiplier,
140 int* right_shift,
141 int requantization_multiplier_precision) {
142 assert(real_multiplier != 0.f);
143
144 // Assuming requantization_multiplier_precision_ = 31,
145 // the default right shift is 31 when the real multiplier is already
146 // in interval [1/2, 1).
147 // Multiplying a 32-bit signed integer with all 31 bits except the sign bit
148 // is used followed by 31-bit right shift implements multiplying with a real
149 // number in [1/2, 1).
150 // We want to utilize all 31 bits except the sign bit in the 32-bit signed
151 // integer to get the best accuracy.
152 int s = 31;
153
154 // We want to bring the real multiplier into the interval [1/2, 1).
155 // We can do so by multiplying it by two, and recording how many times
156 // we multiplied by two so that we can compensate that by a right
157 // shift by the same amount.
158 if (real_multiplier > 0.f) {
159 while (real_multiplier < 0.5f) {
160 real_multiplier *= 2.f;
161 s++;
162 }
163 while (real_multiplier > 1.f) {
164 real_multiplier /= 2.f;
165 s--;
166 }
167 }
168 // Now that the real multiplier is in [1/2, 1), we convert it
169 // into a fixed-point number.
170 int64_t q = nearbyint(
171 real_multiplier * (1ll << (requantization_multiplier_precision - 1)));
172 assert(q <= (1ll << (requantization_multiplier_precision - 1)));
173 // Handle the special case when the real multiplier was so close to 1
174 // that its fixed-point approximation was undistinguishable from 1.
175 // We handle this by dividing it by two, and remembering to decrement
176 // the right shift amount.
177 if (q == (1ll << (requantization_multiplier_precision - 1))) {
178 q /= 2;
179 s--;
180 }
181 assert(s >= 0);
182 assert(q >= 0);
183 assert(q <= numeric_limits<int32_t>::max());
184 *quantized_multiplier = static_cast<int32_t>(q);
185 *right_shift = s;
186 assert(s < 64);
187}
188
189////////////////////////////////////////////////////////////////////////////////
190// Utility functions
191
192#define FBGEMM_SPECIALIZED_QUANTIZE(T, LEGACY) \
193 template <> \
194 FBGEMM_API void Quantize<T, LEGACY>( \
195 const float* src, \
196 T* dst, \
197 const int64_t len, \
198 const TensorQuantizationParams& qparams, \
199 int thread_id, \
200 int num_threads) { \
201 int64_t i_begin, i_end; \
202 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
203 for (int64_t i = i_begin; i < i_end; ++i) { \
204 dst[i] = Quantize<T, LEGACY>(src[i], qparams); \
205 } \
206 }
207FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, true)
208FBGEMM_SPECIALIZED_QUANTIZE(int16_t, true)
209FBGEMM_SPECIALIZED_QUANTIZE(int32_t, true)
210FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, false)
211FBGEMM_SPECIALIZED_QUANTIZE(int16_t, false)
212FBGEMM_SPECIALIZED_QUANTIZE(int32_t, false)
213#undef FBGEMM_SPECIALIZED_QUANTIZE
214
215#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T, LEGACY) \
216 template <> \
217 FBGEMM_API void Quantize<T, LEGACY>( \
218 const float* src, \
219 T* dst, \
220 int64_t len, \
221 const TensorQuantizationParams& qparams, \
222 int thread_id, \
223 int num_threads) { \
224 bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
225 bool fma_support = cpuinfo_has_x86_fma3(); \
226 int64_t i_begin, i_end; \
227 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
228 if (avx2_support && fma_support && qparams.precision == 8) { \
229 /* fast path */ \
230 QuantizeAvx2<T, LEGACY>( \
231 &src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \
232 } else { \
233 for (int64_t i = i_begin; i < i_end; ++i) { \
234 dst[i] = Quantize<T, LEGACY>(src[i], qparams); \
235 } \
236 } \
237 }
238
239FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, true)
240FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, true)
241FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, false)
242FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, false)
243#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2
244
245#define FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(T) \
246 template <> \
247 FBGEMM_API void FusedQuantizeDequantize<T>( \
248 const float* src, \
249 float* dst, \
250 int64_t len, \
251 const TensorQuantizationParams& qparams, \
252 int thread_id, \
253 int num_threads, \
254 float noise_ratio) { \
255 bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \
256 bool fma_support = cpuinfo_has_x86_fma3(); \
257 int64_t i_begin, i_end; \
258 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
259 if (avx2_support && fma_support && qparams.precision == 8) { \
260 /* fast path */ \
261 FusedQuantizeDequantizeAvx2<T>( \
262 &src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \
263 } else if (noise_ratio <= 0.0f) { \
264 for (int64_t i = i_begin; i < i_end; ++i) { \
265 dst[i] = FusedQuantizeDequantize<T>(src[i], qparams); \
266 } \
267 } else { \
268 throw std::runtime_error("Failed to initialize cpuinfo!"); \
269 } \
270 }
271
272FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(int8_t)
273FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2(uint8_t)
274#undef FBGEMM_SPECIALIZED_FUSED_QUANTIZE_DEQUANTIZE_AVX2
275
276#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \
277 template <> \
278 FBGEMM_API void QuantizeGroupwise<T, layout_t::KCX>( \
279 const float* src, \
280 int N, \
281 int C, \
282 int X, \
283 int G, \
284 const float* scales, \
285 const std::int32_t* zero_points, \
286 T* dst) { \
287 assert(C % G == 0); \
288 int C_per_G = C / G; \
289 for (int i = 0; i < N; ++i) { \
290 for (int g = 0; g < G; ++g) { \
291 float scale = scales[g]; \
292 int32_t zero_point = zero_points[g]; \
293 for (int c = 0; c < C / G; ++c) { \
294 for (int x = 0; x < X; ++x) { \
295 dst[(i * C + g * C_per_G + c) * X + x] = Quantize<T>( \
296 src[(i * C + g * C_per_G + c) * X + x], \
297 zero_point, \
298 scale, \
299 8 * sizeof(T)); \
300 } \
301 } \
302 } \
303 } \
304 }
305FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int8_t)
306FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(int32_t)
307#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX
308
309template <>
310FBGEMM_API void QuantizeGroupwise<uint8_t, layout_t::KCX>(
311 const float* src,
312 int K,
313 int C,
314 int X,
315 int G,
316 const float* scales,
317 const std::int32_t* zero_points,
318 uint8_t* dst) {
319 assert(C % G == 0);
320 int C_per_G = C / G;
321 fbgemm::TensorQuantizationParams qparams;
322 qparams.precision = 8 * sizeof(uint8_t);
323 bool takeFastPath =
324 cpuinfo_initialize() && fbgemmHasAvx2Support() && cpuinfo_has_x86_fma3();
325
326 for (int i = 0; i < K; ++i) {
327 for (int g = 0; g < G; ++g) {
328 qparams.scale = scales[g];
329 qparams.zero_point = zero_points[g];
330 if (takeFastPath) {
331 QuantizeAvx2(
332 src + (i * C + g * C_per_G) * X,
333 dst + (i * C + g * C_per_G) * X,
334 C_per_G * X,
335 qparams);
336 } else {
337 for (int c = 0; c < C / G; ++c) {
338 for (int x = 0; x < X; ++x) {
339 dst[(i * C + g * C_per_G + c) * X + x] = Quantize<uint8_t>(
340 src[(i * C + g * C_per_G + c) * X + x],
341 qparams.zero_point,
342 qparams.scale,
343 qparams.precision);
344 }
345 }
346 }
347 }
348 }
349}
350
351#define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(T) \
352 template <> \
353 FBGEMM_API void QuantizeGroupwise<T, layout_t::KXC>( \
354 const float* src, \
355 int K, \
356 int C, \
357 int X, \
358 int G, \
359 const float* scales, \
360 const std::int32_t* zero_points, \
361 T* dst) { \
362 assert(C % G == 0); \
363 int C_per_G = C / G; \
364 for (int i = 0; i < K; ++i) { \
365 for (int x = 0; x < X; ++x) { \
366 for (int g = 0; g < G; ++g) { \
367 float scale = scales[g]; \
368 int32_t zero_point = zero_points[g]; \
369 for (int c = 0; c < C / G; ++c) { \
370 dst[(i * X + x) * C + g * C_per_G + c] = Quantize<T>( \
371 src[(i * X + x) * C + g * C_per_G + c], \
372 zero_point, \
373 scale, \
374 8 * sizeof(T)); \
375 } \
376 } \
377 } \
378 } \
379 }
380FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int8_t)
381FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(uint8_t)
382FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC(int32_t)
383#undef FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKXC
384
385////////////////////////////////////////////////////////////////////////////////
386// Requantization (pure fixed-point)
387
388int64_t SaturatingRoundingMulWithShift(int32_t a, int32_t b, int right_shift) {
389 int64_t a_64(a);
390 int64_t b_64(b);
391 int64_t ab_64 = a_64 * b_64;
392
393 int64_t nudge = 1ll << (right_shift - 1);
394 return (ab_64 + nudge) >> right_shift;
395}
396
397#define FBGEMM_SPECIALIZED_REQUANTIZE(T) \
398 template <> \
399 FBGEMM_API void Requantize<T>( \
400 const int32_t* src, \
401 T* dst, \
402 const int64_t len, \
403 const RequantizationParams& params, \
404 int thread_id, \
405 int num_threads) { \
406 int64_t i_begin, i_end; \
407 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
408 for (int64_t i = i_begin; i < i_end; ++i) { \
409 dst[i] = Requantize<T>(src[i], params); \
410 } \
411 }
412FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t)
413FBGEMM_SPECIALIZED_REQUANTIZE(int32_t)
414#undef FBGEMM_SPECIALIZED_REQUANTIZE
415
416template <>
417FBGEMM_API void Requantize<uint8_t>(
418 const int32_t* src,
419 uint8_t* dst,
420 const int64_t len,
421 const RequantizationParams& params,
422 int thread_id,
423 int num_threads) {
424 int64_t i_begin, i_end;
425 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
426 if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
427 fbgemmHasAvx2Support()) {
428 RequantizeAvx2(&src[i_begin], &dst[i_begin], i_end - i_begin, params);
429 } else {
430 for (int64_t i = i_begin; i < i_end; ++i) {
431 dst[i] = Requantize<uint8_t>(src[i], params);
432 }
433 }
434}
435
436template <typename T>
437FBGEMM_API void RequantizeFixedPoint(
438 const std::int32_t* src,
439 T* dst,
440 int64_t len,
441 const RequantizationParams& params,
442 int thread_id,
443 int num_threads) {
444 int64_t i_begin, i_end;
445 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
446 if (std::is_same<T, uint8_t>::value && params.target_qparams.precision == 8 &&
447 cpuinfo_initialize() && fbgemmHasAvx2Support()) {
448 RequantizeFixedPointAvx2(
449 &src[i_begin], &dst[i_begin], i_end - i_begin, params);
450 } else {
451 for (int64_t i = i_begin; i < i_end; ++i) {
452 dst[i] = RequantizeFixedPoint<T>(src[i], params);
453 }
454 }
455}
456
457#define FBGEMM_SPECIALIZED_REQUANTIZE(T) \
458 template <> \
459 FBGEMM_API void RequantizeFixedPoint<T>( \
460 const int32_t* src, \
461 T* dst, \
462 const int64_t len, \
463 const RequantizationParams& params, \
464 int thread_id, \
465 int num_threads) { \
466 int64_t i_begin, i_end; \
467 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \
468 for (int64_t i = i_begin; i < i_end; ++i) { \
469 dst[i] = RequantizeFixedPoint<T>(src[i], params); \
470 } \
471 }
472FBGEMM_SPECIALIZED_REQUANTIZE(uint16_t)
473FBGEMM_SPECIALIZED_REQUANTIZE(int32_t)
474#undef FBGEMM_SPECIALIZED_REQUANTIZE
475
476template <>
477FBGEMM_API void RequantizeFixedPoint<uint8_t>(
478 const int32_t* src,
479 uint8_t* dst,
480 const int64_t len,
481 const RequantizationParams& params,
482 int thread_id,
483 int num_threads) {
484 int64_t i_begin, i_end;
485 fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
486
487 if (params.target_qparams.precision == 8 && cpuinfo_initialize() &&
488 fbgemmHasAvx2Support()) {
489 RequantizeFixedPointAvx2(
490 &src[i_begin], &dst[i_begin], i_end - i_begin, params);
491 } else {
492 for (int64_t i = i_begin; i < i_end; ++i) {
493 dst[i] = RequantizeFixedPoint<uint8_t>(src[i], params);
494 }
495 }
496}
497
498template <typename InputType>
499void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(
500 int bit_rate,
501 const InputType* input,
502 size_t input_rows,
503 int input_columns,
504 std::uint8_t* output) {
505 static_assert(
506 std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
507 "Only float and float16 types are allowed.");
508 int num_elem_per_byte = 8 / bit_rate;
509 int output_columns =
510 (input_columns + num_elem_per_byte - 1) / num_elem_per_byte +
511 2 * sizeof(float16);
512 std::vector<float> input_row_float(input_columns);
513 for (size_t row = 0; row < input_rows; ++row) {
514 const InputType* input_row = input + row * input_columns;
515 std::uint8_t* output_row = output + row * output_columns;
516 float16* output_row_scale_bias = reinterpret_cast<float16*>(
517 output_row +
518 (input_columns + num_elem_per_byte - 1) / num_elem_per_byte);
519
520 // NOTE: this can be optimized, however we don't care much about performance
521 // for reference implementation.
522 for (int col = 0; col < input_columns; ++col) {
523 if (std::is_same<InputType, float>()) {
524 input_row_float[col] = input_row[col];
525 } else {
526 input_row_float[col] = cpu_half2float(input_row[col]);
527 }
528 }
529
530 float minimum_element =
531 *std::min_element(input_row_float.begin(), input_row_float.end());
532 float maximum_element =
533 *std::max_element(input_row_float.begin(), input_row_float.end());
534 // Truncate since bias will be represented by fp16. Keep higher precision
535 // max untouched.
536 float16 minimum_element_fp16 = cpu_float2half_rn(minimum_element);
537 minimum_element = cpu_half2float(minimum_element_fp16);
538 const float range = maximum_element - minimum_element;
539
540 float scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1);
541 float16 scale_fp16 = cpu_float2half_rn(scale);
542 scale = cpu_half2float(scale_fp16);
543 if (scale == 0) {
544 // Corner case handling when maximum_element == minimum_element
545 // Any scale would work because X - minimum_element will be 0 for all X
546 scale = 1.0f;
547 }
548 float inverse_scale = 1.0f / scale;
549 if (std::isinf(inverse_scale)) {
550 scale = 1.0f;
551 inverse_scale = 1.0f;
552 }
553
554 output_row_scale_bias[0] = cpu_float2half_rn(scale);
555 output_row_scale_bias[1] = minimum_element_fp16;
556 for (int col = 0; col < input_columns; ++col) {
557 float X = input_row_float[col];
558 std::uint8_t quantized = std::max(
559 0,
560 std::min<int>(
561 std::lrintf((X - minimum_element) * inverse_scale),
562 (1 << bit_rate) - 1));
563 if (col % num_elem_per_byte == 0) {
564 output_row[col / num_elem_per_byte] = quantized;
565 } else {
566 output_row[col / num_elem_per_byte] |=
567 (quantized << ((col % num_elem_per_byte) * bit_rate));
568 }
569 }
570 } // for each row
571}
572
573template <typename InputType>
574void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
575 int bit_rate,
576 const InputType* input,
577 size_t input_rows,
578 int input_columns,
579 std::uint8_t* output) {
580 // Currenlty we can only dequantize if the number of input columns
581 // is a multiple of number of elements_per_byte
582
583 int num_elem_per_byte = 8 / bit_rate;
584 if (input_columns % num_elem_per_byte != 0) {
585 throw std::runtime_error("Unsupported number of columns");
586 }
587
588 if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
589 switch (bit_rate) {
590 case 2:
591 FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 2>(
592 input, input_rows, input_columns, output);
593 break;
594 case 4:
595 FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 4>(
596 input, input_rows, input_columns, output);
597 break;
598 case 8:
599 FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<InputType, 8>(
600 input, input_rows, input_columns, output);
601 break;
602 default:
603 FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
604 bit_rate, input, input_rows, input_columns, output);
605 }
606 } else {
607 FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<InputType>(
608 bit_rate, input, input_rows, input_columns, output);
609 }
610}
611
612template <typename InputType>
613void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
614 const InputType* input,
615 size_t input_rows,
616 int input_columns,
617 std::uint8_t* output) {
618 constexpr float kEpsilon = 1e-8f;
619
620 int output_columns = input_columns + 2 * sizeof(float);
621 std::vector<float> input_row_float(input_columns);
622 for (size_t row = 0; row < input_rows; ++row) {
623 const InputType* input_row = input + row * input_columns;
624 std::uint8_t* output_row = output + row * output_columns;
625 float* output_row_scale_bias =
626 reinterpret_cast<float*>(output_row + input_columns);
627
628 for (int col = 0; col < input_columns; ++col) {
629 if (std::is_same<InputType, float>()) {
630 input_row_float[col] = input_row[col];
631 } else {
632 input_row_float[col] = cpu_half2float(input_row[col]);
633 }
634 }
635
636 float minimum_element =
637 *std::min_element(input_row_float.begin(), input_row_float.end());
638 float maximum_element =
639 *std::max_element(input_row_float.begin(), input_row_float.end());
640 float range = maximum_element - minimum_element;
641
642 output_row_scale_bias[0] = range / 255.0f;
643 output_row_scale_bias[1] = minimum_element;
644 const auto inverse_scale = 255.0f / (range + kEpsilon);
645 for (int col = 0; col < input_columns; ++col) {
646 output_row[col] =
647 std::lrintf((input_row_float[col] - minimum_element) * inverse_scale);
648 }
649 } // for each row
650}
651
652template <typename InputType>
653void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
654 const InputType* input,
655 size_t input_rows,
656 int input_columns,
657 std::uint8_t* output) {
658 if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
659 FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<InputType>(
660 input, input_rows, input_columns, output);
661 } else {
662 FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<InputType>(
663 input, input_rows, input_columns, output);
664 }
665}
666
667template <typename OutputType>
668void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
669 int bit_rate,
670 const uint8_t* input,
671 size_t input_rows,
672 int input_columns,
673 OutputType* output) {
674 static_assert(
675 std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(),
676 "Only float and float16 types are allowed.");
677 int num_elem_per_byte = 8 / bit_rate;
678 int output_columns =
679 (input_columns - 2 * sizeof(float16)) * num_elem_per_byte;
680
681 for (size_t row = 0; row < input_rows; ++row) {
682 const std::uint8_t* input_row = input + row * input_columns;
683 const float16* input_row_scale_bias = reinterpret_cast<const float16*>(
684 input_row +
685 (output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
686 float scale = cpu_half2float(input_row_scale_bias[0]);
687 float bias = cpu_half2float(input_row_scale_bias[1]);
688 OutputType* output_row = output + row * output_columns;
689
690 for (int col = 0; col < output_columns; ++col) {
691 std::uint8_t quantized = input_row[col / num_elem_per_byte];
692 quantized >>= (col % num_elem_per_byte) * bit_rate;
693 quantized &= (1 << bit_rate) - 1;
694 float output_value = scale * quantized + bias;
695 if (std::is_same<OutputType, float>()) {
696 output_row[col] = output_value;
697 } else {
698 output_row[col] = cpu_float2half_rn(output_value);
699 }
700 }
701 }
702}
703
704template <typename OutputType>
705void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
706 int bit_rate,
707 const uint8_t* input,
708 size_t input_rows,
709 int input_columns,
710 OutputType* output) {
711 if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
712 switch (bit_rate) {
713 case 2:
714 FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 2>(
715 input, input_rows, input_columns, output);
716 break;
717 case 4:
718 FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 4>(
719 input, input_rows, input_columns, output);
720 break;
721 case 8:
722 FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<OutputType, 8>(
723 input, input_rows, input_columns, output);
724 break;
725 default:
726 FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>(
727 bit_rate, input, input_rows, input_columns, output);
728 }
729 } else {
730 FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<OutputType>(
731 bit_rate, input, input_rows, input_columns, output);
732 }
733}
734
735template <typename OutputType>
736void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
737 const std::uint8_t* input,
738 size_t input_rows,
739 int input_columns,
740 OutputType* output) {
741 int output_columns = input_columns - 2 * sizeof(float);
742
743 for (size_t row = 0; row < input_rows; ++row) {
744 const std::uint8_t* input_row = input + row * input_columns;
745 const float* input_row_scale_bias =
746 reinterpret_cast<const float*>(input_row + output_columns);
747 OutputType* output_row = output + row * output_columns;
748
749 for (int col = 0; col < output_columns; ++col) {
750 float output_value =
751 input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
752 if (std::is_same<OutputType, float>()) {
753 output_row[col] = output_value;
754 } else {
755 output_row[col] = cpu_float2half_rn(output_value);
756 }
757 }
758 }
759}
760
761template <typename OutputType>
762void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
763 const std::uint8_t* input,
764 size_t input_rows,
765 int input_columns,
766 OutputType* output) {
767 if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
768 Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
769 input, input_rows, input_columns, output);
770 } else {
771 Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>(
772 input, input_rows, input_columns, output);
773 }
774}
775
776#define INSTANTIATE_QuantizationFunctions(type) \
777 template FBGEMM_API void \
778 FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<type>( \
779 int bit_rate, \
780 const type* input, \
781 size_t input_rows, \
782 int input_columns, \
783 std::uint8_t* output); \
784 template FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<type>( \
785 int bit_rate, \
786 const type* input, \
787 size_t input_rows, \
788 int input_columns, \
789 std::uint8_t* output); \
790 template FBGEMM_API void \
791 FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<type>( \
792 int bit_rate, \
793 const uint8_t* input, \
794 size_t input_rows, \
795 int input_columns, \
796 type* output); \
797 template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf<type>( \
798 int bit_rate, \
799 const uint8_t* input, \
800 size_t input_rows, \
801 int input_columns, \
802 type* output); \
803 template FBGEMM_API void \
804 FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<type>( \
805 const type* input, \
806 size_t input_rows, \
807 int input_columns, \
808 std::uint8_t* output); \
809 template FBGEMM_API void \
810 FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<type>( \
811 const type* input, \
812 size_t input_rows, \
813 int input_columns, \
814 std::uint8_t* output); \
815 template FBGEMM_API void \
816 Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type>( \
817 const uint8_t* input, \
818 size_t input_rows, \
819 int input_columns, \
820 type* output); \
821 template FBGEMM_API void \
822 Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type>( \
823 const uint8_t* input, \
824 size_t input_rows, \
825 int input_columns, \
826 type* output);
827
828// clang-format off
829INSTANTIATE_QuantizationFunctions(float)
830INSTANTIATE_QuantizationFunctions(float16)
831// clang-format on
832
833#undef INSTANTIATE_QuantizationFunctions
834
835} // namespace fbgemm
836