1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_CONVERT_XF16_HPP |
18 | #define CPU_X64_JIT_UNI_CONVERT_XF16_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/float16.hpp" |
24 | #include "common/nstl.hpp" |
25 | #include "common/type_helpers.hpp" |
26 | |
27 | #include "cpu/x64/cpu_isa_traits.hpp" |
28 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
29 | #include "cpu/x64/jit_generator.hpp" |
30 | |
31 | #include "oneapi/dnnl/dnnl_debug.h" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace x64 { |
37 | |
38 | namespace cvt_xf16_support { |
39 | struct jit_call_t { |
40 | void *inp; |
41 | void *out; |
42 | void *add; |
43 | size_t nelems; |
44 | }; |
45 | struct jit_cvt_xf16_to_ps_params_t { |
46 | const void *inp; |
47 | void *out; |
48 | size_t nelems; |
49 | size_t rows; |
50 | }; |
51 | } // namespace cvt_xf16_support |
52 | |
53 | template <cpu_isa_t isa> |
54 | struct jit_uni_cvt_ps_to_xf16_t : public jit_generator { |
55 | |
56 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_cvt_ps_to_xf16_t) |
57 | |
58 | jit_uni_cvt_ps_to_xf16_t(impl::data_type_t dt, size_t nelems = 0) |
59 | : jit_generator(jit_name()) |
60 | , output_dt_(dt) |
61 | , nelems_(nelems) |
62 | , is_dynamic_size_(nelems_ == 0) |
63 | , tail_size_(nelems_ % simd_w_) {} |
64 | |
65 | void generate() override; |
66 | |
67 | protected: |
68 | const impl::data_type_t output_dt_; // bf16 or f16 |
69 | const size_t nelems_; |
70 | const bool is_dynamic_size_; |
71 | const int tail_size_; |
72 | |
73 | constexpr static int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float); |
74 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
75 | using Vmm_down_t = typename vreg_traits<Vmm>::Vmm_lower_t; |
76 | |
77 | const Vmm vmm_input = Vmm(0); |
78 | const Vmm_down_t vmm_output = Vmm_down_t(1); |
79 | |
80 | // used in avx2_vnni_2 |
81 | const Vmm vmm_in_mask = Vmm(2); |
82 | const Vmm_down_t vmm_out_mask = Vmm(3); |
83 | // used in bf16 emulation |
84 | const Vmm vmm_one = Vmm(2); |
85 | const Vmm vmm_even = Vmm(3); |
86 | const Vmm vmm_selector = Vmm(4); |
87 | const Vmm vmm_fp32_tmp = Vmm(5); |
88 | // used in avx512_core[_fp16] |
89 | const Xbyak::Opmask ktail_f32_mask = Xbyak::Opmask(2); |
90 | const Xbyak::Opmask ktail_xf16_mask = Xbyak::Opmask(3); |
91 | |
92 | Xbyak::Reg64 reg_input = rax; |
93 | Xbyak::Reg64 reg_output = rbx; |
94 | Xbyak::Reg64 reg_nelems = rdx; |
95 | Xbyak::Reg64 reg_tail = rcx; |
96 | Xbyak::Reg64 reg_tmp = r8; |
97 | Xbyak::Reg64 reg_scratch = r9; |
98 | |
99 | void setup_mask(); |
100 | virtual void cvt_ps_to_xf16(const int idx, const bool is_tail); |
101 | virtual void init_bf16() {} // unused for f16 |
102 | }; |
103 | |
104 | struct jit_avx512_core_cvt_ps_to_bf16_t |
105 | : public jit_uni_cvt_ps_to_xf16_t<avx512_core> { |
106 | |
107 | jit_avx512_core_cvt_ps_to_bf16_t(impl::data_type_t dt, size_t nelems = 0) |
108 | : jit_uni_cvt_ps_to_xf16_t<avx512_core>(dt, nelems) |
109 | , use_bf16_emu_(!mayiuse(avx512_core_bf16)) |
110 | , bf16_emu_(use_bf16_emu_ ? utils::make_unique<bf16_emulation_t>(this, |
111 | vmm_one, vmm_even, vmm_selector, reg_scratch, |
112 | vmm_fp32_tmp) |
113 | : nullptr) { |
114 | assert(dt == data_type::bf16); |
115 | } |
116 | |
117 | private: |
118 | const bool use_bf16_emu_; |
119 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
120 | |
121 | void cvt_ps_to_xf16(const int idx, const bool is_tail) override; |
122 | void init_bf16() override { |
123 | if (use_bf16_emu_) bf16_emu_->init_vcvtneps2bf16(); |
124 | } |
125 | }; |
126 | |
127 | struct jit_cvt_ps_to_xf16_t { |
128 | |
129 | jit_cvt_ps_to_xf16_t(impl::data_type_t data_type, size_t nelems = 0) |
130 | : nelems_(nelems) { |
131 | if (data_type == data_type::f16 && mayiuse(avx512_core_fp16)) |
132 | kernel_ = utils::make_unique< |
133 | jit_uni_cvt_ps_to_xf16_t<avx512_core_fp16>>( |
134 | data_type, nelems); |
135 | else if (data_type == data_type::bf16 && mayiuse(avx512_core)) |
136 | kernel_ = utils::make_unique<jit_avx512_core_cvt_ps_to_bf16_t>( |
137 | data_type, nelems); |
138 | else if (mayiuse(avx2_vnni_2)) |
139 | kernel_ = utils::make_unique<jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>>( |
140 | data_type, nelems); |
141 | else { |
142 | assert(!"unsupported ISA for converter" ); |
143 | return; |
144 | } |
145 | kernel_->create_kernel(); |
146 | } |
147 | |
148 | void operator()(cvt_xf16_support::jit_call_t *params) const { |
149 | (*kernel_)(params); |
150 | msan_unpoison(params->out, |
151 | (nelems_ ? nelems_ : params->nelems) * sizeof(float16_t)); |
152 | } |
153 | |
154 | private: |
155 | std::unique_ptr<jit_generator> kernel_; |
156 | const size_t nelems_; |
157 | }; |
158 | |
159 | template <cpu_isa_t isa> |
160 | struct jit_uni_cvt_xf16_to_ps_t : public jit_generator { |
161 | |
162 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_cvt_xf16_to_ps_t) |
163 | |
164 | jit_uni_cvt_xf16_to_ps_t( |
165 | impl::data_type_t dt, bool with_add, size_t row_stride) |
166 | : jit_generator(jit_name()) |
167 | , input_dt_(dt) |
168 | , with_add_(with_add) |
169 | , row_stride_(row_stride) { |
170 | create_kernel(); |
171 | } |
172 | |
173 | void generate() override; |
174 | |
175 | protected: |
176 | constexpr static int elem_granularity = isa == avx2_vnni_2 ? 2 : 1; |
177 | constexpr static int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float); |
178 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
179 | using Vmm_down_t = typename vreg_traits<Vmm>::Vmm_lower_t; |
180 | |
181 | const impl::data_type_t input_dt_; |
182 | const bool with_add_; |
183 | |
184 | const size_t row_stride_; |
185 | |
186 | const Xbyak::Reg64 reg_input = rax; |
187 | const Xbyak::Reg64 reg_output = rbx; |
188 | const Xbyak::Reg64 reg_nelems = r8; |
189 | const Xbyak::Reg64 reg_nrows = r9; |
190 | |
191 | const Xbyak::Reg64 reg_tail = rcx; //used for cl |
192 | |
193 | const Xbyak::Reg64 reg_long_row_stride = r10; |
194 | const Xbyak::Reg64 reg_rollback = r11; |
195 | const Xbyak::Reg64 reg_nelems_save = r12; |
196 | |
197 | const Xbyak::Reg64 reg_tmp = r13; |
198 | |
199 | const Xbyak::Opmask ktail_mask = Xbyak::Opmask(1); |
200 | |
201 | const Vmm vmm_tmp = Vmm(13); |
202 | const Vmm vmm_dst = Vmm(14); |
203 | const Vmm vmm_dst_2 = Vmm(15); |
204 | const Vmm_down_t vmm_in_mask = Vmm_down_t(15); |
205 | |
206 | Vmm get_vmm_src(int idx) { return Vmm(get_even_src_idx(idx)); } |
207 | int get_even_src_idx(int idx) { |
208 | assert(idx < 4); |
209 | return idx; |
210 | } |
211 | int get_odd_src_idx(int idx) { |
212 | assert(idx < 4); |
213 | return idx + 4; |
214 | } |
215 | |
216 | void convert_xf16(const int idx, const bool handle_x2); |
217 | void cvt_tail(); |
218 | }; |
219 | |
220 | struct jit_cvt_xf16_to_ps_t { |
221 | |
222 | jit_cvt_xf16_to_ps_t(impl::data_type_t data_type, bool with_add = false, |
223 | size_t row_stride = 0) { |
224 | if (data_type == data_type::f16 && mayiuse(avx512_core_fp16)) |
225 | kernel_ = utils::make_unique< |
226 | jit_uni_cvt_xf16_to_ps_t<avx512_core_fp16>>( |
227 | data_type, with_add, row_stride); |
228 | else if (data_type == data_type::bf16 && mayiuse(avx512_core)) |
229 | kernel_ = utils::make_unique<jit_uni_cvt_xf16_to_ps_t<avx512_core>>( |
230 | data_type, with_add, row_stride); |
231 | else if (mayiuse(avx2_vnni_2)) { |
232 | if (row_stride != 0) { |
233 | assert(!"unsupported row_stride for avx2_vnni_2" ); |
234 | return; |
235 | } else if (with_add) { |
236 | assert(!"untested implementation 'with_add' for avx2_vnni_2" ); |
237 | return; |
238 | } |
239 | kernel_ = utils::make_unique<jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>>( |
240 | data_type, with_add, row_stride); |
241 | } else { |
242 | assert(!"unsupported configuration for converter" ); |
243 | return; |
244 | } |
245 | kernel_->create_kernel(); |
246 | } |
247 | |
248 | void operator()( |
249 | float *out, const void *inp, size_t nelems, size_t rows = 1) const { |
250 | cvt_xf16_support::jit_cvt_xf16_to_ps_params_t p; |
251 | p.inp = inp; |
252 | p.out = (void *)out; |
253 | p.nelems = nelems; |
254 | p.rows = rows; |
255 | (*kernel_)(&p); |
256 | msan_unpoison(out, nelems * sizeof(float)); |
257 | } |
258 | |
259 | void operator()(float *out, const float16_t *inp, size_t nelems, |
260 | size_t rows = 1) const { |
261 | (*this)(out, (const void *)inp, nelems, rows); |
262 | } |
263 | |
264 | void operator()(float *out, const bfloat16_t *inp, size_t nelems, |
265 | size_t rows = 1) const { |
266 | (*this)(out, (const void *)inp, nelems, rows); |
267 | } |
268 | |
269 | private: |
270 | std::unique_ptr<jit_generator> kernel_; |
271 | }; |
272 | |
273 | } // namespace x64 |
274 | } // namespace cpu |
275 | } // namespace impl |
276 | } // namespace dnnl |
277 | |
278 | #endif |
279 | |