1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3* Copyright 2020 FUJITSU LIMITED
4* Copyright 2022 Arm Ltd. and affiliates
5*
6* Licensed under the Apache License, Version 2.0 (the "License");
7* you may not use this file except in compliance with the License.
8* You may obtain a copy of the License at
9*
10* http://www.apache.org/licenses/LICENSE-2.0
11*
12* Unless required by applicable law or agreed to in writing, software
13* distributed under the License is distributed on an "AS IS" BASIS,
14* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15* See the License for the specific language governing permissions and
16* limitations under the License.
17*******************************************************************************/
18
19#include "cpu/cpu_engine.hpp"
20
21#include "cpu/nchw_pooling.hpp"
22#include "cpu/nhwc_pooling.hpp"
23#include "cpu/ref_pooling.hpp"
24
25#if DNNL_X64
26#include "cpu/x64/jit_uni_i8i8_pooling.hpp"
27#include "cpu/x64/jit_uni_pooling.hpp"
28using namespace dnnl::impl::cpu::x64;
29#elif DNNL_AARCH64
30#include "cpu/aarch64/jit_uni_i8i8_pooling.hpp"
31#include "cpu/aarch64/jit_uni_pooling.hpp"
32using namespace dnnl::impl::cpu::aarch64;
33#if DNNL_AARCH64_USE_ACL
34#include "cpu/aarch64/acl_pooling.hpp"
35#endif // DNNL_AARCH64_USE_ACL
36#endif
37
38namespace dnnl {
39namespace impl {
40namespace cpu {
41
42namespace {
43using namespace dnnl::impl::data_type;
44using namespace dnnl::impl::prop_kind;
45
46// clang-format off
47const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
48 static const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_POOLING_P({
49 {{forward}, {
50 /* fp */
51 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx512_core_fp16, f16>)
52 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx512_core, bf16>)
53 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx512_core, f32>)
54 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx2_vnni_2, bf16>)
55 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx2_vnni_2, f16>)
56 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx2, f32>)
57 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx, f32>)
58 CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<sse41, f32>)
59 CPU_INSTANCE_AARCH64(jit_uni_pooling_fwd_t<sve_512, f32>)
60 CPU_INSTANCE_AARCH64_ACL(acl_pooling_fwd_t)
61 CPU_INSTANCE(nchw_pooling_fwd_t<bf16>)
62 CPU_INSTANCE(nchw_pooling_fwd_t<f32>)
63 CPU_INSTANCE(nchw_pooling_fwd_t<f16>)
64 CPU_INSTANCE(nhwc_pooling_fwd_t<bf16>)
65 CPU_INSTANCE(nhwc_pooling_fwd_t<f32>)
66 CPU_INSTANCE(nhwc_pooling_fwd_t<f16>)
67 CPU_INSTANCE(ref_pooling_fwd_t<f32>)
68 CPU_INSTANCE(ref_pooling_fwd_t<bf16, f32>)
69 CPU_INSTANCE(ref_pooling_fwd_t<f16, f32>)
70 /* int */
71 CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t<avx512_core>)
72 CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t<avx2>)
73 CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t<sse41>)
74 CPU_INSTANCE_AARCH64(jit_uni_i8i8_pooling_fwd_t<sve_512>)
75 CPU_INSTANCE(ref_pooling_fwd_t<s32>)
76 CPU_INSTANCE(ref_pooling_fwd_t<s8, s32>)
77 CPU_INSTANCE(ref_pooling_fwd_t<u8, s32>)
78 nullptr,
79 }},
80 {{backward}, REG_BWD_PK({
81 CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx512_core_fp16, f16>)
82 CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx512_core, bf16>)
83 CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx512_core, f32>)
84 CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx2, f32>)
85 CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx, f32>)
86 CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<sse41, f32>)
87 CPU_INSTANCE_AARCH64(jit_uni_pooling_bwd_t<sve_512, f32>)
88 CPU_INSTANCE(nchw_pooling_bwd_t<bf16>)
89 CPU_INSTANCE(nchw_pooling_bwd_t<f32>)
90 CPU_INSTANCE(nchw_pooling_bwd_t<f16>)
91 CPU_INSTANCE(nhwc_pooling_bwd_t<bf16>)
92 CPU_INSTANCE(nhwc_pooling_bwd_t<f32>)
93 CPU_INSTANCE(nhwc_pooling_bwd_t<f16>)
94 CPU_INSTANCE(ref_pooling_bwd_t<f32>)
95 CPU_INSTANCE(ref_pooling_bwd_t<bf16>)
96 CPU_INSTANCE(ref_pooling_bwd_t<f16>)
97 nullptr,
98 })},
99 });
100 return the_map;
101}
102// clang-format on
103} // namespace
104
105const impl_list_item_t *get_pooling_impl_list(const pooling_desc_t *desc) {
106 static const impl_list_item_t empty_list[] = {nullptr};
107
108 const bool is_fwd = utils::one_of(
109 desc->prop_kind, forward_training, forward_inference);
110 prop_kind_t prop_kind = is_fwd ? forward : backward;
111
112 pk_impl_key_t key {prop_kind};
113
114 const auto impl_list_it = impl_list_map().find(key);
115 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
116 : empty_list;
117}
118
119} // namespace cpu
120} // namespace impl
121} // namespace dnnl
122