1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/cpu_primitive.hpp" |
23 | #include "cpu/scale_utils.hpp" |
24 | |
25 | #include "cpu/x64/jit_avx512_core_amx_conv_utils.hpp" |
26 | #include "cpu/x64/jit_avx512_core_amx_convolution.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | using namespace dnnl::impl::status; |
34 | using namespace dnnl::impl::memory_tracking::names; |
35 | using namespace dnnl::impl::utils; |
36 | |
37 | using namespace nstl; |
38 | |
39 | #define wht_blk_off(d, g, ...) \ |
40 | (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) \ |
41 | : (d).blk_off(__VA_ARGS__)) |
42 | |
43 | #define mem_blk_off(mdw, n, c, d, h, w) \ |
44 | (pd()->ndims() == 3 \ |
45 | ? (mdw).blk_off((n), (c), (w)) \ |
46 | : (pd()->ndims() == 4 \ |
47 | ? (mdw).blk_off((n), (c), (h), (w)) \ |
48 | : (mdw).blk_off((n), (c), (d), (h), (w)))) |
49 | |
50 | template <typename T> |
51 | static inline T accum_with_upper_bound(T ub, T lv, T uv) { |
52 | return nstl::min(ub, nstl::min(ub, lv) + nstl::max(0, ub - uv)); |
53 | } |
54 | |
55 | void jit_avx512_core_amx_convolution_fwd_t::prepare_padded_bias( |
56 | const char *&bias, const memory_tracking::grantor_t &scratchpad) const { |
57 | if (!pd()->wants_padded_bias()) return; |
58 | |
59 | const size_t bia_dt_size = pd()->jcp_.typesize_bia; |
60 | auto padded_bias = scratchpad.template get<char>( |
61 | memory_tracking::names::key_conv_padded_bias); |
62 | utils::array_copy( |
63 | padded_bias, bias, bia_dt_size * pd()->jcp_.oc_without_padding); |
64 | utils::array_set(padded_bias + bia_dt_size * pd()->jcp_.oc_without_padding, |
65 | 0.f, bia_dt_size * (pd()->jcp_.oc - pd()->jcp_.oc_without_padding)); |
66 | bias = padded_bias; |
67 | } |
68 | |
69 | status_t |
70 | jit_avx512_core_amx_convolution_fwd_t::execute_forward_reduced_lowering( |
71 | const exec_ctx_t &ctx) const { |
72 | const auto &jcp = pd()->jcp_; |
73 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
74 | const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
75 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
76 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
77 | const auto post_ops_binary_rhs_arg_vec |
78 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
79 | |
80 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
81 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
82 | |
83 | const memory_desc_wrapper src_d(pd()->src_md()); |
84 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
85 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
86 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
87 | |
88 | const size_t src_dt_size = types::data_type_size(src_d.data_type()); |
89 | const size_t wei_dt_size = types::data_type_size(weights_d.data_type()); |
90 | const size_t bia_dt_size |
91 | = pd()->with_bias() ? types::data_type_size(bias_d.data_type()) : 0; |
92 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
93 | |
94 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
95 | |
96 | assert(jcp.is_relo); |
97 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
98 | |
99 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
100 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
101 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
102 | |
103 | const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), |
104 | src_scales, wei_scales, pd()->OC(), pd()->attr()); |
105 | |
106 | auto inp_p_buffer = ctx.get_scratchpad_grantor().template get<char>( |
107 | key_conv_amx_inp_buffer); // fix the template |
108 | auto wei_buffer = ctx.get_scratchpad_grantor().template get<char>( |
109 | key_conv_amx_wei_buffer); // fix the template |
110 | auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>( |
111 | key_conv_amx_wsp_buffer); |
112 | auto tcfg = ctx.get_scratchpad_grantor().template get<char>( |
113 | key_conv_amx_tilecfg); |
114 | auto zero_point_pbuff = ctx.get_scratchpad_grantor().template get<int32_t>( |
115 | key_conv_zero_point_pad); |
116 | auto zp_flags_ = ctx.get_scratchpad_grantor().template get<bool>( |
117 | key_conv_zero_point_flag); |
118 | |
119 | const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
120 | char *w = const_cast<char *>(weights); |
121 | int32_t *zp_compensation = jcp.src_zero_point |
122 | ? reinterpret_cast<int32_t *>(w + offset) |
123 | : nullptr; |
124 | |
125 | const int t_pad_output = jcp.t_pad_output; |
126 | const int b_pad_output = jcp.b_pad_output; |
127 | const int b_pad_start = nstl::max(jcp.oh - b_pad_output, t_pad_output); |
128 | const int zp_buff_b_pad_start |
129 | = nstl::max(jcp.oh_pad - b_pad_output, t_pad_output); |
130 | |
131 | const int ngroups = jcp.ngroups; |
132 | const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
133 | const int oh_chunks = utils::div_up(jcp.oh, jcp.oh_blk_size); |
134 | const int work_amount |
135 | = jcp.mb * jcp.ngroups * oh_chunks * jcp.nb_ow * oc_chunks; |
136 | const int zp_pbuff_size = jcp.zp_pbuff_size; |
137 | |
138 | // reorder weights from (g)Owhi16o to (g)OR16r16o4r, where r := whi |
139 | auto p = jit_conv_call_s(); |
140 | p.src = weights; |
141 | p.dst = wei_buffer; |
142 | kernel_->copy_to_wbuffer()(&p); |
143 | const char *wei = wei_buffer; |
144 | |
145 | const size_t oc_subblock_step |
146 | = jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block; |
147 | const size_t wei_oc_shift = (size_t)jcp.nb_oc_blocking * jcp.nb_ic_int |
148 | * rnd_up(oc_subblock_step, jcp.ic_block_int * jcp.oc_block); |
149 | |
150 | // Initialize the tile configuration in memory, so that each thread can |
151 | // load this configuration from memory via `amx_tile_configure(tcfg)`. |
152 | kernel_->tile_configure(tcfg); |
153 | |
154 | // init zero_point padding buffer |
155 | const bool req_zero_point_buffer = jcp.req_zero_point_buffer; |
156 | const bool zp_pbuff_outer_compute = jcp.zp_pbuff_outer_compute; |
157 | const bool zp_pbuff_parallel_block |
158 | = req_zero_point_buffer && !zp_pbuff_outer_compute; |
159 | if (req_zero_point_buffer && zp_pbuff_outer_compute) { |
160 | const size_t wei_oc_step = (size_t)jcp.kh * jcp.kw * jcp.ic_block_int_np |
161 | * jcp.nb_oc_blocking * jcp.oc_block; |
162 | const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1); |
163 | const int dilate_h = jcp.dilate_h + 1; |
164 | const int gen_kh = (jcp.kh - 1) * dilate_h + 1; |
165 | const int oh_work = jcp.oh_pad; |
166 | parallel_nd( |
167 | ngroups, oc_chunks, oh_work, [&](dim_t g, dim_t occ, dim_t oh) { |
168 | auto p = jit_conv_call_s(); |
169 | |
170 | const int oh_ = oh >= zp_buff_b_pad_start |
171 | ? b_pad_start + oh - zp_buff_b_pad_start |
172 | : oh; |
173 | const int ih = oh_ * jcp.stride_h - jcp.t_pad; |
174 | const int t_overflow |
175 | = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h)); |
176 | const int b_overflow = nstl::min(jcp.kh, |
177 | div_up(nstl::max(0, ih + gen_kh - jcp.ih), |
178 | dilate_h)); |
179 | p.t_overflow = t_overflow; |
180 | p.b_overflow = b_overflow; |
181 | p.kh_padding |
182 | = nstl::max(0, jcp.kh - t_overflow - b_overflow); |
183 | |
184 | const int ocb = g * jcp.oc |
185 | + occ * jcp.nb_oc_blocking * jcp.oc_block; |
186 | const size_t ch_offset = dst_d.blk_off(0, ocb); |
187 | const size_t sp_offset = oh * jcp.ow_pad * sp_stride; |
188 | p.zero_point_pbuff |
189 | = &zero_point_pbuff[ch_offset + sp_offset]; |
190 | p.oc_blocks = occ * jcp.nb_oc_blocking; |
191 | p.filt = weights |
192 | + wei_dt_size * (g * oc_chunks + occ) * wei_oc_step; |
193 | p.src_zero_point = src_zero_point; |
194 | |
195 | kernel_->zp_pbuff_kernel()(&p); |
196 | }); |
197 | } |
198 | |
199 | // TODO: implement 2D parallelization driver (g * spatial x oc) to increase |
200 | // input data reuse and parallelize input data reorders |
201 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
202 | int start {0}, end {0}; |
203 | balance211(work_amount, nthr, ithr, start, end); |
204 | int32_t *local_zp_pbuff = req_zero_point_buffer |
205 | ? (zp_pbuff_outer_compute |
206 | ? zero_point_pbuff |
207 | : &zero_point_pbuff[ithr * zp_pbuff_size]) |
208 | : nullptr; |
209 | bool *zp_flags = zp_pbuff_parallel_block |
210 | ? &zp_flags_[ithr * oc_chunks * ngroups] |
211 | : nullptr; |
212 | if (zp_pbuff_parallel_block) { |
213 | PRAGMA_OMP_SIMD() |
214 | for (int oc = 0; oc < oc_chunks * ngroups; oc++) |
215 | zp_flags[oc] = true; |
216 | } |
217 | |
218 | auto p = jit_conv_call_s(); |
219 | amx_tile_configure(tcfg); |
220 | |
221 | const int oh_work = jcp.oh_pad; |
222 | const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1); |
223 | const int dilate_h = jcp.dilate_h + 1; |
224 | const int gen_kh = (jcp.kh - 1) * dilate_h + 1; |
225 | const size_t wei_oc_step = (size_t)jcp.kh * jcp.kw * jcp.ic_block_int_np |
226 | * jcp.nb_oc_blocking * jcp.oc_block; |
227 | size_t oc_stride = dst_d.blk_off(0, 1); |
228 | const int owb_limit = jcp.nb_ow - jcp.r_pad_blk - jcp.no_pad_w_blk; |
229 | |
230 | int mb {0}, g {0}, ohc {0}, owb {0}, occ {0}; |
231 | // need "inner" oh blocks w.r.t. ow blocks to allow pbuffer reuse |
232 | nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, owb, jcp.nb_ow, ohc, |
233 | oh_chunks, occ, oc_chunks); |
234 | int last_copied_mb = -1; |
235 | int last_copied_ohc = -1; |
236 | int last_copied_owb = -1; |
237 | int last_copied_g = -1; |
238 | while (start < end) { |
239 | char *inp_buffer |
240 | = inp_p_buffer + src_dt_size * ithr * jcp.inp_buffer_size; |
241 | |
242 | assert(IMPLICATION( |
243 | jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding)); |
244 | int oc = g * jcp.oc + occ * jcp.nb_oc_blocking * jcp.oc_block; |
245 | int ocb = jcp.is_nspc ? oc : oc / jcp.oc_block; |
246 | const char *bias_w = bias |
247 | ? bias + (bias_d.blk_off(oc) * bia_dt_size) |
248 | : nullptr; |
249 | p.zp_compensation |
250 | = jcp.src_zero_point ? zp_compensation + oc : nullptr; |
251 | p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; |
252 | p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; |
253 | |
254 | int oh_s = ohc * jcp.oh_blk_size; |
255 | int oh_e = nstl::min(jcp.oh, oh_s + jcp.oh_blk_size); |
256 | bool is_inp_buffer_relevant = true && last_copied_mb == mb |
257 | && last_copied_ohc == ohc && last_copied_owb == owb |
258 | && last_copied_g == g; |
259 | bool has_inp_buffer_overlap = true && last_copied_mb == mb |
260 | && last_copied_owb == owb && last_copied_g == g |
261 | && jcp.oh_blk_size == jcp.nb_oh_blocking; |
262 | bool is_zp_pbuff_relevant = zp_pbuff_parallel_block |
263 | ? zp_flags[g * oc_chunks + occ] // already computed? |
264 | : false; |
265 | |
266 | int cur_t_pad = nstl::max(0, t_pad_output - oh_s); |
267 | int cur_b_pad = nstl::max( |
268 | nstl::max(0, jcp.oh - b_pad_output - oh_s), cur_t_pad); |
269 | size_t zp_oh |
270 | = accum_with_upper_bound(oh_s, t_pad_output, b_pad_start); |
271 | |
272 | int limit_idx = 0; |
273 | constexpr int limit_size = 5; |
274 | for (; req_zero_point_buffer && limit_idx < limit_size; |
275 | limit_idx++) { |
276 | // find current 'oh_blk' index from 'oh_s` |
277 | if ((size_t)oh_s < jcp.h_blk_limits[limit_idx]) break; |
278 | } |
279 | |
280 | if (is_zp_pbuff_relevant) { |
281 | assert(!zp_pbuff_outer_compute); |
282 | zp_flags[g * oc_chunks + occ] = false; |
283 | for (int oh_pad = 0; oh_pad < oh_work; ++oh_pad) { |
284 | const int oh_ = oh_pad >= zp_buff_b_pad_start |
285 | ? b_pad_start + oh_pad - zp_buff_b_pad_start |
286 | : oh_pad; |
287 | const int ih = oh_ * jcp.stride_h - jcp.t_pad; |
288 | const int t_overflow |
289 | = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h)); |
290 | const int b_overflow = nstl::min(jcp.kh, |
291 | div_up(nstl::max(0, ih + gen_kh - jcp.ih), |
292 | dilate_h)); |
293 | p.t_overflow = t_overflow; |
294 | p.b_overflow = b_overflow; |
295 | p.kh_padding |
296 | = nstl::max(0, jcp.kh - t_overflow - b_overflow); |
297 | |
298 | const size_t ch_offset = dst_d.blk_off(0, oc); |
299 | const size_t sp_offset = oh_pad * jcp.ow_pad * sp_stride; |
300 | p.zero_point_pbuff = &local_zp_pbuff[ch_offset + sp_offset]; |
301 | p.oc_blocks = occ * jcp.nb_oc_blocking; |
302 | p.filt = weights |
303 | + wei_dt_size * (g * oc_chunks + occ) * wei_oc_step; |
304 | p.src_zero_point = src_zero_point; |
305 | |
306 | kernel_->zp_pbuff_kernel()(&p); |
307 | } |
308 | } |
309 | |
310 | int oh_step = jcp.nb_oh_blocking * jcp.oh_per_tile; |
311 | for (int oh = oh_s; oh < oh_e; oh += oh_step) { |
312 | const int inp_buffer_h_step |
313 | = jcp.stride_h * jcp.ic_without_padding; |
314 | assert(jcp.is_nspc); |
315 | assert(jcp.stride_h <= jcp.kh); |
316 | |
317 | int ow = owb * jcp.ow_block; |
318 | |
319 | char *inp_buffer_oh |
320 | = inp_buffer + src_dt_size * oh * inp_buffer_h_step; |
321 | |
322 | if (!is_inp_buffer_relevant) { |
323 | // prepare padded input buffer |
324 | const int icb = g * jcp.ic; |
325 | size_t inp_offset = src_d.blk_off(mb, icb); |
326 | const int iw_step = jcp.ngroups * jcp.ic_without_padding; |
327 | const char *psrc = src + src_dt_size * inp_offset; |
328 | // calculate overlap... |
329 | const int ih_overlap = has_inp_buffer_overlap |
330 | * nstl::max(0, jcp.kh - oh_step * jcp.stride_h); |
331 | const int kh_eff = jcp.kh - ih_overlap; |
332 | // prepare padded input buffer |
333 | char *pdst = inp_buffer_oh |
334 | + src_dt_size * ih_overlap * jcp.ic_without_padding; |
335 | for (int doh = 0; doh < oh_step; doh++) { |
336 | const int ih_s = (doh + oh) * jcp.stride_h - jcp.t_pad |
337 | + ih_overlap; |
338 | const int ih_e = ih_s + kh_eff; |
339 | const int ih = nstl::max(0, ih_s); |
340 | p.t_overflow = nstl::max(0, -ih_s); |
341 | p.b_overflow = nstl::min<int>( |
342 | kh_eff, nstl::max(0, ih_e - jcp.ih)); |
343 | p.kh_padding = nstl::max<int>( |
344 | 0, (kh_eff - p.t_overflow - p.b_overflow)); |
345 | p.kh_offset = kh_eff; |
346 | |
347 | const int iw_s = ow * jcp.stride_w - jcp.l_pad; |
348 | const int iw_e = iw_s + jcp.iwp; |
349 | const int iw = nstl::max(0, iw_s); |
350 | p.f_overflow = nstl::max(0, -iw_s); |
351 | p.back_overflow = nstl::max(0, iw_e - jcp.iw); |
352 | p.kw_padding = nstl::max<int>( |
353 | 0, jcp.iwp - p.f_overflow - p.back_overflow); |
354 | |
355 | p.src = psrc |
356 | + src_dt_size * (ih * jcp.iw + iw) * iw_step; |
357 | p.dst = pdst |
358 | + src_dt_size * doh * jcp.iwp * jcp.kh |
359 | * jcp.ic_without_padding; |
360 | |
361 | kernel_->copy_to_pbuffer()(&p); |
362 | } |
363 | } |
364 | |
365 | p.src = inp_buffer_oh; |
366 | size_t dst_offset = mem_blk_off(dst_d, mb, ocb, 0, oh, ow); |
367 | p.dst = dst + dst_dt_size * dst_offset; |
368 | const size_t pbuff_offset |
369 | = zp_oh * jcp.ow_pad * sp_stride + ocb * oc_stride; |
370 | p.zero_point_pbuff = req_zero_point_buffer |
371 | ? &local_zp_pbuff[pbuff_offset] |
372 | : nullptr; |
373 | p.filt = wei |
374 | + wei_dt_size * (g * oc_chunks + occ) * wei_oc_shift; |
375 | p.bias = bias_w; |
376 | p.scales = &oscales[jcp.is_oc_scale * oc]; |
377 | p.dst_scale = &dst_scales[0]; |
378 | |
379 | p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size; |
380 | |
381 | if (req_zero_point_buffer |
382 | && (size_t)oh >= jcp.h_blk_limits[limit_idx]) |
383 | limit_idx++; |
384 | assert(limit_idx < 6); |
385 | p.ohb = limit_idx; |
386 | p.last_h = (oh + oh_step <= oh_e); |
387 | const int zp_owb = nstl::min(jcp.l_pad_blk, owb) |
388 | + nstl::max(0, owb - owb_limit); |
389 | p.owb = req_zero_point_buffer ? zp_owb : owb; |
390 | |
391 | p.oc_blocks = occ * jcp.nb_oc_blocking; |
392 | |
393 | p.oc_l_off = oc; |
394 | p.post_ops_binary_rhs_arg_vec |
395 | = post_ops_binary_rhs_arg_vec.data(); |
396 | p.dst_orig = dst; |
397 | |
398 | (*kernel_)(&p); |
399 | |
400 | if (req_zero_point_buffer) { |
401 | zp_oh += accum_with_upper_bound( |
402 | oh_step, cur_t_pad, cur_b_pad); |
403 | cur_t_pad = nstl::max(0, cur_t_pad - oh_step); |
404 | cur_b_pad = nstl::max(0, cur_b_pad - oh_step); |
405 | } |
406 | } |
407 | last_copied_mb = mb; |
408 | last_copied_ohc = ohc; |
409 | last_copied_owb = owb; |
410 | last_copied_g = g; |
411 | ++start; |
412 | // need "inner" oh blocks w.r.t. ow blocks to allow pbuffer reuse |
413 | nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, owb, jcp.nb_ow, ohc, |
414 | oh_chunks, occ, oc_chunks); |
415 | } |
416 | |
417 | amx_tile_release(); |
418 | }); |
419 | return status::success; |
420 | } |
421 | |
422 | status_t jit_avx512_core_amx_convolution_fwd_t::execute_forward( |
423 | const exec_ctx_t &ctx) const { |
424 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
425 | const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
426 | auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); |
427 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
428 | const auto post_ops_binary_rhs_arg_vec |
429 | = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); |
430 | |
431 | const memory_desc_wrapper src_d(pd()->src_md()); |
432 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
433 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
434 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
435 | |
436 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
437 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
438 | |
439 | const size_t bia_dt_size = pd()->with_bias() |
440 | ? types::data_type_size(pd()->desc()->bias_desc.data_type) |
441 | : 0; |
442 | const size_t dst_dt_size |
443 | = types::data_type_size(pd()->desc()->dst_desc.data_type); |
444 | const size_t src_dt_size |
445 | = types::data_type_size(pd()->desc()->src_desc.data_type); |
446 | const size_t wei_dt_size |
447 | = types::data_type_size(pd()->desc()->weights_desc.data_type); |
448 | |
449 | prepare_padded_bias(bias, ctx.get_scratchpad_grantor()); |
450 | |
451 | const auto &jcp = pd()->jcp_; |
452 | assert(jcp.nb_oc % jcp.nb_oc_blocking == 0); |
453 | |
454 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
455 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
456 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
457 | |
458 | const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), |
459 | src_scales, wei_scales, pd()->OC(), pd()->attr()); |
460 | |
461 | // TODO: use block offset instead of hand-calculated one |
462 | //size_t wei_oc_shift = wht_blk_off(weights_d, 0, 1); |
463 | const size_t wei_oc_shift = (size_t)jcp.nb_oc_blocking * jcp.nb_ic_int |
464 | * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block; |
465 | const size_t wei_d_shift |
466 | = (size_t)jcp.kh * jcp.kw * jcp.ic_block_int_np * jcp.oc_block; |
467 | |
468 | auto inp_p_buffer = ctx.get_scratchpad_grantor().template get<char>( |
469 | key_conv_amx_inp_buffer); // fix the template |
470 | auto wsp = ctx.get_scratchpad_grantor().template get<int32_t>( |
471 | key_conv_amx_wsp_buffer); |
472 | auto tcfg = ctx.get_scratchpad_grantor().template get<char>( |
473 | key_conv_amx_tilecfg); |
474 | auto zero_point_pbuff = ctx.get_scratchpad_grantor().template get<int32_t>( |
475 | key_conv_zero_point_pad); |
476 | auto zp_flags_ = ctx.get_scratchpad_grantor().template get<bool>( |
477 | key_conv_zero_point_flag); |
478 | |
479 | const size_t offset = weights_d.size() - weights_d.additional_buffer_size(); |
480 | char *w = const_cast<char *>(weights); |
481 | int32_t *zp_compensation = jcp.src_zero_point |
482 | ? reinterpret_cast<int32_t *>(&w[offset]) |
483 | : nullptr; |
484 | |
485 | const int f_pad_output = jcp.f_pad_output; |
486 | const int back_pad_output = jcp.back_pad_output; |
487 | const int back_pad_start |
488 | = nstl::max(jcp.od - back_pad_output, f_pad_output); |
489 | const int zp_buff_back_pad_start |
490 | = nstl::max(jcp.od_pad - back_pad_output, f_pad_output); |
491 | const int t_pad_output = jcp.t_pad_output; |
492 | const int b_pad_output = jcp.b_pad_output; |
493 | const int b_pad_start = nstl::max(jcp.oh - b_pad_output, t_pad_output); |
494 | const int zp_buff_b_pad_start |
495 | = nstl::max(jcp.oh_pad - b_pad_output, t_pad_output); |
496 | |
497 | const int ngroups = jcp.ngroups; |
498 | const int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking; |
499 | const int oh_chunks = utils::div_up(jcp.oh, jcp.oh_blk_size); |
500 | const size_t work_amount = (size_t)jcp.mb * jcp.ngroups * jcp.od * oh_chunks |
501 | * jcp.nb_ow * oc_chunks; |
502 | const int zp_pbuff_size = jcp.zp_pbuff_size; |
503 | |
504 | // Initialize the tile configuration in memory, so that each thread can |
505 | // load this configuration from memory via `amx_tile_configure(tcfg)`. |
506 | kernel_->tile_configure(tcfg); |
507 | |
508 | // init zero_point padding buffer |
509 | const bool req_zero_point_buffer = jcp.req_zero_point_buffer; |
510 | const bool zp_pbuff_outer_compute = jcp.zp_pbuff_outer_compute; |
511 | const bool zp_pbuff_parallel_block |
512 | = req_zero_point_buffer && !zp_pbuff_outer_compute; |
513 | if (req_zero_point_buffer && zp_pbuff_outer_compute) { |
514 | const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1); |
515 | const int dilate_d = jcp.dilate_d + 1; |
516 | const int gen_kd = (jcp.kd - 1) * dilate_d + 1; |
517 | const int dilate_h = jcp.dilate_h + 1; |
518 | const int gen_kh = (jcp.kh - 1) * dilate_h + 1; |
519 | const int od_work = jcp.od_pad; |
520 | const int oh_work = jcp.oh_pad; |
521 | parallel_nd(ngroups, oc_chunks, od_work, oh_work, |
522 | [&](dim_t g, dim_t occ, dim_t od, dim_t oh) { |
523 | auto p = jit_conv_call_s(); |
524 | |
525 | const int od_ = od >= zp_buff_back_pad_start |
526 | ? back_pad_start + od - zp_buff_back_pad_start |
527 | : od; |
528 | const int id = od_ * jcp.stride_d - jcp.f_pad; |
529 | const int f_overflow |
530 | = nstl::min(jcp.kd, div_up(max(0, -id), dilate_d)); |
531 | const int back_overflow = nstl::min(jcp.kd, |
532 | div_up(nstl::max(0, id + gen_kd - jcp.id), |
533 | dilate_d)); |
534 | p.f_overflow = f_overflow; |
535 | p.back_overflow = back_overflow; |
536 | p.kd_padding |
537 | = nstl::max(0, jcp.kd - f_overflow - back_overflow); |
538 | const int oh_ = oh >= zp_buff_b_pad_start |
539 | ? b_pad_start + oh - zp_buff_b_pad_start |
540 | : oh; |
541 | const int ih = oh_ * jcp.stride_h - jcp.t_pad; |
542 | const int t_overflow |
543 | = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h)); |
544 | const int b_overflow = nstl::min(jcp.kh, |
545 | div_up(nstl::max(0, ih + gen_kh - jcp.ih), |
546 | dilate_h)); |
547 | p.t_overflow = t_overflow; |
548 | p.b_overflow = b_overflow; |
549 | p.kh_padding |
550 | = nstl::max(0, jcp.kh - t_overflow - b_overflow); |
551 | |
552 | const int ocb = g * jcp.oc |
553 | + occ * jcp.nb_oc_blocking * jcp.oc_block; |
554 | const size_t ch_offset = dst_d.blk_off(0, ocb); |
555 | auto sp_offset |
556 | = (od * jcp.oh_pad + oh) * jcp.ow_pad * sp_stride; |
557 | p.zero_point_pbuff |
558 | = &zero_point_pbuff[ch_offset + sp_offset]; |
559 | p.oc_blocks = occ * jcp.nb_oc_blocking; |
560 | p.filt = weights |
561 | + wei_dt_size * (g * oc_chunks + occ) |
562 | * wei_oc_shift; |
563 | p.src_zero_point = src_zero_point; |
564 | |
565 | kernel_->zp_pbuff_kernel()(&p); |
566 | }); |
567 | } |
568 | |
569 | // TODO: implement 2D parallelization driver (g * spatial x oc) to increase |
570 | // input data reuse and parallelize input data reorders |
571 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
572 | size_t start {0}, end {0}; |
573 | balance211(work_amount, nthr, ithr, start, end); |
574 | int32_t *local_zp_pbuff = req_zero_point_buffer |
575 | ? (zp_pbuff_outer_compute |
576 | ? zero_point_pbuff |
577 | : &zero_point_pbuff[ithr * zp_pbuff_size]) |
578 | : nullptr; |
579 | bool *zp_flags = zp_pbuff_parallel_block |
580 | ? &zp_flags_[ithr * oc_chunks * ngroups] |
581 | : nullptr; |
582 | if (zp_pbuff_parallel_block) { |
583 | PRAGMA_OMP_SIMD() |
584 | for (int oc = 0; oc < oc_chunks * ngroups; oc++) |
585 | zp_flags[oc] = true; |
586 | } |
587 | |
588 | auto p = jit_conv_call_s(); |
589 | amx_tile_configure(tcfg); |
590 | |
591 | const int oh_work = jcp.oh_pad; |
592 | const int sp_stride = mem_blk_off(dst_d, 0, 0, 0, 0, 1); |
593 | const int dilate_d = jcp.dilate_d + 1; |
594 | const int dilate_h = jcp.dilate_h + 1; |
595 | const int gen_kh = (jcp.kh - 1) * dilate_h + 1; |
596 | size_t oc_stride = dst_d.blk_off(0, 1); |
597 | const int owb_limit = jcp.nb_ow - jcp.r_pad_blk - jcp.no_pad_w_blk; |
598 | |
599 | int mb {0}, g {0}, odc {0}, ohc {0}, owb {0}, occ {0}; |
600 | nd_iterator_init(start, mb, jcp.mb, g, jcp.ngroups, odc, jcp.od, ohc, |
601 | oh_chunks, owb, jcp.nb_ow, occ, oc_chunks); |
602 | int last_copied_mb = -1; |
603 | int last_copied_odc = -1; |
604 | int last_copied_ohc = -1; |
605 | int last_copied_owb = -1; |
606 | int last_copied_g = -1; |
607 | while (start < end) { |
608 | char *inp_buffer |
609 | = inp_p_buffer + src_dt_size * ithr * jcp.inp_buffer_size; |
610 | |
611 | assert(IMPLICATION( |
612 | jcp.ngroups > 1, jcp.oc == jcp.oc_without_padding)); |
613 | int oc = g * jcp.oc + occ * jcp.nb_oc_blocking * jcp.oc_block; |
614 | int ocb = jcp.is_nspc ? oc : oc / jcp.oc_block; |
615 | const char *bias_w = bias |
616 | ? bias + (bias_d.blk_off(oc) * bia_dt_size) |
617 | : nullptr; |
618 | p.zp_compensation |
619 | = jcp.src_zero_point ? zp_compensation + oc : nullptr; |
620 | p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; |
621 | p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; |
622 | |
623 | const size_t inp_src_d_stride = mem_blk_off(src_d, 0, 0, 1, 0, 0); |
624 | int oh_s = ohc * jcp.oh_blk_size; |
625 | int oh_e = nstl::min(jcp.oh, oh_s + jcp.oh_blk_size); |
626 | bool is_inp_buffer_relevant = true && last_copied_mb == mb |
627 | && last_copied_odc == odc && last_copied_ohc == ohc |
628 | && last_copied_owb == owb && last_copied_g == g; |
629 | bool is_zp_pbuff_relevant = zp_pbuff_parallel_block |
630 | ? zp_flags[g * oc_chunks + occ] // already computed? |
631 | : false; |
632 | |
633 | int cur_t_pad = nstl::max(0, t_pad_output - oh_s); |
634 | int cur_b_pad = nstl::max( |
635 | nstl::max(0, jcp.oh - b_pad_output - oh_s), cur_t_pad); |
636 | const size_t zp_od = odc >= back_pad_start |
637 | ? odc - back_pad_start + zp_buff_back_pad_start |
638 | : nstl::min(f_pad_output, odc); |
639 | size_t zp_oh |
640 | = accum_with_upper_bound(oh_s, t_pad_output, b_pad_start); |
641 | |
642 | int limit_idx = 0; |
643 | constexpr int limit_size = 5; |
644 | for (; req_zero_point_buffer && limit_idx < limit_size; |
645 | limit_idx++) { |
646 | // find current 'oh_blk' index from 'oh_s` |
647 | if ((size_t)oh_s < jcp.h_blk_limits[limit_idx]) break; |
648 | } |
649 | |
650 | if (is_zp_pbuff_relevant) { |
651 | assert(pd()->ndims() != 5); |
652 | assert(!zp_pbuff_outer_compute); |
653 | zp_flags[g * oc_chunks + occ] = false; |
654 | for (int oh_pad = 0; oh_pad < oh_work; ++oh_pad) { |
655 | const int oh_ = oh_pad >= zp_buff_b_pad_start |
656 | ? b_pad_start + oh_pad - zp_buff_b_pad_start |
657 | : oh_pad; |
658 | const int ih = oh_ * jcp.stride_h - jcp.t_pad; |
659 | const int t_overflow |
660 | = nstl::min(jcp.kh, div_up(max(0, -ih), dilate_h)); |
661 | const int b_overflow = nstl::min(jcp.kh, |
662 | div_up(nstl::max(0, ih + gen_kh - jcp.ih), |
663 | dilate_h)); |
664 | p.t_overflow = t_overflow; |
665 | p.b_overflow = b_overflow; |
666 | p.kh_padding |
667 | = nstl::max(0, jcp.kh - t_overflow - b_overflow); |
668 | |
669 | const size_t ch_offset = dst_d.blk_off(0, oc); |
670 | const size_t sp_offset = oh_pad * jcp.ow_pad * sp_stride; |
671 | p.zero_point_pbuff = &local_zp_pbuff[ch_offset + sp_offset]; |
672 | p.oc_blocks = occ * jcp.nb_oc_blocking; |
673 | p.filt = weights |
674 | + wei_dt_size * (g * oc_chunks + occ) |
675 | * wei_oc_shift; |
676 | p.src_zero_point = src_zero_point; |
677 | |
678 | kernel_->zp_pbuff_kernel()(&p); |
679 | } |
680 | } |
681 | |
682 | const int id_s = odc * jcp.stride_d - jcp.f_pad; |
683 | const int d_f_overflow |
684 | = nstl::min(jcp.kd, div_up(max(0, -id_s), dilate_d)); |
685 | const int d_back_overflow = nstl::min(jcp.kd, |
686 | div_up(max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1), |
687 | dilate_d)); |
688 | p.kd_padding |
689 | = nstl::max(0, jcp.kd - d_f_overflow - d_back_overflow); |
690 | |
691 | int oh_step = jcp.nb_oh_blocking * jcp.oh_per_tile; |
692 | for (int oh = oh_s; oh < oh_e; oh += oh_step) { |
693 | const int gen_kh = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1); |
694 | const int gen_stride_h = nstl::min(jcp.stride_h, gen_kh); |
695 | if (!is_inp_buffer_relevant) { |
696 | const int iw = nstl::max( |
697 | 0, owb * jcp.ow_block * jcp.stride_w - jcp.l_pad); |
698 | |
699 | assert(IMPLICATION( |
700 | jcp.ngroups > 1, jcp.ic == jcp.ic_without_padding)); |
701 | const int icb = g * (jcp.is_nspc ? jcp.ic : jcp.nb_ic); |
702 | |
703 | // generalized kh including dilation |
704 | // the current implementation of copy routine is not |
705 | // optimal for small jcp.oh_blk_size as it copies |
706 | // dilation rows to buffer |
707 | const bool continuous_copy = gen_kh >= jcp.stride_h; |
708 | int current_oh_block = nstl::min(oh_e - oh, oh_step); |
709 | int num_copy_calls = continuous_copy ? 1 : current_oh_block; |
710 | for (int ohi = 0; ohi < num_copy_calls; ohi++) { |
711 | int ih_copy_start |
712 | = (oh + ohi) * jcp.stride_h - jcp.t_pad; |
713 | int ih_copy_end = ih_copy_start + gen_kh; |
714 | if (continuous_copy) { |
715 | ih_copy_end |
716 | += jcp.stride_h * (current_oh_block - 1); |
717 | if (oh > oh_s) |
718 | // it's non-first block, shift start to the end |
719 | // of previous block |
720 | // ih_copy_end_prev = |
721 | // (oh - 1) * str_h - t_pad + kh |
722 | ih_copy_start += gen_kh - jcp.stride_h; |
723 | } |
724 | int ih_zero_top = nstl::max(0, -ih_copy_start); |
725 | int ih_zero_bottom = nstl::max(0, ih_copy_end - jcp.ih); |
726 | // how many real data rows to copy (including padding) |
727 | int rows_to_copy = ih_copy_end - ih_copy_start; |
728 | p.kh_padding = max(0, rows_to_copy); |
729 | p.t_overflow = ih_zero_top; |
730 | p.b_overflow = ih_zero_bottom; |
731 | p.owb = owb; |
732 | int ih = nstl::max(ih_copy_start, 0); |
733 | size_t inp_offset |
734 | = mem_blk_off(src_d, mb, icb, id_s, ih, iw) |
735 | + d_f_overflow * dilate_d * inp_src_d_stride; |
736 | p.src = src + src_dt_size * inp_offset; |
737 | // inp_buffer has physical padding |
738 | int ih_buf = continuous_copy |
739 | ? ih_copy_start + jcp.t_pad |
740 | - oh_s * jcp.stride_h |
741 | : gen_stride_h * (oh + ohi - oh_s); |
742 | p.dst = inp_buffer |
743 | + src_dt_size * ih_buf * jcp.iwp |
744 | * jcp.ic_block_int_np; |
745 | |
746 | kernel_->copy_to_pbuffer()(&p); |
747 | } |
748 | } |
749 | int ih_buf = gen_stride_h * (oh - oh_s); |
750 | int ow = owb * jcp.ow_block; |
751 | p.src = inp_buffer |
752 | + src_dt_size * ih_buf * jcp.iwp * jcp.ic_block_int_np; |
753 | |
754 | size_t dst_offset = mem_blk_off(dst_d, mb, ocb, odc, oh, ow); |
755 | p.dst = dst + dst_dt_size * dst_offset; |
756 | const size_t pbuff_offset |
757 | = (zp_od * jcp.oh_pad + zp_oh) * jcp.ow_pad * sp_stride |
758 | + ocb * oc_stride; |
759 | p.zero_point_pbuff = req_zero_point_buffer |
760 | ? &local_zp_pbuff[pbuff_offset] |
761 | : nullptr; |
762 | |
763 | p.filt = weights |
764 | + ((g * oc_chunks + occ) * wei_oc_shift |
765 | + d_f_overflow * wei_d_shift) |
766 | * wei_dt_size; |
767 | p.bias = bias_w; |
768 | p.scales = &oscales[jcp.is_oc_scale * oc]; |
769 | p.dst_scale = &dst_scales[0]; |
770 | |
771 | p.acc_s32 = wsp + ithr * jcp.wsp_buffer_size; |
772 | if (req_zero_point_buffer |
773 | && (size_t)oh >= jcp.h_blk_limits[limit_idx]) |
774 | limit_idx++; |
775 | assert(limit_idx < 6); |
776 | p.ohb = limit_idx; |
777 | p.last_h = (oh + oh_step <= oh_e); |
778 | const int zp_owb = nstl::min(jcp.l_pad_blk, owb) |
779 | + nstl::max(0, owb - owb_limit); |
780 | p.owb = req_zero_point_buffer ? zp_owb : owb; |
781 | |
782 | p.oc_blocks = occ * jcp.nb_oc_blocking; |
783 | |
784 | p.oc_l_off = oc; |
785 | p.post_ops_binary_rhs_arg_vec |
786 | = post_ops_binary_rhs_arg_vec.data(); |
787 | p.dst_orig = dst; |
788 | |
789 | (*kernel_)(&p); |
790 | |
791 | if (req_zero_point_buffer) { |
792 | zp_oh += accum_with_upper_bound( |
793 | oh_step, cur_t_pad, cur_b_pad); |
794 | cur_t_pad = nstl::max(0, cur_t_pad - oh_step); |
795 | cur_b_pad = nstl::max(0, cur_b_pad - oh_step); |
796 | } |
797 | } |
798 | last_copied_mb = mb; |
799 | last_copied_odc = odc; |
800 | last_copied_ohc = ohc; |
801 | last_copied_owb = owb; |
802 | last_copied_g = g; |
803 | ++start; |
804 | nd_iterator_step(mb, jcp.mb, g, jcp.ngroups, odc, jcp.od, ohc, |
805 | oh_chunks, owb, jcp.nb_ow, occ, oc_chunks); |
806 | } |
807 | |
808 | amx_tile_release(); |
809 | }); |
810 | return status::success; |
811 | } |
812 | |
813 | template <data_type_t diff_src_type, data_type_t wei_type, |
814 | data_type_t diff_dst_type> |
815 | status_t jit_avx512_core_amx_convolution_bwd_data_t<diff_src_type, wei_type, |
816 | diff_dst_type>::execute_backward(const exec_ctx_t &ctx) const { |
817 | const auto diff_dst = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST); |
818 | const auto weights = CTX_IN_MEM(const char *, DNNL_ARG_WEIGHTS); |
819 | auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); |
820 | |
821 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
822 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
823 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
824 | |
825 | // unused in kernel for bf16, but attributes have scales buffer by default |
826 | // and using it here simplifies the shared `execute_backward_loop`. |
827 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
828 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
829 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
830 | |
831 | const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), |
832 | src_scales, wei_scales, pd()->OC(), pd()->attr()); |
833 | |
834 | amx_utils::execute_backward_convolution_body(ctx, pd()->jcp_, kernel_, |
835 | diff_dst, weights, nullptr /* no bias */, oscales, dst_scales, |
836 | diff_src, diff_dst_d, weights_d, |
837 | memory_desc_wrapper(nullptr) /* no bias */, diff_src_d); |
838 | return status::success; |
839 | } |
840 | |
841 | template struct jit_avx512_core_amx_convolution_bwd_data_t<data_type::bf16, |
842 | data_type::bf16, data_type::bf16>; |
843 | template struct jit_avx512_core_amx_convolution_bwd_data_t<data_type::f32, |
844 | data_type::bf16, data_type::bf16>; |
845 | |
846 | status_t jit_avx512_core_amx_convolution_bwd_weights_t::init(engine_t *engine) { |
847 | const auto &j = pd()->jcp_; |
848 | |
849 | nthr_ = j.nthr; |
850 | nthr_mb_ = j.nthr_mb; |
851 | nthr_g_ = j.nthr_g; |
852 | nthr_oc_b_ = j.nthr_oc_b; |
853 | nthr_ic_b_ = j.nthr_ic_b; |
854 | |
855 | CHECK(safe_ptr_assign( |
856 | kernel_, new jit_avx512_core_amx_bwd_weights_kernel_t(j))); |
857 | CHECK(kernel_->create_kernel()); |
858 | |
859 | CHECK(safe_ptr_assign(trans_kernel_, create_trans_src(&j))); |
860 | CHECK(trans_kernel_->create_kernel()); |
861 | CHECK(safe_ptr_assign(trans_dst_kernel_, create_trans_dst(&j))); |
862 | CHECK(trans_dst_kernel_->create_kernel()); |
863 | if (nthr_mb_ > 1) { |
864 | CHECK(safe_ptr_assign( |
865 | acc_ker_, new cpu_accumulator_1d_t<data_type::f32>())); |
866 | CHECK(acc_ker_->create_kernel()); |
867 | } |
868 | if (j.transform_to_vnni) { |
869 | CHECK(safe_ptr_assign(diff_wei_trans_kernel_, |
870 | new jit_diff_wei_trans_to_vnni_t( |
871 | j.wei_dt, j.kd, j.kh, j.kw, j.ic_block, j.oc_block))); |
872 | CHECK(diff_wei_trans_kernel_->create_kernel()); |
873 | } |
874 | return status::success; |
875 | } |
876 | |
877 | struct jit_avx512_core_amx_convolution_bwd_weights_t::thread_info_t { |
878 | const src_data_t *src = nullptr; |
879 | const diff_dst_data_t *diff_dst = nullptr; |
880 | const void *diff_weights = nullptr; |
881 | const void *diff_bias = nullptr; |
882 | |
883 | const memory_tracking::grantor_t scratchpad; |
884 | |
885 | src_data_t *tr_src = nullptr; |
886 | diff_dst_data_t *tr_diff_dst = nullptr; |
887 | simple_barrier::ctx_t *tr_src_bctx = nullptr; |
888 | simple_barrier::ctx_t *tr_diff_dst_bctx = nullptr; |
889 | |
890 | float *wei_bia_reduction = nullptr; |
891 | float *bia_reduction = nullptr; |
892 | simple_barrier::ctx_t *wei_bia_reduction_bctx = nullptr; |
893 | |
894 | int ithr = 0; |
895 | int ithr_ic_b = 0, ithr_oc_b = 0, ithr_g = 0, ithr_mb = 0; |
896 | int ithr_but_oc = 0; |
897 | int ithr_but_ic = 0; |
898 | |
899 | int img_start = 0, img_end = 0, img_work = 0; |
900 | int g_start = 0, g_end = 0, g_work = 0; |
901 | int oc_b_start = 0, oc_b_end = 0, oc_b_work = 0; |
902 | int ic_b_start = 0, ic_b_end = 0, ic_b_work = 0; |
903 | |
904 | thread_info_t(const jit_avx512_core_amx_convolution_bwd_weights_t *self, |
905 | const exec_ctx_t &ctx, int ithr) |
906 | : scratchpad(ctx.get_scratchpad_grantor()), ithr(ithr) { |
907 | diff_dst = CTX_IN_MEM(const diff_dst_data_t *, DNNL_ARG_DIFF_DST); |
908 | src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); |
909 | diff_weights = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_WEIGHTS); |
910 | |
911 | const auto &jcp = self->kernel_->jcp; |
912 | diff_bias = self->pd()->with_bias() |
913 | && (jcp.oc_without_padding % jcp.oc_block != 0) |
914 | && self->pd()->jcp_.bia_dt == data_type::f32 |
915 | ? (void *)scratchpad.template get<float>(key_conv_padded_bias) |
916 | : CTX_OUT_MEM(void *, DNNL_ARG_DIFF_BIAS); |
917 | |
918 | tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); |
919 | if (jcp.global_transpose) |
920 | tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
921 | key_conv_tr_src_bctx); |
922 | |
923 | tr_diff_dst = scratchpad.template get<diff_dst_data_t>( |
924 | key_conv_tr_diff_dst); |
925 | if (jcp.global_transpose) |
926 | tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
927 | key_conv_tr_diff_dst_bctx); |
928 | wei_bia_reduction |
929 | = scratchpad.template get<float>(key_conv_wei_bia_reduction); |
930 | bia_reduction = nullptr; |
931 | if (jcp.with_bias) { |
932 | const size_t wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block |
933 | * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
934 | const int num_wei_buffers = jcp.wei_dt == data_type::bf16 |
935 | ? jcp.nthr_mb |
936 | : jcp.nthr_mb - 1; |
937 | bia_reduction = wei_bia_reduction + wei_size * num_wei_buffers; |
938 | } |
939 | |
940 | wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
941 | key_conv_wei_bia_reduction_bctx); |
942 | |
943 | ithr_ic_b = ithr % self->nthr_ic_b_; |
944 | ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_; |
945 | ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_; |
946 | ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_; |
947 | |
948 | ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_ |
949 | + ithr_ic_b; |
950 | |
951 | ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_ |
952 | + ithr_oc_b; |
953 | |
954 | int work_amount = jcp.nthr_mb_work; |
955 | /* reduction dimension */ |
956 | balance211(work_amount, self->nthr_mb_, ithr_mb, img_start, img_end); |
957 | img_work = img_end - img_start; |
958 | |
959 | /* independent dimensions */ |
960 | balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end); |
961 | g_work = g_end - g_start; |
962 | |
963 | balance211( |
964 | jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start, oc_b_end); |
965 | oc_b_work = oc_b_end - oc_b_start; |
966 | |
967 | balance211( |
968 | jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start, ic_b_end); |
969 | if (jcp.transform_to_vnni) { |
970 | if (ic_b_start % 2 != 0) ic_b_start++; |
971 | if (ic_b_end != jcp.nb_ic && ic_b_end % 2 != 0) ic_b_end++; |
972 | } |
973 | ic_b_work = ic_b_end - ic_b_start; |
974 | } |
975 | }; |
976 | |
977 | size_t jit_avx512_core_amx_convolution_bwd_weights_t::tr_src_buf_number( |
978 | const thread_info_t *ti, int g, int ic) const { |
979 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
980 | return jcp.global_transpose |
981 | ? ti->ithr_mb * jcp.nb_ic * jcp.ngroups + g * jcp.nb_ic + ic |
982 | : ti->ithr; |
983 | } |
984 | |
985 | size_t jit_avx512_core_amx_convolution_bwd_weights_t::tr_diff_dst_buf_number( |
986 | const thread_info_t *ti, int g, int oc) const { |
987 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
988 | return jcp.global_transpose |
989 | ? ti->ithr_mb * jcp.nb_oc * jcp.ngroups + g * jcp.nb_oc + oc |
990 | : ti->ithr; |
991 | } |
992 | |
993 | void jit_avx512_core_amx_convolution_bwd_weights_t::trans_src_nxc( |
994 | src_data_t *tr_src, const src_data_t *src_base, int spatial_start, |
995 | dim_t spatial_start_offset, int icb_start, dim_t chb_stride, |
996 | int row_count) const { |
997 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
998 | const int src_stride = jcp.iw * jcp.ngroups * jcp.ic; |
999 | const int tr_src_stride = jcp.tr_iw * jcp.ic_block; |
1000 | |
1001 | int work_rest = row_count; |
1002 | int max_spatial_work = jcp.id * jcp.ih; |
1003 | int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start); |
1004 | const src_data_t *src = src_base + spatial_start_offset; |
1005 | int icb = 0; |
1006 | const int ic_tail_work = jcp.ic_tail ? jcp.ic_tail : jcp.ic_block; |
1007 | while (work_rest > 0) { |
1008 | for (int iwork = 0; iwork < sp_work; iwork++) { |
1009 | auto ctx = jit_trans_src_t::ctx_t(); |
1010 | ctx.src = src; |
1011 | ctx.tr_src = tr_src; |
1012 | assert(icb_start + icb < jcp.nb_ic); |
1013 | ctx.ch_work = (icb_start + icb + 1) == jcp.nb_ic ? ic_tail_work |
1014 | : jcp.ic_block; |
1015 | ctx.src_prf = nullptr; |
1016 | ctx.tr_src_prf = nullptr; |
1017 | (*trans_kernel_)(&ctx); |
1018 | src += src_stride; |
1019 | tr_src += tr_src_stride; |
1020 | } |
1021 | work_rest -= sp_work; |
1022 | sp_work = nstl::min(work_rest, max_spatial_work); |
1023 | icb++; |
1024 | src = src_base + icb * chb_stride; |
1025 | } |
1026 | } |
1027 | |
1028 | void jit_avx512_core_amx_convolution_bwd_weights_t::trans_dst_nxc( |
1029 | diff_dst_data_t *tr_diff_dst, const diff_dst_data_t *diff_dst_base, |
1030 | int spatial_start, dim_t spatial_start_offset, int ocb_start, |
1031 | dim_t chb_stride, int row_count) const { |
1032 | const jit_conv_conf_t &jcp = this->kernel_->jcp; |
1033 | const int diff_dst_stride = jcp.ow * jcp.ngroups * jcp.oc; |
1034 | const int tr_diff_dst_stride = jcp.tr_ow * jcp.oc_block; |
1035 | int work_rest = row_count; |
1036 | int max_spatial_work = jcp.od * jcp.oh; |
1037 | int sp_work = nstl::min(work_rest, max_spatial_work - spatial_start); |
1038 | const src_data_t *diff_dst = diff_dst_base + spatial_start_offset; |
1039 | int ocb = 0; |
1040 | const int oc_tail_work = jcp.oc_tail ? jcp.oc_tail : jcp.oc_block; |
1041 | while (work_rest > 0) { |
1042 | for (int iwork = 0; iwork < sp_work; iwork++) { |
1043 | auto ctx = jit_trans_dst_t::ctx_t(); |
1044 | ctx.src = diff_dst; |
1045 | ctx.tr_src = tr_diff_dst; |
1046 | assert(ocb_start + ocb < jcp.nb_oc); |
1047 | ctx.ch_work = (ocb_start + ocb + 1) == jcp.nb_oc ? oc_tail_work |
1048 | : jcp.oc_block; |
1049 | ctx.src_prf = nullptr; |
1050 | ctx.tr_src_prf = nullptr; |
1051 | (*trans_dst_kernel_)(&ctx); |
1052 | diff_dst += diff_dst_stride; |
1053 | tr_diff_dst += tr_diff_dst_stride; |
1054 | } |
1055 | work_rest -= sp_work; |
1056 | sp_work = nstl::min(work_rest, max_spatial_work); |
1057 | ocb++; |
1058 | diff_dst = diff_dst_base + ocb * chb_stride; |
1059 | } |
1060 | } |
1061 | |
1062 | void jit_avx512_core_amx_convolution_bwd_weights_t::compute_diff_weights_2d( |
1063 | const thread_info_t *ti) const { |
1064 | const memory_desc_wrapper src_d(pd()->src_md()); |
1065 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1066 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1067 | const auto &jcp = kernel_->jcp; |
1068 | |
1069 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1070 | * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
1071 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1072 | const int optimal_hblock = jcp.spatial_blk_size; |
1073 | |
1074 | float *diff_wei; |
1075 | if (diff_weights_d.data_type() == data_type::bf16) |
1076 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
1077 | else |
1078 | diff_wei = ti->ithr_mb == 0 |
1079 | ? (float *)ti->diff_weights |
1080 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1081 | |
1082 | float *diff_bias = nullptr; |
1083 | if (jcp.with_bias) { |
1084 | if (jcp.bia_dt == data_type::bf16) |
1085 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
1086 | else |
1087 | diff_bias = ti->ithr_mb == 0 |
1088 | ? (float *)ti->diff_bias |
1089 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
1090 | } |
1091 | |
1092 | auto tr_diff_dst_off = [&](int g, int oc, int oj) { |
1093 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1094 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking; |
1095 | return tr_diff_dst_buf_number(ti, g, oc) * adj |
1096 | * jcp.tr_diff_dst_buf_size |
1097 | + oj * tr_row_size; |
1098 | }; |
1099 | |
1100 | int img {0}, oh_s {0}; |
1101 | int start = ti->img_start; |
1102 | int end = ti->img_end; |
1103 | |
1104 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1105 | |
1106 | nd_iterator_init(start, img, jcp.mb, oh_s, jcp.oh); |
1107 | |
1108 | while (start < end) { |
1109 | auto p = jit_conv_call_s(); |
1110 | int work_rem = end - start; |
1111 | const int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem; |
1112 | int ih_s = nstl::max(0, -jcp.t_pad + oh_s * jcp.stride_h); |
1113 | const int ih_e = nstl::min( |
1114 | jcp.ih, -jcp.t_pad + (oh_e - 1) * jcp.stride_h + ext_kh); |
1115 | |
1116 | auto tr_src_off = [&](int g, int ic, int ih_end, int ij) { |
1117 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1118 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking; |
1119 | // Aligned to buffer end to use guard elements |
1120 | return tr_src_buf_number(ti, g, ic) * adj * jcp.tr_src_buf_size |
1121 | + (jcp.ih - ih_end + ij) * tr_row_size; |
1122 | }; |
1123 | |
1124 | if (jcp.global_transpose) { |
1125 | using simple_barrier::barrier; |
1126 | // TODO: try to call local transpositions just before jit kernel |
1127 | /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ |
1128 | int j {0}; |
1129 | int work_amount = ti->g_work * ti->ic_b_work * (ih_e - ih_s); |
1130 | int tr_start {0}, tr_end {0}; |
1131 | balance211( |
1132 | work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start, tr_end); |
1133 | |
1134 | int g {0}, ic_b {0}; |
1135 | nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work, j, |
1136 | ih_e - ih_s); |
1137 | |
1138 | if (nthr_oc_b_ > 1) |
1139 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1140 | while (tr_start < tr_end) { |
1141 | int g_ = g + ti->g_start; |
1142 | int ic_b_ = ic_b + ti->ic_b_start; |
1143 | int j_s = j + ih_s; |
1144 | int j_e = j_s + nstl::min(tr_end - tr_start, ih_e - j_s); |
1145 | const int ic_off_idx = g_ * jcp.ic + ic_b_ * jcp.ic_block; |
1146 | const src_data_t *src |
1147 | = &ti->src[src_d.blk_off(img, ic_off_idx, j_s)]; |
1148 | src_data_t *tr_src |
1149 | = &ti->tr_src[tr_src_off(g_, ic_b_, ih_e, j_s)]; |
1150 | trans_src_nxc(tr_src, src, 0, 0, ic_b_, 0, j_e - j_s); |
1151 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b, |
1152 | ti->ic_b_work, j, ih_e - ih_s); |
1153 | } |
1154 | if (nthr_oc_b_ > 1) |
1155 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1156 | |
1157 | j = 0; |
1158 | work_amount = ti->g_work * ti->oc_b_work * (oh_e - oh_s); |
1159 | tr_start = 0; |
1160 | tr_end = 0; |
1161 | balance211( |
1162 | work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start, tr_end); |
1163 | |
1164 | g = 0; |
1165 | int oc_b = 0; |
1166 | nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work, j, |
1167 | oh_e - oh_s); |
1168 | |
1169 | if (nthr_ic_b_ > 1) |
1170 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1171 | while (tr_start < tr_end) { |
1172 | int g_ = g + ti->g_start; |
1173 | int oc_b_ = oc_b + ti->oc_b_start; |
1174 | int j_s = j + oh_s; |
1175 | int j_e = j_s + nstl::min(tr_end - tr_start, oh_e - j_s); |
1176 | const int oc_off_idx = g_ * jcp.oc + oc_b_ * jcp.oc_block; |
1177 | const diff_dst_data_t *diff_dst |
1178 | = &ti->diff_dst[diff_dst_d.blk_off( |
1179 | img, oc_off_idx, j_s)]; |
1180 | diff_dst_data_t *tr_diff_dst |
1181 | = &ti->tr_diff_dst[tr_diff_dst_off(g_, oc_b_, j_s)]; |
1182 | |
1183 | trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0, j_e - j_s); |
1184 | |
1185 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b, |
1186 | ti->oc_b_work, j, oh_e - oh_s); |
1187 | } |
1188 | if (nthr_ic_b_ > 1) |
1189 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1190 | } |
1191 | int height_block = jcp.global_transpose ? oh_e - oh_s : optimal_hblock; |
1192 | |
1193 | for_(int ohb_s = oh_s; ohb_s < oh_e; ohb_s += height_block) |
1194 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1195 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; |
1196 | oc_b += jcp.nb_oc_blocking) |
1197 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1198 | ic_b += jcp.nb_ic_blocking) { |
1199 | const int ohb_e = nstl::min(ohb_s + height_block, oh_e); |
1200 | const int ihb_s = nstl::max(0, -jcp.t_pad + ohb_s * jcp.stride_h); |
1201 | const int ihb_e = nstl::min( |
1202 | jcp.ih, -jcp.t_pad + (ohb_e - 1) * jcp.stride_h + ext_kh); |
1203 | assert(IMPLICATION(jcp.global_transpose, |
1204 | oh_s == ohb_s && oh_e == ohb_e && ih_s == ihb_s |
1205 | && ih_e == ihb_e)); |
1206 | const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ti->ic_b_end) |
1207 | ? 1 |
1208 | : jcp.nb_ic_blocking; |
1209 | const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > ti->oc_b_end) |
1210 | ? 1 |
1211 | : jcp.nb_oc_blocking; |
1212 | const int ic_to_compute |
1213 | = this_block_size((ic_b + nb_ic_blocks - 1) * jcp.ic_block, |
1214 | jcp.ic, jcp.ic_block); |
1215 | const int oc_to_compute |
1216 | = this_block_size((oc_b + nb_oc_blocks - 1) * jcp.oc_block, |
1217 | jcp.oc, jcp.oc_block); |
1218 | |
1219 | if (!jcp.global_transpose) { |
1220 | src_data_t *tr_src |
1221 | = &ti->tr_src[tr_src_off(0, 0, ihb_e, ihb_s)]; |
1222 | |
1223 | for (int icb = 0; icb < nb_ic_blocks; icb++) { |
1224 | const int ic_off_idx |
1225 | = g * jcp.ic + (ic_b + icb) * jcp.ic_block; |
1226 | const src_data_t *src |
1227 | = (src_data_t *)&ti->src[src_d.blk_off( |
1228 | img, ic_off_idx, ihb_s)]; |
1229 | src_data_t *tr_src_local |
1230 | = tr_src + icb * jcp.tr_src_buf_size; |
1231 | trans_src_nxc(tr_src_local, src, 0, 0, (ic_b + icb), 0, |
1232 | ihb_e - ihb_s); |
1233 | } |
1234 | p.src = tr_src; |
1235 | } else { |
1236 | p.src = &ti->tr_src[tr_src_off(g, ic_b, ihb_e, ihb_s)]; |
1237 | } |
1238 | |
1239 | if (!jcp.global_transpose) { |
1240 | diff_dst_data_t *tr_diff_dst |
1241 | = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0)]; |
1242 | for (int ocb = 0; ocb < nb_oc_blocks; ocb++) { |
1243 | const int oc_off_idx |
1244 | = g * jcp.oc + (oc_b + ocb) * jcp.oc_block; |
1245 | const diff_dst_data_t *diff_dst |
1246 | = &ti->diff_dst[diff_dst_d.blk_off( |
1247 | img, oc_off_idx, ohb_s)]; |
1248 | diff_dst_data_t *tr_diff_dst_local |
1249 | = tr_diff_dst + ocb * jcp.tr_diff_dst_buf_size; |
1250 | trans_dst_nxc(tr_diff_dst_local, diff_dst, 0, 0, |
1251 | (oc_b + ocb), 0, ohb_e - ohb_s); |
1252 | } |
1253 | p.dst = tr_diff_dst; |
1254 | } else { |
1255 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, ohb_s)]; |
1256 | } |
1257 | |
1258 | p.filt = (jcp.transform_to_vnni) |
1259 | ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0) |
1260 | : diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1261 | p.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
1262 | + oc_b * jcp.oc_block; |
1263 | p.channel = (start == ti->img_start) && (ohb_s == oh_s); |
1264 | |
1265 | p.reduce_work = ic_to_compute; |
1266 | p.load_work = oc_to_compute; // it's only for mask |
1267 | p.os_index_begin = ohb_s; |
1268 | p.os_index_end = ohb_e; |
1269 | p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0); |
1270 | p.last_ic_block = (nb_ic_blocks == jcp.nb_ic_blocking) ? 0 : 1; |
1271 | p.last_oc_block = (nb_oc_blocks == jcp.nb_oc_blocking) ? 0 : 1; |
1272 | assert(ohb_e <= jcp.oh); |
1273 | (*kernel_)(&p); |
1274 | } |
1275 | |
1276 | nd_iterator_jump(start, end, img, jcp.mb, oh_s, jcp.oh); |
1277 | } |
1278 | } |
1279 | |
1280 | void jit_avx512_core_amx_convolution_bwd_weights_t::compute_diff_weights_3d( |
1281 | const thread_info_t *ti) const { |
1282 | const memory_desc_wrapper src_d(pd()->src_md()); |
1283 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1284 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1285 | const auto &jcp = kernel_->jcp; |
1286 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1287 | * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
1288 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1289 | const int optimal_dblock = jcp.spatial_blk_size; |
1290 | |
1291 | float *diff_wei; |
1292 | if (diff_weights_d.data_type() == data_type::bf16) |
1293 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
1294 | else |
1295 | diff_wei = ti->ithr_mb == 0 |
1296 | ? (float *)ti->diff_weights |
1297 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1298 | |
1299 | float *diff_bias = nullptr; |
1300 | if (jcp.with_bias) { |
1301 | if (jcp.bia_dt == data_type::bf16) |
1302 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
1303 | else |
1304 | diff_bias = ti->ithr_mb == 0 |
1305 | ? (float *)ti->diff_bias |
1306 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
1307 | } |
1308 | |
1309 | auto tr_diff_dst_off_3d = [&](int g, int oc, int od) { |
1310 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1311 | const size_t tr_3d_size = tr_row_size * jcp.oh; |
1312 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking; |
1313 | return tr_diff_dst_buf_number(ti, g, oc) * adj |
1314 | * jcp.tr_diff_dst_buf_size |
1315 | + od * tr_3d_size; |
1316 | }; |
1317 | int img {0}, od_s {0}; |
1318 | int start = ti->img_start; |
1319 | int end = ti->img_end; |
1320 | |
1321 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1322 | |
1323 | nd_iterator_init(start, img, jcp.mb, od_s, jcp.od); |
1324 | while (start < end) { |
1325 | auto p = jit_conv_call_s(); |
1326 | int work_rem = end - start; |
1327 | const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem; |
1328 | int id_s = nstl::max(0, -jcp.f_pad + od_s * jcp.stride_d); |
1329 | const int id_e = nstl::min( |
1330 | jcp.id, -jcp.f_pad + (od_e - 1) * jcp.stride_d + ext_kd); |
1331 | |
1332 | auto tr_src_off_3d = [&](int g, int ic, int id_end, int id) { |
1333 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1334 | const size_t tr_3d_size = tr_row_size * jcp.ih; |
1335 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking; |
1336 | // Aligned to buffer end to use guard elements |
1337 | return tr_src_buf_number(ti, g, ic) * adj * jcp.tr_src_buf_size |
1338 | + (jcp.id - id_end + id) * tr_3d_size; |
1339 | }; |
1340 | |
1341 | if (jcp.global_transpose) { |
1342 | using simple_barrier::barrier; |
1343 | // TODO: try to call local transpositions just before jit kernel |
1344 | /* tr_src[nb_ic][id][16][~iw~] <- src[nb_ic][id][iw][16] */ |
1345 | int d {0}; |
1346 | |
1347 | int work_amount = ti->g_work * ti->ic_b_work * (id_e - id_s); |
1348 | |
1349 | int tr_start {0}, tr_end {0}; |
1350 | balance211( |
1351 | work_amount, nthr_oc_b_, ti->ithr_oc_b, tr_start, tr_end); |
1352 | |
1353 | int g {0}, ic_b {0}; |
1354 | |
1355 | nd_iterator_init(tr_start, g, ti->g_work, ic_b, ti->ic_b_work, d, |
1356 | id_e - id_s); |
1357 | |
1358 | if (nthr_oc_b_ > 1) |
1359 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1360 | while (tr_start < tr_end) { |
1361 | int g_ = g + ti->g_start; |
1362 | int ic_b_ = ic_b + ti->ic_b_start; |
1363 | int d_s = d + id_s; |
1364 | int d_e = d_s + nstl::min(tr_end - tr_start, id_e - d_s); |
1365 | |
1366 | const int ic_off_idx = g_ * jcp.ic + ic_b_ * jcp.ic_block; |
1367 | const src_data_t *src |
1368 | = &ti->src[src_d.blk_off(img, ic_off_idx, d_s)]; |
1369 | src_data_t *tr_src |
1370 | = &ti->tr_src[tr_src_off_3d(g_, ic_b_, id_e, d_s)]; |
1371 | trans_src_nxc( |
1372 | tr_src, src, 0, 0, ic_b_, 0, (d_e - d_s) * jcp.ih); |
1373 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, ic_b, |
1374 | ti->ic_b_work, d, id_e - id_s); |
1375 | } |
1376 | if (nthr_oc_b_ > 1) |
1377 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1378 | |
1379 | d = 0; |
1380 | |
1381 | work_amount = ti->g_work * ti->oc_b_work * (od_e - od_s); |
1382 | |
1383 | tr_start = 0, tr_end = 0; |
1384 | balance211( |
1385 | work_amount, nthr_ic_b_, ti->ithr_ic_b, tr_start, tr_end); |
1386 | |
1387 | g = 0; |
1388 | int oc_b = 0; |
1389 | |
1390 | nd_iterator_init(tr_start, g, ti->g_work, oc_b, ti->oc_b_work, d, |
1391 | od_e - od_s); |
1392 | |
1393 | if (nthr_ic_b_ > 1) |
1394 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1395 | while (tr_start < tr_end) { |
1396 | int g_ = g + ti->g_start; |
1397 | int oc_b_ = oc_b + ti->oc_b_start; |
1398 | int d_s = d + od_s; |
1399 | int d_e = d_s + nstl::min(tr_end - tr_start, od_e - d_s); |
1400 | const int oc_off_idx = g_ * jcp.oc + oc_b_ * jcp.oc_block; |
1401 | |
1402 | const diff_dst_data_t *diff_dst |
1403 | = &ti->diff_dst[diff_dst_d.blk_off( |
1404 | img, oc_off_idx, d_s)]; |
1405 | diff_dst_data_t *tr_diff_dst |
1406 | = &ti->tr_diff_dst[tr_diff_dst_off_3d(g_, oc_b_, d_s)]; |
1407 | |
1408 | trans_dst_nxc(tr_diff_dst, diff_dst, 0, 0, oc_b_, 0, |
1409 | (d_e - d_s) * jcp.oh); |
1410 | |
1411 | nd_iterator_jump(tr_start, tr_end, g, ti->g_work, oc_b, |
1412 | ti->oc_b_work, d, od_e - od_s); |
1413 | } |
1414 | if (nthr_ic_b_ > 1) |
1415 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1416 | } |
1417 | |
1418 | int depth_block = jcp.global_transpose ? od_e - od_s : optimal_dblock; |
1419 | |
1420 | for_(int odb_s = od_s; odb_s < od_e; odb_s += depth_block) |
1421 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1422 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; |
1423 | oc_b += jcp.nb_oc_blocking) |
1424 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1425 | ic_b += jcp.nb_ic_blocking) { |
1426 | const int odb_e = nstl::min(odb_s + depth_block, od_e); |
1427 | const int idb_s = nstl::max(0, -jcp.f_pad + odb_s * jcp.stride_d); |
1428 | const int idb_e = nstl::min( |
1429 | jcp.id, -jcp.f_pad + (odb_e - 1) * jcp.stride_d + ext_kd); |
1430 | const int kdb_front_pad |
1431 | = nstl::max(0, jcp.f_pad - odb_s * jcp.stride_d); |
1432 | // Assumes kd_back_pad = 0 when kernel is dilated |
1433 | const int kdb_back_pad = nstl::max( |
1434 | 0, odb_s * jcp.stride_d + jcp.kd - jcp.f_pad - jcp.id); |
1435 | const int kdb_pad_off = nstl::min(jcp.kd - 1, kdb_front_pad) |
1436 | * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block |
1437 | * jcp.typesize_out; |
1438 | |
1439 | assert(IMPLICATION(jcp.global_transpose, |
1440 | od_s == odb_s && od_e == odb_e && id_s == idb_s |
1441 | && id_e == idb_e)); |
1442 | const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ti->ic_b_end) |
1443 | ? 1 |
1444 | : jcp.nb_ic_blocking; |
1445 | const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > ti->oc_b_end) |
1446 | ? 1 |
1447 | : jcp.nb_oc_blocking; |
1448 | const int ic_to_compute |
1449 | = this_block_size((ic_b + nb_ic_blocks - 1) * jcp.ic_block, |
1450 | jcp.ic, jcp.ic_block); |
1451 | const int oc_to_compute |
1452 | = this_block_size((oc_b + nb_oc_blocks - 1) * jcp.oc_block, |
1453 | jcp.oc, jcp.oc_block); |
1454 | |
1455 | if (!jcp.global_transpose) { |
1456 | src_data_t *tr_src |
1457 | = &ti->tr_src[tr_src_off_3d(0, 0, idb_e, idb_s)]; |
1458 | |
1459 | for (int icb = 0; icb < nb_ic_blocks; icb++) { |
1460 | const int ic_off_idx |
1461 | = g * jcp.ic + (ic_b + icb) * jcp.ic_block; |
1462 | const src_data_t *src |
1463 | = (src_data_t *)&ti->src[src_d.blk_off( |
1464 | img, ic_off_idx, idb_s)]; |
1465 | src_data_t *tr_src_local |
1466 | = tr_src + icb * jcp.tr_src_buf_size; |
1467 | trans_src_nxc(tr_src_local, src, 0, 0, (ic_b + icb), 0, |
1468 | (idb_e - idb_s) * jcp.ih); |
1469 | } |
1470 | p.src = tr_src; |
1471 | } else { |
1472 | p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, idb_e, idb_s)]; |
1473 | } |
1474 | |
1475 | if (!jcp.global_transpose) { |
1476 | diff_dst_data_t *tr_diff_dst |
1477 | = &ti->tr_diff_dst[tr_diff_dst_off_3d(0, 0, 0)]; |
1478 | for (int ocb = 0; ocb < nb_oc_blocks; ocb++) { |
1479 | const int oc_off_idx |
1480 | = g * jcp.oc + (oc_b + ocb) * jcp.oc_block; |
1481 | const diff_dst_data_t *diff_dst |
1482 | = &ti->diff_dst[diff_dst_d.blk_off( |
1483 | img, oc_off_idx, odb_s)]; |
1484 | diff_dst_data_t *tr_diff_dst_local |
1485 | = tr_diff_dst + ocb * jcp.tr_diff_dst_buf_size; |
1486 | trans_dst_nxc(tr_diff_dst_local, diff_dst, 0, 0, |
1487 | (oc_b + ocb), 0, (odb_e - odb_s) * jcp.oh); |
1488 | } |
1489 | p.dst = tr_diff_dst; |
1490 | } else { |
1491 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(g, oc_b, odb_s)]; |
1492 | } |
1493 | |
1494 | p.filt = (jcp.transform_to_vnni) |
1495 | ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0) |
1496 | : diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1497 | p.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
1498 | + oc_b * jcp.oc_block; |
1499 | p.channel = (start == ti->img_start) && (odb_s == od_s); |
1500 | |
1501 | p.reduce_work = ic_to_compute; |
1502 | p.load_work = oc_to_compute; |
1503 | p.os_index_begin = odb_s; |
1504 | p.os_index_end = odb_e; |
1505 | p.kd_padding = jcp.kd - kdb_front_pad - kdb_back_pad; |
1506 | p.kd_offset = kdb_pad_off; |
1507 | p.flags = (ic_b == 0 ? FLAG_IC_FIRST : 0); |
1508 | p.last_ic_block = (nb_ic_blocks == jcp.nb_ic_blocking) ? 0 : 1; |
1509 | p.last_oc_block = (nb_oc_blocks == jcp.nb_oc_blocking) ? 0 : 1; |
1510 | assert(odb_e <= jcp.od); |
1511 | (*kernel_)(&p); |
1512 | } |
1513 | |
1514 | nd_iterator_jump(start, end, img, jcp.mb, od_s, jcp.od); |
1515 | } |
1516 | } |
1517 | |
1518 | void jit_avx512_core_amx_convolution_bwd_weights_t::compute_diff_weights( |
1519 | const thread_info_t *ti) const { |
1520 | const memory_desc_wrapper src_d(pd()->src_md()); |
1521 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
1522 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1523 | const auto &jcp = kernel_->jcp; |
1524 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1525 | * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd; |
1526 | const int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1527 | |
1528 | float *diff_wei; |
1529 | if (diff_weights_d.data_type() == data_type::bf16) |
1530 | diff_wei = ti->wei_bia_reduction + (ti->ithr_mb) * wei_size; |
1531 | else |
1532 | diff_wei = ti->ithr_mb == 0 |
1533 | ? (float *)ti->diff_weights |
1534 | : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size; |
1535 | |
1536 | float *diff_bias = nullptr; |
1537 | if (jcp.with_bias) { |
1538 | if (jcp.bia_dt == data_type::bf16) |
1539 | diff_bias = ti->bia_reduction + (ti->ithr_mb) * bias_buf_size; |
1540 | else |
1541 | diff_bias = ti->ithr_mb == 0 |
1542 | ? (float *)ti->diff_bias |
1543 | : ti->bia_reduction + (ti->ithr_mb - 1) * bias_buf_size; |
1544 | } |
1545 | |
1546 | auto tr_src_off = [&](int g, int ic, int nb_ic_block, int ij) { |
1547 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1548 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking; |
1549 | return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size * adj |
1550 | + ij * tr_row_size + nb_ic_block * jcp.tr_src_buf_size; |
1551 | }; |
1552 | |
1553 | auto tr_src_off_3d = [&](int g, int ic, int nb_ic_block, int id, int ij) { |
1554 | const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; |
1555 | const size_t tr_3d_size = tr_row_size * jcp.ih; |
1556 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking; |
1557 | return tr_src_buf_number(ti, g, ic) * jcp.tr_src_buf_size * adj |
1558 | + id * tr_3d_size + ij * tr_row_size |
1559 | + nb_ic_block * jcp.tr_src_buf_size; |
1560 | }; |
1561 | |
1562 | auto tr_diff_dst_off = [&](int g, int oc, int nb_oc_block, int oj) { |
1563 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1564 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking; |
1565 | return tr_diff_dst_buf_number(ti, g, oc) * jcp.tr_diff_dst_buf_size |
1566 | * adj |
1567 | + oj * tr_row_size + nb_oc_block * jcp.tr_diff_dst_buf_size; |
1568 | }; |
1569 | |
1570 | auto tr_diff_dst_off_3d |
1571 | = [&](int g, int oc, int nb_oc_block, int od, int oj) { |
1572 | const size_t tr_row_size = jcp.tr_ow * jcp.oc_block; |
1573 | const size_t tr_3d_size = tr_row_size * jcp.oh; |
1574 | int adj = (jcp.global_transpose) ? 1 : jcp.nb_oc_blocking; |
1575 | return tr_diff_dst_buf_number(ti, g, oc) |
1576 | * jcp.tr_diff_dst_buf_size * adj |
1577 | + od * tr_3d_size + oj * tr_row_size |
1578 | + nb_oc_block * jcp.tr_src_buf_size; |
1579 | }; |
1580 | |
1581 | auto uker_trans = [&](int img, int g = 0, int ic_b = 0, |
1582 | int nb_ic_block = 0) { |
1583 | int j {0}, d {0}; |
1584 | int my_work = jcp.ih * jcp.id; |
1585 | int ic; |
1586 | int icb_start = ic_b; |
1587 | if (jcp.global_transpose) { |
1588 | const int work_amount = ti->ic_b_work * jcp.ih * jcp.id; |
1589 | |
1590 | int start {0}, end {0}; |
1591 | balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end); |
1592 | my_work = end - start; |
1593 | |
1594 | if (jcp.ndims == 5) |
1595 | nd_iterator_init( |
1596 | start, ic_b, ti->ic_b_work, d, jcp.id, j, jcp.ih); |
1597 | else |
1598 | nd_iterator_init(start, ic_b, ti->ic_b_work, j, jcp.ih); |
1599 | g += ti->g_start; |
1600 | ic_b += ti->ic_b_start; |
1601 | icb_start = ic_b; |
1602 | ic = g * jcp.ic + ic_b * jcp.ic_block; |
1603 | |
1604 | } else { |
1605 | ic = g * jcp.ic + ic_b * jcp.ic_block; |
1606 | g = 0; |
1607 | ic_b = 0; |
1608 | } |
1609 | const bool need_local_gwork = jcp.global_transpose; |
1610 | const auto local_gwork = need_local_gwork ? ti->g_work : 1; |
1611 | |
1612 | for (int gg = g; gg < g + local_gwork; ++gg) { |
1613 | if (need_local_gwork) ic = gg * jcp.ic + ic_b * jcp.ic_block; |
1614 | src_data_t *tr_src = (jcp.ndims == 5) |
1615 | ? &ti->tr_src[tr_src_off_3d(gg, ic_b, nb_ic_block, d, j)] |
1616 | : &ti->tr_src[tr_src_off(gg, ic_b, nb_ic_block, j)]; |
1617 | auto src_offset = src_d.blk_off(img, ic); |
1618 | src_data_t *src = (src_data_t *)&ti->src[src_offset]; |
1619 | |
1620 | dim_t sp_start_offset = (jcp.ndims == 5) ? src_d.blk_off(0, 0, d, j) |
1621 | : src_d.blk_off(0, 0, j); |
1622 | dim_t ch_shift = src_d.blk_off(0, jcp.ic_block); |
1623 | int sp_start_idx = d * jcp.ih + j; |
1624 | trans_src_nxc(tr_src, src, sp_start_idx, sp_start_offset, icb_start, |
1625 | ch_shift, my_work); |
1626 | } |
1627 | }; |
1628 | |
1629 | auto diff_dst_trans = [&](int img, int g = 0, int oc_b = 0, |
1630 | int nb_oc_block = 0) { |
1631 | int j {0}, d {0}; |
1632 | int my_work = jcp.oh * jcp.od; |
1633 | int oc; |
1634 | int ocb_start = oc_b; |
1635 | |
1636 | if (jcp.global_transpose) { |
1637 | const size_t work_amount = ti->oc_b_work * jcp.oh * jcp.od; |
1638 | |
1639 | size_t start {0}, end {0}; |
1640 | balance211(work_amount, nthr_ic_b_, ti->ithr_ic_b, start, end); |
1641 | my_work = end - start; |
1642 | |
1643 | if (jcp.ndims == 5) |
1644 | nd_iterator_init( |
1645 | start, oc_b, ti->oc_b_work, d, jcp.od, j, jcp.oh); |
1646 | else |
1647 | nd_iterator_init(start, oc_b, ti->oc_b_work, j, jcp.oh); |
1648 | |
1649 | g += ti->g_start; |
1650 | oc_b += ti->oc_b_start; |
1651 | ocb_start = oc_b; |
1652 | oc = g * jcp.oc + oc_b * jcp.oc_block; |
1653 | } else { |
1654 | oc = g * jcp.oc + oc_b * jcp.oc_block; |
1655 | g = 0; |
1656 | oc_b = 0; |
1657 | } |
1658 | const bool need_local_gwork = jcp.global_transpose; |
1659 | const auto local_gwork = need_local_gwork ? ti->g_work : 1; |
1660 | |
1661 | for (int gg = g; gg < g + local_gwork; ++gg) { |
1662 | if (need_local_gwork) oc = gg * jcp.oc + oc_b * jcp.oc_block; |
1663 | diff_dst_data_t *tr_diff_dst = (jcp.ndims == 5) |
1664 | ? &ti->tr_diff_dst[tr_diff_dst_off_3d( |
1665 | gg, oc_b, nb_oc_block, d, j)] |
1666 | : &ti->tr_diff_dst[tr_diff_dst_off( |
1667 | gg, oc_b, nb_oc_block, j)]; |
1668 | auto ddst_offset = diff_dst_d.blk_off(img, oc); |
1669 | const diff_dst_data_t *diff_dst = &ti->diff_dst[ddst_offset]; |
1670 | |
1671 | dim_t sp_start_offset = (jcp.ndims == 5) |
1672 | ? diff_dst_d.blk_off(0, 0, d, j) |
1673 | : diff_dst_d.blk_off(0, 0, j); |
1674 | dim_t ch_shift = diff_dst_d.blk_off(0, jcp.oc_block); |
1675 | int sp_start_idx = d * jcp.oh + j; |
1676 | trans_dst_nxc(tr_diff_dst, diff_dst, sp_start_idx, sp_start_offset, |
1677 | ocb_start, ch_shift, my_work); |
1678 | } |
1679 | }; |
1680 | |
1681 | for (int img = ti->img_start; img < ti->img_end; ++img) { |
1682 | auto p = jit_conv_call_s(); |
1683 | if (jcp.global_transpose) { |
1684 | using simple_barrier::barrier; |
1685 | // TODO: try to call local transpositions just before jit kernel |
1686 | /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */ |
1687 | if (nthr_oc_b_ > 1) |
1688 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1689 | uker_trans(img); |
1690 | if (nthr_oc_b_ > 1) |
1691 | barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_); |
1692 | if (nthr_ic_b_ > 1) |
1693 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1694 | diff_dst_trans(img); |
1695 | if (nthr_ic_b_ > 1) |
1696 | barrier(&ti->tr_diff_dst_bctx[ti->ithr_but_ic], nthr_ic_b_); |
1697 | } |
1698 | |
1699 | for_(int g = ti->g_start; g < ti->g_end; ++g) |
1700 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; |
1701 | oc_b += jcp.nb_oc_blocking) |
1702 | for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; |
1703 | ic_b += jcp.nb_ic_blocking) { |
1704 | const int nb_ic_blocks = (ic_b + jcp.nb_ic_blocking > ti->ic_b_end) |
1705 | ? 1 |
1706 | : jcp.nb_ic_blocking; |
1707 | const int nb_oc_blocks = (oc_b + jcp.nb_oc_blocking > ti->oc_b_end) |
1708 | ? 1 |
1709 | : jcp.nb_oc_blocking; |
1710 | |
1711 | const int ic_to_compute |
1712 | = this_block_size((ic_b + nb_ic_blocks - 1) * jcp.ic_block, |
1713 | jcp.ic, jcp.ic_block); |
1714 | const int oc_to_compute |
1715 | = this_block_size((oc_b + nb_oc_blocks - 1) * jcp.oc_block, |
1716 | jcp.oc, jcp.oc_block); |
1717 | |
1718 | if (!jcp.global_transpose) { |
1719 | for (int nib = 0; nib < nb_ic_blocks; nib++) |
1720 | uker_trans(img, g, ic_b + nib, nib); |
1721 | } |
1722 | if (jcp.ndims == 5) { |
1723 | p.src = &ti->tr_src[tr_src_off_3d(g, ic_b, 0, 0, 0)]; |
1724 | } else { |
1725 | p.src = &ti->tr_src[tr_src_off(g, ic_b, 0, 0)]; |
1726 | } |
1727 | |
1728 | if (!jcp.global_transpose) { |
1729 | for (int nob = 0; nob < nb_oc_blocks; nob++) |
1730 | diff_dst_trans(img, g, oc_b + nob, nob); |
1731 | if (jcp.ndims == 5) { |
1732 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d(0, 0, 0, 0, 0)]; |
1733 | } else { |
1734 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off(0, 0, 0, 0)]; |
1735 | } |
1736 | } else { |
1737 | if (jcp.ndims == 5) { |
1738 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off_3d( |
1739 | g, oc_b, 0, 0, 0)]; |
1740 | } else { |
1741 | p.dst = &ti->tr_diff_dst[tr_diff_dst_off(g, oc_b, 0, 0)]; |
1742 | } |
1743 | } |
1744 | |
1745 | p.filt = (jcp.transform_to_vnni) |
1746 | ? diff_wei + wei_offset_int(g, oc_b, ic_b, 0) |
1747 | : diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b); |
1748 | p.bias = diff_bias + g * rnd_up(jcp.oc, jcp.oc_block) |
1749 | + oc_b * jcp.oc_block; |
1750 | p.channel = (img == ti->img_start); |
1751 | p.flags = 0 | (ic_b == 0 ? FLAG_IC_FIRST : 0); |
1752 | |
1753 | p.reduce_work = ic_to_compute; |
1754 | p.load_work = oc_to_compute; |
1755 | |
1756 | p.last_ic_block = (nb_ic_blocks == jcp.nb_ic_blocking) ? 0 : 1; |
1757 | p.last_oc_block = (nb_oc_blocks == jcp.nb_oc_blocking) ? 0 : 1; |
1758 | |
1759 | (*kernel_)(&p); |
1760 | } |
1761 | } |
1762 | } |
1763 | |
1764 | void jit_avx512_core_amx_convolution_bwd_weights_t::store_in_vnni_format( |
1765 | const thread_info_t *ti) const { |
1766 | const auto &jcp = kernel_->jcp; |
1767 | |
1768 | for_(int g = ti->g_start; g < ti->g_end; g++) |
1769 | for_(int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) |
1770 | for_(int ic_b = ti->ic_b_start; ic_b < ti->ic_b_start + ti->ic_b_work; |
1771 | ic_b += 2) |
1772 | { |
1773 | jit_conv_call_s p = jit_conv_call_s(); |
1774 | |
1775 | bfloat16_t *output = (bfloat16_t *)ti->diff_weights |
1776 | + wei_offset_ext(g, oc_b, (ic_b / 2), 0); |
1777 | float *input = ti->wei_bia_reduction + wei_offset_int(g, oc_b, ic_b, 0); |
1778 | |
1779 | p.src = (void *)input; |
1780 | p.dst = (void *)output; |
1781 | p.last_ic_block = ((ic_b + 1) >= jcp.nb_ic) ? 1 : 0; |
1782 | (*diff_wei_trans_kernel_)(&p); |
1783 | } |
1784 | } |
1785 | |
1786 | void jit_avx512_core_amx_convolution_bwd_weights_t:: |
1787 | reduce_and_convert_diff_weights_and_bias( |
1788 | const thread_info_t *ti) const { |
1789 | const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0)); |
1790 | |
1791 | const auto &jcp = kernel_->jcp; |
1792 | const int wei_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block * jcp.nb_ic |
1793 | * jcp.ic_block * jcp.kh * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1); |
1794 | |
1795 | const bool is_bf16_out = diff_weights_d.data_type() == data_type::bf16; |
1796 | const bool is_bf16_bias = jcp.with_bias && jcp.bia_dt == data_type::bf16; |
1797 | |
1798 | if (nthr_mb_ == 1) { |
1799 | if (is_bf16_out) { |
1800 | // reduction is not required, only conversion |
1801 | if (jcp.transform_to_vnni) { |
1802 | store_in_vnni_format(ti); |
1803 | } else { |
1804 | for_(int g = ti->g_start; g < ti->g_end; g++) |
1805 | for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; oc_b++) { |
1806 | const size_t acc_size = (size_t)ti->ic_b_work * jcp.kh |
1807 | * jcp.kw * ((jcp.ndims == 5) ? jcp.kd : 1) |
1808 | * jcp.ic_block * jcp.oc_block; |
1809 | const size_t off = wht_blk_off( |
1810 | diff_weights_d, g, oc_b, ti->ic_b_start); |
1811 | cvt_float_to_bfloat16( |
1812 | (bfloat16_t *)(ti->diff_weights) + off, |
1813 | (ti->wei_bia_reduction + off), acc_size); |
1814 | } |
1815 | } |
1816 | } |
1817 | if (is_bf16_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0) { |
1818 | for (int g = ti->g_start; g < ti->g_end; g++) { |
1819 | int result_start_idx = g * jcp.oc_without_padding |
1820 | + ti->oc_b_start * jcp.oc_block; |
1821 | int buffer_start_idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1822 | + ti->oc_b_start * jcp.oc_block; |
1823 | const size_t acc_size = nstl::min(jcp.oc_without_padding, |
1824 | ti->oc_b_end * jcp.oc_block) |
1825 | - ti->oc_b_start * jcp.oc_block; |
1826 | bfloat16_t *diff_bias |
1827 | = (bfloat16_t *)ti->diff_bias + result_start_idx; |
1828 | float *buffer = ti->bia_reduction + buffer_start_idx; |
1829 | cvt_float_to_bfloat16(diff_bias, buffer, acc_size); |
1830 | } |
1831 | } |
1832 | return; |
1833 | } |
1834 | |
1835 | /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */ |
1836 | if (jcp.global_transpose) |
1837 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); |
1838 | |
1839 | const int ic_b_kh_work |
1840 | = ti->ic_b_work * ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1841 | const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work; |
1842 | |
1843 | int start {0}, end {0}; |
1844 | balance211(work, nthr_mb_, ti->ithr_mb, start, end); |
1845 | if (!jcp.transform_to_vnni && start == end) return; |
1846 | |
1847 | const int _start_nthr_mb = 1; |
1848 | for (int thr_mb = _start_nthr_mb; thr_mb < nthr_mb_; ++thr_mb) { |
1849 | int w = start; |
1850 | int sub_g_start {0}, sub_oc_b_start {0}, sub_ic_b_kh_start {0}; |
1851 | nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start, |
1852 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1853 | while (w < end) { |
1854 | const int g = ti->g_start + sub_g_start; |
1855 | const int oc_b = ti->oc_b_start + sub_oc_b_start; |
1856 | const int ic_b = ti->ic_b_start |
1857 | + sub_ic_b_kh_start / ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1858 | const int kX |
1859 | = sub_ic_b_kh_start % ((jcp.ndims == 5) ? jcp.kd : jcp.kh); |
1860 | |
1861 | const size_t acc_size = (size_t)jcp.kw * jcp.ic_block * jcp.oc_block |
1862 | * ((jcp.ndims == 5) ? jcp.kh : 1) |
1863 | * nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start); |
1864 | |
1865 | const size_t off_ext |
1866 | = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kX); |
1867 | const size_t off_int = (jcp.transform_to_vnni) |
1868 | ? wei_offset_int(g, oc_b, ic_b, kX) |
1869 | : off_ext; |
1870 | |
1871 | float *wei_reduced = is_bf16_out |
1872 | ? ti->wei_bia_reduction + off_int |
1873 | : (float *)(ti->diff_weights) + off_ext; |
1874 | |
1875 | int thr_mb_buffer_idx = is_bf16_out ? thr_mb : thr_mb - 1; |
1876 | float *wei_to_reduce = ti->wei_bia_reduction |
1877 | + thr_mb_buffer_idx * wei_size + off_int; |
1878 | |
1879 | if (!jcp.transform_to_vnni && is_bf16_out && thr_mb == nthr_mb_ - 1) |
1880 | // the last iteration for bfloat16 requires conversion and |
1881 | // store to diff_weights array |
1882 | add_floats_and_cvt_to_bfloat16( |
1883 | (bfloat16_t *)(ti->diff_weights) + off_ext, wei_reduced, |
1884 | wei_to_reduce, acc_size); |
1885 | else |
1886 | acc_ker_->accumulate(wei_reduced, wei_to_reduce, acc_size); |
1887 | |
1888 | nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start, |
1889 | ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work); |
1890 | } |
1891 | if (jcp.with_bias && ti->ithr_ic_b == 0 && ti->ic_b_work > 0 |
1892 | && ti->ithr_mb == 0 && ti->img_work > 0) { |
1893 | for (int g = ti->g_start; g < ti->g_end; g++) { |
1894 | float *bias_reduced = is_bf16_bias ? ti->bia_reduction |
1895 | : (float *)(ti->diff_bias); |
1896 | int thr_mb_buffer_idx = is_bf16_bias ? thr_mb : thr_mb - 1; |
1897 | int bias_buf_size = jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
1898 | float *bias_to_reduce |
1899 | = ti->bia_reduction + thr_mb_buffer_idx * bias_buf_size; |
1900 | const size_t acc_size = nstl::min(jcp.oc_without_padding, |
1901 | ti->oc_b_end * jcp.oc_block) |
1902 | - ti->oc_b_start * jcp.oc_block; |
1903 | int idx = g * rnd_up(jcp.oc, jcp.oc_block) |
1904 | + ti->oc_b_start * jcp.oc_block; |
1905 | if (is_bf16_bias && thr_mb == nthr_mb_ - 1) { |
1906 | // the last iteration for bfloat16 requires conversion and |
1907 | // store to diff_weights array |
1908 | int diff_bias_idx = g * jcp.oc_without_padding |
1909 | + ti->oc_b_start * jcp.oc_block; |
1910 | add_floats_and_cvt_to_bfloat16( |
1911 | (bfloat16_t *)(ti->diff_bias) + diff_bias_idx, |
1912 | &bias_reduced[idx], &bias_to_reduce[idx], acc_size); |
1913 | } else { |
1914 | acc_ker_->accumulate( |
1915 | &bias_reduced[idx], &bias_to_reduce[idx], acc_size); |
1916 | } |
1917 | } |
1918 | } |
1919 | } |
1920 | |
1921 | if (jcp.transform_to_vnni && jcp.global_transpose) { |
1922 | simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_); |
1923 | store_in_vnni_format(ti); |
1924 | } |
1925 | } |
1926 | |
1927 | void jit_avx512_core_amx_convolution_bwd_weights_t::prepare_scratchpad_data( |
1928 | const exec_ctx_t &ctx) const { |
1929 | auto scratchpad = ctx.get_scratchpad_grantor(); |
1930 | |
1931 | const auto &jcp = pd()->jcp_; |
1932 | |
1933 | // XXX: See the comment about tr_iw and guarding elements in |
1934 | // jit_avx512_core_bf16_conv_bwd_weights_kernel_f32::init_conf() |
1935 | auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src); |
1936 | // Zero out guard elements that cross a buffer boundary to prevent a |
1937 | // race condition due to buffer overflows from memory optimization where |
1938 | // buffers sharing padding |
1939 | |
1940 | for (size_t ithr = 1; ithr <= jcp.tr_src_buf_count; ++ithr) { |
1941 | src_data_t *ts |
1942 | = &tr_src[ithr * jcp.tr_src_buf_size * jcp.nb_ic_blocking]; |
1943 | for (int i = 0; i < jcp.tr_src_num_guard_elems; ++i) |
1944 | ts[i] = 0; |
1945 | } |
1946 | |
1947 | if (jcp.global_transpose && jcp.nthr_oc_b > 1) { |
1948 | const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b; |
1949 | auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>( |
1950 | key_conv_tr_src_bctx); |
1951 | for (int i = 0; i < tr_src_bctx_size; ++i) |
1952 | simple_barrier::ctx_init(&tr_src_bctx[i]); |
1953 | } |
1954 | if (jcp.global_transpose) { |
1955 | if (jcp.nthr_ic_b > 1) { |
1956 | const int tr_diff_dst_bctx_size = jcp.nthr / jcp.nthr_ic_b; |
1957 | auto tr_diff_dst_bctx |
1958 | = scratchpad.template get<simple_barrier::ctx_t>( |
1959 | key_conv_tr_diff_dst_bctx); |
1960 | for (int i = 0; i < tr_diff_dst_bctx_size; ++i) |
1961 | simple_barrier::ctx_init(&tr_diff_dst_bctx[i]); |
1962 | } |
1963 | } |
1964 | |
1965 | if (nthr_mb_ > 1 |
1966 | || pd()->diff_weights_md(0)->data_type == data_type::bf16) { |
1967 | // TODO: don't use barrier for case |
1968 | // diff_weights_type == data_type::bf16 && nthr_mb_ == 1 |
1969 | simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>( |
1970 | key_conv_wei_bia_reduction_bctx)); |
1971 | } |
1972 | } |
1973 | |
1974 | void jit_avx512_core_amx_convolution_bwd_weights_t::execute_backward_weights( |
1975 | const exec_ctx_t &ctx) const { |
1976 | prepare_scratchpad_data(ctx); |
1977 | |
1978 | auto tcfg = ctx.get_scratchpad_grantor().template get<char>( |
1979 | key_conv_amx_tilecfg); |
1980 | kernel_->tile_configure(tcfg); |
1981 | |
1982 | const auto &jcp = pd()->jcp_; |
1983 | parallel(nthr_, [&](const int ithr, const int nthr) { |
1984 | assert(nthr_ == nthr); |
1985 | assert(utils::one_of(pd()->ndims(), 3, 4, 5)); |
1986 | |
1987 | amx_tile_configure(tcfg); |
1988 | |
1989 | thread_info_t thread_info(this, ctx, ithr); |
1990 | switch (jcp.harness) { |
1991 | case harness_2d_reduction: |
1992 | compute_diff_weights_2d(&thread_info); |
1993 | if (jcp.global_transpose) |
1994 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
1995 | break; |
1996 | case harness_3d_reduction: |
1997 | compute_diff_weights_3d(&thread_info); |
1998 | if (jcp.global_transpose) |
1999 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
2000 | break; |
2001 | case harness_compute_full_spatial: |
2002 | case harness_mb_reduction: |
2003 | compute_diff_weights(&thread_info); |
2004 | if (jcp.global_transpose) |
2005 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
2006 | break; |
2007 | default: assert(!"Invalid harness type" ); |
2008 | } |
2009 | |
2010 | amx_tile_release(); |
2011 | }); |
2012 | |
2013 | if (!jcp.global_transpose) { |
2014 | parallel(nthr_, [&](const int ithr, const int nthr) { |
2015 | assert(nthr_ == nthr); |
2016 | thread_info_t thread_info(this, ctx, ithr); |
2017 | reduce_and_convert_diff_weights_and_bias(&thread_info); |
2018 | }); |
2019 | } |
2020 | |
2021 | if (jcp.transform_to_vnni && !jcp.global_transpose) { |
2022 | parallel(nthr_, [&](const int ithr, const int nthr) { |
2023 | assert(nthr_ == nthr); |
2024 | thread_info_t thread_info(this, ctx, ithr); |
2025 | store_in_vnni_format(&thread_info); |
2026 | }); |
2027 | } |
2028 | |
2029 | if (pd()->with_bias() && (jcp.oc_without_padding % jcp.oc_block != 0) |
2030 | && jcp.bia_dt != data_type::bf16) { |
2031 | auto diff_bias = ctx.get_scratchpad_grantor().template get<const float>( |
2032 | key_conv_padded_bias); |
2033 | auto diff_bias_in = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS); |
2034 | const int padded_stride = rnd_up(jcp.oc, jcp.oc_block); |
2035 | const int stride = jcp.oc_without_padding; |
2036 | for (int g = 0; g < jcp.ngroups; ++g) { |
2037 | utils::array_copy(diff_bias_in + g * stride, |
2038 | diff_bias + g * padded_stride, stride); |
2039 | } |
2040 | } |
2041 | } |
2042 | |
2043 | } // namespace x64 |
2044 | } // namespace cpu |
2045 | } // namespace impl |
2046 | } // namespace dnnl |
2047 | |
2048 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
2049 | |