1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <tuple> |
17 | #include <utility> |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "cpu/rnn/rnn_utils.hpp" |
20 | #include "cpu/x64/rnn/rnn_brgemm_utils.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | namespace rnn_brgemm_utils { |
27 | |
28 | namespace { |
29 | |
30 | x64::cpu_isa_t brgemm_calc_isa(dim_t K1, dim_t K2, bool is_int8, bool is_bf16); |
31 | std::pair<dim_t, dim_t> brgemm_calc_k_block(dim_t K1, dim_t K2, dim_t M, |
32 | dim_t n_block, alg_kind_t cell_kind, dim_t src_layer_type_size, |
33 | dim_t As, dim_t Bs, dim_t Cs, dim_t l2_cache_size, x64::cpu_isa_t isa, |
34 | bool is_int8, bool is_bf16); |
35 | std::pair<dim_t, dim_t> brgemm_calc_k_block_amx( |
36 | dim_t K1, dim_t K2, bool is_int8); |
37 | std::pair<dim_t, dim_t> brgemm_calc_k_block_vanilla_rnn(dim_t K1, dim_t K2, |
38 | dim_t M, dim_t n_block, dim_t src_layer_type_size, dim_t As, dim_t Bs, |
39 | dim_t Cs, dim_t l2_cache_size, bool is_bf16); |
40 | |
41 | dim_t brgemm_calc_m_block(alg_kind_t cell_kind, prop_kind_t aprop, dim_t nthr, |
42 | dim_t M, dim_t N_blocks, bool is_f32, bool is_int8_amx, |
43 | bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs, dim_t Cs, |
44 | dim_t l2_cache_size); |
45 | dim_t brgemm_calc_m_block_vanilla_rnn(dim_t nthr, dim_t M, dim_t N_blocks, |
46 | bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs, |
47 | dim_t Cs, dim_t l2_cache_size); |
48 | dim_t brgemm_calc_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_f32, |
49 | bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Cs, |
50 | dim_t l2_cache_size); |
51 | dim_t adjust_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_int8_amx, |
52 | bool is_bf16_amx); |
53 | |
54 | x64::cpu_isa_t brgemm_calc_isa(dim_t K1, dim_t K2, bool is_int8, bool is_bf16) { |
55 | const bool is_amx_int8 = is_int8 && x64::mayiuse(x64::avx512_core_amx); |
56 | const bool is_amx_bf16 = is_bf16 && x64::mayiuse(x64::avx512_core_amx); |
57 | |
58 | if (is_amx_int8 || is_amx_bf16) { |
59 | const dim_t padding = (is_int8 ? 4 : (is_bf16 ? 2 : 1)); |
60 | const auto result = brgemm_calc_k_block_amx(K1, K2, is_int8); |
61 | const auto k1_block_amx = result.first; |
62 | const auto k2_block_amx = result.second; |
63 | const auto k1_block_tail = K1 % k1_block_amx; |
64 | const auto k2_block_tail = K2 % k2_block_amx; |
65 | const bool amx_block_invalid = k1_block_tail % padding |
66 | || k2_block_tail % padding || k1_block_amx % padding |
67 | || k2_block_amx % padding; |
68 | |
69 | if (!amx_block_invalid) return x64::avx512_core_amx; |
70 | } |
71 | |
72 | if (is_int8) { |
73 | return x64::avx512_core_vnni; |
74 | } else if (is_bf16) { |
75 | return x64::avx512_core_bf16; |
76 | } |
77 | |
78 | return x64::isa_undef; |
79 | } |
80 | |
81 | std::pair<dim_t, dim_t> brgemm_calc_k_block(dim_t K1, dim_t K2, dim_t M, |
82 | dim_t n_block, alg_kind_t cell_kind, dim_t src_layer_type_size, |
83 | dim_t As, dim_t Bs, dim_t Cs, dim_t l2_cache_size, x64::cpu_isa_t isa, |
84 | bool is_int8, bool is_bf16) { |
85 | const bool is_amx_int8 = is_int8 && isa == x64::avx512_core_amx; |
86 | const bool is_amx_bf16 = is_bf16 && isa == x64::avx512_core_amx; |
87 | |
88 | if (is_amx_int8 || is_amx_bf16) |
89 | return brgemm_calc_k_block_amx(K1, K2, is_int8); |
90 | else if (cell_kind == alg_kind::vanilla_rnn) |
91 | return brgemm_calc_k_block_vanilla_rnn(K1, K2, M, n_block, |
92 | src_layer_type_size, As, Bs, Cs, l2_cache_size, is_bf16); |
93 | |
94 | return std::make_pair(K1, K2); |
95 | } |
96 | |
97 | std::pair<dim_t, dim_t> brgemm_calc_k_block_amx( |
98 | dim_t K1, dim_t K2, bool is_int8) { |
99 | const bool is_amx_int8 = is_int8 && x64::mayiuse(x64::avx512_core_amx); |
100 | const dim_t max_row_width = is_amx_int8 ? 64 : 32; |
101 | |
102 | dim_t k1_block = nstl::min(K1, max_row_width); |
103 | dim_t k2_block = nstl::min(K2, max_row_width); |
104 | |
105 | if (k1_block <= K1 || k2_block <= K2) { |
106 | const dim_t t_k_block = nstl::min(k1_block, k2_block); |
107 | k2_block = k1_block = t_k_block; |
108 | } |
109 | |
110 | return std::make_pair(k1_block, k2_block); |
111 | } |
112 | |
113 | std::pair<dim_t, dim_t> brgemm_calc_k_block_vanilla_rnn(dim_t K1, dim_t K2, |
114 | dim_t M, dim_t n_block, dim_t src_layer_type_size, dim_t As, dim_t Bs, |
115 | dim_t Cs, dim_t l2_cache_size, bool is_bf16) { |
116 | |
117 | //Heuristics experimentally selected. |
118 | const bool should_adjust_by_l2 = static_cast<float>(As + Bs + Cs) |
119 | >= 0.25 * static_cast<float>(l2_cache_size); |
120 | dim_t k1_block = K1; |
121 | dim_t k2_block = K2; |
122 | |
123 | if (should_adjust_by_l2) { |
124 | int block_size = (l2_cache_size * 0.25f) |
125 | / ((M + n_block) * src_layer_type_size); |
126 | |
127 | if (is_bf16) { |
128 | // due to weights format ldgOI32o2i block_size should be even |
129 | block_size -= (block_size % 2); |
130 | block_size = nstl::max(block_size, 0); |
131 | } |
132 | if (block_size) { |
133 | k1_block = nstl::min(K1, static_cast<dim_t>(block_size)); |
134 | k2_block = nstl::min(K2, static_cast<dim_t>(block_size)); |
135 | } |
136 | } |
137 | |
138 | return std::make_pair(k1_block, k2_block); |
139 | } |
140 | |
141 | dim_t brgemm_calc_m_block(alg_kind_t cell_kind, prop_kind_t aprop, dim_t nthr, |
142 | dim_t M, dim_t N_blocks, bool is_f32, bool is_int8_amx, |
143 | bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs, dim_t Cs, |
144 | dim_t l2_cache_size) { |
145 | if (cell_kind == alg_kind::vanilla_rnn |
146 | || (cell_kind == alg_kind::vanilla_lstm |
147 | && aprop == prop_kind::backward)) |
148 | return brgemm_calc_m_block_vanilla_rnn(nthr, M, N_blocks, is_int8_amx, |
149 | is_bf16_amx, work_by_N, As, Bs, Cs, l2_cache_size); |
150 | else |
151 | return brgemm_calc_m_block_lstm(nthr, M, N_blocks, is_f32, is_int8_amx, |
152 | is_bf16_amx, work_by_N, As, Cs, l2_cache_size); |
153 | } |
154 | |
155 | dim_t brgemm_calc_m_block_vanilla_rnn(dim_t nthr, dim_t M, dim_t N_blocks, |
156 | bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Bs, |
157 | dim_t Cs, dim_t l2_cache_size) { |
158 | |
159 | //Heuristics experimentally selected. |
160 | const float decimal_n_factor = work_by_N - std::floor(work_by_N); |
161 | static constexpr float thread_balance_threashold = 0.9; |
162 | |
163 | dim_t m_block = M; |
164 | |
165 | if (work_by_N < 1.0) |
166 | return adjust_m_block_lstm(nthr, M, N_blocks, is_int8_amx, is_bf16_amx); |
167 | else if (decimal_n_factor < thread_balance_threashold |
168 | && decimal_n_factor != 0.0f) { |
169 | |
170 | const dim_t m_block_start = M / 2; |
171 | const dim_t m_block_end = 8; |
172 | |
173 | float max_decimal_mn = 0.0; |
174 | dim_t best_candidate = 0.0; |
175 | bool found_best_solution = false; |
176 | |
177 | for (dim_t m_block_it = m_block_start; m_block_it >= m_block_end; |
178 | m_block_it--) { |
179 | if (M % m_block_it == 0) { |
180 | const auto m_blocks = M / m_block_it; |
181 | const auto work_by_MN |
182 | = static_cast<float>(m_blocks * N_blocks) / nthr; |
183 | |
184 | const float work_by_MN_decimal |
185 | = work_by_MN - std::floor(work_by_MN); |
186 | |
187 | static constexpr float tolerance = 0.01; |
188 | if (work_by_MN_decimal > (max_decimal_mn + tolerance)) { |
189 | best_candidate = m_block_it; |
190 | max_decimal_mn = work_by_MN_decimal; |
191 | } |
192 | |
193 | if (work_by_MN_decimal >= thread_balance_threashold |
194 | || work_by_MN_decimal == 0.0f) { |
195 | m_block = m_block_it; |
196 | found_best_solution = true; |
197 | break; |
198 | } |
199 | } |
200 | } |
201 | |
202 | if (!found_best_solution) { |
203 | if ((decimal_n_factor < max_decimal_mn) |
204 | || (static_cast<float>(As) |
205 | > (0.5f * static_cast<float>(l2_cache_size)))) { |
206 | m_block = best_candidate; |
207 | } |
208 | } |
209 | } |
210 | |
211 | return m_block; |
212 | } |
213 | |
214 | dim_t brgemm_calc_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_f32, |
215 | bool is_int8_amx, bool is_bf16_amx, float work_by_N, dim_t As, dim_t Cs, |
216 | dim_t l2_cache_size) { |
217 | const bool adj_by_l2 = is_f32 |
218 | ? true |
219 | : (static_cast<float>(As + Cs) |
220 | < 0.6 * static_cast<float>(l2_cache_size)); |
221 | |
222 | if (work_by_N > 2.0 || (work_by_N > 1.0 && adj_by_l2)) |
223 | return M; |
224 | else |
225 | return adjust_m_block_lstm(nthr, M, N_blocks, is_int8_amx, is_bf16_amx); |
226 | } |
227 | |
228 | dim_t adjust_m_block_lstm(dim_t nthr, dim_t M, dim_t N_blocks, bool is_int8_amx, |
229 | bool is_bf16_amx) { |
230 | |
231 | const bool is_amx = is_int8_amx || is_bf16_amx; |
232 | |
233 | const dim_t max_m_blocks = (is_amx ? 1 : 4) * utils::div_up(nthr, N_blocks); |
234 | const dim_t max_m_value = is_amx ? 64 : 24; |
235 | const dim_t max_M |
236 | = nstl::min(max_m_value, nstl::max((dim_t)1, M / max_m_blocks)); |
237 | const dim_t min_M = 4; |
238 | |
239 | dim_t m_block = 1; |
240 | for (dim_t m = max_M; m >= min_M; m--) |
241 | if (M % m == 0) { |
242 | m_block = m; |
243 | break; |
244 | } |
245 | if (m_block == 1) m_block = M; |
246 | |
247 | return m_block; |
248 | } |
249 | |
250 | x64::cpu_isa_t adjust_isa_by_m_block( |
251 | x64::cpu_isa_t current_isa, dim_t m_block, bool is_int8_amx) { |
252 | /* |
253 | * If we have m<4 TMUL and AVX512 vnni calculate the same number of |
254 | * operation per instruction but TMUL is 2x slower for int8 in terms of |
255 | * throughput. |
256 | */ |
257 | if (is_int8_amx && m_block < 4) { |
258 | if (x64::mayiuse(x64::avx512_core_amx)) return x64::avx512_core_amx; |
259 | } |
260 | |
261 | return current_isa; |
262 | } |
263 | |
264 | } // namespace |
265 | |
266 | void rnn_brgemm_base_t::init_scratchpad(const cpu::rnn_utils::rnn_conf_t &rnn, |
267 | memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size, |
268 | dim_t gemm_acc_align) { |
269 | |
270 | using namespace memory_tracking::names; |
271 | |
272 | if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) { |
273 | const auto m_block = rnn.merge_gemm_layer |
274 | ? nstl::max(rnn.m_block, rnn.mlayermerged_block) |
275 | : rnn.m_block; |
276 | size_t n_elements = m_block * rnn.n_block; |
277 | scratchpad.book(key_brgemm_primitive_buffer, rnn.nthr * n_elements, |
278 | gemm_acc_type_size, gemm_acc_align); |
279 | } |
280 | |
281 | if (rnn.is_bf32()) { |
282 | |
283 | const dims_t wei_layer_dims |
284 | = {rnn.n_layer, rnn.n_dir, rnn.n_gates, rnn.slc, rnn.dlc}; |
285 | const dims_t wei_iter_dims |
286 | = {rnn.n_layer, rnn.n_dir, rnn.n_gates, rnn.sic, rnn.dic}; |
287 | |
288 | memory_desc_t wei_layer_desc; |
289 | const auto tag = rnn.n_block == 64 ? format_tag::ldgOI64o2i |
290 | : format_tag::ldgOI32o2i; |
291 | memory_desc_init_by_tag( |
292 | wei_layer_desc, 5, wei_layer_dims, data_type::bf16, tag); |
293 | |
294 | memory_desc_t wei_iter_desc; |
295 | memory_desc_init_by_tag( |
296 | wei_iter_desc, 5, wei_iter_dims, data_type::bf16, tag); |
297 | |
298 | scratchpad.book(key_rnn_bf32_wei_layer_trans, |
299 | memory_desc_wrapper(wei_layer_desc).size(), 64); |
300 | |
301 | scratchpad.book(key_rnn_bf32_wei_iter_trans, |
302 | memory_desc_wrapper(wei_iter_desc).size(), 64); |
303 | |
304 | scratchpad.book(key_rnn_bf32_attention_trans, |
305 | rnn.n_iter * rnn.mb * sizeof(bfloat16_t), 64); |
306 | } |
307 | |
308 | const int max_K_Block |
309 | = nstl::max(rnn.KB1_blocks + 1, |
310 | nstl::max(rnn.KBproj_blocks + 1, rnn.KB2_blocks + 1)) |
311 | * (rnn.brgemm_fwd_iter_layer_fuse_possible ? 2 : 1); |
312 | scratchpad.template book<x64::brgemm_batch_element_t>( |
313 | key_brgemm_primitive_batch, max_K_Block * rnn.nthr); |
314 | } |
315 | |
316 | status_t rnn_brgemm_t<prop_kind::forward>::configure_brgemm( |
317 | cpu::rnn_utils::rnn_conf_t &rnn, alg_kind_t cell_kind, |
318 | dim_t src_layer_type_size, dim_t scratch_type_size) { |
319 | using namespace cpu::rnn_utils; |
320 | |
321 | rnn.M = rnn.mb; |
322 | rnn.N = rnn.dhc; |
323 | rnn.K1 = rnn.slc; |
324 | rnn.K2 = rnn.sic; |
325 | const auto is_int8 = rnn.is_cell_dt_int8(); |
326 | const auto is_bf16 = rnn.is_cell_dt_bf16(); |
327 | |
328 | const dim_t padding = (is_int8 ? 4 : (is_bf16 ? 2 : 1)); |
329 | rnn.K1padded = utils::rnd_up(rnn.K1, padding); |
330 | rnn.K2padded = utils::rnd_up(rnn.K2, padding); |
331 | |
332 | rnn.brgemm_isa = brgemm_calc_isa(rnn.K1, rnn.K2, is_int8, is_bf16); |
333 | const int bf32_reduction_dim_threshold = 128; |
334 | const bool is_shape_ok_for_bf32 = rnn.K1 >= bf32_reduction_dim_threshold |
335 | && rnn.K2 >= bf32_reduction_dim_threshold; |
336 | const bool is_bf32 = is_bf16 && rnn.brgemm_isa == avx512_core_amx |
337 | && rnn.dt_conf == all_f32 |
338 | // workspace data type and layouts can differ between fwd and bwd |
339 | // implementations during training. |
340 | && !rnn.is_training |
341 | // bf16 lstm_projection is not supported, so neither is bf32. |
342 | && !rnn.is_lstm_projection && is_shape_ok_for_bf32; |
343 | if (!IMPLICATION(rnn.is_cell_dt_bf16(), rnn.is_bf16_conf() || is_bf32)) |
344 | return status::unimplemented; |
345 | |
346 | rnn.nthr = dnnl_get_max_threads(); |
347 | const bool is_amx_isa_selected |
348 | = rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx(); |
349 | const bool can_use_block64 |
350 | = is_amx_isa_selected && rnn.N % 64 == 0 && !rnn.is_lstm_projection; |
351 | rnn.n_block = can_use_block64 ? 64 : 32; |
352 | rnn.N_blocks = utils::div_up(rnn.N, rnn.n_block); |
353 | rnn.n_tail = rnn.N % rnn.n_block; |
354 | |
355 | const float work_by_N |
356 | = static_cast<float>(rnn.N_blocks) / static_cast<float>(rnn.nthr); |
357 | |
358 | const dim_t l2_cache_size = platform::get_per_core_cache_size(2); |
359 | const dim_t As = src_layer_type_size * rnn.M * (nstl::max(rnn.K1, rnn.K2)); |
360 | const dim_t Bs |
361 | = src_layer_type_size * (nstl::max(rnn.K1, rnn.K2)) * rnn.n_block; |
362 | const dim_t Cs |
363 | = scratch_type_size * (rnn.n_gates + 1) * (rnn.M * rnn.n_block); |
364 | |
365 | std::tie(rnn.k1_block, rnn.k2_block) = brgemm_calc_k_block(rnn.K1, rnn.K2, |
366 | rnn.M, rnn.n_block, cell_kind, src_layer_type_size, As, Bs, Cs, |
367 | l2_cache_size, rnn.brgemm_isa, is_int8, is_bf16); |
368 | rnn.KB1_blocks = rnn.K1 / rnn.k1_block; |
369 | rnn.k1_tail = rnn.K1 % rnn.k1_block; |
370 | rnn.KB2_blocks = rnn.K2 / rnn.k2_block; |
371 | rnn.k2_tail = rnn.K2 % rnn.k2_block; |
372 | rnn.m_block = brgemm_calc_m_block(cell_kind, prop_kind::forward, rnn.nthr, |
373 | rnn.M, rnn.N_blocks, rnn.is_cell_dt_f32(), rnn.is_cell_int8_amx(), |
374 | rnn.is_cell_bf16_amx(), work_by_N, As, Bs, Cs, l2_cache_size); |
375 | |
376 | rnn.M_blocks = rnn.M / rnn.m_block; |
377 | |
378 | rnn.brgemm_isa = adjust_isa_by_m_block( |
379 | rnn.brgemm_isa, rnn.m_block, rnn.is_cell_int8_amx()); |
380 | // Unfused post-gemm for lstm cell allows to parallelize across gates loop |
381 | // and reduces brgemm problem size for the single iteration of parallel loop |
382 | rnn.unfused_post_gemm = cell_kind == alg_kind::vanilla_lstm |
383 | ? IMPLICATION(rnn.M_blocks > 1, rnn.is_cell_bf16_amx()) |
384 | : false; |
385 | |
386 | rnn.LDA1[0] = rnn.src_layer_ld_; |
387 | rnn.LDA1[1] = rnn.dst_iter_ld_; |
388 | rnn.LDA1[2] = rnn.ws_states_layer_ld; |
389 | |
390 | rnn.LDA2[0] = rnn.src_iter_ld_; |
391 | rnn.LDA2[1] = rnn.dst_layer_ld_; |
392 | rnn.LDA2[2] = rnn.ws_states_iter_ld; |
393 | |
394 | rnn.LDA2_2[0] = rnn.dst_layer_ld_; |
395 | rnn.LDA2_2[1] = rnn.dst_iter_ld_; |
396 | rnn.LDA2_2[2] = rnn.ws_states_layer_ld; |
397 | rnn.LDA2_2[3] = rnn.ws_states_iter_ld; |
398 | |
399 | rnn.LDB1 = rnn.n_block; |
400 | rnn.LDB2 = rnn.n_block; |
401 | rnn.LDC = rnn.scratch_gates_ld; |
402 | |
403 | auto get_dim = [&](dim_t block, dim_t tail) { |
404 | return (block == 0) ? tail : block; |
405 | }; |
406 | |
407 | dim_t n_block = nstl::min(rnn.N, rnn.n_block); |
408 | dim_t n_tail = nstl::min(rnn.N, rnn.nproj_tail); |
409 | if (rnn.LDA1[0] < rnn.k1_block && rnn.LDA1[1] < rnn.k1_block |
410 | && rnn.LDA1[2] < rnn.k1_block) |
411 | return status::unimplemented; |
412 | if (rnn.LDA2[0] < rnn.k2_block && rnn.LDA2[1] < rnn.k2_block |
413 | && rnn.LDA2[2] < rnn.k2_block) |
414 | return status::unimplemented; |
415 | if (rnn.LDB1 < get_dim(n_block, n_tail) |
416 | && rnn.LDB2 < get_dim(n_block, n_tail)) |
417 | return status::unimplemented; |
418 | if (rnn.LDC < get_dim(n_block, n_tail)) return status::unimplemented; |
419 | |
420 | rnn.KBproj_blocks = 0; |
421 | rnn.kproj_tail = 0; |
422 | rnn.kproj_block = 0; |
423 | |
424 | if (rnn.is_lstm_projection) { |
425 | rnn.Nproj = rnn.dic; |
426 | rnn.Nproj_blocks = utils::div_up(rnn.Nproj, rnn.n_block); |
427 | rnn.nproj_tail = rnn.Nproj % rnn.n_block; |
428 | |
429 | rnn.Kproj = rnn.dhc; |
430 | rnn.Kprojpadded = utils::rnd_up(rnn.Kproj, padding); |
431 | if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) { |
432 | const dim_t max_row_width = rnn.is_cell_int8_amx() ? 64 : 32; |
433 | rnn.kproj_block = nstl::min(rnn.Kproj, (dim_t)max_row_width); |
434 | |
435 | rnn.KBproj_blocks = rnn.Kproj / rnn.kproj_block; |
436 | rnn.kproj_tail = rnn.Kproj % rnn.kproj_block; |
437 | |
438 | if ((rnn.kproj_tail % padding) || (rnn.kproj_block % padding)) { |
439 | rnn.kproj_block = rnn.Kproj; |
440 | rnn.kproj_tail = 0; |
441 | rnn.brgemm_isa = rnn.is_cell_dt_int8() ? x64::avx512_core_vnni |
442 | : x64::avx512_core_bf16; |
443 | } else { |
444 | rnn.brgemm_isa = x64::avx512_core_amx; |
445 | } |
446 | } else { |
447 | rnn.kproj_block = rnn.Kproj; |
448 | rnn.KBproj_blocks = rnn.Kproj / rnn.kproj_block; |
449 | } |
450 | rnn.LDAproj = rnn.proj_ht_ld; |
451 | rnn.LDBproj = rnn.n_block; |
452 | if (rnn.dt_conf != cpu::rnn_utils::all_f32) { |
453 | rnn.LDCproj[0] = rnn.scratch_gates_ld; |
454 | } else { |
455 | rnn.LDCproj[0] = rnn.scratch_ht_ld; |
456 | rnn.LDCproj[1] = rnn.dst_layer_ld_; |
457 | rnn.LDCproj[2] = rnn.dst_iter_ld_; |
458 | rnn.LDCproj[3] = rnn.ws_states_layer_ld; |
459 | } |
460 | |
461 | dim_t n_block = nstl::min(rnn.Nproj, rnn.n_block); |
462 | dim_t n_tail = nstl::min(rnn.Nproj, rnn.nproj_tail); |
463 | bool check_LDC = false; |
464 | if (rnn.dt_conf != cpu::rnn_utils::all_f32) { |
465 | check_LDC = rnn.LDCproj[0] < get_dim(n_block, n_tail); |
466 | } else { |
467 | check_LDC = rnn.LDCproj[0] < get_dim(n_block, n_tail) |
468 | && rnn.LDCproj[1] < get_dim(n_block, n_tail) |
469 | && rnn.LDCproj[2] < get_dim(n_block, n_tail) |
470 | && rnn.LDCproj[3] < get_dim(n_block, n_tail); |
471 | } |
472 | if (rnn.LDAproj < rnn.kproj_block |
473 | || rnn.LDBproj < get_dim(n_block, n_tail) || check_LDC) |
474 | return status::unimplemented; |
475 | } |
476 | |
477 | // enable merged across n_iter dimension layer part of the cell computation |
478 | // TODO: extend coverage for other problem types |
479 | const bool mlc_cell_type_ok = cell_kind == alg_kind::vanilla_lstm |
480 | && !rnn.is_lstm_projection && !rnn.is_lstm_peephole; |
481 | const int mlc_mb_max_threshold = 1; |
482 | const int mlc_n_iter_min_threshold = 2; |
483 | const int mlc_n_layer_max_threshold = 1; |
484 | const bool mlc_problem_shape_ok = rnn.mb <= mlc_mb_max_threshold |
485 | && rnn.n_iter >= mlc_n_iter_min_threshold |
486 | && rnn.n_layer <= mlc_n_layer_max_threshold; |
487 | // if rnn.skip_dst_iter_copy() == false we might need to reduce number of |
488 | // merged cells by 1 (rnn.n_iter - 1) |
489 | // but if in addition rnn.skip_src_layer_copy() == true then on cell |
490 | // position 'first_layer' we still should merge for all the n_iter cells |
491 | // so current support is limited by the case when layer computation is |
492 | // merged for all rnn.n_iter cells |
493 | const bool mlc_m_dim_adjustment_not_required |
494 | = IMPLICATION(rnn.skip_dst_iter_copy(), |
495 | rnn.skip_src_layer_copy() && rnn.n_layer == 1); |
496 | const bool merged_layer_compute_applicable = rnn.src_layer_is_trivial_stride |
497 | && mlc_cell_type_ok && mlc_problem_shape_ok |
498 | && mlc_m_dim_adjustment_not_required; |
499 | if (merged_layer_compute_applicable) { |
500 | rnn.merge_gemm_layer = true; |
501 | |
502 | // required adjustment if mlc_m_dim_adjustment_not_required = false |
503 | const int n_iters_to_merge = rnn.n_iter; |
504 | rnn.Mlayermerged = rnn.mb * n_iters_to_merge; |
505 | rnn.mlayermerged_block = brgemm_calc_m_block(cell_kind, |
506 | prop_kind::forward, rnn.nthr, rnn.Mlayermerged, rnn.N_blocks, |
507 | rnn.is_cell_dt_f32(), rnn.is_cell_int8_amx(), |
508 | rnn.is_cell_bf16_amx(), work_by_N, As, Bs, Cs, l2_cache_size); |
509 | |
510 | rnn.Mlayermerged_blocks = rnn.Mlayermerged / rnn.mlayermerged_block; |
511 | } |
512 | |
513 | rnn.brgemm_fwd_iter_layer_fuse_possible |
514 | = rnn.slc == rnn.sic && !rnn.merge_gemm_layer; |
515 | |
516 | if (!rnn.is_orig_gru) { |
517 | rnn.loop_order = rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx() |
518 | ? brgemm_rnn_execute_loop_order_t::mblk_nblk |
519 | : brgemm_rnn_execute_loop_order_t::nblk_mblk; |
520 | } |
521 | return status::success; |
522 | } |
523 | |
524 | status_t init_brgemm_kernel(x64::brgemm_t *desc, x64::cpu_isa_t isa, |
525 | impl::data_type_t src_type, impl::data_type_t weights_type, |
526 | std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M, dim_t N, dim_t K, |
527 | dim_t LDA, dim_t LDB, dim_t LDC, float beta, dim_t max_bs, |
528 | dim_t hint_expected_A_size = LLONG_MAX, |
529 | dim_t hint_expected_B_size = LLONG_MAX, |
530 | dim_t hint_expected_C_size = LLONG_MAX) { |
531 | bool transA = false; |
532 | bool transB = false; |
533 | x64::brgemm_layout_t layout = x64::brgemm_row_major; |
534 | CHECK(brgemm_desc_init(desc, isa, x64::brgemm_addr, src_type, weights_type, |
535 | transA, transB, layout, 1.0, beta, LDA, LDB, LDC, M, N, K)); |
536 | |
537 | x64::brgemm_attr_t brgattr; |
538 | |
539 | brgattr.hint_expected_A_size = hint_expected_A_size; |
540 | brgattr.hint_expected_B_size = hint_expected_B_size; |
541 | brgattr.hint_expected_C_size = hint_expected_C_size; |
542 | brgattr.max_bs = max_bs; |
543 | brgattr.max_top_vpad = 0; |
544 | brgattr.max_bottom_vpad = 0; |
545 | brgemm_desc_set_attr(desc, brgattr); |
546 | |
547 | x64::brgemm_kernel_t *_t_ptr; |
548 | CHECK(brgemm_kernel_create(&_t_ptr, *desc)); |
549 | safe_ptr_assign<x64::brgemm_kernel_t>(ker, _t_ptr); |
550 | |
551 | return status::success; |
552 | }; |
553 | |
554 | status_t rnn_brgemm_t<prop_kind::forward>::brgemm_rnn_init_tiles( |
555 | brgemm_t *desc_array, dim_t size, brgemm_pallete_t pallete) { |
556 | |
557 | for (dim_t it = 0; it < size; ++it) { |
558 | const auto &desc = desc_array[it]; |
559 | const bool desc_empty |
560 | = utils::everyone_is(0, desc.LDA, desc.LDB, desc.LDC); |
561 | if (!desc_empty) return brgemm_init_tiles(desc, pallete); |
562 | } |
563 | |
564 | return status::unimplemented; |
565 | } |
566 | |
567 | status_t rnn_brgemm_t<prop_kind::forward>::brgemm_rnn_init_tiles( |
568 | brgemm_t *desc_array, brgemm_pallete_t pallete) { |
569 | return brgemm_rnn_init_tiles(desc_array, num_base_kernels_, pallete); |
570 | } |
571 | status_t rnn_brgemm_t<prop_kind::forward>::brgemm_rnn_init_tiles_proj( |
572 | brgemm_t *desc_array, brgemm_pallete_t pallete) { |
573 | return brgemm_rnn_init_tiles(desc_array, num_proj_kernels_, pallete); |
574 | } |
575 | |
576 | status_t rnn_brgemm_t<prop_kind::forward>::init_kernels( |
577 | const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type, |
578 | data_type_t weights_type) { |
579 | |
580 | const auto init_brgemm |
581 | = [&](x64::brgemm_t *desc, x64::cpu_isa_t isa, |
582 | std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M, |
583 | dim_t N, dim_t K, dim_t LDA, dim_t LDB, dim_t LDC, |
584 | float beta, dim_t max_bs) { |
585 | return init_brgemm_kernel(desc, isa, src_type, weights_type, |
586 | ker, M, N, K, LDA, LDB, LDC, beta, max_bs); |
587 | }; |
588 | |
589 | const int brgemm_n = nstl::min(rnn.N, rnn.n_block); |
590 | const int brgemm_n_tail = nstl::min(rnn.N, rnn.n_tail); |
591 | const int max_bs_factor = rnn.brgemm_fwd_iter_layer_fuse_possible ? 2 : 1; |
592 | |
593 | for (int i = 0; i < num_base_kernels_; i++) { |
594 | if (rnn.merge_gemm_layer) { |
595 | init_brgemm(&desc_layermerged_b0_[i], rnn.brgemm_isa, |
596 | kernel_layermerged_b0_[i], rnn.mlayermerged_block, brgemm_n, |
597 | rnn.k1_block, rnn.LDA1[i], rnn.LDB1, rnn.LDC, 0.0, |
598 | rnn.KB1_blocks); |
599 | } else { |
600 | init_brgemm(&desc_layer_b0_[i], rnn.brgemm_isa, kernel_layer_b0_[i], |
601 | rnn.m_block, brgemm_n, rnn.k1_block, rnn.LDA1[i], rnn.LDB1, |
602 | rnn.LDC, 0.0, max_bs_factor * rnn.KB1_blocks); |
603 | } |
604 | |
605 | init_brgemm(&desc_iter_b0_[i], rnn.brgemm_isa, kernel_iter_b0_[i], |
606 | rnn.m_block, brgemm_n, rnn.k2_block, rnn.LDA2[i], rnn.LDB2, |
607 | rnn.LDC, 0.0, rnn.KB2_blocks); |
608 | init_brgemm(&desc_iter_b1_[i], rnn.brgemm_isa, kernel_iter_b1_[i], |
609 | rnn.m_block, brgemm_n, rnn.k2_block, rnn.LDA2[i], rnn.LDB2, |
610 | rnn.LDC, 1.0, rnn.KB2_blocks); |
611 | if (rnn.n_tail) { |
612 | if (rnn.merge_gemm_layer) { |
613 | init_brgemm(&desc_layermerged_N_tail_b0_[i], rnn.brgemm_isa, |
614 | kernel_layermerged_N_tail_b0_[i], |
615 | rnn.mlayermerged_block, brgemm_n_tail, rnn.k1_block, |
616 | rnn.LDA1[i], rnn.LDB1, rnn.LDC, 0.0, rnn.KB1_blocks); |
617 | } else { |
618 | init_brgemm(&desc_layer_N_tail_b0_[i], rnn.brgemm_isa, |
619 | kernel_layer_N_tail_b0_[i], rnn.m_block, brgemm_n_tail, |
620 | rnn.k1_block, rnn.LDA1[i], rnn.LDB1, rnn.LDC, 0.0, |
621 | max_bs_factor * rnn.KB1_blocks); |
622 | } |
623 | |
624 | init_brgemm(&desc_iter_N_tail_b0_[i], rnn.brgemm_isa, |
625 | kernel_iter_N_tail_b0_[i], rnn.m_block, brgemm_n_tail, |
626 | rnn.k2_block, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 0.0, |
627 | rnn.KB2_blocks); |
628 | init_brgemm(&desc_iter_N_tail_b1_[i], rnn.brgemm_isa, |
629 | kernel_iter_N_tail_b1_[i], rnn.m_block, brgemm_n_tail, |
630 | rnn.k2_block, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 1.0, |
631 | rnn.KB2_blocks); |
632 | } |
633 | if (rnn.k1_tail) { |
634 | if (rnn.merge_gemm_layer) { |
635 | init_brgemm(&desc_layermerged_K1_tail_b1_[i], rnn.brgemm_isa, |
636 | kernel_layermerged_K1_tail_b1_[i], |
637 | rnn.mlayermerged_block, brgemm_n, rnn.k1_tail, |
638 | rnn.LDA1[i], rnn.LDB1, rnn.LDC, 1.0, 1); |
639 | } else { |
640 | init_brgemm(&desc_layer_K1_tail_b1_[i], rnn.brgemm_isa, |
641 | kernel_layer_K1_tail_b1_[i], rnn.m_block, brgemm_n, |
642 | rnn.k1_tail, rnn.LDA1[i], rnn.LDB1, rnn.LDC, 1.0, |
643 | max_bs_factor * 1); |
644 | } |
645 | } |
646 | if (rnn.k2_tail) |
647 | init_brgemm(&desc_iter_K2_tail_b1_[i], rnn.brgemm_isa, |
648 | kernel_iter_K2_tail_b1_[i], rnn.m_block, brgemm_n, |
649 | rnn.k2_tail, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 1.0, 1); |
650 | if (rnn.k1_tail && rnn.n_tail) { |
651 | if (rnn.merge_gemm_layer) { |
652 | init_brgemm(&desc_layermerged_NK1_tail_b1_[i], rnn.brgemm_isa, |
653 | kernel_layermerged_NK1_tail_b1_[i], |
654 | rnn.mlayermerged_block, brgemm_n_tail, rnn.k1_tail, |
655 | rnn.LDA1[i], rnn.LDB1, rnn.LDC, 1.0, 1); |
656 | } else { |
657 | init_brgemm(&desc_layer_NK1_tail_b1_[i], rnn.brgemm_isa, |
658 | kernel_layer_NK1_tail_b1_[i], rnn.m_block, |
659 | brgemm_n_tail, rnn.k1_tail, rnn.LDA1[i], rnn.LDB1, |
660 | rnn.LDC, 1.0, max_bs_factor * 1); |
661 | } |
662 | } |
663 | if (rnn.k2_tail && rnn.n_tail) |
664 | init_brgemm(&desc_iter_NK2_tail_b1_[i], rnn.brgemm_isa, |
665 | kernel_iter_NK2_tail_b1_[i], rnn.m_block, brgemm_n_tail, |
666 | rnn.k2_tail, rnn.LDA2[i], rnn.LDB2, rnn.LDC, 1.0, 1); |
667 | } |
668 | if (rnn.is_orig_gru) { |
669 | for (int i = 0; i < num_vanilla_gru_iter_part2_kernels_; i++) { |
670 | init_brgemm(&desc_iter_p2_b1_[i], rnn.brgemm_isa, |
671 | kernel_iter_p2_b1_[i], rnn.m_block, brgemm_n, rnn.k2_block, |
672 | rnn.LDA2_2[i], rnn.LDB2, rnn.LDC, 1.0, rnn.KB2_blocks); |
673 | if (rnn.n_tail) |
674 | init_brgemm(&desc_iter_p2_N_tail_b1_[i], rnn.brgemm_isa, |
675 | kernel_iter_p2_N_tail_b1_[i], rnn.m_block, |
676 | brgemm_n_tail, rnn.k2_block, rnn.LDA2_2[i], rnn.LDB2, |
677 | rnn.LDC, 1.0, rnn.KB2_blocks); |
678 | if (rnn.k2_tail) |
679 | init_brgemm(&desc_iter_p2_K2_tail_b1_[i], rnn.brgemm_isa, |
680 | kernel_iter_p2_K2_tail_b1_[i], rnn.m_block, brgemm_n, |
681 | rnn.k2_tail, rnn.LDA2_2[i], rnn.LDB2, rnn.LDC, 1.0, 1); |
682 | if (rnn.k2_tail && rnn.n_tail) |
683 | init_brgemm(&desc_iter_p2_NK2_tail_b1_[i], rnn.brgemm_isa, |
684 | kernel_iter_p2_NK2_tail_b1_[i], rnn.m_block, |
685 | brgemm_n_tail, rnn.k2_tail, rnn.LDA2_2[i], rnn.LDB2, |
686 | rnn.LDC, 1.0, 1); |
687 | } |
688 | } |
689 | if (rnn.is_lstm_projection) { |
690 | const dim_t brgemm_np = nstl::min(rnn.Nproj, rnn.n_block); |
691 | const dim_t brgemm_np_tail = nstl::min(rnn.Nproj, rnn.nproj_tail); |
692 | const int n_kernel = (rnn.dt_conf == cpu::rnn_utils::all_f32) |
693 | ? num_proj_kernels_ |
694 | : 1; |
695 | for (int i = 0; i < n_kernel; i++) { |
696 | init_brgemm(&desc_proj_b0_[i], rnn.brgemm_isa, kernel_proj_b0_[i], |
697 | rnn.m_block, brgemm_np, rnn.kproj_block, rnn.LDAproj, |
698 | rnn.LDBproj, rnn.LDCproj[i], 0.0, rnn.KBproj_blocks); |
699 | if (rnn.nproj_tail) { |
700 | init_brgemm(&desc_proj_N_tail_b0_[i], rnn.brgemm_isa, |
701 | kernel_proj_N_tail_b0_[i], rnn.m_block, brgemm_np_tail, |
702 | rnn.kproj_block, rnn.LDAproj, rnn.LDBproj, |
703 | rnn.LDCproj[i], 0.0, rnn.KBproj_blocks); |
704 | init_brgemm(&desc_proj_N_tail_b1_[i], rnn.brgemm_isa, |
705 | kernel_proj_N_tail_b1_[i], rnn.m_block, brgemm_np_tail, |
706 | rnn.kproj_block, rnn.LDAproj, rnn.LDBproj, |
707 | rnn.LDCproj[i], 1.0, rnn.KBproj_blocks); |
708 | } |
709 | if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) { |
710 | if (rnn.kproj_tail) |
711 | init_brgemm(&desc_proj_K_tail_b1_[i], rnn.brgemm_isa, |
712 | kernel_proj_K_tail_b1_[i], rnn.m_block, brgemm_np, |
713 | rnn.kproj_tail, rnn.LDAproj, rnn.LDBproj, |
714 | rnn.LDCproj[i], 1.0, 1); |
715 | if (rnn.kproj_tail && rnn.nproj_tail) |
716 | init_brgemm(&desc_proj_NK_tail_b1_[i], rnn.brgemm_isa, |
717 | kernel_proj_NK_tail_b1_[i], rnn.m_block, |
718 | brgemm_np_tail, rnn.kproj_tail, rnn.LDAproj, |
719 | rnn.LDBproj, rnn.LDCproj[i], 1.0, 1); |
720 | } |
721 | } |
722 | } |
723 | |
724 | if (rnn.is_cell_int8_amx() || rnn.is_cell_bf16_amx()) { |
725 | if (rnn.merge_gemm_layer) |
726 | CHECK(brgemm_rnn_init_tiles( |
727 | desc_layermerged_b0_, pallete_buff_layermerged_)); |
728 | else |
729 | CHECK(brgemm_rnn_init_tiles(desc_layer_b0_, pallete_buff_layer_)); |
730 | CHECK(brgemm_rnn_init_tiles(desc_iter_b0_, pallete_buff_iter_)); |
731 | |
732 | if (rnn.n_tail) { |
733 | if (rnn.merge_gemm_layer) |
734 | CHECK(brgemm_rnn_init_tiles(desc_layermerged_N_tail_b0_, |
735 | pallete_buff_layermerged_n_tail_)); |
736 | else |
737 | CHECK(brgemm_rnn_init_tiles( |
738 | desc_layer_N_tail_b0_, pallete_buff_layer_n_tail_)); |
739 | CHECK(brgemm_rnn_init_tiles( |
740 | desc_iter_N_tail_b0_, pallete_buff_iter_n_tail_)); |
741 | } |
742 | if (rnn.k1_tail) { |
743 | if (rnn.merge_gemm_layer) |
744 | CHECK(brgemm_rnn_init_tiles(desc_layermerged_K1_tail_b1_, |
745 | pallete_buff_layermerged_k1_tail_)); |
746 | else |
747 | CHECK(brgemm_rnn_init_tiles( |
748 | desc_layer_K1_tail_b1_, pallete_buff_k1_tail_)); |
749 | } |
750 | if (rnn.k2_tail) |
751 | CHECK(brgemm_rnn_init_tiles( |
752 | desc_iter_K2_tail_b1_, pallete_buff_k2_tail_)); |
753 | if (rnn.k1_tail && rnn.n_tail) { |
754 | if (rnn.merge_gemm_layer) |
755 | CHECK(brgemm_rnn_init_tiles(desc_layermerged_NK1_tail_b1_, |
756 | pallete_buff_layermerged_nk1_tail_)); |
757 | else |
758 | CHECK(brgemm_rnn_init_tiles( |
759 | desc_layer_NK1_tail_b1_, pallete_buff_nk1_tail_)); |
760 | } |
761 | if (rnn.k2_tail && rnn.n_tail) |
762 | CHECK(brgemm_rnn_init_tiles( |
763 | desc_iter_NK2_tail_b1_, pallete_buff_nk2_tail_)); |
764 | if (rnn.is_lstm_projection) { |
765 | CHECK(brgemm_rnn_init_tiles_proj( |
766 | desc_proj_b0_, pallete_buff_proj_)); |
767 | if (rnn.nproj_tail) |
768 | CHECK(brgemm_rnn_init_tiles_proj( |
769 | desc_proj_N_tail_b0_, pallete_buff_nproj_tail_)); |
770 | if (rnn.kproj_tail) |
771 | CHECK(brgemm_rnn_init_tiles_proj( |
772 | desc_proj_K_tail_b1_, pallete_buff_kproj_tail_)); |
773 | if (rnn.kproj_tail && rnn.nproj_tail) |
774 | CHECK(brgemm_rnn_init_tiles_proj( |
775 | desc_proj_NK_tail_b1_, pallete_buff_nkproj_tail_)); |
776 | } |
777 | } |
778 | |
779 | return status::success; |
780 | } |
781 | |
782 | void rnn_brgemm_t<prop_kind::backward>::init_scratchpad( |
783 | const cpu::rnn_utils::rnn_conf_t &rnn, |
784 | memory_tracking::registrar_t &scratchpad, dim_t gemm_acc_type_size, |
785 | dim_t gemm_acc_align) { |
786 | |
787 | rnn_brgemm_base_t::init_scratchpad( |
788 | rnn, scratchpad, gemm_acc_type_size, gemm_acc_align); |
789 | |
790 | using namespace memory_tracking::names; |
791 | |
792 | // init scratchpad for internal reorders: |
793 | const auto data_size |
794 | = rnn.is_bf16_conf() ? sizeof(bfloat16_t) : sizeof(float); |
795 | const auto &d_wei = rnn.diff_wei_brgemm; |
796 | const auto scratch_gates_blocked_per_thr = d_wei.Kpadded * d_wei.n_block; |
797 | const auto scratch_gates_blocked_size |
798 | = rnn.nthr * scratch_gates_blocked_per_thr; |
799 | scratchpad.book(key_rnn_gates_blocked, scratch_gates_blocked_size, |
800 | data_size, gemm_acc_align); |
801 | |
802 | const auto scratch_src_layer_size = d_wei.global_transpose |
803 | ? d_wei.M_layer * d_wei.Kpadded |
804 | : rnn.nthr * std::min(d_wei.m_block, d_wei.M_layer) * d_wei.Kpadded; |
805 | scratchpad.book(key_rnn_src_layer_trans, scratch_src_layer_size, data_size, |
806 | gemm_acc_align); |
807 | |
808 | const auto scratch_src_iter_size = d_wei.global_transpose |
809 | ? d_wei.M_iter * d_wei.Kpadded |
810 | : rnn.nthr * std::min(d_wei.m_block, d_wei.M_iter) * d_wei.Kpadded; |
811 | scratchpad.book(key_rnn_src_iter_trans, scratch_src_iter_size, data_size, |
812 | gemm_acc_align); |
813 | } |
814 | |
815 | status_t rnn_brgemm_t<prop_kind::backward>::configure_brgemm( |
816 | cpu::rnn_utils::rnn_conf_t &rnn, alg_kind_t cell_kind, |
817 | dim_t src_layer_type_size, dim_t scratch_type_size) { |
818 | using namespace cpu::rnn_utils; |
819 | |
820 | if (rnn.is_int8_conf() || rnn.is_cell_dt_int8()) |
821 | return status::unimplemented; |
822 | |
823 | auto &diff_src_conf = rnn.diff_src_brgemm; |
824 | |
825 | diff_src_conf.M = rnn.mb; |
826 | diff_src_conf.N_iter = rnn.sic; |
827 | diff_src_conf.N_layer = rnn.slc; |
828 | diff_src_conf.N = nstl::max(diff_src_conf.N_iter, diff_src_conf.N_layer); |
829 | diff_src_conf.K = rnn.dhc; |
830 | |
831 | rnn.nthr = dnnl_get_max_threads(); |
832 | diff_src_conf.n_block = 32; |
833 | diff_src_conf.N_blocks |
834 | = utils::div_up(diff_src_conf.N, diff_src_conf.n_block); |
835 | diff_src_conf.n_tail = diff_src_conf.N % diff_src_conf.n_block; |
836 | diff_src_conf.N_layer_blocks |
837 | = utils::div_up(diff_src_conf.N_layer, diff_src_conf.n_block); |
838 | diff_src_conf.n_layer_tail = diff_src_conf.N_layer % diff_src_conf.n_block; |
839 | diff_src_conf.N_iter_blocks |
840 | = utils::div_up(diff_src_conf.N_iter, diff_src_conf.n_block); |
841 | diff_src_conf.n_iter_tail = diff_src_conf.N_iter % diff_src_conf.n_block; |
842 | |
843 | const float work_by_N = static_cast<float>(diff_src_conf.N_blocks) |
844 | / static_cast<float>(rnn.nthr); |
845 | |
846 | const dim_t l2_cache_size = platform::get_per_core_cache_size(2); |
847 | const dim_t As = src_layer_type_size * diff_src_conf.M * diff_src_conf.K; |
848 | const dim_t Bs |
849 | = src_layer_type_size * diff_src_conf.K * diff_src_conf.n_block; |
850 | const dim_t Cs = scratch_type_size * (rnn.n_gates + 1) |
851 | * (diff_src_conf.M * diff_src_conf.n_block); |
852 | |
853 | const auto is_bf16 = rnn.is_cell_dt_bf16(); |
854 | |
855 | const dim_t padding = is_bf16 ? 2 : 1; |
856 | diff_src_conf.Kpadded = utils::rnd_up(diff_src_conf.K, padding); |
857 | |
858 | diff_src_conf.isa |
859 | = brgemm_calc_isa(diff_src_conf.K, diff_src_conf.K, false, is_bf16); |
860 | const bool is_bf16_amx |
861 | = is_bf16 && diff_src_conf.isa == x64::avx512_core_amx; |
862 | const bool split_gates_computation = is_bf16_amx && diff_src_conf.K >= 1024 |
863 | && diff_src_conf.n_tail == 0; |
864 | diff_src_conf.gates_block = split_gates_computation ? 1 : rnn.n_gates; |
865 | |
866 | std::tie(diff_src_conf.k_block, std::ignore) = brgemm_calc_k_block( |
867 | diff_src_conf.K, diff_src_conf.K, diff_src_conf.M, |
868 | diff_src_conf.n_block, cell_kind, src_layer_type_size, As, Bs, Cs, |
869 | l2_cache_size, diff_src_conf.isa, false, rnn.is_cell_dt_bf16()); |
870 | |
871 | diff_src_conf.K_blocks = diff_src_conf.K / diff_src_conf.k_block; |
872 | diff_src_conf.K_blocks *= rnn.n_gates; |
873 | diff_src_conf.k_tail = diff_src_conf.K % diff_src_conf.k_block; |
874 | |
875 | diff_src_conf.m_block = brgemm_calc_m_block(cell_kind, prop_kind::backward, |
876 | rnn.nthr, diff_src_conf.M, diff_src_conf.N_blocks, |
877 | rnn.is_cell_dt_f32(), false, is_bf16_amx, work_by_N, As, Bs, Cs, |
878 | l2_cache_size); |
879 | |
880 | diff_src_conf.M_blocks = diff_src_conf.M / diff_src_conf.m_block; |
881 | diff_src_conf.LDA = rnn.scratch_gates_ld; |
882 | diff_src_conf.LDB = diff_src_conf.n_block; |
883 | diff_src_conf.LDC = rnn.ws_diff_states_iter_ld; |
884 | |
885 | if (diff_src_conf.LDA < diff_src_conf.k_block) return status::unimplemented; |
886 | |
887 | const dim_t n_block = nstl::min(diff_src_conf.N, diff_src_conf.n_block); |
888 | |
889 | if (diff_src_conf.LDB < n_block) return status::unimplemented; |
890 | if (diff_src_conf.LDC < n_block) return status::unimplemented; |
891 | |
892 | rnn.KBproj_blocks = 0; |
893 | rnn.kproj_tail = 0; |
894 | rnn.kproj_block = 0; |
895 | |
896 | auto &diff_wei_conf = rnn.diff_wei_brgemm; |
897 | diff_wei_conf.global_transpose = rnn.mb > 1; |
898 | diff_wei_conf.M_iter = rnn.sic; |
899 | diff_wei_conf.M_layer = rnn.slc; |
900 | diff_wei_conf.M = nstl::max(rnn.sic, rnn.slc); |
901 | diff_wei_conf.N = rnn.dhc * rnn.n_gates; |
902 | diff_wei_conf.K = (scratch_type_size != sizeof(float)) |
903 | ? utils::rnd_up(rnn.mb, 2) |
904 | : rnn.mb; |
905 | diff_wei_conf.Kpadded = utils::rnd_up(diff_wei_conf.K, padding); |
906 | |
907 | diff_wei_conf.isa |
908 | = brgemm_calc_isa(diff_wei_conf.K, diff_wei_conf.K, false, is_bf16); |
909 | |
910 | const bool is_wei_bf16_amx = rnn.is_cell_dt_bf16() |
911 | && diff_wei_conf.isa == x64::avx512_core_amx; |
912 | const bool diff_wei_can_use_nblock64 = is_wei_bf16_amx |
913 | && diff_wei_conf.N % 64 == 0 && !rnn.is_lstm_peephole; |
914 | diff_wei_conf.n_block = diff_wei_can_use_nblock64 ? 64 : 32; |
915 | diff_wei_conf.N_blocks |
916 | = utils::div_up(diff_wei_conf.N, diff_wei_conf.n_block); |
917 | diff_wei_conf.n_tail = diff_wei_conf.N % diff_wei_conf.n_block; |
918 | |
919 | const dim_t As_wei |
920 | = src_layer_type_size * diff_wei_conf.M * diff_wei_conf.K; |
921 | const dim_t Bs_wei |
922 | = src_layer_type_size * diff_wei_conf.K * diff_wei_conf.n_block; |
923 | const dim_t Cs_wei = scratch_type_size * (rnn.n_gates + 1) |
924 | * (diff_wei_conf.M * diff_wei_conf.n_block); |
925 | |
926 | std::tie(diff_wei_conf.k_block, std::ignore) |
927 | = brgemm_calc_k_block(diff_wei_conf.K, diff_wei_conf.K, |
928 | diff_wei_conf.M, diff_wei_conf.n_block, cell_kind, |
929 | src_layer_type_size, As_wei, Bs_wei, Cs_wei, l2_cache_size, |
930 | diff_wei_conf.isa, false, rnn.is_cell_dt_bf16()); |
931 | |
932 | diff_wei_conf.K_blocks = diff_wei_conf.K / diff_wei_conf.k_block; |
933 | diff_wei_conf.k_tail = diff_wei_conf.K % diff_wei_conf.k_block; |
934 | |
935 | if (diff_wei_conf.M_iter != diff_wei_conf.M_layer) { |
936 | diff_wei_conf.m_block = diff_wei_conf.M; |
937 | diff_wei_conf.M_blocks = 1; |
938 | } else { |
939 | const float work_by_N_wei = static_cast<float>(diff_wei_conf.N_blocks) |
940 | / static_cast<float>(rnn.nthr); |
941 | |
942 | diff_wei_conf.m_block |
943 | = brgemm_calc_m_block(cell_kind, prop_kind::backward, rnn.nthr, |
944 | diff_wei_conf.M, diff_wei_conf.N_blocks, |
945 | rnn.is_cell_dt_f32(), false, is_wei_bf16_amx, |
946 | work_by_N_wei, As_wei, Bs_wei, Cs_wei, l2_cache_size); |
947 | diff_wei_conf.M_blocks = diff_wei_conf.M / diff_wei_conf.m_block; |
948 | } |
949 | |
950 | diff_wei_conf.LDA_layer = diff_wei_conf.K; |
951 | diff_wei_conf.LDA_iter = diff_wei_conf.K; |
952 | diff_wei_conf.LDB = diff_wei_conf.n_block; |
953 | diff_wei_conf.LDC_iter = rnn.diff_weights_iter_ld; |
954 | diff_wei_conf.LDC_layer = rnn.diff_weights_layer_ld; |
955 | |
956 | if (diff_wei_conf.LDA_layer < diff_wei_conf.k_block |
957 | || diff_wei_conf.LDA_iter < diff_wei_conf.k_block) |
958 | return status::unimplemented; |
959 | |
960 | if (rnn.is_lstm_peephole) { configure_brgemm_peephole(rnn); } |
961 | |
962 | rnn.M = nstl::max(diff_wei_conf.M, diff_src_conf.M); |
963 | rnn.N = nstl::max(diff_wei_conf.N, diff_src_conf.N); |
964 | rnn.K1 = nstl::max(diff_wei_conf.K, diff_src_conf.K); |
965 | rnn.K2 = rnn.K1; |
966 | rnn.m_block = nstl::max(diff_wei_conf.m_block, diff_src_conf.m_block); |
967 | rnn.M_blocks = nstl::max(diff_wei_conf.M_blocks, diff_src_conf.M_blocks); |
968 | rnn.n_block = nstl::max(diff_wei_conf.n_block, diff_src_conf.n_block); |
969 | rnn.N_blocks = nstl::max(diff_wei_conf.N_blocks, diff_src_conf.N_blocks); |
970 | rnn.n_tail = nstl::max(diff_wei_conf.n_tail, diff_src_conf.n_tail); |
971 | rnn.k1_block = nstl::max(diff_wei_conf.k_block, diff_src_conf.k_block); |
972 | rnn.k2_block = rnn.k1_block; |
973 | rnn.k1_tail = nstl::max(diff_wei_conf.k_tail, diff_src_conf.k_tail); |
974 | rnn.k2_tail = rnn.k1_tail; |
975 | rnn.KB1_blocks = nstl::max(diff_wei_conf.K_blocks, diff_src_conf.K_blocks); |
976 | rnn.KB2_blocks = rnn.KB1_blocks; |
977 | rnn.K1padded = nstl::max(diff_wei_conf.Kpadded, diff_src_conf.Kpadded); |
978 | rnn.K2padded = rnn.K1padded; |
979 | rnn.unfused_post_gemm = true; |
980 | |
981 | if (utils::one_of( |
982 | x64::avx512_core_amx, diff_wei_conf.isa, diff_src_conf.isa)) { |
983 | rnn.brgemm_isa = x64::avx512_core_amx; |
984 | } else { |
985 | rnn.brgemm_isa = diff_wei_conf.isa; |
986 | } |
987 | |
988 | if (!rnn.is_orig_gru) { |
989 | rnn.diff_src_brgemm.loop_order |
990 | = is_bf16 && diff_src_conf.isa == x64::avx512_core_amx |
991 | ? brgemm_rnn_execute_loop_order_t::mblk_nblk |
992 | : brgemm_rnn_execute_loop_order_t::nblk_mblk; |
993 | rnn.diff_wei_brgemm.loop_order |
994 | = is_bf16 && diff_wei_conf.isa == x64::avx512_core_amx |
995 | ? brgemm_rnn_execute_loop_order_t::mblk_nblk |
996 | : brgemm_rnn_execute_loop_order_t::nblk_mblk; |
997 | } |
998 | |
999 | return status::success; |
1000 | } |
1001 | |
1002 | static dim_t divide_block_to_improve_thread_balance( |
1003 | const dim_t initial_work_amount, const dim_t division_block, |
1004 | const dim_t nthr) { |
1005 | |
1006 | const float nthr_f = static_cast<float>(nthr); |
1007 | const float initial_work = static_cast<float>(initial_work_amount) / nthr_f; |
1008 | const float decimal_initial_factor |
1009 | = initial_work - std::floor(initial_work); |
1010 | static constexpr float thread_balance_threashold = 0.8; |
1011 | static constexpr float tolerance = 0.01; |
1012 | |
1013 | float max_decimal_factor = -1.0; |
1014 | dim_t best_candidate = -1.0; |
1015 | bool found_best_solution = false; |
1016 | |
1017 | if (decimal_initial_factor < thread_balance_threashold |
1018 | && decimal_initial_factor != 0.0f) { |
1019 | |
1020 | for (const int block_size : {4096, 2048, 1024, 512, 256, 128, 64, 32}) { |
1021 | |
1022 | if (division_block <= block_size) continue; |
1023 | |
1024 | const auto blocks = utils::div_up(division_block, block_size); |
1025 | |
1026 | const float work |
1027 | = static_cast<float>(initial_work_amount * blocks) / nthr_f; |
1028 | const float work_decimal = work - std::floor(work); |
1029 | |
1030 | if (work_decimal == 0.0f |
1031 | || (max_decimal_factor != 0.0f |
1032 | ? work_decimal |
1033 | > (max_decimal_factor + tolerance) |
1034 | : work_decimal >= thread_balance_threashold) |
1035 | |
1036 | ) { |
1037 | best_candidate = block_size; |
1038 | max_decimal_factor = work_decimal; |
1039 | } |
1040 | |
1041 | if (work >= nthr_f |
1042 | && (work_decimal >= thread_balance_threashold |
1043 | || work_decimal == 0.0f)) { |
1044 | found_best_solution = true; |
1045 | break; |
1046 | } |
1047 | } |
1048 | } |
1049 | |
1050 | if (found_best_solution |
1051 | || (!found_best_solution |
1052 | && max_decimal_factor |
1053 | > decimal_initial_factor + tolerance)) { |
1054 | return best_candidate; |
1055 | } |
1056 | |
1057 | return division_block; |
1058 | } |
1059 | |
1060 | void rnn_brgemm_t<prop_kind::backward>::configure_brgemm_peephole( |
1061 | cpu::rnn_utils::rnn_conf_t &rnn) { |
1062 | static constexpr dim_t n_gates = 3; |
1063 | rnn.dhc_block_peephole = divide_block_to_improve_thread_balance( |
1064 | n_gates, rnn.dhc, rnn.nthr); |
1065 | rnn.dhc_blocks_peephole = utils::div_up(rnn.dhc, rnn.dhc_block_peephole); |
1066 | rnn.dhc_tail_peephole = rnn.dhc % rnn.dhc_block_peephole; |
1067 | } |
1068 | |
1069 | static status_t init_kernels_diff_src(rnn_diff_src_brgemm_t &diff_src, |
1070 | const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type, |
1071 | data_type_t weights_type) { |
1072 | |
1073 | const auto init_brgemm_diff_src |
1074 | = [&](x64::brgemm_t *desc, x64::cpu_isa_t isa, |
1075 | std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M, |
1076 | dim_t N, dim_t K, dim_t LDA, dim_t LDB, dim_t LDC, |
1077 | float beta, dim_t max_bs) { |
1078 | const dim_t A_size |
1079 | = rnn.diff_src_brgemm.M * rnn.diff_src_brgemm.Kpadded; |
1080 | const dim_t B_size |
1081 | = rnn.diff_src_brgemm.Kpadded * rnn.diff_src_brgemm.N; |
1082 | const dim_t C_size |
1083 | = rnn.diff_src_brgemm.M * rnn.diff_src_brgemm.N; |
1084 | return init_brgemm_kernel(desc, isa, src_type, weights_type, |
1085 | ker, M, N, K, LDA, LDB, LDC, beta, max_bs, A_size, |
1086 | B_size, C_size); |
1087 | }; |
1088 | |
1089 | const auto &diff_src_conf = rnn.diff_src_brgemm; |
1090 | const int n_diff_src = nstl::min(diff_src_conf.N, diff_src_conf.n_block); |
1091 | const int n_diff_src_iter_tail |
1092 | = nstl::min(diff_src_conf.N_iter, diff_src_conf.n_iter_tail); |
1093 | const int n_diff_src_layer_tail |
1094 | = nstl::min(diff_src_conf.N_layer, diff_src_conf.n_layer_tail); |
1095 | const auto K_batch_size = rnn.n_gates * diff_src_conf.K_blocks; |
1096 | const auto split_gates_computation |
1097 | = diff_src_conf.gates_block != rnn.n_gates; |
1098 | init_brgemm_diff_src(&diff_src.desc_iter_layer_beta0_, diff_src_conf.isa, |
1099 | diff_src.kernel_iter_layer_beta0_, diff_src_conf.m_block, |
1100 | n_diff_src, diff_src_conf.k_block, diff_src_conf.LDA, |
1101 | diff_src_conf.LDB, diff_src_conf.LDC, 0.0, K_batch_size); |
1102 | if (split_gates_computation) |
1103 | init_brgemm_diff_src(&diff_src.desc_iter_layer_beta1_, |
1104 | diff_src_conf.isa, diff_src.kernel_iter_layer_beta1_, |
1105 | diff_src_conf.m_block, n_diff_src, diff_src_conf.k_block, |
1106 | diff_src_conf.LDA, diff_src_conf.LDB, diff_src_conf.LDC, 1.0, |
1107 | K_batch_size); |
1108 | |
1109 | if (n_diff_src_layer_tail) { |
1110 | init_brgemm_diff_src(&diff_src.desc_layer_N_tail_beta0_, |
1111 | diff_src_conf.isa, diff_src.kernel_layer_N_tail_beta0_, |
1112 | diff_src_conf.m_block, n_diff_src_layer_tail, |
1113 | diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB, |
1114 | diff_src_conf.LDC, 0.0, K_batch_size); |
1115 | if (split_gates_computation) |
1116 | init_brgemm_diff_src(&diff_src.desc_layer_N_tail_beta1_, |
1117 | diff_src_conf.isa, diff_src.kernel_layer_N_tail_beta1_, |
1118 | diff_src_conf.m_block, n_diff_src_layer_tail, |
1119 | diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB, |
1120 | diff_src_conf.LDC, 1.0, K_batch_size); |
1121 | } |
1122 | |
1123 | if (n_diff_src_iter_tail) { |
1124 | init_brgemm_diff_src(&diff_src.desc_iter_N_tail_beta0_, |
1125 | diff_src_conf.isa, diff_src.kernel_iter_N_tail_beta0_, |
1126 | diff_src_conf.m_block, n_diff_src_iter_tail, |
1127 | diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB, |
1128 | diff_src_conf.LDC, 0.0, K_batch_size); |
1129 | if (split_gates_computation) |
1130 | init_brgemm_diff_src(&diff_src.desc_iter_N_tail_beta1_, |
1131 | diff_src_conf.isa, diff_src.kernel_iter_N_tail_beta1_, |
1132 | diff_src_conf.m_block, n_diff_src_iter_tail, |
1133 | diff_src_conf.k_block, diff_src_conf.LDA, diff_src_conf.LDB, |
1134 | diff_src_conf.LDC, 1.0, K_batch_size); |
1135 | } |
1136 | |
1137 | if (diff_src_conf.k_tail) { |
1138 | init_brgemm_diff_src(&diff_src.desc_iter_layer_K_tail_beta1_, |
1139 | diff_src_conf.isa, diff_src.kernel_iter_layer_K_tail_beta1_, |
1140 | diff_src_conf.m_block, n_diff_src, diff_src_conf.k_tail, |
1141 | diff_src_conf.LDA, diff_src_conf.LDB, diff_src_conf.LDC, 1.0, |
1142 | rnn.n_gates); |
1143 | |
1144 | if (n_diff_src_layer_tail) { |
1145 | init_brgemm_diff_src(&diff_src.desc_layer_NK_tail_beta1_, |
1146 | diff_src_conf.isa, diff_src.kernel_layer_NK_tail_beta1_, |
1147 | diff_src_conf.m_block, n_diff_src_layer_tail, |
1148 | diff_src_conf.k_tail, diff_src_conf.LDA, diff_src_conf.LDB, |
1149 | diff_src_conf.LDC, 1.0, rnn.n_gates); |
1150 | } |
1151 | |
1152 | if (n_diff_src_iter_tail) { |
1153 | init_brgemm_diff_src(&diff_src.desc_iter_NK_tail_beta1_, |
1154 | diff_src_conf.isa, diff_src.kernel_iter_NK_tail_beta1_, |
1155 | diff_src_conf.m_block, n_diff_src_iter_tail, |
1156 | diff_src_conf.k_tail, diff_src_conf.LDA, diff_src_conf.LDB, |
1157 | diff_src_conf.LDC, 1.0, rnn.n_gates); |
1158 | } |
1159 | } |
1160 | |
1161 | const bool is_bf16_amx = rnn.is_cell_dt_bf16() |
1162 | && diff_src_conf.isa == x64::avx512_core_amx; |
1163 | |
1164 | if (is_bf16_amx) { |
1165 | CHECK(brgemm_init_tiles(diff_src.desc_iter_layer_beta0_, |
1166 | diff_src.pallete_buff_iter_layer_)); |
1167 | |
1168 | if (n_diff_src_layer_tail) |
1169 | CHECK(brgemm_init_tiles(diff_src.desc_layer_N_tail_beta0_, |
1170 | diff_src.pallete_buff_layer_n_tail_)); |
1171 | |
1172 | if (n_diff_src_iter_tail) |
1173 | CHECK(brgemm_init_tiles(diff_src.desc_iter_N_tail_beta0_, |
1174 | diff_src.pallete_buff_iter_n_tail_)); |
1175 | |
1176 | if (diff_src_conf.k_tail) { |
1177 | CHECK(brgemm_init_tiles(diff_src.desc_iter_layer_K_tail_beta1_, |
1178 | diff_src.pallete_buff_iter_layer_k_tail_)); |
1179 | |
1180 | if (n_diff_src_layer_tail) |
1181 | CHECK(brgemm_init_tiles(diff_src.desc_layer_NK_tail_beta1_, |
1182 | diff_src.pallete_buff_layer_nk_tail_)); |
1183 | |
1184 | if (n_diff_src_iter_tail) |
1185 | CHECK(brgemm_init_tiles(diff_src.desc_iter_NK_tail_beta1_, |
1186 | diff_src.pallete_buff_iter_nk_tail_)); |
1187 | } |
1188 | } |
1189 | |
1190 | return status::success; |
1191 | } |
1192 | |
1193 | static status_t init_kernels_diff_wei(rnn_diff_wei_brgemm_t &diff_wei, |
1194 | const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type, |
1195 | data_type_t weights_type) { |
1196 | |
1197 | const auto init_brgemm_diff_wei |
1198 | = [&](x64::brgemm_t *desc, x64::cpu_isa_t isa, |
1199 | std::unique_ptr<x64::brgemm_kernel_t> &ker, dim_t M, |
1200 | dim_t N, dim_t K, dim_t LDA, dim_t LDB, dim_t LDC, |
1201 | float beta, dim_t max_bs) { |
1202 | const dim_t A_size |
1203 | = rnn.diff_wei_brgemm.M * rnn.diff_wei_brgemm.Kpadded; |
1204 | const dim_t B_size |
1205 | = rnn.diff_wei_brgemm.Kpadded * rnn.diff_wei_brgemm.N; |
1206 | const dim_t C_size |
1207 | = rnn.diff_wei_brgemm.M * rnn.diff_wei_brgemm.N; |
1208 | return init_brgemm_kernel(desc, isa, src_type, weights_type, |
1209 | ker, M, N, K, LDA, LDB, LDC, beta, max_bs, A_size, |
1210 | B_size, C_size); |
1211 | }; |
1212 | |
1213 | const auto &diff_wei_conf = rnn.diff_wei_brgemm; |
1214 | const bool is_m_block_equal = rnn.slc == rnn.sic; |
1215 | const auto m_block_iter |
1216 | = is_m_block_equal ? diff_wei_conf.m_block : diff_wei_conf.M_iter; |
1217 | const auto m_block_layer |
1218 | = is_m_block_equal ? diff_wei_conf.m_block : diff_wei_conf.M_layer; |
1219 | const auto n_diff_wei = nstl::min(diff_wei_conf.N, diff_wei_conf.n_block); |
1220 | const auto n_diff_wei_tail |
1221 | = nstl::min(diff_wei_conf.N, diff_wei_conf.n_tail); |
1222 | |
1223 | init_brgemm_diff_wei(&diff_wei.desc_iter_beta1_, diff_wei_conf.isa, |
1224 | diff_wei.kernel_iter_beta1_, m_block_iter, n_diff_wei, |
1225 | diff_wei_conf.k_block, diff_wei_conf.LDA_iter, diff_wei_conf.LDB, |
1226 | diff_wei_conf.LDC_iter, 1.0, diff_wei_conf.K_blocks); |
1227 | init_brgemm_diff_wei(&diff_wei.desc_layer_beta1_, diff_wei_conf.isa, |
1228 | diff_wei.kernel_layer_beta1_, m_block_layer, n_diff_wei, |
1229 | diff_wei_conf.k_block, diff_wei_conf.LDA_layer, diff_wei_conf.LDB, |
1230 | diff_wei_conf.LDC_layer, 1.0, diff_wei_conf.K_blocks); |
1231 | |
1232 | if (n_diff_wei_tail) { |
1233 | init_brgemm_diff_wei(&diff_wei.desc_iter_N_tail_beta1_, |
1234 | diff_wei_conf.isa, diff_wei.kernel_iter_N_tail_beta1_, |
1235 | m_block_iter, n_diff_wei_tail, diff_wei_conf.k_block, |
1236 | diff_wei_conf.LDA_iter, diff_wei_conf.LDB, |
1237 | diff_wei_conf.LDC_iter, 1.0, diff_wei_conf.K_blocks); |
1238 | init_brgemm_diff_wei(&diff_wei.desc_layer_N_tail_beta1_, |
1239 | diff_wei_conf.isa, diff_wei.kernel_layer_N_tail_beta1_, |
1240 | m_block_layer, n_diff_wei_tail, diff_wei_conf.k_block, |
1241 | diff_wei_conf.LDA_layer, diff_wei_conf.LDB, |
1242 | diff_wei_conf.LDC_layer, 1.0, diff_wei_conf.K_blocks); |
1243 | |
1244 | if (diff_wei_conf.k_tail) { |
1245 | init_brgemm_diff_wei(&diff_wei.desc_iter_NK_tail_beta1_, |
1246 | diff_wei_conf.isa, diff_wei.kernel_iter_NK_tail_beta1_, |
1247 | m_block_iter, n_diff_wei_tail, diff_wei_conf.k_tail, |
1248 | diff_wei_conf.LDA_iter, diff_wei_conf.LDB, |
1249 | diff_wei_conf.LDC_iter, 1.0, 1); |
1250 | init_brgemm_diff_wei(&diff_wei.desc_layer_NK_tail_beta1_, |
1251 | diff_wei_conf.isa, diff_wei.kernel_layer_NK_tail_beta1_, |
1252 | m_block_layer, n_diff_wei_tail, diff_wei_conf.k_tail, |
1253 | diff_wei_conf.LDA_layer, diff_wei_conf.LDB, |
1254 | diff_wei_conf.LDC_layer, 1.0, 1); |
1255 | } |
1256 | } |
1257 | |
1258 | if (diff_wei_conf.k_tail) { |
1259 | init_brgemm_diff_wei(&diff_wei.desc_iter_K_tail_beta1_, |
1260 | diff_wei_conf.isa, diff_wei.kernel_iter_K_tail_beta1_, |
1261 | m_block_iter, n_diff_wei, diff_wei_conf.k_tail, |
1262 | diff_wei_conf.LDA_iter, diff_wei_conf.LDB, |
1263 | diff_wei_conf.LDC_iter, 1.0, 1); |
1264 | init_brgemm_diff_wei(&diff_wei.desc_layer_K_tail_beta1_, |
1265 | diff_wei_conf.isa, diff_wei.kernel_layer_K_tail_beta1_, |
1266 | m_block_layer, n_diff_wei, diff_wei_conf.k_tail, |
1267 | diff_wei_conf.LDA_layer, diff_wei_conf.LDB, |
1268 | diff_wei_conf.LDC_layer, 1.0, 1); |
1269 | } |
1270 | |
1271 | const bool is_bf16_amx_wei = rnn.is_cell_dt_bf16() |
1272 | && diff_wei_conf.isa == x64::avx512_core_amx; |
1273 | |
1274 | if (is_bf16_amx_wei) { |
1275 | CHECK(brgemm_init_tiles( |
1276 | diff_wei.desc_iter_beta1_, diff_wei.pallete_buff_iter_)); |
1277 | CHECK(brgemm_init_tiles( |
1278 | diff_wei.desc_layer_beta1_, diff_wei.pallete_buff_layer_)); |
1279 | if (n_diff_wei_tail) { |
1280 | CHECK(brgemm_init_tiles(diff_wei.desc_iter_N_tail_beta1_, |
1281 | diff_wei.pallete_buff_iter_n_tail_)); |
1282 | CHECK(brgemm_init_tiles(diff_wei.desc_layer_N_tail_beta1_, |
1283 | diff_wei.pallete_buff_layer_n_tail_)); |
1284 | |
1285 | if (diff_wei_conf.k_tail) { |
1286 | CHECK(brgemm_init_tiles(diff_wei.desc_iter_NK_tail_beta1_, |
1287 | diff_wei.pallete_buff_iter_nk_tail_)); |
1288 | CHECK(brgemm_init_tiles(diff_wei.desc_layer_NK_tail_beta1_, |
1289 | diff_wei.pallete_buff_layer_nk_tail_)); |
1290 | } |
1291 | } |
1292 | |
1293 | if (diff_wei_conf.k_tail) { |
1294 | CHECK(brgemm_init_tiles(diff_wei.desc_iter_K_tail_beta1_, |
1295 | diff_wei.pallete_buff_iter_k_tail_)); |
1296 | CHECK(brgemm_init_tiles(diff_wei.desc_layer_K_tail_beta1_, |
1297 | diff_wei.pallete_buff_layer_k_tail_)); |
1298 | } |
1299 | } |
1300 | |
1301 | // Creating temporary matmul configuration descriptor to use copy_B jit |
1302 | // kernels from brgemm matmul copy routines for reordering scratch gates in |
1303 | // diff_wei rnn brgemm implementation. |
1304 | // TODO: provide unification of jit-based copy routines with implementation |
1305 | // independent interface |
1306 | matmul::brgemm_matmul_conf_t tmp_matmul_conf_for_reorder; |
1307 | tmp_matmul_conf_for_reorder.wei_tag = format_tag::ab; |
1308 | tmp_matmul_conf_for_reorder.N = rnn.scratch_gates_ld; |
1309 | tmp_matmul_conf_for_reorder.K = rnn.mb; |
1310 | tmp_matmul_conf_for_reorder.wei_n_blk = tmp_matmul_conf_for_reorder.N_blk |
1311 | = diff_wei_conf.n_block; |
1312 | tmp_matmul_conf_for_reorder.N_tail = diff_wei_conf.n_tail; |
1313 | tmp_matmul_conf_for_reorder.LDB = diff_wei_conf.LDB; |
1314 | tmp_matmul_conf_for_reorder.src_dt = tmp_matmul_conf_for_reorder.wei_dt |
1315 | = rnn.is_cell_dt_bf16() ? data_type::bf16 : data_type::f32; |
1316 | tmp_matmul_conf_for_reorder.a_dt_sz = tmp_matmul_conf_for_reorder.tr_a_dt_sz |
1317 | = types::data_type_size(tmp_matmul_conf_for_reorder.src_dt); |
1318 | tmp_matmul_conf_for_reorder.b_dt_sz = tmp_matmul_conf_for_reorder.tr_b_dt_sz |
1319 | = types::data_type_size(tmp_matmul_conf_for_reorder.wei_dt); |
1320 | CHECK(matmul::create_brgemm_matmul_copy_b( |
1321 | diff_wei.srcatch_gates_reorder_kernel_, |
1322 | &tmp_matmul_conf_for_reorder)); |
1323 | |
1324 | return status::success; |
1325 | } |
1326 | |
1327 | status_t rnn_brgemm_t<prop_kind::backward>::init_kernels( |
1328 | const cpu::rnn_utils::rnn_conf_t &rnn, data_type_t src_type, |
1329 | data_type_t weights_type) { |
1330 | |
1331 | init_kernels_diff_src(diff_src_, rnn, src_type, weights_type); |
1332 | init_kernels_diff_wei(diff_wei_, rnn, src_type, weights_type); |
1333 | if (rnn.is_lstm_peephole) CHECK(init_peephole_kernels(rnn)); |
1334 | |
1335 | const auto n_diff_wei_tail |
1336 | = nstl::min(rnn.diff_wei_brgemm.N, rnn.diff_wei_brgemm.n_tail); |
1337 | kernel_gates_reduction_ |
1338 | = utils::make_unique<jit_gates_reduction_t>(rnn, false /*n_tail*/); |
1339 | kernel_gates_reduction_->create_kernel(); |
1340 | |
1341 | if (n_diff_wei_tail) { |
1342 | kernel_gates_reduction_tail_ |
1343 | = utils::make_unique<jit_gates_reduction_t>( |
1344 | rnn, true /*n_tail*/); |
1345 | kernel_gates_reduction_tail_->create_kernel(); |
1346 | } |
1347 | |
1348 | if (rnn.mb == 1) { |
1349 | if (src_type == data_type::bf16) { |
1350 | const bool is_m_block_equal = rnn.slc == rnn.sic; |
1351 | const auto m_block_iter = is_m_block_equal |
1352 | ? rnn.diff_wei_brgemm.m_block |
1353 | : rnn.diff_wei_brgemm.M_iter; |
1354 | |
1355 | kernel_transpose_single_row_iter_ |
1356 | = utils::make_unique<jit_brgemm_transpose_single_row_t>( |
1357 | m_block_iter); |
1358 | CHECK(kernel_transpose_single_row_iter_->create_kernel()); |
1359 | |
1360 | if (!is_m_block_equal) { |
1361 | const auto m_block_layer = is_m_block_equal |
1362 | ? rnn.diff_wei_brgemm.m_block |
1363 | : rnn.diff_wei_brgemm.M_layer; |
1364 | kernel_transpose_single_row_layer_ |
1365 | = utils::make_unique<jit_brgemm_transpose_single_row_t>( |
1366 | m_block_layer); |
1367 | CHECK(kernel_transpose_single_row_layer_->create_kernel()); |
1368 | } |
1369 | } |
1370 | } else { |
1371 | jit_brgemm_primitive_conf_t trans_conf; |
1372 | trans_conf.prop_kind = dnnl_backward_weights; |
1373 | trans_conf.src_dt = src_type; |
1374 | static constexpr int blk_size = 16; |
1375 | trans_conf.os_block = blk_size; // src's rows block size |
1376 | trans_conf.ic_block = blk_size; // src's cols block size |
1377 | trans_conf.M = 0; |
1378 | const auto rnd_up_size = (src_type == data_type::bf16 ? 2 : 1); |
1379 | trans_conf.LDA |
1380 | = utils::rnd_up(rnn.mb, rnd_up_size); // dst's leading dim |
1381 | trans_conf.K_tail = rnn.mb % blk_size; // src's rows tail |
1382 | |
1383 | const int LDA_iter[] |
1384 | = {rnn.src_iter_ld_, rnn.dst_layer_ld_, rnn.ws_states_iter_ld}; |
1385 | trans_conf.M_tail = rnn.sic % blk_size; // src's cols tail |
1386 | for (int i = 0; i < num_base_kernels_; i++) { |
1387 | trans_conf.ic = LDA_iter[i]; |
1388 | CHECK(create_brgemm_trans_src( |
1389 | kernel_transpose_iter_[i], &trans_conf)); |
1390 | } |
1391 | |
1392 | const int LDA_layer[] |
1393 | = {rnn.src_layer_ld_, rnn.dst_iter_ld_, rnn.ws_states_layer_ld}; |
1394 | trans_conf.M_tail = rnn.slc % blk_size; // src's cols tail |
1395 | for (int i = 0; i < num_base_kernels_; i++) { |
1396 | trans_conf.ic = LDA_layer[i]; |
1397 | CHECK(create_brgemm_trans_src( |
1398 | kernel_transpose_layer_[i], &trans_conf)); |
1399 | } |
1400 | } |
1401 | |
1402 | return status::success; |
1403 | } |
1404 | |
1405 | status_t rnn_brgemm_t<prop_kind::backward>::init_peephole_kernels( |
1406 | const cpu::rnn_utils::rnn_conf_t &rnn) { |
1407 | |
1408 | if (rnn.dhc_blocks_peephole) { |
1409 | kernel_peephole_ = utils::make_unique<jit_diff_weights_peephole_t>( |
1410 | rnn, rnn.dhc_block_peephole); |
1411 | CHECK(kernel_peephole_->create_kernel()); |
1412 | } |
1413 | |
1414 | if (rnn.dhc_tail_peephole) { |
1415 | kernel_peephole_tail_ = utils::make_unique<jit_diff_weights_peephole_t>( |
1416 | rnn, rnn.dhc_tail_peephole); |
1417 | CHECK(kernel_peephole_tail_->create_kernel()); |
1418 | } |
1419 | |
1420 | return status::success; |
1421 | } |
1422 | |
1423 | } // namespace rnn_brgemm_utils |
1424 | } // namespace x64 |
1425 | } // namespace cpu |
1426 | } // namespace impl |
1427 | } // namespace dnnl |
1428 | |