1 | /******************************************************************************* |
2 | * Copyright 2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "cpu/x64/jit_gemm_x8s8s32x_conv_zp_src_pad_comp.hpp" |
18 | #include "cpu/x64/jit_generator.hpp" |
19 | |
20 | #include "cpu/gemm_convolution_utils.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace cpu { |
25 | namespace x64 { |
26 | namespace gemm_x8s8s32x_convolution_utils { |
27 | |
28 | jit_gemm_x8s8s32x_zp_pad_comp_helper::jit_gemm_x8s8s32x_zp_pad_comp_helper( |
29 | jit_generator *host, const conv_gemm_conf_t &jcp, |
30 | const Xbyak::Reg64 ®_zp_pad_comp, |
31 | const Xbyak::Reg64 ®_zp_pad_comp_temp, |
32 | const Xbyak::Reg8 &should_apply_zp_src_pad, const dim_t ndims) |
33 | : host_(host) |
34 | , jcp_(jcp) |
35 | , w_addr_(host->qword[host_->rsp]) |
36 | , h_addr_(host->qword[host_->rsp + 8]) |
37 | , w_size_addr_(host->qword[host_->rsp + 16]) |
38 | , w_off_addr_(host->qword[host_->rsp + 24]) |
39 | , zp_pad_com_h_(host->qword[host_->rsp + 32]) |
40 | , zp_pad_com_w_(host->qword[host_->rsp + 40]) |
41 | , zp_pad_com_base_(host->qword[host_->rsp + 48]) |
42 | , g_oc_offset_prologue_(host->qword[host_->rsp + 56]) |
43 | , g_oc_offset_(host->qword[host_->rsp + 64]) |
44 | , zp_pad_com_d_offset_(host->qword[host_->rsp + 72]) |
45 | , h_under_lower_bound_(host->byte[host_->rsp + 80]) |
46 | , h_over_eq_upper_bound_(host->byte[host_->rsp + 81]) |
47 | , w_under_lower_bound_(host->byte[host_->rsp + 82]) |
48 | , w_over_eq_upper_bound_(host->byte[host_->rsp + 83]) |
49 | , should_apply_zp_src_pad_comp_d_(host->byte[host_->rsp + 84]) |
50 | , should_apply_zp_src_pad_(should_apply_zp_src_pad) |
51 | , lower_h_bound_(calculate_lower_bound_dim(jcp.zp.src_pad_comp.top_pad)) |
52 | , upper_h_bound_( |
53 | calculate_upper_bound_dim(jcp.oh, jcp.zp.src_pad_comp.bottom_pad)) |
54 | , lower_w_bound_(calculate_lower_bound_dim(jcp.zp.src_pad_comp.left_pad)) |
55 | , upper_w_bound_( |
56 | calculate_upper_bound_dim(jcp.ow, jcp.zp.src_pad_comp.right_pad)) |
57 | , lower_d_bound_(calculate_lower_bound_dim(jcp.zp.src_pad_comp.front_pad)) |
58 | , upper_d_bound_( |
59 | calculate_upper_bound_dim(jcp.od, jcp.zp.src_pad_comp.back_pad)) |
60 | , with_zp_pad_com_d_(ndims >= 5) |
61 | , with_zp_pad_com_h_(ndims >= 4) |
62 | , reg_zp_pad_comp_(reg_zp_pad_comp) |
63 | , reg_zp_pad_comp_tmp_(reg_zp_pad_comp_temp) {} |
64 | |
65 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::init(const dim_t off_w, |
66 | const dim_t off_h, const dim_t off_w_size, const dim_t off_w_off, |
67 | const dim_t off_zp_pad_com_base_off, |
68 | const dim_t off_g_oc_offset_prologue, const dim_t off_g_oc_offset, |
69 | const dim_t off_zp_src_pad_com_d_offset, |
70 | const dim_t off_should_apply_zp_src_pad_comp_d) { |
71 | |
72 | set_up_initial_args(off_w, off_h, off_w_size, off_w_off, |
73 | off_zp_pad_com_base_off, off_g_oc_offset_prologue, off_g_oc_offset, |
74 | off_zp_src_pad_com_d_offset, off_should_apply_zp_src_pad_comp_d); |
75 | should_apply_zp_src_pad(); |
76 | load_zp_src_comp_pad_addr_if_needed(g_oc_offset_prologue_); |
77 | } |
78 | |
79 | void jit_gemm_x8s8s32x_zp_pad_comp_helper:: |
80 | load_next_point_zp_src_comp_pad_addr() { |
81 | next_point(); |
82 | should_apply_zp_src_pad(); |
83 | load_zp_src_comp_pad_addr_if_needed(g_oc_offset_); |
84 | } |
85 | |
86 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::zp_src_comp_pad_operation( |
87 | const std::function<void(const Xbyak::Reg64 &)> &op) { |
88 | if (op) { |
89 | Xbyak::Label end; |
90 | host_->cmp(should_apply_zp_src_pad_, 0); |
91 | host_->je(end, host_->T_NEAR); |
92 | op(reg_zp_pad_comp_); |
93 | host_->L(end); |
94 | } |
95 | } |
96 | |
97 | jit_gemm_x8s8s32x_zp_pad_comp_helper::zp_src_pad_com_d |
98 | jit_gemm_x8s8s32x_zp_pad_comp_helper::calculate_zp_src_pad_com_d( |
99 | const dim_t d_off) const { |
100 | |
101 | dim_t zp_src_pad_com_d_off = 0; |
102 | |
103 | if (!with_zp_pad_com_d_) { return {false, zp_src_pad_com_d_off}; } |
104 | |
105 | const bool d_under_lower_bound = d_off < lower_d_bound_; |
106 | const bool d_over_eq_upper_bound = d_off >= upper_d_bound_; |
107 | const bool should_apply_zp_src_pad_comp_d |
108 | = d_under_lower_bound || d_over_eq_upper_bound; |
109 | |
110 | dim_t zp_src_pad_com_d = 0; |
111 | if (d_under_lower_bound) { |
112 | zp_src_pad_com_d = d_off; |
113 | } else if (d_over_eq_upper_bound) { |
114 | zp_src_pad_com_d = jcp_.zp.src_pad_comp.front_pad |
115 | + jcp_.zp.src_pad_comp.mid_d |
116 | + (jcp_.zp.src_pad_comp.back_pad - (jcp_.od - d_off)); |
117 | } else { |
118 | zp_src_pad_com_d = jcp_.zp.src_pad_comp.front_pad; |
119 | } |
120 | |
121 | zp_src_pad_com_d_off = zp_src_pad_com_d * jcp_.zp.src_pad_comp.h |
122 | * jcp_.zp.src_pad_comp.w; |
123 | |
124 | return {should_apply_zp_src_pad_comp_d, zp_src_pad_com_d_off}; |
125 | } |
126 | |
127 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::fin() { |
128 | host_->add(host_->rsp, reserved_stack_size_); |
129 | } |
130 | |
131 | dim_t jit_gemm_x8s8s32x_zp_pad_comp_helper::calculate_lower_bound_dim( |
132 | const dim_t begin_comp_pad) const noexcept { |
133 | return begin_comp_pad; |
134 | } |
135 | |
136 | dim_t jit_gemm_x8s8s32x_zp_pad_comp_helper::calculate_upper_bound_dim( |
137 | const dim_t output_size, const dim_t end_comp_pad) const noexcept { |
138 | return output_size - end_comp_pad; |
139 | } |
140 | |
141 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::set_up_initial_args( |
142 | const dim_t off_w, const dim_t off_h, const dim_t off_w_size, |
143 | const dim_t off_w_off, const dim_t off_zp_pad_com_base_off, |
144 | const dim_t off_g_oc_offset_prologue, const dim_t off_g_oc_offset, |
145 | const dim_t off_zp_src_pad_com_d_offset, |
146 | const dim_t off_should_apply_zp_src_pad_comp_d) { |
147 | const auto push = [&](const dim_t src_off, |
148 | const Xbyak::Address &stack_addr) { |
149 | host_->mov(reg_zp_pad_comp_tmp_, host_->qword[abi_param1 + src_off]); |
150 | host_->mov(stack_addr, reg_zp_pad_comp_tmp_); |
151 | }; |
152 | |
153 | host_->sub(host_->rsp, reserved_stack_size_); |
154 | push(off_w, w_addr_); |
155 | check_bound( |
156 | reg_zp_pad_comp_tmp_, w_under_lower_bound_, lower_w_bound_, lower); |
157 | check_bound(reg_zp_pad_comp_tmp_, w_over_eq_upper_bound_, upper_w_bound_, |
158 | upper); |
159 | |
160 | if (with_zp_pad_com_h_) { |
161 | push(off_h, h_addr_); |
162 | check_bound(reg_zp_pad_comp_tmp_, h_under_lower_bound_, lower_h_bound_, |
163 | lower); |
164 | check_bound(reg_zp_pad_comp_tmp_, h_over_eq_upper_bound_, |
165 | upper_h_bound_, upper); |
166 | } |
167 | |
168 | push(off_w_size, w_size_addr_); |
169 | push(off_w_off, w_off_addr_); |
170 | push(off_zp_pad_com_base_off, zp_pad_com_base_); |
171 | push(off_g_oc_offset_prologue, g_oc_offset_prologue_); |
172 | push(off_g_oc_offset, g_oc_offset_); |
173 | |
174 | if (with_zp_pad_com_d_) |
175 | push(off_zp_src_pad_com_d_offset, zp_pad_com_d_offset_); |
176 | |
177 | const auto reg_zp_pad_comp_tmp_i8 = reg_zp_pad_comp_tmp_.cvt8(); |
178 | host_->mov(reg_zp_pad_comp_tmp_i8, |
179 | host_->byte[abi_param1 + off_should_apply_zp_src_pad_comp_d]); |
180 | host_->mov(should_apply_zp_src_pad_comp_d_, reg_zp_pad_comp_tmp_i8); |
181 | } |
182 | |
183 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::check_bound( |
184 | const Xbyak::Reg64 ®_dim, const Xbyak::Address &result_addr, |
185 | const dim_t bound_value, const bound bound_kind) { |
186 | |
187 | host_->cmp(reg_dim, bound_value); |
188 | if (bound_kind == lower) |
189 | host_->setl(result_addr); |
190 | else |
191 | host_->setge(result_addr); |
192 | } |
193 | |
194 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::load_zp_src_comp_pad_addr_if_needed( |
195 | const Xbyak::Address &g_oc_offset) { |
196 | Xbyak::Label calc_zp_src_comp_pad_addr, end; |
197 | host_->cmp(should_apply_zp_src_pad_, 0); |
198 | host_->je(end, host_->T_NEAR); |
199 | |
200 | host_->L(calc_zp_src_comp_pad_addr); |
201 | { |
202 | const auto &comp_pad = jcp_.zp.src_pad_comp; |
203 | if (with_zp_pad_com_h_) { |
204 | get_zp_pad_com_dim(h_under_lower_bound_, h_over_eq_upper_bound_, |
205 | comp_pad.top_pad, comp_pad.mid_h, comp_pad.bottom_pad, |
206 | jcp_.oh, h_addr_, zp_pad_com_h_); |
207 | } |
208 | get_zp_pad_com_dim(w_under_lower_bound_, w_over_eq_upper_bound_, |
209 | comp_pad.left_pad, comp_pad.mid_w, comp_pad.right_pad, jcp_.ow, |
210 | w_addr_, zp_pad_com_w_); |
211 | calculate_zp_src_comp_pad_effective_addr(g_oc_offset); |
212 | } |
213 | |
214 | host_->L(end); |
215 | } |
216 | |
217 | void jit_gemm_x8s8s32x_zp_pad_comp_helper:: |
218 | calculate_zp_src_comp_pad_effective_addr( |
219 | const Xbyak::Address &g_oc_offset) { |
220 | // Calculation steps: |
221 | // comp_pad_offset = ((zp_pad_com_d * jcp.zp.src_pad_comp.h + zp_pad_com_h) |
222 | // * jcp.zp.src_pad_comp.w + zp_pad_com_w) |
223 | // * jcp.oc * jcp.ngroups + (g * jcp.oc + oc); |
224 | // zp_pad_comp = zp_pad_comp_base + comp_pad_offset |
225 | if (with_zp_pad_com_h_) { |
226 | host_->mov(reg_zp_pad_comp_tmp_, jcp_.zp.src_pad_comp.w); |
227 | host_->imul(reg_zp_pad_comp_tmp_, zp_pad_com_h_); |
228 | if (with_zp_pad_com_d_) |
229 | host_->add(reg_zp_pad_comp_tmp_, zp_pad_com_d_offset_); |
230 | host_->add(reg_zp_pad_comp_tmp_, zp_pad_com_w_); |
231 | } else { |
232 | host_->mov(reg_zp_pad_comp_tmp_, zp_pad_com_w_); |
233 | } |
234 | |
235 | host_->imul( |
236 | reg_zp_pad_comp_tmp_, reg_zp_pad_comp_tmp_, jcp_.oc * jcp_.ngroups); |
237 | host_->add(reg_zp_pad_comp_tmp_, g_oc_offset); |
238 | host_->imul(reg_zp_pad_comp_tmp_, reg_zp_pad_comp_tmp_, sizeof(int32_t)); |
239 | host_->mov(reg_zp_pad_comp_, zp_pad_com_base_); |
240 | host_->add(reg_zp_pad_comp_, reg_zp_pad_comp_tmp_); |
241 | } |
242 | |
243 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::get_zp_pad_com_dim( |
244 | const Xbyak::Address &dim_under_lower_bound, |
245 | const Xbyak::Address &dim_over_eq_upper_bound, const dim_t begin_pad, |
246 | dim_t mid_pad, const dim_t end_pad, const dim_t out_dim_size, |
247 | const Xbyak::Address &out_point_dim, const Xbyak::Address &result) { |
248 | |
249 | Xbyak::Label end, lower_bound, upper_bound, mid_point; |
250 | |
251 | host_->L(lower_bound); |
252 | { |
253 | host_->cmp(dim_under_lower_bound, 0); |
254 | host_->je(upper_bound, host_->T_NEAR); |
255 | host_->mov(reg_zp_pad_comp_tmp_, out_point_dim); |
256 | host_->mov(result, reg_zp_pad_comp_tmp_); |
257 | host_->jmp(end, host_->T_NEAR); |
258 | } |
259 | host_->L(upper_bound); |
260 | { |
261 | host_->cmp(dim_over_eq_upper_bound, 0); |
262 | host_->je(mid_point, host_->T_NEAR); |
263 | host_->mov(reg_zp_pad_comp_tmp_, |
264 | begin_pad + mid_pad + end_pad - out_dim_size); |
265 | host_->add(reg_zp_pad_comp_tmp_, out_point_dim); |
266 | host_->mov(result, reg_zp_pad_comp_tmp_); |
267 | host_->jmp(end, host_->T_NEAR); |
268 | } |
269 | |
270 | host_->L(mid_point); |
271 | { host_->mov(result, begin_pad); } |
272 | host_->L(end); |
273 | } |
274 | |
275 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::should_apply_zp_src_pad() { |
276 | const Xbyak::Reg8 ®_tmp8 = reg_zp_pad_comp_tmp_.cvt8(); |
277 | host_->mov(reg_tmp8, w_under_lower_bound_); |
278 | host_->or_(reg_tmp8, w_over_eq_upper_bound_); |
279 | if (with_zp_pad_com_h_) { |
280 | host_->or_(reg_tmp8, h_over_eq_upper_bound_); |
281 | host_->or_(reg_tmp8, h_under_lower_bound_); |
282 | } |
283 | if (with_zp_pad_com_d_) |
284 | host_->or_(reg_tmp8, should_apply_zp_src_pad_comp_d_); |
285 | host_->setne(should_apply_zp_src_pad_); |
286 | } |
287 | |
288 | void jit_gemm_x8s8s32x_zp_pad_comp_helper::next_point() { |
289 | |
290 | Xbyak::Label inc_h, inc_w, row_begin, store_w; |
291 | |
292 | const Xbyak::Reg64 ®_w = reg_zp_pad_comp_tmp_; |
293 | const Xbyak::Reg64 ®_h = reg_zp_pad_comp_; |
294 | |
295 | host_->L(inc_w); |
296 | { |
297 | host_->mov(reg_w, w_addr_); |
298 | host_->add(reg_w, 1); |
299 | } |
300 | |
301 | host_->cmp(reg_w, w_size_addr_); |
302 | host_->jl(store_w, host_->T_NEAR); |
303 | |
304 | if (with_zp_pad_com_h_) { |
305 | |
306 | host_->L(inc_h); |
307 | { |
308 | host_->mov(reg_h, h_addr_); |
309 | host_->add(reg_h, 1); |
310 | host_->mov(h_addr_, reg_h); |
311 | } |
312 | |
313 | check_bound(reg_h, h_under_lower_bound_, lower_h_bound_, lower); |
314 | check_bound(reg_h, h_over_eq_upper_bound_, upper_h_bound_, upper); |
315 | } |
316 | |
317 | host_->L(row_begin); |
318 | { host_->mov(reg_w, w_off_addr_); } |
319 | |
320 | host_->L(store_w); |
321 | { |
322 | check_bound(reg_w, w_under_lower_bound_, lower_w_bound_, lower); |
323 | check_bound(reg_w, w_over_eq_upper_bound_, upper_w_bound_, upper); |
324 | } |
325 | |
326 | host_->mov(w_addr_, reg_w); |
327 | } |
328 | |
329 | } // namespace gemm_x8s8s32x_convolution_utils |
330 | } // namespace x64 |
331 | } // namespace cpu |
332 | } // namespace impl |
333 | } // namespace dnnl |
334 | |