1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3* Copyright 2021 FUJITSU LIMITED
4* Copyright 2021-2022 Arm Ltd. and affiliates
5*
6* Licensed under the Apache License, Version 2.0 (the "License");
7* you may not use this file except in compliance with the License.
8* You may obtain a copy of the License at
9*
10* http://www.apache.org/licenses/LICENSE-2.0
11*
12* Unless required by applicable law or agreed to in writing, software
13* distributed under the License is distributed on an "AS IS" BASIS,
14* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15* See the License for the specific language governing permissions and
16* limitations under the License.
17*******************************************************************************/
18
19#include "cpu/cpu_engine.hpp"
20
21#include "cpu/ref_softmax.hpp"
22
23#if DNNL_X64
24#include "cpu/x64/jit_uni_softmax.hpp"
25using namespace dnnl::impl::cpu::x64;
26#elif DNNL_AARCH64
27#include "cpu/aarch64/jit_uni_softmax.hpp"
28#if DNNL_AARCH64_USE_ACL
29#include "cpu/aarch64/acl_softmax.hpp"
30#endif
31using namespace dnnl::impl::cpu::aarch64;
32#endif
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37
38namespace {
39using namespace dnnl::impl::data_type;
40using namespace dnnl::impl::prop_kind;
41
42const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
43 // clang-format off
44 static std::map<pk_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_SOFTMAX_P({
45 {{forward}, {
46 CPU_INSTANCE_X64(jit_uni_softmax_fwd_t<avx512_core>)
47 CPU_INSTANCE_X64(jit_uni_softmax_fwd_t<avx2>)
48 CPU_INSTANCE_X64(jit_uni_softmax_fwd_t<sse41>)
49 CPU_INSTANCE_AARCH64(jit_uni_softmax_fwd_t<sve_512>)
50 CPU_INSTANCE_AARCH64_ACL(acl_softmax_fwd_t)
51 CPU_INSTANCE(ref_softmax_fwd_t)
52 nullptr,
53 }},
54 {{backward}, REG_BWD_PK({
55 CPU_INSTANCE_X64(jit_uni_softmax_bwd_t<avx512_core>)
56 CPU_INSTANCE_AARCH64(jit_uni_softmax_bwd_t<sve_512>)
57 CPU_INSTANCE(ref_softmax_bwd_t)
58 nullptr,
59 })},
60 });
61 // clang-format on
62 return the_map;
63}
64
65} // namespace
66
67const impl_list_item_t *get_softmax_impl_list(const softmax_desc_t *desc) {
68 static const impl_list_item_t empty_list[] = {nullptr};
69
70 const bool is_fwd = utils::one_of(
71 desc->prop_kind, forward_training, forward_inference);
72 prop_kind_t prop_kind = is_fwd ? forward : backward;
73
74 pk_impl_key_t key {prop_kind};
75
76 const auto impl_list_it = impl_list_map().find(key);
77 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
78 : empty_list;
79}
80
81} // namespace cpu
82} // namespace impl
83} // namespace dnnl
84