1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/memory.hpp" |
19 | #include "common/memory_tracking.hpp" |
20 | #include "common/nstl.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | #include "common/utils.hpp" |
23 | |
24 | #include "cpu/platform.hpp" |
25 | #include "cpu/x64/injectors/injector_utils.hpp" |
26 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
27 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
28 | #include "cpu/x64/jit_avx512_core_x8s8s32x_conv_kernel.hpp" |
29 | |
30 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
31 | |
32 | namespace dnnl { |
33 | namespace impl { |
34 | namespace cpu { |
35 | namespace x64 { |
36 | |
37 | using namespace dnnl::impl::memory_tracking::names; |
38 | using namespace dnnl::impl::utils; |
39 | using namespace dnnl::impl::data_type; |
40 | using namespace Xbyak; |
41 | |
42 | namespace { |
43 | void pick_loop_order(jit_conv_conf_t &jcp, int nthr) { |
44 | jcp.loop_order = loop_cwgn; |
45 | if (jcp.ngroups > 1) { |
46 | jcp.loop_order = loop_ngcw; |
47 | if (jcp.mb < nthr) |
48 | jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg; |
49 | } else if (jcp.mb >= nthr && jcp.ic_without_padding <= 16) { |
50 | jcp.loop_order = loop_ngcw; |
51 | } |
52 | } |
53 | } // namespace |
54 | |
55 | template <typename Vmm> |
56 | _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::_jit_avx512_core_x8s8s32x_fwd_kernel( |
57 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
58 | const memory_desc_t &dst_md) |
59 | : jit_generator(jit_name()) |
60 | , jcp(ajcp) |
61 | , attr_(attr) |
62 | , postops_injector_(nullptr) { |
63 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
64 | using namespace binary_injector; |
65 | static constexpr bool preserve_gpr = true; |
66 | static constexpr bool preserve_vmm = false; |
67 | static constexpr size_t helper_vmm_idx = 31; |
68 | const size_t oc_block_tail = jcp.oc_block % isa_simd_width_; |
69 | const size_t tail_size = oc_block_tail |
70 | ? oc_block_tail |
71 | : jcp.oc_without_padding % isa_simd_width_; |
72 | static constexpr bool use_exact_tail_scalar_bcast = false; |
73 | |
74 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
75 | r14, r15, r13, preserve_gpr, preserve_vmm, |
76 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
77 | memory_desc_wrapper(dst_md), tail_size, postops_mask, |
78 | use_exact_tail_scalar_bcast}; |
79 | const static_params_t static_params { |
80 | this->param1, rhs_arg_static_params}; |
81 | |
82 | postops_injector_ = utils::make_unique< |
83 | injector::jit_uni_postops_injector_t<avx512_core, Vmm>>( |
84 | this, jcp.post_ops, static_params); |
85 | } |
86 | if (!isa_has_bf16(jcp.isa) && jcp.dst_dt == data_type::bf16) |
87 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
88 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
89 | bf16_emu_scratch, bf16_emu_reserv_4, bf16_emu_reserv_4); |
90 | } |
91 | |
92 | template <typename Vmm> |
93 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w) { |
94 | int nb_oc_block |
95 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
96 | for (int k = 0; k < nb_oc_block; k++) |
97 | for (int j = 0; j < ur_w; j++) { |
98 | Vmm vmm = vmm_out(j, k); |
99 | vpxord(vmm, vmm, vmm); |
100 | } |
101 | if (jcp.signed_input) { |
102 | mov(reg_scratch, 128); |
103 | if (jcp.is_depthwise && !jcp.is_fast_depthwise) |
104 | vpbroadcastd(vmm_shift, reg_scratch.cvt32()); |
105 | else |
106 | vpbroadcastb(vmm_shift, reg_scratch.cvt8()); |
107 | } |
108 | } |
109 | |
110 | template <typename Vmm> |
111 | Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::vmm_mask( |
112 | const Vmm vmm_in, bool mask_flag, bool store) { |
113 | return vmm_in; |
114 | } |
115 | |
116 | template <> |
117 | Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::vmm_mask( |
118 | const Zmm zmm_in, bool mask_flag, bool store) { |
119 | return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) |
120 | : zmm_in; |
121 | } |
122 | |
123 | template <typename Vmm> |
124 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in, |
125 | const Vmm vmm_in, const Operand &op, bool mask_flag) { |
126 | using namespace data_type; |
127 | const Vmm vmm = vmm_mask(vmm_in, mask_flag); |
128 | switch (type_in) { |
129 | case f32: |
130 | case s32: vmovups(vmm, op); break; |
131 | case bf16: |
132 | vpmovzxwd(vmm, op); |
133 | vpslld(vmm_in, vmm_in, 16); |
134 | break; |
135 | case s8: vpmovsxbd(vmm, op); break; |
136 | case u8: vpmovzxbd(vmm, op); break; |
137 | default: assert(!"unsupported data type" ); |
138 | } |
139 | if (one_of(type_in, s32, s8, u8)) vcvtdq2ps(vmm_in, vmm_in); |
140 | } |
141 | |
142 | template <typename F> |
143 | static void iterate(const int nb_oc_block, const int ur_w, |
144 | const bool last_oc_block_flag, const bool force_masking, const F &f) { |
145 | for (int k = 0; k < nb_oc_block; k++) { |
146 | const bool mask_flag |
147 | = force_masking || (last_oc_block_flag && k + 1 == nb_oc_block); |
148 | for (int j = 0; j < ur_w; j++) |
149 | f(mask_flag, k, j); |
150 | } |
151 | } |
152 | template <typename F> |
153 | static void iterate(const int nb_oc_block, const int ur_w, |
154 | const bool last_oc_block_flag, const F &f) { |
155 | iterate(nb_oc_block, ur_w, last_oc_block_flag, false, f); |
156 | } |
157 | template <typename F> |
158 | static void iterate(const int nb_oc_block, const int ur_w, const F &f) { |
159 | iterate(nb_oc_block, ur_w, false, false, f); |
160 | } |
161 | |
162 | template <typename Vmm> |
163 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::apply_sum(int ur_w, |
164 | bool last_oc_block_flag, const int nb_oc_block, const int oc_block, |
165 | const float *p_sum_scale, const int32_t *p_sum_zp) { |
166 | if (jcp.with_sum) { |
167 | const float sum_scale = *p_sum_scale; |
168 | const int32_t sum_zp = *p_sum_zp; |
169 | const auto sum_injector_lam = [this, oc_block, sum_scale, sum_zp]( |
170 | const bool mask_flag, const int k, |
171 | const int j) { |
172 | int aux_output_offset = jcp.typesize_out |
173 | * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups); |
174 | auto addr = EVEX_compress_addr(reg_out, aux_output_offset); |
175 | Vmm vmm = vmm_out(j, k); |
176 | cvt2ps(jcp.sum_dt, vmm_prev_dst, addr, mask_flag); |
177 | if (sum_zp != 0) vsubps(vmm_prev_dst, vmm_sum_zp); |
178 | if (sum_scale == 1.f) |
179 | vaddps(vmm, vmm_prev_dst); |
180 | else |
181 | vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]); |
182 | }; |
183 | const auto sum_injector = [=]() { |
184 | iterate(nb_oc_block, ur_w, last_oc_block_flag, sum_injector_lam); |
185 | }; |
186 | if (sum_scale != 1.f) |
187 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
188 | if (sum_zp != 0) { |
189 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
190 | vcvtdq2ps(vmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
191 | } |
192 | postops_injector_->set_lambda_injector( |
193 | primitive_kind::sum, sum_injector); |
194 | } |
195 | } |
196 | |
197 | template <typename Vmm> |
198 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::apply_postops(int ur_w, |
199 | bool last_oc_block_flag, const int nb_oc_block, const int oc_block, |
200 | const float *p_sum_scale, const int32_t *p_sum_zp) { |
201 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
202 | apply_sum(ur_w, last_oc_block_flag, nb_oc_block, oc_block, p_sum_scale, |
203 | p_sum_zp); |
204 | |
205 | injector_utils::vmm_index_set_t vmm_idxs; |
206 | if (jcp.with_binary) { |
207 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
208 | const bool oc_blk_is_smaller_than_vmm = oc_block < isa_simd_width_; |
209 | iterate(nb_oc_block, ur_w, last_oc_block_flag, |
210 | oc_blk_is_smaller_than_vmm, |
211 | [&](const bool mask_flag, const int k, const int j) { |
212 | const size_t aux_output_l_off = jcp.typesize_out |
213 | * (k * oc_block |
214 | + j * jcp.oc_without_padding |
215 | * jcp.ngroups); |
216 | const auto vmm_idx = vmm_out_idx(j, k); |
217 | vmm_idxs.emplace(vmm_idx); |
218 | |
219 | rhs_arg_params.vmm_idx_to_out_reg.emplace( |
220 | vmm_idx, reg_out); |
221 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace( |
222 | vmm_idx, aux_output_l_off); |
223 | if (mask_flag) |
224 | rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
225 | }); |
226 | |
227 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
228 | } else { |
229 | iterate(nb_oc_block, ur_w, |
230 | [&](const bool, const int k, const int j) { |
231 | vmm_idxs.emplace(vmm_out_idx(j, k)); |
232 | }); |
233 | postops_injector_->compute_vector_range(vmm_idxs); |
234 | } |
235 | } |
236 | } |
237 | |
238 | template <typename Vmm> |
239 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output( |
240 | int ur_w, bool last_oc_block_flag) { |
241 | int nb_oc_block |
242 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
243 | int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block; |
244 | |
245 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
246 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
247 | if (jcp.signed_input) |
248 | mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]); |
249 | |
250 | if (jcp.src_zero_point) { |
251 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
252 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
253 | } |
254 | |
255 | const auto &p = attr_.post_ops_; |
256 | const int sum_idx = p.find(primitive_kind::sum); |
257 | const float *p_sum_scale = nullptr; |
258 | const int32_t *p_sum_zp = nullptr; |
259 | if (sum_idx != -1) { |
260 | const auto &p_entry = p.entry_[sum_idx]; |
261 | p_sum_scale = &p_entry.sum.scale; |
262 | p_sum_zp = &p_entry.sum.zero_point; |
263 | } |
264 | |
265 | for (int k = 0; k < nb_oc_block; k++) { |
266 | const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; |
267 | int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block); |
268 | if (jcp.with_bias) { |
269 | int bias_offset = jcp.typesize_bia * k * oc_block; |
270 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
271 | |
272 | cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag); |
273 | } |
274 | if (jcp.signed_input) { |
275 | int comp_offset = sizeof(int32_t) * k * oc_block; |
276 | Vmm vmm_comp_ = vmm_mask(vmm_comp, mask_flag); |
277 | vmovups(vmm_comp_, |
278 | EVEX_compress_addr(reg_compensation, comp_offset)); |
279 | } |
280 | if (jcp.src_zero_point) { |
281 | // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32 |
282 | int zp_offset = sizeof(int32_t) * k * oc_block; |
283 | Vmm vmm_zp_ = vmm_mask(vmm_zp, mask_flag); |
284 | vmovups(vmm_zp_, |
285 | EVEX_compress_addr(reg_zp_compensation, zp_offset)); |
286 | vpmulld(vmm_zp_, vmm_zp_, |
287 | EVEX_compress_addr( |
288 | reg_src_zero_point, 0, jcp.zp_src_is_common)); |
289 | } |
290 | /* add to zmm_accum: compensation, zero_point, bias and permute */ |
291 | for (int j = 0; j < ur_w; j++) { |
292 | Vmm vmm = vmm_out(j, k); |
293 | if (jcp.is_fast_depthwise) |
294 | vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k)); |
295 | /* add comp in s32 to avoid loss of precision |
296 | when convert s32 to f32 in integer(2^24) |
297 | TODO: do the same to bias */ |
298 | if (jcp.signed_input) vpaddd(vmm, vmm, vmm_comp); |
299 | if (jcp.src_zero_point) vpaddd(vmm, vmm, vmm_zp); |
300 | vcvtdq2ps(vmm, vmm); |
301 | |
302 | const Vmm vmm_k = vmm_mask(vmm, mask_flag); |
303 | vmulps(vmm_k, vmm, |
304 | EVEX_compress_addr(reg_ptr_scales, scale_offset)); |
305 | |
306 | if (jcp.with_bias) vaddps(vmm, vmm, vmm_bias); |
307 | } |
308 | } |
309 | |
310 | apply_postops(ur_w, last_oc_block_flag, nb_oc_block, oc_block, p_sum_scale, |
311 | p_sum_zp); |
312 | |
313 | if (jcp.dst_scale) { |
314 | mov(reg_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
315 | vmovups(vmm_dst_scale, EVEX_compress_addr(reg_dst_scale, 0)); |
316 | |
317 | /* Apply dst scale to accumulator */ |
318 | for (int k = 0; k < nb_oc_block; k++) { |
319 | const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; |
320 | for (int j = 0; j < ur_w; j++) { |
321 | Vmm vmm = vmm_out(j, k); |
322 | const Vmm vmm_k = vmm_mask(vmm, mask_flag); |
323 | vmulps(vmm_k, vmm, vmm_dst_scale); |
324 | } |
325 | } |
326 | } |
327 | |
328 | if (jcp.dst_zero_point) { |
329 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
330 | vcvtdq2ps(vmm_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true)); |
331 | |
332 | /* Add dst zero_point to accumulator */ |
333 | for (int k = 0; k < nb_oc_block; k++) { |
334 | for (int j = 0; j < ur_w; j++) { |
335 | Vmm vmm = vmm_out(j, k); |
336 | vaddps(vmm, vmm, vmm_zp); |
337 | } |
338 | } |
339 | } |
340 | |
341 | // Properly saturate the accumulators for integer datatypes |
342 | if (one_of(jcp.dst_dt, u8, s8, s32)) { |
343 | init_saturate_f32( |
344 | vmm_zero, vmm_saturation, aux_reg_saturation, f32, jcp.dst_dt); |
345 | for (int k = 0; k < nb_oc_block; k++) { |
346 | for (int j = 0; j < ur_w; j++) { |
347 | Vmm vmm = vmm_out(j, k); |
348 | saturate_f32(vmm, vmm_zero, vmm_saturation, jcp.dst_dt); |
349 | vcvtps2dq(vmm, vmm); |
350 | } |
351 | } |
352 | } |
353 | |
354 | if (!isa_has_bf16(jcp.isa) && jcp.dst_dt == data_type::bf16) |
355 | bf16_emu_->init_vcvtneps2bf16(); |
356 | |
357 | /* write out register to output_addr */ |
358 | if (jcp.dst_dt == data_type::bf16 && isa_has_bf16(jcp.isa)) { |
359 | // Optimization: use single store instruction for pair of the |
360 | // nearest vectors along OC dimension |
361 | for (int j = 0; j < ur_w; j++) { |
362 | int k = 0; |
363 | for (; k < rnd_dn(nb_oc_block, 2); k += 2) { |
364 | Vmm vmm = vmm_out(j, k); |
365 | Vmm vmm_next = vmm_out(j, k + 1); |
366 | |
367 | int aux_output_offset = jcp.typesize_out |
368 | * (k * oc_block |
369 | + j * jcp.oc_without_padding * jcp.ngroups); |
370 | auto addr = EVEX_compress_addr(reg_out, aux_output_offset); |
371 | |
372 | vcvtne2ps2bf16(vmm, vmm_next, vmm); |
373 | // mask only needed for last oc_block |
374 | const bool mask_flag |
375 | = last_oc_block_flag && k + 2 == nb_oc_block; |
376 | |
377 | vmovdqu16(addr, maybe_mask_vmm(vmm, mask_flag)); |
378 | } |
379 | if (nb_oc_block % 2 != 0) { |
380 | Vmm vmm = vmm_out(j, k); |
381 | auto vmm_down = Vmm_down_t(vmm.getIdx()); |
382 | int aux_output_offset = jcp.typesize_out |
383 | * (k * oc_block |
384 | + j * jcp.oc_without_padding * jcp.ngroups); |
385 | auto addr = EVEX_compress_addr(reg_out, aux_output_offset); |
386 | vcvtneps2bf16(vmm_down, vmm); |
387 | // for xmm, upper half is zero after conversion to |
388 | // bf16, so mask always & mask for tails |
389 | bool mask_flag = jcp.simd_w == 4 || last_oc_block_flag; |
390 | vmovdqu16(addr, maybe_mask_vmm_down(vmm_down, mask_flag)); |
391 | } |
392 | } |
393 | } else { |
394 | for (int k = 0; k < nb_oc_block; k++) { |
395 | const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1; |
396 | for (int j = 0; j < ur_w; j++) { |
397 | int aux_output_offset = jcp.typesize_out |
398 | * (k * oc_block |
399 | + j * jcp.oc_without_padding * jcp.ngroups); |
400 | auto addr = EVEX_compress_addr(reg_out, aux_output_offset); |
401 | |
402 | Vmm vmm = vmm_out(j, k); |
403 | const Vmm r_vmm = vmm_mask(vmm, mask_flag, true); |
404 | |
405 | switch (jcp.dst_dt) { |
406 | case data_type::f32: |
407 | case data_type::s32: vmovups(addr, r_vmm); break; |
408 | case data_type::s8: vpmovsdb(addr, r_vmm); break; |
409 | case data_type::u8: vpmovusdb(addr, r_vmm); break; |
410 | case data_type::bf16: |
411 | store_bf16(addr, vmm.getIdx(), |
412 | get_src_down_idx(nb_oc_block), mask_flag); |
413 | break; |
414 | default: assert(!"unknown dst_dt" ); |
415 | } |
416 | } |
417 | } |
418 | } |
419 | } |
420 | |
421 | template <typename Vmm> |
422 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(int ur_w, |
423 | int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { |
424 | assert(!"invalid group blocking for depthwise convolution" ); |
425 | } |
426 | |
427 | template <> |
428 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(int ur_w, |
429 | int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { |
430 | |
431 | const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); |
432 | |
433 | if (jcp.src_zero_point) { |
434 | push(aux_reg_ker_d); |
435 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
436 | } |
437 | |
438 | auto input_spatial_index = [=](int oi, int ki) { |
439 | return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l); |
440 | }; |
441 | |
442 | auto input_offset2 = [=](int ii, int ci) { |
443 | if (jcp.is_fused_conv) |
444 | return jcp.typesize_in |
445 | * (ii * jcp.dw_conv_buffer_oc + ci * jcp.ch_block); |
446 | else |
447 | return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block); |
448 | }; |
449 | |
450 | auto input_offset3 = [=](int oi, int ci, int ki) { |
451 | return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci); |
452 | }; |
453 | |
454 | auto kernel_offset = [=](int ci, int ki) { |
455 | return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block); |
456 | }; |
457 | |
458 | auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) { |
459 | // okay for depthwise since src is zero-extended |
460 | if (jcp.has_vnni) { |
461 | vpdpbusd(vreg_acc, vreg_src, vreg_wei); |
462 | } else { |
463 | vpmaddwd(zmm_tmp, vreg_src, vreg_wei); |
464 | vpaddd(vreg_acc, vreg_acc, zmm_tmp); |
465 | } |
466 | }; |
467 | |
468 | int ii_start = 0; |
469 | int ii_end = -1; |
470 | if (jcp.is_resrc_depthwise && !h_padded) { |
471 | // find bounds of input spatial indices |
472 | bool first = true; |
473 | for (int ki = 0; ki < jcp.kw; ki++) { |
474 | int oi_start = get_ow_start(ki, pad_l); |
475 | int oi_end = get_ow_end(ur_w, ki, pad_r); |
476 | for (int oi = oi_start; oi < oi_end; oi++) { |
477 | int ii = input_spatial_index(oi, ki); |
478 | if (first || ii < ii_start) ii_start = ii; |
479 | if (first || ii > ii_end) ii_end = ii; |
480 | first = false; |
481 | } |
482 | } |
483 | } |
484 | |
485 | if (jcp.signed_input) vmovups(zmm_shifted_zero, vmm_shift); |
486 | |
487 | for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) { |
488 | const bool mask_flag = last_ic_block_flag != no_last_block |
489 | && ci == jcp.nb_ch_blocking - 1; |
490 | if (jcp.is_resrc_depthwise && !h_padded) { |
491 | // now we can load input once and reuse up to jcp.kw times |
492 | for (int ii = ii_start; ii <= ii_end; ii++) { |
493 | int aux_input_offset = input_offset2(ii, ci); |
494 | const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking); |
495 | const Zmm zmm_inp_msk = mask_flag |
496 | ? zmm_inp_tmp | ktail_mask | T_z |
497 | : zmm_inp_tmp; |
498 | if (jcp.is_fast_depthwise) { |
499 | assert(!mask_flag); |
500 | vbroadcasti32x4(zmm_inp_msk, |
501 | EVEX_compress_addr(aux_reg_inp, aux_input_offset)); |
502 | } else { |
503 | vpmovzxbd(zmm_inp_msk, |
504 | EVEX_compress_addr(aux_reg_inp, aux_input_offset)); |
505 | } |
506 | if (jcp.signed_input) |
507 | vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift); |
508 | } |
509 | } |
510 | for (int ki = 0; ki < jcp.kw; ki++) { |
511 | int aux_kernel_offset = kernel_offset(ci, ki); |
512 | const int oi_start = get_ow_start(ki, pad_l); |
513 | const int oi_end = get_ow_end(ur_w, ki, pad_r); |
514 | if (compute_kernel) { |
515 | if (jcp.is_fast_depthwise) { |
516 | vbroadcasti32x4(zmm_wei, |
517 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
518 | vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei); |
519 | } else { |
520 | vpmovsxbd(zmm_wei, |
521 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
522 | } |
523 | |
524 | if (h_padded) { |
525 | assert(jcp.signed_input); |
526 | for (int oi = 0; oi < ur_w; oi++) |
527 | compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); |
528 | } else { |
529 | const Zmm r_zmm_src |
530 | = mask_flag ? zmm_src | ktail_mask : zmm_src; |
531 | int start_ = jcp.signed_input ? 0 : oi_start; |
532 | int end_ = jcp.signed_input ? ur_w : oi_end; |
533 | for (int oi = start_; oi < end_; oi++) { |
534 | if (oi >= oi_start && oi < oi_end) { |
535 | if (jcp.is_resrc_depthwise) { |
536 | int ii = input_spatial_index(oi, ki); |
537 | zmm_src = zmm_inp(ii, jcp.nb_ch_blocking); |
538 | } else { |
539 | int aux_input_offset |
540 | = input_offset3(oi, ci, ki); |
541 | if (jcp.is_fast_depthwise) { |
542 | assert(!mask_flag); |
543 | vbroadcasti32x4(r_zmm_src, |
544 | EVEX_compress_addr(aux_reg_inp, |
545 | aux_input_offset)); |
546 | } else { |
547 | vpmovzxbd(r_zmm_src, |
548 | EVEX_compress_addr(aux_reg_inp, |
549 | aux_input_offset)); |
550 | } |
551 | if (jcp.signed_input) |
552 | vpaddb(zmm_src, zmm_src, vmm_shift); |
553 | } |
554 | compute(zmm_out(oi, ci), zmm_wei, zmm_src); |
555 | } else { |
556 | assert(jcp.signed_input); |
557 | compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero); |
558 | } |
559 | } |
560 | } |
561 | } |
562 | if (jcp.src_zero_point) { |
563 | /* calculate src_zero_point padding as: |
564 | * (is_padding ? |
565 | * src_zero_point_s32 * conv(1, wei_s32) : 0) */ |
566 | if (jcp.is_fast_depthwise || !compute_kernel) { |
567 | vpmovsxbd(zmm_wei, |
568 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
569 | if (jcp.is_fast_depthwise) |
570 | vpermd(zmm_wei, zmm_permute, zmm_wei); |
571 | } // else: already loaded weights from previous block |
572 | int zp_offset = 0; |
573 | for (int oi = 0; oi < ur_w; oi++) { |
574 | if (oi < oi_start || oi >= oi_end || h_padded) { |
575 | vpmulld(vmm_zp_tmp, zmm_wei, |
576 | EVEX_compress_addr(reg_src_zero_point, |
577 | zp_offset, jcp.zp_src_is_common)); |
578 | vpaddd(zmm_out(oi, ci), zmm_out(oi, ci), vmm_zp_tmp); |
579 | } |
580 | } |
581 | } |
582 | } |
583 | } |
584 | if (jcp.src_zero_point) pop(aux_reg_ker_d); |
585 | } |
586 | |
587 | template <typename Vmm> |
588 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l, |
589 | int pad_r, ic_block_t last_ic_block_flag, bool h_padded) { |
590 | if (jcp.is_depthwise) |
591 | return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded); |
592 | |
593 | const bool compute_kernel = IMPLICATION(h_padded, jcp.signed_input); |
594 | |
595 | assert(IMPLICATION(h_padded, jcp.src_zero_point || jcp.signed_input)); |
596 | |
597 | if (jcp.src_zero_point) { |
598 | push(aux_reg_ker_d); |
599 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
600 | } |
601 | |
602 | int kw = jcp.kw; |
603 | int stride_w = jcp.stride_w; |
604 | int ic_block = jcp.ic_block; |
605 | int oc_block = jcp.oc_block; |
606 | int ch_block_all = jcp.ch_block * ic_block * oc_block; |
607 | |
608 | int nb_oc_block = jcp.nb_oc_blocking; |
609 | |
610 | auto input_offset = [=](int oi, int ic, int ki) { |
611 | return jcp.typesize_in |
612 | * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l) |
613 | * jcp.ic_without_padding * jcp.ngroups |
614 | + ic_sub_step * ic); |
615 | }; |
616 | auto kernel_offset = [=](int ii, int ic, int ki) { |
617 | return jcp.typesize_in |
618 | * ((ii * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw + ki) |
619 | * ch_block_all |
620 | + ic_sub_step * ic * oc_block); |
621 | }; |
622 | auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { |
623 | if (jcp.has_vnni) { |
624 | vpdpbusd(vreg_acc, vreg_src, vreg_wei); |
625 | } else { |
626 | vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); |
627 | vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); |
628 | vpaddd(vreg_acc, vreg_acc, vmm_tmp); |
629 | } |
630 | }; |
631 | |
632 | for (int ki = 0; ki < kw; ki++) { |
633 | int jj_start = get_ow_start(ki, pad_l); |
634 | int jj_end = get_ow_end(ur_w, ki, pad_r); |
635 | int ic_tail_size = jcp.ic_without_padding % ic_sub_step; |
636 | int _start = jcp.signed_input ? 0 : jj_start; |
637 | int _end = jcp.signed_input ? ur_w : jj_end; |
638 | /* Skip the last loads of input |
639 | if (ic%16)/ic_sub_step < ic_block/ic_sub_step */ |
640 | int icb = (last_ic_block_flag != no_last_block) |
641 | ? div_up((jcp.ic_without_padding % ic_block), ic_sub_step) |
642 | : ic_block / ic_sub_step; |
643 | if (compute_kernel) { |
644 | for (int ic = 0; ic < icb; ic++) { |
645 | if (h_padded) { |
646 | // fill padded area with shifted value in first iteration |
647 | if (ic == 0) { |
648 | Vmm inp = vmm_inp(0, nb_oc_block); |
649 | vmovups(inp, vmm_shift); // bcast(128) |
650 | } |
651 | } else { |
652 | for (int jj = _start; jj < _end; jj++) { |
653 | int aux_input_offset = input_offset(jj, ic, ki); |
654 | if (jj >= jj_start && jj < jj_end) { |
655 | if (last_ic_block_flag == last_sp_block |
656 | && ic_tail_size != 0 && ic == icb - 1) { |
657 | Xmm xmm_tmp = Xmm( |
658 | vmm_inp(jj, nb_oc_block).getIdx()); |
659 | load_bytes(xmm_tmp, aux_reg_inp, |
660 | aux_input_offset, ic_tail_size); |
661 | vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp); |
662 | } else { |
663 | vpbroadcastd(vmm_inp(jj, nb_oc_block), |
664 | EVEX_compress_addr( |
665 | aux_reg_inp, aux_input_offset)); |
666 | } |
667 | if (jcp.signed_input) |
668 | vpaddb(vmm_inp(jj, nb_oc_block), |
669 | vmm_inp(jj, nb_oc_block), vmm_shift); |
670 | } else { |
671 | // fill padded area with shifted value in |
672 | // first iteration |
673 | if (jcp.signed_input && ic == 0) { |
674 | Vmm inp = vmm_inp(jj, nb_oc_block); |
675 | vmovups(inp, vmm_shift); |
676 | } |
677 | } |
678 | } |
679 | } |
680 | for (int ii = 0; ii < nb_oc_block; ii++) { |
681 | int aux_kernel_offset = kernel_offset(ii, ic, ki); |
682 | vmovups(vmm_wei, |
683 | EVEX_compress_addr(aux_reg_ker, aux_kernel_offset)); |
684 | for (int jj = _start; jj < _end; jj++) { |
685 | Vmm inp = h_padded ? vmm_inp(0, nb_oc_block) |
686 | : vmm_inp(jj, nb_oc_block); |
687 | compute(vmm_out(jj, ii), vmm_wei, inp); |
688 | } |
689 | } |
690 | } |
691 | } |
692 | if (jcp.src_zero_point) { |
693 | /* calculate src_zero_point padding as: |
694 | * (is_padding ? src_zero_point_s32 * conv(1, wei_s8) : 0) */ |
695 | Vmm vmm_tmp = vmm_inp(0, nb_oc_block); |
696 | for (int jj = 0; jj < ur_w; jj++) { |
697 | if (jj < jj_start || jj >= jj_end || h_padded) { |
698 | for (int ii = 0; ii < nb_oc_block; ii++) { |
699 | vpxord(vmm_zp_tmp, vmm_zp_tmp, vmm_zp_tmp); |
700 | for (int ic = 0; ic < icb; ic++) { |
701 | int aux_kernel_offset = kernel_offset(ii, ic, ki); |
702 | if (jcp.has_vnni) { |
703 | vpdpbusd(vmm_zp_tmp, vmm_zp_one, |
704 | EVEX_compress_addr(aux_reg_ker, |
705 | aux_kernel_offset)); |
706 | } else { |
707 | vpmaddubsw(vmm_tmp, vmm_zp_one, |
708 | EVEX_compress_addr(aux_reg_ker, |
709 | aux_kernel_offset)); |
710 | vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); |
711 | vpaddd(vmm_zp_tmp, vmm_zp_tmp, vmm_tmp); |
712 | } |
713 | } |
714 | int zp_offset = 0; |
715 | vpmulld(vmm_zp_tmp, vmm_zp_tmp, |
716 | EVEX_compress_addr(reg_src_zero_point, |
717 | zp_offset, jcp.zp_src_is_common)); |
718 | vpaddd(vmm_out(jj, ii), vmm_out(jj, ii), vmm_zp_tmp); |
719 | } |
720 | } |
721 | } |
722 | } |
723 | } |
724 | |
725 | if (jcp.src_zero_point) pop(aux_reg_ker_d); |
726 | } |
727 | |
728 | template <typename Vmm> |
729 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop( |
730 | int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) { |
731 | Label kd_label, kh_label, skip_kd_loop, skip_kh_loop; |
732 | Label f_overflow_label, no_f_overflow_label, d_h_f_overflow_label, |
733 | t_overflow_label, no_t_overflow_label, b_overflow_label, |
734 | no_b_overflow_label, back_overflow_label, no_back_overflow_label, |
735 | d_h_back_overflow_label; |
736 | |
737 | int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block; |
738 | int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all; |
739 | int shift_input_ptr |
740 | = jcp.typesize_in * jcp.iw * jcp.ic_without_padding * jcp.ngroups; |
741 | |
742 | if (jcp.ndims == 5) { |
743 | mov(aux_reg_ker_d, reg_ker); |
744 | mov(aux_reg_inp_d, reg_inp); |
745 | if (jcp.signed_input || jcp.src_zero_point) { |
746 | //TODO: May be avoided when f_pad=0 and dd0 |
747 | //TODO: Potential optimization by precomputing, when kd <<< od? |
748 | mov(reg_ki, ptr[param1 + GET_OFF(f_overflow)]); |
749 | cmp(reg_ki, 0); |
750 | je(no_f_overflow_label, T_NEAR); |
751 | L(f_overflow_label); |
752 | { |
753 | mov(aux_reg_ker, aux_reg_ker_d); |
754 | mov(reg_kj, jcp.kh); |
755 | L(d_h_f_overflow_label); |
756 | { |
757 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
758 | add(aux_reg_ker, shift_kernel_ptr); |
759 | dec(reg_kj); |
760 | jne(d_h_f_overflow_label); |
761 | } |
762 | add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh); |
763 | dec(reg_ki); |
764 | jne(f_overflow_label); |
765 | } |
766 | L(no_f_overflow_label); |
767 | } |
768 | |
769 | mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]); |
770 | if ((jcp.signed_input || jcp.src_zero_point) || (jcp.dilate_d >= jcp.id) |
771 | || (!(jcp.signed_input || jcp.src_zero_point) |
772 | && (jcp.kd - 1) * (jcp.dilate_d + 1) |
773 | < nstl::max(jcp.f_pad, jcp.back_pad))) { |
774 | cmp(reg_ki, 0); |
775 | je(skip_kd_loop, T_NEAR); |
776 | } |
777 | L(kd_label); |
778 | mov(aux_reg_inp, aux_reg_inp_d); |
779 | mov(aux_reg_ker, aux_reg_ker_d); |
780 | } else { |
781 | if (jcp.is_fused_conv) { |
782 | mov(aux_reg_inp_buffer_ptr, reg_inp_buffer_ptr); |
783 | } else { |
784 | mov(aux_reg_inp, reg_inp); |
785 | } |
786 | mov(aux_reg_ker, reg_ker); |
787 | } |
788 | |
789 | if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { |
790 | mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); |
791 | cmp(reg_overflow, 0); |
792 | je(no_t_overflow_label, T_NEAR); |
793 | L(t_overflow_label); |
794 | { |
795 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
796 | |
797 | add(aux_reg_ker, shift_kernel_ptr); |
798 | dec(reg_overflow); |
799 | cmp(reg_overflow, 0); |
800 | jg(t_overflow_label, T_NEAR); |
801 | } |
802 | L(no_t_overflow_label); |
803 | } |
804 | mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]); |
805 | if (jcp.signed_input || jcp.src_zero_point || (jcp.dilate_h >= jcp.ih) |
806 | || (!(jcp.signed_input || jcp.src_zero_point) |
807 | && (jcp.kh - 1) * (jcp.dilate_h + 1) |
808 | < nstl::max(jcp.t_pad, jcp.b_pad))) { |
809 | cmp(reg_kj, 0); |
810 | je(skip_kh_loop, T_NEAR); |
811 | } |
812 | L(kh_label); |
813 | { |
814 | if (jcp.is_fused_conv) { |
815 | mov(aux_reg_inp, ptr[aux_reg_inp_buffer_ptr]); |
816 | add(aux_reg_inp, reg_inp); |
817 | } |
818 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false); |
819 | |
820 | add(aux_reg_ker, shift_kernel_ptr); |
821 | if (jcp.is_fused_conv) { |
822 | add(aux_reg_inp_buffer_ptr, sizeof(void *)); |
823 | } else { |
824 | add(aux_reg_inp, shift_input_ptr * (jcp.dilate_h + 1)); |
825 | } |
826 | dec(reg_kj); |
827 | cmp(reg_kj, 0); |
828 | jg(kh_label, T_NEAR); |
829 | } |
830 | L(skip_kh_loop); |
831 | if ((jcp.signed_input || jcp.src_zero_point) && jcp.ndims > 3) { |
832 | mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); |
833 | cmp(reg_overflow, 0); |
834 | je(no_b_overflow_label, T_NEAR); |
835 | L(b_overflow_label); |
836 | { |
837 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
838 | |
839 | add(aux_reg_ker, shift_kernel_ptr); |
840 | dec(reg_overflow); |
841 | cmp(reg_overflow, 0); |
842 | jg(b_overflow_label, T_NEAR); |
843 | } |
844 | L(no_b_overflow_label); |
845 | } |
846 | |
847 | if (jcp.ndims == 5) { |
848 | add(aux_reg_inp_d, shift_input_ptr * jcp.ih * (jcp.dilate_d + 1)); |
849 | add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh); |
850 | dec(reg_ki); |
851 | jne(kd_label, T_NEAR); |
852 | |
853 | L(skip_kd_loop); |
854 | if (jcp.signed_input || jcp.src_zero_point) { |
855 | mov(reg_ki, ptr[param1 + GET_OFF(back_overflow)]); |
856 | cmp(reg_ki, 0); |
857 | je(no_back_overflow_label, T_NEAR); |
858 | L(back_overflow_label); |
859 | { |
860 | mov(aux_reg_ker, aux_reg_ker_d); |
861 | mov(reg_kj, jcp.kh); |
862 | L(d_h_back_overflow_label); |
863 | { |
864 | compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true); |
865 | add(aux_reg_ker, shift_kernel_ptr); |
866 | dec(reg_kj); |
867 | jne(d_h_back_overflow_label); |
868 | } |
869 | add(aux_reg_ker_d, shift_kernel_ptr * jcp.kh); |
870 | dec(reg_ki); |
871 | jne(back_overflow_label); |
872 | } |
873 | L(no_back_overflow_label); |
874 | } |
875 | } |
876 | } |
877 | |
878 | template <typename Vmm> |
879 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop( |
880 | int ur_w, int pad_l, int pad_r, bool is_last_sp_block) { |
881 | |
882 | if (jcp.src_zero_point && !jcp.is_depthwise) { |
883 | xor_(reg_scratch, reg_scratch); |
884 | Reg8 _t8 = reg_scratch.cvt8(); |
885 | mov(_t8, 0x1); |
886 | vpbroadcastb(vmm_zp_one, _t8); |
887 | } |
888 | |
889 | prepare_output(ur_w); |
890 | |
891 | // IC loop |
892 | Label icb_label; |
893 | mov(reg_icb, jcp.nb_ic); |
894 | L(icb_label); |
895 | const bool do_icb_loop |
896 | = jcp.is_depthwise ? jcp.nb_ch > jcp.nb_ch_blocking : jcp.nb_ic > 1; |
897 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) { |
898 | Label common_ker, end_ker; |
899 | if (do_icb_loop) { |
900 | if (jcp.is_depthwise) |
901 | cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); |
902 | else |
903 | cmp(reg_icb, 1); // The last IC block |
904 | jne(common_ker, T_NEAR); |
905 | } |
906 | kh_loop(ur_w, pad_l, pad_r, |
907 | is_last_sp_block ? last_sp_block : last_ic_block); |
908 | if (do_icb_loop) { |
909 | jmp(end_ker, T_NEAR); |
910 | |
911 | L(common_ker); |
912 | kh_loop(ur_w, pad_l, pad_r, no_last_block); |
913 | |
914 | L(end_ker); |
915 | } |
916 | } else { |
917 | kh_loop(ur_w, pad_l, pad_r, no_last_block); |
918 | } |
919 | // End of IC Loop |
920 | if (do_icb_loop) { |
921 | int inp_step = jcp.ic_block; |
922 | const size_t ker_step = (size_t)jcp.kd * jcp.kh * jcp.kw * jcp.oc_block |
923 | * jcp.ic_block; |
924 | add(reg_inp, jcp.typesize_in * inp_step); |
925 | safe_add(reg_ker, jcp.typesize_in * ker_step, reg_ker_long_offt); |
926 | |
927 | dec(reg_icb); |
928 | cmp(reg_icb, 0); |
929 | jg(icb_label, T_NEAR); |
930 | |
931 | sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic); |
932 | safe_sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic, |
933 | reg_ker_long_offt); |
934 | } |
935 | |
936 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { |
937 | Label common_store, end_store; |
938 | |
939 | if (jcp.is_depthwise) |
940 | cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking); |
941 | else |
942 | cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); |
943 | |
944 | jne(common_store, T_NEAR); |
945 | |
946 | store_output(ur_w, true); // last oc block |
947 | jmp(end_store, T_NEAR); |
948 | |
949 | L(common_store); |
950 | store_output(ur_w, false); |
951 | |
952 | L(end_store); |
953 | } else { |
954 | store_output(ur_w, false); |
955 | } |
956 | } |
957 | |
958 | template <typename Vmm> |
959 | void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate() { |
960 | Label permute_index_table; |
961 | int in_ic_shift = jcp.is_fused_conv ? jcp.dw_conv_buffer_oc |
962 | : jcp.ic_without_padding * jcp.ngroups; |
963 | const int urw_inp_stride = jcp.ur_w * jcp.stride_w; |
964 | const int n_urw_l_pad |
965 | = nstl::min(div_up(jcp.l_pad, urw_inp_stride), jcp.ow / jcp.ur_w); |
966 | const int inp_shift_pad = nstl::max(0, |
967 | jcp.typesize_in * (n_urw_l_pad * urw_inp_stride - jcp.l_pad) |
968 | * in_ic_shift); |
969 | int inp_shift = jcp.typesize_in * (jcp.ur_w * jcp.stride_w * in_ic_shift); |
970 | int out_shift = jcp.typesize_out |
971 | * (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups); |
972 | preamble(); |
973 | |
974 | if (jcp.is_depthwise) { |
975 | bool is_zero_point = jcp.src_zero_point || jcp.dst_zero_point; |
976 | // dst zero point and dst scale reuse the same register |
977 | int idx = jcp.max_regs_ur - 1 |
978 | + nstl::max(2 * is_zero_point, static_cast<int>(jcp.dst_scale)); |
979 | if (!jcp.is_resrc_depthwise) zmm_src = Zmm(++idx); |
980 | if (!jcp.has_vnni) zmm_tmp = Zmm(++idx); |
981 | if (jcp.is_fast_depthwise) zmm_permute = Zmm(++idx); |
982 | if (jcp.signed_input) zmm_shifted_zero = Zmm(++idx); |
983 | // due to extra register used for shifts and compensations |
984 | // and/or saturation, we increment by one more |
985 | if (jcp.signed_input || jcp.need_saturation) ++idx; |
986 | |
987 | assert(IMPLICATION(!jcp.dst_scale && !is_zero_point |
988 | && jcp.dst_dt != data_type::bf16, |
989 | idx == ker_dw_reg_base_idx)); |
990 | } |
991 | if (!jcp.is_depthwise && (!jcp.has_vnni)) { |
992 | xor_(reg_scratch, reg_scratch); |
993 | Reg16 _t16 = reg_scratch.cvt16(); |
994 | mov(_t16, 0x1); |
995 | vpbroadcastw(vmm_one, _t16); |
996 | } |
997 | if (jcp.is_fused_conv) { |
998 | mov(reg_inp_buffer_ptr, ptr[param1 + GET_OFF(src)]); |
999 | /* In case of fused depthwise convolution, `param.src` is not a pointer |
1000 | to input, instead it points to a buffer containing pointers to |
1001 | consecutive rows of input in format wc with c=jcp.dw_conv_buffer_oc. |
1002 | Example: [ptr_to_inp_row0, ptr_to_inp_row1, ptr_to_inp_row2]. |
1003 | Traverse the data as |
1004 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
1005 | ... process row0 ... |
1006 | add(reg_input_buffer_ptr, sizeof(void*)) |
1007 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
1008 | ... process row1 ... |
1009 | add(reg_input_buffer_ptr, sizeof(void*)) |
1010 | mov(reg_data, ptr[reg_input_buffer_ptr]) |
1011 | ... process row2 ... |
1012 | */ |
1013 | xor_(reg_inp, reg_inp); |
1014 | } else { |
1015 | mov(reg_inp, ptr[param1 + GET_OFF(src)]); |
1016 | } |
1017 | mov(reg_out, ptr[param1 + GET_OFF(dst)]); |
1018 | mov(reg_ker, ptr[param1 + GET_OFF(filt)]); |
1019 | |
1020 | if (jcp.simd_w == 4 && jcp.dst_dt == data_type::bf16) { |
1021 | auto reg_tail_32 = reg_oi.cvt32(); |
1022 | mov(reg_tail_32, (1 << jcp.simd_w) - 1); |
1023 | kmovb(ktail_mask, reg_tail_32); |
1024 | } |
1025 | if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) { |
1026 | int tail_size = jcp.is_depthwise |
1027 | ? jcp.ngroups % jcp.ch_block |
1028 | : jcp.oc_without_padding % jcp.oc_block; |
1029 | int mask = (1 << tail_size) - 1; |
1030 | mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); |
1031 | Reg32 regw_tmp = reg_oi.cvt32(); |
1032 | mov(regw_tmp, mask); |
1033 | kmovw(ktail_mask, regw_tmp); |
1034 | kmovw(postops_mask, regw_tmp); |
1035 | |
1036 | // To account for special store optimization, where two oc_blocks are |
1037 | // combined with one single write, extend the mask for 32bits (32 bf16s) |
1038 | const int nb_block |
1039 | = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking; |
1040 | const bool need_extended_mask = jcp.dst_dt == data_type::bf16 |
1041 | && isa_has_bf16(jcp.isa) && nb_block > 1; |
1042 | if (need_extended_mask) { |
1043 | mov(regw_tmp, (1 << (tail_size + jcp.simd_w)) - 1); |
1044 | kmovd(ktail_mask_extended, regw_tmp); |
1045 | } |
1046 | } else if (jcp.with_binary) |
1047 | if (jcp.oc_block != isa_simd_width_) { |
1048 | const int mask = (1 << jcp.oc_block) - 1; |
1049 | const Reg32 regw_tmp = reg_oi.cvt32(); |
1050 | mov(regw_tmp, mask); |
1051 | kmovw(postops_mask, regw_tmp); |
1052 | } |
1053 | if (jcp.is_fast_depthwise) { |
1054 | // prepare mask register for blending weights |
1055 | mov(reg_scratch, 0x8888444422221111); |
1056 | kmovq(kblend_mask, reg_scratch); |
1057 | // load permute indices from data section |
1058 | mov(reg_scratch, permute_index_table); |
1059 | vmovdqu32(zmm_permute, ptr[reg_scratch]); |
1060 | } |
1061 | const int extended_filter_size |
1062 | = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1063 | const int r_pad = nstl::max(0, jcp.r_pad); |
1064 | const int ow_with_no_rpad = 1 |
1065 | + (jcp.iw + jcp.l_pad + nstl::min(0, jcp.r_pad) |
1066 | - extended_filter_size) |
1067 | / jcp.stride_w; |
1068 | const int n_urw_per_ow_block = jcp.ow_block / jcp.ur_w; |
1069 | const int max_safe_iw = nstl::max( |
1070 | 0, jcp.iw - div_up(ic_sub_step, jcp.ic_without_padding)); |
1071 | const int max_safe_ow = jcp.ic_without_padding % ic_sub_step == 0 |
1072 | ? jcp.ow |
1073 | : (max_safe_iw + jcp.l_pad - extended_filter_size) / jcp.stride_w; |
1074 | Label middle_block_label, done_compute; |
1075 | std::vector<Label> ow_block_jmp_table; |
1076 | |
1077 | // r_pad_fall_through is a special ow_block, where the block overlaps |
1078 | // both middle_block and r_pad/ur_w_tail region when it exists. |
1079 | // The number of ur_w's to compute in middle_block before executing |
1080 | // r_pad region is stored in r_pad_fall_through_n_urw and the ow_block |
1081 | // number is stored in r_pad_fall_through_ow_block. |
1082 | int r_pad_fall_through_ow_block = 0; |
1083 | int r_pad_fall_through_n_urw = 0; |
1084 | |
1085 | if (jcp.nb_ow > 1) { |
1086 | // Only one ow block is processed, per jit call. |
1087 | // Number of this ow block is passed as parameter owb, |
1088 | // and padding processing depends on this number. |
1089 | // |
1090 | // The compute block to run is determined by using a jmp-table. |
1091 | // jmp-table Layout: |
1092 | // idx -> addr |
1093 | // 0 -> [...l_pad_region label[0]...] |
1094 | // : : : : : : : : : : : : : : : |
1095 | // L -> [...l_pad_region label[L]...] |
1096 | // L+1 -> [...r_pad_region label[0]...] |
1097 | // : : : : : : : : : : : : : : : |
1098 | // L+R -> [...r_pad_region label[R]...] |
1099 | // |
1100 | // Note: Label for middle_block is not stored in the jmp-table. |
1101 | // |
1102 | // During jit call, the jump address is calculated as below: |
1103 | // if (owb < L) { |
1104 | // jmp([jmp_table + owb*sizeof(void*)]); |
1105 | // } else if (owb < X) { |
1106 | // // X is the number of ow_blocks before r_pad region (see below). |
1107 | // jmp(middle_block); |
1108 | // } else { |
1109 | // sub(owb, X); |
1110 | // jmp([jmp_table + owb*sizeof(void*) + L*sizeof(void)]); |
1111 | // } |
1112 | // |
1113 | // To configure the jmp-table, we need to determine some constants |
1114 | // (namely, r_pad_fall_through_n_urw, r_pad_fall_through_ow_block, |
1115 | // n_l_pad_labels, n_labels) ahead of writing the compute assembly. So, |
1116 | // we simulate the filter path without writing the assembly initially. |
1117 | // This makes the math for calculating the constants become simple and |
1118 | // self explanatory. |
1119 | |
1120 | // Begin simulation without writing assembly |
1121 | int n_l_pad_labels = 0; |
1122 | int n_labels = 0; |
1123 | int cur_ow = 0; |
1124 | |
1125 | // l_pad region: |
1126 | n_l_pad_labels = div_up(n_urw_l_pad, n_urw_per_ow_block); |
1127 | n_labels = n_l_pad_labels; |
1128 | cur_ow += n_urw_l_pad * jcp.ur_w; |
1129 | |
1130 | // middle_region: |
1131 | int n_urw_middle_block_loop = 0; |
1132 | int cur_r_pad = nstl::max(0, |
1133 | calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw, |
1134 | jcp.stride_w, extended_filter_size)); |
1135 | if (cur_ow + jcp.ur_w <= jcp.ow && cur_r_pad == 0) { |
1136 | n_urw_middle_block_loop |
1137 | = nstl::max(0, |
1138 | nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow) |
1139 | / jcp.ur_w; |
1140 | cur_ow += n_urw_middle_block_loop * jcp.ur_w; |
1141 | } |
1142 | r_pad_fall_through_n_urw = (cur_ow / jcp.ur_w) % n_urw_per_ow_block; |
1143 | r_pad_fall_through_ow_block = cur_ow / (n_urw_per_ow_block * jcp.ur_w); |
1144 | |
1145 | // r_pad or last_sp_block |
1146 | if (cur_ow + jcp.ur_w <= jcp.ow) { |
1147 | if (r_pad_fall_through_n_urw == 0) ++n_labels; |
1148 | const int n_urw_r_pad_region = (jcp.ow - cur_ow) / jcp.ur_w; |
1149 | n_labels += nstl::max(0, |
1150 | div_up(r_pad_fall_through_n_urw + n_urw_r_pad_region, |
1151 | n_urw_per_ow_block) |
1152 | - 1); |
1153 | } |
1154 | |
1155 | if (jcp.ur_w_tail != 0) { |
1156 | if (jcp.ow % jcp.ow_block == jcp.ur_w_tail) ++n_labels; |
1157 | } |
1158 | // End of simulation |
1159 | |
1160 | ow_block_jmp_table.resize(n_labels); |
1161 | |
1162 | // Begin jump-table logic |
1163 | Label ow_block_jmp_table_label; |
1164 | if (!ow_block_jmp_table.empty()) |
1165 | mov(reg_jmp_tbl_base, ow_block_jmp_table_label); |
1166 | mov(reg_oi, n_urw_per_ow_block); |
1167 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1168 | if (jcp.l_pad > 0) { |
1169 | Label middle_or_rpad_check; |
1170 | cmp(reg_owb, n_l_pad_labels); |
1171 | jge(middle_or_rpad_check, T_NEAR); |
1172 | jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *)]); |
1173 | L(middle_or_rpad_check); |
1174 | // harness passes shifted src pointer that does not take |
1175 | // left-padding into account. So, we must re-shift here. |
1176 | const int inp_shift_pad_middle_block = -1 * jcp.typesize_in |
1177 | * nstl::min(jcp.l_pad, n_urw_l_pad * urw_inp_stride) |
1178 | * in_ic_shift; |
1179 | add(reg_inp, inp_shift_pad_middle_block); |
1180 | } |
1181 | if (r_pad_fall_through_n_urw != 0) { |
1182 | mov(reg_scratch, r_pad_fall_through_n_urw); |
1183 | cmp(reg_owb, r_pad_fall_through_ow_block); |
1184 | cmove(reg_oi, reg_scratch); |
1185 | if (n_urw_middle_block_loop > 0) { |
1186 | sub(reg_owb, r_pad_fall_through_ow_block); |
1187 | // simple middle_block |
1188 | jle(middle_block_label, T_NEAR); |
1189 | dec(reg_owb); |
1190 | } else { |
1191 | sub(reg_owb, r_pad_fall_through_ow_block + 1); |
1192 | } |
1193 | } else { |
1194 | sub(reg_owb, r_pad_fall_through_ow_block); |
1195 | // simple middle_block |
1196 | if (n_urw_middle_block_loop) jl(middle_block_label, T_NEAR); |
1197 | } |
1198 | // r_pad-only region |
1199 | if (!ow_block_jmp_table.empty()) |
1200 | jmp(ptr[reg_jmp_tbl_base + reg_owb * sizeof(void *) |
1201 | + n_l_pad_labels * sizeof(void *)]); |
1202 | |
1203 | if (!ow_block_jmp_table.empty()) { |
1204 | align(8); |
1205 | L(ow_block_jmp_table_label); |
1206 | { |
1207 | for (size_t i = 0; i < ow_block_jmp_table.size(); ++i) { |
1208 | putL(ow_block_jmp_table[i]); |
1209 | } |
1210 | } |
1211 | } |
1212 | // End of jump-table logic |
1213 | } |
1214 | |
1215 | // Begin kernel |
1216 | int cur_ow = 0; |
1217 | int cur_n_oi = 0; // used only for jcp.nb_ow > 1 scenario |
1218 | int label_cntr = 0; |
1219 | int cur_l_pad = 0; |
1220 | if (jcp.l_pad > 0) { |
1221 | for (cur_l_pad = jcp.l_pad; |
1222 | cur_l_pad > 0 && cur_ow + jcp.ur_w <= jcp.ow; |
1223 | cur_l_pad -= urw_inp_stride) { |
1224 | if (jcp.nb_ow > 1 && cur_n_oi == 0) { |
1225 | // cur_n_oi == 0 signifies beginning of new ow_block |
1226 | // (or end of previous block) |
1227 | const dim_t inp_lpad_region_shift = -label_cntr * jcp.ow_block |
1228 | * jcp.stride_w * in_ic_shift; |
1229 | L(ow_block_jmp_table[label_cntr++]); |
1230 | // harness passes shifted src pointer that does not take |
1231 | // left-padding into account. So, we must re-shift here. |
1232 | add(reg_inp, inp_lpad_region_shift); |
1233 | } |
1234 | |
1235 | cur_ow += jcp.ur_w; |
1236 | int cur_r_pad = nstl::max(0, |
1237 | calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw, |
1238 | jcp.stride_w, extended_filter_size)); |
1239 | icb_loop(jcp.ur_w, cur_l_pad, cur_r_pad, cur_ow > max_safe_ow); |
1240 | add(reg_out, out_shift); |
1241 | dec(reg_oi); |
1242 | |
1243 | if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) { |
1244 | // We compute one owb per jit call. So, insert an |
1245 | // unconditional jmp, after computing one owb. |
1246 | jmp(done_compute, T_NEAR); |
1247 | cur_n_oi = 0; |
1248 | } |
1249 | } |
1250 | if (jcp.nb_ow == 1 || cur_n_oi != 0) { |
1251 | // Let it "fall-through" middle_block_label |
1252 | add(reg_inp, inp_shift_pad); |
1253 | } |
1254 | } |
1255 | |
1256 | // middle_block |
1257 | { |
1258 | int cur_r_pad = nstl::max(0, |
1259 | calculate_end_padding(jcp.l_pad, cur_ow + jcp.ur_w, jcp.iw, |
1260 | jcp.stride_w, extended_filter_size)); |
1261 | if (cur_r_pad == 0 && cur_ow + jcp.ur_w <= jcp.ow) { |
1262 | int n_oi_middle_block_loop |
1263 | = nstl::max(0, |
1264 | nstl::min(ow_with_no_rpad, max_safe_ow) - cur_ow) |
1265 | / jcp.ur_w; |
1266 | if (jcp.nb_ow == 1 && n_oi_middle_block_loop > 1) |
1267 | mov(reg_oi, n_oi_middle_block_loop); |
1268 | L(middle_block_label); |
1269 | if (n_oi_middle_block_loop > 0) { |
1270 | icb_loop(jcp.ur_w, 0, 0, false); |
1271 | add(reg_inp, inp_shift); |
1272 | add(reg_out, out_shift); |
1273 | if (n_oi_middle_block_loop > 1) { |
1274 | dec(reg_oi); |
1275 | jg(middle_block_label, T_NEAR); |
1276 | } |
1277 | } |
1278 | cur_ow += n_oi_middle_block_loop * jcp.ur_w; |
1279 | cur_n_oi = (cur_n_oi + n_oi_middle_block_loop) % n_urw_per_ow_block; |
1280 | } |
1281 | } |
1282 | |
1283 | // r_pad region or last_sp_block |
1284 | if (cur_ow + jcp.ur_w <= jcp.ow) { |
1285 | if (jcp.nb_ow > 1) { |
1286 | if (cur_n_oi == 0) { |
1287 | jmp(done_compute, T_NEAR); |
1288 | } else { |
1289 | // r_pad fall-through |
1290 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1291 | cmp(reg_owb, r_pad_fall_through_ow_block); |
1292 | jne(done_compute, T_NEAR); |
1293 | } |
1294 | } |
1295 | |
1296 | while (cur_ow + jcp.ur_w <= jcp.ow) { |
1297 | if (jcp.nb_ow > 1 && cur_n_oi == 0) { |
1298 | L(ow_block_jmp_table[label_cntr++]); |
1299 | } |
1300 | cur_ow += jcp.ur_w; |
1301 | int cur_r_pad = calculate_end_padding(jcp.l_pad, cur_ow, jcp.iw, |
1302 | jcp.stride_w, extended_filter_size); |
1303 | assert(cur_r_pad > 0 || cur_ow > max_safe_ow); // else, why be here? |
1304 | icb_loop(jcp.ur_w, 0, cur_r_pad, cur_ow > max_safe_ow); |
1305 | add(reg_inp, inp_shift); |
1306 | add(reg_out, out_shift); |
1307 | |
1308 | if (jcp.nb_ow > 1 && ++cur_n_oi == n_urw_per_ow_block) { |
1309 | // We compute one owb per jit call. So, insert an |
1310 | // unconditional jmp, after computing one owb. |
1311 | jmp(done_compute, T_NEAR); |
1312 | cur_n_oi = 0; |
1313 | } |
1314 | } |
1315 | // Let it fall-through ur_w_tail |
1316 | } |
1317 | |
1318 | // ur_w_tail |
1319 | if (jcp.ur_w_tail != 0) { |
1320 | if (jcp.nb_ow > 1) { |
1321 | if (cur_n_oi == 0) { |
1322 | jmp(done_compute, T_NEAR); |
1323 | L(ow_block_jmp_table[label_cntr++]); |
1324 | } else { |
1325 | // In case, when there is no r_pad region, then there exists an |
1326 | // ambiguity btw middle_blocks and r_pad_fall_through_ow_block. |
1327 | // If not properly distinguished, there can be a race condition |
1328 | // as middle_blocks and r_pad_fall_through_ow_block both try to |
1329 | // compute ur_w_tail work at the end. |
1330 | mov(reg_owb, ptr[param1 + GET_OFF(owb)]); |
1331 | cmp(reg_owb, jcp.nb_ow - 1); // last ow_block? |
1332 | jne(done_compute, T_NEAR); |
1333 | } |
1334 | } |
1335 | icb_loop(jcp.ur_w_tail, nstl::max(0, cur_l_pad), r_pad, true); |
1336 | } |
1337 | L(done_compute); |
1338 | assert(ow_block_jmp_table.size() == static_cast<size_t>(label_cntr)); |
1339 | |
1340 | postamble(); |
1341 | |
1342 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
1343 | |
1344 | if (jcp.is_fast_depthwise) { |
1345 | align(64); |
1346 | L(permute_index_table); |
1347 | const uint32_t _idx[] |
1348 | = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; |
1349 | for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i) |
1350 | dd(_idx[i]); |
1351 | } |
1352 | } |
1353 | |
1354 | status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp, |
1355 | const convolution_desc_t &cd, memory_desc_t &src_md, |
1356 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
1357 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
1358 | using namespace prop_kind; |
1359 | |
1360 | const memory_desc_wrapper src_d(&src_md); |
1361 | const memory_desc_wrapper weights_d(&weights_md); |
1362 | const memory_desc_wrapper dst_d(&dst_md); |
1363 | const memory_desc_wrapper bias_d(&bias_md); |
1364 | |
1365 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
1366 | const int ndims = src_d.ndims(); |
1367 | const bool is_1d = ndims == 3; |
1368 | const bool is_2d = ndims == 4; |
1369 | const bool is_3d = ndims == 5; |
1370 | assert(is_1d || is_2d || is_3d); |
1371 | |
1372 | if (!(mayiuse(avx512_core) |
1373 | && one_of(src_d.data_type(), data_type::u8, data_type::s8) |
1374 | && weights_d.data_type() == data_type::s8 |
1375 | && one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
1376 | data_type::s8, data_type::u8, data_type::bf16))) |
1377 | return status::unimplemented; |
1378 | |
1379 | jcp = zero<decltype(jcp)>(); |
1380 | jcp.nthr = nthreads; |
1381 | jcp.ndims = ndims; |
1382 | jcp.prop_kind = cd.prop_kind; |
1383 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
1384 | jcp.mb = src_d.dims()[0]; |
1385 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
1386 | jcp.oc_without_padding = jcp.oc; |
1387 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
1388 | jcp.ic_without_padding = jcp.ic; |
1389 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
1390 | jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; |
1391 | jcp.iw = src_d.dims()[ndims - 1]; |
1392 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
1393 | jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; |
1394 | jcp.ow = dst_d.dims()[ndims - 1]; |
1395 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
1396 | jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
1397 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
1398 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
1399 | jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; |
1400 | jcp.l_pad = cd.padding[0][ndims - 3]; |
1401 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
1402 | jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; |
1403 | jcp.stride_w = cd.strides[ndims - 3]; |
1404 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
1405 | |
1406 | jcp.ur_h = 1; /* no code-unrolling by h so far */ |
1407 | jcp.dilate_d = is_3d ? cd.dilates[0] : 0; |
1408 | jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; |
1409 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1410 | |
1411 | int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w); |
1412 | int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h); |
1413 | int ext_kd = calculate_extended_filter_size(jcp.kd, jcp.dilate_d); |
1414 | jcp.r_pad = calculate_end_padding( |
1415 | jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw); |
1416 | jcp.b_pad = calculate_end_padding( |
1417 | jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh); |
1418 | jcp.back_pad = calculate_end_padding( |
1419 | jcp.f_pad, jcp.od, jcp.id, jcp.stride_d, ext_kd); |
1420 | |
1421 | jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false; |
1422 | jcp.need_saturation = utils::one_of(dst_d.data_type(), u8, s8, s32); |
1423 | jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); |
1424 | |
1425 | // Used for bfloat16 output |
1426 | jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
1427 | : bf16_emulation_t::get_isa(); |
1428 | |
1429 | if (jcp.is_depthwise && is_3d) |
1430 | // NOTE: 3D depthwise is not currently supported here. |
1431 | return status::unimplemented; |
1432 | |
1433 | if (jcp.is_depthwise) { |
1434 | jcp.ch_block = 16; |
1435 | jcp.ic_block = 1; |
1436 | jcp.oc_block = 1; |
1437 | } else { |
1438 | jcp.ch_block = 1; |
1439 | jcp.ic_block = 16; |
1440 | jcp.oc_block = 16; |
1441 | |
1442 | if (jcp.ngroups == 1) { |
1443 | /* For non grouped convolutions, pad channels by 16 if needed */ |
1444 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
1445 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
1446 | } else if (jcp.ngroups != 1 |
1447 | && ((jcp.ic % jcp.ic_block != 0) |
1448 | || (jcp.oc % jcp.oc_block != 0))) { |
1449 | /* For grouped convolutions, oneDNN doesn't support padding. |
1450 | When channels per group is not multiple of 16: |
1451 | - Use Ymm when channels per group is multiple of 8, |
1452 | - Use Xmm when channels per group is multiple of 4, |
1453 | - Otherwise return unimplemented. */ |
1454 | jcp.ic_block = (jcp.ic % 8 == 0) && (jcp.oc % 8 == 0) ? 8 : 4; |
1455 | jcp.oc_block = jcp.ic_block; |
1456 | } |
1457 | if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0) |
1458 | return status::unimplemented; |
1459 | } |
1460 | |
1461 | jcp.simd_w = jcp.is_depthwise ? jcp.ch_block : jcp.ic_block; |
1462 | |
1463 | const auto zp = attr.zero_points_; |
1464 | jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); |
1465 | jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); |
1466 | jcp.zp_src_is_common |
1467 | = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel |
1468 | assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)); |
1469 | |
1470 | if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.is_fused_conv) |
1471 | return status::unimplemented; |
1472 | |
1473 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
1474 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
1475 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
1476 | const int wei_mask_per_oc = 1 << (int)with_groups; |
1477 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
1478 | jcp.dst_scale = !dst_scales.has_default_values(); |
1479 | |
1480 | // only common src & dst scales are supported |
1481 | // only common and per-oc-channel weight scales are supported |
1482 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
1483 | && everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
1484 | if (!scales_ok) return status::unimplemented; |
1485 | |
1486 | jcp.has_vnni = mayiuse(avx512_core_vnni); |
1487 | const bool bf16_req_extra_regs = cd.dst_desc.data_type == data_type::bf16 |
1488 | && !isa_has_bf16(jcp.isa); |
1489 | jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.has_vnni |
1490 | && jcp.ngroups % jcp.ch_block == 0 |
1491 | && !bf16_req_extra_regs; /* groups not multiple of |
1492 | ch_block (= 16) would require byte masking for load from src */ |
1493 | |
1494 | jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw |
1495 | && jcp.kw < 4 && jcp.dilate_w == 0; |
1496 | |
1497 | if (jcp.is_depthwise) { |
1498 | jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise |
1499 | - jcp.signed_input - (!jcp.has_vnni) |
1500 | - (jcp.signed_input || jcp.need_saturation) // both alias |
1501 | - (bf16_req_extra_regs ? 4 : 0); |
1502 | } else { |
1503 | jcp.max_regs_ur = bf16_req_extra_regs ? 26 : jcp.has_vnni ? 31 : 28; |
1504 | } |
1505 | |
1506 | // TODO: re-implement so that the JIT Kernel uses the least amount of |
1507 | // registers. Currently, there are issues because of compile and run time |
1508 | // definitions. |
1509 | if (jcp.dst_scale) jcp.max_regs_ur = 26; |
1510 | if (jcp.src_zero_point || jcp.dst_zero_point) jcp.max_regs_ur = 25; |
1511 | |
1512 | auto set_or_check_wei_format = [&]() { |
1513 | using namespace format_tag; |
1514 | using namespace memory_extra_flags; |
1515 | format_tag_t wei_tag; |
1516 | if (jcp.ic_block == 16 || jcp.ch_block == 16) { |
1517 | if (is_3d) { |
1518 | wei_tag = with_groups ? gOIdhw4i16o4i : OIdhw4i16o4i; |
1519 | } else if (is_1d) { |
1520 | wei_tag = with_groups ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i |
1521 | : OIw4i16o4i; |
1522 | } else { |
1523 | assert(is_2d); |
1524 | wei_tag = with_groups |
1525 | ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i |
1526 | : OIhw4i16o4i; |
1527 | } |
1528 | } else if (jcp.ic_block == 8) { |
1529 | assert(with_groups); |
1530 | wei_tag = is_3d ? gOIdhw2i8o4i : is_2d ? gOIhw2i8o4i : gOIw2i8o4i; |
1531 | } else { |
1532 | assert(with_groups && jcp.ic_block == 4); |
1533 | wei_tag = is_3d ? gOIdhw4o4i : is_2d ? gOIhw4o4i : gOIw4o4i; |
1534 | } |
1535 | |
1536 | memory_desc_t want_wei_md = weights_md; |
1537 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
1538 | if (jcp.signed_input) { |
1539 | want_wei_md.extra.flags = 0 | compensation_conv_s8s8 | scale_adjust; |
1540 | want_wei_md.extra.compensation_mask = (1 << 0) |
1541 | + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); |
1542 | want_wei_md.extra.scale_adjust |
1543 | = mayiuse(avx512_core_vnni) ? 1.f : 0.5f; |
1544 | } |
1545 | if (jcp.src_zero_point) { |
1546 | want_wei_md.extra.flags |= compensation_conv_asymmetric_src; |
1547 | want_wei_md.extra.asymm_compensation_mask = (1 << 0) |
1548 | + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); |
1549 | } |
1550 | |
1551 | if (weights_md.format_kind == format_kind::any) { |
1552 | weights_md = want_wei_md; |
1553 | return true; |
1554 | } |
1555 | |
1556 | return weights_md == want_wei_md; |
1557 | }; |
1558 | |
1559 | if (!set_or_check_wei_format()) return status::unimplemented; |
1560 | |
1561 | format_tag_t dat_tag = utils::pick( |
1562 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1563 | |
1564 | if (src_d.format_kind() == format_kind::any) { |
1565 | CHECK(memory_desc_init_by_tag(src_md, dat_tag)); |
1566 | jcp.src_tag = dat_tag; |
1567 | } else { |
1568 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
1569 | } |
1570 | if (jcp.src_tag != dat_tag) return status::unimplemented; |
1571 | |
1572 | if (dst_d.format_kind() == format_kind::any) { |
1573 | CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); |
1574 | jcp.dst_tag = dat_tag; |
1575 | } else { |
1576 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
1577 | } |
1578 | if (jcp.dst_tag != dat_tag) return status::unimplemented; |
1579 | |
1580 | if (jcp.with_bias) { |
1581 | if (bias_d.format_kind() == format_kind::any) |
1582 | CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); |
1583 | } |
1584 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
1585 | jcp.dst_dt = cd.dst_desc.data_type; |
1586 | |
1587 | CHECK(attr.set_default_formats(&dst_md)); |
1588 | |
1589 | const auto &post_ops = attr.post_ops_; |
1590 | const int eltwise_ind = post_ops.find(primitive_kind::eltwise); |
1591 | jcp.with_eltwise = eltwise_ind != -1; |
1592 | |
1593 | const int binary_ind = post_ops.find(primitive_kind::binary); |
1594 | jcp.with_binary = binary_ind != -1; |
1595 | |
1596 | const int sum_ind = post_ops.find(primitive_kind::sum); |
1597 | jcp.with_sum = sum_ind != -1; |
1598 | jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); |
1599 | |
1600 | jcp.post_ops = post_ops; |
1601 | |
1602 | using namespace injector; |
1603 | static constexpr bool sum_at_pos_0_only = false; |
1604 | static constexpr bool sum_requires_scale_one = false; |
1605 | static constexpr bool sum_requires_zp_zero = false; |
1606 | const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, |
1607 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
1608 | sum_requires_zp_zero}); |
1609 | if (!post_ops_ok_) return status::unimplemented; |
1610 | |
1611 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
1612 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
1613 | jcp.typesize_bia |
1614 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
1615 | |
1616 | jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block); |
1617 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
1618 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
1619 | |
1620 | // Try to use 4 channel-groups at a time to avoid false sharing (depthwise) |
1621 | int nb_ch_blocking = 4; |
1622 | for (/* init above */; nb_ch_blocking > 1; nb_ch_blocking--) |
1623 | if (jcp.nb_ch % nb_ch_blocking == 0) break; |
1624 | jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1; |
1625 | |
1626 | // If OC blocking is incommensurate with the number of OC blocks (general |
1627 | // requirement for all convolutions), or if it results in an unrolling |
1628 | // factor smaller than the left padding (special requirement for SSD:fc6), |
1629 | // then search for a smaller OC blocking that satisfies both constraints. |
1630 | auto is_oc_blocking_ok = [&](int block) { |
1631 | int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1)); |
1632 | return jcp.nb_oc % block == 0 && jcp.l_pad <= ur_w |
1633 | && jcp.ow % ur_w != 1; |
1634 | }; |
1635 | |
1636 | // choose nb_oc work chunk size for distribution within threads |
1637 | int max_threading_nb_oc_chunk = 4; |
1638 | // Performance improvements for googlenet_v3 and resnet_50 with mb = 1; |
1639 | // TODO: generalize this condition and rewrite it in appropriate manner |
1640 | int ncores_per_socket = (int)cpu().getNumCores( |
1641 | Xbyak::util::IntelCpuTopologyLevel::CoreLevel); |
1642 | if (jcp.has_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3 |
1643 | && jcp.stride_w == 1 && jcp.ic % 64 == 0 |
1644 | && jcp.nthr <= ncores_per_socket) |
1645 | max_threading_nb_oc_chunk = 2; |
1646 | jcp.nb_oc_blocking_thr_chunk |
1647 | = nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc); |
1648 | for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) { |
1649 | if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk)) break; |
1650 | } |
1651 | |
1652 | // choose oc blocking for computational kernel |
1653 | jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk; |
1654 | |
1655 | // Performance improvements for googlenet_v3 with mb = 1; |
1656 | // TODO: generalize this condition and rewrite it in appropriate manner |
1657 | const int size_treshold_for_nb_oc_blocking_reduction = 17; |
1658 | if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction |
1659 | && jcp.stride_w == 1 && jcp.nthr <= ncores_per_socket |
1660 | && !(jcp.kh == 1 && jcp.kw == 3) |
1661 | && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) { |
1662 | const int max_nb_oc_blocking = 2; |
1663 | jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc); |
1664 | for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--) |
1665 | if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0 |
1666 | && is_oc_blocking_ok(jcp.nb_oc_blocking)) |
1667 | break; |
1668 | } |
1669 | |
1670 | if (jcp.is_resrc_depthwise) |
1671 | jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w) |
1672 | / (jcp.nb_ch_blocking + jcp.stride_w); |
1673 | else |
1674 | jcp.ur_w = jcp.max_regs_ur |
1675 | / (jcp.is_depthwise ? jcp.nb_ch_blocking |
1676 | : jcp.nb_oc_blocking + 1); |
1677 | if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; |
1678 | jcp.ur_w_tail = jcp.ow % jcp.ur_w; |
1679 | |
1680 | auto get_thr_eff = [=](int nb_ow, int nthr) { |
1681 | int base_work_amount = jcp.mb * jcp.nb_ch * jcp.od * jcp.oh |
1682 | * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk); |
1683 | auto work_amount = base_work_amount * nb_ow; |
1684 | return float(work_amount) / rnd_up(work_amount, nthr); |
1685 | }; |
1686 | |
1687 | auto get_ow_block = [=](int ur_w, int nthr) { |
1688 | int res_ow_block = jcp.ow; |
1689 | float best_thr_eff = get_thr_eff(1, nthr); |
1690 | float thr_eff; |
1691 | int max_nb_ow = div_up(jcp.ow, ur_w); |
1692 | for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) { |
1693 | int ow_block |
1694 | = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow); |
1695 | if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block |
1696 | && best_thr_eff > 0.8f) |
1697 | break; |
1698 | if (div_up(jcp.ow, ow_block) != nb_ow) continue; |
1699 | thr_eff = get_thr_eff(nb_ow, nthr); |
1700 | if (ow_block >= ur_w && thr_eff > 1.1f * best_thr_eff) { |
1701 | res_ow_block = ow_block; |
1702 | best_thr_eff = thr_eff; |
1703 | } |
1704 | if (best_thr_eff > 0.9f) break; |
1705 | } |
1706 | return res_ow_block; |
1707 | }; |
1708 | |
1709 | jcp.ow_block = get_ow_block(jcp.ur_w, jcp.nthr); |
1710 | jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); |
1711 | float thr_eff = get_thr_eff(jcp.nb_ow, jcp.nthr); |
1712 | |
1713 | /* adjust the thread decomposition |
1714 | * to improve the thr_eff for small size problem |
1715 | * the threshold L1_cache_size is empirical */ |
1716 | size_t wei_size |
1717 | = sizeof(float) * jcp.ic * jcp.oc * jcp.kh * jcp.kw * jcp.kd; |
1718 | size_t out_size |
1719 | = jcp.mb * jcp.typesize_out * jcp.oc * jcp.oh * jcp.ow * jcp.od; |
1720 | size_t inp_size |
1721 | = jcp.mb * jcp.typesize_in * jcp.ic * jcp.ih * jcp.iw * jcp.id; |
1722 | size_t total_size = jcp.ngroups * (wei_size + out_size + inp_size); |
1723 | const unsigned int L1_cache_size = platform::get_per_core_cache_size(1); |
1724 | |
1725 | if (thr_eff < 0.9f && jcp.ngroups < jcp.nthr |
1726 | && (total_size < L1_cache_size)) { |
1727 | int ow_block = jcp.ow_block; |
1728 | float best_thr_eff = -1.0f; |
1729 | float eff = -1.0f; |
1730 | int end_nthr = with_groups ? jcp.ngroups : 1; |
1731 | for (int nthr = jcp.nthr / 2; nthr > end_nthr; nthr--) { |
1732 | ow_block = get_ow_block(jcp.ur_w, nthr); |
1733 | eff = get_thr_eff(div_up(jcp.ow, ow_block), nthr); |
1734 | if (eff > 1.1f * best_thr_eff) { |
1735 | best_thr_eff = eff; |
1736 | jcp.ow_block = ow_block; |
1737 | jcp.nb_ow = div_up(jcp.ow, jcp.ow_block); |
1738 | jcp.nthr = jcp.aligned_threads = nthr; |
1739 | if (best_thr_eff > 0.9f) break; |
1740 | } |
1741 | } |
1742 | } |
1743 | |
1744 | if (jcp.oc % jcp.oc_block != 0) return status::unimplemented; |
1745 | |
1746 | pick_loop_order(jcp, jcp.nthr); |
1747 | |
1748 | jcp.nb_ic_L2 = jcp.nb_ic; |
1749 | |
1750 | jcp.wei_adj_scale |
1751 | = (weights_d.extra().flags & memory_extra_flags::scale_adjust) |
1752 | ? weights_d.extra().scale_adjust |
1753 | : 1.f; |
1754 | |
1755 | return status::success; |
1756 | } |
1757 | |
1758 | void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad( |
1759 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
1760 | const primitive_attr_t &attr) { |
1761 | const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; |
1762 | const dim_t scales_count = wei_mask == 0 ? 1 : jcp.oc * jcp.ngroups; |
1763 | dim_t count = wei_mask == 0 ? (dim_t)16 : scales_count; |
1764 | scratchpad.book<float>(key_conv_adjusted_scales, count); |
1765 | } |
1766 | |
1767 | template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>; |
1768 | template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>; |
1769 | template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>; |
1770 | } // namespace x64 |
1771 | } // namespace cpu |
1772 | } // namespace impl |
1773 | } // namespace dnnl |
1774 | |
1775 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
1776 | |