1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21#include "cpu/cpu_primitive.hpp"
22#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
23
24#include "cpu/x64/jit_brgemm_conv_bwd_w.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace dnnl::impl::status;
32using namespace dnnl::impl::memory_tracking::names;
33using namespace dnnl::impl::utils;
34
35using namespace nstl;
36using namespace data_type;
37
38#define wht_blk_off(d, g, ...) \
39 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
40 : (d).blk_off(__VA_ARGS__))
41
42status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) {
43 const auto src_type = src_md(0)->data_type;
44 const auto diff_wei_type = diff_weights_md(0)->data_type;
45 const auto diff_bia_type = diff_weights_md(1)->data_type;
46 const auto diff_dst_type = diff_dst_md(0)->data_type;
47 bool ok = true && is_bwd_w()
48 && set_default_alg_kind(alg_kind::convolution_direct)
49 && utils::one_of(src_type, bf16, f16) && diff_dst_type == src_type
50 && utils::one_of(diff_wei_type, f32, src_type)
51 && utils::one_of(diff_bia_type, data_type::undef, f32, src_type)
52 && attr()->has_default_values() && !has_zero_dim_memory();
53 if (!ok) return status::unimplemented;
54
55 auto scratchpad = scratchpad_registry().registrar();
56
57 status_t status = brgemm_convolution_utils::init_conf_bwd_w(jcp_, *desc(),
58 src_md_, diff_weights_md_, diff_bias_md_, diff_dst_md_, attr_,
59 dnnl_get_max_threads());
60 if (status != status::success) return status;
61
62 status = brgemm_convolution_utils::init_scratchpad_bwd_w(
63 scratchpad, jcp_, src_md_, diff_weights_md_, diff_dst_md_);
64
65 if (status != status::success) return status;
66 copy2jit_jcp();
67
68 bs_c = jcp_.var_bs ? 1 : (jcp_.max_batch + 1);
69 batchsizes.resize(bs_c + 1);
70 for (int i = 0; i <= bs_c; i++)
71 batchsizes[i] = -1;
72
73 batchsizes[1] = 0;
74
75 const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail);
76 brgs_sz_ = bs_c * (adj_M + 1) * 2 * 2 * 2;
77 brgs_.resize(brgs_sz_);
78 bd_masks.resize(brgs_sz_);
79
80 const float alpha = 1.0;
81 const float beta = 1.0;
82
83 int M_begin = 0;
84 int M_end = (jcp_.M_tail == jcp_.M || jcp_.M_tail == 0) ? 1 : 2;
85 int N_begin = 0;
86 int N_end = (jcp_.N_tail == jcp_.N || jcp_.N_tail == 0) ? 1 : 2;
87 int K_begin = 0;
88 int K_end = (jcp_.K_tail == jcp_.K || jcp_.K_tail == 0) ? 1 : 2;
89 int init_begin = 0;
90 int init_end = 2;
91
92 const auto wei_type = src_type;
93
94 for (int i = M_begin; i < M_end; i++) {
95 auto M = (i) ? jcp_.M_tail : jcp_.M;
96 if (M <= 0) continue;
97 // init only needed brgemm descriptors
98 for (int bs = 0; bs <= jcp_.max_batch; bs++) {
99 if (batchsizes[bs] == -1) continue;
100 for_(int i_init = init_begin; i_init < init_end; i_init++)
101 for_(int i_N = N_begin; i_N < N_end; i_N++)
102 for (int i_K = K_begin; i_K < K_end; i_K++) {
103 auto vbeta = (i_init) ? 0 : beta;
104 auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
105 auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
106 if (vN == 0 || vK == 0) continue;
107 auto brg_idx = get_brg_idx(bs, M, i_init, i_N, i_K);
108 // if brgemm_t already created then skip this iteration
109 if (brgs_[brg_idx] != nullptr) continue;
110 brgs_[brg_idx] = std::make_shared<brgemm_t>();
111 brgemm_t *brg = brgs_[brg_idx].get();
112 CHECK(brgemm_desc_init(brg, jcp_.isa, jcp_.brg_type, src_type,
113 wei_type, false, false, brgemm_row_major, alpha, vbeta,
114 jcp_.LDA, jcp_.LDB, jcp_.LDC, M, vN, vK, nullptr));
115
116 brgemm_attr_t brgattr;
117 brgattr.use_uker = jcp_.use_uker;
118 brgattr.use_interleave_stores = jcp_.use_interleave_stores;
119 brgattr.hint_prefetching = jcp_.hint_prefetching;
120 brgattr.var_bs = jcp_.var_bs;
121 brgattr.max_bs = jcp_.max_batch;
122 brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
123 ? brgemm_bd_loop_innermost
124 : brgemm_ld_loop_innermost;
125
126 brgattr.hint_expected_A_size = 0;
127 brgattr.hint_expected_B_size = 0;
128 brgattr.hint_expected_C_size = 0;
129
130 brgattr.wary_tail_read = false;
131 brgattr.bd_mask_level = jcp_.use_M_mask;
132
133 brgattr.max_top_vpad = 0;
134 brgattr.max_bottom_vpad = 0;
135
136 brgattr.LDA2 = jcp_.tr_iw * jcp_.ih * jcp_.id;
137 brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od;
138 brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw;
139 brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block
140 * jcp_.kd * jcp_.kh * jcp_.kw;
141
142 CHECK(brgemm_desc_set_attr(brg, brgattr));
143 }
144 }
145 }
146 return status;
147}
148
149// jit_jcp used to initialize transpose kernels shared with jit implementation
150void brgemm_convolution_bwd_weights_t::pd_t::copy2jit_jcp() {
151 jit_jcp_ = zero<decltype(jit_jcp_)>();
152 jit_jcp_.prop_kind = jcp_.prop_kind;
153 jit_jcp_.has_vnni = jcp_.has_vnni;
154 jit_jcp_.harness = jcp_.harness;
155 jit_jcp_.simd_w = jcp_.simd_w;
156 jit_jcp_.ndims = jcp_.ndims;
157 jit_jcp_.mb = jcp_.mb;
158 jit_jcp_.ngroups = jcp_.ngroups;
159 jit_jcp_.ic = jcp_.ic;
160 jit_jcp_.oc = jcp_.oc;
161 jit_jcp_.oc_without_padding = jcp_.oc;
162 jit_jcp_.ic_without_padding = jcp_.ic_without_padding;
163 jit_jcp_.id = jcp_.id;
164 jit_jcp_.ih = jcp_.ih;
165 jit_jcp_.iw = jcp_.iw;
166 jit_jcp_.od = jcp_.od;
167 jit_jcp_.oh = jcp_.oh;
168 jit_jcp_.ow = jcp_.ow;
169 jit_jcp_.f_pad = jcp_.f_pad;
170 jit_jcp_.l_pad = jcp_.l_pad;
171 jit_jcp_.t_pad = jcp_.t_pad;
172 jit_jcp_.back_pad = jcp_.back_pad;
173 jit_jcp_.r_pad = jcp_.r_pad;
174 jit_jcp_.b_pad = jcp_.b_pad;
175 jit_jcp_.kd = jcp_.kd;
176 jit_jcp_.kh = jcp_.kh;
177 jit_jcp_.kw = jcp_.kw;
178 jit_jcp_.stride_d = jcp_.stride_d;
179 jit_jcp_.stride_h = jcp_.stride_h;
180 jit_jcp_.stride_w = jcp_.stride_w;
181 jit_jcp_.dilate_d = jcp_.dilate_d;
182 jit_jcp_.dilate_h = jcp_.dilate_h;
183 jit_jcp_.dilate_w = jcp_.dilate_w;
184 jit_jcp_.src_tag = jcp_.src_tag;
185 jit_jcp_.wei_tag = jcp_.wei_tag;
186 jit_jcp_.dst_tag = jcp_.dst_tag;
187 jit_jcp_.with_bias = jcp_.with_bias;
188 jit_jcp_.with_sum = jcp_.with_sum;
189 jit_jcp_.with_eltwise = jcp_.with_eltwise;
190 jit_jcp_.with_binary = jcp_.with_binary;
191 jit_jcp_.is_fused_conv = jcp_.is_fused_conv;
192 jit_jcp_.nb_ic = jcp_.nb_ic;
193 jit_jcp_.ic_block = jcp_.ic_block;
194 jit_jcp_.nb_oc = jcp_.nb_oc;
195 jit_jcp_.oc_block = jcp_.oc_block;
196 jit_jcp_.nb_oc_blocking = jcp_.nb_oc_blocking;
197
198 jit_jcp_.ic_tail = jcp_.ic_tail;
199 jit_jcp_.oc_tail = jcp_.oc_tail;
200
201 jit_jcp_.tr_iw = jcp_.tr_iw;
202 jit_jcp_.tr_ow = jcp_.tr_ow;
203 jit_jcp_.tr_diff_dst_buf_size = jcp_.tr_diff_dst_buf_size;
204 jit_jcp_.typesize_in = jcp_.typesize_in;
205 jit_jcp_.typesize_out = jcp_.typesize_out;
206 jit_jcp_.ddst_dt = jcp_.dst_dt;
207}
208
209status_t brgemm_convolution_bwd_weights_t::add_brg_kernel(
210 int bs, int M, int i_N, int i_K, int i_init) {
211 if (M <= 0) return status::success;
212 const auto _pd = pd();
213 const auto &jcp = _pd->jcp_;
214 const auto &brgs = _pd->brgs_;
215
216 auto N = (i_N) ? jcp.N_tail : jcp.N;
217 auto K = (i_K) ? jcp.K_tail : jcp.K;
218 if (N <= 0 || K <= 0) return status::success;
219 auto brg_idx = _pd->get_brg_idx(bs, M, i_init, i_N, i_K);
220 auto brg = brgs[brg_idx];
221 if (!brg_kernels_[brg_idx] && brg && brg->bcast_dim > 0 && brg->load_dim > 0
222 && brg->reduce_dim > 0) {
223 brgemm_kernel_t *brg_kernel = nullptr;
224 CHECK(brgemm_kernel_create(&brg_kernel, *brg));
225 CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel));
226 CHECK(brgemm_init_tiles(*brg, &brg_kernel_palettes_[brg_idx].a[0]));
227 }
228 return status::success;
229}
230
231status_t brgemm_convolution_bwd_weights_t::init(engine_t *engine) {
232 const auto _pd = pd();
233 const auto &jcp = _pd->jcp_;
234 const auto &jit_jcp = pd()->jit_jcp_;
235
236 CHECK(safe_ptr_assign(trans_kernel_, create_trans_src(&jit_jcp)));
237 CHECK(trans_kernel_->create_kernel());
238 CHECK(safe_ptr_assign(trans_dst_kernel_, create_trans_dst(&jit_jcp)));
239 CHECK(trans_dst_kernel_->create_kernel());
240
241 if (jcp.with_bias) {
242 CHECK(safe_ptr_assign(diff_bias_kernel_,
243 new jit_avx512_core_amx_bwd_bias_kernel_t(jit_jcp)));
244 CHECK(diff_bias_kernel_->create_kernel());
245 }
246
247 if (jcp.nthr_mb > 1) {
248 CHECK(safe_ptr_assign(
249 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
250 CHECK(acc_ker_->create_kernel());
251 }
252 if (jcp.transform_to_vnni) {
253 CHECK(safe_ptr_assign(diff_wei_trans_kernel_,
254 new jit_diff_wei_trans_to_vnni_t(jcp.wei_dt, jcp.kd, jcp.kh,
255 jcp.kw, jcp.ic_block, jcp.oc_block)));
256 CHECK(diff_wei_trans_kernel_->create_kernel());
257 }
258
259 brg_kernels_.resize(_pd->brgs_sz_);
260 brg_kernel_palettes_.resize(_pd->brgs_sz_);
261
262 for (int i = 0; i < _pd->brgs_sz_; i++)
263 brg_kernels_[i] = nullptr;
264
265 int M_begin = 0;
266 int M_end = (jcp.M_tail == jcp.M || jcp.M_tail == 0) ? 1 : 2;
267 int N_begin = 0;
268 int N_end = (jcp.N_tail == jcp.N || jcp.N_tail == 0) ? 1 : 2;
269 int K_begin = 0;
270 int K_end = (jcp.K_tail == jcp.K || jcp.K_tail == 0) ? 1 : 2;
271 int init_begin = 0;
272 int init_end = 2;
273
274 for (int bs = 0; bs <= jcp.max_batch; bs++) {
275 if (_pd->batchsizes[bs] == -1) continue;
276
277 for_(int i_N = N_begin; i_N < N_end; i_N++)
278 for_(int i_M = M_begin; i_M < M_end; i_M++)
279 for_(int i_init = init_begin; i_init < init_end; i_init++)
280 for (int i_K = K_begin; i_K < K_end; i_K++) {
281 auto M = (i_M) ? jcp.M_tail : jcp.M;
282 if (M <= 0) continue;
283 add_brg_kernel(bs, M, i_N, i_K, i_init);
284 }
285 }
286
287 return status::success;
288}
289
290struct brgemm_convolution_bwd_weights_t::thread_info_t {
291 const src_data_t *src = nullptr;
292 const diff_dst_data_t *diff_dst = nullptr;
293 const void *diff_weights = nullptr;
294 const void *diff_bias = nullptr;
295
296 const brgemm_convolution_bwd_weights_t *self;
297 const memory_tracking::grantor_t scratchpad;
298
299 src_data_t *tr_src = nullptr;
300 diff_dst_data_t *tr_diff_dst = nullptr;
301 simple_barrier::ctx_t *tr_src_bctx = nullptr;
302 simple_barrier::ctx_t *tr_diff_dst_bctx = nullptr;
303
304 float *wei_bia_reduction = nullptr;
305 float *bia_reduction = nullptr;
306 simple_barrier::ctx_t *wei_bia_reduction_bctx = nullptr;
307
308 // All nthreads are mapped to a multidimensional "cube" with sizes:
309 // (nthr_mb, nthr_g, nthr_oc, nthr_ic).
310 // Variables ithr_* define the coordinates and "layers" of the current
311 // thread in this "cube"
312 int ithr = 0;
313 int ithr_ic_b = 0, ithr_oc_b = 0, ithr_g = 0, ithr_mb = 0;
314 int ithr_but_oc = 0;
315 int ithr_but_ic = 0;
316
317 int img_start = 0, img_end = 0, img_work = 0;
318 int g_start = 0, g_end = 0, g_work = 0;
319 int oc_b_start = 0, oc_b_end = 0, oc_b_work = 0;
320 int ic_b_start = 0, ic_b_end = 0, ic_b_work = 0;
321
322 S_t cur_palette;
323 brgemm_batch_element_t *__restrict brg_batch;
324 char *wsp_tile;
325 const exec_ctx_t &exec_ctx;
326 const jit_brgemm_conv_conf_t &jcp;
327 const memory_desc_wrapper src_d;
328 const memory_desc_wrapper diff_dst_d;
329 const memory_desc_wrapper diff_weights_d;
330
331 thread_info_t(const brgemm_convolution_bwd_weights_t *pcnv,
332 const exec_ctx_t &ctx, int ithr)
333 : self(pcnv)
334 , scratchpad(ctx.get_scratchpad_grantor())
335 , ithr(ithr)
336 , exec_ctx(ctx)
337 , jcp(self->pd()->jcp_)
338 , src_d(self->pd()->src_md())
339 , diff_dst_d(self->pd()->diff_dst_md())
340 , diff_weights_d(self->pd()->diff_weights_md(0)) {
341 diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
342 src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
343 diff_weights = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_WEIGHTS);
344
345 diff_bias = self->pd()->with_bias() && (jcp.oc % jcp.oc_block != 0)
346 && self->pd()->jcp_.bia_dt == data_type::f32
347 ? (void *)scratchpad.template get<float>(key_conv_padded_bias)
348 : CTX_OUT_MEM(void *, DNNL_ARG_DIFF_BIAS);
349
350 tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
351 if (jcp.global_transpose)
352 tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
353 key_conv_tr_src_bctx);
354
355 tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
356 key_conv_tr_diff_dst);
357 if (jcp.global_transpose)
358 tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
359 key_conv_tr_diff_dst_bctx);
360 wei_bia_reduction
361 = scratchpad.template get<float>(key_conv_wei_bia_reduction);
362 bia_reduction = nullptr;
363 if (jcp.with_bias) {
364 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
365 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
366 const int num_wei_buffers = jcp.wei_dt != data_type::f32
367 ? jcp.nthr_mb
368 : jcp.nthr_mb - 1;
369 bia_reduction = wei_bia_reduction + wei_size * num_wei_buffers;
370 }
371
372 wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
373 key_conv_wei_bia_reduction_bctx);
374
375 ithr_ic_b = ithr % jcp.nthr_ic_b;
376 ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
377 ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
378 ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g;
379
380 ithr_but_oc
381 = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b;
382
383 ithr_but_ic
384 = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_oc_b + ithr_oc_b;
385
386 int work_amount = jcp.nthr_mb_work;
387 /* reduction dimension */
388 balance211(work_amount, jcp.nthr_mb, ithr_mb, img_start, img_end);
389 img_work = img_end - img_start;
390
391 /* independent dimensions */
392 balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
393 g_work = g_end - g_start;
394
395 balance211(jcp.nb_oc, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end);
396 oc_b_work = oc_b_end - oc_b_start;
397
398 balance211(jcp.nb_ic, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end);
399 if (jcp.transform_to_vnni) {
400 if (ic_b_start % 2 != 0) ic_b_start++;
401 if (ic_b_end != jcp.nb_ic && ic_b_end % 2 != 0) ic_b_end++;
402 }
403 ic_b_work = ic_b_end - ic_b_start;
404
405 std::memset(cur_palette.a, 0, AMX_PALETTE_SIZE);
406 brgemm_batch_element_t *const __restrict brg_batch_global
407 = (jcp.brg_type == brgemm_strd)
408 ? nullptr
409 : scratchpad.template get<brgemm_batch_element_t>(
410 key_brgemm_primitive_batch);
411 brg_batch = brg_batch_global
412 + static_cast<size_t>(ithr) * jcp.adjusted_batch_size;
413
414 auto wsp_tile_global
415 = scratchpad.template get<char>(key_conv_amx_tile_buffer);
416 wsp_tile = wsp_tile_global + ithr * 2 * brgemm_convolution_utils::P4K;
417 }
418
419 const pd_t *pd() const { return self->pd(); }
420
421 inline int get_inp_start(int out_s, int pad, int str) const {
422 return nstl::max(0, -pad + out_s * str);
423 }
424
425 inline int get_inp_end(int out_e, int is, int pad, int str, int ek) const {
426 return nstl::min(is, -pad + (out_e - 1) * str + ek);
427 }
428
429 inline int get_id_start(int od_s) const {
430 return get_inp_start(od_s, jcp.f_pad, jcp.stride_d);
431 }
432 inline int get_ih_start(int oh_s) const {
433 return get_inp_start(oh_s, jcp.t_pad, jcp.stride_h);
434 }
435
436 inline int get_id_end(int od_e) const {
437 return get_inp_end(od_e, jcp.id, jcp.f_pad, jcp.stride_d, jcp.ext_kd);
438 }
439 inline int get_ih_end(int oh_e) const {
440 return get_inp_end(oh_e, jcp.ih, jcp.t_pad, jcp.stride_h, jcp.ext_kh);
441 }
442
443 size_t tr_src_buf_number(int g, int icb) const {
444 return jcp.global_transpose
445 ? ithr_mb * jcp.nb_ic * jcp.ngroups + g * jcp.nb_ic + icb
446 : ithr;
447 }
448
449 size_t tr_diff_dst_buf_number(int g, int ocb) const {
450 // for current loop order (xoi) if jcp.tr_ocb_chunk then we can reuse
451 // same area in tr_diff_dst buffer
452 if (jcp.tr_ocb_chunk)
453 return jcp.global_transpose
454 ? ((ithr_mb * jcp.ngroups + g) * jcp.nthr_oc_b + ithr_oc_b)
455 * jcp.nb_oc_blocking
456 + (ocb - oc_b_start) % jcp.nb_oc_blocking
457 : ithr;
458 else
459 return jcp.global_transpose
460 ? ithr_mb * jcp.nb_oc * jcp.ngroups + g * jcp.nb_oc + ocb
461 : ithr;
462 }
463
464 size_t tr_src_off(int g, int icb, int id, int ih) const {
465 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
466 const size_t tr_3d_size = tr_row_size * jcp.ih;
467 int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking;
468 // Aligned to buffer end to use guard elements
469 return tr_src_buf_number(g, icb) * adj * jcp.tr_src_buf_size
470 + id * tr_3d_size + ih * tr_row_size;
471 }
472
473 size_t tr_diff_dst_off(int g, int ocb, int od, int oh) const {
474 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
475 const size_t tr_3d_size = tr_row_size * jcp.oh;
476 int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking;
477 return tr_diff_dst_buf_number(g, ocb) * adj * jcp.tr_diff_dst_buf_size
478 + od * tr_3d_size + oh * tr_row_size;
479 }
480
481 void trans_src_nxc(src_data_t *tr_src, const src_data_t *src_base,
482 int spatial_start, dim_t spatial_start_offset, int icb_start,
483 dim_t chb_stride, int row_count) const {
484 const int src_stride = jcp.iw * jcp.ngroups * jcp.ic;
485 const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
486
487 int work_rest = row_count;
488 int max_spatial_work = jcp.id * jcp.ih;
489 int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start);
490 const src_data_t *src = src_base + spatial_start_offset;
491 int icb = 0;
492 const int ic_tail_work = jcp.ic_tail ? jcp.ic_tail : jcp.ic_block;
493 while (work_rest > 0) {
494 for (int iwork = 0; iwork < sp_work; iwork++) {
495 // For 1x1 convolutions with strides we transpose only
496 // needed lines
497 if (IMPLICATION(jcp.kh == 1, iwork % jcp.stride_h == 0)) {
498 auto ctx = jit_trans_src_t::ctx_t();
499 ctx.src = src;
500 ctx.tr_src = tr_src;
501 assert(icb_start + icb < jcp.nb_ic);
502 ctx.ch_work = (icb_start + icb + 1) == jcp.nb_ic
503 ? ic_tail_work
504 : jcp.ic_block;
505 ctx.src_prf = nullptr;
506 ctx.tr_src_prf = nullptr;
507 (*self->trans_kernel_)(&ctx);
508 }
509 src += src_stride;
510 tr_src += tr_src_stride;
511 }
512 work_rest -= sp_work;
513 sp_work = nstl::min(work_rest, max_spatial_work);
514 icb++;
515 src = src_base + icb * chb_stride;
516 }
517 }
518
519 void trans_dst_nxc(diff_dst_data_t *tr_diff_dst,
520 const diff_dst_data_t *diff_dst_base, int spatial_start,
521 dim_t spatial_start_offset, int ocb_start, dim_t chb_stride,
522 int row_count) const {
523 const int diff_dst_stride = jcp.ow * jcp.ngroups * jcp.oc;
524 const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block;
525 int work_rest = row_count;
526 int max_spatial_work = jcp.od * jcp.oh;
527 int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start);
528 const src_data_t *diff_dst = diff_dst_base + spatial_start_offset;
529 int ocb = 0;
530 const int oc_tail_work = jcp.oc_tail ? jcp.oc_tail : jcp.oc_block;
531 while (work_rest > 0) {
532 for (int iwork = 0; iwork < sp_work; iwork++) {
533 auto ctx = jit_trans_dst_t::ctx_t();
534 ctx.src = diff_dst;
535 ctx.tr_src = tr_diff_dst;
536 assert(ocb_start + ocb < jcp.nb_oc);
537 ctx.ch_work = (ocb_start + ocb + 1) == jcp.nb_oc ? oc_tail_work
538 : jcp.oc_block;
539 ctx.src_prf = nullptr;
540 ctx.tr_src_prf = nullptr;
541 (*self->trans_dst_kernel_)(&ctx);
542 diff_dst += diff_dst_stride;
543 tr_diff_dst += tr_diff_dst_stride;
544 }
545 work_rest -= sp_work;
546 sp_work = nstl::min(work_rest, max_spatial_work);
547 ocb++;
548 diff_dst = diff_dst_base + ocb * chb_stride;
549 }
550 }
551
552 void maybe_global_transpose(int img, int ocb_s, int ocb_e, int icb_s,
553 int icb_e, int od_s, int odb_s, int odb_e, int oh_s, int ohb_s,
554 int ohb_e) const {
555 if (!jcp.global_transpose) return;
556
557 using simple_barrier::barrier;
558 const int icb_work = icb_e - icb_s;
559 const int ocb_work = ocb_e - ocb_s;
560
561 // The barrier should stay outside of work condition to avoid
562 // possible hang
563 if (jcp.nthr_oc_b > 1)
564 barrier(&tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
565
566 if (icb_work > 0) {
567 const auto id_s = get_id_start(od_s);
568 const auto ih_s = get_ih_start(oh_s);
569
570 const auto idb_s = get_id_start(odb_s);
571 const auto idb_e = get_id_end(odb_e);
572
573 const auto ihb_s = get_ih_start(ohb_s);
574 const auto ihb_e = get_ih_end(ohb_e);
575
576 int work_amount
577 = g_work * icb_work * (idb_e - idb_s) * (ihb_e - ihb_s);
578 int tr_start {0}, tr_end {0};
579 balance211(work_amount, jcp.nthr_oc_b, ithr_oc_b, tr_start, tr_end);
580
581 int g {0}, ic_b {0}, jd {0}, jh {0};
582 nd_iterator_init(tr_start, g, g_work, ic_b, icb_work, jd,
583 idb_e - idb_s, jh, ihb_e - ihb_s);
584
585 while (tr_start < tr_end) {
586 int g_ = g + g_start;
587 int ic_b_ = ic_b + icb_s;
588
589 int jd_s = jd + idb_s;
590
591 int jh_s = jh + ihb_s;
592 int jh_e = jh_s + nstl::min(tr_end - tr_start, ihb_e - jh_s);
593
594 const int ic_off_idx = g_ * jcp.ic + ic_b_ * jcp.ic_block;
595
596 const src_data_t *p_src {nullptr};
597 if (jcp.harness == harness_2d_reduction) {
598 p_src = &src[src_d.blk_off(img, ic_off_idx, jh_s)];
599 } else if (jcp.harness == harness_3d_reduction) {
600 p_src = &src[src_d.blk_off(img, ic_off_idx, jd_s, jh_s)];
601 } else
602 assert(!"Invalid harness type");
603
604 src_data_t *p_tr_src = &tr_src[tr_src_off(
605 g_, ic_b_, jd_s - id_s, jh_s - ih_s)];
606 trans_src_nxc(p_tr_src, p_src, 0, 0, ic_b_, 0, jh_e - jh_s);
607
608 nd_iterator_jump(tr_start, tr_end, g, g_work, ic_b, icb_work,
609 jd, idb_e - idb_s, jh, ihb_e - ihb_s);
610 }
611 }
612 if (jcp.nthr_oc_b > 1)
613 barrier(&tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
614
615 // The barrier should stay outside of work condition to avoid
616 // possible hang
617 if (jcp.nthr_ic_b > 1)
618 barrier(&tr_diff_dst_bctx[ithr_but_ic], jcp.nthr_ic_b);
619
620 if (ocb_work > 0) {
621 int jd = 0;
622 int jh = 0;
623 int work_amount
624 = g_work * ocb_work * (odb_e - odb_s) * (ohb_e - ohb_s);
625 int tr_start = 0;
626 int tr_end = 0;
627 balance211(work_amount, jcp.nthr_ic_b, ithr_ic_b, tr_start, tr_end);
628
629 int g = 0;
630 int oc_b = 0;
631 nd_iterator_init(tr_start, g, g_work, oc_b, ocb_work, jd,
632 odb_e - odb_s, jh, ohb_e - ohb_s);
633
634 while (tr_start < tr_end) {
635 int g_ = g + g_start;
636 int oc_b_ = oc_b + ocb_s;
637 int jd_s = jd + odb_s;
638 int jh_s = jh + ohb_s;
639 int jh_e = jh_s + nstl::min(tr_end - tr_start, ohb_e - jh_s);
640 const int oc_off_idx = g_ * jcp.oc + oc_b_ * jcp.oc_block;
641
642 const diff_dst_data_t *p_diff_dst {nullptr};
643 if (jcp.harness == harness_2d_reduction) {
644 p_diff_dst = &diff_dst[diff_dst_d.blk_off(
645 img, oc_off_idx, jh_s)];
646 } else if (jcp.harness == harness_3d_reduction) {
647 p_diff_dst = &diff_dst[diff_dst_d.blk_off(
648 img, oc_off_idx, jd_s, jh_s)];
649 } else
650 assert(!"Invalid harness type");
651
652 diff_dst_data_t *p_tr_diff_dst = &tr_diff_dst[tr_diff_dst_off(
653 g_, oc_b_, jd_s - od_s, jh_s - oh_s)];
654 trans_dst_nxc(
655 p_tr_diff_dst, p_diff_dst, 0, 0, oc_b_, 0, jh_e - jh_s);
656
657 nd_iterator_jump(tr_start, tr_end, g, g_work, oc_b, ocb_work,
658 jd, odb_e - odb_s, jh, ohb_e - ohb_s);
659 }
660 }
661 if (jcp.nthr_ic_b > 1)
662 barrier(&tr_diff_dst_bctx[ithr_but_ic], jcp.nthr_ic_b);
663 }
664
665 void maybe_local_traspose(void *&p_src, void *&p_dst, int img, int g,
666 int ic_b, int oc_b, int od_s, int odb_s, int odb_e, int oh_s,
667 int ohb_s, int ohb_e) const {
668
669 const int idb_s = get_id_start(odb_s);
670 const int ihb_s = get_ih_start(ohb_s);
671
672 const int idb_e = get_id_end(odb_e);
673 const int ihb_e = get_ih_end(ohb_e);
674
675 const int id_s = get_id_start(od_s);
676 const int ih_s = get_ih_start(oh_s);
677
678 if (jcp.global_transpose) {
679 p_src = &tr_src[tr_src_off(g, ic_b, 0, 0)];
680 p_dst = &tr_diff_dst[tr_diff_dst_off(g, oc_b, 0, 0)];
681 return;
682 }
683
684 const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ic_b_end)
685 ? 1
686 : jcp.nb_ic_blocking;
687
688 const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > oc_b_end)
689 ? 1
690 : jcp.nb_oc_blocking;
691
692 for_(int idb = idb_s; idb < idb_e; idb++)
693 for (int icb = 0; icb < nb_ic_blocks; icb++) {
694 const int ic_off_idx = g * jcp.ic + (ic_b + icb) * jcp.ic_block;
695 src_data_t *p_tr_src
696 = &tr_src[tr_src_off(0, 0, idb - id_s, ihb_s - ih_s)];
697 src_data_t *tr_src_local = p_tr_src + icb * jcp.tr_src_buf_size;
698 const src_data_t *p_raw_src {nullptr};
699 if (jcp.harness == harness_2d_reduction) {
700 p_raw_src = (src_data_t
701 *)&src[src_d.blk_off(img, ic_off_idx, ihb_s)];
702 } else if (jcp.harness == harness_3d_reduction) {
703 p_raw_src = (src_data_t *)&src[src_d.blk_off(
704 img, ic_off_idx, idb, ihb_s)];
705 } else
706 assert(!"Invalid harness type");
707 trans_src_nxc(tr_src_local, p_raw_src, 0, 0, (ic_b + icb), 0,
708 (ihb_e - ihb_s));
709 }
710
711 p_src = &tr_src[tr_src_off(0, 0, 0, 0)]; // p_tr_src;
712
713 for_(int odb = odb_s; odb < odb_e; odb++)
714 for (int ocb = 0; ocb < nb_oc_blocks; ocb++) {
715 const int oc_off_idx = g * jcp.oc + (oc_b + ocb) * jcp.oc_block;
716 const diff_dst_data_t *p_raw_diff_dst {nullptr};
717 if (jcp.harness == harness_2d_reduction) {
718 p_raw_diff_dst
719 = &diff_dst[diff_dst_d.blk_off(img, oc_off_idx, ohb_s)];
720 } else if (jcp.harness == harness_3d_reduction) {
721 p_raw_diff_dst = &diff_dst[diff_dst_d.blk_off(
722 img, oc_off_idx, odb, ohb_s)];
723 } else
724 assert(!"Invalid harness type");
725 diff_dst_data_t *p_tr_diff_dst = &tr_diff_dst[tr_diff_dst_off(
726 0, 0, odb - od_s, ohb_s - oh_s)];
727 diff_dst_data_t *tr_diff_dst_local
728 = p_tr_diff_dst + ocb * jcp.tr_diff_dst_buf_size;
729 trans_dst_nxc(tr_diff_dst_local, p_raw_diff_dst, 0, 0, (oc_b + ocb),
730 0, (ohb_e - ohb_s));
731 }
732 p_dst = &tr_diff_dst[tr_diff_dst_off(0, 0, 0, 0)]; // p_tr_diff_dst;
733 }
734
735 bool just_init_output(
736 int start, int end, float *diff_wei, float *diff_bias) {
737 if (start < end || g_start >= g_end || oc_b_start >= oc_b_end
738 || ic_b_start >= ic_b_end)
739 return false;
740 // for rare case if thread has no work by spatial dimension then we
741 // need to initialize the output at least
742 if (jcp.with_bias) {
743 for_(int g = g_start; g < g_end; ++g)
744 {
745 void *p_bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
746 + oc_b_start * jcp.oc_block;
747 auto bias_amount = (oc_b_end - oc_b_start) * jcp.oc_block;
748 std::memset(p_bias, 0, bias_amount * jcp.acc_dsz);
749 }
750 }
751
752 for_(int g = g_start; g < g_end; ++g)
753 for (int oc_b = oc_b_start; oc_b < oc_b_end; oc_b++) {
754 auto wei_offs_ext = pd()->ndims() == 3
755 ? wht_blk_off(diff_weights_d, g, oc_b, ic_b_start, 0)
756 : (pd()->ndims() == 4 ? wht_blk_off(
757 diff_weights_d, g, oc_b, ic_b_start, 0, 0)
758 : wht_blk_off(diff_weights_d, g, oc_b,
759 ic_b_start, 0, 0, 0));
760 void *ptr_C = (jcp.transform_to_vnni) ? diff_wei
761 + self->wei_offset_int(g, oc_b, ic_b_start, 0, 0, 0)
762 : diff_wei + wei_offs_ext;
763
764 auto C_amount = jcp.kd * jcp.kh * jcp.kw * (ic_b_end - ic_b_start)
765 * jcp.ic_block * jcp.oc_block;
766
767 std::memset(ptr_C, 0, C_amount * jcp.acc_dsz);
768 }
769 return true;
770 }
771};
772
773void brgemm_convolution_bwd_weights_t::call_brgemm_kernel(
774 thread_info_t &btc, int brg_idx, int batch_size, void *ptr_C) const {
775
776 const auto brg_ker = brg_kernels_[brg_idx].get();
777 assert(brg_ker != nullptr);
778
779 // TODO: avoid costly tile reconfigurations
780 if (std::memcmp(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a,
781 AMX_PALETTE_SIZE)
782 != 0) {
783 amx_tile_configure(brg_kernel_palettes_[brg_idx].a);
784 std::memcpy(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a,
785 AMX_PALETTE_SIZE);
786 }
787
788 brgemm_kernel_execute(brg_ker, batch_size, btc.brg_batch, ptr_C,
789 static_cast<void *>(btc.wsp_tile));
790}
791
792void brgemm_convolution_bwd_weights_t::compute_diff_weights_2d(
793 thread_info_t *ti) const {
794
795 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
796 const auto _pd = pd();
797 const auto &jcp = _pd->jcp_;
798
799 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
800 * jcp.ic_block * jcp.kd * jcp.kh * jcp.kw;
801 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
802 const int optimal_spblock = jcp.spatial_blk_size;
803
804 float *diff_wei;
805 if (diff_weights_d.data_type() != data_type::f32)
806 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
807 else
808 diff_wei = ti->ithr_mb == 0
809 ? (float *)ti->diff_weights
810 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
811
812 float *diff_bias = nullptr;
813 if (jcp.with_bias) {
814 if (jcp.bia_dt != data_type::f32)
815 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
816 else
817 diff_bias = ti->ithr_mb == 0
818 ? (float *)ti->diff_bias
819 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
820 }
821
822 int img {0}, oh_s {0};
823 int start = ti->img_start;
824 int end = ti->img_end;
825
826 int img_s {0};
827 nd_iterator_init(start, img_s, jcp.mb, oh_s, jcp.oh);
828 img = img_s;
829
830 auto do_brgemm_call = [&](int g, int bs, int ic_b, int oc_b, int ohb_s,
831 int bs_ih_s, const void *p_src,
832 const void *p_dst, int kh, int kw,
833 bool do_init) {
834 const int ihb_s = ti->get_ih_start(ohb_s);
835
836 const int bs_oh_s = utils::saturate(0, jcp.oh,
837 (bs_ih_s + jcp.t_pad - kh * (jcp.dilate_h + 1)) / jcp.stride_h);
838
839 auto ocb_end = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end);
840 auto icb_end = get_end(ic_b, jcp.nb_ic_blocking, ti->ic_b_end);
841 const int src_stride_w_shift = jcp.tr_iw / jcp.stride_w;
842 const void *ptr_A = ((src_data_t *)p_src)
843 + _pd->filter_w_to_src(kw) / jcp.stride_w
844 + (kw % jcp.stride_w) * src_stride_w_shift
845 + (bs_ih_s - ihb_s) * jcp.tr_iw * jcp.ic_block;
846 const void *ptr_B = ((diff_dst_data_t *)p_dst)
847 + (bs_oh_s - ohb_s) * jcp.tr_ow * jcp.oc_block;
848
849 void *ptr_C = (jcp.transform_to_vnni)
850 ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0, kh, kw)
851 : diff_wei
852 + (pd()->ndims() == 3 ? wht_blk_off(
853 diff_weights_d, g, oc_b, ic_b, kw)
854 : wht_blk_off(diff_weights_d, g,
855 oc_b, ic_b, kh, kw));
856 bool M_tail = (icb_end < ic_b + jcp.nb_ic_blocking);
857 bool N_tail = (ocb_end < oc_b + jcp.nb_oc_blocking);
858
859 auto brg_idx = _pd->get_brg_idx(
860 bs, M_tail ? jcp.M_tail : jcp.M, do_init, N_tail, false);
861
862 for (int ohb = 0; ohb < bs; ohb++) {
863 ti->brg_batch[ohb].ptr.A = (char *)ptr_A
864 + ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block
865 * jcp.stride_h;
866 ti->brg_batch[ohb].ptr.B = (char *)ptr_B
867 + ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block;
868 }
869
870 call_brgemm_kernel(*ti, brg_idx, bs, ptr_C);
871 };
872
873 if (ti->just_init_output(start, end, diff_wei, diff_bias)) return;
874
875 while (start < end) {
876 const int oh_e = _pd->get_finish_oh(
877 oh_s, start, get_end(start, jcp.oh_block, end));
878 int height_block = jcp.global_transpose ? oh_e - oh_s : optimal_spblock;
879
880 // loop by ohb_s have only one iteration for global_transpose case
881 // because height_block = oh_e - oh_s
882 for (int ohb_s = oh_s; ohb_s < oh_e; ohb_s += height_block) {
883 const int ohb_e = get_end(ohb_s, height_block, oh_e);
884 assert(ohb_e <= jcp.oh);
885
886 ti->maybe_global_transpose(img,
887 jcp.tr_ocb_chunk ? 0 : ti->oc_b_start,
888 jcp.tr_ocb_chunk ? 0 : ti->oc_b_end,
889 jcp.tr_icb_chunk ? 0 : ti->ic_b_start,
890 jcp.tr_icb_chunk ? 0 : ti->ic_b_end, 0, 0, 1, oh_s, ohb_s,
891 ohb_e);
892
893 for_(int g = ti->g_start; g < ti->g_end; ++g)
894 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end;
895 oc_b += jcp.nb_oc_blocking) {
896 const int oc_b_e
897 = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end);
898
899 if (jcp.tr_ocb_chunk)
900 ti->maybe_global_transpose(img, oc_b, oc_b_e, 0, 0, 0, 0, 1,
901 oh_s, ohb_s, ohb_e);
902
903 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
904 ic_b += jcp.nb_ic_blocking) {
905
906 const int ic_b_e
907 = get_end(ic_b, jcp.nb_ic_blocking, ti->ic_b_end);
908
909 if (oc_b == ti->oc_b_start && jcp.tr_icb_chunk)
910 ti->maybe_global_transpose(img, 0, 0, ic_b, ic_b_e, 0,
911 0, 1, oh_s, ohb_s, ohb_e);
912
913 void *p_src {nullptr};
914 void *p_dst {nullptr};
915 ti->maybe_local_traspose(p_src, p_dst, img, g, ic_b, oc_b,
916 0, 0, 1, oh_s, ohb_s, ohb_e);
917
918 if (jcp.with_bias && ic_b == 0) {
919 auto bp = jit_conv_call_s();
920
921 bp.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
922 + oc_b * jcp.oc_block;
923 bp.channel
924 = (start == ti->img_start) && (ohb_s == oh_s);
925
926 bp.os_index_begin = ohb_s;
927 bp.os_index_end = ohb_e;
928
929 bp.last_oc_block
930 = ((oc_b_e - oc_b) == jcp.nb_oc_blocking) ? 0
931 : 1;
932
933 bp.dst = p_dst;
934
935 (*diff_bias_kernel_)(&bp);
936 }
937
938 if (ti->g_start == ti->g_end
939 || ti->oc_b_start == ti->oc_b_end
940 || ti->ic_b_start == ti->ic_b_end)
941 continue;
942
943 const auto do_init = (start == ti->img_start);
944
945 for (int kh = 0; kh < jcp.kh; kh++) {
946 const int bs_ih_s = _pd->get_start_ih(kh, ohb_s);
947 const int bs_ih_e = _pd->get_finish_ih(kh, ohb_e);
948 const auto bs = div_up(bs_ih_e - bs_ih_s, jcp.stride_h);
949 if (bs == 0 && !do_init) continue;
950
951 for_(int s = 0; s < jcp.stride_w; s++)
952 for (int kw = s; kw < jcp.kw; kw += jcp.stride_w)
953 do_brgemm_call(g, bs, ic_b, oc_b, ohb_s, bs_ih_s,
954 p_src, p_dst, kh, kw, do_init);
955 }
956 }
957 }
958 }
959
960 nd_iterator_jump(start, get_end(start, jcp.oh_block, end), img, jcp.mb,
961 oh_s, jcp.oh);
962 }
963}
964
965void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d(
966 thread_info_t *ti) const {
967
968 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
969 const auto _pd = pd();
970 const auto &jcp = _pd->jcp_;
971
972 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
973 * jcp.ic_block * jcp.kd * jcp.kh * jcp.kw;
974 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
975 const int optimal_spblock = jcp.spatial_blk_size;
976
977 float *diff_wei;
978 if (diff_weights_d.data_type() != data_type::f32)
979 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
980 else
981 diff_wei = ti->ithr_mb == 0
982 ? (float *)ti->diff_weights
983 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
984
985 float *diff_bias = nullptr;
986 if (jcp.with_bias) {
987 if (jcp.bia_dt != data_type::f32)
988 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
989 else
990 diff_bias = ti->ithr_mb == 0
991 ? (float *)ti->diff_bias
992 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
993 }
994
995 int img {0}, od_s {0};
996 int start = ti->img_start;
997 int end = ti->img_end;
998
999 int img_s {0};
1000 nd_iterator_init(start, img_s, jcp.mb, od_s, jcp.od);
1001 img = img_s;
1002
1003 auto do_brgemm_call = [&](int g, int bs_d, int bs_h, int ic_b, int oc_b,
1004 int od_s, int oh_s, int bs_id_s, int bs_ih_s,
1005 const void *p_src, const void *p_dst, int kd,
1006 int kh, int kw, bool do_init) {
1007 const int id_s = ti->get_id_start(od_s);
1008 const int ih_s = ti->get_ih_start(oh_s);
1009
1010 const int bs_od_s = utils::saturate(0, jcp.od,
1011 (bs_id_s + jcp.f_pad - kd * (jcp.dilate_d + 1)) / jcp.stride_d);
1012
1013 const int bs_oh_s = utils::saturate(0, jcp.oh,
1014 (bs_ih_s + jcp.t_pad - kh * (jcp.dilate_h + 1)) / jcp.stride_h);
1015
1016 auto ocb_end = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end);
1017 auto icb_end = get_end(ic_b, jcp.nb_ic_blocking, ti->ic_b_end);
1018 const int src_stride_w_shift = jcp.tr_iw / jcp.stride_w;
1019 const void *ptr_A = ((src_data_t *)p_src)
1020 + _pd->filter_w_to_src(kw) / jcp.stride_w
1021 + (kw % jcp.stride_w) * src_stride_w_shift
1022 + (bs_ih_s - ih_s) * jcp.tr_iw * jcp.ic_block
1023 + (bs_id_s - id_s) * jcp.ih * jcp.tr_iw * jcp.ic_block;
1024 const void *ptr_B = ((diff_dst_data_t *)p_dst)
1025 + (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block
1026 + (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block;
1027 void *ptr_C = (jcp.transform_to_vnni)
1028 ? diff_wei + wei_offset_int(g, oc_b, ic_b, kd, kh, kw)
1029 : diff_wei
1030 + wht_blk_off(
1031 diff_weights_d, g, oc_b, ic_b, kd, kh, kw);
1032 bool M_tail = (icb_end < ic_b + jcp.nb_ic_blocking);
1033 bool N_tail = (ocb_end < oc_b + jcp.nb_oc_blocking);
1034
1035 const auto bs = bs_d * bs_h;
1036 auto brg_idx = _pd->get_brg_idx(
1037 bs, M_tail ? jcp.M_tail : jcp.M, do_init, N_tail, false);
1038
1039 for (int odb = 0; odb < bs_d; odb++) {
1040 for (int ohb = 0; ohb < bs_h; ohb++) {
1041 ti->brg_batch[odb * bs_h + ohb].ptr.A = (char *)ptr_A
1042 + ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block
1043 * jcp.stride_h
1044 + odb * jcp.typesize_in * jcp.ih * jcp.tr_iw
1045 * jcp.ic_block * jcp.stride_d;
1046 ti->brg_batch[odb * bs_h + ohb].ptr.B = (char *)ptr_B
1047 + ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block
1048 + odb * jcp.typesize_in * jcp.oh * jcp.tr_ow
1049 * jcp.oc_block;
1050 }
1051 }
1052
1053 call_brgemm_kernel(*ti, brg_idx, bs, ptr_C);
1054 };
1055
1056 if (ti->just_init_output(start, end, diff_wei, diff_bias)) return;
1057
1058 const auto oh_s = 0;
1059 const auto oh_e = jcp.oh;
1060
1061 while (start < end) {
1062 const int od_e = _pd->get_finish_od(
1063 od_s, start, get_end(start, jcp.od_block, end));
1064 int sp_block = jcp.global_transpose ? od_e - od_s : optimal_spblock;
1065
1066 // loop by odb_s have only one iteration for global_transpose case
1067 // because sp_block = od_e - od_s
1068 for (int odb_s = od_s; odb_s < od_e; odb_s += sp_block) {
1069 const int odb_e = get_end(odb_s, sp_block, od_e);
1070 assert(odb_e <= jcp.od);
1071
1072 for (int ohb_s = oh_s; ohb_s < oh_e; ohb_s += jcp.oh_block) {
1073 const auto ohb_e = get_end(ohb_s, jcp.oh_block, jcp.oh);
1074
1075 ti->maybe_global_transpose(img,
1076 jcp.tr_ocb_chunk ? 0 : ti->oc_b_start,
1077 jcp.tr_ocb_chunk ? 0 : ti->oc_b_end,
1078 jcp.tr_icb_chunk ? 0 : ti->ic_b_start,
1079 jcp.tr_icb_chunk ? 0 : ti->ic_b_end, od_s, odb_s, odb_e,
1080 oh_s, ohb_s, ohb_e);
1081
1082 for_(int g = ti->g_start; g < ti->g_end; ++g)
1083 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end;
1084 oc_b += jcp.nb_oc_blocking) {
1085 const int oc_b_e
1086 = get_end(oc_b, jcp.nb_oc_blocking, ti->oc_b_end);
1087 if (jcp.tr_ocb_chunk)
1088 ti->maybe_global_transpose(img, oc_b, oc_b_e, 0, 0,
1089 od_s, odb_s, odb_e, oh_s, ohb_s, ohb_e);
1090
1091 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1092 ic_b += jcp.nb_ic_blocking) {
1093
1094 const int ic_b_e = get_end(
1095 ic_b, jcp.nb_ic_blocking, ti->ic_b_end);
1096
1097 if (oc_b == ti->oc_b_start && jcp.tr_icb_chunk)
1098 ti->maybe_global_transpose(img, 0, 0, ic_b, ic_b_e,
1099 od_s, odb_s, odb_e, oh_s, ohb_s, ohb_e);
1100
1101 void *p_src {nullptr};
1102 void *p_dst {nullptr};
1103 ti->maybe_local_traspose(p_src, p_dst, img, g, ic_b,
1104 oc_b, od_s, odb_s, odb_e, oh_s, ohb_s, ohb_e);
1105
1106 if (jcp.with_bias && ic_b == 0) {
1107 for (int iodb = odb_s; iodb < odb_e; iodb++) {
1108 auto bp = jit_conv_call_s();
1109
1110 bp.bias = diff_bias
1111 + g * rnd_up(jcp.oc, jcp.oc_block)
1112 + oc_b * jcp.oc_block;
1113 bp.os_index_begin = ohb_s;
1114 bp.os_index_end = ohb_e;
1115
1116 bp.last_oc_block
1117 = ((oc_b_e - oc_b)
1118 == jcp.nb_oc_blocking)
1119 ? 0
1120 : 1;
1121
1122 bp.channel = (start == ti->img_start)
1123 && (odb_s == od_s) && (iodb == odb_s)
1124 && (ohb_s == oh_s);
1125 bp.dst = ((diff_dst_data_t *)p_dst)
1126 + (iodb - od_s) * jcp.oh * jcp.tr_ow
1127 * jcp.oc_block
1128 + (ohb_s - oh_s) * jcp.tr_ow
1129 * jcp.oc_block;
1130 (*diff_bias_kernel_)(&bp);
1131 }
1132 }
1133
1134 if (ti->g_start == ti->g_end
1135 || ti->oc_b_start == ti->oc_b_end
1136 || ti->ic_b_start == ti->ic_b_end)
1137 continue;
1138
1139 const auto do_init
1140 = (start == ti->img_start && ohb_s == oh_s);
1141
1142 for (int kd = 0; kd < jcp.kd; kd++) {
1143 const int bs_id_s = _pd->get_start_id(kd, odb_s);
1144 const int bs_id_e = _pd->get_finish_id(kd, odb_e);
1145 const auto bs_d
1146 = div_up(bs_id_e - bs_id_s, jcp.stride_d);
1147 // bs_d may be 0 but we may still need to call brgemm to
1148 // initialize output
1149 if (bs_d == 0 && !do_init) continue;
1150
1151 for (int kh = 0; kh < jcp.kh; kh++) {
1152 const int bs_ih_s
1153 = _pd->get_start_ih(kh, ohb_s);
1154 const int bs_ih_e
1155 = _pd->get_finish_ih(kh, ohb_e);
1156 const auto bs_h = div_up(
1157 bs_ih_e - bs_ih_s, jcp.stride_h);
1158 if (bs_h == 0 && !do_init) continue;
1159
1160 for_(int s = 0; s < jcp.stride_w; s++)
1161 for (int kw = s; kw < jcp.kw;
1162 kw += jcp.stride_w)
1163 do_brgemm_call(g, bs_d, bs_h, ic_b, oc_b,
1164 od_s, oh_s, bs_id_s, bs_ih_s, p_src,
1165 p_dst, kd, kh, kw, do_init);
1166 }
1167 }
1168 }
1169 }
1170 }
1171 }
1172
1173 nd_iterator_jump(start, get_end(start, jcp.od_block, end), img, jcp.mb,
1174 od_s, jcp.od);
1175 }
1176}
1177
1178void brgemm_convolution_bwd_weights_t::store_in_vnni_format(
1179 thread_info_t *ti) const {
1180 const auto &jcp = pd()->jcp_;
1181
1182 for_(int g = ti->g_start; g < ti->g_end; g++)
1183 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++)
1184 for_(int ic_b = ti->ic_b_start; ic_b < ti->ic_b_start + ti->ic_b_work;
1185 ic_b += 2)
1186 {
1187 jit_conv_call_s p = jit_conv_call_s();
1188
1189 bfloat16_t *output = (bfloat16_t *)ti->diff_weights
1190 + wei_offset_ext(g, oc_b, (ic_b / 2), 0);
1191 float *input = ti->wei_bia_reduction + wei_offset_int(g, oc_b, ic_b, 0);
1192
1193 p.src = (void *)input;
1194 p.dst = (void *)output;
1195 p.last_ic_block = ((ic_b + 1) >= jcp.nb_ic) ? 1 : 0;
1196 (*diff_wei_trans_kernel_)(&p);
1197 }
1198}
1199
1200void brgemm_convolution_bwd_weights_t::reduce_and_convert_diff_weights_and_bias(
1201 thread_info_t *ti) const {
1202
1203 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1204
1205 const auto &jcp = pd()->jcp_;
1206 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1207 * jcp.ic_block * jcp.kh * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1);
1208
1209 const auto wei_dt = diff_weights_d.data_type();
1210 const auto bia_dt = jcp.bia_dt;
1211 const bool is_f32_out = wei_dt == data_type::f32;
1212 const bool is_f32_bias = bia_dt == data_type::f32;
1213
1214 if (jcp.nthr_mb == 1) {
1215 if (!is_f32_out) {
1216 // reduction is not required, only conversion
1217 if (jcp.transform_to_vnni) {
1218 store_in_vnni_format(ti);
1219 } else {
1220 for_(int g = ti->g_start; g < ti->g_end; g++)
1221 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) {
1222 const size_t acc_size = (size_t)ti->ic_b_work * jcp.kh
1223 * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1)
1224 * jcp.ic_block * jcp.oc_block;
1225 const size_t off = wht_blk_off(
1226 diff_weights_d, g, oc_b, ti->ic_b_start);
1227 types::cvt_from_float(wei_dt,
1228 (void *)((char *)ti->diff_weights
1229 + off * types::data_type_size(wei_dt)),
1230 (ti->wei_bia_reduction + off), acc_size);
1231 }
1232 }
1233 }
1234 if (pd()->with_bias() && !is_f32_bias && ti->ithr_ic_b == 0
1235 && ti->ic_b_work > 0) {
1236 for (int g = ti->g_start; g < ti->g_end; g++) {
1237 int result_start_idx
1238 = g * jcp.oc + ti->oc_b_start * jcp.oc_block;
1239 int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block)
1240 + ti->oc_b_start * jcp.oc_block;
1241 const size_t acc_size
1242 = nstl::min(jcp.oc, ti->oc_b_end * jcp.oc_block)
1243 - ti->oc_b_start * jcp.oc_block;
1244 void *diff_bias = (char *)ti->diff_bias
1245 + result_start_idx * types::data_type_size(bia_dt);
1246 float *buffer = ti->bia_reduction + buffer_start_idx;
1247 types::cvt_from_float(
1248 bia_dt, diff_bias, (const float *)buffer, acc_size);
1249 }
1250 }
1251 return;
1252 }
1253
1254 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1255 if (jcp.global_transpose)
1256 simple_barrier::barrier(ti->wei_bia_reduction_bctx, jcp.nthr);
1257
1258 const int ic_b_kh_work
1259 = ti->ic_b_work * ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1260 if (ic_b_kh_work <= 0 || ti->oc_b_work == 0 || ti->g_work == 0) {
1261 // TODO: double check if a barrier is needed here
1262 // and at the end of function
1263 if (jcp.transform_to_vnni && jcp.global_transpose)
1264 simple_barrier::barrier(ti->wei_bia_reduction_bctx, jcp.nthr);
1265 return;
1266 }
1267
1268 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1269
1270 int start {0}, end {0};
1271 balance211(work, jcp.nthr_mb, ti->ithr_mb, start, end);
1272 if (!jcp.transform_to_vnni && start == end) return;
1273
1274 const int _start_nthr_mb = 1;
1275 for (int thr_mb = _start_nthr_mb; thr_mb < jcp.nthr_mb; ++thr_mb) {
1276 int w = start;
1277 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0};
1278 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1279 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1280 while (w < end) {
1281 const int g = ti->g_start + sub_g_start;
1282 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1283 const int ic_b = ti->ic_b_start
1284 + sub_ic_b_kh_start / ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1285 const int kX
1286 = sub_ic_b_kh_start % ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1287
1288 const size_t acc_size = (size_t)jcp.kw * jcp.ic_block * jcp.oc_block
1289 * ((jcp.ndims == 5) ? jcp.kh : 1)
1290 * nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start);
1291
1292 const size_t off_ext
1293 = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kX);
1294 const size_t off_int = (jcp.transform_to_vnni)
1295 ? wei_offset_int(g, oc_b, ic_b, kX)
1296 : off_ext;
1297
1298 float *wei_reduced = is_f32_out
1299 ? (float *)(ti->diff_weights) + off_ext
1300 : ti->wei_bia_reduction + off_int;
1301
1302 int thr_mb_buffer_idx = is_f32_out ? thr_mb - 1 : thr_mb;
1303 float *wei_to_reduce = ti->wei_bia_reduction
1304 + thr_mb_buffer_idx * wei_size + off_int;
1305
1306 if (!jcp.transform_to_vnni && !is_f32_out
1307 && thr_mb == jcp.nthr_mb - 1) {
1308 // the last iteration for bfloat16 requires conversion and
1309 // store to diff_weights array
1310 if (wei_dt == bf16)
1311 add_floats_and_cvt_to_bfloat16(
1312 (bfloat16_t *)(ti->diff_weights) + off_ext,
1313 wei_reduced, wei_to_reduce, acc_size);
1314 else if (wei_dt == f16)
1315 add_floats_and_cvt_to_float16(
1316 (float16_t *)(ti->diff_weights) + off_ext,
1317 wei_reduced, wei_to_reduce, acc_size);
1318 } else
1319 acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size);
1320
1321 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1322 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1323 }
1324 if (jcp.with_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0
1325 && ti->ithr_mb == 0 && ti->img_work > 0) {
1326 for (int g = ti->g_start; g < ti->g_end; g++) {
1327 float *bias_reduced = is_f32_bias ? (float *)(ti->diff_bias)
1328 : ti->bia_reduction;
1329 int thr_mb_buffer_idx = is_f32_bias ? thr_mb - 1 : thr_mb;
1330 int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1331 float *bias_to_reduce
1332 = ti->bia_reduction + thr_mb_buffer_idx * bias_buf_size;
1333 const size_t acc_size
1334 = nstl::min(jcp.oc, ti->oc_b_end * jcp.oc_block)
1335 - ti->oc_b_start * jcp.oc_block;
1336 int idx = g * rnd_up(jcp.oc, jcp.oc_block)
1337 + ti->oc_b_start * jcp.oc_block;
1338 if (!is_f32_bias && thr_mb == jcp.nthr_mb - 1) {
1339 // the last iteration for bfloat16 requires conversion and
1340 // store to diff_weights array
1341 int diff_bias_idx
1342 = g * jcp.oc + ti->oc_b_start * jcp.oc_block;
1343 if (bia_dt == bf16)
1344 add_floats_and_cvt_to_bfloat16(
1345 (bfloat16_t *)(ti->diff_bias) + diff_bias_idx,
1346 &bias_reduced[idx], &bias_to_reduce[idx],
1347 acc_size);
1348 else if (bia_dt == f16)
1349 add_floats_and_cvt_to_float16(
1350 (float16_t *)(ti->diff_bias) + diff_bias_idx,
1351 &bias_reduced[idx], &bias_to_reduce[idx],
1352 acc_size);
1353 } else {
1354 acc_ker_->accumulate(
1355 &bias_reduced[idx], &bias_to_reduce[idx], acc_size);
1356 }
1357 }
1358 }
1359 }
1360
1361 if (jcp.transform_to_vnni && jcp.global_transpose) {
1362 simple_barrier::barrier(ti->wei_bia_reduction_bctx, jcp.nthr);
1363 store_in_vnni_format(ti);
1364 }
1365}
1366
1367void brgemm_convolution_bwd_weights_t::prepare_scratchpad_data(
1368 const exec_ctx_t &ctx) const {
1369 auto scratchpad = ctx.get_scratchpad_grantor();
1370
1371 const auto &jcp = pd()->jcp_;
1372
1373 auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
1374 // Zero out guard elements that cross a buffer boundary to prevent a
1375 // race condition due to buffer overflows from memory optimization where
1376 // buffers sharing padding
1377 // TODO: optimize it
1378 for (size_t isb = 1; isb <= jcp.tr_src_buf_count; ++isb) {
1379 src_data_t *ts
1380 = &tr_src[isb * jcp.tr_src_buf_size * jcp.nb_ic_blocking];
1381 for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i)
1382 ts[i] = 0;
1383 }
1384
1385 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
1386 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
1387 auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1388 key_conv_tr_src_bctx);
1389 for (int i = 0; i < tr_src_bctx_size; ++i)
1390 simple_barrier::ctx_init(&tr_src_bctx[i]);
1391 }
1392 if (jcp.global_transpose) {
1393 if (jcp.nthr_ic_b > 1) {
1394 const int tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
1395 auto tr_diff_dst_bctx
1396 = scratchpad.template get<simple_barrier::ctx_t>(
1397 key_conv_tr_diff_dst_bctx);
1398 for (int i = 0; i < tr_diff_dst_bctx_size; ++i)
1399 simple_barrier::ctx_init(&tr_diff_dst_bctx[i]);
1400 }
1401 }
1402
1403 if (jcp.nthr_mb > 1
1404 || pd()->diff_weights_md(0)->data_type != data_type::f32) {
1405 // TODO: don't use barrier for case
1406 // diff_weights_type != data_type::f32 && nthr_mb_ == 1
1407 simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1408 key_conv_wei_bia_reduction_bctx));
1409 }
1410}
1411
1412void brgemm_convolution_bwd_weights_t::execute_backward_weights(
1413 const exec_ctx_t &ctx) const {
1414 prepare_scratchpad_data(ctx);
1415
1416 const auto &jcp = pd()->jcp_;
1417
1418 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1419 assert(jcp.nthr == nthr);
1420 assert(utils::one_of(pd()->ndims(), 3, 4, 5));
1421
1422 thread_info_t thread_info(this, ctx, ithr);
1423 switch (jcp.harness) {
1424 case harness_2d_reduction:
1425 compute_diff_weights_2d(&thread_info);
1426 if (jcp.global_transpose)
1427 reduce_and_convert_diff_weights_and_bias(&thread_info);
1428 break;
1429 case harness_3d_reduction:
1430 compute_diff_weights_3d(&thread_info);
1431 if (jcp.global_transpose)
1432 reduce_and_convert_diff_weights_and_bias(&thread_info);
1433 break;
1434 default: assert(!"Invalid harness type");
1435 }
1436
1437 amx_tile_release();
1438 });
1439
1440 if (!jcp.global_transpose) {
1441 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1442 assert(jcp.nthr == nthr);
1443 thread_info_t thread_info(this, ctx, ithr);
1444 reduce_and_convert_diff_weights_and_bias(&thread_info);
1445 });
1446 }
1447
1448 if (jcp.transform_to_vnni && !jcp.global_transpose) {
1449 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1450 assert(jcp.nthr == nthr);
1451 thread_info_t thread_info(this, ctx, ithr);
1452 store_in_vnni_format(&thread_info);
1453 });
1454 }
1455
1456 if (pd()->with_bias() && (jcp.oc % jcp.oc_block != 0)
1457 && jcp.bia_dt == data_type::f32) {
1458 auto diff_bias = ctx.get_scratchpad_grantor().template get<const float>(
1459 key_conv_padded_bias);
1460 auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
1461 const int padded_stride = rnd_up(jcp.oc, jcp.oc_block);
1462 const int stride = jcp.oc;
1463 for (int g = 0; g < jcp.ngroups; ++g) {
1464 utils::array_copy(diff_bias_in + g * stride,
1465 diff_bias + g * padded_stride, stride);
1466 }
1467 }
1468}
1469
1470} // namespace x64
1471
1472} // namespace cpu
1473} // namespace impl
1474} // namespace dnnl
1475
1476// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1477