1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/cpu_primitive.hpp"
23#include "cpu/scale_utils.hpp"
24
25#include "cpu/x64/jit_avx512_core_amx_conv_utils.hpp"
26#include "cpu/x64/jit_avx512_core_amx_convolution.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33using namespace dnnl::impl::status;
34using namespace dnnl::impl::memory_tracking::names;
35using namespace dnnl::impl::utils;
36
37using namespace nstl;
38
39#define wht_blk_off(d, g, ...) \
40 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
41 : (d).blk_off(__VA_ARGS__))
42
43#define mem_blk_off(mdw, n, c, d, h, w) \
44 (pd()->ndims() == 3 \
45 ? (mdw).blk_off((n), (c), (w)) \
46 : (pd()->ndims() == 4 \
47 ? (mdw).blk_off((n), (c), (h), (w)) \
48 : (mdw).blk_off((n), (c), (d), (h), (w))))
49
50template <typename T>
51static inline T accum_with_upper_bound(T ub, T lv, T uv) {
52 return nstl::min(ub, nstl::min(ub, lv) + nstl::max(0, ub - uv));
53}
54
55void jit_avx512_core_amx_convolution_fwd_t::prepare_padded_bias(
56 const char *&bias, const memory_tracking::grantor_t &scratchpad) const {
57 if (!pd()->wants_padded_bias()) return;
58
59 const size_t bia_dt_size = pd()->jcp_.typesize_bia;
60 auto padded_bias = scratchpad.template get<char>(
61 memory_tracking::names::key_conv_padded_bias);
62 utils::array_copy(
63 padded_bias, bias, bia_dt_size * pd()->jcp_.oc_without_padding);
64 utils::array_set(padded_bias + bia_dt_size * pd()->jcp_.oc_without_padding,
65 0.f, bia_dt_size * (pd()->jcp_.oc - pd()->jcp_.oc_without_padding));
66 bias = padded_bias;
67}
68
69status_t
70jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering(
71 const exec_ctx_t &ctx) const {
72 const auto &jcp = pd()->jcp_;
73 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
74 const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
75 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
76 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
77 const auto post_ops_binary_rhs_arg_vec
78 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
79
80 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
81 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
82
83 const memory_desc_wrapper src_d(pd()->src_md());
84 const memory_desc_wrapper dst_d(pd()->dst_md());
85 const memory_desc_wrapper weights_d(pd()->weights_md(0));
86 const memory_desc_wrapper bias_d(pd()->weights_md(1));
87
88 const size_t src_dt_size = types::data_type_size(src_d.data_type());
89 const size_t wei_dt_size = types::data_type_size(weights_d.data_type());
90 const size_t bia_dt_size
91 = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0;
92 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
93
94 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
95
96 assert(jcp.is_relo);
97 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
98
99 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
100 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
101 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
102
103 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
104 src_scales, wei_scales, pd()->OC(), pd()->attr());
105
106 auto inp_p_buffer = ctx.get_scratchpad_grantor().template get<char>(
107 key_conv_amx_inp_buffer); // fix the template
108 auto wei_buffer = ctx.get_scratchpad_grantor().template get<char>(
109 key_conv_amx_wei_buffer); // fix the template
110 auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>(
111 key_conv_amx_wsp_buffer);
112 auto tcfg = ctx.get_scratchpad_grantor().template get<char>(
113 key_conv_amx_tilecfg);
114 auto zero_point_pbuff = ctx.get_scratchpad_grantor().template get<int32_t>(
115 key_conv_zero_point_pad);
116 auto zp_flags_ = ctx.get_scratchpad_grantor().template get<bool>(
117 key_conv_zero_point_flag);
118
119 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
120 char *w = const_cast<char *>(weights);
121 int32_t *zp_compensation = jcp.src_zero_point
122 ? reinterpret_cast<int32_t *>(w + offset)
123 : nullptr;
124
125 const int t_pad_output = jcp.t_pad_output;
126 const int b_pad_output = jcp.b_pad_output;
127 const int b_pad_start = nstl::max(jcp.oh - b_pad_output, t_pad_output);
128 const int zp_buff_b_pad_start
129 = nstl::max(jcp.oh_pad - b_pad_output, t_pad_output);
130
131 const int ngroups = jcp.ngroups;
132 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
133 const int oh_chunks = utils::div_up(jcp.oh, jcp.oh_blk_size);
134 const int work_amount
135 = jcp.mb * jcp.ngroups * oh_chunks * jcp.nb_ow * oc_chunks;
136 const int zp_pbuff_size = jcp.zp_pbuff_size;
137
138 // reorder weights from (g)Owhi16o to (g)OR16r16o4r, where r := whi
139 auto p = jit_conv_call_s();
140 p.src = weights;
141 p.dst = wei_buffer;
142 kernel_->copy_to_wbuffer()(&p);
143 const char *wei = wei_buffer;
144
145 const size_t oc_subblock_step
146 = jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
147 const size_t wei_oc_shift = (size_t)jcp.nb_oc_blocking * jcp.nb_ic_int
148 * rnd_up(oc_subblock_step, jcp.ic_block_int * jcp.oc_block);
149
150 // Initialize the tile configuration in memory, so that each thread can
151 // load this configuration from memory via `amx_tile_configure(tcfg)`.
152 kernel_->tile_configure(tcfg);
153
154 // init zero_point padding buffer
155 const bool req_zero_point_buffer = jcp.req_zero_point_buffer;
156 const bool zp_pbuff_outer_compute = jcp.zp_pbuff_outer_compute;
157 const bool zp_pbuff_parallel_block
158 = req_zero_point_buffer && !zp_pbuff_outer_compute;
159 if (req_zero_point_buffer && zp_pbuff_outer_compute) {
160 const size_t wei_oc_step = (size_t)jcp.kh * jcp.kw * jcp.ic_block_int_np
161 * jcp.nb_oc_blocking * jcp.oc_block;
162 const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1);
163 const int dilate_h = jcp.dilate_h + 1;
164 const int gen_kh = (jcp.kh - 1) * dilate_h + 1;
165 const int oh_work = jcp.oh_pad;
166 parallel_nd(
167 ngroups, oc_chunks, oh_work, [&](dim_t g, dim_t occ, dim_t oh) {
168 auto p = jit_conv_call_s();
169
170 const int oh_ = oh >= zp_buff_b_pad_start
171 ? b_pad_start + oh - zp_buff_b_pad_start
172 : oh;
173 const int ih = oh_ * jcp.stride_h - jcp.t_pad;
174 const int t_overflow
175 = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h));
176 const int b_overflow = nstl::min(jcp.kh,
177 div_up(nstl::max(0, ih + gen_kh - jcp.ih),
178 dilate_h));
179 p.t_overflow = t_overflow;
180 p.b_overflow = b_overflow;
181 p.kh_padding
182 = nstl::max(0, jcp.kh - t_overflow - b_overflow);
183
184 const int ocb = g * jcp.oc
185 + occ * jcp.nb_oc_blocking * jcp.oc_block;
186 const size_t ch_offset = dst_d.blk_off(0, ocb);
187 const size_t sp_offset = oh * jcp.ow_pad * sp_stride;
188 p.zero_point_pbuff
189 = &zero_point_pbuff[ch_offset + sp_offset];
190 p.oc_blocks = occ * jcp.nb_oc_blocking;
191 p.filt = weights
192 + wei_dt_size * (g * oc_chunks + occ) * wei_oc_step;
193 p.src_zero_point = src_zero_point;
194
195 kernel_->zp_pbuff_kernel()(&p);
196 });
197 }
198
199 // TODO: implement 2D parallelization driver (g * spatial x oc) to increase
200 // input data reuse and parallelize input data reorders
201 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
202 int start {0}, end {0};
203 balance211(work_amount, nthr, ithr, start, end);
204 int32_t *local_zp_pbuff = req_zero_point_buffer
205 ? (zp_pbuff_outer_compute
206 ? zero_point_pbuff
207 : &zero_point_pbuff[ithr * zp_pbuff_size])
208 : nullptr;
209 bool *zp_flags = zp_pbuff_parallel_block
210 ? &zp_flags_[ithr * oc_chunks * ngroups]
211 : nullptr;
212 if (zp_pbuff_parallel_block) {
213 PRAGMA_OMP_SIMD()
214 for (int oc = 0; oc < oc_chunks * ngroups; oc++)
215 zp_flags[oc] = true;
216 }
217
218 auto p = jit_conv_call_s();
219 amx_tile_configure(tcfg);
220
221 const int oh_work = jcp.oh_pad;
222 const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1);
223 const int dilate_h = jcp.dilate_h + 1;
224 const int gen_kh = (jcp.kh - 1) * dilate_h + 1;
225 const size_t wei_oc_step = (size_t)jcp.kh * jcp.kw * jcp.ic_block_int_np
226 * jcp.nb_oc_blocking * jcp.oc_block;
227 size_t oc_stride = dst_d.blk_off(0, 1);
228 const int owb_limit = jcp.nb_ow - jcp.r_pad_blk - jcp.no_pad_w_blk;
229
230 int mb {0}, g {0}, ohc {0}, owb {0}, occ {0};
231 // need "inner" oh blocks w.r.t. ow blocks to allow pbuffer reuse
232 nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, owb, jcp.nb_ow, ohc,
233 oh_chunks, occ, oc_chunks);
234 int last_copied_mb = -1;
235 int last_copied_ohc = -1;
236 int last_copied_owb = -1;
237 int last_copied_g = -1;
238 while (start < end) {
239 char *inp_buffer
240 = inp_p_buffer + src_dt_size * ithr * jcp.inp_buffer_size;
241
242 assert(IMPLICATION(
243 jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding));
244 int oc = g * jcp.oc + occ * jcp.nb_oc_blocking * jcp.oc_block;
245 int ocb = jcp.is_nspc ? oc : oc / jcp.oc_block;
246 const char *bias_w = bias
247 ? bias + (bias_d.blk_off(oc) * bia_dt_size)
248 : nullptr;
249 p.zp_compensation
250 = jcp.src_zero_point ? zp_compensation + oc : nullptr;
251 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
252 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
253
254 int oh_s = ohc * jcp.oh_blk_size;
255 int oh_e = nstl::min(jcp.oh, oh_s + jcp.oh_blk_size);
256 bool is_inp_buffer_relevant = true && last_copied_mb == mb
257 && last_copied_ohc == ohc && last_copied_owb == owb
258 && last_copied_g == g;
259 bool has_inp_buffer_overlap = true && last_copied_mb == mb
260 && last_copied_owb == owb && last_copied_g == g
261 && jcp.oh_blk_size == jcp.nb_oh_blocking;
262 bool is_zp_pbuff_relevant = zp_pbuff_parallel_block
263 ? zp_flags[g * oc_chunks + occ] // already computed?
264 : false;
265
266 int cur_t_pad = nstl::max(0, t_pad_output - oh_s);
267 int cur_b_pad = nstl::max(
268 nstl::max(0, jcp.oh - b_pad_output - oh_s), cur_t_pad);
269 size_t zp_oh
270 = accum_with_upper_bound(oh_s, t_pad_output, b_pad_start);
271
272 int limit_idx = 0;
273 constexpr int limit_size = 5;
274 for (; req_zero_point_buffer && limit_idx < limit_size;
275 limit_idx++) {
276 // find current 'oh_blk' index from 'oh_s`
277 if ((size_t)oh_s < jcp.h_blk_limits[limit_idx]) break;
278 }
279
280 if (is_zp_pbuff_relevant) {
281 assert(!zp_pbuff_outer_compute);
282 zp_flags[g * oc_chunks + occ] = false;
283 for (int oh_pad = 0; oh_pad < oh_work; ++oh_pad) {
284 const int oh_ = oh_pad >= zp_buff_b_pad_start
285 ? b_pad_start + oh_pad - zp_buff_b_pad_start
286 : oh_pad;
287 const int ih = oh_ * jcp.stride_h - jcp.t_pad;
288 const int t_overflow
289 = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h));
290 const int b_overflow = nstl::min(jcp.kh,
291 div_up(nstl::max(0, ih + gen_kh - jcp.ih),
292 dilate_h));
293 p.t_overflow = t_overflow;
294 p.b_overflow = b_overflow;
295 p.kh_padding
296 = nstl::max(0, jcp.kh - t_overflow - b_overflow);
297
298 const size_t ch_offset = dst_d.blk_off(0, oc);
299 const size_t sp_offset = oh_pad * jcp.ow_pad * sp_stride;
300 p.zero_point_pbuff = &local_zp_pbuff[ch_offset + sp_offset];
301 p.oc_blocks = occ * jcp.nb_oc_blocking;
302 p.filt = weights
303 + wei_dt_size * (g * oc_chunks + occ) * wei_oc_step;
304 p.src_zero_point = src_zero_point;
305
306 kernel_->zp_pbuff_kernel()(&p);
307 }
308 }
309
310 int oh_step = jcp.nb_oh_blocking * jcp.oh_per_tile;
311 for (int oh = oh_s; oh < oh_e; oh += oh_step) {
312 const int inp_buffer_h_step
313 = jcp.stride_h * jcp.ic_without_padding;
314 assert(jcp.is_nspc);
315 assert(jcp.stride_h <= jcp.kh);
316
317 int ow = owb * jcp.ow_block;
318
319 char *inp_buffer_oh
320 = inp_buffer + src_dt_size * oh * inp_buffer_h_step;
321
322 if (!is_inp_buffer_relevant) {
323 // prepare padded input buffer
324 const int icb = g * jcp.ic;
325 size_t inp_offset = src_d.blk_off(mb, icb);
326 const int iw_step = jcp.ngroups * jcp.ic_without_padding;
327 const char *psrc = src + src_dt_size * inp_offset;
328 // calculate overlap...
329 const int ih_overlap = has_inp_buffer_overlap
330 * nstl::max(0, jcp.kh - oh_step * jcp.stride_h);
331 const int kh_eff = jcp.kh - ih_overlap;
332 // prepare padded input buffer
333 char *pdst = inp_buffer_oh
334 + src_dt_size * ih_overlap * jcp.ic_without_padding;
335 for (int doh = 0; doh < oh_step; doh++) {
336 const int ih_s = (doh + oh) * jcp.stride_h - jcp.t_pad
337 + ih_overlap;
338 const int ih_e = ih_s + kh_eff;
339 const int ih = nstl::max(0, ih_s);
340 p.t_overflow = nstl::max(0, -ih_s);
341 p.b_overflow = nstl::min<int>(
342 kh_eff, nstl::max(0, ih_e - jcp.ih));
343 p.kh_padding = nstl::max<int>(
344 0, (kh_eff - p.t_overflow - p.b_overflow));
345 p.kh_offset = kh_eff;
346
347 const int iw_s = ow * jcp.stride_w - jcp.l_pad;
348 const int iw_e = iw_s + jcp.iwp;
349 const int iw = nstl::max(0, iw_s);
350 p.f_overflow = nstl::max(0, -iw_s);
351 p.back_overflow = nstl::max(0, iw_e - jcp.iw);
352 p.kw_padding = nstl::max<int>(
353 0, jcp.iwp - p.f_overflow - p.back_overflow);
354
355 p.src = psrc
356 + src_dt_size * (ih * jcp.iw + iw) * iw_step;
357 p.dst = pdst
358 + src_dt_size * doh * jcp.iwp * jcp.kh
359 * jcp.ic_without_padding;
360
361 kernel_->copy_to_pbuffer()(&p);
362 }
363 }
364
365 p.src = inp_buffer_oh;
366 size_t dst_offset = mem_blk_off(dst_d, mb, ocb, 0, oh, ow);
367 p.dst = dst + dst_dt_size * dst_offset;
368 const size_t pbuff_offset
369 = zp_oh * jcp.ow_pad * sp_stride + ocb * oc_stride;
370 p.zero_point_pbuff = req_zero_point_buffer
371 ? &local_zp_pbuff[pbuff_offset]
372 : nullptr;
373 p.filt = wei
374 + wei_dt_size * (g * oc_chunks + occ) * wei_oc_shift;
375 p.bias = bias_w;
376 p.scales = &oscales[jcp.is_oc_scale * oc];
377 p.dst_scale = &dst_scales[0];
378
379 p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size;
380
381 if (req_zero_point_buffer
382 && (size_t)oh >= jcp.h_blk_limits[limit_idx])
383 limit_idx++;
384 assert(limit_idx < 6);
385 p.ohb = limit_idx;
386 p.last_h = (oh + oh_step <= oh_e);
387 const int zp_owb = nstl::min(jcp.l_pad_blk, owb)
388 + nstl::max(0, owb - owb_limit);
389 p.owb = req_zero_point_buffer ? zp_owb : owb;
390
391 p.oc_blocks = occ * jcp.nb_oc_blocking;
392
393 p.oc_l_off = oc;
394 p.post_ops_binary_rhs_arg_vec
395 = post_ops_binary_rhs_arg_vec.data();
396 p.dst_orig = dst;
397
398 (*kernel_)(&p);
399
400 if (req_zero_point_buffer) {
401 zp_oh += accum_with_upper_bound(
402 oh_step, cur_t_pad, cur_b_pad);
403 cur_t_pad = nstl::max(0, cur_t_pad - oh_step);
404 cur_b_pad = nstl::max(0, cur_b_pad - oh_step);
405 }
406 }
407 last_copied_mb = mb;
408 last_copied_ohc = ohc;
409 last_copied_owb = owb;
410 last_copied_g = g;
411 ++start;
412 // need "inner" oh blocks w.r.t. ow blocks to allow pbuffer reuse
413 nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, owb, jcp.nb_ow, ohc,
414 oh_chunks, occ, oc_chunks);
415 }
416
417 amx_tile_release();
418 });
419 return status::success;
420}
421
422status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward(
423 const exec_ctx_t &ctx) const {
424 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
425 const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
426 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
427 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
428 const auto post_ops_binary_rhs_arg_vec
429 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
430
431 const memory_desc_wrapper src_d(pd()->src_md());
432 const memory_desc_wrapper dst_d(pd()->dst_md());
433 const memory_desc_wrapper weights_d(pd()->weights_md(0));
434 const memory_desc_wrapper bias_d(pd()->weights_md(1));
435
436 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
437 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
438
439 const size_t bia_dt_size = pd()->with_bias()
440 ? types::data_type_size(pd()->desc()->bias_desc.data_type)
441 : 0;
442 const size_t dst_dt_size
443 = types::data_type_size(pd()->desc()->dst_desc.data_type);
444 const size_t src_dt_size
445 = types::data_type_size(pd()->desc()->src_desc.data_type);
446 const size_t wei_dt_size
447 = types::data_type_size(pd()->desc()->weights_desc.data_type);
448
449 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
450
451 const auto &jcp = pd()->jcp_;
452 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
453
454 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
455 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
456 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
457
458 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
459 src_scales, wei_scales, pd()->OC(), pd()->attr());
460
461 // TODO: use block offset instead of hand-calculated one
462 //size_t wei_oc_shift = wht_blk_off(weights_d, 0, 1);
463 const size_t wei_oc_shift = (size_t)jcp.nb_oc_blocking * jcp.nb_ic_int
464 * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
465 const size_t wei_d_shift
466 = (size_t)jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block;
467
468 auto inp_p_buffer = ctx.get_scratchpad_grantor().template get<char>(
469 key_conv_amx_inp_buffer); // fix the template
470 auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>(
471 key_conv_amx_wsp_buffer);
472 auto tcfg = ctx.get_scratchpad_grantor().template get<char>(
473 key_conv_amx_tilecfg);
474 auto zero_point_pbuff = ctx.get_scratchpad_grantor().template get<int32_t>(
475 key_conv_zero_point_pad);
476 auto zp_flags_ = ctx.get_scratchpad_grantor().template get<bool>(
477 key_conv_zero_point_flag);
478
479 const size_t offset = weights_d.size() - weights_d.additional_buffer_size();
480 char *w = const_cast<char *>(weights);
481 int32_t *zp_compensation = jcp.src_zero_point
482 ? reinterpret_cast<int32_t *>(&w[offset])
483 : nullptr;
484
485 const int f_pad_output = jcp.f_pad_output;
486 const int back_pad_output = jcp.back_pad_output;
487 const int back_pad_start
488 = nstl::max(jcp.od - back_pad_output, f_pad_output);
489 const int zp_buff_back_pad_start
490 = nstl::max(jcp.od_pad - back_pad_output, f_pad_output);
491 const int t_pad_output = jcp.t_pad_output;
492 const int b_pad_output = jcp.b_pad_output;
493 const int b_pad_start = nstl::max(jcp.oh - b_pad_output, t_pad_output);
494 const int zp_buff_b_pad_start
495 = nstl::max(jcp.oh_pad - b_pad_output, t_pad_output);
496
497 const int ngroups = jcp.ngroups;
498 const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
499 const int oh_chunks = utils::div_up(jcp.oh, jcp.oh_blk_size);
500 const size_t work_amount = (size_t)jcp.mb * jcp.ngroups * jcp.od * oh_chunks
501 * jcp.nb_ow * oc_chunks;
502 const int zp_pbuff_size = jcp.zp_pbuff_size;
503
504 // Initialize the tile configuration in memory, so that each thread can
505 // load this configuration from memory via `amx_tile_configure(tcfg)`.
506 kernel_->tile_configure(tcfg);
507
508 // init zero_point padding buffer
509 const bool req_zero_point_buffer = jcp.req_zero_point_buffer;
510 const bool zp_pbuff_outer_compute = jcp.zp_pbuff_outer_compute;
511 const bool zp_pbuff_parallel_block
512 = req_zero_point_buffer && !zp_pbuff_outer_compute;
513 if (req_zero_point_buffer && zp_pbuff_outer_compute) {
514 const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1);
515 const int dilate_d = jcp.dilate_d + 1;
516 const int gen_kd = (jcp.kd - 1) * dilate_d + 1;
517 const int dilate_h = jcp.dilate_h + 1;
518 const int gen_kh = (jcp.kh - 1) * dilate_h + 1;
519 const int od_work = jcp.od_pad;
520 const int oh_work = jcp.oh_pad;
521 parallel_nd(ngroups, oc_chunks, od_work, oh_work,
522 [&](dim_t g, dim_t occ, dim_t od, dim_t oh) {
523 auto p = jit_conv_call_s();
524
525 const int od_ = od >= zp_buff_back_pad_start
526 ? back_pad_start + od - zp_buff_back_pad_start
527 : od;
528 const int id = od_ * jcp.stride_d - jcp.f_pad;
529 const int f_overflow
530 = nstl::min(jcp.kd, div_up(max(0, -id), dilate_d));
531 const int back_overflow = nstl::min(jcp.kd,
532 div_up(nstl::max(0, id + gen_kd - jcp.id),
533 dilate_d));
534 p.f_overflow = f_overflow;
535 p.back_overflow = back_overflow;
536 p.kd_padding
537 = nstl::max(0, jcp.kd - f_overflow - back_overflow);
538 const int oh_ = oh >= zp_buff_b_pad_start
539 ? b_pad_start + oh - zp_buff_b_pad_start
540 : oh;
541 const int ih = oh_ * jcp.stride_h - jcp.t_pad;
542 const int t_overflow
543 = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h));
544 const int b_overflow = nstl::min(jcp.kh,
545 div_up(nstl::max(0, ih + gen_kh - jcp.ih),
546 dilate_h));
547 p.t_overflow = t_overflow;
548 p.b_overflow = b_overflow;
549 p.kh_padding
550 = nstl::max(0, jcp.kh - t_overflow - b_overflow);
551
552 const int ocb = g * jcp.oc
553 + occ * jcp.nb_oc_blocking * jcp.oc_block;
554 const size_t ch_offset = dst_d.blk_off(0, ocb);
555 auto sp_offset
556 = (od * jcp.oh_pad + oh) * jcp.ow_pad * sp_stride;
557 p.zero_point_pbuff
558 = &zero_point_pbuff[ch_offset + sp_offset];
559 p.oc_blocks = occ * jcp.nb_oc_blocking;
560 p.filt = weights
561 + wei_dt_size * (g * oc_chunks + occ)
562 * wei_oc_shift;
563 p.src_zero_point = src_zero_point;
564
565 kernel_->zp_pbuff_kernel()(&p);
566 });
567 }
568
569 // TODO: implement 2D parallelization driver (g * spatial x oc) to increase
570 // input data reuse and parallelize input data reorders
571 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
572 size_t start {0}, end {0};
573 balance211(work_amount, nthr, ithr, start, end);
574 int32_t *local_zp_pbuff = req_zero_point_buffer
575 ? (zp_pbuff_outer_compute
576 ? zero_point_pbuff
577 : &zero_point_pbuff[ithr * zp_pbuff_size])
578 : nullptr;
579 bool *zp_flags = zp_pbuff_parallel_block
580 ? &zp_flags_[ithr * oc_chunks * ngroups]
581 : nullptr;
582 if (zp_pbuff_parallel_block) {
583 PRAGMA_OMP_SIMD()
584 for (int oc = 0; oc < oc_chunks * ngroups; oc++)
585 zp_flags[oc] = true;
586 }
587
588 auto p = jit_conv_call_s();
589 amx_tile_configure(tcfg);
590
591 const int oh_work = jcp.oh_pad;
592 const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1);
593 const int dilate_d = jcp.dilate_d + 1;
594 const int dilate_h = jcp.dilate_h + 1;
595 const int gen_kh = (jcp.kh - 1) * dilate_h + 1;
596 size_t oc_stride = dst_d.blk_off(0, 1);
597 const int owb_limit = jcp.nb_ow - jcp.r_pad_blk - jcp.no_pad_w_blk;
598
599 int mb {0}, g {0}, odc {0}, ohc {0}, owb {0}, occ {0};
600 nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, odc, jcp.od, ohc,
601 oh_chunks, owb, jcp.nb_ow, occ, oc_chunks);
602 int last_copied_mb = -1;
603 int last_copied_odc = -1;
604 int last_copied_ohc = -1;
605 int last_copied_owb = -1;
606 int last_copied_g = -1;
607 while (start < end) {
608 char *inp_buffer
609 = inp_p_buffer + src_dt_size * ithr * jcp.inp_buffer_size;
610
611 assert(IMPLICATION(
612 jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding));
613 int oc = g * jcp.oc + occ * jcp.nb_oc_blocking * jcp.oc_block;
614 int ocb = jcp.is_nspc ? oc : oc / jcp.oc_block;
615 const char *bias_w = bias
616 ? bias + (bias_d.blk_off(oc) * bia_dt_size)
617 : nullptr;
618 p.zp_compensation
619 = jcp.src_zero_point ? zp_compensation + oc : nullptr;
620 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
621 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
622
623 const size_t inp_src_d_stride = mem_blk_off(src_d, 0, 0, 1, 0, 0);
624 int oh_s = ohc * jcp.oh_blk_size;
625 int oh_e = nstl::min(jcp.oh, oh_s + jcp.oh_blk_size);
626 bool is_inp_buffer_relevant = true && last_copied_mb == mb
627 && last_copied_odc == odc && last_copied_ohc == ohc
628 && last_copied_owb == owb && last_copied_g == g;
629 bool is_zp_pbuff_relevant = zp_pbuff_parallel_block
630 ? zp_flags[g * oc_chunks + occ] // already computed?
631 : false;
632
633 int cur_t_pad = nstl::max(0, t_pad_output - oh_s);
634 int cur_b_pad = nstl::max(
635 nstl::max(0, jcp.oh - b_pad_output - oh_s), cur_t_pad);
636 const size_t zp_od = odc >= back_pad_start
637 ? odc - back_pad_start + zp_buff_back_pad_start
638 : nstl::min(f_pad_output, odc);
639 size_t zp_oh
640 = accum_with_upper_bound(oh_s, t_pad_output, b_pad_start);
641
642 int limit_idx = 0;
643 constexpr int limit_size = 5;
644 for (; req_zero_point_buffer && limit_idx < limit_size;
645 limit_idx++) {
646 // find current 'oh_blk' index from 'oh_s`
647 if ((size_t)oh_s < jcp.h_blk_limits[limit_idx]) break;
648 }
649
650 if (is_zp_pbuff_relevant) {
651 assert(pd()->ndims() != 5);
652 assert(!zp_pbuff_outer_compute);
653 zp_flags[g * oc_chunks + occ] = false;
654 for (int oh_pad = 0; oh_pad < oh_work; ++oh_pad) {
655 const int oh_ = oh_pad >= zp_buff_b_pad_start
656 ? b_pad_start + oh_pad - zp_buff_b_pad_start
657 : oh_pad;
658 const int ih = oh_ * jcp.stride_h - jcp.t_pad;
659 const int t_overflow
660 = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h));
661 const int b_overflow = nstl::min(jcp.kh,
662 div_up(nstl::max(0, ih + gen_kh - jcp.ih),
663 dilate_h));
664 p.t_overflow = t_overflow;
665 p.b_overflow = b_overflow;
666 p.kh_padding
667 = nstl::max(0, jcp.kh - t_overflow - b_overflow);
668
669 const size_t ch_offset = dst_d.blk_off(0, oc);
670 const size_t sp_offset = oh_pad * jcp.ow_pad * sp_stride;
671 p.zero_point_pbuff = &local_zp_pbuff[ch_offset + sp_offset];
672 p.oc_blocks = occ * jcp.nb_oc_blocking;
673 p.filt = weights
674 + wei_dt_size * (g * oc_chunks + occ)
675 * wei_oc_shift;
676 p.src_zero_point = src_zero_point;
677
678 kernel_->zp_pbuff_kernel()(&p);
679 }
680 }
681
682 const int id_s = odc * jcp.stride_d - jcp.f_pad;
683 const int d_f_overflow
684 = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d));
685 const int d_back_overflow = nstl::min(jcp.kd,
686 div_up(max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
687 dilate_d));
688 p.kd_padding
689 = nstl::max(0, jcp.kd - d_f_overflow - d_back_overflow);
690
691 int oh_step = jcp.nb_oh_blocking * jcp.oh_per_tile;
692 for (int oh = oh_s; oh < oh_e; oh += oh_step) {
693 const int gen_kh = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1);
694 const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh);
695 if (!is_inp_buffer_relevant) {
696 const int iw = nstl::max(
697 0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad);
698
699 assert(IMPLICATION(
700 jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding));
701 const int icb = g * (jcp.is_nspc ? jcp.ic : jcp.nb_ic);
702
703 // generalized kh including dilation
704 // the current implementation of copy routine is not
705 // optimal for small jcp.oh_blk_size as it copies
706 // dilation rows to buffer
707 const bool continuous_copy = gen_kh >= jcp.stride_h;
708 int current_oh_block = nstl::min(oh_e - oh, oh_step);
709 int num_copy_calls = continuous_copy ? 1 : current_oh_block;
710 for (int ohi = 0; ohi < num_copy_calls; ohi++) {
711 int ih_copy_start
712 = (oh + ohi) * jcp.stride_h - jcp.t_pad;
713 int ih_copy_end = ih_copy_start + gen_kh;
714 if (continuous_copy) {
715 ih_copy_end
716 += jcp.stride_h * (current_oh_block - 1);
717 if (oh > oh_s)
718 // it's non-first block, shift start to the end
719 // of previous block
720 // ih_copy_end_prev =
721 // (oh - 1) * str_h - t_pad + kh
722 ih_copy_start += gen_kh - jcp.stride_h;
723 }
724 int ih_zero_top = nstl::max(0, -ih_copy_start);
725 int ih_zero_bottom = nstl::max(0, ih_copy_end - jcp.ih);
726 // how many real data rows to copy (including padding)
727 int rows_to_copy = ih_copy_end - ih_copy_start;
728 p.kh_padding = max(0, rows_to_copy);
729 p.t_overflow = ih_zero_top;
730 p.b_overflow = ih_zero_bottom;
731 p.owb = owb;
732 int ih = nstl::max(ih_copy_start, 0);
733 size_t inp_offset
734 = mem_blk_off(src_d, mb, icb, id_s, ih, iw)
735 + d_f_overflow * dilate_d * inp_src_d_stride;
736 p.src = src + src_dt_size * inp_offset;
737 // inp_buffer has physical padding
738 int ih_buf = continuous_copy
739 ? ih_copy_start + jcp.t_pad
740 - oh_s * jcp.stride_h
741 : gen_stride_h * (oh + ohi - oh_s);
742 p.dst = inp_buffer
743 + src_dt_size * ih_buf * jcp.iwp
744 * jcp.ic_block_int_np;
745
746 kernel_->copy_to_pbuffer()(&p);
747 }
748 }
749 int ih_buf = gen_stride_h * (oh - oh_s);
750 int ow = owb * jcp.ow_block;
751 p.src = inp_buffer
752 + src_dt_size * ih_buf * jcp.iwp * jcp.ic_block_int_np;
753
754 size_t dst_offset = mem_blk_off(dst_d, mb, ocb, odc, oh, ow);
755 p.dst = dst + dst_dt_size * dst_offset;
756 const size_t pbuff_offset
757 = (zp_od * jcp.oh_pad + zp_oh) * jcp.ow_pad * sp_stride
758 + ocb * oc_stride;
759 p.zero_point_pbuff = req_zero_point_buffer
760 ? &local_zp_pbuff[pbuff_offset]
761 : nullptr;
762
763 p.filt = weights
764 + ((g * oc_chunks + occ) * wei_oc_shift
765 + d_f_overflow * wei_d_shift)
766 * wei_dt_size;
767 p.bias = bias_w;
768 p.scales = &oscales[jcp.is_oc_scale * oc];
769 p.dst_scale = &dst_scales[0];
770
771 p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size;
772 if (req_zero_point_buffer
773 && (size_t)oh >= jcp.h_blk_limits[limit_idx])
774 limit_idx++;
775 assert(limit_idx < 6);
776 p.ohb = limit_idx;
777 p.last_h = (oh + oh_step <= oh_e);
778 const int zp_owb = nstl::min(jcp.l_pad_blk, owb)
779 + nstl::max(0, owb - owb_limit);
780 p.owb = req_zero_point_buffer ? zp_owb : owb;
781
782 p.oc_blocks = occ * jcp.nb_oc_blocking;
783
784 p.oc_l_off = oc;
785 p.post_ops_binary_rhs_arg_vec
786 = post_ops_binary_rhs_arg_vec.data();
787 p.dst_orig = dst;
788
789 (*kernel_)(&p);
790
791 if (req_zero_point_buffer) {
792 zp_oh += accum_with_upper_bound(
793 oh_step, cur_t_pad, cur_b_pad);
794 cur_t_pad = nstl::max(0, cur_t_pad - oh_step);
795 cur_b_pad = nstl::max(0, cur_b_pad - oh_step);
796 }
797 }
798 last_copied_mb = mb;
799 last_copied_odc = odc;
800 last_copied_ohc = ohc;
801 last_copied_owb = owb;
802 last_copied_g = g;
803 ++start;
804 nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, odc, jcp.od, ohc,
805 oh_chunks, owb, jcp.nb_ow, occ, oc_chunks);
806 }
807
808 amx_tile_release();
809 });
810 return status::success;
811}
812
813template <data_type_t diff_src_type, data_type_t wei_type,
814 data_type_t diff_dst_type>
815status_t jit_avx512_core_amx_convolution_bwd_data_t<diff_src_type, wei_type,
816 diff_dst_type>::execute_backward(const exec_ctx_t &ctx) const {
817 const auto diff_dst = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST);
818 const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS);
819 auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC);
820
821 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
822 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
823 const memory_desc_wrapper weights_d(pd()->weights_md(0));
824
825 // unused in kernel for bf16, but attributes have scales buffer by default
826 // and using it here simplifies the shared `execute_backward_loop`.
827 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
828 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
829 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
830
831 const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(),
832 src_scales, wei_scales, pd()->OC(), pd()->attr());
833
834 amx_utils::execute_backward_convolution_body(ctx, pd()->jcp_, kernel_,
835 diff_dst, weights, nullptr /* no bias */, oscales, dst_scales,
836 diff_src, diff_dst_d, weights_d,
837 memory_desc_wrapper(nullptr) /* no bias */, diff_src_d);
838 return status::success;
839}
840
841template struct jit_avx512_core_amx_convolution_bwd_data_t<data_type::bf16,
842 data_type::bf16, data_type::bf16>;
843template struct jit_avx512_core_amx_convolution_bwd_data_t<data_type::f32,
844 data_type::bf16, data_type::bf16>;
845
846status_t jit_avx512_core_amx_convolution_bwd_weights_t::init(engine_t *engine) {
847 const auto &j = pd()->jcp_;
848
849 nthr_ = j.nthr;
850 nthr_mb_ = j.nthr_mb;
851 nthr_g_ = j.nthr_g;
852 nthr_oc_b_ = j.nthr_oc_b;
853 nthr_ic_b_ = j.nthr_ic_b;
854
855 CHECK(safe_ptr_assign(
856 kernel_, new jit_avx512_core_amx_bwd_weights_kernel_t(j)));
857 CHECK(kernel_->create_kernel());
858
859 CHECK(safe_ptr_assign(trans_kernel_, create_trans_src(&j)));
860 CHECK(trans_kernel_->create_kernel());
861 CHECK(safe_ptr_assign(trans_dst_kernel_, create_trans_dst(&j)));
862 CHECK(trans_dst_kernel_->create_kernel());
863 if (nthr_mb_ > 1) {
864 CHECK(safe_ptr_assign(
865 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
866 CHECK(acc_ker_->create_kernel());
867 }
868 if (j.transform_to_vnni) {
869 CHECK(safe_ptr_assign(diff_wei_trans_kernel_,
870 new jit_diff_wei_trans_to_vnni_t(
871 j.wei_dt, j.kd, j.kh, j.kw, j.ic_block, j.oc_block)));
872 CHECK(diff_wei_trans_kernel_->create_kernel());
873 }
874 return status::success;
875}
876
877struct jit_avx512_core_amx_convolution_bwd_weights_t::thread_info_t {
878 const src_data_t *src = nullptr;
879 const diff_dst_data_t *diff_dst = nullptr;
880 const void *diff_weights = nullptr;
881 const void *diff_bias = nullptr;
882
883 const memory_tracking::grantor_t scratchpad;
884
885 src_data_t *tr_src = nullptr;
886 diff_dst_data_t *tr_diff_dst = nullptr;
887 simple_barrier::ctx_t *tr_src_bctx = nullptr;
888 simple_barrier::ctx_t *tr_diff_dst_bctx = nullptr;
889
890 float *wei_bia_reduction = nullptr;
891 float *bia_reduction = nullptr;
892 simple_barrier::ctx_t *wei_bia_reduction_bctx = nullptr;
893
894 int ithr = 0;
895 int ithr_ic_b = 0, ithr_oc_b = 0, ithr_g = 0, ithr_mb = 0;
896 int ithr_but_oc = 0;
897 int ithr_but_ic = 0;
898
899 int img_start = 0, img_end = 0, img_work = 0;
900 int g_start = 0, g_end = 0, g_work = 0;
901 int oc_b_start = 0, oc_b_end = 0, oc_b_work = 0;
902 int ic_b_start = 0, ic_b_end = 0, ic_b_work = 0;
903
904 thread_info_t(const jit_avx512_core_amx_convolution_bwd_weights_t *self,
905 const exec_ctx_t &ctx, int ithr)
906 : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) {
907 diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
908 src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
909 diff_weights = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_WEIGHTS);
910
911 const auto &jcp = self->kernel_->jcp;
912 diff_bias = self->pd()->with_bias()
913 && (jcp.oc_without_padding % jcp.oc_block != 0)
914 && self->pd()->jcp_.bia_dt == data_type::f32
915 ? (void *)scratchpad.template get<float>(key_conv_padded_bias)
916 : CTX_OUT_MEM(void *, DNNL_ARG_DIFF_BIAS);
917
918 tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
919 if (jcp.global_transpose)
920 tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
921 key_conv_tr_src_bctx);
922
923 tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
924 key_conv_tr_diff_dst);
925 if (jcp.global_transpose)
926 tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
927 key_conv_tr_diff_dst_bctx);
928 wei_bia_reduction
929 = scratchpad.template get<float>(key_conv_wei_bia_reduction);
930 bia_reduction = nullptr;
931 if (jcp.with_bias) {
932 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
933 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
934 const int num_wei_buffers = jcp.wei_dt == data_type::bf16
935 ? jcp.nthr_mb
936 : jcp.nthr_mb - 1;
937 bia_reduction = wei_bia_reduction + wei_size * num_wei_buffers;
938 }
939
940 wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
941 key_conv_wei_bia_reduction_bctx);
942
943 ithr_ic_b = ithr % self->nthr_ic_b_;
944 ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
945 ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
946 ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
947
948 ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
949 + ithr_ic_b;
950
951 ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
952 + ithr_oc_b;
953
954 int work_amount = jcp.nthr_mb_work;
955 /* reduction dimension */
956 balance211(work_amount, self->nthr_mb_, ithr_mb, img_start, img_end);
957 img_work = img_end - img_start;
958
959 /* independent dimensions */
960 balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
961 g_work = g_end - g_start;
962
963 balance211(
964 jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, oc_b_end);
965 oc_b_work = oc_b_end - oc_b_start;
966
967 balance211(
968 jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, ic_b_end);
969 if (jcp.transform_to_vnni) {
970 if (ic_b_start % 2 != 0) ic_b_start++;
971 if (ic_b_end != jcp.nb_ic && ic_b_end % 2 != 0) ic_b_end++;
972 }
973 ic_b_work = ic_b_end - ic_b_start;
974 }
975};
976
977size_t jit_avx512_core_amx_convolution_bwd_weights_t::tr_src_buf_number(
978 const thread_info_t *ti, int g, int ic) const {
979 const jit_conv_conf_t &jcp = this->kernel_->jcp;
980 return jcp.global_transpose
981 ? ti->ithr_mb * jcp.nb_ic * jcp.ngroups + g * jcp.nb_ic + ic
982 : ti->ithr;
983}
984
985size_t jit_avx512_core_amx_convolution_bwd_weights_t::tr_diff_dst_buf_number(
986 const thread_info_t *ti, int g, int oc) const {
987 const jit_conv_conf_t &jcp = this->kernel_->jcp;
988 return jcp.global_transpose
989 ? ti->ithr_mb * jcp.nb_oc * jcp.ngroups + g * jcp.nb_oc + oc
990 : ti->ithr;
991}
992
993void jit_avx512_core_amx_convolution_bwd_weights_t::trans_src_nxc(
994 src_data_t *tr_src, const src_data_t *src_base, int spatial_start,
995 dim_t spatial_start_offset, int icb_start, dim_t chb_stride,
996 int row_count) const {
997 const jit_conv_conf_t &jcp = this->kernel_->jcp;
998 const int src_stride = jcp.iw * jcp.ngroups * jcp.ic;
999 const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
1000
1001 int work_rest = row_count;
1002 int max_spatial_work = jcp.id * jcp.ih;
1003 int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start);
1004 const src_data_t *src = src_base + spatial_start_offset;
1005 int icb = 0;
1006 const int ic_tail_work = jcp.ic_tail ? jcp.ic_tail : jcp.ic_block;
1007 while (work_rest > 0) {
1008 for (int iwork = 0; iwork < sp_work; iwork++) {
1009 auto ctx = jit_trans_src_t::ctx_t();
1010 ctx.src = src;
1011 ctx.tr_src = tr_src;
1012 assert(icb_start + icb < jcp.nb_ic);
1013 ctx.ch_work = (icb_start + icb + 1) == jcp.nb_ic ? ic_tail_work
1014 : jcp.ic_block;
1015 ctx.src_prf = nullptr;
1016 ctx.tr_src_prf = nullptr;
1017 (*trans_kernel_)(&ctx);
1018 src += src_stride;
1019 tr_src += tr_src_stride;
1020 }
1021 work_rest -= sp_work;
1022 sp_work = nstl::min(work_rest, max_spatial_work);
1023 icb++;
1024 src = src_base + icb * chb_stride;
1025 }
1026}
1027
1028void jit_avx512_core_amx_convolution_bwd_weights_t::trans_dst_nxc(
1029 diff_dst_data_t *tr_diff_dst, const diff_dst_data_t *diff_dst_base,
1030 int spatial_start, dim_t spatial_start_offset, int ocb_start,
1031 dim_t chb_stride, int row_count) const {
1032 const jit_conv_conf_t &jcp = this->kernel_->jcp;
1033 const int diff_dst_stride = jcp.ow * jcp.ngroups * jcp.oc;
1034 const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block;
1035 int work_rest = row_count;
1036 int max_spatial_work = jcp.od * jcp.oh;
1037 int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start);
1038 const src_data_t *diff_dst = diff_dst_base + spatial_start_offset;
1039 int ocb = 0;
1040 const int oc_tail_work = jcp.oc_tail ? jcp.oc_tail : jcp.oc_block;
1041 while (work_rest > 0) {
1042 for (int iwork = 0; iwork < sp_work; iwork++) {
1043 auto ctx = jit_trans_dst_t::ctx_t();
1044 ctx.src = diff_dst;
1045 ctx.tr_src = tr_diff_dst;
1046 assert(ocb_start + ocb < jcp.nb_oc);
1047 ctx.ch_work = (ocb_start + ocb + 1) == jcp.nb_oc ? oc_tail_work
1048 : jcp.oc_block;
1049 ctx.src_prf = nullptr;
1050 ctx.tr_src_prf = nullptr;
1051 (*trans_dst_kernel_)(&ctx);
1052 diff_dst += diff_dst_stride;
1053 tr_diff_dst += tr_diff_dst_stride;
1054 }
1055 work_rest -= sp_work;
1056 sp_work = nstl::min(work_rest, max_spatial_work);
1057 ocb++;
1058 diff_dst = diff_dst_base + ocb * chb_stride;
1059 }
1060}
1061
1062void jit_avx512_core_amx_convolution_bwd_weights_t::compute_diff_weights_2d(
1063 const thread_info_t *ti) const {
1064 const memory_desc_wrapper src_d(pd()->src_md());
1065 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1066 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1067 const auto &jcp = kernel_->jcp;
1068
1069 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1070 * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
1071 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1072 const int optimal_hblock = jcp.spatial_blk_size;
1073
1074 float *diff_wei;
1075 if (diff_weights_d.data_type() == data_type::bf16)
1076 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
1077 else
1078 diff_wei = ti->ithr_mb == 0
1079 ? (float *)ti->diff_weights
1080 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1081
1082 float *diff_bias = nullptr;
1083 if (jcp.with_bias) {
1084 if (jcp.bia_dt == data_type::bf16)
1085 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
1086 else
1087 diff_bias = ti->ithr_mb == 0
1088 ? (float *)ti->diff_bias
1089 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
1090 }
1091
1092 auto tr_diff_dst_off = [&](int g, int oc, int oj) {
1093 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1094 int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking;
1095 return tr_diff_dst_buf_number(ti, g, oc) * adj
1096 * jcp.tr_diff_dst_buf_size
1097 + oj * tr_row_size;
1098 };
1099
1100 int img {0}, oh_s {0};
1101 int start = ti->img_start;
1102 int end = ti->img_end;
1103
1104 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1105
1106 nd_iterator_init(start, img, jcp.mb, oh_s, jcp.oh);
1107
1108 while (start < end) {
1109 auto p = jit_conv_call_s();
1110 int work_rem = end - start;
1111 const int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1112 int ih_s = nstl::max(0, -jcp.t_pad + oh_s * jcp.stride_h);
1113 const int ih_e = nstl::min(
1114 jcp.ih, -jcp.t_pad + (oh_e - 1) * jcp.stride_h + ext_kh);
1115
1116 auto tr_src_off = [&](int g, int ic, int ih_end, int ij) {
1117 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1118 int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking;
1119 // Aligned to buffer end to use guard elements
1120 return tr_src_buf_number(ti, g, ic) * adj * jcp.tr_src_buf_size
1121 + (jcp.ih - ih_end + ij) * tr_row_size;
1122 };
1123
1124 if (jcp.global_transpose) {
1125 using simple_barrier::barrier;
1126 // TODO: try to call local transpositions just before jit kernel
1127 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
1128 int j {0};
1129 int work_amount = ti->g_work * ti->ic_b_work * (ih_e - ih_s);
1130 int tr_start {0}, tr_end {0};
1131 balance211(
1132 work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start, tr_end);
1133
1134 int g {0}, ic_b {0};
1135 nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work, j,
1136 ih_e - ih_s);
1137
1138 if (nthr_oc_b_ > 1)
1139 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1140 while (tr_start < tr_end) {
1141 int g_ = g + ti->g_start;
1142 int ic_b_ = ic_b + ti->ic_b_start;
1143 int j_s = j + ih_s;
1144 int j_e = j_s + nstl::min(tr_end - tr_start, ih_e - j_s);
1145 const int ic_off_idx = g_ * jcp.ic + ic_b_ * jcp.ic_block;
1146 const src_data_t *src
1147 = &ti->src[src_d.blk_off(img, ic_off_idx, j_s)];
1148 src_data_t *tr_src
1149 = &ti->tr_src[tr_src_off(g_, ic_b_, ih_e, j_s)];
1150 trans_src_nxc(tr_src, src, 0, 0, ic_b_, 0, j_e - j_s);
1151 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b,
1152 ti->ic_b_work, j, ih_e - ih_s);
1153 }
1154 if (nthr_oc_b_ > 1)
1155 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1156
1157 j = 0;
1158 work_amount = ti->g_work * ti->oc_b_work * (oh_e - oh_s);
1159 tr_start = 0;
1160 tr_end = 0;
1161 balance211(
1162 work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start, tr_end);
1163
1164 g = 0;
1165 int oc_b = 0;
1166 nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work, j,
1167 oh_e - oh_s);
1168
1169 if (nthr_ic_b_ > 1)
1170 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1171 while (tr_start < tr_end) {
1172 int g_ = g + ti->g_start;
1173 int oc_b_ = oc_b + ti->oc_b_start;
1174 int j_s = j + oh_s;
1175 int j_e = j_s + nstl::min(tr_end - tr_start, oh_e - j_s);
1176 const int oc_off_idx = g_ * jcp.oc + oc_b_ * jcp.oc_block;
1177 const diff_dst_data_t *diff_dst
1178 = &ti->diff_dst[diff_dst_d.blk_off(
1179 img, oc_off_idx, j_s)];
1180 diff_dst_data_t *tr_diff_dst
1181 = &ti->tr_diff_dst[tr_diff_dst_off(g_, oc_b_, j_s)];
1182
1183 trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0, j_e - j_s);
1184
1185 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b,
1186 ti->oc_b_work, j, oh_e - oh_s);
1187 }
1188 if (nthr_ic_b_ > 1)
1189 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1190 }
1191 int height_block = jcp.global_transpose ? oh_e - oh_s : optimal_hblock;
1192
1193 for_(int ohb_s = oh_s; ohb_s < oh_e; ohb_s += height_block)
1194 for_(int g = ti->g_start; g < ti->g_end; ++g)
1195 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end;
1196 oc_b += jcp.nb_oc_blocking)
1197 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1198 ic_b += jcp.nb_ic_blocking) {
1199 const int ohb_e = nstl::min(ohb_s + height_block, oh_e);
1200 const int ihb_s = nstl::max(0, -jcp.t_pad + ohb_s * jcp.stride_h);
1201 const int ihb_e = nstl::min(
1202 jcp.ih, -jcp.t_pad + (ohb_e - 1) * jcp.stride_h + ext_kh);
1203 assert(IMPLICATION(jcp.global_transpose,
1204 oh_s == ohb_s && oh_e == ohb_e && ih_s == ihb_s
1205 && ih_e == ihb_e));
1206 const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ti->ic_b_end)
1207 ? 1
1208 : jcp.nb_ic_blocking;
1209 const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > ti->oc_b_end)
1210 ? 1
1211 : jcp.nb_oc_blocking;
1212 const int ic_to_compute
1213 = this_block_size((ic_b + nb_ic_blocks - 1) * jcp.ic_block,
1214 jcp.ic, jcp.ic_block);
1215 const int oc_to_compute
1216 = this_block_size((oc_b + nb_oc_blocks - 1) * jcp.oc_block,
1217 jcp.oc, jcp.oc_block);
1218
1219 if (!jcp.global_transpose) {
1220 src_data_t *tr_src
1221 = &ti->tr_src[tr_src_off(0, 0, ihb_e, ihb_s)];
1222
1223 for (int icb = 0; icb < nb_ic_blocks; icb++) {
1224 const int ic_off_idx
1225 = g * jcp.ic + (ic_b + icb) * jcp.ic_block;
1226 const src_data_t *src
1227 = (src_data_t *)&ti->src[src_d.blk_off(
1228 img, ic_off_idx, ihb_s)];
1229 src_data_t *tr_src_local
1230 = tr_src + icb * jcp.tr_src_buf_size;
1231 trans_src_nxc(tr_src_local, src, 0, 0, (ic_b + icb), 0,
1232 ihb_e - ihb_s);
1233 }
1234 p.src = tr_src;
1235 } else {
1236 p.src = &ti->tr_src[tr_src_off(g, ic_b, ihb_e, ihb_s)];
1237 }
1238
1239 if (!jcp.global_transpose) {
1240 diff_dst_data_t *tr_diff_dst
1241 = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0)];
1242 for (int ocb = 0; ocb < nb_oc_blocks; ocb++) {
1243 const int oc_off_idx
1244 = g * jcp.oc + (oc_b + ocb) * jcp.oc_block;
1245 const diff_dst_data_t *diff_dst
1246 = &ti->diff_dst[diff_dst_d.blk_off(
1247 img, oc_off_idx, ohb_s)];
1248 diff_dst_data_t *tr_diff_dst_local
1249 = tr_diff_dst + ocb * jcp.tr_diff_dst_buf_size;
1250 trans_dst_nxc(tr_diff_dst_local, diff_dst, 0, 0,
1251 (oc_b + ocb), 0, ohb_e - ohb_s);
1252 }
1253 p.dst = tr_diff_dst;
1254 } else {
1255 p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, ohb_s)];
1256 }
1257
1258 p.filt = (jcp.transform_to_vnni)
1259 ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0)
1260 : diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1261 p.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
1262 + oc_b * jcp.oc_block;
1263 p.channel = (start == ti->img_start) && (ohb_s == oh_s);
1264
1265 p.reduce_work = ic_to_compute;
1266 p.load_work = oc_to_compute; // it's only for mask
1267 p.os_index_begin = ohb_s;
1268 p.os_index_end = ohb_e;
1269 p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0);
1270 p.last_ic_block = (nb_ic_blocks == jcp.nb_ic_blocking) ? 0 : 1;
1271 p.last_oc_block = (nb_oc_blocks == jcp.nb_oc_blocking) ? 0 : 1;
1272 assert(ohb_e <= jcp.oh);
1273 (*kernel_)(&p);
1274 }
1275
1276 nd_iterator_jump(start, end, img, jcp.mb, oh_s, jcp.oh);
1277 }
1278}
1279
1280void jit_avx512_core_amx_convolution_bwd_weights_t::compute_diff_weights_3d(
1281 const thread_info_t *ti) const {
1282 const memory_desc_wrapper src_d(pd()->src_md());
1283 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1284 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1285 const auto &jcp = kernel_->jcp;
1286 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1287 * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
1288 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1289 const int optimal_dblock = jcp.spatial_blk_size;
1290
1291 float *diff_wei;
1292 if (diff_weights_d.data_type() == data_type::bf16)
1293 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
1294 else
1295 diff_wei = ti->ithr_mb == 0
1296 ? (float *)ti->diff_weights
1297 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1298
1299 float *diff_bias = nullptr;
1300 if (jcp.with_bias) {
1301 if (jcp.bia_dt == data_type::bf16)
1302 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
1303 else
1304 diff_bias = ti->ithr_mb == 0
1305 ? (float *)ti->diff_bias
1306 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
1307 }
1308
1309 auto tr_diff_dst_off_3d = [&](int g, int oc, int od) {
1310 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1311 const size_t tr_3d_size = tr_row_size * jcp.oh;
1312 int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking;
1313 return tr_diff_dst_buf_number(ti, g, oc) * adj
1314 * jcp.tr_diff_dst_buf_size
1315 + od * tr_3d_size;
1316 };
1317 int img {0}, od_s {0};
1318 int start = ti->img_start;
1319 int end = ti->img_end;
1320
1321 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1322
1323 nd_iterator_init(start, img, jcp.mb, od_s, jcp.od);
1324 while (start < end) {
1325 auto p = jit_conv_call_s();
1326 int work_rem = end - start;
1327 const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
1328 int id_s = nstl::max(0, -jcp.f_pad + od_s * jcp.stride_d);
1329 const int id_e = nstl::min(
1330 jcp.id, -jcp.f_pad + (od_e - 1) * jcp.stride_d + ext_kd);
1331
1332 auto tr_src_off_3d = [&](int g, int ic, int id_end, int id) {
1333 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1334 const size_t tr_3d_size = tr_row_size * jcp.ih;
1335 int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking;
1336 // Aligned to buffer end to use guard elements
1337 return tr_src_buf_number(ti, g, ic) * adj * jcp.tr_src_buf_size
1338 + (jcp.id - id_end + id) * tr_3d_size;
1339 };
1340
1341 if (jcp.global_transpose) {
1342 using simple_barrier::barrier;
1343 // TODO: try to call local transpositions just before jit kernel
1344 /* tr_src[nb_ic][id][16][~iw~] <- src[nb_ic][id][iw][16] */
1345 int d {0};
1346
1347 int work_amount = ti->g_work * ti->ic_b_work * (id_e - id_s);
1348
1349 int tr_start {0}, tr_end {0};
1350 balance211(
1351 work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start, tr_end);
1352
1353 int g {0}, ic_b {0};
1354
1355 nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work, d,
1356 id_e - id_s);
1357
1358 if (nthr_oc_b_ > 1)
1359 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1360 while (tr_start < tr_end) {
1361 int g_ = g + ti->g_start;
1362 int ic_b_ = ic_b + ti->ic_b_start;
1363 int d_s = d + id_s;
1364 int d_e = d_s + nstl::min(tr_end - tr_start, id_e - d_s);
1365
1366 const int ic_off_idx = g_ * jcp.ic + ic_b_ * jcp.ic_block;
1367 const src_data_t *src
1368 = &ti->src[src_d.blk_off(img, ic_off_idx, d_s)];
1369 src_data_t *tr_src
1370 = &ti->tr_src[tr_src_off_3d(g_, ic_b_, id_e, d_s)];
1371 trans_src_nxc(
1372 tr_src, src, 0, 0, ic_b_, 0, (d_e - d_s) * jcp.ih);
1373 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b,
1374 ti->ic_b_work, d, id_e - id_s);
1375 }
1376 if (nthr_oc_b_ > 1)
1377 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1378
1379 d = 0;
1380
1381 work_amount = ti->g_work * ti->oc_b_work * (od_e - od_s);
1382
1383 tr_start = 0, tr_end = 0;
1384 balance211(
1385 work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start, tr_end);
1386
1387 g = 0;
1388 int oc_b = 0;
1389
1390 nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work, d,
1391 od_e - od_s);
1392
1393 if (nthr_ic_b_ > 1)
1394 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1395 while (tr_start < tr_end) {
1396 int g_ = g + ti->g_start;
1397 int oc_b_ = oc_b + ti->oc_b_start;
1398 int d_s = d + od_s;
1399 int d_e = d_s + nstl::min(tr_end - tr_start, od_e - d_s);
1400 const int oc_off_idx = g_ * jcp.oc + oc_b_ * jcp.oc_block;
1401
1402 const diff_dst_data_t *diff_dst
1403 = &ti->diff_dst[diff_dst_d.blk_off(
1404 img, oc_off_idx, d_s)];
1405 diff_dst_data_t *tr_diff_dst
1406 = &ti->tr_diff_dst[tr_diff_dst_off_3d(g_, oc_b_, d_s)];
1407
1408 trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0,
1409 (d_e - d_s) * jcp.oh);
1410
1411 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b,
1412 ti->oc_b_work, d, od_e - od_s);
1413 }
1414 if (nthr_ic_b_ > 1)
1415 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1416 }
1417
1418 int depth_block = jcp.global_transpose ? od_e - od_s : optimal_dblock;
1419
1420 for_(int odb_s = od_s; odb_s < od_e; odb_s += depth_block)
1421 for_(int g = ti->g_start; g < ti->g_end; ++g)
1422 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end;
1423 oc_b += jcp.nb_oc_blocking)
1424 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1425 ic_b += jcp.nb_ic_blocking) {
1426 const int odb_e = nstl::min(odb_s + depth_block, od_e);
1427 const int idb_s = nstl::max(0, -jcp.f_pad + odb_s * jcp.stride_d);
1428 const int idb_e = nstl::min(
1429 jcp.id, -jcp.f_pad + (odb_e - 1) * jcp.stride_d + ext_kd);
1430 const int kdb_front_pad
1431 = nstl::max(0, jcp.f_pad - odb_s * jcp.stride_d);
1432 // Assumes kd_back_pad = 0 when kernel is dilated
1433 const int kdb_back_pad = nstl::max(
1434 0, odb_s * jcp.stride_d + jcp.kd - jcp.f_pad - jcp.id);
1435 const int kdb_pad_off = nstl::min(jcp.kd - 1, kdb_front_pad)
1436 * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block
1437 * jcp.typesize_out;
1438
1439 assert(IMPLICATION(jcp.global_transpose,
1440 od_s == odb_s && od_e == odb_e && id_s == idb_s
1441 && id_e == idb_e));
1442 const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ti->ic_b_end)
1443 ? 1
1444 : jcp.nb_ic_blocking;
1445 const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > ti->oc_b_end)
1446 ? 1
1447 : jcp.nb_oc_blocking;
1448 const int ic_to_compute
1449 = this_block_size((ic_b + nb_ic_blocks - 1) * jcp.ic_block,
1450 jcp.ic, jcp.ic_block);
1451 const int oc_to_compute
1452 = this_block_size((oc_b + nb_oc_blocks - 1) * jcp.oc_block,
1453 jcp.oc, jcp.oc_block);
1454
1455 if (!jcp.global_transpose) {
1456 src_data_t *tr_src
1457 = &ti->tr_src[tr_src_off_3d(0, 0, idb_e, idb_s)];
1458
1459 for (int icb = 0; icb < nb_ic_blocks; icb++) {
1460 const int ic_off_idx
1461 = g * jcp.ic + (ic_b + icb) * jcp.ic_block;
1462 const src_data_t *src
1463 = (src_data_t *)&ti->src[src_d.blk_off(
1464 img, ic_off_idx, idb_s)];
1465 src_data_t *tr_src_local
1466 = tr_src + icb * jcp.tr_src_buf_size;
1467 trans_src_nxc(tr_src_local, src, 0, 0, (ic_b + icb), 0,
1468 (idb_e - idb_s) * jcp.ih);
1469 }
1470 p.src = tr_src;
1471 } else {
1472 p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, idb_e, idb_s)];
1473 }
1474
1475 if (!jcp.global_transpose) {
1476 diff_dst_data_t *tr_diff_dst
1477 = &ti->tr_diff_dst[tr_diff_dst_off_3d(0, 0, 0)];
1478 for (int ocb = 0; ocb < nb_oc_blocks; ocb++) {
1479 const int oc_off_idx
1480 = g * jcp.oc + (oc_b + ocb) * jcp.oc_block;
1481 const diff_dst_data_t *diff_dst
1482 = &ti->diff_dst[diff_dst_d.blk_off(
1483 img, oc_off_idx, odb_s)];
1484 diff_dst_data_t *tr_diff_dst_local
1485 = tr_diff_dst + ocb * jcp.tr_diff_dst_buf_size;
1486 trans_dst_nxc(tr_diff_dst_local, diff_dst, 0, 0,
1487 (oc_b + ocb), 0, (odb_e - odb_s) * jcp.oh);
1488 }
1489 p.dst = tr_diff_dst;
1490 } else {
1491 p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(g, oc_b, odb_s)];
1492 }
1493
1494 p.filt = (jcp.transform_to_vnni)
1495 ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0)
1496 : diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1497 p.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
1498 + oc_b * jcp.oc_block;
1499 p.channel = (start == ti->img_start) && (odb_s == od_s);
1500
1501 p.reduce_work = ic_to_compute;
1502 p.load_work = oc_to_compute;
1503 p.os_index_begin = odb_s;
1504 p.os_index_end = odb_e;
1505 p.kd_padding = jcp.kd - kdb_front_pad - kdb_back_pad;
1506 p.kd_offset = kdb_pad_off;
1507 p.flags = (ic_b == 0 ? FLAG_IC_FIRST : 0);
1508 p.last_ic_block = (nb_ic_blocks == jcp.nb_ic_blocking) ? 0 : 1;
1509 p.last_oc_block = (nb_oc_blocks == jcp.nb_oc_blocking) ? 0 : 1;
1510 assert(odb_e <= jcp.od);
1511 (*kernel_)(&p);
1512 }
1513
1514 nd_iterator_jump(start, end, img, jcp.mb, od_s, jcp.od);
1515 }
1516}
1517
1518void jit_avx512_core_amx_convolution_bwd_weights_t::compute_diff_weights(
1519 const thread_info_t *ti) const {
1520 const memory_desc_wrapper src_d(pd()->src_md());
1521 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1522 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1523 const auto &jcp = kernel_->jcp;
1524 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1525 * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
1526 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1527
1528 float *diff_wei;
1529 if (diff_weights_d.data_type() == data_type::bf16)
1530 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
1531 else
1532 diff_wei = ti->ithr_mb == 0
1533 ? (float *)ti->diff_weights
1534 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1535
1536 float *diff_bias = nullptr;
1537 if (jcp.with_bias) {
1538 if (jcp.bia_dt == data_type::bf16)
1539 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
1540 else
1541 diff_bias = ti->ithr_mb == 0
1542 ? (float *)ti->diff_bias
1543 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
1544 }
1545
1546 auto tr_src_off = [&](int g, int ic, int nb_ic_block, int ij) {
1547 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1548 int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking;
1549 return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size * adj
1550 + ij * tr_row_size + nb_ic_block * jcp.tr_src_buf_size;
1551 };
1552
1553 auto tr_src_off_3d = [&](int g, int ic, int nb_ic_block, int id, int ij) {
1554 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1555 const size_t tr_3d_size = tr_row_size * jcp.ih;
1556 int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking;
1557 return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size * adj
1558 + id * tr_3d_size + ij * tr_row_size
1559 + nb_ic_block * jcp.tr_src_buf_size;
1560 };
1561
1562 auto tr_diff_dst_off = [&](int g, int oc, int nb_oc_block, int oj) {
1563 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1564 int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking;
1565 return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size
1566 * adj
1567 + oj * tr_row_size + nb_oc_block * jcp.tr_diff_dst_buf_size;
1568 };
1569
1570 auto tr_diff_dst_off_3d
1571 = [&](int g, int oc, int nb_oc_block, int od, int oj) {
1572 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1573 const size_t tr_3d_size = tr_row_size * jcp.oh;
1574 int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking;
1575 return tr_diff_dst_buf_number(ti, g, oc)
1576 * jcp.tr_diff_dst_buf_size * adj
1577 + od * tr_3d_size + oj * tr_row_size
1578 + nb_oc_block * jcp.tr_src_buf_size;
1579 };
1580
1581 auto uker_trans = [&](int img, int g = 0, int ic_b = 0,
1582 int nb_ic_block = 0) {
1583 int j {0}, d {0};
1584 int my_work = jcp.ih * jcp.id;
1585 int ic;
1586 int icb_start = ic_b;
1587 if (jcp.global_transpose) {
1588 const int work_amount = ti->ic_b_work * jcp.ih * jcp.id;
1589
1590 int start {0}, end {0};
1591 balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end);
1592 my_work = end - start;
1593
1594 if (jcp.ndims == 5)
1595 nd_iterator_init(
1596 start, ic_b, ti->ic_b_work, d, jcp.id, j, jcp.ih);
1597 else
1598 nd_iterator_init(start, ic_b, ti->ic_b_work, j, jcp.ih);
1599 g += ti->g_start;
1600 ic_b += ti->ic_b_start;
1601 icb_start = ic_b;
1602 ic = g * jcp.ic + ic_b * jcp.ic_block;
1603
1604 } else {
1605 ic = g * jcp.ic + ic_b * jcp.ic_block;
1606 g = 0;
1607 ic_b = 0;
1608 }
1609 const bool need_local_gwork = jcp.global_transpose;
1610 const auto local_gwork = need_local_gwork ? ti->g_work : 1;
1611
1612 for (int gg = g; gg < g + local_gwork; ++gg) {
1613 if (need_local_gwork) ic = gg * jcp.ic + ic_b * jcp.ic_block;
1614 src_data_t *tr_src = (jcp.ndims == 5)
1615 ? &ti->tr_src[tr_src_off_3d(gg, ic_b, nb_ic_block, d, j)]
1616 : &ti->tr_src[tr_src_off(gg, ic_b, nb_ic_block, j)];
1617 auto src_offset = src_d.blk_off(img, ic);
1618 src_data_t *src = (src_data_t *)&ti->src[src_offset];
1619
1620 dim_t sp_start_offset = (jcp.ndims == 5) ? src_d.blk_off(0, 0, d, j)
1621 : src_d.blk_off(0, 0, j);
1622 dim_t ch_shift = src_d.blk_off(0, jcp.ic_block);
1623 int sp_start_idx = d * jcp.ih + j;
1624 trans_src_nxc(tr_src, src, sp_start_idx, sp_start_offset, icb_start,
1625 ch_shift, my_work);
1626 }
1627 };
1628
1629 auto diff_dst_trans = [&](int img, int g = 0, int oc_b = 0,
1630 int nb_oc_block = 0) {
1631 int j {0}, d {0};
1632 int my_work = jcp.oh * jcp.od;
1633 int oc;
1634 int ocb_start = oc_b;
1635
1636 if (jcp.global_transpose) {
1637 const size_t work_amount = ti->oc_b_work * jcp.oh * jcp.od;
1638
1639 size_t start {0}, end {0};
1640 balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, start, end);
1641 my_work = end - start;
1642
1643 if (jcp.ndims == 5)
1644 nd_iterator_init(
1645 start, oc_b, ti->oc_b_work, d, jcp.od, j, jcp.oh);
1646 else
1647 nd_iterator_init(start, oc_b, ti->oc_b_work, j, jcp.oh);
1648
1649 g += ti->g_start;
1650 oc_b += ti->oc_b_start;
1651 ocb_start = oc_b;
1652 oc = g * jcp.oc + oc_b * jcp.oc_block;
1653 } else {
1654 oc = g * jcp.oc + oc_b * jcp.oc_block;
1655 g = 0;
1656 oc_b = 0;
1657 }
1658 const bool need_local_gwork = jcp.global_transpose;
1659 const auto local_gwork = need_local_gwork ? ti->g_work : 1;
1660
1661 for (int gg = g; gg < g + local_gwork; ++gg) {
1662 if (need_local_gwork) oc = gg * jcp.oc + oc_b * jcp.oc_block;
1663 diff_dst_data_t *tr_diff_dst = (jcp.ndims == 5)
1664 ? &ti->tr_diff_dst[tr_diff_dst_off_3d(
1665 gg, oc_b, nb_oc_block, d, j)]
1666 : &ti->tr_diff_dst[tr_diff_dst_off(
1667 gg, oc_b, nb_oc_block, j)];
1668 auto ddst_offset = diff_dst_d.blk_off(img, oc);
1669 const diff_dst_data_t *diff_dst = &ti->diff_dst[ddst_offset];
1670
1671 dim_t sp_start_offset = (jcp.ndims == 5)
1672 ? diff_dst_d.blk_off(0, 0, d, j)
1673 : diff_dst_d.blk_off(0, 0, j);
1674 dim_t ch_shift = diff_dst_d.blk_off(0, jcp.oc_block);
1675 int sp_start_idx = d * jcp.oh + j;
1676 trans_dst_nxc(tr_diff_dst, diff_dst, sp_start_idx, sp_start_offset,
1677 ocb_start, ch_shift, my_work);
1678 }
1679 };
1680
1681 for (int img = ti->img_start; img < ti->img_end; ++img) {
1682 auto p = jit_conv_call_s();
1683 if (jcp.global_transpose) {
1684 using simple_barrier::barrier;
1685 // TODO: try to call local transpositions just before jit kernel
1686 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
1687 if (nthr_oc_b_ > 1)
1688 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1689 uker_trans(img);
1690 if (nthr_oc_b_ > 1)
1691 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1692 if (nthr_ic_b_ > 1)
1693 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1694 diff_dst_trans(img);
1695 if (nthr_ic_b_ > 1)
1696 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1697 }
1698
1699 for_(int g = ti->g_start; g < ti->g_end; ++g)
1700 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end;
1701 oc_b += jcp.nb_oc_blocking)
1702 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1703 ic_b += jcp.nb_ic_blocking) {
1704 const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ti->ic_b_end)
1705 ? 1
1706 : jcp.nb_ic_blocking;
1707 const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > ti->oc_b_end)
1708 ? 1
1709 : jcp.nb_oc_blocking;
1710
1711 const int ic_to_compute
1712 = this_block_size((ic_b + nb_ic_blocks - 1) * jcp.ic_block,
1713 jcp.ic, jcp.ic_block);
1714 const int oc_to_compute
1715 = this_block_size((oc_b + nb_oc_blocks - 1) * jcp.oc_block,
1716 jcp.oc, jcp.oc_block);
1717
1718 if (!jcp.global_transpose) {
1719 for (int nib = 0; nib < nb_ic_blocks; nib++)
1720 uker_trans(img, g, ic_b + nib, nib);
1721 }
1722 if (jcp.ndims == 5) {
1723 p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, 0, 0, 0)];
1724 } else {
1725 p.src = &ti->tr_src[tr_src_off(g, ic_b, 0, 0)];
1726 }
1727
1728 if (!jcp.global_transpose) {
1729 for (int nob = 0; nob < nb_oc_blocks; nob++)
1730 diff_dst_trans(img, g, oc_b + nob, nob);
1731 if (jcp.ndims == 5) {
1732 p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(0, 0, 0, 0, 0)];
1733 } else {
1734 p.dst = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0, 0)];
1735 }
1736 } else {
1737 if (jcp.ndims == 5) {
1738 p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(
1739 g, oc_b, 0, 0, 0)];
1740 } else {
1741 p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, 0, 0)];
1742 }
1743 }
1744
1745 p.filt = (jcp.transform_to_vnni)
1746 ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0)
1747 : diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1748 p.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
1749 + oc_b * jcp.oc_block;
1750 p.channel = (img == ti->img_start);
1751 p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0);
1752
1753 p.reduce_work = ic_to_compute;
1754 p.load_work = oc_to_compute;
1755
1756 p.last_ic_block = (nb_ic_blocks == jcp.nb_ic_blocking) ? 0 : 1;
1757 p.last_oc_block = (nb_oc_blocks == jcp.nb_oc_blocking) ? 0 : 1;
1758
1759 (*kernel_)(&p);
1760 }
1761 }
1762}
1763
1764void jit_avx512_core_amx_convolution_bwd_weights_t::store_in_vnni_format(
1765 const thread_info_t *ti) const {
1766 const auto &jcp = kernel_->jcp;
1767
1768 for_(int g = ti->g_start; g < ti->g_end; g++)
1769 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++)
1770 for_(int ic_b = ti->ic_b_start; ic_b < ti->ic_b_start + ti->ic_b_work;
1771 ic_b += 2)
1772 {
1773 jit_conv_call_s p = jit_conv_call_s();
1774
1775 bfloat16_t *output = (bfloat16_t *)ti->diff_weights
1776 + wei_offset_ext(g, oc_b, (ic_b / 2), 0);
1777 float *input = ti->wei_bia_reduction + wei_offset_int(g, oc_b, ic_b, 0);
1778
1779 p.src = (void *)input;
1780 p.dst = (void *)output;
1781 p.last_ic_block = ((ic_b + 1) >= jcp.nb_ic) ? 1 : 0;
1782 (*diff_wei_trans_kernel_)(&p);
1783 }
1784}
1785
1786void jit_avx512_core_amx_convolution_bwd_weights_t::
1787 reduce_and_convert_diff_weights_and_bias(
1788 const thread_info_t *ti) const {
1789 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1790
1791 const auto &jcp = kernel_->jcp;
1792 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1793 * jcp.ic_block * jcp.kh * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1);
1794
1795 const bool is_bf16_out = diff_weights_d.data_type() == data_type::bf16;
1796 const bool is_bf16_bias = jcp.with_bias && jcp.bia_dt == data_type::bf16;
1797
1798 if (nthr_mb_ == 1) {
1799 if (is_bf16_out) {
1800 // reduction is not required, only conversion
1801 if (jcp.transform_to_vnni) {
1802 store_in_vnni_format(ti);
1803 } else {
1804 for_(int g = ti->g_start; g < ti->g_end; g++)
1805 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) {
1806 const size_t acc_size = (size_t)ti->ic_b_work * jcp.kh
1807 * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1)
1808 * jcp.ic_block * jcp.oc_block;
1809 const size_t off = wht_blk_off(
1810 diff_weights_d, g, oc_b, ti->ic_b_start);
1811 cvt_float_to_bfloat16(
1812 (bfloat16_t *)(ti->diff_weights) + off,
1813 (ti->wei_bia_reduction + off), acc_size);
1814 }
1815 }
1816 }
1817 if (is_bf16_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0) {
1818 for (int g = ti->g_start; g < ti->g_end; g++) {
1819 int result_start_idx = g * jcp.oc_without_padding
1820 + ti->oc_b_start * jcp.oc_block;
1821 int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block)
1822 + ti->oc_b_start * jcp.oc_block;
1823 const size_t acc_size = nstl::min(jcp.oc_without_padding,
1824 ti->oc_b_end * jcp.oc_block)
1825 - ti->oc_b_start * jcp.oc_block;
1826 bfloat16_t *diff_bias
1827 = (bfloat16_t *)ti->diff_bias + result_start_idx;
1828 float *buffer = ti->bia_reduction + buffer_start_idx;
1829 cvt_float_to_bfloat16(diff_bias, buffer, acc_size);
1830 }
1831 }
1832 return;
1833 }
1834
1835 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1836 if (jcp.global_transpose)
1837 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1838
1839 const int ic_b_kh_work
1840 = ti->ic_b_work * ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1841 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1842
1843 int start {0}, end {0};
1844 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1845 if (!jcp.transform_to_vnni && start == end) return;
1846
1847 const int _start_nthr_mb = 1;
1848 for (int thr_mb = _start_nthr_mb; thr_mb < nthr_mb_; ++thr_mb) {
1849 int w = start;
1850 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0};
1851 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1852 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1853 while (w < end) {
1854 const int g = ti->g_start + sub_g_start;
1855 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1856 const int ic_b = ti->ic_b_start
1857 + sub_ic_b_kh_start / ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1858 const int kX
1859 = sub_ic_b_kh_start % ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1860
1861 const size_t acc_size = (size_t)jcp.kw * jcp.ic_block * jcp.oc_block
1862 * ((jcp.ndims == 5) ? jcp.kh : 1)
1863 * nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start);
1864
1865 const size_t off_ext
1866 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kX);
1867 const size_t off_int = (jcp.transform_to_vnni)
1868 ? wei_offset_int(g, oc_b, ic_b, kX)
1869 : off_ext;
1870
1871 float *wei_reduced = is_bf16_out
1872 ? ti->wei_bia_reduction + off_int
1873 : (float *)(ti->diff_weights) + off_ext;
1874
1875 int thr_mb_buffer_idx = is_bf16_out ? thr_mb : thr_mb - 1;
1876 float *wei_to_reduce = ti->wei_bia_reduction
1877 + thr_mb_buffer_idx * wei_size + off_int;
1878
1879 if (!jcp.transform_to_vnni && is_bf16_out && thr_mb == nthr_mb_ - 1)
1880 // the last iteration for bfloat16 requires conversion and
1881 // store to diff_weights array
1882 add_floats_and_cvt_to_bfloat16(
1883 (bfloat16_t *)(ti->diff_weights) + off_ext, wei_reduced,
1884 wei_to_reduce, acc_size);
1885 else
1886 acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size);
1887
1888 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1889 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1890 }
1891 if (jcp.with_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0
1892 && ti->ithr_mb == 0 && ti->img_work > 0) {
1893 for (int g = ti->g_start; g < ti->g_end; g++) {
1894 float *bias_reduced = is_bf16_bias ? ti->bia_reduction
1895 : (float *)(ti->diff_bias);
1896 int thr_mb_buffer_idx = is_bf16_bias ? thr_mb : thr_mb - 1;
1897 int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1898 float *bias_to_reduce
1899 = ti->bia_reduction + thr_mb_buffer_idx * bias_buf_size;
1900 const size_t acc_size = nstl::min(jcp.oc_without_padding,
1901 ti->oc_b_end * jcp.oc_block)
1902 - ti->oc_b_start * jcp.oc_block;
1903 int idx = g * rnd_up(jcp.oc, jcp.oc_block)
1904 + ti->oc_b_start * jcp.oc_block;
1905 if (is_bf16_bias && thr_mb == nthr_mb_ - 1) {
1906 // the last iteration for bfloat16 requires conversion and
1907 // store to diff_weights array
1908 int diff_bias_idx = g * jcp.oc_without_padding
1909 + ti->oc_b_start * jcp.oc_block;
1910 add_floats_and_cvt_to_bfloat16(
1911 (bfloat16_t *)(ti->diff_bias) + diff_bias_idx,
1912 &bias_reduced[idx], &bias_to_reduce[idx], acc_size);
1913 } else {
1914 acc_ker_->accumulate(
1915 &bias_reduced[idx], &bias_to_reduce[idx], acc_size);
1916 }
1917 }
1918 }
1919 }
1920
1921 if (jcp.transform_to_vnni && jcp.global_transpose) {
1922 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1923 store_in_vnni_format(ti);
1924 }
1925}
1926
1927void jit_avx512_core_amx_convolution_bwd_weights_t::prepare_scratchpad_data(
1928 const exec_ctx_t &ctx) const {
1929 auto scratchpad = ctx.get_scratchpad_grantor();
1930
1931 const auto &jcp = pd()->jcp_;
1932
1933 // XXX: See the comment about tr_iw and guarding elements in
1934 // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf()
1935 auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
1936 // Zero out guard elements that cross a buffer boundary to prevent a
1937 // race condition due to buffer overflows from memory optimization where
1938 // buffers sharing padding
1939
1940 for (size_t ithr = 1; ithr <= jcp.tr_src_buf_count; ++ithr) {
1941 src_data_t *ts
1942 = &tr_src[ithr * jcp.tr_src_buf_size * jcp.nb_ic_blocking];
1943 for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i)
1944 ts[i] = 0;
1945 }
1946
1947 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
1948 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
1949 auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1950 key_conv_tr_src_bctx);
1951 for (int i = 0; i < tr_src_bctx_size; ++i)
1952 simple_barrier::ctx_init(&tr_src_bctx[i]);
1953 }
1954 if (jcp.global_transpose) {
1955 if (jcp.nthr_ic_b > 1) {
1956 const int tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
1957 auto tr_diff_dst_bctx
1958 = scratchpad.template get<simple_barrier::ctx_t>(
1959 key_conv_tr_diff_dst_bctx);
1960 for (int i = 0; i < tr_diff_dst_bctx_size; ++i)
1961 simple_barrier::ctx_init(&tr_diff_dst_bctx[i]);
1962 }
1963 }
1964
1965 if (nthr_mb_ > 1
1966 || pd()->diff_weights_md(0)->data_type == data_type::bf16) {
1967 // TODO: don't use barrier for case
1968 // diff_weights_type == data_type::bf16 && nthr_mb_ == 1
1969 simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1970 key_conv_wei_bia_reduction_bctx));
1971 }
1972}
1973
1974void jit_avx512_core_amx_convolution_bwd_weights_t::execute_backward_weights(
1975 const exec_ctx_t &ctx) const {
1976 prepare_scratchpad_data(ctx);
1977
1978 auto tcfg = ctx.get_scratchpad_grantor().template get<char>(
1979 key_conv_amx_tilecfg);
1980 kernel_->tile_configure(tcfg);
1981
1982 const auto &jcp = pd()->jcp_;
1983 parallel(nthr_, [&](const int ithr, const int nthr) {
1984 assert(nthr_ == nthr);
1985 assert(utils::one_of(pd()->ndims(), 3, 4, 5));
1986
1987 amx_tile_configure(tcfg);
1988
1989 thread_info_t thread_info(this, ctx, ithr);
1990 switch (jcp.harness) {
1991 case harness_2d_reduction:
1992 compute_diff_weights_2d(&thread_info);
1993 if (jcp.global_transpose)
1994 reduce_and_convert_diff_weights_and_bias(&thread_info);
1995 break;
1996 case harness_3d_reduction:
1997 compute_diff_weights_3d(&thread_info);
1998 if (jcp.global_transpose)
1999 reduce_and_convert_diff_weights_and_bias(&thread_info);
2000 break;
2001 case harness_compute_full_spatial:
2002 case harness_mb_reduction:
2003 compute_diff_weights(&thread_info);
2004 if (jcp.global_transpose)
2005 reduce_and_convert_diff_weights_and_bias(&thread_info);
2006 break;
2007 default: assert(!"Invalid harness type");
2008 }
2009
2010 amx_tile_release();
2011 });
2012
2013 if (!jcp.global_transpose) {
2014 parallel(nthr_, [&](const int ithr, const int nthr) {
2015 assert(nthr_ == nthr);
2016 thread_info_t thread_info(this, ctx, ithr);
2017 reduce_and_convert_diff_weights_and_bias(&thread_info);
2018 });
2019 }
2020
2021 if (jcp.transform_to_vnni && !jcp.global_transpose) {
2022 parallel(nthr_, [&](const int ithr, const int nthr) {
2023 assert(nthr_ == nthr);
2024 thread_info_t thread_info(this, ctx, ithr);
2025 store_in_vnni_format(&thread_info);
2026 });
2027 }
2028
2029 if (pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0)
2030 && jcp.bia_dt != data_type::bf16) {
2031 auto diff_bias = ctx.get_scratchpad_grantor().template get<const float>(
2032 key_conv_padded_bias);
2033 auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
2034 const int padded_stride = rnd_up(jcp.oc, jcp.oc_block);
2035 const int stride = jcp.oc_without_padding;
2036 for (int g = 0; g < jcp.ngroups; ++g) {
2037 utils::array_copy(diff_bias_in + g * stride,
2038 diff_bias + g * padded_stride, stride);
2039 }
2040 }
2041}
2042
2043} // namespace x64
2044} // namespace cpu
2045} // namespace impl
2046} // namespace dnnl
2047
2048// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
2049