1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmSparse.h" |
9 | #include "fbgemm/spmmUtilsAvx2.h" |
10 | |
11 | #if defined(__x86_64__) || defined(__i386__) || \ |
12 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
13 | #include <immintrin.h> |
14 | #endif |
15 | #include <algorithm> // for min and max |
16 | #include <cassert> |
17 | #include <cstring> |
18 | #include "./MaskAvx2.h" |
19 | |
20 | namespace fbgemm { |
21 | namespace internal { |
22 | |
23 | static inline __m256i permute_row(__m256i row) { |
24 | // clang-format off |
25 | __m256i ret = _mm256_shuffle_epi8( |
26 | row, |
27 | _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0, |
28 | 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0)); |
29 | // clang-format on |
30 | return ret; |
31 | } |
32 | |
33 | static inline void interleave_4rows(__m256i data[]) { |
34 | __m256i __t0 = _mm256_unpacklo_epi32(data[0], data[1]); |
35 | __m256i __t1 = _mm256_unpackhi_epi32(data[0], data[1]); |
36 | __m256i __t2 = _mm256_unpacklo_epi32(data[2], data[3]); |
37 | __m256i __t3 = _mm256_unpackhi_epi32(data[2], data[3]); |
38 | __m256i __tt0 = _mm256_unpacklo_epi64(__t0, __t2); |
39 | __m256i __tt1 = _mm256_unpacklo_epi64(__t1, __t3); |
40 | __m256i __tt2 = _mm256_unpackhi_epi64(__t0, __t2); |
41 | __m256i __tt3 = _mm256_unpackhi_epi64(__t1, __t3); |
42 | __m256i row0 = _mm256_permute2x128_si256(__tt0, __tt2, 0x20); |
43 | __m256i row1 = _mm256_permute2x128_si256(__tt1, __tt3, 0x20); |
44 | __m256i row2 = _mm256_permute2x128_si256(__tt0, __tt2, 0x31); |
45 | __m256i row3 = _mm256_permute2x128_si256(__tt1, __tt3, 0x31); |
46 | // End of int32 transpose |
47 | // Now we only need a simple row permutation to get the right result |
48 | data[0] = permute_row(row0); |
49 | data[1] = permute_row(row1); |
50 | data[2] = permute_row(row2); |
51 | data[3] = permute_row(row3); |
52 | return; |
53 | } |
54 | |
55 | template <bool FUSE_RELU, QuantizationGranularity Q_GRAN> |
56 | void SparseDenseInt8MMAvx2( |
57 | int N, |
58 | const std::unique_ptr<BCSRMatrix<>>& bcsr, |
59 | const uint8_t* B, |
60 | int ldb, |
61 | int32_t* C_i32, |
62 | uint8_t* C_u8, |
63 | int ldc, |
64 | trRequantizationParams_t& rParams, |
65 | bool accum, |
66 | int /*thread_id*/, |
67 | int /*num_threads*/) { |
68 | // Calcualtes accum ? C += A * B : C = A * B |
69 | constexpr int VLEN_INT8 = 32; |
70 | constexpr int VLEN_INT32 = 8; |
71 | constexpr int rowBlockSize = BCSRMatrix<>::RB; |
72 | (void)rowBlockSize; // Suppress unused variable warning |
73 | constexpr int colBlockSize = BCSRMatrix<>::CB; |
74 | |
75 | constexpr int colTileSize = BCSRMatrix<>::COLTILE; |
76 | int K = bcsr->C; |
77 | int M = bcsr->R; |
78 | int kTiles = (K + colTileSize - 1) / colTileSize; |
79 | |
80 | for (int i = 0; i < M; ++i) { |
81 | if (!accum) { |
82 | int j = 0; |
83 | __m256i c_v = _mm256_set1_epi32(0); |
84 | for (; j < N / VLEN_INT32 * VLEN_INT32; j += VLEN_INT32) { |
85 | _mm256_storeu_si256( |
86 | reinterpret_cast<__m256i*>(C_i32 + i * ldc + j), c_v); |
87 | } |
88 | // Handle remainder |
89 | int rem = N - j; |
90 | if (rem > 0) { |
91 | __m256i mask_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
92 | &avx2_ps_or_epi32_combined_mask[VLEN_INT32 - rem])); |
93 | _mm256_maskstore_epi32( |
94 | reinterpret_cast<int32_t*>(C_i32 + i * ldc + j), mask_v, c_v); |
95 | } |
96 | } |
97 | for (int kt = 0; kt < kTiles; ++kt) { |
98 | int* row_ptr = bcsr->rowBPtr.data() + kt * M; |
99 | int* col_idx = bcsr->colBIdx.data(); |
100 | int8_t* values = bcsr->values.data(); |
101 | int curKSize = std::min(K - kt * colTileSize, colTileSize); |
102 | |
103 | int r = row_ptr[i]; |
104 | // int r_end_aligned = row_ptr[i] + (row_ptr[i + 1] - row_ptr[i]) / 4 * 4; |
105 | // unrolled by 1 |
106 | for (; r < row_ptr[i + 1]; ++r) { |
107 | // this is needed for correct operation |
108 | assert(rowBlockSize == 1 && "row block size should be 1" ); |
109 | assert(colBlockSize == 4 && "column block size should be 4" ); |
110 | int acbr_block = col_idx[r]; |
111 | int32_t v = reinterpret_cast<const int32_t*>(values)[r]; |
112 | __m256i a_v = _mm256_set1_epi32(v); |
113 | int j = 0; |
114 | for (; j < N / VLEN_INT8 * VLEN_INT8; j += VLEN_INT8) { |
115 | __m256i br_v[4] = {}; |
116 | |
117 | for (int idx = 0; |
118 | idx < std::min(4, curKSize - acbr_block * colBlockSize); |
119 | ++idx) { |
120 | br_v[idx] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
121 | B + (acbr_block * colBlockSize + idx + kt * colTileSize) * ldb + |
122 | j)); |
123 | } |
124 | |
125 | // interleave these 4 rows |
126 | interleave_4rows(br_v); |
127 | |
128 | __m256i one_16bit_v = _mm256_set1_epi16(1); |
129 | __m256i c_v[4]; |
130 | for (int idx = 0; idx < 4; ++idx) { |
131 | c_v[idx] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
132 | C_i32 + i * ldc + j + idx * VLEN_INT32)); |
133 | __m256i c_i16_v = _mm256_maddubs_epi16(br_v[idx], a_v); |
134 | __m256i c_i32_v = _mm256_madd_epi16(one_16bit_v, c_i16_v); |
135 | c_v[idx] = _mm256_add_epi32(c_v[idx], c_i32_v); |
136 | _mm256_storeu_si256( |
137 | reinterpret_cast<__m256i*>( |
138 | C_i32 + i * ldc + j + idx * VLEN_INT32), |
139 | c_v[idx]); |
140 | } |
141 | } |
142 | // Handle remainder j loop |
143 | int rem = N - j; |
144 | if (rem > 0) { |
145 | __m256i br_v[4] = {}; |
146 | for (int idx = 0; |
147 | idx < std::min(4, curKSize - acbr_block * colBlockSize); |
148 | ++idx) { |
149 | uint8_t tmpDest[VLEN_INT8] = {}; |
150 | std::memcpy( |
151 | tmpDest, |
152 | B + (acbr_block * colBlockSize + idx + kt * colTileSize) * ldb + |
153 | j, |
154 | rem); |
155 | br_v[idx] = |
156 | _mm256_loadu_si256(reinterpret_cast<const __m256i*>(tmpDest)); |
157 | } |
158 | // interleave these 4 rows |
159 | interleave_4rows(br_v); |
160 | |
161 | __m256i c_v[4] = {}; |
162 | int idx1 = 0; |
163 | for (; idx1 < rem / VLEN_INT32; ++idx1) { |
164 | c_v[idx1] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
165 | C_i32 + i * ldc + j + idx1 * 8)); |
166 | } |
167 | int rem_int32 = rem - idx1 * VLEN_INT32; |
168 | __m256i mask_int32_v; |
169 | if (rem_int32 > 0) { |
170 | mask_int32_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>( |
171 | &avx2_ps_or_epi32_combined_mask[VLEN_INT32 - rem_int32])); |
172 | c_v[idx1] = _mm256_maskload_epi32( |
173 | reinterpret_cast<const int*>( |
174 | C_i32 + i * ldc + j + idx1 * VLEN_INT32), |
175 | mask_int32_v); |
176 | } |
177 | |
178 | __m256i one_16bit_v = _mm256_set1_epi16(1); |
179 | for (int idx = 0; idx < 4; ++idx) { |
180 | __m256i c_i16_v = _mm256_maddubs_epi16(br_v[idx], a_v); |
181 | __m256i c_i32_v = _mm256_madd_epi16(one_16bit_v, c_i16_v); |
182 | c_v[idx] = _mm256_add_epi32(c_v[idx], c_i32_v); |
183 | } |
184 | |
185 | int idx2 = 0; |
186 | for (; idx2 < rem / VLEN_INT32; ++idx2) { |
187 | _mm256_storeu_si256( |
188 | reinterpret_cast<__m256i*>( |
189 | C_i32 + i * ldc + j + idx2 * VLEN_INT32), |
190 | c_v[idx2]); |
191 | } |
192 | if (rem_int32 > 0) { |
193 | _mm256_maskstore_epi32( |
194 | reinterpret_cast<int*>(C_i32 + i * ldc + j + idx2 * VLEN_INT32), |
195 | mask_int32_v, |
196 | c_v[idx2]); |
197 | } |
198 | } |
199 | } |
200 | } |
201 | } |
202 | |
203 | block_type_t block{0, M, 0, N}; |
204 | if (rParams.bias == nullptr) { |
205 | if (rParams.act_zero_point) { |
206 | trRequantizeOpt< |
207 | FUSE_RELU, |
208 | /*ACT_SYMMETRIC*/ false, |
209 | /*WEIGHT_SYMMETRIC*/ true, |
210 | /*HAS_BIAS*/ false, |
211 | Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams); |
212 | } else { |
213 | trRequantizeOpt< |
214 | FUSE_RELU, |
215 | /*ACT_SYMMETRIC*/ true, |
216 | /*WEIGHT_SYMMETRIC*/ true, |
217 | /*HAS_BIAS*/ false, |
218 | Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams); |
219 | } |
220 | } else { |
221 | if (rParams.act_zero_point) { |
222 | trRequantizeOpt< |
223 | FUSE_RELU, |
224 | /*ACT_SYMMETRIC*/ false, |
225 | /*WEIGHT_SYMMETRIC*/ true, |
226 | /*HAS_BIAS*/ true, |
227 | Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams); |
228 | } else { |
229 | trRequantizeOpt< |
230 | FUSE_RELU, |
231 | /*ACT_SYMMETRIC*/ true, |
232 | /*WEIGHT_SYMMETRIC*/ true, |
233 | /*HAS_BIAS*/ true, |
234 | Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams); |
235 | } |
236 | } |
237 | } |
238 | |
239 | #define CREATE_INSTANCE(FUSE_RELU, QGRAN) \ |
240 | template void SparseDenseInt8MMAvx2<FUSE_RELU, QGRAN>( \ |
241 | int N, \ |
242 | const std::unique_ptr<BCSRMatrix<>>& bcsr, \ |
243 | const uint8_t* B, \ |
244 | int ldb, \ |
245 | int32_t* C_i32, \ |
246 | uint8_t* C_u8, \ |
247 | int ldc, \ |
248 | trRequantizationParams_t& rParams, \ |
249 | bool accum, \ |
250 | int thread_id, \ |
251 | int num_threads); |
252 | CREATE_INSTANCE(true, QuantizationGranularity::TENSOR) |
253 | CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL) |
254 | CREATE_INSTANCE(false, QuantizationGranularity::TENSOR) |
255 | CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL) |
256 | #undef CREATE_INSTANCE |
257 | |
258 | } // namespace internal |
259 | } // namespace fbgemm |
260 | |