1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_UNI_BINARY_KERNEL_HPP |
18 | #define CPU_X64_UNI_BINARY_KERNEL_HPP |
19 | |
20 | #include <cassert> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/x64/cpu_isa_traits.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
28 | #include "cpu/x64/jit_generator.hpp" |
29 | #include "cpu/x64/jit_primitive_conf.hpp" |
30 | #include "cpu/x64/utils/jit_io_helper.hpp" |
31 | |
32 | #include "cpu/cpu_binary_pd.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace x64 { |
38 | |
39 | using namespace Xbyak; |
40 | |
41 | struct binary_kernel_t : public jit_generator { |
42 | using op_t = binary_op_t; |
43 | using bcast_t = binary_bcast_t; |
44 | |
45 | binary_kernel_t(const size_t vlen, const binary_pd_t *pd, |
46 | const jit_binary_conf_t conf, const char *name, |
47 | bool tail_kernel = false); |
48 | ~binary_kernel_t() override = default; |
49 | |
50 | void operator()(jit_binary_call_s *p) { jit_generator::operator()(p); } |
51 | |
52 | size_t simd_w() const noexcept { return simd_w_; } |
53 | size_t vlen() const noexcept { return vlen_; } |
54 | |
55 | protected: |
56 | size_t get_tail_size() const; |
57 | |
58 | const size_t vlen_; |
59 | const size_t simd_w_; |
60 | constexpr static int vmm_start_idx_ = 1; |
61 | const binary_pd_t *pd_; |
62 | const jit_binary_conf_t conf_; |
63 | const bool is_tail_kernel_; |
64 | const bool is_src1_outer_dims_tail_; |
65 | const size_t tail_size_; |
66 | const size_t padding_tail_size_; |
67 | }; |
68 | |
69 | template <cpu_isa_t isa, typename Vmm> |
70 | struct jit_uni_binary_kernel_t : public binary_kernel_t { |
71 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_binary_kernel_t) |
72 | |
73 | const AddressFrame &vmmword = (std::is_same<Vmm, Xmm>::value) |
74 | ? xword |
75 | : ((std::is_same<Vmm, Ymm>::value) ? yword : zword); |
76 | |
77 | const bool is_avx512 = is_superset(isa, avx512_core); |
78 | |
79 | const Reg64 reg_param_ = abi_param1; |
80 | const Reg64 reg_src0_ = r8; |
81 | const Reg64 reg_src1_ = r9; |
82 | const Reg64 reg_dst_ = r10; |
83 | const Reg64 reg_offt_src0_ = r11; |
84 | const Reg64 reg_outer_dims_range_ = r12; |
85 | const Reg64 reg_offt_src1_ = rax; |
86 | const Reg64 reg_src1_stride_range_ = r15; |
87 | const Reg64 reg_reverse_src1_stride_range_ = rax; |
88 | const Reg64 reg_reverse_spat_offt_ = r13; |
89 | const Reg64 reg_tmp_ = r14; |
90 | const Reg64 reg_tmp1_ = abi_not_param1; |
91 | const Reg64 reg_elt_inj_table_ = r15; |
92 | const Reg64 reg_off_rhs_postops_ = rdx; |
93 | const Reg64 reg_scales_src0_ = rbx; |
94 | const Reg64 reg_scales_src1_ = rbp; |
95 | const Reg64 reg_offt_dst_ = rdx; |
96 | const Opmask tail_opmask_ = k2; |
97 | const Opmask cmp_mask = k3; |
98 | const Opmask full_mask_ = k4; |
99 | const Vmm vmm_tail_vmask_ = Vmm(0); |
100 | const Vmm vreg_sum_scale_ = Vmm(is_avx512 ? 17 : 9); |
101 | const Xmm xreg_sum_scale_ = Xmm(9); |
102 | const Vmm vreg_zero_ = Vmm(is_avx512 ? 18 : 10); |
103 | const Vmm vreg_one_ = Vmm(is_avx512 ? 19 : 11); |
104 | const Vmm vreg_saturation_ubound_ = Vmm(is_avx512 ? 20 : 12); |
105 | const Vmm vreg_bcast_src1_ = Vmm(is_avx512 ? 21 : 13); |
106 | const Xmm xreg_bcast_src1_ = Xmm(13); |
107 | const Vmm vreg_scales_src0_ = Vmm(is_avx512 ? 22 : 14); |
108 | const Vmm vreg_scales_src1_ = Vmm(is_avx512 ? 23 : 15); |
109 | |
110 | const Zmm vreg_bf16_emu_1_ = Zmm(26); |
111 | const Zmm vreg_bf16_emu_2_ = Zmm(27); |
112 | const Zmm vreg_bf16_emu_3_ = Zmm(28); |
113 | const Zmm vreg_bf16_emu_4_ = Zmm(29); |
114 | |
115 | const Vmm vmm_full_mask_ = Vmm(is_avx512 ? 24 : 5); |
116 | const Vmm vmm_tmp_gather_ = Vmm(is_avx512 ? 25 : 6); |
117 | const Vmm vmm_indices_ = Vmm(is_avx512 ? 30 : 7); |
118 | const Vmm vmm_gathered_src_ = Vmm(is_avx512 ? 31 : 8); |
119 | |
120 | const size_t unroll_regs_ = is_avx512 ? 8 : 4; |
121 | const size_t offt_src0_; |
122 | const size_t offt_src1_; |
123 | |
124 | static constexpr cpu_isa_t inject_isa |
125 | = isa == avx512_core_bf16 ? avx512_core : isa; |
126 | io::jit_io_multi_dt_helper_t<Vmm> io_; |
127 | std::unique_ptr<injector::jit_uni_postops_injector_t<inject_isa, Vmm>> |
128 | postops_injector_; |
129 | const Opmask elt_inj_opmask_ = k1; |
130 | |
131 | void init(); |
132 | void init_post_ops_injector(); |
133 | void apply_postops(int unroll, bool tail); |
134 | void load_kernel_params(); |
135 | Address src0_ptr(size_t offt = 0); |
136 | Address src1_ptr(size_t offt = 0); |
137 | Address dst_ptr(size_t offt = 0); |
138 | unsigned int cmp_predicate(alg_kind_t alg); |
139 | void perform_op( |
140 | const Vmm &v0, const Vmm &v1, const Vmm &s_src0, const Vmm &s_src1); |
141 | void prepare_isa_kernel(); |
142 | void compute_bcast(bool tail); |
143 | void load_src1(const Vmm &vreg_src1, const int offt, bool tail); |
144 | void compute_dst(int unroll, bool tail); |
145 | void forward(); |
146 | void forward_over_outer_dims(); |
147 | void generate() override; |
148 | |
149 | jit_uni_binary_kernel_t(const binary_pd_t *pd, const jit_binary_conf_t conf, |
150 | bool tail_kernel = false); |
151 | ~jit_uni_binary_kernel_t() override = default; |
152 | |
153 | std::map<data_type_t, io::io_saturation_conf_t> |
154 | create_saturation_vmm_map() const; |
155 | }; |
156 | |
157 | } // namespace x64 |
158 | } // namespace cpu |
159 | } // namespace impl |
160 | } // namespace dnnl |
161 | |
162 | #endif |
163 | |