1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmSparse.h" |
9 | #include "fbgemm/Utils.h" |
10 | #include "fbgemm/spmmUtilsAvx2.h" |
11 | |
12 | #if defined(__x86_64__) || defined(__i386__) || \ |
13 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
14 | #include <immintrin.h> |
15 | #endif |
16 | #include <cassert> |
17 | |
18 | namespace fbgemm { |
19 | |
20 | namespace internal { |
21 | |
22 | static inline int32_t horizontal_add(__m256i a) { |
23 | __m256i t1 = _mm256_hadd_epi32(a, a); |
24 | __m256i t2 = _mm256_hadd_epi32(t1, t1); |
25 | __m128i t3 = _mm256_extracti128_si256(t2, 1); |
26 | __m128i t4 = _mm_add_epi32(_mm256_castsi256_si128(t2), t3); |
27 | return _mm_cvtsi128_si32(t4); |
28 | } |
29 | |
30 | template < |
31 | bool FUSE_RELU, |
32 | bool ACT_ZP_0, // is activation zero point 0? |
33 | bool HAS_BIAS, |
34 | QuantizationGranularity Q_GRAN> |
35 | static inline void requantizeForMV( |
36 | uint8_t* dst, |
37 | int32_t* src, |
38 | int len, |
39 | trRequantizationParams_t& rParams) { |
40 | constexpr int VLEN_INT32 = 16; |
41 | __m512i C_zero_point_epi8_v = _mm512_set1_epi8(rParams.C_zero_point); |
42 | __m512i C_zero_point_epi32_v = _mm512_set1_epi32(rParams.C_zero_point); |
43 | // clang-format off |
44 | __m512i permute_mask_v = _mm512_set_epi32( |
45 | 0x0F, 0x0B, 0x07, 0x03, |
46 | 0x0E, 0x0A, 0x06, 0x02, |
47 | 0x0D, 0x09, 0x05, 0x01, |
48 | 0x0C, 0x08, 0x04, 0x00); |
49 | // clang-format on |
50 | int i = 0; |
51 | for (; i < len / VLEN_INT32 * VLEN_INT32; i += VLEN_INT32) { |
52 | __m512i x_v = _mm512_loadu_si512(src + i); |
53 | if (!ACT_ZP_0) { |
54 | __m512i weight_row_offset_v = |
55 | _mm512_loadu_si512(rParams.weight_row_offsets + i); |
56 | __m512i act_zero_point_v = _mm512_set1_epi32(rParams.act_zero_point); |
57 | x_v = _mm512_sub_epi32( |
58 | x_v, _mm512_mullo_epi32(act_zero_point_v, weight_row_offset_v)); |
59 | } |
60 | __m512 act_times_w_scale_v; |
61 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
62 | act_times_w_scale_v = _mm512_loadu_ps(rParams.act_times_w_scale + i); |
63 | } else { |
64 | act_times_w_scale_v = _mm512_set1_ps(rParams.act_times_w_scale[0]); |
65 | } |
66 | __m512 c_scale_v = _mm512_set1_ps(rParams.C_scale); |
67 | __m512 act_times_w_div_c_v = _mm512_div_ps(act_times_w_scale_v, c_scale_v); |
68 | |
69 | __m512 xf_v; |
70 | if (HAS_BIAS) { |
71 | __m512 bias_v = _mm512_loadu_ps(rParams.bias + i); |
72 | bias_v = _mm512_div_ps(bias_v, act_times_w_scale_v); |
73 | xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x_v), bias_v); |
74 | } else { |
75 | xf_v = _mm512_cvtepi32_ps(x_v); |
76 | } |
77 | |
78 | __m512 x_scaled_v = _mm512_mul_ps(xf_v, act_times_w_div_c_v); |
79 | __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); |
80 | __m512i x_added_v = _mm512_add_epi32(x_rounded_v, C_zero_point_epi32_v); |
81 | |
82 | __m512i x_clamped_v = _mm512_packs_epi32(x_added_v, _mm512_setzero_si512()); |
83 | x_clamped_v = _mm512_packus_epi16(x_clamped_v, _mm512_setzero_si512()); |
84 | if (FUSE_RELU) { |
85 | x_clamped_v = _mm512_max_epu8(C_zero_point_epi8_v, x_clamped_v); |
86 | } |
87 | x_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, x_clamped_v); |
88 | |
89 | _mm_store_si128( |
90 | reinterpret_cast<__m128i*>(dst + i), |
91 | _mm512_castsi512_si128(x_clamped_v)); |
92 | } |
93 | int rem_int32 = len - i; |
94 | if (rem_int32 > 0) { |
95 | __mmask64 mask_int8_v = (((long long)1) << rem_int32) - 1; |
96 | __mmask16 mask_int32_v = (((long long)1) << rem_int32) - 1; |
97 | __m512i x_v = _mm512_maskz_loadu_epi32(mask_int32_v, src + i); |
98 | |
99 | if (!ACT_ZP_0) { |
100 | __m512i weight_row_offset_v = _mm512_maskz_loadu_epi32( |
101 | mask_int32_v, rParams.weight_row_offsets + i); |
102 | __m512i act_zero_point_v = _mm512_set1_epi32(rParams.act_zero_point); |
103 | x_v = _mm512_sub_epi32( |
104 | x_v, _mm512_mullo_epi32(act_zero_point_v, weight_row_offset_v)); |
105 | } |
106 | __m512 act_times_w_scale_v; |
107 | if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { |
108 | act_times_w_scale_v = |
109 | _mm512_maskz_loadu_ps(mask_int32_v, rParams.act_times_w_scale + i); |
110 | } else { |
111 | act_times_w_scale_v = _mm512_set1_ps(rParams.act_times_w_scale[0]); |
112 | } |
113 | __m512 c_scale_v = _mm512_set1_ps(rParams.C_scale); |
114 | __m512 act_times_w_div_c_v = _mm512_div_ps(act_times_w_scale_v, c_scale_v); |
115 | |
116 | __m512 xf_v; |
117 | if (HAS_BIAS) { |
118 | __m512 bias_v = _mm512_maskz_loadu_ps(mask_int32_v, rParams.bias + i); |
119 | bias_v = _mm512_div_ps(bias_v, act_times_w_scale_v); |
120 | xf_v = _mm512_add_ps(_mm512_cvtepi32_ps(x_v), bias_v); |
121 | } else { |
122 | xf_v = _mm512_cvtepi32_ps(x_v); |
123 | } |
124 | |
125 | __m512 x_scaled_v = _mm512_mul_ps(xf_v, act_times_w_div_c_v); |
126 | __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); |
127 | __m512i x_added_v = _mm512_add_epi32(x_rounded_v, C_zero_point_epi32_v); |
128 | |
129 | __m512i x_clamped_v = _mm512_packs_epi32(x_added_v, _mm512_setzero_si512()); |
130 | x_clamped_v = _mm512_packus_epi16(x_clamped_v, _mm512_setzero_si512()); |
131 | if (FUSE_RELU) { |
132 | x_clamped_v = _mm512_max_epu8(C_zero_point_epi8_v, x_clamped_v); |
133 | } |
134 | x_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, x_clamped_v); |
135 | |
136 | _mm512_mask_storeu_epi8(dst + i, mask_int8_v, x_clamped_v); |
137 | } |
138 | } |
139 | |
140 | // matrix-vector product |
141 | // i.e., produces same results as SparseDenseInt8MMAvx512 with N == 1 |
142 | template <bool FUSE_RELU, QuantizationGranularity Q_GRAN> |
143 | void SparseDenseInt8MVAvx512( |
144 | const std::unique_ptr<BCSRMatrix<>>& bcsr, |
145 | const uint8_t* B, |
146 | int ldb, |
147 | int32_t* C_i32, |
148 | uint8_t* C_u8, |
149 | trRequantizationParams_t& rParams, |
150 | bool accum, |
151 | int thread_id, |
152 | int num_threads) { |
153 | (void)num_threads; // Suppress unused variable warning |
154 | // Calcualtes accum ? C += A * B : C = A * B |
155 | constexpr int VLEN_INT32 = 16; |
156 | |
157 | constexpr int block_size = BCSRMatrix<>::CB; |
158 | constexpr int colTileSize = BCSRMatrix<>::COLTILE; |
159 | |
160 | // all work is done by thread 0 for now |
161 | assert(num_threads > 0 && "Numbers of threads should be > 0" ); |
162 | if (thread_id > 0) { |
163 | return; |
164 | } |
165 | |
166 | assert(ldb == 1 && "ldb should be 1" ); |
167 | __m512i one_16bit_v = _mm512_set1_epi16(1); |
168 | // Number of columns in the sparse matrix A |
169 | int K = bcsr->C; |
170 | int M = bcsr->R; |
171 | assert(K % 4 == 0 && "K should be multiple of 4" ); |
172 | assert((K > 0) && "K needs to be positive" ); |
173 | int kTiles = (K + colTileSize - 1) / colTileSize; |
174 | const int* row_ptr = bcsr->rowBPtr.data(); |
175 | const int* col_idx = bcsr->colBIdx.data(); |
176 | const int8_t* values = bcsr->values.data(); |
177 | for (int kt = 0; kt < kTiles; ++kt) { |
178 | const int* cur_row_ptr = row_ptr + kt * M; |
179 | const uint8_t* cur_B = B + kt * colTileSize * ldb; |
180 | // TODO: unroll this loop? |
181 | for (int i = 0; i < M; ++i) { |
182 | __m512i res = _mm512_set1_epi32(0); |
183 | int r = cur_row_ptr[i]; |
184 | int r_end_aligned = cur_row_ptr[i] + |
185 | (cur_row_ptr[i + 1] - cur_row_ptr[i]) / VLEN_INT32 * VLEN_INT32; |
186 | for (; r < r_end_aligned; r += VLEN_INT32) { |
187 | __m512i a_v = _mm512_loadu_si512(values + r * block_size); |
188 | __m512i b_idx = _mm512_loadu_si512(col_idx + r); |
189 | __m512i b_v = _mm512_i32gather_epi32( |
190 | b_idx, reinterpret_cast<const int32_t*>(cur_B), block_size); |
191 | __m512i c_i16_v = _mm512_maddubs_epi16(b_v, a_v); |
192 | __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v); |
193 | res = _mm512_add_epi32(res, c_i32_v); |
194 | } |
195 | |
196 | int rem = cur_row_ptr[i + 1] - r; |
197 | if (rem > 0) { |
198 | __mmask16 mask_int32_v = (((long long)1) << (rem)) - 1; |
199 | __m512i a_v = |
200 | _mm512_maskz_loadu_epi32(mask_int32_v, values + r * block_size); |
201 | __m512i b_idx = _mm512_maskz_loadu_epi32(mask_int32_v, col_idx + r); |
202 | __m512i b_v = _mm512_i32gather_epi32( |
203 | b_idx, reinterpret_cast<const int32_t*>(cur_B), block_size); |
204 | __m512i c_i16_v = _mm512_maddubs_epi16(b_v, a_v); |
205 | __m512i c_i32_v = _mm512_madd_epi16(one_16bit_v, c_i16_v); |
206 | res = _mm512_add_epi32(res, c_i32_v); |
207 | } |
208 | // Horizontal reduce |
209 | // _mm512_reduce_add_epi32 is only available for gcc version > 7 |
210 | #if __GNUC__ >= 7 |
211 | int32_t res_i32 = _mm512_reduce_add_epi32(res); |
212 | #else |
213 | __m256i low = _mm512_castsi512_si256(res); |
214 | __m256i high = _mm512_extracti64x4_epi64(res, 1); |
215 | int32_t res_i32 = horizontal_add(_mm256_add_epi32(low, high)); |
216 | #endif |
217 | |
218 | // store the results |
219 | if (accum || kt > 0) { |
220 | C_i32[i] += res_i32; |
221 | } else { |
222 | C_i32[i] = res_i32; |
223 | } |
224 | } |
225 | } |
226 | if (rParams.bias == nullptr) { |
227 | if (rParams.act_zero_point) { |
228 | requantizeForMV<FUSE_RELU, false, false, Q_GRAN>(C_u8, C_i32, M, rParams); |
229 | } else { |
230 | requantizeForMV<FUSE_RELU, true, false, Q_GRAN>(C_u8, C_i32, M, rParams); |
231 | } |
232 | } else { |
233 | if (rParams.act_zero_point) { |
234 | requantizeForMV<FUSE_RELU, false, true, Q_GRAN>(C_u8, C_i32, M, rParams); |
235 | } else { |
236 | requantizeForMV<FUSE_RELU, true, true, Q_GRAN>(C_u8, C_i32, M, rParams); |
237 | } |
238 | } |
239 | } |
240 | |
241 | #define CREATE_INSTANCE(FUSE_RELU, QGRAN) \ |
242 | template void SparseDenseInt8MVAvx512<FUSE_RELU, QGRAN>( \ |
243 | const std::unique_ptr<BCSRMatrix<>>& bcsr, \ |
244 | const uint8_t* B, \ |
245 | int ldb, \ |
246 | int32_t* C_i32, \ |
247 | uint8_t* C_u8, \ |
248 | trRequantizationParams_t& rParams, \ |
249 | bool accum, \ |
250 | int thread_id, \ |
251 | int num_threads); |
252 | CREATE_INSTANCE(true, QuantizationGranularity::TENSOR) |
253 | CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL) |
254 | CREATE_INSTANCE(false, QuantizationGranularity::TENSOR) |
255 | CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL) |
256 | #undef CREATE_INSTANCE |
257 | |
258 | } // namespace internal |
259 | } // namespace fbgemm |
260 | |