1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "common/float16.hpp" |
20 | |
21 | #include "cpu/x64/cpu_isa_traits.hpp" |
22 | #include "cpu/x64/jit_generator.hpp" |
23 | #include "cpu/x64/jit_uni_convert_xf16.hpp" |
24 | |
25 | namespace dnnl { |
26 | namespace impl { |
27 | namespace cpu { |
28 | namespace x64 { |
29 | |
30 | using namespace Xbyak; |
31 | |
32 | #define GET_OFF(field) offsetof(cvt_xf16_support::jit_call_t, field) |
33 | |
34 | template <cpu_isa_t isa> |
35 | void jit_uni_cvt_ps_to_xf16_t<isa>::generate() { |
36 | |
37 | preamble(); |
38 | |
39 | mov(reg_input, ptr[abi_param1 + GET_OFF(inp)]); |
40 | mov(reg_output, ptr[abi_param1 + GET_OFF(out)]); |
41 | if (is_dynamic_size_) mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]); |
42 | |
43 | init_bf16(); |
44 | |
45 | if (is_dynamic_size_) { // determine nelems after JIT is called |
46 | constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0 |
47 | Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail; |
48 | for (int i = n_unroll; i >= 0; i--) { |
49 | const int unroll = 1 << i; // 4, 2, 1 |
50 | L(l_simd_loop[i + 1]); |
51 | { |
52 | cmp(reg_nelems, simd_w_ * unroll); |
53 | jl(l_simd_loop[i], T_NEAR); |
54 | for (int j = 0; j < simd_w_ * unroll; j += simd_w_) { |
55 | cvt_ps_to_xf16(j, false); |
56 | } |
57 | add(reg_input, simd_w_ * unroll * sizeof(float)); |
58 | add(reg_output, simd_w_ * unroll * sizeof(float16_t)); |
59 | sub(reg_nelems, simd_w_ * unroll); |
60 | jmp(l_simd_loop[i + 1], T_NEAR); |
61 | } |
62 | } |
63 | L(l_simd_loop[0]); |
64 | |
65 | test(reg_nelems, reg_nelems); |
66 | jz(l_simd_notail, T_NEAR); |
67 | |
68 | mov(reg_tail, reg_nelems); |
69 | setup_mask(); |
70 | |
71 | cvt_ps_to_xf16(0, true); |
72 | |
73 | L(l_simd_notail); |
74 | } else { |
75 | const size_t blocked_size = (nelems_ / simd_w_) * simd_w_; |
76 | constexpr size_t unroll_length = 1024; |
77 | const size_t number_of_loops = blocked_size / unroll_length; |
78 | const size_t loop_tail = blocked_size % unroll_length; |
79 | |
80 | if (number_of_loops > 0) { |
81 | Xbyak::Label l_number_of_loops; |
82 | mov(reg_nelems, number_of_loops); |
83 | L(l_number_of_loops); |
84 | for (size_t i = 0; i < unroll_length; i += simd_w_) |
85 | cvt_ps_to_xf16(i, false); |
86 | add(reg_input, sizeof(float) * unroll_length); |
87 | add(reg_output, sizeof(float16_t) * unroll_length); |
88 | |
89 | dec(reg_nelems); |
90 | cmp(reg_nelems, 0); |
91 | jg(l_number_of_loops, T_NEAR); |
92 | } |
93 | if (loop_tail > 0) { |
94 | for (size_t i = 0; i < loop_tail; i += simd_w_) |
95 | cvt_ps_to_xf16(i, false); |
96 | add(reg_input, sizeof(float) * loop_tail); |
97 | add(reg_output, sizeof(float16_t) * loop_tail); |
98 | } |
99 | if (tail_size_ != 0) { |
100 | setup_mask(); |
101 | cvt_ps_to_xf16(0, true); |
102 | } |
103 | } |
104 | postamble(); |
105 | } |
106 | |
107 | template <cpu_isa_t isa> |
108 | void jit_uni_cvt_ps_to_xf16_t<isa>::setup_mask() { |
109 | const Xbyak::Reg32 reg_mask = reg_tmp.cvt32(); |
110 | if (is_dynamic_size_) { |
111 | mov(reg_mask, 1); |
112 | shl(reg_mask, reg_tail.cvt8()); |
113 | sub(reg_mask, 1); |
114 | } else { |
115 | mov(reg_mask, (1 << tail_size_) - 1); |
116 | } |
117 | kmovd(ktail_xf16_mask, reg_mask); |
118 | kmovw(ktail_f32_mask, reg_mask); |
119 | } |
120 | |
121 | template <> |
122 | void jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>::setup_mask() { |
123 | static const uint32_t mask_in[16] |
124 | = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, |
125 | 0xffffffff, 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0, 0}; |
126 | const Xbyak::Reg64 reg64_mask = reg_tmp; |
127 | |
128 | if (!is_dynamic_size_) { |
129 | constexpr int max_words_in_ymm = 8; |
130 | auto mask_in_offset = max_words_in_ymm - tail_size_; |
131 | mov(reg64_mask, reinterpret_cast<size_t>(&mask_in[mask_in_offset])); |
132 | } else { |
133 | mov(reg64_mask, reinterpret_cast<size_t>(&mask_in[8])); |
134 | mov(reg_scratch, reg_tail); |
135 | shl(reg_scratch, 2); |
136 | sub(reg64_mask, reg_scratch); |
137 | } |
138 | vmovups(vmm_in_mask, ptr[reg64_mask]); |
139 | } |
140 | |
141 | // NOTE: putting the function's definition in the header results in |
142 | // a compilation error for VS. |
143 | template <cpu_isa_t isa> |
144 | void jit_uni_cvt_ps_to_xf16_t<isa>::cvt_ps_to_xf16( |
145 | const int idx, const bool is_tail) { |
146 | assert(!"unimplemented template" ); |
147 | } |
148 | |
149 | template <> |
150 | void jit_uni_cvt_ps_to_xf16_t<avx512_core_fp16>::cvt_ps_to_xf16( |
151 | const int idx, const bool is_tail) { |
152 | const Vmm vmm_m_in = is_tail ? vmm_input | ktail_f32_mask | T_z : vmm_input; |
153 | const size_t out_offset = sizeof(float16_t) * idx; |
154 | const auto addr_m_out = is_tail |
155 | ? ptr[reg_output + out_offset] | ktail_xf16_mask |
156 | : ptr[reg_output + out_offset]; |
157 | vmovups(vmm_m_in, ptr[reg_input + sizeof(float) * idx]); |
158 | vcvtps2ph(addr_m_out, vmm_input, _op_mxcsr); |
159 | } |
160 | |
161 | void jit_avx512_core_cvt_ps_to_bf16_t::cvt_ps_to_xf16( |
162 | const int idx, const bool is_tail) { |
163 | const size_t out_offset = sizeof(float16_t) * idx; |
164 | const auto addr_m_out = is_tail |
165 | ? ptr[reg_output + out_offset] | ktail_xf16_mask |
166 | : ptr[reg_output + out_offset]; |
167 | |
168 | if (use_bf16_emu_) { |
169 | const Vmm vmm_m_in |
170 | = is_tail ? vmm_input | ktail_f32_mask | T_z : vmm_input; |
171 | vmovups(vmm_m_in, ptr[reg_input + sizeof(float) * idx]); |
172 | bf16_emu_->vcvtneps2bf16(vmm_output, vmm_input); |
173 | } else { |
174 | const auto vmm_m_out |
175 | = is_tail ? vmm_output | ktail_xf16_mask | T_z : vmm_output; |
176 | vcvtneps2bf16(vmm_m_out, ptr[reg_input + sizeof(float) * idx]); |
177 | } |
178 | vmovdqu16(addr_m_out, vmm_output); |
179 | } |
180 | |
181 | template <> |
182 | void jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>::cvt_ps_to_xf16( |
183 | const int idx, const bool is_tail) { |
184 | if (is_tail) { |
185 | uni_vxorps(vmm_input, vmm_input, vmm_input); |
186 | vmaskmovps( |
187 | vmm_input, vmm_in_mask, ptr[reg_input + sizeof(float) * idx]); |
188 | } else if (output_dt_ == data_type::f16) { |
189 | vmovups(vmm_input, ptr[reg_input + sizeof(float) * idx]); |
190 | } |
191 | |
192 | switch (output_dt_) { |
193 | case data_type::bf16: |
194 | if (is_tail) |
195 | vcvtneps2bf16(vmm_output, vmm_input, Xbyak::VexEncoding); |
196 | else |
197 | vcvtneps2bf16(vmm_output, |
198 | yword[reg_input + sizeof(float) * idx], |
199 | Xbyak::VexEncoding); |
200 | break; |
201 | case data_type::f16: |
202 | if (is_tail) |
203 | vcvtps2ph(vmm_output, vmm_input, _op_mxcsr); |
204 | else |
205 | vcvtps2ph(ptr[reg_output + sizeof(float16_t) * idx], vmm_input, |
206 | _op_mxcsr); |
207 | break; |
208 | default: assert(!"Invalid datatype" ); |
209 | } |
210 | |
211 | if (is_tail) { |
212 | auto tail_store = [&](int load_size) { |
213 | store_bytes(vmm_output, reg_output, sizeof(float16_t) * idx, |
214 | sizeof(float16_t) * load_size); |
215 | }; |
216 | if (is_dynamic_size_) |
217 | runtime_tail_process<Xbyak::Xmm>( |
218 | reg_tail, reg_tmp, tail_store, data_type::f16); |
219 | else |
220 | tail_store(tail_size_); |
221 | |
222 | } else if (output_dt_ == data_type::bf16) |
223 | vmovups(ptr[reg_output + sizeof(bfloat16_t) * idx], vmm_output); |
224 | } |
225 | |
226 | #undef GET_OFF |
227 | |
228 | template struct jit_uni_cvt_ps_to_xf16_t<avx2_vnni_2>; |
229 | template struct jit_uni_cvt_ps_to_xf16_t<avx512_core>; |
230 | template struct jit_uni_cvt_ps_to_xf16_t<avx512_core_fp16>; |
231 | |
232 | #define GET_OFF(field) \ |
233 | offsetof(cvt_xf16_support::jit_cvt_xf16_to_ps_params_t, field) |
234 | |
235 | template <cpu_isa_t isa> |
236 | void jit_uni_cvt_xf16_to_ps_t<isa>::generate() { |
237 | preamble(); |
238 | const bool long_row_stride = (row_stride_ * sizeof(float16_t) >> 32) != 0; |
239 | MAYBE_UNUSED(long_row_stride); |
240 | |
241 | mov(reg_input, ptr[abi_param1 + GET_OFF(inp)]); |
242 | mov(reg_output, ptr[abi_param1 + GET_OFF(out)]); |
243 | mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]); |
244 | mov(reg_nrows, ptr[abi_param1 + GET_OFF(rows)]); |
245 | |
246 | Label l_row_start, l_row_end; |
247 | Label l_exit; // used for row_stride_ |
248 | |
249 | if (row_stride_) { |
250 | test(reg_nrows, reg_nrows); |
251 | jz(l_exit, T_NEAR); // fast exit: nrows == 0 |
252 | mov(reg_nelems_save, reg_nelems); |
253 | mov(reg_rollback, reg_nelems); |
254 | and_(reg_rollback, ~(simd_w_ - 1)); |
255 | neg(reg_rollback); |
256 | if (long_row_stride) { |
257 | mov(reg_long_row_stride, row_stride_ * sizeof(float16_t)); |
258 | lea(reg_long_row_stride, |
259 | ptr[reg_long_row_stride |
260 | + reg_rollback * sizeof(float16_t)]); |
261 | } |
262 | } |
263 | |
264 | L(l_row_start); |
265 | |
266 | constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0 |
267 | Label l_simd_loop[n_unroll + 2]; |
268 | for (int i = n_unroll; i >= 0; i--) { |
269 | const int unroll = 1 << i; // 4, 2, 1 |
270 | assert(IMPLICATION(unroll > 1, unroll % 2 == 0)); |
271 | L(l_simd_loop[i + 1]); |
272 | { |
273 | cmp(reg_nelems, simd_w_ * unroll); |
274 | jl(l_simd_loop[i], T_NEAR); |
275 | for (int j = 0; j < utils::div_up(unroll, elem_granularity); ++j) |
276 | convert_xf16(j, unroll > 1); |
277 | add(reg_input, simd_w_ * unroll * sizeof(float16_t)); |
278 | add(reg_output, simd_w_ * unroll * sizeof(float)); |
279 | sub(reg_nelems, simd_w_ * unroll); |
280 | if (i == n_unroll && n_unroll != 0) jmp(l_simd_loop[i + 1], T_NEAR); |
281 | } |
282 | } |
283 | L(l_simd_loop[0]); |
284 | |
285 | test(reg_nelems, reg_nelems); |
286 | jz(l_row_end, T_NEAR); |
287 | |
288 | mov(reg_tail, reg_nelems); |
289 | cvt_tail(); |
290 | |
291 | L(l_row_end); |
292 | |
293 | if (row_stride_) { |
294 | dec(reg_nrows); |
295 | jz(l_exit, T_NEAR); |
296 | |
297 | // wraparound |
298 | lea(reg_output, ptr[reg_output + reg_rollback * sizeof(float)]); |
299 | if (long_row_stride) |
300 | add(reg_input, reg_long_row_stride); |
301 | else |
302 | lea(reg_input, |
303 | ptr[reg_input + reg_rollback * sizeof(float16_t) |
304 | + row_stride_ * sizeof(float16_t)]); |
305 | mov(reg_nelems, reg_nelems_save); |
306 | jmp(l_row_start); |
307 | |
308 | L(l_exit); |
309 | } |
310 | |
311 | postamble(); |
312 | } |
313 | |
314 | template <cpu_isa_t isa> |
315 | void jit_uni_cvt_xf16_to_ps_t<isa>::convert_xf16( |
316 | const int idx, const bool handle_x2) { |
317 | const size_t offset = idx * simd_w_; |
318 | const auto out_addr = ptr[reg_output + sizeof(float) * offset]; |
319 | const auto in_addr = ptr[reg_input + sizeof(bfloat16_t) * offset]; |
320 | switch (input_dt_) { |
321 | case data_type::bf16: |
322 | vpmovzxwd(get_vmm_src(idx), in_addr); |
323 | vpslld(get_vmm_src(idx), get_vmm_src(idx), 0x10); |
324 | break; |
325 | case data_type::f16: vcvtph2psx(get_vmm_src(idx), in_addr); break; |
326 | default: assert(!"Invalid datatype" ); |
327 | } |
328 | if (with_add_) vaddps(get_vmm_src(idx), get_vmm_src(idx), out_addr); |
329 | uni_vmovdqu(out_addr, get_vmm_src(idx)); |
330 | } |
331 | |
332 | template <typename Wmm> |
333 | struct helper_avx2_cvt_xf16_t { |
334 | static void convert_xf16(jit_generator *host, |
335 | const impl::data_type_t input_dt, const Xbyak::Address in_addr, |
336 | const int even_src, const int odd_src, const int tmp_1, |
337 | const int tmp_2) { |
338 | const Wmm vmm_even_src = Wmm(even_src); |
339 | const Wmm vmm_odd_src = Wmm(odd_src); |
340 | const Wmm vmm_tmp_1 = Wmm(tmp_1); |
341 | const Wmm vmm_tmp_2 = Wmm(tmp_2); |
342 | |
343 | switch (input_dt) { |
344 | case data_type::bf16: |
345 | host->vcvtneebf162ps(vmm_even_src, in_addr); |
346 | host->vcvtneobf162ps(vmm_odd_src, in_addr); |
347 | break; |
348 | case data_type::f16: |
349 | host->vcvtneeph2ps(vmm_even_src, in_addr); |
350 | host->vcvtneoph2ps(vmm_odd_src, in_addr); |
351 | break; |
352 | default: assert(!"Invalid datatype" ); |
353 | } |
354 | host->vpunpckldq(vmm_tmp_1, vmm_even_src, vmm_odd_src); |
355 | host->vpunpckhdq(vmm_tmp_2, vmm_even_src, vmm_odd_src); |
356 | } |
357 | }; |
358 | |
359 | template <> |
360 | void jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>::convert_xf16( |
361 | const int idx, const bool handle_x2) { |
362 | const Vmm vmm_tmp_1 = vmm_tmp; |
363 | const Vmm vmm_tmp_2 = Vmm(get_even_src_idx(idx)); |
364 | const size_t offset = idx * simd_w_ * elem_granularity; |
365 | const auto in_addr = ptr[reg_input + sizeof(bfloat16_t) * offset]; |
366 | auto get_out_addr = [&](const size_t offset_xmmword = 0) { |
367 | return ptr[reg_output + sizeof(float) * (offset + offset_xmmword)]; |
368 | }; |
369 | |
370 | if (handle_x2) |
371 | helper_avx2_cvt_xf16_t<Xbyak::Ymm>::convert_xf16(this, input_dt_, |
372 | in_addr, get_even_src_idx(idx), get_odd_src_idx(idx), |
373 | vmm_tmp_1.getIdx(), vmm_tmp_2.getIdx()); |
374 | else |
375 | helper_avx2_cvt_xf16_t<Xbyak::Xmm>::convert_xf16(this, input_dt_, |
376 | in_addr, get_even_src_idx(idx), get_odd_src_idx(idx), |
377 | vmm_tmp_1.getIdx(), vmm_tmp_2.getIdx()); |
378 | |
379 | vperm2f128(vmm_dst, vmm_tmp_1, vmm_tmp_2, 0x20); |
380 | if (handle_x2) vperm2f128(vmm_dst_2, vmm_tmp_1, vmm_tmp_2, 0x31); |
381 | |
382 | if (with_add_) { |
383 | vaddps(vmm_dst, vmm_dst, get_out_addr()); |
384 | if (handle_x2) vaddps(vmm_dst_2, vmm_dst_2, get_out_addr(simd_w_)); |
385 | } |
386 | uni_vmovdqu(get_out_addr(), vmm_dst); |
387 | if (handle_x2) uni_vmovdqu(get_out_addr(simd_w_), vmm_dst_2); |
388 | } |
389 | |
390 | template <cpu_isa_t isa> |
391 | void jit_uni_cvt_xf16_to_ps_t<isa>::cvt_tail() { |
392 | const Reg32 reg32_mask |
393 | = reg_nelems.cvt32(); // no need for reg_nelems anymore |
394 | |
395 | // ktail_mask <-- (1 << (nelems % simd_w_)) - 1 |
396 | mov(reg32_mask, 1); |
397 | shl(reg32_mask, reg_tail.cvt8()); |
398 | sub(reg32_mask, 1); |
399 | kmovd(ktail_mask, reg32_mask); |
400 | |
401 | auto vmm_masked = get_vmm_src(0) | ktail_mask | T_z; |
402 | switch (input_dt_) { |
403 | case data_type::bf16: |
404 | vpmovzxwd(vmm_masked, ptr[reg_input]); |
405 | vpslld(vmm_masked, get_vmm_src(0), 0x10); |
406 | break; |
407 | case data_type::f16: vcvtph2psx(vmm_masked, ptr[reg_input]); break; |
408 | default: assert(!"Invalid datatype" ); |
409 | } |
410 | if (with_add_) vaddps(vmm_masked, get_vmm_src(0), ptr[reg_output]); |
411 | vmovdqu32(ptr[reg_output] | ktail_mask, get_vmm_src(0)); |
412 | } |
413 | |
414 | template <> |
415 | void jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>::cvt_tail() { |
416 | const Vmm vmm_output = get_vmm_src(0); |
417 | const Vmm_down_t vmm_input = Vmm_down_t(vmm_output.getIdx()); |
418 | auto runtime_tail_load = [&](int load_size) { |
419 | load_bytes(vmm_input, reg_input, 0, sizeof(bfloat16_t) * load_size); |
420 | }; |
421 | auto runtime_tail_store = [&](int load_size) { |
422 | store_data(data_type::f32, vmm_output, reg_output, 0, load_size); |
423 | }; |
424 | |
425 | uni_vxorps(vmm_input, vmm_input, vmm_input); |
426 | runtime_tail_process<Xbyak::Xmm>( |
427 | reg_tail, reg_tmp, runtime_tail_load, data_type::f16); |
428 | switch (input_dt_) { |
429 | case data_type::bf16: |
430 | vpmovzxwd(vmm_output, vmm_input); |
431 | vpslld(vmm_output, vmm_input, 0x10); |
432 | break; |
433 | case data_type::f16: vcvtph2ps(vmm_output, vmm_input); break; |
434 | default: assert(!"Invalid datatype" ); |
435 | } |
436 | runtime_tail_process<Xbyak::Ymm>( |
437 | reg_tail, reg_tmp, runtime_tail_store, data_type::f32); |
438 | } |
439 | |
440 | #undef GET_OFF |
441 | |
442 | template struct jit_uni_cvt_xf16_to_ps_t<avx2_vnni_2>; |
443 | template struct jit_uni_cvt_xf16_to_ps_t<avx512_core>; |
444 | template struct jit_uni_cvt_xf16_to_ps_t<avx512_core_fp16>; |
445 | |
446 | } // namespace x64 |
447 | } // namespace cpu |
448 | } // namespace impl |
449 | } // namespace dnnl |
450 | |