1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/bfloat16.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | #include "common/dnnl_thread.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | |
26 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
27 | #include "cpu/x64/jit_uni_eltwise.hpp" |
28 | |
29 | #define GET_OFF(field) offsetof(jit_args_t, field) |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | namespace cpu { |
34 | namespace x64 { |
35 | |
36 | using namespace Xbyak; |
37 | |
38 | struct jit_args_t { |
39 | const void *src; // fwd: src; bwd: src/dst based on alg; |
40 | const void *dst; // fwd: dst; bwd: diff_src; |
41 | const void *diff_dst; // fwd: nullptr; bwd: diff_dst; |
42 | size_t work_amount; |
43 | }; |
44 | |
45 | struct jit_uni_eltwise_kernel : public jit_generator { |
46 | jit_uni_eltwise_kernel(const eltwise_pd_t *pd, const char *name) |
47 | : jit_generator(name), pd_(pd) {} |
48 | |
49 | void operator()(jit_args_t *p) { jit_generator::operator()(p); } |
50 | |
51 | protected: |
52 | const eltwise_pd_t *pd_; |
53 | |
54 | data_type_t data_type() const { |
55 | return pd_->use_dst() ? pd_->dst_md()->data_type |
56 | : pd_->src_md()->data_type; |
57 | } |
58 | bool is_bf16() const { return data_type() == data_type::bf16; } |
59 | bool is_f16() const { return data_type() == data_type::f16; } |
60 | int dtype_size() const { return types::data_type_size(data_type()); } |
61 | }; |
62 | |
63 | // jit kernels |
64 | namespace { |
65 | |
66 | struct jit_bf16_injector_t { |
67 | jit_bf16_injector_t( |
68 | jit_generator *host, Opmask k_tail_mask, bf16_emulation_t *emu) |
69 | : h(host), k_tail_mask_(k_tail_mask), emu_(emu) {} |
70 | |
71 | void prepare_mask() { |
72 | Reg64 reg_tmp = h->r14; |
73 | h->sub(h->rsp, 8); // sizeof(Reg64) |
74 | h->mov(h->ptr[h->rsp], reg_tmp); |
75 | h->mov(reg_tmp.cvt32(), 0x1); |
76 | h->kmovd(k_tail_mask_, reg_tmp.cvt32()); |
77 | h->mov(reg_tmp, h->ptr[h->rsp]); |
78 | h->add(h->rsp, 8); |
79 | } |
80 | |
81 | void load_bf16_cvt_to_f32(size_t idx, Reg64 reg_src, bool is_tail = false, |
82 | size_t offset = 0) { |
83 | Zmm zmm_f32 = Zmm(idx); |
84 | zmm_f32 = is_tail ? zmm_f32 | k_tail_mask_ | Xbyak::util::T_z : zmm_f32; |
85 | h->vpmovzxwd(zmm_f32, h->ptr[reg_src + offset]); |
86 | h->vpslld(zmm_f32, zmm_f32, 16); |
87 | } |
88 | |
89 | void cvt_f32_to_bf16_store(int step, size_t idx, Reg64 reg_dst, |
90 | bool is_tail = false, size_t offset = 0) { |
91 | assert(step >= 1 && step <= 2 |
92 | && IMPLICATION(step == 2, is_tail == false)); |
93 | if (step == 2 && !is_tail) { |
94 | Ymm ymm_bf16_0 = Ymm(idx); |
95 | Ymm ymm_bf16_1 = Ymm(idx + 1); |
96 | Zmm zmm_f32_0 = Zmm(idx); |
97 | Zmm zmm_f32_1 = Zmm(idx + 1); |
98 | if (emu_) { |
99 | emu_->vcvtneps2bf16(ymm_bf16_0, zmm_f32_0); |
100 | emu_->vcvtneps2bf16(ymm_bf16_1, zmm_f32_1); |
101 | h->vinserti64x4(zmm_f32_0, zmm_f32_0, ymm_bf16_1, 1); |
102 | h->vmovups(h->ptr[reg_dst + offset], zmm_f32_0); |
103 | } else { |
104 | h->vcvtne2ps2bf16(zmm_f32_1, zmm_f32_1, zmm_f32_0); |
105 | h->vmovups(h->ptr[reg_dst + offset], zmm_f32_1); |
106 | } |
107 | } else { |
108 | Ymm ymm_bf16 = Ymm(idx); |
109 | Zmm zmm_f32 = Zmm(idx); |
110 | if (emu_) |
111 | emu_->vcvtneps2bf16(ymm_bf16, zmm_f32); |
112 | else |
113 | h->vcvtneps2bf16(ymm_bf16, zmm_f32); |
114 | if (!is_tail) |
115 | h->vmovdqu16(h->ptr[reg_dst + offset], ymm_bf16); |
116 | else |
117 | h->vmovdqu16(h->ptr[reg_dst + offset] | k_tail_mask_, ymm_bf16); |
118 | } |
119 | } |
120 | |
121 | private: |
122 | jit_generator *const h; |
123 | Xbyak::Opmask k_tail_mask_; |
124 | bf16_emulation_t *const emu_; |
125 | }; |
126 | |
127 | template <cpu_isa_t isa> |
128 | struct jit_uni_kernel_t : public jit_uni_eltwise_kernel { |
129 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel) |
130 | |
131 | jit_uni_kernel_t(const eltwise_pd_t *pd) |
132 | : jit_uni_eltwise_kernel(pd, jit_name()) { |
133 | if (is_bf16()) { |
134 | if (!mayiuse(avx512_core_bf16)) |
135 | bf16_emu_.reset(new bf16_emulation_t(this, bf16_emu_reserv_1, |
136 | bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_scratch, |
137 | bf16_emu_reserv_5)); |
138 | bf16_injector_.reset(new jit_bf16_injector_t( |
139 | this, k_tail_mask, bf16_emu_.get())); |
140 | } |
141 | |
142 | const auto &desc = *pd_->desc(); |
143 | // there's no auxiliary vregs on fwd path |
144 | const bool is_fwd = pd_->is_fwd(); |
145 | const bool save_state = is_fwd ? false : true; |
146 | eltwise_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this, |
147 | desc.alg_kind, desc.alpha, desc.beta, 1.f, save_state, |
148 | reg_injector_table, injector_mask, is_fwd, pd_->use_dst())); |
149 | } |
150 | |
151 | void generate() override { |
152 | const bool is_fwd = pd_->is_fwd(); |
153 | preamble(); |
154 | |
155 | if (is_bf16()) { |
156 | bf16_injector_->prepare_mask(); |
157 | if (!mayiuse(avx512_core_bf16)) bf16_emu_->init_vcvtneps2bf16(); |
158 | } |
159 | |
160 | Reg64 param = abi_param1; |
161 | mov(reg_src, ptr[param + GET_OFF(src)]); |
162 | mov(reg_dst, ptr[param + GET_OFF(dst)]); |
163 | if (!is_fwd) mov(reg_diff_dst, ptr[param + GET_OFF(diff_dst)]); |
164 | mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); |
165 | eltwise_injector_->load_table_addr(); |
166 | |
167 | Label reminder_loop_start, reminder_loop_end; |
168 | Label vectorized_loop_start, vectorized_loop_end; |
169 | |
170 | cmp(reg_work_amount, simd_w()); |
171 | jl(reminder_loop_start, T_NEAR); |
172 | |
173 | L(vectorized_loop_start); |
174 | |
175 | // TODO: consider improving. |
176 | // This piece of code is responsible for the preserve_zero function |
177 | // being a natural restriction of this implementation. It works with any |
178 | // dense and blocked layout, but the problem raises when blocking |
179 | // dimension is not divisible by block size. For such case, the code |
180 | // below should save the mask, where zero padding should be preserved |
181 | // and apply it on register before storing into dst memory. Until |
182 | // there's a restriction on certain blocked layouts, when this behavior |
183 | // can be relevantly easy controlled, this will cost much from code |
184 | // perspective and will complicate the compute logic significantly. |
185 | if (is_bf16()) { |
186 | bf16_injector_->load_bf16_cvt_to_f32(vmm_src.getIdx(), reg_src); |
187 | eltwise_injector_->compute_vector(vmm_src.getIdx()); |
188 | if (!is_fwd) { |
189 | bf16_injector_->load_bf16_cvt_to_f32( |
190 | vmm_diff_dst.getIdx(), reg_diff_dst); |
191 | uni_vmulps(vmm_src, vmm_src, vmm_diff_dst); |
192 | } |
193 | bf16_injector_->cvt_f32_to_bf16_store(1, vmm_src.getIdx(), reg_dst); |
194 | } else if (is_f16()) { |
195 | vcvtph2psx(vmm_src, ptr[reg_src]); |
196 | eltwise_injector_->compute_vector(vmm_src.getIdx()); |
197 | if (!is_fwd) { |
198 | vcvtph2psx(vmm_diff_dst, ptr[reg_diff_dst]); |
199 | uni_vmulps(vmm_src, vmm_src, vmm_diff_dst); |
200 | } |
201 | vcvtps2ph(ptr[reg_dst], vmm_src, _op_mxcsr); |
202 | } else { |
203 | uni_vmovups(vmm_src, ptr[reg_src]); |
204 | eltwise_injector_->compute_vector(vmm_src.getIdx()); |
205 | if (!is_fwd) { |
206 | uni_vmovups(vmm_diff_dst, ptr[reg_diff_dst]); |
207 | uni_vmulps(vmm_src, vmm_src, vmm_diff_dst); |
208 | } |
209 | uni_vmovups(ptr[reg_dst], vmm_src); |
210 | } |
211 | |
212 | const auto shift = vlen(); |
213 | add(reg_src, shift); |
214 | add(reg_dst, shift); |
215 | if (!is_fwd) add(reg_diff_dst, shift); |
216 | |
217 | sub(reg_work_amount, simd_w()); |
218 | cmp(reg_work_amount, simd_w()); |
219 | jge(vectorized_loop_start, T_NEAR); |
220 | |
221 | L(vectorized_loop_end); |
222 | |
223 | L(reminder_loop_start); |
224 | |
225 | cmp(reg_work_amount, 0); |
226 | jle(reminder_loop_end, T_NEAR); |
227 | if (is_bf16()) { |
228 | bf16_injector_->load_bf16_cvt_to_f32( |
229 | vmm_src.getIdx(), reg_src, true); |
230 | eltwise_injector_->compute_vector(vmm_src.getIdx()); |
231 | if (!is_fwd) { |
232 | bf16_injector_->load_bf16_cvt_to_f32( |
233 | vmm_diff_dst.getIdx(), reg_diff_dst, true); |
234 | uni_vmulps(vmm_src, vmm_src, vmm_diff_dst); |
235 | } |
236 | bf16_injector_->cvt_f32_to_bf16_store( |
237 | 1, vmm_src.getIdx(), reg_dst, true); |
238 | } else if (is_f16()) { |
239 | vxorps(xmm_src, xmm_src, xmm_src); |
240 | vcvtsh2ss(xmm_src, xmm_src, ptr[reg_src]); |
241 | eltwise_injector_->compute_vector(vmm_src.getIdx()); |
242 | if (!is_fwd) { |
243 | vxorps(xmm_diff_dst, xmm_diff_dst, xmm_diff_dst); |
244 | vcvtsh2ss(xmm_diff_dst, xmm_diff_dst, ptr[reg_diff_dst]); |
245 | uni_vmulps(xmm_src, xmm_src, xmm_diff_dst); |
246 | } |
247 | vcvtss2sh(xmm_src, xmm_src, xmm_src); |
248 | vmovsh(ptr[reg_dst], xmm_src); |
249 | } else { |
250 | uni_vmovss(xmm_src, ptr[reg_src]); |
251 | eltwise_injector_->compute_vector(xmm_src.getIdx()); |
252 | if (!is_fwd) { |
253 | uni_vmovss(xmm_diff_dst, ptr[reg_diff_dst]); |
254 | uni_vmulps(xmm_src, xmm_src, xmm_diff_dst); |
255 | } |
256 | uni_vmovss(ptr[reg_dst], xmm_src); |
257 | } |
258 | add(reg_src, dtype_size()); |
259 | add(reg_dst, dtype_size()); |
260 | if (!is_fwd) add(reg_diff_dst, dtype_size()); |
261 | |
262 | dec(reg_work_amount); |
263 | jmp(reminder_loop_start, T_NEAR); |
264 | |
265 | L(reminder_loop_end); |
266 | |
267 | postamble(); |
268 | |
269 | eltwise_injector_->prepare_table(); |
270 | } |
271 | |
272 | private: |
273 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
274 | |
275 | int vlen() { |
276 | int vlen = cpu_isa_traits<isa>::vlen; |
277 | return is_bf16() || is_f16() ? vlen / 2 : vlen; |
278 | } |
279 | int simd_w() { return vlen() / dtype_size(); } |
280 | |
281 | Reg64 reg_src = rax; |
282 | Reg64 reg_dst = r8; |
283 | Reg64 reg_injector_table = r9; |
284 | Reg64 reg_diff_dst = r10; |
285 | Reg64 reg_work_amount = rsi; |
286 | Reg64 imm_addr64 = rbx; |
287 | |
288 | Opmask injector_mask = Opmask(1); |
289 | |
290 | Xmm xmm_src = Xmm(1); |
291 | Vmm vmm_src = Vmm(1); |
292 | Xmm xmm_diff_dst = Xmm(2); |
293 | Vmm vmm_diff_dst = Vmm(2); |
294 | std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> eltwise_injector_; |
295 | |
296 | /* bf16 support */ |
297 | Zmm bf16_emu_reserv_1 = Zmm(26); |
298 | Zmm bf16_emu_reserv_2 = Zmm(27); |
299 | Zmm bf16_emu_reserv_3 = Zmm(28); |
300 | Reg64 bf16_emu_scratch = r14; |
301 | Zmm bf16_emu_reserv_5 = Zmm(29); |
302 | |
303 | Opmask k_tail_mask = k6; |
304 | |
305 | std::unique_ptr<jit_bf16_injector_t> bf16_injector_; |
306 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
307 | }; |
308 | |
309 | } // namespace |
310 | |
311 | template <cpu_isa_t isa, data_type_t d_type> |
312 | status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) { |
313 | using namespace alg_kind; |
314 | |
315 | const memory_desc_wrapper src_d(src_md()); |
316 | |
317 | bool ok = mayiuse(isa) && is_fwd() |
318 | && utils::everyone_is( |
319 | d_type, src_md()->data_type, dst_md()->data_type) |
320 | && IMPLICATION(src_md()->data_type == data_type::bf16, |
321 | mayiuse(avx512_core)) |
322 | && IMPLICATION(src_md()->data_type == data_type::f16, |
323 | mayiuse(avx512_core_fp16)) |
324 | && !has_zero_dim_memory() && src_d.is_dense(true) |
325 | && eltwise_injector::is_supported(isa, desc_.alg_kind) |
326 | // refer to a comment in jit_uni_kernel why this is needed |
327 | && IMPLICATION(!src_d.is_dense(), is_zero_preserved()) |
328 | && attr()->has_default_values() && set_default_formats_common() |
329 | && src_d == memory_desc_wrapper(dst_md()); |
330 | return ok ? status::success : status::unimplemented; |
331 | } |
332 | |
333 | template <cpu_isa_t isa, data_type_t d_type> |
334 | jit_uni_eltwise_fwd_t<isa, d_type>::jit_uni_eltwise_fwd_t(const pd_t *apd) |
335 | : primitive_t(apd) {} |
336 | |
337 | template <cpu_isa_t isa, data_type_t d_type> |
338 | jit_uni_eltwise_fwd_t<isa, d_type>::~jit_uni_eltwise_fwd_t() = default; |
339 | |
340 | template <cpu_isa_t isa, data_type_t d_type> |
341 | status_t jit_uni_eltwise_fwd_t<isa, d_type>::init(engine_t *engine) { |
342 | CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd()))); |
343 | return kernel_->create_kernel(); |
344 | } |
345 | |
346 | template <cpu_isa_t isa, data_type_t d_type> |
347 | status_t jit_uni_eltwise_fwd_t<isa, d_type>::execute( |
348 | const exec_ctx_t &ctx) const { |
349 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
350 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
351 | |
352 | const memory_desc_wrapper data_d(pd()->src_md()); |
353 | const auto nelems = data_d.nelems(true); |
354 | const int simd_w = 64 / data_d.data_type_size(); |
355 | |
356 | src += data_d.offset0(); |
357 | dst += data_d.offset0(); |
358 | |
359 | parallel(0, [&](const int ithr, const int nthr) { |
360 | dim_t start {0}, end {0}; |
361 | |
362 | balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end); |
363 | start = nstl::min(nelems, start * simd_w); |
364 | end = nstl::min(nelems, end * simd_w); |
365 | if (start == end) return; |
366 | |
367 | jit_args_t args; |
368 | args.src = src + start; |
369 | args.dst = dst + start; |
370 | args.diff_dst = nullptr; |
371 | args.work_amount = end - start; |
372 | (*kernel_)(&args); |
373 | }); |
374 | |
375 | return status::success; |
376 | } |
377 | |
378 | template <cpu_isa_t isa, data_type_t d_type> |
379 | status_t jit_uni_eltwise_bwd_t<isa, d_type>::pd_t::init(engine_t *engine) { |
380 | using namespace alg_kind; |
381 | |
382 | const memory_desc_wrapper data_d(data_md()); |
383 | |
384 | bool ok = mayiuse(isa) && !is_fwd() |
385 | && utils::everyone_is(d_type, data_md()->data_type, |
386 | diff_src_md()->data_type, diff_dst_md()->data_type) |
387 | && IMPLICATION(data_md()->data_type == data_type::bf16, |
388 | mayiuse(avx512_core)) |
389 | && IMPLICATION(data_md()->data_type == data_type::f16, |
390 | mayiuse(avx512_core_fp16)) |
391 | && !has_zero_dim_memory() && set_default_formats_common() |
392 | && data_d.is_dense(true) && eltwise_injector::is_isa_supported(isa) |
393 | && eltwise_injector::is_alg_supported(desc_.alg_kind) |
394 | // refer to a comment in jit_uni_kernel why this is needed |
395 | && IMPLICATION(!data_d.is_dense(), is_zero_preserved()) |
396 | && data_d == memory_desc_wrapper(diff_dst_md()) |
397 | && memory_desc_wrapper(diff_src_md()) |
398 | == memory_desc_wrapper(diff_dst_md()) |
399 | && attr()->has_default_values(); |
400 | return ok ? status::success : status::unimplemented; |
401 | } |
402 | |
403 | template <cpu_isa_t isa, data_type_t d_type> |
404 | jit_uni_eltwise_bwd_t<isa, d_type>::jit_uni_eltwise_bwd_t(const pd_t *apd) |
405 | : primitive_t(apd) {} |
406 | |
407 | template <cpu_isa_t isa, data_type_t d_type> |
408 | jit_uni_eltwise_bwd_t<isa, d_type>::~jit_uni_eltwise_bwd_t() = default; |
409 | |
410 | template <cpu_isa_t isa, data_type_t d_type> |
411 | status_t jit_uni_eltwise_bwd_t<isa, d_type>::init(engine_t *engine) { |
412 | CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd()))); |
413 | return kernel_->create_kernel(); |
414 | } |
415 | |
416 | template <cpu_isa_t isa, data_type_t d_type> |
417 | status_t jit_uni_eltwise_bwd_t<isa, d_type>::execute( |
418 | const exec_ctx_t &ctx) const { |
419 | auto src = pd()->use_dst() ? CTX_IN_MEM(const data_t *, DNNL_ARG_DST) |
420 | : CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
421 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
422 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
423 | |
424 | const memory_desc_wrapper data_d(pd()->data_md()); |
425 | const memory_desc_wrapper diff_data_d(pd()->diff_src_md()); |
426 | const auto nelems = data_d.nelems(true); |
427 | const int simd_w = 64 / data_d.data_type_size(); |
428 | |
429 | src += data_d.offset0(); |
430 | diff_dst += diff_data_d.offset0(); |
431 | diff_src += diff_data_d.offset0(); |
432 | |
433 | parallel(0, [&](const int ithr, const int nthr) { |
434 | dim_t start {0}, end {0}; |
435 | |
436 | balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end); |
437 | start = nstl::min(nelems, start * simd_w); |
438 | end = nstl::min(nelems, end * simd_w); |
439 | if (start == end) return; |
440 | |
441 | jit_args_t args; |
442 | args.src = src + start; |
443 | args.dst = diff_src + start; |
444 | args.diff_dst = diff_dst + start; |
445 | args.work_amount = end - start; |
446 | (*kernel_)(&args); |
447 | }); |
448 | |
449 | return status::success; |
450 | } |
451 | |
452 | template struct jit_uni_eltwise_fwd_t<sse41, data_type::f32>; |
453 | template struct jit_uni_eltwise_fwd_t<avx, data_type::f32>; |
454 | template struct jit_uni_eltwise_fwd_t<avx2, data_type::f32>; |
455 | template struct jit_uni_eltwise_fwd_t<avx512_core, data_type::f32>; |
456 | template struct jit_uni_eltwise_fwd_t<avx512_core, data_type::bf16>; |
457 | template struct jit_uni_eltwise_fwd_t<avx512_core_fp16, data_type::f16>; |
458 | |
459 | template struct jit_uni_eltwise_bwd_t<sse41, data_type::f32>; |
460 | template struct jit_uni_eltwise_bwd_t<avx, data_type::f32>; |
461 | template struct jit_uni_eltwise_bwd_t<avx2, data_type::f32>; |
462 | template struct jit_uni_eltwise_bwd_t<avx512_core, data_type::f32>; |
463 | template struct jit_uni_eltwise_bwd_t<avx512_core, data_type::bf16>; |
464 | template struct jit_uni_eltwise_bwd_t<avx512_core_fp16, data_type::f16>; |
465 | |
466 | } // namespace x64 |
467 | } // namespace cpu |
468 | } // namespace impl |
469 | } // namespace dnnl |
470 | |