1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/utils.hpp" |
21 | |
22 | #include "cpu/x64/jit_generator.hpp" |
23 | |
24 | #include "cpu/x64/jit_uni_eltwise_int.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | namespace x64 { |
30 | |
31 | using namespace Xbyak; |
32 | |
33 | struct jit_args_int8_t { |
34 | const void *from; |
35 | const void *for_comparison; |
36 | const void *to; |
37 | size_t work_amount; |
38 | }; |
39 | |
40 | struct jit_uni_eltwise_int_kernel : public jit_generator { |
41 | jit_uni_eltwise_int_kernel(const eltwise_pd_t *pd, const char *name) |
42 | : jit_generator(name), pd_(pd) {} |
43 | |
44 | void operator()(jit_args_int8_t *p) { jit_generator::operator()(p); } |
45 | |
46 | protected: |
47 | data_type_t data_type() const { return pd_->src_md()->data_type; } |
48 | int dtype_size() const { return types::data_type_size(data_type()); } |
49 | |
50 | const eltwise_desc_t &desc() const { return *pd_->desc(); } |
51 | |
52 | private: |
53 | const eltwise_pd_t *pd_; |
54 | }; |
55 | |
56 | /* jit kernels */ |
57 | namespace { |
58 | using namespace Xbyak; |
59 | |
60 | template <cpu_isa_t isa> |
61 | struct jit_uni_subkernel_int_t : public jit_uni_eltwise_int_kernel { |
62 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_subkernel_int) |
63 | |
64 | jit_uni_subkernel_int_t(const eltwise_pd_t *pd) |
65 | : jit_uni_eltwise_int_kernel(pd, jit_name()) { |
66 | using namespace data_type; |
67 | |
68 | // Relu and linear for int types: s32, s8, u8; Only forward direction |
69 | assert(utils::one_of(desc().alg_kind, alg_kind::eltwise_relu, |
70 | alg_kind::eltwise_linear)); |
71 | assert(utils::one_of(data_type(), s32, s8, u8)); |
72 | assert(utils::one_of(isa, sse41, avx2, avx512_core)); |
73 | } |
74 | |
75 | void generate() override { |
76 | Reg64 param = abi_param1; |
77 | |
78 | const size_t vlen = cpu_isa_traits<isa>::vlen; |
79 | const size_t simd_w = vlen / sizeof(float); |
80 | const size_t loop_dec[] = {simd_w, 1}; |
81 | const size_t uf[] = {1, 1}; |
82 | const size_t shift[] = {dtype_size() * simd_w, (size_t)dtype_size()}; |
83 | const bool loop_vectorize[] = {true, false}; |
84 | |
85 | preamble(); |
86 | |
87 | #define GET_OFF(field) offsetof(jit_args_int8_t, field) |
88 | mov(reg_from, ptr[param + GET_OFF(from)]); |
89 | mov(reg_to, ptr[param + GET_OFF(to)]); |
90 | mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); |
91 | #undef GET_OFF |
92 | |
93 | mov(imm_addr64, float2int(desc().alpha)); |
94 | uni_vmovq(xmm_alpha, imm_addr64); |
95 | uni_vbroadcastss(vmm_alpha, xmm_alpha); |
96 | |
97 | mov(imm_addr64, float2int(desc().beta)); |
98 | uni_vmovq(xmm_beta, imm_addr64); |
99 | uni_vbroadcastss(vmm_beta, xmm_beta); |
100 | |
101 | uni_vpxor(vmm_zero, vmm_zero, vmm_zero); |
102 | xor_(reg_int8, reg_int8); |
103 | if (isa == avx512_core) { |
104 | mov(reg_int8.cvt8(), 0x01); |
105 | kmovw(k_mask_int8, reg_int8.cvt32()); |
106 | } |
107 | |
108 | Label loop_label[3]; |
109 | |
110 | for (int id = 0; id < 2; id++) { |
111 | L(loop_label[id]); |
112 | cmp(reg_work_amount, uf[id] * loop_dec[id] - 1); |
113 | jle(loop_label[id + 1], T_NEAR); |
114 | |
115 | compute_step( |
116 | loop_vectorize[id], uf[id], shift[id], desc().alg_kind); |
117 | |
118 | add(reg_from, uf[id] * shift[id]); |
119 | add(reg_to, uf[id] * shift[id]); |
120 | |
121 | sub(reg_work_amount, uf[id] * loop_dec[id]); |
122 | jmp(loop_label[id]); |
123 | } |
124 | |
125 | L(loop_label[2]); |
126 | postamble(); |
127 | } |
128 | |
129 | private: |
130 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
131 | using opmask_t = const Xbyak::Opmask; |
132 | |
133 | Reg64 reg_from = rax; |
134 | Reg64 reg_to = r8; |
135 | Reg64 reg_work_amount = rsi; |
136 | Reg64 imm_addr64 = rbx; |
137 | Reg64 reg_int8 = r9; |
138 | |
139 | Xmm xmm_alpha = Xmm(13); |
140 | Xmm xmm_beta = Xmm(14); |
141 | |
142 | Vmm vmm_tmp = Vmm(isa == avx512_core ? 26 : 11); |
143 | Vmm vmm_alpha = Vmm(isa == avx512_core ? 27 : 13); |
144 | Vmm vmm_beta = Vmm(isa == avx512_core ? 28 : 14); |
145 | Vmm vmm_zero = Vmm(isa == avx512_core ? 29 : 15); |
146 | Vmm vmm_mask = Vmm(isa == avx512_core ? 30 : 12); |
147 | |
148 | opmask_t k_mask = k1; |
149 | opmask_t k_mask_int8 = k2; // Mask for store 1 byte in case of AVX512 |
150 | |
151 | bool is32bit() const { return data_type() == data_type::s32; } |
152 | |
153 | // Load 32bit data type (s32) |
154 | void load_32bit( |
155 | const bool vectorize, const Vmm &vr_from, const Address &mem_from) { |
156 | |
157 | if (vectorize) { |
158 | // load full Vmm size |
159 | uni_vmovups(vr_from, mem_from); |
160 | } else { |
161 | // load exactly one data item |
162 | movss(Xmm(vr_from.getIdx()), mem_from); |
163 | } |
164 | } |
165 | |
166 | // Load 8bit data type (u8/s8) |
167 | void load_8bit(const bool vectorize, const Vmm &vr_from, |
168 | const Address &mem_from, bool is_signed) { |
169 | |
170 | // data type u8/s8 load as s32 |
171 | if (vectorize) { |
172 | // load full Vmm size |
173 | if (is_signed) |
174 | uni_vpmovsxbd(vr_from, mem_from); |
175 | else |
176 | uni_vpmovzxbd(vr_from, mem_from); |
177 | } else { |
178 | // load exactly one data item |
179 | mov(reg_int8.cvt8(), mem_from); |
180 | if (is_signed) |
181 | movsx(reg_int8.cvt32(), reg_int8.cvt8()); |
182 | else |
183 | movzx(reg_int8.cvt32(), reg_int8.cvt8()); |
184 | uni_vmovq(Xmm(vr_from.getIdx()), reg_int8); |
185 | } |
186 | } |
187 | |
188 | // Load vregs with data from mem |
189 | void load( |
190 | const bool vectorize, const Vmm &vr_from, const Address &mem_from) { |
191 | |
192 | // Branching on data size |
193 | if (is32bit()) |
194 | load_32bit(vectorize, vr_from, mem_from); |
195 | else |
196 | load_8bit( |
197 | vectorize, vr_from, mem_from, data_type() == data_type::s8); |
198 | } |
199 | |
200 | // Processing |
201 | void process_linear(const Vmm &vr_to, const Vmm &vr_from); |
202 | void process_relu(const Vmm &vr_to, const Vmm &vr_from); |
203 | |
204 | // Store s32 for any isa |
205 | void store_32bit( |
206 | const bool vectorize, const Address &mem_to, const Vmm &vr_to) { |
207 | if (vectorize) { |
208 | // store full Vmm size |
209 | uni_vmovups(mem_to, vr_to); |
210 | } else { |
211 | // store exactly one data item |
212 | movss(mem_to, Xmm(vr_to.getIdx())); |
213 | } |
214 | } |
215 | |
216 | // Store 8 bit int - isa-dependent |
217 | void store_8bit(const bool vectorize, const Address &mem_to, |
218 | const Vmm &vr_to, bool is_signed); |
219 | |
220 | // Store results from vregs to mem |
221 | void store(const bool vectorize, const Address &mem_to, const Vmm &vr_to) { |
222 | // Branching on data size |
223 | if (is32bit()) |
224 | store_32bit(vectorize, mem_to, vr_to); |
225 | else |
226 | store_8bit(vectorize, mem_to, vr_to, data_type() == data_type::s8); |
227 | } |
228 | |
229 | void compute_step(bool vectorize, const size_t uf, const size_t shift, |
230 | const alg_kind_t alg) { |
231 | |
232 | auto vreg_from = [&](const size_t i) -> Vmm { return Vmm(i + 1); }; |
233 | auto vreg_to = [&](const size_t i) -> Vmm { return Vmm(uf + i + 1); }; |
234 | |
235 | // 1. Load (vregs <- mem) |
236 | for (size_t i = 0; i < uf; i++) |
237 | load(vectorize, vreg_from(i), ptr[reg_from + i * shift]); |
238 | |
239 | // 2. Process (vregs <- vergs) |
240 | switch (alg) { |
241 | case alg_kind::eltwise_linear: |
242 | for (size_t i = 0; i < uf; i++) |
243 | process_linear(vreg_to(i), vreg_from(i)); |
244 | break; |
245 | case alg_kind::eltwise_relu: |
246 | for (size_t i = 0; i < uf; i++) |
247 | process_relu(vreg_to(i), vreg_from(i)); |
248 | break; |
249 | default: assert(!"unsupported alg" ); |
250 | } |
251 | |
252 | // 3. Store (mem <- vregs) |
253 | for (size_t i = 0; i < uf; i++) |
254 | store(vectorize, ptr[reg_to + i * shift], vreg_to(i)); |
255 | } |
256 | }; |
257 | |
258 | template <cpu_isa_t isa> |
259 | void jit_uni_subkernel_int_t<isa>::process_linear( |
260 | const Vmm &vr_to, const Vmm &vr_from) { |
261 | uni_vcvtdq2ps(vr_to, vr_from); |
262 | uni_vfmadd213ps(vr_to, vmm_alpha, vmm_beta); |
263 | |
264 | // Saturate before converting from f32 to s32 |
265 | Vmm vmm_saturation_ubound = vmm_tmp; |
266 | Reg64 reg_tmp = r10; |
267 | uni_vpxor(vmm_zero, vmm_zero, vmm_zero); |
268 | init_saturate_f32(vmm_zero, vmm_saturation_ubound, reg_tmp, data_type::f32, |
269 | data_type()); |
270 | saturate_f32(vr_to, vmm_zero, vmm_saturation_ubound, data_type()); |
271 | |
272 | uni_vcvtps2dq(vr_to, vr_to); |
273 | } |
274 | |
275 | template <cpu_isa_t isa> |
276 | void jit_uni_subkernel_int_t<isa>::process_relu( |
277 | const Vmm &vr_to, const Vmm &vr_from) { |
278 | assert(!"unsupported isa" ); |
279 | } |
280 | |
281 | template <> |
282 | void jit_uni_subkernel_int_t<sse41>::process_relu( |
283 | const Vmm &vr_to, const Vmm &vr_from) { |
284 | |
285 | cvtdq2ps(vr_from, vr_from); |
286 | movups(vr_to, vr_from); |
287 | mulps(vr_to, vmm_alpha); |
288 | |
289 | Vmm mask = Vmm(0); |
290 | movups(mask, vr_from); |
291 | cmpps(mask, vmm_zero, _cmp_nle_us); |
292 | blendvps(vr_to, vr_from); |
293 | cvtps2dq(vr_to, vr_to); |
294 | } |
295 | |
296 | template <> |
297 | void jit_uni_subkernel_int_t<avx2>::process_relu( |
298 | const Vmm &vr_to, const Vmm &vr_from) { |
299 | |
300 | vcvtdq2ps(vr_from, vr_from); |
301 | vmulps(vr_to, vr_from, vmm_alpha); |
302 | vcmpgtps(vmm_mask, vr_from, vmm_zero); |
303 | vblendvps(vr_to, vr_to, vr_from, vmm_mask); |
304 | vcvtps2dq(vr_to, vr_to); |
305 | } |
306 | |
307 | template <> |
308 | void jit_uni_subkernel_int_t<avx512_core>::process_relu( |
309 | const Vmm &vr_to, const Vmm &vr_from) { |
310 | |
311 | vcvtdq2ps(vr_from, vr_from); |
312 | vmulps(vr_to, vr_from, vmm_alpha); |
313 | vcmpps(k_mask, vr_from, vmm_zero, _cmp_nle_us); |
314 | vblendmps(vr_to | k_mask, vr_to, vr_from); |
315 | vcvtps2dq(vr_to, vr_to); |
316 | } |
317 | |
318 | template <cpu_isa_t isa> |
319 | void jit_uni_subkernel_int_t<isa>::store_8bit(const bool vectorize, |
320 | const Address &mem_to, const Vmm &vr_to, bool is_signed) { |
321 | assert(!"unsupported isa" ); |
322 | } |
323 | |
324 | template <> |
325 | void jit_uni_subkernel_int_t<sse41>::store_8bit(const bool vectorize, |
326 | const Address &mem_to, const Vmm &vr_to, bool is_signed) { |
327 | if (vectorize) { |
328 | // store full Vmm size |
329 | // s32 -> s16 |
330 | packssdw(vr_to, vmm_zero); |
331 | // s16 -> s8/u8 |
332 | if (is_signed) |
333 | packsswb(vr_to, vmm_zero); |
334 | else |
335 | packuswb(vr_to, vmm_zero); |
336 | |
337 | movd(mem_to, Xmm(vr_to.getIdx())); |
338 | } else { |
339 | // store exactly one data item |
340 | // s32 save as s8/u8 |
341 | packssdw(vr_to, vmm_zero); |
342 | if (is_signed) |
343 | packsswb(vr_to, vmm_zero); |
344 | else |
345 | packuswb(vr_to, vmm_zero); |
346 | movd(reg_int8.cvt32(), Xmm(vr_to.getIdx())); |
347 | mov(mem_to, reg_int8.cvt8()); |
348 | } |
349 | } |
350 | |
351 | template <> |
352 | void jit_uni_subkernel_int_t<avx2>::store_8bit(const bool vectorize, |
353 | const Address &mem_to, const Vmm &vr_to, bool is_signed) { |
354 | if (vectorize) { |
355 | // store full Vmm size |
356 | // s32 -> s16 = {qw0, 0, qw1, 0} |
357 | vpackssdw(vr_to, vr_to, vmm_zero); |
358 | // permute to restore order{qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0} |
359 | vpermq(Ymm(vr_to.getIdx()), Ymm(vr_to.getIdx()), 0x58); |
360 | |
361 | // s16 -> s8/u8 : {16 x s16}{16 x 0} -> {32 x s8/u8} |
362 | if (is_signed) |
363 | vpacksswb(vr_to, vr_to, vmm_zero); |
364 | else |
365 | vpackuswb(vr_to, vr_to, vmm_zero); |
366 | uni_vmovq(mem_to, Xmm(vr_to.getIdx())); |
367 | } else { |
368 | // store exactly one data item |
369 | // s32 save as s8/u8 |
370 | vpackssdw(vr_to, vr_to, vmm_zero); |
371 | if (is_signed) |
372 | vpacksswb(vr_to, vr_to, vmm_zero); |
373 | else |
374 | vpackuswb(vr_to, vr_to, vmm_zero); |
375 | vmovd(reg_int8.cvt32(), Xmm(vr_to.getIdx())); |
376 | mov(mem_to, reg_int8.cvt8()); |
377 | } |
378 | } |
379 | |
380 | template <> |
381 | void jit_uni_subkernel_int_t<avx512_core>::store_8bit(const bool vectorize, |
382 | const Address &mem_to, const Vmm &vr_to, bool is_signed) { |
383 | if (vectorize) { |
384 | // store full Vmm size |
385 | if (is_signed) |
386 | vpmovsdb(mem_to, vr_to); |
387 | else |
388 | vpmovusdb(mem_to, vr_to); |
389 | } else { |
390 | // store exactly one data item |
391 | // s32 save as s8/u8 |
392 | if (is_signed) |
393 | vpmovsdb(mem_to, vr_to | k_mask_int8); |
394 | else |
395 | vpmovusdb(mem_to, vr_to | k_mask_int8); |
396 | } |
397 | } |
398 | |
399 | } /* namespace */ |
400 | |
401 | template <cpu_isa_t isa, data_type_t d_type> |
402 | status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) { |
403 | bool ok = is_fwd() && mayiuse(isa) |
404 | && utils::everyone_is( |
405 | d_type, src_md()->data_type, dst_md()->data_type) |
406 | // only relu and linear so far |
407 | && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu, |
408 | alg_kind::eltwise_linear) |
409 | && !has_zero_dim_memory() |
410 | && memory_desc_wrapper(src_md()).is_dense(true) |
411 | && attr()->has_default_values() && set_default_formats_common() |
412 | && memory_desc_wrapper(src_md()) == memory_desc_wrapper(dst_md()); |
413 | |
414 | return ok ? status::success : status::unimplemented; |
415 | } |
416 | |
417 | template <cpu_isa_t isa, data_type_t d_type> |
418 | jit_uni_eltwise_int_fwd_t<isa, d_type>::jit_uni_eltwise_int_fwd_t( |
419 | const pd_t *apd) |
420 | : primitive_t(apd) {} |
421 | |
422 | template <cpu_isa_t isa, data_type_t d_type> |
423 | status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::init(engine_t *engine) { |
424 | CHECK(safe_ptr_assign(kernel_, new jit_uni_subkernel_int_t<isa>(pd()))); |
425 | return kernel_->create_kernel(); |
426 | } |
427 | |
428 | template <cpu_isa_t isa, data_type_t d_type> |
429 | jit_uni_eltwise_int_fwd_t<isa, d_type>::~jit_uni_eltwise_int_fwd_t() { |
430 | delete kernel_; |
431 | } |
432 | |
433 | template <cpu_isa_t isa, impl::data_type_t d_type> |
434 | status_t jit_uni_eltwise_int_fwd_t<isa, d_type>::execute_forward( |
435 | const exec_ctx_t &ctx) const { |
436 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
437 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
438 | |
439 | const memory_desc_wrapper src_d(pd()->src_md()); |
440 | |
441 | const size_t nelems = src_d.nelems(true); |
442 | |
443 | src += src_d.offset0(); |
444 | dst += src_d.offset0(); |
445 | |
446 | const int cache_line = 64 / src_d.data_type_size(); |
447 | parallel(0, [&](const int ithr, const int nthr) { |
448 | size_t start {0}, end {0}; |
449 | |
450 | balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end); |
451 | start = nstl::min(nelems, start * cache_line); |
452 | end = nstl::min(nelems, end * cache_line); |
453 | |
454 | auto arg = jit_args_int8_t(); |
455 | arg.from = (const void *)&src[start]; |
456 | arg.for_comparison = (const void *)&src[start]; |
457 | arg.to = (const void *)&dst[start]; |
458 | arg.work_amount = end - start; |
459 | if (arg.work_amount) (*kernel_)(&arg); |
460 | }); |
461 | return status::success; |
462 | } |
463 | |
464 | using namespace data_type; |
465 | |
466 | template struct jit_uni_eltwise_int_fwd_t<sse41, s32>; |
467 | template struct jit_uni_eltwise_int_fwd_t<avx2, s32>; |
468 | template struct jit_uni_eltwise_int_fwd_t<avx512_core, s32>; |
469 | |
470 | template struct jit_uni_eltwise_int_fwd_t<sse41, s8>; |
471 | template struct jit_uni_eltwise_int_fwd_t<avx2, s8>; |
472 | template struct jit_uni_eltwise_int_fwd_t<avx512_core, s8>; |
473 | |
474 | template struct jit_uni_eltwise_int_fwd_t<sse41, u8>; |
475 | template struct jit_uni_eltwise_int_fwd_t<avx2, u8>; |
476 | template struct jit_uni_eltwise_int_fwd_t<avx512_core, u8>; |
477 | |
478 | } // namespace x64 |
479 | } // namespace cpu |
480 | } // namespace impl |
481 | } // namespace dnnl |
482 | |