1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "common/bfloat16.hpp" |
23 | #include "cpu/x64/jit_avx512_core_bf16_convolution.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace dnnl::impl::status; |
31 | using namespace dnnl::impl::memory_tracking::names; |
32 | using namespace dnnl::impl::utils; |
33 | |
34 | using namespace nstl; |
35 | |
36 | #define wht_blk_off(d, g, ...) \ |
37 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
38 | : (d).blk_off(__VA_ARGS__)) |
39 | |
40 | void jit_avx512_core_bf16_convolution_fwd_t::prepare_padded_bias( |
41 | const char *&bias, const memory_tracking::grantor_t &scratchpad) const { |
42 | if (!pd()->wants_padded_bias()) return; |
43 | |
44 | const size_t bia_dt_size = pd()->jcp_.typesize_bia; |
45 | auto padded_bias = scratchpad.template get<char>( |
46 | memory_tracking::names::key_conv_padded_bias); |
47 | utils::array_copy( |
48 | padded_bias, bias, bia_dt_size * pd()->jcp_.oc_without_padding); |
49 | utils::array_set(padded_bias + bia_dt_size * pd()->jcp_.oc_without_padding, |
50 | 0.f, bia_dt_size * (pd()->jcp_.oc - pd()->jcp_.oc_without_padding)); |
51 | bias = padded_bias; |
52 | } |
53 | |
54 | void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_1d( |
55 | const exec_ctx_t &ctx) const { |
56 | const auto &jcp = pd()->jcp_; |
57 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
58 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
59 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
60 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
61 | const auto post_ops_binary_rhs_arg_vec |
62 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
63 | |
64 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
65 | |
66 | const size_t bia_dt_size = pd()->jcp_.typesize_bia; |
67 | |
68 | const memory_desc_wrapper src_d(pd()->src_md()); |
69 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
70 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
71 | |
72 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
73 | |
74 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
75 | // TODO: experiment with g_blocking for perf fine tuning |
76 | int g_blocking = 1; |
77 | int nb_groups = jcp.ngroups / g_blocking; |
78 | dim_t work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow; |
79 | |
80 | int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr; |
81 | parallel(nthr, [&](const int ithr, const int nthr) { |
82 | dim_t start {0}, end {0}; |
83 | balance211(work_amount, nthr, ithr, start, end); |
84 | auto par_conv = jit_conv_call_s(); |
85 | |
86 | int n {0}, gg {0}, occ {0}, owb {0}; |
87 | |
88 | if (jcp.loop_order == loop_cwgn) { |
89 | int dummy {0}; |
90 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
91 | nb_groups, n, jcp.mb, dummy, 1); |
92 | } else if (jcp.loop_order == loop_gncw) { |
93 | int dummy {0}; |
94 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, |
95 | owb, jcp.nb_ow, dummy, 1); |
96 | } else if (jcp.loop_order == loop_nhwcg) { |
97 | nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, |
98 | gg, nb_groups); |
99 | } else |
100 | assert(!"unsupported loop order" ); |
101 | |
102 | while (start < end) { |
103 | int ocb = occ * jcp.nb_oc_blocking; |
104 | int g = gg * g_blocking; |
105 | int ow_s = owb * jcp.ow_block; |
106 | int iw_s = ow_s * jcp.stride_w; |
107 | |
108 | const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::nwc; |
109 | const int oc_idx = is_dst_layout_nxc |
110 | ? g * jcp.oc + ocb * jcp.oc_block |
111 | : g * jcp.nb_oc + ocb; |
112 | auto dst_w |
113 | = dst + jcp.typesize_out * dst_d.blk_off(n, oc_idx, ow_s); |
114 | auto bias_w = bias ? bias |
115 | + bia_dt_size * oc_idx |
116 | * (is_dst_layout_nxc ? 1 : jcp.oc_block) |
117 | : nullptr; |
118 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::nwc; |
119 | const int ic_idx = is_src_layout_nxc ? g * jcp.ic : g * jcp.nb_ic; |
120 | auto src_w = src + src_d.blk_off(n, ic_idx, iw_s); |
121 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb); |
122 | |
123 | par_conv.load_work = this_block_size(ocb * jcp.oc_block, |
124 | jcp.oc_without_padding, jcp.nb_oc_blocking * jcp.oc_block); |
125 | par_conv.src = src_w; |
126 | par_conv.dst = dst_w; |
127 | par_conv.filt = wht_w; |
128 | par_conv.bias = bias_w; |
129 | par_conv.owb = owb; |
130 | |
131 | par_conv.oc_l_off = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
132 | par_conv.dst_orig = dst; |
133 | par_conv.post_ops_binary_rhs_arg_vec |
134 | = post_ops_binary_rhs_arg_vec.data(); |
135 | (*kernel_)(&par_conv); |
136 | |
137 | if (jcp.loop_order == loop_cwgn) { |
138 | int dummy {0}; |
139 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg, |
140 | nb_groups, n, jcp.mb, dummy, 1); |
141 | } else if (jcp.loop_order == loop_gncw) { |
142 | int dummy {0}; |
143 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, |
144 | oc_chunks, owb, jcp.nb_ow, dummy, 1); |
145 | } else if (jcp.loop_order == loop_nhwcg) { |
146 | ++start; |
147 | nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg, |
148 | nb_groups); |
149 | } else |
150 | assert(!"unsupported loop order" ); |
151 | } |
152 | }); |
153 | } |
154 | |
155 | void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_2d( |
156 | const exec_ctx_t &ctx) const { |
157 | const auto &jcp = pd()->jcp_; |
158 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
159 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
160 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
161 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
162 | const auto post_ops_binary_rhs_arg_vec |
163 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
164 | |
165 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
166 | |
167 | const size_t bia_dt_size = pd()->jcp_.typesize_bia; |
168 | |
169 | const memory_desc_wrapper src_d(pd()->src_md()); |
170 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
171 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
172 | |
173 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
174 | |
175 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
176 | // TODO: experiment with g_blocking for perf fine tuning |
177 | int g_blocking = 1; |
178 | int nb_groups = jcp.ngroups / g_blocking; |
179 | dim_t work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow; |
180 | |
181 | int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr; |
182 | parallel(nthr, [&](const int ithr, const int nthr) { |
183 | dim_t start {0}, end {0}; |
184 | balance211(work_amount, nthr, ithr, start, end); |
185 | auto par_conv = jit_conv_call_s(); |
186 | |
187 | size_t src_h_stride = src_d.blk_off(0, 0, 1); |
188 | size_t dst_h_stride = dst_d.blk_off(0, 0, 1); |
189 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
190 | |
191 | int n {0}, gg {0}, occ {0}, oh_s {0}, owb {0}; |
192 | |
193 | if (jcp.loop_order == loop_cwgn) |
194 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
195 | nb_groups, n, jcp.mb, oh_s, jcp.oh); |
196 | else if (jcp.loop_order == loop_gncw) |
197 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, |
198 | owb, jcp.nb_ow, oh_s, jcp.oh); |
199 | else if (jcp.loop_order == loop_nhwcg) |
200 | nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, |
201 | occ, oc_chunks, gg, nb_groups); |
202 | else |
203 | assert(!"unsupported loop order" ); |
204 | |
205 | while (start < end) { |
206 | |
207 | int ocb = occ * jcp.nb_oc_blocking; |
208 | int g = gg * g_blocking; |
209 | int work_rem = end - start; |
210 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
211 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
212 | if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; //step instead |
213 | |
214 | int ow_s = owb * jcp.ow_block; |
215 | int iw_s = ow_s * jcp.stride_w; |
216 | |
217 | const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::nhwc; |
218 | const int oc_idx = is_dst_layout_nxc |
219 | ? g * jcp.oc + ocb * jcp.oc_block |
220 | : g * jcp.nb_oc + ocb; |
221 | auto dst_w = dst |
222 | + jcp.typesize_out * dst_d.blk_off(n, oc_idx, oh_s, ow_s); |
223 | auto bias_w = bias ? bias |
224 | + bia_dt_size * oc_idx |
225 | * (is_dst_layout_nxc ? 1 : jcp.oc_block) |
226 | : nullptr; |
227 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::nhwc; |
228 | const int ic_idx = is_src_layout_nxc ? g * jcp.ic : g * jcp.nb_ic; |
229 | auto src_w = src + src_d.blk_off(n, ic_idx, ih_s, iw_s); |
230 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0); |
231 | |
232 | for (int oj = oh_s, ij = ih_s; oj < oh_e; |
233 | ++oj, ij += jcp.stride_h) { |
234 | int dilate_h = jcp.dilate_h + 1; |
235 | int i_t_overflow = div_up(max(0, -ij), dilate_h); |
236 | int i_b_overflow = div_up( |
237 | max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1), |
238 | dilate_h); |
239 | int kh_padding |
240 | = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); |
241 | auto aux_src = src_w + i_t_overflow * dilate_h * src_h_stride; |
242 | auto aux_wht = wht_w + i_t_overflow * wht_h_stride; |
243 | |
244 | par_conv.load_work = utils::this_block_size(ocb * jcp.oc_block, |
245 | jcp.oc_without_padding, |
246 | jcp.nb_oc_blocking * jcp.oc_block); |
247 | par_conv.src = aux_src; |
248 | par_conv.dst = dst_w; |
249 | par_conv.filt = aux_wht; |
250 | par_conv.bias = bias_w; |
251 | par_conv.kh_padding = kh_padding; |
252 | par_conv.owb = owb; |
253 | |
254 | par_conv.oc_l_off |
255 | = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
256 | par_conv.dst_orig = dst; |
257 | par_conv.post_ops_binary_rhs_arg_vec |
258 | = post_ops_binary_rhs_arg_vec.data(); |
259 | (*kernel_)(&par_conv); |
260 | |
261 | src_w += src_h_stride * jcp.stride_h; |
262 | dst_w += jcp.typesize_out * dst_h_stride; |
263 | } |
264 | if (jcp.loop_order == loop_cwgn) |
265 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg, |
266 | nb_groups, n, jcp.mb, oh_s, jcp.oh); |
267 | else if (jcp.loop_order == loop_gncw) |
268 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, |
269 | oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh); |
270 | else if (jcp.loop_order == loop_nhwcg) { |
271 | ++start; |
272 | nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ, |
273 | oc_chunks, gg, nb_groups); |
274 | } else |
275 | assert(!"unsupported loop order" ); |
276 | } |
277 | }); |
278 | } |
279 | |
280 | void jit_avx512_core_bf16_convolution_fwd_t::execute_forward_3d( |
281 | const exec_ctx_t &ctx) const { |
282 | const auto &jcp = pd()->jcp_; |
283 | auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
284 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
285 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
286 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
287 | const auto post_ops_binary_rhs_arg_vec |
288 | = binary_injector::prepare_binary_args(jcp.post_ops, ctx); |
289 | |
290 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
291 | |
292 | const size_t bia_dt_size = pd()->jcp_.typesize_bia; |
293 | |
294 | const memory_desc_wrapper src_d(pd()->src_md()); |
295 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
296 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
297 | |
298 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
299 | |
300 | int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
301 | // TODO: experiment with g_blocking for perf fine tuning |
302 | int g_blocking = 1; |
303 | int nb_groups = jcp.ngroups / g_blocking; |
304 | dim_t work_amount |
305 | = jcp.mb * nb_groups * oc_chunks * jcp.od * jcp.oh * jcp.nb_ow; |
306 | |
307 | int nthr = jcp.aligned_threads ? jcp.aligned_threads : jcp.nthr; |
308 | parallel(nthr, [&](const int ithr, const int nthr) { |
309 | dim_t start {0}, end {0}; |
310 | balance211(work_amount, nthr, ithr, start, end); |
311 | auto par_conv = jit_conv_call_s(); |
312 | |
313 | size_t src_d_stride = src_d.blk_off(0, 0, 1); |
314 | size_t src_h_stride = src_d.blk_off(0, 0, 0, 1); |
315 | size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1); |
316 | size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
317 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
318 | |
319 | int n {0}, gg {0}, occ {0}, od_s {0}, oh_s {0}, owb {0}; |
320 | |
321 | if (jcp.loop_order == loop_cwgn) |
322 | nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg, |
323 | nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); |
324 | else if (jcp.loop_order == loop_gncw) |
325 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks, |
326 | owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); |
327 | else if (jcp.loop_order == loop_nhwcg) |
328 | nd_iterator_init(start, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, |
329 | jcp.nb_ow, occ, oc_chunks, gg, nb_groups); |
330 | else |
331 | assert(!"unsupported loop order" ); |
332 | |
333 | while (start < end) { |
334 | |
335 | int ocb = occ * jcp.nb_oc_blocking; |
336 | int g = gg * g_blocking; |
337 | int work_rem = end - start; |
338 | int ih_s = -jcp.t_pad + oh_s * jcp.stride_h; |
339 | int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
340 | if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; //step instead |
341 | int ow_s = owb * jcp.ow_block; |
342 | int iw_s = ow_s * jcp.stride_w; |
343 | |
344 | int id_s = -jcp.f_pad + od_s * jcp.stride_d; |
345 | int dilate_d = jcp.dilate_d + 1; |
346 | int d_t_overflow = div_up(max(0, -id_s), dilate_d); |
347 | int d_b_overflow = div_up( |
348 | max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), |
349 | dilate_d); |
350 | int kd_padding = nstl::max(0, jcp.kd - d_t_overflow - d_b_overflow); |
351 | |
352 | const bool is_dst_layout_nxc = jcp.dst_tag == format_tag::ndhwc; |
353 | const int oc_idx = is_dst_layout_nxc |
354 | ? g * jcp.oc + ocb * jcp.oc_block |
355 | : g * jcp.nb_oc + ocb; |
356 | auto dst_w = dst |
357 | + jcp.typesize_out |
358 | * dst_d.blk_off(n, oc_idx, od_s, oh_s, ow_s); |
359 | auto bias_w = bias ? bias |
360 | + bia_dt_size * oc_idx |
361 | * (is_dst_layout_nxc ? 1 : jcp.oc_block) |
362 | : nullptr; |
363 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc; |
364 | const int ic_idx = is_src_layout_nxc ? g * jcp.ic : g * jcp.nb_ic; |
365 | auto src_w = src + src_d.blk_off(n, ic_idx, id_s, ih_s, iw_s) |
366 | + d_t_overflow * dilate_d * src_d_stride; |
367 | auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0) |
368 | + d_t_overflow * wht_d_stride; |
369 | |
370 | for (int oj = oh_s, ij = ih_s; oj < oh_e; |
371 | ++oj, ij += jcp.stride_h) { |
372 | int dilate_h = jcp.dilate_h + 1; |
373 | int i_t_overflow = div_up(max(0, -ij), dilate_h); |
374 | int i_b_overflow = div_up( |
375 | max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1), |
376 | dilate_h); |
377 | int kh_padding |
378 | = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow); |
379 | auto aux_src = src_w + i_t_overflow * dilate_h * src_h_stride; |
380 | auto aux_wht = wht_w + i_t_overflow * wht_h_stride; |
381 | |
382 | par_conv.load_work = utils::this_block_size(ocb * jcp.oc_block, |
383 | jcp.oc_without_padding, |
384 | jcp.nb_oc_blocking * jcp.oc_block); |
385 | par_conv.src = aux_src; |
386 | par_conv.dst = dst_w; |
387 | par_conv.filt = aux_wht; |
388 | par_conv.bias = bias_w; |
389 | par_conv.kh_padding = kh_padding; |
390 | par_conv.kd_padding = kd_padding; |
391 | par_conv.owb = owb; |
392 | |
393 | par_conv.oc_l_off |
394 | = oc_idx * (is_dst_layout_nxc ? 1 : jcp.oc_block); |
395 | par_conv.dst_orig = dst; |
396 | par_conv.post_ops_binary_rhs_arg_vec |
397 | = post_ops_binary_rhs_arg_vec.data(); |
398 | (*kernel_)(&par_conv); |
399 | |
400 | src_w += src_h_stride * jcp.stride_h; |
401 | dst_w += jcp.typesize_out * dst_h_stride; |
402 | } |
403 | if (jcp.loop_order == loop_cwgn) |
404 | nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg, |
405 | nb_groups, n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh); |
406 | else if (jcp.loop_order == loop_gncw) |
407 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ, |
408 | oc_chunks, owb, jcp.nb_ow, od_s, jcp.od, oh_s, jcp.oh); |
409 | else if (jcp.loop_order == loop_nhwcg) { |
410 | ++start; |
411 | nd_iterator_step(n, jcp.mb, od_s, jcp.od, oh_s, jcp.oh, owb, |
412 | jcp.nb_ow, occ, oc_chunks, gg, nb_groups); |
413 | } else |
414 | assert(!"unsupported loop order" ); |
415 | } |
416 | }); |
417 | } |
418 | |
419 | void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data_3d( |
420 | const exec_ctx_t &ctx) const { |
421 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
422 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
423 | auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); |
424 | |
425 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
426 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
427 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
428 | |
429 | const auto &jcp = pd()->jcp_; |
430 | |
431 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
432 | int start {0}, end {0}; |
433 | int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; |
434 | // TODO: experiment with g_blocking for perf fine tuning |
435 | int g_blocking = 1; |
436 | int nb_groups = jcp.ngroups / g_blocking; |
437 | int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.id * jcp.ih; |
438 | balance211(work_amount, nthr, ithr, start, end); |
439 | |
440 | auto par_conv = jit_conv_call_s(); |
441 | |
442 | size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1); |
443 | size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1); |
444 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1); |
445 | |
446 | bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1; |
447 | bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1; |
448 | |
449 | int n {0}, gg {0}, icc {0}, id_s {0}, ih_s {0}; |
450 | if (jcp.loop_order == loop_cgn) |
451 | nd_iterator_init(start, icc, ic_chunks, gg, nb_groups, n, jcp.mb, |
452 | id_s, jcp.id, ih_s, jcp.ih); |
453 | else if (jcp.loop_order == loop_gnc) |
454 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, ic_chunks, |
455 | id_s, jcp.id, ih_s, jcp.ih); |
456 | else if (jcp.loop_order == loop_nhwcg) |
457 | nd_iterator_init(start, n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc, |
458 | ic_chunks, gg, nb_groups); |
459 | else |
460 | assert(!"unsupported loop order" ); |
461 | |
462 | while (start < end) { |
463 | int icb = icc * jcp.nb_ic_blocking; |
464 | int g = gg * g_blocking; |
465 | int work_rem = end - start; |
466 | int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; |
467 | if (jcp.loop_order == loop_nhwcg) ih_e = ih_s + 1; //step instead |
468 | |
469 | int od_s = 0, kd_len = 0, kd_lo = 0; |
470 | if (is_fast_path_d) { |
471 | int d_t_overflow = max(0, jcp.kd - 1 - id_s - jcp.f_pad); |
472 | int d_b_overflow |
473 | = max(0, jcp.kd - jcp.id + id_s - jcp.back_pad); |
474 | kd_len = jcp.kd - d_t_overflow - d_b_overflow; |
475 | kd_lo = d_b_overflow; |
476 | od_s = id_s + jcp.f_pad - d_b_overflow; |
477 | } else if (jcp.dilate_d != 0) { // stride == 1 |
478 | int dilate_d = jcp.dilate_d + 1; |
479 | // Note: use div_up to account for "holes" in filter |
480 | int d_t_overflow = div_up( |
481 | max(0, (jcp.kd - 1) * dilate_d - id_s - jcp.f_pad), |
482 | dilate_d); |
483 | int d_b_overflow |
484 | = div_up(max(0, |
485 | (jcp.kd - 1) * dilate_d + 1 - jcp.id |
486 | + id_s - jcp.back_pad), |
487 | dilate_d); |
488 | kd_len = jcp.kd - d_t_overflow - d_b_overflow; |
489 | kd_lo = d_b_overflow; |
490 | od_s = id_s + jcp.f_pad - d_b_overflow * dilate_d; |
491 | } else { // dilate == 0 |
492 | int d_t_overflow = max( |
493 | 0, (jcp.kd - 1 - id_s - jcp.f_pad) / jcp.stride_d); |
494 | int d_b_overflow = max(0, |
495 | (jcp.kd - jcp.id + id_s - jcp.back_pad) / jcp.stride_d); |
496 | int overflow_kd_hi = jcp.kd - 1 |
497 | - modulo( |
498 | jcp.id - 1 + jcp.back_pad - id_s, jcp.stride_d); |
499 | int overflow_kd_lo = (id_s + jcp.f_pad) % jcp.stride_d; |
500 | |
501 | kd_len = (overflow_kd_hi - overflow_kd_lo) / jcp.stride_d + 1 |
502 | - d_t_overflow - d_b_overflow; |
503 | kd_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d; |
504 | od_s = (id_s + jcp.f_pad - kd_lo) / jcp.stride_d; |
505 | } |
506 | assert(kd_len >= 0); |
507 | |
508 | const bool is_dsrc_layout_nxc = jcp.src_tag == format_tag::ndhwc; |
509 | const int ic_idx = is_dsrc_layout_nxc |
510 | ? g * jcp.ic + icb * jcp.ic_block |
511 | : g * jcp.nb_ic + icb; |
512 | auto diff_src_w = diff_src |
513 | + jcp.typesize_out * diff_src_d.blk_off(n, ic_idx, id_s); |
514 | const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::ndhwc; |
515 | const int oc_idx = is_ddst_layout_nxc ? g * jcp.oc : g * jcp.nb_oc; |
516 | auto diff_dst_w = diff_dst + diff_dst_d.blk_off(n, oc_idx, od_s); |
517 | auto wht_w = weights + wht_blk_off(weights_d, g, 0, icb, kd_lo); |
518 | |
519 | for (int ij = ih_s; ij < ih_e; ++ij) { |
520 | int oj, kh_len, kh_lo; |
521 | if (is_fast_path_h) { // dilate == 0 && stride == 1 |
522 | int i_t_overflow = max(0, jcp.kh - 1 - ij - jcp.t_pad); |
523 | int i_b_overflow = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad); |
524 | kh_len = jcp.kh - i_t_overflow - i_b_overflow; |
525 | kh_lo = i_b_overflow; |
526 | oj = ij + jcp.t_pad - i_b_overflow; |
527 | } else if (jcp.dilate_h != 0) { // stride == 1 |
528 | int dilate_h = jcp.dilate_h + 1; |
529 | // Note: use div_up to account for "holes" in filter |
530 | int i_t_overflow = div_up( |
531 | max(0, (jcp.kh - 1) * dilate_h - ij - jcp.t_pad), |
532 | dilate_h); |
533 | int i_b_overflow |
534 | = div_up(max(0, |
535 | (jcp.kh - 1) * dilate_h + 1 |
536 | - jcp.ih + ij - jcp.b_pad), |
537 | dilate_h); |
538 | kh_len = jcp.kh - i_t_overflow - i_b_overflow; |
539 | kh_lo = i_b_overflow; |
540 | oj = ij + jcp.t_pad - i_b_overflow * dilate_h; |
541 | } else { // dilate == 0 |
542 | int i_t_overflow = max( |
543 | 0, (jcp.kh - 1 - ij - jcp.t_pad) / jcp.stride_h); |
544 | int i_b_overflow = max(0, |
545 | (jcp.kh - jcp.ih + ij - jcp.b_pad) / jcp.stride_h); |
546 | int overflow_kh_hi = jcp.kh - 1 |
547 | - modulo(jcp.ih - 1 + jcp.b_pad - ij, jcp.stride_h); |
548 | int overflow_kh_lo = (ij + jcp.t_pad) % jcp.stride_h; |
549 | |
550 | kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h |
551 | + 1 - i_t_overflow - i_b_overflow; |
552 | kh_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; |
553 | oj = (ij + jcp.t_pad - kh_lo) / jcp.stride_h; |
554 | } |
555 | assert(kh_len >= 0); |
556 | |
557 | par_conv.load_work = utils::this_block_size(icb * jcp.ic_block, |
558 | jcp.ic, jcp.nb_ic_blocking * jcp.ic_block); |
559 | par_conv.src = diff_src_w |
560 | + jcp.typesize_out * ij * diff_src_h_stride; |
561 | par_conv.dst = diff_dst_w + oj * diff_dst_h_stride; |
562 | par_conv.filt = wht_w + kh_lo * wht_h_stride; |
563 | par_conv.kh_padding = kh_len; |
564 | par_conv.kd_padding = kd_len; |
565 | |
566 | (*kernel_)(&par_conv); |
567 | } |
568 | |
569 | if (jcp.loop_order == loop_cgn) |
570 | nd_iterator_jump(start, end, icc, ic_chunks, gg, nb_groups, n, |
571 | jcp.mb, id_s, jcp.id, ih_s, jcp.ih); |
572 | else if (jcp.loop_order == loop_gnc) |
573 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc, |
574 | ic_chunks, id_s, jcp.id, ih_s, jcp.ih); |
575 | else if (jcp.loop_order == loop_nhwcg) { |
576 | ++start; |
577 | nd_iterator_step(n, jcp.mb, id_s, jcp.id, ih_s, jcp.ih, icc, |
578 | ic_chunks, gg, nb_groups); |
579 | } else |
580 | assert(!"unsupported loop order" ); |
581 | } |
582 | }); |
583 | } |
584 | |
585 | void jit_avx512_core_bf16_convolution_bwd_data_t ::execute_backward_data( |
586 | const exec_ctx_t &ctx) const { |
587 | auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
588 | auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); |
589 | auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); |
590 | |
591 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
592 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
593 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
594 | |
595 | const auto &jcp = pd()->jcp_; |
596 | |
597 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
598 | int start {0}, end {0}; |
599 | int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking; |
600 | // TODO: experiment with g_blocking for perf fine tuning |
601 | int g_blocking = 1; |
602 | int nb_groups = jcp.ngroups / g_blocking; |
603 | int work_amount = nb_groups * jcp.mb * ic_chunks * jcp.ih * jcp.nb_iw; |
604 | balance211(work_amount, nthr, ithr, start, end); |
605 | |
606 | auto par_conv = jit_conv_call_s(); |
607 | size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1); |
608 | size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1); |
609 | size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1); |
610 | |
611 | bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1; |
612 | |
613 | int n {0}, gg {0}, icc {0}, ih_s {0}, iwb {0}; |
614 | if (jcp.loop_order == loop_cwgn) |
615 | nd_iterator_init(start, icc, ic_chunks, iwb, jcp.nb_iw, gg, |
616 | nb_groups, n, jcp.mb, ih_s, jcp.ih); |
617 | else if (jcp.loop_order == loop_gncw) |
618 | nd_iterator_init(start, gg, nb_groups, n, jcp.mb, icc, ic_chunks, |
619 | iwb, jcp.nb_iw, ih_s, jcp.ih); |
620 | else if (jcp.loop_order == loop_nhwcg) |
621 | nd_iterator_init(start, n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw, |
622 | icc, ic_chunks, gg, nb_groups); |
623 | else |
624 | assert(!"unsupported loop order" ); |
625 | |
626 | while (start < end) { |
627 | int icb = icc * jcp.nb_ic_blocking; |
628 | int g = gg * g_blocking; |
629 | int work_rem = end - start; |
630 | int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem; |
631 | if (jcp.loop_order == loop_nhwcg) ih_e = ih_s + 1; //step instead |
632 | int iw_s = iwb * jcp.iw_block; |
633 | int ow_s = iw_s / jcp.stride_w; |
634 | |
635 | auto diff_src_w = diff_src; |
636 | auto diff_dst_w = diff_dst; |
637 | const bool is_ddst_layout_nxc = utils::one_of( |
638 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc); |
639 | const int oc_idx = is_ddst_layout_nxc ? g * jcp.oc : g * jcp.nb_oc; |
640 | const bool is_dsrc_layout_nxc = utils::one_of( |
641 | jcp.src_tag, format_tag::nwc, format_tag::nhwc); |
642 | const int ic_idx = is_dsrc_layout_nxc |
643 | ? g * jcp.ic + icb * jcp.ic_block |
644 | : g * jcp.nb_ic + icb; |
645 | if (jcp.ndims == 3) { |
646 | diff_src_w += jcp.typesize_out |
647 | * diff_src_d.blk_off(n, ic_idx, iw_s); |
648 | diff_dst_w += diff_dst_d.blk_off(n, oc_idx, ow_s); |
649 | } else { |
650 | diff_src_w += jcp.typesize_out |
651 | * diff_src_d.blk_off(n, ic_idx, 0, iw_s); |
652 | diff_dst_w += diff_dst_d.blk_off(n, oc_idx, 0, ow_s); |
653 | } |
654 | auto wht_w = weights + wht_blk_off(weights_d, g, 0, icb); |
655 | |
656 | for (int ij = ih_s; ij < ih_e; ++ij) { |
657 | int oj, k_len, k_lo; |
658 | if (is_fast_path) { // dilate == 0 && stride == 1 |
659 | int i_t_overflow = max(0, jcp.kh - 1 - ij - jcp.t_pad); |
660 | int i_b_overflow = max(0, jcp.kh - jcp.ih + ij - jcp.b_pad); |
661 | k_len = jcp.kh - i_t_overflow - i_b_overflow; |
662 | k_lo = i_b_overflow; |
663 | oj = ij + jcp.t_pad - i_b_overflow; |
664 | } else if (jcp.dilate_h != 0) { // stride == 1 |
665 | int dilate_h = jcp.dilate_h + 1; |
666 | // Note: use div_up to account for "holes" in filter |
667 | int i_t_overflow = div_up( |
668 | max(0, (jcp.kh - 1) * dilate_h - ij - jcp.t_pad), |
669 | dilate_h); |
670 | int i_b_overflow |
671 | = div_up(max(0, |
672 | (jcp.kh - 1) * dilate_h + 1 |
673 | - jcp.ih + ij - jcp.b_pad), |
674 | dilate_h); |
675 | k_len = jcp.kh - i_t_overflow - i_b_overflow; |
676 | k_lo = i_b_overflow; |
677 | oj = ij + jcp.t_pad - i_b_overflow * dilate_h; |
678 | } else { // dilate == 0 |
679 | int i_t_overflow = max( |
680 | 0, (jcp.kh - 1 - ij - jcp.t_pad) / jcp.stride_h); |
681 | int i_b_overflow = max(0, |
682 | (jcp.kh - jcp.ih + ij - jcp.b_pad) / jcp.stride_h); |
683 | int overflow_kh_hi = jcp.kh - 1 |
684 | - modulo(jcp.ih - 1 + jcp.b_pad - ij, jcp.stride_h); |
685 | int overflow_kh_lo = (ij + jcp.t_pad) % jcp.stride_h; |
686 | |
687 | k_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h + 1 |
688 | - i_t_overflow - i_b_overflow; |
689 | k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h; |
690 | oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h; |
691 | } |
692 | assert(k_len >= 0); |
693 | |
694 | par_conv.load_work = utils::this_block_size(icb * jcp.ic_block, |
695 | jcp.ic, jcp.nb_ic_blocking * jcp.ic_block); |
696 | par_conv.src = diff_src_w |
697 | + jcp.typesize_out * ij * diff_src_h_stride; |
698 | par_conv.dst = diff_dst_w + oj * diff_dst_h_stride; |
699 | par_conv.filt = wht_w + k_lo * wht_h_stride; |
700 | par_conv.kh_padding = k_len; |
701 | par_conv.iwb = iwb; |
702 | |
703 | (*kernel_)(&par_conv); |
704 | } |
705 | |
706 | if (jcp.loop_order == loop_cwgn) |
707 | nd_iterator_jump(start, end, icc, ic_chunks, iwb, jcp.nb_iw, gg, |
708 | nb_groups, n, jcp.mb, ih_s, jcp.ih); |
709 | else if (jcp.loop_order == loop_gncw) |
710 | nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, icc, |
711 | ic_chunks, iwb, jcp.nb_iw, ih_s, jcp.ih); |
712 | else if (jcp.loop_order == loop_nhwcg) { |
713 | ++start; |
714 | nd_iterator_step(n, jcp.mb, ih_s, jcp.ih, iwb, jcp.nb_iw, icc, |
715 | ic_chunks, gg, nb_groups); |
716 | } else |
717 | assert(!"unsupported loop order" ); |
718 | } |
719 | }); |
720 | } |
721 | |
722 | status_t jit_avx512_core_bf16_convolution_bwd_weights_t ::init( |
723 | engine_t *engine) { |
724 | const auto &j = pd()->jcp_; |
725 | |
726 | nthr_ = j.nthr; |
727 | nthr_mb_ = j.nthr_mb; |
728 | nthr_g_ = j.nthr_g; |
729 | nthr_oc_b_ = j.nthr_oc_b; |
730 | nthr_ic_b_ = j.nthr_ic_b; |
731 | |
732 | CHECK(safe_ptr_assign( |
733 | kernel_, new jit_avx512_core_bf16_conv_bwd_weights_kernel_f32(j))); |
734 | CHECK(kernel_->create_kernel()); |
735 | |
736 | if (j.transpose_src) { |
737 | CHECK(safe_ptr_assign(trans_kernel_, create_trans_src(&j))); |
738 | CHECK(trans_kernel_->create_kernel()); |
739 | } |
740 | if (j.transpose_dst) { |
741 | CHECK(safe_ptr_assign(trans_dst_kernel_, create_trans_dst(&j))); |
742 | CHECK(trans_dst_kernel_->create_kernel()); |
743 | } |
744 | if (nthr_mb_ > 1) { |
745 | CHECK(safe_ptr_assign( |
746 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
747 | CHECK(acc_ker_->create_kernel()); |
748 | } |
749 | |
750 | return status::success; |
751 | } |
752 | |
753 | struct jit_avx512_core_bf16_convolution_bwd_weights_t ::thread_info_t { |
754 | const src_data_t *src = nullptr; |
755 | const diff_dst_data_t *diff_dst = nullptr; |
756 | const void *diff_weights = nullptr; |
757 | const void *diff_bias = nullptr; |
758 | |
759 | const memory_tracking::grantor_t scratchpad; |
760 | |
761 | src_data_t *tr_src = nullptr; |
762 | diff_dst_data_t *tr_diff_dst = nullptr; |
763 | simple_barrier::ctx_t *tr_src_bctx = nullptr; |
764 | simple_barrier::ctx_t *tr_diff_dst_bctx = nullptr; |
765 | |
766 | float *wei_bia_reduction = nullptr; |
767 | float *bia_reduction = nullptr; |
768 | simple_barrier::ctx_t *wei_bia_reduction_bctx = nullptr; |
769 | |
770 | int ithr = 0; |
771 | int ithr_ic_b = 0, ithr_oc_b = 0, ithr_g = 0, ithr_mb = 0; |
772 | int ithr_but_oc = 0; |
773 | int ithr_but_ic = 0; |
774 | |
775 | int img_start = 0, img_end = 0, img_work = 0; |
776 | int g_start = 0, g_end = 0, g_work = 0; |
777 | int oc_b_start = 0, oc_b_end = 0, oc_b_work = 0; |
778 | int ic_b_start = 0, ic_b_end = 0, ic_b_work = 0; |
779 | |
780 | thread_info_t(const jit_avx512_core_bf16_convolution_bwd_weights_t *self, |
781 | const exec_ctx_t &ctx, int ithr) |
782 | : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) { |
783 | diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
784 | src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
785 | diff_weights = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_WEIGHTS); |
786 | |
787 | const auto &jcp = self->kernel_->jcp; |
788 | diff_bias = self->pd()->with_bias() |
789 | && (jcp.oc_without_padding % jcp.oc_block != 0) |
790 | && self->pd()->jcp_.bia_dt == data_type::f32 |
791 | ? (void *)scratchpad.template get<float>(key_conv_padded_bias) |
792 | : CTX_OUT_MEM(void *, DNNL_ARG_DIFF_BIAS); |
793 | |
794 | if (jcp.transpose_src) { |
795 | tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); |
796 | |
797 | if (jcp.global_transpose) |
798 | tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
799 | key_conv_tr_src_bctx); |
800 | } |
801 | if (jcp.transpose_dst) { |
802 | tr_diff_dst = scratchpad.template get<diff_dst_data_t>( |
803 | key_conv_tr_diff_dst); |
804 | |
805 | if (jcp.global_transpose) |
806 | tr_diff_dst_bctx |
807 | = scratchpad.template get<simple_barrier::ctx_t>( |
808 | key_conv_tr_diff_dst_bctx); |
809 | } |
810 | wei_bia_reduction |
811 | = scratchpad.template get<float>(key_conv_wei_bia_reduction); |
812 | bia_reduction = nullptr; |
813 | if (jcp.with_bias) { |
814 | const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block |
815 | * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
816 | const int num_wei_buffers = jcp.wei_dt == data_type::bf16 |
817 | ? jcp.nthr_mb |
818 | : jcp.nthr_mb - 1; |
819 | bia_reduction = wei_bia_reduction + wei_size * num_wei_buffers; |
820 | } |
821 | |
822 | if (jcp.global_transpose) |
823 | wei_bia_reduction_bctx |
824 | = scratchpad.template get<simple_barrier::ctx_t>( |
825 | key_conv_wei_bia_reduction_bctx); |
826 | |
827 | ithr_ic_b = ithr % self->nthr_ic_b_; |
828 | ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; |
829 | ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; |
830 | ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; |
831 | |
832 | ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ |
833 | + ithr_ic_b; |
834 | |
835 | ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ |
836 | + ithr_oc_b; |
837 | |
838 | int work_amount = jcp.nthr_mb_work; |
839 | /* reduction dimension */ |
840 | balance211(work_amount, self->nthr_mb_, ithr_mb, img_start, img_end); |
841 | img_work = img_end - img_start; |
842 | |
843 | /* independent dimensions */ |
844 | balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); |
845 | g_work = g_end - g_start; |
846 | |
847 | balance211( |
848 | jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, oc_b_end); |
849 | oc_b_work = oc_b_end - oc_b_start; |
850 | |
851 | balance211( |
852 | jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, ic_b_end); |
853 | ic_b_work = ic_b_end - ic_b_start; |
854 | } |
855 | }; |
856 | |
857 | size_t jit_avx512_core_bf16_convolution_bwd_weights_t::tr_src_buf_number( |
858 | const thread_info_t *ti, int g, int ic) const { |
859 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
860 | return jcp.global_transpose |
861 | ? ti->ithr_mb * jcp.nb_ic * jcp.ngroups + g * jcp.nb_ic + ic |
862 | : ti->ithr; |
863 | } |
864 | size_t jit_avx512_core_bf16_convolution_bwd_weights_t::tr_diff_dst_buf_number( |
865 | const thread_info_t *ti, int g, int oc) const { |
866 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
867 | return jcp.global_transpose |
868 | ? ti->ithr_mb * jcp.nb_oc * jcp.ngroups + g * jcp.nb_oc + oc |
869 | : ti->ithr; |
870 | } |
871 | |
872 | void jit_avx512_core_bf16_convolution_bwd_weights_t ::trans_src( |
873 | src_data_t *tr_src, const src_data_t *src, int row_count) const { |
874 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
875 | const int pf_depth = 2; |
876 | struct { |
877 | const src_data_t *src; |
878 | src_data_t *tr_src; |
879 | } pf_circ_buf_src[pf_depth]; |
880 | |
881 | assert(jcp.ic_block == 16 || jcp.is_1stconv); |
882 | const int src_stride = jcp.iw * jcp.ic_block; |
883 | const int tr_src_stride = jcp.tr_iw * jcp.ic_block; |
884 | |
885 | for (int iwork = 0; iwork < row_count + pf_depth - 1; iwork++) { |
886 | pf_circ_buf_src[iwork % pf_depth] = {src, tr_src}; |
887 | |
888 | if (iwork >= pf_depth - 1) { |
889 | int old_idx = (iwork - pf_depth + 1) % pf_depth; |
890 | auto ctx = jit_trans_src_t::ctx_t(); |
891 | ctx.src = pf_circ_buf_src[old_idx].src; |
892 | ctx.tr_src = pf_circ_buf_src[old_idx].tr_src; |
893 | ctx.src_prf = src; |
894 | ctx.tr_src_prf = tr_src; |
895 | (*trans_kernel_)(&ctx); |
896 | } |
897 | src += src_stride; |
898 | tr_src += tr_src_stride; |
899 | } |
900 | } |
901 | |
902 | void jit_avx512_core_bf16_convolution_bwd_weights_t::trans_src_nxc( |
903 | src_data_t *tr_src, const src_data_t *src_base, int spatial_start, |
904 | dim_t spatial_start_offset, int icb_start, dim_t chb_stride, |
905 | int row_count) const { |
906 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
907 | const int src_stride = jcp.iw * jcp.ngroups * jcp.ic; |
908 | const int tr_src_stride = jcp.tr_iw * jcp.ic_block; |
909 | |
910 | int work_rest = row_count; |
911 | int max_spatial_work = jcp.id * jcp.ih; |
912 | int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start); |
913 | const src_data_t *src = src_base + spatial_start_offset; |
914 | int icb = 0; |
915 | const int ic_tail_work = jcp.ic_tail ? jcp.ic_tail : jcp.ic_block; |
916 | while (work_rest > 0) { |
917 | for (int iwork = 0; iwork < sp_work; iwork++) { |
918 | auto ctx = jit_trans_src_t::ctx_t(); |
919 | ctx.src = src; |
920 | ctx.tr_src = tr_src; |
921 | assert(icb_start + icb < jcp.nb_ic); |
922 | ctx.ch_work = (icb_start + icb + 1) == jcp.nb_ic ? ic_tail_work |
923 | : jcp.ic_block; |
924 | ctx.src_prf = nullptr; |
925 | ctx.tr_src_prf = nullptr; |
926 | (*trans_kernel_)(&ctx); |
927 | src += src_stride; |
928 | tr_src += tr_src_stride; |
929 | } |
930 | work_rest -= sp_work; |
931 | sp_work = nstl::min(work_rest, max_spatial_work); |
932 | icb++; |
933 | src = src_base + icb * chb_stride; |
934 | } |
935 | } |
936 | |
937 | void jit_avx512_core_bf16_convolution_bwd_weights_t ::trans_dst( |
938 | diff_dst_data_t *tr_diff_dst, const diff_dst_data_t *diff_dst, |
939 | int row_count) const { |
940 | |
941 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
942 | const int pf_depth = 2; |
943 | struct { |
944 | const diff_dst_data_t *diff_dst; |
945 | diff_dst_data_t *tr_diff_dst; |
946 | } pf_circ_buf_dst[pf_depth]; |
947 | |
948 | assert(jcp.ic_block == 16 || jcp.is_1stconv); |
949 | const int diff_dst_stride = jcp.ow * jcp.oc_block; |
950 | const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block; |
951 | |
952 | for (int iwork = 0; iwork < row_count + pf_depth - 1; iwork++) { |
953 | pf_circ_buf_dst[iwork % pf_depth] |
954 | = {(diff_dst_data_t *)diff_dst, tr_diff_dst}; |
955 | |
956 | if (iwork >= pf_depth - 1) { |
957 | int old_idx = (iwork - pf_depth + 1) % pf_depth; |
958 | auto ctx = jit_trans_dst_t::ctx_t(); |
959 | ctx.src = pf_circ_buf_dst[old_idx].diff_dst; |
960 | ctx.tr_src = pf_circ_buf_dst[old_idx].tr_diff_dst; |
961 | ctx.src_prf = diff_dst; |
962 | ctx.tr_src_prf = tr_diff_dst; |
963 | (*trans_dst_kernel_)(&ctx); |
964 | } |
965 | diff_dst += diff_dst_stride; |
966 | tr_diff_dst += tr_diff_dst_stride; |
967 | } |
968 | } |
969 | |
970 | void jit_avx512_core_bf16_convolution_bwd_weights_t::trans_dst_nxc( |
971 | diff_dst_data_t *tr_diff_dst, const diff_dst_data_t *diff_dst_base, |
972 | int spatial_start, dim_t spatial_start_offset, int ocb_start, |
973 | dim_t chb_stride, int row_count) const { |
974 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
975 | const int diff_dst_stride = jcp.ow * jcp.ngroups * jcp.oc; |
976 | const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block; |
977 | int work_rest = row_count; |
978 | int max_spatial_work = jcp.od * jcp.oh; |
979 | int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start); |
980 | const src_data_t *diff_dst = diff_dst_base + spatial_start_offset; |
981 | int ocb = 0; |
982 | const int oc_tail_work = jcp.oc_tail ? jcp.oc_tail : jcp.oc_block; |
983 | while (work_rest > 0) { |
984 | for (int iwork = 0; iwork < sp_work; iwork++) { |
985 | auto ctx = jit_trans_dst_t::ctx_t(); |
986 | ctx.src = diff_dst; |
987 | ctx.tr_src = tr_diff_dst; |
988 | assert(ocb_start + ocb < jcp.nb_oc); |
989 | ctx.ch_work = (ocb_start + ocb + 1) == jcp.nb_oc ? oc_tail_work |
990 | : jcp.oc_block; |
991 | ctx.src_prf = nullptr; |
992 | ctx.tr_src_prf = nullptr; |
993 | (*trans_dst_kernel_)(&ctx); |
994 | diff_dst += diff_dst_stride; |
995 | tr_diff_dst += tr_diff_dst_stride; |
996 | } |
997 | work_rest -= sp_work; |
998 | sp_work = nstl::min(work_rest, max_spatial_work); |
999 | ocb++; |
1000 | diff_dst = diff_dst_base + ocb * chb_stride; |
1001 | } |
1002 | } |
1003 | |
1004 | void jit_avx512_core_bf16_convolution_bwd_weights_t ::compute_diff_weights_2d( |
1005 | const thread_info_t *ti) const { |
1006 | const memory_desc_wrapper src_d(pd()->src_md()); |
1007 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1008 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1009 | const auto &jcp = kernel_->jcp; |
1010 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::nhwc; |
1011 | const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::nhwc; |
1012 | |
1013 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1014 | * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
1015 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1016 | const int optimal_hblock = jcp.spatial_blk_size; |
1017 | |
1018 | float *diff_wei; |
1019 | if (diff_weights_d.data_type() == data_type::bf16) |
1020 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
1021 | else |
1022 | diff_wei = ti->ithr_mb == 0 |
1023 | ? (float *)ti->diff_weights |
1024 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1025 | |
1026 | float *diff_bias = nullptr; |
1027 | if (jcp.with_bias) { |
1028 | if (jcp.bia_dt == data_type::bf16) |
1029 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
1030 | else |
1031 | diff_bias = ti->ithr_mb == 0 |
1032 | ? (float *)ti->diff_bias |
1033 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
1034 | } |
1035 | |
1036 | auto tr_diff_dst_off = [&](int g, int oc, int oj) { |
1037 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1038 | return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size |
1039 | + oj * tr_row_size; |
1040 | }; |
1041 | |
1042 | int img {0}, oh_s {0}; |
1043 | int start = ti->img_start; |
1044 | int end = ti->img_end; |
1045 | |
1046 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1047 | |
1048 | nd_iterator_init(start, img, jcp.mb, oh_s, jcp.oh); |
1049 | |
1050 | while (start < end) { |
1051 | auto p = jit_conv_call_s(); |
1052 | int work_rem = end - start; |
1053 | const int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
1054 | int ih_s = nstl::max(0, -jcp.t_pad + oh_s * jcp.stride_h); |
1055 | const int ih_e = nstl::min( |
1056 | jcp.ih, -jcp.t_pad + (oh_e - 1) * jcp.stride_h + ext_kh); |
1057 | |
1058 | auto tr_src_off = [&](int g, int ic, int ih_end, int ij) { |
1059 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1060 | // Aligned to buffer end to use guard elements |
1061 | return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size |
1062 | + (jcp.ih - ih_end + ij) * tr_row_size; |
1063 | }; |
1064 | |
1065 | if (jcp.global_transpose) { |
1066 | using simple_barrier::barrier; |
1067 | // TODO: try to call local transpositions just before jit kernel |
1068 | /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ |
1069 | if (jcp.transpose_src) { |
1070 | int j {0}; |
1071 | const int work_amount |
1072 | = ti->g_work * ti->ic_b_work * (ih_e - ih_s); |
1073 | int tr_start {0}, tr_end {0}; |
1074 | balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start, |
1075 | tr_end); |
1076 | |
1077 | int g {0}, ic_b {0}; |
1078 | nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work, |
1079 | j, ih_e - ih_s); |
1080 | |
1081 | if (nthr_oc_b_ > 1) |
1082 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1083 | while (tr_start < tr_end) { |
1084 | int g_ = g + ti->g_start; |
1085 | int ic_b_ = ic_b + ti->ic_b_start; |
1086 | int j_s = j + ih_s; |
1087 | int j_e = j_s + nstl::min(tr_end - tr_start, ih_e - j_s); |
1088 | const int ic_off_idx = is_src_layout_nxc |
1089 | ? g_ * jcp.ic + ic_b_ * jcp.ic_block |
1090 | : g_ * jcp.nb_ic + ic_b_; |
1091 | const src_data_t *src |
1092 | = &ti->src[src_d.blk_off(img, ic_off_idx, j_s)]; |
1093 | src_data_t *tr_src |
1094 | = &ti->tr_src[tr_src_off(g_, ic_b_, ih_e, j_s)]; |
1095 | |
1096 | if (is_src_layout_nxc) |
1097 | trans_src_nxc(tr_src, src, 0, 0, ic_b_, 0, j_e - j_s); |
1098 | else |
1099 | trans_src(tr_src, src, j_e - j_s); |
1100 | |
1101 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b, |
1102 | ti->ic_b_work, j, ih_e - ih_s); |
1103 | } |
1104 | if (nthr_oc_b_ > 1) |
1105 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1106 | } |
1107 | if (jcp.transpose_dst) { |
1108 | int j {0}; |
1109 | const int work_amount |
1110 | = ti->g_work * ti->oc_b_work * (oh_e - oh_s); |
1111 | int tr_start {0}, tr_end {0}; |
1112 | balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start, |
1113 | tr_end); |
1114 | |
1115 | int g {0}, oc_b {0}; |
1116 | nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work, |
1117 | j, oh_e - oh_s); |
1118 | |
1119 | if (nthr_ic_b_ > 1) |
1120 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1121 | while (tr_start < tr_end) { |
1122 | int g_ = g + ti->g_start; |
1123 | int oc_b_ = oc_b + ti->oc_b_start; |
1124 | int j_s = j + oh_s; |
1125 | int j_e = j_s + nstl::min(tr_end - tr_start, oh_e - j_s); |
1126 | const int oc_off_idx = is_ddst_layout_nxc |
1127 | ? g_ * jcp.oc + oc_b_ * jcp.oc_block |
1128 | : g_ * jcp.nb_oc + oc_b_; |
1129 | const diff_dst_data_t *diff_dst |
1130 | = &ti->diff_dst[diff_dst_d.blk_off( |
1131 | img, oc_off_idx, j_s)]; |
1132 | diff_dst_data_t *tr_diff_dst |
1133 | = &ti->tr_diff_dst[tr_diff_dst_off(g_, oc_b_, j_s)]; |
1134 | |
1135 | if (is_ddst_layout_nxc) |
1136 | trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0, |
1137 | j_e - j_s); |
1138 | else |
1139 | trans_dst(tr_diff_dst, diff_dst, j_e - j_s); |
1140 | |
1141 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b, |
1142 | ti->oc_b_work, j, oh_e - oh_s); |
1143 | } |
1144 | if (nthr_ic_b_ > 1) |
1145 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1146 | } |
1147 | } |
1148 | int height_block = jcp.global_transpose ? oh_e - oh_s : optimal_hblock; |
1149 | int ic_b_step |
1150 | = jcp.uses_permw_transposition ? jcp.nb_ic_blocking_max : 1; |
1151 | int icb_work = ti->ic_b_end - ti->ic_b_start; |
1152 | if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step) |
1153 | ic_b_step = utils::div_up(icb_work, 2); |
1154 | |
1155 | for_(int ohb_s = oh_s; ohb_s < oh_e; ohb_s += height_block) |
1156 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1157 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) |
1158 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1159 | ic_b += ic_b_step) { |
1160 | const int ohb_e = nstl::min(ohb_s + height_block, oh_e); |
1161 | const int ihb_s = nstl::max(0, -jcp.t_pad + ohb_s * jcp.stride_h); |
1162 | const int ihb_e = nstl::min( |
1163 | jcp.ih, -jcp.t_pad + (ohb_e - 1) * jcp.stride_h + ext_kh); |
1164 | assert(IMPLICATION(jcp.global_transpose, |
1165 | oh_s == ohb_s && oh_e == ohb_e && ih_s == ihb_s |
1166 | && ih_e == ihb_e)); |
1167 | const int ic_off_idx = is_src_layout_nxc |
1168 | ? g * jcp.ic + ic_b * jcp.ic_block |
1169 | : g * jcp.nb_ic + ic_b; |
1170 | const int oc_off_idx = is_ddst_layout_nxc |
1171 | ? g * jcp.oc + oc_b * jcp.oc_block |
1172 | : g * jcp.nb_oc + oc_b; |
1173 | const int ic_to_compute = this_block_size( |
1174 | ic_b * jcp.ic_block, jcp.ic, ic_b_step * jcp.ic_block); |
1175 | const int oc_to_compute = this_block_size( |
1176 | oc_b * jcp.oc_block, jcp.oc, jcp.oc_block); |
1177 | |
1178 | if (jcp.transpose_src) { |
1179 | if (!jcp.global_transpose) { |
1180 | const src_data_t *src |
1181 | = (src_data_t *)&ti->src[src_d.blk_off( |
1182 | img, ic_off_idx, ihb_s)]; |
1183 | src_data_t *tr_src |
1184 | = &ti->tr_src[tr_src_off(g, ic_b, ihb_e, ihb_s)]; |
1185 | if (is_src_layout_nxc) |
1186 | trans_src_nxc( |
1187 | tr_src, src, 0, 0, ic_b, 0, ihb_e - ihb_s); |
1188 | else |
1189 | trans_src(tr_src, src, ihb_e - ihb_s); |
1190 | p.src = tr_src; |
1191 | } else { |
1192 | p.src = &ti->tr_src[tr_src_off(g, ic_b, ihb_e, ihb_s)]; |
1193 | } |
1194 | } else { |
1195 | p.src = &ti->src[src_d.blk_off(img, ic_off_idx, ihb_s)]; |
1196 | } |
1197 | |
1198 | if (jcp.transpose_dst) { |
1199 | if (!jcp.global_transpose) { |
1200 | const diff_dst_data_t *diff_dst |
1201 | = &ti->diff_dst[diff_dst_d.blk_off( |
1202 | img, oc_off_idx, ohb_s)]; |
1203 | diff_dst_data_t *tr_diff_dst |
1204 | = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0)]; |
1205 | if (is_ddst_layout_nxc) |
1206 | trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b, 0, |
1207 | ohb_e - ohb_s); |
1208 | else |
1209 | trans_dst(tr_diff_dst, diff_dst, ohb_e - ohb_s); |
1210 | p.dst = tr_diff_dst; |
1211 | } else { |
1212 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, ohb_s)]; |
1213 | } |
1214 | } else { |
1215 | p.dst = &ti->diff_dst[diff_dst_d.blk_off( |
1216 | img, oc_off_idx, ohb_s)]; |
1217 | } |
1218 | |
1219 | p.filt = diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1220 | p.bias = diff_bias ? diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
1221 | + oc_b * jcp.oc_block |
1222 | : nullptr; |
1223 | p.channel = (start == ti->img_start) && (ohb_s == oh_s); |
1224 | p.reduce_work = ic_to_compute; |
1225 | p.load_work = oc_to_compute; |
1226 | p.os_index_begin = ohb_s; |
1227 | p.os_index_end = ohb_e; |
1228 | p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0); |
1229 | assert(ohb_e <= jcp.oh); |
1230 | (*kernel_)(&p); |
1231 | } |
1232 | |
1233 | nd_iterator_jump(start, end, img, jcp.mb, oh_s, jcp.oh); |
1234 | } |
1235 | } |
1236 | |
1237 | void jit_avx512_core_bf16_convolution_bwd_weights_t ::compute_diff_weights_3d( |
1238 | const thread_info_t *ti) const { |
1239 | const memory_desc_wrapper src_d(pd()->src_md()); |
1240 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1241 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1242 | const auto &jcp = kernel_->jcp; |
1243 | const bool is_src_layout_nxc = jcp.src_tag == format_tag::ndhwc; |
1244 | const bool is_ddst_layout_nxc = jcp.dst_tag == format_tag::ndhwc; |
1245 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1246 | * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
1247 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1248 | const int optimal_dblock = jcp.spatial_blk_size; |
1249 | |
1250 | float *diff_wei; |
1251 | if (diff_weights_d.data_type() == data_type::bf16) |
1252 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
1253 | else |
1254 | diff_wei = ti->ithr_mb == 0 |
1255 | ? (float *)ti->diff_weights |
1256 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1257 | |
1258 | float *diff_bias = nullptr; |
1259 | if (jcp.with_bias) { |
1260 | if (jcp.bia_dt == data_type::bf16) |
1261 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
1262 | else |
1263 | diff_bias = ti->ithr_mb == 0 |
1264 | ? (float *)ti->diff_bias |
1265 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
1266 | } |
1267 | |
1268 | auto tr_diff_dst_off_3d = [&](int g, int oc, int od) { |
1269 | assert(IMPLICATION(is_ddst_layout_nxc, jcp.transpose_dst)); |
1270 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1271 | const size_t tr_3d_size = tr_row_size * jcp.oh; |
1272 | return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size |
1273 | + od * tr_3d_size; |
1274 | }; |
1275 | int img {0}, od_s {0}; |
1276 | int start = ti->img_start; |
1277 | int end = ti->img_end; |
1278 | |
1279 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1280 | |
1281 | nd_iterator_init(start, img, jcp.mb, od_s, jcp.od); |
1282 | while (start < end) { |
1283 | auto p = jit_conv_call_s(); |
1284 | int work_rem = end - start; |
1285 | const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; |
1286 | int id_s = nstl::max(0, -jcp.f_pad + od_s * jcp.stride_d); |
1287 | const int id_e = nstl::min( |
1288 | jcp.id, -jcp.f_pad + (od_e - 1) * jcp.stride_d + ext_kd); |
1289 | |
1290 | auto tr_src_off_3d = [&](int g, int ic, int id_end, int id) { |
1291 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1292 | const size_t tr_3d_size = tr_row_size * jcp.ih; |
1293 | // Aligned to buffer end to use guard elements |
1294 | return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size |
1295 | + (jcp.id - id_end + id) * tr_3d_size; |
1296 | }; |
1297 | |
1298 | if (jcp.global_transpose) { |
1299 | using simple_barrier::barrier; |
1300 | // TODO: try to call local transpositions just before jit kernel |
1301 | /* tr_src[nb_ic][id][16][~iw~] <- src[nb_ic][id][iw][16] */ |
1302 | if (jcp.transpose_src) { |
1303 | int d {0}; |
1304 | |
1305 | const int work_amount |
1306 | = ti->g_work * ti->ic_b_work * (id_e - id_s); |
1307 | |
1308 | int tr_start {0}, tr_end {0}; |
1309 | balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start, |
1310 | tr_end); |
1311 | |
1312 | int g {0}, ic_b {0}; |
1313 | |
1314 | nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work, |
1315 | d, id_e - id_s); |
1316 | |
1317 | if (nthr_oc_b_ > 1) |
1318 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1319 | while (tr_start < tr_end) { |
1320 | int g_ = g + ti->g_start; |
1321 | int ic_b_ = ic_b + ti->ic_b_start; |
1322 | int d_s = d + id_s; |
1323 | int d_e = d_s + nstl::min(tr_end - tr_start, id_e - d_s); |
1324 | |
1325 | const int ic_off_idx = is_src_layout_nxc |
1326 | ? g_ * jcp.ic + ic_b_ * jcp.ic_block |
1327 | : g_ * jcp.nb_ic + ic_b_; |
1328 | const src_data_t *src |
1329 | = &ti->src[src_d.blk_off(img, ic_off_idx, d_s)]; |
1330 | src_data_t *tr_src |
1331 | = &ti->tr_src[tr_src_off_3d(g_, ic_b_, id_e, d_s)]; |
1332 | |
1333 | if (is_src_layout_nxc) |
1334 | trans_src_nxc(tr_src, src, 0, 0, ic_b_, 0, |
1335 | (d_e - d_s) * jcp.ih); |
1336 | else |
1337 | trans_src(tr_src, src, (d_e - d_s) * jcp.ih); |
1338 | |
1339 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b, |
1340 | ti->ic_b_work, d, id_e - id_s); |
1341 | } |
1342 | if (nthr_oc_b_ > 1) |
1343 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1344 | } |
1345 | if (jcp.transpose_dst) { |
1346 | int d {0}; |
1347 | |
1348 | const int work_amount |
1349 | = ti->g_work * ti->oc_b_work * (od_e - od_s); |
1350 | |
1351 | int tr_start {0}, tr_end {0}; |
1352 | balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start, |
1353 | tr_end); |
1354 | |
1355 | int g {0}, oc_b {0}; |
1356 | |
1357 | nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work, |
1358 | d, od_e - od_s); |
1359 | |
1360 | if (nthr_ic_b_ > 1) |
1361 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1362 | while (tr_start < tr_end) { |
1363 | int g_ = g + ti->g_start; |
1364 | int oc_b_ = oc_b + ti->oc_b_start; |
1365 | int d_s = d + od_s; |
1366 | int d_e = d_s + nstl::min(tr_end - tr_start, od_e - d_s); |
1367 | const int oc_off_idx = is_ddst_layout_nxc |
1368 | ? g_ * jcp.oc + oc_b_ * jcp.oc_block |
1369 | : g_ * jcp.nb_oc + oc_b_; |
1370 | |
1371 | const diff_dst_data_t *diff_dst |
1372 | = &ti->diff_dst[diff_dst_d.blk_off( |
1373 | img, oc_off_idx, d_s)]; |
1374 | diff_dst_data_t *tr_diff_dst |
1375 | = &ti->tr_diff_dst[tr_diff_dst_off_3d( |
1376 | g_, oc_b_, d_s)]; |
1377 | |
1378 | if (is_ddst_layout_nxc) |
1379 | trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0, |
1380 | (d_e - d_s) * jcp.oh); |
1381 | else |
1382 | trans_dst(tr_diff_dst, diff_dst, (d_e - d_s) * jcp.oh); |
1383 | |
1384 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b, |
1385 | ti->oc_b_work, d, od_e - od_s); |
1386 | } |
1387 | if (nthr_ic_b_ > 1) |
1388 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1389 | } |
1390 | } |
1391 | |
1392 | int depth_block = jcp.global_transpose ? od_e - od_s : optimal_dblock; |
1393 | |
1394 | for_(int odb_s = od_s; odb_s < od_e; odb_s += depth_block) |
1395 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1396 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) |
1397 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) { |
1398 | const int odb_e = nstl::min(odb_s + depth_block, od_e); |
1399 | const int idb_s = nstl::max(0, -jcp.f_pad + odb_s * jcp.stride_d); |
1400 | const int idb_e = nstl::min( |
1401 | jcp.id, -jcp.f_pad + (odb_e - 1) * jcp.stride_d + ext_kd); |
1402 | const int kdb_front_pad |
1403 | = nstl::max(0, jcp.f_pad - odb_s * jcp.stride_d); |
1404 | // Assumes kd_back_pad = 0 when kernel is dilated |
1405 | const int kdb_back_pad = nstl::max( |
1406 | 0, odb_s * jcp.stride_d + jcp.kd - jcp.f_pad - jcp.id); |
1407 | const int kdb_pad_off = nstl::min(jcp.kd - 1, kdb_front_pad) |
1408 | * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block |
1409 | * jcp.typesize_out; |
1410 | |
1411 | assert(IMPLICATION(jcp.global_transpose, |
1412 | od_s == odb_s && od_e == odb_e && id_s == idb_s |
1413 | && id_e == idb_e)); |
1414 | const int ic_off_idx = is_src_layout_nxc |
1415 | ? g * jcp.ic + ic_b * jcp.ic_block |
1416 | : g * jcp.nb_ic + ic_b; |
1417 | const int oc_off_idx = is_ddst_layout_nxc |
1418 | ? g * jcp.oc + oc_b * jcp.oc_block |
1419 | : g * jcp.nb_oc + oc_b; |
1420 | const int ic_to_compute = this_block_size( |
1421 | ic_b * jcp.ic_block, jcp.ic, jcp.ic_block); |
1422 | const int oc_to_compute = this_block_size( |
1423 | oc_b * jcp.oc_block, jcp.oc, jcp.oc_block); |
1424 | if (jcp.transpose_src) { |
1425 | if (!jcp.global_transpose) { |
1426 | const src_data_t *src |
1427 | = (src_data_t *)&ti->src[src_d.blk_off( |
1428 | img, ic_off_idx, idb_s)]; |
1429 | src_data_t *tr_src |
1430 | = &ti->tr_src[tr_src_off_3d(g, ic_b, idb_e, idb_s)]; |
1431 | if (is_src_layout_nxc) |
1432 | trans_src_nxc(tr_src, src, 0, 0, ic_b, 0, |
1433 | (idb_e - idb_s) * jcp.ih); |
1434 | else |
1435 | trans_src(tr_src, src, (idb_e - idb_s) * jcp.ih); |
1436 | p.src = tr_src; |
1437 | } else { |
1438 | p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, idb_e, idb_s)]; |
1439 | } |
1440 | } else { |
1441 | p.src = &ti->src[src_d.blk_off(img, ic_off_idx, idb_s)]; |
1442 | } |
1443 | |
1444 | if (jcp.transpose_dst) { |
1445 | if (!jcp.global_transpose) { |
1446 | const diff_dst_data_t *diff_dst |
1447 | = &ti->diff_dst[diff_dst_d.blk_off( |
1448 | img, oc_off_idx, odb_s)]; |
1449 | diff_dst_data_t *tr_diff_dst |
1450 | = &ti->tr_diff_dst[tr_diff_dst_off_3d(0, 0, 0)]; |
1451 | if (is_ddst_layout_nxc) |
1452 | trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b, 0, |
1453 | (odb_e - odb_s) * jcp.oh); |
1454 | else |
1455 | trans_dst(tr_diff_dst, diff_dst, |
1456 | (odb_e - odb_s) * jcp.oh); |
1457 | p.dst = tr_diff_dst; |
1458 | } else { |
1459 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d( |
1460 | g, oc_b, odb_s)]; |
1461 | } |
1462 | } else { |
1463 | p.dst = &ti->diff_dst[diff_dst_d.blk_off( |
1464 | img, oc_off_idx, odb_s)]; |
1465 | } |
1466 | |
1467 | p.filt = diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1468 | p.bias = diff_bias ? diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
1469 | + oc_b * jcp.oc_block |
1470 | : nullptr; |
1471 | p.channel = (start == ti->img_start) && (odb_s == od_s); |
1472 | p.reduce_work = ic_to_compute; |
1473 | p.load_work = oc_to_compute; |
1474 | p.os_index_begin = odb_s; |
1475 | p.os_index_end = odb_e; |
1476 | p.kd_padding = jcp.kd - kdb_front_pad - kdb_back_pad; |
1477 | p.kd_offset = kdb_pad_off; |
1478 | p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0); |
1479 | assert(odb_e <= jcp.od); |
1480 | (*kernel_)(&p); |
1481 | } |
1482 | |
1483 | nd_iterator_jump(start, end, img, jcp.mb, od_s, jcp.od); |
1484 | } |
1485 | } |
1486 | |
1487 | void jit_avx512_core_bf16_convolution_bwd_weights_t ::compute_diff_weights( |
1488 | const thread_info_t *ti) const { |
1489 | const memory_desc_wrapper src_d(pd()->src_md()); |
1490 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1491 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1492 | const auto &jcp = kernel_->jcp; |
1493 | const bool is_src_layout_nxc = utils::one_of( |
1494 | jcp.src_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1495 | const bool is_ddst_layout_nxc = utils::one_of( |
1496 | jcp.dst_tag, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1497 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1498 | * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
1499 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1500 | |
1501 | float *diff_wei; |
1502 | if (diff_weights_d.data_type() == data_type::bf16) |
1503 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
1504 | else |
1505 | diff_wei = ti->ithr_mb == 0 |
1506 | ? (float *)ti->diff_weights |
1507 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1508 | |
1509 | float *diff_bias = nullptr; |
1510 | if (jcp.with_bias) { |
1511 | if (jcp.bia_dt == data_type::bf16) |
1512 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
1513 | else |
1514 | diff_bias = ti->ithr_mb == 0 |
1515 | ? (float *)ti->diff_bias |
1516 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
1517 | } |
1518 | |
1519 | auto tr_src_off = [&](int g, int ic, int ij) { |
1520 | assert(IMPLICATION(is_src_layout_nxc, jcp.transpose_src)); |
1521 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1522 | return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size |
1523 | + ij * tr_row_size; |
1524 | }; |
1525 | |
1526 | auto tr_src_off_3d = [&](int g, int ic, int id, int ij) { |
1527 | assert(IMPLICATION(is_ddst_layout_nxc, jcp.transpose_dst)); |
1528 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1529 | const size_t tr_3d_size = tr_row_size * jcp.ih; |
1530 | return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size |
1531 | + id * tr_3d_size + ij * tr_row_size; |
1532 | }; |
1533 | |
1534 | auto tr_diff_dst_off = [&](int g, int oc, int oj) { |
1535 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1536 | return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size |
1537 | + oj * tr_row_size; |
1538 | }; |
1539 | |
1540 | auto tr_diff_dst_off_3d = [&](int g, int oc, int od, int oj) { |
1541 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1542 | const size_t tr_3d_size = tr_row_size * jcp.oh; |
1543 | return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size |
1544 | + od * tr_3d_size + oj * tr_row_size; |
1545 | }; |
1546 | |
1547 | auto uker_trans = [&](int img, int g = 0, int ic_b = 0) { |
1548 | int j {0}, d {0}; |
1549 | int my_work = jcp.ih * jcp.id; |
1550 | int ic; |
1551 | int icb_start = ic_b; |
1552 | if (jcp.global_transpose) { |
1553 | const int work_amount = is_src_layout_nxc |
1554 | ? ti->ic_b_work * jcp.ih * jcp.id |
1555 | : ti->g_work * ti->ic_b_work * jcp.ih * jcp.id; |
1556 | |
1557 | int start {0}, end {0}; |
1558 | balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end); |
1559 | my_work = end - start; |
1560 | |
1561 | if (is_src_layout_nxc) { |
1562 | if (jcp.ndims == 5) |
1563 | nd_iterator_init( |
1564 | start, ic_b, ti->ic_b_work, d, jcp.id, j, jcp.ih); |
1565 | else |
1566 | nd_iterator_init(start, ic_b, ti->ic_b_work, j, jcp.ih); |
1567 | } else { |
1568 | if (jcp.ndims == 5) |
1569 | nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, |
1570 | d, jcp.id, j, jcp.ih); |
1571 | else |
1572 | nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, |
1573 | j, jcp.ih); |
1574 | } |
1575 | g += ti->g_start; |
1576 | ic_b += ti->ic_b_start; |
1577 | icb_start = ic_b; |
1578 | ic = is_src_layout_nxc ? g * jcp.ic + ic_b * jcp.ic_block |
1579 | : g * jcp.nb_ic + ic_b; |
1580 | } else { |
1581 | ic = is_src_layout_nxc ? g * jcp.ic + ic_b * jcp.ic_block |
1582 | : g * jcp.nb_ic + ic_b; |
1583 | g = 0; |
1584 | ic_b = 0; |
1585 | } |
1586 | const bool need_local_gwork = is_src_layout_nxc && jcp.global_transpose; |
1587 | const auto local_gwork = need_local_gwork ? ti->g_work : 1; |
1588 | |
1589 | for (int gg = g; gg < g + local_gwork; ++gg) { |
1590 | if (need_local_gwork) ic = gg * jcp.ic + ic_b * jcp.ic_block; |
1591 | src_data_t *tr_src = (jcp.ndims == 5) |
1592 | ? &ti->tr_src[tr_src_off_3d(gg, ic_b, d, j)] |
1593 | : &ti->tr_src[tr_src_off(gg, ic_b, j)]; |
1594 | auto src_offset = is_src_layout_nxc |
1595 | ? src_d.blk_off(img, ic) |
1596 | : (jcp.ndims == 5 ? src_d.blk_off(img, ic, d, j) |
1597 | : src_d.blk_off(img, ic, j)); |
1598 | src_data_t *src = (src_data_t *)&ti->src[src_offset]; |
1599 | |
1600 | if (is_src_layout_nxc) { |
1601 | dim_t sp_start_offset = (jcp.ndims == 5) |
1602 | ? src_d.blk_off(0, 0, d, j) |
1603 | : src_d.blk_off(0, 0, j); |
1604 | dim_t ch_shift = src_d.blk_off(0, jcp.ic_block); |
1605 | int sp_start_idx = d * jcp.ih + j; |
1606 | trans_src_nxc(tr_src, src, sp_start_idx, sp_start_offset, |
1607 | icb_start, ch_shift, my_work); |
1608 | } else |
1609 | trans_src(tr_src, src, my_work); |
1610 | } |
1611 | }; |
1612 | |
1613 | auto diff_dst_trans = [&](int img, int g = 0, int oc_b = 0) { |
1614 | int j {0}, d {0}; |
1615 | int my_work = jcp.oh * jcp.od; |
1616 | int oc; |
1617 | int ocb_start = oc_b; |
1618 | |
1619 | if (jcp.global_transpose) { |
1620 | const size_t work_amount = is_ddst_layout_nxc |
1621 | ? ti->oc_b_work * jcp.oh * jcp.od |
1622 | : ti->g_work * ti->oc_b_work * jcp.oh * jcp.od; |
1623 | |
1624 | size_t start {0}, end {0}; |
1625 | balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, start, end); |
1626 | my_work = end - start; |
1627 | |
1628 | if (is_ddst_layout_nxc) { |
1629 | if (jcp.ndims == 5) |
1630 | nd_iterator_init( |
1631 | start, oc_b, ti->oc_b_work, d, jcp.od, j, jcp.oh); |
1632 | else |
1633 | nd_iterator_init(start, oc_b, ti->oc_b_work, j, jcp.oh); |
1634 | } else { |
1635 | if (jcp.ndims == 5) |
1636 | nd_iterator_init(start, g, ti->g_work, oc_b, ti->oc_b_work, |
1637 | d, jcp.od, j, jcp.oh); |
1638 | else |
1639 | nd_iterator_init(start, g, ti->g_work, oc_b, ti->oc_b_work, |
1640 | j, jcp.oh); |
1641 | } |
1642 | g += ti->g_start; |
1643 | oc_b += ti->oc_b_start; |
1644 | ocb_start = oc_b; |
1645 | oc = is_ddst_layout_nxc ? g * jcp.oc + oc_b * jcp.oc_block |
1646 | : g * jcp.nb_oc + oc_b; |
1647 | } else { |
1648 | oc = is_ddst_layout_nxc ? g * jcp.oc + oc_b * jcp.oc_block |
1649 | : g * jcp.nb_oc + oc_b; |
1650 | g = 0; |
1651 | oc_b = 0; |
1652 | } |
1653 | const bool need_local_gwork |
1654 | = is_ddst_layout_nxc && jcp.global_transpose; |
1655 | const auto local_gwork = need_local_gwork ? ti->g_work : 1; |
1656 | |
1657 | for (int gg = g; gg < g + local_gwork; ++gg) { |
1658 | if (need_local_gwork) oc = gg * jcp.oc + oc_b * jcp.oc_block; |
1659 | diff_dst_data_t *tr_diff_dst = (jcp.ndims == 5) |
1660 | ? &ti->tr_diff_dst[tr_diff_dst_off_3d(gg, oc_b, d, j)] |
1661 | : &ti->tr_diff_dst[tr_diff_dst_off(gg, oc_b, j)]; |
1662 | auto ddst_offset = is_ddst_layout_nxc |
1663 | ? diff_dst_d.blk_off(img, oc) |
1664 | : (jcp.ndims == 5 ? diff_dst_d.blk_off(img, oc, d, j) |
1665 | : diff_dst_d.blk_off(img, oc, j)); |
1666 | const diff_dst_data_t *diff_dst = &ti->diff_dst[ddst_offset]; |
1667 | |
1668 | if (is_ddst_layout_nxc) { |
1669 | dim_t sp_start_offset = (jcp.ndims == 5) |
1670 | ? diff_dst_d.blk_off(0, 0, d, j) |
1671 | : diff_dst_d.blk_off(0, 0, j); |
1672 | dim_t ch_shift = diff_dst_d.blk_off(0, jcp.oc_block); |
1673 | int sp_start_idx = d * jcp.oh + j; |
1674 | trans_dst_nxc(tr_diff_dst, diff_dst, sp_start_idx, |
1675 | sp_start_offset, ocb_start, ch_shift, my_work); |
1676 | } else |
1677 | trans_dst(tr_diff_dst, diff_dst, my_work); |
1678 | } |
1679 | }; |
1680 | |
1681 | for (int img = ti->img_start; img < ti->img_end; ++img) { |
1682 | auto p = jit_conv_call_s(); |
1683 | if (jcp.global_transpose) { |
1684 | using simple_barrier::barrier; |
1685 | // TODO: try to call local transpositions just before jit kernel |
1686 | /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ |
1687 | if (jcp.transpose_src) { |
1688 | if (nthr_oc_b_ > 1) |
1689 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1690 | uker_trans(img); |
1691 | if (nthr_oc_b_ > 1) |
1692 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1693 | } |
1694 | if (jcp.transpose_dst) { |
1695 | if (nthr_ic_b_ > 1) |
1696 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1697 | diff_dst_trans(img); |
1698 | if (nthr_ic_b_ > 1) |
1699 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1700 | } |
1701 | } |
1702 | int ic_b_step |
1703 | = jcp.uses_permw_transposition ? jcp.nb_ic_blocking_max : 1; |
1704 | int icb_work = ti->ic_b_end - ti->ic_b_start; |
1705 | if (ic_b_step > 1 && icb_work > ic_b_step && icb_work < 2 * ic_b_step) |
1706 | ic_b_step = utils::div_up(icb_work, 2); |
1707 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1708 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) |
1709 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1710 | ic_b += ic_b_step) { |
1711 | const int ic_off_idx = is_src_layout_nxc |
1712 | ? g * jcp.ic + ic_b * jcp.ic_block |
1713 | : g * jcp.nb_ic + ic_b; |
1714 | const int oc_off_idx = is_ddst_layout_nxc |
1715 | ? g * jcp.oc + oc_b * jcp.oc_block |
1716 | : g * jcp.nb_oc + oc_b; |
1717 | const int ic_to_compute = this_block_size( |
1718 | ic_b * jcp.ic_block, jcp.ic, ic_b_step * jcp.ic_block); |
1719 | const int oc_to_compute = this_block_size( |
1720 | oc_b * jcp.oc_block, jcp.oc, jcp.oc_block); |
1721 | if (jcp.transpose_src) { |
1722 | if (!jcp.global_transpose) { |
1723 | uker_trans(img, g, ic_b); |
1724 | if (jcp.ndims == 5) { |
1725 | p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, 0, 0)]; |
1726 | } else { |
1727 | p.src = &ti->tr_src[tr_src_off(g, ic_b, 0)]; |
1728 | } |
1729 | } else { |
1730 | if (jcp.ndims == 5) { |
1731 | p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, 0, 0)]; |
1732 | } else { |
1733 | p.src = &ti->tr_src[tr_src_off(g, ic_b, 0)]; |
1734 | } |
1735 | } |
1736 | } else { |
1737 | p.src = &ti->src[src_d.blk_off(img, ic_off_idx)]; |
1738 | } |
1739 | |
1740 | if (jcp.transpose_dst) { |
1741 | if (!jcp.global_transpose) { |
1742 | diff_dst_trans(img, g, oc_b); |
1743 | if (jcp.ndims == 5) { |
1744 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d( |
1745 | 0, 0, 0, 0)]; |
1746 | } else { |
1747 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0)]; |
1748 | } |
1749 | } else { |
1750 | if (jcp.ndims == 5) { |
1751 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d( |
1752 | g, oc_b, 0, 0)]; |
1753 | } else { |
1754 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, 0)]; |
1755 | } |
1756 | } |
1757 | } else { |
1758 | p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, oc_off_idx)]; |
1759 | } |
1760 | |
1761 | p.filt = diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1762 | p.bias = diff_bias ? diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
1763 | + oc_b * jcp.oc_block |
1764 | : nullptr; |
1765 | p.channel = (img == ti->img_start); |
1766 | p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0); |
1767 | p.reduce_work = ic_to_compute; |
1768 | p.load_work = oc_to_compute; |
1769 | (*kernel_)(&p); |
1770 | } |
1771 | } |
1772 | } |
1773 | |
1774 | void jit_avx512_core_bf16_convolution_bwd_weights_t :: |
1775 | reduce_and_convert_diff_weights_and_bias( |
1776 | const thread_info_t *ti) const { |
1777 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1778 | |
1779 | const auto &jcp = kernel_->jcp; |
1780 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1781 | * jcp.ic_block * jcp.kh * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1); |
1782 | |
1783 | const bool is_bf16_out = diff_weights_d.data_type() == data_type::bf16; |
1784 | const bool is_bf16_bias = jcp.with_bias && jcp.bia_dt == data_type::bf16; |
1785 | if (nthr_mb_ == 1) { |
1786 | if (is_bf16_out) { |
1787 | // reduction is not required, only conversion |
1788 | for_(int g = ti->g_start; g < ti->g_end; g++) |
1789 | for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) { |
1790 | const size_t acc_size = (size_t)ti->ic_b_work * jcp.kh * jcp.kw |
1791 | * ((jcp.ndims == 5) ? jcp.kd : 1) * jcp.ic_block |
1792 | * jcp.oc_block; |
1793 | const size_t off |
1794 | = wht_blk_off(diff_weights_d, g, oc_b, ti->ic_b_start); |
1795 | cvt_float_to_bfloat16((bfloat16_t *)(ti->diff_weights) + off, |
1796 | (ti->wei_bia_reduction + off), acc_size); |
1797 | } |
1798 | } |
1799 | |
1800 | if (is_bf16_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0) { |
1801 | for (int g = ti->g_start; g < ti->g_end; g++) { |
1802 | int result_start_idx = g * jcp.oc_without_padding |
1803 | + ti->oc_b_start * jcp.oc_block; |
1804 | int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1805 | + ti->oc_b_start * jcp.oc_block; |
1806 | const size_t acc_size = nstl::min(jcp.oc_without_padding, |
1807 | ti->oc_b_end * jcp.oc_block) |
1808 | - ti->oc_b_start * jcp.oc_block; |
1809 | bfloat16_t *diff_bias |
1810 | = (bfloat16_t *)ti->diff_bias + result_start_idx; |
1811 | float *buffer = ti->bia_reduction + buffer_start_idx; |
1812 | cvt_float_to_bfloat16(diff_bias, buffer, acc_size); |
1813 | } |
1814 | } |
1815 | return; |
1816 | } |
1817 | |
1818 | /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ |
1819 | if (jcp.global_transpose) |
1820 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); |
1821 | |
1822 | const int ic_b_kh_work |
1823 | = ti->ic_b_work * ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1824 | const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; |
1825 | |
1826 | int start {0}, end {0}; |
1827 | balance211(work, nthr_mb_, ti->ithr_mb, start, end); |
1828 | if (start == end) return; |
1829 | |
1830 | const int _start_nthr_mb = 1; |
1831 | for (int thr_mb = _start_nthr_mb; thr_mb < nthr_mb_; ++thr_mb) { |
1832 | int w = start; |
1833 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0}; |
1834 | nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, |
1835 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1836 | while (w < end) { |
1837 | const int g = ti->g_start + sub_g_start; |
1838 | const int oc_b = ti->oc_b_start + sub_oc_b_start; |
1839 | const int ic_b = ti->ic_b_start |
1840 | + sub_ic_b_kh_start / ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1841 | const int kX |
1842 | = sub_ic_b_kh_start % ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1843 | |
1844 | const size_t acc_size = (size_t)jcp.kw * jcp.ic_block * jcp.oc_block |
1845 | * ((jcp.ndims == 5) ? jcp.kh : 1) |
1846 | * nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start); |
1847 | |
1848 | const size_t off = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kX); |
1849 | |
1850 | float *wei_reduced = is_bf16_out |
1851 | ? ti->wei_bia_reduction + off |
1852 | : (float *)(ti->diff_weights) + off; |
1853 | |
1854 | int thr_mb_buffer_idx = is_bf16_out ? thr_mb : thr_mb - 1; |
1855 | float *wei_to_reduce = ti->wei_bia_reduction |
1856 | + thr_mb_buffer_idx * wei_size + off; |
1857 | |
1858 | if (is_bf16_out && thr_mb == nthr_mb_ - 1) |
1859 | // the last iteration for bfloat16 requires conversion and |
1860 | // store to diff_weights array |
1861 | add_floats_and_cvt_to_bfloat16( |
1862 | (bfloat16_t *)(ti->diff_weights) + off, wei_reduced, |
1863 | wei_to_reduce, acc_size); |
1864 | else |
1865 | acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size); |
1866 | |
1867 | nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, |
1868 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1869 | } |
1870 | if (jcp.with_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0 |
1871 | && ti->ithr_mb == 0 && ti->img_work > 0) { |
1872 | for (int g = ti->g_start; g < ti->g_end; g++) { |
1873 | float *bias_reduced = is_bf16_bias ? ti->bia_reduction |
1874 | : (float *)(ti->diff_bias); |
1875 | int thr_mb_buffer_idx = is_bf16_bias ? thr_mb : thr_mb - 1; |
1876 | int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1877 | float *bias_to_reduce |
1878 | = ti->bia_reduction + thr_mb_buffer_idx * bias_buf_size; |
1879 | const size_t acc_size = nstl::min(jcp.oc_without_padding, |
1880 | ti->oc_b_end * jcp.oc_block) |
1881 | - ti->oc_b_start * jcp.oc_block; |
1882 | int idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1883 | + ti->oc_b_start * jcp.oc_block; |
1884 | if (is_bf16_bias && thr_mb == nthr_mb_ - 1) { |
1885 | // the last iteration for bfloat16 requires conversion and |
1886 | // store to diff_weights array |
1887 | int diff_bias_idx = g * jcp.oc_without_padding |
1888 | + ti->oc_b_start * jcp.oc_block; |
1889 | add_floats_and_cvt_to_bfloat16( |
1890 | (bfloat16_t *)(ti->diff_bias) + diff_bias_idx, |
1891 | &bias_reduced[idx], &bias_to_reduce[idx], acc_size); |
1892 | } else { |
1893 | acc_ker_->accumulate( |
1894 | &bias_reduced[idx], &bias_to_reduce[idx], acc_size); |
1895 | } |
1896 | } |
1897 | } |
1898 | } |
1899 | } |
1900 | |
1901 | void jit_avx512_core_bf16_convolution_bwd_weights_t::prepare_scratchpad_data( |
1902 | const exec_ctx_t &ctx) const { |
1903 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1904 | |
1905 | const auto &jcp = pd()->jcp_; |
1906 | |
1907 | if (jcp.transpose_src) { |
1908 | // XXX: See the comment about tr_iw and guarding elements in |
1909 | // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf() |
1910 | auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); |
1911 | // Zero out guard elements that cross a buffer boundary to prevent a |
1912 | // race condition due to buffer overflows from memory optimization where |
1913 | // buffers sharing padding |
1914 | for (size_t ithr = 1; ithr <= jcp.tr_src_buf_count; ++ithr) { |
1915 | src_data_t *ts = &tr_src[ithr * jcp.tr_src_buf_size]; |
1916 | for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i) |
1917 | ts[i] = 0; |
1918 | } |
1919 | |
1920 | if (jcp.global_transpose && jcp.nthr_oc_b > 1) { |
1921 | const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; |
1922 | auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
1923 | key_conv_tr_src_bctx); |
1924 | for (int i = 0; i < tr_src_bctx_size; ++i) |
1925 | simple_barrier::ctx_init(&tr_src_bctx[i]); |
1926 | } |
1927 | } |
1928 | if (jcp.global_transpose && jcp.transpose_dst) { |
1929 | if (jcp.nthr_ic_b > 1) { |
1930 | const int tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b; |
1931 | auto tr_diff_dst_bctx |
1932 | = scratchpad.template get<simple_barrier::ctx_t>( |
1933 | key_conv_tr_diff_dst_bctx); |
1934 | for (int i = 0; i < tr_diff_dst_bctx_size; ++i) |
1935 | simple_barrier::ctx_init(&tr_diff_dst_bctx[i]); |
1936 | } |
1937 | } |
1938 | |
1939 | if (jcp.global_transpose |
1940 | && (nthr_mb_ > 1 |
1941 | || pd()->diff_weights_md(0)->data_type |
1942 | == data_type::bf16)) { |
1943 | // TODO: don't use barrier for case |
1944 | // diff_weights_type == data_type::bf16 && nthr_mb_ == 1 |
1945 | simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>( |
1946 | key_conv_wei_bia_reduction_bctx)); |
1947 | } |
1948 | } |
1949 | |
1950 | void jit_avx512_core_bf16_convolution_bwd_weights_t ::execute_backward_weights( |
1951 | const exec_ctx_t &ctx) const { |
1952 | prepare_scratchpad_data(ctx); |
1953 | |
1954 | const auto &jcp = pd()->jcp_; |
1955 | parallel(nthr_, [&](const int ithr, const int nthr) { |
1956 | assert(nthr_ == nthr); |
1957 | assert(utils::one_of(pd()->ndims(), 3, 4, 5)); |
1958 | |
1959 | thread_info_t thread_info(this, ctx, ithr); |
1960 | switch (jcp.harness) { |
1961 | case harness_2d_reduction: |
1962 | compute_diff_weights_2d(&thread_info); |
1963 | if (jcp.global_transpose) |
1964 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1965 | break; |
1966 | case harness_3d_reduction: |
1967 | compute_diff_weights_3d(&thread_info); |
1968 | if (jcp.global_transpose) |
1969 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1970 | break; |
1971 | case harness_compute_full_spatial: |
1972 | case harness_mb_reduction: |
1973 | compute_diff_weights(&thread_info); |
1974 | if (jcp.global_transpose) |
1975 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1976 | break; |
1977 | default: assert(!"Invalid harness type" ); |
1978 | } |
1979 | }); |
1980 | |
1981 | if (!jcp.global_transpose) { |
1982 | parallel(nthr_, [&](const int ithr, const int nthr) { |
1983 | assert(nthr_ == nthr); |
1984 | thread_info_t thread_info(this, ctx, ithr); |
1985 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1986 | }); |
1987 | } |
1988 | |
1989 | if (pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0) |
1990 | && jcp.bia_dt != data_type::bf16) { |
1991 | auto diff_bias = ctx.get_scratchpad_grantor().template get<const float>( |
1992 | key_conv_padded_bias); |
1993 | auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
1994 | const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); |
1995 | const int stride = jcp.oc_without_padding; |
1996 | for (int g = 0; g < jcp.ngroups; ++g) { |
1997 | utils::array_copy(diff_bias_in + g * stride, |
1998 | diff_bias + g * padded_stride, stride); |
1999 | } |
2000 | } |
2001 | } |
2002 | |
2003 | } // namespace x64 |
2004 | } // namespace cpu |
2005 | } // namespace impl |
2006 | } // namespace dnnl |
2007 | |
2008 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
2009 | |