1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18
19#include "common/bfloat16.hpp"
20#include "common/c_types_map.hpp"
21
22#include "cpu/x64/jit_generator.hpp"
23#include "cpu/x64/shuffle/jit_uni_shuffle_kernel.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace Xbyak;
31
32#define GET_OFF(field) offsetof(jit_shuffle_call_s, field)
33
34static size_t get_padding_size(const jit_shuffle_conf_t &conf) {
35 const auto padding_tail_size = conf.c % conf.blk_size;
36 return (padding_tail_size) ? conf.blk_size - padding_tail_size : 0;
37}
38
39template <cpu_isa_t isa>
40jit_uni_shuffle_kernel_t<isa>::jit_uni_shuffle_kernel_t(
41 const jit_shuffle_conf_t conf)
42 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
43 , conf_(conf)
44 , padding_size_(get_padding_size(conf)) {}
45
46template <cpu_isa_t isa>
47void jit_uni_shuffle_kernel_t<isa>::prepare_mask() {}
48
49template <>
50void jit_uni_shuffle_kernel_t<avx512_core>::prepare_mask() {
51 const size_t tail_mask = (1ULL << conf_.simd_tail) - 1ULL;
52 const Reg64 &reg_tail = reg_tmp_;
53 mov(reg_tail.cvt32(), tail_mask);
54 kmovw(k_tail_mask_, reg_tail.cvt32());
55}
56
57template <>
58void jit_uni_shuffle_kernel_t<avx>::prepare_mask() {
59 static constexpr uint32_t mask[16]
60 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
61 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0, 0};
62 mov(reg_tmp_, reinterpret_cast<size_t>(&mask[8 - conf_.simd_tail]));
63 vmovups(vmm_tail_mask_, ptr[reg_tmp_]);
64}
65
66template <>
67void jit_uni_shuffle_kernel_t<avx512_core>::emu_gather_data(
68 const Reg64 &reg_src_addr, const int indices_idx, const int data_idx,
69 const bool is_tail) {
70 assert(conf_.data_type == data_type::bf16);
71
72 const Xmm xmm_tmp = Xmm(vmm_full_mask_.getIdx());
73 const Xmm xmm_dst = Xmm(vmm_tmp_.getIdx());
74
75 xor_(reg_tmp_, reg_tmp_);
76 mov(reg_tmp1_, reg_src_addr);
77
78 constexpr unsigned xmm_size_elem = 8; //bf16
79 constexpr unsigned xmm_size_elem_half = xmm_size_elem / 2;
80
81 const unsigned number_of_xmms = is_tail
82 ? utils::div_up(conf_.simd_tail, xmm_size_elem)
83 : utils::div_up(conf_.simd_w, xmm_size_elem);
84
85 for (unsigned i = 0; i < number_of_xmms; i++) {
86 const unsigned number_of_xmm_halfs = is_tail && i == number_of_xmms - 1
87 ? utils::div_up(
88 conf_.simd_tail, xmm_size_elem_half + i * xmm_size_elem)
89 : 2;
90
91 for (unsigned j = 0; j < number_of_xmm_halfs; j++) {
92 const unsigned rem = conf_.simd_tail % xmm_size_elem_half;
93 const unsigned number_of_values_to_load = is_tail
94 && i == number_of_xmms - 1
95 && j == number_of_xmm_halfs - 1 && rem
96 ? rem
97 : xmm_size_elem_half;
98
99 vextractf32x4(xmm_tmp, Zmm(indices_idx), j + i * 2);
100 for (unsigned k = 0; k < number_of_values_to_load; k++) {
101 vpextrd(reg_tmp_.cvt32(), xmm_tmp, k + j * xmm_size_elem_half);
102 add(reg_src_addr, reg_tmp_);
103 vpinsrw(xmm_dst, xmm_dst, ptr[reg_src_addr],
104 k + j * xmm_size_elem_half);
105 mov(reg_src_addr, reg_tmp1_);
106 }
107 }
108 vinsertf128(Ymm(data_idx), Ymm(data_idx), xmm_dst, i);
109 }
110}
111
112template <>
113void jit_uni_shuffle_kernel_t<avx>::emu_gather_data(const Reg64 &reg_src_addr,
114 const int indices_idx, const int data_idx, const bool is_tail) {
115 const Xmm xmm_tmp = Xmm(vmm_full_mask_.getIdx());
116 const Xmm xmm_dst = Xmm(vmm_tmp_.getIdx());
117
118 xor_(reg_tmp_, reg_tmp_);
119 mov(reg_tmp1_, reg_src_addr);
120
121 constexpr unsigned xmm_size_elem = 4;
122
123 const unsigned number_of_xmms = is_tail
124 ? utils::div_up(conf_.simd_tail, xmm_size_elem)
125 : utils::div_up(conf_.simd_w, xmm_size_elem);
126 for (unsigned i = 0; i < number_of_xmms; i++) {
127 vextractf128(xmm_tmp, Ymm(indices_idx), i);
128
129 const unsigned number_of_values_to_load = i == number_of_xmms - 1
130 && is_tail && conf_.simd_tail % xmm_size_elem != 0
131 ? conf_.simd_tail % xmm_size_elem
132 : xmm_size_elem;
133 for (unsigned j = 0; j < number_of_values_to_load; j++) {
134 vpextrd(reg_tmp_.cvt32(), xmm_tmp, j);
135 add(reg_src_addr, reg_tmp_);
136 vpinsrd(xmm_dst, xmm_dst, ptr[reg_src_addr], j);
137 mov(reg_src_addr, reg_tmp1_);
138 }
139
140 vinsertf128(Ymm(data_idx), Ymm(data_idx), xmm_dst, i);
141 }
142}
143
144template <>
145void jit_uni_shuffle_kernel_t<sse41>::emu_gather_data(const Reg64 &reg_src_addr,
146 const int indices_idx, const int data_idx, const bool is_tail) {
147 xor_(reg_tmp_, reg_tmp_);
148 mov(reg_tmp1_, reg_src_addr);
149
150 constexpr unsigned xmm_size_elem = 4;
151
152 const unsigned number_of_values_to_load
153 = is_tail ? conf_.simd_tail : xmm_size_elem;
154 for (unsigned j = 0; j < number_of_values_to_load; j++) {
155 pextrd(reg_tmp_.cvt32(), Xmm(indices_idx), j);
156 add(reg_src_addr, reg_tmp_);
157 pinsrd(Xmm(data_idx), ptr[reg_src_addr], j);
158 mov(reg_src_addr, reg_tmp1_);
159 }
160}
161
162template <>
163void jit_uni_shuffle_kernel_t<avx512_core>::gather_data(
164 const Reg64 &reg_src_addr, const int indices_idx, const int data_idx,
165 const bool is_tail) {
166 if (conf_.dt_size == sizeof(float)) {
167 const Opmask &mask = is_tail ? k_tail_mask_ : k_full_mask_;
168 if (!is_tail) {
169 // Have to set the all bits to 1 gather full
170 // register. It is needed after each gather, because
171 // vgatherdps zeros the mask if successful
172 mov(reg_tmp_.cvt32(), (1ULL << conf_.simd_w) - 1ULL);
173 kmovw(k_full_mask_, reg_tmp_.cvt32());
174 }
175 vgatherdps(Vmm(data_idx) | mask, ptr[reg_src_addr + Vmm(indices_idx)]);
176 // Have to restore tail processing mask after gather because mask
177 // was zeroed after vgatherdps.
178 if (is_tail) prepare_mask();
179 } else {
180 emu_gather_data(reg_src_addr, indices_idx, data_idx, is_tail);
181 }
182}
183
184template <>
185void jit_uni_shuffle_kernel_t<avx>::gather_data(const Reg64 &reg_src_addr,
186 const int indices_idx, const int data_idx, const bool is_tail) {
187 if (conf_.isa == avx2 && conf_.dt_size == sizeof(float)) {
188 const Vmm &mask = is_tail ? vmm_tail_mask_ : vmm_full_mask_;
189 if (!is_tail) {
190 // Have to set the all bits to 1 gather full
191 // register. It is needed after each gather, because
192 // vgatherdps zeros the mask if successful
193 if (conf_.data_type == data_type::s32)
194 vpcmpeqw(vmm_full_mask_, vmm_full_mask_, vmm_full_mask_);
195 else
196 vcmpps(vmm_full_mask_, vmm_full_mask_, vmm_full_mask_,
197 _cmp_eq_oq);
198 }
199 if (conf_.data_type == data_type::s32)
200 vpgatherdd(
201 Vmm(data_idx), ptr[reg_src_addr + Vmm(indices_idx)], mask);
202 else
203 vgatherdps(
204 Vmm(data_idx), ptr[reg_src_addr + Vmm(indices_idx)], mask);
205 // Have to restore tail processing mask after gather because mask
206 // was zeroed after vgatherdps.
207 if (is_tail) prepare_mask();
208 } else {
209 emu_gather_data(reg_src_addr, indices_idx, data_idx, is_tail);
210 }
211}
212
213template <>
214void jit_uni_shuffle_kernel_t<sse41>::gather_data(const Reg64 &reg_src_addr,
215 const int indices_idx, const int data_idx, const bool is_tail) {
216 emu_gather_data(reg_src_addr, indices_idx, data_idx, is_tail);
217}
218
219template <>
220void jit_uni_shuffle_kernel_t<avx512_core>::store_data(const int data_idx,
221 const Reg64 &reg_dst_addr, const int offset, const bool is_tail) {
222 const auto extend_for_padding
223 = is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w;
224 if (conf_.data_type == data_type::bf16) {
225 const Ymm to_store_data = Ymm(data_idx);
226 const Ymm ymm_tmp = Ymm(vmm_tmp_.getIdx());
227
228 if (extend_for_padding) {
229 vmovdqu16(ymm_tmp | k_tail_mask_ | T_z, to_store_data);
230 vmovups(ptr[reg_dst_addr + offset], ymm_tmp);
231 } else {
232 if (is_tail)
233 vmovdqu16(ptr[reg_dst_addr + offset] | k_tail_mask_,
234 to_store_data);
235 else
236 vmovups(ptr[reg_dst_addr + offset], to_store_data);
237 }
238 } else {
239 if (extend_for_padding) {
240 vmovups(vmm_tmp_ | k_tail_mask_ | T_z, Vmm(data_idx));
241 vmovups(ptr[reg_dst_addr + offset], vmm_tmp_);
242 } else {
243 if (is_tail)
244 vmovups(ptr[reg_dst_addr + offset] | k_tail_mask_,
245 Vmm(data_idx));
246 else
247 vmovups(ptr[reg_dst_addr + offset], Vmm(data_idx));
248 }
249 }
250 append_zero_padding(reg_dst_, extend_for_padding);
251}
252
253template <>
254void jit_uni_shuffle_kernel_t<avx>::store_data(const int data_idx,
255 const Reg64 &reg_dst_addr, const int offset, const bool is_tail) {
256 const auto extend_for_padding
257 = is_tail && padding_size_ + conf_.simd_tail >= conf_.simd_w;
258
259 if (extend_for_padding) {
260 uni_vxorps(vmm_tmp_, vmm_tmp_, vmm_tmp_);
261 uni_vblendvps(vmm_tmp_, vmm_tmp_, Vmm(data_idx), vmm_tail_mask_);
262 vmovups(ptr[reg_dst_addr + offset], vmm_tmp_);
263 } else {
264 if (is_tail)
265 vmaskmovps(
266 ptr[reg_dst_addr + offset], vmm_tail_mask_, Vmm(data_idx));
267 else
268 vmovups(ptr[reg_dst_addr + offset], Vmm(data_idx));
269 }
270 append_zero_padding(reg_dst_, extend_for_padding);
271}
272
273template <>
274void jit_uni_shuffle_kernel_t<sse41>::store_data(const int data_idx,
275 const Reg64 &reg_dst_addr, const int offset, const bool is_tail) {
276 if (is_tail)
277 for (unsigned i = 0; i < conf_.simd_tail; i++) {
278 pextrd(ptr[reg_dst_addr + offset + i * conf_.dt_size],
279 Xmm(data_idx), i);
280 }
281 else
282 movups(ptr[reg_dst_addr + offset], Vmm(data_idx));
283
284 append_zero_padding(reg_dst_, false);
285}
286
287template <cpu_isa_t isa>
288void jit_uni_shuffle_kernel_t<isa>::shuffle_blocked_format() {
289 const Reg64 &reg_sp = reg_tmp2_;
290 const Reg64 &reg_cb = reg_tmp3_;
291 const Reg64 &reg_cb_loop_size = reg_tmp4_;
292 const Reg64 &reg_blk_tail = reg_tmp5_;
293 const Reg64 &reg_src_save = reg_tmp6_;
294 const int simd_in_blk = conf_.blk_size / conf_.simd_w;
295 const int simd_in_tail_blk
296 = utils::div_up(conf_.c % conf_.blk_size, conf_.simd_w);
297 const Vmm vmm_tmp[4] = {Vmm(5), Vmm(6), Vmm(7), Vmm(8)};
298
299 auto load_indices = ([&](bool is_blk_tail) {
300 const int simd_to_process
301 = is_blk_tail ? simd_in_tail_blk : simd_in_blk;
302 for (int i = 0; i < simd_to_process; ++i)
303 uni_vmovdqu(vmm_tmp[i],
304 ptr[reg_indices_
305 + i * conf_.simd_w * conf_.el_size_of_indices]);
306 });
307
308 auto shuffle = ([&](bool is_blk_tail) {
309 const int simd_to_process
310 = is_blk_tail ? simd_in_tail_blk : simd_in_blk;
311 for (int i = 0; i < simd_to_process; ++i) {
312 const bool simd_tail_condition = is_blk_tail && conf_.simd_tail > 0
313 && i == simd_to_process - 1;
314 uni_vmovups(vmm_indices_, vmm_tmp[i]);
315 gather_data(reg_src_, vmm_indices_.getIdx(), vmm_src_.getIdx(),
316 simd_tail_condition);
317
318 store_data(vmm_src_.getIdx(), reg_dst_,
319 i * conf_.simd_w * conf_.dt_size, simd_tail_condition);
320 }
321 });
322
323 mov(reg_cb_loop_size, ptr[reg_param + GET_OFF(cb_loop_size)]);
324
325 Label sp_loop_begin, sp_loop_end;
326 Label sp_tail_loop_begin, sp_tail_loop_end;
327 Label cb_loop_begin, cb_loop_end;
328 Label simd_loop_begin, simd_loop_end;
329 Label blk_tail_loop_begin, blk_tail_loop_end;
330 Label blk_tail_check_end;
331 Label no_tail;
332
333 xor_(reg_blk_tail, reg_blk_tail);
334
335 cmp(reg_cb_loop_size, conf_.blk_size);
336 je(no_tail, T_NEAR);
337
338 mov(reg_blk_tail, reg_cb_loop_size);
339 xor_(reg_cb_loop_size, reg_cb_loop_size);
340
341 L(no_tail);
342
343 xor_(reg_cb, reg_cb);
344 L(cb_loop_begin);
345 {
346 cmp(reg_cb, reg_cb_loop_size);
347 jge(cb_loop_end, T_NEAR);
348
349 load_indices(false);
350
351 mov(reg_src_save, reg_src_);
352
353 xor_(reg_sp, reg_sp);
354 L(sp_loop_begin);
355 {
356 cmp(reg_sp, conf_.sp_split_size);
357 jge(sp_loop_end, T_NEAR);
358
359 shuffle(false);
360
361 inc(reg_sp);
362 add(reg_src_, conf_.blk_size * conf_.dt_size);
363 add(reg_dst_, conf_.blk_size * conf_.dt_size);
364
365 jmp(sp_loop_begin);
366 }
367 L(sp_loop_end);
368
369 mov(reg_src_, reg_src_save);
370
371 add(reg_cb, conf_.blk_size);
372 add(reg_dst_,
373 conf_.blk_size * (conf_.sp - conf_.sp_split_size)
374 * conf_.dt_size);
375 add(reg_indices_, conf_.blk_size * conf_.el_size_of_indices);
376
377 jmp(cb_loop_begin);
378 }
379 L(cb_loop_end);
380
381 cmp(reg_blk_tail, 0);
382 je(blk_tail_check_end, T_NEAR);
383
384 load_indices(true);
385
386 xor_(reg_sp, reg_sp);
387 L(sp_tail_loop_begin);
388 {
389 cmp(reg_sp, conf_.sp_split_size);
390 jge(sp_tail_loop_end, T_NEAR);
391
392 shuffle(true);
393
394 inc(reg_sp);
395 add(reg_src_, conf_.blk_size * conf_.dt_size);
396 add(reg_dst_, conf_.blk_size * conf_.dt_size);
397
398 jmp(sp_tail_loop_begin);
399 }
400 L(sp_tail_loop_end);
401
402 L(blk_tail_check_end);
403}
404
405template <cpu_isa_t isa>
406void jit_uni_shuffle_kernel_t<isa>::append_zero_padding(
407 const Reg64 &reg_dst_addr, const bool extend_for_padding) {
408
409 static constexpr size_t reg64_size = 8;
410 const size_t simd_w_byte = conf_.simd_w * sizeof(float);
411
412 if (!padding_size_) return;
413
414 const auto padding_start
415 = (extend_for_padding) ? conf_.simd_w : conf_.c % conf_.blk_size;
416 const auto padding_end = (extend_for_padding)
417 ? padding_size_ - (conf_.simd_w - conf_.simd_tail)
418 : padding_size_;
419 const auto off_start = padding_start * conf_.dt_size;
420 const auto padding_to_add = padding_end * conf_.dt_size;
421
422 if (!padding_to_add) return;
423
424 Label end;
425 unsigned int off = 0;
426
427 cmp(reg_padded_block, 0);
428 je(end, T_NEAR);
429
430 if (simd_w_byte <= padding_to_add) {
431 uni_vxorps(vmm_zero_, vmm_zero_, vmm_zero_);
432 for (; off + simd_w_byte < padding_to_add; off += simd_w_byte)
433 uni_vmovups(ptr[reg_dst_addr + off_start + off], vmm_zero_);
434 }
435
436 if (off != padding_to_add) {
437 xor_(reg_tmp_, reg_tmp_);
438 for (; off + reg64_size < padding_to_add; off += reg64_size)
439 mov(ptr[reg_dst_addr + off_start + off], reg_tmp_);
440 for (; off < padding_to_add; off++)
441 mov(ptr[reg_dst_addr + off_start + off], Reg8(reg_tmp_.getIdx()));
442 }
443
444 L(end);
445}
446
447template <cpu_isa_t isa>
448void jit_uni_shuffle_kernel_t<isa>::generate() {
449 preamble();
450
451#if defined(_WIN32)
452 // Always mimic the Unix ABI
453 xor_(rdi, rcx);
454 xor_(rcx, rdi);
455 xor_(rdi, rcx);
456#endif
457
458 if (conf_.isa == avx2) {
459 // Sometimes the values in the register can be nan at the
460 // beginning of the kernel, then using vcmpps(vmm, vmm, vmm)
461 // will not set all bits to 1 for that value, instead
462 // this instruction will zeroed this value. At the beginning,
463 // it is worth to zeroing this register to be sure, that vcmpps
464 // will work properly.
465 uni_vxorps(vmm_full_mask_, vmm_full_mask_, vmm_full_mask_);
466 }
467
468 if (conf_.simd_tail > 0) prepare_mask();
469
470 mov(reg_indices_, ptr[reg_param + GET_OFF(input_off_ptr)]);
471
472 mov(reg_src_, ptr[reg_param + GET_OFF(src)]);
473 mov(reg_dst_, ptr[reg_param + GET_OFF(dst)]);
474 mov(reg_padded_block, ptr[reg_param + GET_OFF(is_padded_block)]);
475
476 shuffle_blocked_format();
477
478 postamble();
479}
480
481template struct jit_uni_shuffle_kernel_t<sse41>;
482template struct jit_uni_shuffle_kernel_t<avx>;
483template struct jit_uni_shuffle_kernel_t<avx512_core>;
484
485#undef GET_OFF
486
487} // namespace x64
488} // namespace cpu
489} // namespace impl
490} // namespace dnnl
491
492// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
493