1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/memory.hpp" |
19 | #include "common/memory_tracking.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | #include "cpu/x64/injectors/injector_utils.hpp" |
26 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
28 | #include "cpu/x64/jit_uni_x8s8s32x_conv_kernel.hpp" |
29 | |
30 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | using namespace dnnl::impl::memory_tracking::names; |
38 | using namespace dnnl::impl::utils; |
39 | using namespace Xbyak; |
40 | using namespace injector_utils; |
41 | |
42 | namespace { |
43 | void pick_loop_order(jit_conv_conf_t &jcp) { |
44 | jcp.loop_order = loop_cwgn; |
45 | if (jcp.ngroups > 1) { |
46 | jcp.loop_order = loop_ngcw; |
47 | if (jcp.mb < jcp.nthr) |
48 | jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; |
49 | } else if (jcp.mb >= jcp.nthr && jcp.ic_without_padding <= 8) { |
50 | jcp.loop_order = loop_ngcw; |
51 | } |
52 | } |
53 | } // namespace |
54 | |
55 | template <cpu_isa_t isa, typename Vmm> |
56 | _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::_jit_uni_x8s8s32x_fwd_kernel( |
57 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
58 | const memory_desc_t &dst_md) |
59 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, isa) |
60 | , jcp(ajcp) |
61 | , attr_(attr) { |
62 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
63 | using namespace binary_injector; |
64 | static constexpr bool preserve_gpr = true; |
65 | static constexpr bool preserve_vmm = false; |
66 | static constexpr size_t helper_vmm_idx = 15; |
67 | const size_t oc_block_tail = jcp.oc_block % isa_simd_width_; |
68 | const size_t tail_size = oc_block_tail |
69 | ? oc_block_tail |
70 | : jcp.oc_without_padding % isa_simd_width_; |
71 | |
72 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
73 | r13, r14, r15, preserve_gpr, preserve_vmm, |
74 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
75 | memory_desc_wrapper(dst_md), tail_size, true}; |
76 | const static_params_t static_params { |
77 | this->param1, rhs_arg_static_params}; |
78 | |
79 | postops_injector_ |
80 | = utils::make_unique<injector::jit_uni_postops_injector_t<isa>>( |
81 | this, jcp.post_ops, static_params); |
82 | } |
83 | } |
84 | |
85 | template <cpu_isa_t isa, typename Vmm> |
86 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::prepare_output(int ur_w) { |
87 | int nb_oc_block |
88 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
89 | for (int k = 0; k < nb_oc_block; ++k) |
90 | for (int j = 0; j < ur_w; ++j) { |
91 | Vmm vmm = vmm_out(j, k); |
92 | uni_vpxor(vmm, vmm, vmm); |
93 | } |
94 | if (jcp.signed_input) { |
95 | auto xmm_shift = Xbyak::Xmm(vmm_shift.getIdx()); |
96 | if (jcp.is_depthwise) |
97 | mov(reg_scratch, 128); |
98 | else |
99 | mov(reg_scratch, 0x80808080); |
100 | uni_vmovq(xmm_shift, reg_scratch); |
101 | uni_vpbroadcastd(vmm_shift, xmm_shift); |
102 | } |
103 | } |
104 | |
105 | template <cpu_isa_t isa, typename Vmm> |
106 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::cvt2ps(data_type_t type_in, |
107 | const Vmm &vmm_in, const Reg64 ®, int offset, int load_size) { |
108 | |
109 | load_data(type_in, vmm_in, reg, offset, load_size); |
110 | if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in); |
111 | } |
112 | |
113 | template <typename F> |
114 | void iterate(const int nb_oc_block, const int ur_w, |
115 | const bool last_oc_block_flag, const bool force_masking, const F &f) { |
116 | for (int k = 0; k < nb_oc_block; k++) { |
117 | const bool mask_flag |
118 | = force_masking || (last_oc_block_flag && k == nb_oc_block - 1); |
119 | for (int j = 0; j < ur_w; j++) |
120 | f(mask_flag, k, j); |
121 | } |
122 | } |
123 | template <typename F> |
124 | void iterate(const int nb_oc_block, const int ur_w, |
125 | const bool last_oc_block_flag, const F &f) { |
126 | iterate(nb_oc_block, ur_w, last_oc_block_flag, false, f); |
127 | } |
128 | template <typename F> |
129 | void iterate(const int nb_oc_block, const int ur_w, const F &f) { |
130 | iterate(nb_oc_block, ur_w, false, false, f); |
131 | } |
132 | |
133 | template <cpu_isa_t isa, typename Vmm> |
134 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::apply_sum(const int nb_oc_block, |
135 | const int ur_w, const bool last_oc_block_flag, const int oc_block, |
136 | const float *p_sum_scale, const int32_t *p_sum_zp) { |
137 | if (jcp.with_sum) { |
138 | assert(!utils::any_null(p_sum_scale, p_sum_zp) |
139 | && "p_sum_scale or p_sum_zp = nullptr" ); |
140 | const float sum_scale = *p_sum_scale; |
141 | const int32_t sum_zp = *p_sum_zp; |
142 | const auto sum_injector_lam = [this, oc_block, sum_scale, sum_zp]( |
143 | const bool mask_flag, const int k, |
144 | const int j) { |
145 | const int aux_output_offset = jcp.typesize_out |
146 | * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); |
147 | cvt2ps(jcp.sum_dt, vmm_prev_dst, reg_out, aux_output_offset, |
148 | mask_flag ? get_tail_size() : get_blocking_size()); |
149 | const Vmm vmm = vmm_out(j, k); |
150 | if (sum_zp != 0) { |
151 | uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_zp]); |
152 | uni_vcvtdq2ps(vmm_tmp, vmm_tmp); |
153 | uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_tmp); |
154 | } |
155 | if (sum_scale == 1.f) |
156 | uni_vaddps(vmm, vmm, vmm_prev_dst); |
157 | else { |
158 | uni_vbroadcastss(vmm_tmp, ptr[reg_ptr_sum_scale]); |
159 | uni_vfmadd231ps(vmm, vmm_prev_dst, vmm_tmp); |
160 | } |
161 | }; |
162 | const auto sum_injector = [=]() { |
163 | iterate(nb_oc_block, ur_w, last_oc_block_flag, sum_injector_lam); |
164 | }; |
165 | if (*p_sum_scale != 1.f) |
166 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
167 | if (*p_sum_zp != 0) |
168 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
169 | postops_injector_->set_lambda_injector( |
170 | primitive_kind::sum, sum_injector); |
171 | } |
172 | } |
173 | |
174 | template <cpu_isa_t isa, typename Vmm> |
175 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::apply_postops( |
176 | const int nb_oc_block, const int ur_w, const bool last_oc_block_flag, |
177 | const int oc_block, const float *p_sum_scale, const int32_t *p_sum_zp) { |
178 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
179 | if (jcp.with_sum && *p_sum_zp != 0) push(reg_ptr_sum_zp); |
180 | apply_sum(nb_oc_block, ur_w, last_oc_block_flag, oc_block, p_sum_scale, |
181 | p_sum_zp); |
182 | |
183 | vmm_index_set_t vmm_idxs; |
184 | if (jcp.with_binary) { |
185 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
186 | const bool oc_blk_is_smaller_than_vmm = oc_block < isa_simd_width_; |
187 | iterate(nb_oc_block, ur_w, last_oc_block_flag, |
188 | oc_blk_is_smaller_than_vmm, |
189 | [&](const bool mask_flag, const int k, const int j) { |
190 | const size_t aux_output_offset = jcp.typesize_out |
191 | * (k * oc_block |
192 | + j * jcp.oc_without_padding |
193 | * jcp.ngroups); |
194 | const auto vmm_idx = vmm_out_idx(j, k); |
195 | vmm_idxs.emplace(vmm_idx); |
196 | |
197 | rhs_arg_params.vmm_idx_to_out_reg.emplace( |
198 | vmm_idx, reg_out); |
199 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
200 | vmm_idx, aux_output_offset); |
201 | |
202 | if (mask_flag) |
203 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
204 | }); |
205 | |
206 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
207 | } else { |
208 | iterate(nb_oc_block, ur_w, |
209 | [&](const bool, const int k, const int j) { |
210 | vmm_idxs.emplace(vmm_out_idx(j, k)); |
211 | }); |
212 | postops_injector_->compute_vector_range(vmm_idxs); |
213 | } |
214 | if (jcp.with_sum && *p_sum_zp != 0) pop(reg_ptr_sum_zp); |
215 | } |
216 | } |
217 | |
218 | template <cpu_isa_t isa, typename Vmm> |
219 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::store_output( |
220 | int ur_w, bool last_oc_block_flag) { |
221 | int nb_oc_block |
222 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
223 | int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; |
224 | |
225 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
226 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
227 | if (jcp.signed_input) |
228 | mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); |
229 | |
230 | if (jcp.src_zero_point) { |
231 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
232 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
233 | uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]); |
234 | } |
235 | |
236 | const auto &p = attr_.post_ops_; |
237 | const int sum_idx = p.find(primitive_kind::sum); |
238 | const float *p_sum_scale = nullptr; |
239 | const int32_t *p_sum_zp = nullptr; |
240 | if (sum_idx != -1) { |
241 | const auto &p_entry = p.entry_[sum_idx]; |
242 | p_sum_scale = &p_entry.sum.scale; |
243 | p_sum_zp = &p_entry.sum.zero_point; |
244 | } |
245 | |
246 | for (int k = 0; k < nb_oc_block; ++k) { |
247 | const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; |
248 | const int load_size = mask_flag ? get_tail_size() : get_blocking_size(); |
249 | int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block); |
250 | if (jcp.with_bias) { |
251 | int bias_offset = jcp.typesize_bia * k * oc_block; |
252 | cvt2ps(jcp.bia_dt, vmm_bias, reg_bias, bias_offset, load_size); |
253 | } |
254 | if (jcp.signed_input) { |
255 | const int comp_offset = sizeof(int32_t) * k * oc_block; |
256 | load_data(data_type::s32, vmm_comp, reg_compensation, comp_offset, |
257 | load_size); |
258 | } |
259 | if (jcp.src_zero_point) { |
260 | const int zp_offset = sizeof(int32_t) * k * oc_block; |
261 | load_data(data_type::s32, vmm_zp_comp, reg_zp_compensation, |
262 | zp_offset, load_size); |
263 | uni_vpmulld(vmm_zp_comp, vmm_zp_comp, vmm_zp); |
264 | } |
265 | /* add to ymm_accum: compensation, zero_point, bias and permute */ |
266 | if (mask_flag) { |
267 | uni_vpxor(vmm_scale, vmm_scale, vmm_scale); |
268 | load_data(data_type::s32, vmm_scale, reg_ptr_scales, scale_offset, |
269 | get_tail_size()); |
270 | } else { |
271 | uni_vmovups(vmm_scale, ptr[reg_ptr_scales + scale_offset]); |
272 | } |
273 | |
274 | for (int j = 0; j < ur_w; ++j) { |
275 | const Vmm vmm = vmm_out(j, k); |
276 | |
277 | /* add comp in s32 to avoid loss of precision |
278 | when convert s32 to f32 in integer (2^24) |
279 | TODO: do the same to bias */ |
280 | if (jcp.signed_input) uni_vpaddd(vmm, vmm, vmm_comp); |
281 | if (jcp.src_zero_point) uni_vpaddd(vmm, vmm, vmm_zp_comp); |
282 | uni_vcvtdq2ps(vmm, vmm); |
283 | |
284 | uni_vmulps(vmm, vmm, vmm_scale); |
285 | |
286 | if (jcp.with_bias) uni_vaddps(vmm, vmm, vmm_bias); |
287 | } |
288 | } |
289 | |
290 | apply_postops(nb_oc_block, ur_w, last_oc_block_flag, oc_block, p_sum_scale, |
291 | p_sum_zp); |
292 | |
293 | if (jcp.dst_scale) { |
294 | mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
295 | uni_vmovups(vmm_dst_scale, ptr[reg_dst_scale]); |
296 | |
297 | /* Apply dst scale to accumulator */ |
298 | for (int k = 0; k < nb_oc_block; k++) { |
299 | for (int j = 0; j < ur_w; j++) { |
300 | const Vmm vmm = vmm_out(j, k); |
301 | uni_vmulps(vmm, vmm, vmm_dst_scale); |
302 | } |
303 | } |
304 | } |
305 | |
306 | if (jcp.dst_zero_point) { |
307 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
308 | uni_vpbroadcastd(vmm_zp, ptr[reg_dst_zero_point]); |
309 | uni_vcvtdq2ps(vmm_zp, vmm_zp); |
310 | |
311 | /* Add dst zero_point to accumulator */ |
312 | for (int k = 0; k < nb_oc_block; k++) { |
313 | for (int j = 0; j < ur_w; j++) { |
314 | const Vmm vmm = vmm_out(j, k); |
315 | uni_vaddps(vmm, vmm, vmm_zp); |
316 | } |
317 | } |
318 | } |
319 | |
320 | // Properly saturate the accumulators for integer datatypes |
321 | |
322 | // No need to saturate on lower bound for signed integer types, as |
323 | // the conversion to int would return INT_MIN, and then proper |
324 | // saturation will happen in store_data |
325 | if (jcp.dst_dt == data_type::u8) { |
326 | uni_vpxor(vmm_zero, vmm_zero, vmm_zero); |
327 | for (int k = 0; k < nb_oc_block; ++k) { |
328 | for (int j = 0; j < ur_w; ++j) { |
329 | Vmm vmm = vmm_out(j, k); |
330 | uni_vmaxps(vmm, vmm, vmm_zero); |
331 | } |
332 | } |
333 | } |
334 | if (utils::one_of( |
335 | jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) { |
336 | float saturation_ubound = types::max_value<float>(jcp.dst_dt); |
337 | Xmm xmm_saturation(vmm_saturation.getIdx()); |
338 | mov(reg_ptr_saturation_ubound, float2int(saturation_ubound)); |
339 | uni_vmovq(xmm_saturation, reg_ptr_saturation_ubound); |
340 | uni_vbroadcastss(vmm_saturation, xmm_saturation); |
341 | |
342 | for (int k = 0; k < nb_oc_block; ++k) { |
343 | for (int j = 0; j < ur_w; ++j) { |
344 | Vmm vmm = vmm_out(j, k); |
345 | uni_vminps(vmm, vmm, vmm_saturation); |
346 | } |
347 | } |
348 | } |
349 | |
350 | // Convert float accumulators to int datatype if needed |
351 | if (utils::one_of(jcp.dst_dt, data_type::u8, data_type::s8, data_type::s32)) |
352 | for (int k = 0; k < nb_oc_block; ++k) |
353 | for (int j = 0; j < ur_w; ++j) { |
354 | Vmm vmm = vmm_out(j, k); |
355 | uni_vcvtps2dq(vmm, vmm); |
356 | } |
357 | |
358 | /* write out register to output_addr */ |
359 | for (int k = 0; k < nb_oc_block; ++k) { |
360 | const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; |
361 | for (int j = 0; j < ur_w; ++j) { |
362 | Vmm r_vmm = vmm_out(j, k); |
363 | int aux_output_offset = jcp.typesize_out |
364 | * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); |
365 | store_data(jcp.dst_dt, r_vmm, reg_out, aux_output_offset, |
366 | mask_flag ? get_tail_size() : get_blocking_size()); |
367 | } |
368 | } |
369 | } |
370 | |
371 | template <cpu_isa_t isa, typename Vmm> |
372 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::compute_ker_dw(int ur_w, int pad_l, |
373 | int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { |
374 | |
375 | if (!(utils::one_of(isa, avx2) && std::is_same<Vmm, Xbyak::Ymm>::value) |
376 | && !(utils::one_of(isa, sse41) |
377 | && std::is_same<Vmm, Xbyak::Xmm>::value)) |
378 | assert(!"invalid group blocking for depthwise convolution" ); |
379 | |
380 | const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); |
381 | |
382 | if (jcp.src_zero_point) { |
383 | push(aux_reg_ker_d); |
384 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
385 | uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]); |
386 | } |
387 | |
388 | auto input_spatial_index = [=](int oi, int ki) { |
389 | return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l); |
390 | }; |
391 | |
392 | auto input_offset2 = [=](int ii, int ci) { |
393 | if (jcp.is_fused_conv) |
394 | return jcp.typesize_in |
395 | * (ii * jcp.dw_conv_buffer_oc + ci * jcp.ch_block); |
396 | else |
397 | return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block); |
398 | }; |
399 | |
400 | auto input_offset3 = [=](int oi, int ci, int ki) { |
401 | return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci); |
402 | }; |
403 | |
404 | auto kernel_offset = [=](int ci, int ki) { |
405 | return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); |
406 | }; |
407 | |
408 | auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { |
409 | if (jcp.has_vnni) { |
410 | vpdpbusd(vreg_acc, vreg_src, vreg_wei, VexEncoding); |
411 | } else { |
412 | // okay for depthwise since src is zero-extended |
413 | uni_vpmaddwd(vmm_dw_tmp, vreg_src, vreg_wei); |
414 | uni_vpaddd(vreg_acc, vreg_acc, vmm_dw_tmp); |
415 | } |
416 | }; |
417 | |
418 | int ii_start = 0; |
419 | int ii_end = -1; |
420 | if (jcp.is_resrc_depthwise && !h_padded) { |
421 | // find bounds of input spatial indices |
422 | bool first = true; |
423 | for (int ki = 0; ki < jcp.kw; ++ki) { |
424 | int oi_start = get_ow_start(ki, pad_l); |
425 | int oi_end = get_ow_end(ur_w, ki, pad_r); |
426 | for (int oi = oi_start; oi < oi_end; ++oi) { |
427 | int ii = input_spatial_index(oi, ki); |
428 | if (first || ii < ii_start) ii_start = ii; |
429 | if (first || ii > ii_end) ii_end = ii; |
430 | first = false; |
431 | } |
432 | } |
433 | } |
434 | |
435 | for (int ci = 0; ci < jcp.nb_ch_blocking; ++ci) { |
436 | const bool mask_flag = last_ic_block_flag != no_last_block |
437 | && ci == jcp.nb_ch_blocking - 1; |
438 | if (jcp.is_resrc_depthwise && !h_padded) { |
439 | // now we can load input once and reuse up to jcp.kw times |
440 | for (int ii = ii_start; ii <= ii_end; ++ii) { |
441 | int aux_input_offset = input_offset2(ii, ci); |
442 | const Vmm vmm_inp_tmp = vmm_inp(ii, jcp.nb_ch_blocking); |
443 | |
444 | uni_vpxor(vmm_inp_tmp, vmm_inp_tmp, vmm_inp_tmp); |
445 | load_data(data_type::u8, vmm_inp_tmp, aux_reg_inp, |
446 | aux_input_offset, |
447 | mask_flag ? get_tail_size() : get_blocking_size()); |
448 | if (jcp.signed_input) |
449 | uni_vpaddb(vmm_inp_tmp, vmm_inp_tmp, vmm_shift); |
450 | } |
451 | } |
452 | for (int ki = 0; ki < jcp.kw; ++ki) { |
453 | int aux_kernel_offset = kernel_offset(ci, ki); |
454 | int oi_start = get_ow_start(ki, pad_l); |
455 | int oi_end = get_ow_end(ur_w, ki, pad_r); |
456 | |
457 | if (compute_kernel) { |
458 | uni_vpmovsxbd(vmm_wei, ptr[aux_reg_ker + aux_kernel_offset]); |
459 | if (h_padded) { |
460 | assert(jcp.signed_input); |
461 | for (int oi = 0; oi < ur_w; ++oi) |
462 | compute(vmm_out(oi, ci), vmm_wei, vmm_shift); |
463 | } else { |
464 | int start = jcp.signed_input ? 0 : oi_start; |
465 | int end = jcp.signed_input ? ur_w : oi_end; |
466 | for (int oi = start; oi < end; ++oi) { |
467 | if (oi >= oi_start && oi < oi_end) { |
468 | if (jcp.is_resrc_depthwise) { |
469 | int ii = input_spatial_index(oi, ki); |
470 | vmm_dw_src = vmm_inp(ii, jcp.nb_ch_blocking); |
471 | } else { |
472 | int aux_input_offset |
473 | = input_offset3(oi, ci, ki); |
474 | load_data(data_type::u8, vmm_dw_src, |
475 | aux_reg_inp, aux_input_offset, |
476 | mask_flag ? get_tail_size() |
477 | : get_blocking_size()); |
478 | if (jcp.signed_input) |
479 | uni_vpaddb( |
480 | vmm_dw_src, vmm_dw_src, vmm_shift); |
481 | } |
482 | compute(vmm_out(oi, ci), vmm_wei, vmm_dw_src); |
483 | } else { |
484 | assert(jcp.signed_input); |
485 | compute(vmm_out(oi, ci), vmm_wei, vmm_shift); |
486 | } |
487 | } |
488 | } |
489 | } |
490 | if (jcp.src_zero_point) { |
491 | /* calculate src_zero_point padding as: |
492 | * (is_padding ? |
493 | * src_zero_point_s32 * conv(1, wei_s32) : 0) */ |
494 | if (!compute_kernel) { |
495 | uni_vpmovsxbd( |
496 | vmm_wei, ptr[aux_reg_ker + aux_kernel_offset]); |
497 | } // else: already loaded weights from previous block |
498 | for (int oi = 0; oi < ur_w; oi++) { |
499 | if (oi < oi_start || oi >= oi_end || h_padded) { |
500 | uni_vpmulld(vmm_zp_dw_tmp, vmm_wei, vmm_zp); |
501 | uni_vpaddd(vmm_out(oi, ci), vmm_out(oi, ci), |
502 | vmm_zp_dw_tmp); |
503 | } |
504 | } |
505 | } |
506 | } |
507 | } |
508 | |
509 | if (jcp.src_zero_point) pop(aux_reg_ker_d); |
510 | } |
511 | |
512 | template <cpu_isa_t isa, typename Vmm> |
513 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::compute_ker(int ur_w, int pad_l, |
514 | int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { |
515 | if (jcp.is_depthwise) |
516 | return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); |
517 | |
518 | int kw = jcp.kw; |
519 | int stride_w = jcp.stride_w; |
520 | int ic_block = jcp.ic_block; |
521 | int oc_block = jcp.oc_block; |
522 | int ch_block_all = jcp.ch_block * ic_block * oc_block; |
523 | |
524 | int nb_oc_block = jcp.nb_oc_blocking; |
525 | |
526 | const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); |
527 | |
528 | assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input)); |
529 | |
530 | if (jcp.src_zero_point) { |
531 | push(aux_reg_ker_d); |
532 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
533 | } |
534 | |
535 | auto input_offset = [=](int oi, int ic, int ki) { |
536 | return jcp.typesize_in |
537 | * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) |
538 | * jcp.ic_without_padding * jcp.ngroups |
539 | + ic_sub_step * ic); |
540 | }; |
541 | auto kernel_offset = [=](int ii, int ic, int ki) { |
542 | return jcp.typesize_in |
543 | * ((ii * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki) |
544 | * ch_block_all |
545 | + ic_sub_step * ic * oc_block); |
546 | }; |
547 | auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { |
548 | if (jcp.has_vnni) { |
549 | vpdpbusd(vreg_acc, vreg_src, vreg_wei, VexEncoding); |
550 | } else { |
551 | uni_vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); |
552 | uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); |
553 | uni_vpaddd(vreg_acc, vreg_acc, vmm_tmp); |
554 | } |
555 | }; |
556 | |
557 | for (int ki = 0; ki < kw; ++ki) { |
558 | const int ow_start = get_ow_start(ki, pad_l); |
559 | const int ow_end = get_ow_end(ur_w, ki, pad_r); |
560 | const int ic_tail_size = jcp.ic_without_padding % ic_sub_step; |
561 | |
562 | const int _start = jcp.signed_input ? 0 : ow_start; |
563 | const int _end = jcp.signed_input ? ur_w : ow_end; |
564 | |
565 | /* Skip the last loads of input |
566 | if (ic % 8) / ic_sub_step < ic_block / ic_sub_step */ |
567 | const int icb = (last_ic_block_flag != no_last_block) |
568 | ? div_up((jcp.ic_without_padding % ic_block), ic_sub_step) |
569 | : ic_block / ic_sub_step; |
570 | |
571 | if (compute_kernel) { |
572 | for (int ic = 0; ic < icb; ++ic) { |
573 | if (h_padded) { |
574 | // fill padded area with shifted value in first iteration |
575 | if (ic == 0) { |
576 | const Vmm inp = vmm_inp(0, nb_oc_block); |
577 | uni_vmovups(inp, vmm_shift); |
578 | } |
579 | } else { |
580 | for (int jj = _start; jj < _end; ++jj) { |
581 | int aux_input_offset = input_offset(jj, ic, ki); |
582 | if (jj >= ow_start && jj < ow_end) { |
583 | const bool need_partial_ic_bcast = true |
584 | && last_ic_block_flag == last_sp_block |
585 | && ic_tail_size != 0 && ic == icb - 1; |
586 | |
587 | if (need_partial_ic_bcast) { |
588 | const auto inp_bcastd_vmm |
589 | = vmm_inp(jj, nb_oc_block); |
590 | const auto inp_bcastd |
591 | = Xmm(inp_bcastd_vmm.getIdx()); |
592 | load_bytes(inp_bcastd_vmm, aux_reg_inp, |
593 | aux_input_offset, ic_tail_size); |
594 | uni_vpbroadcastd( |
595 | vmm_inp(jj, nb_oc_block), inp_bcastd); |
596 | } else { |
597 | uni_vpbroadcastd(vmm_inp(jj, nb_oc_block), |
598 | ptr[aux_reg_inp + aux_input_offset]); |
599 | } |
600 | if (jcp.signed_input) |
601 | uni_vpaddb(vmm_inp(jj, nb_oc_block), |
602 | vmm_inp(jj, nb_oc_block), vmm_shift); |
603 | } else { |
604 | // fill padded area with shifted value in |
605 | // first iteration |
606 | if (jcp.signed_input && ic == 0) { |
607 | const Vmm inp = vmm_inp(jj, nb_oc_block); |
608 | uni_vmovups(inp, vmm_shift); |
609 | } |
610 | } |
611 | } |
612 | } |
613 | for (int ii = 0; ii < nb_oc_block; ++ii) { |
614 | const int aux_kernel_offset = kernel_offset(ii, ic, ki); |
615 | uni_vmovdqu(vmm_wei, ptr[aux_reg_ker + aux_kernel_offset]); |
616 | for (int jj = _start; jj < _end; ++jj) { |
617 | const Vmm inp = vmm_inp(h_padded ? 0 : jj, nb_oc_block); |
618 | compute(vmm_out(jj, ii), vmm_wei, inp); |
619 | } |
620 | } |
621 | } |
622 | } |
623 | if (jcp.src_zero_point) { |
624 | uni_vpbroadcastd(vmm_zp, ptr[reg_src_zero_point]); |
625 | /* calculate src_zero_point padding as: |
626 | * (is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0) */ |
627 | const Vmm vmm_wacc = vmm_inp(0, nb_oc_block); |
628 | for (int jj = 0; jj < ur_w; jj++) { |
629 | if (jj < ow_start || jj >= ow_end || h_padded) { |
630 | for (int ii = 0; ii < nb_oc_block; ii++) { |
631 | uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp); |
632 | for (int ic = 0; ic < icb; ic++) { |
633 | const int aux_kernel_offset |
634 | = kernel_offset(ii, ic, ki); |
635 | if (jcp.has_vnni) { |
636 | vpdpbusd(vmm_tmp, vmm_zp_one, |
637 | ptr[aux_reg_ker + aux_kernel_offset], |
638 | VexEncoding); |
639 | } else { |
640 | uni_vpmaddubsw(vmm_wacc, vmm_zp_one, |
641 | ptr[aux_reg_ker + aux_kernel_offset]); |
642 | uni_vpmaddwd(vmm_wacc, vmm_wacc, vmm_one); |
643 | uni_vpaddd(vmm_tmp, vmm_tmp, vmm_wacc); |
644 | } |
645 | } |
646 | uni_vpmulld(vmm_tmp, vmm_tmp, vmm_zp); |
647 | uni_vpaddd(vmm_out(jj, ii), vmm_out(jj, ii), vmm_tmp); |
648 | } |
649 | } |
650 | } |
651 | } |
652 | } |
653 | if (jcp.src_zero_point) pop(aux_reg_ker_d); |
654 | } |
655 | |
656 | template <cpu_isa_t isa, typename Vmm> |
657 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::kh_loop( |
658 | int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) { |
659 | |
660 | Label kd_label, kh_label, skip_kd_loop, skip_kh_loop; |
661 | Label f_overflow_label, no_f_overflow_label, d_h_f_overflow_label, |
662 | t_overflow_label, no_t_overflow_label, b_overflow_label, |
663 | no_b_overflow_label, back_overflow_label, no_back_overflow_label, |
664 | d_h_back_overflow_label; |
665 | |
666 | int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; |
667 | int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; |
668 | int shift_input_ptr |
669 | = jcp.typesize_in * jcp.iw * jcp.ic_without_padding * jcp.ngroups; |
670 | |
671 | if (jcp.src_zero_point && !jcp.is_depthwise) { |
672 | const auto xmm_one = Xbyak::Xmm(vmm_zp_one.getIdx()); |
673 | mov(reg_scratch, 0x01010101); |
674 | uni_vmovq(xmm_one, reg_scratch); |
675 | uni_vpbroadcastd(vmm_zp_one, xmm_one); |
676 | } |
677 | |
678 | if (jcp.ndims == 5) { |
679 | mov(aux_reg_ker_d, reg_ker); |
680 | mov(aux_reg_inp_d, reg_inp); |
681 | if (jcp.signed_input || jcp.src_zero_point) { |
682 | //TODO: May be avoided when f_pad=0 and dd0 |
683 | //TODO: Potential optimization by precomputing, when kd <<< od? |
684 | mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]); |
685 | cmp(reg_ki, 0); |
686 | je(no_f_overflow_label, T_NEAR); |
687 | L(f_overflow_label); |
688 | { |
689 | mov(aux_reg_ker, aux_reg_ker_d); |
690 | mov(reg_kj, jcp.kh); |
691 | L(d_h_f_overflow_label); |
692 | { |
693 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
694 | add(aux_reg_ker, shift_kernel_ptr); |
695 | dec(reg_kj); |
696 | jne(d_h_f_overflow_label); |
697 | } |
698 | add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh); |
699 | dec(reg_ki); |
700 | jne(f_overflow_label); |
701 | } |
702 | L(no_f_overflow_label); |
703 | } |
704 | |
705 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
706 | if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_d >= jcp.id) |
707 | || (!(jcp.signed_input || jcp.src_zero_point) |
708 | && (jcp.kd - 1) * (jcp.dilate_d + 1) |
709 | < nstl::max(jcp.f_pad, jcp.back_pad))) { |
710 | cmp(reg_ki, 0); |
711 | je(skip_kd_loop, T_NEAR); |
712 | } |
713 | L(kd_label); |
714 | mov(aux_reg_inp, aux_reg_inp_d); |
715 | mov(aux_reg_ker, aux_reg_ker_d); |
716 | } else { |
717 | if (jcp.is_fused_conv) { |
718 | mov(aux_reg_inp_buffer_ptr, reg_inp_buffer_ptr); |
719 | } else { |
720 | mov(aux_reg_inp, reg_inp); |
721 | } |
722 | mov(aux_reg_ker, reg_ker); |
723 | } |
724 | |
725 | if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { |
726 | mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); |
727 | cmp(reg_overflow, 0); |
728 | je(no_t_overflow_label, T_NEAR); |
729 | L(t_overflow_label); |
730 | { |
731 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
732 | |
733 | add(aux_reg_ker, shift_kernel_ptr); |
734 | dec(reg_overflow); |
735 | cmp(reg_overflow, 0); |
736 | jg(t_overflow_label, T_NEAR); |
737 | } |
738 | L(no_t_overflow_label); |
739 | } |
740 | mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); |
741 | if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_h >= jcp.ih) |
742 | || (!(jcp.signed_input || jcp.src_zero_point) |
743 | && (jcp.kh - 1) * (jcp.dilate_h + 1) |
744 | < nstl::max(jcp.t_pad, jcp.b_pad))) { |
745 | cmp(reg_kj, 0); |
746 | je(skip_kh_loop, T_NEAR); |
747 | } |
748 | L(kh_label); |
749 | { |
750 | if (jcp.is_fused_conv) { |
751 | mov(aux_reg_inp, ptr[aux_reg_inp_buffer_ptr]); |
752 | add(aux_reg_inp, reg_inp); |
753 | } |
754 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); |
755 | |
756 | add(aux_reg_ker, shift_kernel_ptr); |
757 | if (jcp.is_fused_conv) { |
758 | add(aux_reg_inp_buffer_ptr, sizeof(void *)); |
759 | } else { |
760 | add(aux_reg_inp, shift_input_ptr * (jcp.dilate_h + 1)); |
761 | } |
762 | dec(reg_kj); |
763 | cmp(reg_kj, 0); |
764 | jg(kh_label, T_NEAR); |
765 | } |
766 | L(skip_kh_loop); |
767 | if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { |
768 | mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); |
769 | cmp(reg_overflow, 0); |
770 | je(no_b_overflow_label, T_NEAR); |
771 | L(b_overflow_label); |
772 | { |
773 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
774 | |
775 | add(aux_reg_ker, shift_kernel_ptr); |
776 | dec(reg_overflow); |
777 | cmp(reg_overflow, 0); |
778 | jg(b_overflow_label, T_NEAR); |
779 | } |
780 | L(no_b_overflow_label); |
781 | } |
782 | if (jcp.ndims == 5) { |
783 | add(aux_reg_inp_d, shift_input_ptr * jcp.ih * (jcp.dilate_d + 1)); |
784 | add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh); |
785 | dec(reg_ki); |
786 | jne(kd_label, T_NEAR); |
787 | |
788 | L(skip_kd_loop); |
789 | if (jcp.signed_input || jcp.src_zero_point) { |
790 | mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]); |
791 | cmp(reg_ki, 0); |
792 | je(no_back_overflow_label, T_NEAR); |
793 | L(back_overflow_label); |
794 | { |
795 | mov(aux_reg_ker, aux_reg_ker_d); |
796 | mov(reg_kj, jcp.kh); |
797 | L(d_h_back_overflow_label); |
798 | { |
799 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
800 | add(aux_reg_ker, shift_kernel_ptr); |
801 | dec(reg_kj); |
802 | jne(d_h_back_overflow_label); |
803 | } |
804 | add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh); |
805 | dec(reg_ki); |
806 | jne(back_overflow_label); |
807 | } |
808 | L(no_back_overflow_label); |
809 | } |
810 | } |
811 | } |
812 | |
813 | template <cpu_isa_t isa, typename Vmm> |
814 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::icb_loop( |
815 | int ur_w, int pad_l, int pad_r, bool is_last_sp_block) { |
816 | prepare_output(ur_w); |
817 | |
818 | // IC loop |
819 | Label icb_label; |
820 | mov(reg_icb, jcp.nb_ic); |
821 | L(icb_label); |
822 | const bool do_icb_loop |
823 | = jcp.is_depthwise ? jcp.nb_ch > jcp.nb_ch_blocking : jcp.nb_ic > 1; |
824 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) { |
825 | Label common_ker, end_ker; |
826 | if (do_icb_loop) { |
827 | if (jcp.is_depthwise) |
828 | cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); |
829 | else |
830 | cmp(reg_icb, 1); // The last IC block |
831 | jne(common_ker, T_NEAR); |
832 | } |
833 | kh_loop(ur_w, pad_l, pad_r, |
834 | is_last_sp_block ? last_sp_block : last_ic_block); |
835 | if (do_icb_loop) { |
836 | jmp(end_ker, T_NEAR); |
837 | |
838 | L(common_ker); |
839 | kh_loop(ur_w, pad_l, pad_r, no_last_block); |
840 | |
841 | L(end_ker); |
842 | } |
843 | } else { |
844 | kh_loop(ur_w, pad_l, pad_r, no_last_block); |
845 | } |
846 | // End of IC Loop |
847 | if (do_icb_loop) { |
848 | int inp_step = jcp.ic_block; |
849 | const size_t ker_step = (size_t)jcp.kd * jcp.kh * jcp.kw * jcp.oc_block |
850 | * jcp.ic_block; |
851 | add(reg_inp, jcp.typesize_in * inp_step); |
852 | safe_add(reg_ker, jcp.typesize_in * ker_step, reg_ker_long_offt); |
853 | |
854 | dec(reg_icb); |
855 | cmp(reg_icb, 0); |
856 | jg(icb_label, T_NEAR); |
857 | |
858 | sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); |
859 | safe_sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic, |
860 | reg_ker_long_offt); |
861 | } |
862 | |
863 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { |
864 | Label common_store, end_store; |
865 | |
866 | if (jcp.is_depthwise) |
867 | cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); |
868 | else |
869 | cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); |
870 | |
871 | jne(common_store, T_NEAR); |
872 | |
873 | store_output(ur_w, true); // last oc block |
874 | jmp(end_store, T_NEAR); |
875 | |
876 | L(common_store); |
877 | store_output(ur_w, false); |
878 | |
879 | L(end_store); |
880 | } else { |
881 | store_output(ur_w, false); |
882 | } |
883 | } |
884 | |
885 | template <cpu_isa_t isa, typename Vmm> |
886 | void _jit_uni_x8s8s32x_fwd_kernel<isa, Vmm>::generate() { |
887 | Label permute_index_table; |
888 | int in_ic_shift = jcp.is_fused_conv ? jcp.dw_conv_buffer_oc |
889 | : jcp.ic_without_padding * jcp.ngroups; |
890 | const int urw_inp_stride = jcp.ur_w * jcp.stride_w; |
891 | const int n_urw_l_pad |
892 | = nstl::min(div_up(jcp.l_pad, urw_inp_stride), jcp.ow / jcp.ur_w); |
893 | const int inp_shift_pad = nstl::max(0, |
894 | jcp.typesize_in * (n_urw_l_pad * urw_inp_stride - jcp.l_pad) |
895 | * in_ic_shift); |
896 | int inp_shift = jcp.typesize_in * (jcp.ur_w * jcp.stride_w * in_ic_shift); |
897 | int out_shift = jcp.typesize_out |
898 | * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); |
899 | preamble(); |
900 | |
901 | if (jcp.is_depthwise) { |
902 | const bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point; |
903 | // dst zero point and dst scale reuse the same register |
904 | int idx = ker_max_reg + 1 - jcp.max_regs_ur |
905 | - nstl::max(2 * is_zero_point, static_cast<int>(jcp.dst_scale)); |
906 | if (!jcp.is_resrc_depthwise) vmm_dw_src = Vmm(--idx); |
907 | if (!jcp.has_vnni) vmm_dw_tmp = Vmm(--idx); |
908 | if (jcp.signed_input) { |
909 | --idx; // due to extra register used for compensations |
910 | } |
911 | assert(IMPLICATION(!jcp.dst_scale && !is_zero_point, |
912 | idx == ker_max_reg - ker_dw_reg_base_idx)); |
913 | } |
914 | |
915 | if (!jcp.is_depthwise && (!jcp.has_vnni)) { |
916 | auto vmm_one_128 = Xbyak::Xmm(vmm_one.getIdx()); |
917 | mov(reg_scratch, 0x10001); |
918 | uni_vmovq(vmm_one_128, reg_scratch); |
919 | uni_vpbroadcastd(vmm_one, vmm_one_128); |
920 | } |
921 | |
922 | if (jcp.is_fused_conv) { |
923 | mov(reg_inp_buffer_ptr, ptr[param1 + GET_OFF(src)]); |
924 | /* In case of fused depthwise convolution, `param.src` is not a pointer |
925 | to input, instead it points to a buffer containing pointers to |
926 | consecutive rows of input in format wc with c=jcp.dw_conv_buffer_oc. |
927 | Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2]. |
928 | Traverse the data as |
929 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
930 | ... process row0 ... |
931 | add(reg_input_buffer_ptr, sizeof(void*)) |
932 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
933 | ... process row1 ... |
934 | add(reg_input_buffer_ptr, sizeof(void*)) |
935 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
936 | ... process row2 ... |
937 | */ |
938 | xor_(reg_inp, reg_inp); |
939 | } else { |
940 | mov(reg_inp, ptr[param1 + GET_OFF(src)]); |
941 | } |
942 | mov(reg_out, ptr[param1 + GET_OFF(dst)]); |
943 | mov(reg_ker, ptr[param1 + GET_OFF(filt)]); |
944 | |
945 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { |
946 | mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); |
947 | } |
948 | |
949 | const int extended_filter_size |
950 | = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
951 | const int r_pad = nstl::max(0, jcp.r_pad); |
952 | const int ow_with_no_rpad = 1 |
953 | + (jcp.iw + jcp.l_pad + nstl::min(0, jcp.r_pad) |
954 | - extended_filter_size) |
955 | / jcp.stride_w; |
956 | const int n_urw_per_ow_block = jcp.ow_block / jcp.ur_w; |
957 | const int max_safe_iw = nstl::max( |
958 | 0, jcp.iw - div_up(ic_sub_step, jcp.ic_without_padding)); |
959 | const int max_safe_ow = jcp.ic_without_padding % ic_sub_step == 0 |
960 | ? jcp.ow |
961 | : (max_safe_iw + jcp.l_pad - extended_filter_size) / jcp.stride_w; |
962 | Label middle_block_label, done_compute; |
963 | std::vector<Label> ow_block_jmp_table; |
964 | |
965 | // r_pad_fall_through is a special ow_block, where the block overlaps |
966 | // both middle_block and r_pad/ur_w_tail region when it exists. |
967 | // The number of ur_w's to compute in middle_block before executing |
968 | // r_pad region is stored in r_pad_fall_through_n_urw and the ow_block |
969 | // number is stored in r_pad_fall_through_ow_block. |
970 | int r_pad_fall_through_ow_block = 0; |
971 | int r_pad_fall_through_n_urw = 0; |
972 | |
973 | if (jcp.nb_ow > 1) { |
974 | // Only one ow block is processed, per jit call. |
975 | // Number of this ow block is passed as parameter owb, |
976 | // and padding processing depends on this number. |
977 | // |
978 | // The compute block to run is determined by using a jmp-table. |
979 | // jmp-table Layout: |
980 | // idx -> addr |
981 | // 0 -> [...l_pad_region label[0]...] |
982 | // : : : : : : : : : : : : : : : |
983 | // L -> [...l_pad_region label[L]...] |
984 | // L+1 -> [...r_pad_region label[0]...] |
985 | // : : : : : : : : : : : : : : : |
986 | // L+R -> [...r_pad_region label[R]...] |
987 | // |
988 | // Note: Label for middle_block is not stored in the jmp-table. |
989 | // |
990 | // During jit call, the jump address is calculated as below: |
991 | // if (owb < n) { |
992 | // jmp([jmp_table + owb*sizeof(void*)]); |
993 | // } else if (owb < X) { |
994 | // // X is the number of ow_blocks before r_pad region (see below). |
995 | // jmp(middle_block); |
996 | // } else { |
997 | // sub(owb, X); |
998 | // jmp([jmp_table + owb*sizeof(void*) + L*sizeof(void)]); |
999 | // } |
1000 | // |
1001 | // To configure the jmp-table, we need to determine some constants |
1002 | // (namely, r_pad_fall_through_n_urw, r_pad_fall_through_ow_block, |
1003 | // n_l_pad_labels, n_labels) ahead of writing the compute assembly. So, |
1004 | // we simulate the filter path without writing the assembly initially. |
1005 | // This makes the math for calculating the constants become simple and |
1006 | // self explanatory. |
1007 | |
1008 | // Begin simulation without writing assembly |
1009 | int n_l_pad_labels = 0; |
1010 | int n_labels = 0; |
1011 | int cur_ow = 0; |
1012 | |
1013 | // l_pad region: |
1014 | n_l_pad_labels = div_up(n_urw_l_pad, n_urw_per_ow_block); |
1015 | n_labels = n_l_pad_labels; |
1016 | cur_ow += n_urw_l_pad * jcp.ur_w; |
1017 | |
1018 | // middle_region: |
1019 | int n_urw_middle_block_loop = 0; |
1020 | int cur_r_pad = nstl::max(0, |
1021 | calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw, |
1022 | jcp.stride_w, extended_filter_size)); |
1023 | if (cur_ow + jcp.ur_w <= jcp.ow && cur_r_pad == 0) { |
1024 | n_urw_middle_block_loop |
1025 | = nstl::max(0, |
1026 | nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow) |
1027 | / jcp.ur_w; |
1028 | cur_ow += n_urw_middle_block_loop * jcp.ur_w; |
1029 | } |
1030 | r_pad_fall_through_n_urw = (cur_ow / jcp.ur_w) % n_urw_per_ow_block; |
1031 | r_pad_fall_through_ow_block = cur_ow / (n_urw_per_ow_block * jcp.ur_w); |
1032 | |
1033 | // r_pad or last_sp_block |
1034 | if (cur_ow + jcp.ur_w <= jcp.ow) { |
1035 | if (r_pad_fall_through_n_urw == 0) ++n_labels; |
1036 | const int n_urw_r_pad_region = (jcp.ow - cur_ow) / jcp.ur_w; |
1037 | n_labels += nstl::max(0, |
1038 | div_up(r_pad_fall_through_n_urw + n_urw_r_pad_region, |
1039 | n_urw_per_ow_block) |
1040 | - 1); |
1041 | } |
1042 | |
1043 | if (jcp.ur_w_tail != 0) { |
1044 | if (jcp.ow % jcp.ow_block == jcp.ur_w_tail) ++n_labels; |
1045 | } |
1046 | // End of simulation |
1047 | |
1048 | ow_block_jmp_table.resize(n_labels); |
1049 | |
1050 | // Begin jump-table logic |
1051 | Label ow_block_jmp_table_label; |
1052 | if (!ow_block_jmp_table.empty()) |
1053 | mov(reg_jmp_tbl_base, ow_block_jmp_table_label); |
1054 | mov(reg_oi, n_urw_per_ow_block); |
1055 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1056 | if (jcp.l_pad > 0) { |
1057 | Label middle_or_rpad_check; |
1058 | cmp(reg_owb, n_l_pad_labels); |
1059 | jge(middle_or_rpad_check, T_NEAR); |
1060 | jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *)]); |
1061 | L(middle_or_rpad_check); |
1062 | // harness passes shifted src pointer that does not take |
1063 | // left-padding into account. So, we must re-shift here. |
1064 | const int inp_shift_pad_middle_block = -1 * jcp.typesize_in |
1065 | * nstl::min(jcp.l_pad, n_urw_l_pad * urw_inp_stride) |
1066 | * in_ic_shift; |
1067 | add(reg_inp, inp_shift_pad_middle_block); |
1068 | } |
1069 | if (r_pad_fall_through_n_urw != 0) { |
1070 | mov(reg_scratch, r_pad_fall_through_n_urw); |
1071 | cmp(reg_owb, r_pad_fall_through_ow_block); |
1072 | cmove(reg_oi, reg_scratch); |
1073 | if (n_urw_middle_block_loop > 0) { |
1074 | sub(reg_owb, r_pad_fall_through_ow_block); |
1075 | // simple middle_block |
1076 | jle(middle_block_label, T_NEAR); |
1077 | dec(reg_owb); |
1078 | } else { |
1079 | sub(reg_owb, r_pad_fall_through_ow_block + 1); |
1080 | } |
1081 | } else { |
1082 | sub(reg_owb, r_pad_fall_through_ow_block); |
1083 | // simple middle_block |
1084 | if (n_urw_middle_block_loop) jl(middle_block_label, T_NEAR); |
1085 | } |
1086 | // r_pad-only region |
1087 | if (!ow_block_jmp_table.empty()) |
1088 | jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *) |
1089 | + n_l_pad_labels * sizeof(void *)]); |
1090 | |
1091 | if (!ow_block_jmp_table.empty()) { |
1092 | align(8); |
1093 | L(ow_block_jmp_table_label); |
1094 | { |
1095 | for (size_t i = 0; i < ow_block_jmp_table.size(); ++i) { |
1096 | putL(ow_block_jmp_table[i]); |
1097 | } |
1098 | } |
1099 | } |
1100 | // End of jump-table logic |
1101 | } |
1102 | |
1103 | // Begin kernel |
1104 | int cur_ow = 0; |
1105 | int cur_n_oi = 0; // used only for jcp.nb_ow > 1 scenario |
1106 | int label_cntr = 0; |
1107 | int cur_l_pad = 0; |
1108 | if (jcp.l_pad > 0) { |
1109 | for (cur_l_pad = jcp.l_pad; |
1110 | cur_l_pad > 0 && cur_ow + jcp.ur_w <= jcp.ow; |
1111 | cur_l_pad -= urw_inp_stride) { |
1112 | if (jcp.nb_ow > 1 && cur_n_oi == 0) { |
1113 | // cur_n_oi == 0 signifies beginning of new ow_block |
1114 | // (or end of previous block) |
1115 | const dim_t inp_lpad_region_shift = -label_cntr * jcp.ow_block |
1116 | * jcp.stride_w * in_ic_shift; |
1117 | L(ow_block_jmp_table[label_cntr++]); |
1118 | // harness passes shifted src pointer that does not take |
1119 | // left-padding into account. So, we must re-shift here. |
1120 | add(reg_inp, inp_lpad_region_shift); |
1121 | } |
1122 | |
1123 | cur_ow += jcp.ur_w; |
1124 | int cur_r_pad = nstl::max(0, |
1125 | calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw, |
1126 | jcp.stride_w, extended_filter_size)); |
1127 | icb_loop(jcp.ur_w, cur_l_pad, cur_r_pad, cur_ow > max_safe_ow); |
1128 | add(reg_out, out_shift); |
1129 | dec(reg_oi); |
1130 | |
1131 | if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) { |
1132 | // We compute one owb per jit call. So, insert an |
1133 | // unconditional jmp, after computing one owb. |
1134 | jmp(done_compute, T_NEAR); |
1135 | cur_n_oi = 0; |
1136 | } |
1137 | } |
1138 | if (jcp.nb_ow == 1 || cur_n_oi != 0) { |
1139 | // Let it "fall-through" middle_block_label |
1140 | add(reg_inp, inp_shift_pad); |
1141 | } |
1142 | } |
1143 | |
1144 | // middle_block |
1145 | { |
1146 | int cur_r_pad = nstl::max(0, |
1147 | calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw, |
1148 | jcp.stride_w, extended_filter_size)); |
1149 | if (cur_r_pad == 0 && cur_ow + jcp.ur_w <= jcp.ow) { |
1150 | int n_oi_middle_block_loop |
1151 | = nstl::max(0, |
1152 | nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow) |
1153 | / jcp.ur_w; |
1154 | if (jcp.nb_ow == 1 && n_oi_middle_block_loop > 1) |
1155 | mov(reg_oi, n_oi_middle_block_loop); |
1156 | L(middle_block_label); |
1157 | if (n_oi_middle_block_loop > 0) { |
1158 | icb_loop(jcp.ur_w, 0, 0, false); |
1159 | add(reg_inp, inp_shift); |
1160 | add(reg_out, out_shift); |
1161 | if (n_oi_middle_block_loop > 1) { |
1162 | dec(reg_oi); |
1163 | jg(middle_block_label, T_NEAR); |
1164 | } |
1165 | } |
1166 | cur_ow += n_oi_middle_block_loop * jcp.ur_w; |
1167 | cur_n_oi = (cur_n_oi + n_oi_middle_block_loop) % n_urw_per_ow_block; |
1168 | } |
1169 | } |
1170 | |
1171 | // r_pad region or last_sp_block |
1172 | if (cur_ow + jcp.ur_w <= jcp.ow) { |
1173 | if (jcp.nb_ow > 1) { |
1174 | if (cur_n_oi == 0) { |
1175 | jmp(done_compute, T_NEAR); |
1176 | } else { |
1177 | // r_pad fall-through |
1178 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1179 | cmp(reg_owb, r_pad_fall_through_ow_block); |
1180 | jne(done_compute, T_NEAR); |
1181 | } |
1182 | } |
1183 | |
1184 | while (cur_ow + jcp.ur_w <= jcp.ow) { |
1185 | if (jcp.nb_ow > 1 && cur_n_oi == 0) { |
1186 | L(ow_block_jmp_table[label_cntr++]); |
1187 | } |
1188 | cur_ow += jcp.ur_w; |
1189 | int cur_r_pad = calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw, |
1190 | jcp.stride_w, extended_filter_size); |
1191 | assert(cur_r_pad > 0 || cur_ow > max_safe_ow); // else, why be here? |
1192 | icb_loop(jcp.ur_w, 0, cur_r_pad, cur_ow > max_safe_ow); |
1193 | add(reg_inp, inp_shift); |
1194 | add(reg_out, out_shift); |
1195 | |
1196 | if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) { |
1197 | // We compute one owb per jit call. So, insert an |
1198 | // unconditional jmp, after computing one owb. |
1199 | jmp(done_compute, T_NEAR); |
1200 | cur_n_oi = 0; |
1201 | } |
1202 | } |
1203 | // Let it fall-through ur_w_tail |
1204 | } |
1205 | |
1206 | // ur_w_tail |
1207 | if (jcp.ur_w_tail != 0) { |
1208 | if (jcp.nb_ow > 1) { |
1209 | if (cur_n_oi == 0) { |
1210 | jmp(done_compute, T_NEAR); |
1211 | L(ow_block_jmp_table[label_cntr++]); |
1212 | } else { |
1213 | // In case, when there is no r_pad region, then there exists an |
1214 | // ambiguity btw middle_blocks and r_pad_fall_through_ow_block. |
1215 | // If not properly distinguished, there can be a race condition |
1216 | // as middle_blocks and r_pad_fall_through_ow_block both try to |
1217 | // compute ur_w_tail work at the end. |
1218 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1219 | cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? |
1220 | jne(done_compute, T_NEAR); |
1221 | } |
1222 | } |
1223 | icb_loop(jcp.ur_w_tail, nstl::max(0, cur_l_pad), r_pad, true); |
1224 | } |
1225 | L(done_compute); |
1226 | assert(ow_block_jmp_table.size() == static_cast<size_t>(label_cntr)); |
1227 | |
1228 | postamble(); |
1229 | |
1230 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
1231 | } |
1232 | |
1233 | template <cpu_isa_t isa> |
1234 | status_t jit_uni_x8s8s32x_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp, |
1235 | const convolution_desc_t &cd, memory_desc_t &src_md, |
1236 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
1237 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
1238 | using namespace prop_kind; |
1239 | |
1240 | const memory_desc_wrapper src_d(&src_md); |
1241 | const memory_desc_wrapper weights_d(&weights_md); |
1242 | const memory_desc_wrapper dst_d(&dst_md); |
1243 | const memory_desc_wrapper bias_d(&bias_md); |
1244 | |
1245 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1246 | const int ndims = src_d.ndims(); |
1247 | const bool is_1d = ndims == 3; |
1248 | const bool is_2d = ndims == 4; |
1249 | const bool is_3d = ndims == 5; |
1250 | const bool is_avx2 = isa == avx2; |
1251 | assert(is_1d || is_2d || is_3d); |
1252 | |
1253 | if (!(mayiuse(isa) |
1254 | && one_of(src_d.data_type(), data_type::u8, data_type::s8) |
1255 | && weights_d.data_type() == data_type::s8 |
1256 | && one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
1257 | data_type::s8, data_type::u8))) |
1258 | return status::unimplemented; |
1259 | |
1260 | jcp = zero<decltype(jcp)>(); |
1261 | jcp.nthr = nthreads; |
1262 | jcp.ndims = ndims; |
1263 | jcp.prop_kind = cd.prop_kind; |
1264 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1265 | jcp.mb = src_d.dims()[0]; |
1266 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
1267 | jcp.oc_without_padding = jcp.oc; |
1268 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1269 | jcp.ic_without_padding = jcp.ic; |
1270 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
1271 | jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; |
1272 | jcp.iw = src_d.dims()[ndims - 1]; |
1273 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
1274 | jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; |
1275 | jcp.ow = dst_d.dims()[ndims - 1]; |
1276 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
1277 | jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1278 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1279 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
1280 | jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; |
1281 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1282 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
1283 | jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; |
1284 | jcp.stride_w = cd.strides[ndims - 3]; |
1285 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
1286 | |
1287 | jcp.ur_h = 1; /* no code-unrolling by h so far */ |
1288 | |
1289 | jcp.dilate_d = is_3d ? cd.dilates[0] : 0; |
1290 | jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; |
1291 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1292 | |
1293 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1294 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1295 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1296 | jcp.r_pad = calculate_end_padding( |
1297 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
1298 | jcp.b_pad = calculate_end_padding( |
1299 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
1300 | jcp.back_pad = calculate_end_padding( |
1301 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); |
1302 | |
1303 | jcp.has_vnni = mayiuse(avx2_vnni); |
1304 | |
1305 | jcp.signed_input = src_d.data_type() == data_type::s8; |
1306 | jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); |
1307 | |
1308 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
1309 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
1310 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
1311 | |
1312 | const int wei_mask_per_oc = 1 << (int)with_groups; |
1313 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
1314 | jcp.dst_scale = !dst_scales.has_default_values(); |
1315 | |
1316 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
1317 | && everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
1318 | if (!scales_ok) return status::unimplemented; |
1319 | |
1320 | const auto zp = attr.zero_points_; |
1321 | jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); |
1322 | jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); |
1323 | jcp.zp_src_is_common |
1324 | = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel |
1325 | assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)); |
1326 | |
1327 | if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.is_fused_conv) |
1328 | return status::unimplemented; |
1329 | |
1330 | if (is_3d && jcp.is_depthwise) return status::unimplemented; |
1331 | |
1332 | if (jcp.is_depthwise) { |
1333 | jcp.ch_block = is_avx2 ? 8 : 4; |
1334 | jcp.ic_block = 1; |
1335 | jcp.oc_block = 1; |
1336 | } else { |
1337 | jcp.ch_block = 1; |
1338 | jcp.ic_block = is_avx2 ? 8 : 4; |
1339 | jcp.oc_block = is_avx2 ? 8 : 4; |
1340 | |
1341 | if (jcp.ngroups == 1) { |
1342 | /* For non grouped convolutions, pad channels by 8 if needed */ |
1343 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
1344 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
1345 | } else if (jcp.ngroups != 1 |
1346 | && ((jcp.ic % jcp.ic_block != 0) |
1347 | || (jcp.oc % jcp.oc_block != 0))) { |
1348 | /* For grouped convolutions, oneDNN doesn't support padding. |
1349 | * When channels per group is not multiple of 8: |
1350 | * - Use Xmm when channels per group is multiple of 4. |
1351 | * - Otherwise return unimplemented */ |
1352 | jcp.oc_block = jcp.ic_block = 4; |
1353 | } |
1354 | if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0) |
1355 | return status::unimplemented; |
1356 | } |
1357 | |
1358 | jcp.is_resrc_depthwise = true && jcp.is_depthwise && jcp.stride_w < jcp.kw |
1359 | && jcp.kw < 4 && jcp.dilate_w == 0; |
1360 | |
1361 | if (jcp.is_depthwise) { |
1362 | jcp.max_regs_ur = 14 - !jcp.is_resrc_depthwise - jcp.signed_input |
1363 | + (jcp.has_vnni); |
1364 | } else { |
1365 | jcp.max_regs_ur = jcp.has_vnni ? 15 - jcp.signed_input : 12; |
1366 | } |
1367 | |
1368 | if (jcp.dst_scale) jcp.max_regs_ur = 10; |
1369 | if (jcp.src_zero_point || jcp.dst_zero_point) jcp.max_regs_ur = 9; |
1370 | |
1371 | auto set_or_check_wei_format = [&]() { |
1372 | using namespace format_tag; |
1373 | using namespace memory_extra_flags; |
1374 | const int c_mask = 0x1, |
1375 | g_mask = 0x3; // mask for i/o-channel and ngroups |
1376 | format_tag_t wei_tag; |
1377 | if (jcp.ic_block == 8 || jcp.ch_block == 8) { |
1378 | if (is_1d) { |
1379 | wei_tag = with_groups ? jcp.is_depthwise ? Goiw8g : gOIw2i8o4i |
1380 | : OIw2i8o4i; |
1381 | } else if (is_2d) { |
1382 | wei_tag = with_groups ? jcp.is_depthwise ? Goihw8g : gOIhw2i8o4i |
1383 | : OIhw2i8o4i; |
1384 | } else { |
1385 | wei_tag = with_groups ? gOIdhw2i8o4i : OIdhw2i8o4i; |
1386 | } |
1387 | } else { |
1388 | if (is_avx2) { |
1389 | assert(with_groups && jcp.ic_block == 4); |
1390 | wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i; |
1391 | } else { |
1392 | if (is_1d) { |
1393 | wei_tag = with_groups ? jcp.is_depthwise ? Goiw4g : gOIw4o4i |
1394 | : OIw4o4i; |
1395 | } else if (is_2d) { |
1396 | wei_tag = with_groups |
1397 | ? jcp.is_depthwise ? Goihw4g : gOIhw4o4i |
1398 | : OIhw4o4i; |
1399 | } else { |
1400 | wei_tag = with_groups ? gOIdhw4o4i : OIdhw4o4i; |
1401 | } |
1402 | } |
1403 | } |
1404 | memory_desc_t want_wei_md = weights_md; |
1405 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
1406 | |
1407 | if (jcp.signed_input) { |
1408 | want_wei_md.extra.flags = 0 |
1409 | | memory_extra_flags::compensation_conv_s8s8 |
1410 | | memory_extra_flags::scale_adjust; |
1411 | want_wei_md.extra.compensation_mask |
1412 | = (with_groups && !jcp.is_depthwise) ? g_mask : c_mask; |
1413 | want_wei_md.extra.scale_adjust = (jcp.has_vnni) ? 1.f : 0.5f; |
1414 | } |
1415 | if (jcp.src_zero_point) { |
1416 | want_wei_md.extra.flags |= compensation_conv_asymmetric_src; |
1417 | want_wei_md.extra.asymm_compensation_mask |
1418 | = (with_groups && !jcp.is_depthwise) ? g_mask : c_mask; |
1419 | } |
1420 | |
1421 | if (weights_md.format_kind == format_kind::any) { |
1422 | weights_md = want_wei_md; |
1423 | return true; |
1424 | } |
1425 | |
1426 | return weights_md == want_wei_md; |
1427 | }; |
1428 | |
1429 | if (!set_or_check_wei_format()) return status::unimplemented; |
1430 | format_tag_t dat_tag = utils::pick( |
1431 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1432 | |
1433 | if (src_d.format_kind() == format_kind::any) { |
1434 | CHECK(memory_desc_init_by_tag(src_md, dat_tag)); |
1435 | jcp.src_tag = dat_tag; |
1436 | } else { |
1437 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
1438 | } |
1439 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
1440 | |
1441 | if (dst_d.format_kind() == format_kind::any) { |
1442 | CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); |
1443 | jcp.dst_tag = dat_tag; |
1444 | } else { |
1445 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
1446 | } |
1447 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
1448 | |
1449 | if (jcp.with_bias) { |
1450 | if (bias_d.format_kind() == format_kind::any) |
1451 | CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); |
1452 | } |
1453 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
1454 | jcp.dst_dt = cd.dst_desc.data_type; |
1455 | |
1456 | CHECK(attr.set_default_formats(&dst_md)); |
1457 | |
1458 | const auto &post_ops = attr.post_ops_; |
1459 | |
1460 | const int eltwise_ind = post_ops.find(primitive_kind::eltwise); |
1461 | jcp.with_eltwise = eltwise_ind != -1; |
1462 | const int binary_ind = post_ops.find(primitive_kind::binary); |
1463 | jcp.with_binary = binary_ind != -1; |
1464 | const int sum_ind = post_ops.find(primitive_kind::sum); |
1465 | jcp.with_sum = sum_ind != -1; |
1466 | jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); |
1467 | |
1468 | jcp.post_ops = post_ops; |
1469 | |
1470 | using namespace injector; |
1471 | |
1472 | const bool post_ops_ok_ = post_ops_ok({isa, {eltwise, binary, sum}, |
1473 | jcp.post_ops, &dst_d, false, false, false}); |
1474 | if (!post_ops_ok_) return status::unimplemented; |
1475 | |
1476 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
1477 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
1478 | jcp.typesize_bia |
1479 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
1480 | |
1481 | jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); |
1482 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
1483 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
1484 | |
1485 | // Try to use 4 channel-groups at a time to avoid false sharing (depthwise) |
1486 | int nb_ch_blocking = 4; |
1487 | for (/* init above */; nb_ch_blocking > 1; --nb_ch_blocking) |
1488 | if (jcp.nb_ch % nb_ch_blocking == 0) break; |
1489 | jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1; |
1490 | |
1491 | // If OC blocking is incommensurate with the number of OC blocks (general |
1492 | // requirement for all convolutions), or if it results in an unrolling |
1493 | // factor smaller than the left padding (special requirement for SSD:fc6), |
1494 | // then search for a smaller OC blocking that satisfies both constraints. |
1495 | auto is_oc_blocking_ok = [&](int block) { |
1496 | int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1)); |
1497 | return jcp.nb_oc % block == 0 && jcp.l_pad <= ur_w |
1498 | && jcp.ow % ur_w != 1; |
1499 | }; |
1500 | |
1501 | // choose nb_oc work chunk size for distribution within threads |
1502 | int max_threading_nb_oc_chunk = 4; |
1503 | |
1504 | jcp.nb_oc_blocking_thr_chunk |
1505 | = nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc); |
1506 | for (; jcp.nb_oc_blocking_thr_chunk > 1; --jcp.nb_oc_blocking_thr_chunk) { |
1507 | if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) break; |
1508 | } |
1509 | |
1510 | auto get_thr_eff = [=](int nb_ow, int nthr) { |
1511 | int base_work_amount = jcp.mb * jcp.nb_ch * jcp.od * jcp.oh |
1512 | * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk); |
1513 | auto work_amount = base_work_amount * nb_ow; |
1514 | return float(work_amount) / rnd_up(work_amount, nthr); |
1515 | }; |
1516 | |
1517 | auto get_ow_block = [=](int ur_w, int nthr) { |
1518 | int res_ow_block = jcp.ow; |
1519 | float best_thr_eff = get_thr_eff(1, nthr); |
1520 | float thr_eff; |
1521 | |
1522 | int max_nb_ow = div_up(jcp.ow, ur_w); |
1523 | for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) { |
1524 | int ow_block |
1525 | = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); |
1526 | if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block |
1527 | && best_thr_eff > 0.8f) |
1528 | break; |
1529 | if (div_up(jcp.ow, ow_block) != nb_ow) continue; |
1530 | thr_eff = get_thr_eff(nb_ow, nthr); |
1531 | if (ow_block >= ur_w && thr_eff > 1.1f * best_thr_eff) { |
1532 | res_ow_block = ow_block; |
1533 | best_thr_eff = thr_eff; |
1534 | } |
1535 | if (best_thr_eff > 0.9f) break; |
1536 | } |
1537 | return res_ow_block; |
1538 | }; |
1539 | |
1540 | // choose oc blocking for computational kernel |
1541 | jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk; |
1542 | |
1543 | if (jcp.is_resrc_depthwise) |
1544 | jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w) |
1545 | / (jcp.nb_ch_blocking + jcp.stride_w); |
1546 | else |
1547 | jcp.ur_w = jcp.max_regs_ur |
1548 | / (jcp.is_depthwise ? jcp.nb_ch_blocking |
1549 | : jcp.nb_oc_blocking + 1); |
1550 | if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; |
1551 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
1552 | |
1553 | jcp.ow_block = jcp.ow; |
1554 | |
1555 | jcp.ow_block = get_ow_block(jcp.ur_w, jcp.nthr); |
1556 | jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); |
1557 | float thr_eff = get_thr_eff(jcp.nb_ow, jcp.nthr); |
1558 | |
1559 | /* adjust the thread decomposition |
1560 | * to improve the thr_eff for small size problem |
1561 | * the threshold L1_cache_size is empirical */ |
1562 | size_t wei_size |
1563 | = sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw * jcp.kd; |
1564 | size_t out_size |
1565 | = jcp.mb * jcp.typesize_out * jcp.oc * jcp.oh * jcp.ow * jcp.od; |
1566 | size_t inp_size |
1567 | = jcp.mb * jcp.typesize_in * jcp.ic * jcp.ih * jcp.iw * jcp.id; |
1568 | size_t total_size = jcp.ngroups * (wei_size + out_size + inp_size); |
1569 | const unsigned int L1_cache_size = platform::get_per_core_cache_size(1); |
1570 | |
1571 | if (jcp.ngroups < jcp.nthr && (total_size < L1_cache_size)) { |
1572 | int ow_block = jcp.ow_block; |
1573 | float best_thr_eff = thr_eff; |
1574 | float eff = -1.0f; |
1575 | int end_nthr = with_groups ? jcp.ngroups : 1; |
1576 | for (int nthr = jcp.nthr / 2; nthr >= end_nthr; nthr--) { |
1577 | ow_block = get_ow_block(jcp.ur_w, nthr); |
1578 | eff = get_thr_eff(div_up(jcp.ow, ow_block), nthr); |
1579 | if (eff > 1.1f * best_thr_eff) { |
1580 | best_thr_eff = eff; |
1581 | jcp.ow_block = ow_block; |
1582 | jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); |
1583 | jcp.nthr = jcp.aligned_threads = nthr; |
1584 | if (best_thr_eff > 0.9f) break; |
1585 | } |
1586 | } |
1587 | } |
1588 | |
1589 | bool args_ok = true && jcp.oc % jcp.oc_block == 0 |
1590 | && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0); |
1591 | if (!args_ok) return status::unimplemented; |
1592 | |
1593 | pick_loop_order(jcp); |
1594 | |
1595 | jcp.nb_ic_L2 = jcp.nb_ic; |
1596 | |
1597 | jcp.wei_adj_scale |
1598 | = (weights_d.extra().flags & memory_extra_flags::scale_adjust) |
1599 | ? weights_d.extra().scale_adjust |
1600 | : 1.f; |
1601 | |
1602 | return status::success; |
1603 | } |
1604 | |
1605 | template <cpu_isa_t isa> |
1606 | void jit_uni_x8s8s32x_fwd_kernel<isa>::init_scratchpad( |
1607 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
1608 | const primitive_attr_t &attr) { |
1609 | |
1610 | const int mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; |
1611 | const dim_t scales_count = mask == 0 ? 1 : jcp.oc * jcp.ngroups; |
1612 | dim_t count = scales_count == 1 ? (dim_t)8 : scales_count; |
1613 | scratchpad.book<float>(key_conv_adjusted_scales, count); |
1614 | } |
1615 | |
1616 | template struct _jit_uni_x8s8s32x_fwd_kernel<avx2, Ymm>; |
1617 | template struct _jit_uni_x8s8s32x_fwd_kernel<avx2, Xmm>; |
1618 | template struct _jit_uni_x8s8s32x_fwd_kernel<sse41, Xmm>; |
1619 | template struct jit_uni_x8s8s32x_fwd_kernel<avx2>; |
1620 | template struct jit_uni_x8s8s32x_fwd_kernel<sse41>; |
1621 | |
1622 | } // namespace x64 |
1623 | } // namespace cpu |
1624 | } // namespace impl |
1625 | } // namespace dnnl |
1626 | |