1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/jit_brgemm_conv_comp_pad_kernel.hpp" |
18 | #include "cpu/x64/jit_brgemm_conv_utils.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace cpu { |
23 | namespace x64 { |
24 | |
25 | using namespace dnnl::impl::utils; |
26 | using namespace nstl; |
27 | using namespace data_type; |
28 | |
29 | namespace jit_avx512_core_brgemm_conv_comp_pad_kernel { |
30 | |
31 | #define GET_OFF(field) offsetof(jit_brgemm_conv_comp_pad_call_s, field) |
32 | |
33 | jit_avx512_core_brgemm_conv_comp_pad_kernel_t:: |
34 | jit_avx512_core_brgemm_conv_comp_pad_kernel_t( |
35 | const jit_brgemm_conv_conf_t &ajcp) |
36 | : jit_generator(jit_name()) |
37 | , jcp_(ajcp) |
38 | , inp_dsz_(jcp_.wei_dsz) |
39 | , out_dsz_(jcp_.acc_dsz) |
40 | , nb_ic_(utils::div_up(jcp_.ic, 4)) |
41 | , inp_ic_sz_(static_cast<size_t>(inp_dsz_) * jcp_.oc_block * 4) |
42 | , inp_kw_sz_(static_cast<size_t>(inp_dsz_) * jcp_.icp * jcp_.oc_block) |
43 | , inp_kh_sz_(static_cast<size_t>(jcp_.kw) * inp_kw_sz_) |
44 | , inp_kd_sz_(static_cast<size_t>(jcp_.kh) * inp_kh_sz_) {} |
45 | |
46 | size_t jit_avx512_core_brgemm_conv_comp_pad_kernel_t::out_oc_offset( |
47 | const int n) const { |
48 | return static_cast<size_t>(out_dsz_) * n * m_block2_; |
49 | } |
50 | size_t jit_avx512_core_brgemm_conv_comp_pad_kernel_t::inp_ic_offset( |
51 | const int m_block, const int icb, const int m, const int n) const { |
52 | return static_cast<size_t>(inp_dsz_) * n * m_block2_ * last_ic_block_ |
53 | + ((icb * m_block) + m) * inp_ic_sz_; |
54 | } |
55 | Xbyak::Zmm jit_avx512_core_brgemm_conv_comp_pad_kernel_t::accum( |
56 | const int n_block, const int m, const int n) const { |
57 | return Xbyak::Zmm(m * n_block + n); |
58 | } |
59 | |
60 | void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::store_accumulators( |
61 | const int m_block, const int n_block) { |
62 | if (jcp_.src_zero_point) { |
63 | for_(int m = 0; m < m_block; m++) |
64 | for (int n = 0; n < n_block; n++) { |
65 | auto zmm = accum(n_block, m, n); |
66 | auto zmm_tmp = zmm_tmp_1(); |
67 | const auto offset = out_oc_offset(n); |
68 | auto zp_addr = ptr[reg_zp_comp_out + offset]; |
69 | |
70 | vpmulld(zmm_tmp, zmm, zmm_zp_shift); |
71 | vpaddd(zmm_tmp, zmm_tmp, zp_addr); |
72 | vmovups(zp_addr, zmm_tmp); |
73 | } |
74 | } |
75 | |
76 | if (jcp_.s8s8_avx512) { |
77 | for_(int m = 0; m < m_block; m++) |
78 | for (int n = 0; n < n_block; n++) { |
79 | auto zmm = accum(n_block, m, n); |
80 | auto zmm_tmp = zmm_tmp_1(); |
81 | const auto offset = out_oc_offset(n); |
82 | auto cp_addr = ptr[reg_comp_out + offset]; |
83 | |
84 | vpmulld(zmm_tmp, zmm, zmm_cp_shift); |
85 | vpaddd(zmm_tmp, zmm_tmp, cp_addr); |
86 | vmovups(cp_addr, zmm_tmp); |
87 | } |
88 | } |
89 | } |
90 | |
91 | void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::zero_accumulators( |
92 | const int m_block, const int n_block) { |
93 | for_(int m = 0; m < m_block; m++) |
94 | for (int n = 0; n < n_block; n++) { |
95 | auto zmm = accum(n_block, m, n); |
96 | vpxord(zmm, zmm, zmm); |
97 | } |
98 | } |
99 | |
100 | void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::compute(const int ic_step, |
101 | const int m_block, const int n_block, const int m_tail, |
102 | const bool is_mb_tail) { |
103 | |
104 | for_(int ic = 0; ic < ic_step; ++ic) |
105 | for (int m = 0; m < m_block; ++m) { |
106 | if (is_mb_tail && (ic * m_block + m) >= m_tail) break; |
107 | for (int n = 0; n < n_block; ++n) { |
108 | auto zmm = accum(n_block, m, n); |
109 | const auto oc_offset = inp_ic_offset(m_block, ic, m, n); |
110 | auto addr = EVEX_compress_addr(reg_aux_in, oc_offset); |
111 | vpdpbusd(zmm, zmm_one_bytes, addr); |
112 | } |
113 | } |
114 | } |
115 | |
116 | void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::icb_loop(const int icb, |
117 | const int icb_tail, const int ic_step, const int m_block, |
118 | const int mb_tail, const int n_block) { |
119 | Xbyak::Label label_icb_loop, label_loop_end; |
120 | |
121 | mov(reg_aux_in, reg_aux_kw_in); |
122 | mov(reg_icb, icb); |
123 | |
124 | L(label_icb_loop); |
125 | { |
126 | cmp(reg_icb, 0); |
127 | je(label_loop_end, T_NEAR); |
128 | compute(ic_step, m_block, n_block, 0, false); |
129 | add(reg_aux_in, ic_step * m_block * inp_ic_sz_); |
130 | dec(reg_icb); |
131 | jmp(label_icb_loop, T_NEAR); |
132 | } |
133 | L_aligned(label_loop_end); |
134 | |
135 | if (icb_tail) compute(ic_step, mb_tail, n_block, icb_tail, true); |
136 | } |
137 | |
138 | void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::khw_loop(const int icb, |
139 | const int icb_tail, const int ic_step, const int m_block, |
140 | const int mb_tail, const int n_block) { |
141 | Xbyak::Label label_kw_loop, label_kw_end, label_kh_loop, label_kh_end; |
142 | mov(reg_kh_l, ptr[param1 + GET_OFF(kh_l)]); |
143 | mov(reg_aux_kh_in, reg_in); |
144 | |
145 | L_aligned(label_kh_loop); |
146 | { |
147 | cmp(reg_kh_l, 0); |
148 | je(label_kh_end, T_NEAR); |
149 | mov(reg_kw_l, ptr[param1 + GET_OFF(kw_l)]); |
150 | mov(reg_aux_kw_in, reg_aux_kh_in); |
151 | L_aligned(label_kw_loop); |
152 | { |
153 | cmp(reg_kw_l, 0); |
154 | je(label_kw_end, T_NEAR); |
155 | icb_loop(icb, icb_tail, ic_step, m_block, mb_tail, n_block); |
156 | add(reg_aux_kw_in, inp_kw_sz_); |
157 | dec(reg_kw_l); |
158 | jmp(label_kw_loop, T_NEAR); |
159 | } |
160 | L_aligned(label_kw_end); |
161 | |
162 | add(reg_aux_kh_in, inp_kh_sz_); |
163 | dec(reg_kh_l); |
164 | jmp(label_kh_loop, T_NEAR); |
165 | } |
166 | L_aligned(label_kh_end); |
167 | } |
168 | |
169 | void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::load_params() { |
170 | mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]); |
171 | mov(reg_zp_comp_out, ptr[param1 + GET_OFF(ptr_zp_out)]); |
172 | mov(reg_comp_out, ptr[param1 + GET_OFF(ptr_cp_out)]); |
173 | } |
174 | |
175 | int jit_avx512_core_brgemm_conv_comp_pad_kernel_t::compute_ic_step( |
176 | const int m_max_regs, const int m_block, const int n_block) const { |
177 | int best_ic_step = 1; |
178 | float best_block_eff = 0.f; |
179 | |
180 | int max_ic_step |
181 | = nstl::min(static_cast<size_t>(m_block), div_up(nb_ic_, m_block)); |
182 | |
183 | // Introduce ic_step to increase kernel efficiency |
184 | // Compute the ic_step based on the optimal kernel efficiency |
185 | for (int ic_s = max_ic_step; ic_s >= 1; --ic_s) { |
186 | const auto blocks = ic_s * m_block; |
187 | const float block_disb |
188 | = static_cast<float>(nb_ic_) / rnd_up(nb_ic_, blocks); |
189 | const float eff = (static_cast<float>(n_block) * blocks) |
190 | / ((n_block + blocks) * max_ic_step); |
191 | const float block_eff = block_disb * eff; |
192 | float = static_cast<float>(inp_dsz_) * blocks |
193 | * jcp_.oc_block * last_ic_block_; |
194 | if (block_footprint <= static_cast<float>( |
195 | platform::get_per_core_cache_size(1)) |
196 | && (block_eff > best_block_eff)) { |
197 | best_ic_step = ic_s; |
198 | best_block_eff = block_eff; |
199 | } |
200 | } |
201 | |
202 | return best_ic_step; |
203 | } |
204 | |
205 | void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::generate() { |
206 | preamble(); |
207 | |
208 | load_params(); |
209 | |
210 | // fill registers with byte ones |
211 | const auto reg32_scratch = reg_tmp.cvt32(); |
212 | mov(reg32_scratch, 0x1010101); |
213 | vpbroadcastd(zmm_one_bytes, reg32_scratch); |
214 | |
215 | // fill register with -128 && -1 |
216 | mov(reg32_scratch, -128); |
217 | vpbroadcastd(zmm_cp_shift, reg32_scratch); |
218 | |
219 | mov(reg32_scratch, -1); |
220 | vpbroadcastd(zmm_zp_shift, reg32_scratch); |
221 | |
222 | const int max_regs = jcp_.s8s8_avx512 ? 28 : 29; |
223 | const int nb = div_up(nstl::min(jcp_.oc, jcp_.oc_block), m_block2_); |
224 | const int nb2 = nb / n_max_regs_; |
225 | const int nb2_tail = nb % n_block2_; |
226 | const int n_block = (nb2 == 0) ? nstl::max(1, nb2_tail) : 4; |
227 | |
228 | const size_t m_max_regs = max_regs / n_block; |
229 | const int m_block = nstl::min(m_max_regs, nb_ic_); |
230 | const int ic_step = compute_ic_step(m_max_regs, m_block, n_block); |
231 | |
232 | const auto blocks = m_block * ic_step; |
233 | const auto icb = nb_ic_ / blocks; |
234 | const auto icb_tail = nb_ic_ % blocks; |
235 | const auto mb_tail = div_up(icb_tail, ic_step); |
236 | |
237 | Xbyak::Label label_kd_loop, label_loop_end; |
238 | mov(reg_kd_l, ptr[param1 + GET_OFF(kd_l)]); |
239 | |
240 | zero_accumulators(m_block, n_block); |
241 | |
242 | L_aligned(label_kd_loop); |
243 | { |
244 | cmp(reg_kd_l, 0); |
245 | je(label_loop_end, T_NEAR); |
246 | khw_loop(icb, icb_tail, ic_step, m_block, mb_tail, n_block); |
247 | add(reg_in, inp_kd_sz_); |
248 | dec(reg_kd_l); |
249 | jmp(label_kd_loop, T_NEAR); |
250 | } |
251 | L_aligned(label_loop_end); |
252 | |
253 | store_accumulators(m_block, n_block); |
254 | |
255 | postamble(); |
256 | } |
257 | |
258 | } // namespace jit_avx512_core_brgemm_conv_comp_pad_kernel |
259 | |
260 | } // namespace x64 |
261 | } // namespace cpu |
262 | } // namespace impl |
263 | } // namespace dnnl |
264 | |
265 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
266 | |