1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_LRN_PD_HPP |
18 | #define COMMON_LRN_PD_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "primitive_desc.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | |
28 | struct lrn_fwd_pd_t; |
29 | |
30 | struct lrn_pd_t : public primitive_desc_t { |
31 | static constexpr auto base_pkind = primitive_kind::lrn; |
32 | |
33 | const lrn_desc_t *desc() const { return &desc_; } |
34 | const op_desc_t *op_desc() const override { |
35 | return reinterpret_cast<const op_desc_t *>(this->desc()); |
36 | } |
37 | |
38 | status_t query(query_t what, int idx, void *result) const override { |
39 | switch (what) { |
40 | case query::prop_kind: |
41 | *(prop_kind_t *)result = desc()->prop_kind; |
42 | break; |
43 | case query::alg_kind: |
44 | *(alg_kind_t *)result = desc()->alg_kind; |
45 | break; |
46 | case query::alpha_f32: *(float *)result = desc()->lrn_alpha; break; |
47 | case query::beta_f32: *(float *)result = desc()->lrn_beta; break; |
48 | case query::local_size_s64: |
49 | *(dim_t *)result = desc()->local_size; |
50 | break; |
51 | case query::k_f32: *(float *)result = desc()->lrn_k; break; |
52 | default: return primitive_desc_t::query(what, idx, result); |
53 | } |
54 | return status::success; |
55 | } |
56 | |
57 | /* common lrn aux functions */ |
58 | |
59 | dim_t MB() const { return src_md()->dims[0]; } |
60 | dim_t C() const { return src_md()->dims[1]; } |
61 | dim_t D() const { return ndims() >= 5 ? src_md()->dims[ndims() - 3] : 1; } |
62 | dim_t H() const { return ndims() >= 4 ? src_md()->dims[ndims() - 2] : 1; } |
63 | dim_t W() const { return ndims() >= 3 ? src_md()->dims[ndims() - 1] : 1; } |
64 | |
65 | int ndims() const { return src_md()->ndims; } |
66 | |
67 | bool has_zero_dim_memory() const { |
68 | return memory_desc_wrapper(desc_.src_desc).has_zero_dim(); |
69 | } |
70 | |
71 | bool is_fwd() const { |
72 | return utils::one_of(desc_.prop_kind, prop_kind::forward_training, |
73 | prop_kind::forward_inference); |
74 | } |
75 | |
76 | protected: |
77 | lrn_desc_t desc_; |
78 | const lrn_fwd_pd_t *hint_fwd_pd_; |
79 | |
80 | memory_desc_t src_md_; |
81 | memory_desc_t ws_md_; |
82 | |
83 | lrn_pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, |
84 | const lrn_fwd_pd_t *hint_fwd_pd) |
85 | : primitive_desc_t(attr, base_pkind) |
86 | , desc_(*adesc) |
87 | , hint_fwd_pd_(hint_fwd_pd) |
88 | , src_md_(desc_.src_desc) |
89 | , ws_md_() {} |
90 | }; |
91 | |
92 | struct lrn_fwd_pd_t : public lrn_pd_t { |
93 | typedef lrn_fwd_pd_t base_class; |
94 | typedef lrn_fwd_pd_t hint_class; |
95 | |
96 | arg_usage_t arg_usage(int arg) const override { |
97 | if (arg == DNNL_ARG_SRC) return arg_usage_t::input; |
98 | |
99 | if (arg == DNNL_ARG_DST) return arg_usage_t::output; |
100 | |
101 | if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) |
102 | return arg_usage_t::output; |
103 | |
104 | return primitive_desc_t::arg_usage(arg); |
105 | } |
106 | |
107 | const memory_desc_t *arg_md(int arg) const override { |
108 | switch (arg) { |
109 | case DNNL_ARG_SRC: return src_md(0); |
110 | case DNNL_ARG_DST: return dst_md(0); |
111 | default: return lrn_pd_t::arg_md(arg); |
112 | } |
113 | } |
114 | |
115 | const memory_desc_t *src_md(int index = 0) const override { |
116 | return index == 0 ? &src_md_ : &glob_zero_md; |
117 | } |
118 | const memory_desc_t *dst_md(int index = 0) const override { |
119 | return index == 0 ? &dst_md_ : &glob_zero_md; |
120 | } |
121 | const memory_desc_t *workspace_md(int index = 0) const override { |
122 | return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ |
123 | : &glob_zero_md; |
124 | } |
125 | |
126 | int n_inputs() const override { return 1; } |
127 | int n_outputs() const override { |
128 | return 1 + (!types::is_zero_md(workspace_md())); |
129 | } |
130 | |
131 | protected: |
132 | memory_desc_t dst_md_; |
133 | |
134 | lrn_fwd_pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, |
135 | const lrn_fwd_pd_t *hint_fwd_pd) |
136 | : lrn_pd_t(adesc, attr, hint_fwd_pd), dst_md_(desc_.dst_desc) {} |
137 | |
138 | bool set_default_formats_common() { |
139 | return IMPLICATION(dst_md_.format_kind == format_kind::any, |
140 | memory_desc_init_by_md_and_dt( |
141 | dst_md_, src_md_, dst_md_.data_type) |
142 | == status::success); |
143 | } |
144 | }; |
145 | |
146 | struct lrn_bwd_pd_t : public lrn_pd_t { |
147 | typedef lrn_bwd_pd_t base_class; |
148 | typedef lrn_fwd_pd_t hint_class; |
149 | |
150 | arg_usage_t arg_usage(int arg) const override { |
151 | if (utils::one_of(arg, DNNL_ARG_SRC, DNNL_ARG_DIFF_DST)) |
152 | return arg_usage_t::input; |
153 | |
154 | if (arg == DNNL_ARG_DIFF_SRC) return arg_usage_t::output; |
155 | |
156 | if (arg == DNNL_ARG_WORKSPACE && (!types::is_zero_md(workspace_md()))) |
157 | return arg_usage_t::input; |
158 | |
159 | return primitive_desc_t::arg_usage(arg); |
160 | } |
161 | |
162 | const memory_desc_t *arg_md(int arg) const override { |
163 | switch (arg) { |
164 | case DNNL_ARG_SRC: return src_md(0); |
165 | case DNNL_ARG_DIFF_DST: return diff_dst_md(0); |
166 | case DNNL_ARG_DIFF_SRC: return diff_src_md(0); |
167 | default: return lrn_pd_t::arg_md(arg); |
168 | } |
169 | } |
170 | |
171 | const memory_desc_t *src_md(int index = 0) const override { |
172 | return index == 0 ? &src_md_ : &glob_zero_md; |
173 | } |
174 | const memory_desc_t *diff_dst_md(int index = 0) const override { |
175 | return index == 0 ? &diff_dst_md_ : &glob_zero_md; |
176 | } |
177 | const memory_desc_t *diff_src_md(int index = 0) const override { |
178 | return index == 0 ? &diff_src_md_ : &glob_zero_md; |
179 | } |
180 | const memory_desc_t *workspace_md(int index = 0) const override { |
181 | return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ |
182 | : &glob_zero_md; |
183 | } |
184 | |
185 | int n_inputs() const override { |
186 | return 2 + (!types::is_zero_md(workspace_md())); |
187 | } |
188 | int n_outputs() const override { return 1; } |
189 | |
190 | protected: |
191 | memory_desc_t diff_src_md_; |
192 | memory_desc_t diff_dst_md_; |
193 | |
194 | lrn_bwd_pd_t(const lrn_desc_t *adesc, const primitive_attr_t *attr, |
195 | const lrn_fwd_pd_t *hint_fwd_pd) |
196 | : lrn_pd_t(adesc, attr, hint_fwd_pd) |
197 | , diff_src_md_(desc_.diff_src_desc) |
198 | , diff_dst_md_(desc_.diff_dst_desc) {} |
199 | |
200 | bool set_default_formats_common() { |
201 | return IMPLICATION(diff_dst_md_.format_kind == format_kind::any, |
202 | memory_desc_init_by_md_and_dt( |
203 | diff_dst_md_, src_md_, diff_dst_md_.data_type) |
204 | == status::success) |
205 | && IMPLICATION(diff_src_md_.format_kind == format_kind::any, |
206 | memory_desc_init_by_md_and_dt( |
207 | diff_src_md_, src_md_, diff_src_md_.data_type) |
208 | == status::success); |
209 | } |
210 | }; |
211 | |
212 | } // namespace impl |
213 | } // namespace dnnl |
214 | |
215 | #endif |
216 | |
217 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
218 | |