1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3* Copyright 2020-2021 Arm Ltd. and affiliates
4* Copyright 2020-2021 FUJITSU LIMITED
5*
6* Licensed under the Apache License, Version 2.0 (the "License");
7* you may not use this file except in compliance with the License.
8* You may obtain a copy of the License at
9*
10* http://www.apache.org/licenses/LICENSE-2.0
11*
12* Unless required by applicable law or agreed to in writing, software
13* distributed under the License is distributed on an "AS IS" BASIS,
14* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15* See the License for the specific language governing permissions and
16* limitations under the License.
17*******************************************************************************/
18
19#include <map>
20#include <vector>
21
22#include "common/c_types_map.hpp"
23#include "common/convolution_pd.hpp"
24
25#include "cpu/cpu_engine.hpp"
26
27#include "cpu/gemm_convolution.hpp"
28#include "cpu/gemm_x8s8s32x_convolution.hpp"
29#include "cpu/ref_convolution.hpp"
30#include "cpu/ref_convolution_int8.hpp"
31#include "cpu/ref_fused_convolution.hpp"
32
33#if DNNL_X64
34#include "cpu/x64/gemm_bf16_convolution.hpp"
35#include "cpu/x64/ip_convolution.hpp"
36#include "cpu/x64/jit_avx2_1x1_convolution.hpp"
37#include "cpu/x64/jit_avx2_convolution.hpp"
38#include "cpu/x64/jit_avx512_common_1x1_convolution.hpp"
39#include "cpu/x64/jit_avx512_common_convolution.hpp"
40#include "cpu/x64/jit_avx512_core_amx_1x1_convolution.hpp"
41#include "cpu/x64/jit_avx512_core_amx_convolution.hpp"
42#include "cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp"
43#include "cpu/x64/jit_avx512_core_bf16_convolution.hpp"
44#include "cpu/x64/jit_avx512_core_f32_wino_conv_2x3.hpp"
45#include "cpu/x64/jit_avx512_core_f32_wino_conv_4x3.hpp"
46#include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
47#include "cpu/x64/jit_avx512_core_x8s8s32x_convolution.hpp"
48#include "cpu/x64/jit_brdgmm_dw_conv.hpp"
49#include "cpu/x64/jit_brgemm_1x1_conv.hpp"
50#include "cpu/x64/jit_brgemm_conv.hpp"
51#include "cpu/x64/jit_brgemm_conv_bwd.hpp"
52#include "cpu/x64/jit_brgemm_conv_bwd_strided.hpp"
53#include "cpu/x64/jit_brgemm_conv_bwd_w.hpp"
54#include "cpu/x64/jit_sse41_1x1_convolution.hpp"
55#include "cpu/x64/jit_sse41_convolution.hpp"
56#include "cpu/x64/jit_uni_dw_convolution.hpp"
57#include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp"
58#include "cpu/x64/jit_uni_x8s8s32x_convolution.hpp"
59using namespace dnnl::impl::cpu::x64;
60#elif DNNL_AARCH64
61#include "cpu/aarch64/jit_sve_512_1x1_convolution.hpp"
62#include "cpu/aarch64/jit_sve_512_convolution.hpp"
63#include "cpu/aarch64/jit_sve_512_x8s8s32x_convolution.hpp"
64#include "cpu/aarch64/jit_uni_dw_convolution.hpp"
65#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL
66#include "cpu/aarch64/acl_gemm_convolution.hpp"
67#include "cpu/aarch64/acl_indirect_gemm_convolution.hpp"
68#include "cpu/aarch64/acl_winograd_convolution.hpp"
69#endif
70using namespace dnnl::impl::cpu::aarch64;
71#endif
72
73namespace dnnl {
74namespace impl {
75namespace cpu {
76
77namespace {
78using namespace dnnl::impl::data_type;
79using namespace dnnl::impl::prop_kind;
80
81// clang-format off
82const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
83 static const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_CONV_P({
84 // FWD fp
85 {{forward, f32, f32, f32}, {
86 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
87 CPU_INSTANCE_X64(ip_convolution_fwd_t)
88 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
89 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
90 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core>)
91 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core>)
92 CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_fwd_t)
93 CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_fwd_f32_t)
94 CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_2x3_fwd_t)
95 CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_fwd_t)
96 CPU_INSTANCE_AVX512(jit_avx512_common_convolution_fwd_t<f32>)
97 CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_fwd_t)
98 CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_fwd_t)
99 CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_fwd_t)
100 CPU_INSTANCE_SSE41(jit_sse41_1x1_convolution_fwd_t)
101 CPU_INSTANCE_AVX2(jit_avx2_convolution_fwd_t)
102 CPU_INSTANCE_SSE41(jit_sse41_convolution_fwd_t)
103 CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t)
104 CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_fwd_t)
105 CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_fwd_f32_t)
106 CPU_INSTANCE_AARCH64(jit_sve_512_convolution_fwd_t<f32>)
107 CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
108 CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<f32>)
109 CPU_INSTANCE(gemm_convolution_fwd_t)
110 // TODO: Move-up the brgemm<avx2> after performance study
111 CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2>)
112 CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2>)
113 CPU_INSTANCE(ref_convolution_fwd_t)
114 CPU_INSTANCE(ref_fused_convolution_fwd_t)
115 nullptr,
116 }},
117 {{forward, bf16, bf16, f32}, {
118 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
119 CPU_INSTANCE_X64(ip_convolution_fwd_t)
120 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
121 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
122 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
123 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
124 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_bf16>)
125 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_bf16>)
126 CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t<avx512_core, bf16, f32>)
127 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t<f32>)
128 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t)
129 CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<f32>)
130 CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>)
131 CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>)
132 CPU_INSTANCE(ref_convolution_fwd_t)
133 nullptr,
134 }},
135 {{forward, bf16, bf16, bf16}, {
136 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
137 CPU_INSTANCE_X64(ip_convolution_fwd_t)
138 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
139 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
140 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
141 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
142 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_bf16>)
143 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_bf16>)
144 CPU_INSTANCE_AVX512(jit_uni_dw_convolution_fwd_t<avx512_core, bf16, bf16>)
145 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_fwd_t<bf16>)
146 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_fwd_t)
147 CPU_INSTANCE_AVX512(gemm_bf16_convolution_fwd_t<bf16>)
148 CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>)
149 CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>)
150 CPU_INSTANCE(ref_convolution_fwd_t)
151 CPU_INSTANCE(ref_fused_convolution_fwd_t)
152 nullptr,
153 }},
154 {{forward, f16, f16, f32}, {
155 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
156 CPU_INSTANCE_X64(ip_convolution_fwd_t)
157 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx_fp16>)
158 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx_fp16>)
159 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_fp16>)
160 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_fp16>)
161 CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>)
162 CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>)
163 CPU_INSTANCE(ref_convolution_fwd_t)
164 nullptr,
165 }},
166 {{forward, f16, f16, f16}, {
167 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
168 CPU_INSTANCE_X64(ip_convolution_fwd_t)
169 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx_fp16>)
170 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx_fp16>)
171 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_fp16>)
172 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_fp16>)
173 CPU_INSTANCE_AVX2(brgemm_1x1_convolution_fwd_t<avx2_vnni_2>)
174 CPU_INSTANCE_AVX2(brgemm_convolution_fwd_t<avx2_vnni_2>)
175 CPU_INSTANCE_AARCH64_ACL(acl_wino_convolution_fwd_t)
176 CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
177 CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<f16>)
178 CPU_INSTANCE(ref_convolution_fwd_t)
179 CPU_INSTANCE(ref_fused_convolution_fwd_t)
180 nullptr,
181 }},
182 // BWD_D fp
183 {{backward_data, f32, f32, f32}, REG_BWD_D_PK({
184 CPU_INSTANCE_X64(ip_convolution_bwd_data_t)
185 CPU_INSTANCE_AMX(brgemm_convolution_bwd_t<avx512_core_amx>)
186 CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core>)
187 CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_data_t)
188 CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_data_f32_t)
189 CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_bwd_data_t)
190 CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_data_t<f32>)
191 CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_data_t)
192 CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_data_t)
193 CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_data_t)
194 CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_data_t)
195 CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_bwd_data_t)
196 CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_data_f32_t)
197 CPU_INSTANCE_AARCH64(jit_sve_512_convolution_bwd_data_t<f32>)
198 CPU_INSTANCE(gemm_convolution_bwd_data_t)
199 CPU_INSTANCE(ref_convolution_bwd_data_t)
200 nullptr,
201 })},
202 {{backward_data, f32, bf16, bf16}, REG_BWD_D_PK({
203 CPU_INSTANCE_X64(ip_convolution_bwd_data_t)
204 CPU_INSTANCE_AMX(brgemm_convolution_bwd_t<avx512_core_amx>)
205 CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx>)
206 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t<f32, bf16, bf16>)
207 CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_bf16>)
208 CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16, f32>)
209 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t<f32>)
210 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t)
211 CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t<f32>)
212 CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>)
213 CPU_INSTANCE(ref_convolution_bwd_data_t)
214 nullptr,
215 })},
216 {{backward_data, bf16, bf16, bf16}, REG_BWD_D_PK({
217 CPU_INSTANCE_X64(ip_convolution_bwd_data_t)
218 CPU_INSTANCE_AMX(brgemm_convolution_bwd_t<avx512_core_amx>)
219 CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx>)
220 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_data_t<bf16, bf16, bf16>)
221 CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_bf16>)
222 CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_data_t<avx512_core, bf16, bf16>)
223 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_data_t<bf16>)
224 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_data_t)
225 CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_data_t<bf16>)
226 CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>)
227 CPU_INSTANCE(ref_convolution_bwd_data_t)
228 nullptr,
229 })},
230 {{backward_data, f32, f16, f16}, REG_BWD_D_PK({
231 CPU_INSTANCE_X64(ip_convolution_bwd_data_t)
232 CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16>)
233 CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_fp16>)
234 CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>)
235 CPU_INSTANCE(ref_convolution_bwd_data_t)
236 nullptr,
237 })},
238 {{backward_data, f16, f16, f16}, REG_BWD_D_PK({
239 CPU_INSTANCE_X64(ip_convolution_bwd_data_t)
240 CPU_INSTANCE_AMX(brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16>)
241 CPU_INSTANCE_AVX512(brgemm_convolution_bwd_t<avx512_core_fp16>)
242 CPU_INSTANCE_AVX2(brgemm_convolution_bwd_t<avx2_vnni_2>)
243 CPU_INSTANCE(ref_convolution_bwd_data_t)
244 nullptr,
245 })},
246 // BWD_W fp
247 {{backward_weights, f32, f32, f32}, REG_BWD_PK({
248 CPU_INSTANCE_X64(ip_convolution_bwd_weights_t)
249 CPU_INSTANCE_AVX512(jit_avx512_common_dw_convolution_bwd_weights_t)
250 CPU_INSTANCE_AVX512(jit_avx512_common_1x1_convolution_bwd_weights_t)
251 CPU_INSTANCE_AVX512(jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t)
252 CPU_INSTANCE_AVX512(jit_avx512_common_convolution_bwd_weights_t<f32>)
253 CPU_INSTANCE_AVX2(jit_avx2_dw_convolution_bwd_weights_t)
254 CPU_INSTANCE_AVX2(jit_avx2_1x1_convolution_bwd_weights_t)
255 CPU_INSTANCE_SSE41(jit_sse41_dw_convolution_bwd_weights_t)
256 CPU_INSTANCE_AVX2(jit_avx2_convolution_bwd_weights_t)
257 CPU_INSTANCE_AARCH64(jit_sve_512_dw_convolution_bwd_weights_t)
258 CPU_INSTANCE_AARCH64(jit_sve_512_1x1_convolution_bwd_weights_t)
259 CPU_INSTANCE_AARCH64(jit_sve_512_convolution_bwd_weights_t<f32>)
260 CPU_INSTANCE(gemm_convolution_bwd_weights_t)
261 CPU_INSTANCE(ref_convolution_bwd_weights_t)
262 nullptr,
263 })},
264 {{backward_weights, bf16, f32, bf16}, REG_BWD_PK({
265 CPU_INSTANCE_X64(ip_convolution_bwd_weights_t)
266 CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t<avx512_core, bf16, f32>)
267 CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t)
268 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_weights_t)
269 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<f32>)
270 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_weights_t)
271 CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t<f32>)
272 CPU_INSTANCE(ref_convolution_bwd_weights_t)
273 nullptr,
274 })},
275 {{backward_weights, bf16, bf16, bf16}, REG_BWD_PK({
276 CPU_INSTANCE_X64(ip_convolution_bwd_weights_t)
277 CPU_INSTANCE_AVX512(jit_uni_dw_convolution_bwd_weights_t<avx512_core, bf16, bf16>)
278 CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t)
279 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_bwd_weights_t)
280 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<bf16>)
281 CPU_INSTANCE_AVX512(jit_avx512_core_bf16_convolution_bwd_weights_t)
282 CPU_INSTANCE_AVX512(gemm_bf16_convolution_bwd_weights_t<bf16>)
283 CPU_INSTANCE(ref_convolution_bwd_weights_t)
284 nullptr,
285 })},
286 {{backward_weights, f16, f32, f16}, REG_BWD_PK({
287 CPU_INSTANCE_X64(ip_convolution_bwd_weights_t)
288 CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t)
289 CPU_INSTANCE(ref_convolution_bwd_weights_t)
290 nullptr,
291 })},
292 {{backward_weights, f16, f16, f16}, REG_BWD_PK({
293 CPU_INSTANCE_X64(ip_convolution_bwd_weights_t)
294 CPU_INSTANCE_AMX(brgemm_convolution_bwd_weights_t)
295 CPU_INSTANCE(ref_convolution_bwd_weights_t)
296 nullptr,
297 })},
298 // FWD int8 (src:s8)
299 {{forward, s8, s8, f32}, {
300 CPU_INSTANCE_X64(ip_convolution_fwd_t)
301 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
302 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
303 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
304 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
305 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
306 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
307 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
308 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
309 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
310 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
311 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
312 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
313 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, f32>)
314 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
315 CPU_INSTANCE(ref_convolution_int8_fwd_t)
316 CPU_INSTANCE(ref_fused_convolution_fwd_t)
317 nullptr,
318 }},
319 {{forward, s8, s8, bf16}, {
320 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
321 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
322 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
323 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
324 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
325 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
326 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
327 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
328 CPU_INSTANCE(ref_convolution_int8_fwd_t)
329 nullptr,
330 }},
331 {{forward, s8, s8, s32}, {
332 CPU_INSTANCE_X64(ip_convolution_fwd_t)
333 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
334 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
335 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
336 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
337 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
338 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
339 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
340 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
341 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
342 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
343 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
344 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
345 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, s32>)
346 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
347 CPU_INSTANCE(ref_convolution_int8_fwd_t)
348 CPU_INSTANCE(ref_fused_convolution_fwd_t)
349 nullptr,
350 }},
351 {{forward, s8, s8, s8}, {
352 CPU_INSTANCE_X64(ip_convolution_fwd_t)
353 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
354 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
355 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
356 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
357 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
358 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
359 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
360 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
361 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
362 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
363 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
364 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
365 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, s8>)
366 CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<s8, s8, s8, s32>)
367 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
368 CPU_INSTANCE(ref_convolution_int8_fwd_t)
369 CPU_INSTANCE(ref_fused_convolution_fwd_t)
370 nullptr,
371 }},
372 {{forward, s8, s8, u8}, {
373 CPU_INSTANCE_X64(ip_convolution_fwd_t)
374 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
375 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
376 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
377 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
378 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
379 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
380 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
381 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
382 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
383 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
384 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
385 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
386 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<s8, u8>)
387 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
388 CPU_INSTANCE(ref_convolution_int8_fwd_t)
389 CPU_INSTANCE(ref_fused_convolution_fwd_t)
390 nullptr,
391 }},
392 // FWD int8 (src:u8)
393 {{forward, u8, s8, f32}, {
394 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
395 CPU_INSTANCE_X64(ip_convolution_fwd_t)
396 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
397 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
398 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
399 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
400 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
401 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
402 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
403 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
404 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
405 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
406 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
407 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
408 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, f32>)
409 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
410 CPU_INSTANCE(ref_convolution_int8_fwd_t)
411 nullptr,
412 }},
413 {{forward, u8, s8, bf16}, {
414 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
415 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
416 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
417 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
418 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
419 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
420 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
421 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
422 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
423 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
424 CPU_INSTANCE(ref_convolution_int8_fwd_t)
425 nullptr,
426 }},
427 {{forward, u8, s8, s32}, {
428 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
429 CPU_INSTANCE_X64(ip_convolution_fwd_t)
430 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
431 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
432 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
433 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
434 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
435 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
436 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
437 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
438 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
439 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
440 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
441 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
442 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, s32>)
443 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
444 CPU_INSTANCE(ref_convolution_int8_fwd_t)
445 nullptr,
446 }},
447 {{forward, u8, s8, s8}, {
448 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
449 CPU_INSTANCE_X64(ip_convolution_fwd_t)
450 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
451 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
452 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
453 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
454 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
455 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
456 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
457 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
458 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
459 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
460 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
461 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
462 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, s8>)
463 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
464 CPU_INSTANCE(ref_convolution_int8_fwd_t)
465 CPU_INSTANCE(ref_fused_convolution_fwd_t)
466 nullptr,
467 }},
468 {{forward, u8, s8, u8}, {
469 CPU_INSTANCE_AVX512(brdgmm_dw_convolution_fwd_t)
470 CPU_INSTANCE_X64(ip_convolution_fwd_t)
471 CPU_INSTANCE_AMX(brgemm_1x1_convolution_fwd_t<avx512_core_amx>)
472 CPU_INSTANCE_AMX(brgemm_convolution_fwd_t<avx512_core_amx>)
473 CPU_INSTANCE_AMX(jit_avx512_core_amx_1x1_convolution_fwd_t)
474 CPU_INSTANCE_AMX(jit_avx512_core_amx_convolution_fwd_t)
475 CPU_INSTANCE_AVX512(brgemm_1x1_convolution_fwd_t<avx512_core_vnni>)
476 CPU_INSTANCE_AVX512(brgemm_convolution_fwd_t<avx512_core_vnni>)
477 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t)
478 CPU_INSTANCE_AVX512(jit_avx512_core_x8s8s32x_convolution_fwd_t)
479 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_1x1_convolution_fwd_t<avx2>)
480 CPU_INSTANCE_AVX2(jit_uni_x8s8s32x_convolution_fwd_t<avx2>)
481 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_1x1_convolution_fwd_t<sse41>)
482 CPU_INSTANCE_SSE41(jit_uni_x8s8s32x_convolution_fwd_t<sse41>)
483 CPU_INSTANCE_AARCH64(jit_sve_512_x8s8s32x_convolution_fwd_t<u8, u8>)
484 CPU_INSTANCE(gemm_x8s8s32x_convolution_fwd_t)
485 CPU_INSTANCE(ref_convolution_int8_fwd_t)
486 CPU_INSTANCE(ref_fused_convolution_fwd_t)
487 nullptr,
488 }},
489 // BWD int8 (diff_dst:u8)
490 {{backward_data, f32, s8, u8}, REG_BWD_D_PK({
491 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
492 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
493 nullptr,
494 })},
495 {{backward_data, bf16, s8, u8}, REG_BWD_D_PK({
496 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
497 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
498 nullptr,
499 })},
500 {{backward_data, s32, s8, u8}, REG_BWD_D_PK({
501 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
502 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
503 nullptr,
504 })},
505 {{backward_data, s8, s8, u8}, REG_BWD_D_PK({
506 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
507 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
508 nullptr,
509 })},
510 {{backward_data, u8, s8, u8}, REG_BWD_D_PK({
511 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
512 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
513 nullptr,
514 })},
515 // BWD int8 (diff_dst:s8)
516 {{backward_data, f32, s8, s8}, REG_BWD_D_PK({
517 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
518 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
519 nullptr,
520 })},
521 {{backward_data, bf16, s8, s8}, REG_BWD_D_PK({
522 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
523 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
524 nullptr,
525 })},
526 {{backward_data, s32, s8, s8}, REG_BWD_D_PK({
527 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
528 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
529 nullptr,
530 })},
531 {{backward_data, s8, s8, s8}, REG_BWD_D_PK({
532 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
533 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
534 nullptr,
535 })},
536 {{backward_data, u8, s8, s8}, REG_BWD_D_PK({
537 CPU_INSTANCE(gemm_x8s8s32x_convolution_bwd_data_t)
538 CPU_INSTANCE(ref_convolution_int8_bwd_data_t)
539 nullptr,
540 })},
541 });
542 return the_map;
543}
544// clang-format on
545} // namespace
546
547const impl_list_item_t *get_convolution_impl_list(
548 const convolution_desc_t *desc) {
549 static const impl_list_item_t empty_list[] = {nullptr};
550
551 const bool is_fwd = utils::one_of(
552 desc->prop_kind, forward_training, forward_inference);
553 prop_kind_t prop_kind = is_fwd ? forward : desc->prop_kind;
554
555 pk_dt_impl_key_t key {
556 prop_kind,
557 conv_prop_invariant_src_d(desc)->data_type,
558 conv_prop_invariant_wei_d(desc)->data_type,
559 conv_prop_invariant_dst_d(desc)->data_type,
560 };
561
562 const auto impl_list_it = impl_list_map().find(key);
563 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
564 : empty_list;
565}
566
567} // namespace cpu
568} // namespace impl
569} // namespace dnnl
570