1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3* Copyright 2021 FUJITSU LIMITED
4* Copyright 2022 Arm Ltd. and affiliates
5*
6* Licensed under the Apache License, Version 2.0 (the "License");
7* you may not use this file except in compliance with the License.
8* You may obtain a copy of the License at
9*
10* http://www.apache.org/licenses/LICENSE-2.0
11*
12* Unless required by applicable law or agreed to in writing, software
13* distributed under the License is distributed on an "AS IS" BASIS,
14* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15* See the License for the specific language governing permissions and
16* limitations under the License.
17*******************************************************************************/
18
19#include "cpu/cpu_engine.hpp"
20
21#include "cpu/ncsp_batch_normalization.hpp"
22#include "cpu/nspc_batch_normalization.hpp"
23#include "cpu/ref_batch_normalization.hpp"
24
25#if DNNL_X64
26#include "cpu/x64/jit_uni_batch_normalization.hpp"
27#include "cpu/x64/jit_uni_batch_normalization_s8.hpp"
28#include "cpu/x64/jit_uni_tbb_batch_normalization.hpp"
29using namespace dnnl::impl::cpu::x64;
30#endif
31
32#if DNNL_AARCH64
33#include "cpu/aarch64/jit_uni_batch_normalization.hpp"
34#include "cpu/aarch64/jit_uni_batch_normalization_s8.hpp"
35#if DNNL_AARCH64_USE_ACL
36#include "cpu/aarch64/acl_batch_normalization.hpp"
37#endif
38using namespace dnnl::impl::cpu::aarch64;
39#endif
40
41namespace dnnl {
42namespace impl {
43namespace cpu {
44
45namespace {
46using namespace dnnl::impl::data_type;
47using namespace dnnl::impl::prop_kind;
48
49const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
50 // clang-format off
51 static const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_BNORM_P({
52 {{forward}, {
53 /* fp */
54 CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t<avx512_core>)
55 CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t<avx2>)
56 CPU_INSTANCE_X64(jit_uni_batch_normalization_fwd_t<sse41>)
57 CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t<avx512_core>)
58 CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t<avx2>)
59 CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_fwd_t<sse41>)
60 CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t<sve_512>)
61 CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_fwd_t<asimd>)
62 CPU_INSTANCE_AARCH64_ACL(acl_batch_normalization_fwd_t)
63 CPU_INSTANCE(ncsp_batch_normalization_fwd_t<f32>)
64 CPU_INSTANCE(ncsp_batch_normalization_fwd_t<bf16>)
65 CPU_INSTANCE(ncsp_batch_normalization_fwd_t<f16>)
66 CPU_INSTANCE(nspc_batch_normalization_fwd_t<f32>)
67 CPU_INSTANCE(nspc_batch_normalization_fwd_t<bf16>)
68 CPU_INSTANCE(nspc_batch_normalization_fwd_t<f16>)
69 CPU_INSTANCE(ref_batch_normalization_fwd_t<f32>)
70 CPU_INSTANCE(ref_batch_normalization_fwd_t<bf16>)
71 CPU_INSTANCE(ref_batch_normalization_fwd_t<f16>)
72 /* int */
73 CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t<avx512_core>)
74 CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t<avx2>)
75 CPU_INSTANCE_X64(jit_uni_batch_normalization_s8_fwd_t<sse41>)
76 CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_s8_fwd_t<sve_512>)
77 CPU_INSTANCE(ref_batch_normalization_fwd_t<s8>)
78 nullptr,
79 }},
80 {{backward}, REG_BWD_PK({
81 CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t<avx512_core>)
82 CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t<avx2>)
83 CPU_INSTANCE_X64(jit_uni_batch_normalization_bwd_t<sse41>)
84 CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t<avx512_core>)
85 CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t<avx2>)
86 CPU_INSTANCE_X64(jit_uni_tbb_batch_normalization_bwd_t<sse41>)
87 CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t<sve_512>)
88 CPU_INSTANCE_AARCH64(jit_uni_batch_normalization_bwd_t<asimd>)
89 CPU_INSTANCE(ncsp_batch_normalization_bwd_t<f32>)
90 CPU_INSTANCE(ncsp_batch_normalization_bwd_t<bf16>)
91 CPU_INSTANCE(ncsp_batch_normalization_bwd_t<f16>)
92 CPU_INSTANCE(nspc_batch_normalization_bwd_t<f32>)
93 CPU_INSTANCE(nspc_batch_normalization_bwd_t<bf16>)
94 CPU_INSTANCE(nspc_batch_normalization_bwd_t<f16>)
95 CPU_INSTANCE(ref_batch_normalization_bwd_t<f32>)
96 CPU_INSTANCE(ref_batch_normalization_bwd_t<bf16>)
97 CPU_INSTANCE(ref_batch_normalization_bwd_t<f16>)
98 nullptr,
99 })},
100 });
101 // clang-format on
102 return the_map;
103}
104} // namespace
105
106const impl_list_item_t *get_batch_normalization_impl_list(
107 const batch_normalization_desc_t *desc) {
108 static const impl_list_item_t empty_list[] = {nullptr};
109
110 const bool is_fwd = utils::one_of(
111 desc->prop_kind, forward_training, forward_inference);
112 prop_kind_t prop_kind = is_fwd ? forward : backward;
113
114 pk_impl_key_t key {prop_kind};
115
116 const auto impl_list_it = impl_list_map().find(key);
117 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
118 : empty_list;
119}
120
121} // namespace cpu
122} // namespace impl
123} // namespace dnnl
124