1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | |
19 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
20 | #include "cpu/x64/utils/jit_io_helper.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | namespace io { |
27 | |
28 | io_conf_t::io_conf_t(const bool nt_stores_enabled) |
29 | : nt_stores_enabled_(nt_stores_enabled) {} |
30 | |
31 | io_tail_conf_t::io_tail_conf_t(const std::size_t simd_w, |
32 | const std::size_t tail_size, const Xbyak::Opmask &tail_opmask, |
33 | const int tail_vmm_mask_idx, const Xbyak::Reg64 ®_tmp) |
34 | : simd_w_(simd_w) |
35 | , tail_size_(tail_size) |
36 | , tail_opmask_(tail_opmask) |
37 | , tail_vmm_mask_idx_(tail_vmm_mask_idx) |
38 | , reg_tmp_(reg_tmp) {} |
39 | |
40 | io_tail_conf_t::io_tail_conf_t(const std::size_t simd_w, |
41 | const std::size_t tail_size, int tail_opmask_idx, |
42 | const int tail_vmm_mask_idx, const Xbyak::Reg64 ®_tmp) |
43 | : simd_w_(simd_w) |
44 | , tail_size_(tail_size) |
45 | , tail_opmask_(Xbyak::Opmask(tail_opmask_idx)) |
46 | , tail_vmm_mask_idx_(tail_vmm_mask_idx) |
47 | , reg_tmp_(reg_tmp) {} |
48 | |
49 | io_emu_bf16_conf_t::io_emu_bf16_conf_t(const Xbyak::Zmm &bf16_emu_reserv_1, |
50 | const Xbyak::Zmm &bf16_emu_reserv_2, |
51 | const Xbyak::Zmm &bf16_emu_reserv_3, const Xbyak::Reg64 ®_tmp, |
52 | const Xbyak::Zmm &bf16_emu_reserv_4) |
53 | : bf16_emu_reserv_1_(bf16_emu_reserv_1) |
54 | , bf16_emu_reserv_2_(bf16_emu_reserv_2) |
55 | , bf16_emu_reserv_3_(bf16_emu_reserv_3) |
56 | , reg_tmp_(reg_tmp) |
57 | , bf16_emu_reserv_4_(bf16_emu_reserv_4) {} |
58 | |
59 | io_emu_bf16_conf_t::io_emu_bf16_conf_t(int bf16_emu_reserv_1_idx, |
60 | int bf16_emu_reserv_2_idx, int bf16_emu_reserv_3_idx, |
61 | const Xbyak::Reg64 ®_tmp, int bf16_emu_reserv_4_idx) |
62 | : bf16_emu_reserv_1_(Xbyak::Zmm(bf16_emu_reserv_1_idx)) |
63 | , bf16_emu_reserv_2_(Xbyak::Zmm(bf16_emu_reserv_2_idx)) |
64 | , bf16_emu_reserv_3_(Xbyak::Zmm(bf16_emu_reserv_3_idx)) |
65 | , reg_tmp_(reg_tmp) |
66 | , bf16_emu_reserv_4_(Xbyak::Zmm(bf16_emu_reserv_4_idx)) {} |
67 | |
68 | io_saturation_conf_t::io_saturation_conf_t(const int vreg_zero_saturation_idx, |
69 | const int vreg_saturation_ubound_idx, const Xbyak::Reg64 ®_tmp) |
70 | : vreg_zero_saturation_idx_(vreg_zero_saturation_idx) |
71 | , vreg_saturation_ubound_idx_(vreg_saturation_ubound_idx) |
72 | , reg_tmp_(reg_tmp) {} |
73 | |
74 | io_gather_conf_t::io_gather_conf_t(const std::size_t simd_w, |
75 | const Xbyak::Opmask &full_opmask, const int full_vmm_mask_idx, |
76 | const Xbyak::Reg64 ®_tmp, const Xbyak::Reg64 ®_tmp1, |
77 | const utils::optional_t<int> &vmm_tmp_idx) |
78 | : simd_w_(simd_w) |
79 | , full_opmask_(full_opmask) |
80 | , full_vmm_mask_idx_(full_vmm_mask_idx) |
81 | , reg_tmp_(reg_tmp) |
82 | , reg_tmp1_(reg_tmp1) |
83 | , vmm_tmp_idx_(vmm_tmp_idx) {} |
84 | |
85 | template <typename Vmm> |
86 | jit_io_helper_t<Vmm>::jit_io_helper_t(jit_generator *host, const cpu_isa_t &isa, |
87 | const data_type_t &data_type, const io_conf_t &io_conf, |
88 | const utils::optional_t<io_tail_conf_t> &tail_conf, |
89 | const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf, |
90 | const utils::optional_t<io_saturation_conf_t> &saturation_conf, |
91 | const utils::optional_t<io_gather_conf_t> &gather_conf) |
92 | : host_(host) |
93 | , isa_(isa) |
94 | , data_type_(data_type) |
95 | , bf16_supported_(is_data_type_supported(data_type::bf16)) |
96 | , f16_supported_(is_data_type_supported(data_type::f16)) |
97 | , bf16_emu_(nullptr) |
98 | , io_conf_(io_conf) |
99 | , tail_conf_(tail_conf) |
100 | , bf16_conf_(bf16_conf) |
101 | , saturation_conf_(saturation_conf) |
102 | , gather_conf_(gather_conf) { |
103 | |
104 | if (data_type_ == data_type::bf16 && isa == avx512_core) { |
105 | assert(bf16_conf.has_value() |
106 | && "Config for bf16 emulation is not set." ); |
107 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(host_, |
108 | bf16_conf->bf16_emu_reserv_1_, bf16_conf->bf16_emu_reserv_2_, |
109 | bf16_conf->bf16_emu_reserv_3_, bf16_conf->reg_tmp_, |
110 | bf16_conf->bf16_emu_reserv_4_); |
111 | } |
112 | |
113 | assert(utils::one_of(data_type_, data_type::f16, data_type::bf16, |
114 | data_type::f32, data_type::s8, data_type::u8, data_type::s32) |
115 | && is_data_type_supported(data_type_) |
116 | && "Supported data types f16, bf16, f32, s8, u8, s32" ); |
117 | |
118 | /* |
119 | * vpmovsxbd, vpmovzxbd for AVX are defined only for XMM. Since AVX2 |
120 | * they are defined also for YMM. In order to avoid workaround with |
121 | * potential performance penalty AVX with s8u8 disabled with YMM. |
122 | */ |
123 | static constexpr bool is_xmm = std::is_same<Vmm, Xbyak::Xmm>::value; |
124 | const bool is_avx_u8s8 = (isa_ == avx |
125 | && utils::one_of(data_type_, data_type::s8, data_type::u8)); |
126 | MAYBE_UNUSED(is_xmm); |
127 | MAYBE_UNUSED(is_avx_u8s8); |
128 | |
129 | assert(IMPLICATION(is_avx_u8s8, is_xmm) |
130 | && "s8u8 with AVX should be used with XMM vreg" ); |
131 | |
132 | static constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value; |
133 | MAYBE_UNUSED(is_zmm); |
134 | assert(IMPLICATION(!is_superset(isa_, avx512_core), !is_zmm) |
135 | && "This architecture does not support zmms." ); |
136 | } |
137 | |
138 | template <typename Vmm> |
139 | jit_io_helper_t<Vmm>::~jit_io_helper_t() = default; |
140 | |
141 | template <typename Vmm> |
142 | bool jit_io_helper_t<Vmm>::is_data_type_supported(const data_type_t dt) { |
143 | switch (dt) { |
144 | case data_type::f32: |
145 | case data_type::s32: |
146 | case data_type::u8: |
147 | case data_type::s8: return true; |
148 | case data_type::bf16: |
149 | return is_superset(isa_, avx512_core) || isa_ == avx2_vnni_2; |
150 | case data_type::f16: |
151 | return is_superset(isa_, avx512_core_fp16) || isa_ == avx2_vnni_2; |
152 | default: assert(!"Unsupported data type" ); |
153 | } |
154 | return false; |
155 | } |
156 | |
157 | template <typename Vmm> |
158 | void jit_io_helper_t<Vmm>::init_bf16() { |
159 | if (bf16_emu_) { |
160 | assert(bf16_conf_.has_value() |
161 | && "Config for bf16 emulation is not set." ); |
162 | bf16_emu_->init_vcvtneps2bf16(); |
163 | } |
164 | } |
165 | |
166 | template <typename Vmm> |
167 | void jit_io_helper_t<Vmm>::prepare_opmask( |
168 | const std::size_t how_many_bits_to_set, const Xbyak::Reg64 ®_tmp, |
169 | const Xbyak::Opmask &mask) { |
170 | const int mask_f32 = (1 << how_many_bits_to_set) - 1; |
171 | const Xbyak::Reg32 regw_tmp = reg_tmp.cvt32(); |
172 | host_->mov(regw_tmp, mask_f32); |
173 | host_->kmovw(mask, regw_tmp); |
174 | } |
175 | |
176 | template <typename Vmm> |
177 | void jit_io_helper_t<Vmm>::prepare_vmm_mask( |
178 | const std::size_t how_many_bits_to_set, const std::size_t simd_w, |
179 | const Xbyak::Reg64 ®_tmp, const Vmm &mask) { |
180 | static const uint32_t mask_f32[14] |
181 | = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, |
182 | 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0}; |
183 | |
184 | if (how_many_bits_to_set < simd_w) { |
185 | host_->mov(reg_tmp, |
186 | reinterpret_cast<size_t>(&mask_f32[7 - how_many_bits_to_set])); |
187 | host_->uni_vmovups(mask, host_->ptr[reg_tmp]); |
188 | } else if (how_many_bits_to_set == simd_w) { |
189 | host_->uni_vcmpps(mask, mask, mask, jit_generator::_cmp_eq_oq); |
190 | } else { |
191 | assert(!"Can't set so many bits." ); |
192 | } |
193 | } |
194 | |
195 | template <typename Vmm> |
196 | void jit_io_helper_t<Vmm>::prepare_i8_data_to_store(const Vmm &i8_vmm) { |
197 | assert(saturation_conf_.has_value() && "Config for saturation is not set." ); |
198 | |
199 | static constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
200 | |
201 | host_->uni_vpackssdw( |
202 | i8_vmm, i8_vmm, Vmm(saturation_conf_->vreg_zero_saturation_idx_)); |
203 | if (is_ymm) { |
204 | // dst[63:0] = src[63:0] |
205 | // dst[127:64] = src[191:128] |
206 | // dst[191:128] = src[127:64] |
207 | // dst[255:192] = src[127:64] |
208 | const auto src_ymm = Xbyak::Ymm(i8_vmm.getIdx()); |
209 | host_->vpermq(src_ymm, src_ymm, 0x58); |
210 | } |
211 | |
212 | if (data_type_ == data_type::s8) |
213 | host_->uni_vpacksswb(i8_vmm, i8_vmm, |
214 | Vmm(saturation_conf_->vreg_zero_saturation_idx_)); |
215 | else |
216 | host_->uni_vpackuswb(i8_vmm, i8_vmm, |
217 | Vmm(saturation_conf_->vreg_zero_saturation_idx_)); |
218 | } |
219 | |
220 | template <typename Vmm> |
221 | void jit_io_helper_t<Vmm>::prepare_xf16_data_to_store(const Vmm &vmm) { |
222 | assert(!is_superset(isa_, avx512_core)); |
223 | const auto &cvt_lower_vmm = |
224 | typename vreg_traits<Vmm>::Vmm_lower_t(vmm.getIdx()); |
225 | |
226 | if (data_type_ == data_type::bf16) |
227 | host_->vcvtneps2bf16(cvt_lower_vmm, vmm, Xbyak::VexEncoding); |
228 | else |
229 | host_->uni_vcvtps2phx(cvt_lower_vmm, vmm); |
230 | } |
231 | |
232 | template <> |
233 | void jit_io_helper_t<Xbyak::Zmm>::emu_gather(const Xbyak::Reg64 &src_reg, |
234 | const Xbyak::Zmm &indices_vmm, const Xbyak::Zmm &dst_vmm, |
235 | const bool tail) { |
236 | assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set." ); |
237 | assert(gather_conf_->vmm_tmp_idx_.has_value() |
238 | && "Temporary vreg is not set." ); |
239 | assert(IMPLICATION(tail, tail_conf_.has_value()) |
240 | && "Config for tail processing is not set." ); |
241 | |
242 | const Xbyak::Xmm xmm_tmp = Xbyak::Xmm(gather_conf_->full_vmm_mask_idx_); |
243 | const Xbyak::Xmm xmm_dst = Xbyak::Xmm(*gather_conf_->vmm_tmp_idx_); |
244 | const Xbyak::Ymm dst_ymm; |
245 | |
246 | host_->mov(gather_conf_->reg_tmp_, 0); |
247 | host_->mov(gather_conf_->reg_tmp1_, src_reg); |
248 | |
249 | // The conversion of bf16->f32 here is split into two parts. |
250 | // Here while loading words of bf16, the words are interleaved, |
251 | // and in convert_to_f32, they are shifted-left to finally convert to f32 |
252 | // For f16 we do not need such interleaving. |
253 | const int xmm_size_elem = (data_type_ == data_type::f16) ? 8 : 4; |
254 | const int number_of_xmms = tail |
255 | ? utils::div_up(tail_conf_->tail_size_, xmm_size_elem) |
256 | : utils::div_up(gather_conf_->simd_w_, xmm_size_elem); |
257 | const int num_indices_in_xmm = 16 / sizeof(int); |
258 | for (int i = 0, idx = 0; i < number_of_xmms; i++) { |
259 | |
260 | const int number_of_values_to_load = i == number_of_xmms - 1 && tail |
261 | && tail_conf_->tail_size_ % xmm_size_elem != 0 |
262 | ? tail_conf_->tail_size_ % xmm_size_elem |
263 | : xmm_size_elem; |
264 | for (int j = 0; j < number_of_values_to_load; j++) { |
265 | |
266 | if (j % num_indices_in_xmm == 0) |
267 | host_->vextractf32x4(xmm_tmp, indices_vmm, idx++); |
268 | host_->vpextrd(gather_conf_->reg_tmp_.cvt32(), xmm_tmp, j); |
269 | host_->add(src_reg, gather_conf_->reg_tmp_); |
270 | switch (data_type_) { |
271 | case data_type::f16: |
272 | host_->vpinsrw(xmm_dst, xmm_dst, host_->ptr[src_reg], j); |
273 | break; |
274 | case data_type::bf16: |
275 | host_->vpinsrw( |
276 | xmm_dst, xmm_dst, host_->ptr[src_reg], j * 2); |
277 | break; |
278 | case data_type::s8: |
279 | case data_type::u8: |
280 | host_->vpinsrb(xmm_dst, xmm_dst, host_->ptr[src_reg], |
281 | i * xmm_size_elem + j); |
282 | break; |
283 | default: assert(!"Unsupported data type." ); |
284 | } |
285 | host_->mov(src_reg, gather_conf_->reg_tmp1_); |
286 | } |
287 | if (data_type_ == data_type::bf16) { |
288 | host_->vinsertf32x4(dst_vmm, dst_vmm, xmm_dst, i); |
289 | host_->vpxord(xmm_dst, xmm_dst, xmm_dst); |
290 | } else if (data_type_ == data_type::f16) { |
291 | host_->vinsertf32x4(dst_ymm, dst_ymm, xmm_dst, i); |
292 | host_->vpxord(xmm_dst, xmm_dst, xmm_dst); |
293 | } |
294 | } |
295 | |
296 | if (data_type_ == data_type::bf16) |
297 | convert_to_f32(dst_vmm, dst_vmm, data_type_); |
298 | else if (data_type_ == data_type::f16) |
299 | convert_to_f32(dst_vmm, dst_ymm, data_type_); |
300 | else if (data_type_ == data_type::s8 || data_type_ == data_type::u8) |
301 | convert_to_f32(dst_vmm, xmm_dst, data_type_); |
302 | } |
303 | |
304 | template <> |
305 | void jit_io_helper_t<Xbyak::Ymm>::emu_gather(const Xbyak::Reg64 &src_reg, |
306 | const Xbyak::Ymm &indices_vmm, const Xbyak::Ymm &dst_vmm, |
307 | const bool tail) { |
308 | assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set." ); |
309 | assert(gather_conf_->vmm_tmp_idx_.has_value() |
310 | && "Temporary vreg is not set." ); |
311 | assert(IMPLICATION(tail, tail_conf_.has_value()) |
312 | && "Config for tail processing is not set." ); |
313 | |
314 | const Xbyak::Xmm xmm_tmp = Xbyak::Xmm(gather_conf_->full_vmm_mask_idx_); |
315 | const Xbyak::Xmm xmm_dst = Xbyak::Xmm(*gather_conf_->vmm_tmp_idx_); |
316 | |
317 | host_->mov(gather_conf_->reg_tmp_, 0); |
318 | host_->mov(gather_conf_->reg_tmp1_, src_reg); |
319 | |
320 | // The conversion of bf16->f32 here is split into two parts. |
321 | // Here while loading words of bf16, the words are interleaved, |
322 | // and in convert_to_f32, they are shifted-left to finally convert to f32 |
323 | // For f16 we do not need such interleaving. |
324 | const int xmm_size_elem = (data_type_ == data_type::f16) ? 8 : 4; |
325 | const int number_of_xmms = tail |
326 | ? utils::div_up(tail_conf_->tail_size_, xmm_size_elem) |
327 | : utils::div_up(gather_conf_->simd_w_, xmm_size_elem); |
328 | for (int i = 0; i < number_of_xmms; i++) { |
329 | host_->vextractf128(xmm_tmp, indices_vmm, i); |
330 | |
331 | const int number_of_values_to_load = i == number_of_xmms - 1 && tail |
332 | && tail_conf_->tail_size_ % xmm_size_elem != 0 |
333 | ? tail_conf_->tail_size_ % xmm_size_elem |
334 | : xmm_size_elem; |
335 | for (int j = 0; j < number_of_values_to_load; j++) { |
336 | host_->vpextrd(gather_conf_->reg_tmp_.cvt32(), xmm_tmp, j); |
337 | host_->add(src_reg, gather_conf_->reg_tmp_); |
338 | switch (data_type_) { |
339 | case data_type::f32: |
340 | case data_type::s32: { |
341 | host_->vpinsrd(xmm_dst, xmm_dst, host_->ptr[src_reg], j); |
342 | break; |
343 | } |
344 | case data_type::f16: |
345 | assert(f16_supported_ && "Unsupported data type." ); |
346 | host_->vpinsrw(xmm_dst, xmm_dst, host_->ptr[src_reg], j); |
347 | break; |
348 | case data_type::bf16: |
349 | assert(bf16_supported_ && "Unsupported data type." ); |
350 | host_->vpinsrw( |
351 | xmm_dst, xmm_dst, host_->ptr[src_reg], j * 2); |
352 | break; |
353 | case data_type::s8: |
354 | case data_type::u8: { |
355 | host_->vpinsrb(xmm_dst, xmm_dst, host_->ptr[src_reg], |
356 | i * xmm_size_elem + j); |
357 | break; |
358 | } |
359 | default: assert(!"Unsupported data type." ); |
360 | } |
361 | host_->mov(src_reg, gather_conf_->reg_tmp1_); |
362 | } |
363 | |
364 | if (data_type_ == data_type::f32 || data_type_ == data_type::s32) { |
365 | host_->vinsertf128(dst_vmm, dst_vmm, xmm_dst, i); |
366 | } |
367 | } |
368 | |
369 | if (data_type_ == data_type::s32 || data_type_ == data_type::bf16) |
370 | convert_to_f32(dst_vmm, dst_vmm, data_type_); |
371 | else if (utils::one_of( |
372 | data_type_, data_type::s8, data_type::u8, data_type::f16)) |
373 | convert_to_f32(dst_vmm, xmm_dst, data_type_); |
374 | } |
375 | |
376 | template <> |
377 | void jit_io_helper_t<Xbyak::Xmm>::emu_gather(const Xbyak::Reg64 &src_reg, |
378 | const Xbyak::Xmm &indices_vmm, const Xbyak::Xmm &dst_vmm, |
379 | const bool tail) { |
380 | assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set." ); |
381 | assert(IMPLICATION(tail, tail_conf_.has_value()) |
382 | && "Config for tail processing is not set." ); |
383 | |
384 | host_->mov(gather_conf_->reg_tmp_, 0); |
385 | host_->mov(gather_conf_->reg_tmp1_, src_reg); |
386 | |
387 | // The conversion of bf16->f32 here is split into two parts. |
388 | // Here while loading words of bf16, the words are interleaved, |
389 | // and in convert_to_f32, they are shifted-left to finally convert to f32 |
390 | // For f16 we do not need such interleaving. |
391 | const unsigned xmm_size_elem = (data_type_ == data_type::f16) ? 8 : 4; |
392 | |
393 | const unsigned number_of_values_to_load |
394 | = tail ? tail_conf_->tail_size_ : xmm_size_elem; |
395 | for (unsigned j = 0; j < number_of_values_to_load; j++) { |
396 | host_->pextrd(gather_conf_->reg_tmp_.cvt32(), indices_vmm, j); |
397 | host_->add(src_reg, gather_conf_->reg_tmp_); |
398 | switch (data_type_) { |
399 | case data_type::f32: |
400 | case data_type::s32: { |
401 | host_->pinsrd(dst_vmm, host_->ptr[src_reg], j); |
402 | break; |
403 | } |
404 | case data_type::f16: |
405 | assert(f16_supported_ && "Unsupported data type." ); |
406 | host_->pinsrw(dst_vmm, host_->ptr[src_reg], j); |
407 | break; |
408 | case data_type::bf16: |
409 | assert(bf16_supported_ && "Unsupported data type." ); |
410 | host_->pinsrw(dst_vmm, host_->ptr[src_reg], j * 2); |
411 | break; |
412 | case data_type::s8: |
413 | case data_type::u8: { |
414 | host_->pinsrb(dst_vmm, host_->ptr[src_reg], j); |
415 | break; |
416 | } |
417 | default: assert(!"Unsupported data type." ); |
418 | } |
419 | host_->mov(src_reg, gather_conf_->reg_tmp1_); |
420 | } |
421 | |
422 | if (data_type_ != data_type::f32) |
423 | convert_to_f32(dst_vmm, dst_vmm, data_type_); |
424 | } |
425 | |
426 | template <typename Vmm> |
427 | void jit_io_helper_t<Vmm>::prepare_tail_mask() { |
428 | assert(tail_conf_.has_value() && "Config for tail processing is not set." ); |
429 | |
430 | if (!tail_conf_->tail_size_) return; |
431 | |
432 | if (is_superset(isa_, avx512_core)) |
433 | prepare_opmask(tail_conf_->tail_size_, tail_conf_->reg_tmp_, |
434 | tail_conf_->tail_opmask_); |
435 | else if (is_superset(isa_, sse41)) |
436 | prepare_vmm_mask(tail_conf_->tail_size_, tail_conf_->simd_w_, |
437 | tail_conf_->reg_tmp_, Vmm(tail_conf_->tail_vmm_mask_idx_)); |
438 | } |
439 | |
440 | template <typename Vmm> |
441 | void jit_io_helper_t<Vmm>::prepare_full_mask() { |
442 | assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set." ); |
443 | |
444 | if (utils::one_of(data_type_, data_type::f16, data_type::bf16, |
445 | data_type::s8, data_type::u8)) |
446 | return; |
447 | |
448 | if (is_superset(isa_, avx512_core)) |
449 | prepare_opmask(gather_conf_->simd_w_, gather_conf_->reg_tmp_, |
450 | gather_conf_->full_opmask_); |
451 | else if (isa_ == avx2) |
452 | prepare_vmm_mask(gather_conf_->simd_w_, gather_conf_->simd_w_, |
453 | gather_conf_->reg_tmp_, Vmm(gather_conf_->full_vmm_mask_idx_)); |
454 | } |
455 | |
456 | template <typename Vmm> |
457 | void jit_io_helper_t<Vmm>::init_full_mask() { |
458 | assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set." ); |
459 | |
460 | if (isa_ == avx2) { |
461 | const Vmm vmm_mask = Vmm(gather_conf_->full_vmm_mask_idx_); |
462 | host_->uni_vxorps(vmm_mask, vmm_mask, vmm_mask); |
463 | } |
464 | } |
465 | |
466 | template <typename Vmm> |
467 | void jit_io_helper_t<Vmm>::init_saturate_f32() const { |
468 | assert(saturation_conf_.has_value() && "Config for saturation is not set." ); |
469 | |
470 | if (utils::one_of(data_type_, data_type::u8, data_type::s8, data_type::s32)) |
471 | host_->init_saturate_f32( |
472 | Vmm(saturation_conf_->vreg_zero_saturation_idx_), |
473 | Vmm(saturation_conf_->vreg_saturation_ubound_idx_), |
474 | saturation_conf_->reg_tmp_, data_type::f32, data_type_); |
475 | } |
476 | |
477 | template <typename Vmm> |
478 | void jit_io_helper_t<Vmm>::gather(const Xbyak::Reg64 &src_reg, |
479 | const Vmm &indices_vmm, const Vmm &dst_vmm, const bool tail) { |
480 | assert(gather_conf_.has_value() && "Config for loading with the use of gather instruction is not set." ); |
481 | assert(IMPLICATION(tail, tail_conf_.has_value()) |
482 | && "Config for tail processing is not set." ); |
483 | |
484 | const Vmm &mask = tail ? Vmm(tail_conf_->tail_vmm_mask_idx_) |
485 | : Vmm(gather_conf_->full_vmm_mask_idx_); |
486 | |
487 | const Vmm dst_vmm_with_mask = tail ? dst_vmm | tail_conf_->tail_opmask_ |
488 | : dst_vmm | gather_conf_->full_opmask_; |
489 | |
490 | const bool can_use_gather_instruction |
491 | = isa_ == avx2 || is_superset(isa_, avx512_core); |
492 | |
493 | if ((data_type_ == data_type::f32 || data_type_ == data_type::s32) |
494 | && can_use_gather_instruction) { |
495 | if (data_type_ == data_type::f32) { |
496 | if (isa_ == avx2) |
497 | host_->vgatherdps( |
498 | dst_vmm, host_->ptr[src_reg + indices_vmm], mask); |
499 | else |
500 | host_->vgatherdps( |
501 | dst_vmm_with_mask, host_->ptr[src_reg + indices_vmm]); |
502 | } else { |
503 | if (isa_ == avx2) |
504 | host_->vpgatherdd( |
505 | dst_vmm, host_->ptr[src_reg + indices_vmm], mask); |
506 | else |
507 | host_->vpgatherdd( |
508 | dst_vmm_with_mask, host_->ptr[src_reg + indices_vmm]); |
509 | convert_to_f32(dst_vmm, dst_vmm, data_type_); |
510 | } |
511 | |
512 | // Have to restore processing mask after gather because mask |
513 | // was zeroed. |
514 | if (tail) |
515 | prepare_tail_mask(); |
516 | else |
517 | prepare_full_mask(); |
518 | } else { |
519 | emu_gather(src_reg, indices_vmm, dst_vmm, tail); |
520 | } |
521 | } |
522 | |
523 | template <typename Vmm> |
524 | void jit_io_helper_t<Vmm>::load(const Xbyak::Address &src_addr, |
525 | const Vmm &dst_raw_vmm, const bool tail) { |
526 | assert(IMPLICATION(tail, tail_conf_.has_value()) |
527 | && "Config for tail processing is not set." ); |
528 | |
529 | const bool is_avx512 = is_superset(isa_, avx512_core); |
530 | |
531 | const auto dst_vmm = tail && is_avx512 |
532 | ? (dst_raw_vmm | tail_conf_->tail_opmask_ | host_->T_z) |
533 | : dst_raw_vmm; |
534 | |
535 | const bool is_i8 = utils::one_of(data_type_, data_type::s8, data_type::u8); |
536 | const bool is_xf16 |
537 | = utils::one_of(data_type_, data_type::bf16, data_type::f16); |
538 | const bool is_tail_load_supported = is_avx512; |
539 | const bool can_load_byte_by_byte = tail |
540 | && (isa_ == sse41 |
541 | || (!is_tail_load_supported && (is_i8 || is_xf16))); |
542 | |
543 | if (can_load_byte_by_byte) { |
544 | load_byte_by_byte(src_addr, dst_vmm, tail_conf_->tail_size_); |
545 | } else { |
546 | switch (data_type_) { |
547 | case data_type::f32: load_f32(src_addr, dst_vmm, tail); break; |
548 | case data_type::s32: load_s32(src_addr, dst_vmm, tail); break; |
549 | case data_type::bf16: load_bf16(src_addr, dst_vmm); break; |
550 | case data_type::f16: load_f16(src_addr, dst_vmm); break; |
551 | case data_type::s8: |
552 | case data_type::u8: load_i8(src_addr, dst_vmm); break; |
553 | default: assert(!"Unsupported data type." ); |
554 | } |
555 | } |
556 | } |
557 | |
558 | template <> |
559 | void jit_io_helper_t<Xbyak::Zmm>::load_byte_by_byte( |
560 | const Xbyak::Address &src_addr, const Xbyak::Zmm &dst_vmm, |
561 | const int load_size) { |
562 | assert("Load byte by byte is not supported for Zmms." ); |
563 | } |
564 | |
565 | template <typename Vmm> |
566 | void jit_io_helper_t<Vmm>::load_byte_by_byte(const Xbyak::Address &src_addr, |
567 | const Vmm &dst_vmm, const int load_size) { |
568 | host_->uni_vxorps(dst_vmm, dst_vmm, dst_vmm); |
569 | host_->load_data(data_type_, dst_vmm, src_addr, load_size); |
570 | |
571 | if (utils::one_of(data_type_, data_type::s32, data_type::s8, data_type::u8)) |
572 | convert_to_f32(dst_vmm, dst_vmm, data_type::s32); |
573 | } |
574 | |
575 | template <typename Vmm> |
576 | void jit_io_helper_t<Vmm>::load_f32( |
577 | const Xbyak::Address &src_addr, const Vmm &dst_vmm, const bool tail) { |
578 | if (tail && !is_superset(isa_, avx512_core)) |
579 | host_->vmaskmovps( |
580 | dst_vmm, Vmm(tail_conf_->tail_vmm_mask_idx_), src_addr); |
581 | else |
582 | host_->uni_vmovups(dst_vmm, src_addr); |
583 | } |
584 | |
585 | template <typename Vmm> |
586 | void jit_io_helper_t<Vmm>::load_s32( |
587 | const Xbyak::Address &src_addr, const Vmm &dst_vmm, const bool tail) { |
588 | if (is_superset(isa_, avx512_core)) |
589 | host_->uni_vcvtdq2ps(dst_vmm, src_addr); |
590 | else { |
591 | load_f32(src_addr, dst_vmm, tail); |
592 | convert_to_f32(dst_vmm, dst_vmm, data_type::s32); |
593 | } |
594 | } |
595 | |
596 | template <typename Vmm> |
597 | void jit_io_helper_t<Vmm>::load_bf16( |
598 | const Xbyak::Address &src_addr, const Vmm &dst_vmm) { |
599 | assert(bf16_supported_ && "Unsupported data type." ); |
600 | |
601 | host_->vpmovzxwd(dst_vmm, src_addr); |
602 | convert_to_f32(dst_vmm, dst_vmm, data_type::bf16); |
603 | } |
604 | |
605 | template <typename Vmm> |
606 | void jit_io_helper_t<Vmm>::load_f16( |
607 | const Xbyak::Address &src_addr, const Vmm &dst_vmm) { |
608 | assert(f16_supported_ && "Unsupported data type." ); |
609 | host_->uni_vcvtph2psx(dst_vmm, src_addr); |
610 | } |
611 | |
612 | template <typename Vmm> |
613 | void jit_io_helper_t<Vmm>::load_i8( |
614 | const Xbyak::Address &src_addr, const Vmm &dst_vmm) { |
615 | if (data_type_ == data_type::s8) |
616 | host_->uni_vpmovsxbd(dst_vmm, src_addr); |
617 | else |
618 | host_->uni_vpmovzxbd(dst_vmm, src_addr); |
619 | |
620 | convert_to_f32(dst_vmm, dst_vmm, data_type::s32); |
621 | } |
622 | |
623 | template <typename Vmm> |
624 | void jit_io_helper_t<Vmm>::load_two_simdw_xf16(const Xbyak::Address &src_addr, |
625 | const Vmm &dst_even_vmm, const Vmm &dst_odd_vmm) { |
626 | // The outputs are in odd/even interleaved layouts |
627 | // now only support bf16/f16 w/o tail on AVX2_VNNI_2 |
628 | assert(utils::one_of(data_type_, data_type::bf16, data_type::f16) |
629 | && isa_ == avx2_vnni_2 && "Unsupported data type." ); |
630 | |
631 | if (data_type_ == data_type::bf16) { |
632 | host_->vcvtneebf162ps(dst_even_vmm, src_addr); |
633 | host_->vcvtneobf162ps(dst_odd_vmm, src_addr); |
634 | } else { |
635 | host_->vcvtneeph2ps(dst_even_vmm, src_addr); |
636 | host_->vcvtneoph2ps(dst_odd_vmm, src_addr); |
637 | } |
638 | } |
639 | |
640 | template <typename Vmm> |
641 | void jit_io_helper_t<Vmm>::merge_interleaved_to_plain( |
642 | const Vmm &vmm_even, const Vmm &vmm_odd, const Vmm &vmm_aux0) { |
643 | // Merge inputs in odd/even interleaved layouts to plain layouts |
644 | assert(vmm_even.isYMM() && vmm_odd.isYMM() |
645 | && "Merge interleaved to plain only supports Ymms" ); |
646 | Xbyak::Ymm ymm_even = Xbyak::Ymm(vmm_even.getIdx()); |
647 | Xbyak::Ymm ymm_odd = Xbyak::Ymm(vmm_odd.getIdx()); |
648 | Xbyak::Ymm ymm_aux0 = Xbyak::Ymm(vmm_aux0.getIdx()); |
649 | Xbyak::Ymm ymm_aux1 = Xbyak::Ymm(vmm_odd.getIdx()); |
650 | |
651 | host_->vpunpckldq(ymm_aux0, ymm_even, ymm_odd); |
652 | host_->vpunpckhdq(ymm_aux1, ymm_even, ymm_odd); |
653 | host_->vperm2i128(ymm_even, ymm_aux0, ymm_aux1, 0x20); |
654 | host_->vperm2i128(ymm_odd, ymm_aux0, ymm_aux1, 0x31); |
655 | } |
656 | |
657 | template <typename Vmm> |
658 | void jit_io_helper_t<Vmm>::store(const Vmm &src_raw_vmm, |
659 | const Xbyak::Address &dst_raw_addr, const bool tail) { |
660 | assert(IMPLICATION(tail, tail_conf_.has_value()) |
661 | && "Config for tail processing is not set." ); |
662 | assert(!(tail && io_conf_.nt_stores_enabled_) |
663 | && "Usage of non-temporal stores with tail leads to a general-protection exception." ); |
664 | |
665 | const bool is_avx512 = is_superset(isa_, avx512_core); |
666 | |
667 | const auto dst_addr = tail && is_avx512 |
668 | ? (dst_raw_addr | tail_conf_->tail_opmask_) |
669 | : dst_raw_addr; |
670 | const auto src_vmm = tail && is_avx512 |
671 | ? (src_raw_vmm | tail_conf_->tail_opmask_) |
672 | : src_raw_vmm; |
673 | |
674 | const bool is_store_tail_supported = is_avx512; |
675 | const bool is_i8 = utils::one_of(data_type_, data_type::s8, data_type::u8); |
676 | const bool is_xf16 |
677 | = utils::one_of(data_type_, data_type::bf16, data_type::f16); |
678 | |
679 | const bool can_store_byte_by_byte = tail |
680 | && (isa_ == sse41 |
681 | || (!is_store_tail_supported && (is_i8 || is_xf16))); |
682 | |
683 | if (data_type_ == data_type::s32 || is_i8) saturate(src_raw_vmm); |
684 | |
685 | if (can_store_byte_by_byte) { |
686 | const size_t store_size |
687 | = tail_conf_->tail_size_ * types::data_type_size(data_type_); |
688 | store_byte_by_byte(src_vmm, dst_addr, store_size); |
689 | } else { |
690 | switch (data_type_) { |
691 | case data_type::f32: |
692 | case data_type::s32: store_f32(src_vmm, dst_addr, tail); break; |
693 | case data_type::bf16: store_bf16(src_vmm, dst_addr); break; |
694 | case data_type::f16: store_f16(src_vmm, dst_addr); break; |
695 | case data_type::s8: |
696 | case data_type::u8: store_i8(src_vmm, dst_raw_addr); break; |
697 | default: assert(!"Unsupported data type." ); |
698 | } |
699 | } |
700 | } |
701 | |
702 | template <typename Vmm> |
703 | void jit_io_helper_t<Vmm>::saturate(const Vmm &vmm) { |
704 | assert(saturation_conf_.has_value() && "Config for saturation is not set." ); |
705 | |
706 | host_->saturate_f32(vmm, Vmm(saturation_conf_->vreg_zero_saturation_idx_), |
707 | Vmm(saturation_conf_->vreg_saturation_ubound_idx_), data_type_); |
708 | host_->uni_vcvtps2dq(vmm, vmm); |
709 | } |
710 | |
711 | template <> |
712 | void jit_io_helper_t<Xbyak::Zmm>::store_byte_by_byte(const Xbyak::Zmm &src_zmm, |
713 | const Xbyak::Address &dst_addr, const int store_size) { |
714 | assert("Store byte by byte is not supported for Zmms." ); |
715 | } |
716 | |
717 | template <typename Vmm> |
718 | void jit_io_helper_t<Vmm>::store_byte_by_byte(const Vmm &src_vmm, |
719 | const Xbyak::Address &dst_addr, const int store_size) { |
720 | const bool is_i8 = utils::one_of(data_type_, data_type::s8, data_type::u8); |
721 | const bool is_xf16 |
722 | = utils::one_of(data_type_, data_type::bf16, data_type::f16); |
723 | const auto &cvt_lower_vmm = |
724 | typename vreg_traits<Vmm>::Vmm_lower_t(src_vmm.getIdx()); |
725 | |
726 | if (is_i8) prepare_i8_data_to_store(src_vmm); |
727 | if (is_xf16) prepare_xf16_data_to_store(src_vmm); |
728 | |
729 | host_->store_bytes(is_xf16 ? cvt_lower_vmm : src_vmm, dst_addr, store_size); |
730 | } |
731 | |
732 | template <typename Vmm> |
733 | void jit_io_helper_t<Vmm>::store_f32( |
734 | const Vmm &src_vmm, const Xbyak::Address &dst_addr, const bool tail) { |
735 | if (io_conf_.nt_stores_enabled_) |
736 | host_->uni_vmovntps(dst_addr, src_vmm); |
737 | else if (!is_superset(isa_, avx512_core) && tail) |
738 | host_->vmaskmovps( |
739 | dst_addr, Vmm(tail_conf_->tail_vmm_mask_idx_), src_vmm); |
740 | else |
741 | host_->uni_vmovups(dst_addr, src_vmm); |
742 | } |
743 | |
744 | template <typename Vmm> |
745 | void jit_io_helper_t<Vmm>::store_bf16( |
746 | const Vmm &src_vmm, const Xbyak::Address &dst_addr) { |
747 | assert(bf16_supported_ && "Unsupported data type." ); |
748 | assert((src_vmm.isZMM() || src_vmm.isYMM()) |
749 | && "Store operation for bf16 is not supported for Xmms." ); |
750 | |
751 | const auto &cvt_lower_vmm = |
752 | typename vreg_traits<Vmm>::Vmm_lower_t(src_vmm.getIdx()); |
753 | |
754 | if (bf16_emu_) |
755 | bf16_emu_->vcvtneps2bf16(cvt_lower_vmm, src_vmm); |
756 | else |
757 | host_->vcvtneps2bf16(cvt_lower_vmm, src_vmm, |
758 | mayiuse(avx512_core) ? Xbyak::EvexEncoding |
759 | : Xbyak::VexEncoding); |
760 | |
761 | if (io_conf_.nt_stores_enabled_) |
762 | host_->uni_vmovntps(dst_addr, cvt_lower_vmm); |
763 | else |
764 | host_->uni_vmovdqu16(dst_addr, cvt_lower_vmm); |
765 | } |
766 | |
767 | template <typename Vmm> |
768 | void jit_io_helper_t<Vmm>::store_f16( |
769 | const Vmm &src_vmm, const Xbyak::Address &dst_addr) { |
770 | assert(f16_supported_ && "Unsupported data type." ); |
771 | assert((src_vmm.isZMM() || src_vmm.isYMM()) |
772 | && "Store operation for f16 is not supported for Xmms." ); |
773 | |
774 | const auto &cvt_lower_vmm = |
775 | typename vreg_traits<Vmm>::Vmm_lower_t(src_vmm.getIdx()); |
776 | |
777 | host_->uni_vcvtps2phx(cvt_lower_vmm, src_vmm); |
778 | |
779 | if (io_conf_.nt_stores_enabled_) |
780 | host_->uni_vmovntps(dst_addr, cvt_lower_vmm); |
781 | else |
782 | host_->uni_vmovdqu16(dst_addr, cvt_lower_vmm); |
783 | } |
784 | |
785 | template <typename Vmm> |
786 | void jit_io_helper_t<Vmm>::store_i8( |
787 | const Vmm &src_vmm, const Xbyak::Address &dst_addr) { |
788 | if (!is_superset(isa_, avx512_core)) { |
789 | static constexpr bool is_ymm = std::is_same<Vmm, Xbyak::Ymm>::value; |
790 | |
791 | prepare_i8_data_to_store(src_vmm); |
792 | if (is_ymm) |
793 | host_->uni_vmovq(dst_addr, Xbyak::Xmm(src_vmm.getIdx())); |
794 | else |
795 | host_->uni_vmovd(dst_addr, src_vmm); |
796 | } else { |
797 | using namespace std::placeholders; |
798 | static constexpr bool is_zmm = std::is_same<Vmm, Xbyak::Zmm>::value; |
799 | |
800 | auto store_i8_fn = data_type_ == data_type::s8 |
801 | ? std::bind(&jit_generator::vpmovsdb, host_, _1, _2) |
802 | : std::bind(&jit_generator::vpmovusdb, host_, _1, _2); |
803 | |
804 | if (io_conf_.nt_stores_enabled_ && is_zmm) { |
805 | Xbyak::Xmm src_xmm(src_vmm.getIdx()); |
806 | store_i8_fn(src_xmm, src_vmm); |
807 | host_->uni_vmovntps(dst_addr, src_xmm); |
808 | } else { |
809 | store_i8_fn(dst_addr, src_vmm); |
810 | } |
811 | } |
812 | } |
813 | |
814 | template <typename Vmm> |
815 | void jit_io_helper_t<Vmm>::convert_to_f32(const Vmm &dst_vmm, |
816 | const Xbyak::Xmm &src_vmm, const data_type_t src_data_type) { |
817 | switch (src_data_type) { |
818 | case data_type::s32: { |
819 | assert(dst_vmm.getIdx() == src_vmm.getIdx()); |
820 | host_->uni_vcvtdq2ps(dst_vmm, dst_vmm); |
821 | break; |
822 | } |
823 | case data_type::bf16: |
824 | assert(bf16_supported_ && "Unsupported data type." ); |
825 | host_->vpslld(dst_vmm, src_vmm, 0x10); |
826 | break; |
827 | case data_type::f16: |
828 | assert(f16_supported_ && "Unsupported data type." ); |
829 | host_->vcvtph2ps(dst_vmm, src_vmm); |
830 | break; |
831 | case data_type::s8: { |
832 | host_->uni_vpmovsxbd(dst_vmm, src_vmm); |
833 | host_->uni_vcvtdq2ps(dst_vmm, dst_vmm); |
834 | break; |
835 | } |
836 | case data_type::u8: { |
837 | host_->uni_vpmovzxbd(dst_vmm, src_vmm); |
838 | host_->uni_vcvtdq2ps(dst_vmm, dst_vmm); |
839 | break; |
840 | } |
841 | default: assert(!"Unsupported data type." ); |
842 | } |
843 | } |
844 | |
845 | template <typename Vmm> |
846 | void jit_io_helper_t<Vmm>::broadcast( |
847 | const Xbyak::Address &src_addr, const Vmm &dst_vmm) { |
848 | switch (data_type_) { |
849 | case data_type::f32: host_->uni_vbroadcastss(dst_vmm, src_addr); break; |
850 | case data_type::bf16: |
851 | assert(bf16_supported_ && "Unsupported data type." ); |
852 | host_->vpbroadcastw(dst_vmm, src_addr); |
853 | convert_to_f32(dst_vmm, dst_vmm, data_type_); |
854 | break; |
855 | case data_type::f16: |
856 | assert(f16_supported_ && "Unsupported data type." ); |
857 | host_->uni_vcvtph2psx(dst_vmm, host_->ptr_b[src_addr.getRegExp()]); |
858 | break; |
859 | case data_type::s32: { |
860 | if (is_superset(isa_, avx512_core)) { |
861 | host_->uni_vcvtdq2ps( |
862 | dst_vmm, host_->ptr_b[src_addr.getRegExp()]); |
863 | } else { |
864 | host_->uni_vbroadcastss(dst_vmm, src_addr); |
865 | convert_to_f32(dst_vmm, dst_vmm, data_type_); |
866 | } |
867 | break; |
868 | } |
869 | case data_type::s8: |
870 | case data_type::u8: { |
871 | const Xbyak::Xmm dst_xmm {dst_vmm.getIdx()}; |
872 | host_->uni_vpinsrb(dst_xmm, dst_xmm, src_addr, 0); |
873 | convert_to_f32(dst_vmm, dst_vmm, data_type_); |
874 | host_->uni_vbroadcastss(dst_vmm, dst_xmm); |
875 | |
876 | break; |
877 | } |
878 | default: assert(!"Unsupported data type." ); |
879 | } |
880 | } |
881 | |
882 | template <typename Vmm> |
883 | jit_io_multi_dt_helper_t<Vmm>::jit_io_multi_dt_helper_t(jit_generator *host, |
884 | const cpu_isa_t &isa, const data_types_t &data_types, |
885 | const io_conf_t &io_conf, |
886 | const utils::optional_t<io_tail_conf_t> &tail_conf, |
887 | const utils::optional_t<io_emu_bf16_conf_t> &bf16_conf, |
888 | const std::map<data_type_t, io_saturation_conf_t> &saturation_confs, |
889 | const utils::optional_t<io_gather_conf_t> &gather_conf) { |
890 | assert(!data_types.empty()); |
891 | for (const auto &dt : data_types) { |
892 | // can be replaced by try_emplace from C++17 |
893 | if (storage_.find(dt) == storage_.cend()) { |
894 | |
895 | const auto saturation_conf = saturation_confs.find(dt); |
896 | const bool store_saturation_needed |
897 | = saturation_conf != saturation_confs.cend(); |
898 | |
899 | storage_.emplace(dt, |
900 | std::make_shared<jit_io_helper_t<Vmm>>(host, isa, dt, |
901 | io_conf, tail_conf, |
902 | dt == data_type::bf16 ? bf16_conf : utils::nullopt, |
903 | store_saturation_needed ? utils::optional_t< |
904 | io_saturation_conf_t> {saturation_conf |
905 | ->second} |
906 | : utils::nullopt, |
907 | gather_conf)); |
908 | } |
909 | } |
910 | } |
911 | |
912 | template <typename Vmm> |
913 | std::shared_ptr<jit_io_helper_t<Vmm>> jit_io_multi_dt_helper_t<Vmm>::at( |
914 | const data_type_t dt) const { |
915 | const auto it = storage_.find(dt); |
916 | if (it != storage_.cend()) return it->second; |
917 | return nullptr; |
918 | } |
919 | |
920 | template <typename Vmm> |
921 | std::shared_ptr<jit_io_helper_t<Vmm>> jit_io_multi_dt_helper_t<Vmm>::operator[]( |
922 | const data_type_t dt) const { |
923 | auto res = this->at(dt); |
924 | if (res == nullptr) { assert(!"data not found in io" ); } |
925 | return res; |
926 | } |
927 | |
928 | template <typename Vmm> |
929 | void jit_io_multi_dt_helper_t<Vmm>::prepare_tail_mask() { |
930 | return storage_.cbegin()->second->prepare_tail_mask(); |
931 | } |
932 | |
933 | template <typename Vmm> |
934 | void jit_io_multi_dt_helper_t<Vmm>::prepare_full_mask() { |
935 | return storage_.cbegin()->second->prepare_full_mask(); |
936 | } |
937 | |
938 | template <typename Vmm> |
939 | void jit_io_multi_dt_helper_t<Vmm>::init_saturate_f32( |
940 | const data_types_t &store_data_types) { |
941 | for (const auto &dt : store_data_types) { |
942 | const auto it = storage_.find(dt); |
943 | if (it != storage_.cend()) { |
944 | if (it->second->saturation_conf_.has_value()) |
945 | it->second->init_saturate_f32(); |
946 | } |
947 | } |
948 | } |
949 | |
950 | template <typename Vmm> |
951 | void jit_io_multi_dt_helper_t<Vmm>::init_full_mask() { |
952 | return storage_.cbegin()->second->init_full_mask(); |
953 | } |
954 | |
955 | template <typename Vmm> |
956 | void jit_io_multi_dt_helper_t<Vmm>::init_bf16() { |
957 | const auto bf16_io_helper = at(data_type::bf16); |
958 | if (bf16_io_helper) bf16_io_helper->init_bf16(); |
959 | } |
960 | |
961 | template <typename Vmm> |
962 | jit_io_multi_dt_helper_t<Vmm>::~jit_io_multi_dt_helper_t() = default; |
963 | |
964 | template class jit_io_helper_t<Xbyak::Zmm>; |
965 | template class jit_io_helper_t<Xbyak::Ymm>; |
966 | template class jit_io_helper_t<Xbyak::Xmm>; |
967 | |
968 | template class jit_io_multi_dt_helper_t<Xbyak::Zmm>; |
969 | template class jit_io_multi_dt_helper_t<Xbyak::Ymm>; |
970 | template class jit_io_multi_dt_helper_t<Xbyak::Xmm>; |
971 | |
972 | } // namespace io |
973 | } // namespace x64 |
974 | } // namespace cpu |
975 | } // namespace impl |
976 | } // namespace dnnl |
977 | |