1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/cpu_engine.hpp" |
18 | |
19 | #include "cpu/gemm_inner_product.hpp" |
20 | #include "cpu/gemm_x8s8s32x_inner_product.hpp" |
21 | #include "cpu/ref_inner_product.hpp" |
22 | #include "cpu/ref_inner_product_int8.hpp" |
23 | |
24 | #if DNNL_X64 |
25 | #include "cpu/x64/gemm_bf16_inner_product.hpp" |
26 | #include "cpu/x64/jit_brgemm_inner_product.hpp" |
27 | using namespace dnnl::impl::cpu::x64; |
28 | #endif |
29 | |
30 | #if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL |
31 | #include "cpu/aarch64/acl_inner_product.hpp" |
32 | using namespace dnnl::impl::cpu::aarch64; |
33 | #endif |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | namespace cpu { |
38 | |
39 | namespace { |
40 | using namespace dnnl::impl::data_type; |
41 | using namespace dnnl::impl::prop_kind; |
42 | |
43 | // clang-format off |
44 | const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() { |
45 | static const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_IP_P({ |
46 | {{forward, f32, f32, f32}, { |
47 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) // bf32 |
48 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>) |
49 | CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) |
50 | CPU_INSTANCE(gemm_inner_product_fwd_t<f32>) |
51 | CPU_INSTANCE(ref_inner_product_fwd_t) |
52 | nullptr, |
53 | }}, |
54 | {{forward, bf16, bf16, f32}, { |
55 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
56 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_bf16>) |
57 | CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t<f32>) |
58 | CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>) |
59 | CPU_INSTANCE(ref_inner_product_fwd_t) |
60 | nullptr, |
61 | }}, |
62 | {{forward, bf16, bf16, bf16}, { |
63 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
64 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_bf16>) |
65 | CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t<bf16>) |
66 | CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>) |
67 | CPU_INSTANCE(ref_inner_product_fwd_t) |
68 | nullptr, |
69 | }}, |
70 | {{forward, f16, f16, f32}, { |
71 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx_fp16>) |
72 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_fp16>) |
73 | CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>) |
74 | CPU_INSTANCE(ref_inner_product_fwd_t) |
75 | nullptr, |
76 | }}, |
77 | {{forward, f16, f16, f16}, { |
78 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx_fp16>) |
79 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_fp16>) |
80 | CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>) |
81 | CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) |
82 | CPU_INSTANCE(ref_inner_product_fwd_t) |
83 | nullptr, |
84 | }}, |
85 | {{backward_data, f32, f32, f32}, REG_BWD_PK({ |
86 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32 |
87 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>) |
88 | CPU_INSTANCE(gemm_inner_product_bwd_data_t<f32>) |
89 | CPU_INSTANCE(ref_inner_product_bwd_data_t) |
90 | nullptr, |
91 | })}, |
92 | {{backward_data, f32, bf16, bf16}, REG_BWD_PK({ |
93 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) |
94 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_bf16>) |
95 | CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t<f32>) |
96 | CPU_INSTANCE(ref_inner_product_bwd_data_t) |
97 | nullptr, |
98 | })}, |
99 | {{backward_data, bf16, bf16, bf16}, REG_BWD_PK({ |
100 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) |
101 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_bf16>) |
102 | CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t<bf16>) |
103 | CPU_INSTANCE(ref_inner_product_bwd_data_t) |
104 | nullptr, |
105 | })}, |
106 | {{backward_data, f32, f16, f16}, REG_BWD_PK({ |
107 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx_fp16>) |
108 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_fp16>) |
109 | CPU_INSTANCE(ref_inner_product_bwd_data_t) |
110 | nullptr, |
111 | })}, |
112 | {{backward_data, f16, f16, f16}, REG_BWD_PK({ |
113 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx_fp16>) |
114 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_fp16>) |
115 | CPU_INSTANCE(ref_inner_product_bwd_data_t) |
116 | nullptr, |
117 | })}, |
118 | {{backward_weights, f32, f32, f32}, REG_BWD_PK({ |
119 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx>) // bf32 |
120 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core>) |
121 | CPU_INSTANCE(gemm_inner_product_bwd_weights_t<f32>) |
122 | CPU_INSTANCE(ref_inner_product_bwd_weights_t) |
123 | nullptr, |
124 | })}, |
125 | {{backward_weights, bf16, f32, bf16}, REG_BWD_PK({ |
126 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx>) |
127 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_bf16>) |
128 | CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t<f32>) |
129 | CPU_INSTANCE(ref_inner_product_bwd_weights_t) |
130 | nullptr, |
131 | })}, |
132 | {{backward_weights, bf16, bf16, bf16}, REG_BWD_PK({ |
133 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx>) |
134 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_bf16>) |
135 | CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t<bf16>) |
136 | CPU_INSTANCE(ref_inner_product_bwd_weights_t) |
137 | nullptr, |
138 | })}, |
139 | {{backward_weights, f16, f32, f16}, REG_BWD_PK({ |
140 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx_fp16>) |
141 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_fp16>) |
142 | CPU_INSTANCE(ref_inner_product_bwd_weights_t) |
143 | nullptr, |
144 | })}, |
145 | {{backward_weights, f16, f16, f16}, REG_BWD_PK({ |
146 | CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx_fp16>) |
147 | CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_fp16>) |
148 | CPU_INSTANCE(ref_inner_product_bwd_weights_t) |
149 | nullptr, |
150 | })}, |
151 | {{forward, s8, s8, f32}, { |
152 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
153 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
154 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
155 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
156 | nullptr, |
157 | }}, |
158 | {{forward, s8, s8, s32}, { |
159 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
160 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
161 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
162 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
163 | nullptr, |
164 | }}, |
165 | {{forward, s8, s8, s8}, { |
166 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
167 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
168 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
169 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
170 | nullptr, |
171 | }}, |
172 | {{forward, s8, s8, u8}, { |
173 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
174 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
175 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
176 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
177 | nullptr, |
178 | }}, |
179 | {{forward, u8, s8, f32}, { |
180 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
181 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
182 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
183 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
184 | nullptr, |
185 | }}, |
186 | {{forward, u8, s8, s32}, { |
187 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
188 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
189 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
190 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
191 | nullptr, |
192 | }}, |
193 | {{forward, u8, s8, s8}, { |
194 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
195 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
196 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
197 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
198 | nullptr, |
199 | }}, |
200 | {{forward, u8, s8, u8}, { |
201 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
202 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
203 | CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t) |
204 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
205 | nullptr, |
206 | }}, |
207 | {{forward, s8, s8, bf16}, { |
208 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
209 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
210 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
211 | nullptr, |
212 | }}, |
213 | {{forward, u8, s8, bf16}, { |
214 | CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) |
215 | CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>) |
216 | CPU_INSTANCE(ref_inner_product_int8_fwd_t) |
217 | nullptr, |
218 | }}, |
219 | }); |
220 | return the_map; |
221 | } |
222 | // clang-format on |
223 | } // namespace |
224 | |
225 | const impl_list_item_t *get_inner_product_impl_list( |
226 | const inner_product_desc_t *desc) { |
227 | static const impl_list_item_t empty_list[] = {nullptr}; |
228 | |
229 | const bool is_fwd = utils::one_of( |
230 | desc->prop_kind, forward_training, forward_inference); |
231 | prop_kind_t prop_kind = is_fwd ? forward : desc->prop_kind; |
232 | |
233 | const memory_desc_t *src_md = desc->prop_kind == backward_data |
234 | ? &desc->diff_src_desc |
235 | : &desc->src_desc; |
236 | const memory_desc_t *wei_md = desc->prop_kind == backward_weights |
237 | ? &desc->diff_weights_desc |
238 | : &desc->weights_desc; |
239 | const memory_desc_t *dst_md |
240 | = is_fwd ? &desc->dst_desc : &desc->diff_dst_desc; |
241 | pk_dt_impl_key_t key { |
242 | prop_kind, |
243 | src_md->data_type, |
244 | wei_md->data_type, |
245 | dst_md->data_type, |
246 | }; |
247 | |
248 | const auto impl_list_it = impl_list_map().find(key); |
249 | return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data() |
250 | : empty_list; |
251 | } |
252 | |
253 | } // namespace cpu |
254 | } // namespace impl |
255 | } // namespace dnnl |
256 | |