1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_CONVERT_XF16_HPP
18#define CPU_X64_JIT_UNI_CONVERT_XF16_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/float16.hpp"
24#include "common/nstl.hpp"
25#include "common/type_helpers.hpp"
26
27#include "cpu/x64/cpu_isa_traits.hpp"
28#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
29#include "cpu/x64/jit_generator.hpp"
30
31#include "oneapi/dnnl/dnnl_debug.h"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace x64 {
37
38namespace cvt_xf16_support {
39struct jit_call_t {
40 void *inp;
41 void *out;
42 void *add;
43 size_t nelems;
44};
45struct jit_cvt_xf16_to_ps_params_t {
46 const void *inp;
47 void *out;
48 size_t nelems;
49 size_t rows;
50};
51} // namespace cvt_xf16_support
52
53template <cpu_isa_t isa>
54struct jit_uni_cvt_ps_to_xf16_t : public jit_generator {
55
56 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_cvt_ps_to_xf16_t)
57
58 jit_uni_cvt_ps_to_xf16_t(impl::data_type_t dt, size_t nelems = 0)
59 : jit_generator(jit_name())
60 , output_dt_(dt)
61 , nelems_(nelems)
62 , is_dynamic_size_(nelems_ == 0)
63 , tail_size_(nelems_ % simd_w_) {}
64
65 void generate() override;
66
67protected:
68 const impl::data_type_t output_dt_; // bf16 or f16
69 const size_t nelems_;
70 const bool is_dynamic_size_;
71 const int tail_size_;
72
73 constexpr static int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
74 using Vmm = typename cpu_isa_traits<isa>::Vmm;
75 using Vmm_down_t = typename vreg_traits<Vmm>::Vmm_lower_t;
76
77 const Vmm vmm_input = Vmm(0);
78 const Vmm_down_t vmm_output = Vmm_down_t(1);
79
80 // used in avx2_vnni_2
81 const Vmm vmm_in_mask = Vmm(2);
82 const Vmm_down_t vmm_out_mask = Vmm(3);
83 // used in bf16 emulation
84 const Vmm vmm_one = Vmm(2);
85 const Vmm vmm_even = Vmm(3);
86 const Vmm vmm_selector = Vmm(4);
87 const Vmm vmm_fp32_tmp = Vmm(5);
88 // used in avx512_core[_fp16]
89 const Xbyak::Opmask ktail_f32_mask = Xbyak::Opmask(2);
90 const Xbyak::Opmask ktail_xf16_mask = Xbyak::Opmask(3);
91
92 Xbyak::Reg64 reg_input = rax;
93 Xbyak::Reg64 reg_output = rbx;
94 Xbyak::Reg64 reg_nelems = rdx;
95 Xbyak::Reg64 reg_tail = rcx;
96 Xbyak::Reg64 reg_tmp = r8;
97 Xbyak::Reg64 reg_scratch = r9;
98
99 void setup_mask();
100 virtual void cvt_ps_to_xf16(const int idx, const bool is_tail);
101 virtual void init_bf16() {} // unused for f16
102};
103
104struct jit_avx512_core_cvt_ps_to_bf16_t
105 : public jit_uni_cvt_ps_to_xf16_t<avx512_core> {
106
107 jit_avx512_core_cvt_ps_to_bf16_t(impl::data_type_t dt, size_t nelems = 0)
108 : jit_uni_cvt_ps_to_xf16_t<avx512_core>(dt, nelems)
109 , use_bf16_emu_(!mayiuse(avx512_core_bf16))
110 , bf16_emu_(use_bf16_emu_ ? utils::make_unique<bf16_emulation_t>(this,
111 vmm_one, vmm_even, vmm_selector, reg_scratch,
112 vmm_fp32_tmp)
113 : nullptr) {
114 assert(dt == data_type::bf16);
115 }
116
117private:
118 const bool use_bf16_emu_;
119 std::unique_ptr<bf16_emulation_t> bf16_emu_;
120
121 void cvt_ps_to_xf16(const int idx, const bool is_tail) override;
122 void init_bf16() override {
123 if (use_bf16_emu_) bf16_emu_->init_vcvtneps2bf16();
124 }
125};
126
127struct jit_cvt_ps_to_xf16_t {
128
129 jit_cvt_ps_to_xf16_t(impl::data_type_t data_type, size_t nelems = 0)
130 : nelems_(nelems) {
131 if (data_type == data_type::f16 && mayiuse(avx512_core_fp16))
132 kernel_ = utils::make_unique<
133 jit_uni_cvt_ps_to_xf16_t<avx512_core_fp16>>(
134 data_type, nelems);
135 else if (data_type == data_type::bf16 && mayiuse(avx512_core))
136 kernel_ = utils::make_unique<jit_avx512_core_cvt_ps_to_bf16_t>(
137 data_type, nelems);
138 else if (mayiuse(avx2_vnni_2))
139 kernel_ = utils::make_unique<jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>>(
140 data_type, nelems);
141 else {
142 assert(!"unsupported ISA for converter");
143 return;
144 }
145 kernel_->create_kernel();
146 }
147
148 void operator()(cvt_xf16_support::jit_call_t *params) const {
149 (*kernel_)(params);
150 msan_unpoison(params->out,
151 (nelems_ ? nelems_ : params->nelems) * sizeof(float16_t));
152 }
153
154private:
155 std::unique_ptr<jit_generator> kernel_;
156 const size_t nelems_;
157};
158
159template <cpu_isa_t isa>
160struct jit_uni_cvt_xf16_to_ps_t : public jit_generator {
161
162 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_cvt_xf16_to_ps_t)
163
164 jit_uni_cvt_xf16_to_ps_t(
165 impl::data_type_t dt, bool with_add, size_t row_stride)
166 : jit_generator(jit_name())
167 , input_dt_(dt)
168 , with_add_(with_add)
169 , row_stride_(row_stride) {
170 create_kernel();
171 }
172
173 void generate() override;
174
175protected:
176 constexpr static int elem_granularity = isa == avx2_vnni_2 ? 2 : 1;
177 constexpr static int simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
178 using Vmm = typename cpu_isa_traits<isa>::Vmm;
179 using Vmm_down_t = typename vreg_traits<Vmm>::Vmm_lower_t;
180
181 const impl::data_type_t input_dt_;
182 const bool with_add_;
183
184 const size_t row_stride_;
185
186 const Xbyak::Reg64 reg_input = rax;
187 const Xbyak::Reg64 reg_output = rbx;
188 const Xbyak::Reg64 reg_nelems = r8;
189 const Xbyak::Reg64 reg_nrows = r9;
190
191 const Xbyak::Reg64 reg_tail = rcx; //used for cl
192
193 const Xbyak::Reg64 reg_long_row_stride = r10;
194 const Xbyak::Reg64 reg_rollback = r11;
195 const Xbyak::Reg64 reg_nelems_save = r12;
196
197 const Xbyak::Reg64 reg_tmp = r13;
198
199 const Xbyak::Opmask ktail_mask = Xbyak::Opmask(1);
200
201 const Vmm vmm_tmp = Vmm(13);
202 const Vmm vmm_dst = Vmm(14);
203 const Vmm vmm_dst_2 = Vmm(15);
204 const Vmm_down_t vmm_in_mask = Vmm_down_t(15);
205
206 Vmm get_vmm_src(int idx) { return Vmm(get_even_src_idx(idx)); }
207 int get_even_src_idx(int idx) {
208 assert(idx < 4);
209 return idx;
210 }
211 int get_odd_src_idx(int idx) {
212 assert(idx < 4);
213 return idx + 4;
214 }
215
216 void convert_xf16(const int idx, const bool handle_x2);
217 void cvt_tail();
218};
219
220struct jit_cvt_xf16_to_ps_t {
221
222 jit_cvt_xf16_to_ps_t(impl::data_type_t data_type, bool with_add = false,
223 size_t row_stride = 0) {
224 if (data_type == data_type::f16 && mayiuse(avx512_core_fp16))
225 kernel_ = utils::make_unique<
226 jit_uni_cvt_xf16_to_ps_t<avx512_core_fp16>>(
227 data_type, with_add, row_stride);
228 else if (data_type == data_type::bf16 && mayiuse(avx512_core))
229 kernel_ = utils::make_unique<jit_uni_cvt_xf16_to_ps_t<avx512_core>>(
230 data_type, with_add, row_stride);
231 else if (mayiuse(avx2_vnni_2)) {
232 if (row_stride != 0) {
233 assert(!"unsupported row_stride for avx2_vnni_2");
234 return;
235 } else if (with_add) {
236 assert(!"untested implementation 'with_add' for avx2_vnni_2");
237 return;
238 }
239 kernel_ = utils::make_unique<jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>>(
240 data_type, with_add, row_stride);
241 } else {
242 assert(!"unsupported configuration for converter");
243 return;
244 }
245 kernel_->create_kernel();
246 }
247
248 void operator()(
249 float *out, const void *inp, size_t nelems, size_t rows = 1) const {
250 cvt_xf16_support::jit_cvt_xf16_to_ps_params_t p;
251 p.inp = inp;
252 p.out = (void *)out;
253 p.nelems = nelems;
254 p.rows = rows;
255 (*kernel_)(&p);
256 msan_unpoison(out, nelems * sizeof(float));
257 }
258
259 void operator()(float *out, const float16_t *inp, size_t nelems,
260 size_t rows = 1) const {
261 (*this)(out, (const void *)inp, nelems, rows);
262 }
263
264 void operator()(float *out, const bfloat16_t *inp, size_t nelems,
265 size_t rows = 1) const {
266 (*this)(out, (const void *)inp, nelems, rows);
267 }
268
269private:
270 std::unique_ptr<jit_generator> kernel_;
271};
272
273} // namespace x64
274} // namespace cpu
275} // namespace impl
276} // namespace dnnl
277
278#endif
279