1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | |
9 | #include <algorithm> // for min and max |
10 | #include <cassert> |
11 | #include <cmath> // for lrintf and sqrt |
12 | #include <cstdint> |
13 | #include <type_traits> // for is_same |
14 | |
15 | #if defined(__x86_64__) || defined(__i386__) || \ |
16 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
17 | #include <immintrin.h> |
18 | #endif |
19 | |
20 | namespace fbgemm { |
21 | |
22 | // Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different |
23 | // row_offsets for each row because of depth-wise convolution |
24 | template < |
25 | bool FUSE_RELU, |
26 | bool HAS_BIAS, |
27 | QuantizationGranularity Q_GRAN, |
28 | bool A_SYMMETRIC, |
29 | bool B_SYMMETRIC, |
30 | int K_PER_G, |
31 | typename BIAS_TYPE> |
32 | static ALWAYS_INLINE void requantize_( |
33 | std::int32_t A_zero_point, |
34 | const std::int32_t* B_zero_point, |
35 | const float* C_multiplier, |
36 | std::int32_t C_zero_point, |
37 | const std::int32_t* C_int32, |
38 | std::uint8_t* C_uint8, |
39 | int n, |
40 | const std::int32_t* row_offsets, |
41 | const std::int32_t* col_offsets, |
42 | const BIAS_TYPE* bias, |
43 | const float* act_times_w_scale = nullptr) { |
44 | __m256 multiplier_v = _mm256_setzero_ps(); |
45 | // Broadcasted reciprocal of act_times_w_scale |
46 | __m256 act_times_w_rcp_v = _mm256_setzero_ps(); |
47 | __m256i B_zero_point_v = _mm256_setzero_si256(); |
48 | if (Q_GRAN == QuantizationGranularity::TENSOR) { |
49 | multiplier_v = _mm256_set1_ps(*C_multiplier); |
50 | if (std::is_same<BIAS_TYPE, float>::value) { |
51 | act_times_w_rcp_v = _mm256_set1_ps(1.0 / (*act_times_w_scale)); |
52 | } |
53 | B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); |
54 | } |
55 | |
56 | __m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0)); |
57 | __m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255)); |
58 | |
59 | if (A_SYMMETRIC) { |
60 | assert(A_zero_point == 0 || col_offsets == nullptr); |
61 | } |
62 | __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); |
63 | __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); |
64 | __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); |
65 | |
66 | __m256i permute_mask_v = |
67 | _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); |
68 | |
69 | constexpr int VLEN = 8; |
70 | int j = 0; |
71 | for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) { |
72 | __m256i x_v = |
73 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); |
74 | __m256i y_v = _mm256_loadu_si256( |
75 | reinterpret_cast<const __m256i*>(C_int32 + j + VLEN)); |
76 | __m256i z_v = _mm256_loadu_si256( |
77 | reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN)); |
78 | __m256i w_v = _mm256_loadu_si256( |
79 | reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN)); |
80 | |
81 | __m256i row_offset_v; |
82 | if (!B_SYMMETRIC) { |
83 | if (K_PER_G == 1) { |
84 | row_offset_v = _mm256_loadu_si256( |
85 | reinterpret_cast<const __m256i*>(row_offsets + j)); |
86 | } else { |
87 | assert(K_PER_G == 2); |
88 | // Load row_offsets for 4 groups and broadcast by 2 times. |
89 | row_offset_v = |
90 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
91 | _mm256_castps128_ps256(_mm_loadu_ps( |
92 | reinterpret_cast<const float*>(row_offsets + j / 2))), |
93 | permute_mask_v))); |
94 | } |
95 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
96 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
97 | B_zero_point_v = _mm256_loadu_si256( |
98 | reinterpret_cast<const __m256i*>(B_zero_point + j)); |
99 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
100 | assert(K_PER_G == 2); |
101 | B_zero_point_v = |
102 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
103 | _mm256_castps128_ps256(_mm_loadu_ps( |
104 | reinterpret_cast<const float*>(B_zero_point + j / 2))), |
105 | permute_mask_v))); |
106 | } |
107 | row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); |
108 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
109 | } |
110 | __m256i col_off_v; |
111 | if (!A_SYMMETRIC) { |
112 | col_off_v = _mm256_mullo_epi32( |
113 | A_zero_point_v, |
114 | _mm256_loadu_si256( |
115 | reinterpret_cast<const __m256i*>(col_offsets + j))); |
116 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
117 | } |
118 | |
119 | if (!B_SYMMETRIC) { |
120 | if (K_PER_G == 1) { |
121 | row_offset_v = _mm256_loadu_si256( |
122 | reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); |
123 | } else { |
124 | row_offset_v = |
125 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
126 | _mm256_castps128_ps256( |
127 | _mm_loadu_ps(reinterpret_cast<const float*>( |
128 | row_offsets + (j + VLEN) / 2))), |
129 | permute_mask_v))); |
130 | } |
131 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
132 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
133 | B_zero_point_v = _mm256_loadu_si256( |
134 | reinterpret_cast<const __m256i*>(B_zero_point + j + VLEN)); |
135 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
136 | B_zero_point_v = |
137 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
138 | _mm256_castps128_ps256( |
139 | _mm_loadu_ps(reinterpret_cast<const float*>( |
140 | B_zero_point + (j + VLEN) / 2))), |
141 | permute_mask_v))); |
142 | } |
143 | row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); |
144 | y_v = _mm256_sub_epi32(y_v, row_offset_v); |
145 | } |
146 | if (!A_SYMMETRIC) { |
147 | col_off_v = _mm256_mullo_epi32( |
148 | A_zero_point_v, |
149 | _mm256_loadu_si256( |
150 | reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); |
151 | y_v = _mm256_sub_epi32(y_v, col_off_v); |
152 | } |
153 | |
154 | if (!B_SYMMETRIC) { |
155 | if (K_PER_G == 1) { |
156 | row_offset_v = _mm256_loadu_si256( |
157 | reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); |
158 | } else { |
159 | row_offset_v = |
160 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
161 | _mm256_castps128_ps256( |
162 | _mm_loadu_ps(reinterpret_cast<const float*>( |
163 | row_offsets + (j + 2 * VLEN) / 2))), |
164 | permute_mask_v))); |
165 | } |
166 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
167 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
168 | B_zero_point_v = _mm256_loadu_si256( |
169 | reinterpret_cast<const __m256i*>(B_zero_point + j + 2 * VLEN)); |
170 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
171 | B_zero_point_v = |
172 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
173 | _mm256_castps128_ps256( |
174 | _mm_loadu_ps(reinterpret_cast<const float*>( |
175 | B_zero_point + (j + 2 * VLEN) / 2))), |
176 | permute_mask_v))); |
177 | } |
178 | row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); |
179 | z_v = _mm256_sub_epi32(z_v, row_offset_v); |
180 | } |
181 | if (!A_SYMMETRIC) { |
182 | col_off_v = _mm256_mullo_epi32( |
183 | A_zero_point_v, |
184 | _mm256_loadu_si256( |
185 | reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN))); |
186 | z_v = _mm256_sub_epi32(z_v, col_off_v); |
187 | } |
188 | |
189 | if (!B_SYMMETRIC) { |
190 | if (K_PER_G == 1) { |
191 | row_offset_v = _mm256_loadu_si256( |
192 | reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); |
193 | } else { |
194 | row_offset_v = |
195 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
196 | _mm256_castps128_ps256( |
197 | _mm_loadu_ps(reinterpret_cast<const float*>( |
198 | row_offsets + (j + 3 * VLEN) / 2))), |
199 | permute_mask_v))); |
200 | } |
201 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
202 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
203 | B_zero_point_v = _mm256_loadu_si256( |
204 | reinterpret_cast<const __m256i*>(B_zero_point + j + 3 * VLEN)); |
205 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
206 | B_zero_point_v = |
207 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
208 | _mm256_castps128_ps256( |
209 | _mm_loadu_ps(reinterpret_cast<const float*>( |
210 | B_zero_point + (j + 3 * VLEN) / 2))), |
211 | permute_mask_v))); |
212 | } |
213 | row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); |
214 | w_v = _mm256_sub_epi32(w_v, row_offset_v); |
215 | } |
216 | if (!A_SYMMETRIC) { |
217 | col_off_v = _mm256_mullo_epi32( |
218 | A_zero_point_v, |
219 | _mm256_loadu_si256( |
220 | reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN))); |
221 | w_v = _mm256_sub_epi32(w_v, col_off_v); |
222 | } |
223 | |
224 | // convert to float |
225 | __m256 xf_v, yf_v, zf_v, wf_v; |
226 | if (HAS_BIAS) { // static if |
227 | if (std::is_same<BIAS_TYPE, float>::value) { |
228 | __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; |
229 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
230 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
231 | x_bias_v = _mm256_div_ps( |
232 | _mm256_loadu_ps( |
233 | reinterpret_cast<const float*>(bias + j + 0 * VLEN)), |
234 | _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN)); |
235 | y_bias_v = _mm256_div_ps( |
236 | _mm256_loadu_ps( |
237 | reinterpret_cast<const float*>(bias + j + 1 * VLEN)), |
238 | _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN)); |
239 | z_bias_v = _mm256_div_ps( |
240 | _mm256_loadu_ps( |
241 | reinterpret_cast<const float*>(bias + j + 2 * VLEN)), |
242 | _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN)); |
243 | w_bias_v = _mm256_div_ps( |
244 | _mm256_loadu_ps( |
245 | reinterpret_cast<const float*>(bias + j + 3 * VLEN)), |
246 | _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN)); |
247 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
248 | assert(K_PER_G == 2); |
249 | x_bias_v = _mm256_div_ps( |
250 | _mm256_loadu_ps( |
251 | reinterpret_cast<const float*>(bias + j + 0 * VLEN)), |
252 | _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
253 | _mm256_castps128_ps256( |
254 | _mm_loadu_ps(act_times_w_scale + j / 2)), |
255 | permute_mask_v))); |
256 | y_bias_v = _mm256_div_ps( |
257 | _mm256_loadu_ps( |
258 | reinterpret_cast<const float*>(bias + j + 1 * VLEN)), |
259 | _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
260 | _mm256_castps128_ps256( |
261 | _mm_loadu_ps(act_times_w_scale + (j + VLEN) / 2)), |
262 | permute_mask_v))); |
263 | z_bias_v = _mm256_div_ps( |
264 | _mm256_loadu_ps( |
265 | reinterpret_cast<const float*>(bias + j + 2 * VLEN)), |
266 | _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
267 | _mm256_castps128_ps256( |
268 | _mm_loadu_ps(act_times_w_scale + (j + 2 * VLEN) / 2)), |
269 | permute_mask_v))); |
270 | w_bias_v = _mm256_div_ps( |
271 | _mm256_loadu_ps( |
272 | reinterpret_cast<const float*>(bias + j + 3 * VLEN)), |
273 | _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
274 | _mm256_castps128_ps256( |
275 | _mm_loadu_ps(act_times_w_scale + (j + 3 * VLEN) / 2)), |
276 | permute_mask_v))); |
277 | } else { |
278 | x_bias_v = _mm256_mul_ps( |
279 | _mm256_loadu_ps( |
280 | reinterpret_cast<const float*>(bias + j + 0 * VLEN)), |
281 | act_times_w_rcp_v); |
282 | y_bias_v = _mm256_mul_ps( |
283 | _mm256_loadu_ps( |
284 | reinterpret_cast<const float*>(bias + j + 1 * VLEN)), |
285 | act_times_w_rcp_v); |
286 | z_bias_v = _mm256_mul_ps( |
287 | _mm256_loadu_ps( |
288 | reinterpret_cast<const float*>(bias + j + 2 * VLEN)), |
289 | act_times_w_rcp_v); |
290 | w_bias_v = _mm256_mul_ps( |
291 | _mm256_loadu_ps( |
292 | reinterpret_cast<const float*>(bias + j + 3 * VLEN)), |
293 | act_times_w_rcp_v); |
294 | } |
295 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); |
296 | yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); |
297 | zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); |
298 | wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); |
299 | } else { |
300 | x_v = _mm256_add_epi32( |
301 | x_v, |
302 | _mm256_loadu_si256( |
303 | reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN))); |
304 | y_v = _mm256_add_epi32( |
305 | y_v, |
306 | _mm256_loadu_si256( |
307 | reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN))); |
308 | z_v = _mm256_add_epi32( |
309 | z_v, |
310 | _mm256_loadu_si256( |
311 | reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN))); |
312 | w_v = _mm256_add_epi32( |
313 | w_v, |
314 | _mm256_loadu_si256( |
315 | reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN))); |
316 | xf_v = _mm256_cvtepi32_ps(x_v); |
317 | yf_v = _mm256_cvtepi32_ps(y_v); |
318 | zf_v = _mm256_cvtepi32_ps(z_v); |
319 | wf_v = _mm256_cvtepi32_ps(w_v); |
320 | } |
321 | } else { |
322 | xf_v = _mm256_cvtepi32_ps(x_v); |
323 | yf_v = _mm256_cvtepi32_ps(y_v); |
324 | zf_v = _mm256_cvtepi32_ps(z_v); |
325 | wf_v = _mm256_cvtepi32_ps(w_v); |
326 | } |
327 | |
328 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
329 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
330 | multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN); |
331 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
332 | multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
333 | _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)), |
334 | permute_mask_v)); |
335 | } |
336 | __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); |
337 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
338 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
339 | multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN); |
340 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
341 | multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
342 | _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + (j + VLEN) / 2)), |
343 | permute_mask_v)); |
344 | } |
345 | __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); |
346 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
347 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
348 | multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN); |
349 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
350 | multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
351 | _mm256_castps128_ps256( |
352 | _mm_loadu_ps(C_multiplier + (j + 2 * VLEN) / 2)), |
353 | permute_mask_v)); |
354 | } |
355 | __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); |
356 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
357 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
358 | multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN); |
359 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
360 | multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
361 | _mm256_castps128_ps256( |
362 | _mm_loadu_ps(C_multiplier + (j + 3 * VLEN) / 2)), |
363 | permute_mask_v)); |
364 | } |
365 | __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); |
366 | |
367 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
368 | __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v); |
369 | __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v); |
370 | __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v); |
371 | |
372 | __m256i xy_packed_v = _mm256_adds_epi16( |
373 | _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v); |
374 | __m256i zw_packed_v = _mm256_adds_epi16( |
375 | _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v); |
376 | __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); |
377 | __m256i xyzw_clamped_v = _mm256_max_epu8( |
378 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
379 | _mm256_min_epu8(xyzw_packed_v, max_v)); |
380 | |
381 | xyzw_clamped_v = |
382 | _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); |
383 | |
384 | _mm256_storeu_si256( |
385 | reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v); |
386 | } // j loop vectorized and unrolled 4x |
387 | |
388 | for (; j < n / VLEN * VLEN; j += VLEN) { |
389 | __m256i x_v = |
390 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); |
391 | |
392 | if (!B_SYMMETRIC) { |
393 | __m256i row_offset_v; |
394 | if (K_PER_G == 1) { |
395 | row_offset_v = _mm256_loadu_si256( |
396 | reinterpret_cast<const __m256i*>(row_offsets + j)); |
397 | } else { |
398 | assert(K_PER_G == 2); |
399 | // Load row_offsets for 4 groups and broadcast by 2 times. |
400 | row_offset_v = |
401 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
402 | _mm256_castps128_ps256(_mm_loadu_ps( |
403 | reinterpret_cast<const float*>(row_offsets + j / 2))), |
404 | permute_mask_v))); |
405 | } |
406 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
407 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
408 | B_zero_point_v = _mm256_loadu_si256( |
409 | reinterpret_cast<const __m256i*>(B_zero_point + j)); |
410 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
411 | assert(K_PER_G == 2); |
412 | B_zero_point_v = |
413 | _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
414 | _mm256_castps128_ps256(_mm_loadu_ps( |
415 | reinterpret_cast<const float*>(B_zero_point + j / 2))), |
416 | permute_mask_v))); |
417 | } |
418 | row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); |
419 | x_v = _mm256_sub_epi32(x_v, row_offset_v); |
420 | } |
421 | if (!A_SYMMETRIC) { |
422 | __m256i col_off_v = _mm256_mullo_epi32( |
423 | A_zero_point_v, |
424 | _mm256_loadu_si256( |
425 | reinterpret_cast<const __m256i*>(col_offsets + j))); |
426 | x_v = _mm256_sub_epi32(x_v, col_off_v); |
427 | } |
428 | |
429 | // Convert to float |
430 | __m256 xf_v; |
431 | if (HAS_BIAS) { // static if |
432 | if (std::is_same<BIAS_TYPE, float>::value) { |
433 | __m256 x_bias_v; |
434 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
435 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
436 | x_bias_v = _mm256_div_ps( |
437 | _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)), |
438 | _mm256_loadu_ps(act_times_w_scale + j)); |
439 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
440 | x_bias_v = _mm256_div_ps( |
441 | _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)), |
442 | _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
443 | _mm256_castps128_ps256( |
444 | _mm_loadu_ps(act_times_w_scale + j / 2)), |
445 | permute_mask_v))); |
446 | } else { |
447 | x_bias_v = _mm256_mul_ps( |
448 | _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)), |
449 | act_times_w_rcp_v); |
450 | } |
451 | xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); |
452 | } else { |
453 | x_v = _mm256_add_epi32( |
454 | x_v, |
455 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j))); |
456 | xf_v = _mm256_cvtepi32_ps(x_v); |
457 | } |
458 | } else { |
459 | xf_v = _mm256_cvtepi32_ps(x_v); |
460 | } |
461 | |
462 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
463 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
464 | multiplier_v = _mm256_loadu_ps(C_multiplier + j); |
465 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
466 | multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps( |
467 | _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)), |
468 | permute_mask_v)); |
469 | } |
470 | __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); |
471 | __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); |
472 | |
473 | __m256i x_packed_v = _mm256_adds_epi16( |
474 | _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()), |
475 | C_zero_point_epi16_v); |
476 | x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256()); |
477 | __m256i x_clamped_v = _mm256_max_epu8( |
478 | FUSE_RELU ? C_zero_point_epi8_v : min_v, |
479 | _mm256_min_epu8(x_packed_v, max_v)); |
480 | |
481 | x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v); |
482 | |
483 | _mm_storel_epi64( |
484 | reinterpret_cast<__m128i*>(C_uint8 + j), |
485 | _mm256_castsi256_si128(x_clamped_v)); |
486 | } // j loop vectorized |
487 | |
488 | for (; j < n; ++j) { |
489 | std::int32_t raw = C_int32[j]; |
490 | int quant_param_idx = 0; |
491 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL || |
492 | (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) { |
493 | quant_param_idx = j; |
494 | } else if (Q_GRAN == QuantizationGranularity::GROUP) { |
495 | quant_param_idx = j / 2; |
496 | } |
497 | if (!B_SYMMETRIC) { |
498 | raw -= B_zero_point[quant_param_idx] * row_offsets[j / K_PER_G]; |
499 | } |
500 | if (!A_SYMMETRIC) { |
501 | raw -= A_zero_point * col_offsets[j]; |
502 | } |
503 | float raw_f; |
504 | if (HAS_BIAS) { // static if |
505 | if (std::is_same<BIAS_TYPE, float>::value) { |
506 | raw_f = raw; |
507 | raw_f += bias[j] / act_times_w_scale[quant_param_idx]; |
508 | } else { |
509 | raw += bias[j]; |
510 | raw_f = raw; |
511 | } |
512 | } else { |
513 | raw_f = raw; |
514 | } |
515 | |
516 | float ab = raw_f * C_multiplier[quant_param_idx]; |
517 | long rounded = lrintf(ab) + C_zero_point; |
518 | |
519 | C_uint8[j] = std::max( |
520 | FUSE_RELU ? static_cast<long>(C_zero_point) : 0l, |
521 | std::min(255l, rounded)); |
522 | } |
523 | } |
524 | |
525 | static inline std::pair<int, int> closest_factors_(int n) { |
526 | int a = static_cast<int>(std::sqrt(n)); |
527 | while (n % a != 0) { |
528 | a--; |
529 | } |
530 | return {a, n / a}; // a <= n / a |
531 | } |
532 | |
533 | } // namespace fbgemm |
534 | |