1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/spmmUtilsAvx2.h" |
9 | #if defined(__x86_64__) || defined(__i386__) || \ |
10 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
11 | #include <immintrin.h> |
12 | #endif |
13 | #include <cassert> //for assert |
14 | #include "./MaskAvx2.h" |
15 | |
16 | namespace fbgemm { |
17 | |
18 | template < |
19 | bool FUSE_RELU, |
20 | bool ACT_SYMMETRIC, |
21 | bool WEIGHT_SYMMETRIC, |
22 | bool HAS_BIAS, |
23 | QuantizationGranularity Q_GRAN> |
24 | FBGEMM_API void trRequantizeOpt( |
25 | uint8_t* out, |
26 | const int32_t* inp, |
27 | const block_type_t& block, |
28 | int ld_out, |
29 | int ld_in, |
30 | const trRequantizationParams_t& r) { |
31 | assert( |
32 | (Q_GRAN != QuantizationGranularity::GROUP) && |
33 | "GROUP Granularity is not supported" ); |
34 | |
35 | // Broadcasted act_times_w_scale / C_scale |
36 | __m256 act_times_w_div_c_v; |
37 | if (Q_GRAN != QuantizationGranularity::OUT_CHANNEL) { |
38 | act_times_w_div_c_v = _mm256_set1_ps(r.act_times_w_scale[0] / r.C_scale); |
39 | } |
40 | |
41 | __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); |
42 | __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); |
43 | |
44 | assert( |
45 | (ACT_SYMMETRIC == (r.act_zero_point == 0)) && |
46 | "ACT_SYMMETRIC == true if and only if act_zero_point == 0" ); |
47 | assert( |
48 | (WEIGHT_SYMMETRIC == |
49 | ((Q_GRAN == QuantizationGranularity::TENSOR && |
50 | r.weight_zero_points[0] == 0) || |
51 | r.act_col_offsets == nullptr)) && |
52 | "WEIGHT_SYMMETRIC == true if and only if weight_zero_point == 0 " |
53 | "or r.act_col_offsets == nullptr" ); |
54 | assert( |
55 | (HAS_BIAS == (r.bias != nullptr)) && |
56 | "HAS_BIAS == true if and only if bias != nullptr" ); |
57 | |
58 | __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point); |
59 | __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point); |
60 | |
61 | __m256i permute_mask_v = |
62 | _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); |
63 | |
64 | constexpr int VLEN = 8; |
65 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
66 | // Scale weight_row_offset with act_zero_point |
67 | int32_t row_offset = 0; |
68 | if (!ACT_SYMMETRIC) { |
69 | row_offset = r.act_zero_point * r.weight_row_offsets[i]; |
70 | } |
71 | |
72 | __m256i row_offset_v = _mm256_set1_epi32(row_offset); |
73 | |
74 | int weight_zeropoint_idx = 0; |
75 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
76 | weight_zeropoint_idx = i; |
77 | } |
78 | __m256 bias_v; |
79 | if (HAS_BIAS) { |
80 | float bias = r.bias[i] / r.act_times_w_scale[weight_zeropoint_idx]; |
81 | bias_v = _mm256_set1_ps(bias); |
82 | } |
83 | |
84 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
85 | float act_times_w_div_c = |
86 | r.act_times_w_scale[weight_zeropoint_idx] / r.C_scale; |
87 | act_times_w_div_c_v = _mm256_set1_ps(act_times_w_div_c); |
88 | } |
89 | |
90 | __m256i weight_zeropoint_v = |
91 | _mm256_set1_epi32(r.weight_zero_points[weight_zeropoint_idx]); |
92 | |
93 | int j = block.col_start; |
94 | for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4)); |
95 | j += (VLEN * 4)) { |
96 | __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
97 | inp + (i - block.row_start) * ld_in + (j - block.col_start))); |
98 | __m256i y_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
99 | inp + (i - block.row_start) * ld_in + (j - block.col_start) + |
100 | 1 * VLEN)); |
101 | __m256i z_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
102 | inp + (i - block.row_start) * ld_in + (j - block.col_start) + |
103 | 2 * VLEN)); |
104 | __m256i w_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
105 | inp + (i - block.row_start) * ld_in + (j - block.col_start) + |
106 | 3 * VLEN)); |
107 | |
108 | if (!ACT_SYMMETRIC) { |
109 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
110 | y_v = _mm256_sub_epi32(y_v, row_offset_v); |
111 | z_v = _mm256_sub_epi32(z_v, row_offset_v); |
112 | w_v = _mm256_sub_epi32(w_v, row_offset_v); |
113 | } |
114 | if (!WEIGHT_SYMMETRIC) { |
115 | __m256i col_offset_v = _mm256_mullo_epi32( |
116 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
117 | r.act_col_offsets + j - block.col_start)), |
118 | weight_zeropoint_v); |
119 | x_v = _mm256_sub_epi32(x_v, col_offset_v); |
120 | |
121 | col_offset_v = _mm256_mullo_epi32( |
122 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
123 | r.act_col_offsets + VLEN + j - block.col_start)), |
124 | weight_zeropoint_v); |
125 | y_v = _mm256_sub_epi32(y_v, col_offset_v); |
126 | |
127 | col_offset_v = _mm256_mullo_epi32( |
128 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
129 | r.act_col_offsets + 2 * VLEN + j - block.col_start)), |
130 | weight_zeropoint_v); |
131 | z_v = _mm256_sub_epi32(z_v, col_offset_v); |
132 | |
133 | col_offset_v = _mm256_mullo_epi32( |
134 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
135 | r.act_col_offsets + 3 * VLEN + j - block.col_start)), |
136 | weight_zeropoint_v); |
137 | w_v = _mm256_sub_epi32(w_v, col_offset_v); |
138 | } |
139 | |
140 | /* |
141 | * Convert int32_t input to FP32 and multiply by FP32 scale. |
142 | * Both operations involve statistically unbiased roundings (with |
143 | * default MXCSR rounding mode): |
144 | * - Large int32_t values can't be exactly represented as FP32. |
145 | * CVTDQ2PS instruction on x86 would round it according to nearest |
146 | * FP32 value with ties to even (assuming default MXCSR rounding |
147 | * mode). |
148 | * - Product of two FP32 values is generally not exactly |
149 | * representation as an FP32 value, and will be rounded to nearest |
150 | * FP32 value with ties to even with default MXCSR rounding mode. |
151 | */ |
152 | __m256 xf_v, yf_v, zf_v, wf_v; |
153 | if (HAS_BIAS) { |
154 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), bias_v); |
155 | yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), bias_v); |
156 | zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), bias_v); |
157 | wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), bias_v); |
158 | } else { |
159 | xf_v = _mm256_cvtepi32_ps(x_v); |
160 | yf_v = _mm256_cvtepi32_ps(y_v); |
161 | zf_v = _mm256_cvtepi32_ps(z_v); |
162 | wf_v = _mm256_cvtepi32_ps(w_v); |
163 | } |
164 | |
165 | __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; |
166 | |
167 | x_scaled_v = _mm256_mul_ps(xf_v, act_times_w_div_c_v); |
168 | y_scaled_v = _mm256_mul_ps(yf_v, act_times_w_div_c_v); |
169 | z_scaled_v = _mm256_mul_ps(zf_v, act_times_w_div_c_v); |
170 | w_scaled_v = _mm256_mul_ps(wf_v, act_times_w_div_c_v); |
171 | |
172 | /* |
173 | * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction. |
174 | * CVTPS2DQ instruction rounds result according to nearest FP32 value |
175 | * with ties to even (assuming default MXCSR rounding mode). However, |
176 | * when conversion overflows, it produces INT32_MIN as a result. For |
177 | * large positive inputs the result of conversion can become negative, |
178 | * which affects the final requantization result. Note that on x86 |
179 | * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This |
180 | * happens because float(INT32_MAX) rounds to 2**31, which overflows |
181 | * int32_t when it is converted back to integer. |
182 | * |
183 | * Thankfully, we can prove that overflow never happens in this |
184 | * requantization scheme. The largest positive input is INT32_MAX |
185 | * (2**31 - 1), which turns into 2**31 when converted to float. The |
186 | * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the |
187 | * result is 2147483520 (compare to INT32_MAX = 2147483647), which |
188 | * fits into int32_t without overflow. |
189 | */ |
190 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
191 | __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); |
192 | __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); |
193 | __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); |
194 | |
195 | /* |
196 | * Standard final sequence on x86 AVX2: |
197 | * - Pack to int16_t and saturate |
198 | * - Add zero point |
199 | * - Pack to uint8_t and saturate |
200 | * - Clamp between qmin and qmax |
201 | */ |
202 | __m256i xy_packed_v = _mm256_adds_epi16( |
203 | _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); |
204 | __m256i zw_packed_v = _mm256_adds_epi16( |
205 | _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); |
206 | __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); |
207 | __m256i xyzw_clamped_v = _mm256_max_epu8( |
208 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
209 | _mm256_min_epu8(xyzw_packed_v, max_v)); |
210 | |
211 | /* |
212 | * xyzw_clamped_v has results in the following layout so we need to |
213 | * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 |
214 | */ |
215 | xyzw_clamped_v = |
216 | _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); |
217 | |
218 | _mm256_storeu_si256( |
219 | reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v); |
220 | } // j loop vectorized and unrolled 4x |
221 | |
222 | for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) { |
223 | __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
224 | inp + (i - block.row_start) * ld_in + (j - block.col_start))); |
225 | |
226 | if (!ACT_SYMMETRIC) { |
227 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
228 | } |
229 | if (!WEIGHT_SYMMETRIC) { |
230 | __m256i col_offset_v = _mm256_mullo_epi32( |
231 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
232 | r.act_col_offsets + j - block.col_start)), |
233 | weight_zeropoint_v); |
234 | x_v = _mm256_sub_epi32(x_v, col_offset_v); |
235 | } |
236 | __m256 xf_v; |
237 | if (HAS_BIAS) { |
238 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), bias_v); |
239 | } else { |
240 | xf_v = _mm256_cvtepi32_ps(x_v); |
241 | } |
242 | |
243 | __m256 x_scaled_v = _mm256_mul_ps(xf_v, act_times_w_div_c_v); |
244 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
245 | |
246 | __m256i x_packed_v = _mm256_adds_epi16( |
247 | _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), |
248 | C_zero_point_epi16_v); |
249 | x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); |
250 | __m256i x_clamped_v = _mm256_max_epu8( |
251 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
252 | _mm256_min_epu8(x_packed_v, max_v)); |
253 | |
254 | /* |
255 | * x_clamped_v has results in the following layout so we need to |
256 | * permute: x0-3 garbage0-11 x4-7 garbage12-23 |
257 | */ |
258 | x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); |
259 | |
260 | _mm_storel_epi64( |
261 | reinterpret_cast<__m128i*>(out + i * ld_out + j), |
262 | _mm256_castsi256_si128(x_clamped_v)); |
263 | } // j loop vectorized |
264 | |
265 | int remainder = block.col_start + block.col_size - j; |
266 | if (remainder > 0) { |
267 | __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
268 | internal::avx2_ps_or_epi32_masks[remainder])); |
269 | |
270 | __m256i x_v = _mm256_maskload_epi32( |
271 | inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v); |
272 | |
273 | if (!ACT_SYMMETRIC) { |
274 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
275 | } |
276 | if (!WEIGHT_SYMMETRIC) { |
277 | __m256i col_offset_v = _mm256_mullo_epi32( |
278 | _mm256_maskload_epi32( |
279 | r.act_col_offsets + j - block.col_start, mask_v), |
280 | weight_zeropoint_v); |
281 | x_v = _mm256_sub_epi32(x_v, col_offset_v); |
282 | } |
283 | |
284 | __m256 xf_v; |
285 | if (HAS_BIAS) { |
286 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), bias_v); |
287 | } else { |
288 | xf_v = _mm256_cvtepi32_ps(x_v); |
289 | } |
290 | __m256 x_scaled_v = _mm256_mul_ps(xf_v, act_times_w_div_c_v); |
291 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
292 | |
293 | __m256i x_packed_v = _mm256_adds_epi16( |
294 | _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), |
295 | C_zero_point_epi16_v); |
296 | x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); |
297 | __m256i x_clamped_v = _mm256_max_epu8( |
298 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
299 | _mm256_min_epu8(x_packed_v, max_v)); |
300 | |
301 | /* |
302 | * x_clamped_v has results in the following layout so we need to |
303 | * permute: x0-3 garbage0-11 x4-7 garbage12-23 |
304 | */ |
305 | x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); |
306 | |
307 | alignas(64) uint8_t x_clamped_buffer[32]; |
308 | _mm256_store_si256( |
309 | reinterpret_cast<__m256i*>(x_clamped_buffer), x_clamped_v); |
310 | for (int k = 0; k < remainder; ++k) { |
311 | out[i * ld_out + j + k] = x_clamped_buffer[k]; |
312 | } |
313 | } // j loop remainder |
314 | } // i loop |
315 | } |
316 | |
317 | #define CREATE_INSTANCE( \ |
318 | FUSE_RELU, ACT_SYMMETRIC, WEIGHT_SYMMETRIC, HAS_BIAS, QGRAN) \ |
319 | template FBGEMM_API void trRequantizeOpt< \ |
320 | FUSE_RELU, \ |
321 | ACT_SYMMETRIC, \ |
322 | WEIGHT_SYMMETRIC, \ |
323 | HAS_BIAS, \ |
324 | QGRAN>( \ |
325 | uint8_t * out, \ |
326 | const int32_t* inp, \ |
327 | const block_type_t& block, \ |
328 | int ld_out, \ |
329 | int ld_in, \ |
330 | const trRequantizationParams_t& r); |
331 | CREATE_INSTANCE(true, true, true, true, QuantizationGranularity::TENSOR) |
332 | CREATE_INSTANCE(true, true, true, false, QuantizationGranularity::TENSOR) |
333 | CREATE_INSTANCE(true, true, false, true, QuantizationGranularity::TENSOR) |
334 | CREATE_INSTANCE(true, true, false, false, QuantizationGranularity::TENSOR) |
335 | CREATE_INSTANCE(true, false, true, true, QuantizationGranularity::TENSOR) |
336 | CREATE_INSTANCE(true, false, true, false, QuantizationGranularity::TENSOR) |
337 | CREATE_INSTANCE(true, false, false, true, QuantizationGranularity::TENSOR) |
338 | CREATE_INSTANCE(true, false, false, false, QuantizationGranularity::TENSOR) |
339 | CREATE_INSTANCE(false, true, true, true, QuantizationGranularity::TENSOR) |
340 | CREATE_INSTANCE(false, true, true, false, QuantizationGranularity::TENSOR) |
341 | CREATE_INSTANCE(false, true, false, true, QuantizationGranularity::TENSOR) |
342 | CREATE_INSTANCE(false, true, false, false, QuantizationGranularity::TENSOR) |
343 | CREATE_INSTANCE(false, false, true, true, QuantizationGranularity::TENSOR) |
344 | CREATE_INSTANCE(false, false, true, false, QuantizationGranularity::TENSOR) |
345 | CREATE_INSTANCE(false, false, false, true, QuantizationGranularity::TENSOR) |
346 | CREATE_INSTANCE(false, false, false, false, QuantizationGranularity::TENSOR) |
347 | CREATE_INSTANCE(true, true, true, true, QuantizationGranularity::OUT_CHANNEL) |
348 | CREATE_INSTANCE(true, true, true, false, QuantizationGranularity::OUT_CHANNEL) |
349 | CREATE_INSTANCE(true, true, false, true, QuantizationGranularity::OUT_CHANNEL) |
350 | CREATE_INSTANCE(true, true, false, false, QuantizationGranularity::OUT_CHANNEL) |
351 | CREATE_INSTANCE(true, false, true, true, QuantizationGranularity::OUT_CHANNEL) |
352 | CREATE_INSTANCE(true, false, true, false, QuantizationGranularity::OUT_CHANNEL) |
353 | CREATE_INSTANCE(true, false, false, true, QuantizationGranularity::OUT_CHANNEL) |
354 | CREATE_INSTANCE(true, false, false, false, QuantizationGranularity::OUT_CHANNEL) |
355 | CREATE_INSTANCE(false, true, true, true, QuantizationGranularity::OUT_CHANNEL) |
356 | CREATE_INSTANCE(false, true, true, false, QuantizationGranularity::OUT_CHANNEL) |
357 | CREATE_INSTANCE(false, true, false, true, QuantizationGranularity::OUT_CHANNEL) |
358 | CREATE_INSTANCE(false, true, false, false, QuantizationGranularity::OUT_CHANNEL) |
359 | CREATE_INSTANCE(false, false, true, true, QuantizationGranularity::OUT_CHANNEL) |
360 | CREATE_INSTANCE(false, false, true, false, QuantizationGranularity::OUT_CHANNEL) |
361 | CREATE_INSTANCE(false, false, false, true, QuantizationGranularity::OUT_CHANNEL) |
362 | CREATE_INSTANCE( |
363 | false, |
364 | false, |
365 | false, |
366 | false, |
367 | QuantizationGranularity::OUT_CHANNEL) |
368 | #undef CREATE_INSTANCE |
369 | |
370 | } // namespace fbgemm |
371 | |