1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <tuple>
17#include <utility>
18#include "common/dnnl_thread.hpp"
19#include "cpu/rnn/rnn_utils.hpp"
20#include "cpu/x64/rnn/rnn_brgemm_utils.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26namespace rnn_brgemm_utils {
27
28namespace {
29
30x64::cpu_isa_t brgemm_calc_isa(dim_t K1, dim_t K2, bool is_int8, bool is_bf16);
31std::pair<dim_t, dim_t> brgemm_calc_k_block(dim_t K1, dim_t K2, dim_t M,
32 dim_t n_block, alg_kind_t cell_kind, dim_t src_layer_type_size,
33 dim_t As, dim_t Bs, dim_t Cs, dim_t l2_cache_size, x64::cpu_isa_t isa,
34 bool is_int8, bool is_bf16);
35std::pair<dim_t, dim_t> brgemm_calc_k_block_amx(
36 dim_t K1, dim_t K2, bool is_int8);
37std::pair<dim_t, dim_t> brgemm_calc_k_block_vanilla_rnn(dim_t K1, dim_t K2,
38 dim_t M, dim_t n_block, dim_t src_layer_type_size, dim_t As, dim_t Bs,
39 dim_t Cs, dim_t l2_cache_size, bool is_bf16);
40
41dim_t brgemm_calc_m_block(alg_kind_t cell_kind, prop_kind_t aprop, dim_t nthr,
42 dim_t M, dim_t N_blocks, bool is_f32, bool is_int8_amx,
43 bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs, dim_t Cs,
44 dim_t l2_cache_size);
45dim_t brgemm_calc_m_block_vanilla_rnn(dim_t nthr, dim_t M, dim_t N_blocks,
46 bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs,
47 dim_t Cs, dim_t l2_cache_size);
48dim_t brgemm_calc_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_f32,
49 bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Cs,
50 dim_t l2_cache_size);
51dim_t adjust_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_int8_amx,
52 bool is_bf16_amx);
53
54x64::cpu_isa_t brgemm_calc_isa(dim_t K1, dim_t K2, bool is_int8, bool is_bf16) {
55 const bool is_amx_int8 = is_int8 && x64::mayiuse(x64::avx512_core_amx);
56 const bool is_amx_bf16 = is_bf16 && x64::mayiuse(x64::avx512_core_amx);
57
58 if (is_amx_int8 || is_amx_bf16) {
59 const dim_t padding = (is_int8 ? 4 : (is_bf16 ? 2 : 1));
60 const auto result = brgemm_calc_k_block_amx(K1, K2, is_int8);
61 const auto k1_block_amx = result.first;
62 const auto k2_block_amx = result.second;
63 const auto k1_block_tail = K1 % k1_block_amx;
64 const auto k2_block_tail = K2 % k2_block_amx;
65 const bool amx_block_invalid = k1_block_tail % padding
66 || k2_block_tail % padding || k1_block_amx % padding
67 || k2_block_amx % padding;
68
69 if (!amx_block_invalid) return x64::avx512_core_amx;
70 }
71
72 if (is_int8) {
73 return x64::avx512_core_vnni;
74 } else if (is_bf16) {
75 return x64::avx512_core_bf16;
76 }
77
78 return x64::isa_undef;
79}
80
81std::pair<dim_t, dim_t> brgemm_calc_k_block(dim_t K1, dim_t K2, dim_t M,
82 dim_t n_block, alg_kind_t cell_kind, dim_t src_layer_type_size,
83 dim_t As, dim_t Bs, dim_t Cs, dim_t l2_cache_size, x64::cpu_isa_t isa,
84 bool is_int8, bool is_bf16) {
85 const bool is_amx_int8 = is_int8 && isa == x64::avx512_core_amx;
86 const bool is_amx_bf16 = is_bf16 && isa == x64::avx512_core_amx;
87
88 if (is_amx_int8 || is_amx_bf16)
89 return brgemm_calc_k_block_amx(K1, K2, is_int8);
90 else if (cell_kind == alg_kind::vanilla_rnn)
91 return brgemm_calc_k_block_vanilla_rnn(K1, K2, M, n_block,
92 src_layer_type_size, As, Bs, Cs, l2_cache_size, is_bf16);
93
94 return std::make_pair(K1, K2);
95}
96
97std::pair<dim_t, dim_t> brgemm_calc_k_block_amx(
98 dim_t K1, dim_t K2, bool is_int8) {
99 const bool is_amx_int8 = is_int8 && x64::mayiuse(x64::avx512_core_amx);
100 const dim_t max_row_width = is_amx_int8 ? 64 : 32;
101
102 dim_t k1_block = nstl::min(K1, max_row_width);
103 dim_t k2_block = nstl::min(K2, max_row_width);
104
105 if (k1_block <= K1 || k2_block <= K2) {
106 const dim_t t_k_block = nstl::min(k1_block, k2_block);
107 k2_block = k1_block = t_k_block;
108 }
109
110 return std::make_pair(k1_block, k2_block);
111}
112
113std::pair<dim_t, dim_t> brgemm_calc_k_block_vanilla_rnn(dim_t K1, dim_t K2,
114 dim_t M, dim_t n_block, dim_t src_layer_type_size, dim_t As, dim_t Bs,
115 dim_t Cs, dim_t l2_cache_size, bool is_bf16) {
116
117 //Heuristics experimentally selected.
118 const bool should_adjust_by_l2 = static_cast<float>(As + Bs + Cs)
119 >= 0.25 * static_cast<float>(l2_cache_size);
120 dim_t k1_block = K1;
121 dim_t k2_block = K2;
122
123 if (should_adjust_by_l2) {
124 int block_size = (l2_cache_size * 0.25f)
125 / ((M + n_block) * src_layer_type_size);
126
127 if (is_bf16) {
128 // due to weights format ldgOI32o2i block_size should be even
129 block_size -= (block_size % 2);
130 block_size = nstl::max(block_size, 0);
131 }
132 if (block_size) {
133 k1_block = nstl::min(K1, static_cast<dim_t>(block_size));
134 k2_block = nstl::min(K2, static_cast<dim_t>(block_size));
135 }
136 }
137
138 return std::make_pair(k1_block, k2_block);
139}
140
141dim_t brgemm_calc_m_block(alg_kind_t cell_kind, prop_kind_t aprop, dim_t nthr,
142 dim_t M, dim_t N_blocks, bool is_f32, bool is_int8_amx,
143 bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs, dim_t Cs,
144 dim_t l2_cache_size) {
145 if (cell_kind == alg_kind::vanilla_rnn
146 || (cell_kind == alg_kind::vanilla_lstm
147 && aprop == prop_kind::backward))
148 return brgemm_calc_m_block_vanilla_rnn(nthr, M, N_blocks, is_int8_amx,
149 is_bf16_amx, work_by_N, As, Bs, Cs, l2_cache_size);
150 else
151 return brgemm_calc_m_block_lstm(nthr, M, N_blocks, is_f32, is_int8_amx,
152 is_bf16_amx, work_by_N, As, Cs, l2_cache_size);
153}
154
155dim_t brgemm_calc_m_block_vanilla_rnn(dim_t nthr, dim_t M, dim_t N_blocks,
156 bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs,
157 dim_t Cs, dim_t l2_cache_size) {
158
159 //Heuristics experimentally selected.
160 const float decimal_n_factor = work_by_N - std::floor(work_by_N);
161 static constexpr float thread_balance_threashold = 0.9;
162
163 dim_t m_block = M;
164
165 if (work_by_N < 1.0)
166 return adjust_m_block_lstm(nthr, M, N_blocks, is_int8_amx, is_bf16_amx);
167 else if (decimal_n_factor < thread_balance_threashold
168 && decimal_n_factor != 0.0f) {
169
170 const dim_t m_block_start = M / 2;
171 const dim_t m_block_end = 8;
172
173 float max_decimal_mn = 0.0;
174 dim_t best_candidate = 0.0;
175 bool found_best_solution = false;
176
177 for (dim_t m_block_it = m_block_start; m_block_it >= m_block_end;
178 m_block_it--) {
179 if (M % m_block_it == 0) {
180 const auto m_blocks = M / m_block_it;
181 const auto work_by_MN
182 = static_cast<float>(m_blocks * N_blocks) / nthr;
183
184 const float work_by_MN_decimal
185 = work_by_MN - std::floor(work_by_MN);
186
187 static constexpr float tolerance = 0.01;
188 if (work_by_MN_decimal > (max_decimal_mn + tolerance)) {
189 best_candidate = m_block_it;
190 max_decimal_mn = work_by_MN_decimal;
191 }
192
193 if (work_by_MN_decimal >= thread_balance_threashold
194 || work_by_MN_decimal == 0.0f) {
195 m_block = m_block_it;
196 found_best_solution = true;
197 break;
198 }
199 }
200 }
201
202 if (!found_best_solution) {
203 if ((decimal_n_factor < max_decimal_mn)
204 || (static_cast<float>(As)
205 > (0.5f * static_cast<float>(l2_cache_size)))) {
206 m_block = best_candidate;
207 }
208 }
209 }
210
211 return m_block;
212}
213
214dim_t brgemm_calc_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_f32,
215 bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Cs,
216 dim_t l2_cache_size) {
217 const bool adj_by_l2 = is_f32
218 ? true
219 : (static_cast<float>(As + Cs)
220 < 0.6 * static_cast<float>(l2_cache_size));
221
222 if (work_by_N > 2.0 || (work_by_N > 1.0 && adj_by_l2))
223 return M;
224 else
225 return adjust_m_block_lstm(nthr, M, N_blocks, is_int8_amx, is_bf16_amx);
226}
227
228dim_t adjust_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_int8_amx,
229 bool is_bf16_amx) {
230
231 const bool is_amx = is_int8_amx || is_bf16_amx;
232
233 const dim_t max_m_blocks = (is_amx ? 1 : 4) * utils::div_up(nthr, N_blocks);
234 const dim_t max_m_value = is_amx ? 64 : 24;
235 const dim_t max_M
236 = nstl::min(max_m_value, nstl::max((dim_t)1, M / max_m_blocks));
237 const dim_t min_M = 4;
238
239 dim_t m_block = 1;
240 for (dim_t m = max_M; m >= min_M; m--)
241 if (M % m == 0) {
242 m_block = m;
243 break;
244 }
245 if (m_block == 1) m_block = M;
246
247 return m_block;
248}
249
250x64::cpu_isa_t adjust_isa_by_m_block(
251 x64::cpu_isa_t current_isa, dim_t m_block, bool is_int8_amx) {
252 /*
253 * If we have m<4 TMUL and AVX512 vnni calculate the same number of
254 * operation per instruction but TMUL is 2x slower for int8 in terms of
255 * throughput.
256 */
257 if (is_int8_amx && m_block < 4) {
258 if (x64::mayiuse(x64::avx512_core_amx)) return x64::avx512_core_amx;
259 }
260
261 return current_isa;
262}
263
264} // namespace
265
266void rnn_brgemm_base_t::init_scratchpad(const cpu::rnn_utils::rnn_conf_t &rnn,
267 memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size,
268 dim_t gemm_acc_align) {
269
270 using namespace memory_tracking::names;
271
272 if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) {
273 const auto m_block = rnn.merge_gemm_layer
274 ? nstl::max(rnn.m_block, rnn.mlayermerged_block)
275 : rnn.m_block;
276 size_t n_elements = m_block * rnn.n_block;
277 scratchpad.book(key_brgemm_primitive_buffer, rnn.nthr * n_elements,
278 gemm_acc_type_size, gemm_acc_align);
279 }
280
281 if (rnn.is_bf32()) {
282
283 const dims_t wei_layer_dims
284 = {rnn.n_layer, rnn.n_dir, rnn.n_gates, rnn.slc, rnn.dlc};
285 const dims_t wei_iter_dims
286 = {rnn.n_layer, rnn.n_dir, rnn.n_gates, rnn.sic, rnn.dic};
287
288 memory_desc_t wei_layer_desc;
289 const auto tag = rnn.n_block == 64 ? format_tag::ldgOI64o2i
290 : format_tag::ldgOI32o2i;
291 memory_desc_init_by_tag(
292 wei_layer_desc, 5, wei_layer_dims, data_type::bf16, tag);
293
294 memory_desc_t wei_iter_desc;
295 memory_desc_init_by_tag(
296 wei_iter_desc, 5, wei_iter_dims, data_type::bf16, tag);
297
298 scratchpad.book(key_rnn_bf32_wei_layer_trans,
299 memory_desc_wrapper(wei_layer_desc).size(), 64);
300
301 scratchpad.book(key_rnn_bf32_wei_iter_trans,
302 memory_desc_wrapper(wei_iter_desc).size(), 64);
303
304 scratchpad.book(key_rnn_bf32_attention_trans,
305 rnn.n_iter * rnn.mb * sizeof(bfloat16_t), 64);
306 }
307
308 const int max_K_Block
309 = nstl::max(rnn.KB1_blocks + 1,
310 nstl::max(rnn.KBproj_blocks + 1, rnn.KB2_blocks + 1))
311 * (rnn.brgemm_fwd_iter_layer_fuse_possible ? 2 : 1);
312 scratchpad.template book<x64::brgemm_batch_element_t>(
313 key_brgemm_primitive_batch, max_K_Block * rnn.nthr);
314}
315
316status_t rnn_brgemm_t<prop_kind::forward>::configure_brgemm(
317 cpu::rnn_utils::rnn_conf_t &rnn, alg_kind_t cell_kind,
318 dim_t src_layer_type_size, dim_t scratch_type_size) {
319 using namespace cpu::rnn_utils;
320
321 rnn.M = rnn.mb;
322 rnn.N = rnn.dhc;
323 rnn.K1 = rnn.slc;
324 rnn.K2 = rnn.sic;
325 const auto is_int8 = rnn.is_cell_dt_int8();
326 const auto is_bf16 = rnn.is_cell_dt_bf16();
327
328 const dim_t padding = (is_int8 ? 4 : (is_bf16 ? 2 : 1));
329 rnn.K1padded = utils::rnd_up(rnn.K1, padding);
330 rnn.K2padded = utils::rnd_up(rnn.K2, padding);
331
332 rnn.brgemm_isa = brgemm_calc_isa(rnn.K1, rnn.K2, is_int8, is_bf16);
333 const int bf32_reduction_dim_threshold = 128;
334 const bool is_shape_ok_for_bf32 = rnn.K1 >= bf32_reduction_dim_threshold
335 && rnn.K2 >= bf32_reduction_dim_threshold;
336 const bool is_bf32 = is_bf16 && rnn.brgemm_isa == avx512_core_amx
337 && rnn.dt_conf == all_f32
338 // workspace data type and layouts can differ between fwd and bwd
339 // implementations during training.
340 && !rnn.is_training
341 // bf16 lstm_projection is not supported, so neither is bf32.
342 && !rnn.is_lstm_projection && is_shape_ok_for_bf32;
343 if (!IMPLICATION(rnn.is_cell_dt_bf16(), rnn.is_bf16_conf() || is_bf32))
344 return status::unimplemented;
345
346 rnn.nthr = dnnl_get_max_threads();
347 const bool is_amx_isa_selected
348 = rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx();
349 const bool can_use_block64
350 = is_amx_isa_selected && rnn.N % 64 == 0 && !rnn.is_lstm_projection;
351 rnn.n_block = can_use_block64 ? 64 : 32;
352 rnn.N_blocks = utils::div_up(rnn.N, rnn.n_block);
353 rnn.n_tail = rnn.N % rnn.n_block;
354
355 const float work_by_N
356 = static_cast<float>(rnn.N_blocks) / static_cast<float>(rnn.nthr);
357
358 const dim_t l2_cache_size = platform::get_per_core_cache_size(2);
359 const dim_t As = src_layer_type_size * rnn.M * (nstl::max(rnn.K1, rnn.K2));
360 const dim_t Bs
361 = src_layer_type_size * (nstl::max(rnn.K1, rnn.K2)) * rnn.n_block;
362 const dim_t Cs
363 = scratch_type_size * (rnn.n_gates + 1) * (rnn.M * rnn.n_block);
364
365 std::tie(rnn.k1_block, rnn.k2_block) = brgemm_calc_k_block(rnn.K1, rnn.K2,
366 rnn.M, rnn.n_block, cell_kind, src_layer_type_size, As, Bs, Cs,
367 l2_cache_size, rnn.brgemm_isa, is_int8, is_bf16);
368 rnn.KB1_blocks = rnn.K1 / rnn.k1_block;
369 rnn.k1_tail = rnn.K1 % rnn.k1_block;
370 rnn.KB2_blocks = rnn.K2 / rnn.k2_block;
371 rnn.k2_tail = rnn.K2 % rnn.k2_block;
372 rnn.m_block = brgemm_calc_m_block(cell_kind, prop_kind::forward, rnn.nthr,
373 rnn.M, rnn.N_blocks, rnn.is_cell_dt_f32(), rnn.is_cell_int8_amx(),
374 rnn.is_cell_bf16_amx(), work_by_N, As, Bs, Cs, l2_cache_size);
375
376 rnn.M_blocks = rnn.M / rnn.m_block;
377
378 rnn.brgemm_isa = adjust_isa_by_m_block(
379 rnn.brgemm_isa, rnn.m_block, rnn.is_cell_int8_amx());
380 // Unfused post-gemm for lstm cell allows to parallelize across gates loop
381 // and reduces brgemm problem size for the single iteration of parallel loop
382 rnn.unfused_post_gemm = cell_kind == alg_kind::vanilla_lstm
383 ? IMPLICATION(rnn.M_blocks > 1, rnn.is_cell_bf16_amx())
384 : false;
385
386 rnn.LDA1[0] = rnn.src_layer_ld_;
387 rnn.LDA1[1] = rnn.dst_iter_ld_;
388 rnn.LDA1[2] = rnn.ws_states_layer_ld;
389
390 rnn.LDA2[0] = rnn.src_iter_ld_;
391 rnn.LDA2[1] = rnn.dst_layer_ld_;
392 rnn.LDA2[2] = rnn.ws_states_iter_ld;
393
394 rnn.LDA2_2[0] = rnn.dst_layer_ld_;
395 rnn.LDA2_2[1] = rnn.dst_iter_ld_;
396 rnn.LDA2_2[2] = rnn.ws_states_layer_ld;
397 rnn.LDA2_2[3] = rnn.ws_states_iter_ld;
398
399 rnn.LDB1 = rnn.n_block;
400 rnn.LDB2 = rnn.n_block;
401 rnn.LDC = rnn.scratch_gates_ld;
402
403 auto get_dim = [&](dim_t block, dim_t tail) {
404 return (block == 0) ? tail : block;
405 };
406
407 dim_t n_block = nstl::min(rnn.N, rnn.n_block);
408 dim_t n_tail = nstl::min(rnn.N, rnn.nproj_tail);
409 if (rnn.LDA1[0] < rnn.k1_block && rnn.LDA1[1] < rnn.k1_block
410 && rnn.LDA1[2] < rnn.k1_block)
411 return status::unimplemented;
412 if (rnn.LDA2[0] < rnn.k2_block && rnn.LDA2[1] < rnn.k2_block
413 && rnn.LDA2[2] < rnn.k2_block)
414 return status::unimplemented;
415 if (rnn.LDB1 < get_dim(n_block, n_tail)
416 && rnn.LDB2 < get_dim(n_block, n_tail))
417 return status::unimplemented;
418 if (rnn.LDC < get_dim(n_block, n_tail)) return status::unimplemented;
419
420 rnn.KBproj_blocks = 0;
421 rnn.kproj_tail = 0;
422 rnn.kproj_block = 0;
423
424 if (rnn.is_lstm_projection) {
425 rnn.Nproj = rnn.dic;
426 rnn.Nproj_blocks = utils::div_up(rnn.Nproj, rnn.n_block);
427 rnn.nproj_tail = rnn.Nproj % rnn.n_block;
428
429 rnn.Kproj = rnn.dhc;
430 rnn.Kprojpadded = utils::rnd_up(rnn.Kproj, padding);
431 if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) {
432 const dim_t max_row_width = rnn.is_cell_int8_amx() ? 64 : 32;
433 rnn.kproj_block = nstl::min(rnn.Kproj, (dim_t)max_row_width);
434
435 rnn.KBproj_blocks = rnn.Kproj / rnn.kproj_block;
436 rnn.kproj_tail = rnn.Kproj % rnn.kproj_block;
437
438 if ((rnn.kproj_tail % padding) || (rnn.kproj_block % padding)) {
439 rnn.kproj_block = rnn.Kproj;
440 rnn.kproj_tail = 0;
441 rnn.brgemm_isa = rnn.is_cell_dt_int8() ? x64::avx512_core_vnni
442 : x64::avx512_core_bf16;
443 } else {
444 rnn.brgemm_isa = x64::avx512_core_amx;
445 }
446 } else {
447 rnn.kproj_block = rnn.Kproj;
448 rnn.KBproj_blocks = rnn.Kproj / rnn.kproj_block;
449 }
450 rnn.LDAproj = rnn.proj_ht_ld;
451 rnn.LDBproj = rnn.n_block;
452 if (rnn.dt_conf != cpu::rnn_utils::all_f32) {
453 rnn.LDCproj[0] = rnn.scratch_gates_ld;
454 } else {
455 rnn.LDCproj[0] = rnn.scratch_ht_ld;
456 rnn.LDCproj[1] = rnn.dst_layer_ld_;
457 rnn.LDCproj[2] = rnn.dst_iter_ld_;
458 rnn.LDCproj[3] = rnn.ws_states_layer_ld;
459 }
460
461 dim_t n_block = nstl::min(rnn.Nproj, rnn.n_block);
462 dim_t n_tail = nstl::min(rnn.Nproj, rnn.nproj_tail);
463 bool check_LDC = false;
464 if (rnn.dt_conf != cpu::rnn_utils::all_f32) {
465 check_LDC = rnn.LDCproj[0] < get_dim(n_block, n_tail);
466 } else {
467 check_LDC = rnn.LDCproj[0] < get_dim(n_block, n_tail)
468 && rnn.LDCproj[1] < get_dim(n_block, n_tail)
469 && rnn.LDCproj[2] < get_dim(n_block, n_tail)
470 && rnn.LDCproj[3] < get_dim(n_block, n_tail);
471 }
472 if (rnn.LDAproj < rnn.kproj_block
473 || rnn.LDBproj < get_dim(n_block, n_tail) || check_LDC)
474 return status::unimplemented;
475 }
476
477 // enable merged across n_iter dimension layer part of the cell computation
478 // TODO: extend coverage for other problem types
479 const bool mlc_cell_type_ok = cell_kind == alg_kind::vanilla_lstm
480 && !rnn.is_lstm_projection && !rnn.is_lstm_peephole;
481 const int mlc_mb_max_threshold = 1;
482 const int mlc_n_iter_min_threshold = 2;
483 const int mlc_n_layer_max_threshold = 1;
484 const bool mlc_problem_shape_ok = rnn.mb <= mlc_mb_max_threshold
485 && rnn.n_iter >= mlc_n_iter_min_threshold
486 && rnn.n_layer <= mlc_n_layer_max_threshold;
487 // if rnn.skip_dst_iter_copy() == false we might need to reduce number of
488 // merged cells by 1 (rnn.n_iter - 1)
489 // but if in addition rnn.skip_src_layer_copy() == true then on cell
490 // position 'first_layer' we still should merge for all the n_iter cells
491 // so current support is limited by the case when layer computation is
492 // merged for all rnn.n_iter cells
493 const bool mlc_m_dim_adjustment_not_required
494 = IMPLICATION(rnn.skip_dst_iter_copy(),
495 rnn.skip_src_layer_copy() && rnn.n_layer == 1);
496 const bool merged_layer_compute_applicable = rnn.src_layer_is_trivial_stride
497 && mlc_cell_type_ok && mlc_problem_shape_ok
498 && mlc_m_dim_adjustment_not_required;
499 if (merged_layer_compute_applicable) {
500 rnn.merge_gemm_layer = true;
501
502 // required adjustment if mlc_m_dim_adjustment_not_required = false
503 const int n_iters_to_merge = rnn.n_iter;
504 rnn.Mlayermerged = rnn.mb * n_iters_to_merge;
505 rnn.mlayermerged_block = brgemm_calc_m_block(cell_kind,
506 prop_kind::forward, rnn.nthr, rnn.Mlayermerged, rnn.N_blocks,
507 rnn.is_cell_dt_f32(), rnn.is_cell_int8_amx(),
508 rnn.is_cell_bf16_amx(), work_by_N, As, Bs, Cs, l2_cache_size);
509
510 rnn.Mlayermerged_blocks = rnn.Mlayermerged / rnn.mlayermerged_block;
511 }
512
513 rnn.brgemm_fwd_iter_layer_fuse_possible
514 = rnn.slc == rnn.sic && !rnn.merge_gemm_layer;
515
516 if (!rnn.is_orig_gru) {
517 rnn.loop_order = rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()
518 ? brgemm_rnn_execute_loop_order_t::mblk_nblk
519 : brgemm_rnn_execute_loop_order_t::nblk_mblk;
520 }
521 return status::success;
522}
523
524status_t init_brgemm_kernel(x64::brgemm_t *desc, x64::cpu_isa_t isa,
525 impl::data_type_t src_type, impl::data_type_t weights_type,
526 std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M, dim_t N, dim_t K,
527 dim_t LDA, dim_t LDB, dim_t LDC, float beta, dim_t max_bs,
528 dim_t hint_expected_A_size = LLONG_MAX,
529 dim_t hint_expected_B_size = LLONG_MAX,
530 dim_t hint_expected_C_size = LLONG_MAX) {
531 bool transA = false;
532 bool transB = false;
533 x64::brgemm_layout_t layout = x64::brgemm_row_major;
534 CHECK(brgemm_desc_init(desc, isa, x64::brgemm_addr, src_type, weights_type,
535 transA, transB, layout, 1.0, beta, LDA, LDB, LDC, M, N, K));
536
537 x64::brgemm_attr_t brgattr;
538
539 brgattr.hint_expected_A_size = hint_expected_A_size;
540 brgattr.hint_expected_B_size = hint_expected_B_size;
541 brgattr.hint_expected_C_size = hint_expected_C_size;
542 brgattr.max_bs = max_bs;
543 brgattr.max_top_vpad = 0;
544 brgattr.max_bottom_vpad = 0;
545 brgemm_desc_set_attr(desc, brgattr);
546
547 x64::brgemm_kernel_t *_t_ptr;
548 CHECK(brgemm_kernel_create(&_t_ptr, *desc));
549 safe_ptr_assign<x64::brgemm_kernel_t>(ker, _t_ptr);
550
551 return status::success;
552};
553
554status_t rnn_brgemm_t<prop_kind::forward>::brgemm_rnn_init_tiles(
555 brgemm_t *desc_array, dim_t size, brgemm_pallete_t pallete) {
556
557 for (dim_t it = 0; it < size; ++it) {
558 const auto &desc = desc_array[it];
559 const bool desc_empty
560 = utils::everyone_is(0, desc.LDA, desc.LDB, desc.LDC);
561 if (!desc_empty) return brgemm_init_tiles(desc, pallete);
562 }
563
564 return status::unimplemented;
565}
566
567status_t rnn_brgemm_t<prop_kind::forward>::brgemm_rnn_init_tiles(
568 brgemm_t *desc_array, brgemm_pallete_t pallete) {
569 return brgemm_rnn_init_tiles(desc_array, num_base_kernels_, pallete);
570}
571status_t rnn_brgemm_t<prop_kind::forward>::brgemm_rnn_init_tiles_proj(
572 brgemm_t *desc_array, brgemm_pallete_t pallete) {
573 return brgemm_rnn_init_tiles(desc_array, num_proj_kernels_, pallete);
574}
575
576status_t rnn_brgemm_t<prop_kind::forward>::init_kernels(
577 const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type,
578 data_type_t weights_type) {
579
580 const auto init_brgemm
581 = [&](x64::brgemm_t *desc, x64::cpu_isa_t isa,
582 std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M,
583 dim_t N, dim_t K, dim_t LDA, dim_t LDB, dim_t LDC,
584 float beta, dim_t max_bs) {
585 return init_brgemm_kernel(desc, isa, src_type, weights_type,
586 ker, M, N, K, LDA, LDB, LDC, beta, max_bs);
587 };
588
589 const int brgemm_n = nstl::min(rnn.N, rnn.n_block);
590 const int brgemm_n_tail = nstl::min(rnn.N, rnn.n_tail);
591 const int max_bs_factor = rnn.brgemm_fwd_iter_layer_fuse_possible ? 2 : 1;
592
593 for (int i = 0; i < num_base_kernels_; i++) {
594 if (rnn.merge_gemm_layer) {
595 init_brgemm(&desc_layermerged_b0_[i], rnn.brgemm_isa,
596 kernel_layermerged_b0_[i], rnn.mlayermerged_block, brgemm_n,
597 rnn.k1_block, rnn.LDA1[i], rnn.LDB1, rnn.LDC, 0.0,
598 rnn.KB1_blocks);
599 } else {
600 init_brgemm(&desc_layer_b0_[i], rnn.brgemm_isa, kernel_layer_b0_[i],
601 rnn.m_block, brgemm_n, rnn.k1_block, rnn.LDA1[i], rnn.LDB1,
602 rnn.LDC, 0.0, max_bs_factor * rnn.KB1_blocks);
603 }
604
605 init_brgemm(&desc_iter_b0_[i], rnn.brgemm_isa, kernel_iter_b0_[i],
606 rnn.m_block, brgemm_n, rnn.k2_block, rnn.LDA2[i], rnn.LDB2,
607 rnn.LDC, 0.0, rnn.KB2_blocks);
608 init_brgemm(&desc_iter_b1_[i], rnn.brgemm_isa, kernel_iter_b1_[i],
609 rnn.m_block, brgemm_n, rnn.k2_block, rnn.LDA2[i], rnn.LDB2,
610 rnn.LDC, 1.0, rnn.KB2_blocks);
611 if (rnn.n_tail) {
612 if (rnn.merge_gemm_layer) {
613 init_brgemm(&desc_layermerged_N_tail_b0_[i], rnn.brgemm_isa,
614 kernel_layermerged_N_tail_b0_[i],
615 rnn.mlayermerged_block, brgemm_n_tail, rnn.k1_block,
616 rnn.LDA1[i], rnn.LDB1, rnn.LDC, 0.0, rnn.KB1_blocks);
617 } else {
618 init_brgemm(&desc_layer_N_tail_b0_[i], rnn.brgemm_isa,
619 kernel_layer_N_tail_b0_[i], rnn.m_block, brgemm_n_tail,
620 rnn.k1_block, rnn.LDA1[i], rnn.LDB1, rnn.LDC, 0.0,
621 max_bs_factor * rnn.KB1_blocks);
622 }
623
624 init_brgemm(&desc_iter_N_tail_b0_[i], rnn.brgemm_isa,
625 kernel_iter_N_tail_b0_[i], rnn.m_block, brgemm_n_tail,
626 rnn.k2_block, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 0.0,
627 rnn.KB2_blocks);
628 init_brgemm(&desc_iter_N_tail_b1_[i], rnn.brgemm_isa,
629 kernel_iter_N_tail_b1_[i], rnn.m_block, brgemm_n_tail,
630 rnn.k2_block, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 1.0,
631 rnn.KB2_blocks);
632 }
633 if (rnn.k1_tail) {
634 if (rnn.merge_gemm_layer) {
635 init_brgemm(&desc_layermerged_K1_tail_b1_[i], rnn.brgemm_isa,
636 kernel_layermerged_K1_tail_b1_[i],
637 rnn.mlayermerged_block, brgemm_n, rnn.k1_tail,
638 rnn.LDA1[i], rnn.LDB1, rnn.LDC, 1.0, 1);
639 } else {
640 init_brgemm(&desc_layer_K1_tail_b1_[i], rnn.brgemm_isa,
641 kernel_layer_K1_tail_b1_[i], rnn.m_block, brgemm_n,
642 rnn.k1_tail, rnn.LDA1[i], rnn.LDB1, rnn.LDC, 1.0,
643 max_bs_factor * 1);
644 }
645 }
646 if (rnn.k2_tail)
647 init_brgemm(&desc_iter_K2_tail_b1_[i], rnn.brgemm_isa,
648 kernel_iter_K2_tail_b1_[i], rnn.m_block, brgemm_n,
649 rnn.k2_tail, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 1.0, 1);
650 if (rnn.k1_tail && rnn.n_tail) {
651 if (rnn.merge_gemm_layer) {
652 init_brgemm(&desc_layermerged_NK1_tail_b1_[i], rnn.brgemm_isa,
653 kernel_layermerged_NK1_tail_b1_[i],
654 rnn.mlayermerged_block, brgemm_n_tail, rnn.k1_tail,
655 rnn.LDA1[i], rnn.LDB1, rnn.LDC, 1.0, 1);
656 } else {
657 init_brgemm(&desc_layer_NK1_tail_b1_[i], rnn.brgemm_isa,
658 kernel_layer_NK1_tail_b1_[i], rnn.m_block,
659 brgemm_n_tail, rnn.k1_tail, rnn.LDA1[i], rnn.LDB1,
660 rnn.LDC, 1.0, max_bs_factor * 1);
661 }
662 }
663 if (rnn.k2_tail && rnn.n_tail)
664 init_brgemm(&desc_iter_NK2_tail_b1_[i], rnn.brgemm_isa,
665 kernel_iter_NK2_tail_b1_[i], rnn.m_block, brgemm_n_tail,
666 rnn.k2_tail, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 1.0, 1);
667 }
668 if (rnn.is_orig_gru) {
669 for (int i = 0; i < num_vanilla_gru_iter_part2_kernels_; i++) {
670 init_brgemm(&desc_iter_p2_b1_[i], rnn.brgemm_isa,
671 kernel_iter_p2_b1_[i], rnn.m_block, brgemm_n, rnn.k2_block,
672 rnn.LDA2_2[i], rnn.LDB2, rnn.LDC, 1.0, rnn.KB2_blocks);
673 if (rnn.n_tail)
674 init_brgemm(&desc_iter_p2_N_tail_b1_[i], rnn.brgemm_isa,
675 kernel_iter_p2_N_tail_b1_[i], rnn.m_block,
676 brgemm_n_tail, rnn.k2_block, rnn.LDA2_2[i], rnn.LDB2,
677 rnn.LDC, 1.0, rnn.KB2_blocks);
678 if (rnn.k2_tail)
679 init_brgemm(&desc_iter_p2_K2_tail_b1_[i], rnn.brgemm_isa,
680 kernel_iter_p2_K2_tail_b1_[i], rnn.m_block, brgemm_n,
681 rnn.k2_tail, rnn.LDA2_2[i], rnn.LDB2, rnn.LDC, 1.0, 1);
682 if (rnn.k2_tail && rnn.n_tail)
683 init_brgemm(&desc_iter_p2_NK2_tail_b1_[i], rnn.brgemm_isa,
684 kernel_iter_p2_NK2_tail_b1_[i], rnn.m_block,
685 brgemm_n_tail, rnn.k2_tail, rnn.LDA2_2[i], rnn.LDB2,
686 rnn.LDC, 1.0, 1);
687 }
688 }
689 if (rnn.is_lstm_projection) {
690 const dim_t brgemm_np = nstl::min(rnn.Nproj, rnn.n_block);
691 const dim_t brgemm_np_tail = nstl::min(rnn.Nproj, rnn.nproj_tail);
692 const int n_kernel = (rnn.dt_conf == cpu::rnn_utils::all_f32)
693 ? num_proj_kernels_
694 : 1;
695 for (int i = 0; i < n_kernel; i++) {
696 init_brgemm(&desc_proj_b0_[i], rnn.brgemm_isa, kernel_proj_b0_[i],
697 rnn.m_block, brgemm_np, rnn.kproj_block, rnn.LDAproj,
698 rnn.LDBproj, rnn.LDCproj[i], 0.0, rnn.KBproj_blocks);
699 if (rnn.nproj_tail) {
700 init_brgemm(&desc_proj_N_tail_b0_[i], rnn.brgemm_isa,
701 kernel_proj_N_tail_b0_[i], rnn.m_block, brgemm_np_tail,
702 rnn.kproj_block, rnn.LDAproj, rnn.LDBproj,
703 rnn.LDCproj[i], 0.0, rnn.KBproj_blocks);
704 init_brgemm(&desc_proj_N_tail_b1_[i], rnn.brgemm_isa,
705 kernel_proj_N_tail_b1_[i], rnn.m_block, brgemm_np_tail,
706 rnn.kproj_block, rnn.LDAproj, rnn.LDBproj,
707 rnn.LDCproj[i], 1.0, rnn.KBproj_blocks);
708 }
709 if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) {
710 if (rnn.kproj_tail)
711 init_brgemm(&desc_proj_K_tail_b1_[i], rnn.brgemm_isa,
712 kernel_proj_K_tail_b1_[i], rnn.m_block, brgemm_np,
713 rnn.kproj_tail, rnn.LDAproj, rnn.LDBproj,
714 rnn.LDCproj[i], 1.0, 1);
715 if (rnn.kproj_tail && rnn.nproj_tail)
716 init_brgemm(&desc_proj_NK_tail_b1_[i], rnn.brgemm_isa,
717 kernel_proj_NK_tail_b1_[i], rnn.m_block,
718 brgemm_np_tail, rnn.kproj_tail, rnn.LDAproj,
719 rnn.LDBproj, rnn.LDCproj[i], 1.0, 1);
720 }
721 }
722 }
723
724 if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) {
725 if (rnn.merge_gemm_layer)
726 CHECK(brgemm_rnn_init_tiles(
727 desc_layermerged_b0_, pallete_buff_layermerged_));
728 else
729 CHECK(brgemm_rnn_init_tiles(desc_layer_b0_, pallete_buff_layer_));
730 CHECK(brgemm_rnn_init_tiles(desc_iter_b0_, pallete_buff_iter_));
731
732 if (rnn.n_tail) {
733 if (rnn.merge_gemm_layer)
734 CHECK(brgemm_rnn_init_tiles(desc_layermerged_N_tail_b0_,
735 pallete_buff_layermerged_n_tail_));
736 else
737 CHECK(brgemm_rnn_init_tiles(
738 desc_layer_N_tail_b0_, pallete_buff_layer_n_tail_));
739 CHECK(brgemm_rnn_init_tiles(
740 desc_iter_N_tail_b0_, pallete_buff_iter_n_tail_));
741 }
742 if (rnn.k1_tail) {
743 if (rnn.merge_gemm_layer)
744 CHECK(brgemm_rnn_init_tiles(desc_layermerged_K1_tail_b1_,
745 pallete_buff_layermerged_k1_tail_));
746 else
747 CHECK(brgemm_rnn_init_tiles(
748 desc_layer_K1_tail_b1_, pallete_buff_k1_tail_));
749 }
750 if (rnn.k2_tail)
751 CHECK(brgemm_rnn_init_tiles(
752 desc_iter_K2_tail_b1_, pallete_buff_k2_tail_));
753 if (rnn.k1_tail && rnn.n_tail) {
754 if (rnn.merge_gemm_layer)
755 CHECK(brgemm_rnn_init_tiles(desc_layermerged_NK1_tail_b1_,
756 pallete_buff_layermerged_nk1_tail_));
757 else
758 CHECK(brgemm_rnn_init_tiles(
759 desc_layer_NK1_tail_b1_, pallete_buff_nk1_tail_));
760 }
761 if (rnn.k2_tail && rnn.n_tail)
762 CHECK(brgemm_rnn_init_tiles(
763 desc_iter_NK2_tail_b1_, pallete_buff_nk2_tail_));
764 if (rnn.is_lstm_projection) {
765 CHECK(brgemm_rnn_init_tiles_proj(
766 desc_proj_b0_, pallete_buff_proj_));
767 if (rnn.nproj_tail)
768 CHECK(brgemm_rnn_init_tiles_proj(
769 desc_proj_N_tail_b0_, pallete_buff_nproj_tail_));
770 if (rnn.kproj_tail)
771 CHECK(brgemm_rnn_init_tiles_proj(
772 desc_proj_K_tail_b1_, pallete_buff_kproj_tail_));
773 if (rnn.kproj_tail && rnn.nproj_tail)
774 CHECK(brgemm_rnn_init_tiles_proj(
775 desc_proj_NK_tail_b1_, pallete_buff_nkproj_tail_));
776 }
777 }
778
779 return status::success;
780}
781
782void rnn_brgemm_t<prop_kind::backward>::init_scratchpad(
783 const cpu::rnn_utils::rnn_conf_t &rnn,
784 memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size,
785 dim_t gemm_acc_align) {
786
787 rnn_brgemm_base_t::init_scratchpad(
788 rnn, scratchpad, gemm_acc_type_size, gemm_acc_align);
789
790 using namespace memory_tracking::names;
791
792 // init scratchpad for internal reorders:
793 const auto data_size
794 = rnn.is_bf16_conf() ? sizeof(bfloat16_t) : sizeof(float);
795 const auto &d_wei = rnn.diff_wei_brgemm;
796 const auto scratch_gates_blocked_per_thr = d_wei.Kpadded * d_wei.n_block;
797 const auto scratch_gates_blocked_size
798 = rnn.nthr * scratch_gates_blocked_per_thr;
799 scratchpad.book(key_rnn_gates_blocked, scratch_gates_blocked_size,
800 data_size, gemm_acc_align);
801
802 const auto scratch_src_layer_size = d_wei.global_transpose
803 ? d_wei.M_layer * d_wei.Kpadded
804 : rnn.nthr * std::min(d_wei.m_block, d_wei.M_layer) * d_wei.Kpadded;
805 scratchpad.book(key_rnn_src_layer_trans, scratch_src_layer_size, data_size,
806 gemm_acc_align);
807
808 const auto scratch_src_iter_size = d_wei.global_transpose
809 ? d_wei.M_iter * d_wei.Kpadded
810 : rnn.nthr * std::min(d_wei.m_block, d_wei.M_iter) * d_wei.Kpadded;
811 scratchpad.book(key_rnn_src_iter_trans, scratch_src_iter_size, data_size,
812 gemm_acc_align);
813}
814
815status_t rnn_brgemm_t<prop_kind::backward>::configure_brgemm(
816 cpu::rnn_utils::rnn_conf_t &rnn, alg_kind_t cell_kind,
817 dim_t src_layer_type_size, dim_t scratch_type_size) {
818 using namespace cpu::rnn_utils;
819
820 if (rnn.is_int8_conf() || rnn.is_cell_dt_int8())
821 return status::unimplemented;
822
823 auto &diff_src_conf = rnn.diff_src_brgemm;
824
825 diff_src_conf.M = rnn.mb;
826 diff_src_conf.N_iter = rnn.sic;
827 diff_src_conf.N_layer = rnn.slc;
828 diff_src_conf.N = nstl::max(diff_src_conf.N_iter, diff_src_conf.N_layer);
829 diff_src_conf.K = rnn.dhc;
830
831 rnn.nthr = dnnl_get_max_threads();
832 diff_src_conf.n_block = 32;
833 diff_src_conf.N_blocks
834 = utils::div_up(diff_src_conf.N, diff_src_conf.n_block);
835 diff_src_conf.n_tail = diff_src_conf.N % diff_src_conf.n_block;
836 diff_src_conf.N_layer_blocks
837 = utils::div_up(diff_src_conf.N_layer, diff_src_conf.n_block);
838 diff_src_conf.n_layer_tail = diff_src_conf.N_layer % diff_src_conf.n_block;
839 diff_src_conf.N_iter_blocks
840 = utils::div_up(diff_src_conf.N_iter, diff_src_conf.n_block);
841 diff_src_conf.n_iter_tail = diff_src_conf.N_iter % diff_src_conf.n_block;
842
843 const float work_by_N = static_cast<float>(diff_src_conf.N_blocks)
844 / static_cast<float>(rnn.nthr);
845
846 const dim_t l2_cache_size = platform::get_per_core_cache_size(2);
847 const dim_t As = src_layer_type_size * diff_src_conf.M * diff_src_conf.K;
848 const dim_t Bs
849 = src_layer_type_size * diff_src_conf.K * diff_src_conf.n_block;
850 const dim_t Cs = scratch_type_size * (rnn.n_gates + 1)
851 * (diff_src_conf.M * diff_src_conf.n_block);
852
853 const auto is_bf16 = rnn.is_cell_dt_bf16();
854
855 const dim_t padding = is_bf16 ? 2 : 1;
856 diff_src_conf.Kpadded = utils::rnd_up(diff_src_conf.K, padding);
857
858 diff_src_conf.isa
859 = brgemm_calc_isa(diff_src_conf.K, diff_src_conf.K, false, is_bf16);
860 const bool is_bf16_amx
861 = is_bf16 && diff_src_conf.isa == x64::avx512_core_amx;
862 const bool split_gates_computation = is_bf16_amx && diff_src_conf.K >= 1024
863 && diff_src_conf.n_tail == 0;
864 diff_src_conf.gates_block = split_gates_computation ? 1 : rnn.n_gates;
865
866 std::tie(diff_src_conf.k_block, std::ignore) = brgemm_calc_k_block(
867 diff_src_conf.K, diff_src_conf.K, diff_src_conf.M,
868 diff_src_conf.n_block, cell_kind, src_layer_type_size, As, Bs, Cs,
869 l2_cache_size, diff_src_conf.isa, false, rnn.is_cell_dt_bf16());
870
871 diff_src_conf.K_blocks = diff_src_conf.K / diff_src_conf.k_block;
872 diff_src_conf.K_blocks *= rnn.n_gates;
873 diff_src_conf.k_tail = diff_src_conf.K % diff_src_conf.k_block;
874
875 diff_src_conf.m_block = brgemm_calc_m_block(cell_kind, prop_kind::backward,
876 rnn.nthr, diff_src_conf.M, diff_src_conf.N_blocks,
877 rnn.is_cell_dt_f32(), false, is_bf16_amx, work_by_N, As, Bs, Cs,
878 l2_cache_size);
879
880 diff_src_conf.M_blocks = diff_src_conf.M / diff_src_conf.m_block;
881 diff_src_conf.LDA = rnn.scratch_gates_ld;
882 diff_src_conf.LDB = diff_src_conf.n_block;
883 diff_src_conf.LDC = rnn.ws_diff_states_iter_ld;
884
885 if (diff_src_conf.LDA < diff_src_conf.k_block) return status::unimplemented;
886
887 const dim_t n_block = nstl::min(diff_src_conf.N, diff_src_conf.n_block);
888
889 if (diff_src_conf.LDB < n_block) return status::unimplemented;
890 if (diff_src_conf.LDC < n_block) return status::unimplemented;
891
892 rnn.KBproj_blocks = 0;
893 rnn.kproj_tail = 0;
894 rnn.kproj_block = 0;
895
896 auto &diff_wei_conf = rnn.diff_wei_brgemm;
897 diff_wei_conf.global_transpose = rnn.mb > 1;
898 diff_wei_conf.M_iter = rnn.sic;
899 diff_wei_conf.M_layer = rnn.slc;
900 diff_wei_conf.M = nstl::max(rnn.sic, rnn.slc);
901 diff_wei_conf.N = rnn.dhc * rnn.n_gates;
902 diff_wei_conf.K = (scratch_type_size != sizeof(float))
903 ? utils::rnd_up(rnn.mb, 2)
904 : rnn.mb;
905 diff_wei_conf.Kpadded = utils::rnd_up(diff_wei_conf.K, padding);
906
907 diff_wei_conf.isa
908 = brgemm_calc_isa(diff_wei_conf.K, diff_wei_conf.K, false, is_bf16);
909
910 const bool is_wei_bf16_amx = rnn.is_cell_dt_bf16()
911 && diff_wei_conf.isa == x64::avx512_core_amx;
912 const bool diff_wei_can_use_nblock64 = is_wei_bf16_amx
913 && diff_wei_conf.N % 64 == 0 && !rnn.is_lstm_peephole;
914 diff_wei_conf.n_block = diff_wei_can_use_nblock64 ? 64 : 32;
915 diff_wei_conf.N_blocks
916 = utils::div_up(diff_wei_conf.N, diff_wei_conf.n_block);
917 diff_wei_conf.n_tail = diff_wei_conf.N % diff_wei_conf.n_block;
918
919 const dim_t As_wei
920 = src_layer_type_size * diff_wei_conf.M * diff_wei_conf.K;
921 const dim_t Bs_wei
922 = src_layer_type_size * diff_wei_conf.K * diff_wei_conf.n_block;
923 const dim_t Cs_wei = scratch_type_size * (rnn.n_gates + 1)
924 * (diff_wei_conf.M * diff_wei_conf.n_block);
925
926 std::tie(diff_wei_conf.k_block, std::ignore)
927 = brgemm_calc_k_block(diff_wei_conf.K, diff_wei_conf.K,
928 diff_wei_conf.M, diff_wei_conf.n_block, cell_kind,
929 src_layer_type_size, As_wei, Bs_wei, Cs_wei, l2_cache_size,
930 diff_wei_conf.isa, false, rnn.is_cell_dt_bf16());
931
932 diff_wei_conf.K_blocks = diff_wei_conf.K / diff_wei_conf.k_block;
933 diff_wei_conf.k_tail = diff_wei_conf.K % diff_wei_conf.k_block;
934
935 if (diff_wei_conf.M_iter != diff_wei_conf.M_layer) {
936 diff_wei_conf.m_block = diff_wei_conf.M;
937 diff_wei_conf.M_blocks = 1;
938 } else {
939 const float work_by_N_wei = static_cast<float>(diff_wei_conf.N_blocks)
940 / static_cast<float>(rnn.nthr);
941
942 diff_wei_conf.m_block
943 = brgemm_calc_m_block(cell_kind, prop_kind::backward, rnn.nthr,
944 diff_wei_conf.M, diff_wei_conf.N_blocks,
945 rnn.is_cell_dt_f32(), false, is_wei_bf16_amx,
946 work_by_N_wei, As_wei, Bs_wei, Cs_wei, l2_cache_size);
947 diff_wei_conf.M_blocks = diff_wei_conf.M / diff_wei_conf.m_block;
948 }
949
950 diff_wei_conf.LDA_layer = diff_wei_conf.K;
951 diff_wei_conf.LDA_iter = diff_wei_conf.K;
952 diff_wei_conf.LDB = diff_wei_conf.n_block;
953 diff_wei_conf.LDC_iter = rnn.diff_weights_iter_ld;
954 diff_wei_conf.LDC_layer = rnn.diff_weights_layer_ld;
955
956 if (diff_wei_conf.LDA_layer < diff_wei_conf.k_block
957 || diff_wei_conf.LDA_iter < diff_wei_conf.k_block)
958 return status::unimplemented;
959
960 if (rnn.is_lstm_peephole) { configure_brgemm_peephole(rnn); }
961
962 rnn.M = nstl::max(diff_wei_conf.M, diff_src_conf.M);
963 rnn.N = nstl::max(diff_wei_conf.N, diff_src_conf.N);
964 rnn.K1 = nstl::max(diff_wei_conf.K, diff_src_conf.K);
965 rnn.K2 = rnn.K1;
966 rnn.m_block = nstl::max(diff_wei_conf.m_block, diff_src_conf.m_block);
967 rnn.M_blocks = nstl::max(diff_wei_conf.M_blocks, diff_src_conf.M_blocks);
968 rnn.n_block = nstl::max(diff_wei_conf.n_block, diff_src_conf.n_block);
969 rnn.N_blocks = nstl::max(diff_wei_conf.N_blocks, diff_src_conf.N_blocks);
970 rnn.n_tail = nstl::max(diff_wei_conf.n_tail, diff_src_conf.n_tail);
971 rnn.k1_block = nstl::max(diff_wei_conf.k_block, diff_src_conf.k_block);
972 rnn.k2_block = rnn.k1_block;
973 rnn.k1_tail = nstl::max(diff_wei_conf.k_tail, diff_src_conf.k_tail);
974 rnn.k2_tail = rnn.k1_tail;
975 rnn.KB1_blocks = nstl::max(diff_wei_conf.K_blocks, diff_src_conf.K_blocks);
976 rnn.KB2_blocks = rnn.KB1_blocks;
977 rnn.K1padded = nstl::max(diff_wei_conf.Kpadded, diff_src_conf.Kpadded);
978 rnn.K2padded = rnn.K1padded;
979 rnn.unfused_post_gemm = true;
980
981 if (utils::one_of(
982 x64::avx512_core_amx, diff_wei_conf.isa, diff_src_conf.isa)) {
983 rnn.brgemm_isa = x64::avx512_core_amx;
984 } else {
985 rnn.brgemm_isa = diff_wei_conf.isa;
986 }
987
988 if (!rnn.is_orig_gru) {
989 rnn.diff_src_brgemm.loop_order
990 = is_bf16 && diff_src_conf.isa == x64::avx512_core_amx
991 ? brgemm_rnn_execute_loop_order_t::mblk_nblk
992 : brgemm_rnn_execute_loop_order_t::nblk_mblk;
993 rnn.diff_wei_brgemm.loop_order
994 = is_bf16 && diff_wei_conf.isa == x64::avx512_core_amx
995 ? brgemm_rnn_execute_loop_order_t::mblk_nblk
996 : brgemm_rnn_execute_loop_order_t::nblk_mblk;
997 }
998
999 return status::success;
1000}
1001
1002static dim_t divide_block_to_improve_thread_balance(
1003 const dim_t initial_work_amount, const dim_t division_block,
1004 const dim_t nthr) {
1005
1006 const float nthr_f = static_cast<float>(nthr);
1007 const float initial_work = static_cast<float>(initial_work_amount) / nthr_f;
1008 const float decimal_initial_factor
1009 = initial_work - std::floor(initial_work);
1010 static constexpr float thread_balance_threashold = 0.8;
1011 static constexpr float tolerance = 0.01;
1012
1013 float max_decimal_factor = -1.0;
1014 dim_t best_candidate = -1.0;
1015 bool found_best_solution = false;
1016
1017 if (decimal_initial_factor < thread_balance_threashold
1018 && decimal_initial_factor != 0.0f) {
1019
1020 for (const int block_size : {4096, 2048, 1024, 512, 256, 128, 64, 32}) {
1021
1022 if (division_block <= block_size) continue;
1023
1024 const auto blocks = utils::div_up(division_block, block_size);
1025
1026 const float work
1027 = static_cast<float>(initial_work_amount * blocks) / nthr_f;
1028 const float work_decimal = work - std::floor(work);
1029
1030 if (work_decimal == 0.0f
1031 || (max_decimal_factor != 0.0f
1032 ? work_decimal
1033 > (max_decimal_factor + tolerance)
1034 : work_decimal >= thread_balance_threashold)
1035
1036 ) {
1037 best_candidate = block_size;
1038 max_decimal_factor = work_decimal;
1039 }
1040
1041 if (work >= nthr_f
1042 && (work_decimal >= thread_balance_threashold
1043 || work_decimal == 0.0f)) {
1044 found_best_solution = true;
1045 break;
1046 }
1047 }
1048 }
1049
1050 if (found_best_solution
1051 || (!found_best_solution
1052 && max_decimal_factor
1053 > decimal_initial_factor + tolerance)) {
1054 return best_candidate;
1055 }
1056
1057 return division_block;
1058}
1059
1060void rnn_brgemm_t<prop_kind::backward>::configure_brgemm_peephole(
1061 cpu::rnn_utils::rnn_conf_t &rnn) {
1062 static constexpr dim_t n_gates = 3;
1063 rnn.dhc_block_peephole = divide_block_to_improve_thread_balance(
1064 n_gates, rnn.dhc, rnn.nthr);
1065 rnn.dhc_blocks_peephole = utils::div_up(rnn.dhc, rnn.dhc_block_peephole);
1066 rnn.dhc_tail_peephole = rnn.dhc % rnn.dhc_block_peephole;
1067}
1068
1069static status_t init_kernels_diff_src(rnn_diff_src_brgemm_t &diff_src,
1070 const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type,
1071 data_type_t weights_type) {
1072
1073 const auto init_brgemm_diff_src
1074 = [&](x64::brgemm_t *desc, x64::cpu_isa_t isa,
1075 std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M,
1076 dim_t N, dim_t K, dim_t LDA, dim_t LDB, dim_t LDC,
1077 float beta, dim_t max_bs) {
1078 const dim_t A_size
1079 = rnn.diff_src_brgemm.M * rnn.diff_src_brgemm.Kpadded;
1080 const dim_t B_size
1081 = rnn.diff_src_brgemm.Kpadded * rnn.diff_src_brgemm.N;
1082 const dim_t C_size
1083 = rnn.diff_src_brgemm.M * rnn.diff_src_brgemm.N;
1084 return init_brgemm_kernel(desc, isa, src_type, weights_type,
1085 ker, M, N, K, LDA, LDB, LDC, beta, max_bs, A_size,
1086 B_size, C_size);
1087 };
1088
1089 const auto &diff_src_conf = rnn.diff_src_brgemm;
1090 const int n_diff_src = nstl::min(diff_src_conf.N, diff_src_conf.n_block);
1091 const int n_diff_src_iter_tail
1092 = nstl::min(diff_src_conf.N_iter, diff_src_conf.n_iter_tail);
1093 const int n_diff_src_layer_tail
1094 = nstl::min(diff_src_conf.N_layer, diff_src_conf.n_layer_tail);
1095 const auto K_batch_size = rnn.n_gates * diff_src_conf.K_blocks;
1096 const auto split_gates_computation
1097 = diff_src_conf.gates_block != rnn.n_gates;
1098 init_brgemm_diff_src(&diff_src.desc_iter_layer_beta0_, diff_src_conf.isa,
1099 diff_src.kernel_iter_layer_beta0_, diff_src_conf.m_block,
1100 n_diff_src, diff_src_conf.k_block, diff_src_conf.LDA,
1101 diff_src_conf.LDB, diff_src_conf.LDC, 0.0, K_batch_size);
1102 if (split_gates_computation)
1103 init_brgemm_diff_src(&diff_src.desc_iter_layer_beta1_,
1104 diff_src_conf.isa, diff_src.kernel_iter_layer_beta1_,
1105 diff_src_conf.m_block, n_diff_src, diff_src_conf.k_block,
1106 diff_src_conf.LDA, diff_src_conf.LDB, diff_src_conf.LDC, 1.0,
1107 K_batch_size);
1108
1109 if (n_diff_src_layer_tail) {
1110 init_brgemm_diff_src(&diff_src.desc_layer_N_tail_beta0_,
1111 diff_src_conf.isa, diff_src.kernel_layer_N_tail_beta0_,
1112 diff_src_conf.m_block, n_diff_src_layer_tail,
1113 diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB,
1114 diff_src_conf.LDC, 0.0, K_batch_size);
1115 if (split_gates_computation)
1116 init_brgemm_diff_src(&diff_src.desc_layer_N_tail_beta1_,
1117 diff_src_conf.isa, diff_src.kernel_layer_N_tail_beta1_,
1118 diff_src_conf.m_block, n_diff_src_layer_tail,
1119 diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB,
1120 diff_src_conf.LDC, 1.0, K_batch_size);
1121 }
1122
1123 if (n_diff_src_iter_tail) {
1124 init_brgemm_diff_src(&diff_src.desc_iter_N_tail_beta0_,
1125 diff_src_conf.isa, diff_src.kernel_iter_N_tail_beta0_,
1126 diff_src_conf.m_block, n_diff_src_iter_tail,
1127 diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB,
1128 diff_src_conf.LDC, 0.0, K_batch_size);
1129 if (split_gates_computation)
1130 init_brgemm_diff_src(&diff_src.desc_iter_N_tail_beta1_,
1131 diff_src_conf.isa, diff_src.kernel_iter_N_tail_beta1_,
1132 diff_src_conf.m_block, n_diff_src_iter_tail,
1133 diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB,
1134 diff_src_conf.LDC, 1.0, K_batch_size);
1135 }
1136
1137 if (diff_src_conf.k_tail) {
1138 init_brgemm_diff_src(&diff_src.desc_iter_layer_K_tail_beta1_,
1139 diff_src_conf.isa, diff_src.kernel_iter_layer_K_tail_beta1_,
1140 diff_src_conf.m_block, n_diff_src, diff_src_conf.k_tail,
1141 diff_src_conf.LDA, diff_src_conf.LDB, diff_src_conf.LDC, 1.0,
1142 rnn.n_gates);
1143
1144 if (n_diff_src_layer_tail) {
1145 init_brgemm_diff_src(&diff_src.desc_layer_NK_tail_beta1_,
1146 diff_src_conf.isa, diff_src.kernel_layer_NK_tail_beta1_,
1147 diff_src_conf.m_block, n_diff_src_layer_tail,
1148 diff_src_conf.k_tail, diff_src_conf.LDA, diff_src_conf.LDB,
1149 diff_src_conf.LDC, 1.0, rnn.n_gates);
1150 }
1151
1152 if (n_diff_src_iter_tail) {
1153 init_brgemm_diff_src(&diff_src.desc_iter_NK_tail_beta1_,
1154 diff_src_conf.isa, diff_src.kernel_iter_NK_tail_beta1_,
1155 diff_src_conf.m_block, n_diff_src_iter_tail,
1156 diff_src_conf.k_tail, diff_src_conf.LDA, diff_src_conf.LDB,
1157 diff_src_conf.LDC, 1.0, rnn.n_gates);
1158 }
1159 }
1160
1161 const bool is_bf16_amx = rnn.is_cell_dt_bf16()
1162 && diff_src_conf.isa == x64::avx512_core_amx;
1163
1164 if (is_bf16_amx) {
1165 CHECK(brgemm_init_tiles(diff_src.desc_iter_layer_beta0_,
1166 diff_src.pallete_buff_iter_layer_));
1167
1168 if (n_diff_src_layer_tail)
1169 CHECK(brgemm_init_tiles(diff_src.desc_layer_N_tail_beta0_,
1170 diff_src.pallete_buff_layer_n_tail_));
1171
1172 if (n_diff_src_iter_tail)
1173 CHECK(brgemm_init_tiles(diff_src.desc_iter_N_tail_beta0_,
1174 diff_src.pallete_buff_iter_n_tail_));
1175
1176 if (diff_src_conf.k_tail) {
1177 CHECK(brgemm_init_tiles(diff_src.desc_iter_layer_K_tail_beta1_,
1178 diff_src.pallete_buff_iter_layer_k_tail_));
1179
1180 if (n_diff_src_layer_tail)
1181 CHECK(brgemm_init_tiles(diff_src.desc_layer_NK_tail_beta1_,
1182 diff_src.pallete_buff_layer_nk_tail_));
1183
1184 if (n_diff_src_iter_tail)
1185 CHECK(brgemm_init_tiles(diff_src.desc_iter_NK_tail_beta1_,
1186 diff_src.pallete_buff_iter_nk_tail_));
1187 }
1188 }
1189
1190 return status::success;
1191}
1192
1193static status_t init_kernels_diff_wei(rnn_diff_wei_brgemm_t &diff_wei,
1194 const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type,
1195 data_type_t weights_type) {
1196
1197 const auto init_brgemm_diff_wei
1198 = [&](x64::brgemm_t *desc, x64::cpu_isa_t isa,
1199 std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M,
1200 dim_t N, dim_t K, dim_t LDA, dim_t LDB, dim_t LDC,
1201 float beta, dim_t max_bs) {
1202 const dim_t A_size
1203 = rnn.diff_wei_brgemm.M * rnn.diff_wei_brgemm.Kpadded;
1204 const dim_t B_size
1205 = rnn.diff_wei_brgemm.Kpadded * rnn.diff_wei_brgemm.N;
1206 const dim_t C_size
1207 = rnn.diff_wei_brgemm.M * rnn.diff_wei_brgemm.N;
1208 return init_brgemm_kernel(desc, isa, src_type, weights_type,
1209 ker, M, N, K, LDA, LDB, LDC, beta, max_bs, A_size,
1210 B_size, C_size);
1211 };
1212
1213 const auto &diff_wei_conf = rnn.diff_wei_brgemm;
1214 const bool is_m_block_equal = rnn.slc == rnn.sic;
1215 const auto m_block_iter
1216 = is_m_block_equal ? diff_wei_conf.m_block : diff_wei_conf.M_iter;
1217 const auto m_block_layer
1218 = is_m_block_equal ? diff_wei_conf.m_block : diff_wei_conf.M_layer;
1219 const auto n_diff_wei = nstl::min(diff_wei_conf.N, diff_wei_conf.n_block);
1220 const auto n_diff_wei_tail
1221 = nstl::min(diff_wei_conf.N, diff_wei_conf.n_tail);
1222
1223 init_brgemm_diff_wei(&diff_wei.desc_iter_beta1_, diff_wei_conf.isa,
1224 diff_wei.kernel_iter_beta1_, m_block_iter, n_diff_wei,
1225 diff_wei_conf.k_block, diff_wei_conf.LDA_iter, diff_wei_conf.LDB,
1226 diff_wei_conf.LDC_iter, 1.0, diff_wei_conf.K_blocks);
1227 init_brgemm_diff_wei(&diff_wei.desc_layer_beta1_, diff_wei_conf.isa,
1228 diff_wei.kernel_layer_beta1_, m_block_layer, n_diff_wei,
1229 diff_wei_conf.k_block, diff_wei_conf.LDA_layer, diff_wei_conf.LDB,
1230 diff_wei_conf.LDC_layer, 1.0, diff_wei_conf.K_blocks);
1231
1232 if (n_diff_wei_tail) {
1233 init_brgemm_diff_wei(&diff_wei.desc_iter_N_tail_beta1_,
1234 diff_wei_conf.isa, diff_wei.kernel_iter_N_tail_beta1_,
1235 m_block_iter, n_diff_wei_tail, diff_wei_conf.k_block,
1236 diff_wei_conf.LDA_iter, diff_wei_conf.LDB,
1237 diff_wei_conf.LDC_iter, 1.0, diff_wei_conf.K_blocks);
1238 init_brgemm_diff_wei(&diff_wei.desc_layer_N_tail_beta1_,
1239 diff_wei_conf.isa, diff_wei.kernel_layer_N_tail_beta1_,
1240 m_block_layer, n_diff_wei_tail, diff_wei_conf.k_block,
1241 diff_wei_conf.LDA_layer, diff_wei_conf.LDB,
1242 diff_wei_conf.LDC_layer, 1.0, diff_wei_conf.K_blocks);
1243
1244 if (diff_wei_conf.k_tail) {
1245 init_brgemm_diff_wei(&diff_wei.desc_iter_NK_tail_beta1_,
1246 diff_wei_conf.isa, diff_wei.kernel_iter_NK_tail_beta1_,
1247 m_block_iter, n_diff_wei_tail, diff_wei_conf.k_tail,
1248 diff_wei_conf.LDA_iter, diff_wei_conf.LDB,
1249 diff_wei_conf.LDC_iter, 1.0, 1);
1250 init_brgemm_diff_wei(&diff_wei.desc_layer_NK_tail_beta1_,
1251 diff_wei_conf.isa, diff_wei.kernel_layer_NK_tail_beta1_,
1252 m_block_layer, n_diff_wei_tail, diff_wei_conf.k_tail,
1253 diff_wei_conf.LDA_layer, diff_wei_conf.LDB,
1254 diff_wei_conf.LDC_layer, 1.0, 1);
1255 }
1256 }
1257
1258 if (diff_wei_conf.k_tail) {
1259 init_brgemm_diff_wei(&diff_wei.desc_iter_K_tail_beta1_,
1260 diff_wei_conf.isa, diff_wei.kernel_iter_K_tail_beta1_,
1261 m_block_iter, n_diff_wei, diff_wei_conf.k_tail,
1262 diff_wei_conf.LDA_iter, diff_wei_conf.LDB,
1263 diff_wei_conf.LDC_iter, 1.0, 1);
1264 init_brgemm_diff_wei(&diff_wei.desc_layer_K_tail_beta1_,
1265 diff_wei_conf.isa, diff_wei.kernel_layer_K_tail_beta1_,
1266 m_block_layer, n_diff_wei, diff_wei_conf.k_tail,
1267 diff_wei_conf.LDA_layer, diff_wei_conf.LDB,
1268 diff_wei_conf.LDC_layer, 1.0, 1);
1269 }
1270
1271 const bool is_bf16_amx_wei = rnn.is_cell_dt_bf16()
1272 && diff_wei_conf.isa == x64::avx512_core_amx;
1273
1274 if (is_bf16_amx_wei) {
1275 CHECK(brgemm_init_tiles(
1276 diff_wei.desc_iter_beta1_, diff_wei.pallete_buff_iter_));
1277 CHECK(brgemm_init_tiles(
1278 diff_wei.desc_layer_beta1_, diff_wei.pallete_buff_layer_));
1279 if (n_diff_wei_tail) {
1280 CHECK(brgemm_init_tiles(diff_wei.desc_iter_N_tail_beta1_,
1281 diff_wei.pallete_buff_iter_n_tail_));
1282 CHECK(brgemm_init_tiles(diff_wei.desc_layer_N_tail_beta1_,
1283 diff_wei.pallete_buff_layer_n_tail_));
1284
1285 if (diff_wei_conf.k_tail) {
1286 CHECK(brgemm_init_tiles(diff_wei.desc_iter_NK_tail_beta1_,
1287 diff_wei.pallete_buff_iter_nk_tail_));
1288 CHECK(brgemm_init_tiles(diff_wei.desc_layer_NK_tail_beta1_,
1289 diff_wei.pallete_buff_layer_nk_tail_));
1290 }
1291 }
1292
1293 if (diff_wei_conf.k_tail) {
1294 CHECK(brgemm_init_tiles(diff_wei.desc_iter_K_tail_beta1_,
1295 diff_wei.pallete_buff_iter_k_tail_));
1296 CHECK(brgemm_init_tiles(diff_wei.desc_layer_K_tail_beta1_,
1297 diff_wei.pallete_buff_layer_k_tail_));
1298 }
1299 }
1300
1301 // Creating temporary matmul configuration descriptor to use copy_B jit
1302 // kernels from brgemm matmul copy routines for reordering scratch gates in
1303 // diff_wei rnn brgemm implementation.
1304 // TODO: provide unification of jit-based copy routines with implementation
1305 // independent interface
1306 matmul::brgemm_matmul_conf_t tmp_matmul_conf_for_reorder;
1307 tmp_matmul_conf_for_reorder.wei_tag = format_tag::ab;
1308 tmp_matmul_conf_for_reorder.N = rnn.scratch_gates_ld;
1309 tmp_matmul_conf_for_reorder.K = rnn.mb;
1310 tmp_matmul_conf_for_reorder.wei_n_blk = tmp_matmul_conf_for_reorder.N_blk
1311 = diff_wei_conf.n_block;
1312 tmp_matmul_conf_for_reorder.N_tail = diff_wei_conf.n_tail;
1313 tmp_matmul_conf_for_reorder.LDB = diff_wei_conf.LDB;
1314 tmp_matmul_conf_for_reorder.src_dt = tmp_matmul_conf_for_reorder.wei_dt
1315 = rnn.is_cell_dt_bf16() ? data_type::bf16 : data_type::f32;
1316 tmp_matmul_conf_for_reorder.a_dt_sz = tmp_matmul_conf_for_reorder.tr_a_dt_sz
1317 = types::data_type_size(tmp_matmul_conf_for_reorder.src_dt);
1318 tmp_matmul_conf_for_reorder.b_dt_sz = tmp_matmul_conf_for_reorder.tr_b_dt_sz
1319 = types::data_type_size(tmp_matmul_conf_for_reorder.wei_dt);
1320 CHECK(matmul::create_brgemm_matmul_copy_b(
1321 diff_wei.srcatch_gates_reorder_kernel_,
1322 &tmp_matmul_conf_for_reorder));
1323
1324 return status::success;
1325}
1326
1327status_t rnn_brgemm_t<prop_kind::backward>::init_kernels(
1328 const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type,
1329 data_type_t weights_type) {
1330
1331 init_kernels_diff_src(diff_src_, rnn, src_type, weights_type);
1332 init_kernels_diff_wei(diff_wei_, rnn, src_type, weights_type);
1333 if (rnn.is_lstm_peephole) CHECK(init_peephole_kernels(rnn));
1334
1335 const auto n_diff_wei_tail
1336 = nstl::min(rnn.diff_wei_brgemm.N, rnn.diff_wei_brgemm.n_tail);
1337 kernel_gates_reduction_
1338 = utils::make_unique<jit_gates_reduction_t>(rnn, false /*n_tail*/);
1339 kernel_gates_reduction_->create_kernel();
1340
1341 if (n_diff_wei_tail) {
1342 kernel_gates_reduction_tail_
1343 = utils::make_unique<jit_gates_reduction_t>(
1344 rnn, true /*n_tail*/);
1345 kernel_gates_reduction_tail_->create_kernel();
1346 }
1347
1348 if (rnn.mb == 1) {
1349 if (src_type == data_type::bf16) {
1350 const bool is_m_block_equal = rnn.slc == rnn.sic;
1351 const auto m_block_iter = is_m_block_equal
1352 ? rnn.diff_wei_brgemm.m_block
1353 : rnn.diff_wei_brgemm.M_iter;
1354
1355 kernel_transpose_single_row_iter_
1356 = utils::make_unique<jit_brgemm_transpose_single_row_t>(
1357 m_block_iter);
1358 CHECK(kernel_transpose_single_row_iter_->create_kernel());
1359
1360 if (!is_m_block_equal) {
1361 const auto m_block_layer = is_m_block_equal
1362 ? rnn.diff_wei_brgemm.m_block
1363 : rnn.diff_wei_brgemm.M_layer;
1364 kernel_transpose_single_row_layer_
1365 = utils::make_unique<jit_brgemm_transpose_single_row_t>(
1366 m_block_layer);
1367 CHECK(kernel_transpose_single_row_layer_->create_kernel());
1368 }
1369 }
1370 } else {
1371 jit_brgemm_primitive_conf_t trans_conf;
1372 trans_conf.prop_kind = dnnl_backward_weights;
1373 trans_conf.src_dt = src_type;
1374 static constexpr int blk_size = 16;
1375 trans_conf.os_block = blk_size; // src's rows block size
1376 trans_conf.ic_block = blk_size; // src's cols block size
1377 trans_conf.M = 0;
1378 const auto rnd_up_size = (src_type == data_type::bf16 ? 2 : 1);
1379 trans_conf.LDA
1380 = utils::rnd_up(rnn.mb, rnd_up_size); // dst's leading dim
1381 trans_conf.K_tail = rnn.mb % blk_size; // src's rows tail
1382
1383 const int LDA_iter[]
1384 = {rnn.src_iter_ld_, rnn.dst_layer_ld_, rnn.ws_states_iter_ld};
1385 trans_conf.M_tail = rnn.sic % blk_size; // src's cols tail
1386 for (int i = 0; i < num_base_kernels_; i++) {
1387 trans_conf.ic = LDA_iter[i];
1388 CHECK(create_brgemm_trans_src(
1389 kernel_transpose_iter_[i], &trans_conf));
1390 }
1391
1392 const int LDA_layer[]
1393 = {rnn.src_layer_ld_, rnn.dst_iter_ld_, rnn.ws_states_layer_ld};
1394 trans_conf.M_tail = rnn.slc % blk_size; // src's cols tail
1395 for (int i = 0; i < num_base_kernels_; i++) {
1396 trans_conf.ic = LDA_layer[i];
1397 CHECK(create_brgemm_trans_src(
1398 kernel_transpose_layer_[i], &trans_conf));
1399 }
1400 }
1401
1402 return status::success;
1403}
1404
1405status_t rnn_brgemm_t<prop_kind::backward>::init_peephole_kernels(
1406 const cpu::rnn_utils::rnn_conf_t &rnn) {
1407
1408 if (rnn.dhc_blocks_peephole) {
1409 kernel_peephole_ = utils::make_unique<jit_diff_weights_peephole_t>(
1410 rnn, rnn.dhc_block_peephole);
1411 CHECK(kernel_peephole_->create_kernel());
1412 }
1413
1414 if (rnn.dhc_tail_peephole) {
1415 kernel_peephole_tail_ = utils::make_unique<jit_diff_weights_peephole_t>(
1416 rnn, rnn.dhc_tail_peephole);
1417 CHECK(kernel_peephole_tail_->create_kernel());
1418 }
1419
1420 return status::success;
1421}
1422
1423} // namespace rnn_brgemm_utils
1424} // namespace x64
1425} // namespace cpu
1426} // namespace impl
1427} // namespace dnnl
1428