1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_GENERATOR_HPP |
18 | #define CPU_X64_JIT_GENERATOR_HPP |
19 | |
20 | #include <limits.h> |
21 | #include <vector> |
22 | |
23 | #include "common/bit_cast.hpp" |
24 | #include "common/compiler_workarounds.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | #include "common/utils.hpp" |
27 | |
28 | #include "cpu/x64/cpu_isa_traits.hpp" |
29 | |
30 | #include "cpu/jit_utils/jit_utils.hpp" |
31 | |
32 | #if defined(_WIN32) && !defined(__GNUC__) |
33 | #define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__ |
34 | #else |
35 | #define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al))) |
36 | #endif |
37 | |
38 | #if defined(_WIN32) |
39 | #define OFFSET_SHADOWSPACE 0x28 |
40 | #endif |
41 | |
42 | #if GCC_WA_NO_TREE_DOMINATOR_OPTS |
43 | #define ATTRIBUTE_OPTIMIZE __attribute__((optimize("no-tree-dominator-opts"))) |
44 | #else |
45 | #define ATTRIBUTE_OPTIMIZE |
46 | #endif |
47 | |
48 | #define DECLARE_CPU_JIT_AUX_FUNCTIONS(gen_name) \ |
49 | const char *name() const override { return STRINGIFY(gen_name); } \ |
50 | const char *source_file() const override { return __FILE__; } \ |
51 | static const char *jit_name() { \ |
52 | static constexpr char ret[] = "/oneDNN:" STRINGIFY(gen_name); \ |
53 | return ret; \ |
54 | } |
55 | |
56 | namespace dnnl { |
57 | namespace impl { |
58 | namespace cpu { |
59 | namespace x64 { |
60 | |
61 | // TODO: move this to jit_generator class? |
62 | namespace { |
63 | |
64 | typedef enum { |
65 | MAX_CODE_SIZE = 256 * 1024, |
66 | } max_code_size_t; |
67 | |
68 | // TODO: move this somewhere else? Although this is only used by jit kernels |
69 | // (Roma) |
70 | static inline int float2int(float x) { |
71 | return utils::bit_cast<int>(x); |
72 | } |
73 | |
74 | static inline void tc_configure_tile( |
75 | palette_config_t *tc, int t, int rows, int cols) { |
76 | const bool rows_ok = (size_t)t < sizeof(tc->rows) / sizeof(tc->rows[0]); |
77 | const bool cols_ok = (size_t)t < sizeof(tc->cols) / sizeof(tc->cols[0]); |
78 | if (rows_ok && cols_ok) { |
79 | tc->rows[t] = rows; |
80 | tc->cols[t] = cols; |
81 | } else { |
82 | assert(!"out of range" ); |
83 | } |
84 | } |
85 | |
86 | // TODO: A GPR class that hides ABI details from the JIT kernels and allows |
87 | // numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and |
88 | // stack register (sr). |
89 | // |
90 | // This will allow using syntax like this: |
91 | // |
92 | // param = gpr0; |
93 | // reg_input = gpr0; |
94 | // reg_output = gpr1; |
95 | // ... |
96 | // |
97 | // #ifndef XBYAK64 |
98 | // mov(param, ptr[sr]) |
99 | // #endif |
100 | // |
101 | // (Roma) |
102 | |
103 | } // namespace |
104 | |
105 | #ifdef XBYAK64 |
106 | constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = { |
107 | Xbyak::Operand::RBP, |
108 | Xbyak::Operand::RBX, |
109 | Xbyak::Operand::R12, |
110 | Xbyak::Operand::R13, |
111 | Xbyak::Operand::R14, |
112 | Xbyak::Operand::R15, |
113 | #ifdef _WIN32 |
114 | Xbyak::Operand::RDI, |
115 | Xbyak::Operand::RSI, |
116 | #endif |
117 | }; |
118 | |
119 | constexpr Xbyak::Operand::Code abi_param_regs[] = { |
120 | #ifdef _WIN32 |
121 | Xbyak::Operand::RCX, Xbyak::Operand::RDX, Xbyak::Operand::R8, |
122 | Xbyak::Operand::R9 |
123 | #else |
124 | Xbyak::Operand::RDI, Xbyak::Operand::RSI, Xbyak::Operand::RDX, |
125 | Xbyak::Operand::RCX, Xbyak::Operand::R8, Xbyak::Operand::R9 |
126 | #endif |
127 | }; |
128 | |
129 | constexpr Xbyak::Operand::Code abi_not_param_reg = |
130 | #ifdef _WIN32 |
131 | Xbyak::Operand::RDI; |
132 | #else |
133 | Xbyak::Operand::RCX; |
134 | #endif |
135 | |
136 | #define abi_param1 Xbyak::Reg64(abi_param_regs[0]) |
137 | #define abi_param2 Xbyak::Reg64(abi_param_regs[1]) |
138 | #define abi_param3 Xbyak::Reg64(abi_param_regs[2]) |
139 | #define abi_param4 Xbyak::Reg64(abi_param_regs[3]) |
140 | #define abi_param5 Xbyak::Reg64(abi_param_regs[4]) |
141 | #define abi_param6 Xbyak::Reg64(abi_param_regs[5]) |
142 | #define abi_not_param1 Xbyak::Reg64(abi_not_param_reg) |
143 | |
144 | #endif |
145 | |
146 | class jit_generator : public Xbyak::MmapAllocator, |
147 | public Xbyak::CodeGenerator, |
148 | public c_compatible { |
149 | public: |
150 | using c_compatible::operator new; |
151 | using c_compatible::operator new[]; |
152 | using c_compatible::operator delete; |
153 | using c_compatible::operator delete[]; |
154 | |
155 | private: |
156 | const size_t xmm_len = 16; |
157 | #ifdef _WIN32 |
158 | const size_t xmm_to_preserve_start = 6; |
159 | const size_t xmm_to_preserve = 10; |
160 | #else |
161 | const size_t xmm_to_preserve_start = 0; |
162 | const size_t xmm_to_preserve = 0; |
163 | #endif |
164 | |
165 | const size_t num_abi_save_gpr_regs |
166 | = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]); |
167 | |
168 | const size_t size_of_abi_save_regs |
169 | = num_abi_save_gpr_regs * rax.getBit() / 8 |
170 | + xmm_to_preserve * xmm_len; |
171 | |
172 | public: |
173 | enum { |
174 | _cmp_eq_oq = 0u, |
175 | _cmp_lt_os = 1u, |
176 | _cmp_le_os = 2u, |
177 | _cmp_neq_uq = 4u, |
178 | _cmp_nlt_us = 5u, |
179 | _cmp_nle_us = 6u, |
180 | |
181 | _op_floor = 1u, |
182 | _op_mxcsr = 4u, |
183 | }; |
184 | |
185 | Xbyak::Reg64 param1 = abi_param1; |
186 | const int EVEX_max_8b_offt = 0x200; |
187 | const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp; |
188 | |
189 | inline size_t get_size_of_abi_save_regs() { return size_of_abi_save_regs; } |
190 | |
191 | void preamble() { |
192 | if (xmm_to_preserve) { |
193 | sub(rsp, xmm_to_preserve * xmm_len); |
194 | for (size_t i = 0; i < xmm_to_preserve; ++i) |
195 | uni_vmovdqu(ptr[rsp + i * xmm_len], |
196 | Xbyak::Xmm(xmm_to_preserve_start + i)); |
197 | } |
198 | for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) { |
199 | push(Xbyak::Reg64(abi_save_gpr_regs[i])); |
200 | // Stack magic: save rsp into rbp state to be able to unwind stack. |
201 | if (i == 0) mov(rbp, rsp); |
202 | } |
203 | #ifndef DNNL_ENABLE_MEM_DEBUG |
204 | // do not use RBP in mem debug mode to enable backtracing from jit code |
205 | if (is_valid_isa(avx512_core)) { |
206 | mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt); |
207 | } |
208 | #endif |
209 | |
210 | #ifdef DNNL_ENABLE_MEM_DEBUG |
211 | // This section poisons vector registers with NaNs to catch situations |
212 | // when leftover trash in the register is used in a valid instruction |
213 | // without handling trash first. |
214 | { |
215 | // Preserve GPR since GeMM code relies on this one. |
216 | push(abi_not_param1); |
217 | if (is_valid_isa(avx512_core)) { |
218 | for (int i = 0; i < 32; i++) { |
219 | init_vmm(Xbyak::Zmm(i), abi_not_param1, NAN); |
220 | } |
221 | } else if (is_valid_isa(avx)) { |
222 | for (int i = 0; i < 16; i++) { |
223 | init_vmm(Xbyak::Ymm(i), abi_not_param1, NAN); |
224 | } |
225 | } else { |
226 | for (int i = 0; i < 16; i++) { |
227 | init_vmm(Xbyak::Xmm(i), abi_not_param1, NAN); |
228 | } |
229 | } |
230 | pop(abi_not_param1); |
231 | } |
232 | #endif |
233 | } |
234 | |
235 | // This function returns the address on the stack of the fist argument |
236 | // that is not passed by register |
237 | // By default it assumes to be called after the prologue |
238 | // Note: that we cannot use RBP inside as we override it in preamble |
239 | // for address computation in EVEX instructions |
240 | inline const Xbyak::RegExp get_stack_params_address( |
241 | bool after_prolog = true) { |
242 | int saved_regs_size = after_prolog ? get_size_of_abi_save_regs() : 0; |
243 | #ifdef _WIN32 |
244 | // Using stack layout described in MS ABI |
245 | // (https://docs.microsoft.com/en-us/cpp/build/stack-usage?view=vs-2019) |
246 | // here, the return address and the first 4 parameters are allocated |
247 | // on the stack |
248 | int first_params_and_return_addr_size = 40; |
249 | #else |
250 | // In System V ABI, only the return address is stacked |
251 | // before the arguments |
252 | int first_params_and_return_addr_size = 8; |
253 | #endif |
254 | return rsp + saved_regs_size + first_params_and_return_addr_size; |
255 | } |
256 | |
257 | void uni_vzeroupper() { |
258 | if (mayiuse(avx)) vzeroupper(); |
259 | } |
260 | |
261 | void postamble() { |
262 | for (size_t i = 0; i < num_abi_save_gpr_regs; ++i) |
263 | pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i])); |
264 | if (xmm_to_preserve) { |
265 | for (size_t i = 0; i < xmm_to_preserve; ++i) |
266 | uni_vmovdqu(Xbyak::Xmm(xmm_to_preserve_start + i), |
267 | ptr[rsp + i * xmm_len]); |
268 | add(rsp, xmm_to_preserve * xmm_len); |
269 | } |
270 | uni_vzeroupper(); |
271 | ret(); |
272 | } |
273 | |
274 | template <typename T> |
275 | Xbyak::Address EVEX_compress_addr( |
276 | Xbyak::Reg64 base, T raw_offt, bool bcast = false) { |
277 | using Xbyak::Address; |
278 | using Xbyak::Reg64; |
279 | using Xbyak::RegExp; |
280 | using Xbyak::Zmm; |
281 | |
282 | assert(raw_offt <= INT_MAX); |
283 | auto offt = static_cast<int>(raw_offt); |
284 | |
285 | int scale = 0; |
286 | |
287 | #ifndef DNNL_ENABLE_MEM_DEBUG |
288 | // do not use RBP in mem debug mode to enable backtracing from jit code |
289 | if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) { |
290 | offt = offt - 2 * EVEX_max_8b_offt; |
291 | scale = 1; |
292 | } else if (3 * EVEX_max_8b_offt <= offt |
293 | && offt < 5 * EVEX_max_8b_offt) { |
294 | offt = offt - 4 * EVEX_max_8b_offt; |
295 | scale = 2; |
296 | } |
297 | #endif |
298 | |
299 | auto re = RegExp() + base + offt; |
300 | if (scale) re = re + reg_EVEX_max_8b_offt * scale; |
301 | |
302 | if (bcast) |
303 | return zword_b[re]; |
304 | else |
305 | return zword[re]; |
306 | } |
307 | |
308 | Xbyak::Address make_safe_addr(const Xbyak::Reg64 ®_out, size_t offt, |
309 | const Xbyak::Reg64 &tmp_reg, bool bcast = false) { |
310 | if (offt > INT_MAX) { |
311 | mov(tmp_reg, offt); |
312 | return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg]; |
313 | } else { |
314 | return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt]; |
315 | } |
316 | } |
317 | |
318 | Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base, |
319 | size_t raw_offt, const Xbyak::Reg64 ®_offt, bool bcast = false) { |
320 | if (raw_offt > INT_MAX) { |
321 | return make_safe_addr(base, raw_offt, reg_offt, bcast); |
322 | } else { |
323 | return EVEX_compress_addr(base, raw_offt, bcast); |
324 | } |
325 | } |
326 | |
327 | void safe_add(const Xbyak::Reg64 &base, size_t raw_offt, |
328 | const Xbyak::Reg64 ®_offt) { |
329 | if (raw_offt > INT_MAX) { |
330 | mov(reg_offt, raw_offt); |
331 | add(base, reg_offt); |
332 | } else { |
333 | add(base, raw_offt); |
334 | } |
335 | } |
336 | |
337 | void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt, |
338 | const Xbyak::Reg64 ®_offt) { |
339 | if (raw_offt > INT_MAX) { |
340 | mov(reg_offt, raw_offt); |
341 | sub(base, reg_offt); |
342 | } else { |
343 | sub(base, raw_offt); |
344 | } |
345 | } |
346 | |
347 | // Disallow char-based labels completely |
348 | void L(const char *label) = delete; |
349 | void L(Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); } |
350 | |
351 | void L_aligned(Xbyak::Label &label, int alignment = 16) { |
352 | align(alignment); |
353 | L(label); |
354 | } |
355 | |
356 | void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
357 | const Xbyak::Operand &op) { |
358 | if (is_valid_isa(avx512_core)) |
359 | vpxord(x1, x2, op); |
360 | else if (is_valid_isa(avx)) |
361 | vpxor(x1, x2, op); |
362 | else { |
363 | assert(x1.isEqualIfNotInherited(x2)); |
364 | pxor(x2, op); |
365 | } |
366 | } |
367 | void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
368 | const Xbyak::Operand &op) { |
369 | if (is_valid_isa(avx512_core)) |
370 | vpxord(x1, x2, op); |
371 | else if (is_valid_isa(avx2)) |
372 | vpxor(x1, x2, op); |
373 | else |
374 | vxorps(x1, x2, op); |
375 | } |
376 | void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2, |
377 | const Xbyak::Operand &op) { |
378 | vpxord(x1, x2, op); |
379 | } |
380 | |
381 | void uni_vmovss(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
382 | if (is_valid_isa(avx)) |
383 | vmovss(addr, x); |
384 | else |
385 | movss(addr, x); |
386 | } |
387 | void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
388 | if (is_valid_isa(avx)) |
389 | vmovss(x, addr); |
390 | else |
391 | movss(x, addr); |
392 | } |
393 | void uni_vmovss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2) { |
394 | if (is_valid_isa(avx)) |
395 | vmovss(x1, x1, x2); |
396 | else |
397 | movss(x1, x2); |
398 | } |
399 | void uni_vmovss(const Xbyak::Address &addr, const Xbyak::Ymm &x) { |
400 | vmovss(addr, Xbyak::Xmm(x.getIdx())); |
401 | } |
402 | void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address &addr) { |
403 | vmovss(Xbyak::Xmm(x.getIdx()), addr); |
404 | } |
405 | void uni_vmovss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2) { |
406 | vmovss(Xbyak::Xmm(x1.getIdx()), Xbyak::Xmm(x2.getIdx())); |
407 | } |
408 | |
409 | void uni_vmovsd(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
410 | if (is_valid_isa(avx)) |
411 | vmovsd(addr, x); |
412 | else |
413 | movsd(addr, x); |
414 | } |
415 | void uni_vmovsd(const Xbyak::Address &addr, const Xbyak::Ymm &x) { |
416 | vmovsd(addr, x); |
417 | } |
418 | void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
419 | if (is_valid_isa(avx)) |
420 | vmovsd(x, addr); |
421 | else |
422 | movsd(x, addr); |
423 | } |
424 | void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address &addr) { |
425 | vmovsd(x, addr); |
426 | } |
427 | |
428 | void uni_vmovlps(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
429 | if (is_valid_isa(avx)) |
430 | vmovlps(addr, x); |
431 | else |
432 | movlps(addr, x); |
433 | } |
434 | void uni_vmovlps(const Xbyak::Address &addr, const Xbyak::Ymm &x) { |
435 | vmovlps(addr, x); |
436 | } |
437 | void uni_vmovlps(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
438 | if (is_valid_isa(avx)) |
439 | vmovlps(x, addr); |
440 | else |
441 | movlps(x, addr); |
442 | } |
443 | void uni_vmovlps(const Xbyak::Ymm &x, const Xbyak::Address &addr) { |
444 | vmovlps(x, addr); |
445 | } |
446 | |
447 | void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
448 | if (is_valid_isa(avx)) |
449 | vmovdqu(addr, x); |
450 | else |
451 | movdqu(addr, x); |
452 | } |
453 | void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) { |
454 | vmovdqu(addr, x); |
455 | } |
456 | void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) { |
457 | vmovdqu32(addr, x); |
458 | } |
459 | |
460 | void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
461 | if (is_valid_isa(avx)) |
462 | vmovdqu(x, addr); |
463 | else |
464 | movdqu(x, addr); |
465 | } |
466 | void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) { |
467 | vmovdqu(x, addr); |
468 | } |
469 | void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) { |
470 | vmovdqu32(x, addr); |
471 | } |
472 | |
473 | void uni_vmovdqu16(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
474 | if (is_valid_isa(avx512_core)) |
475 | vmovdqu16(addr, x); |
476 | else if (is_valid_isa(avx)) |
477 | vmovups(addr, x); |
478 | else |
479 | movups(addr, x); |
480 | } |
481 | |
482 | void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
483 | if (is_valid_isa(avx512_core)) |
484 | vmovdqu16(x, addr); |
485 | else if (is_valid_isa(avx)) |
486 | vmovups(x, addr); |
487 | else |
488 | movups(x, addr); |
489 | } |
490 | |
491 | void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
492 | if (is_valid_isa(avx)) |
493 | vmovups(addr, x); |
494 | else |
495 | movups(addr, x); |
496 | } |
497 | void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) { |
498 | vmovups(addr, x); |
499 | } |
500 | |
501 | void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
502 | if (is_valid_isa(avx)) |
503 | vmovups(x, op); |
504 | else |
505 | movups(x, op); |
506 | } |
507 | void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
508 | vmovups(x, op); |
509 | } |
510 | |
511 | void uni_vmovups_tail(const Xbyak::Address &addr, const Xbyak::Ymm &mask, |
512 | const Xbyak::Ymm &x) { |
513 | vmaskmovps(addr, mask, x); |
514 | } |
515 | void uni_vmovups_tail(const Xbyak::Ymm &x, const Xbyak::Ymm &mask, |
516 | const Xbyak::Address &addr) { |
517 | vmaskmovps(x, mask, addr); |
518 | } |
519 | |
520 | void uni_vmovups_tail(const Xbyak::Address &addr, const Xbyak::Opmask &mask, |
521 | const Xbyak::Zmm &x) { |
522 | vmovups(addr | mask, x); |
523 | } |
524 | void uni_vmovups_tail(const Xbyak::Zmm &x, const Xbyak::Opmask &mask, |
525 | const Xbyak::Address &addr) { |
526 | vmovups(x | mask | T_z, addr); |
527 | } |
528 | |
529 | void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
530 | if (is_valid_isa(avx)) |
531 | vmovntps(addr, x); |
532 | else |
533 | movntps(addr, x); |
534 | } |
535 | |
536 | void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
537 | if (is_valid_isa(avx2) || (is_valid_isa(avx) && op.isMEM())) |
538 | vbroadcastss(x, op); |
539 | else if (is_valid_isa(avx)) { |
540 | vmovss(x, x, op); |
541 | vshufps(x, x, x, 0x0); |
542 | } else { |
543 | movss(x, op); |
544 | shufps(x, x, 0x0); |
545 | } |
546 | } |
547 | void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
548 | if (op.isMEM() || is_valid_isa(avx2)) { |
549 | vbroadcastss(x, op); |
550 | } else { |
551 | Xbyak::Xmm t(x.getIdx()); |
552 | if (!t.isEqualIfNotInherited(op)) movss(t, op); |
553 | vinsertf128(x, x, t, 1); |
554 | vshufps(x, x, x, 0); |
555 | } |
556 | } |
557 | |
558 | void uni_vpbroadcastb(const Xbyak::Ymm &x, const Xbyak::Reg8 &r) { |
559 | if (is_valid_isa(avx512_core)) |
560 | vpbroadcastb(x, r); // broadcast reg32 directly |
561 | else if (is_valid_isa(avx2)) { |
562 | const Xbyak::Xmm t(x.getIdx()); |
563 | uni_vmovd(t, r.cvt32()); |
564 | vpbroadcastb(x, t); |
565 | } |
566 | assert(is_valid_isa(avx2) && "avx does not support vpbroadcastb" ); |
567 | } |
568 | |
569 | void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
570 | if (is_valid_isa(avx2)) |
571 | vpbroadcastd(x, op); |
572 | else if (is_valid_isa(avx)) { |
573 | if (op.isMEM()) |
574 | vmovss(x, op.getAddress()); |
575 | else |
576 | vmovss(x, x, op); |
577 | vpshufd(x, x, 0x0); |
578 | } else { |
579 | movss(x, op); |
580 | pshufd(x, x, 0x0); |
581 | } |
582 | } |
583 | void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Reg32 &r) { |
584 | if (is_valid_isa(avx512_core)) |
585 | vpbroadcastd(x, r); // broadcast reg32 directly |
586 | else { |
587 | const Xbyak::Xmm t(x.getIdx()); |
588 | uni_vmovd(t, r); |
589 | uni_vpbroadcastd(x, t); |
590 | } |
591 | } |
592 | void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
593 | if (is_valid_isa(avx2)) { |
594 | vpbroadcastd(x, op); |
595 | } else { |
596 | const Xbyak::Xmm t(x.getIdx()); |
597 | if (!t.isEqualIfNotInherited(op)) { |
598 | if (op.isMEM()) |
599 | vmovss(t, op.getAddress()); |
600 | else |
601 | vmovss(t, t, op); |
602 | } |
603 | vinsertf128(x, x, t, 1); |
604 | vshufps(x, x, x, 0); |
605 | } |
606 | } |
607 | |
608 | void uni_vshufps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
609 | const Xbyak::Operand &op, Xbyak::uint8 imm) { |
610 | if (is_valid_isa(avx)) |
611 | vshufps(x1, x2, op, imm); |
612 | else { |
613 | movups(x1, x2); |
614 | shufps(x1, op, imm); |
615 | } |
616 | } |
617 | |
618 | void uni_vpshufd( |
619 | const Xbyak::Xmm &x1, const Xbyak::Operand &op, Xbyak::uint8 imm) { |
620 | if (is_valid_isa(avx)) |
621 | vpshufd(x1, op, imm); |
622 | else { |
623 | pshufd(x1, op, imm); |
624 | } |
625 | } |
626 | |
627 | void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
628 | if (is_valid_isa(avx)) |
629 | vrcpss(x, x, op); |
630 | else |
631 | rcpss(x, op); |
632 | } |
633 | void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) { |
634 | Xbyak::Xmm x1_(x1.getIdx()); |
635 | Xbyak::Xmm x2_(x2.getIdx()); |
636 | vrcpss(x1_, x1_, x2_); |
637 | } |
638 | void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) { |
639 | Xbyak::Xmm x_(x.getIdx()); |
640 | vrcpss(x_, x_, op); |
641 | } |
642 | |
643 | void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
644 | if (is_valid_isa(avx)) |
645 | vrcpps(x, op); |
646 | else |
647 | rcpps(x, op); |
648 | } |
649 | void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
650 | vrcpps(x, op); |
651 | } |
652 | void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) { |
653 | vrcp14ps(x, op); |
654 | } |
655 | |
656 | void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
657 | const Xbyak::Operand &op2) { |
658 | if (is_valid_isa(avx)) |
659 | vdivps(x, op1, op2); |
660 | else { |
661 | assert(x.isEqualIfNotInherited(op1)); |
662 | divps(x, op2); |
663 | } |
664 | } |
665 | void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
666 | const Xbyak::Operand &op2) { |
667 | vdivps(x, op1, op2); |
668 | } |
669 | |
670 | void uni_vdivss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
671 | const Xbyak::Operand &op2) { |
672 | if (is_valid_isa(avx)) |
673 | vdivss(x, op1, op2); |
674 | else { |
675 | assert(x.isEqualIfNotInherited(op1)); |
676 | divss(x, op2); |
677 | } |
678 | } |
679 | |
680 | void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
681 | const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { |
682 | if (is_valid_isa(avx)) |
683 | vdivps(x, op1, op2); |
684 | else { |
685 | movups(buf, op1); |
686 | divps(buf, op2); |
687 | if (x.getIdx() != buf.getIdx()) { movups(x, buf); } |
688 | } |
689 | } |
690 | |
691 | void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
692 | const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { |
693 | vdivps(x, op1, op2); |
694 | } |
695 | |
696 | void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
697 | const Xbyak::Operand &op2) { |
698 | if (is_valid_isa(avx)) |
699 | vaddps(x, op1, op2); |
700 | else { |
701 | assert(x.getIdx() == op1.getIdx()); |
702 | addps(x, op2); |
703 | } |
704 | } |
705 | void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
706 | const Xbyak::Operand &op2) { |
707 | vaddps(x, op1, op2); |
708 | } |
709 | void uni_vaddss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
710 | const Xbyak::Operand &op2) { |
711 | if (is_valid_isa(avx)) |
712 | vaddss(x, op1, op2); |
713 | else { |
714 | assert(x.isEqualIfNotInherited(op1)); |
715 | addss(x, op2); |
716 | } |
717 | } |
718 | void uni_vaddss(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
719 | const Xbyak::Operand &op2) { |
720 | vaddss(x, op1, op2); |
721 | } |
722 | |
723 | void uni_vphaddd(const Xbyak::Xmm &x, const Xbyak::Xmm &x2, |
724 | const Xbyak::Operand &op) { |
725 | if (is_valid_isa(avx)) { |
726 | vphaddd(x, x2, op); |
727 | } else { |
728 | assert(x.isEqualIfNotInherited(op)); |
729 | phaddd(x, op); |
730 | } |
731 | } |
732 | |
733 | void uni_vhaddps(const Xbyak::Xmm &x, const Xbyak::Xmm &x2, |
734 | const Xbyak::Operand &op) { |
735 | if (is_valid_isa(avx)) { |
736 | vhaddps(x, x2, op); |
737 | } else { |
738 | assert(x.isEqualIfNotInherited(op)); |
739 | haddps(x, op); |
740 | } |
741 | } |
742 | |
743 | void uni_vpsignd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
744 | const Xbyak::Operand &op) { |
745 | if (is_valid_isa(avx)) |
746 | vpsignd(x1, x2, op); |
747 | else { |
748 | assert(x1.getIdx() == x2.getIdx()); |
749 | psignd(x1, op); |
750 | } |
751 | } |
752 | void uni_vpsignd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
753 | const Xbyak::Operand &op) { |
754 | vpsignd(x1, x2, op); |
755 | } |
756 | |
757 | void uni_vpsubd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
758 | const Xbyak::Operand &op) { |
759 | if (is_valid_isa(avx)) |
760 | vpsubd(x1, x2, op); |
761 | else { |
762 | assert(x1.getIdx() == x2.getIdx()); |
763 | psubd(x1, op); |
764 | } |
765 | } |
766 | void uni_vpsubd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
767 | const Xbyak::Operand &op) { |
768 | vpsubd(x1, x2, op); |
769 | } |
770 | |
771 | void uni_vpsubb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
772 | const Xbyak::Operand &op) { |
773 | if (is_valid_isa(avx)) |
774 | vpsubb(x1, x2, op); |
775 | else { |
776 | assert(x1.getIdx() == x2.getIdx()); |
777 | psubb(x1, op); |
778 | } |
779 | } |
780 | void uni_vpsubb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
781 | const Xbyak::Operand &op) { |
782 | vpsubb(x1, x2, op); |
783 | } |
784 | |
785 | void uni_vsubss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
786 | const Xbyak::Operand &op2) { |
787 | if (is_valid_isa(avx)) |
788 | vsubss(x, op1, op2); |
789 | else { |
790 | assert(x.isEqualIfNotInherited(op1)); |
791 | subss(x, op2); |
792 | } |
793 | } |
794 | void uni_vsubss(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
795 | const Xbyak::Operand &op2) { |
796 | vsubss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx())); |
797 | } |
798 | |
799 | void uni_vsubss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
800 | const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { |
801 | if (is_valid_isa(avx)) |
802 | vsubss(x, op1, op2); |
803 | else { |
804 | if (!buf.isEqualIfNotInherited(op1)) { |
805 | assert(!buf.isEqualIfNotInherited(op2)); |
806 | movss(buf, op1); |
807 | } |
808 | subss(buf, op2); |
809 | if (x.getIdx() != buf.getIdx()) movss(x, buf); |
810 | } |
811 | } |
812 | |
813 | void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
814 | const Xbyak::Operand &op2) { |
815 | if (is_valid_isa(avx)) |
816 | vsubps(x, op1, op2); |
817 | else { |
818 | assert(x.isEqualIfNotInherited(op1)); |
819 | subps(x, op2); |
820 | } |
821 | } |
822 | void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
823 | const Xbyak::Operand &op2) { |
824 | vsubps(x, op1, op2); |
825 | } |
826 | |
827 | void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
828 | const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { |
829 | if (is_valid_isa(avx)) |
830 | vsubps(x, op1, op2); |
831 | else { |
832 | movups(buf, op1); |
833 | subps(buf, op2); |
834 | if (x.getIdx() != buf.getIdx()) { movups(x, buf); } |
835 | } |
836 | } |
837 | |
838 | void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
839 | const Xbyak::Operand &op2, const Xbyak::Ymm &buf) { |
840 | vsubps(x, op1, op2); |
841 | } |
842 | |
843 | void uni_vpmulld(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
844 | const Xbyak::Operand &op) { |
845 | if (is_valid_isa(avx)) { |
846 | vpmulld(x1, x2, op); |
847 | } else { |
848 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
849 | pmulld(x1, op); |
850 | } |
851 | } |
852 | void uni_vpmulld(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
853 | const Xbyak::Operand &op) { |
854 | vpmulld(x1, x2, op); |
855 | } |
856 | |
857 | void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
858 | const Xbyak::Operand &op2) { |
859 | if (is_valid_isa(avx)) |
860 | vmulps(x, op1, op2); |
861 | else { |
862 | if (!x.isEqualIfNotInherited(op1)) { |
863 | assert(!x.isEqualIfNotInherited(op2)); |
864 | movups(x, op1); |
865 | } |
866 | mulps(x, op2); |
867 | } |
868 | } |
869 | void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
870 | const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { |
871 | if (is_valid_isa(avx)) |
872 | vmulps(x, op1, op2); |
873 | else { |
874 | if (!buf.isEqualIfNotInherited(op1)) { |
875 | assert(!buf.isEqualIfNotInherited(op2)); |
876 | movups(buf, op1); |
877 | } |
878 | mulps(buf, op2); |
879 | if (x.getIdx() != buf.getIdx()) movups(x, buf); |
880 | } |
881 | } |
882 | void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
883 | const Xbyak::Operand &op2) { |
884 | vmulps(x, op1, op2); |
885 | } |
886 | |
887 | void uni_vmulss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
888 | const Xbyak::Operand &op2) { |
889 | if (is_valid_isa(avx)) |
890 | vmulss(x, op1, op2); |
891 | else { |
892 | assert(x.isEqualIfNotInherited(op1)); |
893 | mulss(x, op2); |
894 | } |
895 | } |
896 | void uni_vmulss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
897 | const Xbyak::Operand &op2, const Xbyak::Xmm &buf) { |
898 | if (is_valid_isa(avx)) |
899 | vmulss(x, op1, op2); |
900 | else { |
901 | if (!buf.isEqualIfNotInherited(op1)) { |
902 | assert(!buf.isEqualIfNotInherited(op2)); |
903 | movss(buf, op1); |
904 | } |
905 | mulss(buf, op2); |
906 | if (x.getIdx() != buf.getIdx()) movss(x, buf); |
907 | } |
908 | } |
909 | void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
910 | const Xbyak::Address &op2) { |
911 | vmulss(x, Xbyak::Xmm(op1.getIdx()), op2); |
912 | } |
913 | void uni_vmulss(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
914 | const Xbyak::Ymm &op2) { |
915 | vmulss(x, Xbyak::Xmm(op1.getIdx()), Xbyak::Xmm(op2.getIdx())); |
916 | } |
917 | |
918 | void uni_vfmadd132ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
919 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
920 | if (is_valid_isa(avx2)) |
921 | vfmadd132ps(x1, x2, op); |
922 | else if (is_valid_isa(avx)) { |
923 | assert(x1.getIdx() != x2.getIdx()); |
924 | vmulps(x1, x1, op); |
925 | vaddps(x1, x1, x2); |
926 | } else { |
927 | assert(buf.getIdx() != x2.getIdx()); |
928 | if (x1.getIdx() != buf.getIdx()) movups(buf, x1); |
929 | mulps(buf, op); |
930 | addps(buf, x2); |
931 | if (x1.getIdx() != buf.getIdx()) movups(x1, buf); |
932 | } |
933 | } |
934 | void uni_vfmadd132ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
935 | const Xbyak::Operand &op) { |
936 | // Note: SSE, AVX: x1 gets overridden by x1*op |
937 | // This is incorrect if x1 == x2 |
938 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
939 | uni_vfmadd132ps(x1, x2, op, x1); |
940 | } |
941 | |
942 | void uni_vfmadd132ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
943 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
944 | if (is_valid_isa(avx2)) |
945 | vfmadd132ps(x1, x2, op); |
946 | else { |
947 | // AVX is assumed because of YMM |
948 | assert(buf.getIdx() != x2.getIdx()); |
949 | vmulps(buf, x1, op); |
950 | vaddps(x1, buf, x2); |
951 | } |
952 | } |
953 | void uni_vfmadd132ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
954 | const Xbyak::Operand &op) { |
955 | // Note: AVX: x1 gets overridden by x1*op |
956 | // This is incorrect if x1 == x2 |
957 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
958 | uni_vfmadd132ps(x1, x2, op, x1); |
959 | } |
960 | |
961 | void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
962 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
963 | if (is_valid_isa(avx2)) |
964 | vfmadd213ps(x1, x2, op); |
965 | else if (is_valid_isa(avx)) { |
966 | assert(!buf.isEqualIfNotInherited(op)); |
967 | vmulps(buf, x1, x2); |
968 | vaddps(x1, buf, op); |
969 | } else { |
970 | assert(!buf.isEqualIfNotInherited(op)); |
971 | if (x1.getIdx() != buf.getIdx()) movups(buf, x1); |
972 | mulps(buf, x2); |
973 | addps(buf, op); |
974 | if (x1.getIdx() != buf.getIdx()) movups(x1, buf); |
975 | } |
976 | } |
977 | void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
978 | const Xbyak::Operand &op) { |
979 | // Note: SSE, AVX: x1 gets overridden by x1*x2 |
980 | // This is incorrect if x1 == op |
981 | if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op)); |
982 | uni_vfmadd213ps(x1, x2, op, x1); |
983 | } |
984 | |
985 | void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
986 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
987 | if (is_valid_isa(avx2)) |
988 | vfmadd213ps(x1, x2, op); |
989 | else { |
990 | // AVX is assumed because of YMM |
991 | assert(!buf.isEqualIfNotInherited(op)); |
992 | vmulps(buf, x1, x2); |
993 | vaddps(x1, buf, op); |
994 | } |
995 | } |
996 | void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
997 | const Xbyak::Operand &op) { |
998 | // Note: AVX: x1 gets overridden by x1*x2 |
999 | // This is incorrect if x1 == op |
1000 | if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op)); |
1001 | uni_vfmadd213ps(x1, x2, op, x1); |
1002 | } |
1003 | |
1004 | void uni_vfmadd213ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1005 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
1006 | if (is_valid_isa(avx2)) |
1007 | vfmadd213ss(x1, x2, op); |
1008 | else if (is_valid_isa(avx)) { |
1009 | assert(!buf.isEqualIfNotInherited(op)); |
1010 | vmulss(buf, x1, x2); |
1011 | vaddss(x1, buf, op); |
1012 | } else { |
1013 | assert(!buf.isEqualIfNotInherited(op)); |
1014 | if (x1.getIdx() != buf.getIdx()) movss(buf, x1); |
1015 | mulss(buf, x2); |
1016 | addss(x1, op); |
1017 | if (x1.getIdx() != buf.getIdx()) movss(x1, buf); |
1018 | } |
1019 | } |
1020 | void uni_vfmadd213ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1021 | const Xbyak::Operand &op) { |
1022 | // Note: SSE, AVX: x1 gets overridden by x1*x2 |
1023 | // This is incorrect if x1 == op |
1024 | if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op)); |
1025 | uni_vfmadd213ss(x1, x2, op, x1); |
1026 | } |
1027 | |
1028 | void uni_vfmadd213ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1029 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
1030 | if (is_valid_isa(avx2)) |
1031 | vfmadd213ss(x1, x2, op); |
1032 | else { |
1033 | // AVX is assumed because of YMM |
1034 | assert(!buf.isEqualIfNotInherited(op)); |
1035 | vmulss(buf, x1, x2); |
1036 | vaddss(x1, buf, op); |
1037 | } |
1038 | } |
1039 | void uni_vfmadd213ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1040 | const Xbyak::Operand &op) { |
1041 | // Note: AVX: x1 gets overridden by x1*x2 |
1042 | // This is incorrect if x1 == op |
1043 | if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op)); |
1044 | uni_vfmadd213ss(x1, x2, op, x1); |
1045 | } |
1046 | |
1047 | void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1048 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
1049 | if (is_valid_isa(avx2)) |
1050 | vfmadd231ps(x1, x2, op); |
1051 | else if (is_valid_isa(avx)) { |
1052 | assert(buf.getIdx() != x1.getIdx()); |
1053 | vmulps(buf, x2, op); |
1054 | vaddps(x1, x1, buf); |
1055 | } else { |
1056 | assert(buf.getIdx() != x1.getIdx()); |
1057 | if (x2.getIdx() != buf.getIdx()) movups(buf, x2); |
1058 | mulps(buf, op); |
1059 | addps(x1, buf); |
1060 | } |
1061 | } |
1062 | void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1063 | const Xbyak::Operand &op) { |
1064 | // Note: SSE, AVX: x2 gets overridden by x2*op |
1065 | // This is incorrect if x1 == x2 |
1066 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1067 | uni_vfmadd231ps(x1, x2, op, x2); |
1068 | } |
1069 | |
1070 | void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1071 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
1072 | if (is_valid_isa(avx2)) |
1073 | vfmadd231ps(x1, x2, op); |
1074 | else { |
1075 | // AVX is assumed because of YMM |
1076 | assert(buf.getIdx() != x1.getIdx()); |
1077 | vmulps(buf, x2, op); |
1078 | vaddps(x1, x1, buf); |
1079 | } |
1080 | } |
1081 | void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1082 | const Xbyak::Operand &op) { |
1083 | // Note: AVX: x2 gets overridden by x2*op |
1084 | // This is incorrect if x1 == x2 |
1085 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1086 | uni_vfmadd231ps(x1, x2, op, x2); |
1087 | } |
1088 | |
1089 | void uni_vfmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1090 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
1091 | if (is_valid_isa(avx2)) |
1092 | vfmadd231ss(x1, x2, op); |
1093 | else if (is_valid_isa(avx)) { |
1094 | assert(buf.getIdx() != x1.getIdx()); |
1095 | vmulss(buf, x2, op); |
1096 | vaddss(x1, x1, buf); |
1097 | } else { |
1098 | assert(buf.getIdx() != x1.getIdx()); |
1099 | if (x2.getIdx() != buf.getIdx()) movss(buf, x2); |
1100 | mulss(buf, op); |
1101 | addss(x1, buf); |
1102 | } |
1103 | } |
1104 | void uni_vfmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1105 | const Xbyak::Operand &op) { |
1106 | // Note: SSE, AVX: x2 gets overridden by x2*op |
1107 | // This is incorrect if x1 == x2 |
1108 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1109 | uni_vfmadd231ss(x1, x2, op, x2); |
1110 | } |
1111 | |
1112 | void uni_vfmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1113 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
1114 | if (is_valid_isa(avx2)) |
1115 | vfmadd231ss(Xbyak::Xmm(x1.getIdx()), Xbyak::Xmm(x2.getIdx()), op); |
1116 | else { |
1117 | // AVX is assumed because of YMM |
1118 | assert(buf.getIdx() != x1.getIdx()); |
1119 | vmulss(buf, x2, op); |
1120 | vaddss(x1, x1, buf); |
1121 | } |
1122 | } |
1123 | void uni_vfmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1124 | const Xbyak::Operand &op) { |
1125 | // Note: AVX: x2 gets overridden by x2*op |
1126 | // This is incorrect if x1 == x2 |
1127 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1128 | uni_vfmadd231ss(x1, x2, op, x2); |
1129 | } |
1130 | |
1131 | void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1132 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
1133 | if (is_valid_isa(avx2)) |
1134 | vfnmadd231ps(x1, x2, op); |
1135 | else if (is_valid_isa(avx)) { |
1136 | assert(buf.getIdx() != x1.getIdx()); |
1137 | vmulps(buf, x2, op); |
1138 | vsubps(x1, x1, buf); |
1139 | } else { |
1140 | assert(buf.getIdx() != x1.getIdx()); |
1141 | if (x2.getIdx() != buf.getIdx()) movups(buf, x2); |
1142 | mulps(buf, op); |
1143 | subps(x1, buf); |
1144 | } |
1145 | } |
1146 | void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1147 | const Xbyak::Operand &op) { |
1148 | // Note: SSE, AVX: x2 gets overridden by x2*op |
1149 | // This is incorrect if x1 == x2 |
1150 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1151 | uni_vfnmadd231ps(x1, x2, op, x2); |
1152 | } |
1153 | |
1154 | void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1155 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
1156 | if (is_valid_isa(avx2)) |
1157 | vfnmadd231ps(x1, x2, op); |
1158 | else { |
1159 | // AVX is assumed because of YMM |
1160 | assert(buf.getIdx() != x1.getIdx()); |
1161 | vmulps(buf, x2, op); |
1162 | vsubps(x1, x1, buf); |
1163 | } |
1164 | } |
1165 | void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1166 | const Xbyak::Operand &op) { |
1167 | // Note: AVX: x2 gets overridden by x2*op |
1168 | // This is incorrect if x1 == x2 |
1169 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1170 | uni_vfnmadd231ps(x1, x2, op, x2); |
1171 | } |
1172 | |
1173 | void uni_vfnmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1174 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
1175 | if (is_valid_isa(avx2)) |
1176 | vfnmadd231ss(x1, x2, op); |
1177 | else if (is_valid_isa(avx)) { |
1178 | assert(buf.getIdx() != x1.getIdx()); |
1179 | vmulss(buf, x2, op); |
1180 | vsubss(x1, x1, buf); |
1181 | } else { |
1182 | assert(buf.getIdx() != x1.getIdx()); |
1183 | if (x2.getIdx() != buf.getIdx()) movss(buf, x2); |
1184 | mulss(buf, op); |
1185 | subss(x1, buf); |
1186 | } |
1187 | } |
1188 | void uni_vfnmadd231ss(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1189 | const Xbyak::Operand &op) { |
1190 | // Note: SSE, AVX: x2 gets overridden by x2*op |
1191 | // This is incorrect if x1 == x2 |
1192 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1193 | uni_vfnmadd231ss(x1, x2, op, x2); |
1194 | } |
1195 | |
1196 | void uni_vfnmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1197 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
1198 | if (is_valid_isa(avx2)) |
1199 | vfnmadd231ss(x1, x2, op); |
1200 | else { |
1201 | // AVX is assumed because of YMM |
1202 | assert(buf.getIdx() != x1.getIdx()); |
1203 | vmulss(buf, x2, op); |
1204 | vsubss(x1, x1, buf); |
1205 | } |
1206 | } |
1207 | void uni_vfnmadd231ss(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1208 | const Xbyak::Operand &op) { |
1209 | // Note: AVX: x2 gets overridden by x2*op |
1210 | // This is incorrect if x1 == x2 |
1211 | if (!is_valid_isa(avx2)) assert(x1 != x2); |
1212 | uni_vfnmadd231ss(x1, x2, op, x2); |
1213 | } |
1214 | |
1215 | void uni_vfmsub213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1216 | const Xbyak::Operand &op, const Xbyak::Xmm &buf) { |
1217 | if (is_valid_isa(avx2)) |
1218 | vfmsub213ps(x1, x2, op); |
1219 | else if (is_valid_isa(avx)) { |
1220 | assert(!buf.isEqualIfNotInherited(op)); |
1221 | vmulps(buf, x1, x2); |
1222 | vsubps(x1, buf, op); |
1223 | } else { |
1224 | assert(!buf.isEqualIfNotInherited(op)); |
1225 | if (buf.getIdx() != x1.getIdx()) movups(buf, x1); |
1226 | mulps(buf, x2); |
1227 | subps(buf, op); |
1228 | if (buf.getIdx() != x1.getIdx()) movups(x1, buf); |
1229 | } |
1230 | } |
1231 | void uni_vfmsub213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1232 | const Xbyak::Operand &op) { |
1233 | // Note: SSE, AVX: x1 gets overridden by x1*x2 |
1234 | // This is incorrect if x1 == op |
1235 | if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op)); |
1236 | uni_vfmsub213ps(x1, x2, op, x1); |
1237 | } |
1238 | |
1239 | void uni_vfmsub213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1240 | const Xbyak::Operand &op, const Xbyak::Ymm &buf) { |
1241 | if (is_valid_isa(avx2)) |
1242 | vfmsub213ps(x1, x2, op); |
1243 | else { |
1244 | // AVX is assumed because of YMM |
1245 | assert(!buf.isEqualIfNotInherited(op)); |
1246 | vmulps(buf, x1, x2); |
1247 | vsubps(x1, buf, op); |
1248 | } |
1249 | } |
1250 | void uni_vfmsub213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1251 | const Xbyak::Operand &op) { |
1252 | // Note: AVX: x1 gets overridden by x1*x2 |
1253 | // This is incorrect if x1 == op |
1254 | if (!is_valid_isa(avx2)) assert(!x1.isEqualIfNotInherited(op)); |
1255 | uni_vfmsub213ps(x1, x2, op, x1); |
1256 | } |
1257 | |
1258 | void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1259 | if (is_valid_isa(avx)) |
1260 | vsqrtps(x, op); |
1261 | else |
1262 | sqrtps(x, op); |
1263 | } |
1264 | void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
1265 | vsqrtps(x, op); |
1266 | } |
1267 | |
1268 | void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1269 | const Xbyak::Operand &op) { |
1270 | if (is_valid_isa(avx)) |
1271 | vpaddd(x1, x2, op); |
1272 | else { |
1273 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1274 | paddd(x1, op); |
1275 | } |
1276 | } |
1277 | void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1278 | const Xbyak::Operand &op) { |
1279 | vpaddd(x1, x2, op); |
1280 | } |
1281 | |
1282 | void uni_vpaddb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1283 | const Xbyak::Operand &op) { |
1284 | if (is_valid_isa(avx)) |
1285 | vpaddb(x1, x2, op); |
1286 | else { |
1287 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1288 | paddb(x1, op); |
1289 | } |
1290 | } |
1291 | void uni_vpaddb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1292 | const Xbyak::Operand &op) { |
1293 | vpaddb(x1, x2, op); |
1294 | } |
1295 | |
1296 | void uni_vpmaddwd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1297 | const Xbyak::Operand &op) { |
1298 | if (is_valid_isa(avx)) |
1299 | vpmaddwd(x1, x2, op); |
1300 | else { |
1301 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1302 | pmaddwd(x1, op); |
1303 | } |
1304 | } |
1305 | void uni_vpmaddwd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1306 | const Xbyak::Operand &op) { |
1307 | vpmaddwd(x1, x2, op); |
1308 | } |
1309 | |
1310 | void uni_vpmaddubsw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1311 | const Xbyak::Operand &op) { |
1312 | if (is_valid_isa(avx)) |
1313 | vpmaddubsw(x1, x2, op); |
1314 | else { |
1315 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1316 | pmaddubsw(x1, op); |
1317 | } |
1318 | } |
1319 | void uni_vpmaddubsw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1320 | const Xbyak::Operand &op) { |
1321 | vpmaddubsw(x1, x2, op); |
1322 | } |
1323 | |
1324 | void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1325 | const Xbyak::Operand &op) { |
1326 | if (is_valid_isa(avx)) |
1327 | vandps(x1, x2, op); |
1328 | else { |
1329 | assert(x1.getIdx() == x2.getIdx()); |
1330 | andps(x1, op); |
1331 | } |
1332 | } |
1333 | void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1334 | const Xbyak::Operand &op) { |
1335 | if (!is_valid_isa(avx512_core) || x1.getBit() < 512) |
1336 | vandps(x1, x2, op); |
1337 | else |
1338 | vpandd(x1, x2, op); |
1339 | } |
1340 | |
1341 | void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1342 | const Xbyak::Operand &op) { |
1343 | if (is_valid_isa(avx)) |
1344 | vorps(x1, x2, op); |
1345 | else { |
1346 | assert(x1.getIdx() == x2.getIdx()); |
1347 | orps(x1, op); |
1348 | } |
1349 | } |
1350 | void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1351 | const Xbyak::Operand &op) { |
1352 | if (!is_valid_isa(avx512_core) || x1.getBit() < 512) |
1353 | vorps(x1, x2, op); |
1354 | else |
1355 | vpord(x1, x2, op); |
1356 | } |
1357 | |
1358 | void uni_vxorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1359 | const Xbyak::Operand &op) { |
1360 | if (is_valid_isa(avx)) |
1361 | vxorps(x1, x2, op); |
1362 | else { |
1363 | if (x1.getIdx() != x2.getIdx()) { uni_vmovups(x1, x2); } |
1364 | xorps(x1, op); |
1365 | } |
1366 | } |
1367 | void uni_vxorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1368 | const Xbyak::Operand &op) { |
1369 | if (!is_valid_isa(avx512_core) || x1.getBit() < 512) |
1370 | vxorps(x1, x2, op); |
1371 | else |
1372 | vpxord(x1, x2, op); |
1373 | } |
1374 | |
1375 | void uni_vpslld( |
1376 | const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { |
1377 | if (is_valid_isa(avx)) |
1378 | vpslld(x, op, imm); |
1379 | else { |
1380 | assert(x.isEqualIfNotInherited(op)); |
1381 | pslld(x, imm); |
1382 | } |
1383 | } |
1384 | void uni_vpslld( |
1385 | const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { |
1386 | vpslld(x, op, imm); |
1387 | } |
1388 | |
1389 | void uni_vpsrld( |
1390 | const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { |
1391 | if (is_valid_isa(avx)) |
1392 | vpsrld(x, op, imm); |
1393 | else { |
1394 | if (!x.isEqualIfNotInherited(op)) uni_vmovups(x, op); |
1395 | psrld(x, imm); |
1396 | } |
1397 | } |
1398 | void uni_vpsrld( |
1399 | const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { |
1400 | vpsrld(x, op, imm); |
1401 | } |
1402 | |
1403 | void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
1404 | const Xbyak::Operand &op2) { |
1405 | if (is_valid_isa(avx)) |
1406 | vmaxps(x, op1, op2); |
1407 | else { |
1408 | if (!x.isEqualIfNotInherited(op1)) movups(x, op1); |
1409 | maxps(x, op2); |
1410 | } |
1411 | } |
1412 | void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
1413 | const Xbyak::Operand &op2) { |
1414 | vmaxps(x, op1, op2); |
1415 | } |
1416 | |
1417 | void uni_vmaxss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
1418 | const Xbyak::Operand &op2) { |
1419 | if (is_valid_isa(avx)) |
1420 | vmaxss(x, op1, op2); |
1421 | else { |
1422 | if (!x.isEqualIfNotInherited(op1)) movss(x, op1); |
1423 | maxss(x, op2); |
1424 | } |
1425 | } |
1426 | |
1427 | void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
1428 | const Xbyak::Operand &op2) { |
1429 | if (is_valid_isa(avx)) |
1430 | vminps(x, op1, op2); |
1431 | else { |
1432 | if (!x.isEqualIfNotInherited(op1)) movups(x, op1); |
1433 | minps(x, op2); |
1434 | } |
1435 | } |
1436 | |
1437 | void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1, |
1438 | const Xbyak::Operand &op2) { |
1439 | vminps(x, op1, op2); |
1440 | } |
1441 | |
1442 | void uni_vminss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, |
1443 | const Xbyak::Operand &op2) { |
1444 | if (is_valid_isa(avx)) |
1445 | vminss(x, op1, op2); |
1446 | else { |
1447 | if (!x.isEqualIfNotInherited(op1)) movss(x, op1); |
1448 | minss(x, op2); |
1449 | } |
1450 | } |
1451 | |
1452 | void uni_vpmovsxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1453 | if (is_valid_isa(avx)) |
1454 | vpmovsxbd(x, op); |
1455 | else |
1456 | pmovsxbd(x, op); |
1457 | } |
1458 | |
1459 | void uni_vpmovsxbd(const Xbyak::Ymm &y, const Xbyak::Operand &op) { |
1460 | vpmovsxbd(y, op); |
1461 | } |
1462 | |
1463 | void uni_vpmovzxbd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1464 | if (is_valid_isa(avx)) |
1465 | vpmovzxbd(x, op); |
1466 | else |
1467 | pmovzxbd(x, op); |
1468 | } |
1469 | void uni_vpmovzxbd(const Xbyak::Ymm &y, const Xbyak::Operand &op) { |
1470 | vpmovzxbd(y, op); |
1471 | } |
1472 | |
1473 | void uni_vcmpps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1474 | const Xbyak::Operand &op, int cmp_predicate) { |
1475 | if (is_valid_isa(avx)) |
1476 | vcmpps(x1, x2, op, cmp_predicate); |
1477 | else { |
1478 | if (x1.getIdx() != x2.getIdx()) uni_vmovups(x1, x2); |
1479 | cmpps(x1, op, cmp_predicate); |
1480 | } |
1481 | } |
1482 | void uni_vcmpps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1483 | const Xbyak::Operand &op, int cmp_predicate) { |
1484 | vcmpps(x1, x2, op, cmp_predicate); |
1485 | } |
1486 | |
1487 | void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) { |
1488 | if (is_valid_isa(avx)) |
1489 | vtestps(x1, op); |
1490 | else |
1491 | ptest(x1, op); |
1492 | } |
1493 | |
1494 | void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) { |
1495 | assert(!(x1.isZMM() || op.isZMM())); |
1496 | vtestps(x1, op); |
1497 | } |
1498 | |
1499 | void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1500 | const Xbyak::Operand &op, const Xbyak::Xmm &msk) { |
1501 | if (is_valid_isa(avx)) |
1502 | vblendvps(x1, x2, op, msk); |
1503 | else { |
1504 | assert(x1.getIdx() == x2.getIdx()); |
1505 | assert(msk.getIdx() == 0); |
1506 | blendvps(x1, op); |
1507 | } |
1508 | } |
1509 | void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1510 | const Xbyak::Operand &op, const Xbyak::Ymm &msk) { |
1511 | vblendvps(x1, x2, op, msk); |
1512 | } |
1513 | |
1514 | void uni_vblendps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1515 | const Xbyak::Operand &op, const int imm) { |
1516 | assert(!x1.isZMM() && !x2.isZMM()); |
1517 | |
1518 | if (is_valid_isa(avx)) |
1519 | vblendps(x1, x2, op, imm); |
1520 | else { |
1521 | assert(x1.getIdx() == x2.getIdx()); |
1522 | blendps(x1, op, imm); |
1523 | } |
1524 | } |
1525 | |
1526 | void uni_vroundps( |
1527 | const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { |
1528 | if (is_valid_isa(avx512_core)) |
1529 | vrndscaleps(x, op, imm & 0x3); |
1530 | else if (is_valid_isa(avx)) |
1531 | vroundps(x, op, imm); |
1532 | else |
1533 | roundps(x, op, imm); |
1534 | } |
1535 | |
1536 | void uni_vroundps( |
1537 | const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { |
1538 | if (is_valid_isa(avx512_core)) |
1539 | vrndscaleps(x, op, imm & 0x3); |
1540 | else |
1541 | vroundps(x, op, imm); |
1542 | } |
1543 | |
1544 | void uni_vroundps( |
1545 | const Xbyak::Zmm &x, const Xbyak::Operand &op, const int imm) { |
1546 | vrndscaleps(x, op, imm & 0x3); |
1547 | } |
1548 | |
1549 | void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1550 | if (is_valid_isa(avx)) |
1551 | vcvtps2dq(x, op); |
1552 | else |
1553 | cvtps2dq(x, op); |
1554 | } |
1555 | void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
1556 | vcvtps2dq(x, op); |
1557 | } |
1558 | |
1559 | void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1560 | if (is_valid_isa(avx)) |
1561 | vcvtdq2ps(x, op); |
1562 | else |
1563 | cvtdq2ps(x, op); |
1564 | } |
1565 | |
1566 | void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
1567 | vcvtdq2ps(x, op); |
1568 | } |
1569 | |
1570 | void uni_vcvtph2psx(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1571 | assert(is_valid_isa(avx2)); |
1572 | if (is_valid_isa(avx512_core_fp16)) |
1573 | vcvtph2psx(x, op); |
1574 | else if (is_valid_isa(avx2)) { |
1575 | assert(IMPLICATION(op.isMEM(), !op.getAddress().isBroadcast())); |
1576 | vcvtph2ps(x, op); |
1577 | } |
1578 | } |
1579 | |
1580 | void uni_vcvtps2phx(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
1581 | assert(is_valid_isa(avx512_core_fp16)); |
1582 | vcvtps2phx(x, addr); |
1583 | } |
1584 | |
1585 | void uni_vcvtps2phx(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
1586 | assert(is_valid_isa(avx2)); |
1587 | vcvtps2ph(addr, x, _op_mxcsr); |
1588 | } |
1589 | |
1590 | void uni_vcvtps2phx(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2) { |
1591 | assert(is_valid_isa(avx2)); |
1592 | if (is_valid_isa(avx512_core_fp16)) |
1593 | vcvtps2phx(x1, x2); |
1594 | else if (is_valid_isa(avx2)) |
1595 | vcvtps2ph(x1, x2, _op_mxcsr); |
1596 | } |
1597 | |
1598 | void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) { |
1599 | movmskps(x1.cvt64(), x2); |
1600 | } |
1601 | void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) { |
1602 | vmovmskps(x1, x2); |
1603 | } |
1604 | |
1605 | void uni_vmovd(const Xbyak::Reg32 &r, const Xbyak::Xmm &x) { |
1606 | if (is_valid_isa(avx)) |
1607 | vmovd(r, x); |
1608 | else |
1609 | movd(r, x); |
1610 | } |
1611 | void uni_vmovd(const Xbyak::Xmm &x, const Xbyak::Reg32 &r) { |
1612 | if (is_valid_isa(avx)) |
1613 | vmovd(x, r); |
1614 | else |
1615 | movd(x, r); |
1616 | } |
1617 | void uni_vmovd(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
1618 | if (is_valid_isa(avx)) |
1619 | vmovd(addr, x); |
1620 | else |
1621 | movd(addr, x); |
1622 | } |
1623 | |
1624 | void uni_vmovd(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
1625 | if (is_valid_isa(avx)) |
1626 | vmovd(x, addr); |
1627 | else |
1628 | movd(x, addr); |
1629 | } |
1630 | |
1631 | void uni_vmovq(const Xbyak::Xmm &x, const Xbyak::Reg64 &r) { |
1632 | if (is_valid_isa(avx)) |
1633 | vmovq(x, r); |
1634 | else |
1635 | movq(x, r); |
1636 | } |
1637 | void uni_vmovq(const Xbyak::Reg64 &r, const Xbyak::Xmm &x) { |
1638 | if (is_valid_isa(avx)) |
1639 | vmovq(r, x); |
1640 | else |
1641 | movq(r, x); |
1642 | } |
1643 | void uni_vmovq(const Xbyak::Address &addr, const Xbyak::Xmm &x) { |
1644 | if (is_valid_isa(avx)) |
1645 | vmovq(addr, x); |
1646 | else |
1647 | movq(addr, x); |
1648 | } |
1649 | void uni_vmovq(const Xbyak::Xmm &x, const Xbyak::Address &addr) { |
1650 | if (is_valid_isa(avx)) |
1651 | vmovq(x, addr); |
1652 | else |
1653 | movq(x, addr); |
1654 | } |
1655 | |
1656 | void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1657 | const Xbyak::Operand &op) { |
1658 | if (is_valid_isa(avx)) |
1659 | vpackssdw(x1, x2, op); |
1660 | else { |
1661 | assert(x1.getIdx() == x2.getIdx()); |
1662 | packssdw(x1, op); |
1663 | } |
1664 | } |
1665 | void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1666 | const Xbyak::Operand &op) { |
1667 | vpackssdw(x1, x2, op); |
1668 | } |
1669 | |
1670 | void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1671 | const Xbyak::Operand &op) { |
1672 | if (is_valid_isa(avx)) |
1673 | vpackuswb(x1, x2, op); |
1674 | else { |
1675 | assert(x1.getIdx() == x2.getIdx()); |
1676 | packuswb(x1, op); |
1677 | } |
1678 | } |
1679 | void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1680 | const Xbyak::Operand &op) { |
1681 | vpackuswb(x1, x2, op); |
1682 | } |
1683 | |
1684 | void uni_vpacksswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1685 | const Xbyak::Operand &op) { |
1686 | if (is_valid_isa(avx)) |
1687 | vpacksswb(x1, x2, op); |
1688 | else { |
1689 | assert(x1.getIdx() == x2.getIdx()); |
1690 | packsswb(x1, op); |
1691 | } |
1692 | } |
1693 | void uni_vpacksswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1694 | const Xbyak::Operand &op) { |
1695 | vpacksswb(x1, x2, op); |
1696 | } |
1697 | |
1698 | void uni_vpinsrb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1699 | const Xbyak::Operand &op, const int imm) { |
1700 | if (is_valid_isa(avx)) |
1701 | vpinsrb(x1, x2, op, imm); |
1702 | else { |
1703 | assert(x1.getIdx() == x2.getIdx()); |
1704 | pinsrb(x1, op, imm); |
1705 | } |
1706 | } |
1707 | |
1708 | void uni_vpinsrb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1709 | const Xbyak::Operand &op, const int imm) { |
1710 | vpinsrb(x1, x2, op, imm); |
1711 | } |
1712 | |
1713 | void uni_vpinsrd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1714 | const Xbyak::Operand &op, const int imm) { |
1715 | if (is_valid_isa(avx)) |
1716 | vpinsrd(x1, x2, op, imm); |
1717 | else { |
1718 | assert(x1.getIdx() == x2.getIdx()); |
1719 | pinsrd(x1, op, imm); |
1720 | } |
1721 | } |
1722 | void uni_vpinsrd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1723 | const Xbyak::Operand &op, const int imm) { |
1724 | vpinsrd(x1, x2, op, imm); |
1725 | } |
1726 | |
1727 | void uni_vpinsrq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1728 | const Xbyak::Operand &op, const int imm) { |
1729 | if (is_valid_isa(avx)) |
1730 | vpinsrq(x1, x2, op, imm); |
1731 | else { |
1732 | assert(x1.getIdx() == x2.getIdx()); |
1733 | pinsrq(x1, op, imm); |
1734 | } |
1735 | } |
1736 | void uni_vpinsrq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1737 | const Xbyak::Operand &op, const int imm) { |
1738 | vpinsrq(x1, x2, op, imm); |
1739 | } |
1740 | |
1741 | void uni_vpinsrw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1742 | const Xbyak::Operand &op, const int imm) { |
1743 | if (is_valid_isa(avx)) |
1744 | vpinsrw(x1, x2, op, imm); |
1745 | else { |
1746 | assert(x1.getIdx() == x2.getIdx()); |
1747 | pinsrw(x1, op, imm); |
1748 | } |
1749 | } |
1750 | void uni_vpinsrw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1751 | const Xbyak::Operand &op, const int imm) { |
1752 | vpinsrw(x1, x2, op, imm); |
1753 | } |
1754 | |
1755 | void uni_vpextrb( |
1756 | const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) { |
1757 | if (is_valid_isa(avx)) |
1758 | vpextrb(op, x, imm); |
1759 | else |
1760 | pextrb(op, x, imm); |
1761 | } |
1762 | |
1763 | void uni_vpextrb( |
1764 | const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) { |
1765 | vpextrb(op, x, imm); |
1766 | } |
1767 | |
1768 | void uni_vpextrw( |
1769 | const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) { |
1770 | if (is_valid_isa(avx)) |
1771 | vpextrw(op, x, imm); |
1772 | else |
1773 | pextrw(op, x, imm); |
1774 | } |
1775 | void uni_vpextrw( |
1776 | const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) { |
1777 | vpextrw(op, x, imm); |
1778 | } |
1779 | |
1780 | void uni_vpextrd( |
1781 | const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) { |
1782 | if (is_valid_isa(avx)) |
1783 | vpextrd(op, x, imm); |
1784 | else |
1785 | pextrd(op, x, imm); |
1786 | } |
1787 | void uni_vpextrd( |
1788 | const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) { |
1789 | vpextrd(op, x, imm); |
1790 | } |
1791 | |
1792 | void uni_vpextrq( |
1793 | const Xbyak::Operand &op, const Xbyak::Xmm &x, const int imm) { |
1794 | if (is_valid_isa(avx)) |
1795 | vpextrq(op, x, imm); |
1796 | else |
1797 | pextrq(op, x, imm); |
1798 | } |
1799 | void uni_vpextrq( |
1800 | const Xbyak::Operand &op, const Xbyak::Ymm &x, const int imm) { |
1801 | vpextrq(op, x, imm); |
1802 | } |
1803 | |
1804 | void uni_vpmaxsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1805 | const Xbyak::Operand &op) { |
1806 | if (is_valid_isa(avx)) |
1807 | vpmaxsd(x1, x2, op); |
1808 | else { |
1809 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1810 | pmaxsd(x1, op); |
1811 | } |
1812 | } |
1813 | |
1814 | void uni_vpmaxsd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1815 | const Xbyak::Operand &op) { |
1816 | vpmaxsd(x1, x2, op); |
1817 | } |
1818 | |
1819 | void uni_vpmaxsb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1820 | const Xbyak::Operand &op) { |
1821 | if (is_valid_isa(avx)) |
1822 | vpmaxsb(x1, x2, op); |
1823 | else { |
1824 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1825 | pmaxsb(x1, op); |
1826 | } |
1827 | } |
1828 | |
1829 | void uni_vpmaxsb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1830 | const Xbyak::Operand &op) { |
1831 | vpmaxsb(x1, x2, op); |
1832 | } |
1833 | |
1834 | void uni_vpminub(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1835 | const Xbyak::Operand &op) { |
1836 | if (is_valid_isa(avx)) |
1837 | vpminub(x1, x2, op); |
1838 | else { |
1839 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1840 | pminub(x1, op); |
1841 | } |
1842 | } |
1843 | |
1844 | // Note: instructions below are used in custom forks and are put here to |
1845 | // decrease maintenance burden on user side. |
1846 | // Once the instruction becomes used in one of oneDNN implementations, |
1847 | // please, move out of this section. |
1848 | void uni_vpshufb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1849 | const Xbyak::Operand &op) { |
1850 | if (is_valid_isa(avx)) |
1851 | vpshufb(x1, x2, op); |
1852 | else { |
1853 | assert(x1.getIdx() == x2.getIdx()); |
1854 | pshufb(x1, op); |
1855 | } |
1856 | } |
1857 | void uni_vpshufb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1858 | const Xbyak::Operand &op) { |
1859 | vpshufb(x1, x2, op); |
1860 | } |
1861 | |
1862 | void uni_vpand(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1863 | const Xbyak::Operand &op) { |
1864 | if (is_valid_isa(avx512_core) && x1.getBit() == 512) |
1865 | vpandd(x1, x2, op); |
1866 | else if (is_valid_isa(avx)) |
1867 | vpand(x1, x2, op); |
1868 | else { |
1869 | assert(x1.getIdx() == x2.getIdx()); |
1870 | pand(x1, op); |
1871 | } |
1872 | } |
1873 | |
1874 | void uni_vpslldq( |
1875 | const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { |
1876 | if (is_valid_isa(avx)) |
1877 | vpslldq(x, op, imm); |
1878 | else { |
1879 | assert(x.isEqualIfNotInherited(op)); |
1880 | pslldq(x, imm); |
1881 | } |
1882 | } |
1883 | void uni_vpslldq( |
1884 | const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { |
1885 | vpslldq(x, op, imm); |
1886 | } |
1887 | |
1888 | void uni_vpmovsxwd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1889 | if (is_valid_isa(avx)) |
1890 | vpmovsxwd(x, op); |
1891 | else |
1892 | pmovsxwd(x, op); |
1893 | } |
1894 | void uni_vpmovsxwd(const Xbyak::Ymm &y, const Xbyak::Operand &op) { |
1895 | vpmovsxwd(y, op); |
1896 | } |
1897 | |
1898 | void uni_vpmovsxdq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1899 | if (is_valid_isa(avx)) |
1900 | vpmovsxdq(x, op); |
1901 | else |
1902 | pmovsxdq(x, op); |
1903 | } |
1904 | void uni_vpmovsxdq(const Xbyak::Ymm &y, const Xbyak::Operand &op) { |
1905 | vpmovsxdq(y, op); |
1906 | } |
1907 | |
1908 | void uni_vpmovzxwd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1909 | if (is_valid_isa(avx)) |
1910 | vpmovzxwd(x, op); |
1911 | else |
1912 | pmovzxwd(x, op); |
1913 | } |
1914 | void uni_vpmovzxwd(const Xbyak::Ymm &y, const Xbyak::Operand &op) { |
1915 | vpmovzxwd(y, op); |
1916 | } |
1917 | |
1918 | void uni_vpcmpeqd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1919 | const Xbyak::Operand &op) { |
1920 | if (is_valid_isa(avx)) |
1921 | vpcmpeqd(x1, x2, op); |
1922 | else { |
1923 | if (x1.getIdx() != x2.getIdx()) uni_vmovups(x1, x2); |
1924 | pcmpeqd(x1, op); |
1925 | } |
1926 | } |
1927 | void uni_vpcmpeqd(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1928 | const Xbyak::Operand &op) { |
1929 | vpcmpeqd(x1, x2, op); |
1930 | } |
1931 | |
1932 | void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1933 | const Xbyak::Operand &op) { |
1934 | if (is_valid_isa(avx)) |
1935 | vpackusdw(x1, x2, op); |
1936 | else { |
1937 | assert(x1.getIdx() == x2.getIdx()); |
1938 | packusdw(x1, op); |
1939 | } |
1940 | } |
1941 | void uni_vpackusdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, |
1942 | const Xbyak::Operand &op) { |
1943 | vpackusdw(x1, x2, op); |
1944 | } |
1945 | |
1946 | void uni_vpminsd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, |
1947 | const Xbyak::Operand &op) { |
1948 | if (is_valid_isa(avx)) |
1949 | vpminsd(x1, x2, op); |
1950 | else { |
1951 | if (x1.getIdx() != x2.getIdx()) movdqa(x1, x2); |
1952 | pminsd(x1, op); |
1953 | } |
1954 | } |
1955 | |
1956 | void uni_movshdup(const Xbyak::Xmm &x, const Xbyak::Operand &op) { |
1957 | if (is_valid_isa(avx)) |
1958 | vmovshdup(x, op); |
1959 | else |
1960 | movshdup(x, op); |
1961 | } |
1962 | void uni_movshdup(const Xbyak::Ymm &x, const Xbyak::Operand &op) { |
1963 | vmovshdup(x, op); |
1964 | } |
1965 | |
1966 | void uni_vmovhlps( |
1967 | const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Xmm &x3) { |
1968 | if (is_valid_isa(avx)) |
1969 | vmovhlps(x1, x2, x3); |
1970 | else { |
1971 | assert(x1.getIdx() == x2.getIdx()); |
1972 | movhlps(x1, x3); |
1973 | } |
1974 | } |
1975 | void uni_vmovhlps( |
1976 | const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Ymm &x3) { |
1977 | vmovhlps(x1, x2, x3); |
1978 | } |
1979 | |
1980 | // End of custom instructions section. |
1981 | |
1982 | void mul_by_const( |
1983 | const Xbyak::Reg &out, const Xbyak::Reg64 &tmp, int value) { |
1984 | // Generates a shift + add sequence for multiplicating contents of the |
1985 | // out register by a known JIT-time value. Clobbers the tmp register. |
1986 | // |
1987 | // Pros compared to mul/imul: |
1988 | // - does not require using known registers |
1989 | // Still, there are probably a lot of cases when mul/imul is faster on |
1990 | // Intel(R) Core(TM) processors. Not intended for critical path. |
1991 | |
1992 | // TODO: detect when overflow is emminent (Roma) |
1993 | // TODO: detect when using mul/imul is a better option (Roma) |
1994 | |
1995 | int p = 0; // the current power of 2 |
1996 | int old_p = 0; // the last seen power of 2 such that value[old_p] != 0 |
1997 | |
1998 | xor_(tmp, tmp); |
1999 | while (value) { |
2000 | if (value & 1) { |
2001 | int shift = p - old_p; |
2002 | if (shift) { |
2003 | shl(out, shift); |
2004 | old_p = p; |
2005 | } |
2006 | add(tmp, out); |
2007 | } |
2008 | value >>= 1; |
2009 | p++; |
2010 | } |
2011 | mov(out, tmp); |
2012 | } |
2013 | |
2014 | /* |
2015 | Saturation facility functions. enable to prepare the register |
2016 | holding the saturation upperbound and apply the saturation on |
2017 | the floating point register |
2018 | */ |
2019 | template <typename Vmm> |
2020 | void init_vmm(Vmm vmm, Xbyak::Reg64 reg_tmp, float value) { |
2021 | Xbyak::Xmm xmm_tmp(vmm.getIdx()); |
2022 | mov(reg_tmp, float2int(value)); |
2023 | uni_vmovq(xmm_tmp, reg_tmp); |
2024 | if (vmm.isYMM() || vmm.isZMM()) |
2025 | uni_vbroadcastss(vmm, xmm_tmp); |
2026 | else |
2027 | uni_vshufps(vmm, xmm_tmp, xmm_tmp, 0); |
2028 | } |
2029 | |
2030 | template <typename Vmm> |
2031 | void init_saturate_f32(Vmm vmm_lbound, Vmm vmm_ubound, Xbyak::Reg64 reg_tmp, |
2032 | data_type_t idt, data_type_t odt, bool force_lbound = false) { |
2033 | using namespace data_type; |
2034 | if (!((idt == f32) && utils::one_of(odt, u8, s8, s32))) return; |
2035 | |
2036 | assert(IMPLICATION(idt == u8 || force_lbound, |
2037 | vmm_lbound.getIdx() != vmm_ubound.getIdx())); |
2038 | |
2039 | // No need to saturate on lower bound for signed integer types, as |
2040 | // the conversion to int would return INT_MIN, and then proper |
2041 | // saturation will happen in store_data. The param force_lbound, will |
2042 | // force saturate values unconditionally to lbound. |
2043 | if (odt == u8) |
2044 | uni_vpxor(vmm_lbound, vmm_lbound, vmm_lbound); |
2045 | else if (force_lbound) { |
2046 | const float saturation_lbound = odt == s8 ? INT8_MIN : INT32_MIN; |
2047 | init_vmm(vmm_lbound, reg_tmp, saturation_lbound); |
2048 | } |
2049 | |
2050 | const float saturation_ubound = types::max_value<float>(odt); |
2051 | init_vmm(vmm_ubound, reg_tmp, saturation_ubound); |
2052 | } |
2053 | |
2054 | template <typename Vmm> |
2055 | void saturate_f32(const Vmm &vmm, const Vmm &vmm_lbound, |
2056 | const Vmm &vmm_ubound, data_type_t odt, bool force_lbound = false) { |
2057 | // This function is used to saturate to odt in f32 before converting |
2058 | // to s32 in order to avoid bad saturation due to cvtps2dq |
2059 | // behavior (it returns INT_MIN if the f32 is out of the |
2060 | // s32 range) |
2061 | using namespace data_type; |
2062 | if (!utils::one_of(odt, u8, s8, s32)) return; |
2063 | |
2064 | // no need to apply lower saturation bound when odt is |
2065 | // signed, as cvtps2dq will return MIN_INT if the value |
2066 | // does not fit. The param force_lbound, will force saturate values |
2067 | // unconditionally to lbound. |
2068 | if (odt == u8 || force_lbound) { |
2069 | if (is_valid_isa(avx)) |
2070 | vmaxps(vmm, vmm, vmm_lbound); |
2071 | else |
2072 | maxps(vmm, vmm_lbound); |
2073 | } |
2074 | if (is_valid_isa(avx)) |
2075 | vminps(vmm, vmm, vmm_ubound); |
2076 | else |
2077 | minps(vmm, vmm_ubound); |
2078 | } |
2079 | |
2080 | /** |
2081 | * load_bytes is the utility function to facilitate loading of |
2082 | * load_size (0 <= load_size <= 32) many contiguous bytes into the Xmm/Ymm |
2083 | * register from the memory referenced by ptr[reg + offset] address. |
2084 | * |
2085 | * Functionally, invocation of load_bytes is equivalent to |
2086 | * the following loop: |
2087 | * |
2088 | * for (int idx = 0; idx < load_size; ++idx) |
2089 | * vpinsrb(xmm, xmm, ptr[reg + offset + idx], idx); |
2090 | * |
2091 | * TODO: Add an option to zero-out unloaded bytes in the Xmm register. |
2092 | * TODO: Add an option for unsafe_load wherein one could read outside the |
2093 | * provided memory buffer so as to minimize the total number of read |
2094 | * memory instructions. |
2095 | */ |
2096 | template <typename Vmm> |
2097 | void load_bytes( |
2098 | const Vmm &vmm, const Xbyak::Address &src_addr, int load_size) { |
2099 | |
2100 | constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value |
2101 | || std::is_same<Vmm, Xbyak::Xmm>::value; |
2102 | if (!is_vmm_supported) { |
2103 | assert("load_bytes() is only supported for xmm and ymm" ); |
2104 | return; |
2105 | } |
2106 | |
2107 | const auto addr = [&](int bytes_offset) { |
2108 | return ptr[src_addr.getRegExp() |
2109 | + Xbyak::RegExp(bytes_offset * sizeof(int8_t))]; |
2110 | }; |
2111 | |
2112 | helper_load_bytes(vmm, load_size, addr); |
2113 | } |
2114 | |
2115 | template <typename Vmm> |
2116 | void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, |
2117 | int load_size) { |
2118 | |
2119 | constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value |
2120 | || std::is_same<Vmm, Xbyak::Xmm>::value; |
2121 | if (!is_vmm_supported) { |
2122 | assert("load_bytes() is only supported for xmm and ymm" ); |
2123 | return; |
2124 | } |
2125 | |
2126 | // Ensure offset is at most 4 bytes to be encoded in the instruction |
2127 | assert(offset >= INT_MIN && offset <= INT_MAX); |
2128 | |
2129 | const auto addr = [&](int bytes_offset) { |
2130 | return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; |
2131 | }; |
2132 | |
2133 | helper_load_bytes(vmm, load_size, addr); |
2134 | } |
2135 | |
2136 | private: |
2137 | template <typename Vmm, typename AddrFunc> |
2138 | void helper_load_bytes( |
2139 | const Vmm &vmm, int load_size, const AddrFunc &addr) { |
2140 | |
2141 | constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value; |
2142 | constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
2143 | assert((is_xmm || is_ymm) && "only Xmm or Ymm registers are allowed" ); |
2144 | |
2145 | MAYBE_UNUSED(is_xmm); |
2146 | MAYBE_UNUSED(is_ymm); |
2147 | |
2148 | // Ensure data fits completely inside the Xmm/Ymm register |
2149 | assert(load_size >= 0 && load_size <= 32); |
2150 | |
2151 | // At most 16 bytes can fit inside the Xmm register |
2152 | assert(IMPLICATION(load_size > 16, is_ymm)); |
2153 | |
2154 | // Ensure that vector register is compatible with the ISA in hand |
2155 | assert(IMPLICATION(is_ymm, is_valid_isa(avx))); |
2156 | |
2157 | assert(is_valid_isa(sse41) |
2158 | && "routine is not supported for the current isa" ); |
2159 | |
2160 | auto xmm = Xbyak::Xmm(vmm.getIdx()); |
2161 | auto ymm = Xbyak::Ymm(vmm.getIdx()); |
2162 | |
2163 | if (load_size == 32) { |
2164 | vmovups(ymm, addr(0)); |
2165 | return; |
2166 | } |
2167 | |
2168 | int start_bytes = 0; |
2169 | int bytes_to_load = load_size; |
2170 | |
2171 | if (load_size > 16) { |
2172 | // Prepare to insert to upper bits of ymm |
2173 | start_bytes = 16; |
2174 | bytes_to_load -= 16; |
2175 | } |
2176 | |
2177 | if (bytes_to_load >= 8 && bytes_to_load < 16) |
2178 | uni_vpinsrq(xmm, xmm, addr(start_bytes), 0); |
2179 | else if (bytes_to_load == 16) |
2180 | uni_vmovdqu(xmm, addr(start_bytes)); |
2181 | |
2182 | switch (bytes_to_load) { |
2183 | case 0: break; |
2184 | case 1: uni_vpinsrb(xmm, xmm, addr(start_bytes), 0); break; |
2185 | case 2: uni_vpinsrw(xmm, xmm, addr(start_bytes), 0); break; |
2186 | case 3: |
2187 | uni_vpinsrw(xmm, xmm, addr(start_bytes), 0); |
2188 | uni_vpinsrb(xmm, xmm, addr(start_bytes + 2), 2); |
2189 | break; |
2190 | case 4: uni_vpinsrd(xmm, xmm, addr(start_bytes), 0); break; |
2191 | case 5: |
2192 | uni_vpinsrd(xmm, xmm, addr(start_bytes), 0); |
2193 | uni_vpinsrb(xmm, xmm, addr(start_bytes + 4), 4); |
2194 | break; |
2195 | case 6: |
2196 | uni_vpinsrd(xmm, xmm, addr(start_bytes), 0); |
2197 | uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2); |
2198 | break; |
2199 | case 7: |
2200 | uni_vpinsrd(xmm, xmm, addr(start_bytes), 0); |
2201 | uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2); |
2202 | uni_vpinsrb(xmm, xmm, addr(start_bytes + 6), 6); |
2203 | break; |
2204 | case 8: break; |
2205 | case 9: uni_vpinsrb(xmm, xmm, addr(start_bytes + 8), 8); break; |
2206 | case 10: uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4); break; |
2207 | case 11: |
2208 | uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4); |
2209 | uni_vpinsrb(xmm, xmm, addr(start_bytes + 10), 10); |
2210 | break; |
2211 | case 12: uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); break; |
2212 | case 13: |
2213 | uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); |
2214 | uni_vpinsrb(xmm, xmm, addr(start_bytes + 12), 12); |
2215 | break; |
2216 | case 14: |
2217 | uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); |
2218 | uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6); |
2219 | break; |
2220 | case 15: |
2221 | uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); |
2222 | uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6); |
2223 | uni_vpinsrb(xmm, xmm, addr(start_bytes + 14), 14); |
2224 | break; |
2225 | case 16: break; |
2226 | default: assert(!"improper load size" ); |
2227 | } |
2228 | |
2229 | if (load_size > 16) { |
2230 | vinsertf128(ymm, ymm, xmm, 1); // insert to upper bits of ymm |
2231 | vinsertf128(ymm, ymm, addr(0), 0); // insert to lower bits of ymm |
2232 | } |
2233 | } |
2234 | |
2235 | /** |
2236 | * store_bytes is the utility function to facilitate storing of |
2237 | * store_size (0 <= store_size <= 32) many contiguous bytes from the Xmm/Ymm |
2238 | * register into the memory referenced by ptr[reg + offset] address. |
2239 | * |
2240 | * Additionally, when store_size > 16, the input Ymm register will not be |
2241 | * preserved due to the usage of vextracti128 instruction. |
2242 | * |
2243 | * Functionally, invocation of store_bytes is equivalent |
2244 | * to the following loop: |
2245 | * |
2246 | * for (int idx = 0; idx < store_size; ++idx) |
2247 | * vpextrb(ptr[reg + offset + idx], xmm, idx); |
2248 | * |
2249 | * TODO: Add an option for unsafe_store wherein one could store extra dwords |
2250 | * past the provided memory buffer so as to minimize the total number of |
2251 | * write memory instructions. |
2252 | */ |
2253 | public: |
2254 | template <typename Vmm> |
2255 | void store_bytes( |
2256 | const Vmm &vmm, const Xbyak::Address &dst_addr, int store_size) { |
2257 | const auto addr = [&](int bytes_offset) { |
2258 | return ptr[dst_addr.getRegExp() |
2259 | + Xbyak::RegExp(bytes_offset * sizeof(int8_t))]; |
2260 | }; |
2261 | store_bytes(vmm, store_size, addr); |
2262 | } |
2263 | |
2264 | template <typename Vmm> |
2265 | void store_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int64_t offset, |
2266 | int store_size) { |
2267 | |
2268 | // Ensure offset is at most 4 bytes to be encoded in the instruction |
2269 | assert(offset >= INT_MIN && offset <= INT_MAX); |
2270 | |
2271 | const auto addr = [&](int bytes_offset) { |
2272 | return ptr[reg + offset + bytes_offset * sizeof(int8_t)]; |
2273 | }; |
2274 | |
2275 | store_bytes(vmm, store_size, addr); |
2276 | } |
2277 | |
2278 | private: |
2279 | template <typename Vmm, typename AddrFunc> |
2280 | void store_bytes(const Vmm &vmm, int store_size, const AddrFunc &addr) { |
2281 | |
2282 | constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value; |
2283 | constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
2284 | static_assert( |
2285 | is_xmm || is_ymm, "only Xmm or Ymm registers are allowed" ); |
2286 | |
2287 | MAYBE_UNUSED(is_xmm); |
2288 | MAYBE_UNUSED(is_ymm); |
2289 | |
2290 | // Ensure data fits completely inside the Xmm/Ymm register |
2291 | assert(store_size >= 0 && store_size <= 32); |
2292 | |
2293 | // At most 16 bytes can fit inside the Xmm register |
2294 | assert(IMPLICATION(store_size > 16, is_ymm)); |
2295 | |
2296 | // Ensure that vector register is compatible with the ISA in hand |
2297 | assert(IMPLICATION(is_ymm, is_valid_isa(avx))); |
2298 | |
2299 | assert(is_valid_isa(sse41) |
2300 | && "routine is not supported for the current isa" ); |
2301 | |
2302 | auto xmm = Xbyak::Xmm(vmm.getIdx()); |
2303 | auto ymm = Xbyak::Ymm(vmm.getIdx()); |
2304 | |
2305 | if (store_size == 32) { |
2306 | vmovups(addr(0), ymm); |
2307 | return; |
2308 | } |
2309 | |
2310 | int start_bytes = 0; |
2311 | int bytes_to_store = store_size; |
2312 | |
2313 | if (store_size > 16) { |
2314 | vmovdqu(addr(0), xmm); // load lower bits from ymm |
2315 | start_bytes = 16; |
2316 | bytes_to_store -= 16; |
2317 | vextractf128(xmm, ymm, 1); // load upper bits from ymm into xmm |
2318 | } |
2319 | |
2320 | if (bytes_to_store >= 8 && bytes_to_store < 16) |
2321 | uni_vpextrq(addr(start_bytes), xmm, 0); |
2322 | else if (bytes_to_store == 16) |
2323 | uni_vmovdqu(addr(start_bytes), xmm); |
2324 | |
2325 | switch (bytes_to_store) { |
2326 | case 0: break; |
2327 | case 1: uni_vpextrb(addr(start_bytes), xmm, 0); break; |
2328 | case 2: uni_vpextrw(addr(start_bytes), xmm, 0); break; |
2329 | case 3: |
2330 | uni_vpextrw(addr(start_bytes), xmm, 0); |
2331 | uni_vpextrb(addr(start_bytes + 2), xmm, 2); |
2332 | break; |
2333 | case 4: uni_vpextrd(addr(start_bytes), xmm, 0); break; |
2334 | case 5: |
2335 | uni_vpextrd(addr(start_bytes), xmm, 0); |
2336 | uni_vpextrb(addr(start_bytes + 4), xmm, 4); |
2337 | break; |
2338 | case 6: |
2339 | uni_vpextrd(addr(start_bytes), xmm, 0); |
2340 | uni_vpextrw(addr(start_bytes + 4), xmm, 2); |
2341 | break; |
2342 | case 7: |
2343 | uni_vpextrd(addr(start_bytes), xmm, 0); |
2344 | uni_vpextrw(addr(start_bytes + 4), xmm, 2); |
2345 | uni_vpextrb(addr(start_bytes + 6), xmm, 6); |
2346 | break; |
2347 | case 8: break; |
2348 | case 9: uni_vpextrb(addr(start_bytes + 8), xmm, 8); break; |
2349 | case 10: uni_vpextrw(addr(start_bytes + 8), xmm, 4); break; |
2350 | case 11: |
2351 | uni_vpextrw(addr(start_bytes + 8), xmm, 4); |
2352 | uni_vpextrb(addr(start_bytes + 10), xmm, 10); |
2353 | break; |
2354 | case 12: uni_vpextrd(addr(start_bytes + 8), xmm, 2); break; |
2355 | case 13: |
2356 | uni_vpextrd(addr(start_bytes + 8), xmm, 2); |
2357 | uni_vpextrb(addr(start_bytes + 12), xmm, 12); |
2358 | break; |
2359 | case 14: |
2360 | uni_vpextrd(addr(start_bytes + 8), xmm, 2); |
2361 | uni_vpextrw(addr(start_bytes + 12), xmm, 6); |
2362 | break; |
2363 | case 15: |
2364 | uni_vpextrd(addr(start_bytes + 8), xmm, 2); |
2365 | uni_vpextrw(addr(start_bytes + 12), xmm, 6); |
2366 | uni_vpextrb(addr(start_bytes + 14), xmm, 14); |
2367 | break; |
2368 | case 16: break; |
2369 | default: assert(!"improper store size" ); |
2370 | } |
2371 | } |
2372 | |
2373 | public: |
2374 | /** |
2375 | * load_bytes_to_dword_extension is the utility function to facilitate |
2376 | * loading of load_size (0 <= load_size <= 16) many contiguous bytes in |
2377 | * the Xmm register from the memory referenced by ptr[reg + offset] |
2378 | * address and then do signed/zero extension of those to double words. |
2379 | * |
2380 | * Functionally, invocation of load_bytes_to_dword_extension is equivalent |
2381 | * to the following: |
2382 | * |
2383 | * for (int idx = 0; idx < load_size; ++idx) |
2384 | * vpinsrb(xmm, xmm, ptr[reg + offset + idx], idx); |
2385 | * if (is_signed) vpmovsxbd(vmm, vmm); else vpmovzxbd(vmm, vmm); |
2386 | * |
2387 | * Valid values for the load_size variable are: |
2388 | * [0..4] for XMM version of the function |
2389 | * [0..8] for YMM version of the function. |
2390 | * TODO: Implement this routine for every ISA. |
2391 | */ |
2392 | template <typename Vmm> |
2393 | void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, |
2394 | int64_t offset, bool is_signed, int load_size) { |
2395 | // Ensure offset is at most 4 bytes to be encoded in the instruction |
2396 | assert(offset >= INT_MIN && offset <= INT_MAX); |
2397 | load_bytes_to_dword_extension( |
2398 | vmm, ptr[reg + offset], is_signed, load_size); |
2399 | } |
2400 | |
2401 | template <typename Vmm> |
2402 | void load_bytes_to_dword_extension(const Vmm &vmm, |
2403 | const Xbyak::Address &src_addr, bool is_signed, int load_size) { |
2404 | |
2405 | constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value |
2406 | || std::is_same<Vmm, Xbyak::Xmm>::value; |
2407 | if (!is_vmm_supported) { |
2408 | assert("load_bytes_to_dword_extension() is only supported for xmm " |
2409 | "and ymm" ); |
2410 | return; |
2411 | } |
2412 | |
2413 | constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value; |
2414 | constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
2415 | MAYBE_UNUSED(is_xmm); |
2416 | MAYBE_UNUSED(is_ymm); |
2417 | |
2418 | // Ensure extended double words fit inside Ymm (32 * load_size <= 256) |
2419 | assert(load_size >= 0 && load_size <= 8); |
2420 | // For Xmm register, load capacity is halved (32 * load_size <= 128) |
2421 | assert(IMPLICATION(is_xmm, load_size <= 4)); |
2422 | |
2423 | // Ensure that vector register is compatible with the ISA in hand |
2424 | assert(IMPLICATION(is_ymm, is_valid_isa(avx))); |
2425 | |
2426 | assert(is_valid_isa(sse41) |
2427 | && "routine is not supported for the current isa" ); |
2428 | |
2429 | // For load_size == 8/4, do load/extension in one go |
2430 | if (load_size == 8) { |
2431 | const auto ymm = Xbyak::Ymm(vmm.getIdx()); |
2432 | if (is_signed) |
2433 | vpmovsxbd(ymm, src_addr); |
2434 | else |
2435 | vpmovzxbd(ymm, src_addr); |
2436 | } else if (load_size == 4) { |
2437 | const auto xmm = Xbyak::Xmm(vmm.getIdx()); |
2438 | if (is_signed) |
2439 | uni_vpmovsxbd(xmm, src_addr); |
2440 | else |
2441 | uni_vpmovzxbd(xmm, src_addr); |
2442 | } else { |
2443 | load_bytes(vmm, src_addr, load_size); |
2444 | if (is_signed) |
2445 | uni_vpmovsxbd(vmm, vmm); |
2446 | else |
2447 | uni_vpmovzxbd(vmm, vmm); |
2448 | } |
2449 | } |
2450 | |
2451 | /* A utility function to store data of type type_out from vmm register |
2452 | * into the memory. Moreover store_size many chunks are written to the |
2453 | * memory beginning with ptr[reg + offset] address. |
2454 | * |
2455 | * Note: Content of Vmm register is not guaranteed to be preserved after the |
2456 | * invocation of this routine. |
2457 | * |
2458 | * TODO: Support for every possible data type. |
2459 | */ |
2460 | template <typename Vmm> |
2461 | void store_data(data_type_t type_out, const Vmm &vmm, |
2462 | const Xbyak::Reg64 ®, int64_t offset, int store_size) { |
2463 | constexpr bool is_vmm_supported = std::is_same<Vmm, Xbyak::Ymm>::value |
2464 | || std::is_same<Vmm, Xbyak::Xmm>::value; |
2465 | using supported_vmm_t = typename utils::conditional<is_vmm_supported, |
2466 | Vmm, Xbyak::Ymm /*dummy*/>::type; |
2467 | |
2468 | if (!is_vmm_supported) { |
2469 | assert("store_data() not supported" ); |
2470 | return; |
2471 | } |
2472 | helper_store_data(type_out, supported_vmm_t(vmm.getIdx()), reg, offset, |
2473 | store_size); |
2474 | } |
2475 | |
2476 | private: |
2477 | template <typename Vmm> |
2478 | void helper_store_data(data_type_t type_out, const Vmm &vmm, |
2479 | const Xbyak::Reg64 ®, int64_t offset, int store_size) { |
2480 | |
2481 | assert(is_valid_isa(sse41) |
2482 | && "routine is not supported for the current isa" ); |
2483 | constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
2484 | |
2485 | // Owing to lack of cross lane operations in non avx2 compatible isa |
2486 | // this functionality remains unimplemented for int8 data type |
2487 | const bool is_int8_dt |
2488 | = utils::one_of(type_out, data_type::s8, data_type::u8); |
2489 | assert(IMPLICATION(is_ymm && is_int8_dt, is_valid_isa(avx2))); |
2490 | |
2491 | // Ensure that vector register is compatible with the ISA in hand |
2492 | assert(IMPLICATION(is_ymm, is_valid_isa(avx))); |
2493 | |
2494 | MAYBE_UNUSED(is_ymm); |
2495 | MAYBE_UNUSED(is_int8_dt); |
2496 | |
2497 | auto ymm = Xbyak::Ymm(vmm.getIdx()); |
2498 | auto xmm = Xbyak::Xmm(vmm.getIdx()); |
2499 | |
2500 | switch (type_out) { |
2501 | case data_type::f32: |
2502 | case data_type::s32: |
2503 | store_bytes(vmm, reg, offset, sizeof(int32_t) * store_size); |
2504 | break; |
2505 | case data_type::u8: |
2506 | case data_type::s8: |
2507 | uni_vpackssdw(vmm, vmm, vmm); |
2508 | // For each y_i of size 64 bits, following cross lane |
2509 | // operation on ymm yields |
2510 | // [y_3 y_2 y_1 y_0] |--> [0 0 y_2 y_0] |
2511 | if (is_ymm) vpermq(ymm, ymm, 0x08); |
2512 | if (type_out == data_type::s8) |
2513 | uni_vpacksswb(vmm, vmm, vmm); |
2514 | else |
2515 | uni_vpackuswb(vmm, vmm, vmm); |
2516 | store_bytes(vmm, reg, offset, store_size); |
2517 | break; |
2518 | case data_type::bf16: |
2519 | vcvtneps2bf16(xmm, vmm, |
2520 | is_valid_isa(avx512_core_bf16) ? Xbyak::EvexEncoding |
2521 | : Xbyak::VexEncoding); |
2522 | store_bytes(vmm, reg, offset, sizeof(bfloat16_t) * store_size); |
2523 | break; |
2524 | case data_type::f16: |
2525 | vcvtps2ph(xmm, vmm, _op_mxcsr); |
2526 | store_bytes(vmm, reg, offset, sizeof(float16_t) * store_size); |
2527 | break; |
2528 | default: assert(!"unsupported destination data type" ); |
2529 | } |
2530 | } |
2531 | |
2532 | public: |
2533 | /* A utility function to load data of type type_in to vmm register |
2534 | * from the memory. Moreover load_size many chunks are read from the memory |
2535 | * beginning with ptr[reg + offset] address. |
2536 | * |
2537 | * TODO: Support for every possible data type. |
2538 | */ |
2539 | template <typename Vmm> |
2540 | void load_data(data_type_t type_in, const Vmm &vmm, const Xbyak::Reg64 ®, |
2541 | int64_t offset, int load_size) { |
2542 | // Ensure offset is at most 4 bytes to be encoded in the instruction |
2543 | assert(offset >= INT_MIN && offset <= INT_MAX); |
2544 | load_data(type_in, vmm, ptr[reg + offset], load_size); |
2545 | } |
2546 | |
2547 | template <typename Vmm> |
2548 | void load_data(data_type_t type_in, const Vmm &vmm, |
2549 | const Xbyak::Address &src_addr, int load_size) { |
2550 | |
2551 | assert(is_valid_isa(sse41) |
2552 | && "routine is not supported for the current isa" ); |
2553 | |
2554 | switch (type_in) { |
2555 | case data_type::f32: |
2556 | case data_type::s32: |
2557 | load_bytes(vmm, src_addr, sizeof(int32_t) * load_size); |
2558 | break; |
2559 | case data_type::s8: |
2560 | case data_type::u8: |
2561 | load_bytes_to_dword_extension( |
2562 | vmm, src_addr, type_in == data_type::s8, load_size); |
2563 | break; |
2564 | case data_type::bf16: |
2565 | load_bytes(vmm, src_addr, sizeof(bfloat16_t) * load_size); |
2566 | uni_vpmovzxwd(vmm, vmm); |
2567 | uni_vpslld(vmm, vmm, 16); |
2568 | break; |
2569 | case data_type::f16: |
2570 | load_bytes(vmm, src_addr, sizeof(float16_t) * load_size); |
2571 | vcvtph2ps(vmm, |
2572 | typename vreg_traits<Vmm>::Vmm_lower_t(vmm.getIdx())); |
2573 | break; |
2574 | default: assert(!"unsupported source data type" ); |
2575 | } |
2576 | } |
2577 | |
2578 | /* A utility function to process f32 tail (load, store or other) depending |
2579 | * on tail size, stored in Reg64. Tail size must be value from 0 to 3/7 |
2580 | * (Xmm/Ymm). Tail process functions require integer as argument to specify |
2581 | * behavior for each tail size. |
2582 | * |
2583 | * Only supported for Xmm and Ymm. |
2584 | */ |
2585 | template <typename Vmm> |
2586 | void runtime_tail_process(const Xbyak::Reg64 ®_tail, |
2587 | const Xbyak::Reg64 ®_tmp, |
2588 | const std::function<void(int)> &tail_process, |
2589 | const data_type_t data_type = data_type::f32) { |
2590 | const auto simd_w |
2591 | = vreg_traits<Vmm>::vlen / types::data_type_size(data_type); |
2592 | |
2593 | Xbyak::Label label_tbl, label_tbl_end; |
2594 | std::vector<Xbyak::Label> l_case(simd_w); |
2595 | |
2596 | mov(reg_tmp, label_tbl); |
2597 | const Xbyak::Address label_address |
2598 | = ptr[reg_tmp + reg_tail * sizeof(void *)]; |
2599 | jmp(label_address, T_NEAR); |
2600 | |
2601 | // create jump table |
2602 | L(label_tbl); |
2603 | for (size_t i = 0; i < simd_w; i++) |
2604 | putL(l_case[i]); |
2605 | |
2606 | // cases for each tail size - from 0 to 3/7 |
2607 | L(l_case[0]); |
2608 | jmp(label_tbl_end, T_NEAR); |
2609 | for (size_t i = 1; i < simd_w; i++) { |
2610 | L(l_case[i]); |
2611 | tail_process(i); |
2612 | jmp(label_tbl_end, T_NEAR); |
2613 | } |
2614 | L(label_tbl_end); |
2615 | } |
2616 | |
2617 | DNNL_DISALLOW_COPY_AND_ASSIGN(jit_generator); |
2618 | |
2619 | public: |
2620 | /* All uni_ instructions -- apart from uni_vzeroupper() -- will comply with |
2621 | * the max_cpu_isa argument */ |
2622 | jit_generator(const char *name, void *code_ptr = nullptr, |
2623 | size_t code_size = MAX_CODE_SIZE, bool use_autogrow = true, |
2624 | cpu_isa_t max_cpu_isa = get_max_cpu_isa()) |
2625 | : Xbyak::MmapAllocator(name) |
2626 | , Xbyak::CodeGenerator(code_size, |
2627 | (code_ptr == nullptr && use_autogrow) ? Xbyak::AutoGrow |
2628 | : code_ptr, |
2629 | /*allocator=*/this) |
2630 | , max_cpu_isa_(max_cpu_isa) {} |
2631 | |
2632 | virtual ~jit_generator() {} |
2633 | |
2634 | virtual const char *name() const = 0; |
2635 | virtual const char *source_file() const = 0; |
2636 | |
2637 | void register_jit_code(const Xbyak::uint8 *code, size_t code_size) const { |
2638 | jit_utils::register_jit_code(code, code_size, name(), source_file()); |
2639 | } |
2640 | |
2641 | const Xbyak::uint8 *jit_ker() const { return jit_ker_; } |
2642 | |
2643 | template <typename... kernel_args_t> |
2644 | void operator()(kernel_args_t... args) const { |
2645 | using jit_kernel_func_t = void (*)(const kernel_args_t... args); |
2646 | auto *fptr = (jit_kernel_func_t)jit_ker_; |
2647 | (*fptr)(std::forward<kernel_args_t>(args)...); |
2648 | } |
2649 | |
2650 | virtual status_t create_kernel() { |
2651 | int err_code = Xbyak::GetError(); |
2652 | if (err_code == Xbyak::ERR_CANT_ALLOC) return status::out_of_memory; |
2653 | if (err_code != Xbyak::ERR_NONE) return status::runtime_error; |
2654 | generate(); |
2655 | jit_ker_ = getCode(); |
2656 | return (jit_ker_) ? status::success : status::runtime_error; |
2657 | } |
2658 | |
2659 | private: |
2660 | const cpu_isa_t max_cpu_isa_; |
2661 | const Xbyak::uint8 *getCode() { |
2662 | this->ready(); |
2663 | if (!is_initialized()) return nullptr; |
2664 | const Xbyak::uint8 *code = CodeGenerator::getCode(); |
2665 | register_jit_code(code, getSize()); |
2666 | return code; |
2667 | } |
2668 | |
2669 | inline bool is_valid_isa(cpu_isa_t isa) { |
2670 | return is_subset(isa, max_cpu_isa_) && mayiuse(isa); |
2671 | } |
2672 | |
2673 | static inline bool is_initialized() { |
2674 | return Xbyak::GetError() == Xbyak::ERR_NONE; |
2675 | } |
2676 | |
2677 | protected: |
2678 | virtual void generate() = 0; |
2679 | const Xbyak::uint8 *jit_ker_ = nullptr; |
2680 | }; |
2681 | |
2682 | } // namespace x64 |
2683 | } // namespace cpu |
2684 | } // namespace impl |
2685 | } // namespace dnnl |
2686 | |
2687 | #endif |
2688 | |