1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp"
18#include "cpu/x64/jit_generator.hpp"
19
20#include "cpu/gemm_convolution_utils.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace cpu {
25namespace x64 {
26namespace gemm_x8s8s32x_convolution_utils {
27
28jit_gemm_x8s8s32x_zp_pad_comp_helper::jit_gemm_x8s8s32x_zp_pad_comp_helper(
29 jit_generator *host, const conv_gemm_conf_t &jcp,
30 const Xbyak::Reg64 &reg_zp_pad_comp,
31 const Xbyak::Reg64 &reg_zp_pad_comp_temp,
32 const Xbyak::Reg8 &should_apply_zp_src_pad, const dim_t ndims)
33 : host_(host)
34 , jcp_(jcp)
35 , w_addr_(host->qword[host_->rsp])
36 , h_addr_(host->qword[host_->rsp + 8])
37 , w_size_addr_(host->qword[host_->rsp + 16])
38 , w_off_addr_(host->qword[host_->rsp + 24])
39 , zp_pad_com_h_(host->qword[host_->rsp + 32])
40 , zp_pad_com_w_(host->qword[host_->rsp + 40])
41 , zp_pad_com_base_(host->qword[host_->rsp + 48])
42 , g_oc_offset_prologue_(host->qword[host_->rsp + 56])
43 , g_oc_offset_(host->qword[host_->rsp + 64])
44 , zp_pad_com_d_offset_(host->qword[host_->rsp + 72])
45 , h_under_lower_bound_(host->byte[host_->rsp + 80])
46 , h_over_eq_upper_bound_(host->byte[host_->rsp + 81])
47 , w_under_lower_bound_(host->byte[host_->rsp + 82])
48 , w_over_eq_upper_bound_(host->byte[host_->rsp + 83])
49 , should_apply_zp_src_pad_comp_d_(host->byte[host_->rsp + 84])
50 , should_apply_zp_src_pad_(should_apply_zp_src_pad)
51 , lower_h_bound_(calculate_lower_bound_dim(jcp.zp.src_pad_comp.top_pad))
52 , upper_h_bound_(
53 calculate_upper_bound_dim(jcp.oh, jcp.zp.src_pad_comp.bottom_pad))
54 , lower_w_bound_(calculate_lower_bound_dim(jcp.zp.src_pad_comp.left_pad))
55 , upper_w_bound_(
56 calculate_upper_bound_dim(jcp.ow, jcp.zp.src_pad_comp.right_pad))
57 , lower_d_bound_(calculate_lower_bound_dim(jcp.zp.src_pad_comp.front_pad))
58 , upper_d_bound_(
59 calculate_upper_bound_dim(jcp.od, jcp.zp.src_pad_comp.back_pad))
60 , with_zp_pad_com_d_(ndims >= 5)
61 , with_zp_pad_com_h_(ndims >= 4)
62 , reg_zp_pad_comp_(reg_zp_pad_comp)
63 , reg_zp_pad_comp_tmp_(reg_zp_pad_comp_temp) {}
64
65void jit_gemm_x8s8s32x_zp_pad_comp_helper::init(const dim_t off_w,
66 const dim_t off_h, const dim_t off_w_size, const dim_t off_w_off,
67 const dim_t off_zp_pad_com_base_off,
68 const dim_t off_g_oc_offset_prologue, const dim_t off_g_oc_offset,
69 const dim_t off_zp_src_pad_com_d_offset,
70 const dim_t off_should_apply_zp_src_pad_comp_d) {
71
72 set_up_initial_args(off_w, off_h, off_w_size, off_w_off,
73 off_zp_pad_com_base_off, off_g_oc_offset_prologue, off_g_oc_offset,
74 off_zp_src_pad_com_d_offset, off_should_apply_zp_src_pad_comp_d);
75 should_apply_zp_src_pad();
76 load_zp_src_comp_pad_addr_if_needed(g_oc_offset_prologue_);
77}
78
79void jit_gemm_x8s8s32x_zp_pad_comp_helper::
80 load_next_point_zp_src_comp_pad_addr() {
81 next_point();
82 should_apply_zp_src_pad();
83 load_zp_src_comp_pad_addr_if_needed(g_oc_offset_);
84}
85
86void jit_gemm_x8s8s32x_zp_pad_comp_helper::zp_src_comp_pad_operation(
87 const std::function<void(const Xbyak::Reg64 &)> &op) {
88 if (op) {
89 Xbyak::Label end;
90 host_->cmp(should_apply_zp_src_pad_, 0);
91 host_->je(end, host_->T_NEAR);
92 op(reg_zp_pad_comp_);
93 host_->L(end);
94 }
95}
96
97jit_gemm_x8s8s32x_zp_pad_comp_helper::zp_src_pad_com_d
98jit_gemm_x8s8s32x_zp_pad_comp_helper::calculate_zp_src_pad_com_d(
99 const dim_t d_off) const {
100
101 dim_t zp_src_pad_com_d_off = 0;
102
103 if (!with_zp_pad_com_d_) { return {false, zp_src_pad_com_d_off}; }
104
105 const bool d_under_lower_bound = d_off < lower_d_bound_;
106 const bool d_over_eq_upper_bound = d_off >= upper_d_bound_;
107 const bool should_apply_zp_src_pad_comp_d
108 = d_under_lower_bound || d_over_eq_upper_bound;
109
110 dim_t zp_src_pad_com_d = 0;
111 if (d_under_lower_bound) {
112 zp_src_pad_com_d = d_off;
113 } else if (d_over_eq_upper_bound) {
114 zp_src_pad_com_d = jcp_.zp.src_pad_comp.front_pad
115 + jcp_.zp.src_pad_comp.mid_d
116 + (jcp_.zp.src_pad_comp.back_pad - (jcp_.od - d_off));
117 } else {
118 zp_src_pad_com_d = jcp_.zp.src_pad_comp.front_pad;
119 }
120
121 zp_src_pad_com_d_off = zp_src_pad_com_d * jcp_.zp.src_pad_comp.h
122 * jcp_.zp.src_pad_comp.w;
123
124 return {should_apply_zp_src_pad_comp_d, zp_src_pad_com_d_off};
125}
126
127void jit_gemm_x8s8s32x_zp_pad_comp_helper::fin() {
128 host_->add(host_->rsp, reserved_stack_size_);
129}
130
131dim_t jit_gemm_x8s8s32x_zp_pad_comp_helper::calculate_lower_bound_dim(
132 const dim_t begin_comp_pad) const noexcept {
133 return begin_comp_pad;
134}
135
136dim_t jit_gemm_x8s8s32x_zp_pad_comp_helper::calculate_upper_bound_dim(
137 const dim_t output_size, const dim_t end_comp_pad) const noexcept {
138 return output_size - end_comp_pad;
139}
140
141void jit_gemm_x8s8s32x_zp_pad_comp_helper::set_up_initial_args(
142 const dim_t off_w, const dim_t off_h, const dim_t off_w_size,
143 const dim_t off_w_off, const dim_t off_zp_pad_com_base_off,
144 const dim_t off_g_oc_offset_prologue, const dim_t off_g_oc_offset,
145 const dim_t off_zp_src_pad_com_d_offset,
146 const dim_t off_should_apply_zp_src_pad_comp_d) {
147 const auto push = [&](const dim_t src_off,
148 const Xbyak::Address &stack_addr) {
149 host_->mov(reg_zp_pad_comp_tmp_, host_->qword[abi_param1 + src_off]);
150 host_->mov(stack_addr, reg_zp_pad_comp_tmp_);
151 };
152
153 host_->sub(host_->rsp, reserved_stack_size_);
154 push(off_w, w_addr_);
155 check_bound(
156 reg_zp_pad_comp_tmp_, w_under_lower_bound_, lower_w_bound_, lower);
157 check_bound(reg_zp_pad_comp_tmp_, w_over_eq_upper_bound_, upper_w_bound_,
158 upper);
159
160 if (with_zp_pad_com_h_) {
161 push(off_h, h_addr_);
162 check_bound(reg_zp_pad_comp_tmp_, h_under_lower_bound_, lower_h_bound_,
163 lower);
164 check_bound(reg_zp_pad_comp_tmp_, h_over_eq_upper_bound_,
165 upper_h_bound_, upper);
166 }
167
168 push(off_w_size, w_size_addr_);
169 push(off_w_off, w_off_addr_);
170 push(off_zp_pad_com_base_off, zp_pad_com_base_);
171 push(off_g_oc_offset_prologue, g_oc_offset_prologue_);
172 push(off_g_oc_offset, g_oc_offset_);
173
174 if (with_zp_pad_com_d_)
175 push(off_zp_src_pad_com_d_offset, zp_pad_com_d_offset_);
176
177 const auto reg_zp_pad_comp_tmp_i8 = reg_zp_pad_comp_tmp_.cvt8();
178 host_->mov(reg_zp_pad_comp_tmp_i8,
179 host_->byte[abi_param1 + off_should_apply_zp_src_pad_comp_d]);
180 host_->mov(should_apply_zp_src_pad_comp_d_, reg_zp_pad_comp_tmp_i8);
181}
182
183void jit_gemm_x8s8s32x_zp_pad_comp_helper::check_bound(
184 const Xbyak::Reg64 &reg_dim, const Xbyak::Address &result_addr,
185 const dim_t bound_value, const bound bound_kind) {
186
187 host_->cmp(reg_dim, bound_value);
188 if (bound_kind == lower)
189 host_->setl(result_addr);
190 else
191 host_->setge(result_addr);
192}
193
194void jit_gemm_x8s8s32x_zp_pad_comp_helper::load_zp_src_comp_pad_addr_if_needed(
195 const Xbyak::Address &g_oc_offset) {
196 Xbyak::Label calc_zp_src_comp_pad_addr, end;
197 host_->cmp(should_apply_zp_src_pad_, 0);
198 host_->je(end, host_->T_NEAR);
199
200 host_->L(calc_zp_src_comp_pad_addr);
201 {
202 const auto &comp_pad = jcp_.zp.src_pad_comp;
203 if (with_zp_pad_com_h_) {
204 get_zp_pad_com_dim(h_under_lower_bound_, h_over_eq_upper_bound_,
205 comp_pad.top_pad, comp_pad.mid_h, comp_pad.bottom_pad,
206 jcp_.oh, h_addr_, zp_pad_com_h_);
207 }
208 get_zp_pad_com_dim(w_under_lower_bound_, w_over_eq_upper_bound_,
209 comp_pad.left_pad, comp_pad.mid_w, comp_pad.right_pad, jcp_.ow,
210 w_addr_, zp_pad_com_w_);
211 calculate_zp_src_comp_pad_effective_addr(g_oc_offset);
212 }
213
214 host_->L(end);
215}
216
217void jit_gemm_x8s8s32x_zp_pad_comp_helper::
218 calculate_zp_src_comp_pad_effective_addr(
219 const Xbyak::Address &g_oc_offset) {
220 // Calculation steps:
221 // comp_pad_offset = ((zp_pad_com_d * jcp.zp.src_pad_comp.h + zp_pad_com_h)
222 // * jcp.zp.src_pad_comp.w + zp_pad_com_w)
223 // * jcp.oc * jcp.ngroups + (g * jcp.oc + oc);
224 // zp_pad_comp = zp_pad_comp_base + comp_pad_offset
225 if (with_zp_pad_com_h_) {
226 host_->mov(reg_zp_pad_comp_tmp_, jcp_.zp.src_pad_comp.w);
227 host_->imul(reg_zp_pad_comp_tmp_, zp_pad_com_h_);
228 if (with_zp_pad_com_d_)
229 host_->add(reg_zp_pad_comp_tmp_, zp_pad_com_d_offset_);
230 host_->add(reg_zp_pad_comp_tmp_, zp_pad_com_w_);
231 } else {
232 host_->mov(reg_zp_pad_comp_tmp_, zp_pad_com_w_);
233 }
234
235 host_->imul(
236 reg_zp_pad_comp_tmp_, reg_zp_pad_comp_tmp_, jcp_.oc * jcp_.ngroups);
237 host_->add(reg_zp_pad_comp_tmp_, g_oc_offset);
238 host_->imul(reg_zp_pad_comp_tmp_, reg_zp_pad_comp_tmp_, sizeof(int32_t));
239 host_->mov(reg_zp_pad_comp_, zp_pad_com_base_);
240 host_->add(reg_zp_pad_comp_, reg_zp_pad_comp_tmp_);
241}
242
243void jit_gemm_x8s8s32x_zp_pad_comp_helper::get_zp_pad_com_dim(
244 const Xbyak::Address &dim_under_lower_bound,
245 const Xbyak::Address &dim_over_eq_upper_bound, const dim_t begin_pad,
246 dim_t mid_pad, const dim_t end_pad, const dim_t out_dim_size,
247 const Xbyak::Address &out_point_dim, const Xbyak::Address &result) {
248
249 Xbyak::Label end, lower_bound, upper_bound, mid_point;
250
251 host_->L(lower_bound);
252 {
253 host_->cmp(dim_under_lower_bound, 0);
254 host_->je(upper_bound, host_->T_NEAR);
255 host_->mov(reg_zp_pad_comp_tmp_, out_point_dim);
256 host_->mov(result, reg_zp_pad_comp_tmp_);
257 host_->jmp(end, host_->T_NEAR);
258 }
259 host_->L(upper_bound);
260 {
261 host_->cmp(dim_over_eq_upper_bound, 0);
262 host_->je(mid_point, host_->T_NEAR);
263 host_->mov(reg_zp_pad_comp_tmp_,
264 begin_pad + mid_pad + end_pad - out_dim_size);
265 host_->add(reg_zp_pad_comp_tmp_, out_point_dim);
266 host_->mov(result, reg_zp_pad_comp_tmp_);
267 host_->jmp(end, host_->T_NEAR);
268 }
269
270 host_->L(mid_point);
271 { host_->mov(result, begin_pad); }
272 host_->L(end);
273}
274
275void jit_gemm_x8s8s32x_zp_pad_comp_helper::should_apply_zp_src_pad() {
276 const Xbyak::Reg8 &reg_tmp8 = reg_zp_pad_comp_tmp_.cvt8();
277 host_->mov(reg_tmp8, w_under_lower_bound_);
278 host_->or_(reg_tmp8, w_over_eq_upper_bound_);
279 if (with_zp_pad_com_h_) {
280 host_->or_(reg_tmp8, h_over_eq_upper_bound_);
281 host_->or_(reg_tmp8, h_under_lower_bound_);
282 }
283 if (with_zp_pad_com_d_)
284 host_->or_(reg_tmp8, should_apply_zp_src_pad_comp_d_);
285 host_->setne(should_apply_zp_src_pad_);
286}
287
288void jit_gemm_x8s8s32x_zp_pad_comp_helper::next_point() {
289
290 Xbyak::Label inc_h, inc_w, row_begin, store_w;
291
292 const Xbyak::Reg64 &reg_w = reg_zp_pad_comp_tmp_;
293 const Xbyak::Reg64 &reg_h = reg_zp_pad_comp_;
294
295 host_->L(inc_w);
296 {
297 host_->mov(reg_w, w_addr_);
298 host_->add(reg_w, 1);
299 }
300
301 host_->cmp(reg_w, w_size_addr_);
302 host_->jl(store_w, host_->T_NEAR);
303
304 if (with_zp_pad_com_h_) {
305
306 host_->L(inc_h);
307 {
308 host_->mov(reg_h, h_addr_);
309 host_->add(reg_h, 1);
310 host_->mov(h_addr_, reg_h);
311 }
312
313 check_bound(reg_h, h_under_lower_bound_, lower_h_bound_, lower);
314 check_bound(reg_h, h_over_eq_upper_bound_, upper_h_bound_, upper);
315 }
316
317 host_->L(row_begin);
318 { host_->mov(reg_w, w_off_addr_); }
319
320 host_->L(store_w);
321 {
322 check_bound(reg_w, w_under_lower_bound_, lower_w_bound_, lower);
323 check_bound(reg_w, w_over_eq_upper_bound_, upper_w_bound_, upper);
324 }
325
326 host_->mov(w_addr_, reg_w);
327}
328
329} // namespace gemm_x8s8s32x_convolution_utils
330} // namespace x64
331} // namespace cpu
332} // namespace impl
333} // namespace dnnl
334