1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONV_KERNEL_HPP
18#define CPU_X64_JIT_AVX512_CORE_BF16_1X1_CONV_KERNEL_HPP
19
20#include "common/c_types_map.hpp"
21
22#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
23#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32struct jit_avx512_core_bf16_1x1_conv_kernel : public jit_generator {
33 jit_avx512_core_bf16_1x1_conv_kernel(const jit_1x1_conv_conf_t &ajcp,
34 const primitive_attr_t &attr, const memory_desc_t &dst_md);
35
36 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_bf16_1x1_conv_kernel)
37
38 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
39 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
40 const memory_desc_wrapper &weights_d,
41 const memory_desc_wrapper &dst_d, primitive_attr_t &attr,
42 int nthreads, bool reduce_src);
43
44 static status_t init_scratchpad(memory_tracking::registrar_t &scratchpad,
45 const jit_1x1_conv_conf_t &jcp);
46
47 const jit_1x1_conv_conf_t &jcp;
48 const primitive_attr_t &attr_;
49
50private:
51 constexpr static int isa_simd_width_
52 = cpu_isa_traits<avx512_core>::vlen / sizeof(float);
53 std::unique_ptr<injector::jit_uni_postops_injector_t<avx512_core>>
54 postops_injector_;
55
56 using reg64_t = const Xbyak::Reg64;
57 using zmm_t = const Xbyak::Zmm;
58 using mask_t = const Xbyak::Opmask;
59 enum {
60 ker_code_size = 1024 * 1024,
61 };
62
63 reg64_t aux_reg_load_data = r15;
64 reg64_t aux_reg_bcast_data = r14;
65 reg64_t reg_output_stride = rsi;
66 reg64_t reg_bias_data = r12;
67 reg64_t reg_reduce_loop_work = r11;
68 reg64_t reg_load_data = r10;
69 reg64_t reg_output_data = r9;
70 reg64_t reg_bcast_data = r8;
71 reg64_t reg_reduce_pos_flag = rax;
72 reg64_t aux1_reg_bcast_data = rbx;
73 reg64_t aux_reg_output_data = abi_not_param1;
74 reg64_t reg_bcast_loop_iter = rdx;
75 reg64_t reg_load_loop_work = r13;
76 reg64_t reduce_loop_iter = abi_param1;
77 reg64_t reg_load_dim_tail_mask = aux_reg_load_data;
78
79 reg64_t imm_addr64 = aux_reg_load_data;
80 reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
81 reg64_t reg_trans_tmp = reg_reduce_pos_flag;
82 reg64_t reg_store_buf
83 = reg_output_stride; // reg_output_stride used only in BWD/WU
84 reg64_t aux_reg_store_buf = reg_load_loop_work;
85 reg64_t reg_tmp = r12;
86
87 mask_t vmask = k7;
88 // Used for axb tail handling.
89 // k_load_dim_mask is dynamically updated with k_load_mask_tail_mask
90 // whenever tail is detected
91 mask_t k_load_dim_mask = Xbyak::Opmask(2);
92 mask_t k_load_dim_mask_extended = Xbyak::Opmask(3);
93 mask_t k_load_dim_tail_mask = Xbyak::Opmask(4);
94 mask_t k_load_dim_tail_mask_extended = Xbyak::Opmask(5);
95
96 Xbyak::Xmm xmm_relu_ns = Xbyak::Xmm(30);
97 Xbyak::Zmm zmm_relu_ns = Xbyak::Zmm(30);
98 Xbyak::Zmm zmm_zero = Xbyak::Zmm(31);
99 Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
100
101 Xbyak::Zmm bf16_emu_reserv_1 = Xbyak::Zmm(25);
102 Xbyak::Zmm bf16_emu_reserv_2 = Xbyak::Zmm(26);
103 Xbyak::Zmm bf16_emu_reserv_3 = Xbyak::Zmm(27);
104 reg64_t bf16_emu_reserv_4 = imm_addr64;
105 Xbyak::Zmm bf16_emu_reserv_5 = Xbyak::Zmm(28);
106 Xbyak::Zmm bf16_emu_reserv_6 = Xbyak::Zmm(29);
107
108 Xbyak::Zmm zmm_tmp2 = Xbyak::Zmm(30);
109
110 Xbyak::Opmask full_mask = Xbyak::Opmask(7);
111 Xbyak::Opmask half_mask = Xbyak::Opmask(6);
112 Xbyak::Opmask half_mask_hi = Xbyak::Opmask(5);
113 Xbyak::Label dst_prm_table;
114
115 constexpr static int reg64_size_ = sizeof(int64_t);
116 constexpr static int bcast_loop_work_offt = 0;
117 constexpr static int reg_load_loop_work_off = 1 * reg64_size_;
118 constexpr static int perm_reg_offset = 2 * reg64_size_;
119 constexpr static int broadcast_space = 3 * reg64_size_;
120 constexpr static int reg_binary_post_op_acc_off = 4 * reg64_size_;
121 constexpr static int reg_abi_param1_backup = 5 * reg64_size_;
122 constexpr static int stack_space_needed = 376;
123
124 void bcast_loop(int load_loop_blk);
125 void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
126 void compute_diff_bias(int load_loop_blk);
127
128 Xbyak::Address output_ptr(const int i_load, const int i_ur);
129 void apply_postops(const int load_loop_blk, const int ur);
130 void generate() override;
131 static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
132 inline bool is_bcast_layout_nxc() {
133 switch (jcp.prop_kind) {
134 case prop_kind::forward_training:
135 case prop_kind::forward_inference:
136 return utils::one_of(jcp.src_tag, format_tag::ndhwc,
137 format_tag::nhwc, format_tag::nwc);
138 case prop_kind::backward_data:
139 return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
140 format_tag::nhwc, format_tag::nwc);
141 case prop_kind::backward_weights:
142 return jcp.uses_permw_transposition
143 && utils::one_of(jcp.src_tag, format_tag::ndhwc,
144 format_tag::nhwc, format_tag::nwc);
145 default: assert(!"invalid prop_kind"); return false;
146 }
147 }
148 inline bool is_load_layout_nxc() {
149 return jcp.prop_kind == prop_kind::backward_weights
150 && jcp.uses_permw_transposition
151 && utils::one_of(jcp.dst_tag, format_tag::ndhwc,
152 format_tag::nhwc, format_tag::nwc);
153 }
154 inline bool is_out_layout_nxc() {
155 switch (jcp.prop_kind) {
156 case prop_kind::forward_training:
157 case prop_kind::forward_inference:
158 return utils::one_of(jcp.dst_tag, format_tag::ndhwc,
159 format_tag::nhwc, format_tag::nwc);
160 case prop_kind::backward_data:
161 return utils::one_of(jcp.src_tag, format_tag::ndhwc,
162 format_tag::nhwc, format_tag::nwc);
163 case prop_kind::backward_weights: return false;
164 default: assert(!"invalid prop_kind"); return false;
165 }
166 }
167
168 inline Xbyak::Zmm may_be_mask_zmm(Xbyak::Zmm zmm, bool mask_flag,
169 bool zero_mask, bool use_extended_mask = false) {
170 if (mask_flag) {
171 zmm = zmm
172 | (use_extended_mask ? k_load_dim_mask_extended
173 : k_load_dim_mask);
174 if (zero_mask) zmm = zmm | T_z;
175 }
176 return zmm;
177 }
178
179 inline Xbyak::Ymm may_be_mask_ymm(
180 Xbyak::Ymm ymm, bool mask_flag, bool zero_mask = false) {
181 if (mask_flag) {
182 ymm = ymm | k_load_dim_mask;
183 if (zero_mask) ymm = ymm | T_z;
184 }
185 return ymm;
186 }
187
188 inline size_t get_output_offset(const int i_load, const int i_ur) {
189 const bool is_output_layout_nxc = is_out_layout_nxc();
190 const size_t i_load_shift = is_output_layout_nxc
191 ? jcp.load_block
192 : (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) * jcp.load_block;
193 const size_t i_ur_shift
194 = is_output_layout_nxc ? jcp.load_dim : jcp.load_block;
195 return jcp.typesize_out * (i_load * i_load_shift + i_ur * i_ur_shift);
196 }
197
198 std::unique_ptr<bf16_emulation_t> bf16_emu_;
199};
200} // namespace x64
201} // namespace cpu
202} // namespace impl
203} // namespace dnnl
204
205#endif
206