1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl_types.h" |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | #include "cpu/x64/jit_avx512_core_bf16_1x1_convolution.hpp" |
24 | |
25 | #include "cpu/x64/jit_generator.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | using namespace dnnl::impl::status; |
33 | using namespace dnnl::impl::memory_tracking::names; |
34 | using namespace dnnl::impl::utils; |
35 | using namespace dnnl::impl::prop_kind; |
36 | |
37 | #define data_blk_off(f, n, c, d, h, w) \ |
38 | ((ndims == 3) ? (f).blk_off(n, c, w) \ |
39 | : ((ndims == 4) ? (f).blk_off(n, c, h, w) \ |
40 | : (f).blk_off(n, c, d, h, w))) |
41 | |
42 | namespace { |
43 | /*TODO: investigate why common balance2D defined in common/dnnl_thread.hpp |
44 | * not used here ?*/ |
45 | template <typename T, typename U> |
46 | void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, T nx, T &nx_start, |
47 | T &nx_end, T nx_divider) { |
48 | const T grp_size = utils::div_up(nthr, nx_divider); |
49 | const T grp_count = utils::div_up(nthr, grp_size); |
50 | |
51 | T grp = ithr / grp_size; |
52 | T grp_ithr = ithr % grp_size; |
53 | T grp_nthr = grp_size; |
54 | T first_grps = nthr % grp_count; |
55 | if (first_grps > 0 && grp >= first_grps) { |
56 | ithr -= first_grps * grp_size; |
57 | grp_nthr--; |
58 | grp = ithr / grp_nthr + first_grps; |
59 | grp_ithr = ithr % grp_nthr; |
60 | } |
61 | balance211(nx, grp_count, grp, nx_start, nx_end); |
62 | balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end); |
63 | } |
64 | } // namespace |
65 | |
66 | /* convolution forward */ |
67 | template <data_type_t dst_type> |
68 | void jit_avx512_core_bf16_1x1_convolution_fwd_t<dst_type>::execute_forward( |
69 | const exec_ctx_t &ctx) const { |
70 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
71 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
72 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
73 | auto dst = CTX_OUT_MEM(const char *, DNNL_ARG_DST); |
74 | auto weights_dw = CTX_IN_MEM( |
75 | const dw_wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); |
76 | const auto post_ops_binary_rhs_arg_vec |
77 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
78 | const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_ != nullptr |
79 | ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx, |
80 | pd()->jcp_.post_ops.entry_.size() + 1) |
81 | : std::vector<const void *> {}; |
82 | |
83 | auto scratchpad = ctx.get_scratchpad_grantor(); |
84 | |
85 | const auto &jcp = kernel_->jcp; |
86 | if (pd()->wants_padded_bias()) { |
87 | const size_t bia_dt_size = pd()->jcp_.typesize_bia; |
88 | auto padded_bias = scratchpad.template get<char>(key_conv_padded_bias); |
89 | utils::array_copy( |
90 | padded_bias, bias, bia_dt_size * jcp.oc_without_padding); |
91 | utils::array_set(padded_bias + bia_dt_size * jcp.oc_without_padding, |
92 | 0.f, bia_dt_size * (jcp.oc - jcp.oc_without_padding)); |
93 | bias = padded_bias; |
94 | } |
95 | |
96 | float *bias_dw = nullptr; |
97 | if (pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)->data_type |
98 | == data_type::bf16) { |
99 | auto jcp_dw = pd()->jcp_dw_; |
100 | memory_tracking::grantor_t dw_scratchpad( |
101 | scratchpad, memory_tracking::names::prefix_fusion); |
102 | auto bias_in = CTX_IN_MEM( |
103 | const src_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
104 | bias_dw = dw_scratchpad.template get<float>( |
105 | key_conv_bias_bf16_convert_wsp); |
106 | cvt_bfloat16_to_float(bias_dw, bias_in, jcp_dw->oc_without_padding); |
107 | utils::array_set(bias_dw + jcp_dw->oc_without_padding, 0.f, |
108 | jcp_dw->oc - jcp_dw->oc_without_padding); |
109 | } else { |
110 | auto bias_in = CTX_IN_MEM( |
111 | const float *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); |
112 | bias_dw = const_cast<float *>(bias_in); |
113 | } |
114 | |
115 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
116 | execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, |
117 | dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), |
118 | post_ops_binary_rhs_arg_vec_dw.data()); |
119 | }); |
120 | |
121 | if (pd()->wants_zero_pad_dst()) ctx.zero_pad_output(DNNL_ARG_DST); |
122 | } |
123 | |
124 | template <data_type_t dst_type> |
125 | void jit_avx512_core_bf16_1x1_convolution_fwd_t<dst_type>::execute_forward_thr( |
126 | const int ithr, const int nthr, const src_data_t *src, |
127 | const wei_data_t *weights, const char *bias, |
128 | const dw_wei_data_t *weights_dw, const float *bias_dw, const char *dst, |
129 | const memory_tracking::grantor_t &scratchpad, |
130 | const void *post_ops_binary_rhs_arg_vec, |
131 | const void *post_ops_binary_rhs_arg_vec_dw) const { |
132 | const memory_desc_wrapper src_d(pd()->src_md()); |
133 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
134 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
135 | const memory_desc_wrapper dw_weights_d( |
136 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); |
137 | const memory_desc_wrapper dw_bias_d( |
138 | pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)); |
139 | |
140 | const auto &jcp = kernel_->jcp; |
141 | auto rtus_space = pd()->rtus_.reduce_src_ |
142 | ? scratchpad.get<src_data_t>(key_conv_rtus_space) |
143 | : nullptr; |
144 | float *store_buffer = scratchpad.template get<float>(key_conv_store_wsp); |
145 | |
146 | const int ndims = src_d.ndims(); |
147 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
148 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
149 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
150 | |
151 | auto p = jit_1x1_conv_call_s(); |
152 | |
153 | auto rp = rtus_driver_t<avx512_core>::call_params_t(); |
154 | |
155 | const int nb_oc = jcp.nb_load; |
156 | const int nb_ic = jcp.nb_reduce; |
157 | const int nb_ic_blocking = jcp.nb_reduce_blocking; |
158 | |
159 | // override some constants for fused dw_conv |
160 | const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; |
161 | const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; |
162 | const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; |
163 | const int nb_bcast_blocking_max |
164 | = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; |
165 | const int nb_load_blocking = jcp.nb_load_blocking; |
166 | const int nb_load_blocking_max = jcp.with_dw_conv |
167 | ? jcp.nb_load_blocking |
168 | : jcp.nb_load_blocking_max; |
169 | |
170 | // Begin: declare Variables needed for dw conv. |
171 | dst_data_t *pbuf; //bf16->bf16 fusion |
172 | size_t row_offset; |
173 | const auto jcp_dw = pd()->jcp_dw_; |
174 | const int nb_buffer = jcp.nb_load_blocking; |
175 | std::vector<decltype(pbuf)> addrs; |
176 | const bool is_dst_layout_nxc = utils::one_of( |
177 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
178 | const bool is_src_layout_nxc = utils::one_of( |
179 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
180 | |
181 | auto step = [](int default_step, int remaining, int tail_step) { |
182 | assert(default_step <= tail_step); |
183 | return remaining < tail_step ? remaining : default_step; |
184 | }; |
185 | |
186 | auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, |
187 | int &bcast_step, int &od, int &oh, int &ow, |
188 | int &id, int &ih, int &iw) { |
189 | int osb {0}; |
190 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); |
191 | bcast_step = step( |
192 | nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); |
193 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
194 | |
195 | const int os = osb * os_block; |
196 | od = os / (jcp.oh * jcp.ow); |
197 | int os_2d = os % (jcp.oh * jcp.ow); |
198 | oh = os_2d / jcp.ow; |
199 | ow = os_2d % jcp.ow; |
200 | |
201 | id = od * stride_d; |
202 | ih = oh * stride_h; |
203 | iw = ow * stride_w; |
204 | rp.iw_start = iw; |
205 | |
206 | p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); |
207 | rp.os = p.bcast_dim; |
208 | }; |
209 | |
210 | auto init_load = [&](int ocb, int ocb_end, int &load_step) { |
211 | load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); |
212 | const auto max_oc |
213 | = nstl::min(ocb_end * jcp.oc_block, jcp.oc_without_padding); |
214 | p.load_dim = this_block_size( |
215 | ocb * jcp.oc_block, max_oc, load_step * jcp.oc_block); |
216 | }; |
217 | |
218 | auto init_reduce = [&](int icb) { |
219 | const int nb_ic_blocking_step |
220 | = nstl::min(icb + nb_ic_blocking, nb_ic) - icb; |
221 | p.first_last_flag = 0 | (icb == 0 ? FLAG_REDUCE_FIRST : 0) |
222 | | (icb + nb_ic_blocking_step >= nb_ic ? FLAG_REDUCE_LAST : 0); |
223 | |
224 | p.reduce_dim = this_block_size( |
225 | icb * jcp.ic_block, jcp.ic, nb_ic_blocking_step * jcp.ic_block); |
226 | rp.icb = p.reduce_dim; |
227 | }; |
228 | |
229 | auto ker_1x1 = [&](int ocb, int ocb_start, int icb, int n, int g, int od, |
230 | int oh, int ow, int id, int ih, int iw) { |
231 | const int oc_off_idx = is_dst_layout_nxc |
232 | ? g * jcp.oc + ocb * jcp.oc_block |
233 | : g * nb_oc + ocb; |
234 | const size_t dst_off = data_blk_off(dst_d, n, oc_off_idx, od, oh, ow); |
235 | |
236 | void *output_data = jcp.with_dw_conv |
237 | ? (void *)(pbuf + (oh % jcp_dw->kh) * row_offset) |
238 | : (void *)(&dst[dst_off * dst_d.data_type_size()]); |
239 | p.output_data = output_data; |
240 | |
241 | p.bias_data = &bias[jcp.typesize_bia * oc_off_idx |
242 | * (is_dst_layout_nxc ? 1 : jcp.oc_block)]; |
243 | p.load_data |
244 | = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) |
245 | : weights_d.blk_off(ocb, icb)]; |
246 | |
247 | const int ic_off_idx = is_src_layout_nxc |
248 | ? g * jcp.ic + icb * jcp.ic_block |
249 | : g * nb_ic + icb; |
250 | if (pd()->rtus_.reduce_src_) { |
251 | rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ |
252 | + (is_src_layout_nxc ? ic_off_idx |
253 | : jcp.is * ic_off_idx * jcp.ic_block); |
254 | if (ocb == ocb_start) { |
255 | rp.src = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); |
256 | (*rtus_driver_)(&rp); |
257 | } |
258 | p.bcast_data = rp.ws; |
259 | } else |
260 | p.bcast_data = src + data_blk_off(src_d, n, ic_off_idx, id, ih, iw); |
261 | |
262 | const size_t grp_count = utils::div_up( |
263 | jcp.nthr, utils::div_up(jcp.nthr, jcp.load_grp_count)); |
264 | const size_t max_load_per_thread = is_dst_layout_nxc |
265 | ? jcp.load_dim |
266 | : rnd_up((jcp.load_dim / grp_count), jcp.load_block); |
267 | const size_t str_size = jcp.bcast_dim * max_load_per_thread; |
268 | p.store_buffer = store_buffer + ithr * str_size |
269 | + data_blk_off(dst_d, 0, 0, od, oh, ow); |
270 | |
271 | p.dst_l_off = dst_off; |
272 | p.oc_l_off = oc_off_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
273 | p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; |
274 | p.dst_orig = dst; |
275 | |
276 | (*kernel_)(&p); |
277 | }; |
278 | |
279 | auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, |
280 | int ocb_end) { |
281 | if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; |
282 | if (jcp.loop_order == loop_lbr) { |
283 | int ocb = ocb_start; |
284 | while (ocb < ocb_end) { |
285 | int load_step; |
286 | init_load(ocb, ocb_end, load_step); |
287 | int iwork = bcast_start; |
288 | while (iwork < bcast_end) { |
289 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
290 | id {0}, ih {0}, iw {0}; |
291 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, |
292 | id, ih, iw); |
293 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
294 | init_reduce(icb); |
295 | ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, |
296 | iw); |
297 | } |
298 | iwork += bcast_step; |
299 | } |
300 | ocb += load_step; |
301 | } |
302 | } else if (jcp.loop_order == loop_blr) { |
303 | int iwork = bcast_start; |
304 | while (iwork < bcast_end) { |
305 | int n {0}, g {0}, bcast_step {0}, od {0}, oh {0}, ow {0}, |
306 | id {0}, ih {0}, iw {0}; |
307 | init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, |
308 | ih, iw); |
309 | int ocb = ocb_start; |
310 | while (ocb < ocb_end) { |
311 | int load_step; |
312 | init_load(ocb, ocb_end, load_step); |
313 | for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) { |
314 | init_reduce(icb); |
315 | ker_1x1(ocb, ocb_start, icb, n, g, od, oh, ow, id, ih, |
316 | iw); |
317 | } |
318 | ocb += load_step; |
319 | } |
320 | iwork += bcast_step; |
321 | } |
322 | } else { |
323 | assert(!"unsupported loop order" ); |
324 | } |
325 | }; |
326 | |
327 | auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { |
328 | int oh_1x1 = nstl::max(dw_oh * jcp_dw->stride_h - jcp_dw->t_pad, 0); |
329 | |
330 | for (int i = 0; i < jcp_dw->kh; ++i) |
331 | addrs[i] = pbuf + ((oh_1x1++) % jcp_dw->kh) * row_offset; |
332 | |
333 | const auto ocb_end = ocb_start + load_step; |
334 | const auto wch_stride = (is_src_layout_nxc ? 1 : jcp_dw->iw) |
335 | * jcp_dw->nb_ch_blocking * jcp_dw->ch_block; |
336 | |
337 | const int dil_h = jcp_dw->dilate_h + 1; |
338 | const int str_h = jcp_dw->stride_h; |
339 | const int ch_num = jcp_dw->nb_ch_blocking; |
340 | |
341 | for (int ch = ocb_start; ch < ocb_end; ch += jcp_dw->nb_ch_blocking) { |
342 | |
343 | const int i_t_overflow |
344 | = nstl::max(0, (int)(jcp_dw->t_pad - dw_oh * str_h)); |
345 | const int i_b_overflow |
346 | = nstl::max(jcp_dw->ih, |
347 | (int)(dw_oh * str_h + (jcp_dw->kh - 1) * dil_h |
348 | - jcp_dw->t_pad + 1)) |
349 | - jcp_dw->ih; |
350 | |
351 | const int kh = div_up(i_t_overflow, dil_h); |
352 | const int kh_padding = jcp_dw->kh - div_up(i_t_overflow, dil_h) |
353 | - div_up(i_b_overflow, dil_h); |
354 | |
355 | const int ow = 0; |
356 | const int kw = 0; |
357 | jit_conv_call_s par_conv_dw; |
358 | |
359 | par_conv_dw.src = addrs.data(); |
360 | |
361 | const size_t ch_step = is_dst_layout_nxc |
362 | ? jcp_dw->ch_block |
363 | : dst_d.blk_off(0, 1, 0, 0); |
364 | par_conv_dw.dst |
365 | = &dst[(dst_d.blk_off(n, 0, dw_oh, ow) + ch * ch_step) |
366 | * dst_d.data_type_size()]; |
367 | |
368 | par_conv_dw.filt |
369 | = &weights_dw[dw_weights_d.blk_off(ch, 0, 0, kh, kw)]; |
370 | if (bias) |
371 | par_conv_dw.bias |
372 | = &bias_dw[dw_bias_d.blk_off(ch * jcp_dw->ch_block)]; |
373 | |
374 | par_conv_dw.kh_padding = (size_t)nstl::max(0, kh_padding); |
375 | |
376 | par_conv_dw.load_work = (nstl::min(ch + ch_num, jcp_dw->nb_ch) - ch) |
377 | * jcp_dw->ch_block; |
378 | |
379 | par_conv_dw.oc_l_off = ch * jcp_dw->ch_block; |
380 | par_conv_dw.post_ops_binary_rhs_arg_vec |
381 | = post_ops_binary_rhs_arg_vec_dw; |
382 | par_conv_dw.dst_orig = dst; |
383 | |
384 | (*kernel_dw_)(&par_conv_dw); |
385 | |
386 | for (int i = 0; i < jcp_dw->kh; ++i) |
387 | addrs[i] += wch_stride; |
388 | } |
389 | }; |
390 | |
391 | auto conv_dw = [&]() { |
392 | // Set variables |
393 | memory_tracking::grantor_t dw_scratchpad( |
394 | scratchpad, memory_tracking::names::prefix_fusion); |
395 | const auto dw_conv_buffer |
396 | = dw_scratchpad.get<dst_data_t>(key_fusion_inout_buffer); |
397 | |
398 | const auto dw_conv_buffer_size_ |
399 | = jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block; |
400 | pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; |
401 | row_offset = dw_conv_buffer_size_ / jcp_dw->kh; |
402 | addrs.resize(jcp_dw->kh); |
403 | |
404 | int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; |
405 | balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, |
406 | bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); |
407 | |
408 | while (ocb_start < ocb_end) { |
409 | int load_step; |
410 | init_load(ocb_start, ocb_end, load_step); |
411 | |
412 | int oh_1x1 = 0; |
413 | auto bcast_iter = bcast_start; |
414 | while (bcast_iter < bcast_end) { |
415 | int n {0}, g {0}, oh_dw {0}; |
416 | nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, |
417 | jcp_dw->oh); |
418 | if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary |
419 | const int oh_1x1_range |
420 | = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad; |
421 | const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); |
422 | const int oh_1x1_end |
423 | = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh); |
424 | oh_1x1 = nstl::max( |
425 | oh_1x1_begin, oh_1x1); // Skip rows computed previously |
426 | |
427 | // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw->oh |
428 | const int bcast_start_1x1 |
429 | = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; |
430 | const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; |
431 | |
432 | conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, |
433 | ocb_start + load_step); |
434 | oh_1x1 = oh_1x1_end; |
435 | ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); |
436 | |
437 | bcast_iter += nb_bcast_blocking; |
438 | } |
439 | ocb_start += load_step; |
440 | } |
441 | }; |
442 | |
443 | if (jcp.with_dw_conv) { |
444 | conv_dw(); |
445 | } else { |
446 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
447 | int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; |
448 | balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, |
449 | ocb_start, ocb_end, jcp.load_grp_count); |
450 | |
451 | conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); |
452 | } |
453 | } |
454 | |
455 | REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_fwd_t< |
456 | data_type::f32>); |
457 | REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_fwd_t< |
458 | data_type::bf16>); |
459 | |
460 | template <data_type_t diff_src_type> |
461 | void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< |
462 | diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const { |
463 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
464 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
465 | auto diff_src = CTX_OUT_MEM(diff_src_data_t *, DNNL_ARG_DIFF_SRC); |
466 | auto scratchpad = ctx.get_scratchpad_grantor(); |
467 | const auto &jcp = kernel_->jcp; |
468 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
469 | assert(nthr == jcp.nthr); |
470 | execute_backward_data_thr( |
471 | ithr, nthr, diff_dst, weights, diff_src, scratchpad); |
472 | }); |
473 | } |
474 | |
475 | template <data_type_t diff_src_type> |
476 | void jit_avx512_core_bf16_1x1_convolution_bwd_data_t< |
477 | diff_src_type>::execute_backward_data_thr(const int ithr, |
478 | const int nthr, const diff_dst_data_t *diff_dst, |
479 | const wei_data_t *weights, diff_src_data_t *diff_src, |
480 | const memory_tracking::grantor_t &scratchpad) const { |
481 | |
482 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
483 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
484 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
485 | |
486 | const auto &jcp = kernel_->jcp; |
487 | |
488 | auto rtus_space = pd()->rtus_.reduce_src_ |
489 | ? scratchpad.template get<diff_src_data_t>(key_conv_rtus_space) |
490 | : nullptr; |
491 | float *store_buffer = scratchpad.template get<float>(key_conv_store_wsp); |
492 | const int ndims = diff_src_d.ndims(); |
493 | const int stride_d = (ndims == 5) ? pd()->desc()->strides[0] : 1; |
494 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[ndims - 4]; |
495 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
496 | |
497 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; |
498 | |
499 | auto step = [](int default_step, int remaining, int tail_step) { |
500 | assert(default_step <= tail_step); |
501 | return remaining < tail_step ? remaining : default_step; |
502 | }; |
503 | |
504 | auto p = jit_1x1_conv_call_s(); |
505 | |
506 | auto rp = rtus_driver_t<avx512_core>::call_params_t(); |
507 | const int nb_ic = jcp.nb_load; |
508 | const int nb_oc = jcp.nb_reduce; |
509 | const int os_block = jcp.bcast_block; |
510 | const int nb_oc_blocking = jcp.nb_reduce_blocking; |
511 | |
512 | int bcast_start {0}, bcast_end {0}, icb_start {0}, icb_end {0}; |
513 | balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, jcp.nb_load, |
514 | icb_start, icb_end, jcp.load_grp_count); |
515 | |
516 | auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step, int &od, |
517 | int &oh, int &ow, int &id, int &ih, int &iw) { |
518 | int osb {0}; |
519 | nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, jcp.nb_bcast); |
520 | bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb, |
521 | jcp.nb_bcast_blocking_max); |
522 | bcast_step = nstl::min(bcast_step, bcast_end - iwork); |
523 | |
524 | const int os = osb * os_block; |
525 | od = os / (jcp.oh * jcp.ow); |
526 | const int os_2d = os % (jcp.oh * jcp.ow); |
527 | oh = os_2d / jcp.ow; |
528 | ow = os_2d % jcp.ow; |
529 | id = od * stride_d; |
530 | ih = oh * stride_h; |
531 | iw = ow * stride_w; |
532 | rp.iw_start = iw; |
533 | |
534 | p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); |
535 | rp.os = p.bcast_dim; |
536 | }; |
537 | |
538 | auto init_load = [&](int icb, int &load_step) { |
539 | load_step = step( |
540 | jcp.nb_load_blocking, icb_end - icb, jcp.nb_load_blocking_max); |
541 | const int max_ic = nstl::min(icb_end * jcp.ic_block, jcp.ic); |
542 | p.load_dim = this_block_size( |
543 | icb * jcp.ic_block, max_ic, load_step * jcp.ic_block); |
544 | rp.icb = p.load_dim; |
545 | }; |
546 | |
547 | auto init_reduce = [&](int ocb) { |
548 | const int nb_oc_blocking_step |
549 | = nstl::min(ocb + nb_oc_blocking, nb_oc) - ocb; |
550 | p.first_last_flag = 0 | (ocb == 0 ? FLAG_REDUCE_FIRST : 0) |
551 | | (ocb + nb_oc_blocking_step >= nb_oc ? FLAG_REDUCE_LAST : 0); |
552 | |
553 | p.reduce_dim = this_block_size( |
554 | ocb * jcp.oc_block, jcp.oc, nb_oc_blocking_step * jcp.oc_block); |
555 | }; |
556 | |
557 | auto inner_ker = [&](int icb, int ocb, int n, int g, int od, int oh, int ow, |
558 | int id, int ih, int iw) { |
559 | const bool is_dsrc_layout_nxc = utils::one_of(jcp.src_tag, |
560 | format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
561 | const int ic_off_idx = is_dsrc_layout_nxc |
562 | ? g * jcp.ic + icb * jcp.ic_block |
563 | : g * nb_ic + icb; |
564 | const size_t diff_src_off |
565 | = data_blk_off(diff_src_d, n, ic_off_idx, id, ih, iw); |
566 | |
567 | rp.src = diff_src + diff_src_off; |
568 | if (pd()->rtus_.reduce_src_) { |
569 | rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_; |
570 | p.output_data = rp.ws; |
571 | } else |
572 | p.output_data = rp.src; |
573 | p.load_data |
574 | = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) |
575 | : weights_d.blk_off(ocb, icb)]; |
576 | |
577 | const bool is_ddst_layout_nxc = utils::one_of(jcp.dst_tag, |
578 | format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
579 | const int oc_off_idx = is_ddst_layout_nxc |
580 | ? g * jcp.oc + ocb * jcp.oc_block |
581 | : g * nb_oc + ocb; |
582 | p.bcast_data = diff_dst |
583 | + data_blk_off(diff_dst_d, n, oc_off_idx, od, oh, ow); |
584 | |
585 | const size_t grp_count = utils::div_up( |
586 | jcp.nthr, utils::div_up(jcp.nthr, jcp.load_grp_count)); |
587 | const size_t max_load_per_thread = is_dsrc_layout_nxc |
588 | ? jcp.load_dim |
589 | : rnd_up((jcp.load_dim / grp_count), jcp.load_block); |
590 | const size_t str_size = jcp.bcast_dim * max_load_per_thread; |
591 | p.store_buffer = store_buffer + ithr * str_size |
592 | + data_blk_off(diff_src_d, 0, 0, id, ih, iw); |
593 | (*kernel_)(&p); |
594 | if (pd()->rtus_.reduce_src_) (*rtus_driver_)(&rp); |
595 | }; |
596 | |
597 | if (jcp.loop_order == loop_lbr) { |
598 | int icb = icb_start; |
599 | while (icb < icb_end) { |
600 | int load_step; |
601 | init_load(icb, load_step); |
602 | int iwork = bcast_start; |
603 | while (iwork < bcast_end) { |
604 | int n, g, bcast_step, od, oh, ow, id, ih, iw; |
605 | init_bcast(iwork, n, g, bcast_step, od, oh, ow, id, ih, iw); |
606 | for (int ocb = 0; ocb < nb_oc; ocb += nb_oc_blocking) { |
607 | init_reduce(ocb); |
608 | inner_ker(icb, ocb, n, g, od, oh, ow, id, ih, iw); |
609 | } |
610 | iwork += bcast_step; |
611 | } |
612 | icb += load_step; |
613 | } |
614 | } else { |
615 | assert(!"unsupported loop order" ); |
616 | } |
617 | } |
618 | |
619 | REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t< |
620 | data_type::f32>); |
621 | REG_AVX512_ISA(template struct jit_avx512_core_bf16_1x1_convolution_bwd_data_t< |
622 | data_type::bf16>); |
623 | |
624 | /* convolution backward wtr weights */ |
625 | |
626 | #define wht_blk_off(d, g, ...) \ |
627 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
628 | : (d).blk_off(__VA_ARGS__)) |
629 | |
630 | template <data_type_t diff_weights_type> |
631 | status_t |
632 | jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<diff_weights_type>::init( |
633 | engine_t *engine) { |
634 | CHECK(safe_ptr_assign(kernel_, |
635 | new jit_avx512_core_bf16_1x1_conv_kernel( |
636 | pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); |
637 | |
638 | CHECK(safe_ptr_assign( |
639 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
640 | CHECK(kernel_->create_kernel()); |
641 | CHECK(acc_ker_->create_kernel()); |
642 | |
643 | if (!pd()->jcp_.uses_permw_transposition) { |
644 | const bool is_src_layout_nxc = utils::one_of(pd()->jcp_.src_tag, |
645 | format_tag::ndhwc, format_tag::nhwc, format_tag::nwc); |
646 | const bool is_ddst_layout_nxc = utils::one_of(pd()->jcp_.dst_tag, |
647 | format_tag::ndhwc, format_tag::nhwc, format_tag::nwc); |
648 | if (!is_src_layout_nxc || !is_ddst_layout_nxc) { |
649 | CHECK(safe_ptr_assign(tr_reorder_, |
650 | new jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t())); |
651 | CHECK(tr_reorder_->create_kernel()); |
652 | } |
653 | if (is_src_layout_nxc) { |
654 | int ic = pd()->jcp_.ic * pd()->jcp_.ngroups; |
655 | CHECK(safe_ptr_assign(tr_reorder_nhwc_src_, |
656 | new jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t(ic))); |
657 | CHECK(tr_reorder_nhwc_src_->create_kernel()); |
658 | } |
659 | if (is_ddst_layout_nxc) { |
660 | int oc = pd()->jcp_.oc * pd()->jcp_.ngroups; |
661 | CHECK(safe_ptr_assign(tr_reorder_nhwc_ddst_, |
662 | new jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t(oc))); |
663 | CHECK(tr_reorder_nhwc_ddst_->create_kernel()); |
664 | } |
665 | } |
666 | |
667 | CHECK(init_rtus_driver<avx512_core>(this)); |
668 | return status::success; |
669 | } |
670 | |
671 | template <data_type_t diff_weights_type> |
672 | void jit_avx512_core_bf16_1x1_convolution_bwd_weights_t<diff_weights_type>:: |
673 | execute_backward_weights(const exec_ctx_t &ctx) const { |
674 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
675 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
676 | auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, DNNL_ARG_DIFF_WEIGHTS); |
677 | |
678 | auto scratchpad = ctx.get_scratchpad_grantor(); |
679 | const auto &jcp = pd()->jcp_; |
680 | |
681 | float *diff_bias = nullptr; |
682 | if (jcp.with_bias && pd()->jcp_.bia_dt == data_type::f32) { |
683 | diff_bias = pd()->with_bias() && jcp.oc_without_padding % jcp.oc_block |
684 | ? scratchpad.template get<float>(key_conv_padded_bias) |
685 | : CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
686 | } |
687 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
688 | const memory_desc_wrapper src_d(pd()->src_md()); |
689 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
690 | |
691 | auto rtus_space = scratchpad.template get<src_data_t>(key_conv_rtus_space); |
692 | auto wei_reduction = scratchpad.template get<float>(key_conv_wei_reduction); |
693 | |
694 | auto tr_src_buffer = !jcp.uses_permw_transposition |
695 | ? scratchpad.template get<src_data_t>(key_conv_tr_src) |
696 | : nullptr; |
697 | auto tr_diff_buffer = !jcp.uses_permw_transposition |
698 | ? scratchpad.template get<diff_dst_data_t>(key_conv_tr_diff_dst) |
699 | : nullptr; |
700 | |
701 | const int ndims = src_d.ndims(); |
702 | const int wei_size = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block) |
703 | * rnd_up(jcp.ic, jcp.ic_block); |
704 | const int n_wei_buffers |
705 | = jcp.dst_dt == data_type::bf16 ? jcp.nthr_mb : jcp.nthr_mb - 1; |
706 | auto bia_reduction = wei_reduction + n_wei_buffers * wei_size; |
707 | |
708 | simple_barrier::ctx_t reduction_barrier; |
709 | if (dnnl_thr_syncable()) simple_barrier::ctx_init(&reduction_barrier); |
710 | |
711 | // TODO (Roma): remove this restriction |
712 | assert(jcp.stride_w == 1 && jcp.stride_h == 1); |
713 | |
714 | const int nb_ic_blocking = jcp.nb_bcast_blocking; |
715 | |
716 | const int nb_oc = jcp.nb_load; |
717 | const int nb_oc_blocking = jcp.nb_load_blocking; |
718 | |
719 | const int sp_nb = jcp.nb_reduce; |
720 | const int mb_sp_work = jcp.mb * sp_nb; |
721 | |
722 | const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0]; |
723 | const int stride_w = pd()->desc()->strides[ndims - 3]; |
724 | |
725 | auto step = [](int default_step, int remaining, int tail_step) { |
726 | assert(default_step <= tail_step); |
727 | return remaining < tail_step ? remaining : default_step; |
728 | }; |
729 | |
730 | const bool is_ddst_layout_nxc = utils::one_of( |
731 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
732 | |
733 | auto maybe_zero_icpad = [&](const int g_start, const int g_end, |
734 | const int ocb_start, const int ocb_end) { |
735 | // write zeros to IC padded region. |
736 | const int ic_tail = jcp.ic_without_padding % jcp.ic_block; |
737 | if (ic_tail != 0) { |
738 | for_(int g = g_start; g < g_end; ++g) |
739 | for (int z_ocb = ocb_start; z_ocb < ocb_end; ++z_ocb) { |
740 | const int z_icb = jcp.nb_bcast - 1; |
741 | const size_t off = wht_blk_off(diff_weights_d, g, z_ocb, z_icb) |
742 | + ic_tail * jcp.oc_block; |
743 | diff_wei_data_t *z_wei = diff_weights + off; |
744 | const int zero_work |
745 | = (jcp.nb_bcast * jcp.ic_block - jcp.ic_without_padding) |
746 | * jcp.oc_block; |
747 | PRAGMA_OMP_SIMD() |
748 | for (int o = 0; o < zero_work; ++o) { |
749 | z_wei[o] = 0; |
750 | } |
751 | } |
752 | } |
753 | }; |
754 | |
755 | auto ker = [&](const int ithr, const int nthr) { |
756 | assert(nthr == jcp.nthr); |
757 | |
758 | const int ithr_ic_b = ithr % jcp.nthr_ic_b; |
759 | const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; |
760 | const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; |
761 | const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; |
762 | |
763 | /* reduction dimension */ |
764 | int mb_sp_b_start {0}, mb_sp_b_end {0}; |
765 | balance211( |
766 | mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start, mb_sp_b_end); |
767 | |
768 | /* independent dimensions */ |
769 | int g_start {0}, oc_b_start {0}, ic_b_start {0}; |
770 | int g_end {0}, oc_b_end {0}, ic_b_end {0}; |
771 | |
772 | balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); |
773 | balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end); |
774 | balance211( |
775 | jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end); |
776 | |
777 | float *diff_wei; |
778 | if (diff_weights_type == data_type::bf16) { |
779 | diff_wei = wei_reduction + (ithr_mb)*wei_size; |
780 | } else { |
781 | diff_wei = ithr_mb == 0 |
782 | ? (float *)diff_weights |
783 | : (float *)wei_reduction + (ithr_mb - 1) * wei_size; |
784 | } |
785 | |
786 | float *diff_bia = nullptr; |
787 | if (jcp.with_bias) { |
788 | const int bias_size = jcp.ngroups * jcp.nb_load * jcp.oc_block; |
789 | if (jcp.bia_dt == data_type::bf16) { |
790 | diff_bia = bia_reduction + (ithr_mb)*bias_size; |
791 | } else { |
792 | diff_bia = ithr_mb == 0 |
793 | ? (float *)diff_bias |
794 | : (float *)bia_reduction + (ithr_mb - 1) * bias_size; |
795 | } |
796 | } |
797 | |
798 | int sp_b_step = 0; |
799 | for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end; |
800 | mb_sp_b += sp_b_step) { |
801 | int img {0}, sp_b {0}; |
802 | nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb); |
803 | sp_b_step = step(jcp.nb_reduce_blocking, |
804 | nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b), |
805 | jcp.nb_reduce_blocking_max); |
806 | |
807 | for (int g = g_start; g < g_end; ++g) { |
808 | int load_step = 0; |
809 | int bcast_step = 0; |
810 | for (int ic_b = ic_b_start; ic_b < ic_b_end; |
811 | ic_b += bcast_step) { |
812 | bcast_step = step(nb_ic_blocking, ic_b_end - ic_b, |
813 | jcp.nb_bcast_blocking_max); |
814 | for (int oc_b = oc_b_start; oc_b < oc_b_end; |
815 | oc_b += load_step) { |
816 | load_step = step(nb_oc_blocking, oc_b_end - oc_b, |
817 | jcp.nb_load_blocking_max); |
818 | |
819 | float *store_to; |
820 | |
821 | const size_t off |
822 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
823 | store_to = diff_wei + off; |
824 | |
825 | const bool is_src_layout_nxc |
826 | = utils::one_of(jcp.src_tag, format_tag::nwc, |
827 | format_tag::nhwc, format_tag::ndhwc); |
828 | const int ic_off_idx = is_src_layout_nxc |
829 | ? g * jcp.ic + ic_b * jcp.ic_block |
830 | : g * nb_oc + ic_b; |
831 | const src_data_t *diff_src |
832 | = &src[src_d.blk_off(img, ic_off_idx)]; |
833 | const int oc_off_idx = is_ddst_layout_nxc |
834 | ? g * jcp.oc + oc_b * jcp.oc_block |
835 | : g * nb_oc + oc_b; |
836 | const diff_dst_data_t *pdiff_dst |
837 | = &diff_dst[diff_dst_d.blk_off( |
838 | img, oc_off_idx)]; |
839 | const src_data_t *local_src = diff_src; |
840 | |
841 | auto p = jit_1x1_conv_call_s(); |
842 | auto rp = rtus_driver_t<avx512_core>::call_params_t(); |
843 | |
844 | p.output_stride = utils::rnd_up(jcp.ic, jcp.oc_block) |
845 | * jcp.oc_block * jcp.typesize_out; |
846 | |
847 | p.load_dim = this_block_size(oc_b * jcp.oc_block, |
848 | jcp.oc, load_step * jcp.oc_block); |
849 | |
850 | p.bcast_dim = this_block_size(ic_b * jcp.ic_block, |
851 | jcp.ic, bcast_step * jcp.ic_block); |
852 | rp.icb = p.bcast_dim; |
853 | p.output_data = store_to; |
854 | |
855 | p.reduce_dim = sp_b_step * jcp.reduce_block; |
856 | if (!jcp.uses_permw_transposition) |
857 | p.reduce_dim = nstl::min(p.reduce_dim, |
858 | (size_t)jcp.reduce_dim |
859 | - sp_b * jcp.reduce_block); |
860 | |
861 | rp.os = p.reduce_dim; |
862 | |
863 | p.first_last_flag = 0 |
864 | | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST |
865 | : 0) |
866 | | (ic_b == 0 ? FLAG_COMPUTE_BIAS : 0); |
867 | |
868 | int sp = sp_b * jcp.reduce_block; |
869 | int oc_mult = is_ddst_layout_nxc ? jcp.ngroups * jcp.oc |
870 | : jcp.oc_block; |
871 | p.load_data = pdiff_dst + sp * oc_mult; |
872 | |
873 | if (pd()->rtus_.reduce_src_) { |
874 | const int oh = sp / jcp.ow; |
875 | const int ow = sp % jcp.ow; |
876 | |
877 | const int ih = oh * stride_h; |
878 | const int iw = ow * stride_w; |
879 | rp.iw_start = iw; |
880 | |
881 | rp.ws = rtus_space |
882 | + ithr * pd()->rtus_.space_per_thread_ |
883 | + sp * jcp.ic_block; |
884 | |
885 | if (ndims == 3) |
886 | rp.src = local_src |
887 | + iw * src_d.blocking_desc().strides[2]; |
888 | else |
889 | rp.src = local_src |
890 | + ih * src_d.blocking_desc().strides[2] |
891 | + iw * src_d.blocking_desc().strides[3]; |
892 | (*rtus_driver_)(&rp); |
893 | |
894 | p.bcast_data = rp.ws; |
895 | } else { |
896 | int ic_mult = is_src_layout_nxc |
897 | ? jcp.ngroups * jcp.ic |
898 | : jcp.ic_block; |
899 | p.bcast_data = local_src + sp * ic_mult; |
900 | } |
901 | if (!jcp.uses_permw_transposition) { |
902 | bf16_support::jit_call_t ptr; |
903 | ptr.nelems = p.reduce_dim; |
904 | int thr_src_block_size = rnd_up(jcp.reduce_dim, 2) |
905 | * jcp.ic_block * jcp.nb_bcast_blocking_max; |
906 | src_data_t *tr_src |
907 | = &tr_src_buffer[ithr * thr_src_block_size]; |
908 | for (int bs = 0; bs < bcast_step; bs++) { |
909 | size_t src_off = bs * jcp.ic_block |
910 | * (is_src_layout_nxc ? 1 |
911 | : jcp.reduce_dim); |
912 | size_t src_tr_off = bs |
913 | * rnd_up(jcp.reduce_dim, 2) |
914 | * jcp.ic_block; |
915 | src_data_t *curr_inp = &( |
916 | (src_data_t *)p.bcast_data)[src_off]; |
917 | src_data_t *curr_out = &tr_src[src_tr_off]; |
918 | int ch_work = nstl::min<int>( |
919 | p.bcast_dim - bs * jcp.bcast_block, 16); |
920 | assert(ch_work <= 16); |
921 | ptr.mask = (1 << ch_work) - 1; |
922 | ptr.inp = (void *)curr_inp; |
923 | ptr.out = (void *)curr_out; |
924 | if (is_src_layout_nxc) |
925 | (*tr_reorder_nhwc_src_)(&ptr); |
926 | else |
927 | (*tr_reorder_)(&ptr); |
928 | } |
929 | |
930 | p.bcast_data = (void *)tr_src; |
931 | int thr_dst_block_size = rnd_up(jcp.reduce_dim, 2) |
932 | * jcp.oc_block * jcp.nb_load_blocking_max; |
933 | diff_dst_data_t *tr_diff_dst = &tr_diff_buffer[ithr |
934 | * thr_dst_block_size]; |
935 | for (int ls = 0; ls < load_step; ls++) { |
936 | size_t ddst_off = ls * jcp.oc_block |
937 | * (is_ddst_layout_nxc ? 1 : jcp.os); |
938 | size_t ddst_tr_off = ls |
939 | * rnd_up(jcp.reduce_dim, 2) |
940 | * jcp.oc_block; |
941 | diff_dst_data_t *curr_inp |
942 | = &((diff_dst_data_t *) |
943 | p.load_data)[ddst_off]; |
944 | diff_dst_data_t *curr_out |
945 | = &tr_diff_dst[ddst_tr_off]; |
946 | int ch_work = nstl::min<int>( |
947 | p.load_dim - ls * jcp.load_block, 16); |
948 | ptr.mask = (1 << ch_work) - 1; |
949 | ptr.inp = (void *)curr_inp; |
950 | ptr.out = (void *)curr_out; |
951 | if (is_ddst_layout_nxc) |
952 | (*tr_reorder_nhwc_ddst_)(&ptr); |
953 | else |
954 | (*tr_reorder_)(&ptr); |
955 | } |
956 | p.load_data = (void *)tr_diff_dst; |
957 | } //if (!jcp.uses_permw_transposition) |
958 | |
959 | p.bias_data = diff_bia |
960 | ? &diff_bia[oc_off_idx |
961 | * (is_ddst_layout_nxc ? 1 |
962 | : jcp.oc_block)] |
963 | : nullptr; |
964 | (*kernel_)(&p); |
965 | } |
966 | } |
967 | } |
968 | } |
969 | }; |
970 | |
971 | auto ker_reduce_and_convert_diff_wei_bia = [&](const int ithr, |
972 | const int nthr) { |
973 | assert(nthr == jcp.nthr); |
974 | |
975 | const int ithr_ic_b = ithr % jcp.nthr_ic_b; |
976 | const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b; |
977 | const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g; |
978 | const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b / jcp.nthr_g; |
979 | |
980 | /* independent dimensions */ |
981 | int g_start {0}, oc_b_start {0}, ic_b_start {0}; |
982 | int g_end {0}, oc_b_end {0}, ic_b_end {0}; |
983 | |
984 | balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end); |
985 | balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start, oc_b_end); |
986 | balance211( |
987 | jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start, ic_b_end); |
988 | |
989 | const int g_work = g_end - g_start; |
990 | const int oc_b_work = oc_b_end - oc_b_start; |
991 | const int ic_b_work = ic_b_end - ic_b_start; |
992 | |
993 | const int _start_nthr_mb = 1; |
994 | const bool is_bf16_out = diff_weights_type == data_type::bf16; |
995 | const bool is_bf16_bias |
996 | = jcp.with_bias && jcp.bia_dt == data_type::bf16; |
997 | /* diff_weights[:] += sum(ws_reduction_[thr_mb][:]) */ |
998 | if (jcp.nthr_mb > _start_nthr_mb) { |
999 | if (dnnl_thr_syncable()) |
1000 | simple_barrier::barrier(&reduction_barrier, jcp.nthr); |
1001 | const int work = g_work * oc_b_work * ic_b_work; |
1002 | int start {0}, end {0}; |
1003 | balance211(work, jcp.nthr_mb, ithr_mb, start, end); |
1004 | if (start == end) return; |
1005 | |
1006 | for (int thr_mb = _start_nthr_mb; thr_mb < jcp.nthr_mb; ++thr_mb) { |
1007 | int w = start; |
1008 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_start {0}; |
1009 | nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start, |
1010 | oc_b_work, sub_ic_b_start, ic_b_work); |
1011 | while (w < end) { |
1012 | const int g = g_start + sub_g_start; |
1013 | const int oc_b = oc_b_start + sub_oc_b_start; |
1014 | const int ic_b = ic_b_start + sub_ic_b_start; |
1015 | const int ic_to_accumulate |
1016 | = nstl::min(end - w, ic_b_work - sub_ic_b_start) |
1017 | * jcp.ic_block; |
1018 | const int acc_size |
1019 | = this_block_size(ic_b * jcp.ic_block, |
1020 | jcp.ic_without_padding, ic_to_accumulate) |
1021 | * jcp.oc_block; |
1022 | |
1023 | const size_t off |
1024 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1025 | float *wei_reduced = is_bf16_out |
1026 | ? wei_reduction + off |
1027 | : (float *)diff_weights + off; |
1028 | |
1029 | int thr_mb_buffer_idx = is_bf16_out ? thr_mb : thr_mb - 1; |
1030 | float *wei_to_reduce = wei_reduction |
1031 | + thr_mb_buffer_idx * wei_size + off; |
1032 | if (is_bf16_out && thr_mb == jcp.nthr_mb - 1) |
1033 | // the last iteration for bfloat16 requires conversion |
1034 | // and store to diff_weights array |
1035 | add_floats_and_cvt_to_bfloat16( |
1036 | (bfloat16_t *)(diff_weights + off), wei_reduced, |
1037 | wei_to_reduce, acc_size); |
1038 | else |
1039 | acc_ker_->accumulate( |
1040 | wei_reduced, wei_to_reduce, acc_size); |
1041 | |
1042 | nd_iterator_jump(w, end, sub_g_start, g_work, |
1043 | sub_oc_b_start, oc_b_work, sub_ic_b_start, |
1044 | ic_b_work); |
1045 | } |
1046 | |
1047 | if (jcp.with_bias && ithr_ic_b == 0 && ic_b_work > 0 |
1048 | && ithr_mb == 0) { |
1049 | for (int g = g_start; g < g_end; g++) { |
1050 | float *bias_reduced |
1051 | = is_bf16_bias ? bia_reduction : diff_bias; |
1052 | int thr_mb_buffer_idx |
1053 | = is_bf16_bias ? thr_mb : thr_mb - 1; |
1054 | int bias_buf_size |
1055 | = jcp.ngroups * rnd_up(jcp.oc, jcp.oc_block); |
1056 | float *bias_to_reduce = bia_reduction |
1057 | + thr_mb_buffer_idx * bias_buf_size; |
1058 | const size_t acc_size |
1059 | = this_block_size(oc_b_start * jcp.oc_block, |
1060 | jcp.oc_without_padding, |
1061 | (oc_b_end - oc_b_start) * jcp.oc_block); |
1062 | int idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1063 | + oc_b_start * jcp.oc_block; |
1064 | if (is_bf16_bias && thr_mb == jcp.nthr_mb - 1) { |
1065 | // the last iteration for bfloat16 requires conversion and |
1066 | // store to diff_weights array |
1067 | int diff_bias_idx = g * jcp.oc_without_padding |
1068 | + oc_b_start * jcp.oc_block; |
1069 | bfloat16_t *diff_bias_result |
1070 | = CTX_OUT_MEM( |
1071 | bfloat16_t *, DNNL_ARG_DIFF_BIAS) |
1072 | + diff_bias_idx; |
1073 | add_floats_and_cvt_to_bfloat16(diff_bias_result, |
1074 | &bias_reduced[idx], &bias_to_reduce[idx], |
1075 | acc_size); |
1076 | } else { |
1077 | acc_ker_->accumulate(&bias_reduced[idx], |
1078 | &bias_to_reduce[idx], acc_size); |
1079 | } |
1080 | } |
1081 | } |
1082 | } |
1083 | } else { |
1084 | if (is_bf16_out) { |
1085 | const auto ic_work = nstl::min(jcp.ic, ic_b_end * jcp.ic_block) |
1086 | - ic_b_start * jcp.ic_block; |
1087 | for_(int g = g_start; g < g_end; g++) |
1088 | for (int oc_b = oc_b_start; oc_b < oc_b_end; oc_b++) { |
1089 | const size_t acc_size = (size_t)ic_work * jcp.oc_block; |
1090 | const size_t off |
1091 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b_start); |
1092 | |
1093 | cvt_float_to_bfloat16((bfloat16_t *)(diff_weights + off), |
1094 | (const float *)(wei_reduction + off), acc_size); |
1095 | } |
1096 | } |
1097 | |
1098 | if (is_bf16_bias && ithr_ic_b == 0 && ic_b_work > 0) { |
1099 | for (int g = g_start; g < g_end; g++) { |
1100 | int result_start_idx = g * jcp.oc_without_padding |
1101 | + oc_b_start * jcp.oc_block; |
1102 | int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1103 | + oc_b_start * jcp.oc_block; |
1104 | const size_t acc_size = nstl::min(jcp.oc_without_padding, |
1105 | oc_b_end * jcp.oc_block) |
1106 | - oc_b_start * jcp.oc_block; |
1107 | bfloat16_t *diff_bias_result |
1108 | = CTX_OUT_MEM(bfloat16_t *, DNNL_ARG_DIFF_BIAS) |
1109 | + result_start_idx; |
1110 | float *buffer = bia_reduction + buffer_start_idx; |
1111 | cvt_float_to_bfloat16(diff_bias_result, buffer, acc_size); |
1112 | } |
1113 | } |
1114 | } |
1115 | if (ic_b_end >= jcp.nb_bcast) { |
1116 | maybe_zero_icpad(g_start, g_end, oc_b_start, oc_b_end); |
1117 | } |
1118 | }; |
1119 | |
1120 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1121 | assert(nthr == jcp.nthr); |
1122 | ker(ithr, jcp.nthr); |
1123 | if (dnnl_thr_syncable()) |
1124 | ker_reduce_and_convert_diff_wei_bia(ithr, jcp.nthr); |
1125 | }); |
1126 | |
1127 | if (!dnnl_thr_syncable()) { |
1128 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
1129 | assert(nthr == jcp.nthr); |
1130 | ker_reduce_and_convert_diff_wei_bia(ithr, jcp.nthr); |
1131 | }); |
1132 | } |
1133 | |
1134 | if (pd()->jcp_.bia_dt == data_type::f32 |
1135 | && jcp.oc_without_padding % jcp.oc_block) { |
1136 | auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
1137 | utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding); |
1138 | } |
1139 | } |
1140 | |
1141 | REG_AVX512_ISA( |
1142 | template struct jit_avx512_core_bf16_1x1_convolution_bwd_weights_t< |
1143 | data_type::f32>); |
1144 | REG_AVX512_ISA( |
1145 | template struct jit_avx512_core_bf16_1x1_convolution_bwd_weights_t< |
1146 | data_type::bf16>); |
1147 | |
1148 | } // namespace x64 |
1149 | } // namespace cpu |
1150 | } // namespace impl |
1151 | } // namespace dnnl |
1152 | |