1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23#include "cpu/scale_utils.hpp"
24
25#include "cpu/x64/amx_tile_configure.hpp"
26#include "cpu/x64/cpu_barrier.hpp"
27#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28#include "cpu/x64/jit_brgemm_inner_product.hpp"
29#include "cpu/x64/jit_transpose_utils.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36using namespace dnnl::impl::cpu::x64::brgemm_inner_product_utils;
37using namespace dnnl::impl::data_type;
38using namespace dnnl::impl::format_tag;
39using namespace dnnl::impl::memory_tracking::names;
40using namespace dnnl::impl::status;
41using namespace dnnl::impl::utils;
42
43using namespace nstl;
44
45#define get_blk_off(d, dt, ...) \
46 (types::data_type_size((dt)) * (d).blk_off(__VA_ARGS__))
47
48namespace {
49template <typename ker_type>
50void copy_data_chunk(ker_type &ker, char *tr_data, const char *data,
51 int os_work, bool is_last_blk) {
52 auto ctx = jit_brgemm_copy_to_coarse_t::ctx_t();
53 ctx.data = (void *)data;
54 ctx.tr_data = (void *)tr_data;
55 ctx.os_work = os_work;
56 ctx.last_row_blk = is_last_blk ? 1 : 0;
57 (*ker)(&ctx);
58}
59} // namespace
60
61template <cpu_isa_t isa>
62status_t brgemm_inner_product_fwd_t<isa>::execute_forward(
63 const exec_ctx_t &ctx) const {
64 auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
65 auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
66 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
67 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
68 const auto post_ops_binary_rhs_arg_vec
69 = binary_injector::prepare_binary_args(
70 pd()->attr()->post_ops_, ctx);
71
72 memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
73 const memory_desc_wrapper src_d(pd()->src_md());
74 const memory_desc_wrapper dst_d(pd()->dst_md());
75 const memory_desc_wrapper weights_d(pd()->weights_md(0));
76
77 const auto &jbgp = pd()->jbgp_;
78
79 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
80 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
81
82 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
83 src_scales, wei_scales, pd()->OC(), pd()->attr());
84
85 const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);
86
87 const size_t src_dt_size = types::data_type_size(jbgp.src_dt);
88 const size_t bia_dt_size
89 = jbgp.with_bias ? types::data_type_size(jbgp.bia_dt) : 0;
90 const size_t acc_dt_size = types::data_type_size(jbgp.acc_dt);
91 const size_t dst_dt_size = types::data_type_size(jbgp.dst_dt);
92
93 auto addr_batch_global = scratchpad.template get<brgemm_batch_element_t>(
94 key_brgemm_primitive_batch);
95 auto a_buffer_global = (jbgp.use_buffer_a)
96 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a)
97 : nullptr;
98 auto c_buffer_global = (jbgp.use_buffer)
99 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
100 : nullptr;
101 const bool is_amx = jbgp.is_amx;
102 auto wsp_tile_base = is_amx
103 ? ctx.get_scratchpad_grantor().template get<char>(
104 key_conv_amx_tile_buffer)
105 : nullptr;
106
107 const int ic_chunks = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking);
108
109 const bool are_post_ops_applicable = one_of(true, jbgp.with_sum,
110 jbgp.with_bias, jbgp.with_scales, jbgp.with_eltwise,
111 jbgp.with_binary, jbgp.acc_dt != jbgp.dst_dt, jbgp.signed_input);
112
113 size_t offset = types::data_type_size(jbgp.wei_dt)
114 * (weights_d.size() - weights_d.additional_buffer_size());
115 auto compensation = (jbgp.signed_input)
116 ? reinterpret_cast<const int32_t *>(&weights[offset])
117 : nullptr;
118
119 bool is_os_tail = (jbgp.mb < jbgp.os_block);
120 bool is_oc_tail = (jbgp.oc < jbgp.oc_block);
121 int base_brg_ker_idx = brgemm_inner_product_utils::
122 get_brg_kernel_index( // TODO: Can be calculated on initialization stage
123 jbgp, false, false, is_os_tail, is_oc_tail, false);
124
125 const auto ker = [&](int ithr_oc_mb, int nthr_oc_mb, int ithr_ic, int n,
126 int ocb, int icc, bool do_init, int buffer_a_osb,
127 bool copy_buffer_a) {
128 const int ithr = nthr_oc_mb * ithr_ic + ithr_oc_mb;
129 auto addr_batch = addr_batch_global + ithr * jbgp.adjusted_batch_size;
130
131 const size_t a_buffer_osb_stride
132 = src_dt_size * jbgp.LDA * jbgp.os_block;
133 const size_t a_buffer_per_thr
134 = a_buffer_osb_stride * jbgp.nb_os_blocking;
135 auto a_buffer = (jbgp.use_buffer_a)
136 ? a_buffer_global + ithr * a_buffer_per_thr
137 + buffer_a_osb * a_buffer_osb_stride
138 : nullptr;
139
140 const int oc = ocb * jbgp.oc_block;
141 const size_t dst_off = get_blk_off(dst_d, jbgp.dst_dt, n, oc);
142
143 const bool use_c_buffer = (jbgp.with_sum)
144 || (jbgp.use_buffer && (jbgp.nthr_ic_b == 1 || ithr_ic > 0));
145
146 char *c_buffer = nullptr;
147 if (use_c_buffer) {
148 const size_t c_buf_thr_idx = jbgp.nthr_ic_b <= 1
149 ? ithr
150 : (jbgp.acc_dt != jbgp.dst_dt || jbgp.with_sum
151 ? ithr_ic
152 : ithr_ic - 1);
153 const size_t c_buf_num_rows = jbgp.nthr_ic_b > 1 ? jbgp.mb : jbgp.M;
154 const size_t c_buffer_shift
155 = c_buf_thr_idx * c_buf_num_rows * jbgp.LDC;
156 const size_t c_buffer_off = acc_dt_size * c_buffer_shift
157 + (jbgp.nthr_ic_b > 1 ? acc_dt_size * dst_off / dst_dt_size
158 : 0);
159 c_buffer = c_buffer_global + c_buffer_off;
160 }
161
162 char *wsp_tile = is_amx
163 ? wsp_tile_base + ithr * jbgp.amx_buf_size_per_thread
164 : nullptr;
165 int icb = icc * jbgp.nb_ic_blocking;
166 int ic = icb * jbgp.ic_block;
167
168 bool kernel_init = do_init;
169
170 bool is_os_tail = (jbgp.mb - n < jbgp.os_block);
171 bool is_oc_tail = (jbgp.oc - oc < jbgp.oc_block);
172 bool is_last_ic_chunk = icc == ic_chunks - 1;
173 bool is_ic_tail = is_last_ic_chunk && jbgp.K_tail > 0;
174 const int remaining_ic_blks
175 = (jbgp.use_buffer_a ? utils::rnd_up(jbgp.ic, jbgp.ic_block)
176 : jbgp.ic)
177 - ic;
178 const int gemm_batch
179 = nstl::min(jbgp.gemm_batch_size, remaining_ic_blks / jbgp.K);
180
181 auto is_bs_tail = (gemm_batch != jbgp.gemm_batch_size);
182 int brg_ker_idx = brgemm_inner_product_utils::get_brg_kernel_index(
183 jbgp, is_bs_tail, kernel_init, is_os_tail, is_oc_tail, false);
184 auto brg_kernel = brg_kernels_[brg_ker_idx].get();
185
186 if (copy_buffer_a) {
187 assert(!jbgp.is_bf32);
188 auto src_ptr = src + get_blk_off(src_d, jbgp.src_dt, n, ic);
189 copy_data_chunk(copy_src_kernel_, a_buffer, src_ptr,
190 is_os_tail ? jbgp.mb - n : jbgp.os_block, is_last_ic_chunk);
191 }
192 if (gemm_batch > 0 && brg_kernel != nullptr) {
193 if (is_amx && (is_os_tail || is_oc_tail))
194 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
195 const int ic_blocks_per_batch = jbgp.K / jbgp.ic_block;
196 for (int b = 0; b < gemm_batch; b++) {
197 auto A_ptr = jbgp.use_buffer_a
198 ? (a_buffer + src_dt_size * b * jbgp.K)
199 : (src
200 + get_blk_off(src_d, jbgp.src_dt, n,
201 ic + b * jbgp.K));
202 addr_batch[b].ptr.A = A_ptr;
203 addr_batch[b].ptr.B = weights
204 + get_blk_off(weights_d, jbgp.wei_dt, ocb,
205 icb + b * ic_blocks_per_batch);
206 }
207
208 auto ptr_D = dst + dst_off;
209 auto ptr_C = use_c_buffer ? c_buffer : ptr_D;
210
211 if (jbgp.nthr_ic_b == 1 && are_post_ops_applicable
212 && is_last_ic_chunk && !is_ic_tail) {
213 void *scratch = is_amx
214 ? static_cast<void *>(wsp_tile)
215 : (jbgp.signed_input ? static_cast<void *>(
216 const_cast<int *>(&compensation[oc]))
217 : nullptr);
218 auto ptr_bias
219 = jbgp.with_bias ? bias + bia_dt_size * oc : nullptr;
220 const brgemm_post_ops_data_t post_ops_data {
221 static_cast<const void *>(ptr_bias),
222 &oscales[jbgp.is_oc_scale * oc],
223 post_ops_binary_rhs_arg_vec.data(),
224 static_cast<size_t>(oc), 0, dst};
225
226 brgemm_kernel_execute_postops(brg_kernel, gemm_batch,
227 addr_batch, (void *)ptr_C, (void *)ptr_D, post_ops_data,
228 scratch);
229 } else {
230 brgemm_kernel_execute(brg_kernel, gemm_batch, addr_batch,
231 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
232 }
233
234 if (is_amx && (is_os_tail || is_oc_tail))
235 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
236 }
237
238 if (is_ic_tail) {
239 assert(!jbgp.use_buffer_a);
240 int ic_block = gemm_batch * jbgp.K / jbgp.ic_block;
241 addr_batch[0].ptr.A = src
242 + get_blk_off(src_d, jbgp.src_dt, n,
243 ic + ic_block * jbgp.ic_block);
244 addr_batch[0].ptr.B = weights
245 + get_blk_off(weights_d, jbgp.wei_dt, ocb, icb + ic_block);
246
247 auto use_init_ker = (kernel_init && gemm_batch == 0);
248 int brg_ker_idx = brgemm_inner_product_utils::get_brg_kernel_index(
249 jbgp, false, use_init_ker, is_os_tail, is_oc_tail, true);
250 auto brg_kernel_ic_tail = brg_kernels_[brg_ker_idx].get();
251 if (is_amx)
252 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
253 auto ptr_D = dst + dst_off;
254 auto ptr_C = use_c_buffer ? c_buffer : ptr_D;
255 if (jbgp.nthr_ic_b == 1 && are_post_ops_applicable) {
256 void *scratch = is_amx
257 ? static_cast<void *>(wsp_tile)
258 : (jbgp.signed_input ? static_cast<void *>(
259 const_cast<int *>(&compensation[oc]))
260 : nullptr);
261 auto ptr_bias
262 = jbgp.with_bias ? bias + bia_dt_size * oc : nullptr;
263 const brgemm_post_ops_data_t post_ops_data {
264 static_cast<const void *>(ptr_bias),
265 &oscales[jbgp.is_oc_scale * oc],
266 post_ops_binary_rhs_arg_vec.data(),
267 static_cast<size_t>(oc), 0, dst};
268
269 brgemm_kernel_execute_postops(brg_kernel_ic_tail, 1, addr_batch,
270 (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch);
271 } else {
272 brgemm_kernel_execute(brg_kernel_ic_tail, 1, addr_batch,
273 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
274 }
275 if (is_amx)
276 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
277 }
278 };
279
280 const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking);
281 const int oc_chunks = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking);
282 const int work_amount = oc_chunks * os_chunks;
283
284 const auto init_thr_groups
285 = [&](const int ithr, const int nthr, int &nthr_ic, int &nthr_oc_mb,
286 int &ithr_ic, int &ithr_oc_mb) {
287 nthr_ic = jbgp.nthr_ic_b <= nthr ? jbgp.nthr_ic_b : 1;
288 nthr_oc_mb = nthr / nthr_ic;
289 ithr_ic = ithr / nthr_oc_mb;
290 ithr_oc_mb = ithr % nthr_oc_mb;
291 if (ithr_oc_mb >= work_amount || ithr_ic >= ic_chunks
292 || ithr >= rnd_dn(nthr, nthr_ic))
293 return false;
294 return true;
295 };
296
297 // If work_amount == 1 we limit num_threads to 1 as parallel(1, ...) does
298 // not create parallel section at all. We do not limit num_threads
299 // for 1 < work_amount < dnnl_get_max_threads() case to avoid potential
300 // overhead on spawning different number of OMP threads from layer to layer.
301 const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr);
302 parallel(num_threads, [&](const int ithr, const int nthr) {
303 int nthr_ic {1}, nthr_oc_mb {1}, ithr_ic {0}, ithr_oc_mb {0};
304 bool ok = init_thr_groups(
305 ithr, nthr, nthr_ic, nthr_oc_mb, ithr_ic, ithr_oc_mb);
306 if (!ok) return;
307
308 int start {0}, end {0};
309 balance211(work_amount, nthr_oc_mb, ithr_oc_mb, start, end);
310
311 int icc_start {0}, icc_end {ic_chunks};
312 if (nthr_ic > 1)
313 balance211(ic_chunks, nthr_ic, ithr_ic, icc_start, icc_end);
314
315 const int icc_work = icc_end - icc_start;
316
317 if (is_amx)
318 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
319
320 int occ {0}, osc {0};
321 nd_iterator_init(start, osc, os_chunks, occ, oc_chunks);
322 while (start < end) {
323 int ocb_s = occ * jbgp.nb_oc_blocking;
324 int ocb_e = nstl::min(ocb_s + jbgp.nb_oc_blocking, jbgp.nb_oc);
325 int ocb_work = ocb_e - ocb_s;
326
327 int osb_s = osc * jbgp.nb_os_blocking;
328 int osb_e = nstl::min(osb_s + jbgp.nb_os_blocking, jbgp.nb_os);
329 int osb_work = osb_e - osb_s;
330
331 // Each thread runs the below loops:
332 int loop_start = 0, loop_end = icc_work * osb_work * ocb_work;
333 int icc = 0, osb = 0, ocb = 0;
334
335 // If buffer is required, then inner-most loop will be over icc_work
336 const bool ocb_inner_most
337 = is_f32 && !(jbgp.is_bf32 || jbgp.use_buffer);
338 if (ocb_inner_most)
339 nd_iterator_init(
340 0, icc, icc_work, osb, osb_work, ocb, ocb_work);
341 else
342 nd_iterator_init(
343 0, osb, osb_work, ocb, ocb_work, icc, icc_work);
344
345 while (loop_start < loop_end) {
346 const int n = (osb + osb_s) * jbgp.os_block;
347 const int cur_icc = icc + icc_start;
348 const bool copy_buffer_a = jbgp.use_buffer_a
349 && IMPLICATION(ocb_inner_most, ocb == 0);
350 ker(ithr_oc_mb, nthr_oc_mb, ithr_ic, n, ocb + ocb_s, cur_icc,
351 cur_icc == icc_start, osb, copy_buffer_a);
352
353 ++loop_start;
354 if (ocb_inner_most)
355 nd_iterator_step(
356 icc, icc_work, osb, osb_work, ocb, ocb_work);
357 else
358 nd_iterator_step(
359 osb, osb_work, ocb, ocb_work, icc, icc_work);
360 }
361
362 ++start;
363 nd_iterator_step(osc, os_chunks, occ, oc_chunks);
364 }
365 if (is_amx) amx_tile_release();
366 });
367
368 if (jbgp.nthr_ic_b > 1) {
369 assert(jbgp.use_buffer && is_f32);
370
371 const auto get_dst_reduced_off = [&](int ithr_ic, int osb, int ocb) {
372 assert(jbgp.nthr_ic_b > 1);
373 int os = osb * jbgp.os_block;
374 int oc = ocb * jbgp.oc_block;
375 const size_t dst_off = get_blk_off(dst_d, jbgp.dst_dt, os, oc);
376 if (ithr_ic == 0) return dst_off;
377 assert(ithr_ic > 0);
378 const size_t ic_buf_idx = jbgp.with_sum ? ithr_ic : ithr_ic - 1;
379 return dst_off + (ic_buf_idx * jbgp.mb * jbgp.LDC * acc_dt_size);
380 };
381
382 parallel(num_threads, [&](const int ithr, const int nthr) {
383 int nthr_ic {1}, nthr_oc_mb {1}, ithr_ic {0}, ithr_oc_mb {0};
384 bool ok = init_thr_groups(
385 ithr, nthr, nthr_ic, nthr_oc_mb, ithr_ic, ithr_oc_mb);
386 if (!ok) return;
387
388 int ocmb_start {0}, ocmb_end {0};
389 int start {0}, end {0};
390 balance211(
391 work_amount, nthr_oc_mb, ithr_oc_mb, ocmb_start, ocmb_end);
392 balance211(ocmb_end - ocmb_start, nthr_ic, ithr_ic, start, end);
393
394 int occ {0}, osc {0};
395 nd_iterator_init(
396 ocmb_start + start, osc, os_chunks, occ, oc_chunks);
397 while (start < end) {
398 int ocb_s = occ * jbgp.nb_oc_blocking;
399 int ocb_e = nstl::min(ocb_s + jbgp.nb_oc_blocking, jbgp.nb_oc);
400
401 int osb_s = osc * jbgp.nb_os_blocking;
402 int osb_e = nstl::min(osb_s + jbgp.nb_os_blocking, jbgp.nb_os);
403
404 for (int osb = osb_s; osb < osb_e; ++osb) {
405 int cur_os_block = nstl::min(
406 jbgp.os - osb * jbgp.os_block, jbgp.os_block);
407 const bool is_os_tail = cur_os_block < jbgp.os_block;
408 const int cur_oc_chunk_size
409 = nstl::min(jbgp.LDC, ocb_e * jbgp.oc_block)
410 - ocb_s * jbgp.oc_block;
411 char *dst_reduced = (jbgp.with_sum ? c_buffer_global : dst)
412 + get_dst_reduced_off(0, osb, ocb_s);
413 const size_t os_offset = jbgp.LDC * acc_dt_size;
414 for (int ic_buf = 0; ic_buf < nthr_ic - 1; ++ic_buf) {
415 const char *c_buffer = c_buffer_global
416 + get_dst_reduced_off(ic_buf + 1, osb, ocb_s);
417 for (int os = 0; os < cur_os_block; ++os) {
418 acc_ker_->accumulate(
419 (float *)(dst_reduced + os * os_offset),
420 (float *)(c_buffer + os * os_offset),
421 cur_oc_chunk_size);
422 }
423 }
424 if (are_post_ops_applicable) {
425 for (int ocb = ocb_s; ocb < ocb_e; ++ocb) {
426 const bool is_oc_tail
427 = (jbgp.oc - ocb * jbgp.oc_block
428 < jbgp.oc_block);
429 const int brg_ker_idx = brgemm_inner_product_utils::
430 get_brg_kernel_index(jbgp, false, false,
431 is_os_tail, is_oc_tail, false);
432 const auto brg_kernel
433 = brg_kernels_[brg_ker_idx].get();
434 const int os = osb * jbgp.os_block;
435 const int oc = ocb * jbgp.oc_block;
436 const auto ptr_bias = jbgp.with_bias
437 ? bias + bia_dt_size * oc
438 : nullptr;
439 auto ptr_D = dst
440 + get_blk_off(dst_d, jbgp.dst_dt, os, oc);
441 auto ptr_C = (jbgp.with_sum ? c_buffer_global : dst)
442 + get_dst_reduced_off(0, osb, ocb);
443
444 char *wsp_tile = is_amx ? wsp_tile_base
445 + ithr * jbgp.amx_buf_size_per_thread
446 : nullptr;
447
448 void *scratch = is_amx
449 ? static_cast<void *>(wsp_tile)
450 : (jbgp.signed_input ? static_cast<void *>(
451 const_cast<int *>(
452 &compensation[oc]))
453 : nullptr);
454
455 const brgemm_post_ops_data_t post_ops_data {
456 static_cast<const void *>(ptr_bias),
457 &oscales[jbgp.is_oc_scale * oc],
458 post_ops_binary_rhs_arg_vec.data(),
459 static_cast<size_t>(oc), 0, dst, 0, nullptr,
460 nullptr, nullptr, true /* skip_accm */};
461 brgemm_kernel_execute_postops(brg_kernel, 0,
462 nullptr, (void *)ptr_C, (void *)ptr_D,
463 post_ops_data, scratch);
464 }
465 }
466 }
467 ++start;
468 nd_iterator_step(osc, os_chunks, occ, oc_chunks);
469 }
470 });
471 }
472
473 return status::success;
474}
475
476template struct brgemm_inner_product_fwd_t<avx2_vnni_2>;
477template struct brgemm_inner_product_fwd_t<avx512_core>;
478template struct brgemm_inner_product_fwd_t<avx512_core_bf16>;
479template struct brgemm_inner_product_fwd_t<avx512_core_vnni>;
480template struct brgemm_inner_product_fwd_t<avx512_core_amx>;
481template struct brgemm_inner_product_fwd_t<avx512_core_fp16>;
482template struct brgemm_inner_product_fwd_t<avx512_core_amx_fp16>;
483
484template <cpu_isa_t isa>
485void brgemm_inner_product_bwd_data_t<isa>::execute_backward_data(
486 const exec_ctx_t &ctx) const {
487
488 auto diff_dst_ = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST);
489 auto weights_ = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
490 auto diff_src_ = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC);
491
492 auto diff_src = const_cast<char *>(diff_src_);
493 auto weights = const_cast<char *>(weights_);
494 auto diff_dst = const_cast<char *>(diff_dst_);
495
496 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
497 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
498 const memory_desc_wrapper weights_d(pd()->weights_md(0));
499
500 const auto &jbgp = pd()->jbgp_;
501
502 const bool is_f32 = everyone_is(f32, jbgp.src_dt, jbgp.wei_dt, jbgp.dst_dt);
503 const bool is_bf16 = everyone_is(bf16, jbgp.wei_dt, jbgp.dst_dt);
504 const bool is_f16 = everyone_is(f16, jbgp.wei_dt, jbgp.dst_dt);
505 const bool is_f32_out = jbgp.src_dt == f32;
506 const bool is_amx = jbgp.is_amx;
507 const size_t buf_dt_size = types::data_type_size(
508 isa == avx512_core_fp16 ? f32 : jbgp.wei_dt);
509
510 const dim_t wei_dt_size = types::data_type_size(jbgp.wei_dt);
511
512 memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
513 brgemm_batch_element_t *addr_batch_global
514 = scratchpad.template get<brgemm_batch_element_t>(
515 key_brgemm_primitive_batch);
516 char *c_buffer_global = (jbgp.use_buffer)
517 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
518 : nullptr;
519 char *b_buffer_global = jbgp.use_buffer_b
520 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b)
521 : nullptr;
522 char *a_buffer_global = jbgp.use_buffer_a
523 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_a)
524 : nullptr;
525 auto wsp_tile_base = is_amx
526 ? ctx.get_scratchpad_grantor().template get<char>(
527 key_conv_amx_tile_buffer)
528 : nullptr;
529
530 const int oc_chunks = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking);
531 bool is_os_tail = (jbgp.mb < jbgp.os_block);
532 bool is_ic_tail = (jbgp.ic < jbgp.ic_block);
533 bool is_oc_tail = (jbgp.oc < jbgp.oc_block) && !jbgp.use_buffer_a;
534
535 const dim_t acc_dt_sz = types::data_type_size(jbgp.acc_dt);
536 const dim_t src_dt_sz = types::data_type_size(jbgp.src_dt);
537
538 const int base_brg_ker_idx = brgemm_inner_product_utils::
539 get_brg_kernel_index( // TODO: Can be calculated on initialization stage
540 jbgp, false, false, is_os_tail, is_ic_tail, is_oc_tail);
541
542 const int os_chunks = div_up(jbgp.nb_os, jbgp.nb_os_blocking);
543 const int work_amount = jbgp.nb_ic * os_chunks;
544 const int num_threads = (work_amount == 1 ? 1 : jbgp.nthr);
545
546 const auto get_weights_ptr = [&](int icb, int ocb) {
547 int fwd_ic_block
548 = (is_amx && !jbgp.is_bf32) ? 2 * jbgp.simd_w : jbgp.simd_w;
549 int fwd_oc_block = 0;
550 switch (jbgp.wei_tag) {
551 case OI16i64o:
552 case OIw16i64o:
553 case OIhw16i64o:
554 case OIdhw16i64o:
555 case OI8i64o2i:
556 case OIw8i64o2i:
557 case OIhw8i64o2i:
558 case OIdhw8i64o2i:
559 case OI16i64o2i:
560 case OIw16i64o2i:
561 case OIhw16i64o2i:
562 case OIdhw16i64o2i: fwd_oc_block = 4 * jbgp.simd_w; break;
563 case OI16i32o:
564 case OIw16i32o:
565 case OIhw16i32o:
566 case OIdhw16i32o:
567 case OI8i32o2i:
568 case OIw8i32o2i:
569 case OIhw8i32o2i:
570 case OIdhw8i32o2i:
571 case OI16i32o2i:
572 case OIw16i32o2i:
573 case OIhw16i32o2i:
574 case OIdhw16i32o2i: fwd_oc_block = 2 * jbgp.simd_w; break;
575 default: fwd_oc_block = jbgp.simd_w;
576 };
577 int fwd_icb = icb * jbgp.ic_block / fwd_ic_block;
578 int fwd_ocb = ocb * jbgp.oc_block / fwd_oc_block;
579 char *ptr_wei_local = weights
580 + get_blk_off(weights_d, jbgp.wei_dt, fwd_ocb, fwd_icb);
581
582 int fwd_ocb_simd = (ocb * jbgp.oc_block) % fwd_oc_block;
583 int fwd_icb_simd = (icb * jbgp.ic_block) % fwd_ic_block;
584 int blk_sz = is_bf16 || (is_f16 && isa != avx512_core_fp16) ? 2 : 1;
585
586 return ptr_wei_local
587 + wei_dt_size
588 * (fwd_icb_simd / blk_sz * blk_sz * fwd_oc_block
589 + blk_sz * fwd_ocb_simd);
590 };
591
592 const auto transform_b_chunk
593 = [&](char *tr_wei, const char *wei, int trans_batch, int current_N,
594 int current_K) {
595 auto ctx = jit_brgemm_trans_wei_t::ctx_t();
596 ctx.src = (void *)wei;
597 ctx.tr_src = (void *)tr_wei;
598 ctx.current_gemm_batch = trans_batch;
599 ctx.current_N = current_N;
600 ctx.current_K = current_K;
601 (*trans_B_kernel_)(&ctx);
602 };
603
604 const auto ker = [&](int ithr_ic_mb, int nthr_ic_mb, int ithr_oc,
605 int nthr_oc, int n, int icb, int occ, bool do_init,
606 bool do_b_transpose) {
607 const int ithr = nthr_ic_mb * ithr_oc + ithr_ic_mb;
608 brgemm_batch_element_t *addr_batch
609 = addr_batch_global + ithr * jbgp.adjusted_batch_size;
610
611 const int ic = icb * jbgp.ic_block;
612 const int ocb = occ * jbgp.nb_oc_blocking;
613 const int oc = ocb * jbgp.oc_block;
614 const size_t dsrc_off = get_blk_off(diff_src_d, jbgp.src_dt, n, ic);
615 const int adj_buffers = (jbgp.src_dt == f32) ? 1 : 0;
616 const size_t c_buf_shift = jbgp.nthr_oc_b > 1
617 ? (ithr_oc - adj_buffers)
618 * static_cast<size_t>(jbgp.mb * jbgp.LDC)
619 : ithr * static_cast<size_t>(jbgp.LDC * jbgp.M);
620 const size_t c_buf_off
621 = types::data_type_size(jbgp.acc_dt) * c_buf_shift
622 + (jbgp.nthr_oc_b > 1 ? acc_dt_sz * dsrc_off / src_dt_sz : 0);
623 bool use_c_buf = false;
624 if (is_f32_out && jbgp.use_buffer) {
625 use_c_buf = (jbgp.nthr_oc_b == 1 || ithr_oc > 0);
626 } else if (!is_f32_out && jbgp.use_buffer) {
627 if (jbgp.nthr_oc_b > 1)
628 use_c_buf = true;
629 else
630 use_c_buf = (jbgp.nthr_oc_b == 1 || ithr_oc > 0);
631 }
632
633 const size_t a_buffer_size_per_thr
634 = jbgp.os_block * jbgp.LDA * types::data_type_size(jbgp.dst_dt);
635 char *c_buffer = use_c_buf ? c_buffer_global + c_buf_off : nullptr;
636 char *a_buffer = jbgp.use_buffer_a
637 ? a_buffer_global + ithr * a_buffer_size_per_thr
638 : diff_dst;
639 char *wsp_tile = is_amx
640 ? wsp_tile_base + ithr * jbgp.amx_buf_size_per_thread
641 : nullptr;
642
643 bool kernel_init = do_init;
644
645 const bool is_os_tail = (jbgp.mb - n < jbgp.os_block);
646 const bool is_ic_tail = (jbgp.ic - ic < jbgp.ic_block);
647 const bool is_last_oc_chunk = occ == oc_chunks - 1;
648 const bool is_oc_tail = is_last_oc_chunk && jbgp.K_tail > 0;
649
650 const int rnd_oc
651 = rnd_up(jbgp.oc, jbgp.use_buffer_a ? jbgp.oc_block : 1);
652 const int nb_oc_b
653 = nstl::min((rnd_oc - oc) / jbgp.oc_block, jbgp.nb_oc_blocking);
654
655 auto is_bs_tail = (nb_oc_b != jbgp.nb_oc_blocking);
656 const int brg_ker_idx
657 = brgemm_inner_product_utils::get_brg_kernel_index(jbgp,
658 is_bs_tail, kernel_init, is_os_tail, is_ic_tail, false);
659 auto brg_kernel = brg_kernels_[brg_ker_idx].get();
660
661 const int size_B = jbgp.LDB * rnd_up(jbgp.K, 2);
662
663 const size_t b_buf_shift = jbgp.ip_bwd_d_global_b_transpose
664 ? icb * jbgp.nb_oc + ocb
665 : ithr * jbgp.gemm_batch_size;
666 const size_t b_buf_off = buf_dt_size * b_buf_shift * size_B;
667 char *b_buffer = b_buffer_global + b_buf_off;
668
669 char *ptr_D = diff_src + dsrc_off;
670 char *ptr_C = use_c_buf ? c_buffer : ptr_D;
671
672 if (jbgp.use_buffer_a)
673 copy_data_chunk(copy_diff_dst_kernel_, a_buffer,
674 diff_dst + get_blk_off(diff_dst_d, jbgp.dst_dt, n, oc),
675 is_os_tail ? jbgp.os - n : jbgp.os_block, is_last_oc_chunk);
676
677 if (nb_oc_b > 0 && brg_kernel != nullptr) {
678 if (is_amx && (is_os_tail || is_ic_tail))
679 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
680
681 for (int oc_block = 0; oc_block < nb_oc_b; oc_block++) {
682 addr_batch[oc_block].ptr.A = jbgp.use_buffer_a ? a_buffer
683 + oc_block * jbgp.oc_block
684 * types::data_type_size(jbgp.dst_dt)
685 : diff_dst
686 + get_blk_off(diff_dst_d, jbgp.dst_dt, n,
687 oc + oc_block * jbgp.oc_block);
688 addr_batch[oc_block].ptr.B
689 = b_buffer + buf_dt_size * (oc_block * size_B);
690 if (!jbgp.ip_bwd_d_global_b_transpose && do_b_transpose)
691 transform_b_chunk((char *)addr_batch[oc_block].ptr.B,
692 get_weights_ptr(icb, ocb + oc_block), 1,
693 is_ic_tail ? jbgp.ic % jbgp.ic_block
694 : jbgp.ic_block,
695 jbgp.oc_block);
696 }
697
698 if (jbgp.use_buffer && (jbgp.nthr_oc_b <= 1 || num_threads == 1)
699 && is_last_oc_chunk && !is_oc_tail) {
700 void *scratch
701 = is_amx ? static_cast<void *>(wsp_tile) : nullptr;
702 const brgemm_post_ops_data_t empty_po_data {};
703 brgemm_kernel_execute_postops(brg_kernel, nb_oc_b, addr_batch,
704 (void *)c_buffer, (void *)ptr_D, empty_po_data,
705 scratch);
706
707 } else {
708 brgemm_kernel_execute(brg_kernel, nb_oc_b, addr_batch,
709 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
710 }
711 if (is_amx && (is_os_tail || is_ic_tail))
712 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
713 }
714 if (is_oc_tail) {
715 assert(!jbgp.use_buffer_a);
716
717 const int oc_block = nb_oc_b;
718 addr_batch[0].ptr.A = diff_dst
719 + get_blk_off(diff_dst_d, jbgp.dst_dt, n,
720 oc + oc_block * jbgp.oc_block);
721 addr_batch[0].ptr.B = b_buffer + buf_dt_size * (oc_block * size_B);
722 if (!jbgp.ip_bwd_d_global_b_transpose && do_b_transpose) {
723 transform_b_chunk((char *)addr_batch[0].ptr.B,
724 get_weights_ptr(icb, ocb + oc_block), 1,
725 is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block,
726 jbgp.K_tail);
727 }
728
729 auto use_init_ker = (kernel_init && nb_oc_b == 0);
730 const int brg_kernel_oc_tail_idx
731 = brgemm_inner_product_utils::get_brg_kernel_index(jbgp,
732 false, use_init_ker, is_os_tail, is_ic_tail, true);
733 auto brg_kernel_oc_tail
734 = brg_kernels_[brg_kernel_oc_tail_idx].get();
735 if (is_amx)
736 amx_tile_configure(
737 &brg_kernel_palettes_[brg_kernel_oc_tail_idx][0]);
738 if (jbgp.use_buffer && jbgp.nthr_oc_b <= 1) {
739 void *scratch
740 = is_amx ? static_cast<void *>(wsp_tile) : nullptr;
741 const brgemm_post_ops_data_t empty_po_data {};
742 brgemm_kernel_execute_postops(brg_kernel_oc_tail, 1, addr_batch,
743 (void *)c_buffer, (void *)ptr_D, empty_po_data,
744 scratch);
745
746 } else {
747 brgemm_kernel_execute(brg_kernel_oc_tail, 1, addr_batch,
748 (void *)ptr_C, is_amx ? (void *)wsp_tile : nullptr);
749 }
750 if (is_amx)
751 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
752 }
753 };
754
755 if (jbgp.ip_bwd_d_global_b_transpose && jbgp.use_buffer_b) {
756 assert(IMPLICATION(
757 jbgp.ip_bwd_d_global_b_transpose, jbgp.nthr_oc_b == 1));
758 parallel(num_threads, [&](const int ithr, const int nthr) {
759 int start {0}, end {0};
760 int max_ch_block = nstl::max(jbgp.ic_block, jbgp.oc_block);
761 int ic_chunk_sz = max_ch_block / jbgp.ic_block;
762 int oc_chunk_sz = max_ch_block / jbgp.oc_block;
763 int nc_ic = utils::div_up(jbgp.nb_ic, ic_chunk_sz);
764 int nc_oc = utils::div_up(jbgp.nb_oc, oc_chunk_sz);
765 int transp_work_amount = nc_ic * nc_oc;
766 balance211(transp_work_amount, nthr, ithr, start, end);
767 int icc, occ;
768 nd_iterator_init(start, icc, nc_ic, occ, nc_oc);
769 while (start < end) {
770 int icb_start = icc * ic_chunk_sz;
771 int icb_end = nstl::min((icc + 1) * ic_chunk_sz, jbgp.nb_ic);
772 int ocb_start = occ * oc_chunk_sz;
773 int ocb_end = nstl::min((occ + 1) * oc_chunk_sz, jbgp.nb_oc);
774 for_(int icb = icb_start; icb < icb_end; icb++)
775 for (int ocb = ocb_start; ocb < ocb_end; ocb++) {
776 int ic = icb * jbgp.ic_block;
777 int oc = ocb * jbgp.oc_block;
778 bool is_ic_tail = (jbgp.ic - ic < jbgp.ic_block);
779 bool is_oc_tail = (jbgp.oc - oc < jbgp.oc_block);
780 const int size_B = jbgp.LDB * rnd_up(jbgp.K, 2);
781 char *b_buffer = b_buffer_global
782 + buf_dt_size
783 * ((dim_t)icb * jbgp.nb_oc * size_B
784 + (dim_t)ocb * size_B);
785
786 transform_b_chunk(b_buffer, get_weights_ptr(icb, ocb), 1,
787 is_ic_tail ? jbgp.ic % jbgp.ic_block
788 : jbgp.ic_block,
789 is_oc_tail ? jbgp.oc % jbgp.oc_block
790 : jbgp.oc_block);
791 }
792 ++start;
793 nd_iterator_step(icc, nc_ic, occ, nc_oc);
794 }
795 });
796 }
797
798 parallel(num_threads, [&](const int ithr, const int nthr) {
799 const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1;
800 const int nthr_ic_mb = nthr / nthr_oc;
801 const int ithr_ic_mb = ithr % nthr_ic_mb;
802 const int ithr_oc = ithr / nthr_ic_mb;
803 if (ithr_ic_mb >= work_amount || ithr_oc >= oc_chunks
804 || ithr >= rnd_dn(nthr, nthr_oc))
805 return;
806
807 int start {0}, end {0};
808 balance211(work_amount, nthr_ic_mb, ithr_ic_mb, start, end);
809 int occ_start {0}, occ_end {oc_chunks};
810 if (nthr_oc > 1)
811 balance211(oc_chunks, nthr_oc, ithr_oc, occ_start, occ_end);
812
813 if (is_amx)
814 amx_tile_configure(&brg_kernel_palettes_[base_brg_ker_idx][0]);
815
816 int icb {0}, oss {0};
817 nd_iterator_init(start, oss, os_chunks, icb, jbgp.nb_ic);
818 while (start < end) {
819 const int nb_os_blocking
820 = nstl::min(jbgp.nb_os - oss * jbgp.nb_os_blocking,
821 jbgp.nb_os_blocking);
822 const int occ_work = occ_end - occ_start;
823 const int loop_iteration = nb_os_blocking * occ_work;
824
825 for (int iter = 0; iter < loop_iteration; ++iter) {
826 int osb = 0, occ = occ_start;
827 if (jbgp.use_buffer || !is_f32) {
828 osb += iter / occ_work;
829 occ += iter % occ_work;
830 } else {
831 occ += iter / nb_os_blocking;
832 osb += iter % nb_os_blocking;
833 }
834 int n = (oss * jbgp.nb_os_blocking + osb) * jbgp.os_block;
835 ker(ithr_ic_mb, nthr_ic_mb, ithr_oc, nthr_oc, n, icb, occ,
836 occ == occ_start, osb == 0 || occ_work > 1);
837 }
838 ++start;
839 nd_iterator_step(oss, os_chunks, icb, jbgp.nb_ic);
840 }
841 if (is_amx) amx_tile_release();
842 });
843
844 if (jbgp.nthr_oc_b > 1) {
845 parallel(num_threads, [&](const int ithr, const int nthr) {
846 const int nthr_oc = jbgp.nthr_oc_b <= nthr ? jbgp.nthr_oc_b : 1;
847 if (nthr_oc <= 1) return;
848
849 const int ddst_elems = jbgp.LDC * jbgp.os;
850 const int reduce_chunk_size = 64;
851 int start {0}, end {0};
852 balance211(div_up(ddst_elems, reduce_chunk_size), nthr, ithr, start,
853 end);
854 const dim_t reduce_start = start * reduce_chunk_size;
855 const dim_t reduce_finish
856 = nstl::min(end * reduce_chunk_size, ddst_elems);
857 if (reduce_finish <= reduce_start) return;
858 const dim_t elems_to_reduce = reduce_finish - reduce_start;
859 const dim_t acc_dt_sz = types::data_type_size(jbgp.acc_dt);
860
861 char *dsrc_reduced = diff_src + src_dt_sz * reduce_start;
862 char *c_buffer_start = c_buffer_global + acc_dt_sz * reduce_start;
863
864 float *out_buffer = is_f32_out
865 ? reinterpret_cast<float *>(dsrc_reduced)
866 : reinterpret_cast<float *>(c_buffer_start);
867 int oc_buf_idx = !is_f32_out;
868 int oc_buf_end = is_f32_out;
869 for (int oc_buf = oc_buf_idx; oc_buf < nthr_oc - oc_buf_end;
870 oc_buf++) {
871 const dim_t c_buf_offt = acc_dt_sz
872 * (oc_buf * jbgp.os * jbgp.LDC + reduce_start);
873 char *c_buffer = c_buffer_global + c_buf_offt;
874
875 acc_ker_->accumulate((float *)out_buffer, (float *)c_buffer,
876 elems_to_reduce);
877 if (!is_f32_out && oc_buf == (nthr_oc - oc_buf_end) - 1) {
878 if (is_bf16) {
879 cvt_float_to_bfloat16((bfloat16_t *)dsrc_reduced,
880 (const float *)out_buffer, elems_to_reduce);
881 } else if (is_f16) {
882 cvt_float_to_float16((float16_t *)dsrc_reduced,
883 (const float *)out_buffer, elems_to_reduce);
884 }
885 }
886 }
887 });
888 }
889}
890
891template struct brgemm_inner_product_bwd_data_t<avx512_core>;
892template struct brgemm_inner_product_bwd_data_t<avx512_core_amx>;
893template struct brgemm_inner_product_bwd_data_t<avx512_core_bf16>;
894template struct brgemm_inner_product_bwd_data_t<avx512_core_amx_fp16>;
895template struct brgemm_inner_product_bwd_data_t<avx512_core_fp16>;
896
897template <cpu_isa_t isa>
898struct brgemm_inner_product_bwd_weights_t<isa>::thread_info_t {
899 const char *src;
900 const char *diff_dst;
901 char *diff_weights;
902 char *diff_bias;
903
904 const memory_tracking::grantor_t scratchpad;
905
906 char *buffer_c = nullptr;
907 char *buffer_bias = nullptr;
908 char *wsp_tile_base = nullptr;
909
910 int ithr;
911 int ithr_ic_c, ithr_oc_c, ithr_os_c;
912 int nthr;
913 int nthr_ic_c, nthr_oc_c, nthr_os_c;
914
915 int os_c_start = 0, os_c_end = 0, os_c_work;
916 int oc_c_start = 0, oc_c_end = 0, oc_c_work;
917 int ic_c_start = 0, ic_c_end = 0, ic_c_work;
918 simple_barrier::ctx_t *barrier_ctx;
919
920 thread_info_t(const brgemm_inner_product_bwd_weights_t *self,
921 const exec_ctx_t &ctx, int ithr)
922 : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) {
923
924 src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
925 diff_dst = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST);
926 diff_weights = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_WEIGHTS);
927 diff_bias = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_BIAS);
928 const auto &jbgp = self->pd()->jbgp_;
929
930 const bool is_amx = jbgp.is_amx;
931
932 buffer_c = (jbgp.use_buffer)
933 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
934 : nullptr;
935
936 buffer_bias = (jbgp.with_bias
937 && (jbgp.bia_dt != data_type::f32
938 || jbgp.nthr_mb > 1))
939 ? scratchpad.template get<char>(key_iprod_bias_bf16_convert_wsp)
940 : nullptr;
941
942 buffer_a_
943 = scratchpad.template get<char>(key_brgemm_primitive_buffer_a);
944 buffer_b_ = jbgp.use_buffer_b
945 ? scratchpad.template get<char>(key_brgemm_primitive_buffer_b)
946 : nullptr;
947
948 thread_local_input_buffers_
949 = jbgp.ip_bwd_w_local_buffers_for_input_tensors;
950 int ic_chunks = utils::div_up(jbgp.nb_ic, jbgp.nb_ic_blocking);
951 int os_chunks = utils::div_up(jbgp.nb_os, jbgp.nb_os_blocking);
952 nb_ic_blocking_ = jbgp.nb_ic_blocking;
953 nb_oc_blocking_ = jbgp.nb_oc_blocking;
954 const size_t os_chunks_per_thr = utils::div_up(os_chunks, jbgp.nthr_mb);
955 const size_t num_os_chunks_per_thread
956 = thread_local_input_buffers_ ? 1 : os_chunks_per_thr;
957
958 if (jbgp.use_buffer_a) {
959 const size_t dt_sz = buf_dt_size(jbgp.src_dt, jbgp.isa);
960 const size_t ic_chunks_per_thr
961 = utils::div_up(ic_chunks, jbgp.nthr_ic_b);
962 const size_t num_ic_chunks_per_thread
963 = thread_local_input_buffers_ ? 1 : ic_chunks_per_thr;
964 const size_t block_A_size = dt_sz * jbgp.LDA * jbgp.M;
965 const size_t os_chunk_A_buffer
966 = jbgp.gemm_batch_size * block_A_size;
967 const size_t ic_os_chunk_A_buffer
968 = jbgp.nb_ic_blocking * os_chunk_A_buffer;
969
970 buffer_a_icb_shift_ = os_chunk_A_buffer;
971 buffer_a_osb_shift_ = block_A_size;
972 buffer_a_osc_shift_ = thread_local_input_buffers_
973 ? 0
974 : ic_chunks_per_thr * ic_os_chunk_A_buffer;
975 const size_t buffer_a_thread_shift = num_ic_chunks_per_thread
976 * num_os_chunks_per_thread * ic_os_chunk_A_buffer;
977
978 buffer_a_ = buffer_a_ + ithr * buffer_a_thread_shift;
979 }
980
981 if (jbgp.use_buffer_b) {
982 const auto buf_dt = jbgp.dst_dt == f16 && isa == avx512_core_fp16
983 ? data_type::f32
984 : jbgp.dst_dt;
985 const size_t dt_sz = buf_dt_size(jbgp.dst_dt, jbgp.isa);
986 assert(types::data_type_size(buf_dt) == dt_sz);
987 const size_t block_B_size = dt_sz * jbgp.LDB * jbgp.K;
988 const size_t os_chunk_B_buffer
989 = jbgp.gemm_batch_size * block_B_size;
990 buffer_b_ocb_shift_ = dt_sz * jbgp.oc_block
991 * data_type_vnni_granularity(buf_dt);
992 buffer_b_osb_shift_ = block_B_size;
993 buffer_b_osc_shift_
994 = thread_local_input_buffers_ ? 0 : os_chunk_B_buffer;
995
996 const size_t buffer_b_thread_shift
997 = num_os_chunks_per_thread * os_chunk_B_buffer;
998
999 buffer_b_ = buffer_b_ + ithr * buffer_b_thread_shift;
1000 }
1001
1002 wsp_tile_base = is_amx
1003 ? ctx.get_scratchpad_grantor().template get<char>(
1004 key_conv_amx_tile_buffer)
1005 : nullptr;
1006
1007 nthr = jbgp.nthr;
1008 nthr_ic_c = jbgp.nthr_ic_b;
1009 nthr_oc_c = jbgp.nthr_oc_b;
1010 nthr_os_c = jbgp.nthr_mb;
1011
1012 ithr_ic_c = ithr % nthr_ic_c;
1013 ithr_oc_c = ithr / nthr_ic_c % nthr_oc_c;
1014 ithr_os_c = ithr / nthr_ic_c / nthr_oc_c;
1015
1016 int oc_chunks = utils::div_up(jbgp.nb_oc, jbgp.nb_oc_blocking);
1017
1018 /* reduction dimension */
1019 balance211(os_chunks, nthr_os_c, ithr_os_c, os_c_start, os_c_end);
1020 os_c_work = os_c_end - os_c_start;
1021
1022 balance211(oc_chunks, nthr_oc_c, ithr_oc_c, oc_c_start, oc_c_end);
1023 oc_c_work = oc_c_end - oc_c_start;
1024
1025 balance211(ic_chunks, nthr_ic_c, ithr_ic_c, ic_c_start, ic_c_end);
1026 ic_c_work = ic_c_end - ic_c_start;
1027
1028 if (dnnl_thr_syncable())
1029 barrier_ctx = scratchpad.template get<simple_barrier::ctx_t>(
1030 key_conv_wei_bia_reduction_bctx);
1031 }
1032
1033 char *get_buffer_a_ptr(int icb, int osc) const {
1034 if (!buffer_a_) return (char *)nullptr;
1035
1036 const int icb_idx = thread_local_input_buffers_
1037 ? icb % nb_ic_blocking_
1038 : icb - ic_c_start * nb_ic_blocking_;
1039 const int osc_idx = thread_local_input_buffers_ ? 0 : osc - os_c_start;
1040
1041 return buffer_a_ + osc_idx * buffer_a_osc_shift_
1042 + icb_idx * buffer_a_icb_shift_;
1043 }
1044
1045 char *get_buffer_b_ptr(int ocb, int osc) const {
1046 if (!buffer_b_) return (char *)nullptr;
1047
1048 const int ocb_idx = ocb % nb_oc_blocking_;
1049 const int osc_idx = thread_local_input_buffers_ ? 0 : osc - os_c_start;
1050
1051 return buffer_b_ + osc_idx * buffer_b_osc_shift_
1052 + ocb_idx * buffer_b_ocb_shift_;
1053 }
1054
1055 size_t get_buffer_a_osb_shift() const { return buffer_a_osb_shift_; }
1056 size_t get_buffer_b_osb_shift() const { return buffer_b_osb_shift_; }
1057
1058private:
1059 char *buffer_a_ = nullptr;
1060 char *buffer_b_ = nullptr;
1061
1062 bool thread_local_input_buffers_ = false;
1063 int nb_ic_blocking_ = 1;
1064 int nb_oc_blocking_ = 1;
1065 size_t buffer_a_icb_shift_ = 0;
1066 size_t buffer_a_osc_shift_ = 0;
1067 size_t buffer_a_osb_shift_ = 0;
1068 size_t buffer_b_ocb_shift_ = 0;
1069 size_t buffer_b_osc_shift_ = 0;
1070 size_t buffer_b_osb_shift_ = 0;
1071};
1072
1073template <cpu_isa_t isa>
1074void brgemm_inner_product_bwd_weights_t<isa>::transform_matrix_a_chunk(
1075 char *tr_src, const char *src, int trans_batch, int current_m,
1076 int current_k) const {
1077 auto ctx = jit_brgemm_trans_src_t::ctx_t();
1078 ctx.src = (void *)src;
1079 ctx.tr_src = (void *)tr_src;
1080 ctx.current_gemm_batch = trans_batch;
1081 ctx.current_M = current_m;
1082 ctx.current_K = current_k;
1083 (*trans_A_kernel_)(&ctx);
1084}
1085
1086template <cpu_isa_t isa>
1087void brgemm_inner_product_bwd_weights_t<isa>::transform_matrix_b_chunk(
1088 char *tr_diff_dst, const char *diff_dst, int trans_batch,
1089 int current_col_size, int current_row_size) const {
1090 auto ctx = jit_brgemm_trans_to_vnni_t::ctx_t();
1091 ctx.src = (void *)diff_dst;
1092 ctx.tr_src = (void *)tr_diff_dst;
1093 ctx.current_gemm_batch = trans_batch;
1094 ctx.current_col_size = current_col_size;
1095 ctx.current_row_size = current_row_size;
1096 (*trans_B_kernel_)(&ctx);
1097}
1098
1099template <cpu_isa_t isa>
1100void brgemm_inner_product_bwd_weights_t<isa>::transpose_matrix_c_chunk(
1101 const thread_info_t *ti, const int ocb, const int icb, int oc_size,
1102 int ic_size, bool is_reduction) const {
1103 const auto &jbgp = pd()->jbgp_;
1104
1105 if (jbgp.is_amx) {
1106 auto p = jit_amx_ip_trans_diff_wei::ctx_t();
1107
1108 const dim_t ext_nb_ic = div_up(jbgp.ic, ext_ic_block_);
1109 dim_t icb_shift = (icb * (jbgp.ic_block / ext_ic_block_))
1110 * ext_ic_block_ * ext_oc_block_;
1111
1112 dim_t ocb_shift = (ocb * (jbgp.oc_block / ext_oc_block_)) * ext_nb_ic
1113 * ext_ic_block_ * ext_oc_block_;
1114 dim_t out_offset = ocb_shift + icb_shift;
1115
1116 p.src = get_wei_acc_ptr(ti, ocb, icb, 0);
1117 p.dst = (void *)(ti->diff_weights
1118 + types::data_type_size(jbgp.wei_dt) * out_offset);
1119
1120 p.last_ic_block = (jbgp.ic <= ext_ic_block_
1121 || (jbgp.nb_ic > 1 && icb == jbgp.nb_ic - 1))
1122 ? 1
1123 : 0;
1124 p.last_oc_block = (jbgp.oc <= ext_oc_block_
1125 || (jbgp.nb_oc > 1 && ocb == jbgp.nb_oc - 1))
1126 ? 1
1127 : 0;
1128 (*diff_wei_trans_kernel_)(&p);
1129 } else {
1130 auto ctx = jit_brgemm_trans_to_vnni_t::ctx_t();
1131 ctx.src = (void *)(get_wei_acc_ptr(ti, ocb, icb, 0));
1132
1133 ctx.tr_src = (void *)(ti->diff_weights
1134 + types::data_type_size(jbgp.wei_dt)
1135 * get_wei_offset(ocb, icb));
1136
1137 ctx.current_gemm_batch = 1;
1138 ctx.current_col_size = oc_size;
1139 ctx.current_row_size = ic_size;
1140 (*trans_C_kernel_)(&ctx);
1141 }
1142}
1143
1144template <cpu_isa_t isa>
1145dim_t brgemm_inner_product_bwd_weights_t<isa>::get_wei_offset(
1146 int ocb, int icb) const {
1147 const auto &jbgp = pd()->jbgp_;
1148 if (jbgp.is_amx) {
1149 const dim_t offset
1150 = jbgp.kd * jbgp.kh * jbgp.kw * jbgp.ic_block * jbgp.oc_block;
1151 return (ocb * jbgp.nb_ic + icb) * offset;
1152 } else {
1153 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1154 return diff_weights_d.blk_off(ocb, icb);
1155 }
1156}
1157
1158template <cpu_isa_t isa>
1159char *brgemm_inner_product_bwd_weights_t<isa>::get_wei_acc_ptr(
1160 const thread_info_t *ti, int ocb, int icb,
1161 int reduction_buf_idx) const {
1162 const auto &jbgp = pd()->jbgp_;
1163
1164 const int reduction_buf_start_idx = jbgp.wei_dt == f32;
1165 // reduction_buf_idx argument allows manually set up required reduction
1166 // buffer index, required for reduction and transform diff_weights parts.
1167 // It has value -1 by default. If reduction_buf_idx < 0 then ti->ithr_os_c
1168 // is used for calculation of the current reduction index.
1169 const int buf_idx = reduction_buf_idx >= 0
1170 ? reduction_buf_idx
1171 : (ti->ithr_os_c - reduction_buf_start_idx);
1172 const size_t acc_dt_size = types::data_type_size(jbgp.acc_dt);
1173
1174 if ((jbgp.nthr_mb > 1 && buf_idx < 0)
1175 || (jbgp.wei_dt == jbgp.acc_dt && reduction_buf_idx < 0
1176 && ti->ithr_os_c == 0)) {
1177 MAYBE_UNUSED(reduction_buf_idx);
1178 const int icb_scale = (!jbgp.is_amx || jbgp.wei_dt == jbgp.acc_dt)
1179 ? jbgp.ic_block / jbgp.simd_w
1180 : 1;
1181 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1182 return (char *)ti->diff_weights
1183 + get_blk_off(
1184 diff_weights_d, jbgp.wei_dt, ocb, icb * icb_scale);
1185 }
1186
1187 if (!jbgp.use_buffer) return nullptr;
1188
1189 const int ocb_l = ocb % jbgp.nb_oc_blocking;
1190 const int icb_l = icb % jbgp.nb_ic_blocking;
1191
1192 if (jbgp.nthr_mb > 1 || jbgp.harness == harness_mb_reduction) {
1193 const size_t icc = icb / jbgp.nb_ic_blocking;
1194 const size_t occ = ocb / jbgp.nb_oc_blocking;
1195 const size_t num_ic_chunks = div_up(jbgp.nb_ic, jbgp.nb_ic_blocking);
1196 const size_t num_oc_chunks = div_up(jbgp.nb_oc, jbgp.nb_oc_blocking);
1197 const size_t block_size = acc_dt_size * jbgp.ic_block * jbgp.oc_block;
1198 const size_t chunk_size
1199 = block_size * jbgp.nb_ic_blocking * jbgp.nb_oc_blocking;
1200 const size_t reduction_buf_shift
1201 = num_ic_chunks * num_oc_chunks * chunk_size * buf_idx;
1202 return ti->buffer_c + reduction_buf_shift
1203 + (occ * num_ic_chunks + icc) * chunk_size
1204 + (ocb_l * jbgp.nb_ic_blocking + icb_l) * block_size;
1205 } else if (jbgp.nthr_mb == 1) {
1206 MAYBE_UNUSED(reduction_buf_idx);
1207 const size_t blk_size = acc_dt_size * jbgp.ic_block * jbgp.oc_block;
1208 const size_t buf_size_per_thread
1209 = blk_size * jbgp.nb_ic_blocking * jbgp.nb_oc_blocking;
1210 const size_t offset_within_thread_buf
1211 = blk_size * (jbgp.nb_ic_blocking * ocb_l + icb_l);
1212 const size_t offset
1213 = ti->ithr * buf_size_per_thread + offset_within_thread_buf;
1214 return ti->buffer_c + offset;
1215 }
1216
1217 assert(!"unsupported case");
1218 return nullptr;
1219};
1220
1221template <cpu_isa_t isa>
1222void brgemm_inner_product_bwd_weights_t<isa>::compute_diff_weights_and_bias(
1223 const thread_info_t *ti) const {
1224 auto diff_dst = const_cast<char *>(ti->diff_dst);
1225 auto diff_bias = ti->diff_bias;
1226
1227 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1228
1229 const auto &jbgp = pd()->jbgp_;
1230
1231 const size_t bia_dt_size
1232 = jbgp.with_bias ? types::data_type_size(jbgp.bia_dt) : 0;
1233 const size_t acc_dt_size = types::data_type_size(jbgp.acc_dt);
1234
1235 const int oc_chunk_sz = jbgp.oc_block * jbgp.nb_oc_blocking;
1236
1237 brgemm_batch_element_t *addr_batch_global
1238 = ti->scratchpad.template get<brgemm_batch_element_t>(
1239 key_brgemm_primitive_batch);
1240
1241 const bool is_bf16 = jbgp.wei_dt == bf16;
1242 const bool is_amx_bf16 = is_bf16 && isa == avx512_core_amx;
1243 char *wsp_tile_global = (is_amx_bf16) ? ti->wsp_tile_base : nullptr;
1244 int os_chunks = utils::div_up(jbgp.nb_os, jbgp.nb_os_blocking);
1245
1246 const auto get_bia_acc_ptr = [&](int oc) {
1247 const int reduction_buf_start_idx = jbgp.bia_dt == f32;
1248 if (jbgp.bia_dt != data_type::f32
1249 || (jbgp.nthr_mb > 1
1250 && ti->ithr_os_c >= reduction_buf_start_idx)) {
1251 return ti->buffer_bias
1252 + acc_dt_size * (ti->ithr_os_c - reduction_buf_start_idx)
1253 * jbgp.oc
1254 + acc_dt_size * oc;
1255 } else {
1256 return ti->diff_bias + bia_dt_size * oc;
1257 }
1258 };
1259
1260 const auto a_buf_osb_shift = ti->get_buffer_a_osb_shift();
1261 const auto b_buf_osb_shift = ti->get_buffer_b_osb_shift();
1262
1263 const auto ker = [&](const int osc, const int icb, const int ocb,
1264 const int osc_prev, const int icc_prev,
1265 const int occ_prev) {
1266 brgemm_batch_element_t *addr_batch
1267 = addr_batch_global + ti->ithr * jbgp.adjusted_batch_size;
1268 char *wsp_tile = is_amx_bf16
1269 ? wsp_tile_global + ti->ithr * jbgp.amx_buf_size_per_thread
1270 : nullptr;
1271 int ic = icb * jbgp.ic_block;
1272 int oc = ocb * jbgp.oc_block;
1273 int osb = osc * jbgp.nb_os_blocking;
1274 int n = osb * jbgp.os_block;
1275
1276 // Base buffer pointers for the current kernel iteration,
1277 // x_buf_osb_shift values are used for shifting wrt osb iter variable
1278 char *a_buffer = ti->get_buffer_a_ptr(icb, osc);
1279 char *b_buffer = ti->get_buffer_b_ptr(ocb, osc);
1280
1281 bool kernel_init = (osc == ti->os_c_start);
1282
1283 bool is_os_tail = jbgp.mb - n < jbgp.os_block * jbgp.nb_os_blocking;
1284 bool is_ic_tail = jbgp.ic - ic < jbgp.ic_block;
1285 bool is_oc_tail = jbgp.oc - oc < jbgp.oc_block;
1286 const int oc_chunk_tail = jbgp.oc % oc_chunk_sz;
1287 const bool is_last_oc_chunk = jbgp.oc - oc < oc_chunk_sz;
1288 const int curr_oc_chunk_sz = oc_chunk_tail > 0 && is_last_oc_chunk
1289 ? oc_chunk_tail
1290 : oc_chunk_sz;
1291
1292 const bool transform_weights = jbgp.wei_dt != jbgp.acc_dt
1293 && (jbgp.nthr_mb == 1 || os_chunks == 1)
1294 && osc == (os_chunks - 1);
1295 const bool transform_b = jbgp.ip_bwd_w_local_buffers_for_input_tensors
1296 ? jbgp.use_buffer_b && icb % jbgp.nb_ic_blocking == 0
1297 && ocb % jbgp.nb_oc_blocking == 0
1298 && IMPLICATION(osc_prev == osc,
1299 occ_prev != ocb / jbgp.nb_oc_blocking)
1300 : jbgp.use_buffer_b
1301 && icb == ti->ic_c_start * jbgp.nb_ic_blocking
1302 && ocb % jbgp.nb_oc_blocking == 0;
1303 const bool transform_a = jbgp.ip_bwd_w_local_buffers_for_input_tensors
1304 ? jbgp.use_buffer_a && ocb % jbgp.nb_oc_blocking == 0
1305 && IMPLICATION(osc_prev == osc,
1306 icc_prev != icb / jbgp.nb_ic_blocking)
1307 : jbgp.use_buffer_a
1308 && ocb == ti->oc_c_start * jbgp.nb_oc_blocking;
1309
1310 auto nb_os_b = is_os_tail ? (jbgp.mb - n) / jbgp.os_block
1311 : jbgp.nb_os_blocking;
1312
1313 auto is_bs_tail = (nb_os_b != jbgp.nb_os_blocking);
1314 const int brg_ker_idx
1315 = brgemm_inner_product_utils::get_brg_kernel_index(jbgp,
1316 is_bs_tail, kernel_init, is_ic_tail, is_oc_tail, false);
1317 auto brg_kernel = brg_kernels_[brg_ker_idx].get();
1318
1319 if (kernel_init && (is_ic_tail || is_oc_tail))
1320 utils::array_set(get_wei_acc_ptr(ti, ocb, icb), 0,
1321 types::data_type_size(jbgp.acc_dt) * jbgp.ic_block
1322 * jbgp.oc_block);
1323 if (nb_os_b > 0 && brg_kernel != nullptr) {
1324 if (jbgp.is_amx)
1325 amx_tile_configure(&brg_kernel_palettes_[brg_ker_idx][0]);
1326 if (transform_a) {
1327 const memory_desc_wrapper src_d(pd()->src_md());
1328 auto src_ptr = ti->src
1329 + types::data_type_size(jbgp.src_dt)
1330 * src_d.blk_off(n, ic);
1331
1332 transform_matrix_a_chunk(a_buffer, src_ptr, nb_os_b,
1333 is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block,
1334 jbgp.os_block);
1335 }
1336
1337 if (transform_b) {
1338 auto diff_dst_ptr = diff_dst
1339 + types::data_type_size(jbgp.dst_dt)
1340 * diff_dst_d.blk_off(n, oc);
1341 transform_matrix_b_chunk(b_buffer, diff_dst_ptr, nb_os_b,
1342 curr_oc_chunk_sz, jbgp.os_block);
1343 }
1344
1345 for (int os_block = 0; os_block < nb_os_b; os_block++) {
1346 auto a_ptr = a_buffer + os_block * a_buf_osb_shift;
1347 addr_batch[os_block].ptr.A = a_ptr;
1348 auto diff_dst_ptr = diff_dst
1349 + types::data_type_size(jbgp.dst_dt)
1350 * diff_dst_d.blk_off(
1351 n + os_block * jbgp.os_block, oc);
1352 if (jbgp.use_buffer_b) {
1353 auto b_ptr = b_buffer + os_block * b_buf_osb_shift;
1354 addr_batch[os_block].ptr.B = b_ptr;
1355 } else {
1356 addr_batch[os_block].ptr.B = diff_dst_ptr;
1357 }
1358 if (jbgp.with_bias && icb == 0) {
1359 brgemm_kernel_diff_bias_t p;
1360 auto bias_ptr = diff_bias + bia_dt_size * oc;
1361 p.ptr_diff_dst = (void *)addr_batch[os_block].ptr.B;
1362 p.ptr_diff_bias_acc = (void *)get_bia_acc_ptr(oc);
1363 p.ptr_diff_bias = (void *)bias_ptr;
1364 bool is_first = kernel_init && os_block == 0;
1365 bool is_last = (jbgp.nthr_mb == 1 || os_chunks == 1)
1366 && osc == os_chunks - 1 && os_block == nb_os_b - 1
1367 && !is_os_tail;
1368 p.flags = 0 | (is_first ? FLAG_REDUCE_FIRST : 0)
1369 | (is_last ? FLAG_REDUCE_LAST : 0);
1370
1371 (*kernels_db_[false][is_oc_tail])(&p);
1372 }
1373 }
1374 brgemm_kernel_execute(brg_kernel, nb_os_b, addr_batch,
1375 (void *)get_wei_acc_ptr(ti, ocb, icb), wsp_tile);
1376 }
1377
1378 if (is_os_tail) {
1379 int os_block = nb_os_b;
1380 auto a_ptr = a_buffer + os_block * a_buf_osb_shift;
1381
1382 if (transform_a) {
1383 const memory_desc_wrapper src_d(pd()->src_md());
1384 auto src_ptr = ti->src
1385 + types::data_type_size(jbgp.src_dt)
1386 * src_d.blk_off(
1387 n + os_block * jbgp.os_block, ic);
1388 transform_matrix_a_chunk(a_ptr, src_ptr, 1,
1389 is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block,
1390 jbgp.mb % jbgp.os_block);
1391 }
1392
1393 addr_batch[0].ptr.A = a_ptr;
1394 auto diff_dst_ptr = diff_dst
1395 + types::data_type_size(jbgp.dst_dt)
1396 * diff_dst_d.blk_off(
1397 n + os_block * jbgp.os_block, oc);
1398 if (jbgp.use_buffer_b) {
1399 auto b_ptr = b_buffer + os_block * b_buf_osb_shift;
1400
1401 if (transform_b)
1402 transform_matrix_b_chunk(b_ptr, diff_dst_ptr, 1,
1403 curr_oc_chunk_sz, jbgp.mb % jbgp.os_block);
1404 addr_batch[0].ptr.B = b_ptr;
1405 } else {
1406 addr_batch[0].ptr.B = diff_dst_ptr;
1407 }
1408
1409 if (jbgp.with_bias && icb == 0) {
1410 brgemm_kernel_diff_bias_t p;
1411 auto bias_ptr = diff_bias + bia_dt_size * oc;
1412 p.ptr_diff_dst = (void *)addr_batch[0].ptr.B;
1413 p.ptr_diff_bias_acc = (void *)get_bia_acc_ptr(oc);
1414 p.ptr_diff_bias = (void *)bias_ptr;
1415 bool is_first = kernel_init && os_block == 0;
1416 bool is_last = (jbgp.nthr_mb == 1 || os_chunks == 1)
1417 && osc == os_chunks - 1;
1418 p.flags = 0 | (is_first ? FLAG_REDUCE_FIRST : 0)
1419 | (is_last ? FLAG_REDUCE_LAST : 0);
1420
1421 (*kernels_db_[true][is_oc_tail])(&p);
1422 }
1423
1424 auto use_init_ker = (kernel_init && nb_os_b == 0);
1425 const int brg_ker_idx_os_tail
1426 = brgemm_inner_product_utils::get_brg_kernel_index(jbgp,
1427 false, use_init_ker, is_ic_tail, is_oc_tail, true);
1428 auto brg_kernel_os_tail = brg_kernels_[brg_ker_idx_os_tail].get();
1429 if (brg_kernel_os_tail != nullptr) {
1430 if (jbgp.is_amx)
1431 amx_tile_configure(
1432 &brg_kernel_palettes_[brg_ker_idx_os_tail][0]);
1433 brgemm_kernel_execute(brg_kernel_os_tail, 1, addr_batch,
1434 (void *)get_wei_acc_ptr(ti, ocb, icb), wsp_tile);
1435 }
1436 }
1437
1438 if (transform_weights) {
1439 transpose_matrix_c_chunk(ti, ocb, icb,
1440 is_oc_tail ? jbgp.oc % jbgp.oc_block : jbgp.oc_block,
1441 is_ic_tail ? jbgp.ic % jbgp.ic_block : jbgp.ic_block);
1442 }
1443 };
1444
1445 const auto occ_work = (ti->oc_c_end - ti->oc_c_start);
1446 const auto icc_work = (ti->ic_c_end - ti->ic_c_start);
1447 const auto osc_work = (ti->os_c_end - ti->os_c_start);
1448
1449 auto loop_idx = 0;
1450 const auto loop_end = occ_work * icc_work * osc_work;
1451
1452 int occ_idx = 0, icc_idx = 0, osc_idx = 0;
1453 loop_order_t loop_order = jbgp.ip_bwd_w_local_buffers_for_input_tensors
1454 ? loop_order_t::osc_icc_occ
1455 : jbgp.harness == harness_mb_reduction ? loop_order_t::osc_occ_icc
1456 : loop_order_t::occ_icc_osc;
1457
1458 switch (loop_order) {
1459 case loop_order_t::osc_icc_occ:
1460 nd_iterator_init(loop_idx, osc_idx, osc_work, icc_idx, icc_work,
1461 occ_idx, occ_work);
1462 break;
1463 case loop_order_t::osc_occ_icc:
1464 nd_iterator_init(loop_idx, osc_idx, osc_work, occ_idx, occ_work,
1465 icc_idx, icc_work);
1466 break;
1467 case loop_order_t::occ_icc_osc:
1468 nd_iterator_init(loop_idx, occ_idx, occ_work, icc_idx, icc_work,
1469 osc_idx, osc_work);
1470 };
1471
1472 int osc_prev = -1, icc_prev = -1, occ_prev = -1;
1473 while (loop_idx < loop_end) {
1474 const int occ = ti->oc_c_start + occ_idx;
1475 const int icc = ti->ic_c_start + icc_idx;
1476 const int osc = ti->os_c_start + osc_idx;
1477
1478 const int ocb_work = nstl::min(
1479 jbgp.nb_oc_blocking, jbgp.nb_oc - occ * jbgp.nb_oc_blocking);
1480 const int icb_work = nstl::min(
1481 jbgp.nb_ic_blocking, jbgp.nb_ic - icc * jbgp.nb_ic_blocking);
1482
1483 for_(int ocb = 0; ocb < ocb_work; ocb++)
1484 for (int icb = 0; icb < icb_work; icb++) {
1485 ker(osc, icc * jbgp.nb_ic_blocking + icb,
1486 occ * jbgp.nb_oc_blocking + ocb, osc_prev, icc_prev,
1487 occ_prev);
1488 }
1489 osc_prev = osc;
1490 icc_prev = icc;
1491 occ_prev = occ;
1492
1493 ++loop_idx;
1494
1495 switch (loop_order) {
1496 case loop_order_t::osc_icc_occ:
1497 nd_iterator_step(osc_idx, osc_work, icc_idx, icc_work, occ_idx,
1498 occ_work);
1499 break;
1500 case loop_order_t::osc_occ_icc:
1501 nd_iterator_step(osc_idx, osc_work, occ_idx, occ_work, icc_idx,
1502 icc_work);
1503 break;
1504 case loop_order_t::occ_icc_osc:
1505 nd_iterator_step(occ_idx, occ_work, icc_idx, icc_work, osc_idx,
1506 osc_work);
1507 };
1508 }
1509 if (jbgp.is_amx) amx_tile_release();
1510}
1511
1512template <cpu_isa_t isa>
1513void brgemm_inner_product_bwd_weights_t<
1514 isa>::reduce_and_convert_diff_weights_and_bias(const thread_info_t *ti)
1515 const {
1516 const auto &jbgp = pd()->jbgp_;
1517
1518 if (dnnl_thr_syncable() && jbgp.nthr > 1)
1519 simple_barrier::barrier(ti->barrier_ctx, jbgp.nthr);
1520 if (ti->nthr_os_c == 1) return;
1521
1522 const bool is_f32_out = jbgp.wei_dt == data_type::f32;
1523 const int icb_scale = is_f32_out ? jbgp.ic_block / jbgp.simd_w : 1;
1524
1525 const int icb_work = ti->ic_c_work * jbgp.nb_ic_blocking;
1526 const int ocb_work = ti->oc_c_work * jbgp.nb_oc_blocking;
1527 const int work = ocb_work * icb_work;
1528
1529 int os_chunks = utils::div_up(jbgp.nb_os, jbgp.nb_os_blocking);
1530 int reduce_buffers = nstl::min(ti->nthr_os_c, os_chunks);
1531 int reduce_buf_idx_start = !is_f32_out;
1532 int reduce_buf_idx_end = reduce_buffers - is_f32_out;
1533
1534 int start = 0, end = 0;
1535 balance211(work, ti->nthr_os_c, ti->ithr_os_c, start, end);
1536 if (start == end) return;
1537
1538 int icb_l = 0, ocb_l = 0;
1539 const int acc_size = jbgp.ic_block * jbgp.oc_block;
1540
1541 for (int ir = reduce_buf_idx_start; ir < reduce_buf_idx_end; ++ir) {
1542 int counter = start;
1543 nd_iterator_init(start, ocb_l, ocb_work, icb_l, icb_work);
1544 while (counter < end) {
1545 const int ocb = ti->oc_c_start * jbgp.nb_oc_blocking + ocb_l;
1546 const int icb = ti->ic_c_start * jbgp.nb_ic_blocking + icb_l;
1547 char *wei_to_reduce = get_wei_acc_ptr(ti, ocb, icb, ir);
1548 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1549 char *wei_reduced = !is_f32_out ? get_wei_acc_ptr(ti, ocb, icb, 0)
1550 : ti->diff_weights
1551 + get_blk_off(diff_weights_d, jbgp.wei_dt, ocb,
1552 icb * icb_scale);
1553 acc_ker_->accumulate(
1554 (float *)(wei_reduced), (float *)(wei_to_reduce), acc_size);
1555 if (!is_f32_out && ir + 1 == reduce_buf_idx_end) {
1556 transpose_matrix_c_chunk(ti, ocb, icb * icb_scale,
1557 jbgp.oc_block, jbgp.ic_block, true);
1558 }
1559 ++counter;
1560 nd_iterator_step(ocb_l, ocb_work, icb_l, icb_work);
1561 }
1562 }
1563
1564 if (jbgp.with_bias && ti->ithr_ic_c == 0 && ti->ic_c_work > 0
1565 && ti->ithr_os_c == 0 && ti->os_c_work > 0 && ti->oc_c_work > 0) {
1566 const bool is_f32_bias = jbgp.bia_dt == data_type::f32;
1567 float *bias_reduced = is_f32_bias ? (float *)ti->diff_bias
1568 : (float *)ti->buffer_bias;
1569 int reduce_buf_idx_start = !is_f32_bias;
1570 int reduce_buf_idx_end = reduce_buffers - 1;
1571 int oc_chunk_size = jbgp.nb_oc_blocking * jbgp.oc_block;
1572 int oc = ti->oc_c_start * oc_chunk_size;
1573 int acc_size = nstl::min(ti->oc_c_work * oc_chunk_size, jbgp.oc - oc);
1574
1575 int ir = reduce_buf_idx_start;
1576 for (; ir < reduce_buf_idx_end; ++ir) {
1577 float *bias_to_reduce = (float *)ti->buffer_bias + ir * jbgp.oc;
1578 acc_ker_->accumulate(
1579 &bias_reduced[oc], &bias_to_reduce[oc], acc_size);
1580 }
1581
1582 if (!is_f32_bias) {
1583 float *bias_to_reduce = (float *)ti->buffer_bias + ir * jbgp.oc;
1584 switch (jbgp.bia_dt) {
1585 case data_type::bf16:
1586 add_floats_and_cvt_to_bfloat16(
1587 (bfloat16_t *)(ti->diff_bias) + oc,
1588 &bias_reduced[oc], &bias_to_reduce[oc], acc_size);
1589 break;
1590 case data_type::f16:
1591 add_floats_and_cvt_to_float16(
1592 (float16_t *)(ti->diff_bias) + oc,
1593 &bias_reduced[oc], &bias_to_reduce[oc], acc_size);
1594 break;
1595 default: assert(!"invalid data type");
1596 }
1597 }
1598 }
1599}
1600
1601template <cpu_isa_t isa>
1602void brgemm_inner_product_bwd_weights_t<isa>::execute_backward_weights(
1603 const exec_ctx_t &ctx) const {
1604 const auto &jbgp = pd()->jbgp_;
1605
1606 if (dnnl_thr_syncable() && jbgp.nthr > 1) {
1607 auto scratchpad = ctx.get_scratchpad_grantor();
1608 simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1609 key_conv_wei_bia_reduction_bctx));
1610 }
1611
1612 parallel(jbgp.nthr, [&](const int ithr, const int nthr) {
1613 thread_info_t thread_info(this, ctx, ithr);
1614 compute_diff_weights_and_bias(&thread_info);
1615
1616 if (dnnl_thr_syncable()) {
1617 reduce_and_convert_diff_weights_and_bias(&thread_info);
1618 }
1619 });
1620
1621 if (!dnnl_thr_syncable()) {
1622 parallel(jbgp.nthr, [&](const int ithr, const int nthr) {
1623 thread_info_t thread_info(this, ctx, ithr);
1624 reduce_and_convert_diff_weights_and_bias(&thread_info);
1625 });
1626 }
1627}
1628
1629template struct brgemm_inner_product_bwd_weights_t<avx512_core_amx_fp16>;
1630template struct brgemm_inner_product_bwd_weights_t<avx512_core_fp16>;
1631template struct brgemm_inner_product_bwd_weights_t<avx512_core_amx>;
1632template struct brgemm_inner_product_bwd_weights_t<avx512_core_bf16>;
1633template struct brgemm_inner_product_bwd_weights_t<avx512_core>;
1634
1635} // namespace x64
1636} // namespace cpu
1637} // namespace impl
1638} // namespace dnnl
1639
1640// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1641