1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | #include "common/utils.hpp" |
21 | #include "cpu/cpu_primitive.hpp" |
22 | #include "cpu/scale_utils.hpp" |
23 | |
24 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
25 | #include "cpu/x64/jit_brgemm_conv.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | using namespace dnnl::impl::status; |
33 | using namespace dnnl::impl::memory_tracking::names; |
34 | using namespace dnnl::impl::utils; |
35 | |
36 | using namespace nstl; |
37 | using namespace data_type; |
38 | |
39 | using namespace jit_avx512_core_brgemm_conv_trans_kernel; |
40 | using namespace jit_avx512_core_brgemm_conv_comp_pad_kernel; |
41 | |
42 | #define ndims_pick(v5, v4, v3) \ |
43 | ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0) |
44 | |
45 | template <cpu_isa_t isa, bool use_inversion> |
46 | status_t brgemm_convolution_fwd_t<isa, use_inversion>::pd_t::init( |
47 | engine_t *engine) { |
48 | using namespace data_type; |
49 | using namespace utils; |
50 | |
51 | const auto src_type = src_md(0)->data_type; |
52 | const auto wei_type = weights_md(0)->data_type; |
53 | const auto dst_type = dst_md(0)->data_type; |
54 | const bool is_int8 = one_of(src_type, u8, s8); |
55 | |
56 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
57 | auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt |
58 | | skip_mask_t::zero_points_runtime; |
59 | if (is_int8) skip_mask |= skip_mask_t::scales_runtime; |
60 | |
61 | bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) |
62 | && IMPLICATION(is_int8, |
63 | one_of(bias_md_.data_type, data_type::undef, f32, s32, s8, |
64 | u8)) |
65 | && IMPLICATION(!is_int8, |
66 | one_of(bias_md_.data_type, data_type::undef, f32, src_type)) |
67 | && attr()->has_default_values(skip_mask, dst_type) |
68 | && attr()->post_ops_.check_sum_consistent_dt(dst_type) |
69 | && !has_zero_dim_memory() && zero_points_ok() && arg_scales_ok(); |
70 | if (!ok) return status::unimplemented; |
71 | const auto is_amx = brgemm_convolution_utils::is_amx(isa); |
72 | |
73 | CHECK(brgemm_convolution_utils::init_conf(jcp_, isa, *desc(), src_md_, |
74 | weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads())); |
75 | |
76 | const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail); |
77 | |
78 | // 1. Use unrolled kernel for exec_trans only to avoid creation a lot of |
79 | // kernels for each kw range |
80 | // 2. For exec_trans block by kw is always KW |
81 | assert(IMPLICATION(jcp_.use_uker, is_amx && jcp_.exec_type == exec_trans)); |
82 | assert(IMPLICATION(jcp_.use_interleave_stores, jcp_.use_uker)); |
83 | |
84 | batchsizes.resize(jcp_.max_batch + 1); |
85 | for (int i = 0; i <= jcp_.max_batch; i++) |
86 | batchsizes[i] = -1; |
87 | |
88 | first_bs = 0; |
89 | bs_c = 0; |
90 | |
91 | const auto SD = jcp_.stride_d; |
92 | const auto FP = jcp_.f_pad; |
93 | const auto DD = jcp_.dilate_d + 1; |
94 | const auto KD = jcp_.kd; |
95 | const auto ID = jcp_.id; |
96 | |
97 | const auto SH = jcp_.stride_h; |
98 | const auto TP = jcp_.t_pad; |
99 | const auto DH = jcp_.dilate_h + 1; |
100 | const auto KH = jcp_.kh; |
101 | const auto KW = jcp_.kw; |
102 | const auto IH = jcp_.ih; |
103 | |
104 | const auto KD_BLOCK = jcp_.kd_block; |
105 | const auto KH_BLOCK = jcp_.kh_block; |
106 | const auto KW_BLOCK = jcp_.kw_block; |
107 | |
108 | if (jcp_.use_uker) { |
109 | |
110 | assert(KD % KD_BLOCK == 0); |
111 | assert(KH % KH_BLOCK == 0); |
112 | |
113 | for (int iod = 0; iod < jcp_.od; iod++) { |
114 | const int iid = iod * SD - FP; |
115 | const int kd_s = div_up(max(0, -iid), DD); |
116 | const int kd_f |
117 | = KD - div_up(max(0, iid - ID + (KD - 1) * DD + 1), DD); |
118 | const auto kd_l = nstl::min(KD_BLOCK, kd_f - kd_s); |
119 | for (int ioh = 0; ioh < jcp_.oh; ioh++) { |
120 | |
121 | const auto iih = ioh * SH - TP; |
122 | const auto kh_s |
123 | = jcp_.is_os_blocking ? 0 : div_up(max(0, -iih), DH); |
124 | const auto kh_f |
125 | = KH - div_up(max(0, iih - IH + (KH - 1) * DH + 1), DH); |
126 | const auto kh_l = nstl::min(KH_BLOCK, kh_f - kh_s); |
127 | const auto bs = kd_l * kh_l * jcp_.kw; |
128 | if (bs < 0) continue; |
129 | |
130 | if (batchsizes[bs] == -1) { |
131 | batchsizes[bs] = bs_c; |
132 | if (first_bs == 0) first_bs = bs; |
133 | bs_c++; |
134 | } |
135 | } |
136 | } |
137 | } else { |
138 | batchsizes[jcp_.max_batch] = bs_c; |
139 | first_bs = jcp_.max_batch; |
140 | bs_c++; |
141 | } |
142 | |
143 | brgs_sz_ = bs_c * adj_M * 2 * 2 * 2; |
144 | brgs_.resize(brgs_sz_); |
145 | bd_masks.resize(brgs_sz_); |
146 | |
147 | const float alpha = 1.0; |
148 | const float beta = 1.0; |
149 | |
150 | const auto &p = attr()->post_ops_; |
151 | const int sum_idx = p.find(primitive_kind::sum); |
152 | with_sum = (sum_idx != -1); |
153 | |
154 | // os_blocking is supported for exec_trans only |
155 | assert(IMPLICATION(jcp_.exec_type != exec_trans, !jcp_.is_os_blocking)); |
156 | assert(IMPLICATION(jcp_.is_os_blocking, |
157 | jcp_.os_block % jcp_.ow == 0 && jcp_.os_block / jcp_.ow <= jcp_.oh |
158 | && jcp_.os_block / jcp_.ow == jcp_.oh_block)); |
159 | |
160 | auto maybe_M_mask = [&](int brg_idx, brgemm_attr_t &brgattr, int vM, |
161 | int vbrgM) { |
162 | if (!jcp_.use_M_mask) return; |
163 | auto sm_size = vbrgM; |
164 | bd_masks[brg_idx] = std::make_shared<std::vector<char>>(); |
165 | bd_masks[brg_idx]->resize(sm_size); |
166 | char *bd_mask = bd_masks[brg_idx]->data(); |
167 | if (jcp_.is_os_blocking) { |
168 | int ibrgM = 0; |
169 | int iM = 0; |
170 | for (int hh = 0; hh < jcp_.oh_block; hh++) { |
171 | auto M_mask = (iM >= vM) ? 0 : 1; |
172 | for (int ww = 0; ww < jcp_.ow_block && ibrgM < sm_size; |
173 | ww++, ibrgM++, iM += M_mask) { |
174 | bd_mask[ibrgM] = M_mask; |
175 | } |
176 | for (int kk = 0; kk < jcp_.oskip && ibrgM < sm_size; |
177 | kk++, ibrgM++) { |
178 | bd_mask[ibrgM] = 0; |
179 | } |
180 | } |
181 | for (; ibrgM < sm_size; ibrgM++) { |
182 | bd_mask[ibrgM] = 0; |
183 | } |
184 | } else { |
185 | for (int ibrgM = 0; ibrgM < sm_size; ibrgM++) { |
186 | bd_mask[ibrgM] = 1; |
187 | } |
188 | } |
189 | brgattr.bd_mask = bd_mask; |
190 | }; |
191 | |
192 | ic_chunks = div_up(jcp_.nb_ic, jcp_.nb_ic_blocking); |
193 | need_postwork = jcp_.with_bias || jcp_.with_eltwise || jcp_.with_binary |
194 | || (one_of(src_type, u8, s8) && wei_type == s8) // oscales needed |
195 | || (jcp_.dst_dt != jcp_.acc_dt) || jcp_.with_sum || jcp_.use_M_mask |
196 | || jcp_.src_zero_point || jcp_.dst_zero_point; |
197 | |
198 | const auto M_end = nstl::max(jcp_.M, jcp_.M_tail); |
199 | |
200 | int K_begin = 0; |
201 | int K_end = (jcp_.K_tail == 0) ? 1 : 2; |
202 | |
203 | int i_init_begin = (jcp_.K_tail == 0 && jcp_.exec_type == exec_trans |
204 | && div_up(jcp_.nb_ic, jcp_.nb_ic_blocking) == 1 |
205 | && KD_BLOCK == KD && KH_BLOCK == KH) |
206 | ? 1 |
207 | : 0; |
208 | int i_init_end = 2; |
209 | |
210 | for (int i = 0; i < M_end; i++) { |
211 | auto vM = i + 1; |
212 | // init only needed brgemm descriptors |
213 | if (one_of(jcp_.exec_type, exec_trans, exec_vpad) && vM != jcp_.M |
214 | && vM != jcp_.M_tail) |
215 | continue; |
216 | for (int bs = 0; bs <= jcp_.max_batch; bs++) { |
217 | if (batchsizes[bs] == -1) continue; |
218 | for_(int i_init = i_init_begin; i_init < i_init_end; i_init++) |
219 | for_(int i_N = 0; i_N < 2; i_N++) |
220 | for (int i_K = K_begin; i_K < K_end; i_K++) { |
221 | auto vbeta = (i_init) ? 0 : beta; |
222 | auto vN = (i_N) ? jcp_.N_tail : jcp_.N; |
223 | auto vK = (i_K) ? jcp_.K_tail : jcp_.K; |
224 | auto vbrgM = jcp_.use_M_mask |
225 | ? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail) |
226 | : vM; |
227 | auto brg_idx = get_brg_idx(bs, i, i_init, i_N, i_K); |
228 | // if brgemm_t already created then skip this iteration |
229 | if (brgs_[brg_idx] != nullptr) continue; |
230 | brgs_[brg_idx] = std::make_shared<brgemm_t>(); |
231 | brgemm_t *brg = brgs_[brg_idx].get(); |
232 | if (vN == 0 || vK == 0) continue; |
233 | brgemm_strides_t brg_strides; |
234 | brg_strides.stride_a = jcp_.brg_stride_a; |
235 | brg_strides.stride_b = jcp_.brg_stride_b; |
236 | brg->req_cal_comp_pads = jcp_.req_brg_comp_pad |
237 | && (jcp_.src_zero_point || jcp_.s8s8_avx512); |
238 | const auto strides_ptr = (jcp_.brg_type == brgemm_strd) |
239 | ? &brg_strides |
240 | : nullptr; |
241 | CHECK(brgemm_desc_init(brg, isa, jcp_.brg_type, src_type, |
242 | wei_type, false, false, brgemm_row_major, alpha, vbeta, |
243 | jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK, |
244 | strides_ptr)); |
245 | |
246 | brgemm_attr_t brgattr; |
247 | brgattr.use_uker = jcp_.use_uker; |
248 | brgattr.use_interleave_stores = jcp_.use_interleave_stores; |
249 | brgattr.hint_prefetching = jcp_.hint_prefetching; |
250 | brgattr.max_bs = bs; |
251 | brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost |
252 | ? brgemm_bd_loop_innermost |
253 | : brgemm_ld_loop_innermost; |
254 | if (jcp_.amx_tile_load_xx) { |
255 | // assuming 2x2 decomposition in amx brgemm kernel |
256 | // and overlap of input by kw |
257 | const auto bd_blocking = 2 * jcp_.amx_h; |
258 | const auto ld_blocking = 2 * 16; |
259 | brgattr.hint_expected_A_size = bd_blocking * jcp_.K |
260 | * jcp_.kd_block * jcp_.kh_block; |
261 | brgattr.hint_expected_B_size = ld_blocking * jcp_.K |
262 | * jcp_.kd_block * jcp_.kh_block * jcp_.kw_block; |
263 | brgattr.hint_expected_C_size = bd_blocking * ld_blocking; |
264 | } else { |
265 | brgattr.hint_expected_A_size = 0; |
266 | brgattr.hint_expected_B_size = 0; |
267 | brgattr.hint_expected_C_size = 0; |
268 | } |
269 | |
270 | brgattr.wary_tail_read = false; |
271 | maybe_M_mask(brg_idx, brgattr, vM, vbrgM); |
272 | brgattr.bd_mask_level = jcp_.use_M_mask; |
273 | |
274 | if (is_amx) { |
275 | brgattr.max_top_vpad = 0; |
276 | brgattr.max_bottom_vpad = 0; |
277 | } else { |
278 | brgattr.max_top_vpad = jcp_.max_vpad; |
279 | brgattr.max_bottom_vpad = jcp_.max_vpad; |
280 | } |
281 | brgattr.fpmath_mode = attr()->fpmath_mode_; |
282 | |
283 | // if need post_ops and there are no intermediate calculations |
284 | // (like ic_chunks > 1 or blocking by kernel) we don't need |
285 | // code without post-ops in brgemm kernel |
286 | if (need_postwork && ic_chunks == 1 && KD_BLOCK == KD |
287 | && KH_BLOCK == KH && KW_BLOCK == KW) |
288 | brgattr.postops_only = true; |
289 | |
290 | CHECK(brgemm_desc_set_attr(brg, brgattr)); |
291 | |
292 | auto LDD = jcp_.oc_without_padding; |
293 | brg->with_sum = with_sum; |
294 | CHECK(brgemm_desc_set_postops( |
295 | brg, attr(), &dst_md_, LDD, jcp_.bia_dt)); |
296 | jcp_.amx_buf_size_per_thread |
297 | = nstl::max(brg->get_wsp_buffer_size(), |
298 | jcp_.amx_buf_size_per_thread); |
299 | } |
300 | } |
301 | } |
302 | |
303 | brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_); |
304 | auto scratchpad = scratchpad_registry().registrar(); |
305 | brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_); |
306 | if (jcp_.with_scales) |
307 | book_precomputed_scales(scratchpad, attr()->scales_, OC()); |
308 | |
309 | return status::success; |
310 | } |
311 | |
312 | template <cpu_isa_t isa, bool use_inversion> |
313 | brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_convolution_fwd_t( |
314 | const pd_t *apd) |
315 | : primitive_t(apd), bias_d(pd()->weights_md(1)) {} |
316 | |
317 | template <cpu_isa_t isa, bool use_inversion> |
318 | void brgemm_convolution_fwd_t<isa, use_inversion>::get_kw_range( |
319 | int ow, int &kw_s, int &kw_full_s, int &kw_full_f, int &kw_f) const { |
320 | // This function needed for exec_base only |
321 | const auto _pd = pd(); |
322 | const auto &jcp = _pd->jcp_; |
323 | // TODO: calculate these values instead direct loop by kw |
324 | |
325 | const bool is_ow_tail = (jcp.ow - ow < jcp.ow_block); |
326 | const auto M = is_ow_tail ? jcp.ow_tail : jcp.ow_block; |
327 | kw_s = kw_full_s = kw_full_f = kw_f = -1; |
328 | for (int kw = 0; kw < jcp.kw; kw++) { |
329 | int ow_s {0}, ow_f {0}; |
330 | get_ow_range(ow, kw, ow_s, ow_f); |
331 | if (ow_s < ow_f) { |
332 | if (kw_s == -1) kw_s = kw; |
333 | kw_f = kw + 1; |
334 | if (ow_f - ow_s == M) { |
335 | if (kw_full_s == -1) kw_full_s = kw; |
336 | kw_full_f = kw + 1; |
337 | } |
338 | } |
339 | } |
340 | if (kw_f == -1) { |
341 | kw_s = 0; |
342 | kw_f = 0; |
343 | } |
344 | if (kw_full_f == -1) kw_full_s = kw_full_f = kw_f; |
345 | } |
346 | |
347 | template <cpu_isa_t isa, bool use_inversion> |
348 | void brgemm_convolution_fwd_t<isa, use_inversion>::get_ow_range( |
349 | int ow, int kw, int &ow_s, int &ow_f) const { |
350 | // This function needed for exec_base only |
351 | const auto _pd = pd(); |
352 | const auto &jcp = _pd->jcp_; |
353 | |
354 | const bool is_ow_tail = (jcp.ow - ow < jcp.ow_block); |
355 | const auto M = is_ow_tail ? jcp.ow_tail : jcp.ow_block; |
356 | |
357 | const auto IW = jcp.iw; |
358 | const auto SW = jcp.stride_w; |
359 | const auto LP = jcp.l_pad; |
360 | const auto DW = jcp.dilate_w + 1; |
361 | |
362 | const auto iiw = ow * SW - LP; |
363 | auto iw_lp = iiw + kw * DW; |
364 | const auto iw_rp = iw_lp + (M - 1) * SW - IW + 1; |
365 | ow_s = ow; |
366 | |
367 | int ker_idx = 0; |
368 | if (iw_lp < 0) { |
369 | iw_lp = nstl::abs(iw_lp); |
370 | ker_idx += div_up(iw_lp, SW); |
371 | ow_s += ker_idx; |
372 | } |
373 | if (iw_rp > 0) ker_idx += div_up(iw_rp, SW); |
374 | ow_f = ow_s + (M - ker_idx); |
375 | ow_s = nstl::min(ow_s, ow + M); |
376 | ow_f = nstl::min(nstl::max(ow_f, ow_s), ow + M); |
377 | } |
378 | |
379 | template <cpu_isa_t isa, bool use_inversion> |
380 | status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_brg_kernel( |
381 | int bs, int M, int i_N, int i_K, int i_init) { |
382 | if (M <= 0) return status::success; |
383 | const auto _pd = pd(); |
384 | const auto &jcp = _pd->jcp_; |
385 | const auto &brgs = _pd->brgs_; |
386 | |
387 | auto N = (i_N) ? jcp.N_tail : jcp.N; |
388 | auto K = (i_K) ? jcp.K_tail : jcp.K; |
389 | if (N <= 0 || K <= 0) return status::success; |
390 | auto brg_idx = _pd->get_brg_idx(bs, M - 1, i_init, i_N, i_K); |
391 | auto brg = brgs[brg_idx]; |
392 | if (!brg_kernels_[brg_idx] && brg && brg->bcast_dim > 0 && brg->load_dim > 0 |
393 | && brg->reduce_dim > 0) { |
394 | brgemm_kernel_t *brg_kernel = nullptr; |
395 | CHECK(brgemm_kernel_create(&brg_kernel, *brg)); |
396 | CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel)); |
397 | if (is_amx) { |
398 | CHECK(brgemm_init_tiles(*brg, &brg_kernel_palettes_[brg_idx].a[0])); |
399 | } |
400 | } |
401 | return status::success; |
402 | } |
403 | |
404 | template <cpu_isa_t isa, bool use_inversion> |
405 | status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel( |
406 | brgemm_t *bcfg, int ker_idx, bool is_init) { |
407 | if (!bcfg) return status::success; |
408 | const auto _pd = pd(); |
409 | const auto &jcp = _pd->jcp_; |
410 | |
411 | bcfg->LDD = (is_init && jcp.use_buffer) ? jcp.LDC : jcp.LDD; |
412 | bcfg->dt_c = (!is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // inp |
413 | bcfg->dt_d = (is_init && jcp.use_buffer) ? jcp.acc_dt : jcp.dst_dt; // out |
414 | bcfg->alpha |
415 | = (!is_init && IMPLICATION(jcp.with_sum, jcp.use_buffer)) ? 1 : 0; |
416 | bcfg->beta = is_init ? 0 : 1; |
417 | CHECK(safe_ptr_assign(kernels_po_[ker_idx], |
418 | new jit_brgemm_kernel_post_ops<isa>(jcp, *bcfg, *_pd->attr()))); |
419 | kernels_po_[ker_idx]->create_kernel(); |
420 | return status::success; |
421 | } |
422 | |
423 | template <cpu_isa_t isa, bool use_inversion> |
424 | void brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernels( |
425 | int i_N, int init_bcast_dim, int po_bcast_dim) { |
426 | const auto _pd = pd(); |
427 | const auto &jcp = _pd->jcp_; |
428 | const auto &brgs = _pd->brgs_; |
429 | |
430 | auto N = (i_N) ? jcp.N_tail : jcp.N; |
431 | if (N <= 0) return; |
432 | auto i_K = (jcp.K_tail > 0); |
433 | |
434 | if (init_bcast_dim > 0) { |
435 | auto brg_idx = _pd->get_brg_idx( |
436 | _pd->first_bs, init_bcast_dim - 1, 0, i_N, i_K); |
437 | if (brgs[brg_idx]) { |
438 | auto init_cfg = *(brgs[brg_idx].get()); |
439 | auto ker_init_idx = get_ker_po_idx(init_bcast_dim - 1, false, i_N); |
440 | if (init_cfg.load_dim > 0 && kernels_po_[ker_init_idx] == nullptr) { |
441 | init_cfg.bcast_dim = init_bcast_dim; |
442 | add_po_kernel(&init_cfg, ker_init_idx, true); |
443 | } |
444 | } |
445 | } |
446 | |
447 | if ((_pd->need_postwork || jcp.use_buffer) && po_bcast_dim > 0) { |
448 | auto brg_idx = _pd->get_brg_idx( |
449 | _pd->first_bs, po_bcast_dim - 1, 0, i_N, i_K); |
450 | if (brgs[brg_idx]) { |
451 | auto po_cfg = *(brgs[brg_idx].get()); |
452 | auto ker_po_idx = get_ker_po_idx(po_bcast_dim - 1, true, i_N); |
453 | if (po_cfg.load_dim > 0 && kernels_po_[ker_po_idx] == nullptr) { |
454 | po_cfg.bcast_dim = po_bcast_dim; |
455 | add_po_kernel(&po_cfg, ker_po_idx, false); |
456 | } |
457 | } |
458 | } |
459 | } |
460 | template <cpu_isa_t isa, bool use_inversion> |
461 | int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_ker_idx( |
462 | const int kd_b, const int kd_e, const int kh_b, const int kh_e, |
463 | const int kw_b, const int kw_e) const { |
464 | const auto _pd = pd(); |
465 | const auto &jcp = _pd->jcp_; |
466 | |
467 | if (!jcp.req_cal_comp_pad) return 0; |
468 | |
469 | assert(kd_e > kd_b && kh_e > kh_b); |
470 | for (int k = 0; k < jcp.ker_ranges_size; k++) { |
471 | if (kd_b == kd_bs[k] && kd_e == kd_es[k] && kh_b == kh_bs[k] |
472 | && kh_e == kh_es[k] && kw_b == kw_bs[k] && kw_e == kw_es[k]) { |
473 | return k; |
474 | } |
475 | } |
476 | |
477 | return -1; |
478 | } |
479 | |
480 | template <cpu_isa_t isa, bool use_inversion> |
481 | int brgemm_convolution_fwd_t<isa, use_inversion>::get_comp_offset(const int g, |
482 | const int ocb, const int ow, const int kd_b, const int kd_e, |
483 | const int kh_b, const int kh_e, const int kw_b, const int kw_e) const { |
484 | const auto _pd = pd(); |
485 | const auto &jcp = _pd->jcp_; |
486 | |
487 | if (!jcp.src_zero_point && !jcp.s8s8_avx512) return 0; |
488 | |
489 | const auto comp_idx = get_comp_ker_idx(kd_b, kd_e, kh_b, kh_e, kw_b, kw_e); |
490 | assert(IMPLICATION(jcp.req_cal_comp_pad, comp_idx >= 0)); |
491 | |
492 | return jcp.req_cal_comp_pad |
493 | ? g * comp_ocb_sz + ocb * comp_ker_sz + comp_idx * comp_kw_sz |
494 | : (g * jcp.nb_oc + ocb) * jcp.oc_block; |
495 | } |
496 | |
497 | template <cpu_isa_t isa, bool use_inversion> |
498 | status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) { |
499 | |
500 | const auto _pd = pd(); |
501 | const auto &jcp = _pd->jcp_; |
502 | |
503 | bia_dsz = jcp.bia_dsz; |
504 | acc_dsz = jcp.acc_dsz; |
505 | src_dsz = jcp.src_dsz; |
506 | wei_dsz = jcp.wei_dsz; |
507 | dst_dsz = jcp.dst_dsz; |
508 | |
509 | auto ndims = _pd->ndims(); |
510 | if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!" ); |
511 | |
512 | KD = ndims_pick(jcp.kd, 1, 1); |
513 | KH = ndims_pick(jcp.kh, jcp.kh, 1); |
514 | KW = jcp.kw; |
515 | |
516 | EXT_KD = ndims_pick(jcp.ext_kd, 1, 1); |
517 | EXT_KH = ndims_pick(jcp.ext_kh, jcp.ext_kh, 1); |
518 | EXT_KW = jcp.ext_kw; |
519 | |
520 | IDP = ndims_pick(jcp.idp, 1, 1); |
521 | IHP = ndims_pick(jcp.ihp, jcp.ihp, 1); |
522 | IWP = jcp.iwp; |
523 | |
524 | KS = KD * KH * KW; |
525 | KD_BLOCK = ndims_pick(jcp.kd_block, 1, 1); |
526 | KH_BLOCK = ndims_pick(jcp.kh_block, jcp.kh_block, 1); |
527 | KW_BLOCK = jcp.kw_block; |
528 | KD_BLOCK_PAD = ndims_pick(jcp.kd_block_pad, 1, 1); |
529 | KH_BLOCK_PAD = ndims_pick(jcp.kh_block_pad, jcp.kh_block_pad, 1); |
530 | ID = ndims_pick(jcp.id, 1, 1); |
531 | IH = ndims_pick(jcp.ih, jcp.ih, 1); |
532 | IW = jcp.iw; |
533 | OD = ndims_pick(jcp.od, 1, 1); |
534 | OH = ndims_pick(jcp.oh, jcp.oh, 1); |
535 | OW = jcp.ow; |
536 | SD = ndims_pick(jcp.stride_d, 1, 1); |
537 | SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1); |
538 | SW = jcp.stride_w; |
539 | FP = ndims_pick(jcp.f_pad, 0, 0); |
540 | TP = ndims_pick(jcp.t_pad, jcp.t_pad, 0); |
541 | LP = jcp.l_pad; |
542 | DD = ndims_pick(jcp.dilate_d, 0, 0) + 1; |
543 | DH = ndims_pick(jcp.dilate_h, jcp.dilate_h, 0) + 1; |
544 | DW = jcp.dilate_w + 1; |
545 | |
546 | // const variables used for address calculations |
547 | src_w_sz = static_cast<dim_t>(IW) * jcp.ngroups * jcp.ic_without_padding; |
548 | src_h_sz = IH * src_w_sz; |
549 | src_d_sz = ID * src_h_sz; |
550 | dst_w_sz = static_cast<dim_t>(OW) * jcp.oc_without_padding; |
551 | dst_h_sz = OH * dst_w_sz; |
552 | dst_d_sz = OD * dst_h_sz; |
553 | |
554 | wei_ic_sz = static_cast<dim_t>(jcp.icp) * jcp.oc_block; |
555 | wei_kw_sz = KW * wei_ic_sz; |
556 | wei_kh_sz = KH * wei_kw_sz; |
557 | wei_kd_sz = KD * wei_kh_sz; |
558 | wei_ocb_sz = jcp.nb_oc * wei_kd_sz; |
559 | |
560 | comp_kw_sz = static_cast<dim_t>(jcp.oc_block); |
561 | comp_ker_sz = jcp.ker_ranges_size * comp_kw_sz; |
562 | comp_ocb_sz = jcp.nb_oc * comp_ker_sz; |
563 | |
564 | need_compensation |
565 | = (jcp.src_zero_point || jcp.s8s8_avx512) && !jcp.req_brg_comp_pad; |
566 | |
567 | // ---- Initialize arrays --------------------- |
568 | brg_kernels_.resize(_pd->brgs_sz_); |
569 | brg_kernel_palettes_.resize(_pd->brgs_sz_); |
570 | |
571 | for (int i = 0; i < _pd->brgs_sz_; i++) |
572 | brg_kernels_[i] = nullptr; |
573 | |
574 | int num_po_kernels = nstl::max(jcp.M, jcp.M_tail); |
575 | kernels_po_.resize(num_po_kernels * 2 * 2); |
576 | for (int i = 0; i < num_po_kernels; i++) { |
577 | for_(int i_init = 0; i_init < 2; i_init++) |
578 | for (int i_N = 0; i_N < 2; i_N++) |
579 | kernels_po_[get_ker_po_idx(i, i_init, i_N)] = nullptr; |
580 | } |
581 | |
582 | if (jcp.exec_type == exec_trans) { |
583 | CHECK(safe_ptr_assign(copy_to_pbuffer_, |
584 | new jit_avx512_core_brgemm_conv_trans_kernel_t(jcp))); |
585 | CHECK(copy_to_pbuffer_->create_kernel()); |
586 | } |
587 | if (jcp.copy_block_only) { |
588 | assert(jcp.exec_type == exec_trans && "Missing copy kernel" ); |
589 | const auto iw_block = copy_to_pbuffer_->dst_w(jcp.ow_block); |
590 | const auto ih_block = get_inp_size(IHP, jcp.oh_block, KH, SH, DH - 1); |
591 | const auto id_block = get_inp_size(IDP, jcp.od_block, KD, SD, DD - 1); |
592 | |
593 | pbuf_w_sz = (dim_t)jcp.ic_block * jcp.kh_sets * jcp.kw_sets * iw_block; |
594 | pbuf_h_sz = pbuf_w_sz * ih_block; |
595 | pbuf_d_sz = pbuf_h_sz * id_block; |
596 | |
597 | } else { |
598 | pbuf_w_sz = (dim_t)jcp.ic_block * jcp.kh_sets * jcp.kw_sets * jcp.iwp; |
599 | pbuf_h_sz = pbuf_w_sz * jcp.ihp; |
600 | pbuf_d_sz = pbuf_h_sz * jcp.idp; |
601 | } |
602 | |
603 | if (jcp.req_cal_comp_pad) { |
604 | CHECK(safe_ptr_assign(comp_vpad_pbuffer_, |
605 | new jit_avx512_core_brgemm_conv_comp_pad_kernel_t(jcp))); |
606 | CHECK(comp_vpad_pbuffer_->create_kernel()); |
607 | } |
608 | |
609 | is_amx = brgemm_convolution_utils::is_amx(isa); |
610 | |
611 | // #TODO: this needed only if we have d/h padding more then kd/kh |
612 | int M_begin = 0; |
613 | int M_end = (jcp.M_tail == jcp.M) ? 1 : 2; |
614 | int N_begin = 0; |
615 | int N_end = (jcp.N_tail == jcp.N) ? 1 : 2; |
616 | int K_begin = 0; |
617 | int K_end = (jcp.K_tail == 0) ? 1 : 2; |
618 | int i_init_begin = (jcp.K_tail == 0 && jcp.exec_type == exec_trans |
619 | && div_up(jcp.nb_ic, jcp.nb_ic_blocking) == 1 |
620 | && KD_BLOCK == KD && KH_BLOCK == KH) |
621 | ? 1 |
622 | : 0; |
623 | int i_init_end = 2; |
624 | |
625 | for (int bs = 0; bs <= jcp.max_batch; bs++) { |
626 | if (_pd->batchsizes[bs] == -1) continue; |
627 | |
628 | for_(int i_N = N_begin; i_N < N_end; i_N++) |
629 | for_(int i_M = M_begin; i_M < M_end; i_M++) |
630 | for_(int i_init = i_init_begin; i_init < i_init_end; i_init++) |
631 | for (int i_K = K_begin; i_K < K_end; i_K++) { |
632 | auto M = (i_M) ? jcp.M_tail : jcp.M; |
633 | if (M <= 0) continue; |
634 | add_brg_kernel(bs, M, i_N, i_K, i_init); |
635 | } |
636 | } |
637 | |
638 | for_(int i_N = N_begin; i_N < N_end; i_N++) |
639 | for (int i_M = M_begin; i_M < M_end; i_M++) { |
640 | // init "init" and "po" kernels for cases then we never call brgemm kernels |
641 | // e.g. for d/h padded areas |
642 | // TODO: do this only if d/h padding > kd/kh |
643 | if (IMPLICATION(jcp.exec_type == exec_trans, |
644 | jcp.od > jcp.id || jcp.oh > jcp.ih)) { |
645 | auto M = (i_M) ? jcp.M_tail : jcp.M; |
646 | add_po_kernels(i_N, M, M); |
647 | } |
648 | } |
649 | |
650 | if (jcp.exec_type == exec_base) { |
651 | // create brgemm kernels for ow_blocks with padded areas and |
652 | // apply post-ops on final iteration by kw to padded areas in ow_block |
653 | int kw_s {0}, kw_full_s {0}, kw_full_f {0}, kw_f {0}, ow_s {0}, |
654 | ow_f {0}; |
655 | for (int ow = 0; ow < OW; ow += jcp.ow_block) { |
656 | get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f); |
657 | for (int kw = kw_s; kw < kw_f; kw++) { |
658 | get_ow_range(ow, kw, ow_s, ow_f); |
659 | if (ow_f - ow_s <= 0) continue; |
660 | |
661 | auto M = ow_f - ow_s; |
662 | if (M <= 0) continue; |
663 | for (int bs = 0; bs <= jcp.max_batch; bs++) { |
664 | if (_pd->batchsizes[bs] == -1) continue; |
665 | for_(int i_init = 0; i_init < 2; i_init++) |
666 | for_(int i_N = 0; i_N < 2; i_N++) |
667 | for (int i_K = 0; i_K < 2; i_K++) { |
668 | add_brg_kernel(bs, M, i_N, i_K, i_init); |
669 | } |
670 | } |
671 | } |
672 | |
673 | bool is_ow_tail = (jcp.ow - ow < jcp.ow_block); |
674 | for_(int i_N = 0; i_N < 2; i_N++) |
675 | for (int i_side = 0; i_side < 2; i_side++) { |
676 | auto M = is_ow_tail ? jcp.M_tail : jcp.M; |
677 | if (M <= 0) continue; |
678 | get_ow_range(ow, kw_s, ow_s, ow_f); |
679 | const auto init_bcast_dim |
680 | = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f); |
681 | get_ow_range(ow, kw_f - 1, ow_s, ow_f); |
682 | const auto po_bcast_dim |
683 | = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f); |
684 | add_po_kernels(i_N, init_bcast_dim, po_bcast_dim); |
685 | } |
686 | |
687 | if (kw_f == jcp.kw && kw_s == 0) break; |
688 | } |
689 | |
690 | for (int ow = (jcp.nb_ow - 1) * jcp.ow_block; ow >= 0; |
691 | ow -= jcp.ow_block) { |
692 | get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f); |
693 | for (int kw = kw_s; kw < kw_f; kw++) { |
694 | get_ow_range(ow, kw, ow_s, ow_f); |
695 | if (ow_f - ow_s <= 0) continue; |
696 | |
697 | auto M = ow_f - ow_s; |
698 | if (M <= 0) continue; |
699 | for (int bs = 0; bs <= jcp.max_batch; bs++) { |
700 | if (_pd->batchsizes[bs] == -1) continue; |
701 | for_(int i_init = 0; i_init < 2; i_init++) |
702 | for_(int i_N = 0; i_N < 2; i_N++) |
703 | for (int i_K = 0; i_K < 2; i_K++) { |
704 | add_brg_kernel(bs, M, i_N, i_K, i_init); |
705 | } |
706 | } |
707 | } |
708 | |
709 | bool is_ow_tail = (jcp.ow - ow < jcp.ow_block); |
710 | |
711 | for_(int i_N = 0; i_N < 2; i_N++) |
712 | for (int i_side = 0; i_side < 2; i_side++) { |
713 | auto M = is_ow_tail ? jcp.M_tail : jcp.M; |
714 | if (M <= 0) continue; |
715 | get_ow_range(ow, kw_s, ow_s, ow_f); |
716 | const auto init_bcast_dim |
717 | = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f); |
718 | get_ow_range(ow, kw_f - 1, ow_s, ow_f); |
719 | const auto po_bcast_dim |
720 | = (i_side == 0) ? (ow_s - ow) : (ow + M - ow_f); |
721 | add_po_kernels(i_N, init_bcast_dim, po_bcast_dim); |
722 | } |
723 | |
724 | if (kw_f == jcp.kw && kw_s == 0) break; |
725 | } |
726 | } |
727 | |
728 | // pre-calculated values |
729 | if (jcp.exec_type == exec_vpad) { |
730 | owb_kw_top_vpads.resize(jcp.nb_ow * jcp.kw); |
731 | owb_kw_bottom_vpads.resize(jcp.nb_ow * jcp.kw); |
732 | |
733 | for (int owb = 0; owb < jcp.nb_ow; owb++) { |
734 | const int ow = owb * jcp.ow_block; |
735 | const bool is_ow_tail = (jcp.ow - ow < jcp.ow_block); |
736 | const int ow_b {ow}, ow_e {ow + (is_ow_tail ? jcp.M_tail : jcp.M)}; |
737 | const auto ow_l = ow_e - ow_b; |
738 | MAYBE_UNUSED(ow_l); |
739 | assert(0 <= ow_l && ow_l <= jcp.ow_block); |
740 | const auto iiw_b = ow_b * SW - LP; |
741 | const auto iiw_e = (ow_e - 1) * SW - LP + 1; |
742 | const auto iiw_l = iiw_e - iiw_b; |
743 | for (int kw = 0; kw < KW; kw++) { |
744 | const auto iw = iiw_b + kw * DW; |
745 | const auto top_vpad = iw >= 0 ? 0 : div_up(abs(iw), SW); |
746 | const auto bottom_vpad |
747 | = iw + iiw_l <= IW ? 0 : div_up(iw + iiw_l - IW, SW); |
748 | assert(top_vpad == 0 || bottom_vpad == 0); |
749 | owb_kw_top_vpads[owb * KW + kw] = top_vpad; |
750 | owb_kw_bottom_vpads[owb * KW + kw] = bottom_vpad; |
751 | } |
752 | } |
753 | } |
754 | |
755 | // pre-calculate unique kernel combination |
756 | if (jcp.req_cal_comp_pad) { |
757 | std::set<std::vector<int>> unique_kernels; |
758 | size_t k = 0; |
759 | kd_bs.resize(jcp.ker_ranges_size); |
760 | kd_es.resize(jcp.ker_ranges_size); |
761 | kh_bs.resize(jcp.ker_ranges_size); |
762 | kh_es.resize(jcp.ker_ranges_size); |
763 | kw_bs.resize(jcp.ker_ranges_size); |
764 | kw_es.resize(jcp.ker_ranges_size); |
765 | |
766 | const auto update_kernels = [&](int kd_b, int kd_e, int kh_b, int kh_e, |
767 | int kw_b, int kw_e) { |
768 | unique_kernels.insert({kd_b, kd_e, kh_b, kh_e, kw_b, kw_e}); |
769 | if (k == unique_kernels.size()) return; |
770 | kd_bs[k] = kd_b; |
771 | kd_es[k] = kd_e; |
772 | kh_bs[k] = kh_b; |
773 | kh_es[k] = kh_e; |
774 | kw_bs[k] = kw_b; |
775 | kw_es[k] = kw_e; |
776 | k++; |
777 | assert(k <= static_cast<size_t>(jcp.ker_ranges_size)); |
778 | }; |
779 | |
780 | for_(int odb = 0; odb < jcp.nb_od; odb++) |
781 | for_(int ohb = 0; ohb < jcp.nb_oh; ohb++) |
782 | for (int owb = 0; owb < jcp.nb_ow; owb++) { |
783 | auto od_begin = odb * jcp.od_block; |
784 | auto od_end = nstl::min(OD, od_begin + jcp.od_block); |
785 | auto oh_begin = ohb * jcp.oh_block; |
786 | auto oh_end = jcp.is_os_blocking |
787 | ? oh_begin + 1 |
788 | : nstl::min(OH, oh_begin + jcp.oh_block); |
789 | for_(int od = od_begin; od < od_end; od++) |
790 | for (int oh = oh_begin; oh < oh_end; oh++) { |
791 | int kw_s {0}, kw_full_s {0}, kw_f {0}, kw_full_f {0}; |
792 | const int ow = owb * jcp.ow_block; |
793 | const int iid = ndims_pick(od * SD - FP, 0, 0); |
794 | const int kd_s |
795 | = ndims_pick(div_up(nstl::max(0, -iid), DD), 0, 0); |
796 | const int kd_f = ndims_pick( |
797 | KD - div_up(max(0, iid - ID + (KD - 1) * DD + 1), DD), |
798 | 1, 1); |
799 | const int iih = ndims_pick(oh * SH - TP, oh * SH - TP, 0); |
800 | const auto kh_s_ = div_up(max(0, -iih), DH); |
801 | const auto kh_s = ndims_pick(kh_s_, kh_s_, 0); |
802 | const auto kh_f_ |
803 | = KH - div_up(max(0, iih - IH + (KH - 1) * DH + 1), DH); |
804 | const auto kh_f = ndims_pick(kh_f_, kh_f_, 1); |
805 | get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f); |
806 | if (kd_f > kd_s && kh_f > kh_s && kw_f > kw_s) { |
807 | if (jcp.exec_type == exec_vpad) { |
808 | update_kernels(kd_s, kd_f, kh_s, kh_f, 0, KW); |
809 | } else if (jcp.exec_type == exec_base) { |
810 | if (kw_s < kw_full_s) { |
811 | for (auto kw = kw_s; kw < kw_full_s; kw++) { |
812 | update_kernels( |
813 | kd_s, kd_f, kh_s, kh_f, kw, kw + 1); |
814 | } |
815 | } |
816 | if (kw_full_s < kw_full_f) { |
817 | for (auto kw = kw_full_s; kw < kw_full_f; |
818 | kw += KW_BLOCK) { |
819 | const auto kw_e |
820 | = nstl::min(kw_full_f, kw + KW_BLOCK); |
821 | update_kernels( |
822 | kd_s, kd_f, kh_s, kh_f, kw, kw_e); |
823 | } |
824 | } |
825 | if (kw_full_f < kw_f) { |
826 | for (auto kw = kw_full_f; kw < kw_f; kw++) { |
827 | update_kernels( |
828 | kd_s, kd_f, kh_s, kh_f, kw, kw + 1); |
829 | } |
830 | } |
831 | } |
832 | } |
833 | } |
834 | } |
835 | ker_vpad_sz = k; |
836 | } |
837 | |
838 | return status::success; |
839 | } |
840 | template <cpu_isa_t isa, bool use_inversion> |
841 | struct brgemm_convolution_fwd_t<isa, use_inversion>::brgemm_thread_ctx_t { |
842 | brgemm_thread_ctx_t(brgemm_exec_ctx_t &brgemm_ctx_, int ithr_, |
843 | brgemm_batch_element_t *__restrict brg_batch_, char *c_buffer_, |
844 | char *wsp_tile_) |
845 | : brgemm_ctx(brgemm_ctx_) |
846 | , ithr(ithr_) |
847 | , brg_batch(brg_batch_) |
848 | , c_buffer(c_buffer_) |
849 | , wsp_tile(wsp_tile_) {} |
850 | |
851 | brgemm_exec_ctx_t &brgemm_ctx; |
852 | int ithr; |
853 | brgemm_batch_element_t *__restrict brg_batch; |
854 | char *c_buffer; |
855 | char *wsp_tile; |
856 | S_t cur_palette; |
857 | int g, n, ocb; |
858 | int od, odb, oh, ohb, owb; |
859 | int icc; |
860 | const float *oscales {nullptr}; |
861 | int32_t src_zp_vals; |
862 | int32_t *src_zp_comp_ptr; |
863 | int32_t *dst_zp_vals; |
864 | int32_t *s8s8_comp_ptr; |
865 | }; |
866 | |
867 | template <cpu_isa_t isa, bool use_inversion> |
868 | status_t brgemm_convolution_fwd_t<isa, use_inversion>::execute( |
869 | const exec_ctx_t &ctx) const { |
870 | const auto _pd = pd(); |
871 | const auto &jcp = _pd->jcp_; |
872 | |
873 | DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); |
874 | DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); |
875 | |
876 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
877 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
878 | |
879 | const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), |
880 | src_scales, wei_scales, _pd->OC(), _pd->attr()); |
881 | |
882 | brgemm_exec_ctx_t brgemm_ctx(ctx, _pd); |
883 | |
884 | const char *const __restrict src = brgemm_ctx.src; |
885 | const char *const __restrict wei = brgemm_ctx.weights; |
886 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
887 | |
888 | const auto |
889 | = weights_d.size() - weights_d.additional_buffer_size(); |
890 | auto w = const_cast<char *>(brgemm_ctx.weights); |
891 | const auto s8s8_comp_offset = jcp.req_cal_comp_pad |
892 | ? jcp.ngroups * jcp.nb_oc * jcp.kd * jcp.kh * jcp.kw * jcp.oc_block |
893 | : jcp.ngroups * jcp.nb_oc * jcp.oc_block; |
894 | int32_t *s8s8_compensation = jcp.s8s8_avx512 |
895 | ? reinterpret_cast<int32_t *>(w + extra_data_offset) |
896 | : nullptr; |
897 | int32_t *zp_compensation = jcp.src_zero_point |
898 | ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) |
899 | + (jcp.s8s8_avx512 ? s8s8_comp_offset : 0) |
900 | : nullptr; |
901 | |
902 | const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); |
903 | brgemm_batch_element_t *const __restrict brg_batch_global |
904 | = (jcp.brg_type == brgemm_strd && jcp.exec_type != exec_vpad) |
905 | ? nullptr |
906 | : scratchpad.template get<brgemm_batch_element_t>( |
907 | key_brgemm_primitive_batch); |
908 | char *const __restrict c_buffer_global = (jcp.use_buffer) |
909 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer) |
910 | : nullptr; |
911 | |
912 | auto inp_p_buffer = (jcp.exec_type == exec_trans) |
913 | ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer) |
914 | : nullptr; |
915 | auto inp_p_buffer_mask = (jcp.exec_type == exec_trans) |
916 | ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask) |
917 | : nullptr; |
918 | int32_t *src_zp_comp_base = jcp.src_zero_point |
919 | ? (jcp.req_cal_comp_pad ? scratchpad.template get<int32_t>( |
920 | key_brgemm_primitive_zp_comp_a) |
921 | : zp_compensation) |
922 | : nullptr; |
923 | int32_t *s8s8_comp_base = jcp.s8s8_avx512 |
924 | ? (jcp.req_cal_comp_pad ? scratchpad.template get<int32_t>( |
925 | key_brgemm_primitive_buffer_comp) |
926 | : s8s8_compensation) |
927 | : nullptr; |
928 | const auto dst_zp_vals = jcp.dst_zero_point ? &dst_zero_point : nullptr; |
929 | const auto src_zp_vals = src_zero_point; |
930 | |
931 | cal_compensation(wei, src_zp_comp_base, s8s8_comp_base); |
932 | |
933 | char *const wsp_tile_global = is_amx |
934 | ? scratchpad.template get<char>(key_conv_amx_tile_buffer) |
935 | : nullptr; |
936 | |
937 | // --------------- Parallel section ------------------------------ |
938 | const dim_t work_amount = static_cast<dim_t>(jcp.mb) * jcp.ngroups |
939 | * jcp.nb_oc * jcp.nb_od * jcp.nb_oh * jcp.nb_ow; |
940 | // TODO: consider loop by icc be innermost because for current |
941 | // implementation if we use buffer then we accumulate in it only on row |
942 | // or made ic_chunks = 1 if use_buffer |
943 | // or (looks more general) increase buffer size to store several rows |
944 | |
945 | parallel(jcp.nthr, [&](const int ithr, const int nthr) { |
946 | if (ithr >= work_amount) return; |
947 | |
948 | brgemm_batch_element_t *const __restrict brg_batch = brg_batch_global |
949 | + static_cast<size_t>(ithr) * jcp.adjusted_batch_size; |
950 | char *const __restrict c_buffer = (jcp.use_buffer) |
951 | ? c_buffer_global + ithr * acc_dsz * jcp.buffer_size |
952 | : nullptr; |
953 | char *inp_buffer = (jcp.exec_type == exec_trans) |
954 | ? inp_p_buffer + src_dsz * ithr * jcp.inp_buffer_size |
955 | : nullptr; |
956 | if (is_amx) { |
957 | // Workaround: for some machines SEGFAULT possible on tile load |
958 | // if the page was not touched before it |
959 | for (dim_t i = 0; i < jcp.inp_buffer_size; |
960 | i += brgemm_convolution_utils::P4K) |
961 | inp_buffer[i] = 0; |
962 | } |
963 | |
964 | uint8_t *__restrict inp_buffer_mask = (jcp.exec_type == exec_trans) |
965 | ? inp_p_buffer_mask + ithr * jcp.inp_buffer_mask_size |
966 | : nullptr; |
967 | |
968 | char *const wsp_tile = is_amx |
969 | ? wsp_tile_global + ithr * jcp.amx_buf_size_per_thread |
970 | : nullptr; |
971 | dim_t start {0}, end {0}; |
972 | balance211(work_amount, nthr, ithr, start, end); |
973 | int n {0}, g {0}, ocb {0}, odb {0}, ohb {0}, owb {0}; |
974 | if (jcp.loop_order == loop_ndhwgc) |
975 | nd_iterator_init(start, n, jcp.mb, odb, jcp.nb_od, ohb, jcp.nb_oh, |
976 | owb, jcp.nb_ow, g, jcp.ngroups, ocb, jcp.nb_oc); |
977 | else if (jcp.loop_order == loop_ngcdhw) |
978 | nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, |
979 | odb, jcp.nb_od, ohb, jcp.nb_oh, owb, jcp.nb_ow); |
980 | else |
981 | assert(!"Unknown loop order" ); |
982 | |
983 | brgemm_thread_ctx_t btc( |
984 | brgemm_ctx, ithr, brg_batch, c_buffer, wsp_tile); |
985 | std::memset(btc.cur_palette.a, 0, AMX_PALETTE_SIZE); |
986 | |
987 | int last_n = -1; |
988 | int last_g = -1; |
989 | int last_icc = -1; |
990 | int last_odb = -1; |
991 | int last_ohb = -1; |
992 | int last_owb = -1; |
993 | for (auto work = start; work < end; work++) { |
994 | btc.g = g; |
995 | btc.n = n; |
996 | btc.ocb = ocb; |
997 | btc.odb = odb; |
998 | btc.ohb = ohb; |
999 | btc.owb = owb; |
1000 | btc.oscales = oscales; |
1001 | btc.src_zp_vals = src_zp_vals; |
1002 | btc.dst_zp_vals = jcp.dst_zero_point ? dst_zp_vals : nullptr; |
1003 | btc.src_zp_comp_ptr |
1004 | = jcp.src_zero_point ? src_zp_comp_base : nullptr; |
1005 | btc.s8s8_comp_ptr = jcp.s8s8_avx512 ? s8s8_comp_base : nullptr; |
1006 | |
1007 | if (jcp.exec_type == exec_trans && (last_n != n || last_g != g)) { |
1008 | if (!jcp.copy_block_only) |
1009 | std::memset( |
1010 | inp_buffer_mask, false, jcp.inp_buffer_mask_size); |
1011 | } |
1012 | auto od_begin = odb * jcp.od_block; |
1013 | auto od_end = nstl::min(OD, od_begin + jcp.od_block); |
1014 | auto oh_begin = ohb * jcp.oh_block; |
1015 | // if is_os_blocking is true then we do only one iteration of loop |
1016 | // by oh and process entire oh block in kernel call |
1017 | auto oh_end = jcp.is_os_blocking |
1018 | ? oh_begin + 1 |
1019 | : nstl::min(OH, oh_begin + jcp.oh_block); |
1020 | for_(int od = od_begin; od < od_end; od++) |
1021 | for (int oh = oh_begin; oh < oh_end; oh++) { |
1022 | for (int icc = 0; icc < _pd->ic_chunks; icc++) { |
1023 | btc.od = od; |
1024 | btc.oh = oh; |
1025 | btc.icc = icc; |
1026 | |
1027 | if (jcp.exec_type == exec_base) { |
1028 | ker_base(btc); |
1029 | } else if (jcp.exec_type == exec_trans) { |
1030 | maybe_conv_inp(ithr, src, inp_buffer, inp_buffer_mask, |
1031 | g, n, icc, odb, ohb, owb, last_g, last_n, |
1032 | last_icc, last_odb, last_ohb, last_owb); |
1033 | ker_trans(btc, inp_buffer); |
1034 | } else if (jcp.exec_type == exec_vpad) { |
1035 | ker_vpad(btc); |
1036 | } else |
1037 | assert(!"Unknown exec type" ); |
1038 | last_n = n; |
1039 | last_g = g; |
1040 | last_icc = icc; |
1041 | last_odb = odb; |
1042 | last_ohb = ohb; |
1043 | last_owb = owb; |
1044 | } |
1045 | } |
1046 | if (jcp.loop_order == loop_ndhwgc) |
1047 | nd_iterator_step(n, jcp.mb, odb, jcp.nb_od, ohb, jcp.nb_oh, owb, |
1048 | jcp.nb_ow, g, jcp.ngroups, ocb, jcp.nb_oc); |
1049 | else if (jcp.loop_order == loop_ngcdhw) |
1050 | nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, odb, |
1051 | jcp.nb_od, ohb, jcp.nb_oh, owb, jcp.nb_ow); |
1052 | else |
1053 | assert(!"Unknown loop order" ); |
1054 | } |
1055 | if (is_amx) { amx_tile_release(); } |
1056 | }); |
1057 | |
1058 | if (_pd->wants_zero_pad_dst()) ctx.memory(DNNL_ARG_DST)->zero_pad(ctx); |
1059 | |
1060 | return status::success; |
1061 | } |
1062 | |
1063 | template <cpu_isa_t isa, bool use_inversion> |
1064 | status_t brgemm_convolution_fwd_t<isa, use_inversion>::cal_compensation( |
1065 | const char *__restrict weights, int32_t *src_zp_buffer, |
1066 | int32_t *s8s8_comp_buffer) const { |
1067 | const auto _pd = pd(); |
1068 | const auto &jcp = _pd->jcp_; |
1069 | |
1070 | if (!jcp.req_cal_comp_pad) return status::success; |
1071 | |
1072 | if (jcp.src_zero_point) |
1073 | std::memset(src_zp_buffer, 0, sizeof(int32_t) * jcp.comp_a_buffer_size); |
1074 | if (jcp.s8s8_avx512) |
1075 | std::memset(s8s8_comp_buffer, 0, |
1076 | sizeof(int32_t) * jcp.s8s8_comp_buffer_size); |
1077 | |
1078 | const auto work_amount |
1079 | = static_cast<dim_t>(jcp.ngroups) * jcp.nb_oc * ker_vpad_sz; |
1080 | const auto is_small_shape = work_amount <= jcp.nthr |
1081 | && (work_amount * jcp.oc_block * jcp.icp |
1082 | <= platform::get_per_core_cache_size(1)); |
1083 | const int nthr = is_small_shape ? 1 : jcp.nthr; |
1084 | |
1085 | parallel(nthr, [&](const int ithr, const int nthr) { |
1086 | if (ithr >= work_amount) return; |
1087 | |
1088 | dim_t start {0}, end {0}; |
1089 | int g {0}, ocb {0}, k {0}; |
1090 | balance211(work_amount, nthr, ithr, start, end); |
1091 | nd_iterator_init(start, g, jcp.ngroups, ocb, jcp.nb_oc, k, ker_vpad_sz); |
1092 | for (auto work = start; work < end; work++) { |
1093 | const dim_t kd_b {kd_bs[k]}, kd_e {kd_es[k]}, kh_b {kh_bs[k]}, |
1094 | kh_e {kh_es[k]}, kw_b {kw_bs[k]}, kw_e {kw_es[k]}; |
1095 | assert(kd_e > kd_b && kh_e > kh_b && kw_e > kw_b); |
1096 | |
1097 | const auto buffer_offs |
1098 | = g * comp_ocb_sz + ocb * comp_ker_sz + k * comp_kw_sz; |
1099 | const auto wei_offs = (g * jcp.nb_oc + ocb) * wei_kd_sz |
1100 | + kd_b * wei_kh_sz + kh_b * wei_kw_sz + kw_b * wei_ic_sz; |
1101 | |
1102 | jit_brgemm_conv_comp_pad_call_s p; |
1103 | |
1104 | p.kd_l = kd_e - kd_b; |
1105 | p.kh_l = kh_e - kh_b; |
1106 | p.kw_l = kw_e - kw_b; |
1107 | p.ptr_in = &weights[wei_offs]; |
1108 | p.ptr_zp_out = jcp.src_zero_point ? &src_zp_buffer[buffer_offs] |
1109 | : nullptr; |
1110 | p.ptr_cp_out = jcp.s8s8_avx512 ? &s8s8_comp_buffer[buffer_offs] |
1111 | : nullptr; |
1112 | (*comp_vpad_pbuffer_)(&p); |
1113 | |
1114 | nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, k, ker_vpad_sz); |
1115 | } |
1116 | }); |
1117 | return status::success; |
1118 | } |
1119 | |
1120 | template <cpu_isa_t isa, bool use_inversion> |
1121 | void brgemm_convolution_fwd_t<isa, use_inversion>::perform_outwork( |
1122 | char *dst_base, char *dst, char *c_buffer, const char *bias_w, int od, |
1123 | int oh, int ow, int g_oc, bool is_oc_tail, int ker_ow_s, int ker_ow_f, |
1124 | int kd_l, int kh_l, const void *post_ops_binary_rhs_arg_vec, |
1125 | const float *oscales, int32_t src_zp_vals, int32_t *src_zp_ptr, |
1126 | int32_t *dst_zp_ptr, int32_t *s8s8_compensation, bool maybe_do_init, |
1127 | bool do_postwork, bool do_post_comp) const { |
1128 | |
1129 | const auto _pd = pd(); |
1130 | const auto &jcp = _pd->jcp_; |
1131 | |
1132 | const auto do_init |
1133 | = maybe_do_init && IMPLICATION(jcp.with_sum, jcp.use_buffer); |
1134 | if (!do_init && !do_postwork) return; |
1135 | |
1136 | assert(!jcp.is_os_blocking); |
1137 | |
1138 | const bool is_ow_tail = (OW - ow < jcp.ow_block); |
1139 | |
1140 | const auto M = is_ow_tail ? jcp.M_tail : jcp.M; |
1141 | const auto kdh_l = kd_l * kh_l; |
1142 | const auto ow_s = (kdh_l <= 0) ? ow : ker_ow_s; |
1143 | const auto ow_f = (kdh_l <= 0) ? ow : ker_ow_f; |
1144 | assert(ow <= ow_s && ow_s <= ow_f && ow_f <= ow + M); |
1145 | |
1146 | brgemm_kernel_post_ops_t p; |
1147 | if (do_postwork) { |
1148 | p.ptr_bias = (void *)(bias_w); |
1149 | p.ptr_scales = (void *)(&oscales[jcp.is_oc_scale * g_oc]); |
1150 | p.ptr_binary_post_ops_rhs = post_ops_binary_rhs_arg_vec; |
1151 | p.dst_orig = dst; |
1152 | p.c_zp_values = dst_zp_ptr; |
1153 | p.a_comp_val = src_zp_vals; |
1154 | } |
1155 | |
1156 | auto call_outwork_ker = [&](bool is_postwork, bool has_postcomp, |
1157 | int ow_pw_s, int ow_pw_l) { |
1158 | auto ker_po_idx = get_ker_po_idx(ow_pw_l - 1, is_postwork, is_oc_tail); |
1159 | const auto outwork_ker = kernels_po_[ker_po_idx].get(); |
1160 | assert(outwork_ker != nullptr && ow_pw_l == outwork_ker->brg.bcast_dim); |
1161 | if (is_postwork) { |
1162 | p.apply_comp = has_postcomp; |
1163 | p.a_zp_compensation = has_postcomp && jcp.src_zero_point |
1164 | ? &src_zp_ptr[ow_pw_s * jcp.LDB] |
1165 | : src_zp_ptr; |
1166 | p.s8s8_compensation = has_postcomp && jcp.s8s8_avx512 |
1167 | ? &s8s8_compensation[ow_pw_s * jcp.LDB] |
1168 | : s8s8_compensation; |
1169 | |
1170 | p.ptr_out = dst_base |
1171 | + dst_dsz |
1172 | * (od * dst_h_sz + oh * dst_w_sz |
1173 | + ow_pw_s * jcp.oc_without_padding); |
1174 | p.ptr_in = static_cast<void *>(jcp.use_buffer |
1175 | ? (c_buffer + acc_dsz * (ow_pw_s - ow) * jcp.LDC) |
1176 | : p.ptr_out); |
1177 | } else { |
1178 | p.apply_comp = has_postcomp; |
1179 | char *const ptr_Cz = jcp.use_buffer |
1180 | ? (c_buffer + acc_dsz * (ow_pw_s - ow) * jcp.LDC) |
1181 | : dst_base |
1182 | + dst_dsz |
1183 | * (od * dst_h_sz + oh * dst_w_sz |
1184 | + ow_pw_s * jcp.oc_without_padding); |
1185 | p.ptr_out = static_cast<void *>(ptr_Cz); |
1186 | } |
1187 | (*outwork_ker)(&p); |
1188 | }; |
1189 | |
1190 | if (ow < ow_s) { |
1191 | // left side |
1192 | const auto ow_pw_l = ow_s - ow; |
1193 | if (do_init) call_outwork_ker(false, false, ow, ow_pw_l); |
1194 | if (do_postwork) call_outwork_ker(true, do_post_comp, ow, ow_pw_l); |
1195 | } |
1196 | if (ow_f < ow + M) { |
1197 | // right side |
1198 | const auto ow_pw_l = ow + M - ow_f; |
1199 | if (do_init) call_outwork_ker(false, false, ow_f, ow_pw_l); |
1200 | if (do_postwork) call_outwork_ker(true, do_post_comp, ow_f, ow_pw_l); |
1201 | } |
1202 | } |
1203 | |
1204 | template <cpu_isa_t isa, bool use_inversion> |
1205 | void brgemm_convolution_fwd_t<isa, use_inversion>::call_brgemm_kernel( |
1206 | brgemm_thread_ctx_t &btc, int brg_idx, int batch_size, char *ptr_C, |
1207 | char *ptr_D, const char *bias_w, int g_oc, bool do_postops, |
1208 | const void *binary_post_ops_rhs, int32_t src_zp_vals, |
1209 | int32_t *src_zp_ptr, int32_t *dst_zp_ptr, int32_t *s8s8_comp, |
1210 | bool do_only_comp) const { |
1211 | |
1212 | const auto _pd = pd(); |
1213 | const auto &jcp = _pd->jcp_; |
1214 | |
1215 | const auto brg_ker = brg_kernels_[brg_idx].get(); |
1216 | assert(brg_ker != nullptr); |
1217 | |
1218 | // TODO: avoid costly tile reconfigurations |
1219 | if (is_amx) { |
1220 | if (std::memcmp(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a, |
1221 | AMX_PALETTE_SIZE) |
1222 | != 0) { |
1223 | amx_tile_configure(brg_kernel_palettes_[brg_idx].a); |
1224 | std::memcpy(btc.cur_palette.a, brg_kernel_palettes_[brg_idx].a, |
1225 | AMX_PALETTE_SIZE); |
1226 | } |
1227 | } |
1228 | |
1229 | const auto do_only_pass_comp = !do_postops && jcp.src_zero_point |
1230 | && (jcp.req_brg_comp_pad || jcp.max_vpad > 0); |
1231 | const auto maybe_do_postops |
1232 | = one_of(true, do_postops, do_only_comp, do_only_pass_comp); |
1233 | if (maybe_do_postops) { |
1234 | const brgemm_post_ops_data_t post_ops_data { |
1235 | static_cast<const char *>(bias_w), |
1236 | &btc.oscales[jcp.is_oc_scale * g_oc], binary_post_ops_rhs, |
1237 | static_cast<size_t>(g_oc), 0, btc.brgemm_ctx.dst, 0, |
1238 | static_cast<void *>(src_zp_ptr), nullptr, |
1239 | static_cast<void *>(dst_zp_ptr), false, src_zp_vals, |
1240 | do_only_comp, do_only_pass_comp}; |
1241 | |
1242 | void *scratch = is_amx ? static_cast<void *>(btc.wsp_tile) |
1243 | : static_cast<void *>(s8s8_comp); |
1244 | |
1245 | if (do_postops) |
1246 | brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch, |
1247 | ptr_C, ptr_D, post_ops_data, scratch); |
1248 | else |
1249 | brgemm_kernel_execute_postops(brg_ker, batch_size, btc.brg_batch, |
1250 | ptr_C, ptr_C, post_ops_data, scratch); |
1251 | } else |
1252 | brgemm_kernel_execute(brg_ker, batch_size, btc.brg_batch, ptr_C, |
1253 | static_cast<void *>(btc.wsp_tile)); |
1254 | } |
1255 | |
1256 | template <cpu_isa_t isa, bool use_inversion> |
1257 | void brgemm_convolution_fwd_t<isa, use_inversion>::maybe_conv_inp(int ithr, |
1258 | const char *__restrict src, char *__restrict inp_buffer, |
1259 | uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int odb, |
1260 | int ohb, int owb, int last_g, int last_n, int last_icc, int last_odb, |
1261 | int last_ohb, int last_owb) const { |
1262 | |
1263 | const auto _pd = pd(); |
1264 | const auto &jcp = _pd->jcp_; |
1265 | const auto icb = icc * jcp.nb_ic_blocking; |
1266 | |
1267 | #define bmask(icb, odb, ohb, owb) \ |
1268 | inp_buffer_mask[(((icb)*jcp.nb_od + (odb)) * jcp.nb_oh + (ohb)) \ |
1269 | * jcp.nb_ow \ |
1270 | + (owb)] |
1271 | |
1272 | if (jcp.copy_block_only) { |
1273 | if (last_g == g && last_n == n && last_icc == icc && last_odb == odb |
1274 | && last_ohb == ohb && last_owb == owb) |
1275 | return; |
1276 | } else { |
1277 | if (bmask(icb, odb, ohb, owb)) return; |
1278 | } |
1279 | |
1280 | auto cp = jit_brgemm_conv_trans_kernel_call_s(); |
1281 | |
1282 | const auto prev_odb = (jcp.copy_block_only || odb == 0 |
1283 | || bmask(icb, odb - 1, ohb, owb) == 0) |
1284 | ? false |
1285 | : true; |
1286 | |
1287 | const auto prev_ohb = (jcp.copy_block_only || ohb == 0 |
1288 | || bmask(icb, odb, ohb - 1, owb) == 0) |
1289 | ? false |
1290 | : true; |
1291 | |
1292 | const auto prev_odb_ohb |
1293 | = (jcp.copy_block_only |
1294 | || (odb > 0 && ohb > 0 |
1295 | && bmask(icb, odb - 1, ohb - 1, owb) == 0)) |
1296 | ? false |
1297 | : true; |
1298 | |
1299 | const auto ic = icb * jcp.ic_block; |
1300 | const auto g_ic = g * jcp.ic + ic; |
1301 | const auto oh = ohb * jcp.oh_block; |
1302 | const auto ow = owb * jcp.ow_block; |
1303 | const auto iw = nstl::max(0, ow * SW - LP); |
1304 | |
1305 | int id_start {0}, id_end {0}, ih_start {0}, ih_end {0}; |
1306 | int virt_id_start {0}, virt_id_end {0}, virt_ih_start {0}, virt_ih_end {0}; |
1307 | |
1308 | auto get_start_end = [](int &start, int &end, int &virt_start, |
1309 | int &virt_end, int b, int bs, int i, int o, |
1310 | int s, int p, int k, int d, bool prev) { |
1311 | const auto o_b = saturate(0, o, b * bs); |
1312 | const auto prev_o_b = saturate(0, o, (b - 1) * bs); |
1313 | const auto virt_cur_start = o_b * s - p; |
1314 | const auto cur_start = saturate(0, i, virt_cur_start); |
1315 | const auto virt_prev_start = prev_o_b * s - p; |
1316 | const auto i_bs = get_inp_size(i, bs, k, s, d); |
1317 | const auto virt_i_bs = calculate_end_padding( |
1318 | 0, bs, 0, s, calculate_extended_filter_size(k, d)); |
1319 | const auto virt_prev_end = prev ? virt_prev_start + virt_i_bs : -p; |
1320 | const auto prev_end = prev ? saturate(0, i, virt_prev_end) : 0; |
1321 | virt_start = nstl::max(virt_prev_end, virt_cur_start); |
1322 | start = nstl::max(prev_end, cur_start); |
1323 | virt_end = virt_cur_start + virt_i_bs; |
1324 | end = saturate(0, i, cur_start + i_bs); |
1325 | }; |
1326 | get_start_end(id_start, id_end, virt_id_start, virt_id_end, odb, |
1327 | jcp.od_block, nstl::min(ID, IDP - FP), OD, SD, FP, KD, DD - 1, |
1328 | prev_odb && prev_odb_ohb); |
1329 | get_start_end(ih_start, ih_end, virt_ih_start, virt_ih_end, ohb, |
1330 | jcp.oh_block, nstl::min(IH, IHP - TP), OH, SH, TP, KH, DH - 1, |
1331 | prev_ohb && prev_odb_ohb); |
1332 | |
1333 | // how many real data rows to copy (including padding) |
1334 | const auto rows_to_copy = ih_end - ih_start; |
1335 | cp.owb = owb; |
1336 | cp.ic = ic; |
1337 | const auto iw_buf = jcp.copy_block_only ? 0 : (ow * SW); |
1338 | dim_t inp_offset_start, out_offset_start; |
1339 | |
1340 | for (int kh = 0; kh < jcp.kh_sets; kh++) { |
1341 | if (jcp.kh_sets > 1) { |
1342 | assert(!jcp.is_os_blocking); |
1343 | const auto ih_s = oh * SH + kh * DH - TP; |
1344 | const auto ih_f = (oh + jcp.oh_block - 1) * SH + kh * DH - TP + 1; |
1345 | |
1346 | cp.t_pad = max(0, -ih_s); |
1347 | cp.b_pad = max(0, ih_f - jcp.ih); |
1348 | cp.h_count = max(0, jcp.oh_block); |
1349 | const auto ih_buf = (jcp.copy_block_only ? 0 : ih_start) + TP; |
1350 | |
1351 | inp_offset_start = static_cast<dim_t>(n) * src_d_sz |
1352 | + max(ih_s, ih_start) * src_w_sz |
1353 | + iw * jcp.ngroups * jcp.ic_without_padding + g_ic; |
1354 | |
1355 | // inp_buffer has physical padding |
1356 | out_offset_start = (jcp.copy_block_only ? 0 |
1357 | : static_cast<dim_t>(icb) |
1358 | * pbuf_d_sz) |
1359 | + ih_buf * pbuf_w_sz |
1360 | + (iw_buf * jcp.kh_sets + kh) * jcp.kw_sets * jcp.ic_block; |
1361 | } else { |
1362 | // For os_blocking: |
1363 | // We have to zero top and bottom padding now |
1364 | // taking into account that batch size is always the same (kh_s is 0 for os_blocking) |
1365 | // TODO: extend M_mask (may be different for different kh) to avoid copying |
1366 | // top/bottom padded rows and avoid extra calculations in kernel |
1367 | // also for convolutions with pw == 0 the copy routine maybe not needed |
1368 | cp.t_pad = jcp.is_os_blocking ? max(0, -virt_ih_start) : 0; |
1369 | cp.b_pad = jcp.is_os_blocking ? max(0, virt_ih_end - IH) : 0; |
1370 | cp.h_count = max(0, rows_to_copy) + cp.t_pad + cp.b_pad; |
1371 | const auto ih_buf |
1372 | = (jcp.copy_block_only ? 0 : ih_start) + TP - cp.t_pad; |
1373 | |
1374 | inp_offset_start = static_cast<dim_t>(n) * src_d_sz |
1375 | + ih_start * src_w_sz |
1376 | + iw * jcp.ngroups * jcp.ic_without_padding + g_ic; |
1377 | |
1378 | // inp_buffer has physical padding |
1379 | out_offset_start = (jcp.copy_block_only ? 0 |
1380 | : static_cast<dim_t>(icb) |
1381 | * pbuf_d_sz) |
1382 | + ih_buf * pbuf_w_sz |
1383 | + iw_buf * jcp.ic_block * jcp.kh_sets * jcp.kw_sets; |
1384 | } |
1385 | |
1386 | for (int id = id_start; id < id_end; id++) { |
1387 | const auto inp_offset = inp_offset_start + id * src_h_sz; |
1388 | const auto id_buf = id - (jcp.copy_block_only ? id_start : 0) + FP; |
1389 | const auto out_offset = out_offset_start + id_buf * pbuf_h_sz; |
1390 | cp.src = src + src_dsz * inp_offset; |
1391 | cp.dst = inp_buffer + src_dsz * out_offset; |
1392 | (*copy_to_pbuffer_)(&cp); |
1393 | } |
1394 | } |
1395 | if (!jcp.copy_block_only) bmask(icb, odb, ohb, owb) = 1; |
1396 | |
1397 | #undef bmask |
1398 | } |
1399 | |
1400 | #define \ |
1401 | const char *const __restrict src = btc.brgemm_ctx.src; \ |
1402 | const char *const __restrict weights = btc.brgemm_ctx.weights; \ |
1403 | const char *const __restrict bias = btc.brgemm_ctx.bias; \ |
1404 | char *const __restrict dst = btc.brgemm_ctx.dst; \ |
1405 | const std::vector<const void *> &post_ops_binary_rhs_arg_vec \ |
1406 | = btc.brgemm_ctx.post_ops_binary_rhs_arg_vec; \ |
1407 | const int oc = btc.ocb * jcp.oc_block; \ |
1408 | const int g_oc = btc.g * jcp.oc + oc; \ |
1409 | const int icb = btc.icc * jcp.nb_ic_blocking; \ |
1410 | const int ic = icb * jcp.ic_block; \ |
1411 | const int g_ic = btc.g * jcp.ic + ic; \ |
1412 | const int ow = btc.owb * jcp.ow_block; \ |
1413 | const int oh = btc.ohb * jcp.oh_block; \ |
1414 | const int iid = ndims_pick(btc.od * SD - FP, 0, 0); \ |
1415 | const int kd_s = ndims_pick(div_up(max(0, -iid), DD), 0, 0); \ |
1416 | const int kd_f = ndims_pick( \ |
1417 | KD - div_up(max(0, iid - ID + (KD - 1) * DD + 1), DD), 1, 1); \ |
1418 | const auto kd_l = kd_f - kd_s; \ |
1419 | const auto iih = ndims_pick(btc.oh * SH - TP, btc.oh * SH - TP, 0); \ |
1420 | const auto kh_s_ = div_up(max(0, -iih), DH); \ |
1421 | const auto kh_s = jcp.is_os_blocking ? 0 : ndims_pick(kh_s_, kh_s_, 0); \ |
1422 | const auto kh_f_ = KH - div_up(max(0, iih - IH + (KH - 1) * DH + 1), DH); \ |
1423 | const auto kh_f = ndims_pick(kh_f_, kh_f_, 1); \ |
1424 | const auto kh_l = kh_f - kh_s; \ |
1425 | const bool is_oc_tail = (jcp.oc - oc < jcp.oc_block); \ |
1426 | const bool is_ic_tail = (btc.icc == _pd->ic_chunks - 1 \ |
1427 | && ((jcp.ic - ic) % jcp.ic_block != 0)); \ |
1428 | const bool is_ow_tail = (OW - ow < jcp.ow_block); \ |
1429 | const bool is_oh_tail = (OH - oh < jcp.oh_block); \ |
1430 | const char *const __restrict bias_w \ |
1431 | = bias ? bias + (bias_d.blk_off(g_oc) * bia_dsz) : nullptr; \ |
1432 | const auto nb_ic_b = nstl::min(jcp.nb_ic_blocking, jcp.nb_ic - icb) \ |
1433 | - (is_ic_tail ? 1 : 0); \ |
1434 | char *const __restrict dst_base \ |
1435 | = dst + dst_dsz * (btc.n * dst_d_sz + g_oc); \ |
1436 | char *ptr_C; \ |
1437 | char *ptr_D; \ |
1438 | int kd_b(0), kd_e(0), kh_b(0), kh_e(0), k_l(0), iiw_b(0); |
1439 | |
1440 | template <cpu_isa_t isa, bool use_inversion> |
1441 | void brgemm_convolution_fwd_t<isa, use_inversion>::ker_base( |
1442 | brgemm_thread_ctx_t &btc) const { |
1443 | |
1444 | const auto _pd = pd(); |
1445 | const auto &jcp = _pd->jcp_; |
1446 | auto ndims = _pd->ndims(); |
1447 | |
1448 | BRGEMM_CONV_KER_HEADER; |
1449 | MAYBE_UNUSED(is_ow_tail); |
1450 | MAYBE_UNUSED(is_oh_tail); |
1451 | |
1452 | int kw_s {0}, kw_full_s {0}, kw_f {0}, kw_full_f {0}, kw_b(0), kw_e(0); |
1453 | |
1454 | get_kw_range(ow, kw_s, kw_full_s, kw_full_f, kw_f); |
1455 | |
1456 | const auto src_base = src + src_dsz * (btc.n * src_d_sz + g_ic); |
1457 | const auto wei_base |
1458 | = weights + wei_dsz * (btc.g * wei_ocb_sz + btc.ocb * wei_kd_sz); |
1459 | |
1460 | const auto call_brgemm = [&](int brg_idx, int ic_block_s, int n_ic_blocks, |
1461 | int32_t *src_zp, int32_t *s8s8_comp, |
1462 | bool do_postops, bool do_only_comp) { |
1463 | if (k_l <= 0) return; |
1464 | |
1465 | for (int i_icb = 0; i_icb < n_ic_blocks; i_icb++) { |
1466 | const auto ic_off = (ic_block_s + i_icb) * jcp.ic_block; |
1467 | const auto src_ic = ic_off; |
1468 | const auto wei_ic = ic + ic_off; |
1469 | const auto n_icb_off = i_icb * k_l; |
1470 | const auto src_base_ic = src_base + src_dsz * src_ic; |
1471 | const auto wei_base_ic = wei_base + wei_dsz * wei_ic * jcp.oc_block; |
1472 | |
1473 | auto k = 0; |
1474 | for (int kd = kd_b; kd < kd_e; kd++) { |
1475 | const auto id = iid + kd * DD; |
1476 | const auto src_base_kd = src_base_ic + src_dsz * id * src_h_sz; |
1477 | const auto wei_base_kd = wei_base_ic |
1478 | + wei_dsz * maybe_invert(kd, KD) * wei_kh_sz; |
1479 | for (int kh = kh_b; kh < kh_e; kh++) { |
1480 | const auto ih = iih + kh * DH; |
1481 | const auto src_base_kh |
1482 | = src_base_kd + src_dsz * ih * src_w_sz; |
1483 | const auto wei_base_kh = wei_base_kd |
1484 | + wei_dsz * maybe_invert(kh, KH) * wei_kw_sz; |
1485 | for (int kw = kw_b; kw < kw_e; kw++) { |
1486 | const auto iw = iiw_b + kw * DW; |
1487 | btc.brg_batch[n_icb_off + k].ptr.A = src_base_kh |
1488 | + src_dsz * iw * jcp.ngroups |
1489 | * jcp.ic_without_padding; |
1490 | btc.brg_batch[n_icb_off + k].vvpad.top = 0; |
1491 | btc.brg_batch[n_icb_off + k].vvpad.bottom = 0; |
1492 | // general wei layout is gOdhwI<block_o><block_i> |
1493 | btc.brg_batch[n_icb_off + k].ptr.B = wei_base_kh |
1494 | + wei_dsz * maybe_invert(kw, KW) * wei_ic_sz; |
1495 | k++; |
1496 | } |
1497 | } |
1498 | } |
1499 | } |
1500 | call_brgemm_kernel(btc, brg_idx, k_l * n_ic_blocks, ptr_C, ptr_D, |
1501 | bias_w, g_oc, do_postops, post_ops_binary_rhs_arg_vec.data(), |
1502 | btc.src_zp_vals, src_zp, btc.dst_zp_vals, s8s8_comp, |
1503 | do_only_comp); |
1504 | }; |
1505 | |
1506 | const auto kdhw_loop = [&]() { |
1507 | if (kw_e - kw_b <= 0) return; |
1508 | int ow_b {0}, ow_e {0}; |
1509 | get_ow_range(ow, kw_b, ow_b, ow_e); |
1510 | |
1511 | const auto do_init |
1512 | = btc.icc == 0 && kd_b == kd_s && kh_b == kh_s && kw_b == kw_s; |
1513 | const auto do_postwork = _pd->need_postwork |
1514 | && btc.icc == (_pd->ic_chunks - 1) && kd_e == kd_f |
1515 | && kh_e == kh_f && kw_e == kw_f; |
1516 | const auto do_only_comp = need_compensation && kd_e == kd_f |
1517 | && kh_e == kh_f && kw_e != kw_f |
1518 | && btc.icc == (_pd->ic_chunks - 1); |
1519 | if (ow_e - ow_b <= 0 && !do_init && !do_postwork) return; |
1520 | |
1521 | k_l = (kd_e - kd_b) * (kh_e - kh_b) * (kw_e - kw_b); |
1522 | iiw_b = ow_b * SW - LP; |
1523 | ptr_D = dst_base |
1524 | + dst_dsz |
1525 | * (btc.od * dst_h_sz + btc.oh * dst_w_sz |
1526 | + ow_b * jcp.oc_without_padding); |
1527 | ptr_C = (jcp.use_buffer) |
1528 | ? btc.c_buffer + acc_dsz * (ow_b - ow) * jcp.LDC |
1529 | : static_cast<char *>(ptr_D); |
1530 | |
1531 | const auto ow_l = ow_e - ow_b; |
1532 | assert(0 <= ow_l && ow_l <= jcp.ow_block); |
1533 | |
1534 | const auto comp_ker_offs = get_comp_offset( |
1535 | btc.g, btc.ocb, ow_b, kd_s, kd_f, kh_s, kh_f, kw_b, kw_e); |
1536 | |
1537 | const auto ker_i = ow_l - 1; |
1538 | int kernel_idx[2][2]; |
1539 | kernel_idx[false][false] |
1540 | = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, false); |
1541 | kernel_idx[true][false] |
1542 | = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, false); |
1543 | kernel_idx[false][true] |
1544 | = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, true); |
1545 | kernel_idx[true][true] |
1546 | = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, true); |
1547 | |
1548 | if (ow_l > 0 && k_l > 0) { |
1549 | if (nb_ic_b > 0) { |
1550 | const auto brg_idx = kernel_idx[do_init][false]; |
1551 | call_brgemm(brg_idx, 0, nb_ic_b, |
1552 | jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_ker_offs] |
1553 | : nullptr, |
1554 | jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_ker_offs] |
1555 | : nullptr, |
1556 | do_postwork && !is_ic_tail, do_only_comp); |
1557 | } |
1558 | |
1559 | if (is_ic_tail) { |
1560 | const auto use_init_ker = (do_init && nb_ic_b == 0); |
1561 | const auto brg_ic_tail_idx = kernel_idx[use_init_ker][true]; |
1562 | call_brgemm(brg_ic_tail_idx, nb_ic_b, 1, |
1563 | jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_ker_offs] |
1564 | : nullptr, |
1565 | jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_ker_offs] |
1566 | : nullptr, |
1567 | do_postwork, do_only_comp); |
1568 | } |
1569 | } |
1570 | |
1571 | perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow, |
1572 | g_oc, is_oc_tail, ow_b, ow_e, kd_l, kh_l, |
1573 | post_ops_binary_rhs_arg_vec.data(), btc.oscales, |
1574 | btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals, |
1575 | btc.s8s8_comp_ptr, do_init, do_postwork, false); |
1576 | }; |
1577 | |
1578 | if (kd_f > kd_s && kh_f > kh_s && kw_f > kw_s) { |
1579 | // kw values with left padding |
1580 | if (kw_s < kw_full_s) { |
1581 | for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK_PAD) { |
1582 | kd_e = nstl::min(kd_f, kd_b + KD_BLOCK_PAD); |
1583 | for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK_PAD) { |
1584 | kh_e = nstl::min(kh_f, kh_b + KH_BLOCK_PAD); |
1585 | for (auto kw = kw_s; kw < kw_full_s; kw++) { |
1586 | kw_b = kw; |
1587 | kw_e = kw + 1; |
1588 | kdhw_loop(); |
1589 | } |
1590 | } |
1591 | } |
1592 | } |
1593 | |
1594 | // kw values covering full ow_block |
1595 | if (kw_full_s < kw_full_f) { |
1596 | for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) { |
1597 | kd_e = nstl::min(kd_f, kd_b + KD_BLOCK); |
1598 | for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) { |
1599 | kh_e = nstl::min(kh_f, kh_b + KH_BLOCK); |
1600 | for (kw_b = kw_full_s; kw_b < kw_full_f; kw_b += KW_BLOCK) { |
1601 | kw_e = nstl::min(kw_full_f, kw_b + KW_BLOCK); |
1602 | kdhw_loop(); |
1603 | } |
1604 | } |
1605 | } |
1606 | } |
1607 | |
1608 | // kw values with right padding |
1609 | if (kw_full_f < kw_f) { |
1610 | for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK_PAD) { |
1611 | kd_e = nstl::min(kd_f, kd_b + KD_BLOCK_PAD); |
1612 | for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK_PAD) { |
1613 | kh_e = nstl::min(kh_f, kh_b + KH_BLOCK_PAD); |
1614 | for (int kw = kw_full_f; kw < kw_f; kw++) { |
1615 | kw_b = kw; |
1616 | kw_e = kw + 1; |
1617 | kdhw_loop(); |
1618 | } |
1619 | } |
1620 | } |
1621 | } |
1622 | } else { |
1623 | const auto do_init = btc.icc == 0; |
1624 | const auto do_postwork |
1625 | = _pd->need_postwork && btc.icc == (_pd->ic_chunks - 1); |
1626 | perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow, |
1627 | g_oc, is_oc_tail, ow, ow, kd_l, kh_l, |
1628 | post_ops_binary_rhs_arg_vec.data(), btc.oscales, |
1629 | btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals, |
1630 | btc.s8s8_comp_ptr, do_init, do_postwork, false); |
1631 | } |
1632 | } |
1633 | |
1634 | template <cpu_isa_t isa, bool use_inversion> |
1635 | void brgemm_convolution_fwd_t<isa, use_inversion>::ker_trans( |
1636 | brgemm_thread_ctx_t &btc, char *inp_buffer) const { |
1637 | |
1638 | const auto _pd = pd(); |
1639 | const auto &jcp = _pd->jcp_; |
1640 | auto ndims = _pd->ndims(); |
1641 | |
1642 | BRGEMM_CONV_KER_HEADER; |
1643 | MAYBE_UNUSED(g_ic); |
1644 | MAYBE_UNUSED(src); |
1645 | |
1646 | const auto wei_base |
1647 | = weights + wei_dsz * (btc.g * wei_ocb_sz + btc.ocb * wei_kd_sz); |
1648 | const int ow_b {ow}, |
1649 | ow_e {ow + (is_ow_tail ? jcp.ow % jcp.ow_block : jcp.ow_block)}; |
1650 | const int oh_b {oh}, |
1651 | oh_e {oh + (is_oh_tail ? jcp.oh % jcp.oh_block : jcp.oh_block)}; |
1652 | iiw_b = ow_b * SW - LP; |
1653 | ptr_D = dst_base |
1654 | + dst_dsz |
1655 | * (btc.od * dst_h_sz + btc.oh * dst_w_sz |
1656 | + ow_b * jcp.oc_without_padding); |
1657 | ptr_C = (jcp.use_buffer) ? btc.c_buffer + acc_dsz * (ow_b - ow) * jcp.LDC |
1658 | : static_cast<char *>(ptr_D); |
1659 | |
1660 | const auto ow_l = ow_e - ow_b; |
1661 | const auto oh_l = oh_e - oh_b; |
1662 | assert(0 <= ow_l && ow_l <= jcp.ow_block && 0 <= oh_l |
1663 | && oh_l <= jcp.oh_block); |
1664 | |
1665 | const auto ker_i = (jcp.is_os_blocking ? oh_l * ow_l : ow_l) - 1; |
1666 | |
1667 | const auto call_brgemm = [&](int brg_idx, int ic_block_s, int n_ic_blocks, |
1668 | bool do_postops) { |
1669 | if (k_l <= 0) return; |
1670 | |
1671 | const auto kh_ee = jcp.kh_sets > 1 ? kh_b + 1 : kh_e; |
1672 | const auto kw_e = jcp.kw_sets > 1 ? 1 : KW; |
1673 | const auto pbuf_base = inp_buffer |
1674 | + src_dsz |
1675 | * ((jcp.copy_block_only |
1676 | ? 0 |
1677 | : ((icb + ic_block_s) * pbuf_d_sz))); |
1678 | const auto iid_shift = jcp.copy_block_only |
1679 | ? nstl::max(0, btc.odb * jcp.od_block * SD - FP) |
1680 | : 0; |
1681 | const auto iih_shift = jcp.copy_block_only |
1682 | ? nstl::max(0, btc.ohb * jcp.oh_block * SH - TP) |
1683 | : 0; |
1684 | const auto iiw_shift |
1685 | = jcp.copy_block_only ? (btc.owb * jcp.ow_block * SW) : 0; |
1686 | |
1687 | for (int i_icb = 0; i_icb < n_ic_blocks; i_icb++) { |
1688 | const auto ic_off = (ic_block_s + i_icb) * jcp.ic_block; |
1689 | const auto wei_ic = ic + ic_off; |
1690 | const auto n_icb_off = i_icb * k_l; |
1691 | const auto pbuf_base_ic = pbuf_base |
1692 | + src_dsz |
1693 | * ((jcp.copy_block_only ? 0 : (i_icb * pbuf_d_sz))); |
1694 | const auto wei_base_ic = wei_base + wei_dsz * wei_ic * jcp.oc_block; |
1695 | |
1696 | auto k = 0; |
1697 | for (int kd = kd_b; kd < kd_e; kd++) { |
1698 | const auto id = iid - iid_shift + kd * DD + FP; |
1699 | const auto pbuf_base_kd |
1700 | = pbuf_base_ic + src_dsz * id * pbuf_h_sz; |
1701 | const auto wei_base_kd = wei_base_ic |
1702 | + wei_dsz * maybe_invert(kd, KD) * wei_kh_sz; |
1703 | for (int kh = kh_b; kh < kh_ee; kh++) { |
1704 | const auto ih = jcp.kh_sets > 1 |
1705 | ? (iih + 2 * TP) |
1706 | : (iih - iih_shift + kh * DH + TP); |
1707 | const auto pbuf_base_kh |
1708 | = pbuf_base_kd + src_dsz * ih * pbuf_w_sz; |
1709 | const auto wei_base_kh = wei_base_kd |
1710 | + wei_dsz |
1711 | * ((jcp.kh_sets > 1 ? 0 |
1712 | : maybe_invert(kh, KH)) |
1713 | * wei_kw_sz); |
1714 | for (int kw = 0; kw < kw_e; kw++) { |
1715 | const auto iw = iiw_b - iiw_shift + kw * DW + LP; |
1716 | // inp_buffer layout is Cdhw<ic_block>c |
1717 | btc.brg_batch[n_icb_off + k].ptr.A = pbuf_base_kh |
1718 | + src_dsz * iw * jcp.ic_block * jcp.kh_sets |
1719 | * jcp.kw_sets; |
1720 | btc.brg_batch[n_icb_off + k].vvpad.top = 0; |
1721 | btc.brg_batch[n_icb_off + k].vvpad.bottom = 0; |
1722 | // general wei layout is gOdhwI<block_o><block_i> |
1723 | btc.brg_batch[n_icb_off + k].ptr.B = wei_base_kh |
1724 | + wei_dsz * maybe_invert(kw, KW) * wei_ic_sz; |
1725 | k++; |
1726 | } |
1727 | } |
1728 | } |
1729 | } |
1730 | |
1731 | call_brgemm_kernel(btc, brg_idx, k_l * n_ic_blocks, ptr_C, ptr_D, |
1732 | bias_w, g_oc, do_postops, post_ops_binary_rhs_arg_vec.data(), |
1733 | btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals, |
1734 | btc.s8s8_comp_ptr, false); |
1735 | }; |
1736 | |
1737 | const auto kdhw_loop = [&]() { |
1738 | const auto do_init = btc.icc == 0 && kd_b == kd_s && kh_b == kh_s; |
1739 | const auto do_postwork = _pd->need_postwork |
1740 | && btc.icc == (_pd->ic_chunks - 1) && kd_e == kd_f |
1741 | && kh_e == kh_f; |
1742 | if (ow_e - ow_b <= 0 && !do_init && !do_postwork) return; |
1743 | |
1744 | k_l = (kd_e - kd_b) * (jcp.kh_sets > 1 ? 1 : (kh_e - kh_b)) |
1745 | * (jcp.kw_sets > 1 ? 1 : KW); |
1746 | |
1747 | int kernel_idx[2][2]; |
1748 | kernel_idx[false][false] |
1749 | = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, false); |
1750 | kernel_idx[true][false] |
1751 | = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, false); |
1752 | kernel_idx[false][true] |
1753 | = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, true); |
1754 | kernel_idx[true][true] |
1755 | = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, true); |
1756 | |
1757 | if (nb_ic_b > 0) { |
1758 | const auto brg_idx = kernel_idx[do_init][false]; |
1759 | call_brgemm(brg_idx, 0, nb_ic_b, do_postwork && !is_ic_tail); |
1760 | } |
1761 | |
1762 | if (is_ic_tail) { |
1763 | const auto use_init_ker = (do_init && nb_ic_b == 0); |
1764 | const auto brg_ic_tail_idx = kernel_idx[use_init_ker][true]; |
1765 | call_brgemm(brg_ic_tail_idx, nb_ic_b, 1, do_postwork); |
1766 | } |
1767 | }; |
1768 | |
1769 | if (kd_f > kd_s && kh_f > kh_s) { |
1770 | // kw values covering full ow_block |
1771 | for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) { |
1772 | kd_e = nstl::min(kd_f, kd_b + KD_BLOCK); |
1773 | for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) { |
1774 | kh_e = nstl::min(kh_f, kh_b + KH_BLOCK); |
1775 | kdhw_loop(); |
1776 | } |
1777 | } |
1778 | } else { |
1779 | const auto do_init = btc.icc == 0; |
1780 | const auto do_postwork |
1781 | = _pd->need_postwork && btc.icc == (_pd->ic_chunks - 1); |
1782 | perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow, |
1783 | g_oc, is_oc_tail, ow, ow, kd_l, kh_l, |
1784 | post_ops_binary_rhs_arg_vec.data(), btc.oscales, |
1785 | btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals, |
1786 | btc.s8s8_comp_ptr, do_init, do_postwork, false); |
1787 | } |
1788 | } |
1789 | |
1790 | template <cpu_isa_t isa, bool use_inversion> |
1791 | void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad( |
1792 | brgemm_thread_ctx_t &btc) const { |
1793 | |
1794 | const auto _pd = pd(); |
1795 | const auto &jcp = _pd->jcp_; |
1796 | auto ndims = _pd->ndims(); |
1797 | |
1798 | BRGEMM_CONV_KER_HEADER; |
1799 | MAYBE_UNUSED(is_oh_tail); |
1800 | |
1801 | const char *const __restrict src_base |
1802 | = src + src_dsz * (btc.n * src_d_sz + g_ic); |
1803 | |
1804 | const char *const __restrict wei_base |
1805 | = weights + wei_dsz * (btc.g * wei_ocb_sz + btc.ocb * wei_kd_sz); |
1806 | |
1807 | const int ow_b {ow}, ow_e {ow + (is_ow_tail ? jcp.M_tail : jcp.M)}; |
1808 | iiw_b = ow_b * SW - LP; |
1809 | ptr_D = dst_base |
1810 | + dst_dsz |
1811 | * (btc.od * dst_h_sz + btc.oh * dst_w_sz |
1812 | + ow_b * jcp.oc_without_padding); |
1813 | ptr_C = (jcp.use_buffer) ? btc.c_buffer + acc_dsz * (ow_b - ow) * jcp.LDC |
1814 | : static_cast<char *>(ptr_D); |
1815 | |
1816 | const auto ow_l = ow_e - ow_b; |
1817 | assert(0 <= ow_l && ow_l <= jcp.ow_block); |
1818 | const auto ker_i = ow_l - 1; |
1819 | const dim_t *const __restrict kw_top_vpads |
1820 | = owb_kw_top_vpads.data() + btc.owb * KW; |
1821 | const dim_t *const __restrict kw_bottom_vpads |
1822 | = owb_kw_bottom_vpads.data() + btc.owb * KW; |
1823 | |
1824 | const auto call_brgemm = [&](int brg_idx, int ic_block_s, int n_ic_blocks, |
1825 | int32_t *src_zp, int32_t *s8s8_comp, |
1826 | bool do_postops) { |
1827 | for (int i_icb = 0; i_icb < n_ic_blocks; i_icb++) { |
1828 | const auto ic_off = (ic_block_s + i_icb) * jcp.ic_block; |
1829 | const auto src_ic = ic_off; |
1830 | const auto wei_ic = ic + ic_off; |
1831 | const auto n_icb_off = i_icb * k_l; |
1832 | const char *const __restrict src_base_ic |
1833 | = src_base + src_dsz * src_ic; |
1834 | const char *const __restrict wei_base_ic |
1835 | = wei_base + wei_dsz * wei_ic * jcp.oc_block; |
1836 | brgemm_batch_element_t *const __restrict icb_batch |
1837 | = btc.brg_batch + n_icb_off; |
1838 | |
1839 | auto k = 0; |
1840 | for (int kd = kd_b; kd < kd_e; kd++) { |
1841 | const auto id = iid + kd * DD; |
1842 | const char *const __restrict src_base_kd |
1843 | = src_base_ic + src_dsz * id * src_h_sz; |
1844 | const char *const __restrict wei_base_kd = wei_base_ic |
1845 | + wei_dsz * maybe_invert(kd, KD) * wei_kh_sz; |
1846 | for (int kh = kh_b; kh < kh_e; kh++) { |
1847 | const auto ih = iih + kh * DH; |
1848 | const char *const __restrict src_base_kh |
1849 | = src_base_kd + src_dsz * ih * src_w_sz; |
1850 | const char *const __restrict wei_base_kh = wei_base_kd |
1851 | + wei_dsz * maybe_invert(kh, KH) * wei_kw_sz; |
1852 | for (int kw = 0; kw < KW; kw++) { |
1853 | const auto iw = iiw_b + kw * DW; |
1854 | const auto ptr_A = src_base_kh |
1855 | + static_cast<ptrdiff_t>(src_dsz) * iw |
1856 | * jcp.ngroups * jcp.ic_without_padding; |
1857 | if (jcp.max_vpad) { |
1858 | icb_batch[k].vvpad.top = kw_top_vpads[kw]; |
1859 | icb_batch[k].vvpad.bottom = kw_bottom_vpads[kw]; |
1860 | } |
1861 | // general wei layout is gOdhwI<block_o><block_i> |
1862 | const auto ptr_B = wei_base_kh |
1863 | + wei_dsz * maybe_invert(kw, KW) * wei_ic_sz; |
1864 | |
1865 | icb_batch[k].ptr.A = ptr_A; |
1866 | icb_batch[k].ptr.B = ptr_B; |
1867 | |
1868 | k++; |
1869 | } |
1870 | } |
1871 | } |
1872 | } |
1873 | |
1874 | call_brgemm_kernel(btc, brg_idx, k_l * n_ic_blocks, ptr_C, ptr_D, |
1875 | bias_w, g_oc, do_postops, post_ops_binary_rhs_arg_vec.data(), |
1876 | btc.src_zp_vals, src_zp, btc.dst_zp_vals, s8s8_comp, false); |
1877 | }; |
1878 | |
1879 | const auto kdhw_loop = [&]() { |
1880 | const auto do_init = btc.icc == 0 && kd_b == kd_s && kh_b == kh_s; |
1881 | const auto do_postwork = _pd->need_postwork |
1882 | && btc.icc == (_pd->ic_chunks - 1) && kd_e == kd_f |
1883 | && kh_e == kh_f; |
1884 | |
1885 | if (ow_e - ow_b <= 0 && !do_init && !do_postwork) return; |
1886 | |
1887 | k_l = (kd_e - kd_b) * (kh_e - kh_b) * KW; |
1888 | int kernel_idx[2][2]; |
1889 | kernel_idx[false][false] |
1890 | = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, false); |
1891 | kernel_idx[true][false] |
1892 | = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, false); |
1893 | kernel_idx[false][true] |
1894 | = _pd->get_brg_idx(k_l, ker_i, false, is_oc_tail, true); |
1895 | kernel_idx[true][true] |
1896 | = _pd->get_brg_idx(k_l, ker_i, true, is_oc_tail, true); |
1897 | |
1898 | const auto comp_offs = get_comp_offset( |
1899 | btc.g, btc.ocb, ow, kd_b, kd_e, kh_b, kh_e, 0, KW); |
1900 | |
1901 | if (nb_ic_b > 0) { |
1902 | const auto brg_idx = kernel_idx[do_init][false]; |
1903 | call_brgemm(brg_idx, 0, nb_ic_b, |
1904 | jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_offs] |
1905 | : nullptr, |
1906 | jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_offs] : nullptr, |
1907 | do_postwork && !is_ic_tail); |
1908 | } |
1909 | |
1910 | if (is_ic_tail) { |
1911 | const auto use_init_ker = (do_init && nb_ic_b == 0); |
1912 | const auto brg_ic_tail_idx = kernel_idx[use_init_ker][true]; |
1913 | call_brgemm(brg_ic_tail_idx, nb_ic_b, 1, |
1914 | jcp.src_zero_point ? &btc.src_zp_comp_ptr[comp_offs] |
1915 | : nullptr, |
1916 | jcp.s8s8_avx512 ? &btc.s8s8_comp_ptr[comp_offs] : nullptr, |
1917 | do_postwork); |
1918 | } |
1919 | }; |
1920 | |
1921 | if (kd_f > kd_s && kh_f > kh_s) { |
1922 | // kw values covering full ow_block |
1923 | for (kd_b = kd_s; kd_b < kd_f; kd_b += KD_BLOCK) { |
1924 | kd_e = nstl::min(kd_f, kd_b + KD_BLOCK); |
1925 | for (kh_b = kh_s; kh_b < kh_f; kh_b += KH_BLOCK) { |
1926 | kh_e = nstl::min(kh_f, kh_b + KH_BLOCK); |
1927 | kdhw_loop(); |
1928 | } |
1929 | } |
1930 | } else { |
1931 | const auto do_init = btc.icc == 0; |
1932 | const auto do_postwork |
1933 | = _pd->need_postwork && btc.icc == (_pd->ic_chunks - 1); |
1934 | perform_outwork(dst_base, dst, btc.c_buffer, bias_w, btc.od, btc.oh, ow, |
1935 | g_oc, is_oc_tail, ow, ow, kd_l, kh_l, |
1936 | post_ops_binary_rhs_arg_vec.data(), btc.oscales, |
1937 | btc.src_zp_vals, btc.src_zp_comp_ptr, btc.dst_zp_vals, |
1938 | btc.s8s8_comp_ptr, do_init, do_postwork, false); |
1939 | } |
1940 | } |
1941 | |
1942 | #undef BRGEMM_CONV_KER_HEADER |
1943 | |
1944 | template struct brgemm_convolution_fwd_t<avx2>; |
1945 | template struct brgemm_convolution_fwd_t<avx2_vnni_2>; |
1946 | template struct brgemm_convolution_fwd_t<avx2_vnni_2, true>; |
1947 | template struct brgemm_convolution_fwd_t<avx512_core>; |
1948 | template struct brgemm_convolution_fwd_t<avx512_core, true>; |
1949 | template struct brgemm_convolution_fwd_t<avx512_core_vnni>; |
1950 | template struct brgemm_convolution_fwd_t<avx512_core_bf16>; |
1951 | template struct brgemm_convolution_fwd_t<avx512_core_bf16, true>; |
1952 | template struct brgemm_convolution_fwd_t<avx512_core_fp16>; |
1953 | template struct brgemm_convolution_fwd_t<avx512_core_fp16, true>; |
1954 | template struct brgemm_convolution_fwd_t<avx512_core_amx>; |
1955 | template struct brgemm_convolution_fwd_t<avx512_core_amx, true>; |
1956 | template struct brgemm_convolution_fwd_t<avx512_core_amx_fp16>; |
1957 | template struct brgemm_convolution_fwd_t<avx512_core_amx_fp16, true>; |
1958 | |
1959 | } // namespace x64 |
1960 | |
1961 | } // namespace cpu |
1962 | } // namespace impl |
1963 | } // namespace dnnl |
1964 | |
1965 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1966 | |