1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/spmmUtilsAvx2.h"
9#if defined(__x86_64__) || defined(__i386__) || \
10 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
11#include <immintrin.h>
12#endif
13#include <cassert> //for assert
14#include "./MaskAvx2.h"
15
16namespace fbgemm {
17
18template <
19 bool FUSE_RELU,
20 bool ACT_SYMMETRIC,
21 bool WEIGHT_SYMMETRIC,
22 bool HAS_BIAS,
23 QuantizationGranularity Q_GRAN>
24FBGEMM_API void trRequantizeOpt(
25 uint8_t* out,
26 const int32_t* inp,
27 const block_type_t& block,
28 int ld_out,
29 int ld_in,
30 const trRequantizationParams_t& r) {
31 assert(
32 (Q_GRAN != QuantizationGranularity::GROUP) &&
33 "GROUP Granularity is not supported");
34
35 // Broadcasted act_times_w_scale / C_scale
36 __m256 act_times_w_div_c_v;
37 if (Q_GRAN != QuantizationGranularity::OUT_CHANNEL) {
38 act_times_w_div_c_v = _mm256_set1_ps(r.act_times_w_scale[0] / r.C_scale);
39 }
40
41 __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
42 __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
43
44 assert(
45 (ACT_SYMMETRIC == (r.act_zero_point == 0)) &&
46 "ACT_SYMMETRIC == true if and only if act_zero_point == 0");
47 assert(
48 (WEIGHT_SYMMETRIC ==
49 ((Q_GRAN == QuantizationGranularity::TENSOR &&
50 r.weight_zero_points[0] == 0) ||
51 r.act_col_offsets == nullptr)) &&
52 "WEIGHT_SYMMETRIC == true if and only if weight_zero_point == 0 "
53 "or r.act_col_offsets == nullptr");
54 assert(
55 (HAS_BIAS == (r.bias != nullptr)) &&
56 "HAS_BIAS == true if and only if bias != nullptr");
57
58 __m256i C_zero_point_epi16_v = _mm256_set1_epi16(r.C_zero_point);
59 __m256i C_zero_point_epi8_v = _mm256_set1_epi8(r.C_zero_point);
60
61 __m256i permute_mask_v =
62 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
63
64 constexpr int VLEN = 8;
65 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
66 // Scale weight_row_offset with act_zero_point
67 int32_t row_offset = 0;
68 if (!ACT_SYMMETRIC) {
69 row_offset = r.act_zero_point * r.weight_row_offsets[i];
70 }
71
72 __m256i row_offset_v = _mm256_set1_epi32(row_offset);
73
74 int weight_zeropoint_idx = 0;
75 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
76 weight_zeropoint_idx = i;
77 }
78 __m256 bias_v;
79 if (HAS_BIAS) {
80 float bias = r.bias[i] / r.act_times_w_scale[weight_zeropoint_idx];
81 bias_v = _mm256_set1_ps(bias);
82 }
83
84 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
85 float act_times_w_div_c =
86 r.act_times_w_scale[weight_zeropoint_idx] / r.C_scale;
87 act_times_w_div_c_v = _mm256_set1_ps(act_times_w_div_c);
88 }
89
90 __m256i weight_zeropoint_v =
91 _mm256_set1_epi32(r.weight_zero_points[weight_zeropoint_idx]);
92
93 int j = block.col_start;
94 for (; j < block.col_start + (block.col_size / (VLEN * 4) * (VLEN * 4));
95 j += (VLEN * 4)) {
96 __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
97 inp + (i - block.row_start) * ld_in + (j - block.col_start)));
98 __m256i y_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
99 inp + (i - block.row_start) * ld_in + (j - block.col_start) +
100 1 * VLEN));
101 __m256i z_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
102 inp + (i - block.row_start) * ld_in + (j - block.col_start) +
103 2 * VLEN));
104 __m256i w_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
105 inp + (i - block.row_start) * ld_in + (j - block.col_start) +
106 3 * VLEN));
107
108 if (!ACT_SYMMETRIC) {
109 x_v = _mm256_sub_epi32(x_v, row_offset_v);
110 y_v = _mm256_sub_epi32(y_v, row_offset_v);
111 z_v = _mm256_sub_epi32(z_v, row_offset_v);
112 w_v = _mm256_sub_epi32(w_v, row_offset_v);
113 }
114 if (!WEIGHT_SYMMETRIC) {
115 __m256i col_offset_v = _mm256_mullo_epi32(
116 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
117 r.act_col_offsets + j - block.col_start)),
118 weight_zeropoint_v);
119 x_v = _mm256_sub_epi32(x_v, col_offset_v);
120
121 col_offset_v = _mm256_mullo_epi32(
122 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
123 r.act_col_offsets + VLEN + j - block.col_start)),
124 weight_zeropoint_v);
125 y_v = _mm256_sub_epi32(y_v, col_offset_v);
126
127 col_offset_v = _mm256_mullo_epi32(
128 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
129 r.act_col_offsets + 2 * VLEN + j - block.col_start)),
130 weight_zeropoint_v);
131 z_v = _mm256_sub_epi32(z_v, col_offset_v);
132
133 col_offset_v = _mm256_mullo_epi32(
134 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
135 r.act_col_offsets + 3 * VLEN + j - block.col_start)),
136 weight_zeropoint_v);
137 w_v = _mm256_sub_epi32(w_v, col_offset_v);
138 }
139
140 /*
141 * Convert int32_t input to FP32 and multiply by FP32 scale.
142 * Both operations involve statistically unbiased roundings (with
143 * default MXCSR rounding mode):
144 * - Large int32_t values can't be exactly represented as FP32.
145 * CVTDQ2PS instruction on x86 would round it according to nearest
146 * FP32 value with ties to even (assuming default MXCSR rounding
147 * mode).
148 * - Product of two FP32 values is generally not exactly
149 * representation as an FP32 value, and will be rounded to nearest
150 * FP32 value with ties to even with default MXCSR rounding mode.
151 */
152 __m256 xf_v, yf_v, zf_v, wf_v;
153 if (HAS_BIAS) {
154 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), bias_v);
155 yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), bias_v);
156 zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), bias_v);
157 wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), bias_v);
158 } else {
159 xf_v = _mm256_cvtepi32_ps(x_v);
160 yf_v = _mm256_cvtepi32_ps(y_v);
161 zf_v = _mm256_cvtepi32_ps(z_v);
162 wf_v = _mm256_cvtepi32_ps(w_v);
163 }
164
165 __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
166
167 x_scaled_v = _mm256_mul_ps(xf_v, act_times_w_div_c_v);
168 y_scaled_v = _mm256_mul_ps(yf_v, act_times_w_div_c_v);
169 z_scaled_v = _mm256_mul_ps(zf_v, act_times_w_div_c_v);
170 w_scaled_v = _mm256_mul_ps(wf_v, act_times_w_div_c_v);
171
172 /*
173 * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction.
174 * CVTPS2DQ instruction rounds result according to nearest FP32 value
175 * with ties to even (assuming default MXCSR rounding mode). However,
176 * when conversion overflows, it produces INT32_MIN as a result. For
177 * large positive inputs the result of conversion can become negative,
178 * which affects the final requantization result. Note that on x86
179 * SSE2 we have e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This
180 * happens because float(INT32_MAX) rounds to 2**31, which overflows
181 * int32_t when it is converted back to integer.
182 *
183 * Thankfully, we can prove that overflow never happens in this
184 * requantization scheme. The largest positive input is INT32_MAX
185 * (2**31 - 1), which turns into 2**31 when converted to float. The
186 * largest scale value is 0x1.FFFFFEp-1. When multiplied together, the
187 * result is 2147483520 (compare to INT32_MAX = 2147483647), which
188 * fits into int32_t without overflow.
189 */
190 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
191 __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
192 __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
193 __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
194
195 /*
196 * Standard final sequence on x86 AVX2:
197 * - Pack to int16_t and saturate
198 * - Add zero point
199 * - Pack to uint8_t and saturate
200 * - Clamp between qmin and qmax
201 */
202 __m256i xy_packed_v = _mm256_adds_epi16(
203 _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
204 __m256i zw_packed_v = _mm256_adds_epi16(
205 _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
206 __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
207 __m256i xyzw_clamped_v = _mm256_max_epu8(
208 FUSE_RELU ? C_zero_point_epi8_v : min_v,
209 _mm256_min_epu8(xyzw_packed_v, max_v));
210
211 /*
212 * xyzw_clamped_v has results in the following layout so we need to
213 * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7
214 */
215 xyzw_clamped_v =
216 _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
217
218 _mm256_storeu_si256(
219 reinterpret_cast<__m256i*>(out + i * ld_out + j), xyzw_clamped_v);
220 } // j loop vectorized and unrolled 4x
221
222 for (; j < block.col_start + (block.col_size / VLEN * VLEN); j += VLEN) {
223 __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
224 inp + (i - block.row_start) * ld_in + (j - block.col_start)));
225
226 if (!ACT_SYMMETRIC) {
227 x_v = _mm256_sub_epi32(x_v, row_offset_v);
228 }
229 if (!WEIGHT_SYMMETRIC) {
230 __m256i col_offset_v = _mm256_mullo_epi32(
231 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
232 r.act_col_offsets + j - block.col_start)),
233 weight_zeropoint_v);
234 x_v = _mm256_sub_epi32(x_v, col_offset_v);
235 }
236 __m256 xf_v;
237 if (HAS_BIAS) {
238 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), bias_v);
239 } else {
240 xf_v = _mm256_cvtepi32_ps(x_v);
241 }
242
243 __m256 x_scaled_v = _mm256_mul_ps(xf_v, act_times_w_div_c_v);
244 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
245
246 __m256i x_packed_v = _mm256_adds_epi16(
247 _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
248 C_zero_point_epi16_v);
249 x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
250 __m256i x_clamped_v = _mm256_max_epu8(
251 FUSE_RELU ? C_zero_point_epi8_v : min_v,
252 _mm256_min_epu8(x_packed_v, max_v));
253
254 /*
255 * x_clamped_v has results in the following layout so we need to
256 * permute: x0-3 garbage0-11 x4-7 garbage12-23
257 */
258 x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
259
260 _mm_storel_epi64(
261 reinterpret_cast<__m128i*>(out + i * ld_out + j),
262 _mm256_castsi256_si128(x_clamped_v));
263 } // j loop vectorized
264
265 int remainder = block.col_start + block.col_size - j;
266 if (remainder > 0) {
267 __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
268 internal::avx2_ps_or_epi32_masks[remainder]));
269
270 __m256i x_v = _mm256_maskload_epi32(
271 inp + (i - block.row_start) * ld_in + (j - block.col_start), mask_v);
272
273 if (!ACT_SYMMETRIC) {
274 x_v = _mm256_sub_epi32(x_v, row_offset_v);
275 }
276 if (!WEIGHT_SYMMETRIC) {
277 __m256i col_offset_v = _mm256_mullo_epi32(
278 _mm256_maskload_epi32(
279 r.act_col_offsets + j - block.col_start, mask_v),
280 weight_zeropoint_v);
281 x_v = _mm256_sub_epi32(x_v, col_offset_v);
282 }
283
284 __m256 xf_v;
285 if (HAS_BIAS) {
286 xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), bias_v);
287 } else {
288 xf_v = _mm256_cvtepi32_ps(x_v);
289 }
290 __m256 x_scaled_v = _mm256_mul_ps(xf_v, act_times_w_div_c_v);
291 __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
292
293 __m256i x_packed_v = _mm256_adds_epi16(
294 _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
295 C_zero_point_epi16_v);
296 x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
297 __m256i x_clamped_v = _mm256_max_epu8(
298 FUSE_RELU ? C_zero_point_epi8_v : min_v,
299 _mm256_min_epu8(x_packed_v, max_v));
300
301 /*
302 * x_clamped_v has results in the following layout so we need to
303 * permute: x0-3 garbage0-11 x4-7 garbage12-23
304 */
305 x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
306
307 alignas(64) uint8_t x_clamped_buffer[32];
308 _mm256_store_si256(
309 reinterpret_cast<__m256i*>(x_clamped_buffer), x_clamped_v);
310 for (int k = 0; k < remainder; ++k) {
311 out[i * ld_out + j + k] = x_clamped_buffer[k];
312 }
313 } // j loop remainder
314 } // i loop
315}
316
317#define CREATE_INSTANCE( \
318 FUSE_RELU, ACT_SYMMETRIC, WEIGHT_SYMMETRIC, HAS_BIAS, QGRAN) \
319 template FBGEMM_API void trRequantizeOpt< \
320 FUSE_RELU, \
321 ACT_SYMMETRIC, \
322 WEIGHT_SYMMETRIC, \
323 HAS_BIAS, \
324 QGRAN>( \
325 uint8_t * out, \
326 const int32_t* inp, \
327 const block_type_t& block, \
328 int ld_out, \
329 int ld_in, \
330 const trRequantizationParams_t& r);
331CREATE_INSTANCE(true, true, true, true, QuantizationGranularity::TENSOR)
332CREATE_INSTANCE(true, true, true, false, QuantizationGranularity::TENSOR)
333CREATE_INSTANCE(true, true, false, true, QuantizationGranularity::TENSOR)
334CREATE_INSTANCE(true, true, false, false, QuantizationGranularity::TENSOR)
335CREATE_INSTANCE(true, false, true, true, QuantizationGranularity::TENSOR)
336CREATE_INSTANCE(true, false, true, false, QuantizationGranularity::TENSOR)
337CREATE_INSTANCE(true, false, false, true, QuantizationGranularity::TENSOR)
338CREATE_INSTANCE(true, false, false, false, QuantizationGranularity::TENSOR)
339CREATE_INSTANCE(false, true, true, true, QuantizationGranularity::TENSOR)
340CREATE_INSTANCE(false, true, true, false, QuantizationGranularity::TENSOR)
341CREATE_INSTANCE(false, true, false, true, QuantizationGranularity::TENSOR)
342CREATE_INSTANCE(false, true, false, false, QuantizationGranularity::TENSOR)
343CREATE_INSTANCE(false, false, true, true, QuantizationGranularity::TENSOR)
344CREATE_INSTANCE(false, false, true, false, QuantizationGranularity::TENSOR)
345CREATE_INSTANCE(false, false, false, true, QuantizationGranularity::TENSOR)
346CREATE_INSTANCE(false, false, false, false, QuantizationGranularity::TENSOR)
347CREATE_INSTANCE(true, true, true, true, QuantizationGranularity::OUT_CHANNEL)
348CREATE_INSTANCE(true, true, true, false, QuantizationGranularity::OUT_CHANNEL)
349CREATE_INSTANCE(true, true, false, true, QuantizationGranularity::OUT_CHANNEL)
350CREATE_INSTANCE(true, true, false, false, QuantizationGranularity::OUT_CHANNEL)
351CREATE_INSTANCE(true, false, true, true, QuantizationGranularity::OUT_CHANNEL)
352CREATE_INSTANCE(true, false, true, false, QuantizationGranularity::OUT_CHANNEL)
353CREATE_INSTANCE(true, false, false, true, QuantizationGranularity::OUT_CHANNEL)
354CREATE_INSTANCE(true, false, false, false, QuantizationGranularity::OUT_CHANNEL)
355CREATE_INSTANCE(false, true, true, true, QuantizationGranularity::OUT_CHANNEL)
356CREATE_INSTANCE(false, true, true, false, QuantizationGranularity::OUT_CHANNEL)
357CREATE_INSTANCE(false, true, false, true, QuantizationGranularity::OUT_CHANNEL)
358CREATE_INSTANCE(false, true, false, false, QuantizationGranularity::OUT_CHANNEL)
359CREATE_INSTANCE(false, false, true, true, QuantizationGranularity::OUT_CHANNEL)
360CREATE_INSTANCE(false, false, true, false, QuantizationGranularity::OUT_CHANNEL)
361CREATE_INSTANCE(false, false, false, true, QuantizationGranularity::OUT_CHANNEL)
362CREATE_INSTANCE(
363 false,
364 false,
365 false,
366 false,
367 QuantizationGranularity::OUT_CHANNEL)
368#undef CREATE_INSTANCE
369
370} // namespace fbgemm
371