1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_BF16CVT_HPP
18#define CPU_X64_JIT_AVX512_CORE_BF16CVT_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/nstl.hpp"
24#include "common/type_helpers.hpp"
25#include "cpu/x64/cpu_isa_traits.hpp"
26#include "oneapi/dnnl/dnnl_debug.h"
27
28#include "common/bfloat16.hpp"
29#include "cpu/x64/jit_generator.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36namespace bf16_support {
37struct jit_call_t {
38 void *inp;
39 void *out;
40 void *add;
41 size_t nelems;
42 int mask;
43};
44} // namespace bf16_support
45
46#define GET_OFF(field) offsetof(bf16_support::jit_call_t, field)
47
48struct bf16_emulation_t {
49 using opmask_t = const Xbyak::Opmask;
50 using Zmm_t = const Xbyak::Zmm;
51 using Ymm_t = const Xbyak::Ymm;
52 using Xmm_t = const Xbyak::Xmm;
53 using reg64_t = const Xbyak::Reg64;
54
55 bf16_emulation_t(jit_generator *host, Zmm_t one, Zmm_t even, Zmm_t selector,
56 reg64_t scratch, Zmm_t tr0, Zmm_t tr1)
57 : host_(host)
58 , one_(one)
59 , even_(even)
60 , selector_(selector)
61 , scratch_(scratch)
62 , tr0_(tr0)
63 , tr1_(tr1) {}
64
65 bf16_emulation_t(jit_generator *host, Zmm_t one, Zmm_t even, Zmm_t selector,
66 reg64_t scratch, Zmm_t tr0)
67 : bf16_emulation_t(host, one, even, selector, scratch, tr0, tr0) {}
68
69 void vdpbf16ps(Zmm_t &acc, Zmm_t wei, Zmm_t inp) {
70 host_->vpsrad(tr0_, wei, 16);
71 host_->vpslld(tr0_, tr0_, 16);
72
73 host_->vpsrad(tr1_, inp, 16);
74 host_->vpslld(tr1_, tr1_, 16);
75
76 host_->vfmadd231ps(acc, tr1_, tr0_);
77
78 host_->vpslld(tr0_, wei, 16);
79 host_->vpslld(tr1_, inp, 16);
80
81 host_->vfmadd231ps(acc, tr1_, tr0_);
82 }
83
84 void vcvtneps2bf16(Xmm_t &out, Xmm_t &in) {
85 const bool input_is_zmm = in.isZMM() && out.isYMM();
86 const bool input_is_ymm = in.isYMM() && out.isXMM();
87 assert((input_is_zmm || input_is_ymm)
88 && "Incorrect usage of vcvtneps2bf16 instruction.");
89
90 if (input_is_zmm)
91 vcvtneps2bf16(out, in, tr0_, one_, even_, selector_);
92 else if (input_is_ymm) {
93 const Ymm_t tr0_y {tr0_.getIdx()};
94 const Ymm_t even_y {even_.getIdx()};
95 const Ymm_t selector_y {selector_.getIdx()};
96 const Ymm_t one_y {one_.getIdx()};
97
98 vcvtneps2bf16(out, in, tr0_y, one_y, even_y, selector_y);
99 }
100 }
101
102private:
103 void vcvtneps2bf16(const Xbyak::Operand &out, const Xmm_t &in,
104 const Xmm_t &tr0, const Xbyak::Operand &one, const Xmm_t &even,
105 const Xbyak::Operand &selector) {
106 host_->vpsrld(tr0, in, 16);
107 host_->vpandd(tr0, tr0, one);
108
109 host_->vpaddd(tr0, even, tr0);
110
111 host_->vpaddd(tr0, in, tr0);
112 host_->vfixupimmps(tr0, in, selector, 0);
113
114 host_->vpsrad(tr0, tr0, 16);
115 host_->vpmovdw(out, tr0);
116 }
117
118public:
119 void init_vcvtneps2bf16() {
120 const int selector_int32 =
121 /* qnan input to qnan output (presenrving input bits 0..21) */
122 encode_fixup_selector(
123 fixup_input_code_snan, fixup_output_code_qnan_input)
124 |
125 /* snan input to qnan output (presenrving input bits 0..21) */
126 encode_fixup_selector(
127 fixup_input_code_qnan, fixup_output_code_qnan_input)
128 |
129 /* neg inf input copied to output */
130 encode_fixup_selector(
131 fixup_input_code_ninf, fixup_output_code_copy_input)
132 |
133 /* pos inf input copied to output */
134 encode_fixup_selector(
135 fixup_input_code_pinf, fixup_output_code_copy_input);
136
137 host_->xor_(scratch_, scratch_);
138 host_->mov(scratch_.cvt32(), 0x1);
139 host_->vpbroadcastd(one_, scratch_.cvt32());
140
141 host_->xor_(scratch_, scratch_);
142 host_->mov(scratch_.cvt32(), 0x7fff);
143 host_->vpbroadcastd(even_, scratch_.cvt32());
144
145 host_->xor_(scratch_, scratch_);
146 host_->mov(scratch_.cvt32(), selector_int32);
147 host_->vpbroadcastd(selector_, scratch_.cvt32());
148 }
149
150 static cpu_isa_t get_isa() { return avx512_core; }
151
152private:
153 jit_generator *const host_;
154 Zmm_t one_;
155 Zmm_t even_;
156 Zmm_t selector_;
157 reg64_t scratch_;
158 Zmm_t tr0_;
159 Zmm_t tr1_;
160
161 int encode_fixup_selector(int input, int output) {
162 return ((output) << (4 * (input)));
163 }
164
165 enum {
166 fixup_input_code_qnan = 0,
167 fixup_input_code_snan = 1,
168 fixup_input_code_ninf = 4,
169 fixup_input_code_pinf = 5,
170 fixup_output_code_copy_input = 1,
171 fixup_output_code_qnan_input = 2,
172 };
173};
174
175// performs element-by-element sum of inp and add float arrays and stores
176// result to bfloat16 out array with downconversion
177struct jit_avx512_core_add_cvt_ps_to_bf16_t : public jit_generator {
178 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_add_cvt_ps_to_bf16)
179
180 jit_avx512_core_add_cvt_ps_to_bf16_t()
181 : jit_generator(jit_name()), simd_w_(16) {
182 bf16_emu_ = new bf16_emulation_t(
183 this, one, even, selector, scratch, fp32_tmp, fp32_tmp);
184
185 create_kernel();
186 }
187
188 ~jit_avx512_core_add_cvt_ps_to_bf16_t() { delete bf16_emu_; }
189
190 void generate() override {
191 preamble();
192
193 bool use_bf16_emu = !mayiuse(avx512_core_bf16);
194
195 auto add_cvt = [&](size_t idx, Xbyak::Opmask ktail_mask) {
196 vmovups(fp32_inp | ktail_mask | T_z,
197 ptr[reg_inp + sizeof(float) * (idx)]);
198 vaddps(fp32_inp | ktail_mask | T_z, fp32_inp,
199 ptr[reg_add + sizeof(float) * (idx)]);
200 if (use_bf16_emu)
201 bf16_emu_->vcvtneps2bf16(bf16_out, fp32_inp);
202 else
203 vcvtneps2bf16(bf16_out, fp32_inp);
204
205 vmovdqu16(yword[reg_out + sizeof(bfloat16_t) * (idx)] | ktail_mask,
206 bf16_out);
207 };
208
209 mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
210 mov(reg_add, ptr[abi_param1 + GET_OFF(add)]);
211 mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
212 mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]);
213
214 if (use_bf16_emu) bf16_emu_->init_vcvtneps2bf16();
215
216 mov(reg32_tail, 0xffff);
217 kmovw(ktail_mask, reg32_tail);
218
219 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
220 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
221 for (int i = n_unroll; i >= 0; i--) {
222 const int unroll = 1 << i; // 4, 2, 1
223 L(l_simd_loop[i + 1]);
224 {
225 cmp(reg_nelems, simd_w_ * unroll);
226 jl(l_simd_loop[i], T_NEAR);
227 for (int j = 0; j < simd_w_ * unroll; j += simd_w_) {
228 add_cvt(j, ktail_mask);
229 }
230 add(reg_inp, simd_w_ * unroll * sizeof(float));
231 add(reg_add, simd_w_ * unroll * sizeof(float));
232 add(reg_out, simd_w_ * unroll * sizeof(bfloat16_t));
233
234 sub(reg_nelems, simd_w_ * unroll);
235 jmp(l_simd_loop[i + 1], T_NEAR);
236 }
237 }
238 L(l_simd_loop[0]);
239 test(reg_nelems, reg_nelems);
240 jz(l_simd_notail);
241 // JIT of `tail_mask_ = (1 << (nelems_ % simd_w_)) - 1;`
242 mov(reg32_mask, 1);
243 mov(reg64_tail, reg_nelems);
244 shl(reg32_mask, reg8_mask_shift);
245 sub(reg32_mask, 1);
246 kmovd(ktail_mask, reg32_mask);
247 add_cvt(0, ktail_mask);
248 L(l_simd_notail);
249
250 postamble();
251 }
252
253 void operator()(bf16_support::jit_call_t *params) const {
254 jit_generator::operator()(params);
255 msan_unpoison(params->out, params->nelems * sizeof(bfloat16_t));
256 }
257
258private:
259 int simd_w_;
260
261 bf16_emulation_t *bf16_emu_;
262
263 Xbyak::Opmask ktail_mask = k2;
264 Xbyak::Zmm fp32_inp = Xbyak::Zmm(0);
265 Xbyak::Zmm fp32_tmp = Xbyak::Zmm(1);
266
267 Xbyak::Zmm one = Xbyak::Zmm(2);
268 Xbyak::Zmm even = Xbyak::Zmm(3);
269 Xbyak::Zmm selector = Xbyak::Zmm(4);
270 Xbyak::Reg64 scratch = r15;
271
272 Xbyak::Ymm bf16_out = Xbyak::Ymm(5);
273
274 Xbyak::Reg64 reg_inp = rax;
275 Xbyak::Reg64 reg_out = rbx;
276 Xbyak::Reg64 reg_add = r11;
277 Xbyak::Reg64 reg_nelems = rdx;
278
279 Xbyak::Reg64 reg64_tail = rcx;
280 Xbyak::Reg32 reg32_tail = ecx;
281 Xbyak::Reg8 reg8_mask_shift = cl;
282 Xbyak::Reg32 reg32_mask = r8d;
283};
284
285// implementation of reorder of part of tensor [s][16c] -> [S][16c][2s]
286// it is required for quick implementation of 1x1 bf16 bwd_w jit kernel
287// w/o using permw instruction inside
288// TODO: consider modification/replacement for outer transformation jit kernel
289struct jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t : public jit_generator {
290 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_reorder_s16c_to_S16c2s)
291
292 jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t()
293 : jit_generator(jit_name()), simd_w_(16), in_stride_(16) {}
294
295 jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t(int in_stride)
296 : jit_generator(jit_name()), simd_w_(16), in_stride_(in_stride) {}
297
298 ~jit_avx512_core_bf16_reorder_s16c_to_S16c2s_t() {}
299
300 void generate() override {
301 preamble();
302
303 mov(reg32_tail, ptr[abi_param1 + GET_OFF(mask)]);
304 mov(reg_inp, ptr[abi_param1 + GET_OFF(inp)]);
305 mov(reg_out, ptr[abi_param1 + GET_OFF(out)]);
306 mov(reg_nelems, ptr[abi_param1 + GET_OFF(nelems)]);
307
308 auto zmm_reg = [=](int idx) {
309 assert(idx < 31);
310 return Xbyak::Zmm(idx);
311 };
312
313 kmovd(ktail_mask_lo, reg32_tail);
314 kshiftld(ktail_mask_hi, ktail_mask_lo, 16);
315
316 Xbyak::Label dst_prm_table;
317 mov(reg_prm, dst_prm_table);
318 vmovups(zmm_prm, ptr[reg_prm]);
319
320 constexpr int n_unroll = 2; // unroll by powers of 2 from 2^n to 2^0
321 int sizeofcacheline = 2 * simd_w_ * sizeof(bfloat16_t);
322 int in_stride_bytes = in_stride_ * sizeof(bfloat16_t);
323 Xbyak::Label l_simd_loop[n_unroll + 2], l_simd_notail;
324 for (int i = n_unroll; i >= 0; i--) {
325 const int unroll = 1 << i; // 4, 2, 1
326 L(l_simd_loop[i + 1]);
327 {
328 cmp(reg_nelems, 2 * unroll);
329 jl(l_simd_loop[i], T_NEAR);
330 for (int j = 0; j < unroll; j++) {
331 auto zmm_inp = zmm_reg(j);
332 if (in_stride_ == 16)
333 vmovups(zmm_inp, zword[reg_inp + j * sizeofcacheline]);
334 else {
335 vmovdqu16(zmm_inp | ktail_mask_lo | T_z,
336 zword[reg_inp + 2 * j * in_stride_bytes]);
337 vmovdqu16(zmm_inp | ktail_mask_hi,
338 zword[reg_inp + (2 * j + 1) * in_stride_bytes
339 - 32]);
340 }
341 vpermw(zmm_inp, zmm_prm, zmm_inp);
342 vmovups(zword[reg_out + j * sizeofcacheline], zmm_inp);
343 }
344 add(reg_inp,
345 unroll
346 * (in_stride_ == 16 ? sizeofcacheline
347 : 2 * in_stride_bytes));
348 add(reg_out, unroll * sizeofcacheline);
349
350 sub(reg_nelems, 2 * unroll);
351 jmp(l_simd_loop[i + 1], T_NEAR);
352 }
353 }
354 L(l_simd_loop[0]);
355
356 test(reg_nelems, reg_nelems);
357 jz(l_simd_notail);
358
359 auto zmm_inp = zmm_reg(0);
360 vpxord(zmm_inp, zmm_inp, zmm_inp);
361 vmovdqu16(zmm_inp | ktail_mask_lo | T_z, ptr[reg_inp]);
362 vpermw(zmm_inp, zmm_prm, zmm_inp);
363 vmovups(zword[reg_out], zmm_inp);
364
365 L(l_simd_notail);
366
367 postamble();
368
369 const uint16_t dst_prm_array[32] = {0, 16, 1, 17, 2, 18, 3, 19, 4, 20,
370 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13,
371 29, 14, 30, 15, 31};
372
373 align(64);
374 L(dst_prm_table);
375 for (size_t i = 0; i < 32; ++i)
376 dw(dst_prm_array[i]);
377 }
378
379 void operator()(bf16_support::jit_call_t *params) const {
380 jit_generator::operator()(params);
381 msan_unpoison(params->out, params->nelems * sizeof(bfloat16_t));
382 }
383
384private:
385 int simd_w_;
386 int in_stride_;
387
388 Xbyak::Opmask ktail_mask_lo = k2;
389 Xbyak::Opmask ktail_mask_hi = k3;
390 Xbyak::Zmm zmm_prm = Xbyak::Zmm(31);
391
392 Xbyak::Reg64 reg_inp = rax;
393 Xbyak::Reg64 reg_out = rbx;
394 Xbyak::Reg64 reg_prm = r11;
395 Xbyak::Reg64 reg_nelems = rdx;
396
397 Xbyak::Reg32 reg32_tail = abi_not_param1.cvt32();
398};
399
400#undef GET_OFF
401} // namespace x64
402} // namespace cpu
403} // namespace impl
404} // namespace dnnl
405
406#endif
407