1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmSparse.h" |
9 | |
10 | #if defined(__x86_64__) || defined(__i386__) || \ |
11 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
12 | #include <immintrin.h> |
13 | #endif |
14 | |
15 | namespace fbgemm { |
16 | namespace internal { |
17 | |
18 | void SparseDenseMMAvx512( |
19 | int M, |
20 | int N, |
21 | const int* row_ptr, |
22 | const int* col_idx, |
23 | const float* values, |
24 | const float* B, |
25 | int ldb, |
26 | float* C, |
27 | int ldc, |
28 | bool accum) { |
29 | // Calcualtes accum ? C += A * B : C = A * B |
30 | // size of values is equal to number of non-zeros (nnzs) |
31 | // size of row_ptr is equal to M + 1 |
32 | // size of col_idx is equal to nnzs |
33 | constexpr int VLEN = 16; |
34 | int j = 0; |
35 | const int effective_N = ((int)((N + VLEN - 1) / (2 * VLEN))) * (2 * VLEN); |
36 | for (; j < effective_N; j += 2 * VLEN) { |
37 | // r1 is for j:j+VLEN |
38 | // r2 is for j+VLEN:j+2*VLEN |
39 | // r2_rem is used to calculate the mask for r2 |
40 | int r2_rem = N - VLEN - j; |
41 | r2_rem = (r2_rem <= VLEN) ? r2_rem : (VLEN); |
42 | r2_rem = (r2_rem < 0) ? 0 : r2_rem; |
43 | __mmask16 mask_v = (((long long)1) << r2_rem) - 1; |
44 | for (int i = 0; i < M; ++i) { |
45 | __m512 c_v_r1; |
46 | __m512 c_v_r2; |
47 | if (accum) { |
48 | c_v_r1 = _mm512_loadu_ps(C + i * ldc + j); |
49 | c_v_r2 = _mm512_maskz_loadu_ps(mask_v, C + i * ldc + j + VLEN); |
50 | } else { |
51 | c_v_r1 = _mm512_set1_ps(0.0f); |
52 | c_v_r2 = _mm512_set1_ps(0.0f); |
53 | } |
54 | int r = row_ptr[i]; |
55 | int r_end_aligned = row_ptr[i] + (row_ptr[i + 1] - row_ptr[i]) / 3 * 3; |
56 | // unrolled by 3 |
57 | for (; r < r_end_aligned; r += 3) { |
58 | int acbr_0 = col_idx[r + 0]; |
59 | int acbr_1 = col_idx[r + 1]; |
60 | int acbr_2 = col_idx[r + 2]; |
61 | __m512 a_v_0 = _mm512_set1_ps(values[r + 0]); |
62 | __m512 a_v_1 = _mm512_set1_ps(values[r + 1]); |
63 | __m512 a_v_2 = _mm512_set1_ps(values[r + 2]); |
64 | __m512 br_v_0_r1 = _mm512_loadu_ps(B + acbr_0 * ldb + j); |
65 | __m512 br_v_1_r1 = _mm512_loadu_ps(B + acbr_1 * ldb + j); |
66 | __m512 br_v_2_r1 = _mm512_loadu_ps(B + acbr_2 * ldb + j); |
67 | c_v_r1 = _mm512_fmadd_ps(a_v_0, br_v_0_r1, c_v_r1); |
68 | c_v_r1 = _mm512_fmadd_ps(a_v_1, br_v_1_r1, c_v_r1); |
69 | c_v_r1 = _mm512_fmadd_ps(a_v_2, br_v_2_r1, c_v_r1); |
70 | __m512 br_v_0_r2 = |
71 | _mm512_maskz_loadu_ps(mask_v, B + acbr_0 * ldb + j + VLEN); |
72 | __m512 br_v_1_r2 = |
73 | _mm512_maskz_loadu_ps(mask_v, B + acbr_1 * ldb + j + VLEN); |
74 | __m512 br_v_2_r2 = |
75 | _mm512_maskz_loadu_ps(mask_v, B + acbr_2 * ldb + j + VLEN); |
76 | c_v_r2 = _mm512_fmadd_ps(a_v_0, br_v_0_r2, c_v_r2); |
77 | c_v_r2 = _mm512_fmadd_ps(a_v_1, br_v_1_r2, c_v_r2); |
78 | c_v_r2 = _mm512_fmadd_ps(a_v_2, br_v_2_r2, c_v_r2); |
79 | } |
80 | for (; r < row_ptr[i + 1]; ++r) { |
81 | int acbr = col_idx[r]; |
82 | __m512 a_v = _mm512_set1_ps(values[r]); |
83 | __m512 br_v_r1 = _mm512_loadu_ps(B + acbr * ldb + j); |
84 | c_v_r1 = _mm512_fmadd_ps(a_v, br_v_r1, c_v_r1); |
85 | __m512 br_v_r2 = |
86 | _mm512_maskz_loadu_ps(mask_v, B + acbr * ldb + j + VLEN); |
87 | c_v_r2 = _mm512_fmadd_ps(a_v, br_v_r2, c_v_r2); |
88 | } |
89 | _mm512_storeu_ps(C + i * ldc + j, c_v_r1); |
90 | _mm512_mask_storeu_ps(C + i * ldc + j + VLEN, mask_v, c_v_r2); |
91 | } // i loop |
92 | } |
93 | // Handle remainder j loop |
94 | int rem = N - j; |
95 | if (rem > 0) { |
96 | for (int i = 0; i < M; ++i) { |
97 | __m512 c_v; |
98 | __mmask16 mask_v = (((long long)1) << rem) - 1; |
99 | if (accum) { |
100 | c_v = _mm512_maskz_loadu_ps(mask_v, C + i * ldc + j); |
101 | } else { |
102 | c_v = _mm512_set1_ps(0.0f); |
103 | } |
104 | int r = row_ptr[i]; |
105 | int r_end_aligned = row_ptr[i] + (row_ptr[i + 1] - row_ptr[i]) / 8 * 8; |
106 | // unrolled by 8 |
107 | for (; r < r_end_aligned; r += 8) { |
108 | int acbr_0 = col_idx[r + 0]; |
109 | int acbr_1 = col_idx[r + 1]; |
110 | int acbr_2 = col_idx[r + 2]; |
111 | int acbr_3 = col_idx[r + 3]; |
112 | int acbr_4 = col_idx[r + 4]; |
113 | int acbr_5 = col_idx[r + 5]; |
114 | int acbr_6 = col_idx[r + 6]; |
115 | int acbr_7 = col_idx[r + 7]; |
116 | __m512 a_v_0 = _mm512_set1_ps(values[r + 0]); |
117 | __m512 a_v_1 = _mm512_set1_ps(values[r + 1]); |
118 | __m512 a_v_2 = _mm512_set1_ps(values[r + 2]); |
119 | __m512 a_v_3 = _mm512_set1_ps(values[r + 3]); |
120 | __m512 a_v_4 = _mm512_set1_ps(values[r + 4]); |
121 | __m512 a_v_5 = _mm512_set1_ps(values[r + 5]); |
122 | __m512 a_v_6 = _mm512_set1_ps(values[r + 6]); |
123 | __m512 a_v_7 = _mm512_set1_ps(values[r + 7]); |
124 | __m512 br_v_0 = _mm512_maskz_loadu_ps(mask_v, B + acbr_0 * ldb + j); |
125 | __m512 br_v_1 = _mm512_maskz_loadu_ps(mask_v, B + acbr_1 * ldb + j); |
126 | __m512 br_v_2 = _mm512_maskz_loadu_ps(mask_v, B + acbr_2 * ldb + j); |
127 | __m512 br_v_3 = _mm512_maskz_loadu_ps(mask_v, B + acbr_3 * ldb + j); |
128 | __m512 br_v_4 = _mm512_maskz_loadu_ps(mask_v, B + acbr_4 * ldb + j); |
129 | __m512 br_v_5 = _mm512_maskz_loadu_ps(mask_v, B + acbr_5 * ldb + j); |
130 | __m512 br_v_6 = _mm512_maskz_loadu_ps(mask_v, B + acbr_6 * ldb + j); |
131 | __m512 br_v_7 = _mm512_maskz_loadu_ps(mask_v, B + acbr_7 * ldb + j); |
132 | c_v = _mm512_fmadd_ps(a_v_0, br_v_0, c_v); |
133 | c_v = _mm512_fmadd_ps(a_v_1, br_v_1, c_v); |
134 | c_v = _mm512_fmadd_ps(a_v_2, br_v_2, c_v); |
135 | c_v = _mm512_fmadd_ps(a_v_3, br_v_3, c_v); |
136 | c_v = _mm512_fmadd_ps(a_v_4, br_v_4, c_v); |
137 | c_v = _mm512_fmadd_ps(a_v_5, br_v_5, c_v); |
138 | c_v = _mm512_fmadd_ps(a_v_6, br_v_6, c_v); |
139 | c_v = _mm512_fmadd_ps(a_v_7, br_v_7, c_v); |
140 | } |
141 | // Handle remainder r loop |
142 | for (; r < row_ptr[i + 1]; ++r) { |
143 | int acbr = col_idx[r]; |
144 | __m512 a_v = _mm512_set1_ps(values[r]); |
145 | __m512 br_v = _mm512_maskz_loadu_ps(mask_v, B + acbr * ldb + j); |
146 | c_v = _mm512_fmadd_ps(a_v, br_v, c_v); |
147 | } |
148 | _mm512_mask_storeu_ps(C + i * ldc + j, mask_v, c_v); |
149 | } |
150 | } |
151 | } |
152 | } // namespace internal |
153 | } // namespace fbgemm |
154 | |