1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/float16.hpp"
20
21#include "cpu/x64/cpu_isa_traits.hpp"
22#include "cpu/x64/jit_generator.hpp"
23#include "cpu/x64/jit_uni_convert_xf16.hpp"
24
25namespace dnnl {
26namespace impl {
27namespace cpu {
28namespace x64 {
29
30using namespace Xbyak;
31
32#define GET_OFF(field) offsetof(cvt_xf16_support::jit_call_t, field)
33
34template <cpu_isa_t isa>
35void jit_uni_cvt_ps_to_xf16_t<isa>::generate() {
36
37 preamble();
38
39 mov(reg_input, ptr[abi_param1 + GET_OFF(inp)]);
40 mov(reg_output, ptr[abi_param1 + GET_OFF(out)]);
41 if (is_dynamic_size_) mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]);
42
43 init_bf16();
44
45 if (is_dynamic_size_) { // determine nelems after JIT is called
46 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
47 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
48 for (int i = n_unroll; i >= 0; i--) {
49 const int unroll = 1 << i; // 4, 2, 1
50 L(l_simd_loop[i + 1]);
51 {
52 cmp(reg_nelems, simd_w_ * unroll);
53 jl(l_simd_loop[i], T_NEAR);
54 for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
55 cvt_ps_to_xf16(j, false);
56 }
57 add(reg_input, simd_w_ * unroll * sizeof(float));
58 add(reg_output, simd_w_ * unroll * sizeof(float16_t));
59 sub(reg_nelems, simd_w_ * unroll);
60 jmp(l_simd_loop[i + 1], T_NEAR);
61 }
62 }
63 L(l_simd_loop[0]);
64
65 test(reg_nelems, reg_nelems);
66 jz(l_simd_notail, T_NEAR);
67
68 mov(reg_tail, reg_nelems);
69 setup_mask();
70
71 cvt_ps_to_xf16(0, true);
72
73 L(l_simd_notail);
74 } else {
75 const size_t blocked_size = (nelems_ / simd_w_) * simd_w_;
76 constexpr size_t unroll_length = 1024;
77 const size_t number_of_loops = blocked_size / unroll_length;
78 const size_t loop_tail = blocked_size % unroll_length;
79
80 if (number_of_loops > 0) {
81 Xbyak::Label l_number_of_loops;
82 mov(reg_nelems, number_of_loops);
83 L(l_number_of_loops);
84 for (size_t i = 0; i < unroll_length; i += simd_w_)
85 cvt_ps_to_xf16(i, false);
86 add(reg_input, sizeof(float) * unroll_length);
87 add(reg_output, sizeof(float16_t) * unroll_length);
88
89 dec(reg_nelems);
90 cmp(reg_nelems, 0);
91 jg(l_number_of_loops, T_NEAR);
92 }
93 if (loop_tail > 0) {
94 for (size_t i = 0; i < loop_tail; i += simd_w_)
95 cvt_ps_to_xf16(i, false);
96 add(reg_input, sizeof(float) * loop_tail);
97 add(reg_output, sizeof(float16_t) * loop_tail);
98 }
99 if (tail_size_ != 0) {
100 setup_mask();
101 cvt_ps_to_xf16(0, true);
102 }
103 }
104 postamble();
105}
106
107template <cpu_isa_t isa>
108void jit_uni_cvt_ps_to_xf16_t<isa>::setup_mask() {
109 const Xbyak::Reg32 reg_mask = reg_tmp.cvt32();
110 if (is_dynamic_size_) {
111 mov(reg_mask, 1);
112 shl(reg_mask, reg_tail.cvt8());
113 sub(reg_mask, 1);
114 } else {
115 mov(reg_mask, (1 << tail_size_) - 1);
116 }
117 kmovd(ktail_xf16_mask, reg_mask);
118 kmovw(ktail_f32_mask, reg_mask);
119}
120
121template <>
122void jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>::setup_mask() {
123 static const uint32_t mask_in[16]
124 = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
125 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0, 0};
126 const Xbyak::Reg64 reg64_mask = reg_tmp;
127
128 if (!is_dynamic_size_) {
129 constexpr int max_words_in_ymm = 8;
130 auto mask_in_offset = max_words_in_ymm - tail_size_;
131 mov(reg64_mask, reinterpret_cast<size_t>(&mask_in[mask_in_offset]));
132 } else {
133 mov(reg64_mask, reinterpret_cast<size_t>(&mask_in[8]));
134 mov(reg_scratch, reg_tail);
135 shl(reg_scratch, 2);
136 sub(reg64_mask, reg_scratch);
137 }
138 vmovups(vmm_in_mask, ptr[reg64_mask]);
139}
140
141// NOTE: putting the function's definition in the header results in
142// a compilation error for VS.
143template <cpu_isa_t isa>
144void jit_uni_cvt_ps_to_xf16_t<isa>::cvt_ps_to_xf16(
145 const int idx, const bool is_tail) {
146 assert(!"unimplemented template");
147}
148
149template <>
150void jit_uni_cvt_ps_to_xf16_t<avx512_core_fp16>::cvt_ps_to_xf16(
151 const int idx, const bool is_tail) {
152 const Vmm vmm_m_in = is_tail ? vmm_input | ktail_f32_mask | T_z : vmm_input;
153 const size_t out_offset = sizeof(float16_t) * idx;
154 const auto addr_m_out = is_tail
155 ? ptr[reg_output + out_offset] | ktail_xf16_mask
156 : ptr[reg_output + out_offset];
157 vmovups(vmm_m_in, ptr[reg_input + sizeof(float) * idx]);
158 vcvtps2ph(addr_m_out, vmm_input, _op_mxcsr);
159}
160
161void jit_avx512_core_cvt_ps_to_bf16_t::cvt_ps_to_xf16(
162 const int idx, const bool is_tail) {
163 const size_t out_offset = sizeof(float16_t) * idx;
164 const auto addr_m_out = is_tail
165 ? ptr[reg_output + out_offset] | ktail_xf16_mask
166 : ptr[reg_output + out_offset];
167
168 if (use_bf16_emu_) {
169 const Vmm vmm_m_in
170 = is_tail ? vmm_input | ktail_f32_mask | T_z : vmm_input;
171 vmovups(vmm_m_in, ptr[reg_input + sizeof(float) * idx]);
172 bf16_emu_->vcvtneps2bf16(vmm_output, vmm_input);
173 } else {
174 const auto vmm_m_out
175 = is_tail ? vmm_output | ktail_xf16_mask | T_z : vmm_output;
176 vcvtneps2bf16(vmm_m_out, ptr[reg_input + sizeof(float) * idx]);
177 }
178 vmovdqu16(addr_m_out, vmm_output);
179}
180
181template <>
182void jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>::cvt_ps_to_xf16(
183 const int idx, const bool is_tail) {
184 if (is_tail) {
185 uni_vxorps(vmm_input, vmm_input, vmm_input);
186 vmaskmovps(
187 vmm_input, vmm_in_mask, ptr[reg_input + sizeof(float) * idx]);
188 } else if (output_dt_ == data_type::f16) {
189 vmovups(vmm_input, ptr[reg_input + sizeof(float) * idx]);
190 }
191
192 switch (output_dt_) {
193 case data_type::bf16:
194 if (is_tail)
195 vcvtneps2bf16(vmm_output, vmm_input, Xbyak::VexEncoding);
196 else
197 vcvtneps2bf16(vmm_output,
198 yword[reg_input + sizeof(float) * idx],
199 Xbyak::VexEncoding);
200 break;
201 case data_type::f16:
202 if (is_tail)
203 vcvtps2ph(vmm_output, vmm_input, _op_mxcsr);
204 else
205 vcvtps2ph(ptr[reg_output + sizeof(float16_t) * idx], vmm_input,
206 _op_mxcsr);
207 break;
208 default: assert(!"Invalid datatype");
209 }
210
211 if (is_tail) {
212 auto tail_store = [&](int load_size) {
213 store_bytes(vmm_output, reg_output, sizeof(float16_t) * idx,
214 sizeof(float16_t) * load_size);
215 };
216 if (is_dynamic_size_)
217 runtime_tail_process<Xbyak::Xmm>(
218 reg_tail, reg_tmp, tail_store, data_type::f16);
219 else
220 tail_store(tail_size_);
221
222 } else if (output_dt_ == data_type::bf16)
223 vmovups(ptr[reg_output + sizeof(bfloat16_t) * idx], vmm_output);
224}
225
226#undef GET_OFF
227
228template struct jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>;
229template struct jit_uni_cvt_ps_to_xf16_t<avx512_core>;
230template struct jit_uni_cvt_ps_to_xf16_t<avx512_core_fp16>;
231
232#define GET_OFF(field) \
233 offsetof(cvt_xf16_support::jit_cvt_xf16_to_ps_params_t, field)
234
235template <cpu_isa_t isa>
236void jit_uni_cvt_xf16_to_ps_t<isa>::generate() {
237 preamble();
238 const bool long_row_stride = (row_stride_ * sizeof(float16_t) >> 32) != 0;
239 MAYBE_UNUSED(long_row_stride);
240
241 mov(reg_input, ptr[abi_param1 + GET_OFF(inp)]);
242 mov(reg_output, ptr[abi_param1 + GET_OFF(out)]);
243 mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]);
244 mov(reg_nrows, ptr[abi_param1 + GET_OFF(rows)]);
245
246 Label l_row_start, l_row_end;
247 Label l_exit; // used for row_stride_
248
249 if (row_stride_) {
250 test(reg_nrows, reg_nrows);
251 jz(l_exit, T_NEAR); // fast exit: nrows == 0
252 mov(reg_nelems_save, reg_nelems);
253 mov(reg_rollback, reg_nelems);
254 and_(reg_rollback, ~(simd_w_ - 1));
255 neg(reg_rollback);
256 if (long_row_stride) {
257 mov(reg_long_row_stride, row_stride_ * sizeof(float16_t));
258 lea(reg_long_row_stride,
259 ptr[reg_long_row_stride
260 + reg_rollback * sizeof(float16_t)]);
261 }
262 }
263
264 L(l_row_start);
265
266 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
267 Label l_simd_loop[n_unroll + 2];
268 for (int i = n_unroll; i >= 0; i--) {
269 const int unroll = 1 << i; // 4, 2, 1
270 assert(IMPLICATION(unroll > 1, unroll % 2 == 0));
271 L(l_simd_loop[i + 1]);
272 {
273 cmp(reg_nelems, simd_w_ * unroll);
274 jl(l_simd_loop[i], T_NEAR);
275 for (int j = 0; j < utils::div_up(unroll, elem_granularity); ++j)
276 convert_xf16(j, unroll > 1);
277 add(reg_input, simd_w_ * unroll * sizeof(float16_t));
278 add(reg_output, simd_w_ * unroll * sizeof(float));
279 sub(reg_nelems, simd_w_ * unroll);
280 if (i == n_unroll && n_unroll != 0) jmp(l_simd_loop[i + 1], T_NEAR);
281 }
282 }
283 L(l_simd_loop[0]);
284
285 test(reg_nelems, reg_nelems);
286 jz(l_row_end, T_NEAR);
287
288 mov(reg_tail, reg_nelems);
289 cvt_tail();
290
291 L(l_row_end);
292
293 if (row_stride_) {
294 dec(reg_nrows);
295 jz(l_exit, T_NEAR);
296
297 // wraparound
298 lea(reg_output, ptr[reg_output + reg_rollback * sizeof(float)]);
299 if (long_row_stride)
300 add(reg_input, reg_long_row_stride);
301 else
302 lea(reg_input,
303 ptr[reg_input + reg_rollback * sizeof(float16_t)
304 + row_stride_ * sizeof(float16_t)]);
305 mov(reg_nelems, reg_nelems_save);
306 jmp(l_row_start);
307
308 L(l_exit);
309 }
310
311 postamble();
312}
313
314template <cpu_isa_t isa>
315void jit_uni_cvt_xf16_to_ps_t<isa>::convert_xf16(
316 const int idx, const bool handle_x2) {
317 const size_t offset = idx * simd_w_;
318 const auto out_addr = ptr[reg_output + sizeof(float) * offset];
319 const auto in_addr = ptr[reg_input + sizeof(bfloat16_t) * offset];
320 switch (input_dt_) {
321 case data_type::bf16:
322 vpmovzxwd(get_vmm_src(idx), in_addr);
323 vpslld(get_vmm_src(idx), get_vmm_src(idx), 0x10);
324 break;
325 case data_type::f16: vcvtph2psx(get_vmm_src(idx), in_addr); break;
326 default: assert(!"Invalid datatype");
327 }
328 if (with_add_) vaddps(get_vmm_src(idx), get_vmm_src(idx), out_addr);
329 uni_vmovdqu(out_addr, get_vmm_src(idx));
330}
331
332template <typename Wmm>
333struct helper_avx2_cvt_xf16_t {
334 static void convert_xf16(jit_generator *host,
335 const impl::data_type_t input_dt, const Xbyak::Address in_addr,
336 const int even_src, const int odd_src, const int tmp_1,
337 const int tmp_2) {
338 const Wmm vmm_even_src = Wmm(even_src);
339 const Wmm vmm_odd_src = Wmm(odd_src);
340 const Wmm vmm_tmp_1 = Wmm(tmp_1);
341 const Wmm vmm_tmp_2 = Wmm(tmp_2);
342
343 switch (input_dt) {
344 case data_type::bf16:
345 host->vcvtneebf162ps(vmm_even_src, in_addr);
346 host->vcvtneobf162ps(vmm_odd_src, in_addr);
347 break;
348 case data_type::f16:
349 host->vcvtneeph2ps(vmm_even_src, in_addr);
350 host->vcvtneoph2ps(vmm_odd_src, in_addr);
351 break;
352 default: assert(!"Invalid datatype");
353 }
354 host->vpunpckldq(vmm_tmp_1, vmm_even_src, vmm_odd_src);
355 host->vpunpckhdq(vmm_tmp_2, vmm_even_src, vmm_odd_src);
356 }
357};
358
359template <>
360void jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>::convert_xf16(
361 const int idx, const bool handle_x2) {
362 const Vmm vmm_tmp_1 = vmm_tmp;
363 const Vmm vmm_tmp_2 = Vmm(get_even_src_idx(idx));
364 const size_t offset = idx * simd_w_ * elem_granularity;
365 const auto in_addr = ptr[reg_input + sizeof(bfloat16_t) * offset];
366 auto get_out_addr = [&](const size_t offset_xmmword = 0) {
367 return ptr[reg_output + sizeof(float) * (offset + offset_xmmword)];
368 };
369
370 if (handle_x2)
371 helper_avx2_cvt_xf16_t<Xbyak::Ymm>::convert_xf16(this, input_dt_,
372 in_addr, get_even_src_idx(idx), get_odd_src_idx(idx),
373 vmm_tmp_1.getIdx(), vmm_tmp_2.getIdx());
374 else
375 helper_avx2_cvt_xf16_t<Xbyak::Xmm>::convert_xf16(this, input_dt_,
376 in_addr, get_even_src_idx(idx), get_odd_src_idx(idx),
377 vmm_tmp_1.getIdx(), vmm_tmp_2.getIdx());
378
379 vperm2f128(vmm_dst, vmm_tmp_1, vmm_tmp_2, 0x20);
380 if (handle_x2) vperm2f128(vmm_dst_2, vmm_tmp_1, vmm_tmp_2, 0x31);
381
382 if (with_add_) {
383 vaddps(vmm_dst, vmm_dst, get_out_addr());
384 if (handle_x2) vaddps(vmm_dst_2, vmm_dst_2, get_out_addr(simd_w_));
385 }
386 uni_vmovdqu(get_out_addr(), vmm_dst);
387 if (handle_x2) uni_vmovdqu(get_out_addr(simd_w_), vmm_dst_2);
388}
389
390template <cpu_isa_t isa>
391void jit_uni_cvt_xf16_to_ps_t<isa>::cvt_tail() {
392 const Reg32 reg32_mask
393 = reg_nelems.cvt32(); // no need for reg_nelems anymore
394
395 // ktail_mask <-- (1 << (nelems % simd_w_)) - 1
396 mov(reg32_mask, 1);
397 shl(reg32_mask, reg_tail.cvt8());
398 sub(reg32_mask, 1);
399 kmovd(ktail_mask, reg32_mask);
400
401 auto vmm_masked = get_vmm_src(0) | ktail_mask | T_z;
402 switch (input_dt_) {
403 case data_type::bf16:
404 vpmovzxwd(vmm_masked, ptr[reg_input]);
405 vpslld(vmm_masked, get_vmm_src(0), 0x10);
406 break;
407 case data_type::f16: vcvtph2psx(vmm_masked, ptr[reg_input]); break;
408 default: assert(!"Invalid datatype");
409 }
410 if (with_add_) vaddps(vmm_masked, get_vmm_src(0), ptr[reg_output]);
411 vmovdqu32(ptr[reg_output] | ktail_mask, get_vmm_src(0));
412}
413
414template <>
415void jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>::cvt_tail() {
416 const Vmm vmm_output = get_vmm_src(0);
417 const Vmm_down_t vmm_input = Vmm_down_t(vmm_output.getIdx());
418 auto runtime_tail_load = [&](int load_size) {
419 load_bytes(vmm_input, reg_input, 0, sizeof(bfloat16_t) * load_size);
420 };
421 auto runtime_tail_store = [&](int load_size) {
422 store_data(data_type::f32, vmm_output, reg_output, 0, load_size);
423 };
424
425 uni_vxorps(vmm_input, vmm_input, vmm_input);
426 runtime_tail_process<Xbyak::Xmm>(
427 reg_tail, reg_tmp, runtime_tail_load, data_type::f16);
428 switch (input_dt_) {
429 case data_type::bf16:
430 vpmovzxwd(vmm_output, vmm_input);
431 vpslld(vmm_output, vmm_input, 0x10);
432 break;
433 case data_type::f16: vcvtph2ps(vmm_output, vmm_input); break;
434 default: assert(!"Invalid datatype");
435 }
436 runtime_tail_process<Xbyak::Ymm>(
437 reg_tail, reg_tmp, runtime_tail_store, data_type::f32);
438}
439
440#undef GET_OFF
441
442template struct jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>;
443template struct jit_uni_cvt_xf16_to_ps_t<avx512_core>;
444template struct jit_uni_cvt_xf16_to_ps_t<avx512_core_fp16>;
445
446} // namespace x64
447} // namespace cpu
448} // namespace impl
449} // namespace dnnl
450