1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/c_types_map.hpp"
20#include "common/memory.hpp"
21#include "common/memory_tracking.hpp"
22#include "common/nstl.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/platform.hpp"
27
28#include "cpu/x64/injectors/injector_utils.hpp"
29#include "cpu/x64/injectors/jit_uni_binary_injector.hpp"
30#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
31#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
32#include "cpu/x64/jit_uni_x8s8s32x_1x1_conv_kernel.hpp"
33
34#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
35
36namespace dnnl {
37namespace impl {
38namespace cpu {
39namespace x64 {
40
41using namespace dnnl::impl::utils;
42using namespace dnnl::impl::data_type;
43using namespace Xbyak;
44using namespace injector_utils;
45
46template <cpu_isa_t isa, typename Vmm>
47_jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::_jit_uni_x8s8s32x_1x1_conv_kernel(
48 const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr,
49 const memory_desc_t &dst_md)
50 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
51 , jcp(ajcp)
52 , attr_(attr) {
53 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
54 using namespace binary_injector;
55 static constexpr bool preserve_gpr = true;
56 static constexpr bool preserve_vmm = true;
57 rhs_arg_static_params_t rhs_arg_static_params {15, r13, r14, r15,
58 preserve_gpr, preserve_vmm,
59 GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig),
60 memory_desc_wrapper(dst_md)};
61 static_params_t static_params {this->param1, rhs_arg_static_params};
62
63 postops_injector_
64 = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>(
65 this, jcp.post_ops, static_params);
66 }
67}
68
69template <cpu_isa_t isa, typename Vmm>
70void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::cvt2ps(data_type_t type_in,
71 const Vmm &vmm_in, const Reg64 &reg, int offset, int load_size) {
72 load_data(type_in, vmm_in, reg, offset, load_size);
73 if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in);
74}
75
76template <cpu_isa_t isa, typename Vmm>
77void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::bcast_loop(
78 int load_loop_blk) {
79 mov(aux1_reg_bcast_data, reg_bcast_data);
80 mov(aux_reg_bcast_data, reg_bcast_data);
81
82 mov(aux_reg_output_data, reg_output_data);
83 mov(reg_bcast_loop_iter, ptr[rsp + bcast_loop_work_off]);
84
85 Label bcast_loop;
86 Label bcast_loop_tail;
87
88 cmp(reg_bcast_loop_iter, jcp.ur);
89 jl(bcast_loop_tail, T_NEAR);
90
91 L(bcast_loop);
92 {
93 assert(jcp.bcast_block == jcp.ur);
94 reduce_loop(load_loop_blk, jcp.ur, false);
95 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step);
96 add(aux_reg_output_data, jcp.bcast_loop_output_step);
97
98 sub(reg_bcast_loop_iter, jcp.bcast_block);
99 cmp(reg_bcast_loop_iter, jcp.bcast_block);
100 jge(bcast_loop, T_NEAR);
101 }
102
103 L(bcast_loop_tail);
104 if (jcp.ur_tail) {
105 Label bcast_loop_tail_out;
106 cmp(reg_bcast_loop_iter, 0);
107 jz(bcast_loop_tail_out, T_NEAR);
108 reduce_loop(load_loop_blk, jcp.ur_tail, true);
109 L(bcast_loop_tail_out);
110 }
111}
112
113template <cpu_isa_t isa, typename Vmm>
114int _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::output_ptr(
115 const int i_load, const int i_ur) {
116 const size_t ur_stride = jcp.with_dw_conv
117 ? jcp.nb_load_blocking * jcp.oc_block * i_ur
118 : jcp.oc_without_padding * i_ur;
119
120 return jcp.typesize_out * (ur_stride + i_load * jcp.load_block);
121};
122
123template <cpu_isa_t isa, typename Vmm>
124int _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::vreg_accum_idx(
125 const int load_loop_blk, const int i_load, const int i_ur) {
126 const int vmm_idx = i_ur * load_loop_blk + i_load;
127 assert(vmm_idx < ker_max_reg_idx);
128 return (15 - vmm_idx);
129};
130
131template <cpu_isa_t isa, typename Vmm>
132Vmm _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::vreg_accum(
133 const int load_loop_blk, const int i_load, const int i_ur) {
134 return Vmm(vreg_accum_idx(load_loop_blk, i_load, i_ur));
135};
136
137template <typename F>
138void iterate(const int ur, const int load_loop_blk, const F &f) {
139 for (int i_ur = 0; i_ur < ur; ++i_ur)
140 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
141 f(i_ur, i_load);
142}
143
144template <cpu_isa_t isa, typename Vmm>
145void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::apply_sum(const int ur,
146 const int load_loop_blk, const bool mask_flag_in,
147 const float *p_sum_scale, const int32_t *p_sum_zp) {
148
149 if (jcp.with_sum) {
150 assert(!utils::any_null(p_sum_scale, p_sum_zp)
151 && "p_sum_scale or p_sum_zp = nullptr");
152 const float sum_scale = *p_sum_scale;
153 const int32_t sum_zp = *p_sum_zp;
154 const auto sum_injector_lam = [this, mask_flag_in, load_loop_blk,
155 sum_scale, sum_zp](const int i_ur,
156 const int i_load) {
157 const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
158 const auto ymm_prev_dst = vmm_zero;
159
160 const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
161 cvt2ps(jcp.sum_dt, ymm_prev_dst, aux_reg_output_data,
162 output_ptr(i_load, i_ur),
163 mask_flag ? get_tail_size() : simd_w);
164
165 if (sum_zp != 0) {
166 uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_zp]);
167 uni_vcvtdq2ps(vmm_tmp, vmm_tmp);
168 uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_tmp);
169 }
170 if (sum_scale == 1.f)
171 uni_vaddps(r, r, ymm_prev_dst);
172 else {
173 uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_scale]);
174 uni_vfmadd231ps(r, ymm_prev_dst, vmm_tmp);
175 }
176 };
177 const auto sum_injector
178 = [=]() { iterate(ur, load_loop_blk, sum_injector_lam); };
179 if (sum_zp != 0)
180 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
181 postops_injector_->set_lambda_injector(
182 primitive_kind::sum, sum_injector);
183 }
184}
185
186template <cpu_isa_t isa, typename Vmm>
187void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::apply_postops(const int ur,
188 const int load_loop_blk, const bool mask_flag_in,
189 const float *p_sum_scale, const int32_t *p_sum_zp) {
190
191 if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) {
192 if (jcp.with_sum && *p_sum_zp != 0)
193 mov(ptr[rsp + reg_bcast_loop_iter_off], reg_ptr_sum_zp);
194 apply_sum(ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp);
195
196 binary_injector::rhs_arg_dynamic_params_t rhs_arg_params;
197 vmm_index_set_t vmm_idxs;
198 if (jcp.with_binary) {
199 iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) {
200 const int ur_stride = jcp.with_dw_conv
201 ? jcp.nb_load_blocking * jcp.oc_block * i_ur
202 : jcp.oc_without_padding * jcp.ngroups * i_ur;
203 const size_t aux_output_offset = jcp.typesize_out
204 * (ur_stride + i_load * jcp.load_block);
205 const auto vmm_idx
206 = vreg_accum_idx(load_loop_blk, i_load, i_ur);
207 vmm_idxs.emplace(vmm_idx);
208
209 rhs_arg_params.vmm_idx_to_out_reg.emplace(
210 vmm_idx, aux_reg_output_data);
211 rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(
212 vmm_idx, aux_output_offset);
213 });
214
215 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
216 } else {
217 iterate(ur, load_loop_blk, [&](const int i_ur, const int i_load) {
218 vmm_idxs.emplace(vreg_accum_idx(load_loop_blk, i_load, i_ur));
219 });
220 postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params);
221 }
222 if (jcp.with_sum && *p_sum_zp != 0)
223 mov(reg_ptr_sum_zp, ptr[rsp + reg_bcast_loop_iter_off]);
224 }
225}
226
227template <cpu_isa_t isa, typename Vmm>
228void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::reduce_loop(
229 int load_loop_blk, int ur, bool wraparound) {
230
231 // use 0x10001 to represent 2 words of 0x1
232 // and avoid using uni_vpbroadcastb that is missing in jit generator
233 const auto xmm_one = Xmm(vmm_one.getIdx());
234 mov(reg_init_bcast, 0x10001);
235 uni_vmovq(xmm_one, reg_init_bcast);
236 uni_vpbroadcastd(vmm_one, xmm_one);
237
238 auto vreg_load = [&](int i_load) {
239 const int vmm_idx = ur * load_loop_blk + i_load;
240 assert(vmm_idx < ker_max_reg_idx);
241 /* remap the register indices to
242 * avoid passing xmm0 to eltwise injector */
243 return Vmm(15 - vmm_idx);
244 };
245
246 auto bcast_ptr = [&](int i_reduce, int i_ur) {
247 assert(i_ur < jcp.ur);
248 assert(i_reduce <= jcp.reduce_loop_unroll);
249 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
250
251 int offt = (jcp.ic_without_padding * i_ur + i_reduce);
252
253 return ptr[aux_reg_bcast_data + jcp.typesize_in * offt];
254 };
255
256 auto load_ptr = [&](int i_reduce, int i_load) {
257 int u0 = i_reduce % jcp.reduce_loop_unroll;
258 int u1 = i_reduce / jcp.reduce_loop_unroll;
259
260 int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
261
262 return ptr[aux_reg_load_data + u1 * jcp.reduce_loop_load_step
263 + jcp.typesize_in * offt];
264 };
265
266 auto init = [&]() {
267 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
268 for (int i_ur = 0; i_ur < ur; ++i_ur) {
269 auto r = vreg_accum(load_loop_blk, i_load, i_ur);
270 uni_vpxor(r, r, r);
271 }
272 if (jcp.signed_input) {
273 // Used 0x80808080 to represents 2 words of 128
274 // to avoid using uni_vpbroadcastb that is missing in jit generator
275 auto xmm_shift = Xbyak::Xmm(vmm_shift.getIdx());
276 auto _t32 = reg_init_bcast.cvt32();
277 mov(_t32, 0x80808080);
278 uni_vpinsrd(xmm_shift, xmm_shift, _t32, 0);
279 uni_vpbroadcastd(vmm_shift, xmm_shift);
280 }
281 };
282
283 auto store = [&](const bool mask_flag_in) {
284 const auto &p = attr_.post_ops_;
285 const int sum_idx = p.find(primitive_kind::sum);
286 const float *p_sum_scale
287 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
288 const int32_t *p_sum_zp
289 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr;
290 mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data);
291 mov(reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]);
292 if (p_sum_scale && *p_sum_scale != 1.f) {
293 mov(ptr[rsp + reg_load_data_off], reg_load_data);
294 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
295 }
296 if (jcp.src_zero_point) {
297 mov(reg_zp_compensation, ptr[rsp + reg_zp_compensation_off]);
298 mov(reg_src_zero_point, ptr[rsp + reg_src_zero_point_off]);
299 }
300 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
301 if (jcp.src_zero_point) {
302 uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]);
303 }
304 const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
305 const int load_size = mask_flag ? get_tail_size() : simd_w;
306 const auto ptr_scales_offset
307 = jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load);
308 if (jcp.with_bias) {
309 if (jcp.signed_input || jcp.dst_scale)
310 mov(reg_bias_data, ptr[rsp + reg_bias_data_off]);
311 cvt2ps(jcp.bia_dt, vmm_bias, reg_bias_data,
312 jcp.typesize_bia * jcp.oc_block * i_load, load_size);
313 }
314 if (jcp.signed_input) {
315 mov(reg_comp_data, ptr[rsp + reg_comp_data_off]);
316 cvt2ps(data_type::s32, vmm_comp, reg_comp_data,
317 sizeof(int32_t) * jcp.oc_block * i_load, load_size);
318 }
319 if (jcp.src_zero_point) {
320 const int zp_offset = sizeof(int32_t) * i_load * jcp.oc_block;
321 load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation,
322 zp_offset, load_size);
323 uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_zp);
324
325 // upscale to f32
326 uni_vcvtdq2ps(vmm_zp_comp, vmm_zp_comp);
327 }
328
329 if (mask_flag) {
330 uni_vpxor(vmm_scale, vmm_scale, vmm_scale);
331 cvt2ps(data_type::f32, vmm_scale, reg_ptr_scales,
332 ptr_scales_offset, get_tail_size());
333 } else {
334 uni_vmovups(vmm_scale, ptr[reg_ptr_scales + ptr_scales_offset]);
335 }
336
337 for (int i_ur = 0; i_ur < ur; ++i_ur) {
338 const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
339 uni_vcvtdq2ps(r, r);
340 if (jcp.signed_input) uni_vaddps(r, r, vmm_comp);
341 if (jcp.src_zero_point) uni_vaddps(r, r, vmm_zp_comp);
342
343 uni_vmulps(r, r, vmm_scale);
344
345 if (jcp.with_bias) uni_vaddps(r, r, vmm_bias);
346 }
347 }
348
349 apply_postops(ur, load_loop_blk, mask_flag_in, p_sum_scale, p_sum_zp);
350
351 if (jcp.dst_scale) {
352 mov(reg_ptr_dst_scale, ptr[rsp + reg_dst_scale_off]);
353 uni_vmovups(vmm_dst_scale, ptr[reg_ptr_dst_scale]);
354
355 /* Apply dst scale to accumulator */
356 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
357 for (int i_ur = 0; i_ur < ur; ++i_ur) {
358 const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
359 uni_vmulps(r, r, vmm_dst_scale);
360 }
361 }
362 }
363
364 if (jcp.dst_zero_point) {
365 mov(reg_dst_zero_point, ptr[rsp + reg_dst_zero_point_off]);
366 uni_vpbroadcastd(vmm_zp, ptr[reg_dst_zero_point]);
367 uni_vcvtdq2ps(vmm_zp, vmm_zp);
368
369 /* Add dst zero_point to accumulator */
370 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
371 for (int i_ur = 0; i_ur < ur; ++i_ur) {
372 const auto r = vreg_accum(load_loop_blk, i_load, i_ur);
373 uni_vaddps(r, r, vmm_zp);
374 }
375 }
376 }
377
378 // Properly saturate the accumulators for integer datatypes
379 if (utils::one_of(jcp.dst_dt, u8, s8, s32)) {
380 init_saturate_f32(vmm_zero, vmm_saturation, aux_reg_saturation, f32,
381 jcp.dst_dt);
382
383 for (int i_ur = 0; i_ur < ur; ++i_ur)
384 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
385 auto r = vreg_accum(load_loop_blk, i_load, i_ur);
386 saturate_f32(r, vmm_zero, vmm_saturation, jcp.dst_dt);
387 uni_vcvtps2dq(r, r);
388 }
389 }
390
391 /* write out register to output_addr */
392 for (int i_ur = 0; i_ur < ur; ++i_ur) {
393 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
394 const bool mask_flag
395 = mask_flag_in && i_load == load_loop_blk - 1;
396 auto r = vreg_accum(load_loop_blk, i_load, i_ur);
397 store_data(jcp.dst_dt, r, aux_reg_output_data,
398 output_ptr(i_load, i_ur),
399 mask_flag ? get_tail_size() : simd_w);
400 }
401 }
402 mov(reg_bcast_data, ptr[rsp + reg_bcast_data_off]);
403 if (p_sum_scale && *p_sum_scale != 1.f)
404 mov(reg_load_data, ptr[rsp + reg_load_data_off]);
405 };
406
407 auto compute = [&](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
408 if (jcp.has_vnni) {
409 vpdpbusd(vreg_acc, vreg_src, vreg_wei, VexEncoding);
410 } else {
411 uni_vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
412 uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
413 uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp);
414 }
415 };
416
417 auto fma_block = [&](bool last_block) {
418 int reduce_step = 4;
419 int ic_tail_size = jcp.ic_without_padding % reduce_step;
420 int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding
421 ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step)
422 : jcp.reduce_loop_unroll;
423 for (int i_reduce = 0; i_reduce < loop_unroll;
424 i_reduce += reduce_step) {
425 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
426 uni_vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
427 for (int i_ur = 0; i_ur < ur; ++i_ur) {
428 if (last_block && ic_tail_size != 0
429 && i_reduce == loop_unroll - reduce_step) {
430 load_bytes(vmm_bcast, aux_reg_bcast_data,
431 jcp.ic_without_padding * i_ur + i_reduce,
432 ic_tail_size);
433 uni_vpbroadcastd(vmm_bcast, Xmm(vmm_bcast.getIdx()));
434 } else {
435 uni_vpbroadcastd(vmm_bcast, bcast_ptr(i_reduce, i_ur));
436 }
437 if (jcp.signed_input)
438 uni_vpsubb(vmm_bcast, vmm_bcast, vmm_shift);
439 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
440 compute(vreg_accum(load_loop_blk, i_load, i_ur),
441 vreg_load(i_load), vmm_bcast);
442 }
443 }
444 }
445 };
446
447 Label reduce_loop;
448 Label reduce_loop_tail;
449
450 mov(aux_reg_load_data, reg_load_data);
451
452 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
453 init();
454
455 mov(reg_reduce_loop_iter, reg_reduce_loop_work);
456 sub(reg_reduce_loop_iter, jcp.reduce_loop_unroll);
457 jle(reduce_loop_tail, T_NEAR);
458
459 L(reduce_loop);
460 {
461 fma_block(false);
462 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
463 add(aux_reg_load_data, jcp.reduce_loop_load_step);
464 sub(reg_reduce_loop_iter, jcp.reduce_loop_unroll);
465 jg(reduce_loop, T_NEAR);
466 }
467
468 L(reduce_loop_tail);
469 fma_block(jcp.ic != jcp.ic_without_padding);
470
471 if (jcp.oc_without_padding != jcp.oc) {
472 Label end_store, common_store;
473 mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data);
474
475 /*Check if it is the last load_loop_blk*/
476 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
477 cmp(reg_load_loop_work, 0);
478 jg(common_store, T_NEAR);
479
480 /*Check if it is the last ocb*/
481 test(reg_reduce_pos_flag, FLAG_OC_LAST);
482 jz(common_store, T_NEAR);
483
484 store(true);
485 jmp(end_store, T_NEAR);
486
487 L(common_store);
488 store(false);
489
490 L(end_store);
491
492 add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
493 } else {
494 store(false);
495 }
496}
497
498template <cpu_isa_t isa, typename Vmm>
499void _jit_uni_x8s8s32x_1x1_conv_kernel<isa, Vmm>::generate() {
500 preamble();
501
502 sub(rsp, stack_space_needed);
503 if (jcp.with_binary) {
504 // zero initialize binary post_ops offset accumulator (store on stack)
505 const auto binary_post_op_acc_off_reg = r15;
506 xor_(binary_post_op_acc_off_reg, binary_post_op_acc_off_reg);
507 mov(ptr[rsp + reg_binary_post_op_acc_off], binary_post_op_acc_off_reg);
508 }
509
510 if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
511 if (jcp.signed_input) {
512 mov(ptr[rsp + reg_bias_data_off], reg_bias_data);
513 mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]);
514 mov(ptr[rsp + reg_comp_data_off], reg_comp_data);
515 }
516 if (jcp.src_zero_point) {
517 mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]);
518 mov(ptr[rsp + reg_zp_compensation_off], reg_zp_compensation);
519 mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]);
520 mov(ptr[rsp + reg_src_zero_point_off], reg_src_zero_point);
521 }
522 if (jcp.dst_scale) {
523 if (!jcp.signed_input) mov(ptr[rsp + reg_bias_data_off], reg_bias_data);
524 mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]);
525 mov(ptr[rsp + reg_dst_scale_off], reg_ptr_dst_scale);
526 }
527 if (jcp.dst_zero_point) {
528 mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]);
529 mov(ptr[rsp + reg_dst_zero_point_off], reg_dst_zero_point);
530 }
531 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
532 mov(ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales);
533 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
534 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
535 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
536
537 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
538 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
539 mov(ptr[rsp + bcast_loop_work_off], reg_bcast_loop_work);
540 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
541 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
542
543 auto load_loop_body = [&](int load_loop_blk) {
544 bcast_loop(load_loop_blk);
545 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
546 if (jcp.with_bias) {
547 if (jcp.signed_input || jcp.dst_scale)
548 mov(reg_bias_data, ptr[rsp + reg_bias_data_off]);
549 add(reg_bias_data,
550 load_loop_blk * jcp.load_block * jcp.typesize_bia);
551 if (jcp.signed_input || jcp.dst_scale)
552 mov(ptr[rsp + reg_bias_data_off], reg_bias_data);
553 }
554 if (jcp.with_binary) {
555 mov(aux_reg_load_data,
556 EVEX_compress_addr(rsp, reg_binary_post_op_acc_off));
557 add(aux_reg_load_data, jcp.load_block * load_loop_blk);
558 mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off),
559 aux_reg_load_data);
560 }
561 if (jcp.signed_input) {
562 mov(reg_comp_data, ptr[rsp + reg_comp_data_off]);
563 add(reg_comp_data,
564 load_loop_blk * jcp.load_block * sizeof(int32_t));
565 mov(ptr[rsp + reg_comp_data_off], reg_comp_data);
566 }
567 if (jcp.src_zero_point) {
568 mov(reg_zp_compensation, ptr[rsp + reg_zp_compensation_off]);
569 add(reg_zp_compensation,
570 load_loop_blk * jcp.load_block * sizeof(int32_t));
571 mov(ptr[rsp + reg_zp_compensation_off], reg_zp_compensation);
572 }
573 mov(ptr[rsp + reg_bcast_data_off], reg_bcast_data);
574 mov(reg_ptr_scales, ptr[rsp + reg_ptr_sum_scale_off]);
575 add(reg_ptr_scales,
576 jcp.is_oc_scale * load_loop_blk * jcp.load_block
577 * sizeof(float));
578 mov(ptr[rsp + reg_ptr_sum_scale_off], reg_ptr_scales);
579 mov(reg_bcast_data, ptr[rsp + reg_bcast_data_off]);
580 add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out);
581 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
582 };
583
584 static const int ur_cases[] = {2, 3, 5, 12};
585 constexpr int num_ur_cases = sizeof(ur_cases) / sizeof(*ur_cases);
586 Label load_loop_blk[num_ur_cases + 1];
587
588 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
589 int label_idx = num_ur_cases - ur_idx - 1;
590 if (jcp.ur <= ur_cases[ur_idx]) {
591 cmp(reg_load_loop_work, simd_w * (label_idx + 1));
592 jle(load_loop_blk[label_idx], T_NEAR);
593 }
594 }
595
596 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
597 if (jcp.ur <= ur_cases[ur_idx]) {
598 int label_idx = num_ur_cases - ur_idx - 1;
599 L(load_loop_blk[label_idx]);
600 {
601 if (label_idx == 0) {
602 cmp(reg_load_loop_work, 0);
603 je(load_loop_blk[num_ur_cases], T_NEAR);
604 }
605
606 load_loop_body(label_idx + 1);
607 if (label_idx - 1 > 0) {
608 cmp(reg_load_loop_work, 2 * label_idx * simd_w);
609 je(load_loop_blk[label_idx - 1], T_NEAR);
610 }
611 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
612 jge(load_loop_blk[label_idx]);
613 }
614 for (int idx = label_idx - 1; idx > 0; --idx) {
615 cmp(reg_load_loop_work, simd_w * (idx + 1));
616 je(load_loop_blk[idx], T_NEAR);
617 }
618 if (ur_idx < num_ur_cases - 2) {
619 cmp(reg_load_loop_work, simd_w);
620 jle(load_loop_blk[0], T_NEAR);
621 }
622 }
623 }
624 L(load_loop_blk[num_ur_cases]);
625 add(rsp, stack_space_needed);
626 postamble();
627
628 if (jcp.with_eltwise) postops_injector_->prepare_table();
629}
630
631template <cpu_isa_t isa>
632status_t jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_conf(
633 jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
634 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
635 const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
636 primitive_attr_t &attr, int nthreads, bool reduce_src) {
637 if (!mayiuse(isa)) return status::unimplemented;
638
639 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
640 if (!one_of(src_d.data_type(), data_type::u8, data_type::s8)
641 || weights_d.data_type() != data_type::s8
642 || !one_of(dst_d.data_type(), data_type::f32, data_type::s32,
643 data_type::s8, data_type::u8))
644 return status::unimplemented;
645
646 const int ndims = src_d.ndims();
647 jcp.nthr = nthreads;
648 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
649 jcp.mb = src_d.dims()[0];
650 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
651 jcp.oc_without_padding = jcp.oc;
652 jcp.ic = src_d.dims()[1] / jcp.ngroups;
653 jcp.ic_without_padding = jcp.ic;
654 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
655 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims - 2];
656 jcp.iw = src_d.dims()[ndims - 1];
657 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
658 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims - 2];
659 jcp.ow = dst_d.dims()[ndims - 1];
660 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
661 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
662 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
663 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
664 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims - 4];
665 jcp.l_pad = cd.padding[0][ndims - 3];
666 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
667 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims - 4];
668 jcp.stride_w = cd.strides[ndims - 3];
669 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
670
671 jcp.signed_input = (src_d.data_type() == data_type::s8);
672
673 jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow;
674 jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw;
675
676 const auto &post_ops = attr.post_ops_;
677 const int dw_conv_ind = post_ops.find(primitive_kind::convolution);
678 jcp.with_dw_conv = dw_conv_ind != -1;
679 // Using dw_conv_ind as upper-bound below, as post-ops after it will be
680 // handled in depthwise convolution.
681 const int eltwise_ind
682 = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind);
683 jcp.with_eltwise = eltwise_ind != -1;
684
685 const int binary_ind
686 = post_ops.find(primitive_kind::binary, 0, dw_conv_ind);
687 jcp.with_binary = binary_ind != -1;
688
689 const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind);
690 jcp.with_sum = sum_ind != -1;
691
692 const auto zp = attr.zero_points_;
693 jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST);
694 jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC);
695 jcp.zp_src_is_common
696 = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel
697 assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common));
698
699 if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.with_dw_conv)
700 return status::unimplemented;
701
702 format_tag_t dat_tag = utils::pick(
703 ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
704 jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
705 jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
706
707 bool args_ok = true && jcp.ngroups == 1 && jcp.src_tag == dat_tag
708 && jcp.dst_tag == dat_tag;
709 if (!args_ok) return status::unimplemented;
710
711 jcp.has_vnni = mayiuse(avx2_vnni);
712
713 jcp.oc = rnd_up(jcp.oc, simd_w);
714 jcp.ic = rnd_up(jcp.ic, simd_w);
715
716 if (dw_conv_ind >= 0) {
717 // dw_conv and post_ops after it are handled externally, so skip them
718 jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(),
719 post_ops.entry_.cbegin() + dw_conv_ind);
720 } else {
721 jcp.post_ops = post_ops;
722 }
723
724 for (auto &post_op : jcp.post_ops.entry_)
725 if (post_op.is_binary() && post_op.binary.src1_desc.dims[1] != 1) {
726 post_op.binary.src1_desc.dims[1] = jcp.oc;
727 }
728
729 using namespace injector;
730 const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum},
731 jcp.post_ops, &dst_d, false, false, false});
732 if (!post_ops_ok_) return status::unimplemented;
733
734 args_ok = true && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
735 && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0
736 && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1
737 && jcp.ow == jcp.iw && jcp.oh == jcp.ih && jcp.od == jcp.id
738 && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1;
739 if (!args_ok) return status::unimplemented;
740
741 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
742 jcp.dst_dt = cd.dst_desc.data_type;
743 jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt);
744
745 jcp.ic_block = jcp.oc_block = simd_w;
746
747 jcp.typesize_in = types::data_type_size(src_d.data_type());
748 jcp.typesize_out = types::data_type_size(dst_d.data_type());
749 jcp.typesize_bia
750 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
751
752 const int SMALL_SPATIAL = 7 * 7;
753 const int BIG_REDUCE_DIM = 512;
754
755 int load_blocking = 0;
756 int load_blocking_max = 0;
757 int bcast_blocking = 0;
758 int bcast_blocking_max = 0;
759 int reduce_blocking = 0;
760 int reduce_blocking_max = 0;
761 jcp.load_grp_count = 1;
762
763 const int L2_size
764 = platform::get_per_core_cache_size(2) / sizeof(jcp.typesize_in);
765 const int L2_capacity = (L2_size * 3) / 4;
766
767 int size_threshold = 28;
768
769 int min_regs = 3;
770 int max_regs = 5;
771
772 if (jcp.mb == 1 && jcp.ic > 128
773 && (jcp.oh <= size_threshold && jcp.ow <= size_threshold)) {
774 if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size)
775 max_regs = min_regs = 3;
776 jcp.ur = nstl::min<dim_t>(max_regs, jcp.os);
777 } else {
778 const int spatial = jcp.od * jcp.oh;
779 jcp.ur = 1;
780 for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
781 if ((spatial >= size_threshold && spatial % ur_w == 0)
782 || (spatial < size_threshold && jcp.os % ur_w == 0)) {
783 jcp.ur = ur_w;
784 break;
785 }
786 }
787 if (jcp.ur == 1) {
788 jcp.ur = nstl::min<dim_t>(max_regs, jcp.os);
789 int os_tail = jcp.os % max_regs;
790 for (int i = max_regs; i >= min_regs; i--) {
791 int i_tail = jcp.os % i;
792 if (i_tail > os_tail || i_tail == 0) {
793 jcp.ur = i;
794 os_tail = i_tail;
795 if (i_tail == 0) break;
796 }
797 }
798 }
799 }
800
801 if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur);
802 jcp.reduce_dim = jcp.ic;
803 jcp.reduce_block = jcp.ic_block;
804
805 jcp.load_dim = jcp.oc;
806 jcp.load_block = jcp.oc_block;
807
808 jcp.bcast_dim = jcp.is;
809
810 jcp.bcast_block = jcp.ur;
811
812 jcp.reduce_loop_unroll = jcp.reduce_block;
813 jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll * jcp.typesize_in;
814
815 jcp.reduce_loop_load_step
816 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
817
818 jcp.bcast_loop_output_step
819 = jcp.ur * jcp.oc_without_padding * jcp.typesize_out;
820 jcp.bcast_loop_bcast_step
821 = jcp.ur * jcp.ic_without_padding * jcp.typesize_in;
822
823 jcp.load_loop_load_step = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
824
825 jcp.load_loop_iter_step = jcp.load_block;
826
827 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
828
829 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
830 int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
831
832 reduce_blocking = nb_reduce;
833 if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
834 reduce_blocking = 64;
835 else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
836 reduce_blocking = 16;
837
838 reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
839 reduce_blocking *= jcp.reduce_block;
840
841 bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
842 if (cmp_reduce) jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
843 load_blocking = jcp.load_dim;
844
845 jcp.load_grp_count = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast);
846 jcp.load_grp_count = best_divider(
847 jcp.nthr, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
848
849 if (jcp.bcast_dim <= SMALL_SPATIAL
850 && jcp.load_dim * jcp.reduce_dim >= L2_size) {
851 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
852 } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= jcp.nthr
853 && jcp.load_dim > 256 && jcp.load_dim / jcp.reduce_dim >= 4) {
854 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
855 load_blocking = jcp.load_block;
856 }
857
858 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
859 div_up(jcp.nthr, jcp.load_grp_count))
860 * jcp.bcast_block;
861 bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking);
862 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
863
864 int space_for_bcast = (L2_capacity - /* kernel_size - */
865 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking
866 - 3 * 1024);
867 if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2;
868
869 int bcast_in_cache
870 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
871 bcast_blocking = nstl::min(
872 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
873
874 load_blocking_max = load_blocking;
875 bcast_blocking_max = bcast_blocking * 3 / 2;
876 reduce_blocking_max = reduce_blocking;
877
878 const bool params_ok = true && load_blocking > 0 && load_blocking_max > 0
879 && bcast_blocking > 0 && bcast_blocking_max > 0
880 && reduce_blocking > 0 && reduce_blocking_max > 0
881 && load_blocking % jcp.load_block == 0
882 && reduce_blocking % jcp.reduce_block == 0
883 && load_blocking_max % jcp.load_block == 0
884 && reduce_blocking_max % jcp.reduce_block == 0
885 && jcp.reduce_loop_unroll % 4 == 0
886 && jcp.reduce_dim % jcp.reduce_loop_unroll == 0
887 && jcp.bcast_block % jcp.ur == 0
888 && jcp.reduce_dim % jcp.reduce_block == 0;
889
890 assert(params_ok && "parameter values are inconsistent");
891 if (!params_ok) return status::unimplemented;
892
893 jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur;
894
895 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
896 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
897 jcp.nb_load_blocking = load_blocking / jcp.load_block;
898 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
899 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
900 jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
901
902 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
903 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
904 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
905
906 // miniumum size of load dim chunk for work distribution within threads
907 jcp.nb_load_chunk = 1;
908
909 const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
910 const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
911 const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
912 const int wei_mask_per_oc = 1 << (int)with_groups;
913 jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc;
914 jcp.dst_scale = !dst_scales.has_default_values();
915
916 const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc)
917 && everyone_is(src_scales.mask_, dst_scales.mask_, 0);
918 if (!scales_ok) return status::unimplemented;
919
920 jcp.wei_adj_scale
921 = (weights_d.extra().flags & memory_extra_flags::scale_adjust)
922 ? weights_d.extra().scale_adjust
923 : 1.f;
924
925 return status::success;
926}
927
928template <cpu_isa_t isa>
929void jit_uni_x8s8s32x_1x1_conv_kernel<isa>::init_scratchpad(
930 memory_tracking::registrar_t &scratchpad,
931 const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
932 using namespace dnnl::impl::memory_tracking::names;
933
934 const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_;
935 const dim_t scales_count = wei_mask == 0 ? 1 : jcp.oc * jcp.ngroups;
936 const dim_t count = nstl::max<dim_t>(scales_count, 8);
937 scratchpad.book<float>(key_conv_adjusted_scales, count);
938}
939
940template struct _jit_uni_x8s8s32x_1x1_conv_kernel<avx2, Ymm>;
941template struct _jit_uni_x8s8s32x_1x1_conv_kernel<sse41, Xmm>;
942template struct jit_uni_x8s8s32x_1x1_conv_kernel<avx2>;
943template struct jit_uni_x8s8s32x_1x1_conv_kernel<sse41>;
944} // namespace x64
945} // namespace cpu
946} // namespace impl
947} // namespace dnnl
948