1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#pragma once
8
9#include <algorithm> // for min and max
10#include <cassert>
11#include <cmath> // for lrintf and sqrt
12#include <cstdint>
13#include <type_traits> // for is_same
14
15#if defined(__x86_64__) || defined(__i386__) || \
16 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
17#include <immintrin.h>
18#endif
19
20namespace fbgemm {
21
22// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
23// row_offsets for each row because of depth-wise convolution
24template <
25 bool FUSE_RELU,
26 bool HAS_BIAS,
27 QuantizationGranularity Q_GRAN,
28 bool A_SYMMETRIC,
29 bool B_SYMMETRIC,
30 int K_PER_G,
31 typename BIAS_TYPE>
32static ALWAYS_INLINE void requantize_(
33 std::int32_t A_zero_point,
34 const std::int32_t* B_zero_point,
35 const float* C_multiplier,
36 std::int32_t C_zero_point,
37 const std::int32_t* C_int32,
38 std::uint8_t* C_uint8,
39 int n,
40 const std::int32_t* row_offsets,
41 const std::int32_t* col_offsets,
42 const BIAS_TYPE* bias,
43 const float* act_times_w_scale = nullptr) {
44 __m256 multiplier_v = _mm256_setzero_ps();
45 // Broadcasted reciprocal of act_times_w_scale
46 __m256 act_times_w_rcp_v = _mm256_setzero_ps();
47 __m256i B_zero_point_v = _mm256_setzero_si256();
48 if (Q_GRAN == QuantizationGranularity::TENSOR) {
49 multiplier_v = _mm256_set1_ps(*C_multiplier);
50 if (std::is_same<BIAS_TYPE, float>::value) {
51 act_times_w_rcp_v = _mm256_set1_ps(1.0 / (*act_times_w_scale));
52 }
53 B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
54 }
55
56 __m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0));
57 __m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255));
58
59 if (A_SYMMETRIC) {
60 assert(A_zero_point == 0 || col_offsets == nullptr);
61 }
62 __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
63 __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
64 __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
65
66 __m256i permute_mask_v =
67 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
68
69 constexpr int VLEN = 8;
70 int j = 0;
71 for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
72 __m256i x_v =
73 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
74 __m256i y_v = _mm256_loadu_si256(
75 reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
76 __m256i z_v = _mm256_loadu_si256(
77 reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
78 __m256i w_v = _mm256_loadu_si256(
79 reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
80
81 __m256i row_offset_v;
82 if (!B_SYMMETRIC) {
83 if (K_PER_G == 1) {
84 row_offset_v = _mm256_loadu_si256(
85 reinterpret_cast<const __m256i*>(row_offsets + j));
86 } else {
87 assert(K_PER_G == 2);
88 // Load row_offsets for 4 groups and broadcast by 2 times.
89 row_offset_v =
90 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
91 _mm256_castps128_ps256(_mm_loadu_ps(
92 reinterpret_cast<const float*>(row_offsets + j / 2))),
93 permute_mask_v)));
94 }
95 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
96 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
97 B_zero_point_v = _mm256_loadu_si256(
98 reinterpret_cast<const __m256i*>(B_zero_point + j));
99 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
100 assert(K_PER_G == 2);
101 B_zero_point_v =
102 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
103 _mm256_castps128_ps256(_mm_loadu_ps(
104 reinterpret_cast<const float*>(B_zero_point + j / 2))),
105 permute_mask_v)));
106 }
107 row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
108 x_v = _mm256_sub_epi32(x_v, row_offset_v);
109 }
110 __m256i col_off_v;
111 if (!A_SYMMETRIC) {
112 col_off_v = _mm256_mullo_epi32(
113 A_zero_point_v,
114 _mm256_loadu_si256(
115 reinterpret_cast<const __m256i*>(col_offsets + j)));
116 x_v = _mm256_sub_epi32(x_v, col_off_v);
117 }
118
119 if (!B_SYMMETRIC) {
120 if (K_PER_G == 1) {
121 row_offset_v = _mm256_loadu_si256(
122 reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
123 } else {
124 row_offset_v =
125 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
126 _mm256_castps128_ps256(
127 _mm_loadu_ps(reinterpret_cast<const float*>(
128 row_offsets + (j + VLEN) / 2))),
129 permute_mask_v)));
130 }
131 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
132 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
133 B_zero_point_v = _mm256_loadu_si256(
134 reinterpret_cast<const __m256i*>(B_zero_point + j + VLEN));
135 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
136 B_zero_point_v =
137 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
138 _mm256_castps128_ps256(
139 _mm_loadu_ps(reinterpret_cast<const float*>(
140 B_zero_point + (j + VLEN) / 2))),
141 permute_mask_v)));
142 }
143 row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
144 y_v = _mm256_sub_epi32(y_v, row_offset_v);
145 }
146 if (!A_SYMMETRIC) {
147 col_off_v = _mm256_mullo_epi32(
148 A_zero_point_v,
149 _mm256_loadu_si256(
150 reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
151 y_v = _mm256_sub_epi32(y_v, col_off_v);
152 }
153
154 if (!B_SYMMETRIC) {
155 if (K_PER_G == 1) {
156 row_offset_v = _mm256_loadu_si256(
157 reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
158 } else {
159 row_offset_v =
160 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
161 _mm256_castps128_ps256(
162 _mm_loadu_ps(reinterpret_cast<const float*>(
163 row_offsets + (j + 2 * VLEN) / 2))),
164 permute_mask_v)));
165 }
166 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
167 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
168 B_zero_point_v = _mm256_loadu_si256(
169 reinterpret_cast<const __m256i*>(B_zero_point + j + 2 * VLEN));
170 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
171 B_zero_point_v =
172 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
173 _mm256_castps128_ps256(
174 _mm_loadu_ps(reinterpret_cast<const float*>(
175 B_zero_point + (j + 2 * VLEN) / 2))),
176 permute_mask_v)));
177 }
178 row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
179 z_v = _mm256_sub_epi32(z_v, row_offset_v);
180 }
181 if (!A_SYMMETRIC) {
182 col_off_v = _mm256_mullo_epi32(
183 A_zero_point_v,
184 _mm256_loadu_si256(
185 reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
186 z_v = _mm256_sub_epi32(z_v, col_off_v);
187 }
188
189 if (!B_SYMMETRIC) {
190 if (K_PER_G == 1) {
191 row_offset_v = _mm256_loadu_si256(
192 reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
193 } else {
194 row_offset_v =
195 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
196 _mm256_castps128_ps256(
197 _mm_loadu_ps(reinterpret_cast<const float*>(
198 row_offsets + (j + 3 * VLEN) / 2))),
199 permute_mask_v)));
200 }
201 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
202 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
203 B_zero_point_v = _mm256_loadu_si256(
204 reinterpret_cast<const __m256i*>(B_zero_point + j + 3 * VLEN));
205 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
206 B_zero_point_v =
207 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
208 _mm256_castps128_ps256(
209 _mm_loadu_ps(reinterpret_cast<const float*>(
210 B_zero_point + (j + 3 * VLEN) / 2))),
211 permute_mask_v)));
212 }
213 row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
214 w_v = _mm256_sub_epi32(w_v, row_offset_v);
215 }
216 if (!A_SYMMETRIC) {
217 col_off_v = _mm256_mullo_epi32(
218 A_zero_point_v,
219 _mm256_loadu_si256(
220 reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
221 w_v = _mm256_sub_epi32(w_v, col_off_v);
222 }
223
224 // convert to float
225 __m256 xf_v, yf_v, zf_v, wf_v;
226 if (HAS_BIAS) { // static if
227 if (std::is_same<BIAS_TYPE, float>::value) {
228 __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
229 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
230 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
231 x_bias_v = _mm256_div_ps(
232 _mm256_loadu_ps(
233 reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
234 _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN));
235 y_bias_v = _mm256_div_ps(
236 _mm256_loadu_ps(
237 reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
238 _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN));
239 z_bias_v = _mm256_div_ps(
240 _mm256_loadu_ps(
241 reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
242 _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN));
243 w_bias_v = _mm256_div_ps(
244 _mm256_loadu_ps(
245 reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
246 _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN));
247 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
248 assert(K_PER_G == 2);
249 x_bias_v = _mm256_div_ps(
250 _mm256_loadu_ps(
251 reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
252 _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
253 _mm256_castps128_ps256(
254 _mm_loadu_ps(act_times_w_scale + j / 2)),
255 permute_mask_v)));
256 y_bias_v = _mm256_div_ps(
257 _mm256_loadu_ps(
258 reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
259 _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
260 _mm256_castps128_ps256(
261 _mm_loadu_ps(act_times_w_scale + (j + VLEN) / 2)),
262 permute_mask_v)));
263 z_bias_v = _mm256_div_ps(
264 _mm256_loadu_ps(
265 reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
266 _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
267 _mm256_castps128_ps256(
268 _mm_loadu_ps(act_times_w_scale + (j + 2 * VLEN) / 2)),
269 permute_mask_v)));
270 w_bias_v = _mm256_div_ps(
271 _mm256_loadu_ps(
272 reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
273 _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
274 _mm256_castps128_ps256(
275 _mm_loadu_ps(act_times_w_scale + (j + 3 * VLEN) / 2)),
276 permute_mask_v)));
277 } else {
278 x_bias_v = _mm256_mul_ps(
279 _mm256_loadu_ps(
280 reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
281 act_times_w_rcp_v);
282 y_bias_v = _mm256_mul_ps(
283 _mm256_loadu_ps(
284 reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
285 act_times_w_rcp_v);
286 z_bias_v = _mm256_mul_ps(
287 _mm256_loadu_ps(
288 reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
289 act_times_w_rcp_v);
290 w_bias_v = _mm256_mul_ps(
291 _mm256_loadu_ps(
292 reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
293 act_times_w_rcp_v);
294 }
295 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
296 yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
297 zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
298 wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
299 } else {
300 x_v = _mm256_add_epi32(
301 x_v,
302 _mm256_loadu_si256(
303 reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN)));
304 y_v = _mm256_add_epi32(
305 y_v,
306 _mm256_loadu_si256(
307 reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN)));
308 z_v = _mm256_add_epi32(
309 z_v,
310 _mm256_loadu_si256(
311 reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
312 w_v = _mm256_add_epi32(
313 w_v,
314 _mm256_loadu_si256(
315 reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
316 xf_v = _mm256_cvtepi32_ps(x_v);
317 yf_v = _mm256_cvtepi32_ps(y_v);
318 zf_v = _mm256_cvtepi32_ps(z_v);
319 wf_v = _mm256_cvtepi32_ps(w_v);
320 }
321 } else {
322 xf_v = _mm256_cvtepi32_ps(x_v);
323 yf_v = _mm256_cvtepi32_ps(y_v);
324 zf_v = _mm256_cvtepi32_ps(z_v);
325 wf_v = _mm256_cvtepi32_ps(w_v);
326 }
327
328 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
329 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
330 multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN);
331 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
332 multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
333 _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)),
334 permute_mask_v));
335 }
336 __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
337 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
338 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
339 multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN);
340 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
341 multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
342 _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + (j + VLEN) / 2)),
343 permute_mask_v));
344 }
345 __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
346 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
347 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
348 multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
349 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
350 multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
351 _mm256_castps128_ps256(
352 _mm_loadu_ps(C_multiplier + (j + 2 * VLEN) / 2)),
353 permute_mask_v));
354 }
355 __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
356 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
357 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
358 multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
359 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
360 multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
361 _mm256_castps128_ps256(
362 _mm_loadu_ps(C_multiplier + (j + 3 * VLEN) / 2)),
363 permute_mask_v));
364 }
365 __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
366
367 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
368 __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
369 __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
370 __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
371
372 __m256i xy_packed_v = _mm256_adds_epi16(
373 _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
374 __m256i zw_packed_v = _mm256_adds_epi16(
375 _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
376 __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
377 __m256i xyzw_clamped_v = _mm256_max_epu8(
378 FUSE_RELU ? C_zero_point_epi8_v : min_v,
379 _mm256_min_epu8(xyzw_packed_v, max_v));
380
381 xyzw_clamped_v =
382 _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
383
384 _mm256_storeu_si256(
385 reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
386 } // j loop vectorized and unrolled 4x
387
388 for (; j < n / VLEN * VLEN; j += VLEN) {
389 __m256i x_v =
390 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
391
392 if (!B_SYMMETRIC) {
393 __m256i row_offset_v;
394 if (K_PER_G == 1) {
395 row_offset_v = _mm256_loadu_si256(
396 reinterpret_cast<const __m256i*>(row_offsets + j));
397 } else {
398 assert(K_PER_G == 2);
399 // Load row_offsets for 4 groups and broadcast by 2 times.
400 row_offset_v =
401 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
402 _mm256_castps128_ps256(_mm_loadu_ps(
403 reinterpret_cast<const float*>(row_offsets + j / 2))),
404 permute_mask_v)));
405 }
406 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
407 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
408 B_zero_point_v = _mm256_loadu_si256(
409 reinterpret_cast<const __m256i*>(B_zero_point + j));
410 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
411 assert(K_PER_G == 2);
412 B_zero_point_v =
413 _mm256_castps_si256(_mm256_moveldup_ps(_mm256_permutevar8x32_ps(
414 _mm256_castps128_ps256(_mm_loadu_ps(
415 reinterpret_cast<const float*>(B_zero_point + j / 2))),
416 permute_mask_v)));
417 }
418 row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
419 x_v = _mm256_sub_epi32(x_v, row_offset_v);
420 }
421 if (!A_SYMMETRIC) {
422 __m256i col_off_v = _mm256_mullo_epi32(
423 A_zero_point_v,
424 _mm256_loadu_si256(
425 reinterpret_cast<const __m256i*>(col_offsets + j)));
426 x_v = _mm256_sub_epi32(x_v, col_off_v);
427 }
428
429 // Convert to float
430 __m256 xf_v;
431 if (HAS_BIAS) { // static if
432 if (std::is_same<BIAS_TYPE, float>::value) {
433 __m256 x_bias_v;
434 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
435 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
436 x_bias_v = _mm256_div_ps(
437 _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
438 _mm256_loadu_ps(act_times_w_scale + j));
439 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
440 x_bias_v = _mm256_div_ps(
441 _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
442 _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
443 _mm256_castps128_ps256(
444 _mm_loadu_ps(act_times_w_scale + j / 2)),
445 permute_mask_v)));
446 } else {
447 x_bias_v = _mm256_mul_ps(
448 _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
449 act_times_w_rcp_v);
450 }
451 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
452 } else {
453 x_v = _mm256_add_epi32(
454 x_v,
455 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
456 xf_v = _mm256_cvtepi32_ps(x_v);
457 }
458 } else {
459 xf_v = _mm256_cvtepi32_ps(x_v);
460 }
461
462 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
463 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
464 multiplier_v = _mm256_loadu_ps(C_multiplier + j);
465 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
466 multiplier_v = _mm256_moveldup_ps(_mm256_permutevar8x32_ps(
467 _mm256_castps128_ps256(_mm_loadu_ps(C_multiplier + j / 2)),
468 permute_mask_v));
469 }
470 __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
471 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
472
473 __m256i x_packed_v = _mm256_adds_epi16(
474 _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
475 C_zero_point_epi16_v);
476 x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
477 __m256i x_clamped_v = _mm256_max_epu8(
478 FUSE_RELU ? C_zero_point_epi8_v : min_v,
479 _mm256_min_epu8(x_packed_v, max_v));
480
481 x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
482
483 _mm_storel_epi64(
484 reinterpret_cast<__m128i*>(C_uint8 + j),
485 _mm256_castsi256_si128(x_clamped_v));
486 } // j loop vectorized
487
488 for (; j < n; ++j) {
489 std::int32_t raw = C_int32[j];
490 int quant_param_idx = 0;
491 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL ||
492 (Q_GRAN == QuantizationGranularity::GROUP && K_PER_G == 1)) {
493 quant_param_idx = j;
494 } else if (Q_GRAN == QuantizationGranularity::GROUP) {
495 quant_param_idx = j / 2;
496 }
497 if (!B_SYMMETRIC) {
498 raw -= B_zero_point[quant_param_idx] * row_offsets[j / K_PER_G];
499 }
500 if (!A_SYMMETRIC) {
501 raw -= A_zero_point * col_offsets[j];
502 }
503 float raw_f;
504 if (HAS_BIAS) { // static if
505 if (std::is_same<BIAS_TYPE, float>::value) {
506 raw_f = raw;
507 raw_f += bias[j] / act_times_w_scale[quant_param_idx];
508 } else {
509 raw += bias[j];
510 raw_f = raw;
511 }
512 } else {
513 raw_f = raw;
514 }
515
516 float ab = raw_f * C_multiplier[quant_param_idx];
517 long rounded = lrintf(ab) + C_zero_point;
518
519 C_uint8[j] = std::max(
520 FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
521 std::min(255l, rounded));
522 }
523}
524
525static inline std::pair<int, int> closest_factors_(int n) {
526 int a = static_cast<int>(std::sqrt(n));
527 while (n % a != 0) {
528 a--;
529 }
530 return {a, n / a}; // a <= n / a
531}
532
533} // namespace fbgemm
534