1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "brgemm_cell_common_fwd.hpp"
18
19#include "common/dnnl_thread.hpp"
20#include "common/utils.hpp"
21#include "cpu/x64/rnn/brgemm_cell_common_utils.hpp"
22
23using namespace dnnl::impl::utils;
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30template <typename src_t, typename weights_t, typename scratch_t,
31 typename gemm_acc_t>
32brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t,
33 gemm_acc_t>::brgemm_dst_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm,
34 const rnn_utils::rnn_conf_t &rnn,
35 rnn_utils::cell_position_t cell_position, const src_t *src_iter,
36 const src_t *src_layer, weights_t *w_iter, weights_t *w_layer,
37 scratch_t *scratch_gates, gemm_acc_t *amx_scratchpad,
38 x64::brgemm_batch_element_t *addr_batch_global,
39 const postgemm_fused_t &fused_postgemm)
40 : rnn_brgemm_(rnn_brgemm)
41 , rnn_(rnn)
42 , need_gemm_layer_(rnn_.need_gemm_layer(cell_position))
43 , layer_desc_idx_(rnn_.layer_brgemm_desc(cell_position))
44 , iter_desc_idx_(rnn_.iter_brgemm_desc(cell_position))
45 , Al_(src_layer)
46 , Ai_(src_iter)
47 , Bl_(w_layer)
48 , Bi_(w_iter)
49 , C_(scratch_gates)
50 , LDAl_(rnn_.src_layer_ld(cell_position))
51 , LDAi_(rnn_.src_iter_ld(cell_position))
52 , max_nthr_(rnn_.nthr)
53 , n_blocking_((rnn_.unfused_post_gemm) ? rnn_.N_blocks * rnn_.n_gates
54 : rnn_.N_blocks)
55 , m_blocking_(rnn_.M_blocks)
56 , work_amount_(n_blocking_ * m_blocking_)
57 , Bl_n_offset_(rnn_.K1padded * rnn_.n_block)
58 , Bi_n_offset_(rnn_.K2padded * rnn_.n_block)
59 , Bl_g_offset_(rnn_.N_blocks * Bl_n_offset_)
60 , Bi_g_offset_(rnn_.N_blocks * Bi_n_offset_)
61 , Al_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block)
62 , Ai_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block)
63 , Bl_kb_offset_(rnn_.k1_block * rnn_.n_block)
64 , Bi_kb_offset_(rnn_.k2_block * rnn_.n_block)
65 , Bl_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block * rnn_.n_block)
66 , Bi_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block * rnn_.n_block)
67 , n_gates_(rnn.unfused_post_gemm ? 1 : rnn.n_gates)
68 , brgemm_kernel_iter_main_(
69 rnn_brgemm_.kernel_iter_b1_[iter_desc_idx_].get())
70 , brgemm_kernel_iter_n_tail_(
71 rnn_brgemm_.kernel_iter_N_tail_b1_[iter_desc_idx_].get())
72 , brgemm_kernel_iter_k_tail_(
73 rnn_brgemm_.kernel_iter_K2_tail_b1_[iter_desc_idx_].get())
74 , brgemm_kernel_iter_nk_tail_(
75 rnn_brgemm_.kernel_iter_NK2_tail_b1_[iter_desc_idx_].get())
76 , brgemm_kernel_layer_main_(
77 rnn_brgemm_.kernel_layer_b0_[layer_desc_idx_].get())
78 , brgemm_kernel_layer_n_tail_(
79 rnn_brgemm_.kernel_layer_N_tail_b0_[layer_desc_idx_].get())
80 , brgemm_kernel_layer_k_tail_(
81 rnn_brgemm_.kernel_layer_K1_tail_b1_[layer_desc_idx_].get())
82 , brgemm_kernel_layer_nk_tail_(
83 rnn_brgemm_.kernel_layer_NK1_tail_b1_[layer_desc_idx_].get())
84 , pallete_buff_iter_main_(rnn.k1_block == rnn.k2_block && need_gemm_layer_
85 ? rnn_brgemm_.pallete_buff_layer_
86 : rnn_brgemm_.pallete_buff_iter_)
87 , pallete_buff_iter_n_tail_(rnn.k1_block == rnn.k2_block && need_gemm_layer_
88 ? rnn_brgemm_.pallete_buff_layer_n_tail_
89 : rnn_brgemm_.pallete_buff_iter_n_tail_)
90 , pallete_buff_iter_k_tail_(rnn.k1_tail == rnn.k2_tail && need_gemm_layer_
91 ? rnn_brgemm_.pallete_buff_k1_tail_
92 : rnn_brgemm_.pallete_buff_k2_tail_)
93 , pallete_buff_iter_nk_tail_(rnn.k1_tail == rnn.k2_tail && need_gemm_layer_
94 ? rnn_brgemm_.pallete_buff_nk1_tail_
95 : rnn_brgemm_.pallete_buff_nk2_tail_)
96 , pallete_buff_layer_main_(rnn_brgemm_.pallete_buff_layer_)
97 , pallete_buff_layer_n_tail_(rnn_brgemm_.pallete_buff_layer_n_tail_)
98 , pallete_buff_layer_k_tail_(rnn_brgemm_.pallete_buff_k1_tail_)
99 , pallete_buff_layer_nk_tail_(rnn_brgemm_.pallete_buff_nk1_tail_)
100 , amx_scratchpad_(amx_scratchpad)
101 , addr_batch_global_(addr_batch_global)
102 , fused_postgemm_(fused_postgemm)
103 , is_fused_layer_iter_brgemm_(
104 rnn_.sic == rnn_.slc && LDAi_ == LDAl_ && need_gemm_layer_) {}
105
106template <typename src_t, typename weights_t, typename scratch_t,
107 typename gemm_acc_t>
108void brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t, gemm_acc_t>::execute()
109 const {
110 if (is_fused_layer_iter_brgemm_) {
111 parallel(max_nthr_, [this](const int ithr, const int nthr) {
112 this->kernel_fused_iter_layer(ithr, nthr);
113 });
114 } else {
115 parallel(max_nthr_, [this](const int ithr, const int nthr) {
116 this->kernel(ithr, nthr);
117 });
118 }
119}
120
121template <typename src_t, typename weights_t, typename scratch_t,
122 typename gemm_acc_t>
123void brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t, gemm_acc_t>::kernel(
124 const int ithr, const int nthr) const {
125 using namespace cpu::rnn_utils;
126
127 int start = 0, end = 0;
128 balance211(work_amount_, nthr, ithr, start, end);
129
130 const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx();
131 gemm_acc_t *const amx_buffer = is_amx
132 ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr
133 : nullptr;
134 const int max_K_Block = nstl::max(rnn_.KB1_blocks + 1,
135 nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1));
136 brgemm_batch_element_t *const addr_batch
137 = addr_batch_global_ + ithr * max_K_Block;
138
139 const char *pallete_buff_iter = nullptr;
140 const char *pallete_buff_layer = nullptr;
141 const char *pallete_buff_iter_k_tail = nullptr;
142 const char *pallete_buff_layer_k_tail = nullptr;
143
144 dim_t nb_i = 0, mb = 0;
145 switch (rnn_.loop_order) {
146 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
147 nd_iterator_init(start, mb, m_blocking_, nb_i, n_blocking_);
148 break;
149 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
150 nd_iterator_init(start, nb_i, n_blocking_, mb, m_blocking_);
151 break;
152 default: assert(!"unsupported loop order");
153 }
154
155 amx_tile_configuration_loader_t load_cfg_if_needed;
156
157 while (start < end) {
158 const auto m = mb * rnn_.m_block;
159 const auto nb = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i;
160 const auto n = nb * rnn_.n_block;
161 const auto g_unfused
162 = (rnn_.unfused_post_gemm) ? nb_i % rnn_.n_gates : 0;
163
164 const auto *const Al_m = Al_ + m * LDAl_;
165 const auto *const Ai_m = Ai_ + m * LDAi_;
166 const auto *const Bl_n = Bl_ + nb * Bl_n_offset_;
167 const auto *const Bi_n = Bi_ + nb * Bi_n_offset_;
168 auto *const C_n = C_ + m * rnn_.LDC + n;
169
170 const brgemm_kernel_t *brgemm_kernel_layer_b0
171 = brgemm_kernel_layer_main_;
172 const brgemm_kernel_t *brgemm_kernel_iter = brgemm_kernel_iter_main_;
173 const brgemm_kernel_t *brgemm_kernel_layer_k_tail
174 = brgemm_kernel_layer_k_tail_;
175 const brgemm_kernel_t *brgemm_kernel_iter_k_tail
176 = brgemm_kernel_iter_k_tail_;
177
178 if (is_amx) {
179 pallete_buff_iter = pallete_buff_iter_main_;
180 pallete_buff_layer = pallete_buff_layer_main_;
181 pallete_buff_iter_k_tail = pallete_buff_iter_k_tail_;
182 pallete_buff_layer_k_tail = pallete_buff_layer_k_tail_;
183 }
184
185 const bool do_n_tail = (n + rnn_.n_block) > rnn_.N;
186 if (do_n_tail) {
187 brgemm_kernel_layer_b0 = brgemm_kernel_layer_n_tail_;
188 brgemm_kernel_iter = brgemm_kernel_iter_n_tail_;
189 brgemm_kernel_layer_k_tail = brgemm_kernel_layer_nk_tail_;
190 brgemm_kernel_iter_k_tail = brgemm_kernel_iter_nk_tail_;
191
192 if (is_amx) {
193 pallete_buff_iter = pallete_buff_iter_n_tail_;
194 pallete_buff_layer = pallete_buff_layer_n_tail_;
195 pallete_buff_iter_k_tail = pallete_buff_iter_nk_tail_;
196 pallete_buff_layer_k_tail = pallete_buff_layer_nk_tail_;
197 }
198 }
199
200 for (int g = 0; g < n_gates_; g++) {
201 const int lg = g + g_unfused;
202 const auto *const Bl_g = Bl_n + lg * Bl_g_offset_;
203 const auto *const Bi_g = Bi_n + lg * Bi_g_offset_;
204 auto *const C_g = C_n + lg * rnn_.N;
205
206 if (need_gemm_layer_) {
207 if (is_amx) load_cfg_if_needed(pallete_buff_layer);
208 for (int i = 0; i < rnn_.KB1_blocks; i++) {
209 addr_batch[i].ptr.A = Al_m + i * rnn_.k1_block;
210 addr_batch[i].ptr.B = Bl_g + i * Bl_kb_offset_;
211 }
212 brgemm_kernel_execute(brgemm_kernel_layer_b0, rnn_.KB1_blocks,
213 addr_batch, reinterpret_cast<void *>(C_g), amx_buffer);
214 }
215
216 for (int i = 0; i < rnn_.KB2_blocks; i++) {
217 addr_batch[i].ptr.A = Ai_m + i * rnn_.k2_block;
218 addr_batch[i].ptr.B = Bi_g + i * Bi_kb_offset_;
219 }
220 if (is_amx) load_cfg_if_needed(pallete_buff_iter);
221 brgemm_kernel_execute(brgemm_kernel_iter, rnn_.KB2_blocks,
222 addr_batch, reinterpret_cast<void *>(C_g), amx_buffer);
223 }
224
225 if (rnn_.k1_tail && need_gemm_layer_) {
226 if (is_amx) load_cfg_if_needed(pallete_buff_layer_k_tail);
227
228 for (int g = 0; g < n_gates_; g++) {
229 const int lg = g + g_unfused;
230 const auto *const Bl_g = Bl_n + lg * Bl_g_offset_;
231 auto *const C_g = C_n + lg * rnn_.N;
232
233 addr_batch[0].ptr.A = Al_m + Al_k_tail_offset_;
234 addr_batch[0].ptr.B = Bl_g + Bl_k_tail_offset_;
235 brgemm_kernel_execute(brgemm_kernel_layer_k_tail, 1, addr_batch,
236 reinterpret_cast<void *>(C_g), amx_buffer);
237 }
238 }
239
240 if (rnn_.k2_tail) {
241 if (is_amx) load_cfg_if_needed(pallete_buff_iter_k_tail);
242
243 for (int g = 0; g < n_gates_; g++) {
244 const int lg = g + g_unfused;
245 const auto *const Bi_g = Bi_n + lg * Bi_g_offset_;
246 auto *const C_g = C_n + lg * rnn_.N;
247
248 addr_batch[0].ptr.A = Ai_m + Ai_k_tail_offset_;
249 addr_batch[0].ptr.B = Bi_g + Bi_k_tail_offset_;
250 brgemm_kernel_execute(brgemm_kernel_iter_k_tail, 1, addr_batch,
251 reinterpret_cast<void *>(C_g), amx_buffer);
252 }
253 }
254
255 if (!rnn_.unfused_post_gemm) {
256 const auto block_step = (do_n_tail ? rnn_.n_tail : rnn_.n_block)
257 * sizeof(scratch_t);
258 fused_postgemm_(m, n, nb_i, Ai_m, C_n, block_step);
259 }
260
261 ++start;
262 switch (rnn_.loop_order) {
263 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
264 nd_iterator_step(mb, m_blocking_, nb_i, n_blocking_);
265 break;
266 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
267 nd_iterator_step(nb_i, n_blocking_, mb, m_blocking_);
268 break;
269 default: assert(!"unsupported loop order");
270 }
271 }
272}
273
274template <typename src_t, typename weights_t, typename scratch_t,
275 typename gemm_acc_t>
276void brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t,
277 gemm_acc_t>::kernel_fused_iter_layer(const int ithr,
278 const int nthr) const {
279 using namespace cpu::rnn_utils;
280
281 int start = 0, end = 0;
282 balance211(work_amount_, nthr, ithr, start, end);
283
284 const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx();
285 gemm_acc_t *const amx_buffer = is_amx
286 ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr
287 : nullptr;
288 const int max_K_Block = 2
289 * nstl::max(rnn_.KB1_blocks + 1,
290 nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1));
291 brgemm_batch_element_t *const addr_batch
292 = addr_batch_global_ + ithr * max_K_Block;
293
294 const char *pallete_buff = nullptr;
295 const char *pallete_buff_k_tail = nullptr;
296
297 dim_t nb_i = 0, mb = 0;
298 switch (rnn_.loop_order) {
299 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
300 nd_iterator_init(start, mb, m_blocking_, nb_i, n_blocking_);
301 break;
302 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
303 nd_iterator_init(start, nb_i, n_blocking_, mb, m_blocking_);
304 break;
305 default: assert(!"unsupported loop order");
306 }
307
308 amx_tile_configuration_loader_t load_cfg_if_needed;
309 const auto LDA = LDAl_;
310 const auto B_n_offset = Bl_n_offset_;
311 const auto B_g_offset = Bl_g_offset_;
312 const auto B_kb_offset = Bl_kb_offset_;
313 const auto KB_blocks
314 = (need_gemm_layer_ ? rnn_.KB1_blocks : 0) + rnn_.KB2_blocks;
315 const auto KB_blocks_tail = (need_gemm_layer_ ? 1 : 0) + 1;
316 const auto A_k_tail_offset = Al_k_tail_offset_;
317 const auto B_k_tail_offset = Bl_k_tail_offset_;
318
319 while (start < end) {
320 const auto m = mb * rnn_.m_block;
321 const auto nb = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i;
322 const auto n = nb * rnn_.n_block;
323 const auto g_unfused
324 = (rnn_.unfused_post_gemm) ? nb_i % rnn_.n_gates : 0;
325
326 const auto *const Al_m = Al_ + m * LDA;
327 const auto *const Ai_m = Ai_ + m * LDA;
328 const auto *const Bl_n = Bl_ + nb * B_n_offset;
329 const auto *const Bi_n = Bi_ + nb * B_n_offset;
330 auto *const C_n = C_ + m * rnn_.LDC + n;
331
332 const brgemm_kernel_t *brgemm_kernel = brgemm_kernel_layer_main_;
333 const brgemm_kernel_t *brgemm_kernel_k_tail
334 = brgemm_kernel_layer_k_tail_;
335
336 if (is_amx) {
337 pallete_buff = pallete_buff_layer_main_;
338 pallete_buff_k_tail = pallete_buff_layer_k_tail_;
339 }
340
341 const bool do_n_tail = (n + rnn_.n_block) > rnn_.N;
342 if (do_n_tail) {
343 brgemm_kernel = brgemm_kernel_layer_n_tail_;
344 brgemm_kernel_k_tail = brgemm_kernel_layer_nk_tail_;
345
346 if (is_amx) {
347 pallete_buff = pallete_buff_layer_n_tail_;
348 pallete_buff_k_tail = pallete_buff_layer_nk_tail_;
349 }
350 }
351
352 for (int g = 0; g < n_gates_; g++) {
353 const int lg = g + g_unfused;
354 const auto *const Bl_g = Bl_n + lg * B_g_offset;
355 const auto *const Bi_g = Bi_n + lg * B_g_offset;
356 auto *const C_g = C_n + lg * rnn_.N;
357 int batch_idx = 0;
358
359 if (need_gemm_layer_) {
360 for (; batch_idx < rnn_.KB1_blocks; batch_idx++) {
361 addr_batch[batch_idx].ptr.A
362 = Al_m + batch_idx * rnn_.k1_block;
363 addr_batch[batch_idx].ptr.B
364 = Bl_g + batch_idx * B_kb_offset;
365 }
366 }
367
368 int iter_idx = 0;
369 for (; batch_idx < KB_blocks; batch_idx++) {
370 addr_batch[batch_idx].ptr.A = Ai_m + iter_idx * rnn_.k2_block;
371 addr_batch[batch_idx].ptr.B = Bi_g + iter_idx * B_kb_offset;
372 iter_idx++;
373 }
374
375 if (is_amx) load_cfg_if_needed(pallete_buff);
376 brgemm_kernel_execute(brgemm_kernel, KB_blocks, addr_batch,
377 reinterpret_cast<void *>(C_g), amx_buffer);
378 }
379
380 if (rnn_.k2_tail) {
381 for (int g = 0; g < n_gates_; g++) {
382 const int lg = g + g_unfused;
383 auto *const C_g = C_n + lg * rnn_.N;
384
385 int batch_idx = 0;
386 if (need_gemm_layer_) {
387 const auto *const Bl_g = Bl_n + lg * B_g_offset;
388 addr_batch[batch_idx].ptr.A = Al_m + A_k_tail_offset;
389 addr_batch[batch_idx].ptr.B = Bl_g + B_k_tail_offset;
390 batch_idx++;
391 }
392 const auto *const Bi_g = Bi_n + lg * B_g_offset;
393 addr_batch[batch_idx].ptr.A = Ai_m + A_k_tail_offset;
394 addr_batch[batch_idx].ptr.B = Bi_g + B_k_tail_offset;
395
396 if (is_amx) load_cfg_if_needed(pallete_buff_k_tail);
397 brgemm_kernel_execute(brgemm_kernel_k_tail, KB_blocks_tail,
398 addr_batch, reinterpret_cast<void *>(C_g), amx_buffer);
399 }
400 }
401
402 if (!rnn_.unfused_post_gemm) {
403 const auto block_step = (do_n_tail ? rnn_.n_tail : rnn_.n_block)
404 * sizeof(scratch_t);
405 fused_postgemm_(m, n, nb_i, Ai_m, C_n, block_step);
406 }
407
408 ++start;
409 switch (rnn_.loop_order) {
410 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
411 nd_iterator_step(mb, m_blocking_, nb_i, n_blocking_);
412 break;
413 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
414 nd_iterator_step(nb_i, n_blocking_, mb, m_blocking_);
415 break;
416 default: assert(!"unsupported loop order");
417 }
418 }
419}
420
421template <typename src_t, typename weights_t, typename gemm_acc_t>
422brgemm_dst_proj_t<src_t, weights_t, gemm_acc_t>::brgemm_dst_proj_t(
423 const ref_rnn_brgemm_t &rnn_brgemm, const rnn_utils::rnn_conf_t &rnn,
424 rnn_utils::cell_position_t cell_position, const src_t *proj_ht,
425 const weights_t *w_projection, gemm_acc_t *output,
426 gemm_acc_t *amx_scratchpad,
427 x64::brgemm_batch_element_t *addr_batch_global,
428 const postgemm_fused_t &fused_postgemm)
429 : rnn_brgemm_(rnn_brgemm)
430 , rnn_(rnn)
431 , proj_desc_idx_(rnn_.is_cell_dt_f32()
432 ? rnn_.dst_brgemm_desc(cell_position, true)
433 : 0)
434 , A_(proj_ht)
435 , B_(w_projection)
436 , C_(output)
437 , LDC_(rnn_.is_cell_dt_f32() ? rnn_.dst_layer_ld(cell_position, true)
438 : rnn_.scratch_gates_ld)
439 , max_nthr_(rnn_.nthr)
440 , work_amount_proj_(rnn_.Nproj_blocks * rnn_.M_blocks)
441 , B_n_offset_(rnn_.Kprojpadded * rnn_.n_block)
442 , Bp_kb_offset_(rnn_.kproj_block * rnn_.n_block)
443 , amx_scratchpad_(amx_scratchpad)
444 , addr_batch_global_(addr_batch_global)
445 , brgemm_kernel_main_(rnn_brgemm_.kernel_proj_b0_[proj_desc_idx_].get())
446 , brgemm_kernel_n_tail_(
447 rnn_brgemm_.kernel_proj_N_tail_b0_[proj_desc_idx_].get())
448 , brgemm_kernel_nk_tail_(
449 rnn_brgemm_.kernel_proj_NK_tail_b1_[proj_desc_idx_].get())
450 , brgemm_kernel_k_tail_(
451 rnn_brgemm_.kernel_proj_K_tail_b1_[proj_desc_idx_].get())
452 , fused_postgemm_(fused_postgemm) {}
453
454template <typename src_t, typename weights_t, typename gemm_acc_t>
455void brgemm_dst_proj_t<src_t, weights_t, gemm_acc_t>::execute() const {
456 parallel(max_nthr_, [this](const int ithr, const int nthr) {
457 this->kernel(ithr, nthr);
458 });
459}
460
461template <typename src_t, typename weights_t, typename gemm_acc_t>
462void brgemm_dst_proj_t<src_t, weights_t, gemm_acc_t>::kernel(
463 const int ithr, const int nthr) const {
464 using namespace cpu::rnn_utils;
465
466 int start = 0, end = 0;
467 balance211(work_amount_proj_, nthr, ithr, start, end);
468 const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx();
469 const int max_K_Block = nstl::max(rnn_.KB1_blocks + 1,
470 nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1));
471 auto *const amx_buffer = is_amx
472 ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr
473 : nullptr;
474 auto *const addr_batch = is_amx ? addr_batch_global_ + ithr * max_K_Block
475 : addr_batch_global_ + ithr;
476 amx_tile_configuration_loader_t load_cfg_if_needed;
477
478 if (is_amx) load_cfg_if_needed(rnn_brgemm_.pallete_buff_proj_);
479
480 int nb = 0, mb = 0;
481 switch (rnn_.loop_order) {
482 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
483 nd_iterator_init(start, mb, rnn_.M_blocks, nb, rnn_.Nproj_blocks);
484 break;
485 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
486 nd_iterator_init(start, nb, rnn_.Nproj_blocks, mb, rnn_.M_blocks);
487 break;
488 default: assert(!"unsupported loop order");
489 }
490
491 while (start < end) {
492 const int n = nb * rnn_.n_block;
493 const int m = mb * rnn_.m_block;
494 const bool do_n_tail = (n + rnn_.n_block) > rnn_.Nproj;
495 const int block_step = ((do_n_tail) ? rnn_.nproj_tail : rnn_.n_block)
496 * sizeof(src_t);
497
498 const auto *const Ap_m = A_ + m * rnn_.LDAproj;
499 const auto *const Bp_n = B_ + nb * B_n_offset_;
500 auto *const Cp_n = C_ + m * LDC_ + n;
501
502 const brgemm_kernel_t *const brgemm_kernel_proj_b0
503 = do_n_tail ? brgemm_kernel_n_tail_ : brgemm_kernel_main_;
504
505 if (is_amx) {
506 if (do_n_tail)
507 load_cfg_if_needed(rnn_brgemm_.pallete_buff_nproj_tail_);
508 for (int k = 0; k < rnn_.KBproj_blocks; k++) {
509 addr_batch[k].ptr.A = Ap_m + k * rnn_.kproj_block;
510 addr_batch[k].ptr.B = Bp_n + k * Bp_kb_offset_;
511 }
512 brgemm_kernel_execute(brgemm_kernel_proj_b0, rnn_.KBproj_blocks,
513 addr_batch, reinterpret_cast<void *>(Cp_n), amx_buffer);
514
515 if (rnn_.kproj_tail) {
516 const brgemm_kernel_t *brgemm_kernel_proj_tail;
517 const char *tail_cfg_kproj, *tail_recfg;
518 if (do_n_tail) {
519 tail_cfg_kproj = rnn_brgemm_.pallete_buff_nkproj_tail_;
520 tail_recfg = rnn_brgemm_.pallete_buff_nproj_tail_;
521 brgemm_kernel_proj_tail = brgemm_kernel_nk_tail_;
522 } else {
523 tail_cfg_kproj = rnn_brgemm_.pallete_buff_kproj_tail_;
524 tail_recfg = rnn_brgemm_.pallete_buff_proj_;
525 brgemm_kernel_proj_tail = brgemm_kernel_k_tail_;
526 }
527 load_cfg_if_needed(tail_cfg_kproj);
528 addr_batch[0].ptr.A
529 = Ap_m + rnn_.KBproj_blocks * rnn_.kproj_block;
530 addr_batch[0].ptr.B = Bp_n
531 + rnn_.KBproj_blocks * rnn_.kproj_block * rnn_.n_block;
532 brgemm_kernel_execute(brgemm_kernel_proj_tail, 1, addr_batch,
533 reinterpret_cast<void *>(Cp_n), amx_buffer);
534 load_cfg_if_needed(tail_recfg);
535 }
536 } else {
537 addr_batch[0].ptr.A = Ap_m;
538 addr_batch[0].ptr.B = Bp_n;
539 brgemm_kernel_execute(brgemm_kernel_proj_b0, 1, addr_batch,
540 reinterpret_cast<void *>(Cp_n), amx_buffer);
541 }
542
543 if (!rnn_.unfused_post_gemm) {
544 fused_postgemm_(m, n, Cp_n, block_step);
545 }
546
547 ++start;
548 switch (rnn_.loop_order) {
549 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
550 nd_iterator_step(mb, rnn_.M_blocks, nb, rnn_.Nproj_blocks);
551 break;
552 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
553 nd_iterator_step(nb, rnn_.Nproj_blocks, mb, rnn_.M_blocks);
554 break;
555 default: assert(!"unsupported loop order");
556 }
557 }
558}
559
560template <typename src_t, typename weights_t, typename scratch_t,
561 typename gemm_acc_t>
562brgemm_gru_t<src_t, weights_t, scratch_t, gemm_acc_t>::brgemm_gru_t(
563 const ref_rnn_brgemm_t &rnn_brgemm, const rnn_utils::rnn_conf_t &rnn,
564 rnn_utils::cell_position_t cell_position, const src_t *src_iter,
565 const src_t *src_layer, weights_t *w_iter0, weights_t *w_iter1,
566 weights_t *w_layer, src_t *d_layer, scratch_t *scratch_gates,
567 scratch_t *scratch_cell, gemm_acc_t *amx_scratchpad,
568 x64::brgemm_batch_element_t *addr_batch_global,
569 const postgemm_fused_t &fused_postgemm_part1,
570 const postgemm_fused_t &fused_postgemm_part2)
571 : rnn_brgemm_(rnn_brgemm)
572 , rnn_(rnn)
573 , need_gemm_layer_(rnn_.need_gemm_layer(cell_position))
574 , layer_desc_idx_(rnn_.layer_brgemm_desc(cell_position))
575 , iter_desc_idx_(rnn_.iter_brgemm_desc(cell_position))
576 , iter_part2_desc_idx_(rnn_.iter_part2_brgemm_desc(cell_position))
577 , Al_(src_layer)
578 , Ai_(src_iter)
579 , Bl_(w_layer)
580 , Bi_(w_iter0)
581 , Bi2_(w_iter1)
582 , C_gates_(scratch_gates)
583 , C_cell_(scratch_cell)
584 , Dl_(d_layer)
585 , LDAl_(rnn_.src_layer_ld(cell_position))
586 , LDAi_(rnn_.src_iter_ld(cell_position))
587 , max_nthr_(rnn_.nthr)
588 , n_blocking_((rnn_.unfused_post_gemm) ? rnn_.N_blocks * rnn_.n_gates
589 : rnn_.N_blocks)
590 , m_blocking_(rnn_.M_blocks)
591 , work_amount_(m_blocking_)
592 , Bl_n_offset_(rnn_.K1padded * rnn_.n_block)
593 , Bi_n_offset_(rnn_.K2padded * rnn_.n_block)
594 , Bl_g_offset_(rnn_.N_blocks * Bl_n_offset_)
595 , Bi_g_offset_(rnn_.N_blocks * Bi_n_offset_)
596 , Al_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block)
597 , Ai_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block)
598 , Bl_kb_offset_(rnn_.k1_block * rnn_.n_block)
599 , Bi_kb_offset_(rnn_.k2_block * rnn_.n_block)
600 , Bl_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block * rnn_.n_block)
601 , Bi_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block * rnn_.n_block)
602 , n_gates_(rnn.unfused_post_gemm ? 1 : rnn.n_gates)
603 , brgemm_kernel_iter_p0_main_(need_gemm_layer_
604 ? rnn_brgemm_.kernel_iter_b1_[iter_desc_idx_].get()
605 : rnn_brgemm_.kernel_iter_b0_[iter_desc_idx_].get())
606 , brgemm_kernel_iter_p0_n_tail_(need_gemm_layer_
607 ? rnn_brgemm_.kernel_iter_N_tail_b1_[iter_desc_idx_].get()
608 : rnn_brgemm_.kernel_iter_N_tail_b0_[iter_desc_idx_]
609 .get())
610 , brgemm_kernel_iter_p0_k_tail_(
611 rnn_brgemm_.kernel_iter_K2_tail_b1_[iter_desc_idx_].get())
612 , brgemm_kernel_iter_p0_nk_tail_(
613 rnn_brgemm_.kernel_iter_NK2_tail_b1_[iter_desc_idx_].get())
614 , brgemm_kernel_iter_p1_main_(
615 rnn_brgemm_.kernel_iter_p2_b1_[iter_part2_desc_idx_].get())
616 , brgemm_kernel_iter_p1_n_tail_(
617 rnn_brgemm_.kernel_iter_p2_N_tail_b1_[iter_part2_desc_idx_].get())
618 , brgemm_kernel_iter_p1_k_tail_(
619 rnn_brgemm_.kernel_iter_p2_K2_tail_b1_[iter_part2_desc_idx_]
620 .get())
621 , brgemm_kernel_iter_p1_nk_tail_(
622 rnn_brgemm_.kernel_iter_p2_NK2_tail_b1_[iter_part2_desc_idx_]
623 .get())
624 , brgemm_kernel_layer_main_(
625 rnn_brgemm_.kernel_layer_b0_[layer_desc_idx_].get())
626 , brgemm_kernel_layer_n_tail_(
627 rnn_brgemm_.kernel_layer_N_tail_b0_[layer_desc_idx_].get())
628 , brgemm_kernel_layer_k_tail_(
629 rnn_brgemm_.kernel_layer_K1_tail_b1_[layer_desc_idx_].get())
630 , brgemm_kernel_layer_nk_tail_(
631 rnn_brgemm_.kernel_layer_NK1_tail_b1_[layer_desc_idx_].get())
632 , pallete_buff_iter_main_(rnn.k1_block == rnn.k2_block
633 ? rnn_brgemm_.pallete_buff_layer_
634 : rnn_brgemm_.pallete_buff_iter_)
635 , pallete_buff_iter_n_tail_(rnn.k1_block == rnn.k2_block
636 ? rnn_brgemm_.pallete_buff_layer_n_tail_
637 : rnn_brgemm_.pallete_buff_iter_n_tail_)
638 , pallete_buff_iter_k_tail_(rnn.k1_tail == rnn.k2_tail
639 ? rnn_brgemm_.pallete_buff_k1_tail_
640 : rnn_brgemm_.pallete_buff_k2_tail_)
641 , pallete_buff_iter_nk_tail_(rnn.k1_tail == rnn.k2_tail
642 ? rnn_brgemm_.pallete_buff_nk1_tail_
643 : rnn_brgemm_.pallete_buff_nk2_tail_)
644 , pallete_buff_layer_main_(rnn_brgemm_.pallete_buff_layer_)
645 , pallete_buff_layer_n_tail_(rnn_brgemm_.pallete_buff_layer_n_tail_)
646 , pallete_buff_layer_k_tail_(rnn_brgemm_.pallete_buff_k1_tail_)
647 , pallete_buff_layer_nk_tail_(rnn_brgemm_.pallete_buff_nk1_tail_)
648 , amx_scratchpad_(amx_scratchpad)
649 , addr_batch_global_(addr_batch_global)
650 , fused_postgemm_part1_(fused_postgemm_part1)
651 , fused_postgemm_part2_(fused_postgemm_part2)
652 , is_fused_layer_iter_brgemm_(true) {}
653
654template <typename src_t, typename weights_t, typename scratch_t,
655 typename gemm_acc_t>
656void brgemm_gru_t<src_t, weights_t, scratch_t, gemm_acc_t>::execute() const {
657 assert(is_fused_layer_iter_brgemm_);
658 parallel(max_nthr_, [this](const int ithr, const int nthr) {
659 this->kernel(ithr, nthr);
660 });
661}
662
663template <typename src_t, typename weights_t, typename scratch_t,
664 typename gemm_acc_t>
665void brgemm_gru_t<src_t, weights_t, scratch_t, gemm_acc_t>::kernel(
666 const int ithr, const int nthr) const {
667 int start = 0, end = 0;
668 balance211(work_amount_, nthr, ithr, start, end);
669
670 const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx();
671 gemm_acc_t *const amx_buffer = is_amx
672 ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr
673 : nullptr;
674 const int max_K_Block = 2
675 * nstl::max(rnn_.KB1_blocks + 1,
676 nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1));
677 brgemm_batch_element_t *const addr_batch
678 = addr_batch_global_ + ithr * max_K_Block;
679
680 const char *pallete_buff_layer = nullptr;
681 const char *pallete_buff_layer_k_tail = nullptr;
682 const char *pallete_buff_iter = nullptr;
683 const char *pallete_buff_iter_k_tail = nullptr;
684
685 amx_tile_configuration_loader_t load_cfg_if_needed;
686 while (start < end) {
687 dim_t mb = start;
688 const auto m = mb * rnn_.m_block;
689 const auto *const Al_m = Al_ + m * LDAl_;
690 const auto *const Ai_m = Ai_ + m * LDAi_;
691 const auto *const Ai2_m = Dl_ + m * LDAl_;
692
693 for (dim_t nb_i = 0; nb_i < n_blocking_; nb_i++) {
694 const auto nb
695 = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i;
696 const auto n = nb * rnn_.n_block;
697
698 const auto *const Bl_n = Bl_ + nb * Bl_n_offset_;
699 const auto *const Bi_n = Bi_ + nb * Bi_n_offset_;
700 auto *const C_gates_n = C_gates_ + m * rnn_.LDC + n;
701 auto *const C_cell_n = C_cell_ + m * rnn_.LDC + n;
702
703 const brgemm_kernel_t *brgemm_kernel_layer
704 = brgemm_kernel_layer_main_;
705 const brgemm_kernel_t *brgemm_kernel_layer_k_tail
706 = brgemm_kernel_layer_k_tail_;
707 const brgemm_kernel_t *brgemm_kernel_iter_p0
708 = brgemm_kernel_iter_p0_main_;
709 const brgemm_kernel_t *brgemm_kernel_iter_p0_k_tail
710 = brgemm_kernel_iter_p0_k_tail_;
711
712 if (is_amx) {
713 pallete_buff_layer = pallete_buff_layer_main_;
714 pallete_buff_layer_k_tail = pallete_buff_layer_k_tail_;
715 pallete_buff_iter = pallete_buff_iter_main_;
716 pallete_buff_iter_k_tail = pallete_buff_iter_k_tail_;
717 }
718
719 const bool do_n_tail = (n + rnn_.n_block) > rnn_.N;
720 if (do_n_tail) {
721 brgemm_kernel_layer = brgemm_kernel_layer_n_tail_;
722 brgemm_kernel_layer_k_tail = brgemm_kernel_layer_nk_tail_;
723 brgemm_kernel_iter_p0 = brgemm_kernel_iter_p0_n_tail_;
724 brgemm_kernel_iter_p0_k_tail = brgemm_kernel_iter_p0_nk_tail_;
725
726 if (is_amx) {
727 pallete_buff_layer = pallete_buff_layer_n_tail_;
728 pallete_buff_layer_k_tail = pallete_buff_layer_nk_tail_;
729 pallete_buff_iter = pallete_buff_iter_n_tail_;
730 pallete_buff_iter_k_tail = pallete_buff_iter_nk_tail_;
731 }
732 }
733
734 if (need_gemm_layer_) {
735 if (is_amx) load_cfg_if_needed(pallete_buff_layer);
736 for (int g = 0; g < n_gates_; g++) {
737 const auto *const Bl_g = Bl_n + g * Bl_g_offset_;
738 auto *const C_gates_g = C_gates_n + g * rnn_.N;
739
740 for (int batch_idx = 0; batch_idx < rnn_.KB1_blocks;
741 batch_idx++) {
742 addr_batch[batch_idx].ptr.A
743 = Al_m + batch_idx * rnn_.k1_block;
744 addr_batch[batch_idx].ptr.B
745 = Bl_g + batch_idx * Bl_kb_offset_;
746 }
747 brgemm_kernel_execute(brgemm_kernel_layer, rnn_.KB1_blocks,
748 addr_batch, reinterpret_cast<void *>(C_gates_g),
749 amx_buffer);
750 }
751 }
752
753 if (need_gemm_layer_ && rnn_.k1_tail > 0) {
754 if (is_amx) load_cfg_if_needed(pallete_buff_layer_k_tail);
755 for (int g = 0; g < n_gates_; g++) {
756 const auto *const Bl_g = Bl_n + g * Bl_g_offset_;
757 auto *const C_gates_g = C_gates_n + g * rnn_.N;
758
759 addr_batch[0].ptr.A
760 = Al_m + rnn_.KB1_blocks * rnn_.k1_block;
761 addr_batch[0].ptr.B
762 = Bl_g + rnn_.KB1_blocks * Bl_kb_offset_;
763 brgemm_kernel_execute(brgemm_kernel_layer_k_tail, 1,
764 addr_batch, reinterpret_cast<void *>(C_gates_g),
765 amx_buffer);
766 }
767 }
768 if (is_amx) load_cfg_if_needed(pallete_buff_iter);
769 for (int g = 0; g < n_gates_ - 1; g++) {
770 const auto *const Bi_g = Bi_n + g * Bi_g_offset_;
771 auto *const C_gates_g = C_gates_n + g * rnn_.N;
772
773 for (int batch_idx = 0; batch_idx < rnn_.KB2_blocks;
774 batch_idx++) {
775 addr_batch[batch_idx].ptr.A
776 = Ai_m + batch_idx * rnn_.k2_block;
777 addr_batch[batch_idx].ptr.B
778 = Bi_g + batch_idx * Bi_kb_offset_;
779 }
780
781 brgemm_kernel_execute(brgemm_kernel_iter_p0, rnn_.KB2_blocks,
782 addr_batch, reinterpret_cast<void *>(C_gates_g),
783 amx_buffer);
784 }
785
786 if (rnn_.k2_tail > 0) {
787 if (is_amx) load_cfg_if_needed(pallete_buff_iter_k_tail);
788 for (int g = 0; g < n_gates_ - 1; g++) {
789 const auto *const Bi_g = Bi_n + g * Bi_g_offset_;
790 auto *const C_gates_g = C_gates_n + g * rnn_.N;
791
792 addr_batch[0].ptr.A
793 = Ai_m + rnn_.KB2_blocks * rnn_.k2_block;
794 addr_batch[0].ptr.B
795 = Bi_g + rnn_.KB2_blocks * Bi_kb_offset_;
796
797 brgemm_kernel_execute(brgemm_kernel_iter_p0_k_tail, 1,
798 addr_batch, reinterpret_cast<void *>(C_gates_g),
799 amx_buffer);
800 }
801 }
802
803 if (!rnn_.unfused_post_gemm) {
804 const auto block_step
805 = (do_n_tail ? rnn_.n_tail : rnn_.n_block);
806 fused_postgemm_part1_(
807 m, n, nb_i, Ai_m + n, C_gates_n, C_cell_n, block_step);
808 }
809 }
810
811 for (dim_t nb_i = 0; nb_i < n_blocking_; nb_i++) {
812 const auto nb
813 = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i;
814 const auto n = nb * rnn_.n_block;
815
816 const auto *const Bi2_n = Bi2_ + nb * Bi_n_offset_;
817 auto *const C_gates_n = C_gates_ + m * rnn_.LDC + n;
818
819 const brgemm_kernel_t *brgemm_kernel_iter_p1
820 = brgemm_kernel_iter_p1_main_;
821 const brgemm_kernel_t *brgemm_kernel_iter_p1_k_tail
822 = brgemm_kernel_iter_p1_k_tail_;
823
824 if (is_amx) {
825 pallete_buff_iter = pallete_buff_iter_main_;
826 pallete_buff_iter_k_tail = pallete_buff_iter_k_tail_;
827 }
828
829 const bool do_n_tail = (n + rnn_.n_block) > rnn_.N;
830 if (do_n_tail) {
831 brgemm_kernel_iter_p1 = brgemm_kernel_iter_p1_n_tail_;
832 brgemm_kernel_iter_p1_k_tail = brgemm_kernel_iter_p1_nk_tail_;
833
834 if (is_amx) {
835 pallete_buff_iter = pallete_buff_iter_n_tail_;
836 pallete_buff_iter_k_tail = pallete_buff_iter_nk_tail_;
837 }
838 }
839
840 if (is_amx) load_cfg_if_needed(pallete_buff_iter);
841 for (int g = 0; g < 1; g++) {
842 const auto *const Bi2_g = Bi2_n + g * Bi_g_offset_;
843 auto *const C_gates_g = C_gates_n + (n_gates_ - 1) * rnn_.N;
844
845 for (int batch_idx = 0; batch_idx < rnn_.KB2_blocks;
846 batch_idx++) {
847 addr_batch[batch_idx].ptr.A
848 = Ai2_m + batch_idx * rnn_.k2_block;
849 addr_batch[batch_idx].ptr.B
850 = Bi2_g + batch_idx * Bi_kb_offset_;
851 }
852
853 brgemm_kernel_execute(brgemm_kernel_iter_p1, rnn_.KB2_blocks,
854 addr_batch, reinterpret_cast<void *>(C_gates_g),
855 amx_buffer);
856 }
857
858 if (rnn_.k2_tail > 0) {
859 if (is_amx) load_cfg_if_needed(pallete_buff_iter_k_tail);
860 for (int g = 0; g < 1; g++) {
861 const auto *const Bi2_g = Bi2_n + g * Bi_g_offset_;
862 auto *const C_gates_g = C_gates_n + (n_gates_ - 1) * rnn_.N;
863
864 addr_batch[0].ptr.A
865 = Ai2_m + rnn_.KB2_blocks * rnn_.k2_block;
866 addr_batch[0].ptr.B
867 = Bi2_g + rnn_.KB2_blocks * Bi_kb_offset_;
868
869 brgemm_kernel_execute(brgemm_kernel_iter_p1_k_tail, 1,
870 addr_batch, reinterpret_cast<void *>(C_gates_g),
871 amx_buffer);
872 }
873 }
874 if (!rnn_.unfused_post_gemm && nb_i == n_blocking_ - 1) {
875 fused_postgemm_part2_(m, 0, 0, Ai_m, C_gates_ + m * rnn_.LDC,
876 C_cell_ + m * rnn_.LDC, rnn_.N);
877 }
878 }
879 ++start;
880 }
881}
882
883template <typename src_t, typename weights_t, typename scratch_t,
884 typename gemm_acc_t>
885brgemm_merged_layer_t<src_t, weights_t, scratch_t,
886 gemm_acc_t>::brgemm_merged_layer_t(const ref_rnn_brgemm_t &rnn_brgemm,
887 const rnn_utils::rnn_conf_t &rnn,
888 rnn_utils::cell_position_t cell_position, const src_t *src_layer,
889 weights_t *w_layer, scratch_t *scratch_gates,
890 gemm_acc_t *amx_scratchpad,
891 x64::brgemm_batch_element_t *addr_batch_global)
892 : rnn_brgemm_(rnn_brgemm)
893 , rnn_(rnn)
894 , layer_desc_idx_(rnn_.layer_brgemm_desc(cell_position))
895 , Al_(src_layer)
896 , Bl_(w_layer)
897 , C_(scratch_gates)
898 , LDAl_(rnn_.src_layer_ld(cell_position))
899 , max_nthr_(rnn_.nthr)
900 , n_blocking_((rnn_.unfused_post_gemm) ? rnn_.N_blocks * rnn_.n_gates
901 : rnn_.N_blocks)
902 , m_blocking_(rnn_.Mlayermerged_blocks)
903 , work_amount_(n_blocking_ * m_blocking_)
904 , Bl_n_offset_(rnn_.K1padded * rnn_.n_block)
905 , Bl_g_offset_(rnn_.N_blocks * Bl_n_offset_)
906 , Al_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block)
907 , Bl_kb_offset_(rnn_.k1_block * rnn_.n_block)
908 , Bl_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block * rnn_.n_block)
909 , n_gates_(rnn.unfused_post_gemm ? 1 : rnn.n_gates)
910 , brgemm_kernel_layer_main_(
911 rnn_brgemm_.kernel_layermerged_b0_[layer_desc_idx_].get())
912 , brgemm_kernel_layer_n_tail_(
913 rnn_brgemm_.kernel_layermerged_N_tail_b0_[layer_desc_idx_].get())
914 , brgemm_kernel_layer_k_tail_(
915 rnn_brgemm_.kernel_layermerged_K1_tail_b1_[layer_desc_idx_].get())
916 , brgemm_kernel_layer_nk_tail_(
917 rnn_brgemm_.kernel_layermerged_NK1_tail_b1_[layer_desc_idx_]
918 .get())
919 , pallete_buff_layer_main_(rnn_brgemm_.pallete_buff_layermerged_)
920 , pallete_buff_layer_n_tail_(rnn_brgemm_.pallete_buff_layermerged_n_tail_)
921 , pallete_buff_layer_k_tail_(rnn_brgemm_.pallete_buff_layermerged_k1_tail_)
922 , pallete_buff_layer_nk_tail_(
923 rnn_brgemm_.pallete_buff_layermerged_nk1_tail_)
924 , amx_scratchpad_(amx_scratchpad)
925 , addr_batch_global_(addr_batch_global) {}
926
927template <typename src_t, typename weights_t, typename scratch_t,
928 typename gemm_acc_t>
929void brgemm_merged_layer_t<src_t, weights_t, scratch_t, gemm_acc_t>::execute()
930 const {
931 parallel(max_nthr_, [this](const int ithr, const int nthr) {
932 this->kernel(ithr, nthr);
933 });
934}
935
936template <typename src_t, typename weights_t, typename scratch_t,
937 typename gemm_acc_t>
938void brgemm_merged_layer_t<src_t, weights_t, scratch_t, gemm_acc_t>::kernel(
939 const int ithr, const int nthr) const {
940 using namespace cpu::rnn_utils;
941
942 int start = 0, end = 0;
943 balance211(work_amount_, nthr, ithr, start, end);
944
945 const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx();
946 const auto m_block = rnn_.mlayermerged_block;
947 gemm_acc_t *const amx_buffer = is_amx
948 ? amx_scratchpad_ + m_block * rnn_.n_block * ithr
949 : nullptr;
950 const int max_K_Block = rnn_.KB1_blocks + 1;
951 brgemm_batch_element_t *const addr_batch
952 = addr_batch_global_ + ithr * max_K_Block;
953
954 const char *pallete_buff_layer = nullptr;
955 const char *pallete_buff_layer_k_tail = nullptr;
956
957 dim_t nb_i = 0, mb = 0;
958 switch (rnn_.loop_order) {
959 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
960 nd_iterator_init(start, mb, m_blocking_, nb_i, n_blocking_);
961 break;
962 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
963 nd_iterator_init(start, nb_i, n_blocking_, mb, m_blocking_);
964 break;
965 default: assert(!"unsupported loop order");
966 }
967
968 amx_tile_configuration_loader_t load_cfg_if_needed;
969
970 while (start < end) {
971 const auto m = mb * m_block;
972 const auto nb = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i;
973 const auto n = nb * rnn_.n_block;
974 const auto g_unfused
975 = (rnn_.unfused_post_gemm) ? nb_i % rnn_.n_gates : 0;
976
977 const auto *const Al_m = Al_ + m * LDAl_;
978 const auto *const Bl_n = Bl_ + nb * Bl_n_offset_;
979 auto *const C_n = C_ + m * rnn_.LDC + n;
980
981 const brgemm_kernel_t *brgemm_kernel_layer_b0
982 = brgemm_kernel_layer_main_;
983 const brgemm_kernel_t *brgemm_kernel_layer_k_tail
984 = brgemm_kernel_layer_k_tail_;
985
986 if (is_amx) {
987 pallete_buff_layer = pallete_buff_layer_main_;
988 pallete_buff_layer_k_tail = pallete_buff_layer_k_tail_;
989 }
990
991 const bool do_n_tail = (n + rnn_.n_block) > rnn_.N;
992 if (do_n_tail) {
993 brgemm_kernel_layer_b0 = brgemm_kernel_layer_n_tail_;
994 brgemm_kernel_layer_k_tail = brgemm_kernel_layer_nk_tail_;
995
996 if (is_amx) {
997 pallete_buff_layer = pallete_buff_layer_n_tail_;
998 pallete_buff_layer_k_tail = pallete_buff_layer_nk_tail_;
999 }
1000 }
1001
1002 for (int g = 0; g < n_gates_; g++) {
1003 const int lg = g + g_unfused;
1004 const auto *const Bl_g = Bl_n + lg * Bl_g_offset_;
1005 auto *const C_g = C_n + lg * rnn_.N;
1006
1007 if (is_amx) load_cfg_if_needed(pallete_buff_layer);
1008 for (int i = 0; i < rnn_.KB1_blocks; i++) {
1009 addr_batch[i].ptr.A = Al_m + i * rnn_.k1_block;
1010 addr_batch[i].ptr.B = Bl_g + i * Bl_kb_offset_;
1011 }
1012 brgemm_kernel_execute(brgemm_kernel_layer_b0, rnn_.KB1_blocks,
1013 addr_batch, reinterpret_cast<void *>(C_g), amx_buffer);
1014 }
1015
1016 if (rnn_.k1_tail) {
1017 if (is_amx) load_cfg_if_needed(pallete_buff_layer_k_tail);
1018
1019 for (int g = 0; g < n_gates_; g++) {
1020 const int lg = g + g_unfused;
1021 const auto *const Bl_g = Bl_n + lg * Bl_g_offset_;
1022 auto *const C_g = C_n + lg * rnn_.N;
1023
1024 addr_batch[0].ptr.A = Al_m + Al_k_tail_offset_;
1025 addr_batch[0].ptr.B = Bl_g + Bl_k_tail_offset_;
1026 brgemm_kernel_execute(brgemm_kernel_layer_k_tail, 1, addr_batch,
1027 reinterpret_cast<void *>(C_g), amx_buffer);
1028 }
1029 }
1030
1031 ++start;
1032 switch (rnn_.loop_order) {
1033 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
1034 nd_iterator_step(mb, m_blocking_, nb_i, n_blocking_);
1035 break;
1036 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
1037 nd_iterator_step(nb_i, n_blocking_, mb, m_blocking_);
1038 break;
1039 default: assert(!"unsupported loop order");
1040 }
1041 }
1042}
1043
1044template class brgemm_dst_layer_iter_t<uint8_t, int8_t, int32_t, int32_t>;
1045template class brgemm_dst_layer_iter_t<int8_t, int8_t, int32_t, int32_t>;
1046template class brgemm_dst_layer_iter_t<float, float, float, float>;
1047template class brgemm_dst_layer_iter_t<bfloat16_t, bfloat16_t, float, float>;
1048
1049template class brgemm_dst_proj_t<float, float, float>;
1050template class brgemm_dst_proj_t<bfloat16_t, bfloat16_t, float>;
1051template class brgemm_dst_proj_t<int8_t, int8_t, int32_t>;
1052template class brgemm_dst_proj_t<uint8_t, int8_t, int32_t>;
1053
1054template class brgemm_gru_t<uint8_t, int8_t, int32_t, int32_t>;
1055template class brgemm_gru_t<int8_t, int8_t, int32_t, int32_t>;
1056template class brgemm_gru_t<float, float, float, float>;
1057template class brgemm_gru_t<bfloat16_t, bfloat16_t, float, float>;
1058
1059template class brgemm_merged_layer_t<uint8_t, int8_t, int32_t, int32_t>;
1060template class brgemm_merged_layer_t<int8_t, int8_t, int32_t, int32_t>;
1061template class brgemm_merged_layer_t<float, float, float, float>;
1062template class brgemm_merged_layer_t<bfloat16_t, bfloat16_t, float, float>;
1063
1064} // namespace x64
1065} // namespace cpu
1066} // namespace impl
1067} // namespace dnnl
1068