1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_RNN_REF_RNN_HPP |
18 | #define CPU_RNN_REF_RNN_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <tuple> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/memory_tracking.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/reorder.hpp" |
27 | #include "common/type_helpers.hpp" |
28 | #include "common/utils.hpp" |
29 | |
30 | #include "cpu/gemm/gemm.hpp" |
31 | #include "cpu/gemm/os_blas.hpp" |
32 | |
33 | #include "cpu/rnn/cpu_rnn_pd.hpp" |
34 | #include "cpu/rnn/postgemm_dispatcher.hpp" |
35 | #if DNNL_X64 |
36 | #include "cpu/x64/rnn/rnn_brgemm_utils.hpp" |
37 | #endif |
38 | #include "cpu/rnn/rnn_utils.hpp" |
39 | namespace dnnl { |
40 | namespace impl { |
41 | namespace cpu { |
42 | |
43 | namespace { |
44 | template <typename gates_t, typename acc_t> |
45 | // The loop body needs to be put in a function as some versions of icc have |
46 | // an issue with lambdas & macros inside omp simd loops |
47 | inline void body_loop(int i, int k, const gates_t *ws_gates, acc_t *diff_bias, |
48 | const rnn_utils::rnn_conf_t &rnn) { |
49 | for (int j = 0; j < rnn.mb; j++) |
50 | diff_bias[i * rnn.dhc + k] |
51 | += ws_gates[j * rnn.scratch_gates_ld + i * rnn.dhc + k]; |
52 | } |
53 | } // namespace |
54 | |
55 | template <typename gates_t, typename acc_t> |
56 | void gates_reduction(const rnn_utils::rnn_conf_t &rnn, const gates_t *ws_gates_, |
57 | acc_t *diff_bias_) { |
58 | |
59 | // @todo block k on simd-width to enable vectorization in |
60 | // parallel_nd path |
61 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP && _OPENMP >= 201307 \ |
62 | && (!defined(__INTEL_COMPILER) || __INTEL_COMPILER < 1910) |
63 | #pragma omp parallel for simd collapse(2) |
64 | for (int i = 0; i < rnn.n_gates; i++) |
65 | for (int k = 0; k < rnn.dhc; k++) |
66 | body_loop(i, k, ws_gates_, diff_bias_, rnn); |
67 | #else |
68 | parallel_nd(rnn.n_gates, rnn.dhc, [&](dim_t i, dim_t k) { |
69 | body_loop(i, k, ws_gates_, diff_bias_, rnn); |
70 | }); |
71 | #endif |
72 | } |
73 | |
74 | template <prop_kind_t aprop, impl::data_type_t src_type, |
75 | impl::data_type_t weights_type, impl::data_type_t acc_type> |
76 | struct _ref_rnn_common_t : public primitive_t { |
77 | static constexpr impl::data_type_t scratch_type |
78 | = aprop == prop_kind::forward ? acc_type : src_type; |
79 | |
80 | /* These types are defined for each element in the cell execution */ |
81 | typedef typename prec_traits<src_type>::type src_layer_t; |
82 | typedef typename prec_traits<src_type>::type src_iter_t; |
83 | typedef typename prec_traits<src_type>::type dst_layer_t; |
84 | typedef typename prec_traits<src_type>::type dst_iter_t; |
85 | typedef typename prec_traits<weights_type>::type weights_t; |
86 | typedef typename prec_traits<src_type>::type gemm_data_t; |
87 | typedef typename prec_traits<acc_type>::type gemm_acc_t; |
88 | typedef typename prec_traits<scratch_type>::type scratch_t; |
89 | typedef typename prec_traits<src_type>::type ht_t; |
90 | typedef typename prec_traits<src_type>::type gates_t; |
91 | |
92 | using class_name |
93 | = _ref_rnn_common_t<aprop, src_type, weights_type, acc_type>; |
94 | #if DNNL_X64 |
95 | using ref_rnn_brgemm_t = x64::rnn_brgemm_utils::rnn_brgemm_t<aprop>; |
96 | #endif |
97 | |
98 | typedef rnn_cell_execution_sig((class_name::*cell_execution_f)); |
99 | typedef rnn_grid_execution_sig((class_name::*grid_execution_f)); |
100 | typedef rnn_merged_layer_execution_sig( |
101 | (class_name::*merged_layer_execution_f)); |
102 | |
103 | typedef rnn_gemm_sig((class_name::*gemm_t)); |
104 | typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t)); |
105 | typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t)); |
106 | typedef rnn_weights_assign_sig((class_name::*weights_assign_t)); |
107 | |
108 | using base_pd_t = |
109 | typename utils::conditional<false || aprop == prop_kind::forward, |
110 | cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type; |
111 | |
112 | struct pd_t : public base_pd_t { |
113 | using base_pd_t::base_pd_t; |
114 | |
115 | const char *impl_name() const { |
116 | #if DNNL_X64 |
117 | using namespace dnnl::impl::cpu::x64; |
118 | return rnn_.is_brgemm |
119 | ? JIT_IMPL_NAME_HELPER("brgemm:" , rnn_.brgemm_isa, "" ) |
120 | : "ref" ; |
121 | #else |
122 | return "ref" ; |
123 | #endif |
124 | } |
125 | |
126 | DECLARE_COMMON_PD_T(impl_name(), class_name, USE_GLOBAL_SCRATCHPAD); |
127 | |
128 | status_t init_ref(engine_t *engine) { |
129 | using namespace prop_kind; |
130 | using namespace utils; |
131 | using namespace format_tag; |
132 | using namespace rnn_utils; |
133 | const alg_kind_t cell_kind = this->desc()->cell_kind; |
134 | |
135 | const data_type_t src_layer_dt |
136 | = this->desc()->src_layer_desc.data_type; |
137 | const data_type_t weights_iter_dt |
138 | = this->desc()->weights_iter_desc.data_type; |
139 | const data_type_t weights_layer_dt |
140 | = this->desc()->weights_layer_desc.data_type; |
141 | |
142 | bool ok = true |
143 | && one_of(cell_kind, alg_kind::vanilla_rnn, |
144 | alg_kind::vanilla_lstm, alg_kind::vanilla_gru, |
145 | alg_kind::lbr_gru, alg_kind::vanilla_augru, |
146 | alg_kind::lbr_augru) |
147 | && IMPLICATION(aprop == prop_kind::forward, |
148 | one_of(this->desc()->prop_kind, forward_training, |
149 | forward_inference)) |
150 | && IMPLICATION(aprop == backward, |
151 | one_of(this->desc()->prop_kind, backward)) |
152 | && src_layer_dt == src_type |
153 | && everyone_is( |
154 | weights_type, weights_iter_dt, weights_layer_dt) |
155 | && this->set_default_params() == status::success |
156 | && this->with_bias(); |
157 | |
158 | if (!ok) return status::unimplemented; |
159 | |
160 | rnn_ = zero<decltype(rnn_)>(); |
161 | rnn_.is_brgemm = false; |
162 | ok = init_conf<class_name>(rnn_, *this->desc(), *this->attr(), |
163 | this->src_md(0), this->src_md(1), this->src_md(2), |
164 | this->weights_md(0), this->weights_md(1), |
165 | this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION), this->dst_md(0), |
166 | this->dst_md(1), this->dst_md(2), |
167 | this->arg_md(DNNL_ARG_BIAS)); |
168 | if (!ok) return status::unimplemented; |
169 | |
170 | if (rnn_.is_bf16_conf()) { |
171 | if (!utils::one_of( |
172 | rnn_.bias_dt, data_type::bf16, data_type::f32) |
173 | || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt |
174 | || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef, |
175 | data_type::bf16, data_type::f32)) |
176 | return status::unimplemented; |
177 | } else if (rnn_.bias_dt != data_type::f32 |
178 | || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef, |
179 | data_type::f32) |
180 | || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt) |
181 | return status::unimplemented; |
182 | |
183 | /* check that no data shift have been passed to s8s8 lstm */ |
184 | if (!IMPLICATION(rnn_.is_signed_int8_conf(), |
185 | this->attr()->rnn_data_qparams_.shift_ == 0.f)) |
186 | return status::unimplemented; |
187 | |
188 | /* check that only supported attr have been passed */ |
189 | primitive_attr_t::skip_mask_t attr_mask |
190 | = primitive_attr_t::skip_mask_t::rnn_tparams; |
191 | if (weights_layer_dt == data_type::s8) |
192 | attr_mask = attr_mask |
193 | | primitive_attr_t::skip_mask_t::rnn_data_qparams |
194 | | primitive_attr_t::skip_mask_t::rnn_weights_qparams |
195 | | primitive_attr_t::skip_mask_t:: |
196 | rnn_weights_projection_qparams; |
197 | ok = ok && this->attr()->has_default_values(attr_mask); |
198 | if (!ok) return status::unimplemented; |
199 | |
200 | // Set weights descriptors to desired format |
201 | memory_desc_t new_weights_layer_md = *this->weights_md(0); |
202 | CHECK(set_expected_desc(rnn_, new_weights_layer_md, |
203 | rnn_utils::weights_type_t::layer)); |
204 | if (this->weights_layer_md_.format_kind == format_kind::any) { |
205 | this->weights_layer_md_ = new_weights_layer_md; |
206 | } else if (this->weights_layer_md_.format_kind |
207 | == format_kind::rnn_packed) { |
208 | if (this->weights_layer_md_ != new_weights_layer_md) |
209 | return status::unimplemented; |
210 | } |
211 | |
212 | memory_desc_t new_weights_iter_md = *this->weights_md(1); |
213 | CHECK(set_expected_desc(rnn_, new_weights_iter_md, |
214 | rnn_utils::weights_type_t::iter)); |
215 | if (this->weights_iter_md_.format_kind == format_kind::any) { |
216 | this->weights_iter_md_ = new_weights_iter_md; |
217 | } else if (this->weights_iter_md_.format_kind |
218 | == format_kind::rnn_packed) { |
219 | if (this->weights_iter_md_ != new_weights_iter_md) |
220 | return status::unimplemented; |
221 | } |
222 | |
223 | if (rnn_.is_lstm_projection) { |
224 | memory_desc_t new_weights_projection_md |
225 | = *this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION); |
226 | CHECK(set_expected_desc(rnn_, new_weights_projection_md, |
227 | rnn_utils::weights_type_t::projection)); |
228 | if (this->weights_projection_md_.format_kind |
229 | == format_kind::any) { |
230 | this->weights_projection_md_ = new_weights_projection_md; |
231 | } else if (this->weights_projection_md_.format_kind |
232 | == format_kind::rnn_packed) { |
233 | if (this->weights_projection_md_ |
234 | != new_weights_projection_md) |
235 | return status::unimplemented; |
236 | } |
237 | } |
238 | |
239 | CHECK(this->check_layout_consistency(false /*is_brgemm*/)); |
240 | |
241 | set_conf<class_name>(rnn_, *this->desc(), this->weights_md(0), |
242 | this->weights_md(1), |
243 | this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION), |
244 | this->diff_weights_md(0), this->diff_weights_md(1), |
245 | this->arg_md(DNNL_ARG_DIFF_WEIGHTS_PROJECTION)); |
246 | set_workspace_sizes<class_name>(rnn_, *this->desc()); |
247 | return status::success; |
248 | } |
249 | |
250 | status_t init_brgemm(engine_t *engine) { |
251 | using namespace prop_kind; |
252 | using namespace utils; |
253 | using namespace format_tag; |
254 | using namespace rnn_utils; |
255 | #if DNNL_X64 |
256 | using namespace x64; |
257 | const alg_kind_t cell_kind = this->desc()->cell_kind; |
258 | |
259 | const data_type_t src_layer_dt |
260 | = this->desc()->src_layer_desc.data_type; |
261 | const data_type_t weights_iter_dt |
262 | = this->desc()->weights_iter_desc.data_type; |
263 | const data_type_t weights_layer_dt |
264 | = this->desc()->weights_layer_desc.data_type; |
265 | |
266 | bool ok = one_of(cell_kind, alg_kind::vanilla_rnn, |
267 | alg_kind::vanilla_lstm, alg_kind::vanilla_gru, |
268 | alg_kind::vanilla_augru) |
269 | && IMPLICATION(aprop == prop_kind::forward, |
270 | one_of(this->desc()->prop_kind, forward_training, |
271 | forward_inference)) |
272 | && IMPLICATION(aprop == backward, |
273 | one_of(this->desc()->prop_kind, backward)) |
274 | // cell_type (or src_type) and primitive data type should |
275 | // match, except for the bf32 case. |
276 | && IMPLICATION( |
277 | this->attr()->fpmath_mode_ == fpmath_mode::strict, |
278 | src_layer_dt == src_type |
279 | && everyone_is(weights_type, |
280 | weights_iter_dt, weights_layer_dt)) |
281 | && this->set_default_params() == status::success |
282 | && this->with_bias(); |
283 | |
284 | if (!ok) return status::unimplemented; |
285 | |
286 | rnn_ = zero<decltype(rnn_)>(); |
287 | rnn_.is_brgemm = true; |
288 | ok = init_conf<class_name>(rnn_, *this->desc(), *this->attr(), |
289 | this->src_md(0), this->src_md(1), this->src_md(2), |
290 | this->weights_md(0), this->weights_md(1), |
291 | this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION), this->dst_md(0), |
292 | this->dst_md(1), this->dst_md(2), |
293 | this->arg_md(DNNL_ARG_BIAS)); |
294 | |
295 | ok = ok |
296 | && IMPLICATION(one_of(this->desc()->prop_kind, |
297 | forward_training, backward), |
298 | (rnn_.is_bf16_conf() || rnn_.is_f32_conf())); |
299 | |
300 | if (!ok) return status::unimplemented; |
301 | |
302 | // Support for GRU / AUGRU cell in BRGEMM-based implementation is |
303 | // limited by forward_inference pass for now, all_f32 is disabled |
304 | // due to performance degradation. |
305 | // TODO: Improve GRU / AUGRU coverage in BRGEMM-based implementation |
306 | ok = IMPLICATION(rnn_.is_orig_gru, |
307 | this->desc()->prop_kind == forward_inference |
308 | && !rnn_.is_cell_dt_f32()); |
309 | if (!ok) return status::unimplemented; |
310 | |
311 | if (rnn_.is_cell_dt_f32() |
312 | && utils::one_of(this->desc()->prop_kind, backward, |
313 | forward_training)) |
314 | return status::unimplemented; |
315 | |
316 | if (!(IMPLICATION((cell_kind == alg_kind::vanilla_lstm |
317 | && rnn_.is_lstm_projection), |
318 | this->desc()->prop_kind == forward_inference))) |
319 | return status::unimplemented; |
320 | |
321 | if (rnn_.is_bf16_conf()) { |
322 | if (!mayiuse(avx512_core_bf16) |
323 | || !utils::one_of( |
324 | rnn_.bias_dt, data_type::bf16, data_type::f32) |
325 | || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt |
326 | || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef, |
327 | data_type::bf16, data_type::f32)) |
328 | return status::unimplemented; |
329 | } else if (rnn_.bias_dt != data_type::f32 |
330 | || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef, |
331 | data_type::f32) |
332 | || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt) |
333 | return status::unimplemented; |
334 | |
335 | if (rnn_.is_signed_int8_conf() && !mayiuse(avx512_core_amx)) |
336 | return status::unimplemented; |
337 | if (rnn_.is_int8_conf() && !mayiuse(avx512_core_vnni)) |
338 | return status::unimplemented; |
339 | if (rnn_.is_f32_conf() && !mayiuse(avx512_core)) |
340 | return status::unimplemented; |
341 | |
342 | /* check that no shift have been passed to s8s8 amx lstm */ |
343 | if (!IMPLICATION(rnn_.is_signed_int8_conf(), |
344 | this->attr()->rnn_data_qparams_.shift_ == 0)) |
345 | return status::unimplemented; |
346 | |
347 | /* check that only supported attr have been passed */ |
348 | primitive_attr_t::skip_mask_t attr_mask |
349 | = primitive_attr_t::skip_mask_t::rnn_tparams; |
350 | if (weights_layer_dt == data_type::s8) |
351 | attr_mask = attr_mask |
352 | | primitive_attr_t::skip_mask_t::rnn_data_qparams |
353 | | primitive_attr_t::skip_mask_t::rnn_weights_qparams |
354 | | primitive_attr_t::skip_mask_t:: |
355 | rnn_weights_projection_qparams; |
356 | ok = ok && this->attr()->has_default_values(attr_mask); |
357 | if (!ok) return status::unimplemented; |
358 | |
359 | set_conf<class_name>(rnn_, *this->desc(), this->weights_md(0), |
360 | this->weights_md(1), |
361 | this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION), |
362 | this->diff_weights_md(0), this->diff_weights_md(1), |
363 | this->arg_md(DNNL_ARG_DIFF_WEIGHTS_PROJECTION)); |
364 | |
365 | CHECK(ref_rnn_brgemm_t::configure_brgemm(rnn_, |
366 | this->desc()->cell_kind, sizeof(src_layer_t), |
367 | sizeof(scratch_t))); |
368 | |
369 | // must be called after configure_brgemm() |
370 | set_workspace_sizes<class_name>(rnn_, *this->desc()); |
371 | |
372 | // Only AMX LSTM supports s8s8 now |
373 | if (rnn_.is_signed_int8_conf() && !rnn_.is_cell_int8_amx()) |
374 | return status::unimplemented; |
375 | |
376 | // Set weights descriptors to desired format |
377 | memory_desc_t new_weights_layer_md = *this->weights_md(0); |
378 | CHECK(set_expected_desc(rnn_, new_weights_layer_md, |
379 | rnn_utils::weights_type_t::layer)); |
380 | if (this->weights_layer_md_.format_kind == format_kind::any) { |
381 | this->weights_layer_md_ = new_weights_layer_md; |
382 | } else if (this->weights_layer_md_ != new_weights_layer_md) { |
383 | return status::unimplemented; |
384 | } |
385 | |
386 | memory_desc_t new_weights_iter_md = *this->weights_md(1); |
387 | CHECK(set_expected_desc(rnn_, new_weights_iter_md, |
388 | rnn_utils::weights_type_t::iter)); |
389 | if (this->weights_iter_md_.format_kind == format_kind::any) { |
390 | this->weights_iter_md_ = new_weights_iter_md; |
391 | } else if (this->weights_iter_md_ != new_weights_iter_md) { |
392 | return status::unimplemented; |
393 | } |
394 | if (rnn_.is_lstm_projection) { |
395 | memory_desc_t new_weights_projection_md |
396 | = *this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION); |
397 | CHECK(set_expected_desc(rnn_, new_weights_projection_md, |
398 | rnn_utils::weights_type_t::projection)); |
399 | if (this->weights_projection_md_.format_kind |
400 | == format_kind::any) { |
401 | this->weights_projection_md_ = new_weights_projection_md; |
402 | } else if (this->weights_projection_md_ |
403 | != new_weights_projection_md) { |
404 | return status::unimplemented; |
405 | } |
406 | } |
407 | if (rnn_.is_unsigned_int8_conf()) { |
408 | const memory_desc_wrapper &weights_layer_d( |
409 | this->weights_layer_md_); |
410 | const memory_desc_wrapper &weights_iter_d( |
411 | this->weights_iter_md_); |
412 | const auto &pdims_l = weights_layer_d.padded_dims(); |
413 | const auto &pdims_i = weights_iter_d.padded_dims(); |
414 | rnn_.weights_layer_comp_offset = rnn_.n_layer * rnn_.n_dir |
415 | * rnn_.n_gates * pdims_l[2] * pdims_l[4]; |
416 | rnn_.weights_iter_comp_offset = rnn_.n_layer * rnn_.n_dir |
417 | * rnn_.n_gates * pdims_i[2] * pdims_i[4]; |
418 | if (rnn_.is_lstm_projection) { |
419 | const memory_desc_wrapper &weights_proj_d( |
420 | this->weights_projection_md_); |
421 | const auto &pdims_p = weights_proj_d.padded_dims(); |
422 | rnn_.weights_projection_comp_offset = rnn_.n_layer |
423 | * rnn_.n_dir * pdims_p[2] * pdims_p[3]; |
424 | } else { |
425 | rnn_.weights_projection_comp_offset = 0; |
426 | } |
427 | } |
428 | CHECK(this->check_layout_consistency(true /*is_brgemm*/)); |
429 | |
430 | if (rnn_.is_bf32()) { |
431 | const memory_desc_wrapper weights_layer_d( |
432 | this->weights_layer_md_); |
433 | memory_desc_t weights_layer_md; |
434 | const memory_desc_wrapper weights_iter_d( |
435 | this->weights_iter_md_); |
436 | memory_desc_t weights_iter_md; |
437 | |
438 | const auto bf16_tag = rnn_.n_block == 64 |
439 | ? format_tag::ldgOI64o2i |
440 | : format_tag::ldgOI32o2i; |
441 | memory_desc_init_by_tag(weights_layer_md, |
442 | weights_layer_d.ndims(), weights_layer_d.dims(), |
443 | data_type::bf16, bf16_tag); |
444 | CHECK(reorder_primitive_desc_create(bf32_wei_layer_reorder_pd_, |
445 | engine, weights_layer_d.md_, &weights_layer_md, |
446 | nullptr)); |
447 | |
448 | memory_desc_init_by_tag(weights_iter_md, weights_iter_d.ndims(), |
449 | weights_iter_d.dims(), data_type::bf16, bf16_tag); |
450 | CHECK(reorder_primitive_desc_create(bf32_wei_iter_reorder_pd_, |
451 | engine, weights_iter_d.md_, &weights_iter_md, nullptr)); |
452 | } |
453 | |
454 | return status::success; |
455 | #else |
456 | return status::unimplemented; |
457 | #endif |
458 | } |
459 | |
460 | status_t init(engine_t *engine) { |
461 | status_t st = init_brgemm(engine); |
462 | if (st != status::success) { |
463 | rnn_.is_brgemm = false; |
464 | st = init_ref(engine); |
465 | } |
466 | if (st == status::success) { |
467 | size_t scratchpad_sz {0}, ws_sz {0}; |
468 | get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz); |
469 | |
470 | init_scratchpad(scratchpad_sz); |
471 | // initialize the workspace if needed |
472 | if (rnn_.is_training) { |
473 | dims_t ws_dims = {(dim_t)ws_sz}; |
474 | memory_desc_init_by_tag(this->ws_md_, 1, ws_dims, |
475 | data_type::u8, format_tag::x); |
476 | } |
477 | } |
478 | return st; |
479 | } |
480 | |
481 | rnn_utils::rnn_conf_t rnn_; |
482 | #if DNNL_X64 |
483 | std::shared_ptr<primitive_desc_t> bf32_wei_layer_reorder_pd_; |
484 | std::shared_ptr<primitive_desc_t> bf32_wei_iter_reorder_pd_; |
485 | #endif |
486 | private: |
487 | void init_scratchpad(size_t scratchpad_sz) { |
488 | using namespace memory_tracking::names; |
489 | auto scratchpad = this->scratchpad_registry().registrar(); |
490 | |
491 | { |
492 | static constexpr size_t data_size |
493 | = 1; // "true" data size already incorporated |
494 | static constexpr size_t data_align |
495 | = alignof(float); // "worst" case scenario |
496 | static constexpr size_t perf_align = 4096; |
497 | scratchpad.book(key_rnn_space, scratchpad_sz, data_size, |
498 | data_align, perf_align); |
499 | } |
500 | |
501 | const int max_nparts |
502 | = utils::one_of(this->cell_kind(), alg_kind::vanilla_gru, |
503 | alg_kind::vanilla_augru) |
504 | ? 2 |
505 | : 1; |
506 | const int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts; |
507 | scratchpad.template book<float *>( |
508 | key_rnn_ptrs_wei_layer, ptr_wei_sz); |
509 | scratchpad.template book<float *>( |
510 | key_rnn_ptrs_wei_iter, ptr_wei_sz); |
511 | scratchpad.template book<float *>( |
512 | key_rnn_ptrs_wei_projection, ptr_wei_sz); |
513 | |
514 | const auto bias_dt_size = types::data_type_size( |
515 | this->arg_md(DNNL_ARG_BIAS)->data_type); |
516 | scratchpad.template book<void *>( |
517 | key_rnn_ptrs_bia, ptr_wei_sz * bias_dt_size); |
518 | |
519 | scratchpad.template book<scratch_t>( |
520 | key_rnn_gates, rnn_.scratch_gates_size); |
521 | scratchpad.template book<ht_t>(key_rnn_ht, rnn_.scratch_ht_size); |
522 | scratchpad.template book<gemm_acc_t>( |
523 | key_rnn_diff_ht, rnn_.scratch_diff_ht_size); |
524 | scratchpad.template book<scratch_t>( |
525 | key_rnn_cell, rnn_.scratch_cell_size); |
526 | |
527 | #if DNNL_X64 |
528 | if (rnn_.is_brgemm) { |
529 | ref_rnn_brgemm_t::init_scratchpad(rnn_, scratchpad, |
530 | sizeof(gemm_acc_t), alignof(gemm_acc_t)); |
531 | if (rnn_.is_bf32()) { |
532 | scratchpad.book(key_nested_multiple + 0, |
533 | bf32_wei_layer_reorder_pd_->scratchpad_registry()); |
534 | scratchpad.book(key_nested_multiple + 1, |
535 | bf32_wei_iter_reorder_pd_->scratchpad_registry()); |
536 | } |
537 | } |
538 | #endif |
539 | } |
540 | }; |
541 | |
542 | _ref_rnn_common_t(const pd_t *apd) |
543 | : primitive_t(apd), rnn_postgemm_(nullptr) {} |
544 | |
545 | status_t init(engine_t *engine) override { |
546 | /// @todo set max_feature_size assuming that we limit the number of |
547 | /// iterations and layer to one if slc != dhc and sic != dhc |
548 | /// respectively |
549 | |
550 | bias_preparation_func = &class_name::bias_prepare; |
551 | bias_finalization_func = &class_name::bias_finalize; |
552 | |
553 | const auto set_gemm_funcs |
554 | = [](bool packed_gemm, gemm_t &g, weights_assign_t &a, |
555 | bool is_brgemm) { |
556 | if (packed_gemm) { |
557 | g = &class_name::packed_gemm; |
558 | a = &class_name::assign_packed_weights; |
559 | } else { |
560 | g = (!is_brgemm) ? &class_name::gemm : nullptr; |
561 | a = &class_name::assign_weights; |
562 | } |
563 | }; |
564 | set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func, |
565 | weights_iter_assign_func, pd()->rnn_.is_brgemm); |
566 | |
567 | set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func, |
568 | weights_layer_assign_func, pd()->rnn_.is_brgemm); |
569 | |
570 | if (pd()->rnn_.is_lstm_projection) { |
571 | set_gemm_funcs(pd()->rnn_.use_projection_packed_gemm, |
572 | gemm_projection_func, weights_projection_assign_func, |
573 | pd()->rnn_.is_brgemm); |
574 | } |
575 | |
576 | rnn_postgemm_ = new rnn_postgemm_dispatcher<aprop, src_type, |
577 | scratch_type, acc_type>(pd()->rnn_, pd()); |
578 | assert(rnn_postgemm_ != nullptr); |
579 | switch (pd()->cell_kind()) { |
580 | case alg_kind::vanilla_rnn: |
581 | case alg_kind::vanilla_lstm: |
582 | cell_func = (pd()->rnn_.is_brgemm) |
583 | ? (aprop == prop_kind::forward |
584 | ? &class_name::cell_execution_brgemm_fwd |
585 | : &class_name:: |
586 | cell_execution_brgemm_bwd) |
587 | : &class_name::cell_execution_ref; |
588 | break; |
589 | case alg_kind::vanilla_gru: |
590 | case alg_kind::vanilla_augru: |
591 | cell_func = (pd()->rnn_.is_brgemm) |
592 | ? &class_name::cell_execution_brgemm_fwd |
593 | : &class_name::cell_execution_gru; |
594 | break; |
595 | case alg_kind::lbr_augru: |
596 | case alg_kind::lbr_gru: |
597 | cell_func = &class_name::cell_execution_gru_lbr; |
598 | break; |
599 | default: break; |
600 | } |
601 | |
602 | merged_layer_func = pd()->rnn_.is_brgemm && pd()->rnn_.merge_gemm_layer |
603 | && aprop == prop_kind::forward |
604 | ? &class_name::merged_layer_brgemm_fwd |
605 | : &class_name::merged_layer_execution_ref; |
606 | grid_computation = &class_name::linear_execution; |
607 | |
608 | size_t scratchpad_size, workspace_size; |
609 | rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_ht_offset_, |
610 | ws_states_layer_offset_, ws_states_iter_offset_, |
611 | ws_states_iter_c_offset_, ws_diff_states_layer_offset_, |
612 | ws_diff_states_iter_offset_, ws_diff_states_iter_c_offset_, |
613 | ws_grid_comp_offset_, ws_bias_offset_, scratch_gates_offset_, |
614 | scratch_ht_offset_, scratch_diff_ht_offset_, |
615 | scratch_cell_offset_, scratchpad_size, workspace_size); |
616 | #if DNNL_X64 |
617 | const auto rnn = pd()->rnn_; |
618 | if (rnn.is_brgemm) { |
619 | if (rnn.is_bf32()) { |
620 | |
621 | pd()->bf32_wei_layer_reorder_pd_->create_primitive( |
622 | bf32_wei_layer_reorder_, engine); |
623 | |
624 | pd()->bf32_wei_iter_reorder_pd_->create_primitive( |
625 | bf32_wei_iter_reorder_, engine); |
626 | } |
627 | return rnn_brgemm_.init_kernels(rnn, src_type, weights_type); |
628 | } |
629 | #endif |
630 | return status::success; |
631 | } |
632 | |
633 | ~_ref_rnn_common_t() { delete rnn_postgemm_; } |
634 | |
635 | status_t execute(const exec_ctx_t &ctx) const override { |
636 | execute_(ctx); |
637 | return status::success; |
638 | } |
639 | |
640 | private: |
641 | #if DNNL_X64 |
642 | ref_rnn_brgemm_t rnn_brgemm_; |
643 | std::shared_ptr<primitive_t> bf32_wei_layer_reorder_; |
644 | std::shared_ptr<primitive_t> bf32_wei_iter_reorder_; |
645 | #endif |
646 | void execute_(const exec_ctx_t &ctx) const; |
647 | |
648 | rnn_grid_execution_sig(linear_execution); |
649 | rnn_cell_execution_sig(cell_execution_ref); |
650 | rnn_merged_layer_execution_sig(merged_layer_execution_ref); |
651 | rnn_cell_execution_sig(cell_execution_brgemm_fwd); |
652 | rnn_merged_layer_execution_sig(merged_layer_brgemm_fwd); |
653 | rnn_cell_execution_sig(cell_execution_brgemm_bwd); |
654 | |
655 | rnn_cell_execution_sig(cell_execution_gru); |
656 | rnn_cell_execution_sig(cell_execution_gru_lbr); |
657 | rnn_gemm_sig(gemm); |
658 | rnn_gemm_sig(packed_gemm); |
659 | rnn_bias_prepare_sig(bias_prepare); |
660 | rnn_bias_finalize_sig(bias_finalize); |
661 | rnn_weights_assign_sig(assign_weights); |
662 | rnn_weights_assign_sig(assign_packed_weights); |
663 | |
664 | float (*activation_func)(float s, float alpha, float cliping); |
665 | |
666 | template <typename input_t> |
667 | void copy_init_layer(const rnn_utils::rnn_conf_t &rnn, |
668 | src_layer_t *ws_states_layer_, gemm_acc_t *ws_diff_states_layer_, |
669 | const input_t *xt_, const gemm_acc_t *diff_dst_layer) const; |
670 | |
671 | template <typename input_t> |
672 | void copy_init_iter(const rnn_utils::rnn_conf_t &rnn, |
673 | src_iter_t *ws_states_iter_, void *ws_states_iter_c_, |
674 | gemm_acc_t *ws_diff_states_iter_, |
675 | gemm_acc_t *ws_diff_states_iter_c_, const input_t *src_iter_, |
676 | const void *src_iter_c_, const gemm_acc_t *diff_dst_iter_, |
677 | const float *diff_dst_iter_c_) const; |
678 | |
679 | template <typename dst_layer_dt, typename dst_iter_dt> |
680 | void copy_res_layer(const rnn_utils::rnn_conf_t &rnn, |
681 | dst_layer_dt *dst_layer_, gemm_acc_t *diff_src_layer_, |
682 | const dst_iter_dt *dst_iter_, const src_layer_t *ws_states_layer_, |
683 | const gemm_acc_t *ws_diff_states_layer_) const; |
684 | |
685 | template <typename prim_dst_iter_t, typename prim_dst_layer_t> |
686 | void copy_res_iter(const rnn_utils::rnn_conf_t &rnn, |
687 | prim_dst_iter_t *dst_iter_, void *dst_iter_c_, |
688 | gemm_acc_t *diff_src_iter_, float *diff_src_iter_c_, |
689 | const prim_dst_layer_t *dst_layer_, |
690 | const src_iter_t *ws_states_iter_, const void *ws_states_iter_c, |
691 | const gemm_acc_t *ws_diff_states_iter_, |
692 | const gemm_acc_t *ws_diff_states_iter_c_) const; |
693 | |
694 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
695 | |
696 | size_t ws_gates_offset_; |
697 | size_t ws_ht_offset_; |
698 | size_t ws_states_layer_offset_; |
699 | size_t ws_states_iter_offset_; |
700 | size_t ws_states_iter_c_offset_; |
701 | size_t ws_bias_offset_; |
702 | size_t ws_diff_states_layer_offset_; |
703 | size_t ws_diff_states_iter_offset_; |
704 | size_t ws_diff_states_iter_c_offset_; |
705 | size_t ws_grid_comp_offset_; |
706 | size_t scratch_gates_offset_; |
707 | size_t scratch_ht_offset_; |
708 | size_t scratch_diff_ht_offset_; |
709 | size_t scratch_cell_offset_; |
710 | rnn_postgemm_dispatcher<aprop, src_type, scratch_type, acc_type> |
711 | *rnn_postgemm_; |
712 | |
713 | grid_execution_f grid_computation; |
714 | cell_execution_f cell_func; |
715 | merged_layer_execution_f merged_layer_func; |
716 | |
717 | bias_prepare_t bias_preparation_func; |
718 | bias_finalize_t bias_finalization_func; |
719 | weights_assign_t weights_layer_assign_func; |
720 | weights_assign_t weights_iter_assign_func; |
721 | weights_assign_t weights_projection_assign_func; |
722 | |
723 | gemm_t gemm_layer_func; |
724 | gemm_t gemm_iter_func; |
725 | gemm_t gemm_projection_func; |
726 | }; |
727 | |
728 | using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, |
729 | data_type::f32, data_type::f32>; |
730 | using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, |
731 | data_type::f32, data_type::f32>; |
732 | using ref_rnn_fwd_bf16_t = _ref_rnn_common_t<prop_kind::forward, |
733 | data_type::bf16, data_type::bf16, data_type::f32>; |
734 | using ref_rnn_bwd_bf16_t = _ref_rnn_common_t<prop_kind::backward, |
735 | data_type::bf16, data_type::bf16, data_type::f32>; |
736 | using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, |
737 | data_type::s8, data_type::s32>; |
738 | using ref_rnn_fwd_s8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::s8, |
739 | data_type::s8, data_type::s32>; |
740 | } // namespace cpu |
741 | } // namespace impl |
742 | } // namespace dnnl |
743 | #endif |
744 | |
745 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
746 | |