1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_AVX512_COMMON_LRN_HPP
18#define CPU_X64_LRN_JIT_AVX512_COMMON_LRN_HPP
19
20#include <memory>
21#include "common/c_types_map.hpp"
22#include "common/primitive.hpp"
23
24#include "cpu/cpu_lrn_pd.hpp"
25#include "cpu/x64/cpu_isa_traits.hpp"
26#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
27#include "cpu/x64/lrn/lrn_executor.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33template <data_type_t d_type>
34struct jit_avx512_common_lrn_fwd_t : public primitive_t {
35 struct pd_t : public cpu_lrn_fwd_pd_t {
36 using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t;
37
38 DECLARE_COMMON_PD_T(
39 JIT_IMPL_NAME_HELPER("lrn_jit:",
40 utils::map(true, avx512_core,
41 d_type == data_type::bf16
42 && mayiuse(avx512_core_bf16),
43 avx512_core_bf16,
44 d_type == data_type::bf16
45 && !mayiuse(avx512_core_bf16),
46 bf16_emulation_t::get_isa(),
47 d_type == data_type::f16, avx512_core_fp16),
48 ""),
49 jit_avx512_common_lrn_fwd_t);
50
51 status_t init(engine_t *engine);
52 };
53
54 jit_avx512_common_lrn_fwd_t(const pd_t *apd);
55 ~jit_avx512_common_lrn_fwd_t();
56
57 using data_t = typename prec_traits<d_type>::type;
58
59 status_t init(engine_t *engine) override {
60 return lrn_executor_->create_kernel();
61 }
62
63 status_t execute(const exec_ctx_t &ctx) const override {
64 return lrn_executor_->execute(ctx);
65 }
66
67private:
68 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
69 std::unique_ptr<lrn::i_lrn_executor_t> lrn_executor_;
70};
71
72template <data_type_t d_type>
73struct jit_avx512_common_lrn_bwd_t : public primitive_t {
74 struct pd_t : public cpu_lrn_bwd_pd_t {
75 using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t;
76
77 DECLARE_COMMON_PD_T(
78 JIT_IMPL_NAME_HELPER("lrn_jit:",
79 utils::map(true, avx512_core,
80 d_type == data_type::bf16
81 && mayiuse(avx512_core_bf16),
82 avx512_core_bf16,
83 d_type == data_type::bf16
84 && !mayiuse(avx512_core_bf16),
85 bf16_emulation_t::get_isa(),
86 d_type == data_type::f16, avx512_core_fp16),
87 ""),
88 jit_avx512_common_lrn_bwd_t);
89
90 status_t init(engine_t *engine);
91 };
92
93 jit_avx512_common_lrn_bwd_t(const pd_t *apd);
94 ~jit_avx512_common_lrn_bwd_t();
95
96 using data_t = typename prec_traits<d_type>::type;
97
98 status_t init(engine_t *engine) override {
99 return lrn_executor_->create_kernel();
100 }
101
102 status_t execute(const exec_ctx_t &ctx) const override {
103 return lrn_executor_->execute(ctx);
104 }
105
106private:
107 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
108 std::unique_ptr<lrn::i_lrn_executor_t> lrn_executor_;
109};
110
111} // namespace x64
112} // namespace cpu
113} // namespace impl
114} // namespace dnnl
115
116#endif
117
118// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
119