1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/cpu_engine.hpp"
18
19#include "cpu/ref_lrn.hpp"
20
21#if DNNL_X64
22#include "cpu/x64/lrn/jit_avx512_common_lrn.hpp"
23#include "cpu/x64/lrn/jit_uni_lrn.hpp"
24using namespace dnnl::impl::cpu::x64;
25#endif
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30
31namespace {
32using namespace dnnl::impl::data_type;
33using namespace dnnl::impl::prop_kind;
34
35// clang-format off
36const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
37 static const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_LRN_P({
38 {{forward}, {
39 CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t<f32>)
40 CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t<bf16>)
41 CPU_INSTANCE_X64(jit_avx512_common_lrn_fwd_t<f16>)
42 CPU_INSTANCE_X64(jit_uni_lrn_fwd_t<avx512_core_fp16, f16>)
43 CPU_INSTANCE_X64(jit_uni_lrn_fwd_t<avx512_core, f32>)
44 CPU_INSTANCE_X64(jit_uni_lrn_fwd_t<avx512_core, bf16>)
45 CPU_INSTANCE_X64(jit_uni_lrn_fwd_t<avx2, f32>)
46 CPU_INSTANCE_X64(jit_uni_lrn_fwd_t<sse41, f32>)
47 CPU_INSTANCE(ref_lrn_fwd_t<f32>)
48 CPU_INSTANCE(ref_lrn_fwd_t<bf16>)
49 CPU_INSTANCE(ref_lrn_fwd_t<f16>)
50 nullptr,
51 }},
52 {{backward}, REG_BWD_PK({
53 CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t<f32>)
54 CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t<bf16>)
55 CPU_INSTANCE_X64(jit_avx512_common_lrn_bwd_t<f16>)
56 CPU_INSTANCE_X64(jit_uni_lrn_bwd_t<avx512_core_fp16, f16>)
57 CPU_INSTANCE_X64(jit_uni_lrn_bwd_t<avx512_core, f32>)
58 CPU_INSTANCE_X64(jit_uni_lrn_bwd_t<avx512_core, bf16>)
59 CPU_INSTANCE_X64(jit_uni_lrn_bwd_t<avx2, f32>)
60 CPU_INSTANCE(ref_lrn_bwd_t<f32>)
61 CPU_INSTANCE(ref_lrn_bwd_t<bf16>)
62 CPU_INSTANCE(ref_lrn_bwd_t<f16>)
63 nullptr,
64 })},
65 });
66 return the_map;
67}
68// clang-format on
69} // namespace
70
71const impl_list_item_t *get_lrn_impl_list(const lrn_desc_t *desc) {
72 static const impl_list_item_t empty_list[] = {nullptr};
73
74 const bool is_fwd = utils::one_of(
75 desc->prop_kind, forward_training, forward_inference);
76 prop_kind_t prop_kind = is_fwd ? forward : backward;
77
78 pk_impl_key_t key {prop_kind};
79
80 const auto impl_list_it = impl_list_map().find(key);
81 return impl_list_it != impl_list_map().cend() ? impl_list_it->second.data()
82 : empty_list;
83}
84
85} // namespace cpu
86} // namespace impl
87} // namespace dnnl
88