1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/math_utils.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/nstl.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_primitive.hpp"
28
29#include "cpu/x64/jit_generator.hpp"
30
31#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp"
32#include "cpu/x64/jit_uni_softmax.hpp"
33#include "cpu/x64/utils/jit_io_helper.hpp"
34
35#if defined(__INTEL_COMPILER) && (__INTEL_COMPILER < 1900)
36// Intel Compilers 17.x and 18.x do not like that diff_src_ptr() is only used
37// in a single descendant class and marks it as unused. This breaks builds
38// with DNNL_WERROR=on. Disabling the warning for this file seems to be less
39// ugly than all the fixes that I came up with.
40#pragma warning disable : 177
41#endif
42
43namespace dnnl {
44namespace impl {
45namespace cpu {
46namespace x64 {
47
48using namespace Xbyak;
49
50template <cpu_isa_t isa>
51struct jit_softmax_t : public jit_generator {
52 struct call_params_t {
53 // keep all sizes at 8 bytes -- jit code expects this
54 const void *src, *dst, *diff_dst; // src dubs as diff_src
55 const void *interim; // scratch memory for intermediate storage
56 const void *src_scales; // src_scales defined for all data type cases
57 const void *dst_scales; // dst_scales defined for all data type cases
58 size_t process_n_elems;
59 };
60 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_softmax_t)
61
62 // cpu specific part
63 using Vmm = typename cpu_isa_traits<isa>::Vmm;
64 const AddressFrame &vmmword
65 = (isa == sse41) ? xword : (isa == avx2) ? yword : zword;
66 const int vlen = cpu_isa_traits<isa>::vlen;
67
68 const softmax_pd_t *pd_;
69 const memory_desc_wrapper src_d_, dst_d_, diff_dst_d_;
70 io::jit_io_multi_dt_helper_t<Vmm> io_;
71
72 std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> exp_injector_;
73 std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> log_injector_;
74
75 Reg64 reg_param = abi_param1;
76
77 Reg64 reg_exp_injector_table = rax;
78 Reg64 reg_log_injector_table = rbx;
79 Reg64 reg_src = r8;
80 Reg64 reg_diff_src = reg_src;
81 Reg64 reg_dst = r9;
82 Reg64 reg_diff_dst = r14;
83 Reg64 reg_src_spat_offt = r10;
84 Reg64 reg_process_n_elems = r11;
85 Reg64 reg_reverse_n_elems = r12;
86 Reg64 reg_tmp = r13;
87 Reg64 reg_dst_spat_offt = r15;
88 Reg64 reg_diff_dst_spat_offt = reg_log_injector_table;
89 Reg64 reg_interim = reg_diff_dst;
90 Reg64 reg_interim_spat_offt = abi_not_param1;
91 Reg64 reg_src_scales = rsi;
92 Reg64 reg_dst_scales = rdx;
93
94 Opmask injector_mask = Opmask(1);
95
96 Vmm vtmp; // assigned at placed where used
97 Vmm tail_vmask = Vmm(0);
98 Xmm xneg_flt_max = Xmm(12);
99 Vmm vneg_flt_max = Vmm(isa == avx512_core ? 28 : 12);
100 Xmm xone = Xmm(13);
101 Vmm vone = Vmm(isa == avx512_core ? 29 : 13);
102 Vmm vsum = Vmm(isa == avx512_core ? 30 : 14);
103 Vmm vmax = Vmm(isa == avx512_core ? 31 : 15);
104 Vmm vsbr = vsum; // must be not equal to vmax
105 Vmm vzero = Vmm(isa == avx512_core ? 21 : 11);
106 Vmm vcvt_vmm = Vmm(isa == avx512_core ? 22 : 10);
107 Vmm vsaturation_ubound = vneg_flt_max;
108
109 bool is_bf16_ = false;
110 bool is_f16_ = false;
111 bool is_avx2_ne_xf16_ = false;
112 bool is_softmax_ = pd_->is_softmax();
113 bool is_logsoftmax_ = pd_->is_logsoftmax();
114 bool axis_is_blocked_;
115 bool need_scratchpad_;
116
117 size_t simd_w_ = 0;
118 size_t unroll_regs_ = 4;
119
120 size_t axis_simd_full_;
121 size_t axis_simd_tail_;
122 size_t n_loops_;
123 size_t loop_tail_;
124 size_t process_n_elems_;
125 size_t src_axis_stride_;
126 size_t interim_axis_stride_;
127 size_t dst_axis_stride_;
128 size_t diff_dst_axis_stride_;
129
130 const int bf16_emu_zmm_1_idx_ = 23;
131 const int bf16_emu_zmm_2_idx_ = 24;
132 const int bf16_emu_zmm_3_idx_ = 25;
133 const int bf16_emu_zmm_4_idx_ = 26;
134 const int tail_opmask_idx_ = 2;
135
136 Opmask tail_opmask = Opmask(tail_opmask_idx_);
137
138 void operator()(const call_params_t *p) {
139 return jit_generator::operator()(p);
140 }
141
142 cpu_isa_t get_io_isa() {
143 // reusing avx512_core instantiation for xf16 on AVX512_CORE+
144 // reusing avx2 instantiation for xf16 on AVX2_VNNI_2
145 const bool is_reuse_avx512_core = isa == avx512_core
146 && (mayiuse(avx512_core_bf16) || mayiuse(avx512_core_fp16));
147 const bool is_reuse_avx2 = isa == avx2 && mayiuse(avx2_vnni_2);
148 if (is_bf16_ || is_f16_) {
149 return is_reuse_avx512_core
150 ? is_f16_ ? avx512_core_fp16 : avx512_core_bf16
151 : is_reuse_avx2 ? avx2_vnni_2 : isa;
152 } else
153 return isa;
154 }
155
156 bool is_data_type_xf16(data_type_t dt) {
157 return utils::one_of(dt, data_type::bf16, data_type::f16);
158 }
159
160 void compute_predefined_variables() {
161 n_loops_ = axis_simd_full_ / unroll_regs_;
162 loop_tail_ = axis_simd_full_ - n_loops_ * unroll_regs_;
163 process_n_elems_ = compute_process_n_elems(dst_d_);
164 src_axis_stride_ = compute_axis_stride(src_d_);
165 interim_axis_stride_ = simd_w_ * sizeof(float);
166 dst_axis_stride_ = compute_axis_stride(dst_d_);
167 if (!pd_->is_fwd())
168 diff_dst_axis_stride_ = compute_axis_stride(diff_dst_d_);
169 axis_is_blocked_ = pd_->axis_size(true) != pd_->axis_size();
170 }
171
172 size_t compute_process_n_elems(const memory_desc_wrapper &mdw) {
173 const auto &bd = mdw.blocking_desc();
174 if (bd.inner_nblks) return bd.strides[pd_->axis()];
175 return simd_w_;
176 }
177
178 size_t compute_axis_stride(const memory_desc_wrapper &mdw) {
179 return compute_process_n_elems(mdw) * mdw.data_type_size();
180 }
181
182 void load_common_params() {
183 mov(reg_tmp, float2int(1.0f));
184 uni_vmovq(xone, reg_tmp);
185 uni_vbroadcastss(vone, xone);
186 mov(reg_tmp, float2int(-FLT_MAX));
187 uni_vmovq(xneg_flt_max, reg_tmp);
188 uni_vbroadcastss(vneg_flt_max, xneg_flt_max);
189
190#define PARAM_OFF(x) offsetof(call_params_t, x)
191 mov(reg_process_n_elems, ptr[reg_param + PARAM_OFF(process_n_elems)]);
192 mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
193 if (pd_->is_fwd())
194 mov(reg_src, ptr[reg_param + PARAM_OFF(src)]);
195 else {
196 mov(reg_diff_src, ptr[reg_param + PARAM_OFF(src)]); // src is reused
197 mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]);
198 }
199 if (need_scratchpad_) {
200 mov(reg_interim, ptr[reg_param + PARAM_OFF(interim)]);
201 }
202 mov(reg_src_scales, ptr[reg_param + PARAM_OFF(src_scales)]);
203 mov(reg_dst_scales, ptr[reg_param + PARAM_OFF(dst_scales)]);
204#undef PARAM_OFF
205 }
206
207 Address diff_src_ptr(size_t offt = 0) {
208 return vmmword[reg_diff_src + reg_src_spat_offt + offt];
209 }
210
211 Address src_ptr(size_t offt = 0) {
212 return vmmword[reg_src + reg_src_spat_offt + offt];
213 }
214
215 Address interim_ptr(size_t offt = 0) {
216 return vmmword[reg_interim + reg_interim_spat_offt + offt];
217 }
218
219 Address dst_ptr(size_t offt = 0) {
220 return vmmword[reg_dst + reg_dst_spat_offt + offt];
221 }
222
223 Address diff_dst_ptr(size_t offt = 0) {
224 return vmmword[reg_diff_dst + reg_diff_dst_spat_offt + offt];
225 }
226
227 enum class op_t : unsigned { max, sum };
228
229 void perform_op(Vmm v, Vmm vtmp, op_t op) {
230 if (op == op_t::max)
231 uni_vmaxps(v, v, vtmp);
232 else if (op == op_t::sum)
233 uni_vaddps(v, v, vtmp);
234 }
235
236 void get_horizontal_op(const Vmm &vsrc, const Vmm &vtmp, op_t op) {
237 const Zmm &zsrc = Zmm(vsrc.getIdx());
238 const Zmm &ztmp = Zmm(vtmp.getIdx());
239 const Ymm &ysrc = Ymm(vsrc.getIdx());
240 const Ymm &ytmp = Ymm(vtmp.getIdx());
241
242 if (is_superset(isa, avx512_core)) {
243 vshuff32x4(ztmp, zsrc, zsrc, 0x4E); // 256-bit shuffle
244 perform_op(vsrc, vtmp, op);
245 vshuff32x4(ztmp, zsrc, zsrc, 0xB1); // 128/256-bit shuffle
246 perform_op(vsrc, vtmp, op);
247 } else if (is_superset(isa, avx2)) {
248 vperm2f128(ytmp, ysrc, ysrc, 0x1); // 128/256-bit shuffle
249 perform_op(vsrc, vtmp, op);
250 }
251 uni_vshufps(vtmp, vsrc, vsrc, 0x4E); // 64/128-bit shuffle
252 perform_op(vsrc, vtmp, op);
253 uni_vshufps(vtmp, vsrc, vsrc, 0xB1); // 32/64-bit shuffle
254 perform_op(vsrc, vtmp, op);
255 }
256
257 template <typename body_t>
258 void axis_loop(body_t body) {
259 Label main_loop, tail_loop, tail_axis;
260
261 // reverse_spat_offt to dispatch between labels
262 mov(reg_reverse_n_elems, reg_process_n_elems);
263 xor_(reg_src_spat_offt, reg_src_spat_offt); // src/diff_src addr
264 xor_(reg_dst_spat_offt, reg_dst_spat_offt); // dst addr
265 if (need_scratchpad_)
266 xor_(reg_interim_spat_offt, reg_interim_spat_offt); // scratch addr
267 if (!pd_->is_fwd())
268 xor_(reg_diff_dst_spat_offt, reg_diff_dst_spat_offt); // d_dst addr
269 L(main_loop);
270 {
271 if (n_loops_) {
272 cmp(reg_reverse_n_elems, unroll_regs_ * process_n_elems_);
273 jl(tail_loop, T_NEAR);
274
275 body(unroll_regs_, false);
276 sub(reg_reverse_n_elems, unroll_regs_ * process_n_elems_);
277 add(reg_src_spat_offt, unroll_regs_ * src_axis_stride_);
278 add(reg_dst_spat_offt, unroll_regs_ * dst_axis_stride_);
279 if (need_scratchpad_)
280 add(reg_interim_spat_offt,
281 unroll_regs_ * interim_axis_stride_);
282 if (!pd_->is_fwd())
283 add(reg_diff_dst_spat_offt,
284 unroll_regs_ * diff_dst_axis_stride_);
285 jmp(main_loop);
286 }
287 }
288
289 L(tail_loop);
290 {
291 if (loop_tail_) {
292 body(loop_tail_, false);
293 add(reg_src_spat_offt, loop_tail_ * src_axis_stride_);
294 add(reg_dst_spat_offt, loop_tail_ * dst_axis_stride_);
295 if (need_scratchpad_)
296 add(reg_interim_spat_offt,
297 loop_tail_ * interim_axis_stride_);
298 if (!pd_->is_fwd())
299 add(reg_diff_dst_spat_offt,
300 loop_tail_ * diff_dst_axis_stride_);
301 }
302 }
303
304 L(tail_axis);
305 {
306 if (axis_simd_tail_) { body(1, true); }
307 }
308 }
309
310 void uni_vaddps_maybe_tail(
311 const Vmm &v1, const Vmm &v2, const Vmm &vtmp, const bool tail) {
312 if (tail) {
313 if (is_superset(isa, avx512_core)) {
314 uni_vaddps(v1 | tail_opmask, v1, v2);
315 } else {
316 uni_vpxor(vtmp, vtmp, vtmp);
317 uni_vblendvps(vtmp, vtmp, v2, tail_vmask);
318 uni_vaddps(v1, v1, vtmp);
319 }
320 } else
321 uni_vaddps(v1, v1, v2);
322 }
323
324 void uni_vmaxps_maybe_tail(
325 const Vmm &v1, const Vmm &v2, const Vmm &vtmp, const bool tail) {
326 if (tail) {
327 if (is_superset(isa, avx512_core)) {
328 uni_vmaxps(v1 | tail_opmask, v1, v2);
329 } else if (is_superset(isa, avx)) {
330 uni_vblendvps(v2, vneg_flt_max, v2, tail_vmask);
331 uni_vmaxps(v1, v1, v2);
332 } else {
333 uni_vmovups(vtmp, v2);
334 uni_vmovups(v2, vneg_flt_max);
335 uni_vblendvps(v2, v2, vtmp, tail_vmask);
336 uni_vmaxps(v1, v1, v2);
337 }
338 } else
339 uni_vmaxps(v1, v1, v2);
340 }
341
342 void store(const Address &addr, const Vmm &vmm, data_type_t dt,
343 bool tail = false) {
344 // Use temporary register in storing when convertion is needed
345 // Or we need to restore data back to fp32 since we apply exp after
346 // storing and data should be fp32
347 const bool need_restore = is_logsoftmax_ && dt != data_type::f32;
348 Vmm src_vmm = vmm;
349
350 if (tail && axis_is_blocked_) {
351 if (is_superset(isa, avx512_core)
352 && utils::one_of(dt, data_type::f32, data_type::bf16,
353 data_type::f16)) {
354 src_vmm = vzero | tail_opmask;
355 uni_vxorps(vzero, vzero, vzero);
356 uni_vmovups(src_vmm, vmm);
357 src_vmm = vzero;
358 } else {
359 uni_vpxor(vzero, vzero, vzero);
360 uni_vblendvps(vzero, vzero, src_vmm, tail_vmask);
361 src_vmm = vzero;
362 }
363 } else if (need_restore) {
364 uni_vmovups(vcvt_vmm, vmm);
365 src_vmm = vcvt_vmm;
366 }
367
368 io_[dt]->store(src_vmm, addr, tail && !axis_is_blocked_);
369 }
370
371 // Use ne_convert instruction to load xf16 even/odd elements from memory
372 void accumulate_avx2_ne_xf16_vmax() {
373 // flush to -FLT_MAX before accumulation
374 uni_vmovups(vmax, vneg_flt_max);
375
376 axis_loop([&](int unroll, bool tail = false) {
377 for (int i = 0; i < unroll; i += 2) {
378 const bool can_load_two_simdw = unroll - i >= 2;
379 Vmm vreg_tmp_src_even = Vmm(i + 1);
380 Vmm vreg_tmp_src_odd = Vmm(i + 2);
381 vtmp = Vmm(i + 3);
382 if (can_load_two_simdw) {
383 io_[src_d_.data_type()]->load_two_simdw_xf16(
384 src_ptr(src_axis_stride_ * i), vreg_tmp_src_even,
385 vreg_tmp_src_odd);
386 } else
387 io_[src_d_.data_type()]->load(src_ptr(src_axis_stride_ * i),
388 vreg_tmp_src_even, tail);
389 uni_vmaxps_maybe_tail(vmax, vreg_tmp_src_even, vtmp, tail);
390 if (can_load_two_simdw)
391 uni_vmaxps_maybe_tail(vmax, vreg_tmp_src_odd, vtmp, tail);
392 }
393 });
394
395 get_horizontal_op(vmax, vtmp = vsum, op_t::max);
396 }
397
398 void accumulate_vmax() {
399 if (is_avx2_ne_xf16_ && is_data_type_xf16(src_d_.data_type())) {
400 accumulate_avx2_ne_xf16_vmax();
401 return;
402 }
403
404 // flush to -FLT_MAX before accumulation
405 uni_vmovups(vmax, vneg_flt_max);
406
407 axis_loop([&](int unroll, bool tail = false) {
408 for (int i = 0; i < unroll; i++) {
409 Vmm vreg_tmp_src = Vmm(i + 1);
410 vtmp = Vmm(i + 2);
411 // do maxps directly from memory on f32 avx2 for performance purpose
412 if (!tail && isa == avx2
413 && src_d_.data_type() == data_type::f32) {
414 uni_vmaxps(vmax, vmax, src_ptr(src_axis_stride_ * i));
415 } else {
416 io_[src_d_.data_type()]->load(
417 src_ptr(src_axis_stride_ * i), vreg_tmp_src, tail);
418 uni_vmaxps_maybe_tail(vmax, vreg_tmp_src, vtmp, tail);
419 }
420 }
421 });
422
423 get_horizontal_op(vmax, vtmp = vsum, op_t::max);
424 }
425
426 // Use ne_convert instruction to load xf16 even/odd elements from memory
427 void accumulate_avx2_ne_xf16_vsum() {
428 // Initialize saturation vector register
429 io_.init_saturate_f32({dst_d_.data_type()});
430
431 uni_vpxor(vsum, vsum, vsum); // flush to zero before accumulation
432
433 axis_loop([&](int unroll, bool tail = false) {
434 for (int i = 0; i < unroll; i += 2) {
435 const bool can_load_two_simdw = unroll - i >= 2;
436 Vmm vreg_tmp_src_even = Vmm(i + 1);
437 Vmm vreg_tmp_src_odd = Vmm(i + 2);
438 vtmp = Vmm(i + 3);
439 if (can_load_two_simdw) {
440 io_[src_d_.data_type()]->load_two_simdw_xf16(
441 src_ptr(src_axis_stride_ * i), vreg_tmp_src_even,
442 vreg_tmp_src_odd);
443 io_[src_d_.data_type()]->merge_interleaved_to_plain(
444 vreg_tmp_src_even, vreg_tmp_src_odd, vtmp);
445 } else
446 io_[src_d_.data_type()]->load(src_ptr(src_axis_stride_ * i),
447 vreg_tmp_src_even, tail);
448 for (int i_odd = 0; i_odd < 2 && i_odd + i < unroll; i_odd++) {
449 const auto vreg_tmp_src
450 = i_odd ? vreg_tmp_src_odd : vreg_tmp_src_even;
451 uni_vsubps(vreg_tmp_src, vreg_tmp_src, vmax);
452 if (is_logsoftmax_) // store before applying exp
453 store(dst_ptr(dst_axis_stride_ * (i + i_odd)),
454 vreg_tmp_src, dst_d_.data_type(), tail);
455 exp_injector_->compute_vector(vreg_tmp_src.getIdx());
456 uni_vaddps_maybe_tail(vsum, vreg_tmp_src, vtmp, tail);
457 if (is_softmax_) // store after applying exp
458 store(dst_ptr(dst_axis_stride_ * (i + i_odd)),
459 vreg_tmp_src, dst_d_.data_type(), tail);
460 }
461 }
462 });
463
464 get_horizontal_op(vsum, vtmp = vmax, op_t::sum);
465 if (is_softmax_) uni_vdivps(vsum, vone, vsum, vtmp = vmax);
466 if (is_logsoftmax_) log_injector_->compute_vector(vsum.getIdx());
467 }
468
469 void accumulate_vsum() {
470 if (is_avx2_ne_xf16_ && is_data_type_xf16(src_d_.data_type())) {
471 accumulate_avx2_ne_xf16_vsum();
472 return;
473 }
474
475 // Initialize saturation vector register
476 io_.init_saturate_f32({dst_d_.data_type()});
477
478 uni_vpxor(vsum, vsum, vsum); // flush to zero before accumulation
479
480 axis_loop([&](int unroll, bool tail = false) {
481 for (int i = 0; i < unroll; i++) {
482 Vmm vreg_tmp_src = Vmm(i + 1);
483 vtmp = Vmm(i + 2);
484 io_[src_d_.data_type()]->load(
485 src_ptr(src_axis_stride_ * i), vreg_tmp_src, tail);
486 uni_vsubps(vreg_tmp_src, vreg_tmp_src, vmax);
487 if (is_logsoftmax_) { // store before applying exp
488 if (need_scratchpad_)
489 store(interim_ptr(interim_axis_stride_ * i),
490 vreg_tmp_src, data_type::f32, tail);
491 else
492 store(dst_ptr(dst_axis_stride_ * i), vreg_tmp_src,
493 dst_d_.data_type(), tail);
494 }
495 exp_injector_->compute_vector(vreg_tmp_src.getIdx());
496 uni_vaddps_maybe_tail(vsum, vreg_tmp_src, vtmp, tail);
497 if (is_softmax_) { // store after applying exp
498 if (need_scratchpad_)
499 store(interim_ptr(interim_axis_stride_ * i),
500 vreg_tmp_src, data_type::f32, tail);
501 else
502 store(dst_ptr(dst_axis_stride_ * i), vreg_tmp_src,
503 dst_d_.data_type(), tail);
504 }
505 }
506 });
507
508 get_horizontal_op(vsum, vtmp = vmax, op_t::sum);
509 if (is_softmax_) uni_vdivps(vsum, vone, vsum, vtmp = vmax);
510 if (is_logsoftmax_) log_injector_->compute_vector(vsum.getIdx());
511 }
512
513 // Use ne_convert instruction to load xf16 even/odd elements from memory
514 void compute_avx2_ne_xf16_dst() {
515 axis_loop([&](int unroll, bool tail = false) {
516 for (int i = 0; i < unroll; i += 2) {
517 const bool can_load_two_simdw = unroll - i >= 2;
518 Vmm vreg_tmp_src_even = Vmm(i + 1);
519 Vmm vreg_tmp_src_odd = Vmm(i + 2);
520 vtmp = Vmm(i + 3);
521 if (can_load_two_simdw) {
522 io_[dst_d_.data_type()]->load_two_simdw_xf16(
523 dst_ptr(dst_axis_stride_ * i), vreg_tmp_src_even,
524 vreg_tmp_src_odd);
525 io_[dst_d_.data_type()]->merge_interleaved_to_plain(
526 vreg_tmp_src_even, vreg_tmp_src_odd, vtmp);
527 } else
528 io_[dst_d_.data_type()]->load(dst_ptr(dst_axis_stride_ * i),
529 vreg_tmp_src_even, tail);
530 for (int i_odd = 0; i_odd < 2 && i_odd + i < unroll; i_odd++) {
531 const auto vreg_tmp_src
532 = i_odd ? vreg_tmp_src_odd : vreg_tmp_src_even;
533 if (is_softmax_)
534 uni_vmulps(vreg_tmp_src, vreg_tmp_src, vsum);
535 if (is_logsoftmax_)
536 uni_vsubps(vreg_tmp_src, vreg_tmp_src, vsum);
537
538 store(dst_ptr(dst_axis_stride_ * (i + i_odd)), vreg_tmp_src,
539 dst_d_.data_type(), tail);
540 }
541 }
542 });
543 }
544
545 void compute_dst() {
546 if (is_avx2_ne_xf16_ && is_data_type_xf16(dst_d_.data_type())) {
547 compute_avx2_ne_xf16_dst();
548 return;
549 }
550
551 axis_loop([&](int unroll, bool tail = false) {
552 for (int i = 0; i < unroll; i++) {
553 Vmm vreg_tmp_src = Vmm(i + 1);
554 if (need_scratchpad_)
555 io_[data_type::f32]->load(
556 interim_ptr(interim_axis_stride_ * i), vreg_tmp_src,
557 tail);
558 else
559 io_[dst_d_.data_type()]->load(
560 dst_ptr(dst_axis_stride_ * i), vreg_tmp_src, tail);
561
562 if (is_softmax_) uni_vmulps(vreg_tmp_src, vreg_tmp_src, vsum);
563 if (is_logsoftmax_)
564 uni_vsubps(vreg_tmp_src, vreg_tmp_src, vsum);
565
566 if (is_superset(isa, avx512_core)) {
567 Vmm vscale = vmax;
568 uni_vmovups(vscale, ptr[reg_src_scales]);
569 uni_vmulps(vreg_tmp_src, vreg_tmp_src, vscale);
570 // Reserved spot for post-ops injector
571 uni_vmovups(vscale, ptr[reg_dst_scales]);
572 uni_vmulps(vreg_tmp_src, vreg_tmp_src, vscale);
573 }
574 store(dst_ptr(dst_axis_stride_ * i), vreg_tmp_src,
575 dst_d_.data_type(), tail);
576 }
577 });
578 }
579
580 void accumulate_vsbr() {
581 uni_vpxor(vsbr, vsbr, vsbr); // flush to zero before accumulation
582
583 axis_loop([&](int unroll, bool tail = false) {
584 for (int i = 0; i < unroll; i++) {
585 Vmm vreg_tmp_dst = Vmm(i * 2 + 1);
586 Vmm vreg_tmp_diff_dst = Vmm(i * 2 + 2);
587 io_[diff_dst_d_.data_type()]->load(
588 diff_dst_ptr(diff_dst_axis_stride_ * i),
589 vreg_tmp_diff_dst, tail);
590 if (is_softmax_) {
591 io_[dst_d_.data_type()]->load(
592 dst_ptr(dst_axis_stride_ * i), vreg_tmp_dst, tail);
593 uni_vmulps(
594 vreg_tmp_diff_dst, vreg_tmp_diff_dst, vreg_tmp_dst);
595 }
596 uni_vaddps(vsbr, vsbr, vreg_tmp_diff_dst);
597 }
598 });
599
600 get_horizontal_op(vsbr, vtmp = vmax, op_t::sum);
601 }
602
603 void compute_diff_src() {
604 axis_loop([&](int unroll, bool tail = false) {
605 for (int i = 0; i < unroll; i++) {
606 Vmm vreg_tmp_dst = Vmm(i * 2 + 1);
607 Vmm vreg_tmp_diff_dst = Vmm(i * 2 + 2);
608 io_[dst_d_.data_type()]->load(
609 dst_ptr(dst_axis_stride_ * i), vreg_tmp_dst, tail);
610 io_[diff_dst_d_.data_type()]->load(
611 diff_dst_ptr(diff_dst_axis_stride_ * i),
612 vreg_tmp_diff_dst, tail);
613 if (is_softmax_) {
614 uni_vsubps(vreg_tmp_diff_dst, vreg_tmp_diff_dst, vsbr);
615 uni_vmulps(
616 vreg_tmp_diff_dst, vreg_tmp_dst, vreg_tmp_diff_dst);
617 }
618 if (is_logsoftmax_) {
619 exp_injector_->compute_vector(vreg_tmp_dst.getIdx());
620 uni_vfnmadd231ps(vreg_tmp_diff_dst, vreg_tmp_dst, vsbr);
621 }
622 store(diff_src_ptr(src_axis_stride_ * i), vreg_tmp_diff_dst,
623 src_d_.data_type(), tail);
624 }
625 });
626 }
627
628 void forward() {
629 accumulate_vmax();
630 accumulate_vsum();
631 compute_dst();
632 }
633
634 void backward() {
635 accumulate_vsbr();
636 compute_diff_src();
637 }
638
639 // either this stub or duplication at each jit_binary_t ctor due to methods
640 // that are participated are not defined at the moment of base ctor
641 // initialization.
642 void generate() override {
643 if (pd_->is_fwd() || is_logsoftmax_)
644 exp_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this,
645 alg_kind::eltwise_exp, 0.0f, 0.0f, 1.0f, true,
646 reg_exp_injector_table, injector_mask));
647 if (pd_->is_fwd() && is_logsoftmax_) {
648 log_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this,
649 alg_kind::eltwise_log, 0.0f, 0.0f, 1.0f, true,
650 reg_log_injector_table, injector_mask));
651 }
652
653 compute_predefined_variables();
654 preamble();
655 io_.init_bf16();
656 if (exp_injector_) exp_injector_->load_table_addr();
657 if (log_injector_) log_injector_->load_table_addr();
658 if (axis_simd_tail_) io_.prepare_tail_mask();
659 load_common_params();
660 if (pd_->is_fwd())
661 forward();
662 else
663 backward();
664 postamble();
665 if (exp_injector_) exp_injector_->prepare_table();
666 if (log_injector_) log_injector_->prepare_table();
667 }
668
669 jit_softmax_t(const softmax_pd_t *pd)
670 : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa)
671 , pd_(pd)
672 , src_d_(pd_->is_fwd() ? pd_->src_md() : pd_->diff_src_md())
673 , dst_d_(pd_->dst_md())
674 , diff_dst_d_(pd_->diff_dst_md()) {
675 is_bf16_ = utils::one_of(
676 data_type::bf16, src_d_.data_type(), dst_d_.data_type());
677 is_f16_ = utils::one_of(
678 data_type::f16, src_d_.data_type(), dst_d_.data_type());
679 simd_w_ = vlen / sizeof(float); // bf16 works on ymms
680 is_avx2_ne_xf16_
681 = isa == avx2 && mayiuse(avx2_vnni_2) && (is_bf16_ || is_f16_);
682 axis_simd_full_ = pd_->axis_size() / simd_w_;
683 axis_simd_tail_ = pd_->axis_size() % simd_w_;
684 need_scratchpad_ = utils::one_of(
685 dst_d_.data_type(), data_type::u8, data_type::s8);
686
687 io::io_conf_t io_conf;
688 io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_,
689 tail_opmask_idx_, tail_vmask.getIdx(), reg_tmp);
690 io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx_,
691 bf16_emu_zmm_2_idx_, bf16_emu_zmm_3_idx_, reg_tmp,
692 bf16_emu_zmm_4_idx_);
693 io::io_saturation_conf_t io_saturation_conf(
694 vzero.getIdx(), vsaturation_ubound.getIdx(), reg_tmp);
695 io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, get_io_isa(),
696 {src_d_.data_type(), dst_d_.data_type(),
697 data_type::f32 /* stats */},
698 io_conf, io_tail_conf, io_bf16_conf,
699 {{dst_d_.data_type(), io_saturation_conf}});
700 }
701};
702
703template <cpu_isa_t isa>
704jit_uni_softmax_fwd_t<isa>::jit_uni_softmax_fwd_t(const pd_t *apd)
705 : primitive_t(apd)
706 , softmax_driver_(new softmax_impl::driver_t<isa>(pd())) {}
707
708template <cpu_isa_t isa>
709jit_uni_softmax_fwd_t<isa>::~jit_uni_softmax_fwd_t() {
710 delete softmax_driver_;
711}
712
713template <cpu_isa_t isa>
714status_t jit_uni_softmax_fwd_t<isa>::init(engine_t *engine) {
715 return softmax_driver_->create_kernel();
716}
717
718template <cpu_isa_t isa>
719status_t jit_uni_softmax_fwd_t<isa>::execute(const exec_ctx_t &ctx) const {
720 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC);
721 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
722 auto scratchpad_ptr = ctx.get_scratchpad_grantor().template get<char>(
723 memory_tracking::names::key_softmax_interim_store);
724
725 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
726 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
727
728 const memory_desc_wrapper src_d(pd()->src_md());
729 const memory_desc_wrapper dst_d(pd()->dst_md());
730 const auto src_data_type_size = src_d.data_type_size();
731 const auto dst_data_type_size = dst_d.data_type_size();
732 const auto &bd = src_d.blocking_desc();
733 const auto axis = pd()->axis();
734
735 const auto axis_size_padded = pd()->axis_size(true);
736 const auto inner_stride
737 = bd.inner_nblks ? bd.inner_blks[bd.inner_nblks - 1] : (dim_t)1;
738 const auto inner_size = bd.strides[axis] / inner_stride;
739 const auto process_n_elems = pd()->axis_size() * inner_size;
740 const auto outer_stride = axis_size_padded * inner_size;
741 const auto outer_size = src_d.nelems(true) / outer_stride;
742
743 const int nthr = pd()->nthr_;
744
745 parallel_nd_ext(nthr, outer_size, inner_size,
746 [&](int ithr, int, dim_t ou, dim_t in) {
747 dim_t offset = (ou * outer_stride + in * inner_stride);
748 const char *src_ptr = src + offset * src_data_type_size;
749 char *dst_ptr = dst + offset * dst_data_type_size;
750 char *interim_ptr = scratchpad_ptr ? scratchpad_ptr
751 + ithr * axis_size_padded * sizeof(float)
752 : nullptr;
753 softmax_driver_->exec(src_ptr, dst_ptr, interim_ptr, src_scales,
754 dst_scales, process_n_elems);
755 });
756
757 return status::success;
758}
759
760template <cpu_isa_t isa>
761jit_uni_softmax_bwd_t<isa>::jit_uni_softmax_bwd_t(const pd_t *apd)
762 : primitive_t(apd)
763 , softmax_driver_(new softmax_impl::driver_t<isa>(pd())) {}
764
765template <cpu_isa_t isa>
766jit_uni_softmax_bwd_t<isa>::~jit_uni_softmax_bwd_t() {
767 delete softmax_driver_;
768}
769
770template <cpu_isa_t isa>
771status_t jit_uni_softmax_bwd_t<isa>::init(engine_t *engine) {
772 return softmax_driver_->create_kernel();
773}
774
775template <cpu_isa_t isa>
776status_t jit_uni_softmax_bwd_t<isa>::execute(const exec_ctx_t &ctx) const {
777 auto dst = CTX_IN_MEM(const char *, DNNL_ARG_DST);
778 auto diff_dst = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST);
779 auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC);
780
781 const memory_desc_wrapper dst_d(pd()->dst_md());
782 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
783 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
784 const auto dst_data_type_size = dst_d.data_type_size();
785 const auto diff_dst_data_type_size = diff_dst_d.data_type_size();
786 const auto diff_src_data_type_size = diff_src_d.data_type_size();
787 const auto &bd = dst_d.blocking_desc();
788 const auto axis = pd()->axis();
789
790 const auto inner_stride
791 = bd.inner_nblks ? bd.inner_blks[bd.inner_nblks - 1] : (dim_t)1;
792 const auto inner_size = bd.strides[axis] / inner_stride;
793 const auto process_n_elems = pd()->axis_size() * inner_size;
794 const auto outer_stride = pd()->axis_size(true) * inner_size;
795 const auto outer_size = dst_d.nelems(true) / outer_stride;
796
797 parallel_nd(outer_size, inner_size, [&](dim_t ou, dim_t in) {
798 dim_t offset = (ou * outer_stride + in * inner_stride);
799 char *diff_src_ptr = diff_src + offset * diff_src_data_type_size;
800 const char *dst_ptr = dst + offset * dst_data_type_size;
801 const char *diff_dst_ptr = diff_dst + offset * diff_dst_data_type_size;
802 softmax_driver_->exec(
803 diff_src_ptr, dst_ptr, diff_dst_ptr, process_n_elems);
804 });
805
806 return status::success;
807}
808
809namespace softmax_impl {
810
811template <cpu_isa_t isa>
812struct driver_t : public c_compatible {
813
814 driver_t(const softmax_pd_t *pd) : pd_(pd), ker_(pd_) {}
815
816 void exec(const void *src, void *dst, void *interim, const void *src_scales,
817 const void *dst_scales, const dim_t process_n_elems) {
818 typename jit_softmax_t<isa>::call_params_t p;
819 p.process_n_elems = process_n_elems;
820 p.src = src;
821 p.dst = dst;
822 p.interim = interim;
823 p.src_scales = src_scales;
824 p.dst_scales = dst_scales;
825 ker_(&p);
826 }
827
828 void exec(void *diff_src, const void *dst, const void *diff_dst,
829 const dim_t process_n_elems) {
830 typename jit_softmax_t<isa>::call_params_t p;
831 p.process_n_elems = process_n_elems;
832 p.src = diff_src;
833 p.dst = dst;
834 p.diff_dst = diff_dst;
835 ker_(&p);
836 }
837
838 status_t create_kernel() { return ker_.create_kernel(); }
839
840private:
841 const softmax_pd_t *pd_;
842 jit_softmax_t<isa> ker_;
843};
844
845} // namespace softmax_impl
846
847/* struct instantiation */
848template struct jit_uni_softmax_fwd_t<sse41>;
849template struct jit_uni_softmax_fwd_t<avx2>;
850template struct jit_uni_softmax_fwd_t<avx512_core>;
851template struct jit_uni_softmax_bwd_t<avx512_core>;
852
853} // namespace x64
854} // namespace cpu
855} // namespace impl
856} // namespace dnnl
857