1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmSparse.h"
9#include "fbgemm/spmmUtilsAvx2.h"
10
11#if defined(__x86_64__) || defined(__i386__) || \
12 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
13#include <immintrin.h>
14#endif
15#include <algorithm> // for min and max
16#include <cassert>
17#include <cstring>
18#include "./MaskAvx2.h"
19
20namespace fbgemm {
21namespace internal {
22
23static inline __m256i permute_row(__m256i row) {
24 // clang-format off
25 __m256i ret = _mm256_shuffle_epi8(
26 row,
27 _mm256_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0,
28 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0));
29 // clang-format on
30 return ret;
31}
32
33static inline void interleave_4rows(__m256i data[]) {
34 __m256i __t0 = _mm256_unpacklo_epi32(data[0], data[1]);
35 __m256i __t1 = _mm256_unpackhi_epi32(data[0], data[1]);
36 __m256i __t2 = _mm256_unpacklo_epi32(data[2], data[3]);
37 __m256i __t3 = _mm256_unpackhi_epi32(data[2], data[3]);
38 __m256i __tt0 = _mm256_unpacklo_epi64(__t0, __t2);
39 __m256i __tt1 = _mm256_unpacklo_epi64(__t1, __t3);
40 __m256i __tt2 = _mm256_unpackhi_epi64(__t0, __t2);
41 __m256i __tt3 = _mm256_unpackhi_epi64(__t1, __t3);
42 __m256i row0 = _mm256_permute2x128_si256(__tt0, __tt2, 0x20);
43 __m256i row1 = _mm256_permute2x128_si256(__tt1, __tt3, 0x20);
44 __m256i row2 = _mm256_permute2x128_si256(__tt0, __tt2, 0x31);
45 __m256i row3 = _mm256_permute2x128_si256(__tt1, __tt3, 0x31);
46 // End of int32 transpose
47 // Now we only need a simple row permutation to get the right result
48 data[0] = permute_row(row0);
49 data[1] = permute_row(row1);
50 data[2] = permute_row(row2);
51 data[3] = permute_row(row3);
52 return;
53}
54
55template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
56void SparseDenseInt8MMAvx2(
57 int N,
58 const std::unique_ptr<BCSRMatrix<>>& bcsr,
59 const uint8_t* B,
60 int ldb,
61 int32_t* C_i32,
62 uint8_t* C_u8,
63 int ldc,
64 trRequantizationParams_t& rParams,
65 bool accum,
66 int /*thread_id*/,
67 int /*num_threads*/) {
68 // Calcualtes accum ? C += A * B : C = A * B
69 constexpr int VLEN_INT8 = 32;
70 constexpr int VLEN_INT32 = 8;
71 constexpr int rowBlockSize = BCSRMatrix<>::RB;
72 (void)rowBlockSize; // Suppress unused variable warning
73 constexpr int colBlockSize = BCSRMatrix<>::CB;
74
75 constexpr int colTileSize = BCSRMatrix<>::COLTILE;
76 int K = bcsr->C;
77 int M = bcsr->R;
78 int kTiles = (K + colTileSize - 1) / colTileSize;
79
80 for (int i = 0; i < M; ++i) {
81 if (!accum) {
82 int j = 0;
83 __m256i c_v = _mm256_set1_epi32(0);
84 for (; j < N / VLEN_INT32 * VLEN_INT32; j += VLEN_INT32) {
85 _mm256_storeu_si256(
86 reinterpret_cast<__m256i*>(C_i32 + i * ldc + j), c_v);
87 }
88 // Handle remainder
89 int rem = N - j;
90 if (rem > 0) {
91 __m256i mask_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
92 &avx2_ps_or_epi32_combined_mask[VLEN_INT32 - rem]));
93 _mm256_maskstore_epi32(
94 reinterpret_cast<int32_t*>(C_i32 + i * ldc + j), mask_v, c_v);
95 }
96 }
97 for (int kt = 0; kt < kTiles; ++kt) {
98 int* row_ptr = bcsr->rowBPtr.data() + kt * M;
99 int* col_idx = bcsr->colBIdx.data();
100 int8_t* values = bcsr->values.data();
101 int curKSize = std::min(K - kt * colTileSize, colTileSize);
102
103 int r = row_ptr[i];
104 // int r_end_aligned = row_ptr[i] + (row_ptr[i + 1] - row_ptr[i]) / 4 * 4;
105 // unrolled by 1
106 for (; r < row_ptr[i + 1]; ++r) {
107 // this is needed for correct operation
108 assert(rowBlockSize == 1 && "row block size should be 1");
109 assert(colBlockSize == 4 && "column block size should be 4");
110 int acbr_block = col_idx[r];
111 int32_t v = reinterpret_cast<const int32_t*>(values)[r];
112 __m256i a_v = _mm256_set1_epi32(v);
113 int j = 0;
114 for (; j < N / VLEN_INT8 * VLEN_INT8; j += VLEN_INT8) {
115 __m256i br_v[4] = {};
116
117 for (int idx = 0;
118 idx < std::min(4, curKSize - acbr_block * colBlockSize);
119 ++idx) {
120 br_v[idx] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
121 B + (acbr_block * colBlockSize + idx + kt * colTileSize) * ldb +
122 j));
123 }
124
125 // interleave these 4 rows
126 interleave_4rows(br_v);
127
128 __m256i one_16bit_v = _mm256_set1_epi16(1);
129 __m256i c_v[4];
130 for (int idx = 0; idx < 4; ++idx) {
131 c_v[idx] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
132 C_i32 + i * ldc + j + idx * VLEN_INT32));
133 __m256i c_i16_v = _mm256_maddubs_epi16(br_v[idx], a_v);
134 __m256i c_i32_v = _mm256_madd_epi16(one_16bit_v, c_i16_v);
135 c_v[idx] = _mm256_add_epi32(c_v[idx], c_i32_v);
136 _mm256_storeu_si256(
137 reinterpret_cast<__m256i*>(
138 C_i32 + i * ldc + j + idx * VLEN_INT32),
139 c_v[idx]);
140 }
141 }
142 // Handle remainder j loop
143 int rem = N - j;
144 if (rem > 0) {
145 __m256i br_v[4] = {};
146 for (int idx = 0;
147 idx < std::min(4, curKSize - acbr_block * colBlockSize);
148 ++idx) {
149 uint8_t tmpDest[VLEN_INT8] = {};
150 std::memcpy(
151 tmpDest,
152 B + (acbr_block * colBlockSize + idx + kt * colTileSize) * ldb +
153 j,
154 rem);
155 br_v[idx] =
156 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(tmpDest));
157 }
158 // interleave these 4 rows
159 interleave_4rows(br_v);
160
161 __m256i c_v[4] = {};
162 int idx1 = 0;
163 for (; idx1 < rem / VLEN_INT32; ++idx1) {
164 c_v[idx1] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
165 C_i32 + i * ldc + j + idx1 * 8));
166 }
167 int rem_int32 = rem - idx1 * VLEN_INT32;
168 __m256i mask_int32_v;
169 if (rem_int32 > 0) {
170 mask_int32_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
171 &avx2_ps_or_epi32_combined_mask[VLEN_INT32 - rem_int32]));
172 c_v[idx1] = _mm256_maskload_epi32(
173 reinterpret_cast<const int*>(
174 C_i32 + i * ldc + j + idx1 * VLEN_INT32),
175 mask_int32_v);
176 }
177
178 __m256i one_16bit_v = _mm256_set1_epi16(1);
179 for (int idx = 0; idx < 4; ++idx) {
180 __m256i c_i16_v = _mm256_maddubs_epi16(br_v[idx], a_v);
181 __m256i c_i32_v = _mm256_madd_epi16(one_16bit_v, c_i16_v);
182 c_v[idx] = _mm256_add_epi32(c_v[idx], c_i32_v);
183 }
184
185 int idx2 = 0;
186 for (; idx2 < rem / VLEN_INT32; ++idx2) {
187 _mm256_storeu_si256(
188 reinterpret_cast<__m256i*>(
189 C_i32 + i * ldc + j + idx2 * VLEN_INT32),
190 c_v[idx2]);
191 }
192 if (rem_int32 > 0) {
193 _mm256_maskstore_epi32(
194 reinterpret_cast<int*>(C_i32 + i * ldc + j + idx2 * VLEN_INT32),
195 mask_int32_v,
196 c_v[idx2]);
197 }
198 }
199 }
200 }
201 }
202
203 block_type_t block{0, M, 0, N};
204 if (rParams.bias == nullptr) {
205 if (rParams.act_zero_point) {
206 trRequantizeOpt<
207 FUSE_RELU,
208 /*ACT_SYMMETRIC*/ false,
209 /*WEIGHT_SYMMETRIC*/ true,
210 /*HAS_BIAS*/ false,
211 Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
212 } else {
213 trRequantizeOpt<
214 FUSE_RELU,
215 /*ACT_SYMMETRIC*/ true,
216 /*WEIGHT_SYMMETRIC*/ true,
217 /*HAS_BIAS*/ false,
218 Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
219 }
220 } else {
221 if (rParams.act_zero_point) {
222 trRequantizeOpt<
223 FUSE_RELU,
224 /*ACT_SYMMETRIC*/ false,
225 /*WEIGHT_SYMMETRIC*/ true,
226 /*HAS_BIAS*/ true,
227 Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
228 } else {
229 trRequantizeOpt<
230 FUSE_RELU,
231 /*ACT_SYMMETRIC*/ true,
232 /*WEIGHT_SYMMETRIC*/ true,
233 /*HAS_BIAS*/ true,
234 Q_GRAN>(C_u8, C_i32, block, ldc, ldc, rParams);
235 }
236 }
237}
238
239#define CREATE_INSTANCE(FUSE_RELU, QGRAN) \
240 template void SparseDenseInt8MMAvx2<FUSE_RELU, QGRAN>( \
241 int N, \
242 const std::unique_ptr<BCSRMatrix<>>& bcsr, \
243 const uint8_t* B, \
244 int ldb, \
245 int32_t* C_i32, \
246 uint8_t* C_u8, \
247 int ldc, \
248 trRequantizationParams_t& rParams, \
249 bool accum, \
250 int thread_id, \
251 int num_threads);
252CREATE_INSTANCE(true, QuantizationGranularity::TENSOR)
253CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL)
254CREATE_INSTANCE(false, QuantizationGranularity::TENSOR)
255CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL)
256#undef CREATE_INSTANCE
257
258} // namespace internal
259} // namespace fbgemm
260