1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/bfloat16.hpp"
18#include "common/c_types_map.hpp"
19#include "common/dnnl_thread.hpp"
20#include "common/nstl.hpp"
21#include "common/utils.hpp"
22
23#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
24#include "cpu/x64/jit_generator.hpp"
25
26#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
27#include "cpu/x64/jit_uni_eltwise.hpp"
28
29#define GET_OFF(field) offsetof(jit_args_t, field)
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36using namespace Xbyak;
37
38struct jit_args_t {
39 const void *src; // fwd: src; bwd: src/dst based on alg;
40 const void *dst; // fwd: dst; bwd: diff_src;
41 const void *diff_dst; // fwd: nullptr; bwd: diff_dst;
42 size_t work_amount;
43};
44
45struct jit_uni_eltwise_kernel : public jit_generator {
46 jit_uni_eltwise_kernel(const eltwise_pd_t *pd, const char *name)
47 : jit_generator(name), pd_(pd) {}
48
49 void operator()(jit_args_t *p) { jit_generator::operator()(p); }
50
51protected:
52 const eltwise_pd_t *pd_;
53
54 data_type_t data_type() const {
55 return pd_->use_dst() ? pd_->dst_md()->data_type
56 : pd_->src_md()->data_type;
57 }
58 bool is_bf16() const { return data_type() == data_type::bf16; }
59 bool is_f16() const { return data_type() == data_type::f16; }
60 int dtype_size() const { return types::data_type_size(data_type()); }
61};
62
63// jit kernels
64namespace {
65
66struct jit_bf16_injector_t {
67 jit_bf16_injector_t(
68 jit_generator *host, Opmask k_tail_mask, bf16_emulation_t *emu)
69 : h(host), k_tail_mask_(k_tail_mask), emu_(emu) {}
70
71 void prepare_mask() {
72 Reg64 reg_tmp = h->r14;
73 h->sub(h->rsp, 8); // sizeof(Reg64)
74 h->mov(h->ptr[h->rsp], reg_tmp);
75 h->mov(reg_tmp.cvt32(), 0x1);
76 h->kmovd(k_tail_mask_, reg_tmp.cvt32());
77 h->mov(reg_tmp, h->ptr[h->rsp]);
78 h->add(h->rsp, 8);
79 }
80
81 void load_bf16_cvt_to_f32(size_t idx, Reg64 reg_src, bool is_tail = false,
82 size_t offset = 0) {
83 Zmm zmm_f32 = Zmm(idx);
84 zmm_f32 = is_tail ? zmm_f32 | k_tail_mask_ | Xbyak::util::T_z : zmm_f32;
85 h->vpmovzxwd(zmm_f32, h->ptr[reg_src + offset]);
86 h->vpslld(zmm_f32, zmm_f32, 16);
87 }
88
89 void cvt_f32_to_bf16_store(int step, size_t idx, Reg64 reg_dst,
90 bool is_tail = false, size_t offset = 0) {
91 assert(step >= 1 && step <= 2
92 && IMPLICATION(step == 2, is_tail == false));
93 if (step == 2 && !is_tail) {
94 Ymm ymm_bf16_0 = Ymm(idx);
95 Ymm ymm_bf16_1 = Ymm(idx + 1);
96 Zmm zmm_f32_0 = Zmm(idx);
97 Zmm zmm_f32_1 = Zmm(idx + 1);
98 if (emu_) {
99 emu_->vcvtneps2bf16(ymm_bf16_0, zmm_f32_0);
100 emu_->vcvtneps2bf16(ymm_bf16_1, zmm_f32_1);
101 h->vinserti64x4(zmm_f32_0, zmm_f32_0, ymm_bf16_1, 1);
102 h->vmovups(h->ptr[reg_dst + offset], zmm_f32_0);
103 } else {
104 h->vcvtne2ps2bf16(zmm_f32_1, zmm_f32_1, zmm_f32_0);
105 h->vmovups(h->ptr[reg_dst + offset], zmm_f32_1);
106 }
107 } else {
108 Ymm ymm_bf16 = Ymm(idx);
109 Zmm zmm_f32 = Zmm(idx);
110 if (emu_)
111 emu_->vcvtneps2bf16(ymm_bf16, zmm_f32);
112 else
113 h->vcvtneps2bf16(ymm_bf16, zmm_f32);
114 if (!is_tail)
115 h->vmovdqu16(h->ptr[reg_dst + offset], ymm_bf16);
116 else
117 h->vmovdqu16(h->ptr[reg_dst + offset] | k_tail_mask_, ymm_bf16);
118 }
119 }
120
121private:
122 jit_generator *const h;
123 Xbyak::Opmask k_tail_mask_;
124 bf16_emulation_t *const emu_;
125};
126
127template <cpu_isa_t isa>
128struct jit_uni_kernel_t : public jit_uni_eltwise_kernel {
129 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel)
130
131 jit_uni_kernel_t(const eltwise_pd_t *pd)
132 : jit_uni_eltwise_kernel(pd, jit_name()) {
133 if (is_bf16()) {
134 if (!mayiuse(avx512_core_bf16))
135 bf16_emu_.reset(new bf16_emulation_t(this, bf16_emu_reserv_1,
136 bf16_emu_reserv_2, bf16_emu_reserv_3, bf16_emu_scratch,
137 bf16_emu_reserv_5));
138 bf16_injector_.reset(new jit_bf16_injector_t(
139 this, k_tail_mask, bf16_emu_.get()));
140 }
141
142 const auto &desc = *pd_->desc();
143 // there's no auxiliary vregs on fwd path
144 const bool is_fwd = pd_->is_fwd();
145 const bool save_state = is_fwd ? false : true;
146 eltwise_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this,
147 desc.alg_kind, desc.alpha, desc.beta, 1.f, save_state,
148 reg_injector_table, injector_mask, is_fwd, pd_->use_dst()));
149 }
150
151 void generate() override {
152 const bool is_fwd = pd_->is_fwd();
153 preamble();
154
155 if (is_bf16()) {
156 bf16_injector_->prepare_mask();
157 if (!mayiuse(avx512_core_bf16)) bf16_emu_->init_vcvtneps2bf16();
158 }
159
160 Reg64 param = abi_param1;
161 mov(reg_src, ptr[param + GET_OFF(src)]);
162 mov(reg_dst, ptr[param + GET_OFF(dst)]);
163 if (!is_fwd) mov(reg_diff_dst, ptr[param + GET_OFF(diff_dst)]);
164 mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
165 eltwise_injector_->load_table_addr();
166
167 Label reminder_loop_start, reminder_loop_end;
168 Label vectorized_loop_start, vectorized_loop_end;
169
170 cmp(reg_work_amount, simd_w());
171 jl(reminder_loop_start, T_NEAR);
172
173 L(vectorized_loop_start);
174
175 // TODO: consider improving.
176 // This piece of code is responsible for the preserve_zero function
177 // being a natural restriction of this implementation. It works with any
178 // dense and blocked layout, but the problem raises when blocking
179 // dimension is not divisible by block size. For such case, the code
180 // below should save the mask, where zero padding should be preserved
181 // and apply it on register before storing into dst memory. Until
182 // there's a restriction on certain blocked layouts, when this behavior
183 // can be relevantly easy controlled, this will cost much from code
184 // perspective and will complicate the compute logic significantly.
185 if (is_bf16()) {
186 bf16_injector_->load_bf16_cvt_to_f32(vmm_src.getIdx(), reg_src);
187 eltwise_injector_->compute_vector(vmm_src.getIdx());
188 if (!is_fwd) {
189 bf16_injector_->load_bf16_cvt_to_f32(
190 vmm_diff_dst.getIdx(), reg_diff_dst);
191 uni_vmulps(vmm_src, vmm_src, vmm_diff_dst);
192 }
193 bf16_injector_->cvt_f32_to_bf16_store(1, vmm_src.getIdx(), reg_dst);
194 } else if (is_f16()) {
195 vcvtph2psx(vmm_src, ptr[reg_src]);
196 eltwise_injector_->compute_vector(vmm_src.getIdx());
197 if (!is_fwd) {
198 vcvtph2psx(vmm_diff_dst, ptr[reg_diff_dst]);
199 uni_vmulps(vmm_src, vmm_src, vmm_diff_dst);
200 }
201 vcvtps2ph(ptr[reg_dst], vmm_src, _op_mxcsr);
202 } else {
203 uni_vmovups(vmm_src, ptr[reg_src]);
204 eltwise_injector_->compute_vector(vmm_src.getIdx());
205 if (!is_fwd) {
206 uni_vmovups(vmm_diff_dst, ptr[reg_diff_dst]);
207 uni_vmulps(vmm_src, vmm_src, vmm_diff_dst);
208 }
209 uni_vmovups(ptr[reg_dst], vmm_src);
210 }
211
212 const auto shift = vlen();
213 add(reg_src, shift);
214 add(reg_dst, shift);
215 if (!is_fwd) add(reg_diff_dst, shift);
216
217 sub(reg_work_amount, simd_w());
218 cmp(reg_work_amount, simd_w());
219 jge(vectorized_loop_start, T_NEAR);
220
221 L(vectorized_loop_end);
222
223 L(reminder_loop_start);
224
225 cmp(reg_work_amount, 0);
226 jle(reminder_loop_end, T_NEAR);
227 if (is_bf16()) {
228 bf16_injector_->load_bf16_cvt_to_f32(
229 vmm_src.getIdx(), reg_src, true);
230 eltwise_injector_->compute_vector(vmm_src.getIdx());
231 if (!is_fwd) {
232 bf16_injector_->load_bf16_cvt_to_f32(
233 vmm_diff_dst.getIdx(), reg_diff_dst, true);
234 uni_vmulps(vmm_src, vmm_src, vmm_diff_dst);
235 }
236 bf16_injector_->cvt_f32_to_bf16_store(
237 1, vmm_src.getIdx(), reg_dst, true);
238 } else if (is_f16()) {
239 vxorps(xmm_src, xmm_src, xmm_src);
240 vcvtsh2ss(xmm_src, xmm_src, ptr[reg_src]);
241 eltwise_injector_->compute_vector(vmm_src.getIdx());
242 if (!is_fwd) {
243 vxorps(xmm_diff_dst, xmm_diff_dst, xmm_diff_dst);
244 vcvtsh2ss(xmm_diff_dst, xmm_diff_dst, ptr[reg_diff_dst]);
245 uni_vmulps(xmm_src, xmm_src, xmm_diff_dst);
246 }
247 vcvtss2sh(xmm_src, xmm_src, xmm_src);
248 vmovsh(ptr[reg_dst], xmm_src);
249 } else {
250 uni_vmovss(xmm_src, ptr[reg_src]);
251 eltwise_injector_->compute_vector(xmm_src.getIdx());
252 if (!is_fwd) {
253 uni_vmovss(xmm_diff_dst, ptr[reg_diff_dst]);
254 uni_vmulps(xmm_src, xmm_src, xmm_diff_dst);
255 }
256 uni_vmovss(ptr[reg_dst], xmm_src);
257 }
258 add(reg_src, dtype_size());
259 add(reg_dst, dtype_size());
260 if (!is_fwd) add(reg_diff_dst, dtype_size());
261
262 dec(reg_work_amount);
263 jmp(reminder_loop_start, T_NEAR);
264
265 L(reminder_loop_end);
266
267 postamble();
268
269 eltwise_injector_->prepare_table();
270 }
271
272private:
273 using Vmm = typename cpu_isa_traits<isa>::Vmm;
274
275 int vlen() {
276 int vlen = cpu_isa_traits<isa>::vlen;
277 return is_bf16() || is_f16() ? vlen / 2 : vlen;
278 }
279 int simd_w() { return vlen() / dtype_size(); }
280
281 Reg64 reg_src = rax;
282 Reg64 reg_dst = r8;
283 Reg64 reg_injector_table = r9;
284 Reg64 reg_diff_dst = r10;
285 Reg64 reg_work_amount = rsi;
286 Reg64 imm_addr64 = rbx;
287
288 Opmask injector_mask = Opmask(1);
289
290 Xmm xmm_src = Xmm(1);
291 Vmm vmm_src = Vmm(1);
292 Xmm xmm_diff_dst = Xmm(2);
293 Vmm vmm_diff_dst = Vmm(2);
294 std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> eltwise_injector_;
295
296 /* bf16 support */
297 Zmm bf16_emu_reserv_1 = Zmm(26);
298 Zmm bf16_emu_reserv_2 = Zmm(27);
299 Zmm bf16_emu_reserv_3 = Zmm(28);
300 Reg64 bf16_emu_scratch = r14;
301 Zmm bf16_emu_reserv_5 = Zmm(29);
302
303 Opmask k_tail_mask = k6;
304
305 std::unique_ptr<jit_bf16_injector_t> bf16_injector_;
306 std::unique_ptr<bf16_emulation_t> bf16_emu_;
307};
308
309} // namespace
310
311template <cpu_isa_t isa, data_type_t d_type>
312status_t jit_uni_eltwise_fwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
313 using namespace alg_kind;
314
315 const memory_desc_wrapper src_d(src_md());
316
317 bool ok = mayiuse(isa) && is_fwd()
318 && utils::everyone_is(
319 d_type, src_md()->data_type, dst_md()->data_type)
320 && IMPLICATION(src_md()->data_type == data_type::bf16,
321 mayiuse(avx512_core))
322 && IMPLICATION(src_md()->data_type == data_type::f16,
323 mayiuse(avx512_core_fp16))
324 && !has_zero_dim_memory() && src_d.is_dense(true)
325 && eltwise_injector::is_supported(isa, desc_.alg_kind)
326 // refer to a comment in jit_uni_kernel why this is needed
327 && IMPLICATION(!src_d.is_dense(), is_zero_preserved())
328 && attr()->has_default_values() && set_default_formats_common()
329 && src_d == memory_desc_wrapper(dst_md());
330 return ok ? status::success : status::unimplemented;
331}
332
333template <cpu_isa_t isa, data_type_t d_type>
334jit_uni_eltwise_fwd_t<isa, d_type>::jit_uni_eltwise_fwd_t(const pd_t *apd)
335 : primitive_t(apd) {}
336
337template <cpu_isa_t isa, data_type_t d_type>
338jit_uni_eltwise_fwd_t<isa, d_type>::~jit_uni_eltwise_fwd_t() = default;
339
340template <cpu_isa_t isa, data_type_t d_type>
341status_t jit_uni_eltwise_fwd_t<isa, d_type>::init(engine_t *engine) {
342 CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd())));
343 return kernel_->create_kernel();
344}
345
346template <cpu_isa_t isa, data_type_t d_type>
347status_t jit_uni_eltwise_fwd_t<isa, d_type>::execute(
348 const exec_ctx_t &ctx) const {
349 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
350 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
351
352 const memory_desc_wrapper data_d(pd()->src_md());
353 const auto nelems = data_d.nelems(true);
354 const int simd_w = 64 / data_d.data_type_size();
355
356 src += data_d.offset0();
357 dst += data_d.offset0();
358
359 parallel(0, [&](const int ithr, const int nthr) {
360 dim_t start {0}, end {0};
361
362 balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
363 start = nstl::min(nelems, start * simd_w);
364 end = nstl::min(nelems, end * simd_w);
365 if (start == end) return;
366
367 jit_args_t args;
368 args.src = src + start;
369 args.dst = dst + start;
370 args.diff_dst = nullptr;
371 args.work_amount = end - start;
372 (*kernel_)(&args);
373 });
374
375 return status::success;
376}
377
378template <cpu_isa_t isa, data_type_t d_type>
379status_t jit_uni_eltwise_bwd_t<isa, d_type>::pd_t::init(engine_t *engine) {
380 using namespace alg_kind;
381
382 const memory_desc_wrapper data_d(data_md());
383
384 bool ok = mayiuse(isa) && !is_fwd()
385 && utils::everyone_is(d_type, data_md()->data_type,
386 diff_src_md()->data_type, diff_dst_md()->data_type)
387 && IMPLICATION(data_md()->data_type == data_type::bf16,
388 mayiuse(avx512_core))
389 && IMPLICATION(data_md()->data_type == data_type::f16,
390 mayiuse(avx512_core_fp16))
391 && !has_zero_dim_memory() && set_default_formats_common()
392 && data_d.is_dense(true) && eltwise_injector::is_isa_supported(isa)
393 && eltwise_injector::is_alg_supported(desc_.alg_kind)
394 // refer to a comment in jit_uni_kernel why this is needed
395 && IMPLICATION(!data_d.is_dense(), is_zero_preserved())
396 && data_d == memory_desc_wrapper(diff_dst_md())
397 && memory_desc_wrapper(diff_src_md())
398 == memory_desc_wrapper(diff_dst_md())
399 && attr()->has_default_values();
400 return ok ? status::success : status::unimplemented;
401}
402
403template <cpu_isa_t isa, data_type_t d_type>
404jit_uni_eltwise_bwd_t<isa, d_type>::jit_uni_eltwise_bwd_t(const pd_t *apd)
405 : primitive_t(apd) {}
406
407template <cpu_isa_t isa, data_type_t d_type>
408jit_uni_eltwise_bwd_t<isa, d_type>::~jit_uni_eltwise_bwd_t() = default;
409
410template <cpu_isa_t isa, data_type_t d_type>
411status_t jit_uni_eltwise_bwd_t<isa, d_type>::init(engine_t *engine) {
412 CHECK(safe_ptr_assign(kernel_, new jit_uni_kernel_t<isa>(pd())));
413 return kernel_->create_kernel();
414}
415
416template <cpu_isa_t isa, data_type_t d_type>
417status_t jit_uni_eltwise_bwd_t<isa, d_type>::execute(
418 const exec_ctx_t &ctx) const {
419 auto src = pd()->use_dst() ? CTX_IN_MEM(const data_t *, DNNL_ARG_DST)
420 : CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
421 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
422 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
423
424 const memory_desc_wrapper data_d(pd()->data_md());
425 const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
426 const auto nelems = data_d.nelems(true);
427 const int simd_w = 64 / data_d.data_type_size();
428
429 src += data_d.offset0();
430 diff_dst += diff_data_d.offset0();
431 diff_src += diff_data_d.offset0();
432
433 parallel(0, [&](const int ithr, const int nthr) {
434 dim_t start {0}, end {0};
435
436 balance211(utils::div_up(nelems, simd_w), nthr, ithr, start, end);
437 start = nstl::min(nelems, start * simd_w);
438 end = nstl::min(nelems, end * simd_w);
439 if (start == end) return;
440
441 jit_args_t args;
442 args.src = src + start;
443 args.dst = diff_src + start;
444 args.diff_dst = diff_dst + start;
445 args.work_amount = end - start;
446 (*kernel_)(&args);
447 });
448
449 return status::success;
450}
451
452template struct jit_uni_eltwise_fwd_t<sse41, data_type::f32>;
453template struct jit_uni_eltwise_fwd_t<avx, data_type::f32>;
454template struct jit_uni_eltwise_fwd_t<avx2, data_type::f32>;
455template struct jit_uni_eltwise_fwd_t<avx512_core, data_type::f32>;
456template struct jit_uni_eltwise_fwd_t<avx512_core, data_type::bf16>;
457template struct jit_uni_eltwise_fwd_t<avx512_core_fp16, data_type::f16>;
458
459template struct jit_uni_eltwise_bwd_t<sse41, data_type::f32>;
460template struct jit_uni_eltwise_bwd_t<avx, data_type::f32>;
461template struct jit_uni_eltwise_bwd_t<avx2, data_type::f32>;
462template struct jit_uni_eltwise_bwd_t<avx512_core, data_type::f32>;
463template struct jit_uni_eltwise_bwd_t<avx512_core, data_type::bf16>;
464template struct jit_uni_eltwise_bwd_t<avx512_core_fp16, data_type::f16>;
465
466} // namespace x64
467} // namespace cpu
468} // namespace impl
469} // namespace dnnl
470