1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20#include "common/utils.hpp"
21
22#include "common/bfloat16.hpp"
23#include "cpu/x64/jit_avx512_core_bf16_convolution.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace dnnl::impl::status;
31using namespace dnnl::impl::memory_tracking::names;
32using namespace dnnl::impl::utils;
33
34using namespace nstl;
35
36#define wht_blk_off(d, g, ...) \
37 (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \
38 : (d).blk_off(__VA_ARGS__))
39
40void jit_avx512_core_bf16_convolution_fwd_t::prepare_padded_bias(
41 const char *&bias, const memory_tracking::grantor_t &scratchpad) const {
42 if (!pd()->wants_padded_bias()) return;
43
44 const size_t bia_dt_size = pd()->jcp_.typesize_bia;
45 auto padded_bias = scratchpad.template get<char>(
46 memory_tracking::names::key_conv_padded_bias);
47 utils::array_copy(
48 padded_bias, bias, bia_dt_size * pd()->jcp_.oc_without_padding);
49 utils::array_set(padded_bias + bia_dt_size * pd()->jcp_.oc_without_padding,
50 0.f, bia_dt_size * (pd()->jcp_.oc - pd()->jcp_.oc_without_padding));
51 bias = padded_bias;
52}
53
54void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_1d(
55 const exec_ctx_t &ctx) const {
56 const auto &jcp = pd()->jcp_;
57 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
58 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
59 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
60 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
61 const auto post_ops_binary_rhs_arg_vec
62 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
63
64 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
65
66 const size_t bia_dt_size = pd()->jcp_.typesize_bia;
67
68 const memory_desc_wrapper src_d(pd()->src_md());
69 const memory_desc_wrapper dst_d(pd()->dst_md());
70 const memory_desc_wrapper weights_d(pd()->weights_md(0));
71
72 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
73
74 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
75 // TODO: experiment with g_blocking for perf fine tuning
76 int g_blocking = 1;
77 int nb_groups = jcp.ngroups / g_blocking;
78 dim_t work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
79
80 int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr;
81 parallel(nthr, [&](const int ithr, const int nthr) {
82 dim_t start {0}, end {0};
83 balance211(work_amount, nthr, ithr, start, end);
84 auto par_conv = jit_conv_call_s();
85
86 int n {0}, gg {0}, occ {0}, owb {0};
87
88 if (jcp.loop_order == loop_cwgn) {
89 int dummy {0};
90 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
91 nb_groups, n, jcp.mb, dummy, 1);
92 } else if (jcp.loop_order == loop_gncw) {
93 int dummy {0};
94 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
95 owb, jcp.nb_ow, dummy, 1);
96 } else if (jcp.loop_order == loop_nhwcg) {
97 nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
98 gg, nb_groups);
99 } else
100 assert(!"unsupported loop order");
101
102 while (start < end) {
103 int ocb = occ * jcp.nb_oc_blocking;
104 int g = gg * g_blocking;
105 int ow_s = owb * jcp.ow_block;
106 int iw_s = ow_s * jcp.stride_w;
107
108 const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::nwc;
109 const int oc_idx = is_dst_layout_nxc
110 ? g * jcp.oc + ocb * jcp.oc_block
111 : g * jcp.nb_oc + ocb;
112 auto dst_w
113 = dst + jcp.typesize_out * dst_d.blk_off(n, oc_idx, ow_s);
114 auto bias_w = bias ? bias
115 + bia_dt_size * oc_idx
116 * (is_dst_layout_nxc ? 1 : jcp.oc_block)
117 : nullptr;
118 const bool is_src_layout_nxc = jcp.src_tag == format_tag::nwc;
119 const int ic_idx = is_src_layout_nxc ? g * jcp.ic : g * jcp.nb_ic;
120 auto src_w = src + src_d.blk_off(n, ic_idx, iw_s);
121 auto wht_w = weights + wht_blk_off(weights_d, g, ocb);
122
123 par_conv.load_work = this_block_size(ocb * jcp.oc_block,
124 jcp.oc_without_padding, jcp.nb_oc_blocking * jcp.oc_block);
125 par_conv.src = src_w;
126 par_conv.dst = dst_w;
127 par_conv.filt = wht_w;
128 par_conv.bias = bias_w;
129 par_conv.owb = owb;
130
131 par_conv.oc_l_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block);
132 par_conv.dst_orig = dst;
133 par_conv.post_ops_binary_rhs_arg_vec
134 = post_ops_binary_rhs_arg_vec.data();
135 (*kernel_)(&par_conv);
136
137 if (jcp.loop_order == loop_cwgn) {
138 int dummy {0};
139 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg,
140 nb_groups, n, jcp.mb, dummy, 1);
141 } else if (jcp.loop_order == loop_gncw) {
142 int dummy {0};
143 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
144 oc_chunks, owb, jcp.nb_ow, dummy, 1);
145 } else if (jcp.loop_order == loop_nhwcg) {
146 ++start;
147 nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg,
148 nb_groups);
149 } else
150 assert(!"unsupported loop order");
151 }
152 });
153}
154
155void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d(
156 const exec_ctx_t &ctx) const {
157 const auto &jcp = pd()->jcp_;
158 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
159 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
160 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
161 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
162 const auto post_ops_binary_rhs_arg_vec
163 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
164
165 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
166
167 const size_t bia_dt_size = pd()->jcp_.typesize_bia;
168
169 const memory_desc_wrapper src_d(pd()->src_md());
170 const memory_desc_wrapper dst_d(pd()->dst_md());
171 const memory_desc_wrapper weights_d(pd()->weights_md(0));
172
173 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
174
175 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
176 // TODO: experiment with g_blocking for perf fine tuning
177 int g_blocking = 1;
178 int nb_groups = jcp.ngroups / g_blocking;
179 dim_t work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
180
181 int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr;
182 parallel(nthr, [&](const int ithr, const int nthr) {
183 dim_t start {0}, end {0};
184 balance211(work_amount, nthr, ithr, start, end);
185 auto par_conv = jit_conv_call_s();
186
187 size_t src_h_stride = src_d.blk_off(0, 0, 1);
188 size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
189 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
190
191 int n {0}, gg {0}, occ {0}, oh_s {0}, owb {0};
192
193 if (jcp.loop_order == loop_cwgn)
194 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
195 nb_groups, n, jcp.mb, oh_s, jcp.oh);
196 else if (jcp.loop_order == loop_gncw)
197 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
198 owb, jcp.nb_ow, oh_s, jcp.oh);
199 else if (jcp.loop_order == loop_nhwcg)
200 nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
201 occ, oc_chunks, gg, nb_groups);
202 else
203 assert(!"unsupported loop order");
204
205 while (start < end) {
206
207 int ocb = occ * jcp.nb_oc_blocking;
208 int g = gg * g_blocking;
209 int work_rem = end - start;
210 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
211 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
212 if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; //step instead
213
214 int ow_s = owb * jcp.ow_block;
215 int iw_s = ow_s * jcp.stride_w;
216
217 const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::nhwc;
218 const int oc_idx = is_dst_layout_nxc
219 ? g * jcp.oc + ocb * jcp.oc_block
220 : g * jcp.nb_oc + ocb;
221 auto dst_w = dst
222 + jcp.typesize_out * dst_d.blk_off(n, oc_idx, oh_s, ow_s);
223 auto bias_w = bias ? bias
224 + bia_dt_size * oc_idx
225 * (is_dst_layout_nxc ? 1 : jcp.oc_block)
226 : nullptr;
227 const bool is_src_layout_nxc = jcp.src_tag == format_tag::nhwc;
228 const int ic_idx = is_src_layout_nxc ? g * jcp.ic : g * jcp.nb_ic;
229 auto src_w = src + src_d.blk_off(n, ic_idx, ih_s, iw_s);
230 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
231
232 for (int oj = oh_s, ij = ih_s; oj < oh_e;
233 ++oj, ij += jcp.stride_h) {
234 int dilate_h = jcp.dilate_h + 1;
235 int i_t_overflow = div_up(max(0, -ij), dilate_h);
236 int i_b_overflow = div_up(
237 max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
238 dilate_h);
239 int kh_padding
240 = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
241 auto aux_src = src_w + i_t_overflow * dilate_h * src_h_stride;
242 auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
243
244 par_conv.load_work = utils::this_block_size(ocb * jcp.oc_block,
245 jcp.oc_without_padding,
246 jcp.nb_oc_blocking * jcp.oc_block);
247 par_conv.src = aux_src;
248 par_conv.dst = dst_w;
249 par_conv.filt = aux_wht;
250 par_conv.bias = bias_w;
251 par_conv.kh_padding = kh_padding;
252 par_conv.owb = owb;
253
254 par_conv.oc_l_off
255 = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block);
256 par_conv.dst_orig = dst;
257 par_conv.post_ops_binary_rhs_arg_vec
258 = post_ops_binary_rhs_arg_vec.data();
259 (*kernel_)(&par_conv);
260
261 src_w += src_h_stride * jcp.stride_h;
262 dst_w += jcp.typesize_out * dst_h_stride;
263 }
264 if (jcp.loop_order == loop_cwgn)
265 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg,
266 nb_groups, n, jcp.mb, oh_s, jcp.oh);
267 else if (jcp.loop_order == loop_gncw)
268 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
269 oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
270 else if (jcp.loop_order == loop_nhwcg) {
271 ++start;
272 nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
273 oc_chunks, gg, nb_groups);
274 } else
275 assert(!"unsupported loop order");
276 }
277 });
278}
279
280void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d(
281 const exec_ctx_t &ctx) const {
282 const auto &jcp = pd()->jcp_;
283 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
284 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
285 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
286 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
287 const auto post_ops_binary_rhs_arg_vec
288 = binary_injector::prepare_binary_args(jcp.post_ops, ctx);
289
290 prepare_padded_bias(bias, ctx.get_scratchpad_grantor());
291
292 const size_t bia_dt_size = pd()->jcp_.typesize_bia;
293
294 const memory_desc_wrapper src_d(pd()->src_md());
295 const memory_desc_wrapper dst_d(pd()->dst_md());
296 const memory_desc_wrapper weights_d(pd()->weights_md(0));
297
298 assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
299
300 int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
301 // TODO: experiment with g_blocking for perf fine tuning
302 int g_blocking = 1;
303 int nb_groups = jcp.ngroups / g_blocking;
304 dim_t work_amount
305 = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow;
306
307 int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr;
308 parallel(nthr, [&](const int ithr, const int nthr) {
309 dim_t start {0}, end {0};
310 balance211(work_amount, nthr, ithr, start, end);
311 auto par_conv = jit_conv_call_s();
312
313 size_t src_d_stride = src_d.blk_off(0, 0, 1);
314 size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
315 size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
316 size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
317 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
318
319 int n {0}, gg {0}, occ {0}, od_s {0}, oh_s {0}, owb {0};
320
321 if (jcp.loop_order == loop_cwgn)
322 nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
323 nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
324 else if (jcp.loop_order == loop_gncw)
325 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
326 owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh);
327 else if (jcp.loop_order == loop_nhwcg)
328 nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb,
329 jcp.nb_ow, occ, oc_chunks, gg, nb_groups);
330 else
331 assert(!"unsupported loop order");
332
333 while (start < end) {
334
335 int ocb = occ * jcp.nb_oc_blocking;
336 int g = gg * g_blocking;
337 int work_rem = end - start;
338 int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
339 int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
340 if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; //step instead
341 int ow_s = owb * jcp.ow_block;
342 int iw_s = ow_s * jcp.stride_w;
343
344 int id_s = -jcp.f_pad + od_s * jcp.stride_d;
345 int dilate_d = jcp.dilate_d + 1;
346 int d_t_overflow = div_up(max(0, -id_s), dilate_d);
347 int d_b_overflow = div_up(
348 max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
349 dilate_d);
350 int kd_padding = nstl::max(0, jcp.kd - d_t_overflow - d_b_overflow);
351
352 const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::ndhwc;
353 const int oc_idx = is_dst_layout_nxc
354 ? g * jcp.oc + ocb * jcp.oc_block
355 : g * jcp.nb_oc + ocb;
356 auto dst_w = dst
357 + jcp.typesize_out
358 * dst_d.blk_off(n, oc_idx, od_s, oh_s, ow_s);
359 auto bias_w = bias ? bias
360 + bia_dt_size * oc_idx
361 * (is_dst_layout_nxc ? 1 : jcp.oc_block)
362 : nullptr;
363 const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc;
364 const int ic_idx = is_src_layout_nxc ? g * jcp.ic : g * jcp.nb_ic;
365 auto src_w = src + src_d.blk_off(n, ic_idx, id_s, ih_s, iw_s)
366 + d_t_overflow * dilate_d * src_d_stride;
367 auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0)
368 + d_t_overflow * wht_d_stride;
369
370 for (int oj = oh_s, ij = ih_s; oj < oh_e;
371 ++oj, ij += jcp.stride_h) {
372 int dilate_h = jcp.dilate_h + 1;
373 int i_t_overflow = div_up(max(0, -ij), dilate_h);
374 int i_b_overflow = div_up(
375 max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
376 dilate_h);
377 int kh_padding
378 = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
379 auto aux_src = src_w + i_t_overflow * dilate_h * src_h_stride;
380 auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
381
382 par_conv.load_work = utils::this_block_size(ocb * jcp.oc_block,
383 jcp.oc_without_padding,
384 jcp.nb_oc_blocking * jcp.oc_block);
385 par_conv.src = aux_src;
386 par_conv.dst = dst_w;
387 par_conv.filt = aux_wht;
388 par_conv.bias = bias_w;
389 par_conv.kh_padding = kh_padding;
390 par_conv.kd_padding = kd_padding;
391 par_conv.owb = owb;
392
393 par_conv.oc_l_off
394 = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block);
395 par_conv.dst_orig = dst;
396 par_conv.post_ops_binary_rhs_arg_vec
397 = post_ops_binary_rhs_arg_vec.data();
398 (*kernel_)(&par_conv);
399
400 src_w += src_h_stride * jcp.stride_h;
401 dst_w += jcp.typesize_out * dst_h_stride;
402 }
403 if (jcp.loop_order == loop_cwgn)
404 nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg,
405 nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh);
406 else if (jcp.loop_order == loop_gncw)
407 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
408 oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh);
409 else if (jcp.loop_order == loop_nhwcg) {
410 ++start;
411 nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb,
412 jcp.nb_ow, occ, oc_chunks, gg, nb_groups);
413 } else
414 assert(!"unsupported loop order");
415 }
416 });
417}
418
419void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data_3d(
420 const exec_ctx_t &ctx) const {
421 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
422 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
423 auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC);
424
425 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
426 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
427 const memory_desc_wrapper weights_d(pd()->weights_md(0));
428
429 const auto &jcp = pd()->jcp_;
430
431 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
432 int start {0}, end {0};
433 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
434 // TODO: experiment with g_blocking for perf fine tuning
435 int g_blocking = 1;
436 int nb_groups = jcp.ngroups / g_blocking;
437 int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.id * jcp.ih;
438 balance211(work_amount, nthr, ithr, start, end);
439
440 auto par_conv = jit_conv_call_s();
441
442 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
443 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
444 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
445
446 bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1;
447 bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1;
448
449 int n {0}, gg {0}, icc {0}, id_s {0}, ih_s {0};
450 if (jcp.loop_order == loop_cgn)
451 nd_iterator_init(start, icc, ic_chunks, gg, nb_groups, n, jcp.mb,
452 id_s, jcp.id, ih_s, jcp.ih);
453 else if (jcp.loop_order == loop_gnc)
454 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, ic_chunks,
455 id_s, jcp.id, ih_s, jcp.ih);
456 else if (jcp.loop_order == loop_nhwcg)
457 nd_iterator_init(start, n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc,
458 ic_chunks, gg, nb_groups);
459 else
460 assert(!"unsupported loop order");
461
462 while (start < end) {
463 int icb = icc * jcp.nb_ic_blocking;
464 int g = gg * g_blocking;
465 int work_rem = end - start;
466 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
467 if (jcp.loop_order == loop_nhwcg) ih_e = ih_s + 1; //step instead
468
469 int od_s = 0, kd_len = 0, kd_lo = 0;
470 if (is_fast_path_d) {
471 int d_t_overflow = max(0, jcp.kd - 1 - id_s - jcp.f_pad);
472 int d_b_overflow
473 = max(0, jcp.kd - jcp.id + id_s - jcp.back_pad);
474 kd_len = jcp.kd - d_t_overflow - d_b_overflow;
475 kd_lo = d_b_overflow;
476 od_s = id_s + jcp.f_pad - d_b_overflow;
477 } else if (jcp.dilate_d != 0) { // stride == 1
478 int dilate_d = jcp.dilate_d + 1;
479 // Note: use div_up to account for "holes" in filter
480 int d_t_overflow = div_up(
481 max(0, (jcp.kd - 1) * dilate_d - id_s - jcp.f_pad),
482 dilate_d);
483 int d_b_overflow
484 = div_up(max(0,
485 (jcp.kd - 1) * dilate_d + 1 - jcp.id
486 + id_s - jcp.back_pad),
487 dilate_d);
488 kd_len = jcp.kd - d_t_overflow - d_b_overflow;
489 kd_lo = d_b_overflow;
490 od_s = id_s + jcp.f_pad - d_b_overflow * dilate_d;
491 } else { // dilate == 0
492 int d_t_overflow = max(
493 0, (jcp.kd - 1 - id_s - jcp.f_pad) / jcp.stride_d);
494 int d_b_overflow = max(0,
495 (jcp.kd - jcp.id + id_s - jcp.back_pad) / jcp.stride_d);
496 int overflow_kd_hi = jcp.kd - 1
497 - modulo(
498 jcp.id - 1 + jcp.back_pad - id_s, jcp.stride_d);
499 int overflow_kd_lo = (id_s + jcp.f_pad) % jcp.stride_d;
500
501 kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1
502 - d_t_overflow - d_b_overflow;
503 kd_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d;
504 od_s = (id_s + jcp.f_pad - kd_lo) / jcp.stride_d;
505 }
506 assert(kd_len >= 0);
507
508 const bool is_dsrc_layout_nxc = jcp.src_tag == format_tag::ndhwc;
509 const int ic_idx = is_dsrc_layout_nxc
510 ? g * jcp.ic + icb * jcp.ic_block
511 : g * jcp.nb_ic + icb;
512 auto diff_src_w = diff_src
513 + jcp.typesize_out * diff_src_d.blk_off(n, ic_idx, id_s);
514 const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::ndhwc;
515 const int oc_idx = is_ddst_layout_nxc ? g * jcp.oc : g * jcp.nb_oc;
516 auto diff_dst_w = diff_dst + diff_dst_d.blk_off(n, oc_idx, od_s);
517 auto wht_w = weights + wht_blk_off(weights_d, g, 0, icb, kd_lo);
518
519 for (int ij = ih_s; ij < ih_e; ++ij) {
520 int oj, kh_len, kh_lo;
521 if (is_fast_path_h) { // dilate == 0 && stride == 1
522 int i_t_overflow = max(0, jcp.kh - 1 - ij - jcp.t_pad);
523 int i_b_overflow = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad);
524 kh_len = jcp.kh - i_t_overflow - i_b_overflow;
525 kh_lo = i_b_overflow;
526 oj = ij + jcp.t_pad - i_b_overflow;
527 } else if (jcp.dilate_h != 0) { // stride == 1
528 int dilate_h = jcp.dilate_h + 1;
529 // Note: use div_up to account for "holes" in filter
530 int i_t_overflow = div_up(
531 max(0, (jcp.kh - 1) * dilate_h - ij - jcp.t_pad),
532 dilate_h);
533 int i_b_overflow
534 = div_up(max(0,
535 (jcp.kh - 1) * dilate_h + 1
536 - jcp.ih + ij - jcp.b_pad),
537 dilate_h);
538 kh_len = jcp.kh - i_t_overflow - i_b_overflow;
539 kh_lo = i_b_overflow;
540 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
541 } else { // dilate == 0
542 int i_t_overflow = max(
543 0, (jcp.kh - 1 - ij - jcp.t_pad) / jcp.stride_h);
544 int i_b_overflow = max(0,
545 (jcp.kh - jcp.ih + ij - jcp.b_pad) / jcp.stride_h);
546 int overflow_kh_hi = jcp.kh - 1
547 - modulo(jcp.ih - 1 + jcp.b_pad - ij, jcp.stride_h);
548 int overflow_kh_lo = (ij + jcp.t_pad) % jcp.stride_h;
549
550 kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
551 + 1 - i_t_overflow - i_b_overflow;
552 kh_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
553 oj = (ij + jcp.t_pad - kh_lo) / jcp.stride_h;
554 }
555 assert(kh_len >= 0);
556
557 par_conv.load_work = utils::this_block_size(icb * jcp.ic_block,
558 jcp.ic, jcp.nb_ic_blocking * jcp.ic_block);
559 par_conv.src = diff_src_w
560 + jcp.typesize_out * ij * diff_src_h_stride;
561 par_conv.dst = diff_dst_w + oj * diff_dst_h_stride;
562 par_conv.filt = wht_w + kh_lo * wht_h_stride;
563 par_conv.kh_padding = kh_len;
564 par_conv.kd_padding = kd_len;
565
566 (*kernel_)(&par_conv);
567 }
568
569 if (jcp.loop_order == loop_cgn)
570 nd_iterator_jump(start, end, icc, ic_chunks, gg, nb_groups, n,
571 jcp.mb, id_s, jcp.id, ih_s, jcp.ih);
572 else if (jcp.loop_order == loop_gnc)
573 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc,
574 ic_chunks, id_s, jcp.id, ih_s, jcp.ih);
575 else if (jcp.loop_order == loop_nhwcg) {
576 ++start;
577 nd_iterator_step(n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc,
578 ic_chunks, gg, nb_groups);
579 } else
580 assert(!"unsupported loop order");
581 }
582 });
583}
584
585void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data(
586 const exec_ctx_t &ctx) const {
587 auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
588 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
589 auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC);
590
591 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
592 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
593 const memory_desc_wrapper weights_d(pd()->weights_md(0));
594
595 const auto &jcp = pd()->jcp_;
596
597 parallel(jcp.nthr, [&](const int ithr, const int nthr) {
598 int start {0}, end {0};
599 int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
600 // TODO: experiment with g_blocking for perf fine tuning
601 int g_blocking = 1;
602 int nb_groups = jcp.ngroups / g_blocking;
603 int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.ih * jcp.nb_iw;
604 balance211(work_amount, nthr, ithr, start, end);
605
606 auto par_conv = jit_conv_call_s();
607 size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
608 size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
609 size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
610
611 bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1;
612
613 int n {0}, gg {0}, icc {0}, ih_s {0}, iwb {0};
614 if (jcp.loop_order == loop_cwgn)
615 nd_iterator_init(start, icc, ic_chunks, iwb, jcp.nb_iw, gg,
616 nb_groups, n, jcp.mb, ih_s, jcp.ih);
617 else if (jcp.loop_order == loop_gncw)
618 nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, ic_chunks,
619 iwb, jcp.nb_iw, ih_s, jcp.ih);
620 else if (jcp.loop_order == loop_nhwcg)
621 nd_iterator_init(start, n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw,
622 icc, ic_chunks, gg, nb_groups);
623 else
624 assert(!"unsupported loop order");
625
626 while (start < end) {
627 int icb = icc * jcp.nb_ic_blocking;
628 int g = gg * g_blocking;
629 int work_rem = end - start;
630 int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
631 if (jcp.loop_order == loop_nhwcg) ih_e = ih_s + 1; //step instead
632 int iw_s = iwb * jcp.iw_block;
633 int ow_s = iw_s / jcp.stride_w;
634
635 auto diff_src_w = diff_src;
636 auto diff_dst_w = diff_dst;
637 const bool is_ddst_layout_nxc = utils::one_of(
638 jcp.dst_tag, format_tag::nwc, format_tag::nhwc);
639 const int oc_idx = is_ddst_layout_nxc ? g * jcp.oc : g * jcp.nb_oc;
640 const bool is_dsrc_layout_nxc = utils::one_of(
641 jcp.src_tag, format_tag::nwc, format_tag::nhwc);
642 const int ic_idx = is_dsrc_layout_nxc
643 ? g * jcp.ic + icb * jcp.ic_block
644 : g * jcp.nb_ic + icb;
645 if (jcp.ndims == 3) {
646 diff_src_w += jcp.typesize_out
647 * diff_src_d.blk_off(n, ic_idx, iw_s);
648 diff_dst_w += diff_dst_d.blk_off(n, oc_idx, ow_s);
649 } else {
650 diff_src_w += jcp.typesize_out
651 * diff_src_d.blk_off(n, ic_idx, 0, iw_s);
652 diff_dst_w += diff_dst_d.blk_off(n, oc_idx, 0, ow_s);
653 }
654 auto wht_w = weights + wht_blk_off(weights_d, g, 0, icb);
655
656 for (int ij = ih_s; ij < ih_e; ++ij) {
657 int oj, k_len, k_lo;
658 if (is_fast_path) { // dilate == 0 && stride == 1
659 int i_t_overflow = max(0, jcp.kh - 1 - ij - jcp.t_pad);
660 int i_b_overflow = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad);
661 k_len = jcp.kh - i_t_overflow - i_b_overflow;
662 k_lo = i_b_overflow;
663 oj = ij + jcp.t_pad - i_b_overflow;
664 } else if (jcp.dilate_h != 0) { // stride == 1
665 int dilate_h = jcp.dilate_h + 1;
666 // Note: use div_up to account for "holes" in filter
667 int i_t_overflow = div_up(
668 max(0, (jcp.kh - 1) * dilate_h - ij - jcp.t_pad),
669 dilate_h);
670 int i_b_overflow
671 = div_up(max(0,
672 (jcp.kh - 1) * dilate_h + 1
673 - jcp.ih + ij - jcp.b_pad),
674 dilate_h);
675 k_len = jcp.kh - i_t_overflow - i_b_overflow;
676 k_lo = i_b_overflow;
677 oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
678 } else { // dilate == 0
679 int i_t_overflow = max(
680 0, (jcp.kh - 1 - ij - jcp.t_pad) / jcp.stride_h);
681 int i_b_overflow = max(0,
682 (jcp.kh - jcp.ih + ij - jcp.b_pad) / jcp.stride_h);
683 int overflow_kh_hi = jcp.kh - 1
684 - modulo(jcp.ih - 1 + jcp.b_pad - ij, jcp.stride_h);
685 int overflow_kh_lo = (ij + jcp.t_pad) % jcp.stride_h;
686
687 k_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h + 1
688 - i_t_overflow - i_b_overflow;
689 k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
690 oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
691 }
692 assert(k_len >= 0);
693
694 par_conv.load_work = utils::this_block_size(icb * jcp.ic_block,
695 jcp.ic, jcp.nb_ic_blocking * jcp.ic_block);
696 par_conv.src = diff_src_w
697 + jcp.typesize_out * ij * diff_src_h_stride;
698 par_conv.dst = diff_dst_w + oj * diff_dst_h_stride;
699 par_conv.filt = wht_w + k_lo * wht_h_stride;
700 par_conv.kh_padding = k_len;
701 par_conv.iwb = iwb;
702
703 (*kernel_)(&par_conv);
704 }
705
706 if (jcp.loop_order == loop_cwgn)
707 nd_iterator_jump(start, end, icc, ic_chunks, iwb, jcp.nb_iw, gg,
708 nb_groups, n, jcp.mb, ih_s, jcp.ih);
709 else if (jcp.loop_order == loop_gncw)
710 nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc,
711 ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih);
712 else if (jcp.loop_order == loop_nhwcg) {
713 ++start;
714 nd_iterator_step(n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw, icc,
715 ic_chunks, gg, nb_groups);
716 } else
717 assert(!"unsupported loop order");
718 }
719 });
720}
721
722status_t jit_avx512_core_bf16_convolution_bwd_weights_t ::init(
723 engine_t *engine) {
724 const auto &j = pd()->jcp_;
725
726 nthr_ = j.nthr;
727 nthr_mb_ = j.nthr_mb;
728 nthr_g_ = j.nthr_g;
729 nthr_oc_b_ = j.nthr_oc_b;
730 nthr_ic_b_ = j.nthr_ic_b;
731
732 CHECK(safe_ptr_assign(
733 kernel_, new jit_avx512_core_bf16_conv_bwd_weights_kernel_f32(j)));
734 CHECK(kernel_->create_kernel());
735
736 if (j.transpose_src) {
737 CHECK(safe_ptr_assign(trans_kernel_, create_trans_src(&j)));
738 CHECK(trans_kernel_->create_kernel());
739 }
740 if (j.transpose_dst) {
741 CHECK(safe_ptr_assign(trans_dst_kernel_, create_trans_dst(&j)));
742 CHECK(trans_dst_kernel_->create_kernel());
743 }
744 if (nthr_mb_ > 1) {
745 CHECK(safe_ptr_assign(
746 acc_ker_, new cpu_accumulator_1d_t<data_type::f32>()));
747 CHECK(acc_ker_->create_kernel());
748 }
749
750 return status::success;
751}
752
753struct jit_avx512_core_bf16_convolution_bwd_weights_t ::thread_info_t {
754 const src_data_t *src = nullptr;
755 const diff_dst_data_t *diff_dst = nullptr;
756 const void *diff_weights = nullptr;
757 const void *diff_bias = nullptr;
758
759 const memory_tracking::grantor_t scratchpad;
760
761 src_data_t *tr_src = nullptr;
762 diff_dst_data_t *tr_diff_dst = nullptr;
763 simple_barrier::ctx_t *tr_src_bctx = nullptr;
764 simple_barrier::ctx_t *tr_diff_dst_bctx = nullptr;
765
766 float *wei_bia_reduction = nullptr;
767 float *bia_reduction = nullptr;
768 simple_barrier::ctx_t *wei_bia_reduction_bctx = nullptr;
769
770 int ithr = 0;
771 int ithr_ic_b = 0, ithr_oc_b = 0, ithr_g = 0, ithr_mb = 0;
772 int ithr_but_oc = 0;
773 int ithr_but_ic = 0;
774
775 int img_start = 0, img_end = 0, img_work = 0;
776 int g_start = 0, g_end = 0, g_work = 0;
777 int oc_b_start = 0, oc_b_end = 0, oc_b_work = 0;
778 int ic_b_start = 0, ic_b_end = 0, ic_b_work = 0;
779
780 thread_info_t(const jit_avx512_core_bf16_convolution_bwd_weights_t *self,
781 const exec_ctx_t &ctx, int ithr)
782 : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) {
783 diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST);
784 src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
785 diff_weights = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_WEIGHTS);
786
787 const auto &jcp = self->kernel_->jcp;
788 diff_bias = self->pd()->with_bias()
789 && (jcp.oc_without_padding % jcp.oc_block != 0)
790 && self->pd()->jcp_.bia_dt == data_type::f32
791 ? (void *)scratchpad.template get<float>(key_conv_padded_bias)
792 : CTX_OUT_MEM(void *, DNNL_ARG_DIFF_BIAS);
793
794 if (jcp.transpose_src) {
795 tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
796
797 if (jcp.global_transpose)
798 tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
799 key_conv_tr_src_bctx);
800 }
801 if (jcp.transpose_dst) {
802 tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
803 key_conv_tr_diff_dst);
804
805 if (jcp.global_transpose)
806 tr_diff_dst_bctx
807 = scratchpad.template get<simple_barrier::ctx_t>(
808 key_conv_tr_diff_dst_bctx);
809 }
810 wei_bia_reduction
811 = scratchpad.template get<float>(key_conv_wei_bia_reduction);
812 bia_reduction = nullptr;
813 if (jcp.with_bias) {
814 const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block
815 * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
816 const int num_wei_buffers = jcp.wei_dt == data_type::bf16
817 ? jcp.nthr_mb
818 : jcp.nthr_mb - 1;
819 bia_reduction = wei_bia_reduction + wei_size * num_wei_buffers;
820 }
821
822 if (jcp.global_transpose)
823 wei_bia_reduction_bctx
824 = scratchpad.template get<simple_barrier::ctx_t>(
825 key_conv_wei_bia_reduction_bctx);
826
827 ithr_ic_b = ithr % self->nthr_ic_b_;
828 ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
829 ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
830 ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
831
832 ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
833 + ithr_ic_b;
834
835 ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
836 + ithr_oc_b;
837
838 int work_amount = jcp.nthr_mb_work;
839 /* reduction dimension */
840 balance211(work_amount, self->nthr_mb_, ithr_mb, img_start, img_end);
841 img_work = img_end - img_start;
842
843 /* independent dimensions */
844 balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
845 g_work = g_end - g_start;
846
847 balance211(
848 jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, oc_b_end);
849 oc_b_work = oc_b_end - oc_b_start;
850
851 balance211(
852 jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, ic_b_end);
853 ic_b_work = ic_b_end - ic_b_start;
854 }
855};
856
857size_t jit_avx512_core_bf16_convolution_bwd_weights_t::tr_src_buf_number(
858 const thread_info_t *ti, int g, int ic) const {
859 const jit_conv_conf_t &jcp = this->kernel_->jcp;
860 return jcp.global_transpose
861 ? ti->ithr_mb * jcp.nb_ic * jcp.ngroups + g * jcp.nb_ic + ic
862 : ti->ithr;
863}
864size_t jit_avx512_core_bf16_convolution_bwd_weights_t::tr_diff_dst_buf_number(
865 const thread_info_t *ti, int g, int oc) const {
866 const jit_conv_conf_t &jcp = this->kernel_->jcp;
867 return jcp.global_transpose
868 ? ti->ithr_mb * jcp.nb_oc * jcp.ngroups + g * jcp.nb_oc + oc
869 : ti->ithr;
870}
871
872void jit_avx512_core_bf16_convolution_bwd_weights_t ::trans_src(
873 src_data_t *tr_src, const src_data_t *src, int row_count) const {
874 const jit_conv_conf_t &jcp = this->kernel_->jcp;
875 const int pf_depth = 2;
876 struct {
877 const src_data_t *src;
878 src_data_t *tr_src;
879 } pf_circ_buf_src[pf_depth];
880
881 assert(jcp.ic_block == 16 || jcp.is_1stconv);
882 const int src_stride = jcp.iw * jcp.ic_block;
883 const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
884
885 for (int iwork = 0; iwork < row_count + pf_depth - 1; iwork++) {
886 pf_circ_buf_src[iwork % pf_depth] = {src, tr_src};
887
888 if (iwork >= pf_depth - 1) {
889 int old_idx = (iwork - pf_depth + 1) % pf_depth;
890 auto ctx = jit_trans_src_t::ctx_t();
891 ctx.src = pf_circ_buf_src[old_idx].src;
892 ctx.tr_src = pf_circ_buf_src[old_idx].tr_src;
893 ctx.src_prf = src;
894 ctx.tr_src_prf = tr_src;
895 (*trans_kernel_)(&ctx);
896 }
897 src += src_stride;
898 tr_src += tr_src_stride;
899 }
900}
901
902void jit_avx512_core_bf16_convolution_bwd_weights_t::trans_src_nxc(
903 src_data_t *tr_src, const src_data_t *src_base, int spatial_start,
904 dim_t spatial_start_offset, int icb_start, dim_t chb_stride,
905 int row_count) const {
906 const jit_conv_conf_t &jcp = this->kernel_->jcp;
907 const int src_stride = jcp.iw * jcp.ngroups * jcp.ic;
908 const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
909
910 int work_rest = row_count;
911 int max_spatial_work = jcp.id * jcp.ih;
912 int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start);
913 const src_data_t *src = src_base + spatial_start_offset;
914 int icb = 0;
915 const int ic_tail_work = jcp.ic_tail ? jcp.ic_tail : jcp.ic_block;
916 while (work_rest > 0) {
917 for (int iwork = 0; iwork < sp_work; iwork++) {
918 auto ctx = jit_trans_src_t::ctx_t();
919 ctx.src = src;
920 ctx.tr_src = tr_src;
921 assert(icb_start + icb < jcp.nb_ic);
922 ctx.ch_work = (icb_start + icb + 1) == jcp.nb_ic ? ic_tail_work
923 : jcp.ic_block;
924 ctx.src_prf = nullptr;
925 ctx.tr_src_prf = nullptr;
926 (*trans_kernel_)(&ctx);
927 src += src_stride;
928 tr_src += tr_src_stride;
929 }
930 work_rest -= sp_work;
931 sp_work = nstl::min(work_rest, max_spatial_work);
932 icb++;
933 src = src_base + icb * chb_stride;
934 }
935}
936
937void jit_avx512_core_bf16_convolution_bwd_weights_t ::trans_dst(
938 diff_dst_data_t *tr_diff_dst, const diff_dst_data_t *diff_dst,
939 int row_count) const {
940
941 const jit_conv_conf_t &jcp = this->kernel_->jcp;
942 const int pf_depth = 2;
943 struct {
944 const diff_dst_data_t *diff_dst;
945 diff_dst_data_t *tr_diff_dst;
946 } pf_circ_buf_dst[pf_depth];
947
948 assert(jcp.ic_block == 16 || jcp.is_1stconv);
949 const int diff_dst_stride = jcp.ow * jcp.oc_block;
950 const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block;
951
952 for (int iwork = 0; iwork < row_count + pf_depth - 1; iwork++) {
953 pf_circ_buf_dst[iwork % pf_depth]
954 = {(diff_dst_data_t *)diff_dst, tr_diff_dst};
955
956 if (iwork >= pf_depth - 1) {
957 int old_idx = (iwork - pf_depth + 1) % pf_depth;
958 auto ctx = jit_trans_dst_t::ctx_t();
959 ctx.src = pf_circ_buf_dst[old_idx].diff_dst;
960 ctx.tr_src = pf_circ_buf_dst[old_idx].tr_diff_dst;
961 ctx.src_prf = diff_dst;
962 ctx.tr_src_prf = tr_diff_dst;
963 (*trans_dst_kernel_)(&ctx);
964 }
965 diff_dst += diff_dst_stride;
966 tr_diff_dst += tr_diff_dst_stride;
967 }
968}
969
970void jit_avx512_core_bf16_convolution_bwd_weights_t::trans_dst_nxc(
971 diff_dst_data_t *tr_diff_dst, const diff_dst_data_t *diff_dst_base,
972 int spatial_start, dim_t spatial_start_offset, int ocb_start,
973 dim_t chb_stride, int row_count) const {
974 const jit_conv_conf_t &jcp = this->kernel_->jcp;
975 const int diff_dst_stride = jcp.ow * jcp.ngroups * jcp.oc;
976 const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block;
977 int work_rest = row_count;
978 int max_spatial_work = jcp.od * jcp.oh;
979 int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start);
980 const src_data_t *diff_dst = diff_dst_base + spatial_start_offset;
981 int ocb = 0;
982 const int oc_tail_work = jcp.oc_tail ? jcp.oc_tail : jcp.oc_block;
983 while (work_rest > 0) {
984 for (int iwork = 0; iwork < sp_work; iwork++) {
985 auto ctx = jit_trans_dst_t::ctx_t();
986 ctx.src = diff_dst;
987 ctx.tr_src = tr_diff_dst;
988 assert(ocb_start + ocb < jcp.nb_oc);
989 ctx.ch_work = (ocb_start + ocb + 1) == jcp.nb_oc ? oc_tail_work
990 : jcp.oc_block;
991 ctx.src_prf = nullptr;
992 ctx.tr_src_prf = nullptr;
993 (*trans_dst_kernel_)(&ctx);
994 diff_dst += diff_dst_stride;
995 tr_diff_dst += tr_diff_dst_stride;
996 }
997 work_rest -= sp_work;
998 sp_work = nstl::min(work_rest, max_spatial_work);
999 ocb++;
1000 diff_dst = diff_dst_base + ocb * chb_stride;
1001 }
1002}
1003
1004void jit_avx512_core_bf16_convolution_bwd_weights_t ::compute_diff_weights_2d(
1005 const thread_info_t *ti) const {
1006 const memory_desc_wrapper src_d(pd()->src_md());
1007 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1008 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1009 const auto &jcp = kernel_->jcp;
1010 const bool is_src_layout_nxc = jcp.src_tag == format_tag::nhwc;
1011 const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nhwc;
1012
1013 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1014 * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
1015 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1016 const int optimal_hblock = jcp.spatial_blk_size;
1017
1018 float *diff_wei;
1019 if (diff_weights_d.data_type() == data_type::bf16)
1020 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
1021 else
1022 diff_wei = ti->ithr_mb == 0
1023 ? (float *)ti->diff_weights
1024 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1025
1026 float *diff_bias = nullptr;
1027 if (jcp.with_bias) {
1028 if (jcp.bia_dt == data_type::bf16)
1029 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
1030 else
1031 diff_bias = ti->ithr_mb == 0
1032 ? (float *)ti->diff_bias
1033 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
1034 }
1035
1036 auto tr_diff_dst_off = [&](int g, int oc, int oj) {
1037 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1038 return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size
1039 + oj * tr_row_size;
1040 };
1041
1042 int img {0}, oh_s {0};
1043 int start = ti->img_start;
1044 int end = ti->img_end;
1045
1046 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1047
1048 nd_iterator_init(start, img, jcp.mb, oh_s, jcp.oh);
1049
1050 while (start < end) {
1051 auto p = jit_conv_call_s();
1052 int work_rem = end - start;
1053 const int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
1054 int ih_s = nstl::max(0, -jcp.t_pad + oh_s * jcp.stride_h);
1055 const int ih_e = nstl::min(
1056 jcp.ih, -jcp.t_pad + (oh_e - 1) * jcp.stride_h + ext_kh);
1057
1058 auto tr_src_off = [&](int g, int ic, int ih_end, int ij) {
1059 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1060 // Aligned to buffer end to use guard elements
1061 return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size
1062 + (jcp.ih - ih_end + ij) * tr_row_size;
1063 };
1064
1065 if (jcp.global_transpose) {
1066 using simple_barrier::barrier;
1067 // TODO: try to call local transpositions just before jit kernel
1068 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
1069 if (jcp.transpose_src) {
1070 int j {0};
1071 const int work_amount
1072 = ti->g_work * ti->ic_b_work * (ih_e - ih_s);
1073 int tr_start {0}, tr_end {0};
1074 balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start,
1075 tr_end);
1076
1077 int g {0}, ic_b {0};
1078 nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work,
1079 j, ih_e - ih_s);
1080
1081 if (nthr_oc_b_ > 1)
1082 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1083 while (tr_start < tr_end) {
1084 int g_ = g + ti->g_start;
1085 int ic_b_ = ic_b + ti->ic_b_start;
1086 int j_s = j + ih_s;
1087 int j_e = j_s + nstl::min(tr_end - tr_start, ih_e - j_s);
1088 const int ic_off_idx = is_src_layout_nxc
1089 ? g_ * jcp.ic + ic_b_ * jcp.ic_block
1090 : g_ * jcp.nb_ic + ic_b_;
1091 const src_data_t *src
1092 = &ti->src[src_d.blk_off(img, ic_off_idx, j_s)];
1093 src_data_t *tr_src
1094 = &ti->tr_src[tr_src_off(g_, ic_b_, ih_e, j_s)];
1095
1096 if (is_src_layout_nxc)
1097 trans_src_nxc(tr_src, src, 0, 0, ic_b_, 0, j_e - j_s);
1098 else
1099 trans_src(tr_src, src, j_e - j_s);
1100
1101 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b,
1102 ti->ic_b_work, j, ih_e - ih_s);
1103 }
1104 if (nthr_oc_b_ > 1)
1105 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1106 }
1107 if (jcp.transpose_dst) {
1108 int j {0};
1109 const int work_amount
1110 = ti->g_work * ti->oc_b_work * (oh_e - oh_s);
1111 int tr_start {0}, tr_end {0};
1112 balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start,
1113 tr_end);
1114
1115 int g {0}, oc_b {0};
1116 nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work,
1117 j, oh_e - oh_s);
1118
1119 if (nthr_ic_b_ > 1)
1120 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1121 while (tr_start < tr_end) {
1122 int g_ = g + ti->g_start;
1123 int oc_b_ = oc_b + ti->oc_b_start;
1124 int j_s = j + oh_s;
1125 int j_e = j_s + nstl::min(tr_end - tr_start, oh_e - j_s);
1126 const int oc_off_idx = is_ddst_layout_nxc
1127 ? g_ * jcp.oc + oc_b_ * jcp.oc_block
1128 : g_ * jcp.nb_oc + oc_b_;
1129 const diff_dst_data_t *diff_dst
1130 = &ti->diff_dst[diff_dst_d.blk_off(
1131 img, oc_off_idx, j_s)];
1132 diff_dst_data_t *tr_diff_dst
1133 = &ti->tr_diff_dst[tr_diff_dst_off(g_, oc_b_, j_s)];
1134
1135 if (is_ddst_layout_nxc)
1136 trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0,
1137 j_e - j_s);
1138 else
1139 trans_dst(tr_diff_dst, diff_dst, j_e - j_s);
1140
1141 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b,
1142 ti->oc_b_work, j, oh_e - oh_s);
1143 }
1144 if (nthr_ic_b_ > 1)
1145 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1146 }
1147 }
1148 int height_block = jcp.global_transpose ? oh_e - oh_s : optimal_hblock;
1149 int ic_b_step
1150 = jcp.uses_permw_transposition ? jcp.nb_ic_blocking_max : 1;
1151 int icb_work = ti->ic_b_end - ti->ic_b_start;
1152 if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step)
1153 ic_b_step = utils::div_up(icb_work, 2);
1154
1155 for_(int ohb_s = oh_s; ohb_s < oh_e; ohb_s += height_block)
1156 for_(int g = ti->g_start; g < ti->g_end; ++g)
1157 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b)
1158 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1159 ic_b += ic_b_step) {
1160 const int ohb_e = nstl::min(ohb_s + height_block, oh_e);
1161 const int ihb_s = nstl::max(0, -jcp.t_pad + ohb_s * jcp.stride_h);
1162 const int ihb_e = nstl::min(
1163 jcp.ih, -jcp.t_pad + (ohb_e - 1) * jcp.stride_h + ext_kh);
1164 assert(IMPLICATION(jcp.global_transpose,
1165 oh_s == ohb_s && oh_e == ohb_e && ih_s == ihb_s
1166 && ih_e == ihb_e));
1167 const int ic_off_idx = is_src_layout_nxc
1168 ? g * jcp.ic + ic_b * jcp.ic_block
1169 : g * jcp.nb_ic + ic_b;
1170 const int oc_off_idx = is_ddst_layout_nxc
1171 ? g * jcp.oc + oc_b * jcp.oc_block
1172 : g * jcp.nb_oc + oc_b;
1173 const int ic_to_compute = this_block_size(
1174 ic_b * jcp.ic_block, jcp.ic, ic_b_step * jcp.ic_block);
1175 const int oc_to_compute = this_block_size(
1176 oc_b * jcp.oc_block, jcp.oc, jcp.oc_block);
1177
1178 if (jcp.transpose_src) {
1179 if (!jcp.global_transpose) {
1180 const src_data_t *src
1181 = (src_data_t *)&ti->src[src_d.blk_off(
1182 img, ic_off_idx, ihb_s)];
1183 src_data_t *tr_src
1184 = &ti->tr_src[tr_src_off(g, ic_b, ihb_e, ihb_s)];
1185 if (is_src_layout_nxc)
1186 trans_src_nxc(
1187 tr_src, src, 0, 0, ic_b, 0, ihb_e - ihb_s);
1188 else
1189 trans_src(tr_src, src, ihb_e - ihb_s);
1190 p.src = tr_src;
1191 } else {
1192 p.src = &ti->tr_src[tr_src_off(g, ic_b, ihb_e, ihb_s)];
1193 }
1194 } else {
1195 p.src = &ti->src[src_d.blk_off(img, ic_off_idx, ihb_s)];
1196 }
1197
1198 if (jcp.transpose_dst) {
1199 if (!jcp.global_transpose) {
1200 const diff_dst_data_t *diff_dst
1201 = &ti->diff_dst[diff_dst_d.blk_off(
1202 img, oc_off_idx, ohb_s)];
1203 diff_dst_data_t *tr_diff_dst
1204 = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0)];
1205 if (is_ddst_layout_nxc)
1206 trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b, 0,
1207 ohb_e - ohb_s);
1208 else
1209 trans_dst(tr_diff_dst, diff_dst, ohb_e - ohb_s);
1210 p.dst = tr_diff_dst;
1211 } else {
1212 p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, ohb_s)];
1213 }
1214 } else {
1215 p.dst = &ti->diff_dst[diff_dst_d.blk_off(
1216 img, oc_off_idx, ohb_s)];
1217 }
1218
1219 p.filt = diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1220 p.bias = diff_bias ? diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
1221 + oc_b * jcp.oc_block
1222 : nullptr;
1223 p.channel = (start == ti->img_start) && (ohb_s == oh_s);
1224 p.reduce_work = ic_to_compute;
1225 p.load_work = oc_to_compute;
1226 p.os_index_begin = ohb_s;
1227 p.os_index_end = ohb_e;
1228 p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0);
1229 assert(ohb_e <= jcp.oh);
1230 (*kernel_)(&p);
1231 }
1232
1233 nd_iterator_jump(start, end, img, jcp.mb, oh_s, jcp.oh);
1234 }
1235}
1236
1237void jit_avx512_core_bf16_convolution_bwd_weights_t ::compute_diff_weights_3d(
1238 const thread_info_t *ti) const {
1239 const memory_desc_wrapper src_d(pd()->src_md());
1240 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1241 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1242 const auto &jcp = kernel_->jcp;
1243 const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc;
1244 const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::ndhwc;
1245 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1246 * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
1247 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1248 const int optimal_dblock = jcp.spatial_blk_size;
1249
1250 float *diff_wei;
1251 if (diff_weights_d.data_type() == data_type::bf16)
1252 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
1253 else
1254 diff_wei = ti->ithr_mb == 0
1255 ? (float *)ti->diff_weights
1256 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1257
1258 float *diff_bias = nullptr;
1259 if (jcp.with_bias) {
1260 if (jcp.bia_dt == data_type::bf16)
1261 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
1262 else
1263 diff_bias = ti->ithr_mb == 0
1264 ? (float *)ti->diff_bias
1265 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
1266 }
1267
1268 auto tr_diff_dst_off_3d = [&](int g, int oc, int od) {
1269 assert(IMPLICATION(is_ddst_layout_nxc, jcp.transpose_dst));
1270 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1271 const size_t tr_3d_size = tr_row_size * jcp.oh;
1272 return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size
1273 + od * tr_3d_size;
1274 };
1275 int img {0}, od_s {0};
1276 int start = ti->img_start;
1277 int end = ti->img_end;
1278
1279 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1280
1281 nd_iterator_init(start, img, jcp.mb, od_s, jcp.od);
1282 while (start < end) {
1283 auto p = jit_conv_call_s();
1284 int work_rem = end - start;
1285 const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
1286 int id_s = nstl::max(0, -jcp.f_pad + od_s * jcp.stride_d);
1287 const int id_e = nstl::min(
1288 jcp.id, -jcp.f_pad + (od_e - 1) * jcp.stride_d + ext_kd);
1289
1290 auto tr_src_off_3d = [&](int g, int ic, int id_end, int id) {
1291 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1292 const size_t tr_3d_size = tr_row_size * jcp.ih;
1293 // Aligned to buffer end to use guard elements
1294 return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size
1295 + (jcp.id - id_end + id) * tr_3d_size;
1296 };
1297
1298 if (jcp.global_transpose) {
1299 using simple_barrier::barrier;
1300 // TODO: try to call local transpositions just before jit kernel
1301 /* tr_src[nb_ic][id][16][~iw~] <- src[nb_ic][id][iw][16] */
1302 if (jcp.transpose_src) {
1303 int d {0};
1304
1305 const int work_amount
1306 = ti->g_work * ti->ic_b_work * (id_e - id_s);
1307
1308 int tr_start {0}, tr_end {0};
1309 balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start,
1310 tr_end);
1311
1312 int g {0}, ic_b {0};
1313
1314 nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work,
1315 d, id_e - id_s);
1316
1317 if (nthr_oc_b_ > 1)
1318 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1319 while (tr_start < tr_end) {
1320 int g_ = g + ti->g_start;
1321 int ic_b_ = ic_b + ti->ic_b_start;
1322 int d_s = d + id_s;
1323 int d_e = d_s + nstl::min(tr_end - tr_start, id_e - d_s);
1324
1325 const int ic_off_idx = is_src_layout_nxc
1326 ? g_ * jcp.ic + ic_b_ * jcp.ic_block
1327 : g_ * jcp.nb_ic + ic_b_;
1328 const src_data_t *src
1329 = &ti->src[src_d.blk_off(img, ic_off_idx, d_s)];
1330 src_data_t *tr_src
1331 = &ti->tr_src[tr_src_off_3d(g_, ic_b_, id_e, d_s)];
1332
1333 if (is_src_layout_nxc)
1334 trans_src_nxc(tr_src, src, 0, 0, ic_b_, 0,
1335 (d_e - d_s) * jcp.ih);
1336 else
1337 trans_src(tr_src, src, (d_e - d_s) * jcp.ih);
1338
1339 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b,
1340 ti->ic_b_work, d, id_e - id_s);
1341 }
1342 if (nthr_oc_b_ > 1)
1343 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1344 }
1345 if (jcp.transpose_dst) {
1346 int d {0};
1347
1348 const int work_amount
1349 = ti->g_work * ti->oc_b_work * (od_e - od_s);
1350
1351 int tr_start {0}, tr_end {0};
1352 balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start,
1353 tr_end);
1354
1355 int g {0}, oc_b {0};
1356
1357 nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work,
1358 d, od_e - od_s);
1359
1360 if (nthr_ic_b_ > 1)
1361 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1362 while (tr_start < tr_end) {
1363 int g_ = g + ti->g_start;
1364 int oc_b_ = oc_b + ti->oc_b_start;
1365 int d_s = d + od_s;
1366 int d_e = d_s + nstl::min(tr_end - tr_start, od_e - d_s);
1367 const int oc_off_idx = is_ddst_layout_nxc
1368 ? g_ * jcp.oc + oc_b_ * jcp.oc_block
1369 : g_ * jcp.nb_oc + oc_b_;
1370
1371 const diff_dst_data_t *diff_dst
1372 = &ti->diff_dst[diff_dst_d.blk_off(
1373 img, oc_off_idx, d_s)];
1374 diff_dst_data_t *tr_diff_dst
1375 = &ti->tr_diff_dst[tr_diff_dst_off_3d(
1376 g_, oc_b_, d_s)];
1377
1378 if (is_ddst_layout_nxc)
1379 trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0,
1380 (d_e - d_s) * jcp.oh);
1381 else
1382 trans_dst(tr_diff_dst, diff_dst, (d_e - d_s) * jcp.oh);
1383
1384 nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b,
1385 ti->oc_b_work, d, od_e - od_s);
1386 }
1387 if (nthr_ic_b_ > 1)
1388 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1389 }
1390 }
1391
1392 int depth_block = jcp.global_transpose ? od_e - od_s : optimal_dblock;
1393
1394 for_(int odb_s = od_s; odb_s < od_e; odb_s += depth_block)
1395 for_(int g = ti->g_start; g < ti->g_end; ++g)
1396 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b)
1397 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
1398 const int odb_e = nstl::min(odb_s + depth_block, od_e);
1399 const int idb_s = nstl::max(0, -jcp.f_pad + odb_s * jcp.stride_d);
1400 const int idb_e = nstl::min(
1401 jcp.id, -jcp.f_pad + (odb_e - 1) * jcp.stride_d + ext_kd);
1402 const int kdb_front_pad
1403 = nstl::max(0, jcp.f_pad - odb_s * jcp.stride_d);
1404 // Assumes kd_back_pad = 0 when kernel is dilated
1405 const int kdb_back_pad = nstl::max(
1406 0, odb_s * jcp.stride_d + jcp.kd - jcp.f_pad - jcp.id);
1407 const int kdb_pad_off = nstl::min(jcp.kd - 1, kdb_front_pad)
1408 * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block
1409 * jcp.typesize_out;
1410
1411 assert(IMPLICATION(jcp.global_transpose,
1412 od_s == odb_s && od_e == odb_e && id_s == idb_s
1413 && id_e == idb_e));
1414 const int ic_off_idx = is_src_layout_nxc
1415 ? g * jcp.ic + ic_b * jcp.ic_block
1416 : g * jcp.nb_ic + ic_b;
1417 const int oc_off_idx = is_ddst_layout_nxc
1418 ? g * jcp.oc + oc_b * jcp.oc_block
1419 : g * jcp.nb_oc + oc_b;
1420 const int ic_to_compute = this_block_size(
1421 ic_b * jcp.ic_block, jcp.ic, jcp.ic_block);
1422 const int oc_to_compute = this_block_size(
1423 oc_b * jcp.oc_block, jcp.oc, jcp.oc_block);
1424 if (jcp.transpose_src) {
1425 if (!jcp.global_transpose) {
1426 const src_data_t *src
1427 = (src_data_t *)&ti->src[src_d.blk_off(
1428 img, ic_off_idx, idb_s)];
1429 src_data_t *tr_src
1430 = &ti->tr_src[tr_src_off_3d(g, ic_b, idb_e, idb_s)];
1431 if (is_src_layout_nxc)
1432 trans_src_nxc(tr_src, src, 0, 0, ic_b, 0,
1433 (idb_e - idb_s) * jcp.ih);
1434 else
1435 trans_src(tr_src, src, (idb_e - idb_s) * jcp.ih);
1436 p.src = tr_src;
1437 } else {
1438 p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, idb_e, idb_s)];
1439 }
1440 } else {
1441 p.src = &ti->src[src_d.blk_off(img, ic_off_idx, idb_s)];
1442 }
1443
1444 if (jcp.transpose_dst) {
1445 if (!jcp.global_transpose) {
1446 const diff_dst_data_t *diff_dst
1447 = &ti->diff_dst[diff_dst_d.blk_off(
1448 img, oc_off_idx, odb_s)];
1449 diff_dst_data_t *tr_diff_dst
1450 = &ti->tr_diff_dst[tr_diff_dst_off_3d(0, 0, 0)];
1451 if (is_ddst_layout_nxc)
1452 trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b, 0,
1453 (odb_e - odb_s) * jcp.oh);
1454 else
1455 trans_dst(tr_diff_dst, diff_dst,
1456 (odb_e - odb_s) * jcp.oh);
1457 p.dst = tr_diff_dst;
1458 } else {
1459 p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(
1460 g, oc_b, odb_s)];
1461 }
1462 } else {
1463 p.dst = &ti->diff_dst[diff_dst_d.blk_off(
1464 img, oc_off_idx, odb_s)];
1465 }
1466
1467 p.filt = diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1468 p.bias = diff_bias ? diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
1469 + oc_b * jcp.oc_block
1470 : nullptr;
1471 p.channel = (start == ti->img_start) && (odb_s == od_s);
1472 p.reduce_work = ic_to_compute;
1473 p.load_work = oc_to_compute;
1474 p.os_index_begin = odb_s;
1475 p.os_index_end = odb_e;
1476 p.kd_padding = jcp.kd - kdb_front_pad - kdb_back_pad;
1477 p.kd_offset = kdb_pad_off;
1478 p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0);
1479 assert(odb_e <= jcp.od);
1480 (*kernel_)(&p);
1481 }
1482
1483 nd_iterator_jump(start, end, img, jcp.mb, od_s, jcp.od);
1484 }
1485}
1486
1487void jit_avx512_core_bf16_convolution_bwd_weights_t ::compute_diff_weights(
1488 const thread_info_t *ti) const {
1489 const memory_desc_wrapper src_d(pd()->src_md());
1490 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
1491 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1492 const auto &jcp = kernel_->jcp;
1493 const bool is_src_layout_nxc = utils::one_of(
1494 jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1495 const bool is_ddst_layout_nxc = utils::one_of(
1496 jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1497 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1498 * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd;
1499 const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1500
1501 float *diff_wei;
1502 if (diff_weights_d.data_type() == data_type::bf16)
1503 diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size;
1504 else
1505 diff_wei = ti->ithr_mb == 0
1506 ? (float *)ti->diff_weights
1507 : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
1508
1509 float *diff_bias = nullptr;
1510 if (jcp.with_bias) {
1511 if (jcp.bia_dt == data_type::bf16)
1512 diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size;
1513 else
1514 diff_bias = ti->ithr_mb == 0
1515 ? (float *)ti->diff_bias
1516 : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size;
1517 }
1518
1519 auto tr_src_off = [&](int g, int ic, int ij) {
1520 assert(IMPLICATION(is_src_layout_nxc, jcp.transpose_src));
1521 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1522 return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size
1523 + ij * tr_row_size;
1524 };
1525
1526 auto tr_src_off_3d = [&](int g, int ic, int id, int ij) {
1527 assert(IMPLICATION(is_ddst_layout_nxc, jcp.transpose_dst));
1528 const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
1529 const size_t tr_3d_size = tr_row_size * jcp.ih;
1530 return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size
1531 + id * tr_3d_size + ij * tr_row_size;
1532 };
1533
1534 auto tr_diff_dst_off = [&](int g, int oc, int oj) {
1535 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1536 return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size
1537 + oj * tr_row_size;
1538 };
1539
1540 auto tr_diff_dst_off_3d = [&](int g, int oc, int od, int oj) {
1541 const size_t tr_row_size = jcp.tr_ow * jcp.oc_block;
1542 const size_t tr_3d_size = tr_row_size * jcp.oh;
1543 return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size
1544 + od * tr_3d_size + oj * tr_row_size;
1545 };
1546
1547 auto uker_trans = [&](int img, int g = 0, int ic_b = 0) {
1548 int j {0}, d {0};
1549 int my_work = jcp.ih * jcp.id;
1550 int ic;
1551 int icb_start = ic_b;
1552 if (jcp.global_transpose) {
1553 const int work_amount = is_src_layout_nxc
1554 ? ti->ic_b_work * jcp.ih * jcp.id
1555 : ti->g_work * ti->ic_b_work * jcp.ih * jcp.id;
1556
1557 int start {0}, end {0};
1558 balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end);
1559 my_work = end - start;
1560
1561 if (is_src_layout_nxc) {
1562 if (jcp.ndims == 5)
1563 nd_iterator_init(
1564 start, ic_b, ti->ic_b_work, d, jcp.id, j, jcp.ih);
1565 else
1566 nd_iterator_init(start, ic_b, ti->ic_b_work, j, jcp.ih);
1567 } else {
1568 if (jcp.ndims == 5)
1569 nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work,
1570 d, jcp.id, j, jcp.ih);
1571 else
1572 nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work,
1573 j, jcp.ih);
1574 }
1575 g += ti->g_start;
1576 ic_b += ti->ic_b_start;
1577 icb_start = ic_b;
1578 ic = is_src_layout_nxc ? g * jcp.ic + ic_b * jcp.ic_block
1579 : g * jcp.nb_ic + ic_b;
1580 } else {
1581 ic = is_src_layout_nxc ? g * jcp.ic + ic_b * jcp.ic_block
1582 : g * jcp.nb_ic + ic_b;
1583 g = 0;
1584 ic_b = 0;
1585 }
1586 const bool need_local_gwork = is_src_layout_nxc && jcp.global_transpose;
1587 const auto local_gwork = need_local_gwork ? ti->g_work : 1;
1588
1589 for (int gg = g; gg < g + local_gwork; ++gg) {
1590 if (need_local_gwork) ic = gg * jcp.ic + ic_b * jcp.ic_block;
1591 src_data_t *tr_src = (jcp.ndims == 5)
1592 ? &ti->tr_src[tr_src_off_3d(gg, ic_b, d, j)]
1593 : &ti->tr_src[tr_src_off(gg, ic_b, j)];
1594 auto src_offset = is_src_layout_nxc
1595 ? src_d.blk_off(img, ic)
1596 : (jcp.ndims == 5 ? src_d.blk_off(img, ic, d, j)
1597 : src_d.blk_off(img, ic, j));
1598 src_data_t *src = (src_data_t *)&ti->src[src_offset];
1599
1600 if (is_src_layout_nxc) {
1601 dim_t sp_start_offset = (jcp.ndims == 5)
1602 ? src_d.blk_off(0, 0, d, j)
1603 : src_d.blk_off(0, 0, j);
1604 dim_t ch_shift = src_d.blk_off(0, jcp.ic_block);
1605 int sp_start_idx = d * jcp.ih + j;
1606 trans_src_nxc(tr_src, src, sp_start_idx, sp_start_offset,
1607 icb_start, ch_shift, my_work);
1608 } else
1609 trans_src(tr_src, src, my_work);
1610 }
1611 };
1612
1613 auto diff_dst_trans = [&](int img, int g = 0, int oc_b = 0) {
1614 int j {0}, d {0};
1615 int my_work = jcp.oh * jcp.od;
1616 int oc;
1617 int ocb_start = oc_b;
1618
1619 if (jcp.global_transpose) {
1620 const size_t work_amount = is_ddst_layout_nxc
1621 ? ti->oc_b_work * jcp.oh * jcp.od
1622 : ti->g_work * ti->oc_b_work * jcp.oh * jcp.od;
1623
1624 size_t start {0}, end {0};
1625 balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, start, end);
1626 my_work = end - start;
1627
1628 if (is_ddst_layout_nxc) {
1629 if (jcp.ndims == 5)
1630 nd_iterator_init(
1631 start, oc_b, ti->oc_b_work, d, jcp.od, j, jcp.oh);
1632 else
1633 nd_iterator_init(start, oc_b, ti->oc_b_work, j, jcp.oh);
1634 } else {
1635 if (jcp.ndims == 5)
1636 nd_iterator_init(start, g, ti->g_work, oc_b, ti->oc_b_work,
1637 d, jcp.od, j, jcp.oh);
1638 else
1639 nd_iterator_init(start, g, ti->g_work, oc_b, ti->oc_b_work,
1640 j, jcp.oh);
1641 }
1642 g += ti->g_start;
1643 oc_b += ti->oc_b_start;
1644 ocb_start = oc_b;
1645 oc = is_ddst_layout_nxc ? g * jcp.oc + oc_b * jcp.oc_block
1646 : g * jcp.nb_oc + oc_b;
1647 } else {
1648 oc = is_ddst_layout_nxc ? g * jcp.oc + oc_b * jcp.oc_block
1649 : g * jcp.nb_oc + oc_b;
1650 g = 0;
1651 oc_b = 0;
1652 }
1653 const bool need_local_gwork
1654 = is_ddst_layout_nxc && jcp.global_transpose;
1655 const auto local_gwork = need_local_gwork ? ti->g_work : 1;
1656
1657 for (int gg = g; gg < g + local_gwork; ++gg) {
1658 if (need_local_gwork) oc = gg * jcp.oc + oc_b * jcp.oc_block;
1659 diff_dst_data_t *tr_diff_dst = (jcp.ndims == 5)
1660 ? &ti->tr_diff_dst[tr_diff_dst_off_3d(gg, oc_b, d, j)]
1661 : &ti->tr_diff_dst[tr_diff_dst_off(gg, oc_b, j)];
1662 auto ddst_offset = is_ddst_layout_nxc
1663 ? diff_dst_d.blk_off(img, oc)
1664 : (jcp.ndims == 5 ? diff_dst_d.blk_off(img, oc, d, j)
1665 : diff_dst_d.blk_off(img, oc, j));
1666 const diff_dst_data_t *diff_dst = &ti->diff_dst[ddst_offset];
1667
1668 if (is_ddst_layout_nxc) {
1669 dim_t sp_start_offset = (jcp.ndims == 5)
1670 ? diff_dst_d.blk_off(0, 0, d, j)
1671 : diff_dst_d.blk_off(0, 0, j);
1672 dim_t ch_shift = diff_dst_d.blk_off(0, jcp.oc_block);
1673 int sp_start_idx = d * jcp.oh + j;
1674 trans_dst_nxc(tr_diff_dst, diff_dst, sp_start_idx,
1675 sp_start_offset, ocb_start, ch_shift, my_work);
1676 } else
1677 trans_dst(tr_diff_dst, diff_dst, my_work);
1678 }
1679 };
1680
1681 for (int img = ti->img_start; img < ti->img_end; ++img) {
1682 auto p = jit_conv_call_s();
1683 if (jcp.global_transpose) {
1684 using simple_barrier::barrier;
1685 // TODO: try to call local transpositions just before jit kernel
1686 /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
1687 if (jcp.transpose_src) {
1688 if (nthr_oc_b_ > 1)
1689 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1690 uker_trans(img);
1691 if (nthr_oc_b_ > 1)
1692 barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
1693 }
1694 if (jcp.transpose_dst) {
1695 if (nthr_ic_b_ > 1)
1696 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1697 diff_dst_trans(img);
1698 if (nthr_ic_b_ > 1)
1699 barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_);
1700 }
1701 }
1702 int ic_b_step
1703 = jcp.uses_permw_transposition ? jcp.nb_ic_blocking_max : 1;
1704 int icb_work = ti->ic_b_end - ti->ic_b_start;
1705 if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step)
1706 ic_b_step = utils::div_up(icb_work, 2);
1707 for_(int g = ti->g_start; g < ti->g_end; ++g)
1708 for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b)
1709 for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end;
1710 ic_b += ic_b_step) {
1711 const int ic_off_idx = is_src_layout_nxc
1712 ? g * jcp.ic + ic_b * jcp.ic_block
1713 : g * jcp.nb_ic + ic_b;
1714 const int oc_off_idx = is_ddst_layout_nxc
1715 ? g * jcp.oc + oc_b * jcp.oc_block
1716 : g * jcp.nb_oc + oc_b;
1717 const int ic_to_compute = this_block_size(
1718 ic_b * jcp.ic_block, jcp.ic, ic_b_step * jcp.ic_block);
1719 const int oc_to_compute = this_block_size(
1720 oc_b * jcp.oc_block, jcp.oc, jcp.oc_block);
1721 if (jcp.transpose_src) {
1722 if (!jcp.global_transpose) {
1723 uker_trans(img, g, ic_b);
1724 if (jcp.ndims == 5) {
1725 p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, 0, 0)];
1726 } else {
1727 p.src = &ti->tr_src[tr_src_off(g, ic_b, 0)];
1728 }
1729 } else {
1730 if (jcp.ndims == 5) {
1731 p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, 0, 0)];
1732 } else {
1733 p.src = &ti->tr_src[tr_src_off(g, ic_b, 0)];
1734 }
1735 }
1736 } else {
1737 p.src = &ti->src[src_d.blk_off(img, ic_off_idx)];
1738 }
1739
1740 if (jcp.transpose_dst) {
1741 if (!jcp.global_transpose) {
1742 diff_dst_trans(img, g, oc_b);
1743 if (jcp.ndims == 5) {
1744 p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(
1745 0, 0, 0, 0)];
1746 } else {
1747 p.dst = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0)];
1748 }
1749 } else {
1750 if (jcp.ndims == 5) {
1751 p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(
1752 g, oc_b, 0, 0)];
1753 } else {
1754 p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, 0)];
1755 }
1756 }
1757 } else {
1758 p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx)];
1759 }
1760
1761 p.filt = diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b);
1762 p.bias = diff_bias ? diff_bias + g * rnd_up(jcp.oc, jcp.oc_block)
1763 + oc_b * jcp.oc_block
1764 : nullptr;
1765 p.channel = (img == ti->img_start);
1766 p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0);
1767 p.reduce_work = ic_to_compute;
1768 p.load_work = oc_to_compute;
1769 (*kernel_)(&p);
1770 }
1771 }
1772}
1773
1774void jit_avx512_core_bf16_convolution_bwd_weights_t ::
1775 reduce_and_convert_diff_weights_and_bias(
1776 const thread_info_t *ti) const {
1777 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
1778
1779 const auto &jcp = kernel_->jcp;
1780 const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic
1781 * jcp.ic_block * jcp.kh * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1);
1782
1783 const bool is_bf16_out = diff_weights_d.data_type() == data_type::bf16;
1784 const bool is_bf16_bias = jcp.with_bias && jcp.bia_dt == data_type::bf16;
1785 if (nthr_mb_ == 1) {
1786 if (is_bf16_out) {
1787 // reduction is not required, only conversion
1788 for_(int g = ti->g_start; g < ti->g_end; g++)
1789 for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) {
1790 const size_t acc_size = (size_t)ti->ic_b_work * jcp.kh * jcp.kw
1791 * ((jcp.ndims == 5) ? jcp.kd : 1) * jcp.ic_block
1792 * jcp.oc_block;
1793 const size_t off
1794 = wht_blk_off(diff_weights_d, g, oc_b, ti->ic_b_start);
1795 cvt_float_to_bfloat16((bfloat16_t *)(ti->diff_weights) + off,
1796 (ti->wei_bia_reduction + off), acc_size);
1797 }
1798 }
1799
1800 if (is_bf16_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0) {
1801 for (int g = ti->g_start; g < ti->g_end; g++) {
1802 int result_start_idx = g * jcp.oc_without_padding
1803 + ti->oc_b_start * jcp.oc_block;
1804 int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block)
1805 + ti->oc_b_start * jcp.oc_block;
1806 const size_t acc_size = nstl::min(jcp.oc_without_padding,
1807 ti->oc_b_end * jcp.oc_block)
1808 - ti->oc_b_start * jcp.oc_block;
1809 bfloat16_t *diff_bias
1810 = (bfloat16_t *)ti->diff_bias + result_start_idx;
1811 float *buffer = ti->bia_reduction + buffer_start_idx;
1812 cvt_float_to_bfloat16(diff_bias, buffer, acc_size);
1813 }
1814 }
1815 return;
1816 }
1817
1818 /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
1819 if (jcp.global_transpose)
1820 simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
1821
1822 const int ic_b_kh_work
1823 = ti->ic_b_work * ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1824 const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
1825
1826 int start {0}, end {0};
1827 balance211(work, nthr_mb_, ti->ithr_mb, start, end);
1828 if (start == end) return;
1829
1830 const int _start_nthr_mb = 1;
1831 for (int thr_mb = _start_nthr_mb; thr_mb < nthr_mb_; ++thr_mb) {
1832 int w = start;
1833 int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0};
1834 nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
1835 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1836 while (w < end) {
1837 const int g = ti->g_start + sub_g_start;
1838 const int oc_b = ti->oc_b_start + sub_oc_b_start;
1839 const int ic_b = ti->ic_b_start
1840 + sub_ic_b_kh_start / ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1841 const int kX
1842 = sub_ic_b_kh_start % ((jcp.ndims == 5) ? jcp.kd : jcp.kh);
1843
1844 const size_t acc_size = (size_t)jcp.kw * jcp.ic_block * jcp.oc_block
1845 * ((jcp.ndims == 5) ? jcp.kh : 1)
1846 * nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start);
1847
1848 const size_t off = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kX);
1849
1850 float *wei_reduced = is_bf16_out
1851 ? ti->wei_bia_reduction + off
1852 : (float *)(ti->diff_weights) + off;
1853
1854 int thr_mb_buffer_idx = is_bf16_out ? thr_mb : thr_mb - 1;
1855 float *wei_to_reduce = ti->wei_bia_reduction
1856 + thr_mb_buffer_idx * wei_size + off;
1857
1858 if (is_bf16_out && thr_mb == nthr_mb_ - 1)
1859 // the last iteration for bfloat16 requires conversion and
1860 // store to diff_weights array
1861 add_floats_and_cvt_to_bfloat16(
1862 (bfloat16_t *)(ti->diff_weights) + off, wei_reduced,
1863 wei_to_reduce, acc_size);
1864 else
1865 acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size);
1866
1867 nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
1868 ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
1869 }
1870 if (jcp.with_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0
1871 && ti->ithr_mb == 0 && ti->img_work > 0) {
1872 for (int g = ti->g_start; g < ti->g_end; g++) {
1873 float *bias_reduced = is_bf16_bias ? ti->bia_reduction
1874 : (float *)(ti->diff_bias);
1875 int thr_mb_buffer_idx = is_bf16_bias ? thr_mb : thr_mb - 1;
1876 int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block;
1877 float *bias_to_reduce
1878 = ti->bia_reduction + thr_mb_buffer_idx * bias_buf_size;
1879 const size_t acc_size = nstl::min(jcp.oc_without_padding,
1880 ti->oc_b_end * jcp.oc_block)
1881 - ti->oc_b_start * jcp.oc_block;
1882 int idx = g * rnd_up(jcp.oc, jcp.oc_block)
1883 + ti->oc_b_start * jcp.oc_block;
1884 if (is_bf16_bias && thr_mb == nthr_mb_ - 1) {
1885 // the last iteration for bfloat16 requires conversion and
1886 // store to diff_weights array
1887 int diff_bias_idx = g * jcp.oc_without_padding
1888 + ti->oc_b_start * jcp.oc_block;
1889 add_floats_and_cvt_to_bfloat16(
1890 (bfloat16_t *)(ti->diff_bias) + diff_bias_idx,
1891 &bias_reduced[idx], &bias_to_reduce[idx], acc_size);
1892 } else {
1893 acc_ker_->accumulate(
1894 &bias_reduced[idx], &bias_to_reduce[idx], acc_size);
1895 }
1896 }
1897 }
1898 }
1899}
1900
1901void jit_avx512_core_bf16_convolution_bwd_weights_t::prepare_scratchpad_data(
1902 const exec_ctx_t &ctx) const {
1903 auto scratchpad = ctx.get_scratchpad_grantor();
1904
1905 const auto &jcp = pd()->jcp_;
1906
1907 if (jcp.transpose_src) {
1908 // XXX: See the comment about tr_iw and guarding elements in
1909 // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf()
1910 auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
1911 // Zero out guard elements that cross a buffer boundary to prevent a
1912 // race condition due to buffer overflows from memory optimization where
1913 // buffers sharing padding
1914 for (size_t ithr = 1; ithr <= jcp.tr_src_buf_count; ++ithr) {
1915 src_data_t *ts = &tr_src[ithr * jcp.tr_src_buf_size];
1916 for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i)
1917 ts[i] = 0;
1918 }
1919
1920 if (jcp.global_transpose && jcp.nthr_oc_b > 1) {
1921 const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
1922 auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
1923 key_conv_tr_src_bctx);
1924 for (int i = 0; i < tr_src_bctx_size; ++i)
1925 simple_barrier::ctx_init(&tr_src_bctx[i]);
1926 }
1927 }
1928 if (jcp.global_transpose && jcp.transpose_dst) {
1929 if (jcp.nthr_ic_b > 1) {
1930 const int tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b;
1931 auto tr_diff_dst_bctx
1932 = scratchpad.template get<simple_barrier::ctx_t>(
1933 key_conv_tr_diff_dst_bctx);
1934 for (int i = 0; i < tr_diff_dst_bctx_size; ++i)
1935 simple_barrier::ctx_init(&tr_diff_dst_bctx[i]);
1936 }
1937 }
1938
1939 if (jcp.global_transpose
1940 && (nthr_mb_ > 1
1941 || pd()->diff_weights_md(0)->data_type
1942 == data_type::bf16)) {
1943 // TODO: don't use barrier for case
1944 // diff_weights_type == data_type::bf16 && nthr_mb_ == 1
1945 simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
1946 key_conv_wei_bia_reduction_bctx));
1947 }
1948}
1949
1950void jit_avx512_core_bf16_convolution_bwd_weights_t ::execute_backward_weights(
1951 const exec_ctx_t &ctx) const {
1952 prepare_scratchpad_data(ctx);
1953
1954 const auto &jcp = pd()->jcp_;
1955 parallel(nthr_, [&](const int ithr, const int nthr) {
1956 assert(nthr_ == nthr);
1957 assert(utils::one_of(pd()->ndims(), 3, 4, 5));
1958
1959 thread_info_t thread_info(this, ctx, ithr);
1960 switch (jcp.harness) {
1961 case harness_2d_reduction:
1962 compute_diff_weights_2d(&thread_info);
1963 if (jcp.global_transpose)
1964 reduce_and_convert_diff_weights_and_bias(&thread_info);
1965 break;
1966 case harness_3d_reduction:
1967 compute_diff_weights_3d(&thread_info);
1968 if (jcp.global_transpose)
1969 reduce_and_convert_diff_weights_and_bias(&thread_info);
1970 break;
1971 case harness_compute_full_spatial:
1972 case harness_mb_reduction:
1973 compute_diff_weights(&thread_info);
1974 if (jcp.global_transpose)
1975 reduce_and_convert_diff_weights_and_bias(&thread_info);
1976 break;
1977 default: assert(!"Invalid harness type");
1978 }
1979 });
1980
1981 if (!jcp.global_transpose) {
1982 parallel(nthr_, [&](const int ithr, const int nthr) {
1983 assert(nthr_ == nthr);
1984 thread_info_t thread_info(this, ctx, ithr);
1985 reduce_and_convert_diff_weights_and_bias(&thread_info);
1986 });
1987 }
1988
1989 if (pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0)
1990 && jcp.bia_dt != data_type::bf16) {
1991 auto diff_bias = ctx.get_scratchpad_grantor().template get<const float>(
1992 key_conv_padded_bias);
1993 auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
1994 const int padded_stride = rnd_up(jcp.oc, jcp.oc_block);
1995 const int stride = jcp.oc_without_padding;
1996 for (int g = 0; g < jcp.ngroups; ++g) {
1997 utils::array_copy(diff_bias_in + g * stride,
1998 diff_bias + g * padded_stride, stride);
1999 }
2000 }
2001}
2002
2003} // namespace x64
2004} // namespace cpu
2005} // namespace impl
2006} // namespace dnnl
2007
2008// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
2009