1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/nstl.hpp"
20#include "common/utils.hpp"
21
22#include "cpu/x64/jit_generator.hpp"
23
24#include "cpu/x64/jit_uni_eltwise_int.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31using namespace Xbyak;
32
33struct jit_args_int8_t {
34 const void *from;
35 const void *for_comparison;
36 const void *to;
37 size_t work_amount;
38};
39
40struct jit_uni_eltwise_int_kernel : public jit_generator {
41 jit_uni_eltwise_int_kernel(const eltwise_pd_t *pd, const char *name)
42 : jit_generator(name), pd_(pd) {}
43
44 void operator()(jit_args_int8_t *p) { jit_generator::operator()(p); }
45
46protected:
47 data_type_t data_type() const { return pd_->src_md()->data_type; }
48 int dtype_size() const { return types::data_type_size(data_type()); }
49
50 const eltwise_desc_t &desc() const { return *pd_->desc(); }
51
52private:
53 const eltwise_pd_t *pd_;
54};
55
56/* jit kernels */
57namespace {
58using namespace Xbyak;
59
60template <cpu_isa_t isa>
61struct jit_uni_subkernel_int_t : public jit_uni_eltwise_int_kernel {
62 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_subkernel_int)
63
64 jit_uni_subkernel_int_t(const eltwise_pd_t *pd)
65 : jit_uni_eltwise_int_kernel(pd, jit_name()) {
66 using namespace data_type;
67
68 // Relu and linear for int types: s32, s8, u8; Only forward direction
69 assert(utils::one_of(desc().alg_kind, alg_kind::eltwise_relu,
70 alg_kind::eltwise_linear));
71 assert(utils::one_of(data_type(), s32, s8, u8));
72 assert(utils::one_of(isa, sse41, avx2, avx512_core));
73 }
74
75 void generate() override {
76 Reg64 param = abi_param1;
77
78 const size_t vlen = cpu_isa_traits<isa>::vlen;
79 const size_t simd_w = vlen / sizeof(float);
80 const size_t loop_dec[] = {simd_w, 1};
81 const size_t uf[] = {1, 1};
82 const size_t shift[] = {dtype_size() * simd_w, (size_t)dtype_size()};
83 const bool loop_vectorize[] = {true, false};
84
85 preamble();
86
87#define GET_OFF(field) offsetof(jit_args_int8_t, field)
88 mov(reg_from, ptr[param + GET_OFF(from)]);
89 mov(reg_to, ptr[param + GET_OFF(to)]);
90 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
91#undef GET_OFF
92
93 mov(imm_addr64, float2int(desc().alpha));
94 uni_vmovq(xmm_alpha, imm_addr64);
95 uni_vbroadcastss(vmm_alpha, xmm_alpha);
96
97 mov(imm_addr64, float2int(desc().beta));
98 uni_vmovq(xmm_beta, imm_addr64);
99 uni_vbroadcastss(vmm_beta, xmm_beta);
100
101 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
102 xor_(reg_int8, reg_int8);
103 if (isa == avx512_core) {
104 mov(reg_int8.cvt8(), 0x01);
105 kmovw(k_mask_int8, reg_int8.cvt32());
106 }
107
108 Label loop_label[3];
109
110 for (int id = 0; id < 2; id++) {
111 L(loop_label[id]);
112 cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
113 jle(loop_label[id + 1], T_NEAR);
114
115 compute_step(
116 loop_vectorize[id], uf[id], shift[id], desc().alg_kind);
117
118 add(reg_from, uf[id] * shift[id]);
119 add(reg_to, uf[id] * shift[id]);
120
121 sub(reg_work_amount, uf[id] * loop_dec[id]);
122 jmp(loop_label[id]);
123 }
124
125 L(loop_label[2]);
126 postamble();
127 }
128
129private:
130 using Vmm = typename cpu_isa_traits<isa>::Vmm;
131 using opmask_t = const Xbyak::Opmask;
132
133 Reg64 reg_from = rax;
134 Reg64 reg_to = r8;
135 Reg64 reg_work_amount = rsi;
136 Reg64 imm_addr64 = rbx;
137 Reg64 reg_int8 = r9;
138
139 Xmm xmm_alpha = Xmm(13);
140 Xmm xmm_beta = Xmm(14);
141
142 Vmm vmm_tmp = Vmm(isa == avx512_core ? 26 : 11);
143 Vmm vmm_alpha = Vmm(isa == avx512_core ? 27 : 13);
144 Vmm vmm_beta = Vmm(isa == avx512_core ? 28 : 14);
145 Vmm vmm_zero = Vmm(isa == avx512_core ? 29 : 15);
146 Vmm vmm_mask = Vmm(isa == avx512_core ? 30 : 12);
147
148 opmask_t k_mask = k1;
149 opmask_t k_mask_int8 = k2; // Mask for store 1 byte in case of AVX512
150
151 bool is32bit() const { return data_type() == data_type::s32; }
152
153 // Load 32bit data type (s32)
154 void load_32bit(
155 const bool vectorize, const Vmm &vr_from, const Address &mem_from) {
156
157 if (vectorize) {
158 // load full Vmm size
159 uni_vmovups(vr_from, mem_from);
160 } else {
161 // load exactly one data item
162 movss(Xmm(vr_from.getIdx()), mem_from);
163 }
164 }
165
166 // Load 8bit data type (u8/s8)
167 void load_8bit(const bool vectorize, const Vmm &vr_from,
168 const Address &mem_from, bool is_signed) {
169
170 // data type u8/s8 load as s32
171 if (vectorize) {
172 // load full Vmm size
173 if (is_signed)
174 uni_vpmovsxbd(vr_from, mem_from);
175 else
176 uni_vpmovzxbd(vr_from, mem_from);
177 } else {
178 // load exactly one data item
179 mov(reg_int8.cvt8(), mem_from);
180 if (is_signed)
181 movsx(reg_int8.cvt32(), reg_int8.cvt8());
182 else
183 movzx(reg_int8.cvt32(), reg_int8.cvt8());
184 uni_vmovq(Xmm(vr_from.getIdx()), reg_int8);
185 }
186 }
187
188 // Load vregs with data from mem
189 void load(
190 const bool vectorize, const Vmm &vr_from, const Address &mem_from) {
191
192 // Branching on data size
193 if (is32bit())
194 load_32bit(vectorize, vr_from, mem_from);
195 else
196 load_8bit(
197 vectorize, vr_from, mem_from, data_type() == data_type::s8);
198 }
199
200 // Processing
201 void process_linear(const Vmm &vr_to, const Vmm &vr_from);
202 void process_relu(const Vmm &vr_to, const Vmm &vr_from);
203
204 // Store s32 for any isa
205 void store_32bit(
206 const bool vectorize, const Address &mem_to, const Vmm &vr_to) {
207 if (vectorize) {
208 // store full Vmm size
209 uni_vmovups(mem_to, vr_to);
210 } else {
211 // store exactly one data item
212 movss(mem_to, Xmm(vr_to.getIdx()));
213 }
214 }
215
216 // Store 8 bit int - isa-dependent
217 void store_8bit(const bool vectorize, const Address &mem_to,
218 const Vmm &vr_to, bool is_signed);
219
220 // Store results from vregs to mem
221 void store(const bool vectorize, const Address &mem_to, const Vmm &vr_to) {
222 // Branching on data size
223 if (is32bit())
224 store_32bit(vectorize, mem_to, vr_to);
225 else
226 store_8bit(vectorize, mem_to, vr_to, data_type() == data_type::s8);
227 }
228
229 void compute_step(bool vectorize, const size_t uf, const size_t shift,
230 const alg_kind_t alg) {
231
232 auto vreg_from = [&](const size_t i) -> Vmm { return Vmm(i + 1); };
233 auto vreg_to = [&](const size_t i) -> Vmm { return Vmm(uf + i + 1); };
234
235 // 1. Load (vregs <- mem)
236 for (size_t i = 0; i < uf; i++)
237 load(vectorize, vreg_from(i), ptr[reg_from + i * shift]);
238
239 // 2. Process (vregs <- vergs)
240 switch (alg) {
241 case alg_kind::eltwise_linear:
242 for (size_t i = 0; i < uf; i++)
243 process_linear(vreg_to(i), vreg_from(i));
244 break;
245 case alg_kind::eltwise_relu:
246 for (size_t i = 0; i < uf; i++)
247 process_relu(vreg_to(i), vreg_from(i));
248 break;
249 default: assert(!"unsupported alg");
250 }
251
252 // 3. Store (mem <- vregs)
253 for (size_t i = 0; i < uf; i++)
254 store(vectorize, ptr[reg_to + i * shift], vreg_to(i));
255 }
256};
257
258template <cpu_isa_t isa>
259void jit_uni_subkernel_int_t<isa>::process_linear(
260 const Vmm &vr_to, const Vmm &vr_from) {
261 uni_vcvtdq2ps(vr_to, vr_from);
262 uni_vfmadd213ps(vr_to, vmm_alpha, vmm_beta);
263
264 // Saturate before converting from f32 to s32
265 Vmm vmm_saturation_ubound = vmm_tmp;
266 Reg64 reg_tmp = r10;
267 uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
268 init_saturate_f32(vmm_zero, vmm_saturation_ubound, reg_tmp, data_type::f32,
269 data_type());
270 saturate_f32(vr_to, vmm_zero, vmm_saturation_ubound, data_type());
271
272 uni_vcvtps2dq(vr_to, vr_to);
273}
274
275template <cpu_isa_t isa>
276void jit_uni_subkernel_int_t<isa>::process_relu(
277 const Vmm &vr_to, const Vmm &vr_from) {
278 assert(!"unsupported isa");
279}
280
281template <>
282void jit_uni_subkernel_int_t<sse41>::process_relu(
283 const Vmm &vr_to, const Vmm &vr_from) {
284
285 cvtdq2ps(vr_from, vr_from);
286 movups(vr_to, vr_from);
287 mulps(vr_to, vmm_alpha);
288
289 Vmm mask = Vmm(0);
290 movups(mask, vr_from);
291 cmpps(mask, vmm_zero, _cmp_nle_us);
292 blendvps(vr_to, vr_from);
293 cvtps2dq(vr_to, vr_to);
294}
295
296template <>
297void jit_uni_subkernel_int_t<avx2>::process_relu(
298 const Vmm &vr_to, const Vmm &vr_from) {
299
300 vcvtdq2ps(vr_from, vr_from);
301 vmulps(vr_to, vr_from, vmm_alpha);
302 vcmpgtps(vmm_mask, vr_from, vmm_zero);
303 vblendvps(vr_to, vr_to, vr_from, vmm_mask);
304 vcvtps2dq(vr_to, vr_to);
305}
306
307template <>
308void jit_uni_subkernel_int_t<avx512_core>::process_relu(
309 const Vmm &vr_to, const Vmm &vr_from) {
310
311 vcvtdq2ps(vr_from, vr_from);
312 vmulps(vr_to, vr_from, vmm_alpha);
313 vcmpps(k_mask, vr_from, vmm_zero, _cmp_nle_us);
314 vblendmps(vr_to | k_mask, vr_to, vr_from);
315 vcvtps2dq(vr_to, vr_to);
316}
317
318template <cpu_isa_t isa>
319void jit_uni_subkernel_int_t<isa>::store_8bit(const bool vectorize,
320 const Address &mem_to, const Vmm &vr_to, bool is_signed) {
321 assert(!"unsupported isa");
322}
323
324template <>
325void jit_uni_subkernel_int_t<sse41>::store_8bit(const bool vectorize,
326 const Address &mem_to, const Vmm &vr_to, bool is_signed) {
327 if (vectorize) {
328 // store full Vmm size
329 // s32 -> s16
330 packssdw(vr_to, vmm_zero);
331 // s16 -> s8/u8
332 if (is_signed)
333 packsswb(vr_to, vmm_zero);
334 else
335 packuswb(vr_to, vmm_zero);
336
337 movd(mem_to, Xmm(vr_to.getIdx()));
338 } else {
339 // store exactly one data item
340 // s32 save as s8/u8
341 packssdw(vr_to, vmm_zero);
342 if (is_signed)
343 packsswb(vr_to, vmm_zero);
344 else
345 packuswb(vr_to, vmm_zero);
346 movd(reg_int8.cvt32(), Xmm(vr_to.getIdx()));
347 mov(mem_to, reg_int8.cvt8());
348 }
349}
350
351template <>
352void jit_uni_subkernel_int_t<avx2>::store_8bit(const bool vectorize,
353 const Address &mem_to, const Vmm &vr_to, bool is_signed) {
354 if (vectorize) {
355 // store full Vmm size
356 // s32 -> s16 = {qw0, 0, qw1, 0}
357 vpackssdw(vr_to, vr_to, vmm_zero);
358 // permute to restore order{qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
359 vpermq(Ymm(vr_to.getIdx()), Ymm(vr_to.getIdx()), 0x58);
360
361 // s16 -> s8/u8 : {16 x s16}{16 x 0} -> {32 x s8/u8}
362 if (is_signed)
363 vpacksswb(vr_to, vr_to, vmm_zero);
364 else
365 vpackuswb(vr_to, vr_to, vmm_zero);
366 uni_vmovq(mem_to, Xmm(vr_to.getIdx()));
367 } else {
368 // store exactly one data item
369 // s32 save as s8/u8
370 vpackssdw(vr_to, vr_to, vmm_zero);
371 if (is_signed)
372 vpacksswb(vr_to, vr_to, vmm_zero);
373 else
374 vpackuswb(vr_to, vr_to, vmm_zero);
375 vmovd(reg_int8.cvt32(), Xmm(vr_to.getIdx()));
376 mov(mem_to, reg_int8.cvt8());
377 }
378}
379
380template <>
381void jit_uni_subkernel_int_t<avx512_core>::store_8bit(const bool vectorize,
382 const Address &mem_to, const Vmm &vr_to, bool is_signed) {
383 if (vectorize) {
384 // store full Vmm size
385 if (is_signed)
386 vpmovsdb(mem_to, vr_to);
387 else
388 vpmovusdb(mem_to, vr_to);
389 } else {
390 // store exactly one data item
391 // s32 save as s8/u8
392 if (is_signed)
393 vpmovsdb(mem_to, vr_to | k_mask_int8);
394 else
395 vpmovusdb(mem_to, vr_to | k_mask_int8);
396 }
397}
398
399} /* namespace */
400
401template <cpu_isa_t isa, data_type_t d_type>
402status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
403 bool ok = is_fwd() && mayiuse(isa)
404 && utils::everyone_is(
405 d_type, src_md()->data_type, dst_md()->data_type)
406 // only relu and linear so far
407 && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
408 alg_kind::eltwise_linear)
409 && !has_zero_dim_memory()
410 && memory_desc_wrapper(src_md()).is_dense(true)
411 && attr()->has_default_values() && set_default_formats_common()
412 && memory_desc_wrapper(src_md()) == memory_desc_wrapper(dst_md());
413
414 return ok ? status::success : status::unimplemented;
415}
416
417template <cpu_isa_t isa, data_type_t d_type>
418jit_uni_eltwise_int_fwd_t<isa, d_type>::jit_uni_eltwise_int_fwd_t(
419 const pd_t *apd)
420 : primitive_t(apd) {}
421
422template <cpu_isa_t isa, data_type_t d_type>
423status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::init(engine_t *engine) {
424 CHECK(safe_ptr_assign(kernel_, new jit_uni_subkernel_int_t<isa>(pd())));
425 return kernel_->create_kernel();
426}
427
428template <cpu_isa_t isa, data_type_t d_type>
429jit_uni_eltwise_int_fwd_t<isa, d_type>::~jit_uni_eltwise_int_fwd_t() {
430 delete kernel_;
431}
432
433template <cpu_isa_t isa, impl::data_type_t d_type>
434status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::execute_forward(
435 const exec_ctx_t &ctx) const {
436 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
437 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
438
439 const memory_desc_wrapper src_d(pd()->src_md());
440
441 const size_t nelems = src_d.nelems(true);
442
443 src += src_d.offset0();
444 dst += src_d.offset0();
445
446 const int cache_line = 64 / src_d.data_type_size();
447 parallel(0, [&](const int ithr, const int nthr) {
448 size_t start {0}, end {0};
449
450 balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
451 start = nstl::min(nelems, start * cache_line);
452 end = nstl::min(nelems, end * cache_line);
453
454 auto arg = jit_args_int8_t();
455 arg.from = (const void *)&src[start];
456 arg.for_comparison = (const void *)&src[start];
457 arg.to = (const void *)&dst[start];
458 arg.work_amount = end - start;
459 if (arg.work_amount) (*kernel_)(&arg);
460 });
461 return status::success;
462}
463
464using namespace data_type;
465
466template struct jit_uni_eltwise_int_fwd_t<sse41, s32>;
467template struct jit_uni_eltwise_int_fwd_t<avx2, s32>;
468template struct jit_uni_eltwise_int_fwd_t<avx512_core, s32>;
469
470template struct jit_uni_eltwise_int_fwd_t<sse41, s8>;
471template struct jit_uni_eltwise_int_fwd_t<avx2, s8>;
472template struct jit_uni_eltwise_int_fwd_t<avx512_core, s8>;
473
474template struct jit_uni_eltwise_int_fwd_t<sse41, u8>;
475template struct jit_uni_eltwise_int_fwd_t<avx2, u8>;
476template struct jit_uni_eltwise_int_fwd_t<avx512_core, u8>;
477
478} // namespace x64
479} // namespace cpu
480} // namespace impl
481} // namespace dnnl
482