1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/QuantUtilsAvx512.h" |
9 | #if defined(__x86_64__) || defined(__i386__) || \ |
10 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
11 | #include <immintrin.h> |
12 | #endif |
13 | #include <algorithm> //for std::min/std::max |
14 | #include <cassert> |
15 | #include <cmath> //for nearbyint |
16 | #include <limits> //for numeric_limits |
17 | |
18 | namespace fbgemm { |
19 | |
20 | using namespace std; |
21 | template < |
22 | bool A_SYMMETRIC, |
23 | bool B_SYMMETRIC, |
24 | QuantizationGranularity Q_GRAN, |
25 | bool HAS_BIAS, |
26 | bool FUSE_RELU, |
27 | int C_PER_G, |
28 | typename BIAS_TYPE> |
29 | void requantizeOutputProcessingGConvAvx512( |
30 | std::uint8_t* out, |
31 | const std::int32_t* inp, |
32 | const block_type_t& block, |
33 | int ld_out, |
34 | int ld_in, |
35 | const requantizationParams_t<BIAS_TYPE>& r) { |
36 | // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c |
37 | // using AVX2 instructions |
38 | int quant_param_idx = 0; |
39 | if (Q_GRAN == QuantizationGranularity::GROUP) { |
40 | int ncol_per_group = r.ncols / r.groups; |
41 | int g = block.col_start / ncol_per_group; |
42 | quant_param_idx = g; |
43 | } |
44 | __m512 multiplier_v = _mm512_set1_ps(r.C_multiplier[quant_param_idx]); |
45 | // Broadcasted reciprocal of act_times_w_scale |
46 | __m512 act_times_w_rcp_v; |
47 | |
48 | if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { |
49 | if (is_same<BIAS_TYPE, float>::value) { |
50 | act_times_w_rcp_v = |
51 | _mm512_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); |
52 | } |
53 | } |
54 | __m512i min_v = _mm512_set1_epi8(static_cast<uint8_t>(0)); |
55 | __m512i max_v = _mm512_set1_epi8(static_cast<uint8_t>(255)); |
56 | |
57 | assert( |
58 | (A_SYMMETRIC == (r.A_zero_point == 0)) && |
59 | "A_SYMMETRIC == true if and only if A_zero_point == 0" ); |
60 | assert( |
61 | (B_SYMMETRIC == |
62 | ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) || |
63 | r.row_offsets == nullptr)) && |
64 | "B_SYMMETRIC == true if and only if B_zero_point == 0 " |
65 | "or r.row_offsets == nullptr" ); |
66 | assert( |
67 | (HAS_BIAS == (r.bias != nullptr)) && |
68 | "HAS_BIAS == true if and only if bias != nullptr" ); |
69 | |
70 | __m512i A_zero_point_v = _mm512_set1_epi32(r.A_zero_point); |
71 | __m512i C_zero_point_epi16_v = _mm512_set1_epi16(r.C_zero_point); |
72 | __m512i C_zero_point_epi8_v = _mm512_set1_epi8(r.C_zero_point); |
73 | __m512i permute_mask_v_g8 = _mm512_set_epi32( |
74 | 0x0f, |
75 | 0x07, |
76 | 0x0e, |
77 | 0x06, |
78 | 0x0d, |
79 | 0x05, |
80 | 0x0c, |
81 | 0x04, |
82 | 0x0b, |
83 | 0x03, |
84 | 0x0a, |
85 | 0x02, |
86 | 0x09, |
87 | 0x01, |
88 | 0x08, |
89 | 0x00); |
90 | |
91 | __m512i permute_mask_v_g4 = _mm512_set_epi32( |
92 | 0x0f, |
93 | 0x0b, |
94 | 0x07, |
95 | 0x03, |
96 | 0x0e, |
97 | 0x0a, |
98 | 0x06, |
99 | 0x02, |
100 | 0x0d, |
101 | 0x09, |
102 | 0x05, |
103 | 0x01, |
104 | 0x0c, |
105 | 0x08, |
106 | 0x04, |
107 | 0x00); |
108 | // vector lane width 16 * 32 = 512 bits |
109 | constexpr int VLEN = 16; |
110 | const __mmask16 mask = 0x00ff; |
111 | |
112 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
113 | int j = block.col_start; |
114 | // changed the iteration termination criteria for C_per_g = 8 |
115 | // for avx512 currently all 4 cases supported will only run one iteration of |
116 | // inner loop |
117 | // for C_per_g == 8, we only have 8 outputs while the other cases have 16. |
118 | // thus, we do masked load for all col quantization scheme under C_per_g == |
119 | // 8 |
120 | for (; j < block.col_start + ((block.col_size + VLEN - 1) / VLEN * VLEN); |
121 | j += VLEN) { |
122 | __m512i x_v; |
123 | if (C_PER_G != 8) { |
124 | x_v = _mm512_loadu_si512(reinterpret_cast<const __m512i*>( |
125 | inp + (i - block.row_start) * ld_in + (j - block.col_start))); |
126 | } else { |
127 | // as of now we only have C_per_G = 2,4,8,16 thus this j loop all only |
128 | // execute one iteration, the following point will be wrong if run more |
129 | // than one iter |
130 | x_v = _mm512_maskz_loadu_epi32( |
131 | mask, inp + (i - block.row_start) * ld_in + (j - block.col_start)); |
132 | } |
133 | |
134 | if (!A_SYMMETRIC) { |
135 | __m512i col_off_raw_v; |
136 | if (C_PER_G != 8) { |
137 | col_off_raw_v = _mm512_loadu_si512( |
138 | reinterpret_cast<const __m512i*>(r.col_offsets + j)); |
139 | } else { |
140 | col_off_raw_v = _mm512_maskz_loadu_epi32(mask, r.col_offsets + j); |
141 | } |
142 | |
143 | __m512i col_off_v = _mm512_mullo_epi32(A_zero_point_v, col_off_raw_v); |
144 | x_v = _mm512_sub_epi32(x_v, col_off_v); |
145 | } |
146 | |
147 | if (!B_SYMMETRIC) { |
148 | __m512i row_offset_v; |
149 | |
150 | if (C_PER_G == 2) { |
151 | // When C_PER_G == 2, we need to handle 8 groups at a time to fully |
152 | // utilize 64B AVX12 vector register (C_PER_G * 8 * sizeof(int32_t) == |
153 | // 64B) |
154 | // Load row_offsets for 8 groups and broadcast by 2 times. |
155 | row_offset_v = |
156 | _mm512_castps_si512(_mm512_moveldup_ps(_mm512_permutexvar_ps( |
157 | permute_mask_v_g8, |
158 | _mm512_castps256_ps512( |
159 | _mm256_loadu_ps(reinterpret_cast<const float*>( |
160 | r.row_offsets + (i - block.row_start) * 8)))))); |
161 | |
162 | } |
163 | // When C_PER_G == 4, we need to handle 4 groups at a time to fully |
164 | // utilize 32B AVX2 vector register (C_PER_G * 4 * sizeof(int32_t) == |
165 | // 32B) |
166 | // When C_PER_G == 8, we just need 1 group at a time on the other hand. |
167 | |
168 | // Groups 0,1,2,3 when C_PER_G == 4 |
169 | // Group 0 when C_PER_G == 8 |
170 | else if (C_PER_G == 4) { |
171 | // Load row_offsets for 4 groups and broadcast by 4 times each because |
172 | // we have 4 channels per group. |
173 | // groups 0,1,2,3 |
174 | row_offset_v = _mm512_permutexvar_epi32( |
175 | permute_mask_v_g4, |
176 | _mm512_broadcast_i32x4( |
177 | _mm_loadu_si128(reinterpret_cast<const __m128i*>( |
178 | r.row_offsets + (i - block.row_start) * 4)))); |
179 | } else if (C_PER_G == 8) { |
180 | row_offset_v = |
181 | _mm512_set1_epi32(r.row_offsets[(i - block.row_start)]); |
182 | } else { |
183 | assert(C_PER_G == 16); |
184 | row_offset_v = |
185 | _mm512_set1_epi32(r.row_offsets[(i - block.row_start)]); |
186 | } |
187 | |
188 | __m512i B_zero_point_v = _mm512_set1_epi32(r.B_zero_point[0]); |
189 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
190 | if (C_PER_G != 8) { |
191 | B_zero_point_v = _mm512_loadu_si512( |
192 | reinterpret_cast<const __m512i*>(r.B_zero_point + j)); |
193 | } else { |
194 | B_zero_point_v = _mm512_maskz_loadu_epi32(mask, r.B_zero_point + j); |
195 | } |
196 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
197 | if (C_PER_G == 2) { |
198 | B_zero_point_v = |
199 | _mm512_castps_si512(_mm512_moveldup_ps(_mm512_permutexvar_ps( |
200 | permute_mask_v_g8, |
201 | _mm512_castps256_ps512( |
202 | _mm256_loadu_ps(reinterpret_cast<const float*>( |
203 | r.B_zero_point + quant_param_idx)))))); |
204 | } else if (C_PER_G == 4) { |
205 | B_zero_point_v = _mm512_permutexvar_epi32( |
206 | permute_mask_v_g4, |
207 | _mm512_broadcast_i32x4( |
208 | _mm_loadu_si128(reinterpret_cast<const __m128i*>( |
209 | r.B_zero_point + quant_param_idx)))); |
210 | } else if (C_PER_G == 8) { |
211 | B_zero_point_v = _mm512_set1_epi32(r.B_zero_point[quant_param_idx]); |
212 | } else { |
213 | B_zero_point_v = _mm512_set1_epi32(r.B_zero_point[quant_param_idx]); |
214 | } |
215 | } |
216 | row_offset_v = _mm512_mullo_epi32(row_offset_v, B_zero_point_v); |
217 | x_v = _mm512_sub_epi32(x_v, row_offset_v); |
218 | } |
219 | __m512 xf_v; |
220 | if (HAS_BIAS) { |
221 | if (is_same<BIAS_TYPE, float>::value) { |
222 | __m512 x_bias_v; |
223 | if (C_PER_G != 8) { |
224 | x_bias_v = |
225 | _mm512_loadu_ps(reinterpret_cast<const float*>(r.bias + j)); |
226 | } else { |
227 | x_bias_v = _mm512_maskz_loadu_ps( |
228 | mask, reinterpret_cast<const float*>(r.bias + j)); |
229 | } |
230 | |
231 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
232 | __m512 act_times_w_scale_v; |
233 | if (C_PER_G != 8) { |
234 | act_times_w_scale_v = _mm512_loadu_ps(r.act_times_w_scale + j); |
235 | } else { |
236 | act_times_w_scale_v = |
237 | _mm512_maskz_loadu_ps(mask, r.act_times_w_scale + j); |
238 | } |
239 | x_bias_v = _mm512_div_ps(x_bias_v, act_times_w_scale_v); |
240 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
241 | __m512 diviser_v; |
242 | if (C_PER_G == 2) { |
243 | diviser_v = _mm512_moveldup_ps(_mm512_permutexvar_ps( |
244 | permute_mask_v_g8, |
245 | _mm512_castps256_ps512( |
246 | _mm256_loadu_ps(r.act_times_w_scale + quant_param_idx)))); |
247 | } else if (C_PER_G == 4) { |
248 | diviser_v = _mm512_permutexvar_ps( |
249 | permute_mask_v_g4, |
250 | _mm512_broadcast_f32x4( |
251 | |
252 | _mm_loadu_ps(r.act_times_w_scale + quant_param_idx))); |
253 | } else if (C_PER_G == 8) { |
254 | diviser_v = _mm512_set1_ps(r.act_times_w_scale[quant_param_idx]); |
255 | } else { |
256 | assert(C_PER_G == 16); |
257 | diviser_v = _mm512_set1_ps(r.act_times_w_scale[quant_param_idx]); |
258 | } |
259 | x_bias_v = _mm512_div_ps(x_bias_v, diviser_v); |
260 | } else { |
261 | x_bias_v = _mm512_mul_ps(x_bias_v, act_times_w_rcp_v); |
262 | } |
263 | xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x_v), x_bias_v); |
264 | } else { |
265 | x_v = _mm512_add_epi32( |
266 | x_v, |
267 | _mm512_loadu_si512(reinterpret_cast<const __m512i*>(r.bias + j))); |
268 | xf_v = _mm512_cvtepi32_ps(x_v); |
269 | } |
270 | } else { |
271 | xf_v = _mm512_cvtepi32_ps(x_v); |
272 | } |
273 | |
274 | /* |
275 | * Convert int32_t input to FP32 and multiply by FP32 scale. |
276 | * Both operations involve statistically unbiased roundings (with |
277 | * default MXCSR rounding mode): |
278 | * - Large int32_t values can't be exactly represented as FP32. |
279 | * CVTDQ2PS instruction on x86 would round it according to nearest |
280 | * FP32 value with ties to even (assuming default MXCSR rounding |
281 | * mode). |
282 | * - Product of two FP32 values is generally not exactly |
283 | * representation as an FP32 value, and will be rounded to nearest |
284 | * FP32 value with ties to even with default MXCSR rounding mode. |
285 | */ |
286 | __m512 x_scaled_v; |
287 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
288 | __m512 C_multiplier_v; |
289 | if (C_PER_G != 8) { |
290 | C_multiplier_v = _mm512_loadu_ps(r.C_multiplier + j); |
291 | } else { |
292 | C_multiplier_v = _mm512_maskz_loadu_ps(mask, r.C_multiplier + j); |
293 | } |
294 | x_scaled_v = _mm512_mul_ps(xf_v, C_multiplier_v); |
295 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
296 | if (C_PER_G == 2) { |
297 | multiplier_v = _mm512_moveldup_ps(_mm512_permutexvar_ps( |
298 | permute_mask_v_g8, |
299 | _mm512_castps256_ps512( |
300 | _mm256_loadu_ps(r.C_multiplier + quant_param_idx)))); |
301 | } else if (C_PER_G == 4) { |
302 | multiplier_v = _mm512_permutexvar_ps( |
303 | permute_mask_v_g4, |
304 | _mm512_broadcast_f32x4( |
305 | _mm_loadu_ps(r.C_multiplier + quant_param_idx))); |
306 | } else if (C_PER_G == 8) { |
307 | multiplier_v = _mm512_set1_ps(r.C_multiplier[quant_param_idx]); |
308 | } else { |
309 | multiplier_v = _mm512_set1_ps(r.C_multiplier[quant_param_idx]); |
310 | } |
311 | x_scaled_v = _mm512_mul_ps(xf_v, multiplier_v); |
312 | } else { |
313 | x_scaled_v = _mm512_mul_ps(xf_v, multiplier_v); |
314 | } |
315 | |
316 | /* |
317 | * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction. |
318 | * CVTPS2DQ instruction rounds result according to nearest FP32 value |
319 | * with ties to even (assuming default MXCSR rounding mode). However, |
320 | * when conversion overflows, it produces INT32_MIN as a result. For |
321 | * large positive inputs the result of conversion can become negative, |
322 | * which affects the final requantization result. Note that on x86 |
323 | * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This |
324 | * happens because float(INT32_MAX) rounds to 2**31, which overflows |
325 | * int32_t when it is converted back to integer. |
326 | * |
327 | * Thankfully, we can prove that overflow never happens in this |
328 | * requantization scheme. The largest positive input is INT32_MAX |
329 | * (2**31 - 1), which turns into 2**31 when converted to float. The |
330 | * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the |
331 | * result is 2147483520 (compare to INT32_MAX = 2147483647), which |
332 | * fits into int32_t without overflow. |
333 | */ |
334 | __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); |
335 | |
336 | /* |
337 | * Standard final sequence on x86 AVX512: |
338 | * - Pack to int16_t and saturate |
339 | * - Add zero point |
340 | * - Pack to uint8_t and saturate |
341 | * - Clamp between qmin and qmax |
342 | */ |
343 | __m512i x_packed_v = _mm512_adds_epi16( |
344 | _mm512_packs_epi32(x_rounded_v, _mm512_setzero_si512()), |
345 | C_zero_point_epi16_v); |
346 | x_packed_v = _mm512_packus_epi16(x_packed_v, _mm512_setzero_si512()); |
347 | __m512i x_clamped_v = _mm512_max_epu8( |
348 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
349 | _mm512_min_epu8(x_packed_v, max_v)); |
350 | |
351 | /* |
352 | * x_clamped_v has results in the following layout so we need to |
353 | * permute: x0-3 garbage0-11 x4-7 garbage12-23 x8-11 garbage24-35 x12-15 |
354 | * garbage36-47 |
355 | */ |
356 | x_clamped_v = _mm512_permutexvar_epi32(permute_mask_v_g4, x_clamped_v); |
357 | |
358 | /* |
359 | * 1x CVTDQ2PS |
360 | * 1x MULPS |
361 | * 1x CVTPS2DQ |
362 | * 1x PACKSSDW |
363 | * 1x PACKUSWB |
364 | * 1x PADDW |
365 | * 1x PMAXUB |
366 | * 1x PMINUB |
367 | * 1x PERMD |
368 | * --------------------- |
369 | * 9 instructions total |
370 | */ |
371 | if (C_PER_G != 8) { |
372 | _mm_storeu_si128( |
373 | reinterpret_cast<__m128i*>(out + i * ld_out + j), |
374 | _mm512_castsi512_si128(x_clamped_v)); |
375 | } else { |
376 | _mm_storel_epi64( |
377 | reinterpret_cast<__m128i*>(out + i * ld_out + j), |
378 | _mm512_castsi512_si128(x_clamped_v)); |
379 | } |
380 | } // j loop vectorized |
381 | |
382 | #ifndef NDEBUG |
383 | int remainder = block.col_start + block.col_size - j; |
384 | assert(remainder == 0 || C_PER_G == 8); |
385 | #endif |
386 | } // i loop |
387 | } |
388 | |
389 | #define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \ |
390 | A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \ |
391 | template void requantizeOutputProcessingGConvAvx512< \ |
392 | A_SYM, \ |
393 | B_SYM, \ |
394 | Q_GRAN, \ |
395 | BIAS, \ |
396 | RELU, \ |
397 | 2, \ |
398 | BIAS_TYPE>( \ |
399 | uint8_t * out, \ |
400 | const int32_t* inp, \ |
401 | const block_type_t& block, \ |
402 | int ld_out, \ |
403 | int ld_in, \ |
404 | const requantizationParams_t<BIAS_TYPE>& r); \ |
405 | template void requantizeOutputProcessingGConvAvx512< \ |
406 | A_SYM, \ |
407 | B_SYM, \ |
408 | Q_GRAN, \ |
409 | BIAS, \ |
410 | RELU, \ |
411 | 4, \ |
412 | BIAS_TYPE>( \ |
413 | uint8_t * out, \ |
414 | const int32_t* inp, \ |
415 | const block_type_t& block, \ |
416 | int ld_out, \ |
417 | int ld_in, \ |
418 | const requantizationParams_t<BIAS_TYPE>& r); \ |
419 | template void requantizeOutputProcessingGConvAvx512< \ |
420 | A_SYM, \ |
421 | B_SYM, \ |
422 | Q_GRAN, \ |
423 | BIAS, \ |
424 | RELU, \ |
425 | 8, \ |
426 | BIAS_TYPE>( \ |
427 | uint8_t * out, \ |
428 | const int32_t* inp, \ |
429 | const block_type_t& block, \ |
430 | int ld_out, \ |
431 | int ld_in, \ |
432 | const requantizationParams_t<BIAS_TYPE>& r); \ |
433 | template void requantizeOutputProcessingGConvAvx512< \ |
434 | A_SYM, \ |
435 | B_SYM, \ |
436 | Q_GRAN, \ |
437 | BIAS, \ |
438 | RELU, \ |
439 | 16, \ |
440 | BIAS_TYPE>( \ |
441 | uint8_t * out, \ |
442 | const int32_t* inp, \ |
443 | const block_type_t& block, \ |
444 | int ld_out, \ |
445 | int ld_in, \ |
446 | const requantizationParams_t<BIAS_TYPE>& r); |
447 | |
448 | #define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ |
449 | INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \ |
450 | INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) |
451 | |
452 | #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ |
453 | INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ |
454 | INSTANTIATE_REQUANTIZE(false, B_SYM, Q_GRAN, BIAS, RELU) |
455 | |
456 | #define INSTANTIATE_B_SYM(Q_GRAN, BIAS, RELU) \ |
457 | INSTANTIATE_A_SYM(true, Q_GRAN, BIAS, RELU) \ |
458 | INSTANTIATE_A_SYM(false, Q_GRAN, BIAS, RELU) |
459 | |
460 | #define INSTANTIATE_Q_GRANS(BIAS, RELU) \ |
461 | INSTANTIATE_B_SYM(QuantizationGranularity::TENSOR, BIAS, RELU) \ |
462 | INSTANTIATE_B_SYM(QuantizationGranularity::GROUP, BIAS, RELU) \ |
463 | INSTANTIATE_B_SYM(QuantizationGranularity::OUT_CHANNEL, BIAS, RELU) |
464 | |
465 | #define INSTANTIATE_BIAS(RELU) \ |
466 | INSTANTIATE_Q_GRANS(true, RELU) \ |
467 | INSTANTIATE_Q_GRANS(false, RELU) |
468 | |
469 | INSTANTIATE_BIAS(true) |
470 | INSTANTIATE_BIAS(false) |
471 | |
472 | #undef INSTANTIATE_A_SYM |
473 | #undef INSTANTIATE_B_SYM |
474 | #undef INSTANTIATE_Q_GRANS |
475 | #undef INSTANTIATE_BIAS |
476 | } // namespace fbgemm |
477 | |