1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/nstl.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | #include "cpu/cpu_primitive.hpp" |
28 | |
29 | #include "cpu/x64/jit_generator.hpp" |
30 | |
31 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
32 | #include "cpu/x64/jit_uni_softmax.hpp" |
33 | #include "cpu/x64/utils/jit_io_helper.hpp" |
34 | |
35 | #if defined(__INTEL_COMPILER) && (__INTEL_COMPILER < 1900) |
36 | // Intel Compilers 17.x and 18.x do not like that diff_src_ptr() is only used |
37 | // in a single descendant class and marks it as unused. This breaks builds |
38 | // with DNNL_WERROR=on. Disabling the warning for this file seems to be less |
39 | // ugly than all the fixes that I came up with. |
40 | #pragma warning disable : 177 |
41 | #endif |
42 | |
43 | namespace dnnl { |
44 | namespace impl { |
45 | namespace cpu { |
46 | namespace x64 { |
47 | |
48 | using namespace Xbyak; |
49 | |
50 | template <cpu_isa_t isa> |
51 | struct jit_softmax_t : public jit_generator { |
52 | struct call_params_t { |
53 | // keep all sizes at 8 bytes -- jit code expects this |
54 | const void *src, *dst, *diff_dst; // src dubs as diff_src |
55 | const void *interim; // scratch memory for intermediate storage |
56 | const void *src_scales; // src_scales defined for all data type cases |
57 | const void *dst_scales; // dst_scales defined for all data type cases |
58 | size_t process_n_elems; |
59 | }; |
60 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_softmax_t) |
61 | |
62 | // cpu specific part |
63 | using Vmm = typename cpu_isa_traits<isa>::Vmm; |
64 | const AddressFrame &vmmword |
65 | = (isa == sse41) ? xword : (isa == avx2) ? yword : zword; |
66 | const int vlen = cpu_isa_traits<isa>::vlen; |
67 | |
68 | const softmax_pd_t *pd_; |
69 | const memory_desc_wrapper src_d_, dst_d_, diff_dst_d_; |
70 | io::jit_io_multi_dt_helper_t<Vmm> io_; |
71 | |
72 | std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> exp_injector_; |
73 | std::unique_ptr<jit_uni_eltwise_injector_f32<isa>> log_injector_; |
74 | |
75 | Reg64 reg_param = abi_param1; |
76 | |
77 | Reg64 reg_exp_injector_table = rax; |
78 | Reg64 reg_log_injector_table = rbx; |
79 | Reg64 reg_src = r8; |
80 | Reg64 reg_diff_src = reg_src; |
81 | Reg64 reg_dst = r9; |
82 | Reg64 reg_diff_dst = r14; |
83 | Reg64 reg_src_spat_offt = r10; |
84 | Reg64 reg_process_n_elems = r11; |
85 | Reg64 reg_reverse_n_elems = r12; |
86 | Reg64 reg_tmp = r13; |
87 | Reg64 reg_dst_spat_offt = r15; |
88 | Reg64 reg_diff_dst_spat_offt = reg_log_injector_table; |
89 | Reg64 reg_interim = reg_diff_dst; |
90 | Reg64 reg_interim_spat_offt = abi_not_param1; |
91 | Reg64 reg_src_scales = rsi; |
92 | Reg64 reg_dst_scales = rdx; |
93 | |
94 | Opmask injector_mask = Opmask(1); |
95 | |
96 | Vmm vtmp; // assigned at placed where used |
97 | Vmm tail_vmask = Vmm(0); |
98 | Xmm xneg_flt_max = Xmm(12); |
99 | Vmm vneg_flt_max = Vmm(isa == avx512_core ? 28 : 12); |
100 | Xmm xone = Xmm(13); |
101 | Vmm vone = Vmm(isa == avx512_core ? 29 : 13); |
102 | Vmm vsum = Vmm(isa == avx512_core ? 30 : 14); |
103 | Vmm vmax = Vmm(isa == avx512_core ? 31 : 15); |
104 | Vmm vsbr = vsum; // must be not equal to vmax |
105 | Vmm vzero = Vmm(isa == avx512_core ? 21 : 11); |
106 | Vmm vcvt_vmm = Vmm(isa == avx512_core ? 22 : 10); |
107 | Vmm vsaturation_ubound = vneg_flt_max; |
108 | |
109 | bool is_bf16_ = false; |
110 | bool is_f16_ = false; |
111 | bool is_avx2_ne_xf16_ = false; |
112 | bool is_softmax_ = pd_->is_softmax(); |
113 | bool is_logsoftmax_ = pd_->is_logsoftmax(); |
114 | bool axis_is_blocked_; |
115 | bool need_scratchpad_; |
116 | |
117 | size_t simd_w_ = 0; |
118 | size_t unroll_regs_ = 4; |
119 | |
120 | size_t axis_simd_full_; |
121 | size_t axis_simd_tail_; |
122 | size_t n_loops_; |
123 | size_t loop_tail_; |
124 | size_t process_n_elems_; |
125 | size_t src_axis_stride_; |
126 | size_t interim_axis_stride_; |
127 | size_t dst_axis_stride_; |
128 | size_t diff_dst_axis_stride_; |
129 | |
130 | const int bf16_emu_zmm_1_idx_ = 23; |
131 | const int bf16_emu_zmm_2_idx_ = 24; |
132 | const int bf16_emu_zmm_3_idx_ = 25; |
133 | const int bf16_emu_zmm_4_idx_ = 26; |
134 | const int tail_opmask_idx_ = 2; |
135 | |
136 | Opmask tail_opmask = Opmask(tail_opmask_idx_); |
137 | |
138 | void operator()(const call_params_t *p) { |
139 | return jit_generator::operator()(p); |
140 | } |
141 | |
142 | cpu_isa_t get_io_isa() { |
143 | // reusing avx512_core instantiation for xf16 on AVX512_CORE+ |
144 | // reusing avx2 instantiation for xf16 on AVX2_VNNI_2 |
145 | const bool is_reuse_avx512_core = isa == avx512_core |
146 | && (mayiuse(avx512_core_bf16) || mayiuse(avx512_core_fp16)); |
147 | const bool is_reuse_avx2 = isa == avx2 && mayiuse(avx2_vnni_2); |
148 | if (is_bf16_ || is_f16_) { |
149 | return is_reuse_avx512_core |
150 | ? is_f16_ ? avx512_core_fp16 : avx512_core_bf16 |
151 | : is_reuse_avx2 ? avx2_vnni_2 : isa; |
152 | } else |
153 | return isa; |
154 | } |
155 | |
156 | bool is_data_type_xf16(data_type_t dt) { |
157 | return utils::one_of(dt, data_type::bf16, data_type::f16); |
158 | } |
159 | |
160 | void compute_predefined_variables() { |
161 | n_loops_ = axis_simd_full_ / unroll_regs_; |
162 | loop_tail_ = axis_simd_full_ - n_loops_ * unroll_regs_; |
163 | process_n_elems_ = compute_process_n_elems(dst_d_); |
164 | src_axis_stride_ = compute_axis_stride(src_d_); |
165 | interim_axis_stride_ = simd_w_ * sizeof(float); |
166 | dst_axis_stride_ = compute_axis_stride(dst_d_); |
167 | if (!pd_->is_fwd()) |
168 | diff_dst_axis_stride_ = compute_axis_stride(diff_dst_d_); |
169 | axis_is_blocked_ = pd_->axis_size(true) != pd_->axis_size(); |
170 | } |
171 | |
172 | size_t compute_process_n_elems(const memory_desc_wrapper &mdw) { |
173 | const auto &bd = mdw.blocking_desc(); |
174 | if (bd.inner_nblks) return bd.strides[pd_->axis()]; |
175 | return simd_w_; |
176 | } |
177 | |
178 | size_t compute_axis_stride(const memory_desc_wrapper &mdw) { |
179 | return compute_process_n_elems(mdw) * mdw.data_type_size(); |
180 | } |
181 | |
182 | void load_common_params() { |
183 | mov(reg_tmp, float2int(1.0f)); |
184 | uni_vmovq(xone, reg_tmp); |
185 | uni_vbroadcastss(vone, xone); |
186 | mov(reg_tmp, float2int(-FLT_MAX)); |
187 | uni_vmovq(xneg_flt_max, reg_tmp); |
188 | uni_vbroadcastss(vneg_flt_max, xneg_flt_max); |
189 | |
190 | #define PARAM_OFF(x) offsetof(call_params_t, x) |
191 | mov(reg_process_n_elems, ptr[reg_param + PARAM_OFF(process_n_elems)]); |
192 | mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]); |
193 | if (pd_->is_fwd()) |
194 | mov(reg_src, ptr[reg_param + PARAM_OFF(src)]); |
195 | else { |
196 | mov(reg_diff_src, ptr[reg_param + PARAM_OFF(src)]); // src is reused |
197 | mov(reg_diff_dst, ptr[reg_param + PARAM_OFF(diff_dst)]); |
198 | } |
199 | if (need_scratchpad_) { |
200 | mov(reg_interim, ptr[reg_param + PARAM_OFF(interim)]); |
201 | } |
202 | mov(reg_src_scales, ptr[reg_param + PARAM_OFF(src_scales)]); |
203 | mov(reg_dst_scales, ptr[reg_param + PARAM_OFF(dst_scales)]); |
204 | #undef PARAM_OFF |
205 | } |
206 | |
207 | Address diff_src_ptr(size_t offt = 0) { |
208 | return vmmword[reg_diff_src + reg_src_spat_offt + offt]; |
209 | } |
210 | |
211 | Address src_ptr(size_t offt = 0) { |
212 | return vmmword[reg_src + reg_src_spat_offt + offt]; |
213 | } |
214 | |
215 | Address interim_ptr(size_t offt = 0) { |
216 | return vmmword[reg_interim + reg_interim_spat_offt + offt]; |
217 | } |
218 | |
219 | Address dst_ptr(size_t offt = 0) { |
220 | return vmmword[reg_dst + reg_dst_spat_offt + offt]; |
221 | } |
222 | |
223 | Address diff_dst_ptr(size_t offt = 0) { |
224 | return vmmword[reg_diff_dst + reg_diff_dst_spat_offt + offt]; |
225 | } |
226 | |
227 | enum class op_t : unsigned { max, sum }; |
228 | |
229 | void perform_op(Vmm v, Vmm vtmp, op_t op) { |
230 | if (op == op_t::max) |
231 | uni_vmaxps(v, v, vtmp); |
232 | else if (op == op_t::sum) |
233 | uni_vaddps(v, v, vtmp); |
234 | } |
235 | |
236 | void get_horizontal_op(const Vmm &vsrc, const Vmm &vtmp, op_t op) { |
237 | const Zmm &zsrc = Zmm(vsrc.getIdx()); |
238 | const Zmm &ztmp = Zmm(vtmp.getIdx()); |
239 | const Ymm &ysrc = Ymm(vsrc.getIdx()); |
240 | const Ymm &ytmp = Ymm(vtmp.getIdx()); |
241 | |
242 | if (is_superset(isa, avx512_core)) { |
243 | vshuff32x4(ztmp, zsrc, zsrc, 0x4E); // 256-bit shuffle |
244 | perform_op(vsrc, vtmp, op); |
245 | vshuff32x4(ztmp, zsrc, zsrc, 0xB1); // 128/256-bit shuffle |
246 | perform_op(vsrc, vtmp, op); |
247 | } else if (is_superset(isa, avx2)) { |
248 | vperm2f128(ytmp, ysrc, ysrc, 0x1); // 128/256-bit shuffle |
249 | perform_op(vsrc, vtmp, op); |
250 | } |
251 | uni_vshufps(vtmp, vsrc, vsrc, 0x4E); // 64/128-bit shuffle |
252 | perform_op(vsrc, vtmp, op); |
253 | uni_vshufps(vtmp, vsrc, vsrc, 0xB1); // 32/64-bit shuffle |
254 | perform_op(vsrc, vtmp, op); |
255 | } |
256 | |
257 | template <typename body_t> |
258 | void axis_loop(body_t body) { |
259 | Label main_loop, tail_loop, tail_axis; |
260 | |
261 | // reverse_spat_offt to dispatch between labels |
262 | mov(reg_reverse_n_elems, reg_process_n_elems); |
263 | xor_(reg_src_spat_offt, reg_src_spat_offt); // src/diff_src addr |
264 | xor_(reg_dst_spat_offt, reg_dst_spat_offt); // dst addr |
265 | if (need_scratchpad_) |
266 | xor_(reg_interim_spat_offt, reg_interim_spat_offt); // scratch addr |
267 | if (!pd_->is_fwd()) |
268 | xor_(reg_diff_dst_spat_offt, reg_diff_dst_spat_offt); // d_dst addr |
269 | L(main_loop); |
270 | { |
271 | if (n_loops_) { |
272 | cmp(reg_reverse_n_elems, unroll_regs_ * process_n_elems_); |
273 | jl(tail_loop, T_NEAR); |
274 | |
275 | body(unroll_regs_, false); |
276 | sub(reg_reverse_n_elems, unroll_regs_ * process_n_elems_); |
277 | add(reg_src_spat_offt, unroll_regs_ * src_axis_stride_); |
278 | add(reg_dst_spat_offt, unroll_regs_ * dst_axis_stride_); |
279 | if (need_scratchpad_) |
280 | add(reg_interim_spat_offt, |
281 | unroll_regs_ * interim_axis_stride_); |
282 | if (!pd_->is_fwd()) |
283 | add(reg_diff_dst_spat_offt, |
284 | unroll_regs_ * diff_dst_axis_stride_); |
285 | jmp(main_loop); |
286 | } |
287 | } |
288 | |
289 | L(tail_loop); |
290 | { |
291 | if (loop_tail_) { |
292 | body(loop_tail_, false); |
293 | add(reg_src_spat_offt, loop_tail_ * src_axis_stride_); |
294 | add(reg_dst_spat_offt, loop_tail_ * dst_axis_stride_); |
295 | if (need_scratchpad_) |
296 | add(reg_interim_spat_offt, |
297 | loop_tail_ * interim_axis_stride_); |
298 | if (!pd_->is_fwd()) |
299 | add(reg_diff_dst_spat_offt, |
300 | loop_tail_ * diff_dst_axis_stride_); |
301 | } |
302 | } |
303 | |
304 | L(tail_axis); |
305 | { |
306 | if (axis_simd_tail_) { body(1, true); } |
307 | } |
308 | } |
309 | |
310 | void uni_vaddps_maybe_tail( |
311 | const Vmm &v1, const Vmm &v2, const Vmm &vtmp, const bool tail) { |
312 | if (tail) { |
313 | if (is_superset(isa, avx512_core)) { |
314 | uni_vaddps(v1 | tail_opmask, v1, v2); |
315 | } else { |
316 | uni_vpxor(vtmp, vtmp, vtmp); |
317 | uni_vblendvps(vtmp, vtmp, v2, tail_vmask); |
318 | uni_vaddps(v1, v1, vtmp); |
319 | } |
320 | } else |
321 | uni_vaddps(v1, v1, v2); |
322 | } |
323 | |
324 | void uni_vmaxps_maybe_tail( |
325 | const Vmm &v1, const Vmm &v2, const Vmm &vtmp, const bool tail) { |
326 | if (tail) { |
327 | if (is_superset(isa, avx512_core)) { |
328 | uni_vmaxps(v1 | tail_opmask, v1, v2); |
329 | } else if (is_superset(isa, avx)) { |
330 | uni_vblendvps(v2, vneg_flt_max, v2, tail_vmask); |
331 | uni_vmaxps(v1, v1, v2); |
332 | } else { |
333 | uni_vmovups(vtmp, v2); |
334 | uni_vmovups(v2, vneg_flt_max); |
335 | uni_vblendvps(v2, v2, vtmp, tail_vmask); |
336 | uni_vmaxps(v1, v1, v2); |
337 | } |
338 | } else |
339 | uni_vmaxps(v1, v1, v2); |
340 | } |
341 | |
342 | void store(const Address &addr, const Vmm &vmm, data_type_t dt, |
343 | bool tail = false) { |
344 | // Use temporary register in storing when convertion is needed |
345 | // Or we need to restore data back to fp32 since we apply exp after |
346 | // storing and data should be fp32 |
347 | const bool need_restore = is_logsoftmax_ && dt != data_type::f32; |
348 | Vmm src_vmm = vmm; |
349 | |
350 | if (tail && axis_is_blocked_) { |
351 | if (is_superset(isa, avx512_core) |
352 | && utils::one_of(dt, data_type::f32, data_type::bf16, |
353 | data_type::f16)) { |
354 | src_vmm = vzero | tail_opmask; |
355 | uni_vxorps(vzero, vzero, vzero); |
356 | uni_vmovups(src_vmm, vmm); |
357 | src_vmm = vzero; |
358 | } else { |
359 | uni_vpxor(vzero, vzero, vzero); |
360 | uni_vblendvps(vzero, vzero, src_vmm, tail_vmask); |
361 | src_vmm = vzero; |
362 | } |
363 | } else if (need_restore) { |
364 | uni_vmovups(vcvt_vmm, vmm); |
365 | src_vmm = vcvt_vmm; |
366 | } |
367 | |
368 | io_[dt]->store(src_vmm, addr, tail && !axis_is_blocked_); |
369 | } |
370 | |
371 | // Use ne_convert instruction to load xf16 even/odd elements from memory |
372 | void accumulate_avx2_ne_xf16_vmax() { |
373 | // flush to -FLT_MAX before accumulation |
374 | uni_vmovups(vmax, vneg_flt_max); |
375 | |
376 | axis_loop([&](int unroll, bool tail = false) { |
377 | for (int i = 0; i < unroll; i += 2) { |
378 | const bool can_load_two_simdw = unroll - i >= 2; |
379 | Vmm vreg_tmp_src_even = Vmm(i + 1); |
380 | Vmm vreg_tmp_src_odd = Vmm(i + 2); |
381 | vtmp = Vmm(i + 3); |
382 | if (can_load_two_simdw) { |
383 | io_[src_d_.data_type()]->load_two_simdw_xf16( |
384 | src_ptr(src_axis_stride_ * i), vreg_tmp_src_even, |
385 | vreg_tmp_src_odd); |
386 | } else |
387 | io_[src_d_.data_type()]->load(src_ptr(src_axis_stride_ * i), |
388 | vreg_tmp_src_even, tail); |
389 | uni_vmaxps_maybe_tail(vmax, vreg_tmp_src_even, vtmp, tail); |
390 | if (can_load_two_simdw) |
391 | uni_vmaxps_maybe_tail(vmax, vreg_tmp_src_odd, vtmp, tail); |
392 | } |
393 | }); |
394 | |
395 | get_horizontal_op(vmax, vtmp = vsum, op_t::max); |
396 | } |
397 | |
398 | void accumulate_vmax() { |
399 | if (is_avx2_ne_xf16_ && is_data_type_xf16(src_d_.data_type())) { |
400 | accumulate_avx2_ne_xf16_vmax(); |
401 | return; |
402 | } |
403 | |
404 | // flush to -FLT_MAX before accumulation |
405 | uni_vmovups(vmax, vneg_flt_max); |
406 | |
407 | axis_loop([&](int unroll, bool tail = false) { |
408 | for (int i = 0; i < unroll; i++) { |
409 | Vmm vreg_tmp_src = Vmm(i + 1); |
410 | vtmp = Vmm(i + 2); |
411 | // do maxps directly from memory on f32 avx2 for performance purpose |
412 | if (!tail && isa == avx2 |
413 | && src_d_.data_type() == data_type::f32) { |
414 | uni_vmaxps(vmax, vmax, src_ptr(src_axis_stride_ * i)); |
415 | } else { |
416 | io_[src_d_.data_type()]->load( |
417 | src_ptr(src_axis_stride_ * i), vreg_tmp_src, tail); |
418 | uni_vmaxps_maybe_tail(vmax, vreg_tmp_src, vtmp, tail); |
419 | } |
420 | } |
421 | }); |
422 | |
423 | get_horizontal_op(vmax, vtmp = vsum, op_t::max); |
424 | } |
425 | |
426 | // Use ne_convert instruction to load xf16 even/odd elements from memory |
427 | void accumulate_avx2_ne_xf16_vsum() { |
428 | // Initialize saturation vector register |
429 | io_.init_saturate_f32({dst_d_.data_type()}); |
430 | |
431 | uni_vpxor(vsum, vsum, vsum); // flush to zero before accumulation |
432 | |
433 | axis_loop([&](int unroll, bool tail = false) { |
434 | for (int i = 0; i < unroll; i += 2) { |
435 | const bool can_load_two_simdw = unroll - i >= 2; |
436 | Vmm vreg_tmp_src_even = Vmm(i + 1); |
437 | Vmm vreg_tmp_src_odd = Vmm(i + 2); |
438 | vtmp = Vmm(i + 3); |
439 | if (can_load_two_simdw) { |
440 | io_[src_d_.data_type()]->load_two_simdw_xf16( |
441 | src_ptr(src_axis_stride_ * i), vreg_tmp_src_even, |
442 | vreg_tmp_src_odd); |
443 | io_[src_d_.data_type()]->merge_interleaved_to_plain( |
444 | vreg_tmp_src_even, vreg_tmp_src_odd, vtmp); |
445 | } else |
446 | io_[src_d_.data_type()]->load(src_ptr(src_axis_stride_ * i), |
447 | vreg_tmp_src_even, tail); |
448 | for (int i_odd = 0; i_odd < 2 && i_odd + i < unroll; i_odd++) { |
449 | const auto vreg_tmp_src |
450 | = i_odd ? vreg_tmp_src_odd : vreg_tmp_src_even; |
451 | uni_vsubps(vreg_tmp_src, vreg_tmp_src, vmax); |
452 | if (is_logsoftmax_) // store before applying exp |
453 | store(dst_ptr(dst_axis_stride_ * (i + i_odd)), |
454 | vreg_tmp_src, dst_d_.data_type(), tail); |
455 | exp_injector_->compute_vector(vreg_tmp_src.getIdx()); |
456 | uni_vaddps_maybe_tail(vsum, vreg_tmp_src, vtmp, tail); |
457 | if (is_softmax_) // store after applying exp |
458 | store(dst_ptr(dst_axis_stride_ * (i + i_odd)), |
459 | vreg_tmp_src, dst_d_.data_type(), tail); |
460 | } |
461 | } |
462 | }); |
463 | |
464 | get_horizontal_op(vsum, vtmp = vmax, op_t::sum); |
465 | if (is_softmax_) uni_vdivps(vsum, vone, vsum, vtmp = vmax); |
466 | if (is_logsoftmax_) log_injector_->compute_vector(vsum.getIdx()); |
467 | } |
468 | |
469 | void accumulate_vsum() { |
470 | if (is_avx2_ne_xf16_ && is_data_type_xf16(src_d_.data_type())) { |
471 | accumulate_avx2_ne_xf16_vsum(); |
472 | return; |
473 | } |
474 | |
475 | // Initialize saturation vector register |
476 | io_.init_saturate_f32({dst_d_.data_type()}); |
477 | |
478 | uni_vpxor(vsum, vsum, vsum); // flush to zero before accumulation |
479 | |
480 | axis_loop([&](int unroll, bool tail = false) { |
481 | for (int i = 0; i < unroll; i++) { |
482 | Vmm vreg_tmp_src = Vmm(i + 1); |
483 | vtmp = Vmm(i + 2); |
484 | io_[src_d_.data_type()]->load( |
485 | src_ptr(src_axis_stride_ * i), vreg_tmp_src, tail); |
486 | uni_vsubps(vreg_tmp_src, vreg_tmp_src, vmax); |
487 | if (is_logsoftmax_) { // store before applying exp |
488 | if (need_scratchpad_) |
489 | store(interim_ptr(interim_axis_stride_ * i), |
490 | vreg_tmp_src, data_type::f32, tail); |
491 | else |
492 | store(dst_ptr(dst_axis_stride_ * i), vreg_tmp_src, |
493 | dst_d_.data_type(), tail); |
494 | } |
495 | exp_injector_->compute_vector(vreg_tmp_src.getIdx()); |
496 | uni_vaddps_maybe_tail(vsum, vreg_tmp_src, vtmp, tail); |
497 | if (is_softmax_) { // store after applying exp |
498 | if (need_scratchpad_) |
499 | store(interim_ptr(interim_axis_stride_ * i), |
500 | vreg_tmp_src, data_type::f32, tail); |
501 | else |
502 | store(dst_ptr(dst_axis_stride_ * i), vreg_tmp_src, |
503 | dst_d_.data_type(), tail); |
504 | } |
505 | } |
506 | }); |
507 | |
508 | get_horizontal_op(vsum, vtmp = vmax, op_t::sum); |
509 | if (is_softmax_) uni_vdivps(vsum, vone, vsum, vtmp = vmax); |
510 | if (is_logsoftmax_) log_injector_->compute_vector(vsum.getIdx()); |
511 | } |
512 | |
513 | // Use ne_convert instruction to load xf16 even/odd elements from memory |
514 | void compute_avx2_ne_xf16_dst() { |
515 | axis_loop([&](int unroll, bool tail = false) { |
516 | for (int i = 0; i < unroll; i += 2) { |
517 | const bool can_load_two_simdw = unroll - i >= 2; |
518 | Vmm vreg_tmp_src_even = Vmm(i + 1); |
519 | Vmm vreg_tmp_src_odd = Vmm(i + 2); |
520 | vtmp = Vmm(i + 3); |
521 | if (can_load_two_simdw) { |
522 | io_[dst_d_.data_type()]->load_two_simdw_xf16( |
523 | dst_ptr(dst_axis_stride_ * i), vreg_tmp_src_even, |
524 | vreg_tmp_src_odd); |
525 | io_[dst_d_.data_type()]->merge_interleaved_to_plain( |
526 | vreg_tmp_src_even, vreg_tmp_src_odd, vtmp); |
527 | } else |
528 | io_[dst_d_.data_type()]->load(dst_ptr(dst_axis_stride_ * i), |
529 | vreg_tmp_src_even, tail); |
530 | for (int i_odd = 0; i_odd < 2 && i_odd + i < unroll; i_odd++) { |
531 | const auto vreg_tmp_src |
532 | = i_odd ? vreg_tmp_src_odd : vreg_tmp_src_even; |
533 | if (is_softmax_) |
534 | uni_vmulps(vreg_tmp_src, vreg_tmp_src, vsum); |
535 | if (is_logsoftmax_) |
536 | uni_vsubps(vreg_tmp_src, vreg_tmp_src, vsum); |
537 | |
538 | store(dst_ptr(dst_axis_stride_ * (i + i_odd)), vreg_tmp_src, |
539 | dst_d_.data_type(), tail); |
540 | } |
541 | } |
542 | }); |
543 | } |
544 | |
545 | void compute_dst() { |
546 | if (is_avx2_ne_xf16_ && is_data_type_xf16(dst_d_.data_type())) { |
547 | compute_avx2_ne_xf16_dst(); |
548 | return; |
549 | } |
550 | |
551 | axis_loop([&](int unroll, bool tail = false) { |
552 | for (int i = 0; i < unroll; i++) { |
553 | Vmm vreg_tmp_src = Vmm(i + 1); |
554 | if (need_scratchpad_) |
555 | io_[data_type::f32]->load( |
556 | interim_ptr(interim_axis_stride_ * i), vreg_tmp_src, |
557 | tail); |
558 | else |
559 | io_[dst_d_.data_type()]->load( |
560 | dst_ptr(dst_axis_stride_ * i), vreg_tmp_src, tail); |
561 | |
562 | if (is_softmax_) uni_vmulps(vreg_tmp_src, vreg_tmp_src, vsum); |
563 | if (is_logsoftmax_) |
564 | uni_vsubps(vreg_tmp_src, vreg_tmp_src, vsum); |
565 | |
566 | if (is_superset(isa, avx512_core)) { |
567 | Vmm vscale = vmax; |
568 | uni_vmovups(vscale, ptr[reg_src_scales]); |
569 | uni_vmulps(vreg_tmp_src, vreg_tmp_src, vscale); |
570 | // Reserved spot for post-ops injector |
571 | uni_vmovups(vscale, ptr[reg_dst_scales]); |
572 | uni_vmulps(vreg_tmp_src, vreg_tmp_src, vscale); |
573 | } |
574 | store(dst_ptr(dst_axis_stride_ * i), vreg_tmp_src, |
575 | dst_d_.data_type(), tail); |
576 | } |
577 | }); |
578 | } |
579 | |
580 | void accumulate_vsbr() { |
581 | uni_vpxor(vsbr, vsbr, vsbr); // flush to zero before accumulation |
582 | |
583 | axis_loop([&](int unroll, bool tail = false) { |
584 | for (int i = 0; i < unroll; i++) { |
585 | Vmm vreg_tmp_dst = Vmm(i * 2 + 1); |
586 | Vmm vreg_tmp_diff_dst = Vmm(i * 2 + 2); |
587 | io_[diff_dst_d_.data_type()]->load( |
588 | diff_dst_ptr(diff_dst_axis_stride_ * i), |
589 | vreg_tmp_diff_dst, tail); |
590 | if (is_softmax_) { |
591 | io_[dst_d_.data_type()]->load( |
592 | dst_ptr(dst_axis_stride_ * i), vreg_tmp_dst, tail); |
593 | uni_vmulps( |
594 | vreg_tmp_diff_dst, vreg_tmp_diff_dst, vreg_tmp_dst); |
595 | } |
596 | uni_vaddps(vsbr, vsbr, vreg_tmp_diff_dst); |
597 | } |
598 | }); |
599 | |
600 | get_horizontal_op(vsbr, vtmp = vmax, op_t::sum); |
601 | } |
602 | |
603 | void compute_diff_src() { |
604 | axis_loop([&](int unroll, bool tail = false) { |
605 | for (int i = 0; i < unroll; i++) { |
606 | Vmm vreg_tmp_dst = Vmm(i * 2 + 1); |
607 | Vmm vreg_tmp_diff_dst = Vmm(i * 2 + 2); |
608 | io_[dst_d_.data_type()]->load( |
609 | dst_ptr(dst_axis_stride_ * i), vreg_tmp_dst, tail); |
610 | io_[diff_dst_d_.data_type()]->load( |
611 | diff_dst_ptr(diff_dst_axis_stride_ * i), |
612 | vreg_tmp_diff_dst, tail); |
613 | if (is_softmax_) { |
614 | uni_vsubps(vreg_tmp_diff_dst, vreg_tmp_diff_dst, vsbr); |
615 | uni_vmulps( |
616 | vreg_tmp_diff_dst, vreg_tmp_dst, vreg_tmp_diff_dst); |
617 | } |
618 | if (is_logsoftmax_) { |
619 | exp_injector_->compute_vector(vreg_tmp_dst.getIdx()); |
620 | uni_vfnmadd231ps(vreg_tmp_diff_dst, vreg_tmp_dst, vsbr); |
621 | } |
622 | store(diff_src_ptr(src_axis_stride_ * i), vreg_tmp_diff_dst, |
623 | src_d_.data_type(), tail); |
624 | } |
625 | }); |
626 | } |
627 | |
628 | void forward() { |
629 | accumulate_vmax(); |
630 | accumulate_vsum(); |
631 | compute_dst(); |
632 | } |
633 | |
634 | void backward() { |
635 | accumulate_vsbr(); |
636 | compute_diff_src(); |
637 | } |
638 | |
639 | // either this stub or duplication at each jit_binary_t ctor due to methods |
640 | // that are participated are not defined at the moment of base ctor |
641 | // initialization. |
642 | void generate() override { |
643 | if (pd_->is_fwd() || is_logsoftmax_) |
644 | exp_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this, |
645 | alg_kind::eltwise_exp, 0.0f, 0.0f, 1.0f, true, |
646 | reg_exp_injector_table, injector_mask)); |
647 | if (pd_->is_fwd() && is_logsoftmax_) { |
648 | log_injector_.reset(new jit_uni_eltwise_injector_f32<isa>(this, |
649 | alg_kind::eltwise_log, 0.0f, 0.0f, 1.0f, true, |
650 | reg_log_injector_table, injector_mask)); |
651 | } |
652 | |
653 | compute_predefined_variables(); |
654 | preamble(); |
655 | io_.init_bf16(); |
656 | if (exp_injector_) exp_injector_->load_table_addr(); |
657 | if (log_injector_) log_injector_->load_table_addr(); |
658 | if (axis_simd_tail_) io_.prepare_tail_mask(); |
659 | load_common_params(); |
660 | if (pd_->is_fwd()) |
661 | forward(); |
662 | else |
663 | backward(); |
664 | postamble(); |
665 | if (exp_injector_) exp_injector_->prepare_table(); |
666 | if (log_injector_) log_injector_->prepare_table(); |
667 | } |
668 | |
669 | jit_softmax_t(const softmax_pd_t *pd) |
670 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa) |
671 | , pd_(pd) |
672 | , src_d_(pd_->is_fwd() ? pd_->src_md() : pd_->diff_src_md()) |
673 | , dst_d_(pd_->dst_md()) |
674 | , diff_dst_d_(pd_->diff_dst_md()) { |
675 | is_bf16_ = utils::one_of( |
676 | data_type::bf16, src_d_.data_type(), dst_d_.data_type()); |
677 | is_f16_ = utils::one_of( |
678 | data_type::f16, src_d_.data_type(), dst_d_.data_type()); |
679 | simd_w_ = vlen / sizeof(float); // bf16 works on ymms |
680 | is_avx2_ne_xf16_ |
681 | = isa == avx2 && mayiuse(avx2_vnni_2) && (is_bf16_ || is_f16_); |
682 | axis_simd_full_ = pd_->axis_size() / simd_w_; |
683 | axis_simd_tail_ = pd_->axis_size() % simd_w_; |
684 | need_scratchpad_ = utils::one_of( |
685 | dst_d_.data_type(), data_type::u8, data_type::s8); |
686 | |
687 | io::io_conf_t io_conf; |
688 | io::io_tail_conf_t io_tail_conf(simd_w_, axis_simd_tail_, |
689 | tail_opmask_idx_, tail_vmask.getIdx(), reg_tmp); |
690 | io::io_emu_bf16_conf_t io_bf16_conf(bf16_emu_zmm_1_idx_, |
691 | bf16_emu_zmm_2_idx_, bf16_emu_zmm_3_idx_, reg_tmp, |
692 | bf16_emu_zmm_4_idx_); |
693 | io::io_saturation_conf_t io_saturation_conf( |
694 | vzero.getIdx(), vsaturation_ubound.getIdx(), reg_tmp); |
695 | io_ = io::jit_io_multi_dt_helper_t<Vmm>(this, get_io_isa(), |
696 | {src_d_.data_type(), dst_d_.data_type(), |
697 | data_type::f32 /* stats */}, |
698 | io_conf, io_tail_conf, io_bf16_conf, |
699 | {{dst_d_.data_type(), io_saturation_conf}}); |
700 | } |
701 | }; |
702 | |
703 | template <cpu_isa_t isa> |
704 | jit_uni_softmax_fwd_t<isa>::jit_uni_softmax_fwd_t(const pd_t *apd) |
705 | : primitive_t(apd) |
706 | , softmax_driver_(new softmax_impl::driver_t<isa>(pd())) {} |
707 | |
708 | template <cpu_isa_t isa> |
709 | jit_uni_softmax_fwd_t<isa>::~jit_uni_softmax_fwd_t() { |
710 | delete softmax_driver_; |
711 | } |
712 | |
713 | template <cpu_isa_t isa> |
714 | status_t jit_uni_softmax_fwd_t<isa>::init(engine_t *engine) { |
715 | return softmax_driver_->create_kernel(); |
716 | } |
717 | |
718 | template <cpu_isa_t isa> |
719 | status_t jit_uni_softmax_fwd_t<isa>::execute(const exec_ctx_t &ctx) const { |
720 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_SRC); |
721 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST); |
722 | auto scratchpad_ptr = ctx.get_scratchpad_grantor().template get<char>( |
723 | memory_tracking::names::key_softmax_interim_store); |
724 | |
725 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
726 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
727 | |
728 | const memory_desc_wrapper src_d(pd()->src_md()); |
729 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
730 | const auto src_data_type_size = src_d.data_type_size(); |
731 | const auto dst_data_type_size = dst_d.data_type_size(); |
732 | const auto &bd = src_d.blocking_desc(); |
733 | const auto axis = pd()->axis(); |
734 | |
735 | const auto axis_size_padded = pd()->axis_size(true); |
736 | const auto inner_stride |
737 | = bd.inner_nblks ? bd.inner_blks[bd.inner_nblks - 1] : (dim_t)1; |
738 | const auto inner_size = bd.strides[axis] / inner_stride; |
739 | const auto process_n_elems = pd()->axis_size() * inner_size; |
740 | const auto outer_stride = axis_size_padded * inner_size; |
741 | const auto outer_size = src_d.nelems(true) / outer_stride; |
742 | |
743 | const int nthr = pd()->nthr_; |
744 | |
745 | parallel_nd_ext(nthr, outer_size, inner_size, |
746 | [&](int ithr, int, dim_t ou, dim_t in) { |
747 | dim_t offset = (ou * outer_stride + in * inner_stride); |
748 | const char *src_ptr = src + offset * src_data_type_size; |
749 | char *dst_ptr = dst + offset * dst_data_type_size; |
750 | char *interim_ptr = scratchpad_ptr ? scratchpad_ptr |
751 | + ithr * axis_size_padded * sizeof(float) |
752 | : nullptr; |
753 | softmax_driver_->exec(src_ptr, dst_ptr, interim_ptr, src_scales, |
754 | dst_scales, process_n_elems); |
755 | }); |
756 | |
757 | return status::success; |
758 | } |
759 | |
760 | template <cpu_isa_t isa> |
761 | jit_uni_softmax_bwd_t<isa>::jit_uni_softmax_bwd_t(const pd_t *apd) |
762 | : primitive_t(apd) |
763 | , softmax_driver_(new softmax_impl::driver_t<isa>(pd())) {} |
764 | |
765 | template <cpu_isa_t isa> |
766 | jit_uni_softmax_bwd_t<isa>::~jit_uni_softmax_bwd_t() { |
767 | delete softmax_driver_; |
768 | } |
769 | |
770 | template <cpu_isa_t isa> |
771 | status_t jit_uni_softmax_bwd_t<isa>::init(engine_t *engine) { |
772 | return softmax_driver_->create_kernel(); |
773 | } |
774 | |
775 | template <cpu_isa_t isa> |
776 | status_t jit_uni_softmax_bwd_t<isa>::execute(const exec_ctx_t &ctx) const { |
777 | auto dst = CTX_IN_MEM(const char *, DNNL_ARG_DST); |
778 | auto diff_dst = CTX_IN_MEM(const char *, DNNL_ARG_DIFF_DST); |
779 | auto diff_src = CTX_OUT_MEM(char *, DNNL_ARG_DIFF_SRC); |
780 | |
781 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
782 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
783 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
784 | const auto dst_data_type_size = dst_d.data_type_size(); |
785 | const auto diff_dst_data_type_size = diff_dst_d.data_type_size(); |
786 | const auto diff_src_data_type_size = diff_src_d.data_type_size(); |
787 | const auto &bd = dst_d.blocking_desc(); |
788 | const auto axis = pd()->axis(); |
789 | |
790 | const auto inner_stride |
791 | = bd.inner_nblks ? bd.inner_blks[bd.inner_nblks - 1] : (dim_t)1; |
792 | const auto inner_size = bd.strides[axis] / inner_stride; |
793 | const auto process_n_elems = pd()->axis_size() * inner_size; |
794 | const auto outer_stride = pd()->axis_size(true) * inner_size; |
795 | const auto outer_size = dst_d.nelems(true) / outer_stride; |
796 | |
797 | parallel_nd(outer_size, inner_size, [&](dim_t ou, dim_t in) { |
798 | dim_t offset = (ou * outer_stride + in * inner_stride); |
799 | char *diff_src_ptr = diff_src + offset * diff_src_data_type_size; |
800 | const char *dst_ptr = dst + offset * dst_data_type_size; |
801 | const char *diff_dst_ptr = diff_dst + offset * diff_dst_data_type_size; |
802 | softmax_driver_->exec( |
803 | diff_src_ptr, dst_ptr, diff_dst_ptr, process_n_elems); |
804 | }); |
805 | |
806 | return status::success; |
807 | } |
808 | |
809 | namespace softmax_impl { |
810 | |
811 | template <cpu_isa_t isa> |
812 | struct driver_t : public c_compatible { |
813 | |
814 | driver_t(const softmax_pd_t *pd) : pd_(pd), ker_(pd_) {} |
815 | |
816 | void exec(const void *src, void *dst, void *interim, const void *src_scales, |
817 | const void *dst_scales, const dim_t process_n_elems) { |
818 | typename jit_softmax_t<isa>::call_params_t p; |
819 | p.process_n_elems = process_n_elems; |
820 | p.src = src; |
821 | p.dst = dst; |
822 | p.interim = interim; |
823 | p.src_scales = src_scales; |
824 | p.dst_scales = dst_scales; |
825 | ker_(&p); |
826 | } |
827 | |
828 | void exec(void *diff_src, const void *dst, const void *diff_dst, |
829 | const dim_t process_n_elems) { |
830 | typename jit_softmax_t<isa>::call_params_t p; |
831 | p.process_n_elems = process_n_elems; |
832 | p.src = diff_src; |
833 | p.dst = dst; |
834 | p.diff_dst = diff_dst; |
835 | ker_(&p); |
836 | } |
837 | |
838 | status_t create_kernel() { return ker_.create_kernel(); } |
839 | |
840 | private: |
841 | const softmax_pd_t *pd_; |
842 | jit_softmax_t<isa> ker_; |
843 | }; |
844 | |
845 | } // namespace softmax_impl |
846 | |
847 | /* struct instantiation */ |
848 | template struct jit_uni_softmax_fwd_t<sse41>; |
849 | template struct jit_uni_softmax_fwd_t<avx2>; |
850 | template struct jit_uni_softmax_fwd_t<avx512_core>; |
851 | template struct jit_uni_softmax_bwd_t<avx512_core>; |
852 | |
853 | } // namespace x64 |
854 | } // namespace cpu |
855 | } // namespace impl |
856 | } // namespace dnnl |
857 | |