1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3* Copyright 2022 Arm Ltd. and affiliates
4*
5* Licensed under the Apache License, Version 2.0 (the "License");
6* you may not use this file except in compliance with the License.
7* You may obtain a copy of the License at
8*
9* http://www.apache.org/licenses/LICENSE-2.0
10*
11* Unless required by applicable law or agreed to in writing, software
12* distributed under the License is distributed on an "AS IS" BASIS,
13* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14* See the License for the specific language governing permissions and
15* limitations under the License.
16*******************************************************************************/
17
18#include "cpu/cpu_engine.hpp"
19#include "cpu/ref_prelu.hpp"
20
21#if DNNL_X64
22#include "cpu/x64/prelu/jit_prelu_backward.hpp"
23#include "cpu/x64/prelu/jit_prelu_forward.hpp"
24
25using namespace dnnl::impl::cpu::x64;
26#elif DNNL_AARCH64 && DNNL_AARCH64_USE_ACL
27#include "cpu/aarch64/acl_prelu.hpp"
28using namespace dnnl::impl::cpu::aarch64;
29#endif
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34
35namespace {
36using namespace dnnl::impl::data_type;
37using namespace dnnl::impl::prop_kind;
38
39// clang-format off
40const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
41 static const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_PRELU_P({
42 {{forward}, {
43 CPU_INSTANCE_X64(jit_prelu_fwd_t)
44 CPU_INSTANCE_AARCH64_ACL(acl_prelu_fwd_t)
45 CPU_INSTANCE(ref_prelu_fwd_t)
46 nullptr,
47 }},
48 {{backward}, REG_BWD_PK({
49 CPU_INSTANCE_X64(jit_prelu_bwd_t)
50 CPU_INSTANCE(ref_prelu_bwd_t)
51 nullptr,
52 })},
53 });
54 return the_map;
55}
56// clang-format on
57} // namespace
58
59const impl_list_item_t *get_prelu_impl_list(const prelu_desc_t *desc) {
60 static const impl_list_item_t empty_list[] = {nullptr};
61
62 const bool is_fwd = utils::one_of(
63 desc->prop_kind, forward_training, forward_inference);
64 prop_kind_t prop_kind = is_fwd ? forward : backward;
65
66 pk_impl_key_t key {prop_kind};
67
68 const auto impl_list_it = impl_list_map().find(key);
69 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
70 : empty_list;
71}
72
73} // namespace cpu
74} // namespace impl
75} // namespace dnnl
76