1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/rnn/brgemm_cell_common_bwd.hpp"
18
19#include "common/dnnl_thread.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/rnn/brgemm_cell_common_reorders.hpp"
23#include "cpu/x64/rnn/brgemm_cell_common_utils.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace dnnl::impl::utils;
31
32template <typename weights_t, typename scratch_t, typename gemm_acc_t>
33brgemm_diff_src_layer_iter_t<weights_t, scratch_t,
34 gemm_acc_t>::brgemm_diff_src_layer_iter_t(const ref_rnn_brgemm_t
35 &rnn_brgemm,
36 const rnn_utils::rnn_conf_t &rnn,
37 rnn_utils::cell_position_t cell_position, scratch_t *scratch_gates,
38 weights_t *w_iter, weights_t *w_layer, gemm_acc_t *diff_src_iter,
39 gemm_acc_t *diff_src_layer, gemm_acc_t *amx_scratchpad,
40 x64::brgemm_batch_element_t *addr_batch_global)
41 : rnn_brgemm_(rnn_brgemm)
42 , rnn_(rnn)
43 , A_(scratch_gates)
44 , B_wei_iter_(w_iter)
45 , B_wei_layer_(w_layer)
46 , C_diff_iter_(diff_src_iter)
47 , C_diff_layer_(diff_src_layer)
48 , k_blocks_n_gates_(rnn.diff_src_brgemm.K_blocks)
49 , k_blocks_(rnn.diff_src_brgemm.K_blocks / rnn.n_gates)
50 , k_tail_(rnn.diff_src_brgemm.k_tail)
51 , k_block_(rnn.diff_src_brgemm.k_block)
52 , A_k_tail_offset_(k_blocks_ * k_block_)
53 , B_k_tail_offset_(A_k_tail_offset_ * rnn.diff_src_brgemm.n_block)
54 , B_nb_offset_(rnn.diff_src_brgemm.Kpadded * rnn.diff_src_brgemm.n_block)
55 , B_kb_offset_(k_block_ * rnn.diff_src_brgemm.n_block)
56 , B_gb_iter_offset_(rnn.diff_src_brgemm.Kpadded
57 * rnn.diff_src_brgemm.n_block * rnn.diff_src_brgemm.N_iter_blocks)
58 , B_gb_layer_offset_(rnn.diff_src_brgemm.Kpadded
59 * rnn.diff_src_brgemm.n_block
60 * rnn.diff_src_brgemm.N_layer_blocks)
61 , LDA_(rnn.diff_src_brgemm.LDA)
62 , LDC_(rnn.diff_src_brgemm.LDC)
63 , max_nthr_(rnn.nthr)
64 , n_blocking_(rnn.diff_src_brgemm.N_blocks)
65 , m_blocking_(rnn.diff_src_brgemm.M_blocks)
66 , work_amount_(n_blocking_ * m_blocking_)
67 , max_n_layer_blocks_(rnn.diff_src_brgemm.N_layer_blocks)
68 , max_n_iter_blocks_(rnn.diff_src_brgemm.N_iter_blocks)
69 , gemm_layer_needed_(rnn.need_gemm_layer(cell_position))
70 , kernel_iter_full_blocks_b0_(
71 rnn_brgemm_.diff_src_.kernel_iter_layer_beta0_.get())
72 , kernel_iter_full_blocks_b1_(
73 rnn_brgemm_.diff_src_.kernel_iter_layer_beta1_.get())
74
75 , kernel_iter_n_tail_b0_(
76 rnn_brgemm_.diff_src_.kernel_iter_N_tail_beta0_.get())
77 , kernel_iter_n_tail_b1_(
78 rnn_brgemm_.diff_src_.kernel_iter_N_tail_beta1_.get())
79 , kernel_iter_k_tail_(
80 rnn_brgemm_.diff_src_.kernel_iter_layer_K_tail_beta1_.get())
81 , kernel_iter_nk_tail_(
82 rnn_brgemm_.diff_src_.kernel_iter_NK_tail_beta1_.get())
83 , kernel_layer_full_blocks_b0_(
84 rnn_brgemm_.diff_src_.kernel_iter_layer_beta0_.get())
85 , kernel_layer_full_blocks_b1_(
86 rnn_brgemm_.diff_src_.kernel_iter_layer_beta1_.get())
87 , kernel_layer_n_tail_b0_(
88 rnn_brgemm_.diff_src_.kernel_layer_N_tail_beta0_.get())
89 , kernel_layer_n_tail_b1_(
90 rnn_brgemm_.diff_src_.kernel_layer_N_tail_beta1_.get())
91 , kernel_layer_k_tail_(
92 rnn_brgemm_.diff_src_.kernel_iter_layer_K_tail_beta1_.get())
93 , kernel_layer_nk_tail_(
94 rnn_brgemm_.diff_src_.kernel_layer_NK_tail_beta1_.get())
95 , amx_scratchpad_(amx_scratchpad)
96 , addr_batch_global_(addr_batch_global) {}
97
98template <typename weights_t, typename scratch_t, typename gemm_acc_t>
99void brgemm_diff_src_layer_iter_t<weights_t, scratch_t, gemm_acc_t>::execute()
100 const {
101 if (rnn_.is_cell_dt_bf16()
102 && rnn_.diff_src_brgemm.isa == x64::avx512_core_amx) {
103 parallel(max_nthr_, [this](const int ithr, const int nthr) {
104 this->kernel_amx(ithr, nthr);
105 });
106 } else {
107 parallel(max_nthr_, [this](const int ithr, const int nthr) {
108 this->kernel(ithr, nthr);
109 });
110 }
111}
112
113template <typename weights_t, typename scratch_t, typename gemm_acc_t>
114void brgemm_diff_src_layer_iter_t<weights_t, scratch_t,
115 gemm_acc_t>::kernel_amx_compute_iter(const int m_block_id,
116 const int n_block_id, const int gates_start, const int gates_end,
117 thread_exec_ctx_t &ctx) const {
118
119 const int m = m_block_id * rnn_.diff_src_brgemm.m_block;
120 const int n = n_block_id * rnn_.diff_src_brgemm.n_block;
121 const int num_gates = gates_end - gates_start;
122 const scratch_t *const A_m = A_ + m * LDA_;
123 const auto B_n_offset = n_block_id * B_nb_offset_;
124 const weights_t *const B_wei_iter_n = B_wei_iter_ + B_n_offset;
125 const weights_t *const B_wei_layer_n = B_wei_layer_ + B_n_offset;
126 const auto C_offset = m * LDC_ + n;
127 gemm_acc_t *const C_diff_iter_n = C_diff_iter_ + C_offset;
128 gemm_acc_t *const C_diff_layer_n = C_diff_layer_ + C_offset;
129
130 const brgemm_kernel_t *kernel_iter = gates_start == 0
131 ? kernel_iter_full_blocks_b0_
132 : kernel_iter_full_blocks_b1_;
133 const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_;
134 const brgemm_kernel_t *kernel_layer = gates_start == 0
135 ? kernel_layer_full_blocks_b0_
136 : kernel_layer_full_blocks_b1_;
137 const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_;
138
139 const char *kernel_iter_config
140 = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_;
141 const char *kernel_iter_k_tail_config
142 = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_k_tail_;
143 const char *kernel_layer_config
144 = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_;
145 const char *kernel_layer_k_tail_config
146 = rnn_brgemm_.diff_src_.pallete_buff_iter_layer_k_tail_;
147
148 const bool should_calc_diff_src_layer
149 = gemm_layer_needed_ && n_block_id < max_n_layer_blocks_;
150 const bool should_calc_diff_src_iter = n_block_id < max_n_iter_blocks_;
151
152 if (should_calc_diff_src_iter) {
153 const bool do_n_iter_tail = (n + rnn_.diff_src_brgemm.n_block)
154 > rnn_.diff_src_brgemm.N_iter;
155
156 if (do_n_iter_tail) {
157 kernel_iter = gates_start == 0 ? kernel_iter_n_tail_b0_
158 : kernel_iter_n_tail_b1_;
159 kernel_iter_k_tail = kernel_iter_nk_tail_;
160 kernel_iter_config
161 = rnn_brgemm_.diff_src_.pallete_buff_iter_n_tail_;
162 kernel_iter_k_tail_config
163 = rnn_brgemm_.diff_src_.pallete_buff_iter_nk_tail_;
164 }
165
166 for (int gate_id = gates_start; gate_id < gates_end; gate_id++) {
167 const auto g_block_id = gate_id * k_blocks_;
168 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
169 const auto B_g_offset = gate_id * B_gb_iter_offset_;
170 const auto A_gm = A_m + A_gb_offset;
171 const auto B_wei_iter_gn = B_wei_iter_n + B_g_offset;
172 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
173 ctx.addr_batch[g_block_id + k_block_id].ptr.A
174 = A_gm + k_block_id * k_block_;
175 ctx.addr_batch[g_block_id + k_block_id].ptr.B
176 = B_wei_iter_gn + k_block_id * B_kb_offset_;
177 }
178 }
179
180 ctx.tile_configure_if_needed(kernel_iter_config);
181 brgemm_kernel_execute(kernel_iter, k_blocks_ * num_gates,
182 ctx.addr_batch, reinterpret_cast<void *>(C_diff_iter_n),
183 ctx.amx_buffer);
184 }
185
186 if (should_calc_diff_src_layer) {
187 const bool do_n_layer_tail = (n + rnn_.diff_src_brgemm.n_block)
188 > rnn_.diff_src_brgemm.N_layer;
189
190 if (do_n_layer_tail) {
191 kernel_layer = gates_start == 0 ? kernel_layer_n_tail_b0_
192 : kernel_layer_n_tail_b1_;
193 kernel_layer_k_tail = kernel_layer_nk_tail_;
194 kernel_layer_config
195 = rnn_brgemm_.diff_src_.pallete_buff_layer_n_tail_;
196 kernel_layer_k_tail_config
197 = rnn_brgemm_.diff_src_.pallete_buff_layer_nk_tail_;
198 }
199
200 for (int gate_id = gates_start; gate_id < gates_end; gate_id++) {
201 const auto g_block_id = gate_id * k_blocks_;
202 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
203 const auto B_g_offset = gate_id * B_gb_layer_offset_;
204 const auto A_gm = A_m + A_gb_offset;
205 const auto B_wei_layer_gn = B_wei_layer_n + B_g_offset;
206 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
207 ctx.addr_batch[g_block_id + k_block_id].ptr.A
208 = A_gm + k_block_id * k_block_;
209 ctx.addr_batch[g_block_id + k_block_id].ptr.B
210 = B_wei_layer_gn + k_block_id * B_kb_offset_;
211 }
212 }
213
214 ctx.tile_configure_if_needed(kernel_layer_config);
215 brgemm_kernel_execute(kernel_layer, k_blocks_ * num_gates,
216 ctx.addr_batch, reinterpret_cast<void *>(C_diff_layer_n),
217 ctx.amx_buffer);
218 }
219
220 if (should_calc_diff_src_iter && k_tail_) {
221 for (int gate_id = gates_start; gate_id < gates_end; gate_id++) {
222 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
223 const auto B_gb_offset = gate_id * B_gb_iter_offset_;
224 ctx.addr_batch[gate_id].ptr.A
225 = A_m + A_gb_offset + A_k_tail_offset_;
226 ctx.addr_batch[gate_id].ptr.B
227 = B_wei_iter_n + B_gb_offset + B_k_tail_offset_;
228 }
229
230 ctx.tile_configure_if_needed(kernel_iter_k_tail_config);
231 brgemm_kernel_execute(kernel_iter_k_tail, num_gates, ctx.addr_batch,
232 reinterpret_cast<void *>(C_diff_iter_n), ctx.amx_buffer);
233 }
234
235 if (should_calc_diff_src_layer && k_tail_) {
236 for (int gate_id = gates_start; gate_id < gates_end; gate_id++) {
237 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
238 const auto B_gb_offset = gate_id * B_gb_layer_offset_;
239 ctx.addr_batch[gate_id].ptr.A
240 = A_m + A_gb_offset + A_k_tail_offset_;
241 ctx.addr_batch[gate_id].ptr.B
242 = B_wei_layer_n + B_gb_offset + B_k_tail_offset_;
243 }
244
245 ctx.tile_configure_if_needed(kernel_layer_k_tail_config);
246 brgemm_kernel_execute(kernel_layer_k_tail, num_gates, ctx.addr_batch,
247 reinterpret_cast<void *>(C_diff_layer_n), ctx.amx_buffer);
248 }
249}
250
251// TODO consider merging with kernel - check perf after merge
252template <typename weights_t, typename scratch_t, typename gemm_acc_t>
253void brgemm_diff_src_layer_iter_t<weights_t, scratch_t, gemm_acc_t>::kernel_amx(
254 const int ithr, const int nthr) const {
255 using namespace cpu::rnn_utils;
256
257 int mn_start = 0, mn_end = 0;
258 balance211(work_amount_, nthr, ithr, mn_start, mn_end);
259
260 int n_block_id = 0, m_block_id = 0;
261 const auto n_gates = rnn_.n_gates;
262 const int gates_block_size = rnn_.diff_src_brgemm.gates_block;
263 thread_exec_ctx_t ctx;
264 ctx.addr_batch = addr_batch_global_ + ithr * (k_blocks_n_gates_ + 1);
265 ctx.amx_buffer = amx_scratchpad_
266 + rnn_.diff_src_brgemm.m_block * rnn_.diff_src_brgemm.n_block
267 * ithr;
268
269 for (int gate_idx = 0; gate_idx < n_gates; gate_idx += gates_block_size) {
270 const int gates_start = gate_idx;
271 const int gates_end = nstl::min(gate_idx + gates_block_size, n_gates);
272
273 switch (rnn_.diff_src_brgemm.loop_order) {
274 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
275 nd_iterator_init(mn_start, m_block_id, m_blocking_, n_block_id,
276 n_blocking_);
277 break;
278 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
279 nd_iterator_init(mn_start, n_block_id, n_blocking_, m_block_id,
280 m_blocking_);
281 break;
282 default: assert(!"unsupported loop order");
283 }
284 int mn_idx = mn_start;
285 while (mn_idx < mn_end) {
286 kernel_amx_compute_iter(
287 m_block_id, n_block_id, gates_start, gates_end, ctx);
288 ++mn_idx;
289 switch (rnn_.diff_src_brgemm.loop_order) {
290 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
291 nd_iterator_step(
292 m_block_id, m_blocking_, n_block_id, n_blocking_);
293 break;
294 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
295 nd_iterator_step(
296 n_block_id, n_blocking_, m_block_id, m_blocking_);
297 break;
298 default: assert(!"unsupported loop order");
299 }
300 }
301 }
302}
303
304template <typename weights_t, typename scratch_t, typename gemm_acc_t>
305void brgemm_diff_src_layer_iter_t<weights_t, scratch_t, gemm_acc_t>::kernel(
306 const int ithr, const int nthr) const {
307 int start = 0, end = 0;
308 balance211(work_amount_, nthr, ithr, start, end);
309
310 int n_block_id = 0, m_block_id = 0;
311 nd_iterator_init(start, n_block_id, n_blocking_, m_block_id, m_blocking_);
312
313 x64::brgemm_batch_element_t *const addr_batch
314 = addr_batch_global_ + ithr * (k_blocks_n_gates_ + 1);
315 const auto n_gates = rnn_.n_gates;
316
317 while (start < end) {
318 const int m = m_block_id * rnn_.diff_src_brgemm.m_block;
319 const int n = n_block_id * rnn_.diff_src_brgemm.n_block;
320 const scratch_t *const A_m = A_ + m * LDA_;
321 const auto B_n_offset = n_block_id * B_nb_offset_;
322 const weights_t *const B_wei_iter_n = B_wei_iter_ + B_n_offset;
323 const weights_t *const B_wei_layer_n = B_wei_layer_ + B_n_offset;
324 const auto C_offset = m * LDC_ + n;
325 gemm_acc_t *const C_diff_iter_n = C_diff_iter_ + C_offset;
326 gemm_acc_t *const C_diff_layer_n = C_diff_layer_ + C_offset;
327 const brgemm_kernel_t *kernel_iter = kernel_iter_full_blocks_b0_;
328 const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_;
329 const brgemm_kernel_t *kernel_layer = kernel_layer_full_blocks_b0_;
330 const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_;
331 const bool should_calc_diff_src_layer
332 = gemm_layer_needed_ && n_block_id < max_n_layer_blocks_;
333 const bool should_calc_diff_src_iter = n_block_id < max_n_iter_blocks_;
334
335 if (should_calc_diff_src_iter) {
336 const bool do_n_iter_tail = (n + rnn_.diff_src_brgemm.n_block)
337 > rnn_.diff_src_brgemm.N_iter;
338
339 if (do_n_iter_tail) {
340 kernel_iter = kernel_iter_n_tail_b0_;
341 kernel_iter_k_tail = kernel_iter_nk_tail_;
342 }
343
344 for (int gate_id = 0; gate_id < n_gates; gate_id++) {
345 const auto g_block_id = gate_id * k_blocks_;
346 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
347 const auto B_gb_offset = gate_id * B_gb_iter_offset_;
348 const auto A_gm = A_m + A_gb_offset;
349 const auto B_wei_iter_gn = B_wei_iter_n + B_gb_offset;
350 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
351 addr_batch[g_block_id + k_block_id].ptr.A
352 = A_gm + k_block_id * k_block_;
353 addr_batch[g_block_id + k_block_id].ptr.B
354 = B_wei_iter_gn + k_block_id * B_kb_offset_;
355 }
356 }
357
358 brgemm_kernel_execute(kernel_iter, k_blocks_n_gates_, addr_batch,
359 reinterpret_cast<void *>(C_diff_iter_n), nullptr);
360 }
361
362 if (should_calc_diff_src_layer) {
363 const bool do_n_layer_tail = (n + rnn_.diff_src_brgemm.n_block)
364 > rnn_.diff_src_brgemm.N_layer;
365
366 if (do_n_layer_tail) {
367 kernel_layer = kernel_layer_n_tail_b0_;
368 kernel_layer_k_tail = kernel_layer_nk_tail_;
369 }
370
371 for (int gate_id = 0; gate_id < n_gates; gate_id++) {
372 const auto g_block_id = gate_id * k_blocks_;
373 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
374 const auto B_gb_offset = gate_id * B_gb_layer_offset_;
375 const auto A_gm = A_m + A_gb_offset;
376 const auto B_wei_layer_gn = B_wei_layer_n + B_gb_offset;
377 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
378 addr_batch[g_block_id + k_block_id].ptr.A
379 = A_gm + k_block_id * k_block_;
380 addr_batch[g_block_id + k_block_id].ptr.B
381 = B_wei_layer_gn + k_block_id * B_kb_offset_;
382 }
383 }
384 brgemm_kernel_execute(kernel_layer, k_blocks_n_gates_, addr_batch,
385 reinterpret_cast<void *>(C_diff_layer_n), nullptr);
386 }
387
388 if (should_calc_diff_src_iter && k_tail_) {
389 for (int gate_id = 0; gate_id < n_gates; gate_id++) {
390 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
391 const auto B_gb_offset = gate_id * B_gb_iter_offset_;
392 addr_batch[gate_id].ptr.A
393 = A_m + A_gb_offset + A_k_tail_offset_;
394 addr_batch[gate_id].ptr.B
395 = B_wei_iter_n + B_gb_offset + B_k_tail_offset_;
396 }
397
398 brgemm_kernel_execute(kernel_iter_k_tail, n_gates, addr_batch,
399 reinterpret_cast<void *>(C_diff_iter_n), nullptr);
400 }
401
402 if (should_calc_diff_src_layer && k_tail_) {
403 for (int gate_id = 0; gate_id < n_gates; gate_id++) {
404 const auto A_gb_offset = gate_id * rnn_.diff_src_brgemm.K;
405 const auto B_gb_offset = gate_id * B_gb_layer_offset_;
406 addr_batch[gate_id].ptr.A
407 = A_m + A_gb_offset + A_k_tail_offset_;
408 addr_batch[gate_id].ptr.B
409 = B_wei_layer_n + B_gb_offset + B_k_tail_offset_;
410 }
411
412 brgemm_kernel_execute(kernel_layer_k_tail, n_gates, addr_batch,
413 reinterpret_cast<void *>(C_diff_layer_n), nullptr);
414 }
415
416 ++start;
417 nd_iterator_step(n_block_id, n_blocking_, m_block_id, m_blocking_);
418 }
419}
420
421template <typename src_layer_t, typename src_iter_t, typename scratch_t,
422 typename gemm_acc_t>
423brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t,
424 gemm_acc_t>::brgemm_diff_weights_layer_iter_t(const ref_rnn_brgemm_t
425 &rnn_brgemm,
426 const rnn_utils::rnn_conf_t &rnn,
427 rnn_utils::cell_position_t cell_position, const src_layer_t *src_iter,
428 scratch_t *const A_iter_transposed_scratch, const src_iter_t *src_layer,
429 scratch_t *const A_layer_transposed_scratch, const scratch_t *scratch,
430 scratch_t *scratch_gates_blocked, gemm_acc_t *diff_weights_iter,
431 gemm_acc_t *diff_weights_layer, gemm_acc_t *diff_bias,
432 gemm_acc_t *amx_scratchpad,
433 x64::brgemm_batch_element_t *addr_batch_global)
434 : rnn_brgemm_(rnn_brgemm)
435 , rnn_(rnn)
436 , is_amx_(rnn_.is_cell_bf16_amx())
437 , A_iter_(src_iter)
438 , A_iter_transposed_scratch_(A_iter_transposed_scratch)
439 , A_layer_(src_layer)
440 , A_layer_transposed_scratch_(A_layer_transposed_scratch)
441 , B_(scratch)
442 , B_blocked_scratch_(scratch_gates_blocked)
443 , C_iter_(diff_weights_iter)
444 , C_layer_(diff_weights_layer)
445 , diff_bias_(diff_bias)
446 , LDA_iter_(rnn.diff_wei_brgemm.LDA_iter)
447 , LDA_layer_(rnn.diff_wei_brgemm.LDA_layer)
448 , LDC_iter_(rnn.diff_wei_brgemm.LDC_iter)
449 , LDC_layer_(rnn.diff_wei_brgemm.LDC_layer)
450 , max_nthr_(rnn.nthr)
451 , n_blocking_(rnn.diff_wei_brgemm.N_blocks)
452 , m_blocking_(rnn.diff_wei_brgemm.M_blocks)
453 , k_blocks_(rnn.diff_wei_brgemm.K_blocks)
454 , k_tail_(rnn.diff_wei_brgemm.k_tail)
455 , k_block_(rnn.diff_wei_brgemm.k_block)
456 , m_iter_block_(rnn.slc == rnn.sic ? rnn.diff_wei_brgemm.m_block
457 : rnn.diff_wei_brgemm.M_iter)
458 , m_layer_block_(rnn.slc == rnn.sic ? rnn.diff_wei_brgemm.m_block
459 : rnn.diff_wei_brgemm.M_layer)
460 , A_k_iter_tail_offset_(k_blocks_ * k_block_)
461 , A_k_layer_tail_offset_(k_blocks_ * k_block_)
462 , B_kb_offset_(k_block_ * rnn.diff_wei_brgemm.n_block)
463 , B_k_tail_offset_(k_blocks_ * k_block_ * rnn.scratch_gates_ld)
464 , B_k_tail_offset_blocked_(
465 k_blocks_ * k_block_ * rnn.diff_wei_brgemm.n_block)
466 , work_amount_(n_blocking_ * m_blocking_)
467 , kernel_iter_full_blocks_(rnn_brgemm.diff_wei_.kernel_iter_beta1_.get())
468 , kernel_iter_n_tail_(rnn_brgemm.diff_wei_.kernel_iter_N_tail_beta1_.get())
469 , kernel_iter_k_tail_(rnn_brgemm.diff_wei_.kernel_iter_K_tail_beta1_.get())
470 , kernel_iter_nk_tail_(
471 rnn_brgemm.diff_wei_.kernel_iter_NK_tail_beta1_.get())
472 , kernel_layer_full_blocks_(rnn_brgemm.diff_wei_.kernel_layer_beta1_.get())
473 , kernel_layer_n_tail_(
474 rnn_brgemm.diff_wei_.kernel_layer_N_tail_beta1_.get())
475 , kernel_layer_k_tail_(
476 rnn_brgemm.diff_wei_.kernel_layer_K_tail_beta1_.get())
477 , kernel_layer_nk_tail_(
478 rnn_brgemm.diff_wei_.kernel_layer_NK_tail_beta1_.get())
479 , cell_position_(cell_position)
480 , kernel_gates_reduction_(rnn_brgemm.kernel_gates_reduction_.get())
481 , kernel_gates_reduction_tail_(
482 rnn_brgemm.kernel_gates_reduction_tail_.get())
483 , kernel_transpose_iter_(rnn_brgemm.kernel_transpose_single_row_iter_.get())
484 , kernel_transpose_layer_(rnn_brgemm.kernel_transpose_single_row_layer_
485 ? rnn_brgemm.kernel_transpose_single_row_layer_.get()
486 : kernel_transpose_iter_)
487 , amx_scratchpad_(amx_scratchpad)
488 , addr_batch_global_(addr_batch_global) {}
489
490template <typename src_layer_t, typename src_iter_t, typename scratch_t,
491 typename gemm_acc_t>
492void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t,
493 gemm_acc_t>::execute() const {
494 if (is_amx_) {
495 parallel(max_nthr_, [this](const int ithr, const int nthr) {
496 this->kernel_amx(ithr, nthr);
497 });
498 } else {
499 parallel(max_nthr_, [this](const int ithr, const int nthr) {
500 this->kernel(ithr, nthr);
501 });
502 }
503}
504
505template <typename src_layer_t, typename src_iter_t, typename scratch_t,
506 typename gemm_acc_t>
507void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t,
508 gemm_acc_t>::reorder_scratch_gates(const scratch_t *src, scratch_t *dst,
509 const bool do_n_tail) const {
510 auto ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t();
511 ctx.src = (void *)src;
512 ctx.tr_src = (void *)dst;
513 ctx.current_K_start = 0;
514 ctx.current_K_iters = rnn_.mb;
515 ctx.current_N_blk = do_n_tail ? rnn_.diff_wei_brgemm.n_tail
516 : rnn_.diff_wei_brgemm.n_block;
517 (*rnn_brgemm_.diff_wei_.srcatch_gates_reorder_kernel_)(&ctx);
518}
519
520template <typename src_layer_t, typename src_iter_t, typename scratch_t,
521 typename gemm_acc_t>
522void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t,
523 gemm_acc_t>::kernel(const int ithr, const int nthr) const {
524
525 const bool global_transpose = rnn_.diff_wei_brgemm.global_transpose;
526
527 scratch_t *const B_blocked = B_blocked_scratch_
528 + ithr * rnn_.diff_wei_brgemm.Kpadded
529 * rnn_.diff_wei_brgemm.n_block;
530
531 scratch_t *const A_iter_transposed_ithr = global_transpose
532 ? A_iter_transposed_scratch_
533 : (A_iter_transposed_scratch_
534 + ithr * rnn_.diff_wei_brgemm.Kpadded * m_iter_block_);
535
536 scratch_t *const A_layer_transposed_ithr = global_transpose
537 ? A_layer_transposed_scratch_
538 : A_layer_transposed_scratch_
539 + ithr * rnn_.diff_wei_brgemm.Kpadded * m_layer_block_;
540
541 int start = 0, end = 0;
542 balance211(work_amount_, nthr, ithr, start, end);
543
544 int n_block_id = 0, m_block_id = 0, last_n_block_id = -1,
545 last_m_block_id = -1;
546
547 nd_iterator_init(start, n_block_id, n_blocking_, m_block_id, m_blocking_);
548
549 x64::brgemm_batch_element_t *const addr_batch
550 = addr_batch_global_ + ithr * (k_blocks_ + 1);
551
552 while (start < end) {
553 const bool should_reorder_gates = last_n_block_id != n_block_id;
554 const bool transpose_needed
555 = !(rnn_.mb == 1 && std::is_same<float, src_iter_t>::value);
556 const bool should_transpose_src = transpose_needed && !global_transpose
557 && (last_m_block_id != m_block_id);
558
559 const int m_iter = m_block_id * m_iter_block_;
560 const int m_layer = m_block_id * m_layer_block_;
561 const src_iter_t *const A_iter_m = global_transpose
562 ? A_iter_transposed_ithr + m_iter * LDA_iter_
563 : A_iter_ + m_iter;
564 const src_layer_t *const A_layer_m = global_transpose
565 ? A_layer_transposed_ithr + m_layer * LDA_layer_
566 : A_layer_ + m_layer;
567
568 src_iter_t *const A_iter_transposed
569 = (global_transpose || !transpose_needed)
570 ? const_cast<src_iter_t *>(A_iter_m)
571 : A_iter_transposed_ithr;
572 src_layer_t *const A_layer_transposed
573 = (global_transpose || !transpose_needed)
574 ? const_cast<src_layer_t *>(A_layer_m)
575 : A_layer_transposed_ithr;
576
577 const int n = n_block_id * rnn_.diff_wei_brgemm.n_block;
578 const scratch_t *const B_n = B_ + n;
579 const auto C_iter_offset = m_iter * LDC_iter_ + n;
580 const auto C_layer_offset = m_layer * LDC_layer_ + n;
581 gemm_acc_t *const C_diff_iter_n = C_iter_ + C_iter_offset;
582 gemm_acc_t *const C_diff_layer_n = C_layer_ + C_layer_offset;
583
584 const bool do_n_tail
585 = (n + rnn_.diff_wei_brgemm.n_block) > rnn_.diff_wei_brgemm.N;
586
587 const brgemm_kernel_t *kernel_iter = kernel_iter_full_blocks_;
588 const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_;
589 const brgemm_kernel_t *kernel_layer = kernel_layer_full_blocks_;
590 const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_;
591 const auto *kernel_reduction = kernel_gates_reduction_;
592
593 if (do_n_tail) {
594 kernel_iter = kernel_iter_n_tail_;
595 kernel_iter_k_tail = kernel_iter_nk_tail_;
596 kernel_layer = kernel_layer_n_tail_;
597 kernel_layer_k_tail = kernel_layer_nk_tail_;
598 kernel_reduction = kernel_gates_reduction_tail_;
599 }
600
601 if (should_reorder_gates) {
602 reorder_scratch_gates(B_n, B_blocked, do_n_tail);
603
604 if (m_block_id == 0) {
605 jit_gates_reduction_t::call_params_t params;
606 params.src = reinterpret_cast<const void *>(B_blocked);
607 params.dst = reinterpret_cast<void *>(diff_bias_ + n);
608 (*kernel_reduction)(&params);
609 }
610 }
611
612 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
613 addr_batch[k_block_id].ptr.A
614 = A_iter_transposed + k_block_id * k_block_;
615 addr_batch[k_block_id].ptr.B
616 = B_blocked + k_block_id * B_kb_offset_;
617 }
618 if (should_transpose_src) {
619 jit_brgemm_transpose_single_row_t::call_params_t params;
620 params.src = reinterpret_cast<const void *>(A_iter_m);
621 params.dst = reinterpret_cast<void *>(A_iter_transposed);
622 (*kernel_transpose_iter_)(&params);
623 }
624 brgemm_kernel_execute(kernel_iter, k_blocks_, addr_batch,
625 reinterpret_cast<void *>(C_diff_iter_n), nullptr);
626
627 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
628 addr_batch[k_block_id].ptr.A
629 = A_layer_transposed + k_block_id * k_block_;
630 addr_batch[k_block_id].ptr.B
631 = B_blocked + k_block_id * B_kb_offset_;
632 }
633 if (should_transpose_src) {
634 jit_brgemm_transpose_single_row_t::call_params_t params;
635 params.src = reinterpret_cast<const void *>(A_layer_m);
636 params.dst = reinterpret_cast<void *>(A_layer_transposed);
637 (*kernel_transpose_layer_)(&params);
638 }
639 brgemm_kernel_execute(kernel_layer, k_blocks_, addr_batch,
640 reinterpret_cast<void *>(C_diff_layer_n), nullptr);
641
642 if (k_tail_) {
643 const auto B_blocked_k_tail = B_blocked + B_k_tail_offset_blocked_;
644
645 addr_batch[0].ptr.A = A_iter_transposed + A_k_iter_tail_offset_;
646 addr_batch[0].ptr.B = B_blocked_k_tail;
647
648 brgemm_kernel_execute(kernel_iter_k_tail, 1, addr_batch,
649 reinterpret_cast<void *>(C_diff_iter_n), nullptr);
650
651 addr_batch[0].ptr.A = A_layer_transposed + A_k_layer_tail_offset_;
652 addr_batch[0].ptr.B = B_blocked_k_tail;
653
654 brgemm_kernel_execute(kernel_layer_k_tail, 1, addr_batch,
655 reinterpret_cast<void *>(C_diff_layer_n), nullptr);
656 }
657
658 if (should_reorder_gates) { last_n_block_id = n_block_id; }
659 if (should_transpose_src) { last_m_block_id = m_block_id; }
660
661 ++start;
662 nd_iterator_step(n_block_id, n_blocking_, m_block_id, m_blocking_);
663 }
664}
665
666template <typename src_layer_t, typename src_iter_t, typename scratch_t,
667 typename gemm_acc_t>
668void brgemm_diff_weights_layer_iter_t<src_layer_t, src_iter_t, scratch_t,
669 gemm_acc_t>::kernel_amx(const int ithr, const int nthr) const {
670 using namespace cpu::rnn_utils;
671
672 const bool global_transpose = rnn_.diff_wei_brgemm.global_transpose;
673
674 int start = 0, end = 0;
675 balance211(work_amount_, nthr, ithr, start, end);
676
677 int n_block_id = 0, m_block_id = 0, last_n_block_id = -1,
678 last_m_block_id = -1;
679 switch (rnn_.diff_wei_brgemm.loop_order) {
680 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
681 nd_iterator_init(
682 start, m_block_id, m_blocking_, n_block_id, n_blocking_);
683 break;
684 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
685 nd_iterator_init(
686 start, n_block_id, n_blocking_, m_block_id, m_blocking_);
687 break;
688 default: assert(!"unsupported loop order");
689 }
690
691 x64::brgemm_batch_element_t *const addr_batch
692 = addr_batch_global_ + ithr * (k_blocks_ + 1);
693 scratch_t *const B_blocked = B_blocked_scratch_
694 + ithr * rnn_.diff_wei_brgemm.Kpadded
695 * rnn_.diff_wei_brgemm.n_block;
696 scratch_t *const A_iter_transposed_ithr = global_transpose
697 ? A_iter_transposed_scratch_
698 : (A_iter_transposed_scratch_
699 + ithr * rnn_.diff_wei_brgemm.Kpadded * m_iter_block_);
700 scratch_t *const A_layer_transposed_ithr = global_transpose
701 ? A_layer_transposed_scratch_
702 : A_layer_transposed_scratch_
703 + ithr * rnn_.diff_wei_brgemm.Kpadded * m_layer_block_;
704
705 gemm_acc_t *const amx_buffer = amx_scratchpad_
706 + rnn_.diff_wei_brgemm.m_block * rnn_.diff_wei_brgemm.n_block
707 * ithr;
708 const bool m_equal
709 = rnn_.diff_wei_brgemm.M_iter == rnn_.diff_wei_brgemm.M_layer;
710 amx_tile_configuration_loader_t load_cfg_if_needed;
711
712 while (start < end) {
713 const bool should_reorder_gates = last_n_block_id != n_block_id;
714 const bool should_transpose_src
715 = !global_transpose && last_m_block_id != m_block_id;
716
717 const int m_iter = m_block_id * m_iter_block_;
718 const int m_layer = m_block_id * m_layer_block_;
719 const src_iter_t *const A_iter_m = global_transpose
720 ? A_iter_transposed_ithr + m_iter * LDA_iter_
721 : A_iter_ + m_iter;
722 const src_layer_t *const A_layer_m = global_transpose
723 ? A_layer_transposed_ithr + m_layer * LDA_layer_
724 : A_layer_ + m_layer;
725
726 src_iter_t *const A_iter_transposed = global_transpose
727 ? const_cast<src_iter_t *>(A_iter_m)
728 : A_iter_transposed_ithr;
729 src_layer_t *const A_layer_transposed = global_transpose
730 ? const_cast<src_layer_t *>(A_layer_m)
731 : A_layer_transposed_ithr;
732
733 const int n = n_block_id * rnn_.diff_wei_brgemm.n_block;
734 const scratch_t *const B_n = B_ + n;
735 const auto C_iter_offset = m_iter * LDC_iter_ + n;
736 const auto C_layer_offset = m_layer * LDC_layer_ + n;
737 gemm_acc_t *const C_diff_iter_n = C_iter_ + C_iter_offset;
738 gemm_acc_t *const C_diff_layer_n = C_layer_ + C_layer_offset;
739
740 const bool do_n_tail
741 = (n + rnn_.diff_wei_brgemm.n_block) > rnn_.diff_wei_brgemm.N;
742
743 const brgemm_kernel_t *kernel_iter = kernel_iter_full_blocks_;
744 const brgemm_kernel_t *kernel_iter_k_tail = kernel_iter_k_tail_;
745 const brgemm_kernel_t *kernel_layer = kernel_layer_full_blocks_;
746 const brgemm_kernel_t *kernel_layer_k_tail = kernel_layer_k_tail_;
747 const auto *kernel_reduction = kernel_gates_reduction_;
748
749 const char *kernel_iter_config
750 = rnn_brgemm_.diff_wei_.pallete_buff_iter_;
751 const char *kernel_iter_k_tail_config
752 = rnn_brgemm_.diff_wei_.pallete_buff_iter_k_tail_;
753 const char *kernel_layer_config = m_equal
754 ? rnn_brgemm_.diff_wei_.pallete_buff_iter_
755 : rnn_brgemm_.diff_wei_.pallete_buff_layer_;
756 const char *kernel_layer_k_tail_config = m_equal
757 ? rnn_brgemm_.diff_wei_.pallete_buff_iter_k_tail_
758 : rnn_brgemm_.diff_wei_.pallete_buff_layer_k_tail_;
759
760 if (do_n_tail) {
761 kernel_iter = kernel_iter_n_tail_;
762 kernel_iter_k_tail = kernel_iter_nk_tail_;
763 kernel_layer = kernel_layer_n_tail_;
764 kernel_layer_k_tail = kernel_layer_nk_tail_;
765 kernel_reduction = kernel_gates_reduction_tail_;
766
767 kernel_iter_config
768 = rnn_brgemm_.diff_wei_.pallete_buff_iter_n_tail_;
769 kernel_iter_k_tail_config
770 = rnn_brgemm_.diff_wei_.pallete_buff_iter_nk_tail_;
771 kernel_layer_config = m_equal
772 ? rnn_brgemm_.diff_wei_.pallete_buff_iter_n_tail_
773 : rnn_brgemm_.diff_wei_.pallete_buff_layer_n_tail_;
774 kernel_layer_k_tail_config = m_equal
775 ? rnn_brgemm_.diff_wei_.pallete_buff_iter_nk_tail_
776 : rnn_brgemm_.diff_wei_.pallete_buff_layer_nk_tail_;
777 }
778
779 if (should_reorder_gates) {
780 reorder_scratch_gates(B_n, B_blocked, do_n_tail);
781
782 if (m_block_id == 0) {
783 jit_gates_reduction_t::call_params_t params;
784 params.src = reinterpret_cast<const void *>(B_blocked);
785 params.dst = reinterpret_cast<void *>(diff_bias_ + n);
786 (*kernel_reduction)(&params);
787 }
788 }
789
790 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
791 addr_batch[k_block_id].ptr.A
792 = A_iter_transposed + k_block_id * k_block_;
793 addr_batch[k_block_id].ptr.B
794 = B_blocked + k_block_id * B_kb_offset_;
795 }
796 if (should_transpose_src) {
797 jit_brgemm_transpose_single_row_t::call_params_t params;
798 params.src = reinterpret_cast<const void *>(A_iter_m);
799 params.dst = reinterpret_cast<void *>(A_iter_transposed);
800 (*kernel_transpose_iter_)(&params);
801 }
802
803 load_cfg_if_needed(kernel_iter_config);
804 brgemm_kernel_execute(kernel_iter, k_blocks_, addr_batch,
805 reinterpret_cast<void *>(C_diff_iter_n), amx_buffer);
806
807 for (int k_block_id = 0; k_block_id < k_blocks_; k_block_id++) {
808 addr_batch[k_block_id].ptr.A
809 = A_layer_transposed + k_block_id * k_block_;
810 addr_batch[k_block_id].ptr.B
811 = B_blocked + k_block_id * B_kb_offset_;
812 }
813
814 if (should_transpose_src) {
815 jit_brgemm_transpose_single_row_t::call_params_t params;
816 params.src = reinterpret_cast<const void *>(A_layer_m);
817 params.dst = reinterpret_cast<void *>(A_layer_transposed);
818 (*kernel_transpose_layer_)(&params);
819 }
820
821 load_cfg_if_needed(kernel_layer_config);
822 brgemm_kernel_execute(kernel_layer, k_blocks_, addr_batch,
823 reinterpret_cast<void *>(C_diff_layer_n), amx_buffer);
824
825 if (k_tail_) {
826 const auto B_blocked_k_tail = B_blocked + B_k_tail_offset_blocked_;
827
828 addr_batch[0].ptr.A = A_iter_transposed + A_k_iter_tail_offset_;
829 addr_batch[0].ptr.B = B_blocked_k_tail;
830
831 load_cfg_if_needed(kernel_iter_k_tail_config);
832 brgemm_kernel_execute(kernel_iter_k_tail, 1, addr_batch,
833 reinterpret_cast<void *>(C_diff_iter_n), amx_buffer);
834
835 addr_batch[0].ptr.A = A_layer_transposed + A_k_layer_tail_offset_;
836 addr_batch[0].ptr.B = B_blocked_k_tail;
837
838 load_cfg_if_needed(kernel_layer_k_tail_config);
839 brgemm_kernel_execute(kernel_layer_k_tail, 1, addr_batch,
840 reinterpret_cast<void *>(C_diff_layer_n), amx_buffer);
841 }
842
843 if (should_reorder_gates) { last_n_block_id = n_block_id; }
844 if (should_transpose_src) { last_m_block_id = m_block_id; }
845
846 ++start;
847 switch (rnn_.diff_wei_brgemm.loop_order) {
848 case brgemm_rnn_execute_loop_order_t::mblk_nblk:
849 nd_iterator_step(
850 m_block_id, m_blocking_, n_block_id, n_blocking_);
851 break;
852 case brgemm_rnn_execute_loop_order_t::nblk_mblk:
853 nd_iterator_step(
854 n_block_id, n_blocking_, m_block_id, m_blocking_);
855 break;
856 default: assert(!"unsupported loop order");
857 }
858 }
859}
860
861template <typename scratch_t>
862brgemm_diff_wei_peep_t<scratch_t>::brgemm_diff_wei_peep_t(
863 const ref_rnn_brgemm_t &rnn_brgemm, const rnn_utils::rnn_conf_t &rnn,
864 rnn_utils::cell_position_t cell_position,
865 const scratch_t *scratch_gates, const void *src_iter_c,
866 const void *dst_iter_c, float *diff_weights_peephole)
867 : rnn_(rnn)
868 , scratch_gates_(scratch_gates)
869 , src_iter_c_(src_iter_c)
870 , dst_iter_c_(dst_iter_c)
871 , diff_weights_peephole_(diff_weights_peephole)
872 , work_amount_(n_gates_ * rnn_.dhc_blocks_peephole)
873 , dst_iter_c_ld_(rnn.dst_iter_c_ld(cell_position))
874 , src_iter_c_ld_(rnn.src_iter_c_ld(cell_position))
875 , kernel_(rnn_brgemm.kernel_peephole_.get())
876 , kernel_tail_(rnn_brgemm.kernel_peephole_tail_.get()) {}
877
878template <typename scratch_t>
879void brgemm_diff_wei_peep_t<scratch_t>::execute() const {
880 parallel(rnn_.nthr, [this](const int ithr, const int nthr) {
881 this->kernel(ithr, nthr);
882 });
883}
884
885template <typename scratch_t>
886void brgemm_diff_wei_peep_t<scratch_t>::kernel(
887 const int ithr, const int nthr) const {
888
889 int start = 0, end = 0;
890 balance211(work_amount_, nthr, ithr, start, end);
891
892 int g = 0, dhc_block_id = 0;
893
894 nd_iterator_init(
895 start, g, n_gates_, dhc_block_id, rnn_.dhc_blocks_peephole);
896
897 const auto dst_iter_c = rnn_utils::make_raw_aoc(dst_iter_c_,
898 types::data_type_size(rnn_.dst_iter_c_dt),
899 rnn_.ws_states_iter_c_nld, dst_iter_c_ld_);
900 const auto src_iter_c = rnn_utils::make_raw_aoc(src_iter_c_,
901 types::data_type_size(rnn_.src_iter_c_dt),
902 rnn_.ws_states_iter_c_nld, src_iter_c_ld_);
903
904 const rnn_utils::ws_gates_aoc<const scratch_t> scratch_gates(
905 rnn_, scratch_gates_);
906 const rnn_utils::weights_peephole_aoc_t<float> diff_weights_peephole(
907 rnn_, diff_weights_peephole_);
908
909 while (start < end) {
910 const auto dhc = dhc_block_id * rnn_.dhc_block_peephole;
911 const auto &c_states = g < 2 ? src_iter_c : dst_iter_c;
912 const auto scratch_g = g == 2 ? 3 : g;
913 const auto *const kernel = rnn_.dhc_tail_peephole
914 && dhc_block_id == rnn_.dhc_blocks_peephole - 1
915 ? kernel_tail_
916 : kernel_;
917
918 jit_diff_weights_peephole_t::call_params_t params;
919
920 for (int mb = 0; mb < rnn_.mb; ++mb) {
921 params.c_states = c_states(mb, dhc);
922 params.scratch_gates = &scratch_gates(mb, scratch_g, dhc);
923 params.dst = &diff_weights_peephole(g, dhc);
924 (*kernel)(&params);
925 }
926
927 ++start;
928 nd_iterator_step(g, n_gates_, dhc_block_id, rnn_.dhc_blocks_peephole);
929 }
930}
931
932template class brgemm_diff_src_layer_iter_t<float, float, float>;
933template class brgemm_diff_src_layer_iter_t<bfloat16_t, bfloat16_t, float>;
934
935template class brgemm_diff_weights_layer_iter_t<float, float, float, float>;
936template class brgemm_diff_weights_layer_iter_t<bfloat16_t, bfloat16_t,
937 bfloat16_t, float>;
938
939template class brgemm_diff_wei_peep_t<bfloat16_t>;
940template class brgemm_diff_wei_peep_t<float>;
941
942} // namespace x64
943} // namespace cpu
944} // namespace impl
945} // namespace dnnl
946