1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_BF16CVT_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_BF16CVT_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/nstl.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "cpu/x64/cpu_isa_traits.hpp" |
26 | #include "oneapi/dnnl/dnnl_debug.h" |
27 | |
28 | #include "common/bfloat16.hpp" |
29 | #include "cpu/x64/jit_generator.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | namespace bf16_support { |
37 | struct jit_call_t { |
38 | void *inp; |
39 | void *out; |
40 | void *add; |
41 | size_t nelems; |
42 | int mask; |
43 | }; |
44 | } // namespace bf16_support |
45 | |
46 | #define GET_OFF(field) offsetof(bf16_support::jit_call_t, field) |
47 | |
48 | struct bf16_emulation_t { |
49 | using opmask_t = const Xbyak::Opmask; |
50 | using Zmm_t = const Xbyak::Zmm; |
51 | using Ymm_t = const Xbyak::Ymm; |
52 | using Xmm_t = const Xbyak::Xmm; |
53 | using reg64_t = const Xbyak::Reg64; |
54 | |
55 | bf16_emulation_t(jit_generator *host, Zmm_t one, Zmm_t even, Zmm_t selector, |
56 | reg64_t scratch, Zmm_t tr0, Zmm_t tr1) |
57 | : host_(host) |
58 | , one_(one) |
59 | , even_(even) |
60 | , selector_(selector) |
61 | , scratch_(scratch) |
62 | , tr0_(tr0) |
63 | , tr1_(tr1) {} |
64 | |
65 | bf16_emulation_t(jit_generator *host, Zmm_t one, Zmm_t even, Zmm_t selector, |
66 | reg64_t scratch, Zmm_t tr0) |
67 | : bf16_emulation_t(host, one, even, selector, scratch, tr0, tr0) {} |
68 | |
69 | void vdpbf16ps(Zmm_t &acc, Zmm_t wei, Zmm_t inp) { |
70 | host_->vpsrad(tr0_, wei, 16); |
71 | host_->vpslld(tr0_, tr0_, 16); |
72 | |
73 | host_->vpsrad(tr1_, inp, 16); |
74 | host_->vpslld(tr1_, tr1_, 16); |
75 | |
76 | host_->vfmadd231ps(acc, tr1_, tr0_); |
77 | |
78 | host_->vpslld(tr0_, wei, 16); |
79 | host_->vpslld(tr1_, inp, 16); |
80 | |
81 | host_->vfmadd231ps(acc, tr1_, tr0_); |
82 | } |
83 | |
84 | void vcvtneps2bf16(Xmm_t &out, Xmm_t &in) { |
85 | const bool input_is_zmm = in.isZMM() && out.isYMM(); |
86 | const bool input_is_ymm = in.isYMM() && out.isXMM(); |
87 | assert((input_is_zmm || input_is_ymm) |
88 | && "Incorrect usage of vcvtneps2bf16 instruction." ); |
89 | |
90 | if (input_is_zmm) |
91 | vcvtneps2bf16(out, in, tr0_, one_, even_, selector_); |
92 | else if (input_is_ymm) { |
93 | const Ymm_t tr0_y {tr0_.getIdx()}; |
94 | const Ymm_t even_y {even_.getIdx()}; |
95 | const Ymm_t selector_y {selector_.getIdx()}; |
96 | const Ymm_t one_y {one_.getIdx()}; |
97 | |
98 | vcvtneps2bf16(out, in, tr0_y, one_y, even_y, selector_y); |
99 | } |
100 | } |
101 | |
102 | private: |
103 | void vcvtneps2bf16(const Xbyak::Operand &out, const Xmm_t &in, |
104 | const Xmm_t &tr0, const Xbyak::Operand &one, const Xmm_t &even, |
105 | const Xbyak::Operand &selector) { |
106 | host_->vpsrld(tr0, in, 16); |
107 | host_->vpandd(tr0, tr0, one); |
108 | |
109 | host_->vpaddd(tr0, even, tr0); |
110 | |
111 | host_->vpaddd(tr0, in, tr0); |
112 | host_->vfixupimmps(tr0, in, selector, 0); |
113 | |
114 | host_->vpsrad(tr0, tr0, 16); |
115 | host_->vpmovdw(out, tr0); |
116 | } |
117 | |
118 | public: |
119 | void init_vcvtneps2bf16() { |
120 | const int selector_int32 = |
121 | /* qnan input to qnan output (presenrving input bits 0..21) */ |
122 | encode_fixup_selector( |
123 | fixup_input_code_snan, fixup_output_code_qnan_input) |
124 | | |
125 | /* snan input to qnan output (presenrving input bits 0..21) */ |
126 | encode_fixup_selector( |
127 | fixup_input_code_qnan, fixup_output_code_qnan_input) |
128 | | |
129 | /* neg inf input copied to output */ |
130 | encode_fixup_selector( |
131 | fixup_input_code_ninf, fixup_output_code_copy_input) |
132 | | |
133 | /* pos inf input copied to output */ |
134 | encode_fixup_selector( |
135 | fixup_input_code_pinf, fixup_output_code_copy_input); |
136 | |
137 | host_->xor_(scratch_, scratch_); |
138 | host_->mov(scratch_.cvt32(), 0x1); |
139 | host_->vpbroadcastd(one_, scratch_.cvt32()); |
140 | |
141 | host_->xor_(scratch_, scratch_); |
142 | host_->mov(scratch_.cvt32(), 0x7fff); |
143 | host_->vpbroadcastd(even_, scratch_.cvt32()); |
144 | |
145 | host_->xor_(scratch_, scratch_); |
146 | host_->mov(scratch_.cvt32(), selector_int32); |
147 | host_->vpbroadcastd(selector_, scratch_.cvt32()); |
148 | } |
149 | |
150 | static cpu_isa_t get_isa() { return avx512_core; } |
151 | |
152 | private: |
153 | jit_generator *const host_; |
154 | Zmm_t one_; |
155 | Zmm_t even_; |
156 | Zmm_t selector_; |
157 | reg64_t scratch_; |
158 | Zmm_t tr0_; |
159 | Zmm_t tr1_; |
160 | |
161 | int encode_fixup_selector(int input, int output) { |
162 | return ((output) << (4 * (input))); |
163 | } |
164 | |
165 | enum { |
166 | fixup_input_code_qnan = 0, |
167 | fixup_input_code_snan = 1, |
168 | fixup_input_code_ninf = 4, |
169 | fixup_input_code_pinf = 5, |
170 | fixup_output_code_copy_input = 1, |
171 | fixup_output_code_qnan_input = 2, |
172 | }; |
173 | }; |
174 | |
175 | // performs element-by-element sum of inp and add float arrays and stores |
176 | // result to bfloat16 out array with downconversion |
177 | struct jit_avx512_core_add_cvt_ps_to_bf16_t : public jit_generator { |
178 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_add_cvt_ps_to_bf16) |
179 | |
180 | jit_avx512_core_add_cvt_ps_to_bf16_t() |
181 | : jit_generator(jit_name()), simd_w_(16) { |
182 | bf16_emu_ = new bf16_emulation_t( |
183 | this, one, even, selector, scratch, fp32_tmp, fp32_tmp); |
184 | |
185 | create_kernel(); |
186 | } |
187 | |
188 | ~jit_avx512_core_add_cvt_ps_to_bf16_t() { delete bf16_emu_; } |
189 | |
190 | void generate() override { |
191 | preamble(); |
192 | |
193 | bool use_bf16_emu = !mayiuse(avx512_core_bf16); |
194 | |
195 | auto add_cvt = [&](size_t idx, Xbyak::Opmask ktail_mask) { |
196 | vmovups(fp32_inp | ktail_mask | T_z, |
197 | ptr[reg_inp + sizeof(float) * (idx)]); |
198 | vaddps(fp32_inp | ktail_mask | T_z, fp32_inp, |
199 | ptr[reg_add + sizeof(float) * (idx)]); |
200 | if (use_bf16_emu) |
201 | bf16_emu_->vcvtneps2bf16(bf16_out, fp32_inp); |
202 | else |
203 | vcvtneps2bf16(bf16_out, fp32_inp); |
204 | |
205 | vmovdqu16(yword[reg_out + sizeof(bfloat16_t) * (idx)] | ktail_mask, |
206 | bf16_out); |
207 | }; |
208 | |
209 | mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]); |
210 | mov(reg_add, ptr[abi_param1 + GET_OFF(add)]); |
211 | mov(reg_out, ptr[abi_param1 + GET_OFF(out)]); |
212 | mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]); |
213 | |
214 | if (use_bf16_emu) bf16_emu_->init_vcvtneps2bf16(); |
215 | |
216 | mov(reg32_tail, 0xffff); |
217 | kmovw(ktail_mask, reg32_tail); |
218 | |
219 | constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0 |
220 | Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail; |
221 | for (int i = n_unroll; i >= 0; i--) { |
222 | const int unroll = 1 << i; // 4, 2, 1 |
223 | L(l_simd_loop[i + 1]); |
224 | { |
225 | cmp(reg_nelems, simd_w_ * unroll); |
226 | jl(l_simd_loop[i], T_NEAR); |
227 | for (int j = 0; j < simd_w_ * unroll; j += simd_w_) { |
228 | add_cvt(j, ktail_mask); |
229 | } |
230 | add(reg_inp, simd_w_ * unroll * sizeof(float)); |
231 | add(reg_add, simd_w_ * unroll * sizeof(float)); |
232 | add(reg_out, simd_w_ * unroll * sizeof(bfloat16_t)); |
233 | |
234 | sub(reg_nelems, simd_w_ * unroll); |
235 | jmp(l_simd_loop[i + 1], T_NEAR); |
236 | } |
237 | } |
238 | L(l_simd_loop[0]); |
239 | test(reg_nelems, reg_nelems); |
240 | jz(l_simd_notail); |
241 | // JIT of `tail_mask_ = (1 << (nelems_ % simd_w_)) - 1;` |
242 | mov(reg32_mask, 1); |
243 | mov(reg64_tail, reg_nelems); |
244 | shl(reg32_mask, reg8_mask_shift); |
245 | sub(reg32_mask, 1); |
246 | kmovd(ktail_mask, reg32_mask); |
247 | add_cvt(0, ktail_mask); |
248 | L(l_simd_notail); |
249 | |
250 | postamble(); |
251 | } |
252 | |
253 | void operator()(bf16_support::jit_call_t *params) const { |
254 | jit_generator::operator()(params); |
255 | msan_unpoison(params->out, params->nelems * sizeof(bfloat16_t)); |
256 | } |
257 | |
258 | private: |
259 | int simd_w_; |
260 | |
261 | bf16_emulation_t *bf16_emu_; |
262 | |
263 | Xbyak::Opmask ktail_mask = k2; |
264 | Xbyak::Zmm fp32_inp = Xbyak::Zmm(0); |
265 | Xbyak::Zmm fp32_tmp = Xbyak::Zmm(1); |
266 | |
267 | Xbyak::Zmm one = Xbyak::Zmm(2); |
268 | Xbyak::Zmm even = Xbyak::Zmm(3); |
269 | Xbyak::Zmm selector = Xbyak::Zmm(4); |
270 | Xbyak::Reg64 scratch = r15; |
271 | |
272 | Xbyak::Ymm bf16_out = Xbyak::Ymm(5); |
273 | |
274 | Xbyak::Reg64 reg_inp = rax; |
275 | Xbyak::Reg64 reg_out = rbx; |
276 | Xbyak::Reg64 reg_add = r11; |
277 | Xbyak::Reg64 reg_nelems = rdx; |
278 | |
279 | Xbyak::Reg64 reg64_tail = rcx; |
280 | Xbyak::Reg32 reg32_tail = ecx; |
281 | Xbyak::Reg8 reg8_mask_shift = cl; |
282 | Xbyak::Reg32 reg32_mask = r8d; |
283 | }; |
284 | |
285 | // implementation of reorder of part of tensor [s][16c] -> [S][16c][2s] |
286 | // it is required for quick implementation of 1x1 bf16 bwd_w jit kernel |
287 | // w/o using permw instruction inside |
288 | // TODO: consider modification/replacement for outer transformation jit kernel |
289 | struct jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t : public jit_generator { |
290 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_reorder_s16c_to_S16c2s) |
291 | |
292 | jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t() |
293 | : jit_generator(jit_name()), simd_w_(16), in_stride_(16) {} |
294 | |
295 | jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t(int in_stride) |
296 | : jit_generator(jit_name()), simd_w_(16), in_stride_(in_stride) {} |
297 | |
298 | ~jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t() {} |
299 | |
300 | void generate() override { |
301 | preamble(); |
302 | |
303 | mov(reg32_tail, ptr[abi_param1 + GET_OFF(mask)]); |
304 | mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]); |
305 | mov(reg_out, ptr[abi_param1 + GET_OFF(out)]); |
306 | mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]); |
307 | |
308 | auto zmm_reg = [=](int idx) { |
309 | assert(idx < 31); |
310 | return Xbyak::Zmm(idx); |
311 | }; |
312 | |
313 | kmovd(ktail_mask_lo, reg32_tail); |
314 | kshiftld(ktail_mask_hi, ktail_mask_lo, 16); |
315 | |
316 | Xbyak::Label dst_prm_table; |
317 | mov(reg_prm, dst_prm_table); |
318 | vmovups(zmm_prm, ptr[reg_prm]); |
319 | |
320 | constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0 |
321 | int sizeofcacheline = 2 * simd_w_ * sizeof(bfloat16_t); |
322 | int in_stride_bytes = in_stride_ * sizeof(bfloat16_t); |
323 | Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail; |
324 | for (int i = n_unroll; i >= 0; i--) { |
325 | const int unroll = 1 << i; // 4, 2, 1 |
326 | L(l_simd_loop[i + 1]); |
327 | { |
328 | cmp(reg_nelems, 2 * unroll); |
329 | jl(l_simd_loop[i], T_NEAR); |
330 | for (int j = 0; j < unroll; j++) { |
331 | auto zmm_inp = zmm_reg(j); |
332 | if (in_stride_ == 16) |
333 | vmovups(zmm_inp, zword[reg_inp + j * sizeofcacheline]); |
334 | else { |
335 | vmovdqu16(zmm_inp | ktail_mask_lo | T_z, |
336 | zword[reg_inp + 2 * j * in_stride_bytes]); |
337 | vmovdqu16(zmm_inp | ktail_mask_hi, |
338 | zword[reg_inp + (2 * j + 1) * in_stride_bytes |
339 | - 32]); |
340 | } |
341 | vpermw(zmm_inp, zmm_prm, zmm_inp); |
342 | vmovups(zword[reg_out + j * sizeofcacheline], zmm_inp); |
343 | } |
344 | add(reg_inp, |
345 | unroll |
346 | * (in_stride_ == 16 ? sizeofcacheline |
347 | : 2 * in_stride_bytes)); |
348 | add(reg_out, unroll * sizeofcacheline); |
349 | |
350 | sub(reg_nelems, 2 * unroll); |
351 | jmp(l_simd_loop[i + 1], T_NEAR); |
352 | } |
353 | } |
354 | L(l_simd_loop[0]); |
355 | |
356 | test(reg_nelems, reg_nelems); |
357 | jz(l_simd_notail); |
358 | |
359 | auto zmm_inp = zmm_reg(0); |
360 | vpxord(zmm_inp, zmm_inp, zmm_inp); |
361 | vmovdqu16(zmm_inp | ktail_mask_lo | T_z, ptr[reg_inp]); |
362 | vpermw(zmm_inp, zmm_prm, zmm_inp); |
363 | vmovups(zword[reg_out], zmm_inp); |
364 | |
365 | L(l_simd_notail); |
366 | |
367 | postamble(); |
368 | |
369 | const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, |
370 | 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, |
371 | 29, 14, 30, 15, 31}; |
372 | |
373 | align(64); |
374 | L(dst_prm_table); |
375 | for (size_t i = 0; i < 32; ++i) |
376 | dw(dst_prm_array[i]); |
377 | } |
378 | |
379 | void operator()(bf16_support::jit_call_t *params) const { |
380 | jit_generator::operator()(params); |
381 | msan_unpoison(params->out, params->nelems * sizeof(bfloat16_t)); |
382 | } |
383 | |
384 | private: |
385 | int simd_w_; |
386 | int in_stride_; |
387 | |
388 | Xbyak::Opmask ktail_mask_lo = k2; |
389 | Xbyak::Opmask ktail_mask_hi = k3; |
390 | Xbyak::Zmm zmm_prm = Xbyak::Zmm(31); |
391 | |
392 | Xbyak::Reg64 reg_inp = rax; |
393 | Xbyak::Reg64 reg_out = rbx; |
394 | Xbyak::Reg64 reg_prm = r11; |
395 | Xbyak::Reg64 reg_nelems = rdx; |
396 | |
397 | Xbyak::Reg32 reg32_tail = abi_not_param1.cvt32(); |
398 | }; |
399 | |
400 | #undef GET_OFF |
401 | } // namespace x64 |
402 | } // namespace cpu |
403 | } // namespace impl |
404 | } // namespace dnnl |
405 | |
406 | #endif |
407 | |