1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/nstl.hpp"
20#include "common/type_helpers.hpp"
21#include "common/utils.hpp"
22#include "cpu/cpu_primitive.hpp"
23
24#include "cpu/x64/jit_brgemm_conv_bwd_strided.hpp"
25#include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32using namespace dnnl::impl::status;
33using namespace dnnl::impl::memory_tracking::names;
34using namespace dnnl::impl::utils;
35
36using namespace nstl;
37using namespace data_type;
38
39using namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel;
40
41#define ndims_pick(v5, v4, v3) \
42 ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0)
43
44static bool impl_supports_datatype(data_type_t data_type) {
45 switch (data_type) {
46 case data_type::bf16: return x64::mayiuse(x64::avx512_core);
47 case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16);
48 case data_type::f32:
49 case data_type::s32:
50 case data_type::s8:
51 case data_type::u8: return true;
52 default: return false;
53 }
54}
55
56template <cpu_isa_t isa, bool enable_postops>
57status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::pd_t::init(
58 engine_t *engine) {
59 using namespace data_type;
60
61 const auto diff_src_type = diff_src_md(0)->data_type;
62 const auto wei_type = weights_md(0)->data_type;
63 const auto diff_dst_type = diff_dst_md(0)->data_type;
64
65 using skip_mask_t = primitive_attr_t::skip_mask_t;
66 auto skip_mask = enable_postops
67 ? (skip_mask_t::post_ops | skip_mask_t::sum_dt
68 | skip_mask_t::zero_points_runtime)
69 : skip_mask_t::none;
70
71 const bool ok = is_bwd_d()
72 && set_default_alg_kind(alg_kind::convolution_direct)
73 && impl_supports_datatype(diff_src_type)
74 && impl_supports_datatype(wei_type)
75 && impl_supports_datatype(diff_dst_type)
76 && one_of(wei_type, f32, bf16, f16) && wei_type == diff_dst_type
77 && one_of(diff_src_type, wei_type, f32)
78 && one_of(with_bias(), one_of(bias_md_.data_type, f32, wei_type))
79 && attr()->has_default_values(skip_mask, diff_src_type)
80 && IMPLICATION(enable_postops,
81 attr()->post_ops_.check_sum_consistent_dt(diff_src_type))
82 && !has_zero_dim_memory();
83
84 if (!ok) return status::unimplemented;
85
86 const auto is_amx = brgemm_convolution_bwd_utils::is_amx(isa);
87
88 CHECK(brgemm_convolution_bwd_utils::init_conf(jcp_, isa, desc_,
89 diff_dst_md_, weights_md_, diff_src_md_, bias_md_, attr_,
90 dnnl_get_max_threads(), enable_postops));
91
92 const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail);
93
94 batchsizes.resize(jcp_.max_batch + 1);
95 for (int i = 0; i <= jcp_.max_batch; i++)
96 batchsizes[i] = -1;
97
98 first_bs = 0;
99 bs_c = 0;
100
101 batchsizes[jcp_.max_batch] = bs_c;
102 first_bs = jcp_.max_batch;
103 bs_c++;
104
105 brgs_sz_ = bs_c * adj_M * 2 * 2 * 2;
106 brgs_.resize(brgs_sz_);
107 bd_masks.resize(brgs_sz_);
108
109 const float alpha = 1.0;
110 const float beta = 1.0;
111
112 const auto &p = attr()->post_ops_;
113 const int sum_idx = p.find(primitive_kind::sum);
114 const bool with_sum = (sum_idx != -1);
115
116 auto maybe_M_mask
117 = [&](int brg_idx, brgemm_attr_t &brgattr, int vM, int vbrgM) {
118 if (!jcp_.use_M_mask) return;
119 auto sm_size = vbrgM;
120 bd_masks[brg_idx] = std::make_shared<std::vector<char>>();
121 bd_masks[brg_idx]->resize(sm_size);
122 char *bd_mask = bd_masks[brg_idx]->data();
123 for (int ibrgM = 0; ibrgM < sm_size; ibrgM++) {
124 bd_mask[ibrgM] = 1;
125 }
126 brgattr.bd_mask = bd_mask;
127 };
128
129 const auto M_end = nstl::max(jcp_.M, jcp_.M_tail);
130 for (int i = 0; i < M_end; i++) {
131 auto vM = i + 1;
132 // init only needed brgemm descriptors
133 if (one_of(jcp_.exec_type, exec_trans, exec_vpad) && vM != jcp_.M
134 && vM != jcp_.M_tail)
135 continue;
136 for (int bs = 0; bs <= jcp_.max_batch; bs++) {
137 if (batchsizes[bs] == -1) continue;
138 for_(int i_init = 0; i_init < 2; i_init++)
139 for_(int i_N = 0; i_N < 2; i_N++)
140 for (int i_K = 0; i_K < 2; i_K++) {
141 auto vbeta = (i_init) ? 0 : beta;
142 auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
143 auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
144 auto vbrgM = jcp_.use_M_mask
145 ? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail)
146 : vM;
147 auto brg_idx = get_brg_idx(bs, i, i_init, i_N, i_K);
148 // if brgemm_t already created then skip this iteration
149 if (brgs_[brg_idx] != nullptr) continue;
150 brgs_[brg_idx] = std::make_shared<brgemm_t>();
151 brgemm_t *brg = brgs_[brg_idx].get();
152 if (vN == 0 || vK == 0) continue;
153 brgemm_strides_t brg_strides;
154 brg_strides.stride_a = jcp_.brg_stride_a;
155 brg_strides.stride_b = jcp_.brg_stride_b;
156 brg->req_cal_comp_pads = jcp_.req_brg_comp_pad
157 && nstl::max(jcp_.l_pad, jcp_.r_pad);
158 const auto strides_ptr = (jcp_.brg_type == brgemm_strd)
159 ? &brg_strides
160 : nullptr;
161 CHECK(brgemm_desc_init(brg, isa, jcp_.brg_type, diff_dst_type,
162 wei_type, false, false, brgemm_row_major, alpha, vbeta,
163 jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK,
164 strides_ptr));
165
166 brgemm_attr_t brgattr;
167 brgattr.use_uker = jcp_.use_uker;
168 brgattr.use_interleave_stores = jcp_.use_interleave_stores;
169 brgattr.hint_prefetching = jcp_.hint_prefetching;
170 brgattr.max_bs = bs;
171 brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
172 ? brgemm_bd_loop_innermost
173 : brgemm_ld_loop_innermost;
174 if (jcp_.amx_tile_load_xx) {
175 // assuming 2x2 decomposition in amx brgemm kernel
176 // and overlap of input by kw
177 const auto bd_blocking = 2 * jcp_.amx_h;
178 const auto ld_blocking = 2 * 16;
179 brgattr.hint_expected_A_size = bd_blocking * jcp_.K
180 * jcp_.kd_block * jcp_.kh_block;
181 brgattr.hint_expected_B_size = ld_blocking * jcp_.K
182 * jcp_.kd_block * jcp_.kh_block * jcp_.kw_block;
183 brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
184 } else {
185 brgattr.hint_expected_A_size = 0;
186 brgattr.hint_expected_B_size = 0;
187 brgattr.hint_expected_C_size = 0;
188 }
189
190 brgattr.wary_tail_read = false;
191 maybe_M_mask(brg_idx, brgattr, vM, vbrgM);
192 brgattr.bd_mask_level = jcp_.use_M_mask;
193
194 if (is_amx) {
195 brgattr.max_top_vpad = 0;
196 brgattr.max_bottom_vpad = 0;
197 } else {
198 brgattr.max_top_vpad = jcp_.max_vpad;
199 brgattr.max_bottom_vpad = jcp_.max_vpad;
200 }
201 brgattr.generate_skip_accumulation = true;
202 CHECK(brgemm_desc_set_attr(brg, brgattr));
203
204 auto LDD = jcp_.stride_w * jcp_.ic_without_padding;
205 brg->with_sum = with_sum;
206 CHECK(brgemm_desc_set_postops(
207 brg, attr(), &diff_src_md_, LDD, jcp_.bia_dt));
208 jcp_.amx_buf_size_per_thread
209 = nstl::max(brg->get_wsp_buffer_size(),
210 jcp_.amx_buf_size_per_thread);
211 }
212 }
213 }
214
215 auto scratchpad = scratchpad_registry().registrar();
216 brgemm_convolution_bwd_utils::init_scratchpad(scratchpad, jcp_);
217
218 return status::success;
219}
220
221template <cpu_isa_t isa, bool enable_postops>
222status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::add_brg_kernel(
223 int bs, int M, int i_N, int i_K, int i_init) {
224 if (M <= 0) return status::success;
225 const auto _pd = pd();
226 const auto &jcp = _pd->jcp_;
227 const auto &brgs = _pd->brgs_;
228
229 auto N = (i_N) ? jcp.N_tail : jcp.N;
230 auto K = (i_K) ? jcp.K_tail : jcp.K;
231 if (N <= 0 || K <= 0) return status::success;
232 auto brg_idx = _pd->get_brg_idx(bs, M - 1, i_init, i_N, i_K);
233 auto brg = brgs[brg_idx];
234 if (!brg_kernels_[brg_idx] && brg && brg->bcast_dim > 0 && brg->load_dim > 0
235 && brg->reduce_dim > 0) {
236 brgemm_kernel_t *brg_kernel = nullptr;
237 CHECK(brgemm_kernel_create(&brg_kernel, *brg));
238 CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel));
239 if (is_amx) {
240 CHECK(brgemm_init_tiles(*brg, &brg_kernel_palettes_[brg_idx].a[0]));
241 }
242 }
243 return status::success;
244}
245
246template <cpu_isa_t isa, bool enable_postops>
247status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::init(
248 engine_t *engine) {
249
250 const auto _pd = pd();
251 const auto &jcp = _pd->jcp_;
252
253 bia_dsz = jcp.bia_dsz;
254 acc_dsz = jcp.acc_dsz;
255 src_dsz = jcp.src_dsz;
256 wei_dsz = jcp.wei_dsz;
257 dst_dsz = jcp.dst_dsz;
258
259 auto ndims = _pd->ndims();
260 if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!");
261
262 KD = ndims_pick(jcp.kd, 1, 1);
263 KH = ndims_pick(jcp.kh, jcp.kh, 1);
264 KW = jcp.kw;
265
266 EXT_KD = ndims_pick(jcp.ext_kd, 1, 1);
267 EXT_KH = ndims_pick(jcp.ext_kh, jcp.ext_kh, 1);
268 EXT_KW = jcp.ext_kw;
269
270 ODP = ndims_pick(jcp.odp, 1, 1);
271 OHP = ndims_pick(jcp.ohp, jcp.ohp, 1);
272 OWP = jcp.owp;
273
274 KS = KD * KH * KW;
275 KD_BLOCK = ndims_pick(jcp.kd_block, 1, 1);
276 KH_BLOCK = ndims_pick(jcp.kh_block, jcp.kh_block, 1);
277 KW_BLOCK = jcp.kw_block;
278 KD_BLOCK_PAD = ndims_pick(jcp.kd_block_pad, 1, 1);
279 KH_BLOCK_PAD = ndims_pick(jcp.kh_block_pad, jcp.kh_block_pad, 1);
280 ID = ndims_pick(jcp.id, 1, 1);
281 IH = ndims_pick(jcp.ih, jcp.ih, 1);
282 IW = jcp.iw;
283 OD = ndims_pick(jcp.od, 1, 1);
284 OH = ndims_pick(jcp.oh, jcp.oh, 1);
285 OW = jcp.ow;
286 SD = ndims_pick(jcp.stride_d, 1, 1);
287 SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1);
288 SW = jcp.stride_w;
289 FP = ndims_pick(jcp.f_pad, 0, 0);
290 TP = ndims_pick(jcp.t_pad, jcp.t_pad, 0);
291 LP = jcp.l_pad;
292 DD = ndims_pick(jcp.dilate_d, 0, 0) + 1;
293 DH = ndims_pick(jcp.dilate_h, jcp.dilate_h, 0) + 1;
294 DW = jcp.dilate_w + 1;
295
296 oc_chunks = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
297
298 // const variables used for address calculations
299 src_w_sz = static_cast<dim_t>(OW) * jcp.ngroups * jcp.oc_without_padding;
300 src_h_sz = OH * src_w_sz;
301 src_d_sz = OD * src_h_sz;
302 dst_w_sz = static_cast<dim_t>(IW) * jcp.ic_without_padding;
303 dst_h_sz = IH * dst_w_sz;
304 dst_d_sz = ID * dst_h_sz;
305
306 wei_oc_sz = static_cast<dim_t>(jcp.ocp) * jcp.ic_block;
307 wei_kw_sz = KW * wei_oc_sz;
308 wei_kh_sz = KH * wei_kw_sz;
309 wei_kd_sz = KD * wei_kh_sz;
310 wei_icb_sz = jcp.nb_ic * wei_kd_sz;
311
312 need_postwork = jcp.with_bias || jcp.with_eltwise || jcp.with_binary
313 || (jcp.dst_dt != jcp.acc_dt) || jcp.with_sum || jcp.use_M_mask
314 || jcp.src_zero_point || jcp.dst_zero_point;
315
316 // ---- Initialize arrays ---------------------
317 brg_kernels_.resize(_pd->brgs_sz_);
318 brg_kernel_palettes_.resize(_pd->brgs_sz_);
319
320 for (int i = 0; i < _pd->brgs_sz_; i++)
321 brg_kernels_[i] = nullptr;
322
323 CHECK(safe_ptr_assign(copy_to_pbuffer_,
324 new jit_avx512_core_brgemm_conv_bwd_trans_kernel_t(jcp)));
325 CHECK(copy_to_pbuffer_->create_kernel());
326
327 const auto ow_block = jcp.owp;
328 const auto oh_block = jcp.ohp;
329 const auto od_block = jcp.odp;
330
331 pbuf_w_sz = (dim_t)jcp.oc_block * ow_block;
332 pbuf_h_sz = pbuf_w_sz * oh_block;
333 pbuf_d_sz = pbuf_h_sz * od_block;
334
335 is_amx = brgemm_convolution_bwd_utils::is_amx(isa);
336
337 // TODO: this is only needed if we have d/h padding exceeding kd/kh
338 int M_begin = 0;
339 int M_end = (jcp.M_tail == jcp.M) ? 1 : 2;
340 int N_begin = 0;
341 int N_end = (jcp.N_tail == jcp.N) ? 1 : 2;
342 int K_begin = 0;
343 int K_end = (jcp.K_tail == jcp.K) ? 1 : 2;
344 int i_init_begin = (div_up(jcp.nb_oc, jcp.nb_oc_blocking) == 1
345 && KD_BLOCK == KD && KH_BLOCK == KH)
346 ? 1
347 : 0;
348 int i_init_end = 2;
349
350 for (int bs = 0; bs <= jcp.max_batch; bs++) {
351 if (_pd->batchsizes[bs] == -1) continue;
352
353 for_(int i_N = N_begin; i_N < N_end; i_N++)
354 for_(int i_M = M_begin; i_M < M_end; i_M++)
355 for_(int i_init = i_init_begin; i_init < i_init_end; i_init++)
356 for (int i_K = K_begin; i_K < K_end; i_K++) {
357 auto M = (i_M) ? jcp.M_tail : jcp.M;
358 if (M <= 0) continue;
359 add_brg_kernel(bs, M, i_N, i_K, i_init);
360 }
361 }
362
363 return status::success;
364}
365
366template <cpu_isa_t isa, bool enable_postops>
367status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::execute(
368 const exec_ctx_t &ctx) const {
369 const auto _pd = pd();
370 const auto &jcp = _pd->jcp_;
371
372 // XXX: brgemm requires scales to be passed, so passing default wei scales
373 DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS);
374
375 const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
376 brgemm_batch_element_t *const __restrict brg_batch_global
377 = (jcp.brg_type == brgemm_strd && jcp.exec_type != exec_vpad)
378 ? nullptr
379 : scratchpad.template get<brgemm_batch_element_t>(
380 key_brgemm_primitive_batch);
381 char *const __restrict c_buffer_global = (jcp.use_buffer)
382 ? scratchpad.template get<char>(key_brgemm_primitive_buffer)
383 : nullptr;
384
385 auto inp_p_buffer = (jcp.exec_type == exec_trans)
386 ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer)
387 : nullptr;
388 auto inp_p_buffer_mask = (jcp.exec_type == exec_trans)
389 ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask)
390 : nullptr;
391
392 char *const wsp_tile_global = is_amx
393 ? scratchpad.template get<char>(key_conv_amx_tile_buffer)
394 : nullptr;
395
396 brgemm_bwd_exec_ctx_t brgemm_ctx(ctx, _pd);
397
398 const char *const __restrict diff_dst = brgemm_ctx.diff_dst;
399
400 const dim_t work_amount = static_cast<dim_t>(jcp.mb) * jcp.ngroups
401 * jcp.nb_ic * jcp.nb_id * jcp.nb_ih * jcp.nb_iw;
402
403 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
404 if (ithr >= work_amount) return;
405
406 brgemm_batch_element_t *const __restrict brg_batch = brg_batch_global
407 + static_cast<size_t>(ithr) * jcp.adjusted_batch_size;
408 char *const __restrict c_buffer = (jcp.use_buffer)
409 ? c_buffer_global + ithr * acc_dsz * jcp.buffer_size
410 : nullptr;
411 char *inp_buffer = (jcp.exec_type == exec_trans)
412 ? inp_p_buffer + src_dsz * ithr * jcp.inp_buffer_size
413 : nullptr;
414 if (is_amx) {
415 // Workaround: for some machines SEGFAULT possible on tile load
416 // if the page was not touched before it
417 for (dim_t i = 0; i < jcp.inp_buffer_size;
418 i += brgemm_convolution_bwd_utils::P4K)
419 inp_buffer[i] = 0;
420 }
421
422 uint8_t *__restrict inp_buffer_mask = (jcp.exec_type == exec_trans)
423 ? inp_p_buffer_mask + ithr * jcp.inp_buffer_mask_size
424 : nullptr;
425
426 char *const wsp_tile = is_amx
427 ? wsp_tile_global + ithr * 2 * brgemm_convolution_bwd_utils::P4K
428 : nullptr;
429 dim_t start {0}, end {0};
430 balance211(work_amount, nthr, ithr, start, end);
431 int n {0}, g {0}, icb {0}, idb {0}, ihb {0}, iwb {0};
432
433 nd_iterator_init(start, n, jcp.mb, idb, jcp.nb_id, ihb, jcp.nb_ih, iwb,
434 jcp.nb_iw, g, jcp.ngroups, icb, jcp.nb_ic);
435
436 brgemm_bwd_thread_ctx_t btc(
437 brgemm_ctx, ithr, brg_batch, c_buffer, wsp_tile);
438 std::memset(btc.cur_palette.a, 0, AMX_PALETTE_SIZE);
439
440 int last_n = -1;
441 int last_g = -1;
442 int last_occ = -1;
443 int last_idb = -1;
444 int last_ihb = -1;
445 int last_iwb = -1;
446 for (auto work = start; work < end; work++) {
447 btc.g = g;
448 btc.n = n;
449 btc.icb = icb;
450 btc.idb = idb;
451 btc.ihb = ihb;
452 btc.iwb = iwb;
453 btc.oscales = oscales;
454
455 auto id_begin = idb * jcp.id_block;
456 auto id_end = nstl::min(ID, id_begin + jcp.id_block);
457 auto ih_begin = ihb * jcp.ih_block;
458 auto ih_end = nstl::min(IH, ih_begin + jcp.ih_block);
459
460 for_(int id = id_begin; id < id_end; id++)
461 for_(int ih = ih_begin; ih < ih_end; ih++)
462 for (int occ = 0; occ < oc_chunks; occ++) {
463 btc.id = id;
464 btc.ih = ih;
465 btc.occ = occ;
466
467 maybe_trans_inp(ithr, diff_dst, inp_buffer, inp_buffer_mask, g,
468 n, occ, idb, ihb, iwb, last_g, last_n, last_occ,
469 last_idb, last_ihb, last_iwb);
470 for (int sw = 0; sw < SW; sw++) {
471 btc.sw = sw;
472 ker_trans(btc, inp_buffer);
473 }
474
475 last_n = n;
476 last_g = g;
477 last_occ = occ;
478 last_idb = idb;
479 last_ihb = ihb;
480 last_iwb = iwb;
481 }
482 nd_iterator_step(n, jcp.mb, idb, jcp.nb_id, ihb, jcp.nb_ih, iwb,
483 jcp.nb_iw, g, jcp.ngroups, icb, jcp.nb_ic);
484 }
485 if (is_amx) { amx_tile_release(); }
486 });
487
488 return status::success;
489}
490
491template <cpu_isa_t isa, bool enable_postops>
492void brgemm_convolution_bwd_strided_t<isa, enable_postops>::call_brgemm_kernel(
493 brgemm_bwd_thread_ctx_t &btc, int brg_idx, int batch_size, char *ptr_C,
494 char *ptr_D, const char *bias_w, int g_ic, bool do_postops,
495 const void *binary_post_ops_rhs, int32_t src_zp_vals,
496 int32_t *src_zp_ptr, int32_t *dst_zp_ptr, int32_t *s8s8_comp,
497 bool do_only_comp, bool is_first_call_postops) const {
498
499 const auto _pd = pd();
500 const auto &jcp = _pd->jcp_;
501
502 const auto brg_ker = brg_kernels_[brg_idx].get();
503 assert(brg_ker != nullptr);
504
505 if (is_first_call_postops) return;
506
507 if (is_amx) {
508 if (std::memcmp(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a,
509 AMX_PALETTE_SIZE)
510 != 0) {
511 amx_tile_configure(brg_kernel_palettes_[brg_idx].a);
512 std::memcpy(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a,
513 AMX_PALETTE_SIZE);
514 }
515 }
516
517 const auto do_only_pass_comp = !do_postops && jcp.src_zero_point
518 && (jcp.req_brg_comp_pad || jcp.max_vpad > 0);
519 const auto do_skip_accm = batch_size == 0;
520 const auto maybe_do_postops = one_of(
521 true, do_postops, do_only_comp, do_only_pass_comp, do_skip_accm);
522 if (maybe_do_postops) {
523 const brgemm_post_ops_data_t post_ops_data {
524 static_cast<const char *>(bias_w),
525 &btc.oscales[jcp.is_ic_scale * g_ic], binary_post_ops_rhs,
526 static_cast<size_t>(g_ic), 0, btc.brgemm_ctx.dst, 0,
527 static_cast<void *>(src_zp_ptr), nullptr,
528 static_cast<void *>(dst_zp_ptr), do_skip_accm, src_zp_vals,
529 do_only_comp, do_only_pass_comp};
530
531 void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile)
532 : static_cast<void *>(s8s8_comp);
533
534 if (do_postops || do_skip_accm)
535 brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch,
536 ptr_C, ptr_D, post_ops_data, scratch);
537 else
538 brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch,
539 ptr_C, ptr_C, post_ops_data, scratch);
540 } else
541 brgemm_kernel_execute(brg_ker, batch_size, btc.brg_batch, ptr_C,
542 static_cast<void *>(btc.wsp_tile));
543}
544
545template <cpu_isa_t isa, bool enable_postops>
546void brgemm_convolution_bwd_strided_t<isa, enable_postops>::maybe_trans_inp(
547 int ithr, const char *__restrict src, char *__restrict inp_buffer,
548 uint8_t *__restrict inp_buffer_mask, int g, int n, int occ, int idb,
549 int ihb, int iwb, int last_g, int last_n, int last_occ, int last_idb,
550 int last_ihb, int last_iwb) const {
551
552 const auto _pd = pd();
553 const auto &jcp = _pd->jcp_;
554 const auto ocb = occ * jcp.nb_oc_blocking;
555
556 if (last_g == g && last_n == n && last_occ == occ && last_idb == idb
557 && last_ihb == ihb && last_iwb == iwb)
558 return;
559
560 auto cp = jit_brgemm_conv_bwd_trans_kernel_call_s();
561
562 const auto oc = ocb * jcp.oc_block;
563 const auto g_oc = g * jcp.oc + oc;
564
565 const auto sw = jcp.l_pad % jcp.stride_w;
566 const auto kw = (jcp.kw - 1) % jcp.stride_w;
567 const auto kw_x = (jcp.kw - 1) - nstl::modulo(kw - sw, jcp.stride_w);
568 const auto ow = (iwb * jcp.iw_block + jcp.l_pad - kw_x * (jcp.dilate_w + 1))
569 / jcp.stride_w;
570
571 int od_start {0}, od_end {0}, oh_start {0}, oh_end {0};
572
573 const auto sh = jcp.t_pad % jcp.stride_h;
574 const auto kh = (jcp.kh - 1) % jcp.stride_h;
575 const auto kh_x = (jcp.kh - 1) - nstl::modulo(kh - sh, jcp.stride_h);
576 oh_start = (ihb * jcp.ih_block + jcp.t_pad - kh_x * (jcp.dilate_h + 1))
577 / jcp.stride_h;
578 oh_end = oh_start + jcp.oh_block;
579
580 const auto sd = jcp.f_pad % jcp.stride_d;
581 const auto kd = (jcp.kd - 1) % jcp.stride_d;
582 const auto kd_x = (jcp.kd - 1) - nstl::modulo(kd - sd, jcp.stride_d);
583 od_start = (idb * jcp.id_block + jcp.f_pad - kd_x * (jcp.dilate_d + 1))
584 / jcp.stride_d;
585 od_end = od_start + jcp.od_block;
586
587 const auto rows_to_copy = min(jcp.oh, oh_end) - max(0, oh_start);
588 cp.iwb = iwb;
589 cp.oc = oc;
590 const auto ow_buf = ow;
591 dim_t inp_offset_start, out_offset_start;
592
593 cp.t_pad = 0;
594 cp.b_pad = 0;
595 cp.h_count = max(0, rows_to_copy);
596
597 const auto oh_buf = max(0, oh_start);
598
599 inp_offset_start = static_cast<dim_t>(n) * src_d_sz
600 + max(0, oh_start) * src_w_sz
601 + max(0, ow) * jcp.ngroups * jcp.oc_without_padding + g_oc;
602 out_offset_start = oh_buf * pbuf_w_sz + ow_buf * jcp.oc_block;
603
604 for (int od = max(0, od_start); od < min(jcp.od, od_end); od++) {
605 const auto inp_offset = inp_offset_start + od * src_h_sz;
606 const auto od_buf = od;
607 const auto out_offset = out_offset_start + od_buf * pbuf_h_sz;
608 cp.src = src + src_dsz * inp_offset;
609 cp.dst = inp_buffer + src_dsz * out_offset;
610
611 (*copy_to_pbuffer_)(&cp);
612 }
613}
614
615template <cpu_isa_t isa, bool enable_postops>
616void brgemm_convolution_bwd_strided_t<isa, enable_postops>::ker_trans(
617 brgemm_bwd_thread_ctx_t &btc, char *inp_buffer) const {
618
619 const auto _pd = pd();
620 const auto &jcp = _pd->jcp_;
621 auto ndims = _pd->ndims();
622
623 const char *const __restrict weights = btc.brgemm_ctx.weights;
624 const char *const __restrict bias = btc.brgemm_ctx.bias;
625 char *const __restrict dst = btc.brgemm_ctx.dst;
626 const std::vector<const void *> &post_ops_binary_rhs_arg_vec
627 = btc.brgemm_ctx.post_ops_binary_rhs_arg_vec;
628 const int ic = btc.icb * jcp.ic_block;
629 const int g_ic = btc.g * jcp.ic + ic;
630 const int ocb = btc.occ * jcp.nb_oc_blocking;
631 const int oc = ocb * jcp.oc_block;
632 const dim_t iw = btc.iwb * jcp.iw_block + btc.sw;
633 const dim_t ih = btc.ih;
634 const dim_t id = btc.id;
635
636 // od = (id + FP - kd * DD) / SD <-- general relation for all sets of (od, id, kd) that overlap
637 // for a given index from diff_src, we need to find the appropriate stride sector
638 int kd_s_(0), kh_s_(0), kw_s(0), kd_f_(0), kh_f_(0), kw_f(0);
639
640 auto set_k_range = [&](int P, int D, int S, dim_t i, dim_t O, int K,
641 int &k_s, int &k_f, bool is_w) {
642 int s(0), o_test(0);
643 while (true) {
644 o_test = i + P - s * D;
645 if (o_test % S == 0) break;
646 s++;
647 }
648
649 k_f = is_w ? K : min(K, static_cast<int>(div_up(i + P + 1, D)));
650 k_s = is_w ? 0 : max(0, static_cast<int>(div_up(i + P - O * S + 1, D)));
651
652 while (k_s % S != s)
653 k_s++;
654 };
655
656 set_k_range(FP, DD, SD, id, OD, KD, kd_s_, kd_f_, false);
657 set_k_range(TP, DH, SH, ih, OH, KH, kh_s_, kh_f_, false);
658 set_k_range(LP, DW, SW, iw, OW, KW, kw_s, kw_f, true);
659
660 const auto kh_f = ndims_pick(kh_f_, kh_f_, 1);
661 const auto kh_s = ndims_pick(kh_s_, kh_s_, 0);
662
663 const auto kd_f = ndims_pick(kd_f_, 1, 1);
664 const auto kd_s = ndims_pick(kd_s_, 0, 0);
665
666 const bool is_oc_tail
667 = (btc.occ == oc_chunks - 1 && ((jcp.oc - oc) % jcp.oc_block != 0));
668
669 const bool is_ic_tail = (jcp.ic - ic < jcp.ic_block);
670 const char *const __restrict bias_w
671 = bias ? bias + (bias_d.blk_off(g_ic) * bia_dsz) : nullptr;
672 const auto nb_oc_b = nstl::min(jcp.nb_oc_blocking, jcp.nb_oc - ocb)
673 - (is_oc_tail ? 1 : 0);
674 char *const __restrict dst_base = dst + dst_dsz * (btc.n * dst_d_sz + g_ic);
675 char *ptr_C;
676 char *ptr_D;
677 int kd_b(0), kd_e(0), kh_b(0), kh_e(0), k_l(0);
678
679 const auto wei_base
680 = weights + wei_dsz * (btc.g * wei_icb_sz + btc.icb * wei_kd_sz);
681 const dim_t iw_b {iw};
682
683 ptr_D = dst_base
684 + dst_dsz
685 * (id * dst_h_sz + ih * dst_w_sz
686 + iw_b * jcp.ic_without_padding);
687 ptr_C = (jcp.use_buffer) ? btc.c_buffer : static_cast<char *>(ptr_D);
688
689 const auto ker_i = (jcp.M > 0 ? jcp.M : jcp.M_tail) - 1;
690
691 bool is_first_call_postops = false,
692 is_first_call_postops_state_changed = false;
693 const auto call_brgemm = [&](int brg_idx, int oc_block_s, int n_oc_blocks,
694 bool do_postops) {
695 const auto kh_ee = kh_e;
696 const auto kw_e = kw_f;
697 const auto pbuf_base = inp_buffer;
698
699 int k_sum = 0;
700 for (int i_ocb = 0; i_ocb < n_oc_blocks; i_ocb++) {
701 const auto oc_off = (oc_block_s + i_ocb) * jcp.oc_block;
702 const auto wei_oc = oc + oc_off;
703 const auto n_ocb_off = i_ocb * k_l;
704 const auto pbuf_base_oc = pbuf_base;
705 const auto wei_base_oc = wei_base + wei_dsz * wei_oc * jcp.ic_block;
706
707 auto k = 0;
708 for (int kd = kd_b; kd < kd_e; kd++) {
709 auto od = (id - kd * DD + FP);
710 if (od % SD != 0) continue;
711 od /= SD;
712 const auto pbuf_base_kd
713 = pbuf_base_oc + src_dsz * od * pbuf_h_sz;
714 const auto wei_base_kd = wei_base_oc + wei_dsz * kd * wei_kh_sz;
715 for (int kh = kh_b; kh < kh_ee; kh++) {
716 auto oh = (ih - kh * DH + TP);
717 if (oh % SH != 0) continue;
718 oh /= SH;
719 const auto pbuf_base_kh
720 = pbuf_base_kd + src_dsz * oh * pbuf_w_sz;
721 const auto wei_base_kh
722 = wei_base_kd + wei_dsz * kh * wei_kw_sz;
723 for (int kw = kw_s; kw < kw_e; kw += SW) {
724 const auto ow = (iw - kw * DW + LP) / SW;
725 // inp_buffer layout is Cdhw<oc_block>c
726 btc.brg_batch[n_ocb_off + k].ptr.A = pbuf_base_kh
727 + src_dsz * (ow + jcp.l_ovf) * jcp.oc_block;
728 btc.brg_batch[n_ocb_off + k].vvpad.top = 0;
729 btc.brg_batch[n_ocb_off + k].vvpad.bottom = 0;
730 // general wei layout is gIdhwO<block_i><block_o>
731 btc.brg_batch[n_ocb_off + k].ptr.B
732 = wei_base_kh + wei_dsz * kw * wei_oc_sz;
733 k++;
734 }
735 }
736 }
737 k_sum += k;
738 }
739 call_brgemm_kernel(btc, brg_idx, k_sum, ptr_C, ptr_D, bias_w, g_ic,
740 do_postops, post_ops_binary_rhs_arg_vec.data(), 0, nullptr,
741 nullptr, nullptr, false, is_first_call_postops);
742 if (!is_first_call_postops_state_changed) {
743 const auto do_only_pass_comp = !do_postops && jcp.src_zero_point
744 && (jcp.req_brg_comp_pad || jcp.max_vpad > 0);
745 const auto do_skip_accm = k_sum == 0;
746 is_first_call_postops
747 = one_of(true, do_postops, do_only_pass_comp, do_skip_accm);
748 is_first_call_postops_state_changed = true;
749 }
750
751 MAYBE_UNUSED(bias_w);
752 MAYBE_UNUSED(ptr_C);
753 MAYBE_UNUSED(post_ops_binary_rhs_arg_vec);
754 };
755
756 const auto kdhw_loop = [&]() {
757 const auto do_init = btc.occ == 0 && kd_b == kd_s && kh_b == kh_s;
758 const auto do_postwork = need_postwork && btc.occ == (oc_chunks - 1)
759 && kd_e == kd_f && kh_e == kh_f;
760
761 const int kd_l = div_up(kd_e - kd_b, SD);
762 const int kh_l = div_up(kh_e - kh_b, SH);
763 const int kw_l = div_up(kw_f - kw_s, SW);
764 k_l = kd_l * kh_l * kw_l;
765
766 int kernel_idx[2][2];
767 kernel_idx[false][false]
768 = _pd->get_brg_idx(k_l, ker_i, false, is_ic_tail, false);
769 kernel_idx[true][false]
770 = _pd->get_brg_idx(k_l, ker_i, true, is_ic_tail, false);
771 kernel_idx[false][true]
772 = _pd->get_brg_idx(k_l, ker_i, false, is_ic_tail, true);
773 kernel_idx[true][true]
774 = _pd->get_brg_idx(k_l, ker_i, true, is_ic_tail, true);
775
776 if (nb_oc_b > 0) {
777 const auto brg_idx = kernel_idx[do_init][false];
778 call_brgemm(brg_idx, 0, nb_oc_b, do_postwork && !is_oc_tail);
779 }
780
781 if (is_oc_tail) {
782 const auto use_init_ker = (do_init && nb_oc_b == 0);
783 const auto brg_oc_tail_idx = kernel_idx[use_init_ker][true];
784 call_brgemm(brg_oc_tail_idx, nb_oc_b, 1, do_postwork);
785 }
786 };
787
788 if (kd_f > kd_s && kh_f > kh_s) {
789 // kw values covering full ow_block
790 for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) {
791 kd_e = nstl::min(kd_f, kd_b + KD_BLOCK);
792 for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) {
793 kh_e = nstl::min(kh_f, kh_b + KH_BLOCK);
794 kdhw_loop();
795 }
796 }
797 } else {
798 kd_b = kd_e = kd_s;
799 kh_b = kh_e = kh_s;
800 kdhw_loop();
801 }
802}
803
804template struct brgemm_convolution_bwd_strided_t<avx512_core_amx>;
805template struct brgemm_convolution_bwd_strided_t<avx512_core_amx, true>;
806template struct brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16>;
807template struct brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16, true>;
808
809} // namespace x64
810
811} // namespace cpu
812} // namespace impl
813} // namespace dnnl
814
815// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
816