1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/jit_brgemm_conv_comp_pad_kernel.hpp"
18#include "cpu/x64/jit_brgemm_conv_utils.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace cpu {
23namespace x64 {
24
25using namespace dnnl::impl::utils;
26using namespace nstl;
27using namespace data_type;
28
29namespace jit_avx512_core_brgemm_conv_comp_pad_kernel {
30
31#define GET_OFF(field) offsetof(jit_brgemm_conv_comp_pad_call_s, field)
32
33jit_avx512_core_brgemm_conv_comp_pad_kernel_t::
34 jit_avx512_core_brgemm_conv_comp_pad_kernel_t(
35 const jit_brgemm_conv_conf_t &ajcp)
36 : jit_generator(jit_name())
37 , jcp_(ajcp)
38 , inp_dsz_(jcp_.wei_dsz)
39 , out_dsz_(jcp_.acc_dsz)
40 , nb_ic_(utils::div_up(jcp_.ic, 4))
41 , inp_ic_sz_(static_cast<size_t>(inp_dsz_) * jcp_.oc_block * 4)
42 , inp_kw_sz_(static_cast<size_t>(inp_dsz_) * jcp_.icp * jcp_.oc_block)
43 , inp_kh_sz_(static_cast<size_t>(jcp_.kw) * inp_kw_sz_)
44 , inp_kd_sz_(static_cast<size_t>(jcp_.kh) * inp_kh_sz_) {}
45
46size_t jit_avx512_core_brgemm_conv_comp_pad_kernel_t::out_oc_offset(
47 const int n) const {
48 return static_cast<size_t>(out_dsz_) * n * m_block2_;
49}
50size_t jit_avx512_core_brgemm_conv_comp_pad_kernel_t::inp_ic_offset(
51 const int m_block, const int icb, const int m, const int n) const {
52 return static_cast<size_t>(inp_dsz_) * n * m_block2_ * last_ic_block_
53 + ((icb * m_block) + m) * inp_ic_sz_;
54}
55Xbyak::Zmm jit_avx512_core_brgemm_conv_comp_pad_kernel_t::accum(
56 const int n_block, const int m, const int n) const {
57 return Xbyak::Zmm(m * n_block + n);
58}
59
60void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::store_accumulators(
61 const int m_block, const int n_block) {
62 if (jcp_.src_zero_point) {
63 for_(int m = 0; m < m_block; m++)
64 for (int n = 0; n < n_block; n++) {
65 auto zmm = accum(n_block, m, n);
66 auto zmm_tmp = zmm_tmp_1();
67 const auto offset = out_oc_offset(n);
68 auto zp_addr = ptr[reg_zp_comp_out + offset];
69
70 vpmulld(zmm_tmp, zmm, zmm_zp_shift);
71 vpaddd(zmm_tmp, zmm_tmp, zp_addr);
72 vmovups(zp_addr, zmm_tmp);
73 }
74 }
75
76 if (jcp_.s8s8_avx512) {
77 for_(int m = 0; m < m_block; m++)
78 for (int n = 0; n < n_block; n++) {
79 auto zmm = accum(n_block, m, n);
80 auto zmm_tmp = zmm_tmp_1();
81 const auto offset = out_oc_offset(n);
82 auto cp_addr = ptr[reg_comp_out + offset];
83
84 vpmulld(zmm_tmp, zmm, zmm_cp_shift);
85 vpaddd(zmm_tmp, zmm_tmp, cp_addr);
86 vmovups(cp_addr, zmm_tmp);
87 }
88 }
89}
90
91void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::zero_accumulators(
92 const int m_block, const int n_block) {
93 for_(int m = 0; m < m_block; m++)
94 for (int n = 0; n < n_block; n++) {
95 auto zmm = accum(n_block, m, n);
96 vpxord(zmm, zmm, zmm);
97 }
98}
99
100void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::compute(const int ic_step,
101 const int m_block, const int n_block, const int m_tail,
102 const bool is_mb_tail) {
103
104 for_(int ic = 0; ic < ic_step; ++ic)
105 for (int m = 0; m < m_block; ++m) {
106 if (is_mb_tail && (ic * m_block + m) >= m_tail) break;
107 for (int n = 0; n < n_block; ++n) {
108 auto zmm = accum(n_block, m, n);
109 const auto oc_offset = inp_ic_offset(m_block, ic, m, n);
110 auto addr = EVEX_compress_addr(reg_aux_in, oc_offset);
111 vpdpbusd(zmm, zmm_one_bytes, addr);
112 }
113 }
114}
115
116void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::icb_loop(const int icb,
117 const int icb_tail, const int ic_step, const int m_block,
118 const int mb_tail, const int n_block) {
119 Xbyak::Label label_icb_loop, label_loop_end;
120
121 mov(reg_aux_in, reg_aux_kw_in);
122 mov(reg_icb, icb);
123
124 L(label_icb_loop);
125 {
126 cmp(reg_icb, 0);
127 je(label_loop_end, T_NEAR);
128 compute(ic_step, m_block, n_block, 0, false);
129 add(reg_aux_in, ic_step * m_block * inp_ic_sz_);
130 dec(reg_icb);
131 jmp(label_icb_loop, T_NEAR);
132 }
133 L_aligned(label_loop_end);
134
135 if (icb_tail) compute(ic_step, mb_tail, n_block, icb_tail, true);
136}
137
138void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::khw_loop(const int icb,
139 const int icb_tail, const int ic_step, const int m_block,
140 const int mb_tail, const int n_block) {
141 Xbyak::Label label_kw_loop, label_kw_end, label_kh_loop, label_kh_end;
142 mov(reg_kh_l, ptr[param1 + GET_OFF(kh_l)]);
143 mov(reg_aux_kh_in, reg_in);
144
145 L_aligned(label_kh_loop);
146 {
147 cmp(reg_kh_l, 0);
148 je(label_kh_end, T_NEAR);
149 mov(reg_kw_l, ptr[param1 + GET_OFF(kw_l)]);
150 mov(reg_aux_kw_in, reg_aux_kh_in);
151 L_aligned(label_kw_loop);
152 {
153 cmp(reg_kw_l, 0);
154 je(label_kw_end, T_NEAR);
155 icb_loop(icb, icb_tail, ic_step, m_block, mb_tail, n_block);
156 add(reg_aux_kw_in, inp_kw_sz_);
157 dec(reg_kw_l);
158 jmp(label_kw_loop, T_NEAR);
159 }
160 L_aligned(label_kw_end);
161
162 add(reg_aux_kh_in, inp_kh_sz_);
163 dec(reg_kh_l);
164 jmp(label_kh_loop, T_NEAR);
165 }
166 L_aligned(label_kh_end);
167}
168
169void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::load_params() {
170 mov(reg_in, ptr[param1 + GET_OFF(ptr_in)]);
171 mov(reg_zp_comp_out, ptr[param1 + GET_OFF(ptr_zp_out)]);
172 mov(reg_comp_out, ptr[param1 + GET_OFF(ptr_cp_out)]);
173}
174
175int jit_avx512_core_brgemm_conv_comp_pad_kernel_t::compute_ic_step(
176 const int m_max_regs, const int m_block, const int n_block) const {
177 int best_ic_step = 1;
178 float best_block_eff = 0.f;
179
180 int max_ic_step
181 = nstl::min(static_cast<size_t>(m_block), div_up(nb_ic_, m_block));
182
183 // Introduce ic_step to increase kernel efficiency
184 // Compute the ic_step based on the optimal kernel efficiency
185 for (int ic_s = max_ic_step; ic_s >= 1; --ic_s) {
186 const auto blocks = ic_s * m_block;
187 const float block_disb
188 = static_cast<float>(nb_ic_) / rnd_up(nb_ic_, blocks);
189 const float eff = (static_cast<float>(n_block) * blocks)
190 / ((n_block + blocks) * max_ic_step);
191 const float block_eff = block_disb * eff;
192 float block_footprint = static_cast<float>(inp_dsz_) * blocks
193 * jcp_.oc_block * last_ic_block_;
194 if (block_footprint <= static_cast<float>(
195 platform::get_per_core_cache_size(1))
196 && (block_eff > best_block_eff)) {
197 best_ic_step = ic_s;
198 best_block_eff = block_eff;
199 }
200 }
201
202 return best_ic_step;
203}
204
205void jit_avx512_core_brgemm_conv_comp_pad_kernel_t::generate() {
206 preamble();
207
208 load_params();
209
210 // fill registers with byte ones
211 const auto reg32_scratch = reg_tmp.cvt32();
212 mov(reg32_scratch, 0x1010101);
213 vpbroadcastd(zmm_one_bytes, reg32_scratch);
214
215 // fill register with -128 && -1
216 mov(reg32_scratch, -128);
217 vpbroadcastd(zmm_cp_shift, reg32_scratch);
218
219 mov(reg32_scratch, -1);
220 vpbroadcastd(zmm_zp_shift, reg32_scratch);
221
222 const int max_regs = jcp_.s8s8_avx512 ? 28 : 29;
223 const int nb = div_up(nstl::min(jcp_.oc, jcp_.oc_block), m_block2_);
224 const int nb2 = nb / n_max_regs_;
225 const int nb2_tail = nb % n_block2_;
226 const int n_block = (nb2 == 0) ? nstl::max(1, nb2_tail) : 4;
227
228 const size_t m_max_regs = max_regs / n_block;
229 const int m_block = nstl::min(m_max_regs, nb_ic_);
230 const int ic_step = compute_ic_step(m_max_regs, m_block, n_block);
231
232 const auto blocks = m_block * ic_step;
233 const auto icb = nb_ic_ / blocks;
234 const auto icb_tail = nb_ic_ % blocks;
235 const auto mb_tail = div_up(icb_tail, ic_step);
236
237 Xbyak::Label label_kd_loop, label_loop_end;
238 mov(reg_kd_l, ptr[param1 + GET_OFF(kd_l)]);
239
240 zero_accumulators(m_block, n_block);
241
242 L_aligned(label_kd_loop);
243 {
244 cmp(reg_kd_l, 0);
245 je(label_loop_end, T_NEAR);
246 khw_loop(icb, icb_tail, ic_step, m_block, mb_tail, n_block);
247 add(reg_in, inp_kd_sz_);
248 dec(reg_kd_l);
249 jmp(label_kd_loop, T_NEAR);
250 }
251 L_aligned(label_loop_end);
252
253 store_accumulators(m_block, n_block);
254
255 postamble();
256}
257
258} // namespace jit_avx512_core_brgemm_conv_comp_pad_kernel
259
260} // namespace x64
261} // namespace cpu
262} // namespace impl
263} // namespace dnnl
264
265// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
266