1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_UNI_BINARY_KERNEL_HPP
18#define CPU_X64_UNI_BINARY_KERNEL_HPP
19
20#include <cassert>
21
22#include "common/c_types_map.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/x64/cpu_isa_traits.hpp"
27#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
28#include "cpu/x64/jit_generator.hpp"
29#include "cpu/x64/jit_primitive_conf.hpp"
30#include "cpu/x64/utils/jit_io_helper.hpp"
31
32#include "cpu/cpu_binary_pd.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39using namespace Xbyak;
40
41struct binary_kernel_t : public jit_generator {
42 using op_t = binary_op_t;
43 using bcast_t = binary_bcast_t;
44
45 binary_kernel_t(const size_t vlen, const binary_pd_t *pd,
46 const jit_binary_conf_t conf, const char *name,
47 bool tail_kernel = false);
48 ~binary_kernel_t() override = default;
49
50 void operator()(jit_binary_call_s *p) { jit_generator::operator()(p); }
51
52 size_t simd_w() const noexcept { return simd_w_; }
53 size_t vlen() const noexcept { return vlen_; }
54
55protected:
56 size_t get_tail_size() const;
57
58 const size_t vlen_;
59 const size_t simd_w_;
60 constexpr static int vmm_start_idx_ = 1;
61 const binary_pd_t *pd_;
62 const jit_binary_conf_t conf_;
63 const bool is_tail_kernel_;
64 const bool is_src1_outer_dims_tail_;
65 const size_t tail_size_;
66 const size_t padding_tail_size_;
67};
68
69template <cpu_isa_t isa, typename Vmm>
70struct jit_uni_binary_kernel_t : public binary_kernel_t {
71 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_binary_kernel_t)
72
73 const AddressFrame &vmmword = (std::is_same<Vmm, Xmm>::value)
74 ? xword
75 : ((std::is_same<Vmm, Ymm>::value) ? yword : zword);
76
77 const bool is_avx512 = is_superset(isa, avx512_core);
78
79 const Reg64 reg_param_ = abi_param1;
80 const Reg64 reg_src0_ = r8;
81 const Reg64 reg_src1_ = r9;
82 const Reg64 reg_dst_ = r10;
83 const Reg64 reg_offt_src0_ = r11;
84 const Reg64 reg_outer_dims_range_ = r12;
85 const Reg64 reg_offt_src1_ = rax;
86 const Reg64 reg_src1_stride_range_ = r15;
87 const Reg64 reg_reverse_src1_stride_range_ = rax;
88 const Reg64 reg_reverse_spat_offt_ = r13;
89 const Reg64 reg_tmp_ = r14;
90 const Reg64 reg_tmp1_ = abi_not_param1;
91 const Reg64 reg_elt_inj_table_ = r15;
92 const Reg64 reg_off_rhs_postops_ = rdx;
93 const Reg64 reg_scales_src0_ = rbx;
94 const Reg64 reg_scales_src1_ = rbp;
95 const Reg64 reg_offt_dst_ = rdx;
96 const Opmask tail_opmask_ = k2;
97 const Opmask cmp_mask = k3;
98 const Opmask full_mask_ = k4;
99 const Vmm vmm_tail_vmask_ = Vmm(0);
100 const Vmm vreg_sum_scale_ = Vmm(is_avx512 ? 17 : 9);
101 const Xmm xreg_sum_scale_ = Xmm(9);
102 const Vmm vreg_zero_ = Vmm(is_avx512 ? 18 : 10);
103 const Vmm vreg_one_ = Vmm(is_avx512 ? 19 : 11);
104 const Vmm vreg_saturation_ubound_ = Vmm(is_avx512 ? 20 : 12);
105 const Vmm vreg_bcast_src1_ = Vmm(is_avx512 ? 21 : 13);
106 const Xmm xreg_bcast_src1_ = Xmm(13);
107 const Vmm vreg_scales_src0_ = Vmm(is_avx512 ? 22 : 14);
108 const Vmm vreg_scales_src1_ = Vmm(is_avx512 ? 23 : 15);
109
110 const Zmm vreg_bf16_emu_1_ = Zmm(26);
111 const Zmm vreg_bf16_emu_2_ = Zmm(27);
112 const Zmm vreg_bf16_emu_3_ = Zmm(28);
113 const Zmm vreg_bf16_emu_4_ = Zmm(29);
114
115 const Vmm vmm_full_mask_ = Vmm(is_avx512 ? 24 : 5);
116 const Vmm vmm_tmp_gather_ = Vmm(is_avx512 ? 25 : 6);
117 const Vmm vmm_indices_ = Vmm(is_avx512 ? 30 : 7);
118 const Vmm vmm_gathered_src_ = Vmm(is_avx512 ? 31 : 8);
119
120 const size_t unroll_regs_ = is_avx512 ? 8 : 4;
121 const size_t offt_src0_;
122 const size_t offt_src1_;
123
124 static constexpr cpu_isa_t inject_isa
125 = isa == avx512_core_bf16 ? avx512_core : isa;
126 io::jit_io_multi_dt_helper_t<Vmm> io_;
127 std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa, Vmm>>
128 postops_injector_;
129 const Opmask elt_inj_opmask_ = k1;
130
131 void init();
132 void init_post_ops_injector();
133 void apply_postops(int unroll, bool tail);
134 void load_kernel_params();
135 Address src0_ptr(size_t offt = 0);
136 Address src1_ptr(size_t offt = 0);
137 Address dst_ptr(size_t offt = 0);
138 unsigned int cmp_predicate(alg_kind_t alg);
139 void perform_op(
140 const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1);
141 void prepare_isa_kernel();
142 void compute_bcast(bool tail);
143 void load_src1(const Vmm &vreg_src1, const int offt, bool tail);
144 void compute_dst(int unroll, bool tail);
145 void forward();
146 void forward_over_outer_dims();
147 void generate() override;
148
149 jit_uni_binary_kernel_t(const binary_pd_t *pd, const jit_binary_conf_t conf,
150 bool tail_kernel = false);
151 ~jit_uni_binary_kernel_t() override = default;
152
153 std::map<data_type_t, io::io_saturation_conf_t>
154 create_saturation_vmm_map() const;
155};
156
157} // namespace x64
158} // namespace cpu
159} // namespace impl
160} // namespace dnnl
161
162#endif
163