1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/QuantUtilsAvx2.h"
9#if defined(__x86_64__) || defined(__i386__) || \
10 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
11#include <immintrin.h>
12#endif
13#include <algorithm> //for std::min/std::max
14#include <cassert> //for assert
15#include <cfloat> // for FLT_MAX
16#include <cmath> //for nearbyint
17#include <cstring> //for memcpy
18#include <limits> //for numeric_limits
19#include "./MaskAvx2.h"
20#include "fbgemm/Types.h"
21
22namespace fbgemm {
23
24using namespace std;
25////////////////////////////////////////////////////////////////////////////////
26// Utility functions
27
28template <typename T, bool LEGACY>
29void QuantizeAvx2(
30 const float* src,
31 T* dst,
32 int64_t len,
33 const TensorQuantizationParams& qparams) {
34#if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER))
35 constexpr int VLEN = 8;
36 constexpr int32_t min_val = std::numeric_limits<T>::min();
37 constexpr int32_t max_val = std::numeric_limits<T>::max();
38 // This is the largest int32 value less than int32_max
39 // that is exactly representable in float
40 constexpr int32_t int32_float_max_val =
41 std::numeric_limits<int32_t>::max() - 127;
42 int i = 0;
43 float inverse_scale = 1.f / qparams.scale;
44 __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
45 // clang-format off
46 __m256i shuffle_mask_v = _mm256_set_epi8(
47 0xff, 0xff, 0xff, 0xff,
48 0xff, 0xff, 0xff, 0xff,
49 0xff, 0xff, 0xff, 0xff,
50 0x0c, 0x08, 0x04, 0x00,
51 0xff, 0xff, 0xff, 0xff,
52 0xff, 0xff, 0xff, 0xff,
53 0xff, 0xff, 0xff, 0xff,
54 0x0c, 0x08, 0x04, 0x00);
55 // clang-format on
56 __m256i permute_mask_v =
57 _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
58 const auto zero_point_v_legacy = _mm256_set1_ps(qparams.zero_point);
59 const auto zero_point_v_non_legacy = _mm256_set1_epi32(qparams.zero_point);
60 for (; i < len / VLEN * VLEN; i += VLEN) {
61 __m256 src_v = _mm256_loadu_ps(src + i);
62 __m256 transformed_v;
63 if (LEGACY) { // static if
64 transformed_v =
65 _mm256_fmadd_ps(src_v, inverse_scale_v, zero_point_v_legacy);
66 } else {
67 transformed_v = _mm256_mul_ps(src_v, inverse_scale_v);
68 }
69 // If the floating point value is greater than int32_max,
70 // _mm256_cvtps_epi32 converts them to negative. Clip at int32_float_max_val
71 // to avoid this.
72 transformed_v =
73 _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val));
74
75 __m256i rounded_v = _mm256_cvtps_epi32(transformed_v);
76 if (!LEGACY) {
77 rounded_v = _mm256_add_epi32(rounded_v, zero_point_v_non_legacy);
78 }
79 __m256i clipped_v = _mm256_min_epi32(
80 _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)),
81 _mm256_set1_epi32(max_val));
82
83 // An instruction sequence to save 8 32-bit integers as 8 8-bit integers
84 clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v);
85 clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v);
86 _mm_storel_epi64(
87 reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(clipped_v));
88 }
89
90 // Handle remainder using mask instructions so that
91 // the main loop and remainder loop have the same behavior
92 int64_t rem = len - i;
93 if (rem > 0) {
94 __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
95 internal::avx2_ps_or_epi32_masks[rem]));
96 // __m128i store_mask_v = _mm_load_si128(
97 // reinterpret_cast<const __m128i*>(internal::sse_epi8_masks[rem]));
98 __m256 src_v = _mm256_maskload_ps(src + i, mask_v);
99 __m256 transformed_v;
100 if (LEGACY) {
101 transformed_v =
102 _mm256_fmadd_ps(src_v, inverse_scale_v, zero_point_v_legacy);
103 } else {
104 transformed_v = _mm256_mul_ps(src_v, inverse_scale_v);
105 }
106 transformed_v =
107 _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val));
108
109 __m256i rounded_v = _mm256_cvtps_epi32(transformed_v);
110 if (!LEGACY) {
111 rounded_v = _mm256_add_epi32(rounded_v, zero_point_v_non_legacy);
112 }
113 __m256i clipped_v = _mm256_min_epi32(
114 _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)),
115 _mm256_set1_epi32(max_val));
116
117 // An instruction sequence to save "rem" number of 32-bit integers
118 // as "rem" number of 8-bit integers
119 clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v);
120 clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v);
121 // do not use _mm_maskmoveu_si128 instead of memcpy.
122 // asan has false positives for _mm_maskmoveu_si128 and this instruction
123 // sometimes causes segfault (root cause is unknown).
124 memcpy(dst + i, reinterpret_cast<void*>(&clipped_v), rem * sizeof(T));
125 // _mm_maskmoveu_si128(
126 // _mm256_castsi256_si128(clipped_v),
127 // store_mask_v,
128 // reinterpret_cast<char*>(dst + i));
129 }
130#endif
131}
132
133uint32_t Xor128(void) {
134 /* library-local */ static uint32_t x = 123456789;
135 /* library-local */ static uint32_t y = 362436069;
136 /* library-local */ static uint32_t z = 521288629;
137 /* library-local */ static uint32_t w = 88675123;
138 uint32_t t;
139 t = x ^ (x << 11);
140 x = y;
141 y = z;
142 z = w;
143 return w = w ^ (w >> 19) ^ (t ^ (t >> 8));
144}
145
146// Instantiate QuantizeAvx2 for known datatypes
147#define SPECIALIZE_QUANTIZEAVX2(T, LEGACY) \
148 template void QuantizeAvx2<T, LEGACY>( \
149 const float* src, \
150 T* dst, \
151 int64_t len, \
152 const TensorQuantizationParams& qparams);
153SPECIALIZE_QUANTIZEAVX2(uint8_t, true)
154SPECIALIZE_QUANTIZEAVX2(int8_t, true)
155SPECIALIZE_QUANTIZEAVX2(uint8_t, false)
156SPECIALIZE_QUANTIZEAVX2(int8_t, false)
157#undef SPECIALIZE_QUANTIZEAVX2
158
159template <typename T>
160void NO_SANITIZE("address") FusedQuantizeDequantizeAvx2(
161 const float* src,
162 float* dst,
163 int len,
164 const TensorQuantizationParams& qparams,
165 float noise_ratio) {
166 float inverse_scale = 1.f / qparams.scale;
167 constexpr int32_t min_val = std::numeric_limits<T>::min();
168 constexpr int32_t max_val = std::numeric_limits<T>::max();
169 (void)inverse_scale; // Suppress unused variable warning
170 (void)min_val; // Suppress unused variable warning
171 (void)max_val; // Suppress unused variable warning
172#if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER))
173
174 constexpr int VLEN = 8;
175 // This is the largest int32 value less than int32_max
176 // that is exactly representable in float
177 constexpr int32_t int32_float_max_val =
178 std::numeric_limits<int32_t>::max() - 127;
179 int i = 0;
180 uint32_t rand;
181 __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
182 __m256 scale_v = _mm256_set1_ps(qparams.scale);
183 __m256 zp_v = _mm256_set1_ps(qparams.zero_point);
184
185 for (; i < len / VLEN * VLEN; i += VLEN) {
186 // prefetch src and dst
187 _mm_prefetch(reinterpret_cast<const char*>(src + i + VLEN), _MM_HINT_T0);
188 _mm_prefetch(reinterpret_cast<const char*>(dst + i + VLEN), _MM_HINT_T0);
189
190 __m256 src_v = _mm256_loadu_ps(src + i);
191 __m256 transformed_v;
192 if (noise_ratio > 0) {
193 rand = Xor128() % 10;
194 if (rand < noise_ratio * 10) {
195 _mm256_storeu_ps(dst + i, src_v);
196 continue;
197 }
198 }
199
200 transformed_v = _mm256_mul_ps(src_v, inverse_scale_v);
201 // If the floating point value is greater than int32_max,
202 // _mm256_cvtps_epi32 converts them to negative. Clip at int32_float_max_val
203 // to avoid this.
204 transformed_v =
205 _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val));
206
207 __m256i rounded_v = _mm256_cvtps_epi32(transformed_v);
208 rounded_v =
209 _mm256_add_epi32(rounded_v, _mm256_set1_epi32(qparams.zero_point));
210 __m256i clipped_v = _mm256_min_epi32(
211 _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)),
212 _mm256_set1_epi32(max_val));
213
214 // convert int32 to float32
215 __m256 fp32_clipped_v = _mm256_cvtepi32_ps(clipped_v);
216 // minus zero point, multiply by scale
217 __m256 fp32_dq_sub = _mm256_sub_ps(fp32_clipped_v, zp_v);
218 __m256 fp32_dq = _mm256_mul_ps(fp32_dq_sub, scale_v);
219
220 // save fusued quantize-dequantize fp32 values into dst
221 _mm256_storeu_ps(dst + i, fp32_dq);
222 }
223
224 // Handle remainder using mask instructions so that
225 // the main loop and remainder loop have the same behavior
226 int rem = len - i;
227 if (rem > 0) {
228 __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
229 internal::avx2_ps_or_epi32_masks[rem]));
230
231 __m256 src_v = _mm256_maskload_ps(src + i, mask_v);
232 __m256 transformed_v;
233
234 if (noise_ratio > 0) {
235 rand = Xor128() % 10;
236 if (rand < noise_ratio * 10) {
237 _mm256_storeu_ps(dst + i, src_v);
238 return;
239 }
240 }
241
242 transformed_v = _mm256_mul_ps(src_v, inverse_scale_v);
243 // If the floating point value is greater than int32_max,
244 // _mm256_cvtps_epi32 converts them to negative. Clip at int32_float_max_val
245 // to avoid this.
246 transformed_v =
247 _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val));
248
249 __m256i rounded_v = _mm256_cvtps_epi32(transformed_v);
250 rounded_v =
251 _mm256_add_epi32(rounded_v, _mm256_set1_epi32(qparams.zero_point));
252
253 __m256i clipped_v = _mm256_min_epi32(
254 _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)),
255 _mm256_set1_epi32(max_val));
256
257 // convert int32 to float32
258 __m256 fp32_clipped_v = _mm256_cvtepi32_ps(clipped_v);
259 // minus zero point, multiply by scale
260 __m256 fp32_dq_sub =
261 _mm256_sub_ps(fp32_clipped_v, _mm256_set1_ps(qparams.zero_point));
262 __m256 fp32_dq = _mm256_mul_ps(fp32_dq_sub, _mm256_set1_ps(qparams.scale));
263
264 // store fp32 values with mask
265 _mm256_maskstore_ps(dst + i, mask_v, fp32_dq);
266 }
267#endif
268}
269
270// Instantiate QuantizeAvx2 for known datatypes
271#define SPECIALIZE_FUSEDDQAVX2(T) \
272 template void FusedQuantizeDequantizeAvx2<T>( \
273 const float* src, \
274 float* dst, \
275 int len, \
276 const TensorQuantizationParams& qparams, \
277 float noise_ratio);
278SPECIALIZE_FUSEDDQAVX2(uint8_t)
279SPECIALIZE_FUSEDDQAVX2(int8_t)
280
281#undef SPECIALIZE_FUSEDDQAVX2
282
283void FindMinMax(const float* a, float* min, float* max, int64_t len) {
284 if (len <= 0) {
285 *min = 0.0f;
286 *max = 0.0f;
287 return;
288 }
289
290 float temp_min = *a, temp_max = *a;
291 int64_t i = 0;
292
293#ifdef __AVX__
294 __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a);
295 constexpr int VLEN = 8;
296 if (len >= VLEN) {
297 for (; i < len / VLEN * VLEN; i += VLEN) {
298 min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i));
299 max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i));
300 }
301
302 float min_buf[VLEN], max_buf[VLEN];
303 _mm256_storeu_ps(min_buf, min_v);
304 _mm256_storeu_ps(max_buf, max_v);
305 for (int j = 0; j < VLEN; ++j) {
306 temp_min = std::min(temp_min, min_buf[j]);
307 temp_max = std::max(temp_max, max_buf[j]);
308 }
309 }
310#endif
311
312 for (; i < len; i++) {
313 temp_min = std::min(temp_min, a[i]);
314 temp_max = std::max(temp_max, a[i]);
315 }
316 *min = temp_min;
317 *max = temp_max;
318}
319
320////////////////////////////////////////////////////////////////////////////////
321// Requantization (with floats)
322
323#ifdef __AVX2__
324void RequantizeAvx2(
325 const int32_t* src,
326 uint8_t* dst,
327 int len,
328 const RequantizationParams& params) {
329 int32_t Bq_zero_point[] = {0};
330
331 requantizationParams_t<> reqObj = {
332 0, // Aq_zero_point
333 Bq_zero_point,
334 params.target_qparams.zero_point,
335 &params.real_multiplier,
336 nullptr, // row_offsets
337 nullptr, // col_offsets
338 nullptr, // bias
339 static_cast<std::uint32_t>(len), // ncols
340 1, // groups
341 nullptr}; // act_times_w_scale
342 requantizeOutputProcessingAvx2<
343 true, // A_SYMMETRIC
344 true, // B_SYMMETRIC
345 QuantizationGranularity::TENSOR,
346 false, // HAS_BIAS
347 false // FUSE_RELU
348 >(dst, src, {0, 1, 0, len}, len, len, reqObj);
349}
350
351void RequantizeFixedPointAvx2(
352 const int32_t* src,
353 uint8_t* dst,
354 int len,
355 const RequantizationParams& params) {
356 constexpr int VLEN = 8;
357
358 __m256i b = _mm256_set1_epi32(params.multiplier);
359
360 // AVX2 doesn't support arithmetic right shift.
361 // As a work around, we convert 64-bit multiplied results to uint64_t by
362 // adding 0x8000000000000000ULL, logical right shift, and subtract by
363 // (0x8000000000000000ULL >> right_shift).
364 __m256i pre_shift_nudge = _mm256_set1_epi64x(
365 (1ll << (params.right_shift - 1)) + 0x8000000000000000ULL);
366 __m256i post_shift_nudge = _mm256_set1_epi64x(
367 params.target_qparams.zero_point -
368 (0x8000000000000000ULL >> params.right_shift));
369
370 __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min());
371 __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max());
372
373 __m256i shuffle_mask_v = _mm256_set_epi8(
374 0xff,
375 0xff,
376 0xff,
377 0xff,
378 0xff,
379 0xff,
380 0xff,
381 0xff,
382 0xff,
383 0xff,
384 0xff,
385 0xff,
386 0x0c,
387 0x08,
388 0x04,
389 0x00,
390 0xff,
391 0xff,
392 0xff,
393 0xff,
394 0xff,
395 0xff,
396 0xff,
397 0xff,
398 0xff,
399 0xff,
400 0xff,
401 0xff,
402 0x0c,
403 0x08,
404 0x04,
405 0x00);
406 __m256i permute_mask_v =
407 _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
408
409 int i = 0;
410 for (; i < len / VLEN * VLEN; i += VLEN) {
411 __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i));
412
413 // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7
414 // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7
415 __m256i a_even_v = a_v;
416 __m256i a_odd_v = _mm256_srli_si256(a_v, 4);
417
418 __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b);
419 __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b);
420
421 __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge);
422 __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge);
423
424 __m256i even_result_v = _mm256_add_epi64(
425 _mm256_srli_epi64(even_rounded_v, params.right_shift),
426 post_shift_nudge);
427 __m256i odd_result_v = _mm256_add_epi64(
428 _mm256_srli_epi64(odd_rounded_v, params.right_shift), post_shift_nudge);
429 odd_result_v = _mm256_slli_si256(odd_result_v, 4);
430
431 // even_result_v has numbers we want in its even 32-bit SIMD lanes, and
432 // odd_result_v has numbers we want in its odd 32-bit SIMD lanes.
433 // Use blend to combine them.
434 __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa);
435 __m256i clipped_v =
436 _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v));
437
438 clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v);
439 clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v);
440 *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0);
441 }
442
443 for (; i < len; ++i) {
444 int64_t ab_64 =
445 static_cast<int64_t>(src[i]) * static_cast<int64_t>(params.multiplier);
446 int64_t nudge = 1ll << std::max(0, params.right_shift - 1);
447 int64_t quantized_down = params.target_qparams.zero_point +
448 ((ab_64 + nudge) >> params.right_shift);
449 dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l);
450 }
451}
452#endif
453
454template <
455 bool A_SYMMETRIC,
456 bool B_SYMMETRIC,
457 QuantizationGranularity Q_GRAN,
458 bool HAS_BIAS,
459 bool FUSE_RELU,
460 typename BIAS_TYPE,
461 bool DIRECT>
462void requantizeOutputProcessingAvx2(
463 uint8_t* out,
464 const int32_t* inp,
465 const block_type_t& block,
466 int ld_out,
467 int ld_in,
468 const requantizationParams_t<BIAS_TYPE>& r) {
469 // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
470 // using AVX2 instructions
471 int quant_param_idx = 0;
472 if (Q_GRAN == QuantizationGranularity::GROUP) {
473 int ncol_per_group = r.ncols / r.groups;
474 int g = block.col_start / ncol_per_group;
475 quant_param_idx = g;
476 }
477 __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
478
479 // Broadcasted reciprocal of act_times_w_scale
480 __m256 act_times_w_rcp_v;
481 if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
482 if (is_same<BIAS_TYPE, float>::value) {
483 act_times_w_rcp_v =
484 _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
485 }
486 }
487
488 __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
489 __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
490
491 assert(
492 (A_SYMMETRIC == (r.A_zero_point == 0)) &&
493 "A_SYMMETRIC == true if and only if A_zero_point == 0");
494 assert(
495 (B_SYMMETRIC ==
496 ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) ||
497 r.row_offsets == nullptr)) &&
498 "B_SYMMETRIC == true if and only if B_zero_point == 0 "
499 "or r.row_offsets == nullptr");
500 assert(
501 (HAS_BIAS == (r.bias != nullptr)) &&
502 "HAS_BIAS == true if and only if bias != nullptr");
503
504 __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point);
505 __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point);
506 __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point);
507
508 __m256i permute_mask_v =
509 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
510
511 constexpr int VLEN = 8;
512 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
513 // Scale row_offset with Bq_zero_point
514 int32_t row_offset = 0;
515 if (B_SYMMETRIC) {
516 row_offset = 0;
517 } else if (
518 Q_GRAN == QuantizationGranularity::TENSOR ||
519 Q_GRAN == QuantizationGranularity::GROUP) {
520 row_offset =
521 r.row_offsets[i - block.row_start] * r.B_zero_point[quant_param_idx];
522 } else {
523 assert(
524 Q_GRAN == QuantizationGranularity::OUT_CHANNEL &&
525 "unknown quantization granularity");
526 }
527 __m256i row_offset_v = _mm256_set1_epi32(row_offset);
528
529 int j = block.col_start;
530 for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4));
531 j += (VLEN * 4)) {
532 __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
533 inp + (i - block.row_start) * ld_in + (j - block.col_start)));
534 __m256i y_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
535 inp + (i - block.row_start) * ld_in + (j - block.col_start) +
536 1 * VLEN));
537 __m256i z_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
538 inp + (i - block.row_start) * ld_in + (j - block.col_start) +
539 2 * VLEN));
540 __m256i w_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
541 inp + (i - block.row_start) * ld_in + (j - block.col_start) +
542 3 * VLEN));
543
544 if (!A_SYMMETRIC) {
545 __m256i col_off_v;
546 if (DIRECT == false) {
547 col_off_v = _mm256_mullo_epi32(
548 A_zero_point_v,
549 _mm256_loadu_si256(
550 reinterpret_cast<const __m256i*>(r.col_offsets + j)));
551 } else {
552 col_off_v = _mm256_mullo_epi32(
553 A_zero_point_v,
554 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
555 r.col_offsets + j + i * block.col_size)));
556 }
557
558 x_v = _mm256_sub_epi32(x_v, col_off_v);
559
560 if (DIRECT == false) {
561 col_off_v = _mm256_mullo_epi32(
562 A_zero_point_v,
563 _mm256_loadu_si256(
564 reinterpret_cast<const __m256i*>(r.col_offsets + j + VLEN)));
565 } else {
566 col_off_v = _mm256_mullo_epi32(
567 A_zero_point_v,
568 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
569 r.col_offsets + j + VLEN + i * block.col_size)));
570 }
571
572 y_v = _mm256_sub_epi32(y_v, col_off_v);
573
574 if (DIRECT == false) {
575 col_off_v = _mm256_mullo_epi32(
576 A_zero_point_v,
577 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
578 r.col_offsets + j + 2 * VLEN)));
579 } else {
580 col_off_v = _mm256_mullo_epi32(
581 A_zero_point_v,
582 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
583 r.col_offsets + j + 2 * VLEN + i * block.col_size)));
584 }
585
586 z_v = _mm256_sub_epi32(z_v, col_off_v);
587
588 if (DIRECT == false) {
589 col_off_v = _mm256_mullo_epi32(
590 A_zero_point_v,
591 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
592 r.col_offsets + j + 3 * VLEN)));
593 } else {
594 col_off_v = _mm256_mullo_epi32(
595 A_zero_point_v,
596 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
597 r.col_offsets + j + 3 * VLEN + i * block.col_size)));
598 }
599
600 w_v = _mm256_sub_epi32(w_v, col_off_v);
601 }
602
603 if (!B_SYMMETRIC) {
604 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
605 row_offset_v = _mm256_mullo_epi32(
606 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
607 _mm256_loadu_si256(
608 reinterpret_cast<const __m256i*>(r.B_zero_point + j)));
609 }
610 x_v = _mm256_sub_epi32(x_v, row_offset_v);
611 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
612 row_offset_v = _mm256_mullo_epi32(
613 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
614 _mm256_loadu_si256(
615 reinterpret_cast<const __m256i*>(r.B_zero_point + j + VLEN)));
616 }
617 y_v = _mm256_sub_epi32(y_v, row_offset_v);
618 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
619 row_offset_v = _mm256_mullo_epi32(
620 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
621 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
622 r.B_zero_point + j + 2 * VLEN)));
623 }
624 z_v = _mm256_sub_epi32(z_v, row_offset_v);
625 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
626 row_offset_v = _mm256_mullo_epi32(
627 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
628 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
629 r.B_zero_point + j + 3 * VLEN)));
630 }
631 w_v = _mm256_sub_epi32(w_v, row_offset_v);
632 }
633 __m256 xf_v, yf_v, zf_v, wf_v;
634 if (HAS_BIAS) {
635 if (is_same<BIAS_TYPE, float>::value) {
636 __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
637 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
638 x_bias_v = _mm256_div_ps(
639 _mm256_loadu_ps(
640 reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
641 _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
642 y_bias_v = _mm256_div_ps(
643 _mm256_loadu_ps(
644 reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
645 _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
646 z_bias_v = _mm256_div_ps(
647 _mm256_loadu_ps(
648 reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
649 _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
650 w_bias_v = _mm256_div_ps(
651 _mm256_loadu_ps(
652 reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
653 _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
654 } else {
655 x_bias_v = _mm256_mul_ps(
656 _mm256_loadu_ps(
657 reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
658 act_times_w_rcp_v);
659 y_bias_v = _mm256_mul_ps(
660 _mm256_loadu_ps(
661 reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
662 act_times_w_rcp_v);
663 z_bias_v = _mm256_mul_ps(
664 _mm256_loadu_ps(
665 reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
666 act_times_w_rcp_v);
667 w_bias_v = _mm256_mul_ps(
668 _mm256_loadu_ps(
669 reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
670 act_times_w_rcp_v);
671 }
672 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
673 yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
674 zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
675 wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
676 } else {
677 x_v = _mm256_add_epi32(
678 x_v,
679 _mm256_loadu_si256(
680 reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
681 y_v = _mm256_add_epi32(
682 y_v,
683 _mm256_loadu_si256(
684 reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
685 z_v = _mm256_add_epi32(
686 z_v,
687 _mm256_loadu_si256(
688 reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
689 w_v = _mm256_add_epi32(
690 w_v,
691 _mm256_loadu_si256(
692 reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
693 xf_v = _mm256_cvtepi32_ps(x_v);
694 yf_v = _mm256_cvtepi32_ps(y_v);
695 zf_v = _mm256_cvtepi32_ps(z_v);
696 wf_v = _mm256_cvtepi32_ps(w_v);
697 }
698 } else {
699 xf_v = _mm256_cvtepi32_ps(x_v);
700 yf_v = _mm256_cvtepi32_ps(y_v);
701 zf_v = _mm256_cvtepi32_ps(z_v);
702 wf_v = _mm256_cvtepi32_ps(w_v);
703 }
704
705 /*
706 * Convert int32_t input to FP32 and multiply by FP32 scale.
707 * Both operations involve statistically unbiased roundings (with
708 * default MXCSR rounding mode):
709 * - Large int32_t values can't be exactly represented as FP32.
710 * CVTDQ2PS instruction on x86 would round it according to nearest
711 * FP32 value with ties to even (assuming default MXCSR rounding
712 * mode).
713 * - Product of two FP32 values is generally not exactly
714 * representation as an FP32 value, and will be rounded to nearest
715 * FP32 value with ties to even with default MXCSR rounding mode.
716 */
717 __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
718 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
719 x_scaled_v =
720 _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN));
721 y_scaled_v =
722 _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN));
723 z_scaled_v =
724 _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
725 w_scaled_v =
726 _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
727 } else {
728 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
729 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
730 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
731 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
732 }
733
734 /*
735 * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction.
736 * CVTPS2DQ instruction rounds result according to nearest FP32 value
737 * with ties to even (assuming default MXCSR rounding mode). However,
738 * when conversion overflows, it produces INT32_MIN as a result. For
739 * large positive inputs the result of conversion can become negative,
740 * which affects the final requantization result. Note that on x86
741 * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This
742 * happens because float(INT32_MAX) rounds to 2**31, which overflows
743 * int32_t when it is converted back to integer.
744 *
745 * Thankfully, we can prove that overflow never happens in this
746 * requantization scheme. The largest positive input is INT32_MAX
747 * (2**31 - 1), which turns into 2**31 when converted to float. The
748 * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the
749 * result is 2147483520 (compare to INT32_MAX = 2147483647), which
750 * fits into int32_t without overflow.
751 */
752 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
753 __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
754 __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
755 __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
756
757 /*
758 * Standard final sequence on x86 AVX2:
759 * - Pack to int16_t and saturate
760 * - Add zero point
761 * - Pack to uint8_t and saturate
762 * - Clamp between qmin and qmax
763 */
764 __m256i xy_packed_v = _mm256_adds_epi16(
765 _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
766 __m256i zw_packed_v = _mm256_adds_epi16(
767 _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
768 __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
769 __m256i xyzw_clamped_v = _mm256_max_epu8(
770 FUSE_RELU ? C_zero_point_epi8_v : min_v,
771 _mm256_min_epu8(xyzw_packed_v, max_v));
772
773 /*
774 * xyzw_clamped_v has results in the following layout so we need to
775 * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
776 */
777 xyzw_clamped_v =
778 _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
779
780 /*
781 * 4x CVTDQ2PS
782 * 4x MULPS
783 * 4x CVTPS2DQ
784 * 2x PACKSSDW
785 * 1x PACKUSWB
786 * 2x PADDW
787 * 1x PMAXUB
788 * 1x PMINUB
789 * 1x PERMD
790 * ---------------------
791 * 20 instructions total
792 */
793 _mm256_storeu_si256(
794 reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v);
795 } // j loop vectorized and unrolled 4x
796
797 for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) {
798 __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
799 inp + (i - block.row_start) * ld_in + (j - block.col_start)));
800
801 if (!A_SYMMETRIC) {
802 __m256i col_off_v;
803 if (DIRECT == false) {
804 col_off_v = _mm256_mullo_epi32(
805 A_zero_point_v,
806 _mm256_loadu_si256(
807 reinterpret_cast<const __m256i*>(r.col_offsets + j)));
808 } else {
809 col_off_v = _mm256_mullo_epi32(
810 A_zero_point_v,
811 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
812 r.col_offsets + j + i * block.col_size)));
813 }
814 x_v = _mm256_sub_epi32(x_v, col_off_v);
815 }
816
817 if (!B_SYMMETRIC) {
818 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
819 row_offset_v = _mm256_mullo_epi32(
820 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
821 _mm256_loadu_si256(
822 reinterpret_cast<const __m256i*>(r.B_zero_point + j)));
823 }
824 x_v = _mm256_sub_epi32(x_v, row_offset_v);
825 }
826 __m256 xf_v;
827 if (HAS_BIAS) {
828 if (is_same<BIAS_TYPE, float>::value) {
829 __m256 x_bias_v;
830 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
831 x_bias_v = _mm256_div_ps(
832 _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
833 _mm256_loadu_ps(r.act_times_w_scale + j));
834 } else {
835 x_bias_v = _mm256_mul_ps(
836 _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
837 act_times_w_rcp_v);
838 }
839 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
840 } else {
841 x_v = _mm256_add_epi32(
842 x_v,
843 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
844 xf_v = _mm256_cvtepi32_ps(x_v);
845 }
846 } else {
847 xf_v = _mm256_cvtepi32_ps(x_v);
848 }
849
850 __m256 x_scaled_v;
851 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
852 x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
853 } else {
854 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
855 }
856 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
857
858 __m256i x_packed_v = _mm256_adds_epi16(
859 _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
860 C_zero_point_epi16_v);
861 x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
862 __m256i x_clamped_v = _mm256_max_epu8(
863 FUSE_RELU ? C_zero_point_epi8_v : min_v,
864 _mm256_min_epu8(x_packed_v, max_v));
865
866 /*
867 * x_clamped_v has results in the following layout so we need to
868 * permute: x0-3 garbage0-11 x4-7 garbage12-23
869 */
870 x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
871
872 /*
873 * 1x CVTDQ2PS
874 * 1x MULPS
875 * 1x CVTPS2DQ
876 * 1x PACKSSDW
877 * 1x PACKUSWB
878 * 1x PADDW
879 * 1x PMAXUB
880 * 1x PMINUB
881 * 1x PERMD
882 * ---------------------
883 * 9 instructions total
884 */
885 _mm_storel_epi64(
886 reinterpret_cast<__m128i*>(out + i * ld_out + j),
887 _mm256_castsi256_si128(x_clamped_v));
888 } // j loop vectorized
889
890 int remainder = block.col_start + block.col_size - j;
891 if (remainder > 0) {
892 __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
893 internal::avx2_ps_or_epi32_masks[remainder]));
894
895 __m256i x_v = _mm256_maskload_epi32(
896 inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v);
897
898 if (!A_SYMMETRIC) {
899 __m256i col_off_v;
900 if (DIRECT == false) {
901 col_off_v = _mm256_mullo_epi32(
902 A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v));
903 } else {
904 col_off_v = _mm256_mullo_epi32(
905 A_zero_point_v,
906 _mm256_maskload_epi32(
907 r.col_offsets + j + i * block.col_size, mask_v));
908 }
909 x_v = _mm256_sub_epi32(x_v, col_off_v);
910 }
911
912 if (!B_SYMMETRIC) {
913 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
914 row_offset_v = _mm256_mullo_epi32(
915 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
916 _mm256_maskload_epi32(r.B_zero_point + j, mask_v));
917 }
918 x_v = _mm256_sub_epi32(x_v, row_offset_v);
919 }
920
921 __m256 xf_v;
922 if (HAS_BIAS) {
923 if (is_same<BIAS_TYPE, float>::value) {
924 __m256 x_bias_v;
925 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
926 x_bias_v = _mm256_div_ps(
927 _mm256_maskload_ps(
928 reinterpret_cast<const float*>(r.bias + j), mask_v),
929 _mm256_maskload_ps(r.act_times_w_scale + j, mask_v));
930 } else {
931 x_bias_v = _mm256_mul_ps(
932 _mm256_maskload_ps(
933 reinterpret_cast<const float*>(r.bias + j), mask_v),
934 act_times_w_rcp_v);
935 }
936 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
937 } else {
938 x_v = _mm256_add_epi32(
939 x_v,
940 _mm256_maskload_epi32(
941 reinterpret_cast<const int*>(r.bias + j), mask_v));
942 xf_v = _mm256_cvtepi32_ps(x_v);
943 }
944 } else {
945 xf_v = _mm256_cvtepi32_ps(x_v);
946 }
947
948 __m256 x_scaled_v;
949 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
950 x_scaled_v =
951 _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v));
952 } else {
953 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
954 }
955 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
956
957 __m256i x_packed_v = _mm256_adds_epi16(
958 _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
959 C_zero_point_epi16_v);
960 x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
961 __m256i x_clamped_v = _mm256_max_epu8(
962 FUSE_RELU ? C_zero_point_epi8_v : min_v,
963 _mm256_min_epu8(x_packed_v, max_v));
964
965 /*
966 * x_clamped_v has results in the following layout so we need to
967 * permute: x0-3 garbage0-11 x4-7 garbage12-23
968 */
969 x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
970
971 /*
972 * 1x CVTDQ2PS
973 * 1x MULPS
974 * 1x CVTPS2DQ
975 * 1x PACKSSDW
976 * 1x PACKUSWB
977 * 1x PADDW
978 * 1x PMAXUB
979 * 1x PMINUB
980 * 1x PERMD
981 * ---------------------
982 * 9 instructions total
983 */
984 alignas(64) uint8_t x_clamped_buffer[32];
985 _mm256_store_si256(
986 reinterpret_cast<__m256i*>(x_clamped_buffer), x_clamped_v);
987 for (int k = 0; k < remainder; ++k) {
988 out[i * ld_out + j + k] = x_clamped_buffer[k];
989 }
990 } // j loop remainder
991 } // i loop
992}
993
994template <
995 bool A_SYMMETRIC,
996 bool B_SYMMETRIC,
997 QuantizationGranularity Q_GRAN,
998 bool HAS_BIAS,
999 bool FUSE_RELU>
1000void requantizeForFloatAvx2(
1001 float* out,
1002 const int32_t* inp,
1003 const block_type_t& block,
1004 int ld_out,
1005 int ld_in,
1006 const requantizationForFloatParams_t& r) {
1007 // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
1008 // using AVX2 instructions
1009 int quant_param_idx = 0;
1010 if (Q_GRAN == QuantizationGranularity::GROUP) {
1011 int ncol_per_group = r.ncols / r.groups;
1012 int g = block.col_start / ncol_per_group;
1013 quant_param_idx = g;
1014 }
1015 __m256 multiplier_v = _mm256_set1_ps(r.A_scale * r.B_scale[quant_param_idx]);
1016
1017 assert(
1018 (A_SYMMETRIC == (r.A_zero_point == 0)) &&
1019 "A_SYMMETRIC == true if and only if A_zero_point == 0");
1020 assert(
1021 (B_SYMMETRIC ==
1022 ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) ||
1023 r.row_offsets == nullptr)) &&
1024 "B_SYMMETRIC == true if and only if B_zero_point == 0 "
1025 "or r.row_offsets == nullptr");
1026 assert(
1027 (HAS_BIAS == (r.bias != nullptr)) &&
1028 "HAS_BIAS == true if and only if bias != nullptr");
1029
1030 __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point);
1031
1032 constexpr int VLEN = 8;
1033 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
1034 // Scale row_offset with Bq_zero_point
1035 int32_t row_offset = 0;
1036 if (B_SYMMETRIC) {
1037 row_offset = 0;
1038 } else if (
1039 Q_GRAN == QuantizationGranularity::TENSOR ||
1040 Q_GRAN == QuantizationGranularity::GROUP) {
1041 row_offset =
1042 r.row_offsets[i - block.row_start] * r.B_zero_point[quant_param_idx];
1043 } else {
1044 assert(
1045 Q_GRAN == QuantizationGranularity::OUT_CHANNEL &&
1046 "unknown quantization granularity");
1047 }
1048 __m256i row_offset_v = _mm256_set1_epi32(row_offset);
1049
1050 int j = block.col_start;
1051 for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) {
1052 __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1053 inp + (i - block.row_start) * ld_in + (j - block.col_start)));
1054
1055 if (!A_SYMMETRIC) {
1056 __m256i col_off_v = _mm256_mullo_epi32(
1057 A_zero_point_v,
1058 _mm256_loadu_si256(
1059 reinterpret_cast<const __m256i*>(r.col_offsets + j)));
1060 x_v = _mm256_sub_epi32(x_v, col_off_v);
1061 }
1062
1063 if (!B_SYMMETRIC) {
1064 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
1065 row_offset_v = _mm256_mullo_epi32(
1066 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
1067 _mm256_loadu_si256(
1068 reinterpret_cast<const __m256i*>(r.B_zero_point + j)));
1069 }
1070 x_v = _mm256_sub_epi32(x_v, row_offset_v);
1071 }
1072
1073 __m256 x_scaled_v;
1074 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
1075 x_scaled_v = _mm256_mul_ps(
1076 _mm256_cvtepi32_ps(x_v),
1077 _mm256_mul_ps(
1078 _mm256_set1_ps(r.A_scale), _mm256_loadu_ps(r.B_scale + j)));
1079 } else {
1080 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
1081 }
1082
1083 if (HAS_BIAS) {
1084 x_scaled_v = _mm256_add_ps(x_scaled_v, _mm256_loadu_ps(r.bias + j));
1085 }
1086 if (FUSE_RELU) {
1087 x_scaled_v = _mm256_max_ps(_mm256_setzero_ps(), x_scaled_v);
1088 }
1089
1090 _mm256_storeu_ps(out + i * ld_out + j, x_scaled_v);
1091 } // j loop vectorized
1092
1093 int remainder = block.col_start + block.col_size - j;
1094 if (remainder > 0) {
1095 __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
1096 internal::avx2_ps_or_epi32_masks[remainder]));
1097
1098 __m256i x_v = _mm256_maskload_epi32(
1099 inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v);
1100
1101 if (!A_SYMMETRIC) {
1102 __m256i col_off_v = _mm256_mullo_epi32(
1103 A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v));
1104 x_v = _mm256_sub_epi32(x_v, col_off_v);
1105 }
1106
1107 if (!B_SYMMETRIC) {
1108 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
1109 row_offset_v = _mm256_mullo_epi32(
1110 _mm256_set1_epi32(r.row_offsets[i - block.row_start]),
1111 _mm256_maskload_epi32(r.B_zero_point + j, mask_v));
1112 }
1113 x_v = _mm256_sub_epi32(x_v, row_offset_v);
1114 }
1115
1116 __m256 x_scaled_v;
1117 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
1118 x_scaled_v = _mm256_mul_ps(
1119 _mm256_cvtepi32_ps(x_v),
1120 _mm256_mul_ps(
1121 _mm256_set1_ps(r.A_scale),
1122 _mm256_maskload_ps(r.B_scale + j, mask_v)));
1123 } else {
1124 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
1125 }
1126
1127 if (HAS_BIAS) {
1128 x_scaled_v =
1129 _mm256_add_ps(x_scaled_v, _mm256_maskload_ps(r.bias + j, mask_v));
1130 }
1131 if (FUSE_RELU) {
1132 x_scaled_v = _mm256_max_ps(_mm256_setzero_ps(), x_scaled_v);
1133 }
1134
1135 _mm256_maskstore_ps(out + i * ld_out + j, mask_v, x_scaled_v);
1136 } // j loop remainder
1137 } // i loop
1138}
1139
1140template <
1141 bool A_SYMMETRIC,
1142 bool B_SYMMETRIC,
1143 QuantizationGranularity Q_GRAN,
1144 bool HAS_BIAS,
1145 bool FUSE_RELU,
1146 int C_PER_G,
1147 typename BIAS_TYPE>
1148void requantizeOutputProcessingGConvAvx2(
1149 uint8_t* out,
1150 const int32_t* inp,
1151 const block_type_t& block,
1152 int ld_out,
1153 int ld_in,
1154 const requantizationParams_t<BIAS_TYPE>& r) {
1155 // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
1156 // using AVX2 instructions
1157 int quant_param_idx = 0;
1158 if (Q_GRAN == QuantizationGranularity::GROUP) {
1159 int ncol_per_group = r.ncols / r.groups;
1160 int g = block.col_start / ncol_per_group;
1161 quant_param_idx = g;
1162 }
1163 __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
1164
1165 // Broadcasted reciprocal of act_times_w_scale
1166 __m256 act_times_w_rcp_v;
1167 if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
1168 if (is_same<BIAS_TYPE, float>::value) {
1169 act_times_w_rcp_v =
1170 _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
1171 }
1172 }
1173 __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
1174 __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
1175
1176 assert(
1177 (A_SYMMETRIC == (r.A_zero_point == 0)) &&
1178 "A_SYMMETRIC == true if and only if A_zero_point == 0");
1179 assert(
1180 (B_SYMMETRIC ==
1181 ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) ||
1182 r.row_offsets == nullptr)) &&
1183 "B_SYMMETRIC == true if and only if B_zero_point == 0 "
1184 "or r.row_offsets == nullptr");
1185 assert(
1186 (HAS_BIAS == (r.bias != nullptr)) &&
1187 "HAS_BIAS == true if and only if bias != nullptr");
1188
1189 __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point);
1190 __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point);
1191 __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point);
1192
1193 __m256i permute_mask_v =
1194 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
1195
1196 constexpr int VLEN = 8;
1197 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
1198 int j = block.col_start;
1199 for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) {
1200 __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1201 inp + (i - block.row_start) * ld_in + (j - block.col_start)));
1202
1203 if (!A_SYMMETRIC) {
1204 __m256i col_off_v = _mm256_mullo_epi32(
1205 A_zero_point_v,
1206 _mm256_loadu_si256(
1207 reinterpret_cast<const __m256i*>(r.col_offsets + j)));
1208 x_v = _mm256_sub_epi32(x_v, col_off_v);
1209 }
1210
1211 if (!B_SYMMETRIC) {
1212 __m256i row_offset_v;
1213
1214 if (C_PER_G == 2) {
1215 // When C_PER_G == 2, we need to handle 4 groups at a time to fully
1216 // utilize 32B AVX2 vector register (C_PER_G * 4 * sizeof(int32_t) ==
1217 // 32B)
1218 // Load row_offsets for 4 groups and broadcast by 2 times.
1219 row_offset_v =
1220 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
1221 _mm256_castps128_ps256(
1222 _mm_loadu_ps(reinterpret_cast<const float*>(
1223 r.row_offsets + (i - block.row_start) * 4))),
1224 permute_mask_v)));
1225
1226 }
1227 // When C_PER_G == 4, we need to handle 2 groups at a time to fully
1228 // utilize 32B AVX2 vector register (C_PER_G * 2 * sizeof(int32_t) ==
1229 // 32B)
1230 // When C_PER_G == 8, we just need 1 group at a time on the other hand.
1231
1232 // Groups 0 and 1 when C_PER_G == 4
1233 // Group 0 when C_PER_G == 8
1234 else if (C_PER_G == 4) {
1235 // Load row_offsets for 2 groups and broadcast by 4 times each because
1236 // we have 4 channels per group.
1237 // groups 0 and 1
1238 row_offset_v = _mm256_insertf128_si256(
1239 _mm256_castsi128_si256(
1240 _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 2 + 0])),
1241 _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 2 + 1]),
1242 1);
1243 } else if (C_PER_G == 8) {
1244 row_offset_v =
1245 _mm256_set1_epi32(r.row_offsets[(i - block.row_start)]);
1246 } else {
1247 assert(C_PER_G == 16);
1248 row_offset_v =
1249 _mm256_set1_epi32(r.row_offsets[(i - block.row_start)]);
1250 }
1251
1252 __m256i B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[0]);
1253 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
1254 B_zero_point_v = _mm256_loadu_si256(
1255 reinterpret_cast<const __m256i*>(r.B_zero_point + j));
1256 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
1257 if (C_PER_G == 2) {
1258 B_zero_point_v =
1259 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
1260 _mm256_castps128_ps256(
1261 _mm_loadu_ps(reinterpret_cast<const float*>(
1262 r.B_zero_point + quant_param_idx))),
1263 permute_mask_v)));
1264 } else if (C_PER_G == 4) {
1265 B_zero_point_v = _mm256_insertf128_si256(
1266 _mm256_castsi128_si256(
1267 _mm_set1_epi32(r.B_zero_point[quant_param_idx])),
1268 _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]),
1269 1);
1270 } else if (C_PER_G == 8) {
1271 B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[quant_param_idx]);
1272 } else {
1273 B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[quant_param_idx]);
1274 }
1275 }
1276 row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
1277 x_v = _mm256_sub_epi32(x_v, row_offset_v);
1278 }
1279 __m256 xf_v;
1280 if (HAS_BIAS) {
1281 if (is_same<BIAS_TYPE, float>::value) {
1282 __m256 x_bias_v =
1283 _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j));
1284 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
1285 x_bias_v = _mm256_div_ps(
1286 x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j));
1287 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
1288 __m256 diviser_v;
1289 if (C_PER_G == 2) {
1290 diviser_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
1291 _mm256_castps128_ps256(
1292 _mm_loadu_ps(r.act_times_w_scale + quant_param_idx)),
1293 permute_mask_v));
1294 } else if (C_PER_G == 4) {
1295 diviser_v = _mm256_insertf128_ps(
1296 _mm256_castps128_ps256(
1297 _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])),
1298 _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]),
1299 1);
1300 } else if (C_PER_G == 8) {
1301 diviser_v = _mm256_set1_ps(r.act_times_w_scale[quant_param_idx]);
1302 } else {
1303 assert(C_PER_G == 16);
1304 diviser_v = _mm256_set1_ps(r.act_times_w_scale[quant_param_idx]);
1305 }
1306 x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
1307 } else {
1308 x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v);
1309 }
1310 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
1311 } else {
1312 x_v = _mm256_add_epi32(
1313 x_v,
1314 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
1315 xf_v = _mm256_cvtepi32_ps(x_v);
1316 }
1317 } else {
1318 xf_v = _mm256_cvtepi32_ps(x_v);
1319 }
1320
1321 /*
1322 * Convert int32_t input to FP32 and multiply by FP32 scale.
1323 * Both operations involve statistically unbiased roundings (with
1324 * default MXCSR rounding mode):
1325 * - Large int32_t values can't be exactly represented as FP32.
1326 * CVTDQ2PS instruction on x86 would round it according to nearest
1327 * FP32 value with ties to even (assuming default MXCSR rounding
1328 * mode).
1329 * - Product of two FP32 values is generally not exactly
1330 * representation as an FP32 value, and will be rounded to nearest
1331 * FP32 value with ties to even with default MXCSR rounding mode.
1332 */
1333 __m256 x_scaled_v;
1334 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
1335 x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
1336 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
1337 if (C_PER_G == 2) {
1338 multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
1339 _mm256_castps128_ps256(
1340 _mm_loadu_ps(r.C_multiplier + quant_param_idx)),
1341 permute_mask_v));
1342 } else if (C_PER_G == 4) {
1343 multiplier_v = _mm256_insertf128_ps(
1344 _mm256_castps128_ps256(
1345 _mm_set1_ps(r.C_multiplier[quant_param_idx])),
1346 _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]),
1347 1);
1348 } else if (C_PER_G == 8) {
1349 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
1350 } else {
1351 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
1352 }
1353 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
1354 } else {
1355 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
1356 }
1357
1358 /*
1359 * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction.
1360 * CVTPS2DQ instruction rounds result according to nearest FP32 value
1361 * with ties to even (assuming default MXCSR rounding mode). However,
1362 * when conversion overflows, it produces INT32_MIN as a result. For
1363 * large positive inputs the result of conversion can become negative,
1364 * which affects the final requantization result. Note that on x86
1365 * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This
1366 * happens because float(INT32_MAX) rounds to 2**31, which overflows
1367 * int32_t when it is converted back to integer.
1368 *
1369 * Thankfully, we can prove that overflow never happens in this
1370 * requantization scheme. The largest positive input is INT32_MAX
1371 * (2**31 - 1), which turns into 2**31 when converted to float. The
1372 * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the
1373 * result is 2147483520 (compare to INT32_MAX = 2147483647), which
1374 * fits into int32_t without overflow.
1375 */
1376 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
1377
1378 /*
1379 * Standard final sequence on x86 AVX2:
1380 * - Pack to int16_t and saturate
1381 * - Add zero point
1382 * - Pack to uint8_t and saturate
1383 * - Clamp between qmin and qmax
1384 */
1385 __m256i x_packed_v = _mm256_adds_epi16(
1386 _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
1387 C_zero_point_epi16_v);
1388 x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
1389 __m256i x_clamped_v = _mm256_max_epu8(
1390 FUSE_RELU ? C_zero_point_epi8_v : min_v,
1391 _mm256_min_epu8(x_packed_v, max_v));
1392
1393 /*
1394 * x_clamped_v has results in the following layout so we need to
1395 * permute: x0-3 garbage0-11 x4-7 garbage12-23
1396 */
1397 x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
1398
1399 /*
1400 * 1x CVTDQ2PS
1401 * 1x MULPS
1402 * 1x CVTPS2DQ
1403 * 1x PACKSSDW
1404 * 1x PACKUSWB
1405 * 1x PADDW
1406 * 1x PMAXUB
1407 * 1x PMINUB
1408 * 1x PERMD
1409 * ---------------------
1410 * 9 instructions total
1411 */
1412
1413 _mm_storel_epi64(
1414 reinterpret_cast<__m128i*>(out + i * ld_out + j),
1415 _mm256_castsi256_si128(x_clamped_v));
1416 } // j loop vectorized
1417
1418 const int remainder = block.col_start + block.col_size - j;
1419 (void)remainder; // Suppress unused variable warning
1420 assert(remainder == 0);
1421 } // i loop
1422}
1423
1424#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
1425 A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
1426 template void FBGEMM_API requantizeOutputProcessingAvx2< \
1427 A_SYM, \
1428 B_SYM, \
1429 Q_GRAN, \
1430 BIAS, \
1431 RELU, \
1432 BIAS_TYPE, \
1433 false>( \
1434 uint8_t * out, \
1435 const int32_t* inp, \
1436 const block_type_t& block, \
1437 int ld_out, \
1438 int ld_in, \
1439 const requantizationParams_t<BIAS_TYPE>& r); \
1440 template void FBGEMM_API requantizeOutputProcessingAvx2< \
1441 A_SYM, \
1442 B_SYM, \
1443 Q_GRAN, \
1444 BIAS, \
1445 RELU, \
1446 BIAS_TYPE, \
1447 true>( \
1448 uint8_t * out, \
1449 const int32_t* inp, \
1450 const block_type_t& block, \
1451 int ld_out, \
1452 int ld_in, \
1453 const requantizationParams_t<BIAS_TYPE>& r); \
1454 template void requantizeOutputProcessingGConvAvx2< \
1455 A_SYM, \
1456 B_SYM, \
1457 Q_GRAN, \
1458 BIAS, \
1459 RELU, \
1460 2, \
1461 BIAS_TYPE>( \
1462 uint8_t * out, \
1463 const int32_t* inp, \
1464 const block_type_t& block, \
1465 int ld_out, \
1466 int ld_in, \
1467 const requantizationParams_t<BIAS_TYPE>& r); \
1468 template void requantizeOutputProcessingGConvAvx2< \
1469 A_SYM, \
1470 B_SYM, \
1471 Q_GRAN, \
1472 BIAS, \
1473 RELU, \
1474 4, \
1475 BIAS_TYPE>( \
1476 uint8_t * out, \
1477 const int32_t* inp, \
1478 const block_type_t& block, \
1479 int ld_out, \
1480 int ld_in, \
1481 const requantizationParams_t<BIAS_TYPE>& r); \
1482 template void requantizeOutputProcessingGConvAvx2< \
1483 A_SYM, \
1484 B_SYM, \
1485 Q_GRAN, \
1486 BIAS, \
1487 RELU, \
1488 8, \
1489 BIAS_TYPE>( \
1490 uint8_t * out, \
1491 const int32_t* inp, \
1492 const block_type_t& block, \
1493 int ld_out, \
1494 int ld_in, \
1495 const requantizationParams_t<BIAS_TYPE>& r); \
1496 template void requantizeOutputProcessingGConvAvx2< \
1497 A_SYM, \
1498 B_SYM, \
1499 Q_GRAN, \
1500 BIAS, \
1501 RELU, \
1502 16, \
1503 BIAS_TYPE>( \
1504 uint8_t * out, \
1505 const int32_t* inp, \
1506 const block_type_t& block, \
1507 int ld_out, \
1508 int ld_in, \
1509 const requantizationParams_t<BIAS_TYPE>& r);
1510
1511#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
1512 INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \
1513 INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \
1514 template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
1515 float* out, \
1516 const int32_t* inp, \
1517 const block_type_t& block, \
1518 int ld_out, \
1519 int ld_in, \
1520 const requantizationForFloatParams_t& r);
1521
1522#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
1523 INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \
1524 INSTANTIATE_REQUANTIZE(false, B_SYM, Q_GRAN, BIAS, RELU)
1525
1526#define INSTANTIATE_B_SYM(Q_GRAN, BIAS, RELU) \
1527 INSTANTIATE_A_SYM(true, Q_GRAN, BIAS, RELU) \
1528 INSTANTIATE_A_SYM(false, Q_GRAN, BIAS, RELU)
1529
1530#define INSTANTIATE_Q_GRANS(BIAS, RELU) \
1531 INSTANTIATE_B_SYM(QuantizationGranularity::TENSOR, BIAS, RELU) \
1532 INSTANTIATE_B_SYM(QuantizationGranularity::GROUP, BIAS, RELU) \
1533 INSTANTIATE_B_SYM(QuantizationGranularity::OUT_CHANNEL, BIAS, RELU)
1534
1535#define INSTANTIATE_BIAS(RELU) \
1536 INSTANTIATE_Q_GRANS(true, RELU) \
1537 INSTANTIATE_Q_GRANS(false, RELU)
1538
1539INSTANTIATE_BIAS(true)
1540INSTANTIATE_BIAS(false)
1541
1542#undef INSTANTIATE_A_SYM
1543#undef INSTANTIATE_B_SYM
1544#undef INSTANTIATE_Q_GRANS
1545#undef INSTANTIATE_BIAS
1546
1547static inline uint16_t floatToHalf(float val) {
1548#ifdef _MSC_VER
1549 // Use _mm256_cvtps_ph/_mm256_cvtph_ps because _cvtsh_ss/_cvtss_sh don't
1550 // exist in MSVC.
1551 __m256 val_v = _mm256_set1_ps(val);
1552 __m128i val_half_v =
1553 _mm256_cvtps_ph(val_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
1554 return static_cast<std::uint16_t>(_mm_cvtsi128_si32(val_half_v));
1555#else
1556 return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
1557#endif
1558}
1559static inline float halfToFloat(uint16_t val) {
1560#ifdef _MSC_VER
1561 return _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_cvtsi32_si128(val)));
1562#else
1563 return _cvtsh_ss(val);
1564#endif
1565}
1566
1567template <typename InputType, int BIT_RATE>
1568void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
1569 const InputType* input,
1570 size_t input_rows,
1571 int input_columns,
1572 std::uint8_t* output) {
1573 static_assert(
1574 std::is_same<InputType, float>() || std::is_same<InputType, float16>(),
1575 "Only float and float16 types are allowed.");
1576 constexpr int VLEN = 8;
1577 constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
1578 int output_columns =
1579 (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE +
1580 2 * sizeof(std::uint16_t);
1581
1582 float* input_row_float_for_fp16;
1583 if (std::is_same<InputType, float16>()) {
1584 input_row_float_for_fp16 = static_cast<float*>(
1585 fbgemmAlignedAlloc(64, input_columns * sizeof(float)));
1586 }
1587
1588 for (size_t row = 0; row < input_rows; ++row) {
1589 const InputType* input_row = input + row * input_columns;
1590 const float* input_row_float;
1591 if (std::is_same<InputType, float>()) {
1592 // NOTE: this reinterpret_cast is only to workaround c++
1593 // type requirements -- it is not for fp16 case and `input_row` HAS to be
1594 // float* type. Remove it and use constexpr when pytorch allows C++17.
1595 input_row_float = reinterpret_cast<const float*>(input_row);
1596 } else {
1597 input_row_float = input_row_float_for_fp16;
1598 }
1599
1600 std::uint8_t* output_row = output + row * output_columns;
1601 std::uint16_t* output_row_scale_bias = reinterpret_cast<std::uint16_t*>(
1602 output_row +
1603 (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
1604
1605 float minimum_element = FLT_MAX;
1606 float maximum_element = -FLT_MAX;
1607 __m256 min_v = _mm256_set1_ps(minimum_element);
1608 __m256 max_v = _mm256_set1_ps(maximum_element);
1609
1610 int col;
1611 for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
1612 __m256 in_v;
1613 if (std::is_same<InputType, float>()) {
1614 in_v = _mm256_loadu_ps(input_row_float + col);
1615 } else {
1616 __m128i in_half_v =
1617 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_row + col));
1618 in_v = _mm256_cvtph_ps(in_half_v);
1619 _mm256_store_ps(input_row_float_for_fp16 + col, in_v);
1620 }
1621
1622 min_v = _mm256_min_ps(min_v, in_v);
1623 max_v = _mm256_max_ps(max_v, in_v);
1624 }
1625 alignas(64) float min_buf[VLEN], max_buf[VLEN];
1626 _mm256_store_ps(min_buf, min_v);
1627 _mm256_store_ps(max_buf, max_v);
1628 for (int i = 0; i < VLEN; ++i) {
1629 minimum_element = std::min(minimum_element, min_buf[i]);
1630 maximum_element = std::max(maximum_element, max_buf[i]);
1631 }
1632
1633 for (; col < input_columns; ++col) {
1634 if (std::is_same<InputType, float>()) {
1635 minimum_element = std::min(minimum_element, input_row_float[col]);
1636 maximum_element = std::max(maximum_element, input_row_float[col]);
1637 } else {
1638 float element = halfToFloat(input_row[col]);
1639 input_row_float_for_fp16[col] = element;
1640 minimum_element = std::min(minimum_element, element);
1641 maximum_element = std::max(maximum_element, element);
1642 }
1643 }
1644
1645 output_row_scale_bias[1] = floatToHalf(minimum_element);
1646 minimum_element = halfToFloat(output_row_scale_bias[1]);
1647 const float range = maximum_element - minimum_element;
1648
1649 float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1);
1650 std::uint16_t scale_fp16 = floatToHalf(scale);
1651 scale = halfToFloat(scale_fp16);
1652 if (scale == 0) {
1653 // Corner case handling when maximum_element == minimum_element
1654 // Any scale would work because maximum_element - minimum_element will be
1655 // 0 for all X
1656 scale = 1.0f;
1657 }
1658 float inverse_scale = 1.0f / scale;
1659 if (std::isinf(inverse_scale)) {
1660 scale = 1.0f;
1661 inverse_scale = 1.0f;
1662 }
1663
1664 output_row_scale_bias[0] = floatToHalf(scale);
1665
1666 col = 0;
1667 if (BIT_RATE == 2 || BIT_RATE == 4) {
1668 __m256i permute_mask1_v =
1669 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
1670 __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
1671 min_v = _mm256_set1_ps(minimum_element);
1672
1673 for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) {
1674 __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1675 _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v),
1676 inverse_scale_v));
1677 __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1678 _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col + VLEN), min_v),
1679 inverse_scale_v));
1680 __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1681 _mm256_sub_ps(
1682 _mm256_loadu_ps(input_row_float + col + 2 * VLEN), min_v),
1683 inverse_scale_v));
1684 __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1685 _mm256_sub_ps(
1686 _mm256_loadu_ps(input_row_float + col + 3 * VLEN), min_v),
1687 inverse_scale_v));
1688
1689 // An instruction sequence to save 32 32-bit integers as 8-bit integers
1690 __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
1691 __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
1692 __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
1693 xyzw_packed_v =
1694 _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v);
1695
1696 // saturate to BIT_RATE
1697 xyzw_packed_v = _mm256_min_epu8(
1698 xyzw_packed_v,
1699 _mm256_set1_epi8(static_cast<char>((1 << BIT_RATE) - 1)));
1700
1701 if (BIT_RATE == 4) {
1702 // pack into lower 8-bit of each 16-bit
1703 xyzw_packed_v = _mm256_and_si256(
1704 _mm256_or_si256(
1705 xyzw_packed_v, _mm256_srli_epi16(xyzw_packed_v, 4)),
1706 _mm256_set1_epi16(0x00ff));
1707 } else {
1708 // pack into lower 8-bit of each 32-bit
1709 xyzw_packed_v = _mm256_and_si256(
1710 _mm256_or_si256(
1711 _mm256_or_si256(
1712 xyzw_packed_v, _mm256_srli_epi32(xyzw_packed_v, 6)),
1713 _mm256_or_si256(
1714 _mm256_srli_epi32(xyzw_packed_v, 8 + 4),
1715 _mm256_srli_epi32(xyzw_packed_v, 2 * 8 + 2))),
1716 _mm256_set1_epi32(0x00ff));
1717 }
1718
1719 __m128i out_v;
1720 if (BIT_RATE == 4) {
1721 // avx2 doesn't have _mm256_cvtepi16_epi8
1722 out_v = _mm_packus_epi16(
1723 _mm256_castsi256_si128(xyzw_packed_v),
1724 _mm256_extractf128_si256(xyzw_packed_v, 1));
1725 _mm_storeu_si128(
1726 reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE),
1727 out_v);
1728 } else {
1729 // avx2 doesn't have _mm256_cvtepi32_epi8
1730 out_v = _mm_packus_epi32(
1731 _mm256_castsi256_si128(xyzw_packed_v),
1732 _mm256_extractf128_si256(xyzw_packed_v, 1));
1733 out_v = _mm_packus_epi16(out_v, out_v);
1734 _mm_storel_epi64(
1735 reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE),
1736 out_v);
1737 }
1738 }
1739 }
1740
1741 for (; col < input_columns; ++col) {
1742 float X = input_row_float[col];
1743 std::uint8_t quantized = std::max(
1744 0,
1745 std::min<int>(
1746 std::lrintf((X - minimum_element) * inverse_scale),
1747 (1 << BIT_RATE) - 1));
1748 if (col % NUM_ELEM_PER_BYTE == 0) {
1749 output_row[col / NUM_ELEM_PER_BYTE] = quantized;
1750 } else {
1751 output_row[col / NUM_ELEM_PER_BYTE] |=
1752 (quantized << ((col % NUM_ELEM_PER_BYTE) * BIT_RATE));
1753 }
1754 }
1755 } // for each row
1756
1757 if (std::is_same<InputType, float16>()) {
1758 fbgemmAlignedFree(input_row_float_for_fp16);
1759 }
1760}
1761
1762template <typename InputType>
1763void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
1764 const InputType* input,
1765 size_t input_rows,
1766 int input_columns,
1767 std::uint8_t* output) {
1768 constexpr int VLEN = 8;
1769 constexpr float kEpsilon = 1e-8f;
1770
1771 __m256i permute_mask1_v =
1772 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
1773 // clang-format off
1774 __m256i shuffle_mask_v = _mm256_set_epi8(
1775 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
1776 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00,
1777 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
1778 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00);
1779 // clang-format on
1780
1781 __m256i permute_mask2_v =
1782 _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
1783
1784 int output_columns = input_columns + 2 * sizeof(float);
1785 float* input_row_float_for_fp16;
1786 if (std::is_same<InputType, float16>()) {
1787 input_row_float_for_fp16 = static_cast<float*>(
1788 fbgemmAlignedAlloc(64, input_columns * sizeof(float)));
1789 }
1790 for (size_t row = 0; row < input_rows; ++row) {
1791 const InputType* input_row = input + row * input_columns;
1792 const float* input_row_float;
1793 if (std::is_same<InputType, float>()) {
1794 // NOTE: this reinterpret_cast is only to workaround c++
1795 // type requirements -- it is not for fp16 case and `input_row` HAS to be
1796 // float* type. Remove it and use constexpr when pytorch allows C++17.
1797 input_row_float = reinterpret_cast<const float*>(input_row);
1798 } else {
1799 input_row_float = input_row_float_for_fp16;
1800 }
1801 std::uint8_t* output_row = output + row * output_columns;
1802 float* output_row_scale_bias =
1803 reinterpret_cast<float*>(output_row + input_columns);
1804
1805 float minimum_element = FLT_MAX;
1806 float maximum_element = -FLT_MAX;
1807 __m256 min_v = _mm256_set1_ps(minimum_element);
1808 __m256 max_v = _mm256_set1_ps(maximum_element);
1809 int col;
1810 for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
1811 __m256 in_v;
1812 if (std::is_same<InputType, float>()) {
1813 in_v = _mm256_loadu_ps(input_row_float + col);
1814 } else {
1815 __m128i in_half_v =
1816 _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_row + col));
1817 in_v = _mm256_cvtph_ps(in_half_v);
1818 _mm256_store_ps(input_row_float_for_fp16 + col, in_v);
1819 }
1820 min_v = _mm256_min_ps(min_v, in_v);
1821 max_v = _mm256_max_ps(max_v, in_v);
1822 }
1823 alignas(64) float min_buf[VLEN], max_buf[VLEN];
1824 _mm256_store_ps(min_buf, min_v);
1825 _mm256_store_ps(max_buf, max_v);
1826 for (int i = 0; i < VLEN; ++i) {
1827 minimum_element = std::min(minimum_element, min_buf[i]);
1828 maximum_element = std::max(maximum_element, max_buf[i]);
1829 }
1830
1831 for (; col < input_columns; ++col) {
1832 if (std::is_same<InputType, float>()) {
1833 minimum_element = std::min(minimum_element, input_row_float[col]);
1834 maximum_element = std::max(maximum_element, input_row_float[col]);
1835 } else {
1836 float element = halfToFloat(input_row[col]);
1837 input_row_float_for_fp16[col] = element;
1838 minimum_element = std::min(minimum_element, element);
1839 maximum_element = std::max(maximum_element, element);
1840 }
1841 }
1842
1843 float range = maximum_element - minimum_element;
1844 output_row_scale_bias[0] = range / 255.0f;
1845 output_row_scale_bias[1] = minimum_element;
1846 const auto inverse_scale = 255.0f / (range + kEpsilon);
1847 min_v = _mm256_set1_ps(minimum_element);
1848 __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);
1849
1850 for (col = 0; col < input_columns / (4 * VLEN) * (4 * VLEN);
1851 col += 4 * VLEN) {
1852 __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1853 _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v),
1854 inverse_scale_v));
1855 __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1856 _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col + VLEN), min_v),
1857 inverse_scale_v));
1858 __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1859 _mm256_sub_ps(
1860 _mm256_loadu_ps(input_row_float + col + 2 * VLEN), min_v),
1861 inverse_scale_v));
1862 __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1863 _mm256_sub_ps(
1864 _mm256_loadu_ps(input_row_float + col + 3 * VLEN), min_v),
1865 inverse_scale_v));
1866
1867 // An instruction sequence to save 32 32-bit integers as 8-bit integers
1868 __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
1869 __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
1870 __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
1871 xyzw_packed_v =
1872 _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v);
1873 _mm256_storeu_si256(
1874 reinterpret_cast<__m256i*>(output_row + col), xyzw_packed_v);
1875 }
1876 for (; col < input_columns / VLEN * VLEN; col += VLEN) {
1877 __m256i rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
1878 _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v),
1879 inverse_scale_v));
1880
1881 // An instruction sequence to save 8 32-bit integers as 8-bit integers
1882 rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v);
1883 rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask2_v);
1884 _mm_storel_epi64(
1885 reinterpret_cast<__m128i*>(output_row + col),
1886 _mm256_castsi256_si128(rounded_v));
1887 }
1888 for (; col < input_columns; ++col) {
1889 output_row[col] =
1890 std::lrintf((input_row_float[col] - minimum_element) * inverse_scale);
1891 }
1892 } // for each row
1893 if (std::is_same<InputType, float16>()) {
1894 fbgemmAlignedFree(input_row_float_for_fp16);
1895 }
1896}
1897
1898template <typename OutputType, int BIT_RATE>
1899void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
1900 const std::uint8_t* input,
1901 size_t input_rows,
1902 int input_columns,
1903 OutputType* output) {
1904 static_assert(
1905 std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(),
1906 "Only float and float16 types are allowed.");
1907 constexpr int VLEN = 8;
1908 constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE;
1909 int output_columns =
1910 (input_columns - 2 * sizeof(uint16_t)) * NUM_ELEM_PER_BYTE;
1911
1912 // Compute a remainder for vector load
1913 // Since every row is followed by 2 fp16 (scale and bias), luckily
1914 // we don't need mask at bit-rate granularity but just at 32-bit
1915 // granularity.
1916 constexpr int NUM_ELEM_PER_32BIT = 32 / BIT_RATE;
1917 // multiply by 4 because we're handling 4 vlen per iteration
1918 constexpr int NUM_OF_32BIT_PER_VLOAD = VLEN * 4 / NUM_ELEM_PER_32BIT;
1919
1920 int remainder_32bit_granularity, remainder;
1921 __m128i vmask_load;
1922 __m256i vmask_store0, vmask_store1, vmask_store2, vmask_store3;
1923 if (BIT_RATE == 4 || BIT_RATE == 2) {
1924 remainder_32bit_granularity = (output_columns + NUM_ELEM_PER_32BIT - 1) /
1925 NUM_ELEM_PER_32BIT % NUM_OF_32BIT_PER_VLOAD;
1926 vmask_load = _mm_lddqu_si128(reinterpret_cast<const __m128i*>(
1927 internal::avx2_ps_or_epi32_combined_mask + NUM_OF_32BIT_PER_VLOAD +
1928 (NUM_OF_32BIT_PER_VLOAD - remainder_32bit_granularity) %
1929 NUM_OF_32BIT_PER_VLOAD));
1930 remainder = output_columns % (4 * VLEN);
1931 int remainder_ratio = 1;
1932 if (std::is_same<OutputType, float16>()) {
1933 // For fp16 we only need half of the mask.
1934 //
1935 // For instance, if reminder is 2, for FP32 the masks are
1936 // {-1, -1, 0, ..., 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0}
1937 // (8 32-bit integers for each mask)
1938 // for FP16 we only need
1939 // {-1, 0, 0, 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0}
1940 // (4 32-bit integers for each mask)
1941 // since we reinterpret 2 FP16 numbers as one 32-bit number.
1942 // NOTE: for bit_rate 4 or 2, reminders are always multiple of 2 or 4,
1943 // so we do have to worry about odd number of FP16 numbers.
1944 //
1945 // Or, if reminder is 30, for FP32 the masks are
1946 // {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, .., -1, 0, 0}
1947 // for FP16 we only need
1948 // {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, -1, -1, 0}
1949 remainder_ratio = 2;
1950 }
1951 vmask_store0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1952 internal::avx2_ps_or_epi32_combined_mask +
1953 (VLEN - std::min(remainder, VLEN) / remainder_ratio % (VLEN + 1))));
1954 vmask_store1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1955 internal::avx2_ps_or_epi32_combined_mask +
1956 (VLEN -
1957 std::max(0, std::min(remainder - VLEN, VLEN) / remainder_ratio) %
1958 (VLEN + 1))));
1959 vmask_store2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1960 internal::avx2_ps_or_epi32_combined_mask +
1961 (VLEN -
1962 std::max(0, std::min(remainder - 2 * VLEN, VLEN) / remainder_ratio) %
1963 (VLEN + 1))));
1964 vmask_store3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
1965 internal::avx2_ps_or_epi32_combined_mask +
1966 (VLEN -
1967 std::max(0, std::min(remainder - 3 * VLEN, VLEN) / remainder_ratio) %
1968 (VLEN + 1))));
1969 }
1970
1971 for (size_t row = 0; row < input_rows; ++row) {
1972 const std::uint8_t* input_row = input + row * input_columns;
1973 const uint16_t* input_row_scale_bias = reinterpret_cast<const uint16_t*>(
1974 input_row +
1975 (output_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE);
1976 float scale = halfToFloat(input_row_scale_bias[0]);
1977 float bias = halfToFloat(input_row_scale_bias[1]);
1978 OutputType* output_row = output + row * output_columns;
1979 float* output_row_float;
1980 if (std::is_same<OutputType, float>()) {
1981 // NOTE: this reinterpret_cast is only to workaround c++
1982 // type requirements -- it is not for fp16 case and `output_row` HAS to be
1983 // float* type. Remove it and use constexpr when pytorch allows C++17.
1984 output_row_float = reinterpret_cast<float*>(output_row);
1985 }
1986
1987 int col = 0;
1988 if (BIT_RATE == 4 || BIT_RATE == 2) {
1989 __m256 vscale = _mm256_set1_ps(scale);
1990 __m256 vbias = _mm256_set1_ps(bias);
1991 for (; col + 4 * VLEN <= output_columns; col += 4 * VLEN) {
1992 __m256i vinq;
1993 // unpack to 8-bit integers
1994 if (BIT_RATE == 4) {
1995 vinq = _mm256_cvtepu8_epi16(
1996 _mm_loadu_si128(reinterpret_cast<const __m128i*>(
1997 input_row + col / NUM_ELEM_PER_BYTE)));
1998 vinq = _mm256_and_si256(
1999 _mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)),
2000 _mm256_set1_epi16(0x0f0f));
2001 } else {
2002 vinq = _mm256_cvtepu8_epi32(
2003 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(
2004 input_row + col / NUM_ELEM_PER_BYTE)));
2005 vinq = _mm256_and_si256(
2006 _mm256_or_si256(
2007 _mm256_or_si256(
2008 _mm256_slli_epi32(vinq, 2 * 8 + 2),
2009 _mm256_slli_epi32(vinq, 8 + 4)),
2010 _mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)),
2011 _mm256_set1_epi32(0x03030303));
2012 }
2013 __m256 vinq0 = _mm256_cvtepi32_ps(
2014 _mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq)));
2015 __m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
2016 _mm_set1_epi64x(_mm256_extract_epi64(vinq, 1))));
2017 __m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
2018 _mm_set1_epi64x(_mm256_extract_epi64(vinq, 2))));
2019 __m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
2020 _mm_set1_epi64x(_mm256_extract_epi64(vinq, 3))));
2021 vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias);
2022 vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias);
2023 vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias);
2024 vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias);
2025
2026 if (std::is_same<OutputType, float>()) {
2027 _mm256_storeu_ps(output_row_float + col, vinq0);
2028 _mm256_storeu_ps(output_row_float + col + VLEN, vinq1);
2029 _mm256_storeu_ps(output_row_float + col + 2 * VLEN, vinq2);
2030 _mm256_storeu_ps(output_row_float + col + 3 * VLEN, vinq3);
2031 } else {
2032 _mm_storeu_si128(
2033 reinterpret_cast<__m128i*>(output_row + col),
2034 _mm256_cvtps_ph(
2035 vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2036 _mm_storeu_si128(
2037 reinterpret_cast<__m128i*>(output_row + col + VLEN),
2038 _mm256_cvtps_ph(
2039 vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2040 _mm_storeu_si128(
2041 reinterpret_cast<__m128i*>(output_row + col + 2 * VLEN),
2042 _mm256_cvtps_ph(
2043 vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2044 _mm_storeu_si128(
2045 reinterpret_cast<__m128i*>(output_row + col + 3 * VLEN),
2046 _mm256_cvtps_ph(
2047 vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2048 }
2049 }
2050
2051 if (remainder) {
2052 __m256i vinq;
2053 if (BIT_RATE == 4) {
2054 vinq = _mm256_cvtepu8_epi16(_mm_maskload_epi32(
2055 reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE),
2056 vmask_load));
2057 vinq = _mm256_and_si256(
2058 _mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)),
2059 _mm256_set1_epi16(0x0f0f));
2060 } else {
2061 vinq = _mm256_cvtepu8_epi32(_mm_maskload_epi32(
2062 reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE),
2063 vmask_load));
2064 vinq = _mm256_and_si256(
2065 _mm256_or_si256(
2066 _mm256_or_si256(
2067 _mm256_slli_epi32(vinq, 2 * 8 + 2),
2068 _mm256_slli_epi32(vinq, 8 + 4)),
2069 _mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)),
2070 _mm256_set1_epi32(0x03030303));
2071 }
2072
2073 __m256 vinq0 = _mm256_cvtepi32_ps(
2074 _mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq)));
2075 __m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
2076 _mm_set1_epi64x(_mm256_extract_epi64(vinq, 1))));
2077 __m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
2078 _mm_set1_epi64x(_mm256_extract_epi64(vinq, 2))));
2079 __m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
2080 _mm_set1_epi64x(_mm256_extract_epi64(vinq, 3))));
2081
2082 vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias);
2083 vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias);
2084 vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias);
2085 vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias);
2086
2087 if (std::is_same<OutputType, float>()) {
2088 _mm256_maskstore_ps(output_row_float + col, vmask_store0, vinq0);
2089 _mm256_maskstore_ps(
2090 output_row_float + col + VLEN, vmask_store1, vinq1);
2091 _mm256_maskstore_ps(
2092 output_row_float + col + 2 * VLEN, vmask_store2, vinq2);
2093 _mm256_maskstore_ps(
2094 output_row_float + col + 3 * VLEN, vmask_store3, vinq3);
2095 } else {
2096 _mm_maskstore_epi32(
2097 reinterpret_cast<int*>(output_row + col),
2098 _mm256_castsi256_si128(vmask_store0),
2099 _mm256_cvtps_ph(
2100 vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2101 _mm_maskstore_epi32(
2102 reinterpret_cast<int*>(output_row + col + VLEN),
2103 _mm256_castsi256_si128(vmask_store1),
2104 _mm256_cvtps_ph(
2105 vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2106 _mm_maskstore_epi32(
2107 reinterpret_cast<int*>(output_row + col + 2 * VLEN),
2108 _mm256_castsi256_si128(vmask_store2),
2109 _mm256_cvtps_ph(
2110 vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2111 _mm_maskstore_epi32(
2112 reinterpret_cast<int*>(output_row + col + 3 * VLEN),
2113 _mm256_castsi256_si128(vmask_store3),
2114 _mm256_cvtps_ph(
2115 vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2116 }
2117 }
2118 } else {
2119 for (; col < output_columns; ++col) {
2120 std::uint8_t quantized = input_row[col / NUM_ELEM_PER_BYTE];
2121 quantized >>= (col % NUM_ELEM_PER_BYTE) * BIT_RATE;
2122 quantized &= (1 << BIT_RATE) - 1;
2123 float output_value = scale * quantized + bias;
2124 if (std::is_same<OutputType, float>()) {
2125 output_row[col] = output_value;
2126 } else {
2127 output_row[col] = cpu_float2half_rn(output_value);
2128 }
2129 }
2130 }
2131 } // for each row
2132}
2133
2134template <typename OutputType>
2135void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
2136 const std::uint8_t* input,
2137 size_t input_rows,
2138 int input_columns,
2139 OutputType* output) {
2140 constexpr int VLEN = 8;
2141 int output_columns = input_columns - 2 * sizeof(float);
2142
2143 for (size_t row = 0; row < input_rows; ++row) {
2144 const std::uint8_t* input_row = input + row * input_columns;
2145 const float* input_row_scale_bias =
2146 reinterpret_cast<const float*>(input_row + output_columns);
2147 OutputType* output_row = output + row * output_columns;
2148
2149 __m256 scale_v = _mm256_set1_ps(input_row_scale_bias[0]);
2150 __m256 bias_v = _mm256_set1_ps(input_row_scale_bias[1]);
2151
2152 int col;
2153 for (col = 0; col < output_columns / VLEN * VLEN; col += VLEN) {
2154 __m256 in_v = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2155 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input_row + col))));
2156#ifdef __FMA__
2157 __m256 dequantzed_v = _mm256_fmadd_ps(in_v, scale_v, bias_v);
2158#else
2159 __m256 dequantzed_v = _mm256_add_ps(_mm256_mul_ps(in_v, scale_v), bias_v);
2160#endif
2161 if (std::is_same<OutputType, float>()) {
2162 float* output_row_float = reinterpret_cast<float*>(output_row);
2163 _mm256_storeu_ps(output_row_float + col, dequantzed_v);
2164 } else {
2165 _mm_storeu_si128(
2166 reinterpret_cast<__m128i*>(output_row + col),
2167 _mm256_cvtps_ph(
2168 dequantzed_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
2169 }
2170 }
2171
2172 for (; col < output_columns; ++col) {
2173 float output_value =
2174 input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
2175 if (std::is_same<OutputType, float>()) {
2176 output_row[col] = output_value;
2177 } else {
2178 output_row[col] = cpu_float2half_rn(output_value);
2179 }
2180 }
2181 } // for each row
2182}
2183
2184#define INSTANTIATE_QuantizationAvx2FunctionsNBits(type, bit_rate) \
2185 template void \
2186 FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<type, bit_rate>( \
2187 const type* input, \
2188 size_t input_rows, \
2189 int input_columns, \
2190 std::uint8_t* output); \
2191 template void \
2192 FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<type, bit_rate>( \
2193 const std::uint8_t* input, \
2194 size_t input_rows, \
2195 int input_columns, \
2196 type* output);
2197
2198// clang-format off
2199INSTANTIATE_QuantizationAvx2FunctionsNBits(float, 2)
2200INSTANTIATE_QuantizationAvx2FunctionsNBits(float, 4)
2201INSTANTIATE_QuantizationAvx2FunctionsNBits(float, 8)
2202INSTANTIATE_QuantizationAvx2FunctionsNBits(float16, 2)
2203INSTANTIATE_QuantizationAvx2FunctionsNBits(float16, 4)
2204INSTANTIATE_QuantizationAvx2FunctionsNBits(float16, 8)
2205// clang-format on
2206#undef INSTANTIATE_QuantizationAvx2FunctionsNBits
2207
2208#define INSTANTIATE_QuantizationAvx2Functions8Bits(type) \
2209 template void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<type>( \
2210 const type* input, \
2211 size_t input_rows, \
2212 int input_columns, \
2213 std::uint8_t* output); \
2214 template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<type>( \
2215 const std::uint8_t* input, \
2216 size_t input_rows, \
2217 int input_columns, \
2218 type* output);
2219
2220 // clang-format off
2221INSTANTIATE_QuantizationAvx2Functions8Bits(float)
2222INSTANTIATE_QuantizationAvx2Functions8Bits(float16)
2223// clang-format on
2224#undef INSTANTIATE_QuantizationAvx2Functions8Bits
2225
2226} // namespace fbgemm
2227