1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmSparse.h"
9
10#if defined(__x86_64__) || defined(__i386__) || \
11 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
12#include <immintrin.h>
13#endif
14
15namespace fbgemm {
16namespace internal {
17
18void SparseDenseMMAvx512(
19 int M,
20 int N,
21 const int* row_ptr,
22 const int* col_idx,
23 const float* values,
24 const float* B,
25 int ldb,
26 float* C,
27 int ldc,
28 bool accum) {
29 // Calcualtes accum ? C += A * B : C = A * B
30 // size of values is equal to number of non-zeros (nnzs)
31 // size of row_ptr is equal to M + 1
32 // size of col_idx is equal to nnzs
33 constexpr int VLEN = 16;
34 int j = 0;
35 const int effective_N = ((int)((N + VLEN - 1) / (2 * VLEN))) * (2 * VLEN);
36 for (; j < effective_N; j += 2 * VLEN) {
37 // r1 is for j:j+VLEN
38 // r2 is for j+VLEN:j+2*VLEN
39 // r2_rem is used to calculate the mask for r2
40 int r2_rem = N - VLEN - j;
41 r2_rem = (r2_rem <= VLEN) ? r2_rem : (VLEN);
42 r2_rem = (r2_rem < 0) ? 0 : r2_rem;
43 __mmask16 mask_v = (((long long)1) << r2_rem) - 1;
44 for (int i = 0; i < M; ++i) {
45 __m512 c_v_r1;
46 __m512 c_v_r2;
47 if (accum) {
48 c_v_r1 = _mm512_loadu_ps(C + i * ldc + j);
49 c_v_r2 = _mm512_maskz_loadu_ps(mask_v, C + i * ldc + j + VLEN);
50 } else {
51 c_v_r1 = _mm512_set1_ps(0.0f);
52 c_v_r2 = _mm512_set1_ps(0.0f);
53 }
54 int r = row_ptr[i];
55 int r_end_aligned = row_ptr[i] + (row_ptr[i + 1] - row_ptr[i]) / 3 * 3;
56 // unrolled by 3
57 for (; r < r_end_aligned; r += 3) {
58 int acbr_0 = col_idx[r + 0];
59 int acbr_1 = col_idx[r + 1];
60 int acbr_2 = col_idx[r + 2];
61 __m512 a_v_0 = _mm512_set1_ps(values[r + 0]);
62 __m512 a_v_1 = _mm512_set1_ps(values[r + 1]);
63 __m512 a_v_2 = _mm512_set1_ps(values[r + 2]);
64 __m512 br_v_0_r1 = _mm512_loadu_ps(B + acbr_0 * ldb + j);
65 __m512 br_v_1_r1 = _mm512_loadu_ps(B + acbr_1 * ldb + j);
66 __m512 br_v_2_r1 = _mm512_loadu_ps(B + acbr_2 * ldb + j);
67 c_v_r1 = _mm512_fmadd_ps(a_v_0, br_v_0_r1, c_v_r1);
68 c_v_r1 = _mm512_fmadd_ps(a_v_1, br_v_1_r1, c_v_r1);
69 c_v_r1 = _mm512_fmadd_ps(a_v_2, br_v_2_r1, c_v_r1);
70 __m512 br_v_0_r2 =
71 _mm512_maskz_loadu_ps(mask_v, B + acbr_0 * ldb + j + VLEN);
72 __m512 br_v_1_r2 =
73 _mm512_maskz_loadu_ps(mask_v, B + acbr_1 * ldb + j + VLEN);
74 __m512 br_v_2_r2 =
75 _mm512_maskz_loadu_ps(mask_v, B + acbr_2 * ldb + j + VLEN);
76 c_v_r2 = _mm512_fmadd_ps(a_v_0, br_v_0_r2, c_v_r2);
77 c_v_r2 = _mm512_fmadd_ps(a_v_1, br_v_1_r2, c_v_r2);
78 c_v_r2 = _mm512_fmadd_ps(a_v_2, br_v_2_r2, c_v_r2);
79 }
80 for (; r < row_ptr[i + 1]; ++r) {
81 int acbr = col_idx[r];
82 __m512 a_v = _mm512_set1_ps(values[r]);
83 __m512 br_v_r1 = _mm512_loadu_ps(B + acbr * ldb + j);
84 c_v_r1 = _mm512_fmadd_ps(a_v, br_v_r1, c_v_r1);
85 __m512 br_v_r2 =
86 _mm512_maskz_loadu_ps(mask_v, B + acbr * ldb + j + VLEN);
87 c_v_r2 = _mm512_fmadd_ps(a_v, br_v_r2, c_v_r2);
88 }
89 _mm512_storeu_ps(C + i * ldc + j, c_v_r1);
90 _mm512_mask_storeu_ps(C + i * ldc + j + VLEN, mask_v, c_v_r2);
91 } // i loop
92 }
93 // Handle remainder j loop
94 int rem = N - j;
95 if (rem > 0) {
96 for (int i = 0; i < M; ++i) {
97 __m512 c_v;
98 __mmask16 mask_v = (((long long)1) << rem) - 1;
99 if (accum) {
100 c_v = _mm512_maskz_loadu_ps(mask_v, C + i * ldc + j);
101 } else {
102 c_v = _mm512_set1_ps(0.0f);
103 }
104 int r = row_ptr[i];
105 int r_end_aligned = row_ptr[i] + (row_ptr[i + 1] - row_ptr[i]) / 8 * 8;
106 // unrolled by 8
107 for (; r < r_end_aligned; r += 8) {
108 int acbr_0 = col_idx[r + 0];
109 int acbr_1 = col_idx[r + 1];
110 int acbr_2 = col_idx[r + 2];
111 int acbr_3 = col_idx[r + 3];
112 int acbr_4 = col_idx[r + 4];
113 int acbr_5 = col_idx[r + 5];
114 int acbr_6 = col_idx[r + 6];
115 int acbr_7 = col_idx[r + 7];
116 __m512 a_v_0 = _mm512_set1_ps(values[r + 0]);
117 __m512 a_v_1 = _mm512_set1_ps(values[r + 1]);
118 __m512 a_v_2 = _mm512_set1_ps(values[r + 2]);
119 __m512 a_v_3 = _mm512_set1_ps(values[r + 3]);
120 __m512 a_v_4 = _mm512_set1_ps(values[r + 4]);
121 __m512 a_v_5 = _mm512_set1_ps(values[r + 5]);
122 __m512 a_v_6 = _mm512_set1_ps(values[r + 6]);
123 __m512 a_v_7 = _mm512_set1_ps(values[r + 7]);
124 __m512 br_v_0 = _mm512_maskz_loadu_ps(mask_v, B + acbr_0 * ldb + j);
125 __m512 br_v_1 = _mm512_maskz_loadu_ps(mask_v, B + acbr_1 * ldb + j);
126 __m512 br_v_2 = _mm512_maskz_loadu_ps(mask_v, B + acbr_2 * ldb + j);
127 __m512 br_v_3 = _mm512_maskz_loadu_ps(mask_v, B + acbr_3 * ldb + j);
128 __m512 br_v_4 = _mm512_maskz_loadu_ps(mask_v, B + acbr_4 * ldb + j);
129 __m512 br_v_5 = _mm512_maskz_loadu_ps(mask_v, B + acbr_5 * ldb + j);
130 __m512 br_v_6 = _mm512_maskz_loadu_ps(mask_v, B + acbr_6 * ldb + j);
131 __m512 br_v_7 = _mm512_maskz_loadu_ps(mask_v, B + acbr_7 * ldb + j);
132 c_v = _mm512_fmadd_ps(a_v_0, br_v_0, c_v);
133 c_v = _mm512_fmadd_ps(a_v_1, br_v_1, c_v);
134 c_v = _mm512_fmadd_ps(a_v_2, br_v_2, c_v);
135 c_v = _mm512_fmadd_ps(a_v_3, br_v_3, c_v);
136 c_v = _mm512_fmadd_ps(a_v_4, br_v_4, c_v);
137 c_v = _mm512_fmadd_ps(a_v_5, br_v_5, c_v);
138 c_v = _mm512_fmadd_ps(a_v_6, br_v_6, c_v);
139 c_v = _mm512_fmadd_ps(a_v_7, br_v_7, c_v);
140 }
141 // Handle remainder r loop
142 for (; r < row_ptr[i + 1]; ++r) {
143 int acbr = col_idx[r];
144 __m512 a_v = _mm512_set1_ps(values[r]);
145 __m512 br_v = _mm512_maskz_loadu_ps(mask_v, B + acbr * ldb + j);
146 c_v = _mm512_fmadd_ps(a_v, br_v, c_v);
147 }
148 _mm512_mask_storeu_ps(C + i * ldc + j, mask_v, c_v);
149 }
150 }
151}
152} // namespace internal
153} // namespace fbgemm
154