1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/memory_tracking.hpp" |
19 | #include "common/nstl.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | #include "common/utils.hpp" |
22 | |
23 | #include "cpu/platform.hpp" |
24 | #include "cpu/scale_utils.hpp" |
25 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
26 | #include "cpu/x64/jit_avx512_core_amx_1x1_conv_kernel.hpp" |
27 | |
28 | #define GET_OFF(field) offsetof(jit_conv_call_s, field) |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | namespace x64 { |
34 | |
35 | using namespace dnnl::impl::memory_tracking::names; |
36 | using namespace dnnl::impl::data_type; |
37 | using namespace dnnl::impl::utils; |
38 | using namespace Xbyak; |
39 | |
40 | jit_avx512_core_amx_1x1_fwd_kernel_t::jit_avx512_core_amx_1x1_fwd_kernel_t( |
41 | const jit_conv_conf_t &ajcp, const primitive_attr_t &attr, |
42 | const memory_desc_t &dst_md) |
43 | : jit_generator(jit_name(), nullptr, MAX_CODE_SIZE, true, avx512_core_amx) |
44 | , jcp(ajcp) |
45 | , attr_(attr) { |
46 | if (jcp.with_eltwise || jcp.with_binary || jcp.with_sum) { |
47 | using namespace binary_injector; |
48 | const auto &rhs_addr_reg = bin_injector_helper_reg_1; |
49 | const auto &rhs_helper_reg = bin_injector_helper_reg_2; |
50 | const auto &rhs_addr_cache_reg = bin_injector_helper_reg_3; |
51 | static constexpr bool preserve_gpr = false; |
52 | static constexpr bool preserve_vmm = false; |
53 | const size_t tail_size = jcp.oc_without_padding % isa_simd_width_; |
54 | static constexpr bool use_exact_tail_scalar_bcast = true; |
55 | |
56 | const rhs_arg_static_params_t rhs_arg_static_params {31, rhs_addr_reg, |
57 | rhs_helper_reg, rhs_addr_cache_reg, preserve_gpr, preserve_vmm, |
58 | GET_OFF(post_ops_binary_rhs_arg_vec), GET_OFF(dst_orig), |
59 | memory_desc_wrapper(dst_md), tail_size, ktail_mask, |
60 | use_exact_tail_scalar_bcast}; |
61 | const static_params_t static_params { |
62 | this->param1, rhs_arg_static_params}; |
63 | |
64 | postops_injector_ = utils::make_unique< |
65 | injector::jit_uni_postops_injector_t<avx512_core>>( |
66 | this, jcp.post_ops, static_params); |
67 | } |
68 | } |
69 | |
70 | // Tile register decomposition |
71 | int jit_avx512_core_amx_1x1_fwd_kernel_t::get_out_tensor(int h, int i) const { |
72 | return C_BASE + h * jcp.nb_os_blocking + i; |
73 | } |
74 | int jit_avx512_core_amx_1x1_fwd_kernel_t::get_inp_tensor(int h) const { |
75 | return I_BASE + h; |
76 | } |
77 | int jit_avx512_core_amx_1x1_fwd_kernel_t::get_wei_tensor(int i) const { |
78 | return W_BASE + i; |
79 | } |
80 | |
81 | bool jit_avx512_core_amx_1x1_fwd_kernel_t::is_bf16() const { |
82 | return jcp.src_dt == data_type::bf16; |
83 | } |
84 | |
85 | // Code generation |
86 | void jit_avx512_core_amx_1x1_fwd_kernel_t::init_runtime_counters() { |
87 | row_count_ = 0; |
88 | buf_count_ = 0; |
89 | is_store_done_ = false; |
90 | is_buffer_empty_ = true; |
91 | } |
92 | |
93 | size_t jit_avx512_core_amx_1x1_fwd_kernel_t::out_h_shift() const { |
94 | return (size_t)jcp.ow * jcp.ngroups * jcp.oc_without_padding; |
95 | } |
96 | |
97 | size_t jit_avx512_core_amx_1x1_fwd_kernel_t::out_w_shift() const { |
98 | return (size_t)jcp.ngroups * jcp.oc_without_padding; |
99 | } |
100 | |
101 | size_t jit_avx512_core_amx_1x1_fwd_kernel_t::inp_offset( |
102 | int h, int w, int icb) const { |
103 | return (size_t)jcp.typesize_in |
104 | * (h * jcp.iw * jcp.ngroups * jcp.ic_without_padding |
105 | + w * jcp.ngroups * jcp.ic_without_padding |
106 | + icb * jcp.ic_block_int_np); |
107 | } |
108 | |
109 | size_t jit_avx512_core_amx_1x1_fwd_kernel_t::out_row_offset( |
110 | int h, int w, int ocb) const { |
111 | return (size_t)jcp.typesize_out |
112 | * (h * jcp.ow * jcp.ngroups * jcp.oc_without_padding |
113 | + w * jcp.ngroups * jcp.oc_without_padding |
114 | + ocb * jcp.oc_block); |
115 | } |
116 | |
117 | void jit_avx512_core_amx_1x1_fwd_kernel_t::update_buffer_pointers() { |
118 | auto buffer_offset = [=](bool shift) { return ((buf_count_ + shift) % 2); }; |
119 | int wsp_shift = jcp.typesize_acc * (jcp.wsp_buffer_size / 2); |
120 | |
121 | int postop_shift = wsp_shift * buffer_offset(true); |
122 | |
123 | mov(reg_postop, wsp_ptr); |
124 | add(reg_postop, postop_shift); |
125 | |
126 | buf_count_++; |
127 | } |
128 | |
129 | void jit_avx512_core_amx_1x1_fwd_kernel_t::interleave_store() { |
130 | int scnd_dim = jcp.nb_os_blocking * jcp.tile_width; |
131 | |
132 | for (int c = 0; |
133 | c < jcp.per_one_pstore && !is_store_done_ && !is_buffer_empty_; |
134 | c++) { |
135 | int ocb = (row_count_ / scnd_dim); |
136 | int osb = (row_count_ % scnd_dim) / jcp.tile_width; |
137 | int row = (row_count_ % scnd_dim) % jcp.tile_width; |
138 | |
139 | const Zmm zmm_r = zmm_out(row); |
140 | |
141 | int oh = ((osb * jcp.tile_width + row) / jcp.ow); |
142 | int ow = ((osb * jcp.tile_width + row) % jcp.ow); |
143 | |
144 | { |
145 | // preserve registers used by binary post_ops injector |
146 | const injector_utils::conditional_register_preserve_guard_t |
147 | cond_register_guard(jcp.with_binary, this, |
148 | {bin_injector_helper_reg_1, |
149 | bin_injector_helper_reg_2, |
150 | bin_injector_helper_reg_3}); |
151 | const int wsp_row_offset = jcp.typesize_acc |
152 | * (osb * jcp.nb_oc_blocking * jcp.max_width * jcp.oc_block |
153 | + ocb * jcp.max_width * jcp.oc_block |
154 | + row * jcp.oc_block); |
155 | |
156 | vmovups(zmm_r, ptr[reg_postop + wsp_row_offset]); |
157 | store_output_vector(zmm_r, ocb, oh, ow); |
158 | row_count_++; |
159 | } |
160 | |
161 | int exp_row_count |
162 | = jcp.tile_width * jcp.nb_oc_blocking * jcp.nb_os_blocking; |
163 | if (row_count_ == exp_row_count) { |
164 | int oh = ((jcp.nb_os_blocking * jcp.tile_width) / jcp.ow); |
165 | int ow = ((jcp.nb_os_blocking * jcp.tile_width) % jcp.ow); |
166 | size_t out_offset = jcp.typesize_out |
167 | * (oh * out_h_shift() + ow * out_w_shift()); |
168 | add(out_ptr, out_offset); |
169 | row_count_ = 0; |
170 | is_store_done_ = true; |
171 | } |
172 | } |
173 | } |
174 | |
175 | Ymm jit_avx512_core_amx_1x1_fwd_kernel_t::ymm_mask( |
176 | const Ymm ymm_in, bool mask_flag, bool store) { |
177 | return mask_flag ? (store ? ymm_in | ktail_mask : ymm_in | ktail_mask | T_z) |
178 | : ymm_in; |
179 | } |
180 | |
181 | Zmm jit_avx512_core_amx_1x1_fwd_kernel_t::zmm_mask( |
182 | const Zmm zmm_in, bool mask_flag, bool store) { |
183 | return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z) |
184 | : zmm_in; |
185 | } |
186 | |
187 | void jit_avx512_core_amx_1x1_fwd_kernel_t::cvt2ps(data_type_t type_in, |
188 | const Zmm zmm_in, const Operand &op, bool mask_flag = false) { |
189 | using namespace dnnl::impl::data_type; |
190 | const Zmm zmm = zmm_mask(zmm_in, mask_flag); |
191 | switch (type_in) { |
192 | case bf16: |
193 | vpmovzxwd(zmm, op); |
194 | vpslld(zmm_in, zmm_in, 16); |
195 | break; |
196 | case f32: |
197 | case s32: vmovups(zmm, op); break; |
198 | case s8: vpmovsxbd(zmm, op); break; |
199 | case u8: vpmovzxbd(zmm, op); break; |
200 | default: assert(!"unsupported data type" ); |
201 | } |
202 | if (utils::one_of(type_in, s32, s8, u8)) vcvtdq2ps(zmm_in, zmm_in); |
203 | } |
204 | |
205 | void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_sum(const Zmm zmm_out, |
206 | const float *p_sum_scale, const int32_t *p_sum_zp, |
207 | const Xbyak::Address &addr, const bool mask_flag) { |
208 | if (p_sum_scale) { |
209 | const auto p_sum_scale_val = *p_sum_scale; |
210 | const auto p_sum_zp_val = *p_sum_zp; |
211 | const auto sum_injector = [&, zmm_out, p_sum_scale_val, p_sum_zp_val, |
212 | mask_flag]() { |
213 | cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag); |
214 | if (p_sum_zp_val != 0) { |
215 | vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
216 | vsubps(zmm_prev_dst, zmm_sum_zp); |
217 | } |
218 | if (p_sum_scale_val == 1.f) |
219 | vaddps(zmm_out, zmm_prev_dst); |
220 | else |
221 | vfmadd231ps(zmm_out, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); |
222 | }; |
223 | postops_injector_->set_lambda_injector( |
224 | primitive_kind::sum, sum_injector); |
225 | } |
226 | } |
227 | |
228 | void jit_avx512_core_amx_1x1_fwd_kernel_t::apply_postops(const Zmm zmm_out, |
229 | const float *p_sum_scale, const int32_t *p_sum_zp, |
230 | const Xbyak::Address &addr, const size_t off, const bool mask_flag) { |
231 | if (jcp.with_eltwise || jcp.with_binary |
232 | || (jcp.with_sum && p_sum_scale != nullptr)) { |
233 | apply_sum(zmm_out, p_sum_scale, p_sum_zp, addr, mask_flag); |
234 | |
235 | const auto vmm_idx = zmm_out.getIdx(); |
236 | if (jcp.with_binary) { |
237 | binary_injector::rhs_arg_dynamic_params_t rhs_arg_params; |
238 | rhs_arg_params.vmm_idx_to_out_reg.emplace(vmm_idx, out_ptr); |
239 | rhs_arg_params.vmm_idx_to_out_elem_off_val.emplace(vmm_idx, off); |
240 | if (mask_flag) rhs_arg_params.vmm_tail_idx_.emplace(vmm_idx); |
241 | |
242 | postops_injector_->compute_vector(vmm_idx, rhs_arg_params); |
243 | } else { |
244 | postops_injector_->compute_vector(vmm_idx); |
245 | } |
246 | } |
247 | } |
248 | |
249 | bool jit_avx512_core_amx_1x1_fwd_kernel_t::is_fast_postops( |
250 | const jit_conv_conf_t &jcp) { |
251 | const auto &p = jcp.post_ops; |
252 | auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); }; |
253 | auto is_sum = [&](int idx) { |
254 | const bool require_scale_one = jcp.src_dt == data_type::bf16; |
255 | return p.entry_[idx].is_sum(require_scale_one); |
256 | }; |
257 | switch (p.len()) { |
258 | case 0: return true; |
259 | case 1: return is_relu(0) || is_sum(0); |
260 | case 2: return is_sum(0) && is_relu(1); |
261 | default: return false; |
262 | } |
263 | return false; |
264 | } |
265 | |
266 | inline void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_ymm_bf16( |
267 | const int idx, const Xbyak::Address &addr, const bool mask_flag) { |
268 | Ymm ymm_out = Ymm(idx); |
269 | vcvtneps2bf16(ymm_out, Zmm(idx)); |
270 | vmovdqu16(addr, ymm_mask(ymm_out, mask_flag, true)); |
271 | } |
272 | |
273 | void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vectors_int8( |
274 | int ocb, int osb) { |
275 | const bool mask_flag |
276 | = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1); |
277 | const auto &p = attr_.post_ops_; |
278 | const int sum_idx = p.find(primitive_kind::sum); |
279 | const float *p_sum_scale = nullptr; |
280 | const int32_t *p_sum_zp = nullptr; |
281 | if (sum_idx != -1) { |
282 | const auto &p_entry = p.entry_[sum_idx]; |
283 | p_sum_scale = &p_entry.sum.scale; |
284 | p_sum_zp = &p_entry.sum.zero_point; |
285 | } |
286 | if (p_sum_scale) { |
287 | if (*p_sum_scale != 1.f) |
288 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
289 | if (*p_sum_zp != 0) |
290 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
291 | } |
292 | |
293 | if (jcp.src_zero_point) { |
294 | const int zp_offset = sizeof(int32_t) * ocb * jcp.oc_block; |
295 | const Zmm zmm_zp_m = zmm_mask(zmm_zp, mask_flag); |
296 | vpmulld(zmm_zp_m, zmm_src_zp, |
297 | EVEX_compress_addr(reg_zp_compensation, zp_offset)); |
298 | for (int j = 0; j < jcp.tile_width; j++) { |
299 | const Zmm zmm_r = zmm_out(j); |
300 | vpaddd(zmm_r, zmm_r, zmm_zp_m); |
301 | } |
302 | } |
303 | |
304 | for (int j = 0; j < jcp.tile_width; j++) { |
305 | const Zmm zmm_r = zmm_out(j); |
306 | vcvtdq2ps(zmm_r, zmm_r); |
307 | } |
308 | |
309 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
310 | for (int j = 0; j < jcp.tile_width; j++) { |
311 | const int scale_offset |
312 | = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); |
313 | const Zmm zmm_r = zmm_out(j); |
314 | const Zmm zmm_r_msk = zmm_mask(zmm_r, mask_flag); |
315 | vmulps(zmm_r_msk, zmm_r, |
316 | EVEX_compress_addr(reg_ptr_scales, scale_offset)); |
317 | } |
318 | |
319 | if (jcp.with_bias) { |
320 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
321 | int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; |
322 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
323 | cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); |
324 | for (int j = 0; j < jcp.tile_width; j++) { |
325 | const Zmm zmm_r = zmm_out(j); |
326 | vaddps(zmm_r, zmm_r, zmm_bias); |
327 | } |
328 | } |
329 | |
330 | if (p_sum_zp && *p_sum_zp != 0) |
331 | vcvtdq2ps(zmm_sum_zp, ptr_b[reg_ptr_sum_zp]); |
332 | if (jcp.with_sum && p_sum_scale != nullptr) { |
333 | const auto p_sum_scale_val = *p_sum_scale; |
334 | const auto p_sum_zp_val = *p_sum_zp; |
335 | for (int j = 0; j < jcp.tile_width; j++) { |
336 | int h = ((osb * jcp.tile_width + j) / jcp.ow); |
337 | int w = ((osb * jcp.tile_width + j) % jcp.ow); |
338 | |
339 | const auto off = out_row_offset(h, w, ocb); |
340 | const auto addr = EVEX_compress_addr(out_ptr, off); |
341 | |
342 | const Zmm zmm_r = zmm_out(j); |
343 | cvt2ps(jcp.sum_dt, zmm_prev_dst, addr, mask_flag); |
344 | if (p_sum_zp_val != 0) vsubps(zmm_prev_dst, zmm_sum_zp); |
345 | if (p_sum_scale_val == 1.f) |
346 | vaddps(zmm_r, zmm_prev_dst); |
347 | else |
348 | vfmadd231ps(zmm_r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]); |
349 | } |
350 | } |
351 | if (jcp.with_eltwise) { |
352 | vxorps(zmm_zero, zmm_zero, zmm_zero); |
353 | for (int j = 0; j < jcp.tile_width; j++) { |
354 | const Zmm zmm_r = zmm_out(j); |
355 | vmaxps(zmm_r, zmm_r, zmm_zero); |
356 | } |
357 | } |
358 | |
359 | if (jcp.dst_scale) { |
360 | mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
361 | for (int j = 0; j < jcp.tile_width; j++) { |
362 | const Zmm zmm_r = zmm_out(j); |
363 | const Zmm zmm_r_msk = zmm_mask(zmm_r, mask_flag); |
364 | vmulps(zmm_r_msk, zmm_r, EVEX_compress_addr(reg_ptr_dst_scale, 0)); |
365 | } |
366 | } |
367 | |
368 | if (jcp.dst_zero_point) { |
369 | for (int j = 0; j < jcp.tile_width; j++) { |
370 | const Zmm zmm_r = zmm_out(j); |
371 | vaddps(zmm_r, zmm_r, zmm_dst_zp); |
372 | } |
373 | } |
374 | |
375 | // Properly saturate the accumulators for integer datatypes |
376 | if (one_of(jcp.dst_dt, u8, s8, s32)) { |
377 | init_saturate_f32( |
378 | zmm_zero, zmm_saturation, aux_reg_saturation, f32, jcp.dst_dt); |
379 | for (int j = 0; j < jcp.tile_width; j++) { |
380 | const Zmm zmm_r = zmm_out(j); |
381 | saturate_f32(zmm_r, zmm_zero, zmm_saturation, jcp.dst_dt); |
382 | vcvtps2dq(zmm_r, zmm_r); |
383 | } |
384 | } |
385 | |
386 | for (int j = 0; j < jcp.tile_width; j++) { |
387 | const int h = ((osb * jcp.tile_width + j) / jcp.ow); |
388 | const int w = ((osb * jcp.tile_width + j) % jcp.ow); |
389 | const auto off = out_row_offset(h, w, ocb); |
390 | const auto addr = EVEX_compress_addr(out_ptr, off); |
391 | |
392 | const Zmm zmm_out_store = zmm_mask(zmm_out(j), mask_flag, true); |
393 | switch (jcp.dst_dt) { |
394 | case data_type::f32: |
395 | case data_type::s32: vmovups(addr, zmm_out_store); break; |
396 | case data_type::bf16: |
397 | store_output_ymm_bf16(zmm_out_store.getIdx(), addr, mask_flag); |
398 | break; |
399 | case data_type::s8: vpmovsdb(addr, zmm_out_store); break; |
400 | case data_type::u8: vpmovusdb(addr, zmm_out_store); break; |
401 | default: assert(!"unknown dst_dt" ); |
402 | } |
403 | } |
404 | } |
405 | |
406 | void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_int8( |
407 | const Zmm zmm_out, int ocb, int h, int w) { |
408 | |
409 | const auto off = out_row_offset(h, w, ocb); |
410 | const auto addr = EVEX_compress_addr(out_ptr, off); |
411 | |
412 | const bool mask_flag |
413 | = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1); |
414 | const auto &p = attr_.post_ops_; |
415 | const int sum_idx = p.find(primitive_kind::sum); |
416 | const float *p_sum_scale = nullptr; |
417 | const int32_t *p_sum_zp = nullptr; |
418 | if (sum_idx != -1) { |
419 | const auto &p_entry = p.entry_[sum_idx]; |
420 | p_sum_scale = &p_entry.sum.scale; |
421 | p_sum_zp = &p_entry.sum.zero_point; |
422 | } |
423 | |
424 | if (p_sum_scale) { |
425 | if (*p_sum_scale != 1.f) |
426 | mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale)); |
427 | if (*p_sum_zp != 0) |
428 | mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp)); |
429 | } |
430 | |
431 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
432 | mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]); |
433 | |
434 | int scale_offset = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block); |
435 | if (jcp.with_bias) { |
436 | int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; |
437 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
438 | cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); |
439 | } |
440 | if (jcp.src_zero_point) { |
441 | const int zp_offset = sizeof(int32_t) * ocb * jcp.oc_block; |
442 | const Zmm zmm_zp_m = zmm_mask(zmm_zp, mask_flag); |
443 | vpmulld(zmm_zp_m, zmm_src_zp, |
444 | EVEX_compress_addr(reg_zp_compensation, zp_offset)); |
445 | vpaddd(zmm_out, zmm_out, zmm_zp_m); |
446 | } |
447 | /* add to zmm_accum: compensation, bias and permute */ |
448 | vcvtdq2ps(zmm_out, zmm_out); |
449 | |
450 | const Zmm zmm_out_msk = zmm_mask(zmm_out, mask_flag); |
451 | vmulps(zmm_out_msk, zmm_out, |
452 | EVEX_compress_addr(reg_ptr_scales, scale_offset)); |
453 | |
454 | if (jcp.with_bias) vaddps(zmm_out_msk, zmm_out, zmm_bias); |
455 | |
456 | apply_postops(zmm_out, p_sum_scale, p_sum_zp, addr, off, mask_flag); |
457 | |
458 | if (jcp.dst_scale) { |
459 | mov(reg_ptr_dst_scale, ptr[param1 + GET_OFF(dst_scale)]); |
460 | vmulps(zmm_out, zmm_out, EVEX_compress_addr(reg_ptr_dst_scale, 0)); |
461 | } |
462 | if (jcp.dst_zero_point) { vaddps(zmm_out, zmm_out, zmm_dst_zp); } |
463 | |
464 | // Properly saturate the accumulators for integer datatypes |
465 | if (one_of(jcp.dst_dt, u8, s8, s32)) { |
466 | init_saturate_f32( |
467 | zmm_zero, zmm_saturation, aux_reg_saturation, f32, jcp.dst_dt); |
468 | saturate_f32(zmm_out, zmm_zero, zmm_saturation, jcp.dst_dt); |
469 | vcvtps2dq(zmm_out, zmm_out); |
470 | } |
471 | |
472 | const Zmm zmm_out_store = zmm_mask(zmm_out, mask_flag, true); |
473 | switch (jcp.dst_dt) { |
474 | case data_type::f32: |
475 | case data_type::s32: vmovups(addr, zmm_out_store); break; |
476 | case data_type::bf16: |
477 | store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag); |
478 | break; |
479 | case data_type::s8: vpmovsdb(addr, zmm_out_store); break; |
480 | case data_type::u8: vpmovusdb(addr, zmm_out_store); break; |
481 | default: assert(!"unknown dst_dt" ); |
482 | } |
483 | } |
484 | |
485 | void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vectors_bf16( |
486 | int ocb, int osb) { |
487 | const bool mask_flag |
488 | = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1); |
489 | |
490 | if (jcp.with_bias) { |
491 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
492 | const int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; |
493 | const auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
494 | cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag); |
495 | for (int j = 0; j < jcp.tile_width; j++) { |
496 | const Zmm zmm_r = zmm_out(j); |
497 | vaddps(zmm_r, zmm_r, zmm_bias); |
498 | } |
499 | } |
500 | |
501 | if (jcp.with_sum) { |
502 | for (int j = 0; j < jcp.tile_width; j++) { |
503 | int h = ((osb * jcp.tile_width + j) / jcp.ow); |
504 | int w = ((osb * jcp.tile_width + j) % jcp.ow); |
505 | const auto off = out_row_offset(h, w, ocb); |
506 | const auto addr = EVEX_compress_addr(out_ptr, off); |
507 | const Zmm zmm_r = zmm_out(j); |
508 | cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag); |
509 | vaddps(zmm_r, zmm_prev_dst); |
510 | } |
511 | } |
512 | if (jcp.with_eltwise) { |
513 | vxorps(zmm_zero, zmm_zero, zmm_zero); |
514 | for (int j = 0; j < jcp.tile_width; j++) { |
515 | const Zmm zmm_r = zmm_out(j); |
516 | vmaxps(zmm_r, zmm_r, zmm_zero); |
517 | } |
518 | } |
519 | |
520 | for (int j = 0; j < jcp.tile_width; j++) { |
521 | const int h = ((osb * jcp.tile_width + j) / jcp.ow); |
522 | const int w = ((osb * jcp.tile_width + j) % jcp.ow); |
523 | const auto off = out_row_offset(h, w, ocb); |
524 | const auto addr = EVEX_compress_addr(out_ptr, off); |
525 | const Zmm zmm_r = zmm_out(j); |
526 | if (jcp.dst_dt == data_type::bf16) { |
527 | store_output_ymm_bf16(zmm_r.getIdx(), addr, mask_flag); |
528 | } else { |
529 | vmovups(addr, zmm_mask(zmm_r, mask_flag, true)); |
530 | } |
531 | } |
532 | } |
533 | |
534 | void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector_bf16( |
535 | const Zmm zmm_out, int ocb, int h, int w) { |
536 | const auto off = out_row_offset(h, w, ocb); |
537 | const auto addr = EVEX_compress_addr(out_ptr, off); |
538 | |
539 | const bool mask_flag |
540 | = last_oc_block_flag_ && ocb == (jcp.nb_oc_blocking - 1); |
541 | |
542 | const auto &p = attr_.post_ops_; |
543 | |
544 | mov(reg_bias, ptr[param1 + GET_OFF(bias)]); |
545 | |
546 | const int sum_idx = p.find(primitive_kind::sum); |
547 | if (sum_idx != -1) { |
548 | if (jcp.dst_dt == data_type::bf16) { |
549 | vpmovzxwd(zmm_mask(zmm_prev_dst, mask_flag), addr); |
550 | vpslld(zmm_prev_dst, zmm_prev_dst, 16); |
551 | vaddps(zmm_out, zmm_prev_dst); |
552 | } else { |
553 | vmovups(zmm_mask(zmm_prev_dst, mask_flag), addr); |
554 | vaddps(zmm_out, zmm_prev_dst); |
555 | } |
556 | } |
557 | if (jcp.with_bias) { |
558 | int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block; |
559 | auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset); |
560 | if (jcp.bia_dt == data_type::bf16) { |
561 | vpmovzxwd(zmm_mask(zmm_bias, mask_flag), bias_addr); |
562 | vpslld(zmm_bias, zmm_bias, 16); |
563 | vaddps(zmm_out, zmm_bias); |
564 | } else |
565 | vaddps(zmm_mask(zmm_out, mask_flag), bias_addr); |
566 | } |
567 | |
568 | static constexpr auto skip_sum_in_injection = nullptr; |
569 | apply_postops(zmm_out, skip_sum_in_injection, skip_sum_in_injection, addr, |
570 | off, mask_flag); |
571 | |
572 | if (jcp.dst_dt == data_type::bf16) { |
573 | store_output_ymm_bf16(zmm_out.getIdx(), addr, mask_flag); |
574 | } else { |
575 | vmovups(addr, zmm_mask(zmm_out, mask_flag, true)); |
576 | } |
577 | } |
578 | |
579 | // Store all rows of a tile |
580 | void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vectors( |
581 | int ocb, int osb) { |
582 | if (is_bf16()) { |
583 | store_output_vectors_bf16(ocb, osb); |
584 | } else { |
585 | store_output_vectors_int8(ocb, osb); |
586 | } |
587 | } |
588 | |
589 | // Store single row |
590 | void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output_vector( |
591 | const Zmm zmm_out, int ocb, int h, int w) { |
592 | if (is_bf16()) { |
593 | store_output_vector_bf16(zmm_out, ocb, h, w); |
594 | } else { |
595 | store_output_vector_int8(zmm_out, ocb, h, w); |
596 | } |
597 | } |
598 | |
599 | void jit_avx512_core_amx_1x1_fwd_kernel_t::prepare_output() { |
600 | for (int osb = 0; osb < jcp.nb_os_blocking; osb++) |
601 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) |
602 | tilezero(Tmm(get_out_tensor(osb, ocb))); |
603 | } |
604 | |
605 | void jit_avx512_core_amx_1x1_fwd_kernel_t::store_output( |
606 | bool do_store, bool has_tail) { |
607 | |
608 | auto store_output_subblock = [=](int ocb, int osb) { |
609 | const int wsp_offset = jcp.typesize_acc |
610 | * (osb * jcp.nb_oc_blocking * jcp.max_width * jcp.oc_block |
611 | + ocb * jcp.max_width * jcp.oc_block); |
612 | tilestored(ptr[wsp_ptr + stride_seq + wsp_offset], |
613 | Tmm(get_out_tensor(osb, ocb))); |
614 | |
615 | // preserve registers used by binary post_ops injector |
616 | const injector_utils::conditional_register_preserve_guard_t |
617 | cond_register_guard(jcp.with_binary, this, |
618 | {bin_injector_helper_reg_1, bin_injector_helper_reg_2}); |
619 | is_buffer_empty_ = false; |
620 | is_store_done_ = (do_store) ? true : false; |
621 | for (int j = 0; j < jcp.tile_width && do_store; j++) { |
622 | int oh_ = ((osb * jcp.tile_width + j) / jcp.ow); |
623 | int ow_ = ((osb * jcp.tile_width + j) % jcp.ow); |
624 | |
625 | auto addr = ptr[wsp_ptr + jcp.typesize_acc * (j * jcp.oc_block) |
626 | + wsp_offset]; |
627 | const Zmm zmm_r = zmm_out(j); |
628 | vmovups(zmm_r, addr); |
629 | if (!jcp.is_fast_postops) store_output_vector(zmm_r, ocb, oh_, ow_); |
630 | } |
631 | if (do_store && jcp.is_fast_postops) store_output_vectors(ocb, osb); |
632 | }; |
633 | |
634 | auto store_output_block = [=](int os_b = 1) { |
635 | if (jcp.src_zero_point) { |
636 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
637 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
638 | vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0)); |
639 | } |
640 | if (jcp.dst_zero_point) { |
641 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
642 | vcvtdq2ps(zmm_dst_zp, |
643 | EVEX_compress_addr(reg_dst_zero_point, 0, true)); |
644 | } |
645 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) |
646 | for (int osb = 0; osb < os_b; osb++) |
647 | store_output_subblock(ocb, osb); |
648 | }; |
649 | |
650 | Label label_oc_store, label_done; |
651 | |
652 | if (check_last_sb_) { |
653 | mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]); |
654 | cmp(reg_last_h, 1); |
655 | je(label_oc_store, T_NEAR); |
656 | } |
657 | |
658 | store_output_block(jcp.nb_os_blocking); |
659 | jmp(label_done, T_NEAR); |
660 | |
661 | L(label_oc_store); |
662 | store_output_block(); |
663 | |
664 | L(label_done); |
665 | update_buffer_pointers(); |
666 | } |
667 | |
668 | void jit_avx512_core_amx_1x1_fwd_kernel_t::icb_loop(bool do_store) { |
669 | enum tiles_cfg_t { cfg_tiles, cfg_tiles_tail }; |
670 | enum restore_tiles_t { write_tiles, read_tiles }; |
671 | |
672 | auto tdpbxxd = [=](const Tmm &x1, const Tmm &x2, const Tmm &x3) { |
673 | if (jcp.src_dt == data_type::bf16 && jcp.wei_dt == data_type::bf16) { |
674 | tdpbf16ps(x1, x2, x3); |
675 | } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::u8) { |
676 | tdpbuud(x1, x2, x3); |
677 | } else if (jcp.src_dt == data_type::u8 && jcp.wei_dt == data_type::s8) { |
678 | tdpbusd(x1, x2, x3); |
679 | } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::u8) { |
680 | tdpbsud(x1, x2, x3); |
681 | } else if (jcp.src_dt == data_type::s8 && jcp.wei_dt == data_type::s8) { |
682 | tdpbssd(x1, x2, x3); |
683 | } else { |
684 | assert(!"unsupported combination" ); |
685 | } |
686 | }; |
687 | |
688 | auto tileloadd_nt = [=](const Tmm &t1, int offset) { |
689 | int ab_size = jcp.nb_os2_blocking * jcp.nb_os_blocking * jcp.tile_width |
690 | * (jcp.nb_ic_int * jcp.ic_block_int_np |
691 | + jcp.nb_oc_blocking * jcp.oc_block); |
692 | int c_size = (jcp.nb_ic_int * jcp.ic_block_int_np * jcp.nb_oc_blocking |
693 | * jcp.oc_block); |
694 | // If the size of src + wei used in the kernel cannot fit into L1 cache, |
695 | // use non-temporal load of weights to help keep src in L1 cache |
696 | if (static_cast<size_t>(jcp.typesize_in * (ab_size + c_size)) |
697 | >= platform::get_per_core_cache_size(1)) |
698 | tileloaddt1(t1, ptr[wei_ptr + offset + stride_seq]); |
699 | else |
700 | tileloadd(t1, ptr[wei_ptr + offset + stride_seq]); |
701 | }; |
702 | |
703 | auto compute_block = [=](int icb, int os_b) { |
704 | for (int osb = 0; osb < os_b; osb++) { |
705 | int ih = ((osb * jcp.tile_width) / jcp.ow) * jcp.stride_h; |
706 | int iw = ((osb * jcp.tile_width) % jcp.ow) * jcp.stride_w; |
707 | tileloadd(Tmm(get_inp_tensor(osb)), |
708 | ptr[inp_ptr + stride_nhwc + inp_offset(ih, iw, icb)]); |
709 | } |
710 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) { |
711 | const int wei_offset = jcp.typesize_in |
712 | * (ocb |
713 | * utils::rnd_up(jcp.ic_without_padding, |
714 | jcp.ic_block_int) |
715 | * jcp.oc_block |
716 | + icb * jcp.ic_block_int_np * jcp.oc_block); |
717 | tileloadd_nt(Tmm(get_wei_tensor(ocb)), wei_offset); |
718 | for (int osb = 0; osb < os_b; osb++) { |
719 | tdpbxxd(Tmm(get_out_tensor(osb, ocb)), Tmm(get_inp_tensor(osb)), |
720 | Tmm(get_wei_tensor(ocb))); |
721 | interleave_store(); |
722 | } |
723 | } |
724 | }; |
725 | |
726 | auto reconfig_tiles = [=](tiles_cfg_t cfg) { |
727 | tilerelease(); |
728 | if (cfg == cfg_tiles) { |
729 | mov(reg_scratch, ptr[param1 + GET_OFF(tile_cfg)]); |
730 | } else if (cfg == cfg_tiles_tail) { |
731 | mov(reg_scratch, ptr[param1 + GET_OFF(tile_cfg_tail)]); |
732 | } |
733 | ldtilecfg(ptr[reg_scratch]); |
734 | }; |
735 | |
736 | auto restore_output_tiles = [=](int os_b, restore_tiles_t restore) { |
737 | mov(reg_tilebuff, ptr[param1 + GET_OFF(src_prf)]); |
738 | for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) |
739 | for (int osb = 0; osb < os_b; osb++) { |
740 | const int wsp_offset = jcp.typesize_acc |
741 | * (osb * jcp.nb_oc_blocking * jcp.max_width |
742 | * jcp.oc_block |
743 | + ocb * jcp.max_width * jcp.oc_block); |
744 | if (restore == write_tiles) |
745 | tilestored(ptr[reg_tilebuff + stride_seq + wsp_offset], |
746 | Tmm(get_out_tensor(osb, ocb))); |
747 | else if (restore == read_tiles) |
748 | tileloadd(Tmm(get_out_tensor(osb, ocb)), |
749 | ptr[reg_tilebuff + stride_seq + wsp_offset]); |
750 | } |
751 | }; |
752 | |
753 | auto reset_tiles = [=](int os_b, bool tail) { |
754 | if (jcp.nb_ic_int != 1) { |
755 | restore_output_tiles(os_b, write_tiles); |
756 | reconfig_tiles((tail) ? cfg_tiles_tail : cfg_tiles); |
757 | restore_output_tiles(os_b, read_tiles); |
758 | } |
759 | }; |
760 | |
761 | auto compute_icb_loop = [=](int os_b = 1) { |
762 | int shift = (get_ic_tail() && os_b == 1) ? 1 : 0; |
763 | int nb_ic_int = jcp.nb_ic_int - shift; |
764 | |
765 | if (jcp.src_zero_point) { |
766 | mov(reg_zp_compensation, ptr[param1 + GET_OFF(zp_compensation)]); |
767 | mov(reg_src_zero_point, ptr[param1 + GET_OFF(src_zero_point)]); |
768 | vpbroadcastd(zmm_src_zp, EVEX_compress_addr(reg_src_zero_point, 0)); |
769 | } |
770 | if (jcp.dst_zero_point) { |
771 | mov(reg_dst_zero_point, ptr[param1 + GET_OFF(dst_zero_point)]); |
772 | vcvtdq2ps(zmm_dst_zp, |
773 | EVEX_compress_addr(reg_dst_zero_point, 0, true)); |
774 | } |
775 | |
776 | for (int icb = 0; icb < nb_ic_int; icb++) |
777 | compute_block(icb, os_b); |
778 | |
779 | // Tail processing |
780 | if (get_ic_tail() && os_b == 1) { |
781 | reset_tiles(os_b, true); |
782 | compute_block(nb_ic_int, os_b); |
783 | reset_tiles(os_b, false); |
784 | } |
785 | }; |
786 | |
787 | Label label_last_os, label_compute_done, label_tail, label_done; |
788 | |
789 | int stride_nhwc_ = jcp.typesize_in * jcp.ngroups * jcp.ic_without_padding |
790 | * jcp.stride_w; |
791 | mov(stride_nhwc, stride_nhwc_); |
792 | |
793 | prepare_output(); |
794 | { // Compute |
795 | if (check_last_sb_) { |
796 | mov(reg_last_h, ptr[param1 + GET_OFF(last_h)]); |
797 | cmp(reg_last_h, 1); |
798 | je(label_last_os, T_NEAR); |
799 | } |
800 | compute_icb_loop(jcp.nb_os_blocking); |
801 | |
802 | jmp(label_compute_done, T_NEAR); |
803 | |
804 | L(label_last_os); |
805 | compute_icb_loop(); |
806 | } |
807 | L(label_compute_done); |
808 | { // Store |
809 | if (jcp.tile_tail && check_last_sb_) |
810 | store_output(do_store, true); |
811 | else |
812 | store_output(do_store, false); |
813 | } |
814 | } |
815 | |
816 | void jit_avx512_core_amx_1x1_fwd_kernel_t::osb_loop(int nb_os) { |
817 | for (int osi = 0; osi < nb_os; osi++) { |
818 | bool do_store = IMPLICATION(jcp.per_one_pstore, (osi == nb_os - 1)); |
819 | check_last_sb_ = do_store; |
820 | |
821 | icb_loop(do_store); |
822 | |
823 | int oh = (((osi + 1) * jcp.nb_os_blocking * jcp.tile_width) / jcp.ow); |
824 | int ow = (((osi + 1) * jcp.nb_os_blocking * jcp.tile_width) % jcp.ow); |
825 | if (do_store) { |
826 | size_t out_offset = jcp.typesize_out |
827 | * (oh * out_h_shift() + ow * out_w_shift()); |
828 | add(out_ptr, out_offset); |
829 | } |
830 | |
831 | int ih = oh * jcp.stride_h; |
832 | int iw = ow * jcp.stride_w; |
833 | add(inp_ptr, inp_offset(ih, iw, 0)); |
834 | } |
835 | } |
836 | |
837 | int jit_avx512_core_amx_1x1_fwd_kernel_t::get_ic_tail() const { |
838 | return (jcp.ic_without_padding % jcp.ic_block_int_np); |
839 | } |
840 | |
841 | void jit_avx512_core_amx_1x1_fwd_kernel_t::generate() { |
842 | preamble(); |
843 | |
844 | last_oc_block_flag_ = (jcp.oc_without_padding != jcp.oc); |
845 | if (last_oc_block_flag_) { |
846 | Xbyak::Label mask_is_set; |
847 | |
848 | // Use mask 0xF by default for all output data and post-ops |
849 | // loads / stores with block index |
850 | // ocb = occ * jcp.nb_oc_blocking + (jcp.nb_oc_blocking - 1) |
851 | // TODO: use masked loads / stores for the last occ only |
852 | int mask = (1 << jcp.oc_block) - 1; |
853 | Xbyak::Reg32 regw_tmp = reg_tmp.cvt32(); |
854 | mov(regw_tmp, mask); |
855 | kmovw(ktail_mask, regw_tmp); |
856 | mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]); |
857 | cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking); |
858 | jne(mask_is_set, T_NEAR); |
859 | |
860 | // Reset the mask |
861 | mask = (1 << (jcp.oc_without_padding % jcp.oc_block)) - 1; |
862 | mov(regw_tmp, mask); |
863 | kmovw(ktail_mask, regw_tmp); |
864 | |
865 | L(mask_is_set); |
866 | } |
867 | |
868 | mov(inp_ptr, ptr[param1 + GET_OFF(src)]); |
869 | mov(wei_ptr, ptr[param1 + GET_OFF(filt)]); |
870 | mov(out_ptr, ptr[param1 + GET_OFF(dst)]); |
871 | mov(wsp_ptr, ptr[param1 + GET_OFF(acc_s32)]); |
872 | |
873 | mov(reg_is_osb, ptr[param1 + GET_OFF(is_osb)]); |
874 | |
875 | constexpr int tile_mem_stride_in_bytes = 64; |
876 | mov(stride_seq, tile_mem_stride_in_bytes); |
877 | |
878 | init_runtime_counters(); |
879 | update_buffer_pointers(); |
880 | |
881 | Xbyak::Label label_no_osb, label_done; |
882 | |
883 | cmp(reg_is_osb, 0); |
884 | je(label_no_osb, T_NEAR); |
885 | |
886 | osb_loop(jcp.nb_os2_blocking); |
887 | jmp(label_done, T_NEAR); |
888 | |
889 | L(label_no_osb); |
890 | osb_loop(); |
891 | |
892 | L(label_done); |
893 | postamble(); |
894 | |
895 | if (jcp.with_eltwise) postops_injector_->prepare_table(); |
896 | } |
897 | |
898 | void jit_avx512_core_amx_1x1_fwd_kernel_t::tile_configure(char *tcfg_buff) { |
899 | |
900 | int tile_max_columns_in_bytes |
901 | = amx::get_max_column_bytes(amx::get_target_palette()); |
902 | const int max_palette_size_in_bytes = 64; |
903 | |
904 | auto cfg_tiles = [=](palette_config_t *buff, int Ac) { |
905 | char *_tc = (char *)buff; |
906 | for (int i = 0; i < max_palette_size_in_bytes; i++) |
907 | _tc[i] = 0; |
908 | |
909 | int Ar = jcp.tile_width; |
910 | int Br = Ac / jcp.typesize_acc; |
911 | int Cr = jcp.tile_width; |
912 | |
913 | int Bc = tile_max_columns_in_bytes; |
914 | int Cc = tile_max_columns_in_bytes; |
915 | |
916 | for (int s = 0; s < jcp.nb_os_blocking; s++) |
917 | tc_configure_tile(buff, get_inp_tensor(s), Ar, Ac); |
918 | for (int i = 0; i < jcp.nb_oc_blocking; i++) |
919 | tc_configure_tile(buff, get_wei_tensor(i), Br, Bc); |
920 | |
921 | for (int s = 0; s < jcp.nb_os_blocking; s++) |
922 | for (int i = 0; i < jcp.nb_oc_blocking; i++) { |
923 | tc_configure_tile(buff, get_out_tensor(s, i), Cr, Cc); |
924 | } |
925 | |
926 | buff->palette_id = amx::get_target_palette(); |
927 | }; |
928 | |
929 | int Ac = jcp.typesize_in |
930 | * ((jcp.nb_ic_int == 1 && get_ic_tail()) ? get_ic_tail() |
931 | : jcp.ic_block_int_np); |
932 | |
933 | cfg_tiles((palette_config_t *)tcfg_buff, Ac); |
934 | if (jcp.nb_ic_int > 1 && get_ic_tail()) { |
935 | int Ac = jcp.typesize_in * get_ic_tail(); |
936 | char *_t = tcfg_buff + max_palette_size_in_bytes; |
937 | cfg_tiles((palette_config_t *)(_t), Ac); |
938 | } |
939 | } |
940 | |
941 | status_t jit_avx512_core_amx_1x1_fwd_kernel_t::init_conf(jit_conv_conf_t &jcp, |
942 | const convolution_desc_t &cd, memory_desc_t &src_md, |
943 | memory_desc_t &weights_md, memory_desc_t &dst_md, |
944 | memory_desc_t &bias_md, primitive_attr_t &attr, int nthreads) { |
945 | using namespace prop_kind; |
946 | |
947 | const memory_desc_wrapper src_d(&src_md); |
948 | const memory_desc_wrapper weights_d(&weights_md); |
949 | const memory_desc_wrapper dst_d(&dst_md); |
950 | const memory_desc_wrapper bias_d(&bias_md); |
951 | |
952 | const bool with_groups = weights_d.ndims() == src_d.ndims() + 1; |
953 | int ndims = src_d.ndims(); |
954 | bool is_1d = ndims == 3; |
955 | bool is_3d = ndims == 5; |
956 | |
957 | const bool is_bf16_convolution |
958 | = everyone_is(true, src_d.data_type() == data_type::bf16, |
959 | weights_d.data_type() == data_type::bf16, |
960 | one_of(dst_d.data_type(), data_type::bf16, data_type::f32)); |
961 | const bool is_int8_convolution = everyone_is(true, |
962 | (src_d.data_type() == data_type::u8 |
963 | || src_d.data_type() == data_type::s8), |
964 | weights_d.data_type() == data_type::s8, |
965 | one_of(dst_d.data_type(), data_type::f32, data_type::s32, |
966 | data_type::s8, data_type::u8, data_type::bf16)); |
967 | |
968 | bool supported = mayiuse(avx512_core_amx) |
969 | && (is_bf16_convolution || is_int8_convolution); |
970 | if (!supported) return status::unimplemented; |
971 | |
972 | jcp = zero<decltype(jcp)>(); |
973 | jcp.isa = avx512_core_amx; |
974 | jcp.ndims = ndims; |
975 | jcp.prop_kind = cd.prop_kind; |
976 | jcp.ngroups = with_groups ? weights_d.dims()[0] : 1; |
977 | jcp.mb = src_d.dims()[0]; |
978 | jcp.oc = dst_d.dims()[1] / jcp.ngroups; |
979 | jcp.oc_without_padding = jcp.oc; |
980 | jcp.ic = src_d.dims()[1] / jcp.ngroups; |
981 | jcp.ic_without_padding = jcp.ic; |
982 | jcp.id = is_3d ? src_d.dims()[2] : 1; |
983 | jcp.ih = !is_1d ? src_d.dims()[ndims - 2] : 1; |
984 | jcp.iw = src_d.dims()[ndims - 1]; |
985 | jcp.od = is_3d ? dst_d.dims()[2] : 1; |
986 | jcp.oh = !is_1d ? dst_d.dims()[ndims - 2] : 1; |
987 | jcp.ow = dst_d.dims()[ndims - 1]; |
988 | jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1; |
989 | jcp.kh = !is_1d ? weights_d.dims()[with_groups + ndims - 2] : 1; |
990 | jcp.kw = weights_d.dims()[with_groups + ndims - 1]; |
991 | jcp.f_pad = is_3d ? cd.padding[0][0] : 0; |
992 | jcp.t_pad = !is_1d ? cd.padding[0][ndims - 4] : 0; |
993 | jcp.l_pad = cd.padding[0][ndims - 3]; |
994 | jcp.stride_d = is_3d ? cd.strides[0] : 1; |
995 | jcp.stride_h = !is_1d ? cd.strides[ndims - 4] : 1; |
996 | jcp.stride_w = cd.strides[ndims - 3]; |
997 | jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef; |
998 | |
999 | if (!(jcp.kd == 1 && jcp.kh == 1 && jcp.kw == 1)) |
1000 | return status::unimplemented; |
1001 | |
1002 | if (!(jcp.f_pad == 0 && jcp.t_pad == 0 && jcp.l_pad == 0)) |
1003 | return status::unimplemented; |
1004 | |
1005 | jcp.dilate_d = is_3d ? cd.dilates[0] : 0; |
1006 | jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4]; |
1007 | jcp.dilate_w = cd.dilates[ndims - 3]; |
1008 | |
1009 | jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc); |
1010 | |
1011 | if (jcp.dilate_d != 0 || jcp.dilate_h != 0 || jcp.dilate_w != 0) |
1012 | return status::unimplemented; |
1013 | if (jcp.is_depthwise) |
1014 | return status::unimplemented; // TODO: add support of DW convolution |
1015 | if (jcp.ngroups > 1) |
1016 | return status::unimplemented; // TODO: add support for non-unit groups |
1017 | |
1018 | jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef; |
1019 | jcp.dst_dt = cd.dst_desc.data_type; |
1020 | jcp.src_dt = cd.src_desc.data_type; |
1021 | jcp.wei_dt = cd.weights_desc.data_type; |
1022 | |
1023 | // Dispatch small shapes to VNNI for better performance |
1024 | const auto is_small_shape = jcp.od * jcp.oh * jcp.ow <= 4 && jcp.ic <= 512 |
1025 | && jcp.mb * jcp.ngroups * jcp.ic * jcp.oc <= static_cast<int32_t>( |
1026 | platform::get_per_core_cache_size(1) / 2); |
1027 | const auto is_3d_small_ic = jcp.ndims == 5 && jcp.ic * jcp.oc <= 32 |
1028 | && jcp.od >= 128 && jcp.oh >= 128 && jcp.ow >= 128; |
1029 | if (is_small_shape || is_3d_small_ic) return status::unimplemented; |
1030 | |
1031 | const auto zp = attr.zero_points_; |
1032 | jcp.dst_zero_point = !zp.has_default_values(DNNL_ARG_DST); |
1033 | jcp.src_zero_point = !zp.has_default_values(DNNL_ARG_SRC); |
1034 | jcp.zp_src_is_common = zp.common( |
1035 | DNNL_ARG_SRC); // otherwise, it's per-channel (not supported) |
1036 | if (!IMPLICATION(jcp.src_zero_point, jcp.zp_src_is_common) |
1037 | || !IMPLICATION(jcp.dst_zero_point || jcp.src_zero_point, |
1038 | is_int8_convolution)) |
1039 | return status::unimplemented; |
1040 | |
1041 | jcp.nthr = nthreads; |
1042 | |
1043 | jcp.ic_block = 16; |
1044 | jcp.ic_block_int = is_bf16_convolution ? 32 : 64; |
1045 | jcp.ic_block_int_np = jcp.ic_block_int; |
1046 | if (jcp.ic_block_int < jcp.ic_without_padding |
1047 | && jcp.ic_without_padding % jcp.ic_block_int != 0) { |
1048 | // Order of blocks comes from empirical observation |
1049 | static const int try_blocks[] = {32, 48, 40, 56}; |
1050 | for (auto blk_size : try_blocks) { |
1051 | const int _blk_size = is_bf16_convolution ? blk_size / 2 : blk_size; |
1052 | if (jcp.ic_without_padding % _blk_size == 0) { |
1053 | jcp.ic_block_int_np = _blk_size; |
1054 | break; |
1055 | } |
1056 | } |
1057 | } |
1058 | jcp.oc_block = 16; |
1059 | |
1060 | bool args_ok = true && jcp.ic % 4 == 0 |
1061 | && (jcp.ow == jcp.iw && jcp.stride_w == 1) |
1062 | && (jcp.oh == jcp.ih && jcp.stride_h == 1) |
1063 | && (jcp.od == jcp.id && jcp.stride_d == 1); |
1064 | if (!args_ok) return status::unimplemented; |
1065 | |
1066 | if (jcp.ngroups == 1) { |
1067 | jcp.oc = rnd_up(jcp.oc, jcp.oc_block); |
1068 | jcp.ic = rnd_up(jcp.ic, jcp.ic_block); |
1069 | } |
1070 | |
1071 | auto set_or_check_wei_format = [&]() { |
1072 | using namespace format_tag; |
1073 | using namespace memory_extra_flags; |
1074 | format_tag_t wei_tag; |
1075 | wei_tag = (is_bf16_convolution) |
1076 | ? pick(with_groups + 2 * (ndims - 3), OIw16i16o2i, gOIw16i16o2i, |
1077 | OIhw16i16o2i, gOIhw16i16o2i, OIdhw16i16o2i, |
1078 | gOIdhw16i16o2i) |
1079 | : pick(with_groups + 2 * (ndims - 3), OIw16i16o4i, gOIw16i16o4i, |
1080 | OIhw16i16o4i, gOIhw16i16o4i, OIdhw16i16o4i, |
1081 | gOIdhw16i16o4i); |
1082 | memory_desc_t want_wei_md = weights_md; |
1083 | memory_desc_init_by_tag(want_wei_md, wei_tag); |
1084 | |
1085 | if (jcp.src_zero_point) { |
1086 | want_wei_md.extra.flags |= compensation_conv_asymmetric_src; |
1087 | want_wei_md.extra.asymm_compensation_mask = (1 << 0) |
1088 | + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0); |
1089 | } |
1090 | if (weights_md.format_kind == format_kind::any) { |
1091 | weights_md = want_wei_md; |
1092 | return true; |
1093 | } |
1094 | return weights_md == want_wei_md; |
1095 | }; |
1096 | |
1097 | if (!set_or_check_wei_format()) { return status::unimplemented; } |
1098 | |
1099 | format_tag_t dat_tag = utils::pick( |
1100 | ndims - 3, format_tag::nwc, format_tag::nhwc, format_tag::ndhwc); |
1101 | |
1102 | if (src_d.format_kind() == format_kind::any) { |
1103 | CHECK(memory_desc_init_by_tag(src_md, dat_tag)); |
1104 | jcp.src_tag = dat_tag; |
1105 | } else { |
1106 | jcp.src_tag = src_d.matches_one_of_tag(dat_tag); |
1107 | } |
1108 | if (jcp.src_tag != dat_tag) { return status::unimplemented; } |
1109 | |
1110 | if (dst_d.format_kind() == format_kind::any) { |
1111 | CHECK(memory_desc_init_by_tag(dst_md, dat_tag)); |
1112 | jcp.dst_tag = dat_tag; |
1113 | } else { |
1114 | jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag); |
1115 | } |
1116 | if (jcp.dst_tag != dat_tag) { return status::unimplemented; } |
1117 | |
1118 | if (jcp.with_bias) { |
1119 | if (bias_d.format_kind() == format_kind::any) |
1120 | CHECK(memory_desc_init_by_tag(bias_md, format_tag::x)); |
1121 | } |
1122 | |
1123 | CHECK(attr.set_default_formats(&dst_md)); |
1124 | |
1125 | const auto &p = attr.post_ops_; |
1126 | |
1127 | const int sum_ind = p.find(primitive_kind::sum); |
1128 | jcp.with_sum = sum_ind != -1; |
1129 | const int eltwise_ind = p.find(primitive_kind::eltwise); |
1130 | jcp.with_eltwise = eltwise_ind != -1; |
1131 | const int binary_ind = p.find(primitive_kind::binary); |
1132 | jcp.with_binary = binary_ind != -1; |
1133 | jcp.sum_dt = p.get_sum_dt(jcp.dst_dt); |
1134 | |
1135 | jcp.post_ops = p; |
1136 | jcp.is_fast_postops = is_fast_postops(jcp); |
1137 | |
1138 | using namespace injector; |
1139 | const bool sum_at_pos_0_only = (jcp.src_dt == data_type::bf16); |
1140 | const bool sum_requires_scale_one = sum_at_pos_0_only; |
1141 | const bool sum_requires_zp_zero = sum_at_pos_0_only; |
1142 | const bool post_ops_ok_ = post_ops_ok({avx512_core, {eltwise, binary, sum}, |
1143 | jcp.post_ops, &dst_d, sum_at_pos_0_only, sum_requires_scale_one, |
1144 | sum_requires_zp_zero}); |
1145 | if (!post_ops_ok_) return status::unimplemented; |
1146 | |
1147 | jcp.typesize_in = types::data_type_size(src_d.data_type()); |
1148 | jcp.typesize_out = types::data_type_size(dst_d.data_type()); |
1149 | jcp.typesize_bia |
1150 | = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0; |
1151 | jcp.typesize_acc = sizeof(int32_t); |
1152 | |
1153 | jcp.nb_ic = jcp.ic / jcp.ic_block; |
1154 | jcp.nb_oc = jcp.oc / jcp.oc_block; |
1155 | jcp.nb_ic_int = div_up(jcp.ic_without_padding, jcp.ic_block_int_np); |
1156 | |
1157 | jcp.max_width = amx::get_max_rows(amx::get_target_palette()); |
1158 | if (jcp.max_width <= 0) return status::unimplemented; |
1159 | |
1160 | const int size_treshold = 32; |
1161 | const int min_width |
1162 | = 1; // TODO: Possible optimizations: do not use small values |
1163 | const int spatial = jcp.od * jcp.oh; |
1164 | const int os = jcp.od * jcp.oh * jcp.ow; |
1165 | |
1166 | jcp.tile_width = 1; |
1167 | for (int s_size = jcp.max_width; s_size >= min_width; s_size--) { |
1168 | if ((spatial >= size_treshold && spatial % s_size == 0) |
1169 | || (spatial < size_treshold && os % s_size == 0)) { |
1170 | jcp.tile_width = s_size; |
1171 | break; |
1172 | } |
1173 | } |
1174 | if (jcp.tile_width == 1) { |
1175 | jcp.tile_width = nstl::min(jcp.max_width, os); |
1176 | jcp.tile_tail = os % jcp.max_width; |
1177 | for (int i = jcp.max_width; i >= min_width; i--) { |
1178 | int i_tail = os % i; |
1179 | if (i_tail > jcp.tile_tail || i_tail == 0) { |
1180 | jcp.tile_width = i; |
1181 | jcp.tile_tail = i_tail; |
1182 | if (i_tail == 0) break; |
1183 | } |
1184 | } |
1185 | if (jcp.tile_width < min_width && jcp.tile_tail < min_width) |
1186 | jcp.tile_tail = 0; |
1187 | } |
1188 | |
1189 | /* TODO: Add stride support ! |
1190 | while ((jcp.stride_h != 1 || jcp.stride_w != 1) |
1191 | && (jcp.ow % jcp.tile_width != 0) || jcp.tile_width > 16) { |
1192 | jcp.tile_width = jcp.ow / 2; |
1193 | } |
1194 | */ |
1195 | |
1196 | // TODO: Add support for spatial tails |
1197 | if (jcp.tile_tail != 0) return status::unimplemented; |
1198 | |
1199 | // TODO: Implement efficient tile tail processing. Now just go to common |
1200 | // case if we utilize half of tile or less. |
1201 | if (jcp.tile_width <= jcp.max_width / 2) return status::unimplemented; |
1202 | |
1203 | jcp.nb_oc_blocking = (jcp.nb_oc % 2 == 0) ? 2 : 1; |
1204 | jcp.nb_ic_blocking = 1; |
1205 | jcp.nb_os_blocking = (os / jcp.tile_width > 2) ? 2 : 1; |
1206 | jcp.nb_os2_blocking = (jcp.nb_os_blocking > 1) |
1207 | ? ((jcp.nb_os_blocking * jcp.tile_width) % 2 == 0) ? 2 : 1 |
1208 | : 1; |
1209 | jcp.nb_os = os / jcp.tile_width; |
1210 | |
1211 | jcp.wsp_buffer_size = (size_t)2 * jcp.nb_os_blocking * jcp.nb_oc_blocking |
1212 | * jcp.max_width * jcp.oc_block; |
1213 | |
1214 | int ops_tile_store |
1215 | = jcp.nb_oc_blocking * jcp.nb_os_blocking * jcp.tile_width; |
1216 | int avaliable_ops = jcp.nb_ic_int * jcp.nb_oc_blocking * jcp.nb_os_blocking; |
1217 | jcp.per_one_pstore |
1218 | = (avaliable_ops) ? ops_tile_store / avaliable_ops + 1 : 0; |
1219 | if (jcp.per_one_pstore > 12) jcp.per_one_pstore = 0; |
1220 | |
1221 | const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC); |
1222 | const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS); |
1223 | const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST); |
1224 | const int wei_mask_per_oc = 1 << (int)with_groups; |
1225 | jcp.is_oc_scale = wei_scales.mask_ == wei_mask_per_oc; |
1226 | jcp.dst_scale = !dst_scales.has_default_values(); |
1227 | |
1228 | // only common src & dst scales are supported |
1229 | // only common and per-oc-channel weight scales are supported |
1230 | const bool scales_ok = one_of(wei_scales.mask_, 0, wei_mask_per_oc) |
1231 | && everyone_is(src_scales.mask_, dst_scales.mask_, 0); |
1232 | if (!scales_ok) return status::unimplemented; |
1233 | |
1234 | return status::success; |
1235 | } |
1236 | |
1237 | void jit_avx512_core_amx_1x1_fwd_kernel_t::init_scratchpad( |
1238 | memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, |
1239 | const primitive_attr_t &attr) { |
1240 | scratchpad.book(key_conv_amx_wsp_buffer, jcp.nthr * jcp.wsp_buffer_size, |
1241 | jcp.typesize_acc); |
1242 | if (jcp.ic_without_padding % jcp.ic_block_int_np) |
1243 | scratchpad.book(key_conv_amx_tile_buffer, |
1244 | jcp.nthr * (jcp.wsp_buffer_size / 2), jcp.typesize_acc); |
1245 | if (jcp.with_bias && jcp.oc != jcp.oc_without_padding) { |
1246 | assert(jcp.ngroups == 1); |
1247 | scratchpad.book(key_conv_padded_bias, jcp.oc, jcp.typesize_bia); |
1248 | } |
1249 | scratchpad.book(key_conv_amx_tilecfg, 2, 64); // 2 whole cachelines |
1250 | book_precomputed_scales( |
1251 | scratchpad, attr.scales_, jcp.ngroups * jcp.oc_without_padding); |
1252 | } |
1253 | |
1254 | } // namespace x64 |
1255 | } // namespace cpu |
1256 | } // namespace impl |
1257 | } // namespace dnnl |
1258 | |
1259 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
1260 | |