1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 General architecture
19
20 for diff states, we have n_states + 1 as we have n_states diff
21 to propagate to the previous iteration and 1 states to propagate
22 to the previous layer
23 index 0 is dh for cell(t-1, l) to consume // replaced by diff_src_iter
24 index 1 is dc for cell(t-1, l) to consume // replaced by diff_src_iter_c
25 index 2 is dh for cell(t, l-1) to consume // replace by diff_src_layer
26 this indexing enables to have the same indexing for states in elemwise
27 function
28 only the cell execution function should be impacted
29
30 */
31
32#include "common/dnnl_thread.hpp"
33#include "common/stream.hpp"
34
35#include "cpu/simple_q10n.hpp"
36
37#include "cpu/gemm/gemm.hpp"
38#include "cpu/gemm/gemm_pack.hpp"
39
40#include "cpu/rnn/ref_rnn.hpp"
41
42namespace dnnl {
43namespace impl {
44namespace cpu {
45
46using namespace dnnl::impl::utils;
47using namespace dnnl::impl::memory_tracking::names;
48using namespace rnn_utils;
49#define AOC array_offset_calculator
50
51// GEMM functions wrapper definitions
52
53template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
54 data_type_t acc_type>
55rnn_gemm_sig(
56 (_ref_rnn_common_t<aprop, src_type, weights_type, acc_type>::gemm)) {
57 assert(!"non packed gemm is unavailable for this data type");
58 return dnnl_unimplemented;
59}
60
61template <>
62rnn_gemm_sig((ref_rnn_fwd_f32_t::gemm)) {
63 assert(ldA * ldB * ldC != 0);
64 return extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_,
65 &ldB, &beta, c_, &ldC, nullptr, pd()->rnn_.force_nocopy);
66}
67
68template <>
69rnn_gemm_sig((ref_rnn_bwd_f32_t::gemm)) {
70 assert(ldA * ldB * ldC != 0);
71 return extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_,
72 &ldB, &beta, c_, &ldC, nullptr, pd()->rnn_.force_nocopy);
73}
74
75template <>
76rnn_gemm_sig((ref_rnn_fwd_bf16_t::gemm)) {
77 assert(ldA * ldB * ldC != 0);
78 return gemm_bf16bf16f32(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_,
79 &ldB, &beta, c_, &ldC);
80}
81
82template <>
83rnn_gemm_sig((ref_rnn_bwd_bf16_t::gemm)) {
84 assert(ldA * ldB * ldC != 0);
85 return gemm_bf16bf16f32(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_,
86 &ldB, &beta, c_, &ldC);
87}
88
89// packed GEMM functions wrapper definitions
90
91template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
92 data_type_t acc_type>
93rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
94 acc_type>::packed_gemm)) {
95 assert(!"packed gemm is unavailable for this datatype");
96 return dnnl_unimplemented;
97}
98
99template <>
100rnn_gemm_sig(ref_rnn_fwd_f32_t::packed_gemm) {
101 assert(transA == 'N' && transB == 'N' && alpha == 1.);
102 return sgemm_compute(
103 "P", "N", &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC);
104}
105
106template <>
107rnn_gemm_sig(ref_rnn_bwd_f32_t::packed_gemm) {
108 assert(transA == 'N' && transB == 'N' && alpha == 1.);
109 return sgemm_compute(
110 "P", "N", &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC);
111}
112
113template <>
114rnn_gemm_sig((ref_rnn_fwd_bf16_t::packed_gemm)) {
115 assert(transA == 'N' && transB == 'N' && alpha == 1.);
116 return gemm_bf16bf16f32_compute(
117 "P", "N", &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC);
118}
119
120template <>
121rnn_gemm_sig((ref_rnn_bwd_bf16_t::packed_gemm)) {
122 assert(transA == 'N' && transB == 'N' && alpha == 1.);
123 return gemm_bf16bf16f32_compute(
124 "P", "N", &m, &n, &k, a_, &ldA, b_, &ldB, &beta, c_, &ldC);
125}
126
127template <>
128rnn_gemm_sig(ref_rnn_fwd_u8s8_t::packed_gemm) {
129 assert(transA == 'N' && transB == 'N' && alpha == 1.);
130 int32_t offsetc = 0;
131 return gemm_s8u8s32_compute("P", "N", "F", &m, &n, &k, a_, &ldA, b_, &ldB,
132 &beta, c_, &ldC, &offsetc);
133}
134
135template <>
136rnn_gemm_sig(ref_rnn_fwd_s8s8_t::packed_gemm) {
137 assert(transA == 'N' && transB == 'N' && alpha == 1.);
138 int32_t offsetc = 0;
139 return gemm_s8s8s32_compute("P", "N", "F", &m, &n, &k, a_, &ldA, b_, &ldB,
140 &beta, c_, &ldC, &offsetc);
141}
142
143//*************** Grid computations strategy: linear ***************//
144template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
145 data_type_t acc_type>
146rnn_grid_execution_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
147 acc_type>::linear_execution)) {
148 const AOC<src_layer_t, 4> ws_states_layer(ws_states_layer_, rnn.n_layer + 1,
149 rnn.n_dir, rnn.n_iter + 1,
150 rnn.ws_states_layer_nld * rnn.ws_states_layer_ld);
151 const AOC<const src_layer_t, 3> augru_attention(
152 augru_attention_, rnn.n_iter, rnn.mb, 1);
153 const AOC<src_iter_t, 4> ws_states_iter(ws_states_iter_, rnn.n_layer + 1,
154 rnn.n_dir, rnn.n_iter + 1,
155 rnn.ws_states_iter_nld * rnn.ws_states_iter_ld);
156 const auto ws_states_iter_c = rnn_utils::make_raw_aoc(ws_states_iter_c_,
157 types::data_type_size(rnn.src_iter_c_dt), rnn.n_layer + 1,
158 rnn.n_dir, rnn.n_iter + 1,
159 rnn.ws_diff_states_iter_c_nld * rnn.ws_diff_states_iter_c_ld);
160 const AOC<gemm_acc_t, 4> ws_diff_states_layer(ws_diff_states_layer_,
161 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1,
162 rnn.ws_diff_states_layer_nld * rnn.ws_diff_states_layer_ld);
163 const AOC<gemm_acc_t, 3> diff_augru_attention(
164 diff_augru_attention_, rnn.n_iter, rnn.mb, 1);
165 const AOC<gemm_acc_t, 4> ws_diff_states_iter(ws_diff_states_iter_,
166 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1,
167 rnn.ws_diff_states_iter_nld * rnn.ws_diff_states_iter_ld);
168 const AOC<gemm_acc_t, 4> ws_diff_states_iter_c(ws_diff_states_iter_c_,
169 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1,
170 rnn.ws_diff_states_iter_c_nld * rnn.ws_diff_states_iter_c_ld);
171 const AOC<gates_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir,
172 rnn.n_iter, rnn.ws_gates_nld * rnn.ws_gates_ld);
173 const AOC<dst_iter_t, 4> ws_ht(ws_ht_, rnn.n_layer, rnn.n_dir, rnn.n_iter,
174 rnn.ws_ht_nld * rnn.ws_ht_ld);
175 const AOC<weights_t *, 3> weights_layer(
176 weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer);
177 const AOC<weights_t *, 3> weights_iter(
178 weights_iter_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
179 const AOC<weights_t *, 2> weights_projection(
180 weights_projection_, rnn.n_layer, rnn.n_dir);
181 const AOC<const float, 3> weights_peephole(
182 weights_peephole_, rnn.n_layer, rnn.n_dir, 3 * rnn.dhc);
183 bias_linear_exec_aoc_t bias(rnn, bias_);
184 const AOC<gemm_acc_t, 3> diff_weights_layer(diff_weights_layer_,
185 rnn.n_layer, rnn.n_dir,
186 rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld);
187 const AOC<gemm_acc_t, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer,
188 rnn.n_dir, rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld);
189 const AOC<float, 3> diff_weights_peephole(
190 diff_weights_peephole_, rnn.n_layer, rnn.n_dir, 3 * rnn.dhc);
191 const AOC<float, 3> diff_weights_projection(diff_weights_projection_,
192 rnn.n_layer, rnn.n_dir,
193 rnn.diff_weights_projection_nld * rnn.diff_weights_projection_ld);
194 const AOC<float, 3> diff_bias(
195 diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc);
196 const AOC<gates_t, 4> ws_grid(
197 ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell);
198
199 /* Raw inputs/outputs coming from the user */
200 // Here we cannot use AOC as user's input can have arbitrary strides, so we use desc_wrapper.
201 const auto src_layer_mdw = memory_desc_wrapper(pd()->src_md(0));
202 const auto dst_layer_mdw = memory_desc_wrapper(pd()->dst_md(0));
203 const auto src_iter_mdw = memory_desc_wrapper(pd()->src_md(1));
204 const auto dst_iter_mdw = memory_desc_wrapper(pd()->dst_md(1));
205 const auto src_iter_c_mdw = memory_desc_wrapper(pd()->src_md(2));
206 const auto dst_iter_c_mdw = memory_desc_wrapper(pd()->dst_md(2));
207
208// Since the function FN(...) returns by reference so an extra exception
209// has to be made for nullptr argument
210#define SAFE_PTR(FN, ...) CONCAT2(FN, _) ? &(FN(__VA_ARGS__)) : nullptr
211 const auto compute_merged_layer_part_if_applicable
212 = [&](prop_kind_t target_prop, int dir, int lay) {
213 if (IMPLICATION(rnn.merge_gemm_layer, aprop != target_prop))
214 return dnnl_success;
215
216 cell_position_t cell_position = middle_cell;
217 if (lay == 0) cell_position |= first_layer;
218
219 const src_layer_t *src_layer
220 = lay == 0 && rnn.skip_src_layer_copy()
221 ? src_layer_
222 : SAFE_PTR(ws_states_layer, lay, dir, 1, 0);
223#if DNNL_X64
224 CHECK((this->*merged_layer_func)(ctx, rnn, cell_position,
225 SAFE_PTR(weights_layer, lay, dir, 0), src_layer,
226 scratch_gates_,
227 SAFE_PTR(ws_diff_states_layer, lay, dir, 0, 0),
228 SAFE_PTR(diff_weights_layer, lay, dir, 0),
229 amx_scratchpad, addr_batch_global));
230#else
231 CHECK((this->*merged_layer_func)(rnn, cell_position,
232 SAFE_PTR(weights_layer, lay, dir, 0), src_layer,
233 scratch_gates_,
234 SAFE_PTR(ws_diff_states_layer, lay, dir, 0, 0),
235 SAFE_PTR(diff_weights_layer, lay, dir, 0)));
236#endif
237 return dnnl_success;
238 };
239
240 // We run the grid of computation
241 for_(int dir = 0; dir < rnn.n_dir; dir++)
242 for (int j = 0; j < rnn.n_layer; j++) {
243 const int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1;
244
245 CHECK(compute_merged_layer_part_if_applicable(
246 prop_kind::forward, dir, lay));
247
248 // TODO: enable merging projection gemm in bwd lstm projection
249
250 for (int i = 0; i < rnn.n_iter; i++) {
251 const int iter
252 = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1;
253
254 // We set parameters to the cell execution call
255
256 // dst_layer is equal to dst_iter. To avoid
257 // duplication of memory access we hence use only
258 // dst_layer and set dst_iter to nullptr, unless we
259 // cannot for one of the following condition:
260 // - in the last layer and last iteration, we need to
261 // copy ht in two tensors (dst_layer and dst_iter)
262 dst_layer_t *cell_dst_layer
263 = &(ws_states_layer(lay + 1, dir, iter + 1, 0));
264 dst_iter_t *cell_dst_iter = nullptr;
265 const src_layer_t *cell_src_layer
266 = &(ws_states_layer(lay, dir, iter + 1, 0));
267 const src_iter_t *cell_src_iter
268 = &(ws_states_iter(lay + 1, dir, iter, 0));
269
270 void *cell_dst_iter_c = const_cast<void *>(
271 ws_states_iter_c(lay + 1, dir, iter + 1, 0));
272 const void *cell_src_iter_c
273 = ws_states_iter_c(lay + 1, dir, iter, 0);
274
275 // the cell_position is used only when skip_data_copy is
276 // supported currently supported only for forward
277 cell_position_t cell_position = middle_cell;
278 if (iter == 0) cell_position |= first_iter;
279 if (lay == 0) cell_position |= first_layer;
280 if (iter == rnn.n_iter - 1) cell_position |= last_iter;
281 if (lay == rnn.n_layer - 1) cell_position |= last_layer;
282
283 // The dst_* paths should be before the src_* paths as
284 // the later will override cell_src_layer and
285 // cell_src_iter appropriately for 1st layer and 1st
286 // iter.
287 const bool last_iter_skip_copy
288 = rnn.skip_dst_iter_copy() && (cell_position & last_iter);
289 if (last_iter_skip_copy) {
290 cell_dst_layer = dst_iter_ + dst_iter_mdw.off(lay, dir, 0, 0);
291 cell_src_layer
292 = dst_iter_ + dst_iter_mdw.off(lay - 1, dir, 0, 0);
293 }
294
295 if (rnn.skip_dst_layer_copy() && (cell_position & last_layer)) {
296 // Note: for last layer and last iter, the output is in dst_layer
297 // and still need to be copied to dst_iter
298 cell_dst_layer = dst_layer_ + dst_layer_mdw.off(iter, 0, 0);
299 cell_dst_iter = last_iter_skip_copy
300 ? dst_iter_ + dst_iter_mdw.off(lay, dir, 0, 0)
301 : nullptr;
302 cell_src_iter = (iter != 0)
303 ? dst_layer_ + dst_layer_mdw.off(iter - 1, 0, 0)
304 : cell_src_iter;
305 }
306 if (rnn.skip_src_iter_copy() && (cell_position & first_iter))
307 cell_src_iter = src_iter_ + src_iter_mdw.off(lay, dir, 0, 0);
308
309 if (rnn.skip_src_layer_copy() && (cell_position & first_layer))
310 cell_src_layer = src_layer_ + src_layer_mdw.off(iter, 0, 0);
311
312 // because the c state is always f32 and require no
313 // conversion, we can always skip to copy for the 1st
314 // and last iteration
315 if (iter == 0 && src_iter_c_) {
316 cell_src_iter_c = inc_ptr(src_iter_c_, rnn.src_iter_c_dt,
317 src_iter_c_mdw.off(lay, dir, 0, 0));
318 cell_position |= c_state_first_iter;
319 }
320 if (iter == rnn.n_iter - 1 && dst_iter_c_) {
321 cell_dst_iter_c = inc_ptr(dst_iter_c_, rnn.dst_iter_c_dt,
322 dst_iter_c_mdw.off(lay, dir, 0, 0));
323 cell_position |= c_state_last_iter;
324 }
325 const size_t sg_start_idx = rnn.n_iter_scratch_gates == 1
326 ? static_cast<size_t>(0)
327 : static_cast<size_t>(iter) * rnn.scratch_gates_nld
328 * rnn.scratch_gates_ld;
329 const auto cell_scratch_gates = &scratch_gates_[sg_start_idx];
330
331 dst_iter_t *proj_ht = nullptr;
332 if (rnn.is_lstm_projection) {
333 if (rnn.is_training)
334 proj_ht = &(ws_ht(lay, dir, iter, 0));
335 else
336 proj_ht = scratch_ht_;
337 }
338
339#if DNNL_X64
340 CHECK((this->*cell_func)(ctx, rnn, cell_position, cell_dst_layer,
341 cell_dst_iter_c,
342 SAFE_PTR(ws_diff_states_layer, lay, dir, iter, 0),
343 SAFE_PTR(diff_augru_attention, iter, 0, 0),
344 SAFE_PTR(ws_diff_states_iter, lay, dir, iter, 0),
345 SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter, 0),
346 SAFE_PTR(weights_layer, lay, dir, 0),
347 SAFE_PTR(weights_iter, lay, dir, 0),
348 SAFE_PTR(weights_projection, lay, dir),
349 SAFE_PTR(weights_peephole, lay, dir, 0),
350 w_proj_comp ? w_proj_comp + (j * rnn.n_dir + dir) * rnn.dic
351 : nullptr,
352 bias(lay, dir), cell_src_layer,
353 SAFE_PTR(augru_attention, iter, 0, 0), cell_src_iter,
354 cell_src_iter_c,
355 SAFE_PTR(ws_diff_states_layer, lay + 1, dir, iter, 0),
356 SAFE_PTR(ws_diff_states_iter, lay, dir, iter + 1, 0),
357 SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter + 1, 0),
358 SAFE_PTR(diff_weights_layer, lay, dir, 0),
359 SAFE_PTR(diff_weights_iter, lay, dir, 0),
360 SAFE_PTR(diff_weights_projection, lay, dir, 0),
361 SAFE_PTR(diff_weights_peephole, lay, dir, 0),
362 SAFE_PTR(diff_bias, lay, dir, 0),
363 SAFE_PTR(ws_gates, lay, dir, iter, 0), cell_scratch_gates,
364 proj_ht, scratch_diff_ht_,
365 SAFE_PTR(ws_grid, lay, dir, iter, 0), scratch_cell_,
366 scratch_gates_blocked_, scratch_src_layer_,
367 scratch_src_iter_, cell_dst_iter, amx_scratchpad,
368 addr_batch_global));
369#else
370 CHECK((this->*cell_func)(rnn, cell_position, cell_dst_layer,
371 cell_dst_iter_c,
372 SAFE_PTR(ws_diff_states_layer, lay, dir, iter, 0),
373 SAFE_PTR(diff_augru_attention, iter, 0, 0),
374 SAFE_PTR(ws_diff_states_iter, lay, dir, iter, 0),
375 SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter, 0),
376 SAFE_PTR(weights_layer, lay, dir, 0),
377 SAFE_PTR(weights_iter, lay, dir, 0),
378 SAFE_PTR(weights_projection, lay, dir),
379 SAFE_PTR(weights_peephole, lay, dir, 0),
380 w_proj_comp ? w_proj_comp + (j * rnn.n_dir + dir) * rnn.dic
381 : nullptr,
382 bias(lay, dir), cell_src_layer,
383 SAFE_PTR(augru_attention, iter, 0, 0), cell_src_iter,
384 cell_src_iter_c,
385 SAFE_PTR(ws_diff_states_layer, lay + 1, dir, iter, 0),
386 SAFE_PTR(ws_diff_states_iter, lay, dir, iter + 1, 0),
387 SAFE_PTR(ws_diff_states_iter_c, lay, dir, iter + 1, 0),
388 SAFE_PTR(diff_weights_layer, lay, dir, 0),
389 SAFE_PTR(diff_weights_iter, lay, dir, 0),
390 SAFE_PTR(diff_weights_projection, lay, dir, 0),
391 SAFE_PTR(diff_weights_peephole, lay, dir, 0),
392 SAFE_PTR(diff_bias, lay, dir, 0),
393 SAFE_PTR(ws_gates, lay, dir, iter, 0), cell_scratch_gates,
394 proj_ht, scratch_diff_ht_,
395 SAFE_PTR(ws_grid, lay, dir, iter, 0), scratch_cell_,
396 cell_dst_iter, amx_scratchpad));
397#endif
398 }
399
400 CHECK(compute_merged_layer_part_if_applicable(
401 prop_kind::backward, dir, lay));
402#undef SAFE_PTR
403
404 if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) {
405 // This is split in 3 pieces if we skip copies.
406 // last iter in user mem, middle iters in ws, first iter in user mem
407 // Note 1: here we assume no change in datatypes for src_iter, ws_iter and dst_iter
408
409 const dst_iter_t *states_iter = nullptr;
410 int states_iter_ld = 0;
411 int niter_merge_gemm_iter = 0;
412
413 states_iter = &(
414 ws_states_iter(lay + 1, dir, rnn.skip_src_iter_copy(), 0));
415 states_iter_ld = rnn.ws_states_iter_ld;
416 if (rnn.skip_dst_layer_copy()
417 && (lay == rnn.n_layer - 1)) { // last layer
418 states_iter = dst_layer_;
419 states_iter_ld = rnn.dst_layer_ld_;
420 }
421 niter_merge_gemm_iter = rnn.n_iter - rnn.skip_src_iter_copy();
422 if (niter_merge_gemm_iter > 0) {
423 CHECK(gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic,
424 rnn.mb * niter_merge_gemm_iter, 1.0,
425 (weights_t *)scratch_gates_
426 + rnn.skip_src_iter_copy()
427 * rnn.scratch_gates_nld
428 * rnn.scratch_gates_ld,
429 rnn.scratch_gates_ld, states_iter, states_iter_ld, 1.0,
430 &(diff_weights_iter(lay, dir, 0)),
431 rnn.diff_weights_iter_ld));
432 }
433
434 if (rnn.skip_src_iter_copy()) {
435 states_iter = src_iter_ + src_iter_mdw.off(lay, dir, 0, 0);
436 states_iter_ld = rnn.src_iter_ld_;
437 niter_merge_gemm_iter = 1;
438 CHECK(gemm('N', 'T', rnn.n_gates * rnn.dhc, rnn.sic,
439 rnn.mb * niter_merge_gemm_iter, 1.0,
440 (weights_t *)scratch_gates_, rnn.scratch_gates_ld,
441 states_iter, states_iter_ld, 1.0,
442 &(diff_weights_iter(lay, dir, 0)),
443 rnn.diff_weights_iter_ld));
444 }
445 }
446 }
447 return dnnl_success;
448}
449
450//********* GRID computations strategy: utility functions **********//
451
452// for bf32 src_data_t(bf16) and input_data_t(f32) types can be different.
453template <typename src_data_t, typename input_data_t>
454void copy_init_layer_fwd_template(const rnn_conf_t &rnn,
455 src_data_t *__restrict ws_states_layer_,
456 const input_data_t *__restrict xt_, const memory_desc_wrapper &xt_d) {
457
458 const AOC<src_data_t, 4> ws_states_layer(ws_states_layer_, rnn.n_dir,
459 rnn.n_iter + 1, rnn.mb, rnn.ws_states_layer_ld);
460
461 parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) {
462 auto xxt = xt_ + xt_d.blk_off(it, b);
463 src_data_t *ws_l2r_ptr = &(ws_states_layer(0, it + 1, b, 0));
464 src_data_t *ws_r2l_ptr
465 = &(ws_states_layer(rnn.n_dir - 1, rnn.n_iter - it, b, 0));
466 if (rnn.exec_dir != r2l) {
467 if (rnn.is_bf32()) {
468 cvt_float_to_bfloat16(
469 (bfloat16_t *)ws_l2r_ptr, (const float *)xxt, rnn.slc);
470 } else {
471 PRAGMA_OMP_SIMD()
472 for (int c = 0; c < rnn.slc; c++)
473 ws_l2r_ptr[c] = xxt[c];
474 }
475 }
476 if (rnn.exec_dir != l2r) {
477 if (rnn.is_bf32()) {
478 cvt_float_to_bfloat16(
479 (bfloat16_t *)ws_r2l_ptr, (const float *)xxt, rnn.slc);
480 } else {
481 PRAGMA_OMP_SIMD()
482 for (int c = 0; c < rnn.slc; c++)
483 ws_r2l_ptr[c] = xxt[c];
484 }
485 }
486 });
487}
488
489template <typename acc_data_t>
490void copy_init_layer_bwd_template(const rnn_conf_t &rnn,
491 acc_data_t *ws_diff_states_layer_, const acc_data_t *diff_dst_layer_,
492 const memory_desc_wrapper &diff_dst_layer_d) {
493 const AOC<acc_data_t, 5> ws_diff_states_layer(ws_diff_states_layer_,
494 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
495 rnn.ws_diff_states_layer_ld);
496
497 switch (rnn.exec_dir) {
498 case bi_concat:
499 parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) {
500 const auto diff_dst_layer_x
501 = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
502 for (int s = 0; s < rnn.dlc; s++) {
503 ws_diff_states_layer(rnn.n_layer, 0, it, b, s)
504 = diff_dst_layer_x[s];
505 ws_diff_states_layer(
506 rnn.n_layer, 1, rnn.n_iter - it - 1, b, s)
507 = diff_dst_layer_x[rnn.dlc + s];
508 }
509 });
510 break;
511 case bi_sum:
512 parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) {
513 const auto diff_dst_layer_x
514 = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
515 for (int s = 0; s < rnn.dlc; s++) {
516 ws_diff_states_layer(rnn.n_layer, 0, it, b, s)
517 = diff_dst_layer_x[s];
518 ws_diff_states_layer(
519 rnn.n_layer, 1, rnn.n_iter - it - 1, b, s)
520 = diff_dst_layer_x[s];
521 }
522 });
523 break;
524 case l2r:
525 parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) {
526 const auto diff_dst_layer_x
527 = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
528 for (int s = 0; s < rnn.dlc; s++) {
529 ws_diff_states_layer(rnn.n_layer, 0, it, b, s)
530 = diff_dst_layer_x[s];
531 }
532 });
533 break;
534 case r2l:
535 parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) {
536 const auto diff_dst_layer_x = diff_dst_layer_
537 + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b);
538 for (int s = 0; s < rnn.dlc; s++) {
539 ws_diff_states_layer(rnn.n_layer, 0, it, b, s)
540 = diff_dst_layer_x[s];
541 }
542 });
543 break;
544 default: assert(!"Unsupported direction"); break;
545 }
546}
547
548#define RNN_DECL_COPY_INIT_LAYER_FWD(cname) \
549 template <> \
550 template <typename input_data_t> \
551 void cname::copy_init_layer(const rnn_conf_t &rnn, \
552 src_layer_t *ws_states_layer_, gemm_acc_t *ws_diff_states_layer_, \
553 const input_data_t *xt_, const gemm_acc_t *diff_dst_layer_) \
554 const { \
555 copy_init_layer_fwd_template(rnn, ws_states_layer_, xt_, \
556 memory_desc_wrapper(pd()->src_md(0))); \
557 }
558
559RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_f32_t)
560RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_bf16_t)
561RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_u8s8_t)
562RNN_DECL_COPY_INIT_LAYER_FWD(ref_rnn_fwd_s8s8_t)
563
564#define RNN_DECL_COPY_INIT_LAYER_BWD(cname) \
565 template <> \
566 template <typename input_data_t> \
567 void cname::copy_init_layer(const rnn_conf_t &rnn, \
568 src_layer_t *ws_states_layer_, gemm_acc_t *ws_diff_states_layer_, \
569 const input_data_t *xt_, const gemm_acc_t *diff_dst_layer_) \
570 const { \
571 copy_init_layer_bwd_template(rnn, ws_diff_states_layer_, \
572 diff_dst_layer_, memory_desc_wrapper(pd()->diff_dst_md(0))); \
573 }
574
575RNN_DECL_COPY_INIT_LAYER_BWD(ref_rnn_bwd_f32_t)
576RNN_DECL_COPY_INIT_LAYER_BWD(ref_rnn_bwd_bf16_t)
577
578/* For int8 configuration, input iteration states may be of types f32 or u8
579 * Internally h_state is always stored in u8 and c_state is always stored in f32
580 * If input states are of type u8 then h state is copied and c state is dequantized
581 * If input states are of type f32 then h state is quantized and c_state is copied
582 * */
583template <typename src_data_t, typename input_data_t>
584void copy_init_iter_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd,
585 src_data_t *__restrict ws_states_iter_,
586 void *__restrict ws_states_iter_c_,
587 const input_data_t *__restrict src_iter_,
588 const memory_desc_wrapper &src_iter_d,
589 const void *__restrict src_iter_c_,
590 const memory_desc_wrapper &src_iter_c_d) {
591 const AOC<src_data_t, 5> ws_states_iter(ws_states_iter_, rnn.n_layer + 1,
592 rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.ws_states_iter_ld);
593 const auto ws_states_iter_c_aoc = rnn_utils::make_raw_aoc(ws_states_iter_c_,
594 types::data_type_size(rnn.src_iter_c_dt), rnn.n_layer + 1,
595 rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.ws_states_iter_c_ld);
596
597 const float data_shift = pd->attr()->rnn_data_qparams_.shift_;
598 const float data_scale = pd->attr()->rnn_data_qparams_.scale_;
599
600 const bool quantize = rnn.is_int8_conf()
601 && IMPLICATION(pd->with_src_iter(),
602 pd->src_md(1)->data_type == data_type::f32);
603 const auto maybe_q = [&](input_data_t f) {
604 if (quantize) {
605 float qf = f * data_scale + data_shift;
606 return qz_a1b0<float, src_data_t>()(qf);
607 } else
608 return (src_data_t)f;
609 };
610 const src_data_t zero = maybe_q(0.f);
611 const auto zero_ws_iter_c = [&](int lay, int dir, int mb_id, int sic_id) {
612 void *ws_states_iter_c = const_cast<void *>(
613 ws_states_iter_c_aoc(lay, dir, 0, mb_id, sic_id));
614 if (rnn.src_iter_c_dt == data_type::f32)
615 *(static_cast<float *>(ws_states_iter_c)) = 0.0f;
616 else if (rnn.src_iter_c_dt == data_type::bf16)
617 *(static_cast<bfloat16_t *>(ws_states_iter_c)) = 0.0f;
618 };
619
620 if (src_iter_) {
621 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
622 [&](dim_t lay, dim_t dir, dim_t b) {
623 const auto *ss
624 = &src_iter_[src_iter_d.blk_off(lay, dir, b, 0)];
625 auto *dd = &ws_states_iter(lay + 1, dir, 0, b, 0);
626 PRAGMA_OMP_SIMD()
627 for (int s = 0; s < rnn.sic; s++)
628 dd[s] = maybe_q(ss[s]);
629 });
630 } else {
631 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
632 [&](dim_t lay, dim_t dir, dim_t b) {
633 for (int j = 0; j < rnn.sic; j++)
634 ws_states_iter(lay + 1, dir, 0, b, j) = zero;
635 if (pd->cell_kind() == alg_kind::vanilla_lstm)
636 for (int j = 0; j < rnn.dhc; j++)
637 zero_ws_iter_c(lay + 1, dir, b, j);
638 });
639 }
640}
641
642template <typename acc_data_t>
643void copy_init_iter_bwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd,
644 acc_data_t *ws_diff_states_iter_, acc_data_t *ws_diff_states_iter_c_,
645 const acc_data_t *diff_dst_iter_,
646 const memory_desc_wrapper diff_dst_iter_d,
647 const float *diff_dst_iter_c_,
648 const memory_desc_wrapper diff_dst_iter_c_d) {
649 const AOC<acc_data_t, 5> ws_diff_states_iter(ws_diff_states_iter_,
650 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
651 rnn.ws_diff_states_iter_ld);
652 const AOC<acc_data_t, 5> ws_diff_states_iter_c(ws_diff_states_iter_c_,
653 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
654 rnn.ws_diff_states_iter_c_ld);
655 if (diff_dst_iter_) {
656 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
657 [&](dim_t lay, dim_t dir, dim_t b) {
658 array_copy(
659 &(ws_diff_states_iter(lay, dir, rnn.n_iter, b, 0)),
660 diff_dst_iter_
661 + diff_dst_iter_d.blk_off(lay, dir, b),
662 rnn.dic);
663 if (pd->cell_kind() == alg_kind::vanilla_lstm)
664 array_copy(&(ws_diff_states_iter_c(
665 lay, dir, rnn.n_iter, b, 0)),
666 diff_dst_iter_c_
667 + diff_dst_iter_c_d.blk_off(
668 lay, dir, b),
669 rnn.dhc);
670 });
671 } else {
672 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
673 [&](dim_t lay, dim_t dir, dim_t i) {
674 for (int j = 0; j < rnn.dic; j++)
675 ws_diff_states_iter(lay, dir, rnn.n_iter, i, j) = 0.0f;
676 if (pd->cell_kind() == alg_kind::vanilla_lstm)
677 for (int j = 0; j < rnn.dhc; j++)
678 ws_diff_states_iter_c(lay, dir, rnn.n_iter, i, j)
679 = 0.0f;
680 });
681 }
682}
683
684#define RNN_DECL_COPY_INIT_ITER_FWD(cname) \
685 template <> \
686 template <typename input_data_t> \
687 void cname::copy_init_iter(const rnn_conf_t &rnn, \
688 src_layer_t *__restrict ws_states_iter_, \
689 void *__restrict ws_states_iter_c_, \
690 gemm_acc_t *__restrict ws_diff_states_iter_, \
691 gemm_acc_t *__restrict ws_diff_states_iter_c_, \
692 const input_data_t *__restrict src_iter_, \
693 const void *__restrict src_iter_c_, \
694 const gemm_acc_t *__restrict diff_dst_iter_, \
695 const float *__restrict diff_dst_iter_c_) const { \
696 auto src_iter_d = memory_desc_wrapper(pd()->src_md(1)); \
697 auto src_iter_c_d = memory_desc_wrapper(pd()->src_md(2)); \
698 copy_init_iter_fwd_template(rnn, pd(), ws_states_iter_, \
699 ws_states_iter_c_, src_iter_, src_iter_d, src_iter_c_, \
700 src_iter_c_d); \
701 }
702
703RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_f32_t)
704RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_bf16_t)
705RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_u8s8_t)
706RNN_DECL_COPY_INIT_ITER_FWD(ref_rnn_fwd_s8s8_t)
707
708#define RNN_DECL_COPY_INIT_ITER_BWD(cname) \
709 template <> \
710 template <typename input_data_t> \
711 void cname::copy_init_iter(const rnn_conf_t &rnn, \
712 src_layer_t *ws_states_iter_, void *ws_states_iter_c_, \
713 gemm_acc_t *ws_diff_states_iter_, \
714 gemm_acc_t *ws_diff_states_iter_c_, const input_data_t *src_iter_, \
715 const void *src_iter_c_, const gemm_acc_t *diff_dst_iter_, \
716 const float *diff_dst_iter_c_) const { \
717 auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); \
718 auto diff_dst_iter_c_d = memory_desc_wrapper(pd()->diff_dst_md(2)); \
719 copy_init_iter_bwd_template(rnn, pd(), ws_diff_states_iter_, \
720 ws_diff_states_iter_c_, diff_dst_iter_, diff_dst_iter_d, \
721 diff_dst_iter_c_, diff_dst_iter_c_d); \
722 }
723
724RNN_DECL_COPY_INIT_ITER_BWD(ref_rnn_bwd_f32_t)
725RNN_DECL_COPY_INIT_ITER_BWD(ref_rnn_bwd_bf16_t)
726
727template <typename src_data_t, typename dst_layer_dt, typename dst_iter_dt>
728void copy_res_layer_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd,
729 dst_layer_dt *dst_layer_, memory_desc_wrapper &dst_layer_d,
730 const dst_iter_dt *dst_iter_, const memory_desc_wrapper &dst_iter_d,
731 const src_data_t *ws_states_layer_) {
732
733 const AOC<const src_data_t, 5> ws_states_layer(ws_states_layer_,
734 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
735 rnn.ws_states_layer_ld);
736 const float shift = (pd->attr()->rnn_data_qparams_.shift_);
737 const float scale = (pd->attr()->rnn_data_qparams_.scale_);
738
739 const bool dequantize
740 = pd->dst_md(0)->data_type == data_type::f32 && rnn.is_int8_conf();
741 const bool dequantize_at_copy = dequantize && rnn.exec_dir != bi_sum;
742
743 // minor optimization helper for a compiler
744 static constexpr bool rnn_u8u8_case
745 = std::is_same<dst_layer_dt, uint8_t>::value
746 && std::is_same<src_data_t, uint8_t>::value;
747 static constexpr bool rnn_s8s8_case
748 = std::is_same<dst_layer_dt, int8_t>::value
749 && std::is_same<src_data_t, int8_t>::value;
750
751 const auto copy_vec = [&](dst_layer_dt *dd, const src_data_t *ss) {
752 if (dequantize_at_copy) {
753 PRAGMA_OMP_SIMD()
754 for (int s = 0; s < rnn.dlc; s++)
755 dd[s] = (dst_layer_dt)(((float)ss[s] - shift) / scale);
756 } else {
757 PRAGMA_OMP_SIMD()
758 for (int s = 0; s < rnn.dlc; s++)
759 dd[s] = (dst_layer_dt)ss[s];
760 }
761 };
762
763 const auto acc_vec = [&](dst_layer_dt *dd, const src_data_t *ss) {
764 if (dequantize) {
765 PRAGMA_OMP_SIMD()
766 for (int s = 0; s < rnn.dlc; s++) {
767 float val = (float)ss[s] + dd[s];
768 val = qz_a1b0<float, src_data_t>()(val);
769 dd[s] = (dst_layer_dt)((val - 2 * shift) / scale);
770 }
771 } else if (rnn_u8u8_case
772 || rnn_s8s8_case) { // instead of checking for rnn.is_int8()
773 PRAGMA_OMP_SIMD()
774 for (int s = 0; s < rnn.dlc; s++)
775 dd[s] = saturate<dst_layer_dt, int16_t>(
776 (int16_t)dd[s] + (int16_t)ss[s]);
777 } else {
778 PRAGMA_OMP_SIMD()
779 for (int s = 0; s < rnn.dlc; s++)
780 dd[s] += (dst_layer_dt)ss[s];
781 }
782 };
783
784 // if skip_dst_iter_copy, then the data for the last iteration is
785 // in dst_iter, not in workspace
786 parallel_nd(rnn.n_iter - (rnn.skip_dst_iter_copy() ? 1 : 0), rnn.mb,
787 [&](dim_t it, dim_t b) {
788 int dir = 0;
789 if (rnn.exec_dir != r2l) {
790 const auto *ss
791 = &ws_states_layer(rnn.n_layer, dir, it + 1, b, 0);
792 auto *dd = &dst_layer_[dst_layer_d.blk_off(
793 it, b, dir * rnn.dlc)];
794 copy_vec(dd, ss);
795 dir = 1;
796 }
797 if (rnn.exec_dir != l2r) {
798 const auto *ss = &ws_states_layer(
799 rnn.n_layer, dir, rnn.n_iter - it, b, 0);
800 if (rnn.exec_dir == bi_sum) {
801 auto *dd = &dst_layer_[dst_layer_d.blk_off(it, b, 0)];
802 acc_vec(dd, ss);
803 } else {
804 auto *dd = &dst_layer_[dst_layer_d.blk_off(
805 it, b, dir * rnn.dlc)];
806 copy_vec(dd, ss);
807 }
808 }
809 });
810 if (rnn.skip_dst_iter_copy()) {
811 parallel_nd(rnn.mb, [&](dim_t b) {
812 const int it = rnn.n_iter - 1;
813 int dir = 0;
814 if (rnn.exec_dir != r2l) {
815 const auto *ss = dst_iter_
816 + dst_iter_d.blk_off(rnn.n_layer - 1, dir, b, 0);
817 auto *dd = &dst_layer_[dst_layer_d.blk_off(
818 it, b, dir * rnn.dlc)];
819 copy_vec(dd, (src_data_t *)ss);
820 dir = 1;
821 }
822 if (rnn.exec_dir != l2r) {
823 const auto *ss = dst_iter_
824 + dst_iter_d.blk_off(rnn.n_layer - 1, dir, b, 0);
825 if (rnn.exec_dir == bi_sum) {
826 auto *dd = &dst_layer_[dst_layer_d.blk_off(it, b, 0)];
827 acc_vec(dd, (src_data_t *)ss);
828 } else {
829 auto *dd = &dst_layer_[dst_layer_d.blk_off(
830 it, b, dir * rnn.dlc)];
831 copy_vec(dd, (src_data_t *)ss);
832 }
833 }
834 });
835 }
836}
837
838template <typename acc_data_t>
839void copy_res_layer_bwd_template(const rnn_conf_t &rnn,
840 acc_data_t *diff_src_layer_, memory_desc_wrapper &diff_src_layer_d,
841 const acc_data_t *ws_diff_states_layer_) {
842 const AOC<const acc_data_t, 5> ws_diff_states_layer(ws_diff_states_layer_,
843 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
844 rnn.ws_diff_states_layer_ld);
845
846 parallel_nd(rnn.n_iter, rnn.mb, [&](dim_t it, dim_t b) {
847 int dir = 0;
848 for (int s = 0; s < rnn.slc; s++) {
849 acc_data_t *dst_addr = diff_src_layer_
850 + diff_src_layer_d.blk_off(
851 (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, b,
852 dir * rnn.slc + s);
853 acc_data_t res = ws_diff_states_layer(0, 0, it, b, s);
854 if (rnn.n_dir - 1)
855 res += ws_diff_states_layer(0, 1, rnn.n_iter - 1 - it, b, s);
856 dst_addr[0] = res;
857 }
858 });
859}
860
861#define RNN_DECL_COPY_RES_LAYER_FWD(cname) \
862 template <> \
863 template <typename dst_layer_dt, typename dst_iter_dt> \
864 void cname::copy_res_layer(const rnn_conf_t &rnn, \
865 dst_layer_dt *dst_layer_, gemm_acc_t *diff_src_layer, \
866 const dst_iter_dt *dst_iter_, const src_layer_t *ws_states_layer_, \
867 const gemm_acc_t *ws_diff_states_layer_) const { \
868 auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); \
869 auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); \
870 copy_res_layer_fwd_template(rnn, pd(), dst_layer_, dst_layer_d, \
871 dst_iter_, dst_iter_d, ws_states_layer_); \
872 }
873
874RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_f32_t)
875RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_bf16_t)
876RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_u8s8_t)
877RNN_DECL_COPY_RES_LAYER_FWD(ref_rnn_fwd_s8s8_t)
878
879#define RNN_DECL_COPY_RES_LAYER_BWD(cname) \
880 template <> \
881 template <typename dst_layer_dt, typename dst_iter_dt> \
882 void cname::copy_res_layer(const rnn_conf_t &rnn, \
883 dst_layer_dt *dst_layer_, gemm_acc_t *diff_src_layer_, \
884 const dst_iter_dt *dst_iter_, const src_layer_t *ws_states_layer_, \
885 const gemm_acc_t *ws_diff_states_layer_) const { \
886 auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); \
887 copy_res_layer_bwd_template(rnn, diff_src_layer_, diff_src_layer_d, \
888 ws_diff_states_layer_); \
889 }
890
891RNN_DECL_COPY_RES_LAYER_BWD(ref_rnn_bwd_f32_t)
892RNN_DECL_COPY_RES_LAYER_BWD(ref_rnn_bwd_bf16_t)
893
894template <typename src_data_t, typename dst_iter_dt, typename dst_layer_dt>
895void copy_res_iter_fwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd,
896 dst_iter_dt *dst_iter_, memory_desc_wrapper &dst_iter_d,
897 void *dst_iter_c_, memory_desc_wrapper dst_iter_c_d,
898 const dst_layer_dt *dst_layer_, memory_desc_wrapper dst_layer_d,
899 const src_data_t *ws_states_iter_, const void *ws_states_iter_c_) {
900 if (dst_iter_ == nullptr) return;
901
902 const AOC<const src_data_t, 5> ws_states_iter(ws_states_iter_,
903 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
904 rnn.ws_states_iter_ld);
905
906 const float data_shift = pd->attr()->rnn_data_qparams_.shift_;
907 const float data_scale = pd->attr()->rnn_data_qparams_.scale_;
908
909 const bool dequantize = pd->with_dst_iter()
910 && pd->dst_md(1)->data_type == data_type::f32 && rnn.is_int8_conf();
911 const auto copy_vec = [&](dst_iter_dt *dd, const src_data_t *ss) {
912 if (dequantize) {
913 PRAGMA_OMP_SIMD()
914 for (int s = 0; s < rnn.dic; s++)
915 dd[s] = (dst_iter_dt)(((float)ss[s] - data_shift) / data_scale);
916 } else {
917 PRAGMA_OMP_SIMD()
918 for (int s = 0; s < rnn.dic; s++)
919 dd[s] = (dst_iter_dt)ss[s];
920 }
921 };
922
923 // If skip_dst_layer_copy, then the data to copy for the last
924 // layer is in dst_layer, not in workspace.
925 const auto n_layer_in_ws = rnn.n_layer - rnn.skip_dst_layer_copy();
926
927 parallel_nd(n_layer_in_ws, rnn.n_dir, rnn.mb,
928 [&](dim_t lay, dim_t dir, dim_t b) {
929 const auto *ss
930 = &ws_states_iter(lay + 1, dir, rnn.n_iter, b, 0);
931 auto *dd = dst_iter_ + dst_iter_d.blk_off(lay, dir, b, 0);
932 copy_vec(dd, ss);
933 });
934
935 if (rnn.skip_dst_layer_copy()) {
936 parallel_nd(rnn.n_dir, rnn.mb, [&](dim_t dir, dim_t b) {
937 const auto *ss
938 = &dst_layer_[dst_layer_d.blk_off(rnn.n_iter - 1, b, dir)];
939 auto *dd = &dst_iter_[dst_iter_d.blk_off(
940 rnn.n_layer - 1, dir, b, 0)];
941 copy_vec(dd, (src_data_t *)ss);
942 });
943 }
944}
945
946template <typename acc_data_t>
947void copy_res_iter_bwd_template(const rnn_conf_t &rnn, const rnn_pd_t *pd,
948 acc_data_t *diff_src_iter_, memory_desc_wrapper &diff_src_iter_d,
949 float *diff_src_iter_c_, memory_desc_wrapper &diff_src_iter_c_d,
950 const acc_data_t *ws_diff_states_iter_,
951 const acc_data_t *ws_diff_states_iter_c_) {
952 const AOC<const acc_data_t, 5> ws_diff_states_iter(ws_diff_states_iter_,
953 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
954 rnn.ws_diff_states_iter_ld);
955 const AOC<const acc_data_t, 5> ws_diff_states_iter_c(ws_diff_states_iter_c_,
956 rnn.n_layer + 1, rnn.n_dir, rnn.n_iter + 1, rnn.mb,
957 rnn.ws_diff_states_iter_c_ld);
958 if (diff_src_iter_) {
959 parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
960 [&](dim_t lay, dim_t dir, dim_t b) {
961 for (int s = 0; s < rnn.sic; s++) {
962 diff_src_iter_[diff_src_iter_d.blk_off(lay, dir, b, s)]
963 = ws_diff_states_iter(lay, dir, 0, b, s);
964 }
965 if (pd->cell_kind() == alg_kind::vanilla_lstm)
966 for (int s = 0; s < rnn.dhc; s++) {
967 diff_src_iter_c_[diff_src_iter_c_d.blk_off(
968 lay, dir, b, s)]
969 = ws_diff_states_iter_c(lay, dir, 0, b, s);
970 }
971 });
972 }
973}
974
975#define RNN_DECL_COPY_RES_ITER_FWD(cname) \
976 template <> \
977 template <typename dst_iter_dt, typename dst_layer_dt> \
978 void cname::copy_res_iter(const rnn_conf_t &rnn, dst_iter_dt *dst_iter_, \
979 void *dst_iter_c_, gemm_acc_t *diff_src_iter_, \
980 float *diff_src_iter_c_, const dst_layer_dt *dst_layer_, \
981 const src_layer_t *ws_states_layer_, \
982 const void *ws_states_iter_c_, \
983 const gemm_acc_t *ws_diff_states_iter_, \
984 const gemm_acc_t *ws_diff_states_iter_c_) const { \
985 auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); \
986 auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); \
987 auto dst_iter_c_d = memory_desc_wrapper(pd()->dst_md(2)); \
988 copy_res_iter_fwd_template(rnn, pd(), dst_iter_, dst_iter_d, \
989 dst_iter_c_, dst_iter_c_d, dst_layer_, dst_layer_d, \
990 ws_states_layer_, ws_states_iter_c_); \
991 }
992
993RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_f32_t)
994RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_bf16_t)
995RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_u8s8_t)
996RNN_DECL_COPY_RES_ITER_FWD(ref_rnn_fwd_s8s8_t)
997
998#define RNN_DECL_COPY_RES_ITER_BWD(cname) \
999 template <> \
1000 template <typename output_data_t, typename dst_data_t> \
1001 void cname::copy_res_iter(const rnn_conf_t &rnn, output_data_t *dst_iter_, \
1002 void *dst_iter_c_, gemm_acc_t *diff_src_iter_, \
1003 float *diff_src_iter_c_, const dst_data_t *dst_layer_, \
1004 const src_layer_t *ws_states_layer_, \
1005 const void *ws_states_iter_c_, \
1006 const gemm_acc_t *ws_diff_states_iter_, \
1007 const gemm_acc_t *ws_diff_states_iter_c_) const { \
1008 auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); \
1009 auto diff_src_iter_c_d = memory_desc_wrapper(pd()->diff_src_md(2)); \
1010 copy_res_iter_bwd_template(rnn, pd(), diff_src_iter_, diff_src_iter_d, \
1011 diff_src_iter_c_, diff_src_iter_c_d, ws_diff_states_iter_, \
1012 ws_diff_states_iter_c_); \
1013 }
1014
1015RNN_DECL_COPY_RES_ITER_BWD(ref_rnn_bwd_f32_t)
1016RNN_DECL_COPY_RES_ITER_BWD(ref_rnn_bwd_bf16_t)
1017
1018rnn_bias_prepare_sig_templ(copy_bias_to_scratch) {
1019 const AOC<T, 3> scratch_bias(
1020 scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc);
1021
1022 parallel_nd(rnn.n_layer * rnn.n_dir, [&](dim_t i) {
1023 const int off = i * rnn.n_bias * rnn.dhc;
1024 PRAGMA_OMP_SIMD()
1025 for (int j = 0; j < rnn.n_bias * rnn.dhc; j++)
1026 scratch_bias_[off + j] = b_[off + j];
1027 });
1028}
1029
1030rnn_bias_prepare_sig_templ(copy_bias_to_ws) {
1031 /* Original set of bias provided by the user */
1032 const AOC<const T, 5> b(b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc);
1033 /* Array of pointers initialized in packing */
1034 const AOC<T *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
1035 const AOC<T, 3> scratch_bias(
1036 scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dhc);
1037
1038 for (int i = 0; i < rnn.n_layer; i++) {
1039 for (int d = 0; d < rnn.n_dir; d++) {
1040 int offset_bias = 0;
1041 for (int p = 0; p < rnn.n_parts_bias; p++) {
1042 bias(i, d, p) = rnn.copy_bias
1043 ? const_cast<T *>(&scratch_bias(i, d, offset_bias))
1044 : const_cast<T *>(&b(i, d, offset_bias));
1045 offset_bias += rnn.parts_bias[p] * rnn.dhc;
1046 }
1047 }
1048 }
1049}
1050
1051template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
1052 data_type_t acc_type>
1053rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
1054 acc_type>::bias_prepare)) {
1055
1056 if (rnn.copy_bias) {
1057 if (rnn.bias_dt == data_type::f32)
1058 copy_bias_to_scratch(rnn, reinterpret_cast<float **>(bias_),
1059 static_cast<const float *>(b_),
1060 static_cast<float *>(scratch_bias_));
1061 else if (rnn.bias_dt == data_type::bf16)
1062 copy_bias_to_scratch(rnn, reinterpret_cast<bfloat16_t **>(bias_),
1063 static_cast<const bfloat16_t *>(b_),
1064 static_cast<bfloat16_t *>(scratch_bias_));
1065 else
1066 assert("Unsupported bias data type");
1067 }
1068
1069 if (rnn.bias_dt == data_type::f32)
1070 copy_bias_to_ws(rnn, reinterpret_cast<float **>(bias_),
1071 static_cast<const float *>(b_),
1072 static_cast<float *>(scratch_bias_));
1073 else if (rnn.bias_dt == data_type::bf16)
1074 copy_bias_to_ws(rnn, reinterpret_cast<bfloat16_t **>(bias_),
1075 static_cast<const bfloat16_t *>(b_),
1076 static_cast<bfloat16_t *>(scratch_bias_));
1077 else
1078 assert("Unsupported bias data type");
1079}
1080
1081static void apply_bias_compensation(const rnn_utils::rnn_conf_t &rnn,
1082 float *scratch_bias_, const float *w_iter_comp,
1083 const float *w_layer_comp, const float data_shift,
1084 const float data_scale, const float *const weights_scales,
1085 const bool scale_per_oc) {
1086
1087 for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++)
1088 for (int j = 0; j < rnn.n_bias * rnn.dhc; j++) {
1089 const size_t off = i * rnn.n_bias * rnn.dhc + j;
1090 const float weights_scale
1091 = scale_per_oc ? weights_scales[j] : weights_scales[0];
1092 scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off])
1093 * data_shift / (weights_scale * data_scale);
1094 }
1095}
1096
1097template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
1098 data_type_t acc_type>
1099rnn_bias_finalize_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
1100 acc_type>::bias_finalize)) {
1101 if (rnn.is_unsigned_int8_conf()) {
1102 const float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
1103 const float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
1104 const float *const weights_scales
1105 = pd()->attr()->rnn_weights_qparams_.scales_;
1106 const bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0;
1107
1108 apply_bias_compensation(rnn, static_cast<float *>(scratch_bias_),
1109 w_iter_comp, w_layer_comp, data_shift, data_scale,
1110 weights_scales, scale_per_oc);
1111 }
1112}
1113
1114template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
1115 data_type_t acc_type>
1116rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
1117 acc_type>::assign_packed_weights)) {
1118 assert(md->format_kind == format_kind::rnn_packed);
1119 const auto packed_desc = md->format_desc.rnn_packed_desc;
1120 const AOC<weights_t *, 3> weights(
1121 weights_, rnn.n_layer, rnn.n_dir, packed_desc.n_parts);
1122
1123 size_t offset_packed = 0;
1124 for (int l = 0; l < rnn.n_layer; l++)
1125 for (int d = 0; d < rnn.n_dir; d++) {
1126 for (int p = 0; p < packed_desc.n_parts; p++) {
1127 weights(l, d, p) = (weights_t *)&w_[offset_packed];
1128 offset_packed
1129 += packed_desc.part_pack_size[p] / sizeof(weights_t);
1130 }
1131 }
1132}
1133
1134template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
1135 data_type_t acc_type>
1136rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type, weights_type,
1137 acc_type>::assign_weights)) {
1138 assert(md->format_kind == format_kind::blocked);
1139 const auto &blk = md->format_desc.blocking;
1140 /* Original set of weights provided by the user */
1141 const AOC<const weights_t, 3> w(
1142 w_, rnn.n_layer, rnn.n_dir, (int)blk.strides[1]);
1143 /* Array of pointers for each part of weights */
1144 const AOC<weights_t *, 3> weights(
1145 weights_, rnn.n_layer, rnn.n_dir, n_parts);
1146
1147 for (int i = 0; i < rnn.n_layer; i++)
1148 for (int d = 0; d < rnn.n_dir; d++) {
1149 size_t offset_weights = 0;
1150 for (int p = 0; p < n_parts; p++) {
1151 weights(i, d, p) = (weights_t *)&w(i, d, offset_weights);
1152 offset_weights += gates_per_part[p] * blk.strides[3];
1153 }
1154 }
1155}
1156
1157//********************* Execution function *********************//
1158template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type,
1159 data_type_t acc_type>
1160void _ref_rnn_common_t<aprop, src_type, weights_type, acc_type>::execute_(
1161 const exec_ctx_t &ctx) const {
1162 const rnn_conf_t &rnn = this->pd()->rnn_;
1163 auto src_layer = CTX_IN_MEM(const src_layer_t *, DNNL_ARG_SRC_LAYER);
1164 auto augru_attention
1165 = CTX_IN_MEM(const src_layer_t *, DNNL_ARG_AUGRU_ATTENTION);
1166 auto src_iter = CTX_IN_MEM(const char *, DNNL_ARG_SRC_ITER);
1167 auto src_iter_c = CTX_IN_MEM(const void *, DNNL_ARG_SRC_ITER_C);
1168 auto layer_weights_n_comp
1169 = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS_LAYER);
1170 auto iter_weights_n_comp = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS_ITER);
1171 auto weights_peephole
1172 = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS_PEEPHOLE);
1173 auto projection_weights_n_comp
1174 = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS_PROJECTION);
1175 auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
1176
1177 auto dst_layer = rnn.is_fwd
1178 ? CTX_OUT_MEM(char *, DNNL_ARG_DST_LAYER)
1179 : const_cast<char *>(CTX_IN_MEM(const char *, DNNL_ARG_DST_LAYER));
1180 auto dst_iter = rnn.is_fwd
1181 ? CTX_OUT_MEM(char *, DNNL_ARG_DST_ITER)
1182 : const_cast<char *>(CTX_IN_MEM(const char *, DNNL_ARG_DST_ITER));
1183 auto dst_iter_c = CTX_OUT_MEM(void *, DNNL_ARG_DST_ITER_C);
1184
1185 auto diff_dst_layer
1186 = CTX_IN_MEM(const gemm_acc_t *, DNNL_ARG_DIFF_DST_LAYER);
1187 auto diff_dst_iter = CTX_IN_MEM(const gemm_acc_t *, DNNL_ARG_DIFF_DST_ITER);
1188 auto diff_dst_iter_c = CTX_IN_MEM(const float *, DNNL_ARG_DIFF_DST_ITER_C);
1189
1190 auto w_layer = reinterpret_cast<const weights_t *>(layer_weights_n_comp);
1191 auto w_iter = reinterpret_cast<const weights_t *>(iter_weights_n_comp);
1192 auto w_projection
1193 = reinterpret_cast<const weights_t *>(projection_weights_n_comp);
1194 auto w_layer_comp = reinterpret_cast<const float *>(
1195 layer_weights_n_comp + rnn.weights_layer_comp_offset);
1196 auto w_iter_comp = reinterpret_cast<const float *>(
1197 iter_weights_n_comp + rnn.weights_iter_comp_offset);
1198 auto w_projection_comp = reinterpret_cast<const float *>(
1199 projection_weights_n_comp + rnn.weights_projection_comp_offset);
1200 auto scratchpad = ctx.get_scratchpad_grantor();
1201
1202 auto ptr_wei_layer
1203 = scratchpad.template get<weights_t *>(key_rnn_ptrs_wei_layer);
1204 auto ptr_wei_iter
1205 = scratchpad.template get<weights_t *>(key_rnn_ptrs_wei_iter);
1206 auto ptr_wei_projection
1207 = scratchpad.template get<weights_t *>(key_rnn_ptrs_wei_projection);
1208 auto ptr_bias = scratchpad.template get<void *>(key_rnn_ptrs_bia);
1209 // Here we use scratch_gates for the output of GEMMs on FWD and on input of GEMMs for BWD.
1210 // None of the values are kept for bwd
1211 auto scratch_gates = scratchpad.template get<scratch_t>(key_rnn_gates);
1212#if DNNL_X64
1213 const auto scratch_gates_blocked
1214 = scratchpad.template get<scratch_t>(key_rnn_gates_blocked);
1215 const auto scratch_src_layer
1216 = scratchpad.template get<scratch_t>(key_rnn_src_layer_trans);
1217 const auto scratch_src_iter
1218 = scratchpad.template get<scratch_t>(key_rnn_src_iter_trans);
1219#endif
1220
1221 auto scratch_ht = scratchpad.template get<ht_t>(key_rnn_ht);
1222 auto scratch_diff_ht = scratchpad.template get<gemm_acc_t>(key_rnn_diff_ht);
1223 auto scratch_cell = scratchpad.template get<scratch_t>(key_rnn_cell);
1224
1225 gemm_acc_t *amx_scratchpad = nullptr;
1226#if DNNL_X64
1227 x64::brgemm_batch_element_t *addr_batch_global = nullptr;
1228 if (rnn.is_brgemm && (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx())) {
1229 amx_scratchpad = scratchpad.template get<gemm_acc_t>(
1230 key_brgemm_primitive_buffer);
1231 }
1232 addr_batch_global = scratchpad.template get<x64::brgemm_batch_element_t>(
1233 key_brgemm_primitive_batch);
1234#endif
1235 // Fetching buffers from the workspace
1236 // if no workspace was provided we use the scratchpad
1237 char *scratch_ptr = scratchpad.template get<char>(key_rnn_space);
1238 char *ws_ptr = nullptr;
1239 if (rnn.use_workspace)
1240 ws_ptr = rnn.is_fwd ? CTX_OUT_MEM(char *, DNNL_ARG_WORKSPACE)
1241 : const_cast<char *>(CTX_IN_MEM(
1242 const char *, DNNL_ARG_WORKSPACE));
1243
1244 char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr;
1245 // ws_gates is only used to pass data from FWD to BWD.
1246 // assumption: in training, src_data_t and weights_t match
1247 gates_t *ws_gates = (gates_t *)(base_ptr + ws_gates_offset_);
1248 dst_iter_t *ws_ht = (dst_iter_t *)(base_ptr + ws_ht_offset_);
1249 src_layer_t *ws_states_layer
1250 = (src_layer_t *)(base_ptr + ws_states_layer_offset_);
1251 src_iter_t *ws_states_iter
1252 = (src_iter_t *)(base_ptr + ws_states_iter_offset_);
1253 void *ws_states_iter_c = (void *)(base_ptr + ws_states_iter_c_offset_);
1254 gemm_acc_t *ws_diff_states_layer
1255 = (gemm_acc_t *)(base_ptr + ws_diff_states_layer_offset_);
1256 gemm_acc_t *ws_diff_states_iter
1257 = (gemm_acc_t *)(base_ptr + ws_diff_states_iter_offset_);
1258 gemm_acc_t *ws_diff_states_iter_c
1259 = (gemm_acc_t *)(base_ptr + ws_diff_states_iter_c_offset_);
1260 gates_t *ws_grid = (gates_t *)(base_ptr + ws_grid_comp_offset_);
1261
1262 auto diff_src_layer = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_SRC_LAYER);
1263 auto diff_src_iter = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_SRC_ITER);
1264 auto diff_src_iter_c = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_SRC_ITER_C);
1265
1266 auto diff_augru_attention
1267 = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_AUGRU_ATTENTION);
1268 auto diff_weights_layer
1269 = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_WEIGHTS_LAYER);
1270 auto diff_weights_iter
1271 = CTX_OUT_MEM(gemm_acc_t *, DNNL_ARG_DIFF_WEIGHTS_ITER);
1272 auto diff_weights_projection
1273 = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
1274 auto diff_weights_peephole
1275 = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
1276 auto diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
1277
1278 // Fetching extra buffers from scratchpad
1279 void *ws_bias = static_cast<void *>(scratch_ptr + ws_bias_offset_);
1280 /* Pack(if using packed gemm API) or copy(if input arrays have bad leading
1281 * dimension */
1282 (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias);
1283
1284 const memory_desc_t *weights_layer_md = pd()->weights_md(0);
1285 const memory_desc_t *weights_iter_md = pd()->weights_md(1);
1286
1287 const auto tag = rnn.n_block == 64 ? format_tag::ldgOI64o2i
1288 : format_tag::ldgOI32o2i;
1289 memory_desc_t wei_layer_desc;
1290 memory_desc_init_by_tag(wei_layer_desc, weights_layer_md->ndims,
1291 weights_layer_md->dims, data_type::bf16, tag);
1292
1293 memory_desc_t wei_iter_desc;
1294 memory_desc_init_by_tag(wei_iter_desc, weights_iter_md->ndims,
1295 weights_iter_md->dims, data_type::bf16, tag);
1296
1297#if DNNL_X64
1298 if (rnn.is_bf32()) {
1299 if (rnn.is_augru) {
1300 const auto bf32_augru_attention
1301 = scratchpad.template get<src_layer_t>(
1302 key_rnn_bf32_attention_trans);
1303 cvt_float_to_bfloat16((bfloat16_t *)bf32_augru_attention,
1304 (float *)augru_attention, rnn.n_iter * rnn.mb);
1305 augru_attention = bf32_augru_attention;
1306 }
1307 engine_t *engine = ctx.stream()->engine();
1308 auto wei_layer_mem
1309 = scratchpad.get_memory_storage(key_rnn_bf32_wei_layer_trans);
1310 auto wei_iter_mem
1311 = scratchpad.get_memory_storage(key_rnn_bf32_wei_iter_trans);
1312 {
1313 memory_t reorder_dst(
1314 engine, &wei_layer_desc, std::move(wei_layer_mem));
1315 exec_args_t reorder_args;
1316 reorder_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_WEIGHTS_LAYER);
1317 reorder_args[DNNL_ARG_DST] = {&reorder_dst, false};
1318 exec_ctx_t reorder_ctx(ctx, std::move(reorder_args));
1319 nested_scratchpad_t ns(
1320 ctx, key_nested_multiple + 0, bf32_wei_layer_reorder_);
1321 reorder_ctx.set_scratchpad_grantor(ns.grantor());
1322 bf32_wei_layer_reorder_->execute(reorder_ctx);
1323 w_layer = scratchpad.template get<weights_t>(
1324 key_rnn_bf32_wei_layer_trans);
1325 weights_layer_md = &wei_layer_desc;
1326 }
1327
1328 {
1329 memory_t reorder_dst(
1330 engine, &wei_iter_desc, std::move(wei_iter_mem));
1331 exec_args_t reorder_args;
1332 reorder_args[DNNL_ARG_SRC] = ctx.args().at(DNNL_ARG_WEIGHTS_ITER);
1333 reorder_args[DNNL_ARG_DST] = {&reorder_dst, false};
1334 exec_ctx_t reorder_ctx(ctx, std::move(reorder_args));
1335 nested_scratchpad_t ns(
1336 ctx, key_nested_multiple + 1, bf32_wei_iter_reorder_);
1337 reorder_ctx.set_scratchpad_grantor(ns.grantor());
1338 bf32_wei_iter_reorder_->execute(reorder_ctx);
1339 w_iter = scratchpad.template get<weights_t>(
1340 key_rnn_bf32_wei_iter_trans);
1341 weights_iter_md = &wei_iter_desc;
1342 }
1343 }
1344#endif
1345
1346 (this->*weights_iter_assign_func)(rnn, weights_iter_md,
1347 rnn.n_parts_weights_iter, rnn.parts_weights_iter, ptr_wei_iter,
1348 w_iter);
1349 (this->*weights_layer_assign_func)(rnn, weights_layer_md,
1350 rnn.n_parts_weights_layer, rnn.parts_weights_layer, ptr_wei_layer,
1351 w_layer);
1352
1353 if (rnn.is_lstm_projection) {
1354 (this->*weights_projection_assign_func)(rnn,
1355 pd()->arg_md(DNNL_ARG_WEIGHTS_PROJECTION),
1356 rnn.n_parts_weights_projection, rnn.parts_weights_projection,
1357 ptr_wei_projection, w_projection);
1358 }
1359
1360 (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp);
1361
1362 // we first need to copy the initial states and input into ws
1363 if (!(rnn.skip_src_layer_copy() && rnn.is_fwd)) {
1364 if (pd()->src_md(0)->data_type == data_type::f32)
1365 copy_init_layer(rnn, ws_states_layer, ws_diff_states_layer,
1366 (const float *)src_layer, diff_dst_layer);
1367 else
1368 copy_init_layer(rnn, ws_states_layer, ws_diff_states_layer,
1369 src_layer, diff_dst_layer);
1370 }
1371
1372 if (!(rnn.skip_src_iter_copy() && rnn.is_fwd)) {
1373 if (pd()->src_md(1)->data_type == data_type::f32)
1374 copy_init_iter(rnn, ws_states_iter,
1375 static_cast<void *>(ws_states_iter_c), ws_diff_states_iter,
1376 ws_diff_states_iter_c, (const float *)src_iter, src_iter_c,
1377 diff_dst_iter, diff_dst_iter_c);
1378 else
1379 copy_init_iter(rnn, ws_states_iter, ws_states_iter_c,
1380 ws_diff_states_iter, ws_diff_states_iter_c,
1381 (const src_iter_t *)src_iter, src_iter_c, diff_dst_iter,
1382 diff_dst_iter_c);
1383 }
1384
1385 // run the execution on the grid
1386 (this->*grid_computation)(
1387#if DNNL_X64
1388 ctx,
1389#endif
1390 rnn, ptr_wei_layer, ptr_wei_iter, ptr_wei_projection,
1391 weights_peephole, w_projection_comp, ptr_bias, src_layer,
1392 augru_attention, (const src_iter_t *)src_iter, src_iter_c,
1393 (dst_layer_t *)dst_layer, (dst_iter_t *)dst_iter, dst_iter_c,
1394 ws_states_layer, ws_states_iter, ws_states_iter_c,
1395 ws_diff_states_layer, ws_diff_states_iter, ws_diff_states_iter_c,
1396 ws_gates, ws_ht, ws_grid, scratch_gates, scratch_ht,
1397 scratch_diff_ht, scratch_cell,
1398#if DNNL_X64
1399 scratch_gates_blocked, scratch_src_layer, scratch_src_iter,
1400#endif
1401 diff_augru_attention, diff_weights_layer, diff_weights_iter,
1402 diff_weights_projection, diff_weights_peephole, diff_bias,
1403 amx_scratchpad
1404#if DNNL_X64
1405 ,
1406 addr_batch_global
1407#endif
1408 );
1409
1410 // Finally we copy the results to the result buffers
1411 if (!(rnn.skip_dst_layer_copy() && rnn.is_fwd)) {
1412 if (pd()->dst_md(0)->data_type == data_type::f32)
1413 copy_res_layer(rnn, (float *)dst_layer, diff_src_layer, dst_iter,
1414 ws_states_layer, ws_diff_states_layer);
1415 else
1416 copy_res_layer(rnn, (dst_layer_t *)dst_layer, diff_src_layer,
1417 dst_iter, ws_states_layer, ws_diff_states_layer);
1418 }
1419
1420 if (!(rnn.skip_dst_iter_copy() && rnn.is_fwd)) {
1421 if (pd()->dst_md(1)->data_type == data_type::f32)
1422 copy_res_iter(rnn, (float *)dst_iter, dst_iter_c, diff_src_iter,
1423 diff_src_iter_c, dst_layer, ws_states_iter,
1424 ws_states_iter_c, ws_diff_states_iter,
1425 ws_diff_states_iter_c);
1426 else
1427 copy_res_iter(rnn, (dst_iter_t *)dst_iter, dst_iter_c,
1428 diff_src_iter, diff_src_iter_c, dst_layer, ws_states_iter,
1429 ws_states_iter_c, ws_diff_states_iter,
1430 ws_diff_states_iter_c);
1431 }
1432};
1433
1434/* Fix for MSVS warning C4661 */
1435template <>
1436rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_ref);
1437template <>
1438rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_fwd);
1439template <>
1440rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_brgemm_bwd);
1441template <>
1442rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
1443template <>
1444rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
1445template <>
1446rnn_merged_layer_execution_sig(ref_rnn_fwd_f32_t::merged_layer_execution_ref);
1447template <>
1448rnn_merged_layer_execution_sig(ref_rnn_fwd_f32_t::merged_layer_brgemm_fwd);
1449template <>
1450rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_ref);
1451template <>
1452rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_fwd);
1453template <>
1454rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_brgemm_bwd);
1455template <>
1456rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru);
1457template <>
1458rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr);
1459template <>
1460rnn_merged_layer_execution_sig(ref_rnn_bwd_f32_t::merged_layer_execution_ref);
1461template <>
1462rnn_merged_layer_execution_sig(ref_rnn_bwd_f32_t::merged_layer_brgemm_fwd);
1463
1464template <>
1465rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_ref);
1466template <>
1467rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_fwd);
1468template <>
1469rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_brgemm_bwd);
1470template <>
1471rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru);
1472template <>
1473rnn_cell_execution_sig(ref_rnn_fwd_bf16_t::cell_execution_gru_lbr);
1474template <>
1475rnn_merged_layer_execution_sig(ref_rnn_fwd_bf16_t::merged_layer_execution_ref);
1476template <>
1477rnn_merged_layer_execution_sig(ref_rnn_fwd_bf16_t::merged_layer_brgemm_fwd);
1478template <>
1479rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_ref);
1480template <>
1481rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_fwd);
1482template <>
1483rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_brgemm_bwd);
1484template <>
1485rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru);
1486template <>
1487rnn_cell_execution_sig(ref_rnn_bwd_bf16_t::cell_execution_gru_lbr);
1488template <>
1489rnn_merged_layer_execution_sig(ref_rnn_bwd_bf16_t::merged_layer_execution_ref);
1490template <>
1491rnn_merged_layer_execution_sig(ref_rnn_bwd_bf16_t::merged_layer_brgemm_fwd);
1492
1493template <>
1494rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_ref);
1495template <>
1496rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_fwd);
1497template <>
1498rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_brgemm_bwd);
1499template <>
1500rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
1501template <>
1502rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr);
1503template <>
1504rnn_merged_layer_execution_sig(ref_rnn_fwd_u8s8_t::merged_layer_execution_ref);
1505template <>
1506rnn_merged_layer_execution_sig(ref_rnn_fwd_u8s8_t::merged_layer_brgemm_fwd);
1507
1508template <>
1509rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_ref);
1510template <>
1511rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_fwd);
1512template <>
1513rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_brgemm_bwd);
1514template <>
1515rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru);
1516template <>
1517rnn_cell_execution_sig(ref_rnn_fwd_s8s8_t::cell_execution_gru_lbr);
1518template <>
1519rnn_merged_layer_execution_sig(ref_rnn_fwd_s8s8_t::merged_layer_execution_ref);
1520template <>
1521rnn_merged_layer_execution_sig(ref_rnn_fwd_s8s8_t::merged_layer_brgemm_fwd);
1522
1523template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32,
1524 data_type::f32, data_type::f32>;
1525template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32,
1526 data_type::f32, data_type::f32>;
1527
1528template struct _ref_rnn_common_t<prop_kind::forward, data_type::bf16,
1529 data_type::bf16, data_type::f32>;
1530template struct _ref_rnn_common_t<prop_kind::backward, data_type::bf16,
1531 data_type::bf16, data_type::f32>;
1532
1533template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8,
1534 data_type::s8, data_type::s32>;
1535template struct _ref_rnn_common_t<prop_kind::forward, data_type::s8,
1536 data_type::s8, data_type::s32>;
1537
1538#undef AOC
1539
1540} // namespace cpu
1541} // namespace impl
1542} // namespace dnnl
1543