1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmSparse.h" |
9 | |
10 | #if defined(__x86_64__) || defined(__i386__) || \ |
11 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
12 | #include <immintrin.h> |
13 | #endif |
14 | #include "./MaskAvx2.h" |
15 | |
16 | namespace fbgemm { |
17 | namespace internal { |
18 | |
19 | void SparseDenseMMAvx2( |
20 | int M, |
21 | int N, |
22 | const int* row_ptr, |
23 | const int* col_idx, |
24 | const float* values, |
25 | const float* B, |
26 | int ldb, |
27 | float* C, |
28 | int ldc, |
29 | bool accum) { |
30 | // Calcualtes accum ? C += A * B : C = A * B |
31 | // size of values is equal to number of non-zeros (nnzs) |
32 | // size of row_ptr is equal to M + 1 |
33 | // size of col_idx is equal to nnzs |
34 | |
35 | constexpr int VLEN = 8; |
36 | int j = 0; |
37 | int effective_N = (N + VLEN - 1) / (2 * VLEN) * (2 * VLEN); |
38 | for (; j < effective_N; j += (2 * VLEN)) { |
39 | // r1 is for j:j+VLEN |
40 | // r2 is for j+VLEN:j+2*VLEN |
41 | // r2_rem is used to calculate the mask for r2 |
42 | int r2_rem = N - VLEN - j; |
43 | r2_rem = (r2_rem <= VLEN) ? r2_rem : VLEN; |
44 | r2_rem = (r2_rem < 0) ? 0 : r2_rem; |
45 | __m256i mask_v = _mm256_loadu_si256( |
46 | reinterpret_cast<const __m256i*>(&avx2_ps_or_epi32_masks[r2_rem])); |
47 | for (int i = 0; i < M; ++i) { |
48 | __m256 c_v_r1; |
49 | __m256 c_v_r2; |
50 | if (accum) { |
51 | c_v_r1 = _mm256_loadu_ps(C + i * ldc + j); |
52 | c_v_r2 = _mm256_maskload_ps(C + i * ldc + j + VLEN, mask_v); |
53 | } else { |
54 | c_v_r1 = _mm256_set1_ps(0.0f); |
55 | c_v_r2 = _mm256_set1_ps(0.0f); |
56 | } |
57 | int r = row_ptr[i]; |
58 | // unrolled by 4 |
59 | for (; r < row_ptr[i + 1] - 4; r += 4) { |
60 | int acbr_0 = col_idx[r + 0]; |
61 | int acbr_1 = col_idx[r + 1]; |
62 | int acbr_2 = col_idx[r + 2]; |
63 | int acbr_3 = col_idx[r + 3]; |
64 | __m256 a_v0 = _mm256_set1_ps(values[r + 0]); |
65 | __m256 a_v1 = _mm256_set1_ps(values[r + 1]); |
66 | __m256 a_v2 = _mm256_set1_ps(values[r + 2]); |
67 | __m256 a_v3 = _mm256_set1_ps(values[r + 3]); |
68 | __m256 br_v_0_r1 = _mm256_loadu_ps(B + acbr_0 * ldb + j); |
69 | __m256 br_v_1_r1 = _mm256_loadu_ps(B + acbr_1 * ldb + j); |
70 | __m256 br_v_2_r1 = _mm256_loadu_ps(B + acbr_2 * ldb + j); |
71 | __m256 br_v_3_r1 = _mm256_loadu_ps(B + acbr_3 * ldb + j); |
72 | __m256 br_v_0_r2 = _mm256_loadu_ps(B + acbr_0 * ldb + j + VLEN); |
73 | __m256 br_v_1_r2 = _mm256_loadu_ps(B + acbr_1 * ldb + j + VLEN); |
74 | __m256 br_v_2_r2 = _mm256_loadu_ps(B + acbr_2 * ldb + j + VLEN); |
75 | __m256 br_v_3_r2 = _mm256_loadu_ps(B + acbr_3 * ldb + j + VLEN); |
76 | c_v_r1 = _mm256_fmadd_ps(a_v0, br_v_0_r1, c_v_r1); |
77 | c_v_r1 = _mm256_fmadd_ps(a_v1, br_v_1_r1, c_v_r1); |
78 | c_v_r1 = _mm256_fmadd_ps(a_v2, br_v_2_r1, c_v_r1); |
79 | c_v_r1 = _mm256_fmadd_ps(a_v3, br_v_3_r1, c_v_r1); |
80 | c_v_r2 = _mm256_fmadd_ps(a_v0, br_v_0_r2, c_v_r2); |
81 | c_v_r2 = _mm256_fmadd_ps(a_v1, br_v_1_r2, c_v_r2); |
82 | c_v_r2 = _mm256_fmadd_ps(a_v2, br_v_2_r2, c_v_r2); |
83 | c_v_r2 = _mm256_fmadd_ps(a_v3, br_v_3_r2, c_v_r2); |
84 | } |
85 | for (; r < row_ptr[i + 1]; ++r) { |
86 | int acbr = col_idx[r]; |
87 | __m256 a_v = _mm256_set1_ps(values[r]); |
88 | __m256 br_v_r1 = _mm256_loadu_ps(B + acbr * ldb + j); |
89 | __m256 br_v_r2 = _mm256_maskload_ps(B + acbr * ldb + j + VLEN, mask_v); |
90 | c_v_r1 = _mm256_fmadd_ps(a_v, br_v_r1, c_v_r1); |
91 | c_v_r2 = _mm256_fmadd_ps(a_v, br_v_r2, c_v_r2); |
92 | } |
93 | _mm256_storeu_ps(C + i * ldc + j, c_v_r1); |
94 | _mm256_maskstore_ps(C + i * ldc + j + VLEN, mask_v, c_v_r2); |
95 | } // i loop |
96 | } |
97 | // Handle remainder j loop |
98 | int rem = N - j; |
99 | if (rem > 0) { |
100 | for (int i = 0; i < M; ++i) { |
101 | __m256 c_v_r; |
102 | __m256i mask_v = _mm256_loadu_si256( |
103 | reinterpret_cast<const __m256i*>(&avx2_ps_or_epi32_masks[rem])); |
104 | if (accum) { |
105 | c_v_r = _mm256_maskload_ps(C + i * ldc + j, mask_v); |
106 | } else { |
107 | c_v_r = _mm256_set1_ps(0.0f); |
108 | } |
109 | int r = row_ptr[i]; |
110 | for (; r < row_ptr[i + 1] - 4; r += 4) { |
111 | int acbr_0 = col_idx[r + 0]; |
112 | int acbr_1 = col_idx[r + 1]; |
113 | int acbr_2 = col_idx[r + 2]; |
114 | int acbr_3 = col_idx[r + 3]; |
115 | __m256 a_v0 = _mm256_set1_ps(values[r + 0]); |
116 | __m256 a_v1 = _mm256_set1_ps(values[r + 1]); |
117 | __m256 a_v2 = _mm256_set1_ps(values[r + 2]); |
118 | __m256 a_v3 = _mm256_set1_ps(values[r + 3]); |
119 | __m256 br_v_r0 = _mm256_maskload_ps(B + acbr_0 * ldb + j, mask_v); |
120 | __m256 br_v_r1 = _mm256_maskload_ps(B + acbr_1 * ldb + j, mask_v); |
121 | __m256 br_v_r2 = _mm256_maskload_ps(B + acbr_2 * ldb + j, mask_v); |
122 | __m256 br_v_r3 = _mm256_maskload_ps(B + acbr_3 * ldb + j, mask_v); |
123 | c_v_r = _mm256_fmadd_ps(a_v0, br_v_r0, c_v_r); |
124 | c_v_r = _mm256_fmadd_ps(a_v1, br_v_r1, c_v_r); |
125 | c_v_r = _mm256_fmadd_ps(a_v2, br_v_r2, c_v_r); |
126 | c_v_r = _mm256_fmadd_ps(a_v3, br_v_r3, c_v_r); |
127 | } |
128 | // Handle remainder r loop |
129 | for (; r < row_ptr[i + 1]; ++r) { |
130 | int acbr = col_idx[r]; |
131 | __m256 a_v = _mm256_set1_ps(values[r]); |
132 | __m256 br_v_r = _mm256_maskload_ps(B + acbr * ldb + j, mask_v); |
133 | c_v_r = _mm256_fmadd_ps(a_v, br_v_r, c_v_r); |
134 | } |
135 | _mm256_maskstore_ps(C + i * ldc + j, mask_v, c_v_r); |
136 | } |
137 | } |
138 | } |
139 | } // namespace internal |
140 | } // namespace fbgemm |
141 | |