1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/memory.hpp"
19#include "common/memory_tracking.hpp"
20#include "common/nstl.hpp"
21#include "common/type_helpers.hpp"
22#include "common/utils.hpp"
23
24#include "cpu/platform.hpp"
25#include "cpu/x64/injectors/injector_utils.hpp"
26#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
27#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
28#include "cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp"
29
30#define GET_OFF(field) offsetof(jit_conv_call_s, field)
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37using namespace dnnl::impl::memory_tracking::names;
38using namespace dnnl::impl::utils;
39using namespace Xbyak;
40using namespace injector_utils;
41
42namespace {
43void pick_loop_order(jit_conv_conf_t &jcp) {
44 jcp.loop_order = loop_cwgn;
45 if (jcp.ngroups > 1) {
46 jcp.loop_order = loop_ngcw;
47 if (jcp.mb < jcp.nthr)
48 jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
49 } else if (jcp.mb >= jcp.nthr && jcp.ic_without_padding <= 8) {
50 jcp.loop_order = loop_ngcw;
51 }
52}
53} // namespace
54
55template <cpu_isa_t isa, typename Vmm>
56_jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::_jit_uni_x8s8s32x_fwd_kernel(
57 const jit_conv_conf_t &ajcp, const primitive_attr_t &attr,
58 const memory_desc_t &dst_md)
59 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
60 , jcp(ajcp)
61 , attr_(attr) {
62 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
63 using namespace binary_injector;
64 static constexpr bool preserve_gpr = true;
65 static constexpr bool preserve_vmm = false;
66 static constexpr size_t helper_vmm_idx = 15;
67 const size_t oc_block_tail = jcp.oc_block % isa_simd_width_;
68 const size_t tail_size = oc_block_tail
69 ? oc_block_tail
70 : jcp.oc_without_padding % isa_simd_width_;
71
72 const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx,
73 r13, r14, r15, preserve_gpr, preserve_vmm,
74 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
75 memory_desc_wrapper(dst_md), tail_size, true};
76 const static_params_t static_params {
77 this->param1, rhs_arg_static_params};
78
79 postops_injector_
80 = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
81 this, jcp.post_ops, static_params);
82 }
83}
84
85template <cpu_isa_t isa, typename Vmm>
86void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::prepare_output(int ur_w) {
87 int nb_oc_block
88 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
89 for (int k = 0; k < nb_oc_block; ++k)
90 for (int j = 0; j < ur_w; ++j) {
91 Vmm vmm = vmm_out(j, k);
92 uni_vpxor(vmm, vmm, vmm);
93 }
94 if (jcp.signed_input) {
95 auto xmm_shift = Xbyak::Xmm(vmm_shift.getIdx());
96 if (jcp.is_depthwise)
97 mov(reg_scratch, 128);
98 else
99 mov(reg_scratch, 0x80808080);
100 uni_vmovq(xmm_shift, reg_scratch);
101 uni_vpbroadcastd(vmm_shift, xmm_shift);
102 }
103}
104
105template <cpu_isa_t isa, typename Vmm>
106void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::cvt2ps(data_type_t type_in,
107 const Vmm &vmm_in, const Reg64 &reg, int offset, int load_size) {
108
109 load_data(type_in, vmm_in, reg, offset, load_size);
110 if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in);
111}
112
113template <typename F>
114void iterate(const int nb_oc_block, const int ur_w,
115 const bool last_oc_block_flag, const bool force_masking, const F &f) {
116 for (int k = 0; k < nb_oc_block; k++) {
117 const bool mask_flag
118 = force_masking || (last_oc_block_flag && k == nb_oc_block - 1);
119 for (int j = 0; j < ur_w; j++)
120 f(mask_flag, k, j);
121 }
122}
123template <typename F>
124void iterate(const int nb_oc_block, const int ur_w,
125 const bool last_oc_block_flag, const F &f) {
126 iterate(nb_oc_block, ur_w, last_oc_block_flag, false, f);
127}
128template <typename F>
129void iterate(const int nb_oc_block, const int ur_w, const F &f) {
130 iterate(nb_oc_block, ur_w, false, false, f);
131}
132
133template <cpu_isa_t isa, typename Vmm>
134void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::apply_sum(const int nb_oc_block,
135 const int ur_w, const bool last_oc_block_flag, const int oc_block,
136 const float *p_sum_scale, const int32_t *p_sum_zp) {
137 if (jcp.with_sum) {
138 assert(!utils::any_null(p_sum_scale, p_sum_zp)
139 && "p_sum_scale or p_sum_zp = nullptr");
140 const float sum_scale = *p_sum_scale;
141 const int32_t sum_zp = *p_sum_zp;
142 const auto sum_injector_lam = [this, oc_block, sum_scale, sum_zp](
143 const bool mask_flag, const int k,
144 const int j) {
145 const int aux_output_offset = jcp.typesize_out
146 * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
147 cvt2ps(jcp.sum_dt, vmm_prev_dst, reg_out, aux_output_offset,
148 mask_flag ? get_tail_size() : get_blocking_size());
149 const Vmm vmm = vmm_out(j, k);
150 if (sum_zp != 0) {
151 uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_zp]);
152 uni_vcvtdq2ps(vmm_tmp, vmm_tmp);
153 uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_tmp);
154 }
155 if (sum_scale == 1.f)
156 uni_vaddps(vmm, vmm, vmm_prev_dst);
157 else {
158 uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_scale]);
159 uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_tmp);
160 }
161 };
162 const auto sum_injector = [=]() {
163 iterate(nb_oc_block, ur_w, last_oc_block_flag, sum_injector_lam);
164 };
165 if (*p_sum_scale != 1.f)
166 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
167 if (*p_sum_zp != 0)
168 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
169 postops_injector_->set_lambda_injector(
170 primitive_kind::sum, sum_injector);
171 }
172}
173
174template <cpu_isa_t isa, typename Vmm>
175void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::apply_postops(
176 const int nb_oc_block, const int ur_w, const bool last_oc_block_flag,
177 const int oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) {
178 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
179 if (jcp.with_sum && *p_sum_zp != 0) push(reg_ptr_sum_zp);
180 apply_sum(nb_oc_block, ur_w, last_oc_block_flag, oc_block, p_sum_scale,
181 p_sum_zp);
182
183 vmm_index_set_t vmm_idxs;
184 if (jcp.with_binary) {
185 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
186 const bool oc_blk_is_smaller_than_vmm = oc_block < isa_simd_width_;
187 iterate(nb_oc_block, ur_w, last_oc_block_flag,
188 oc_blk_is_smaller_than_vmm,
189 [&](const bool mask_flag, const int k, const int j) {
190 const size_t aux_output_offset = jcp.typesize_out
191 * (k * oc_block
192 + j * jcp.oc_without_padding
193 * jcp.ngroups);
194 const auto vmm_idx = vmm_out_idx(j, k);
195 vmm_idxs.emplace(vmm_idx);
196
197 rhs_arg_params.vmm_idx_to_out_reg.emplace(
198 vmm_idx, reg_out);
199 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
200 vmm_idx, aux_output_offset);
201
202 if (mask_flag)
203 rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx);
204 });
205
206 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
207 } else {
208 iterate(nb_oc_block, ur_w,
209 [&](const bool, const int k, const int j) {
210 vmm_idxs.emplace(vmm_out_idx(j, k));
211 });
212 postops_injector_->compute_vector_range(vmm_idxs);
213 }
214 if (jcp.with_sum && *p_sum_zp != 0) pop(reg_ptr_sum_zp);
215 }
216}
217
218template <cpu_isa_t isa, typename Vmm>
219void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::store_output(
220 int ur_w, bool last_oc_block_flag) {
221 int nb_oc_block
222 = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
223 int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
224
225 mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
226 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
227 if (jcp.signed_input)
228 mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
229
230 if (jcp.src_zero_point) {
231 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
232 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
233 uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]);
234 }
235
236 const auto &p = attr_.post_ops_;
237 const int sum_idx = p.find(primitive_kind::sum);
238 const float *p_sum_scale = nullptr;
239 const int32_t *p_sum_zp = nullptr;
240 if (sum_idx != -1) {
241 const auto &p_entry = p.entry_[sum_idx];
242 p_sum_scale = &p_entry.sum.scale;
243 p_sum_zp = &p_entry.sum.zero_point;
244 }
245
246 for (int k = 0; k < nb_oc_block; ++k) {
247 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
248 const int load_size = mask_flag ? get_tail_size() : get_blocking_size();
249 int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
250 if (jcp.with_bias) {
251 int bias_offset = jcp.typesize_bia * k * oc_block;
252 cvt2ps(jcp.bia_dt, vmm_bias, reg_bias, bias_offset, load_size);
253 }
254 if (jcp.signed_input) {
255 const int comp_offset = sizeof(int32_t) * k * oc_block;
256 load_data(data_type::s32, vmm_comp, reg_compensation, comp_offset,
257 load_size);
258 }
259 if (jcp.src_zero_point) {
260 const int zp_offset = sizeof(int32_t) * k * oc_block;
261 load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation,
262 zp_offset, load_size);
263 uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_zp);
264 }
265 /* add to ymm_accum: compensation, zero_point, bias and permute */
266 if (mask_flag) {
267 uni_vpxor(vmm_scale, vmm_scale, vmm_scale);
268 load_data(data_type::s32, vmm_scale, reg_ptr_scales, scale_offset,
269 get_tail_size());
270 } else {
271 uni_vmovups(vmm_scale, ptr[reg_ptr_scales + scale_offset]);
272 }
273
274 for (int j = 0; j < ur_w; ++j) {
275 const Vmm vmm = vmm_out(j, k);
276
277 /* add comp in s32 to avoid loss of precision
278 when convert s32 to f32 in integer (2^24)
279 TODO: do the same to bias */
280 if (jcp.signed_input) uni_vpaddd(vmm, vmm, vmm_comp);
281 if (jcp.src_zero_point) uni_vpaddd(vmm, vmm, vmm_zp_comp);
282 uni_vcvtdq2ps(vmm, vmm);
283
284 uni_vmulps(vmm, vmm, vmm_scale);
285
286 if (jcp.with_bias) uni_vaddps(vmm, vmm, vmm_bias);
287 }
288 }
289
290 apply_postops(nb_oc_block, ur_w, last_oc_block_flag, oc_block, p_sum_scale,
291 p_sum_zp);
292
293 if (jcp.dst_scale) {
294 mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
295 uni_vmovups(vmm_dst_scale, ptr[reg_dst_scale]);
296
297 /* Apply dst scale to accumulator */
298 for (int k = 0; k < nb_oc_block; k++) {
299 for (int j = 0; j < ur_w; j++) {
300 const Vmm vmm = vmm_out(j, k);
301 uni_vmulps(vmm, vmm, vmm_dst_scale);
302 }
303 }
304 }
305
306 if (jcp.dst_zero_point) {
307 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
308 uni_vpbroadcastd(vmm_zp, ptr[reg_dst_zero_point]);
309 uni_vcvtdq2ps(vmm_zp, vmm_zp);
310
311 /* Add dst zero_point to accumulator */
312 for (int k = 0; k < nb_oc_block; k++) {
313 for (int j = 0; j < ur_w; j++) {
314 const Vmm vmm = vmm_out(j, k);
315 uni_vaddps(vmm, vmm, vmm_zp);
316 }
317 }
318 }
319
320 // Properly saturate the accumulators for integer datatypes
321
322 // No need to saturate on lower bound for signed integer types, as
323 // the conversion to int would return INT_MIN, and then proper
324 // saturation will happen in store_data
325 if (jcp.dst_dt == data_type::u8) {
326 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
327 for (int k = 0; k < nb_oc_block; ++k) {
328 for (int j = 0; j < ur_w; ++j) {
329 Vmm vmm = vmm_out(j, k);
330 uni_vmaxps(vmm, vmm, vmm_zero);
331 }
332 }
333 }
334 if (utils::one_of(
335 jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) {
336 float saturation_ubound = types::max_value<float>(jcp.dst_dt);
337 Xmm xmm_saturation(vmm_saturation.getIdx());
338 mov(reg_ptr_saturation_ubound, float2int(saturation_ubound));
339 uni_vmovq(xmm_saturation, reg_ptr_saturation_ubound);
340 uni_vbroadcastss(vmm_saturation, xmm_saturation);
341
342 for (int k = 0; k < nb_oc_block; ++k) {
343 for (int j = 0; j < ur_w; ++j) {
344 Vmm vmm = vmm_out(j, k);
345 uni_vminps(vmm, vmm, vmm_saturation);
346 }
347 }
348 }
349
350 // Convert float accumulators to int datatype if needed
351 if (utils::one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32))
352 for (int k = 0; k < nb_oc_block; ++k)
353 for (int j = 0; j < ur_w; ++j) {
354 Vmm vmm = vmm_out(j, k);
355 uni_vcvtps2dq(vmm, vmm);
356 }
357
358 /* write out register to output_addr */
359 for (int k = 0; k < nb_oc_block; ++k) {
360 const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
361 for (int j = 0; j < ur_w; ++j) {
362 Vmm r_vmm = vmm_out(j, k);
363 int aux_output_offset = jcp.typesize_out
364 * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
365 store_data(jcp.dst_dt, r_vmm, reg_out, aux_output_offset,
366 mask_flag ? get_tail_size() : get_blocking_size());
367 }
368 }
369}
370
371template <cpu_isa_t isa, typename Vmm>
372void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::compute_ker_dw(int ur_w, int pad_l,
373 int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
374
375 if (!(utils::one_of(isa, avx2) && std::is_same<Vmm, Xbyak::Ymm>::value)
376 && !(utils::one_of(isa, sse41)
377 && std::is_same<Vmm, Xbyak::Xmm>::value))
378 assert(!"invalid group blocking for depthwise convolution");
379
380 const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input);
381
382 if (jcp.src_zero_point) {
383 push(aux_reg_ker_d);
384 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
385 uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]);
386 }
387
388 auto input_spatial_index = [=](int oi, int ki) {
389 return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
390 };
391
392 auto input_offset2 = [=](int ii, int ci) {
393 if (jcp.is_fused_conv)
394 return jcp.typesize_in
395 * (ii * jcp.dw_conv_buffer_oc + ci * jcp.ch_block);
396 else
397 return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
398 };
399
400 auto input_offset3 = [=](int oi, int ci, int ki) {
401 return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
402 };
403
404 auto kernel_offset = [=](int ci, int ki) {
405 return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
406 };
407
408 auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
409 if (jcp.has_vnni) {
410 vpdpbusd(vreg_acc, vreg_src, vreg_wei, VexEncoding);
411 } else {
412 // okay for depthwise since src is zero-extended
413 uni_vpmaddwd(vmm_dw_tmp, vreg_src, vreg_wei);
414 uni_vpaddd(vreg_acc, vreg_acc, vmm_dw_tmp);
415 }
416 };
417
418 int ii_start = 0;
419 int ii_end = -1;
420 if (jcp.is_resrc_depthwise && !h_padded) {
421 // find bounds of input spatial indices
422 bool first = true;
423 for (int ki = 0; ki < jcp.kw; ++ki) {
424 int oi_start = get_ow_start(ki, pad_l);
425 int oi_end = get_ow_end(ur_w, ki, pad_r);
426 for (int oi = oi_start; oi < oi_end; ++oi) {
427 int ii = input_spatial_index(oi, ki);
428 if (first || ii < ii_start) ii_start = ii;
429 if (first || ii > ii_end) ii_end = ii;
430 first = false;
431 }
432 }
433 }
434
435 for (int ci = 0; ci < jcp.nb_ch_blocking; ++ci) {
436 const bool mask_flag = last_ic_block_flag != no_last_block
437 && ci == jcp.nb_ch_blocking - 1;
438 if (jcp.is_resrc_depthwise && !h_padded) {
439 // now we can load input once and reuse up to jcp.kw times
440 for (int ii = ii_start; ii <= ii_end; ++ii) {
441 int aux_input_offset = input_offset2(ii, ci);
442 const Vmm vmm_inp_tmp = vmm_inp(ii, jcp.nb_ch_blocking);
443
444 uni_vpxor(vmm_inp_tmp, vmm_inp_tmp, vmm_inp_tmp);
445 load_data(data_type::u8, vmm_inp_tmp, aux_reg_inp,
446 aux_input_offset,
447 mask_flag ? get_tail_size() : get_blocking_size());
448 if (jcp.signed_input)
449 uni_vpaddb(vmm_inp_tmp, vmm_inp_tmp, vmm_shift);
450 }
451 }
452 for (int ki = 0; ki < jcp.kw; ++ki) {
453 int aux_kernel_offset = kernel_offset(ci, ki);
454 int oi_start = get_ow_start(ki, pad_l);
455 int oi_end = get_ow_end(ur_w, ki, pad_r);
456
457 if (compute_kernel) {
458 uni_vpmovsxbd(vmm_wei, ptr[aux_reg_ker + aux_kernel_offset]);
459 if (h_padded) {
460 assert(jcp.signed_input);
461 for (int oi = 0; oi < ur_w; ++oi)
462 compute(vmm_out(oi, ci), vmm_wei, vmm_shift);
463 } else {
464 int start = jcp.signed_input ? 0 : oi_start;
465 int end = jcp.signed_input ? ur_w : oi_end;
466 for (int oi = start; oi < end; ++oi) {
467 if (oi >= oi_start && oi < oi_end) {
468 if (jcp.is_resrc_depthwise) {
469 int ii = input_spatial_index(oi, ki);
470 vmm_dw_src = vmm_inp(ii, jcp.nb_ch_blocking);
471 } else {
472 int aux_input_offset
473 = input_offset3(oi, ci, ki);
474 load_data(data_type::u8, vmm_dw_src,
475 aux_reg_inp, aux_input_offset,
476 mask_flag ? get_tail_size()
477 : get_blocking_size());
478 if (jcp.signed_input)
479 uni_vpaddb(
480 vmm_dw_src, vmm_dw_src, vmm_shift);
481 }
482 compute(vmm_out(oi, ci), vmm_wei, vmm_dw_src);
483 } else {
484 assert(jcp.signed_input);
485 compute(vmm_out(oi, ci), vmm_wei, vmm_shift);
486 }
487 }
488 }
489 }
490 if (jcp.src_zero_point) {
491 /* calculate src_zero_point padding as:
492 * (is_padding ?
493 * src_zero_point_s32 * conv(1, wei_s32) : 0) */
494 if (!compute_kernel) {
495 uni_vpmovsxbd(
496 vmm_wei, ptr[aux_reg_ker + aux_kernel_offset]);
497 } // else: already loaded weights from previous block
498 for (int oi = 0; oi < ur_w; oi++) {
499 if (oi < oi_start || oi >= oi_end || h_padded) {
500 uni_vpmulld(vmm_zp_dw_tmp, vmm_wei, vmm_zp);
501 uni_vpaddd(vmm_out(oi, ci), vmm_out(oi, ci),
502 vmm_zp_dw_tmp);
503 }
504 }
505 }
506 }
507 }
508
509 if (jcp.src_zero_point) pop(aux_reg_ker_d);
510}
511
512template <cpu_isa_t isa, typename Vmm>
513void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::compute_ker(int ur_w, int pad_l,
514 int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
515 if (jcp.is_depthwise)
516 return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
517
518 int kw = jcp.kw;
519 int stride_w = jcp.stride_w;
520 int ic_block = jcp.ic_block;
521 int oc_block = jcp.oc_block;
522 int ch_block_all = jcp.ch_block * ic_block * oc_block;
523
524 int nb_oc_block = jcp.nb_oc_blocking;
525
526 const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input);
527
528 assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input));
529
530 if (jcp.src_zero_point) {
531 push(aux_reg_ker_d);
532 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
533 }
534
535 auto input_offset = [=](int oi, int ic, int ki) {
536 return jcp.typesize_in
537 * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
538 * jcp.ic_without_padding * jcp.ngroups
539 + ic_sub_step * ic);
540 };
541 auto kernel_offset = [=](int ii, int ic, int ki) {
542 return jcp.typesize_in
543 * ((ii * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki)
544 * ch_block_all
545 + ic_sub_step * ic * oc_block);
546 };
547 auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
548 if (jcp.has_vnni) {
549 vpdpbusd(vreg_acc, vreg_src, vreg_wei, VexEncoding);
550 } else {
551 uni_vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
552 uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
553 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp);
554 }
555 };
556
557 for (int ki = 0; ki < kw; ++ki) {
558 const int ow_start = get_ow_start(ki, pad_l);
559 const int ow_end = get_ow_end(ur_w, ki, pad_r);
560 const int ic_tail_size = jcp.ic_without_padding % ic_sub_step;
561
562 const int _start = jcp.signed_input ? 0 : ow_start;
563 const int _end = jcp.signed_input ? ur_w : ow_end;
564
565 /* Skip the last loads of input
566 if (ic % 8) / ic_sub_step < ic_block / ic_sub_step */
567 const int icb = (last_ic_block_flag != no_last_block)
568 ? div_up((jcp.ic_without_padding % ic_block), ic_sub_step)
569 : ic_block / ic_sub_step;
570
571 if (compute_kernel) {
572 for (int ic = 0; ic < icb; ++ic) {
573 if (h_padded) {
574 // fill padded area with shifted value in first iteration
575 if (ic == 0) {
576 const Vmm inp = vmm_inp(0, nb_oc_block);
577 uni_vmovups(inp, vmm_shift);
578 }
579 } else {
580 for (int jj = _start; jj < _end; ++jj) {
581 int aux_input_offset = input_offset(jj, ic, ki);
582 if (jj >= ow_start && jj < ow_end) {
583 const bool need_partial_ic_bcast = true
584 && last_ic_block_flag == last_sp_block
585 && ic_tail_size != 0 && ic == icb - 1;
586
587 if (need_partial_ic_bcast) {
588 const auto inp_bcastd_vmm
589 = vmm_inp(jj, nb_oc_block);
590 const auto inp_bcastd
591 = Xmm(inp_bcastd_vmm.getIdx());
592 load_bytes(inp_bcastd_vmm, aux_reg_inp,
593 aux_input_offset, ic_tail_size);
594 uni_vpbroadcastd(
595 vmm_inp(jj, nb_oc_block), inp_bcastd);
596 } else {
597 uni_vpbroadcastd(vmm_inp(jj, nb_oc_block),
598 ptr[aux_reg_inp + aux_input_offset]);
599 }
600 if (jcp.signed_input)
601 uni_vpaddb(vmm_inp(jj, nb_oc_block),
602 vmm_inp(jj, nb_oc_block), vmm_shift);
603 } else {
604 // fill padded area with shifted value in
605 // first iteration
606 if (jcp.signed_input && ic == 0) {
607 const Vmm inp = vmm_inp(jj, nb_oc_block);
608 uni_vmovups(inp, vmm_shift);
609 }
610 }
611 }
612 }
613 for (int ii = 0; ii < nb_oc_block; ++ii) {
614 const int aux_kernel_offset = kernel_offset(ii, ic, ki);
615 uni_vmovdqu(vmm_wei, ptr[aux_reg_ker + aux_kernel_offset]);
616 for (int jj = _start; jj < _end; ++jj) {
617 const Vmm inp = vmm_inp(h_padded ? 0 : jj, nb_oc_block);
618 compute(vmm_out(jj, ii), vmm_wei, inp);
619 }
620 }
621 }
622 }
623 if (jcp.src_zero_point) {
624 uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]);
625 /* calculate src_zero_point padding as:
626 * (is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0) */
627 const Vmm vmm_wacc = vmm_inp(0, nb_oc_block);
628 for (int jj = 0; jj < ur_w; jj++) {
629 if (jj < ow_start || jj >= ow_end || h_padded) {
630 for (int ii = 0; ii < nb_oc_block; ii++) {
631 uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp);
632 for (int ic = 0; ic < icb; ic++) {
633 const int aux_kernel_offset
634 = kernel_offset(ii, ic, ki);
635 if (jcp.has_vnni) {
636 vpdpbusd(vmm_tmp, vmm_zp_one,
637 ptr[aux_reg_ker + aux_kernel_offset],
638 VexEncoding);
639 } else {
640 uni_vpmaddubsw(vmm_wacc, vmm_zp_one,
641 ptr[aux_reg_ker + aux_kernel_offset]);
642 uni_vpmaddwd(vmm_wacc, vmm_wacc, vmm_one);
643 uni_vpaddd(vmm_tmp, vmm_tmp, vmm_wacc);
644 }
645 }
646 uni_vpmulld(vmm_tmp, vmm_tmp, vmm_zp);
647 uni_vpaddd(vmm_out(jj, ii), vmm_out(jj, ii), vmm_tmp);
648 }
649 }
650 }
651 }
652 }
653 if (jcp.src_zero_point) pop(aux_reg_ker_d);
654}
655
656template <cpu_isa_t isa, typename Vmm>
657void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::kh_loop(
658 int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
659
660 Label kd_label, kh_label, skip_kd_loop, skip_kh_loop;
661 Label f_overflow_label, no_f_overflow_label, d_h_f_overflow_label,
662 t_overflow_label, no_t_overflow_label, b_overflow_label,
663 no_b_overflow_label, back_overflow_label, no_back_overflow_label,
664 d_h_back_overflow_label;
665
666 int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
667 int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
668 int shift_input_ptr
669 = jcp.typesize_in * jcp.iw * jcp.ic_without_padding * jcp.ngroups;
670
671 if (jcp.src_zero_point && !jcp.is_depthwise) {
672 const auto xmm_one = Xbyak::Xmm(vmm_zp_one.getIdx());
673 mov(reg_scratch, 0x01010101);
674 uni_vmovq(xmm_one, reg_scratch);
675 uni_vpbroadcastd(vmm_zp_one, xmm_one);
676 }
677
678 if (jcp.ndims == 5) {
679 mov(aux_reg_ker_d, reg_ker);
680 mov(aux_reg_inp_d, reg_inp);
681 if (jcp.signed_input || jcp.src_zero_point) {
682 //TODO: May be avoided when f_pad=0 and dd0
683 //TODO: Potential optimization by precomputing, when kd <<< od?
684 mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]);
685 cmp(reg_ki, 0);
686 je(no_f_overflow_label, T_NEAR);
687 L(f_overflow_label);
688 {
689 mov(aux_reg_ker, aux_reg_ker_d);
690 mov(reg_kj, jcp.kh);
691 L(d_h_f_overflow_label);
692 {
693 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
694 add(aux_reg_ker, shift_kernel_ptr);
695 dec(reg_kj);
696 jne(d_h_f_overflow_label);
697 }
698 add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh);
699 dec(reg_ki);
700 jne(f_overflow_label);
701 }
702 L(no_f_overflow_label);
703 }
704
705 mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
706 if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_d >= jcp.id)
707 || (!(jcp.signed_input || jcp.src_zero_point)
708 && (jcp.kd - 1) * (jcp.dilate_d + 1)
709 < nstl::max(jcp.f_pad, jcp.back_pad))) {
710 cmp(reg_ki, 0);
711 je(skip_kd_loop, T_NEAR);
712 }
713 L(kd_label);
714 mov(aux_reg_inp, aux_reg_inp_d);
715 mov(aux_reg_ker, aux_reg_ker_d);
716 } else {
717 if (jcp.is_fused_conv) {
718 mov(aux_reg_inp_buffer_ptr, reg_inp_buffer_ptr);
719 } else {
720 mov(aux_reg_inp, reg_inp);
721 }
722 mov(aux_reg_ker, reg_ker);
723 }
724
725 if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) {
726 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
727 cmp(reg_overflow, 0);
728 je(no_t_overflow_label, T_NEAR);
729 L(t_overflow_label);
730 {
731 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
732
733 add(aux_reg_ker, shift_kernel_ptr);
734 dec(reg_overflow);
735 cmp(reg_overflow, 0);
736 jg(t_overflow_label, T_NEAR);
737 }
738 L(no_t_overflow_label);
739 }
740 mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
741 if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_h >= jcp.ih)
742 || (!(jcp.signed_input || jcp.src_zero_point)
743 && (jcp.kh - 1) * (jcp.dilate_h + 1)
744 < nstl::max(jcp.t_pad, jcp.b_pad))) {
745 cmp(reg_kj, 0);
746 je(skip_kh_loop, T_NEAR);
747 }
748 L(kh_label);
749 {
750 if (jcp.is_fused_conv) {
751 mov(aux_reg_inp, ptr[aux_reg_inp_buffer_ptr]);
752 add(aux_reg_inp, reg_inp);
753 }
754 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
755
756 add(aux_reg_ker, shift_kernel_ptr);
757 if (jcp.is_fused_conv) {
758 add(aux_reg_inp_buffer_ptr, sizeof(void *));
759 } else {
760 add(aux_reg_inp, shift_input_ptr * (jcp.dilate_h + 1));
761 }
762 dec(reg_kj);
763 cmp(reg_kj, 0);
764 jg(kh_label, T_NEAR);
765 }
766 L(skip_kh_loop);
767 if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) {
768 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
769 cmp(reg_overflow, 0);
770 je(no_b_overflow_label, T_NEAR);
771 L(b_overflow_label);
772 {
773 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
774
775 add(aux_reg_ker, shift_kernel_ptr);
776 dec(reg_overflow);
777 cmp(reg_overflow, 0);
778 jg(b_overflow_label, T_NEAR);
779 }
780 L(no_b_overflow_label);
781 }
782 if (jcp.ndims == 5) {
783 add(aux_reg_inp_d, shift_input_ptr * jcp.ih * (jcp.dilate_d + 1));
784 add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh);
785 dec(reg_ki);
786 jne(kd_label, T_NEAR);
787
788 L(skip_kd_loop);
789 if (jcp.signed_input || jcp.src_zero_point) {
790 mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]);
791 cmp(reg_ki, 0);
792 je(no_back_overflow_label, T_NEAR);
793 L(back_overflow_label);
794 {
795 mov(aux_reg_ker, aux_reg_ker_d);
796 mov(reg_kj, jcp.kh);
797 L(d_h_back_overflow_label);
798 {
799 compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
800 add(aux_reg_ker, shift_kernel_ptr);
801 dec(reg_kj);
802 jne(d_h_back_overflow_label);
803 }
804 add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh);
805 dec(reg_ki);
806 jne(back_overflow_label);
807 }
808 L(no_back_overflow_label);
809 }
810 }
811}
812
813template <cpu_isa_t isa, typename Vmm>
814void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::icb_loop(
815 int ur_w, int pad_l, int pad_r, bool is_last_sp_block) {
816 prepare_output(ur_w);
817
818 // IC loop
819 Label icb_label;
820 mov(reg_icb, jcp.nb_ic);
821 L(icb_label);
822 const bool do_icb_loop
823 = jcp.is_depthwise ? jcp.nb_ch > jcp.nb_ch_blocking : jcp.nb_ic > 1;
824 if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
825 Label common_ker, end_ker;
826 if (do_icb_loop) {
827 if (jcp.is_depthwise)
828 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
829 else
830 cmp(reg_icb, 1); // The last IC block
831 jne(common_ker, T_NEAR);
832 }
833 kh_loop(ur_w, pad_l, pad_r,
834 is_last_sp_block ? last_sp_block : last_ic_block);
835 if (do_icb_loop) {
836 jmp(end_ker, T_NEAR);
837
838 L(common_ker);
839 kh_loop(ur_w, pad_l, pad_r, no_last_block);
840
841 L(end_ker);
842 }
843 } else {
844 kh_loop(ur_w, pad_l, pad_r, no_last_block);
845 }
846 // End of IC Loop
847 if (do_icb_loop) {
848 int inp_step = jcp.ic_block;
849 const size_t ker_step = (size_t)jcp.kd * jcp.kh * jcp.kw * jcp.oc_block
850 * jcp.ic_block;
851 add(reg_inp, jcp.typesize_in * inp_step);
852 safe_add(reg_ker, jcp.typesize_in * ker_step, reg_ker_long_offt);
853
854 dec(reg_icb);
855 cmp(reg_icb, 0);
856 jg(icb_label, T_NEAR);
857
858 sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
859 safe_sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic,
860 reg_ker_long_offt);
861 }
862
863 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
864 Label common_store, end_store;
865
866 if (jcp.is_depthwise)
867 cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
868 else
869 cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
870
871 jne(common_store, T_NEAR);
872
873 store_output(ur_w, true); // last oc block
874 jmp(end_store, T_NEAR);
875
876 L(common_store);
877 store_output(ur_w, false);
878
879 L(end_store);
880 } else {
881 store_output(ur_w, false);
882 }
883}
884
885template <cpu_isa_t isa, typename Vmm>
886void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::generate() {
887 Label permute_index_table;
888 int in_ic_shift = jcp.is_fused_conv ? jcp.dw_conv_buffer_oc
889 : jcp.ic_without_padding * jcp.ngroups;
890 const int urw_inp_stride = jcp.ur_w * jcp.stride_w;
891 const int n_urw_l_pad
892 = nstl::min(div_up(jcp.l_pad, urw_inp_stride), jcp.ow / jcp.ur_w);
893 const int inp_shift_pad = nstl::max(0,
894 jcp.typesize_in * (n_urw_l_pad * urw_inp_stride - jcp.l_pad)
895 * in_ic_shift);
896 int inp_shift = jcp.typesize_in * (jcp.ur_w * jcp.stride_w * in_ic_shift);
897 int out_shift = jcp.typesize_out
898 * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
899 preamble();
900
901 if (jcp.is_depthwise) {
902 const bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point;
903 // dst zero point and dst scale reuse the same register
904 int idx = ker_max_reg + 1 - jcp.max_regs_ur
905 - nstl::max(2 * is_zero_point, static_cast<int>(jcp.dst_scale));
906 if (!jcp.is_resrc_depthwise) vmm_dw_src = Vmm(--idx);
907 if (!jcp.has_vnni) vmm_dw_tmp = Vmm(--idx);
908 if (jcp.signed_input) {
909 --idx; // due to extra register used for compensations
910 }
911 assert(IMPLICATION(!jcp.dst_scale && !is_zero_point,
912 idx == ker_max_reg - ker_dw_reg_base_idx));
913 }
914
915 if (!jcp.is_depthwise && (!jcp.has_vnni)) {
916 auto vmm_one_128 = Xbyak::Xmm(vmm_one.getIdx());
917 mov(reg_scratch, 0x10001);
918 uni_vmovq(vmm_one_128, reg_scratch);
919 uni_vpbroadcastd(vmm_one, vmm_one_128);
920 }
921
922 if (jcp.is_fused_conv) {
923 mov(reg_inp_buffer_ptr, ptr[param1 + GET_OFF(src)]);
924 /* In case of fused depthwise convolution, `param.src` is not a pointer
925 to input, instead it points to a buffer containing pointers to
926 consecutive rows of input in format wc with c=jcp.dw_conv_buffer_oc.
927 Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2].
928 Traverse the data as
929 mov(reg_data, ptr[reg_input_buffer_ptr])
930 ... process row0 ...
931 add(reg_input_buffer_ptr, sizeof(void*))
932 mov(reg_data, ptr[reg_input_buffer_ptr])
933 ... process row1 ...
934 add(reg_input_buffer_ptr, sizeof(void*))
935 mov(reg_data, ptr[reg_input_buffer_ptr])
936 ... process row2 ...
937 */
938 xor_(reg_inp, reg_inp);
939 } else {
940 mov(reg_inp, ptr[param1 + GET_OFF(src)]);
941 }
942 mov(reg_out, ptr[param1 + GET_OFF(dst)]);
943 mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
944
945 if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
946 mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
947 }
948
949 const int extended_filter_size
950 = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
951 const int r_pad = nstl::max(0, jcp.r_pad);
952 const int ow_with_no_rpad = 1
953 + (jcp.iw + jcp.l_pad + nstl::min(0, jcp.r_pad)
954 - extended_filter_size)
955 / jcp.stride_w;
956 const int n_urw_per_ow_block = jcp.ow_block / jcp.ur_w;
957 const int max_safe_iw = nstl::max(
958 0, jcp.iw - div_up(ic_sub_step, jcp.ic_without_padding));
959 const int max_safe_ow = jcp.ic_without_padding % ic_sub_step == 0
960 ? jcp.ow
961 : (max_safe_iw + jcp.l_pad - extended_filter_size) / jcp.stride_w;
962 Label middle_block_label, done_compute;
963 std::vector<Label> ow_block_jmp_table;
964
965 // r_pad_fall_through is a special ow_block, where the block overlaps
966 // both middle_block and r_pad/ur_w_tail region when it exists.
967 // The number of ur_w's to compute in middle_block before executing
968 // r_pad region is stored in r_pad_fall_through_n_urw and the ow_block
969 // number is stored in r_pad_fall_through_ow_block.
970 int r_pad_fall_through_ow_block = 0;
971 int r_pad_fall_through_n_urw = 0;
972
973 if (jcp.nb_ow > 1) {
974 // Only one ow block is processed, per jit call.
975 // Number of this ow block is passed as parameter owb,
976 // and padding processing depends on this number.
977 //
978 // The compute block to run is determined by using a jmp-table.
979 // jmp-table Layout:
980 // idx -> addr
981 // 0 -> [...l_pad_region label[0]...]
982 // : : : : : : : : : : : : : : :
983 // L -> [...l_pad_region label[L]...]
984 // L+1 -> [...r_pad_region label[0]...]
985 // : : : : : : : : : : : : : : :
986 // L+R -> [...r_pad_region label[R]...]
987 //
988 // Note: Label for middle_block is not stored in the jmp-table.
989 //
990 // During jit call, the jump address is calculated as below:
991 // if (owb < n) {
992 // jmp([jmp_table + owb*sizeof(void*)]);
993 // } else if (owb < X) {
994 // // X is the number of ow_blocks before r_pad region (see below).
995 // jmp(middle_block);
996 // } else {
997 // sub(owb, X);
998 // jmp([jmp_table + owb*sizeof(void*) + L*sizeof(void)]);
999 // }
1000 //
1001 // To configure the jmp-table, we need to determine some constants
1002 // (namely, r_pad_fall_through_n_urw, r_pad_fall_through_ow_block,
1003 // n_l_pad_labels, n_labels) ahead of writing the compute assembly. So,
1004 // we simulate the filter path without writing the assembly initially.
1005 // This makes the math for calculating the constants become simple and
1006 // self explanatory.
1007
1008 // Begin simulation without writing assembly
1009 int n_l_pad_labels = 0;
1010 int n_labels = 0;
1011 int cur_ow = 0;
1012
1013 // l_pad region:
1014 n_l_pad_labels = div_up(n_urw_l_pad, n_urw_per_ow_block);
1015 n_labels = n_l_pad_labels;
1016 cur_ow += n_urw_l_pad * jcp.ur_w;
1017
1018 // middle_region:
1019 int n_urw_middle_block_loop = 0;
1020 int cur_r_pad = nstl::max(0,
1021 calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw,
1022 jcp.stride_w, extended_filter_size));
1023 if (cur_ow + jcp.ur_w <= jcp.ow && cur_r_pad == 0) {
1024 n_urw_middle_block_loop
1025 = nstl::max(0,
1026 nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow)
1027 / jcp.ur_w;
1028 cur_ow += n_urw_middle_block_loop * jcp.ur_w;
1029 }
1030 r_pad_fall_through_n_urw = (cur_ow / jcp.ur_w) % n_urw_per_ow_block;
1031 r_pad_fall_through_ow_block = cur_ow / (n_urw_per_ow_block * jcp.ur_w);
1032
1033 // r_pad or last_sp_block
1034 if (cur_ow + jcp.ur_w <= jcp.ow) {
1035 if (r_pad_fall_through_n_urw == 0) ++n_labels;
1036 const int n_urw_r_pad_region = (jcp.ow - cur_ow) / jcp.ur_w;
1037 n_labels += nstl::max(0,
1038 div_up(r_pad_fall_through_n_urw + n_urw_r_pad_region,
1039 n_urw_per_ow_block)
1040 - 1);
1041 }
1042
1043 if (jcp.ur_w_tail != 0) {
1044 if (jcp.ow % jcp.ow_block == jcp.ur_w_tail) ++n_labels;
1045 }
1046 // End of simulation
1047
1048 ow_block_jmp_table.resize(n_labels);
1049
1050 // Begin jump-table logic
1051 Label ow_block_jmp_table_label;
1052 if (!ow_block_jmp_table.empty())
1053 mov(reg_jmp_tbl_base, ow_block_jmp_table_label);
1054 mov(reg_oi, n_urw_per_ow_block);
1055 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1056 if (jcp.l_pad > 0) {
1057 Label middle_or_rpad_check;
1058 cmp(reg_owb, n_l_pad_labels);
1059 jge(middle_or_rpad_check, T_NEAR);
1060 jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *)]);
1061 L(middle_or_rpad_check);
1062 // harness passes shifted src pointer that does not take
1063 // left-padding into account. So, we must re-shift here.
1064 const int inp_shift_pad_middle_block = -1 * jcp.typesize_in
1065 * nstl::min(jcp.l_pad, n_urw_l_pad * urw_inp_stride)
1066 * in_ic_shift;
1067 add(reg_inp, inp_shift_pad_middle_block);
1068 }
1069 if (r_pad_fall_through_n_urw != 0) {
1070 mov(reg_scratch, r_pad_fall_through_n_urw);
1071 cmp(reg_owb, r_pad_fall_through_ow_block);
1072 cmove(reg_oi, reg_scratch);
1073 if (n_urw_middle_block_loop > 0) {
1074 sub(reg_owb, r_pad_fall_through_ow_block);
1075 // simple middle_block
1076 jle(middle_block_label, T_NEAR);
1077 dec(reg_owb);
1078 } else {
1079 sub(reg_owb, r_pad_fall_through_ow_block + 1);
1080 }
1081 } else {
1082 sub(reg_owb, r_pad_fall_through_ow_block);
1083 // simple middle_block
1084 if (n_urw_middle_block_loop) jl(middle_block_label, T_NEAR);
1085 }
1086 // r_pad-only region
1087 if (!ow_block_jmp_table.empty())
1088 jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *)
1089 + n_l_pad_labels * sizeof(void *)]);
1090
1091 if (!ow_block_jmp_table.empty()) {
1092 align(8);
1093 L(ow_block_jmp_table_label);
1094 {
1095 for (size_t i = 0; i < ow_block_jmp_table.size(); ++i) {
1096 putL(ow_block_jmp_table[i]);
1097 }
1098 }
1099 }
1100 // End of jump-table logic
1101 }
1102
1103 // Begin kernel
1104 int cur_ow = 0;
1105 int cur_n_oi = 0; // used only for jcp.nb_ow > 1 scenario
1106 int label_cntr = 0;
1107 int cur_l_pad = 0;
1108 if (jcp.l_pad > 0) {
1109 for (cur_l_pad = jcp.l_pad;
1110 cur_l_pad > 0 && cur_ow + jcp.ur_w <= jcp.ow;
1111 cur_l_pad -= urw_inp_stride) {
1112 if (jcp.nb_ow > 1 && cur_n_oi == 0) {
1113 // cur_n_oi == 0 signifies beginning of new ow_block
1114 // (or end of previous block)
1115 const dim_t inp_lpad_region_shift = -label_cntr * jcp.ow_block
1116 * jcp.stride_w * in_ic_shift;
1117 L(ow_block_jmp_table[label_cntr++]);
1118 // harness passes shifted src pointer that does not take
1119 // left-padding into account. So, we must re-shift here.
1120 add(reg_inp, inp_lpad_region_shift);
1121 }
1122
1123 cur_ow += jcp.ur_w;
1124 int cur_r_pad = nstl::max(0,
1125 calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw,
1126 jcp.stride_w, extended_filter_size));
1127 icb_loop(jcp.ur_w, cur_l_pad, cur_r_pad, cur_ow > max_safe_ow);
1128 add(reg_out, out_shift);
1129 dec(reg_oi);
1130
1131 if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) {
1132 // We compute one owb per jit call. So, insert an
1133 // unconditional jmp, after computing one owb.
1134 jmp(done_compute, T_NEAR);
1135 cur_n_oi = 0;
1136 }
1137 }
1138 if (jcp.nb_ow == 1 || cur_n_oi != 0) {
1139 // Let it "fall-through" middle_block_label
1140 add(reg_inp, inp_shift_pad);
1141 }
1142 }
1143
1144 // middle_block
1145 {
1146 int cur_r_pad = nstl::max(0,
1147 calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw,
1148 jcp.stride_w, extended_filter_size));
1149 if (cur_r_pad == 0 && cur_ow + jcp.ur_w <= jcp.ow) {
1150 int n_oi_middle_block_loop
1151 = nstl::max(0,
1152 nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow)
1153 / jcp.ur_w;
1154 if (jcp.nb_ow == 1 && n_oi_middle_block_loop > 1)
1155 mov(reg_oi, n_oi_middle_block_loop);
1156 L(middle_block_label);
1157 if (n_oi_middle_block_loop > 0) {
1158 icb_loop(jcp.ur_w, 0, 0, false);
1159 add(reg_inp, inp_shift);
1160 add(reg_out, out_shift);
1161 if (n_oi_middle_block_loop > 1) {
1162 dec(reg_oi);
1163 jg(middle_block_label, T_NEAR);
1164 }
1165 }
1166 cur_ow += n_oi_middle_block_loop * jcp.ur_w;
1167 cur_n_oi = (cur_n_oi + n_oi_middle_block_loop) % n_urw_per_ow_block;
1168 }
1169 }
1170
1171 // r_pad region or last_sp_block
1172 if (cur_ow + jcp.ur_w <= jcp.ow) {
1173 if (jcp.nb_ow > 1) {
1174 if (cur_n_oi == 0) {
1175 jmp(done_compute, T_NEAR);
1176 } else {
1177 // r_pad fall-through
1178 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1179 cmp(reg_owb, r_pad_fall_through_ow_block);
1180 jne(done_compute, T_NEAR);
1181 }
1182 }
1183
1184 while (cur_ow + jcp.ur_w <= jcp.ow) {
1185 if (jcp.nb_ow > 1 && cur_n_oi == 0) {
1186 L(ow_block_jmp_table[label_cntr++]);
1187 }
1188 cur_ow += jcp.ur_w;
1189 int cur_r_pad = calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw,
1190 jcp.stride_w, extended_filter_size);
1191 assert(cur_r_pad > 0 || cur_ow > max_safe_ow); // else, why be here?
1192 icb_loop(jcp.ur_w, 0, cur_r_pad, cur_ow > max_safe_ow);
1193 add(reg_inp, inp_shift);
1194 add(reg_out, out_shift);
1195
1196 if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) {
1197 // We compute one owb per jit call. So, insert an
1198 // unconditional jmp, after computing one owb.
1199 jmp(done_compute, T_NEAR);
1200 cur_n_oi = 0;
1201 }
1202 }
1203 // Let it fall-through ur_w_tail
1204 }
1205
1206 // ur_w_tail
1207 if (jcp.ur_w_tail != 0) {
1208 if (jcp.nb_ow > 1) {
1209 if (cur_n_oi == 0) {
1210 jmp(done_compute, T_NEAR);
1211 L(ow_block_jmp_table[label_cntr++]);
1212 } else {
1213 // In case, when there is no r_pad region, then there exists an
1214 // ambiguity btw middle_blocks and r_pad_fall_through_ow_block.
1215 // If not properly distinguished, there can be a race condition
1216 // as middle_blocks and r_pad_fall_through_ow_block both try to
1217 // compute ur_w_tail work at the end.
1218 mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
1219 cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
1220 jne(done_compute, T_NEAR);
1221 }
1222 }
1223 icb_loop(jcp.ur_w_tail, nstl::max(0, cur_l_pad), r_pad, true);
1224 }
1225 L(done_compute);
1226 assert(ow_block_jmp_table.size() == static_cast<size_t>(label_cntr));
1227
1228 postamble();
1229
1230 if (jcp.with_eltwise) postops_injector_->prepare_table();
1231}
1232
1233template <cpu_isa_t isa>
1234status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
1235 const convolution_desc_t &cd, memory_desc_t &src_md,
1236 memory_desc_t &weights_md, memory_desc_t &dst_md,
1237 memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) {
1238 using namespace prop_kind;
1239
1240 const memory_desc_wrapper src_d(&src_md);
1241 const memory_desc_wrapper weights_d(&weights_md);
1242 const memory_desc_wrapper dst_d(&dst_md);
1243 const memory_desc_wrapper bias_d(&bias_md);
1244
1245 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1246 const int ndims = src_d.ndims();
1247 const bool is_1d = ndims == 3;
1248 const bool is_2d = ndims == 4;
1249 const bool is_3d = ndims == 5;
1250 const bool is_avx2 = isa == avx2;
1251 assert(is_1d || is_2d || is_3d);
1252
1253 if (!(mayiuse(isa)
1254 && one_of(src_d.data_type(), data_type::u8, data_type::s8)
1255 && weights_d.data_type() == data_type::s8
1256 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
1257 data_type::s8, data_type::u8)))
1258 return status::unimplemented;
1259
1260 jcp = zero<decltype(jcp)>();
1261 jcp.nthr = nthreads;
1262 jcp.ndims = ndims;
1263 jcp.prop_kind = cd.prop_kind;
1264 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1265 jcp.mb = src_d.dims()[0];
1266 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1267 jcp.oc_without_padding = jcp.oc;
1268 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1269 jcp.ic_without_padding = jcp.ic;
1270 jcp.id = is_3d ? src_d.dims()[2] : 1;
1271 jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
1272 jcp.iw = src_d.dims()[ndims - 1];
1273 jcp.od = is_3d ? dst_d.dims()[2] : 1;
1274 jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
1275 jcp.ow = dst_d.dims()[ndims - 1];
1276 jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
1277 jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
1278 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
1279 jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
1280 jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
1281 jcp.l_pad = cd.padding[0][ndims - 3];
1282 jcp.stride_d = is_3d ? cd.strides[0] : 1;
1283 jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
1284 jcp.stride_w = cd.strides[ndims - 3];
1285 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
1286
1287 jcp.ur_h = 1; /* no code-unrolling by h so far */
1288
1289 jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
1290 jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
1291 jcp.dilate_w = cd.dilates[ndims - 3];
1292
1293 int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
1294 int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
1295 int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d);
1296 jcp.r_pad = calculate_end_padding(
1297 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
1298 jcp.b_pad = calculate_end_padding(
1299 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
1300 jcp.back_pad = calculate_end_padding(
1301 jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd);
1302
1303 jcp.has_vnni = mayiuse(avx2_vnni);
1304
1305 jcp.signed_input = src_d.data_type() == data_type::s8;
1306 jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
1307
1308 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
1309 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
1310 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
1311
1312 const int wei_mask_per_oc = 1 << (int)with_groups;
1313 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
1314 jcp.dst_scale = !dst_scales.has_default_values();
1315
1316 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
1317 && everyone_is(src_scales.mask_, dst_scales.mask_, 0);
1318 if (!scales_ok) return status::unimplemented;
1319
1320 const auto zp = attr.zero_points_;
1321 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
1322 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
1323 jcp.zp_src_is_common
1324 = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel
1325 assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common));
1326
1327 if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.is_fused_conv)
1328 return status::unimplemented;
1329
1330 if (is_3d && jcp.is_depthwise) return status::unimplemented;
1331
1332 if (jcp.is_depthwise) {
1333 jcp.ch_block = is_avx2 ? 8 : 4;
1334 jcp.ic_block = 1;
1335 jcp.oc_block = 1;
1336 } else {
1337 jcp.ch_block = 1;
1338 jcp.ic_block = is_avx2 ? 8 : 4;
1339 jcp.oc_block = is_avx2 ? 8 : 4;
1340
1341 if (jcp.ngroups == 1) {
1342 /* For non grouped convolutions, pad channels by 8 if needed */
1343 jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
1344 jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
1345 } else if (jcp.ngroups != 1
1346 && ((jcp.ic % jcp.ic_block != 0)
1347 || (jcp.oc % jcp.oc_block != 0))) {
1348 /* For grouped convolutions, oneDNN doesn't support padding.
1349 * When channels per group is not multiple of 8:
1350 * - Use Xmm when channels per group is multiple of 4.
1351 * - Otherwise return unimplemented */
1352 jcp.oc_block = jcp.ic_block = 4;
1353 }
1354 if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
1355 return status::unimplemented;
1356 }
1357
1358 jcp.is_resrc_depthwise = true && jcp.is_depthwise && jcp.stride_w < jcp.kw
1359 && jcp.kw < 4 && jcp.dilate_w == 0;
1360
1361 if (jcp.is_depthwise) {
1362 jcp.max_regs_ur = 14 - !jcp.is_resrc_depthwise - jcp.signed_input
1363 + (jcp.has_vnni);
1364 } else {
1365 jcp.max_regs_ur = jcp.has_vnni ? 15 - jcp.signed_input : 12;
1366 }
1367
1368 if (jcp.dst_scale) jcp.max_regs_ur = 10;
1369 if (jcp.src_zero_point || jcp.dst_zero_point) jcp.max_regs_ur = 9;
1370
1371 auto set_or_check_wei_format = [&]() {
1372 using namespace format_tag;
1373 using namespace memory_extra_flags;
1374 const int c_mask = 0x1,
1375 g_mask = 0x3; // mask for i/o-channel and ngroups
1376 format_tag_t wei_tag;
1377 if (jcp.ic_block == 8 || jcp.ch_block == 8) {
1378 if (is_1d) {
1379 wei_tag = with_groups ? jcp.is_depthwise ? Goiw8g : gOIw2i8o4i
1380 : OIw2i8o4i;
1381 } else if (is_2d) {
1382 wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i
1383 : OIhw2i8o4i;
1384 } else {
1385 wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i;
1386 }
1387 } else {
1388 if (is_avx2) {
1389 assert(with_groups && jcp.ic_block == 4);
1390 wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i;
1391 } else {
1392 if (is_1d) {
1393 wei_tag = with_groups ? jcp.is_depthwise ? Goiw4g : gOIw4o4i
1394 : OIw4o4i;
1395 } else if (is_2d) {
1396 wei_tag = with_groups
1397 ? jcp.is_depthwise ? Goihw4g : gOIhw4o4i
1398 : OIhw4o4i;
1399 } else {
1400 wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i;
1401 }
1402 }
1403 }
1404 memory_desc_t want_wei_md = weights_md;
1405 memory_desc_init_by_tag(want_wei_md, wei_tag);
1406
1407 if (jcp.signed_input) {
1408 want_wei_md.extra.flags = 0
1409 | memory_extra_flags::compensation_conv_s8s8
1410 | memory_extra_flags::scale_adjust;
1411 want_wei_md.extra.compensation_mask
1412 = (with_groups && !jcp.is_depthwise) ? g_mask : c_mask;
1413 want_wei_md.extra.scale_adjust = (jcp.has_vnni) ? 1.f : 0.5f;
1414 }
1415 if (jcp.src_zero_point) {
1416 want_wei_md.extra.flags |= compensation_conv_asymmetric_src;
1417 want_wei_md.extra.asymm_compensation_mask
1418 = (with_groups && !jcp.is_depthwise) ? g_mask : c_mask;
1419 }
1420
1421 if (weights_md.format_kind == format_kind::any) {
1422 weights_md = want_wei_md;
1423 return true;
1424 }
1425
1426 return weights_md == want_wei_md;
1427 };
1428
1429 if (!set_or_check_wei_format()) return status::unimplemented;
1430 format_tag_t dat_tag = utils::pick(
1431 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
1432
1433 if (src_d.format_kind() == format_kind::any) {
1434 CHECK(memory_desc_init_by_tag(src_md, dat_tag));
1435 jcp.src_tag = dat_tag;
1436 } else {
1437 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
1438 }
1439 if (jcp.src_tag != dat_tag) return status::unimplemented;
1440
1441 if (dst_d.format_kind() == format_kind::any) {
1442 CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
1443 jcp.dst_tag = dat_tag;
1444 } else {
1445 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
1446 }
1447 if (jcp.dst_tag != dat_tag) return status::unimplemented;
1448
1449 if (jcp.with_bias) {
1450 if (bias_d.format_kind() == format_kind::any)
1451 CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
1452 }
1453 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
1454 jcp.dst_dt = cd.dst_desc.data_type;
1455
1456 CHECK(attr.set_default_formats(&dst_md));
1457
1458 const auto &post_ops = attr.post_ops_;
1459
1460 const int eltwise_ind = post_ops.find(primitive_kind::eltwise);
1461 jcp.with_eltwise = eltwise_ind != -1;
1462 const int binary_ind = post_ops.find(primitive_kind::binary);
1463 jcp.with_binary = binary_ind != -1;
1464 const int sum_ind = post_ops.find(primitive_kind::sum);
1465 jcp.with_sum = sum_ind != -1;
1466 jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt);
1467
1468 jcp.post_ops = post_ops;
1469
1470 using namespace injector;
1471
1472 const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum},
1473 jcp.post_ops, &dst_d, false, false, false});
1474 if (!post_ops_ok_) return status::unimplemented;
1475
1476 jcp.typesize_in = types::data_type_size(src_d.data_type());
1477 jcp.typesize_out = types::data_type_size(dst_d.data_type());
1478 jcp.typesize_bia
1479 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
1480
1481 jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
1482 jcp.nb_ic = jcp.ic / jcp.ic_block;
1483 jcp.nb_oc = jcp.oc / jcp.oc_block;
1484
1485 // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
1486 int nb_ch_blocking = 4;
1487 for (/* init above */; nb_ch_blocking > 1; --nb_ch_blocking)
1488 if (jcp.nb_ch % nb_ch_blocking == 0) break;
1489 jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
1490
1491 // If OC blocking is incommensurate with the number of OC blocks (general
1492 // requirement for all convolutions), or if it results in an unrolling
1493 // factor smaller than the left padding (special requirement for SSD:fc6),
1494 // then search for a smaller OC blocking that satisfies both constraints.
1495 auto is_oc_blocking_ok = [&](int block) {
1496 int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
1497 return jcp.nb_oc % block == 0 && jcp.l_pad <= ur_w
1498 && jcp.ow % ur_w != 1;
1499 };
1500
1501 // choose nb_oc work chunk size for distribution within threads
1502 int max_threading_nb_oc_chunk = 4;
1503
1504 jcp.nb_oc_blocking_thr_chunk
1505 = nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
1506 for (; jcp.nb_oc_blocking_thr_chunk > 1; --jcp.nb_oc_blocking_thr_chunk) {
1507 if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) break;
1508 }
1509
1510 auto get_thr_eff = [=](int nb_ow, int nthr) {
1511 int base_work_amount = jcp.mb * jcp.nb_ch * jcp.od * jcp.oh
1512 * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
1513 auto work_amount = base_work_amount * nb_ow;
1514 return float(work_amount) / rnd_up(work_amount, nthr);
1515 };
1516
1517 auto get_ow_block = [=](int ur_w, int nthr) {
1518 int res_ow_block = jcp.ow;
1519 float best_thr_eff = get_thr_eff(1, nthr);
1520 float thr_eff;
1521
1522 int max_nb_ow = div_up(jcp.ow, ur_w);
1523 for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
1524 int ow_block
1525 = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
1526 if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
1527 && best_thr_eff > 0.8f)
1528 break;
1529 if (div_up(jcp.ow, ow_block) != nb_ow) continue;
1530 thr_eff = get_thr_eff(nb_ow, nthr);
1531 if (ow_block >= ur_w && thr_eff > 1.1f * best_thr_eff) {
1532 res_ow_block = ow_block;
1533 best_thr_eff = thr_eff;
1534 }
1535 if (best_thr_eff > 0.9f) break;
1536 }
1537 return res_ow_block;
1538 };
1539
1540 // choose oc blocking for computational kernel
1541 jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
1542
1543 if (jcp.is_resrc_depthwise)
1544 jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
1545 / (jcp.nb_ch_blocking + jcp.stride_w);
1546 else
1547 jcp.ur_w = jcp.max_regs_ur
1548 / (jcp.is_depthwise ? jcp.nb_ch_blocking
1549 : jcp.nb_oc_blocking + 1);
1550 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
1551 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
1552
1553 jcp.ow_block = jcp.ow;
1554
1555 jcp.ow_block = get_ow_block(jcp.ur_w, jcp.nthr);
1556 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1557 float thr_eff = get_thr_eff(jcp.nb_ow, jcp.nthr);
1558
1559 /* adjust the thread decomposition
1560 * to improve the thr_eff for small size problem
1561 * the threshold L1_cache_size is empirical */
1562 size_t wei_size
1563 = sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw * jcp.kd;
1564 size_t out_size
1565 = jcp.mb * jcp.typesize_out * jcp.oc * jcp.oh * jcp.ow * jcp.od;
1566 size_t inp_size
1567 = jcp.mb * jcp.typesize_in * jcp.ic * jcp.ih * jcp.iw * jcp.id;
1568 size_t total_size = jcp.ngroups * (wei_size + out_size + inp_size);
1569 const unsigned int L1_cache_size = platform::get_per_core_cache_size(1);
1570
1571 if (jcp.ngroups < jcp.nthr && (total_size < L1_cache_size)) {
1572 int ow_block = jcp.ow_block;
1573 float best_thr_eff = thr_eff;
1574 float eff = -1.0f;
1575 int end_nthr = with_groups ? jcp.ngroups : 1;
1576 for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) {
1577 ow_block = get_ow_block(jcp.ur_w, nthr);
1578 eff = get_thr_eff(div_up(jcp.ow, ow_block), nthr);
1579 if (eff > 1.1f * best_thr_eff) {
1580 best_thr_eff = eff;
1581 jcp.ow_block = ow_block;
1582 jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
1583 jcp.nthr = jcp.aligned_threads = nthr;
1584 if (best_thr_eff > 0.9f) break;
1585 }
1586 }
1587 }
1588
1589 bool args_ok = true && jcp.oc % jcp.oc_block == 0
1590 && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0);
1591 if (!args_ok) return status::unimplemented;
1592
1593 pick_loop_order(jcp);
1594
1595 jcp.nb_ic_L2 = jcp.nb_ic;
1596
1597 jcp.wei_adj_scale
1598 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
1599 ? weights_d.extra().scale_adjust
1600 : 1.f;
1601
1602 return status::success;
1603}
1604
1605template <cpu_isa_t isa>
1606void jit_uni_x8s8s32x_fwd_kernel<isa>::init_scratchpad(
1607 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
1608 const primitive_attr_t &attr) {
1609
1610 const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_;
1611 const dim_t scales_count = mask == 0 ? 1 : jcp.oc * jcp.ngroups;
1612 dim_t count = scales_count == 1 ? (dim_t)8 : scales_count;
1613 scratchpad.book<float>(key_conv_adjusted_scales, count);
1614}
1615
1616template struct _jit_uni_x8s8s32x_fwd_kernel<avx2, Ymm>;
1617template struct _jit_uni_x8s8s32x_fwd_kernel<avx2, Xmm>;
1618template struct _jit_uni_x8s8s32x_fwd_kernel<sse41, Xmm>;
1619template struct jit_uni_x8s8s32x_fwd_kernel<avx2>;
1620template struct jit_uni_x8s8s32x_fwd_kernel<sse41>;
1621
1622} // namespace x64
1623} // namespace cpu
1624} // namespace impl
1625} // namespace dnnl
1626