1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/rnn/brgemm_cell_common_bwd.hpp" |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp" |
23 | #include "cpu/x64/rnn/brgemm_cell_common_utils.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace dnnl::impl::utils; |
31 | |
32 | template <typename weights_t, typename scratch_t, typename gemm_acc_t> |
33 | brgemm_diff_src_layer_iter_t<weights_t, scratch_t, |
34 | gemm_acc_t>::brgemm_diff_src_layer_iter_t(const ref_rnn_brgemm_t |
35 | &rnn_brgemm, |
36 | const rnn_utils::rnn_conf_t &rnn, |
37 | rnn_utils::cell_position_t cell_position, scratch_t *scratch_gates, |
38 | weights_t *w_iter, weights_t *w_layer, gemm_acc_t *diff_src_iter, |
39 | gemm_acc_t *diff_src_layer, gemm_acc_t *amx_scratchpad, |
40 | x64::brgemm_batch_element_t *addr_batch_global) |
41 | : rnn_brgemm_(rnn_brgemm) |
42 | , rnn_(rnn) |
43 | , A_(scratch_gates) |
44 | , B_wei_iter_(w_iter) |
45 | , B_wei_layer_(w_layer) |
46 | , C_diff_iter_(diff_src_iter) |
47 | , C_diff_layer_(diff_src_layer) |
48 | , k_blocks_n_gates_(rnn.diff_src_brgemm.K_blocks) |
49 | , k_blocks_(rnn.diff_src_brgemm.K_blocks / rnn.n_gates) |
50 | , k_tail_(rnn.diff_src_brgemm.k_tail) |
51 | , k_block_(rnn.diff_src_brgemm.k_block) |
52 | , A_k_tail_offset_(k_blocks_ * k_block_) |
53 | , B_k_tail_offset_(A_k_tail_offset_ * rnn.diff_src_brgemm.n_block) |
54 | , B_nb_offset_(rnn.diff_src_brgemm.Kpadded * rnn.diff_src_brgemm.n_block) |
55 | , B_kb_offset_(k_block_ * rnn.diff_src_brgemm.n_block) |
56 | , B_gb_iter_offset_(rnn.diff_src_brgemm.Kpadded |
57 | * rnn.diff_src_brgemm.n_block * rnn.diff_src_brgemm.N_iter_blocks) |
58 | , B_gb_layer_offset_(rnn.diff_src_brgemm.Kpadded |
59 | * rnn.diff_src_brgemm.n_block |
60 | * rnn.diff_src_brgemm.N_layer_blocks) |
61 | , LDA_(rnn.diff_src_brgemm.LDA) |
62 | , LDC_(rnn.diff_src_brgemm.LDC) |
63 | , max_nthr_(rnn.nthr) |
64 | , n_blocking_(rnn.diff_src_brgemm.N_blocks) |
65 | , m_blocking_(rnn.diff_src_brgemm.M_blocks) |
66 | , work_amount_(n_blocking_ * m_blocking_) |
67 | , max_n_layer_blocks_(rnn.diff_src_brgemm.N_layer_blocks) |
68 | , max_n_iter_blocks_(rnn.diff_src_brgemm.N_iter_blocks) |
69 | , gemm_layer_needed_(rnn.need_gemm_layer(cell_position)) |
70 | , kernel_iter_full_blocks_b0_( |
71 | rnn_brgemm_.diff_src_.kernel_iter_layer_beta0_.get()) |
72 | , kernel_iter_full_blocks_b1_( |
73 | rnn_brgemm_.diff_src_.kernel_iter_layer_beta1_.get()) |
74 | |
75 | , kernel_iter_n_tail_b0_( |
76 | rnn_brgemm_.diff_src_.kernel_iter_N_tail_beta0_.get()) |
77 | , kernel_iter_n_tail_b1_( |
78 | rnn_brgemm_.diff_src_.kernel_iter_N_tail_beta1_.get()) |
79 | , kernel_iter_k_tail_( |
80 | rnn_brgemm_.diff_src_.kernel_iter_layer_K_tail_beta1_.get()) |
81 | , kernel_iter_nk_tail_( |
82 | rnn_brgemm_.diff_src_.kernel_iter_NK_tail_beta1_.get()) |
83 | , kernel_layer_full_blocks_b0_( |
84 | rnn_brgemm_.diff_src_.kernel_iter_layer_beta0_.get()) |
85 | , kernel_layer_full_blocks_b1_( |
86 | rnn_brgemm_.diff_src_.kernel_iter_layer_beta1_.get()) |
87 | , kernel_layer_n_tail_b0_( |
88 | rnn_brgemm_.diff_src_.kernel_layer_N_tail_beta0_.get()) |
89 | , kernel_layer_n_tail_b1_( |
90 | rnn_brgemm_.diff_src_.kernel_layer_N_tail_beta1_.get()) |
91 | , kernel_layer_k_tail_( |
92 | rnn_brgemm_.diff_src_.kernel_iter_layer_K_tail_beta1_.get()) |
93 | , kernel_layer_nk_tail_( |
94 | rnn_brgemm_.diff_src_.kernel_layer_NK_tail_beta1_.get()) |
95 | , amx_scratchpad_(amx_scratchpad) |
96 | , addr_batch_global_(addr_batch_global) {} |
97 | |
98 | template <typename weights_t, typename scratch_t, typename gemm_acc_t> |
99 | void brgemm_diff_src_layer_iter_t<weights_t, scratch_t, gemm_acc_t>::execute() |
100 | const { |
101 | if (rnn_.is_cell_dt_bf16() |
102 | && rnn_.diff_src_brgemm.isa == x64::avx512_core_amx) { |
103 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
104 | this->kernel_amx(ithr, nthr); |
105 | }); |
106 | } else { |
107 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
108 | this->kernel(ithr, nthr); |
109 | }); |
110 | } |
111 | } |
112 | |
113 | template <typename weights_t, typename scratch_t, typename gemm_acc_t> |
114 | void brgemm_diff_src_layer_iter_t<weights_t, scratch_t, |
115 | gemm_acc_t>::kernel_amx_compute_iter(const int m_block_id, |
116 | const int n_block_id, const int gates_start, const int gates_end, |
117 | thread_exec_ctx_t &ctx) const { |
118 | |
119 | const int m = m_block_id * rnn_.diff_src_brgemm.m_block; |
120 | const int n = n_block_id * rnn_.diff_src_brgemm.n_block; |
121 | const int num_gates = gates_end - gates_start; |
122 | const scratch_t *const A_m = A_ + m * LDA_; |
123 | const auto B_n_offset = n_block_id * B_nb_offset_; |
124 | const weights_t *const B_wei_iter_n = B_wei_iter_ + B_n_offset; |
125 | const weights_t *const B_wei_layer_n = B_wei_layer_ + B_n_offset; |
126 | const auto C_offset = m * LDC_ + n; |
127 | gemm_acc_t *const C_diff_iter_n = C_diff_iter_ + C_offset; |
128 | gemm_acc_t *const C_diff_layer_n = C_diff_layer_ + C_offset; |
129 | |
130 | const brgemm_kernel_t *kernel_iter = gates_start == 0 |
131 | ? kernel_iter_full_blocks_b0_ |
132 | : kernel_iter_full_blocks_b1_; |
133 | const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_; |
134 | const brgemm_kernel_t *kernel_layer = gates_start == 0 |
135 | ? kernel_layer_full_blocks_b0_ |
136 | : kernel_layer_full_blocks_b1_; |
137 | const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_; |
138 | |
139 | const char *kernel_iter_config |
140 | = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_; |
141 | const char *kernel_iter_k_tail_config |
142 | = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_k_tail_; |
143 | const char *kernel_layer_config |
144 | = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_; |
145 | const char *kernel_layer_k_tail_config |
146 | = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_k_tail_; |
147 | |
148 | const bool should_calc_diff_src_layer |
149 | = gemm_layer_needed_ && n_block_id < max_n_layer_blocks_; |
150 | const bool should_calc_diff_src_iter = n_block_id < max_n_iter_blocks_; |
151 | |
152 | if (should_calc_diff_src_iter) { |
153 | const bool do_n_iter_tail = (n + rnn_.diff_src_brgemm.n_block) |
154 | > rnn_.diff_src_brgemm.N_iter; |
155 | |
156 | if (do_n_iter_tail) { |
157 | kernel_iter = gates_start == 0 ? kernel_iter_n_tail_b0_ |
158 | : kernel_iter_n_tail_b1_; |
159 | kernel_iter_k_tail = kernel_iter_nk_tail_; |
160 | kernel_iter_config |
161 | = rnn_brgemm_.diff_src_.pallete_buff_iter_n_tail_; |
162 | kernel_iter_k_tail_config |
163 | = rnn_brgemm_.diff_src_.pallete_buff_iter_nk_tail_; |
164 | } |
165 | |
166 | for (int gate_id = gates_start; gate_id < gates_end; gate_id++) { |
167 | const auto g_block_id = gate_id * k_blocks_; |
168 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
169 | const auto B_g_offset = gate_id * B_gb_iter_offset_; |
170 | const auto A_gm = A_m + A_gb_offset; |
171 | const auto B_wei_iter_gn = B_wei_iter_n + B_g_offset; |
172 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
173 | ctx.addr_batch[g_block_id + k_block_id].ptr.A |
174 | = A_gm + k_block_id * k_block_; |
175 | ctx.addr_batch[g_block_id + k_block_id].ptr.B |
176 | = B_wei_iter_gn + k_block_id * B_kb_offset_; |
177 | } |
178 | } |
179 | |
180 | ctx.tile_configure_if_needed(kernel_iter_config); |
181 | brgemm_kernel_execute(kernel_iter, k_blocks_ * num_gates, |
182 | ctx.addr_batch, reinterpret_cast<void *>(C_diff_iter_n), |
183 | ctx.amx_buffer); |
184 | } |
185 | |
186 | if (should_calc_diff_src_layer) { |
187 | const bool do_n_layer_tail = (n + rnn_.diff_src_brgemm.n_block) |
188 | > rnn_.diff_src_brgemm.N_layer; |
189 | |
190 | if (do_n_layer_tail) { |
191 | kernel_layer = gates_start == 0 ? kernel_layer_n_tail_b0_ |
192 | : kernel_layer_n_tail_b1_; |
193 | kernel_layer_k_tail = kernel_layer_nk_tail_; |
194 | kernel_layer_config |
195 | = rnn_brgemm_.diff_src_.pallete_buff_layer_n_tail_; |
196 | kernel_layer_k_tail_config |
197 | = rnn_brgemm_.diff_src_.pallete_buff_layer_nk_tail_; |
198 | } |
199 | |
200 | for (int gate_id = gates_start; gate_id < gates_end; gate_id++) { |
201 | const auto g_block_id = gate_id * k_blocks_; |
202 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
203 | const auto B_g_offset = gate_id * B_gb_layer_offset_; |
204 | const auto A_gm = A_m + A_gb_offset; |
205 | const auto B_wei_layer_gn = B_wei_layer_n + B_g_offset; |
206 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
207 | ctx.addr_batch[g_block_id + k_block_id].ptr.A |
208 | = A_gm + k_block_id * k_block_; |
209 | ctx.addr_batch[g_block_id + k_block_id].ptr.B |
210 | = B_wei_layer_gn + k_block_id * B_kb_offset_; |
211 | } |
212 | } |
213 | |
214 | ctx.tile_configure_if_needed(kernel_layer_config); |
215 | brgemm_kernel_execute(kernel_layer, k_blocks_ * num_gates, |
216 | ctx.addr_batch, reinterpret_cast<void *>(C_diff_layer_n), |
217 | ctx.amx_buffer); |
218 | } |
219 | |
220 | if (should_calc_diff_src_iter && k_tail_) { |
221 | for (int gate_id = gates_start; gate_id < gates_end; gate_id++) { |
222 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
223 | const auto B_gb_offset = gate_id * B_gb_iter_offset_; |
224 | ctx.addr_batch[gate_id].ptr.A |
225 | = A_m + A_gb_offset + A_k_tail_offset_; |
226 | ctx.addr_batch[gate_id].ptr.B |
227 | = B_wei_iter_n + B_gb_offset + B_k_tail_offset_; |
228 | } |
229 | |
230 | ctx.tile_configure_if_needed(kernel_iter_k_tail_config); |
231 | brgemm_kernel_execute(kernel_iter_k_tail, num_gates, ctx.addr_batch, |
232 | reinterpret_cast<void *>(C_diff_iter_n), ctx.amx_buffer); |
233 | } |
234 | |
235 | if (should_calc_diff_src_layer && k_tail_) { |
236 | for (int gate_id = gates_start; gate_id < gates_end; gate_id++) { |
237 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
238 | const auto B_gb_offset = gate_id * B_gb_layer_offset_; |
239 | ctx.addr_batch[gate_id].ptr.A |
240 | = A_m + A_gb_offset + A_k_tail_offset_; |
241 | ctx.addr_batch[gate_id].ptr.B |
242 | = B_wei_layer_n + B_gb_offset + B_k_tail_offset_; |
243 | } |
244 | |
245 | ctx.tile_configure_if_needed(kernel_layer_k_tail_config); |
246 | brgemm_kernel_execute(kernel_layer_k_tail, num_gates, ctx.addr_batch, |
247 | reinterpret_cast<void *>(C_diff_layer_n), ctx.amx_buffer); |
248 | } |
249 | } |
250 | |
251 | // TODO consider merging with kernel - check perf after merge |
252 | template <typename weights_t, typename scratch_t, typename gemm_acc_t> |
253 | void brgemm_diff_src_layer_iter_t<weights_t, scratch_t, gemm_acc_t>::kernel_amx( |
254 | const int ithr, const int nthr) const { |
255 | using namespace cpu::rnn_utils; |
256 | |
257 | int mn_start = 0, mn_end = 0; |
258 | balance211(work_amount_, nthr, ithr, mn_start, mn_end); |
259 | |
260 | int n_block_id = 0, m_block_id = 0; |
261 | const auto n_gates = rnn_.n_gates; |
262 | const int gates_block_size = rnn_.diff_src_brgemm.gates_block; |
263 | thread_exec_ctx_t ctx; |
264 | ctx.addr_batch = addr_batch_global_ + ithr * (k_blocks_n_gates_ + 1); |
265 | ctx.amx_buffer = amx_scratchpad_ |
266 | + rnn_.diff_src_brgemm.m_block * rnn_.diff_src_brgemm.n_block |
267 | * ithr; |
268 | |
269 | for (int gate_idx = 0; gate_idx < n_gates; gate_idx += gates_block_size) { |
270 | const int gates_start = gate_idx; |
271 | const int gates_end = nstl::min(gate_idx + gates_block_size, n_gates); |
272 | |
273 | switch (rnn_.diff_src_brgemm.loop_order) { |
274 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
275 | nd_iterator_init(mn_start, m_block_id, m_blocking_, n_block_id, |
276 | n_blocking_); |
277 | break; |
278 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
279 | nd_iterator_init(mn_start, n_block_id, n_blocking_, m_block_id, |
280 | m_blocking_); |
281 | break; |
282 | default: assert(!"unsupported loop order" ); |
283 | } |
284 | int mn_idx = mn_start; |
285 | while (mn_idx < mn_end) { |
286 | kernel_amx_compute_iter( |
287 | m_block_id, n_block_id, gates_start, gates_end, ctx); |
288 | ++mn_idx; |
289 | switch (rnn_.diff_src_brgemm.loop_order) { |
290 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
291 | nd_iterator_step( |
292 | m_block_id, m_blocking_, n_block_id, n_blocking_); |
293 | break; |
294 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
295 | nd_iterator_step( |
296 | n_block_id, n_blocking_, m_block_id, m_blocking_); |
297 | break; |
298 | default: assert(!"unsupported loop order" ); |
299 | } |
300 | } |
301 | } |
302 | } |
303 | |
304 | template <typename weights_t, typename scratch_t, typename gemm_acc_t> |
305 | void brgemm_diff_src_layer_iter_t<weights_t, scratch_t, gemm_acc_t>::kernel( |
306 | const int ithr, const int nthr) const { |
307 | int start = 0, end = 0; |
308 | balance211(work_amount_, nthr, ithr, start, end); |
309 | |
310 | int n_block_id = 0, m_block_id = 0; |
311 | nd_iterator_init(start, n_block_id, n_blocking_, m_block_id, m_blocking_); |
312 | |
313 | x64::brgemm_batch_element_t *const addr_batch |
314 | = addr_batch_global_ + ithr * (k_blocks_n_gates_ + 1); |
315 | const auto n_gates = rnn_.n_gates; |
316 | |
317 | while (start < end) { |
318 | const int m = m_block_id * rnn_.diff_src_brgemm.m_block; |
319 | const int n = n_block_id * rnn_.diff_src_brgemm.n_block; |
320 | const scratch_t *const A_m = A_ + m * LDA_; |
321 | const auto B_n_offset = n_block_id * B_nb_offset_; |
322 | const weights_t *const B_wei_iter_n = B_wei_iter_ + B_n_offset; |
323 | const weights_t *const B_wei_layer_n = B_wei_layer_ + B_n_offset; |
324 | const auto C_offset = m * LDC_ + n; |
325 | gemm_acc_t *const C_diff_iter_n = C_diff_iter_ + C_offset; |
326 | gemm_acc_t *const C_diff_layer_n = C_diff_layer_ + C_offset; |
327 | const brgemm_kernel_t *kernel_iter = kernel_iter_full_blocks_b0_; |
328 | const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_; |
329 | const brgemm_kernel_t *kernel_layer = kernel_layer_full_blocks_b0_; |
330 | const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_; |
331 | const bool should_calc_diff_src_layer |
332 | = gemm_layer_needed_ && n_block_id < max_n_layer_blocks_; |
333 | const bool should_calc_diff_src_iter = n_block_id < max_n_iter_blocks_; |
334 | |
335 | if (should_calc_diff_src_iter) { |
336 | const bool do_n_iter_tail = (n + rnn_.diff_src_brgemm.n_block) |
337 | > rnn_.diff_src_brgemm.N_iter; |
338 | |
339 | if (do_n_iter_tail) { |
340 | kernel_iter = kernel_iter_n_tail_b0_; |
341 | kernel_iter_k_tail = kernel_iter_nk_tail_; |
342 | } |
343 | |
344 | for (int gate_id = 0; gate_id < n_gates; gate_id++) { |
345 | const auto g_block_id = gate_id * k_blocks_; |
346 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
347 | const auto B_gb_offset = gate_id * B_gb_iter_offset_; |
348 | const auto A_gm = A_m + A_gb_offset; |
349 | const auto B_wei_iter_gn = B_wei_iter_n + B_gb_offset; |
350 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
351 | addr_batch[g_block_id + k_block_id].ptr.A |
352 | = A_gm + k_block_id * k_block_; |
353 | addr_batch[g_block_id + k_block_id].ptr.B |
354 | = B_wei_iter_gn + k_block_id * B_kb_offset_; |
355 | } |
356 | } |
357 | |
358 | brgemm_kernel_execute(kernel_iter, k_blocks_n_gates_, addr_batch, |
359 | reinterpret_cast<void *>(C_diff_iter_n), nullptr); |
360 | } |
361 | |
362 | if (should_calc_diff_src_layer) { |
363 | const bool do_n_layer_tail = (n + rnn_.diff_src_brgemm.n_block) |
364 | > rnn_.diff_src_brgemm.N_layer; |
365 | |
366 | if (do_n_layer_tail) { |
367 | kernel_layer = kernel_layer_n_tail_b0_; |
368 | kernel_layer_k_tail = kernel_layer_nk_tail_; |
369 | } |
370 | |
371 | for (int gate_id = 0; gate_id < n_gates; gate_id++) { |
372 | const auto g_block_id = gate_id * k_blocks_; |
373 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
374 | const auto B_gb_offset = gate_id * B_gb_layer_offset_; |
375 | const auto A_gm = A_m + A_gb_offset; |
376 | const auto B_wei_layer_gn = B_wei_layer_n + B_gb_offset; |
377 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
378 | addr_batch[g_block_id + k_block_id].ptr.A |
379 | = A_gm + k_block_id * k_block_; |
380 | addr_batch[g_block_id + k_block_id].ptr.B |
381 | = B_wei_layer_gn + k_block_id * B_kb_offset_; |
382 | } |
383 | } |
384 | brgemm_kernel_execute(kernel_layer, k_blocks_n_gates_, addr_batch, |
385 | reinterpret_cast<void *>(C_diff_layer_n), nullptr); |
386 | } |
387 | |
388 | if (should_calc_diff_src_iter && k_tail_) { |
389 | for (int gate_id = 0; gate_id < n_gates; gate_id++) { |
390 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
391 | const auto B_gb_offset = gate_id * B_gb_iter_offset_; |
392 | addr_batch[gate_id].ptr.A |
393 | = A_m + A_gb_offset + A_k_tail_offset_; |
394 | addr_batch[gate_id].ptr.B |
395 | = B_wei_iter_n + B_gb_offset + B_k_tail_offset_; |
396 | } |
397 | |
398 | brgemm_kernel_execute(kernel_iter_k_tail, n_gates, addr_batch, |
399 | reinterpret_cast<void *>(C_diff_iter_n), nullptr); |
400 | } |
401 | |
402 | if (should_calc_diff_src_layer && k_tail_) { |
403 | for (int gate_id = 0; gate_id < n_gates; gate_id++) { |
404 | const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K; |
405 | const auto B_gb_offset = gate_id * B_gb_layer_offset_; |
406 | addr_batch[gate_id].ptr.A |
407 | = A_m + A_gb_offset + A_k_tail_offset_; |
408 | addr_batch[gate_id].ptr.B |
409 | = B_wei_layer_n + B_gb_offset + B_k_tail_offset_; |
410 | } |
411 | |
412 | brgemm_kernel_execute(kernel_layer_k_tail, n_gates, addr_batch, |
413 | reinterpret_cast<void *>(C_diff_layer_n), nullptr); |
414 | } |
415 | |
416 | ++start; |
417 | nd_iterator_step(n_block_id, n_blocking_, m_block_id, m_blocking_); |
418 | } |
419 | } |
420 | |
421 | template <typename src_layer_t, typename src_iter_t, typename scratch_t, |
422 | typename gemm_acc_t> |
423 | brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t, |
424 | gemm_acc_t>::brgemm_diff_weights_layer_iter_t(const ref_rnn_brgemm_t |
425 | &rnn_brgemm, |
426 | const rnn_utils::rnn_conf_t &rnn, |
427 | rnn_utils::cell_position_t cell_position, const src_layer_t *src_iter, |
428 | scratch_t *const A_iter_transposed_scratch, const src_iter_t *src_layer, |
429 | scratch_t *const A_layer_transposed_scratch, const scratch_t *scratch, |
430 | scratch_t *scratch_gates_blocked, gemm_acc_t *diff_weights_iter, |
431 | gemm_acc_t *diff_weights_layer, gemm_acc_t *diff_bias, |
432 | gemm_acc_t *amx_scratchpad, |
433 | x64::brgemm_batch_element_t *addr_batch_global) |
434 | : rnn_brgemm_(rnn_brgemm) |
435 | , rnn_(rnn) |
436 | , is_amx_(rnn_.is_cell_bf16_amx()) |
437 | , A_iter_(src_iter) |
438 | , A_iter_transposed_scratch_(A_iter_transposed_scratch) |
439 | , A_layer_(src_layer) |
440 | , A_layer_transposed_scratch_(A_layer_transposed_scratch) |
441 | , B_(scratch) |
442 | , B_blocked_scratch_(scratch_gates_blocked) |
443 | , C_iter_(diff_weights_iter) |
444 | , C_layer_(diff_weights_layer) |
445 | , diff_bias_(diff_bias) |
446 | , LDA_iter_(rnn.diff_wei_brgemm.LDA_iter) |
447 | , LDA_layer_(rnn.diff_wei_brgemm.LDA_layer) |
448 | , LDC_iter_(rnn.diff_wei_brgemm.LDC_iter) |
449 | , LDC_layer_(rnn.diff_wei_brgemm.LDC_layer) |
450 | , max_nthr_(rnn.nthr) |
451 | , n_blocking_(rnn.diff_wei_brgemm.N_blocks) |
452 | , m_blocking_(rnn.diff_wei_brgemm.M_blocks) |
453 | , k_blocks_(rnn.diff_wei_brgemm.K_blocks) |
454 | , k_tail_(rnn.diff_wei_brgemm.k_tail) |
455 | , k_block_(rnn.diff_wei_brgemm.k_block) |
456 | , m_iter_block_(rnn.slc == rnn.sic ? rnn.diff_wei_brgemm.m_block |
457 | : rnn.diff_wei_brgemm.M_iter) |
458 | , m_layer_block_(rnn.slc == rnn.sic ? rnn.diff_wei_brgemm.m_block |
459 | : rnn.diff_wei_brgemm.M_layer) |
460 | , A_k_iter_tail_offset_(k_blocks_ * k_block_) |
461 | , A_k_layer_tail_offset_(k_blocks_ * k_block_) |
462 | , B_kb_offset_(k_block_ * rnn.diff_wei_brgemm.n_block) |
463 | , B_k_tail_offset_(k_blocks_ * k_block_ * rnn.scratch_gates_ld) |
464 | , B_k_tail_offset_blocked_( |
465 | k_blocks_ * k_block_ * rnn.diff_wei_brgemm.n_block) |
466 | , work_amount_(n_blocking_ * m_blocking_) |
467 | , kernel_iter_full_blocks_(rnn_brgemm.diff_wei_.kernel_iter_beta1_.get()) |
468 | , kernel_iter_n_tail_(rnn_brgemm.diff_wei_.kernel_iter_N_tail_beta1_.get()) |
469 | , kernel_iter_k_tail_(rnn_brgemm.diff_wei_.kernel_iter_K_tail_beta1_.get()) |
470 | , kernel_iter_nk_tail_( |
471 | rnn_brgemm.diff_wei_.kernel_iter_NK_tail_beta1_.get()) |
472 | , kernel_layer_full_blocks_(rnn_brgemm.diff_wei_.kernel_layer_beta1_.get()) |
473 | , kernel_layer_n_tail_( |
474 | rnn_brgemm.diff_wei_.kernel_layer_N_tail_beta1_.get()) |
475 | , kernel_layer_k_tail_( |
476 | rnn_brgemm.diff_wei_.kernel_layer_K_tail_beta1_.get()) |
477 | , kernel_layer_nk_tail_( |
478 | rnn_brgemm.diff_wei_.kernel_layer_NK_tail_beta1_.get()) |
479 | , cell_position_(cell_position) |
480 | , kernel_gates_reduction_(rnn_brgemm.kernel_gates_reduction_.get()) |
481 | , kernel_gates_reduction_tail_( |
482 | rnn_brgemm.kernel_gates_reduction_tail_.get()) |
483 | , kernel_transpose_iter_(rnn_brgemm.kernel_transpose_single_row_iter_.get()) |
484 | , kernel_transpose_layer_(rnn_brgemm.kernel_transpose_single_row_layer_ |
485 | ? rnn_brgemm.kernel_transpose_single_row_layer_.get() |
486 | : kernel_transpose_iter_) |
487 | , amx_scratchpad_(amx_scratchpad) |
488 | , addr_batch_global_(addr_batch_global) {} |
489 | |
490 | template <typename src_layer_t, typename src_iter_t, typename scratch_t, |
491 | typename gemm_acc_t> |
492 | void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t, |
493 | gemm_acc_t>::execute() const { |
494 | if (is_amx_) { |
495 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
496 | this->kernel_amx(ithr, nthr); |
497 | }); |
498 | } else { |
499 | parallel(max_nthr_, [this](const int ithr, const int nthr) { |
500 | this->kernel(ithr, nthr); |
501 | }); |
502 | } |
503 | } |
504 | |
505 | template <typename src_layer_t, typename src_iter_t, typename scratch_t, |
506 | typename gemm_acc_t> |
507 | void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t, |
508 | gemm_acc_t>::reorder_scratch_gates(const scratch_t *src, scratch_t *dst, |
509 | const bool do_n_tail) const { |
510 | auto ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); |
511 | ctx.src = (void *)src; |
512 | ctx.tr_src = (void *)dst; |
513 | ctx.current_K_start = 0; |
514 | ctx.current_K_iters = rnn_.mb; |
515 | ctx.current_N_blk = do_n_tail ? rnn_.diff_wei_brgemm.n_tail |
516 | : rnn_.diff_wei_brgemm.n_block; |
517 | (*rnn_brgemm_.diff_wei_.srcatch_gates_reorder_kernel_)(&ctx); |
518 | } |
519 | |
520 | template <typename src_layer_t, typename src_iter_t, typename scratch_t, |
521 | typename gemm_acc_t> |
522 | void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t, |
523 | gemm_acc_t>::kernel(const int ithr, const int nthr) const { |
524 | |
525 | const bool global_transpose = rnn_.diff_wei_brgemm.global_transpose; |
526 | |
527 | scratch_t *const B_blocked = B_blocked_scratch_ |
528 | + ithr * rnn_.diff_wei_brgemm.Kpadded |
529 | * rnn_.diff_wei_brgemm.n_block; |
530 | |
531 | scratch_t *const A_iter_transposed_ithr = global_transpose |
532 | ? A_iter_transposed_scratch_ |
533 | : (A_iter_transposed_scratch_ |
534 | + ithr * rnn_.diff_wei_brgemm.Kpadded * m_iter_block_); |
535 | |
536 | scratch_t *const A_layer_transposed_ithr = global_transpose |
537 | ? A_layer_transposed_scratch_ |
538 | : A_layer_transposed_scratch_ |
539 | + ithr * rnn_.diff_wei_brgemm.Kpadded * m_layer_block_; |
540 | |
541 | int start = 0, end = 0; |
542 | balance211(work_amount_, nthr, ithr, start, end); |
543 | |
544 | int n_block_id = 0, m_block_id = 0, last_n_block_id = -1, |
545 | last_m_block_id = -1; |
546 | |
547 | nd_iterator_init(start, n_block_id, n_blocking_, m_block_id, m_blocking_); |
548 | |
549 | x64::brgemm_batch_element_t *const addr_batch |
550 | = addr_batch_global_ + ithr * (k_blocks_ + 1); |
551 | |
552 | while (start < end) { |
553 | const bool should_reorder_gates = last_n_block_id != n_block_id; |
554 | const bool transpose_needed |
555 | = !(rnn_.mb == 1 && std::is_same<float, src_iter_t>::value); |
556 | const bool should_transpose_src = transpose_needed && !global_transpose |
557 | && (last_m_block_id != m_block_id); |
558 | |
559 | const int m_iter = m_block_id * m_iter_block_; |
560 | const int m_layer = m_block_id * m_layer_block_; |
561 | const src_iter_t *const A_iter_m = global_transpose |
562 | ? A_iter_transposed_ithr + m_iter * LDA_iter_ |
563 | : A_iter_ + m_iter; |
564 | const src_layer_t *const A_layer_m = global_transpose |
565 | ? A_layer_transposed_ithr + m_layer * LDA_layer_ |
566 | : A_layer_ + m_layer; |
567 | |
568 | src_iter_t *const A_iter_transposed |
569 | = (global_transpose || !transpose_needed) |
570 | ? const_cast<src_iter_t *>(A_iter_m) |
571 | : A_iter_transposed_ithr; |
572 | src_layer_t *const A_layer_transposed |
573 | = (global_transpose || !transpose_needed) |
574 | ? const_cast<src_layer_t *>(A_layer_m) |
575 | : A_layer_transposed_ithr; |
576 | |
577 | const int n = n_block_id * rnn_.diff_wei_brgemm.n_block; |
578 | const scratch_t *const B_n = B_ + n; |
579 | const auto C_iter_offset = m_iter * LDC_iter_ + n; |
580 | const auto C_layer_offset = m_layer * LDC_layer_ + n; |
581 | gemm_acc_t *const C_diff_iter_n = C_iter_ + C_iter_offset; |
582 | gemm_acc_t *const C_diff_layer_n = C_layer_ + C_layer_offset; |
583 | |
584 | const bool do_n_tail |
585 | = (n + rnn_.diff_wei_brgemm.n_block) > rnn_.diff_wei_brgemm.N; |
586 | |
587 | const brgemm_kernel_t *kernel_iter = kernel_iter_full_blocks_; |
588 | const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_; |
589 | const brgemm_kernel_t *kernel_layer = kernel_layer_full_blocks_; |
590 | const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_; |
591 | const auto *kernel_reduction = kernel_gates_reduction_; |
592 | |
593 | if (do_n_tail) { |
594 | kernel_iter = kernel_iter_n_tail_; |
595 | kernel_iter_k_tail = kernel_iter_nk_tail_; |
596 | kernel_layer = kernel_layer_n_tail_; |
597 | kernel_layer_k_tail = kernel_layer_nk_tail_; |
598 | kernel_reduction = kernel_gates_reduction_tail_; |
599 | } |
600 | |
601 | if (should_reorder_gates) { |
602 | reorder_scratch_gates(B_n, B_blocked, do_n_tail); |
603 | |
604 | if (m_block_id == 0) { |
605 | jit_gates_reduction_t::call_params_t params; |
606 | params.src = reinterpret_cast<const void *>(B_blocked); |
607 | params.dst = reinterpret_cast<void *>(diff_bias_ + n); |
608 | (*kernel_reduction)(¶ms); |
609 | } |
610 | } |
611 | |
612 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
613 | addr_batch[k_block_id].ptr.A |
614 | = A_iter_transposed + k_block_id * k_block_; |
615 | addr_batch[k_block_id].ptr.B |
616 | = B_blocked + k_block_id * B_kb_offset_; |
617 | } |
618 | if (should_transpose_src) { |
619 | jit_brgemm_transpose_single_row_t::call_params_t params; |
620 | params.src = reinterpret_cast<const void *>(A_iter_m); |
621 | params.dst = reinterpret_cast<void *>(A_iter_transposed); |
622 | (*kernel_transpose_iter_)(¶ms); |
623 | } |
624 | brgemm_kernel_execute(kernel_iter, k_blocks_, addr_batch, |
625 | reinterpret_cast<void *>(C_diff_iter_n), nullptr); |
626 | |
627 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
628 | addr_batch[k_block_id].ptr.A |
629 | = A_layer_transposed + k_block_id * k_block_; |
630 | addr_batch[k_block_id].ptr.B |
631 | = B_blocked + k_block_id * B_kb_offset_; |
632 | } |
633 | if (should_transpose_src) { |
634 | jit_brgemm_transpose_single_row_t::call_params_t params; |
635 | params.src = reinterpret_cast<const void *>(A_layer_m); |
636 | params.dst = reinterpret_cast<void *>(A_layer_transposed); |
637 | (*kernel_transpose_layer_)(¶ms); |
638 | } |
639 | brgemm_kernel_execute(kernel_layer, k_blocks_, addr_batch, |
640 | reinterpret_cast<void *>(C_diff_layer_n), nullptr); |
641 | |
642 | if (k_tail_) { |
643 | const auto B_blocked_k_tail = B_blocked + B_k_tail_offset_blocked_; |
644 | |
645 | addr_batch[0].ptr.A = A_iter_transposed + A_k_iter_tail_offset_; |
646 | addr_batch[0].ptr.B = B_blocked_k_tail; |
647 | |
648 | brgemm_kernel_execute(kernel_iter_k_tail, 1, addr_batch, |
649 | reinterpret_cast<void *>(C_diff_iter_n), nullptr); |
650 | |
651 | addr_batch[0].ptr.A = A_layer_transposed + A_k_layer_tail_offset_; |
652 | addr_batch[0].ptr.B = B_blocked_k_tail; |
653 | |
654 | brgemm_kernel_execute(kernel_layer_k_tail, 1, addr_batch, |
655 | reinterpret_cast<void *>(C_diff_layer_n), nullptr); |
656 | } |
657 | |
658 | if (should_reorder_gates) { last_n_block_id = n_block_id; } |
659 | if (should_transpose_src) { last_m_block_id = m_block_id; } |
660 | |
661 | ++start; |
662 | nd_iterator_step(n_block_id, n_blocking_, m_block_id, m_blocking_); |
663 | } |
664 | } |
665 | |
666 | template <typename src_layer_t, typename src_iter_t, typename scratch_t, |
667 | typename gemm_acc_t> |
668 | void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t, |
669 | gemm_acc_t>::kernel_amx(const int ithr, const int nthr) const { |
670 | using namespace cpu::rnn_utils; |
671 | |
672 | const bool global_transpose = rnn_.diff_wei_brgemm.global_transpose; |
673 | |
674 | int start = 0, end = 0; |
675 | balance211(work_amount_, nthr, ithr, start, end); |
676 | |
677 | int n_block_id = 0, m_block_id = 0, last_n_block_id = -1, |
678 | last_m_block_id = -1; |
679 | switch (rnn_.diff_wei_brgemm.loop_order) { |
680 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
681 | nd_iterator_init( |
682 | start, m_block_id, m_blocking_, n_block_id, n_blocking_); |
683 | break; |
684 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
685 | nd_iterator_init( |
686 | start, n_block_id, n_blocking_, m_block_id, m_blocking_); |
687 | break; |
688 | default: assert(!"unsupported loop order" ); |
689 | } |
690 | |
691 | x64::brgemm_batch_element_t *const addr_batch |
692 | = addr_batch_global_ + ithr * (k_blocks_ + 1); |
693 | scratch_t *const B_blocked = B_blocked_scratch_ |
694 | + ithr * rnn_.diff_wei_brgemm.Kpadded |
695 | * rnn_.diff_wei_brgemm.n_block; |
696 | scratch_t *const A_iter_transposed_ithr = global_transpose |
697 | ? A_iter_transposed_scratch_ |
698 | : (A_iter_transposed_scratch_ |
699 | + ithr * rnn_.diff_wei_brgemm.Kpadded * m_iter_block_); |
700 | scratch_t *const A_layer_transposed_ithr = global_transpose |
701 | ? A_layer_transposed_scratch_ |
702 | : A_layer_transposed_scratch_ |
703 | + ithr * rnn_.diff_wei_brgemm.Kpadded * m_layer_block_; |
704 | |
705 | gemm_acc_t *const amx_buffer = amx_scratchpad_ |
706 | + rnn_.diff_wei_brgemm.m_block * rnn_.diff_wei_brgemm.n_block |
707 | * ithr; |
708 | const bool m_equal |
709 | = rnn_.diff_wei_brgemm.M_iter == rnn_.diff_wei_brgemm.M_layer; |
710 | amx_tile_configuration_loader_t load_cfg_if_needed; |
711 | |
712 | while (start < end) { |
713 | const bool should_reorder_gates = last_n_block_id != n_block_id; |
714 | const bool should_transpose_src |
715 | = !global_transpose && last_m_block_id != m_block_id; |
716 | |
717 | const int m_iter = m_block_id * m_iter_block_; |
718 | const int m_layer = m_block_id * m_layer_block_; |
719 | const src_iter_t *const A_iter_m = global_transpose |
720 | ? A_iter_transposed_ithr + m_iter * LDA_iter_ |
721 | : A_iter_ + m_iter; |
722 | const src_layer_t *const A_layer_m = global_transpose |
723 | ? A_layer_transposed_ithr + m_layer * LDA_layer_ |
724 | : A_layer_ + m_layer; |
725 | |
726 | src_iter_t *const A_iter_transposed = global_transpose |
727 | ? const_cast<src_iter_t *>(A_iter_m) |
728 | : A_iter_transposed_ithr; |
729 | src_layer_t *const A_layer_transposed = global_transpose |
730 | ? const_cast<src_layer_t *>(A_layer_m) |
731 | : A_layer_transposed_ithr; |
732 | |
733 | const int n = n_block_id * rnn_.diff_wei_brgemm.n_block; |
734 | const scratch_t *const B_n = B_ + n; |
735 | const auto C_iter_offset = m_iter * LDC_iter_ + n; |
736 | const auto C_layer_offset = m_layer * LDC_layer_ + n; |
737 | gemm_acc_t *const C_diff_iter_n = C_iter_ + C_iter_offset; |
738 | gemm_acc_t *const C_diff_layer_n = C_layer_ + C_layer_offset; |
739 | |
740 | const bool do_n_tail |
741 | = (n + rnn_.diff_wei_brgemm.n_block) > rnn_.diff_wei_brgemm.N; |
742 | |
743 | const brgemm_kernel_t *kernel_iter = kernel_iter_full_blocks_; |
744 | const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_; |
745 | const brgemm_kernel_t *kernel_layer = kernel_layer_full_blocks_; |
746 | const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_; |
747 | const auto *kernel_reduction = kernel_gates_reduction_; |
748 | |
749 | const char *kernel_iter_config |
750 | = rnn_brgemm_.diff_wei_.pallete_buff_iter_; |
751 | const char *kernel_iter_k_tail_config |
752 | = rnn_brgemm_.diff_wei_.pallete_buff_iter_k_tail_; |
753 | const char *kernel_layer_config = m_equal |
754 | ? rnn_brgemm_.diff_wei_.pallete_buff_iter_ |
755 | : rnn_brgemm_.diff_wei_.pallete_buff_layer_; |
756 | const char *kernel_layer_k_tail_config = m_equal |
757 | ? rnn_brgemm_.diff_wei_.pallete_buff_iter_k_tail_ |
758 | : rnn_brgemm_.diff_wei_.pallete_buff_layer_k_tail_; |
759 | |
760 | if (do_n_tail) { |
761 | kernel_iter = kernel_iter_n_tail_; |
762 | kernel_iter_k_tail = kernel_iter_nk_tail_; |
763 | kernel_layer = kernel_layer_n_tail_; |
764 | kernel_layer_k_tail = kernel_layer_nk_tail_; |
765 | kernel_reduction = kernel_gates_reduction_tail_; |
766 | |
767 | kernel_iter_config |
768 | = rnn_brgemm_.diff_wei_.pallete_buff_iter_n_tail_; |
769 | kernel_iter_k_tail_config |
770 | = rnn_brgemm_.diff_wei_.pallete_buff_iter_nk_tail_; |
771 | kernel_layer_config = m_equal |
772 | ? rnn_brgemm_.diff_wei_.pallete_buff_iter_n_tail_ |
773 | : rnn_brgemm_.diff_wei_.pallete_buff_layer_n_tail_; |
774 | kernel_layer_k_tail_config = m_equal |
775 | ? rnn_brgemm_.diff_wei_.pallete_buff_iter_nk_tail_ |
776 | : rnn_brgemm_.diff_wei_.pallete_buff_layer_nk_tail_; |
777 | } |
778 | |
779 | if (should_reorder_gates) { |
780 | reorder_scratch_gates(B_n, B_blocked, do_n_tail); |
781 | |
782 | if (m_block_id == 0) { |
783 | jit_gates_reduction_t::call_params_t params; |
784 | params.src = reinterpret_cast<const void *>(B_blocked); |
785 | params.dst = reinterpret_cast<void *>(diff_bias_ + n); |
786 | (*kernel_reduction)(¶ms); |
787 | } |
788 | } |
789 | |
790 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
791 | addr_batch[k_block_id].ptr.A |
792 | = A_iter_transposed + k_block_id * k_block_; |
793 | addr_batch[k_block_id].ptr.B |
794 | = B_blocked + k_block_id * B_kb_offset_; |
795 | } |
796 | if (should_transpose_src) { |
797 | jit_brgemm_transpose_single_row_t::call_params_t params; |
798 | params.src = reinterpret_cast<const void *>(A_iter_m); |
799 | params.dst = reinterpret_cast<void *>(A_iter_transposed); |
800 | (*kernel_transpose_iter_)(¶ms); |
801 | } |
802 | |
803 | load_cfg_if_needed(kernel_iter_config); |
804 | brgemm_kernel_execute(kernel_iter, k_blocks_, addr_batch, |
805 | reinterpret_cast<void *>(C_diff_iter_n), amx_buffer); |
806 | |
807 | for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) { |
808 | addr_batch[k_block_id].ptr.A |
809 | = A_layer_transposed + k_block_id * k_block_; |
810 | addr_batch[k_block_id].ptr.B |
811 | = B_blocked + k_block_id * B_kb_offset_; |
812 | } |
813 | |
814 | if (should_transpose_src) { |
815 | jit_brgemm_transpose_single_row_t::call_params_t params; |
816 | params.src = reinterpret_cast<const void *>(A_layer_m); |
817 | params.dst = reinterpret_cast<void *>(A_layer_transposed); |
818 | (*kernel_transpose_layer_)(¶ms); |
819 | } |
820 | |
821 | load_cfg_if_needed(kernel_layer_config); |
822 | brgemm_kernel_execute(kernel_layer, k_blocks_, addr_batch, |
823 | reinterpret_cast<void *>(C_diff_layer_n), amx_buffer); |
824 | |
825 | if (k_tail_) { |
826 | const auto B_blocked_k_tail = B_blocked + B_k_tail_offset_blocked_; |
827 | |
828 | addr_batch[0].ptr.A = A_iter_transposed + A_k_iter_tail_offset_; |
829 | addr_batch[0].ptr.B = B_blocked_k_tail; |
830 | |
831 | load_cfg_if_needed(kernel_iter_k_tail_config); |
832 | brgemm_kernel_execute(kernel_iter_k_tail, 1, addr_batch, |
833 | reinterpret_cast<void *>(C_diff_iter_n), amx_buffer); |
834 | |
835 | addr_batch[0].ptr.A = A_layer_transposed + A_k_layer_tail_offset_; |
836 | addr_batch[0].ptr.B = B_blocked_k_tail; |
837 | |
838 | load_cfg_if_needed(kernel_layer_k_tail_config); |
839 | brgemm_kernel_execute(kernel_layer_k_tail, 1, addr_batch, |
840 | reinterpret_cast<void *>(C_diff_layer_n), amx_buffer); |
841 | } |
842 | |
843 | if (should_reorder_gates) { last_n_block_id = n_block_id; } |
844 | if (should_transpose_src) { last_m_block_id = m_block_id; } |
845 | |
846 | ++start; |
847 | switch (rnn_.diff_wei_brgemm.loop_order) { |
848 | case brgemm_rnn_execute_loop_order_t::mblk_nblk: |
849 | nd_iterator_step( |
850 | m_block_id, m_blocking_, n_block_id, n_blocking_); |
851 | break; |
852 | case brgemm_rnn_execute_loop_order_t::nblk_mblk: |
853 | nd_iterator_step( |
854 | n_block_id, n_blocking_, m_block_id, m_blocking_); |
855 | break; |
856 | default: assert(!"unsupported loop order" ); |
857 | } |
858 | } |
859 | } |
860 | |
861 | template <typename scratch_t> |
862 | brgemm_diff_wei_peep_t<scratch_t>::brgemm_diff_wei_peep_t( |
863 | const ref_rnn_brgemm_t &rnn_brgemm, const rnn_utils::rnn_conf_t &rnn, |
864 | rnn_utils::cell_position_t cell_position, |
865 | const scratch_t *scratch_gates, const void *src_iter_c, |
866 | const void *dst_iter_c, float *diff_weights_peephole) |
867 | : rnn_(rnn) |
868 | , scratch_gates_(scratch_gates) |
869 | , src_iter_c_(src_iter_c) |
870 | , dst_iter_c_(dst_iter_c) |
871 | , diff_weights_peephole_(diff_weights_peephole) |
872 | , work_amount_(n_gates_ * rnn_.dhc_blocks_peephole) |
873 | , dst_iter_c_ld_(rnn.dst_iter_c_ld(cell_position)) |
874 | , src_iter_c_ld_(rnn.src_iter_c_ld(cell_position)) |
875 | , kernel_(rnn_brgemm.kernel_peephole_.get()) |
876 | , kernel_tail_(rnn_brgemm.kernel_peephole_tail_.get()) {} |
877 | |
878 | template <typename scratch_t> |
879 | void brgemm_diff_wei_peep_t<scratch_t>::execute() const { |
880 | parallel(rnn_.nthr, [this](const int ithr, const int nthr) { |
881 | this->kernel(ithr, nthr); |
882 | }); |
883 | } |
884 | |
885 | template <typename scratch_t> |
886 | void brgemm_diff_wei_peep_t<scratch_t>::kernel( |
887 | const int ithr, const int nthr) const { |
888 | |
889 | int start = 0, end = 0; |
890 | balance211(work_amount_, nthr, ithr, start, end); |
891 | |
892 | int g = 0, dhc_block_id = 0; |
893 | |
894 | nd_iterator_init( |
895 | start, g, n_gates_, dhc_block_id, rnn_.dhc_blocks_peephole); |
896 | |
897 | const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_, |
898 | types::data_type_size(rnn_.dst_iter_c_dt), |
899 | rnn_.ws_states_iter_c_nld, dst_iter_c_ld_); |
900 | const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_, |
901 | types::data_type_size(rnn_.src_iter_c_dt), |
902 | rnn_.ws_states_iter_c_nld, src_iter_c_ld_); |
903 | |
904 | const rnn_utils::ws_gates_aoc<const scratch_t> scratch_gates( |
905 | rnn_, scratch_gates_); |
906 | const rnn_utils::weights_peephole_aoc_t<float> diff_weights_peephole( |
907 | rnn_, diff_weights_peephole_); |
908 | |
909 | while (start < end) { |
910 | const auto dhc = dhc_block_id * rnn_.dhc_block_peephole; |
911 | const auto &c_states = g < 2 ? src_iter_c : dst_iter_c; |
912 | const auto scratch_g = g == 2 ? 3 : g; |
913 | const auto *const kernel = rnn_.dhc_tail_peephole |
914 | && dhc_block_id == rnn_.dhc_blocks_peephole - 1 |
915 | ? kernel_tail_ |
916 | : kernel_; |
917 | |
918 | jit_diff_weights_peephole_t::call_params_t params; |
919 | |
920 | for (int mb = 0; mb < rnn_.mb; ++mb) { |
921 | params.c_states = c_states(mb, dhc); |
922 | params.scratch_gates = &scratch_gates(mb, scratch_g, dhc); |
923 | params.dst = &diff_weights_peephole(g, dhc); |
924 | (*kernel)(¶ms); |
925 | } |
926 | |
927 | ++start; |
928 | nd_iterator_step(g, n_gates_, dhc_block_id, rnn_.dhc_blocks_peephole); |
929 | } |
930 | } |
931 | |
932 | template class brgemm_diff_src_layer_iter_t<float, float, float>; |
933 | template class brgemm_diff_src_layer_iter_t<bfloat16_t, bfloat16_t, float>; |
934 | |
935 | template class brgemm_diff_weights_layer_iter_t<float, float, float, float>; |
936 | template class brgemm_diff_weights_layer_iter_t<bfloat16_t, bfloat16_t, |
937 | bfloat16_t, float>; |
938 | |
939 | template class brgemm_diff_wei_peep_t<bfloat16_t>; |
940 | template class brgemm_diff_wei_peep_t<float>; |
941 | |
942 | } // namespace x64 |
943 | } // namespace cpu |
944 | } // namespace impl |
945 | } // namespace dnnl |
946 | |