1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <atomic> |
18 | |
19 | #include "oneapi/dnnl/dnnl_types.h" |
20 | |
21 | #include "common/bfloat16.hpp" |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "common/utils.hpp" |
26 | #include "cpu/x64/gemm_bf16_convolution.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | namespace x64 { |
33 | |
34 | using namespace dnnl::impl::status; |
35 | using namespace dnnl::impl::memory_tracking::names; |
36 | using namespace dnnl::impl::utils; |
37 | using namespace dnnl::impl::cpu::x64::bf16_support; |
38 | |
39 | // Below two stand-alone functions are moved out from execute_backward_data |
40 | // and execute_backward_weights to avoid warnings with gcc 6.x and 7.x compilers |
41 | // "declared with greater visibility than the type of its field" |
42 | // when one lambda function is delcared whithin the other one |
43 | void store_bfloat16_in_parallel(bfloat16_t *output_data, const float *acc_data, |
44 | size_t parallel_work, size_t parallel_work_size, bool do_in_parallel) { |
45 | parallel(do_in_parallel ? 0 : 1, [&](const int ithr, const int nthr) { |
46 | size_t start = 0, end = 0; |
47 | balance211(parallel_work, nthr, ithr, start, end); |
48 | if (start < end) |
49 | cvt_float_to_bfloat16(&output_data[start * parallel_work_size], |
50 | &acc_data[start * parallel_work_size], |
51 | (end - start) * parallel_work_size); |
52 | }); |
53 | } |
54 | |
55 | void cvt_acc_to_dst(const conv_gemm_conf_t &jcp, size_t g_start, size_t g_end, |
56 | const float *acc_base, bfloat16_t *diff_weights) { |
57 | const size_t parallel_work_size = jcp.ic * jcp.ks; |
58 | parallel(jcp.nthr == 1 ? 0 : 1, [&](const int ithr, const int nthr) { |
59 | size_t w_start = 0, w_end = 0; |
60 | balance211(parallel_work_size, nthr, ithr, w_start, w_end); |
61 | for_(auto w = w_start; w < w_end; ++w) |
62 | for (auto g = g_start; g < g_end; ++g) { |
63 | const float *__restrict acc_ptr |
64 | = acc_base + (w * jcp.ngroups + g) * jcp.oc; |
65 | bfloat16_t *__restrict dw_ptr |
66 | = diff_weights + (w * jcp.ngroups + g) * jcp.oc; |
67 | cvt_float_to_bfloat16(dw_ptr, acc_ptr, jcp.oc); |
68 | } |
69 | }); |
70 | } |
71 | |
72 | template <data_type_t dst_data_type> |
73 | gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::pp_ker_t(const pd_t *pd) |
74 | : jit_generator(jit_name()) |
75 | , jcp_(pd->jcp_) |
76 | , do_sum_(dst_data_type != data_type::f32 && jcp_.with_sum) |
77 | , max_data_reg_idx_(31) |
78 | , max_unroll_(12) |
79 | , compute_reg_step_(1) |
80 | , data_reg_base_idx_(0) { |
81 | using namespace types; |
82 | using namespace Xbyak; |
83 | |
84 | if (!mayiuse(avx512_core)) |
85 | // bf16 is not supported |
86 | return; |
87 | |
88 | const auto &post_ops = jcp_.post_ops; |
89 | if (jcp_.with_eltwise || jcp_.with_binary) { |
90 | #define PARAM_OFF(field) offsetof(ker_args, field) |
91 | static constexpr bool preserve_gpr = true; |
92 | static constexpr bool preserve_vmm = true; |
93 | static constexpr size_t helper_vmm_idx = 31; |
94 | static constexpr size_t tail_size = 1; |
95 | static constexpr bool use_exact_tail_scalar_bcast = false; |
96 | const binary_injector::rhs_arg_static_params_t rhs_arg_static_params { |
97 | helper_vmm_idx, reserved_eltwise_gpr, r14, r15, preserve_gpr, |
98 | preserve_vmm, PARAM_OFF(post_ops_binary_rhs_arg_vec), |
99 | PARAM_OFF(dst_orig), memory_desc_wrapper(pd->dst_md()), |
100 | tail_size, kreg_rem_mask, use_exact_tail_scalar_bcast}; |
101 | const binary_injector::static_params_t binary_static_params { |
102 | this->reg_param, rhs_arg_static_params}; |
103 | static constexpr bool save_state = true; |
104 | const eltwise_injector::static_params_t eltwise_static_params { |
105 | save_state, reserved_eltwise_gpr, reserved_eltwise_maskr}; |
106 | |
107 | postops_injector_ = utils::make_unique< |
108 | injector::jit_uni_postops_injector_t<avx512_core>>( |
109 | this, post_ops, binary_static_params, eltwise_static_params); |
110 | #undef PARAM_OFF |
111 | } |
112 | |
113 | if (do_sum_) { |
114 | compute_reg_step_ = 2; |
115 | vreg_sum_scale = Zmm(data_reg_base_idx_++); |
116 | } |
117 | |
118 | if (jcp_.with_bias) vreg_bias = Zmm(data_reg_base_idx_++); |
119 | |
120 | vlen_ = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
121 | |
122 | isa_ = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
123 | : bf16_emulation_t::get_isa(); |
124 | |
125 | if (isa_ != avx512_core_bf16) { |
126 | max_data_reg_idx_ = 26; |
127 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
128 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
129 | bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_6); |
130 | } |
131 | |
132 | max_unroll_ |
133 | = (max_data_reg_idx_ - data_reg_base_idx_ + 1) / compute_reg_step_; |
134 | } |
135 | |
136 | template <data_type_t dst_data_type> |
137 | void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::apply_postops( |
138 | const bool apply_mask, const size_t out_offset, const int vmm_idx) { |
139 | #define PARAM_OFF(x) offsetof(ker_args, x) |
140 | if (jcp_.with_eltwise || jcp_.with_binary) { |
141 | if (jcp_.with_binary) { |
142 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
143 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, reg_dst); |
144 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
145 | vmm_idx, out_offset * sizeof(dst_data_t)); |
146 | if (apply_mask) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
147 | |
148 | postops_injector_->compute_vector(vmm_idx, rhs_arg_params); |
149 | } else |
150 | postops_injector_->compute_vector(vmm_idx); |
151 | } |
152 | #undef PARAM_OFF |
153 | } |
154 | |
155 | template <data_type_t dst_data_type> |
156 | void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::generate() { |
157 | using namespace Xbyak; |
158 | using namespace utils; |
159 | |
160 | preamble(); |
161 | |
162 | #ifdef _WIN32 |
163 | mov(reg_param, rcx); |
164 | #endif |
165 | |
166 | #define PARAM_OFF(x) offsetof(ker_args, x) |
167 | mov(reg_dst_base, ptr[reg_param + PARAM_OFF(dst)]); |
168 | mov(reg_acc_base, ptr[reg_param + PARAM_OFF(acc)]); |
169 | if (jcp_.with_bias) mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]); |
170 | mov(reg_dst_str, ptr[reg_param + PARAM_OFF(dst_stride_in_bytes)]); |
171 | mov(reg_acc_str, ptr[reg_param + PARAM_OFF(acc_stride_in_bytes)]); |
172 | mov(reg_len, ptr[reg_param + PARAM_OFF(spatial_length)]); |
173 | mov(reg_oc_iter, ptr[reg_param + PARAM_OFF(oc_work)]); |
174 | |
175 | if (jcp_.with_binary) { |
176 | // zero initialize binary post_ops offset accumulator (store on stack) |
177 | const auto binary_post_op_acc_off_reg = reg_tmp; |
178 | xor_(binary_post_op_acc_off_reg, binary_post_op_acc_off_reg); |
179 | push(binary_post_op_acc_off_reg); |
180 | } |
181 | |
182 | if (do_sum_) |
183 | vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]); |
184 | #undef PARAM_OFF |
185 | |
186 | // Load accumulated value, apply sum (if any), bias (if any) |
187 | // and relu (if any); then convert to destination type and store |
188 | auto compute = [&](size_t offset, int idx, bool apply_mask) { |
189 | auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)]; |
190 | auto vreg_dst_ = vreg_dst(idx); |
191 | |
192 | if (dst_data_type == data_type::bf16 && isa_ != avx512_core_bf16) |
193 | bf16_emu_->init_vcvtneps2bf16(); |
194 | |
195 | if (apply_mask) vreg_dst_ = vreg_dst_ | kreg_rem_mask; |
196 | vmovups(vreg_dst_, acc_addr); |
197 | |
198 | if (jcp_.with_bias) vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias); |
199 | |
200 | auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)]; |
201 | if (do_sum_) { |
202 | auto vreg_prev_dst_ = vreg_prev_dst(idx); |
203 | if (dst_data_type == data_type::f32) { |
204 | if (apply_mask) vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask; |
205 | |
206 | vmovups(vreg_prev_dst_, dst_addr); |
207 | } else if (dst_data_type == data_type::bf16) { |
208 | auto vreg_prev_dst_ymm_ = vreg_prev_dst_ymm(idx); |
209 | if (apply_mask) |
210 | vreg_prev_dst_ymm_ = vreg_prev_dst_ymm_ | kreg_rem_mask; |
211 | |
212 | vmovdqu16(vreg_prev_dst_ymm_, dst_addr); |
213 | vpmovzxwd(vreg_prev_dst(idx), vreg_prev_dst_ymm_); |
214 | vpslld(vreg_prev_dst(idx), vreg_prev_dst(idx), 0x10); |
215 | } else |
216 | assert(!"unsupported data type" ); |
217 | |
218 | vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale); |
219 | } |
220 | |
221 | apply_postops(apply_mask, offset, vreg_dst_idx(idx)); |
222 | |
223 | if (dst_data_type == data_type::bf16) { |
224 | // TODO: implement store by zmm registers for bf16 |
225 | auto vreg_dst_ymm_ = vreg_dst_ymm(idx); |
226 | if (isa_ == avx512_core_bf16) |
227 | vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx)); |
228 | else |
229 | bf16_emu_->vcvtneps2bf16(vreg_dst_ymm_, vreg_dst(idx)); |
230 | |
231 | if (apply_mask) vreg_dst_ymm_ = vreg_dst_ymm_ | kreg_rem_mask; |
232 | |
233 | vmovdqu16(dst_addr, vreg_dst_ymm_); |
234 | } else if (dst_data_type == data_type::f32) |
235 | vmovups(dst_addr, vreg_dst_); |
236 | else |
237 | assert(!"unimplemented" ); |
238 | }; |
239 | |
240 | // Advance all pointers by an immediate |
241 | auto advance_ptrs_imm = [&](size_t offset) { |
242 | add(reg_dst, offset * sizeof(dst_data_t)); |
243 | add(reg_acc, offset * sizeof(acc_data_t)); |
244 | }; |
245 | |
246 | Xbyak::Label oc_loop, oc_loop_end; |
247 | |
248 | cmp(reg_oc_iter, 0); |
249 | jle(oc_loop_end, T_NEAR); |
250 | |
251 | L(oc_loop); |
252 | |
253 | mov(reg_len_iter, reg_len); |
254 | mov(reg_dst, reg_dst_base); |
255 | mov(reg_acc, reg_acc_base); |
256 | |
257 | if (jcp_.with_bias) vbroadcastss(vreg_bias, ptr[reg_bias]); |
258 | |
259 | constexpr int n_unroll = default_unroll_2_pow_; // unroll by powers of 2 |
260 | // from 2^n to 2^0 |
261 | assert((1 << n_unroll) <= max_unroll_); |
262 | |
263 | Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail; |
264 | for (int i = n_unroll; i >= 0; i--) { |
265 | const int unroll = 1 << i; // 4, 2, 1 |
266 | L(l_simd_loop[i + 1]); |
267 | { |
268 | const int loop_len = unroll * vlen_; |
269 | cmp(reg_len_iter, loop_len); |
270 | jl(l_simd_loop[i], T_NEAR); |
271 | for (int j = 0; j < unroll; j++) |
272 | compute(j * vlen_, j, false); |
273 | |
274 | advance_ptrs_imm(loop_len); |
275 | sub(reg_len_iter, loop_len); |
276 | jmp(l_simd_loop[i + 1], T_NEAR); |
277 | } |
278 | } |
279 | L(l_simd_loop[0]); |
280 | |
281 | mov(reg_tmp, reg_len_iter); // reg_tmp is rcx, and we need cl for the shift |
282 | mov(reg_rem_mask, 1); |
283 | shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen_ == 16 |
284 | sub(reg_rem_mask, 1); |
285 | jz(l_simd_notail, T_NEAR); |
286 | kmovq(kreg_rem_mask, reg_rem_mask); |
287 | compute(0, 0, true); |
288 | |
289 | L(l_simd_notail); |
290 | |
291 | add(reg_dst_base, reg_dst_str); |
292 | add(reg_acc_base, reg_acc_str); |
293 | if (jcp_.with_bias) add(reg_bias, sizeof(acc_data_t)); |
294 | if (jcp_.with_binary) |
295 | inc(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off)); |
296 | |
297 | dec(reg_oc_iter); |
298 | jnz(oc_loop, T_NEAR); // oc_loop end |
299 | |
300 | L(oc_loop_end); |
301 | |
302 | if (jcp_.with_binary) add(rsp, stack_space_needed); |
303 | |
304 | postamble(); |
305 | |
306 | if (jcp_.with_eltwise) postops_injector_->prepare_table(); |
307 | } |
308 | |
309 | // operator () specialized for nspc format |
310 | template <data_type_t dst_data_type> |
311 | void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()( |
312 | dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias, |
313 | float sum_scale, size_t oc_work, |
314 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
315 | const size_t g_oc_offset) { |
316 | |
317 | ker_args args; |
318 | args.acc = acc; |
319 | args.dst = dst; |
320 | args.bias = bias; |
321 | args.sum_scale = sum_scale; |
322 | args.dst_stride_in_bytes = sizeof(dst_data_t); |
323 | args.acc_stride_in_bytes = sizeof(acc_data_t); |
324 | args.spatial_length = 1; |
325 | args.oc_work = oc_work; |
326 | |
327 | args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
328 | args.dst_orig = dst_orig; |
329 | args.g_oc_offset = g_oc_offset; |
330 | jit_generator::operator()(&args); |
331 | } |
332 | |
333 | // operator () specialized for ncsp format |
334 | template <data_type_t dst_data_type> |
335 | void gemm_bf16_convolution_fwd_t<dst_data_type>::pp_ker_t::operator()( |
336 | dst_data_t *dst, const acc_data_t *acc, const acc_data_t *bias, |
337 | float sum_scale, size_t dst_stride_in_elements, |
338 | size_t acc_stride_in_elements, size_t sp_len, size_t oc_len, |
339 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
340 | const size_t g_oc_offset) { |
341 | if (sp_len == 0) return; |
342 | |
343 | ker_args args; |
344 | args.acc = acc; |
345 | args.dst = dst; |
346 | args.bias = bias; |
347 | args.sum_scale = sum_scale; |
348 | args.dst_stride_in_bytes = dst_stride_in_elements * sizeof(dst_data_t); |
349 | args.acc_stride_in_bytes = acc_stride_in_elements * sizeof(acc_data_t); |
350 | args.spatial_length = sp_len; |
351 | args.oc_work = oc_len; |
352 | |
353 | args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
354 | args.dst_orig = dst_orig; |
355 | args.g_oc_offset = g_oc_offset; |
356 | jit_generator::operator()(&args); |
357 | } |
358 | |
359 | template <data_type_t dst_data_type> |
360 | status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_nspc( |
361 | const exec_ctx_t &ctx) const { |
362 | auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
363 | auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
364 | auto dst_base = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
365 | const auto post_ops_binary_rhs_arg_vec |
366 | = binary_injector::prepare_binary_args( |
367 | this->pd()->attr()->post_ops_, ctx); |
368 | |
369 | auto scratchpad = ctx.get_scratchpad_grantor(); |
370 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
371 | |
372 | float *bia_base = nullptr; |
373 | if (jcp.with_bias) { |
374 | if (pd()->desc()->bias_desc.data_type == data_type::bf16) { |
375 | auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS); |
376 | bia_base = ctx.get_scratchpad_grantor().template get<float>( |
377 | key_conv_bias_bf16_convert_wsp); |
378 | cvt_bfloat16_to_float(bia_base, bias_in, jcp.ngroups * jcp.oc); |
379 | } else { |
380 | auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); |
381 | bia_base = const_cast<float *>(bias_in); |
382 | } |
383 | } |
384 | assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); |
385 | |
386 | std::atomic<status_t> st(status::success); |
387 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
388 | status_t st_thr = execute_forward_thr_nspc(ithr, nthr, src_base, |
389 | wei_base, bia_base, dst_base, scratchpad, |
390 | post_ops_binary_rhs_arg_vec.data()); |
391 | if (st_thr != status::success) st = st_thr; |
392 | }); |
393 | |
394 | return st; |
395 | } |
396 | |
397 | template <data_type_t dst_data_type> |
398 | status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_thr_nspc( |
399 | const int ithr, const int nthr, const src_data_t *src_base, |
400 | const wei_data_t *wei_base, const float *bia_base, dst_data_t *dst_base, |
401 | const memory_tracking::grantor_t &scratchpad, |
402 | const void *post_ops_binary_rhs_arg_vec) const { |
403 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
404 | |
405 | // Src Format: mb-spatial-groups-input_channels |
406 | const size_t src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih * jcp.iw |
407 | * jcp.ngroups * jcp.ic; |
408 | const size_t src_g_stride = jcp.ic; |
409 | // Wei Format: spatial-input_channels-groups-output_channels |
410 | const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0; |
411 | |
412 | // Dst Format: mb-spatial-groups-output_channels |
413 | const size_t dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh * jcp.ow |
414 | * jcp.ngroups * jcp.oc; |
415 | const size_t dst_g_stride = jcp.oc; |
416 | const size_t dst_os_stride = jcp.ngroups * jcp.oc; |
417 | |
418 | src_data_t *__restrict col = scratchpad.get<src_data_t>(key_conv_gemm_col) |
419 | + (ptrdiff_t)ithr * jcp.im2col_sz; |
420 | src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr) |
421 | + (ptrdiff_t)ithr * jcp.is * jcp.ic; |
422 | acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc) |
423 | + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc; |
424 | |
425 | const auto &post_ops = pd()->attr()->post_ops_; |
426 | const bool do_sum = post_ops.contain(primitive_kind::sum, 0); |
427 | const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0; |
428 | |
429 | dim_t g {0}, n {0}, ohb {0}, owb {0}; |
430 | dim_t start = 0, end = 0; |
431 | |
432 | const bool is_problem_3d = pd()->ndims() == 5; |
433 | assert(IMPLICATION(is_problem_3d, |
434 | jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow |
435 | && jcp.ic_block == jcp.ic)); |
436 | |
437 | const dim_t nb_oh = div_up(jcp.oh, jcp.oh_block); |
438 | const dim_t nb_ow = div_up(jcp.ow, jcp.ow_block); |
439 | const dim_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow; |
440 | balance211(work_amount, nthr, ithr, start, end); |
441 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); |
442 | |
443 | if (jcp.im2col_sz && is_problem_3d) { |
444 | // jit_gemm_convolution_utils::im2col_dt_3d() requires external |
445 | // data initialization by zeroes |
446 | // For performance reasons use uint16_t as a proxy for bfloat16_t |
447 | uint16_t *__restrict col_r |
448 | = reinterpret_cast<uint16_t *__restrict>(col); |
449 | constexpr uint16_t zero_val = 0; |
450 | |
451 | PRAGMA_OMP_SIMD() |
452 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
453 | col_r[i] = zero_val; |
454 | } |
455 | for (dim_t iwork = start; iwork < end; ++iwork) { |
456 | int oh = ohb * jcp.oh_block; |
457 | int ow = owb * jcp.ow_block; |
458 | const src_data_t *__restrict src |
459 | = src_base + n * src_mb_stride + g * src_g_stride; |
460 | const wei_data_t *__restrict wei = wei_base + g * wei_g_stride; |
461 | |
462 | const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); |
463 | const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); |
464 | if (jcp.im2col_sz && is_problem_3d) |
465 | jit_gemm_convolution_utils::transpose_dt(jcp, src, imtr); |
466 | |
467 | for (int od = 0; od < jcp.od; od++) { |
468 | dst_data_t *__restrict dst = dst_base + n * dst_mb_stride |
469 | + g * dst_g_stride |
470 | + ((od * jcp.oh + oh) * jcp.ow + ow) * dst_os_stride; |
471 | if (jcp.im2col_sz) { |
472 | if (is_problem_3d) |
473 | jit_gemm_convolution_utils::im2col_dt_3d<src_data_t, |
474 | src_data_t>(jcp, imtr, col, od); |
475 | else |
476 | jit_gemm_convolution_utils::im2col_dt<src_data_t, |
477 | src_data_t>( |
478 | jcp, src, imtr, col, oh, h_step, ow, w_step); |
479 | } |
480 | |
481 | const dim_t M = jcp.oc; |
482 | const dim_t K = jcp.ks * jcp.ic; |
483 | const dim_t N = h_step * w_step; |
484 | const dim_t LDA = M * jcp.ngroups; |
485 | const dim_t LDB = jcp.im2col_sz ? N : K * jcp.ngroups; |
486 | const char *BT = jcp.im2col_sz ? "T" : "N" ; |
487 | const float onef = 1.f; |
488 | const float beta = this->beta_; |
489 | const src_data_t *__restrict src_od |
490 | = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic; |
491 | const bool acc_needed = dst_data_type == data_type::bf16; |
492 | status_t st = gemm_bf16bf16f32("N" , BT, &M, &N, &K, &onef, wei, |
493 | &LDA, jcp.im2col_sz ? col : (src_data_t *)src_od, &LDB, |
494 | &beta, acc_needed ? acc : (float *)dst, |
495 | acc_needed ? &M : &LDA); |
496 | if (st != status::success) return st; |
497 | |
498 | const bool do_postprocess = pd()->is_postprocess_required(); |
499 | if (do_postprocess) { |
500 | parallel_nd_ext(jcp.nthr == 1 ? 0 : 1, N, |
501 | [&](size_t ithr, size_t nthr, size_t os) { |
502 | const float *__restrict acc_arr = acc + os * jcp.oc; |
503 | const float *__restrict bia_arr |
504 | = (bia_base == nullptr) |
505 | ? nullptr |
506 | : bia_base + g * jcp.oc; |
507 | dst_data_t *__restrict dst_arr |
508 | = dst + os * dst_os_stride; |
509 | |
510 | (*pp_ker_)(dst_arr, |
511 | acc_needed ? acc_arr : (float *)dst_arr, |
512 | bia_arr, sum_scale, jcp.oc, |
513 | post_ops_binary_rhs_arg_vec, dst_base, |
514 | g * jcp.oc); |
515 | }); |
516 | } |
517 | } |
518 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh, owb, nb_ow); |
519 | } |
520 | return status::success; |
521 | } |
522 | |
523 | template <data_type_t dst_data_type> |
524 | status_t gemm_bf16_convolution_fwd_t<dst_data_type>::execute_forward_ncsp( |
525 | const exec_ctx_t &ctx) const { |
526 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
527 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
528 | auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); |
529 | const auto post_ops_binary_rhs_arg_vec |
530 | = binary_injector::prepare_binary_args( |
531 | this->pd()->attr()->post_ops_, ctx); |
532 | |
533 | bool is_bf16_dst = dst_data_type == data_type::bf16; |
534 | |
535 | auto col = ctx.get_scratchpad_grantor().template get<src_data_t>( |
536 | key_conv_gemm_col); |
537 | acc_data_t *acc_base = is_bf16_dst |
538 | ? ctx.get_scratchpad_grantor().template get<acc_data_t>( |
539 | key_conv_int_dat_in_acc_dt) |
540 | : nullptr; |
541 | |
542 | const conv_gemm_conf_t &jcp = this->pd()->jcp_; |
543 | |
544 | float *bias = nullptr; |
545 | if (jcp.with_bias) { |
546 | if (pd()->desc()->bias_desc.data_type == data_type::bf16) { |
547 | auto bias_in = CTX_IN_MEM(const bfloat16_t *, DNNL_ARG_BIAS); |
548 | bias = ctx.get_scratchpad_grantor().template get<float>( |
549 | key_conv_bias_bf16_convert_wsp); |
550 | cvt_bfloat16_to_float(bias, bias_in, jcp.ngroups * jcp.oc); |
551 | } else { |
552 | auto bias_in = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); |
553 | bias = const_cast<float *>(bias_in); |
554 | } |
555 | } |
556 | |
557 | const auto &post_ops = pd()->attr()->post_ops_; |
558 | const bool do_sum = post_ops.contain(primitive_kind::sum, 0); |
559 | const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0; |
560 | |
561 | const dim_t M = jcp.os * jcp.od; |
562 | const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id; |
563 | const size_t dst_step = (size_t)jcp.oc * M; |
564 | const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; |
565 | const size_t weights_oc_size = jcp.ic * jcp.ks; |
566 | |
567 | const dim_t LDB = weights_oc_size; |
568 | const dim_t work_amount |
569 | = (size_t)jcp.ngroups * jcp.mb * jcp.od * jcp.os_nb_block; |
570 | const bool is_problem_3d = pd()->ndims() == 5; |
571 | std::atomic<status_t> st(status::success); |
572 | |
573 | auto inner_ker = [&](const int ic, const int oc, const int groups, |
574 | const int od, const int spatial, |
575 | const src_data_t *src, const wei_data_t *weights, |
576 | src_data_t *col, dst_data_t *dst_im, |
577 | acc_data_t *acc, int ic_block, int oc_block) { |
578 | const dim_t os_block = nstl::min( |
579 | (dim_t)jcp.os_block, (dim_t)jcp.os - spatial * jcp.os_block); |
580 | |
581 | if (jcp.im2col_sz) { |
582 | if (!is_problem_3d) { |
583 | jit_gemm_convolution_utils::im2col<src_data_t>(jcp, src, col, |
584 | spatial * jcp.os_block, os_block, ic, ic_block); |
585 | } else { |
586 | assert(jcp.ic_block == jcp.ic); |
587 | jit_gemm_convolution_utils::im2col_3d<src_data_t>( |
588 | jcp, src, col, od, spatial * jcp.os_block, os_block); |
589 | } |
590 | } |
591 | |
592 | const acc_data_t one = 1.0; |
593 | const dim_t N = oc_block; |
594 | const dim_t K = ic_block * jcp.ks; |
595 | const dim_t m = os_block; |
596 | const dim_t LDA = jcp.im2col_sz ? m : M; |
597 | const dim_t LDC = is_bf16_dst ? m : M; |
598 | const float beta = (ic == 0) ? this->beta_ : one; |
599 | auto out_off = spatial * jcp.os_block + od * jcp.os; |
600 | dst_data_t *dst_local = dst_im + out_off; |
601 | |
602 | status_t st_thr = gemm_bf16bf16f32("N" , "N" , &m, &N, &K, &one, |
603 | jcp.im2col_sz ? col : src + ic * M + out_off, &LDA, weights, |
604 | &LDB, &beta, acc, &LDC); |
605 | |
606 | if (st_thr != status::success) { |
607 | st = st_thr; |
608 | return; |
609 | } |
610 | |
611 | if (this->pd()->is_postprocess_required() && ic + ic_block >= jcp.ic) { |
612 | size_t acc_str = LDC; |
613 | size_t dst_str = M; |
614 | float *bias_ptr = bias ? bias + groups * jcp.oc + oc : nullptr; |
615 | (*pp_ker_)(dst_local, acc, bias_ptr, sum_scale, dst_str, acc_str, m, |
616 | oc_block, post_ops_binary_rhs_arg_vec.data(), dst, |
617 | groups * jcp.oc + oc); |
618 | } |
619 | }; |
620 | |
621 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
622 | src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
623 | if (is_problem_3d) { |
624 | // jit_gemm_convolution_utils::im2col_3d() requires external |
625 | // data initialization by zeroes |
626 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
627 | _col[i] = (src_data_t)0; |
628 | } |
629 | dim_t g {0}, n {0}, od {0}, nb_os {0}; |
630 | dim_t start = 0, end = 0; |
631 | dim_t oc_start = 0, oc_end = 0; |
632 | |
633 | assert(jcp.loop_order == gemm_loop_lbr); |
634 | balance2D(nthr, ithr, work_amount, start, end, jcp.oc, oc_start, oc_end, |
635 | dim_t(jcp.nthr_oc)); |
636 | |
637 | nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os, |
638 | jcp.os_nb_block); |
639 | for (dim_t iwork = start; iwork < end; ++iwork) { |
640 | for_(dim_t oc = (dim_t)oc_start; oc < (dim_t)oc_end; |
641 | oc += jcp.oc_block) |
642 | for (dim_t ic = 0; ic < jcp.ic; ic += jcp.ic_block) { |
643 | const src_data_t *_src = src + (n * jcp.ngroups + g) * src_step; |
644 | const wei_data_t *_weights = weights + g * weights_g_size |
645 | + oc * weights_oc_size + ic * jcp.ks; |
646 | dst_data_t *_dst_im |
647 | = dst + (n * jcp.ngroups + g) * dst_step + oc * M; |
648 | auto out_off = nb_os * jcp.os_block + od * jcp.os; |
649 | dst_data_t *dst_local = _dst_im + out_off; |
650 | const dim_t sizeof_cacheline_float = 16; |
651 | acc_data_t *_acc = is_bf16_dst ? acc_base |
652 | + ithr |
653 | * rnd_up(jcp.oc_block * jcp.os_block, |
654 | sizeof_cacheline_float) |
655 | : (acc_data_t *)dst_local; |
656 | |
657 | const dim_t ic_block = nstl::min(jcp.ic - ic, jcp.ic_block); |
658 | const dim_t oc_block |
659 | = nstl::min(dim_t(oc_end) - oc, jcp.oc_block); |
660 | |
661 | inner_ker(ic, oc, g, od, nb_os, _src, _weights, _col, _dst_im, |
662 | _acc, ic_block, oc_block); |
663 | } |
664 | nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, nb_os, |
665 | jcp.os_nb_block); |
666 | } |
667 | }); |
668 | |
669 | return st; |
670 | } |
671 | |
672 | template <data_type_t diff_src_data_type> |
673 | status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>:: |
674 | execute_backward_data_nspc(const exec_ctx_t &ctx) const { |
675 | |
676 | auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
677 | auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
678 | auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
679 | |
680 | auto scratchpad = ctx.get_scratchpad_grantor(); |
681 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
682 | |
683 | std::atomic<status_t> st(status::success); |
684 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
685 | status_t st_thr = execute_backward_data_thr_nspc( |
686 | ithr, nthr, diff_src_base, wei_base, diff_dst_base, scratchpad); |
687 | if (st_thr != status::success) st = st_thr; |
688 | }); |
689 | |
690 | return st; |
691 | } |
692 | |
693 | template <data_type_t diff_src_data_type> |
694 | status_t gemm_bf16_convolution_bwd_data_t< |
695 | diff_src_data_type>::execute_backward_data_thr_nspc(const int ithr, |
696 | const int nthr, diff_src_data_t *diff_src_base, |
697 | const wei_data_t *wei_base, const diff_dst_data_t *diff_dst_base, |
698 | const memory_tracking::grantor_t &scratchpad) const { |
699 | |
700 | const conv_gemm_conf_t &jcp = pd()->jcp_; |
701 | |
702 | // Diff_dst Format: mb-spatial-groups-output_channels |
703 | const size_t diff_dst_mb_stride = static_cast<size_t>(jcp.od) * jcp.oh |
704 | * jcp.ow * jcp.ngroups * jcp.oc; |
705 | const size_t diff_dst_g_stride = jcp.oc; |
706 | |
707 | // Wei Format: spatial-input_channels-groups-output_channels |
708 | const size_t wei_g_stride = pd()->with_groups() ? jcp.oc : 0; |
709 | |
710 | // Diff_src Format: mb-spatial-groups-input_channels |
711 | const size_t diff_src_mb_stride = static_cast<size_t>(jcp.id) * jcp.ih |
712 | * jcp.iw * jcp.ngroups * jcp.ic; |
713 | const size_t diff_src_g_stride = jcp.ic; |
714 | const size_t diff_src_os_stride = jcp.ngroups * jcp.ic; |
715 | |
716 | // threads share work across mini-batch and groups |
717 | const dim_t work_amount = jcp.ngroups * jcp.mb; |
718 | |
719 | acc_data_t *__restrict col = scratchpad.get<acc_data_t>(key_conv_gemm_col) |
720 | + (ptrdiff_t)ithr * jcp.im2col_sz; |
721 | acc_data_t *__restrict acc = scratchpad.get<acc_data_t>(key_conv_gemm_acc) |
722 | + (ptrdiff_t)ithr * jcp.is * jcp.id * jcp.ic; |
723 | |
724 | dim_t n {0}, g {0}; |
725 | dim_t start = 0, end = 0; |
726 | |
727 | balance211(work_amount, nthr, ithr, start, end); |
728 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups); |
729 | |
730 | for (dim_t iwork = start; iwork < end; ++iwork) { |
731 | const diff_dst_data_t *__restrict diff_dst = diff_dst_base |
732 | + n * diff_dst_mb_stride + g * diff_dst_g_stride; |
733 | const wei_data_t *__restrict wei = wei_base + g * wei_g_stride; |
734 | diff_src_data_t *__restrict diff_src = diff_src_base |
735 | + n * diff_src_mb_stride + g * diff_src_g_stride; |
736 | |
737 | const dim_t M = jcp.ks * jcp.ic; |
738 | const dim_t N = jcp.os * jcp.od; |
739 | const dim_t K = jcp.oc; |
740 | |
741 | const float onef = 1.0f, zerof = 0.0f; |
742 | const dim_t LD = K * jcp.ngroups; |
743 | |
744 | status_t st = gemm_bf16bf16f32("T" , "N" , &M, &N, &K, &onef, wei, &LD, |
745 | diff_dst, &LD, &zerof, jcp.im2col_sz ? col : acc, &M); |
746 | if (st != status::success) return st; |
747 | |
748 | if (jcp.im2col_sz) |
749 | jit_gemm_convolution_utils::col2im_dt<acc_data_t>(jcp, col, acc); |
750 | |
751 | const bool is_diff_src_bf16 = diff_src_data_type == data_type::bf16; |
752 | |
753 | if (is_diff_src_bf16 && jcp.ngroups == 1 && jcp.nthr != 1) { |
754 | cvt_float_to_bfloat16((bfloat16_t *)diff_src, (const float *)acc, |
755 | static_cast<size_t>(jcp.is) * jcp.id * jcp.ic); |
756 | } else if (is_diff_src_bf16) { |
757 | parallel_nd_ext(jcp.nthr == 1 ? 0 : 1, |
758 | static_cast<size_t>(jcp.is) * jcp.id, |
759 | [&](size_t ithr, size_t nthr, size_t is) { |
760 | diff_src_data_t *__restrict diff_src_loc |
761 | = diff_src + is * diff_src_os_stride; |
762 | const acc_data_t *__restrict acc_loc |
763 | = acc + is * jcp.ic; |
764 | cvt_float_to_bfloat16((bfloat16_t *)diff_src_loc, |
765 | (const float *)acc_loc, jcp.ic); |
766 | }); |
767 | } else { |
768 | assert(diff_src_data_type == data_type::f32); |
769 | parallel_nd_ext(jcp.nthr == 1 ? 0 : 1, |
770 | static_cast<size_t>(jcp.is) * jcp.id, |
771 | [&](size_t ithr, size_t nthr, size_t is) { |
772 | diff_src_data_t *__restrict diff_src_loc |
773 | = diff_src + is * diff_src_os_stride; |
774 | const acc_data_t *__restrict acc_loc |
775 | = acc + is * jcp.ic; |
776 | PRAGMA_OMP_SIMD() |
777 | for (int ic = 0; ic < jcp.ic; ++ic) |
778 | diff_src_loc[ic] = acc_loc[ic]; |
779 | }); |
780 | } |
781 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups); |
782 | } |
783 | return status::success; |
784 | } |
785 | |
786 | template <data_type_t diff_src_data_type> |
787 | status_t gemm_bf16_convolution_bwd_data_t<diff_src_data_type>:: |
788 | execute_backward_data_ncsp(const exec_ctx_t &ctx) const { |
789 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
790 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
791 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
792 | |
793 | auto col = ctx.get_scratchpad_grantor().template get<acc_data_t>( |
794 | key_conv_gemm_col); |
795 | acc_data_t *acc_base = diff_src_data_type == data_type::bf16 |
796 | ? ctx.get_scratchpad_grantor().template get<acc_data_t>( |
797 | key_conv_int_dat_in_acc_dt) |
798 | : nullptr; |
799 | |
800 | const conv_gemm_conf_t &jcp = this->pd()->jcp_; |
801 | |
802 | const dim_t M = jcp.os * jcp.od; |
803 | const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id; |
804 | const size_t dst_step = (size_t)jcp.oc * M; |
805 | const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; |
806 | |
807 | const dim_t m = jcp.os_block; |
808 | const dim_t K = jcp.oc; |
809 | const dim_t N = jcp.ic * jcp.ks; |
810 | |
811 | const dim_t work_amount = (size_t)jcp.ngroups * jcp.mb; |
812 | const bool is_problem_3d = pd()->ndims() == 5; |
813 | |
814 | std::atomic<status_t> st(status::success); |
815 | |
816 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
817 | acc_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
818 | |
819 | dim_t g {0}, n {0}; |
820 | dim_t start = 0, end = 0; |
821 | balance211(work_amount, nthr, ithr, start, end); |
822 | nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb); |
823 | for (dim_t iwork = start; iwork < end; ++iwork) { |
824 | |
825 | diff_src_data_t *diff_src_local |
826 | = diff_src + (n * jcp.ngroups + g) * src_step; |
827 | acc_data_t *acc = diff_src_data_type == data_type::bf16 |
828 | ? acc_base + ithr * rnd_up(src_step, 16) |
829 | : (acc_data_t *)diff_src_local; |
830 | |
831 | if (is_problem_3d && jcp.im2col_sz > 0) { |
832 | // jit_gemm_convolution_utils::col2im_3d() assumes that the |
833 | // accumulator is initialized by zeroes |
834 | for (size_t i = 0; i < src_step; i++) |
835 | acc[i] = (acc_data_t)0; |
836 | } |
837 | |
838 | const wei_data_t *_weights = weights + g * weights_g_size; |
839 | for_(int od = 0; od < jcp.od; ++od) |
840 | for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) { |
841 | auto out_off = os_nb * m + od * jcp.os; |
842 | const diff_dst_data_t *_diff_dst |
843 | = diff_dst + (n * jcp.ngroups + g) * dst_step + out_off; |
844 | const dim_t os_block |
845 | = nstl::min((dim_t)jcp.os_block, jcp.os - os_nb * m); |
846 | const dim_t LDC = jcp.im2col_sz ? os_block : M; |
847 | |
848 | const acc_data_t zero = 0.0, one = 1.0; |
849 | status_t st_thr = gemm_bf16bf16f32("N" , "T" , &os_block, &N, &K, |
850 | &one, _diff_dst, &M, _weights, &N, &zero, |
851 | jcp.im2col_sz ? _col : acc + out_off, &LDC); |
852 | |
853 | if (st_thr != status::success) { |
854 | st = st_thr; |
855 | return; |
856 | } |
857 | |
858 | if (jcp.im2col_sz) { |
859 | if (!is_problem_3d) |
860 | jit_gemm_convolution_utils::col2im( |
861 | jcp, _col, acc, os_nb * jcp.os_block, os_block); |
862 | else |
863 | jit_gemm_convolution_utils::col2im_3d(jcp, _col, acc, |
864 | od, os_nb * jcp.os_block, os_block); |
865 | } |
866 | } |
867 | if (diff_src_data_type == data_type::bf16) { |
868 | size_t spatial_size = (size_t)jcp.ih * jcp.iw * jcp.id; |
869 | store_bfloat16_in_parallel((bfloat16_t *)diff_src_local, |
870 | (const float *)acc, jcp.ic, spatial_size, |
871 | jcp.nthr == 1); |
872 | } |
873 | nd_iterator_step(g, jcp.ngroups, n, jcp.mb); |
874 | } |
875 | }); |
876 | |
877 | return st; |
878 | } |
879 | |
880 | template <data_type_t diff_wei_data_type> |
881 | void gemm_bf16_convolution_bwd_weights_t< |
882 | diff_wei_data_type>::bf16_bwd_weights_reduction_par_nspc(int ithr_mb, |
883 | int nthr_mb, size_t g_start, size_t g_end, const conv_gemm_conf_t &jcp, |
884 | const acc_data_t *weights_reduce_base, |
885 | diff_wei_data_t *weights_base) const { |
886 | assert(nthr_mb > 1); // no reduction for nthr_mb == 1 |
887 | |
888 | const bool is_bf16_out = diff_wei_data_type == data_type::bf16; |
889 | const dim_t weights_g_size = jcp.oc; |
890 | dim_t weights_start {0}, weights_end {0}; |
891 | balance211(jcp.ks * jcp.ic, nthr_mb, ithr_mb, weights_start, weights_end); |
892 | |
893 | for (auto tidx = 1; tidx < nthr_mb; ++tidx) { |
894 | const acc_data_t *ws_base |
895 | = weights_reduce_base + tidx * weights_g_size * jcp.ks * jcp.ic; |
896 | for_(auto w = weights_start; w < weights_end; ++w) |
897 | for (auto g = g_start; g < g_end; ++g) { |
898 | const acc_data_t *ws_ptr = ws_base + w * jcp.oc; |
899 | float *wei_reduced = is_bf16_out |
900 | ? (float *)weights_reduce_base + w * jcp.oc |
901 | : (float *)weights_base + (w * jcp.ngroups + g) * jcp.oc; |
902 | if (is_bf16_out && tidx == nthr_mb - 1) { |
903 | // the last iteration for bfloat16 requires conversion |
904 | // and store to diff_weights array |
905 | diff_wei_data_t *dwei_ptr |
906 | = weights_base + (w * jcp.ngroups + g) * jcp.oc; |
907 | add_floats_and_cvt_to_bfloat16( |
908 | (bfloat16_t *)(dwei_ptr), wei_reduced, ws_ptr, jcp.oc); |
909 | } else { |
910 | acc_ker_->accumulate(wei_reduced, ws_ptr, jcp.oc); |
911 | } |
912 | } |
913 | } |
914 | } |
915 | |
916 | template <data_type_t diff_wei_data_type> |
917 | void gemm_bf16_convolution_bwd_weights_t< |
918 | diff_wei_data_type>::bf16_bwd_weights_reduction_par_ncsp(int ithr_mb, |
919 | int nthr_mb, const conv_gemm_conf_t &jcp, |
920 | const acc_data_t *weights_reduce_base, |
921 | diff_wei_data_t *weights_base) const { |
922 | assert(nthr_mb > 1); // no reduction for nthr_mb == 1 |
923 | |
924 | const bool is_bf16_out = diff_wei_data_type == data_type::bf16; |
925 | const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; |
926 | size_t weights_start {0}, weights_end {0}; |
927 | balance211(weights_g_size, nthr_mb, ithr_mb, weights_start, weights_end); |
928 | |
929 | if (weights_start >= weights_end) return; // nothing to do |
930 | |
931 | size_t acc_size = weights_end - weights_start; |
932 | float *wei_reduced = is_bf16_out |
933 | ? (float *)weights_reduce_base + weights_start |
934 | : (float *)weights_base + weights_start; |
935 | if (!is_bf16_out) { |
936 | // f32 diff_weights require initialization by weights_reduce |
937 | // for thr_mb = 0 |
938 | for (size_t i = 0; i < acc_size; i++) |
939 | wei_reduced[i] = ((float *)weights_reduce_base + weights_start)[i]; |
940 | } |
941 | |
942 | for (int thr_mb = 1; thr_mb < nthr_mb; ++thr_mb) { |
943 | float *wei_to_reduce = (float *)weights_reduce_base |
944 | + thr_mb * weights_g_size + weights_start; |
945 | |
946 | if (is_bf16_out && thr_mb == nthr_mb - 1) |
947 | // the last iteration for bfloat16 requires conversion |
948 | // and store to diff_weights array |
949 | add_floats_and_cvt_to_bfloat16( |
950 | (bfloat16_t *)(weights_base + weights_start), wei_reduced, |
951 | wei_to_reduce, acc_size); |
952 | else |
953 | acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size); |
954 | } |
955 | } |
956 | |
957 | template <data_type_t diff_wei_data_type> |
958 | status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>:: |
959 | execute_backward_weights_nspc(const exec_ctx_t &ctx) const { |
960 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
961 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
962 | auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
963 | |
964 | auto col = ctx.get_scratchpad_grantor().template get<src_data_t>( |
965 | key_conv_gemm_col); |
966 | auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>( |
967 | key_conv_wei_reduction); |
968 | const conv_gemm_conf_t &jcp = this->pd()->jcp_; |
969 | |
970 | acc_data_t *acc_base = diff_wei_data_type == data_type::bf16 |
971 | ? ctx.get_scratchpad_grantor().template get<acc_data_t>( |
972 | key_conv_int_dat_in_acc_dt) |
973 | : (acc_data_t *)diff_weights; |
974 | |
975 | float *diff_bias = nullptr; |
976 | if (jcp.with_bias) { |
977 | if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) |
978 | diff_bias = ctx.get_scratchpad_grantor().template get<float>( |
979 | key_conv_bias_bf16_convert_wsp); |
980 | else |
981 | diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
982 | } |
983 | |
984 | const dim_t K = jcp.os * static_cast<size_t>(jcp.od); |
985 | const size_t src_step |
986 | = static_cast<size_t>(jcp.ic) * jcp.ih * jcp.iw * jcp.id; |
987 | const size_t dst_step = jcp.oc * K; |
988 | const size_t weights_g_size = jcp.oc; |
989 | |
990 | const dim_t k = jcp.os; |
991 | const dim_t M = jcp.oc; |
992 | const dim_t N = static_cast<dim_t>(jcp.ic) * jcp.ks; |
993 | const dim_t LDB = jcp.ngroups * jcp.oc; |
994 | const dim_t LDA = jcp.im2col_sz ? jcp.oh * jcp.ow : jcp.ngroups * jcp.ic; |
995 | const bool is_problem_3d = pd()->ndims() == 5; |
996 | |
997 | std::atomic<status_t> st(status::success); |
998 | |
999 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1000 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
1001 | size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0}; |
1002 | |
1003 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
1004 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, |
1005 | mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); |
1006 | |
1007 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
1008 | |
1009 | const int need_reduction = nthr_mb != 1; |
1010 | src_data_t *__restrict imtr |
1011 | = ctx.get_scratchpad_grantor().template get<src_data_t>( |
1012 | key_conv_gemm_imtr) |
1013 | + (ptrdiff_t)ithr * jcp.id * jcp.ic * jcp.is; |
1014 | |
1015 | if (ithr_g != -1 && ithr_mb != -1) { |
1016 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
1017 | balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); |
1018 | |
1019 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
1020 | |
1021 | src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
1022 | if (is_problem_3d) { |
1023 | // jit_gemm_convolution_utils::im2col_3d() requires external |
1024 | // data initialization by zeroes |
1025 | // For performance reasons use uint16_t as proxy for bfloat16_t |
1026 | uint16_t *__restrict _col_r |
1027 | = reinterpret_cast<uint16_t *__restrict>(_col); |
1028 | constexpr uint16_t zero_val = 0; |
1029 | |
1030 | PRAGMA_OMP_SIMD() |
1031 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
1032 | _col_r[i] = zero_val; |
1033 | } |
1034 | |
1035 | acc_data_t *weights_reduce_base = wei_reduction |
1036 | + ithr_g * nthr_mb * weights_g_size * jcp.ks * jcp.ic; |
1037 | acc_data_t *weights_reduce = weights_reduce_base |
1038 | + ithr_mb * weights_g_size * jcp.ks * jcp.ic; |
1039 | |
1040 | const bool use_diff_wei |
1041 | = ithr_mb == 0 && diff_wei_data_type == data_type::f32; |
1042 | for (size_t g = g_start; g < g_end; ++g) { |
1043 | acc_data_t *_diff_weights = use_diff_wei |
1044 | ? (acc_data_t *)diff_weights + g * weights_g_size |
1045 | : need_reduction ? weights_reduce |
1046 | : acc_base + g * weights_g_size; |
1047 | const dim_t LDC = use_diff_wei |
1048 | ? jcp.ngroups * jcp.oc |
1049 | : need_reduction ? jcp.oc : jcp.ngroups * jcp.oc; |
1050 | for (size_t mb = mb_start; mb < mb_end; ++mb) { |
1051 | const src_data_t *_src |
1052 | = src + mb * jcp.ngroups * src_step + g * jcp.ic; |
1053 | if (jcp.im2col_sz && is_problem_3d) |
1054 | jit_gemm_convolution_utils::transpose_dt( |
1055 | jcp, _src, imtr); |
1056 | for (int od = 0; od < jcp.od; ++od) { |
1057 | const diff_dst_data_t *_diff_dst = diff_dst |
1058 | + mb * jcp.ngroups * dst_step |
1059 | + od * k * jcp.ngroups * jcp.oc + g * jcp.oc; |
1060 | |
1061 | if (jcp.im2col_sz) { |
1062 | if (is_problem_3d) |
1063 | jit_gemm_convolution_utils::im2col_dt_3d< |
1064 | src_data_t, src_data_t>( |
1065 | jcp, imtr, _col, od); |
1066 | else |
1067 | jit_gemm_convolution_utils::im2col_dt< |
1068 | src_data_t, src_data_t>(jcp, _src, imtr, |
1069 | _col, 0, jcp.oh, 0, jcp.ow); |
1070 | } |
1071 | const float zero = 0.0f, one = 1.0f; |
1072 | status_t st_thr = gemm_bf16bf16f32("N" , |
1073 | jcp.im2col_sz ? "N" : "T" , &M, &N, &k, &one, |
1074 | _diff_dst, &LDB, |
1075 | jcp.im2col_sz |
1076 | ? _col |
1077 | : _src + od * k * jcp.ngroups * jcp.ic, |
1078 | &LDA, mb == mb_start && od == 0 ? &zero : &one, |
1079 | _diff_weights, &LDC); |
1080 | if (st_thr != status::success) { |
1081 | st = st_thr; |
1082 | // Finish the loops early if failure occured. |
1083 | g = g_end; |
1084 | mb = mb_end; |
1085 | od = jcp.od; |
1086 | } |
1087 | } |
1088 | } |
1089 | } |
1090 | if (need_reduction && dnnl_thr_syncable()) { |
1091 | dnnl_thr_barrier(); |
1092 | if (st != status::success) return; |
1093 | bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start, |
1094 | g_end, jcp, weights_reduce_base, diff_weights); |
1095 | } else if (diff_wei_data_type == data_type::bf16 |
1096 | && g_end > g_start) { |
1097 | cvt_acc_to_dst(jcp, g_start, g_end, (const float *)acc_base, |
1098 | (bfloat16_t *)diff_weights); |
1099 | } |
1100 | } else { |
1101 | if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier(); |
1102 | } |
1103 | }); |
1104 | |
1105 | if (jcp.need_wei_reduction && !dnnl_thr_syncable()) { |
1106 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1107 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
1108 | size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0}; |
1109 | |
1110 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
1111 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, |
1112 | jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb, |
1113 | nthr_mb); |
1114 | |
1115 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
1116 | const int need_reduction = nthr_mb != 1; |
1117 | |
1118 | if (need_reduction && ithr_g != -1 && ithr_mb != -1) { |
1119 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
1120 | balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); |
1121 | |
1122 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
1123 | |
1124 | acc_data_t *weights_reduce_base = wei_reduction |
1125 | + ithr_g * nthr_mb * weights_g_size * jcp.ic * jcp.ks; |
1126 | |
1127 | bf16_bwd_weights_reduction_par_nspc(ithr_mb, nthr_mb, g_start, |
1128 | g_end, jcp, weights_reduce_base, diff_weights); |
1129 | } |
1130 | }); |
1131 | } |
1132 | |
1133 | if (jcp.with_bias) { |
1134 | parallel_nd(jcp.ngroups, jcp.oc, [&](dim_t g, dim_t oc) { |
1135 | acc_data_t db = 0; |
1136 | const dim_t offset_base = g * jcp.oc + oc; |
1137 | for_(dim_t mb = 0; mb < jcp.mb; ++mb) |
1138 | for_(dim_t od = 0; od < jcp.od; ++od) |
1139 | for (dim_t oh = 0; oh < jcp.oh; ++oh) { |
1140 | const dim_t width_stride = jcp.ngroups * jcp.oc; |
1141 | const diff_dst_data_t *__restrict diff_dst_arr = diff_dst |
1142 | + offset_base |
1143 | + ((mb * jcp.od + od) * jcp.oh + oh) * jcp.ow |
1144 | * width_stride; |
1145 | |
1146 | PRAGMA_OMP_SIMD(reduction(+ : db)) |
1147 | for (dim_t ow = 0; ow < jcp.ow; ++ow) { |
1148 | db += diff_dst_arr[ow * width_stride]; |
1149 | } |
1150 | } |
1151 | diff_bias[g * jcp.oc + oc] = db; |
1152 | }); |
1153 | |
1154 | if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) { |
1155 | auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS); |
1156 | cvt_float_to_bfloat16( |
1157 | diff_bias_in, diff_bias, jcp.ngroups * jcp.oc); |
1158 | } |
1159 | } |
1160 | return st; |
1161 | } |
1162 | |
1163 | template <data_type_t diff_wei_data_type> |
1164 | status_t gemm_bf16_convolution_bwd_weights_t<diff_wei_data_type>:: |
1165 | execute_backward_weights_ncsp(const exec_ctx_t &ctx) const { |
1166 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
1167 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
1168 | auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
1169 | |
1170 | auto col = ctx.get_scratchpad_grantor().template get<src_data_t>( |
1171 | key_conv_gemm_col); |
1172 | auto wei_reduction = ctx.get_scratchpad_grantor().template get<acc_data_t>( |
1173 | key_conv_wei_reduction); |
1174 | |
1175 | const conv_gemm_conf_t &jcp = this->pd()->jcp_; |
1176 | |
1177 | acc_data_t *acc_base = diff_wei_data_type == data_type::bf16 |
1178 | ? ctx.get_scratchpad_grantor().template get<acc_data_t>( |
1179 | key_conv_int_dat_in_acc_dt) |
1180 | : (acc_data_t *)diff_weights; |
1181 | |
1182 | float *diff_bias = nullptr; |
1183 | if (jcp.with_bias) { |
1184 | if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) |
1185 | diff_bias = ctx.get_scratchpad_grantor().template get<float>( |
1186 | key_conv_bias_bf16_convert_wsp); |
1187 | else |
1188 | diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
1189 | } |
1190 | |
1191 | const dim_t K = jcp.os * jcp.od; |
1192 | const size_t src_step = (size_t)jcp.ic * jcp.ih * jcp.iw * jcp.id; |
1193 | const size_t dst_step = (size_t)jcp.oc * K; |
1194 | const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; |
1195 | |
1196 | const dim_t k = jcp.os_block; |
1197 | const dim_t N = jcp.oc; |
1198 | const dim_t M = jcp.ic * jcp.ks; |
1199 | const bool is_problem_3d = pd()->ndims() == 5; |
1200 | |
1201 | std::atomic<status_t> st(status::success); |
1202 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1203 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
1204 | size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0}; |
1205 | |
1206 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
1207 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups, |
1208 | mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb); |
1209 | |
1210 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
1211 | const int need_reduction = nthr_mb != 1; |
1212 | |
1213 | if (ithr_g != -1 && ithr_mb != -1) { |
1214 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
1215 | balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); |
1216 | |
1217 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
1218 | |
1219 | src_data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; |
1220 | // non-blocked jit_gemm_convolution_utils::im2col_3d() requires |
1221 | // external data initialization by zeroes |
1222 | const bool outer_padding = jcp.os_nb_block == 1; |
1223 | if (outer_padding && is_problem_3d) { |
1224 | for (ptrdiff_t i = 0; i < jcp.im2col_sz; i++) |
1225 | _col[i] = (src_data_t)0; |
1226 | } |
1227 | |
1228 | acc_data_t *weights_reduce_base |
1229 | = wei_reduction + ithr_g * nthr_mb * weights_g_size; |
1230 | acc_data_t *weights_reduce |
1231 | = weights_reduce_base + ithr_mb * weights_g_size; |
1232 | |
1233 | for (size_t g = g_start; g < g_end; ++g) { |
1234 | acc_data_t *acc = need_reduction |
1235 | ? weights_reduce |
1236 | : (acc_base + g * weights_g_size); |
1237 | for (size_t mb = mb_start; mb < mb_end; ++mb) { |
1238 | const src_data_t *_src |
1239 | = src + (mb * jcp.ngroups + g) * src_step; |
1240 | for_(int od = 0; od < jcp.od; ++od) |
1241 | for (int os_nb = 0; os_nb < jcp.os_nb_block; ++os_nb) { |
1242 | auto out_off = os_nb * k + od * jcp.os; |
1243 | const dim_t os_block = nstl::min( |
1244 | (dim_t)jcp.os_block, jcp.os - os_nb * k); |
1245 | const diff_dst_data_t *_diff_dst = diff_dst |
1246 | + (mb * jcp.ngroups + g) * dst_step + out_off; |
1247 | |
1248 | if (jcp.im2col_sz) { |
1249 | if (!is_problem_3d) |
1250 | jit_gemm_convolution_utils::im2col<src_data_t>( |
1251 | jcp, _src, _col, os_nb * jcp.os_block, |
1252 | os_block, 0, jcp.ic); |
1253 | else |
1254 | jit_gemm_convolution_utils::im2col_3d< |
1255 | src_data_t>(jcp, _src, _col, od, |
1256 | os_nb * jcp.os_block, os_block); |
1257 | } |
1258 | |
1259 | const dim_t LDA = jcp.im2col_sz ? os_block : K; |
1260 | const acc_data_t zero = 0.0, one = 1.0; |
1261 | status_t st_thr = gemm_bf16bf16f32("T" , "N" , &M, &N, |
1262 | &os_block, &one, |
1263 | jcp.im2col_sz ? _col : _src + out_off, &LDA, |
1264 | _diff_dst, &K, |
1265 | mb == mb_start && os_nb == 0 && od == 0 ? &zero |
1266 | : &one, |
1267 | acc, &M); |
1268 | |
1269 | if (st_thr != status::success) { |
1270 | st = st_thr; |
1271 | // Finish the loops early if failure occured. |
1272 | g = g_end; |
1273 | mb = mb_end; |
1274 | od = jcp.od; |
1275 | os_nb = jcp.os_nb_block; |
1276 | } |
1277 | } |
1278 | } |
1279 | } |
1280 | if (need_reduction && dnnl_thr_syncable()) { |
1281 | dnnl_thr_barrier(); |
1282 | if (st != status::success) return; |
1283 | diff_wei_data_t *weights_base |
1284 | = diff_weights + g_start * weights_g_size; |
1285 | bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp, |
1286 | weights_reduce_base, weights_base); |
1287 | } else if (diff_wei_data_type == data_type::bf16 |
1288 | && g_end > g_start) { |
1289 | const size_t weights_g_size = (size_t)jcp.ic * jcp.oc * jcp.ks; |
1290 | const size_t work_size = (g_end - g_start) * weights_g_size; |
1291 | bfloat16_t *diff_weights_local |
1292 | = (bfloat16_t *)diff_weights + g_start * weights_g_size; |
1293 | const float *acc_local |
1294 | = (const float *)acc_base + g_start * weights_g_size; |
1295 | store_bfloat16_in_parallel(diff_weights_local, acc_local, |
1296 | work_size, 1, jcp.nthr == 1); |
1297 | } |
1298 | } else { |
1299 | if (need_reduction && dnnl_thr_syncable()) dnnl_thr_barrier(); |
1300 | } |
1301 | }); |
1302 | |
1303 | if (st != status::success) return st; |
1304 | |
1305 | if (jcp.need_wei_reduction && !dnnl_thr_syncable()) { |
1306 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1307 | int ithr_g, nthr_g, ithr_mb, nthr_mb; |
1308 | size_t g_start {0}, g_end {0}, mb_start {0}, mb_end {0}; |
1309 | |
1310 | const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1; |
1311 | jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, |
1312 | jcp.ngroups, mb_for_balance, ithr_g, nthr_g, ithr_mb, |
1313 | nthr_mb); |
1314 | |
1315 | assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1)); |
1316 | const int need_reduction = nthr_mb != 1; |
1317 | |
1318 | if (need_reduction && ithr_g != -1 && ithr_mb != -1) { |
1319 | balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end); |
1320 | balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end); |
1321 | |
1322 | assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0)); |
1323 | |
1324 | acc_data_t *weights_reduce_base |
1325 | = wei_reduction + ithr_g * nthr_mb * weights_g_size; |
1326 | |
1327 | diff_wei_data_t *weights_base |
1328 | = diff_weights + g_start * weights_g_size; |
1329 | bf16_bwd_weights_reduction_par_ncsp(ithr_mb, nthr_mb, jcp, |
1330 | weights_reduce_base, weights_base); |
1331 | } |
1332 | }); |
1333 | } |
1334 | |
1335 | if (jcp.with_bias) { |
1336 | parallel_nd(jcp.ngroups, jcp.oc, [&](size_t g, size_t oc) { |
1337 | acc_data_t db = 0; |
1338 | dim_t offset_ = g * dst_step + oc * K; |
1339 | for (dim_t mb = 0; mb < jcp.mb; ++mb) { |
1340 | dim_t offset = offset_ + mb * jcp.ngroups * dst_step; |
1341 | for_(dim_t od = 0; od < jcp.od; ++od) |
1342 | for (dim_t oh = 0; oh < jcp.oh; ++oh) { |
1343 | PRAGMA_OMP_SIMD(reduction(+ : db)) |
1344 | for (dim_t ow = 0; ow < jcp.ow; ++ow) { |
1345 | db += diff_dst[offset]; |
1346 | offset++; |
1347 | } |
1348 | } |
1349 | } |
1350 | diff_bias[g * jcp.oc + oc] = db; |
1351 | }); |
1352 | |
1353 | if (pd()->desc()->diff_bias_desc.data_type == data_type::bf16) { |
1354 | auto diff_bias_in = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS); |
1355 | cvt_float_to_bfloat16( |
1356 | diff_bias_in, diff_bias, jcp.ngroups * jcp.oc); |
1357 | } |
1358 | } |
1359 | |
1360 | return st; |
1361 | } |
1362 | |
1363 | template struct gemm_bf16_convolution_fwd_t<data_type::f32>; |
1364 | template struct gemm_bf16_convolution_fwd_t<data_type::bf16>; |
1365 | template struct gemm_bf16_convolution_bwd_data_t<data_type::f32>; |
1366 | template struct gemm_bf16_convolution_bwd_data_t<data_type::bf16>; |
1367 | template struct gemm_bf16_convolution_bwd_weights_t<data_type::f32>; |
1368 | template struct gemm_bf16_convolution_bwd_weights_t<data_type::bf16>; |
1369 | |
1370 | } // namespace x64 |
1371 | } // namespace cpu |
1372 | } // namespace impl |
1373 | } // namespace dnnl |
1374 | |