1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmSparse.h"
9#include "fbgemm/Utils.h"
10#include "fbgemm/spmmUtilsAvx2.h"
11
12#if defined(__x86_64__) || defined(__i386__) || \
13 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
14#include <immintrin.h>
15#endif
16#include <cassert>
17
18namespace fbgemm {
19
20namespace internal {
21
22static inline int32_t horizontal_add(__m256i a) {
23 __m256i t1 = _mm256_hadd_epi32(a, a);
24 __m256i t2 = _mm256_hadd_epi32(t1, t1);
25 __m128i t3 = _mm256_extracti128_si256(t2, 1);
26 __m128i t4 = _mm_add_epi32(_mm256_castsi256_si128(t2), t3);
27 return _mm_cvtsi128_si32(t4);
28}
29
30template <
31 bool FUSE_RELU,
32 bool ACT_ZP_0, // is activation zero point 0?
33 bool HAS_BIAS,
34 QuantizationGranularity Q_GRAN>
35static inline void requantizeForMV(
36 uint8_t* dst,
37 int32_t* src,
38 int len,
39 trRequantizationParams_t& rParams) {
40 constexpr int VLEN_INT32 = 16;
41 __m512i C_zero_point_epi8_v = _mm512_set1_epi8(rParams.C_zero_point);
42 __m512i C_zero_point_epi32_v = _mm512_set1_epi32(rParams.C_zero_point);
43 // clang-format off
44 __m512i permute_mask_v = _mm512_set_epi32(
45 0x0F, 0x0B, 0x07, 0x03,
46 0x0E, 0x0A, 0x06, 0x02,
47 0x0D, 0x09, 0x05, 0x01,
48 0x0C, 0x08, 0x04, 0x00);
49 // clang-format on
50 int i = 0;
51 for (; i < len / VLEN_INT32 * VLEN_INT32; i += VLEN_INT32) {
52 __m512i x_v = _mm512_loadu_si512(src + i);
53 if (!ACT_ZP_0) {
54 __m512i weight_row_offset_v =
55 _mm512_loadu_si512(rParams.weight_row_offsets + i);
56 __m512i act_zero_point_v = _mm512_set1_epi32(rParams.act_zero_point);
57 x_v = _mm512_sub_epi32(
58 x_v, _mm512_mullo_epi32(act_zero_point_v, weight_row_offset_v));
59 }
60 __m512 act_times_w_scale_v;
61 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
62 act_times_w_scale_v = _mm512_loadu_ps(rParams.act_times_w_scale + i);
63 } else {
64 act_times_w_scale_v = _mm512_set1_ps(rParams.act_times_w_scale[0]);
65 }
66 __m512 c_scale_v = _mm512_set1_ps(rParams.C_scale);
67 __m512 act_times_w_div_c_v = _mm512_div_ps(act_times_w_scale_v, c_scale_v);
68
69 __m512 xf_v;
70 if (HAS_BIAS) {
71 __m512 bias_v = _mm512_loadu_ps(rParams.bias + i);
72 bias_v = _mm512_div_ps(bias_v, act_times_w_scale_v);
73 xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x_v), bias_v);
74 } else {
75 xf_v = _mm512_cvtepi32_ps(x_v);
76 }
77
78 __m512 x_scaled_v = _mm512_mul_ps(xf_v, act_times_w_div_c_v);
79 __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v);
80 __m512i x_added_v = _mm512_add_epi32(x_rounded_v, C_zero_point_epi32_v);
81
82 __m512i x_clamped_v = _mm512_packs_epi32(x_added_v, _mm512_setzero_si512());
83 x_clamped_v = _mm512_packus_epi16(x_clamped_v, _mm512_setzero_si512());
84 if (FUSE_RELU) {
85 x_clamped_v = _mm512_max_epu8(C_zero_point_epi8_v, x_clamped_v);
86 }
87 x_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, x_clamped_v);
88
89 _mm_store_si128(
90 reinterpret_cast<__m128i*>(dst + i),
91 _mm512_castsi512_si128(x_clamped_v));
92 }
93 int rem_int32 = len - i;
94 if (rem_int32 > 0) {
95 __mmask64 mask_int8_v = (((long long)1) << rem_int32) - 1;
96 __mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1;
97 __m512i x_v = _mm512_maskz_loadu_epi32(mask_int32_v, src + i);
98
99 if (!ACT_ZP_0) {
100 __m512i weight_row_offset_v = _mm512_maskz_loadu_epi32(
101 mask_int32_v, rParams.weight_row_offsets + i);
102 __m512i act_zero_point_v = _mm512_set1_epi32(rParams.act_zero_point);
103 x_v = _mm512_sub_epi32(
104 x_v, _mm512_mullo_epi32(act_zero_point_v, weight_row_offset_v));
105 }
106 __m512 act_times_w_scale_v;
107 if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
108 act_times_w_scale_v =
109 _mm512_maskz_loadu_ps(mask_int32_v, rParams.act_times_w_scale + i);
110 } else {
111 act_times_w_scale_v = _mm512_set1_ps(rParams.act_times_w_scale[0]);
112 }
113 __m512 c_scale_v = _mm512_set1_ps(rParams.C_scale);
114 __m512 act_times_w_div_c_v = _mm512_div_ps(act_times_w_scale_v, c_scale_v);
115
116 __m512 xf_v;
117 if (HAS_BIAS) {
118 __m512 bias_v = _mm512_maskz_loadu_ps(mask_int32_v, rParams.bias + i);
119 bias_v = _mm512_div_ps(bias_v, act_times_w_scale_v);
120 xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x_v), bias_v);
121 } else {
122 xf_v = _mm512_cvtepi32_ps(x_v);
123 }
124
125 __m512 x_scaled_v = _mm512_mul_ps(xf_v, act_times_w_div_c_v);
126 __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v);
127 __m512i x_added_v = _mm512_add_epi32(x_rounded_v, C_zero_point_epi32_v);
128
129 __m512i x_clamped_v = _mm512_packs_epi32(x_added_v, _mm512_setzero_si512());
130 x_clamped_v = _mm512_packus_epi16(x_clamped_v, _mm512_setzero_si512());
131 if (FUSE_RELU) {
132 x_clamped_v = _mm512_max_epu8(C_zero_point_epi8_v, x_clamped_v);
133 }
134 x_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, x_clamped_v);
135
136 _mm512_mask_storeu_epi8(dst + i, mask_int8_v, x_clamped_v);
137 }
138}
139
140// matrix-vector product
141// i.e., produces same results as SparseDenseInt8MMAvx512 with N == 1
142template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
143void SparseDenseInt8MVAvx512(
144 const std::unique_ptr<BCSRMatrix<>>& bcsr,
145 const uint8_t* B,
146 int ldb,
147 int32_t* C_i32,
148 uint8_t* C_u8,
149 trRequantizationParams_t& rParams,
150 bool accum,
151 int thread_id,
152 int num_threads) {
153 (void)num_threads; // Suppress unused variable warning
154 // Calcualtes accum ? C += A * B : C = A * B
155 constexpr int VLEN_INT32 = 16;
156
157 constexpr int block_size = BCSRMatrix<>::CB;
158 constexpr int colTileSize = BCSRMatrix<>::COLTILE;
159
160 // all work is done by thread 0 for now
161 assert(num_threads > 0 && "Numbers of threads should be > 0");
162 if (thread_id > 0) {
163 return;
164 }
165
166 assert(ldb == 1 && "ldb should be 1");
167 __m512i one_16bit_v = _mm512_set1_epi16(1);
168 // Number of columns in the sparse matrix A
169 int K = bcsr->C;
170 int M = bcsr->R;
171 assert(K % 4 == 0 && "K should be multiple of 4");
172 assert((K > 0) && "K needs to be positive");
173 int kTiles = (K + colTileSize - 1) / colTileSize;
174 const int* row_ptr = bcsr->rowBPtr.data();
175 const int* col_idx = bcsr->colBIdx.data();
176 const int8_t* values = bcsr->values.data();
177 for (int kt = 0; kt < kTiles; ++kt) {
178 const int* cur_row_ptr = row_ptr + kt * M;
179 const uint8_t* cur_B = B + kt * colTileSize * ldb;
180 // TODO: unroll this loop?
181 for (int i = 0; i < M; ++i) {
182 __m512i res = _mm512_set1_epi32(0);
183 int r = cur_row_ptr[i];
184 int r_end_aligned = cur_row_ptr[i] +
185 (cur_row_ptr[i + 1] - cur_row_ptr[i]) / VLEN_INT32 * VLEN_INT32;
186 for (; r < r_end_aligned; r += VLEN_INT32) {
187 __m512i a_v = _mm512_loadu_si512(values + r * block_size);
188 __m512i b_idx = _mm512_loadu_si512(col_idx + r);
189 __m512i b_v = _mm512_i32gather_epi32(
190 b_idx, reinterpret_cast<const int32_t*>(cur_B), block_size);
191 __m512i c_i16_v = _mm512_maddubs_epi16(b_v, a_v);
192 __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v);
193 res = _mm512_add_epi32(res, c_i32_v);
194 }
195
196 int rem = cur_row_ptr[i + 1] - r;
197 if (rem > 0) {
198 __mmask16 mask_int32_v = (((long long)1) << (rem)) - 1;
199 __m512i a_v =
200 _mm512_maskz_loadu_epi32(mask_int32_v, values + r * block_size);
201 __m512i b_idx = _mm512_maskz_loadu_epi32(mask_int32_v, col_idx + r);
202 __m512i b_v = _mm512_i32gather_epi32(
203 b_idx, reinterpret_cast<const int32_t*>(cur_B), block_size);
204 __m512i c_i16_v = _mm512_maddubs_epi16(b_v, a_v);
205 __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v);
206 res = _mm512_add_epi32(res, c_i32_v);
207 }
208 // Horizontal reduce
209 // _mm512_reduce_add_epi32 is only available for gcc version > 7
210#if __GNUC__ >= 7
211 int32_t res_i32 = _mm512_reduce_add_epi32(res);
212#else
213 __m256i low = _mm512_castsi512_si256(res);
214 __m256i high = _mm512_extracti64x4_epi64(res, 1);
215 int32_t res_i32 = horizontal_add(_mm256_add_epi32(low, high));
216#endif
217
218 // store the results
219 if (accum || kt > 0) {
220 C_i32[i] += res_i32;
221 } else {
222 C_i32[i] = res_i32;
223 }
224 }
225 }
226 if (rParams.bias == nullptr) {
227 if (rParams.act_zero_point) {
228 requantizeForMV<FUSE_RELU, false, false, Q_GRAN>(C_u8, C_i32, M, rParams);
229 } else {
230 requantizeForMV<FUSE_RELU, true, false, Q_GRAN>(C_u8, C_i32, M, rParams);
231 }
232 } else {
233 if (rParams.act_zero_point) {
234 requantizeForMV<FUSE_RELU, false, true, Q_GRAN>(C_u8, C_i32, M, rParams);
235 } else {
236 requantizeForMV<FUSE_RELU, true, true, Q_GRAN>(C_u8, C_i32, M, rParams);
237 }
238 }
239}
240
241#define CREATE_INSTANCE(FUSE_RELU, QGRAN) \
242 template void SparseDenseInt8MVAvx512<FUSE_RELU, QGRAN>( \
243 const std::unique_ptr<BCSRMatrix<>>& bcsr, \
244 const uint8_t* B, \
245 int ldb, \
246 int32_t* C_i32, \
247 uint8_t* C_u8, \
248 trRequantizationParams_t& rParams, \
249 bool accum, \
250 int thread_id, \
251 int num_threads);
252CREATE_INSTANCE(true, QuantizationGranularity::TENSOR)
253CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL)
254CREATE_INSTANCE(false, QuantizationGranularity::TENSOR)
255CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL)
256#undef CREATE_INSTANCE
257
258} // namespace internal
259} // namespace fbgemm
260