1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/memory.hpp"
19#include "common/memory_tracking.hpp"
20#include "common/nstl.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/platform.hpp"
25#include "cpu/x64/injectors/injector_utils.hpp"
26#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
27#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
28#include "cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp"
29
30#define GET_OFF(field) offsetof(jit_conv_call_s, field)
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37using namespace dnnl::impl::memory_tracking::names;
38using namespace dnnl::impl::utils;
39using namespace dnnl::impl::data_type;
40using namespace Xbyak;
41
42namespace {
43void pick_loop_order(jit_conv_conf_t &jcp, int nthr) {
44 jcp.loop_order = loop_cwgn;
45 if (jcp.ngroups > 1) {
46 jcp.loop_order = loop_ngcw;
47 if (jcp.mb < nthr)
48 jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
49 } else if (jcp.mb >= nthr && jcp.ic_without_padding <= 16) {
50 jcp.loop_order = loop_ngcw;
51 }
52}
53} // namespace
54
55template <typename Vmm>
56_jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::_jit_avx512_core_x8s8s32x_fwd_kernel(
57 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
58 const memory_desc_t &dst_md)
59 : jit_generator(jit_name())
60 , jcp(ajcp)
61 , attr_(attr)
62 , postops_injector_(nullptr) {
63 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
64 using namespace binary_injector;
65 static constexpr bool preserve_gpr = true;
66 static constexpr bool preserve_vmm = false;
67 static constexpr size_t helper_vmm_idx = 31;
68 const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
69 const size_t tail_size = oc_block_tail
70 ? oc_block_tail
71 : jcp.oc_without_padding % isa_simd_width_;
72 static constexpr bool use_exact_tail_scalar_bcast = false;
73
74 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
75 r14, r15, r13, preserve_gpr, preserve_vmm,
76 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
77 memory_desc_wrapper(dst_md), tail_size, postops_mask,
78 use_exact_tail_scalar_bcast};
79 const static_params_t static_params {
80 this->param1, rhs_arg_static_params};
81
82 postops_injector_ = utils::make_unique<
83 injector::jit_uni_postops_injector_t<avx512_core, Vmm>>(
84 this, jcp.post_ops, static_params);
85 }
86 if (!isa_has_bf16(jcp.isa) && jcp.dst_dt == data_type::bf16)
87 bf16_emu_ = utils::make_unique<bf16_emulation_t>(this,
88 bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3,
89 bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_4);
90}
91
92template <typename Vmm>
93void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w) {
94 int nb_oc_block
95 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
96 for (int k = 0; k < nb_oc_block; k++)
97 for (int j = 0; j < ur_w; j++) {
98 Vmm vmm = vmm_out(j, k);
99 vpxord(vmm, vmm, vmm);
100 }
101 if (jcp.signed_input) {
102 mov(reg_scratch, 128);
103 if (jcp.is_depthwise && !jcp.is_fast_depthwise)
104 vpbroadcastd(vmm_shift, reg_scratch.cvt32());
105 else
106 vpbroadcastb(vmm_shift, reg_scratch.cvt8());
107 }
108}
109
110template <typename Vmm>
111Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::vmm_mask(
112 const Vmm vmm_in, bool mask_flag, bool store) {
113 return vmm_in;
114}
115
116template <>
117Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::vmm_mask(
118 const Zmm zmm_in, bool mask_flag, bool store) {
119 return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
120 : zmm_in;
121}
122
123template <typename Vmm>
124void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in,
125 const Vmm vmm_in, const Operand &op, bool mask_flag) {
126 using namespace data_type;
127 const Vmm vmm = vmm_mask(vmm_in, mask_flag);
128 switch (type_in) {
129 case f32:
130 case s32: vmovups(vmm, op); break;
131 case bf16:
132 vpmovzxwd(vmm, op);
133 vpslld(vmm_in, vmm_in, 16);
134 break;
135 case s8: vpmovsxbd(vmm, op); break;
136 case u8: vpmovzxbd(vmm, op); break;
137 default: assert(!"unsupported data type");
138 }
139 if (one_of(type_in, s32, s8, u8)) vcvtdq2ps(vmm_in, vmm_in);
140}
141
142template <typename F>
143static void iterate(const int nb_oc_block, const int ur_w,
144 const bool last_oc_block_flag, const bool force_masking, const F &f) {
145 for (int k = 0; k < nb_oc_block; k++) {
146 const bool mask_flag
147 = force_masking || (last_oc_block_flag && k + 1 == nb_oc_block);
148 for (int j = 0; j < ur_w; j++)
149 f(mask_flag, k, j);
150 }
151}
152template <typename F>
153static void iterate(const int nb_oc_block, const int ur_w,
154 const bool last_oc_block_flag, const F &f) {
155 iterate(nb_oc_block, ur_w, last_oc_block_flag, false, f);
156}
157template <typename F>
158static void iterate(const int nb_oc_block, const int ur_w, const F &f) {
159 iterate(nb_oc_block, ur_w, false, false, f);
160}
161
162template <typename Vmm>
163void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::apply_sum(int ur_w,
164 bool last_oc_block_flag, const int nb_oc_block, const int oc_block,
165 const float *p_sum_scale, const int32_t *p_sum_zp) {
166 if (jcp.with_sum) {
167 const float sum_scale = *p_sum_scale;
168 const int32_t sum_zp = *p_sum_zp;
169 const auto sum_injector_lam = [this, oc_block, sum_scale, sum_zp](
170 const bool mask_flag, const int k,
171 const int j) {
172 int aux_output_offset = jcp.typesize_out
173 * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
174 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
175 Vmm vmm = vmm_out(j, k);
176 cvt2ps(jcp.sum_dt, vmm_prev_dst, addr, mask_flag);
177 if (sum_zp != 0) vsubps(vmm_prev_dst, vmm_sum_zp);
178 if (sum_scale == 1.f)
179 vaddps(vmm, vmm_prev_dst);
180 else
181 vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
182 };
183 const auto sum_injector = [=]() {
184 iterate(nb_oc_block, ur_w, last_oc_block_flag, sum_injector_lam);
185 };
186 if (sum_scale != 1.f)
187 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
188 if (sum_zp != 0) {
189 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
190 vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]);
191 }
192 postops_injector_->set_lambda_injector(
193 primitive_kind::sum, sum_injector);
194 }
195}
196
197template <typename Vmm>
198void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::apply_postops(int ur_w,
199 bool last_oc_block_flag, const int nb_oc_block, const int oc_block,
200 const float *p_sum_scale, const int32_t *p_sum_zp) {
201 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
202 apply_sum(ur_w, last_oc_block_flag, nb_oc_block, oc_block, p_sum_scale,
203 p_sum_zp);
204
205 injector_utils::vmm_index_set_t vmm_idxs;
206 if (jcp.with_binary) {
207 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
208 const bool oc_blk_is_smaller_than_vmm = oc_block < isa_simd_width_;
209 iterate(nb_oc_block, ur_w, last_oc_block_flag,
210 oc_blk_is_smaller_than_vmm,
211 [&](const bool mask_flag, const int k, const int j) {
212 const size_t aux_output_l_off = jcp.typesize_out
213 * (k * oc_block
214 + j * jcp.oc_without_padding
215 * jcp.ngroups);
216 const auto vmm_idx = vmm_out_idx(j, k);
217 vmm_idxs.emplace(vmm_idx);
218
219 rhs_arg_params.vmm_idx_to_out_reg.emplace(
220 vmm_idx, reg_out);
221 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
222 vmm_idx, aux_output_l_off);
223 if (mask_flag)
224 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
225 });
226
227 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
228 } else {
229 iterate(nb_oc_block, ur_w,
230 [&](const bool, const int k, const int j) {
231 vmm_idxs.emplace(vmm_out_idx(j, k));
232 });
233 postops_injector_->compute_vector_range(vmm_idxs);
234 }
235 }
236}
237
238template <typename Vmm>
239void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output(
240 int ur_w, bool last_oc_block_flag) {
241 int nb_oc_block
242 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
243 int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
244
245 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
246 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
247 if (jcp.signed_input)
248 mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
249
250 if (jcp.src_zero_point) {
251 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
252 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
253 }
254
255 const auto &p = attr_.post_ops_;
256 const int sum_idx = p.find(primitive_kind::sum);
257 const float *p_sum_scale = nullptr;
258 const int32_t *p_sum_zp = nullptr;
259 if (sum_idx != -1) {
260 const auto &p_entry = p.entry_[sum_idx];
261 p_sum_scale = &p_entry.sum.scale;
262 p_sum_zp = &p_entry.sum.zero_point;
263 }
264
265 for (int k = 0; k < nb_oc_block; k++) {
266 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
267 int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
268 if (jcp.with_bias) {
269 int bias_offset = jcp.typesize_bia * k * oc_block;
270 auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
271
272 cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
273 }
274 if (jcp.signed_input) {
275 int comp_offset = sizeof(int32_t) * k * oc_block;
276 Vmm vmm_comp_ = vmm_mask(vmm_comp, mask_flag);
277 vmovups(vmm_comp_,
278 EVEX_compress_addr(reg_compensation, comp_offset));
279 }
280 if (jcp.src_zero_point) {
281 // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32
282 int zp_offset = sizeof(int32_t) * k * oc_block;
283 Vmm vmm_zp_ = vmm_mask(vmm_zp, mask_flag);
284 vmovups(vmm_zp_,
285 EVEX_compress_addr(reg_zp_compensation, zp_offset));
286 vpmulld(vmm_zp_, vmm_zp_,
287 EVEX_compress_addr(
288 reg_src_zero_point, 0, jcp.zp_src_is_common));
289 }
290 /* add to zmm_accum: compensation, zero_point, bias and permute */
291 for (int j = 0; j < ur_w; j++) {
292 Vmm vmm = vmm_out(j, k);
293 if (jcp.is_fast_depthwise)
294 vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k));
295 /* add comp in s32 to avoid loss of precision
296 when convert s32 to f32 in integer(2^24)
297 TODO: do the same to bias */
298 if (jcp.signed_input) vpaddd(vmm, vmm, vmm_comp);
299 if (jcp.src_zero_point) vpaddd(vmm, vmm, vmm_zp);
300 vcvtdq2ps(vmm, vmm);
301
302 const Vmm vmm_k = vmm_mask(vmm, mask_flag);
303 vmulps(vmm_k, vmm,
304 EVEX_compress_addr(reg_ptr_scales, scale_offset));
305
306 if (jcp.with_bias) vaddps(vmm, vmm, vmm_bias);
307 }
308 }
309
310 apply_postops(ur_w, last_oc_block_flag, nb_oc_block, oc_block, p_sum_scale,
311 p_sum_zp);
312
313 if (jcp.dst_scale) {
314 mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
315 vmovups(vmm_dst_scale, EVEX_compress_addr(reg_dst_scale, 0));
316
317 /* Apply dst scale to accumulator */
318 for (int k = 0; k < nb_oc_block; k++) {
319 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
320 for (int j = 0; j < ur_w; j++) {
321 Vmm vmm = vmm_out(j, k);
322 const Vmm vmm_k = vmm_mask(vmm, mask_flag);
323 vmulps(vmm_k, vmm, vmm_dst_scale);
324 }
325 }
326 }
327
328 if (jcp.dst_zero_point) {
329 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
330 vcvtdq2ps(vmm_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true));
331
332 /* Add dst zero_point to accumulator */
333 for (int k = 0; k < nb_oc_block; k++) {
334 for (int j = 0; j < ur_w; j++) {
335 Vmm vmm = vmm_out(j, k);
336 vaddps(vmm, vmm, vmm_zp);
337 }
338 }
339 }
340
341 // Properly saturate the accumulators for integer datatypes
342 if (one_of(jcp.dst_dt, u8, s8, s32)) {
343 init_saturate_f32(
344 vmm_zero, vmm_saturation, aux_reg_saturation, f32, jcp.dst_dt);
345 for (int k = 0; k < nb_oc_block; k++) {
346 for (int j = 0; j < ur_w; j++) {
347 Vmm vmm = vmm_out(j, k);
348 saturate_f32(vmm, vmm_zero, vmm_saturation, jcp.dst_dt);
349 vcvtps2dq(vmm, vmm);
350 }
351 }
352 }
353
354 if (!isa_has_bf16(jcp.isa) && jcp.dst_dt == data_type::bf16)
355 bf16_emu_->init_vcvtneps2bf16();
356
357 /* write out register to output_addr */
358 if (jcp.dst_dt == data_type::bf16 && isa_has_bf16(jcp.isa)) {
359 // Optimization: use single store instruction for pair of the
360 // nearest vectors along OC dimension
361 for (int j = 0; j < ur_w; j++) {
362 int k = 0;
363 for (; k < rnd_dn(nb_oc_block, 2); k += 2) {
364 Vmm vmm = vmm_out(j, k);
365 Vmm vmm_next = vmm_out(j, k + 1);
366
367 int aux_output_offset = jcp.typesize_out
368 * (k * oc_block
369 + j * jcp.oc_without_padding * jcp.ngroups);
370 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
371
372 vcvtne2ps2bf16(vmm, vmm_next, vmm);
373 // mask only needed for last oc_block
374 const bool mask_flag
375 = last_oc_block_flag && k + 2 == nb_oc_block;
376
377 vmovdqu16(addr, maybe_mask_vmm(vmm, mask_flag));
378 }
379 if (nb_oc_block % 2 != 0) {
380 Vmm vmm = vmm_out(j, k);
381 auto vmm_down = Vmm_down_t(vmm.getIdx());
382 int aux_output_offset = jcp.typesize_out
383 * (k * oc_block
384 + j * jcp.oc_without_padding * jcp.ngroups);
385 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
386 vcvtneps2bf16(vmm_down, vmm);
387 // for xmm, upper half is zero after conversion to
388 // bf16, so mask always & mask for tails
389 bool mask_flag = jcp.simd_w == 4 || last_oc_block_flag;
390 vmovdqu16(addr, maybe_mask_vmm_down(vmm_down, mask_flag));
391 }
392 }
393 } else {
394 for (int k = 0; k < nb_oc_block; k++) {
395 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
396 for (int j = 0; j < ur_w; j++) {
397 int aux_output_offset = jcp.typesize_out
398 * (k * oc_block
399 + j * jcp.oc_without_padding * jcp.ngroups);
400 auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
401
402 Vmm vmm = vmm_out(j, k);
403 const Vmm r_vmm = vmm_mask(vmm, mask_flag, true);
404
405 switch (jcp.dst_dt) {
406 case data_type::f32:
407 case data_type::s32: vmovups(addr, r_vmm); break;
408 case data_type::s8: vpmovsdb(addr, r_vmm); break;
409 case data_type::u8: vpmovusdb(addr, r_vmm); break;
410 case data_type::bf16:
411 store_bf16(addr, vmm.getIdx(),
412 get_src_down_idx(nb_oc_block), mask_flag);
413 break;
414 default: assert(!"unknown dst_dt");
415 }
416 }
417 }
418 }
419}
420
421template <typename Vmm>
422void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(int ur_w,
423 int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
424 assert(!"invalid group blocking for depthwise convolution");
425}
426
427template <>
428void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(int ur_w,
429 int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
430
431 const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input);
432
433 if (jcp.src_zero_point) {
434 push(aux_reg_ker_d);
435 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
436 }
437
438 auto input_spatial_index = [=](int oi, int ki) {
439 return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
440 };
441
442 auto input_offset2 = [=](int ii, int ci) {
443 if (jcp.is_fused_conv)
444 return jcp.typesize_in
445 * (ii * jcp.dw_conv_buffer_oc + ci * jcp.ch_block);
446 else
447 return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
448 };
449
450 auto input_offset3 = [=](int oi, int ci, int ki) {
451 return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
452 };
453
454 auto kernel_offset = [=](int ci, int ki) {
455 return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
456 };
457
458 auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
459 // okay for depthwise since src is zero-extended
460 if (jcp.has_vnni) {
461 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
462 } else {
463 vpmaddwd(zmm_tmp, vreg_src, vreg_wei);
464 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
465 }
466 };
467
468 int ii_start = 0;
469 int ii_end = -1;
470 if (jcp.is_resrc_depthwise && !h_padded) {
471 // find bounds of input spatial indices
472 bool first = true;
473 for (int ki = 0; ki < jcp.kw; ki++) {
474 int oi_start = get_ow_start(ki, pad_l);
475 int oi_end = get_ow_end(ur_w, ki, pad_r);
476 for (int oi = oi_start; oi < oi_end; oi++) {
477 int ii = input_spatial_index(oi, ki);
478 if (first || ii < ii_start) ii_start = ii;
479 if (first || ii > ii_end) ii_end = ii;
480 first = false;
481 }
482 }
483 }
484
485 if (jcp.signed_input) vmovups(zmm_shifted_zero, vmm_shift);
486
487 for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) {
488 const bool mask_flag = last_ic_block_flag != no_last_block
489 && ci == jcp.nb_ch_blocking - 1;
490 if (jcp.is_resrc_depthwise && !h_padded) {
491 // now we can load input once and reuse up to jcp.kw times
492 for (int ii = ii_start; ii <= ii_end; ii++) {
493 int aux_input_offset = input_offset2(ii, ci);
494 const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking);
495 const Zmm zmm_inp_msk = mask_flag
496 ? zmm_inp_tmp | ktail_mask | T_z
497 : zmm_inp_tmp;
498 if (jcp.is_fast_depthwise) {
499 assert(!mask_flag);
500 vbroadcasti32x4(zmm_inp_msk,
501 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
502 } else {
503 vpmovzxbd(zmm_inp_msk,
504 EVEX_compress_addr(aux_reg_inp, aux_input_offset));
505 }
506 if (jcp.signed_input)
507 vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift);
508 }
509 }
510 for (int ki = 0; ki < jcp.kw; ki++) {
511 int aux_kernel_offset = kernel_offset(ci, ki);
512 const int oi_start = get_ow_start(ki, pad_l);
513 const int oi_end = get_ow_end(ur_w, ki, pad_r);
514 if (compute_kernel) {
515 if (jcp.is_fast_depthwise) {
516 vbroadcasti32x4(zmm_wei,
517 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
518 vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei);
519 } else {
520 vpmovsxbd(zmm_wei,
521 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
522 }
523
524 if (h_padded) {
525 assert(jcp.signed_input);
526 for (int oi = 0; oi < ur_w; oi++)
527 compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
528 } else {
529 const Zmm r_zmm_src
530 = mask_flag ? zmm_src | ktail_mask : zmm_src;
531 int start_ = jcp.signed_input ? 0 : oi_start;
532 int end_ = jcp.signed_input ? ur_w : oi_end;
533 for (int oi = start_; oi < end_; oi++) {
534 if (oi >= oi_start && oi < oi_end) {
535 if (jcp.is_resrc_depthwise) {
536 int ii = input_spatial_index(oi, ki);
537 zmm_src = zmm_inp(ii, jcp.nb_ch_blocking);
538 } else {
539 int aux_input_offset
540 = input_offset3(oi, ci, ki);
541 if (jcp.is_fast_depthwise) {
542 assert(!mask_flag);
543 vbroadcasti32x4(r_zmm_src,
544 EVEX_compress_addr(aux_reg_inp,
545 aux_input_offset));
546 } else {
547 vpmovzxbd(r_zmm_src,
548 EVEX_compress_addr(aux_reg_inp,
549 aux_input_offset));
550 }
551 if (jcp.signed_input)
552 vpaddb(zmm_src, zmm_src, vmm_shift);
553 }
554 compute(zmm_out(oi, ci), zmm_wei, zmm_src);
555 } else {
556 assert(jcp.signed_input);
557 compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
558 }
559 }
560 }
561 }
562 if (jcp.src_zero_point) {
563 /* calculate src_zero_point padding as:
564 * (is_padding ?
565 * src_zero_point_s32 * conv(1, wei_s32) : 0) */
566 if (jcp.is_fast_depthwise || !compute_kernel) {
567 vpmovsxbd(zmm_wei,
568 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
569 if (jcp.is_fast_depthwise)
570 vpermd(zmm_wei, zmm_permute, zmm_wei);
571 } // else: already loaded weights from previous block
572 int zp_offset = 0;
573 for (int oi = 0; oi < ur_w; oi++) {
574 if (oi < oi_start || oi >= oi_end || h_padded) {
575 vpmulld(vmm_zp_tmp, zmm_wei,
576 EVEX_compress_addr(reg_src_zero_point,
577 zp_offset, jcp.zp_src_is_common));
578 vpaddd(zmm_out(oi, ci), zmm_out(oi, ci), vmm_zp_tmp);
579 }
580 }
581 }
582 }
583 }
584 if (jcp.src_zero_point) pop(aux_reg_ker_d);
585}
586
587template <typename Vmm>
588void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l,
589 int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
590 if (jcp.is_depthwise)
591 return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
592
593 const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input);
594
595 assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input));
596
597 if (jcp.src_zero_point) {
598 push(aux_reg_ker_d);
599 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
600 }
601
602 int kw = jcp.kw;
603 int stride_w = jcp.stride_w;
604 int ic_block = jcp.ic_block;
605 int oc_block = jcp.oc_block;
606 int ch_block_all = jcp.ch_block * ic_block * oc_block;
607
608 int nb_oc_block = jcp.nb_oc_blocking;
609
610 auto input_offset = [=](int oi, int ic, int ki) {
611 return jcp.typesize_in
612 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
613 * jcp.ic_without_padding * jcp.ngroups
614 + ic_sub_step * ic);
615 };
616 auto kernel_offset = [=](int ii, int ic, int ki) {
617 return jcp.typesize_in
618 * ((ii * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki)
619 * ch_block_all
620 + ic_sub_step * ic * oc_block);
621 };
622 auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
623 if (jcp.has_vnni) {
624 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
625 } else {
626 vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
627 vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
628 vpaddd(vreg_acc, vreg_acc, vmm_tmp);
629 }
630 };
631
632 for (int ki = 0; ki < kw; ki++) {
633 int jj_start = get_ow_start(ki, pad_l);
634 int jj_end = get_ow_end(ur_w, ki, pad_r);
635 int ic_tail_size = jcp.ic_without_padding % ic_sub_step;
636 int _start = jcp.signed_input ? 0 : jj_start;
637 int _end = jcp.signed_input ? ur_w : jj_end;
638 /* Skip the last loads of input
639 if (ic%16)/ic_sub_step < ic_block/ic_sub_step */
640 int icb = (last_ic_block_flag != no_last_block)
641 ? div_up((jcp.ic_without_padding % ic_block), ic_sub_step)
642 : ic_block / ic_sub_step;
643 if (compute_kernel) {
644 for (int ic = 0; ic < icb; ic++) {
645 if (h_padded) {
646 // fill padded area with shifted value in first iteration
647 if (ic == 0) {
648 Vmm inp = vmm_inp(0, nb_oc_block);
649 vmovups(inp, vmm_shift); // bcast(128)
650 }
651 } else {
652 for (int jj = _start; jj < _end; jj++) {
653 int aux_input_offset = input_offset(jj, ic, ki);
654 if (jj >= jj_start && jj < jj_end) {
655 if (last_ic_block_flag == last_sp_block
656 && ic_tail_size != 0 && ic == icb - 1) {
657 Xmm xmm_tmp = Xmm(
658 vmm_inp(jj, nb_oc_block).getIdx());
659 load_bytes(xmm_tmp, aux_reg_inp,
660 aux_input_offset, ic_tail_size);
661 vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp);
662 } else {
663 vpbroadcastd(vmm_inp(jj, nb_oc_block),
664 EVEX_compress_addr(
665 aux_reg_inp, aux_input_offset));
666 }
667 if (jcp.signed_input)
668 vpaddb(vmm_inp(jj, nb_oc_block),
669 vmm_inp(jj, nb_oc_block), vmm_shift);
670 } else {
671 // fill padded area with shifted value in
672 // first iteration
673 if (jcp.signed_input && ic == 0) {
674 Vmm inp = vmm_inp(jj, nb_oc_block);
675 vmovups(inp, vmm_shift);
676 }
677 }
678 }
679 }
680 for (int ii = 0; ii < nb_oc_block; ii++) {
681 int aux_kernel_offset = kernel_offset(ii, ic, ki);
682 vmovups(vmm_wei,
683 EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
684 for (int jj = _start; jj < _end; jj++) {
685 Vmm inp = h_padded ? vmm_inp(0, nb_oc_block)
686 : vmm_inp(jj, nb_oc_block);
687 compute(vmm_out(jj, ii), vmm_wei, inp);
688 }
689 }
690 }
691 }
692 if (jcp.src_zero_point) {
693 /* calculate src_zero_point padding as:
694 * (is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0) */
695 Vmm vmm_tmp = vmm_inp(0, nb_oc_block);
696 for (int jj = 0; jj < ur_w; jj++) {
697 if (jj < jj_start || jj >= jj_end || h_padded) {
698 for (int ii = 0; ii < nb_oc_block; ii++) {
699 vpxord(vmm_zp_tmp, vmm_zp_tmp, vmm_zp_tmp);
700 for (int ic = 0; ic < icb; ic++) {
701 int aux_kernel_offset = kernel_offset(ii, ic, ki);
702 if (jcp.has_vnni) {
703 vpdpbusd(vmm_zp_tmp, vmm_zp_one,
704 EVEX_compress_addr(aux_reg_ker,
705 aux_kernel_offset));
706 } else {
707 vpmaddubsw(vmm_tmp, vmm_zp_one,
708 EVEX_compress_addr(aux_reg_ker,
709 aux_kernel_offset));
710 vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
711 vpaddd(vmm_zp_tmp, vmm_zp_tmp, vmm_tmp);
712 }
713 }
714 int zp_offset = 0;
715 vpmulld(vmm_zp_tmp, vmm_zp_tmp,
716 EVEX_compress_addr(reg_src_zero_point,
717 zp_offset, jcp.zp_src_is_common));
718 vpaddd(vmm_out(jj, ii), vmm_out(jj, ii), vmm_zp_tmp);
719 }
720 }
721 }
722 }
723 }
724
725 if (jcp.src_zero_point) pop(aux_reg_ker_d);
726}
727
728template <typename Vmm>
729void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop(
730 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
731 Label kd_label, kh_label, skip_kd_loop, skip_kh_loop;
732 Label f_overflow_label, no_f_overflow_label, d_h_f_overflow_label,
733 t_overflow_label, no_t_overflow_label, b_overflow_label,
734 no_b_overflow_label, back_overflow_label, no_back_overflow_label,
735 d_h_back_overflow_label;
736
737 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
738 int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
739 int shift_input_ptr
740 = jcp.typesize_in * jcp.iw * jcp.ic_without_padding * jcp.ngroups;
741
742 if (jcp.ndims == 5) {
743 mov(aux_reg_ker_d, reg_ker);
744 mov(aux_reg_inp_d, reg_inp);
745 if (jcp.signed_input || jcp.src_zero_point) {
746 //TODO: May be avoided when f_pad=0 and dd0
747 //TODO: Potential optimization by precomputing, when kd <<< od?
748 mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]);
749 cmp(reg_ki, 0);
750 je(no_f_overflow_label, T_NEAR);
751 L(f_overflow_label);
752 {
753 mov(aux_reg_ker, aux_reg_ker_d);
754 mov(reg_kj, jcp.kh);
755 L(d_h_f_overflow_label);
756 {
757 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
758 add(aux_reg_ker, shift_kernel_ptr);
759 dec(reg_kj);
760 jne(d_h_f_overflow_label);
761 }
762 add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh);
763 dec(reg_ki);
764 jne(f_overflow_label);
765 }
766 L(no_f_overflow_label);
767 }
768
769 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
770 if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_d >= jcp.id)
771 || (!(jcp.signed_input || jcp.src_zero_point)
772 && (jcp.kd - 1) * (jcp.dilate_d + 1)
773 < nstl::max(jcp.f_pad, jcp.back_pad))) {
774 cmp(reg_ki, 0);
775 je(skip_kd_loop, T_NEAR);
776 }
777 L(kd_label);
778 mov(aux_reg_inp, aux_reg_inp_d);
779 mov(aux_reg_ker, aux_reg_ker_d);
780 } else {
781 if (jcp.is_fused_conv) {
782 mov(aux_reg_inp_buffer_ptr, reg_inp_buffer_ptr);
783 } else {
784 mov(aux_reg_inp, reg_inp);
785 }
786 mov(aux_reg_ker, reg_ker);
787 }
788
789 if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) {
790 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
791 cmp(reg_overflow, 0);
792 je(no_t_overflow_label, T_NEAR);
793 L(t_overflow_label);
794 {
795 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
796
797 add(aux_reg_ker, shift_kernel_ptr);
798 dec(reg_overflow);
799 cmp(reg_overflow, 0);
800 jg(t_overflow_label, T_NEAR);
801 }
802 L(no_t_overflow_label);
803 }
804 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
805 if (jcp.signed_input || jcp.src_zero_point || (jcp.dilate_h >= jcp.ih)
806 || (!(jcp.signed_input || jcp.src_zero_point)
807 && (jcp.kh - 1) * (jcp.dilate_h + 1)
808 < nstl::max(jcp.t_pad, jcp.b_pad))) {
809 cmp(reg_kj, 0);
810 je(skip_kh_loop, T_NEAR);
811 }
812 L(kh_label);
813 {
814 if (jcp.is_fused_conv) {
815 mov(aux_reg_inp, ptr[aux_reg_inp_buffer_ptr]);
816 add(aux_reg_inp, reg_inp);
817 }
818 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
819
820 add(aux_reg_ker, shift_kernel_ptr);
821 if (jcp.is_fused_conv) {
822 add(aux_reg_inp_buffer_ptr, sizeof(void *));
823 } else {
824 add(aux_reg_inp, shift_input_ptr * (jcp.dilate_h + 1));
825 }
826 dec(reg_kj);
827 cmp(reg_kj, 0);
828 jg(kh_label, T_NEAR);
829 }
830 L(skip_kh_loop);
831 if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) {
832 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
833 cmp(reg_overflow, 0);
834 je(no_b_overflow_label, T_NEAR);
835 L(b_overflow_label);
836 {
837 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
838
839 add(aux_reg_ker, shift_kernel_ptr);
840 dec(reg_overflow);
841 cmp(reg_overflow, 0);
842 jg(b_overflow_label, T_NEAR);
843 }
844 L(no_b_overflow_label);
845 }
846
847 if (jcp.ndims == 5) {
848 add(aux_reg_inp_d, shift_input_ptr * jcp.ih * (jcp.dilate_d + 1));
849 add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh);
850 dec(reg_ki);
851 jne(kd_label, T_NEAR);
852
853 L(skip_kd_loop);
854 if (jcp.signed_input || jcp.src_zero_point) {
855 mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]);
856 cmp(reg_ki, 0);
857 je(no_back_overflow_label, T_NEAR);
858 L(back_overflow_label);
859 {
860 mov(aux_reg_ker, aux_reg_ker_d);
861 mov(reg_kj, jcp.kh);
862 L(d_h_back_overflow_label);
863 {
864 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
865 add(aux_reg_ker, shift_kernel_ptr);
866 dec(reg_kj);
867 jne(d_h_back_overflow_label);
868 }
869 add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh);
870 dec(reg_ki);
871 jne(back_overflow_label);
872 }
873 L(no_back_overflow_label);
874 }
875 }
876}
877
878template <typename Vmm>
879void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop(
880 int ur_w, int pad_l, int pad_r, bool is_last_sp_block) {
881
882 if (jcp.src_zero_point && !jcp.is_depthwise) {
883 xor_(reg_scratch, reg_scratch);
884 Reg8 _t8 = reg_scratch.cvt8();
885 mov(_t8, 0x1);
886 vpbroadcastb(vmm_zp_one, _t8);
887 }
888
889 prepare_output(ur_w);
890
891 // IC loop
892 Label icb_label;
893 mov(reg_icb, jcp.nb_ic);
894 L(icb_label);
895 const bool do_icb_loop
896 = jcp.is_depthwise ? jcp.nb_ch > jcp.nb_ch_blocking : jcp.nb_ic > 1;
897 if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
898 Label common_ker, end_ker;
899 if (do_icb_loop) {
900 if (jcp.is_depthwise)
901 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
902 else
903 cmp(reg_icb, 1); // The last IC block
904 jne(common_ker, T_NEAR);
905 }
906 kh_loop(ur_w, pad_l, pad_r,
907 is_last_sp_block ? last_sp_block : last_ic_block);
908 if (do_icb_loop) {
909 jmp(end_ker, T_NEAR);
910
911 L(common_ker);
912 kh_loop(ur_w, pad_l, pad_r, no_last_block);
913
914 L(end_ker);
915 }
916 } else {
917 kh_loop(ur_w, pad_l, pad_r, no_last_block);
918 }
919 // End of IC Loop
920 if (do_icb_loop) {
921 int inp_step = jcp.ic_block;
922 const size_t ker_step = (size_t)jcp.kd * jcp.kh * jcp.kw * jcp.oc_block
923 * jcp.ic_block;
924 add(reg_inp, jcp.typesize_in * inp_step);
925 safe_add(reg_ker, jcp.typesize_in * ker_step, reg_ker_long_offt);
926
927 dec(reg_icb);
928 cmp(reg_icb, 0);
929 jg(icb_label, T_NEAR);
930
931 sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
932 safe_sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic,
933 reg_ker_long_offt);
934 }
935
936 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
937 Label common_store, end_store;
938
939 if (jcp.is_depthwise)
940 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
941 else
942 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
943
944 jne(common_store, T_NEAR);
945
946 store_output(ur_w, true); // last oc block
947 jmp(end_store, T_NEAR);
948
949 L(common_store);
950 store_output(ur_w, false);
951
952 L(end_store);
953 } else {
954 store_output(ur_w, false);
955 }
956}
957
958template <typename Vmm>
959void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate() {
960 Label permute_index_table;
961 int in_ic_shift = jcp.is_fused_conv ? jcp.dw_conv_buffer_oc
962 : jcp.ic_without_padding * jcp.ngroups;
963 const int urw_inp_stride = jcp.ur_w * jcp.stride_w;
964 const int n_urw_l_pad
965 = nstl::min(div_up(jcp.l_pad, urw_inp_stride), jcp.ow / jcp.ur_w);
966 const int inp_shift_pad = nstl::max(0,
967 jcp.typesize_in * (n_urw_l_pad * urw_inp_stride - jcp.l_pad)
968 * in_ic_shift);
969 int inp_shift = jcp.typesize_in * (jcp.ur_w * jcp.stride_w * in_ic_shift);
970 int out_shift = jcp.typesize_out
971 * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
972 preamble();
973
974 if (jcp.is_depthwise) {
975 bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point;
976 // dst zero point and dst scale reuse the same register
977 int idx = jcp.max_regs_ur - 1
978 + nstl::max(2 * is_zero_point, static_cast<int>(jcp.dst_scale));
979 if (!jcp.is_resrc_depthwise) zmm_src = Zmm(++idx);
980 if (!jcp.has_vnni) zmm_tmp = Zmm(++idx);
981 if (jcp.is_fast_depthwise) zmm_permute = Zmm(++idx);
982 if (jcp.signed_input) zmm_shifted_zero = Zmm(++idx);
983 // due to extra register used for shifts and compensations
984 // and/or saturation, we increment by one more
985 if (jcp.signed_input || jcp.need_saturation) ++idx;
986
987 assert(IMPLICATION(!jcp.dst_scale && !is_zero_point
988 && jcp.dst_dt != data_type::bf16,
989 idx == ker_dw_reg_base_idx));
990 }
991 if (!jcp.is_depthwise && (!jcp.has_vnni)) {
992 xor_(reg_scratch, reg_scratch);
993 Reg16 _t16 = reg_scratch.cvt16();
994 mov(_t16, 0x1);
995 vpbroadcastw(vmm_one, _t16);
996 }
997 if (jcp.is_fused_conv) {
998 mov(reg_inp_buffer_ptr, ptr[param1 + GET_OFF(src)]);
999 /* In case of fused depthwise convolution, `param.src` is not a pointer
1000 to input, instead it points to a buffer containing pointers to
1001 consecutive rows of input in format wc with c=jcp.dw_conv_buffer_oc.
1002 Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2].
1003 Traverse the data as
1004 mov(reg_data, ptr[reg_input_buffer_ptr])
1005 ... process row0 ...
1006 add(reg_input_buffer_ptr, sizeof(void*))
1007 mov(reg_data, ptr[reg_input_buffer_ptr])
1008 ... process row1 ...
1009 add(reg_input_buffer_ptr, sizeof(void*))
1010 mov(reg_data, ptr[reg_input_buffer_ptr])
1011 ... process row2 ...
1012 */
1013 xor_(reg_inp, reg_inp);
1014 } else {
1015 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
1016 }
1017 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
1018 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
1019
1020 if (jcp.simd_w == 4 && jcp.dst_dt == data_type::bf16) {
1021 auto reg_tail_32 = reg_oi.cvt32();
1022 mov(reg_tail_32, (1 << jcp.simd_w) - 1);
1023 kmovb(ktail_mask, reg_tail_32);
1024 }
1025 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
1026 int tail_size = jcp.is_depthwise
1027 ? jcp.ngroups % jcp.ch_block
1028 : jcp.oc_without_padding % jcp.oc_block;
1029 int mask = (1 << tail_size) - 1;
1030 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
1031 Reg32 regw_tmp = reg_oi.cvt32();
1032 mov(regw_tmp, mask);
1033 kmovw(ktail_mask, regw_tmp);
1034 kmovw(postops_mask, regw_tmp);
1035
1036 // To account for special store optimization, where two oc_blocks are
1037 // combined with one single write, extend the mask for 32bits (32 bf16s)
1038 const int nb_block
1039 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
1040 const bool need_extended_mask = jcp.dst_dt == data_type::bf16
1041 && isa_has_bf16(jcp.isa) && nb_block > 1;
1042 if (need_extended_mask) {
1043 mov(regw_tmp, (1 << (tail_size + jcp.simd_w)) - 1);
1044 kmovd(ktail_mask_extended, regw_tmp);
1045 }
1046 } else if (jcp.with_binary)
1047 if (jcp.oc_block != isa_simd_width_) {
1048 const int mask = (1 << jcp.oc_block) - 1;
1049 const Reg32 regw_tmp = reg_oi.cvt32();
1050 mov(regw_tmp, mask);
1051 kmovw(postops_mask, regw_tmp);
1052 }
1053 if (jcp.is_fast_depthwise) {
1054 // prepare mask register for blending weights
1055 mov(reg_scratch, 0x8888444422221111);
1056 kmovq(kblend_mask, reg_scratch);
1057 // load permute indices from data section
1058 mov(reg_scratch, permute_index_table);
1059 vmovdqu32(zmm_permute, ptr[reg_scratch]);
1060 }
1061 const int extended_filter_size
1062 = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1063 const int r_pad = nstl::max(0, jcp.r_pad);
1064 const int ow_with_no_rpad = 1
1065 + (jcp.iw + jcp.l_pad + nstl::min(0, jcp.r_pad)
1066 - extended_filter_size)
1067 / jcp.stride_w;
1068 const int n_urw_per_ow_block = jcp.ow_block / jcp.ur_w;
1069 const int max_safe_iw = nstl::max(
1070 0, jcp.iw - div_up(ic_sub_step, jcp.ic_without_padding));
1071 const int max_safe_ow = jcp.ic_without_padding % ic_sub_step == 0
1072 ? jcp.ow
1073 : (max_safe_iw + jcp.l_pad - extended_filter_size) / jcp.stride_w;
1074 Label middle_block_label, done_compute;
1075 std::vector<Label> ow_block_jmp_table;
1076
1077 // r_pad_fall_through is a special ow_block, where the block overlaps
1078 // both middle_block and r_pad/ur_w_tail region when it exists.
1079 // The number of ur_w's to compute in middle_block before executing
1080 // r_pad region is stored in r_pad_fall_through_n_urw and the ow_block
1081 // number is stored in r_pad_fall_through_ow_block.
1082 int r_pad_fall_through_ow_block = 0;
1083 int r_pad_fall_through_n_urw = 0;
1084
1085 if (jcp.nb_ow > 1) {
1086 // Only one ow block is processed, per jit call.
1087 // Number of this ow block is passed as parameter owb,
1088 // and padding processing depends on this number.
1089 //
1090 // The compute block to run is determined by using a jmp-table.
1091 // jmp-table Layout:
1092 // idx -> addr
1093 // 0 -> [...l_pad_region label[0]...]
1094 // : : : : : : : : : : : : : : :
1095 // L -> [...l_pad_region label[L]...]
1096 // L+1 -> [...r_pad_region label[0]...]
1097 // : : : : : : : : : : : : : : :
1098 // L+R -> [...r_pad_region label[R]...]
1099 //
1100 // Note: Label for middle_block is not stored in the jmp-table.
1101 //
1102 // During jit call, the jump address is calculated as below:
1103 // if (owb < L) {
1104 // jmp([jmp_table + owb*sizeof(void*)]);
1105 // } else if (owb < X) {
1106 // // X is the number of ow_blocks before r_pad region (see below).
1107 // jmp(middle_block);
1108 // } else {
1109 // sub(owb, X);
1110 // jmp([jmp_table + owb*sizeof(void*) + L*sizeof(void)]);
1111 // }
1112 //
1113 // To configure the jmp-table, we need to determine some constants
1114 // (namely, r_pad_fall_through_n_urw, r_pad_fall_through_ow_block,
1115 // n_l_pad_labels, n_labels) ahead of writing the compute assembly. So,
1116 // we simulate the filter path without writing the assembly initially.
1117 // This makes the math for calculating the constants become simple and
1118 // self explanatory.
1119
1120 // Begin simulation without writing assembly
1121 int n_l_pad_labels = 0;
1122 int n_labels = 0;
1123 int cur_ow = 0;
1124
1125 // l_pad region:
1126 n_l_pad_labels = div_up(n_urw_l_pad, n_urw_per_ow_block);
1127 n_labels = n_l_pad_labels;
1128 cur_ow += n_urw_l_pad * jcp.ur_w;
1129
1130 // middle_region:
1131 int n_urw_middle_block_loop = 0;
1132 int cur_r_pad = nstl::max(0,
1133 calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw,
1134 jcp.stride_w, extended_filter_size));
1135 if (cur_ow + jcp.ur_w <= jcp.ow && cur_r_pad == 0) {
1136 n_urw_middle_block_loop
1137 = nstl::max(0,
1138 nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow)
1139 / jcp.ur_w;
1140 cur_ow += n_urw_middle_block_loop * jcp.ur_w;
1141 }
1142 r_pad_fall_through_n_urw = (cur_ow / jcp.ur_w) % n_urw_per_ow_block;
1143 r_pad_fall_through_ow_block = cur_ow / (n_urw_per_ow_block * jcp.ur_w);
1144
1145 // r_pad or last_sp_block
1146 if (cur_ow + jcp.ur_w <= jcp.ow) {
1147 if (r_pad_fall_through_n_urw == 0) ++n_labels;
1148 const int n_urw_r_pad_region = (jcp.ow - cur_ow) / jcp.ur_w;
1149 n_labels += nstl::max(0,
1150 div_up(r_pad_fall_through_n_urw + n_urw_r_pad_region,
1151 n_urw_per_ow_block)
1152 - 1);
1153 }
1154
1155 if (jcp.ur_w_tail != 0) {
1156 if (jcp.ow % jcp.ow_block == jcp.ur_w_tail) ++n_labels;
1157 }
1158 // End of simulation
1159
1160 ow_block_jmp_table.resize(n_labels);
1161
1162 // Begin jump-table logic
1163 Label ow_block_jmp_table_label;
1164 if (!ow_block_jmp_table.empty())
1165 mov(reg_jmp_tbl_base, ow_block_jmp_table_label);
1166 mov(reg_oi, n_urw_per_ow_block);
1167 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1168 if (jcp.l_pad > 0) {
1169 Label middle_or_rpad_check;
1170 cmp(reg_owb, n_l_pad_labels);
1171 jge(middle_or_rpad_check, T_NEAR);
1172 jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *)]);
1173 L(middle_or_rpad_check);
1174 // harness passes shifted src pointer that does not take
1175 // left-padding into account. So, we must re-shift here.
1176 const int inp_shift_pad_middle_block = -1 * jcp.typesize_in
1177 * nstl::min(jcp.l_pad, n_urw_l_pad * urw_inp_stride)
1178 * in_ic_shift;
1179 add(reg_inp, inp_shift_pad_middle_block);
1180 }
1181 if (r_pad_fall_through_n_urw != 0) {
1182 mov(reg_scratch, r_pad_fall_through_n_urw);
1183 cmp(reg_owb, r_pad_fall_through_ow_block);
1184 cmove(reg_oi, reg_scratch);
1185 if (n_urw_middle_block_loop > 0) {
1186 sub(reg_owb, r_pad_fall_through_ow_block);
1187 // simple middle_block
1188 jle(middle_block_label, T_NEAR);
1189 dec(reg_owb);
1190 } else {
1191 sub(reg_owb, r_pad_fall_through_ow_block + 1);
1192 }
1193 } else {
1194 sub(reg_owb, r_pad_fall_through_ow_block);
1195 // simple middle_block
1196 if (n_urw_middle_block_loop) jl(middle_block_label, T_NEAR);
1197 }
1198 // r_pad-only region
1199 if (!ow_block_jmp_table.empty())
1200 jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *)
1201 + n_l_pad_labels * sizeof(void *)]);
1202
1203 if (!ow_block_jmp_table.empty()) {
1204 align(8);
1205 L(ow_block_jmp_table_label);
1206 {
1207 for (size_t i = 0; i < ow_block_jmp_table.size(); ++i) {
1208 putL(ow_block_jmp_table[i]);
1209 }
1210 }
1211 }
1212 // End of jump-table logic
1213 }
1214
1215 // Begin kernel
1216 int cur_ow = 0;
1217 int cur_n_oi = 0; // used only for jcp.nb_ow > 1 scenario
1218 int label_cntr = 0;
1219 int cur_l_pad = 0;
1220 if (jcp.l_pad > 0) {
1221 for (cur_l_pad = jcp.l_pad;
1222 cur_l_pad > 0 && cur_ow + jcp.ur_w <= jcp.ow;
1223 cur_l_pad -= urw_inp_stride) {
1224 if (jcp.nb_ow > 1 && cur_n_oi == 0) {
1225 // cur_n_oi == 0 signifies beginning of new ow_block
1226 // (or end of previous block)
1227 const dim_t inp_lpad_region_shift = -label_cntr * jcp.ow_block
1228 * jcp.stride_w * in_ic_shift;
1229 L(ow_block_jmp_table[label_cntr++]);
1230 // harness passes shifted src pointer that does not take
1231 // left-padding into account. So, we must re-shift here.
1232 add(reg_inp, inp_lpad_region_shift);
1233 }
1234
1235 cur_ow += jcp.ur_w;
1236 int cur_r_pad = nstl::max(0,
1237 calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw,
1238 jcp.stride_w, extended_filter_size));
1239 icb_loop(jcp.ur_w, cur_l_pad, cur_r_pad, cur_ow > max_safe_ow);
1240 add(reg_out, out_shift);
1241 dec(reg_oi);
1242
1243 if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) {
1244 // We compute one owb per jit call. So, insert an
1245 // unconditional jmp, after computing one owb.
1246 jmp(done_compute, T_NEAR);
1247 cur_n_oi = 0;
1248 }
1249 }
1250 if (jcp.nb_ow == 1 || cur_n_oi != 0) {
1251 // Let it "fall-through" middle_block_label
1252 add(reg_inp, inp_shift_pad);
1253 }
1254 }
1255
1256 // middle_block
1257 {
1258 int cur_r_pad = nstl::max(0,
1259 calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw,
1260 jcp.stride_w, extended_filter_size));
1261 if (cur_r_pad == 0 && cur_ow + jcp.ur_w <= jcp.ow) {
1262 int n_oi_middle_block_loop
1263 = nstl::max(0,
1264 nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow)
1265 / jcp.ur_w;
1266 if (jcp.nb_ow == 1 && n_oi_middle_block_loop > 1)
1267 mov(reg_oi, n_oi_middle_block_loop);
1268 L(middle_block_label);
1269 if (n_oi_middle_block_loop > 0) {
1270 icb_loop(jcp.ur_w, 0, 0, false);
1271 add(reg_inp, inp_shift);
1272 add(reg_out, out_shift);
1273 if (n_oi_middle_block_loop > 1) {
1274 dec(reg_oi);
1275 jg(middle_block_label, T_NEAR);
1276 }
1277 }
1278 cur_ow += n_oi_middle_block_loop * jcp.ur_w;
1279 cur_n_oi = (cur_n_oi + n_oi_middle_block_loop) % n_urw_per_ow_block;
1280 }
1281 }
1282
1283 // r_pad region or last_sp_block
1284 if (cur_ow + jcp.ur_w <= jcp.ow) {
1285 if (jcp.nb_ow > 1) {
1286 if (cur_n_oi == 0) {
1287 jmp(done_compute, T_NEAR);
1288 } else {
1289 // r_pad fall-through
1290 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1291 cmp(reg_owb, r_pad_fall_through_ow_block);
1292 jne(done_compute, T_NEAR);
1293 }
1294 }
1295
1296 while (cur_ow + jcp.ur_w <= jcp.ow) {
1297 if (jcp.nb_ow > 1 && cur_n_oi == 0) {
1298 L(ow_block_jmp_table[label_cntr++]);
1299 }
1300 cur_ow += jcp.ur_w;
1301 int cur_r_pad = calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw,
1302 jcp.stride_w, extended_filter_size);
1303 assert(cur_r_pad > 0 || cur_ow > max_safe_ow); // else, why be here?
1304 icb_loop(jcp.ur_w, 0, cur_r_pad, cur_ow > max_safe_ow);
1305 add(reg_inp, inp_shift);
1306 add(reg_out, out_shift);
1307
1308 if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) {
1309 // We compute one owb per jit call. So, insert an
1310 // unconditional jmp, after computing one owb.
1311 jmp(done_compute, T_NEAR);
1312 cur_n_oi = 0;
1313 }
1314 }
1315 // Let it fall-through ur_w_tail
1316 }
1317
1318 // ur_w_tail
1319 if (jcp.ur_w_tail != 0) {
1320 if (jcp.nb_ow > 1) {
1321 if (cur_n_oi == 0) {
1322 jmp(done_compute, T_NEAR);
1323 L(ow_block_jmp_table[label_cntr++]);
1324 } else {
1325 // In case, when there is no r_pad region, then there exists an
1326 // ambiguity btw middle_blocks and r_pad_fall_through_ow_block.
1327 // If not properly distinguished, there can be a race condition
1328 // as middle_blocks and r_pad_fall_through_ow_block both try to
1329 // compute ur_w_tail work at the end.
1330 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1331 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
1332 jne(done_compute, T_NEAR);
1333 }
1334 }
1335 icb_loop(jcp.ur_w_tail, nstl::max(0, cur_l_pad), r_pad, true);
1336 }
1337 L(done_compute);
1338 assert(ow_block_jmp_table.size() == static_cast<size_t>(label_cntr));
1339
1340 postamble();
1341
1342 if (jcp.with_eltwise) postops_injector_->prepare_table();
1343
1344 if (jcp.is_fast_depthwise) {
1345 align(64);
1346 L(permute_index_table);
1347 const uint32_t _idx[]
1348 = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
1349 for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
1350 dd(_idx[i]);
1351 }
1352}
1353
1354status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
1355 const convolution_desc_t &cd, memory_desc_t &src_md,
1356 memory_desc_t &weights_md, memory_desc_t &dst_md,
1357 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1358 using namespace prop_kind;
1359
1360 const memory_desc_wrapper src_d(&src_md);
1361 const memory_desc_wrapper weights_d(&weights_md);
1362 const memory_desc_wrapper dst_d(&dst_md);
1363 const memory_desc_wrapper bias_d(&bias_md);
1364
1365 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1366 const int ndims = src_d.ndims();
1367 const bool is_1d = ndims == 3;
1368 const bool is_2d = ndims == 4;
1369 const bool is_3d = ndims == 5;
1370 assert(is_1d || is_2d || is_3d);
1371
1372 if (!(mayiuse(avx512_core)
1373 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
1374 && weights_d.data_type() == data_type::s8
1375 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
1376 data_type::s8, data_type::u8, data_type::bf16)))
1377 return status::unimplemented;
1378
1379 jcp = zero<decltype(jcp)>();
1380 jcp.nthr = nthreads;
1381 jcp.ndims = ndims;
1382 jcp.prop_kind = cd.prop_kind;
1383 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1384 jcp.mb = src_d.dims()[0];
1385 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1386 jcp.oc_without_padding = jcp.oc;
1387 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1388 jcp.ic_without_padding = jcp.ic;
1389 jcp.id = is_3d ? src_d.dims()[2] : 1;
1390 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1391 jcp.iw = src_d.dims()[ndims - 1];
1392 jcp.od = is_3d ? dst_d.dims()[2] : 1;
1393 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1394 jcp.ow = dst_d.dims()[ndims - 1];
1395 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
1396 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
1397 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1398 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
1399 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
1400 jcp.l_pad = cd.padding[0][ndims - 3];
1401 jcp.stride_d = is_3d ? cd.strides[0] : 1;
1402 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
1403 jcp.stride_w = cd.strides[ndims - 3];
1404 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1405
1406 jcp.ur_h = 1; /* no code-unrolling by h so far */
1407 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
1408 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
1409 jcp.dilate_w = cd.dilates[ndims - 3];
1410
1411 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1412 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1413 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1414 jcp.r_pad = calculate_end_padding(
1415 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1416 jcp.b_pad = calculate_end_padding(
1417 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1418 jcp.back_pad = calculate_end_padding(
1419 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1420
1421 jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
1422 jcp.need_saturation = utils::one_of(dst_d.data_type(), u8, s8, s32);
1423 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
1424
1425 // Used for bfloat16 output
1426 jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16
1427 : bf16_emulation_t::get_isa();
1428
1429 if (jcp.is_depthwise && is_3d)
1430 // NOTE: 3D depthwise is not currently supported here.
1431 return status::unimplemented;
1432
1433 if (jcp.is_depthwise) {
1434 jcp.ch_block = 16;
1435 jcp.ic_block = 1;
1436 jcp.oc_block = 1;
1437 } else {
1438 jcp.ch_block = 1;
1439 jcp.ic_block = 16;
1440 jcp.oc_block = 16;
1441
1442 if (jcp.ngroups == 1) {
1443 /* For non grouped convolutions, pad channels by 16 if needed */
1444 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1445 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1446 } else if (jcp.ngroups != 1
1447 && ((jcp.ic % jcp.ic_block != 0)
1448 || (jcp.oc % jcp.oc_block != 0))) {
1449 /* For grouped convolutions, oneDNN doesn't support padding.
1450 When channels per group is not multiple of 16:
1451 - Use Ymm when channels per group is multiple of 8,
1452 - Use Xmm when channels per group is multiple of 4,
1453 - Otherwise return unimplemented. */
1454 jcp.ic_block = (jcp.ic % 8 == 0) && (jcp.oc % 8 == 0) ? 8 : 4;
1455 jcp.oc_block = jcp.ic_block;
1456 }
1457 if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
1458 return status::unimplemented;
1459 }
1460
1461 jcp.simd_w = jcp.is_depthwise ? jcp.ch_block : jcp.ic_block;
1462
1463 const auto zp = attr.zero_points_;
1464 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
1465 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
1466 jcp.zp_src_is_common
1467 = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel
1468 assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common));
1469
1470 if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.is_fused_conv)
1471 return status::unimplemented;
1472
1473 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
1474 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
1475 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
1476 const int wei_mask_per_oc = 1 << (int)with_groups;
1477 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
1478 jcp.dst_scale = !dst_scales.has_default_values();
1479
1480 // only common src & dst scales are supported
1481 // only common and per-oc-channel weight scales are supported
1482 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
1483 && everyone_is(src_scales.mask_, dst_scales.mask_, 0);
1484 if (!scales_ok) return status::unimplemented;
1485
1486 jcp.has_vnni = mayiuse(avx512_core_vnni);
1487 const bool bf16_req_extra_regs = cd.dst_desc.data_type == data_type::bf16
1488 && !isa_has_bf16(jcp.isa);
1489 jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.has_vnni
1490 && jcp.ngroups % jcp.ch_block == 0
1491 && !bf16_req_extra_regs; /* groups not multiple of
1492 ch_block (= 16) would require byte masking for load from src */
1493
1494 jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw
1495 && jcp.kw < 4 && jcp.dilate_w == 0;
1496
1497 if (jcp.is_depthwise) {
1498 jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise
1499 - jcp.signed_input - (!jcp.has_vnni)
1500 - (jcp.signed_input || jcp.need_saturation) // both alias
1501 - (bf16_req_extra_regs ? 4 : 0);
1502 } else {
1503 jcp.max_regs_ur = bf16_req_extra_regs ? 26 : jcp.has_vnni ? 31 : 28;
1504 }
1505
1506 // TODO: re-implement so that the JIT Kernel uses the least amount of
1507 // registers. Currently, there are issues because of compile and run time
1508 // definitions.
1509 if (jcp.dst_scale) jcp.max_regs_ur = 26;
1510 if (jcp.src_zero_point || jcp.dst_zero_point) jcp.max_regs_ur = 25;
1511
1512 auto set_or_check_wei_format = [&]() {
1513 using namespace format_tag;
1514 using namespace memory_extra_flags;
1515 format_tag_t wei_tag;
1516 if (jcp.ic_block == 16 || jcp.ch_block == 16) {
1517 if (is_3d) {
1518 wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i;
1519 } else if (is_1d) {
1520 wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
1521 : OIw4i16o4i;
1522 } else {
1523 assert(is_2d);
1524 wei_tag = with_groups
1525 ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i
1526 : OIhw4i16o4i;
1527 }
1528 } else if (jcp.ic_block == 8) {
1529 assert(with_groups);
1530 wei_tag = is_3d ? gOIdhw2i8o4i : is_2d ? gOIhw2i8o4i : gOIw2i8o4i;
1531 } else {
1532 assert(with_groups && jcp.ic_block == 4);
1533 wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
1534 }
1535
1536 memory_desc_t want_wei_md = weights_md;
1537 memory_desc_init_by_tag(want_wei_md, wei_tag);
1538 if (jcp.signed_input) {
1539 want_wei_md.extra.flags = 0 | compensation_conv_s8s8 | scale_adjust;
1540 want_wei_md.extra.compensation_mask = (1 << 0)
1541 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
1542 want_wei_md.extra.scale_adjust
1543 = mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
1544 }
1545 if (jcp.src_zero_point) {
1546 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
1547 want_wei_md.extra.asymm_compensation_mask = (1 << 0)
1548 + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
1549 }
1550
1551 if (weights_md.format_kind == format_kind::any) {
1552 weights_md = want_wei_md;
1553 return true;
1554 }
1555
1556 return weights_md == want_wei_md;
1557 };
1558
1559 if (!set_or_check_wei_format()) return status::unimplemented;
1560
1561 format_tag_t dat_tag = utils::pick(
1562 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1563
1564 if (src_d.format_kind() == format_kind::any) {
1565 CHECK(memory_desc_init_by_tag(src_md, dat_tag));
1566 jcp.src_tag = dat_tag;
1567 } else {
1568 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
1569 }
1570 if (jcp.src_tag != dat_tag) return status::unimplemented;
1571
1572 if (dst_d.format_kind() == format_kind::any) {
1573 CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
1574 jcp.dst_tag = dat_tag;
1575 } else {
1576 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
1577 }
1578 if (jcp.dst_tag != dat_tag) return status::unimplemented;
1579
1580 if (jcp.with_bias) {
1581 if (bias_d.format_kind() == format_kind::any)
1582 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
1583 }
1584 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1585 jcp.dst_dt = cd.dst_desc.data_type;
1586
1587 CHECK(attr.set_default_formats(&dst_md));
1588
1589 const auto &post_ops = attr.post_ops_;
1590 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
1591 jcp.with_eltwise = eltwise_ind != -1;
1592
1593 const int binary_ind = post_ops.find(primitive_kind::binary);
1594 jcp.with_binary = binary_ind != -1;
1595
1596 const int sum_ind = post_ops.find(primitive_kind::sum);
1597 jcp.with_sum = sum_ind != -1;
1598 jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt);
1599
1600 jcp.post_ops = post_ops;
1601
1602 using namespace injector;
1603 static constexpr bool sum_at_pos_0_only = false;
1604 static constexpr bool sum_requires_scale_one = false;
1605 static constexpr bool sum_requires_zp_zero = false;
1606 const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum},
1607 jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one,
1608 sum_requires_zp_zero});
1609 if (!post_ops_ok_) return status::unimplemented;
1610
1611 jcp.typesize_in = types::data_type_size(src_d.data_type());
1612 jcp.typesize_out = types::data_type_size(dst_d.data_type());
1613 jcp.typesize_bia
1614 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
1615
1616 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
1617 jcp.nb_ic = jcp.ic / jcp.ic_block;
1618 jcp.nb_oc = jcp.oc / jcp.oc_block;
1619
1620 // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
1621 int nb_ch_blocking = 4;
1622 for (/* init above */; nb_ch_blocking > 1; nb_ch_blocking--)
1623 if (jcp.nb_ch % nb_ch_blocking == 0) break;
1624 jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
1625
1626 // If OC blocking is incommensurate with the number of OC blocks (general
1627 // requirement for all convolutions), or if it results in an unrolling
1628 // factor smaller than the left padding (special requirement for SSD:fc6),
1629 // then search for a smaller OC blocking that satisfies both constraints.
1630 auto is_oc_blocking_ok = [&](int block) {
1631 int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
1632 return jcp.nb_oc % block == 0 && jcp.l_pad <= ur_w
1633 && jcp.ow % ur_w != 1;
1634 };
1635
1636 // choose nb_oc work chunk size for distribution within threads
1637 int max_threading_nb_oc_chunk = 4;
1638 // Performance improvements for googlenet_v3 and resnet_50 with mb = 1;
1639 // TODO: generalize this condition and rewrite it in appropriate manner
1640 int ncores_per_socket = (int)cpu().getNumCores(
1641 Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
1642 if (jcp.has_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3
1643 && jcp.stride_w == 1 && jcp.ic % 64 == 0
1644 && jcp.nthr <= ncores_per_socket)
1645 max_threading_nb_oc_chunk = 2;
1646 jcp.nb_oc_blocking_thr_chunk
1647 = nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
1648 for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) {
1649 if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) break;
1650 }
1651
1652 // choose oc blocking for computational kernel
1653 jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
1654
1655 // Performance improvements for googlenet_v3 with mb = 1;
1656 // TODO: generalize this condition and rewrite it in appropriate manner
1657 const int size_treshold_for_nb_oc_blocking_reduction = 17;
1658 if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction
1659 && jcp.stride_w == 1 && jcp.nthr <= ncores_per_socket
1660 && !(jcp.kh == 1 && jcp.kw == 3)
1661 && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) {
1662 const int max_nb_oc_blocking = 2;
1663 jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc);
1664 for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
1665 if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0
1666 && is_oc_blocking_ok(jcp.nb_oc_blocking))
1667 break;
1668 }
1669
1670 if (jcp.is_resrc_depthwise)
1671 jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
1672 / (jcp.nb_ch_blocking + jcp.stride_w);
1673 else
1674 jcp.ur_w = jcp.max_regs_ur
1675 / (jcp.is_depthwise ? jcp.nb_ch_blocking
1676 : jcp.nb_oc_blocking + 1);
1677 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
1678 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1679
1680 auto get_thr_eff = [=](int nb_ow, int nthr) {
1681 int base_work_amount = jcp.mb * jcp.nb_ch * jcp.od * jcp.oh
1682 * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
1683 auto work_amount = base_work_amount * nb_ow;
1684 return float(work_amount) / rnd_up(work_amount, nthr);
1685 };
1686
1687 auto get_ow_block = [=](int ur_w, int nthr) {
1688 int res_ow_block = jcp.ow;
1689 float best_thr_eff = get_thr_eff(1, nthr);
1690 float thr_eff;
1691 int max_nb_ow = div_up(jcp.ow, ur_w);
1692 for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1693 int ow_block
1694 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
1695 if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
1696 && best_thr_eff > 0.8f)
1697 break;
1698 if (div_up(jcp.ow, ow_block) != nb_ow) continue;
1699 thr_eff = get_thr_eff(nb_ow, nthr);
1700 if (ow_block >= ur_w && thr_eff > 1.1f * best_thr_eff) {
1701 res_ow_block = ow_block;
1702 best_thr_eff = thr_eff;
1703 }
1704 if (best_thr_eff > 0.9f) break;
1705 }
1706 return res_ow_block;
1707 };
1708
1709 jcp.ow_block = get_ow_block(jcp.ur_w, jcp.nthr);
1710 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1711 float thr_eff = get_thr_eff(jcp.nb_ow, jcp.nthr);
1712
1713 /* adjust the thread decomposition
1714 * to improve the thr_eff for small size problem
1715 * the threshold L1_cache_size is empirical */
1716 size_t wei_size
1717 = sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw * jcp.kd;
1718 size_t out_size
1719 = jcp.mb * jcp.typesize_out * jcp.oc * jcp.oh * jcp.ow * jcp.od;
1720 size_t inp_size
1721 = jcp.mb * jcp.typesize_in * jcp.ic * jcp.ih * jcp.iw * jcp.id;
1722 size_t total_size = jcp.ngroups * (wei_size + out_size + inp_size);
1723 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1724
1725 if (thr_eff < 0.9f && jcp.ngroups < jcp.nthr
1726 && (total_size < L1_cache_size)) {
1727 int ow_block = jcp.ow_block;
1728 float best_thr_eff = -1.0f;
1729 float eff = -1.0f;
1730 int end_nthr = with_groups ? jcp.ngroups : 1;
1731 for (int nthr = jcp.nthr / 2; nthr > end_nthr; nthr--) {
1732 ow_block = get_ow_block(jcp.ur_w, nthr);
1733 eff = get_thr_eff(div_up(jcp.ow, ow_block), nthr);
1734 if (eff > 1.1f * best_thr_eff) {
1735 best_thr_eff = eff;
1736 jcp.ow_block = ow_block;
1737 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1738 jcp.nthr = jcp.aligned_threads = nthr;
1739 if (best_thr_eff > 0.9f) break;
1740 }
1741 }
1742 }
1743
1744 if (jcp.oc % jcp.oc_block != 0) return status::unimplemented;
1745
1746 pick_loop_order(jcp, jcp.nthr);
1747
1748 jcp.nb_ic_L2 = jcp.nb_ic;
1749
1750 jcp.wei_adj_scale
1751 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
1752 ? weights_d.extra().scale_adjust
1753 : 1.f;
1754
1755 return status::success;
1756}
1757
1758void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
1759 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1760 const primitive_attr_t &attr) {
1761 const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_;
1762 const dim_t scales_count = wei_mask == 0 ? 1 : jcp.oc * jcp.ngroups;
1763 dim_t count = wei_mask == 0 ? (dim_t)16 : scales_count;
1764 scratchpad.book<float>(key_conv_adjusted_scales, count);
1765}
1766
1767template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>;
1768template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>;
1769template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>;
1770} // namespace x64
1771} // namespace cpu
1772} // namespace impl
1773} // namespace dnnl
1774
1775// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
1776