1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "brgemm_cell_common_fwd.hpp" |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/utils.hpp" |
21 | #include "cpu/x64/rnn/brgemm_cell_common_utils.hpp" |
22 | |
23 | using namespace dnnl::impl::utils; |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | template <typename src_t, typename weights_t, typename scratch_t, |
31 | typename gemm_acc_t> |
32 | brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t, |
33 | gemm_acc_t>::brgemm_dst_layer_iter_t(const ref_rnn_brgemm_t &rnn_brgemm, |
34 | const rnn_utils::rnn_conf_t &rnn, |
35 | rnn_utils::cell_position_t cell_position, const src_t *src_iter, |
36 | const src_t *src_layer, weights_t *w_iter, weights_t *w_layer, |
37 | scratch_t *scratch_gates, gemm_acc_t *amx_scratchpad, |
38 | x64::brgemm_batch_element_t *addr_batch_global, |
39 | const postgemm_fused_t &fused_postgemm) |
40 | : rnn_brgemm_(rnn_brgemm) |
41 | , rnn_(rnn) |
42 | , need_gemm_layer_(rnn_.need_gemm_layer(cell_position)) |
43 | , layer_desc_idx_(rnn_.layer_brgemm_desc(cell_position)) |
44 | , iter_desc_idx_(rnn_.iter_brgemm_desc(cell_position)) |
45 | , Al_(src_layer) |
46 | , Ai_(src_iter) |
47 | , Bl_(w_layer) |
48 | , Bi_(w_iter) |
49 | , C_(scratch_gates) |
50 | , LDAl_(rnn_.src_layer_ld(cell_position)) |
51 | , LDAi_(rnn_.src_iter_ld(cell_position)) |
52 | , max_nthr_(rnn_.nthr) |
53 | , n_blocking_((rnn_.unfused_post_gemm) ? rnn_.N_blocks * rnn_.n_gates |
54 | : rnn_.N_blocks) |
55 | , m_blocking_(rnn_.M_blocks) |
56 | , work_amount_(n_blocking_ * m_blocking_) |
57 | , Bl_n_offset_(rnn_.K1padded * rnn_.n_block) |
58 | , Bi_n_offset_(rnn_.K2padded * rnn_.n_block) |
59 | , Bl_g_offset_(rnn_.N_blocks * Bl_n_offset_) |
60 | , Bi_g_offset_(rnn_.N_blocks * Bi_n_offset_) |
61 | , Al_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block) |
62 | , Ai_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block) |
63 | , Bl_kb_offset_(rnn_.k1_block * rnn_.n_block) |
64 | , Bi_kb_offset_(rnn_.k2_block * rnn_.n_block) |
65 | , Bl_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block * rnn_.n_block) |
66 | , Bi_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block * rnn_.n_block) |
67 | , n_gates_(rnn.unfused_post_gemm ? 1 : rnn.n_gates) |
68 | , brgemm_kernel_iter_main_( |
69 | rnn_brgemm_.kernel_iter_b1_[iter_desc_idx_].get()) |
70 | , brgemm_kernel_iter_n_tail_( |
71 | rnn_brgemm_.kernel_iter_N_tail_b1_[iter_desc_idx_].get()) |
72 | , brgemm_kernel_iter_k_tail_( |
73 | rnn_brgemm_.kernel_iter_K2_tail_b1_[iter_desc_idx_].get()) |
74 | , brgemm_kernel_iter_nk_tail_( |
75 | rnn_brgemm_.kernel_iter_NK2_tail_b1_[iter_desc_idx_].get()) |
76 | , brgemm_kernel_layer_main_( |
77 | rnn_brgemm_.kernel_layer_b0_[layer_desc_idx_].get()) |
78 | , brgemm_kernel_layer_n_tail_( |
79 | rnn_brgemm_.kernel_layer_N_tail_b0_[layer_desc_idx_].get()) |
80 | , brgemm_kernel_layer_k_tail_( |
81 | rnn_brgemm_.kernel_layer_K1_tail_b1_[layer_desc_idx_].get()) |
82 | , brgemm_kernel_layer_nk_tail_( |
83 | rnn_brgemm_.kernel_layer_NK1_tail_b1_[layer_desc_idx_].get()) |
84 | , pallete_buff_iter_main_(rnn.k1_block == rnn.k2_block && need_gemm_layer_ |
85 | ? rnn_brgemm_.pallete_buff_layer_ |
86 | : rnn_brgemm_.pallete_buff_iter_) |
87 | , pallete_buff_iter_n_tail_(rnn.k1_block == rnn.k2_block && need_gemm_layer_ |
88 | ? rnn_brgemm_.pallete_buff_layer_n_tail_ |
89 | : rnn_brgemm_.pallete_buff_iter_n_tail_) |
90 | , pallete_buff_iter_k_tail_(rnn.k1_tail == rnn.k2_tail && need_gemm_layer_ |
91 | ? rnn_brgemm_.pallete_buff_k1_tail_ |
92 | : rnn_brgemm_.pallete_buff_k2_tail_) |
93 | , pallete_buff_iter_nk_tail_(rnn.k1_tail == rnn.k2_tail && need_gemm_layer_ |
94 | ? rnn_brgemm_.pallete_buff_nk1_tail_ |
95 | : rnn_brgemm_.pallete_buff_nk2_tail_) |
96 | , pallete_buff_layer_main_(rnn_brgemm_.pallete_buff_layer_) |
97 | , pallete_buff_layer_n_tail_(rnn_brgemm_.pallete_buff_layer_n_tail_) |
98 | , pallete_buff_layer_k_tail_(rnn_brgemm_.pallete_buff_k1_tail_) |
99 | , pallete_buff_layer_nk_tail_(rnn_brgemm_.pallete_buff_nk1_tail_) |
100 | , amx_scratchpad_(amx_scratchpad) |
101 | , addr_batch_global_(addr_batch_global) |
102 | , fused_postgemm_(fused_postgemm) |
103 | , is_fused_layer_iter_brgemm_( |
104 | rnn_.sic == rnn_.slc && LDAi_ == LDAl_ && need_gemm_layer_) {} |
105 | |
106 | template <typename src_t, typename weights_t, typename scratch_t, |
107 | typename gemm_acc_t> |
108 | void brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t, gemm_acc_t>::execute() |
109 | const { |
110 | if (is_fused_layer_iter_brgemm_) { |
111 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
112 | this->kernel_fused_iter_layer(ithr, nthr); |
113 | }); |
114 | } else { |
115 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
116 | this->kernel(ithr, nthr); |
117 | }); |
118 | } |
119 | } |
120 | |
121 | template <typename src_t, typename weights_t, typename scratch_t, |
122 | typename gemm_acc_t> |
123 | void brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t, gemm_acc_t>::kernel( |
124 | const int ithr, const int nthr) const { |
125 | using namespace cpu::rnn_utils; |
126 | |
127 | int start = 0, end = 0; |
128 | balance211(work_amount_, nthr, ithr, start, end); |
129 | |
130 | const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx(); |
131 | gemm_acc_t *const amx_buffer = is_amx |
132 | ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr |
133 | : nullptr; |
134 | const int max_K_Block = nstl::max(rnn_.KB1_blocks + 1, |
135 | nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1)); |
136 | brgemm_batch_element_t *const addr_batch |
137 | = addr_batch_global_ + ithr * max_K_Block; |
138 | |
139 | const char *pallete_buff_iter = nullptr; |
140 | const char *pallete_buff_layer = nullptr; |
141 | const char *pallete_buff_iter_k_tail = nullptr; |
142 | const char *pallete_buff_layer_k_tail = nullptr; |
143 | |
144 | dim_t nb_i = 0, mb = 0; |
145 | switch (rnn_.loop_order) { |
146 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
147 | nd_iterator_init(start, mb, m_blocking_, nb_i, n_blocking_); |
148 | break; |
149 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
150 | nd_iterator_init(start, nb_i, n_blocking_, mb, m_blocking_); |
151 | break; |
152 | default: assert(!"unsupported loop order" ); |
153 | } |
154 | |
155 | amx_tile_configuration_loader_t load_cfg_if_needed; |
156 | |
157 | while (start < end) { |
158 | const auto m = mb * rnn_.m_block; |
159 | const auto nb = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i; |
160 | const auto n = nb * rnn_.n_block; |
161 | const auto g_unfused |
162 | = (rnn_.unfused_post_gemm) ? nb_i % rnn_.n_gates : 0; |
163 | |
164 | const auto *const Al_m = Al_ + m * LDAl_; |
165 | const auto *const Ai_m = Ai_ + m * LDAi_; |
166 | const auto *const Bl_n = Bl_ + nb * Bl_n_offset_; |
167 | const auto *const Bi_n = Bi_ + nb * Bi_n_offset_; |
168 | auto *const C_n = C_ + m * rnn_.LDC + n; |
169 | |
170 | const brgemm_kernel_t *brgemm_kernel_layer_b0 |
171 | = brgemm_kernel_layer_main_; |
172 | const brgemm_kernel_t *brgemm_kernel_iter = brgemm_kernel_iter_main_; |
173 | const brgemm_kernel_t *brgemm_kernel_layer_k_tail |
174 | = brgemm_kernel_layer_k_tail_; |
175 | const brgemm_kernel_t *brgemm_kernel_iter_k_tail |
176 | = brgemm_kernel_iter_k_tail_; |
177 | |
178 | if (is_amx) { |
179 | pallete_buff_iter = pallete_buff_iter_main_; |
180 | pallete_buff_layer = pallete_buff_layer_main_; |
181 | pallete_buff_iter_k_tail = pallete_buff_iter_k_tail_; |
182 | pallete_buff_layer_k_tail = pallete_buff_layer_k_tail_; |
183 | } |
184 | |
185 | const bool do_n_tail = (n + rnn_.n_block) > rnn_.N; |
186 | if (do_n_tail) { |
187 | brgemm_kernel_layer_b0 = brgemm_kernel_layer_n_tail_; |
188 | brgemm_kernel_iter = brgemm_kernel_iter_n_tail_; |
189 | brgemm_kernel_layer_k_tail = brgemm_kernel_layer_nk_tail_; |
190 | brgemm_kernel_iter_k_tail = brgemm_kernel_iter_nk_tail_; |
191 | |
192 | if (is_amx) { |
193 | pallete_buff_iter = pallete_buff_iter_n_tail_; |
194 | pallete_buff_layer = pallete_buff_layer_n_tail_; |
195 | pallete_buff_iter_k_tail = pallete_buff_iter_nk_tail_; |
196 | pallete_buff_layer_k_tail = pallete_buff_layer_nk_tail_; |
197 | } |
198 | } |
199 | |
200 | for (int g = 0; g < n_gates_; g++) { |
201 | const int lg = g + g_unfused; |
202 | const auto *const Bl_g = Bl_n + lg * Bl_g_offset_; |
203 | const auto *const Bi_g = Bi_n + lg * Bi_g_offset_; |
204 | auto *const C_g = C_n + lg * rnn_.N; |
205 | |
206 | if (need_gemm_layer_) { |
207 | if (is_amx) load_cfg_if_needed(pallete_buff_layer); |
208 | for (int i = 0; i < rnn_.KB1_blocks; i++) { |
209 | addr_batch[i].ptr.A = Al_m + i * rnn_.k1_block; |
210 | addr_batch[i].ptr.B = Bl_g + i * Bl_kb_offset_; |
211 | } |
212 | brgemm_kernel_execute(brgemm_kernel_layer_b0, rnn_.KB1_blocks, |
213 | addr_batch, reinterpret_cast<void *>(C_g), amx_buffer); |
214 | } |
215 | |
216 | for (int i = 0; i < rnn_.KB2_blocks; i++) { |
217 | addr_batch[i].ptr.A = Ai_m + i * rnn_.k2_block; |
218 | addr_batch[i].ptr.B = Bi_g + i * Bi_kb_offset_; |
219 | } |
220 | if (is_amx) load_cfg_if_needed(pallete_buff_iter); |
221 | brgemm_kernel_execute(brgemm_kernel_iter, rnn_.KB2_blocks, |
222 | addr_batch, reinterpret_cast<void *>(C_g), amx_buffer); |
223 | } |
224 | |
225 | if (rnn_.k1_tail && need_gemm_layer_) { |
226 | if (is_amx) load_cfg_if_needed(pallete_buff_layer_k_tail); |
227 | |
228 | for (int g = 0; g < n_gates_; g++) { |
229 | const int lg = g + g_unfused; |
230 | const auto *const Bl_g = Bl_n + lg * Bl_g_offset_; |
231 | auto *const C_g = C_n + lg * rnn_.N; |
232 | |
233 | addr_batch[0].ptr.A = Al_m + Al_k_tail_offset_; |
234 | addr_batch[0].ptr.B = Bl_g + Bl_k_tail_offset_; |
235 | brgemm_kernel_execute(brgemm_kernel_layer_k_tail, 1, addr_batch, |
236 | reinterpret_cast<void *>(C_g), amx_buffer); |
237 | } |
238 | } |
239 | |
240 | if (rnn_.k2_tail) { |
241 | if (is_amx) load_cfg_if_needed(pallete_buff_iter_k_tail); |
242 | |
243 | for (int g = 0; g < n_gates_; g++) { |
244 | const int lg = g + g_unfused; |
245 | const auto *const Bi_g = Bi_n + lg * Bi_g_offset_; |
246 | auto *const C_g = C_n + lg * rnn_.N; |
247 | |
248 | addr_batch[0].ptr.A = Ai_m + Ai_k_tail_offset_; |
249 | addr_batch[0].ptr.B = Bi_g + Bi_k_tail_offset_; |
250 | brgemm_kernel_execute(brgemm_kernel_iter_k_tail, 1, addr_batch, |
251 | reinterpret_cast<void *>(C_g), amx_buffer); |
252 | } |
253 | } |
254 | |
255 | if (!rnn_.unfused_post_gemm) { |
256 | const auto block_step = (do_n_tail ? rnn_.n_tail : rnn_.n_block) |
257 | * sizeof(scratch_t); |
258 | fused_postgemm_(m, n, nb_i, Ai_m, C_n, block_step); |
259 | } |
260 | |
261 | ++start; |
262 | switch (rnn_.loop_order) { |
263 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
264 | nd_iterator_step(mb, m_blocking_, nb_i, n_blocking_); |
265 | break; |
266 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
267 | nd_iterator_step(nb_i, n_blocking_, mb, m_blocking_); |
268 | break; |
269 | default: assert(!"unsupported loop order" ); |
270 | } |
271 | } |
272 | } |
273 | |
274 | template <typename src_t, typename weights_t, typename scratch_t, |
275 | typename gemm_acc_t> |
276 | void brgemm_dst_layer_iter_t<src_t, weights_t, scratch_t, |
277 | gemm_acc_t>::kernel_fused_iter_layer(const int ithr, |
278 | const int nthr) const { |
279 | using namespace cpu::rnn_utils; |
280 | |
281 | int start = 0, end = 0; |
282 | balance211(work_amount_, nthr, ithr, start, end); |
283 | |
284 | const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx(); |
285 | gemm_acc_t *const amx_buffer = is_amx |
286 | ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr |
287 | : nullptr; |
288 | const int max_K_Block = 2 |
289 | * nstl::max(rnn_.KB1_blocks + 1, |
290 | nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1)); |
291 | brgemm_batch_element_t *const addr_batch |
292 | = addr_batch_global_ + ithr * max_K_Block; |
293 | |
294 | const char *pallete_buff = nullptr; |
295 | const char *pallete_buff_k_tail = nullptr; |
296 | |
297 | dim_t nb_i = 0, mb = 0; |
298 | switch (rnn_.loop_order) { |
299 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
300 | nd_iterator_init(start, mb, m_blocking_, nb_i, n_blocking_); |
301 | break; |
302 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
303 | nd_iterator_init(start, nb_i, n_blocking_, mb, m_blocking_); |
304 | break; |
305 | default: assert(!"unsupported loop order" ); |
306 | } |
307 | |
308 | amx_tile_configuration_loader_t load_cfg_if_needed; |
309 | const auto LDA = LDAl_; |
310 | const auto B_n_offset = Bl_n_offset_; |
311 | const auto B_g_offset = Bl_g_offset_; |
312 | const auto B_kb_offset = Bl_kb_offset_; |
313 | const auto KB_blocks |
314 | = (need_gemm_layer_ ? rnn_.KB1_blocks : 0) + rnn_.KB2_blocks; |
315 | const auto KB_blocks_tail = (need_gemm_layer_ ? 1 : 0) + 1; |
316 | const auto A_k_tail_offset = Al_k_tail_offset_; |
317 | const auto B_k_tail_offset = Bl_k_tail_offset_; |
318 | |
319 | while (start < end) { |
320 | const auto m = mb * rnn_.m_block; |
321 | const auto nb = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i; |
322 | const auto n = nb * rnn_.n_block; |
323 | const auto g_unfused |
324 | = (rnn_.unfused_post_gemm) ? nb_i % rnn_.n_gates : 0; |
325 | |
326 | const auto *const Al_m = Al_ + m * LDA; |
327 | const auto *const Ai_m = Ai_ + m * LDA; |
328 | const auto *const Bl_n = Bl_ + nb * B_n_offset; |
329 | const auto *const Bi_n = Bi_ + nb * B_n_offset; |
330 | auto *const C_n = C_ + m * rnn_.LDC + n; |
331 | |
332 | const brgemm_kernel_t *brgemm_kernel = brgemm_kernel_layer_main_; |
333 | const brgemm_kernel_t *brgemm_kernel_k_tail |
334 | = brgemm_kernel_layer_k_tail_; |
335 | |
336 | if (is_amx) { |
337 | pallete_buff = pallete_buff_layer_main_; |
338 | pallete_buff_k_tail = pallete_buff_layer_k_tail_; |
339 | } |
340 | |
341 | const bool do_n_tail = (n + rnn_.n_block) > rnn_.N; |
342 | if (do_n_tail) { |
343 | brgemm_kernel = brgemm_kernel_layer_n_tail_; |
344 | brgemm_kernel_k_tail = brgemm_kernel_layer_nk_tail_; |
345 | |
346 | if (is_amx) { |
347 | pallete_buff = pallete_buff_layer_n_tail_; |
348 | pallete_buff_k_tail = pallete_buff_layer_nk_tail_; |
349 | } |
350 | } |
351 | |
352 | for (int g = 0; g < n_gates_; g++) { |
353 | const int lg = g + g_unfused; |
354 | const auto *const Bl_g = Bl_n + lg * B_g_offset; |
355 | const auto *const Bi_g = Bi_n + lg * B_g_offset; |
356 | auto *const C_g = C_n + lg * rnn_.N; |
357 | int batch_idx = 0; |
358 | |
359 | if (need_gemm_layer_) { |
360 | for (; batch_idx < rnn_.KB1_blocks; batch_idx++) { |
361 | addr_batch[batch_idx].ptr.A |
362 | = Al_m + batch_idx * rnn_.k1_block; |
363 | addr_batch[batch_idx].ptr.B |
364 | = Bl_g + batch_idx * B_kb_offset; |
365 | } |
366 | } |
367 | |
368 | int iter_idx = 0; |
369 | for (; batch_idx < KB_blocks; batch_idx++) { |
370 | addr_batch[batch_idx].ptr.A = Ai_m + iter_idx * rnn_.k2_block; |
371 | addr_batch[batch_idx].ptr.B = Bi_g + iter_idx * B_kb_offset; |
372 | iter_idx++; |
373 | } |
374 | |
375 | if (is_amx) load_cfg_if_needed(pallete_buff); |
376 | brgemm_kernel_execute(brgemm_kernel, KB_blocks, addr_batch, |
377 | reinterpret_cast<void *>(C_g), amx_buffer); |
378 | } |
379 | |
380 | if (rnn_.k2_tail) { |
381 | for (int g = 0; g < n_gates_; g++) { |
382 | const int lg = g + g_unfused; |
383 | auto *const C_g = C_n + lg * rnn_.N; |
384 | |
385 | int batch_idx = 0; |
386 | if (need_gemm_layer_) { |
387 | const auto *const Bl_g = Bl_n + lg * B_g_offset; |
388 | addr_batch[batch_idx].ptr.A = Al_m + A_k_tail_offset; |
389 | addr_batch[batch_idx].ptr.B = Bl_g + B_k_tail_offset; |
390 | batch_idx++; |
391 | } |
392 | const auto *const Bi_g = Bi_n + lg * B_g_offset; |
393 | addr_batch[batch_idx].ptr.A = Ai_m + A_k_tail_offset; |
394 | addr_batch[batch_idx].ptr.B = Bi_g + B_k_tail_offset; |
395 | |
396 | if (is_amx) load_cfg_if_needed(pallete_buff_k_tail); |
397 | brgemm_kernel_execute(brgemm_kernel_k_tail, KB_blocks_tail, |
398 | addr_batch, reinterpret_cast<void *>(C_g), amx_buffer); |
399 | } |
400 | } |
401 | |
402 | if (!rnn_.unfused_post_gemm) { |
403 | const auto block_step = (do_n_tail ? rnn_.n_tail : rnn_.n_block) |
404 | * sizeof(scratch_t); |
405 | fused_postgemm_(m, n, nb_i, Ai_m, C_n, block_step); |
406 | } |
407 | |
408 | ++start; |
409 | switch (rnn_.loop_order) { |
410 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
411 | nd_iterator_step(mb, m_blocking_, nb_i, n_blocking_); |
412 | break; |
413 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
414 | nd_iterator_step(nb_i, n_blocking_, mb, m_blocking_); |
415 | break; |
416 | default: assert(!"unsupported loop order" ); |
417 | } |
418 | } |
419 | } |
420 | |
421 | template <typename src_t, typename weights_t, typename gemm_acc_t> |
422 | brgemm_dst_proj_t<src_t, weights_t, gemm_acc_t>::brgemm_dst_proj_t( |
423 | const ref_rnn_brgemm_t &rnn_brgemm, const rnn_utils::rnn_conf_t &rnn, |
424 | rnn_utils::cell_position_t cell_position, const src_t *proj_ht, |
425 | const weights_t *w_projection, gemm_acc_t *output, |
426 | gemm_acc_t *amx_scratchpad, |
427 | x64::brgemm_batch_element_t *addr_batch_global, |
428 | const postgemm_fused_t &fused_postgemm) |
429 | : rnn_brgemm_(rnn_brgemm) |
430 | , rnn_(rnn) |
431 | , proj_desc_idx_(rnn_.is_cell_dt_f32() |
432 | ? rnn_.dst_brgemm_desc(cell_position, true) |
433 | : 0) |
434 | , A_(proj_ht) |
435 | , B_(w_projection) |
436 | , C_(output) |
437 | , LDC_(rnn_.is_cell_dt_f32() ? rnn_.dst_layer_ld(cell_position, true) |
438 | : rnn_.scratch_gates_ld) |
439 | , max_nthr_(rnn_.nthr) |
440 | , work_amount_proj_(rnn_.Nproj_blocks * rnn_.M_blocks) |
441 | , B_n_offset_(rnn_.Kprojpadded * rnn_.n_block) |
442 | , Bp_kb_offset_(rnn_.kproj_block * rnn_.n_block) |
443 | , amx_scratchpad_(amx_scratchpad) |
444 | , addr_batch_global_(addr_batch_global) |
445 | , brgemm_kernel_main_(rnn_brgemm_.kernel_proj_b0_[proj_desc_idx_].get()) |
446 | , brgemm_kernel_n_tail_( |
447 | rnn_brgemm_.kernel_proj_N_tail_b0_[proj_desc_idx_].get()) |
448 | , brgemm_kernel_nk_tail_( |
449 | rnn_brgemm_.kernel_proj_NK_tail_b1_[proj_desc_idx_].get()) |
450 | , brgemm_kernel_k_tail_( |
451 | rnn_brgemm_.kernel_proj_K_tail_b1_[proj_desc_idx_].get()) |
452 | , fused_postgemm_(fused_postgemm) {} |
453 | |
454 | template <typename src_t, typename weights_t, typename gemm_acc_t> |
455 | void brgemm_dst_proj_t<src_t, weights_t, gemm_acc_t>::execute() const { |
456 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
457 | this->kernel(ithr, nthr); |
458 | }); |
459 | } |
460 | |
461 | template <typename src_t, typename weights_t, typename gemm_acc_t> |
462 | void brgemm_dst_proj_t<src_t, weights_t, gemm_acc_t>::kernel( |
463 | const int ithr, const int nthr) const { |
464 | using namespace cpu::rnn_utils; |
465 | |
466 | int start = 0, end = 0; |
467 | balance211(work_amount_proj_, nthr, ithr, start, end); |
468 | const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx(); |
469 | const int max_K_Block = nstl::max(rnn_.KB1_blocks + 1, |
470 | nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1)); |
471 | auto *const amx_buffer = is_amx |
472 | ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr |
473 | : nullptr; |
474 | auto *const addr_batch = is_amx ? addr_batch_global_ + ithr * max_K_Block |
475 | : addr_batch_global_ + ithr; |
476 | amx_tile_configuration_loader_t load_cfg_if_needed; |
477 | |
478 | if (is_amx) load_cfg_if_needed(rnn_brgemm_.pallete_buff_proj_); |
479 | |
480 | int nb = 0, mb = 0; |
481 | switch (rnn_.loop_order) { |
482 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
483 | nd_iterator_init(start, mb, rnn_.M_blocks, nb, rnn_.Nproj_blocks); |
484 | break; |
485 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
486 | nd_iterator_init(start, nb, rnn_.Nproj_blocks, mb, rnn_.M_blocks); |
487 | break; |
488 | default: assert(!"unsupported loop order" ); |
489 | } |
490 | |
491 | while (start < end) { |
492 | const int n = nb * rnn_.n_block; |
493 | const int m = mb * rnn_.m_block; |
494 | const bool do_n_tail = (n + rnn_.n_block) > rnn_.Nproj; |
495 | const int block_step = ((do_n_tail) ? rnn_.nproj_tail : rnn_.n_block) |
496 | * sizeof(src_t); |
497 | |
498 | const auto *const Ap_m = A_ + m * rnn_.LDAproj; |
499 | const auto *const Bp_n = B_ + nb * B_n_offset_; |
500 | auto *const Cp_n = C_ + m * LDC_ + n; |
501 | |
502 | const brgemm_kernel_t *const brgemm_kernel_proj_b0 |
503 | = do_n_tail ? brgemm_kernel_n_tail_ : brgemm_kernel_main_; |
504 | |
505 | if (is_amx) { |
506 | if (do_n_tail) |
507 | load_cfg_if_needed(rnn_brgemm_.pallete_buff_nproj_tail_); |
508 | for (int k = 0; k < rnn_.KBproj_blocks; k++) { |
509 | addr_batch[k].ptr.A = Ap_m + k * rnn_.kproj_block; |
510 | addr_batch[k].ptr.B = Bp_n + k * Bp_kb_offset_; |
511 | } |
512 | brgemm_kernel_execute(brgemm_kernel_proj_b0, rnn_.KBproj_blocks, |
513 | addr_batch, reinterpret_cast<void *>(Cp_n), amx_buffer); |
514 | |
515 | if (rnn_.kproj_tail) { |
516 | const brgemm_kernel_t *brgemm_kernel_proj_tail; |
517 | const char *tail_cfg_kproj, *tail_recfg; |
518 | if (do_n_tail) { |
519 | tail_cfg_kproj = rnn_brgemm_.pallete_buff_nkproj_tail_; |
520 | tail_recfg = rnn_brgemm_.pallete_buff_nproj_tail_; |
521 | brgemm_kernel_proj_tail = brgemm_kernel_nk_tail_; |
522 | } else { |
523 | tail_cfg_kproj = rnn_brgemm_.pallete_buff_kproj_tail_; |
524 | tail_recfg = rnn_brgemm_.pallete_buff_proj_; |
525 | brgemm_kernel_proj_tail = brgemm_kernel_k_tail_; |
526 | } |
527 | load_cfg_if_needed(tail_cfg_kproj); |
528 | addr_batch[0].ptr.A |
529 | = Ap_m + rnn_.KBproj_blocks * rnn_.kproj_block; |
530 | addr_batch[0].ptr.B = Bp_n |
531 | + rnn_.KBproj_blocks * rnn_.kproj_block * rnn_.n_block; |
532 | brgemm_kernel_execute(brgemm_kernel_proj_tail, 1, addr_batch, |
533 | reinterpret_cast<void *>(Cp_n), amx_buffer); |
534 | load_cfg_if_needed(tail_recfg); |
535 | } |
536 | } else { |
537 | addr_batch[0].ptr.A = Ap_m; |
538 | addr_batch[0].ptr.B = Bp_n; |
539 | brgemm_kernel_execute(brgemm_kernel_proj_b0, 1, addr_batch, |
540 | reinterpret_cast<void *>(Cp_n), amx_buffer); |
541 | } |
542 | |
543 | if (!rnn_.unfused_post_gemm) { |
544 | fused_postgemm_(m, n, Cp_n, block_step); |
545 | } |
546 | |
547 | ++start; |
548 | switch (rnn_.loop_order) { |
549 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
550 | nd_iterator_step(mb, rnn_.M_blocks, nb, rnn_.Nproj_blocks); |
551 | break; |
552 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
553 | nd_iterator_step(nb, rnn_.Nproj_blocks, mb, rnn_.M_blocks); |
554 | break; |
555 | default: assert(!"unsupported loop order" ); |
556 | } |
557 | } |
558 | } |
559 | |
560 | template <typename src_t, typename weights_t, typename scratch_t, |
561 | typename gemm_acc_t> |
562 | brgemm_gru_t<src_t, weights_t, scratch_t, gemm_acc_t>::brgemm_gru_t( |
563 | const ref_rnn_brgemm_t &rnn_brgemm, const rnn_utils::rnn_conf_t &rnn, |
564 | rnn_utils::cell_position_t cell_position, const src_t *src_iter, |
565 | const src_t *src_layer, weights_t *w_iter0, weights_t *w_iter1, |
566 | weights_t *w_layer, src_t *d_layer, scratch_t *scratch_gates, |
567 | scratch_t *scratch_cell, gemm_acc_t *amx_scratchpad, |
568 | x64::brgemm_batch_element_t *addr_batch_global, |
569 | const postgemm_fused_t &fused_postgemm_part1, |
570 | const postgemm_fused_t &fused_postgemm_part2) |
571 | : rnn_brgemm_(rnn_brgemm) |
572 | , rnn_(rnn) |
573 | , need_gemm_layer_(rnn_.need_gemm_layer(cell_position)) |
574 | , layer_desc_idx_(rnn_.layer_brgemm_desc(cell_position)) |
575 | , iter_desc_idx_(rnn_.iter_brgemm_desc(cell_position)) |
576 | , iter_part2_desc_idx_(rnn_.iter_part2_brgemm_desc(cell_position)) |
577 | , Al_(src_layer) |
578 | , Ai_(src_iter) |
579 | , Bl_(w_layer) |
580 | , Bi_(w_iter0) |
581 | , Bi2_(w_iter1) |
582 | , C_gates_(scratch_gates) |
583 | , C_cell_(scratch_cell) |
584 | , Dl_(d_layer) |
585 | , LDAl_(rnn_.src_layer_ld(cell_position)) |
586 | , LDAi_(rnn_.src_iter_ld(cell_position)) |
587 | , max_nthr_(rnn_.nthr) |
588 | , n_blocking_((rnn_.unfused_post_gemm) ? rnn_.N_blocks * rnn_.n_gates |
589 | : rnn_.N_blocks) |
590 | , m_blocking_(rnn_.M_blocks) |
591 | , work_amount_(m_blocking_) |
592 | , Bl_n_offset_(rnn_.K1padded * rnn_.n_block) |
593 | , Bi_n_offset_(rnn_.K2padded * rnn_.n_block) |
594 | , Bl_g_offset_(rnn_.N_blocks * Bl_n_offset_) |
595 | , Bi_g_offset_(rnn_.N_blocks * Bi_n_offset_) |
596 | , Al_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block) |
597 | , Ai_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block) |
598 | , Bl_kb_offset_(rnn_.k1_block * rnn_.n_block) |
599 | , Bi_kb_offset_(rnn_.k2_block * rnn_.n_block) |
600 | , Bl_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block * rnn_.n_block) |
601 | , Bi_k_tail_offset_(rnn_.KB2_blocks * rnn_.k2_block * rnn_.n_block) |
602 | , n_gates_(rnn.unfused_post_gemm ? 1 : rnn.n_gates) |
603 | , brgemm_kernel_iter_p0_main_(need_gemm_layer_ |
604 | ? rnn_brgemm_.kernel_iter_b1_[iter_desc_idx_].get() |
605 | : rnn_brgemm_.kernel_iter_b0_[iter_desc_idx_].get()) |
606 | , brgemm_kernel_iter_p0_n_tail_(need_gemm_layer_ |
607 | ? rnn_brgemm_.kernel_iter_N_tail_b1_[iter_desc_idx_].get() |
608 | : rnn_brgemm_.kernel_iter_N_tail_b0_[iter_desc_idx_] |
609 | .get()) |
610 | , brgemm_kernel_iter_p0_k_tail_( |
611 | rnn_brgemm_.kernel_iter_K2_tail_b1_[iter_desc_idx_].get()) |
612 | , brgemm_kernel_iter_p0_nk_tail_( |
613 | rnn_brgemm_.kernel_iter_NK2_tail_b1_[iter_desc_idx_].get()) |
614 | , brgemm_kernel_iter_p1_main_( |
615 | rnn_brgemm_.kernel_iter_p2_b1_[iter_part2_desc_idx_].get()) |
616 | , brgemm_kernel_iter_p1_n_tail_( |
617 | rnn_brgemm_.kernel_iter_p2_N_tail_b1_[iter_part2_desc_idx_].get()) |
618 | , brgemm_kernel_iter_p1_k_tail_( |
619 | rnn_brgemm_.kernel_iter_p2_K2_tail_b1_[iter_part2_desc_idx_] |
620 | .get()) |
621 | , brgemm_kernel_iter_p1_nk_tail_( |
622 | rnn_brgemm_.kernel_iter_p2_NK2_tail_b1_[iter_part2_desc_idx_] |
623 | .get()) |
624 | , brgemm_kernel_layer_main_( |
625 | rnn_brgemm_.kernel_layer_b0_[layer_desc_idx_].get()) |
626 | , brgemm_kernel_layer_n_tail_( |
627 | rnn_brgemm_.kernel_layer_N_tail_b0_[layer_desc_idx_].get()) |
628 | , brgemm_kernel_layer_k_tail_( |
629 | rnn_brgemm_.kernel_layer_K1_tail_b1_[layer_desc_idx_].get()) |
630 | , brgemm_kernel_layer_nk_tail_( |
631 | rnn_brgemm_.kernel_layer_NK1_tail_b1_[layer_desc_idx_].get()) |
632 | , pallete_buff_iter_main_(rnn.k1_block == rnn.k2_block |
633 | ? rnn_brgemm_.pallete_buff_layer_ |
634 | : rnn_brgemm_.pallete_buff_iter_) |
635 | , pallete_buff_iter_n_tail_(rnn.k1_block == rnn.k2_block |
636 | ? rnn_brgemm_.pallete_buff_layer_n_tail_ |
637 | : rnn_brgemm_.pallete_buff_iter_n_tail_) |
638 | , pallete_buff_iter_k_tail_(rnn.k1_tail == rnn.k2_tail |
639 | ? rnn_brgemm_.pallete_buff_k1_tail_ |
640 | : rnn_brgemm_.pallete_buff_k2_tail_) |
641 | , pallete_buff_iter_nk_tail_(rnn.k1_tail == rnn.k2_tail |
642 | ? rnn_brgemm_.pallete_buff_nk1_tail_ |
643 | : rnn_brgemm_.pallete_buff_nk2_tail_) |
644 | , pallete_buff_layer_main_(rnn_brgemm_.pallete_buff_layer_) |
645 | , pallete_buff_layer_n_tail_(rnn_brgemm_.pallete_buff_layer_n_tail_) |
646 | , pallete_buff_layer_k_tail_(rnn_brgemm_.pallete_buff_k1_tail_) |
647 | , pallete_buff_layer_nk_tail_(rnn_brgemm_.pallete_buff_nk1_tail_) |
648 | , amx_scratchpad_(amx_scratchpad) |
649 | , addr_batch_global_(addr_batch_global) |
650 | , fused_postgemm_part1_(fused_postgemm_part1) |
651 | , fused_postgemm_part2_(fused_postgemm_part2) |
652 | , is_fused_layer_iter_brgemm_(true) {} |
653 | |
654 | template <typename src_t, typename weights_t, typename scratch_t, |
655 | typename gemm_acc_t> |
656 | void brgemm_gru_t<src_t, weights_t, scratch_t, gemm_acc_t>::execute() const { |
657 | assert(is_fused_layer_iter_brgemm_); |
658 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
659 | this->kernel(ithr, nthr); |
660 | }); |
661 | } |
662 | |
663 | template <typename src_t, typename weights_t, typename scratch_t, |
664 | typename gemm_acc_t> |
665 | void brgemm_gru_t<src_t, weights_t, scratch_t, gemm_acc_t>::kernel( |
666 | const int ithr, const int nthr) const { |
667 | int start = 0, end = 0; |
668 | balance211(work_amount_, nthr, ithr, start, end); |
669 | |
670 | const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx(); |
671 | gemm_acc_t *const amx_buffer = is_amx |
672 | ? amx_scratchpad_ + rnn_.m_block * rnn_.n_block * ithr |
673 | : nullptr; |
674 | const int max_K_Block = 2 |
675 | * nstl::max(rnn_.KB1_blocks + 1, |
676 | nstl::max(rnn_.KBproj_blocks + 1, rnn_.KB2_blocks + 1)); |
677 | brgemm_batch_element_t *const addr_batch |
678 | = addr_batch_global_ + ithr * max_K_Block; |
679 | |
680 | const char *pallete_buff_layer = nullptr; |
681 | const char *pallete_buff_layer_k_tail = nullptr; |
682 | const char *pallete_buff_iter = nullptr; |
683 | const char *pallete_buff_iter_k_tail = nullptr; |
684 | |
685 | amx_tile_configuration_loader_t load_cfg_if_needed; |
686 | while (start < end) { |
687 | dim_t mb = start; |
688 | const auto m = mb * rnn_.m_block; |
689 | const auto *const Al_m = Al_ + m * LDAl_; |
690 | const auto *const Ai_m = Ai_ + m * LDAi_; |
691 | const auto *const Ai2_m = Dl_ + m * LDAl_; |
692 | |
693 | for (dim_t nb_i = 0; nb_i < n_blocking_; nb_i++) { |
694 | const auto nb |
695 | = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i; |
696 | const auto n = nb * rnn_.n_block; |
697 | |
698 | const auto *const Bl_n = Bl_ + nb * Bl_n_offset_; |
699 | const auto *const Bi_n = Bi_ + nb * Bi_n_offset_; |
700 | auto *const C_gates_n = C_gates_ + m * rnn_.LDC + n; |
701 | auto *const C_cell_n = C_cell_ + m * rnn_.LDC + n; |
702 | |
703 | const brgemm_kernel_t *brgemm_kernel_layer |
704 | = brgemm_kernel_layer_main_; |
705 | const brgemm_kernel_t *brgemm_kernel_layer_k_tail |
706 | = brgemm_kernel_layer_k_tail_; |
707 | const brgemm_kernel_t *brgemm_kernel_iter_p0 |
708 | = brgemm_kernel_iter_p0_main_; |
709 | const brgemm_kernel_t *brgemm_kernel_iter_p0_k_tail |
710 | = brgemm_kernel_iter_p0_k_tail_; |
711 | |
712 | if (is_amx) { |
713 | pallete_buff_layer = pallete_buff_layer_main_; |
714 | pallete_buff_layer_k_tail = pallete_buff_layer_k_tail_; |
715 | pallete_buff_iter = pallete_buff_iter_main_; |
716 | pallete_buff_iter_k_tail = pallete_buff_iter_k_tail_; |
717 | } |
718 | |
719 | const bool do_n_tail = (n + rnn_.n_block) > rnn_.N; |
720 | if (do_n_tail) { |
721 | brgemm_kernel_layer = brgemm_kernel_layer_n_tail_; |
722 | brgemm_kernel_layer_k_tail = brgemm_kernel_layer_nk_tail_; |
723 | brgemm_kernel_iter_p0 = brgemm_kernel_iter_p0_n_tail_; |
724 | brgemm_kernel_iter_p0_k_tail = brgemm_kernel_iter_p0_nk_tail_; |
725 | |
726 | if (is_amx) { |
727 | pallete_buff_layer = pallete_buff_layer_n_tail_; |
728 | pallete_buff_layer_k_tail = pallete_buff_layer_nk_tail_; |
729 | pallete_buff_iter = pallete_buff_iter_n_tail_; |
730 | pallete_buff_iter_k_tail = pallete_buff_iter_nk_tail_; |
731 | } |
732 | } |
733 | |
734 | if (need_gemm_layer_) { |
735 | if (is_amx) load_cfg_if_needed(pallete_buff_layer); |
736 | for (int g = 0; g < n_gates_; g++) { |
737 | const auto *const Bl_g = Bl_n + g * Bl_g_offset_; |
738 | auto *const C_gates_g = C_gates_n + g * rnn_.N; |
739 | |
740 | for (int batch_idx = 0; batch_idx < rnn_.KB1_blocks; |
741 | batch_idx++) { |
742 | addr_batch[batch_idx].ptr.A |
743 | = Al_m + batch_idx * rnn_.k1_block; |
744 | addr_batch[batch_idx].ptr.B |
745 | = Bl_g + batch_idx * Bl_kb_offset_; |
746 | } |
747 | brgemm_kernel_execute(brgemm_kernel_layer, rnn_.KB1_blocks, |
748 | addr_batch, reinterpret_cast<void *>(C_gates_g), |
749 | amx_buffer); |
750 | } |
751 | } |
752 | |
753 | if (need_gemm_layer_ && rnn_.k1_tail > 0) { |
754 | if (is_amx) load_cfg_if_needed(pallete_buff_layer_k_tail); |
755 | for (int g = 0; g < n_gates_; g++) { |
756 | const auto *const Bl_g = Bl_n + g * Bl_g_offset_; |
757 | auto *const C_gates_g = C_gates_n + g * rnn_.N; |
758 | |
759 | addr_batch[0].ptr.A |
760 | = Al_m + rnn_.KB1_blocks * rnn_.k1_block; |
761 | addr_batch[0].ptr.B |
762 | = Bl_g + rnn_.KB1_blocks * Bl_kb_offset_; |
763 | brgemm_kernel_execute(brgemm_kernel_layer_k_tail, 1, |
764 | addr_batch, reinterpret_cast<void *>(C_gates_g), |
765 | amx_buffer); |
766 | } |
767 | } |
768 | if (is_amx) load_cfg_if_needed(pallete_buff_iter); |
769 | for (int g = 0; g < n_gates_ - 1; g++) { |
770 | const auto *const Bi_g = Bi_n + g * Bi_g_offset_; |
771 | auto *const C_gates_g = C_gates_n + g * rnn_.N; |
772 | |
773 | for (int batch_idx = 0; batch_idx < rnn_.KB2_blocks; |
774 | batch_idx++) { |
775 | addr_batch[batch_idx].ptr.A |
776 | = Ai_m + batch_idx * rnn_.k2_block; |
777 | addr_batch[batch_idx].ptr.B |
778 | = Bi_g + batch_idx * Bi_kb_offset_; |
779 | } |
780 | |
781 | brgemm_kernel_execute(brgemm_kernel_iter_p0, rnn_.KB2_blocks, |
782 | addr_batch, reinterpret_cast<void *>(C_gates_g), |
783 | amx_buffer); |
784 | } |
785 | |
786 | if (rnn_.k2_tail > 0) { |
787 | if (is_amx) load_cfg_if_needed(pallete_buff_iter_k_tail); |
788 | for (int g = 0; g < n_gates_ - 1; g++) { |
789 | const auto *const Bi_g = Bi_n + g * Bi_g_offset_; |
790 | auto *const C_gates_g = C_gates_n + g * rnn_.N; |
791 | |
792 | addr_batch[0].ptr.A |
793 | = Ai_m + rnn_.KB2_blocks * rnn_.k2_block; |
794 | addr_batch[0].ptr.B |
795 | = Bi_g + rnn_.KB2_blocks * Bi_kb_offset_; |
796 | |
797 | brgemm_kernel_execute(brgemm_kernel_iter_p0_k_tail, 1, |
798 | addr_batch, reinterpret_cast<void *>(C_gates_g), |
799 | amx_buffer); |
800 | } |
801 | } |
802 | |
803 | if (!rnn_.unfused_post_gemm) { |
804 | const auto block_step |
805 | = (do_n_tail ? rnn_.n_tail : rnn_.n_block); |
806 | fused_postgemm_part1_( |
807 | m, n, nb_i, Ai_m + n, C_gates_n, C_cell_n, block_step); |
808 | } |
809 | } |
810 | |
811 | for (dim_t nb_i = 0; nb_i < n_blocking_; nb_i++) { |
812 | const auto nb |
813 | = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i; |
814 | const auto n = nb * rnn_.n_block; |
815 | |
816 | const auto *const Bi2_n = Bi2_ + nb * Bi_n_offset_; |
817 | auto *const C_gates_n = C_gates_ + m * rnn_.LDC + n; |
818 | |
819 | const brgemm_kernel_t *brgemm_kernel_iter_p1 |
820 | = brgemm_kernel_iter_p1_main_; |
821 | const brgemm_kernel_t *brgemm_kernel_iter_p1_k_tail |
822 | = brgemm_kernel_iter_p1_k_tail_; |
823 | |
824 | if (is_amx) { |
825 | pallete_buff_iter = pallete_buff_iter_main_; |
826 | pallete_buff_iter_k_tail = pallete_buff_iter_k_tail_; |
827 | } |
828 | |
829 | const bool do_n_tail = (n + rnn_.n_block) > rnn_.N; |
830 | if (do_n_tail) { |
831 | brgemm_kernel_iter_p1 = brgemm_kernel_iter_p1_n_tail_; |
832 | brgemm_kernel_iter_p1_k_tail = brgemm_kernel_iter_p1_nk_tail_; |
833 | |
834 | if (is_amx) { |
835 | pallete_buff_iter = pallete_buff_iter_n_tail_; |
836 | pallete_buff_iter_k_tail = pallete_buff_iter_nk_tail_; |
837 | } |
838 | } |
839 | |
840 | if (is_amx) load_cfg_if_needed(pallete_buff_iter); |
841 | for (int g = 0; g < 1; g++) { |
842 | const auto *const Bi2_g = Bi2_n + g * Bi_g_offset_; |
843 | auto *const C_gates_g = C_gates_n + (n_gates_ - 1) * rnn_.N; |
844 | |
845 | for (int batch_idx = 0; batch_idx < rnn_.KB2_blocks; |
846 | batch_idx++) { |
847 | addr_batch[batch_idx].ptr.A |
848 | = Ai2_m + batch_idx * rnn_.k2_block; |
849 | addr_batch[batch_idx].ptr.B |
850 | = Bi2_g + batch_idx * Bi_kb_offset_; |
851 | } |
852 | |
853 | brgemm_kernel_execute(brgemm_kernel_iter_p1, rnn_.KB2_blocks, |
854 | addr_batch, reinterpret_cast<void *>(C_gates_g), |
855 | amx_buffer); |
856 | } |
857 | |
858 | if (rnn_.k2_tail > 0) { |
859 | if (is_amx) load_cfg_if_needed(pallete_buff_iter_k_tail); |
860 | for (int g = 0; g < 1; g++) { |
861 | const auto *const Bi2_g = Bi2_n + g * Bi_g_offset_; |
862 | auto *const C_gates_g = C_gates_n + (n_gates_ - 1) * rnn_.N; |
863 | |
864 | addr_batch[0].ptr.A |
865 | = Ai2_m + rnn_.KB2_blocks * rnn_.k2_block; |
866 | addr_batch[0].ptr.B |
867 | = Bi2_g + rnn_.KB2_blocks * Bi_kb_offset_; |
868 | |
869 | brgemm_kernel_execute(brgemm_kernel_iter_p1_k_tail, 1, |
870 | addr_batch, reinterpret_cast<void *>(C_gates_g), |
871 | amx_buffer); |
872 | } |
873 | } |
874 | if (!rnn_.unfused_post_gemm && nb_i == n_blocking_ - 1) { |
875 | fused_postgemm_part2_(m, 0, 0, Ai_m, C_gates_ + m * rnn_.LDC, |
876 | C_cell_ + m * rnn_.LDC, rnn_.N); |
877 | } |
878 | } |
879 | ++start; |
880 | } |
881 | } |
882 | |
883 | template <typename src_t, typename weights_t, typename scratch_t, |
884 | typename gemm_acc_t> |
885 | brgemm_merged_layer_t<src_t, weights_t, scratch_t, |
886 | gemm_acc_t>::brgemm_merged_layer_t(const ref_rnn_brgemm_t &rnn_brgemm, |
887 | const rnn_utils::rnn_conf_t &rnn, |
888 | rnn_utils::cell_position_t cell_position, const src_t *src_layer, |
889 | weights_t *w_layer, scratch_t *scratch_gates, |
890 | gemm_acc_t *amx_scratchpad, |
891 | x64::brgemm_batch_element_t *addr_batch_global) |
892 | : rnn_brgemm_(rnn_brgemm) |
893 | , rnn_(rnn) |
894 | , layer_desc_idx_(rnn_.layer_brgemm_desc(cell_position)) |
895 | , Al_(src_layer) |
896 | , Bl_(w_layer) |
897 | , C_(scratch_gates) |
898 | , LDAl_(rnn_.src_layer_ld(cell_position)) |
899 | , max_nthr_(rnn_.nthr) |
900 | , n_blocking_((rnn_.unfused_post_gemm) ? rnn_.N_blocks * rnn_.n_gates |
901 | : rnn_.N_blocks) |
902 | , m_blocking_(rnn_.Mlayermerged_blocks) |
903 | , work_amount_(n_blocking_ * m_blocking_) |
904 | , Bl_n_offset_(rnn_.K1padded * rnn_.n_block) |
905 | , Bl_g_offset_(rnn_.N_blocks * Bl_n_offset_) |
906 | , Al_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block) |
907 | , Bl_kb_offset_(rnn_.k1_block * rnn_.n_block) |
908 | , Bl_k_tail_offset_(rnn_.KB1_blocks * rnn_.k1_block * rnn_.n_block) |
909 | , n_gates_(rnn.unfused_post_gemm ? 1 : rnn.n_gates) |
910 | , brgemm_kernel_layer_main_( |
911 | rnn_brgemm_.kernel_layermerged_b0_[layer_desc_idx_].get()) |
912 | , brgemm_kernel_layer_n_tail_( |
913 | rnn_brgemm_.kernel_layermerged_N_tail_b0_[layer_desc_idx_].get()) |
914 | , brgemm_kernel_layer_k_tail_( |
915 | rnn_brgemm_.kernel_layermerged_K1_tail_b1_[layer_desc_idx_].get()) |
916 | , brgemm_kernel_layer_nk_tail_( |
917 | rnn_brgemm_.kernel_layermerged_NK1_tail_b1_[layer_desc_idx_] |
918 | .get()) |
919 | , pallete_buff_layer_main_(rnn_brgemm_.pallete_buff_layermerged_) |
920 | , pallete_buff_layer_n_tail_(rnn_brgemm_.pallete_buff_layermerged_n_tail_) |
921 | , pallete_buff_layer_k_tail_(rnn_brgemm_.pallete_buff_layermerged_k1_tail_) |
922 | , pallete_buff_layer_nk_tail_( |
923 | rnn_brgemm_.pallete_buff_layermerged_nk1_tail_) |
924 | , amx_scratchpad_(amx_scratchpad) |
925 | , addr_batch_global_(addr_batch_global) {} |
926 | |
927 | template <typename src_t, typename weights_t, typename scratch_t, |
928 | typename gemm_acc_t> |
929 | void brgemm_merged_layer_t<src_t, weights_t, scratch_t, gemm_acc_t>::execute() |
930 | const { |
931 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
932 | this->kernel(ithr, nthr); |
933 | }); |
934 | } |
935 | |
936 | template <typename src_t, typename weights_t, typename scratch_t, |
937 | typename gemm_acc_t> |
938 | void brgemm_merged_layer_t<src_t, weights_t, scratch_t, gemm_acc_t>::kernel( |
939 | const int ithr, const int nthr) const { |
940 | using namespace cpu::rnn_utils; |
941 | |
942 | int start = 0, end = 0; |
943 | balance211(work_amount_, nthr, ithr, start, end); |
944 | |
945 | const bool is_amx = rnn_.is_cell_int8_amx() || rnn_.is_cell_bf16_amx(); |
946 | const auto m_block = rnn_.mlayermerged_block; |
947 | gemm_acc_t *const amx_buffer = is_amx |
948 | ? amx_scratchpad_ + m_block * rnn_.n_block * ithr |
949 | : nullptr; |
950 | const int max_K_Block = rnn_.KB1_blocks + 1; |
951 | brgemm_batch_element_t *const addr_batch |
952 | = addr_batch_global_ + ithr * max_K_Block; |
953 | |
954 | const char *pallete_buff_layer = nullptr; |
955 | const char *pallete_buff_layer_k_tail = nullptr; |
956 | |
957 | dim_t nb_i = 0, mb = 0; |
958 | switch (rnn_.loop_order) { |
959 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
960 | nd_iterator_init(start, mb, m_blocking_, nb_i, n_blocking_); |
961 | break; |
962 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
963 | nd_iterator_init(start, nb_i, n_blocking_, mb, m_blocking_); |
964 | break; |
965 | default: assert(!"unsupported loop order" ); |
966 | } |
967 | |
968 | amx_tile_configuration_loader_t load_cfg_if_needed; |
969 | |
970 | while (start < end) { |
971 | const auto m = mb * m_block; |
972 | const auto nb = (rnn_.unfused_post_gemm) ? nb_i / rnn_.n_gates : nb_i; |
973 | const auto n = nb * rnn_.n_block; |
974 | const auto g_unfused |
975 | = (rnn_.unfused_post_gemm) ? nb_i % rnn_.n_gates : 0; |
976 | |
977 | const auto *const Al_m = Al_ + m * LDAl_; |
978 | const auto *const Bl_n = Bl_ + nb * Bl_n_offset_; |
979 | auto *const C_n = C_ + m * rnn_.LDC + n; |
980 | |
981 | const brgemm_kernel_t *brgemm_kernel_layer_b0 |
982 | = brgemm_kernel_layer_main_; |
983 | const brgemm_kernel_t *brgemm_kernel_layer_k_tail |
984 | = brgemm_kernel_layer_k_tail_; |
985 | |
986 | if (is_amx) { |
987 | pallete_buff_layer = pallete_buff_layer_main_; |
988 | pallete_buff_layer_k_tail = pallete_buff_layer_k_tail_; |
989 | } |
990 | |
991 | const bool do_n_tail = (n + rnn_.n_block) > rnn_.N; |
992 | if (do_n_tail) { |
993 | brgemm_kernel_layer_b0 = brgemm_kernel_layer_n_tail_; |
994 | brgemm_kernel_layer_k_tail = brgemm_kernel_layer_nk_tail_; |
995 | |
996 | if (is_amx) { |
997 | pallete_buff_layer = pallete_buff_layer_n_tail_; |
998 | pallete_buff_layer_k_tail = pallete_buff_layer_nk_tail_; |
999 | } |
1000 | } |
1001 | |
1002 | for (int g = 0; g < n_gates_; g++) { |
1003 | const int lg = g + g_unfused; |
1004 | const auto *const Bl_g = Bl_n + lg * Bl_g_offset_; |
1005 | auto *const C_g = C_n + lg * rnn_.N; |
1006 | |
1007 | if (is_amx) load_cfg_if_needed(pallete_buff_layer); |
1008 | for (int i = 0; i < rnn_.KB1_blocks; i++) { |
1009 | addr_batch[i].ptr.A = Al_m + i * rnn_.k1_block; |
1010 | addr_batch[i].ptr.B = Bl_g + i * Bl_kb_offset_; |
1011 | } |
1012 | brgemm_kernel_execute(brgemm_kernel_layer_b0, rnn_.KB1_blocks, |
1013 | addr_batch, reinterpret_cast<void *>(C_g), amx_buffer); |
1014 | } |
1015 | |
1016 | if (rnn_.k1_tail) { |
1017 | if (is_amx) load_cfg_if_needed(pallete_buff_layer_k_tail); |
1018 | |
1019 | for (int g = 0; g < n_gates_; g++) { |
1020 | const int lg = g + g_unfused; |
1021 | const auto *const Bl_g = Bl_n + lg * Bl_g_offset_; |
1022 | auto *const C_g = C_n + lg * rnn_.N; |
1023 | |
1024 | addr_batch[0].ptr.A = Al_m + Al_k_tail_offset_; |
1025 | addr_batch[0].ptr.B = Bl_g + Bl_k_tail_offset_; |
1026 | brgemm_kernel_execute(brgemm_kernel_layer_k_tail, 1, addr_batch, |
1027 | reinterpret_cast<void *>(C_g), amx_buffer); |
1028 | } |
1029 | } |
1030 | |
1031 | ++start; |
1032 | switch (rnn_.loop_order) { |
1033 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
1034 | nd_iterator_step(mb, m_blocking_, nb_i, n_blocking_); |
1035 | break; |
1036 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
1037 | nd_iterator_step(nb_i, n_blocking_, mb, m_blocking_); |
1038 | break; |
1039 | default: assert(!"unsupported loop order" ); |
1040 | } |
1041 | } |
1042 | } |
1043 | |
1044 | template class brgemm_dst_layer_iter_t<uint8_t, int8_t, int32_t, int32_t>; |
1045 | template class brgemm_dst_layer_iter_t<int8_t, int8_t, int32_t, int32_t>; |
1046 | template class brgemm_dst_layer_iter_t<float, float, float, float>; |
1047 | template class brgemm_dst_layer_iter_t<bfloat16_t, bfloat16_t, float, float>; |
1048 | |
1049 | template class brgemm_dst_proj_t<float, float, float>; |
1050 | template class brgemm_dst_proj_t<bfloat16_t, bfloat16_t, float>; |
1051 | template class brgemm_dst_proj_t<int8_t, int8_t, int32_t>; |
1052 | template class brgemm_dst_proj_t<uint8_t, int8_t, int32_t>; |
1053 | |
1054 | template class brgemm_gru_t<uint8_t, int8_t, int32_t, int32_t>; |
1055 | template class brgemm_gru_t<int8_t, int8_t, int32_t, int32_t>; |
1056 | template class brgemm_gru_t<float, float, float, float>; |
1057 | template class brgemm_gru_t<bfloat16_t, bfloat16_t, float, float>; |
1058 | |
1059 | template class brgemm_merged_layer_t<uint8_t, int8_t, int32_t, int32_t>; |
1060 | template class brgemm_merged_layer_t<int8_t, int8_t, int32_t, int32_t>; |
1061 | template class brgemm_merged_layer_t<float, float, float, float>; |
1062 | template class brgemm_merged_layer_t<bfloat16_t, bfloat16_t, float, float>; |
1063 | |
1064 | } // namespace x64 |
1065 | } // namespace cpu |
1066 | } // namespace impl |
1067 | } // namespace dnnl |
1068 | |