1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #define FBGEMM_EXPORTS |
8 | #include "fbgemm/FbgemmI8DepthwiseAvx2.h" |
9 | |
10 | #if defined(__x86_64__) || defined(__i386__) || \ |
11 | (defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))) |
12 | #include <immintrin.h> |
13 | #endif |
14 | |
15 | #include "./MaskAvx2.h" |
16 | #include "fbgemm/UtilsAvx2.h" |
17 | |
18 | namespace fbgemm { |
19 | |
20 | PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( |
21 | int OC, |
22 | int kernel_prod, |
23 | const int8_t* smat) |
24 | : OC_(OC), kernel_prod_(kernel_prod) { |
25 | // The input is in OC T R S layout. |
26 | // Transpose the input matrix to make packing faster. |
27 | int8_t* smat_transposed = static_cast<int8_t*>( |
28 | fbgemmAlignedAlloc(64, OC * kernel_prod * sizeof(int8_t))); |
29 | for (int i = 0; i < kernel_prod; ++i) { |
30 | for (int j = 0; j < OC; ++j) { |
31 | smat_transposed[i * OC + j] = smat[i + j * kernel_prod]; |
32 | } |
33 | } |
34 | |
35 | // Allocate packed arrays |
36 | int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2; |
37 | pmat_ = static_cast<int8_t*>(fbgemmAlignedAlloc( |
38 | 64, ((OC + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t))); |
39 | |
40 | // Pack input matrix |
41 | // The layout is optimized to use vpmaddubsw efficiently (see |
42 | // genMaddEpi16xNPacked function). |
43 | // For a group of 32 channels, we have 10 32B SIMD registers. |
44 | // Denote ith channel jth filter as (i, j) |
45 | // 0th SIMD register: |
46 | // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3) |
47 | // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3) |
48 | // 1st SIMD register: |
49 | // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3) |
50 | // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3) |
51 | // 2nd SIMD register: |
52 | // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3) |
53 | // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3) |
54 | // 3rd SIMD register: |
55 | // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3) |
56 | // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3) |
57 | // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter |
58 | // coefficients |
59 | // ... |
60 | // |
61 | // REMAINDER |
62 | // If kernel_prod % 4 == 1 for example when kernel_prod == 9 |
63 | // 8th SIMD register: |
64 | // (0, 8), zero, ..., (7, 8), zero |
65 | // (16, 8), zero, ..., (23, 8), zero |
66 | // 9th SIMD register: |
67 | // (8, 8), zero, ..., (15, 8), zero |
68 | // (24, 8), zero, ..., (31, 8), zero |
69 | // We use madd_epi16_packed for this case |
70 | // |
71 | // If kernel_prod % 4 == 2 for example when kernel_prod == 10 |
72 | // 8th SIMD register: |
73 | // (0, 8), (0, 9), ..., (7, 8), (7, 9) |
74 | // (16, 8), (16, 9), ..., (23, 8), (23, 9) |
75 | // 9th SIMD register: |
76 | // (8, 8), (8, 9), ..., (15, 8), (15, 9) |
77 | // (24, 8), (24, 9), ..., (31, 8), (31, 9) |
78 | // |
79 | // If kernel_prod % 4 == 3 for example when kernel_prod == 11 |
80 | // 8th SIMD register: |
81 | // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero |
82 | // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero |
83 | // 9th SIMD register: |
84 | // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero |
85 | // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero |
86 | // 10th SIMD register: |
87 | // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero |
88 | // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero |
89 | // 11th SIMD register: |
90 | // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero |
91 | // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero |
92 | |
93 | // Allocate buffers |
94 | auto b_v = static_cast<__m256i*>( |
95 | fbgemmAlignedAlloc(64, kernel_prod * sizeof(__m256i))); |
96 | auto b_interleaved_epi16 = static_cast<__m256i*>( |
97 | fbgemmAlignedAlloc(64, kernel_prod_aligned * sizeof(__m256i))); |
98 | auto b_interleaved_epi32 = static_cast<__m256i*>( |
99 | fbgemmAlignedAlloc(64, kernel_prod_aligned * sizeof(__m256i))); |
100 | for (int k1 = 0; k1 < OC; k1 += 32) { |
101 | int remainder = OC - k1; |
102 | if (remainder < 32) { |
103 | __m256i mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( |
104 | internal::avx2_ps_or_epi32_masks[remainder / 4])); |
105 | for (int i = 0; i < kernel_prod; ++i) { |
106 | b_v[i] = _mm256_maskload_epi32( |
107 | reinterpret_cast<const int*>(smat_transposed + i * OC + k1), |
108 | mask_v); |
109 | } |
110 | } else { |
111 | for (int i = 0; i < kernel_prod; ++i) { |
112 | b_v[i] = _mm256_lddqu_si256( |
113 | reinterpret_cast<const __m256i*>(smat_transposed + i * OC + k1)); |
114 | } |
115 | } |
116 | |
117 | // Interleave 2 SIMD registers |
118 | __m256i zero_v = _mm256_setzero_si256(); |
119 | for (int i = 0; i < kernel_prod_aligned / 2; ++i) { |
120 | if (2 * i + 1 >= kernel_prod) { |
121 | b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v); |
122 | b_interleaved_epi16[2 * i + 1] = |
123 | _mm256_unpackhi_epi8(b_v[2 * i], zero_v); |
124 | } else { |
125 | b_interleaved_epi16[2 * i] = |
126 | _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]); |
127 | b_interleaved_epi16[2 * i + 1] = |
128 | _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]); |
129 | } |
130 | } |
131 | |
132 | // Interleave 4 SIMD registers |
133 | for (int i = 0; i < kernel_prod_aligned / 4; ++i) { |
134 | b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16( |
135 | b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); |
136 | b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16( |
137 | b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]); |
138 | b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16( |
139 | b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); |
140 | b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16( |
141 | b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]); |
142 | } |
143 | for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) { |
144 | b_interleaved_epi32[i] = b_interleaved_epi16[i]; |
145 | } |
146 | |
147 | for (int i = 0; i < kernel_prod_aligned; ++i) { |
148 | _mm256_storeu_si256( |
149 | reinterpret_cast<__m256i*>( |
150 | &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32]), |
151 | b_interleaved_epi32[i]); |
152 | } |
153 | } |
154 | fbgemmAlignedFree(b_v); |
155 | fbgemmAlignedFree(b_interleaved_epi16); |
156 | fbgemmAlignedFree(b_interleaved_epi32); |
157 | fbgemmAlignedFree(smat_transposed); |
158 | } |
159 | |
160 | int PackedDepthWiseConvMatrix::addr(int r, int c) { |
161 | int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2; |
162 | if (c >= kernel_prod_ / 4 * 4 && |
163 | (kernel_prod_ % 4 == 1 || kernel_prod_ % 4 == 2)) { |
164 | int kBlock = r / 32; |
165 | int reg_idx = (r % 16) / 8 + c / 4 * 4; |
166 | |
167 | int blk_idx = kBlock * kernel_prod_aligned + reg_idx; |
168 | |
169 | int r_ = r % 8; |
170 | int c_ = c % 4; |
171 | |
172 | int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_; |
173 | return blk_idx * 32 + in_blk_idx; |
174 | |
175 | } else { |
176 | int kBlock = r / 32; |
177 | int reg_idx = (r % 16) / 4 + c / 4 * 4; |
178 | |
179 | int blk_idx = kBlock * kernel_prod_aligned + reg_idx; |
180 | |
181 | int r_ = r % 4; |
182 | int c_ = c % 4; |
183 | |
184 | int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_; |
185 | return blk_idx * 32 + in_blk_idx; |
186 | } |
187 | } |
188 | |
189 | void PackedDepthWiseConvMatrix::unpack(int8_t* unpacked_data) { |
190 | for (int r = 0; r < OC_; ++r) { |
191 | for (int c = 0; c < kernel_prod_; ++c) { |
192 | unpacked_data[r * kernel_prod_ + c] = pmat_[addr(r, c)]; |
193 | } |
194 | } |
195 | } |
196 | |
197 | PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() { |
198 | fbgemmAlignedFree(pmat_); |
199 | } |
200 | |
201 | } // namespace fbgemm |
202 | |