1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * Copyright 2020-2021 Arm Ltd. and affiliates |
4 | * Copyright 2020-2021 FUJITSU LIMITED |
5 | * |
6 | * Licensed under the Apache License, Version 2.0 (the "License"); |
7 | * you may not use this file except in compliance with the License. |
8 | * You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, software |
13 | * distributed under the License is distributed on an "AS IS" BASIS, |
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | * See the License for the specific language governing permissions and |
16 | * limitations under the License. |
17 | *******************************************************************************/ |
18 | |
19 | #include <map> |
20 | #include <vector> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/convolution_pd.hpp" |
24 | |
25 | #include "cpu/cpu_engine.hpp" |
26 | |
27 | #include "cpu/gemm_convolution.hpp" |
28 | #include "cpu/gemm_x8s8s32x_convolution.hpp" |
29 | #include "cpu/ref_convolution.hpp" |
30 | #include "cpu/ref_convolution_int8.hpp" |
31 | #include "cpu/ref_fused_convolution.hpp" |
32 | |
33 | #if DNNL_X64 |
34 | #include "cpu/x64/gemm_bf16_convolution.hpp" |
35 | #include "cpu/x64/ip_convolution.hpp" |
36 | #include "cpu/x64/jit_avx2_1x1_convolution.hpp" |
37 | #include "cpu/x64/jit_avx2_convolution.hpp" |
38 | #include "cpu/x64/jit_avx512_common_1x1_convolution.hpp" |
39 | #include "cpu/x64/jit_avx512_common_convolution.hpp" |
40 | #include "cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp" |
41 | #include "cpu/x64/jit_avx512_core_amx_convolution.hpp" |
42 | #include "cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp" |
43 | #include "cpu/x64/jit_avx512_core_bf16_convolution.hpp" |
44 | #include "cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp" |
45 | #include "cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp" |
46 | #include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" |
47 | #include "cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp" |
48 | #include "cpu/x64/jit_brdgmm_dw_conv.hpp" |
49 | #include "cpu/x64/jit_brgemm_1x1_conv.hpp" |
50 | #include "cpu/x64/jit_brgemm_conv.hpp" |
51 | #include "cpu/x64/jit_brgemm_conv_bwd.hpp" |
52 | #include "cpu/x64/jit_brgemm_conv_bwd_strided.hpp" |
53 | #include "cpu/x64/jit_brgemm_conv_bwd_w.hpp" |
54 | #include "cpu/x64/jit_sse41_1x1_convolution.hpp" |
55 | #include "cpu/x64/jit_sse41_convolution.hpp" |
56 | #include "cpu/x64/jit_uni_dw_convolution.hpp" |
57 | #include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp" |
58 | #include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp" |
59 | using namespace dnnl::impl::cpu::x64; |
60 | #elif DNNL_AARCH64 |
61 | #include "cpu/aarch64/jit_sve_512_1x1_convolution.hpp" |
62 | #include "cpu/aarch64/jit_sve_512_convolution.hpp" |
63 | #include "cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp" |
64 | #include "cpu/aarch64/jit_uni_dw_convolution.hpp" |
65 | #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL |
66 | #include "cpu/aarch64/acl_gemm_convolution.hpp" |
67 | #include "cpu/aarch64/acl_indirect_gemm_convolution.hpp" |
68 | #include "cpu/aarch64/acl_winograd_convolution.hpp" |
69 | #endif |
70 | using namespace dnnl::impl::cpu::aarch64; |
71 | #endif |
72 | |
73 | namespace dnnl { |
74 | namespace impl { |
75 | namespace cpu { |
76 | |
77 | namespace { |
78 | using namespace dnnl::impl::data_type; |
79 | using namespace dnnl::impl::prop_kind; |
80 | |
81 | // clang-format off |
82 | const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() { |
83 | static const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_CONV_P({ |
84 | // FWD fp |
85 | {{forward, f32, f32, f32}, { |
86 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
87 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
88 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
89 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
90 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core>) |
91 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core>) |
92 | CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_fwd_t) |
93 | CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_fwd_f32_t) |
94 | CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_2x3_fwd_t) |
95 | CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_fwd_t) |
96 | CPU_INSTANCE_AVX512(jit_avx512_common_convolution_fwd_t<f32>) |
97 | CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_fwd_t) |
98 | CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_fwd_t) |
99 | CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_fwd_t) |
100 | CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t) |
101 | CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t) |
102 | CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t) |
103 | CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t) |
104 | CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t) |
105 | CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t) |
106 | CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t<f32>) |
107 | CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) |
108 | CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<f32>) |
109 | CPU_INSTANCE(gemm_convolution_fwd_t) |
110 | // TODO: Move-up the brgemm<avx2> after performance study |
111 | CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2>) |
112 | CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2>) |
113 | CPU_INSTANCE(ref_convolution_fwd_t) |
114 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
115 | nullptr, |
116 | }}, |
117 | {{forward, bf16, bf16, f32}, { |
118 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
119 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
120 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
121 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
122 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
123 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
124 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_bf16>) |
125 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_bf16>) |
126 | CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t<avx512_core, bf16, f32>) |
127 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t<f32>) |
128 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t) |
129 | CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<f32>) |
130 | CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>) |
131 | CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>) |
132 | CPU_INSTANCE(ref_convolution_fwd_t) |
133 | nullptr, |
134 | }}, |
135 | {{forward, bf16, bf16, bf16}, { |
136 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
137 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
138 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
139 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
140 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
141 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
142 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_bf16>) |
143 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_bf16>) |
144 | CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t<avx512_core, bf16, bf16>) |
145 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t<bf16>) |
146 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t) |
147 | CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<bf16>) |
148 | CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>) |
149 | CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>) |
150 | CPU_INSTANCE(ref_convolution_fwd_t) |
151 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
152 | nullptr, |
153 | }}, |
154 | {{forward, f16, f16, f32}, { |
155 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
156 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
157 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx_fp16>) |
158 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx_fp16>) |
159 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_fp16>) |
160 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_fp16>) |
161 | CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>) |
162 | CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>) |
163 | CPU_INSTANCE(ref_convolution_fwd_t) |
164 | nullptr, |
165 | }}, |
166 | {{forward, f16, f16, f16}, { |
167 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
168 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
169 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx_fp16>) |
170 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx_fp16>) |
171 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_fp16>) |
172 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_fp16>) |
173 | CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>) |
174 | CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>) |
175 | CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t) |
176 | CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t) |
177 | CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<f16>) |
178 | CPU_INSTANCE(ref_convolution_fwd_t) |
179 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
180 | nullptr, |
181 | }}, |
182 | // BWD_D fp |
183 | {{backward_data, f32, f32, f32}, REG_BWD_D_PK({ |
184 | CPU_INSTANCE_X64(ip_convolution_bwd_data_t) |
185 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_t<avx512_core_amx>) |
186 | CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core>) |
187 | CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_data_t) |
188 | CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_data_f32_t) |
189 | CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_bwd_data_t) |
190 | CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_data_t<f32>) |
191 | CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_data_t) |
192 | CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_data_t) |
193 | CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_data_t) |
194 | CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_data_t) |
195 | CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_bwd_data_t) |
196 | CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_data_f32_t) |
197 | CPU_INSTANCE_AARCH64(jit_sve_512_convolution_bwd_data_t<f32>) |
198 | CPU_INSTANCE(gemm_convolution_bwd_data_t) |
199 | CPU_INSTANCE(ref_convolution_bwd_data_t) |
200 | nullptr, |
201 | })}, |
202 | {{backward_data, f32, bf16, bf16}, REG_BWD_D_PK({ |
203 | CPU_INSTANCE_X64(ip_convolution_bwd_data_t) |
204 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_t<avx512_core_amx>) |
205 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx>) |
206 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t<f32, bf16, bf16>) |
207 | CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_bf16>) |
208 | CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16, f32>) |
209 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t<f32>) |
210 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t) |
211 | CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t<f32>) |
212 | CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>) |
213 | CPU_INSTANCE(ref_convolution_bwd_data_t) |
214 | nullptr, |
215 | })}, |
216 | {{backward_data, bf16, bf16, bf16}, REG_BWD_D_PK({ |
217 | CPU_INSTANCE_X64(ip_convolution_bwd_data_t) |
218 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_t<avx512_core_amx>) |
219 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx>) |
220 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t<bf16, bf16, bf16>) |
221 | CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_bf16>) |
222 | CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16, bf16>) |
223 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t<bf16>) |
224 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t) |
225 | CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t<bf16>) |
226 | CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>) |
227 | CPU_INSTANCE(ref_convolution_bwd_data_t) |
228 | nullptr, |
229 | })}, |
230 | {{backward_data, f32, f16, f16}, REG_BWD_D_PK({ |
231 | CPU_INSTANCE_X64(ip_convolution_bwd_data_t) |
232 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16>) |
233 | CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_fp16>) |
234 | CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>) |
235 | CPU_INSTANCE(ref_convolution_bwd_data_t) |
236 | nullptr, |
237 | })}, |
238 | {{backward_data, f16, f16, f16}, REG_BWD_D_PK({ |
239 | CPU_INSTANCE_X64(ip_convolution_bwd_data_t) |
240 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16>) |
241 | CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_fp16>) |
242 | CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>) |
243 | CPU_INSTANCE(ref_convolution_bwd_data_t) |
244 | nullptr, |
245 | })}, |
246 | // BWD_W fp |
247 | {{backward_weights, f32, f32, f32}, REG_BWD_PK({ |
248 | CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) |
249 | CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_weights_t) |
250 | CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_weights_t) |
251 | CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t) |
252 | CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_weights_t<f32>) |
253 | CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_weights_t) |
254 | CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_weights_t) |
255 | CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_weights_t) |
256 | CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_weights_t) |
257 | CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_bwd_weights_t) |
258 | CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_weights_t) |
259 | CPU_INSTANCE_AARCH64(jit_sve_512_convolution_bwd_weights_t<f32>) |
260 | CPU_INSTANCE(gemm_convolution_bwd_weights_t) |
261 | CPU_INSTANCE(ref_convolution_bwd_weights_t) |
262 | nullptr, |
263 | })}, |
264 | {{backward_weights, bf16, f32, bf16}, REG_BWD_PK({ |
265 | CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) |
266 | CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t<avx512_core, bf16, f32>) |
267 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t) |
268 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_weights_t) |
269 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<f32>) |
270 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_weights_t) |
271 | CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t<f32>) |
272 | CPU_INSTANCE(ref_convolution_bwd_weights_t) |
273 | nullptr, |
274 | })}, |
275 | {{backward_weights, bf16, bf16, bf16}, REG_BWD_PK({ |
276 | CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) |
277 | CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t<avx512_core, bf16, bf16>) |
278 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t) |
279 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_weights_t) |
280 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<bf16>) |
281 | CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_weights_t) |
282 | CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t<bf16>) |
283 | CPU_INSTANCE(ref_convolution_bwd_weights_t) |
284 | nullptr, |
285 | })}, |
286 | {{backward_weights, f16, f32, f16}, REG_BWD_PK({ |
287 | CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) |
288 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t) |
289 | CPU_INSTANCE(ref_convolution_bwd_weights_t) |
290 | nullptr, |
291 | })}, |
292 | {{backward_weights, f16, f16, f16}, REG_BWD_PK({ |
293 | CPU_INSTANCE_X64(ip_convolution_bwd_weights_t) |
294 | CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t) |
295 | CPU_INSTANCE(ref_convolution_bwd_weights_t) |
296 | nullptr, |
297 | })}, |
298 | // FWD int8 (src:s8) |
299 | {{forward, s8, s8, f32}, { |
300 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
301 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
302 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
303 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
304 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
305 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
306 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
307 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
308 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
309 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
310 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
311 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
312 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
313 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, f32>) |
314 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
315 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
316 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
317 | nullptr, |
318 | }}, |
319 | {{forward, s8, s8, bf16}, { |
320 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
321 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
322 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
323 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
324 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
325 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
326 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
327 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
328 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
329 | nullptr, |
330 | }}, |
331 | {{forward, s8, s8, s32}, { |
332 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
333 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
334 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
335 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
336 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
337 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
338 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
339 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
340 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
341 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
342 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
343 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
344 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
345 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, s32>) |
346 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
347 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
348 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
349 | nullptr, |
350 | }}, |
351 | {{forward, s8, s8, s8}, { |
352 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
353 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
354 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
355 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
356 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
357 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
358 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
359 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
360 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
361 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
362 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
363 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
364 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
365 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, s8>) |
366 | CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<s8, s8, s8, s32>) |
367 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
368 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
369 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
370 | nullptr, |
371 | }}, |
372 | {{forward, s8, s8, u8}, { |
373 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
374 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
375 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
376 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
377 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
378 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
379 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
380 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
381 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
382 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
383 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
384 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
385 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
386 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, u8>) |
387 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
388 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
389 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
390 | nullptr, |
391 | }}, |
392 | // FWD int8 (src:u8) |
393 | {{forward, u8, s8, f32}, { |
394 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
395 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
396 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
397 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
398 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
399 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
400 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
401 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
402 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
403 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
404 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
405 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
406 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
407 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
408 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, f32>) |
409 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
410 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
411 | nullptr, |
412 | }}, |
413 | {{forward, u8, s8, bf16}, { |
414 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
415 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
416 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
417 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
418 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
419 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
420 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
421 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
422 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
423 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
424 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
425 | nullptr, |
426 | }}, |
427 | {{forward, u8, s8, s32}, { |
428 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
429 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
430 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
431 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
432 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
433 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
434 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
435 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
436 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
437 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
438 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
439 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
440 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
441 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
442 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, s32>) |
443 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
444 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
445 | nullptr, |
446 | }}, |
447 | {{forward, u8, s8, s8}, { |
448 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
449 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
450 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
451 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
452 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
453 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
454 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
455 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
456 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
457 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
458 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
459 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
460 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
461 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
462 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, s8>) |
463 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
464 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
465 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
466 | nullptr, |
467 | }}, |
468 | {{forward, u8, s8, u8}, { |
469 | CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t) |
470 | CPU_INSTANCE_X64(ip_convolution_fwd_t) |
471 | CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>) |
472 | CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>) |
473 | CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t) |
474 | CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t) |
475 | CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>) |
476 | CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>) |
477 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t) |
478 | CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t) |
479 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>) |
480 | CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>) |
481 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>) |
482 | CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>) |
483 | CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, u8>) |
484 | CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t) |
485 | CPU_INSTANCE(ref_convolution_int8_fwd_t) |
486 | CPU_INSTANCE(ref_fused_convolution_fwd_t) |
487 | nullptr, |
488 | }}, |
489 | // BWD int8 (diff_dst:u8) |
490 | {{backward_data, f32, s8, u8}, REG_BWD_D_PK({ |
491 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
492 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
493 | nullptr, |
494 | })}, |
495 | {{backward_data, bf16, s8, u8}, REG_BWD_D_PK({ |
496 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
497 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
498 | nullptr, |
499 | })}, |
500 | {{backward_data, s32, s8, u8}, REG_BWD_D_PK({ |
501 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
502 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
503 | nullptr, |
504 | })}, |
505 | {{backward_data, s8, s8, u8}, REG_BWD_D_PK({ |
506 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
507 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
508 | nullptr, |
509 | })}, |
510 | {{backward_data, u8, s8, u8}, REG_BWD_D_PK({ |
511 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
512 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
513 | nullptr, |
514 | })}, |
515 | // BWD int8 (diff_dst:s8) |
516 | {{backward_data, f32, s8, s8}, REG_BWD_D_PK({ |
517 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
518 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
519 | nullptr, |
520 | })}, |
521 | {{backward_data, bf16, s8, s8}, REG_BWD_D_PK({ |
522 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
523 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
524 | nullptr, |
525 | })}, |
526 | {{backward_data, s32, s8, s8}, REG_BWD_D_PK({ |
527 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
528 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
529 | nullptr, |
530 | })}, |
531 | {{backward_data, s8, s8, s8}, REG_BWD_D_PK({ |
532 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
533 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
534 | nullptr, |
535 | })}, |
536 | {{backward_data, u8, s8, s8}, REG_BWD_D_PK({ |
537 | CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t) |
538 | CPU_INSTANCE(ref_convolution_int8_bwd_data_t) |
539 | nullptr, |
540 | })}, |
541 | }); |
542 | return the_map; |
543 | } |
544 | // clang-format on |
545 | } // namespace |
546 | |
547 | const impl_list_item_t *get_convolution_impl_list( |
548 | const convolution_desc_t *desc) { |
549 | static const impl_list_item_t empty_list[] = {nullptr}; |
550 | |
551 | const bool is_fwd = utils::one_of( |
552 | desc->prop_kind, forward_training, forward_inference); |
553 | prop_kind_t prop_kind = is_fwd ? forward : desc->prop_kind; |
554 | |
555 | pk_dt_impl_key_t key { |
556 | prop_kind, |
557 | conv_prop_invariant_src_d(desc)->data_type, |
558 | conv_prop_invariant_wei_d(desc)->data_type, |
559 | conv_prop_invariant_dst_d(desc)->data_type, |
560 | }; |
561 | |
562 | const auto impl_list_it = impl_list_map().find(key); |
563 | return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data() |
564 | : empty_list; |
565 | } |
566 | |
567 | } // namespace cpu |
568 | } // namespace impl |
569 | } // namespace dnnl |
570 | |