1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmI8Spmdm.h" |
9 | |
10 | #include <algorithm> |
11 | #include <array> |
12 | #include <cassert> |
13 | #include <cmath> |
14 | #include <cstring> |
15 | #include "./OptimizedKernelsAvx2.h" |
16 | |
17 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
18 | double spmdm_initial_time = 0.0; |
19 | double spmdm_transpose_uint8_time = 0.0; |
20 | double spmdm_transpose_32xN_time = 0.0; |
21 | double spmdm_compute_time = 0.0; |
22 | double spmdm_transpose_Nx32_time = 0.0; |
23 | double spmdm_run_time = 0.0; |
24 | double sconv_run_time = 0.0; |
25 | #endif |
26 | |
27 | using namespace std; |
28 | |
29 | namespace fbgemm { |
30 | |
31 | CompressedSparseColumn::CompressedSparseColumn(int num_of_rows, int num_of_cols) |
32 | : num_rows_(num_of_rows), |
33 | colptr_(num_of_cols + 1), |
34 | hyper_sparse_(false), |
35 | old_nnz_(-1) {} |
36 | |
37 | double CompressedSparseColumn::Density() const { |
38 | return static_cast<double>(NumOfNonZeros()) / (NumOfRows() * NumOfCols()); |
39 | } |
40 | |
41 | bool CompressedSparseColumn::IsHyperSparse() const { |
42 | if (NumOfNonZeros() != old_nnz_) { |
43 | old_nnz_ = NumOfNonZeros(); |
44 | // The number of non-zero per row is very small. |
45 | hyper_sparse_ = static_cast<double>(old_nnz_) / NumOfRows() < 0.3; |
46 | } |
47 | |
48 | return hyper_sparse_; |
49 | } |
50 | |
51 | // TODO: fallback when AVX2 is not available |
52 | void CompressedSparseColumn::SpMDM( |
53 | const block_type_t& block, |
54 | const uint8_t* A, |
55 | int lda, |
56 | bool accumulation, |
57 | int32_t* C, |
58 | int ldc) const { |
59 | int K = NumOfRows(); |
60 | int N = block.col_size; |
61 | |
62 | if (K == 0 || N == 0 || block.row_size == 0) { |
63 | return; |
64 | } |
65 | |
66 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
67 | std::chrono::time_point<std::chrono::high_resolution_clock> t_very_start, |
68 | t_start, t_end; |
69 | double dt; |
70 | t_start = std::chrono::high_resolution_clock::now(); |
71 | t_very_start = std::chrono::high_resolution_clock::now(); |
72 | #endif |
73 | |
74 | // Note: These (and others below) cause a ~2-3% overall performance drop in |
75 | // resnet/resnext so we are keeping arrays with dynamic size for gcc/clang and |
76 | // dynamically allocated memory for MSVC even though dynamically allocated |
77 | // memory works for all compilers. |
78 | #ifdef _MSC_VER |
79 | uint8_t* A_buffer = |
80 | static_cast<uint8_t*>(fbgemmAlignedAlloc(64, K * 32 * sizeof(uint8_t))); |
81 | int32_t* C_buffer = |
82 | static_cast<int32_t*>(fbgemmAlignedAlloc(64, N * 32 * sizeof(int32_t))); |
83 | #else |
84 | alignas(64) uint8_t A_buffer[K * 32]; |
85 | alignas(64) int32_t C_buffer[N * 32]; |
86 | #endif |
87 | |
88 | // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for |
89 | // each non-zero in B, we'd need to access the corresponding column in A. |
90 | // This results in strided access, which we want to avoid. |
91 | // Instead, we pre-transpose A and C, and compute C = (C^T + B^T * A^T)^T |
92 | |
93 | if (IsHyperSparse()) { |
94 | // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications. |
95 | // If NNZ/K is small, it's not worth doing transpose so we just use this |
96 | // scalar loop. |
97 | #ifdef _MSC_VER |
98 | int32_t* C_temp = static_cast<int32_t*>( |
99 | fbgemmAlignedAlloc(64, block.row_size * sizeof(int32_t))); |
100 | #else |
101 | int32_t C_temp[block.row_size]; |
102 | #endif |
103 | if (accumulation) { |
104 | for (int j = 0; j < block.col_size; ++j) { |
105 | int k = colptr_[block.col_start + j]; |
106 | int k_end = colptr_[block.col_start + j + 1]; |
107 | if (k_end == k) { |
108 | } else if (k_end == k + 1) { |
109 | int row = rowidx_[k]; |
110 | int w = values_[k]; |
111 | for (int i = 0; i < block.row_size; ++i) { |
112 | C[i * ldc + j] += A[(block.row_start + i) * lda + row] * w; |
113 | } |
114 | } else { |
115 | for (int i = 0; i < block.row_size; ++i) { |
116 | C_temp[i] = C[i * ldc + j]; |
117 | } |
118 | for (; k < k_end; ++k) { |
119 | int row = rowidx_[k]; |
120 | int w = values_[k]; |
121 | for (int i = 0; i < block.row_size; ++i) { |
122 | C_temp[i] += A[(block.row_start + i) * lda + row] * w; |
123 | } |
124 | } |
125 | for (int i = 0; i < block.row_size; ++i) { |
126 | C[i * ldc + j] = C_temp[i]; |
127 | } |
128 | } |
129 | } // for each column of B |
130 | } else { |
131 | for (int j = 0; j < block.col_size; ++j) { |
132 | int k = colptr_[block.col_start + j]; |
133 | int k_end = colptr_[block.col_start + j + 1]; |
134 | if (k_end == k) { |
135 | for (int i = 0; i < block.row_size; ++i) { |
136 | C[i * ldc + j] = 0; |
137 | } |
138 | } else if (k_end == k + 1) { |
139 | int row = rowidx_[k]; |
140 | int w = values_[k]; |
141 | for (int i = 0; i < block.row_size; ++i) { |
142 | C[i * ldc + j] = A[(block.row_start + i) * lda + row] * w; |
143 | } |
144 | } else { |
145 | for (int i = 0; i < block.row_size; ++i) { |
146 | C_temp[i] = 0; |
147 | } |
148 | for (; k < k_end; ++k) { |
149 | int row = rowidx_[k]; |
150 | int w = values_[k]; |
151 | for (int i = 0; i < block.row_size; ++i) { |
152 | C_temp[i] += A[(block.row_start + i) * lda + row] * w; |
153 | } |
154 | } |
155 | for (int i = 0; i < block.row_size; ++i) { |
156 | C[i * ldc + j] = C_temp[i]; |
157 | } |
158 | } |
159 | } // for each column of B |
160 | } |
161 | #ifdef _MSC_VER |
162 | fbgemmAlignedFree(A_buffer); |
163 | fbgemmAlignedFree(C_buffer); |
164 | fbgemmAlignedFree(C_temp); |
165 | #endif |
166 | return; |
167 | } |
168 | |
169 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
170 | t_end = std::chrono::high_resolution_clock::now(); |
171 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
172 | .count(); |
173 | spmdm_initial_time += (dt); |
174 | t_start = std::chrono::high_resolution_clock::now(); |
175 | #endif |
176 | |
177 | // Take 32 rows at a time |
178 | int i_end = block.row_start + block.row_size; |
179 | for (int i1 = block.row_start; i1 < i_end; i1 += 32) { |
180 | // Transpose 32 x K submatrix of A |
181 | if (i_end - i1 < 32) { |
182 | #ifdef _MSC_VER |
183 | uint8_t* A_temp_buffer = static_cast<uint8_t*>( |
184 | fbgemmAlignedAlloc(64, K * 32 * sizeof(uint8_t))); |
185 | #else |
186 | alignas(64) uint8_t A_temp_buffer[K * 32]; |
187 | #endif |
188 | for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) { |
189 | transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); |
190 | } |
191 | |
192 | for (int i2 = (i_end - i1) / 8 * 8; i2 < i_end - i1; ++i2) { |
193 | memcpy( |
194 | A_temp_buffer + i2 * K, A + (i1 + i2) * lda, K * sizeof(uint8_t)); |
195 | } |
196 | memset( |
197 | A_temp_buffer + (i_end - i1) * K, |
198 | 0, |
199 | (32 - (i_end - i1)) * K * sizeof(uint8_t)); |
200 | for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) { |
201 | transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32); |
202 | } |
203 | #ifdef _MSC_VER |
204 | fbgemmAlignedFree(A_temp_buffer); |
205 | #endif |
206 | } else { |
207 | for (int i2 = 0; i2 < 32; i2 += 8) { |
208 | transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); |
209 | } |
210 | } |
211 | |
212 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
213 | t_end = std::chrono::high_resolution_clock::now(); |
214 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
215 | .count(); |
216 | spmdm_transpose_uint8_time += (dt); |
217 | t_start = std::chrono::high_resolution_clock::now(); |
218 | #endif |
219 | |
220 | if (accumulation) { |
221 | // Transpose 32 x N submatrix of C to fill N x 32 C_buffer |
222 | transpose_simd( |
223 | std::min(32, i_end - i1), |
224 | N, |
225 | reinterpret_cast<const float*>(C + (i1 - block.row_start) * ldc), |
226 | ldc, |
227 | reinterpret_cast<float*>(C_buffer), |
228 | 32); |
229 | } else { |
230 | memset(C_buffer, 0, N * 32 * sizeof(int32_t)); |
231 | } |
232 | |
233 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
234 | t_end = std::chrono::high_resolution_clock::now(); |
235 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
236 | .count(); |
237 | spmdm_transpose_32xN_time += (dt); |
238 | t_start = std::chrono::high_resolution_clock::now(); |
239 | #endif |
240 | |
241 | spmdmKernelAvx2( |
242 | block.col_size, |
243 | A_buffer, |
244 | colptr_.data() + block.col_start, |
245 | values_.data(), |
246 | rowidx_.data(), |
247 | C_buffer); |
248 | |
249 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
250 | t_end = std::chrono::high_resolution_clock::now(); |
251 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
252 | .count(); |
253 | spmdm_compute_time += (dt); |
254 | t_start = std::chrono::high_resolution_clock::now(); |
255 | #endif |
256 | |
257 | // Transpose N x 32 C_buffer to fill 32 x N submatrix of C |
258 | transpose_simd( |
259 | N, |
260 | std::min(32, i_end - i1), |
261 | reinterpret_cast<const float*>(C_buffer), |
262 | 32, |
263 | reinterpret_cast<float*>(C + (i1 - block.row_start) * ldc), |
264 | ldc); |
265 | |
266 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
267 | t_end = std::chrono::high_resolution_clock::now(); |
268 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
269 | .count(); |
270 | spmdm_transpose_Nx32_time += (dt); |
271 | t_start = std::chrono::high_resolution_clock::now(); |
272 | #endif |
273 | } |
274 | |
275 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
276 | t_end = std::chrono::high_resolution_clock::now(); |
277 | dt = |
278 | std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_very_start) |
279 | .count(); |
280 | spmdm_run_time += (dt); |
281 | t_start = std::chrono::high_resolution_clock::now(); |
282 | #endif |
283 | #ifdef _MSC_VER |
284 | fbgemmAlignedFree(A_buffer); |
285 | fbgemmAlignedFree(C_buffer); |
286 | #endif |
287 | } |
288 | |
289 | void CompressedSparseColumn::SparseConv( |
290 | const conv_param_t<>& conv_p, |
291 | const block_type_t& block, |
292 | const uint8_t* A, |
293 | int32_t A_zero_point, |
294 | bool accumulation, |
295 | int32_t* C, |
296 | int ldc) const { |
297 | int K = NumOfRows(); |
298 | int N = block.col_size; |
299 | |
300 | if (K == 0 || N == 0) { |
301 | return; |
302 | } |
303 | |
304 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
305 | std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end; |
306 | double dt; |
307 | t_start = std::chrono::high_resolution_clock::now(); |
308 | #endif |
309 | |
310 | // TODO: if not hyper sparse, transpose a block of A matrix as in SpMDM. |
311 | if (!accumulation) { |
312 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
313 | for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { |
314 | C[(i - block.row_start) * ldc + j - block.col_start] = 0; |
315 | } |
316 | } |
317 | } |
318 | for (int j = block.col_start; j < block.col_start + block.col_size; ++j) { |
319 | for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) { |
320 | int v = values_[k]; |
321 | for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { |
322 | int ow = i % conv_p.OUT_DIM[1]; |
323 | int oh = i / conv_p.OUT_DIM[1] % conv_p.OUT_DIM[0]; |
324 | int n = i / conv_p.OUT_DIM[1] / conv_p.OUT_DIM[0]; |
325 | assert(n < conv_p.MB); |
326 | int iw = -conv_p.pad[1] + ow * conv_p.stride[1] + kw_[k]; |
327 | int ih = -conv_p.pad[0] + oh * conv_p.stride[0] + kh_[k]; |
328 | |
329 | if (ih >= 0 && ih < conv_p.IN_DIM[0] && iw >= 0 && |
330 | iw < conv_p.IN_DIM[1]) { |
331 | C[(i - block.row_start) * ldc + j - block.col_start] += |
332 | A[((n * conv_p.IN_DIM[0] + ih) * conv_p.IN_DIM[1] + iw) * |
333 | conv_p.IC + |
334 | ic_[k]] * |
335 | v; |
336 | } else { |
337 | C[(i - block.row_start) * ldc + j - block.col_start] += |
338 | A_zero_point * v; |
339 | } |
340 | } |
341 | } |
342 | } // for each column of B |
343 | |
344 | #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN |
345 | t_end = std::chrono::high_resolution_clock::now(); |
346 | dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start) |
347 | .count(); |
348 | sconv_run_time += (dt); |
349 | #endif |
350 | } |
351 | |
352 | } // namespace fbgemm |
353 | |