1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18
19#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
20#include "cpu/x64/utils/jit_io_helper.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26namespace io {
27
28io_conf_t::io_conf_t(const bool nt_stores_enabled)
29 : nt_stores_enabled_(nt_stores_enabled) {}
30
31io_tail_conf_t::io_tail_conf_t(const std::size_t simd_w,
32 const std::size_t tail_size, const Xbyak::Opmask &tail_opmask,
33 const int tail_vmm_mask_idx, const Xbyak::Reg64 &reg_tmp)
34 : simd_w_(simd_w)
35 , tail_size_(tail_size)
36 , tail_opmask_(tail_opmask)
37 , tail_vmm_mask_idx_(tail_vmm_mask_idx)
38 , reg_tmp_(reg_tmp) {}
39
40io_tail_conf_t::io_tail_conf_t(const std::size_t simd_w,
41 const std::size_t tail_size, int tail_opmask_idx,
42 const int tail_vmm_mask_idx, const Xbyak::Reg64 &reg_tmp)
43 : simd_w_(simd_w)
44 , tail_size_(tail_size)
45 , tail_opmask_(Xbyak::Opmask(tail_opmask_idx))
46 , tail_vmm_mask_idx_(tail_vmm_mask_idx)
47 , reg_tmp_(reg_tmp) {}
48
49io_emu_bf16_conf_t::io_emu_bf16_conf_t(const Xbyak::Zmm &bf16_emu_reserv_1,
50 const Xbyak::Zmm &bf16_emu_reserv_2,
51 const Xbyak::Zmm &bf16_emu_reserv_3, const Xbyak::Reg64 &reg_tmp,
52 const Xbyak::Zmm &bf16_emu_reserv_4)
53 : bf16_emu_reserv_1_(bf16_emu_reserv_1)
54 , bf16_emu_reserv_2_(bf16_emu_reserv_2)
55 , bf16_emu_reserv_3_(bf16_emu_reserv_3)
56 , reg_tmp_(reg_tmp)
57 , bf16_emu_reserv_4_(bf16_emu_reserv_4) {}
58
59io_emu_bf16_conf_t::io_emu_bf16_conf_t(int bf16_emu_reserv_1_idx,
60 int bf16_emu_reserv_2_idx, int bf16_emu_reserv_3_idx,
61 const Xbyak::Reg64 &reg_tmp, int bf16_emu_reserv_4_idx)
62 : bf16_emu_reserv_1_(Xbyak::Zmm(bf16_emu_reserv_1_idx))
63 , bf16_emu_reserv_2_(Xbyak::Zmm(bf16_emu_reserv_2_idx))
64 , bf16_emu_reserv_3_(Xbyak::Zmm(bf16_emu_reserv_3_idx))
65 , reg_tmp_(reg_tmp)
66 , bf16_emu_reserv_4_(Xbyak::Zmm(bf16_emu_reserv_4_idx)) {}
67
68io_saturation_conf_t::io_saturation_conf_t(const int vreg_zero_saturation_idx,
69 const int vreg_saturation_ubound_idx, const Xbyak::Reg64 &reg_tmp)
70 : vreg_zero_saturation_idx_(vreg_zero_saturation_idx)
71 , vreg_saturation_ubound_idx_(vreg_saturation_ubound_idx)
72 , reg_tmp_(reg_tmp) {}
73
74io_gather_conf_t::io_gather_conf_t(const std::size_t simd_w,
75 const Xbyak::Opmask &full_opmask, const int full_vmm_mask_idx,
76 const Xbyak::Reg64 &reg_tmp, const Xbyak::Reg64 &reg_tmp1,
77 const utils::optional_t<int> &vmm_tmp_idx)
78 : simd_w_(simd_w)
79 , full_opmask_(full_opmask)
80 , full_vmm_mask_idx_(full_vmm_mask_idx)
81 , reg_tmp_(reg_tmp)
82 , reg_tmp1_(reg_tmp1)
83 , vmm_tmp_idx_(vmm_tmp_idx) {}
84
85template <typename Vmm>
86jit_io_helper_t<Vmm>::jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa,
87 const data_type_t &data_type, const io_conf_t &io_conf,
88 const utils::optional_t<io_tail_conf_t> &tail_conf,
89 const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf,
90 const utils::optional_t<io_saturation_conf_t> &saturation_conf,
91 const utils::optional_t<io_gather_conf_t> &gather_conf)
92 : host_(host)
93 , isa_(isa)
94 , data_type_(data_type)
95 , bf16_supported_(is_data_type_supported(data_type::bf16))
96 , f16_supported_(is_data_type_supported(data_type::f16))
97 , bf16_emu_(nullptr)
98 , io_conf_(io_conf)
99 , tail_conf_(tail_conf)
100 , bf16_conf_(bf16_conf)
101 , saturation_conf_(saturation_conf)
102 , gather_conf_(gather_conf) {
103
104 if (data_type_ == data_type::bf16 && isa == avx512_core) {
105 assert(bf16_conf.has_value()
106 && "Config for bf16 emulation is not set.");
107 bf16_emu_ = utils::make_unique<bf16_emulation_t>(host_,
108 bf16_conf->bf16_emu_reserv_1_, bf16_conf->bf16_emu_reserv_2_,
109 bf16_conf->bf16_emu_reserv_3_, bf16_conf->reg_tmp_,
110 bf16_conf->bf16_emu_reserv_4_);
111 }
112
113 assert(utils::one_of(data_type_, data_type::f16, data_type::bf16,
114 data_type::f32, data_type::s8, data_type::u8, data_type::s32)
115 && is_data_type_supported(data_type_)
116 && "Supported data types f16, bf16, f32, s8, u8, s32");
117
118 /*
119 * vpmovsxbd, vpmovzxbd for AVX are defined only for XMM. Since AVX2
120 * they are defined also for YMM. In order to avoid workaround with
121 * potential performance penalty AVX with s8u8 disabled with YMM.
122 */
123 static constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value;
124 const bool is_avx_u8s8 = (isa_ == avx
125 && utils::one_of(data_type_, data_type::s8, data_type::u8));
126 MAYBE_UNUSED(is_xmm);
127 MAYBE_UNUSED(is_avx_u8s8);
128
129 assert(IMPLICATION(is_avx_u8s8, is_xmm)
130 && "s8u8 with AVX should be used with XMM vreg");
131
132 static constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
133 MAYBE_UNUSED(is_zmm);
134 assert(IMPLICATION(!is_superset(isa_, avx512_core), !is_zmm)
135 && "This architecture does not support zmms.");
136}
137
138template <typename Vmm>
139jit_io_helper_t<Vmm>::~jit_io_helper_t() = default;
140
141template <typename Vmm>
142bool jit_io_helper_t<Vmm>::is_data_type_supported(const data_type_t dt) {
143 switch (dt) {
144 case data_type::f32:
145 case data_type::s32:
146 case data_type::u8:
147 case data_type::s8: return true;
148 case data_type::bf16:
149 return is_superset(isa_, avx512_core) || isa_ == avx2_vnni_2;
150 case data_type::f16:
151 return is_superset(isa_, avx512_core_fp16) || isa_ == avx2_vnni_2;
152 default: assert(!"Unsupported data type");
153 }
154 return false;
155}
156
157template <typename Vmm>
158void jit_io_helper_t<Vmm>::init_bf16() {
159 if (bf16_emu_) {
160 assert(bf16_conf_.has_value()
161 && "Config for bf16 emulation is not set.");
162 bf16_emu_->init_vcvtneps2bf16();
163 }
164}
165
166template <typename Vmm>
167void jit_io_helper_t<Vmm>::prepare_opmask(
168 const std::size_t how_many_bits_to_set, const Xbyak::Reg64 &reg_tmp,
169 const Xbyak::Opmask &mask) {
170 const int mask_f32 = (1 << how_many_bits_to_set) - 1;
171 const Xbyak::Reg32 regw_tmp = reg_tmp.cvt32();
172 host_->mov(regw_tmp, mask_f32);
173 host_->kmovw(mask, regw_tmp);
174}
175
176template <typename Vmm>
177void jit_io_helper_t<Vmm>::prepare_vmm_mask(
178 const std::size_t how_many_bits_to_set, const std::size_t simd_w,
179 const Xbyak::Reg64 &reg_tmp, const Vmm &mask) {
180 static const uint32_t mask_f32[14]
181 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
182 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0};
183
184 if (how_many_bits_to_set < simd_w) {
185 host_->mov(reg_tmp,
186 reinterpret_cast<size_t>(&mask_f32[7 - how_many_bits_to_set]));
187 host_->uni_vmovups(mask, host_->ptr[reg_tmp]);
188 } else if (how_many_bits_to_set == simd_w) {
189 host_->uni_vcmpps(mask, mask, mask, jit_generator::_cmp_eq_oq);
190 } else {
191 assert(!"Can't set so many bits.");
192 }
193}
194
195template <typename Vmm>
196void jit_io_helper_t<Vmm>::prepare_i8_data_to_store(const Vmm &i8_vmm) {
197 assert(saturation_conf_.has_value() && "Config for saturation is not set.");
198
199 static constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
200
201 host_->uni_vpackssdw(
202 i8_vmm, i8_vmm, Vmm(saturation_conf_->vreg_zero_saturation_idx_));
203 if (is_ymm) {
204 // dst[63:0] = src[63:0]
205 // dst[127:64] = src[191:128]
206 // dst[191:128] = src[127:64]
207 // dst[255:192] = src[127:64]
208 const auto src_ymm = Xbyak::Ymm(i8_vmm.getIdx());
209 host_->vpermq(src_ymm, src_ymm, 0x58);
210 }
211
212 if (data_type_ == data_type::s8)
213 host_->uni_vpacksswb(i8_vmm, i8_vmm,
214 Vmm(saturation_conf_->vreg_zero_saturation_idx_));
215 else
216 host_->uni_vpackuswb(i8_vmm, i8_vmm,
217 Vmm(saturation_conf_->vreg_zero_saturation_idx_));
218}
219
220template <typename Vmm>
221void jit_io_helper_t<Vmm>::prepare_xf16_data_to_store(const Vmm &vmm) {
222 assert(!is_superset(isa_, avx512_core));
223 const auto &cvt_lower_vmm =
224 typename vreg_traits<Vmm>::Vmm_lower_t(vmm.getIdx());
225
226 if (data_type_ == data_type::bf16)
227 host_->vcvtneps2bf16(cvt_lower_vmm, vmm, Xbyak::VexEncoding);
228 else
229 host_->uni_vcvtps2phx(cvt_lower_vmm, vmm);
230}
231
232template <>
233void jit_io_helper_t<Xbyak::Zmm>::emu_gather(const Xbyak::Reg64 &src_reg,
234 const Xbyak::Zmm &indices_vmm, const Xbyak::Zmm &dst_vmm,
235 const bool tail) {
236 assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set.");
237 assert(gather_conf_->vmm_tmp_idx_.has_value()
238 && "Temporary vreg is not set.");
239 assert(IMPLICATION(tail, tail_conf_.has_value())
240 && "Config for tail processing is not set.");
241
242 const Xbyak::Xmm xmm_tmp = Xbyak::Xmm(gather_conf_->full_vmm_mask_idx_);
243 const Xbyak::Xmm xmm_dst = Xbyak::Xmm(*gather_conf_->vmm_tmp_idx_);
244 const Xbyak::Ymm dst_ymm;
245
246 host_->mov(gather_conf_->reg_tmp_, 0);
247 host_->mov(gather_conf_->reg_tmp1_, src_reg);
248
249 // The conversion of bf16->f32 here is split into two parts.
250 // Here while loading words of bf16, the words are interleaved,
251 // and in convert_to_f32, they are shifted-left to finally convert to f32
252 // For f16 we do not need such interleaving.
253 const int xmm_size_elem = (data_type_ == data_type::f16) ? 8 : 4;
254 const int number_of_xmms = tail
255 ? utils::div_up(tail_conf_->tail_size_, xmm_size_elem)
256 : utils::div_up(gather_conf_->simd_w_, xmm_size_elem);
257 const int num_indices_in_xmm = 16 / sizeof(int);
258 for (int i = 0, idx = 0; i < number_of_xmms; i++) {
259
260 const int number_of_values_to_load = i == number_of_xmms - 1 && tail
261 && tail_conf_->tail_size_ % xmm_size_elem != 0
262 ? tail_conf_->tail_size_ % xmm_size_elem
263 : xmm_size_elem;
264 for (int j = 0; j < number_of_values_to_load; j++) {
265
266 if (j % num_indices_in_xmm == 0)
267 host_->vextractf32x4(xmm_tmp, indices_vmm, idx++);
268 host_->vpextrd(gather_conf_->reg_tmp_.cvt32(), xmm_tmp, j);
269 host_->add(src_reg, gather_conf_->reg_tmp_);
270 switch (data_type_) {
271 case data_type::f16:
272 host_->vpinsrw(xmm_dst, xmm_dst, host_->ptr[src_reg], j);
273 break;
274 case data_type::bf16:
275 host_->vpinsrw(
276 xmm_dst, xmm_dst, host_->ptr[src_reg], j * 2);
277 break;
278 case data_type::s8:
279 case data_type::u8:
280 host_->vpinsrb(xmm_dst, xmm_dst, host_->ptr[src_reg],
281 i * xmm_size_elem + j);
282 break;
283 default: assert(!"Unsupported data type.");
284 }
285 host_->mov(src_reg, gather_conf_->reg_tmp1_);
286 }
287 if (data_type_ == data_type::bf16) {
288 host_->vinsertf32x4(dst_vmm, dst_vmm, xmm_dst, i);
289 host_->vpxord(xmm_dst, xmm_dst, xmm_dst);
290 } else if (data_type_ == data_type::f16) {
291 host_->vinsertf32x4(dst_ymm, dst_ymm, xmm_dst, i);
292 host_->vpxord(xmm_dst, xmm_dst, xmm_dst);
293 }
294 }
295
296 if (data_type_ == data_type::bf16)
297 convert_to_f32(dst_vmm, dst_vmm, data_type_);
298 else if (data_type_ == data_type::f16)
299 convert_to_f32(dst_vmm, dst_ymm, data_type_);
300 else if (data_type_ == data_type::s8 || data_type_ == data_type::u8)
301 convert_to_f32(dst_vmm, xmm_dst, data_type_);
302}
303
304template <>
305void jit_io_helper_t<Xbyak::Ymm>::emu_gather(const Xbyak::Reg64 &src_reg,
306 const Xbyak::Ymm &indices_vmm, const Xbyak::Ymm &dst_vmm,
307 const bool tail) {
308 assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set.");
309 assert(gather_conf_->vmm_tmp_idx_.has_value()
310 && "Temporary vreg is not set.");
311 assert(IMPLICATION(tail, tail_conf_.has_value())
312 && "Config for tail processing is not set.");
313
314 const Xbyak::Xmm xmm_tmp = Xbyak::Xmm(gather_conf_->full_vmm_mask_idx_);
315 const Xbyak::Xmm xmm_dst = Xbyak::Xmm(*gather_conf_->vmm_tmp_idx_);
316
317 host_->mov(gather_conf_->reg_tmp_, 0);
318 host_->mov(gather_conf_->reg_tmp1_, src_reg);
319
320 // The conversion of bf16->f32 here is split into two parts.
321 // Here while loading words of bf16, the words are interleaved,
322 // and in convert_to_f32, they are shifted-left to finally convert to f32
323 // For f16 we do not need such interleaving.
324 const int xmm_size_elem = (data_type_ == data_type::f16) ? 8 : 4;
325 const int number_of_xmms = tail
326 ? utils::div_up(tail_conf_->tail_size_, xmm_size_elem)
327 : utils::div_up(gather_conf_->simd_w_, xmm_size_elem);
328 for (int i = 0; i < number_of_xmms; i++) {
329 host_->vextractf128(xmm_tmp, indices_vmm, i);
330
331 const int number_of_values_to_load = i == number_of_xmms - 1 && tail
332 && tail_conf_->tail_size_ % xmm_size_elem != 0
333 ? tail_conf_->tail_size_ % xmm_size_elem
334 : xmm_size_elem;
335 for (int j = 0; j < number_of_values_to_load; j++) {
336 host_->vpextrd(gather_conf_->reg_tmp_.cvt32(), xmm_tmp, j);
337 host_->add(src_reg, gather_conf_->reg_tmp_);
338 switch (data_type_) {
339 case data_type::f32:
340 case data_type::s32: {
341 host_->vpinsrd(xmm_dst, xmm_dst, host_->ptr[src_reg], j);
342 break;
343 }
344 case data_type::f16:
345 assert(f16_supported_ && "Unsupported data type.");
346 host_->vpinsrw(xmm_dst, xmm_dst, host_->ptr[src_reg], j);
347 break;
348 case data_type::bf16:
349 assert(bf16_supported_ && "Unsupported data type.");
350 host_->vpinsrw(
351 xmm_dst, xmm_dst, host_->ptr[src_reg], j * 2);
352 break;
353 case data_type::s8:
354 case data_type::u8: {
355 host_->vpinsrb(xmm_dst, xmm_dst, host_->ptr[src_reg],
356 i * xmm_size_elem + j);
357 break;
358 }
359 default: assert(!"Unsupported data type.");
360 }
361 host_->mov(src_reg, gather_conf_->reg_tmp1_);
362 }
363
364 if (data_type_ == data_type::f32 || data_type_ == data_type::s32) {
365 host_->vinsertf128(dst_vmm, dst_vmm, xmm_dst, i);
366 }
367 }
368
369 if (data_type_ == data_type::s32 || data_type_ == data_type::bf16)
370 convert_to_f32(dst_vmm, dst_vmm, data_type_);
371 else if (utils::one_of(
372 data_type_, data_type::s8, data_type::u8, data_type::f16))
373 convert_to_f32(dst_vmm, xmm_dst, data_type_);
374}
375
376template <>
377void jit_io_helper_t<Xbyak::Xmm>::emu_gather(const Xbyak::Reg64 &src_reg,
378 const Xbyak::Xmm &indices_vmm, const Xbyak::Xmm &dst_vmm,
379 const bool tail) {
380 assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set.");
381 assert(IMPLICATION(tail, tail_conf_.has_value())
382 && "Config for tail processing is not set.");
383
384 host_->mov(gather_conf_->reg_tmp_, 0);
385 host_->mov(gather_conf_->reg_tmp1_, src_reg);
386
387 // The conversion of bf16->f32 here is split into two parts.
388 // Here while loading words of bf16, the words are interleaved,
389 // and in convert_to_f32, they are shifted-left to finally convert to f32
390 // For f16 we do not need such interleaving.
391 const unsigned xmm_size_elem = (data_type_ == data_type::f16) ? 8 : 4;
392
393 const unsigned number_of_values_to_load
394 = tail ? tail_conf_->tail_size_ : xmm_size_elem;
395 for (unsigned j = 0; j < number_of_values_to_load; j++) {
396 host_->pextrd(gather_conf_->reg_tmp_.cvt32(), indices_vmm, j);
397 host_->add(src_reg, gather_conf_->reg_tmp_);
398 switch (data_type_) {
399 case data_type::f32:
400 case data_type::s32: {
401 host_->pinsrd(dst_vmm, host_->ptr[src_reg], j);
402 break;
403 }
404 case data_type::f16:
405 assert(f16_supported_ && "Unsupported data type.");
406 host_->pinsrw(dst_vmm, host_->ptr[src_reg], j);
407 break;
408 case data_type::bf16:
409 assert(bf16_supported_ && "Unsupported data type.");
410 host_->pinsrw(dst_vmm, host_->ptr[src_reg], j * 2);
411 break;
412 case data_type::s8:
413 case data_type::u8: {
414 host_->pinsrb(dst_vmm, host_->ptr[src_reg], j);
415 break;
416 }
417 default: assert(!"Unsupported data type.");
418 }
419 host_->mov(src_reg, gather_conf_->reg_tmp1_);
420 }
421
422 if (data_type_ != data_type::f32)
423 convert_to_f32(dst_vmm, dst_vmm, data_type_);
424}
425
426template <typename Vmm>
427void jit_io_helper_t<Vmm>::prepare_tail_mask() {
428 assert(tail_conf_.has_value() && "Config for tail processing is not set.");
429
430 if (!tail_conf_->tail_size_) return;
431
432 if (is_superset(isa_, avx512_core))
433 prepare_opmask(tail_conf_->tail_size_, tail_conf_->reg_tmp_,
434 tail_conf_->tail_opmask_);
435 else if (is_superset(isa_, sse41))
436 prepare_vmm_mask(tail_conf_->tail_size_, tail_conf_->simd_w_,
437 tail_conf_->reg_tmp_, Vmm(tail_conf_->tail_vmm_mask_idx_));
438}
439
440template <typename Vmm>
441void jit_io_helper_t<Vmm>::prepare_full_mask() {
442 assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set.");
443
444 if (utils::one_of(data_type_, data_type::f16, data_type::bf16,
445 data_type::s8, data_type::u8))
446 return;
447
448 if (is_superset(isa_, avx512_core))
449 prepare_opmask(gather_conf_->simd_w_, gather_conf_->reg_tmp_,
450 gather_conf_->full_opmask_);
451 else if (isa_ == avx2)
452 prepare_vmm_mask(gather_conf_->simd_w_, gather_conf_->simd_w_,
453 gather_conf_->reg_tmp_, Vmm(gather_conf_->full_vmm_mask_idx_));
454}
455
456template <typename Vmm>
457void jit_io_helper_t<Vmm>::init_full_mask() {
458 assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set.");
459
460 if (isa_ == avx2) {
461 const Vmm vmm_mask = Vmm(gather_conf_->full_vmm_mask_idx_);
462 host_->uni_vxorps(vmm_mask, vmm_mask, vmm_mask);
463 }
464}
465
466template <typename Vmm>
467void jit_io_helper_t<Vmm>::init_saturate_f32() const {
468 assert(saturation_conf_.has_value() && "Config for saturation is not set.");
469
470 if (utils::one_of(data_type_, data_type::u8, data_type::s8, data_type::s32))
471 host_->init_saturate_f32(
472 Vmm(saturation_conf_->vreg_zero_saturation_idx_),
473 Vmm(saturation_conf_->vreg_saturation_ubound_idx_),
474 saturation_conf_->reg_tmp_, data_type::f32, data_type_);
475}
476
477template <typename Vmm>
478void jit_io_helper_t<Vmm>::gather(const Xbyak::Reg64 &src_reg,
479 const Vmm &indices_vmm, const Vmm &dst_vmm, const bool tail) {
480 assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set.");
481 assert(IMPLICATION(tail, tail_conf_.has_value())
482 && "Config for tail processing is not set.");
483
484 const Vmm &mask = tail ? Vmm(tail_conf_->tail_vmm_mask_idx_)
485 : Vmm(gather_conf_->full_vmm_mask_idx_);
486
487 const Vmm dst_vmm_with_mask = tail ? dst_vmm | tail_conf_->tail_opmask_
488 : dst_vmm | gather_conf_->full_opmask_;
489
490 const bool can_use_gather_instruction
491 = isa_ == avx2 || is_superset(isa_, avx512_core);
492
493 if ((data_type_ == data_type::f32 || data_type_ == data_type::s32)
494 && can_use_gather_instruction) {
495 if (data_type_ == data_type::f32) {
496 if (isa_ == avx2)
497 host_->vgatherdps(
498 dst_vmm, host_->ptr[src_reg + indices_vmm], mask);
499 else
500 host_->vgatherdps(
501 dst_vmm_with_mask, host_->ptr[src_reg + indices_vmm]);
502 } else {
503 if (isa_ == avx2)
504 host_->vpgatherdd(
505 dst_vmm, host_->ptr[src_reg + indices_vmm], mask);
506 else
507 host_->vpgatherdd(
508 dst_vmm_with_mask, host_->ptr[src_reg + indices_vmm]);
509 convert_to_f32(dst_vmm, dst_vmm, data_type_);
510 }
511
512 // Have to restore processing mask after gather because mask
513 // was zeroed.
514 if (tail)
515 prepare_tail_mask();
516 else
517 prepare_full_mask();
518 } else {
519 emu_gather(src_reg, indices_vmm, dst_vmm, tail);
520 }
521}
522
523template <typename Vmm>
524void jit_io_helper_t<Vmm>::load(const Xbyak::Address &src_addr,
525 const Vmm &dst_raw_vmm, const bool tail) {
526 assert(IMPLICATION(tail, tail_conf_.has_value())
527 && "Config for tail processing is not set.");
528
529 const bool is_avx512 = is_superset(isa_, avx512_core);
530
531 const auto dst_vmm = tail && is_avx512
532 ? (dst_raw_vmm | tail_conf_->tail_opmask_ | host_->T_z)
533 : dst_raw_vmm;
534
535 const bool is_i8 = utils::one_of(data_type_, data_type::s8, data_type::u8);
536 const bool is_xf16
537 = utils::one_of(data_type_, data_type::bf16, data_type::f16);
538 const bool is_tail_load_supported = is_avx512;
539 const bool can_load_byte_by_byte = tail
540 && (isa_ == sse41
541 || (!is_tail_load_supported && (is_i8 || is_xf16)));
542
543 if (can_load_byte_by_byte) {
544 load_byte_by_byte(src_addr, dst_vmm, tail_conf_->tail_size_);
545 } else {
546 switch (data_type_) {
547 case data_type::f32: load_f32(src_addr, dst_vmm, tail); break;
548 case data_type::s32: load_s32(src_addr, dst_vmm, tail); break;
549 case data_type::bf16: load_bf16(src_addr, dst_vmm); break;
550 case data_type::f16: load_f16(src_addr, dst_vmm); break;
551 case data_type::s8:
552 case data_type::u8: load_i8(src_addr, dst_vmm); break;
553 default: assert(!"Unsupported data type.");
554 }
555 }
556}
557
558template <>
559void jit_io_helper_t<Xbyak::Zmm>::load_byte_by_byte(
560 const Xbyak::Address &src_addr, const Xbyak::Zmm &dst_vmm,
561 const int load_size) {
562 assert("Load byte by byte is not supported for Zmms.");
563}
564
565template <typename Vmm>
566void jit_io_helper_t<Vmm>::load_byte_by_byte(const Xbyak::Address &src_addr,
567 const Vmm &dst_vmm, const int load_size) {
568 host_->uni_vxorps(dst_vmm, dst_vmm, dst_vmm);
569 host_->load_data(data_type_, dst_vmm, src_addr, load_size);
570
571 if (utils::one_of(data_type_, data_type::s32, data_type::s8, data_type::u8))
572 convert_to_f32(dst_vmm, dst_vmm, data_type::s32);
573}
574
575template <typename Vmm>
576void jit_io_helper_t<Vmm>::load_f32(
577 const Xbyak::Address &src_addr, const Vmm &dst_vmm, const bool tail) {
578 if (tail && !is_superset(isa_, avx512_core))
579 host_->vmaskmovps(
580 dst_vmm, Vmm(tail_conf_->tail_vmm_mask_idx_), src_addr);
581 else
582 host_->uni_vmovups(dst_vmm, src_addr);
583}
584
585template <typename Vmm>
586void jit_io_helper_t<Vmm>::load_s32(
587 const Xbyak::Address &src_addr, const Vmm &dst_vmm, const bool tail) {
588 if (is_superset(isa_, avx512_core))
589 host_->uni_vcvtdq2ps(dst_vmm, src_addr);
590 else {
591 load_f32(src_addr, dst_vmm, tail);
592 convert_to_f32(dst_vmm, dst_vmm, data_type::s32);
593 }
594}
595
596template <typename Vmm>
597void jit_io_helper_t<Vmm>::load_bf16(
598 const Xbyak::Address &src_addr, const Vmm &dst_vmm) {
599 assert(bf16_supported_ && "Unsupported data type.");
600
601 host_->vpmovzxwd(dst_vmm, src_addr);
602 convert_to_f32(dst_vmm, dst_vmm, data_type::bf16);
603}
604
605template <typename Vmm>
606void jit_io_helper_t<Vmm>::load_f16(
607 const Xbyak::Address &src_addr, const Vmm &dst_vmm) {
608 assert(f16_supported_ && "Unsupported data type.");
609 host_->uni_vcvtph2psx(dst_vmm, src_addr);
610}
611
612template <typename Vmm>
613void jit_io_helper_t<Vmm>::load_i8(
614 const Xbyak::Address &src_addr, const Vmm &dst_vmm) {
615 if (data_type_ == data_type::s8)
616 host_->uni_vpmovsxbd(dst_vmm, src_addr);
617 else
618 host_->uni_vpmovzxbd(dst_vmm, src_addr);
619
620 convert_to_f32(dst_vmm, dst_vmm, data_type::s32);
621}
622
623template <typename Vmm>
624void jit_io_helper_t<Vmm>::load_two_simdw_xf16(const Xbyak::Address &src_addr,
625 const Vmm &dst_even_vmm, const Vmm &dst_odd_vmm) {
626 // The outputs are in odd/even interleaved layouts
627 // now only support bf16/f16 w/o tail on AVX2_VNNI_2
628 assert(utils::one_of(data_type_, data_type::bf16, data_type::f16)
629 && isa_ == avx2_vnni_2 && "Unsupported data type.");
630
631 if (data_type_ == data_type::bf16) {
632 host_->vcvtneebf162ps(dst_even_vmm, src_addr);
633 host_->vcvtneobf162ps(dst_odd_vmm, src_addr);
634 } else {
635 host_->vcvtneeph2ps(dst_even_vmm, src_addr);
636 host_->vcvtneoph2ps(dst_odd_vmm, src_addr);
637 }
638}
639
640template <typename Vmm>
641void jit_io_helper_t<Vmm>::merge_interleaved_to_plain(
642 const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0) {
643 // Merge inputs in odd/even interleaved layouts to plain layouts
644 assert(vmm_even.isYMM() && vmm_odd.isYMM()
645 && "Merge interleaved to plain only supports Ymms");
646 Xbyak::Ymm ymm_even = Xbyak::Ymm(vmm_even.getIdx());
647 Xbyak::Ymm ymm_odd = Xbyak::Ymm(vmm_odd.getIdx());
648 Xbyak::Ymm ymm_aux0 = Xbyak::Ymm(vmm_aux0.getIdx());
649 Xbyak::Ymm ymm_aux1 = Xbyak::Ymm(vmm_odd.getIdx());
650
651 host_->vpunpckldq(ymm_aux0, ymm_even, ymm_odd);
652 host_->vpunpckhdq(ymm_aux1, ymm_even, ymm_odd);
653 host_->vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20);
654 host_->vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31);
655}
656
657template <typename Vmm>
658void jit_io_helper_t<Vmm>::store(const Vmm &src_raw_vmm,
659 const Xbyak::Address &dst_raw_addr, const bool tail) {
660 assert(IMPLICATION(tail, tail_conf_.has_value())
661 && "Config for tail processing is not set.");
662 assert(!(tail && io_conf_.nt_stores_enabled_)
663 && "Usage of non-temporal stores with tail leads to a general-protection exception.");
664
665 const bool is_avx512 = is_superset(isa_, avx512_core);
666
667 const auto dst_addr = tail && is_avx512
668 ? (dst_raw_addr | tail_conf_->tail_opmask_)
669 : dst_raw_addr;
670 const auto src_vmm = tail && is_avx512
671 ? (src_raw_vmm | tail_conf_->tail_opmask_)
672 : src_raw_vmm;
673
674 const bool is_store_tail_supported = is_avx512;
675 const bool is_i8 = utils::one_of(data_type_, data_type::s8, data_type::u8);
676 const bool is_xf16
677 = utils::one_of(data_type_, data_type::bf16, data_type::f16);
678
679 const bool can_store_byte_by_byte = tail
680 && (isa_ == sse41
681 || (!is_store_tail_supported && (is_i8 || is_xf16)));
682
683 if (data_type_ == data_type::s32 || is_i8) saturate(src_raw_vmm);
684
685 if (can_store_byte_by_byte) {
686 const size_t store_size
687 = tail_conf_->tail_size_ * types::data_type_size(data_type_);
688 store_byte_by_byte(src_vmm, dst_addr, store_size);
689 } else {
690 switch (data_type_) {
691 case data_type::f32:
692 case data_type::s32: store_f32(src_vmm, dst_addr, tail); break;
693 case data_type::bf16: store_bf16(src_vmm, dst_addr); break;
694 case data_type::f16: store_f16(src_vmm, dst_addr); break;
695 case data_type::s8:
696 case data_type::u8: store_i8(src_vmm, dst_raw_addr); break;
697 default: assert(!"Unsupported data type.");
698 }
699 }
700}
701
702template <typename Vmm>
703void jit_io_helper_t<Vmm>::saturate(const Vmm &vmm) {
704 assert(saturation_conf_.has_value() && "Config for saturation is not set.");
705
706 host_->saturate_f32(vmm, Vmm(saturation_conf_->vreg_zero_saturation_idx_),
707 Vmm(saturation_conf_->vreg_saturation_ubound_idx_), data_type_);
708 host_->uni_vcvtps2dq(vmm, vmm);
709}
710
711template <>
712void jit_io_helper_t<Xbyak::Zmm>::store_byte_by_byte(const Xbyak::Zmm &src_zmm,
713 const Xbyak::Address &dst_addr, const int store_size) {
714 assert("Store byte by byte is not supported for Zmms.");
715}
716
717template <typename Vmm>
718void jit_io_helper_t<Vmm>::store_byte_by_byte(const Vmm &src_vmm,
719 const Xbyak::Address &dst_addr, const int store_size) {
720 const bool is_i8 = utils::one_of(data_type_, data_type::s8, data_type::u8);
721 const bool is_xf16
722 = utils::one_of(data_type_, data_type::bf16, data_type::f16);
723 const auto &cvt_lower_vmm =
724 typename vreg_traits<Vmm>::Vmm_lower_t(src_vmm.getIdx());
725
726 if (is_i8) prepare_i8_data_to_store(src_vmm);
727 if (is_xf16) prepare_xf16_data_to_store(src_vmm);
728
729 host_->store_bytes(is_xf16 ? cvt_lower_vmm : src_vmm, dst_addr, store_size);
730}
731
732template <typename Vmm>
733void jit_io_helper_t<Vmm>::store_f32(
734 const Vmm &src_vmm, const Xbyak::Address &dst_addr, const bool tail) {
735 if (io_conf_.nt_stores_enabled_)
736 host_->uni_vmovntps(dst_addr, src_vmm);
737 else if (!is_superset(isa_, avx512_core) && tail)
738 host_->vmaskmovps(
739 dst_addr, Vmm(tail_conf_->tail_vmm_mask_idx_), src_vmm);
740 else
741 host_->uni_vmovups(dst_addr, src_vmm);
742}
743
744template <typename Vmm>
745void jit_io_helper_t<Vmm>::store_bf16(
746 const Vmm &src_vmm, const Xbyak::Address &dst_addr) {
747 assert(bf16_supported_ && "Unsupported data type.");
748 assert((src_vmm.isZMM() || src_vmm.isYMM())
749 && "Store operation for bf16 is not supported for Xmms.");
750
751 const auto &cvt_lower_vmm =
752 typename vreg_traits<Vmm>::Vmm_lower_t(src_vmm.getIdx());
753
754 if (bf16_emu_)
755 bf16_emu_->vcvtneps2bf16(cvt_lower_vmm, src_vmm);
756 else
757 host_->vcvtneps2bf16(cvt_lower_vmm, src_vmm,
758 mayiuse(avx512_core) ? Xbyak::EvexEncoding
759 : Xbyak::VexEncoding);
760
761 if (io_conf_.nt_stores_enabled_)
762 host_->uni_vmovntps(dst_addr, cvt_lower_vmm);
763 else
764 host_->uni_vmovdqu16(dst_addr, cvt_lower_vmm);
765}
766
767template <typename Vmm>
768void jit_io_helper_t<Vmm>::store_f16(
769 const Vmm &src_vmm, const Xbyak::Address &dst_addr) {
770 assert(f16_supported_ && "Unsupported data type.");
771 assert((src_vmm.isZMM() || src_vmm.isYMM())
772 && "Store operation for f16 is not supported for Xmms.");
773
774 const auto &cvt_lower_vmm =
775 typename vreg_traits<Vmm>::Vmm_lower_t(src_vmm.getIdx());
776
777 host_->uni_vcvtps2phx(cvt_lower_vmm, src_vmm);
778
779 if (io_conf_.nt_stores_enabled_)
780 host_->uni_vmovntps(dst_addr, cvt_lower_vmm);
781 else
782 host_->uni_vmovdqu16(dst_addr, cvt_lower_vmm);
783}
784
785template <typename Vmm>
786void jit_io_helper_t<Vmm>::store_i8(
787 const Vmm &src_vmm, const Xbyak::Address &dst_addr) {
788 if (!is_superset(isa_, avx512_core)) {
789 static constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value;
790
791 prepare_i8_data_to_store(src_vmm);
792 if (is_ymm)
793 host_->uni_vmovq(dst_addr, Xbyak::Xmm(src_vmm.getIdx()));
794 else
795 host_->uni_vmovd(dst_addr, src_vmm);
796 } else {
797 using namespace std::placeholders;
798 static constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value;
799
800 auto store_i8_fn = data_type_ == data_type::s8
801 ? std::bind(&jit_generator::vpmovsdb, host_, _1, _2)
802 : std::bind(&jit_generator::vpmovusdb, host_, _1, _2);
803
804 if (io_conf_.nt_stores_enabled_ && is_zmm) {
805 Xbyak::Xmm src_xmm(src_vmm.getIdx());
806 store_i8_fn(src_xmm, src_vmm);
807 host_->uni_vmovntps(dst_addr, src_xmm);
808 } else {
809 store_i8_fn(dst_addr, src_vmm);
810 }
811 }
812}
813
814template <typename Vmm>
815void jit_io_helper_t<Vmm>::convert_to_f32(const Vmm &dst_vmm,
816 const Xbyak::Xmm &src_vmm, const data_type_t src_data_type) {
817 switch (src_data_type) {
818 case data_type::s32: {
819 assert(dst_vmm.getIdx() == src_vmm.getIdx());
820 host_->uni_vcvtdq2ps(dst_vmm, dst_vmm);
821 break;
822 }
823 case data_type::bf16:
824 assert(bf16_supported_ && "Unsupported data type.");
825 host_->vpslld(dst_vmm, src_vmm, 0x10);
826 break;
827 case data_type::f16:
828 assert(f16_supported_ && "Unsupported data type.");
829 host_->vcvtph2ps(dst_vmm, src_vmm);
830 break;
831 case data_type::s8: {
832 host_->uni_vpmovsxbd(dst_vmm, src_vmm);
833 host_->uni_vcvtdq2ps(dst_vmm, dst_vmm);
834 break;
835 }
836 case data_type::u8: {
837 host_->uni_vpmovzxbd(dst_vmm, src_vmm);
838 host_->uni_vcvtdq2ps(dst_vmm, dst_vmm);
839 break;
840 }
841 default: assert(!"Unsupported data type.");
842 }
843}
844
845template <typename Vmm>
846void jit_io_helper_t<Vmm>::broadcast(
847 const Xbyak::Address &src_addr, const Vmm &dst_vmm) {
848 switch (data_type_) {
849 case data_type::f32: host_->uni_vbroadcastss(dst_vmm, src_addr); break;
850 case data_type::bf16:
851 assert(bf16_supported_ && "Unsupported data type.");
852 host_->vpbroadcastw(dst_vmm, src_addr);
853 convert_to_f32(dst_vmm, dst_vmm, data_type_);
854 break;
855 case data_type::f16:
856 assert(f16_supported_ && "Unsupported data type.");
857 host_->uni_vcvtph2psx(dst_vmm, host_->ptr_b[src_addr.getRegExp()]);
858 break;
859 case data_type::s32: {
860 if (is_superset(isa_, avx512_core)) {
861 host_->uni_vcvtdq2ps(
862 dst_vmm, host_->ptr_b[src_addr.getRegExp()]);
863 } else {
864 host_->uni_vbroadcastss(dst_vmm, src_addr);
865 convert_to_f32(dst_vmm, dst_vmm, data_type_);
866 }
867 break;
868 }
869 case data_type::s8:
870 case data_type::u8: {
871 const Xbyak::Xmm dst_xmm {dst_vmm.getIdx()};
872 host_->uni_vpinsrb(dst_xmm, dst_xmm, src_addr, 0);
873 convert_to_f32(dst_vmm, dst_vmm, data_type_);
874 host_->uni_vbroadcastss(dst_vmm, dst_xmm);
875
876 break;
877 }
878 default: assert(!"Unsupported data type.");
879 }
880}
881
882template <typename Vmm>
883jit_io_multi_dt_helper_t<Vmm>::jit_io_multi_dt_helper_t(jit_generator *host,
884 const cpu_isa_t &isa, const data_types_t &data_types,
885 const io_conf_t &io_conf,
886 const utils::optional_t<io_tail_conf_t> &tail_conf,
887 const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf,
888 const std::map<data_type_t, io_saturation_conf_t> &saturation_confs,
889 const utils::optional_t<io_gather_conf_t> &gather_conf) {
890 assert(!data_types.empty());
891 for (const auto &dt : data_types) {
892 // can be replaced by try_emplace from C++17
893 if (storage_.find(dt) == storage_.cend()) {
894
895 const auto saturation_conf = saturation_confs.find(dt);
896 const bool store_saturation_needed
897 = saturation_conf != saturation_confs.cend();
898
899 storage_.emplace(dt,
900 std::make_shared<jit_io_helper_t<Vmm>>(host, isa, dt,
901 io_conf, tail_conf,
902 dt == data_type::bf16 ? bf16_conf : utils::nullopt,
903 store_saturation_needed ? utils::optional_t<
904 io_saturation_conf_t> {saturation_conf
905 ->second}
906 : utils::nullopt,
907 gather_conf));
908 }
909 }
910}
911
912template <typename Vmm>
913std::shared_ptr<jit_io_helper_t<Vmm>> jit_io_multi_dt_helper_t<Vmm>::at(
914 const data_type_t dt) const {
915 const auto it = storage_.find(dt);
916 if (it != storage_.cend()) return it->second;
917 return nullptr;
918}
919
920template <typename Vmm>
921std::shared_ptr<jit_io_helper_t<Vmm>> jit_io_multi_dt_helper_t<Vmm>::operator[](
922 const data_type_t dt) const {
923 auto res = this->at(dt);
924 if (res == nullptr) { assert(!"data not found in io"); }
925 return res;
926}
927
928template <typename Vmm>
929void jit_io_multi_dt_helper_t<Vmm>::prepare_tail_mask() {
930 return storage_.cbegin()->second->prepare_tail_mask();
931}
932
933template <typename Vmm>
934void jit_io_multi_dt_helper_t<Vmm>::prepare_full_mask() {
935 return storage_.cbegin()->second->prepare_full_mask();
936}
937
938template <typename Vmm>
939void jit_io_multi_dt_helper_t<Vmm>::init_saturate_f32(
940 const data_types_t &store_data_types) {
941 for (const auto &dt : store_data_types) {
942 const auto it = storage_.find(dt);
943 if (it != storage_.cend()) {
944 if (it->second->saturation_conf_.has_value())
945 it->second->init_saturate_f32();
946 }
947 }
948}
949
950template <typename Vmm>
951void jit_io_multi_dt_helper_t<Vmm>::init_full_mask() {
952 return storage_.cbegin()->second->init_full_mask();
953}
954
955template <typename Vmm>
956void jit_io_multi_dt_helper_t<Vmm>::init_bf16() {
957 const auto bf16_io_helper = at(data_type::bf16);
958 if (bf16_io_helper) bf16_io_helper->init_bf16();
959}
960
961template <typename Vmm>
962jit_io_multi_dt_helper_t<Vmm>::~jit_io_multi_dt_helper_t() = default;
963
964template class jit_io_helper_t<Xbyak::Zmm>;
965template class jit_io_helper_t<Xbyak::Ymm>;
966template class jit_io_helper_t<Xbyak::Xmm>;
967
968template class jit_io_multi_dt_helper_t<Xbyak::Zmm>;
969template class jit_io_multi_dt_helper_t<Xbyak::Ymm>;
970template class jit_io_multi_dt_helper_t<Xbyak::Xmm>;
971
972} // namespace io
973} // namespace x64
974} // namespace cpu
975} // namespace impl
976} // namespace dnnl
977