1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/memory.hpp" |
21 | #include "common/memory_tracking.hpp" |
22 | #include "common/nstl.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | #include "common/utils.hpp" |
25 | |
26 | #include "cpu/platform.hpp" |
27 | |
28 | #include "cpu/x64/injectors/injector_utils.hpp" |
29 | #include "cpu/x64/injectors/jit_uni_binary_injector.hpp" |
30 | #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" |
31 | #include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp" |
32 | #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" |
33 | |
34 | #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field) |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace cpu { |
39 | namespace x64 { |
40 | |
41 | using namespace dnnl::impl::utils; |
42 | using namespace dnnl::impl::data_type; |
43 | using namespace dnnl::impl::prop_kind; |
44 | using namespace Xbyak; |
45 | |
46 | template <typename Vmm> |
47 | _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>:: |
48 | _jit_avx512_core_x8s8s32x_1x1_conv_kernel( |
49 | const jit_1x1_conv_conf_t &ajcp, const primitive_attr_t &attr, |
50 | const memory_desc_t &dst_md) |
51 | : jit_generator(jit_name()) |
52 | , jcp(ajcp) |
53 | , attr_(attr) |
54 | , postops_injector_(nullptr) { |
55 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
56 | using namespace binary_injector; |
57 | static constexpr bool preserve_gpr = true; |
58 | static constexpr bool preserve_vmm = false; |
59 | static constexpr unsigned helper_vmm_idx = 31; |
60 | const size_t oc_block_tail = jcp.oc_block % isa_simd_width_; |
61 | const size_t tail_size = oc_block_tail |
62 | ? oc_block_tail |
63 | : jcp.oc_without_padding % isa_simd_width_; |
64 | static constexpr bool use_exact_tail_scalar_bcast = true; |
65 | |
66 | const rhs_arg_static_params_t rhs_arg_static_params {helper_vmm_idx, |
67 | r14, r15, r13, preserve_gpr, preserve_vmm, |
68 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
69 | memory_desc_wrapper(dst_md), tail_size, postops_mask, |
70 | use_exact_tail_scalar_bcast}; |
71 | const static_params_t static_params { |
72 | this->param1, rhs_arg_static_params}; |
73 | |
74 | postops_injector_ = utils::make_unique< |
75 | injector::jit_uni_postops_injector_t<avx512_core, Vmm>>( |
76 | this, jcp.post_ops, static_params); |
77 | } |
78 | if (jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa)) |
79 | bf16_emu_ = utils::make_unique<bf16_emulation_t>(this, |
80 | bf16_emu_reserv_1, bf16_emu_reserv_2, bf16_emu_reserv_3, |
81 | bf16_emu_reserv_4, bf16_emu_reserv_5, bf16_emu_reserv_5); |
82 | } |
83 | |
84 | template <typename Vmm> |
85 | void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::bcast_loop( |
86 | int load_loop_blk) { |
87 | mov(aux1_reg_bcast_data, reg_bcast_data); |
88 | mov(aux_reg_bcast_data, reg_bcast_data); |
89 | |
90 | mov(aux_reg_output_data, reg_output_data); |
91 | mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off)); |
92 | |
93 | Label bcast_loop; |
94 | Label bcast_loop_tail; |
95 | |
96 | cmp(bcast_loop_iter, jcp.ur); |
97 | jl(bcast_loop_tail, T_NEAR); |
98 | |
99 | L(bcast_loop); |
100 | { |
101 | assert(jcp.bcast_block % jcp.ur == 0); |
102 | int num_substeps = jcp.bcast_block / jcp.ur; |
103 | assert(num_substeps > 0 && num_substeps < 10); |
104 | for (int i = 0; i < num_substeps; i++) { |
105 | reduce_loop(load_loop_blk, jcp.ur, false); |
106 | if (i < num_substeps - 1) { |
107 | add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep); |
108 | add(aux_reg_output_data, jcp.bcast_loop_output_substep); |
109 | } else { |
110 | add(aux1_reg_bcast_data, |
111 | jcp.bcast_loop_bcast_step |
112 | - (num_substeps - 1) |
113 | * jcp.bcast_loop_bcast_substep); |
114 | int output_offset = jcp.bcast_loop_output_step |
115 | - (num_substeps - 1) * jcp.bcast_loop_output_substep; |
116 | |
117 | add(aux_reg_output_data, output_offset); |
118 | } |
119 | } |
120 | sub(bcast_loop_iter, jcp.bcast_block); |
121 | cmp(bcast_loop_iter, jcp.bcast_block); |
122 | jge(bcast_loop, T_NEAR); |
123 | } |
124 | |
125 | L(bcast_loop_tail); |
126 | if (jcp.ur_tail) { |
127 | Label bcast_loop_tail_out; |
128 | cmp(bcast_loop_iter, 0); |
129 | jz(bcast_loop_tail_out, T_NEAR); |
130 | reduce_loop(load_loop_blk, jcp.ur_tail, true); |
131 | L(bcast_loop_tail_out); |
132 | } |
133 | } |
134 | |
135 | template <typename Vmm> |
136 | void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::cvt2ps(data_type_t type_in, |
137 | const Vmm vmm_in, const Xbyak::Operand &op, bool mask_flag) { |
138 | using namespace data_type; |
139 | const Vmm vmm = mask_flag ? vmm_in | k_load_dim_mask | T_z : vmm_in; |
140 | switch (type_in) { |
141 | case f32: |
142 | case s32: vmovups(vmm, op); break; |
143 | case bf16: vpmovzxwd(vmm, op); break; |
144 | case s8: vpmovsxbd(vmm, op); break; |
145 | case u8: vpmovzxbd(vmm, op); break; |
146 | default: assert(!"unsupported data type" ); |
147 | } |
148 | if (one_of(type_in, s32, s8, u8)) |
149 | vcvtdq2ps(vmm_in, vmm_in); |
150 | else if (type_in == bf16) |
151 | vpslld(vmm_in, vmm_in, 16); |
152 | } |
153 | |
154 | template <typename F> |
155 | static void iterate(const int load_loop_blk, const int ur, |
156 | const bool last_oc_block_flag, const bool force_masking, const F &f) { |
157 | for (int i_load = 0; i_load < load_loop_blk; i_load++) { |
158 | const bool mask_flag = force_masking |
159 | || (last_oc_block_flag && i_load + 1 == load_loop_blk); |
160 | for (int i_ur = 0; i_ur < ur; i_ur++) |
161 | f(mask_flag, i_load, i_ur); |
162 | } |
163 | } |
164 | template <typename F> |
165 | static void iterate(const int load_loop_blk, const int ur, |
166 | const bool last_oc_block_flag, const F &f) { |
167 | iterate(load_loop_blk, ur, last_oc_block_flag, false, f); |
168 | } |
169 | template <typename F> |
170 | static void iterate(const int load_loop_blk, const int ur, const F &f) { |
171 | iterate(load_loop_blk, ur, false, false, f); |
172 | } |
173 | |
174 | template <typename Vmm> |
175 | Address _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::output_ptr( |
176 | const int i_load, const int i_ur) { |
177 | const size_t ur_stride = jcp.with_dw_conv |
178 | ? jcp.nb_load_blocking * jcp.oc_block * i_ur |
179 | : jcp.oc_without_padding * jcp.ngroups * i_ur; |
180 | |
181 | return EVEX_compress_addr(aux_reg_output_data, |
182 | jcp.typesize_out * (ur_stride + i_load * jcp.load_block)); |
183 | }; |
184 | |
185 | template <typename Vmm> |
186 | int _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::vreg_accum_idx( |
187 | const int load_loop_blk, int i_load, int i_ur) const { |
188 | return (i_ur * load_loop_blk + i_load); |
189 | }; |
190 | |
191 | template <typename Vmm> |
192 | Vmm _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::vreg_accum( |
193 | const int load_loop_blk, int i_load, int i_ur) const { |
194 | return Vmm(vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
195 | }; |
196 | |
197 | template <typename Vmm> |
198 | void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::apply_sum( |
199 | const int load_loop_blk, const int ur, const bool mask_flag_in, |
200 | const float *p_sum_scale, const int32_t *p_sum_zp) { |
201 | if (jcp.with_sum) { |
202 | const float sum_scale = *p_sum_scale; |
203 | const int32_t sum_zp = *p_sum_zp; |
204 | const auto sum_injector_lam |
205 | = [this, sum_scale, sum_zp, load_loop_blk](const bool mask_flag, |
206 | const int i_load, const int i_ur) { |
207 | const auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
208 | cvt2ps(jcp.sum_dt, vmm_prev_dst, output_ptr(i_load, i_ur), |
209 | mask_flag); |
210 | if (sum_zp != 0) vsubps(vmm_prev_dst, vmm_tmp); |
211 | if (sum_scale == 1.f) |
212 | vaddps(r, vmm_prev_dst); |
213 | else |
214 | vfmadd231ps( |
215 | r, vmm_prev_dst, zword_b[reg_ptr_sum_scale]); |
216 | }; |
217 | const auto sum_injector = [=]() { |
218 | iterate(load_loop_blk, ur, mask_flag_in, sum_injector_lam); |
219 | }; |
220 | if (sum_zp != 0) vcvtdq2ps(vmm_tmp, ptr_b[rsp + reg_ptr_sum_zp_off]); |
221 | postops_injector_->set_lambda_injector( |
222 | primitive_kind::sum, sum_injector); |
223 | } |
224 | } |
225 | |
226 | template <typename Vmm> |
227 | void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::apply_postops( |
228 | const int load_loop_blk, const int ur, const bool mask_flag_in, |
229 | const float *p_sum_scale, const int32_t *p_sum_zp) { |
230 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
231 | |
232 | apply_sum(load_loop_blk, ur, mask_flag_in, p_sum_scale, p_sum_zp); |
233 | |
234 | injector_utils::vmm_index_set_t vmm_idxs; |
235 | if (jcp.with_binary) { |
236 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params, |
237 | rhs_arg_params_tail; |
238 | const auto mask_tail = jcp.oc_without_padding % jcp.load_block; |
239 | const bool oc_blk_is_smaller_than_vmm |
240 | = jcp.oc_block < isa_simd_width_; |
241 | iterate(load_loop_blk, ur, mask_tail, oc_blk_is_smaller_than_vmm, |
242 | [&](const bool mask_flag, const int i_load, |
243 | const int i_ur) { |
244 | const int ur_stride = jcp.with_dw_conv |
245 | ? jcp.nb_load_blocking * jcp.oc_block * i_ur |
246 | : jcp.oc_without_padding * jcp.ngroups * i_ur; |
247 | const size_t aux_output_l_off = jcp.typesize_out |
248 | * (ur_stride + i_load * jcp.load_block); |
249 | const auto vmm_idx |
250 | = vreg_accum_idx(load_loop_blk, i_load, i_ur); |
251 | vmm_idxs.emplace(vmm_idx); |
252 | |
253 | rhs_arg_params_tail.vmm_idx_to_out_reg.emplace( |
254 | vmm_idx, aux_reg_output_data); |
255 | rhs_arg_params_tail.vmm_idx_to_out_elem_off_val.emplace( |
256 | vmm_idx, aux_output_l_off); |
257 | if (mask_flag) |
258 | rhs_arg_params_tail.vmm_tail_idx_.emplace(vmm_idx); |
259 | }); |
260 | rhs_arg_params = rhs_arg_params_tail; |
261 | rhs_arg_params.vmm_tail_idx_.clear(); |
262 | |
263 | mov(abi_param1, EVEX_compress_addr(rsp, reg_abi_param1_backup)); |
264 | |
265 | Label postops_done; |
266 | if (mask_tail || oc_blk_is_smaller_than_vmm) { |
267 | Label postops_no_tail; |
268 | if (mask_tail) { |
269 | test(reg_reduce_pos_flag, FLAG_OC_LAST); |
270 | jz(postops_no_tail, T_NEAR); |
271 | cmp(reg_load_loop_work, 0); |
272 | jg(postops_no_tail, T_NEAR); |
273 | } |
274 | postops_injector_->compute_vector_range( |
275 | vmm_idxs, rhs_arg_params_tail); |
276 | jmp(postops_done, T_NEAR); |
277 | L(postops_no_tail); |
278 | } |
279 | postops_injector_->compute_vector_range(vmm_idxs, rhs_arg_params); |
280 | L(postops_done); |
281 | |
282 | } else { |
283 | iterate(load_loop_blk, ur, |
284 | [&](const bool, const int i_load, const int i_ur) { |
285 | vmm_idxs.emplace( |
286 | vreg_accum_idx(load_loop_blk, i_load, i_ur)); |
287 | }); |
288 | postops_injector_->compute_vector_range(vmm_idxs); |
289 | } |
290 | } |
291 | } |
292 | |
293 | template <typename Vmm> |
294 | void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::reduce_loop( |
295 | int load_loop_blk, int ur, bool wraparound) { |
296 | auto vreg_load |
297 | = [=](int i_load) { return Vmm(ur * load_loop_blk + i_load); }; |
298 | |
299 | auto bias_ptr = [=](int i_load) { |
300 | return EVEX_compress_addr( |
301 | reg_bias_data, jcp.typesize_bia * jcp.oc_block * i_load); |
302 | }; |
303 | |
304 | auto comp_ptr = [=](int i_load) { |
305 | return EVEX_compress_addr( |
306 | reg_comp_data, sizeof(int32_t) * jcp.oc_block * i_load); |
307 | }; |
308 | |
309 | auto scale_ptr = [=](int i_load) { |
310 | return EVEX_compress_addr(reg_ptr_scales, |
311 | jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load)); |
312 | }; |
313 | |
314 | auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) { |
315 | assert(i_ur < jcp.ur); |
316 | assert(i_reduce <= jcp.reduce_loop_unroll); |
317 | assert(jcp.reduce_loop_unroll == jcp.reduce_block); |
318 | |
319 | int offt = (jcp.ic_without_padding * i_ur * jcp.ngroups + i_reduce); |
320 | |
321 | return EVEX_compress_addr( |
322 | aux_reg_bcast_data, jcp.typesize_in * offt, bcast); |
323 | }; |
324 | |
325 | auto load_ptr = [=](int i_reduce, int i_load) { |
326 | int u0 = i_reduce % jcp.reduce_loop_unroll; |
327 | int u1 = i_reduce / jcp.reduce_loop_unroll; |
328 | |
329 | int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block; |
330 | |
331 | return EVEX_compress_addr(aux_reg_load_data, |
332 | u1 * jcp.reduce_loop_load_step + jcp.typesize_in * offt); |
333 | }; |
334 | |
335 | auto init = [=]() { |
336 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) |
337 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
338 | auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
339 | vpxord(r, r, r); |
340 | } |
341 | if (jcp.signed_input) { |
342 | mov(reg_scratch, -128); |
343 | vpbroadcastb(vmm_shift, reg_scratch.cvt8()); |
344 | } |
345 | }; |
346 | |
347 | auto store = [=](const bool mask_flag_in) { |
348 | const auto &p = attr_.post_ops_; |
349 | const int sum_idx = p.find(primitive_kind::sum); |
350 | const float *p_sum_scale = nullptr; |
351 | const int32_t *p_sum_zp = nullptr; |
352 | if (sum_idx != -1) { |
353 | p_sum_scale = &p.entry_[sum_idx].sum.scale; |
354 | p_sum_zp = &p.entry_[sum_idx].sum.zero_point; |
355 | } |
356 | const auto p_sum_scale_val = p_sum_scale ? *p_sum_scale : 1.f; |
357 | const auto p_sum_zp_val = p_sum_zp ? *p_sum_zp : 0; |
358 | const bool is_scale_or_zp_sum |
359 | = p_sum_zp_val != 0 || p_sum_scale_val != 1.f; |
360 | mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); |
361 | mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); |
362 | if (is_scale_or_zp_sum) { |
363 | mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data); |
364 | if (p_sum_zp_val != 0) { |
365 | mov(reg_load_data, p_sum_zp_val); |
366 | mov(ptr[rsp + reg_ptr_sum_zp_off], reg_load_data); |
367 | } |
368 | if (p_sum_scale_val != 1.f) |
369 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
370 | } |
371 | if (jcp.signed_input && (!jcp.has_vnni)) { |
372 | mov(reg_scratch, float2int(jcp.wei_adj_scale)); |
373 | } |
374 | if (jcp.src_zero_point) { |
375 | mov(reg_zp_compensation, |
376 | EVEX_compress_addr(rsp, reg_zp_compensation_off)); |
377 | mov(reg_src_zero_point, |
378 | EVEX_compress_addr(rsp, reg_src_zero_point_off)); |
379 | } |
380 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
381 | const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1; |
382 | auto vmm_bias = vmm_tmp; |
383 | auto vmm_comp = vmm_bcast; |
384 | if (jcp.with_bias) { |
385 | if (jcp.signed_input || jcp.dst_scale) |
386 | mov(reg_bias_data, |
387 | EVEX_compress_addr(rsp, reg_bias_data_off)); |
388 | cvt2ps(jcp.bia_dt, vmm_bias, bias_ptr(i_load), mask_flag); |
389 | } |
390 | if (jcp.signed_input) { |
391 | mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); |
392 | cvt2ps(data_type::s32, vmm_comp, comp_ptr(i_load), mask_flag); |
393 | } |
394 | if (jcp.src_zero_point) { |
395 | // zero_point: conv(src_x8, wei_s8) - src_shift_s32 * compensation_s32 |
396 | const int zp_offset = sizeof(int32_t) * i_load * jcp.load_block; |
397 | vmovups(vmm_zp, |
398 | EVEX_compress_addr(reg_zp_compensation, zp_offset)); |
399 | vpmulld(vmm_zp, vmm_zp, |
400 | EVEX_compress_addr( |
401 | reg_src_zero_point, 0, jcp.zp_src_is_common)); |
402 | // upscale to f32 |
403 | const Vmm vmm_ |
404 | = mask_flag ? vmm_zp | k_load_dim_mask | T_z : vmm_zp; |
405 | vcvtdq2ps(vmm_, vmm_); |
406 | } |
407 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
408 | auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
409 | vcvtdq2ps(r, r); |
410 | if (jcp.signed_input) vaddps(r, r, vmm_comp); |
411 | if (jcp.src_zero_point) vaddps(r, r, vmm_zp); |
412 | |
413 | const Vmm mask_vmm = mask_flag ? r | k_load_dim_mask | T_z : r; |
414 | vmulps(mask_vmm, r, scale_ptr(i_load)); |
415 | |
416 | if (jcp.with_bias) vaddps(r, r, vmm_bias); |
417 | } |
418 | } |
419 | |
420 | apply_postops(load_loop_blk, ur, mask_flag_in, p_sum_scale, p_sum_zp); |
421 | |
422 | if (jcp.dst_scale) { |
423 | mov(reg_ptr_dst_scale, EVEX_compress_addr(rsp, reg_dst_scale_off)); |
424 | vmovups(vmm_dst_scale, EVEX_compress_addr(reg_ptr_dst_scale, 0)); |
425 | |
426 | /* Apply dst scale to accumulator */ |
427 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
428 | const bool mask_flag |
429 | = mask_flag_in && i_load == load_loop_blk - 1; |
430 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
431 | const auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
432 | const Vmm mask_vmm |
433 | = mask_flag ? r | k_load_dim_mask | T_z : r; |
434 | vmulps(mask_vmm, r, vmm_dst_scale); |
435 | } |
436 | } |
437 | } |
438 | |
439 | if (jcp.dst_zero_point) { |
440 | mov(reg_dst_zero_point, |
441 | EVEX_compress_addr(rsp, reg_dst_zero_point_off)); |
442 | vcvtdq2ps(vmm_zp, EVEX_compress_addr(reg_dst_zero_point, 0, true)); |
443 | |
444 | /* Add dst zero_point to accumulator */ |
445 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
446 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
447 | const auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
448 | vaddps(r, r, vmm_zp); |
449 | } |
450 | } |
451 | } |
452 | |
453 | // Properly saturate the accumulators for integer datatypes |
454 | if (one_of(jcp.dst_dt, u8, s8, s32)) { |
455 | init_saturate_f32(vmm_zero, vmm_saturation, |
456 | reg_ptr_saturation_ubound, f32, jcp.dst_dt); |
457 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
458 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
459 | auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
460 | saturate_f32(r, vmm_zero, vmm_saturation, jcp.dst_dt); |
461 | vcvtps2dq(r, r); |
462 | } |
463 | } |
464 | } |
465 | |
466 | if (jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa)) |
467 | bf16_emu_->init_vcvtneps2bf16(); |
468 | |
469 | // store to the destination |
470 | if (jcp.dst_dt == data_type::bf16 && isa_has_bf16(jcp.isa)) { |
471 | // Optimization: use single store instruction for pair |
472 | // of the nearest vectors along LOAD dimension |
473 | for (int i_ur = 0; i_ur < ur; i_ur++) { |
474 | int i_load = 0; |
475 | for (; i_load < rnd_dn(load_loop_blk, 2); i_load += 2) { |
476 | auto vmm_dst = vreg_accum(load_loop_blk, i_load, i_ur); |
477 | auto vmm_dst_next |
478 | = vreg_accum(load_loop_blk, i_load + 1, i_ur); |
479 | vcvtne2ps2bf16(vmm_dst, vmm_dst_next, vmm_dst); |
480 | bool mask_flag |
481 | = mask_flag_in && i_load + 2 == load_loop_blk; |
482 | vmovdqu16(output_ptr(i_load, i_ur), |
483 | maybe_mask_vmm(vmm_dst, mask_flag)); |
484 | } |
485 | if (load_loop_blk % 2 != 0) { |
486 | auto vmm_accum = vreg_accum(load_loop_blk, i_load, i_ur); |
487 | auto vmm_down = Vmm_down_t(vmm_accum.getIdx()); |
488 | vcvtneps2bf16(vmm_down, vmm_accum); |
489 | vmovdqu16(output_ptr(i_load, i_ur), |
490 | maybe_mask_vmm_down(vmm_down, |
491 | jcp.ic_block == 4 || mask_flag_in)); |
492 | } |
493 | } |
494 | } else { |
495 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
496 | const bool mask_flag |
497 | = mask_flag_in && i_load == load_loop_blk - 1; |
498 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
499 | auto r = vreg_accum(load_loop_blk, i_load, i_ur); |
500 | const Vmm r_vmm = mask_flag ? r | k_load_dim_mask : r; |
501 | |
502 | switch (jcp.dst_dt) { |
503 | case data_type::f32: |
504 | case data_type::s32: |
505 | vmovups(output_ptr(i_load, i_ur), r_vmm); |
506 | break; |
507 | case data_type::s8: |
508 | vpmovsdb(output_ptr(i_load, i_ur), r_vmm); |
509 | break; |
510 | case data_type::u8: |
511 | vpmovusdb(output_ptr(i_load, i_ur), r_vmm); |
512 | break; |
513 | case data_type::bf16: { |
514 | bf16_emu_->vcvtneps2bf16( |
515 | ymm_store, Zmm(r.getIdx())); |
516 | vmovdqu16(output_ptr(i_load, i_ur), |
517 | maybe_mask_vmm_down(vmm_store(), |
518 | jcp.ic_block == 4 || mask_flag)); |
519 | } break; |
520 | default: assert(!"unknown dst_dt" ); |
521 | } |
522 | } |
523 | } |
524 | } |
525 | mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); |
526 | if (is_scale_or_zp_sum) |
527 | mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off)); |
528 | }; |
529 | |
530 | auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) { |
531 | if (jcp.has_vnni) { |
532 | vpdpbusd(vreg_acc, vreg_src, vreg_wei); |
533 | } else { |
534 | vpmaddubsw(vmm_tmp, vreg_src, vreg_wei); |
535 | vpmaddwd(vmm_tmp, vmm_tmp, vmm_one); |
536 | vpaddd(vreg_acc, vreg_acc, vmm_tmp); |
537 | } |
538 | }; |
539 | |
540 | auto fma_block = [=](bool last_block) { |
541 | int reduce_step = 4; |
542 | int ic_tail_size = jcp.ic_without_padding % reduce_step; |
543 | int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding |
544 | ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step) |
545 | : jcp.reduce_loop_unroll; |
546 | for (int i_reduce = 0; i_reduce < loop_unroll; |
547 | i_reduce += reduce_step) { |
548 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) |
549 | vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load)); |
550 | for (int i_ur = 0; i_ur < ur; ++i_ur) { |
551 | if (last_block && ic_tail_size != 0 |
552 | && i_reduce == loop_unroll - reduce_step) { |
553 | Xmm xmm_bcast = Xmm(vmm_bcast.getIdx()); |
554 | load_bytes(xmm_bcast, aux_reg_bcast_data, |
555 | jcp.ic_without_padding * i_ur + i_reduce, |
556 | ic_tail_size); |
557 | vpbroadcastd(vmm_bcast, xmm_bcast); |
558 | } else { |
559 | vpbroadcastd(vmm_bcast, bcast_ptr(i_reduce, i_ur, false)); |
560 | } |
561 | if (jcp.signed_input) vpsubb(vmm_bcast, vmm_bcast, vmm_shift); |
562 | for (int i_load = 0; i_load < load_loop_blk; ++i_load) { |
563 | compute(vreg_accum(load_loop_blk, i_load, i_ur), |
564 | vreg_load(i_load), vmm_bcast); |
565 | } |
566 | } |
567 | } |
568 | }; |
569 | |
570 | Label reduce_loop; |
571 | Label reduce_loop_tail; |
572 | |
573 | mov(aux_reg_load_data, reg_load_data); |
574 | |
575 | mov(aux_reg_bcast_data, aux1_reg_bcast_data); |
576 | init(); |
577 | |
578 | mov(reduce_loop_iter, reg_reduce_loop_work); |
579 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
580 | jle(reduce_loop_tail, T_NEAR); |
581 | |
582 | L(reduce_loop); |
583 | { |
584 | fma_block(false); |
585 | add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step); |
586 | add(aux_reg_load_data, jcp.reduce_loop_load_step); |
587 | sub(reduce_loop_iter, jcp.reduce_loop_unroll); |
588 | jg(reduce_loop, T_NEAR); |
589 | } |
590 | |
591 | L(reduce_loop_tail); |
592 | if (jcp.ic != jcp.ic_without_padding) { |
593 | fma_block(true); |
594 | } else { |
595 | fma_block(false); |
596 | } |
597 | |
598 | if (jcp.oc_without_padding != jcp.oc) { |
599 | Label end_store, common_store; |
600 | mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); |
601 | |
602 | /*Check if it is the last load_loop_blk*/ |
603 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
604 | cmp(reg_load_loop_work, 0); |
605 | jg(common_store, T_NEAR); |
606 | |
607 | /*Check if it is the last ocb*/ |
608 | test(reg_reduce_pos_flag, FLAG_OC_LAST); |
609 | jz(common_store, T_NEAR); |
610 | |
611 | store(true); |
612 | jmp(end_store, T_NEAR); |
613 | |
614 | L(common_store); |
615 | store(false); |
616 | |
617 | L(end_store); |
618 | |
619 | add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
620 | } else { |
621 | store(false); |
622 | } |
623 | } |
624 | |
625 | template <typename Vmm> |
626 | void _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Vmm>::generate() { |
627 | |
628 | preamble(); |
629 | |
630 | const int simd_w = jcp.ic_block; |
631 | xor_(reg_scratch, reg_scratch); |
632 | Reg16 _t = reg_scratch.cvt16(); |
633 | mov(_t, 0x1); |
634 | vpbroadcastw(vmm_one, _t); |
635 | |
636 | sub(rsp, stack_space_needed); |
637 | if (jcp.with_binary) { |
638 | const auto zeroed_reg = r15; |
639 | xor_(zeroed_reg, zeroed_reg); |
640 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), zeroed_reg); |
641 | mov(EVEX_compress_addr(rsp, reg_abi_param1_backup), abi_param1); |
642 | } |
643 | |
644 | if (jcp.with_bias) mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]); |
645 | if (jcp.signed_input) { |
646 | mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); |
647 | mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]); |
648 | mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); |
649 | } |
650 | if (jcp.src_zero_point) { |
651 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
652 | mov(EVEX_compress_addr(rsp, reg_zp_compensation_off), |
653 | reg_zp_compensation); |
654 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
655 | mov(EVEX_compress_addr(rsp, reg_src_zero_point_off), |
656 | reg_src_zero_point); |
657 | } |
658 | if (jcp.dst_scale) { |
659 | if (!jcp.signed_input) |
660 | mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); |
661 | mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
662 | mov(EVEX_compress_addr(rsp, reg_dst_scale_off), reg_ptr_dst_scale); |
663 | } |
664 | if (jcp.dst_zero_point) { |
665 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
666 | mov(EVEX_compress_addr(rsp, reg_dst_zero_point_off), |
667 | reg_dst_zero_point); |
668 | } |
669 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
670 | mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); |
671 | mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]); |
672 | mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]); |
673 | mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]); |
674 | |
675 | mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]); |
676 | mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]); |
677 | mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work); |
678 | mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]); |
679 | mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]); |
680 | |
681 | if (jcp.ic_block == 4 && jcp.dst_dt == data_type::bf16) { |
682 | Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32(); |
683 | mov(reg_tail_32, (1 << jcp.ic_block) - 1); |
684 | kmovb(k_load_dim_tail_mask, reg_tail_32); |
685 | } |
686 | |
687 | const int load_dim_tail |
688 | = (one_of(jcp.prop_kind, forward_training, forward_inference) |
689 | ? jcp.oc_without_padding |
690 | : jcp.load_dim) |
691 | % jcp.load_block; |
692 | const bool use_extended_mask |
693 | = jcp.dst_dt == data_type::bf16 && isa_has_bf16(jcp.isa); |
694 | if (load_dim_tail) { |
695 | Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32(); |
696 | mov(reg_tail_32, (1 << load_dim_tail) - 1); |
697 | kmovw(k_load_dim_tail_mask, reg_tail_32); |
698 | kmovw(postops_mask, reg_tail_32); |
699 | |
700 | if (use_extended_mask) { |
701 | mov(reg_tail_32.cvt32(), |
702 | (1 << (load_dim_tail + jcp.load_block)) - 1); |
703 | kmovd(k_load_dim_tail_mask_extended, reg_tail_32.cvt32()); |
704 | } |
705 | } else if (jcp.with_binary) |
706 | if (jcp.oc_block != isa_simd_width_) { |
707 | const int mask = (1 << jcp.oc_block) - 1; |
708 | const Reg32 reg_tail_32 = reg_load_dim_tail_mask.cvt32(); |
709 | mov(reg_tail_32, mask); |
710 | kmovw(postops_mask, reg_tail_32); |
711 | } |
712 | |
713 | auto load_loop_body = [=](int load_loop_blk) { |
714 | if (load_dim_tail) { |
715 | kxnorw(k_load_dim_mask, k_load_dim_mask, k_load_dim_mask); |
716 | if (use_extended_mask) |
717 | kxnord(k_load_dim_mask_extended, k_load_dim_mask_extended, |
718 | k_load_dim_mask_extended); |
719 | Label no_update_mask; |
720 | test(reg_reduce_pos_flag, FLAG_OC_LAST); |
721 | jz(no_update_mask, T_NEAR); |
722 | cmp(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
723 | jg(no_update_mask, T_NEAR); |
724 | kmovw(k_load_dim_mask, k_load_dim_tail_mask); |
725 | if (use_extended_mask) |
726 | kmovd(k_load_dim_mask_extended, k_load_dim_tail_mask_extended); |
727 | L(no_update_mask); |
728 | } else if (jcp.ic_block == 4 && jcp.dst_dt == data_type::bf16) { |
729 | kmovw(k_load_dim_mask, k_load_dim_tail_mask); |
730 | } |
731 | |
732 | bcast_loop(load_loop_blk); |
733 | add(reg_load_data, load_loop_blk * jcp.load_loop_load_step); |
734 | if (jcp.with_bias) { |
735 | if (jcp.signed_input || jcp.dst_scale) |
736 | mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off)); |
737 | add(reg_bias_data, |
738 | load_loop_blk * jcp.load_block * jcp.typesize_bia); |
739 | if (jcp.signed_input || jcp.dst_scale) |
740 | mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data); |
741 | } |
742 | if (jcp.with_binary) { |
743 | mov(reg_scratch, |
744 | EVEX_compress_addr(rsp, reg_binary_post_op_acc_off)); |
745 | add(reg_scratch, jcp.load_block * load_loop_blk); |
746 | mov(EVEX_compress_addr(rsp, reg_binary_post_op_acc_off), |
747 | reg_scratch); |
748 | } |
749 | if (jcp.signed_input) { |
750 | mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off)); |
751 | add(reg_comp_data, |
752 | load_loop_blk * jcp.load_block * sizeof(int32_t)); |
753 | mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data); |
754 | } |
755 | if (jcp.src_zero_point) { |
756 | mov(reg_zp_compensation, |
757 | EVEX_compress_addr(rsp, reg_zp_compensation_off)); |
758 | add(reg_zp_compensation, |
759 | load_loop_blk * jcp.load_block * sizeof(int32_t)); |
760 | mov(EVEX_compress_addr(rsp, reg_zp_compensation_off), |
761 | reg_zp_compensation); |
762 | } |
763 | mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data); |
764 | mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off)); |
765 | add(reg_ptr_scales, |
766 | jcp.is_oc_scale * load_loop_blk * jcp.load_block |
767 | * sizeof(float)); |
768 | mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales); |
769 | mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off)); |
770 | add(reg_output_data, load_loop_blk * jcp.load_block * jcp.typesize_out); |
771 | sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step); |
772 | }; |
773 | |
774 | Label load_loop_blk[7]; |
775 | |
776 | static const int ur_cases_fma_expl_bcast[] = {2, 5, 6, 9, 14, 32}; |
777 | const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast); |
778 | const int *ur_cases_fma = ur_cases_fma_expl_bcast; |
779 | const int *ur_cases = ur_cases_fma; |
780 | const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases); |
781 | |
782 | for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) { |
783 | int label_idx = num_ur_cases - ur_idx - 1; |
784 | if (jcp.ur <= ur_cases[ur_idx]) { |
785 | cmp(reg_load_loop_work, simd_w * (label_idx + 1)); |
786 | jle(load_loop_blk[label_idx], T_NEAR); |
787 | } |
788 | } |
789 | |
790 | for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) { |
791 | if (jcp.ur <= ur_cases[ur_idx]) { |
792 | int label_idx = num_ur_cases - ur_idx - 1; |
793 | L(load_loop_blk[label_idx]); |
794 | { |
795 | if (label_idx == 0) { |
796 | cmp(reg_load_loop_work, 0); |
797 | je(load_loop_blk[num_ur_cases], T_NEAR); |
798 | } |
799 | |
800 | for (int _i = 1; _i <= label_idx + 1; _i++) { |
801 | prefetcht0(ptr[reg_load_data + _i * jcp.ic * jcp.oc_block]); |
802 | prefetcht1(ptr[reg_output_data + _i * jcp.oc_block]); |
803 | } |
804 | |
805 | load_loop_body(label_idx + 1); |
806 | if (label_idx - 1 > 0) { |
807 | cmp(reg_load_loop_work, 2 * label_idx * simd_w); |
808 | je(load_loop_blk[label_idx - 1], T_NEAR); |
809 | } |
810 | cmp(reg_load_loop_work, (label_idx + 1) * simd_w); |
811 | jge(load_loop_blk[label_idx]); |
812 | } |
813 | for (int idx = label_idx - 1; idx > 0; --idx) { |
814 | cmp(reg_load_loop_work, simd_w * (idx + 1)); |
815 | je(load_loop_blk[idx], T_NEAR); |
816 | } |
817 | if (ur_idx < num_ur_cases - 2) { |
818 | cmp(reg_load_loop_work, simd_w); |
819 | jle(load_loop_blk[0], T_NEAR); |
820 | } |
821 | } |
822 | } |
823 | L(load_loop_blk[num_ur_cases]); |
824 | |
825 | add(rsp, stack_space_needed); |
826 | |
827 | postamble(); |
828 | |
829 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
830 | } |
831 | |
832 | status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf( |
833 | jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd, |
834 | const memory_desc_t *&src_md, memory_desc_t &weights_md, |
835 | memory_desc_t &dst_md, memory_desc_t &bias_md, |
836 | const primitive_attr_t &attr, int nthreads, bool reduce_src) { |
837 | |
838 | if (!mayiuse(avx512_core)) return status::unimplemented; |
839 | |
840 | // used for bf16 output |
841 | jcp.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 |
842 | : bf16_emulation_t::get_isa(); |
843 | |
844 | const memory_desc_wrapper src_d(src_md); |
845 | const memory_desc_wrapper weights_d(&weights_md); |
846 | const memory_desc_wrapper dst_d(&dst_md); |
847 | const memory_desc_wrapper bias_d(&bias_md); |
848 | |
849 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
850 | if (!one_of(src_d.data_type(), data_type::u8, data_type::s8) |
851 | || weights_d.data_type() != data_type::s8 |
852 | || !one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
853 | data_type::s8, data_type::u8, data_type::bf16)) |
854 | return status::unimplemented; |
855 | |
856 | jcp.nthr = nthreads; |
857 | |
858 | jcp.has_vnni = mayiuse(avx512_core_vnni); |
859 | |
860 | int ndims = src_d.ndims(); |
861 | jcp.ndims = ndims; |
862 | |
863 | jcp.prop_kind = cd.prop_kind; |
864 | |
865 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
866 | jcp.mb = src_d.dims()[0]; |
867 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
868 | jcp.oc_without_padding = jcp.oc; |
869 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
870 | jcp.ic_without_padding = jcp.ic; |
871 | |
872 | const bool is_1d = ndims == 3; |
873 | const bool is_3d = ndims == 5; |
874 | |
875 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
876 | jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2]; |
877 | jcp.iw = src_d.dims()[ndims - 1]; |
878 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
879 | jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2]; |
880 | jcp.ow = dst_d.dims()[ndims - 1]; |
881 | |
882 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
883 | jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2]; |
884 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
885 | |
886 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
887 | jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4]; |
888 | jcp.l_pad = cd.padding[0][ndims - 3]; |
889 | |
890 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
891 | jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4]; |
892 | jcp.stride_w = cd.strides[ndims - 3]; |
893 | |
894 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
895 | jcp.signed_input = (src_d.data_type() == data_type::s8); |
896 | |
897 | jcp.os = static_cast<dim_t>(jcp.od) * jcp.oh * jcp.ow; |
898 | jcp.is = static_cast<dim_t>(jcp.id) * jcp.ih * jcp.iw; |
899 | |
900 | if (jcp.os > INT_MAX || jcp.is > INT_MAX) return status::unimplemented; |
901 | |
902 | const auto &post_ops = attr.post_ops_; |
903 | const int dw_conv_ind = post_ops.find(primitive_kind::convolution); |
904 | jcp.with_dw_conv = dw_conv_ind != -1; |
905 | // Using dw_conv_ind as upper-bound below, as post-ops after it will be |
906 | // handled in depthwise convolution. |
907 | const int eltwise_ind |
908 | = post_ops.find(primitive_kind::eltwise, 0, dw_conv_ind); |
909 | jcp.with_eltwise = eltwise_ind != -1; |
910 | |
911 | const int binary_ind |
912 | = post_ops.find(primitive_kind::binary, 0, dw_conv_ind); |
913 | jcp.with_binary = binary_ind != -1; |
914 | |
915 | const int sum_ind = post_ops.find(primitive_kind::sum, 0, dw_conv_ind); |
916 | jcp.with_sum = sum_ind != -1; |
917 | |
918 | if (dw_conv_ind >= 0) { |
919 | // dw_conv and post_ops after it are handled externally, so skip them |
920 | jcp.post_ops.entry_.assign(post_ops.entry_.cbegin(), |
921 | post_ops.entry_.cbegin() + dw_conv_ind); |
922 | } else { |
923 | jcp.post_ops = post_ops; |
924 | } |
925 | |
926 | const auto zp = attr.zero_points_; |
927 | jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); |
928 | jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); |
929 | jcp.zp_src_is_common |
930 | = zp.common(DNNL_ARG_SRC); // otherwise, it's per-channel |
931 | assert(IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common)); |
932 | |
933 | if ((jcp.dst_zero_point || jcp.src_zero_point) && jcp.with_dw_conv) |
934 | return status::unimplemented; |
935 | |
936 | format_tag_t dat_tag = utils::pick( |
937 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
938 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
939 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
940 | |
941 | bool args_ok = jcp.src_tag == dat_tag && jcp.dst_tag == dat_tag; |
942 | if (!args_ok) return status::unimplemented; |
943 | |
944 | if (jcp.ngroups == 1) { |
945 | jcp.oc = rnd_up(jcp.oc, 16); |
946 | jcp.ic = rnd_up(jcp.ic, 16); |
947 | } |
948 | |
949 | using namespace injector; |
950 | static constexpr bool sum_at_pos_0_only = false; |
951 | static constexpr bool sum_requires_scale_one = false; |
952 | static constexpr bool sum_requires_zp_zero = false; |
953 | const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, |
954 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
955 | sum_requires_zp_zero}); |
956 | if (!post_ops_ok_) return status::unimplemented; |
957 | |
958 | const int simd_w = (jcp.ic % 16 == 0 && jcp.oc % 16 == 0) |
959 | ? 16 |
960 | : (jcp.ic % 8 == 0 && jcp.oc % 8 == 0) ? 8 : 4; |
961 | |
962 | auto set_or_check_wei_format = [&]() -> bool { |
963 | using namespace format_tag; |
964 | using namespace memory_extra_flags; |
965 | const format_tag_t wei_tags[3][2][3] |
966 | = {{{OIw4i16o4i, OIhw4i16o4i, OIdhw4i16o4i}, |
967 | {gOIw4i16o4i, gOIhw4i16o4i, gOIdhw4i16o4i}}, |
968 | {{OIw2i8o4i, OIhw2i8o4i, OIdhw2i8o4i}, |
969 | {gOIw2i8o4i, gOIhw2i8o4i, gOIdhw2i8o4i}}, |
970 | {{OIw4o4i, OIhw4o4i, OIdhw4o4i}, |
971 | {gOIw4o4i, gOIhw4o4i, gOIdhw4o4i}}}; |
972 | |
973 | const int simd_idx = simd_w == 16 ? 0 : simd_w == 8 ? 1 : 2; |
974 | const auto wei_tag = wei_tags[simd_idx][with_groups][ndims - 3]; |
975 | memory_desc_t want_wei_md = weights_md; |
976 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
977 | if (jcp.signed_input) { |
978 | want_wei_md.extra.flags = 0 | compensation_conv_s8s8 | scale_adjust; |
979 | want_wei_md.extra.compensation_mask |
980 | = (1 << 0) + (with_groups ? (1 << 1) : 0); |
981 | want_wei_md.extra.scale_adjust |
982 | = mayiuse(avx512_core_vnni) ? 1.f : 0.5f; |
983 | } |
984 | if (jcp.src_zero_point) { |
985 | want_wei_md.extra.flags |= compensation_conv_asymmetric_src; |
986 | want_wei_md.extra.asymm_compensation_mask |
987 | = (1 << 0) + (with_groups ? (1 << 1) : 0); |
988 | } |
989 | |
990 | if (weights_md.format_kind == format_kind::any) { |
991 | weights_md = want_wei_md; |
992 | return true; |
993 | } |
994 | |
995 | return weights_md == want_wei_md; |
996 | }; |
997 | |
998 | if (!set_or_check_wei_format()) return status::unimplemented; |
999 | |
1000 | args_ok = true && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0 |
1001 | && jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0 |
1002 | && jcp.stride_d == 1 && jcp.stride_h == 1 |
1003 | && jcp.stride_w == 1 // TODO: support some strides |
1004 | && jcp.od == jcp.id && jcp.oh == jcp.ih |
1005 | && jcp.ow == jcp.iw // enforce rpad = 0 |
1006 | && jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1; |
1007 | if (!args_ok) return status::unimplemented; |
1008 | |
1009 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
1010 | jcp.dst_dt = cd.dst_desc.data_type; |
1011 | jcp.sum_dt = post_ops.get_sum_dt(jcp.dst_dt); |
1012 | |
1013 | jcp.ic_block = jcp.oc_block = simd_w; |
1014 | |
1015 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
1016 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
1017 | jcp.typesize_bia |
1018 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
1019 | |
1020 | const int SMALL_SPATIAL = 7 * 7; |
1021 | const int BIG_REDUCE_DIM = 1024; |
1022 | |
1023 | int load_blocking = 0; |
1024 | int load_blocking_max = 0; |
1025 | int bcast_blocking = 0; |
1026 | int bcast_blocking_max = 0; |
1027 | int reduce_blocking = 0; |
1028 | int reduce_blocking_max = 0; |
1029 | jcp.load_grp_count = 1; |
1030 | jcp.use_vmovntps = false; |
1031 | |
1032 | const int L2_size |
1033 | = platform::get_per_core_cache_size(2) / sizeof(jcp.typesize_in); |
1034 | const int L2_capacity = (L2_size * 3) / 4; |
1035 | |
1036 | const bool |
1037 | = jcp.dst_dt == data_type::bf16 && !isa_has_bf16(jcp.isa); |
1038 | int size_treshold = req_extra_bf16_regs ? 25 : 28; |
1039 | int max_regs = 0; |
1040 | int min_regs = 6; |
1041 | if (jcp.has_vnni && !req_extra_bf16_regs) |
1042 | max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold) |
1043 | && (jcp.oc < 128 || jcp.ic < 128)) |
1044 | ? min_regs |
1045 | : 9; |
1046 | else |
1047 | max_regs = 8; |
1048 | jcp.expl_bcast = true; |
1049 | |
1050 | if (jcp.mb == 1 && jcp.ic > 128 |
1051 | && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) { |
1052 | if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size) |
1053 | max_regs = min_regs; // mobilenet_v2 performance improvement |
1054 | jcp.ur = nstl::min<dim_t>(max_regs, jcp.os); |
1055 | } else { |
1056 | const int spatial = jcp.od * jcp.oh; |
1057 | jcp.ur = 1; |
1058 | for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) { |
1059 | if ((spatial >= size_treshold && spatial % ur_w == 0) |
1060 | || (spatial < size_treshold && jcp.os % ur_w == 0)) { |
1061 | jcp.ur = ur_w; |
1062 | break; |
1063 | } |
1064 | } |
1065 | if (jcp.ur == 1) { |
1066 | jcp.ur = nstl::min<dim_t>(max_regs, jcp.os); |
1067 | int os_tail = jcp.os % max_regs; |
1068 | for (int i = max_regs; i >= min_regs; i--) { |
1069 | int i_tail = jcp.os % i; |
1070 | if (i_tail > os_tail || i_tail == 0) { |
1071 | jcp.ur = i; |
1072 | os_tail = i_tail; |
1073 | if (i_tail == 0) break; |
1074 | } |
1075 | } |
1076 | } |
1077 | } |
1078 | if (jcp.with_dw_conv) jcp.ur = nstl::min(jcp.ow, jcp.ur); |
1079 | |
1080 | jcp.reduce_dim = jcp.ic; |
1081 | jcp.reduce_block = jcp.ic_block; |
1082 | |
1083 | jcp.load_dim = jcp.oc; |
1084 | jcp.load_block = jcp.oc_block; |
1085 | |
1086 | jcp.bcast_dim = jcp.is; |
1087 | |
1088 | jcp.bcast_block = jcp.ur; |
1089 | |
1090 | jcp.reduce_loop_unroll = jcp.reduce_block; |
1091 | jcp.reduce_loop_bcast_step = jcp.reduce_loop_unroll * jcp.typesize_in; |
1092 | |
1093 | jcp.reduce_loop_load_step |
1094 | = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in; |
1095 | |
1096 | jcp.bcast_loop_output_step |
1097 | = jcp.ur * jcp.ngroups * jcp.oc_without_padding * jcp.typesize_out; |
1098 | jcp.bcast_loop_output_substep = -1; // unused |
1099 | jcp.bcast_loop_bcast_step |
1100 | = jcp.ur * jcp.ngroups * jcp.ic_without_padding * jcp.typesize_in; |
1101 | jcp.bcast_loop_bcast_substep = -1; // unused |
1102 | |
1103 | jcp.load_loop_load_step = jcp.reduce_dim * jcp.load_block * jcp.typesize_in; |
1104 | |
1105 | jcp.load_loop_iter_step = jcp.load_block; |
1106 | |
1107 | jcp.loop_order = reduce_src ? loop_blr : loop_lbr; |
1108 | |
1109 | int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
1110 | int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
1111 | |
1112 | reduce_blocking = nb_reduce; |
1113 | if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) |
1114 | reduce_blocking = 64; |
1115 | else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM) |
1116 | reduce_blocking = 16; |
1117 | reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true); |
1118 | reduce_blocking *= jcp.reduce_block; |
1119 | |
1120 | bool cmp_reduce = reduce_blocking <= jcp.reduce_dim; |
1121 | if (cmp_reduce) jcp.loop_order = reduce_src ? loop_rbl : loop_rlb; |
1122 | load_blocking = jcp.load_dim; |
1123 | |
1124 | jcp.load_grp_count = div_up(jcp.nthr, jcp.mb * jcp.ngroups * nb_bcast); |
1125 | jcp.load_grp_count = best_divider( |
1126 | jcp.nthr, jcp.load_grp_count, 2 * jcp.load_grp_count, false); |
1127 | |
1128 | if (jcp.bcast_dim <= SMALL_SPATIAL |
1129 | && jcp.load_dim * jcp.reduce_dim >= L2_size) { |
1130 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4); |
1131 | } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= jcp.nthr |
1132 | && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) { |
1133 | jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); // |
1134 | load_blocking = jcp.load_block; |
1135 | } |
1136 | |
1137 | bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast, |
1138 | div_up(jcp.nthr, jcp.load_grp_count)) |
1139 | * jcp.bcast_block; |
1140 | bcast_blocking = nstl::min<dim_t>(jcp.bcast_dim, bcast_blocking); |
1141 | bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block); |
1142 | |
1143 | int space_for_bcast = (L2_capacity - /* kernel_size - */ |
1144 | 2 * jcp.load_block * reduce_blocking - jcp.ur * reduce_blocking |
1145 | - 3 * 1024); |
1146 | if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity) space_for_bcast /= 2; |
1147 | |
1148 | int bcast_in_cache |
1149 | = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking); |
1150 | bcast_blocking = nstl::min( |
1151 | bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block)); |
1152 | |
1153 | load_blocking_max = load_blocking; |
1154 | bcast_blocking_max = bcast_blocking * 3 / 2; |
1155 | reduce_blocking_max = reduce_blocking; |
1156 | |
1157 | assert(load_blocking); |
1158 | assert(load_blocking_max); |
1159 | assert(bcast_blocking); |
1160 | assert(bcast_blocking_max); |
1161 | assert(reduce_blocking); |
1162 | assert(reduce_blocking_max); |
1163 | assert(load_blocking % jcp.load_block == 0); |
1164 | assert(reduce_blocking % jcp.reduce_block == 0); |
1165 | assert(load_blocking_max % jcp.load_block == 0); |
1166 | assert(reduce_blocking_max % jcp.reduce_block == 0); |
1167 | |
1168 | assert(jcp.reduce_loop_unroll % 4 == 0); |
1169 | assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0); |
1170 | |
1171 | assert(jcp.bcast_block % jcp.ur == 0); |
1172 | assert(jcp.reduce_dim % jcp.reduce_block == 0); |
1173 | |
1174 | jcp.ur_tail = (jcp.with_dw_conv ? jcp.ow : jcp.bcast_dim) % jcp.ur; |
1175 | |
1176 | jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block; |
1177 | jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block; |
1178 | jcp.nb_load_blocking = load_blocking / jcp.load_block; |
1179 | jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block; |
1180 | jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block; |
1181 | jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block; |
1182 | |
1183 | jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block); |
1184 | jcp.nb_load = div_up(jcp.load_dim, jcp.load_block); |
1185 | jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block); |
1186 | |
1187 | // miniumum size of load dim chunk for work distribution within threads |
1188 | jcp.nb_load_chunk = 1; |
1189 | // peformance improvements for googlenet_v3, mb=1; |
1190 | // TODO: generalize this condition and rewrite it in appropriate manner |
1191 | int ncores_per_socket = (int)cpu().getNumCores( |
1192 | Xbyak::util::IntelCpuTopologyLevel::CoreLevel); |
1193 | if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4 |
1194 | && jcp.ic * jcp.oc <= L2_size && jcp.nthr <= ncores_per_socket) { |
1195 | jcp.nb_load_chunk = 4; |
1196 | jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count); |
1197 | } |
1198 | |
1199 | /* adjust the thread decomposition |
1200 | * to improve the perf for small size problem |
1201 | * the threshold 8192 is empirical |
1202 | * simply set the thread to max of nb_load and nb_bcast now |
1203 | * TODO: add get_thr_eff func to compute optimal thread |
1204 | * TODO: Threshold can be increase when init stride > 1 */ |
1205 | auto bcast_size |
1206 | = (dim_t)jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim; |
1207 | if (jcp.typesize_in * bcast_size < 8192 && jcp.ngroups < jcp.nthr |
1208 | && jcp.nb_bcast * jcp.nb_load < jcp.nthr) { |
1209 | int nthr = nstl::max(jcp.nb_load, jcp.nb_bcast); |
1210 | jcp.nthr = nstl::min(jcp.nthr, nthr); |
1211 | } |
1212 | |
1213 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
1214 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
1215 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
1216 | const int wei_mask_per_oc = 1 << (int)with_groups; |
1217 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
1218 | jcp.dst_scale = !dst_scales.has_default_values(); |
1219 | |
1220 | // only common src & dst scales are supported |
1221 | // only common and per-oc-channel weight scales are supported |
1222 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
1223 | && everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
1224 | if (!scales_ok) return status::unimplemented; |
1225 | |
1226 | jcp.wei_adj_scale |
1227 | = (weights_d.extra().flags & memory_extra_flags::scale_adjust) |
1228 | ? weights_d.extra().scale_adjust |
1229 | : 1.f; |
1230 | |
1231 | return status::success; |
1232 | } |
1233 | |
1234 | void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad( |
1235 | memory_tracking::registrar_t &scratchpad, |
1236 | const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) { |
1237 | using namespace dnnl::impl::memory_tracking::names; |
1238 | |
1239 | const int wei_mask = attr.scales_.get(DNNL_ARG_WEIGHTS).mask_; |
1240 | const dim_t scales_count = wei_mask == 0 ? 1 : jcp.oc * jcp.ngroups; |
1241 | const dim_t count = nstl::max<dim_t>(scales_count, (dim_t)jcp.ic_block); |
1242 | scratchpad.book<float>(key_conv_adjusted_scales, count); |
1243 | } |
1244 | |
1245 | template struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Xbyak::Zmm>; |
1246 | template struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Xbyak::Ymm>; |
1247 | template struct _jit_avx512_core_x8s8s32x_1x1_conv_kernel<Xbyak::Xmm>; |
1248 | |
1249 | } // namespace x64 |
1250 | } // namespace cpu |
1251 | } // namespace impl |
1252 | } // namespace dnnl |
1253 | |