1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_GENERATOR_HPP
18#define CPU_X64_JIT_GENERATOR_HPP
19
20#include <limits.h>
21#include <vector>
22
23#include "common/bit_cast.hpp"
24#include "common/compiler_workarounds.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/x64/cpu_isa_traits.hpp"
29
30#include "cpu/jit_utils/jit_utils.hpp"
31
32#if defined(_WIN32) && !defined(__GNUC__)
33#define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__
34#else
35#define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al)))
36#endif
37
38#if defined(_WIN32)
39#define OFFSET_SHADOWSPACE 0x28
40#endif
41
42#if GCC_WA_NO_TREE_DOMINATOR_OPTS
43#define ATTRIBUTE_OPTIMIZE __attribute__((optimize("no-tree-dominator-opts")))
44#else
45#define ATTRIBUTE_OPTIMIZE
46#endif
47
48#define DECLARE_CPU_JIT_AUX_FUNCTIONS(gen_name) \
49 const char *name() const override { return STRINGIFY(gen_name); } \
50 const char *source_file() const override { return __FILE__; } \
51 static const char *jit_name() { \
52 static constexpr char ret[] = "/oneDNN:" STRINGIFY(gen_name); \
53 return ret; \
54 }
55
56namespace dnnl {
57namespace impl {
58namespace cpu {
59namespace x64 {
60
61// TODO: move this to jit_generator class?
62namespace {
63
64typedef enum {
65 MAX_CODE_SIZE = 256 * 1024,
66} max_code_size_t;
67
68// TODO: move this somewhere else? Although this is only used by jit kernels
69// (Roma)
70static inline int float2int(float x) {
71 return utils::bit_cast<int>(x);
72}
73
74static inline void tc_configure_tile(
75 palette_config_t *tc, int t, int rows, int cols) {
76 const bool rows_ok = (size_t)t < sizeof(tc->rows) / sizeof(tc->rows[0]);
77 const bool cols_ok = (size_t)t < sizeof(tc->cols) / sizeof(tc->cols[0]);
78 if (rows_ok && cols_ok) {
79 tc->rows[t] = rows;
80 tc->cols[t] = cols;
81 } else {
82 assert(!"out of range");
83 }
84}
85
86// TODO: A GPR class that hides ABI details from the JIT kernels and allows
87// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and
88// stack register (sr).
89//
90// This will allow using syntax like this:
91//
92// param = gpr0;
93// reg_input = gpr0;
94// reg_output = gpr1;
95// ...
96//
97// #ifndef XBYAK64
98// mov(param, ptr[sr])
99// #endif
100//
101// (Roma)
102
103} // namespace
104
105#ifdef XBYAK64
106constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = {
107 Xbyak::Operand::RBP,
108 Xbyak::Operand::RBX,
109 Xbyak::Operand::R12,
110 Xbyak::Operand::R13,
111 Xbyak::Operand::R14,
112 Xbyak::Operand::R15,
113#ifdef _WIN32
114 Xbyak::Operand::RDI,
115 Xbyak::Operand::RSI,
116#endif
117};
118
119constexpr Xbyak::Operand::Code abi_param_regs[] = {
120#ifdef _WIN32
121 Xbyak::Operand::RCX, Xbyak::Operand::RDX, Xbyak::Operand::R8,
122 Xbyak::Operand::R9
123#else
124 Xbyak::Operand::RDI, Xbyak::Operand::RSI, Xbyak::Operand::RDX,
125 Xbyak::Operand::RCX, Xbyak::Operand::R8, Xbyak::Operand::R9
126#endif
127};
128
129constexpr Xbyak::Operand::Code abi_not_param_reg =
130#ifdef _WIN32
131 Xbyak::Operand::RDI;
132#else
133 Xbyak::Operand::RCX;
134#endif
135
136#define abi_param1 Xbyak::Reg64(abi_param_regs[0])
137#define abi_param2 Xbyak::Reg64(abi_param_regs[1])
138#define abi_param3 Xbyak::Reg64(abi_param_regs[2])
139#define abi_param4 Xbyak::Reg64(abi_param_regs[3])
140#define abi_param5 Xbyak::Reg64(abi_param_regs[4])
141#define abi_param6 Xbyak::Reg64(abi_param_regs[5])
142#define abi_not_param1 Xbyak::Reg64(abi_not_param_reg)
143
144#endif
145
146class jit_generator : public Xbyak::MmapAllocator,
147 public Xbyak::CodeGenerator,
148 public c_compatible {
149public:
150 using c_compatible::operator new;
151 using c_compatible::operator new[];
152 using c_compatible::operator delete;
153 using c_compatible::operator delete[];
154
155private:
156 const size_t xmm_len = 16;
157#ifdef _WIN32
158 const size_t xmm_to_preserve_start = 6;
159 const size_t xmm_to_preserve = 10;
160#else
161 const size_t xmm_to_preserve_start = 0;
162 const size_t xmm_to_preserve = 0;
163#endif
164
165 const size_t num_abi_save_gpr_regs
166 = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]);
167
168 const size_t size_of_abi_save_regs
169 = num_abi_save_gpr_regs * rax.getBit() / 8
170 + xmm_to_preserve * xmm_len;
171
172public:
173 enum {
174 _cmp_eq_oq = 0u,
175 _cmp_lt_os = 1u,
176 _cmp_le_os = 2u,
177 _cmp_neq_uq = 4u,
178 _cmp_nlt_us = 5u,
179 _cmp_nle_us = 6u,
180
181 _op_floor = 1u,
182 _op_mxcsr = 4u,
183 };
184
185 Xbyak::Reg64 param1 = abi_param1;
186 const int EVEX_max_8b_offt = 0x200;
187 const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
188
189 inline size_t get_size_of_abi_save_regs() { return size_of_abi_save_regs; }
190
191 void preamble() {
192 if (xmm_to_preserve) {
193 sub(rsp, xmm_to_preserve * xmm_len);
194 for (size_t i = 0; i < xmm_to_preserve; ++i)
195 uni_vmovdqu(ptr[rsp + i * xmm_len],
196 Xbyak::Xmm(xmm_to_preserve_start + i));
197 }
198 for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) {
199 push(Xbyak::Reg64(abi_save_gpr_regs[i]));
200 // Stack magic: save rsp into rbp state to be able to unwind stack.
201 if (i == 0) mov(rbp, rsp);
202 }
203#ifndef DNNL_ENABLE_MEM_DEBUG
204 // do not use RBP in mem debug mode to enable backtracing from jit code
205 if (is_valid_isa(avx512_core)) {
206 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
207 }
208#endif
209
210#ifdef DNNL_ENABLE_MEM_DEBUG
211 // This section poisons vector registers with NaNs to catch situations
212 // when leftover trash in the register is used in a valid instruction
213 // without handling trash first.
214 {
215 // Preserve GPR since GeMM code relies on this one.
216 push(abi_not_param1);
217 if (is_valid_isa(avx512_core)) {
218 for (int i = 0; i < 32; i++) {
219 init_vmm(Xbyak::Zmm(i), abi_not_param1, NAN);
220 }
221 } else if (is_valid_isa(avx)) {
222 for (int i = 0; i < 16; i++) {
223 init_vmm(Xbyak::Ymm(i), abi_not_param1, NAN);
224 }
225 } else {
226 for (int i = 0; i < 16; i++) {
227 init_vmm(Xbyak::Xmm(i), abi_not_param1, NAN);
228 }
229 }
230 pop(abi_not_param1);
231 }
232#endif
233 }
234
235 // This function returns the address on the stack of the fist argument
236 // that is not passed by register
237 // By default it assumes to be called after the prologue
238 // Note: that we cannot use RBP inside as we override it in preamble
239 // for address computation in EVEX instructions
240 inline const Xbyak::RegExp get_stack_params_address(
241 bool after_prolog = true) {
242 int saved_regs_size = after_prolog ? get_size_of_abi_save_regs() : 0;
243#ifdef _WIN32
244 // Using stack layout described in MS ABI
245 // (https://docs.microsoft.com/en-us/cpp/build/stack-usage?view=vs-2019)
246 // here, the return address and the first 4 parameters are allocated
247 // on the stack
248 int first_params_and_return_addr_size = 40;
249#else
250 // In System V ABI, only the return address is stacked
251 // before the arguments
252 int first_params_and_return_addr_size = 8;
253#endif
254 return rsp + saved_regs_size + first_params_and_return_addr_size;
255 }
256
257 void uni_vzeroupper() {
258 if (mayiuse(avx)) vzeroupper();
259 }
260
261 void postamble() {
262 for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
263 pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i]));
264 if (xmm_to_preserve) {
265 for (size_t i = 0; i < xmm_to_preserve; ++i)
266 uni_vmovdqu(Xbyak::Xmm(xmm_to_preserve_start + i),
267 ptr[rsp + i * xmm_len]);
268 add(rsp, xmm_to_preserve * xmm_len);
269 }
270 uni_vzeroupper();
271 ret();
272 }
273
274 template <typename T>
275 Xbyak::Address EVEX_compress_addr(
276 Xbyak::Reg64 base, T raw_offt, bool bcast = false) {
277 using Xbyak::Address;
278 using Xbyak::Reg64;
279 using Xbyak::RegExp;
280 using Xbyak::Zmm;
281
282 assert(raw_offt <= INT_MAX);
283 auto offt = static_cast<int>(raw_offt);
284
285 int scale = 0;
286
287#ifndef DNNL_ENABLE_MEM_DEBUG
288 // do not use RBP in mem debug mode to enable backtracing from jit code
289 if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
290 offt = offt - 2 * EVEX_max_8b_offt;
291 scale = 1;
292 } else if (3 * EVEX_max_8b_offt <= offt
293 && offt < 5 * EVEX_max_8b_offt) {
294 offt = offt - 4 * EVEX_max_8b_offt;
295 scale = 2;
296 }
297#endif
298
299 auto re = RegExp() + base + offt;
300 if (scale) re = re + reg_EVEX_max_8b_offt * scale;
301
302 if (bcast)
303 return zword_b[re];
304 else
305 return zword[re];
306 }
307
308 Xbyak::Address make_safe_addr(const Xbyak::Reg64 &reg_out, size_t offt,
309 const Xbyak::Reg64 &tmp_reg, bool bcast = false) {
310 if (offt > INT_MAX) {
311 mov(tmp_reg, offt);
312 return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg];
313 } else {
314 return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt];
315 }
316 }
317
318 Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base,
319 size_t raw_offt, const Xbyak::Reg64 &reg_offt, bool bcast = false) {
320 if (raw_offt > INT_MAX) {
321 return make_safe_addr(base, raw_offt, reg_offt, bcast);
322 } else {
323 return EVEX_compress_addr(base, raw_offt, bcast);
324 }
325 }
326
327 void safe_add(const Xbyak::Reg64 &base, size_t raw_offt,
328 const Xbyak::Reg64 &reg_offt) {
329 if (raw_offt > INT_MAX) {
330 mov(reg_offt, raw_offt);
331 add(base, reg_offt);
332 } else {
333 add(base, raw_offt);
334 }
335 }
336
337 void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt,
338 const Xbyak::Reg64 &reg_offt) {
339 if (raw_offt > INT_MAX) {
340 mov(reg_offt, raw_offt);
341 sub(base, reg_offt);
342 } else {
343 sub(base, raw_offt);
344 }
345 }
346
347 // Disallow char-based labels completely
348 void L(const char *label) = delete;
349 void L(Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); }
350
351 void L_aligned(Xbyak::Label &label, int alignment = 16) {
352 align(alignment);
353 L(label);
354 }
355
356 void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
357 const Xbyak::Operand &op) {
358 if (is_valid_isa(avx512_core))
359 vpxord(x1, x2, op);
360 else if (is_valid_isa(avx))
361 vpxor(x1, x2, op);
362 else {
363 assert(x1.isEqualIfNotInherited(x2));
364 pxor(x2, op);
365 }
366 }
367 void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
368 const Xbyak::Operand &op) {
369 if (is_valid_isa(avx512_core))
370 vpxord(x1, x2, op);
371 else if (is_valid_isa(avx2))
372 vpxor(x1, x2, op);
373 else
374 vxorps(x1, x2, op);
375 }
376 void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2,
377 const Xbyak::Operand &op) {
378 vpxord(x1, x2, op);
379 }
380
381 void uni_vmovss(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
382 if (is_valid_isa(avx))
383 vmovss(addr, x);
384 else
385 movss(addr, x);
386 }
387 void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
388 if (is_valid_isa(avx))
389 vmovss(x, addr);
390 else
391 movss(x, addr);
392 }
393 void uni_vmovss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2) {
394 if (is_valid_isa(avx))
395 vmovss(x1, x1, x2);
396 else
397 movss(x1, x2);
398 }
399 void uni_vmovss(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
400 vmovss(addr, Xbyak::Xmm(x.getIdx()));
401 }
402 void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
403 vmovss(Xbyak::Xmm(x.getIdx()), addr);
404 }
405 void uni_vmovss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2) {
406 vmovss(Xbyak::Xmm(x1.getIdx()), Xbyak::Xmm(x2.getIdx()));
407 }
408
409 void uni_vmovsd(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
410 if (is_valid_isa(avx))
411 vmovsd(addr, x);
412 else
413 movsd(addr, x);
414 }
415 void uni_vmovsd(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
416 vmovsd(addr, x);
417 }
418 void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
419 if (is_valid_isa(avx))
420 vmovsd(x, addr);
421 else
422 movsd(x, addr);
423 }
424 void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
425 vmovsd(x, addr);
426 }
427
428 void uni_vmovlps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
429 if (is_valid_isa(avx))
430 vmovlps(addr, x);
431 else
432 movlps(addr, x);
433 }
434 void uni_vmovlps(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
435 vmovlps(addr, x);
436 }
437 void uni_vmovlps(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
438 if (is_valid_isa(avx))
439 vmovlps(x, addr);
440 else
441 movlps(x, addr);
442 }
443 void uni_vmovlps(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
444 vmovlps(x, addr);
445 }
446
447 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
448 if (is_valid_isa(avx))
449 vmovdqu(addr, x);
450 else
451 movdqu(addr, x);
452 }
453 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
454 vmovdqu(addr, x);
455 }
456 void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) {
457 vmovdqu32(addr, x);
458 }
459
460 void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
461 if (is_valid_isa(avx))
462 vmovdqu(x, addr);
463 else
464 movdqu(x, addr);
465 }
466 void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
467 vmovdqu(x, addr);
468 }
469 void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) {
470 vmovdqu32(x, addr);
471 }
472
473 void uni_vmovdqu16(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
474 if (is_valid_isa(avx512_core))
475 vmovdqu16(addr, x);
476 else if (is_valid_isa(avx))
477 vmovups(addr, x);
478 else
479 movups(addr, x);
480 }
481
482 void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
483 if (is_valid_isa(avx512_core))
484 vmovdqu16(x, addr);
485 else if (is_valid_isa(avx))
486 vmovups(x, addr);
487 else
488 movups(x, addr);
489 }
490
491 void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
492 if (is_valid_isa(avx))
493 vmovups(addr, x);
494 else
495 movups(addr, x);
496 }
497 void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
498 vmovups(addr, x);
499 }
500
501 void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
502 if (is_valid_isa(avx))
503 vmovups(x, op);
504 else
505 movups(x, op);
506 }
507 void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
508 vmovups(x, op);
509 }
510
511 void uni_vmovups_tail(const Xbyak::Address &addr, const Xbyak::Ymm &mask,
512 const Xbyak::Ymm &x) {
513 vmaskmovps(addr, mask, x);
514 }
515 void uni_vmovups_tail(const Xbyak::Ymm &x, const Xbyak::Ymm &mask,
516 const Xbyak::Address &addr) {
517 vmaskmovps(x, mask, addr);
518 }
519
520 void uni_vmovups_tail(const Xbyak::Address &addr, const Xbyak::Opmask &mask,
521 const Xbyak::Zmm &x) {
522 vmovups(addr | mask, x);
523 }
524 void uni_vmovups_tail(const Xbyak::Zmm &x, const Xbyak::Opmask &mask,
525 const Xbyak::Address &addr) {
526 vmovups(x | mask | T_z, addr);
527 }
528
529 void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
530 if (is_valid_isa(avx))
531 vmovntps(addr, x);
532 else
533 movntps(addr, x);
534 }
535
536 void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
537 if (is_valid_isa(avx2) || (is_valid_isa(avx) && op.isMEM()))
538 vbroadcastss(x, op);
539 else if (is_valid_isa(avx)) {
540 vmovss(x, x, op);
541 vshufps(x, x, x, 0x0);
542 } else {
543 movss(x, op);
544 shufps(x, x, 0x0);
545 }
546 }
547 void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
548 if (op.isMEM() || is_valid_isa(avx2)) {
549 vbroadcastss(x, op);
550 } else {
551 Xbyak::Xmm t(x.getIdx());
552 if (!t.isEqualIfNotInherited(op)) movss(t, op);
553 vinsertf128(x, x, t, 1);
554 vshufps(x, x, x, 0);
555 }
556 }
557
558 void uni_vpbroadcastb(const Xbyak::Ymm &x, const Xbyak::Reg8 &r) {
559 if (is_valid_isa(avx512_core))
560 vpbroadcastb(x, r); // broadcast reg32 directly
561 else if (is_valid_isa(avx2)) {
562 const Xbyak::Xmm t(x.getIdx());
563 uni_vmovd(t, r.cvt32());
564 vpbroadcastb(x, t);
565 }
566 assert(is_valid_isa(avx2) && "avx does not support vpbroadcastb");
567 }
568
569 void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
570 if (is_valid_isa(avx2))
571 vpbroadcastd(x, op);
572 else if (is_valid_isa(avx)) {
573 if (op.isMEM())
574 vmovss(x, op.getAddress());
575 else
576 vmovss(x, x, op);
577 vpshufd(x, x, 0x0);
578 } else {
579 movss(x, op);
580 pshufd(x, x, 0x0);
581 }
582 }
583 void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Reg32 &r) {
584 if (is_valid_isa(avx512_core))
585 vpbroadcastd(x, r); // broadcast reg32 directly
586 else {
587 const Xbyak::Xmm t(x.getIdx());
588 uni_vmovd(t, r);
589 uni_vpbroadcastd(x, t);
590 }
591 }
592 void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
593 if (is_valid_isa(avx2)) {
594 vpbroadcastd(x, op);
595 } else {
596 const Xbyak::Xmm t(x.getIdx());
597 if (!t.isEqualIfNotInherited(op)) {
598 if (op.isMEM())
599 vmovss(t, op.getAddress());
600 else
601 vmovss(t, t, op);
602 }
603 vinsertf128(x, x, t, 1);
604 vshufps(x, x, x, 0);
605 }
606 }
607
608 void uni_vshufps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
609 const Xbyak::Operand &op, Xbyak::uint8 imm) {
610 if (is_valid_isa(avx))
611 vshufps(x1, x2, op, imm);
612 else {
613 movups(x1, x2);
614 shufps(x1, op, imm);
615 }
616 }
617
618 void uni_vpshufd(
619 const Xbyak::Xmm &x1, const Xbyak::Operand &op, Xbyak::uint8 imm) {
620 if (is_valid_isa(avx))
621 vpshufd(x1, op, imm);
622 else {
623 pshufd(x1, op, imm);
624 }
625 }
626
627 void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
628 if (is_valid_isa(avx))
629 vrcpss(x, x, op);
630 else
631 rcpss(x, op);
632 }
633 void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
634 Xbyak::Xmm x1_(x1.getIdx());
635 Xbyak::Xmm x2_(x2.getIdx());
636 vrcpss(x1_, x1_, x2_);
637 }
638 void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
639 Xbyak::Xmm x_(x.getIdx());
640 vrcpss(x_, x_, op);
641 }
642
643 void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
644 if (is_valid_isa(avx))
645 vrcpps(x, op);
646 else
647 rcpps(x, op);
648 }
649 void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
650 vrcpps(x, op);
651 }
652 void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
653 vrcp14ps(x, op);
654 }
655
656 void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
657 const Xbyak::Operand &op2) {
658 if (is_valid_isa(avx))
659 vdivps(x, op1, op2);
660 else {
661 assert(x.isEqualIfNotInherited(op1));
662 divps(x, op2);
663 }
664 }
665 void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
666 const Xbyak::Operand &op2) {
667 vdivps(x, op1, op2);
668 }
669
670 void uni_vdivss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
671 const Xbyak::Operand &op2) {
672 if (is_valid_isa(avx))
673 vdivss(x, op1, op2);
674 else {
675 assert(x.isEqualIfNotInherited(op1));
676 divss(x, op2);
677 }
678 }
679
680 void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
681 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
682 if (is_valid_isa(avx))
683 vdivps(x, op1, op2);
684 else {
685 movups(buf, op1);
686 divps(buf, op2);
687 if (x.getIdx() != buf.getIdx()) { movups(x, buf); }
688 }
689 }
690
691 void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
692 const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
693 vdivps(x, op1, op2);
694 }
695
696 void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
697 const Xbyak::Operand &op2) {
698 if (is_valid_isa(avx))
699 vaddps(x, op1, op2);
700 else {
701 assert(x.getIdx() == op1.getIdx());
702 addps(x, op2);
703 }
704 }
705 void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
706 const Xbyak::Operand &op2) {
707 vaddps(x, op1, op2);
708 }
709 void uni_vaddss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
710 const Xbyak::Operand &op2) {
711 if (is_valid_isa(avx))
712 vaddss(x, op1, op2);
713 else {
714 assert(x.isEqualIfNotInherited(op1));
715 addss(x, op2);
716 }
717 }
718 void uni_vaddss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
719 const Xbyak::Operand &op2) {
720 vaddss(x, op1, op2);
721 }
722
723 void uni_vphaddd(const Xbyak::Xmm &x, const Xbyak::Xmm &x2,
724 const Xbyak::Operand &op) {
725 if (is_valid_isa(avx)) {
726 vphaddd(x, x2, op);
727 } else {
728 assert(x.isEqualIfNotInherited(op));
729 phaddd(x, op);
730 }
731 }
732
733 void uni_vhaddps(const Xbyak::Xmm &x, const Xbyak::Xmm &x2,
734 const Xbyak::Operand &op) {
735 if (is_valid_isa(avx)) {
736 vhaddps(x, x2, op);
737 } else {
738 assert(x.isEqualIfNotInherited(op));
739 haddps(x, op);
740 }
741 }
742
743 void uni_vpsignd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
744 const Xbyak::Operand &op) {
745 if (is_valid_isa(avx))
746 vpsignd(x1, x2, op);
747 else {
748 assert(x1.getIdx() == x2.getIdx());
749 psignd(x1, op);
750 }
751 }
752 void uni_vpsignd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
753 const Xbyak::Operand &op) {
754 vpsignd(x1, x2, op);
755 }
756
757 void uni_vpsubd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
758 const Xbyak::Operand &op) {
759 if (is_valid_isa(avx))
760 vpsubd(x1, x2, op);
761 else {
762 assert(x1.getIdx() == x2.getIdx());
763 psubd(x1, op);
764 }
765 }
766 void uni_vpsubd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
767 const Xbyak::Operand &op) {
768 vpsubd(x1, x2, op);
769 }
770
771 void uni_vpsubb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
772 const Xbyak::Operand &op) {
773 if (is_valid_isa(avx))
774 vpsubb(x1, x2, op);
775 else {
776 assert(x1.getIdx() == x2.getIdx());
777 psubb(x1, op);
778 }
779 }
780 void uni_vpsubb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
781 const Xbyak::Operand &op) {
782 vpsubb(x1, x2, op);
783 }
784
785 void uni_vsubss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
786 const Xbyak::Operand &op2) {
787 if (is_valid_isa(avx))
788 vsubss(x, op1, op2);
789 else {
790 assert(x.isEqualIfNotInherited(op1));
791 subss(x, op2);
792 }
793 }
794 void uni_vsubss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
795 const Xbyak::Operand &op2) {
796 vsubss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx()));
797 }
798
799 void uni_vsubss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
800 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
801 if (is_valid_isa(avx))
802 vsubss(x, op1, op2);
803 else {
804 if (!buf.isEqualIfNotInherited(op1)) {
805 assert(!buf.isEqualIfNotInherited(op2));
806 movss(buf, op1);
807 }
808 subss(buf, op2);
809 if (x.getIdx() != buf.getIdx()) movss(x, buf);
810 }
811 }
812
813 void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
814 const Xbyak::Operand &op2) {
815 if (is_valid_isa(avx))
816 vsubps(x, op1, op2);
817 else {
818 assert(x.isEqualIfNotInherited(op1));
819 subps(x, op2);
820 }
821 }
822 void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
823 const Xbyak::Operand &op2) {
824 vsubps(x, op1, op2);
825 }
826
827 void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
828 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
829 if (is_valid_isa(avx))
830 vsubps(x, op1, op2);
831 else {
832 movups(buf, op1);
833 subps(buf, op2);
834 if (x.getIdx() != buf.getIdx()) { movups(x, buf); }
835 }
836 }
837
838 void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
839 const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
840 vsubps(x, op1, op2);
841 }
842
843 void uni_vpmulld(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
844 const Xbyak::Operand &op) {
845 if (is_valid_isa(avx)) {
846 vpmulld(x1, x2, op);
847 } else {
848 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
849 pmulld(x1, op);
850 }
851 }
852 void uni_vpmulld(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
853 const Xbyak::Operand &op) {
854 vpmulld(x1, x2, op);
855 }
856
857 void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
858 const Xbyak::Operand &op2) {
859 if (is_valid_isa(avx))
860 vmulps(x, op1, op2);
861 else {
862 if (!x.isEqualIfNotInherited(op1)) {
863 assert(!x.isEqualIfNotInherited(op2));
864 movups(x, op1);
865 }
866 mulps(x, op2);
867 }
868 }
869 void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
870 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
871 if (is_valid_isa(avx))
872 vmulps(x, op1, op2);
873 else {
874 if (!buf.isEqualIfNotInherited(op1)) {
875 assert(!buf.isEqualIfNotInherited(op2));
876 movups(buf, op1);
877 }
878 mulps(buf, op2);
879 if (x.getIdx() != buf.getIdx()) movups(x, buf);
880 }
881 }
882 void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
883 const Xbyak::Operand &op2) {
884 vmulps(x, op1, op2);
885 }
886
887 void uni_vmulss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
888 const Xbyak::Operand &op2) {
889 if (is_valid_isa(avx))
890 vmulss(x, op1, op2);
891 else {
892 assert(x.isEqualIfNotInherited(op1));
893 mulss(x, op2);
894 }
895 }
896 void uni_vmulss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
897 const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
898 if (is_valid_isa(avx))
899 vmulss(x, op1, op2);
900 else {
901 if (!buf.isEqualIfNotInherited(op1)) {
902 assert(!buf.isEqualIfNotInherited(op2));
903 movss(buf, op1);
904 }
905 mulss(buf, op2);
906 if (x.getIdx() != buf.getIdx()) movss(x, buf);
907 }
908 }
909 void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
910 const Xbyak::Address &op2) {
911 vmulss(x, Xbyak::Xmm(op1.getIdx()), op2);
912 }
913 void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
914 const Xbyak::Ymm &op2) {
915 vmulss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx()));
916 }
917
918 void uni_vfmadd132ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
919 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
920 if (is_valid_isa(avx2))
921 vfmadd132ps(x1, x2, op);
922 else if (is_valid_isa(avx)) {
923 assert(x1.getIdx() != x2.getIdx());
924 vmulps(x1, x1, op);
925 vaddps(x1, x1, x2);
926 } else {
927 assert(buf.getIdx() != x2.getIdx());
928 if (x1.getIdx() != buf.getIdx()) movups(buf, x1);
929 mulps(buf, op);
930 addps(buf, x2);
931 if (x1.getIdx() != buf.getIdx()) movups(x1, buf);
932 }
933 }
934 void uni_vfmadd132ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
935 const Xbyak::Operand &op) {
936 // Note: SSE, AVX: x1 gets overridden by x1*op
937 // This is incorrect if x1 == x2
938 if (!is_valid_isa(avx2)) assert(x1 != x2);
939 uni_vfmadd132ps(x1, x2, op, x1);
940 }
941
942 void uni_vfmadd132ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
943 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
944 if (is_valid_isa(avx2))
945 vfmadd132ps(x1, x2, op);
946 else {
947 // AVX is assumed because of YMM
948 assert(buf.getIdx() != x2.getIdx());
949 vmulps(buf, x1, op);
950 vaddps(x1, buf, x2);
951 }
952 }
953 void uni_vfmadd132ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
954 const Xbyak::Operand &op) {
955 // Note: AVX: x1 gets overridden by x1*op
956 // This is incorrect if x1 == x2
957 if (!is_valid_isa(avx2)) assert(x1 != x2);
958 uni_vfmadd132ps(x1, x2, op, x1);
959 }
960
961 void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
962 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
963 if (is_valid_isa(avx2))
964 vfmadd213ps(x1, x2, op);
965 else if (is_valid_isa(avx)) {
966 assert(!buf.isEqualIfNotInherited(op));
967 vmulps(buf, x1, x2);
968 vaddps(x1, buf, op);
969 } else {
970 assert(!buf.isEqualIfNotInherited(op));
971 if (x1.getIdx() != buf.getIdx()) movups(buf, x1);
972 mulps(buf, x2);
973 addps(buf, op);
974 if (x1.getIdx() != buf.getIdx()) movups(x1, buf);
975 }
976 }
977 void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
978 const Xbyak::Operand &op) {
979 // Note: SSE, AVX: x1 gets overridden by x1*x2
980 // This is incorrect if x1 == op
981 if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op));
982 uni_vfmadd213ps(x1, x2, op, x1);
983 }
984
985 void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
986 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
987 if (is_valid_isa(avx2))
988 vfmadd213ps(x1, x2, op);
989 else {
990 // AVX is assumed because of YMM
991 assert(!buf.isEqualIfNotInherited(op));
992 vmulps(buf, x1, x2);
993 vaddps(x1, buf, op);
994 }
995 }
996 void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
997 const Xbyak::Operand &op) {
998 // Note: AVX: x1 gets overridden by x1*x2
999 // This is incorrect if x1 == op
1000 if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op));
1001 uni_vfmadd213ps(x1, x2, op, x1);
1002 }
1003
1004 void uni_vfmadd213ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1005 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
1006 if (is_valid_isa(avx2))
1007 vfmadd213ss(x1, x2, op);
1008 else if (is_valid_isa(avx)) {
1009 assert(!buf.isEqualIfNotInherited(op));
1010 vmulss(buf, x1, x2);
1011 vaddss(x1, buf, op);
1012 } else {
1013 assert(!buf.isEqualIfNotInherited(op));
1014 if (x1.getIdx() != buf.getIdx()) movss(buf, x1);
1015 mulss(buf, x2);
1016 addss(x1, op);
1017 if (x1.getIdx() != buf.getIdx()) movss(x1, buf);
1018 }
1019 }
1020 void uni_vfmadd213ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1021 const Xbyak::Operand &op) {
1022 // Note: SSE, AVX: x1 gets overridden by x1*x2
1023 // This is incorrect if x1 == op
1024 if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op));
1025 uni_vfmadd213ss(x1, x2, op, x1);
1026 }
1027
1028 void uni_vfmadd213ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1029 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
1030 if (is_valid_isa(avx2))
1031 vfmadd213ss(x1, x2, op);
1032 else {
1033 // AVX is assumed because of YMM
1034 assert(!buf.isEqualIfNotInherited(op));
1035 vmulss(buf, x1, x2);
1036 vaddss(x1, buf, op);
1037 }
1038 }
1039 void uni_vfmadd213ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1040 const Xbyak::Operand &op) {
1041 // Note: AVX: x1 gets overridden by x1*x2
1042 // This is incorrect if x1 == op
1043 if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op));
1044 uni_vfmadd213ss(x1, x2, op, x1);
1045 }
1046
1047 void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1048 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
1049 if (is_valid_isa(avx2))
1050 vfmadd231ps(x1, x2, op);
1051 else if (is_valid_isa(avx)) {
1052 assert(buf.getIdx() != x1.getIdx());
1053 vmulps(buf, x2, op);
1054 vaddps(x1, x1, buf);
1055 } else {
1056 assert(buf.getIdx() != x1.getIdx());
1057 if (x2.getIdx() != buf.getIdx()) movups(buf, x2);
1058 mulps(buf, op);
1059 addps(x1, buf);
1060 }
1061 }
1062 void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1063 const Xbyak::Operand &op) {
1064 // Note: SSE, AVX: x2 gets overridden by x2*op
1065 // This is incorrect if x1 == x2
1066 if (!is_valid_isa(avx2)) assert(x1 != x2);
1067 uni_vfmadd231ps(x1, x2, op, x2);
1068 }
1069
1070 void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1071 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
1072 if (is_valid_isa(avx2))
1073 vfmadd231ps(x1, x2, op);
1074 else {
1075 // AVX is assumed because of YMM
1076 assert(buf.getIdx() != x1.getIdx());
1077 vmulps(buf, x2, op);
1078 vaddps(x1, x1, buf);
1079 }
1080 }
1081 void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1082 const Xbyak::Operand &op) {
1083 // Note: AVX: x2 gets overridden by x2*op
1084 // This is incorrect if x1 == x2
1085 if (!is_valid_isa(avx2)) assert(x1 != x2);
1086 uni_vfmadd231ps(x1, x2, op, x2);
1087 }
1088
1089 void uni_vfmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1090 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
1091 if (is_valid_isa(avx2))
1092 vfmadd231ss(x1, x2, op);
1093 else if (is_valid_isa(avx)) {
1094 assert(buf.getIdx() != x1.getIdx());
1095 vmulss(buf, x2, op);
1096 vaddss(x1, x1, buf);
1097 } else {
1098 assert(buf.getIdx() != x1.getIdx());
1099 if (x2.getIdx() != buf.getIdx()) movss(buf, x2);
1100 mulss(buf, op);
1101 addss(x1, buf);
1102 }
1103 }
1104 void uni_vfmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1105 const Xbyak::Operand &op) {
1106 // Note: SSE, AVX: x2 gets overridden by x2*op
1107 // This is incorrect if x1 == x2
1108 if (!is_valid_isa(avx2)) assert(x1 != x2);
1109 uni_vfmadd231ss(x1, x2, op, x2);
1110 }
1111
1112 void uni_vfmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1113 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
1114 if (is_valid_isa(avx2))
1115 vfmadd231ss(Xbyak::Xmm(x1.getIdx()), Xbyak::Xmm(x2.getIdx()), op);
1116 else {
1117 // AVX is assumed because of YMM
1118 assert(buf.getIdx() != x1.getIdx());
1119 vmulss(buf, x2, op);
1120 vaddss(x1, x1, buf);
1121 }
1122 }
1123 void uni_vfmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1124 const Xbyak::Operand &op) {
1125 // Note: AVX: x2 gets overridden by x2*op
1126 // This is incorrect if x1 == x2
1127 if (!is_valid_isa(avx2)) assert(x1 != x2);
1128 uni_vfmadd231ss(x1, x2, op, x2);
1129 }
1130
1131 void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1132 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
1133 if (is_valid_isa(avx2))
1134 vfnmadd231ps(x1, x2, op);
1135 else if (is_valid_isa(avx)) {
1136 assert(buf.getIdx() != x1.getIdx());
1137 vmulps(buf, x2, op);
1138 vsubps(x1, x1, buf);
1139 } else {
1140 assert(buf.getIdx() != x1.getIdx());
1141 if (x2.getIdx() != buf.getIdx()) movups(buf, x2);
1142 mulps(buf, op);
1143 subps(x1, buf);
1144 }
1145 }
1146 void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1147 const Xbyak::Operand &op) {
1148 // Note: SSE, AVX: x2 gets overridden by x2*op
1149 // This is incorrect if x1 == x2
1150 if (!is_valid_isa(avx2)) assert(x1 != x2);
1151 uni_vfnmadd231ps(x1, x2, op, x2);
1152 }
1153
1154 void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1155 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
1156 if (is_valid_isa(avx2))
1157 vfnmadd231ps(x1, x2, op);
1158 else {
1159 // AVX is assumed because of YMM
1160 assert(buf.getIdx() != x1.getIdx());
1161 vmulps(buf, x2, op);
1162 vsubps(x1, x1, buf);
1163 }
1164 }
1165 void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1166 const Xbyak::Operand &op) {
1167 // Note: AVX: x2 gets overridden by x2*op
1168 // This is incorrect if x1 == x2
1169 if (!is_valid_isa(avx2)) assert(x1 != x2);
1170 uni_vfnmadd231ps(x1, x2, op, x2);
1171 }
1172
1173 void uni_vfnmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1174 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
1175 if (is_valid_isa(avx2))
1176 vfnmadd231ss(x1, x2, op);
1177 else if (is_valid_isa(avx)) {
1178 assert(buf.getIdx() != x1.getIdx());
1179 vmulss(buf, x2, op);
1180 vsubss(x1, x1, buf);
1181 } else {
1182 assert(buf.getIdx() != x1.getIdx());
1183 if (x2.getIdx() != buf.getIdx()) movss(buf, x2);
1184 mulss(buf, op);
1185 subss(x1, buf);
1186 }
1187 }
1188 void uni_vfnmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1189 const Xbyak::Operand &op) {
1190 // Note: SSE, AVX: x2 gets overridden by x2*op
1191 // This is incorrect if x1 == x2
1192 if (!is_valid_isa(avx2)) assert(x1 != x2);
1193 uni_vfnmadd231ss(x1, x2, op, x2);
1194 }
1195
1196 void uni_vfnmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1197 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
1198 if (is_valid_isa(avx2))
1199 vfnmadd231ss(x1, x2, op);
1200 else {
1201 // AVX is assumed because of YMM
1202 assert(buf.getIdx() != x1.getIdx());
1203 vmulss(buf, x2, op);
1204 vsubss(x1, x1, buf);
1205 }
1206 }
1207 void uni_vfnmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1208 const Xbyak::Operand &op) {
1209 // Note: AVX: x2 gets overridden by x2*op
1210 // This is incorrect if x1 == x2
1211 if (!is_valid_isa(avx2)) assert(x1 != x2);
1212 uni_vfnmadd231ss(x1, x2, op, x2);
1213 }
1214
1215 void uni_vfmsub213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1216 const Xbyak::Operand &op, const Xbyak::Xmm &buf) {
1217 if (is_valid_isa(avx2))
1218 vfmsub213ps(x1, x2, op);
1219 else if (is_valid_isa(avx)) {
1220 assert(!buf.isEqualIfNotInherited(op));
1221 vmulps(buf, x1, x2);
1222 vsubps(x1, buf, op);
1223 } else {
1224 assert(!buf.isEqualIfNotInherited(op));
1225 if (buf.getIdx() != x1.getIdx()) movups(buf, x1);
1226 mulps(buf, x2);
1227 subps(buf, op);
1228 if (buf.getIdx() != x1.getIdx()) movups(x1, buf);
1229 }
1230 }
1231 void uni_vfmsub213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1232 const Xbyak::Operand &op) {
1233 // Note: SSE, AVX: x1 gets overridden by x1*x2
1234 // This is incorrect if x1 == op
1235 if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op));
1236 uni_vfmsub213ps(x1, x2, op, x1);
1237 }
1238
1239 void uni_vfmsub213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1240 const Xbyak::Operand &op, const Xbyak::Ymm &buf) {
1241 if (is_valid_isa(avx2))
1242 vfmsub213ps(x1, x2, op);
1243 else {
1244 // AVX is assumed because of YMM
1245 assert(!buf.isEqualIfNotInherited(op));
1246 vmulps(buf, x1, x2);
1247 vsubps(x1, buf, op);
1248 }
1249 }
1250 void uni_vfmsub213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1251 const Xbyak::Operand &op) {
1252 // Note: AVX: x1 gets overridden by x1*x2
1253 // This is incorrect if x1 == op
1254 if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op));
1255 uni_vfmsub213ps(x1, x2, op, x1);
1256 }
1257
1258 void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1259 if (is_valid_isa(avx))
1260 vsqrtps(x, op);
1261 else
1262 sqrtps(x, op);
1263 }
1264 void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
1265 vsqrtps(x, op);
1266 }
1267
1268 void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1269 const Xbyak::Operand &op) {
1270 if (is_valid_isa(avx))
1271 vpaddd(x1, x2, op);
1272 else {
1273 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1274 paddd(x1, op);
1275 }
1276 }
1277 void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1278 const Xbyak::Operand &op) {
1279 vpaddd(x1, x2, op);
1280 }
1281
1282 void uni_vpaddb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1283 const Xbyak::Operand &op) {
1284 if (is_valid_isa(avx))
1285 vpaddb(x1, x2, op);
1286 else {
1287 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1288 paddb(x1, op);
1289 }
1290 }
1291 void uni_vpaddb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1292 const Xbyak::Operand &op) {
1293 vpaddb(x1, x2, op);
1294 }
1295
1296 void uni_vpmaddwd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1297 const Xbyak::Operand &op) {
1298 if (is_valid_isa(avx))
1299 vpmaddwd(x1, x2, op);
1300 else {
1301 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1302 pmaddwd(x1, op);
1303 }
1304 }
1305 void uni_vpmaddwd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1306 const Xbyak::Operand &op) {
1307 vpmaddwd(x1, x2, op);
1308 }
1309
1310 void uni_vpmaddubsw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1311 const Xbyak::Operand &op) {
1312 if (is_valid_isa(avx))
1313 vpmaddubsw(x1, x2, op);
1314 else {
1315 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1316 pmaddubsw(x1, op);
1317 }
1318 }
1319 void uni_vpmaddubsw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1320 const Xbyak::Operand &op) {
1321 vpmaddubsw(x1, x2, op);
1322 }
1323
1324 void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1325 const Xbyak::Operand &op) {
1326 if (is_valid_isa(avx))
1327 vandps(x1, x2, op);
1328 else {
1329 assert(x1.getIdx() == x2.getIdx());
1330 andps(x1, op);
1331 }
1332 }
1333 void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1334 const Xbyak::Operand &op) {
1335 if (!is_valid_isa(avx512_core) || x1.getBit() < 512)
1336 vandps(x1, x2, op);
1337 else
1338 vpandd(x1, x2, op);
1339 }
1340
1341 void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1342 const Xbyak::Operand &op) {
1343 if (is_valid_isa(avx))
1344 vorps(x1, x2, op);
1345 else {
1346 assert(x1.getIdx() == x2.getIdx());
1347 orps(x1, op);
1348 }
1349 }
1350 void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1351 const Xbyak::Operand &op) {
1352 if (!is_valid_isa(avx512_core) || x1.getBit() < 512)
1353 vorps(x1, x2, op);
1354 else
1355 vpord(x1, x2, op);
1356 }
1357
1358 void uni_vxorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1359 const Xbyak::Operand &op) {
1360 if (is_valid_isa(avx))
1361 vxorps(x1, x2, op);
1362 else {
1363 if (x1.getIdx() != x2.getIdx()) { uni_vmovups(x1, x2); }
1364 xorps(x1, op);
1365 }
1366 }
1367 void uni_vxorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1368 const Xbyak::Operand &op) {
1369 if (!is_valid_isa(avx512_core) || x1.getBit() < 512)
1370 vxorps(x1, x2, op);
1371 else
1372 vpxord(x1, x2, op);
1373 }
1374
1375 void uni_vpslld(
1376 const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) {
1377 if (is_valid_isa(avx))
1378 vpslld(x, op, imm);
1379 else {
1380 assert(x.isEqualIfNotInherited(op));
1381 pslld(x, imm);
1382 }
1383 }
1384 void uni_vpslld(
1385 const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) {
1386 vpslld(x, op, imm);
1387 }
1388
1389 void uni_vpsrld(
1390 const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) {
1391 if (is_valid_isa(avx))
1392 vpsrld(x, op, imm);
1393 else {
1394 if (!x.isEqualIfNotInherited(op)) uni_vmovups(x, op);
1395 psrld(x, imm);
1396 }
1397 }
1398 void uni_vpsrld(
1399 const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) {
1400 vpsrld(x, op, imm);
1401 }
1402
1403 void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
1404 const Xbyak::Operand &op2) {
1405 if (is_valid_isa(avx))
1406 vmaxps(x, op1, op2);
1407 else {
1408 if (!x.isEqualIfNotInherited(op1)) movups(x, op1);
1409 maxps(x, op2);
1410 }
1411 }
1412 void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
1413 const Xbyak::Operand &op2) {
1414 vmaxps(x, op1, op2);
1415 }
1416
1417 void uni_vmaxss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
1418 const Xbyak::Operand &op2) {
1419 if (is_valid_isa(avx))
1420 vmaxss(x, op1, op2);
1421 else {
1422 if (!x.isEqualIfNotInherited(op1)) movss(x, op1);
1423 maxss(x, op2);
1424 }
1425 }
1426
1427 void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
1428 const Xbyak::Operand &op2) {
1429 if (is_valid_isa(avx))
1430 vminps(x, op1, op2);
1431 else {
1432 if (!x.isEqualIfNotInherited(op1)) movups(x, op1);
1433 minps(x, op2);
1434 }
1435 }
1436
1437 void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
1438 const Xbyak::Operand &op2) {
1439 vminps(x, op1, op2);
1440 }
1441
1442 void uni_vminss(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
1443 const Xbyak::Operand &op2) {
1444 if (is_valid_isa(avx))
1445 vminss(x, op1, op2);
1446 else {
1447 if (!x.isEqualIfNotInherited(op1)) movss(x, op1);
1448 minss(x, op2);
1449 }
1450 }
1451
1452 void uni_vpmovsxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1453 if (is_valid_isa(avx))
1454 vpmovsxbd(x, op);
1455 else
1456 pmovsxbd(x, op);
1457 }
1458
1459 void uni_vpmovsxbd(const Xbyak::Ymm &y, const Xbyak::Operand &op) {
1460 vpmovsxbd(y, op);
1461 }
1462
1463 void uni_vpmovzxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1464 if (is_valid_isa(avx))
1465 vpmovzxbd(x, op);
1466 else
1467 pmovzxbd(x, op);
1468 }
1469 void uni_vpmovzxbd(const Xbyak::Ymm &y, const Xbyak::Operand &op) {
1470 vpmovzxbd(y, op);
1471 }
1472
1473 void uni_vcmpps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1474 const Xbyak::Operand &op, int cmp_predicate) {
1475 if (is_valid_isa(avx))
1476 vcmpps(x1, x2, op, cmp_predicate);
1477 else {
1478 if (x1.getIdx() != x2.getIdx()) uni_vmovups(x1, x2);
1479 cmpps(x1, op, cmp_predicate);
1480 }
1481 }
1482 void uni_vcmpps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1483 const Xbyak::Operand &op, int cmp_predicate) {
1484 vcmpps(x1, x2, op, cmp_predicate);
1485 }
1486
1487 void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
1488 if (is_valid_isa(avx))
1489 vtestps(x1, op);
1490 else
1491 ptest(x1, op);
1492 }
1493
1494 void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
1495 assert(!(x1.isZMM() || op.isZMM()));
1496 vtestps(x1, op);
1497 }
1498
1499 void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1500 const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
1501 if (is_valid_isa(avx))
1502 vblendvps(x1, x2, op, msk);
1503 else {
1504 assert(x1.getIdx() == x2.getIdx());
1505 assert(msk.getIdx() == 0);
1506 blendvps(x1, op);
1507 }
1508 }
1509 void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1510 const Xbyak::Operand &op, const Xbyak::Ymm &msk) {
1511 vblendvps(x1, x2, op, msk);
1512 }
1513
1514 void uni_vblendps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1515 const Xbyak::Operand &op, const int imm) {
1516 assert(!x1.isZMM() && !x2.isZMM());
1517
1518 if (is_valid_isa(avx))
1519 vblendps(x1, x2, op, imm);
1520 else {
1521 assert(x1.getIdx() == x2.getIdx());
1522 blendps(x1, op, imm);
1523 }
1524 }
1525
1526 void uni_vroundps(
1527 const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) {
1528 if (is_valid_isa(avx512_core))
1529 vrndscaleps(x, op, imm & 0x3);
1530 else if (is_valid_isa(avx))
1531 vroundps(x, op, imm);
1532 else
1533 roundps(x, op, imm);
1534 }
1535
1536 void uni_vroundps(
1537 const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) {
1538 if (is_valid_isa(avx512_core))
1539 vrndscaleps(x, op, imm & 0x3);
1540 else
1541 vroundps(x, op, imm);
1542 }
1543
1544 void uni_vroundps(
1545 const Xbyak::Zmm &x, const Xbyak::Operand &op, const int imm) {
1546 vrndscaleps(x, op, imm & 0x3);
1547 }
1548
1549 void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1550 if (is_valid_isa(avx))
1551 vcvtps2dq(x, op);
1552 else
1553 cvtps2dq(x, op);
1554 }
1555 void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
1556 vcvtps2dq(x, op);
1557 }
1558
1559 void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1560 if (is_valid_isa(avx))
1561 vcvtdq2ps(x, op);
1562 else
1563 cvtdq2ps(x, op);
1564 }
1565
1566 void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
1567 vcvtdq2ps(x, op);
1568 }
1569
1570 void uni_vcvtph2psx(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1571 assert(is_valid_isa(avx2));
1572 if (is_valid_isa(avx512_core_fp16))
1573 vcvtph2psx(x, op);
1574 else if (is_valid_isa(avx2)) {
1575 assert(IMPLICATION(op.isMEM(), !op.getAddress().isBroadcast()));
1576 vcvtph2ps(x, op);
1577 }
1578 }
1579
1580 void uni_vcvtps2phx(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
1581 assert(is_valid_isa(avx512_core_fp16));
1582 vcvtps2phx(x, addr);
1583 }
1584
1585 void uni_vcvtps2phx(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
1586 assert(is_valid_isa(avx2));
1587 vcvtps2ph(addr, x, _op_mxcsr);
1588 }
1589
1590 void uni_vcvtps2phx(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2) {
1591 assert(is_valid_isa(avx2));
1592 if (is_valid_isa(avx512_core_fp16))
1593 vcvtps2phx(x1, x2);
1594 else if (is_valid_isa(avx2))
1595 vcvtps2ph(x1, x2, _op_mxcsr);
1596 }
1597
1598 void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) {
1599 movmskps(x1.cvt64(), x2);
1600 }
1601 void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) {
1602 vmovmskps(x1, x2);
1603 }
1604
1605 void uni_vmovd(const Xbyak::Reg32 &r, const Xbyak::Xmm &x) {
1606 if (is_valid_isa(avx))
1607 vmovd(r, x);
1608 else
1609 movd(r, x);
1610 }
1611 void uni_vmovd(const Xbyak::Xmm &x, const Xbyak::Reg32 &r) {
1612 if (is_valid_isa(avx))
1613 vmovd(x, r);
1614 else
1615 movd(x, r);
1616 }
1617 void uni_vmovd(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
1618 if (is_valid_isa(avx))
1619 vmovd(addr, x);
1620 else
1621 movd(addr, x);
1622 }
1623
1624 void uni_vmovd(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
1625 if (is_valid_isa(avx))
1626 vmovd(x, addr);
1627 else
1628 movd(x, addr);
1629 }
1630
1631 void uni_vmovq(const Xbyak::Xmm &x, const Xbyak::Reg64 &r) {
1632 if (is_valid_isa(avx))
1633 vmovq(x, r);
1634 else
1635 movq(x, r);
1636 }
1637 void uni_vmovq(const Xbyak::Reg64 &r, const Xbyak::Xmm &x) {
1638 if (is_valid_isa(avx))
1639 vmovq(r, x);
1640 else
1641 movq(r, x);
1642 }
1643 void uni_vmovq(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
1644 if (is_valid_isa(avx))
1645 vmovq(addr, x);
1646 else
1647 movq(addr, x);
1648 }
1649 void uni_vmovq(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
1650 if (is_valid_isa(avx))
1651 vmovq(x, addr);
1652 else
1653 movq(x, addr);
1654 }
1655
1656 void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1657 const Xbyak::Operand &op) {
1658 if (is_valid_isa(avx))
1659 vpackssdw(x1, x2, op);
1660 else {
1661 assert(x1.getIdx() == x2.getIdx());
1662 packssdw(x1, op);
1663 }
1664 }
1665 void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1666 const Xbyak::Operand &op) {
1667 vpackssdw(x1, x2, op);
1668 }
1669
1670 void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1671 const Xbyak::Operand &op) {
1672 if (is_valid_isa(avx))
1673 vpackuswb(x1, x2, op);
1674 else {
1675 assert(x1.getIdx() == x2.getIdx());
1676 packuswb(x1, op);
1677 }
1678 }
1679 void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1680 const Xbyak::Operand &op) {
1681 vpackuswb(x1, x2, op);
1682 }
1683
1684 void uni_vpacksswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1685 const Xbyak::Operand &op) {
1686 if (is_valid_isa(avx))
1687 vpacksswb(x1, x2, op);
1688 else {
1689 assert(x1.getIdx() == x2.getIdx());
1690 packsswb(x1, op);
1691 }
1692 }
1693 void uni_vpacksswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1694 const Xbyak::Operand &op) {
1695 vpacksswb(x1, x2, op);
1696 }
1697
1698 void uni_vpinsrb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1699 const Xbyak::Operand &op, const int imm) {
1700 if (is_valid_isa(avx))
1701 vpinsrb(x1, x2, op, imm);
1702 else {
1703 assert(x1.getIdx() == x2.getIdx());
1704 pinsrb(x1, op, imm);
1705 }
1706 }
1707
1708 void uni_vpinsrb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1709 const Xbyak::Operand &op, const int imm) {
1710 vpinsrb(x1, x2, op, imm);
1711 }
1712
1713 void uni_vpinsrd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1714 const Xbyak::Operand &op, const int imm) {
1715 if (is_valid_isa(avx))
1716 vpinsrd(x1, x2, op, imm);
1717 else {
1718 assert(x1.getIdx() == x2.getIdx());
1719 pinsrd(x1, op, imm);
1720 }
1721 }
1722 void uni_vpinsrd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1723 const Xbyak::Operand &op, const int imm) {
1724 vpinsrd(x1, x2, op, imm);
1725 }
1726
1727 void uni_vpinsrq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1728 const Xbyak::Operand &op, const int imm) {
1729 if (is_valid_isa(avx))
1730 vpinsrq(x1, x2, op, imm);
1731 else {
1732 assert(x1.getIdx() == x2.getIdx());
1733 pinsrq(x1, op, imm);
1734 }
1735 }
1736 void uni_vpinsrq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1737 const Xbyak::Operand &op, const int imm) {
1738 vpinsrq(x1, x2, op, imm);
1739 }
1740
1741 void uni_vpinsrw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1742 const Xbyak::Operand &op, const int imm) {
1743 if (is_valid_isa(avx))
1744 vpinsrw(x1, x2, op, imm);
1745 else {
1746 assert(x1.getIdx() == x2.getIdx());
1747 pinsrw(x1, op, imm);
1748 }
1749 }
1750 void uni_vpinsrw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1751 const Xbyak::Operand &op, const int imm) {
1752 vpinsrw(x1, x2, op, imm);
1753 }
1754
1755 void uni_vpextrb(
1756 const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) {
1757 if (is_valid_isa(avx))
1758 vpextrb(op, x, imm);
1759 else
1760 pextrb(op, x, imm);
1761 }
1762
1763 void uni_vpextrb(
1764 const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) {
1765 vpextrb(op, x, imm);
1766 }
1767
1768 void uni_vpextrw(
1769 const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) {
1770 if (is_valid_isa(avx))
1771 vpextrw(op, x, imm);
1772 else
1773 pextrw(op, x, imm);
1774 }
1775 void uni_vpextrw(
1776 const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) {
1777 vpextrw(op, x, imm);
1778 }
1779
1780 void uni_vpextrd(
1781 const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) {
1782 if (is_valid_isa(avx))
1783 vpextrd(op, x, imm);
1784 else
1785 pextrd(op, x, imm);
1786 }
1787 void uni_vpextrd(
1788 const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) {
1789 vpextrd(op, x, imm);
1790 }
1791
1792 void uni_vpextrq(
1793 const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) {
1794 if (is_valid_isa(avx))
1795 vpextrq(op, x, imm);
1796 else
1797 pextrq(op, x, imm);
1798 }
1799 void uni_vpextrq(
1800 const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) {
1801 vpextrq(op, x, imm);
1802 }
1803
1804 void uni_vpmaxsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1805 const Xbyak::Operand &op) {
1806 if (is_valid_isa(avx))
1807 vpmaxsd(x1, x2, op);
1808 else {
1809 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1810 pmaxsd(x1, op);
1811 }
1812 }
1813
1814 void uni_vpmaxsd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1815 const Xbyak::Operand &op) {
1816 vpmaxsd(x1, x2, op);
1817 }
1818
1819 void uni_vpmaxsb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1820 const Xbyak::Operand &op) {
1821 if (is_valid_isa(avx))
1822 vpmaxsb(x1, x2, op);
1823 else {
1824 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1825 pmaxsb(x1, op);
1826 }
1827 }
1828
1829 void uni_vpmaxsb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1830 const Xbyak::Operand &op) {
1831 vpmaxsb(x1, x2, op);
1832 }
1833
1834 void uni_vpminub(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1835 const Xbyak::Operand &op) {
1836 if (is_valid_isa(avx))
1837 vpminub(x1, x2, op);
1838 else {
1839 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1840 pminub(x1, op);
1841 }
1842 }
1843
1844 // Note: instructions below are used in custom forks and are put here to
1845 // decrease maintenance burden on user side.
1846 // Once the instruction becomes used in one of oneDNN implementations,
1847 // please, move out of this section.
1848 void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1849 const Xbyak::Operand &op) {
1850 if (is_valid_isa(avx))
1851 vpshufb(x1, x2, op);
1852 else {
1853 assert(x1.getIdx() == x2.getIdx());
1854 pshufb(x1, op);
1855 }
1856 }
1857 void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1858 const Xbyak::Operand &op) {
1859 vpshufb(x1, x2, op);
1860 }
1861
1862 void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1863 const Xbyak::Operand &op) {
1864 if (is_valid_isa(avx512_core) && x1.getBit() == 512)
1865 vpandd(x1, x2, op);
1866 else if (is_valid_isa(avx))
1867 vpand(x1, x2, op);
1868 else {
1869 assert(x1.getIdx() == x2.getIdx());
1870 pand(x1, op);
1871 }
1872 }
1873
1874 void uni_vpslldq(
1875 const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) {
1876 if (is_valid_isa(avx))
1877 vpslldq(x, op, imm);
1878 else {
1879 assert(x.isEqualIfNotInherited(op));
1880 pslldq(x, imm);
1881 }
1882 }
1883 void uni_vpslldq(
1884 const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) {
1885 vpslldq(x, op, imm);
1886 }
1887
1888 void uni_vpmovsxwd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1889 if (is_valid_isa(avx))
1890 vpmovsxwd(x, op);
1891 else
1892 pmovsxwd(x, op);
1893 }
1894 void uni_vpmovsxwd(const Xbyak::Ymm &y, const Xbyak::Operand &op) {
1895 vpmovsxwd(y, op);
1896 }
1897
1898 void uni_vpmovsxdq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1899 if (is_valid_isa(avx))
1900 vpmovsxdq(x, op);
1901 else
1902 pmovsxdq(x, op);
1903 }
1904 void uni_vpmovsxdq(const Xbyak::Ymm &y, const Xbyak::Operand &op) {
1905 vpmovsxdq(y, op);
1906 }
1907
1908 void uni_vpmovzxwd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1909 if (is_valid_isa(avx))
1910 vpmovzxwd(x, op);
1911 else
1912 pmovzxwd(x, op);
1913 }
1914 void uni_vpmovzxwd(const Xbyak::Ymm &y, const Xbyak::Operand &op) {
1915 vpmovzxwd(y, op);
1916 }
1917
1918 void uni_vpcmpeqd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1919 const Xbyak::Operand &op) {
1920 if (is_valid_isa(avx))
1921 vpcmpeqd(x1, x2, op);
1922 else {
1923 if (x1.getIdx() != x2.getIdx()) uni_vmovups(x1, x2);
1924 pcmpeqd(x1, op);
1925 }
1926 }
1927 void uni_vpcmpeqd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1928 const Xbyak::Operand &op) {
1929 vpcmpeqd(x1, x2, op);
1930 }
1931
1932 void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1933 const Xbyak::Operand &op) {
1934 if (is_valid_isa(avx))
1935 vpackusdw(x1, x2, op);
1936 else {
1937 assert(x1.getIdx() == x2.getIdx());
1938 packusdw(x1, op);
1939 }
1940 }
1941 void uni_vpackusdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
1942 const Xbyak::Operand &op) {
1943 vpackusdw(x1, x2, op);
1944 }
1945
1946 void uni_vpminsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
1947 const Xbyak::Operand &op) {
1948 if (is_valid_isa(avx))
1949 vpminsd(x1, x2, op);
1950 else {
1951 if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2);
1952 pminsd(x1, op);
1953 }
1954 }
1955
1956 void uni_movshdup(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
1957 if (is_valid_isa(avx))
1958 vmovshdup(x, op);
1959 else
1960 movshdup(x, op);
1961 }
1962 void uni_movshdup(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
1963 vmovshdup(x, op);
1964 }
1965
1966 void uni_vmovhlps(
1967 const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Xmm &x3) {
1968 if (is_valid_isa(avx))
1969 vmovhlps(x1, x2, x3);
1970 else {
1971 assert(x1.getIdx() == x2.getIdx());
1972 movhlps(x1, x3);
1973 }
1974 }
1975 void uni_vmovhlps(
1976 const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Ymm &x3) {
1977 vmovhlps(x1, x2, x3);
1978 }
1979
1980 // End of custom instructions section.
1981
1982 void mul_by_const(
1983 const Xbyak::Reg &out, const Xbyak::Reg64 &tmp, int value) {
1984 // Generates a shift + add sequence for multiplicating contents of the
1985 // out register by a known JIT-time value. Clobbers the tmp register.
1986 //
1987 // Pros compared to mul/imul:
1988 // - does not require using known registers
1989 // Still, there are probably a lot of cases when mul/imul is faster on
1990 // Intel(R) Core(TM) processors. Not intended for critical path.
1991
1992 // TODO: detect when overflow is emminent (Roma)
1993 // TODO: detect when using mul/imul is a better option (Roma)
1994
1995 int p = 0; // the current power of 2
1996 int old_p = 0; // the last seen power of 2 such that value[old_p] != 0
1997
1998 xor_(tmp, tmp);
1999 while (value) {
2000 if (value & 1) {
2001 int shift = p - old_p;
2002 if (shift) {
2003 shl(out, shift);
2004 old_p = p;
2005 }
2006 add(tmp, out);
2007 }
2008 value >>= 1;
2009 p++;
2010 }
2011 mov(out, tmp);
2012 }
2013
2014 /*
2015 Saturation facility functions. enable to prepare the register
2016 holding the saturation upperbound and apply the saturation on
2017 the floating point register
2018 */
2019 template <typename Vmm>
2020 void init_vmm(Vmm vmm, Xbyak::Reg64 reg_tmp, float value) {
2021 Xbyak::Xmm xmm_tmp(vmm.getIdx());
2022 mov(reg_tmp, float2int(value));
2023 uni_vmovq(xmm_tmp, reg_tmp);
2024 if (vmm.isYMM() || vmm.isZMM())
2025 uni_vbroadcastss(vmm, xmm_tmp);
2026 else
2027 uni_vshufps(vmm, xmm_tmp, xmm_tmp, 0);
2028 }
2029
2030 template <typename Vmm>
2031 void init_saturate_f32(Vmm vmm_lbound, Vmm vmm_ubound, Xbyak::Reg64 reg_tmp,
2032 data_type_t idt, data_type_t odt, bool force_lbound = false) {
2033 using namespace data_type;
2034 if (!((idt == f32) && utils::one_of(odt, u8, s8, s32))) return;
2035
2036 assert(IMPLICATION(idt == u8 || force_lbound,
2037 vmm_lbound.getIdx() != vmm_ubound.getIdx()));
2038
2039 // No need to saturate on lower bound for signed integer types, as
2040 // the conversion to int would return INT_MIN, and then proper
2041 // saturation will happen in store_data. The param force_lbound, will
2042 // force saturate values unconditionally to lbound.
2043 if (odt == u8)
2044 uni_vpxor(vmm_lbound, vmm_lbound, vmm_lbound);
2045 else if (force_lbound) {
2046 const float saturation_lbound = odt == s8 ? INT8_MIN : INT32_MIN;
2047 init_vmm(vmm_lbound, reg_tmp, saturation_lbound);
2048 }
2049
2050 const float saturation_ubound = types::max_value<float>(odt);
2051 init_vmm(vmm_ubound, reg_tmp, saturation_ubound);
2052 }
2053
2054 template <typename Vmm>
2055 void saturate_f32(const Vmm &vmm, const Vmm &vmm_lbound,
2056 const Vmm &vmm_ubound, data_type_t odt, bool force_lbound = false) {
2057 // This function is used to saturate to odt in f32 before converting
2058 // to s32 in order to avoid bad saturation due to cvtps2dq
2059 // behavior (it returns INT_MIN if the f32 is out of the
2060 // s32 range)
2061 using namespace data_type;
2062 if (!utils::one_of(odt, u8, s8, s32)) return;
2063
2064 // no need to apply lower saturation bound when odt is
2065 // signed, as cvtps2dq will return MIN_INT if the value
2066 // does not fit. The param force_lbound, will force saturate values
2067 // unconditionally to lbound.
2068 if (odt == u8 || force_lbound) {
2069 if (is_valid_isa(avx))
2070 vmaxps(vmm, vmm, vmm_lbound);
2071 else
2072 maxps(vmm, vmm_lbound);
2073 }
2074 if (is_valid_isa(avx))
2075 vminps(vmm, vmm, vmm_ubound);
2076 else
2077 minps(vmm, vmm_ubound);
2078 }
2079
2080 /**
2081 * load_bytes is the utility function to facilitate loading of
2082 * load_size (0 <= load_size <= 32) many contiguous bytes into the Xmm/Ymm
2083 * register from the memory referenced by ptr[reg + offset] address.
2084 *
2085 * Functionally, invocation of load_bytes is equivalent to
2086 * the following loop:
2087 *
2088 * for (int idx = 0; idx < load_size; ++idx)
2089 * vpinsrb(xmm, xmm, ptr[reg + offset + idx], idx);
2090 *
2091 * TODO: Add an option to zero-out unloaded bytes in the Xmm register.
2092 * TODO: Add an option for unsafe_load wherein one could read outside the
2093 * provided memory buffer so as to minimize the total number of read
2094 * memory instructions.
2095 */
2096 template <typename Vmm>
2097 void load_bytes(
2098 const Vmm &vmm, const Xbyak::Address &src_addr, int load_size) {
2099
2100 constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value
2101 || std::is_same<Vmm, Xbyak::Xmm>::value;
2102 if (!is_vmm_supported) {
2103 assert("load_bytes() is only supported for xmm and ymm");
2104 return;
2105 }
2106
2107 const auto addr = [&](int bytes_offset) {
2108 return ptr[src_addr.getRegExp()
2109 + Xbyak::RegExp(bytes_offset * sizeof(int8_t))];
2110 };
2111
2112 helper_load_bytes(vmm, load_size, addr);
2113 }
2114
2115 template <typename Vmm>
2116 void load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset,
2117 int load_size) {
2118
2119 constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value
2120 || std::is_same<Vmm, Xbyak::Xmm>::value;
2121 if (!is_vmm_supported) {
2122 assert("load_bytes() is only supported for xmm and ymm");
2123 return;
2124 }
2125
2126 // Ensure offset is at most 4 bytes to be encoded in the instruction
2127 assert(offset >= INT_MIN && offset <= INT_MAX);
2128
2129 const auto addr = [&](int bytes_offset) {
2130 return ptr[reg + offset + bytes_offset * sizeof(int8_t)];
2131 };
2132
2133 helper_load_bytes(vmm, load_size, addr);
2134 }
2135
2136private:
2137 template <typename Vmm, typename AddrFunc>
2138 void helper_load_bytes(
2139 const Vmm &vmm, int load_size, const AddrFunc &addr) {
2140
2141 constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
2142 constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
2143 assert((is_xmm || is_ymm) && "only Xmm or Ymm registers are allowed");
2144
2145 MAYBE_UNUSED(is_xmm);
2146 MAYBE_UNUSED(is_ymm);
2147
2148 // Ensure data fits completely inside the Xmm/Ymm register
2149 assert(load_size >= 0 && load_size <= 32);
2150
2151 // At most 16 bytes can fit inside the Xmm register
2152 assert(IMPLICATION(load_size > 16, is_ymm));
2153
2154 // Ensure that vector register is compatible with the ISA in hand
2155 assert(IMPLICATION(is_ymm, is_valid_isa(avx)));
2156
2157 assert(is_valid_isa(sse41)
2158 && "routine is not supported for the current isa");
2159
2160 auto xmm = Xbyak::Xmm(vmm.getIdx());
2161 auto ymm = Xbyak::Ymm(vmm.getIdx());
2162
2163 if (load_size == 32) {
2164 vmovups(ymm, addr(0));
2165 return;
2166 }
2167
2168 int start_bytes = 0;
2169 int bytes_to_load = load_size;
2170
2171 if (load_size > 16) {
2172 // Prepare to insert to upper bits of ymm
2173 start_bytes = 16;
2174 bytes_to_load -= 16;
2175 }
2176
2177 if (bytes_to_load >= 8 && bytes_to_load < 16)
2178 uni_vpinsrq(xmm, xmm, addr(start_bytes), 0);
2179 else if (bytes_to_load == 16)
2180 uni_vmovdqu(xmm, addr(start_bytes));
2181
2182 switch (bytes_to_load) {
2183 case 0: break;
2184 case 1: uni_vpinsrb(xmm, xmm, addr(start_bytes), 0); break;
2185 case 2: uni_vpinsrw(xmm, xmm, addr(start_bytes), 0); break;
2186 case 3:
2187 uni_vpinsrw(xmm, xmm, addr(start_bytes), 0);
2188 uni_vpinsrb(xmm, xmm, addr(start_bytes + 2), 2);
2189 break;
2190 case 4: uni_vpinsrd(xmm, xmm, addr(start_bytes), 0); break;
2191 case 5:
2192 uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
2193 uni_vpinsrb(xmm, xmm, addr(start_bytes + 4), 4);
2194 break;
2195 case 6:
2196 uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
2197 uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2);
2198 break;
2199 case 7:
2200 uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
2201 uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2);
2202 uni_vpinsrb(xmm, xmm, addr(start_bytes + 6), 6);
2203 break;
2204 case 8: break;
2205 case 9: uni_vpinsrb(xmm, xmm, addr(start_bytes + 8), 8); break;
2206 case 10: uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4); break;
2207 case 11:
2208 uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4);
2209 uni_vpinsrb(xmm, xmm, addr(start_bytes + 10), 10);
2210 break;
2211 case 12: uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); break;
2212 case 13:
2213 uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
2214 uni_vpinsrb(xmm, xmm, addr(start_bytes + 12), 12);
2215 break;
2216 case 14:
2217 uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
2218 uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6);
2219 break;
2220 case 15:
2221 uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
2222 uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6);
2223 uni_vpinsrb(xmm, xmm, addr(start_bytes + 14), 14);
2224 break;
2225 case 16: break;
2226 default: assert(!"improper load size");
2227 }
2228
2229 if (load_size > 16) {
2230 vinsertf128(ymm, ymm, xmm, 1); // insert to upper bits of ymm
2231 vinsertf128(ymm, ymm, addr(0), 0); // insert to lower bits of ymm
2232 }
2233 }
2234
2235 /**
2236 * store_bytes is the utility function to facilitate storing of
2237 * store_size (0 <= store_size <= 32) many contiguous bytes from the Xmm/Ymm
2238 * register into the memory referenced by ptr[reg + offset] address.
2239 *
2240 * Additionally, when store_size > 16, the input Ymm register will not be
2241 * preserved due to the usage of vextracti128 instruction.
2242 *
2243 * Functionally, invocation of store_bytes is equivalent
2244 * to the following loop:
2245 *
2246 * for (int idx = 0; idx < store_size; ++idx)
2247 * vpextrb(ptr[reg + offset + idx], xmm, idx);
2248 *
2249 * TODO: Add an option for unsafe_store wherein one could store extra dwords
2250 * past the provided memory buffer so as to minimize the total number of
2251 * write memory instructions.
2252 */
2253public:
2254 template <typename Vmm>
2255 void store_bytes(
2256 const Vmm &vmm, const Xbyak::Address &dst_addr, int store_size) {
2257 const auto addr = [&](int bytes_offset) {
2258 return ptr[dst_addr.getRegExp()
2259 + Xbyak::RegExp(bytes_offset * sizeof(int8_t))];
2260 };
2261 store_bytes(vmm, store_size, addr);
2262 }
2263
2264 template <typename Vmm>
2265 void store_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int64_t offset,
2266 int store_size) {
2267
2268 // Ensure offset is at most 4 bytes to be encoded in the instruction
2269 assert(offset >= INT_MIN && offset <= INT_MAX);
2270
2271 const auto addr = [&](int bytes_offset) {
2272 return ptr[reg + offset + bytes_offset * sizeof(int8_t)];
2273 };
2274
2275 store_bytes(vmm, store_size, addr);
2276 }
2277
2278private:
2279 template <typename Vmm, typename AddrFunc>
2280 void store_bytes(const Vmm &vmm, int store_size, const AddrFunc &addr) {
2281
2282 constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
2283 constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
2284 static_assert(
2285 is_xmm || is_ymm, "only Xmm or Ymm registers are allowed");
2286
2287 MAYBE_UNUSED(is_xmm);
2288 MAYBE_UNUSED(is_ymm);
2289
2290 // Ensure data fits completely inside the Xmm/Ymm register
2291 assert(store_size >= 0 && store_size <= 32);
2292
2293 // At most 16 bytes can fit inside the Xmm register
2294 assert(IMPLICATION(store_size > 16, is_ymm));
2295
2296 // Ensure that vector register is compatible with the ISA in hand
2297 assert(IMPLICATION(is_ymm, is_valid_isa(avx)));
2298
2299 assert(is_valid_isa(sse41)
2300 && "routine is not supported for the current isa");
2301
2302 auto xmm = Xbyak::Xmm(vmm.getIdx());
2303 auto ymm = Xbyak::Ymm(vmm.getIdx());
2304
2305 if (store_size == 32) {
2306 vmovups(addr(0), ymm);
2307 return;
2308 }
2309
2310 int start_bytes = 0;
2311 int bytes_to_store = store_size;
2312
2313 if (store_size > 16) {
2314 vmovdqu(addr(0), xmm); // load lower bits from ymm
2315 start_bytes = 16;
2316 bytes_to_store -= 16;
2317 vextractf128(xmm, ymm, 1); // load upper bits from ymm into xmm
2318 }
2319
2320 if (bytes_to_store >= 8 && bytes_to_store < 16)
2321 uni_vpextrq(addr(start_bytes), xmm, 0);
2322 else if (bytes_to_store == 16)
2323 uni_vmovdqu(addr(start_bytes), xmm);
2324
2325 switch (bytes_to_store) {
2326 case 0: break;
2327 case 1: uni_vpextrb(addr(start_bytes), xmm, 0); break;
2328 case 2: uni_vpextrw(addr(start_bytes), xmm, 0); break;
2329 case 3:
2330 uni_vpextrw(addr(start_bytes), xmm, 0);
2331 uni_vpextrb(addr(start_bytes + 2), xmm, 2);
2332 break;
2333 case 4: uni_vpextrd(addr(start_bytes), xmm, 0); break;
2334 case 5:
2335 uni_vpextrd(addr(start_bytes), xmm, 0);
2336 uni_vpextrb(addr(start_bytes + 4), xmm, 4);
2337 break;
2338 case 6:
2339 uni_vpextrd(addr(start_bytes), xmm, 0);
2340 uni_vpextrw(addr(start_bytes + 4), xmm, 2);
2341 break;
2342 case 7:
2343 uni_vpextrd(addr(start_bytes), xmm, 0);
2344 uni_vpextrw(addr(start_bytes + 4), xmm, 2);
2345 uni_vpextrb(addr(start_bytes + 6), xmm, 6);
2346 break;
2347 case 8: break;
2348 case 9: uni_vpextrb(addr(start_bytes + 8), xmm, 8); break;
2349 case 10: uni_vpextrw(addr(start_bytes + 8), xmm, 4); break;
2350 case 11:
2351 uni_vpextrw(addr(start_bytes + 8), xmm, 4);
2352 uni_vpextrb(addr(start_bytes + 10), xmm, 10);
2353 break;
2354 case 12: uni_vpextrd(addr(start_bytes + 8), xmm, 2); break;
2355 case 13:
2356 uni_vpextrd(addr(start_bytes + 8), xmm, 2);
2357 uni_vpextrb(addr(start_bytes + 12), xmm, 12);
2358 break;
2359 case 14:
2360 uni_vpextrd(addr(start_bytes + 8), xmm, 2);
2361 uni_vpextrw(addr(start_bytes + 12), xmm, 6);
2362 break;
2363 case 15:
2364 uni_vpextrd(addr(start_bytes + 8), xmm, 2);
2365 uni_vpextrw(addr(start_bytes + 12), xmm, 6);
2366 uni_vpextrb(addr(start_bytes + 14), xmm, 14);
2367 break;
2368 case 16: break;
2369 default: assert(!"improper store size");
2370 }
2371 }
2372
2373public:
2374 /**
2375 * load_bytes_to_dword_extension is the utility function to facilitate
2376 * loading of load_size (0 <= load_size <= 16) many contiguous bytes in
2377 * the Xmm register from the memory referenced by ptr[reg + offset]
2378 * address and then do signed/zero extension of those to double words.
2379 *
2380 * Functionally, invocation of load_bytes_to_dword_extension is equivalent
2381 * to the following:
2382 *
2383 * for (int idx = 0; idx < load_size; ++idx)
2384 * vpinsrb(xmm, xmm, ptr[reg + offset + idx], idx);
2385 * if (is_signed) vpmovsxbd(vmm, vmm); else vpmovzxbd(vmm, vmm);
2386 *
2387 * Valid values for the load_size variable are:
2388 * [0..4] for XMM version of the function
2389 * [0..8] for YMM version of the function.
2390 * TODO: Implement this routine for every ISA.
2391 */
2392 template <typename Vmm>
2393 void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 &reg,
2394 int64_t offset, bool is_signed, int load_size) {
2395 // Ensure offset is at most 4 bytes to be encoded in the instruction
2396 assert(offset >= INT_MIN && offset <= INT_MAX);
2397 load_bytes_to_dword_extension(
2398 vmm, ptr[reg + offset], is_signed, load_size);
2399 }
2400
2401 template <typename Vmm>
2402 void load_bytes_to_dword_extension(const Vmm &vmm,
2403 const Xbyak::Address &src_addr, bool is_signed, int load_size) {
2404
2405 constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value
2406 || std::is_same<Vmm, Xbyak::Xmm>::value;
2407 if (!is_vmm_supported) {
2408 assert("load_bytes_to_dword_extension() is only supported for xmm "
2409 "and ymm");
2410 return;
2411 }
2412
2413 constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
2414 constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
2415 MAYBE_UNUSED(is_xmm);
2416 MAYBE_UNUSED(is_ymm);
2417
2418 // Ensure extended double words fit inside Ymm (32 * load_size <= 256)
2419 assert(load_size >= 0 && load_size <= 8);
2420 // For Xmm register, load capacity is halved (32 * load_size <= 128)
2421 assert(IMPLICATION(is_xmm, load_size <= 4));
2422
2423 // Ensure that vector register is compatible with the ISA in hand
2424 assert(IMPLICATION(is_ymm, is_valid_isa(avx)));
2425
2426 assert(is_valid_isa(sse41)
2427 && "routine is not supported for the current isa");
2428
2429 // For load_size == 8/4, do load/extension in one go
2430 if (load_size == 8) {
2431 const auto ymm = Xbyak::Ymm(vmm.getIdx());
2432 if (is_signed)
2433 vpmovsxbd(ymm, src_addr);
2434 else
2435 vpmovzxbd(ymm, src_addr);
2436 } else if (load_size == 4) {
2437 const auto xmm = Xbyak::Xmm(vmm.getIdx());
2438 if (is_signed)
2439 uni_vpmovsxbd(xmm, src_addr);
2440 else
2441 uni_vpmovzxbd(xmm, src_addr);
2442 } else {
2443 load_bytes(vmm, src_addr, load_size);
2444 if (is_signed)
2445 uni_vpmovsxbd(vmm, vmm);
2446 else
2447 uni_vpmovzxbd(vmm, vmm);
2448 }
2449 }
2450
2451 /* A utility function to store data of type type_out from vmm register
2452 * into the memory. Moreover store_size many chunks are written to the
2453 * memory beginning with ptr[reg + offset] address.
2454 *
2455 * Note: Content of Vmm register is not guaranteed to be preserved after the
2456 * invocation of this routine.
2457 *
2458 * TODO: Support for every possible data type.
2459 */
2460 template <typename Vmm>
2461 void store_data(data_type_t type_out, const Vmm &vmm,
2462 const Xbyak::Reg64 &reg, int64_t offset, int store_size) {
2463 constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value
2464 || std::is_same<Vmm, Xbyak::Xmm>::value;
2465 using supported_vmm_t = typename utils::conditional<is_vmm_supported,
2466 Vmm, Xbyak::Ymm /*dummy*/>::type;
2467
2468 if (!is_vmm_supported) {
2469 assert("store_data() not supported");
2470 return;
2471 }
2472 helper_store_data(type_out, supported_vmm_t(vmm.getIdx()), reg, offset,
2473 store_size);
2474 }
2475
2476private:
2477 template <typename Vmm>
2478 void helper_store_data(data_type_t type_out, const Vmm &vmm,
2479 const Xbyak::Reg64 &reg, int64_t offset, int store_size) {
2480
2481 assert(is_valid_isa(sse41)
2482 && "routine is not supported for the current isa");
2483 constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
2484
2485 // Owing to lack of cross lane operations in non avx2 compatible isa
2486 // this functionality remains unimplemented for int8 data type
2487 const bool is_int8_dt
2488 = utils::one_of(type_out, data_type::s8, data_type::u8);
2489 assert(IMPLICATION(is_ymm && is_int8_dt, is_valid_isa(avx2)));
2490
2491 // Ensure that vector register is compatible with the ISA in hand
2492 assert(IMPLICATION(is_ymm, is_valid_isa(avx)));
2493
2494 MAYBE_UNUSED(is_ymm);
2495 MAYBE_UNUSED(is_int8_dt);
2496
2497 auto ymm = Xbyak::Ymm(vmm.getIdx());
2498 auto xmm = Xbyak::Xmm(vmm.getIdx());
2499
2500 switch (type_out) {
2501 case data_type::f32:
2502 case data_type::s32:
2503 store_bytes(vmm, reg, offset, sizeof(int32_t) * store_size);
2504 break;
2505 case data_type::u8:
2506 case data_type::s8:
2507 uni_vpackssdw(vmm, vmm, vmm);
2508 // For each y_i of size 64 bits, following cross lane
2509 // operation on ymm yields
2510 // [y_3 y_2 y_1 y_0] |--> [0 0 y_2 y_0]
2511 if (is_ymm) vpermq(ymm, ymm, 0x08);
2512 if (type_out == data_type::s8)
2513 uni_vpacksswb(vmm, vmm, vmm);
2514 else
2515 uni_vpackuswb(vmm, vmm, vmm);
2516 store_bytes(vmm, reg, offset, store_size);
2517 break;
2518 case data_type::bf16:
2519 vcvtneps2bf16(xmm, vmm,
2520 is_valid_isa(avx512_core_bf16) ? Xbyak::EvexEncoding
2521 : Xbyak::VexEncoding);
2522 store_bytes(vmm, reg, offset, sizeof(bfloat16_t) * store_size);
2523 break;
2524 case data_type::f16:
2525 vcvtps2ph(xmm, vmm, _op_mxcsr);
2526 store_bytes(vmm, reg, offset, sizeof(float16_t) * store_size);
2527 break;
2528 default: assert(!"unsupported destination data type");
2529 }
2530 }
2531
2532public:
2533 /* A utility function to load data of type type_in to vmm register
2534 * from the memory. Moreover load_size many chunks are read from the memory
2535 * beginning with ptr[reg + offset] address.
2536 *
2537 * TODO: Support for every possible data type.
2538 */
2539 template <typename Vmm>
2540 void load_data(data_type_t type_in, const Vmm &vmm, const Xbyak::Reg64 &reg,
2541 int64_t offset, int load_size) {
2542 // Ensure offset is at most 4 bytes to be encoded in the instruction
2543 assert(offset >= INT_MIN && offset <= INT_MAX);
2544 load_data(type_in, vmm, ptr[reg + offset], load_size);
2545 }
2546
2547 template <typename Vmm>
2548 void load_data(data_type_t type_in, const Vmm &vmm,
2549 const Xbyak::Address &src_addr, int load_size) {
2550
2551 assert(is_valid_isa(sse41)
2552 && "routine is not supported for the current isa");
2553
2554 switch (type_in) {
2555 case data_type::f32:
2556 case data_type::s32:
2557 load_bytes(vmm, src_addr, sizeof(int32_t) * load_size);
2558 break;
2559 case data_type::s8:
2560 case data_type::u8:
2561 load_bytes_to_dword_extension(
2562 vmm, src_addr, type_in == data_type::s8, load_size);
2563 break;
2564 case data_type::bf16:
2565 load_bytes(vmm, src_addr, sizeof(bfloat16_t) * load_size);
2566 uni_vpmovzxwd(vmm, vmm);
2567 uni_vpslld(vmm, vmm, 16);
2568 break;
2569 case data_type::f16:
2570 load_bytes(vmm, src_addr, sizeof(float16_t) * load_size);
2571 vcvtph2ps(vmm,
2572 typename vreg_traits<Vmm>::Vmm_lower_t(vmm.getIdx()));
2573 break;
2574 default: assert(!"unsupported source data type");
2575 }
2576 }
2577
2578 /* A utility function to process f32 tail (load, store or other) depending
2579 * on tail size, stored in Reg64. Tail size must be value from 0 to 3/7
2580 * (Xmm/Ymm). Tail process functions require integer as argument to specify
2581 * behavior for each tail size.
2582 *
2583 * Only supported for Xmm and Ymm.
2584 */
2585 template <typename Vmm>
2586 void runtime_tail_process(const Xbyak::Reg64 &reg_tail,
2587 const Xbyak::Reg64 &reg_tmp,
2588 const std::function<void(int)> &tail_process,
2589 const data_type_t data_type = data_type::f32) {
2590 const auto simd_w
2591 = vreg_traits<Vmm>::vlen / types::data_type_size(data_type);
2592
2593 Xbyak::Label label_tbl, label_tbl_end;
2594 std::vector<Xbyak::Label> l_case(simd_w);
2595
2596 mov(reg_tmp, label_tbl);
2597 const Xbyak::Address label_address
2598 = ptr[reg_tmp + reg_tail * sizeof(void *)];
2599 jmp(label_address, T_NEAR);
2600
2601 // create jump table
2602 L(label_tbl);
2603 for (size_t i = 0; i < simd_w; i++)
2604 putL(l_case[i]);
2605
2606 // cases for each tail size - from 0 to 3/7
2607 L(l_case[0]);
2608 jmp(label_tbl_end, T_NEAR);
2609 for (size_t i = 1; i < simd_w; i++) {
2610 L(l_case[i]);
2611 tail_process(i);
2612 jmp(label_tbl_end, T_NEAR);
2613 }
2614 L(label_tbl_end);
2615 }
2616
2617 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_generator);
2618
2619public:
2620 /* All uni_ instructions -- apart from uni_vzeroupper() -- will comply with
2621 * the max_cpu_isa argument */
2622 jit_generator(const char *name, void *code_ptr = nullptr,
2623 size_t code_size = MAX_CODE_SIZE, bool use_autogrow = true,
2624 cpu_isa_t max_cpu_isa = get_max_cpu_isa())
2625 : Xbyak::MmapAllocator(name)
2626 , Xbyak::CodeGenerator(code_size,
2627 (code_ptr == nullptr && use_autogrow) ? Xbyak::AutoGrow
2628 : code_ptr,
2629 /*allocator=*/this)
2630 , max_cpu_isa_(max_cpu_isa) {}
2631
2632 virtual ~jit_generator() {}
2633
2634 virtual const char *name() const = 0;
2635 virtual const char *source_file() const = 0;
2636
2637 void register_jit_code(const Xbyak::uint8 *code, size_t code_size) const {
2638 jit_utils::register_jit_code(code, code_size, name(), source_file());
2639 }
2640
2641 const Xbyak::uint8 *jit_ker() const { return jit_ker_; }
2642
2643 template <typename... kernel_args_t>
2644 void operator()(kernel_args_t... args) const {
2645 using jit_kernel_func_t = void (*)(const kernel_args_t... args);
2646 auto *fptr = (jit_kernel_func_t)jit_ker_;
2647 (*fptr)(std::forward<kernel_args_t>(args)...);
2648 }
2649
2650 virtual status_t create_kernel() {
2651 int err_code = Xbyak::GetError();
2652 if (err_code == Xbyak::ERR_CANT_ALLOC) return status::out_of_memory;
2653 if (err_code != Xbyak::ERR_NONE) return status::runtime_error;
2654 generate();
2655 jit_ker_ = getCode();
2656 return (jit_ker_) ? status::success : status::runtime_error;
2657 }
2658
2659private:
2660 const cpu_isa_t max_cpu_isa_;
2661 const Xbyak::uint8 *getCode() {
2662 this->ready();
2663 if (!is_initialized()) return nullptr;
2664 const Xbyak::uint8 *code = CodeGenerator::getCode();
2665 register_jit_code(code, getSize());
2666 return code;
2667 }
2668
2669 inline bool is_valid_isa(cpu_isa_t isa) {
2670 return is_subset(isa, max_cpu_isa_) && mayiuse(isa);
2671 }
2672
2673 static inline bool is_initialized() {
2674 return Xbyak::GetError() == Xbyak::ERR_NONE;
2675 }
2676
2677protected:
2678 virtual void generate() = 0;
2679 const Xbyak::uint8 *jit_ker_ = nullptr;
2680};
2681
2682} // namespace x64
2683} // namespace cpu
2684} // namespace impl
2685} // namespace dnnl
2686
2687#endif
2688