1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | #include "cpu/cpu_primitive.hpp" |
23 | |
24 | #include "cpu/x64/jit_brgemm_conv_bwd_strided.hpp" |
25 | #include "cpu/x64/jit_brgemm_conv_bwd_utils.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | using namespace dnnl::impl::status; |
33 | using namespace dnnl::impl::memory_tracking::names; |
34 | using namespace dnnl::impl::utils; |
35 | |
36 | using namespace nstl; |
37 | using namespace data_type; |
38 | |
39 | using namespace jit_avx512_core_brgemm_conv_bwd_trans_kernel; |
40 | |
41 | #define ndims_pick(v5, v4, v3) \ |
42 | ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0) |
43 | |
44 | static bool impl_supports_datatype(data_type_t data_type) { |
45 | switch (data_type) { |
46 | case data_type::bf16: return x64::mayiuse(x64::avx512_core); |
47 | case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16); |
48 | case data_type::f32: |
49 | case data_type::s32: |
50 | case data_type::s8: |
51 | case data_type::u8: return true; |
52 | default: return false; |
53 | } |
54 | } |
55 | |
56 | template <cpu_isa_t isa, bool enable_postops> |
57 | status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::pd_t::init( |
58 | engine_t *engine) { |
59 | using namespace data_type; |
60 | |
61 | const auto diff_src_type = diff_src_md(0)->data_type; |
62 | const auto wei_type = weights_md(0)->data_type; |
63 | const auto diff_dst_type = diff_dst_md(0)->data_type; |
64 | |
65 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
66 | auto skip_mask = enable_postops |
67 | ? (skip_mask_t::post_ops | skip_mask_t::sum_dt |
68 | | skip_mask_t::zero_points_runtime) |
69 | : skip_mask_t::none; |
70 | |
71 | const bool ok = is_bwd_d() |
72 | && set_default_alg_kind(alg_kind::convolution_direct) |
73 | && impl_supports_datatype(diff_src_type) |
74 | && impl_supports_datatype(wei_type) |
75 | && impl_supports_datatype(diff_dst_type) |
76 | && one_of(wei_type, f32, bf16, f16) && wei_type == diff_dst_type |
77 | && one_of(diff_src_type, wei_type, f32) |
78 | && one_of(with_bias(), one_of(bias_md_.data_type, f32, wei_type)) |
79 | && attr()->has_default_values(skip_mask, diff_src_type) |
80 | && IMPLICATION(enable_postops, |
81 | attr()->post_ops_.check_sum_consistent_dt(diff_src_type)) |
82 | && !has_zero_dim_memory(); |
83 | |
84 | if (!ok) return status::unimplemented; |
85 | |
86 | const auto is_amx = brgemm_convolution_bwd_utils::is_amx(isa); |
87 | |
88 | CHECK(brgemm_convolution_bwd_utils::init_conf(jcp_, isa, desc_, |
89 | diff_dst_md_, weights_md_, diff_src_md_, bias_md_, attr_, |
90 | dnnl_get_max_threads(), enable_postops)); |
91 | |
92 | const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail); |
93 | |
94 | batchsizes.resize(jcp_.max_batch + 1); |
95 | for (int i = 0; i <= jcp_.max_batch; i++) |
96 | batchsizes[i] = -1; |
97 | |
98 | first_bs = 0; |
99 | bs_c = 0; |
100 | |
101 | batchsizes[jcp_.max_batch] = bs_c; |
102 | first_bs = jcp_.max_batch; |
103 | bs_c++; |
104 | |
105 | brgs_sz_ = bs_c * adj_M * 2 * 2 * 2; |
106 | brgs_.resize(brgs_sz_); |
107 | bd_masks.resize(brgs_sz_); |
108 | |
109 | const float alpha = 1.0; |
110 | const float beta = 1.0; |
111 | |
112 | const auto &p = attr()->post_ops_; |
113 | const int sum_idx = p.find(primitive_kind::sum); |
114 | const bool with_sum = (sum_idx != -1); |
115 | |
116 | auto maybe_M_mask |
117 | = [&](int brg_idx, brgemm_attr_t &brgattr, int vM, int vbrgM) { |
118 | if (!jcp_.use_M_mask) return; |
119 | auto sm_size = vbrgM; |
120 | bd_masks[brg_idx] = std::make_shared<std::vector<char>>(); |
121 | bd_masks[brg_idx]->resize(sm_size); |
122 | char *bd_mask = bd_masks[brg_idx]->data(); |
123 | for (int ibrgM = 0; ibrgM < sm_size; ibrgM++) { |
124 | bd_mask[ibrgM] = 1; |
125 | } |
126 | brgattr.bd_mask = bd_mask; |
127 | }; |
128 | |
129 | const auto M_end = nstl::max(jcp_.M, jcp_.M_tail); |
130 | for (int i = 0; i < M_end; i++) { |
131 | auto vM = i + 1; |
132 | // init only needed brgemm descriptors |
133 | if (one_of(jcp_.exec_type, exec_trans, exec_vpad) && vM != jcp_.M |
134 | && vM != jcp_.M_tail) |
135 | continue; |
136 | for (int bs = 0; bs <= jcp_.max_batch; bs++) { |
137 | if (batchsizes[bs] == -1) continue; |
138 | for_(int i_init = 0; i_init < 2; i_init++) |
139 | for_(int i_N = 0; i_N < 2; i_N++) |
140 | for (int i_K = 0; i_K < 2; i_K++) { |
141 | auto vbeta = (i_init) ? 0 : beta; |
142 | auto vN = (i_N) ? jcp_.N_tail : jcp_.N; |
143 | auto vK = (i_K) ? jcp_.K_tail : jcp_.K; |
144 | auto vbrgM = jcp_.use_M_mask |
145 | ? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail) |
146 | : vM; |
147 | auto brg_idx = get_brg_idx(bs, i, i_init, i_N, i_K); |
148 | // if brgemm_t already created then skip this iteration |
149 | if (brgs_[brg_idx] != nullptr) continue; |
150 | brgs_[brg_idx] = std::make_shared<brgemm_t>(); |
151 | brgemm_t *brg = brgs_[brg_idx].get(); |
152 | if (vN == 0 || vK == 0) continue; |
153 | brgemm_strides_t brg_strides; |
154 | brg_strides.stride_a = jcp_.brg_stride_a; |
155 | brg_strides.stride_b = jcp_.brg_stride_b; |
156 | brg->req_cal_comp_pads = jcp_.req_brg_comp_pad |
157 | && nstl::max(jcp_.l_pad, jcp_.r_pad); |
158 | const auto strides_ptr = (jcp_.brg_type == brgemm_strd) |
159 | ? &brg_strides |
160 | : nullptr; |
161 | CHECK(brgemm_desc_init(brg, isa, jcp_.brg_type, diff_dst_type, |
162 | wei_type, false, false, brgemm_row_major, alpha, vbeta, |
163 | jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK, |
164 | strides_ptr)); |
165 | |
166 | brgemm_attr_t brgattr; |
167 | brgattr.use_uker = jcp_.use_uker; |
168 | brgattr.use_interleave_stores = jcp_.use_interleave_stores; |
169 | brgattr.hint_prefetching = jcp_.hint_prefetching; |
170 | brgattr.max_bs = bs; |
171 | brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost |
172 | ? brgemm_bd_loop_innermost |
173 | : brgemm_ld_loop_innermost; |
174 | if (jcp_.amx_tile_load_xx) { |
175 | // assuming 2x2 decomposition in amx brgemm kernel |
176 | // and overlap of input by kw |
177 | const auto bd_blocking = 2 * jcp_.amx_h; |
178 | const auto ld_blocking = 2 * 16; |
179 | brgattr.hint_expected_A_size = bd_blocking * jcp_.K |
180 | * jcp_.kd_block * jcp_.kh_block; |
181 | brgattr.hint_expected_B_size = ld_blocking * jcp_.K |
182 | * jcp_.kd_block * jcp_.kh_block * jcp_.kw_block; |
183 | brgattr.hint_expected_C_size = bd_blocking * ld_blocking; |
184 | } else { |
185 | brgattr.hint_expected_A_size = 0; |
186 | brgattr.hint_expected_B_size = 0; |
187 | brgattr.hint_expected_C_size = 0; |
188 | } |
189 | |
190 | brgattr.wary_tail_read = false; |
191 | maybe_M_mask(brg_idx, brgattr, vM, vbrgM); |
192 | brgattr.bd_mask_level = jcp_.use_M_mask; |
193 | |
194 | if (is_amx) { |
195 | brgattr.max_top_vpad = 0; |
196 | brgattr.max_bottom_vpad = 0; |
197 | } else { |
198 | brgattr.max_top_vpad = jcp_.max_vpad; |
199 | brgattr.max_bottom_vpad = jcp_.max_vpad; |
200 | } |
201 | brgattr.generate_skip_accumulation = true; |
202 | CHECK(brgemm_desc_set_attr(brg, brgattr)); |
203 | |
204 | auto LDD = jcp_.stride_w * jcp_.ic_without_padding; |
205 | brg->with_sum = with_sum; |
206 | CHECK(brgemm_desc_set_postops( |
207 | brg, attr(), &diff_src_md_, LDD, jcp_.bia_dt)); |
208 | jcp_.amx_buf_size_per_thread |
209 | = nstl::max(brg->get_wsp_buffer_size(), |
210 | jcp_.amx_buf_size_per_thread); |
211 | } |
212 | } |
213 | } |
214 | |
215 | auto scratchpad = scratchpad_registry().registrar(); |
216 | brgemm_convolution_bwd_utils::init_scratchpad(scratchpad, jcp_); |
217 | |
218 | return status::success; |
219 | } |
220 | |
221 | template <cpu_isa_t isa, bool enable_postops> |
222 | status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::add_brg_kernel( |
223 | int bs, int M, int i_N, int i_K, int i_init) { |
224 | if (M <= 0) return status::success; |
225 | const auto _pd = pd(); |
226 | const auto &jcp = _pd->jcp_; |
227 | const auto &brgs = _pd->brgs_; |
228 | |
229 | auto N = (i_N) ? jcp.N_tail : jcp.N; |
230 | auto K = (i_K) ? jcp.K_tail : jcp.K; |
231 | if (N <= 0 || K <= 0) return status::success; |
232 | auto brg_idx = _pd->get_brg_idx(bs, M - 1, i_init, i_N, i_K); |
233 | auto brg = brgs[brg_idx]; |
234 | if (!brg_kernels_[brg_idx] && brg && brg->bcast_dim > 0 && brg->load_dim > 0 |
235 | && brg->reduce_dim > 0) { |
236 | brgemm_kernel_t *brg_kernel = nullptr; |
237 | CHECK(brgemm_kernel_create(&brg_kernel, *brg)); |
238 | CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel)); |
239 | if (is_amx) { |
240 | CHECK(brgemm_init_tiles(*brg, &brg_kernel_palettes_[brg_idx].a[0])); |
241 | } |
242 | } |
243 | return status::success; |
244 | } |
245 | |
246 | template <cpu_isa_t isa, bool enable_postops> |
247 | status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::init( |
248 | engine_t *engine) { |
249 | |
250 | const auto _pd = pd(); |
251 | const auto &jcp = _pd->jcp_; |
252 | |
253 | bia_dsz = jcp.bia_dsz; |
254 | acc_dsz = jcp.acc_dsz; |
255 | src_dsz = jcp.src_dsz; |
256 | wei_dsz = jcp.wei_dsz; |
257 | dst_dsz = jcp.dst_dsz; |
258 | |
259 | auto ndims = _pd->ndims(); |
260 | if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!" ); |
261 | |
262 | KD = ndims_pick(jcp.kd, 1, 1); |
263 | KH = ndims_pick(jcp.kh, jcp.kh, 1); |
264 | KW = jcp.kw; |
265 | |
266 | EXT_KD = ndims_pick(jcp.ext_kd, 1, 1); |
267 | EXT_KH = ndims_pick(jcp.ext_kh, jcp.ext_kh, 1); |
268 | EXT_KW = jcp.ext_kw; |
269 | |
270 | ODP = ndims_pick(jcp.odp, 1, 1); |
271 | OHP = ndims_pick(jcp.ohp, jcp.ohp, 1); |
272 | OWP = jcp.owp; |
273 | |
274 | KS = KD * KH * KW; |
275 | KD_BLOCK = ndims_pick(jcp.kd_block, 1, 1); |
276 | KH_BLOCK = ndims_pick(jcp.kh_block, jcp.kh_block, 1); |
277 | KW_BLOCK = jcp.kw_block; |
278 | KD_BLOCK_PAD = ndims_pick(jcp.kd_block_pad, 1, 1); |
279 | KH_BLOCK_PAD = ndims_pick(jcp.kh_block_pad, jcp.kh_block_pad, 1); |
280 | ID = ndims_pick(jcp.id, 1, 1); |
281 | IH = ndims_pick(jcp.ih, jcp.ih, 1); |
282 | IW = jcp.iw; |
283 | OD = ndims_pick(jcp.od, 1, 1); |
284 | OH = ndims_pick(jcp.oh, jcp.oh, 1); |
285 | OW = jcp.ow; |
286 | SD = ndims_pick(jcp.stride_d, 1, 1); |
287 | SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1); |
288 | SW = jcp.stride_w; |
289 | FP = ndims_pick(jcp.f_pad, 0, 0); |
290 | TP = ndims_pick(jcp.t_pad, jcp.t_pad, 0); |
291 | LP = jcp.l_pad; |
292 | DD = ndims_pick(jcp.dilate_d, 0, 0) + 1; |
293 | DH = ndims_pick(jcp.dilate_h, jcp.dilate_h, 0) + 1; |
294 | DW = jcp.dilate_w + 1; |
295 | |
296 | oc_chunks = div_up(jcp.nb_oc, jcp.nb_oc_blocking); |
297 | |
298 | // const variables used for address calculations |
299 | src_w_sz = static_cast<dim_t>(OW) * jcp.ngroups * jcp.oc_without_padding; |
300 | src_h_sz = OH * src_w_sz; |
301 | src_d_sz = OD * src_h_sz; |
302 | dst_w_sz = static_cast<dim_t>(IW) * jcp.ic_without_padding; |
303 | dst_h_sz = IH * dst_w_sz; |
304 | dst_d_sz = ID * dst_h_sz; |
305 | |
306 | wei_oc_sz = static_cast<dim_t>(jcp.ocp) * jcp.ic_block; |
307 | wei_kw_sz = KW * wei_oc_sz; |
308 | wei_kh_sz = KH * wei_kw_sz; |
309 | wei_kd_sz = KD * wei_kh_sz; |
310 | wei_icb_sz = jcp.nb_ic * wei_kd_sz; |
311 | |
312 | need_postwork = jcp.with_bias || jcp.with_eltwise || jcp.with_binary |
313 | || (jcp.dst_dt != jcp.acc_dt) || jcp.with_sum || jcp.use_M_mask |
314 | || jcp.src_zero_point || jcp.dst_zero_point; |
315 | |
316 | // ---- Initialize arrays --------------------- |
317 | brg_kernels_.resize(_pd->brgs_sz_); |
318 | brg_kernel_palettes_.resize(_pd->brgs_sz_); |
319 | |
320 | for (int i = 0; i < _pd->brgs_sz_; i++) |
321 | brg_kernels_[i] = nullptr; |
322 | |
323 | CHECK(safe_ptr_assign(copy_to_pbuffer_, |
324 | new jit_avx512_core_brgemm_conv_bwd_trans_kernel_t(jcp))); |
325 | CHECK(copy_to_pbuffer_->create_kernel()); |
326 | |
327 | const auto ow_block = jcp.owp; |
328 | const auto oh_block = jcp.ohp; |
329 | const auto od_block = jcp.odp; |
330 | |
331 | pbuf_w_sz = (dim_t)jcp.oc_block * ow_block; |
332 | pbuf_h_sz = pbuf_w_sz * oh_block; |
333 | pbuf_d_sz = pbuf_h_sz * od_block; |
334 | |
335 | is_amx = brgemm_convolution_bwd_utils::is_amx(isa); |
336 | |
337 | // TODO: this is only needed if we have d/h padding exceeding kd/kh |
338 | int M_begin = 0; |
339 | int M_end = (jcp.M_tail == jcp.M) ? 1 : 2; |
340 | int N_begin = 0; |
341 | int N_end = (jcp.N_tail == jcp.N) ? 1 : 2; |
342 | int K_begin = 0; |
343 | int K_end = (jcp.K_tail == jcp.K) ? 1 : 2; |
344 | int i_init_begin = (div_up(jcp.nb_oc, jcp.nb_oc_blocking) == 1 |
345 | && KD_BLOCK == KD && KH_BLOCK == KH) |
346 | ? 1 |
347 | : 0; |
348 | int i_init_end = 2; |
349 | |
350 | for (int bs = 0; bs <= jcp.max_batch; bs++) { |
351 | if (_pd->batchsizes[bs] == -1) continue; |
352 | |
353 | for_(int i_N = N_begin; i_N < N_end; i_N++) |
354 | for_(int i_M = M_begin; i_M < M_end; i_M++) |
355 | for_(int i_init = i_init_begin; i_init < i_init_end; i_init++) |
356 | for (int i_K = K_begin; i_K < K_end; i_K++) { |
357 | auto M = (i_M) ? jcp.M_tail : jcp.M; |
358 | if (M <= 0) continue; |
359 | add_brg_kernel(bs, M, i_N, i_K, i_init); |
360 | } |
361 | } |
362 | |
363 | return status::success; |
364 | } |
365 | |
366 | template <cpu_isa_t isa, bool enable_postops> |
367 | status_t brgemm_convolution_bwd_strided_t<isa, enable_postops>::execute( |
368 | const exec_ctx_t &ctx) const { |
369 | const auto _pd = pd(); |
370 | const auto &jcp = _pd->jcp_; |
371 | |
372 | // XXX: brgemm requires scales to be passed, so passing default wei scales |
373 | DEFINE_ARG_SCALES_BUFFER(oscales, DNNL_ARG_WEIGHTS); |
374 | |
375 | const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); |
376 | brgemm_batch_element_t *const __restrict brg_batch_global |
377 | = (jcp.brg_type == brgemm_strd && jcp.exec_type != exec_vpad) |
378 | ? nullptr |
379 | : scratchpad.template get<brgemm_batch_element_t>( |
380 | key_brgemm_primitive_batch); |
381 | char *const __restrict c_buffer_global = (jcp.use_buffer) |
382 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer) |
383 | : nullptr; |
384 | |
385 | auto inp_p_buffer = (jcp.exec_type == exec_trans) |
386 | ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer) |
387 | : nullptr; |
388 | auto inp_p_buffer_mask = (jcp.exec_type == exec_trans) |
389 | ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask) |
390 | : nullptr; |
391 | |
392 | char *const wsp_tile_global = is_amx |
393 | ? scratchpad.template get<char>(key_conv_amx_tile_buffer) |
394 | : nullptr; |
395 | |
396 | brgemm_bwd_exec_ctx_t brgemm_ctx(ctx, _pd); |
397 | |
398 | const char *const __restrict diff_dst = brgemm_ctx.diff_dst; |
399 | |
400 | const dim_t work_amount = static_cast<dim_t>(jcp.mb) * jcp.ngroups |
401 | * jcp.nb_ic * jcp.nb_id * jcp.nb_ih * jcp.nb_iw; |
402 | |
403 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
404 | if (ithr >= work_amount) return; |
405 | |
406 | brgemm_batch_element_t *const __restrict brg_batch = brg_batch_global |
407 | + static_cast<size_t>(ithr) * jcp.adjusted_batch_size; |
408 | char *const __restrict c_buffer = (jcp.use_buffer) |
409 | ? c_buffer_global + ithr * acc_dsz * jcp.buffer_size |
410 | : nullptr; |
411 | char *inp_buffer = (jcp.exec_type == exec_trans) |
412 | ? inp_p_buffer + src_dsz * ithr * jcp.inp_buffer_size |
413 | : nullptr; |
414 | if (is_amx) { |
415 | // Workaround: for some machines SEGFAULT possible on tile load |
416 | // if the page was not touched before it |
417 | for (dim_t i = 0; i < jcp.inp_buffer_size; |
418 | i += brgemm_convolution_bwd_utils::P4K) |
419 | inp_buffer[i] = 0; |
420 | } |
421 | |
422 | uint8_t *__restrict inp_buffer_mask = (jcp.exec_type == exec_trans) |
423 | ? inp_p_buffer_mask + ithr * jcp.inp_buffer_mask_size |
424 | : nullptr; |
425 | |
426 | char *const wsp_tile = is_amx |
427 | ? wsp_tile_global + ithr * 2 * brgemm_convolution_bwd_utils::P4K |
428 | : nullptr; |
429 | dim_t start {0}, end {0}; |
430 | balance211(work_amount, nthr, ithr, start, end); |
431 | int n {0}, g {0}, icb {0}, idb {0}, ihb {0}, iwb {0}; |
432 | |
433 | nd_iterator_init(start, n, jcp.mb, idb, jcp.nb_id, ihb, jcp.nb_ih, iwb, |
434 | jcp.nb_iw, g, jcp.ngroups, icb, jcp.nb_ic); |
435 | |
436 | brgemm_bwd_thread_ctx_t btc( |
437 | brgemm_ctx, ithr, brg_batch, c_buffer, wsp_tile); |
438 | std::memset(btc.cur_palette.a, 0, AMX_PALETTE_SIZE); |
439 | |
440 | int last_n = -1; |
441 | int last_g = -1; |
442 | int last_occ = -1; |
443 | int last_idb = -1; |
444 | int last_ihb = -1; |
445 | int last_iwb = -1; |
446 | for (auto work = start; work < end; work++) { |
447 | btc.g = g; |
448 | btc.n = n; |
449 | btc.icb = icb; |
450 | btc.idb = idb; |
451 | btc.ihb = ihb; |
452 | btc.iwb = iwb; |
453 | btc.oscales = oscales; |
454 | |
455 | auto id_begin = idb * jcp.id_block; |
456 | auto id_end = nstl::min(ID, id_begin + jcp.id_block); |
457 | auto ih_begin = ihb * jcp.ih_block; |
458 | auto ih_end = nstl::min(IH, ih_begin + jcp.ih_block); |
459 | |
460 | for_(int id = id_begin; id < id_end; id++) |
461 | for_(int ih = ih_begin; ih < ih_end; ih++) |
462 | for (int occ = 0; occ < oc_chunks; occ++) { |
463 | btc.id = id; |
464 | btc.ih = ih; |
465 | btc.occ = occ; |
466 | |
467 | maybe_trans_inp(ithr, diff_dst, inp_buffer, inp_buffer_mask, g, |
468 | n, occ, idb, ihb, iwb, last_g, last_n, last_occ, |
469 | last_idb, last_ihb, last_iwb); |
470 | for (int sw = 0; sw < SW; sw++) { |
471 | btc.sw = sw; |
472 | ker_trans(btc, inp_buffer); |
473 | } |
474 | |
475 | last_n = n; |
476 | last_g = g; |
477 | last_occ = occ; |
478 | last_idb = idb; |
479 | last_ihb = ihb; |
480 | last_iwb = iwb; |
481 | } |
482 | nd_iterator_step(n, jcp.mb, idb, jcp.nb_id, ihb, jcp.nb_ih, iwb, |
483 | jcp.nb_iw, g, jcp.ngroups, icb, jcp.nb_ic); |
484 | } |
485 | if (is_amx) { amx_tile_release(); } |
486 | }); |
487 | |
488 | return status::success; |
489 | } |
490 | |
491 | template <cpu_isa_t isa, bool enable_postops> |
492 | void brgemm_convolution_bwd_strided_t<isa, enable_postops>::call_brgemm_kernel( |
493 | brgemm_bwd_thread_ctx_t &btc, int brg_idx, int batch_size, char *ptr_C, |
494 | char *ptr_D, const char *bias_w, int g_ic, bool do_postops, |
495 | const void *binary_post_ops_rhs, int32_t src_zp_vals, |
496 | int32_t *src_zp_ptr, int32_t *dst_zp_ptr, int32_t *s8s8_comp, |
497 | bool do_only_comp, bool is_first_call_postops) const { |
498 | |
499 | const auto _pd = pd(); |
500 | const auto &jcp = _pd->jcp_; |
501 | |
502 | const auto brg_ker = brg_kernels_[brg_idx].get(); |
503 | assert(brg_ker != nullptr); |
504 | |
505 | if (is_first_call_postops) return; |
506 | |
507 | if (is_amx) { |
508 | if (std::memcmp(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a, |
509 | AMX_PALETTE_SIZE) |
510 | != 0) { |
511 | amx_tile_configure(brg_kernel_palettes_[brg_idx].a); |
512 | std::memcpy(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a, |
513 | AMX_PALETTE_SIZE); |
514 | } |
515 | } |
516 | |
517 | const auto do_only_pass_comp = !do_postops && jcp.src_zero_point |
518 | && (jcp.req_brg_comp_pad || jcp.max_vpad > 0); |
519 | const auto do_skip_accm = batch_size == 0; |
520 | const auto maybe_do_postops = one_of( |
521 | true, do_postops, do_only_comp, do_only_pass_comp, do_skip_accm); |
522 | if (maybe_do_postops) { |
523 | const brgemm_post_ops_data_t post_ops_data { |
524 | static_cast<const char *>(bias_w), |
525 | &btc.oscales[jcp.is_ic_scale * g_ic], binary_post_ops_rhs, |
526 | static_cast<size_t>(g_ic), 0, btc.brgemm_ctx.dst, 0, |
527 | static_cast<void *>(src_zp_ptr), nullptr, |
528 | static_cast<void *>(dst_zp_ptr), do_skip_accm, src_zp_vals, |
529 | do_only_comp, do_only_pass_comp}; |
530 | |
531 | void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile) |
532 | : static_cast<void *>(s8s8_comp); |
533 | |
534 | if (do_postops || do_skip_accm) |
535 | brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch, |
536 | ptr_C, ptr_D, post_ops_data, scratch); |
537 | else |
538 | brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch, |
539 | ptr_C, ptr_C, post_ops_data, scratch); |
540 | } else |
541 | brgemm_kernel_execute(brg_ker, batch_size, btc.brg_batch, ptr_C, |
542 | static_cast<void *>(btc.wsp_tile)); |
543 | } |
544 | |
545 | template <cpu_isa_t isa, bool enable_postops> |
546 | void brgemm_convolution_bwd_strided_t<isa, enable_postops>::maybe_trans_inp( |
547 | int ithr, const char *__restrict src, char *__restrict inp_buffer, |
548 | uint8_t *__restrict inp_buffer_mask, int g, int n, int occ, int idb, |
549 | int ihb, int iwb, int last_g, int last_n, int last_occ, int last_idb, |
550 | int last_ihb, int last_iwb) const { |
551 | |
552 | const auto _pd = pd(); |
553 | const auto &jcp = _pd->jcp_; |
554 | const auto ocb = occ * jcp.nb_oc_blocking; |
555 | |
556 | if (last_g == g && last_n == n && last_occ == occ && last_idb == idb |
557 | && last_ihb == ihb && last_iwb == iwb) |
558 | return; |
559 | |
560 | auto cp = jit_brgemm_conv_bwd_trans_kernel_call_s(); |
561 | |
562 | const auto oc = ocb * jcp.oc_block; |
563 | const auto g_oc = g * jcp.oc + oc; |
564 | |
565 | const auto sw = jcp.l_pad % jcp.stride_w; |
566 | const auto kw = (jcp.kw - 1) % jcp.stride_w; |
567 | const auto kw_x = (jcp.kw - 1) - nstl::modulo(kw - sw, jcp.stride_w); |
568 | const auto ow = (iwb * jcp.iw_block + jcp.l_pad - kw_x * (jcp.dilate_w + 1)) |
569 | / jcp.stride_w; |
570 | |
571 | int od_start {0}, od_end {0}, oh_start {0}, oh_end {0}; |
572 | |
573 | const auto sh = jcp.t_pad % jcp.stride_h; |
574 | const auto kh = (jcp.kh - 1) % jcp.stride_h; |
575 | const auto kh_x = (jcp.kh - 1) - nstl::modulo(kh - sh, jcp.stride_h); |
576 | oh_start = (ihb * jcp.ih_block + jcp.t_pad - kh_x * (jcp.dilate_h + 1)) |
577 | / jcp.stride_h; |
578 | oh_end = oh_start + jcp.oh_block; |
579 | |
580 | const auto sd = jcp.f_pad % jcp.stride_d; |
581 | const auto kd = (jcp.kd - 1) % jcp.stride_d; |
582 | const auto kd_x = (jcp.kd - 1) - nstl::modulo(kd - sd, jcp.stride_d); |
583 | od_start = (idb * jcp.id_block + jcp.f_pad - kd_x * (jcp.dilate_d + 1)) |
584 | / jcp.stride_d; |
585 | od_end = od_start + jcp.od_block; |
586 | |
587 | const auto rows_to_copy = min(jcp.oh, oh_end) - max(0, oh_start); |
588 | cp.iwb = iwb; |
589 | cp.oc = oc; |
590 | const auto ow_buf = ow; |
591 | dim_t inp_offset_start, out_offset_start; |
592 | |
593 | cp.t_pad = 0; |
594 | cp.b_pad = 0; |
595 | cp.h_count = max(0, rows_to_copy); |
596 | |
597 | const auto oh_buf = max(0, oh_start); |
598 | |
599 | inp_offset_start = static_cast<dim_t>(n) * src_d_sz |
600 | + max(0, oh_start) * src_w_sz |
601 | + max(0, ow) * jcp.ngroups * jcp.oc_without_padding + g_oc; |
602 | out_offset_start = oh_buf * pbuf_w_sz + ow_buf * jcp.oc_block; |
603 | |
604 | for (int od = max(0, od_start); od < min(jcp.od, od_end); od++) { |
605 | const auto inp_offset = inp_offset_start + od * src_h_sz; |
606 | const auto od_buf = od; |
607 | const auto out_offset = out_offset_start + od_buf * pbuf_h_sz; |
608 | cp.src = src + src_dsz * inp_offset; |
609 | cp.dst = inp_buffer + src_dsz * out_offset; |
610 | |
611 | (*copy_to_pbuffer_)(&cp); |
612 | } |
613 | } |
614 | |
615 | template <cpu_isa_t isa, bool enable_postops> |
616 | void brgemm_convolution_bwd_strided_t<isa, enable_postops>::ker_trans( |
617 | brgemm_bwd_thread_ctx_t &btc, char *inp_buffer) const { |
618 | |
619 | const auto _pd = pd(); |
620 | const auto &jcp = _pd->jcp_; |
621 | auto ndims = _pd->ndims(); |
622 | |
623 | const char *const __restrict weights = btc.brgemm_ctx.weights; |
624 | const char *const __restrict bias = btc.brgemm_ctx.bias; |
625 | char *const __restrict dst = btc.brgemm_ctx.dst; |
626 | const std::vector<const void *> &post_ops_binary_rhs_arg_vec |
627 | = btc.brgemm_ctx.post_ops_binary_rhs_arg_vec; |
628 | const int ic = btc.icb * jcp.ic_block; |
629 | const int g_ic = btc.g * jcp.ic + ic; |
630 | const int ocb = btc.occ * jcp.nb_oc_blocking; |
631 | const int oc = ocb * jcp.oc_block; |
632 | const dim_t iw = btc.iwb * jcp.iw_block + btc.sw; |
633 | const dim_t ih = btc.ih; |
634 | const dim_t id = btc.id; |
635 | |
636 | // od = (id + FP - kd * DD) / SD <-- general relation for all sets of (od, id, kd) that overlap |
637 | // for a given index from diff_src, we need to find the appropriate stride sector |
638 | int kd_s_(0), kh_s_(0), kw_s(0), kd_f_(0), kh_f_(0), kw_f(0); |
639 | |
640 | auto set_k_range = [&](int P, int D, int S, dim_t i, dim_t O, int K, |
641 | int &k_s, int &k_f, bool is_w) { |
642 | int s(0), o_test(0); |
643 | while (true) { |
644 | o_test = i + P - s * D; |
645 | if (o_test % S == 0) break; |
646 | s++; |
647 | } |
648 | |
649 | k_f = is_w ? K : min(K, static_cast<int>(div_up(i + P + 1, D))); |
650 | k_s = is_w ? 0 : max(0, static_cast<int>(div_up(i + P - O * S + 1, D))); |
651 | |
652 | while (k_s % S != s) |
653 | k_s++; |
654 | }; |
655 | |
656 | set_k_range(FP, DD, SD, id, OD, KD, kd_s_, kd_f_, false); |
657 | set_k_range(TP, DH, SH, ih, OH, KH, kh_s_, kh_f_, false); |
658 | set_k_range(LP, DW, SW, iw, OW, KW, kw_s, kw_f, true); |
659 | |
660 | const auto kh_f = ndims_pick(kh_f_, kh_f_, 1); |
661 | const auto kh_s = ndims_pick(kh_s_, kh_s_, 0); |
662 | |
663 | const auto kd_f = ndims_pick(kd_f_, 1, 1); |
664 | const auto kd_s = ndims_pick(kd_s_, 0, 0); |
665 | |
666 | const bool is_oc_tail |
667 | = (btc.occ == oc_chunks - 1 && ((jcp.oc - oc) % jcp.oc_block != 0)); |
668 | |
669 | const bool is_ic_tail = (jcp.ic - ic < jcp.ic_block); |
670 | const char *const __restrict bias_w |
671 | = bias ? bias + (bias_d.blk_off(g_ic) * bia_dsz) : nullptr; |
672 | const auto nb_oc_b = nstl::min(jcp.nb_oc_blocking, jcp.nb_oc - ocb) |
673 | - (is_oc_tail ? 1 : 0); |
674 | char *const __restrict dst_base = dst + dst_dsz * (btc.n * dst_d_sz + g_ic); |
675 | char *ptr_C; |
676 | char *ptr_D; |
677 | int kd_b(0), kd_e(0), kh_b(0), kh_e(0), k_l(0); |
678 | |
679 | const auto wei_base |
680 | = weights + wei_dsz * (btc.g * wei_icb_sz + btc.icb * wei_kd_sz); |
681 | const dim_t iw_b {iw}; |
682 | |
683 | ptr_D = dst_base |
684 | + dst_dsz |
685 | * (id * dst_h_sz + ih * dst_w_sz |
686 | + iw_b * jcp.ic_without_padding); |
687 | ptr_C = (jcp.use_buffer) ? btc.c_buffer : static_cast<char *>(ptr_D); |
688 | |
689 | const auto ker_i = (jcp.M > 0 ? jcp.M : jcp.M_tail) - 1; |
690 | |
691 | bool is_first_call_postops = false, |
692 | is_first_call_postops_state_changed = false; |
693 | const auto call_brgemm = [&](int brg_idx, int oc_block_s, int n_oc_blocks, |
694 | bool do_postops) { |
695 | const auto kh_ee = kh_e; |
696 | const auto kw_e = kw_f; |
697 | const auto pbuf_base = inp_buffer; |
698 | |
699 | int k_sum = 0; |
700 | for (int i_ocb = 0; i_ocb < n_oc_blocks; i_ocb++) { |
701 | const auto oc_off = (oc_block_s + i_ocb) * jcp.oc_block; |
702 | const auto wei_oc = oc + oc_off; |
703 | const auto n_ocb_off = i_ocb * k_l; |
704 | const auto pbuf_base_oc = pbuf_base; |
705 | const auto wei_base_oc = wei_base + wei_dsz * wei_oc * jcp.ic_block; |
706 | |
707 | auto k = 0; |
708 | for (int kd = kd_b; kd < kd_e; kd++) { |
709 | auto od = (id - kd * DD + FP); |
710 | if (od % SD != 0) continue; |
711 | od /= SD; |
712 | const auto pbuf_base_kd |
713 | = pbuf_base_oc + src_dsz * od * pbuf_h_sz; |
714 | const auto wei_base_kd = wei_base_oc + wei_dsz * kd * wei_kh_sz; |
715 | for (int kh = kh_b; kh < kh_ee; kh++) { |
716 | auto oh = (ih - kh * DH + TP); |
717 | if (oh % SH != 0) continue; |
718 | oh /= SH; |
719 | const auto pbuf_base_kh |
720 | = pbuf_base_kd + src_dsz * oh * pbuf_w_sz; |
721 | const auto wei_base_kh |
722 | = wei_base_kd + wei_dsz * kh * wei_kw_sz; |
723 | for (int kw = kw_s; kw < kw_e; kw += SW) { |
724 | const auto ow = (iw - kw * DW + LP) / SW; |
725 | // inp_buffer layout is Cdhw<oc_block>c |
726 | btc.brg_batch[n_ocb_off + k].ptr.A = pbuf_base_kh |
727 | + src_dsz * (ow + jcp.l_ovf) * jcp.oc_block; |
728 | btc.brg_batch[n_ocb_off + k].vvpad.top = 0; |
729 | btc.brg_batch[n_ocb_off + k].vvpad.bottom = 0; |
730 | // general wei layout is gIdhwO<block_i><block_o> |
731 | btc.brg_batch[n_ocb_off + k].ptr.B |
732 | = wei_base_kh + wei_dsz * kw * wei_oc_sz; |
733 | k++; |
734 | } |
735 | } |
736 | } |
737 | k_sum += k; |
738 | } |
739 | call_brgemm_kernel(btc, brg_idx, k_sum, ptr_C, ptr_D, bias_w, g_ic, |
740 | do_postops, post_ops_binary_rhs_arg_vec.data(), 0, nullptr, |
741 | nullptr, nullptr, false, is_first_call_postops); |
742 | if (!is_first_call_postops_state_changed) { |
743 | const auto do_only_pass_comp = !do_postops && jcp.src_zero_point |
744 | && (jcp.req_brg_comp_pad || jcp.max_vpad > 0); |
745 | const auto do_skip_accm = k_sum == 0; |
746 | is_first_call_postops |
747 | = one_of(true, do_postops, do_only_pass_comp, do_skip_accm); |
748 | is_first_call_postops_state_changed = true; |
749 | } |
750 | |
751 | MAYBE_UNUSED(bias_w); |
752 | MAYBE_UNUSED(ptr_C); |
753 | MAYBE_UNUSED(post_ops_binary_rhs_arg_vec); |
754 | }; |
755 | |
756 | const auto kdhw_loop = [&]() { |
757 | const auto do_init = btc.occ == 0 && kd_b == kd_s && kh_b == kh_s; |
758 | const auto do_postwork = need_postwork && btc.occ == (oc_chunks - 1) |
759 | && kd_e == kd_f && kh_e == kh_f; |
760 | |
761 | const int kd_l = div_up(kd_e - kd_b, SD); |
762 | const int kh_l = div_up(kh_e - kh_b, SH); |
763 | const int kw_l = div_up(kw_f - kw_s, SW); |
764 | k_l = kd_l * kh_l * kw_l; |
765 | |
766 | int kernel_idx[2][2]; |
767 | kernel_idx[false][false] |
768 | = _pd->get_brg_idx(k_l, ker_i, false, is_ic_tail, false); |
769 | kernel_idx[true][false] |
770 | = _pd->get_brg_idx(k_l, ker_i, true, is_ic_tail, false); |
771 | kernel_idx[false][true] |
772 | = _pd->get_brg_idx(k_l, ker_i, false, is_ic_tail, true); |
773 | kernel_idx[true][true] |
774 | = _pd->get_brg_idx(k_l, ker_i, true, is_ic_tail, true); |
775 | |
776 | if (nb_oc_b > 0) { |
777 | const auto brg_idx = kernel_idx[do_init][false]; |
778 | call_brgemm(brg_idx, 0, nb_oc_b, do_postwork && !is_oc_tail); |
779 | } |
780 | |
781 | if (is_oc_tail) { |
782 | const auto use_init_ker = (do_init && nb_oc_b == 0); |
783 | const auto brg_oc_tail_idx = kernel_idx[use_init_ker][true]; |
784 | call_brgemm(brg_oc_tail_idx, nb_oc_b, 1, do_postwork); |
785 | } |
786 | }; |
787 | |
788 | if (kd_f > kd_s && kh_f > kh_s) { |
789 | // kw values covering full ow_block |
790 | for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) { |
791 | kd_e = nstl::min(kd_f, kd_b + KD_BLOCK); |
792 | for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) { |
793 | kh_e = nstl::min(kh_f, kh_b + KH_BLOCK); |
794 | kdhw_loop(); |
795 | } |
796 | } |
797 | } else { |
798 | kd_b = kd_e = kd_s; |
799 | kh_b = kh_e = kh_s; |
800 | kdhw_loop(); |
801 | } |
802 | } |
803 | |
804 | template struct brgemm_convolution_bwd_strided_t<avx512_core_amx>; |
805 | template struct brgemm_convolution_bwd_strided_t<avx512_core_amx, true>; |
806 | template struct brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16>; |
807 | template struct brgemm_convolution_bwd_strided_t<avx512_core_amx_fp16, true>; |
808 | |
809 | } // namespace x64 |
810 | |
811 | } // namespace cpu |
812 | } // namespace impl |
813 | } // namespace dnnl |
814 | |
815 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
816 | |