1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/QuantUtilsAvx2.h" |
9 | #if defined(__x86_64__) || defined(__i386__) || \ |
10 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
11 | #include <immintrin.h> |
12 | #endif |
13 | #include <algorithm> //for std::min/std::max |
14 | #include <cassert> //for assert |
15 | #include <cfloat> // for FLT_MAX |
16 | #include <cmath> //for nearbyint |
17 | #include <cstring> //for memcpy |
18 | #include <limits> //for numeric_limits |
19 | #include "./MaskAvx2.h" |
20 | #include "fbgemm/Types.h" |
21 | |
22 | namespace fbgemm { |
23 | |
24 | using namespace std; |
25 | //////////////////////////////////////////////////////////////////////////////// |
26 | // Utility functions |
27 | |
28 | template <typename T, bool LEGACY> |
29 | void QuantizeAvx2( |
30 | const float* src, |
31 | T* dst, |
32 | int64_t len, |
33 | const TensorQuantizationParams& qparams) { |
34 | #if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER)) |
35 | constexpr int VLEN = 8; |
36 | constexpr int32_t min_val = std::numeric_limits<T>::min(); |
37 | constexpr int32_t max_val = std::numeric_limits<T>::max(); |
38 | // This is the largest int32 value less than int32_max |
39 | // that is exactly representable in float |
40 | constexpr int32_t int32_float_max_val = |
41 | std::numeric_limits<int32_t>::max() - 127; |
42 | int i = 0; |
43 | float inverse_scale = 1.f / qparams.scale; |
44 | __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); |
45 | // clang-format off |
46 | __m256i shuffle_mask_v = _mm256_set_epi8( |
47 | 0xff, 0xff, 0xff, 0xff, |
48 | 0xff, 0xff, 0xff, 0xff, |
49 | 0xff, 0xff, 0xff, 0xff, |
50 | 0x0c, 0x08, 0x04, 0x00, |
51 | 0xff, 0xff, 0xff, 0xff, |
52 | 0xff, 0xff, 0xff, 0xff, |
53 | 0xff, 0xff, 0xff, 0xff, |
54 | 0x0c, 0x08, 0x04, 0x00); |
55 | // clang-format on |
56 | __m256i permute_mask_v = |
57 | _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); |
58 | const auto zero_point_v_legacy = _mm256_set1_ps(qparams.zero_point); |
59 | const auto zero_point_v_non_legacy = _mm256_set1_epi32(qparams.zero_point); |
60 | for (; i < len / VLEN * VLEN; i += VLEN) { |
61 | __m256 src_v = _mm256_loadu_ps(src + i); |
62 | __m256 transformed_v; |
63 | if (LEGACY) { // static if |
64 | transformed_v = |
65 | _mm256_fmadd_ps(src_v, inverse_scale_v, zero_point_v_legacy); |
66 | } else { |
67 | transformed_v = _mm256_mul_ps(src_v, inverse_scale_v); |
68 | } |
69 | // If the floating point value is greater than int32_max, |
70 | // _mm256_cvtps_epi32 converts them to negative. Clip at int32_float_max_val |
71 | // to avoid this. |
72 | transformed_v = |
73 | _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val)); |
74 | |
75 | __m256i rounded_v = _mm256_cvtps_epi32(transformed_v); |
76 | if (!LEGACY) { |
77 | rounded_v = _mm256_add_epi32(rounded_v, zero_point_v_non_legacy); |
78 | } |
79 | __m256i clipped_v = _mm256_min_epi32( |
80 | _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)), |
81 | _mm256_set1_epi32(max_val)); |
82 | |
83 | // An instruction sequence to save 8 32-bit integers as 8 8-bit integers |
84 | clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); |
85 | clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); |
86 | _mm_storel_epi64( |
87 | reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(clipped_v)); |
88 | } |
89 | |
90 | // Handle remainder using mask instructions so that |
91 | // the main loop and remainder loop have the same behavior |
92 | int64_t rem = len - i; |
93 | if (rem > 0) { |
94 | __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
95 | internal::avx2_ps_or_epi32_masks[rem])); |
96 | // __m128i store_mask_v = _mm_load_si128( |
97 | // reinterpret_cast<const __m128i*>(internal::sse_epi8_masks[rem])); |
98 | __m256 src_v = _mm256_maskload_ps(src + i, mask_v); |
99 | __m256 transformed_v; |
100 | if (LEGACY) { |
101 | transformed_v = |
102 | _mm256_fmadd_ps(src_v, inverse_scale_v, zero_point_v_legacy); |
103 | } else { |
104 | transformed_v = _mm256_mul_ps(src_v, inverse_scale_v); |
105 | } |
106 | transformed_v = |
107 | _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val)); |
108 | |
109 | __m256i rounded_v = _mm256_cvtps_epi32(transformed_v); |
110 | if (!LEGACY) { |
111 | rounded_v = _mm256_add_epi32(rounded_v, zero_point_v_non_legacy); |
112 | } |
113 | __m256i clipped_v = _mm256_min_epi32( |
114 | _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)), |
115 | _mm256_set1_epi32(max_val)); |
116 | |
117 | // An instruction sequence to save "rem" number of 32-bit integers |
118 | // as "rem" number of 8-bit integers |
119 | clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); |
120 | clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); |
121 | // do not use _mm_maskmoveu_si128 instead of memcpy. |
122 | // asan has false positives for _mm_maskmoveu_si128 and this instruction |
123 | // sometimes causes segfault (root cause is unknown). |
124 | memcpy(dst + i, reinterpret_cast<void*>(&clipped_v), rem * sizeof(T)); |
125 | // _mm_maskmoveu_si128( |
126 | // _mm256_castsi256_si128(clipped_v), |
127 | // store_mask_v, |
128 | // reinterpret_cast<char*>(dst + i)); |
129 | } |
130 | #endif |
131 | } |
132 | |
133 | uint32_t Xor128(void) { |
134 | /* library-local */ static uint32_t x = 123456789; |
135 | /* library-local */ static uint32_t y = 362436069; |
136 | /* library-local */ static uint32_t z = 521288629; |
137 | /* library-local */ static uint32_t w = 88675123; |
138 | uint32_t t; |
139 | t = x ^ (x << 11); |
140 | x = y; |
141 | y = z; |
142 | z = w; |
143 | return w = w ^ (w >> 19) ^ (t ^ (t >> 8)); |
144 | } |
145 | |
146 | // Instantiate QuantizeAvx2 for known datatypes |
147 | #define SPECIALIZE_QUANTIZEAVX2(T, LEGACY) \ |
148 | template void QuantizeAvx2<T, LEGACY>( \ |
149 | const float* src, \ |
150 | T* dst, \ |
151 | int64_t len, \ |
152 | const TensorQuantizationParams& qparams); |
153 | SPECIALIZE_QUANTIZEAVX2(uint8_t, true) |
154 | SPECIALIZE_QUANTIZEAVX2(int8_t, true) |
155 | SPECIALIZE_QUANTIZEAVX2(uint8_t, false) |
156 | SPECIALIZE_QUANTIZEAVX2(int8_t, false) |
157 | #undef SPECIALIZE_QUANTIZEAVX2 |
158 | |
159 | template <typename T> |
160 | void NO_SANITIZE("address" ) FusedQuantizeDequantizeAvx2( |
161 | const float* src, |
162 | float* dst, |
163 | int len, |
164 | const TensorQuantizationParams& qparams, |
165 | float noise_ratio) { |
166 | float inverse_scale = 1.f / qparams.scale; |
167 | constexpr int32_t min_val = std::numeric_limits<T>::min(); |
168 | constexpr int32_t max_val = std::numeric_limits<T>::max(); |
169 | (void)inverse_scale; // Suppress unused variable warning |
170 | (void)min_val; // Suppress unused variable warning |
171 | (void)max_val; // Suppress unused variable warning |
172 | #if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER)) |
173 | |
174 | constexpr int VLEN = 8; |
175 | // This is the largest int32 value less than int32_max |
176 | // that is exactly representable in float |
177 | constexpr int32_t int32_float_max_val = |
178 | std::numeric_limits<int32_t>::max() - 127; |
179 | int i = 0; |
180 | uint32_t rand; |
181 | __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); |
182 | __m256 scale_v = _mm256_set1_ps(qparams.scale); |
183 | __m256 zp_v = _mm256_set1_ps(qparams.zero_point); |
184 | |
185 | for (; i < len / VLEN * VLEN; i += VLEN) { |
186 | // prefetch src and dst |
187 | _mm_prefetch(reinterpret_cast<const char*>(src + i + VLEN), _MM_HINT_T0); |
188 | _mm_prefetch(reinterpret_cast<const char*>(dst + i + VLEN), _MM_HINT_T0); |
189 | |
190 | __m256 src_v = _mm256_loadu_ps(src + i); |
191 | __m256 transformed_v; |
192 | if (noise_ratio > 0) { |
193 | rand = Xor128() % 10; |
194 | if (rand < noise_ratio * 10) { |
195 | _mm256_storeu_ps(dst + i, src_v); |
196 | continue; |
197 | } |
198 | } |
199 | |
200 | transformed_v = _mm256_mul_ps(src_v, inverse_scale_v); |
201 | // If the floating point value is greater than int32_max, |
202 | // _mm256_cvtps_epi32 converts them to negative. Clip at int32_float_max_val |
203 | // to avoid this. |
204 | transformed_v = |
205 | _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val)); |
206 | |
207 | __m256i rounded_v = _mm256_cvtps_epi32(transformed_v); |
208 | rounded_v = |
209 | _mm256_add_epi32(rounded_v, _mm256_set1_epi32(qparams.zero_point)); |
210 | __m256i clipped_v = _mm256_min_epi32( |
211 | _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)), |
212 | _mm256_set1_epi32(max_val)); |
213 | |
214 | // convert int32 to float32 |
215 | __m256 fp32_clipped_v = _mm256_cvtepi32_ps(clipped_v); |
216 | // minus zero point, multiply by scale |
217 | __m256 fp32_dq_sub = _mm256_sub_ps(fp32_clipped_v, zp_v); |
218 | __m256 fp32_dq = _mm256_mul_ps(fp32_dq_sub, scale_v); |
219 | |
220 | // save fusued quantize-dequantize fp32 values into dst |
221 | _mm256_storeu_ps(dst + i, fp32_dq); |
222 | } |
223 | |
224 | // Handle remainder using mask instructions so that |
225 | // the main loop and remainder loop have the same behavior |
226 | int rem = len - i; |
227 | if (rem > 0) { |
228 | __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
229 | internal::avx2_ps_or_epi32_masks[rem])); |
230 | |
231 | __m256 src_v = _mm256_maskload_ps(src + i, mask_v); |
232 | __m256 transformed_v; |
233 | |
234 | if (noise_ratio > 0) { |
235 | rand = Xor128() % 10; |
236 | if (rand < noise_ratio * 10) { |
237 | _mm256_storeu_ps(dst + i, src_v); |
238 | return; |
239 | } |
240 | } |
241 | |
242 | transformed_v = _mm256_mul_ps(src_v, inverse_scale_v); |
243 | // If the floating point value is greater than int32_max, |
244 | // _mm256_cvtps_epi32 converts them to negative. Clip at int32_float_max_val |
245 | // to avoid this. |
246 | transformed_v = |
247 | _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val)); |
248 | |
249 | __m256i rounded_v = _mm256_cvtps_epi32(transformed_v); |
250 | rounded_v = |
251 | _mm256_add_epi32(rounded_v, _mm256_set1_epi32(qparams.zero_point)); |
252 | |
253 | __m256i clipped_v = _mm256_min_epi32( |
254 | _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)), |
255 | _mm256_set1_epi32(max_val)); |
256 | |
257 | // convert int32 to float32 |
258 | __m256 fp32_clipped_v = _mm256_cvtepi32_ps(clipped_v); |
259 | // minus zero point, multiply by scale |
260 | __m256 fp32_dq_sub = |
261 | _mm256_sub_ps(fp32_clipped_v, _mm256_set1_ps(qparams.zero_point)); |
262 | __m256 fp32_dq = _mm256_mul_ps(fp32_dq_sub, _mm256_set1_ps(qparams.scale)); |
263 | |
264 | // store fp32 values with mask |
265 | _mm256_maskstore_ps(dst + i, mask_v, fp32_dq); |
266 | } |
267 | #endif |
268 | } |
269 | |
270 | // Instantiate QuantizeAvx2 for known datatypes |
271 | #define SPECIALIZE_FUSEDDQAVX2(T) \ |
272 | template void FusedQuantizeDequantizeAvx2<T>( \ |
273 | const float* src, \ |
274 | float* dst, \ |
275 | int len, \ |
276 | const TensorQuantizationParams& qparams, \ |
277 | float noise_ratio); |
278 | SPECIALIZE_FUSEDDQAVX2(uint8_t) |
279 | SPECIALIZE_FUSEDDQAVX2(int8_t) |
280 | |
281 | #undef SPECIALIZE_FUSEDDQAVX2 |
282 | |
283 | void FindMinMax(const float* a, float* min, float* max, int64_t len) { |
284 | if (len <= 0) { |
285 | *min = 0.0f; |
286 | *max = 0.0f; |
287 | return; |
288 | } |
289 | |
290 | float temp_min = *a, temp_max = *a; |
291 | int64_t i = 0; |
292 | |
293 | #ifdef __AVX__ |
294 | __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a); |
295 | constexpr int VLEN = 8; |
296 | if (len >= VLEN) { |
297 | for (; i < len / VLEN * VLEN; i += VLEN) { |
298 | min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i)); |
299 | max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i)); |
300 | } |
301 | |
302 | float min_buf[VLEN], max_buf[VLEN]; |
303 | _mm256_storeu_ps(min_buf, min_v); |
304 | _mm256_storeu_ps(max_buf, max_v); |
305 | for (int j = 0; j < VLEN; ++j) { |
306 | temp_min = std::min(temp_min, min_buf[j]); |
307 | temp_max = std::max(temp_max, max_buf[j]); |
308 | } |
309 | } |
310 | #endif |
311 | |
312 | for (; i < len; i++) { |
313 | temp_min = std::min(temp_min, a[i]); |
314 | temp_max = std::max(temp_max, a[i]); |
315 | } |
316 | *min = temp_min; |
317 | *max = temp_max; |
318 | } |
319 | |
320 | //////////////////////////////////////////////////////////////////////////////// |
321 | // Requantization (with floats) |
322 | |
323 | #ifdef __AVX2__ |
324 | void RequantizeAvx2( |
325 | const int32_t* src, |
326 | uint8_t* dst, |
327 | int len, |
328 | const RequantizationParams& params) { |
329 | int32_t Bq_zero_point[] = {0}; |
330 | |
331 | requantizationParams_t<> reqObj = { |
332 | 0, // Aq_zero_point |
333 | Bq_zero_point, |
334 | params.target_qparams.zero_point, |
335 | ¶ms.real_multiplier, |
336 | nullptr, // row_offsets |
337 | nullptr, // col_offsets |
338 | nullptr, // bias |
339 | static_cast<std::uint32_t>(len), // ncols |
340 | 1, // groups |
341 | nullptr}; // act_times_w_scale |
342 | requantizeOutputProcessingAvx2< |
343 | true, // A_SYMMETRIC |
344 | true, // B_SYMMETRIC |
345 | QuantizationGranularity::TENSOR, |
346 | false, // HAS_BIAS |
347 | false // FUSE_RELU |
348 | >(dst, src, {0, 1, 0, len}, len, len, reqObj); |
349 | } |
350 | |
351 | void RequantizeFixedPointAvx2( |
352 | const int32_t* src, |
353 | uint8_t* dst, |
354 | int len, |
355 | const RequantizationParams& params) { |
356 | constexpr int VLEN = 8; |
357 | |
358 | __m256i b = _mm256_set1_epi32(params.multiplier); |
359 | |
360 | // AVX2 doesn't support arithmetic right shift. |
361 | // As a work around, we convert 64-bit multiplied results to uint64_t by |
362 | // adding 0x8000000000000000ULL, logical right shift, and subtract by |
363 | // (0x8000000000000000ULL >> right_shift). |
364 | __m256i pre_shift_nudge = _mm256_set1_epi64x( |
365 | (1ll << (params.right_shift - 1)) + 0x8000000000000000ULL); |
366 | __m256i post_shift_nudge = _mm256_set1_epi64x( |
367 | params.target_qparams.zero_point - |
368 | (0x8000000000000000ULL >> params.right_shift)); |
369 | |
370 | __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min()); |
371 | __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max()); |
372 | |
373 | __m256i shuffle_mask_v = _mm256_set_epi8( |
374 | 0xff, |
375 | 0xff, |
376 | 0xff, |
377 | 0xff, |
378 | 0xff, |
379 | 0xff, |
380 | 0xff, |
381 | 0xff, |
382 | 0xff, |
383 | 0xff, |
384 | 0xff, |
385 | 0xff, |
386 | 0x0c, |
387 | 0x08, |
388 | 0x04, |
389 | 0x00, |
390 | 0xff, |
391 | 0xff, |
392 | 0xff, |
393 | 0xff, |
394 | 0xff, |
395 | 0xff, |
396 | 0xff, |
397 | 0xff, |
398 | 0xff, |
399 | 0xff, |
400 | 0xff, |
401 | 0xff, |
402 | 0x0c, |
403 | 0x08, |
404 | 0x04, |
405 | 0x00); |
406 | __m256i permute_mask_v = |
407 | _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); |
408 | |
409 | int i = 0; |
410 | for (; i < len / VLEN * VLEN; i += VLEN) { |
411 | __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i)); |
412 | |
413 | // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7 |
414 | // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7 |
415 | __m256i a_even_v = a_v; |
416 | __m256i a_odd_v = _mm256_srli_si256(a_v, 4); |
417 | |
418 | __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b); |
419 | __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b); |
420 | |
421 | __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge); |
422 | __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge); |
423 | |
424 | __m256i even_result_v = _mm256_add_epi64( |
425 | _mm256_srli_epi64(even_rounded_v, params.right_shift), |
426 | post_shift_nudge); |
427 | __m256i odd_result_v = _mm256_add_epi64( |
428 | _mm256_srli_epi64(odd_rounded_v, params.right_shift), post_shift_nudge); |
429 | odd_result_v = _mm256_slli_si256(odd_result_v, 4); |
430 | |
431 | // even_result_v has numbers we want in its even 32-bit SIMD lanes, and |
432 | // odd_result_v has numbers we want in its odd 32-bit SIMD lanes. |
433 | // Use blend to combine them. |
434 | __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa); |
435 | __m256i clipped_v = |
436 | _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v)); |
437 | |
438 | clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); |
439 | clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); |
440 | *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0); |
441 | } |
442 | |
443 | for (; i < len; ++i) { |
444 | int64_t ab_64 = |
445 | static_cast<int64_t>(src[i]) * static_cast<int64_t>(params.multiplier); |
446 | int64_t nudge = 1ll << std::max(0, params.right_shift - 1); |
447 | int64_t quantized_down = params.target_qparams.zero_point + |
448 | ((ab_64 + nudge) >> params.right_shift); |
449 | dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l); |
450 | } |
451 | } |
452 | #endif |
453 | |
454 | template < |
455 | bool A_SYMMETRIC, |
456 | bool B_SYMMETRIC, |
457 | QuantizationGranularity Q_GRAN, |
458 | bool HAS_BIAS, |
459 | bool FUSE_RELU, |
460 | typename BIAS_TYPE, |
461 | bool DIRECT> |
462 | void requantizeOutputProcessingAvx2( |
463 | uint8_t* out, |
464 | const int32_t* inp, |
465 | const block_type_t& block, |
466 | int ld_out, |
467 | int ld_in, |
468 | const requantizationParams_t<BIAS_TYPE>& r) { |
469 | // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c |
470 | // using AVX2 instructions |
471 | int quant_param_idx = 0; |
472 | if (Q_GRAN == QuantizationGranularity::GROUP) { |
473 | int ncol_per_group = r.ncols / r.groups; |
474 | int g = block.col_start / ncol_per_group; |
475 | quant_param_idx = g; |
476 | } |
477 | __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); |
478 | |
479 | // Broadcasted reciprocal of act_times_w_scale |
480 | __m256 act_times_w_rcp_v; |
481 | if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { |
482 | if (is_same<BIAS_TYPE, float>::value) { |
483 | act_times_w_rcp_v = |
484 | _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); |
485 | } |
486 | } |
487 | |
488 | __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); |
489 | __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); |
490 | |
491 | assert( |
492 | (A_SYMMETRIC == (r.A_zero_point == 0)) && |
493 | "A_SYMMETRIC == true if and only if A_zero_point == 0" ); |
494 | assert( |
495 | (B_SYMMETRIC == |
496 | ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) || |
497 | r.row_offsets == nullptr)) && |
498 | "B_SYMMETRIC == true if and only if B_zero_point == 0 " |
499 | "or r.row_offsets == nullptr" ); |
500 | assert( |
501 | (HAS_BIAS == (r.bias != nullptr)) && |
502 | "HAS_BIAS == true if and only if bias != nullptr" ); |
503 | |
504 | __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point); |
505 | __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point); |
506 | __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point); |
507 | |
508 | __m256i permute_mask_v = |
509 | _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); |
510 | |
511 | constexpr int VLEN = 8; |
512 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
513 | // Scale row_offset with Bq_zero_point |
514 | int32_t row_offset = 0; |
515 | if (B_SYMMETRIC) { |
516 | row_offset = 0; |
517 | } else if ( |
518 | Q_GRAN == QuantizationGranularity::TENSOR || |
519 | Q_GRAN == QuantizationGranularity::GROUP) { |
520 | row_offset = |
521 | r.row_offsets[i - block.row_start] * r.B_zero_point[quant_param_idx]; |
522 | } else { |
523 | assert( |
524 | Q_GRAN == QuantizationGranularity::OUT_CHANNEL && |
525 | "unknown quantization granularity" ); |
526 | } |
527 | __m256i row_offset_v = _mm256_set1_epi32(row_offset); |
528 | |
529 | int j = block.col_start; |
530 | for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4)); |
531 | j += (VLEN * 4)) { |
532 | __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
533 | inp + (i - block.row_start) * ld_in + (j - block.col_start))); |
534 | __m256i y_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
535 | inp + (i - block.row_start) * ld_in + (j - block.col_start) + |
536 | 1 * VLEN)); |
537 | __m256i z_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
538 | inp + (i - block.row_start) * ld_in + (j - block.col_start) + |
539 | 2 * VLEN)); |
540 | __m256i w_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
541 | inp + (i - block.row_start) * ld_in + (j - block.col_start) + |
542 | 3 * VLEN)); |
543 | |
544 | if (!A_SYMMETRIC) { |
545 | __m256i col_off_v; |
546 | if (DIRECT == false) { |
547 | col_off_v = _mm256_mullo_epi32( |
548 | A_zero_point_v, |
549 | _mm256_loadu_si256( |
550 | reinterpret_cast<const __m256i*>(r.col_offsets + j))); |
551 | } else { |
552 | col_off_v = _mm256_mullo_epi32( |
553 | A_zero_point_v, |
554 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
555 | r.col_offsets + j + i * block.col_size))); |
556 | } |
557 | |
558 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
559 | |
560 | if (DIRECT == false) { |
561 | col_off_v = _mm256_mullo_epi32( |
562 | A_zero_point_v, |
563 | _mm256_loadu_si256( |
564 | reinterpret_cast<const __m256i*>(r.col_offsets + j + VLEN))); |
565 | } else { |
566 | col_off_v = _mm256_mullo_epi32( |
567 | A_zero_point_v, |
568 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
569 | r.col_offsets + j + VLEN + i * block.col_size))); |
570 | } |
571 | |
572 | y_v = _mm256_sub_epi32(y_v, col_off_v); |
573 | |
574 | if (DIRECT == false) { |
575 | col_off_v = _mm256_mullo_epi32( |
576 | A_zero_point_v, |
577 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
578 | r.col_offsets + j + 2 * VLEN))); |
579 | } else { |
580 | col_off_v = _mm256_mullo_epi32( |
581 | A_zero_point_v, |
582 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
583 | r.col_offsets + j + 2 * VLEN + i * block.col_size))); |
584 | } |
585 | |
586 | z_v = _mm256_sub_epi32(z_v, col_off_v); |
587 | |
588 | if (DIRECT == false) { |
589 | col_off_v = _mm256_mullo_epi32( |
590 | A_zero_point_v, |
591 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
592 | r.col_offsets + j + 3 * VLEN))); |
593 | } else { |
594 | col_off_v = _mm256_mullo_epi32( |
595 | A_zero_point_v, |
596 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
597 | r.col_offsets + j + 3 * VLEN + i * block.col_size))); |
598 | } |
599 | |
600 | w_v = _mm256_sub_epi32(w_v, col_off_v); |
601 | } |
602 | |
603 | if (!B_SYMMETRIC) { |
604 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
605 | row_offset_v = _mm256_mullo_epi32( |
606 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
607 | _mm256_loadu_si256( |
608 | reinterpret_cast<const __m256i*>(r.B_zero_point + j))); |
609 | } |
610 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
611 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
612 | row_offset_v = _mm256_mullo_epi32( |
613 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
614 | _mm256_loadu_si256( |
615 | reinterpret_cast<const __m256i*>(r.B_zero_point + j + VLEN))); |
616 | } |
617 | y_v = _mm256_sub_epi32(y_v, row_offset_v); |
618 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
619 | row_offset_v = _mm256_mullo_epi32( |
620 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
621 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
622 | r.B_zero_point + j + 2 * VLEN))); |
623 | } |
624 | z_v = _mm256_sub_epi32(z_v, row_offset_v); |
625 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
626 | row_offset_v = _mm256_mullo_epi32( |
627 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
628 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
629 | r.B_zero_point + j + 3 * VLEN))); |
630 | } |
631 | w_v = _mm256_sub_epi32(w_v, row_offset_v); |
632 | } |
633 | __m256 xf_v, yf_v, zf_v, wf_v; |
634 | if (HAS_BIAS) { |
635 | if (is_same<BIAS_TYPE, float>::value) { |
636 | __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; |
637 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
638 | x_bias_v = _mm256_div_ps( |
639 | _mm256_loadu_ps( |
640 | reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), |
641 | _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); |
642 | y_bias_v = _mm256_div_ps( |
643 | _mm256_loadu_ps( |
644 | reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), |
645 | _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); |
646 | z_bias_v = _mm256_div_ps( |
647 | _mm256_loadu_ps( |
648 | reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), |
649 | _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); |
650 | w_bias_v = _mm256_div_ps( |
651 | _mm256_loadu_ps( |
652 | reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), |
653 | _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); |
654 | } else { |
655 | x_bias_v = _mm256_mul_ps( |
656 | _mm256_loadu_ps( |
657 | reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), |
658 | act_times_w_rcp_v); |
659 | y_bias_v = _mm256_mul_ps( |
660 | _mm256_loadu_ps( |
661 | reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), |
662 | act_times_w_rcp_v); |
663 | z_bias_v = _mm256_mul_ps( |
664 | _mm256_loadu_ps( |
665 | reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), |
666 | act_times_w_rcp_v); |
667 | w_bias_v = _mm256_mul_ps( |
668 | _mm256_loadu_ps( |
669 | reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), |
670 | act_times_w_rcp_v); |
671 | } |
672 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); |
673 | yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); |
674 | zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); |
675 | wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); |
676 | } else { |
677 | x_v = _mm256_add_epi32( |
678 | x_v, |
679 | _mm256_loadu_si256( |
680 | reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); |
681 | y_v = _mm256_add_epi32( |
682 | y_v, |
683 | _mm256_loadu_si256( |
684 | reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); |
685 | z_v = _mm256_add_epi32( |
686 | z_v, |
687 | _mm256_loadu_si256( |
688 | reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); |
689 | w_v = _mm256_add_epi32( |
690 | w_v, |
691 | _mm256_loadu_si256( |
692 | reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); |
693 | xf_v = _mm256_cvtepi32_ps(x_v); |
694 | yf_v = _mm256_cvtepi32_ps(y_v); |
695 | zf_v = _mm256_cvtepi32_ps(z_v); |
696 | wf_v = _mm256_cvtepi32_ps(w_v); |
697 | } |
698 | } else { |
699 | xf_v = _mm256_cvtepi32_ps(x_v); |
700 | yf_v = _mm256_cvtepi32_ps(y_v); |
701 | zf_v = _mm256_cvtepi32_ps(z_v); |
702 | wf_v = _mm256_cvtepi32_ps(w_v); |
703 | } |
704 | |
705 | /* |
706 | * Convert int32_t input to FP32 and multiply by FP32 scale. |
707 | * Both operations involve statistically unbiased roundings (with |
708 | * default MXCSR rounding mode): |
709 | * - Large int32_t values can't be exactly represented as FP32. |
710 | * CVTDQ2PS instruction on x86 would round it according to nearest |
711 | * FP32 value with ties to even (assuming default MXCSR rounding |
712 | * mode). |
713 | * - Product of two FP32 values is generally not exactly |
714 | * representation as an FP32 value, and will be rounded to nearest |
715 | * FP32 value with ties to even with default MXCSR rounding mode. |
716 | */ |
717 | __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; |
718 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
719 | x_scaled_v = |
720 | _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN)); |
721 | y_scaled_v = |
722 | _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN)); |
723 | z_scaled_v = |
724 | _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); |
725 | w_scaled_v = |
726 | _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); |
727 | } else { |
728 | x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); |
729 | y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); |
730 | z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); |
731 | w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); |
732 | } |
733 | |
734 | /* |
735 | * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction. |
736 | * CVTPS2DQ instruction rounds result according to nearest FP32 value |
737 | * with ties to even (assuming default MXCSR rounding mode). However, |
738 | * when conversion overflows, it produces INT32_MIN as a result. For |
739 | * large positive inputs the result of conversion can become negative, |
740 | * which affects the final requantization result. Note that on x86 |
741 | * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This |
742 | * happens because float(INT32_MAX) rounds to 2**31, which overflows |
743 | * int32_t when it is converted back to integer. |
744 | * |
745 | * Thankfully, we can prove that overflow never happens in this |
746 | * requantization scheme. The largest positive input is INT32_MAX |
747 | * (2**31 - 1), which turns into 2**31 when converted to float. The |
748 | * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the |
749 | * result is 2147483520 (compare to INT32_MAX = 2147483647), which |
750 | * fits into int32_t without overflow. |
751 | */ |
752 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
753 | __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); |
754 | __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); |
755 | __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); |
756 | |
757 | /* |
758 | * Standard final sequence on x86 AVX2: |
759 | * - Pack to int16_t and saturate |
760 | * - Add zero point |
761 | * - Pack to uint8_t and saturate |
762 | * - Clamp between qmin and qmax |
763 | */ |
764 | __m256i xy_packed_v = _mm256_adds_epi16( |
765 | _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); |
766 | __m256i zw_packed_v = _mm256_adds_epi16( |
767 | _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); |
768 | __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); |
769 | __m256i xyzw_clamped_v = _mm256_max_epu8( |
770 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
771 | _mm256_min_epu8(xyzw_packed_v, max_v)); |
772 | |
773 | /* |
774 | * xyzw_clamped_v has results in the following layout so we need to |
775 | * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 |
776 | */ |
777 | xyzw_clamped_v = |
778 | _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); |
779 | |
780 | /* |
781 | * 4x CVTDQ2PS |
782 | * 4x MULPS |
783 | * 4x CVTPS2DQ |
784 | * 2x PACKSSDW |
785 | * 1x PACKUSWB |
786 | * 2x PADDW |
787 | * 1x PMAXUB |
788 | * 1x PMINUB |
789 | * 1x PERMD |
790 | * --------------------- |
791 | * 20 instructions total |
792 | */ |
793 | _mm256_storeu_si256( |
794 | reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v); |
795 | } // j loop vectorized and unrolled 4x |
796 | |
797 | for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) { |
798 | __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
799 | inp + (i - block.row_start) * ld_in + (j - block.col_start))); |
800 | |
801 | if (!A_SYMMETRIC) { |
802 | __m256i col_off_v; |
803 | if (DIRECT == false) { |
804 | col_off_v = _mm256_mullo_epi32( |
805 | A_zero_point_v, |
806 | _mm256_loadu_si256( |
807 | reinterpret_cast<const __m256i*>(r.col_offsets + j))); |
808 | } else { |
809 | col_off_v = _mm256_mullo_epi32( |
810 | A_zero_point_v, |
811 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
812 | r.col_offsets + j + i * block.col_size))); |
813 | } |
814 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
815 | } |
816 | |
817 | if (!B_SYMMETRIC) { |
818 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
819 | row_offset_v = _mm256_mullo_epi32( |
820 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
821 | _mm256_loadu_si256( |
822 | reinterpret_cast<const __m256i*>(r.B_zero_point + j))); |
823 | } |
824 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
825 | } |
826 | __m256 xf_v; |
827 | if (HAS_BIAS) { |
828 | if (is_same<BIAS_TYPE, float>::value) { |
829 | __m256 x_bias_v; |
830 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
831 | x_bias_v = _mm256_div_ps( |
832 | _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), |
833 | _mm256_loadu_ps(r.act_times_w_scale + j)); |
834 | } else { |
835 | x_bias_v = _mm256_mul_ps( |
836 | _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), |
837 | act_times_w_rcp_v); |
838 | } |
839 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); |
840 | } else { |
841 | x_v = _mm256_add_epi32( |
842 | x_v, |
843 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); |
844 | xf_v = _mm256_cvtepi32_ps(x_v); |
845 | } |
846 | } else { |
847 | xf_v = _mm256_cvtepi32_ps(x_v); |
848 | } |
849 | |
850 | __m256 x_scaled_v; |
851 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
852 | x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); |
853 | } else { |
854 | x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); |
855 | } |
856 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
857 | |
858 | __m256i x_packed_v = _mm256_adds_epi16( |
859 | _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), |
860 | C_zero_point_epi16_v); |
861 | x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); |
862 | __m256i x_clamped_v = _mm256_max_epu8( |
863 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
864 | _mm256_min_epu8(x_packed_v, max_v)); |
865 | |
866 | /* |
867 | * x_clamped_v has results in the following layout so we need to |
868 | * permute: x0-3 garbage0-11 x4-7 garbage12-23 |
869 | */ |
870 | x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); |
871 | |
872 | /* |
873 | * 1x CVTDQ2PS |
874 | * 1x MULPS |
875 | * 1x CVTPS2DQ |
876 | * 1x PACKSSDW |
877 | * 1x PACKUSWB |
878 | * 1x PADDW |
879 | * 1x PMAXUB |
880 | * 1x PMINUB |
881 | * 1x PERMD |
882 | * --------------------- |
883 | * 9 instructions total |
884 | */ |
885 | _mm_storel_epi64( |
886 | reinterpret_cast<__m128i*>(out + i * ld_out + j), |
887 | _mm256_castsi256_si128(x_clamped_v)); |
888 | } // j loop vectorized |
889 | |
890 | int remainder = block.col_start + block.col_size - j; |
891 | if (remainder > 0) { |
892 | __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
893 | internal::avx2_ps_or_epi32_masks[remainder])); |
894 | |
895 | __m256i x_v = _mm256_maskload_epi32( |
896 | inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v); |
897 | |
898 | if (!A_SYMMETRIC) { |
899 | __m256i col_off_v; |
900 | if (DIRECT == false) { |
901 | col_off_v = _mm256_mullo_epi32( |
902 | A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v)); |
903 | } else { |
904 | col_off_v = _mm256_mullo_epi32( |
905 | A_zero_point_v, |
906 | _mm256_maskload_epi32( |
907 | r.col_offsets + j + i * block.col_size, mask_v)); |
908 | } |
909 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
910 | } |
911 | |
912 | if (!B_SYMMETRIC) { |
913 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
914 | row_offset_v = _mm256_mullo_epi32( |
915 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
916 | _mm256_maskload_epi32(r.B_zero_point + j, mask_v)); |
917 | } |
918 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
919 | } |
920 | |
921 | __m256 xf_v; |
922 | if (HAS_BIAS) { |
923 | if (is_same<BIAS_TYPE, float>::value) { |
924 | __m256 x_bias_v; |
925 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
926 | x_bias_v = _mm256_div_ps( |
927 | _mm256_maskload_ps( |
928 | reinterpret_cast<const float*>(r.bias + j), mask_v), |
929 | _mm256_maskload_ps(r.act_times_w_scale + j, mask_v)); |
930 | } else { |
931 | x_bias_v = _mm256_mul_ps( |
932 | _mm256_maskload_ps( |
933 | reinterpret_cast<const float*>(r.bias + j), mask_v), |
934 | act_times_w_rcp_v); |
935 | } |
936 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); |
937 | } else { |
938 | x_v = _mm256_add_epi32( |
939 | x_v, |
940 | _mm256_maskload_epi32( |
941 | reinterpret_cast<const int*>(r.bias + j), mask_v)); |
942 | xf_v = _mm256_cvtepi32_ps(x_v); |
943 | } |
944 | } else { |
945 | xf_v = _mm256_cvtepi32_ps(x_v); |
946 | } |
947 | |
948 | __m256 x_scaled_v; |
949 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
950 | x_scaled_v = |
951 | _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v)); |
952 | } else { |
953 | x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); |
954 | } |
955 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
956 | |
957 | __m256i x_packed_v = _mm256_adds_epi16( |
958 | _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), |
959 | C_zero_point_epi16_v); |
960 | x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); |
961 | __m256i x_clamped_v = _mm256_max_epu8( |
962 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
963 | _mm256_min_epu8(x_packed_v, max_v)); |
964 | |
965 | /* |
966 | * x_clamped_v has results in the following layout so we need to |
967 | * permute: x0-3 garbage0-11 x4-7 garbage12-23 |
968 | */ |
969 | x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); |
970 | |
971 | /* |
972 | * 1x CVTDQ2PS |
973 | * 1x MULPS |
974 | * 1x CVTPS2DQ |
975 | * 1x PACKSSDW |
976 | * 1x PACKUSWB |
977 | * 1x PADDW |
978 | * 1x PMAXUB |
979 | * 1x PMINUB |
980 | * 1x PERMD |
981 | * --------------------- |
982 | * 9 instructions total |
983 | */ |
984 | alignas(64) uint8_t x_clamped_buffer[32]; |
985 | _mm256_store_si256( |
986 | reinterpret_cast<__m256i*>(x_clamped_buffer), x_clamped_v); |
987 | for (int k = 0; k < remainder; ++k) { |
988 | out[i * ld_out + j + k] = x_clamped_buffer[k]; |
989 | } |
990 | } // j loop remainder |
991 | } // i loop |
992 | } |
993 | |
994 | template < |
995 | bool A_SYMMETRIC, |
996 | bool B_SYMMETRIC, |
997 | QuantizationGranularity Q_GRAN, |
998 | bool HAS_BIAS, |
999 | bool FUSE_RELU> |
1000 | void requantizeForFloatAvx2( |
1001 | float* out, |
1002 | const int32_t* inp, |
1003 | const block_type_t& block, |
1004 | int ld_out, |
1005 | int ld_in, |
1006 | const requantizationForFloatParams_t& r) { |
1007 | // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c |
1008 | // using AVX2 instructions |
1009 | int quant_param_idx = 0; |
1010 | if (Q_GRAN == QuantizationGranularity::GROUP) { |
1011 | int ncol_per_group = r.ncols / r.groups; |
1012 | int g = block.col_start / ncol_per_group; |
1013 | quant_param_idx = g; |
1014 | } |
1015 | __m256 multiplier_v = _mm256_set1_ps(r.A_scale * r.B_scale[quant_param_idx]); |
1016 | |
1017 | assert( |
1018 | (A_SYMMETRIC == (r.A_zero_point == 0)) && |
1019 | "A_SYMMETRIC == true if and only if A_zero_point == 0" ); |
1020 | assert( |
1021 | (B_SYMMETRIC == |
1022 | ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) || |
1023 | r.row_offsets == nullptr)) && |
1024 | "B_SYMMETRIC == true if and only if B_zero_point == 0 " |
1025 | "or r.row_offsets == nullptr" ); |
1026 | assert( |
1027 | (HAS_BIAS == (r.bias != nullptr)) && |
1028 | "HAS_BIAS == true if and only if bias != nullptr" ); |
1029 | |
1030 | __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point); |
1031 | |
1032 | constexpr int VLEN = 8; |
1033 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
1034 | // Scale row_offset with Bq_zero_point |
1035 | int32_t row_offset = 0; |
1036 | if (B_SYMMETRIC) { |
1037 | row_offset = 0; |
1038 | } else if ( |
1039 | Q_GRAN == QuantizationGranularity::TENSOR || |
1040 | Q_GRAN == QuantizationGranularity::GROUP) { |
1041 | row_offset = |
1042 | r.row_offsets[i - block.row_start] * r.B_zero_point[quant_param_idx]; |
1043 | } else { |
1044 | assert( |
1045 | Q_GRAN == QuantizationGranularity::OUT_CHANNEL && |
1046 | "unknown quantization granularity" ); |
1047 | } |
1048 | __m256i row_offset_v = _mm256_set1_epi32(row_offset); |
1049 | |
1050 | int j = block.col_start; |
1051 | for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) { |
1052 | __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1053 | inp + (i - block.row_start) * ld_in + (j - block.col_start))); |
1054 | |
1055 | if (!A_SYMMETRIC) { |
1056 | __m256i col_off_v = _mm256_mullo_epi32( |
1057 | A_zero_point_v, |
1058 | _mm256_loadu_si256( |
1059 | reinterpret_cast<const __m256i*>(r.col_offsets + j))); |
1060 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
1061 | } |
1062 | |
1063 | if (!B_SYMMETRIC) { |
1064 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
1065 | row_offset_v = _mm256_mullo_epi32( |
1066 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
1067 | _mm256_loadu_si256( |
1068 | reinterpret_cast<const __m256i*>(r.B_zero_point + j))); |
1069 | } |
1070 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
1071 | } |
1072 | |
1073 | __m256 x_scaled_v; |
1074 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
1075 | x_scaled_v = _mm256_mul_ps( |
1076 | _mm256_cvtepi32_ps(x_v), |
1077 | _mm256_mul_ps( |
1078 | _mm256_set1_ps(r.A_scale), _mm256_loadu_ps(r.B_scale + j))); |
1079 | } else { |
1080 | x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); |
1081 | } |
1082 | |
1083 | if (HAS_BIAS) { |
1084 | x_scaled_v = _mm256_add_ps(x_scaled_v, _mm256_loadu_ps(r.bias + j)); |
1085 | } |
1086 | if (FUSE_RELU) { |
1087 | x_scaled_v = _mm256_max_ps(_mm256_setzero_ps(), x_scaled_v); |
1088 | } |
1089 | |
1090 | _mm256_storeu_ps(out + i * ld_out + j, x_scaled_v); |
1091 | } // j loop vectorized |
1092 | |
1093 | int remainder = block.col_start + block.col_size - j; |
1094 | if (remainder > 0) { |
1095 | __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
1096 | internal::avx2_ps_or_epi32_masks[remainder])); |
1097 | |
1098 | __m256i x_v = _mm256_maskload_epi32( |
1099 | inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v); |
1100 | |
1101 | if (!A_SYMMETRIC) { |
1102 | __m256i col_off_v = _mm256_mullo_epi32( |
1103 | A_zero_point_v, _mm256_maskload_epi32(r.col_offsets + j, mask_v)); |
1104 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
1105 | } |
1106 | |
1107 | if (!B_SYMMETRIC) { |
1108 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
1109 | row_offset_v = _mm256_mullo_epi32( |
1110 | _mm256_set1_epi32(r.row_offsets[i - block.row_start]), |
1111 | _mm256_maskload_epi32(r.B_zero_point + j, mask_v)); |
1112 | } |
1113 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
1114 | } |
1115 | |
1116 | __m256 x_scaled_v; |
1117 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
1118 | x_scaled_v = _mm256_mul_ps( |
1119 | _mm256_cvtepi32_ps(x_v), |
1120 | _mm256_mul_ps( |
1121 | _mm256_set1_ps(r.A_scale), |
1122 | _mm256_maskload_ps(r.B_scale + j, mask_v))); |
1123 | } else { |
1124 | x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); |
1125 | } |
1126 | |
1127 | if (HAS_BIAS) { |
1128 | x_scaled_v = |
1129 | _mm256_add_ps(x_scaled_v, _mm256_maskload_ps(r.bias + j, mask_v)); |
1130 | } |
1131 | if (FUSE_RELU) { |
1132 | x_scaled_v = _mm256_max_ps(_mm256_setzero_ps(), x_scaled_v); |
1133 | } |
1134 | |
1135 | _mm256_maskstore_ps(out + i * ld_out + j, mask_v, x_scaled_v); |
1136 | } // j loop remainder |
1137 | } // i loop |
1138 | } |
1139 | |
1140 | template < |
1141 | bool A_SYMMETRIC, |
1142 | bool B_SYMMETRIC, |
1143 | QuantizationGranularity Q_GRAN, |
1144 | bool HAS_BIAS, |
1145 | bool FUSE_RELU, |
1146 | int C_PER_G, |
1147 | typename BIAS_TYPE> |
1148 | void requantizeOutputProcessingGConvAvx2( |
1149 | uint8_t* out, |
1150 | const int32_t* inp, |
1151 | const block_type_t& block, |
1152 | int ld_out, |
1153 | int ld_in, |
1154 | const requantizationParams_t<BIAS_TYPE>& r) { |
1155 | // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c |
1156 | // using AVX2 instructions |
1157 | int quant_param_idx = 0; |
1158 | if (Q_GRAN == QuantizationGranularity::GROUP) { |
1159 | int ncol_per_group = r.ncols / r.groups; |
1160 | int g = block.col_start / ncol_per_group; |
1161 | quant_param_idx = g; |
1162 | } |
1163 | __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); |
1164 | |
1165 | // Broadcasted reciprocal of act_times_w_scale |
1166 | __m256 act_times_w_rcp_v; |
1167 | if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { |
1168 | if (is_same<BIAS_TYPE, float>::value) { |
1169 | act_times_w_rcp_v = |
1170 | _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); |
1171 | } |
1172 | } |
1173 | __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); |
1174 | __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); |
1175 | |
1176 | assert( |
1177 | (A_SYMMETRIC == (r.A_zero_point == 0)) && |
1178 | "A_SYMMETRIC == true if and only if A_zero_point == 0" ); |
1179 | assert( |
1180 | (B_SYMMETRIC == |
1181 | ((Q_GRAN == QuantizationGranularity::TENSOR && r.B_zero_point[0] == 0) || |
1182 | r.row_offsets == nullptr)) && |
1183 | "B_SYMMETRIC == true if and only if B_zero_point == 0 " |
1184 | "or r.row_offsets == nullptr" ); |
1185 | assert( |
1186 | (HAS_BIAS == (r.bias != nullptr)) && |
1187 | "HAS_BIAS == true if and only if bias != nullptr" ); |
1188 | |
1189 | __m256i A_zero_point_v = _mm256_set1_epi32(r.A_zero_point); |
1190 | __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point); |
1191 | __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point); |
1192 | |
1193 | __m256i permute_mask_v = |
1194 | _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); |
1195 | |
1196 | constexpr int VLEN = 8; |
1197 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
1198 | int j = block.col_start; |
1199 | for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) { |
1200 | __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1201 | inp + (i - block.row_start) * ld_in + (j - block.col_start))); |
1202 | |
1203 | if (!A_SYMMETRIC) { |
1204 | __m256i col_off_v = _mm256_mullo_epi32( |
1205 | A_zero_point_v, |
1206 | _mm256_loadu_si256( |
1207 | reinterpret_cast<const __m256i*>(r.col_offsets + j))); |
1208 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
1209 | } |
1210 | |
1211 | if (!B_SYMMETRIC) { |
1212 | __m256i row_offset_v; |
1213 | |
1214 | if (C_PER_G == 2) { |
1215 | // When C_PER_G == 2, we need to handle 4 groups at a time to fully |
1216 | // utilize 32B AVX2 vector register (C_PER_G * 4 * sizeof(int32_t) == |
1217 | // 32B) |
1218 | // Load row_offsets for 4 groups and broadcast by 2 times. |
1219 | row_offset_v = |
1220 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
1221 | _mm256_castps128_ps256( |
1222 | _mm_loadu_ps(reinterpret_cast<const float*>( |
1223 | r.row_offsets + (i - block.row_start) * 4))), |
1224 | permute_mask_v))); |
1225 | |
1226 | } |
1227 | // When C_PER_G == 4, we need to handle 2 groups at a time to fully |
1228 | // utilize 32B AVX2 vector register (C_PER_G * 2 * sizeof(int32_t) == |
1229 | // 32B) |
1230 | // When C_PER_G == 8, we just need 1 group at a time on the other hand. |
1231 | |
1232 | // Groups 0 and 1 when C_PER_G == 4 |
1233 | // Group 0 when C_PER_G == 8 |
1234 | else if (C_PER_G == 4) { |
1235 | // Load row_offsets for 2 groups and broadcast by 4 times each because |
1236 | // we have 4 channels per group. |
1237 | // groups 0 and 1 |
1238 | row_offset_v = _mm256_insertf128_si256( |
1239 | _mm256_castsi128_si256( |
1240 | _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 2 + 0])), |
1241 | _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 2 + 1]), |
1242 | 1); |
1243 | } else if (C_PER_G == 8) { |
1244 | row_offset_v = |
1245 | _mm256_set1_epi32(r.row_offsets[(i - block.row_start)]); |
1246 | } else { |
1247 | assert(C_PER_G == 16); |
1248 | row_offset_v = |
1249 | _mm256_set1_epi32(r.row_offsets[(i - block.row_start)]); |
1250 | } |
1251 | |
1252 | __m256i B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[0]); |
1253 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
1254 | B_zero_point_v = _mm256_loadu_si256( |
1255 | reinterpret_cast<const __m256i*>(r.B_zero_point + j)); |
1256 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
1257 | if (C_PER_G == 2) { |
1258 | B_zero_point_v = |
1259 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
1260 | _mm256_castps128_ps256( |
1261 | _mm_loadu_ps(reinterpret_cast<const float*>( |
1262 | r.B_zero_point + quant_param_idx))), |
1263 | permute_mask_v))); |
1264 | } else if (C_PER_G == 4) { |
1265 | B_zero_point_v = _mm256_insertf128_si256( |
1266 | _mm256_castsi128_si256( |
1267 | _mm_set1_epi32(r.B_zero_point[quant_param_idx])), |
1268 | _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]), |
1269 | 1); |
1270 | } else if (C_PER_G == 8) { |
1271 | B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[quant_param_idx]); |
1272 | } else { |
1273 | B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[quant_param_idx]); |
1274 | } |
1275 | } |
1276 | row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); |
1277 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
1278 | } |
1279 | __m256 xf_v; |
1280 | if (HAS_BIAS) { |
1281 | if (is_same<BIAS_TYPE, float>::value) { |
1282 | __m256 x_bias_v = |
1283 | _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)); |
1284 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
1285 | x_bias_v = _mm256_div_ps( |
1286 | x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j)); |
1287 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
1288 | __m256 diviser_v; |
1289 | if (C_PER_G == 2) { |
1290 | diviser_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
1291 | _mm256_castps128_ps256( |
1292 | _mm_loadu_ps(r.act_times_w_scale + quant_param_idx)), |
1293 | permute_mask_v)); |
1294 | } else if (C_PER_G == 4) { |
1295 | diviser_v = _mm256_insertf128_ps( |
1296 | _mm256_castps128_ps256( |
1297 | _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])), |
1298 | _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]), |
1299 | 1); |
1300 | } else if (C_PER_G == 8) { |
1301 | diviser_v = _mm256_set1_ps(r.act_times_w_scale[quant_param_idx]); |
1302 | } else { |
1303 | assert(C_PER_G == 16); |
1304 | diviser_v = _mm256_set1_ps(r.act_times_w_scale[quant_param_idx]); |
1305 | } |
1306 | x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); |
1307 | } else { |
1308 | x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v); |
1309 | } |
1310 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); |
1311 | } else { |
1312 | x_v = _mm256_add_epi32( |
1313 | x_v, |
1314 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); |
1315 | xf_v = _mm256_cvtepi32_ps(x_v); |
1316 | } |
1317 | } else { |
1318 | xf_v = _mm256_cvtepi32_ps(x_v); |
1319 | } |
1320 | |
1321 | /* |
1322 | * Convert int32_t input to FP32 and multiply by FP32 scale. |
1323 | * Both operations involve statistically unbiased roundings (with |
1324 | * default MXCSR rounding mode): |
1325 | * - Large int32_t values can't be exactly represented as FP32. |
1326 | * CVTDQ2PS instruction on x86 would round it according to nearest |
1327 | * FP32 value with ties to even (assuming default MXCSR rounding |
1328 | * mode). |
1329 | * - Product of two FP32 values is generally not exactly |
1330 | * representation as an FP32 value, and will be rounded to nearest |
1331 | * FP32 value with ties to even with default MXCSR rounding mode. |
1332 | */ |
1333 | __m256 x_scaled_v; |
1334 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
1335 | x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); |
1336 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
1337 | if (C_PER_G == 2) { |
1338 | multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
1339 | _mm256_castps128_ps256( |
1340 | _mm_loadu_ps(r.C_multiplier + quant_param_idx)), |
1341 | permute_mask_v)); |
1342 | } else if (C_PER_G == 4) { |
1343 | multiplier_v = _mm256_insertf128_ps( |
1344 | _mm256_castps128_ps256( |
1345 | _mm_set1_ps(r.C_multiplier[quant_param_idx])), |
1346 | _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), |
1347 | 1); |
1348 | } else if (C_PER_G == 8) { |
1349 | multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); |
1350 | } else { |
1351 | multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); |
1352 | } |
1353 | x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); |
1354 | } else { |
1355 | x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); |
1356 | } |
1357 | |
1358 | /* |
1359 | * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction. |
1360 | * CVTPS2DQ instruction rounds result according to nearest FP32 value |
1361 | * with ties to even (assuming default MXCSR rounding mode). However, |
1362 | * when conversion overflows, it produces INT32_MIN as a result. For |
1363 | * large positive inputs the result of conversion can become negative, |
1364 | * which affects the final requantization result. Note that on x86 |
1365 | * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This |
1366 | * happens because float(INT32_MAX) rounds to 2**31, which overflows |
1367 | * int32_t when it is converted back to integer. |
1368 | * |
1369 | * Thankfully, we can prove that overflow never happens in this |
1370 | * requantization scheme. The largest positive input is INT32_MAX |
1371 | * (2**31 - 1), which turns into 2**31 when converted to float. The |
1372 | * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the |
1373 | * result is 2147483520 (compare to INT32_MAX = 2147483647), which |
1374 | * fits into int32_t without overflow. |
1375 | */ |
1376 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
1377 | |
1378 | /* |
1379 | * Standard final sequence on x86 AVX2: |
1380 | * - Pack to int16_t and saturate |
1381 | * - Add zero point |
1382 | * - Pack to uint8_t and saturate |
1383 | * - Clamp between qmin and qmax |
1384 | */ |
1385 | __m256i x_packed_v = _mm256_adds_epi16( |
1386 | _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), |
1387 | C_zero_point_epi16_v); |
1388 | x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); |
1389 | __m256i x_clamped_v = _mm256_max_epu8( |
1390 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
1391 | _mm256_min_epu8(x_packed_v, max_v)); |
1392 | |
1393 | /* |
1394 | * x_clamped_v has results in the following layout so we need to |
1395 | * permute: x0-3 garbage0-11 x4-7 garbage12-23 |
1396 | */ |
1397 | x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); |
1398 | |
1399 | /* |
1400 | * 1x CVTDQ2PS |
1401 | * 1x MULPS |
1402 | * 1x CVTPS2DQ |
1403 | * 1x PACKSSDW |
1404 | * 1x PACKUSWB |
1405 | * 1x PADDW |
1406 | * 1x PMAXUB |
1407 | * 1x PMINUB |
1408 | * 1x PERMD |
1409 | * --------------------- |
1410 | * 9 instructions total |
1411 | */ |
1412 | |
1413 | _mm_storel_epi64( |
1414 | reinterpret_cast<__m128i*>(out + i * ld_out + j), |
1415 | _mm256_castsi256_si128(x_clamped_v)); |
1416 | } // j loop vectorized |
1417 | |
1418 | const int remainder = block.col_start + block.col_size - j; |
1419 | (void)remainder; // Suppress unused variable warning |
1420 | assert(remainder == 0); |
1421 | } // i loop |
1422 | } |
1423 | |
1424 | #define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \ |
1425 | A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \ |
1426 | template void FBGEMM_API requantizeOutputProcessingAvx2< \ |
1427 | A_SYM, \ |
1428 | B_SYM, \ |
1429 | Q_GRAN, \ |
1430 | BIAS, \ |
1431 | RELU, \ |
1432 | BIAS_TYPE, \ |
1433 | false>( \ |
1434 | uint8_t * out, \ |
1435 | const int32_t* inp, \ |
1436 | const block_type_t& block, \ |
1437 | int ld_out, \ |
1438 | int ld_in, \ |
1439 | const requantizationParams_t<BIAS_TYPE>& r); \ |
1440 | template void FBGEMM_API requantizeOutputProcessingAvx2< \ |
1441 | A_SYM, \ |
1442 | B_SYM, \ |
1443 | Q_GRAN, \ |
1444 | BIAS, \ |
1445 | RELU, \ |
1446 | BIAS_TYPE, \ |
1447 | true>( \ |
1448 | uint8_t * out, \ |
1449 | const int32_t* inp, \ |
1450 | const block_type_t& block, \ |
1451 | int ld_out, \ |
1452 | int ld_in, \ |
1453 | const requantizationParams_t<BIAS_TYPE>& r); \ |
1454 | template void requantizeOutputProcessingGConvAvx2< \ |
1455 | A_SYM, \ |
1456 | B_SYM, \ |
1457 | Q_GRAN, \ |
1458 | BIAS, \ |
1459 | RELU, \ |
1460 | 2, \ |
1461 | BIAS_TYPE>( \ |
1462 | uint8_t * out, \ |
1463 | const int32_t* inp, \ |
1464 | const block_type_t& block, \ |
1465 | int ld_out, \ |
1466 | int ld_in, \ |
1467 | const requantizationParams_t<BIAS_TYPE>& r); \ |
1468 | template void requantizeOutputProcessingGConvAvx2< \ |
1469 | A_SYM, \ |
1470 | B_SYM, \ |
1471 | Q_GRAN, \ |
1472 | BIAS, \ |
1473 | RELU, \ |
1474 | 4, \ |
1475 | BIAS_TYPE>( \ |
1476 | uint8_t * out, \ |
1477 | const int32_t* inp, \ |
1478 | const block_type_t& block, \ |
1479 | int ld_out, \ |
1480 | int ld_in, \ |
1481 | const requantizationParams_t<BIAS_TYPE>& r); \ |
1482 | template void requantizeOutputProcessingGConvAvx2< \ |
1483 | A_SYM, \ |
1484 | B_SYM, \ |
1485 | Q_GRAN, \ |
1486 | BIAS, \ |
1487 | RELU, \ |
1488 | 8, \ |
1489 | BIAS_TYPE>( \ |
1490 | uint8_t * out, \ |
1491 | const int32_t* inp, \ |
1492 | const block_type_t& block, \ |
1493 | int ld_out, \ |
1494 | int ld_in, \ |
1495 | const requantizationParams_t<BIAS_TYPE>& r); \ |
1496 | template void requantizeOutputProcessingGConvAvx2< \ |
1497 | A_SYM, \ |
1498 | B_SYM, \ |
1499 | Q_GRAN, \ |
1500 | BIAS, \ |
1501 | RELU, \ |
1502 | 16, \ |
1503 | BIAS_TYPE>( \ |
1504 | uint8_t * out, \ |
1505 | const int32_t* inp, \ |
1506 | const block_type_t& block, \ |
1507 | int ld_out, \ |
1508 | int ld_in, \ |
1509 | const requantizationParams_t<BIAS_TYPE>& r); |
1510 | |
1511 | #define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ |
1512 | INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \ |
1513 | INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \ |
1514 | template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ |
1515 | float* out, \ |
1516 | const int32_t* inp, \ |
1517 | const block_type_t& block, \ |
1518 | int ld_out, \ |
1519 | int ld_in, \ |
1520 | const requantizationForFloatParams_t& r); |
1521 | |
1522 | #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ |
1523 | INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ |
1524 | INSTANTIATE_REQUANTIZE(false, B_SYM, Q_GRAN, BIAS, RELU) |
1525 | |
1526 | #define INSTANTIATE_B_SYM(Q_GRAN, BIAS, RELU) \ |
1527 | INSTANTIATE_A_SYM(true, Q_GRAN, BIAS, RELU) \ |
1528 | INSTANTIATE_A_SYM(false, Q_GRAN, BIAS, RELU) |
1529 | |
1530 | #define INSTANTIATE_Q_GRANS(BIAS, RELU) \ |
1531 | INSTANTIATE_B_SYM(QuantizationGranularity::TENSOR, BIAS, RELU) \ |
1532 | INSTANTIATE_B_SYM(QuantizationGranularity::GROUP, BIAS, RELU) \ |
1533 | INSTANTIATE_B_SYM(QuantizationGranularity::OUT_CHANNEL, BIAS, RELU) |
1534 | |
1535 | #define INSTANTIATE_BIAS(RELU) \ |
1536 | INSTANTIATE_Q_GRANS(true, RELU) \ |
1537 | INSTANTIATE_Q_GRANS(false, RELU) |
1538 | |
1539 | INSTANTIATE_BIAS(true) |
1540 | INSTANTIATE_BIAS(false) |
1541 | |
1542 | #undef INSTANTIATE_A_SYM |
1543 | #undef INSTANTIATE_B_SYM |
1544 | #undef INSTANTIATE_Q_GRANS |
1545 | #undef INSTANTIATE_BIAS |
1546 | |
1547 | static inline uint16_t floatToHalf(float val) { |
1548 | #ifdef _MSC_VER |
1549 | // Use _mm256_cvtps_ph/_mm256_cvtph_ps because _cvtsh_ss/_cvtss_sh don't |
1550 | // exist in MSVC. |
1551 | __m256 val_v = _mm256_set1_ps(val); |
1552 | __m128i val_half_v = |
1553 | _mm256_cvtps_ph(val_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); |
1554 | return static_cast<std::uint16_t>(_mm_cvtsi128_si32(val_half_v)); |
1555 | #else |
1556 | return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); |
1557 | #endif |
1558 | } |
1559 | static inline float halfToFloat(uint16_t val) { |
1560 | #ifdef _MSC_VER |
1561 | return _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_cvtsi32_si128(val))); |
1562 | #else |
1563 | return _cvtsh_ss(val); |
1564 | #endif |
1565 | } |
1566 | |
1567 | template <typename InputType, int BIT_RATE> |
1568 | void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2( |
1569 | const InputType* input, |
1570 | size_t input_rows, |
1571 | int input_columns, |
1572 | std::uint8_t* output) { |
1573 | static_assert( |
1574 | std::is_same<InputType, float>() || std::is_same<InputType, float16>(), |
1575 | "Only float and float16 types are allowed." ); |
1576 | constexpr int VLEN = 8; |
1577 | constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE; |
1578 | int output_columns = |
1579 | (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE + |
1580 | 2 * sizeof(std::uint16_t); |
1581 | |
1582 | float* input_row_float_for_fp16; |
1583 | if (std::is_same<InputType, float16>()) { |
1584 | input_row_float_for_fp16 = static_cast<float*>( |
1585 | fbgemmAlignedAlloc(64, input_columns * sizeof(float))); |
1586 | } |
1587 | |
1588 | for (size_t row = 0; row < input_rows; ++row) { |
1589 | const InputType* input_row = input + row * input_columns; |
1590 | const float* input_row_float; |
1591 | if (std::is_same<InputType, float>()) { |
1592 | // NOTE: this reinterpret_cast is only to workaround c++ |
1593 | // type requirements -- it is not for fp16 case and `input_row` HAS to be |
1594 | // float* type. Remove it and use constexpr when pytorch allows C++17. |
1595 | input_row_float = reinterpret_cast<const float*>(input_row); |
1596 | } else { |
1597 | input_row_float = input_row_float_for_fp16; |
1598 | } |
1599 | |
1600 | std::uint8_t* output_row = output + row * output_columns; |
1601 | std::uint16_t* output_row_scale_bias = reinterpret_cast<std::uint16_t*>( |
1602 | output_row + |
1603 | (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE); |
1604 | |
1605 | float minimum_element = FLT_MAX; |
1606 | float maximum_element = -FLT_MAX; |
1607 | __m256 min_v = _mm256_set1_ps(minimum_element); |
1608 | __m256 max_v = _mm256_set1_ps(maximum_element); |
1609 | |
1610 | int col; |
1611 | for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) { |
1612 | __m256 in_v; |
1613 | if (std::is_same<InputType, float>()) { |
1614 | in_v = _mm256_loadu_ps(input_row_float + col); |
1615 | } else { |
1616 | __m128i in_half_v = |
1617 | _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_row + col)); |
1618 | in_v = _mm256_cvtph_ps(in_half_v); |
1619 | _mm256_store_ps(input_row_float_for_fp16 + col, in_v); |
1620 | } |
1621 | |
1622 | min_v = _mm256_min_ps(min_v, in_v); |
1623 | max_v = _mm256_max_ps(max_v, in_v); |
1624 | } |
1625 | alignas(64) float min_buf[VLEN], max_buf[VLEN]; |
1626 | _mm256_store_ps(min_buf, min_v); |
1627 | _mm256_store_ps(max_buf, max_v); |
1628 | for (int i = 0; i < VLEN; ++i) { |
1629 | minimum_element = std::min(minimum_element, min_buf[i]); |
1630 | maximum_element = std::max(maximum_element, max_buf[i]); |
1631 | } |
1632 | |
1633 | for (; col < input_columns; ++col) { |
1634 | if (std::is_same<InputType, float>()) { |
1635 | minimum_element = std::min(minimum_element, input_row_float[col]); |
1636 | maximum_element = std::max(maximum_element, input_row_float[col]); |
1637 | } else { |
1638 | float element = halfToFloat(input_row[col]); |
1639 | input_row_float_for_fp16[col] = element; |
1640 | minimum_element = std::min(minimum_element, element); |
1641 | maximum_element = std::max(maximum_element, element); |
1642 | } |
1643 | } |
1644 | |
1645 | output_row_scale_bias[1] = floatToHalf(minimum_element); |
1646 | minimum_element = halfToFloat(output_row_scale_bias[1]); |
1647 | const float range = maximum_element - minimum_element; |
1648 | |
1649 | float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1); |
1650 | std::uint16_t scale_fp16 = floatToHalf(scale); |
1651 | scale = halfToFloat(scale_fp16); |
1652 | if (scale == 0) { |
1653 | // Corner case handling when maximum_element == minimum_element |
1654 | // Any scale would work because maximum_element - minimum_element will be |
1655 | // 0 for all X |
1656 | scale = 1.0f; |
1657 | } |
1658 | float inverse_scale = 1.0f / scale; |
1659 | if (std::isinf(inverse_scale)) { |
1660 | scale = 1.0f; |
1661 | inverse_scale = 1.0f; |
1662 | } |
1663 | |
1664 | output_row_scale_bias[0] = floatToHalf(scale); |
1665 | |
1666 | col = 0; |
1667 | if (BIT_RATE == 2 || BIT_RATE == 4) { |
1668 | __m256i permute_mask1_v = |
1669 | _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); |
1670 | __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); |
1671 | min_v = _mm256_set1_ps(minimum_element); |
1672 | |
1673 | for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) { |
1674 | __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1675 | _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v), |
1676 | inverse_scale_v)); |
1677 | __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1678 | _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col + VLEN), min_v), |
1679 | inverse_scale_v)); |
1680 | __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1681 | _mm256_sub_ps( |
1682 | _mm256_loadu_ps(input_row_float + col + 2 * VLEN), min_v), |
1683 | inverse_scale_v)); |
1684 | __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1685 | _mm256_sub_ps( |
1686 | _mm256_loadu_ps(input_row_float + col + 3 * VLEN), min_v), |
1687 | inverse_scale_v)); |
1688 | |
1689 | // An instruction sequence to save 32 32-bit integers as 8-bit integers |
1690 | __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); |
1691 | __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); |
1692 | __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); |
1693 | xyzw_packed_v = |
1694 | _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v); |
1695 | |
1696 | // saturate to BIT_RATE |
1697 | xyzw_packed_v = _mm256_min_epu8( |
1698 | xyzw_packed_v, |
1699 | _mm256_set1_epi8(static_cast<char>((1 << BIT_RATE) - 1))); |
1700 | |
1701 | if (BIT_RATE == 4) { |
1702 | // pack into lower 8-bit of each 16-bit |
1703 | xyzw_packed_v = _mm256_and_si256( |
1704 | _mm256_or_si256( |
1705 | xyzw_packed_v, _mm256_srli_epi16(xyzw_packed_v, 4)), |
1706 | _mm256_set1_epi16(0x00ff)); |
1707 | } else { |
1708 | // pack into lower 8-bit of each 32-bit |
1709 | xyzw_packed_v = _mm256_and_si256( |
1710 | _mm256_or_si256( |
1711 | _mm256_or_si256( |
1712 | xyzw_packed_v, _mm256_srli_epi32(xyzw_packed_v, 6)), |
1713 | _mm256_or_si256( |
1714 | _mm256_srli_epi32(xyzw_packed_v, 8 + 4), |
1715 | _mm256_srli_epi32(xyzw_packed_v, 2 * 8 + 2))), |
1716 | _mm256_set1_epi32(0x00ff)); |
1717 | } |
1718 | |
1719 | __m128i out_v; |
1720 | if (BIT_RATE == 4) { |
1721 | // avx2 doesn't have _mm256_cvtepi16_epi8 |
1722 | out_v = _mm_packus_epi16( |
1723 | _mm256_castsi256_si128(xyzw_packed_v), |
1724 | _mm256_extractf128_si256(xyzw_packed_v, 1)); |
1725 | _mm_storeu_si128( |
1726 | reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE), |
1727 | out_v); |
1728 | } else { |
1729 | // avx2 doesn't have _mm256_cvtepi32_epi8 |
1730 | out_v = _mm_packus_epi32( |
1731 | _mm256_castsi256_si128(xyzw_packed_v), |
1732 | _mm256_extractf128_si256(xyzw_packed_v, 1)); |
1733 | out_v = _mm_packus_epi16(out_v, out_v); |
1734 | _mm_storel_epi64( |
1735 | reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE), |
1736 | out_v); |
1737 | } |
1738 | } |
1739 | } |
1740 | |
1741 | for (; col < input_columns; ++col) { |
1742 | float X = input_row_float[col]; |
1743 | std::uint8_t quantized = std::max( |
1744 | 0, |
1745 | std::min<int>( |
1746 | std::lrintf((X - minimum_element) * inverse_scale), |
1747 | (1 << BIT_RATE) - 1)); |
1748 | if (col % NUM_ELEM_PER_BYTE == 0) { |
1749 | output_row[col / NUM_ELEM_PER_BYTE] = quantized; |
1750 | } else { |
1751 | output_row[col / NUM_ELEM_PER_BYTE] |= |
1752 | (quantized << ((col % NUM_ELEM_PER_BYTE) * BIT_RATE)); |
1753 | } |
1754 | } |
1755 | } // for each row |
1756 | |
1757 | if (std::is_same<InputType, float16>()) { |
1758 | fbgemmAlignedFree(input_row_float_for_fp16); |
1759 | } |
1760 | } |
1761 | |
1762 | template <typename InputType> |
1763 | void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2( |
1764 | const InputType* input, |
1765 | size_t input_rows, |
1766 | int input_columns, |
1767 | std::uint8_t* output) { |
1768 | constexpr int VLEN = 8; |
1769 | constexpr float kEpsilon = 1e-8f; |
1770 | |
1771 | __m256i permute_mask1_v = |
1772 | _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); |
1773 | // clang-format off |
1774 | __m256i shuffle_mask_v = _mm256_set_epi8( |
1775 | 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, |
1776 | 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00, |
1777 | 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, |
1778 | 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00); |
1779 | // clang-format on |
1780 | |
1781 | __m256i permute_mask2_v = |
1782 | _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); |
1783 | |
1784 | int output_columns = input_columns + 2 * sizeof(float); |
1785 | float* input_row_float_for_fp16; |
1786 | if (std::is_same<InputType, float16>()) { |
1787 | input_row_float_for_fp16 = static_cast<float*>( |
1788 | fbgemmAlignedAlloc(64, input_columns * sizeof(float))); |
1789 | } |
1790 | for (size_t row = 0; row < input_rows; ++row) { |
1791 | const InputType* input_row = input + row * input_columns; |
1792 | const float* input_row_float; |
1793 | if (std::is_same<InputType, float>()) { |
1794 | // NOTE: this reinterpret_cast is only to workaround c++ |
1795 | // type requirements -- it is not for fp16 case and `input_row` HAS to be |
1796 | // float* type. Remove it and use constexpr when pytorch allows C++17. |
1797 | input_row_float = reinterpret_cast<const float*>(input_row); |
1798 | } else { |
1799 | input_row_float = input_row_float_for_fp16; |
1800 | } |
1801 | std::uint8_t* output_row = output + row * output_columns; |
1802 | float* output_row_scale_bias = |
1803 | reinterpret_cast<float*>(output_row + input_columns); |
1804 | |
1805 | float minimum_element = FLT_MAX; |
1806 | float maximum_element = -FLT_MAX; |
1807 | __m256 min_v = _mm256_set1_ps(minimum_element); |
1808 | __m256 max_v = _mm256_set1_ps(maximum_element); |
1809 | int col; |
1810 | for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) { |
1811 | __m256 in_v; |
1812 | if (std::is_same<InputType, float>()) { |
1813 | in_v = _mm256_loadu_ps(input_row_float + col); |
1814 | } else { |
1815 | __m128i in_half_v = |
1816 | _mm_loadu_si128(reinterpret_cast<const __m128i*>(input_row + col)); |
1817 | in_v = _mm256_cvtph_ps(in_half_v); |
1818 | _mm256_store_ps(input_row_float_for_fp16 + col, in_v); |
1819 | } |
1820 | min_v = _mm256_min_ps(min_v, in_v); |
1821 | max_v = _mm256_max_ps(max_v, in_v); |
1822 | } |
1823 | alignas(64) float min_buf[VLEN], max_buf[VLEN]; |
1824 | _mm256_store_ps(min_buf, min_v); |
1825 | _mm256_store_ps(max_buf, max_v); |
1826 | for (int i = 0; i < VLEN; ++i) { |
1827 | minimum_element = std::min(minimum_element, min_buf[i]); |
1828 | maximum_element = std::max(maximum_element, max_buf[i]); |
1829 | } |
1830 | |
1831 | for (; col < input_columns; ++col) { |
1832 | if (std::is_same<InputType, float>()) { |
1833 | minimum_element = std::min(minimum_element, input_row_float[col]); |
1834 | maximum_element = std::max(maximum_element, input_row_float[col]); |
1835 | } else { |
1836 | float element = halfToFloat(input_row[col]); |
1837 | input_row_float_for_fp16[col] = element; |
1838 | minimum_element = std::min(minimum_element, element); |
1839 | maximum_element = std::max(maximum_element, element); |
1840 | } |
1841 | } |
1842 | |
1843 | float range = maximum_element - minimum_element; |
1844 | output_row_scale_bias[0] = range / 255.0f; |
1845 | output_row_scale_bias[1] = minimum_element; |
1846 | const auto inverse_scale = 255.0f / (range + kEpsilon); |
1847 | min_v = _mm256_set1_ps(minimum_element); |
1848 | __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); |
1849 | |
1850 | for (col = 0; col < input_columns / (4 * VLEN) * (4 * VLEN); |
1851 | col += 4 * VLEN) { |
1852 | __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1853 | _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v), |
1854 | inverse_scale_v)); |
1855 | __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1856 | _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col + VLEN), min_v), |
1857 | inverse_scale_v)); |
1858 | __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1859 | _mm256_sub_ps( |
1860 | _mm256_loadu_ps(input_row_float + col + 2 * VLEN), min_v), |
1861 | inverse_scale_v)); |
1862 | __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1863 | _mm256_sub_ps( |
1864 | _mm256_loadu_ps(input_row_float + col + 3 * VLEN), min_v), |
1865 | inverse_scale_v)); |
1866 | |
1867 | // An instruction sequence to save 32 32-bit integers as 8-bit integers |
1868 | __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); |
1869 | __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); |
1870 | __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); |
1871 | xyzw_packed_v = |
1872 | _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v); |
1873 | _mm256_storeu_si256( |
1874 | reinterpret_cast<__m256i*>(output_row + col), xyzw_packed_v); |
1875 | } |
1876 | for (; col < input_columns / VLEN * VLEN; col += VLEN) { |
1877 | __m256i rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( |
1878 | _mm256_sub_ps(_mm256_loadu_ps(input_row_float + col), min_v), |
1879 | inverse_scale_v)); |
1880 | |
1881 | // An instruction sequence to save 8 32-bit integers as 8-bit integers |
1882 | rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); |
1883 | rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask2_v); |
1884 | _mm_storel_epi64( |
1885 | reinterpret_cast<__m128i*>(output_row + col), |
1886 | _mm256_castsi256_si128(rounded_v)); |
1887 | } |
1888 | for (; col < input_columns; ++col) { |
1889 | output_row[col] = |
1890 | std::lrintf((input_row_float[col] - minimum_element) * inverse_scale); |
1891 | } |
1892 | } // for each row |
1893 | if (std::is_same<InputType, float16>()) { |
1894 | fbgemmAlignedFree(input_row_float_for_fp16); |
1895 | } |
1896 | } |
1897 | |
1898 | template <typename OutputType, int BIT_RATE> |
1899 | void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2( |
1900 | const std::uint8_t* input, |
1901 | size_t input_rows, |
1902 | int input_columns, |
1903 | OutputType* output) { |
1904 | static_assert( |
1905 | std::is_same<OutputType, float>() || std::is_same<OutputType, float16>(), |
1906 | "Only float and float16 types are allowed." ); |
1907 | constexpr int VLEN = 8; |
1908 | constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE; |
1909 | int output_columns = |
1910 | (input_columns - 2 * sizeof(uint16_t)) * NUM_ELEM_PER_BYTE; |
1911 | |
1912 | // Compute a remainder for vector load |
1913 | // Since every row is followed by 2 fp16 (scale and bias), luckily |
1914 | // we don't need mask at bit-rate granularity but just at 32-bit |
1915 | // granularity. |
1916 | constexpr int NUM_ELEM_PER_32BIT = 32 / BIT_RATE; |
1917 | // multiply by 4 because we're handling 4 vlen per iteration |
1918 | constexpr int NUM_OF_32BIT_PER_VLOAD = VLEN * 4 / NUM_ELEM_PER_32BIT; |
1919 | |
1920 | int remainder_32bit_granularity, remainder; |
1921 | __m128i vmask_load; |
1922 | __m256i vmask_store0, vmask_store1, vmask_store2, vmask_store3; |
1923 | if (BIT_RATE == 4 || BIT_RATE == 2) { |
1924 | remainder_32bit_granularity = (output_columns + NUM_ELEM_PER_32BIT - 1) / |
1925 | NUM_ELEM_PER_32BIT % NUM_OF_32BIT_PER_VLOAD; |
1926 | vmask_load = _mm_lddqu_si128(reinterpret_cast<const __m128i*>( |
1927 | internal::avx2_ps_or_epi32_combined_mask + NUM_OF_32BIT_PER_VLOAD + |
1928 | (NUM_OF_32BIT_PER_VLOAD - remainder_32bit_granularity) % |
1929 | NUM_OF_32BIT_PER_VLOAD)); |
1930 | remainder = output_columns % (4 * VLEN); |
1931 | int remainder_ratio = 1; |
1932 | if (std::is_same<OutputType, float16>()) { |
1933 | // For fp16 we only need half of the mask. |
1934 | // |
1935 | // For instance, if reminder is 2, for FP32 the masks are |
1936 | // {-1, -1, 0, ..., 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0} |
1937 | // (8 32-bit integers for each mask) |
1938 | // for FP16 we only need |
1939 | // {-1, 0, 0, 0}, {0, ..., 0}, {0, ..., 0}, {0, ..., 0} |
1940 | // (4 32-bit integers for each mask) |
1941 | // since we reinterpret 2 FP16 numbers as one 32-bit number. |
1942 | // NOTE: for bit_rate 4 or 2, reminders are always multiple of 2 or 4, |
1943 | // so we do have to worry about odd number of FP16 numbers. |
1944 | // |
1945 | // Or, if reminder is 30, for FP32 the masks are |
1946 | // {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, .., -1, 0, 0} |
1947 | // for FP16 we only need |
1948 | // {-1, ..., -1}, {-1, ..., -1}, {-1, ..., -1}, {-1, -1, -1, 0} |
1949 | remainder_ratio = 2; |
1950 | } |
1951 | vmask_store0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1952 | internal::avx2_ps_or_epi32_combined_mask + |
1953 | (VLEN - std::min(remainder, VLEN) / remainder_ratio % (VLEN + 1)))); |
1954 | vmask_store1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1955 | internal::avx2_ps_or_epi32_combined_mask + |
1956 | (VLEN - |
1957 | std::max(0, std::min(remainder - VLEN, VLEN) / remainder_ratio) % |
1958 | (VLEN + 1)))); |
1959 | vmask_store2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1960 | internal::avx2_ps_or_epi32_combined_mask + |
1961 | (VLEN - |
1962 | std::max(0, std::min(remainder - 2 * VLEN, VLEN) / remainder_ratio) % |
1963 | (VLEN + 1)))); |
1964 | vmask_store3 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
1965 | internal::avx2_ps_or_epi32_combined_mask + |
1966 | (VLEN - |
1967 | std::max(0, std::min(remainder - 3 * VLEN, VLEN) / remainder_ratio) % |
1968 | (VLEN + 1)))); |
1969 | } |
1970 | |
1971 | for (size_t row = 0; row < input_rows; ++row) { |
1972 | const std::uint8_t* input_row = input + row * input_columns; |
1973 | const uint16_t* input_row_scale_bias = reinterpret_cast<const uint16_t*>( |
1974 | input_row + |
1975 | (output_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE); |
1976 | float scale = halfToFloat(input_row_scale_bias[0]); |
1977 | float bias = halfToFloat(input_row_scale_bias[1]); |
1978 | OutputType* output_row = output + row * output_columns; |
1979 | float* output_row_float; |
1980 | if (std::is_same<OutputType, float>()) { |
1981 | // NOTE: this reinterpret_cast is only to workaround c++ |
1982 | // type requirements -- it is not for fp16 case and `output_row` HAS to be |
1983 | // float* type. Remove it and use constexpr when pytorch allows C++17. |
1984 | output_row_float = reinterpret_cast<float*>(output_row); |
1985 | } |
1986 | |
1987 | int col = 0; |
1988 | if (BIT_RATE == 4 || BIT_RATE == 2) { |
1989 | __m256 vscale = _mm256_set1_ps(scale); |
1990 | __m256 vbias = _mm256_set1_ps(bias); |
1991 | for (; col + 4 * VLEN <= output_columns; col += 4 * VLEN) { |
1992 | __m256i vinq; |
1993 | // unpack to 8-bit integers |
1994 | if (BIT_RATE == 4) { |
1995 | vinq = _mm256_cvtepu8_epi16( |
1996 | _mm_loadu_si128(reinterpret_cast<const __m128i*>( |
1997 | input_row + col / NUM_ELEM_PER_BYTE))); |
1998 | vinq = _mm256_and_si256( |
1999 | _mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)), |
2000 | _mm256_set1_epi16(0x0f0f)); |
2001 | } else { |
2002 | vinq = _mm256_cvtepu8_epi32( |
2003 | _mm_loadl_epi64(reinterpret_cast<const __m128i*>( |
2004 | input_row + col / NUM_ELEM_PER_BYTE))); |
2005 | vinq = _mm256_and_si256( |
2006 | _mm256_or_si256( |
2007 | _mm256_or_si256( |
2008 | _mm256_slli_epi32(vinq, 2 * 8 + 2), |
2009 | _mm256_slli_epi32(vinq, 8 + 4)), |
2010 | _mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)), |
2011 | _mm256_set1_epi32(0x03030303)); |
2012 | } |
2013 | __m256 vinq0 = _mm256_cvtepi32_ps( |
2014 | _mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq))); |
2015 | __m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( |
2016 | _mm_set1_epi64x(_mm256_extract_epi64(vinq, 1)))); |
2017 | __m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( |
2018 | _mm_set1_epi64x(_mm256_extract_epi64(vinq, 2)))); |
2019 | __m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( |
2020 | _mm_set1_epi64x(_mm256_extract_epi64(vinq, 3)))); |
2021 | vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias); |
2022 | vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias); |
2023 | vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias); |
2024 | vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias); |
2025 | |
2026 | if (std::is_same<OutputType, float>()) { |
2027 | _mm256_storeu_ps(output_row_float + col, vinq0); |
2028 | _mm256_storeu_ps(output_row_float + col + VLEN, vinq1); |
2029 | _mm256_storeu_ps(output_row_float + col + 2 * VLEN, vinq2); |
2030 | _mm256_storeu_ps(output_row_float + col + 3 * VLEN, vinq3); |
2031 | } else { |
2032 | _mm_storeu_si128( |
2033 | reinterpret_cast<__m128i*>(output_row + col), |
2034 | _mm256_cvtps_ph( |
2035 | vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2036 | _mm_storeu_si128( |
2037 | reinterpret_cast<__m128i*>(output_row + col + VLEN), |
2038 | _mm256_cvtps_ph( |
2039 | vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2040 | _mm_storeu_si128( |
2041 | reinterpret_cast<__m128i*>(output_row + col + 2 * VLEN), |
2042 | _mm256_cvtps_ph( |
2043 | vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2044 | _mm_storeu_si128( |
2045 | reinterpret_cast<__m128i*>(output_row + col + 3 * VLEN), |
2046 | _mm256_cvtps_ph( |
2047 | vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2048 | } |
2049 | } |
2050 | |
2051 | if (remainder) { |
2052 | __m256i vinq; |
2053 | if (BIT_RATE == 4) { |
2054 | vinq = _mm256_cvtepu8_epi16(_mm_maskload_epi32( |
2055 | reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE), |
2056 | vmask_load)); |
2057 | vinq = _mm256_and_si256( |
2058 | _mm256_or_si256(vinq, _mm256_slli_epi32(vinq, 4)), |
2059 | _mm256_set1_epi16(0x0f0f)); |
2060 | } else { |
2061 | vinq = _mm256_cvtepu8_epi32(_mm_maskload_epi32( |
2062 | reinterpret_cast<const int*>(input_row + col / NUM_ELEM_PER_BYTE), |
2063 | vmask_load)); |
2064 | vinq = _mm256_and_si256( |
2065 | _mm256_or_si256( |
2066 | _mm256_or_si256( |
2067 | _mm256_slli_epi32(vinq, 2 * 8 + 2), |
2068 | _mm256_slli_epi32(vinq, 8 + 4)), |
2069 | _mm256_or_si256(_mm256_slli_epi32(vinq, 6), vinq)), |
2070 | _mm256_set1_epi32(0x03030303)); |
2071 | } |
2072 | |
2073 | __m256 vinq0 = _mm256_cvtepi32_ps( |
2074 | _mm256_cvtepi8_epi32(_mm256_castsi256_si128(vinq))); |
2075 | __m256 vinq1 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( |
2076 | _mm_set1_epi64x(_mm256_extract_epi64(vinq, 1)))); |
2077 | __m256 vinq2 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( |
2078 | _mm_set1_epi64x(_mm256_extract_epi64(vinq, 2)))); |
2079 | __m256 vinq3 = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( |
2080 | _mm_set1_epi64x(_mm256_extract_epi64(vinq, 3)))); |
2081 | |
2082 | vinq0 = _mm256_fmadd_ps(vscale, vinq0, vbias); |
2083 | vinq1 = _mm256_fmadd_ps(vscale, vinq1, vbias); |
2084 | vinq2 = _mm256_fmadd_ps(vscale, vinq2, vbias); |
2085 | vinq3 = _mm256_fmadd_ps(vscale, vinq3, vbias); |
2086 | |
2087 | if (std::is_same<OutputType, float>()) { |
2088 | _mm256_maskstore_ps(output_row_float + col, vmask_store0, vinq0); |
2089 | _mm256_maskstore_ps( |
2090 | output_row_float + col + VLEN, vmask_store1, vinq1); |
2091 | _mm256_maskstore_ps( |
2092 | output_row_float + col + 2 * VLEN, vmask_store2, vinq2); |
2093 | _mm256_maskstore_ps( |
2094 | output_row_float + col + 3 * VLEN, vmask_store3, vinq3); |
2095 | } else { |
2096 | _mm_maskstore_epi32( |
2097 | reinterpret_cast<int*>(output_row + col), |
2098 | _mm256_castsi256_si128(vmask_store0), |
2099 | _mm256_cvtps_ph( |
2100 | vinq0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2101 | _mm_maskstore_epi32( |
2102 | reinterpret_cast<int*>(output_row + col + VLEN), |
2103 | _mm256_castsi256_si128(vmask_store1), |
2104 | _mm256_cvtps_ph( |
2105 | vinq1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2106 | _mm_maskstore_epi32( |
2107 | reinterpret_cast<int*>(output_row + col + 2 * VLEN), |
2108 | _mm256_castsi256_si128(vmask_store2), |
2109 | _mm256_cvtps_ph( |
2110 | vinq2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2111 | _mm_maskstore_epi32( |
2112 | reinterpret_cast<int*>(output_row + col + 3 * VLEN), |
2113 | _mm256_castsi256_si128(vmask_store3), |
2114 | _mm256_cvtps_ph( |
2115 | vinq3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2116 | } |
2117 | } |
2118 | } else { |
2119 | for (; col < output_columns; ++col) { |
2120 | std::uint8_t quantized = input_row[col / NUM_ELEM_PER_BYTE]; |
2121 | quantized >>= (col % NUM_ELEM_PER_BYTE) * BIT_RATE; |
2122 | quantized &= (1 << BIT_RATE) - 1; |
2123 | float output_value = scale * quantized + bias; |
2124 | if (std::is_same<OutputType, float>()) { |
2125 | output_row[col] = output_value; |
2126 | } else { |
2127 | output_row[col] = cpu_float2half_rn(output_value); |
2128 | } |
2129 | } |
2130 | } |
2131 | } // for each row |
2132 | } |
2133 | |
2134 | template <typename OutputType> |
2135 | void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2( |
2136 | const std::uint8_t* input, |
2137 | size_t input_rows, |
2138 | int input_columns, |
2139 | OutputType* output) { |
2140 | constexpr int VLEN = 8; |
2141 | int output_columns = input_columns - 2 * sizeof(float); |
2142 | |
2143 | for (size_t row = 0; row < input_rows; ++row) { |
2144 | const std::uint8_t* input_row = input + row * input_columns; |
2145 | const float* input_row_scale_bias = |
2146 | reinterpret_cast<const float*>(input_row + output_columns); |
2147 | OutputType* output_row = output + row * output_columns; |
2148 | |
2149 | __m256 scale_v = _mm256_set1_ps(input_row_scale_bias[0]); |
2150 | __m256 bias_v = _mm256_set1_ps(input_row_scale_bias[1]); |
2151 | |
2152 | int col; |
2153 | for (col = 0; col < output_columns / VLEN * VLEN; col += VLEN) { |
2154 | __m256 in_v = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( |
2155 | _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input_row + col)))); |
2156 | #ifdef __FMA__ |
2157 | __m256 dequantzed_v = _mm256_fmadd_ps(in_v, scale_v, bias_v); |
2158 | #else |
2159 | __m256 dequantzed_v = _mm256_add_ps(_mm256_mul_ps(in_v, scale_v), bias_v); |
2160 | #endif |
2161 | if (std::is_same<OutputType, float>()) { |
2162 | float* output_row_float = reinterpret_cast<float*>(output_row); |
2163 | _mm256_storeu_ps(output_row_float + col, dequantzed_v); |
2164 | } else { |
2165 | _mm_storeu_si128( |
2166 | reinterpret_cast<__m128i*>(output_row + col), |
2167 | _mm256_cvtps_ph( |
2168 | dequantzed_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
2169 | } |
2170 | } |
2171 | |
2172 | for (; col < output_columns; ++col) { |
2173 | float output_value = |
2174 | input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1]; |
2175 | if (std::is_same<OutputType, float>()) { |
2176 | output_row[col] = output_value; |
2177 | } else { |
2178 | output_row[col] = cpu_float2half_rn(output_value); |
2179 | } |
2180 | } |
2181 | } // for each row |
2182 | } |
2183 | |
2184 | #define INSTANTIATE_QuantizationAvx2FunctionsNBits(type, bit_rate) \ |
2185 | template void \ |
2186 | FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2<type, bit_rate>( \ |
2187 | const type* input, \ |
2188 | size_t input_rows, \ |
2189 | int input_columns, \ |
2190 | std::uint8_t* output); \ |
2191 | template void \ |
2192 | FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2<type, bit_rate>( \ |
2193 | const std::uint8_t* input, \ |
2194 | size_t input_rows, \ |
2195 | int input_columns, \ |
2196 | type* output); |
2197 | |
2198 | // clang-format off |
2199 | INSTANTIATE_QuantizationAvx2FunctionsNBits(float, 2) |
2200 | INSTANTIATE_QuantizationAvx2FunctionsNBits(float, 4) |
2201 | INSTANTIATE_QuantizationAvx2FunctionsNBits(float, 8) |
2202 | INSTANTIATE_QuantizationAvx2FunctionsNBits(float16, 2) |
2203 | INSTANTIATE_QuantizationAvx2FunctionsNBits(float16, 4) |
2204 | INSTANTIATE_QuantizationAvx2FunctionsNBits(float16, 8) |
2205 | // clang-format on |
2206 | #undef INSTANTIATE_QuantizationAvx2FunctionsNBits |
2207 | |
2208 | #define INSTANTIATE_QuantizationAvx2Functions8Bits(type) \ |
2209 | template void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<type>( \ |
2210 | const type* input, \ |
2211 | size_t input_rows, \ |
2212 | int input_columns, \ |
2213 | std::uint8_t* output); \ |
2214 | template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<type>( \ |
2215 | const std::uint8_t* input, \ |
2216 | size_t input_rows, \ |
2217 | int input_columns, \ |
2218 | type* output); |
2219 | |
2220 | // clang-format off |
2221 | INSTANTIATE_QuantizationAvx2Functions8Bits(float) |
2222 | INSTANTIATE_QuantizationAvx2Functions8Bits(float16) |
2223 | // clang-format on |
2224 | #undef INSTANTIATE_QuantizationAvx2Functions8Bits |
2225 | |
2226 | } // namespace fbgemm |
2227 | |