1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONV_KERNEL_HPP |
18 | #define CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONV_KERNEL_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | |
22 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
23 | #include "cpu/x64/jit_avx512_core_bf16cvt.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_primitive_conf.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | struct jit_avx512_core_bf16_1x1_conv_kernel : public jit_generator { |
33 | jit_avx512_core_bf16_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp, |
34 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
35 | |
36 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_1x1_conv_kernel) |
37 | |
38 | static status_t init_conf(jit_1x1_conv_conf_t &jcp, |
39 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
40 | const memory_desc_wrapper &weights_d, |
41 | const memory_desc_wrapper &dst_d, primitive_attr_t &attr, |
42 | int nthreads, bool reduce_src); |
43 | |
44 | static status_t init_scratchpad(memory_tracking::registrar_t &scratchpad, |
45 | const jit_1x1_conv_conf_t &jcp); |
46 | |
47 | const jit_1x1_conv_conf_t &jcp; |
48 | const primitive_attr_t &attr_; |
49 | |
50 | private: |
51 | constexpr static int isa_simd_width_ |
52 | = cpu_isa_traits<avx512_core>::vlen / sizeof(float); |
53 | std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>> |
54 | postops_injector_; |
55 | |
56 | using reg64_t = const Xbyak::Reg64; |
57 | using zmm_t = const Xbyak::Zmm; |
58 | using mask_t = const Xbyak::Opmask; |
59 | enum { |
60 | ker_code_size = 1024 * 1024, |
61 | }; |
62 | |
63 | reg64_t aux_reg_load_data = r15; |
64 | reg64_t aux_reg_bcast_data = r14; |
65 | reg64_t reg_output_stride = rsi; |
66 | reg64_t reg_bias_data = r12; |
67 | reg64_t reg_reduce_loop_work = r11; |
68 | reg64_t reg_load_data = r10; |
69 | reg64_t reg_output_data = r9; |
70 | reg64_t reg_bcast_data = r8; |
71 | reg64_t reg_reduce_pos_flag = rax; |
72 | reg64_t aux1_reg_bcast_data = rbx; |
73 | reg64_t aux_reg_output_data = abi_not_param1; |
74 | reg64_t reg_bcast_loop_iter = rdx; |
75 | reg64_t reg_load_loop_work = r13; |
76 | reg64_t reduce_loop_iter = abi_param1; |
77 | reg64_t reg_load_dim_tail_mask = aux_reg_load_data; |
78 | |
79 | reg64_t imm_addr64 = aux_reg_load_data; |
80 | reg64_t reg_bcast_loop_work = aux1_reg_bcast_data; |
81 | reg64_t reg_trans_tmp = reg_reduce_pos_flag; |
82 | reg64_t reg_store_buf |
83 | = reg_output_stride; // reg_output_stride used only in BWD/WU |
84 | reg64_t aux_reg_store_buf = reg_load_loop_work; |
85 | reg64_t reg_tmp = r12; |
86 | |
87 | mask_t vmask = k7; |
88 | // Used for axb tail handling. |
89 | // k_load_dim_mask is dynamically updated with k_load_mask_tail_mask |
90 | // whenever tail is detected |
91 | mask_t k_load_dim_mask = Xbyak::Opmask(2); |
92 | mask_t k_load_dim_mask_extended = Xbyak::Opmask(3); |
93 | mask_t k_load_dim_tail_mask = Xbyak::Opmask(4); |
94 | mask_t k_load_dim_tail_mask_extended = Xbyak::Opmask(5); |
95 | |
96 | Xbyak::Xmm xmm_relu_ns = Xbyak::Xmm(30); |
97 | Xbyak::Zmm zmm_relu_ns = Xbyak::Zmm(30); |
98 | Xbyak::Zmm zmm_zero = Xbyak::Zmm(31); |
99 | Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31); |
100 | |
101 | Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(25); |
102 | Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(26); |
103 | Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(27); |
104 | reg64_t bf16_emu_reserv_4 = imm_addr64; |
105 | Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(28); |
106 | Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(29); |
107 | |
108 | Xbyak::Zmm zmm_tmp2 = Xbyak::Zmm(30); |
109 | |
110 | Xbyak::Opmask full_mask = Xbyak::Opmask(7); |
111 | Xbyak::Opmask half_mask = Xbyak::Opmask(6); |
112 | Xbyak::Opmask half_mask_hi = Xbyak::Opmask(5); |
113 | Xbyak::Label dst_prm_table; |
114 | |
115 | constexpr static int reg64_size_ = sizeof(int64_t); |
116 | constexpr static int bcast_loop_work_offt = 0; |
117 | constexpr static int reg_load_loop_work_off = 1 * reg64_size_; |
118 | constexpr static int perm_reg_offset = 2 * reg64_size_; |
119 | constexpr static int broadcast_space = 3 * reg64_size_; |
120 | constexpr static int reg_binary_post_op_acc_off = 4 * reg64_size_; |
121 | constexpr static int reg_abi_param1_backup = 5 * reg64_size_; |
122 | constexpr static int stack_space_needed = 376; |
123 | |
124 | void bcast_loop(int load_loop_blk); |
125 | void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound); |
126 | void compute_diff_bias(int load_loop_blk); |
127 | |
128 | Xbyak::Address output_ptr(const int i_load, const int i_ur); |
129 | void apply_postops(const int load_loop_blk, const int ur); |
130 | void generate() override; |
131 | static void balance(jit_1x1_conv_conf_t &jcp, int nthreads); |
132 | inline bool is_bcast_layout_nxc() { |
133 | switch (jcp.prop_kind) { |
134 | case prop_kind::forward_training: |
135 | case prop_kind::forward_inference: |
136 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, |
137 | format_tag::nhwc, format_tag::nwc); |
138 | case prop_kind::backward_data: |
139 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, |
140 | format_tag::nhwc, format_tag::nwc); |
141 | case prop_kind::backward_weights: |
142 | return jcp.uses_permw_transposition |
143 | && utils::one_of(jcp.src_tag, format_tag::ndhwc, |
144 | format_tag::nhwc, format_tag::nwc); |
145 | default: assert(!"invalid prop_kind" ); return false; |
146 | } |
147 | } |
148 | inline bool is_load_layout_nxc() { |
149 | return jcp.prop_kind == prop_kind::backward_weights |
150 | && jcp.uses_permw_transposition |
151 | && utils::one_of(jcp.dst_tag, format_tag::ndhwc, |
152 | format_tag::nhwc, format_tag::nwc); |
153 | } |
154 | inline bool is_out_layout_nxc() { |
155 | switch (jcp.prop_kind) { |
156 | case prop_kind::forward_training: |
157 | case prop_kind::forward_inference: |
158 | return utils::one_of(jcp.dst_tag, format_tag::ndhwc, |
159 | format_tag::nhwc, format_tag::nwc); |
160 | case prop_kind::backward_data: |
161 | return utils::one_of(jcp.src_tag, format_tag::ndhwc, |
162 | format_tag::nhwc, format_tag::nwc); |
163 | case prop_kind::backward_weights: return false; |
164 | default: assert(!"invalid prop_kind" ); return false; |
165 | } |
166 | } |
167 | |
168 | inline Xbyak::Zmm may_be_mask_zmm(Xbyak::Zmm zmm, bool mask_flag, |
169 | bool zero_mask, bool use_extended_mask = false) { |
170 | if (mask_flag) { |
171 | zmm = zmm |
172 | | (use_extended_mask ? k_load_dim_mask_extended |
173 | : k_load_dim_mask); |
174 | if (zero_mask) zmm = zmm | T_z; |
175 | } |
176 | return zmm; |
177 | } |
178 | |
179 | inline Xbyak::Ymm may_be_mask_ymm( |
180 | Xbyak::Ymm ymm, bool mask_flag, bool zero_mask = false) { |
181 | if (mask_flag) { |
182 | ymm = ymm | k_load_dim_mask; |
183 | if (zero_mask) ymm = ymm | T_z; |
184 | } |
185 | return ymm; |
186 | } |
187 | |
188 | inline size_t get_output_offset(const int i_load, const int i_ur) { |
189 | const bool is_output_layout_nxc = is_out_layout_nxc(); |
190 | const size_t i_load_shift = is_output_layout_nxc |
191 | ? jcp.load_block |
192 | : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block; |
193 | const size_t i_ur_shift |
194 | = is_output_layout_nxc ? jcp.load_dim : jcp.load_block; |
195 | return jcp.typesize_out * (i_load * i_load_shift + i_ur * i_ur_shift); |
196 | } |
197 | |
198 | std::unique_ptr<bf16_emulation_t> bf16_emu_; |
199 | }; |
200 | } // namespace x64 |
201 | } // namespace cpu |
202 | } // namespace impl |
203 | } // namespace dnnl |
204 | |
205 | #endif |
206 | |