1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmI8Spmdm.h"
9
10#include <algorithm>
11#include <array>
12#include <cassert>
13#include <cmath>
14#include <cstring>
15#include "./OptimizedKernelsAvx2.h"
16
17#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
18double spmdm_initial_time = 0.0;
19double spmdm_transpose_uint8_time = 0.0;
20double spmdm_transpose_32xN_time = 0.0;
21double spmdm_compute_time = 0.0;
22double spmdm_transpose_Nx32_time = 0.0;
23double spmdm_run_time = 0.0;
24double sconv_run_time = 0.0;
25#endif
26
27using namespace std;
28
29namespace fbgemm {
30
31CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols)
32 : num_rows_(num_of_rows),
33 colptr_(num_of_cols + 1),
34 hyper_sparse_(false),
35 old_nnz_(-1) {}
36
37double CompressedSparseColumn::Density() const {
38 return static_cast<double>(NumOfNonZeros()) / (NumOfRows() * NumOfCols());
39}
40
41bool CompressedSparseColumn::IsHyperSparse() const {
42 if (NumOfNonZeros() != old_nnz_) {
43 old_nnz_ = NumOfNonZeros();
44 // The number of non-zero per row is very small.
45 hyper_sparse_ = static_cast<double>(old_nnz_) / NumOfRows() < 0.3;
46 }
47
48 return hyper_sparse_;
49}
50
51// TODO: fallback when AVX2 is not available
52void CompressedSparseColumn::SpMDM(
53 const block_type_t& block,
54 const uint8_t* A,
55 int lda,
56 bool accumulation,
57 int32_t* C,
58 int ldc) const {
59 int K = NumOfRows();
60 int N = block.col_size;
61
62 if (K == 0 || N == 0 || block.row_size == 0) {
63 return;
64 }
65
66#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
67 std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start,
68 t_start, t_end;
69 double dt;
70 t_start = std::chrono::high_resolution_clock::now();
71 t_very_start = std::chrono::high_resolution_clock::now();
72#endif
73
74// Note: These (and others below) cause a ~2-3% overall performance drop in
75// resnet/resnext so we are keeping arrays with dynamic size for gcc/clang and
76// dynamically allocated memory for MSVC even though dynamically allocated
77// memory works for all compilers.
78#ifdef _MSC_VER
79 uint8_t* A_buffer =
80 static_cast<uint8_t*>(fbgemmAlignedAlloc(64, K * 32 * sizeof(uint8_t)));
81 int32_t* C_buffer =
82 static_cast<int32_t*>(fbgemmAlignedAlloc(64, N * 32 * sizeof(int32_t)));
83#else
84 alignas(64) uint8_t A_buffer[K * 32];
85 alignas(64) int32_t C_buffer[N * 32];
86#endif
87
88 // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for
89 // each non-zero in B, we'd need to access the corresponding column in A.
90 // This results in strided access, which we want to avoid.
91 // Instead, we pre-transpose A and C, and compute C = (C^T + B^T * A^T)^T
92
93 if (IsHyperSparse()) {
94 // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications.
95 // If NNZ/K is small, it's not worth doing transpose so we just use this
96 // scalar loop.
97#ifdef _MSC_VER
98 int32_t* C_temp = static_cast<int32_t*>(
99 fbgemmAlignedAlloc(64, block.row_size * sizeof(int32_t)));
100#else
101 int32_t C_temp[block.row_size];
102#endif
103 if (accumulation) {
104 for (int j = 0; j < block.col_size; ++j) {
105 int k = colptr_[block.col_start + j];
106 int k_end = colptr_[block.col_start + j + 1];
107 if (k_end == k) {
108 } else if (k_end == k + 1) {
109 int row = rowidx_[k];
110 int w = values_[k];
111 for (int i = 0; i < block.row_size; ++i) {
112 C[i * ldc + j] += A[(block.row_start + i) * lda + row] * w;
113 }
114 } else {
115 for (int i = 0; i < block.row_size; ++i) {
116 C_temp[i] = C[i * ldc + j];
117 }
118 for (; k < k_end; ++k) {
119 int row = rowidx_[k];
120 int w = values_[k];
121 for (int i = 0; i < block.row_size; ++i) {
122 C_temp[i] += A[(block.row_start + i) * lda + row] * w;
123 }
124 }
125 for (int i = 0; i < block.row_size; ++i) {
126 C[i * ldc + j] = C_temp[i];
127 }
128 }
129 } // for each column of B
130 } else {
131 for (int j = 0; j < block.col_size; ++j) {
132 int k = colptr_[block.col_start + j];
133 int k_end = colptr_[block.col_start + j + 1];
134 if (k_end == k) {
135 for (int i = 0; i < block.row_size; ++i) {
136 C[i * ldc + j] = 0;
137 }
138 } else if (k_end == k + 1) {
139 int row = rowidx_[k];
140 int w = values_[k];
141 for (int i = 0; i < block.row_size; ++i) {
142 C[i * ldc + j] = A[(block.row_start + i) * lda + row] * w;
143 }
144 } else {
145 for (int i = 0; i < block.row_size; ++i) {
146 C_temp[i] = 0;
147 }
148 for (; k < k_end; ++k) {
149 int row = rowidx_[k];
150 int w = values_[k];
151 for (int i = 0; i < block.row_size; ++i) {
152 C_temp[i] += A[(block.row_start + i) * lda + row] * w;
153 }
154 }
155 for (int i = 0; i < block.row_size; ++i) {
156 C[i * ldc + j] = C_temp[i];
157 }
158 }
159 } // for each column of B
160 }
161#ifdef _MSC_VER
162 fbgemmAlignedFree(A_buffer);
163 fbgemmAlignedFree(C_buffer);
164 fbgemmAlignedFree(C_temp);
165#endif
166 return;
167 }
168
169#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
170 t_end = std::chrono::high_resolution_clock::now();
171 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
172 .count();
173 spmdm_initial_time += (dt);
174 t_start = std::chrono::high_resolution_clock::now();
175#endif
176
177 // Take 32 rows at a time
178 int i_end = block.row_start + block.row_size;
179 for (int i1 = block.row_start; i1 < i_end; i1 += 32) {
180 // Transpose 32 x K submatrix of A
181 if (i_end - i1 < 32) {
182#ifdef _MSC_VER
183 uint8_t* A_temp_buffer = static_cast<uint8_t*>(
184 fbgemmAlignedAlloc(64, K * 32 * sizeof(uint8_t)));
185#else
186 alignas(64) uint8_t A_temp_buffer[K * 32];
187#endif
188 for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) {
189 transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
190 }
191
192 for (int i2 = (i_end - i1) / 8 * 8; i2 < i_end - i1; ++i2) {
193 memcpy(
194 A_temp_buffer + i2 * K, A + (i1 + i2) * lda, K * sizeof(uint8_t));
195 }
196 memset(
197 A_temp_buffer + (i_end - i1) * K,
198 0,
199 (32 - (i_end - i1)) * K * sizeof(uint8_t));
200 for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) {
201 transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32);
202 }
203#ifdef _MSC_VER
204 fbgemmAlignedFree(A_temp_buffer);
205#endif
206 } else {
207 for (int i2 = 0; i2 < 32; i2 += 8) {
208 transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
209 }
210 }
211
212#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
213 t_end = std::chrono::high_resolution_clock::now();
214 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
215 .count();
216 spmdm_transpose_uint8_time += (dt);
217 t_start = std::chrono::high_resolution_clock::now();
218#endif
219
220 if (accumulation) {
221 // Transpose 32 x N submatrix of C to fill N x 32 C_buffer
222 transpose_simd(
223 std::min(32, i_end - i1),
224 N,
225 reinterpret_cast<const float*>(C + (i1 - block.row_start) * ldc),
226 ldc,
227 reinterpret_cast<float*>(C_buffer),
228 32);
229 } else {
230 memset(C_buffer, 0, N * 32 * sizeof(int32_t));
231 }
232
233#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
234 t_end = std::chrono::high_resolution_clock::now();
235 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
236 .count();
237 spmdm_transpose_32xN_time += (dt);
238 t_start = std::chrono::high_resolution_clock::now();
239#endif
240
241 spmdmKernelAvx2(
242 block.col_size,
243 A_buffer,
244 colptr_.data() + block.col_start,
245 values_.data(),
246 rowidx_.data(),
247 C_buffer);
248
249#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
250 t_end = std::chrono::high_resolution_clock::now();
251 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
252 .count();
253 spmdm_compute_time += (dt);
254 t_start = std::chrono::high_resolution_clock::now();
255#endif
256
257 // Transpose N x 32 C_buffer to fill 32 x N submatrix of C
258 transpose_simd(
259 N,
260 std::min(32, i_end - i1),
261 reinterpret_cast<const float*>(C_buffer),
262 32,
263 reinterpret_cast<float*>(C + (i1 - block.row_start) * ldc),
264 ldc);
265
266#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
267 t_end = std::chrono::high_resolution_clock::now();
268 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
269 .count();
270 spmdm_transpose_Nx32_time += (dt);
271 t_start = std::chrono::high_resolution_clock::now();
272#endif
273 }
274
275#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
276 t_end = std::chrono::high_resolution_clock::now();
277 dt =
278 std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start)
279 .count();
280 spmdm_run_time += (dt);
281 t_start = std::chrono::high_resolution_clock::now();
282#endif
283#ifdef _MSC_VER
284 fbgemmAlignedFree(A_buffer);
285 fbgemmAlignedFree(C_buffer);
286#endif
287}
288
289void CompressedSparseColumn::SparseConv(
290 const conv_param_t<>& conv_p,
291 const block_type_t& block,
292 const uint8_t* A,
293 int32_t A_zero_point,
294 bool accumulation,
295 int32_t* C,
296 int ldc) const {
297 int K = NumOfRows();
298 int N = block.col_size;
299
300 if (K == 0 || N == 0) {
301 return;
302 }
303
304#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
305 std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end;
306 double dt;
307 t_start = std::chrono::high_resolution_clock::now();
308#endif
309
310 // TODO: if not hyper sparse, transpose a block of A matrix as in SpMDM.
311 if (!accumulation) {
312 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
313 for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
314 C[(i - block.row_start) * ldc + j - block.col_start] = 0;
315 }
316 }
317 }
318 for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
319 for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) {
320 int v = values_[k];
321 for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
322 int ow = i % conv_p.OUT_DIM[1];
323 int oh = i / conv_p.OUT_DIM[1] % conv_p.OUT_DIM[0];
324 int n = i / conv_p.OUT_DIM[1] / conv_p.OUT_DIM[0];
325 assert(n < conv_p.MB);
326 int iw = -conv_p.pad[1] + ow * conv_p.stride[1] + kw_[k];
327 int ih = -conv_p.pad[0] + oh * conv_p.stride[0] + kh_[k];
328
329 if (ih >= 0 && ih < conv_p.IN_DIM[0] && iw >= 0 &&
330 iw < conv_p.IN_DIM[1]) {
331 C[(i - block.row_start) * ldc + j - block.col_start] +=
332 A[((n * conv_p.IN_DIM[0] + ih) * conv_p.IN_DIM[1] + iw) *
333 conv_p.IC +
334 ic_[k]] *
335 v;
336 } else {
337 C[(i - block.row_start) * ldc + j - block.col_start] +=
338 A_zero_point * v;
339 }
340 }
341 }
342 } // for each column of B
343
344#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
345 t_end = std::chrono::high_resolution_clock::now();
346 dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
347 .count();
348 sconv_run_time += (dt);
349#endif
350}
351
352} // namespace fbgemm
353