1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_RNN_CPU_RNN_PD_HPP |
18 | #define CPU_RNN_CPU_RNN_PD_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/nstl.hpp" |
22 | #include "common/rnn_pd.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/cpu_engine.hpp" |
27 | |
28 | #include "cpu/rnn/rnn_utils.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t { |
35 | using rnn_fwd_pd_t::rnn_fwd_pd_t; |
36 | |
37 | protected: |
38 | status_t set_default_params() { |
39 | using namespace format_tag; |
40 | if (src_layer_md_.format_kind == format_kind::any) |
41 | CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); |
42 | if (dst_layer_md_.format_kind == format_kind::any) |
43 | CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); |
44 | |
45 | if (is_augru()) { |
46 | if (augru_attention_md().format_kind == format_kind::any) |
47 | CHECK(memory_desc_init_by_tag(augru_attention_md(), tnc)); |
48 | } |
49 | |
50 | // Optional parameters |
51 | if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) |
52 | CHECK(memory_desc_init_by_tag(src_iter_md_, ldnc)); |
53 | if (with_src_iter_c() && src_iter_c_md_.format_kind == format_kind::any) |
54 | CHECK(memory_desc_init_by_tag(src_iter_c_md_, ldnc)); |
55 | if (is_lstm_peephole() |
56 | && weights_peephole_md_.format_kind == format_kind::any) |
57 | CHECK(memory_desc_init_by_tag(weights_peephole_md_, ldgo)); |
58 | if (with_bias() && bias_md_.format_kind == format_kind::any) |
59 | CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); |
60 | if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) |
61 | CHECK(memory_desc_init_by_tag(dst_iter_md_, ldnc)); |
62 | if (with_dst_iter_c() && dst_iter_c_md_.format_kind == format_kind::any) |
63 | CHECK(memory_desc_init_by_tag(dst_iter_c_md_, ldnc)); |
64 | |
65 | return status::success; |
66 | } |
67 | |
68 | status_t check_layout_consistency(bool is_brgemm) { |
69 | using namespace format_tag; |
70 | using namespace data_type; |
71 | using namespace types; |
72 | |
73 | const auto is_blocked = [&](const memory_desc_t &md, int ndims, |
74 | bool require_last_dim_contiguous) { |
75 | return md.format_kind == format_kind::blocked && md.ndims == ndims |
76 | && IMPLICATION(require_last_dim_contiguous, |
77 | md.format_desc.blocking.strides[md.ndims - 1] == 1); |
78 | }; |
79 | |
80 | bool ok = true; |
81 | ok = ok && is_blocked(src_layer_md_, 3, true) |
82 | && is_blocked(dst_layer_md_, 3, true); |
83 | ok = ok |
84 | && IMPLICATION(!is_zero_md(&src_iter_md_), |
85 | is_blocked(src_iter_md_, 4, true)) |
86 | && IMPLICATION(!is_zero_md(&src_iter_c_md_), |
87 | is_blocked(src_iter_c_md_, 4, true)) |
88 | && IMPLICATION(!is_zero_md(&dst_iter_md_), |
89 | is_blocked(dst_iter_md_, 4, true)) |
90 | && IMPLICATION(!is_zero_md(&dst_iter_c_md_), |
91 | is_blocked(dst_iter_c_md_, 4, true)); |
92 | |
93 | if (weights_layer_md_.format_kind == format_kind::rnn_packed) |
94 | ok = ok |
95 | && (weights_layer_md_.format_desc.rnn_packed_desc.format |
96 | == rnn_packed_memory_format_t::ldigo_p); |
97 | else |
98 | ok = ok |
99 | && (rnn_utils::is_ldigo(&weights_layer_md_) |
100 | || rnn_utils::is_ldigo_blocked(&weights_layer_md_)); |
101 | |
102 | if (weights_iter_md_.format_kind == format_kind::rnn_packed) |
103 | ok = ok |
104 | && (weights_iter_md_.format_desc.rnn_packed_desc.format |
105 | == rnn_packed_memory_format_t::ldigo_p); |
106 | else |
107 | ok = ok |
108 | && (rnn_utils::is_ldigo(&weights_iter_md_) |
109 | || rnn_utils::is_ldigo_blocked(&weights_iter_md_)); |
110 | |
111 | ok = ok |
112 | && IMPLICATION(is_lstm_peephole(), |
113 | memory_desc_matches_tag(weights_peephole_md_, ldgo)); |
114 | |
115 | if (is_lstm_projection()) { |
116 | if (weights_projection_md_.format_kind == format_kind::rnn_packed) |
117 | ok = ok |
118 | && (weights_projection_md_.format_desc.rnn_packed_desc |
119 | .format |
120 | == rnn_packed_memory_format_t::ldio_p); |
121 | else |
122 | ok = ok |
123 | && (rnn_utils::is_ldio(&weights_projection_md_) |
124 | || rnn_utils::is_ldio_blocked( |
125 | &weights_projection_md_)); |
126 | } |
127 | |
128 | ok = ok |
129 | && IMPLICATION( |
130 | with_bias(), memory_desc_matches_tag(bias_md_, ldgo)); |
131 | |
132 | /* Int8 is supported only for packed weights, if not BRGEMM version */ |
133 | const data_type_t weights_iter_dt = weights_iter_md_.data_type; |
134 | const data_type_t weights_layer_dt = weights_layer_md_.data_type; |
135 | if (!rnn_utils::is_ldigo_blocked(&weights_iter_md_)) |
136 | ok = ok |
137 | && IMPLICATION(weights_iter_dt == s8, |
138 | weights_iter_md_.format_kind |
139 | == format_kind::rnn_packed); |
140 | if (!rnn_utils::is_ldigo_blocked(&weights_layer_md_)) |
141 | ok = ok |
142 | && IMPLICATION(weights_layer_dt == s8, |
143 | weights_layer_md_.format_kind |
144 | == format_kind::rnn_packed); |
145 | return ok ? status::success : status::unimplemented; |
146 | } |
147 | }; |
148 | |
149 | struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t { |
150 | using rnn_bwd_pd_t::rnn_bwd_pd_t; |
151 | |
152 | protected: |
153 | status_t set_default_params() { |
154 | using namespace format_tag; |
155 | if (src_layer_md_.format_kind == format_kind::any) |
156 | CHECK(memory_desc_init_by_tag(src_layer_md_, tnc)); |
157 | if (dst_layer_md_.format_kind == format_kind::any) |
158 | CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc)); |
159 | |
160 | if (is_augru()) { |
161 | if (augru_attention_md().format_kind == format_kind::any) |
162 | CHECK(memory_desc_init_by_tag(augru_attention_md(), tnc)); |
163 | if (diff_augru_attention_md().format_kind == format_kind::any) |
164 | CHECK(memory_desc_init_by_tag(diff_augru_attention_md(), tnc)); |
165 | } |
166 | |
167 | if (diff_src_layer_md_.format_kind == format_kind::any) |
168 | CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc)); |
169 | if (diff_weights_layer_md_.format_kind == format_kind::any) { |
170 | CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo)); |
171 | CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo)); |
172 | } |
173 | if (diff_weights_iter_md_.format_kind == format_kind::any) { |
174 | CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo)); |
175 | CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo)); |
176 | } |
177 | if (diff_dst_layer_md_.format_kind == format_kind::any) |
178 | CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc)); |
179 | |
180 | // Optional parameters |
181 | if (with_src_iter() && src_iter_md_.format_kind == format_kind::any) |
182 | CHECK(memory_desc_init_by_tag(src_iter_md_, ldnc)); |
183 | if (with_src_iter_c() && src_iter_c_md_.format_kind == format_kind::any) |
184 | CHECK(memory_desc_init_by_tag(src_iter_c_md_, ldnc)); |
185 | if (is_lstm_peephole() |
186 | && weights_peephole_md_.format_kind == format_kind::any) |
187 | CHECK(memory_desc_init_by_tag(weights_peephole_md_, ldgo)); |
188 | if (is_lstm_projection() |
189 | && weights_projection_md_.format_kind == format_kind::any) |
190 | CHECK(memory_desc_init_by_tag(weights_projection_md_, ldoi)); |
191 | if (with_bias() && bias_md_.format_kind == format_kind::any) |
192 | CHECK(memory_desc_init_by_tag(bias_md_, ldgo)); |
193 | if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any) |
194 | CHECK(memory_desc_init_by_tag(dst_iter_md_, ldnc)); |
195 | if (with_dst_iter_c() && dst_iter_c_md_.format_kind == format_kind::any) |
196 | CHECK(memory_desc_init_by_tag(dst_iter_c_md_, ldnc)); |
197 | |
198 | if (with_src_iter() |
199 | && diff_src_iter_md_.format_kind == format_kind::any) |
200 | CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldnc)); |
201 | if (with_src_iter_c() |
202 | && diff_src_iter_c_md_.format_kind == format_kind::any) |
203 | CHECK(memory_desc_init_by_tag(diff_src_iter_c_md_, ldnc)); |
204 | if (is_lstm_peephole() |
205 | && diff_weights_peephole_md_.format_kind == format_kind::any) |
206 | CHECK(memory_desc_init_by_tag(diff_weights_peephole_md_, ldgo)); |
207 | if (is_lstm_projection() |
208 | && diff_weights_projection_md_.format_kind == format_kind::any) |
209 | CHECK(memory_desc_init_by_tag(diff_weights_projection_md_, ldio)); |
210 | if (with_bias() && diff_bias_md_.format_kind == format_kind::any) |
211 | CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo)); |
212 | if (with_dst_iter() |
213 | && diff_dst_iter_md_.format_kind == format_kind::any) |
214 | CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldnc)); |
215 | if (with_dst_iter_c() |
216 | && diff_dst_iter_c_md_.format_kind == format_kind::any) |
217 | CHECK(memory_desc_init_by_tag(diff_dst_iter_c_md_, ldnc)); |
218 | |
219 | return status::success; |
220 | } |
221 | |
222 | status_t check_layout_consistency(bool is_brgemm) { |
223 | using namespace format_tag; |
224 | using namespace types; |
225 | |
226 | const auto is_blocked = [&](const memory_desc_t &md, int ndims, |
227 | bool require_last_dim_contiguous) { |
228 | return md.format_kind == format_kind::blocked && md.ndims == ndims |
229 | && IMPLICATION(require_last_dim_contiguous, |
230 | md.format_desc.blocking.strides[md.ndims - 1] == 1); |
231 | }; |
232 | |
233 | bool ok = true; |
234 | ok = ok && is_blocked(src_layer_md_, 3, true) |
235 | && is_blocked(dst_layer_md_, 3, true); |
236 | ok = ok |
237 | && IMPLICATION(!is_zero_md(&src_iter_md_), |
238 | is_blocked(src_iter_md_, 4, true)) |
239 | && IMPLICATION(!is_zero_md(&src_iter_c_md_), |
240 | is_blocked(src_iter_c_md_, 4, true)) |
241 | && IMPLICATION(!is_zero_md(&dst_iter_md_), |
242 | is_blocked(dst_iter_md_, 4, true)) |
243 | && IMPLICATION(!is_zero_md(&dst_iter_c_md_), |
244 | is_blocked(dst_iter_c_md_, 4, true)); |
245 | |
246 | const auto check_weights_consistency = |
247 | [&](const memory_desc_t &weights_md) { |
248 | if (weights_md.format_kind == format_kind::rnn_packed) |
249 | return ok |
250 | && weights_md.format_desc.rnn_packed_desc.format |
251 | == rnn_packed_memory_format_t::ldgoi_p; |
252 | else if (is_brgemm) |
253 | return ok && rnn_utils::is_ldgoi_blocked(&weights_md); |
254 | else |
255 | return ok && rnn_utils::is_ldgoi(&weights_md); |
256 | }; |
257 | |
258 | ok = check_weights_consistency(weights_layer_md_); |
259 | ok = check_weights_consistency(weights_iter_md_); |
260 | |
261 | ok = ok |
262 | && IMPLICATION(is_augru(), |
263 | memory_desc_matches_tag(augru_attention_md(), tnc)); |
264 | ok = ok |
265 | && IMPLICATION(is_lstm_peephole(), |
266 | memory_desc_matches_tag(weights_peephole_md_, ldgo)); |
267 | ok = ok |
268 | && IMPLICATION(is_lstm_projection(), |
269 | memory_desc_matches_tag(weights_projection_md_, ldoi)); |
270 | ok = ok |
271 | && IMPLICATION( |
272 | with_bias(), memory_desc_matches_tag(bias_md_, ldgo)); |
273 | |
274 | ok = ok && is_blocked(diff_src_layer_md_, 3, true) |
275 | && is_blocked(diff_dst_layer_md_, 3, true); |
276 | ok = ok |
277 | && IMPLICATION(!is_zero_md(&diff_src_iter_md_), |
278 | is_blocked(diff_src_iter_md_, 4, true)) |
279 | && IMPLICATION(!is_zero_md(&diff_src_iter_c_md_), |
280 | is_blocked(diff_src_iter_c_md_, 4, true)) |
281 | && IMPLICATION(!is_zero_md(&diff_dst_iter_md_), |
282 | is_blocked(diff_dst_iter_md_, 4, true)) |
283 | && IMPLICATION(!is_zero_md(&diff_dst_iter_c_md_), |
284 | is_blocked(diff_dst_iter_c_md_, 4, true)); |
285 | |
286 | ok = ok |
287 | && IMPLICATION(is_augru(), |
288 | memory_desc_matches_tag( |
289 | diff_augru_attention_md(), tnc)); |
290 | ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_) |
291 | && rnn_utils::is_ldigo(&diff_weights_iter_md_); |
292 | ok = ok |
293 | && IMPLICATION(is_lstm_peephole() |
294 | && !is_zero_md(&diff_weights_peephole_md_), |
295 | memory_desc_matches_tag( |
296 | diff_weights_peephole_md_, ldgo)); |
297 | ok = ok |
298 | && IMPLICATION(!is_zero_md(&diff_weights_projection_md_), |
299 | memory_desc_matches_tag( |
300 | diff_weights_projection_md_, ldio)); |
301 | ok = ok |
302 | && IMPLICATION(!is_zero_md(&diff_bias_md_), |
303 | memory_desc_matches_tag(diff_bias_md_, ldgo)); |
304 | |
305 | return ok ? status::success : status::unimplemented; |
306 | } |
307 | }; |
308 | |
309 | } // namespace cpu |
310 | } // namespace impl |
311 | } // namespace dnnl |
312 | |
313 | #endif |
314 | |