1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/c_types_map.hpp"
20#include "common/memory.hpp"
21#include "common/memory_tracking.hpp"
22#include "common/nstl.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/platform.hpp"
27
28#include "cpu/x64/injectors/injector_utils.hpp"
29#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
30#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
31#include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
32#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
33
34#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39namespace x64 {
40
41using namespace dnnl::impl::utils;
42using namespace dnnl::impl::data_type;
43using namespace dnnl::impl::prop_kind;
44using namespace Xbyak;
45
46template <typename Vmm>
47_jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::
48 _jit_avx512_core_x8s8s32x_1x1_conv_kernel(
49 const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr,
50 const memory_desc_t &dst_md)
51 : jit_generator(jit_name())
52 , jcp(ajcp)
53 , attr_(attr)
54 , postops_injector_(nullptr) {
55 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
56 using namespace binary_injector;
57 static constexpr bool preserve_gpr = true;
58 static constexpr bool preserve_vmm = false;
59 static constexpr unsigned helper_vmm_idx = 31;
60 const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
61 const size_t tail_size = oc_block_tail
62 ? oc_block_tail
63 : jcp.oc_without_padding % isa_simd_width_;
64 static constexpr bool use_exact_tail_scalar_bcast = true;
65
66 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
67 r14, r15, r13, preserve_gpr, preserve_vmm,
68 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
69 memory_desc_wrapper(dst_md), tail_size, postops_mask,
70 use_exact_tail_scalar_bcast};
71 const static_params_t static_params {
72 this->param1, rhs_arg_static_params};
73
74 postops_injector_ = utils::make_unique<
75 injector::jit_uni_postops_injector_t<avx512_core, Vmm>>(
76 this, jcp.post_ops, static_params);
77 }
78 if (jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa))
79 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
80 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
81 bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_5);
82}
83
84template <typename Vmm>
85void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::bcast_loop(
86 int load_loop_blk) {
87 mov(aux1_reg_bcast_data, reg_bcast_data);
88 mov(aux_reg_bcast_data, reg_bcast_data);
89
90 mov(aux_reg_output_data, reg_output_data);
91 mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off));
92
93 Label bcast_loop;
94 Label bcast_loop_tail;
95
96 cmp(bcast_loop_iter, jcp.ur);
97 jl(bcast_loop_tail, T_NEAR);
98
99 L(bcast_loop);
100 {
101 assert(jcp.bcast_block % jcp.ur == 0);
102 int num_substeps = jcp.bcast_block / jcp.ur;
103 assert(num_substeps > 0 && num_substeps < 10);
104 for (int i = 0; i < num_substeps; i++) {
105 reduce_loop(load_loop_blk, jcp.ur, false);
106 if (i < num_substeps - 1) {
107 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
108 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
109 } else {
110 add(aux1_reg_bcast_data,
111 jcp.bcast_loop_bcast_step
112 - (num_substeps - 1)
113 * jcp.bcast_loop_bcast_substep);
114 int output_offset = jcp.bcast_loop_output_step
115 - (num_substeps - 1) * jcp.bcast_loop_output_substep;
116
117 add(aux_reg_output_data, output_offset);
118 }
119 }
120 sub(bcast_loop_iter, jcp.bcast_block);
121 cmp(bcast_loop_iter, jcp.bcast_block);
122 jge(bcast_loop, T_NEAR);
123 }
124
125 L(bcast_loop_tail);
126 if (jcp.ur_tail) {
127 Label bcast_loop_tail_out;
128 cmp(bcast_loop_iter, 0);
129 jz(bcast_loop_tail_out, T_NEAR);
130 reduce_loop(load_loop_blk, jcp.ur_tail, true);
131 L(bcast_loop_tail_out);
132 }
133}
134
135template <typename Vmm>
136void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::cvt2ps(data_type_t type_in,
137 const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag) {
138 using namespace data_type;
139 const Vmm vmm = mask_flag ? vmm_in | k_load_dim_mask | T_z : vmm_in;
140 switch (type_in) {
141 case f32:
142 case s32: vmovups(vmm, op); break;
143 case bf16: vpmovzxwd(vmm, op); break;
144 case s8: vpmovsxbd(vmm, op); break;
145 case u8: vpmovzxbd(vmm, op); break;
146 default: assert(!"unsupported data type");
147 }
148 if (one_of(type_in, s32, s8, u8))
149 vcvtdq2ps(vmm_in, vmm_in);
150 else if (type_in == bf16)
151 vpslld(vmm_in, vmm_in, 16);
152}
153
154template <typename F>
155static void iterate(const int load_loop_blk, const int ur,
156 const bool last_oc_block_flag, const bool force_masking, const F &f) {
157 for (int i_load = 0; i_load < load_loop_blk; i_load++) {
158 const bool mask_flag = force_masking
159 || (last_oc_block_flag && i_load + 1 == load_loop_blk);
160 for (int i_ur = 0; i_ur < ur; i_ur++)
161 f(mask_flag, i_load, i_ur);
162 }
163}
164template <typename F>
165static void iterate(const int load_loop_blk, const int ur,
166 const bool last_oc_block_flag, const F &f) {
167 iterate(load_loop_blk, ur, last_oc_block_flag, false, f);
168}
169template <typename F>
170static void iterate(const int load_loop_blk, const int ur, const F &f) {
171 iterate(load_loop_blk, ur, false, false, f);
172}
173
174template <typename Vmm>
175Address _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::output_ptr(
176 const int i_load, const int i_ur) {
177 const size_t ur_stride = jcp.with_dw_conv
178 ? jcp.nb_load_blocking * jcp.oc_block * i_ur
179 : jcp.oc_without_padding * jcp.ngroups * i_ur;
180
181 return EVEX_compress_addr(aux_reg_output_data,
182 jcp.typesize_out * (ur_stride + i_load * jcp.load_block));
183};
184
185template <typename Vmm>
186int _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::vreg_accum_idx(
187 const int load_loop_blk, int i_load, int i_ur) const {
188 return (i_ur * load_loop_blk + i_load);
189};
190
191template <typename Vmm>
192Vmm _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::vreg_accum(
193 const int load_loop_blk, int i_load, int i_ur) const {
194 return Vmm(vreg_accum_idx(load_loop_blk, i_load, i_ur));
195};
196
197template <typename Vmm>
198void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::apply_sum(
199 const int load_loop_blk, const int ur, const bool mask_flag_in,
200 const float *p_sum_scale, const int32_t *p_sum_zp) {
201 if (jcp.with_sum) {
202 const float sum_scale = *p_sum_scale;
203 const int32_t sum_zp = *p_sum_zp;
204 const auto sum_injector_lam
205 = [this, sum_scale, sum_zp, load_loop_blk](const bool mask_flag,
206 const int i_load, const int i_ur) {
207 const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
208 cvt2ps(jcp.sum_dt, vmm_prev_dst, output_ptr(i_load, i_ur),
209 mask_flag);
210 if (sum_zp != 0) vsubps(vmm_prev_dst, vmm_tmp);
211 if (sum_scale == 1.f)
212 vaddps(r, vmm_prev_dst);
213 else
214 vfmadd231ps(
215 r, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
216 };
217 const auto sum_injector = [=]() {
218 iterate(load_loop_blk, ur, mask_flag_in, sum_injector_lam);
219 };
220 if (sum_zp != 0) vcvtdq2ps(vmm_tmp, ptr_b[rsp + reg_ptr_sum_zp_off]);
221 postops_injector_->set_lambda_injector(
222 primitive_kind::sum, sum_injector);
223 }
224}
225
226template <typename Vmm>
227void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::apply_postops(
228 const int load_loop_blk, const int ur, const bool mask_flag_in,
229 const float *p_sum_scale, const int32_t *p_sum_zp) {
230 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
231
232 apply_sum(load_loop_blk, ur, mask_flag_in, p_sum_scale, p_sum_zp);
233
234 injector_utils::vmm_index_set_t vmm_idxs;
235 if (jcp.with_binary) {
236 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params,
237 rhs_arg_params_tail;
238 const auto mask_tail = jcp.oc_without_padding % jcp.load_block;
239 const bool oc_blk_is_smaller_than_vmm
240 = jcp.oc_block < isa_simd_width_;
241 iterate(load_loop_blk, ur, mask_tail, oc_blk_is_smaller_than_vmm,
242 [&](const bool mask_flag, const int i_load,
243 const int i_ur) {
244 const int ur_stride = jcp.with_dw_conv
245 ? jcp.nb_load_blocking * jcp.oc_block * i_ur
246 : jcp.oc_without_padding * jcp.ngroups * i_ur;
247 const size_t aux_output_l_off = jcp.typesize_out
248 * (ur_stride + i_load * jcp.load_block);
249 const auto vmm_idx
250 = vreg_accum_idx(load_loop_blk, i_load, i_ur);
251 vmm_idxs.emplace(vmm_idx);
252
253 rhs_arg_params_tail.vmm_idx_to_out_reg.emplace(
254 vmm_idx, aux_reg_output_data);
255 rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace(
256 vmm_idx, aux_output_l_off);
257 if (mask_flag)
258 rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx);
259 });
260 rhs_arg_params = rhs_arg_params_tail;
261 rhs_arg_params.vmm_tail_idx_.clear();
262
263 mov(abi_param1, EVEX_compress_addr(rsp, reg_abi_param1_backup));
264
265 Label postops_done;
266 if (mask_tail || oc_blk_is_smaller_than_vmm) {
267 Label postops_no_tail;
268 if (mask_tail) {
269 test(reg_reduce_pos_flag, FLAG_OC_LAST);
270 jz(postops_no_tail, T_NEAR);
271 cmp(reg_load_loop_work, 0);
272 jg(postops_no_tail, T_NEAR);
273 }
274 postops_injector_->compute_vector_range(
275 vmm_idxs, rhs_arg_params_tail);
276 jmp(postops_done, T_NEAR);
277 L(postops_no_tail);
278 }
279 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
280 L(postops_done);
281
282 } else {
283 iterate(load_loop_blk, ur,
284 [&](const bool, const int i_load, const int i_ur) {
285 vmm_idxs.emplace(
286 vreg_accum_idx(load_loop_blk, i_load, i_ur));
287 });
288 postops_injector_->compute_vector_range(vmm_idxs);
289 }
290 }
291}
292
293template <typename Vmm>
294void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::reduce_loop(
295 int load_loop_blk, int ur, bool wraparound) {
296 auto vreg_load
297 = [=](int i_load) { return Vmm(ur * load_loop_blk + i_load); };
298
299 auto bias_ptr = [=](int i_load) {
300 return EVEX_compress_addr(
301 reg_bias_data, jcp.typesize_bia * jcp.oc_block * i_load);
302 };
303
304 auto comp_ptr = [=](int i_load) {
305 return EVEX_compress_addr(
306 reg_comp_data, sizeof(int32_t) * jcp.oc_block * i_load);
307 };
308
309 auto scale_ptr = [=](int i_load) {
310 return EVEX_compress_addr(reg_ptr_scales,
311 jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load));
312 };
313
314 auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
315 assert(i_ur < jcp.ur);
316 assert(i_reduce <= jcp.reduce_loop_unroll);
317 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
318
319 int offt = (jcp.ic_without_padding * i_ur * jcp.ngroups + i_reduce);
320
321 return EVEX_compress_addr(
322 aux_reg_bcast_data, jcp.typesize_in * offt, bcast);
323 };
324
325 auto load_ptr = [=](int i_reduce, int i_load) {
326 int u0 = i_reduce % jcp.reduce_loop_unroll;
327 int u1 = i_reduce / jcp.reduce_loop_unroll;
328
329 int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
330
331 return EVEX_compress_addr(aux_reg_load_data,
332 u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt);
333 };
334
335 auto init = [=]() {
336 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
337 for (int i_ur = 0; i_ur < ur; ++i_ur) {
338 auto r = vreg_accum(load_loop_blk, i_load, i_ur);
339 vpxord(r, r, r);
340 }
341 if (jcp.signed_input) {
342 mov(reg_scratch, -128);
343 vpbroadcastb(vmm_shift, reg_scratch.cvt8());
344 }
345 };
346
347 auto store = [=](const bool mask_flag_in) {
348 const auto &p = attr_.post_ops_;
349 const int sum_idx = p.find(primitive_kind::sum);
350 const float *p_sum_scale = nullptr;
351 const int32_t *p_sum_zp = nullptr;
352 if (sum_idx != -1) {
353 p_sum_scale = &p.entry_[sum_idx].sum.scale;
354 p_sum_zp = &p.entry_[sum_idx].sum.zero_point;
355 }
356 const auto p_sum_scale_val = p_sum_scale ? *p_sum_scale : 1.f;
357 const auto p_sum_zp_val = p_sum_zp ? *p_sum_zp : 0;
358 const bool is_scale_or_zp_sum
359 = p_sum_zp_val != 0 || p_sum_scale_val != 1.f;
360 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
361 mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
362 if (is_scale_or_zp_sum) {
363 mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data);
364 if (p_sum_zp_val != 0) {
365 mov(reg_load_data, p_sum_zp_val);
366 mov(ptr[rsp + reg_ptr_sum_zp_off], reg_load_data);
367 }
368 if (p_sum_scale_val != 1.f)
369 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
370 }
371 if (jcp.signed_input && (!jcp.has_vnni)) {
372 mov(reg_scratch, float2int(jcp.wei_adj_scale));
373 }
374 if (jcp.src_zero_point) {
375 mov(reg_zp_compensation,
376 EVEX_compress_addr(rsp, reg_zp_compensation_off));
377 mov(reg_src_zero_point,
378 EVEX_compress_addr(rsp, reg_src_zero_point_off));
379 }
380 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
381 const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
382 auto vmm_bias = vmm_tmp;
383 auto vmm_comp = vmm_bcast;
384 if (jcp.with_bias) {
385 if (jcp.signed_input || jcp.dst_scale)
386 mov(reg_bias_data,
387 EVEX_compress_addr(rsp, reg_bias_data_off));
388 cvt2ps(jcp.bia_dt, vmm_bias, bias_ptr(i_load), mask_flag);
389 }
390 if (jcp.signed_input) {
391 mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
392 cvt2ps(data_type::s32, vmm_comp, comp_ptr(i_load), mask_flag);
393 }
394 if (jcp.src_zero_point) {
395 // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
396 const int zp_offset = sizeof(int32_t) * i_load * jcp.load_block;
397 vmovups(vmm_zp,
398 EVEX_compress_addr(reg_zp_compensation, zp_offset));
399 vpmulld(vmm_zp, vmm_zp,
400 EVEX_compress_addr(
401 reg_src_zero_point, 0, jcp.zp_src_is_common));
402 // upscale to f32
403 const Vmm vmm_
404 = mask_flag ? vmm_zp | k_load_dim_mask | T_z : vmm_zp;
405 vcvtdq2ps(vmm_, vmm_);
406 }
407 for (int i_ur = 0; i_ur < ur; ++i_ur) {
408 auto r = vreg_accum(load_loop_blk, i_load, i_ur);
409 vcvtdq2ps(r, r);
410 if (jcp.signed_input) vaddps(r, r, vmm_comp);
411 if (jcp.src_zero_point) vaddps(r, r, vmm_zp);
412
413 const Vmm mask_vmm = mask_flag ? r | k_load_dim_mask | T_z : r;
414 vmulps(mask_vmm, r, scale_ptr(i_load));
415
416 if (jcp.with_bias) vaddps(r, r, vmm_bias);
417 }
418 }
419
420 apply_postops(load_loop_blk, ur, mask_flag_in, p_sum_scale, p_sum_zp);
421
422 if (jcp.dst_scale) {
423 mov(reg_ptr_dst_scale, EVEX_compress_addr(rsp, reg_dst_scale_off));
424 vmovups(vmm_dst_scale, EVEX_compress_addr(reg_ptr_dst_scale, 0));
425
426 /* Apply dst scale to accumulator */
427 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
428 const bool mask_flag
429 = mask_flag_in && i_load == load_loop_blk - 1;
430 for (int i_ur = 0; i_ur < ur; ++i_ur) {
431 const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
432 const Vmm mask_vmm
433 = mask_flag ? r | k_load_dim_mask | T_z : r;
434 vmulps(mask_vmm, r, vmm_dst_scale);
435 }
436 }
437 }
438
439 if (jcp.dst_zero_point) {
440 mov(reg_dst_zero_point,
441 EVEX_compress_addr(rsp, reg_dst_zero_point_off));
442 vcvtdq2ps(vmm_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true));
443
444 /* Add dst zero_point to accumulator */
445 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
446 for (int i_ur = 0; i_ur < ur; ++i_ur) {
447 const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
448 vaddps(r, r, vmm_zp);
449 }
450 }
451 }
452
453 // Properly saturate the accumulators for integer datatypes
454 if (one_of(jcp.dst_dt, u8, s8, s32)) {
455 init_saturate_f32(vmm_zero, vmm_saturation,
456 reg_ptr_saturation_ubound, f32, jcp.dst_dt);
457 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
458 for (int i_ur = 0; i_ur < ur; ++i_ur) {
459 auto r = vreg_accum(load_loop_blk, i_load, i_ur);
460 saturate_f32(r, vmm_zero, vmm_saturation, jcp.dst_dt);
461 vcvtps2dq(r, r);
462 }
463 }
464 }
465
466 if (jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa))
467 bf16_emu_->init_vcvtneps2bf16();
468
469 // store to the destination
470 if (jcp.dst_dt == data_type::bf16 && isa_has_bf16(jcp.isa)) {
471 // Optimization: use single store instruction for pair
472 // of the nearest vectors along LOAD dimension
473 for (int i_ur = 0; i_ur < ur; i_ur++) {
474 int i_load = 0;
475 for (; i_load < rnd_dn(load_loop_blk, 2); i_load += 2) {
476 auto vmm_dst = vreg_accum(load_loop_blk, i_load, i_ur);
477 auto vmm_dst_next
478 = vreg_accum(load_loop_blk, i_load + 1, i_ur);
479 vcvtne2ps2bf16(vmm_dst, vmm_dst_next, vmm_dst);
480 bool mask_flag
481 = mask_flag_in && i_load + 2 == load_loop_blk;
482 vmovdqu16(output_ptr(i_load, i_ur),
483 maybe_mask_vmm(vmm_dst, mask_flag));
484 }
485 if (load_loop_blk % 2 != 0) {
486 auto vmm_accum = vreg_accum(load_loop_blk, i_load, i_ur);
487 auto vmm_down = Vmm_down_t(vmm_accum.getIdx());
488 vcvtneps2bf16(vmm_down, vmm_accum);
489 vmovdqu16(output_ptr(i_load, i_ur),
490 maybe_mask_vmm_down(vmm_down,
491 jcp.ic_block == 4 || mask_flag_in));
492 }
493 }
494 } else {
495 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
496 const bool mask_flag
497 = mask_flag_in && i_load == load_loop_blk - 1;
498 for (int i_ur = 0; i_ur < ur; ++i_ur) {
499 auto r = vreg_accum(load_loop_blk, i_load, i_ur);
500 const Vmm r_vmm = mask_flag ? r | k_load_dim_mask : r;
501
502 switch (jcp.dst_dt) {
503 case data_type::f32:
504 case data_type::s32:
505 vmovups(output_ptr(i_load, i_ur), r_vmm);
506 break;
507 case data_type::s8:
508 vpmovsdb(output_ptr(i_load, i_ur), r_vmm);
509 break;
510 case data_type::u8:
511 vpmovusdb(output_ptr(i_load, i_ur), r_vmm);
512 break;
513 case data_type::bf16: {
514 bf16_emu_->vcvtneps2bf16(
515 ymm_store, Zmm(r.getIdx()));
516 vmovdqu16(output_ptr(i_load, i_ur),
517 maybe_mask_vmm_down(vmm_store(),
518 jcp.ic_block == 4 || mask_flag));
519 } break;
520 default: assert(!"unknown dst_dt");
521 }
522 }
523 }
524 }
525 mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
526 if (is_scale_or_zp_sum)
527 mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off));
528 };
529
530 auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
531 if (jcp.has_vnni) {
532 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
533 } else {
534 vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
535 vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
536 vpaddd(vreg_acc, vreg_acc, vmm_tmp);
537 }
538 };
539
540 auto fma_block = [=](bool last_block) {
541 int reduce_step = 4;
542 int ic_tail_size = jcp.ic_without_padding % reduce_step;
543 int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding
544 ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step)
545 : jcp.reduce_loop_unroll;
546 for (int i_reduce = 0; i_reduce < loop_unroll;
547 i_reduce += reduce_step) {
548 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
549 vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
550 for (int i_ur = 0; i_ur < ur; ++i_ur) {
551 if (last_block && ic_tail_size != 0
552 && i_reduce == loop_unroll - reduce_step) {
553 Xmm xmm_bcast = Xmm(vmm_bcast.getIdx());
554 load_bytes(xmm_bcast, aux_reg_bcast_data,
555 jcp.ic_without_padding * i_ur + i_reduce,
556 ic_tail_size);
557 vpbroadcastd(vmm_bcast, xmm_bcast);
558 } else {
559 vpbroadcastd(vmm_bcast, bcast_ptr(i_reduce, i_ur, false));
560 }
561 if (jcp.signed_input) vpsubb(vmm_bcast, vmm_bcast, vmm_shift);
562 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
563 compute(vreg_accum(load_loop_blk, i_load, i_ur),
564 vreg_load(i_load), vmm_bcast);
565 }
566 }
567 }
568 };
569
570 Label reduce_loop;
571 Label reduce_loop_tail;
572
573 mov(aux_reg_load_data, reg_load_data);
574
575 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
576 init();
577
578 mov(reduce_loop_iter, reg_reduce_loop_work);
579 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
580 jle(reduce_loop_tail, T_NEAR);
581
582 L(reduce_loop);
583 {
584 fma_block(false);
585 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
586 add(aux_reg_load_data, jcp.reduce_loop_load_step);
587 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
588 jg(reduce_loop, T_NEAR);
589 }
590
591 L(reduce_loop_tail);
592 if (jcp.ic != jcp.ic_without_padding) {
593 fma_block(true);
594 } else {
595 fma_block(false);
596 }
597
598 if (jcp.oc_without_padding != jcp.oc) {
599 Label end_store, common_store;
600 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
601
602 /*Check if it is the last load_loop_blk*/
603 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
604 cmp(reg_load_loop_work, 0);
605 jg(common_store, T_NEAR);
606
607 /*Check if it is the last ocb*/
608 test(reg_reduce_pos_flag, FLAG_OC_LAST);
609 jz(common_store, T_NEAR);
610
611 store(true);
612 jmp(end_store, T_NEAR);
613
614 L(common_store);
615 store(false);
616
617 L(end_store);
618
619 add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
620 } else {
621 store(false);
622 }
623}
624
625template <typename Vmm>
626void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::generate() {
627
628 preamble();
629
630 const int simd_w = jcp.ic_block;
631 xor_(reg_scratch, reg_scratch);
632 Reg16 _t = reg_scratch.cvt16();
633 mov(_t, 0x1);
634 vpbroadcastw(vmm_one, _t);
635
636 sub(rsp, stack_space_needed);
637 if (jcp.with_binary) {
638 const auto zeroed_reg = r15;
639 xor_(zeroed_reg, zeroed_reg);
640 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), zeroed_reg);
641 mov(EVEX_compress_addr(rsp, reg_abi_param1_backup), abi_param1);
642 }
643
644 if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
645 if (jcp.signed_input) {
646 mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
647 mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]);
648 mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
649 }
650 if (jcp.src_zero_point) {
651 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
652 mov(EVEX_compress_addr(rsp, reg_zp_compensation_off),
653 reg_zp_compensation);
654 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
655 mov(EVEX_compress_addr(rsp, reg_src_zero_point_off),
656 reg_src_zero_point);
657 }
658 if (jcp.dst_scale) {
659 if (!jcp.signed_input)
660 mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
661 mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
662 mov(EVEX_compress_addr(rsp, reg_dst_scale_off), reg_ptr_dst_scale);
663 }
664 if (jcp.dst_zero_point) {
665 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
666 mov(EVEX_compress_addr(rsp, reg_dst_zero_point_off),
667 reg_dst_zero_point);
668 }
669 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
670 mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
671 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
672 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
673 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
674
675 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
676 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
677 mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work);
678 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
679 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
680
681 if (jcp.ic_block == 4 && jcp.dst_dt == data_type::bf16) {
682 Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32();
683 mov(reg_tail_32, (1 << jcp.ic_block) - 1);
684 kmovb(k_load_dim_tail_mask, reg_tail_32);
685 }
686
687 const int load_dim_tail
688 = (one_of(jcp.prop_kind, forward_training, forward_inference)
689 ? jcp.oc_without_padding
690 : jcp.load_dim)
691 % jcp.load_block;
692 const bool use_extended_mask
693 = jcp.dst_dt == data_type::bf16 && isa_has_bf16(jcp.isa);
694 if (load_dim_tail) {
695 Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32();
696 mov(reg_tail_32, (1 << load_dim_tail) - 1);
697 kmovw(k_load_dim_tail_mask, reg_tail_32);
698 kmovw(postops_mask, reg_tail_32);
699
700 if (use_extended_mask) {
701 mov(reg_tail_32.cvt32(),
702 (1 << (load_dim_tail + jcp.load_block)) - 1);
703 kmovd(k_load_dim_tail_mask_extended, reg_tail_32.cvt32());
704 }
705 } else if (jcp.with_binary)
706 if (jcp.oc_block != isa_simd_width_) {
707 const int mask = (1 << jcp.oc_block) - 1;
708 const Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32();
709 mov(reg_tail_32, mask);
710 kmovw(postops_mask, reg_tail_32);
711 }
712
713 auto load_loop_body = [=](int load_loop_blk) {
714 if (load_dim_tail) {
715 kxnorw(k_load_dim_mask, k_load_dim_mask, k_load_dim_mask);
716 if (use_extended_mask)
717 kxnord(k_load_dim_mask_extended, k_load_dim_mask_extended,
718 k_load_dim_mask_extended);
719 Label no_update_mask;
720 test(reg_reduce_pos_flag, FLAG_OC_LAST);
721 jz(no_update_mask, T_NEAR);
722 cmp(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
723 jg(no_update_mask, T_NEAR);
724 kmovw(k_load_dim_mask, k_load_dim_tail_mask);
725 if (use_extended_mask)
726 kmovd(k_load_dim_mask_extended, k_load_dim_tail_mask_extended);
727 L(no_update_mask);
728 } else if (jcp.ic_block == 4 && jcp.dst_dt == data_type::bf16) {
729 kmovw(k_load_dim_mask, k_load_dim_tail_mask);
730 }
731
732 bcast_loop(load_loop_blk);
733 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
734 if (jcp.with_bias) {
735 if (jcp.signed_input || jcp.dst_scale)
736 mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off));
737 add(reg_bias_data,
738 load_loop_blk * jcp.load_block * jcp.typesize_bia);
739 if (jcp.signed_input || jcp.dst_scale)
740 mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
741 }
742 if (jcp.with_binary) {
743 mov(reg_scratch,
744 EVEX_compress_addr(rsp, reg_binary_post_op_acc_off));
745 add(reg_scratch, jcp.load_block * load_loop_blk);
746 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off),
747 reg_scratch);
748 }
749 if (jcp.signed_input) {
750 mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
751 add(reg_comp_data,
752 load_loop_blk * jcp.load_block * sizeof(int32_t));
753 mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
754 }
755 if (jcp.src_zero_point) {
756 mov(reg_zp_compensation,
757 EVEX_compress_addr(rsp, reg_zp_compensation_off));
758 add(reg_zp_compensation,
759 load_loop_blk * jcp.load_block * sizeof(int32_t));
760 mov(EVEX_compress_addr(rsp, reg_zp_compensation_off),
761 reg_zp_compensation);
762 }
763 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
764 mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
765 add(reg_ptr_scales,
766 jcp.is_oc_scale * load_loop_blk * jcp.load_block
767 * sizeof(float));
768 mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
769 mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
770 add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out);
771 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
772 };
773
774 Label load_loop_blk[7];
775
776 static const int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32};
777 const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast);
778 const int *ur_cases_fma = ur_cases_fma_expl_bcast;
779 const int *ur_cases = ur_cases_fma;
780 const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
781
782 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
783 int label_idx = num_ur_cases - ur_idx - 1;
784 if (jcp.ur <= ur_cases[ur_idx]) {
785 cmp(reg_load_loop_work, simd_w * (label_idx + 1));
786 jle(load_loop_blk[label_idx], T_NEAR);
787 }
788 }
789
790 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
791 if (jcp.ur <= ur_cases[ur_idx]) {
792 int label_idx = num_ur_cases - ur_idx - 1;
793 L(load_loop_blk[label_idx]);
794 {
795 if (label_idx == 0) {
796 cmp(reg_load_loop_work, 0);
797 je(load_loop_blk[num_ur_cases], T_NEAR);
798 }
799
800 for (int _i = 1; _i <= label_idx + 1; _i++) {
801 prefetcht0(ptr[reg_load_data + _i * jcp.ic * jcp.oc_block]);
802 prefetcht1(ptr[reg_output_data + _i * jcp.oc_block]);
803 }
804
805 load_loop_body(label_idx + 1);
806 if (label_idx - 1 > 0) {
807 cmp(reg_load_loop_work, 2 * label_idx * simd_w);
808 je(load_loop_blk[label_idx - 1], T_NEAR);
809 }
810 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
811 jge(load_loop_blk[label_idx]);
812 }
813 for (int idx = label_idx - 1; idx > 0; --idx) {
814 cmp(reg_load_loop_work, simd_w * (idx + 1));
815 je(load_loop_blk[idx], T_NEAR);
816 }
817 if (ur_idx < num_ur_cases - 2) {
818 cmp(reg_load_loop_work, simd_w);
819 jle(load_loop_blk[0], T_NEAR);
820 }
821 }
822 }
823 L(load_loop_blk[num_ur_cases]);
824
825 add(rsp, stack_space_needed);
826
827 postamble();
828
829 if (jcp.with_eltwise) postops_injector_->prepare_table();
830}
831
832status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
833 jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
834 const memory_desc_t *&src_md, memory_desc_t &weights_md,
835 memory_desc_t &dst_md, memory_desc_t &bias_md,
836 const primitive_attr_t &attr, int nthreads, bool reduce_src) {
837
838 if (!mayiuse(avx512_core)) return status::unimplemented;
839
840 // used for bf16 output
841 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
842 : bf16_emulation_t::get_isa();
843
844 const memory_desc_wrapper src_d(src_md);
845 const memory_desc_wrapper weights_d(&weights_md);
846 const memory_desc_wrapper dst_d(&dst_md);
847 const memory_desc_wrapper bias_d(&bias_md);
848
849 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
850 if (!one_of(src_d.data_type(), data_type::u8, data_type::s8)
851 || weights_d.data_type() != data_type::s8
852 || !one_of(dst_d.data_type(), data_type::f32, data_type::s32,
853 data_type::s8, data_type::u8, data_type::bf16))
854 return status::unimplemented;
855
856 jcp.nthr = nthreads;
857
858 jcp.has_vnni = mayiuse(avx512_core_vnni);
859
860 int ndims = src_d.ndims();
861 jcp.ndims = ndims;
862
863 jcp.prop_kind = cd.prop_kind;
864
865 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
866 jcp.mb = src_d.dims()[0];
867 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
868 jcp.oc_without_padding = jcp.oc;
869 jcp.ic = src_d.dims()[1] / jcp.ngroups;
870 jcp.ic_without_padding = jcp.ic;
871
872 const bool is_1d = ndims == 3;
873 const bool is_3d = ndims == 5;
874
875 jcp.id = is_3d ? src_d.dims()[2] : 1;
876 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
877 jcp.iw = src_d.dims()[ndims - 1];
878 jcp.od = is_3d ? dst_d.dims()[2] : 1;
879 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
880 jcp.ow = dst_d.dims()[ndims - 1];
881
882 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
883 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
884 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
885
886 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
887 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
888 jcp.l_pad = cd.padding[0][ndims - 3];
889
890 jcp.stride_d = is_3d ? cd.strides[0] : 1;
891 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
892 jcp.stride_w = cd.strides[ndims - 3];
893
894 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
895 jcp.signed_input = (src_d.data_type() == data_type::s8);
896
897 jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow;
898 jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw;
899
900 if (jcp.os > INT_MAX || jcp.is > INT_MAX) return status::unimplemented;
901
902 const auto &post_ops = attr.post_ops_;
903 const int dw_conv_ind = post_ops.find(primitive_kind::convolution);
904 jcp.with_dw_conv = dw_conv_ind != -1;
905 // Using dw_conv_ind as upper-bound below, as post-ops after it will be
906 // handled in depthwise convolution.
907 const int eltwise_ind
908 = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind);
909 jcp.with_eltwise = eltwise_ind != -1;
910
911 const int binary_ind
912 = post_ops.find(primitive_kind::binary, 0, dw_conv_ind);
913 jcp.with_binary = binary_ind != -1;
914
915 const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind);
916 jcp.with_sum = sum_ind != -1;
917
918 if (dw_conv_ind >= 0) {
919 // dw_conv and post_ops after it are handled externally, so skip them
920 jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(),
921 post_ops.entry_.cbegin() + dw_conv_ind);
922 } else {
923 jcp.post_ops = post_ops;
924 }
925
926 const auto zp = attr.zero_points_;
927 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
928 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
929 jcp.zp_src_is_common
930 = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel
931 assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common));
932
933 if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.with_dw_conv)
934 return status::unimplemented;
935
936 format_tag_t dat_tag = utils::pick(
937 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
938 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
939 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
940
941 bool args_ok = jcp.src_tag == dat_tag && jcp.dst_tag == dat_tag;
942 if (!args_ok) return status::unimplemented;
943
944 if (jcp.ngroups == 1) {
945 jcp.oc = rnd_up(jcp.oc, 16);
946 jcp.ic = rnd_up(jcp.ic, 16);
947 }
948
949 using namespace injector;
950 static constexpr bool sum_at_pos_0_only = false;
951 static constexpr bool sum_requires_scale_one = false;
952 static constexpr bool sum_requires_zp_zero = false;
953 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
954 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
955 sum_requires_zp_zero});
956 if (!post_ops_ok_) return status::unimplemented;
957
958 const int simd_w = (jcp.ic % 16 == 0 && jcp.oc % 16 == 0)
959 ? 16
960 : (jcp.ic % 8 == 0 && jcp.oc % 8 == 0) ? 8 : 4;
961
962 auto set_or_check_wei_format = [&]() -> bool {
963 using namespace format_tag;
964 using namespace memory_extra_flags;
965 const format_tag_t wei_tags[3][2][3]
966 = {{{OIw4i16o4i, OIhw4i16o4i, OIdhw4i16o4i},
967 {gOIw4i16o4i, gOIhw4i16o4i, gOIdhw4i16o4i}},
968 {{OIw2i8o4i, OIhw2i8o4i, OIdhw2i8o4i},
969 {gOIw2i8o4i, gOIhw2i8o4i, gOIdhw2i8o4i}},
970 {{OIw4o4i, OIhw4o4i, OIdhw4o4i},
971 {gOIw4o4i, gOIhw4o4i, gOIdhw4o4i}}};
972
973 const int simd_idx = simd_w == 16 ? 0 : simd_w == 8 ? 1 : 2;
974 const auto wei_tag = wei_tags[simd_idx][with_groups][ndims - 3];
975 memory_desc_t want_wei_md = weights_md;
976 memory_desc_init_by_tag(want_wei_md, wei_tag);
977 if (jcp.signed_input) {
978 want_wei_md.extra.flags = 0 | compensation_conv_s8s8 | scale_adjust;
979 want_wei_md.extra.compensation_mask
980 = (1 << 0) + (with_groups ? (1 << 1) : 0);
981 want_wei_md.extra.scale_adjust
982 = mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
983 }
984 if (jcp.src_zero_point) {
985 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
986 want_wei_md.extra.asymm_compensation_mask
987 = (1 << 0) + (with_groups ? (1 << 1) : 0);
988 }
989
990 if (weights_md.format_kind == format_kind::any) {
991 weights_md = want_wei_md;
992 return true;
993 }
994
995 return weights_md == want_wei_md;
996 };
997
998 if (!set_or_check_wei_format()) return status::unimplemented;
999
1000 args_ok = true && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
1001 && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0
1002 && jcp.stride_d == 1 && jcp.stride_h == 1
1003 && jcp.stride_w == 1 // TODO: support some strides
1004 && jcp.od == jcp.id && jcp.oh == jcp.ih
1005 && jcp.ow == jcp.iw // enforce rpad = 0
1006 && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1;
1007 if (!args_ok) return status::unimplemented;
1008
1009 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1010 jcp.dst_dt = cd.dst_desc.data_type;
1011 jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt);
1012
1013 jcp.ic_block = jcp.oc_block = simd_w;
1014
1015 jcp.typesize_in = types::data_type_size(src_d.data_type());
1016 jcp.typesize_out = types::data_type_size(dst_d.data_type());
1017 jcp.typesize_bia
1018 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
1019
1020 const int SMALL_SPATIAL = 7 * 7;
1021 const int BIG_REDUCE_DIM = 1024;
1022
1023 int load_blocking = 0;
1024 int load_blocking_max = 0;
1025 int bcast_blocking = 0;
1026 int bcast_blocking_max = 0;
1027 int reduce_blocking = 0;
1028 int reduce_blocking_max = 0;
1029 jcp.load_grp_count = 1;
1030 jcp.use_vmovntps = false;
1031
1032 const int L2_size
1033 = platform::get_per_core_cache_size(2) / sizeof(jcp.typesize_in);
1034 const int L2_capacity = (L2_size * 3) / 4;
1035
1036 const bool req_extra_bf16_regs
1037 = jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa);
1038 int size_treshold = req_extra_bf16_regs ? 25 : 28;
1039 int max_regs = 0;
1040 int min_regs = 6;
1041 if (jcp.has_vnni && !req_extra_bf16_regs)
1042 max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold)
1043 && (jcp.oc < 128 || jcp.ic < 128))
1044 ? min_regs
1045 : 9;
1046 else
1047 max_regs = 8;
1048 jcp.expl_bcast = true;
1049
1050 if (jcp.mb == 1 && jcp.ic > 128
1051 && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) {
1052 if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size)
1053 max_regs = min_regs; // mobilenet_v2 performance improvement
1054 jcp.ur = nstl::min<dim_t>(max_regs, jcp.os);
1055 } else {
1056 const int spatial = jcp.od * jcp.oh;
1057 jcp.ur = 1;
1058 for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
1059 if ((spatial >= size_treshold && spatial % ur_w == 0)
1060 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
1061 jcp.ur = ur_w;
1062 break;
1063 }
1064 }
1065 if (jcp.ur == 1) {
1066 jcp.ur = nstl::min<dim_t>(max_regs, jcp.os);
1067 int os_tail = jcp.os % max_regs;
1068 for (int i = max_regs; i >= min_regs; i--) {
1069 int i_tail = jcp.os % i;
1070 if (i_tail > os_tail || i_tail == 0) {
1071 jcp.ur = i;
1072 os_tail = i_tail;
1073 if (i_tail == 0) break;
1074 }
1075 }
1076 }
1077 }
1078 if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
1079
1080 jcp.reduce_dim = jcp.ic;
1081 jcp.reduce_block = jcp.ic_block;
1082
1083 jcp.load_dim = jcp.oc;
1084 jcp.load_block = jcp.oc_block;
1085
1086 jcp.bcast_dim = jcp.is;
1087
1088 jcp.bcast_block = jcp.ur;
1089
1090 jcp.reduce_loop_unroll = jcp.reduce_block;
1091 jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll * jcp.typesize_in;
1092
1093 jcp.reduce_loop_load_step
1094 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
1095
1096 jcp.bcast_loop_output_step
1097 = jcp.ur * jcp.ngroups * jcp.oc_without_padding * jcp.typesize_out;
1098 jcp.bcast_loop_output_substep = -1; // unused
1099 jcp.bcast_loop_bcast_step
1100 = jcp.ur * jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in;
1101 jcp.bcast_loop_bcast_substep = -1; // unused
1102
1103 jcp.load_loop_load_step = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
1104
1105 jcp.load_loop_iter_step = jcp.load_block;
1106
1107 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
1108
1109 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1110 int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1111
1112 reduce_blocking = nb_reduce;
1113 if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
1114 reduce_blocking = 64;
1115 else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
1116 reduce_blocking = 16;
1117 reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
1118 reduce_blocking *= jcp.reduce_block;
1119
1120 bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
1121 if (cmp_reduce) jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
1122 load_blocking = jcp.load_dim;
1123
1124 jcp.load_grp_count = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast);
1125 jcp.load_grp_count = best_divider(
1126 jcp.nthr, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
1127
1128 if (jcp.bcast_dim <= SMALL_SPATIAL
1129 && jcp.load_dim * jcp.reduce_dim >= L2_size) {
1130 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
1131 } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= jcp.nthr
1132 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
1133 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); //
1134 load_blocking = jcp.load_block;
1135 }
1136
1137 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
1138 div_up(jcp.nthr, jcp.load_grp_count))
1139 * jcp.bcast_block;
1140 bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking);
1141 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
1142
1143 int space_for_bcast = (L2_capacity - /* kernel_size - */
1144 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking
1145 - 3 * 1024);
1146 if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2;
1147
1148 int bcast_in_cache
1149 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
1150 bcast_blocking = nstl::min(
1151 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
1152
1153 load_blocking_max = load_blocking;
1154 bcast_blocking_max = bcast_blocking * 3 / 2;
1155 reduce_blocking_max = reduce_blocking;
1156
1157 assert(load_blocking);
1158 assert(load_blocking_max);
1159 assert(bcast_blocking);
1160 assert(bcast_blocking_max);
1161 assert(reduce_blocking);
1162 assert(reduce_blocking_max);
1163 assert(load_blocking % jcp.load_block == 0);
1164 assert(reduce_blocking % jcp.reduce_block == 0);
1165 assert(load_blocking_max % jcp.load_block == 0);
1166 assert(reduce_blocking_max % jcp.reduce_block == 0);
1167
1168 assert(jcp.reduce_loop_unroll % 4 == 0);
1169 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
1170
1171 assert(jcp.bcast_block % jcp.ur == 0);
1172 assert(jcp.reduce_dim % jcp.reduce_block == 0);
1173
1174 jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur;
1175
1176 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
1177 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
1178 jcp.nb_load_blocking = load_blocking / jcp.load_block;
1179 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
1180 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
1181 jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
1182
1183 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1184 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1185 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1186
1187 // miniumum size of load dim chunk for work distribution within threads
1188 jcp.nb_load_chunk = 1;
1189 // peformance improvements for googlenet_v3, mb=1;
1190 // TODO: generalize this condition and rewrite it in appropriate manner
1191 int ncores_per_socket = (int)cpu().getNumCores(
1192 Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
1193 if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4
1194 && jcp.ic * jcp.oc <= L2_size && jcp.nthr <= ncores_per_socket) {
1195 jcp.nb_load_chunk = 4;
1196 jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count);
1197 }
1198
1199 /* adjust the thread decomposition
1200 * to improve the perf for small size problem
1201 * the threshold 8192 is empirical
1202 * simply set the thread to max of nb_load and nb_bcast now
1203 * TODO: add get_thr_eff func to compute optimal thread
1204 * TODO: Threshold can be increase when init stride > 1 */
1205 auto bcast_size
1206 = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
1207 if (jcp.typesize_in * bcast_size < 8192 && jcp.ngroups < jcp.nthr
1208 && jcp.nb_bcast * jcp.nb_load < jcp.nthr) {
1209 int nthr = nstl::max(jcp.nb_load, jcp.nb_bcast);
1210 jcp.nthr = nstl::min(jcp.nthr, nthr);
1211 }
1212
1213 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
1214 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
1215 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
1216 const int wei_mask_per_oc = 1 << (int)with_groups;
1217 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
1218 jcp.dst_scale = !dst_scales.has_default_values();
1219
1220 // only common src & dst scales are supported
1221 // only common and per-oc-channel weight scales are supported
1222 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
1223 && everyone_is(src_scales.mask_, dst_scales.mask_, 0);
1224 if (!scales_ok) return status::unimplemented;
1225
1226 jcp.wei_adj_scale
1227 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
1228 ? weights_d.extra().scale_adjust
1229 : 1.f;
1230
1231 return status::success;
1232}
1233
1234void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
1235 memory_tracking::registrar_t &scratchpad,
1236 const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
1237 using namespace dnnl::impl::memory_tracking::names;
1238
1239 const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_;
1240 const dim_t scales_count = wei_mask == 0 ? 1 : jcp.oc * jcp.ngroups;
1241 const dim_t count = nstl::max<dim_t>(scales_count, (dim_t)jcp.ic_block);
1242 scratchpad.book<float>(key_conv_adjusted_scales, count);
1243}
1244
1245template struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Xbyak::Zmm>;
1246template struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Xbyak::Ymm>;
1247template struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Xbyak::Xmm>;
1248
1249} // namespace x64
1250} // namespace cpu
1251} // namespace impl
1252} // namespace dnnl
1253