1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_RNN_REF_RNN_HPP
18#define CPU_RNN_REF_RNN_HPP
19
20#include <assert.h>
21#include <tuple>
22
23#include "common/c_types_map.hpp"
24#include "common/memory_tracking.hpp"
25#include "common/primitive.hpp"
26#include "common/reorder.hpp"
27#include "common/type_helpers.hpp"
28#include "common/utils.hpp"
29
30#include "cpu/gemm/gemm.hpp"
31#include "cpu/gemm/os_blas.hpp"
32
33#include "cpu/rnn/cpu_rnn_pd.hpp"
34#include "cpu/rnn/postgemm_dispatcher.hpp"
35#if DNNL_X64
36#include "cpu/x64/rnn/rnn_brgemm_utils.hpp"
37#endif
38#include "cpu/rnn/rnn_utils.hpp"
39namespace dnnl {
40namespace impl {
41namespace cpu {
42
43namespace {
44template <typename gates_t, typename acc_t>
45// The loop body needs to be put in a function as some versions of icc have
46// an issue with lambdas & macros inside omp simd loops
47inline void body_loop(int i, int k, const gates_t *ws_gates, acc_t *diff_bias,
48 const rnn_utils::rnn_conf_t &rnn) {
49 for (int j = 0; j < rnn.mb; j++)
50 diff_bias[i * rnn.dhc + k]
51 += ws_gates[j * rnn.scratch_gates_ld + i * rnn.dhc + k];
52}
53} // namespace
54
55template <typename gates_t, typename acc_t>
56void gates_reduction(const rnn_utils::rnn_conf_t &rnn, const gates_t *ws_gates_,
57 acc_t *diff_bias_) {
58
59 // @todo block k on simd-width to enable vectorization in
60 // parallel_nd path
61#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_OMP && _OPENMP >= 201307 \
62 && (!defined(__INTEL_COMPILER) || __INTEL_COMPILER < 1910)
63#pragma omp parallel for simd collapse(2)
64 for (int i = 0; i < rnn.n_gates; i++)
65 for (int k = 0; k < rnn.dhc; k++)
66 body_loop(i, k, ws_gates_, diff_bias_, rnn);
67#else
68 parallel_nd(rnn.n_gates, rnn.dhc, [&](dim_t i, dim_t k) {
69 body_loop(i, k, ws_gates_, diff_bias_, rnn);
70 });
71#endif
72}
73
74template <prop_kind_t aprop, impl::data_type_t src_type,
75 impl::data_type_t weights_type, impl::data_type_t acc_type>
76struct _ref_rnn_common_t : public primitive_t {
77 static constexpr impl::data_type_t scratch_type
78 = aprop == prop_kind::forward ? acc_type : src_type;
79
80 /* These types are defined for each element in the cell execution */
81 typedef typename prec_traits<src_type>::type src_layer_t;
82 typedef typename prec_traits<src_type>::type src_iter_t;
83 typedef typename prec_traits<src_type>::type dst_layer_t;
84 typedef typename prec_traits<src_type>::type dst_iter_t;
85 typedef typename prec_traits<weights_type>::type weights_t;
86 typedef typename prec_traits<src_type>::type gemm_data_t;
87 typedef typename prec_traits<acc_type>::type gemm_acc_t;
88 typedef typename prec_traits<scratch_type>::type scratch_t;
89 typedef typename prec_traits<src_type>::type ht_t;
90 typedef typename prec_traits<src_type>::type gates_t;
91
92 using class_name
93 = _ref_rnn_common_t<aprop, src_type, weights_type, acc_type>;
94#if DNNL_X64
95 using ref_rnn_brgemm_t = x64::rnn_brgemm_utils::rnn_brgemm_t<aprop>;
96#endif
97
98 typedef rnn_cell_execution_sig((class_name::*cell_execution_f));
99 typedef rnn_grid_execution_sig((class_name::*grid_execution_f));
100 typedef rnn_merged_layer_execution_sig(
101 (class_name::*merged_layer_execution_f));
102
103 typedef rnn_gemm_sig((class_name::*gemm_t));
104 typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t));
105 typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t));
106 typedef rnn_weights_assign_sig((class_name::*weights_assign_t));
107
108 using base_pd_t =
109 typename utils::conditional<false || aprop == prop_kind::forward,
110 cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
111
112 struct pd_t : public base_pd_t {
113 using base_pd_t::base_pd_t;
114
115 const char *impl_name() const {
116#if DNNL_X64
117 using namespace dnnl::impl::cpu::x64;
118 return rnn_.is_brgemm
119 ? JIT_IMPL_NAME_HELPER("brgemm:", rnn_.brgemm_isa, "")
120 : "ref";
121#else
122 return "ref";
123#endif
124 }
125
126 DECLARE_COMMON_PD_T(impl_name(), class_name, USE_GLOBAL_SCRATCHPAD);
127
128 status_t init_ref(engine_t *engine) {
129 using namespace prop_kind;
130 using namespace utils;
131 using namespace format_tag;
132 using namespace rnn_utils;
133 const alg_kind_t cell_kind = this->desc()->cell_kind;
134
135 const data_type_t src_layer_dt
136 = this->desc()->src_layer_desc.data_type;
137 const data_type_t weights_iter_dt
138 = this->desc()->weights_iter_desc.data_type;
139 const data_type_t weights_layer_dt
140 = this->desc()->weights_layer_desc.data_type;
141
142 bool ok = true
143 && one_of(cell_kind, alg_kind::vanilla_rnn,
144 alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
145 alg_kind::lbr_gru, alg_kind::vanilla_augru,
146 alg_kind::lbr_augru)
147 && IMPLICATION(aprop == prop_kind::forward,
148 one_of(this->desc()->prop_kind, forward_training,
149 forward_inference))
150 && IMPLICATION(aprop == backward,
151 one_of(this->desc()->prop_kind, backward))
152 && src_layer_dt == src_type
153 && everyone_is(
154 weights_type, weights_iter_dt, weights_layer_dt)
155 && this->set_default_params() == status::success
156 && this->with_bias();
157
158 if (!ok) return status::unimplemented;
159
160 rnn_ = zero<decltype(rnn_)>();
161 rnn_.is_brgemm = false;
162 ok = init_conf<class_name>(rnn_, *this->desc(), *this->attr(),
163 this->src_md(0), this->src_md(1), this->src_md(2),
164 this->weights_md(0), this->weights_md(1),
165 this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION), this->dst_md(0),
166 this->dst_md(1), this->dst_md(2),
167 this->arg_md(DNNL_ARG_BIAS));
168 if (!ok) return status::unimplemented;
169
170 if (rnn_.is_bf16_conf()) {
171 if (!utils::one_of(
172 rnn_.bias_dt, data_type::bf16, data_type::f32)
173 || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt
174 || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef,
175 data_type::bf16, data_type::f32))
176 return status::unimplemented;
177 } else if (rnn_.bias_dt != data_type::f32
178 || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef,
179 data_type::f32)
180 || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt)
181 return status::unimplemented;
182
183 /* check that no data shift have been passed to s8s8 lstm */
184 if (!IMPLICATION(rnn_.is_signed_int8_conf(),
185 this->attr()->rnn_data_qparams_.shift_ == 0.f))
186 return status::unimplemented;
187
188 /* check that only supported attr have been passed */
189 primitive_attr_t::skip_mask_t attr_mask
190 = primitive_attr_t::skip_mask_t::rnn_tparams;
191 if (weights_layer_dt == data_type::s8)
192 attr_mask = attr_mask
193 | primitive_attr_t::skip_mask_t::rnn_data_qparams
194 | primitive_attr_t::skip_mask_t::rnn_weights_qparams
195 | primitive_attr_t::skip_mask_t::
196 rnn_weights_projection_qparams;
197 ok = ok && this->attr()->has_default_values(attr_mask);
198 if (!ok) return status::unimplemented;
199
200 // Set weights descriptors to desired format
201 memory_desc_t new_weights_layer_md = *this->weights_md(0);
202 CHECK(set_expected_desc(rnn_, new_weights_layer_md,
203 rnn_utils::weights_type_t::layer));
204 if (this->weights_layer_md_.format_kind == format_kind::any) {
205 this->weights_layer_md_ = new_weights_layer_md;
206 } else if (this->weights_layer_md_.format_kind
207 == format_kind::rnn_packed) {
208 if (this->weights_layer_md_ != new_weights_layer_md)
209 return status::unimplemented;
210 }
211
212 memory_desc_t new_weights_iter_md = *this->weights_md(1);
213 CHECK(set_expected_desc(rnn_, new_weights_iter_md,
214 rnn_utils::weights_type_t::iter));
215 if (this->weights_iter_md_.format_kind == format_kind::any) {
216 this->weights_iter_md_ = new_weights_iter_md;
217 } else if (this->weights_iter_md_.format_kind
218 == format_kind::rnn_packed) {
219 if (this->weights_iter_md_ != new_weights_iter_md)
220 return status::unimplemented;
221 }
222
223 if (rnn_.is_lstm_projection) {
224 memory_desc_t new_weights_projection_md
225 = *this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION);
226 CHECK(set_expected_desc(rnn_, new_weights_projection_md,
227 rnn_utils::weights_type_t::projection));
228 if (this->weights_projection_md_.format_kind
229 == format_kind::any) {
230 this->weights_projection_md_ = new_weights_projection_md;
231 } else if (this->weights_projection_md_.format_kind
232 == format_kind::rnn_packed) {
233 if (this->weights_projection_md_
234 != new_weights_projection_md)
235 return status::unimplemented;
236 }
237 }
238
239 CHECK(this->check_layout_consistency(false /*is_brgemm*/));
240
241 set_conf<class_name>(rnn_, *this->desc(), this->weights_md(0),
242 this->weights_md(1),
243 this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION),
244 this->diff_weights_md(0), this->diff_weights_md(1),
245 this->arg_md(DNNL_ARG_DIFF_WEIGHTS_PROJECTION));
246 set_workspace_sizes<class_name>(rnn_, *this->desc());
247 return status::success;
248 }
249
250 status_t init_brgemm(engine_t *engine) {
251 using namespace prop_kind;
252 using namespace utils;
253 using namespace format_tag;
254 using namespace rnn_utils;
255#if DNNL_X64
256 using namespace x64;
257 const alg_kind_t cell_kind = this->desc()->cell_kind;
258
259 const data_type_t src_layer_dt
260 = this->desc()->src_layer_desc.data_type;
261 const data_type_t weights_iter_dt
262 = this->desc()->weights_iter_desc.data_type;
263 const data_type_t weights_layer_dt
264 = this->desc()->weights_layer_desc.data_type;
265
266 bool ok = one_of(cell_kind, alg_kind::vanilla_rnn,
267 alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
268 alg_kind::vanilla_augru)
269 && IMPLICATION(aprop == prop_kind::forward,
270 one_of(this->desc()->prop_kind, forward_training,
271 forward_inference))
272 && IMPLICATION(aprop == backward,
273 one_of(this->desc()->prop_kind, backward))
274 // cell_type (or src_type) and primitive data type should
275 // match, except for the bf32 case.
276 && IMPLICATION(
277 this->attr()->fpmath_mode_ == fpmath_mode::strict,
278 src_layer_dt == src_type
279 && everyone_is(weights_type,
280 weights_iter_dt, weights_layer_dt))
281 && this->set_default_params() == status::success
282 && this->with_bias();
283
284 if (!ok) return status::unimplemented;
285
286 rnn_ = zero<decltype(rnn_)>();
287 rnn_.is_brgemm = true;
288 ok = init_conf<class_name>(rnn_, *this->desc(), *this->attr(),
289 this->src_md(0), this->src_md(1), this->src_md(2),
290 this->weights_md(0), this->weights_md(1),
291 this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION), this->dst_md(0),
292 this->dst_md(1), this->dst_md(2),
293 this->arg_md(DNNL_ARG_BIAS));
294
295 ok = ok
296 && IMPLICATION(one_of(this->desc()->prop_kind,
297 forward_training, backward),
298 (rnn_.is_bf16_conf() || rnn_.is_f32_conf()));
299
300 if (!ok) return status::unimplemented;
301
302 // Support for GRU / AUGRU cell in BRGEMM-based implementation is
303 // limited by forward_inference pass for now, all_f32 is disabled
304 // due to performance degradation.
305 // TODO: Improve GRU / AUGRU coverage in BRGEMM-based implementation
306 ok = IMPLICATION(rnn_.is_orig_gru,
307 this->desc()->prop_kind == forward_inference
308 && !rnn_.is_cell_dt_f32());
309 if (!ok) return status::unimplemented;
310
311 if (rnn_.is_cell_dt_f32()
312 && utils::one_of(this->desc()->prop_kind, backward,
313 forward_training))
314 return status::unimplemented;
315
316 if (!(IMPLICATION((cell_kind == alg_kind::vanilla_lstm
317 && rnn_.is_lstm_projection),
318 this->desc()->prop_kind == forward_inference)))
319 return status::unimplemented;
320
321 if (rnn_.is_bf16_conf()) {
322 if (!mayiuse(avx512_core_bf16)
323 || !utils::one_of(
324 rnn_.bias_dt, data_type::bf16, data_type::f32)
325 || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt
326 || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef,
327 data_type::bf16, data_type::f32))
328 return status::unimplemented;
329 } else if (rnn_.bias_dt != data_type::f32
330 || !utils::one_of(rnn_.src_iter_c_dt, data_type::undef,
331 data_type::f32)
332 || rnn_.src_iter_c_dt != rnn_.dst_iter_c_dt)
333 return status::unimplemented;
334
335 if (rnn_.is_signed_int8_conf() && !mayiuse(avx512_core_amx))
336 return status::unimplemented;
337 if (rnn_.is_int8_conf() && !mayiuse(avx512_core_vnni))
338 return status::unimplemented;
339 if (rnn_.is_f32_conf() && !mayiuse(avx512_core))
340 return status::unimplemented;
341
342 /* check that no shift have been passed to s8s8 amx lstm */
343 if (!IMPLICATION(rnn_.is_signed_int8_conf(),
344 this->attr()->rnn_data_qparams_.shift_ == 0))
345 return status::unimplemented;
346
347 /* check that only supported attr have been passed */
348 primitive_attr_t::skip_mask_t attr_mask
349 = primitive_attr_t::skip_mask_t::rnn_tparams;
350 if (weights_layer_dt == data_type::s8)
351 attr_mask = attr_mask
352 | primitive_attr_t::skip_mask_t::rnn_data_qparams
353 | primitive_attr_t::skip_mask_t::rnn_weights_qparams
354 | primitive_attr_t::skip_mask_t::
355 rnn_weights_projection_qparams;
356 ok = ok && this->attr()->has_default_values(attr_mask);
357 if (!ok) return status::unimplemented;
358
359 set_conf<class_name>(rnn_, *this->desc(), this->weights_md(0),
360 this->weights_md(1),
361 this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION),
362 this->diff_weights_md(0), this->diff_weights_md(1),
363 this->arg_md(DNNL_ARG_DIFF_WEIGHTS_PROJECTION));
364
365 CHECK(ref_rnn_brgemm_t::configure_brgemm(rnn_,
366 this->desc()->cell_kind, sizeof(src_layer_t),
367 sizeof(scratch_t)));
368
369 // must be called after configure_brgemm()
370 set_workspace_sizes<class_name>(rnn_, *this->desc());
371
372 // Only AMX LSTM supports s8s8 now
373 if (rnn_.is_signed_int8_conf() && !rnn_.is_cell_int8_amx())
374 return status::unimplemented;
375
376 // Set weights descriptors to desired format
377 memory_desc_t new_weights_layer_md = *this->weights_md(0);
378 CHECK(set_expected_desc(rnn_, new_weights_layer_md,
379 rnn_utils::weights_type_t::layer));
380 if (this->weights_layer_md_.format_kind == format_kind::any) {
381 this->weights_layer_md_ = new_weights_layer_md;
382 } else if (this->weights_layer_md_ != new_weights_layer_md) {
383 return status::unimplemented;
384 }
385
386 memory_desc_t new_weights_iter_md = *this->weights_md(1);
387 CHECK(set_expected_desc(rnn_, new_weights_iter_md,
388 rnn_utils::weights_type_t::iter));
389 if (this->weights_iter_md_.format_kind == format_kind::any) {
390 this->weights_iter_md_ = new_weights_iter_md;
391 } else if (this->weights_iter_md_ != new_weights_iter_md) {
392 return status::unimplemented;
393 }
394 if (rnn_.is_lstm_projection) {
395 memory_desc_t new_weights_projection_md
396 = *this->arg_md(DNNL_ARG_WEIGHTS_PROJECTION);
397 CHECK(set_expected_desc(rnn_, new_weights_projection_md,
398 rnn_utils::weights_type_t::projection));
399 if (this->weights_projection_md_.format_kind
400 == format_kind::any) {
401 this->weights_projection_md_ = new_weights_projection_md;
402 } else if (this->weights_projection_md_
403 != new_weights_projection_md) {
404 return status::unimplemented;
405 }
406 }
407 if (rnn_.is_unsigned_int8_conf()) {
408 const memory_desc_wrapper &weights_layer_d(
409 this->weights_layer_md_);
410 const memory_desc_wrapper &weights_iter_d(
411 this->weights_iter_md_);
412 const auto &pdims_l = weights_layer_d.padded_dims();
413 const auto &pdims_i = weights_iter_d.padded_dims();
414 rnn_.weights_layer_comp_offset = rnn_.n_layer * rnn_.n_dir
415 * rnn_.n_gates * pdims_l[2] * pdims_l[4];
416 rnn_.weights_iter_comp_offset = rnn_.n_layer * rnn_.n_dir
417 * rnn_.n_gates * pdims_i[2] * pdims_i[4];
418 if (rnn_.is_lstm_projection) {
419 const memory_desc_wrapper &weights_proj_d(
420 this->weights_projection_md_);
421 const auto &pdims_p = weights_proj_d.padded_dims();
422 rnn_.weights_projection_comp_offset = rnn_.n_layer
423 * rnn_.n_dir * pdims_p[2] * pdims_p[3];
424 } else {
425 rnn_.weights_projection_comp_offset = 0;
426 }
427 }
428 CHECK(this->check_layout_consistency(true /*is_brgemm*/));
429
430 if (rnn_.is_bf32()) {
431 const memory_desc_wrapper weights_layer_d(
432 this->weights_layer_md_);
433 memory_desc_t weights_layer_md;
434 const memory_desc_wrapper weights_iter_d(
435 this->weights_iter_md_);
436 memory_desc_t weights_iter_md;
437
438 const auto bf16_tag = rnn_.n_block == 64
439 ? format_tag::ldgOI64o2i
440 : format_tag::ldgOI32o2i;
441 memory_desc_init_by_tag(weights_layer_md,
442 weights_layer_d.ndims(), weights_layer_d.dims(),
443 data_type::bf16, bf16_tag);
444 CHECK(reorder_primitive_desc_create(bf32_wei_layer_reorder_pd_,
445 engine, weights_layer_d.md_, &weights_layer_md,
446 nullptr));
447
448 memory_desc_init_by_tag(weights_iter_md, weights_iter_d.ndims(),
449 weights_iter_d.dims(), data_type::bf16, bf16_tag);
450 CHECK(reorder_primitive_desc_create(bf32_wei_iter_reorder_pd_,
451 engine, weights_iter_d.md_, &weights_iter_md, nullptr));
452 }
453
454 return status::success;
455#else
456 return status::unimplemented;
457#endif
458 }
459
460 status_t init(engine_t *engine) {
461 status_t st = init_brgemm(engine);
462 if (st != status::success) {
463 rnn_.is_brgemm = false;
464 st = init_ref(engine);
465 }
466 if (st == status::success) {
467 size_t scratchpad_sz {0}, ws_sz {0};
468 get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz);
469
470 init_scratchpad(scratchpad_sz);
471 // initialize the workspace if needed
472 if (rnn_.is_training) {
473 dims_t ws_dims = {(dim_t)ws_sz};
474 memory_desc_init_by_tag(this->ws_md_, 1, ws_dims,
475 data_type::u8, format_tag::x);
476 }
477 }
478 return st;
479 }
480
481 rnn_utils::rnn_conf_t rnn_;
482#if DNNL_X64
483 std::shared_ptr<primitive_desc_t> bf32_wei_layer_reorder_pd_;
484 std::shared_ptr<primitive_desc_t> bf32_wei_iter_reorder_pd_;
485#endif
486 private:
487 void init_scratchpad(size_t scratchpad_sz) {
488 using namespace memory_tracking::names;
489 auto scratchpad = this->scratchpad_registry().registrar();
490
491 {
492 static constexpr size_t data_size
493 = 1; // "true" data size already incorporated
494 static constexpr size_t data_align
495 = alignof(float); // "worst" case scenario
496 static constexpr size_t perf_align = 4096;
497 scratchpad.book(key_rnn_space, scratchpad_sz, data_size,
498 data_align, perf_align);
499 }
500
501 const int max_nparts
502 = utils::one_of(this->cell_kind(), alg_kind::vanilla_gru,
503 alg_kind::vanilla_augru)
504 ? 2
505 : 1;
506 const int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts;
507 scratchpad.template book<float *>(
508 key_rnn_ptrs_wei_layer, ptr_wei_sz);
509 scratchpad.template book<float *>(
510 key_rnn_ptrs_wei_iter, ptr_wei_sz);
511 scratchpad.template book<float *>(
512 key_rnn_ptrs_wei_projection, ptr_wei_sz);
513
514 const auto bias_dt_size = types::data_type_size(
515 this->arg_md(DNNL_ARG_BIAS)->data_type);
516 scratchpad.template book<void *>(
517 key_rnn_ptrs_bia, ptr_wei_sz * bias_dt_size);
518
519 scratchpad.template book<scratch_t>(
520 key_rnn_gates, rnn_.scratch_gates_size);
521 scratchpad.template book<ht_t>(key_rnn_ht, rnn_.scratch_ht_size);
522 scratchpad.template book<gemm_acc_t>(
523 key_rnn_diff_ht, rnn_.scratch_diff_ht_size);
524 scratchpad.template book<scratch_t>(
525 key_rnn_cell, rnn_.scratch_cell_size);
526
527#if DNNL_X64
528 if (rnn_.is_brgemm) {
529 ref_rnn_brgemm_t::init_scratchpad(rnn_, scratchpad,
530 sizeof(gemm_acc_t), alignof(gemm_acc_t));
531 if (rnn_.is_bf32()) {
532 scratchpad.book(key_nested_multiple + 0,
533 bf32_wei_layer_reorder_pd_->scratchpad_registry());
534 scratchpad.book(key_nested_multiple + 1,
535 bf32_wei_iter_reorder_pd_->scratchpad_registry());
536 }
537 }
538#endif
539 }
540 };
541
542 _ref_rnn_common_t(const pd_t *apd)
543 : primitive_t(apd), rnn_postgemm_(nullptr) {}
544
545 status_t init(engine_t *engine) override {
546 /// @todo set max_feature_size assuming that we limit the number of
547 /// iterations and layer to one if slc != dhc and sic != dhc
548 /// respectively
549
550 bias_preparation_func = &class_name::bias_prepare;
551 bias_finalization_func = &class_name::bias_finalize;
552
553 const auto set_gemm_funcs
554 = [](bool packed_gemm, gemm_t &g, weights_assign_t &a,
555 bool is_brgemm) {
556 if (packed_gemm) {
557 g = &class_name::packed_gemm;
558 a = &class_name::assign_packed_weights;
559 } else {
560 g = (!is_brgemm) ? &class_name::gemm : nullptr;
561 a = &class_name::assign_weights;
562 }
563 };
564 set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func,
565 weights_iter_assign_func, pd()->rnn_.is_brgemm);
566
567 set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func,
568 weights_layer_assign_func, pd()->rnn_.is_brgemm);
569
570 if (pd()->rnn_.is_lstm_projection) {
571 set_gemm_funcs(pd()->rnn_.use_projection_packed_gemm,
572 gemm_projection_func, weights_projection_assign_func,
573 pd()->rnn_.is_brgemm);
574 }
575
576 rnn_postgemm_ = new rnn_postgemm_dispatcher<aprop, src_type,
577 scratch_type, acc_type>(pd()->rnn_, pd());
578 assert(rnn_postgemm_ != nullptr);
579 switch (pd()->cell_kind()) {
580 case alg_kind::vanilla_rnn:
581 case alg_kind::vanilla_lstm:
582 cell_func = (pd()->rnn_.is_brgemm)
583 ? (aprop == prop_kind::forward
584 ? &class_name::cell_execution_brgemm_fwd
585 : &class_name::
586 cell_execution_brgemm_bwd)
587 : &class_name::cell_execution_ref;
588 break;
589 case alg_kind::vanilla_gru:
590 case alg_kind::vanilla_augru:
591 cell_func = (pd()->rnn_.is_brgemm)
592 ? &class_name::cell_execution_brgemm_fwd
593 : &class_name::cell_execution_gru;
594 break;
595 case alg_kind::lbr_augru:
596 case alg_kind::lbr_gru:
597 cell_func = &class_name::cell_execution_gru_lbr;
598 break;
599 default: break;
600 }
601
602 merged_layer_func = pd()->rnn_.is_brgemm && pd()->rnn_.merge_gemm_layer
603 && aprop == prop_kind::forward
604 ? &class_name::merged_layer_brgemm_fwd
605 : &class_name::merged_layer_execution_ref;
606 grid_computation = &class_name::linear_execution;
607
608 size_t scratchpad_size, workspace_size;
609 rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_ht_offset_,
610 ws_states_layer_offset_, ws_states_iter_offset_,
611 ws_states_iter_c_offset_, ws_diff_states_layer_offset_,
612 ws_diff_states_iter_offset_, ws_diff_states_iter_c_offset_,
613 ws_grid_comp_offset_, ws_bias_offset_, scratch_gates_offset_,
614 scratch_ht_offset_, scratch_diff_ht_offset_,
615 scratch_cell_offset_, scratchpad_size, workspace_size);
616#if DNNL_X64
617 const auto rnn = pd()->rnn_;
618 if (rnn.is_brgemm) {
619 if (rnn.is_bf32()) {
620
621 pd()->bf32_wei_layer_reorder_pd_->create_primitive(
622 bf32_wei_layer_reorder_, engine);
623
624 pd()->bf32_wei_iter_reorder_pd_->create_primitive(
625 bf32_wei_iter_reorder_, engine);
626 }
627 return rnn_brgemm_.init_kernels(rnn, src_type, weights_type);
628 }
629#endif
630 return status::success;
631 }
632
633 ~_ref_rnn_common_t() { delete rnn_postgemm_; }
634
635 status_t execute(const exec_ctx_t &ctx) const override {
636 execute_(ctx);
637 return status::success;
638 }
639
640private:
641#if DNNL_X64
642 ref_rnn_brgemm_t rnn_brgemm_;
643 std::shared_ptr<primitive_t> bf32_wei_layer_reorder_;
644 std::shared_ptr<primitive_t> bf32_wei_iter_reorder_;
645#endif
646 void execute_(const exec_ctx_t &ctx) const;
647
648 rnn_grid_execution_sig(linear_execution);
649 rnn_cell_execution_sig(cell_execution_ref);
650 rnn_merged_layer_execution_sig(merged_layer_execution_ref);
651 rnn_cell_execution_sig(cell_execution_brgemm_fwd);
652 rnn_merged_layer_execution_sig(merged_layer_brgemm_fwd);
653 rnn_cell_execution_sig(cell_execution_brgemm_bwd);
654
655 rnn_cell_execution_sig(cell_execution_gru);
656 rnn_cell_execution_sig(cell_execution_gru_lbr);
657 rnn_gemm_sig(gemm);
658 rnn_gemm_sig(packed_gemm);
659 rnn_bias_prepare_sig(bias_prepare);
660 rnn_bias_finalize_sig(bias_finalize);
661 rnn_weights_assign_sig(assign_weights);
662 rnn_weights_assign_sig(assign_packed_weights);
663
664 float (*activation_func)(float s, float alpha, float cliping);
665
666 template <typename input_t>
667 void copy_init_layer(const rnn_utils::rnn_conf_t &rnn,
668 src_layer_t *ws_states_layer_, gemm_acc_t *ws_diff_states_layer_,
669 const input_t *xt_, const gemm_acc_t *diff_dst_layer) const;
670
671 template <typename input_t>
672 void copy_init_iter(const rnn_utils::rnn_conf_t &rnn,
673 src_iter_t *ws_states_iter_, void *ws_states_iter_c_,
674 gemm_acc_t *ws_diff_states_iter_,
675 gemm_acc_t *ws_diff_states_iter_c_, const input_t *src_iter_,
676 const void *src_iter_c_, const gemm_acc_t *diff_dst_iter_,
677 const float *diff_dst_iter_c_) const;
678
679 template <typename dst_layer_dt, typename dst_iter_dt>
680 void copy_res_layer(const rnn_utils::rnn_conf_t &rnn,
681 dst_layer_dt *dst_layer_, gemm_acc_t *diff_src_layer_,
682 const dst_iter_dt *dst_iter_, const src_layer_t *ws_states_layer_,
683 const gemm_acc_t *ws_diff_states_layer_) const;
684
685 template <typename prim_dst_iter_t, typename prim_dst_layer_t>
686 void copy_res_iter(const rnn_utils::rnn_conf_t &rnn,
687 prim_dst_iter_t *dst_iter_, void *dst_iter_c_,
688 gemm_acc_t *diff_src_iter_, float *diff_src_iter_c_,
689 const prim_dst_layer_t *dst_layer_,
690 const src_iter_t *ws_states_iter_, const void *ws_states_iter_c,
691 const gemm_acc_t *ws_diff_states_iter_,
692 const gemm_acc_t *ws_diff_states_iter_c_) const;
693
694 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
695
696 size_t ws_gates_offset_;
697 size_t ws_ht_offset_;
698 size_t ws_states_layer_offset_;
699 size_t ws_states_iter_offset_;
700 size_t ws_states_iter_c_offset_;
701 size_t ws_bias_offset_;
702 size_t ws_diff_states_layer_offset_;
703 size_t ws_diff_states_iter_offset_;
704 size_t ws_diff_states_iter_c_offset_;
705 size_t ws_grid_comp_offset_;
706 size_t scratch_gates_offset_;
707 size_t scratch_ht_offset_;
708 size_t scratch_diff_ht_offset_;
709 size_t scratch_cell_offset_;
710 rnn_postgemm_dispatcher<aprop, src_type, scratch_type, acc_type>
711 *rnn_postgemm_;
712
713 grid_execution_f grid_computation;
714 cell_execution_f cell_func;
715 merged_layer_execution_f merged_layer_func;
716
717 bias_prepare_t bias_preparation_func;
718 bias_finalize_t bias_finalization_func;
719 weights_assign_t weights_layer_assign_func;
720 weights_assign_t weights_iter_assign_func;
721 weights_assign_t weights_projection_assign_func;
722
723 gemm_t gemm_layer_func;
724 gemm_t gemm_iter_func;
725 gemm_t gemm_projection_func;
726};
727
728using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32,
729 data_type::f32, data_type::f32>;
730using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32,
731 data_type::f32, data_type::f32>;
732using ref_rnn_fwd_bf16_t = _ref_rnn_common_t<prop_kind::forward,
733 data_type::bf16, data_type::bf16, data_type::f32>;
734using ref_rnn_bwd_bf16_t = _ref_rnn_common_t<prop_kind::backward,
735 data_type::bf16, data_type::bf16, data_type::f32>;
736using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8,
737 data_type::s8, data_type::s32>;
738using ref_rnn_fwd_s8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::s8,
739 data_type::s8, data_type::s32>;
740} // namespace cpu
741} // namespace impl
742} // namespace dnnl
743#endif
744
745// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
746