1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 * This source code is licensed under the BSD-style license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7#define FBGEMM_EXPORTS
8#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
9
10#if defined(__x86_64__) || defined(__i386__) || \
11 (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
12#include <immintrin.h>
13#endif
14
15#include "./MaskAvx2.h"
16#include "fbgemm/UtilsAvx2.h"
17
18namespace fbgemm {
19
20PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
21 int OC,
22 int kernel_prod,
23 const int8_t* smat)
24 : OC_(OC), kernel_prod_(kernel_prod) {
25 // The input is in OC T R S layout.
26 // Transpose the input matrix to make packing faster.
27 int8_t* smat_transposed = static_cast<int8_t*>(
28 fbgemmAlignedAlloc(64, OC * kernel_prod * sizeof(int8_t)));
29 for (int i = 0; i < kernel_prod; ++i) {
30 for (int j = 0; j < OC; ++j) {
31 smat_transposed[i * OC + j] = smat[i + j * kernel_prod];
32 }
33 }
34
35 // Allocate packed arrays
36 int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2;
37 pmat_ = static_cast<int8_t*>(fbgemmAlignedAlloc(
38 64, ((OC + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t)));
39
40 // Pack input matrix
41 // The layout is optimized to use vpmaddubsw efficiently (see
42 // genMaddEpi16xNPacked function).
43 // For a group of 32 channels, we have 10 32B SIMD registers.
44 // Denote ith channel jth filter as (i, j)
45 // 0th SIMD register:
46 // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
47 // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
48 // 1st SIMD register:
49 // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
50 // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
51 // 2nd SIMD register:
52 // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
53 // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
54 // 3rd SIMD register:
55 // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
56 // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
57 // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
58 // coefficients
59 // ...
60 //
61 // REMAINDER
62 // If kernel_prod % 4 == 1 for example when kernel_prod == 9
63 // 8th SIMD register:
64 // (0, 8), zero, ..., (7, 8), zero
65 // (16, 8), zero, ..., (23, 8), zero
66 // 9th SIMD register:
67 // (8, 8), zero, ..., (15, 8), zero
68 // (24, 8), zero, ..., (31, 8), zero
69 // We use madd_epi16_packed for this case
70 //
71 // If kernel_prod % 4 == 2 for example when kernel_prod == 10
72 // 8th SIMD register:
73 // (0, 8), (0, 9), ..., (7, 8), (7, 9)
74 // (16, 8), (16, 9), ..., (23, 8), (23, 9)
75 // 9th SIMD register:
76 // (8, 8), (8, 9), ..., (15, 8), (15, 9)
77 // (24, 8), (24, 9), ..., (31, 8), (31, 9)
78 //
79 // If kernel_prod % 4 == 3 for example when kernel_prod == 11
80 // 8th SIMD register:
81 // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
82 // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
83 // 9th SIMD register:
84 // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
85 // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
86 // 10th SIMD register:
87 // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
88 // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
89 // 11th SIMD register:
90 // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
91 // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
92
93 // Allocate buffers
94 auto b_v = static_cast<__m256i*>(
95 fbgemmAlignedAlloc(64, kernel_prod * sizeof(__m256i)));
96 auto b_interleaved_epi16 = static_cast<__m256i*>(
97 fbgemmAlignedAlloc(64, kernel_prod_aligned * sizeof(__m256i)));
98 auto b_interleaved_epi32 = static_cast<__m256i*>(
99 fbgemmAlignedAlloc(64, kernel_prod_aligned * sizeof(__m256i)));
100 for (int k1 = 0; k1 < OC; k1 += 32) {
101 int remainder = OC - k1;
102 if (remainder < 32) {
103 __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
104 internal::avx2_ps_or_epi32_masks[remainder / 4]));
105 for (int i = 0; i < kernel_prod; ++i) {
106 b_v[i] = _mm256_maskload_epi32(
107 reinterpret_cast<const int*>(smat_transposed + i * OC + k1),
108 mask_v);
109 }
110 } else {
111 for (int i = 0; i < kernel_prod; ++i) {
112 b_v[i] = _mm256_lddqu_si256(
113 reinterpret_cast<const __m256i*>(smat_transposed + i * OC + k1));
114 }
115 }
116
117 // Interleave 2 SIMD registers
118 __m256i zero_v = _mm256_setzero_si256();
119 for (int i = 0; i < kernel_prod_aligned / 2; ++i) {
120 if (2 * i + 1 >= kernel_prod) {
121 b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
122 b_interleaved_epi16[2 * i + 1] =
123 _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
124 } else {
125 b_interleaved_epi16[2 * i] =
126 _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
127 b_interleaved_epi16[2 * i + 1] =
128 _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
129 }
130 }
131
132 // Interleave 4 SIMD registers
133 for (int i = 0; i < kernel_prod_aligned / 4; ++i) {
134 b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
135 b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
136 b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
137 b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
138 b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
139 b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
140 b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
141 b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
142 }
143 for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) {
144 b_interleaved_epi32[i] = b_interleaved_epi16[i];
145 }
146
147 for (int i = 0; i < kernel_prod_aligned; ++i) {
148 _mm256_storeu_si256(
149 reinterpret_cast<__m256i*>(
150 &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32]),
151 b_interleaved_epi32[i]);
152 }
153 }
154 fbgemmAlignedFree(b_v);
155 fbgemmAlignedFree(b_interleaved_epi16);
156 fbgemmAlignedFree(b_interleaved_epi32);
157 fbgemmAlignedFree(smat_transposed);
158}
159
160int PackedDepthWiseConvMatrix::addr(int r, int c) {
161 int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2;
162 if (c >= kernel_prod_ / 4 * 4 &&
163 (kernel_prod_ % 4 == 1 || kernel_prod_ % 4 == 2)) {
164 int kBlock = r / 32;
165 int reg_idx = (r % 16) / 8 + c / 4 * 4;
166
167 int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
168
169 int r_ = r % 8;
170 int c_ = c % 4;
171
172 int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_;
173 return blk_idx * 32 + in_blk_idx;
174
175 } else {
176 int kBlock = r / 32;
177 int reg_idx = (r % 16) / 4 + c / 4 * 4;
178
179 int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
180
181 int r_ = r % 4;
182 int c_ = c % 4;
183
184 int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_;
185 return blk_idx * 32 + in_blk_idx;
186 }
187}
188
189void PackedDepthWiseConvMatrix::unpack(int8_t* unpacked_data) {
190 for (int r = 0; r < OC_; ++r) {
191 for (int c = 0; c < kernel_prod_; ++c) {
192 unpacked_data[r * kernel_prod_ + c] = pmat_[addr(r, c)];
193 }
194 }
195}
196
197PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() {
198 fbgemmAlignedFree(pmat_);
199}
200
201} // namespace fbgemm
202