1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/lrn/jit_avx512_common_lrn.hpp" |
23 | #include "cpu/x64/lrn/lrn_executor_factory.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | static constexpr int vsize = 16; |
31 | |
32 | using namespace dnnl::impl::status; |
33 | using namespace dnnl::impl::utils; |
34 | using namespace data_type; |
35 | |
36 | template <data_type_t d_type> |
37 | status_t jit_avx512_common_lrn_fwd_t<d_type>::pd_t::init(engine_t *engine) { |
38 | using namespace prop_kind; |
39 | using namespace alg_kind; |
40 | |
41 | const memory_desc_wrapper src_d(src_md()); |
42 | const memory_desc_wrapper dst_d(dst_md()); |
43 | |
44 | const bool ok = is_fwd() && mayiuse(avx512_core) && !has_zero_dim_memory() |
45 | && everyone_is(d_type, src_d.data_type(), dst_d.data_type()) |
46 | && IMPLICATION(d_type == f16, mayiuse(avx512_core_fp16)) |
47 | && src_d.ndims() == 4 && attr()->has_default_values() |
48 | && set_default_formats_common() && src_d == dst_d; |
49 | if (!ok) return unimplemented; |
50 | |
51 | const auto fmt_tag |
52 | = src_d.matches_one_of_tag(format_tag::nhwc, format_tag::nChw16c); |
53 | |
54 | const bool args_ok_across = desc()->alg_kind == lrn_across_channels |
55 | && desc()->local_size >= 1 && desc()->local_size <= 16 |
56 | && (desc()->lrn_beta == 0.75 || desc()->lrn_beta == 1.0) |
57 | && src_d.matches_tag(fmt_tag) |
58 | && IMPLICATION(fmt_tag == format_tag::nChw16c, |
59 | src_d.dims()[1] % vsize == 0 && desc()->local_size == 5); |
60 | |
61 | if (!args_ok_across) return unimplemented; |
62 | |
63 | if (desc()->prop_kind == forward_training) { |
64 | dims_t ws_dims = {MB(), C(), H(), 2 * W()}; |
65 | memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, fmt_tag); |
66 | } |
67 | |
68 | return success; |
69 | } |
70 | |
71 | template <data_type_t d_type> |
72 | jit_avx512_common_lrn_fwd_t<d_type>::jit_avx512_common_lrn_fwd_t( |
73 | const pd_t *apd) |
74 | : primitive_t(apd) |
75 | , lrn_executor_(lrn::lrn_executor_factory_t::create_executor<d_type, |
76 | typename jit_avx512_common_lrn_fwd_t<d_type>::pd_t>( |
77 | pd(), lrn::direction::forward)) {} |
78 | |
79 | template <data_type_t d_type> |
80 | jit_avx512_common_lrn_fwd_t<d_type>::~jit_avx512_common_lrn_fwd_t() = default; |
81 | |
82 | template struct jit_avx512_common_lrn_fwd_t<f32>; |
83 | template struct jit_avx512_common_lrn_fwd_t<bf16>; |
84 | template struct jit_avx512_common_lrn_fwd_t<f16>; |
85 | |
86 | template <data_type_t d_type> |
87 | status_t jit_avx512_common_lrn_bwd_t<d_type>::pd_t::init(engine_t *engine) { |
88 | using namespace alg_kind; |
89 | |
90 | const memory_desc_wrapper src_d(src_md()); |
91 | const memory_desc_wrapper diff_src_d(diff_src_md()); |
92 | const memory_desc_wrapper diff_dst_d(diff_dst_md()); |
93 | |
94 | const bool ok = !is_fwd() && mayiuse(avx512_core) && !has_zero_dim_memory() |
95 | && utils::everyone_is(d_type, src_d.data_type(), |
96 | diff_src_d.data_type(), diff_dst_d.data_type()) |
97 | && IMPLICATION(d_type == f16, mayiuse(avx512_core_fp16)) |
98 | && src_d.ndims() == 4 && attr()->has_default_values() |
99 | && set_default_formats_common() && src_d == diff_dst_d |
100 | && diff_dst_d == diff_src_d; |
101 | if (!ok) return unimplemented; |
102 | |
103 | const dims_t ws_dims = {MB(), C(), H(), 2 * W()}; |
104 | const auto fmt_tag |
105 | = src_d.matches_one_of_tag(format_tag::nhwc, format_tag::nChw16c); |
106 | memory_desc_init_by_tag(ws_md_, 4, ws_dims, d_type, fmt_tag); |
107 | if (!compare_ws(hint_fwd_pd_)) return unimplemented; |
108 | |
109 | const bool args_ok_across = true && desc()->alg_kind == lrn_across_channels |
110 | && desc()->local_size >= 1 && desc()->local_size <= 16 |
111 | && (desc()->lrn_beta == 0.75 || desc()->lrn_beta == 1.0) |
112 | && src_d.matches_tag(fmt_tag) |
113 | && IMPLICATION(fmt_tag == format_tag::nChw16c, |
114 | src_d.dims()[1] % vsize == 0 && desc()->local_size == 5); |
115 | |
116 | return args_ok_across ? success : unimplemented; |
117 | } |
118 | |
119 | template <data_type_t d_type> |
120 | jit_avx512_common_lrn_bwd_t<d_type>::jit_avx512_common_lrn_bwd_t( |
121 | const pd_t *apd) |
122 | : primitive_t(apd) |
123 | , lrn_executor_(lrn::lrn_executor_factory_t::create_executor<d_type, |
124 | typename jit_avx512_common_lrn_bwd_t<d_type>::pd_t>( |
125 | pd(), lrn::direction::backward)) {} |
126 | |
127 | template <data_type_t d_type> |
128 | jit_avx512_common_lrn_bwd_t<d_type>::~jit_avx512_common_lrn_bwd_t() = default; |
129 | |
130 | template struct jit_avx512_common_lrn_bwd_t<f32>; |
131 | template struct jit_avx512_common_lrn_bwd_t<bf16>; |
132 | template struct jit_avx512_common_lrn_bwd_t<f16>; |
133 | |
134 | } // namespace x64 |
135 | } // namespace cpu |
136 | } // namespace impl |
137 | } // namespace dnnl |
138 | |