1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/lrn/jit_avx512_common_lrn.hpp"
23#include "cpu/x64/lrn/lrn_executor_factory.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30static constexpr int vsize = 16;
31
32using namespace dnnl::impl::status;
33using namespace dnnl::impl::utils;
34using namespace data_type;
35
36template <data_type_t d_type>
37status_t jit_avx512_common_lrn_fwd_t<d_type>::pd_t::init(engine_t *engine) {
38 using namespace prop_kind;
39 using namespace alg_kind;
40
41 const memory_desc_wrapper src_d(src_md());
42 const memory_desc_wrapper dst_d(dst_md());
43
44 const bool ok = is_fwd() && mayiuse(avx512_core) && !has_zero_dim_memory()
45 && everyone_is(d_type, src_d.data_type(), dst_d.data_type())
46 && IMPLICATION(d_type == f16, mayiuse(avx512_core_fp16))
47 && src_d.ndims() == 4 && attr()->has_default_values()
48 && set_default_formats_common() && src_d == dst_d;
49 if (!ok) return unimplemented;
50
51 const auto fmt_tag
52 = src_d.matches_one_of_tag(format_tag::nhwc, format_tag::nChw16c);
53
54 const bool args_ok_across = desc()->alg_kind == lrn_across_channels
55 && desc()->local_size >= 1 && desc()->local_size <= 16
56 && (desc()->lrn_beta == 0.75 || desc()->lrn_beta == 1.0)
57 && src_d.matches_tag(fmt_tag)
58 && IMPLICATION(fmt_tag == format_tag::nChw16c,
59 src_d.dims()[1] % vsize == 0 && desc()->local_size == 5);
60
61 if (!args_ok_across) return unimplemented;
62
63 if (desc()->prop_kind == forward_training) {
64 dims_t ws_dims = {MB(), C(), H(), 2 * W()};
65 memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, fmt_tag);
66 }
67
68 return success;
69}
70
71template <data_type_t d_type>
72jit_avx512_common_lrn_fwd_t<d_type>::jit_avx512_common_lrn_fwd_t(
73 const pd_t *apd)
74 : primitive_t(apd)
75 , lrn_executor_(lrn::lrn_executor_factory_t::create_executor<d_type,
76 typename jit_avx512_common_lrn_fwd_t<d_type>::pd_t>(
77 pd(), lrn::direction::forward)) {}
78
79template <data_type_t d_type>
80jit_avx512_common_lrn_fwd_t<d_type>::~jit_avx512_common_lrn_fwd_t() = default;
81
82template struct jit_avx512_common_lrn_fwd_t<f32>;
83template struct jit_avx512_common_lrn_fwd_t<bf16>;
84template struct jit_avx512_common_lrn_fwd_t<f16>;
85
86template <data_type_t d_type>
87status_t jit_avx512_common_lrn_bwd_t<d_type>::pd_t::init(engine_t *engine) {
88 using namespace alg_kind;
89
90 const memory_desc_wrapper src_d(src_md());
91 const memory_desc_wrapper diff_src_d(diff_src_md());
92 const memory_desc_wrapper diff_dst_d(diff_dst_md());
93
94 const bool ok = !is_fwd() && mayiuse(avx512_core) && !has_zero_dim_memory()
95 && utils::everyone_is(d_type, src_d.data_type(),
96 diff_src_d.data_type(), diff_dst_d.data_type())
97 && IMPLICATION(d_type == f16, mayiuse(avx512_core_fp16))
98 && src_d.ndims() == 4 && attr()->has_default_values()
99 && set_default_formats_common() && src_d == diff_dst_d
100 && diff_dst_d == diff_src_d;
101 if (!ok) return unimplemented;
102
103 const dims_t ws_dims = {MB(), C(), H(), 2 * W()};
104 const auto fmt_tag
105 = src_d.matches_one_of_tag(format_tag::nhwc, format_tag::nChw16c);
106 memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, fmt_tag);
107 if (!compare_ws(hint_fwd_pd_)) return unimplemented;
108
109 const bool args_ok_across = true && desc()->alg_kind == lrn_across_channels
110 && desc()->local_size >= 1 && desc()->local_size <= 16
111 && (desc()->lrn_beta == 0.75 || desc()->lrn_beta == 1.0)
112 && src_d.matches_tag(fmt_tag)
113 && IMPLICATION(fmt_tag == format_tag::nChw16c,
114 src_d.dims()[1] % vsize == 0 && desc()->local_size == 5);
115
116 return args_ok_across ? success : unimplemented;
117}
118
119template <data_type_t d_type>
120jit_avx512_common_lrn_bwd_t<d_type>::jit_avx512_common_lrn_bwd_t(
121 const pd_t *apd)
122 : primitive_t(apd)
123 , lrn_executor_(lrn::lrn_executor_factory_t::create_executor<d_type,
124 typename jit_avx512_common_lrn_bwd_t<d_type>::pd_t>(
125 pd(), lrn::direction::backward)) {}
126
127template <data_type_t d_type>
128jit_avx512_common_lrn_bwd_t<d_type>::~jit_avx512_common_lrn_bwd_t() = default;
129
130template struct jit_avx512_common_lrn_bwd_t<f32>;
131template struct jit_avx512_common_lrn_bwd_t<bf16>;
132template struct jit_avx512_common_lrn_bwd_t<f16>;
133
134} // namespace x64
135} // namespace cpu
136} // namespace impl
137} // namespace dnnl
138