1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_RNN_RNN_UTILS_HPP |
18 | #define CPU_RNN_RNN_UTILS_HPP |
19 | |
20 | #include <memory> |
21 | #include <type_traits> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/memory_desc_wrapper.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/platform.hpp" |
29 | |
30 | #include "cpu/gemm/gemm_pack.hpp" |
31 | |
32 | #if DNNL_X64 |
33 | #include "cpu/x64/cpu_isa_traits.hpp" |
34 | #endif |
35 | |
36 | #define rnn_postgemm_sig(f) \ |
37 | void f(const rnn_utils::rnn_conf_t &rnn, \ |
38 | rnn_utils::cell_position_t cell_position, gates_t *ws_gates_, \ |
39 | scratch_t *scratch_gates_, const dst_layer_t *augru_attention_, \ |
40 | dst_layer_t *dst_layer_, void *dst_iter_c_, \ |
41 | const src_iter_t *src_iter_, const void *src_iter_c_, \ |
42 | gemm_acc_t *diff_src_layer_, gemm_acc_t *diff_augru_attention_, \ |
43 | gemm_acc_t *diff_src_iter_, gemm_acc_t *diff_src_iter_c_, \ |
44 | gemm_acc_t *diff_dst_layer_, gemm_acc_t *diff_dst_iter_, \ |
45 | gemm_acc_t *diff_dst_iter_c_, const float *weights_peephole_, \ |
46 | const void *bias_, gates_t *ws_grid_, scratch_t *scratch_cell_, \ |
47 | dst_iter_t *dst_iter_, float *weights_scales_, int block_step) \ |
48 | const |
49 | |
50 | #if DNNL_X64 |
51 | #define rnn_merged_layer_execution_sig(f) \ |
52 | dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \ |
53 | rnn_utils::cell_position_t cell_position, weights_t **w_layer_, \ |
54 | const src_layer_t *src_layer_, scratch_t *scratch_gates_, \ |
55 | gemm_acc_t *diff_src_layer_, gemm_acc_t *diff_w_layer_, \ |
56 | gemm_acc_t *amx_scratchpad, \ |
57 | x64::brgemm_batch_element_t *addr_batch_global) const |
58 | |
59 | #define rnn_cell_execution_sig(f) \ |
60 | dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \ |
61 | rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \ |
62 | void *dst_iter_c_, gemm_acc_t *diff_src_layer_, \ |
63 | gemm_acc_t *diff_augru_attention_, gemm_acc_t *diff_src_iter_, \ |
64 | gemm_acc_t *diff_src_iter_c_, weights_t **w_layer_, \ |
65 | weights_t **w_iter_, weights_t **w_projection_, \ |
66 | const float *weights_peephole_, const float *w_proj_comp, \ |
67 | void **bias_, const src_layer_t *src_layer_, \ |
68 | const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \ |
69 | const void *src_iter_c_, gemm_acc_t *diff_dst_layer_, \ |
70 | gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \ |
71 | gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \ |
72 | float *diff_weights_projection_, float *diff_weights_peephole_, \ |
73 | float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \ |
74 | ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \ |
75 | scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \ |
76 | scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \ |
77 | dst_iter_t *dst_iter_, gemm_acc_t *amx_scratchpad, \ |
78 | x64::brgemm_batch_element_t *addr_batch_global) const |
79 | |
80 | #define rnn_grid_execution_sig(f) \ |
81 | dnnl_status_t f(const exec_ctx_t &ctx, const rnn_utils::rnn_conf_t &rnn, \ |
82 | weights_t **weights_layer_, weights_t **weights_iter_, \ |
83 | weights_t **weights_projection_, const float *weights_peephole_, \ |
84 | const float *w_proj_comp, void **bias_, \ |
85 | const src_layer_t *src_layer_, \ |
86 | const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \ |
87 | const void *src_iter_c_, dst_layer_t *dst_layer_, \ |
88 | dst_iter_t *dst_iter_, void *dst_iter_c_, \ |
89 | src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \ |
90 | void *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \ |
91 | gemm_acc_t *ws_diff_states_iter_, \ |
92 | gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \ |
93 | ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \ |
94 | ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \ |
95 | scratch_t *scratch_cell_, scratch_t *scratch_gates_blocked_, \ |
96 | scratch_t *scratch_src_layer_, scratch_t *scratch_src_iter_, \ |
97 | gemm_acc_t *diff_augru_attention_, \ |
98 | gemm_acc_t *diff_weights_layer_, gemm_acc_t *diff_weights_iter_, \ |
99 | float *diff_weights_projection_, float *diff_weights_peephole_, \ |
100 | float *diff_bias_, gemm_acc_t *amx_scratchpad, \ |
101 | x64::brgemm_batch_element_t *addr_batch_global) const |
102 | #else |
103 | #define rnn_merged_layer_execution_sig(f) \ |
104 | dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \ |
105 | rnn_utils::cell_position_t cell_position, weights_t **w_layer_, \ |
106 | const src_layer_t *src_layer_, scratch_t *scratch_gates_, \ |
107 | gemm_acc_t *diff_src_layer_, gemm_acc_t *diff_w_layer_) const |
108 | |
109 | #define rnn_cell_execution_sig(f) \ |
110 | dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \ |
111 | rnn_utils::cell_position_t cell_position, dst_layer_t *dst_layer_, \ |
112 | void *dst_iter_c_, gemm_acc_t *diff_src_layer_, \ |
113 | gemm_acc_t *diff_augru_attention_, gemm_acc_t *diff_src_iter_, \ |
114 | gemm_acc_t *diff_src_iter_c_, weights_t **w_layer_, \ |
115 | weights_t **w_iter_, weights_t **w_projection_, \ |
116 | const float *weights_peephole_, const float *w_proj_comp, \ |
117 | void **bias_, const src_layer_t *src_layer_, \ |
118 | const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \ |
119 | const void *src_iter_c_, gemm_acc_t *diff_dst_layer_, \ |
120 | gemm_acc_t *diff_dst_iter_, gemm_acc_t *diff_dst_iter_c_, \ |
121 | gemm_acc_t *diff_w_layer_, gemm_acc_t *diff_w_iter_, \ |
122 | float *diff_weights_projection_, float *diff_weights_peephole_, \ |
123 | float *diff_bias_, gates_t *ws_gates_, scratch_t *scratch_gates_, \ |
124 | ht_t *proj_ht_, gemm_acc_t *scratch_diff_ht_, gates_t *ws_grid_, \ |
125 | scratch_t *scratch_cell_, dst_iter_t *dst_iter_, \ |
126 | gemm_acc_t *amx_scratchpad) const |
127 | |
128 | #define rnn_grid_execution_sig(f) \ |
129 | dnnl_status_t f(const rnn_utils::rnn_conf_t &rnn, \ |
130 | weights_t **weights_layer_, weights_t **weights_iter_, \ |
131 | weights_t **weights_projection_, const float *weights_peephole_, \ |
132 | const float *w_proj_comp, void **bias_, \ |
133 | const src_layer_t *src_layer_, \ |
134 | const src_layer_t *augru_attention_, const src_iter_t *src_iter_, \ |
135 | const void *src_iter_c_, dst_layer_t *dst_layer_, \ |
136 | dst_iter_t *dst_iter_, void *dst_iter_c_, \ |
137 | src_layer_t *ws_states_layer_, src_iter_t *ws_states_iter_, \ |
138 | void *ws_states_iter_c_, gemm_acc_t *ws_diff_states_layer_, \ |
139 | gemm_acc_t *ws_diff_states_iter_, \ |
140 | gemm_acc_t *ws_diff_states_iter_c_, gates_t *ws_gates_, \ |
141 | ht_t *ws_ht_, gates_t *ws_grid_, scratch_t *scratch_gates_, \ |
142 | ht_t *scratch_ht_, gemm_acc_t *scratch_diff_ht_, \ |
143 | scratch_t *scratch_cell_, gemm_acc_t *diff_augru_attention_, \ |
144 | gemm_acc_t *diff_weights_layer_, gemm_acc_t *diff_weights_iter_, \ |
145 | float *diff_weights_projection_, float *diff_weights_peephole_, \ |
146 | float *diff_bias_, gemm_acc_t *amx_scratchpad) const |
147 | #endif |
148 | |
149 | #define rnn_gemm_sig(f) \ |
150 | dnnl_status_t f(const char transA, const char transB, dim_t m, dim_t n, \ |
151 | dim_t k, const float alpha, const weights_t *a_, const dim_t ldA, \ |
152 | const gemm_data_t *b_, const dim_t ldB, const float beta, \ |
153 | gemm_acc_t *c_, const dim_t ldC) const |
154 | |
155 | #define rnn_bias_prepare_sig(f) \ |
156 | void f(const rnn_utils::rnn_conf_t &rnn, void **bias_, const void *b_, \ |
157 | void *scratch_bias_) const |
158 | |
159 | #define rnn_bias_prepare_sig_templ(f) \ |
160 | template <typename T> \ |
161 | static void f(const rnn_utils::rnn_conf_t &rnn, T **bias_, const T *b_, \ |
162 | T *scratch_bias_) |
163 | |
164 | #define rnn_bias_finalize_sig(f) \ |
165 | void f(const rnn_utils::rnn_conf_t &rnn, void *scratch_bias_, \ |
166 | const float *w_iter_comp, const float *w_layer_comp) const |
167 | |
168 | #define rnn_weights_assign_sig(f) \ |
169 | void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, \ |
170 | int n_parts, const int *gates_per_part, weights_t **weights_, \ |
171 | const weights_t *w_) const |
172 | |
173 | namespace dnnl { |
174 | namespace impl { |
175 | namespace cpu { |
176 | |
177 | namespace rnn_utils { |
178 | |
179 | enum execution_direction_t { |
180 | l2r, |
181 | r2l, |
182 | bi_concat, |
183 | bi_sum, |
184 | }; |
185 | |
186 | enum cell_position_t { |
187 | middle_cell = 0x0, |
188 | first_layer = 0x1, |
189 | first_iter = 0x2, |
190 | last_layer = 0x4, |
191 | last_iter = 0x8, |
192 | c_state_first_iter = 0x10, |
193 | c_state_last_iter = 0x20 |
194 | }; |
195 | |
196 | enum class weights_type_t { |
197 | layer, |
198 | iter, |
199 | projection, |
200 | peephole, |
201 | }; |
202 | |
203 | inline cell_position_t &operator|=(cell_position_t &lhs, cell_position_t rhs) { |
204 | lhs = static_cast<cell_position_t>( |
205 | static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); |
206 | return lhs; |
207 | } |
208 | |
209 | inline cell_position_t operator|(cell_position_t lhs, cell_position_t rhs) { |
210 | return static_cast<cell_position_t>( |
211 | static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); |
212 | } |
213 | |
214 | enum data_type_conf_t { |
215 | all_f32, |
216 | all_bf16, |
217 | u8u8u8f32, |
218 | f32u8f32f32, |
219 | u8u8u8u8, |
220 | f32u8f32u8, |
221 | s8s8s8f32, |
222 | f32s8f32f32, |
223 | s8s8s8s8, |
224 | f32s8f32s8 |
225 | }; |
226 | |
227 | enum brgemm_rnn_execute_loop_order_t { |
228 | // default for kernels w/o loop order choice |
229 | undefined = 0x0, |
230 | // m_blocking loop is outermost |
231 | mblk_nblk = 0x1, |
232 | // n_blocking loop is outermost |
233 | nblk_mblk = 0x2 |
234 | }; |
235 | |
236 | struct diff_src_brgemm_conf_t { |
237 | dim_t M = 0, N = 0, K = 0; |
238 | |
239 | dim_t n_block = 0, N_blocks = 0, n_tail = 0; |
240 | dim_t m_block = 0, M_blocks = 0; |
241 | |
242 | dim_t K_blocks = 0, k_block = 0, k_tail = 0; |
243 | dim_t Kpadded = 0; |
244 | |
245 | dim_t N_iter = 0, N_layer = 0; |
246 | dim_t N_layer_blocks = 0, n_layer_tail = 0; |
247 | dim_t N_iter_blocks = 0, n_iter_tail = 0; |
248 | dim_t LDA = 0, LDB = 0, LDC = 0; |
249 | |
250 | #if DNNL_X64 |
251 | x64::cpu_isa_t isa = x64::isa_undef; |
252 | #endif |
253 | |
254 | brgemm_rnn_execute_loop_order_t loop_order |
255 | = brgemm_rnn_execute_loop_order_t::undefined; |
256 | int gates_block; |
257 | }; |
258 | |
259 | struct diff_wei_brgemm_conf_t { |
260 | dim_t M = 0, M_layer = 0, M_iter = 0, N = 0, K = 0; |
261 | |
262 | dim_t n_block = 0, N_blocks = 0, n_tail = 0; |
263 | dim_t m_block = 0, M_blocks = 0; |
264 | dim_t K_blocks = 0, k_block = 0, k_tail = 0; |
265 | dim_t Kpadded = 0; |
266 | |
267 | dim_t LDA_layer = 0, LDA_iter = 0, LDB = 0, LDC_iter = 0, LDC_layer = 0; |
268 | |
269 | bool global_transpose = false; |
270 | |
271 | #if DNNL_X64 |
272 | x64::cpu_isa_t isa = x64::isa_undef; |
273 | #endif |
274 | |
275 | brgemm_rnn_execute_loop_order_t loop_order |
276 | = brgemm_rnn_execute_loop_order_t::undefined; |
277 | }; |
278 | |
279 | struct rnn_conf_t { |
280 | execution_direction_t exec_dir; |
281 | data_type_conf_t dt_conf; |
282 | data_type_t cell_dt = data_type::undef; // The data type used by cell |
283 | data_type_t bias_dt = data_type::undef; |
284 | data_type_t src_iter_c_dt = data_type::undef; |
285 | data_type_t dst_iter_c_dt = data_type::undef; |
286 | |
287 | int n_layer = 0, n_iter = 0, n_dir = 0, n_gates = 0, n_states = 0; |
288 | int mb = 0; |
289 | int slc = 0, sic = 0, dhc = 0, dic = 0, dlc = 0; |
290 | //int gates_ld, gates_nld, gates_ws_ld; |
291 | |
292 | int n_parts_weights_layer = 0; |
293 | int parts_weights_layer[DNNL_RNN_MAX_N_PARTS]; |
294 | size_t part_weights_layer_pack_size[DNNL_RNN_MAX_N_PARTS]; |
295 | |
296 | int n_parts_weights_iter = 0; |
297 | int parts_weights_iter[DNNL_RNN_MAX_N_PARTS]; |
298 | size_t part_weights_iter_pack_size[DNNL_RNN_MAX_N_PARTS]; |
299 | |
300 | int n_parts_weights_projection = 0; |
301 | int parts_weights_projection[DNNL_RNN_MAX_N_PARTS]; |
302 | size_t part_weights_projection_pack_size[DNNL_RNN_MAX_N_PARTS]; |
303 | |
304 | int n_bias = 0, n_parts_bias = 0, parts_bias[DNNL_RNN_MAX_N_PARTS]; |
305 | |
306 | /* Size of packed data in bytes */ |
307 | size_t weights_layer_comp_offset = 0, weights_layer_pack_size = 0; |
308 | size_t weights_iter_comp_offset = 0, weights_iter_pack_size = 0; |
309 | size_t weights_projection_comp_offset = 0, weights_projection_pack_size = 0; |
310 | |
311 | bool copy_bias = 0; |
312 | int weights_layer_ld = 0, weights_layer_nld = 0; |
313 | int diff_weights_layer_ld = 0, diff_weights_layer_nld = 0; |
314 | int weights_iter_ld = 0, weights_iter_nld = 0; |
315 | int diff_weights_iter_ld = 0, diff_weights_iter_nld = 0; |
316 | int weights_projection_ld = 0, weights_projection_nld = 0; |
317 | int diff_weights_projection_ld = 0, diff_weights_projection_nld = 0; |
318 | |
319 | int proj_ht_ld = 0, proj_ht_nld = 0; |
320 | |
321 | int ws_gates_ld = 0, ws_gates_nld = 0; |
322 | int ws_ht_ld = 0, ws_ht_nld = 0; |
323 | int ws_states_layer_ld = 0, ws_states_layer_nld = 0; |
324 | int ws_states_iter_ld = 0, ws_states_iter_nld = 0; |
325 | int ws_states_iter_c_ld = 0, ws_states_iter_c_nld = 0; |
326 | int ws_diff_states_layer_ld = 0, ws_diff_states_layer_nld = 0; |
327 | int ws_diff_states_iter_ld = 0, ws_diff_states_iter_nld = 0; |
328 | int ws_diff_states_iter_c_ld = 0, ws_diff_states_iter_c_nld = 0; |
329 | |
330 | int scratch_gates_ld = 0, scratch_gates_nld = 0; |
331 | int scratch_ht_ld = 0, scratch_ht_nld = 0; |
332 | int scratch_diff_ht_ld = 0, scratch_diff_ht_nld = 0; |
333 | |
334 | int src_layer_ld_ = 0, src_layer_nld_ = 0; |
335 | int src_iter_ld_ = 0, src_iter_nld_ = 0; |
336 | int src_iter_c_ld_ = 0, src_iter_c_nld_ = 0; |
337 | int dst_layer_ld_ = 0, dst_layer_nld_ = 0; |
338 | int dst_iter_ld_ = 0, dst_iter_nld_ = 0; |
339 | int dst_iter_c_ld_ = 0, dst_iter_c_nld_ = 0; |
340 | |
341 | int weights_iter_compensation_size = 0, weights_layer_compensation_size = 0; |
342 | bool is_fwd = 0, is_training = 0, is_lbr = 0, is_lstm_peephole = 0, |
343 | is_lstm_projection = 0, is_augru = 0, is_orig_gru = 0; |
344 | bool use_workspace = 0; |
345 | |
346 | // Size of workspace for each tensor in bytes |
347 | // Notes: |
348 | // 1. For non-LSTMP ws_states_iter_size == ws_states_layer_size. The corresponding |
349 | // pointers should point to the same places. |
350 | size_t ws_gates_size = 0; |
351 | size_t ws_ht_size = 0; |
352 | size_t ws_states_layer_size = 0; |
353 | size_t ws_states_iter_size = 0; |
354 | size_t ws_states_iter_c_size = 0; |
355 | size_t ws_diff_states_layer_size = 0; |
356 | size_t ws_diff_states_iter_size = 0; |
357 | size_t ws_diff_states_iter_c_size = 0; |
358 | size_t scratch_gates_size = 0; |
359 | |
360 | size_t scratch_gates_blocked_size = 0; |
361 | size_t scratch_gates_blocked_nested_reorder_size = 0; |
362 | size_t scratch_src_layer_size = 0; |
363 | size_t scratch_src_layer_nested_reorder_size = 0; |
364 | size_t scratch_src_iter_size = 0; |
365 | size_t scratch_src_iter_nested_reorder_size = 0; |
366 | |
367 | size_t scratch_ht_size = 0; |
368 | size_t scratch_diff_ht_size = 0; |
369 | size_t scratch_cell_size = 0; |
370 | size_t ws_grid_comp_size = 0; |
371 | size_t ws_per_cell = 0; |
372 | size_t ws_bias_size = 0; |
373 | |
374 | bool src_layer_is_trivial_stride = false; |
375 | bool merge_gemm_iter = false, merge_gemm_layer = false, |
376 | force_nocopy = false, use_layer_packed_gemm = false, |
377 | use_iter_packed_gemm = false, use_projection_packed_gemm = false; |
378 | int n_iter_scratch_gates = 0; |
379 | |
380 | inline bool is_int8_conf() const { |
381 | return is_signed_int8_conf() || is_unsigned_int8_conf(); |
382 | } |
383 | inline bool is_signed_int8_conf() const { |
384 | return utils::one_of( |
385 | dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8, f32s8f32s8); |
386 | } |
387 | inline bool is_unsigned_int8_conf() const { |
388 | return utils::one_of( |
389 | dt_conf, u8u8u8f32, f32u8f32f32, u8u8u8u8, f32u8f32u8); |
390 | } |
391 | |
392 | inline bool is_cell_dt_int8() const { |
393 | return is_cell_dt_signed_int8() || is_cell_dt_unsigned_int8(); |
394 | } |
395 | inline bool is_cell_dt_signed_int8() const { |
396 | return cell_dt == data_type::s8; |
397 | } |
398 | inline bool is_cell_dt_unsigned_int8() const { |
399 | return cell_dt == data_type::u8; |
400 | } |
401 | |
402 | inline bool is_cell_int8_amx() const { |
403 | #if DNNL_X64 |
404 | return brgemm_isa == x64::avx512_core_amx && is_cell_dt_int8(); |
405 | #else |
406 | return false; |
407 | #endif |
408 | } |
409 | |
410 | inline bool is_bf16_conf() const { return dt_conf == all_bf16; } |
411 | |
412 | inline bool is_f32_conf() const { return dt_conf == all_f32; } |
413 | |
414 | inline bool is_cell_dt_f32() const { return cell_dt == data_type::f32; } |
415 | inline bool is_cell_dt_bf16() const { return cell_dt == data_type::bf16; } |
416 | inline bool is_cell_bf16_amx() const { |
417 | #if DNNL_X64 |
418 | return brgemm_isa == x64::avx512_core_amx && is_cell_dt_bf16(); |
419 | #else |
420 | return false; |
421 | #endif |
422 | } |
423 | inline bool is_bf32() const { return is_cell_bf16_amx() && is_f32_conf(); } |
424 | |
425 | inline bool skip_src_layer_copy() const { |
426 | return (exec_dir == l2r) && !is_bf32() |
427 | && utils::one_of(dt_conf, s8s8s8f32, f32s8f32f32, s8s8s8s8, |
428 | f32s8f32s8, u8u8u8u8, u8u8u8f32, f32u8f32u8, |
429 | f32u8f32f32, all_f32, all_bf16); |
430 | } |
431 | inline bool skip_src_iter_copy() const { |
432 | return (exec_dir == l2r) && (src_iter_ld_ > 0) && !is_bf32() |
433 | && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8, |
434 | u8u8u8f32, all_f32, all_bf16); |
435 | } |
436 | inline bool skip_dst_layer_copy() const { |
437 | return (exec_dir == l2r) && !is_bf32() |
438 | && utils::one_of(dt_conf, s8s8s8s8, f32s8f32s8, u8u8u8u8, |
439 | f32u8f32u8, all_f32, all_bf16); |
440 | } |
441 | inline bool skip_dst_iter_copy() const { |
442 | return (exec_dir == l2r) && (dst_iter_ld_ > 0) && !is_bf32() |
443 | && utils::one_of(dt_conf, s8s8s8s8, s8s8s8f32, u8u8u8u8, |
444 | u8u8u8f32, all_f32, all_bf16); |
445 | } |
446 | |
447 | inline dim_t src_layer_ld(cell_position_t cell_position) const { |
448 | return (cell_position & first_layer) && skip_src_layer_copy() |
449 | ? src_layer_ld_ |
450 | : (cell_position & last_iter) && skip_dst_iter_copy() |
451 | ? dst_iter_ld_ |
452 | : ws_states_layer_ld; |
453 | } |
454 | |
455 | inline dim_t src_iter_ld(cell_position_t cell_position) const { |
456 | return (cell_position & first_iter) && skip_src_iter_copy() |
457 | ? src_iter_ld_ |
458 | : ((cell_position & last_layer) && skip_dst_layer_copy() |
459 | && !(cell_position & first_iter) |
460 | ? dst_layer_ld_ |
461 | : ws_states_iter_ld); |
462 | } |
463 | |
464 | inline dim_t layer_brgemm_desc(cell_position_t cell_position) const { |
465 | return ((cell_position & first_layer) && skip_src_layer_copy()) |
466 | ? 0 |
467 | : ((cell_position & last_iter) && skip_dst_iter_copy()) ? 1 : 2; |
468 | } |
469 | |
470 | inline dim_t iter_brgemm_desc(cell_position_t cell_position) const { |
471 | return ((cell_position & first_iter) && skip_src_iter_copy()) |
472 | ? 0 |
473 | : ((cell_position & last_layer) && skip_dst_layer_copy() |
474 | && !(cell_position & first_iter)) |
475 | ? 1 |
476 | : 2; |
477 | } |
478 | |
479 | // Returns index of brgemm kernel for 2nd part of iteration gemm in vanilla |
480 | // GRU cell for the current position. |
481 | // Note: this method must be aligned with dst_iter_part2_ld() and LDA2_2[] |
482 | // values initialization order |
483 | inline dim_t iter_part2_brgemm_desc(cell_position_t cell_position) const { |
484 | if (cell_position & last_layer) { |
485 | return (cell_position & last_layer) && skip_dst_layer_copy() |
486 | ? 0 |
487 | : (cell_position & last_iter) && skip_dst_iter_copy() ? 1 |
488 | : 2; |
489 | } else { |
490 | return (cell_position & last_iter) && skip_dst_iter_copy() ? 1 : 3; |
491 | } |
492 | } |
493 | |
494 | inline dim_t src_iter_c_ld(cell_position_t cell_position) const { |
495 | return (cell_position & c_state_first_iter) ? src_iter_c_ld_ |
496 | : ws_states_iter_c_ld; |
497 | } |
498 | |
499 | inline dim_t dst_layer_ld( |
500 | cell_position_t cell_position, bool after_proj = false) const { |
501 | // We use scratch_ht and not dst_layer for lstmp |
502 | if (is_lstm_projection && !after_proj) return scratch_ht_ld; |
503 | |
504 | return (cell_position & last_layer) && skip_dst_layer_copy() |
505 | ? dst_layer_ld_ |
506 | : (cell_position & last_iter) && skip_dst_iter_copy() |
507 | ? dst_iter_ld_ |
508 | : ws_states_layer_ld; |
509 | } |
510 | |
511 | inline dim_t dst_brgemm_desc( |
512 | cell_position_t cell_position, bool after_proj = false) const { |
513 | // We use scratch_ht and not dst_layer for lstmp |
514 | if (is_lstm_projection && !after_proj) return 0; |
515 | |
516 | return (cell_position & last_layer) && skip_dst_layer_copy() |
517 | ? 1 |
518 | : (cell_position & last_iter) && skip_dst_iter_copy() ? 2 : 3; |
519 | } |
520 | |
521 | inline dim_t dst_iter_ld(cell_position_t cell_position) const { |
522 | return (cell_position & last_iter) && skip_dst_iter_copy() |
523 | ? dst_iter_ld_ |
524 | : ws_states_iter_ld; |
525 | } |
526 | |
527 | // Returns dst tensor leading dimension for 2nd part of iteration gemm in |
528 | // vanilla GRU cell for the current position |
529 | inline dim_t dst_iter_part2_ld(cell_position_t cell_position) const { |
530 | return (cell_position & last_layer) ? dst_layer_ld(cell_position) |
531 | : dst_iter_ld(cell_position); |
532 | } |
533 | |
534 | inline dim_t dst_iter_c_ld(cell_position_t cell_position) const { |
535 | return (cell_position & c_state_last_iter) ? dst_iter_c_ld_ |
536 | : ws_states_iter_c_ld; |
537 | } |
538 | |
539 | // // when skipping copy, the output ld can be states_ws_ld, |
540 | // // dst_iter_ld or dst_layer_ld depending on the cell position |
541 | // inline dim_t dst_ld(cell_position_t cell_position) const { |
542 | // return (cell_position & last_layer) ? dst_layer_ld(cell_position) |
543 | // : dst_iter_ld(cell_position); |
544 | // } |
545 | inline dim_t dst_copy_ld(cell_position_t cell_position) const { |
546 | return dst_iter_ld(cell_position); |
547 | } |
548 | |
549 | inline bool need_gemm_layer(cell_position_t cell_position) const { |
550 | // In case of merge_gemm_layer we might still need a layer gemm if we store |
551 | // the states of the last iteration in the destination memory. The |
552 | // exception of this rule is the first layer though, in which case all |
553 | // states are kept in user's src_layer, hence making full merged gemm |
554 | // possible. |
555 | return IMPLICATION(merge_gemm_layer, |
556 | skip_dst_iter_copy() && (cell_position & last_iter) |
557 | && !(cell_position & first_layer)); |
558 | } |
559 | bool is_brgemm; |
560 | |
561 | diff_src_brgemm_conf_t diff_src_brgemm; |
562 | diff_wei_brgemm_conf_t diff_wei_brgemm; |
563 | |
564 | dim_t M, N, K1, K2; |
565 | |
566 | dim_t LDB1, LDB2; |
567 | dim_t LDA1[3]; |
568 | dim_t LDA2[3]; |
569 | // LDA for iter part2 gemm in vanilla gru cell |
570 | dim_t LDA2_2[4]; |
571 | dim_t LDC; |
572 | |
573 | dim_t m_block, M_blocks; |
574 | dim_t n_block, N_blocks, n_tail; |
575 | |
576 | dim_t k2_block, k1_block, k1_tail, k2_tail; |
577 | dim_t KB1_blocks, KB2_blocks; |
578 | dim_t K1padded, K2padded; |
579 | |
580 | dim_t Kproj, Kprojpadded; |
581 | dim_t kproj_block, KBproj_blocks, kproj_tail; |
582 | |
583 | dim_t Nproj, Nproj_blocks, nproj_tail; |
584 | dim_t LDAproj, LDBproj, LDCproj[4]; |
585 | int dhc_block_peephole, dhc_tail_peephole, dhc_blocks_peephole; |
586 | bool brgemm_fwd_iter_layer_fuse_possible = false; |
587 | |
588 | dim_t nthr; |
589 | #if DNNL_X64 |
590 | x64::cpu_isa_t brgemm_isa; |
591 | #endif |
592 | bool unfused_post_gemm; |
593 | brgemm_rnn_execute_loop_order_t loop_order |
594 | = brgemm_rnn_execute_loop_order_t::undefined; |
595 | |
596 | // for merged layer computation in brgemm |
597 | dim_t Mlayermerged; |
598 | dim_t mlayermerged_block, Mlayermerged_blocks; |
599 | }; |
600 | |
601 | bool is_ldigo(const memory_desc_wrapper &md); |
602 | bool is_ldgoi(const memory_desc_wrapper &md); |
603 | bool is_ldio(const memory_desc_wrapper &md); |
604 | bool is_ldoi(const memory_desc_wrapper &md); |
605 | bool is_ldigo_blocked(const memory_desc_wrapper &md); |
606 | bool is_ldgoi_blocked(const memory_desc_wrapper &md); |
607 | bool is_ldio_blocked(const memory_desc_wrapper &md); |
608 | bool is_ldoi_blocked(const memory_desc_wrapper &md); |
609 | |
610 | int get_good_ld(int dim, int sizeof_dt); |
611 | |
612 | template <typename T> |
613 | bool init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, |
614 | const primitive_attr_t &attr, const memory_desc_wrapper &src_layer_d, |
615 | const memory_desc_wrapper &src_iter_d, |
616 | const memory_desc_wrapper &src_iter_c_d, |
617 | const memory_desc_wrapper &weights_layer_d, |
618 | const memory_desc_wrapper &weights_iter_d, |
619 | const memory_desc_wrapper &weights_projection_d, |
620 | const memory_desc_wrapper &dst_layer_d, |
621 | const memory_desc_wrapper &dst_iter_d, |
622 | const memory_desc_wrapper &dst_iter_c_d, |
623 | const memory_desc_wrapper &bias_d) { |
624 | rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training, |
625 | prop_kind::forward_inference); |
626 | rnn.is_training = utils::one_of( |
627 | rd.prop_kind, prop_kind::forward_training, prop_kind::backward); |
628 | rnn.is_lbr = utils::one_of(rd.cell_kind, dnnl_lbr_gru, dnnl_lbr_augru); |
629 | rnn.is_lstm_peephole = rd.cell_kind == dnnl_vanilla_lstm |
630 | && !memory_desc_wrapper(rd.weights_peephole_desc).is_zero(); |
631 | rnn.is_lstm_projection = rd.cell_kind == dnnl_vanilla_lstm |
632 | && !memory_desc_wrapper(rd.weights_projection_desc).is_zero(); |
633 | rnn.is_augru |
634 | = utils::one_of(rd.cell_kind, dnnl_lbr_augru, dnnl_vanilla_augru); |
635 | rnn.bias_dt = bias_d.is_zero() ? data_type::f32 : bias_d.data_type(); |
636 | rnn.src_iter_c_dt = src_iter_c_d.is_zero() ? data_type::f32 |
637 | : src_iter_c_d.data_type(); |
638 | rnn.dst_iter_c_dt = dst_iter_c_d.is_zero() ? data_type::f32 |
639 | : dst_iter_c_d.data_type(); |
640 | |
641 | rnn.cell_dt = data_traits<typename T::src_layer_t>::data_type; |
642 | switch (rd.direction) { |
643 | case dnnl_unidirectional_left2right: rnn.exec_dir = l2r; break; |
644 | case dnnl_unidirectional_right2left: rnn.exec_dir = r2l; break; |
645 | case dnnl_bidirectional_concat: rnn.exec_dir = bi_concat; break; |
646 | case dnnl_bidirectional_sum: rnn.exec_dir = bi_sum; break; |
647 | default: break; |
648 | } |
649 | |
650 | if (utils::everyone_is(data_type::f32, src_layer_d.data_type(), |
651 | dst_layer_d.data_type(), weights_layer_d.data_type())) |
652 | rnn.dt_conf = all_f32; |
653 | else if (utils::everyone_is(data_type::bf16, src_layer_d.data_type(), |
654 | dst_layer_d.data_type(), weights_layer_d.data_type())) { |
655 | if (!platform::has_data_type_support(data_type::bf16)) return false; |
656 | rnn.dt_conf = all_bf16; |
657 | } else if (dst_layer_d.data_type() == data_type::u8) { |
658 | if (IMPLICATION( |
659 | src_iter_d.md_, src_iter_d.data_type() == data_type::u8)) |
660 | rnn.dt_conf = u8u8u8u8; |
661 | else |
662 | rnn.dt_conf = f32u8f32u8; |
663 | } else if (dst_layer_d.data_type() == data_type::s8) { |
664 | if (IMPLICATION( |
665 | src_iter_d.md_, src_iter_d.data_type() == data_type::s8)) |
666 | rnn.dt_conf = s8s8s8s8; |
667 | else |
668 | rnn.dt_conf = f32s8f32s8; |
669 | |
670 | } else if (dst_layer_d.data_type() == data_type::f32) { |
671 | if (IMPLICATION( |
672 | src_iter_d.md_, src_iter_d.data_type() == data_type::u8)) |
673 | rnn.dt_conf = u8u8u8f32; |
674 | else if (IMPLICATION(src_iter_d.md_, |
675 | src_iter_d.data_type() == data_type::s8)) |
676 | rnn.dt_conf = s8s8s8f32; |
677 | else if (IMPLICATION(src_layer_d.md_, |
678 | src_layer_d.data_type() == data_type::s8)) |
679 | rnn.dt_conf = f32s8f32f32; |
680 | else |
681 | rnn.dt_conf = f32u8f32f32; |
682 | } |
683 | |
684 | // Set problem members defining problem sizes |
685 | rnn.n_layer = weights_layer_d.dims()[0]; |
686 | rnn.n_iter = src_layer_d.dims()[0]; |
687 | rnn.n_dir = weights_layer_d.dims()[1]; |
688 | rnn.n_gates = weights_layer_d.dims()[3]; |
689 | rnn.n_states = rd.cell_kind == dnnl_vanilla_lstm ? 2 : 1; |
690 | rnn.n_bias = rnn.n_gates + rnn.is_lbr; |
691 | rnn.mb = src_layer_d.dims()[1]; |
692 | rnn.sic = weights_iter_d.dims()[2]; |
693 | rnn.slc = weights_layer_d.dims()[2]; |
694 | rnn.dhc = weights_layer_d.dims()[4]; |
695 | rnn.dlc = rnn.is_lstm_projection ? weights_projection_d.dims()[3] : rnn.dhc; |
696 | // All supported cells have dic == dlc |
697 | rnn.dic = rnn.dlc; |
698 | |
699 | // set members with user memories leading dimensions |
700 | // Assumption: weights datatype size is the same as state datatype size |
701 | assert(types::data_type_size(weights_layer_d.data_type()) |
702 | == types::data_type_size(src_layer_d.data_type())); |
703 | |
704 | // set workspace leading dimensions (and non leading-dimensions) |
705 | |
706 | // the ws and scratch proj_ht need to match as we use them interchangeably |
707 | assert(IMPLICATION(rnn.is_lstm_projection, |
708 | sizeof(typename T::ht_t) == sizeof(typename T::dst_iter_t))); |
709 | rnn.proj_ht_nld = rnn.mb; |
710 | rnn.proj_ht_ld = get_good_ld(rnn.dhc, sizeof(typename T::ht_t)); |
711 | |
712 | rnn.ws_gates_nld = rnn.mb; |
713 | rnn.ws_gates_ld |
714 | = get_good_ld(rnn.dhc * rnn.n_gates, sizeof(typename T::gates_t)); |
715 | rnn.ws_ht_nld = rnn.proj_ht_nld; |
716 | rnn.ws_ht_ld = rnn.proj_ht_ld; |
717 | |
718 | rnn.ws_states_layer_nld = rnn.mb; |
719 | static_assert(std::is_same<typename T::src_layer_t, |
720 | typename T::src_iter_t>::value, |
721 | "src_layer_t and src_iter_t must be the same" ); |
722 | rnn.ws_states_layer_ld |
723 | = get_good_ld(nstl::max(rnn.sic, nstl::max(rnn.slc, rnn.dlc)), |
724 | sizeof(typename T::src_layer_t)); |
725 | // there is no need for al separate ws_states_iter for now as all |
726 | // supported cell have dst_iter == dst_layer |
727 | rnn.ws_states_iter_nld = rnn.ws_states_layer_nld; |
728 | rnn.ws_states_iter_ld = rnn.ws_states_layer_ld; |
729 | |
730 | // we do not need a good ld for iter_c as it is not involved in GEMM |
731 | rnn.ws_states_iter_c_nld = rnn.mb; |
732 | rnn.ws_states_iter_c_ld = rnn.dhc; |
733 | |
734 | // TODO: be more restrictive on the leading dimensions |
735 | rnn.ws_diff_states_layer_nld = rnn.mb; |
736 | rnn.ws_diff_states_layer_ld = get_good_ld( |
737 | nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)), |
738 | sizeof(typename T::gemm_acc_t)); |
739 | |
740 | rnn.ws_diff_states_iter_nld = rnn.mb; |
741 | rnn.ws_diff_states_iter_ld = get_good_ld( |
742 | nstl::max(nstl::max(rnn.slc, rnn.dic), nstl::max(rnn.sic, rnn.dhc)), |
743 | sizeof(typename T::gemm_acc_t)); |
744 | |
745 | rnn.ws_diff_states_iter_c_nld = rnn.mb; |
746 | rnn.ws_diff_states_iter_c_ld = rnn.dhc; |
747 | |
748 | // set scratch (not)leading dimensions |
749 | // scratch gates is used to store intermediate gates before postgemm operation |
750 | // temporary: we also use it in lstmp as temporary scratchpad |
751 | // between projection and downconversion, hence the max with dlc |
752 | rnn.scratch_gates_nld = rnn.mb; |
753 | rnn.scratch_gates_ld |
754 | = get_good_ld(nstl::max(rnn.dlc, rnn.n_gates * rnn.dhc), |
755 | sizeof(typename T::scratch_t)); |
756 | rnn.scratch_ht_nld = rnn.proj_ht_nld; |
757 | rnn.scratch_ht_ld = rnn.proj_ht_ld; |
758 | |
759 | rnn.scratch_diff_ht_nld = rnn.mb; |
760 | rnn.scratch_diff_ht_ld |
761 | = get_good_ld(rnn.dlc, sizeof(typename T::gemm_acc_t)); |
762 | |
763 | // Assumption: {src,dst}_layer has tnc layout, {src,dst}_iter has ldnc, |
764 | rnn.src_layer_ld_ = src_layer_d.blocking_desc().strides[1]; |
765 | rnn.dst_layer_ld_ = dst_layer_d.blocking_desc().strides[1]; |
766 | rnn.src_iter_ld_ = types::is_zero_md(src_iter_d.md_) |
767 | ? 0 |
768 | : src_iter_d.blocking_desc().strides[2]; |
769 | rnn.dst_iter_ld_ = types::is_zero_md(dst_iter_d.md_) |
770 | ? 0 |
771 | : dst_iter_d.blocking_desc().strides[2]; |
772 | rnn.src_iter_c_ld_ = types::is_zero_md(src_iter_c_d.md_) |
773 | ? 0 |
774 | : src_iter_c_d.blocking_desc().strides[2]; |
775 | rnn.dst_iter_c_ld_ = types::is_zero_md(dst_iter_c_d.md_) |
776 | ? 0 |
777 | : dst_iter_c_d.blocking_desc().strides[2]; |
778 | |
779 | /* Set the correct number of weights parts */ |
780 | rnn.is_orig_gru = utils::one_of( |
781 | rd.cell_kind, alg_kind::vanilla_gru, alg_kind::vanilla_augru); |
782 | rnn.n_parts_weights_layer = 1; |
783 | rnn.parts_weights_layer[0] = rnn.n_gates; |
784 | rnn.parts_weights_layer[1] = 0; |
785 | |
786 | rnn.n_parts_weights_iter = rnn.is_orig_gru ? 2 : 1; |
787 | rnn.parts_weights_iter[0] = rnn.is_orig_gru ? 2 : rnn.n_gates; |
788 | rnn.parts_weights_iter[1] = rnn.is_orig_gru ? 1 : 0; |
789 | |
790 | rnn.n_parts_weights_projection = 1; |
791 | rnn.parts_weights_projection[0] = 1; |
792 | |
793 | rnn.n_parts_bias = 1; |
794 | rnn.parts_bias[0] = rnn.n_bias; |
795 | rnn.parts_bias[1] = 0; |
796 | |
797 | /* Decide which gemm implementation to use: packed/nonpacked jit/cblas |
798 | * and if to merge gemm across iterations */ |
799 | const bool is_f32 = rnn.dt_conf == all_f32, |
800 | is_bf16 = rnn.dt_conf == all_bf16; |
801 | const bool is_gru = utils::one_of(rd.cell_kind, alg_kind::vanilla_gru, |
802 | alg_kind::lbr_gru, alg_kind::vanilla_augru, alg_kind::lbr_augru); |
803 | const bool is_inference = !rnn.is_training; |
804 | |
805 | // To be able to merge the GEMM on the layer input when not |
806 | // copying, we need to have a trivial stride for the T dimension |
807 | rnn.src_layer_is_trivial_stride = src_layer_d.blocking_desc().strides[0] |
808 | == (rnn.src_layer_ld_ * rnn.mb); |
809 | const auto dst_layer_is_trivial_stride |
810 | = dst_layer_d.blocking_desc().strides[0] |
811 | == (rnn.dst_layer_ld_ * rnn.mb); |
812 | |
813 | rnn.merge_gemm_layer = (!rnn.is_brgemm) |
814 | ? ((rnn.is_fwd && rnn.src_layer_is_trivial_stride) |
815 | || ((rd.prop_kind == prop_kind::backward) |
816 | && dst_layer_is_trivial_stride)) |
817 | && (((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd) |
818 | || rnn.is_int8_conf()) |
819 | : false; |
820 | rnn.merge_gemm_iter = (!rnn.is_brgemm) |
821 | ? dst_layer_is_trivial_stride && !(rnn.is_fwd || is_gru) |
822 | : false; |
823 | rnn.force_nocopy = false; |
824 | #if DNNL_X64 |
825 | rnn.force_nocopy = x64::mayiuse(x64::avx) |
826 | && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100)) |
827 | || (rnn.is_training && rnn.dhc < 500)); |
828 | #endif |
829 | |
830 | /* Decide to copy bias */ |
831 | rnn.copy_bias = rnn.is_int8_conf(); |
832 | |
833 | rnn.use_layer_packed_gemm = !rnn.is_brgemm |
834 | ? utils::one_of(weights_layer_d.format_kind(), format_kind::any, |
835 | format_kind::rnn_packed) |
836 | && is_inference |
837 | && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1) |
838 | || rnn.is_int8_conf() || is_bf16) |
839 | : false; |
840 | rnn.use_iter_packed_gemm = !rnn.is_brgemm |
841 | ? utils::one_of(weights_iter_d.format_kind(), format_kind::any, |
842 | format_kind::rnn_packed) |
843 | && is_inference |
844 | && ((is_f32 && pack_sgemm_supported() && rnn.mb >= 16) |
845 | || rnn.is_int8_conf() || is_bf16) |
846 | : false; |
847 | rnn.use_projection_packed_gemm = !rnn.is_brgemm |
848 | ? utils::one_of(weights_projection_d.format_kind(), |
849 | format_kind::any, format_kind::rnn_packed) |
850 | && is_inference |
851 | && ((is_f32 && pack_sgemm_supported() && rnn.n_iter == 1) |
852 | || rnn.is_int8_conf() || is_bf16) |
853 | : false; |
854 | |
855 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
856 | // XXX: Threadpool runtime may use different number of threads at execute |
857 | // and create stages. GEMM packed API is not aware of number of threads as |
858 | // of now. In order to synchronize all layers, GEMM pack API should be |
859 | // modified to accept number of threads instead of taking it from |
860 | // `dnnl_get_max_threads()`, and rnn_packed_desc_t should be updated with |
861 | // `nthr` member to pass this information between different parts of packed |
862 | // API, since `get_size` call happens on RNN side, while packing happens |
863 | // on reorder side. Consider enabling later. |
864 | // `test_iface_runtime_attr` was disabled for RNN with threadpool due to |
865 | // this is the only working approach for int8 computations in RNN for now. |
866 | // Consider enabling it once resolved. |
867 | rnn.use_layer_packed_gemm = false; |
868 | rnn.use_iter_packed_gemm = false; |
869 | rnn.use_projection_packed_gemm = false; |
870 | #endif |
871 | |
872 | /* Set packed gemm sizes */ |
873 | /* TODO: investigate the benefit of mixing packed and non-packed weights parts */ |
874 | const auto set_pack_sizes |
875 | = [&](bool merge, bool &do_pack, size_t &weights_pack_size, |
876 | int &n_parts, int *parts, size_t *parts_pack_size, |
877 | size_t &comp_offset, int ic, int oc, int weights_oc, |
878 | dim_t data_ld) -> bool { |
879 | bool pack = true; |
880 | weights_pack_size = 0; |
881 | for (int p = 0; p < n_parts; p++) { |
882 | const dim_t m_p = rnn.is_fwd ? (parts[p] * oc) : ic; |
883 | const dim_t k_p = rnn.is_fwd ? ic : (parts[p] * oc); |
884 | const dim_t n_p = merge ? rnn.mb * rnn.n_iter : rnn.mb; |
885 | bool pack_part = true; |
886 | |
887 | dnnl_status_t st = dnnl_success; |
888 | switch (rnn.dt_conf) { |
889 | case all_f32: |
890 | st = sgemm_pack_get_size("A" , "N" , "N" , &m_p, &n_p, &k_p, |
891 | &m_p, &data_ld, &parts_pack_size[p], &pack_part); |
892 | break; |
893 | case s8s8s8f32: |
894 | case f32s8f32f32: |
895 | case s8s8s8s8: |
896 | case f32s8f32s8: |
897 | st = gemm_s8s8s32_pack_get_size("A" , "N" , "N" , &m_p, &n_p, |
898 | &k_p, &m_p, &data_ld, &parts_pack_size[p], |
899 | &pack_part); |
900 | break; |
901 | case u8u8u8f32: |
902 | case f32u8f32f32: |
903 | case u8u8u8u8: |
904 | case f32u8f32u8: |
905 | st = gemm_s8u8s32_pack_get_size("A" , "N" , "N" , &m_p, &n_p, |
906 | &k_p, &m_p, &data_ld, &parts_pack_size[p], |
907 | &pack_part); |
908 | break; |
909 | case all_bf16: |
910 | st = gemm_bf16bf16f32_pack_get_size("A" , "N" , "N" , &m_p, |
911 | &n_p, &k_p, &m_p, &data_ld, &parts_pack_size[p], |
912 | &pack_part); |
913 | break; |
914 | default: assert(!"Unsupported configuration" ); |
915 | } |
916 | if (st != dnnl_success) return false; |
917 | |
918 | pack = pack && pack_part; |
919 | weights_pack_size += rnn.n_layer * rnn.n_dir * parts_pack_size[p]; |
920 | } |
921 | |
922 | // NOTE: pack is updated only for f32. We force pack for int8 |
923 | do_pack = (rnn.dt_conf == all_f32) ? pack : true; |
924 | comp_offset = weights_pack_size; |
925 | const bool need_compensation = rnn.is_int8_conf(); |
926 | weights_pack_size += (need_compensation ? rnn.n_layer * rnn.n_dir : 0) |
927 | * weights_oc * sizeof(float); |
928 | |
929 | return true; |
930 | }; |
931 | // TODO: the activation leading dimension can vary for first layer/iteration |
932 | if (rnn.use_layer_packed_gemm) { |
933 | bool ok = set_pack_sizes(rnn.merge_gemm_layer, |
934 | rnn.use_layer_packed_gemm, rnn.weights_layer_pack_size, |
935 | rnn.n_parts_weights_layer, rnn.parts_weights_layer, |
936 | rnn.part_weights_layer_pack_size, rnn.weights_layer_comp_offset, |
937 | rnn.slc, rnn.dhc, rnn.n_gates * rnn.dhc, |
938 | rnn.ws_states_layer_ld); |
939 | if (!ok) return false; |
940 | } |
941 | |
942 | if (rnn.use_iter_packed_gemm) { |
943 | bool ok = set_pack_sizes(rnn.merge_gemm_iter, rnn.use_iter_packed_gemm, |
944 | rnn.weights_iter_pack_size, rnn.n_parts_weights_iter, |
945 | rnn.parts_weights_iter, rnn.part_weights_iter_pack_size, |
946 | rnn.weights_iter_comp_offset, rnn.sic, rnn.dhc, |
947 | rnn.n_gates * rnn.dhc, rnn.ws_states_iter_ld); |
948 | if (!ok) return false; |
949 | } |
950 | |
951 | if (rnn.use_projection_packed_gemm) { |
952 | bool ok = set_pack_sizes(false, rnn.use_projection_packed_gemm, |
953 | rnn.weights_projection_pack_size, |
954 | rnn.n_parts_weights_projection, rnn.parts_weights_projection, |
955 | rnn.part_weights_projection_pack_size, |
956 | rnn.weights_projection_comp_offset, rnn.dhc, rnn.dic, rnn.dic, |
957 | rnn.scratch_ht_ld); |
958 | if (!ok) return false; |
959 | } |
960 | |
961 | return true; |
962 | } |
963 | |
964 | template <typename T> |
965 | void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd, |
966 | const memory_desc_wrapper &weights_layer_d, |
967 | const memory_desc_wrapper &weights_iter_d, |
968 | const memory_desc_wrapper &weights_projection_d, |
969 | const memory_desc_wrapper &diff_weights_layer_d, |
970 | const memory_desc_wrapper &diff_weights_iter_d, |
971 | const memory_desc_wrapper &diff_weights_projection_d) { |
972 | |
973 | // Set leading dimensions for input weights arrays depending on input format |
974 | const auto set_dims |
975 | = [&](const memory_desc_wrapper &md, int &ld, int &nld) { |
976 | ld = 0; |
977 | nld = 0; |
978 | if (md.is_blocking_desc()) { |
979 | if (is_ldigo(md)) { |
980 | ld = (int)md.blocking_desc().strides[2]; |
981 | nld = md.dims()[2]; |
982 | } else if (is_ldgoi(md)) { |
983 | ld = (int)md.blocking_desc().strides[4]; |
984 | nld = md.dims()[3] * md.dims()[4]; |
985 | } else if (is_ldoi(md)) { |
986 | ld = (int)md.blocking_desc().strides[3]; |
987 | nld = md.dims()[3]; |
988 | } else if (is_ldio(md)) { |
989 | ld = (int)md.blocking_desc().strides[2]; |
990 | nld = md.dims()[2]; |
991 | } else |
992 | assert(!"unsupported weights format" ); |
993 | } |
994 | }; |
995 | set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld); |
996 | set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld); |
997 | set_dims(weights_projection_d, rnn.weights_projection_ld, |
998 | rnn.weights_projection_nld); |
999 | if (!rnn.is_fwd) { |
1000 | set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld, |
1001 | rnn.diff_weights_layer_nld); |
1002 | set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld, |
1003 | rnn.diff_weights_iter_nld); |
1004 | set_dims(diff_weights_projection_d, rnn.diff_weights_projection_ld, |
1005 | rnn.diff_weights_projection_nld); |
1006 | } |
1007 | |
1008 | assert(weights_layer_d.data_type() == weights_iter_d.data_type()); |
1009 | assert(IMPLICATION(diff_weights_layer_d.ndims() != 0, |
1010 | (diff_weights_layer_d.data_type() |
1011 | == diff_weights_iter_d.data_type()))); |
1012 | |
1013 | /* Set workspace sizes to store: |
1014 | * states to compute a pass |
1015 | * diff states to compute bwd pass (training onl)y |
1016 | * intermediate results from the gates |
1017 | */ |
1018 | |
1019 | assert(sizeof(typename T::src_layer_t) == sizeof(typename T::dst_layer_t)); |
1020 | assert(sizeof(typename T::src_iter_t) == sizeof(typename T::dst_iter_t)); |
1021 | } |
1022 | |
1023 | template <typename T> |
1024 | void set_workspace_sizes(rnn_conf_t &rnn, const rnn_desc_t &rd) { |
1025 | rnn.use_workspace = rnn.is_training; |
1026 | // TODO: for inference, we can make ws_states_* smaller, but |
1027 | // dependant of the grid execution though |
1028 | rnn.ws_states_layer_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir |
1029 | * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_layer_ld |
1030 | * sizeof(typename T::src_layer_t); |
1031 | rnn.ws_states_iter_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir |
1032 | * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_iter_ld |
1033 | * sizeof(typename T::src_iter_t); |
1034 | bool is_lstm = rd.cell_kind == dnnl_vanilla_lstm; |
1035 | rnn.ws_states_iter_c_size = is_lstm ? (size_t)(rnn.n_layer + 1) * rnn.n_dir |
1036 | * (rnn.n_iter + 1) * rnn.mb * rnn.ws_states_iter_c_ld |
1037 | * types::data_type_size(rnn.src_iter_c_dt) |
1038 | : 0; |
1039 | |
1040 | rnn.ws_diff_states_layer_size = rnn.is_training |
1041 | ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb |
1042 | * rnn.ws_diff_states_layer_ld |
1043 | * sizeof(typename T::gemm_acc_t) |
1044 | : (size_t)0; |
1045 | rnn.ws_diff_states_iter_size = rnn.is_training |
1046 | ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb |
1047 | * rnn.ws_diff_states_iter_ld |
1048 | * sizeof(typename T::gemm_acc_t) |
1049 | : (size_t)0; |
1050 | rnn.ws_diff_states_iter_c_size = rnn.is_training && is_lstm |
1051 | ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb |
1052 | * rnn.ws_diff_states_iter_c_ld |
1053 | * sizeof(typename T::gemm_acc_t) |
1054 | : (size_t)0; |
1055 | |
1056 | rnn.ws_gates_size = rnn.is_training |
1057 | ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_gates_nld |
1058 | * rnn.ws_gates_ld * sizeof(typename T::gates_t) |
1059 | : (size_t)0; |
1060 | rnn.ws_ht_size = rnn.is_training |
1061 | ? (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.ws_ht_nld |
1062 | * rnn.ws_ht_ld * sizeof(typename T::dst_iter_t) |
1063 | : (size_t)0; |
1064 | rnn.n_iter_scratch_gates |
1065 | = (rnn.merge_gemm_layer || rnn.merge_gemm_iter) ? rnn.n_iter : 1; |
1066 | rnn.scratch_gates_size = sizeof(typename T::scratch_t) |
1067 | * rnn.n_iter_scratch_gates * rnn.scratch_gates_nld |
1068 | * rnn.scratch_gates_ld; |
1069 | rnn.scratch_ht_size |
1070 | = sizeof(typename T::ht_t) * rnn.scratch_ht_nld * rnn.scratch_ht_ld; |
1071 | rnn.scratch_diff_ht_size = rnn.is_training ? sizeof(typename T::gemm_acc_t) |
1072 | * rnn.scratch_diff_ht_nld * rnn.scratch_diff_ht_ld |
1073 | : (size_t)0; |
1074 | |
1075 | /* set other sizes */ |
1076 | /// scratchpad buffer for each cell to hold intermediate data in gru/lbr_gru |
1077 | rnn.scratch_cell_size = rnn.is_lbr |
1078 | ? (size_t)rnn.scratch_gates_nld * rnn.scratch_gates_ld |
1079 | * sizeof(typename T::gemm_acc_t) |
1080 | : (utils::one_of(rd.cell_kind, alg_kind::vanilla_gru, |
1081 | alg_kind::vanilla_augru) |
1082 | ? (size_t)rnn.ws_states_layer_nld |
1083 | * rnn.ws_states_layer_ld |
1084 | * sizeof(typename T::gemm_acc_t) |
1085 | : 0); |
1086 | /// workspace needed for lbr GRU |
1087 | rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dhc |
1088 | * sizeof(typename T::gemm_acc_t); |
1089 | rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer |
1090 | * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float); |
1091 | /// bias ws needed to add compensation in int8 |
1092 | rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dhc |
1093 | * types::data_type_size(rnn.bias_dt); |
1094 | } |
1095 | |
1096 | void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, |
1097 | size_t &ws_ht_offset, size_t &ws_state_layer_offset, |
1098 | size_t &ws_states_iter_offset, size_t &ws_states_iter_c_offset, |
1099 | size_t &ws_diff_states_layer_offset, size_t &ws_diff_states_iter_offset, |
1100 | size_t &ws_diff_states_iter_c_offset, size_t &ws_grid_comp_offset, |
1101 | size_t &ws_bias_offset, size_t &scratch_gates_offset, |
1102 | size_t &scratch_ht_offset, size_t &scratch_diff_ht_offset, |
1103 | size_t &scratch_cell_offset, size_t &scratchpad_size, |
1104 | size_t &workspace_size); |
1105 | |
1106 | void get_scratchpad_and_workspace_sizes( |
1107 | const rnn_conf_t &rnn, size_t &scratchpad_size, size_t &workspace_size); |
1108 | status_t set_expected_desc(rnn_conf_t &rnn, memory_desc_t &weights_md, |
1109 | weights_type_t weights_type); |
1110 | status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag); |
1111 | |
1112 | using byte = unsigned char; |
1113 | template <size_t Tdims> |
1114 | struct raw_array_offset_calculator_t { |
1115 | template <typename... Targs> |
1116 | raw_array_offset_calculator_t( |
1117 | const byte *base, const dim_t dt_size, Targs... Fargs) |
1118 | : base_ptr_(base), dt_size_(dt_size), dims_ {Fargs...} {} |
1119 | |
1120 | template <typename... Targs> |
1121 | raw_array_offset_calculator_t(std::nullptr_t, Targs... Fargs) = delete; |
1122 | |
1123 | template <typename... Targs> |
1124 | inline const void *operator()(Targs... Fargs) const { |
1125 | assert(static_cast<bool>(base_ptr_)); |
1126 | return base_ptr_ + (offset(1, Fargs...) * dt_size_); |
1127 | } |
1128 | |
1129 | private: |
1130 | template <typename... Targs> |
1131 | inline size_t offset(size_t const dimension, size_t element) const { |
1132 | return element; |
1133 | } |
1134 | template <typename... Targs> |
1135 | inline size_t offset( |
1136 | size_t const dimension, size_t theta, size_t element) const { |
1137 | return element + (dims_[dimension] * theta); |
1138 | } |
1139 | |
1140 | template <typename... Targs> |
1141 | inline size_t offset(size_t const dimension, size_t theta, size_t element, |
1142 | Targs... Fargs) const { |
1143 | const size_t t_prime = element + (dims_[dimension] * theta); |
1144 | return offset(dimension + 1, t_prime, Fargs...); |
1145 | } |
1146 | |
1147 | const byte *const base_ptr_; |
1148 | const dim_t dt_size_; |
1149 | const int dims_[Tdims]; |
1150 | }; |
1151 | |
1152 | template <typename... Targs> |
1153 | raw_array_offset_calculator_t<sizeof...(Targs)> make_raw_aoc( |
1154 | const void *base, const dim_t dt_size, Targs... Fargs) { |
1155 | return raw_array_offset_calculator_t<sizeof...(Targs)>( |
1156 | static_cast<const byte *>(base), dt_size, |
1157 | std::forward<Targs>(Fargs)...); |
1158 | } |
1159 | |
1160 | template <typename T> |
1161 | struct ws_gates_aoc { |
1162 | ws_gates_aoc(const rnn_conf_t &rnn, T *data) |
1163 | : gates_(data, rnn.ws_gates_nld, rnn.ws_gates_ld), DHC_(rnn.dhc) {} |
1164 | T &operator()(int batch, int gate, int dhc) const { |
1165 | return gates_(batch, gate * DHC_ + dhc); |
1166 | } |
1167 | |
1168 | private: |
1169 | const dnnl::impl::utils::array_offset_calculator<T, 2> gates_; |
1170 | const int DHC_; |
1171 | }; |
1172 | using ws_gates_aoc_t = ws_gates_aoc<float>; |
1173 | using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>; |
1174 | |
1175 | template <typename T> |
1176 | struct ws_ht_aoc { |
1177 | ws_ht_aoc(const rnn_conf_t &rnn, T *data) |
1178 | : ht_(data, rnn.ws_ht_nld, rnn.ws_ht_ld) {} |
1179 | T &operator()(int batch, int dhc) const { return ht_(batch, dhc); } |
1180 | |
1181 | private: |
1182 | const dnnl::impl::utils::array_offset_calculator<T, 2> ht_; |
1183 | }; |
1184 | |
1185 | template <typename T> |
1186 | struct scratch_gates_aoc { |
1187 | scratch_gates_aoc(const rnn_conf_t &rnn, T *data) |
1188 | : gates_(data, rnn.scratch_gates_nld, rnn.scratch_gates_ld) |
1189 | , DHC_(rnn.dhc) {} |
1190 | T &operator()(int batch, int gate, int dhc) const { |
1191 | return gates_(batch, gate * DHC_ + dhc); |
1192 | } |
1193 | |
1194 | private: |
1195 | const dnnl::impl::utils::array_offset_calculator<T, 2> gates_; |
1196 | const int DHC_; |
1197 | }; |
1198 | using scratch_gates_aoc_t = scratch_gates_aoc<float>; |
1199 | using scratch_gates_aoc_s32_t = scratch_gates_aoc<int32_t>; |
1200 | |
1201 | template <typename T> |
1202 | struct scratch_ht_aoc { |
1203 | scratch_ht_aoc(const rnn_conf_t &rnn, T *data) |
1204 | : ht_(data, rnn.scratch_ht_nld, rnn.scratch_ht_ld) {} |
1205 | T &operator()(int batch, int dhc) const { return ht_(batch, dhc); } |
1206 | |
1207 | private: |
1208 | const dnnl::impl::utils::array_offset_calculator<T, 2> ht_; |
1209 | }; |
1210 | using scratch_ht_aoc_t = scratch_ht_aoc<float>; |
1211 | using scratch_ht_aoc_s32_t = scratch_ht_aoc<int32_t>; |
1212 | |
1213 | template <typename T> |
1214 | struct weights_peephole_aoc_t { |
1215 | weights_peephole_aoc_t(const rnn_conf_t &rnn, T *data) |
1216 | : weights_peephole_(data, 3, rnn.dhc) {} |
1217 | T &operator()(int g, int dhc) const { return weights_peephole_(g, dhc); } |
1218 | |
1219 | private: |
1220 | const utils::array_offset_calculator<T, 2> weights_peephole_; |
1221 | }; |
1222 | |
1223 | float to_float(const void *data, const data_type_t dt); |
1224 | |
1225 | struct bias_linear_exec_aoc_t { |
1226 | bias_linear_exec_aoc_t(const rnn_conf_t &rnn, void **bias) |
1227 | : bias_dt_(rnn.bias_dt), bias_present_(static_cast<bool>(bias)) { |
1228 | |
1229 | if (bias_dt_ == data_type::f32) |
1230 | new (std::addressof(bias_f32_aoc_)) |
1231 | utils::array_offset_calculator<float *, 3>( |
1232 | reinterpret_cast<float **>(bias), rnn.n_layer, |
1233 | rnn.n_dir, rnn.n_parts_bias); |
1234 | else |
1235 | new (std::addressof(bias_bf16_aoc_)) |
1236 | utils::array_offset_calculator<bfloat16_t *, 3>( |
1237 | reinterpret_cast<bfloat16_t **>(bias), rnn.n_layer, |
1238 | rnn.n_dir, rnn.n_parts_bias); |
1239 | } |
1240 | |
1241 | void **operator()(int layer, int dir) const { |
1242 | if (bias_present_) { |
1243 | if (bias_dt_ == data_type::f32) |
1244 | return reinterpret_cast<void **>( |
1245 | &bias_f32_aoc_.operator()(layer, dir, 0)); |
1246 | else if (bias_dt_ == data_type::bf16) |
1247 | return reinterpret_cast<void **>( |
1248 | &bias_bf16_aoc_.operator()(layer, dir, 0)); |
1249 | } |
1250 | |
1251 | return nullptr; |
1252 | } |
1253 | |
1254 | ~bias_linear_exec_aoc_t() { |
1255 | if (bias_dt_ == data_type::f32) |
1256 | bias_f32_aoc_.~array_offset_calculator<float *, 3>(); |
1257 | else |
1258 | bias_bf16_aoc_.~array_offset_calculator<bfloat16_t *, 3>(); |
1259 | } |
1260 | |
1261 | DNNL_DISALLOW_COPY_AND_ASSIGN(bias_linear_exec_aoc_t); |
1262 | bias_linear_exec_aoc_t(bias_linear_exec_aoc_t &&) = delete; |
1263 | bias_linear_exec_aoc_t &operator=(bias_linear_exec_aoc_t &&) = delete; |
1264 | |
1265 | private: |
1266 | data_type_t bias_dt_; |
1267 | bool bias_present_; |
1268 | union { |
1269 | utils::array_offset_calculator<float *, 3> bias_f32_aoc_; |
1270 | utils::array_offset_calculator<bfloat16_t *, 3> bias_bf16_aoc_; |
1271 | }; |
1272 | }; |
1273 | |
1274 | template <typename T> |
1275 | struct ws_states_layer_aoc { |
1276 | ws_states_layer_aoc(const rnn_conf_t &rnn, T *data, int leading_dim) |
1277 | : state_(data, rnn.ws_states_layer_nld, leading_dim) {} |
1278 | ws_states_layer_aoc(const rnn_conf_t &rnn, T *data) |
1279 | : state_(data, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld) {} |
1280 | T &operator()(int batch, int dhc) const { return state_(batch, dhc); } |
1281 | |
1282 | private: |
1283 | const dnnl::impl::utils::array_offset_calculator<T, 2> state_; |
1284 | }; |
1285 | |
1286 | template <typename T> |
1287 | struct ws_states_iter_aoc { |
1288 | ws_states_iter_aoc(const rnn_conf_t &rnn, T *data, int leading_dim) |
1289 | : state_(data, rnn.ws_states_iter_nld, leading_dim) {} |
1290 | ws_states_iter_aoc(const rnn_conf_t &rnn, T *data) |
1291 | : state_(data, rnn.ws_states_iter_nld, rnn.ws_states_iter_ld) {} |
1292 | T &operator()(int batch, int dhc) const { return state_(batch, dhc); } |
1293 | |
1294 | private: |
1295 | const dnnl::impl::utils::array_offset_calculator<T, 2> state_; |
1296 | }; |
1297 | |
1298 | template <typename T> |
1299 | struct augru_attention_aoc { |
1300 | augru_attention_aoc(const rnn_conf_t &rnn, T *data) |
1301 | : state_(data, rnn.mb) {} |
1302 | T &operator()(int batch) const { return state_(batch); } |
1303 | |
1304 | private: |
1305 | const dnnl::impl::utils::array_offset_calculator<T, 1> state_; |
1306 | }; |
1307 | |
1308 | template <typename T> |
1309 | struct ws_diff_states_layer_aoc { |
1310 | ws_diff_states_layer_aoc(const rnn_conf_t &rnn, T *data) |
1311 | : diff_states_layer_(data, rnn.ws_diff_states_layer_nld, |
1312 | rnn.ws_diff_states_layer_ld) {} |
1313 | T &operator()(int batch, int dhc) const { |
1314 | return diff_states_layer_(batch, dhc); |
1315 | } |
1316 | |
1317 | private: |
1318 | const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_layer_; |
1319 | }; |
1320 | |
1321 | template <typename T> |
1322 | struct ws_diff_states_iter_aoc { |
1323 | ws_diff_states_iter_aoc(const rnn_conf_t &rnn, T *data) |
1324 | : diff_states_iter_(data, rnn.ws_diff_states_iter_nld, |
1325 | rnn.ws_diff_states_iter_ld) {} |
1326 | T &operator()(int batch, int dhc) const { |
1327 | return diff_states_iter_(batch, dhc); |
1328 | } |
1329 | |
1330 | private: |
1331 | const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_; |
1332 | }; |
1333 | |
1334 | template <typename T> |
1335 | struct ws_diff_states_iter_c_aoc { |
1336 | ws_diff_states_iter_c_aoc(const rnn_conf_t &rnn, T *data) |
1337 | : diff_states_iter_c_(data, rnn.ws_diff_states_iter_c_nld, |
1338 | rnn.ws_diff_states_iter_c_ld) {} |
1339 | T &operator()(int batch, int dhc) const { |
1340 | return diff_states_iter_c_(batch, dhc); |
1341 | } |
1342 | |
1343 | private: |
1344 | const dnnl::impl::utils::array_offset_calculator<T, 2> diff_states_iter_c_; |
1345 | }; |
1346 | |
1347 | struct ws_diff_w_iter_aoc_t { |
1348 | ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data) |
1349 | : diff_weights_iter_( |
1350 | data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld) |
1351 | , DHC_(rnn.dhc) {} |
1352 | float &operator()(int sic, int gate, int dhc) const { |
1353 | return diff_weights_iter_(sic, gate * DHC_ + dhc); |
1354 | } |
1355 | |
1356 | private: |
1357 | const dnnl::impl::utils::array_offset_calculator<float, 2> |
1358 | diff_weights_iter_; |
1359 | const int DHC_; |
1360 | }; |
1361 | |
1362 | const void *inc_ptr(const void *data, data_type_t data_type, int offset); |
1363 | void *inc_ptr(void *data, data_type_t data_type, int offset); |
1364 | |
1365 | } // namespace rnn_utils |
1366 | } // namespace cpu |
1367 | } // namespace impl |
1368 | } // namespace dnnl |
1369 | #endif |
1370 | |