1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/cpu_primitive.hpp" |
24 | #include "cpu/scale_utils.hpp" |
25 | |
26 | #include "cpu/x64/amx_tile_configure.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
28 | #include "cpu/x64/jit_brgemm_1x1_conv.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | using namespace dnnl::impl::status; |
36 | using namespace dnnl::impl::memory_tracking::names; |
37 | using namespace dnnl::impl::utils; |
38 | |
39 | using namespace nstl; |
40 | using namespace data_type; |
41 | |
42 | #define ndims_pick(v5, v4, v3) \ |
43 | ((ndims == 5) ? (v5) : (ndims == 4) ? (v4) : (ndims == 3) ? (v3) : 0) |
44 | |
45 | template <cpu_isa_t isa> |
46 | status_t brgemm_1x1_convolution_fwd_t<isa>::pd_t::init(engine_t *engine) { |
47 | using namespace data_type; |
48 | using namespace utils; |
49 | |
50 | const auto src_type = src_md(0)->data_type; |
51 | const auto wei_type = weights_md(0)->data_type; |
52 | const auto dst_type = dst_md(0)->data_type; |
53 | const bool is_int8 = one_of(src_type, u8, s8); |
54 | |
55 | using skip_mask_t = primitive_attr_t::skip_mask_t; |
56 | auto skip_mask = skip_mask_t::post_ops | skip_mask_t::sum_dt |
57 | | skip_mask_t::zero_points_runtime; |
58 | if (one_of(src_type, u8, s8)) skip_mask |= skip_mask_t::scales_runtime; |
59 | |
60 | bool ok = is_fwd() && set_default_alg_kind(alg_kind::convolution_direct) |
61 | && expect_data_types(src_type, wei_type, data_type::undef, dst_type, |
62 | data_type::undef) |
63 | && IMPLICATION(is_int8, |
64 | one_of(bias_md_.data_type, data_type::undef, f32, s32, s8, |
65 | u8)) |
66 | && IMPLICATION(!is_int8, |
67 | one_of(bias_md_.data_type, data_type::undef, f32, src_type)) |
68 | && attr()->has_default_values(skip_mask, dst_type) |
69 | && attr()->post_ops_.check_sum_consistent_dt(dst_type) |
70 | && !has_zero_dim_memory() && zero_points_ok() && arg_scales_ok(); |
71 | if (!ok) return status::unimplemented; |
72 | |
73 | CHECK(brgemm_convolution_utils::init_1x1_conf(jcp_, isa, *desc(), src_md_, |
74 | weights_md_, dst_md_, bias_md_, attr_, dnnl_get_max_threads())); |
75 | |
76 | for (int i = 0; i < 16; i++) |
77 | brgs_[i].bcast_dim = brgs_[i].load_dim = brgs_[i].reduce_dim = 0; |
78 | |
79 | const float alpha = 1.0; |
80 | const float beta = 1.0; |
81 | const auto &p = attr()->post_ops_; |
82 | const int sum_idx = p.find(primitive_kind::sum); |
83 | with_sum = (sum_idx != -1); |
84 | sum_scale = with_sum ? p.entry_[sum_idx].sum.scale : 0.0; |
85 | |
86 | ic_chunks = div_up(jcp_.nb_ic, jcp_.nb_ic_blocking); |
87 | need_postwork = jcp_.with_bias || jcp_.with_eltwise || jcp_.with_binary |
88 | || (one_of(src_type, u8, s8) && wei_type == s8) // oscales needed |
89 | || (jcp_.dst_dt != jcp_.acc_dt) || jcp_.with_sum; |
90 | |
91 | int i_init_begin = (ic_chunks == 1) ? 1 : 0; |
92 | int i_init_end = 2; |
93 | |
94 | for_(int i_M = 0; i_M < 2; i_M++) |
95 | for_(int i_N = 0; i_N < 2; i_N++) |
96 | for_(int i_K = 0; i_K < 2; i_K++) |
97 | for (int i_init = i_init_begin; i_init < i_init_end; i_init++) { |
98 | auto vbeta = (i_init) ? 0 : beta; |
99 | auto vM = (i_M) ? jcp_.M_tail : jcp_.M; |
100 | auto vN = (i_N) ? jcp_.N_tail : jcp_.N; |
101 | auto vK = (i_K) ? jcp_.K_tail : jcp_.K; |
102 | brgemm_t &brg = brgs_[get_brg_idx(i_init, i_M, i_N, i_K)]; |
103 | if (vM == 0 || vN == 0 || vK == 0) continue; |
104 | brgemm_strides_t brg_strides; |
105 | brg_strides.stride_a = jcp_.brg_stride_a; |
106 | brg_strides.stride_b = jcp_.brg_stride_b; |
107 | const auto strides_ptr |
108 | = (jcp_.brg_type == brgemm_strd) ? &brg_strides : nullptr; |
109 | CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, src_type, wei_type, |
110 | false, false, brgemm_row_major, alpha, vbeta, jcp_.LDA, |
111 | jcp_.LDB, jcp_.LDC, vM, vN, vK, strides_ptr)); |
112 | |
113 | brgemm_attr_t brgattr; |
114 | brgattr.max_bs = jcp_.gemm_batch_size; |
115 | brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost |
116 | ? brgemm_bd_loop_innermost |
117 | : brgemm_ld_loop_innermost; |
118 | brgattr.max_top_vpad = jcp_.max_vpad; |
119 | brgattr.max_bottom_vpad = jcp_.max_vpad; |
120 | |
121 | // assuming 2x2 decomposition in amx brgemm kernel |
122 | const auto bd_blocking = 2 * jcp_.amx_h; |
123 | brgattr.hint_expected_A_size = bd_blocking * vK; |
124 | brgattr.hint_expected_B_size = vN * vK; |
125 | brgattr.hint_expected_C_size = bd_blocking * vN; |
126 | |
127 | brgattr.wary_tail_read = false; |
128 | brgattr.use_uker = jcp_.use_uker; |
129 | brgattr.use_interleave_stores = brgattr.use_uker; |
130 | brgattr.hint_prefetching = jcp_.hint_prefetching; |
131 | brgattr.fpmath_mode = attr()->fpmath_mode_; |
132 | // if post-ops are required and there are no intermediate calculations |
133 | // (like ic_chunks > 1) then we don't need code without post-ops in |
134 | // brgemm kernel |
135 | if (need_postwork && ic_chunks == 1) brgattr.postops_only = true; |
136 | |
137 | CHECK(brgemm_desc_set_attr(&brg, brgattr)); |
138 | auto LDD = jcp_.oc_without_padding; |
139 | brg.with_sum = with_sum; |
140 | CHECK(brgemm_desc_set_postops( |
141 | &brg, attr(), &dst_md_, LDD, jcp_.bia_dt)); |
142 | jcp_.amx_buf_size_per_thread = nstl::max( |
143 | brg.get_wsp_buffer_size(), jcp_.amx_buf_size_per_thread); |
144 | } |
145 | |
146 | brgemm_convolution_utils::set_amx_wsp_per_thread(jcp_); |
147 | auto scratchpad = scratchpad_registry().registrar(); |
148 | brgemm_convolution_utils::init_scratchpad(scratchpad, jcp_); |
149 | if (jcp_.with_scales) |
150 | book_precomputed_scales(scratchpad, attr()->scales_, OC()); |
151 | |
152 | return status::success; |
153 | } |
154 | |
155 | template <cpu_isa_t isa> |
156 | status_t brgemm_1x1_convolution_fwd_t<isa>::init(engine_t *engine) { |
157 | auto ndims = pd()->ndims(); |
158 | if (ndims < 3 || ndims > 5) assert(!"Invalid ndims!" ); |
159 | |
160 | const auto &jcp = pd()->jcp_; |
161 | |
162 | ID = ndims_pick(jcp.id, 1, 1); |
163 | IH = ndims_pick(jcp.ih, jcp.ih, 1); |
164 | IW = jcp.iw; |
165 | |
166 | OD = ndims_pick(jcp.od, 1, 1); |
167 | OH = ndims_pick(jcp.oh, jcp.oh, 1); |
168 | OW = jcp.ow; |
169 | |
170 | SD = ndims_pick(jcp.stride_d, 1, 1); |
171 | SH = ndims_pick(jcp.stride_h, jcp.stride_h, 1); |
172 | SW = jcp.stride_w; |
173 | |
174 | bia_dsz = jcp.bia_dsz; |
175 | acc_dsz = jcp.acc_dsz; |
176 | src_dsz = jcp.src_dsz; |
177 | wei_dsz = jcp.wei_dsz; |
178 | |
179 | // const variables used for address calculations |
180 | src_w_sz = (dim_t)IW * jcp.ngroups * jcp.ic_without_padding; |
181 | src_h_sz = IH * src_w_sz; |
182 | src_d_sz = ID * src_h_sz; |
183 | dst_w_sz = (dim_t)OW * jcp.oc_without_padding; |
184 | dst_h_sz = OH * dst_w_sz; |
185 | dst_d_sz = OD * dst_h_sz; |
186 | |
187 | const auto src_type = pd()->src_md(0)->data_type; |
188 | |
189 | const auto last_ic_block |
190 | = src_type == f16 ? 1 : data_type_vnni_granularity(src_type); |
191 | |
192 | wei_oc_sz = jcp.wei_plain ? jcp.oc : jcp.oc_block; |
193 | wei_ic_sz = jcp.wei_plain |
194 | ? (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc |
195 | : (dim_t)rnd_up(jcp.ic, last_ic_block) * jcp.oc_block; |
196 | wei_ocb_sz = jcp.wei_plain ? jcp.oc_block * last_ic_block |
197 | : jcp.nb_oc * wei_ic_sz; |
198 | |
199 | for (int i = 0; i < 16; i++) |
200 | brg_kernels_[i] = nullptr; |
201 | |
202 | if (jcp.is_rtus) { |
203 | CHECK(safe_ptr_assign(rtus_kernel_, |
204 | new jit_avx512_core_brgemm_conv_trans_kernel:: |
205 | jit_avx512_core_brgemm_conv_rtus_kernel_t(jcp))); |
206 | CHECK(rtus_kernel_->create_kernel()); |
207 | } |
208 | int i_init_begin = (pd()->ic_chunks == 1) ? 1 : 0; |
209 | int i_init_end = 2; |
210 | |
211 | const bool is_amx = brgemm_convolution_utils::is_amx(isa); |
212 | for_(int i_M = 0; i_M < 2; i_M++) |
213 | for_(int i_N = 0; i_N < 2; i_N++) |
214 | for_(int i_K = 0; i_K < 2; i_K++) |
215 | for (int i_init = i_init_begin; i_init < i_init_end; i_init++) { |
216 | auto brg_idx = get_brg_idx(i_init, i_M, i_N, i_K); |
217 | auto &brg = pd()->brgs_[brg_idx]; |
218 | if (brg.bcast_dim > 0 && brg.load_dim > 0 && brg.reduce_dim > 0 |
219 | && !brg_kernels_[brg_idx]) { |
220 | brgemm_kernel_t *brg_kernel = nullptr; |
221 | CHECK(brgemm_kernel_create(&brg_kernel, brg)); |
222 | CHECK(safe_ptr_assign(brg_kernels_[brg_idx], brg_kernel)); |
223 | if (is_amx) { |
224 | amx_palette_t tmp; |
225 | int &palette_idx = brg_kernel_palette_idx_[brg_idx]; |
226 | palette_idx = -1; |
227 | CHECK(brgemm_init_tiles(brg, tmp.p)); |
228 | // check if it's in set of tile configs |
229 | for (size_t i = 0; i < brg_kernel_palette_.size(); i++) { |
230 | const bool is_match = 0 |
231 | == std::memcmp(brg_kernel_palette_[i].p, tmp.p, |
232 | AMX_PALETTE_SIZE); |
233 | if (is_match) { |
234 | palette_idx = i; |
235 | break; |
236 | } |
237 | } |
238 | // add to set of tile configs if needed |
239 | if (palette_idx == -1) { |
240 | palette_idx = brg_kernel_palette_.size(); |
241 | brg_kernel_palette_.push_back(tmp); |
242 | } |
243 | } |
244 | } |
245 | } |
246 | return status::success; |
247 | } |
248 | |
249 | template <cpu_isa_t isa> |
250 | void brgemm_1x1_convolution_fwd_t<isa>::maybe_rtus(int ithr, |
251 | const char *__restrict src, char *__restrict inp_buffer, |
252 | uint8_t *__restrict inp_buffer_mask, int g, int n, int icc, int od, |
253 | int oh, int ow) const { |
254 | const auto &jcp = pd()->jcp_; |
255 | if (!jcp.is_rtus) return; |
256 | assert(jcp.is_os_blocking); |
257 | const size_t src_dt_size = jcp.src_dsz; |
258 | |
259 | const auto os = (od * OH + oh) * OW + ow; |
260 | const auto osb = os / jcp.os_block; |
261 | |
262 | uint8_t *bmask = &inp_buffer_mask[icc * jcp.nb_os + osb]; |
263 | if (bmask && *bmask) return; // skip if already masked |
264 | if (bmask) *bmask = 1; // set mask to skip next time |
265 | |
266 | const auto g_ic = g * jcp.ic_without_padding |
267 | + icc * jcp.nb_ic_blocking * jcp.ic_block; |
268 | |
269 | auto call_kernel = [&](int nh, int nw, int od, int oh, int ow) { |
270 | assert(nh == 0 || (nw == 0 && ow == 0)); |
271 | if (utils::everyone_is(0, nh, nw)) return; |
272 | const int id = od * jcp.stride_d; |
273 | const int ih = oh * jcp.stride_h; |
274 | const int iw = ow * jcp.stride_w; |
275 | const auto inp_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz |
276 | + iw * jcp.ngroups * jcp.ic_without_padding + g_ic; |
277 | auto p = jit_avx512_core_brgemm_conv_trans_kernel:: |
278 | jit_brgemm_conv_trans_kernel_call_s(); |
279 | p.h_count = nh; |
280 | p.owb = nw; |
281 | p.src = src + src_dt_size * inp_offset; |
282 | p.dst = inp_buffer; |
283 | (*rtus_kernel_)(&p); |
284 | inp_buffer += src_dt_size * (nh * jcp.ow + nw) * jcp.LDA; |
285 | }; |
286 | |
287 | const bool is_os_tail = jcp.os - os < jcp.os_block; |
288 | int count = is_os_tail ? jcp.M_tail : jcp.M; |
289 | |
290 | if (count < OW || ow > 0) { |
291 | // copy to end of row |
292 | const auto nw = nstl::min(count, OW - ow); |
293 | call_kernel(0, nw, od, oh, ow); |
294 | count -= nw; |
295 | if (count == 0) return; |
296 | ow = 0; |
297 | oh = (oh + 1) % OH; |
298 | if (oh == 0) od++; |
299 | } |
300 | |
301 | while (od < OD) { |
302 | // copy to end of column |
303 | const auto nh = nstl::min(count / OW, OH - oh); |
304 | call_kernel(nh, 0, od, oh, ow); |
305 | count -= nh * OW; |
306 | if (count == 0) return; |
307 | oh = (oh + nh) % OH; |
308 | if (oh == 0) od++; |
309 | if (count < OW) { |
310 | // copy partial row |
311 | const auto nw = count; |
312 | call_kernel(0, nw, od, oh, ow); |
313 | return; |
314 | } |
315 | } |
316 | } |
317 | |
318 | template <cpu_isa_t isa> |
319 | void brgemm_1x1_convolution_fwd_t<isa>::exec_ker( |
320 | const brgemm_exec_ctx_t &brgemm_ctx, int ithr, |
321 | brgemm_batch_element_t *const __restrict brg_batch, |
322 | char *const c_buffer, const char *inp_buffer, int g, int n, int ocb, |
323 | int od, int oh, int ow, int icc, int *last_palette_idx, |
324 | const float *oscales, int32_t src_zp_vals, int32_t *src_zp_comp, |
325 | int32_t *dst_zp_vals, int32_t *s8s8_compensation) const { |
326 | |
327 | const memory_desc_wrapper src_d(pd()->src_md()); |
328 | const memory_desc_wrapper weights_d(pd()->weights_md()); |
329 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
330 | const size_t src_dt_size = types::data_type_size(src_d.data_type()); |
331 | const size_t wei_dt_size = types::data_type_size(weights_d.data_type()); |
332 | const size_t dst_dt_size = types::data_type_size(dst_d.data_type()); |
333 | |
334 | const char *const __restrict src = brgemm_ctx.src; |
335 | const char *const __restrict weights = brgemm_ctx.weights; |
336 | const char *const __restrict bias = brgemm_ctx.bias; |
337 | char *const __restrict dst = brgemm_ctx.dst; |
338 | const std::vector<const void *> &post_ops_binary_rhs_arg_vec |
339 | = brgemm_ctx.post_ops_binary_rhs_arg_vec; |
340 | |
341 | const auto &jcp = pd()->jcp_; |
342 | auto ndims = pd()->ndims(); |
343 | |
344 | const bool is_amx = brgemm_convolution_utils::is_amx(isa); |
345 | char *const wsp_tile = is_amx |
346 | ? brgemm_ctx.wsp_tile + ithr * jcp.amx_buf_size_per_thread |
347 | : nullptr; |
348 | |
349 | const int id = ndims_pick(od * SD, 0, 0); |
350 | const int ih = ndims_pick(oh * SH, oh * SH, 0); |
351 | const int iw = ow * SW; |
352 | |
353 | const int oc = ocb * jcp.oc_block; |
354 | const int g_oc = g * jcp.oc + oc; |
355 | |
356 | const int icb = icc * jcp.nb_ic_blocking; |
357 | const int ic = icb * jcp.ic_block; |
358 | const int g_ic = g * jcp.ic + ic; |
359 | |
360 | const bool kernel_init = (icc == 0); |
361 | |
362 | const auto os = (od * OH + oh) * OW + ow; |
363 | |
364 | const bool is_os_tail = jcp.is_os_blocking ? (jcp.os - os < jcp.os_block) |
365 | : (OW - ow < jcp.ow_block); |
366 | const bool is_oc_tail = (jcp.oc - oc < jcp.oc_block); |
367 | const bool is_ic_tail = (icc == pd()->ic_chunks - 1 |
368 | && ((jcp.ic - ic) % jcp.ic_block != 0)); |
369 | |
370 | const auto src_offset = n * src_d_sz + id * src_h_sz + ih * src_w_sz |
371 | + iw * jcp.ngroups * jcp.ic_without_padding + g_ic; |
372 | const auto src_base |
373 | = jcp.is_rtus ? inp_buffer : src + src_dt_size * src_offset; |
374 | const auto wei_offset = jcp.wei_plain ? g * wei_ic_sz + ocb * wei_ocb_sz |
375 | : g * wei_ocb_sz + ocb * wei_ic_sz; |
376 | const auto wei_base = weights + wei_dt_size * wei_offset; |
377 | const auto ptr_D = dst |
378 | + dst_dt_size |
379 | * (n * dst_d_sz + od * dst_h_sz + oh * dst_w_sz |
380 | + ow * jcp.oc_without_padding + g_oc); |
381 | char *const ptr_C = (jcp.use_buffer) ? c_buffer : (char *)ptr_D; |
382 | |
383 | const auto bias_w |
384 | = bias ? bias + (bias_d.blk_off(g_oc) * bia_dsz) : nullptr; |
385 | const auto nb_ic_b = nstl::min(jcp.nb_ic_blocking, jcp.nb_ic - icb) |
386 | - (is_ic_tail ? 1 : 0); |
387 | |
388 | const auto comp_offset = (g * jcp.nb_oc + ocb) * jcp.oc_block; |
389 | int32_t *src_zp_comp_ptr |
390 | = (jcp.src_zero_point && icc == pd()->ic_chunks - 1) |
391 | ? &src_zp_comp[comp_offset] |
392 | : nullptr; |
393 | int32_t *s8s8_comp_ptr = (jcp.s8s8_avx512 && icc == pd()->ic_chunks - 1) |
394 | ? &s8s8_compensation[comp_offset] |
395 | : nullptr; |
396 | |
397 | const auto call_brgemm = [=](int brg_idx, int ic_block_s, int n_ic_blocks, |
398 | bool do_postops) { |
399 | for (int k = 0; k < n_ic_blocks; k++) { |
400 | const auto ic_off = (ic_block_s + k) * jcp.ic_block; |
401 | const auto src_ic = ic_off; |
402 | const auto wei_ic = ic + ic_off; |
403 | const auto ptr_A = src_base + src_dt_size * src_ic; |
404 | const auto ptr_B = wei_base + wei_dt_size * wei_ic * wei_oc_sz; |
405 | brg_batch[k].ptr.A = ptr_A; |
406 | brg_batch[k].ptr.B = ptr_B; |
407 | brg_batch[k].vvpad.top = 0; |
408 | brg_batch[k].vvpad.bottom = 0; |
409 | } |
410 | |
411 | // NOTE: avoid some costly tile reconfigurations here by keeping track |
412 | // of the previous brg kernel tile configuration palette |
413 | // TODO: adjust harness to avoid even more tile reconfigurations |
414 | if (is_amx) { |
415 | const int curr_palette_idx = brg_kernel_palette_idx_[brg_idx]; |
416 | if (curr_palette_idx != *last_palette_idx) { |
417 | amx_tile_configure(brg_kernel_palette_[curr_palette_idx].p); |
418 | *last_palette_idx = curr_palette_idx; |
419 | } |
420 | } |
421 | |
422 | const brgemm_kernel_t *brg_ker = brg_kernels_[brg_idx].get(); |
423 | if (do_postops) { |
424 | const brgemm_post_ops_data_t post_ops_data { |
425 | static_cast<const void *>(bias_w), |
426 | &oscales[jcp.is_oc_scale * g_oc], |
427 | post_ops_binary_rhs_arg_vec.data(), |
428 | static_cast<size_t>(g_oc), 0, dst, 0, |
429 | static_cast<void *>(src_zp_comp_ptr), nullptr, |
430 | static_cast<void *>(dst_zp_vals), false, src_zp_vals}; |
431 | |
432 | void *scratch = is_amx ? static_cast<void *>(wsp_tile) |
433 | : static_cast<void *>(s8s8_comp_ptr); |
434 | brgemm_kernel_execute_postops(brg_ker, n_ic_blocks, brg_batch, |
435 | (void *)ptr_C, (void *)ptr_D, post_ops_data, scratch); |
436 | } else { |
437 | void *scratch = is_amx ? static_cast<void *>(wsp_tile) |
438 | : static_cast<void *>(s8s8_comp_ptr); |
439 | brgemm_kernel_execute( |
440 | brg_ker, n_ic_blocks, brg_batch, (void *)ptr_C, scratch); |
441 | } |
442 | }; |
443 | |
444 | const auto do_post_work = (pd()->need_postwork || jcp.use_buffer) |
445 | && icc == pd()->ic_chunks - 1; |
446 | |
447 | if (nb_ic_b > 0) { |
448 | const auto brg_idx |
449 | = get_brg_idx(kernel_init, is_os_tail, is_oc_tail, false); |
450 | call_brgemm(brg_idx, 0, nb_ic_b, do_post_work && !is_ic_tail); |
451 | } |
452 | if (is_ic_tail) { |
453 | const auto use_init_ker = (kernel_init && nb_ic_b == 0); |
454 | const auto brg_idx |
455 | = get_brg_idx(use_init_ker, is_os_tail, is_oc_tail, true); |
456 | |
457 | call_brgemm(brg_idx, nb_ic_b, 1, do_post_work); |
458 | } |
459 | } |
460 | |
461 | template <cpu_isa_t isa> |
462 | status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all( |
463 | const exec_ctx_t &ctx) const { |
464 | |
465 | brgemm_exec_ctx_t brgemm_ctx(ctx, pd()); |
466 | |
467 | const memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); |
468 | |
469 | const auto &jcp = pd()->jcp_; |
470 | const bool is_amx = brgemm_convolution_utils::is_amx(isa); |
471 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
472 | |
473 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
474 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
475 | |
476 | const float *oscales = precompute_scales(ctx.get_scratchpad_grantor(), |
477 | src_scales, wei_scales, pd()->OC(), pd()->attr()); |
478 | |
479 | DEFINE_ZERO_POINT_VALUE(src_zero_point, DNNL_ARG_SRC); |
480 | DEFINE_ZERO_POINT_VALUE(dst_zero_point, DNNL_ARG_DST); |
481 | |
482 | const auto |
483 | = weights_d.size() - weights_d.additional_buffer_size(); |
484 | auto w = const_cast<char *>(brgemm_ctx.weights); |
485 | int32_t *s8s8_compensation = (jcp.s8s8_avx512) |
486 | ? reinterpret_cast<int32_t *>(w + extra_data_offset) |
487 | : nullptr; |
488 | int32_t *zp_compensation = (jcp.src_zero_point) |
489 | ? reinterpret_cast<int32_t *>(&w[extra_data_offset]) |
490 | + (jcp.s8s8_avx512 ? jcp.s8s8_comp_buffer_size : 0) |
491 | : nullptr; |
492 | int32_t *dst_zp_vals = jcp.dst_zero_point ? &dst_zero_point : nullptr; |
493 | |
494 | brgemm_batch_element_t *const brg_batch_global |
495 | = (jcp.brg_type != brgemm_strd) |
496 | ? scratchpad.template get<brgemm_batch_element_t>( |
497 | key_brgemm_primitive_batch) |
498 | : nullptr; |
499 | char *const c_buffer_global = (jcp.use_buffer) |
500 | ? scratchpad.template get<char>(key_brgemm_primitive_buffer) |
501 | : nullptr; |
502 | char *inp_buffer_base = (jcp.is_rtus) |
503 | ? scratchpad.template get<char>(key_conv_brgemm_inp_buffer) |
504 | : nullptr; |
505 | uint8_t *inp_buffer_mask_base = (jcp.is_rtus) |
506 | ? scratchpad.template get<uint8_t>(key_conv_brgemm_inp_buffer_mask) |
507 | : nullptr; |
508 | |
509 | if (jcp.is_os_blocking) { |
510 | const int os_chunks = div_up(jcp.nb_os, jcp.nb_os_blocking); |
511 | const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_oc * os_chunks; |
512 | |
513 | #define BRGC_WO(...) \ |
514 | parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \ |
515 | if (ithr >= work_amount) return; \ |
516 | brgemm_batch_element_t *const brg_batch \ |
517 | = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \ |
518 | char *const c_buffer = (jcp.use_buffer) \ |
519 | ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \ |
520 | : nullptr; \ |
521 | char *inp_buffer = (jcp.is_rtus) \ |
522 | ? inp_buffer_base + ithr * src_dsz * jcp.inp_buffer_size \ |
523 | : nullptr; \ |
524 | uint8_t *__restrict inp_buffer_mask = (jcp.is_rtus) \ |
525 | ? inp_buffer_mask_base + ithr * jcp.inp_buffer_mask_size \ |
526 | : nullptr; \ |
527 | int last_n = -1; \ |
528 | int last_g = -1; \ |
529 | int last_palette_idx = -1; \ |
530 | int start {0}, end {0}; \ |
531 | balance211(work_amount, nthr, ithr, start, end); \ |
532 | int n {0}, g {0}, ocb {0}, oss {0}; \ |
533 | nd_iterator_init(start, __VA_ARGS__); \ |
534 | for (auto work = start; work < end; work++) { \ |
535 | if (jcp.is_rtus && (last_n != n || last_g != g)) \ |
536 | std::memset(inp_buffer_mask, 0, jcp.inp_buffer_mask_size); \ |
537 | const auto osb_start = oss * jcp.nb_os_blocking; \ |
538 | const auto osb_range \ |
539 | = nstl::min(jcp.nb_os - osb_start, jcp.nb_os_blocking); \ |
540 | for (int osb = 0; osb < osb_range; osb++) { \ |
541 | const int os = (osb_start + osb) * jcp.os_block; \ |
542 | const int od = os / (OH * OW); \ |
543 | const int oh = (os % (OH * OW)) / OW; \ |
544 | const int ow = os % OW; \ |
545 | char *inp_buffer_sp = (jcp.is_rtus) \ |
546 | ? inp_buffer + src_dsz * os * jcp.LDA \ |
547 | : nullptr; \ |
548 | for (int icc = 0; icc < pd()->ic_chunks; icc++) { \ |
549 | if (jcp.is_rtus) \ |
550 | maybe_rtus(ithr, brgemm_ctx.src, inp_buffer_sp, \ |
551 | inp_buffer_mask, g, n, icc, od, oh, ow); \ |
552 | exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, \ |
553 | inp_buffer_sp, g, n, ocb, od, oh, ow, icc, \ |
554 | &last_palette_idx, oscales, src_zero_point, \ |
555 | zp_compensation, dst_zp_vals, s8s8_compensation); \ |
556 | } \ |
557 | } \ |
558 | last_n = n; \ |
559 | last_g = g; \ |
560 | nd_iterator_step(__VA_ARGS__); \ |
561 | } \ |
562 | if (is_amx) amx_tile_release(); \ |
563 | }); |
564 | |
565 | if (jcp.loop_order == loop_ndhwgc) |
566 | BRGC_WO(n, jcp.mb, oss, os_chunks, g, jcp.ngroups, ocb, jcp.nb_oc) |
567 | else if (jcp.loop_order == loop_ngcdhw) |
568 | BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, oss, os_chunks) |
569 | else |
570 | assert(!"Unknown loop order" ); |
571 | |
572 | #undef BRGC_WO |
573 | |
574 | } else { |
575 | const int work_amount |
576 | = jcp.mb * jcp.ngroups * jcp.nb_oc * OD * OH * jcp.nb_ow; |
577 | |
578 | #define BRGC_WO(...) \ |
579 | parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { \ |
580 | if (ithr >= work_amount) return; \ |
581 | brgemm_batch_element_t *const brg_batch \ |
582 | = brg_batch_global + (size_t)ithr * jcp.adjusted_batch_size; \ |
583 | char *const c_buffer = (jcp.use_buffer) \ |
584 | ? c_buffer_global + ithr * acc_dsz * jcp.LDC * jcp.M \ |
585 | : nullptr; \ |
586 | int last_palette_idx = -1; \ |
587 | int start {0}, end {0}; \ |
588 | balance211(work_amount, nthr, ithr, start, end); \ |
589 | int n {0}, g {0}, ocb {0}, od {0}, oh {0}, owb {0}; \ |
590 | nd_iterator_init(start, __VA_ARGS__); \ |
591 | for (auto work = start; work < end; work++) { \ |
592 | for (int icc = 0; icc < pd()->ic_chunks; icc++) { \ |
593 | const int ow = owb * jcp.ow_block; \ |
594 | exec_ker(brgemm_ctx, ithr, brg_batch, c_buffer, nullptr, g, n, \ |
595 | ocb, od, oh, ow, icc, &last_palette_idx, oscales, \ |
596 | src_zero_point, zp_compensation, dst_zp_vals, \ |
597 | s8s8_compensation); \ |
598 | } \ |
599 | nd_iterator_step(__VA_ARGS__); \ |
600 | } \ |
601 | if (is_amx) amx_tile_release(); \ |
602 | }); |
603 | |
604 | if (jcp.loop_order == loop_ndhwgc) |
605 | BRGC_WO(n, jcp.mb, od, OD, oh, OH, owb, jcp.nb_ow, g, jcp.ngroups, |
606 | ocb, jcp.nb_oc) |
607 | else if (jcp.loop_order == loop_ngcdhw) |
608 | BRGC_WO(n, jcp.mb, g, jcp.ngroups, ocb, jcp.nb_oc, od, OD, oh, OH, |
609 | owb, jcp.nb_ow) |
610 | else |
611 | assert(!"Unknown loop order" ); |
612 | |
613 | #undef BRGC_WO |
614 | } |
615 | |
616 | return status::success; |
617 | } |
618 | |
619 | template struct brgemm_1x1_convolution_fwd_t<avx2>; |
620 | template struct brgemm_1x1_convolution_fwd_t<avx2_vnni_2>; |
621 | template struct brgemm_1x1_convolution_fwd_t<avx512_core>; |
622 | template struct brgemm_1x1_convolution_fwd_t<avx512_core_vnni>; |
623 | template struct brgemm_1x1_convolution_fwd_t<avx512_core_bf16>; |
624 | template struct brgemm_1x1_convolution_fwd_t<avx512_core_fp16>; |
625 | template struct brgemm_1x1_convolution_fwd_t<avx512_core_amx>; |
626 | template struct brgemm_1x1_convolution_fwd_t<avx512_core_amx_fp16>; |
627 | |
628 | } // namespace x64 |
629 | } // namespace cpu |
630 | } // namespace impl |
631 | } // namespace dnnl |
632 | |
633 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
634 | |