1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <atomic>
18
19#include "oneapi/dnnl/dnnl_types.h"
20
21#include "common/bfloat16.hpp"
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26#include "cpu/x64/gemm_bf16_convolution.hpp"
27#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34using namespace dnnl::impl::status;
35using namespace dnnl::impl::memory_tracking::names;
36using namespace dnnl::impl::utils;
37using namespace dnnl::impl::cpu::x64::bf16_support;
38
39// Below two stand-alone functions are moved out from execute_backward_data
40// and execute_backward_weights to avoid warnings with gcc 6.x and 7.x compilers
41// "declared with greater visibility than the type of its field"
42// when one lambda function is delcared whithin the other one
43void store_bfloat16_in_parallel(bfloat16_t *output_data, const float *acc_data,
44 size_t parallel_work, size_t parallel_work_size, bool do_in_parallel) {
45 parallel(do_in_parallel ? 0 : 1, [&](const int ithr, const int nthr) {
46 size_t start = 0, end = 0;
47 balance211(parallel_work, nthr, ithr, start, end);
48 if (start < end)
49 cvt_float_to_bfloat16(&output_data[start * parallel_work_size],
50 &acc_data[start * parallel_work_size],
51 (end - start) * parallel_work_size);
52 });
53}
54
55void cvt_acc_to_dst(const conv_gemm_conf_t &jcp, size_t g_start, size_t g_end,
56 const float *acc_base, bfloat16_t *diff_weights) {
57 const size_t parallel_work_size = jcp.ic * jcp.ks;
58 parallel(jcp.nthr == 1 ? 0 : 1, [&](const int ithr, const int nthr) {
59 size_t w_start = 0, w_end = 0;
60 balance211(parallel_work_size, nthr, ithr, w_start, w_end);
61 for_(auto w = w_start; w < w_end; ++w)
62 for (auto g = g_start; g < g_end; ++g) {
63 const float *__restrict acc_ptr
64 = acc_base + (w * jcp.ngroups + g) * jcp.oc;
65 bfloat16_t *__restrict dw_ptr
66 = diff_weights + (w * jcp.ngroups + g) * jcp.oc;
67 cvt_float_to_bfloat16(dw_ptr, acc_ptr, jcp.oc);
68 }
69 });
70}
71
72template <data_type_t dst_data_type>
73gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd)
74 : jit_generator(jit_name())
75 , jcp_(pd->jcp_)
76 , do_sum_(dst_data_type != data_type::f32 && jcp_.with_sum)
77 , max_data_reg_idx_(31)
78 , max_unroll_(12)
79 , compute_reg_step_(1)
80 , data_reg_base_idx_(0) {
81 using namespace types;
82 using namespace Xbyak;
83
84 if (!mayiuse(avx512_core))
85 // bf16 is not supported
86 return;
87
88 const auto &post_ops = jcp_.post_ops;
89 if (jcp_.with_eltwise || jcp_.with_binary) {
90#define PARAM_OFF(field) offsetof(ker_args, field)
91 static constexpr bool preserve_gpr = true;
92 static constexpr bool preserve_vmm = true;
93 static constexpr size_t helper_vmm_idx = 31;
94 static constexpr size_t tail_size = 1;
95 static constexpr bool use_exact_tail_scalar_bcast = false;
96 const binary_injector::rhs_arg_static_params_t rhs_arg_static_params {
97 helper_vmm_idx, reserved_eltwise_gpr, r14, r15, preserve_gpr,
98 preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec),
99 PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()),
100 tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast};
101 const binary_injector::static_params_t binary_static_params {
102 this->reg_param, rhs_arg_static_params};
103 static constexpr bool save_state = true;
104 const eltwise_injector::static_params_t eltwise_static_params {
105 save_state, reserved_eltwise_gpr, reserved_eltwise_maskr};
106
107 postops_injector_ = utils::make_unique<
108 injector::jit_uni_postops_injector_t<avx512_core>>(
109 this, post_ops, binary_static_params, eltwise_static_params);
110#undef PARAM_OFF
111 }
112
113 if (do_sum_) {
114 compute_reg_step_ = 2;
115 vreg_sum_scale = Zmm(data_reg_base_idx_++);
116 }
117
118 if (jcp_.with_bias) vreg_bias = Zmm(data_reg_base_idx_++);
119
120 vlen_ = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
121
122 isa_ = mayiuse(avx512_core_bf16) ? avx512_core_bf16
123 : bf16_emulation_t::get_isa();
124
125 if (isa_ != avx512_core_bf16) {
126 max_data_reg_idx_ = 26;
127 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
128 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
129 bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6);
130 }
131
132 max_unroll_
133 = (max_data_reg_idx_ - data_reg_base_idx_ + 1) / compute_reg_step_;
134}
135
136template <data_type_t dst_data_type>
137void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::apply_postops(
138 const bool apply_mask, const size_t out_offset, const int vmm_idx) {
139#define PARAM_OFF(x) offsetof(ker_args, x)
140 if (jcp_.with_eltwise || jcp_.with_binary) {
141 if (jcp_.with_binary) {
142 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
143 rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst);
144 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
145 vmm_idx, out_offset * sizeof(dst_data_t));
146 if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
147
148 postops_injector_->compute_vector(vmm_idx, rhs_arg_params);
149 } else
150 postops_injector_->compute_vector(vmm_idx);
151 }
152#undef PARAM_OFF
153}
154
155template <data_type_t dst_data_type>
156void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() {
157 using namespace Xbyak;
158 using namespace utils;
159
160 preamble();
161
162#ifdef _WIN32
163 mov(reg_param, rcx);
164#endif
165
166#define PARAM_OFF(x) offsetof(ker_args, x)
167 mov(reg_dst_base, ptr[reg_param + PARAM_OFF(dst)]);
168 mov(reg_acc_base, ptr[reg_param + PARAM_OFF(acc)]);
169 if (jcp_.with_bias) mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
170 mov(reg_dst_str, ptr[reg_param + PARAM_OFF(dst_stride_in_bytes)]);
171 mov(reg_acc_str, ptr[reg_param + PARAM_OFF(acc_stride_in_bytes)]);
172 mov(reg_len, ptr[reg_param + PARAM_OFF(spatial_length)]);
173 mov(reg_oc_iter, ptr[reg_param + PARAM_OFF(oc_work)]);
174
175 if (jcp_.with_binary) {
176 // zero initialize binary post_ops offset accumulator (store on stack)
177 const auto binary_post_op_acc_off_reg = reg_tmp;
178 xor_(binary_post_op_acc_off_reg, binary_post_op_acc_off_reg);
179 push(binary_post_op_acc_off_reg);
180 }
181
182 if (do_sum_)
183 vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
184#undef PARAM_OFF
185
186 // Load accumulated value, apply sum (if any), bias (if any)
187 // and relu (if any); then convert to destination type and store
188 auto compute = [&](size_t offset, int idx, bool apply_mask) {
189 auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
190 auto vreg_dst_ = vreg_dst(idx);
191
192 if (dst_data_type == data_type::bf16 && isa_ != avx512_core_bf16)
193 bf16_emu_->init_vcvtneps2bf16();
194
195 if (apply_mask) vreg_dst_ = vreg_dst_ | kreg_rem_mask;
196 vmovups(vreg_dst_, acc_addr);
197
198 if (jcp_.with_bias) vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias);
199
200 auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
201 if (do_sum_) {
202 auto vreg_prev_dst_ = vreg_prev_dst(idx);
203 if (dst_data_type == data_type::f32) {
204 if (apply_mask) vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask;
205
206 vmovups(vreg_prev_dst_, dst_addr);
207 } else if (dst_data_type == data_type::bf16) {
208 auto vreg_prev_dst_ymm_ = vreg_prev_dst_ymm(idx);
209 if (apply_mask)
210 vreg_prev_dst_ymm_ = vreg_prev_dst_ymm_ | kreg_rem_mask;
211
212 vmovdqu16(vreg_prev_dst_ymm_, dst_addr);
213 vpmovzxwd(vreg_prev_dst(idx), vreg_prev_dst_ymm_);
214 vpslld(vreg_prev_dst(idx), vreg_prev_dst(idx), 0x10);
215 } else
216 assert(!"unsupported data type");
217
218 vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
219 }
220
221 apply_postops(apply_mask, offset, vreg_dst_idx(idx));
222
223 if (dst_data_type == data_type::bf16) {
224 // TODO: implement store by zmm registers for bf16
225 auto vreg_dst_ymm_ = vreg_dst_ymm(idx);
226 if (isa_ == avx512_core_bf16)
227 vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx));
228 else
229 bf16_emu_->vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx));
230
231 if (apply_mask) vreg_dst_ymm_ = vreg_dst_ymm_ | kreg_rem_mask;
232
233 vmovdqu16(dst_addr, vreg_dst_ymm_);
234 } else if (dst_data_type == data_type::f32)
235 vmovups(dst_addr, vreg_dst_);
236 else
237 assert(!"unimplemented");
238 };
239
240 // Advance all pointers by an immediate
241 auto advance_ptrs_imm = [&](size_t offset) {
242 add(reg_dst, offset * sizeof(dst_data_t));
243 add(reg_acc, offset * sizeof(acc_data_t));
244 };
245
246 Xbyak::Label oc_loop, oc_loop_end;
247
248 cmp(reg_oc_iter, 0);
249 jle(oc_loop_end, T_NEAR);
250
251 L(oc_loop);
252
253 mov(reg_len_iter, reg_len);
254 mov(reg_dst, reg_dst_base);
255 mov(reg_acc, reg_acc_base);
256
257 if (jcp_.with_bias) vbroadcastss(vreg_bias, ptr[reg_bias]);
258
259 constexpr int n_unroll = default_unroll_2_pow_; // unroll by powers of 2
260 // from 2^n to 2^0
261 assert((1 << n_unroll) <= max_unroll_);
262
263 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
264 for (int i = n_unroll; i >= 0; i--) {
265 const int unroll = 1 << i; // 4, 2, 1
266 L(l_simd_loop[i + 1]);
267 {
268 const int loop_len = unroll * vlen_;
269 cmp(reg_len_iter, loop_len);
270 jl(l_simd_loop[i], T_NEAR);
271 for (int j = 0; j < unroll; j++)
272 compute(j * vlen_, j, false);
273
274 advance_ptrs_imm(loop_len);
275 sub(reg_len_iter, loop_len);
276 jmp(l_simd_loop[i + 1], T_NEAR);
277 }
278 }
279 L(l_simd_loop[0]);
280
281 mov(reg_tmp, reg_len_iter); // reg_tmp is rcx, and we need cl for the shift
282 mov(reg_rem_mask, 1);
283 shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen_ == 16
284 sub(reg_rem_mask, 1);
285 jz(l_simd_notail, T_NEAR);
286 kmovq(kreg_rem_mask, reg_rem_mask);
287 compute(0, 0, true);
288
289 L(l_simd_notail);
290
291 add(reg_dst_base, reg_dst_str);
292 add(reg_acc_base, reg_acc_str);
293 if (jcp_.with_bias) add(reg_bias, sizeof(acc_data_t));
294 if (jcp_.with_binary)
295 inc(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off));
296
297 dec(reg_oc_iter);
298 jnz(oc_loop, T_NEAR); // oc_loop end
299
300 L(oc_loop_end);
301
302 if (jcp_.with_binary) add(rsp, stack_space_needed);
303
304 postamble();
305
306 if (jcp_.with_eltwise) postops_injector_->prepare_table();
307}
308
309// operator () specialized for nspc format
310template <data_type_t dst_data_type>
311void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
312 dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
313 float sum_scale, size_t oc_work,
314 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
315 const size_t g_oc_offset) {
316
317 ker_args args;
318 args.acc = acc;
319 args.dst = dst;
320 args.bias = bias;
321 args.sum_scale = sum_scale;
322 args.dst_stride_in_bytes = sizeof(dst_data_t);
323 args.acc_stride_in_bytes = sizeof(acc_data_t);
324 args.spatial_length = 1;
325 args.oc_work = oc_work;
326
327 args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
328 args.dst_orig = dst_orig;
329 args.g_oc_offset = g_oc_offset;
330 jit_generator::operator()(&args);
331}
332
333// operator () specialized for ncsp format
334template <data_type_t dst_data_type>
335void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()(
336 dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias,
337 float sum_scale, size_t dst_stride_in_elements,
338 size_t acc_stride_in_elements, size_t sp_len, size_t oc_len,
339 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
340 const size_t g_oc_offset) {
341 if (sp_len == 0) return;
342
343 ker_args args;
344 args.acc = acc;
345 args.dst = dst;
346 args.bias = bias;
347 args.sum_scale = sum_scale;
348 args.dst_stride_in_bytes = dst_stride_in_elements * sizeof(dst_data_t);
349 args.acc_stride_in_bytes = acc_stride_in_elements * sizeof(acc_data_t);
350 args.spatial_length = sp_len;
351 args.oc_work = oc_len;
352
353 args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
354 args.dst_orig = dst_orig;
355 args.g_oc_offset = g_oc_offset;
356 jit_generator::operator()(&args);
357}
358
359template <data_type_t dst_data_type>
360status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_nspc(
361 const exec_ctx_t &ctx) const {
362 auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
363 auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
364 auto dst_base = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
365 const auto post_ops_binary_rhs_arg_vec
366 = binary_injector::prepare_binary_args(
367 this->pd()->attr()->post_ops_, ctx);
368
369 auto scratchpad = ctx.get_scratchpad_grantor();
370 const conv_gemm_conf_t &jcp = pd()->jcp_;
371
372 float *bia_base = nullptr;
373 if (jcp.with_bias) {
374 if (pd()->desc()->bias_desc.data_type == data_type::bf16) {
375 auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS);
376 bia_base = ctx.get_scratchpad_grantor().template get<float>(
377 key_conv_bias_bf16_convert_wsp);
378 cvt_bfloat16_to_float(bia_base, bias_in, jcp.ngroups * jcp.oc);
379 } else {
380 auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
381 bia_base = const_cast<float *>(bias_in);
382 }
383 }
384 assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
385
386 std::atomic<status_t> st(status::success);
387 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
388 status_t st_thr = execute_forward_thr_nspc(ithr, nthr, src_base,
389 wei_base, bia_base, dst_base, scratchpad,
390 post_ops_binary_rhs_arg_vec.data());
391 if (st_thr != status::success) st = st_thr;
392 });
393
394 return st;
395}
396
397template <data_type_t dst_data_type>
398status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_thr_nspc(
399 const int ithr, const int nthr, const src_data_t *src_base,
400 const wei_data_t *wei_base, const float *bia_base, dst_data_t *dst_base,
401 const memory_tracking::grantor_t &scratchpad,
402 const void *post_ops_binary_rhs_arg_vec) const {
403 const conv_gemm_conf_t &jcp = pd()->jcp_;
404
405 // Src Format: mb-spatial-groups-input_channels
406 const size_t src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw
407 * jcp.ngroups * jcp.ic;
408 const size_t src_g_stride = jcp.ic;
409 // Wei Format: spatial-input_channels-groups-output_channels
410 const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
411
412 // Dst Format: mb-spatial-groups-output_channels
413 const size_t dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh * jcp.ow
414 * jcp.ngroups * jcp.oc;
415 const size_t dst_g_stride = jcp.oc;
416 const size_t dst_os_stride = jcp.ngroups * jcp.oc;
417
418 src_data_t *__restrict col = scratchpad.get<src_data_t>(key_conv_gemm_col)
419 + (ptrdiff_t)ithr * jcp.im2col_sz;
420 src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr)
421 + (ptrdiff_t)ithr * jcp.is * jcp.ic;
422 acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc)
423 + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
424
425 const auto &post_ops = pd()->attr()->post_ops_;
426 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
427 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
428
429 dim_t g {0}, n {0}, ohb {0}, owb {0};
430 dim_t start = 0, end = 0;
431
432 const bool is_problem_3d = pd()->ndims() == 5;
433 assert(IMPLICATION(is_problem_3d,
434 jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow
435 && jcp.ic_block == jcp.ic));
436
437 const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block);
438 const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block);
439 const dim_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow;
440 balance211(work_amount, nthr, ithr, start, end);
441 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
442
443 if (jcp.im2col_sz && is_problem_3d) {
444 // jit_gemm_convolution_utils::im2col_dt_3d() requires external
445 // data initialization by zeroes
446 // For performance reasons use uint16_t as a proxy for bfloat16_t
447 uint16_t *__restrict col_r
448 = reinterpret_cast<uint16_t *__restrict>(col);
449 constexpr uint16_t zero_val = 0;
450
451 PRAGMA_OMP_SIMD()
452 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
453 col_r[i] = zero_val;
454 }
455 for (dim_t iwork = start; iwork < end; ++iwork) {
456 int oh = ohb * jcp.oh_block;
457 int ow = owb * jcp.ow_block;
458 const src_data_t *__restrict src
459 = src_base + n * src_mb_stride + g * src_g_stride;
460 const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
461
462 const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
463 const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
464 if (jcp.im2col_sz && is_problem_3d)
465 jit_gemm_convolution_utils::transpose_dt(jcp, src, imtr);
466
467 for (int od = 0; od < jcp.od; od++) {
468 dst_data_t *__restrict dst = dst_base + n * dst_mb_stride
469 + g * dst_g_stride
470 + ((od * jcp.oh + oh) * jcp.ow + ow) * dst_os_stride;
471 if (jcp.im2col_sz) {
472 if (is_problem_3d)
473 jit_gemm_convolution_utils::im2col_dt_3d<src_data_t,
474 src_data_t>(jcp, imtr, col, od);
475 else
476 jit_gemm_convolution_utils::im2col_dt<src_data_t,
477 src_data_t>(
478 jcp, src, imtr, col, oh, h_step, ow, w_step);
479 }
480
481 const dim_t M = jcp.oc;
482 const dim_t K = jcp.ks * jcp.ic;
483 const dim_t N = h_step * w_step;
484 const dim_t LDA = M * jcp.ngroups;
485 const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups;
486 const char *BT = jcp.im2col_sz ? "T" : "N";
487 const float onef = 1.f;
488 const float beta = this->beta_;
489 const src_data_t *__restrict src_od
490 = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
491 const bool acc_needed = dst_data_type == data_type::bf16;
492 status_t st = gemm_bf16bf16f32("N", BT, &M, &N, &K, &onef, wei,
493 &LDA, jcp.im2col_sz ? col : (src_data_t *)src_od, &LDB,
494 &beta, acc_needed ? acc : (float *)dst,
495 acc_needed ? &M : &LDA);
496 if (st != status::success) return st;
497
498 const bool do_postprocess = pd()->is_postprocess_required();
499 if (do_postprocess) {
500 parallel_nd_ext(jcp.nthr == 1 ? 0 : 1, N,
501 [&](size_t ithr, size_t nthr, size_t os) {
502 const float *__restrict acc_arr = acc + os * jcp.oc;
503 const float *__restrict bia_arr
504 = (bia_base == nullptr)
505 ? nullptr
506 : bia_base + g * jcp.oc;
507 dst_data_t *__restrict dst_arr
508 = dst + os * dst_os_stride;
509
510 (*pp_ker_)(dst_arr,
511 acc_needed ? acc_arr : (float *)dst_arr,
512 bia_arr, sum_scale, jcp.oc,
513 post_ops_binary_rhs_arg_vec, dst_base,
514 g * jcp.oc);
515 });
516 }
517 }
518 nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow);
519 }
520 return status::success;
521}
522
523template <data_type_t dst_data_type>
524status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_ncsp(
525 const exec_ctx_t &ctx) const {
526 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
527 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
528 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
529 const auto post_ops_binary_rhs_arg_vec
530 = binary_injector::prepare_binary_args(
531 this->pd()->attr()->post_ops_, ctx);
532
533 bool is_bf16_dst = dst_data_type == data_type::bf16;
534
535 auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
536 key_conv_gemm_col);
537 acc_data_t *acc_base = is_bf16_dst
538 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
539 key_conv_int_dat_in_acc_dt)
540 : nullptr;
541
542 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
543
544 float *bias = nullptr;
545 if (jcp.with_bias) {
546 if (pd()->desc()->bias_desc.data_type == data_type::bf16) {
547 auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS);
548 bias = ctx.get_scratchpad_grantor().template get<float>(
549 key_conv_bias_bf16_convert_wsp);
550 cvt_bfloat16_to_float(bias, bias_in, jcp.ngroups * jcp.oc);
551 } else {
552 auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
553 bias = const_cast<float *>(bias_in);
554 }
555 }
556
557 const auto &post_ops = pd()->attr()->post_ops_;
558 const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
559 const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
560
561 const dim_t M = jcp.os * jcp.od;
562 const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
563 const size_t dst_step = (size_t)jcp.oc * M;
564 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
565 const size_t weights_oc_size = jcp.ic * jcp.ks;
566
567 const dim_t LDB = weights_oc_size;
568 const dim_t work_amount
569 = (size_t)jcp.ngroups * jcp.mb * jcp.od * jcp.os_nb_block;
570 const bool is_problem_3d = pd()->ndims() == 5;
571 std::atomic<status_t> st(status::success);
572
573 auto inner_ker = [&](const int ic, const int oc, const int groups,
574 const int od, const int spatial,
575 const src_data_t *src, const wei_data_t *weights,
576 src_data_t *col, dst_data_t *dst_im,
577 acc_data_t *acc, int ic_block, int oc_block) {
578 const dim_t os_block = nstl::min(
579 (dim_t)jcp.os_block, (dim_t)jcp.os - spatial * jcp.os_block);
580
581 if (jcp.im2col_sz) {
582 if (!is_problem_3d) {
583 jit_gemm_convolution_utils::im2col<src_data_t>(jcp, src, col,
584 spatial * jcp.os_block, os_block, ic, ic_block);
585 } else {
586 assert(jcp.ic_block == jcp.ic);
587 jit_gemm_convolution_utils::im2col_3d<src_data_t>(
588 jcp, src, col, od, spatial * jcp.os_block, os_block);
589 }
590 }
591
592 const acc_data_t one = 1.0;
593 const dim_t N = oc_block;
594 const dim_t K = ic_block * jcp.ks;
595 const dim_t m = os_block;
596 const dim_t LDA = jcp.im2col_sz ? m : M;
597 const dim_t LDC = is_bf16_dst ? m : M;
598 const float beta = (ic == 0) ? this->beta_ : one;
599 auto out_off = spatial * jcp.os_block + od * jcp.os;
600 dst_data_t *dst_local = dst_im + out_off;
601
602 status_t st_thr = gemm_bf16bf16f32("N", "N", &m, &N, &K, &one,
603 jcp.im2col_sz ? col : src + ic * M + out_off, &LDA, weights,
604 &LDB, &beta, acc, &LDC);
605
606 if (st_thr != status::success) {
607 st = st_thr;
608 return;
609 }
610
611 if (this->pd()->is_postprocess_required() && ic + ic_block >= jcp.ic) {
612 size_t acc_str = LDC;
613 size_t dst_str = M;
614 float *bias_ptr = bias ? bias + groups * jcp.oc + oc : nullptr;
615 (*pp_ker_)(dst_local, acc, bias_ptr, sum_scale, dst_str, acc_str, m,
616 oc_block, post_ops_binary_rhs_arg_vec.data(), dst,
617 groups * jcp.oc + oc);
618 }
619 };
620
621 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
622 src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
623 if (is_problem_3d) {
624 // jit_gemm_convolution_utils::im2col_3d() requires external
625 // data initialization by zeroes
626 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
627 _col[i] = (src_data_t)0;
628 }
629 dim_t g {0}, n {0}, od {0}, nb_os {0};
630 dim_t start = 0, end = 0;
631 dim_t oc_start = 0, oc_end = 0;
632
633 assert(jcp.loop_order == gemm_loop_lbr);
634 balance2D(nthr, ithr, work_amount, start, end, jcp.oc, oc_start, oc_end,
635 dim_t(jcp.nthr_oc));
636
637 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os,
638 jcp.os_nb_block);
639 for (dim_t iwork = start; iwork < end; ++iwork) {
640 for_(dim_t oc = (dim_t)oc_start; oc < (dim_t)oc_end;
641 oc += jcp.oc_block)
642 for (dim_t ic = 0; ic < jcp.ic; ic += jcp.ic_block) {
643 const src_data_t *_src = src + (n * jcp.ngroups + g) * src_step;
644 const wei_data_t *_weights = weights + g * weights_g_size
645 + oc * weights_oc_size + ic * jcp.ks;
646 dst_data_t *_dst_im
647 = dst + (n * jcp.ngroups + g) * dst_step + oc * M;
648 auto out_off = nb_os * jcp.os_block + od * jcp.os;
649 dst_data_t *dst_local = _dst_im + out_off;
650 const dim_t sizeof_cacheline_float = 16;
651 acc_data_t *_acc = is_bf16_dst ? acc_base
652 + ithr
653 * rnd_up(jcp.oc_block * jcp.os_block,
654 sizeof_cacheline_float)
655 : (acc_data_t *)dst_local;
656
657 const dim_t ic_block = nstl::min(jcp.ic - ic, jcp.ic_block);
658 const dim_t oc_block
659 = nstl::min(dim_t(oc_end) - oc, jcp.oc_block);
660
661 inner_ker(ic, oc, g, od, nb_os, _src, _weights, _col, _dst_im,
662 _acc, ic_block, oc_block);
663 }
664 nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os,
665 jcp.os_nb_block);
666 }
667 });
668
669 return st;
670}
671
672template <data_type_t diff_src_data_type>
673status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
674 execute_backward_data_nspc(const exec_ctx_t &ctx) const {
675
676 auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
677 auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
678 auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
679
680 auto scratchpad = ctx.get_scratchpad_grantor();
681 const conv_gemm_conf_t &jcp = pd()->jcp_;
682
683 std::atomic<status_t> st(status::success);
684 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
685 status_t st_thr = execute_backward_data_thr_nspc(
686 ithr, nthr, diff_src_base, wei_base, diff_dst_base, scratchpad);
687 if (st_thr != status::success) st = st_thr;
688 });
689
690 return st;
691}
692
693template <data_type_t diff_src_data_type>
694status_t gemm_bf16_convolution_bwd_data_t<
695 diff_src_data_type>::execute_backward_data_thr_nspc(const int ithr,
696 const int nthr, diff_src_data_t *diff_src_base,
697 const wei_data_t *wei_base, const diff_dst_data_t *diff_dst_base,
698 const memory_tracking::grantor_t &scratchpad) const {
699
700 const conv_gemm_conf_t &jcp = pd()->jcp_;
701
702 // Diff_dst Format: mb-spatial-groups-output_channels
703 const size_t diff_dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh
704 * jcp.ow * jcp.ngroups * jcp.oc;
705 const size_t diff_dst_g_stride = jcp.oc;
706
707 // Wei Format: spatial-input_channels-groups-output_channels
708 const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0;
709
710 // Diff_src Format: mb-spatial-groups-input_channels
711 const size_t diff_src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih
712 * jcp.iw * jcp.ngroups * jcp.ic;
713 const size_t diff_src_g_stride = jcp.ic;
714 const size_t diff_src_os_stride = jcp.ngroups * jcp.ic;
715
716 // threads share work across mini-batch and groups
717 const dim_t work_amount = jcp.ngroups * jcp.mb;
718
719 acc_data_t *__restrict col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
720 + (ptrdiff_t)ithr * jcp.im2col_sz;
721 acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc)
722 + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic;
723
724 dim_t n {0}, g {0};
725 dim_t start = 0, end = 0;
726
727 balance211(work_amount, nthr, ithr, start, end);
728 nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
729
730 for (dim_t iwork = start; iwork < end; ++iwork) {
731 const diff_dst_data_t *__restrict diff_dst = diff_dst_base
732 + n * diff_dst_mb_stride + g * diff_dst_g_stride;
733 const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
734 diff_src_data_t *__restrict diff_src = diff_src_base
735 + n * diff_src_mb_stride + g * diff_src_g_stride;
736
737 const dim_t M = jcp.ks * jcp.ic;
738 const dim_t N = jcp.os * jcp.od;
739 const dim_t K = jcp.oc;
740
741 const float onef = 1.0f, zerof = 0.0f;
742 const dim_t LD = K * jcp.ngroups;
743
744 status_t st = gemm_bf16bf16f32("T", "N", &M, &N, &K, &onef, wei, &LD,
745 diff_dst, &LD, &zerof, jcp.im2col_sz ? col : acc, &M);
746 if (st != status::success) return st;
747
748 if (jcp.im2col_sz)
749 jit_gemm_convolution_utils::col2im_dt<acc_data_t>(jcp, col, acc);
750
751 const bool is_diff_src_bf16 = diff_src_data_type == data_type::bf16;
752
753 if (is_diff_src_bf16 && jcp.ngroups == 1 && jcp.nthr != 1) {
754 cvt_float_to_bfloat16((bfloat16_t *)diff_src, (const float *)acc,
755 static_cast<size_t>(jcp.is) * jcp.id * jcp.ic);
756 } else if (is_diff_src_bf16) {
757 parallel_nd_ext(jcp.nthr == 1 ? 0 : 1,
758 static_cast<size_t>(jcp.is) * jcp.id,
759 [&](size_t ithr, size_t nthr, size_t is) {
760 diff_src_data_t *__restrict diff_src_loc
761 = diff_src + is * diff_src_os_stride;
762 const acc_data_t *__restrict acc_loc
763 = acc + is * jcp.ic;
764 cvt_float_to_bfloat16((bfloat16_t *)diff_src_loc,
765 (const float *)acc_loc, jcp.ic);
766 });
767 } else {
768 assert(diff_src_data_type == data_type::f32);
769 parallel_nd_ext(jcp.nthr == 1 ? 0 : 1,
770 static_cast<size_t>(jcp.is) * jcp.id,
771 [&](size_t ithr, size_t nthr, size_t is) {
772 diff_src_data_t *__restrict diff_src_loc
773 = diff_src + is * diff_src_os_stride;
774 const acc_data_t *__restrict acc_loc
775 = acc + is * jcp.ic;
776 PRAGMA_OMP_SIMD()
777 for (int ic = 0; ic < jcp.ic; ++ic)
778 diff_src_loc[ic] = acc_loc[ic];
779 });
780 }
781 nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
782 }
783 return status::success;
784}
785
786template <data_type_t diff_src_data_type>
787status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>::
788 execute_backward_data_ncsp(const exec_ctx_t &ctx) const {
789 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
790 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
791 auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC);
792
793 auto col = ctx.get_scratchpad_grantor().template get<acc_data_t>(
794 key_conv_gemm_col);
795 acc_data_t *acc_base = diff_src_data_type == data_type::bf16
796 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
797 key_conv_int_dat_in_acc_dt)
798 : nullptr;
799
800 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
801
802 const dim_t M = jcp.os * jcp.od;
803 const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
804 const size_t dst_step = (size_t)jcp.oc * M;
805 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
806
807 const dim_t m = jcp.os_block;
808 const dim_t K = jcp.oc;
809 const dim_t N = jcp.ic * jcp.ks;
810
811 const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb;
812 const bool is_problem_3d = pd()->ndims() == 5;
813
814 std::atomic<status_t> st(status::success);
815
816 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
817 acc_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
818
819 dim_t g {0}, n {0};
820 dim_t start = 0, end = 0;
821 balance211(work_amount, nthr, ithr, start, end);
822 nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb);
823 for (dim_t iwork = start; iwork < end; ++iwork) {
824
825 diff_src_data_t *diff_src_local
826 = diff_src + (n * jcp.ngroups + g) * src_step;
827 acc_data_t *acc = diff_src_data_type == data_type::bf16
828 ? acc_base + ithr * rnd_up(src_step, 16)
829 : (acc_data_t *)diff_src_local;
830
831 if (is_problem_3d && jcp.im2col_sz > 0) {
832 // jit_gemm_convolution_utils::col2im_3d() assumes that the
833 // accumulator is initialized by zeroes
834 for (size_t i = 0; i < src_step; i++)
835 acc[i] = (acc_data_t)0;
836 }
837
838 const wei_data_t *_weights = weights + g * weights_g_size;
839 for_(int od = 0; od < jcp.od; ++od)
840 for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
841 auto out_off = os_nb * m + od * jcp.os;
842 const diff_dst_data_t *_diff_dst
843 = diff_dst + (n * jcp.ngroups + g) * dst_step + out_off;
844 const dim_t os_block
845 = nstl::min((dim_t)jcp.os_block, jcp.os - os_nb * m);
846 const dim_t LDC = jcp.im2col_sz ? os_block : M;
847
848 const acc_data_t zero = 0.0, one = 1.0;
849 status_t st_thr = gemm_bf16bf16f32("N", "T", &os_block, &N, &K,
850 &one, _diff_dst, &M, _weights, &N, &zero,
851 jcp.im2col_sz ? _col : acc + out_off, &LDC);
852
853 if (st_thr != status::success) {
854 st = st_thr;
855 return;
856 }
857
858 if (jcp.im2col_sz) {
859 if (!is_problem_3d)
860 jit_gemm_convolution_utils::col2im(
861 jcp, _col, acc, os_nb * jcp.os_block, os_block);
862 else
863 jit_gemm_convolution_utils::col2im_3d(jcp, _col, acc,
864 od, os_nb * jcp.os_block, os_block);
865 }
866 }
867 if (diff_src_data_type == data_type::bf16) {
868 size_t spatial_size = (size_t)jcp.ih * jcp.iw * jcp.id;
869 store_bfloat16_in_parallel((bfloat16_t *)diff_src_local,
870 (const float *)acc, jcp.ic, spatial_size,
871 jcp.nthr == 1);
872 }
873 nd_iterator_step(g, jcp.ngroups, n, jcp.mb);
874 }
875 });
876
877 return st;
878}
879
880template <data_type_t diff_wei_data_type>
881void gemm_bf16_convolution_bwd_weights_t<
882 diff_wei_data_type>::bf16_bwd_weights_reduction_par_nspc(int ithr_mb,
883 int nthr_mb, size_t g_start, size_t g_end, const conv_gemm_conf_t &jcp,
884 const acc_data_t *weights_reduce_base,
885 diff_wei_data_t *weights_base) const {
886 assert(nthr_mb > 1); // no reduction for nthr_mb == 1
887
888 const bool is_bf16_out = diff_wei_data_type == data_type::bf16;
889 const dim_t weights_g_size = jcp.oc;
890 dim_t weights_start {0}, weights_end {0};
891 balance211(jcp.ks * jcp.ic, nthr_mb, ithr_mb, weights_start, weights_end);
892
893 for (auto tidx = 1; tidx < nthr_mb; ++tidx) {
894 const acc_data_t *ws_base
895 = weights_reduce_base + tidx * weights_g_size * jcp.ks * jcp.ic;
896 for_(auto w = weights_start; w < weights_end; ++w)
897 for (auto g = g_start; g < g_end; ++g) {
898 const acc_data_t *ws_ptr = ws_base + w * jcp.oc;
899 float *wei_reduced = is_bf16_out
900 ? (float *)weights_reduce_base + w * jcp.oc
901 : (float *)weights_base + (w * jcp.ngroups + g) * jcp.oc;
902 if (is_bf16_out && tidx == nthr_mb - 1) {
903 // the last iteration for bfloat16 requires conversion
904 // and store to diff_weights array
905 diff_wei_data_t *dwei_ptr
906 = weights_base + (w * jcp.ngroups + g) * jcp.oc;
907 add_floats_and_cvt_to_bfloat16(
908 (bfloat16_t *)(dwei_ptr), wei_reduced, ws_ptr, jcp.oc);
909 } else {
910 acc_ker_->accumulate(wei_reduced, ws_ptr, jcp.oc);
911 }
912 }
913 }
914}
915
916template <data_type_t diff_wei_data_type>
917void gemm_bf16_convolution_bwd_weights_t<
918 diff_wei_data_type>::bf16_bwd_weights_reduction_par_ncsp(int ithr_mb,
919 int nthr_mb, const conv_gemm_conf_t &jcp,
920 const acc_data_t *weights_reduce_base,
921 diff_wei_data_t *weights_base) const {
922 assert(nthr_mb > 1); // no reduction for nthr_mb == 1
923
924 const bool is_bf16_out = diff_wei_data_type == data_type::bf16;
925 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
926 size_t weights_start {0}, weights_end {0};
927 balance211(weights_g_size, nthr_mb, ithr_mb, weights_start, weights_end);
928
929 if (weights_start >= weights_end) return; // nothing to do
930
931 size_t acc_size = weights_end - weights_start;
932 float *wei_reduced = is_bf16_out
933 ? (float *)weights_reduce_base + weights_start
934 : (float *)weights_base + weights_start;
935 if (!is_bf16_out) {
936 // f32 diff_weights require initialization by weights_reduce
937 // for thr_mb = 0
938 for (size_t i = 0; i < acc_size; i++)
939 wei_reduced[i] = ((float *)weights_reduce_base + weights_start)[i];
940 }
941
942 for (int thr_mb = 1; thr_mb < nthr_mb; ++thr_mb) {
943 float *wei_to_reduce = (float *)weights_reduce_base
944 + thr_mb * weights_g_size + weights_start;
945
946 if (is_bf16_out && thr_mb == nthr_mb - 1)
947 // the last iteration for bfloat16 requires conversion
948 // and store to diff_weights array
949 add_floats_and_cvt_to_bfloat16(
950 (bfloat16_t *)(weights_base + weights_start), wei_reduced,
951 wei_to_reduce, acc_size);
952 else
953 acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size);
954 }
955}
956
957template <data_type_t diff_wei_data_type>
958status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
959 execute_backward_weights_nspc(const exec_ctx_t &ctx) const {
960 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
961 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
962 auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
963
964 auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
965 key_conv_gemm_col);
966 auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>(
967 key_conv_wei_reduction);
968 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
969
970 acc_data_t *acc_base = diff_wei_data_type == data_type::bf16
971 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
972 key_conv_int_dat_in_acc_dt)
973 : (acc_data_t *)diff_weights;
974
975 float *diff_bias = nullptr;
976 if (jcp.with_bias) {
977 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16)
978 diff_bias = ctx.get_scratchpad_grantor().template get<float>(
979 key_conv_bias_bf16_convert_wsp);
980 else
981 diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
982 }
983
984 const dim_t K = jcp.os * static_cast<size_t>(jcp.od);
985 const size_t src_step
986 = static_cast<size_t>(jcp.ic) * jcp.ih * jcp.iw * jcp.id;
987 const size_t dst_step = jcp.oc * K;
988 const size_t weights_g_size = jcp.oc;
989
990 const dim_t k = jcp.os;
991 const dim_t M = jcp.oc;
992 const dim_t N = static_cast<dim_t>(jcp.ic) * jcp.ks;
993 const dim_t LDB = jcp.ngroups * jcp.oc;
994 const dim_t LDA = jcp.im2col_sz ? jcp.oh * jcp.ow : jcp.ngroups * jcp.ic;
995 const bool is_problem_3d = pd()->ndims() == 5;
996
997 std::atomic<status_t> st(status::success);
998
999 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1000 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1001 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1002
1003 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1004 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
1005 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
1006
1007 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1008
1009 const int need_reduction = nthr_mb != 1;
1010 src_data_t *__restrict imtr
1011 = ctx.get_scratchpad_grantor().template get<src_data_t>(
1012 key_conv_gemm_imtr)
1013 + (ptrdiff_t)ithr * jcp.id * jcp.ic * jcp.is;
1014
1015 if (ithr_g != -1 && ithr_mb != -1) {
1016 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1017 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1018
1019 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1020
1021 src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
1022 if (is_problem_3d) {
1023 // jit_gemm_convolution_utils::im2col_3d() requires external
1024 // data initialization by zeroes
1025 // For performance reasons use uint16_t as proxy for bfloat16_t
1026 uint16_t *__restrict _col_r
1027 = reinterpret_cast<uint16_t *__restrict>(_col);
1028 constexpr uint16_t zero_val = 0;
1029
1030 PRAGMA_OMP_SIMD()
1031 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
1032 _col_r[i] = zero_val;
1033 }
1034
1035 acc_data_t *weights_reduce_base = wei_reduction
1036 + ithr_g * nthr_mb * weights_g_size * jcp.ks * jcp.ic;
1037 acc_data_t *weights_reduce = weights_reduce_base
1038 + ithr_mb * weights_g_size * jcp.ks * jcp.ic;
1039
1040 const bool use_diff_wei
1041 = ithr_mb == 0 && diff_wei_data_type == data_type::f32;
1042 for (size_t g = g_start; g < g_end; ++g) {
1043 acc_data_t *_diff_weights = use_diff_wei
1044 ? (acc_data_t *)diff_weights + g * weights_g_size
1045 : need_reduction ? weights_reduce
1046 : acc_base + g * weights_g_size;
1047 const dim_t LDC = use_diff_wei
1048 ? jcp.ngroups * jcp.oc
1049 : need_reduction ? jcp.oc : jcp.ngroups * jcp.oc;
1050 for (size_t mb = mb_start; mb < mb_end; ++mb) {
1051 const src_data_t *_src
1052 = src + mb * jcp.ngroups * src_step + g * jcp.ic;
1053 if (jcp.im2col_sz && is_problem_3d)
1054 jit_gemm_convolution_utils::transpose_dt(
1055 jcp, _src, imtr);
1056 for (int od = 0; od < jcp.od; ++od) {
1057 const diff_dst_data_t *_diff_dst = diff_dst
1058 + mb * jcp.ngroups * dst_step
1059 + od * k * jcp.ngroups * jcp.oc + g * jcp.oc;
1060
1061 if (jcp.im2col_sz) {
1062 if (is_problem_3d)
1063 jit_gemm_convolution_utils::im2col_dt_3d<
1064 src_data_t, src_data_t>(
1065 jcp, imtr, _col, od);
1066 else
1067 jit_gemm_convolution_utils::im2col_dt<
1068 src_data_t, src_data_t>(jcp, _src, imtr,
1069 _col, 0, jcp.oh, 0, jcp.ow);
1070 }
1071 const float zero = 0.0f, one = 1.0f;
1072 status_t st_thr = gemm_bf16bf16f32("N",
1073 jcp.im2col_sz ? "N" : "T", &M, &N, &k, &one,
1074 _diff_dst, &LDB,
1075 jcp.im2col_sz
1076 ? _col
1077 : _src + od * k * jcp.ngroups * jcp.ic,
1078 &LDA, mb == mb_start && od == 0 ? &zero : &one,
1079 _diff_weights, &LDC);
1080 if (st_thr != status::success) {
1081 st = st_thr;
1082 // Finish the loops early if failure occured.
1083 g = g_end;
1084 mb = mb_end;
1085 od = jcp.od;
1086 }
1087 }
1088 }
1089 }
1090 if (need_reduction && dnnl_thr_syncable()) {
1091 dnnl_thr_barrier();
1092 if (st != status::success) return;
1093 bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start,
1094 g_end, jcp, weights_reduce_base, diff_weights);
1095 } else if (diff_wei_data_type == data_type::bf16
1096 && g_end > g_start) {
1097 cvt_acc_to_dst(jcp, g_start, g_end, (const float *)acc_base,
1098 (bfloat16_t *)diff_weights);
1099 }
1100 } else {
1101 if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
1102 }
1103 });
1104
1105 if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
1106 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1107 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1108 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1109
1110 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1111 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
1112 jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
1113 nthr_mb);
1114
1115 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1116 const int need_reduction = nthr_mb != 1;
1117
1118 if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
1119 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1120 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1121
1122 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1123
1124 acc_data_t *weights_reduce_base = wei_reduction
1125 + ithr_g * nthr_mb * weights_g_size * jcp.ic * jcp.ks;
1126
1127 bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start,
1128 g_end, jcp, weights_reduce_base, diff_weights);
1129 }
1130 });
1131 }
1132
1133 if (jcp.with_bias) {
1134 parallel_nd(jcp.ngroups, jcp.oc, [&](dim_t g, dim_t oc) {
1135 acc_data_t db = 0;
1136 const dim_t offset_base = g * jcp.oc + oc;
1137 for_(dim_t mb = 0; mb < jcp.mb; ++mb)
1138 for_(dim_t od = 0; od < jcp.od; ++od)
1139 for (dim_t oh = 0; oh < jcp.oh; ++oh) {
1140 const dim_t width_stride = jcp.ngroups * jcp.oc;
1141 const diff_dst_data_t *__restrict diff_dst_arr = diff_dst
1142 + offset_base
1143 + ((mb * jcp.od + od) * jcp.oh + oh) * jcp.ow
1144 * width_stride;
1145
1146 PRAGMA_OMP_SIMD(reduction(+ : db))
1147 for (dim_t ow = 0; ow < jcp.ow; ++ow) {
1148 db += diff_dst_arr[ow * width_stride];
1149 }
1150 }
1151 diff_bias[g * jcp.oc + oc] = db;
1152 });
1153
1154 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) {
1155 auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS);
1156 cvt_float_to_bfloat16(
1157 diff_bias_in, diff_bias, jcp.ngroups * jcp.oc);
1158 }
1159 }
1160 return st;
1161}
1162
1163template <data_type_t diff_wei_data_type>
1164status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>::
1165 execute_backward_weights_ncsp(const exec_ctx_t &ctx) const {
1166 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
1167 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
1168 auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS);
1169
1170 auto col = ctx.get_scratchpad_grantor().template get<src_data_t>(
1171 key_conv_gemm_col);
1172 auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>(
1173 key_conv_wei_reduction);
1174
1175 const conv_gemm_conf_t &jcp = this->pd()->jcp_;
1176
1177 acc_data_t *acc_base = diff_wei_data_type == data_type::bf16
1178 ? ctx.get_scratchpad_grantor().template get<acc_data_t>(
1179 key_conv_int_dat_in_acc_dt)
1180 : (acc_data_t *)diff_weights;
1181
1182 float *diff_bias = nullptr;
1183 if (jcp.with_bias) {
1184 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16)
1185 diff_bias = ctx.get_scratchpad_grantor().template get<float>(
1186 key_conv_bias_bf16_convert_wsp);
1187 else
1188 diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
1189 }
1190
1191 const dim_t K = jcp.os * jcp.od;
1192 const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id;
1193 const size_t dst_step = (size_t)jcp.oc * K;
1194 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
1195
1196 const dim_t k = jcp.os_block;
1197 const dim_t N = jcp.oc;
1198 const dim_t M = jcp.ic * jcp.ks;
1199 const bool is_problem_3d = pd()->ndims() == 5;
1200
1201 std::atomic<status_t> st(status::success);
1202 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1203 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1204 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1205
1206 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1207 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
1208 mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
1209
1210 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1211 const int need_reduction = nthr_mb != 1;
1212
1213 if (ithr_g != -1 && ithr_mb != -1) {
1214 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1215 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1216
1217 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1218
1219 src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
1220 // non-blocked jit_gemm_convolution_utils::im2col_3d() requires
1221 // external data initialization by zeroes
1222 const bool outer_padding = jcp.os_nb_block == 1;
1223 if (outer_padding && is_problem_3d) {
1224 for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++)
1225 _col[i] = (src_data_t)0;
1226 }
1227
1228 acc_data_t *weights_reduce_base
1229 = wei_reduction + ithr_g * nthr_mb * weights_g_size;
1230 acc_data_t *weights_reduce
1231 = weights_reduce_base + ithr_mb * weights_g_size;
1232
1233 for (size_t g = g_start; g < g_end; ++g) {
1234 acc_data_t *acc = need_reduction
1235 ? weights_reduce
1236 : (acc_base + g * weights_g_size);
1237 for (size_t mb = mb_start; mb < mb_end; ++mb) {
1238 const src_data_t *_src
1239 = src + (mb * jcp.ngroups + g) * src_step;
1240 for_(int od = 0; od < jcp.od; ++od)
1241 for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) {
1242 auto out_off = os_nb * k + od * jcp.os;
1243 const dim_t os_block = nstl::min(
1244 (dim_t)jcp.os_block, jcp.os - os_nb * k);
1245 const diff_dst_data_t *_diff_dst = diff_dst
1246 + (mb * jcp.ngroups + g) * dst_step + out_off;
1247
1248 if (jcp.im2col_sz) {
1249 if (!is_problem_3d)
1250 jit_gemm_convolution_utils::im2col<src_data_t>(
1251 jcp, _src, _col, os_nb * jcp.os_block,
1252 os_block, 0, jcp.ic);
1253 else
1254 jit_gemm_convolution_utils::im2col_3d<
1255 src_data_t>(jcp, _src, _col, od,
1256 os_nb * jcp.os_block, os_block);
1257 }
1258
1259 const dim_t LDA = jcp.im2col_sz ? os_block : K;
1260 const acc_data_t zero = 0.0, one = 1.0;
1261 status_t st_thr = gemm_bf16bf16f32("T", "N", &M, &N,
1262 &os_block, &one,
1263 jcp.im2col_sz ? _col : _src + out_off, &LDA,
1264 _diff_dst, &K,
1265 mb == mb_start && os_nb == 0 && od == 0 ? &zero
1266 : &one,
1267 acc, &M);
1268
1269 if (st_thr != status::success) {
1270 st = st_thr;
1271 // Finish the loops early if failure occured.
1272 g = g_end;
1273 mb = mb_end;
1274 od = jcp.od;
1275 os_nb = jcp.os_nb_block;
1276 }
1277 }
1278 }
1279 }
1280 if (need_reduction && dnnl_thr_syncable()) {
1281 dnnl_thr_barrier();
1282 if (st != status::success) return;
1283 diff_wei_data_t *weights_base
1284 = diff_weights + g_start * weights_g_size;
1285 bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp,
1286 weights_reduce_base, weights_base);
1287 } else if (diff_wei_data_type == data_type::bf16
1288 && g_end > g_start) {
1289 const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks;
1290 const size_t work_size = (g_end - g_start) * weights_g_size;
1291 bfloat16_t *diff_weights_local
1292 = (bfloat16_t *)diff_weights + g_start * weights_g_size;
1293 const float *acc_local
1294 = (const float *)acc_base + g_start * weights_g_size;
1295 store_bfloat16_in_parallel(diff_weights_local, acc_local,
1296 work_size, 1, jcp.nthr == 1);
1297 }
1298 } else {
1299 if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier();
1300 }
1301 });
1302
1303 if (st != status::success) return st;
1304
1305 if (jcp.need_wei_reduction && !dnnl_thr_syncable()) {
1306 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
1307 int ithr_g, nthr_g, ithr_mb, nthr_mb;
1308 size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0};
1309
1310 const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
1311 jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr,
1312 jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb,
1313 nthr_mb);
1314
1315 assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
1316 const int need_reduction = nthr_mb != 1;
1317
1318 if (need_reduction && ithr_g != -1 && ithr_mb != -1) {
1319 balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
1320 balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
1321
1322 assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
1323
1324 acc_data_t *weights_reduce_base
1325 = wei_reduction + ithr_g * nthr_mb * weights_g_size;
1326
1327 diff_wei_data_t *weights_base
1328 = diff_weights + g_start * weights_g_size;
1329 bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp,
1330 weights_reduce_base, weights_base);
1331 }
1332 });
1333 }
1334
1335 if (jcp.with_bias) {
1336 parallel_nd(jcp.ngroups, jcp.oc, [&](size_t g, size_t oc) {
1337 acc_data_t db = 0;
1338 dim_t offset_ = g * dst_step + oc * K;
1339 for (dim_t mb = 0; mb < jcp.mb; ++mb) {
1340 dim_t offset = offset_ + mb * jcp.ngroups * dst_step;
1341 for_(dim_t od = 0; od < jcp.od; ++od)
1342 for (dim_t oh = 0; oh < jcp.oh; ++oh) {
1343 PRAGMA_OMP_SIMD(reduction(+ : db))
1344 for (dim_t ow = 0; ow < jcp.ow; ++ow) {
1345 db += diff_dst[offset];
1346 offset++;
1347 }
1348 }
1349 }
1350 diff_bias[g * jcp.oc + oc] = db;
1351 });
1352
1353 if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) {
1354 auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS);
1355 cvt_float_to_bfloat16(
1356 diff_bias_in, diff_bias, jcp.ngroups * jcp.oc);
1357 }
1358 }
1359
1360 return st;
1361}
1362
1363template struct gemm_bf16_convolution_fwd_t<data_type::f32>;
1364template struct gemm_bf16_convolution_fwd_t<data_type::bf16>;
1365template struct gemm_bf16_convolution_bwd_data_t<data_type::f32>;
1366template struct gemm_bf16_convolution_bwd_data_t<data_type::bf16>;
1367template struct gemm_bf16_convolution_bwd_weights_t<data_type::f32>;
1368template struct gemm_bf16_convolution_bwd_weights_t<data_type::bf16>;
1369
1370} // namespace x64
1371} // namespace cpu
1372} // namespace impl
1373} // namespace dnnl
1374