1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/QuantUtilsAvx512.h"
9#if defined(__x86_64__) || defined(__i386__) || \
10 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
11#include <immintrin.h>
12#endif
13#include <algorithm> //for std::min/std::max
14#include <cassert>
15#include <cmath> //for nearbyint
16#include <limits> //for numeric_limits
17
18namespace fbgemm {
19
20using namespace std;
21template <
22 bool A_SYMMETRIC,
23 bool B_SYMMETRIC,
24 QuantizationGranularity Q_GRAN,
25 bool HAS_BIAS,
26 bool FUSE_RELU,
27 int C_PER_G,
28 typename BIAS_TYPE>
29void requantizeOutputProcessingGConvAvx512(
30 std::uint8_t* out,
31 const std::int32_t* inp,
32 const block_type_t& block,
33 int ld_out,
34 int ld_in,
35 const requantizationParams_t<BIAS_TYPE>& r) {
36 // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
37 // using AVX2 instructions
38 int quant_param_idx = 0;
39 if (Q_GRAN == QuantizationGranularity::GROUP) {
40 int ncol_per_group = r.ncols / r.groups;
41 int g = block.col_start / ncol_per_group;
42 quant_param_idx = g;
43 }
44 __m512 multiplier_v = _mm512_set1_ps(r.C_multiplier[quant_param_idx]);
45 // Broadcasted reciprocal of act_times_w_scale
46 __m512 act_times_w_rcp_v;
47
48 if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
49 if (is_same<BIAS_TYPE, float>::value) {
50 act_times_w_rcp_v =
51 _mm512_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
52 }
53 }
54 __m512i min_v = _mm512_set1_epi8(static_cast<uint8_t>(0));
55 __m512i max_v = _mm512_set1_epi8(static_cast<uint8_t>(255));
56
57 assert(
58 (A_SYMMETRIC == (r.A_zero_point == 0)) &&
59 "A_SYMMETRIC == true if and only if A_zero_point == 0");
60 assert(
61 (B_SYMMETRIC ==
62 ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) ||
63 r.row_offsets == nullptr)) &&
64 "B_SYMMETRIC == true if and only if B_zero_point == 0 "
65 "or r.row_offsets == nullptr");
66 assert(
67 (HAS_BIAS == (r.bias != nullptr)) &&
68 "HAS_BIAS == true if and only if bias != nullptr");
69
70 __m512i A_zero_point_v = _mm512_set1_epi32(r.A_zero_point);
71 __m512i C_zero_point_epi16_v = _mm512_set1_epi16(r.C_zero_point);
72 __m512i C_zero_point_epi8_v = _mm512_set1_epi8(r.C_zero_point);
73 __m512i permute_mask_v_g8 = _mm512_set_epi32(
74 0x0f,
75 0x07,
76 0x0e,
77 0x06,
78 0x0d,
79 0x05,
80 0x0c,
81 0x04,
82 0x0b,
83 0x03,
84 0x0a,
85 0x02,
86 0x09,
87 0x01,
88 0x08,
89 0x00);
90
91 __m512i permute_mask_v_g4 = _mm512_set_epi32(
92 0x0f,
93 0x0b,
94 0x07,
95 0x03,
96 0x0e,
97 0x0a,
98 0x06,
99 0x02,
100 0x0d,
101 0x09,
102 0x05,
103 0x01,
104 0x0c,
105 0x08,
106 0x04,
107 0x00);
108 // vector lane width 16 * 32 = 512 bits
109 constexpr int VLEN = 16;
110 const __mmask16 mask = 0x00ff;
111
112 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
113 int j = block.col_start;
114 // changed the iteration termination criteria for C_per_g = 8
115 // for avx512 currently all 4 cases supported will only run one iteration of
116 // inner loop
117 // for C_per_g == 8, we only have 8 outputs while the other cases have 16.
118 // thus, we do masked load for all col quantization scheme under C_per_g ==
119 // 8
120 for (; j < block.col_start + ((block.col_size + VLEN - 1) / VLEN * VLEN);
121 j += VLEN) {
122 __m512i x_v;
123 if (C_PER_G != 8) {
124 x_v = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(
125 inp + (i - block.row_start) * ld_in + (j - block.col_start)));
126 } else {
127 // as of now we only have C_per_G = 2,4,8,16 thus this j loop all only
128 // execute one iteration, the following point will be wrong if run more
129 // than one iter
130 x_v = _mm512_maskz_loadu_epi32(
131 mask, inp + (i - block.row_start) * ld_in + (j - block.col_start));
132 }
133
134 if (!A_SYMMETRIC) {
135 __m512i col_off_raw_v;
136 if (C_PER_G != 8) {
137 col_off_raw_v = _mm512_loadu_si512(
138 reinterpret_cast<const __m512i*>(r.col_offsets + j));
139 } else {
140 col_off_raw_v = _mm512_maskz_loadu_epi32(mask, r.col_offsets + j);
141 }
142
143 __m512i col_off_v = _mm512_mullo_epi32(A_zero_point_v, col_off_raw_v);
144 x_v = _mm512_sub_epi32(x_v, col_off_v);
145 }
146
147 if (!B_SYMMETRIC) {
148 __m512i row_offset_v;
149
150 if (C_PER_G == 2) {
151 // When C_PER_G == 2, we need to handle 8 groups at a time to fully
152 // utilize 64B AVX12 vector register (C_PER_G * 8 * sizeof(int32_t) ==
153 // 64B)
154 // Load row_offsets for 8 groups and broadcast by 2 times.
155 row_offset_v =
156 _mm512_castps_si512(_mm512_moveldup_ps(_mm512_permutexvar_ps(
157 permute_mask_v_g8,
158 _mm512_castps256_ps512(
159 _mm256_loadu_ps(reinterpret_cast<const float*>(
160 r.row_offsets + (i - block.row_start) * 8))))));
161
162 }
163 // When C_PER_G == 4, we need to handle 4 groups at a time to fully
164 // utilize 32B AVX2 vector register (C_PER_G * 4 * sizeof(int32_t) ==
165 // 32B)
166 // When C_PER_G == 8, we just need 1 group at a time on the other hand.
167
168 // Groups 0,1,2,3 when C_PER_G == 4
169 // Group 0 when C_PER_G == 8
170 else if (C_PER_G == 4) {
171 // Load row_offsets for 4 groups and broadcast by 4 times each because
172 // we have 4 channels per group.
173 // groups 0,1,2,3
174 row_offset_v = _mm512_permutexvar_epi32(
175 permute_mask_v_g4,
176 _mm512_broadcast_i32x4(
177 _mm_loadu_si128(reinterpret_cast<const __m128i*>(
178 r.row_offsets + (i - block.row_start) * 4))));
179 } else if (C_PER_G == 8) {
180 row_offset_v =
181 _mm512_set1_epi32(r.row_offsets[(i - block.row_start)]);
182 } else {
183 assert(C_PER_G == 16);
184 row_offset_v =
185 _mm512_set1_epi32(r.row_offsets[(i - block.row_start)]);
186 }
187
188 __m512i B_zero_point_v = _mm512_set1_epi32(r.B_zero_point[0]);
189 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
190 if (C_PER_G != 8) {
191 B_zero_point_v = _mm512_loadu_si512(
192 reinterpret_cast<const __m512i*>(r.B_zero_point + j));
193 } else {
194 B_zero_point_v = _mm512_maskz_loadu_epi32(mask, r.B_zero_point + j);
195 }
196 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
197 if (C_PER_G == 2) {
198 B_zero_point_v =
199 _mm512_castps_si512(_mm512_moveldup_ps(_mm512_permutexvar_ps(
200 permute_mask_v_g8,
201 _mm512_castps256_ps512(
202 _mm256_loadu_ps(reinterpret_cast<const float*>(
203 r.B_zero_point + quant_param_idx))))));
204 } else if (C_PER_G == 4) {
205 B_zero_point_v = _mm512_permutexvar_epi32(
206 permute_mask_v_g4,
207 _mm512_broadcast_i32x4(
208 _mm_loadu_si128(reinterpret_cast<const __m128i*>(
209 r.B_zero_point + quant_param_idx))));
210 } else if (C_PER_G == 8) {
211 B_zero_point_v = _mm512_set1_epi32(r.B_zero_point[quant_param_idx]);
212 } else {
213 B_zero_point_v = _mm512_set1_epi32(r.B_zero_point[quant_param_idx]);
214 }
215 }
216 row_offset_v = _mm512_mullo_epi32(row_offset_v, B_zero_point_v);
217 x_v = _mm512_sub_epi32(x_v, row_offset_v);
218 }
219 __m512 xf_v;
220 if (HAS_BIAS) {
221 if (is_same<BIAS_TYPE, float>::value) {
222 __m512 x_bias_v;
223 if (C_PER_G != 8) {
224 x_bias_v =
225 _mm512_loadu_ps(reinterpret_cast<const float*>(r.bias + j));
226 } else {
227 x_bias_v = _mm512_maskz_loadu_ps(
228 mask, reinterpret_cast<const float*>(r.bias + j));
229 }
230
231 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
232 __m512 act_times_w_scale_v;
233 if (C_PER_G != 8) {
234 act_times_w_scale_v = _mm512_loadu_ps(r.act_times_w_scale + j);
235 } else {
236 act_times_w_scale_v =
237 _mm512_maskz_loadu_ps(mask, r.act_times_w_scale + j);
238 }
239 x_bias_v = _mm512_div_ps(x_bias_v, act_times_w_scale_v);
240 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
241 __m512 diviser_v;
242 if (C_PER_G == 2) {
243 diviser_v = _mm512_moveldup_ps(_mm512_permutexvar_ps(
244 permute_mask_v_g8,
245 _mm512_castps256_ps512(
246 _mm256_loadu_ps(r.act_times_w_scale + quant_param_idx))));
247 } else if (C_PER_G == 4) {
248 diviser_v = _mm512_permutexvar_ps(
249 permute_mask_v_g4,
250 _mm512_broadcast_f32x4(
251
252 _mm_loadu_ps(r.act_times_w_scale + quant_param_idx)));
253 } else if (C_PER_G == 8) {
254 diviser_v = _mm512_set1_ps(r.act_times_w_scale[quant_param_idx]);
255 } else {
256 assert(C_PER_G == 16);
257 diviser_v = _mm512_set1_ps(r.act_times_w_scale[quant_param_idx]);
258 }
259 x_bias_v = _mm512_div_ps(x_bias_v, diviser_v);
260 } else {
261 x_bias_v = _mm512_mul_ps(x_bias_v, act_times_w_rcp_v);
262 }
263 xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x_v), x_bias_v);
264 } else {
265 x_v = _mm512_add_epi32(
266 x_v,
267 _mm512_loadu_si512(reinterpret_cast<const __m512i*>(r.bias + j)));
268 xf_v = _mm512_cvtepi32_ps(x_v);
269 }
270 } else {
271 xf_v = _mm512_cvtepi32_ps(x_v);
272 }
273
274 /*
275 * Convert int32_t input to FP32 and multiply by FP32 scale.
276 * Both operations involve statistically unbiased roundings (with
277 * default MXCSR rounding mode):
278 * - Large int32_t values can't be exactly represented as FP32.
279 * CVTDQ2PS instruction on x86 would round it according to nearest
280 * FP32 value with ties to even (assuming default MXCSR rounding
281 * mode).
282 * - Product of two FP32 values is generally not exactly
283 * representation as an FP32 value, and will be rounded to nearest
284 * FP32 value with ties to even with default MXCSR rounding mode.
285 */
286 __m512 x_scaled_v;
287 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
288 __m512 C_multiplier_v;
289 if (C_PER_G != 8) {
290 C_multiplier_v = _mm512_loadu_ps(r.C_multiplier + j);
291 } else {
292 C_multiplier_v = _mm512_maskz_loadu_ps(mask, r.C_multiplier + j);
293 }
294 x_scaled_v = _mm512_mul_ps(xf_v, C_multiplier_v);
295 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
296 if (C_PER_G == 2) {
297 multiplier_v = _mm512_moveldup_ps(_mm512_permutexvar_ps(
298 permute_mask_v_g8,
299 _mm512_castps256_ps512(
300 _mm256_loadu_ps(r.C_multiplier + quant_param_idx))));
301 } else if (C_PER_G == 4) {
302 multiplier_v = _mm512_permutexvar_ps(
303 permute_mask_v_g4,
304 _mm512_broadcast_f32x4(
305 _mm_loadu_ps(r.C_multiplier + quant_param_idx)));
306 } else if (C_PER_G == 8) {
307 multiplier_v = _mm512_set1_ps(r.C_multiplier[quant_param_idx]);
308 } else {
309 multiplier_v = _mm512_set1_ps(r.C_multiplier[quant_param_idx]);
310 }
311 x_scaled_v = _mm512_mul_ps(xf_v, multiplier_v);
312 } else {
313 x_scaled_v = _mm512_mul_ps(xf_v, multiplier_v);
314 }
315
316 /*
317 * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction.
318 * CVTPS2DQ instruction rounds result according to nearest FP32 value
319 * with ties to even (assuming default MXCSR rounding mode). However,
320 * when conversion overflows, it produces INT32_MIN as a result. For
321 * large positive inputs the result of conversion can become negative,
322 * which affects the final requantization result. Note that on x86
323 * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This
324 * happens because float(INT32_MAX) rounds to 2**31, which overflows
325 * int32_t when it is converted back to integer.
326 *
327 * Thankfully, we can prove that overflow never happens in this
328 * requantization scheme. The largest positive input is INT32_MAX
329 * (2**31 - 1), which turns into 2**31 when converted to float. The
330 * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the
331 * result is 2147483520 (compare to INT32_MAX = 2147483647), which
332 * fits into int32_t without overflow.
333 */
334 __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v);
335
336 /*
337 * Standard final sequence on x86 AVX512:
338 * - Pack to int16_t and saturate
339 * - Add zero point
340 * - Pack to uint8_t and saturate
341 * - Clamp between qmin and qmax
342 */
343 __m512i x_packed_v = _mm512_adds_epi16(
344 _mm512_packs_epi32(x_rounded_v, _mm512_setzero_si512()),
345 C_zero_point_epi16_v);
346 x_packed_v = _mm512_packus_epi16(x_packed_v, _mm512_setzero_si512());
347 __m512i x_clamped_v = _mm512_max_epu8(
348 FUSE_RELU ? C_zero_point_epi8_v : min_v,
349 _mm512_min_epu8(x_packed_v, max_v));
350
351 /*
352 * x_clamped_v has results in the following layout so we need to
353 * permute: x0-3 garbage0-11 x4-7 garbage12-23 x8-11 garbage24-35 x12-15
354 * garbage36-47
355 */
356 x_clamped_v = _mm512_permutexvar_epi32(permute_mask_v_g4, x_clamped_v);
357
358 /*
359 * 1x CVTDQ2PS
360 * 1x MULPS
361 * 1x CVTPS2DQ
362 * 1x PACKSSDW
363 * 1x PACKUSWB
364 * 1x PADDW
365 * 1x PMAXUB
366 * 1x PMINUB
367 * 1x PERMD
368 * ---------------------
369 * 9 instructions total
370 */
371 if (C_PER_G != 8) {
372 _mm_storeu_si128(
373 reinterpret_cast<__m128i*>(out + i * ld_out + j),
374 _mm512_castsi512_si128(x_clamped_v));
375 } else {
376 _mm_storel_epi64(
377 reinterpret_cast<__m128i*>(out + i * ld_out + j),
378 _mm512_castsi512_si128(x_clamped_v));
379 }
380 } // j loop vectorized
381
382#ifndef NDEBUG
383 int remainder = block.col_start + block.col_size - j;
384 assert(remainder == 0 || C_PER_G == 8);
385#endif
386 } // i loop
387}
388
389#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
390 A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
391 template void requantizeOutputProcessingGConvAvx512< \
392 A_SYM, \
393 B_SYM, \
394 Q_GRAN, \
395 BIAS, \
396 RELU, \
397 2, \
398 BIAS_TYPE>( \
399 uint8_t * out, \
400 const int32_t* inp, \
401 const block_type_t& block, \
402 int ld_out, \
403 int ld_in, \
404 const requantizationParams_t<BIAS_TYPE>& r); \
405 template void requantizeOutputProcessingGConvAvx512< \
406 A_SYM, \
407 B_SYM, \
408 Q_GRAN, \
409 BIAS, \
410 RELU, \
411 4, \
412 BIAS_TYPE>( \
413 uint8_t * out, \
414 const int32_t* inp, \
415 const block_type_t& block, \
416 int ld_out, \
417 int ld_in, \
418 const requantizationParams_t<BIAS_TYPE>& r); \
419 template void requantizeOutputProcessingGConvAvx512< \
420 A_SYM, \
421 B_SYM, \
422 Q_GRAN, \
423 BIAS, \
424 RELU, \
425 8, \
426 BIAS_TYPE>( \
427 uint8_t * out, \
428 const int32_t* inp, \
429 const block_type_t& block, \
430 int ld_out, \
431 int ld_in, \
432 const requantizationParams_t<BIAS_TYPE>& r); \
433 template void requantizeOutputProcessingGConvAvx512< \
434 A_SYM, \
435 B_SYM, \
436 Q_GRAN, \
437 BIAS, \
438 RELU, \
439 16, \
440 BIAS_TYPE>( \
441 uint8_t * out, \
442 const int32_t* inp, \
443 const block_type_t& block, \
444 int ld_out, \
445 int ld_in, \
446 const requantizationParams_t<BIAS_TYPE>& r);
447
448#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
449 INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \
450 INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t)
451
452#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
453 INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \
454 INSTANTIATE_REQUANTIZE(false, B_SYM, Q_GRAN, BIAS, RELU)
455
456#define INSTANTIATE_B_SYM(Q_GRAN, BIAS, RELU) \
457 INSTANTIATE_A_SYM(true, Q_GRAN, BIAS, RELU) \
458 INSTANTIATE_A_SYM(false, Q_GRAN, BIAS, RELU)
459
460#define INSTANTIATE_Q_GRANS(BIAS, RELU) \
461 INSTANTIATE_B_SYM(QuantizationGranularity::TENSOR, BIAS, RELU) \
462 INSTANTIATE_B_SYM(QuantizationGranularity::GROUP, BIAS, RELU) \
463 INSTANTIATE_B_SYM(QuantizationGranularity::OUT_CHANNEL, BIAS, RELU)
464
465#define INSTANTIATE_BIAS(RELU) \
466 INSTANTIATE_Q_GRANS(true, RELU) \
467 INSTANTIATE_Q_GRANS(false, RELU)
468
469INSTANTIATE_BIAS(true)
470INSTANTIATE_BIAS(false)
471
472#undef INSTANTIATE_A_SYM
473#undef INSTANTIATE_B_SYM
474#undef INSTANTIATE_Q_GRANS
475#undef INSTANTIATE_BIAS
476} // namespace fbgemm
477