1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_RNN_CPU_RNN_PD_HPP
18#define CPU_RNN_CPU_RNN_PD_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/nstl.hpp"
22#include "common/rnn_pd.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_engine.hpp"
27
28#include "cpu/rnn/rnn_utils.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
35 using rnn_fwd_pd_t::rnn_fwd_pd_t;
36
37protected:
38 status_t set_default_params() {
39 using namespace format_tag;
40 if (src_layer_md_.format_kind == format_kind::any)
41 CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
42 if (dst_layer_md_.format_kind == format_kind::any)
43 CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
44
45 if (is_augru()) {
46 if (augru_attention_md().format_kind == format_kind::any)
47 CHECK(memory_desc_init_by_tag(augru_attention_md(), tnc));
48 }
49
50 // Optional parameters
51 if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
52 CHECK(memory_desc_init_by_tag(src_iter_md_, ldnc));
53 if (with_src_iter_c() && src_iter_c_md_.format_kind == format_kind::any)
54 CHECK(memory_desc_init_by_tag(src_iter_c_md_, ldnc));
55 if (is_lstm_peephole()
56 && weights_peephole_md_.format_kind == format_kind::any)
57 CHECK(memory_desc_init_by_tag(weights_peephole_md_, ldgo));
58 if (with_bias() && bias_md_.format_kind == format_kind::any)
59 CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
60 if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
61 CHECK(memory_desc_init_by_tag(dst_iter_md_, ldnc));
62 if (with_dst_iter_c() && dst_iter_c_md_.format_kind == format_kind::any)
63 CHECK(memory_desc_init_by_tag(dst_iter_c_md_, ldnc));
64
65 return status::success;
66 }
67
68 status_t check_layout_consistency(bool is_brgemm) {
69 using namespace format_tag;
70 using namespace data_type;
71 using namespace types;
72
73 const auto is_blocked = [&](const memory_desc_t &md, int ndims,
74 bool require_last_dim_contiguous) {
75 return md.format_kind == format_kind::blocked && md.ndims == ndims
76 && IMPLICATION(require_last_dim_contiguous,
77 md.format_desc.blocking.strides[md.ndims - 1] == 1);
78 };
79
80 bool ok = true;
81 ok = ok && is_blocked(src_layer_md_, 3, true)
82 && is_blocked(dst_layer_md_, 3, true);
83 ok = ok
84 && IMPLICATION(!is_zero_md(&src_iter_md_),
85 is_blocked(src_iter_md_, 4, true))
86 && IMPLICATION(!is_zero_md(&src_iter_c_md_),
87 is_blocked(src_iter_c_md_, 4, true))
88 && IMPLICATION(!is_zero_md(&dst_iter_md_),
89 is_blocked(dst_iter_md_, 4, true))
90 && IMPLICATION(!is_zero_md(&dst_iter_c_md_),
91 is_blocked(dst_iter_c_md_, 4, true));
92
93 if (weights_layer_md_.format_kind == format_kind::rnn_packed)
94 ok = ok
95 && (weights_layer_md_.format_desc.rnn_packed_desc.format
96 == rnn_packed_memory_format_t::ldigo_p);
97 else
98 ok = ok
99 && (rnn_utils::is_ldigo(&weights_layer_md_)
100 || rnn_utils::is_ldigo_blocked(&weights_layer_md_));
101
102 if (weights_iter_md_.format_kind == format_kind::rnn_packed)
103 ok = ok
104 && (weights_iter_md_.format_desc.rnn_packed_desc.format
105 == rnn_packed_memory_format_t::ldigo_p);
106 else
107 ok = ok
108 && (rnn_utils::is_ldigo(&weights_iter_md_)
109 || rnn_utils::is_ldigo_blocked(&weights_iter_md_));
110
111 ok = ok
112 && IMPLICATION(is_lstm_peephole(),
113 memory_desc_matches_tag(weights_peephole_md_, ldgo));
114
115 if (is_lstm_projection()) {
116 if (weights_projection_md_.format_kind == format_kind::rnn_packed)
117 ok = ok
118 && (weights_projection_md_.format_desc.rnn_packed_desc
119 .format
120 == rnn_packed_memory_format_t::ldio_p);
121 else
122 ok = ok
123 && (rnn_utils::is_ldio(&weights_projection_md_)
124 || rnn_utils::is_ldio_blocked(
125 &weights_projection_md_));
126 }
127
128 ok = ok
129 && IMPLICATION(
130 with_bias(), memory_desc_matches_tag(bias_md_, ldgo));
131
132 /* Int8 is supported only for packed weights, if not BRGEMM version */
133 const data_type_t weights_iter_dt = weights_iter_md_.data_type;
134 const data_type_t weights_layer_dt = weights_layer_md_.data_type;
135 if (!rnn_utils::is_ldigo_blocked(&weights_iter_md_))
136 ok = ok
137 && IMPLICATION(weights_iter_dt == s8,
138 weights_iter_md_.format_kind
139 == format_kind::rnn_packed);
140 if (!rnn_utils::is_ldigo_blocked(&weights_layer_md_))
141 ok = ok
142 && IMPLICATION(weights_layer_dt == s8,
143 weights_layer_md_.format_kind
144 == format_kind::rnn_packed);
145 return ok ? status::success : status::unimplemented;
146 }
147};
148
149struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
150 using rnn_bwd_pd_t::rnn_bwd_pd_t;
151
152protected:
153 status_t set_default_params() {
154 using namespace format_tag;
155 if (src_layer_md_.format_kind == format_kind::any)
156 CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
157 if (dst_layer_md_.format_kind == format_kind::any)
158 CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
159
160 if (is_augru()) {
161 if (augru_attention_md().format_kind == format_kind::any)
162 CHECK(memory_desc_init_by_tag(augru_attention_md(), tnc));
163 if (diff_augru_attention_md().format_kind == format_kind::any)
164 CHECK(memory_desc_init_by_tag(diff_augru_attention_md(), tnc));
165 }
166
167 if (diff_src_layer_md_.format_kind == format_kind::any)
168 CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc));
169 if (diff_weights_layer_md_.format_kind == format_kind::any) {
170 CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo));
171 CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo));
172 }
173 if (diff_weights_iter_md_.format_kind == format_kind::any) {
174 CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo));
175 CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo));
176 }
177 if (diff_dst_layer_md_.format_kind == format_kind::any)
178 CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc));
179
180 // Optional parameters
181 if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
182 CHECK(memory_desc_init_by_tag(src_iter_md_, ldnc));
183 if (with_src_iter_c() && src_iter_c_md_.format_kind == format_kind::any)
184 CHECK(memory_desc_init_by_tag(src_iter_c_md_, ldnc));
185 if (is_lstm_peephole()
186 && weights_peephole_md_.format_kind == format_kind::any)
187 CHECK(memory_desc_init_by_tag(weights_peephole_md_, ldgo));
188 if (is_lstm_projection()
189 && weights_projection_md_.format_kind == format_kind::any)
190 CHECK(memory_desc_init_by_tag(weights_projection_md_, ldoi));
191 if (with_bias() && bias_md_.format_kind == format_kind::any)
192 CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
193 if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
194 CHECK(memory_desc_init_by_tag(dst_iter_md_, ldnc));
195 if (with_dst_iter_c() && dst_iter_c_md_.format_kind == format_kind::any)
196 CHECK(memory_desc_init_by_tag(dst_iter_c_md_, ldnc));
197
198 if (with_src_iter()
199 && diff_src_iter_md_.format_kind == format_kind::any)
200 CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldnc));
201 if (with_src_iter_c()
202 && diff_src_iter_c_md_.format_kind == format_kind::any)
203 CHECK(memory_desc_init_by_tag(diff_src_iter_c_md_, ldnc));
204 if (is_lstm_peephole()
205 && diff_weights_peephole_md_.format_kind == format_kind::any)
206 CHECK(memory_desc_init_by_tag(diff_weights_peephole_md_, ldgo));
207 if (is_lstm_projection()
208 && diff_weights_projection_md_.format_kind == format_kind::any)
209 CHECK(memory_desc_init_by_tag(diff_weights_projection_md_, ldio));
210 if (with_bias() && diff_bias_md_.format_kind == format_kind::any)
211 CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo));
212 if (with_dst_iter()
213 && diff_dst_iter_md_.format_kind == format_kind::any)
214 CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldnc));
215 if (with_dst_iter_c()
216 && diff_dst_iter_c_md_.format_kind == format_kind::any)
217 CHECK(memory_desc_init_by_tag(diff_dst_iter_c_md_, ldnc));
218
219 return status::success;
220 }
221
222 status_t check_layout_consistency(bool is_brgemm) {
223 using namespace format_tag;
224 using namespace types;
225
226 const auto is_blocked = [&](const memory_desc_t &md, int ndims,
227 bool require_last_dim_contiguous) {
228 return md.format_kind == format_kind::blocked && md.ndims == ndims
229 && IMPLICATION(require_last_dim_contiguous,
230 md.format_desc.blocking.strides[md.ndims - 1] == 1);
231 };
232
233 bool ok = true;
234 ok = ok && is_blocked(src_layer_md_, 3, true)
235 && is_blocked(dst_layer_md_, 3, true);
236 ok = ok
237 && IMPLICATION(!is_zero_md(&src_iter_md_),
238 is_blocked(src_iter_md_, 4, true))
239 && IMPLICATION(!is_zero_md(&src_iter_c_md_),
240 is_blocked(src_iter_c_md_, 4, true))
241 && IMPLICATION(!is_zero_md(&dst_iter_md_),
242 is_blocked(dst_iter_md_, 4, true))
243 && IMPLICATION(!is_zero_md(&dst_iter_c_md_),
244 is_blocked(dst_iter_c_md_, 4, true));
245
246 const auto check_weights_consistency =
247 [&](const memory_desc_t &weights_md) {
248 if (weights_md.format_kind == format_kind::rnn_packed)
249 return ok
250 && weights_md.format_desc.rnn_packed_desc.format
251 == rnn_packed_memory_format_t::ldgoi_p;
252 else if (is_brgemm)
253 return ok && rnn_utils::is_ldgoi_blocked(&weights_md);
254 else
255 return ok && rnn_utils::is_ldgoi(&weights_md);
256 };
257
258 ok = check_weights_consistency(weights_layer_md_);
259 ok = check_weights_consistency(weights_iter_md_);
260
261 ok = ok
262 && IMPLICATION(is_augru(),
263 memory_desc_matches_tag(augru_attention_md(), tnc));
264 ok = ok
265 && IMPLICATION(is_lstm_peephole(),
266 memory_desc_matches_tag(weights_peephole_md_, ldgo));
267 ok = ok
268 && IMPLICATION(is_lstm_projection(),
269 memory_desc_matches_tag(weights_projection_md_, ldoi));
270 ok = ok
271 && IMPLICATION(
272 with_bias(), memory_desc_matches_tag(bias_md_, ldgo));
273
274 ok = ok && is_blocked(diff_src_layer_md_, 3, true)
275 && is_blocked(diff_dst_layer_md_, 3, true);
276 ok = ok
277 && IMPLICATION(!is_zero_md(&diff_src_iter_md_),
278 is_blocked(diff_src_iter_md_, 4, true))
279 && IMPLICATION(!is_zero_md(&diff_src_iter_c_md_),
280 is_blocked(diff_src_iter_c_md_, 4, true))
281 && IMPLICATION(!is_zero_md(&diff_dst_iter_md_),
282 is_blocked(diff_dst_iter_md_, 4, true))
283 && IMPLICATION(!is_zero_md(&diff_dst_iter_c_md_),
284 is_blocked(diff_dst_iter_c_md_, 4, true));
285
286 ok = ok
287 && IMPLICATION(is_augru(),
288 memory_desc_matches_tag(
289 diff_augru_attention_md(), tnc));
290 ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_)
291 && rnn_utils::is_ldigo(&diff_weights_iter_md_);
292 ok = ok
293 && IMPLICATION(is_lstm_peephole()
294 && !is_zero_md(&diff_weights_peephole_md_),
295 memory_desc_matches_tag(
296 diff_weights_peephole_md_, ldgo));
297 ok = ok
298 && IMPLICATION(!is_zero_md(&diff_weights_projection_md_),
299 memory_desc_matches_tag(
300 diff_weights_projection_md_, ldio));
301 ok = ok
302 && IMPLICATION(!is_zero_md(&diff_bias_md_),
303 memory_desc_matches_tag(diff_bias_md_, ldgo));
304
305 return ok ? status::success : status::unimplemented;
306 }
307};
308
309} // namespace cpu
310} // namespace impl
311} // namespace dnnl
312
313#endif
314