1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/cpu_engine.hpp"
18
19#include "cpu/gemm_inner_product.hpp"
20#include "cpu/gemm_x8s8s32x_inner_product.hpp"
21#include "cpu/ref_inner_product.hpp"
22#include "cpu/ref_inner_product_int8.hpp"
23
24#if DNNL_X64
25#include "cpu/x64/gemm_bf16_inner_product.hpp"
26#include "cpu/x64/jit_brgemm_inner_product.hpp"
27using namespace dnnl::impl::cpu::x64;
28#endif
29
30#if DNNL_AARCH64 && DNNL_AARCH64_USE_ACL
31#include "cpu/aarch64/acl_inner_product.hpp"
32using namespace dnnl::impl::cpu::aarch64;
33#endif
34
35namespace dnnl {
36namespace impl {
37namespace cpu {
38
39namespace {
40using namespace dnnl::impl::data_type;
41using namespace dnnl::impl::prop_kind;
42
43// clang-format off
44const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
45 static const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_IP_P({
46 {{forward, f32, f32, f32}, {
47 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>) // bf32
48 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core>)
49 CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
50 CPU_INSTANCE(gemm_inner_product_fwd_t<f32>)
51 CPU_INSTANCE(ref_inner_product_fwd_t)
52 nullptr,
53 }},
54 {{forward, bf16, bf16, f32}, {
55 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
56 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_bf16>)
57 CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t<f32>)
58 CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>)
59 CPU_INSTANCE(ref_inner_product_fwd_t)
60 nullptr,
61 }},
62 {{forward, bf16, bf16, bf16}, {
63 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
64 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_bf16>)
65 CPU_INSTANCE_AVX512(gemm_bf16_inner_product_fwd_t<bf16>)
66 CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>)
67 CPU_INSTANCE(ref_inner_product_fwd_t)
68 nullptr,
69 }},
70 {{forward, f16, f16, f32}, {
71 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx_fp16>)
72 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_fp16>)
73 CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>)
74 CPU_INSTANCE(ref_inner_product_fwd_t)
75 nullptr,
76 }},
77 {{forward, f16, f16, f16}, {
78 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx_fp16>)
79 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_fp16>)
80 CPU_INSTANCE_AVX2(brgemm_inner_product_fwd_t<avx2_vnni_2>)
81 CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
82 CPU_INSTANCE(ref_inner_product_fwd_t)
83 nullptr,
84 }},
85 {{backward_data, f32, f32, f32}, REG_BWD_PK({
86 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
87 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
88 CPU_INSTANCE(gemm_inner_product_bwd_data_t<f32>)
89 CPU_INSTANCE(ref_inner_product_bwd_data_t)
90 nullptr,
91 })},
92 {{backward_data, f32, bf16, bf16}, REG_BWD_PK({
93 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>)
94 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_bf16>)
95 CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t<f32>)
96 CPU_INSTANCE(ref_inner_product_bwd_data_t)
97 nullptr,
98 })},
99 {{backward_data, bf16, bf16, bf16}, REG_BWD_PK({
100 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>)
101 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_bf16>)
102 CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_data_t<bf16>)
103 CPU_INSTANCE(ref_inner_product_bwd_data_t)
104 nullptr,
105 })},
106 {{backward_data, f32, f16, f16}, REG_BWD_PK({
107 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx_fp16>)
108 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_fp16>)
109 CPU_INSTANCE(ref_inner_product_bwd_data_t)
110 nullptr,
111 })},
112 {{backward_data, f16, f16, f16}, REG_BWD_PK({
113 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx_fp16>)
114 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core_fp16>)
115 CPU_INSTANCE(ref_inner_product_bwd_data_t)
116 nullptr,
117 })},
118 {{backward_weights, f32, f32, f32}, REG_BWD_PK({
119 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx>) // bf32
120 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core>)
121 CPU_INSTANCE(gemm_inner_product_bwd_weights_t<f32>)
122 CPU_INSTANCE(ref_inner_product_bwd_weights_t)
123 nullptr,
124 })},
125 {{backward_weights, bf16, f32, bf16}, REG_BWD_PK({
126 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx>)
127 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_bf16>)
128 CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t<f32>)
129 CPU_INSTANCE(ref_inner_product_bwd_weights_t)
130 nullptr,
131 })},
132 {{backward_weights, bf16, bf16, bf16}, REG_BWD_PK({
133 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx>)
134 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_bf16>)
135 CPU_INSTANCE_AVX512(gemm_bf16_inner_product_bwd_weights_t<bf16>)
136 CPU_INSTANCE(ref_inner_product_bwd_weights_t)
137 nullptr,
138 })},
139 {{backward_weights, f16, f32, f16}, REG_BWD_PK({
140 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx_fp16>)
141 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_fp16>)
142 CPU_INSTANCE(ref_inner_product_bwd_weights_t)
143 nullptr,
144 })},
145 {{backward_weights, f16, f16, f16}, REG_BWD_PK({
146 CPU_INSTANCE_AMX(brgemm_inner_product_bwd_weights_t<avx512_core_amx_fp16>)
147 CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_weights_t<avx512_core_fp16>)
148 CPU_INSTANCE(ref_inner_product_bwd_weights_t)
149 nullptr,
150 })},
151 {{forward, s8, s8, f32}, {
152 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
153 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
154 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
155 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
156 nullptr,
157 }},
158 {{forward, s8, s8, s32}, {
159 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
160 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
161 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
162 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
163 nullptr,
164 }},
165 {{forward, s8, s8, s8}, {
166 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
167 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
168 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
169 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
170 nullptr,
171 }},
172 {{forward, s8, s8, u8}, {
173 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
174 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
175 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
176 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
177 nullptr,
178 }},
179 {{forward, u8, s8, f32}, {
180 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
181 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
182 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
183 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
184 nullptr,
185 }},
186 {{forward, u8, s8, s32}, {
187 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
188 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
189 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
190 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
191 nullptr,
192 }},
193 {{forward, u8, s8, s8}, {
194 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
195 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
196 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
197 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
198 nullptr,
199 }},
200 {{forward, u8, s8, u8}, {
201 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
202 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
203 CPU_INSTANCE(gemm_x8s8s32x_inner_product_fwd_t)
204 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
205 nullptr,
206 }},
207 {{forward, s8, s8, bf16}, {
208 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
209 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
210 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
211 nullptr,
212 }},
213 {{forward, u8, s8, bf16}, {
214 CPU_INSTANCE_AMX(brgemm_inner_product_fwd_t<avx512_core_amx>)
215 CPU_INSTANCE_AVX512(brgemm_inner_product_fwd_t<avx512_core_vnni>)
216 CPU_INSTANCE(ref_inner_product_int8_fwd_t)
217 nullptr,
218 }},
219 });
220 return the_map;
221}
222// clang-format on
223} // namespace
224
225const impl_list_item_t *get_inner_product_impl_list(
226 const inner_product_desc_t *desc) {
227 static const impl_list_item_t empty_list[] = {nullptr};
228
229 const bool is_fwd = utils::one_of(
230 desc->prop_kind, forward_training, forward_inference);
231 prop_kind_t prop_kind = is_fwd ? forward : desc->prop_kind;
232
233 const memory_desc_t *src_md = desc->prop_kind == backward_data
234 ? &desc->diff_src_desc
235 : &desc->src_desc;
236 const memory_desc_t *wei_md = desc->prop_kind == backward_weights
237 ? &desc->diff_weights_desc
238 : &desc->weights_desc;
239 const memory_desc_t *dst_md
240 = is_fwd ? &desc->dst_desc : &desc->diff_dst_desc;
241 pk_dt_impl_key_t key {
242 prop_kind,
243 src_md->data_type,
244 wei_md->data_type,
245 dst_md->data_type,
246 };
247
248 const auto impl_list_it = impl_list_map().find(key);
249 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
250 : empty_list;
251}
252
253} // namespace cpu
254} // namespace impl
255} // namespace dnnl
256