1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmSparse.h"
9
10#if defined(__x86_64__) || defined(__i386__) || \
11 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
12#include <immintrin.h>
13#endif
14#include "./MaskAvx2.h"
15
16namespace fbgemm {
17namespace internal {
18
19void SparseDenseMMAvx2(
20 int M,
21 int N,
22 const int* row_ptr,
23 const int* col_idx,
24 const float* values,
25 const float* B,
26 int ldb,
27 float* C,
28 int ldc,
29 bool accum) {
30 // Calcualtes accum ? C += A * B : C = A * B
31 // size of values is equal to number of non-zeros (nnzs)
32 // size of row_ptr is equal to M + 1
33 // size of col_idx is equal to nnzs
34
35 constexpr int VLEN = 8;
36 int j = 0;
37 int effective_N = (N + VLEN - 1) / (2 * VLEN) * (2 * VLEN);
38 for (; j < effective_N; j += (2 * VLEN)) {
39 // r1 is for j:j+VLEN
40 // r2 is for j+VLEN:j+2*VLEN
41 // r2_rem is used to calculate the mask for r2
42 int r2_rem = N - VLEN - j;
43 r2_rem = (r2_rem <= VLEN) ? r2_rem : VLEN;
44 r2_rem = (r2_rem < 0) ? 0 : r2_rem;
45 __m256i mask_v = _mm256_loadu_si256(
46 reinterpret_cast<const __m256i*>(&avx2_ps_or_epi32_masks[r2_rem]));
47 for (int i = 0; i < M; ++i) {
48 __m256 c_v_r1;
49 __m256 c_v_r2;
50 if (accum) {
51 c_v_r1 = _mm256_loadu_ps(C + i * ldc + j);
52 c_v_r2 = _mm256_maskload_ps(C + i * ldc + j + VLEN, mask_v);
53 } else {
54 c_v_r1 = _mm256_set1_ps(0.0f);
55 c_v_r2 = _mm256_set1_ps(0.0f);
56 }
57 int r = row_ptr[i];
58 // unrolled by 4
59 for (; r < row_ptr[i + 1] - 4; r += 4) {
60 int acbr_0 = col_idx[r + 0];
61 int acbr_1 = col_idx[r + 1];
62 int acbr_2 = col_idx[r + 2];
63 int acbr_3 = col_idx[r + 3];
64 __m256 a_v0 = _mm256_set1_ps(values[r + 0]);
65 __m256 a_v1 = _mm256_set1_ps(values[r + 1]);
66 __m256 a_v2 = _mm256_set1_ps(values[r + 2]);
67 __m256 a_v3 = _mm256_set1_ps(values[r + 3]);
68 __m256 br_v_0_r1 = _mm256_loadu_ps(B + acbr_0 * ldb + j);
69 __m256 br_v_1_r1 = _mm256_loadu_ps(B + acbr_1 * ldb + j);
70 __m256 br_v_2_r1 = _mm256_loadu_ps(B + acbr_2 * ldb + j);
71 __m256 br_v_3_r1 = _mm256_loadu_ps(B + acbr_3 * ldb + j);
72 __m256 br_v_0_r2 = _mm256_loadu_ps(B + acbr_0 * ldb + j + VLEN);
73 __m256 br_v_1_r2 = _mm256_loadu_ps(B + acbr_1 * ldb + j + VLEN);
74 __m256 br_v_2_r2 = _mm256_loadu_ps(B + acbr_2 * ldb + j + VLEN);
75 __m256 br_v_3_r2 = _mm256_loadu_ps(B + acbr_3 * ldb + j + VLEN);
76 c_v_r1 = _mm256_fmadd_ps(a_v0, br_v_0_r1, c_v_r1);
77 c_v_r1 = _mm256_fmadd_ps(a_v1, br_v_1_r1, c_v_r1);
78 c_v_r1 = _mm256_fmadd_ps(a_v2, br_v_2_r1, c_v_r1);
79 c_v_r1 = _mm256_fmadd_ps(a_v3, br_v_3_r1, c_v_r1);
80 c_v_r2 = _mm256_fmadd_ps(a_v0, br_v_0_r2, c_v_r2);
81 c_v_r2 = _mm256_fmadd_ps(a_v1, br_v_1_r2, c_v_r2);
82 c_v_r2 = _mm256_fmadd_ps(a_v2, br_v_2_r2, c_v_r2);
83 c_v_r2 = _mm256_fmadd_ps(a_v3, br_v_3_r2, c_v_r2);
84 }
85 for (; r < row_ptr[i + 1]; ++r) {
86 int acbr = col_idx[r];
87 __m256 a_v = _mm256_set1_ps(values[r]);
88 __m256 br_v_r1 = _mm256_loadu_ps(B + acbr * ldb + j);
89 __m256 br_v_r2 = _mm256_maskload_ps(B + acbr * ldb + j + VLEN, mask_v);
90 c_v_r1 = _mm256_fmadd_ps(a_v, br_v_r1, c_v_r1);
91 c_v_r2 = _mm256_fmadd_ps(a_v, br_v_r2, c_v_r2);
92 }
93 _mm256_storeu_ps(C + i * ldc + j, c_v_r1);
94 _mm256_maskstore_ps(C + i * ldc + j + VLEN, mask_v, c_v_r2);
95 } // i loop
96 }
97 // Handle remainder j loop
98 int rem = N - j;
99 if (rem > 0) {
100 for (int i = 0; i < M; ++i) {
101 __m256 c_v_r;
102 __m256i mask_v = _mm256_loadu_si256(
103 reinterpret_cast<const __m256i*>(&avx2_ps_or_epi32_masks[rem]));
104 if (accum) {
105 c_v_r = _mm256_maskload_ps(C + i * ldc + j, mask_v);
106 } else {
107 c_v_r = _mm256_set1_ps(0.0f);
108 }
109 int r = row_ptr[i];
110 for (; r < row_ptr[i + 1] - 4; r += 4) {
111 int acbr_0 = col_idx[r + 0];
112 int acbr_1 = col_idx[r + 1];
113 int acbr_2 = col_idx[r + 2];
114 int acbr_3 = col_idx[r + 3];
115 __m256 a_v0 = _mm256_set1_ps(values[r + 0]);
116 __m256 a_v1 = _mm256_set1_ps(values[r + 1]);
117 __m256 a_v2 = _mm256_set1_ps(values[r + 2]);
118 __m256 a_v3 = _mm256_set1_ps(values[r + 3]);
119 __m256 br_v_r0 = _mm256_maskload_ps(B + acbr_0 * ldb + j, mask_v);
120 __m256 br_v_r1 = _mm256_maskload_ps(B + acbr_1 * ldb + j, mask_v);
121 __m256 br_v_r2 = _mm256_maskload_ps(B + acbr_2 * ldb + j, mask_v);
122 __m256 br_v_r3 = _mm256_maskload_ps(B + acbr_3 * ldb + j, mask_v);
123 c_v_r = _mm256_fmadd_ps(a_v0, br_v_r0, c_v_r);
124 c_v_r = _mm256_fmadd_ps(a_v1, br_v_r1, c_v_r);
125 c_v_r = _mm256_fmadd_ps(a_v2, br_v_r2, c_v_r);
126 c_v_r = _mm256_fmadd_ps(a_v3, br_v_r3, c_v_r);
127 }
128 // Handle remainder r loop
129 for (; r < row_ptr[i + 1]; ++r) {
130 int acbr = col_idx[r];
131 __m256 a_v = _mm256_set1_ps(values[r]);
132 __m256 br_v_r = _mm256_maskload_ps(B + acbr * ldb + j, mask_v);
133 c_v_r = _mm256_fmadd_ps(a_v, br_v_r, c_v_r);
134 }
135 _mm256_maskstore_ps(C + i * ldc + j, mask_v, c_v_r);
136 }
137 }
138}
139} // namespace internal
140} // namespace fbgemm
141