1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_RNN_RNN_UTILS_HPP
18#define CPU_RNN_RNN_UTILS_HPP
19
20#include <memory>
21#include <type_traits>
22
23#include "common/c_types_map.hpp"
24#include "common/memory_desc_wrapper.hpp"
25#include "common/primitive.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/platform.hpp"
29
30#include "cpu/gemm/gemm_pack.hpp"
31
32#if DNNL_X64
33#include "cpu/x64/cpu_isa_traits.hpp"
34#endif
35
36#define rnn_postgemm_sig(f) \
37 void f(const rnn_utils::rnn_conf_t &rnn, \
38 rnn_utils::cell_position_t cell_position, gates_t *ws_gates_, \
39 scratch_t *scratch_gates_, const dst_layer_t *augru_attention_, \
40 dst_layer_t *dst_layer_, void *dst_iter_c_, \
41 const src_iter_t *src_iter_, const void *src_iter_c_, \
42 gemm_acc_t *diff_src_layer_, gemm_acc_t *diff_augru_attention_, \
43 gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \
44 gemm_acc_t *diff_dst_layer_, gemm_acc_t *diff_dst_iter_, \
45 gemm_acc_t *diff_dst_iter_c_, const float *weights_peephole_, \
46 const void *bias_, gates_t *ws_grid_, scratch_t *scratch_cell_, \
47 dst_iter_t *dst_iter_, float *weights_scales_, int block_step) \
48 const
49
50#if DNNL_X64
51#define rnn_merged_layer_execution_sig(f) \
52 dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \
53 rnn_utils::cell_position_t cell_position, weights_t **w_layer_, \
54 const src_layer_t *src_layer_, scratch_t *scratch_gates_, \
55 gemm_acc_t *diff_src_layer_, gemm_acc_t *diff_w_layer_, \
56 gemm_acc_t *amx_scratchpad, \
57 x64::brgemm_batch_element_t *addr_batch_global) const
58
59#define rnn_cell_execution_sig(f) \
60 dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \
61 rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \
62 void *dst_iter_c_, gemm_acc_t *diff_src_layer_, \
63 gemm_acc_t *diff_augru_attention_, gemm_acc_t *diff_src_iter_, \
64 gemm_acc_t *diff_src_iter_c_, weights_t **w_layer_, \
65 weights_t **w_iter_, weights_t **w_projection_, \
66 const float *weights_peephole_, const float *w_proj_comp, \
67 void **bias_, const src_layer_t *src_layer_, \
68 const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \
69 const void *src_iter_c_, gemm_acc_t *diff_dst_layer_, \
70 gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \
71 gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \
72 float *diff_weights_projection_, float *diff_weights_peephole_, \
73 float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \
74 ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \
75 scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \
76 scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \
77 dst_iter_t *dst_iter_, gemm_acc_t *amx_scratchpad, \
78 x64::brgemm_batch_element_t *addr_batch_global) const
79
80#define rnn_grid_execution_sig(f) \
81 dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \
82 weights_t **weights_layer_, weights_t **weights_iter_, \
83 weights_t **weights_projection_, const float *weights_peephole_, \
84 const float *w_proj_comp, void **bias_, \
85 const src_layer_t *src_layer_, \
86 const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \
87 const void *src_iter_c_, dst_layer_t *dst_layer_, \
88 dst_iter_t *dst_iter_, void *dst_iter_c_, \
89 src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \
90 void *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \
91 gemm_acc_t *ws_diff_states_iter_, \
92 gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \
93 ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \
94 ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \
95 scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \
96 scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \
97 gemm_acc_t *diff_augru_attention_, \
98 gemm_acc_t *diff_weights_layer_, gemm_acc_t *diff_weights_iter_, \
99 float *diff_weights_projection_, float *diff_weights_peephole_, \
100 float *diff_bias_, gemm_acc_t *amx_scratchpad, \
101 x64::brgemm_batch_element_t *addr_batch_global) const
102#else
103#define rnn_merged_layer_execution_sig(f) \
104 dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \
105 rnn_utils::cell_position_t cell_position, weights_t **w_layer_, \
106 const src_layer_t *src_layer_, scratch_t *scratch_gates_, \
107 gemm_acc_t *diff_src_layer_, gemm_acc_t *diff_w_layer_) const
108
109#define rnn_cell_execution_sig(f) \
110 dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \
111 rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \
112 void *dst_iter_c_, gemm_acc_t *diff_src_layer_, \
113 gemm_acc_t *diff_augru_attention_, gemm_acc_t *diff_src_iter_, \
114 gemm_acc_t *diff_src_iter_c_, weights_t **w_layer_, \
115 weights_t **w_iter_, weights_t **w_projection_, \
116 const float *weights_peephole_, const float *w_proj_comp, \
117 void **bias_, const src_layer_t *src_layer_, \
118 const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \
119 const void *src_iter_c_, gemm_acc_t *diff_dst_layer_, \
120 gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \
121 gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \
122 float *diff_weights_projection_, float *diff_weights_peephole_, \
123 float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \
124 ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \
125 scratch_t *scratch_cell_, dst_iter_t *dst_iter_, \
126 gemm_acc_t *amx_scratchpad) const
127
128#define rnn_grid_execution_sig(f) \
129 dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \
130 weights_t **weights_layer_, weights_t **weights_iter_, \
131 weights_t **weights_projection_, const float *weights_peephole_, \
132 const float *w_proj_comp, void **bias_, \
133 const src_layer_t *src_layer_, \
134 const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \
135 const void *src_iter_c_, dst_layer_t *dst_layer_, \
136 dst_iter_t *dst_iter_, void *dst_iter_c_, \
137 src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \
138 void *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \
139 gemm_acc_t *ws_diff_states_iter_, \
140 gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \
141 ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \
142 ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \
143 scratch_t *scratch_cell_, gemm_acc_t *diff_augru_attention_, \
144 gemm_acc_t *diff_weights_layer_, gemm_acc_t *diff_weights_iter_, \
145 float *diff_weights_projection_, float *diff_weights_peephole_, \
146 float *diff_bias_, gemm_acc_t *amx_scratchpad) const
147#endif
148
149#define rnn_gemm_sig(f) \
150 dnnl_status_t f(const char transA, const char transB, dim_t m, dim_t n, \
151 dim_t k, const float alpha, const weights_t *a_, const dim_t ldA, \
152 const gemm_data_t *b_, const dim_t ldB, const float beta, \
153 gemm_acc_t *c_, const dim_t ldC) const
154
155#define rnn_bias_prepare_sig(f) \
156 void f(const rnn_utils::rnn_conf_t &rnn, void **bias_, const void *b_, \
157 void *scratch_bias_) const
158
159#define rnn_bias_prepare_sig_templ(f) \
160 template <typename T> \
161 static void f(const rnn_utils::rnn_conf_t &rnn, T **bias_, const T *b_, \
162 T *scratch_bias_)
163
164#define rnn_bias_finalize_sig(f) \
165 void f(const rnn_utils::rnn_conf_t &rnn, void *scratch_bias_, \
166 const float *w_iter_comp, const float *w_layer_comp) const
167
168#define rnn_weights_assign_sig(f) \
169 void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, \
170 int n_parts, const int *gates_per_part, weights_t **weights_, \
171 const weights_t *w_) const
172
173namespace dnnl {
174namespace impl {
175namespace cpu {
176
177namespace rnn_utils {
178
179enum execution_direction_t {
180 l2r,
181 r2l,
182 bi_concat,
183 bi_sum,
184};
185
186enum cell_position_t {
187 middle_cell = 0x0,
188 first_layer = 0x1,
189 first_iter = 0x2,
190 last_layer = 0x4,
191 last_iter = 0x8,
192 c_state_first_iter = 0x10,
193 c_state_last_iter = 0x20
194};
195
196enum class weights_type_t {
197 layer,
198 iter,
199 projection,
200 peephole,
201};
202
203inline cell_position_t &operator|=(cell_position_t &lhs, cell_position_t rhs) {
204 lhs = static_cast<cell_position_t>(
205 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
206 return lhs;
207}
208
209inline cell_position_t operator|(cell_position_t lhs, cell_position_t rhs) {
210 return static_cast<cell_position_t>(
211 static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs));
212}
213
214enum data_type_conf_t {
215 all_f32,
216 all_bf16,
217 u8u8u8f32,
218 f32u8f32f32,
219 u8u8u8u8,
220 f32u8f32u8,
221 s8s8s8f32,
222 f32s8f32f32,
223 s8s8s8s8,
224 f32s8f32s8
225};
226
227enum brgemm_rnn_execute_loop_order_t {
228 // default for kernels w/o loop order choice
229 undefined = 0x0,
230 // m_blocking loop is outermost
231 mblk_nblk = 0x1,
232 // n_blocking loop is outermost
233 nblk_mblk = 0x2
234};
235
236struct diff_src_brgemm_conf_t {
237 dim_t M = 0, N = 0, K = 0;
238
239 dim_t n_block = 0, N_blocks = 0, n_tail = 0;
240 dim_t m_block = 0, M_blocks = 0;
241
242 dim_t K_blocks = 0, k_block = 0, k_tail = 0;
243 dim_t Kpadded = 0;
244
245 dim_t N_iter = 0, N_layer = 0;
246 dim_t N_layer_blocks = 0, n_layer_tail = 0;
247 dim_t N_iter_blocks = 0, n_iter_tail = 0;
248 dim_t LDA = 0, LDB = 0, LDC = 0;
249
250#if DNNL_X64
251 x64::cpu_isa_t isa = x64::isa_undef;
252#endif
253
254 brgemm_rnn_execute_loop_order_t loop_order
255 = brgemm_rnn_execute_loop_order_t::undefined;
256 int gates_block;
257};
258
259struct diff_wei_brgemm_conf_t {
260 dim_t M = 0, M_layer = 0, M_iter = 0, N = 0, K = 0;
261
262 dim_t n_block = 0, N_blocks = 0, n_tail = 0;
263 dim_t m_block = 0, M_blocks = 0;
264 dim_t K_blocks = 0, k_block = 0, k_tail = 0;
265 dim_t Kpadded = 0;
266
267 dim_t LDA_layer = 0, LDA_iter = 0, LDB = 0, LDC_iter = 0, LDC_layer = 0;
268
269 bool global_transpose = false;
270
271#if DNNL_X64
272 x64::cpu_isa_t isa = x64::isa_undef;
273#endif
274
275 brgemm_rnn_execute_loop_order_t loop_order
276 = brgemm_rnn_execute_loop_order_t::undefined;
277};
278
279struct rnn_conf_t {
280 execution_direction_t exec_dir;
281 data_type_conf_t dt_conf;
282 data_type_t cell_dt = data_type::undef; // The data type used by cell
283 data_type_t bias_dt = data_type::undef;
284 data_type_t src_iter_c_dt = data_type::undef;
285 data_type_t dst_iter_c_dt = data_type::undef;
286
287 int n_layer = 0, n_iter = 0, n_dir = 0, n_gates = 0, n_states = 0;
288 int mb = 0;
289 int slc = 0, sic = 0, dhc = 0, dic = 0, dlc = 0;
290 //int gates_ld, gates_nld, gates_ws_ld;
291
292 int n_parts_weights_layer = 0;
293 int parts_weights_layer[DNNL_RNN_MAX_N_PARTS];
294 size_t part_weights_layer_pack_size[DNNL_RNN_MAX_N_PARTS];
295
296 int n_parts_weights_iter = 0;
297 int parts_weights_iter[DNNL_RNN_MAX_N_PARTS];
298 size_t part_weights_iter_pack_size[DNNL_RNN_MAX_N_PARTS];
299
300 int n_parts_weights_projection = 0;
301 int parts_weights_projection[DNNL_RNN_MAX_N_PARTS];
302 size_t part_weights_projection_pack_size[DNNL_RNN_MAX_N_PARTS];
303
304 int n_bias = 0, n_parts_bias = 0, parts_bias[DNNL_RNN_MAX_N_PARTS];
305
306 /* Size of packed data in bytes */
307 size_t weights_layer_comp_offset = 0, weights_layer_pack_size = 0;
308 size_t weights_iter_comp_offset = 0, weights_iter_pack_size = 0;
309 size_t weights_projection_comp_offset = 0, weights_projection_pack_size = 0;
310
311 bool copy_bias = 0;
312 int weights_layer_ld = 0, weights_layer_nld = 0;
313 int diff_weights_layer_ld = 0, diff_weights_layer_nld = 0;
314 int weights_iter_ld = 0, weights_iter_nld = 0;
315 int diff_weights_iter_ld = 0, diff_weights_iter_nld = 0;
316 int weights_projection_ld = 0, weights_projection_nld = 0;
317 int diff_weights_projection_ld = 0, diff_weights_projection_nld = 0;
318
319 int proj_ht_ld = 0, proj_ht_nld = 0;
320
321 int ws_gates_ld = 0, ws_gates_nld = 0;
322 int ws_ht_ld = 0, ws_ht_nld = 0;
323 int ws_states_layer_ld = 0, ws_states_layer_nld = 0;
324 int ws_states_iter_ld = 0, ws_states_iter_nld = 0;
325 int ws_states_iter_c_ld = 0, ws_states_iter_c_nld = 0;
326 int ws_diff_states_layer_ld = 0, ws_diff_states_layer_nld = 0;
327 int ws_diff_states_iter_ld = 0, ws_diff_states_iter_nld = 0;
328 int ws_diff_states_iter_c_ld = 0, ws_diff_states_iter_c_nld = 0;
329
330 int scratch_gates_ld = 0, scratch_gates_nld = 0;
331 int scratch_ht_ld = 0, scratch_ht_nld = 0;
332 int scratch_diff_ht_ld = 0, scratch_diff_ht_nld = 0;
333
334 int src_layer_ld_ = 0, src_layer_nld_ = 0;
335 int src_iter_ld_ = 0, src_iter_nld_ = 0;
336 int src_iter_c_ld_ = 0, src_iter_c_nld_ = 0;
337 int dst_layer_ld_ = 0, dst_layer_nld_ = 0;
338 int dst_iter_ld_ = 0, dst_iter_nld_ = 0;
339 int dst_iter_c_ld_ = 0, dst_iter_c_nld_ = 0;
340
341 int weights_iter_compensation_size = 0, weights_layer_compensation_size = 0;
342 bool is_fwd = 0, is_training = 0, is_lbr = 0, is_lstm_peephole = 0,
343 is_lstm_projection = 0, is_augru = 0, is_orig_gru = 0;
344 bool use_workspace = 0;
345
346 // Size of workspace for each tensor in bytes
347 // Notes:
348 // 1. For non-LSTMP ws_states_iter_size == ws_states_layer_size. The corresponding
349 // pointers should point to the same places.
350 size_t ws_gates_size = 0;
351 size_t ws_ht_size = 0;
352 size_t ws_states_layer_size = 0;
353 size_t ws_states_iter_size = 0;
354 size_t ws_states_iter_c_size = 0;
355 size_t ws_diff_states_layer_size = 0;
356 size_t ws_diff_states_iter_size = 0;
357 size_t ws_diff_states_iter_c_size = 0;
358 size_t scratch_gates_size = 0;
359
360 size_t scratch_gates_blocked_size = 0;
361 size_t scratch_gates_blocked_nested_reorder_size = 0;
362 size_t scratch_src_layer_size = 0;
363 size_t scratch_src_layer_nested_reorder_size = 0;
364 size_t scratch_src_iter_size = 0;
365 size_t scratch_src_iter_nested_reorder_size = 0;
366
367 size_t scratch_ht_size = 0;
368 size_t scratch_diff_ht_size = 0;
369 size_t scratch_cell_size = 0;
370 size_t ws_grid_comp_size = 0;
371 size_t ws_per_cell = 0;
372 size_t ws_bias_size = 0;
373
374 bool src_layer_is_trivial_stride = false;
375 bool merge_gemm_iter = false, merge_gemm_layer = false,
376 force_nocopy = false, use_layer_packed_gemm = false,
377 use_iter_packed_gemm = false, use_projection_packed_gemm = false;
378 int n_iter_scratch_gates = 0;
379
380 inline bool is_int8_conf() const {
381 return is_signed_int8_conf() || is_unsigned_int8_conf();
382 }
383 inline bool is_signed_int8_conf() const {
384 return utils::one_of(
385 dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8, f32s8f32s8);
386 }
387 inline bool is_unsigned_int8_conf() const {
388 return utils::one_of(
389 dt_conf, u8u8u8f32, f32u8f32f32, u8u8u8u8, f32u8f32u8);
390 }
391
392 inline bool is_cell_dt_int8() const {
393 return is_cell_dt_signed_int8() || is_cell_dt_unsigned_int8();
394 }
395 inline bool is_cell_dt_signed_int8() const {
396 return cell_dt == data_type::s8;
397 }
398 inline bool is_cell_dt_unsigned_int8() const {
399 return cell_dt == data_type::u8;
400 }
401
402 inline bool is_cell_int8_amx() const {
403#if DNNL_X64
404 return brgemm_isa == x64::avx512_core_amx && is_cell_dt_int8();
405#else
406 return false;
407#endif
408 }
409
410 inline bool is_bf16_conf() const { return dt_conf == all_bf16; }
411
412 inline bool is_f32_conf() const { return dt_conf == all_f32; }
413
414 inline bool is_cell_dt_f32() const { return cell_dt == data_type::f32; }
415 inline bool is_cell_dt_bf16() const { return cell_dt == data_type::bf16; }
416 inline bool is_cell_bf16_amx() const {
417#if DNNL_X64
418 return brgemm_isa == x64::avx512_core_amx && is_cell_dt_bf16();
419#else
420 return false;
421#endif
422 }
423 inline bool is_bf32() const { return is_cell_bf16_amx() && is_f32_conf(); }
424
425 inline bool skip_src_layer_copy() const {
426 return (exec_dir == l2r) && !is_bf32()
427 && utils::one_of(dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8,
428 f32s8f32s8, u8u8u8u8, u8u8u8f32, f32u8f32u8,
429 f32u8f32f32, all_f32, all_bf16);
430 }
431 inline bool skip_src_iter_copy() const {
432 return (exec_dir == l2r) && (src_iter_ld_ > 0) && !is_bf32()
433 && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8,
434 u8u8u8f32, all_f32, all_bf16);
435 }
436 inline bool skip_dst_layer_copy() const {
437 return (exec_dir == l2r) && !is_bf32()
438 && utils::one_of(dt_conf, s8s8s8s8, f32s8f32s8, u8u8u8u8,
439 f32u8f32u8, all_f32, all_bf16);
440 }
441 inline bool skip_dst_iter_copy() const {
442 return (exec_dir == l2r) && (dst_iter_ld_ > 0) && !is_bf32()
443 && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8,
444 u8u8u8f32, all_f32, all_bf16);
445 }
446
447 inline dim_t src_layer_ld(cell_position_t cell_position) const {
448 return (cell_position & first_layer) && skip_src_layer_copy()
449 ? src_layer_ld_
450 : (cell_position & last_iter) && skip_dst_iter_copy()
451 ? dst_iter_ld_
452 : ws_states_layer_ld;
453 }
454
455 inline dim_t src_iter_ld(cell_position_t cell_position) const {
456 return (cell_position & first_iter) && skip_src_iter_copy()
457 ? src_iter_ld_
458 : ((cell_position & last_layer) && skip_dst_layer_copy()
459 && !(cell_position & first_iter)
460 ? dst_layer_ld_
461 : ws_states_iter_ld);
462 }
463
464 inline dim_t layer_brgemm_desc(cell_position_t cell_position) const {
465 return ((cell_position & first_layer) && skip_src_layer_copy())
466 ? 0
467 : ((cell_position & last_iter) && skip_dst_iter_copy()) ? 1 : 2;
468 }
469
470 inline dim_t iter_brgemm_desc(cell_position_t cell_position) const {
471 return ((cell_position & first_iter) && skip_src_iter_copy())
472 ? 0
473 : ((cell_position & last_layer) && skip_dst_layer_copy()
474 && !(cell_position & first_iter))
475 ? 1
476 : 2;
477 }
478
479 // Returns index of brgemm kernel for 2nd part of iteration gemm in vanilla
480 // GRU cell for the current position.
481 // Note: this method must be aligned with dst_iter_part2_ld() and LDA2_2[]
482 // values initialization order
483 inline dim_t iter_part2_brgemm_desc(cell_position_t cell_position) const {
484 if (cell_position & last_layer) {
485 return (cell_position & last_layer) && skip_dst_layer_copy()
486 ? 0
487 : (cell_position & last_iter) && skip_dst_iter_copy() ? 1
488 : 2;
489 } else {
490 return (cell_position & last_iter) && skip_dst_iter_copy() ? 1 : 3;
491 }
492 }
493
494 inline dim_t src_iter_c_ld(cell_position_t cell_position) const {
495 return (cell_position & c_state_first_iter) ? src_iter_c_ld_
496 : ws_states_iter_c_ld;
497 }
498
499 inline dim_t dst_layer_ld(
500 cell_position_t cell_position, bool after_proj = false) const {
501 // We use scratch_ht and not dst_layer for lstmp
502 if (is_lstm_projection && !after_proj) return scratch_ht_ld;
503
504 return (cell_position & last_layer) && skip_dst_layer_copy()
505 ? dst_layer_ld_
506 : (cell_position & last_iter) && skip_dst_iter_copy()
507 ? dst_iter_ld_
508 : ws_states_layer_ld;
509 }
510
511 inline dim_t dst_brgemm_desc(
512 cell_position_t cell_position, bool after_proj = false) const {
513 // We use scratch_ht and not dst_layer for lstmp
514 if (is_lstm_projection && !after_proj) return 0;
515
516 return (cell_position & last_layer) && skip_dst_layer_copy()
517 ? 1
518 : (cell_position & last_iter) && skip_dst_iter_copy() ? 2 : 3;
519 }
520
521 inline dim_t dst_iter_ld(cell_position_t cell_position) const {
522 return (cell_position & last_iter) && skip_dst_iter_copy()
523 ? dst_iter_ld_
524 : ws_states_iter_ld;
525 }
526
527 // Returns dst tensor leading dimension for 2nd part of iteration gemm in
528 // vanilla GRU cell for the current position
529 inline dim_t dst_iter_part2_ld(cell_position_t cell_position) const {
530 return (cell_position & last_layer) ? dst_layer_ld(cell_position)
531 : dst_iter_ld(cell_position);
532 }
533
534 inline dim_t dst_iter_c_ld(cell_position_t cell_position) const {
535 return (cell_position & c_state_last_iter) ? dst_iter_c_ld_
536 : ws_states_iter_c_ld;
537 }
538
539 // // when skipping copy, the output ld can be states_ws_ld,
540 // // dst_iter_ld or dst_layer_ld depending on the cell position
541 // inline dim_t dst_ld(cell_position_t cell_position) const {
542 // return (cell_position & last_layer) ? dst_layer_ld(cell_position)
543 // : dst_iter_ld(cell_position);
544 // }
545 inline dim_t dst_copy_ld(cell_position_t cell_position) const {
546 return dst_iter_ld(cell_position);
547 }
548
549 inline bool need_gemm_layer(cell_position_t cell_position) const {
550 // In case of merge_gemm_layer we might still need a layer gemm if we store
551 // the states of the last iteration in the destination memory. The
552 // exception of this rule is the first layer though, in which case all
553 // states are kept in user's src_layer, hence making full merged gemm
554 // possible.
555 return IMPLICATION(merge_gemm_layer,
556 skip_dst_iter_copy() && (cell_position & last_iter)
557 && !(cell_position & first_layer));
558 }
559 bool is_brgemm;
560
561 diff_src_brgemm_conf_t diff_src_brgemm;
562 diff_wei_brgemm_conf_t diff_wei_brgemm;
563
564 dim_t M, N, K1, K2;
565
566 dim_t LDB1, LDB2;
567 dim_t LDA1[3];
568 dim_t LDA2[3];
569 // LDA for iter part2 gemm in vanilla gru cell
570 dim_t LDA2_2[4];
571 dim_t LDC;
572
573 dim_t m_block, M_blocks;
574 dim_t n_block, N_blocks, n_tail;
575
576 dim_t k2_block, k1_block, k1_tail, k2_tail;
577 dim_t KB1_blocks, KB2_blocks;
578 dim_t K1padded, K2padded;
579
580 dim_t Kproj, Kprojpadded;
581 dim_t kproj_block, KBproj_blocks, kproj_tail;
582
583 dim_t Nproj, Nproj_blocks, nproj_tail;
584 dim_t LDAproj, LDBproj, LDCproj[4];
585 int dhc_block_peephole, dhc_tail_peephole, dhc_blocks_peephole;
586 bool brgemm_fwd_iter_layer_fuse_possible = false;
587
588 dim_t nthr;
589#if DNNL_X64
590 x64::cpu_isa_t brgemm_isa;
591#endif
592 bool unfused_post_gemm;
593 brgemm_rnn_execute_loop_order_t loop_order
594 = brgemm_rnn_execute_loop_order_t::undefined;
595
596 // for merged layer computation in brgemm
597 dim_t Mlayermerged;
598 dim_t mlayermerged_block, Mlayermerged_blocks;
599};
600
601bool is_ldigo(const memory_desc_wrapper &md);
602bool is_ldgoi(const memory_desc_wrapper &md);
603bool is_ldio(const memory_desc_wrapper &md);
604bool is_ldoi(const memory_desc_wrapper &md);
605bool is_ldigo_blocked(const memory_desc_wrapper &md);
606bool is_ldgoi_blocked(const memory_desc_wrapper &md);
607bool is_ldio_blocked(const memory_desc_wrapper &md);
608bool is_ldoi_blocked(const memory_desc_wrapper &md);
609
610int get_good_ld(int dim, int sizeof_dt);
611
612template <typename T>
613bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
614 const primitive_attr_t &attr, const memory_desc_wrapper &src_layer_d,
615 const memory_desc_wrapper &src_iter_d,
616 const memory_desc_wrapper &src_iter_c_d,
617 const memory_desc_wrapper &weights_layer_d,
618 const memory_desc_wrapper &weights_iter_d,
619 const memory_desc_wrapper &weights_projection_d,
620 const memory_desc_wrapper &dst_layer_d,
621 const memory_desc_wrapper &dst_iter_d,
622 const memory_desc_wrapper &dst_iter_c_d,
623 const memory_desc_wrapper &bias_d) {
624 rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
625 prop_kind::forward_inference);
626 rnn.is_training = utils::one_of(
627 rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
628 rnn.is_lbr = utils::one_of(rd.cell_kind, dnnl_lbr_gru, dnnl_lbr_augru);
629 rnn.is_lstm_peephole = rd.cell_kind == dnnl_vanilla_lstm
630 && !memory_desc_wrapper(rd.weights_peephole_desc).is_zero();
631 rnn.is_lstm_projection = rd.cell_kind == dnnl_vanilla_lstm
632 && !memory_desc_wrapper(rd.weights_projection_desc).is_zero();
633 rnn.is_augru
634 = utils::one_of(rd.cell_kind, dnnl_lbr_augru, dnnl_vanilla_augru);
635 rnn.bias_dt = bias_d.is_zero() ? data_type::f32 : bias_d.data_type();
636 rnn.src_iter_c_dt = src_iter_c_d.is_zero() ? data_type::f32
637 : src_iter_c_d.data_type();
638 rnn.dst_iter_c_dt = dst_iter_c_d.is_zero() ? data_type::f32
639 : dst_iter_c_d.data_type();
640
641 rnn.cell_dt = data_traits<typename T::src_layer_t>::data_type;
642 switch (rd.direction) {
643 case dnnl_unidirectional_left2right: rnn.exec_dir = l2r; break;
644 case dnnl_unidirectional_right2left: rnn.exec_dir = r2l; break;
645 case dnnl_bidirectional_concat: rnn.exec_dir = bi_concat; break;
646 case dnnl_bidirectional_sum: rnn.exec_dir = bi_sum; break;
647 default: break;
648 }
649
650 if (utils::everyone_is(data_type::f32, src_layer_d.data_type(),
651 dst_layer_d.data_type(), weights_layer_d.data_type()))
652 rnn.dt_conf = all_f32;
653 else if (utils::everyone_is(data_type::bf16, src_layer_d.data_type(),
654 dst_layer_d.data_type(), weights_layer_d.data_type())) {
655 if (!platform::has_data_type_support(data_type::bf16)) return false;
656 rnn.dt_conf = all_bf16;
657 } else if (dst_layer_d.data_type() == data_type::u8) {
658 if (IMPLICATION(
659 src_iter_d.md_, src_iter_d.data_type() == data_type::u8))
660 rnn.dt_conf = u8u8u8u8;
661 else
662 rnn.dt_conf = f32u8f32u8;
663 } else if (dst_layer_d.data_type() == data_type::s8) {
664 if (IMPLICATION(
665 src_iter_d.md_, src_iter_d.data_type() == data_type::s8))
666 rnn.dt_conf = s8s8s8s8;
667 else
668 rnn.dt_conf = f32s8f32s8;
669
670 } else if (dst_layer_d.data_type() == data_type::f32) {
671 if (IMPLICATION(
672 src_iter_d.md_, src_iter_d.data_type() == data_type::u8))
673 rnn.dt_conf = u8u8u8f32;
674 else if (IMPLICATION(src_iter_d.md_,
675 src_iter_d.data_type() == data_type::s8))
676 rnn.dt_conf = s8s8s8f32;
677 else if (IMPLICATION(src_layer_d.md_,
678 src_layer_d.data_type() == data_type::s8))
679 rnn.dt_conf = f32s8f32f32;
680 else
681 rnn.dt_conf = f32u8f32f32;
682 }
683
684 // Set problem members defining problem sizes
685 rnn.n_layer = weights_layer_d.dims()[0];
686 rnn.n_iter = src_layer_d.dims()[0];
687 rnn.n_dir = weights_layer_d.dims()[1];
688 rnn.n_gates = weights_layer_d.dims()[3];
689 rnn.n_states = rd.cell_kind == dnnl_vanilla_lstm ? 2 : 1;
690 rnn.n_bias = rnn.n_gates + rnn.is_lbr;
691 rnn.mb = src_layer_d.dims()[1];
692 rnn.sic = weights_iter_d.dims()[2];
693 rnn.slc = weights_layer_d.dims()[2];
694 rnn.dhc = weights_layer_d.dims()[4];
695 rnn.dlc = rnn.is_lstm_projection ? weights_projection_d.dims()[3] : rnn.dhc;
696 // All supported cells have dic == dlc
697 rnn.dic = rnn.dlc;
698
699 // set members with user memories leading dimensions
700 // Assumption: weights datatype size is the same as state datatype size
701 assert(types::data_type_size(weights_layer_d.data_type())
702 == types::data_type_size(src_layer_d.data_type()));
703
704 // set workspace leading dimensions (and non leading-dimensions)
705
706 // the ws and scratch proj_ht need to match as we use them interchangeably
707 assert(IMPLICATION(rnn.is_lstm_projection,
708 sizeof(typename T::ht_t) == sizeof(typename T::dst_iter_t)));
709 rnn.proj_ht_nld = rnn.mb;
710 rnn.proj_ht_ld = get_good_ld(rnn.dhc, sizeof(typename T::ht_t));
711
712 rnn.ws_gates_nld = rnn.mb;
713 rnn.ws_gates_ld
714 = get_good_ld(rnn.dhc * rnn.n_gates, sizeof(typename T::gates_t));
715 rnn.ws_ht_nld = rnn.proj_ht_nld;
716 rnn.ws_ht_ld = rnn.proj_ht_ld;
717
718 rnn.ws_states_layer_nld = rnn.mb;
719 static_assert(std::is_same<typename T::src_layer_t,
720 typename T::src_iter_t>::value,
721 "src_layer_t and src_iter_t must be the same");
722 rnn.ws_states_layer_ld
723 = get_good_ld(nstl::max(rnn.sic, nstl::max(rnn.slc, rnn.dlc)),
724 sizeof(typename T::src_layer_t));
725 // there is no need for al separate ws_states_iter for now as all
726 // supported cell have dst_iter == dst_layer
727 rnn.ws_states_iter_nld = rnn.ws_states_layer_nld;
728 rnn.ws_states_iter_ld = rnn.ws_states_layer_ld;
729
730 // we do not need a good ld for iter_c as it is not involved in GEMM
731 rnn.ws_states_iter_c_nld = rnn.mb;
732 rnn.ws_states_iter_c_ld = rnn.dhc;
733
734 // TODO: be more restrictive on the leading dimensions
735 rnn.ws_diff_states_layer_nld = rnn.mb;
736 rnn.ws_diff_states_layer_ld = get_good_ld(
737 nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)),
738 sizeof(typename T::gemm_acc_t));
739
740 rnn.ws_diff_states_iter_nld = rnn.mb;
741 rnn.ws_diff_states_iter_ld = get_good_ld(
742 nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)),
743 sizeof(typename T::gemm_acc_t));
744
745 rnn.ws_diff_states_iter_c_nld = rnn.mb;
746 rnn.ws_diff_states_iter_c_ld = rnn.dhc;
747
748 // set scratch (not)leading dimensions
749 // scratch gates is used to store intermediate gates before postgemm operation
750 // temporary: we also use it in lstmp as temporary scratchpad
751 // between projection and downconversion, hence the max with dlc
752 rnn.scratch_gates_nld = rnn.mb;
753 rnn.scratch_gates_ld
754 = get_good_ld(nstl::max(rnn.dlc, rnn.n_gates * rnn.dhc),
755 sizeof(typename T::scratch_t));
756 rnn.scratch_ht_nld = rnn.proj_ht_nld;
757 rnn.scratch_ht_ld = rnn.proj_ht_ld;
758
759 rnn.scratch_diff_ht_nld = rnn.mb;
760 rnn.scratch_diff_ht_ld
761 = get_good_ld(rnn.dlc, sizeof(typename T::gemm_acc_t));
762
763 // Assumption: {src,dst}_layer has tnc layout, {src,dst}_iter has ldnc,
764 rnn.src_layer_ld_ = src_layer_d.blocking_desc().strides[1];
765 rnn.dst_layer_ld_ = dst_layer_d.blocking_desc().strides[1];
766 rnn.src_iter_ld_ = types::is_zero_md(src_iter_d.md_)
767 ? 0
768 : src_iter_d.blocking_desc().strides[2];
769 rnn.dst_iter_ld_ = types::is_zero_md(dst_iter_d.md_)
770 ? 0
771 : dst_iter_d.blocking_desc().strides[2];
772 rnn.src_iter_c_ld_ = types::is_zero_md(src_iter_c_d.md_)
773 ? 0
774 : src_iter_c_d.blocking_desc().strides[2];
775 rnn.dst_iter_c_ld_ = types::is_zero_md(dst_iter_c_d.md_)
776 ? 0
777 : dst_iter_c_d.blocking_desc().strides[2];
778
779 /* Set the correct number of weights parts */
780 rnn.is_orig_gru = utils::one_of(
781 rd.cell_kind, alg_kind::vanilla_gru, alg_kind::vanilla_augru);
782 rnn.n_parts_weights_layer = 1;
783 rnn.parts_weights_layer[0] = rnn.n_gates;
784 rnn.parts_weights_layer[1] = 0;
785
786 rnn.n_parts_weights_iter = rnn.is_orig_gru ? 2 : 1;
787 rnn.parts_weights_iter[0] = rnn.is_orig_gru ? 2 : rnn.n_gates;
788 rnn.parts_weights_iter[1] = rnn.is_orig_gru ? 1 : 0;
789
790 rnn.n_parts_weights_projection = 1;
791 rnn.parts_weights_projection[0] = 1;
792
793 rnn.n_parts_bias = 1;
794 rnn.parts_bias[0] = rnn.n_bias;
795 rnn.parts_bias[1] = 0;
796
797 /* Decide which gemm implementation to use: packed/nonpacked jit/cblas
798 * and if to merge gemm across iterations */
799 const bool is_f32 = rnn.dt_conf == all_f32,
800 is_bf16 = rnn.dt_conf == all_bf16;
801 const bool is_gru = utils::one_of(rd.cell_kind, alg_kind::vanilla_gru,
802 alg_kind::lbr_gru, alg_kind::vanilla_augru, alg_kind::lbr_augru);
803 const bool is_inference = !rnn.is_training;
804
805 // To be able to merge the GEMM on the layer input when not
806 // copying, we need to have a trivial stride for the T dimension
807 rnn.src_layer_is_trivial_stride = src_layer_d.blocking_desc().strides[0]
808 == (rnn.src_layer_ld_ * rnn.mb);
809 const auto dst_layer_is_trivial_stride
810 = dst_layer_d.blocking_desc().strides[0]
811 == (rnn.dst_layer_ld_ * rnn.mb);
812
813 rnn.merge_gemm_layer = (!rnn.is_brgemm)
814 ? ((rnn.is_fwd && rnn.src_layer_is_trivial_stride)
815 || ((rd.prop_kind == prop_kind::backward)
816 && dst_layer_is_trivial_stride))
817 && (((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
818 || rnn.is_int8_conf())
819 : false;
820 rnn.merge_gemm_iter = (!rnn.is_brgemm)
821 ? dst_layer_is_trivial_stride && !(rnn.is_fwd || is_gru)
822 : false;
823 rnn.force_nocopy = false;
824#if DNNL_X64
825 rnn.force_nocopy = x64::mayiuse(x64::avx)
826 && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
827 || (rnn.is_training && rnn.dhc < 500));
828#endif
829
830 /* Decide to copy bias */
831 rnn.copy_bias = rnn.is_int8_conf();
832
833 rnn.use_layer_packed_gemm = !rnn.is_brgemm
834 ? utils::one_of(weights_layer_d.format_kind(), format_kind::any,
835 format_kind::rnn_packed)
836 && is_inference
837 && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1)
838 || rnn.is_int8_conf() || is_bf16)
839 : false;
840 rnn.use_iter_packed_gemm = !rnn.is_brgemm
841 ? utils::one_of(weights_iter_d.format_kind(), format_kind::any,
842 format_kind::rnn_packed)
843 && is_inference
844 && ((is_f32 && pack_sgemm_supported() && rnn.mb >= 16)
845 || rnn.is_int8_conf() || is_bf16)
846 : false;
847 rnn.use_projection_packed_gemm = !rnn.is_brgemm
848 ? utils::one_of(weights_projection_d.format_kind(),
849 format_kind::any, format_kind::rnn_packed)
850 && is_inference
851 && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1)
852 || rnn.is_int8_conf() || is_bf16)
853 : false;
854
855#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
856 // XXX: Threadpool runtime may use different number of threads at execute
857 // and create stages. GEMM packed API is not aware of number of threads as
858 // of now. In order to synchronize all layers, GEMM pack API should be
859 // modified to accept number of threads instead of taking it from
860 // `dnnl_get_max_threads()`, and rnn_packed_desc_t should be updated with
861 // `nthr` member to pass this information between different parts of packed
862 // API, since `get_size` call happens on RNN side, while packing happens
863 // on reorder side. Consider enabling later.
864 // `test_iface_runtime_attr` was disabled for RNN with threadpool due to
865 // this is the only working approach for int8 computations in RNN for now.
866 // Consider enabling it once resolved.
867 rnn.use_layer_packed_gemm = false;
868 rnn.use_iter_packed_gemm = false;
869 rnn.use_projection_packed_gemm = false;
870#endif
871
872 /* Set packed gemm sizes */
873 /* TODO: investigate the benefit of mixing packed and non-packed weights parts */
874 const auto set_pack_sizes
875 = [&](bool merge, bool &do_pack, size_t &weights_pack_size,
876 int &n_parts, int *parts, size_t *parts_pack_size,
877 size_t &comp_offset, int ic, int oc, int weights_oc,
878 dim_t data_ld) -> bool {
879 bool pack = true;
880 weights_pack_size = 0;
881 for (int p = 0; p < n_parts; p++) {
882 const dim_t m_p = rnn.is_fwd ? (parts[p] * oc) : ic;
883 const dim_t k_p = rnn.is_fwd ? ic : (parts[p] * oc);
884 const dim_t n_p = merge ? rnn.mb * rnn.n_iter : rnn.mb;
885 bool pack_part = true;
886
887 dnnl_status_t st = dnnl_success;
888 switch (rnn.dt_conf) {
889 case all_f32:
890 st = sgemm_pack_get_size("A", "N", "N", &m_p, &n_p, &k_p,
891 &m_p, &data_ld, &parts_pack_size[p], &pack_part);
892 break;
893 case s8s8s8f32:
894 case f32s8f32f32:
895 case s8s8s8s8:
896 case f32s8f32s8:
897 st = gemm_s8s8s32_pack_get_size("A", "N", "N", &m_p, &n_p,
898 &k_p, &m_p, &data_ld, &parts_pack_size[p],
899 &pack_part);
900 break;
901 case u8u8u8f32:
902 case f32u8f32f32:
903 case u8u8u8u8:
904 case f32u8f32u8:
905 st = gemm_s8u8s32_pack_get_size("A", "N", "N", &m_p, &n_p,
906 &k_p, &m_p, &data_ld, &parts_pack_size[p],
907 &pack_part);
908 break;
909 case all_bf16:
910 st = gemm_bf16bf16f32_pack_get_size("A", "N", "N", &m_p,
911 &n_p, &k_p, &m_p, &data_ld, &parts_pack_size[p],
912 &pack_part);
913 break;
914 default: assert(!"Unsupported configuration");
915 }
916 if (st != dnnl_success) return false;
917
918 pack = pack && pack_part;
919 weights_pack_size += rnn.n_layer * rnn.n_dir * parts_pack_size[p];
920 }
921
922 // NOTE: pack is updated only for f32. We force pack for int8
923 do_pack = (rnn.dt_conf == all_f32) ? pack : true;
924 comp_offset = weights_pack_size;
925 const bool need_compensation = rnn.is_int8_conf();
926 weights_pack_size += (need_compensation ? rnn.n_layer * rnn.n_dir : 0)
927 * weights_oc * sizeof(float);
928
929 return true;
930 };
931 // TODO: the activation leading dimension can vary for first layer/iteration
932 if (rnn.use_layer_packed_gemm) {
933 bool ok = set_pack_sizes(rnn.merge_gemm_layer,
934 rnn.use_layer_packed_gemm, rnn.weights_layer_pack_size,
935 rnn.n_parts_weights_layer, rnn.parts_weights_layer,
936 rnn.part_weights_layer_pack_size, rnn.weights_layer_comp_offset,
937 rnn.slc, rnn.dhc, rnn.n_gates * rnn.dhc,
938 rnn.ws_states_layer_ld);
939 if (!ok) return false;
940 }
941
942 if (rnn.use_iter_packed_gemm) {
943 bool ok = set_pack_sizes(rnn.merge_gemm_iter, rnn.use_iter_packed_gemm,
944 rnn.weights_iter_pack_size, rnn.n_parts_weights_iter,
945 rnn.parts_weights_iter, rnn.part_weights_iter_pack_size,
946 rnn.weights_iter_comp_offset, rnn.sic, rnn.dhc,
947 rnn.n_gates * rnn.dhc, rnn.ws_states_iter_ld);
948 if (!ok) return false;
949 }
950
951 if (rnn.use_projection_packed_gemm) {
952 bool ok = set_pack_sizes(false, rnn.use_projection_packed_gemm,
953 rnn.weights_projection_pack_size,
954 rnn.n_parts_weights_projection, rnn.parts_weights_projection,
955 rnn.part_weights_projection_pack_size,
956 rnn.weights_projection_comp_offset, rnn.dhc, rnn.dic, rnn.dic,
957 rnn.scratch_ht_ld);
958 if (!ok) return false;
959 }
960
961 return true;
962}
963
964template <typename T>
965void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
966 const memory_desc_wrapper &weights_layer_d,
967 const memory_desc_wrapper &weights_iter_d,
968 const memory_desc_wrapper &weights_projection_d,
969 const memory_desc_wrapper &diff_weights_layer_d,
970 const memory_desc_wrapper &diff_weights_iter_d,
971 const memory_desc_wrapper &diff_weights_projection_d) {
972
973 // Set leading dimensions for input weights arrays depending on input format
974 const auto set_dims
975 = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
976 ld = 0;
977 nld = 0;
978 if (md.is_blocking_desc()) {
979 if (is_ldigo(md)) {
980 ld = (int)md.blocking_desc().strides[2];
981 nld = md.dims()[2];
982 } else if (is_ldgoi(md)) {
983 ld = (int)md.blocking_desc().strides[4];
984 nld = md.dims()[3] * md.dims()[4];
985 } else if (is_ldoi(md)) {
986 ld = (int)md.blocking_desc().strides[3];
987 nld = md.dims()[3];
988 } else if (is_ldio(md)) {
989 ld = (int)md.blocking_desc().strides[2];
990 nld = md.dims()[2];
991 } else
992 assert(!"unsupported weights format");
993 }
994 };
995 set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
996 set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
997 set_dims(weights_projection_d, rnn.weights_projection_ld,
998 rnn.weights_projection_nld);
999 if (!rnn.is_fwd) {
1000 set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
1001 rnn.diff_weights_layer_nld);
1002 set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
1003 rnn.diff_weights_iter_nld);
1004 set_dims(diff_weights_projection_d, rnn.diff_weights_projection_ld,
1005 rnn.diff_weights_projection_nld);
1006 }
1007
1008 assert(weights_layer_d.data_type() == weights_iter_d.data_type());
1009 assert(IMPLICATION(diff_weights_layer_d.ndims() != 0,
1010 (diff_weights_layer_d.data_type()
1011 == diff_weights_iter_d.data_type())));
1012
1013 /* Set workspace sizes to store:
1014 * states to compute a pass
1015 * diff states to compute bwd pass (training onl)y
1016 * intermediate results from the gates
1017 */
1018
1019 assert(sizeof(typename T::src_layer_t) == sizeof(typename T::dst_layer_t));
1020 assert(sizeof(typename T::src_iter_t) == sizeof(typename T::dst_iter_t));
1021}
1022
1023template <typename T>
1024void set_workspace_sizes(rnn_conf_t &rnn, const rnn_desc_t &rd) {
1025 rnn.use_workspace = rnn.is_training;
1026 // TODO: for inference, we can make ws_states_* smaller, but
1027 // dependant of the grid execution though
1028 rnn.ws_states_layer_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
1029 * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_layer_ld
1030 * sizeof(typename T::src_layer_t);
1031 rnn.ws_states_iter_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
1032 * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_iter_ld
1033 * sizeof(typename T::src_iter_t);
1034 bool is_lstm = rd.cell_kind == dnnl_vanilla_lstm;
1035 rnn.ws_states_iter_c_size = is_lstm ? (size_t)(rnn.n_layer + 1) * rnn.n_dir
1036 * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_iter_c_ld
1037 * types::data_type_size(rnn.src_iter_c_dt)
1038 : 0;
1039
1040 rnn.ws_diff_states_layer_size = rnn.is_training
1041 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
1042 * rnn.ws_diff_states_layer_ld
1043 * sizeof(typename T::gemm_acc_t)
1044 : (size_t)0;
1045 rnn.ws_diff_states_iter_size = rnn.is_training
1046 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
1047 * rnn.ws_diff_states_iter_ld
1048 * sizeof(typename T::gemm_acc_t)
1049 : (size_t)0;
1050 rnn.ws_diff_states_iter_c_size = rnn.is_training && is_lstm
1051 ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
1052 * rnn.ws_diff_states_iter_c_ld
1053 * sizeof(typename T::gemm_acc_t)
1054 : (size_t)0;
1055
1056 rnn.ws_gates_size = rnn.is_training
1057 ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_gates_nld
1058 * rnn.ws_gates_ld * sizeof(typename T::gates_t)
1059 : (size_t)0;
1060 rnn.ws_ht_size = rnn.is_training
1061 ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_ht_nld
1062 * rnn.ws_ht_ld * sizeof(typename T::dst_iter_t)
1063 : (size_t)0;
1064 rnn.n_iter_scratch_gates
1065 = (rnn.merge_gemm_layer || rnn.merge_gemm_iter) ? rnn.n_iter : 1;
1066 rnn.scratch_gates_size = sizeof(typename T::scratch_t)
1067 * rnn.n_iter_scratch_gates * rnn.scratch_gates_nld
1068 * rnn.scratch_gates_ld;
1069 rnn.scratch_ht_size
1070 = sizeof(typename T::ht_t) * rnn.scratch_ht_nld * rnn.scratch_ht_ld;
1071 rnn.scratch_diff_ht_size = rnn.is_training ? sizeof(typename T::gemm_acc_t)
1072 * rnn.scratch_diff_ht_nld * rnn.scratch_diff_ht_ld
1073 : (size_t)0;
1074
1075 /* set other sizes */
1076 /// scratchpad buffer for each cell to hold intermediate data in gru/lbr_gru
1077 rnn.scratch_cell_size = rnn.is_lbr
1078 ? (size_t)rnn.scratch_gates_nld * rnn.scratch_gates_ld
1079 * sizeof(typename T::gemm_acc_t)
1080 : (utils::one_of(rd.cell_kind, alg_kind::vanilla_gru,
1081 alg_kind::vanilla_augru)
1082 ? (size_t)rnn.ws_states_layer_nld
1083 * rnn.ws_states_layer_ld
1084 * sizeof(typename T::gemm_acc_t)
1085 : 0);
1086 /// workspace needed for lbr GRU
1087 rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dhc
1088 * sizeof(typename T::gemm_acc_t);
1089 rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
1090 * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
1091 /// bias ws needed to add compensation in int8
1092 rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dhc
1093 * types::data_type_size(rnn.bias_dt);
1094}
1095
1096void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
1097 size_t &ws_ht_offset, size_t &ws_state_layer_offset,
1098 size_t &ws_states_iter_offset, size_t &ws_states_iter_c_offset,
1099 size_t &ws_diff_states_layer_offset, size_t &ws_diff_states_iter_offset,
1100 size_t &ws_diff_states_iter_c_offset, size_t &ws_grid_comp_offset,
1101 size_t &ws_bias_offset, size_t &scratch_gates_offset,
1102 size_t &scratch_ht_offset, size_t &scratch_diff_ht_offset,
1103 size_t &scratch_cell_offset, size_t &scratchpad_size,
1104 size_t &workspace_size);
1105
1106void get_scratchpad_and_workspace_sizes(
1107 const rnn_conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size);
1108status_t set_expected_desc(rnn_conf_t &rnn, memory_desc_t &weights_md,
1109 weights_type_t weights_type);
1110status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag);
1111
1112using byte = unsigned char;
1113template <size_t Tdims>
1114struct raw_array_offset_calculator_t {
1115 template <typename... Targs>
1116 raw_array_offset_calculator_t(
1117 const byte *base, const dim_t dt_size, Targs... Fargs)
1118 : base_ptr_(base), dt_size_(dt_size), dims_ {Fargs...} {}
1119
1120 template <typename... Targs>
1121 raw_array_offset_calculator_t(std::nullptr_t, Targs... Fargs) = delete;
1122
1123 template <typename... Targs>
1124 inline const void *operator()(Targs... Fargs) const {
1125 assert(static_cast<bool>(base_ptr_));
1126 return base_ptr_ + (offset(1, Fargs...) * dt_size_);
1127 }
1128
1129private:
1130 template <typename... Targs>
1131 inline size_t offset(size_t const dimension, size_t element) const {
1132 return element;
1133 }
1134 template <typename... Targs>
1135 inline size_t offset(
1136 size_t const dimension, size_t theta, size_t element) const {
1137 return element + (dims_[dimension] * theta);
1138 }
1139
1140 template <typename... Targs>
1141 inline size_t offset(size_t const dimension, size_t theta, size_t element,
1142 Targs... Fargs) const {
1143 const size_t t_prime = element + (dims_[dimension] * theta);
1144 return offset(dimension + 1, t_prime, Fargs...);
1145 }
1146
1147 const byte *const base_ptr_;
1148 const dim_t dt_size_;
1149 const int dims_[Tdims];
1150};
1151
1152template <typename... Targs>
1153raw_array_offset_calculator_t<sizeof...(Targs)> make_raw_aoc(
1154 const void *base, const dim_t dt_size, Targs... Fargs) {
1155 return raw_array_offset_calculator_t<sizeof...(Targs)>(
1156 static_cast<const byte *>(base), dt_size,
1157 std::forward<Targs>(Fargs)...);
1158}
1159
1160template <typename T>
1161struct ws_gates_aoc {
1162 ws_gates_aoc(const rnn_conf_t &rnn, T *data)
1163 : gates_(data, rnn.ws_gates_nld, rnn.ws_gates_ld), DHC_(rnn.dhc) {}
1164 T &operator()(int batch, int gate, int dhc) const {
1165 return gates_(batch, gate * DHC_ + dhc);
1166 }
1167
1168private:
1169 const dnnl::impl::utils::array_offset_calculator<T, 2> gates_;
1170 const int DHC_;
1171};
1172using ws_gates_aoc_t = ws_gates_aoc<float>;
1173using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>;
1174
1175template <typename T>
1176struct ws_ht_aoc {
1177 ws_ht_aoc(const rnn_conf_t &rnn, T *data)
1178 : ht_(data, rnn.ws_ht_nld, rnn.ws_ht_ld) {}
1179 T &operator()(int batch, int dhc) const { return ht_(batch, dhc); }
1180
1181private:
1182 const dnnl::impl::utils::array_offset_calculator<T, 2> ht_;
1183};
1184
1185template <typename T>
1186struct scratch_gates_aoc {
1187 scratch_gates_aoc(const rnn_conf_t &rnn, T *data)
1188 : gates_(data, rnn.scratch_gates_nld, rnn.scratch_gates_ld)
1189 , DHC_(rnn.dhc) {}
1190 T &operator()(int batch, int gate, int dhc) const {
1191 return gates_(batch, gate * DHC_ + dhc);
1192 }
1193
1194private:
1195 const dnnl::impl::utils::array_offset_calculator<T, 2> gates_;
1196 const int DHC_;
1197};
1198using scratch_gates_aoc_t = scratch_gates_aoc<float>;
1199using scratch_gates_aoc_s32_t = scratch_gates_aoc<int32_t>;
1200
1201template <typename T>
1202struct scratch_ht_aoc {
1203 scratch_ht_aoc(const rnn_conf_t &rnn, T *data)
1204 : ht_(data, rnn.scratch_ht_nld, rnn.scratch_ht_ld) {}
1205 T &operator()(int batch, int dhc) const { return ht_(batch, dhc); }
1206
1207private:
1208 const dnnl::impl::utils::array_offset_calculator<T, 2> ht_;
1209};
1210using scratch_ht_aoc_t = scratch_ht_aoc<float>;
1211using scratch_ht_aoc_s32_t = scratch_ht_aoc<int32_t>;
1212
1213template <typename T>
1214struct weights_peephole_aoc_t {
1215 weights_peephole_aoc_t(const rnn_conf_t &rnn, T *data)
1216 : weights_peephole_(data, 3, rnn.dhc) {}
1217 T &operator()(int g, int dhc) const { return weights_peephole_(g, dhc); }
1218
1219private:
1220 const utils::array_offset_calculator<T, 2> weights_peephole_;
1221};
1222
1223float to_float(const void *data, const data_type_t dt);
1224
1225struct bias_linear_exec_aoc_t {
1226 bias_linear_exec_aoc_t(const rnn_conf_t &rnn, void **bias)
1227 : bias_dt_(rnn.bias_dt), bias_present_(static_cast<bool>(bias)) {
1228
1229 if (bias_dt_ == data_type::f32)
1230 new (std::addressof(bias_f32_aoc_))
1231 utils::array_offset_calculator<float *, 3>(
1232 reinterpret_cast<float **>(bias), rnn.n_layer,
1233 rnn.n_dir, rnn.n_parts_bias);
1234 else
1235 new (std::addressof(bias_bf16_aoc_))
1236 utils::array_offset_calculator<bfloat16_t *, 3>(
1237 reinterpret_cast<bfloat16_t **>(bias), rnn.n_layer,
1238 rnn.n_dir, rnn.n_parts_bias);
1239 }
1240
1241 void **operator()(int layer, int dir) const {
1242 if (bias_present_) {
1243 if (bias_dt_ == data_type::f32)
1244 return reinterpret_cast<void **>(
1245 &bias_f32_aoc_.operator()(layer, dir, 0));
1246 else if (bias_dt_ == data_type::bf16)
1247 return reinterpret_cast<void **>(
1248 &bias_bf16_aoc_.operator()(layer, dir, 0));
1249 }
1250
1251 return nullptr;
1252 }
1253
1254 ~bias_linear_exec_aoc_t() {
1255 if (bias_dt_ == data_type::f32)
1256 bias_f32_aoc_.~array_offset_calculator<float *, 3>();
1257 else
1258 bias_bf16_aoc_.~array_offset_calculator<bfloat16_t *, 3>();
1259 }
1260
1261 DNNL_DISALLOW_COPY_AND_ASSIGN(bias_linear_exec_aoc_t);
1262 bias_linear_exec_aoc_t(bias_linear_exec_aoc_t &&) = delete;
1263 bias_linear_exec_aoc_t &operator=(bias_linear_exec_aoc_t &&) = delete;
1264
1265private:
1266 data_type_t bias_dt_;
1267 bool bias_present_;
1268 union {
1269 utils::array_offset_calculator<float *, 3> bias_f32_aoc_;
1270 utils::array_offset_calculator<bfloat16_t *, 3> bias_bf16_aoc_;
1271 };
1272};
1273
1274template <typename T>
1275struct ws_states_layer_aoc {
1276 ws_states_layer_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1277 : state_(data, rnn.ws_states_layer_nld, leading_dim) {}
1278 ws_states_layer_aoc(const rnn_conf_t &rnn, T *data)
1279 : state_(data, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld) {}
1280 T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1281
1282private:
1283 const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1284};
1285
1286template <typename T>
1287struct ws_states_iter_aoc {
1288 ws_states_iter_aoc(const rnn_conf_t &rnn, T *data, int leading_dim)
1289 : state_(data, rnn.ws_states_iter_nld, leading_dim) {}
1290 ws_states_iter_aoc(const rnn_conf_t &rnn, T *data)
1291 : state_(data, rnn.ws_states_iter_nld, rnn.ws_states_iter_ld) {}
1292 T &operator()(int batch, int dhc) const { return state_(batch, dhc); }
1293
1294private:
1295 const dnnl::impl::utils::array_offset_calculator<T, 2> state_;
1296};
1297
1298template <typename T>
1299struct augru_attention_aoc {
1300 augru_attention_aoc(const rnn_conf_t &rnn, T *data)
1301 : state_(data, rnn.mb) {}
1302 T &operator()(int batch) const { return state_(batch); }
1303
1304private:
1305 const dnnl::impl::utils::array_offset_calculator<T, 1> state_;
1306};
1307
1308template <typename T>
1309struct ws_diff_states_layer_aoc {
1310 ws_diff_states_layer_aoc(const rnn_conf_t &rnn, T *data)
1311 : diff_states_layer_(data, rnn.ws_diff_states_layer_nld,
1312 rnn.ws_diff_states_layer_ld) {}
1313 T &operator()(int batch, int dhc) const {
1314 return diff_states_layer_(batch, dhc);
1315 }
1316
1317private:
1318 const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_layer_;
1319};
1320
1321template <typename T>
1322struct ws_diff_states_iter_aoc {
1323 ws_diff_states_iter_aoc(const rnn_conf_t &rnn, T *data)
1324 : diff_states_iter_(data, rnn.ws_diff_states_iter_nld,
1325 rnn.ws_diff_states_iter_ld) {}
1326 T &operator()(int batch, int dhc) const {
1327 return diff_states_iter_(batch, dhc);
1328 }
1329
1330private:
1331 const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_;
1332};
1333
1334template <typename T>
1335struct ws_diff_states_iter_c_aoc {
1336 ws_diff_states_iter_c_aoc(const rnn_conf_t &rnn, T *data)
1337 : diff_states_iter_c_(data, rnn.ws_diff_states_iter_c_nld,
1338 rnn.ws_diff_states_iter_c_ld) {}
1339 T &operator()(int batch, int dhc) const {
1340 return diff_states_iter_c_(batch, dhc);
1341 }
1342
1343private:
1344 const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_c_;
1345};
1346
1347struct ws_diff_w_iter_aoc_t {
1348 ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data)
1349 : diff_weights_iter_(
1350 data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld)
1351 , DHC_(rnn.dhc) {}
1352 float &operator()(int sic, int gate, int dhc) const {
1353 return diff_weights_iter_(sic, gate * DHC_ + dhc);
1354 }
1355
1356private:
1357 const dnnl::impl::utils::array_offset_calculator<float, 2>
1358 diff_weights_iter_;
1359 const int DHC_;
1360};
1361
1362const void *inc_ptr(const void *data, data_type_t data_type, int offset);
1363void *inc_ptr(void *data, data_type_t data_type, int offset);
1364
1365} // namespace rnn_utils
1366} // namespace cpu
1367} // namespace impl
1368} // namespace dnnl
1369#endif
1370